# Modified from: keras/src/layers/core/input_layer.py
# Original authors: François Chollet et al. (Keras Team)
# License Apache 2.0: (c) 2025 Yoan Sallami (Synalinks Team)

from synalinks.src import backend
from synalinks.src.api_export import synalinks_export
from synalinks.src.backend import is_schema_equal
from synalinks.src.backend import is_symbolic_data_model
from synalinks.src.backend import standardize_schema
from synalinks.src.modules.module import Module
from synalinks.src.ops.node import Node


@synalinks_export("synalinks.modules.InputModule")
class InputModule(Module):
    def __init__(
        self,
        schema=None,
        input_data_model=None,
        optional=False,
        name=None,
        **kwargs,
    ):
        super().__init__(
            name=name, description="Defines the input data model for a program."
        )
        if input_data_model is not None:
            if not is_symbolic_data_model(input_data_model):
                raise ValueError(
                    "Argument `input_data_model` must be a SymbolicDataModel. "
                    f"Received invalid type: input_data_model={input_data_model} "
                    f"(of type {type(input_data_model)})"
                )
            if schema is not None:
                if not is_schema_equal(schema, input_data_model.schema()):
                    raise ValueError(
                        "When providing the `input_data_model` argument, you "
                        "cannot provide an incompatible `schema` argument."
                    )
            schema = input_data_model.schema()
        else:
            if schema is None:
                raise ValueError("You must pass a `schema` argument.")

        self._schema = standardize_schema(schema)
        if input_data_model is None:
            input_data_model = backend.SymbolicDataModel(
                schema=self._schema,
                name=name,
            )
        self._input_data_model = input_data_model
        Node(
            operation=self,
            call_args={},
            call_kwargs={},
            outputs=input_data_model,
        )
        self.built = True
        self.optional = optional

    @property
    def input_schema(self):
        return self._schema

    async def call(self):
        return

    def schema(self):
        return self._schema

    def get_config(self):
        return {
            "schema": self._schema,
            "optional": self.optional,
            "name": self.name,
            "description": self.description,
        }

    @classmethod
    def from_config(cls, config):
        return cls(**config)


@synalinks_export(["synalinks.modules.Input", "synalinks.Input"])
def Input(
    schema=None,
    data_model=None,
    optional=False,
    name=None,
):
    """Used to instantiate a `SymbolicDataModel`.

    A `SymbolicDataModel` is a symbolic data_model-like object, which we augment with
    certain attributes that allow us to build a Synalinks `Program` just by knowing the
    inputs and outputs of the program.

    Example:

    ```python
    import synalinks

    class Query(synalinks.DataModel):
        query: str

    inputs = synalinks.Input(data_model=Query)

    # You can also create it using a JSON schema like this:

    inputs = synalinks.Input(schema=Query.schema())

    # Or using a symbolic datamodel:

    inputs = synalinks.Input(data_model=Query.to_symbolic_data_model())
    ```

    Args:
        schema (dict): A Json schema of the data_model.
            If not provided uses the `data_model` argument.
        data_model (DataModel): Optional existing data model to wrap into
            the `Input` layer. If set, the module will use this data_model rather
            than creating a new placeholder data model.
        optional (bool): Whether the input is optional or not.
            An optional input can accept `None` values.
        name (string): Optional name string for the module.
            Should be unique in a program (do not reuse the same name twice).
            It will be autogenerated if it isn't provided.

    Returns:
        (SymbolicDataModel): The symbolic data model corresponding to
            the given data model/schema.
    """
    module = InputModule(
        schema=schema,
        input_data_model=data_model.to_symbolic_data_model() if data_model else None,
        optional=optional,
        name=name,
    )
    return module.output
