"""Airflow integration for NodeGraph Engine."""

from __future__ import annotations

import base64
import inspect
import json
import logging
import os
import re
import textwrap
import time
import uuid
from datetime import datetime, timedelta, timezone
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

from aiida import orm
from node_graph import NodeGraph
from node_graph.node_graph import BUILTIN_NODES

from ..core.base import BaseEngine
from ..core.execution import (
    compute_graph_outputs,
    execute_node_job,
    iterate_node_order,
    mark_process_failure,
    mark_process_success,
    prepare_graph_run,
)
from ..core.utils import (
    _collect_literals,
    _scan_links_topology,
    get_nested_dict,
    update_nested_dict_with_special_keys,
    close_threadlocal_aiida_session,
    get_default_user_email,
    load_default_user,
)

from airflow import DAG
from airflow.providers.standard.operators.python import PythonOperator


IncomingSpec = Dict[str, str]


_SANITIZE_PATTERN = re.compile(r"[^A-Za-z0-9_]")


def _sanitize_dag_id(value: str) -> str:
    sanitized = _SANITIZE_PATTERN.sub("_", value)
    if not sanitized:
        sanitized = "node_graph"
    if sanitized[0].isdigit():
        sanitized = f"ng_{sanitized}"
    return sanitized


@dataclass
class SchedulerPaths:
    """Resolved file-system locations used for scheduler-backed runs."""

    airflow_home: Path
    dags_dir: Path
    run_root: Path


@dataclass
class SchedulerRunArtifacts:
    """Metadata for a single scheduler-triggered run."""

    run_id: str
    run_conf: Dict[str, Any]
    result_path: Path


@dataclass
class TaskSdkInterfaces:
    """Encapsulate optional Task SDK communication primitives."""

    supervisor_comms: Any = None
    trigger_cls: Any = None
    get_state_cls: Any = None
    ok_response_cls: Any = None
    error_response_cls: Any = None
    error_type_cls: Any = None

    @property
    def available(self) -> bool:
        return self.supervisor_comms is not None and self.trigger_cls is not None


def _resolve_scheduler_paths(dag_id: str) -> SchedulerPaths:
    """Determine key Airflow directories for a generated sub-graph DAG."""

    airflow_home = Path(os.environ.get("AIRFLOW_HOME", Path.home() / "airflow"))

    dags_dir_setting = os.environ.get("AIRFLOW__CORE__DAGS_FOLDER")
    if not dags_dir_setting:
        try:  # pragma: no cover - requires Airflow runtime configuration
            from airflow.configuration import conf as airflow_conf

            dags_dir_setting = airflow_conf.get("core", "dags_folder")
        except Exception:
            dags_dir_setting = None

    if dags_dir_setting:
        dags_dir = Path(dags_dir_setting).expanduser()
    else:
        dags_dir = airflow_home / "dags"

    dags_dir.mkdir(parents=True, exist_ok=True)

    run_root = airflow_home / "ng_subgraph_runs"
    run_root.mkdir(parents=True, exist_ok=True)

    return SchedulerPaths(
        airflow_home=airflow_home, dags_dir=dags_dir, run_root=run_root
    )


def _build_scheduler_payload(
    *,
    dag_id: str,
    ng: NodeGraph,
    engine_config: Dict[str, Any],
    default_user_email: Optional[str],
    schedule_subgraphs: bool,
    serialize_fn: Callable[[Any], Any],
) -> bytes:
    """Serialize the information required to rebuild a sub-graph DAG."""

    payload = {
        "dag_id": dag_id,
        "engine_config": engine_config,
        "graph": ng.to_dict(include_sockets=True, should_serialize=True),
        "default_user_email": default_user_email,
        "schedule_subgraphs": bool(schedule_subgraphs),
    }

    serialized = serialize_fn(payload)
    if not isinstance(serialized, (str, bytes)):
        serialized = json.dumps(serialized)

    if isinstance(serialized, str):
        return serialized.encode("utf-8")
    return serialized


_DAG_TEMPLATE = textwrap.dedent(
    '''
    """Auto-generated NodeGraph Engine DAG for sub-graph execution."""

    from __future__ import annotations

    import base64

    from aiida import load_profile
    from aiida.orm.utils.serialize import deserialize_unsafe
    from node_graph import NodeGraph

    load_profile()

    from node_graph_engine.engines.airflow import AirflowEngine

    _PAYLOAD_DATA = """
    __PAYLOAD_BLOB__
    """.encode("utf-8")

    payload = deserialize_unsafe(base64.b64decode(_PAYLOAD_DATA))
    ng = NodeGraph.from_dict(payload["graph"])

    engine = AirflowEngine(
        dag_id=payload["dag_id"],
        default_args=payload["engine_config"].get("default_args"),
        schedule=payload["engine_config"].get("schedule"),
        start_date=payload["engine_config"].get("start_date"),
        catchup=payload["engine_config"].get("catchup", False),
        max_active_runs=payload["engine_config"].get("max_active_runs", 1),
        _default_user_email=payload["default_user_email"],
        schedule_subgraphs=payload.get("schedule_subgraphs", False),
    )

    dag = engine.build_dag(ng)
    '''
)


def _render_generated_dag(payload_blob: bytes) -> str:
    """Embed the serialized payload inside the generated DAG template."""

    encoded_payload = base64.b64encode(payload_blob).decode("ascii")
    wrapped_payload = textwrap.fill(encoded_payload, width=76)
    return _DAG_TEMPLATE.replace("__PAYLOAD_BLOB__", wrapped_payload)


def _write_dag_file(dag_path: Path, dag_source: str) -> None:
    """Persist the generated DAG to disk using an atomic replace."""

    dag_path.parent.mkdir(parents=True, exist_ok=True)
    tmp_path = dag_path.with_suffix(f"{dag_path.suffix}.tmp")
    tmp_path.write_text(dag_source)
    os.replace(tmp_path, dag_path)


def _register_generated_dag(dag_id: str, dag_path: Path, dags_dir: Path) -> "DAG":
    """Load the generated DAG into a DagBag and sync metadata."""

    from airflow.models import DagBag  # type: ignore  # pragma: no cover - runtime import

    dag_bag = DagBag(dag_folder=str(dags_dir), include_examples=False)
    dag_bag.process_file(str(dag_path))

    try:  # pragma: no cover - requires Airflow metadata database
        dag_bag.sync_to_db(bundle_name=dag_id, bundle_version=None)
    except Exception:
        logging.getLogger(__name__).debug(
            "Failed to sync DagBag for '%s' to metadata database", dag_id, exc_info=True
        )

    dag_obj = dag_bag.dags.get(dag_id)
    if dag_obj is None:
        raise RuntimeError(
            f"Failed to register Airflow DAG '{dag_id}' for sub-graph execution"
        )
    return dag_obj


def _prepare_run_artifacts(
    *, paths: SchedulerPaths, dag_id: str, run_id: str, parent_pid: Optional[str]
) -> SchedulerRunArtifacts:
    """Set up filesystem targets and configuration for a scheduler run."""

    run_dir = paths.run_root / dag_id / run_id
    run_dir.mkdir(parents=True, exist_ok=True)

    result_path = run_dir / "result.json"

    run_conf: Dict[str, Any] = {}
    if parent_pid is not None:
        run_conf["ng_parent_pid"] = parent_pid

    run_conf["ng_result_path"] = str(result_path)
    run_conf.setdefault("ng_subgraph_run_id", run_id)

    return SchedulerRunArtifacts(
        run_id=run_id, run_conf=run_conf, result_path=result_path
    )


def _load_task_sdk_interfaces(
    task_context: Optional[Dict[str, Any]]
) -> TaskSdkInterfaces:
    """Discover Task SDK communication primitives when running inside Airflow."""

    supervisor_comms = None

    if task_context is not None:
        ti = task_context.get("ti")
        if ti is not None:
            task_runner = getattr(ti, "_task_runner", None)
            if task_runner is not None:
                supervisor_comms = getattr(task_runner, "supervisor_comms", None)

    try:  # pragma: no cover - requires Airflow runtime
        from airflow.sdk.execution_time.task_runner import (
            SUPERVISOR_COMMS as _SUPERVISOR_COMMS,
        )
        from airflow.sdk.execution_time.comms import (
            ErrorResponse as SDKErrorResponse,
            GetDagRunState as SDKGetDagRunState,
            OKResponse as SDKOKResponse,
            TriggerDagRun as SDKTriggerDagRun,
        )
        from airflow.sdk.exceptions import ErrorType as SDKErrorType
    except Exception:
        return TaskSdkInterfaces(supervisor_comms=supervisor_comms)

    if supervisor_comms is None and _SUPERVISOR_COMMS is not None:
        supervisor_comms = _SUPERVISOR_COMMS

    return TaskSdkInterfaces(
        supervisor_comms=supervisor_comms,
        trigger_cls=SDKTriggerDagRun,
        get_state_cls=SDKGetDagRunState,
        ok_response_cls=SDKOKResponse,
        error_response_cls=SDKErrorResponse,
        error_type_cls=SDKErrorType,
    )


def _trigger_via_task_sdk(
    interfaces: TaskSdkInterfaces, dag_id: str, run_id: str, run_conf: Dict[str, Any]
) -> Tuple[bool, Optional[BaseException]]:
    """Attempt to trigger a DAG through the Task SDK, returning success status."""

    if not interfaces.available:
        return False, None

    try:  # pragma: no cover - requires Airflow runtime
        response = interfaces.supervisor_comms.send(
            interfaces.trigger_cls(
                dag_id=dag_id,
                run_id=run_id,
                conf=run_conf,
                logical_date=None,
                reset_dag_run=True,
            )
        )
    except Exception as exc:  # pragma: no cover - runtime-specific failure
        logging.getLogger(__name__).debug(
            "TriggerDagRun request via Task SDK failed for '%s': %s",
            dag_id,
            exc,
            exc_info=True,
        )
        return False, exc

    if interfaces.error_response_cls is not None and isinstance(
        response, interfaces.error_response_cls
    ):
        error_detail = getattr(response, "error", None)
        if (
            interfaces.error_type_cls is not None
            and error_detail == interfaces.error_type_cls.DAGRUN_ALREADY_EXISTS
        ):
            raise RuntimeError(
                f"Airflow DAG '{dag_id}' already has a run with id '{run_id}'"
            )

        message = getattr(response, "message", None)
        detail = message or error_detail or response
        return False, RuntimeError(
            f"Failed to trigger Airflow DAG '{dag_id}' via Task SDK: {detail}"
        )

    if interfaces.ok_response_cls is not None and isinstance(
        response, interfaces.ok_response_cls
    ):
        return True, None

    return True, None


def _get_refresh_interval() -> Optional[float]:
    """Fetch Airflow's DAG processor refresh interval when available."""

    try:  # pragma: no cover - requires Airflow runtime configuration
        from airflow.configuration import conf as airflow_conf

        return airflow_conf.getfloat("dag_processor", "refresh_interval")
    except Exception:
        return None


def _trigger_dag_with_retries(
    *,
    dag_id: str,
    run_id: str,
    run_conf: Dict[str, Any],
    discovery_timeout: float,
    interfaces: TaskSdkInterfaces,
) -> None:
    """Trigger a DAG run, retrying until success or timeout."""

    logger = logging.getLogger(__name__)

    refresh_interval = _get_refresh_interval()
    default_trigger_timeout = discovery_timeout
    if refresh_interval:
        default_trigger_timeout = max(default_trigger_timeout, refresh_interval * 2.0)

    trigger_timeout = float(
        os.environ.get("NG_AIRFLOW_TRIGGER_TIMEOUT", default_trigger_timeout)
    )
    trigger_interval = float(os.environ.get("NG_AIRFLOW_TRIGGER_INTERVAL", 5.0))

    deadline = time.monotonic() + trigger_timeout
    attempt = 0
    last_error: Optional[BaseException] = None

    while time.monotonic() < deadline:
        attempt += 1

        if interfaces.available:
            try:
                triggered, sdk_error = _trigger_via_task_sdk(
                    interfaces, dag_id, run_id, run_conf
                )
            except RuntimeError as exc:
                last_error = exc
                raise

            if triggered:
                logger.info(
                    "Successfully triggered DAG '%s' via Task SDK on attempt %d",
                    dag_id,
                    attempt,
                )
                return

            if sdk_error is not None:
                last_error = sdk_error
                logger.info(
                    "Task SDK trigger attempt %d for DAG '%s' failed (%s); retrying in %.1fs",
                    attempt,
                    dag_id,
                    sdk_error,
                    trigger_interval,
                )
            else:
                logger.info(
                    "Task SDK trigger attempt %d for DAG '%s' did not succeed; retrying in %.1fs",
                    attempt,
                    dag_id,
                    trigger_interval,
                )

            time.sleep(max(0.5, trigger_interval))
            continue

        try:
            from airflow.api.common.trigger_dag import trigger_dag
        except Exception as import_exc:  # pragma: no cover - requires Airflow runtime
            last_error = import_exc
            logger.info(
                "Trigger mechanism unavailable for DAG '%s'; retrying in %.1fs (attempt %d)",
                dag_id,
                trigger_interval,
                attempt,
            )
            time.sleep(max(0.5, trigger_interval))
            continue

        trigger_sig = inspect.signature(trigger_dag)
        trigger_kwargs: Dict[str, Any] = {
            "dag_id": dag_id,
            "run_id": run_id,
            "conf": run_conf,
            "logical_date": None,
            "replace_microseconds": False,
        }

        if "triggered_by" in trigger_sig.parameters:
            try:
                from airflow.utils.types import DagRunTriggeredByType

                trigger_kwargs["triggered_by"] = DagRunTriggeredByType.OPERATOR
            except Exception:
                trigger_kwargs.pop("triggered_by", None)
        else:
            trigger_kwargs.pop("logical_date", None)

        supported_kwargs = {
            name: value
            for name, value in trigger_kwargs.items()
            if name in trigger_sig.parameters
        }

        try:
            trigger_dag(**supported_kwargs)
        except Exception as exc:  # pragma: no cover - requires Airflow runtime
            last_error = exc
            if (
                exc.__class__.__name__ == "DagRunAlreadyExists"
                or "already exists" in str(exc).lower()
            ):
                raise RuntimeError(
                    f"Airflow DAG '{dag_id}' already has a run with id '{run_id}'"
                ) from exc

            logger.info(
                "Trigger attempt %d for DAG '%s' failed (%s); retrying in %.1fs",
                attempt,
                dag_id,
                exc,
                trigger_interval,
            )
            time.sleep(max(0.5, trigger_interval))
            continue

        logger.info(
            "Successfully triggered DAG '%s' after %d attempt(s)", dag_id, attempt
        )
        return

    if last_error is not None:
        raise RuntimeError(
            f"Failed to trigger Airflow DAG '{dag_id}' after {attempt} attempt(s)"
        ) from last_error

    raise RuntimeError(
        f"Failed to trigger Airflow DAG '{dag_id}' within {trigger_timeout:.1f}s"
    )


def _fetch_run_state(
    interfaces: TaskSdkInterfaces, dag_id: str, run_id: str
) -> Optional[str]:
    """Fetch the latest DAG run state via the Task SDK when available."""

    if not interfaces.available or interfaces.get_state_cls is None:
        return None

    try:  # pragma: no cover - requires Airflow runtime
        response = interfaces.supervisor_comms.send(
            interfaces.get_state_cls(dag_id=dag_id, run_id=run_id)
        )
    except Exception:
        return None

    if interfaces.error_response_cls is not None and isinstance(
        response, interfaces.error_response_cls
    ):
        return None

    state = getattr(response, "state", None)
    if state is None:
        return None
    return getattr(state, "value", state)


def _poll_for_result(
    *,
    dag_id: str,
    run_id: str,
    result_path: Path,
    poll_interval: float,
    deserialize_fn: Callable[[bytes], Any],
    interfaces: TaskSdkInterfaces,
    failure_state_values: Iterable[Optional[str]],
    success_state_values: Iterable[Optional[str]],
) -> Dict[str, Any]:
    """Poll the scheduler result file until outputs are available or a failure occurs."""

    deadline = time.monotonic() + max(poll_interval * 10, 3600)
    logger = logging.getLogger(__name__)

    while time.monotonic() < deadline:
        if result_path.exists():
            serialized_result = result_path.read_text()
            try:
                payload = deserialize_fn(serialized_result.encode("utf-8"))
            except Exception as exc:  # pragma: no cover - runtime specific
                raise RuntimeError(
                    f"Failed to load scheduler result payload for run '{run_id}'"
                ) from exc

            return payload if isinstance(payload, dict) else {}

        dag_state = _fetch_run_state(interfaces, dag_id, run_id)
        if dag_state in failure_state_values:
            raise RuntimeError(
                f"Airflow scheduler reported failure for DAG '{dag_id}' run '{run_id}'"
            )

        if dag_state in success_state_values and not result_path.exists():
            logger.debug(
                "Waiting for result payload for DAG '%s' run '%s' despite successful state",
                dag_id,
                run_id,
            )

        time.sleep(min(1.0, poll_interval))

    raise RuntimeError(
        "Airflow scheduler run completed but no result payload was produced"
    )


def _ensure_scheduler_detects_dag(
    *,
    dag_id: str,
    dag_path: Path,
    airflow_home: Path,
    dags_dir: Path,
    poll_interval: float,
    timeout: float,
) -> None:
    try:  # pragma: no cover - requires Airflow runtime
        from airflow.models.dagbag import DagBag
    except Exception:
        return

    os.environ.setdefault("AIRFLOW_HOME", str(airflow_home))
    os.environ.setdefault("AIRFLOW__CORE__DAGS_FOLDER", str(dags_dir))

    logger = logging.getLogger(__name__)
    if logger.isEnabledFor(logging.INFO):
        log_method = logger.info
    elif logger.isEnabledFor(logging.WARNING):
        log_method = logger.warning
    else:
        log_method = logger.error

    def log(message: str, *args: Any) -> None:
        log_method(message, *args)

    if not dag_path.exists():
        log("DAG path '%s' is missing while waiting for scheduler detection", dag_path)

    log(
        "Waiting for Airflow scheduler to parse DAG '%s' located at '%s'",
        dag_id,
        dag_path,
    )

    deadline = time.monotonic() + timeout
    last_error: Optional[Exception] = None
    attempt = 0

    while time.monotonic() < deadline:
        attempt += 1
        try:
            dagbag = DagBag(
                dag_folder=str(dags_dir),
                include_examples=False,
                safe_mode=False,
            )
        except Exception as exc:
            last_error = exc
            log(
                "DagBag refresh failed while waiting for '%s' (attempt %d): %s",
                dag_id,
                attempt,
                exc,
            )
            time.sleep(max(0.5, poll_interval))
            continue

        if getattr(dagbag, "import_errors", {}):
            log(
                "DagBag import errors while polling for '%s' (attempt %d): %s",
                dag_id,
                attempt,
                dagbag.import_errors,
            )

        if dag_id in getattr(dagbag, "dags", {}):
            log("DagBag detected DAG '%s' after %d attempt(s)", dag_id, attempt)
            return

        log(
            "DagBag did not detect DAG '%s' on attempt %d; sleeping for %.1fs",
            dag_id,
            attempt,
            max(0.5, poll_interval),
        )

        time.sleep(max(0.5, poll_interval))

    if last_error is not None:
        log(
            "Unable to confirm DAG '%s' registration before timeout after %d attempt(s): %s",
            dag_id,
            attempt,
            last_error,
        )
    else:
        log(
            "Timed out waiting for DAG '%s' to be detected after %d attempt(s)",
            dag_id,
            attempt,
        )


def _build_runtime_kwargs(
    *,
    incoming: Iterable[IncomingSpec],
    source_map: Dict[str, Any],
    target_name: str,
) -> Dict[str, Any]:
    grouped: Dict[str, List[IncomingSpec]] = {}
    for spec in incoming:
        if spec["target"] != target_name:
            continue
        grouped.setdefault(spec["target_socket"], []).append(spec)

    kwargs: Dict[str, Any] = {}
    for to_socket, specs in grouped.items():
        active_links = [s for s in specs if s["from_socket"] != "_wait"]
        if not active_links:
            continue
        if len(active_links) == 1:
            spec = active_links[0]
            from_payload = source_map.get(spec["from"], {})
            from_socket = spec["from_socket"]
            if from_socket == "_outputs":
                kwargs[to_socket] = from_payload
            else:
                kwargs[to_socket] = get_nested_dict(
                    from_payload, from_socket, default=None
                )
            continue
        bundle: Dict[str, Any] = {}
        for spec in active_links:
            from_socket = spec["from_socket"]
            if from_socket in ("_wait", "_outputs"):
                continue
            from_payload = source_map.get(spec["from"], {})
            bundle_key = f"{spec['from']}_{from_socket}"
            bundle[bundle_key] = get_nested_dict(
                from_payload, from_socket, default=None
            )
        if bundle:
            kwargs[to_socket] = bundle

    return kwargs


def _ensure_meta(meta: Any) -> "NodeTaskMeta":
    from ..core.task import NodeTaskMeta

    if isinstance(meta, NodeTaskMeta):
        return meta
    if isinstance(meta, dict):
        return NodeTaskMeta(**meta)
    if hasattr(meta, "as_dict"):
        return NodeTaskMeta(**meta.as_dict())
    raise TypeError(f"Unsupported metadata payload: {meta!r}")


def _execute_node(
    *,
    parent_pid: Optional[str],
    meta: Any,
    callable_payload: Optional[Dict[str, Any]],
    literals: Dict[str, Any],
    incoming: Iterable[IncomingSpec],
    source_map: Dict[str, Any],
    engine_name: str,
    node_inputs: Optional[Dict[str, Any]],
    node_outputs: Optional[Dict[str, Any]],
    default_user_email: str,
    sub_engine_config: Dict[str, Any],
    schedule_subgraphs: bool,
    task_context: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
    meta_obj = _ensure_meta(meta)
    runtime_inputs = dict(literals)
    runtime_inputs.update(
        _build_runtime_kwargs(
            incoming=incoming,
            source_map=source_map,
            target_name=meta_obj.node_name,
        )
    )
    runtime_inputs = update_nested_dict_with_special_keys(runtime_inputs)

    user = load_default_user(default_user_email)

    def _build_sub_engine(name: str) -> "AirflowEngine":
        sanitized = _sanitize_dag_id(name)
        return AirflowEngine(
            dag_id=sanitized,
            default_args=sub_engine_config.get("default_args"),
            schedule=sub_engine_config.get("schedule"),
            start_date=sub_engine_config.get("start_date"),
            catchup=sub_engine_config.get("catchup", False),
            max_active_runs=sub_engine_config.get("max_active_runs", 1),
            _default_user_email=default_user_email,
            schedule_subgraphs=sub_engine_config.get(
                "schedule_subgraphs", schedule_subgraphs
            ),
        )

    try:
        return execute_node_job(
            parent_pid=parent_pid,
            meta=meta_obj,
            callable_payload=callable_payload,
            runtime_inputs=runtime_inputs,
            engine_name=engine_name,
            node_inputs=node_inputs,
            node_outputs=node_outputs,
            build_sub_engine=_build_sub_engine,
            user=user,
            schedule_subgraphs=schedule_subgraphs,
            task_context=task_context,
        )
    finally:
        close_threadlocal_aiida_session()


def _airflow_node_task(**context: Any) -> Dict[str, Any]:
    runtime_context_task_id = context.get("_ng_runtime_context_task_id")

    ti = context.get("ti")
    runtime_context: Optional[Dict[str, Any]] = None
    if runtime_context_task_id and ti is not None:
        runtime_context = ti.xcom_pull(task_ids=runtime_context_task_id)

    parent_pid = context.get("_ng_parent_pid")
    if parent_pid is None and runtime_context:
        parent_pid = runtime_context.get("graph_pid")
    if parent_pid is None:
        dag_run = context.get("dag_run")
        if dag_run is not None:
            parent_pid = dag_run.conf.get("ng_parent_pid")

    default_user_email = context.get("_ng_default_user_email")
    if default_user_email is None:
        default_user_email = get_default_user_email()

    base_values: Dict[str, Any]
    if runtime_context and "values" in runtime_context:
        base_values = dict(runtime_context.get("values", {}))
    else:
        base_values = dict(context.get("_ng_base_values", {}))

    source_map = dict(base_values)
    incoming_specs: Iterable[IncomingSpec] = context.get("_ng_incoming", [])
    upstream_ids = {spec["from"] for spec in incoming_specs}
    if ti is not None:
        for task_id in upstream_ids:
            if task_id in BUILTIN_NODES:
                continue
            pulled = ti.xcom_pull(task_ids=task_id)
            if pulled is not None:
                source_map[task_id] = pulled

    graph_pid = runtime_context.get("graph_pid") if runtime_context else None

    try:
        result = _execute_node(
            parent_pid=parent_pid,
            meta=context["_ng_meta"],
            callable_payload=context.get("_ng_callable"),
            literals=context.get("_ng_literals", {}),
            incoming=incoming_specs,
            source_map=source_map,
            engine_name=context.get("_ng_engine_name", "airflow"),
            node_inputs=context.get("_ng_node_inputs"),
            node_outputs=context.get("_ng_node_outputs"),
            default_user_email=default_user_email,
            sub_engine_config=context.get("_ng_engine_config", {}),
            schedule_subgraphs=context.get("_ng_schedule_subgraphs", False),
            task_context=context,
        )
    except Exception as exc:
        if graph_pid:
            process_node = orm.load_node(graph_pid)
            mark_process_failure(process_node, exc)
            process_node.seal()
        raise

    if ti is not None:
        ti.xcom_push(key="return_value", value=result)
    return result


def _airflow_init_task(**context: Any) -> Dict[str, Any]:
    ng: NodeGraph = context["_ng_graph"]
    default_user_email = context.get("_ng_default_user_email")
    if default_user_email is None:
        default_user_email = get_default_user_email()
    user = load_default_user(default_user_email)

    parent_pid = context.get("_ng_parent_pid")
    dag_run = context.get("dag_run")
    if parent_pid is None and dag_run is not None:
        parent_pid = dag_run.conf.get("ng_parent_pid")

    graph_context = prepare_graph_run(
        ng,
        parent_pid=parent_pid,
        user=user,
        encode_graph_inputs=True,
    )

    builtins = dict(context.get("_ng_builtins", {}))
    for key, value in builtins.items():
        graph_context.values.setdefault(key, value)

    return {
        "graph_pid": graph_context.process_node.uuid,
        "values": dict(graph_context.values),
    }


def _airflow_finalize_task(**context: Any) -> Dict[str, Any]:
    runtime_context_task_id = context["_ng_context_task_id"]
    node_task_ids: Iterable[str] = context.get("_ng_node_task_ids", [])
    incoming_specs: Dict[str, List[IncomingSpec]] = context.get("_ng_incoming", {})

    ti = context.get("ti")
    if ti is None:
        raise RuntimeError("Task instance context is required for finalize task")

    dag_run = context.get("dag_run")
    result_path_str: Optional[str] = None
    if dag_run is not None and getattr(dag_run, "conf", None):
        result_path_str = dag_run.conf.get("ng_result_path")

    runtime_context = ti.xcom_pull(task_ids=runtime_context_task_id)
    if not runtime_context:
        raise RuntimeError(
            "Missing NodeGraph runtime context; ensure init task succeeded"
        )

    graph_pid = runtime_context.get("graph_pid")
    if not graph_pid:
        raise RuntimeError("NodeGraph runtime context did not provide a graph PID")

    process_node = orm.load_node(graph_pid)

    if process_node.is_excepted:
        if not process_node.is_sealed:
            process_node.seal()
        raise RuntimeError(
            "NodeGraph execution failed; see upstream task logs for details"
        )

    values: Dict[str, Any] = dict(runtime_context.get("values", {}))
    for task_id in node_task_ids:
        if task_id in BUILTIN_NODES:
            continue
        pulled = ti.xcom_pull(task_ids=task_id)
        if pulled is not None:
            values[task_id] = pulled

    try:
        graph_outputs = compute_graph_outputs(
            incoming=incoming_specs,
            values=values,
            link_builder=lambda target, links, source_map: _build_runtime_kwargs(
                incoming=links,
                source_map=source_map,
                target_name=target,
            ),
        )
        mark_process_success(process_node, graph_outputs)
        if result_path_str:
            try:
                from aiida.orm.utils.serialize import serialize

                result_payload = serialize(
                    {"graph_pid": graph_pid, "outputs": graph_outputs}
                )
                if not isinstance(result_payload, (str, bytes)):
                    result_payload = json.dumps(result_payload)
                if isinstance(result_payload, bytes):
                    result_payload = result_payload.decode("utf-8")

                result_path = Path(result_path_str)
                result_path.parent.mkdir(parents=True, exist_ok=True)
                result_path.write_text(result_payload)
            except Exception:
                logging.getLogger(__name__).exception(
                    "Failed to persist NodeGraph scheduler results to %s",
                    result_path_str,
                )
        return graph_outputs
    except Exception as exc:
        mark_process_failure(process_node, exc)
        raise
    finally:
        if not process_node.is_sealed:
            process_node.seal()


@dataclass
class _CompiledDag:
    dag: "DAG"
    order: Tuple[str, ...]
    incoming: Dict[str, Any]
    incoming_specs: Dict[str, List[IncomingSpec]]
    task_configs: Dict[str, Dict[str, Any]]
    builtins: Dict[str, Any]


class AirflowEngine(BaseEngine):
    """Build Airflow DAGs from NodeGraph workflows."""

    engine_kind = "airflow"

    def __init__(
        self,
        dag_id: str = "node-graph-dag",
        *,
        default_args: Optional[Dict[str, Any]] = None,
        schedule: Optional[Any] = None,
        start_date: Optional[Any] = None,
        catchup: bool = False,
        max_active_runs: int = 1,
        _default_user_email: Optional[str] = None,
        schedule_subgraphs: bool = False,
    ) -> None:
        super().__init__(dag_id)
        self.dag_id = dag_id
        self.default_args = dict(default_args) if default_args else {}
        self.schedule = schedule
        self.start_date = start_date
        self.catchup = catchup
        self.max_active_runs = max_active_runs
        default_email = _default_user_email or get_default_user_email()
        self._default_user_email = default_email
        self.schedule_subgraphs = schedule_subgraphs

    def _compile(
        self,
        ng: NodeGraph,
        *,
        base_values: Optional[Dict[str, Any]] = None,
        runtime_context_task_id: Optional[str] = None,
        schedule_subgraphs: Optional[bool] = None,
    ) -> _CompiledDag:
        if schedule_subgraphs is None:
            schedule_subgraphs = self.schedule_subgraphs

        order, incoming, _required = _scan_links_topology(ng)

        incoming_specs: Dict[str, List[IncomingSpec]] = {}
        for target, links in incoming.items():
            incoming_specs[target] = [
                {
                    "from": lk.from_node.name,
                    "from_socket": lk.from_socket._scoped_name,
                    "target": lk.to_node.name,
                    "target_socket": lk.to_socket._scoped_name,
                }
                for lk in links
            ]

        init_params = inspect.signature(DAG.__init__).parameters
        accepts_kwargs = any(
            param.kind == inspect.Parameter.VAR_KEYWORD
            for param in init_params.values()
        )

        dag_kwargs: Dict[str, Any] = {"dag_id": self.dag_id}

        def _add_param(name: str, value: Any) -> None:
            if accepts_kwargs or name in init_params:
                dag_kwargs[name] = value

        _add_param("default_args", dict(self.default_args))
        effective_start_date = self.start_date
        if effective_start_date is None and schedule_subgraphs:
            effective_start_date = datetime.now(timezone.utc) - timedelta(minutes=1)
        _add_param("catchup", self.catchup)
        if accepts_kwargs or "max_active_runs" in init_params:
            dag_kwargs["max_active_runs"] = self.max_active_runs
        elif "max_active_runs_per_dag" in init_params:
            dag_kwargs["max_active_runs_per_dag"] = self.max_active_runs
        if effective_start_date is not None:
            _add_param("start_date", effective_start_date)
        _add_param("is_paused_upon_creation", False)

        schedule_value = self.schedule
        if schedule_value is not None:
            if accepts_kwargs or "schedule" in init_params:
                dag_kwargs["schedule"] = schedule_value
            elif "schedule_interval" in init_params:
                dag_kwargs["schedule_interval"] = schedule_value
            elif "timetable" in init_params:
                dag_kwargs["timetable"] = schedule_value
            else:
                raise RuntimeError(
                    "Unsupported DAG signature: unable to determine schedule parameter"
                )

        dag = DAG(**dag_kwargs)

        builtin_snapshot = self._snapshot_builtins(ng)
        if base_values is None:
            base_values = builtin_snapshot
        else:
            merged_values = dict(builtin_snapshot)
            merged_values.update(base_values)
            base_values = merged_values
        task_configs: Dict[str, Dict[str, Any]] = {}
        tasks: Dict[str, "PythonOperator"] = {}

        for name in ng.get_node_names():
            if name in BUILTIN_NODES:
                continue
            node = ng.nodes[name]
            node_type = getattr(node.spec, "node_type", "") or ""
            label_kind = "return" if node_type.upper() == "GRAPH" else "create"
            executor = getattr(node.spec, "executor", None)
            callable_payload = executor.to_dict() if executor is not None else None
            incoming_specs_for_node = incoming_specs.get(name, [])
            op_kwargs = {
                "_ng_meta": self._build_node_task_meta(node, label_kind).as_dict(),
                "_ng_callable": callable_payload,
                "_ng_engine_name": self.dag_id,
                "_ng_node_inputs": (
                    node.spec.inputs.to_dict() if node.spec.inputs else None
                ),
                "_ng_node_outputs": (
                    node.spec.outputs.to_dict() if node.spec.outputs else None
                ),
                "_ng_incoming": incoming_specs_for_node,
                "_ng_literals": _collect_literals(node),
                "_ng_default_user_email": self._default_user_email,
                "_ng_engine_config": {
                    "default_args": self.default_args,
                    "schedule": self.schedule,
                    "start_date": self.start_date,
                    "catchup": self.catchup,
                    "max_active_runs": self.max_active_runs,
                    "schedule_subgraphs": schedule_subgraphs,
                },
                "_ng_schedule_subgraphs": schedule_subgraphs,
            }
            if runtime_context_task_id is not None:
                op_kwargs["_ng_runtime_context_task_id"] = runtime_context_task_id
            else:
                op_kwargs["_ng_base_values"] = base_values

            task = PythonOperator(
                task_id=name,
                dag=dag,
                python_callable=_airflow_node_task,
                op_kwargs=op_kwargs,
            )
            tasks[name] = task
            upstream_ids = {
                lk.from_node.name
                for lk in incoming.get(name, [])
                if lk.from_node.name not in BUILTIN_NODES
            }
            task_configs[name] = {
                "meta": op_kwargs["_ng_meta"],
                "callable": callable_payload,
                "literals": op_kwargs["_ng_literals"],
                "incoming": incoming_specs_for_node,
                "node_inputs": op_kwargs["_ng_node_inputs"],
                "node_outputs": op_kwargs["_ng_node_outputs"],
                "engine_config": op_kwargs["_ng_engine_config"],
                "base_values": base_values,
                "upstream": upstream_ids,
                "schedule_subgraphs": schedule_subgraphs,
            }

        for name, task in tasks.items():
            for lk in incoming.get(name, []):
                upstream = lk.from_node.name
                if upstream in tasks:
                    tasks[upstream] >> task

        return _CompiledDag(
            dag=dag,
            order=tuple(order),
            incoming=incoming,
            incoming_specs=incoming_specs,
            task_configs=task_configs,
            builtins=builtin_snapshot,
        )

    def build_dag(self, ng: NodeGraph) -> DAG:
        """Return an Airflow DAG representing ``ng`` without executing it."""
        context_task_id = "engine__init"
        finalize_task_id = "engine__finalize"

        compiled = self._compile(
            ng, runtime_context_task_id=context_task_id, schedule_subgraphs=True
        )
        dag = compiled.dag

        init_task = PythonOperator(
            task_id=context_task_id,
            dag=dag,
            python_callable=_airflow_init_task,
            op_kwargs={
                "_ng_graph": ng,
                "_ng_default_user_email": self._default_user_email,
                "_ng_builtins": compiled.builtins,
            },
        )

        node_task_ids = [
            name
            for name in compiled.order
            if name not in BUILTIN_NODES and name in dag.task_dict
        ]

        finalize_task = PythonOperator(
            task_id=finalize_task_id,
            dag=dag,
            python_callable=_airflow_finalize_task,
            op_kwargs={
                "_ng_context_task_id": context_task_id,
                "_ng_node_task_ids": node_task_ids,
                "_ng_incoming": compiled.incoming_specs,
            },
            trigger_rule="all_done",
        )

        init_task >> finalize_task

        for task_id in node_task_ids:
            task = dag.task_dict[task_id]
            init_task >> task
            task >> finalize_task

        return dag

    def run(self, ng: NodeGraph, parent_pid: Optional[str] = None) -> Dict[str, Any]:
        context = prepare_graph_run(
            ng,
            parent_pid=parent_pid,
        )
        self._graph_pid = context.process_node.uuid

        compiled = self._compile(
            ng, base_values=context.values, schedule_subgraphs=self.schedule_subgraphs
        )
        values = dict(context.values)
        xcom_store: Dict[str, Any] = {}

        try:
            for name in iterate_node_order(compiled.order):
                if name in BUILTIN_NODES:
                    continue
                config = compiled.task_configs[name]
                source_map = dict(values)
                source_map.update({k: xcom_store.get(k) for k in config["upstream"]})
                result = _execute_node(
                    parent_pid=context.process_node.uuid,
                    meta=config["meta"],
                    callable_payload=config["callable"],
                    literals=config["literals"],
                    incoming=config["incoming"],
                    source_map=source_map,
                    engine_name=self.dag_id,
                    node_inputs=config["node_inputs"],
                    node_outputs=config["node_outputs"],
                    default_user_email=self._default_user_email,
                    sub_engine_config=config["engine_config"],
                    schedule_subgraphs=config.get("schedule_subgraphs", False),
                )
                xcom_store[name] = result
                values[name] = result

            graph_outputs = compute_graph_outputs(
                incoming=compiled.incoming_specs,
                values=values,
                link_builder=lambda target, links, source_map: _build_runtime_kwargs(
                    incoming=links,
                    source_map=source_map,
                    target_name=target,
                ),
            )
            mark_process_success(context.process_node, graph_outputs)
            return graph_outputs
        except Exception as exc:
            mark_process_failure(context.process_node, exc)
            raise
        finally:
            context.process_node.seal()

    def run_via_scheduler(
        self,
        ng: NodeGraph,
        parent_pid: Optional[str] = None,
        task_context: Optional[Dict[str, Any]] = None,
        wait: bool = False,
    ) -> Dict[str, Any]:
        try:
            from airflow.utils.state import DagRunState
            from aiida.orm.utils.serialize import deserialize_unsafe, serialize
        except Exception as exc:  # pragma: no cover - requires Airflow runtime
            raise RuntimeError(
                "Airflow scheduler components are required to schedule sub-graphs"
            ) from exc

        dag_id = _sanitize_dag_id(self.dag_id)
        if dag_id != self.dag_id:
            self.dag_id = dag_id

        paths = _resolve_scheduler_paths(dag_id)
        dag_path = paths.dags_dir / f"{dag_id}.py"

        effective_start_date = self.start_date or (
            datetime.now(timezone.utc) - timedelta(minutes=1)
        )

        payload_blob = _build_scheduler_payload(
            dag_id=dag_id,
            ng=ng,
            engine_config={
                "default_args": self.default_args,
                "schedule": self.schedule,
                "start_date": effective_start_date,
                "catchup": self.catchup,
                "max_active_runs": self.max_active_runs,
            },
            default_user_email=self._default_user_email,
            schedule_subgraphs=self.schedule_subgraphs,
            serialize_fn=serialize,
        )

        dag_source = _render_generated_dag(payload_blob)
        _write_dag_file(dag_path, dag_source)

        discovery_interval = float(
            os.environ.get("NG_AIRFLOW_DAG_DISCOVERY_INTERVAL", 2.0)
        )
        discovery_timeout = float(
            os.environ.get("NG_AIRFLOW_DAG_DISCOVERY_TIMEOUT", 60.0)
        )

        _ensure_scheduler_detects_dag(
            dag_id=dag_id,
            dag_path=dag_path,
            airflow_home=paths.airflow_home,
            dags_dir=paths.dags_dir,
            poll_interval=discovery_interval,
            timeout=discovery_timeout,
        )

        _register_generated_dag(dag_id, dag_path, paths.dags_dir)

        run_id = f"ng_subgraph__{uuid.uuid4().hex}"
        artifacts = _prepare_run_artifacts(
            paths=paths, dag_id=dag_id, run_id=run_id, parent_pid=parent_pid
        )

        poll_interval = float(os.environ.get("NG_AIRFLOW_POLL_INTERVAL", 2.0))

        failure_states = {
            getattr(DagRunState, "FAILED", None),
            getattr(DagRunState, "CANCELLED", None),
            getattr(DagRunState, "REMOVED", None),
        }
        failure_state_values = {
            getattr(state, "value", state)
            for state in failure_states
            if state is not None
        }
        success_state_values = {
            getattr(getattr(DagRunState, "SUCCESS", None), "value", None)
        }
        success_state_values = {value for value in success_state_values if value}

        interfaces = _load_task_sdk_interfaces(task_context)

        _trigger_dag_with_retries(
            dag_id=dag_id,
            run_id=run_id,
            run_conf=artifacts.run_conf,
            discovery_timeout=discovery_timeout,
            interfaces=interfaces,
        )

        if not wait:
            logging.getLogger(__name__).info(
                "Triggered Airflow DAG '%s' run '%s' without waiting for completion.",
                dag_id,
                run_id,
            )
            logging.getLogger(__name__).info(
                "Visit the Airflow UI to monitor progress."
            )
            return None

        payload = _poll_for_result(
            dag_id=dag_id,
            run_id=run_id,
            result_path=artifacts.result_path,
            poll_interval=poll_interval,
            deserialize_fn=deserialize_unsafe,
            interfaces=interfaces,
            failure_state_values=failure_state_values,
            success_state_values=success_state_values,
        )

        graph_pid = payload.get("graph_pid")
        if graph_pid:
            self._graph_pid = graph_pid

        outputs = payload.get("outputs")
        if not isinstance(outputs, dict):
            outputs = {}
        return outputs
