"""Various utilities"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/api/utils.ipynb.

# %% auto 0
__all__ = ['NA', 'get_unique_across_dfs', 'Remapper', 'has_valid_varname', 'get_bbox', 'ddmm_to_dd', 'download_files_in_folder',
           'download_file', 'match_worms', 'Match', 'match_maris_lut', 'test_dfs', 'ExtractNetcdfContents']

# %% ../nbs/api/utils.ipynb 2
from pathlib import Path
from math import modf
from netCDF4 import Dataset
from fastcore.test import test_eq
import fastcore.all as fc
import pandas as pd
import numpy as np
from tqdm import tqdm 
import requests
from shapely import MultiPoint
import jellyfish as jf
from dataclasses import dataclass
from ast import literal_eval


from typing import List, Dict, Callable, Tuple, Optional, Union
from .configs import cache_path

from marisco.configs import (
    NC_VARS,
    )
# from collections.abc import Callable

# %% ../nbs/api/utils.ipynb 4
# TBD: move to configs
NA = 'Not available'

# %% ../nbs/api/utils.ipynb 7
def get_unique_across_dfs(dfs: Dict[str, pd.DataFrame],  # Dictionary of dataframes
                          col_name: str='NUCLIDE', # Column name to extract unique values from
                          as_df: bool=False, # Return a DataFrame of unique values
                          include_nchars: bool=False # Add a column with the number of characters in the value
                          ) -> List[str]: # Returns a list of unique column values across dataframes
    "Get a list of unique column values across dataframes."
    unique_values = list(set().union(*(df[col_name].unique() for df in dfs.values() if col_name in df.columns)))
    if not as_df:
        return unique_values
    else:
        df_uniques = pd.DataFrame(unique_values, columns=['value']).reset_index()
        if include_nchars: df_uniques['n_chars'] = df_uniques['value'].str.len()
        return df_uniques

# %% ../nbs/api/utils.ipynb 13
class Remapper():
    "Remap a data provider lookup table to a MARIS lookup table using fuzzy matching."
    def __init__(self,
                 provider_lut_df: pd.DataFrame, # Data provider lookup table to be remapped
                 maris_lut_fn: Union[Callable, pd.DataFrame], # MARIS lookup table or function returning the path
                 maris_col_id: str, # MARIS lookup table column name for the id
                 maris_col_name: str, # MARIS lookup table column name for the name
                 provider_col_to_match: str, # Data provider lookup table column name for the name to match
                 provider_col_key: str, # Data provider lookup table column name for the key
                 fname_cache: str # Cache file name
                 ):
        fc.store_attr()
        self.cache_file = cache_path() / fname_cache
        # Check if maris_lut is a callable function or already a DataFrame
        if callable(maris_lut_fn):
            self.maris_lut = maris_lut_fn()
        else:
            self.maris_lut = maris_lut_fn
        self.lut = {}

    def generate_lookup_table(self, 
                              fixes={}, # Lookup table fixes
                              as_df=True, # Whether to return a DataFrame
                              overwrite=True):
        "Generate a lookup table from a data provider lookup table to a MARIS lookup table using fuzzy matching."
        self.fixes = fixes
        self.as_df = as_df
        if overwrite or not self.cache_file.exists():
            self._create_lookup_table()
            fc.save_pickle(self.cache_file, self.lut)
        else:
            self.lut = fc.load_pickle(self.cache_file)

        return self._format_output()

    def _create_lookup_table(self):
        df = self.provider_lut_df
        for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing"): 
            self._process_row(row)

    def _process_row(self, row):
        value_to_match = row[self.provider_col_to_match]
        if isinstance(value_to_match, str):  # Only process if value is a string
            # If value is in fixes, use the fixed value
            name_to_match = self.fixes.get(value_to_match, value_to_match)
            result = match_maris_lut(self.maris_lut, name_to_match, self.maris_col_id, self.maris_col_name).iloc[0]
            match = Match(result[self.maris_col_id], result[self.maris_col_name], 
                          value_to_match, result['score'])
            self.lut[row[self.provider_col_key]] = match
        else:
            # Handle non-string values (e.g., NaN)
            self.lut[row[self.provider_col_key]] = Match(-1, "Unknown", value_to_match, 0)
            
    def select_match(self, match_score_threshold:int=1, verbose:bool=False):
        if verbose:
            matched_len= len([v for v in self.lut.values() if v.match_score < match_score_threshold])
            print(f"{matched_len} entries matched the criteria, while {len(self.lut) - matched_len} entries had a match score of {match_score_threshold} or higher.")
        
        self.lut = {k: v for k, v in self.lut.items() if v.match_score >= match_score_threshold}
        return self._format_output()

    def _format_output(self):
        if not self.as_df: return self.lut
        df_lut = pd.DataFrame.from_dict(self.lut, orient='index', 
                                        columns=['matched_maris_name', 'source_name', 'match_score'])
        df_lut.index.name = 'source_key'
        return df_lut.sort_values(by='match_score', ascending=False)

# %% ../nbs/api/utils.ipynb 16
# TBD: Assess if still needed
def has_valid_varname(
    var_names: List[str], # variable names
    cdl_path: str, # Path to MARIS CDL file (point of truth)
    group: Optional[str] = None, # Check if the variable names is contained in the group
):
    "Check that proposed variable names are in MARIS CDL"
    has_valid = True
    with Dataset(cdl_path) as nc:
        cdl_vars={}
        all_vars=[]
        # get variable names in CDL 
        for grp in nc.groups.values():
            # Create a list of var for each group
            vars = list(grp.variables.keys())
            cdl_vars[grp.name] = vars
            all_vars.extend(vars)
        
    if group != None:
        allowed_vars= cdl_vars[group]
    else: 
        # get unique 
        allowed_vars = list(set(all_vars))
        
    for name in var_names:
        if name not in allowed_vars:
            has_valid = False
            if group != None:
                print(f'"{name}" variable name not found in group "{group}" of MARIS CDL')
            else:
                print(f'"{name}" variable name not found in MARIS CDL')
    return has_valid  

# %% ../nbs/api/utils.ipynb 20
def get_bbox(df,
             coord_cols: Tuple[str, str] = ('LON', 'LAT')
            ):
    "Get the bounding box of a DataFrame."
    x, y = coord_cols        
    arr = [(row[x], row[y]) for _, row in df.iterrows()]
    return MultiPoint(arr).envelope

# %% ../nbs/api/utils.ipynb 26
def ddmm_to_dd(
    ddmmmm: float # Coordinates in degrees/minutes decimal format
    ) -> float: # Coordinates in degrees decimal format
    # Convert degrees/minutes decimal to degrees decimal.
    mins, degs = modf(ddmmmm)
    mins = mins * 100
    return round(int(degs) + (mins / 60), 6)

# %% ../nbs/api/utils.ipynb 29
def download_files_in_folder(
    owner: str, # GitHub owner
    repo: str, # GitHub repository
    src_dir: str, # Source directory
    dest_dir: str # Destination directory
    ):
    "Make a GET request to the GitHub API to get the contents of the folder."
    url = f"https://api.github.com/repos/{owner}/{repo}/contents/{src_dir}"
    response = requests.get(url)

    if response.status_code == 200:
        contents = response.json()

        # Iterate over the files and download them
        for item in contents:
            if item["type"] == "file":
                fname = item["name"]
                download_file(owner, repo, src_dir, dest_dir, fname)
    else:
        print(f"Error: {response.status_code}")

def download_file(owner, repo, src_dir, dest_dir, fname):
    # Make a GET request to the GitHub API to get the raw file contents
    url = f"https://raw.githubusercontent.com/{owner}/{repo}/master/{src_dir}/{fname}"
    response = requests.get(url)

    if response.status_code == 200:
        # Save the file locally
        with open(Path(dest_dir) / fname, "wb") as file:
            file.write(response.content)
        print(f"{fname} downloaded successfully.")
    else:
        print(f"Error: {response.status_code}")

# %% ../nbs/api/utils.ipynb 31
def match_worms(
    name: str # Name of species to look up in WoRMS
    ):
    "Lookup `name` in WoRMS (fuzzy match)."
    url = 'https://www.marinespecies.org/rest/AphiaRecordsByMatchNames'
    params = {
        'scientificnames[]': [name],
        'marine_only': 'true'
    }
    headers = {
        'accept': 'application/json'
    }
    
    response = requests.get(url, params=params, headers=headers)
    
    # Check if the request was successful (status code 200)
    if response.status_code == 200:
        data = response.json()
        return data
    else:
        return -1

# %% ../nbs/api/utils.ipynb 36
@dataclass
class Match:
    "Match between a data provider name and a MARIS lookup table."
    matched_id: int
    matched_maris_name: str
    source_name: str
    match_score: int

# %% ../nbs/api/utils.ipynb 37
def match_maris_lut(
    lut: Union[str, pd.DataFrame, Path], # Either str, Path or DataFrame
    data_provider_name: str, # Name of data provider nomenclature item to look up 
    maris_id: str, # Id of MARIS lookup table nomenclature item to match
    maris_name: str, # Name of MARIS lookup table nomenclature item to match
    dist_fn: Callable = jf.levenshtein_distance, # Distance function
    nresults: int = 10 # Maximum number of results to return
) -> pd.DataFrame:
    "Fuzzy matching data provider and MARIS lookup tables (e.g biota species, sediments, ...)."
    if isinstance(lut, str) or isinstance(lut, Path):
        df = pd.read_excel(lut)  # Load the LUT if a path is provided
    elif isinstance(lut, pd.DataFrame):
        df = lut  # Use the DataFrame directly if provided
    else:
        raise ValueError("lut must be either a file path or a DataFrame")

    df = df.dropna(subset=[maris_name])
    df = df.astype({maris_id: 'int'})
    df['score'] = df[maris_name].str.lower().apply(lambda x: dist_fn(data_provider_name.lower(), x))
    df = df.sort_values(by='score', ascending=True)[:nresults]
    return df[[maris_id, maris_name, 'score']]

# %% ../nbs/api/utils.ipynb 46
def download_files_in_folder(
    owner: str, # GitHub owner
    repo: str, # GitHub repository
    src_dir: str, # Source directory
    dest_dir: str # Destination directory
    ):
    "Make a GET request to the GitHub API to get the contents of the folder"
    url = f"https://api.github.com/repos/{owner}/{repo}/contents/{src_dir}"
    response = requests.get(url)

    if response.status_code == 200:
        contents = response.json()

        # Iterate over the files and download them
        for item in contents:
            if item["type"] == "file":
                fname = item["name"]
                download_file(owner, repo, src_dir, dest_dir, fname)
    else:
        print(f"Error: {response.status_code}")

def download_file(owner, repo, src_dir, dest_dir, fname):
    # Make a GET request to the GitHub API to get the raw file contents
    url = f"https://raw.githubusercontent.com/{owner}/{repo}/master/{src_dir}/{fname}"
    response = requests.get(url)

    if response.status_code == 200:
        # Save the file locally
        with open(Path(dest_dir) / fname, "wb") as file:
            file.write(response.content)
        print(f"{fname} downloaded successfully.")
    else:
        print(f"Error: {response.status_code}")

# %% ../nbs/api/utils.ipynb 48
def test_dfs(
    dfs1: Dict[str, pd.DataFrame], # First dictionary of DataFrames to compare 
    dfs2: Dict[str, pd.DataFrame] # Second dictionary of DataFrames to compare
    ) -> None: # It raises an `AssertionError` if the DataFrames are not equal
    "Compare two dictionaries of DataFrames for equality (also ensuring that columns are in the same order)."
    for grp in dfs1.keys():
        df1, df2 = (df.sort_index() for df in (dfs1[grp], dfs2[grp]))
        fc.test_eq(df1, df2.reindex(columns=df1.columns))

# %% ../nbs/api/utils.ipynb 51
class ExtractNetcdfContents:
    def __init__(self, filename: str, verbose: bool = False):
        "Initialize and extract data from a NetCDF file."
        self.filename = filename
        self.verbose = verbose
        self.dfs = {}  # DataFrames extracted from the NetCDF file
        self.enum_dicts = {}  # Enum dictionaries extracted from the NetCDF file
        self.global_attrs = {}  # Global attributes extracted from the NetCDF file
        self.custom_maps = {}  # Custom maps extracted from the NetCDF file
        self.extract_all()

    def extract_all(self):
        "Extract data, enums, and global attributes from the NetCDF file."
        if not Path(self.filename).exists():
            print(f'File {self.filename} not found.')
            return
        
        with Dataset(self.filename, 'r') as nc:
            self.global_attrs = self.extract_global_attributes(nc)
            for group_name in nc.groups:
                group = nc.groups[group_name]
                self.dfs[group_name.upper()] = self.extract_data(group)
                self.enum_dicts[group_name.upper()] = self.extract_enums(group, group_name)
                self.custom_maps[group_name.upper()] = self.extract_custom_maps(group, group_name)
                
            if self.verbose:
                print("Data extraction complete.")

    def extract_data(self, group) -> pd.DataFrame:
        "Extract data from a group and convert to DataFrame."
        data = {var_name: var[:] for var_name, var in group.variables.items() if var_name not in group.dimensions}
        df = pd.DataFrame(data)
        rename_map = {nc_var: col for col, nc_var in NC_VARS.items() if nc_var in df.columns}
        df = df.rename(columns=rename_map)
        return df

    def extract_enums(self, group, group_name: str) -> Dict:
        "Extract enum dictionaries for variables in a group."
        local_enum_dicts = {}
        for var_name, var in group.variables.items():
            if hasattr(var.datatype, 'enum_dict'):
                local_enum_dicts[var_name] = {str(k): str(v) for k, v in var.datatype.enum_dict.items()}
                if self.verbose:
                    print(f"Extracted enum_dict for {var_name} in {group_name}")
        return local_enum_dicts

    def extract_global_attributes(self, nc) -> Dict:
        "Extract global attributes from the NetCDF file."
        globattrs = {attr: getattr(nc, attr) for attr in nc.ncattrs()}
        return globattrs
    
    def extract_custom_maps(self, group, group_name: str) -> Dict:
        "Extract custom maps from the NetCDF file."
        reverse_nc_vars = {v: k for k, v in NC_VARS.items()}        
        custom_maps = {}
        for var_name, var in group.variables.items():
            attr=f"{var_name}_map"
            if hasattr(var, attr):
                custom_maps[reverse_nc_vars[var.name]] =  literal_eval(getattr(var, attr))
        return custom_maps
