"""Merge multiple (climate) datasets into one using quantile reduction."""

from pathlib import Path

from pydantic import Field, model_validator

from hydroflows._typing import ListOfPath, OutputDirPath, WildcardPath
from hydroflows.methods.raster.merge_utils import merge_raster_datasets
from hydroflows.methods.utils.io import to_netcdf
from hydroflows.workflow.method import ReduceMethod
from hydroflows.workflow.method_parameters import Parameters

__all__ = ["MergeGriddedDatasets", "Input", "Output", "Params"]


class Input(Parameters):
    """Input parameters.

    This class represents the input data
    required for the :py:class:`MergeGriddedDatasets` method.
    """

    datasets: ListOfPath | WildcardPath
    """The path(s) to the change factor datasets of the different climate models."""


class Output(Parameters):
    """output parameters.

    this class represents the output data
    generated by the :py:class:`MergeGriddedDatasets` method.
    """

    merged_dataset: Path
    """Path to the single output merged dataset."""


class Params(Parameters):
    """Parameters for the :py:class:`MergeGriddedDatasets`.

    Instances of this class are used in the :py:class:`MergeGriddedDatasets`
    method to define the required settings.
    """

    output_dir: OutputDirPath
    """
    The output directory of the dataset.
    """

    output_name: str | None = None
    """
    The name of the output file.
    """

    aligned: bool = False
    """Whether the datasets are already aligned or not"""

    res: float = 0.25
    """Resolution (in degrees) of the resulting dataset. Default is 0.25 degrees."""

    quantile: float = Field(0.5, ge=0, le=1)
    """The quantile to reduce the input datasets."""

    reduce_dim: str = "model"
    """The dimension to reduce the datasets along. Default is "model"."""

    @model_validator(mode="after")
    def check_output_name(self):
        """Check if the output name is set."""
        # get common part of the input file names
        if self.output_name is None:
            self.output_name = f"merged_q{int(self.quantile*100)}.nc"
        return self


class MergeGriddedDatasets(ReduceMethod):
    """Merge multiple (climate) datasets into one using quantile reduction.

    Parameters
    ----------
    datasets : ListOfPath, WildcardPath
        List of paths of the datasets for merging.
    output_dir : Path
        The output directory of reduced dataset.
    output_name : str, optional
        The name of the output file. Default is None.
    quantile : float
        The quantile the merged data should be given the input datasets.
    reduce_dim : str
        The dimension to reduce the datasets along. Default is "model"
        This dimension will be added if not present in the datasets.
    **params
        Additional parameters to pass to the MergeGriddedDatasets instance.
        See :py:class:`merge Params <hydroflows.methods.raster.merge.Params>`.

    See Also
    --------
    :py:class:`merge Input <~hydroflows.methods.raster.merge.Input>`
    :py:class:`merge Output <~hydroflows.methods.raster.merge.Output>`
    :py:class:`merge Params <~hydroflows.methods.raster.merge.Params>`
    """

    name: str = "merge_gridded_datasets"

    _test_kwargs = {
        "datasets": [Path("change1.nc"), Path("change2.nc")],
        "output_dir": Path("data"),
    }

    def __init__(
        self,
        datasets: ListOfPath | WildcardPath,
        output_dir: Path,
        output_name: str | None = None,
        reduce_dim: str = "model",
        quantile: float = 0.5,
        **params,
    ) -> None:
        self.input: Input = Input(datasets=datasets)
        self.params: Params = Params(
            output_dir=output_dir,
            quantile=quantile,
            reduce_dim=reduce_dim,
            output_name=output_name,
            **params,
        )

        self.output: Output = Output(
            merged_dataset=self.params.output_dir / self.params.output_name
        )

    def _run(self):
        """Run the merge datasets method."""
        merged_ds = merge_raster_datasets(
            self.input.datasets,
            aligned=self.params.aligned,
            res=self.params.res,
            quantile=self.params.quantile,
        )

        to_netcdf(
            merged_ds,
            file_name=self.output.merged_dataset.name,
            output_dir=self.output.merged_dataset.parent,
        )
