"""
Module: perplexity_search_tool
Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
SPDX-License-Identifier: Apache-2.0

This module provides classes for interacting with the Perplexity AI search API.
It defines data models for responses and a tool for executing web and conversational searches.
"""

import os
from typing import Annotated, Any, Optional, Union

import requests
from pydantic import BaseModel
import re

from autogen.tools import Tool
from autogen.cmbagent_utils import cmbagent_debug
from autogen import SwarmResult, AfterWorkOption
from autogen.agentchat.contrib.swarm_agent import (
    AfterWork,
    AfterWorkOption,
    OnCondition,
    OnContextCondition,
    SwarmResult,
)

class Message(BaseModel):
    """
    Represents a message in the chat conversation.

    Attributes:
        role (str): The role of the message sender (e.g., "system", "user").
        content (str): The text content of the message.
    """

    role: str
    content: str


class Usage(BaseModel):
    """
    Model representing token usage details.

    Attributes:
        prompt_tokens (int): The number of tokens used for the prompt.
        completion_tokens (int): The number of tokens generated in the completion.
        total_tokens (int): The total number of tokens (prompt + completion).
        search_context_size (str): The size context used in the search (e.g., "high").
    """

    prompt_tokens: int
    completion_tokens: int
    total_tokens: int
    search_context_size: str


class Choice(BaseModel):
    """
    Represents one choice in the response from the Perplexity API.

    Attributes:
        index (int): The index of this choice.
        finish_reason (str): The reason why the API finished generating this choice.
        message (Message): The message object containing the response text.
    """

    index: int
    finish_reason: str
    message: Message


class PerplexityChatCompletionResponse(BaseModel):
    """
    Represents the full chat completion response from the Perplexity API.

    Attributes:
        id (str): Unique identifier for the response.
        model (str): The model name used for generating the response.
        created (int): Timestamp when the response was created.
        usage (Usage): Token usage details.
        citations (list[str]): List of citation strings included in the response.
        object (str): Type of the response object.
        choices (list[Choice]): List of choices returned by the API.
    """

    id: str
    model: str
    created: int
    usage: Usage
    citations: list[str]
    object: str
    choices: list[Choice]


class SearchResponse(BaseModel):
    """
    Represents the response from a search query.

    Attributes:
        content (Optional[str]): The textual content returned from the search.
        citations (Optional[List[str]]): A list of citation URLs relevant to the search result.
        error (Optional[str]): An error message if the search failed.
    """

    content: Union[str, None]
    citations: Union[list[str], None]
    error: Union[str, None]


class PerplexitySearchTool(Tool):
    """
    Tool for interacting with the Perplexity AI search API.

    This tool uses the Perplexity API to perform web search, news search,
    and conversational search, returning concise and precise responses.

    Attributes:
        url (str): API endpoint URL.
        model (str): Name of the model to be used.
        api_key (str): API key for authenticating with the Perplexity API.
        max_tokens (int): Maximum tokens allowed for the API response.
        search_domain_filters (Optional[list[str]]): Optional list of domain filters for the search.
    """

    def __init__(
        self,
        model: str = "sonar-pro",
        api_key: Optional[str] = None,
        max_tokens: int = 1000,
        search_domain_filter: Optional[list[str]] = None,
    ):
        """
        Initializes a new instance of the PerplexitySearchTool.

        Args:
            model (str, optional): The model to use. Defaults to "sonar".
            api_key (Optional[str], optional): API key for authentication.
            max_tokens (int, optional): Maximum number of tokens for the response. Defaults to 1000.
            search_domain_filter (Optional[list[str]], optional): List of domain filters to restrict search.

        Raises:
            ValueError: If the API key is missing, the model is empty, max_tokens is not positive,
                        or if search_domain_filter is not a list when provided.
        """
        self.api_key = api_key or os.getenv("PERPLEXITY_API_KEY")
        self._validate_tool_config(model, self.api_key, max_tokens, search_domain_filter)
        self.url = "https://api.perplexity.ai/chat/completions"
        self.model = model
        self.api_key = api_key
        self.max_tokens = max_tokens
        self.search_domain_filters = search_domain_filter
        super().__init__(
            name="perplexity-search",
            description="Perplexity AI search tool for web search, news search, and conversational search "
            "for finding answers to everyday questions, conducting in-depth research and analysis.",
            func_or_tool=self.search,
        )

    @staticmethod
    def _validate_tool_config(
        model: str, api_key: Union[str, None], max_tokens: int, search_domain_filter: Union[list[str], None]
    ) -> None:
        """
        Validates the configuration parameters for the search tool.

        Args:
            model (str): The model to use.
            api_key (Union[str, None]): The API key for authentication.
            max_tokens (int): Maximum tokens allowed.
            search_domain_filter (Union[list[str], None]): Domain filters for search.

        Raises:
            ValueError: If the API key is missing, model is empty, max_tokens is not positive,
                        or search_domain_filter is not a list.
        """
        if not api_key:
            raise ValueError("Perplexity API key is missing")
        if model is None or model == "":
            raise ValueError("model cannot be empty")
        if max_tokens <= 0:
            raise ValueError("max_tokens must be positive")
        if search_domain_filter is not None and not isinstance(search_domain_filter, list):
            raise ValueError("search_domain_filter must be a list")

    def _execute_query(self, payload: dict[str, Any]) -> PerplexityChatCompletionResponse:
        """
        Executes a query by sending a POST request to the Perplexity API.

        Args:
            payload (dict[str, Any]): The payload to send in the API request.

        Returns:
            PerplexityChatCompletionResponse: Parsed response from the Perplexity API.
        """
        headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
        response = requests.post(self.url, headers=headers, json=payload).json()
        # print(response["choices"][0]["message"]["content"])
        # print(response["citations"])
        return response
        # response = requests.request("POST", self.url, json=payload, headers=headers)
        # response_json = response.json()
        # perp_resp = PerplexityChatCompletionResponse(**response_json)
        # if cmbagent_debug:
        #     print("\n\nin perplexity_search_tool.py _execute_query... perp_resp:")
        #     import pprint; pprint.pprint(perp_resp)
        # # import sys; sys.exit()
        # return perp_resp

    def search(self, query: Annotated[str, "The search query."], context_variables: dict) -> SearchResponse:
        """
        Perform a search query using the Perplexity AI API.

        Constructs the payload, executes the query, and parses the response to return
        a concise search result along with any provided citations.

        Args:
            query (str): The search query.

        Returns:
            SearchResponse: A model containing the search result content and citations.

        Raises:
            ValueError: If the search query is invalid.
            RuntimeError: If there is an error during the search process.
        """
        # print('\n\n\n\nin perplexity_search_tool.py search: ', query)
        # print(dir(self))
        # print('\n\n\n\nin perplexity_search_tool.py context_variables: ', context_variables)
        # print('\n perplexity_query: ', context_variables['perplexity_query'])
        query = context_variables['perplexity_query']
        # import sys; sys.exit()
        payload = {
            "model": self.model,
            "messages": [{"role": "system", "content": "Be precise and concise."}, {"role": "user", "content": query}],
            # "max_tokens": self.max_tokens,
            "search_domain_filter": self.search_domain_filters,
            # "web_search_options": {"search_context_size": "high"},
        }
        # print('\n\n\n\nin perplexity_search_tool.py payload: ', payload)
        try:
            perplexity_response = self._execute_query(payload)
            # content = perplexity_response.choices[0].message.content
            # citations = perplexity_response.citations
            content = perplexity_response["choices"][0]["message"]["content"]
            citations = perplexity_response["citations"]
            # print('\n\n\n\nin perplexity_search_tool.py content: ', content)
            # print('\n\n\n\nin perplexity_search_tool.py citations: ', citations)
            # context_variables["perplexity_response"] = content
            context_variables["perplexity_citations"] = citations

            # This regex removes everything from "<think>" up to the first occurrence of "</think>"
            cleaned_response = re.sub(r'<think>.*?</think>\s*', '', content, flags=re.DOTALL)
            # Update the final context with the cleaned response.
            context_variables["perplexity_response"] = cleaned_response

            def citation_repl(match):
                # Extract the citation number as a string and convert to an integer.
                number_str = match.group(1)
                index = int(number_str) - 1  # Adjust for 0-based indexing
                if 0 <= index < len(citations):
                    return f'[[{number_str}]({citations[index]})]'
                # If the citation number is out of bounds, return it unchanged.
                return match.group(0)
            # Replace all instances of citations in the form [x] using the helper function.
            markdown_response = re.sub(r'\[(\d+)\]', citation_repl, cleaned_response)

            # return SearchResponse(content=content, citations=citations, error=None)
            return SwarmResult(agent=AfterWorkOption.TERMINATE, ## transfer to planner
                                values=markdown_response,
                                context_variables=context_variables)
        except Exception as e:
            return SearchResponse(
                content=None, citations=None, error=f"PerplexitySearchTool failed to search. Error: {e}"
            )
