# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_base.ipynb.

# %% auto 0
__all__ = ['BaseConfig', 'BaseModule', 'PredFnMixedin', 'TrainableMixedin']

# %% ../nbs/00_base.ipynb 2
from pydantic import BaseModel as BasePydanticModel
import json
from pathlib import Path


# %% ../nbs/00_base.ipynb 4
class BaseConfig(BasePydanticModel):
    """Base class for all config classes."""

    def save(self, path):
        p = Path(path)
        if not str(p).endswith('.json'):
            raise ValueError(f"Path must end with `.json`, but got: {p}")
        if not p.parent.exists():
            p.parent.mkdir(parents=True)
        with open(path, 'w') as f:
            json.dump(self.dict(), f, indent=4)
    
    @classmethod
    def load_from_json(cls, path):
        p = Path(path)
        if not p.exists():
            raise FileNotFoundError(f"File not found: {p}")
        with open(path, 'r') as f:
            return cls(**json.load(f))

# %% ../nbs/00_base.ipynb 6
class BaseModule:
    """Base class for all modules."""
    def __init__(self, config, *, name=None):
        self.config = config
        self._name = name

    @property
    def name(self):
        return self._name or self.__class__.__name__
    
    def save(self, path):
        raise NotImplementedError

    @classmethod
    def load_from_path(cls, path):
        raise NotImplementedError

# %% ../nbs/00_base.ipynb 8
class PredFnMixedin:
    """Mixin class for modules that have a `pred_fn` method."""
    
    def pred_fn(self, x):
        """Return the prediction/probability of the model on `x`."""
        raise NotImplementedError
    
    __ALL__ = ['pred_fn']

# %% ../nbs/00_base.ipynb 9
class TrainableMixedin:
    """Mixin class for trainable modules."""
    
    @property
    def is_trained(self) -> bool:
        """Return whether the module is trained or not."""
        raise NotImplementedError
    
    def train(self, data, **kwargs):
        """Train the module."""
        raise NotImplementedError
    
    __ALL__ = ['is_trained', 'train']
