import requests
from requests.adapters import HTTPAdapter, Retry
import pandas as pd
from .agsi_mappings import AGSICompany, AGSIStorage, AGSICountry, lookup_company, lookup_storage, lookup_country
from .alsi_mappings import ALSITerminal, ALSILSO, ALSICountry, lookup_terminal, lookup_lso, \
    lookup_country as lookup_country_alsi
from .exceptions import *
from enum import Enum

__title__ = "gie-py"
__version__ = "0.4.6"
__author__ = "Frank Boerman"
__license__ = "MIT"


class APIType(str, Enum):
    AGSI = "https://agsi.gie.eu/api"
    ALSI = "https://alsi.gie.eu/api"


class GieRawClient:
    def __init__(self, api_key):
        self.s = requests.Session()
        retries = Retry(total=5,
                        backoff_factor=0.1,
                        status_forcelist=[500, 502, 503, 504])
        self.s.mount('http://', HTTPAdapter(max_retries=retries))
        self.s.mount('https://', HTTPAdapter(max_retries=retries))
        self.s.headers.update({
            'user-agent': f'gie-py v{__version__} (github.com/fboerman/gie-py)',
            'x-key': api_key
        })

    def _fetch(self, obj, t: APIType,
               start: pd.Timestamp | str, end: pd.Timestamp | str):
        if type(start) is not pd.Timestamp:
            start = pd.Timestamp(start)
        if type(end) is not pd.Timestamp:
            end = pd.Timestamp(end)

        def _fetch_one(start, end, obj, page=1):
            r = self.s.get(t.value, params={
                                               'from': start.strftime('%Y-%m-%d'),
                                               'till': end.strftime('%Y-%m-%d'),
                                               'size': 300,
                                               'page': page
                                           } | obj.get_params())
            r.raise_for_status()

            return r.json()

        r = _fetch_one(start, end, obj)
        data = r['data']
        if r['last_page'] != 1:
            for p in range(2, r['last_page'] + 1):
                data += _fetch_one(start, end, obj, page=p)['data']

        if len(data) == 0:
            raise NoMatchingDataError

        return data

    def query_gas_storage(self, storage: AGSIStorage | str,
                          start: pd.Timestamp | str, end: pd.Timestamp | str) -> list[dict]:
        storage = lookup_storage(storage)
        return self._fetch(storage, APIType.AGSI, start=start, end=end)

    def query_gas_company(self, company: AGSICompany | str,
                          start: pd.Timestamp | str, end: pd.Timestamp | str) -> list[dict]:
        company = lookup_company(company)
        return self._fetch(company, APIType.AGSI, start=start, end=end)

    def query_gas_country(self, country: AGSICountry | str,
                          start: pd.Timestamp | str, end: pd.Timestamp | str) -> list[dict]:
        country = lookup_country(country)
        return self._fetch(country, APIType.AGSI, start=start, end=end)

    def query_lng_terminal(self, terminal: ALSITerminal | str,
                           start: pd.Timestamp | str, end: pd.Timestamp | str) -> list[dict]:
        terminal = lookup_terminal(terminal)
        return self._fetch(terminal, APIType.ALSI, start=start, end=end)

    def query_lng_lso(self, lso: ALSILSO | str,
                      start: pd.Timestamp | str, end: pd.Timestamp | str) -> list[dict]:
        lso = lookup_lso(lso)
        return self._fetch(lso, APIType.ALSI, start=start, end=end)

    def query_lng_country(self, country: ALSICountry | str,
                          start: pd.Timestamp | str, end: pd.Timestamp | str) -> list[dict]:
        country = lookup_country_alsi(country)
        return self._fetch(country, APIType.ALSI, start=start, end=end)


class GiePandasClient(GieRawClient):
    @staticmethod
    def _fix_dataframe(data):
        def _fix_values(x):
            if 'inventory' in x:
                x['inventory'] = x['inventory']['lng']
            if 'dtmi' in x:
                x['dtmi'] = x['dtmi']['lng']
            return x

        df = pd.DataFrame([_fix_values(x) for x in data])
        for c in ['name', 'code', 'url', 'info']:
            if c in df:
                df = df.drop(columns=c)
        df = df.loc[df['status'] != 'N']
        df['gasDayStart'] = pd.to_datetime(df['gasDayStart'])
        df = df.set_index('gasDayStart')
        # status is only str column, save it for now, convert whole dataframe to float, restore status
        status = df['status'].copy()
        updated_at = pd.to_datetime(df['updatedAt'])
        df = df.drop(columns=['status', 'updatedAt'])
        if 'type' in df:
            df = df.drop(columns=['type'])
        for column in df.columns:
            try:
                df[column] = pd.to_numeric(df[column].replace('-', 0).infer_objects())
            except ValueError:
                pass
        # df = df.replace('-', 0).astype(float)
        df['status'] = status
        df['updatedAt'] = updated_at
        return df

    def query_gas_storage(self, storage: AGSIStorage | str,
                          start: pd.Timestamp | str, end: pd.Timestamp | str) -> pd.DataFrame:
        return self._fix_dataframe(
            super().query_gas_storage(storage=storage, start=start, end=end)
        )

    def query_gas_company(self, company: AGSIStorage | str,
                          start: pd.Timestamp | str, end: pd.Timestamp | str) -> pd.DataFrame:
        return self._fix_dataframe(
            super().query_gas_company(company=company, start=start, end=end)
        )

    def query_gas_country(self, country: AGSICountry | str,
                          start: pd.Timestamp | str, end: pd.Timestamp | str) -> pd.DataFrame:
        return self._fix_dataframe(
            super().query_gas_country(country=country, start=start, end=end)
        )

    def query_lng_terminal(self, terminal: ALSITerminal | str,
                           start: pd.Timestamp | str, end: pd.Timestamp | str) -> pd.DataFrame:
        return self._fix_dataframe(
            super().query_lng_terminal(terminal=terminal, start=start, end=end)
        )

    def query_lng_lso(self, lso: ALSILSO | str,
                      start: pd.Timestamp | str, end: pd.Timestamp | str) -> pd.DataFrame:
        return self._fix_dataframe(
            super().query_lng_lso(lso=lso, start=start, end=end)
        )

    def query_lng_country(self, country: ALSICountry | str,
                          start: pd.Timestamp | str, end: pd.Timestamp | str) -> list[dict]:
        return self._fix_dataframe(
            super().query_lng_country(country=country, start=start, end=end)
        )
