# -*- coding: utf-8 -*-
"""
Created on Thu Sep  1 23:42:18 2022

@author: ScottStation
"""

import asyncio
import os
import pickle
import sys
import threading
import time
import zlib
from os import path

import msgpack
import nest_asyncio
import pandas as pd
import thriftpy2
from thriftpy2.rpc import make_aio_client

from .utils import get_mac_address

nest_asyncio.apply()

try:
    from .config import outside_server_config, second_server_config

    avail_servers = [outside_server_config, second_server_config]
    server_config = avail_servers[0]
    server_index = 0
except ImportError:
    server_config = {'host': '192.168.18.4', 'port': 6001}
    avail_servers = [server_config]
    server_index = 0

__version__ = '0.0.1'

thrift_path = path.join(sys.modules["ROOT_DIR"], "qedata.thrift")
thrift_path = path.abspath(thrift_path)
qedata_thrift = thriftpy2.load(thrift_path, module_name="qedata_thrift")
loop = asyncio.get_event_loop()


def setTimeout(client, timeout):
    try:
        try:
            sock = client._iprot.trans._trans.sock
        except AttributeError:
            sock = client._iprot.trans.sock
            sock.settimeout(timeout)
    except Exception:
        pass


class qedataClient(object):
    _systoken = ''
    _syssmtoken = ''
    _instance = None
    request_timeout = 30 * 1000

    def __init__(self):
        pass

    def __call__(self, method, **kwargs):
        #print(method)
        #print(kwargs)
        #print(self._syssmtoken)
        return asyncio.run(self.queryData(method, **kwargs))

    @classmethod
    def instance(cls):
        if cls._instance is None:
            cls._instance = qedataClient()
        return cls._instance

    @classmethod
    def check_auth(cls):
        token = cls._systoken
        #print('auth',token)
        return token != ''

    @classmethod
    def check_login(cls):
        smtoken = cls._syssmtoken
        #print('auth',token)
        return smtoken != ''

    @classmethod
    def start_heartbeat(cls):
        t = threading.Thread(target=cls.heartbeat)
        t.daemon = True
        t.start()

    @classmethod
    async def heartbeat(cls):
        while True:
            try:
                # 进行心跳操作，例如发送一个心跳消息给服务器
                # 这里使用 echo 方法作为心跳
                await cls.instance().echo('heartbeat')
                print('heartbeat')
            except Exception as e:
                print(f'Heartbeat failed: {e}')
            time.sleep(10)  # 每隔60秒执行一次心跳

    @classmethod
    async def getClient(cls):
        global server_config, server_index, avail_servers

        for i in range(len(avail_servers)):
            retry_count = 3
            server_config = avail_servers[i]
            while retry_count > 0:
                try:
                    #print(server_config)
                    client = await make_aio_client(
                        qedata_thrift.TestService, server_config['host'], server_config['port'],
                        timeout=cls.request_timeout)
                    if i > 0:
                        avail_servers[0], avail_servers[i] = avail_servers[1], avail_servers[0]
                    return client
                except Exception as e:
                    retry_count -= 1
                    if retry_count == 0 and i == len(avail_servers) - 1:
                        print(f'ERROR: {e}')
                        return None
                    else:
                        print(f'Connect failed, retrying...')

    async def echo(self, name):
        global server_config
        client = await self.getClient()
        if client:
            #setTimeout(client, 30*1000)
            result = await client.echo(name)
            print(result, self._systoken)
            client.close()

    async def queryData(self, method, **kwargs):
        global server_config, server_index, avail_servers
        try:
            client = await self.getClient()
            if client:
                setTimeout(client, self.request_timeout)
                req = qedata_thrift.St_Query_Req()
                req.method_name = method
                req.params = zlib.compress(msgpack.dumps(kwargs, use_bin_type=True))
                if "sm_get" in method:
                    req.token = self._syssmtoken
                else:
                    req.token = self._systoken
                #print('token',req.token)
                result = await client.query(req)

                client.close()
                if result.status:
                    msg = result.msg
                    #print(pickle.loads(zlib.decompress(msg)))
                    data = pickle.loads(zlib.decompress(msg))
                    #print(data)
                    if isinstance(data, dict) and 'type' in data:
                        if data['type'] == 'DataFrame':
                            df = pd.DataFrame(data['data']).sort_index()
                            df.columns = data['cols']

                            #print(df)
                            return df
                        elif data['type'] == 'DataFrameDict':
                            del data['type']
                            resdata = {}
                            for key in data:
                                resdata[key] = pd.DataFrame(data[key]['data']).sort_index()
                                resdata[key].columns = data[key]['cols']
                            return resdata
                        else:
                            return data
                    else:
                        return data
                else:
                    return (result.msg)
            else:
                return "Client load failed"
        except Exception as e:
            return f'ERROR: {e}'

    @classmethod
    async def login(cls, username, password):

        try:
            client = await cls.getClient()
            if client:
                setTimeout(client, cls.request_timeout)
                result = await client.login(username, password, get_mac_address(), __version__)
                #print(result)
                client.close()
                if result.status:
                    cls._syssmtoken = result.msg
                    print(result.msg, cls._syssmtoken)
                    print('LOGIN SUCCEED')
                    return True
                else:
                    print(f'LOGIN FAILED : {result.msg}')
                    return False
            else:
                return False
        except Exception as e:
            return f'ERROR: {e}'

    @classmethod
    async def auth(cls, username, authcode):
        try:
            client = await cls.getClient()
            if client:
                setTimeout(client, cls.request_timeout)
                result = await client.auth(username, authcode, False, get_mac_address(), __version__)
                client.close()
                if result.status:
                    cls._systoken = result.msg
                    print('AUTH SUCCEED')
                    return True
                else:
                    print(f'AUTH FAILED : {result.msg}')
                    return False
            else:
                return False
        except Exception as e:
            return f'ERROR: {e}'


def auth(username, authcode):
    return asyncio.run(qedataClient.auth(username, authcode))


def login(username, password):
    return asyncio.run(qedataClient.login(username, password))


def check_auth():
    return qedataClient.check_auth()


async def testClient():
    # ret = asyncio.run(qedataClient.instance().echo('hello world'))
    # print(ret)

    ret = loop.run_until_complete(qedataClient.instance().auth('tester02', '$1$$hBe1fMaNHRGmjFXmE4Gwb/'))
    print(ret)

    security = 'A2405.DCE'
    start_date = '2024-03-20'
    end_date = '2024-03-26'
    freq = 'minute'
    from aio_api import aio_get_price
    result = await aio_get_price(security, start_date, end_date, freq)
    print(result)


if __name__ == '__main__':
    asyncio.run(testClient())
