from __future__ import annotations

from typing import Any, Callable, Dict, Optional

from prefect import flow, task
from prefect.cache_policies import NO_CACHE
from node_graph import NodeGraph

from ..core.base import BaseEngine
from ..core.execution import (
    compute_graph_outputs,
    execute_node_job,
    mark_process_failure,
    mark_process_success,
    prepare_graph_run,
)
from ..core.task import EngineNodeExecutor, NodeTaskMeta
from ..core.utils import (
    update_nested_dict_with_special_keys,
    get_nested_dict,
    _collect_literals,
    _is_encoded_tagged,
    _scan_links_topology,
    close_threadlocal_aiida_session,
    get_default_user_email,
    load_default_user,
)
from prefect.task_runners import ThreadPoolTaskRunner
from prefect.futures import PrefectFuture
from prefect.states import State
from aiida import orm
from node_graph.node_graph import BUILTIN_NODES


@task(name="ng:get_nested", cache_policy=NO_CACHE)
def _prefect_get_nested(d: Dict[str, Any], dotted: str, default=None):
    return get_nested_dict(d, dotted, default=default)


@task(name="node:generic", cache_policy=NO_CACHE)
def _node_task(
    _ng_meta: NodeTaskMeta,
    _ng_callable: Optional[Callable] = None,
    _ng_engine_name: Optional[str] = None,
    _ng_node_inputs: Optional[Dict[str, Any]] = None,
    _ng_node_outputs: Optional[Dict[str, Any]] = None,
    _ng_task_runner=None,
    _ng_flow_name: Optional[str] = None,
    _ng_use_analysis: bool = False,
    _ng_parent_pid: Optional[str] = None,
    _ng_default_user_email: Optional[str] = None,
    **kwargs: Any,
):
    user = load_default_user(_ng_default_user_email)
    flow_base = _ng_flow_name or _ng_engine_name or "prefect-flow"

    def _build_sub_engine(name: str) -> "PrefectEngine":
        return PrefectEngine(
            flow_name=name,
            task_runner=_ng_task_runner,
            use_analysis=_ng_use_analysis,
            _default_user_email=_ng_default_user_email,
        )

    try:
        return execute_node_job(
            parent_pid=_ng_parent_pid,
            meta=_ng_meta,
            callable_payload=_ng_callable,
            runtime_inputs=kwargs,
            engine_name=flow_base,
            node_inputs=_ng_node_inputs,
            node_outputs=_ng_node_outputs,
            build_sub_engine=_build_sub_engine,
            user=user,
        )
    finally:
        close_threadlocal_aiida_session()


def _prefect_node_job(
    parent_pid: Optional[str],
    _ng_meta,
    _ng_callable=None,
    _ng_task: Optional[Any] = None,
    _ng_engine_name: Optional[str] = None,
    _ng_default_user_email: Optional[str] = None,
    **kwargs: Any,
):
    if _ng_task is None:
        raise RuntimeError("Prefect task handle is not available for execution")

    submit_kwargs = dict(kwargs)
    submit_kwargs.update(
        {
            "_ng_meta": _ng_meta,
            "_ng_callable": _ng_callable,
            "_ng_parent_pid": parent_pid,
            "_ng_engine_name": _ng_engine_name,
            "_ng_default_user_email": _ng_default_user_email,
        }
    )
    return _ng_task.submit(**submit_kwargs)


class PrefectEngine(BaseEngine):
    """
    Prefect engine for NodeGraph.
      - Builds a Flow from nodes/links with Kahn topological order.
      - Supports multi-fan-in by bundling into a dict with "{fromNode}_{fromSocket}" keys.
    """

    engine_kind = "prefect"

    def __init__(
        self,
        flow_name: str = "node-graph-flow",
        use_analysis: bool = False,
        task_runner=None,
        *,
        _default_user_email: Optional[str] = None,
    ):
        super().__init__(flow_name)
        self.flow_name = flow_name
        self.use_analysis = use_analysis
        self.task_runner = task_runner or ThreadPoolTaskRunner()
        default_email = _default_user_email or get_default_user_email()
        self._default_user_email = default_email

    def _link_socket_value(
        self, from_name: str, from_socket: str, source_map: Dict[str, Any]
    ) -> Any:
        upstream = source_map[from_name]
        if isinstance(upstream, PrefectFuture):
            return _prefect_get_nested.submit(upstream, from_socket, default=None)
        return super()._link_socket_value(from_name, from_socket, source_map)

    def _build_node_executor(
        self,
        node,
        label_kind: str,
    ) -> EngineNodeExecutor:
        executor = node.spec.executor.to_dict()
        meta = self._build_node_task_meta(node, label_kind)
        task_obj = _node_task.with_options(name=f"node:{node.name}")
        static_kwargs = {
            "_ng_engine_name": self.flow_name,
            "_ng_node_inputs": (node.spec.inputs.to_dict() if node.spec.inputs else {}),
            "_ng_node_outputs": (
                node.spec.outputs.to_dict() if node.spec.outputs else {}
            ),
            "_ng_task_runner": self.task_runner,
            "_ng_flow_name": self.flow_name,
            "_ng_use_analysis": self.use_analysis,
            "_ng_task": task_obj,
            "_ng_default_user_email": self._default_user_email,
        }
        return EngineNodeExecutor(
            runner=_prefect_node_job,
            meta=meta,
            callable=executor,
            static_kwargs=static_kwargs,
        )

    def to_flow(self, ng: NodeGraph, values: Dict[str, Any]):
        order, incoming, required_out_sockets = _scan_links_topology(ng)

        executors: Dict[str, EngineNodeExecutor] = {}
        for name in ng.get_node_names():
            if name in BUILTIN_NODES:
                continue
            node = ng.nodes[name]
            node_type = getattr(node.spec, "node_type", "") or ""
            label_kind = "return" if node_type.upper() == "GRAPH" else "create"
            executors[name] = self._build_node_executor(node, label_kind)

        @flow(name=self.flow_name, task_runner=self.task_runner)  # <-- concurrency ON
        def adapted_flow():
            literals = {n: _collect_literals(ng.nodes[n]) for n in ng.get_node_names()}

            all_task_future: Dict[str, Any] = values

            for n in order:
                if n in BUILTIN_NODES:
                    continue

                kw = dict(literals[n])

                kw.update(
                    self._build_link_kwargs(
                        target_name=n,
                        links=incoming.get(n, []),
                        source_map=all_task_future,
                    )
                )
                kw = update_nested_dict_with_special_keys(kw)

                # collect explicit wait deps from _wait edges
                wait_deps = []
                for lk in incoming.get(n, []):
                    if lk.from_socket._scoped_name == "_wait":
                        # depend on the WHOLE upstream task dict future
                        up = all_task_future.get(lk.from_node.name)
                        if up is not None:
                            wait_deps.append(up)

                # schedule with explicit dependencies (does not block others)
                executor = executors[n]
                node_future = executor.invoke(
                    parent_pid=self._graph_pid,
                    **kw,
                )
                all_task_future[n] = node_future

            return all_task_future

        return adapted_flow

    def run(self, ng: NodeGraph, parent_pid: Optional[str] = None) -> Dict[str, Any]:
        """Build the flow and execute it synchronously; returns mapping of Prefect futures."""
        context = prepare_graph_run(
            ng,
            parent_pid=parent_pid,
            user=load_default_user(self._default_user_email),
            encode_graph_inputs=True,
        )
        self._graph_pid = context.process_node.uuid
        values = context.values
        flow_fn = self.to_flow(ng, values)
        try:
            state_map = flow_fn()
            values.update(
                {name: self._resolve_state(value) for name, value in state_map.items()}
            )
            graph_outputs = compute_graph_outputs(
                incoming=context.incoming,
                values=values,
                link_builder=self._build_link_kwargs,
            )
            mark_process_success(context.process_node, graph_outputs)
            return graph_outputs
        except Exception as e:
            mark_process_failure(context.process_node, e)
            raise
        finally:
            context.process_node.seal()

    def _resolve_state(self, value: Any) -> Any:

        if isinstance(value, PrefectFuture):
            value = value.result()
        if isinstance(value, State):
            return value.result()
        if isinstance(value, orm.Data):
            return value
        if _is_encoded_tagged(value):
            return value
        if isinstance(value, dict):
            return {k: self._resolve_state(v) for k, v in value.items()}
        raise TypeError(f"Cannot resolve Prefect state for value: {value}")
