import asyncio, random, logging, threading, time, functools, traceback, grpc, json
from urllib.parse import urlparse
from proto import base_pb2, queues_pb2, workitems_pb2, querys_pb2, watch_pb2
from proto.base_pb2_grpc import FlowServiceStub
class protowrap():
    pending:dict = {}
    streams:dict = {}
    @staticmethod
    def Connect(client):
        uri = urlparse(client.url)
        if(uri.scheme == "grpc"):
            threading.Thread(target=protowrap.__grpc_listen_for_messages, args=(client, ), daemon=True).start()

    def __grpc_connect_and_listen(client, itr):
        try:
            uri = urlparse(client.url)
            if(uri.port == 443 or uri.port == "443"):
                logging.info(f"Connecting to {uri.hostname}:{uri.port} using ssl credentials")
                credentials = grpc.ssl_channel_credentials()
                client.chan = grpc.secure_channel(f"{uri.hostname}:{uri.port}", credentials, options=(('grpc.ssl_target_name_override', uri.hostname),))
            else:
                logging.info(f"Connecting to {uri.hostname}:{uri.port}")
                client.chan = grpc.insecure_channel(f"{uri.hostname}:{uri.port}")
            fut = grpc.channel_ready_future(client.chan)
            while not fut.done():
                logging.debug("channel is not ready")
                time.sleep(1)
            client.connected = True
            logging.info(f"Connected to {uri.hostname}:{uri.port}")
            asyncio.run_coroutine_threadsafe(client.onconnected(client), client.loop)
            logging.debug(f"Create stub and connect streams")
            stub = FlowServiceStub(client.chan)
            for message in stub.SetupStream(itr):
                logging.debug(f"RCV[{message.id}][{message.rid}][{message.command}]")
                protowrap.parse_message(client, message)
        except Exception as e:
            if(client.connected == True):
                client.connected = False
                print(repr(e))
                traceback.print_tb(e.__traceback__)
            pass
        logging.debug(f"Close channels")
        client.messagequeues = {}
        client.watches = {}
        for id in protowrap.pending:
            err = ValueError("Channel closed")
            client.loop.call_soon_threadsafe(protowrap.pending[id].set_exception, err)
        for id in protowrap.pending:
            err = ValueError("Channel closed")
            client.loop.call_soon_threadsafe(protowrap.pending[id].set_exception, err)
        client.chan.close()
    def __grpc_request_iterator(client, connectonid:str):
        try:
            logging.debug(f"Waiting for message for connecton id {connectonid}")
            message = client.grpcqueue.get()
            if(connectonid != protowrap.connectonid):
                client.grpcqueue.put(message)
                return None
            logging.debug(f"Process sending message for connecton id {connectonid}")
            if(message.id == None or message.id == ""): message.id = str(next(protowrap.uniqueid()))
            logging.debug(f"SND[{message.id}][{message.rid}][{message.command}]")
            return(message)
        except Exception as e:
            print(repr(e))
            traceback.print_tb(e.__traceback__)
            client.chan.close()
            pass        
    @staticmethod
    def __grpc_listen_for_messages(client):
        while True:
            protowrap.connectonid = str(next(protowrap.uniqueid()))
            count = 0
            logging.debug(f"Estabilish connecton id {protowrap.connectonid}")
            protowrap.__grpc_connect_and_listen(client,
                iter(functools.partial(protowrap.__grpc_request_iterator, client, protowrap.connectonid), None)
            )
            protowrap.connected = False
            count += 1
            logging.debug(f"Reconnect number {count}")
            time.sleep(2)
    @staticmethod
    def RPC(client, request:base_pb2.Envelope, id:str = ""):
        if(id == ""):
            id = str(next(protowrap.uniqueid()))
        request.id = id
        future = asyncio.Future()
        protowrap.pending[id] = future
        protowrap.sendMesssag(client, request, id)
        return future
    @staticmethod
    def sendMesssag(client, request:base_pb2.Envelope, id:str):
        if(request.id == None or request.id == ""):
            id = str(next(protowrap.uniqueid()))
            request.id = id
        if(client.grpcqueue != None):
            client.grpcqueue.put(request)
    @staticmethod
    def SetStream(rid:str):
        protowrap.streams[rid] = bytearray(0)
    @staticmethod
    def uniqueid():
        protowrap.seed = random.getrandbits(32)
        while True:
            yield protowrap.seed
            protowrap.seed += 1
    @staticmethod
    async def DownloadFile(client, Id:str=None, Filename:str=None):
        request = base_pb2.Envelope(command="download")
        request.data.Pack(base_pb2.DownloadRequest(filename=Filename,id=Id))
        rid = str(next(protowrap.uniqueid()))
        request.id = rid
        protowrap.SetStream(rid)
        result:base_pb2.DownloadResponse = await protowrap.RPC(client, request, rid)
        if(result.filename != None and result.filename != ""):
            with open(result.filename, "wb") as out_file:
                out_file.write(protowrap.streams[rid])
        protowrap.streams.pop(rid, None)
        return result
    async def __handle_onmessage_callback(client, message, msg):
        reply = await client.onmessage(client, message.command, message.id, msg)
        if(reply != None and message.rid == ""):
            if(reply.command != "noop"):
                protowrap.sendMesssag(client, reply, reply.id)

    async def __handle_watchevent_callback(client, msg):
        doc = json.loads(msg.document)
        await client.watches[msg.id](client, msg.operation, doc)
    async def __handle_queueevent_callback(client, msg):
        payload = json.loads(msg.data)
        if("payload" in payload):
            payload = payload["payload"]
        payload = await client.messagequeues[msg.queuename](client, msg, payload)
        if(payload != None  and msg.replyto != None and msg.replyto != ""):
            reply = base_pb2.Envelope(command="queuemessage")
            res = json.dumps(payload)
            reply.data.Pack(queues_pb2.QueueMessageRequest(queuename=msg.replyto, data=res, striptoken=True, correlationId=msg.correlationId))
            protowrap.sendMesssag(client, reply, reply.id)
    def parse_message(client, message:base_pb2.Envelope):
        return protowrap.__parse_message(client, message=message)
    def __parse_message(client, message:base_pb2.Envelope):
        msg = protowrap.__Unpack(message)
        if(message.command == "queueevent" and msg.correlationId in protowrap.pending and msg.replyto == ""):
            client.loop.call_soon_threadsafe(protowrap.pending[msg.correlationId].set_result, msg)
            protowrap.pending.pop(msg.correlationId, None)
        elif(message.rid in protowrap.streams and message.command in ["beginstream", "stream", "endstream"]):
            # if(message.command == "beginstream"):
            # elif(message.command == "endstream"):
            if(message.command == "stream"):
                # protowrap[message.rid] += msg.data
                protowrap.streams[message.rid].extend(msg.data)
        elif(message.rid in protowrap.pending):
            if(message.command == "error"):
                #raise ValueError(msg.message)
                client.loop.call_soon_threadsafe(protowrap.pending[message.rid].set_exception, ValueError(f"SERVER ERROR {msg.message}\n{msg.stack}" ))
            else:
                client.loop.call_soon_threadsafe(protowrap.pending[message.rid].set_result, msg)
            protowrap.pending.pop(message.rid, None)
        else:
            if(message.command == "watchevent" and msg.id in client.watches):
                asyncio.run_coroutine_threadsafe(protowrap.__handle_watchevent_callback(client, msg), client.loop)
            elif(message.command == "queueevent" and msg.queuename in client.messagequeues):
                asyncio.run_coroutine_threadsafe(protowrap.__handle_queueevent_callback(client, msg), client.loop)
            elif(message.command == "ping" or message.command == "pong" or message.command == "queuemessagereply"):
                pass
                #time.sleep(1)
            else:
                asyncio.run_coroutine_threadsafe(protowrap.__handle_onmessage_callback(client, message, msg), client.loop)
    def __Unpack(message:base_pb2.Envelope):
        if(message.command == "getelement"):
            msg = base_pb2.GetElementResponse()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "signinreply"):
            msg = base_pb2.SigninResponse()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "registerqueuereply"):
            msg = queues_pb2.RegisterQueueResponse()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "queuemessagereply"):
            msg = queues_pb2.QueueMessageResponse()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "queueevent"):
            msg = queues_pb2.QueueEvent()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "pushworkitemreply"):
            msg = workitems_pb2.PushWorkitemResponse()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "popworkitemreply"):
            msg = workitems_pb2.PopWorkitemResponse()
            if(message.data.value != None and message.data.value != "" and message.data.value != b""):
                msg.ParseFromString(message.data.value);
                return msg
            else:
                return None
        elif(message.command == "updateworkitemreply"):
            msg = workitems_pb2.UpdateWorkitemResponse()
            if(message.data.value != None and message.data.value != "" and message.data.value != b""):
                msg.ParseFromString(message.data.value);
                return msg
            else:
                return None
        elif(message.command == "pong"):
            msg = base_pb2.PingResponse()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "ping"):
            msg = base_pb2.PingRequest()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "error"):
            msg = base_pb2.ErrorResponse()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "download"):
            msg = base_pb2.DownloadRequest()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "downloadreply"):
            msg = base_pb2.DownloadResponse()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "stream"):
            msg = base_pb2.Stream()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "beginstream"):
            msg = base_pb2.BeginStream()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "endstream"):
            msg = base_pb2.EndStream()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "queryreply"):
            msg = querys_pb2.QueryResponse()
            msg.ParseFromString(message.data.value);
            return msg 
        elif(message.command == "listcollectionsreply"):
            msg = querys_pb2.ListCollectionsResponse()
            msg.ParseFromString(message.data.value);
            return msg 
        elif(message.command == "dropcollectionreply"):
            msg = querys_pb2.DropCollectionResponse()
            msg.ParseFromString(message.data.value);
            return msg 
        elif(message.command == "getdocumentversionreply"):
            msg = querys_pb2.GetDocumentVersionResponse()
            msg.ParseFromString(message.data.value);
            return msg 
        elif(message.command == "countreply"):
            msg = querys_pb2.CountResponse()
            msg.ParseFromString(message.data.value);
            return msg 
        elif(message.command == "aggregatereply"):
            msg = querys_pb2.AggregateResponse()
            msg.ParseFromString(message.data.value);
            return msg 
        elif(message.command == "insertonereply"):
            msg = querys_pb2.InsertOneResponse()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "insertmanyreply"):
            msg = querys_pb2.InsertManyResponse()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "updateonereply"):
            msg = querys_pb2.UpdateOneResponse()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "updatedocumentreply"):
            msg = querys_pb2.UpdateDocumentResponse()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "insertorupdateonereply"):
            msg = querys_pb2.InsertOrUpdateOneResponse()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "insertorupdatemanyreply"):
            msg = querys_pb2.InsertOrUpdateManyResponse()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "deleteonereply"):
            msg = querys_pb2.DeleteOneResponse()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "deletemanyreply"):
            msg = querys_pb2.DeleteManyResponse()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "watchreply"):
            msg = watch_pb2.WatchResponse()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "watchevent"):
            msg = watch_pb2.WatchEvent()
            msg.ParseFromString(message.data.value);
            return msg
        elif(message.command == "unwatchreply"):
            msg = watch_pb2.UnWatchResponse()
            msg.ParseFromString(message.data.value);
            return msg
        else:
            logging.error(f"Got unknown {message.command} message")
            return None