"""Data pipeline (handler) to convert TEPCO dataset ([Source](https://radioactivity.nsr.go.jp/ja/list/349/list-1.html)) to `NetCDF` format"""

# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/handlers/tepco.ipynb.

# %% auto 0
__all__ = ['fname_coastal_water', 'fname_clos1F', 'fname_iaea_orbs', 'fname_out', 'common_coi', 'nuclides_pattern',
           'unit_mapping', 'nuclide_mapping', 'kw', 'find_location_section', 'fix_sampling_time',
           'get_coastal_water_df', 'get_locs_coastal_water', 'get_clos1F_df', 'get_locs_clos1F', 'get_locs_orbs',
           'concat_locs', 'align_dfs', 'concat_dfs', 'georef_data', 'load_data', 'FixMissingValuesCB',
           'RemoveJapanaseCharCB', 'FixRangeValueStringCB', 'SelectColsOfInterestCB', 'WideToLongCB', 'RemapUnitNameCB',
           'RemapNuclideNameCB', 'RemapDLCB', 'ParseTimeCB', 'get_attrs', 'encode']

# %% ../../nbs/handlers/tepco.ipynb 3
import warnings
warnings.filterwarnings('ignore')

# %% ../../nbs/handlers/tepco.ipynb 4
import pandas as pd
import re
import numpy as np
import fastcore.all as fc
from tqdm import tqdm
from collections import defaultdict

from marisco.callbacks import (
    Callback, 
    Transformer,
    EncodeTimeCB, 
    SanitizeLonLatCB,
    EncodeTimeCB, 
    )

from ..configs import cfg
from ..encoders import NetCDFEncoder

from marisco.metadata import (
    GlobAttrsFeeder, 
    BboxCB,
    TimeRangeCB,
    ZoteroCB, 
    KeyValuePairCB    
    )

from ..netcdf2csv import decode

# %% ../../nbs/handlers/tepco.ipynb 6
fname_coastal_water = 'https://radioactivity.nra.go.jp/cont/en/results/sea/coastal_water.csv'
fname_clos1F = 'https://radioactivity.nra.go.jp/cont/en/results/sea/close1F_water.xlsx'
fname_iaea_orbs = 'https://raw.githubusercontent.com/RML-IAEA/iaea.orbs/refs/heads/main/src/iaea/orbs/stations/station_points.csv'

fname_out = '../../_data/output/tepco.nc'

# %% ../../nbs/handlers/tepco.ipynb 10
def find_location_section(df, 
                          col_idx=0,
                          pattern='Sampling point number'
                          ):
    "Find the line number where location data begins."
    mask = df.iloc[:, col_idx] == pattern
    indices = df[mask].index
    return indices[0] if len(indices) > 0 else -1

# %% ../../nbs/handlers/tepco.ipynb 13
def fix_sampling_time(x):
    if pd.isna(x): 
        return '00:00:00'
    else:
        hour, min =  x.split(':')[:2]
        return f"{hour if len(hour) == 2 else '0' + hour}:{min}:00"

# %% ../../nbs/handlers/tepco.ipynb 14
def get_coastal_water_df(fname_coastal_water):
    "Get the measurements dataframe from the `coastal_water.csv` file."
    
    locs_idx = find_location_section(pd.read_csv(fname_coastal_water, 
                                      skiprows=0, low_memory=False))
    
    df = pd.read_csv(fname_coastal_water, skiprows=1, 
                     nrows=locs_idx - 1,
                     low_memory=False)
    df.dropna(subset=['Sampling point number'], inplace=True)
    df['Sampling time'] = df['Sampling time'].map(fix_sampling_time)
    
    df['TIME'] = df['Sampling date'].replace('-', '/') + ' ' + df['Sampling time']
    
    df = df.drop(columns=['Sampling date', 'Sampling time'])
    return df

# %% ../../nbs/handlers/tepco.ipynb 17
def get_locs_coastal_water(fname_coastal_water):
    locs_idx = find_location_section(pd.read_csv(fname_coastal_water, 
                                      skiprows=0, low_memory=False))
    
    df = pd.read_csv(fname_coastal_water, skiprows=locs_idx+1, 
                     low_memory=False).iloc[:, :3]
    
    df.columns = ['station', 'LON', 'LAT']
    df.dropna(subset=['LAT'], inplace=True)
    df['org'] = 'coastal_seawater.csv'
    return df

# %% ../../nbs/handlers/tepco.ipynb 21
def get_clos1F_df(fname_clos1F):
    "Get measurements dataframe from close1F_water.xlsx file and parse datetime."
    excel_file = pd.ExcelFile(fname_clos1F)
    dfs = {}
    
    for sheet_name in tqdm(excel_file.sheet_names):
        locs_idx = find_location_section(pd.read_excel(excel_file, 
                                                       sheet_name=sheet_name,
                                                       skiprows=1))
        df = pd.read_excel(excel_file, 
                   sheet_name=sheet_name, 
                   skiprows=1,
                   nrows=locs_idx-1)
        
        df.dropna(subset=['Sampling point number'], inplace=True)
        df['Sampling date'] = df['Sampling date']\
            .astype(str)\
            .apply(lambda x: x.split(' ')[0]\
            .replace('-', '/'))
            
        dfs[sheet_name] = df
    
    df = pd.concat(dfs.values(), ignore_index=True)
    df.dropna(subset=['Sampling date'], inplace=True)
    df['TIME'] = df['Sampling date'] + ' ' + df['Sampling time'].astype(str)
    df = df.drop(columns=['Sampling date', 'Sampling time'])
    return df

# %% ../../nbs/handlers/tepco.ipynb 24
def get_locs_clos1F(fname_clos1F):
    "Get locations dataframe from close1F_water.xlsx file from each sheets."
    excel_file = pd.ExcelFile(fname_clos1F)
    dfs = {}
    
    for sheet_name in tqdm(excel_file.sheet_names):
        locs_idx = find_location_section(pd.read_excel(excel_file, 
                                                       sheet_name=sheet_name,
                                                       skiprows=1))
        df = pd.read_excel(excel_file, 
                           sheet_name=sheet_name, 
                           skiprows=locs_idx+2)
            
        dfs[sheet_name] = df
    
    df = pd.concat(dfs.values(), ignore_index=True).iloc[:, :3]
    df.dropna(subset=['Sampling coordinate North latitude (Decimal)'], inplace=True)    
    df.columns = ['station', 'LON', 'LAT']
    df['org'] = 'close1F.csv'
    return df

# %% ../../nbs/handlers/tepco.ipynb 27
def get_locs_orbs(fname_iaea_orbs):
    df = pd.read_csv(fname_iaea_orbs)
    df.columns = ['org', 'station', 'LON', 'LAT']
    return df

# %% ../../nbs/handlers/tepco.ipynb 29
def concat_locs(dfs):
    "Concatenate and drop duplicates from coastal_seawater.csv and iaea_orbs.csv (kept)"
    df = pd.concat(dfs)
    # Group by org to be used for sorting
    df['org_grp'] = df['org'].apply(
        lambda x: 1 if x == 'coastal_seawater.csv' else 2 if x == 'close1F.csv' else 0)
    df.sort_values('org_grp', ascending=True, inplace=True)
    # Drop duplicates and keep orbs data first
    df.drop_duplicates(subset='station', keep='first', inplace=True)
    df.drop(columns=['org_grp'], inplace=True)
    df.sort_values('station', ascending=True, inplace=True)
    return df

# %% ../../nbs/handlers/tepco.ipynb 31
def align_dfs(df_from, df_to):
    "Align columns structure of df_from to df_to."
    df = defaultdict()    
    for c in df_to.columns:
        df[c] = df_from[c].values if c in df_from.columns else np.NAN
    return pd.DataFrame(df)

# %% ../../nbs/handlers/tepco.ipynb 33
def concat_dfs(df_coastal_water, df_clos1F):
    "Concatenate and drop duplicates from coastal_seawater.csv and close1F_water.xlsx (kept)"
    df_clos1F = align_dfs(df_clos1F, df_coastal_water)
    df = pd.concat([df_coastal_water, df_clos1F])
    return df

# %% ../../nbs/handlers/tepco.ipynb 35
def georef_data(df_meas, df_locs):
    "Georeference measurements dataframe using locations dataframe."
    assert "Sampling point number" in df_meas.columns and "station" in df_locs.columns
    return pd.merge(df_meas, df_locs, how="inner", 
                    left_on='Sampling point number', right_on='station')

# %% ../../nbs/handlers/tepco.ipynb 37
def load_data(fname_coastal_water, fname_clos1F, fname_iaea_orbs):
    "Load, align and georeference TEPCO data"
    df_locs = concat_locs(
        [get_locs_coastal_water(fname_coastal_water), 
         get_locs_clos1F(fname_clos1F),
         get_locs_orbs(fname_iaea_orbs)])
    df_meas = concat_dfs(get_coastal_water_df(fname_coastal_water), get_clos1F_df(fname_clos1F))
    df_meas.dropna(subset=['Sampling point number'], inplace=True)
    return {'SEAWATER': georef_data(df_meas, df_locs)}

# %% ../../nbs/handlers/tepco.ipynb 44
class FixMissingValuesCB(Callback):
    "Assign `NaN` to values equal to `ND` (not detected) - to be confirmed "
    def __call__(self, tfm): 
        for k in tfm.dfs.keys():
            predicate = tfm.dfs[k] == 'ND'
            tfm.dfs[k][predicate] = np.nan

# %% ../../nbs/handlers/tepco.ipynb 48
class RemoveJapanaseCharCB(Callback):
    "Remove 約 (about) char"
    def _transform_if_about(self, value, about_char='約'):
        if pd.isna(value): return value
        return (value.replace(about_char, '') if str(value).count(about_char) != 0 
                else value)
    
    def __call__(self, tfm): 
        for k in tfm.dfs.keys():
            cols_rdn = [c for c in tfm.dfs[k].columns if ('(Bq/L)' in c) and (tfm.dfs[k][c].dtype == 'object')]
            tfm.dfs[k][cols_rdn] = tfm.dfs[k][cols_rdn].map(self._transform_if_about)

# %% ../../nbs/handlers/tepco.ipynb 52
class FixRangeValueStringCB(Callback):
    "Replace range values (e.g '4.0E+00<&<8.0E+00' or '1.0～2.7') by their mean"
    
    def _extract_and_calculate_mean(self, s):
        # For scientific notation ranges
        float_strings = re.findall(r"[+-]?\d+\.?\d*E?[+-]?\d*", s)
        if float_strings:
            float_numbers = np.array(float_strings, dtype=float)
            return float_numbers.mean()
        return s
    
    def _transform_if_range(self, value):
        if pd.isna(value): 
            return value
        value = str(value)
        # Check for both range patterns
        if '<&<' in value or '～' in value:
            return self._extract_and_calculate_mean(value)
        return value

    def __call__(self, tfm): 
        for k in tfm.dfs.keys():
            cols_rdn = [c for c in tfm.dfs[k].columns 
                       if ('(Bq/L)' in c) and (tfm.dfs[k][c].dtype == 'object')]
            tfm.dfs[k][cols_rdn] = tfm.dfs[k][cols_rdn].map(self._transform_if_range).astype(float)

# %% ../../nbs/handlers/tepco.ipynb 56
common_coi = ['org', 'LON', 'LAT', 'TIME', 'station']
nuclides_pattern = '(Bq/L)'

# %% ../../nbs/handlers/tepco.ipynb 57
class SelectColsOfInterestCB(Callback):
    "Select columns of interest."
    def __init__(self, common_coi, nuclides_pattern): fc.store_attr()
    def __call__(self, tfm):
        nuc_of_interest = [c for c in tfm.dfs['SEAWATER'].columns if nuclides_pattern in c]
        tfm.dfs['SEAWATER'] = tfm.dfs['SEAWATER'][self.common_coi + nuc_of_interest]

# %% ../../nbs/handlers/tepco.ipynb 60
class WideToLongCB(Callback):
    """
    Parse TEPCO measurement columns to extract nuclide name, measurement value, 
    detection limit and uncertainty
    """
    def __init__(self): fc.store_attr()
    
    
    def _melt(self, df):
        "Melt dataframe to long format."
        return df.melt(id_vars=['LON', 'LAT', 'TIME', 'station'])
        
    def _extract_nuclide(self, text):
        words = text.split(' ')
        # Handle special cases for alpha/beta
        if len(words) >= 2 and words[1].lower() in ['alpha', 'beta']:
            return f"{words[0]} {words[1]}"
        return words[0]
    
    def _nuclide_name(self, df):
        "Extract nuclide name from nuclide names."
        df['NUCLIDE'] = df['variable'].map(self._extract_nuclide)
        return df
    
    def _type_indicator(self, df):
        "Create type indicators."
        df['is_concentration'] = df['variable'].str.contains('radioactivity concentration')
        df['is_dl'] = df['variable'].str.contains('detection limit')
        df['is_unc'] = df['variable'].str.contains('statistical error')
        return df
    
    def _unit(self, df):
        "Extract unit from nuclide names."
        df['UNIT'] = df['variable'].str.extract(r'\((.*?)\)')
        return df
    
    def _type_column(self, df):
        "Create type column."
        conditions = [
            df['is_concentration'],
            df['is_dl'],
            df['is_unc']
        ]
        choices = ['VALUE', 'DL', 'UNC']
        df['type'] = np.select(conditions, choices)
        df = df.drop(['is_concentration', 'is_dl', 'is_unc'], axis=1)
        return df
    
    def __call__(self, tfm):
        tfm.dfs['SEAWATER'] = self._melt(tfm.dfs['SEAWATER'])
        tfm.dfs['SEAWATER'] = self._nuclide_name(tfm.dfs['SEAWATER'])
        tfm.dfs['SEAWATER'] = self._type_indicator(tfm.dfs['SEAWATER'])
        tfm.dfs['SEAWATER'] = self._unit(tfm.dfs['SEAWATER'])
        tfm.dfs['SEAWATER'] = self._type_column(tfm.dfs['SEAWATER'])
        tfm.dfs['SEAWATER'] = pd.pivot_table(
            tfm.dfs['SEAWATER'],
            values='value',
            index=['LON', 'LAT', 'TIME', 'station', 'NUCLIDE', 'UNIT'],
            columns='type',
            aggfunc='first'
        ).reset_index()
        # reset the index and rename it ID
        tfm.dfs['SEAWATER'].reset_index(inplace=True)
        tfm.dfs['SEAWATER'].rename(columns={'index': 'ID'}, inplace=True)
        

# %% ../../nbs/handlers/tepco.ipynb 63
unit_mapping = {'Bq/L': 3}

# %% ../../nbs/handlers/tepco.ipynb 64
class RemapUnitNameCB(Callback):
    """
    Remap `UNIT` name to MARIS id.
    """
    def __init__(self, unit_mapping): fc.store_attr()
    def __call__(self, tfm):
        tfm.dfs['SEAWATER']['UNIT'] = tfm.dfs['SEAWATER']['UNIT'].map(self.unit_mapping)


# %% ../../nbs/handlers/tepco.ipynb 67
nuclide_mapping = {
    '131I': 29,
    '134Cs': 31,
    '137Cs': 33,
    '125Sb': 24,
    'Total beta': 103,
    '238Pu': 67,
    '239Pu+240Pu': 77,
    '3H': 1,
    '89Sr': 11,
    '90Sr': 12,
    'Total alpha': 104,
    '132I': 100,
    '136Cs': 102,
    '58Co': 8,
    '105Ru': 97,
    '106Ru': 17,
    '140La': 35,
    '140Ba': 34,
    '132Te': 99,
    '60Co': 9,
    '144Ce': 37,
    '54Mn': 6
}

# %% ../../nbs/handlers/tepco.ipynb 68
class RemapNuclideNameCB(Callback):
    """
    Remap `NUCLIDE` name to MARIS id.
    """
    def __init__(self, nuclide_mapping): fc.store_attr()
    def __call__(self, tfm):
        tfm.dfs['SEAWATER']['NUCLIDE'] = tfm.dfs['SEAWATER']['NUCLIDE'].map(self.nuclide_mapping)

# %% ../../nbs/handlers/tepco.ipynb 72
class RemapDLCB(Callback):
    """
    Remap `DL` name to MARIS id.
    """
    def __init__(self): fc.store_attr()
    def dl_mapping(self, value): return 1 if pd.isna(value) else 2
    def __call__(self, tfm): 
        tfm.dfs['SEAWATER']['DL'] = tfm.dfs['SEAWATER']['DL'].map(self.dl_mapping)

# %% ../../nbs/handlers/tepco.ipynb 75
class ParseTimeCB(Callback):
    "Parse time column from TEPCO."
    def __init__(self,
                 time_name='TIME'):
        fc.store_attr()
        
    def __call__(self, tfm):
        tfm.dfs['SEAWATER'][self.time_name] = pd.to_datetime(tfm.dfs['SEAWATER'][self.time_name], 
                                                             format='%Y/%m/%d %H:%M:%S', errors='coerce')

# %% ../../nbs/handlers/tepco.ipynb 82
kw = ['oceanography', 'Earth Science > Oceans > Ocean Chemistry> Radionuclides',
      'Earth Science > Human Dimensions > Environmental Impacts > Nuclear Radiation Exposure',
      'Earth Science > Oceans > Ocean Chemistry > Ocean Tracers, Earth Science > Oceans > Marine Sediments',
      'Earth Science > Oceans > Ocean Chemistry, Earth Science > Oceans > Sea Ice > Isotopes',
      'Earth Science > Oceans > Water Quality > Ocean Contaminants',
      'Earth Science > Biological Classification > Animals/Vertebrates > Fish',
      'Earth Science > Biosphere > Ecosystems > Marine Ecosystems',
      'Earth Science > Biological Classification > Animals/Invertebrates > Mollusks',
      'Earth Science > Biological Classification > Animals/Invertebrates > Arthropods > Crustaceans',
      'Earth Science > Biological Classification > Plants > Macroalgae (Seaweeds)']

# %% ../../nbs/handlers/tepco.ipynb 83
def get_attrs(tfm, zotero_key, kw=kw):
    "Retrieve global attributes from MARIS dump."
    return GlobAttrsFeeder(tfm.dfs, cbs=[
        BboxCB(),
        TimeRangeCB(),
        ZoteroCB(zotero_key, cfg=cfg()),
        KeyValuePairCB('keywords', ', '.join(kw)),
        KeyValuePairCB('publisher_postprocess_logs', ', '.join(tfm.logs))
        ])()

# %% ../../nbs/handlers/tepco.ipynb 85
def encode(
    fname_out: str, # Path to the folder where the NetCDF output will be saved
    **kwargs # Additional keyword arguments
    ):
    "Encode TEPCO data to NetCDF."
    dfs = load_data(fname_coastal_water, fname_clos1F, fname_iaea_orbs)
    
    tfm = Transformer(dfs, cbs=[
        FixMissingValuesCB(),
        RemoveJapanaseCharCB(),
        FixRangeValueStringCB(),
        SelectColsOfInterestCB(common_coi, nuclides_pattern),
        WideToLongCB(),
        RemapUnitNameCB(unit_mapping),
        RemapNuclideNameCB(nuclide_mapping),
        RemapDLCB(),
        ParseTimeCB(),
        EncodeTimeCB(),
        SanitizeLonLatCB()
    ])        
    tfm()
    encoder = NetCDFEncoder(tfm.dfs, 
                            dest_fname=fname_out, 
                            global_attrs=get_attrs(tfm, zotero_key='JEV6HP5A', kw=kw),
                            verbose=kwargs.get('verbose', False)
                            )
    encoder.encode()
