# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/05_methods.base.ipynb.

# %% ../../nbs/05_methods.base.ipynb 3
from __future__ import annotations
from ..import_essentials import *
from ..data import TabularDataModule
from ..trainer import TrainingConfigs
from copy import deepcopy

# %% auto 0
__all__ = ['BaseCFModule', 'BaseParametricCFModule', 'BasePredFnCFModule']

# %% ../../nbs/05_methods.base.ipynb 4
class BaseCFModule(ABC):
    """Base CF Explanation Module."""
    _data_module: TabularDataModule

    @property
    @abstractmethod
    def name(self):
        """Name of the CF Explanation Module."""
        raise NotImplementedError
    
    @property
    def data_module(self) -> TabularDataModule:
        """Binded `DataModule`."""
        return self._data_module

    @abstractmethod
    def generate_cfs(
        self,
        X: jnp.ndarray, # Input to be explained
        pred_fn: Callable = None # Predictive function 
    ) -> jnp.ndarray: # Generated counterfactuals
        """Abstract method to generate counterfactuals"""
        pass

    def hook_data_module(self, data_module: TabularDataModule):
        """Bind `TabularDataModule` to `self._data_module`."""
        self._data_module = data_module


# %% ../../nbs/05_methods.base.ipynb 9
class BaseParametricCFModule(ABC):
    @abstractmethod
    def train(
        self, 
        datamodule: TabularDataModule, # data module
        t_configs: TrainingConfigs | dict = None, # training configs; see docs in `TrainingConfigs`
        pred_fn: Callable = None # predictive function
    ): 
        pass

    @abstractmethod
    def _is_module_trained(self) -> bool: pass

# %% ../../nbs/05_methods.base.ipynb 11
class BasePredFnCFModule(ABC):
    """Base class of CF Module with a predictive module."""
    @abstractmethod
    def pred_fn(
        self, 
        X: jnp.DeviceArray  # input `X`
    ) -> jnp.DeviceArray:   # prediction
        raise NotImplementedError
