import asyncio
import json
import weakref
from logging import getLogger

import bcrypt
from aiohttp import WSCloseCode, web

from riego.db import subscribe_db

_log = getLogger(__name__)


_instance = None


def get_websockets():
    global _instance
    return _instance


def setup_websockets(app=None, options=None, db=None):
    global _instance
    if _instance is not None:
        del _instance
    _instance = Websockets(app=app, options=options, db=db)
    return _instance


class Websockets:
    def __init__(self, app=None, options=None, db=None):
        global _instance
        if _instance is None:
            _instance = self

        self._options = options
        self._db = db

        self._ws_list = weakref.WeakSet()
        self._subscriptions = {}

        subscribe_db(self._db_update_cb)

        app.router.add_get(
            self._options.websocket_path, self._ws_handler, name="websockets"
        )
        app.on_shutdown.append(self._shutdown)

    async def _db_update_cb(self, table, action, old_rowid, new_rowid):
        message = {
            "action": "update",
            "scope": table,
        }
        message = json.dumps(message)
        await self.send_to_all(message)

    async def send_to_all(self, message):
        for ws in self._ws_list:
            if ws.closed:
                _log.debug("send_to_all: websocket already closed")
            else:
                try:
                    await ws.send_str(message)
                except Exception as e:
                    _log.error(f"send_to_all: Exeption while send_str: {e}")
        return None

    def subscribe(self, scope: str, callback: callable) -> None:
        """Install a callback function for given scope.

        :param scope: name of scope that asks for websocket
        :type scope: str
        :param callback: callback function that is called when data arrives
        :type callback: function with parameters msg
        :return: None
        :rtype: None
        """
        self._subscriptions[scope] = callback
        return None

    async def _ws_handler(self, request) -> web.WebSocketResponse:
        max_msg_size = self._options.websockets_max_receive_size
        ws = web.WebSocketResponse(max_msg_size=max_msg_size)
        await ws.prepare(request)

        self._ws_list.add(ws)

        try:
            async for msg in ws:
                if msg.type == web.WSMsgType.TEXT:
                    _log.debug(f"_dispatch_message: {msg.data}")
                    await self._dispatch_message(msg.data)
                else:
                    _log.error(f"Unknown message type: {msg}")
                    await asyncio.sleep(3)
                    break
        except Exception as e:
            _log.error(f"Exeption while reading from websocket: {e}")
            await asyncio.sleep(3)
        finally:
            _log.debug(f"Removing closed websocket: {ws}")
            self._ws_list.discard(ws)
        return ws

    async def _dispatch_message(self, msg: dict) -> bool:
        msg = json.loads(msg)
        scope = msg.get("scope", "")
        if scope == "authenticate_v1":
            # TODO Implemet a one-time authentication
            return True

        if len(scope) == 0:
            _log.error(f"Message not for a scope: {msg}")
            await asyncio.sleep(3)
            return False

        token = msg.get("token", "")
        sequence = msg.get("sequence", "")

        if len(sequence) == 0 or len(token) == 0:
            _log.error(f"Websocket-Auth: missing var {sequence},{token}")
            await asyncio.sleep(3)
            return False

        cursor = self._db.cursor()
        cursor.execute(
            """SELECT * FROM users_tokens
                    WHERE sequence = ?""",
            (sequence,),
        )
        item = cursor.fetchone()
        if item is None:
            _log.error(f"Websocket-Auth: unknown {sequence},{token}")
            await asyncio.sleep(3)
            return False

        token = token.encode("utf-8")
        if bcrypt.checkpw(token, item["hash"]):
            _log.debug(f"authenticate: {token}, sequence: {sequence}")
        else:
            _log.error(f"no authenticate: {token}, sequence: {sequence}")
            await asyncio.sleep(3)
            return False

        callback_func = self._subscriptions.get(scope, None)
        if callback_func is None:
            _log.error(f"Message for an unknown scope: {msg}")
            return False
        try:
            await callback_func(msg)
        except Exception as e:
            _log.error(f"Exeption in {callback_func}: {e}")
            return False
        return True

    async def _shutdown(self, app) -> None:
        for ws in self._ws_list.copy():
            _log.debug(f"calling ws.close for: {ws}")
            await ws.close(code=WSCloseCode.GOING_AWAY, message="Server shutdown")
        return None
