# Modified from: keras/src/backend/common/keras_tensor.py
# Original authors: François Chollet et al. (Keras Team)
# License Apache 2.0: (c) 2025 Yoan Sallami (Synalinks Team)

import asyncio
import json

from synalinks.src import tree
from synalinks.src.api_export import synalinks_export
from synalinks.src.backend.common.json_schema_utils import is_schema_equal
from synalinks.src.backend.common.json_schema_utils import standardize_schema
from synalinks.src.utils.naming import auto_name


@synalinks_export("synalinks.SymbolicDataModel")
class SymbolicDataModel:
    """A symbolic backend-independent data model.

    A `SymbolicDataModel` is a container for a JSON schema and can be used to represent
        data structures in a backend-agnostic way. It can record history and is used in
        symbolic operations (in the Functional API and to compute output specs).

    A "symbolic data model" can be understood as a placeholder for data specification,
        it does not contain any actual data, only a schema. It can be used for building
        Functional models, but it cannot be used in actual computations.

    Args:
        data_model (DataModel): Optional. The data_model used to extract the schema.
        schema (dict): Optional. The JSON schema to be used. If the schema is not
            provided, the data_model argument should be used to infer it.
        record_history (bool): Optional. Boolean indicating if the history
            should be recorded. Defaults to `True`.
        name (str): Optional. A unique name for the data model. Automatically generated
            if not set.

    Examples:

    **Creating a `SymbolicDataModel` with a data_model:**

    ```python
    class Query(synalinks.DataModel):
        query: str

    data_model = SymbolicDataModel(data_model=Query)
    ```

    **Creating a `SymbolicDataModel` with a data_model's schema:**

    ```python
    class Query(synalinks.DataModel):
        query: str

    data_model = SymbolicDataModel(schema=Query.schema())
    ```
    """

    def __init__(
        self,
        data_model=None,
        schema=None,
        record_history=True,
        name=None,
    ):
        self.name = name or auto_name(self.__class__.__name__)
        self.record_history = record_history
        self._schema = None
        if not schema and not data_model:
            raise ValueError(
                "You should specify at least one argument between "
                "`data_model` or `schema`"
            )
        if schema and data_model:
            if not is_schema_equal(schema, data_model.schema()):
                raise ValueError(
                    "Attempting to create a SymbolicDataModel "
                    "with both `schema` and `data_model` argument "
                    "but their schemas are incompatible "
                    f"received schema={schema} and "
                    f"data_model.schema()={data_model.schema()}."
                )
            self._schema = standardize_schema(schema)
        else:
            if schema:
                self._schema = standardize_schema(schema)
            if data_model:
                self._schema = standardize_schema(data_model.schema())

    @property
    def name(self):
        """The name of the data model."""
        return self._name

    @name.setter
    def name(self, value):
        self._name = value

    @property
    def record_history(self):
        """Whether the history is being recorded."""
        return self._record_history

    @record_history.setter
    def record_history(self, value):
        self._record_history = value

    def schema(self):
        """The JSON schema of the data model.

        Returns:
            dict: The JSON schema.
        """
        return self._schema

    def pretty_schema(self):
        """Get a pretty version of the JSON schema for display.

        Returns:
            dict: The indented JSON schema.
        """
        return json.dumps(self.schema(), indent=2)

    def __repr__(self):
        return f"<SymbolicDataModel schema={self._schema}, name={self._name}>"

    def __add__(self, other):
        """Concatenates this data model with another.

        Args:
            other (SymbolicDataModel): The other data model to concatenate with.

        Returns:
            SymbolicDataModel: The concatenated data model.
        """
        from synalinks.src import ops

        return asyncio.get_event_loop().run_until_complete(
            ops.Concat().symbolic_call(self, other)
        )

    def __radd__(self, other):
        """Concatenates (reverse) another data model with this one.

        Args:
            other (SymbolicDataModel): The other data model to concatenate with.

        Returns:
            SymbolicDataModel: The concatenated data model.
        """
        from synalinks.src import ops

        return asyncio.get_event_loop().run_until_complete(
            ops.Concat().symbolic_call(other, self)
        )

    def __and__(self, other):
        """Perform a logical `And` with another data model

        If one of them is None, output None. If both are provided,
        then concatenates the other data model with this one.

        Args:
            other (SymbolicDataModel): The other data model to concatenate with.

        Returns:
            SymbolicDataModel | None: The concatenated data model.
        """
        from synalinks.src import ops

        return asyncio.get_event_loop().run_until_complete(
            ops.And().symbolic_call(self, other)
        )

    def __rand__(self, other):
        """Perform a logical `And` (reverse) with another data model

        If one of them is None, output None. If both are provided,
        then cnoncatenates the other data model with this one.

        Args:
            other (SymbolicDataModel): The other data model to concatenate with.

        Returns:
            SymbolicDataModel | None: The concatenated data model.
        """
        from synalinks.src import ops

        return asyncio.get_event_loop().run_until_complete(
            ops.And().symbolic_call(other, self)
        )

    def __or__(self, other):
        """Perform a logical `Or` with another data model

        If one of them is None, output the other one. If both are provided,
        then concatenates the other data model with this one.

        Args:
            other (SymbolicDataModel): The other data model to concatenate with.

        Returns:
            SymbolicDataModel | None: The concatenated data model.
        """
        from synalinks.src import ops

        return asyncio.get_event_loop().run_until_complete(
            ops.Or().symbolic_call(self, other)
        )

    def __ror__(self, other):
        """Perform a logical `Or` (reverse) with another data model

        If one of them is None, output the other one. If both are provided,
        then concatenates the other data model with this one.

        Args:
            other (SymbolicDataModel): The other data model to concatenate with.

        Returns:
            SymbolicDataModel | None: The concatenated data model.
        """
        from synalinks.src import ops

        return asyncio.get_event_loop().run_until_complete(
            ops.Or().symbolic_call(other, self)
        )

    def factorize(self):
        """Factorizes the data model.

        Returns:
            SymbolicDataModel: The factorized data model.
        """
        from synalinks.src import ops

        return asyncio.get_event_loop().run_until_complete(
            ops.Factorize().symbolic_call(self)
        )

    def in_mask(self, mask=None, recursive=True):
        """Applies a mask to **keep only** specified keys of the data model.

        Args:
            mask (list): The mask to be applied (list of keys).
            recursive (bool): Optional. Whether to apply the mask recursively.
                Defaults to True.

        Returns:
            SymbolicDataModel: The data model with the input mask applied.
        """
        from synalinks.src import ops

        return asyncio.get_event_loop().run_until_complete(
            ops.InMask(mask=mask, recursive=True).symbolic_call(self)
        )

    def out_mask(self, mask=None, recursive=True):
        """Applies an output mask to **remove** specified keys of the data model.

        Args:
            mask (list): The mask to be applied (list of keys).
            recursive (bool): Optional. Whether to apply the mask recursively.
                Defaults to True.

        Returns:
            SymbolicDataModel: The data model with the output mask applied.
        """
        from synalinks.src import ops

        return asyncio.get_event_loop().run_until_complete(
            ops.OutMask(mask=mask, recursive=True).symbolic_call(self)
        )

    def prefix(self, prefix=None):
        """Add a prefix to **all** the data model fields (non-recursive).

        Args:
            prefix (str): the prefix to add

        Returns:
            SymbolicDataModel: The data model with the prefix added.
        """
        from synalinks.src import ops

        return asyncio.get_event_loop().run_until_complete(
            ops.Prefix(prefix=prefix).symbolic_call(self)
        )

    def suffix(self, suffix=None):
        """Add a suffix to **all** the data model fields (non-recursive).

        Args:
            suffix (str): the suffix to add

        Returns:
            SymbolicDataModel: The data model with the suffix added.
        """
        from synalinks.src import ops

        return asyncio.get_event_loop().run_until_complete(
            ops.Suffix(suffix=suffix).symbolic_call(self)
        )

    def get(self, key):
        """Get wrapper to make easier to access fields.

        Args:
            key (str): The key to access.
        """
        raise ValueError(
            f"Attempting to get '{key}' from a symbolic data model "
            "this operation is not possible, make sure that your `call()` "
            "is correctly implemented, if so then you likely need to implement "
            " `compute_output_spec()` in your subclassed module."
        )

    def update(self, kv_dict):
        """Update wrapper to make easier to modify fields.

        Args:
            kv_dict (dict): The key/value dict to update.
        """
        raise ValueError(
            f"Attempting to update keys '{list(kv_dict.key())}' from a symbolic "
            "data model this operation is not possible, make sure that your `call()` "
            "is correctly implemented, if so then you likely need to implement "
            " `compute_output_spec()` in your subclassed module."
        )


def any_symbolic_data_models(args=None, kwargs=None):
    """Checks if any of the arguments are symbolic data models.

    Args:
        args (tuple): Optional. The positional arguments to check.
        kwargs (dict): Optional. The keyword arguments to check.

    Returns:
        bool: True if any of the arguments are symbolic data models, False otherwise.
    """
    args = args or ()
    kwargs = kwargs or {}
    for x in tree.flatten((args, kwargs)):
        if is_symbolic_data_model(x):
            return True
    return False


@synalinks_export(
    [
        "synalinks.utils.is_symbolic_data_model",
        "synalinks.backend.is_symbolic_data_model",
    ]
)
def is_symbolic_data_model(x):
    """Returns whether `x` is a synalinks data model.

    A "synalinks data model" is a *symbolic data model*, such as a data model
    that was created via `Input()`. A "symbolic data model"
    can be understood as a placeholder for data specification -- it does not
    contain any actual data, only a schema.
    It can be used for building Functional models, but it
    cannot be used in actual computations.

    Args:
        x (any): The object to check.

    Returns:
        bool: True if `x` is a Synalinks data model, False otherwise.
    """
    return isinstance(x, SymbolicDataModel)
