# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/handlers/geotraces.ipynb.

# %% auto 0
__all__ = ['fname_in', 'fname_out_nc', 'zotero_key', 'load_data', 'common_coi', 'nuclides_pattern', 'phase', 'smp_method',
           'nuclides_name', 'units_lut', 'renaming_rules', 'lut_nuclides', 'kw', 'SelectColsOfInterestCB',
           'WideToLongCB', 'ExtractUnitCB', 'ExtractFilteringStatusCB', 'ExtractSamplingMethodCB', 'RenameNuclideCB',
           'StandardizeUnitCB', 'RenameColumnCB', 'UnshiftLongitudeCB', 'DispatchToGroupCB', 'ParseTimeCB', 'get_attrs',
           'encode']

# %% ../../nbs/handlers/geotraces.ipynb 6
import fastcore.all as fc
import pandas as pd
import numpy as np
import re

from marisco.callbacks import (
    Callback, 
    Transformer, 
    SanitizeLonLatCB, 
    EncodeTimeCB,
    RemapCB
)

from marisco.metadata import (
    GlobAttrsFeeder, 
    BboxCB,
    DepthRangeCB, 
    TimeRangeCB,
    ZoteroCB,
    KeyValuePairCB
)

from marisco.configs import (
    AVOGADRO,
    get_lut,
    lut_path,
    cfg
)

from ..netcdf2csv import decode
from ..encoders import NetCDFEncoder

# %% ../../nbs/handlers/geotraces.ipynb 10
fname_in = '../../_data/geotraces/GEOTRACES_IDP2021_v2/seawater/ascii/GEOTRACES_IDP2021_Seawater_Discrete_Sample_Data_v2.csv'
fname_out_nc = '../../_data/output/190-geotraces-2021.nc'
zotero_key = '97UIMEXN'

# %% ../../nbs/handlers/geotraces.ipynb 12
load_data = lambda fname: pd.read_csv(fname_in)

# %% ../../nbs/handlers/geotraces.ipynb 17
common_coi = ['yyyy-mm-ddThh:mm:ss.sss', 'Longitude [degrees_east]',
              'Latitude [degrees_north]', 'Bot. Depth [m]', 'DEPTH [m]', 'BODC Bottle Number:INTEGER']

nuclides_pattern = ['^TRITI', '^Th_228', '^Th_23[024]', '^Pa_231', 
                    '^U_236_[DT]', '^Be_', '^Cs_137', '^Pb_210', '^Po_210',
                    '^Ra_22[3468]', 'Np_237', '^Pu_239_[D]', '^Pu_240', '^Pu_239_Pu_240',
                    '^I_129', '^Ac_227']  

class SelectColsOfInterestCB(Callback):
    "Select columns of interest."
    def __init__(self, common_coi, nuclides_pattern): fc.store_attr()
    def __call__(self, tfm):
        nuc_of_interest = [c for c in tfm.df.columns if 
                           any(re.match(pattern, c) for pattern in self.nuclides_pattern)]

        tfm.df = tfm.df[self.common_coi + nuc_of_interest]

# %% ../../nbs/handlers/geotraces.ipynb 23
class WideToLongCB(Callback):
    """
    Get Geotraces nuclide names as values not column names 
    to extract contained information (unit, sampling method, ...).
    """
    def __init__(self, common_coi, nuclides_pattern, 
                 var_name='NUCLIDE', value_name='VALUE'): 
        fc.store_attr()
        
    def __call__(self, tfm):
        nuc_of_interest = [c for c in tfm.df.columns if 
                           any(re.match(pattern, c) for pattern in self.nuclides_pattern)]
        tfm.df = pd.melt(tfm.df, id_vars=self.common_coi, value_vars=nuc_of_interest, 
                          var_name=self.var_name, value_name=self.value_name)
        tfm.df.dropna(subset=self.value_name, inplace=True)

# %% ../../nbs/handlers/geotraces.ipynb 28
class ExtractUnitCB(Callback):
    """
    Extract units from nuclide names.
    """
    def __init__(self, var_name='NUCLIDE'): 
        fc.store_attr()
        self.unit_col_name = 'UNIT'

    def extract_unit(self, s):
        match = re.search(r'\[(.*?)\]', s)
        return match.group(1) if match else None
        
    def __call__(self, tfm):
        tfm.df[self.unit_col_name] = tfm.df[self.var_name].apply(self.extract_unit)

# %% ../../nbs/handlers/geotraces.ipynb 31
phase = {
    'D': {'FILT': 1, 'group': 'SEAWATER'},
    'T': {'FILT': 2, 'group': 'SEAWATER'},
    'TP': {'FILT': 1, 'group': 'SUSPENDED_MATTER'}, 
    'LPT': {'FILT': 1, 'group': 'SUSPENDED_MATTER'},
    'SPT': {'FILT': 1, 'group': 'SUSPENDED_MATTER'}}

# %% ../../nbs/handlers/geotraces.ipynb 32
class ExtractFilteringStatusCB(Callback):
    "Extract filtering status from nuclide names."
    def __init__(self, phase, var_name='NUCLIDE'): 
        fc.store_attr()
        # self.filt_col_name = cdl_cfg()['vars']['suffixes']['filtered']['name']
        self.filt_col_name = 'FILT'

    def extract_filt_status(self, s):
        matched_string = self.match(s)
        return self.phase[matched_string.group(1)][self.filt_col_name] if matched_string else None

    def match(self, s):
        return re.search(r'_(' + '|'.join(self.phase.keys()) + ')_', s)
        
    def extract_group(self, s):
        matched_string = self.match(s)
        return self.phase[matched_string.group(1)]['group'] if matched_string else None
        
    def __call__(self, tfm):
        tfm.df[self.filt_col_name] = tfm.df[self.var_name].apply(self.extract_filt_status)
        tfm.df['GROUP'] = tfm.df[self.var_name].apply(self.extract_group)

# %% ../../nbs/handlers/geotraces.ipynb 35
# To be validated
smp_method = {
    'BOTTLE': 1,
    'FISH': 18,
    'PUMP': 14,
    'UWAY': 24}

# %% ../../nbs/handlers/geotraces.ipynb 36
class ExtractSamplingMethodCB(Callback):
    "Extract sampling method from nuclide names."
    def __init__(self, 
                 smp_method:dict = smp_method, # Sampling method lookup table
                 var_name='NUCLIDE', # Column name containing nuclide names
                 smp_method_col_name = 'SAMP_MET' # Column name for sampling method in output df
                 ): 
        fc.store_attr()

    def extract_smp_method(self, s):
        match = re.search(r'_(' + '|'.join(self.smp_method.keys()) + ') ', s)
        return self.smp_method[match.group(1)] if match else None
        
    def __call__(self, tfm):
        tfm.df[self.smp_method_col_name] = tfm.df[self.var_name].apply(self.extract_smp_method)

# %% ../../nbs/handlers/geotraces.ipynb 40
nuclides_name = {'TRITIUM': 'h3', 'Pu_239_Pu_240': 'pu239_240_tot'}

# %% ../../nbs/handlers/geotraces.ipynb 41
class RenameNuclideCB(Callback):
    "Remap nuclides name to MARIS standard."
    def __init__(self, nuclides_name, var_name='NUCLIDE'): 
        fc.store_attr()
        self.patterns = ['_D', '_T', '_TP', '_LPT', '_SPT']

    def extract_nuclide_name(self, s):
        match = re.search(r'(.*?)(' + '|'.join(self.patterns) + ')', s)
        return match.group(1) if match else None

    def standardize_name(self, s):
        s = self.extract_nuclide_name(s)
        return self.nuclides_name[s] if s in self.nuclides_name else s.lower().replace('_', '')
        
    def __call__(self, tfm):
        tfm.df[self.var_name] = tfm.df[self.var_name].apply(self.standardize_name)

# %% ../../nbs/handlers/geotraces.ipynb 49
units_lut = {
    'TU': {'id': 7, 'factor': 1},
    'uBq/kg': {'id': 3, 'factor': 1e-6},
    'atoms/kg': {'id': 9, 'factor': 1},
    'mBq/kg': {'id': 3, 'factor': 1e-3},
    'pmol/kg': {'id': 9, 'factor': 1e-12 * AVOGADRO}
    }

# %% ../../nbs/handlers/geotraces.ipynb 50
class StandardizeUnitCB(Callback):
    "Remap unit to MARIS standard ones and apply conversion where needed."
    def __init__(self, 
                 units_lut, 
                 unit_col_name='UNIT',
                 var_name='VALUE'): 
        fc.store_attr()
        # self.unit_col_name = cdl_cfg()['vars']['suffixes']['unit']['name']
        
    def __call__(self, tfm):
        # Convert/rescale values
        tfm.df[self.var_name] *= tfm.df[self.unit_col_name].map(
            {k: v['factor'] for k, v in self.units_lut.items()})
        
        # Match MARIS unit id
        tfm.df[self.unit_col_name] = tfm.df[self.unit_col_name].map(
            {k: v['id'] for k, v in self.units_lut.items()})

# %% ../../nbs/handlers/geotraces.ipynb 54
renaming_rules = {
    'yyyy-mm-ddThh:mm:ss.sss': 'TIME',
    'Longitude [degrees_east]': 'LON',
    'Latitude [degrees_north]': 'LAT',
    'DEPTH [m]': 'SMP_DEPTH',
    'Bot. Depth [m]': 'TOT_DEPTH',
    'BODC Bottle Number:INTEGER': 'SMP_ID'
}

# %% ../../nbs/handlers/geotraces.ipynb 55
class RenameColumnCB(Callback):
    "Renaming variables to MARIS standard names."
    def __init__(self, lut=renaming_rules): fc.store_attr()
    def __call__(self, tfm):
        # lut = self.renaming_rules()
        new_col_names = [self.lut[name] if name in self.lut else name for name in tfm.df.columns]
        tfm.df.columns = new_col_names

# %% ../../nbs/handlers/geotraces.ipynb 59
class UnshiftLongitudeCB(Callback):
    "Longitudes are coded between 0 and 360 in Geotraces. We rescale it between -180 and 180 instead."
    def __init__(self, lon_col_name='LON'): 
        fc.store_attr()
    def __call__(self, tfm):
        tfm.df[self.lon_col_name] = tfm.df[self.lon_col_name] - 180

# %% ../../nbs/handlers/geotraces.ipynb 64
class DispatchToGroupCB(Callback):
    "Convert to a dictionary of dataframe with sample type (seawater,...) as keys."
    def __init__(self, group_name='GROUP'): 
        fc.store_attr()
        
    def __call__(self, tfm):
        tfm.dfs = dict(tuple(tfm.df.groupby(self.group_name)))
        for key in tfm.dfs:
            tfm.dfs[key] = tfm.dfs[key].drop(self.group_name, axis=1)

# %% ../../nbs/handlers/geotraces.ipynb 68
class ParseTimeCB(Callback):
    def __call__(self, tfm, time_col_name='TIME'):
        for k in tfm.dfs.keys():
            tfm.dfs[k][time_col_name] = pd.to_datetime(tfm.dfs[k][time_col_name], 
                                                       format='ISO8601')

# %% ../../nbs/handlers/geotraces.ipynb 77
lut_nuclides = lambda: get_lut(lut_path(), 'dbo_nuclide.xlsx', 
                               key='nc_name', value='nuclide_id', reverse=False)

# %% ../../nbs/handlers/geotraces.ipynb 85
kw = ['oceanography', 'Earth Science > Oceans > Ocean Chemistry> Radionuclides',
      'Earth Science > Human Dimensions > Environmental Impacts > Nuclear Radiation Exposure',
      'Earth Science > Oceans > Ocean Chemistry > Ocean Tracers, Earth Science > Oceans > Marine Sediments',
      'Earth Science > Oceans > Ocean Chemistry, Earth Science > Oceans > Sea Ice > Isotopes',
      'Earth Science > Oceans > Water Quality > Ocean Contaminants',
      'Earth Science > Biological Classification > Animals/Vertebrates > Fish',
      'Earth Science > Biosphere > Ecosystems > Marine Ecosystems',
      'Earth Science > Biological Classification > Animals/Invertebrates > Mollusks',
      'Earth Science > Biological Classification > Animals/Invertebrates > Arthropods > Crustaceans',
      'Earth Science > Biological Classification > Plants > Macroalgae (Seaweeds)']

# %% ../../nbs/handlers/geotraces.ipynb 86
def get_attrs(tfm, zotero_key, kw=kw):
    "Retrieve global attributes from Geotraces dataset."
    return GlobAttrsFeeder(tfm.dfs, cbs=[
        BboxCB(),
        DepthRangeCB(),
        TimeRangeCB(),
        ZoteroCB(zotero_key, cfg=cfg()),
        KeyValuePairCB('keywords', ', '.join(kw)),
        KeyValuePairCB('publisher_postprocess_logs', ', '.join(tfm.logs))
        ])()

# %% ../../nbs/handlers/geotraces.ipynb 89
def encode(fname_in, fname_out_nc, **kwargs):
    df = pd.read_csv(fname_in)
    tfm = Transformer(df, cbs=[
        SelectColsOfInterestCB(common_coi, nuclides_pattern),
        WideToLongCB(common_coi, nuclides_pattern),
        ExtractUnitCB(),
        ExtractFilteringStatusCB(phase),
        ExtractSamplingMethodCB(smp_method),
        RenameNuclideCB(nuclides_name),
        StandardizeUnitCB(units_lut),
        RenameColumnCB(renaming_rules),
        UnshiftLongitudeCB(),
        DispatchToGroupCB(),
        ParseTimeCB(),
        EncodeTimeCB(),
        SanitizeLonLatCB(),
        RemapCB(fn_lut=lut_nuclides, col_remap='NUCLIDE', col_src='NUCLIDE')
        ])
    
    tfm()
    encoder = NetCDFEncoder(tfm.dfs, 
                            # src_fname=nc_tpl_path,
                            dest_fname=fname_out_nc,    
                            global_attrs=get_attrs(tfm, zotero_key=zotero_key, kw=kw),
                            verbose=kwargs.get('verbose', False)
                           )
    encoder.encode()
