import logging
import re
from datetime import datetime
from typing import (
    Iterable,
    List,
    Literal,
    NamedTuple,
    Optional,
    TypeAlias,
    TypeGuard,
    get_args,
)

import pandas as pd
import requests

from epx.core.cloud.auth import platform_api_headers
from epx.core.utils.config import read_auth_config
from epx.core.utils.http_request import retry_request
from epx.run.exec.cloud.strategy import ForbiddenResponse, UnauthorizedUserError
from epx.run.fred_run import FREDRun
from epx.core.types.common import RunInfo, UserRequests

logger = logging.getLogger(__name__)

StatusName = Literal["NOT_STARTED", "RUNNING", "ERROR", "QUEUED", "DONE"]
LogLevel = Literal["DEBUG", "INFO", "WARNING", "ERROR"]
RunWithId: TypeAlias = tuple[int, FREDRun]


class _LogItem(NamedTuple):
    """An individual entry in the logs generated by FRED.

    Attributes
    ----------
    level : LogLevel
        The log level of the message, e.g. `INFO`, `ERROR`, etc.
    time : datetime
        The time that the message was reported at.
    message : str
        The log message.
    """

    level: LogLevel
    time: datetime
    message: str


class _JobExecutionStatus:
    def __init__(
        self,
        value: StatusName,
        errors: Optional[list[int]] = None,
        runs_done_count: int = 0,
        runs_executing_count: int = 0,
        runs_total_count: int = 0,
    ):
        self.value = value
        self.errors = errors if errors is not None else []
        self.runs_done_count = runs_done_count
        self.runs_executing_count = runs_executing_count
        self.runs_total_count = runs_total_count

    def __repr__(self):
        return (
            f"_JobExecutionStatus("
            f"status={self.value}, "
            f"errors={self.errors}, "
            f"runs_done_count={self.runs_done_count}, "
            f"runs_executing_count={self.runs_executing_count}, "
            f"runs_total_count={self.runs_total_count})"
        )

    def __str__(self) -> str:
        return self.value


class FREDJobStatus:
    def __init__(self, job_name: str, _run_with_ids: Iterable[RunWithId]):
        self.job_name = job_name
        self._run_with_ids = _run_with_ids
        self._log_re = re.compile(
            r"^\[(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}.\d{3}Z)\] ([A-Z]*): (.*)$"
        )

    def get_run_requests(self) -> List[RunInfo]:
        """Response run requests by job_name from the /runs endpoint.

        Returns
        -------
        Collection of run requests

        Raises
        ------
        UnauthorizedUserError
            If the user does not have sufficient privileges to perform the
            specified action on FRED Cloud.
        ConnectionError
            If network connection issues.
        RuntimeError
            If a FRED Cloud server error occurs.
        """

        endpoint_url = f"{read_auth_config('api-url')}/runs"
        params = {"job_name": str(self.job_name)}
        headers = platform_api_headers()

        try:
            # Get request for a run to be executed to FRED Cloud API
            response = retry_request(
                method="GET",
                url=endpoint_url,
                headers=headers,
                params=params,
            )
        except requests.exceptions.RequestException as e:
            raise e
        # Check HTTP response status code and raise exceptions as appropriate
        if not response.ok:
            if response.status_code == requests.codes.forbidden:
                raise UnauthorizedUserError(
                    ForbiddenResponse.model_validate_json(response.text).description
                )
            else:
                raise RuntimeError(f"FRED Cloud error code: {response.status_code}")

        response_payload = response.text
        response_body = UserRequests.model_validate_json(response_payload)
        logger.debug(f"Payload: {response.text}")

        return response_body.runs

    @property
    def logs(self) -> pd.DataFrame:
        """Return a collection of log entries output by FRED.

        Returns
        -------
        pd.DataFrame
            Collection of individual log entries generated by FRED during the
            run.
        """

        def process_log_line(line: str) -> _LogItem:
            m = self._log_re.match(line)
            if m is None:
                raise ValueError(f"Invalid logline: {line}")
            level = m.group(2)
            assert self._is_valid_log_level(level)
            return _LogItem(
                level,
                datetime.strptime(m.group(1), "%Y-%m-%dT%H:%M:%S.%fZ"),
                m.group(3),
            )

        data_frames: List[pd.DataFrame] = []
        for run_id, run in self._run_with_ids:
            try:
                p = f"{run.output_result_cache_dir}/logs.txt"
                with open(p, "r") as f:
                    log_records = tuple(
                        process_log_line(line) for line in f.readlines()
                    )
            except FileNotFoundError:
                log_records = tuple()
            logs = pd.DataFrame.from_records(
                log_records, columns=["level", "time", "message"]
            )
            data_frames.append(logs.assign(run_id=run_id))

        log_cols = ["run_id", "level", "time", "message"]
        if len(data_frames) != 0:
            return pd.concat(data_frames).loc[:, log_cols]
        return pd.DataFrame.from_records(tuple(), columns=log_cols)

    @property
    def name(self) -> _JobExecutionStatus:
        """Return a string summarizing the job status.

        Returns
        -------
        Status
            A string indicating the status of the job, one of: `"NOT_STARTED"`,
            `"RUNNING"`, `"ERROR"`, `"QUEUED"`or `"DONE"`.

        Raises
        ------
        UnauthorizedUserError
            If the user does not have sufficient privileges to perform the
            specified action on FRED Cloud.
        ConnectionError
            If network connection issues.
        RuntimeError
            If a FRED Cloud server error occurs.
        """

        status_names = set()
        errors: list[int] = []
        runs_executing_count = 0
        runs_done_count = 0

        runs = self.get_run_requests()

        if len(runs) == 0:
            return _JobExecutionStatus("NOT_STARTED", errors)

        for run in runs:
            status = run.status
            status_names.add(status)

            if status == "ERROR":
                errors.append(run.id)
            elif status == "RUNNING":
                runs_executing_count += 1
            elif status == "DONE":
                runs_done_count += 1

        # If all of the runs are NOT_STARTED, then the Job is NOT_STARTED
        # If any of the runs are in ERROR, then the Job is ERROR
        # If any of the runs are RUNNING, then the Job is RUNNING
        # If any of the runs are QUEUED, then the Job is QUEUED
        # If all of the runs are DONE, then the Job is DONE
        # Otherwise, the Job has some runs that are NOT_STARTED or RUNNING or DONE,
        # so mark the Job RUNNING
        if all(sn == "NOT_STARTED" for sn in status_names):
            return _JobExecutionStatus(
                "NOT_STARTED",
                errors,
                runs_executing_count=runs_executing_count,
                runs_done_count=runs_done_count,
                runs_total_count=len(runs),
            )
        if "ERROR" in status_names:
            return _JobExecutionStatus(
                "ERROR",
                errors,
                runs_executing_count=runs_executing_count,
                runs_done_count=runs_done_count,
                runs_total_count=len(runs),
            )
        if "RUNNING" in status_names:
            return _JobExecutionStatus(
                "RUNNING",
                errors,
                runs_executing_count=runs_executing_count,
                runs_done_count=runs_done_count,
                runs_total_count=len(runs),
            )
        if "QUEUED" in status_names:
            return _JobExecutionStatus(
                "QUEUED",
                errors,
                runs_executing_count=runs_executing_count,
                runs_done_count=runs_done_count,
                runs_total_count=len(runs),
            )
        if all(sn == "DONE" for sn in status_names):
            return _JobExecutionStatus(
                "DONE",
                errors,
                runs_executing_count=runs_executing_count,
                runs_done_count=runs_done_count,
                runs_total_count=len(runs),
            )

        logger.debug(f"Job {self.job_name} all runs are not complete {status_names}")
        return _JobExecutionStatus(
            "RUNNING",
            errors,
            runs_executing_count=runs_executing_count,
            runs_done_count=runs_done_count,
            runs_total_count=len(runs),
        )

    @staticmethod
    def _is_valid_log_level(level: str) -> TypeGuard[LogLevel]:
        """Helper method for validating that a string is a valid log level."""
        return level in get_args(LogLevel)

    def __repr__(self) -> str:
        return f"JobStatus({self.job_name})"

    def __str__(self) -> str:
        return self.name.value
