import uuid
from typing import Iterator, List, Optional

from pydantic import BaseModel

from dtx.core import logging
from dtx.core.converters.prompts import PromptVariableSubstitutor
from dtx.core.exceptions.agents import BaseAgentException, UnknownAgentException
from dtx.core.exceptions.base import FeatureNotImplementedError
from dtx_models.analysis import (
    PromptDataset,
    TestPromptWithEvalCriteria,
    TestPromptWithModEval,
    TestSuitePrompts,
)
from dtx_models.evaluator import TypeAndNameBasedEvaluator
from dtx_models.prompts import (
    BaseMultiTurnAgentResponse,
    BaseMultiTurnConversation,
    BaseMultiTurnResponse,
    BaseTestPrompt,
    MultiTurnTestPrompt,
    Turn,
)
from dtx_models.results import (
    AttemptsBenchmarkBuilder,
    AttemptsBenchmarkStats,
    EvalResult,
    ResponseEvaluationStatus,
)
from dtx_models.tactic import PromptMutationTactic
from dtx.plugins.redteam.tactics.generator import TacticalPromptGenerator
from dtx.plugins.redteam.tactics.repo import TacticRepo

from ...plugins.providers.base.agent import BaseAgent
from .evaluator import EvaluatorRouter


class AdvOptions(BaseModel):
    attempts: int  # Number of attempts to send the same prompt (for benchmarking)
    threads: int  # Number of concurrent threads for sending requests (this will be ignored now)


class EngineConfig:
    """
    Configuration class for setting up the evaluation engine.

    Attributes:
        evaluator_router (EvaluatorRouter): The router responsible for delegating evaluations.
        test_suites (List[TestSuitePrompts]): A list of test suites containing prompt-based tests.
        global_evaluator (Optional[TypeAndNameBasedEvaluator]):
            An optional evaluator to override all evaluation methods globally.
        adv_options (AdvOptions): Advanced options for configuring execution settings like
            number of attempts and threads.
    """

    def __init__(
        self,
        evaluator_router: EvaluatorRouter,
        test_suites: List[TestSuitePrompts],
        global_evaluator: Optional[TypeAndNameBasedEvaluator] = None,
        max_per_tactic: int = 5,
        tactics: Optional[List[PromptMutationTactic]] = None,
        tactics_repo: TacticRepo = None,
        adv_options: Optional[AdvOptions] = None,
    ):
        """
        Initializes the EngineConfig with the given parameters.

        Args:
            evaluator_router (EvaluatorRouter): The evaluator router responsible for evaluation dispatch.
            test_suites (List[TestSuitePrompts]): List of test suites containing evaluation prompts.
            global_evaluator (Optional[TypeAndNameBasedEvaluator]):
                If provided, this evaluator will override all evaluation methods.
            tactics Optional[List[PromptMutationTactic]]:
                Tactics provided to change the prompts
            adv_options (Optional[AdvOptions]): Advanced execution settings (e.g., attempts, threading).
                Defaults to one attempt and single-threaded execution.
        """
        default_adv_options = AdvOptions(attempts=1, threads=1)
        self.adv_options = adv_options or default_adv_options
        self.evaluator_router = evaluator_router
        self.test_suites = test_suites
        self.tactics = tactics
        self.max_per_tactic = max_per_tactic
        self.tactics_repo = tactics_repo
        self.global_evaluator = global_evaluator


class Prompt2TacticalVariations:
    """
    Generate concrete instances of the prompts based on the prompt templates and prompt variables.
    """

    def __init__(
        self,
        max_per_tactic: int,
        tactics: Optional[List[PromptMutationTactic]] = None,
        tactics_repo: TacticRepo = None,
    ):
        self.max_per_tactic = max_per_tactic
        self.tactics = tactics
        self.generator = TacticalPromptGenerator(
            tactic_repo=tactics_repo, max_per_tactic=max_per_tactic, tactics=tactics
        )

    def generate(self, prompt: BaseTestPrompt) -> Iterator[BaseMultiTurnAgentResponse]:
        yield prompt
        if isinstance(prompt, MultiTurnTestPrompt):
            prompt_var_gen = self.generator.generate_variations(base_prompt=prompt)
            for i, prompt_variation in enumerate(prompt_var_gen):
                if i < self.max_per_tactic:
                    yield (prompt_variation)


class TestPrompt2Turns:
    """
    Generate concrete instances of the prompts based on the prompt templates and prompt variables.
    """

    def generate(self, prompt: BaseTestPrompt) -> Iterator[BaseMultiTurnAgentResponse]:
        if isinstance(prompt, TestPromptWithEvalCriteria):
            converter = PromptVariableSubstitutor(prompt.variables)
            # Convert each string into a single Turn and return it
            for prompt_with_value in converter.convert(prompt=prompt.prompt):
                yield BaseMultiTurnAgentResponse(
                    turns=[Turn(role="USER", message=prompt_with_value)]
                )
        elif isinstance(prompt, TestPromptWithModEval):
            # Convert each string into a single Turn and return it
            yield BaseMultiTurnAgentResponse(
                turns=[Turn(role="USER", message=prompt.prompt)]
            )
        elif isinstance(prompt, MultiTurnTestPrompt):
            yield prompt
        else:
            raise FeatureNotImplementedError(
                f"Prompt of type {type(prompt)} is not handled"
            )

    def _remove_any_assistant_turn(self, turns: List[Turn]) -> List[Turn]:
        """
        Remove any Assistant turn from the conversation.
        """
        return [turn for turn in turns if turn.role != "ASSISTANT"]


class MultiTurnScanner:
    logger = logging.getLogger(__name__)

    def __init__(self, config: EngineConfig):
        self.config = config
        self.consecutive_failures = 0
        self.failure_threshold = 5  # Define threshold for consecutive failures

    def scan(
        self, agent: BaseAgent, max_prompts: int = 1000000
    ) -> Iterator[EvalResult]:
        """
        Iterates through test suites and executes test prompts.
        """
        i = 0  # Number of prompts executed
        p2str = TestPrompt2Turns()
        p2vars = Prompt2TacticalVariations(
            max_per_tactic=self.config.max_per_tactic,
            tactics=self.config.tactics,
            tactics_repo=self.config.tactics_repo,
        )

        for test_suite in self.config.test_suites:
            for risk_prompt in test_suite.risk_prompts:
                for test_prompt in risk_prompt.test_prompts:
                    for prompt_with_values in p2str.generate(test_prompt):
                        if not self._should_continue(i, max_prompts):
                            return

                        self.logger.info("Executing prompt number - %s", i + 1)
                        for prompt_variation in p2vars.generate(prompt_with_values):
                            self.logger.debug("Prompt Variation: %s", prompt_variation)
                            yield from self._process_prompt(
                                test_suite.dataset,
                                agent,
                                test_prompt,
                                prompt_variation,
                            )
                            i += 1
                            if not self._should_continue(i, max_prompts):
                                return

    def _should_continue(self, i: int, max_prompts: int) -> bool:
        """
        Checks if the scanning process should continue.
        """
        return i < max_prompts

    def _process_prompt(
        self,
        dataset: PromptDataset,
        agent: BaseAgent,
        test_prompt: BaseTestPrompt,
        prompt_with_values: BaseMultiTurnConversation,
    ) -> Iterator[EvalResult]:
        """
        Executes prompts, collects responses, evaluates them, and tracks statistics.
        """
        attempts_builder = AttemptsBenchmarkBuilder()
        run_id = str(uuid.uuid4())

        evaluation_method = (
            self.config.global_evaluator or test_prompt.evaluation_method
        )

        responses = self._collect_responses(agent, prompt_with_values)
        evaluation_results = self._evaluate_responses(
            dataset,
            responses,
            evaluation_method=evaluation_method,
        )

        response_evaluation_statuses = self._build_response_statuses(
            responses, evaluation_results
        )
        self._update_attempts(attempts_builder, evaluation_results)
        yield EvalResult(
            run_id=run_id,
            prompt=prompt_with_values,
            evaluation_method=evaluation_method,
            responses=response_evaluation_statuses,
            attempts=attempts_builder.get_attempts(),
        )

    def _collect_responses(
        self, agent: BaseAgent, prompt_with_values: BaseMultiTurnConversation
    ) -> List[BaseMultiTurnResponse]:
        """
        Collects responses from the agent.
        """
        return [
            self._get_response(agent, prompt_with_values)
            for _ in range(self.config.adv_options.attempts)
        ]

    def _evaluate_responses(
        self,
        dataset: PromptDataset,
        responses: List[BaseMultiTurnResponse],
        evaluation_method: TypeAndNameBasedEvaluator,
    ):
        """
        Evaluates all collected responses.
        """

        return [
            self.config.evaluator_router.evaluate_conversation(
                dataset=dataset,
                response=response,
                evaluation_method=evaluation_method,
            )
            for response in responses
        ]

    def _build_response_statuses(self, responses, evaluation_results):
        """
        Builds response evaluation status objects.
        """
        return [
            ResponseEvaluationStatus(
                response=response,
                success=eval_result.success,
                description=eval_result.description,
            )
            for response, eval_result in zip(responses, evaluation_results)
        ]

    def _update_attempts(self, attempts_builder, evaluation_results):
        """
        Updates failure statistics.
        """
        for eval_result in evaluation_results:
            failed = not eval_result.success
            attempts_builder.add_result(failed=failed, error=False)

        attempts_builder.calculate_failure_rate()

    def _get_response(
        self, agent: BaseAgent, prompt: BaseMultiTurnConversation
    ) -> BaseMultiTurnResponse:
        """
        Retrieves the agent's response for a given prompt.
        """
        try:
            return agent.converse(prompt)
        except BaseAgentException as ex:
            raise ex
        except Exception as e:
            self.logger.warning(e)
            raise UnknownAgentException(
                f"Unknown Error while invoking the agent: {str(e)}"
            )

    def get_attempts(self) -> AttemptsBenchmarkStats:
        """
        Returns the evaluation statistics.
        """
        return self.attempts_builder.get_attempts()
