"""Claude MCP Agent implementation."""

from __future__ import annotations

import copy
import logging
import re
from inspect import cleandoc
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast

from anthropic import Anthropic, AsyncAnthropic, Omit
from anthropic.types import (
    CacheControlEphemeralParam,
)
from anthropic.types.beta import (
    BetaBase64ImageSourceParam,
    BetaContentBlockParam,
    BetaImageBlockParam,
    BetaMessageParam,
    BetaTextBlockParam,
    BetaToolBash20250124Param,
    BetaToolComputerUse20250124Param,
    BetaToolParam,
    BetaToolResultBlockParam,
    BetaToolTextEditor20250728Param,
    BetaToolUnionParam,
)

import hud

if TYPE_CHECKING:
    from hud.datasets import Task

import mcp.types as types

from hud.settings import settings
from hud.tools.computer.settings import computer_settings
from hud.types import AgentResponse, MCPToolCall, MCPToolResult
from hud.utils.hud_console import HUDConsole

from .base import MCPAgent

logger = logging.getLogger(__name__)


class ClaudeAgent(MCPAgent):
    """
    Claude agent that uses MCP servers for tool execution.

    This agent uses Claude's native tool calling capabilities but executes
    tools through MCP servers instead of direct implementation.
    """

    metadata: ClassVar[dict[str, Any]] = {
        "display_width": computer_settings.ANTHROPIC_COMPUTER_WIDTH,
        "display_height": computer_settings.ANTHROPIC_COMPUTER_HEIGHT,
    }

    def __init__(
        self,
        model_client: AsyncAnthropic | None = None,
        model: str = "claude-sonnet-4-5",
        max_tokens: int = 16384,
        use_computer_beta: bool = True,
        validate_api_key: bool = True,
        computer_tool_regex: str = r"(^|_)(anthropic_computer|computer_anthropic|computer)$",
        **kwargs: Any,
    ) -> None:
        """
        Initialize Claude MCP agent.

        Args:
            model_client: AsyncAnthropic client (created if not provided)
            model: Claude model to use
            max_tokens: Maximum tokens for response
            use_computer_beta: Whether to use computer-use beta features
            computer_tool_regex: we use this regex to identify the computer tool
            **kwargs: Additional arguments passed to BaseMCPAgent (including mcp_client)
        """
        super().__init__(**kwargs)

        # Initialize client if not provided
        if model_client is None:
            api_key = settings.anthropic_api_key
            if not api_key:
                raise ValueError("Anthropic API key not found. Set ANTHROPIC_API_KEY.")
            model_client = AsyncAnthropic(api_key=api_key)

        # validate api key if requested
        if validate_api_key:
            try:
                Anthropic(api_key=model_client.api_key).models.list()
            except Exception as e:
                raise ValueError(f"Anthropic API key is invalid: {e}") from e

        self.anthropic_client = model_client
        self.model = model
        self.max_tokens = max_tokens
        self.use_computer_beta = use_computer_beta
        self.hud_console = HUDConsole(logger=logger)

        self.model_name = "Claude"
        self.checkpoint_name = self.model

        self.computer_tool_regex = computer_tool_regex

        # these will be initialized in _convert_tools_for_claude
        self.has_computer_tool = False
        self.tool_mapping: dict[str, str] = {}
        self.claude_tools: list[BetaToolUnionParam] = []

    async def initialize(self, task: str | Task | None = None) -> None:
        """Initialize the agent and build tool mappings."""
        await super().initialize(task)
        # Build tool mappings after tools are discovered
        self._convert_tools_for_claude()

    async def get_system_messages(self) -> list[Any]:
        """No system messages for Claude because applied in get_response"""
        return []

    async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[Any]:
        """Format messages for Claude."""
        # Convert MCP content types to Anthropic content types
        anthropic_blocks: list[BetaContentBlockParam] = []

        for block in blocks:
            if isinstance(block, types.TextContent):
                # Only include fields that Anthropic expects
                anthropic_blocks.append(
                    BetaTextBlockParam(
                        type="text",
                        text=block.text,
                    )
                )
            elif isinstance(block, types.ImageContent):
                # Convert MCP ImageContent to Anthropic format
                anthropic_blocks.append(
                    BetaImageBlockParam(
                        type="image",
                        source=BetaBase64ImageSourceParam(
                            type="base64",
                            media_type=cast(
                                "Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']",
                                block.mimeType,
                            ),
                            data=block.data,
                        ),
                    )
                )
            else:
                raise ValueError(f"Unknown content block type: {type(block)}")

        return [BetaMessageParam(role="user", content=anthropic_blocks)]

    @hud.instrument(
        span_type="agent",
        record_args=False,  # Messages can be large
        record_result=True,
    )
    async def get_response(self, messages: list[BetaMessageParam]) -> AgentResponse:
        """Get response from Claude including any tool calls."""

        messages_cached = self._add_prompt_caching(messages)

        response = await self.anthropic_client.beta.messages.create(
            model=self.model,
            system=self.system_prompt if self.system_prompt is not None else Omit(),
            max_tokens=self.max_tokens,
            messages=messages_cached,
            tools=self.claude_tools,
            tool_choice={"type": "auto", "disable_parallel_tool_use": True},
            betas=["computer-use-2025-01-24"] if self.has_computer_tool else Omit(),
        )

        messages.append(
            BetaMessageParam(
                role="assistant",
                content=response.content,
            )
        )

        # Process response
        result = AgentResponse(content="", tool_calls=[], done=True)

        # Extract text content and reasoning
        text_content = ""
        thinking_content = ""

        for block in response.content:
            if block.type == "tool_use":
                tool_call = MCPToolCall(
                    id=block.id,
                    # look up name in tool_mapping if available, otherwise use block name
                    name=self.tool_mapping.get(block.name, block.name),
                    arguments=block.input,
                )
                result.tool_calls.append(tool_call)
                result.done = False
            elif block.type == "text":
                text_content += block.text
            elif hasattr(block, "type") and block.type == "thinking":
                thinking_content += f"Thinking: {block.thinking}\n"

        result.content = thinking_content + text_content

        return result

    async def format_tool_results(
        self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
    ) -> list[BetaMessageParam]:
        """Format tool results into Claude messages."""
        # Process each tool result
        user_content = []

        for tool_call, result in zip(tool_calls, tool_results, strict=True):
            # Extract Claude-specific metadata from extra fields
            tool_use_id = tool_call.id
            if not tool_use_id:
                self.hud_console.warning(f"No tool_use_id found for {tool_call.name}")
                continue

            # Convert MCP tool results to Claude format
            claude_blocks = []

            if result.isError:
                # Extract error message from content
                error_msg = "Tool execution failed"
                for content in result.content:
                    if isinstance(content, types.TextContent):
                        error_msg = content.text
                        break
                claude_blocks.append(text_to_content_block(f"Error: {error_msg}"))
            else:
                # Process success content
                for content in result.content:
                    if isinstance(content, types.TextContent):
                        claude_blocks.append(text_to_content_block(content.text))
                    elif isinstance(content, types.ImageContent):
                        claude_blocks.append(base64_to_content_block(content.data))

            # Add tool result
            user_content.append(tool_use_content_block(tool_use_id, claude_blocks))

        # Return as a user message containing all tool results
        return [
            BetaMessageParam(
                role="user",
                content=user_content,
            )
        ]

    async def create_user_message(self, text: str) -> BetaMessageParam:
        """Create a user message in Claude's format."""
        return BetaMessageParam(role="user", content=text)

    def _convert_tools_for_claude(self) -> None:
        """Convert MCP tools to Claude API tools."""

        def to_api_tool(tool: types.Tool) -> BetaToolUnionParam:
            if tool.name == "str_replace_based_edit_tool":
                return BetaToolTextEditor20250728Param(
                    type="text_editor_20250728",
                    name="str_replace_based_edit_tool",
                    cache_control=CacheControlEphemeralParam(type="ephemeral"),
                )
            if tool.name == "bash":
                return BetaToolBash20250124Param(
                    type="bash_20250124",
                    name="bash",
                    cache_control=CacheControlEphemeralParam(type="ephemeral"),
                )
            if re.fullmatch(self.computer_tool_regex, tool.name):
                return BetaToolComputerUse20250124Param(
                    type="computer_20250124",
                    name="computer",
                    display_number=1,
                    display_width_px=computer_settings.ANTHROPIC_COMPUTER_WIDTH,
                    display_height_px=computer_settings.ANTHROPIC_COMPUTER_HEIGHT,
                    cache_control=CacheControlEphemeralParam(type="ephemeral"),
                )

            if tool.description is None or tool.inputSchema is None:
                raise ValueError(
                    cleandoc(f"""MCP tool {tool.name} requires both a description and inputSchema.
                    Add these by:
                    1. Adding a docstring to your @mcp.tool decorated function for the description
                    2. Using pydantic Field() annotations on function parameters for the schema
                    """)
                )
            """Convert a tool to the API format"""
            return BetaToolParam(
                name=tool.name,
                description=tool.description,
                input_schema=tool.inputSchema,
                cache_control=CacheControlEphemeralParam(type="ephemeral"),
            )

        self.has_computer_tool = False
        self.tool_mapping = {}
        self.claude_tools = []
        for tool in self.get_available_tools():
            claude_tool = to_api_tool(tool)
            # warn if multiple computer tools are found
            if claude_tool["name"] == "computer":
                if self.has_computer_tool:
                    logger.warning(
                        "Multiple computer tools found. Ignoring %s since %s is already present",
                        tool.name,
                        self.tool_mapping["computer"],
                    )
                    continue
                else:
                    self.has_computer_tool = True
            self.tool_mapping[claude_tool["name"]] = tool.name
            self.claude_tools.append(claude_tool)

    def _add_prompt_caching(self, messages: list[BetaMessageParam]) -> list[BetaMessageParam]:
        """Add prompt caching to messages."""
        messages_cached = copy.deepcopy(messages)
        cache_control: CacheControlEphemeralParam = {"type": "ephemeral"}

        # Mark last user message with cache control
        if (
            messages_cached
            and isinstance(messages_cached[-1], dict)
            and messages_cached[-1].get("role") == "user"
        ):
            last_content = messages_cached[-1]["content"]
            # Content is formatted to be list of ContentBlock in format_blocks and format_message
            if isinstance(last_content, list):
                for block in last_content:
                    # Only add cache control to dict-like block types that support it
                    if isinstance(block, dict):
                        match block["type"]:
                            case "redacted_thinking" | "thinking":
                                pass
                            case _:
                                block["cache_control"] = cache_control

        return messages_cached


def base64_to_content_block(base64: str) -> BetaImageBlockParam:
    """Convert base64 image to Claude content block."""
    return BetaImageBlockParam(
        type="image",
        source=BetaBase64ImageSourceParam(
            type="base64",
            media_type="image/png",
            data=base64,
        ),
    )


def text_to_content_block(text: str) -> BetaTextBlockParam:
    """Convert text to Claude content block."""
    return {"type": "text", "text": text}


def tool_use_content_block(
    tool_use_id: str, content: list[BetaTextBlockParam | BetaImageBlockParam]
) -> BetaToolResultBlockParam:
    """Create tool result content block."""
    return {"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
