# SPDX-License-Identifier: LGPL-3.0-or-later
import functools
import logging
from typing import (
    Any,
    Callable,
    Optional,
    Union,
)

import paddle

from deepmd.dpmodel import (
    FittingOutputDef,
)
from deepmd.pd.model.descriptor.base_descriptor import (
    BaseDescriptor,
)
from deepmd.pd.model.task.base_fitting import (
    BaseFitting,
)
from deepmd.utils.path import (
    DPPath,
)
from deepmd.utils.version import (
    check_version_compatibility,
)

from .base_atomic_model import (
    BaseAtomicModel,
)

log = logging.getLogger(__name__)


@BaseAtomicModel.register("standard")
class DPAtomicModel(BaseAtomicModel):
    """Model give atomic prediction of some physical property.

    Parameters
    ----------
    descriptor
            Descriptor
    fitting_net
            Fitting net
    type_map
            Mapping atom type to the name (str) of the type.
            For example `type_map[1]` gives the name of the type 1.
    """

    def __init__(
        self,
        descriptor: BaseDescriptor,
        fitting: BaseFitting,
        type_map: list[str],
        **kwargs: Any,
    ) -> None:
        super().__init__(type_map, **kwargs)
        ntypes = len(type_map)
        self.type_map = type_map
        self.ntypes = ntypes
        self.descriptor = descriptor
        self.rcut = self.descriptor.get_rcut()
        self.sel = self.descriptor.get_sel()
        self.fitting_net = fitting
        super().init_out_stat()
        self.enable_eval_descriptor_hook = False
        self.enable_eval_fitting_last_layer_hook = False
        self.eval_descriptor_list = []
        self.eval_fitting_last_layer_list = []

        # register 'type_map' as buffer
        def _string_to_array(s: Union[str, list[str]]) -> list[int]:
            return [ord(c) for c in s]

        if type_map is not None:
            self.register_buffer(
                "buffer_type_map",
                paddle.to_tensor(
                    _string_to_array(" ".join(self.type_map)), dtype="int32"
                ),
            )
            self.buffer_type_map.name = "buffer_type_map"
        if hasattr(self.descriptor, "has_message_passing"):
            # register 'has_message_passing' as buffer(cast to int32 as problems may meets with vector<bool>)
            self.register_buffer(
                "buffer_has_message_passing",
                paddle.to_tensor(self.descriptor.has_message_passing(), dtype="int32"),
            )
            self.buffer_has_message_passing.name = "buffer_has_message_passing"
        # register 'ntypes' as buffer
        self.register_buffer(
            "buffer_ntypes", paddle.to_tensor(self.ntypes, dtype="int32")
        )
        self.buffer_ntypes.name = "buffer_ntypes"
        # register 'rcut' as buffer
        self.register_buffer(
            "buffer_rcut", paddle.to_tensor(self.rcut, dtype="float64")
        )
        self.buffer_rcut.name = "buffer_rcut"
        if hasattr(self.fitting_net, "get_dim_fparam"):
            # register 'dfparam' as buffer
            self.register_buffer(
                "buffer_dfparam",
                paddle.to_tensor(self.fitting_net.get_dim_fparam(), dtype="int32"),
            )
            self.buffer_dfparam.name = "buffer_dfparam"
        if hasattr(self.fitting_net, "get_dim_aparam"):
            # register 'daparam' as buffer
            self.register_buffer(
                "buffer_daparam",
                paddle.to_tensor(self.fitting_net.get_dim_aparam(), dtype="int32"),
            )
            self.buffer_daparam.name = "buffer_daparam"
        # register 'aparam_nall' as buffer
        self.register_buffer(
            "buffer_aparam_nall",
            paddle.to_tensor(False, dtype="int32"),
        )
        self.buffer_aparam_nall.name = "buffer_aparam_nall"

    eval_descriptor_list: list[paddle.Tensor]
    eval_fitting_last_layer_list: list[paddle.Tensor]

    def set_eval_descriptor_hook(self, enable: bool) -> None:
        """Set the hook for evaluating descriptor and clear the cache for descriptor list."""
        self.enable_eval_descriptor_hook = enable
        # = [] does not work; See #4533
        self.eval_descriptor_list.clear()

    def eval_descriptor(self) -> paddle.Tensor:
        """Evaluate the descriptor."""
        return paddle.concat(self.eval_descriptor_list)

    def set_eval_fitting_last_layer_hook(self, enable: bool) -> None:
        """Set the hook for evaluating fitting last layer output and clear the cache for fitting last layer output list."""
        self.enable_eval_fitting_last_layer_hook = enable
        self.fitting_net.set_return_middle_output(enable)
        # = [] does not work; See #4533
        self.eval_fitting_last_layer_list.clear()

    def eval_fitting_last_layer(self) -> paddle.Tensor:
        """Evaluate the fitting last layer output."""
        return paddle.concat(self.eval_fitting_last_layer_list)

    def fitting_output_def(self) -> FittingOutputDef:
        """Get the output def of the fitting net."""
        return (
            self.fitting_net.output_def()
            if self.fitting_net is not None
            else self.coord_denoise_net.output_def()
        )

    def get_rcut(self) -> float:
        """Get the cut-off radius."""
        return self.rcut

    def get_sel(self) -> list[int]:
        """Get the neighbor selection."""
        return self.sel

    def get_buffer_type_map(self) -> paddle.Tensor:
        """
        Return the type map as a buffer-style Tensor for JIT saving.

        The original type map (e.g., ['Ni', 'O']) is first joined into a single space-separated string
        (e.g., "Ni O"). Each character in this string is then converted to its ASCII code using `ord()`,
        and the resulting integer sequence is stored as a 1D paddle.Tensor of dtype int.

        This format allows the type map to be serialized as a raw byte buffer during JIT model saving.
        """
        return self.buffer_type_map

    def get_buffer_rcut(self) -> paddle.Tensor:
        """Get the cut-off radius as a buffer-style Tensor."""
        return self.descriptor.get_buffer_rcut()

    def get_buffer_sel(self) -> paddle.Tensor:
        """Get the neighbor selection as a buffer-style Tensor."""
        return self.descriptor.get_buffer_sel()

    def set_case_embd(self, case_idx: int) -> None:
        """
        Set the case embedding of this atomic model by the given case_idx,
        typically concatenated with the output of the descriptor and fed into the fitting net.
        """
        self.fitting_net.set_case_embd(case_idx)

    def mixed_types(self) -> bool:
        """If true, the model
        1. assumes total number of atoms aligned across frames;
        2. uses a neighbor list that does not distinguish different atomic types.

        If false, the model
        1. assumes total number of atoms of each atom type aligned across frames;
        2. uses a neighbor list that distinguishes different atomic types.

        """
        return self.descriptor.mixed_types()

    def change_type_map(
        self,
        type_map: list[str],
        model_with_new_type_stat: Optional["DPAtomicModel"] = None,
    ) -> None:
        """Change the type related params to new ones, according to `type_map` and the original one in the model.
        If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types.
        """
        super().change_type_map(
            type_map=type_map, model_with_new_type_stat=model_with_new_type_stat
        )
        self.type_map = type_map
        self.ntypes = len(type_map)
        self.descriptor.change_type_map(
            type_map=type_map,
            model_with_new_type_stat=model_with_new_type_stat.descriptor
            if model_with_new_type_stat is not None
            else None,
        )
        self.fitting_net.change_type_map(type_map=type_map)

    def has_message_passing(self) -> bool:
        """Returns whether the atomic model has message passing."""
        return self.descriptor.has_message_passing()

    def need_sorted_nlist_for_lower(self) -> bool:
        """Returns whether the atomic model needs sorted nlist when using `forward_lower`."""
        return self.descriptor.need_sorted_nlist_for_lower()

    def serialize(self) -> dict:
        dd = BaseAtomicModel.serialize(self)
        dd.update(
            {
                "@class": "Model",
                "@version": 2,
                "type": "standard",
                "type_map": self.type_map,
                "descriptor": self.descriptor.serialize(),
                "fitting": self.fitting_net.serialize(),
            }
        )
        return dd

    @classmethod
    def deserialize(cls, data: dict) -> "DPAtomicModel":
        data = data.copy()
        check_version_compatibility(data.pop("@version", 1), 2, 1)
        data.pop("@class", None)
        data.pop("type", None)
        descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor"))
        fitting_obj = BaseFitting.deserialize(data.pop("fitting"))
        data["descriptor"] = descriptor_obj
        data["fitting"] = fitting_obj
        obj = super().deserialize(data)
        return obj

    def enable_compression(
        self,
        min_nbor_dist: float,
        table_extrapolate: float = 5,
        table_stride_1: float = 0.01,
        table_stride_2: float = 0.1,
        check_frequency: int = -1,
    ) -> None:
        """Call descriptor enable_compression().

        Parameters
        ----------
        min_nbor_dist
            The nearest distance between atoms
        table_extrapolate
            The scale of model extrapolation
        table_stride_1
            The uniform stride of the first table
        table_stride_2
            The uniform stride of the second table
        check_frequency
            The overflow check frequency
        """
        self.descriptor.enable_compression(
            min_nbor_dist,
            table_extrapolate,
            table_stride_1,
            table_stride_2,
            check_frequency,
        )

    def forward_atomic(
        self,
        extended_coord: paddle.Tensor,
        extended_atype: paddle.Tensor,
        nlist: paddle.Tensor,
        mapping: Optional[paddle.Tensor] = None,
        fparam: Optional[paddle.Tensor] = None,
        aparam: Optional[paddle.Tensor] = None,
        comm_dict: Optional[dict[str, paddle.Tensor]] = None,
    ) -> dict[str, paddle.Tensor]:
        """Return atomic prediction.

        Parameters
        ----------
        extended_coord
            coordinates in extended region
        extended_atype
            atomic type in extended region
        nlist
            neighbor list. nf x nloc x nsel
        mapping
            mapps the extended indices to local indices
        fparam
            frame parameter. nf x ndf
        aparam
            atomic parameter. nf x nloc x nda

        Returns
        -------
        result_dict
            the result dict, defined by the `FittingOutputDef`.

        """
        nframes, nloc, nnei = nlist.shape
        atype = extended_atype[:, :nloc]
        if self.do_grad_r() or self.do_grad_c():
            extended_coord.stop_gradient = False
        descriptor, rot_mat, g2, h2, sw = self.descriptor(
            extended_coord,
            extended_atype,
            nlist,
            mapping=mapping,
            comm_dict=comm_dict,
        )
        assert descriptor is not None
        if self.enable_eval_descriptor_hook:
            self.eval_descriptor_list.append(descriptor.detach())
        # energy, force
        fit_ret = self.fitting_net(
            descriptor,
            atype,
            gr=rot_mat,
            g2=g2,
            h2=h2,
            fparam=fparam,
            aparam=aparam,
        )
        if self.enable_eval_fitting_last_layer_hook:
            assert "middle_output" in fit_ret, (
                "eval_fitting_last_layer not supported for this fitting net!"
            )
            self.eval_fitting_last_layer_list.append(
                fit_ret.pop("middle_output").detach()
            )
        return fit_ret

    def get_out_bias(self) -> paddle.Tensor:
        return self.out_bias

    def compute_or_load_stat(
        self,
        sampled_func: Callable[[], list[dict]],
        stat_file_path: Optional[DPPath] = None,
        compute_or_load_out_stat: bool = True,
    ) -> None:
        """
        Compute or load the statistics parameters of the model,
        such as mean and standard deviation of descriptors or the energy bias of the fitting net.
        When `sampled` is provided, all the statistics parameters will be calculated (or re-calculated for update),
        and saved in the `stat_file_path`(s).
        When `sampled` is not provided, it will check the existence of `stat_file_path`(s)
        and load the calculated statistics parameters.

        Parameters
        ----------
        sampled_func
            The lazy sampled function to get data frames from different data systems.
        stat_file_path
            The dictionary of paths to the statistics files.
        compute_or_load_out_stat : bool
            Whether to compute the output statistics.
            If False, it will only compute the input statistics (e.g. mean and standard deviation of descriptors).
        """
        if stat_file_path is not None and self.type_map is not None:
            # descriptors and fitting net with different type_map
            # should not share the same parameters
            stat_file_path /= " ".join(self.type_map)

        @functools.lru_cache
        def wrapped_sampler():
            sampled = sampled_func()
            if self.pair_excl is not None:
                pair_exclude_types = self.pair_excl.get_exclude_types()
                for sample in sampled:
                    sample["pair_exclude_types"] = list(pair_exclude_types)
            if self.atom_excl is not None:
                atom_exclude_types = self.atom_excl.get_exclude_types()
                for sample in sampled:
                    sample["atom_exclude_types"] = list(atom_exclude_types)
            return sampled

        self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
        self.fitting_net.compute_input_stats(
            wrapped_sampler, protection=self.data_stat_protect
        )
        if compute_or_load_out_stat:
            self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)

    def get_dim_fparam(self) -> int:
        """Get the number (dimension) of frame parameters of this atomic model."""
        return self.fitting_net.get_dim_fparam()

    def get_buffer_dim_fparam(self) -> paddle.Tensor:
        """Get the number (dimension) of frame parameters of this atomic model as a buffer-style Tensor."""
        return self.fitting_net.get_buffer_dim_fparam()

    def has_default_fparam(self) -> bool:
        """Check if the model has default frame parameters."""
        return self.fitting_net.has_default_fparam()

    def get_dim_aparam(self) -> int:
        """Get the number (dimension) of atomic parameters of this atomic model."""
        return self.fitting_net.get_dim_aparam()

    def get_buffer_dim_aparam(self) -> paddle.Tensor:
        """Get the number (dimension) of atomic parameters of this atomic model as a buffer-style Tensor."""
        return self.fitting_net.get_buffer_dim_aparam()

    def get_sel_type(self) -> list[int]:
        """Get the selected atom types of this model.

        Only atoms with selected atom types have atomic contribution
        to the result of the model.
        If returning an empty list, all atom types are selected.
        """
        return self.fitting_net.get_sel_type()

    def is_aparam_nall(self) -> bool:
        """Check whether the shape of atomic parameters is (nframes, nall, ndim).

        If False, the shape is (nframes, nloc, ndim).
        """
        return False
