# AUTOGENERATED! DO NOT EDIT! File to edit: ../00_utils.ipynb.

# %% auto 0
__all__ = ['get_datetime', 'Lag', 'list_SDMX_sources', 'list_all_dataflows', 'load_SDMX_data']

# %% ../00_utils.ipynb 5
#| include: false
import datetime
import os

# %% ../00_utils.ipynb 7
#| include: false
def get_datetime():
    "Returns the time now"
    return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S %Z") 

# %% ../00_utils.ipynb 11
#| include: false
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_is_fitted

class Lag(BaseEstimator, TransformerMixin):
    "A transformer that lags variables"
    def __init__(self, lags=1, jump=0, keep_contemporaneous_X=False):
        self.lags = lags
        self.jump = jump
        self.keep_contemporaneous_X = keep_contemporaneous_X
    
    def fit(
        self, 
        X:np.ndarray, # Array-like data of shape (n_samples, n_features)
        y=None # Array-like data of shape (n_samples,) or (n_samples, n_targets) or None
        ): # A fitted version of the `Lag` instance
        "Fit the `Lag` transformer"       
        self.index = None
        if hasattr(X, "index"):
            self.index = X.index
        else:
            if y is not None and hasattr(y, "index"):
                self.index = y.index
        X = self._validate_data(X)

        self.effective_lags_ = self.lags + self.jump
        return self

    def transform(
        self, 
        X:np.ndarray, # Array-like data of shape (n_samples, n_features)
        ): # A lagged version of `X`
        "Lag the dataset `X`"
        X_forlag = X
        
        X = self._validate_data(X)
        check_is_fitted(self)
        X_lags = []
        X_colnames = list(self.feature_names_in_) if self.keep_contemporaneous_X else []
        for lag in range(self.effective_lags_):
            if lag < self.jump:
                continue
            lag_count = lag+1
            lag_X = np.roll(X_forlag, lag_count, axis=0)
            X_lags.append(lag_X)
            if hasattr(self, "feature_names_in_"):
                X_colnames = X_colnames + [col+"_lag_"+str(lag+1) for col in list(self.feature_names_in_)]
        X = np.concatenate(X_lags, axis=1)
        if self.keep_contemporaneous_X:
            X = np.concatenate([X_forlag, X], axis=1)
        X = X[self.effective_lags_:, :]
        if hasattr(self, "index") and self.index is not None:
            new_index = self.index[self.effective_lags_:]
            X = pd.DataFrame(X, index=new_index, columns=X_colnames)
        else:
            X = pd.DataFrame(X)
        return X

# %% ../00_utils.ipynb 23
#| include: false
import pandasdmx as sdmx

def list_SDMX_sources(): # The list of codes representing the SDMX sources available for data download
    "Fetch the list of SDMX sources"
    return sdmx.list_sources()

# %% ../00_utils.ipynb 26
#| include: false
import pandas as pd
import pandasdmx as sdmx

def list_all_dataflows(
    codes_only:bool=False, # Whether to return only the dataflow codes
    return_pandas:bool=True # Whether to return the result in a pandas DataFrame format
    ): # All available dataflows for all SDMX sources used by gingado
    "List all SDMX dataflows. Note: When using as a parameter to an `AugmentSDMX` object or to the `load_SDMX_data` function, set `codes_only=True`"
    sources = sdmx.list_sources()
    dflows = {}
    for src in sources:
        try:
            dflows[src] = sdmx.to_pandas(sdmx.Request(src).dataflow().dataflow)
            dflows[src] = dflows[src].index if codes_only else dflows[src].index.reset_index()
        except:
            pass
    if return_pandas:
        dflows = pd.concat({
            src: pd.DataFrame.from_dict(dflows)
            for src, dflows in dflows.items()
            })[0].rename('dataflow')
    return dflows

# %% ../00_utils.ipynb 37
#| include: false
import pandasdmx as sdmx

def load_SDMX_data(
    sources:dict, # A dictionary with the sources and dataflows per source
    keys:dict, # The keys to be used in the SDMX query
    params:dict, # The parameters to be used in the SDMX query
    verbose:bool=True # Whether to communicate download steps to the user
    ): # A pandas DataFrame with data from SDMX or None if no data matches the sources, keys and parameters
    "Loads datasets from SDMX."
    data_sdmx = {}
    for source in sources.keys():
        src_conn = sdmx.Request(source)
        src_dflows = src_conn.dataflow()
        if sources[source] == 'all':
            dflows = {k: v for k, v in src_dflows.dataflow.items()}
        else:
            dflows = {k: v for k, v in src_dflows.dataflow.items() if k in sources[source]}
        for dflow in dflows.keys():
            if verbose: print(f"Querying data from {source}'s dataflow '{dflow}' - {dflows[dflow].dict()['name']}...")
            try:
                data = sdmx.to_pandas(src_conn.data(dflow, key=keys, params=params), datetime='TIME_PERIOD')
            except:
                if verbose: print("this dataflow does not have data in the desired frequency and time period.")
                continue
            data.columns = ['__'.join(col) for col in data.columns.to_flat_index()]
            data_sdmx[source+"__"+dflow] = data

    if len(data_sdmx.keys()) is None:
        return

    df = pd.concat(data_sdmx, axis=1)
    df.columns = ['_'.join(col) for col in df.columns.to_flat_index()]
    return df
