from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Optional, Union
from node_graph.socket import TaggedValue
from node_graph.socket_spec import SocketSpec
from aiida_pythonjob.parsers.utils import parse_outputs


def _coerce_outputs_spec(spec: Any) -> Optional[SocketSpec]:
    if isinstance(spec, SocketSpec):
        return spec
    if isinstance(spec, dict):
        if not spec:
            return None
        return SocketSpec.from_dict(spec)
    to_dict = getattr(spec, "to_dict", None)
    if callable(to_dict):
        return SocketSpec.from_dict(to_dict())
    raise TypeError(f"Cannot coerce outputs spec from type: {type(spec)}")


@dataclass(frozen=True)
class NodeTaskMeta:
    node_name: str
    outputs_spec: Optional[Dict[str, Any]]
    label_kind: str
    is_graph: bool

    def as_dict(self) -> Dict[str, Any]:
        return {
            "node_name": self.node_name,
            "outputs_spec": self.outputs_spec,
            "label_kind": self.label_kind,
            "is_graph": self.is_graph,
        }

    @classmethod
    def from_node(
        cls,
        node: Any,
        *,
        label_kind: str = "return",
    ) -> "NodeTaskMeta":
        outputs_spec_obj = getattr(node.spec, "outputs", None)
        outputs_spec: Optional[Dict[str, Any]]
        if isinstance(outputs_spec_obj, dict):
            outputs_spec = outputs_spec_obj or None
        elif outputs_spec_obj is None:
            outputs_spec = None
        else:
            to_dict = getattr(outputs_spec_obj, "to_dict", None)
            if callable(to_dict):
                outputs_spec = to_dict() or None
            elif isinstance(outputs_spec_obj, SocketSpec):
                outputs_spec = outputs_spec_obj.to_dict() or None
            else:
                outputs_spec = None
        node_type = getattr(node.spec, "node_type", "") or ""
        is_graph = node_type.lower() == "graph"
        return cls(
            node_name=getattr(node, "name", "<node>"),
            outputs_spec=outputs_spec,
            label_kind=label_kind,
            is_graph=is_graph,
        )


def normalize_outputs_for_spec(
    result: Any, outputs_spec: Union[SocketSpec, Dict[str, Any], None]
) -> Dict[str, TaggedValue]:
    spec = _coerce_outputs_spec(outputs_spec)
    parsed = parse_outputs(result, spec)
    if isinstance(parsed, tuple):
        parsed = parsed[0]
    if parsed is None:
        raise ValueError("Failed to normalise node outputs against socket spec")
    return parsed


@dataclass(frozen=True)
class EngineNodeExecutor:
    """Encapsulates the runtime callable and metadata for a node execution."""

    runner: Callable[..., Any]
    meta: NodeTaskMeta
    callable: Optional[Callable] = None
    static_kwargs: Dict[str, Any] = field(default_factory=dict)

    def invoke(
        self,
        *,
        parent_pid: Optional[str],
        **inputs: Any,
    ) -> Any:
        payload = dict(self.static_kwargs)
        payload.update(inputs)
        runtime_kwargs = {
            "parent_pid": parent_pid,
            "_ng_meta": self.meta,
            "_ng_callable": self.callable,
        }
        runtime_kwargs.update(payload)
        return self.runner(**runtime_kwargs)
