"""Algorithm to evaluate a model on remote data."""
from __future__ import annotations

from typing import Any, Dict, List, Type, cast

from marshmallow import Schema as MarshmallowSchema
from marshmallow import fields, post_load
import numpy as np

from bitfount.federated.algorithms.base import _BaseAlgorithmFactory
from bitfount.federated.algorithms.model_algorithms.base import (
    _BaseModelAlgorithmFactory,
    _BaseModellerModelAlgorithm,
    _BaseWorkerModelAlgorithm,
)
from bitfount.federated.logging import _get_federated_logger
from bitfount.hub.api import BitfountHub
from bitfount.metrics import MetricCollection

logger = _get_federated_logger(__name__)


class _ModellerSide(_BaseModellerModelAlgorithm):
    """Modeller side of the ModelEvaluation algorithm."""

    def run(self, results: List[Dict[str, float]]) -> List[Dict[str, float]]:
        """Simply returns results."""
        return results


class _WorkerSide(_BaseWorkerModelAlgorithm):
    """Worker side of the ModelEvaluation algorithm."""

    def run(self) -> Dict[str, float]:
        """Runs evaluation and returns metrics."""
        preds, targs = self.model.evaluate()
        # TODO: [BIT-1604] Remove these cast statements once they become superfluous.
        preds = cast(np.ndarray, preds)
        targs = cast(np.ndarray, targs)
        m = MetricCollection.create_from_model(self.model)
        return m.compute(targs, preds)


class ModelEvaluation(_BaseModelAlgorithmFactory):
    """Algorithm for evaluating a model and returning metrics.

    Args:
        model: The model to evaluate on remote data.

    Attributes:
        name: The name of the algorithm.
        model: The model to evaluate on remote data.

    :::note

    The metrics cannot currently be specified by the user.

    :::
    """

    def modeller(self, **kwargs: Any) -> _ModellerSide:
        """Returns the modeller side of the ModelEvaluation algorithm."""
        model = self._get_model_from_reference()
        return _ModellerSide(model=model, **kwargs)

    def worker(self, hub: BitfountHub, **kwargs: Any) -> _WorkerSide:
        """Returns the worker side of the ModelEvaluation algorithm.

        Args:
            hub: `BitfountHub` object to use for communication with the hub.
        """
        model = self._get_model_from_reference(hub=hub)
        return _WorkerSide(model=model, **kwargs)

    @staticmethod
    def get_schema(
        model_schema: Type[MarshmallowSchema], **kwargs: Any
    ) -> Type[MarshmallowSchema]:
        """Returns the schema for ModelTrainingAndEvaluation.

        Args:
            model_schema: The schema for the underlying model.
        """

        class Schema(_BaseAlgorithmFactory._Schema):

            model = fields.Nested(model_schema)

            @post_load
            def recreate_factory(self, data: dict, **_kwargs: Any) -> ModelEvaluation:
                return ModelEvaluation(**data)

        return Schema
