from __future__ import annotations

import os
from abc import abstractmethod
from collections.abc import Callable, Iterable, Mapping, MutableMapping
from dataclasses import replace
from datetime import timedelta
from functools import cached_property
from pathlib import Path
from subprocess import STDOUT, CalledProcessError, check_output
from types import TracebackType
from typing import (
    Any,
    Literal,
    Optional,
    TypeVar,
)

import pandas as pd
from atoti_core import (
    DEFAULT_QUERY_TIMEOUT,
    EMPTY_MAPPING,
    LICENSE_KEY,
    LICENSE_KEY_ENV_VAR_NAME,
    PLUGINS,
    ActiveViamClient,
    BaseSession,
    Context,
    CoordinatesT,
    DataType,
    MissingPluginError,
    PathLike,
    Plugin,
    doc,
)
from atoti_query import QuerySession
from atoti_query._internal import Security
from py4j.java_gateway import DEFAULT_PORT as _PY4J_DEFAULT_PORT

from ._basic_credentials import BasicCredentials
from ._docs_utils import EXPLAIN_QUERY_DOC
from ._endpoint import EndpointHandler
from ._exceptions import AtotiException, AtotiJavaException
from ._get_java_executable_path import get_java_executable_path
from ._is_jwt_expired import is_jwt_expired
from ._java_api import JavaApi
from ._local_cube import LocalCube
from ._local_cubes import LocalCubes
from ._path_utils import to_absolute_path
from ._query_plan import QueryAnalysis
from ._runtime_type_checking_utils import typecheck
from ._server_subprocess import ServerSubprocess
from .config._session_config import SessionConfig
from .table import Table, _LoadKafka, _LoadSql

_LocalCubes = TypeVar("_LocalCubes", bound=LocalCubes[LocalCube[Any, Any, Any]])


def _add_plugin_app_extensions_to_config(
    config: SessionConfig, /, *, plugins: Mapping[str, Plugin]
) -> SessionConfig:
    config_app_extensions = {**(config.app_extensions or {})}
    app_extensions = config_app_extensions.copy()

    for plugin_key, plugin in plugins.items():
        for extension_name, extension_path in plugin.app_extensions.items():
            if extension_name in config_app_extensions:
                raise ValueError(
                    f"App extension `{extension_name}` is declared both in the session's configuration and in the plugin `{plugin_key}`."
                )
            if extension_name in app_extensions:
                raise ValueError(
                    f"App extension `{extension_name}` is declared in multiple plugins."
                )
            app_extensions[extension_name] = extension_path

    return replace(config, app_extensions=app_extensions)


@typecheck
class LocalSession(BaseSession[_LocalCubes, Security]):
    """Local session class."""

    def __init__(
        self,
        *,
        config: SessionConfig,
        distributed: bool,
        license_key: Optional[str],
        name: Optional[str],
        plugins: Optional[Mapping[str, Plugin]],
    ):
        super().__init__()

        self._name = name

        def load_kafka(
            table: Table,  # noqa: ARG001
            /,
            bootstrap_server: str,  # noqa: ARG001
            topic: str,  # noqa: ARG001
            *,
            group_id: str,  # noqa: ARG001
            batch_duration: int,  # noqa: ARG001
            consumer_config: Mapping[str, str],  # noqa: ARG001
        ) -> None:
            raise MissingPluginError("kafka")

        self._load_kafka: _LoadKafka = load_kafka

        def load_sql(
            table: Table,  # noqa: ARG001
            /,
            sql: str,  # noqa: ARG001
            *,
            url: str,  # noqa: ARG001
            driver: Optional[str] = None,  # noqa: ARG001
        ) -> None:
            raise MissingPluginError("sql")

        self._load_sql: _LoadSql = load_sql

        if plugins is None:
            plugins = PLUGINS.default

        self._plugins = plugins
        self._config = _add_plugin_app_extensions_to_config(config, plugins=plugins)
        self.__jwt: Optional[str] = None

        if not license_key and LICENSE_KEY.use_env_var:
            license_key = os.environ.get(LICENSE_KEY_ENV_VAR_NAME)

        self._create_subprocess_and_java_api(
            distributed=distributed, license_key=license_key
        )

        for plugin in plugins.values():
            plugin.init_session(self)

        if license_key:
            self._java_api.gateway.jvm.io.atoti.plugins.PlusPlugin.init()
        try:
            self._start_application()
        except AtotiJavaException as ave:
            raise AtotiException(
                f"{ave.java_traceback}\n"
                f"An error occurred while configuring the session.\n"
                f"The logs are available at {self.logs_path}"
            ) from None

        self._closed = False

    def _create_subprocess_and_java_api(
        self, *, distributed: bool, license_key: Optional[str]
    ) -> None:
        try:
            # Attempt to connect to an existing detached process (useful for debugging).
            # Failed attempts are very fast (usually less than 2ms): users won't notice them.
            self._java_api = JavaApi(
                py4j_java_port=_PY4J_DEFAULT_PORT,
                distributed=distributed,
            )
            self._server_subprocess = None
        except ConnectionRefusedError:
            # No available unauthenticated detached process: creating subprocess.
            process = ServerSubprocess(
                config=self._config,
                license_key=license_key,
                plugins=self._plugins,
                session_id=self._id,
            )
            self._java_api = JavaApi(
                auth_token=process.auth_token,
                py4j_java_port=process.py4j_java_port,
                distributed=distributed,
            )
            self._server_subprocess = process

    @cached_property
    def __client(self) -> ActiveViamClient:
        return ActiveViamClient(
            self._local_url,
            auth=lambda _url: self._generate_auth_headers(),
            certificate_authority=Path(self._config.https.certificate_authority)
            if self._config.https and self._config.https.certificate_authority
            else None,
        )

    @property
    def _client(self) -> ActiveViamClient:
        return self.__client

    @property
    def name(self) -> Optional[str]:
        """Name of the session."""
        return self._name

    @property
    @abstractmethod
    def cubes(self) -> _LocalCubes:
        """Cubes of the session."""

    @property
    def security(self) -> Security:
        return self._security

    @property
    def _security(self) -> Security:
        return self.__security

    @cached_property
    def __security(self) -> Security:
        return Security(basic_credentials=self._basic_credentials, client=self._client)

    @property
    def closed(self) -> bool:
        """Return whether the session is closed or not."""
        return self._closed

    @cached_property
    def port(self) -> int:
        """Port on which the session is exposed.

        Can be configured with :func:`atoti.Session`'s *port*  parameter.

        """
        return self._java_api.get_session_port()

    @cached_property
    def logs_path(self) -> Path:
        """Path to the session logs file."""
        assert (
            self._server_subprocess
        ), "The logs path is not available when using a detached server process."
        return self._server_subprocess.logs_path

    def _start_application(self) -> None:
        self._java_api.start_application(self._config)

    def __exit__(  # pylint: disable=too-many-positional-parameters
        self,
        exc_type: Optional[type[BaseException]],
        exc_value: Optional[BaseException],
        traceback: Optional[TracebackType],
    ) -> None:
        self.close()

    def _clear(self) -> None:
        """Clear this session and free all the associated resources."""
        self._java_api.clear_session()

    def close(self) -> None:
        """Close this session and free all the associated resources."""
        self._java_api.shutdown()
        if self._server_subprocess:
            self.wait()
        self._closed = True

    def wait(self) -> None:
        """Wait for the underlying server subprocess to terminate.

        This will prevent the Python process to exit.
        """
        assert self._server_subprocess
        self._server_subprocess.wait()

    @property
    def _location(self) -> Mapping[str, Any]:
        return {
            "https": self._config.https is not None,
            "port": self.port,
        }

    @cached_property
    def _basic_credentials(self) -> Optional[MutableMapping[str, str]]:
        return BasicCredentials(java_api=self._java_api)

    @property
    def _local_url(self) -> str:
        return (
            f"https://{self._config.https.domain}:{self.port}"
            if self._config.https is not None
            else f"http://localhost:{self.port}"
        )

    def _generate_token(self) -> str:
        """Return a token that can be used to authenticate against the server."""
        return self._java_api.generate_jwt()

    def _block_until_widget_loaded(self, widget_id: str) -> None:
        self._java_api.block_until_widget_loaded(widget_id)

    def _create_query_session(self) -> QuerySession:
        return QuerySession(
            self._local_url,
            client=self._client,  # Sharing the client to avoid refetching the server versions.
        )

    def query_mdx(
        self,
        mdx: str,
        *,
        keep_totals: bool = False,
        timeout: timedelta = DEFAULT_QUERY_TIMEOUT,
        mode: Literal["pretty", "raw"] = "pretty",
        context: Context = EMPTY_MAPPING,
    ) -> pd.DataFrame:
        def get_data_types(
            coordinates: Iterable[CoordinatesT], /, *, cube_name: str
        ) -> dict[CoordinatesT, DataType]:
            return self.cubes[cube_name]._get_data_types(coordinates)

        return self._create_query_session().query_mdx(
            mdx,
            get_data_types=get_data_types,
            keep_totals=keep_totals,
            timeout=timeout,
            session=self,
            mode=mode,
            context=context,
        )

    @doc(EXPLAIN_QUERY_DOC, corresponding_method="query_mdx")
    def explain_mdx_query(
        self, mdx: str, *, timeout: timedelta = DEFAULT_QUERY_TIMEOUT
    ) -> QueryAnalysis:
        return self._java_api.analyze_mdx(mdx, timeout=timeout)

    def _generate_auth_headers(self) -> dict[str, str]:
        return {"Authorization": f"Jwt {self._jwt}"}

    @property
    def _jwt(self) -> str:
        if not self.__jwt or is_jwt_expired(self.__jwt):
            self.__jwt = self._java_api.generate_jwt()
        return self.__jwt

    def endpoint(
        self, route: str, *, method: Literal["POST", "GET", "PUT", "DELETE"] = "GET"
    ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
        """Create a custom endpoint at ``/atoti/pyapi/{route}"``.

        This is useful to reuse Atoti's built-in server instead of adding a `FastAPI <https://fastapi.tiangolo.com/>`__ or `Flask <https://flask.palletsprojects.com/>`__ server to the project.
        This way, when deploying the project in a container or a VM, only one port (the one of the Atoti server) can be exposed instead of two.
        Since custom endpoints are exposed by Atoti's server, they automatically inherit from the configured :func:`atoti.Session`'s *authentication* and *https* parameter.

        The decorated function must take three parameters with types :class:`atoti.pyapi.User`, :class:`atoti.pyapi.HttpRequest`, and :class:`atoti.Session` and return a response body as a Python data structure that can be converted to JSON.

        Args:
            route: The path suffix after ``/atoti/pyapi/``.
                For instance, if ``custom/search`` is passed, a request to ``/atoti/pyapi/custom/search?query=test#results`` will match.
                The route should not contain the query (``?``) or fragment (``#``).

                Path parameters can be configured by wrapping their name in curly braces in the route.
            method: The HTTP method the request must be using to trigger this endpoint.
                ``DELETE``, ``POST``, and ``PUT`` requests can have a body but it must be JSON.

        Example:
            .. doctest:: Session.endpoint
                :skipif: True

                >>> import requests
                >>> df = pd.DataFrame(
                ...     columns=["Year", "Month", "Day", "Quantity"],
                ...     data=[
                ...         (2019, 7, 1, 15),
                ...         (2019, 7, 2, 20),
                ...     ],
                ... )
                >>> table = session.read_pandas(df, table_name="Quantity")
                >>> table.head()
                Year  Month  Day  Quantity
                0  2019      7    1        15
                1  2019      7    2        20
                >>> endpoints_base_url = f"http://localhost:{session.port}/atoti/pyapi"
                >>> @session.endpoint("tables/{table_name}/size", method="GET")
                ... def get_table_size(request, user, session):
                ...     table_name = request.path_parameters["table_name"]
                ...     return len(session.tables[table_name])
                ...
                >>> requests.get(f"{endpoints_base_url}/tables/Quantity/size").json()
                2
                >>> @session.endpoint("tables/{table_name}/rows", method="POST")
                ... def append_rows_to_table(request, user, session):
                ...     rows = request.body
                ...     table_name = request.path_parameters["table_name"]
                ...     session.tables[table_name].append(*rows)
                ...
                >>> requests.post(
                ...     f"{endpoints_base_url}/tables/Quantity/rows",
                ...     json=[
                ...         {"Year": 2021, "Month": 5, "Day": 19, "Quantity": 50},
                ...         {"Year": 2021, "Month": 5, "Day": 20, "Quantity": 6},
                ...     ],
                ... ).status_code
                200
                >>> requests.get(f"{endpoints_base_url}/tables/Quantity/size").json()
                4
                >>> table.head()
                Year  Month  Day  Quantity
                0  2019      7    1        15
                1  2019      7    2        20
                2  2021      5   19        50
                3  2021      5   20         6

        """
        if route[0] == "/" or "?" in route or "#" in route:
            raise ValueError(
                f"Invalid route '{route}'. It should not start with '/' and not contain '?' or '#'."
            )

        def endpoint_decorator(callback: Callable[..., Any]) -> Callable[..., Any]:
            self._java_api.create_endpoint(
                http_method=method,
                route=route,
                handler=EndpointHandler(callback, session=self),
            )
            return callback

        return endpoint_decorator

    def export_translations_template(self, path: PathLike) -> None:
        """Export a template containing all translatable values in the session's cubes.

        Args:
            path: The path at which to write the template.
        """
        self._java_api.export_i18n_template(path)

    def _get_jfr_command(self, jfr_action: str, *args: str) -> list[str]:
        assert (
            self._server_subprocess
        ), "Cannot create flight recording with detached process."

        return [
            str(get_java_executable_path(executable_name="jcmd")),
            str(self._server_subprocess.pid),
            f"JFR.{jfr_action}",
            *args,
        ]

    def _create_flight_recording(self, path: PathLike, *, duration: timedelta) -> None:
        """Create a recording file using Java Flight Recorder (JFR).

        This call is non-blocking: ``jcmd`` will continue writing to the file at the specified *path* for the given *duration* after this function returns.
        Call :func:`time.sleep` with ``duration.total_seconds()`` to block the current thread until the end of the recording.

        Args:
            path: The path (with a :guilabel:`.jfr` extension) at which the recording file should be written to.
            duration: The duration of the recording.
        """
        command = self._get_jfr_command(
            "start",
            f"duration={int(duration.total_seconds())}s",
            f"filename={to_absolute_path(path)}",
        )

        try:
            check_output(
                command,  # noqa: S603
                stderr=STDOUT,
                text=True,
            )
        except CalledProcessError as error:
            raise RuntimeError(
                f"Failed to create flight recording:\n{error.output}"
            ) from error
