# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import re

from aiq.builder.builder import Builder
from aiq.builder.framework_enum import LLMFrameworkEnum
from aiq.cli.register_workflow import register_its_strategy
from aiq.data_models.its_strategy import ITSStrategyBaseConfig
from aiq.experimental.inference_time_scaling.models.its_item import ITSItem
from aiq.experimental.inference_time_scaling.models.selection_config import LLMBasedAgentOutputSelectionConfig
from aiq.experimental.inference_time_scaling.models.stage_enums import PipelineTypeEnum
from aiq.experimental.inference_time_scaling.models.stage_enums import StageTypeEnum
from aiq.experimental.inference_time_scaling.models.strategy_base import StrategyBase
from aiq.utils.io.model_processing import remove_r1_think_tags

logger = logging.getLogger(__name__)


class LLMBasedAgentOutputSelector(StrategyBase):

    def __init__(self, config: ITSStrategyBaseConfig) -> None:
        super().__init__(config)
        self.llm_bound = None

    async def build_components(self, builder: Builder) -> None:
        """
        Build the components required for the selector.
        """
        self.llm_bound = await builder.get_llm(self.config.selection_llm, wrapper_type=LLMFrameworkEnum.LANGCHAIN)

    def supported_pipeline_types(self) -> [PipelineTypeEnum]:
        return [PipelineTypeEnum.AGENT_EXECUTION]

    def stage_type(self) -> StageTypeEnum:
        return StageTypeEnum.SELECTION

    async def ainvoke(self,
                      items: list[ITSItem],
                      original_prompt: str | None = None,
                      agent_context: str | None = None,
                      **kwargs) -> [ITSItem]:
        """
        Select the planning items based on the selection strategy.

        Args:
            original_prompt (str): The prompt the user provided the agent.
            agent_context (str): The context of the agent, if applicable.
            items (list[ITSItem]): The list of planning items to select from.

        Returns:
            ITSItem: The selected planning item.
        """

        try:
            from langchain_core.language_models import BaseChatModel
            from langchain_core.prompts import PromptTemplate
        except ImportError:
            raise ImportError("langchain-core is not installed. Please install it to use SingleShotMultiPlanPlanner.\n"
                              "This error can be resolved by installing aiqtoolkit-langchain.")

        from pydantic import BaseModel

        if not isinstance(self.llm_bound, BaseChatModel):
            raise ValueError("The `selection_llm` must be an instance of `BaseChatModel`.")

        model: BaseChatModel = self.llm_bound

        results = ""
        for idx, item in enumerate(items):
            item_str = str(item.output.model_dump()) if isinstance(item.output, BaseModel) else str(item.output)
            results += f"{idx + 1}. {remove_r1_think_tags(item_str)}\n\n"

        prompt_template = PromptTemplate(
            template=self.config.selection_template,
            input_variables=["objective", "input", "results"],
            validate_template=True,
        )

        prompt = (await prompt_template.ainvoke(input={
            "objective": agent_context, "input": original_prompt, "results": results
        })).to_string()

        selected_plan_index = remove_r1_think_tags((await model.ainvoke(prompt)).content)

        # Model Response will be 'Plan {plan number}'
        # Use RegEx to extrac Plan {idx} from response strong
        if not isinstance(selected_plan_index, str):
            logger.warning(f"Invalid response from LLM for selected plan index: {selected_plan_index}.")
            raise ValueError("Unable to parse the selected plan index.")
        selected_plan_index = selected_plan_index.strip()
        match = re.match(r'^\s*SELECTED ITEM:\s+(\d+)', selected_plan_index)
        if not match:
            logger.warning(f"Could not parse the selected plan index from the response: {selected_plan_index}.")
            raise ValueError("The response format for selecting the item is incorrect.")
        index = match.group(1)

        try:
            selected_index = int(index) - 1
            if selected_index < 0 or selected_index >= len(items):
                raise ValueError("Selected index is out of range.")

            # Return the selected planning item
            return [items[selected_index]]
        except ValueError as e:
            logger.warning(f"Error parsing the selected plan index: {index}. Exception: {str(e)}")
            raise ValueError(f"Failed to parse the selected plan index from the LLM response: {selected_plan_index}. "
                             "Ensure the response follows the expected format.") from e


@register_its_strategy(config_type=LLMBasedAgentOutputSelectionConfig)
async def register_llm_based_agent_output_selector(config: LLMBasedAgentOutputSelectionConfig, builder: Builder):
    """
    Register the LLMBasedAgentOutputSelector with the builder.
    """
    selector = LLMBasedAgentOutputSelector(config)
    await selector.build_components(builder)
    yield selector
