"""Module containing tasks and flows for interacting with Census sync runs"""
import asyncio
from enum import Enum
from typing import Any, Dict, Tuple

from httpx import HTTPStatusError
from prefect import flow, task
from prefect.logging import get_run_logger

from prefect_census.credentials import CensusCredentials
from prefect_census.utils import extract_user_message


class CensusSyncRunFailed(RuntimeError):
    """Raised when unable to retrieve Census sync run"""


class CensusSyncRunTimeout(RuntimeError):
    """
    Raised when a triggered job run does not complete in the configured max
    wait seconds
    """


class CensusGetSyncRunInfoFailed(RuntimeError):
    """Used to idicate retrieve sync run info."""


class CensusSyncRunCancelled(Exception):
    """Raised when a triggered sync run is cancelled"""


class CensusSyncRunStatus(Enum):
    """Census sync statuses."""

    CANCELLED = "cancelled"
    WORKING = "working"
    FAILED = "failed"
    COMPLETED = "completed"
    SKIPPED = "skipped"
    QUEUED = "queued"

    @classmethod
    def is_terminal_status_code(cls, status_code: str) -> bool:
        """
        Returns True if a status code is terminal for a sync run.
        Returns False otherwise.
        """
        return status_code in [
            cls.CANCELLED.value,
            cls.FAILED.value,
            cls.COMPLETED.value,
            cls.SKIPPED.value,
        ]


@task(
    name="Get Census sync run details",
    description=(
        "Retrieves details of a Census sync run" "for the sync with the given sync_id."
    ),
    retries=3,
    retry_delay_seconds=10,
)
async def get_census_sync_run_info(
    credentials: CensusCredentials, run_id: int
) -> Dict[str, Any]:
    """
    A task to retrieve information a Census sync run.

    Args:
        credentials: Credentials for authenticating with Census.
        run_id: The ID of the run of the sync to trigger.

    Returns:
        The run data returned by the Census API as dict with the following shape:
            ```
            {
                "id": 94,
                "sync_id": 52,
                "source_record_count": 1,
                "records_processed": 1,
                "records_updated": 1,
                "records_failed": 0,
                "records_invalid": 0,
                "created_at": "2021-10-20T02:51:07.546Z",
                "updated_at": "2021-10-20T02:52:29.236Z",
                "completed_at": "2021-10-20T02:52:29.234Z",
                "scheduled_execution_time": null,
                "error_code": null,
                "error_message": null,
                "error_detail": null,
                "status": "completed",
                "canceled": false,
                "full_sync": true,
                "sync_trigger_reason": {
                    "ui_tag": "Manual",
                    "ui_detail": "Manually triggered by test@getcensus.com"
                }
            }
            ```


    Example:
        Get Census sync run info:
        ```python
        from prefect import flow

        from prefect_census import CensusCredentials
        from prefect_census.runs import get_census_sync_run_info

        @flow
        def get_sync_run_info_flow():
            credentials = CensusCredentials(api_key="my_api_key")

            return get_census_sync_run_info(
                credentials=credentials,
                run_id=42
            )

        get_sync_run_info_flow()
        ```
    """  # noqa
    try:
        async with credentials.get_client() as client:
            response = await client.get_run_info(run_id)
    except HTTPStatusError as e:
        raise CensusGetSyncRunInfoFailed(extract_user_message(e)) from e

    return response.json()["data"]


@flow(
    name="Wait for Census sync run",
    description="Waits for the Census sync run to finish running.",
)
async def wait_census_sync_completion(
    run_id: int,
    credentials: CensusCredentials,
    max_wait_seconds: int = 60,
    poll_frequency_seconds: int = 5,
) -> Tuple[CensusSyncRunStatus, Dict[str, Any]]:
    """
    Wait for the given Census sync run to finish running.

    Args:
        run_id: The ID of the sync run to wait for.
        credentials: Credentials for authenticating with Census.
        max_wait_seconds: Maximum number of seconds to wait for sync to complete.
        poll_frequency_seconds: Number of seconds to wait in between checks for
            run completion.

    Raises:
        CensusSyncRunTimeout: When the elapsed wait time exceeds `max_wait_seconds`.

    Returns:
        run_status: An enum representing the final Census sync run status.
        run_data: A dictionary containing information about the run after completion
            in the following shape:
            ```
            {
                "id": 94,
                "sync_id": 52,
                "source_record_count": 1,
                "records_processed": 1,
                "records_updated": 1,
                "records_failed": 0,
                "records_invalid": 0,
                "created_at": "2021-10-20T02:51:07.546Z",
                "updated_at": "2021-10-20T02:52:29.236Z",
                "completed_at": "2021-10-20T02:52:29.234Z",
                "scheduled_execution_time": null,
                "error_code": null,
                "error_message": null,
                "error_detail": null,
                "status": "completed",
                "canceled": false,
                "full_sync": true,
                "sync_trigger_reason": {
                    "ui_tag": "Manual",
                    "ui_detail": "Manually triggered by test@getcensus.com"
                }
            }
            ```

    """
    logger = get_run_logger()
    seconds_waited_for_run_completion = 0
    wait_for = []
    while seconds_waited_for_run_completion <= max_wait_seconds:
        run_data_future = await get_census_sync_run_info.submit(
            credentials=credentials,
            run_id=run_id,
            wait_for=wait_for,
        )
        run_data = await run_data_future.result()
        run_status = run_data.get("status")

        if CensusSyncRunStatus.is_terminal_status_code(run_status):
            return CensusSyncRunStatus(run_status), run_data

        wait_for = [run_data_future]
        logger.info(
            "Census sync run with ID %i has status %s. Waiting for %i seconds.",
            run_id,
            CensusSyncRunStatus(run_status).name,
            poll_frequency_seconds,
        )
        await asyncio.sleep(poll_frequency_seconds)
        seconds_waited_for_run_completion += poll_frequency_seconds

    raise CensusSyncRunTimeout(
        f"Max wait time of {max_wait_seconds} seconds exceeded while waiting "
        f"for sync run with ID {run_id}"
    )
