# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/05-orchestrator.ipynb (unless otherwise specified).

__all__ = ['retry_request', 'if_possible_parse_local_datetime', 'SP_and_date_request', 'handle_capping',
           'date_range_request', 'just_date_request', 'year_request', 'construct_year_month_pairs',
           'year_and_month_request', 'clean_year_week', 'construct_year_week_pairs', 'year_and_week_request',
           'non_temporal_request', 'query_orchestrator']

# Cell
import pandas as pd
from tqdm import tqdm
from warnings import warn
from requests.models import Response

from . import utils, raw

# Cell
def retry_request(raw, method, kwargs, n_attempts=3):
    attempts = 0
    success = False

    while (attempts < n_attempts) and (success == False):
        try:
            r = getattr(raw, method)(**kwargs)
            utils.check_status(r)
            success = True
        except Exception as e:
            attempts += 1
            if attempts == n_attempts:
                raise e

    return r

def if_possible_parse_local_datetime(df):
    dt_cols_with_period_in_name = ['startTimeOfHalfHrPeriod', 'initialForecastPublishingPeriodCommencingTime', 'latestForecastPublishingPeriodCommencingTime', 'outTurnPublishingPeriodCommencingTime']

    dt_cols = [col for col in df.columns if 'date' in col.lower() or col in dt_cols_with_period_in_name]
    sp_cols = [col for col in df.columns if 'period' in col.lower() and col not in dt_cols_with_period_in_name]

    if len(dt_cols)==1 and len(sp_cols)==1:
        df = utils.parse_local_datetime(df, dt_col=dt_cols[0], SP_col=sp_cols[0])

    return df

def SP_and_date_request(
    method: str,
    kwargs_map: dict,
    func_params: list,
    api_key: str,
    start_date: str,
    end_date: str,
    n_attempts: int=3,
    **kwargs
):
    assert start_date is not None, '`start_date` must be specified'
    assert end_date is not None, '`end_date` must be specified'

    df = pd.DataFrame()
    stream = '_'.join(method.split('_')[1:])

    kwargs.update({
        'APIKey': api_key,
        'ServiceType': 'xml'
    })

    df_dates_SPs = utils.dt_rng_to_SPs(start_date, end_date)
    date_SP_tuples = list(df_dates_SPs.reset_index().itertuples(index=False, name=None))[:-1]

    for datetime, query_date, SP in tqdm(date_SP_tuples, desc=stream, total=len(date_SP_tuples)):
        kwargs.update({
            kwargs_map['date']: datetime.strftime('%Y-%m-%d'),
            kwargs_map['SP']: SP,
        })

        missing_kwargs = list(set(func_params) - set(['SP', 'date'] + list(kwargs.keys())))
        assert len(missing_kwargs) == 0, f"The following kwargs are missing: {', '.join(missing_kwargs)}"

        r = retry_request(raw, method, kwargs, n_attempts=n_attempts)

        df_SP = utils.parse_xml_response(r)
        df = df.append(df_SP)

    df = utils.expand_cols(df)
    df = if_possible_parse_local_datetime(df)

    return df

# Cell
def handle_capping(
    r: Response,
    df: pd.DataFrame,
    method: str,
    kwargs_map: dict,
    func_params: list,
    api_key: str,
    end_date: str,
    request_type: str,
    **kwargs
):
    capping_applied = utils.check_capping(r)
    assert capping_applied != None, 'No information on whether or not capping limits had been breached could be found in the response metadata'

    if capping_applied == True: # only subset of date range returned
        dt_cols_with_period_in_name = ['startTimeOfHalfHrPeriod']
        dt_cols = [col for col in df.columns if ('date' in col.lower() or col in dt_cols_with_period_in_name) and ('end' not in col.lower())]

        if len(dt_cols) == 1:
            start_date = pd.to_datetime(df[dt_cols[0]]).max().strftime('%Y-%m-%d')
            if 'start_time' in kwargs.keys():
                kwargs['start_time'] = '00:00'

            if pd.to_datetime(start_date) >= pd.to_datetime(end_date):
                warnings.warn(f'The `end_date` ({end_date}) was earlier than `start_date` ({start_date})\nThe `start_date` will be set one day earlier than the `end_date`.')
                start_date = (pd.to_datetime(end_date) - pd.Timedelta(days=1)).strftime('%Y-%m-%d')

            warn(f'Response was capped, request is rerunning for missing data from {start_date}')
            df_rerun = date_range_request(
                            method=method,
                            kwargs_map=kwargs_map,
                            func_params=func_params,
                            api_key=api_key,
                            start_date=start_date,
                            end_date=end_date,
                            request_type=request_type,
                            **kwargs
                        )

            df = df.append(df_rerun)
            df = df.drop_duplicates()

        else:
            warn(f'Response was capped: a new `start_date` to continue requesting could not be determined automatically, please handle manually for `{method}`')

    return df

def date_range_request(
    method: str,
    kwargs_map: dict,
    func_params: list,
    api_key: str,
    start_date: str,
    end_date: str,
    request_type: str,
    n_attempts: int=3,
    **kwargs
):
    assert start_date is not None, '`start_date` must be specified'
    assert end_date is not None, '`end_date` must be specified'

    kwargs.update({
        'APIKey': api_key,
        'ServiceType': 'xml'
    })

    for kwarg in ['start_time', 'end_time']:
        if kwarg not in kwargs_map.keys():
            kwargs_map[kwarg] = kwarg

    kwargs[kwargs_map['start_date']], kwargs[kwargs_map['start_time']] = pd.to_datetime(start_date).strftime('%Y-%m-%d %H:%M:%S').split(' ')
    kwargs[kwargs_map['end_date']], kwargs[kwargs_map['end_time']] = pd.to_datetime(end_date).strftime('%Y-%m-%d %H:%M:%S').split(' ')

    if 'SP' in kwargs_map.keys():
        kwargs[kwargs_map['SP']] = '*'
        func_params.remove('SP')
        func_params += [kwargs_map['SP']]

    missing_kwargs = list(set(func_params) - set(['start_date', 'end_date', 'start_time', 'end_time'] + list(kwargs.keys())))
    assert len(missing_kwargs) == 0, f"The following kwargs are missing: {', '.join(missing_kwargs)}"

    if request_type == 'date_range':
        kwargs.pop(kwargs_map['start_time'])
        kwargs.pop(kwargs_map['end_time'])

    r = retry_request(raw, method, kwargs, n_attempts=n_attempts)

    df = utils.parse_xml_response(r)
    df = if_possible_parse_local_datetime(df)

    # Handling capping
    df = handle_capping(
        r,
        df,
        method=method,
        kwargs_map=kwargs_map,
        func_params=func_params,
        api_key=api_key,
        end_date=end_date,
        request_type=request_type,
        **kwargs
    )

    return df

# Cell
def just_date_request(
    method: str,
    kwargs_map: dict,
    func_params: list,
    api_key: str,
    start_date: str,
    end_date: str,
    n_attempts: int=3,
    **kwargs
):
    assert start_date is not None, '`start_date` must be specified'
    assert end_date is not None, '`end_date` must be specified'

    df = pd.DataFrame()
    stream = '_'.join(method.split('_')[1:])

    kwargs.update({
        'APIKey': api_key,
        'ServiceType': 'xml'
    })

    dt_rng = pd.date_range(start_date, end_date, freq='D')

    for datetime in tqdm(dt_rng, desc=stream):
        kwargs.update({
            kwargs_map['date']: datetime.strftime('%Y-%m-%d'),
        })

        missing_kwargs = list(set(func_params) - set(['date'] + list(kwargs.keys())))
        assert len(missing_kwargs) == 0, f"The following kwargs are missing: {', '.join(missing_kwargs)}"

        r = retry_request(raw, method, kwargs, n_attempts=n_attempts)

        df_SP = utils.parse_xml_response(r)
        df = df.append(df_SP)

    df = utils.expand_cols(df)
    df = if_possible_parse_local_datetime(df)

    return df

# Cell
def year_request(
    method: str,
    kwargs_map: dict,
    func_params: list,
    api_key: str,
    start_date: str,
    end_date: str,
    n_attempts: int=3,
    **kwargs
):
    assert start_date is not None, '`start_date` must be specified'
    assert end_date is not None, '`end_date` must be specified'

    df = pd.DataFrame()
    stream = '_'.join(method.split('_')[1:])

    kwargs.update({
        'APIKey': api_key,
        'ServiceType': 'xml'
    })

    start_year = int(pd.to_datetime(start_date).strftime('%Y'))
    end_year = int(pd.to_datetime(end_date).strftime('%Y'))

    for year in tqdm(range(start_year, end_year+1), desc=stream):
        kwargs.update({kwargs_map['year']: year})

        missing_kwargs = list(set(func_params) - set(['year'] + list(kwargs.keys())))
        assert len(missing_kwargs) == 0, f"The following kwargs are missing: {', '.join(missing_kwargs)}"

        r = retry_request(raw, method, kwargs, n_attempts=n_attempts)

        df_year = utils.parse_xml_response(r)
        df = df.append(df_year)

    df = if_possible_parse_local_datetime(df)

    return df

# Cell
def construct_year_month_pairs(start_date, end_date):
    dt_rng = pd.date_range(start_date, end_date, freq='M')

    if len(dt_rng) == 0:
        year_month_pairs = [tuple(pd.to_datetime(start_date).strftime('%Y %b').split(' '))]
    else:
        year_month_pairs = [tuple(dt.strftime('%Y %b').split(' ')) for dt in dt_rng]

    year_month_pairs = [(int(year), week.upper()) for year, week in year_month_pairs]

    return year_month_pairs

def year_and_month_request(
    method: str,
    kwargs_map: dict,
    func_params: list,
    api_key: str,
    start_date: str,
    end_date: str,
    n_attempts: int=3,
    **kwargs
):
    assert start_date is not None, '`start_date` must be specified'
    assert end_date is not None, '`end_date` must be specified'

    df = pd.DataFrame()
    stream = '_'.join(method.split('_')[1:])

    kwargs.update({
        'APIKey': api_key,
        'ServiceType': 'xml'
    })

    year_month_pairs = construct_year_month_pairs(start_date, end_date)

    for year, month in tqdm(year_month_pairs, desc=stream):
        kwargs.update({
            kwargs_map['year']: year,
            kwargs_map['month']: month
        })

        missing_kwargs = list(set(func_params) - set(['year', 'month'] + list(kwargs.keys())))
        assert len(missing_kwargs) == 0, f"The following kwargs are missing: {', '.join(missing_kwargs)}"

        r = retry_request(raw, method, kwargs, n_attempts=n_attempts)

        df_year = utils.parse_xml_response(r)
        df = df.append(df_year)

    df = if_possible_parse_local_datetime(df)

    return df

# Cell
def clean_year_week(year, week):
    year = int(year)

    if week == '00':
        year = int(year) - 1
        week = 52

    else:
        year = int(year)
        week = int(week.strip('0'))

    return year, week

def construct_year_week_pairs(start_date, end_date):
    dt_rng = pd.date_range(start_date, end_date, freq='W')

    if len(dt_rng) == 0:
        year_week_pairs = [tuple(pd.to_datetime(start_date).strftime('%Y %W').split(' '))]
    else:
        year_week_pairs = [tuple(dt.strftime('%Y %W').split(' ')) for dt in dt_rng]

    year_week_pairs = [clean_year_week(year, week) for year, week in year_week_pairs]

    return year_week_pairs

def year_and_week_request(
    method: str,
    kwargs_map: dict,
    func_params: list,
    api_key: str,
    start_date: str,
    end_date: str,
    n_attempts: int=3,
    **kwargs
):
    assert start_date is not None, '`start_date` must be specified'
    assert end_date is not None, '`end_date` must be specified'

    df = pd.DataFrame()
    stream = '_'.join(method.split('_')[1:])

    kwargs.update({
        'APIKey': api_key,
        'ServiceType': 'xml'
    })

    year_week_pairs = construct_year_week_pairs(start_date, end_date)

    for year, week in tqdm(year_week_pairs, desc=stream):
        kwargs.update({
            kwargs_map['year']: year,
            kwargs_map['week']: week
        })

        missing_kwargs = list(set(func_params) - set(['year', 'week'] + list(kwargs.keys())))
        assert len(missing_kwargs) == 0, f"The following kwargs are missing: {', '.join(missing_kwargs)}"

        r = retry_request(raw, method, kwargs, n_attempts=n_attempts)

        df_year = utils.parse_xml_response(r)
        df = df.append(df_year)

    df = if_possible_parse_local_datetime(df)

    return df

# Cell
def non_temporal_request(
    method: str,
    api_key: str,
    n_attempts: int=3,
    **kwargs
):
    kwargs.update({
        'APIKey': api_key,
        'ServiceType': 'xml'
    })

    r = retry_request(raw, method, kwargs, n_attempts=n_attempts)

    df = utils.parse_xml_response(r)
    df = if_possible_parse_local_datetime(df)

    return df

# Cell
def query_orchestrator(
    method: str,
    api_key: str,
    request_type: str,
    kwargs_map: dict=None,
    func_params: list=None,
    start_date: str=None,
    end_date: str=None,
    n_attempts: int=3,
    **kwargs
):
    if request_type not in ['non_temporal']:
        kwargs.update({
            'kwargs_map': kwargs_map,
            'func_params': func_params,
            'start_date': start_date,
            'end_date': end_date,
        })

    if request_type in ['date_range', 'date_time_range']:
        kwargs.update({
            'request_type': request_type,
        })

    request_type_to_func = {
        'SP_and_date': SP_and_date_request,
        'just_date': just_date_request,
        'date_range': date_range_request,
        'date_time_range': date_range_request,
        'year': year_request,
        'year_and_month': year_and_month_request,
        'year_and_week': year_and_week_request,
        'non_temporal': non_temporal_request
    }

    assert request_type in request_type_to_func.keys(), f"{request_type} must be one of: {', '.join(request_type_to_func.keys())}"
    request_func = request_type_to_func[request_type]

    df = request_func(
        method=method,
        api_key=api_key,
        n_attempts=n_attempts,
        **kwargs
    )

    df = df.reset_index(drop=True)

    return df