# Define Base Class for Models
from abc import ABC, abstractmethod
from datetime import datetime

from pyarrow import RecordBatch as rb
from qablet_contracts.timetable import TS_EVENT_SCHEMA, py_to_ts


def convert_to_arrow(timetable, dataset):
    """If required, convert timetable and dataset to arrow format."""
    if isinstance(dataset["PRICING_TS"], datetime):
        dataset = dataset.copy()  # shallow copy
        dataset["PRICING_TS"] = py_to_ts(dataset["PRICING_TS"]).value
    if isinstance(timetable["events"], list):
        timetable = timetable.copy()  # shallow copy
        timetable["events"] = rb.from_pylist(
            timetable["events"], schema=TS_EVENT_SCHEMA
        )
    return timetable, dataset


# Define Base Class for State Object for all Models
class ModelStateBase(ABC):
    """Class to maintain the state during a model execution."""

    stats: dict = {}

    def __init__(self, timetable, dataset):
        pass

    def set_stat(self, key: str, val):
        self.stats[key] = val


class Model(ABC):
    """Base class for all models."""

    @abstractmethod
    def state_class(self):
        """The class that maintains state for this model."""
        ...

    @abstractmethod
    def price_method(self):
        """The method that calculates price."""
        ...

    def price(self, timetable, dataset):
        """Calculate price of contract.

        Parameters:
            timetable (dict): timetable for the contract.
            dataset (dict): dataset for the model.

        Returns:
            price (float): price of contract
            stats (dict): stats such as standard error

        """

        timetable, dataset = convert_to_arrow(timetable, dataset)
        model_state = (self.state_class())(timetable, dataset)
        price = self.price_method()(
            timetable["events"],
            model_state,
            dataset,
            timetable.get("expressions", {}),
        )

        return price, model_state.stats
