import os
import re
from threading import Lock
from typing import List, Optional

import yaml
from pydantic import BaseModel

from dtx.core.exceptions.base import ModelNotFoundError
from dtx.core.base.evaluator import BaseEvaluator
from dtx_models.evaluator import EvaluationModelName, EvaluationModelType
from dtx_models.providers.hf import HFModels, HuggingFaceProviderConfig, HuggingFaceTask
from dtx.plugins.redteam.dataset.hf.hf_prompts_repo import (
    HFPromptsGenerator,
)
from dtx.plugins.redteam.evaluators.hf.toxicity.granite_toxicity_evaluator import (
    GraniteToxicityEvaluator,
)
from dtx.plugins.redteam.evaluators.llamaguard._ollama import Ollama_LlamaGuard
from dtx.plugins.redteam.evaluators.openai.pbe import OpenAIBasedPolicyEvaluator
from dtx.plugins.redteam.evaluators.xray.any_json_path_expr_match import (
    AnyJsonPathExpressionMatch,
)
from dtx.plugins.redteam.evaluators.xray.any_keyword_match import AnyKeywordMatch
from dtx.plugins.redteam.tactics.repo import TacticRepo
from dtx.plugins.redteam.dataset.base.adv_repo import AdvBenchRepo
from huggingface_hub import HfApi
from huggingface_hub.utils import HfHubHTTPError


from .providers import OllamaLlamaGuardConfig, OpenAIConfig


class DependencyFactory:
    """Manages dependencies required for red team evaluation."""

    def __init__(self):
        self.hf_prompt_generator = HFPromptsGenerator()


# Lock to ensure thread-safe initialization
_initialize_lock = Lock()


# Global variables
global_config: Optional["GlobalConfig"] = None
deps_factory: Optional["DependencyFactory"] = None
eval_factory: Optional["ModelEvaluatorsFactory"] = None
llm_models: Optional["HFModelsRepo"] = None
tactics_repo: Optional[TacticRepo] = None
adv_repo: Optional[AdvBenchRepo] = None


# Add to the parameters of ensure_initialized
def ensure_initialized(
    init_global_config: bool = False,
    init_deps_factory: bool = False,
    init_eval_factory: bool = False,
    init_llm_models: bool = False,
    init_tactics_repo: bool = False,
    init_adv_repo: bool = False,  
):
    """Idempotent initialization of global dependencies.

    If any init_* argument is set to True, only initialize the specified components.
    Otherwise, initialize all components if not already initialized.
    """
    global global_config, deps_factory, eval_factory, llm_models, tactics_repo, adv_repo

    with _initialize_lock:  # Ensures thread safety
        if any(
            [
                init_global_config,
                init_deps_factory,
                init_eval_factory,
                init_llm_models,
                init_tactics_repo,
                init_adv_repo,  
            ]
        ):
            if init_global_config and not global_config:
                global_config = GlobalConfig()
            if init_deps_factory and not deps_factory:
                deps_factory = DependencyFactory()
            if init_eval_factory and not eval_factory:
                eval_factory = ModelEvaluatorsFactory(config=global_config)
            if init_llm_models and not llm_models:
                llm_models = HFModelsRepo()
            if init_tactics_repo and not tactics_repo:
                tactics_repo = TacticRepo()
            if init_adv_repo and not adv_repo:  
                adv_repo = AdvBenchRepo()
        else:
            if all([global_config, deps_factory, eval_factory, llm_models, tactics_repo, adv_repo]):
                return  # Already initialized, do nothing

            if not global_config:
                global_config = GlobalConfig()
            if not deps_factory:
                deps_factory = DependencyFactory()
            if not eval_factory:
                eval_factory = ModelEvaluatorsFactory(config=global_config)
            if not llm_models:
                llm_models = HFModelsRepo()
            if not tactics_repo:
                tactics_repo = TacticRepo()
            if not adv_repo:
                adv_repo = AdvBenchRepo() 


# Add getter function for AdvRepo
def get_adv_repo(only: bool = False) -> AdvBenchRepo:
    ensure_initialized(init_adv_repo=True if only else False)
    return adv_repo


def get_global_config(only: bool = False) -> "GlobalConfig":
    ensure_initialized(init_global_config=True if only else False)
    return global_config


def get_deps_factory(only: bool = False) -> "DependencyFactory":
    ensure_initialized(init_deps_factory=True if only else False)
    return deps_factory


def get_eval_factory(only: bool = False) -> "ModelEvaluatorsFactory":
    ensure_initialized(init_eval_factory=True if only else False)
    return eval_factory


def get_llm_models(only: bool = False) -> "HFModelsRepo":
    ensure_initialized(init_llm_models=True if only else False)
    return llm_models


def get_tactics_repo(only: bool = False) -> TacticRepo:
    ensure_initialized(init_tactics_repo=True if only else False)
    return tactics_repo


class GlobalConfig(BaseModel):
    openai: Optional[OpenAIConfig] = OpenAIConfig()
    ollamaguard: Optional[OllamaLlamaGuardConfig] = OllamaLlamaGuardConfig()


# class HFModelsRepo:
#     def __init__(self, models_path=None):
#         script_dir = os.path.dirname(os.path.abspath(__file__))
#         self._models_path = os.path.join(
#             script_dir, models_path or "../core/repo/hf_models.yml"
#         )
#         self.models = self._load_from_file()

#     def _load_from_file(self) -> HFModels:
#         with open(self._models_path, "r") as file:
#             data = yaml.safe_load(file)
#         return HFModels(
#             huggingface=[
#                 HuggingFaceProviderConfig(**model) for model in data.get("huggingface", [])
#             ]
#         )

#     def get_huggingface_model(self, model_name: str) -> Optional[HuggingFaceProviderConfig]:
#         """Retrieve a HuggingFaceProviderConfig instance by model name."""
#         return next(
#             (
#                 provider
#                 for provider in self.models.huggingface
#                 if provider.model == model_name
#             ),
#             None,
#         )

class HFModelsRepo:
    TASK_TAGS_TO_ENUM = {
        "text-generation": HuggingFaceTask.TEXT_GENERATION,
        "text2text-generation": HuggingFaceTask.TEXT2TEXT_GENERATION,
        "text-classification": HuggingFaceTask.TEXT_CLASSIFICATION,
        "token-classification": HuggingFaceTask.TOKEN_CLASSIFICATION,
        "fill-mask": HuggingFaceTask.FILL_MASK,
        "feature-extraction": HuggingFaceTask.FEATURE_EXTRACTION,
        "sentence-similarity": HuggingFaceTask.SENTENCE_SIMILARITY,
    }

    DEFAULT_CONFIGS = {
        HuggingFaceTask.TEXT_GENERATION: {"max_new_tokens": 512, "temperature": 0.7, "top_p": 0.9},
        HuggingFaceTask.TEXT2TEXT_GENERATION: {"max_new_tokens": 512, "temperature": 0.7, "top_p": 0.9},
        HuggingFaceTask.FILL_MASK: {},
        HuggingFaceTask.TEXT_CLASSIFICATION: {},
        HuggingFaceTask.TOKEN_CLASSIFICATION: {},
        HuggingFaceTask.FEATURE_EXTRACTION: {},
        HuggingFaceTask.SENTENCE_SIMILARITY: {},
    }

    def __init__(self, models_path=None):
        script_dir = os.path.dirname(os.path.abspath(__file__))
        self._models_path = os.path.join(script_dir, models_path or "../core/repo/hf_models.yml")
        self.models = self._load_from_file()

    def _load_from_file(self) -> HFModels:
        if not os.path.exists(self._models_path):
            return HFModels(huggingface=[])
        with open(self._models_path, "r") as file:
            data = yaml.safe_load(file) or {}
        return HFModels(
            huggingface=[HuggingFaceProviderConfig(**model) for model in data.get("huggingface", [])]
        )

    def get_huggingface_model(self, model_name: str) -> HuggingFaceProviderConfig:
        """Retrieve a HuggingFaceProviderConfig instance by model name.
        If not found locally, fetch from Hugging Face. Raise if that fails."""
        model = next((provider for provider in self.models.huggingface if provider.model == model_name), None)
        print(model)
        if model:
            return model

        try:
            api = HfApi()
            model_info = api.model_info(model_name)

            task = next(
                (self.TASK_TAGS_TO_ENUM[tag] for tag in model_info.tags if tag in self.TASK_TAGS_TO_ENUM),
                HuggingFaceTask.TEXT_GENERATION
            )

            support_multi_turn = any(
                re.search(r"(chat|dialog|instruct|conversational)", tag, re.IGNORECASE)
                for tag in model_info.tags
            )

            new_provider = HuggingFaceProviderConfig(
                model=model_name,
                task=task.value,
                support_multi_turn=support_multi_turn,
                supported_input_format="openai",
                config=self.DEFAULT_CONFIGS.get(task, {})
            )

            self.models.huggingface.append(new_provider)

            return new_provider

        except HfHubHTTPError as e:
            raise ModelNotFoundError(model_name) from e


class ModelEvaluatorsFactory:
    """Factory class for initializing and managing different evaluation models."""

    EVALUATORS = [
        (AnyKeywordMatch, {}),
        (AnyJsonPathExpressionMatch, {}),
        (
            GraniteToxicityEvaluator,
            {
                "model_name": "ibm-granite/granite-guardian-hap-125m",
                "eval_name": EvaluationModelName.IBM_GRANITE_TOXICITY_HAP_125M,
            },
        ),
        (
            GraniteToxicityEvaluator,
            {
                "model_name": "ibm-granite/granite-guardian-hap-38m",
                "eval_name": EvaluationModelName.IBM_GRANITE_TOXICITY_HAP_38M,
            },
        ),
        (Ollama_LlamaGuard, {"config_class": OllamaLlamaGuardConfig}),
        (OpenAIBasedPolicyEvaluator, {"model_name": "gpt-4o-mini"}),
    ]

    def __init__(self, config: GlobalConfig):
        """Initializes the evaluator factory with the given configuration."""
        self._evaluators: List[BaseEvaluator] = []
        self._init_evaluators(config)

    def _add_evaluator(self, evaluator: BaseEvaluator):
        """Adds an evaluator to the internal list."""
        self._evaluators.append(evaluator)

    def get_evaluator_by_name(
        self, name: EvaluationModelName
    ) -> Optional[BaseEvaluator]:
        """Retrieves an evaluator instance by its name."""
        return next(
            (
                evaluator
                for evaluator in self._evaluators
                if evaluator.get_name() == name
            ),
            None,
        )

    def get_evaluators_by_type(
        self, eval_type: EvaluationModelType
    ) -> List[BaseEvaluator]:
        """Retrieves a list of evaluators filtered by their type."""
        return [
            evaluator
            for evaluator in self._evaluators
            if evaluator.get_type() == eval_type
        ]

    def select_evaluator(
        self, eval_model_name: EvaluationModelName, eval_model_type: EvaluationModelType
    ) -> BaseEvaluator:
        """
        Selects an appropriate evaluator based on model name or type.
        """
        evaluator = (
            self.get_evaluator_by_name(eval_model_name) if eval_model_name else None
        )
        if not evaluator:
            evaluators = self.get_evaluators_by_type(eval_model_type)
            evaluator = evaluators[0] if evaluators else None
        return evaluator

    def _init_evaluators(self, config: GlobalConfig):
        """Initializes evaluators, including those with and without configuration dependencies."""
        for evaluator_class, param_data in self.EVALUATORS:
            params = {}
            for k, v in param_data.items():
                if k == "config_class":
                    params["config"] = next(
                        (
                            getattr(config, attr)
                            for attr in config.model_fields
                            if isinstance(getattr(config, attr, None), v)
                        ),
                        None,
                    )
                else:
                    params[k] = v

            evaluator = (
                evaluator_class(**params["config"].model_dump())
                if "config" in params and isinstance(params["config"], BaseModel)
                else evaluator_class(**params)
            )

            if evaluator.is_available():
                self._add_evaluator(evaluator)
