import os
from threading import Lock
from typing import List, Optional

import yaml
from pydantic import BaseModel

from morphius.core.base.evaluator import BaseEvaluator
from morphius.core.models.evaluator import EvaluationModelName, EvaluationModelType
from morphius.core.models.providers.hf import HFModels, HuggingFaceProvider
from morphius.plugins.redteam.dataset.hf.hf_prompts_repo import (
    HFPromptsGenerator,
)
from morphius.plugins.redteam.evaluators.hf.toxicity.granite_toxicity_evaluator import (
    GraniteToxicityEvaluator,
)
from morphius.plugins.redteam.evaluators.llamaguard._ollama import Ollama_LlamaGuard
from morphius.plugins.redteam.evaluators.xray.any_json_path_expr_match import (
    AnyJsonPathExpressionMatch,
)
from morphius.plugins.redteam.evaluators.xray.any_keyword_match import AnyKeywordMatch
from morphius.plugins.redteam.tactics.repo import TacticRepo

from .providers import OllamaLlamaGuardConfig, OpenAIConfig


class DependencyFactory:
    """Manages dependencies required for red team evaluation."""

    def __init__(self):
        self.hf_prompt_generator = HFPromptsGenerator()


# Lock to ensure thread-safe initialization
_initialize_lock = Lock()


# Global variables
global_config: Optional["GlobalConfig"] = None
deps_factory: Optional["DependencyFactory"] = None
eval_factory: Optional["ModelEvaluatorsFactory"] = None
llm_models: Optional["HFModelsRepo"] = None
tactics_repo: Optional[TacticRepo] = None


def ensure_initialized(
    init_global_config: bool = False,
    init_deps_factory: bool = False,
    init_eval_factory: bool = False,
    init_llm_models: bool = False,
    init_tactics_repo: bool = False,
):
    """Idempotent initialization of global dependencies.

    If any init_* argument is set to True, only initialize the specified components.
    Otherwise, initialize all components if not already initialized.
    """
    global global_config, deps_factory, eval_factory, llm_models, tactics_repo

    with _initialize_lock:  # Ensures thread safety
        if any(
            [
                init_global_config,
                init_deps_factory,
                init_eval_factory,
                init_llm_models,
                init_tactics_repo,
            ]
        ):
            if init_global_config and not global_config:
                global_config = GlobalConfig()
            if init_deps_factory and not deps_factory:
                deps_factory = DependencyFactory()
            if init_eval_factory and not eval_factory:
                eval_factory = ModelEvaluatorsFactory(config=global_config)
            if init_llm_models and not llm_models:
                llm_models = HFModelsRepo()
            if init_tactics_repo and not tactics_repo:
                tactics_repo = TacticRepo()
        else:
            if all([global_config, deps_factory, eval_factory, llm_models]):
                return  # Already initialized, do nothing

            if not global_config:
                global_config = GlobalConfig()
            if not deps_factory:
                deps_factory = DependencyFactory()
            if not eval_factory:
                eval_factory = ModelEvaluatorsFactory(config=global_config)
            if not llm_models:
                llm_models = HFModelsRepo()
            if not tactics_repo:
                tactics_repo = TacticRepo()


def get_global_config(only: bool = False) -> "GlobalConfig":
    ensure_initialized(init_global_config=True if only else False)
    return global_config


def get_deps_factory(only: bool = False) -> "DependencyFactory":
    ensure_initialized(init_deps_factory=True if only else False)
    return deps_factory


def get_eval_factory(only: bool = False) -> "ModelEvaluatorsFactory":
    ensure_initialized(init_eval_factory=True if only else False)
    return eval_factory


def get_llm_models(only: bool = False) -> "HFModelsRepo":
    ensure_initialized(init_llm_models=True if only else False)
    return llm_models


def get_tactics_repo(only: bool = False) -> TacticRepo:
    ensure_initialized(init_tactics_repo=True if only else False)
    return tactics_repo


class GlobalConfig(BaseModel):
    openai: Optional[OpenAIConfig] = OpenAIConfig()
    ollamaguard: Optional[OllamaLlamaGuardConfig] = OllamaLlamaGuardConfig()


class HFModelsRepo:
    def __init__(self, models_path=None):
        script_dir = os.path.dirname(os.path.abspath(__file__))
        self._models_path = os.path.join(
            script_dir, models_path or "../core/repo/hf_models.yml"
        )
        self.models = self._load_from_file()

    def _load_from_file(self) -> HFModels:
        with open(self._models_path, "r") as file:
            data = yaml.safe_load(file)
        return HFModels(
            huggingface=[
                HuggingFaceProvider(**model) for model in data.get("huggingface", [])
            ]
        )

    def get_huggingface_model(self, model_name: str) -> Optional[HuggingFaceProvider]:
        """Retrieve a HuggingFaceProvider instance by model name."""
        return next(
            (
                provider
                for provider in self.models.huggingface
                if provider.model == model_name
            ),
            None,
        )


class ModelEvaluatorsFactory:
    """Factory class for initializing and managing different evaluation models."""

    EVALUATORS = [
        (AnyKeywordMatch, {}),
        (AnyJsonPathExpressionMatch, {}),
        (
            GraniteToxicityEvaluator,
            {
                "model_name": "ibm-granite/granite-guardian-hap-125m",
                "eval_name": EvaluationModelName.IBM_GRANITE_TOXICITY_HAP_125M,
            },
        ),
        (
            GraniteToxicityEvaluator,
            {
                "model_name": "ibm-granite/granite-guardian-hap-38m",
                "eval_name": EvaluationModelName.IBM_GRANITE_TOXICITY_HAP_38M,
            },
        ),
        (Ollama_LlamaGuard, {"config_class": OllamaLlamaGuardConfig}),
    ]

    def __init__(self, config: GlobalConfig):
        """Initializes the evaluator factory with the given configuration."""
        self._evaluators: List[BaseEvaluator] = []
        self._init_evaluators(config)

    def _add_evaluator(self, evaluator: BaseEvaluator):
        """Adds an evaluator to the internal list."""
        self._evaluators.append(evaluator)

    def get_evaluator_by_name(
        self, name: EvaluationModelName
    ) -> Optional[BaseEvaluator]:
        """Retrieves an evaluator instance by its name."""
        return next(
            (
                evaluator
                for evaluator in self._evaluators
                if evaluator.get_name() == name
            ),
            None,
        )

    def get_evaluators_by_type(
        self, eval_type: EvaluationModelType
    ) -> List[BaseEvaluator]:
        """Retrieves a list of evaluators filtered by their type."""
        return [
            evaluator
            for evaluator in self._evaluators
            if evaluator.get_type() == eval_type
        ]

    def select_evaluator(
        self, eval_model_name: EvaluationModelName, eval_model_type: EvaluationModelType
    ) -> BaseEvaluator:
        """
        Selects an appropriate evaluator based on model name or type.
        """
        evaluator = (
            self.get_evaluator_by_name(eval_model_name) if eval_model_name else None
        )
        if not evaluator:
            evaluators = self.get_evaluators_by_type(eval_model_type)
            evaluator = evaluators[0] if evaluators else None
        return evaluator

    def _init_evaluators(self, config: GlobalConfig):
        """Initializes evaluators, including those with and without configuration dependencies."""
        for evaluator_class, param_data in self.EVALUATORS:
            params = {}
            for k, v in param_data.items():
                if k == "config_class":
                    params["config"] = next(
                        (
                            getattr(config, attr)
                            for attr in config.model_fields
                            if isinstance(getattr(config, attr, None), v)
                        ),
                        None,
                    )
                else:
                    params[k] = v

            evaluator = (
                evaluator_class(**params["config"].model_dump())
                if "config" in params and isinstance(params["config"], BaseModel)
                else evaluator_class(**params)
            )

            if evaluator.is_available():
                self._add_evaluator(evaluator)
