from typing import Any, List, Optional, Protocol, TypedDict, Union

import pandas as pd

from snowflake import snowpark


class LocalModelTransformHandlers(Protocol):
    """A protocol defining the behavior of a local execution model transformer."""

    def __init__(
        self,
        dataset: pd.DataFrame,
        estimator: object,
        class_name: str,
        subproject: str,
        autogenerated: Optional[bool] = False,
    ) -> None:
        """
        Args:
            dataset: The dataset to run transform functions on.
            estimator: The estimator used to run transforms.
            class_name: class name to be used in telemetry.
            subproject: subproject to be used in telemetry.
            autogenerated: Whether the class was autogenerated from a template.
        """
        ...

    def batch_inference(
        self,
        inference_method: str,
        input_cols: List[str],
        expected_output_cols: List[str],
        snowpark_input_cols: Optional[List[str]],
        drop_input_cols: Optional[bool] = False,
        *args: Any,
        **kwargs: Any,
    ) -> pd.DataFrame:
        """Run batch inference on the given dataset.

         Args:
            inference_method: the name of the method used by `estimator` to run inference.
            input_cols: column names of the input dataset.
            expected_output_cols: column names (in order) of the output dataset.
            snowpark_input_cols: list of input columns used if estimator is fit in snowpark.
            drop_input_cols: If set True, the response will not contain input columns.
            args: additional positional arguments.
            kwargs: additional keyword args.

        Returns:
            A new dataset of the same type as the input dataset.

        # noqa: DAR202
        (function in protocol definition does not actually return a value)
        """
        ...

    def score(
        self,
        input_cols: List[str],
        label_cols: List[str],
        sample_weight_col: Optional[str],
        *args: Any,
        **kwargs: Any,
    ) -> float:
        """Score the given test dataset.

        Args:
            input_cols: List of feature columns for scoring.
            label_cols: List of label columns for scoring.
            sample_weight_col: A column assigning relative weights to each row for scoring.
            args: additional positional arguments.
            kwargs: additional keyword args.

        Returns:
             An accuracy score for the model on the given test data.

        # noqa: DAR202
        (function in protocol definition does not actually return a value)
        """
        ...


class RemoteModelTransformHandlers(Protocol):
    """A protocol defining behavior of a local execution model transformer."""

    def __init__(
        self,
        dataset: snowpark.DataFrame,
        estimator: object,
        class_name: str,
        subproject: str,
        autogenerated: Optional[bool] = False,
    ) -> None:
        """
        Args:
            dataset: The dataset to run transform functions on.
            estimator: The estimator used to run transforms.
            class_name: class name to be used in telemetry.
            subproject: subproject to be used in telemetry.
            autogenerated: Whether the class was autogenerated from a template.
        """
        ...

    def batch_inference(
        self,
        inference_method: str,
        input_cols: List[str],
        expected_output_cols: List[str],
        session: snowpark.Session,
        dependencies: List[str],
        drop_input_cols: Optional[bool] = False,
        expected_output_cols_type: Optional[str] = "",
        *args: Any,
        **kwargs: Any,
    ) -> snowpark.DataFrame:
        """Run batch inference on the given dataset.

        Args:
            session: An active Snowpark Session.
            dependencies: List of dependencies for the transformer.
            inference_method: the name of the method used by `estimator` to run inference.
            input_cols: List of feature columns for inference.
            expected_output_cols: column names (in order) of the output dataset.
            expected_output_cols_type: Expected type of the output columns.
            drop_input_cols: Boolean to determine drop the input columns from the output dataset or not
            args: additional positional arguments.
            kwargs: additional keyword args.

        Returns:
            A new dataset of the same type as the input dataset.

        # noqa: DAR202
        (function in protocol definition does not actually return a value)
        """
        ...

    def score(
        self,
        input_cols: List[str],
        label_cols: List[str],
        session: snowpark.Session,
        dependencies: List[str],
        score_sproc_imports: List[str],
        sample_weight_col: Optional[str] = None,
        *args: Any,
        **kwargs: Any,
    ) -> float:
        """Score the given test dataset.

        Args:
            session: An active Snowpark Session.
            dependencies: score function dependencies.
            score_sproc_imports: imports for score stored procedure.
            input_cols: List of feature columns for inference.
            label_cols: List of label columns for scoring.
            sample_weight_col: A column assigning relative weights to each row for scoring.
            args: additional positional arguments.
            kwargs: additional keyword args.

        Returns:
            An accuracy score for the model on the given test data.

        # noqa: DAR202
        (function in protocol definition does not actually return a value)
        """
        ...


ModelTransformHandlers = Union[LocalModelTransformHandlers, RemoteModelTransformHandlers]


class BatchInferenceKwargsTypedDict(TypedDict, total=False):
    """A typed dict specifying all possible optional keyword args accepted by batch_inference() methods."""

    snowpark_input_cols: Optional[List[str]]
    drop_input_cols: Optional[bool]
    session: snowpark.Session
    dependencies: List[str]
    expected_output_cols_type: str
    n_neighbors: Optional[int]
    return_distance: bool


class ScoreKwargsTypedDict(TypedDict, total=False):
    """A typed dict specifying all possible optional keyword args accepted by score() methods."""

    session: snowpark.Session
    dependencies: List[str]
    score_sproc_imports: List[str]
