#!/usr/bin/env python3

import json
import sys
import os
import re
import tempfile
from typing import Optional, List

try:
    from google import genai
    from google.genai import types
    from pydantic import BaseModel
except ImportError:
    print(
        "Error: google-genai or pydantic not installed. Run: pip install google-genai pydantic",
        file=sys.stderr,
    )
    sys.exit(0)


class SeverityBreakdown(BaseModel):
    BLOCK: Optional[List[str]] = []
    WARN: Optional[List[str]] = []
    INFO: Optional[List[str]] = []


class ValidationResponse(BaseModel):
    approved: bool
    reason: str
    suggestions: Optional[List[str]] = []
    detailed_analysis: Optional[str] = None
    thinking_process: Optional[str] = None
    full_context: Optional[str] = None
    performance_analysis: Optional[str] = None
    code_quality_analysis: Optional[str] = None
    alternative_approaches: Optional[List[str]] = []
    severity_breakdown: Optional[SeverityBreakdown] = None


class FileAnalysisResponse(BaseModel):
    security_issues: List[str]
    code_quality_concerns: List[str]
    risk_assessment: str
    recommendations: List[str]


class ClaudeToolValidator:
    def __init__(self, api_key: Optional[str] = None):
        self.api_key = api_key
        self.client = genai.Client(api_key=api_key) if api_key else None
        self.model_name = "gemini-2.5-pro"
        self.decision_history: List[dict] = []
        self.uploaded_files: List[dict] = []

    def upload_file_for_analysis(
        self, file_path: str, content: str
    ) -> Optional[object]:
        """Upload file content to Gemini for enhanced analysis"""
        if not self.client:
            return None
        try:
            with tempfile.NamedTemporaryFile(
                mode="w", suffix=os.path.splitext(file_path)[1], delete=False
            ) as temp_file:
                temp_file.write(content)
                temp_file_path = temp_file.name

            uploaded_file = self.client.files.upload(file=temp_file_path)
            self.uploaded_files.append(
                {"file_obj": uploaded_file, "temp_path": temp_file_path}
            )
            return uploaded_file  # Return the file object directly
        except Exception:
            return None

    def analyze_uploaded_file(
        self, uploaded_file: object, file_path: str
    ) -> Optional[dict]:
        """Perform enhanced security analysis using uploaded file"""
        if not self.client:
            return None
        try:
            prompt = f"""Perform comprehensive security analysis of this file: {os.path.basename(file_path)}

Analyze for:
1. Security vulnerabilities (injections, exposures, dangerous functions)
2. Code quality issues (complexity, maintainability, best practices)
3. Configuration security (permissions, secrets, access controls)
4. Potential attack vectors and exploitation risks
5. Compliance with security standards

Focus on:
- Malicious code patterns
- Credential leaks or hardcoded secrets
- Unsafe file operations
- Network security issues
- Input validation gaps
- Authorization/authentication flaws

Provide structured assessment with specific security concerns and actionable recommendations."""

            # Pass the uploaded file object directly
            response = self.client.models.generate_content(
                model=self.model_name,
                contents=[prompt, uploaded_file],
                config=types.GenerateContentConfig(
                    response_mime_type="application/json",
                    response_schema=FileAnalysisResponse,
                    thinking_config=types.ThinkingConfig(thinking_budget=24768),
                ),
            )

            if hasattr(response, "parsed") and response.parsed:
                result = response.parsed
                return {
                    "security_issues": result.security_issues,
                    "code_quality_concerns": result.code_quality_concerns,
                    "risk_assessment": result.risk_assessment,
                    "recommendations": result.recommendations,
                }
            else:
                return dict(json.loads(response.text))
        except Exception:
            return None

    def cleanup_uploaded_files(self) -> None:
        """Clean up uploaded files and temporary files"""
        for file_info in self.uploaded_files:
            try:
                if os.path.exists(file_info["temp_path"]):
                    os.unlink(file_info["temp_path"])
            except Exception:
                pass
        self.uploaded_files = []

    def before_tool_callback(self, tool_request: dict) -> Optional[dict]:
        """ADK-inspired tool validation for Claude Code hooks"""

        tool_name = tool_request.get("tool_name", "")
        tool_input = tool_request.get("tool_input", {})
        transcript_path = tool_request.get("transcript_path", "")

        try:
            context = self.extract_conversation_context(transcript_path)
            validation_result = self.validate_tool_use(tool_name, tool_input, context)

            self.cleanup_uploaded_files()

            if validation_result["approved"]:
                return None
            else:
                return {"error": validation_result["reason"]}
        except Exception:
            self.cleanup_uploaded_files()
            return None

    def extract_conversation_context(self, transcript_path: str) -> str:
        """Extract recent conversation context from transcript"""
        try:
            if os.path.exists(transcript_path):
                with open(transcript_path, "r", encoding="utf-8") as f:
                    content = f.read()
                return content
        except Exception:
            pass
        return ""

    def validate_tool_use(self, tool_name: str, tool_input: dict, context: str) -> dict:
        """Main validation logic using Gemini with ADK patterns"""

        quick_check = self.perform_quick_validation(tool_name, tool_input)
        if not quick_check["approved"]:
            return quick_check

        # Skip LLM-dependent analysis if no API key
        if not self.api_key:
            return quick_check

        file_analysis = None
        if (
            tool_name in ["Write", "Edit", "MultiEdit", "Update"]
            and "content" in tool_input
        ):
            file_path = tool_input.get("file_path", "")
            content = tool_input.get("content", "")

            if content and len(content) > 500:
                uploaded_file = self.upload_file_for_analysis(file_path, content)
                if uploaded_file:
                    file_analysis = self.analyze_uploaded_file(uploaded_file, file_path)
                    if file_analysis and file_analysis.get("security_issues"):
                        return {
                            "approved": False,
                            "reason": f"File analysis result: {', '.join(file_analysis['security_issues'])}",
                            "suggestions": file_analysis.get("recommendations", []),
                            "file_analysis": file_analysis,
                        }

        if quick_check.get("approved", True) and not file_analysis:
            return quick_check

        try:
            prompt = self.build_validation_prompt(
                tool_name, tool_input, context, file_analysis
            )

            thinking_config = types.ThinkingConfig(thinking_budget=24576)

            assert self.client is not None  # Type guard for mypy
            response = self.client.models.generate_content(
                model=self.model_name,
                contents=prompt,
                config=types.GenerateContentConfig(
                    response_mime_type="application/json",
                    response_schema=ValidationResponse,
                    thinking_config=thinking_config,
                ),
            )

            # Extract full response details
            raw_response = response.text if hasattr(response, "text") else str(response)
            thinking_content = ""

            # Try to extract thinking process if available
            if hasattr(response, "candidates") and response.candidates:
                for candidate in response.candidates:
                    if hasattr(candidate, "content") and candidate.content:
                        for part in candidate.content.parts:
                            if hasattr(part, "thought") and part.thought:
                                thinking_content += part.thought + "\n"

            if hasattr(response, "parsed") and response.parsed:
                result = response.parsed
                return {
                    "approved": result.approved,
                    "reason": result.reason,
                    "suggestions": result.suggestions or [],
                    "detailed_analysis": getattr(result, "detailed_analysis", None),
                    "thinking_process": thinking_content
                    or getattr(result, "thinking_process", None),
                    "full_context": context,
                    "raw_response": raw_response,
                    "file_analysis": file_analysis,
                    "performance_analysis": getattr(
                        result, "performance_analysis", None
                    ),
                    "code_quality_analysis": getattr(
                        result, "code_quality_analysis", None
                    ),
                    "alternative_approaches": getattr(
                        result, "alternative_approaches", []
                    ),
                    "severity_breakdown": (
                        lambda x: x.model_dump() if x is not None else None
                    )(getattr(result, "severity_breakdown", None)),
                }
            else:
                result = json.loads(response.text)
                required_fields = ["approved", "reason"]
                if all(field in result for field in required_fields):
                    result["thinking_process"] = thinking_content
                    result["full_context"] = context
                    result["raw_response"] = raw_response
                    result["file_analysis"] = file_analysis
                    # Add new fields if they exist in the response
                    result["performance_analysis"] = result.get("performance_analysis")
                    result["code_quality_analysis"] = result.get(
                        "code_quality_analysis"
                    )
                    result["alternative_approaches"] = result.get(
                        "alternative_approaches", []
                    )
                    result["severity_breakdown"] = result.get("severity_breakdown")
                    return dict(result)
                else:
                    return {
                        "approved": False,
                        "reason": "Invalid response structure from validation service",
                        "raw_response": raw_response,
                        "full_context": context,
                    }
        except Exception as e:
            return {
                "approved": True,
                "reason": f"Validation service unavailable: {str(e)}",
            }

    def perform_quick_validation(self, tool_name: str, tool_input: dict) -> dict:
        """Fast rule-based validation (Tier 1)"""

        if tool_name == "Bash":
            return self.validate_bash_command(tool_input)
        elif tool_name in ["Write", "Edit", "MultiEdit", "Update"]:
            return self.validate_file_operation(tool_input)
        else:
            return {"approved": True}

    def validate_bash_command(self, tool_input: dict) -> dict:
        """Enhanced validation for bash commands"""
        command = tool_input.get("command", "")

        critical_patterns = [
            r"rm\s+-rf\s+/",
            r"sudo\s+rm.*/",
            r"mkfs",
            r"dd\s+if=.*of=.*",
            r"curl.*\|\s*bash",
            r"wget.*\|\s*(bash|sh)",
            r"> /etc/",
            r"> /bin/",
            r"> /usr/",
        ]

        for pattern in critical_patterns:
            if re.search(pattern, command, re.IGNORECASE):
                return {
                    "approved": False,
                    "reason": "Dangerous command pattern detected: potentially destructive operation",
                }

        # Performance/tool enforcement - block inefficient commands
        tool_enforcement = [
            (
                r"^grep\b|^\s*grep\b",
                "Use 'rg' (ripgrep) instead of 'grep' for better performance and features. Command blocked to enforce best practices.",
            ),
            (
                r"^find\s+.*-name\b|^\s*find\s+.*-name\b",
                "Use 'rg --files -g pattern' or 'rg --files | rg pattern' instead of 'find -name' for better performance. Command blocked to enforce best practices.",
            ),
            (
                r"^(python|python3)\b|^\s*(python|python3)\b",
                "Use 'uv run python' instead of direct python for better dependency management and virtual environment handling. Command blocked to enforce best practices.",
            ),
        ]

        for pattern, suggestion in tool_enforcement:
            if re.search(pattern, command, re.IGNORECASE):
                return {
                    "approved": False,
                    "reason": suggestion,
                    "suggestions": [suggestion.split(". Command blocked")[0]],
                }

        warning_patterns = [
            r"sudo",
            r"rm\s+-rf",
            r"git\s+reset\s+--hard",
            r"npm\s+uninstall",
            r"pip\s+uninstall",
        ]

        for pattern in warning_patterns:
            if re.search(pattern, command, re.IGNORECASE):
                return {
                    "approved": True,
                    "reason": "Command requires elevated privileges or has destructive potential",
                    "risk_level": "high",
                }

        return {"approved": True}

    def validate_file_operation(self, tool_input: dict) -> dict:
        """Enhanced validation for file operations"""
        file_path = tool_input.get("file_path", "")
        content = tool_input.get("content", "")

        if (
            "../" in file_path
            or file_path.startswith("/etc/")
            or file_path.startswith("/bin/")
        ):
            return {
                "approved": False,
                "reason": "Potentially dangerous file path - outside project boundary or system directory",
                "risk_level": "critical",
            }

        # Basic secret detection patterns (LLM handles sophisticated cases)
        secret_patterns = [
            (r"sk_live_[a-zA-Z0-9]{24,}", "Stripe live secret key detected"),
            (r"sk_test_[a-zA-Z0-9]{24,}", "Stripe test secret key detected"),
            (r"AKIA[0-9A-Z]{16}", "AWS access key ID detected"),
            (r"ghp_[a-zA-Z0-9]{36}", "GitHub personal access token detected"),
            (r"gho_[a-zA-Z0-9]{36}", "GitHub OAuth token detected"),
            (r"ghr_[a-zA-Z0-9]{36}", "GitHub refresh token detected"),
            (
                r"eyJ[a-zA-Z0-9_-]+\.eyJ[a-zA-Z0-9_-]+\.[a-zA-Z0-9_-]+",
                "JWT token detected",
            ),
            (
                r"xoxb-[0-9]{10,13}-[0-9]{10,13}-[a-zA-Z0-9]{24}",
                "Slack bot token detected",
            ),
            (r"-----BEGIN [A-Z]+ PRIVATE KEY-----", "Private key detected"),
        ]

        for pattern, message in secret_patterns:
            if re.search(pattern, content):
                return {
                    "approved": False,
                    "reason": f"Security violation: {message}",
                    "suggestions": [
                        "Use environment variables for secrets",
                        "Consider using a secrets manager like AWS Secrets Manager",
                        "Never commit real credentials to version control",
                    ],
                }

        if file_path.endswith(".py") and (
            "import os" in content and "system(" in content
        ):
            return {
                "approved": False,
                "reason": "Potential shell injection pattern in Python code",
                "risk_level": "high",
            }

        return {"approved": True, "risk_level": "low"}

    def build_validation_prompt(
        self,
        tool_name: str,
        tool_input: dict,
        context: str,
        file_analysis: Optional[dict] = None,
    ) -> str:
        """Build validation prompt for Gemini with structured output and Google Search"""

        file_analysis_section = ""
        if file_analysis:
            file_analysis_section = f"""
FILE_ANALYSIS_RESULTS:
Security Issues Detected: {json.dumps(file_analysis.get('security_issues', []), indent=2)}
Code Quality Concerns: {json.dumps(file_analysis.get('code_quality_concerns', []), indent=2)}
Risk Assessment: {file_analysis.get('risk_assessment', 'Not available')}
Recommendations: {json.dumps(file_analysis.get('recommendations', []), indent=2)}
"""

        return f"""You are a comprehensive code analyzer for Claude Code. You provide security validation AND constructive feedback on performance, code quality, and best practices for ALL operations.

TOOL: {tool_name}
INPUT: {json.dumps(tool_input, indent=2)}
RECENT_CONTEXT: {context}
{file_analysis_section}

ANALYSIS SCOPE:
1. **Security** - Identify vulnerabilities and dangerous patterns
2. **Performance** - Suggest faster tools and more efficient approaches
3. **Code Quality** - Promote clean, self-evident, pythonic code
4. **Best Practices** - Recommend modern patterns and industry standards
5. **Alternative Approaches** - Suggest different ways to achieve the same goal

CRITICAL SECURITY RULES - YOU MUST BLOCK THESE:

1. **ALWAYS BLOCK Real Credentials/Secrets**:
   - AWS keys (AKIA..., 40-char base64 strings)
   - GitHub tokens (ghp_..., gho_..., ghr_...)
   - Stripe keys (sk_live_..., pk_live_...)
   - JWT tokens (ey...ey...signature)
   - API keys that look real (long, random, not placeholders)
   - Actual passwords in configuration files
   - Private keys, certificates, or tokens

2. **ALWAYS BLOCK Dangerous Commands**:
   - System destruction (rm -rf /, mkfs, dd to devices)
   - Malicious downloads (curl | bash, wget | sh)
   - Privilege escalation attempts
   - Network exfiltration
   - System file modifications outside project

3. **ALWAYS ALLOW Documentation/Examples**:
   - Placeholder values: "YOUR_API_KEY", "xxx", "...", "<SECRET>"
   - Variable names in docs: GEMINI_API_KEY (without assignment)
   - Example configurations with fake values
   - Safe development commands (ls, git, npm, pip)

ANALYSIS FRAMEWORK:
Search current security intelligence and analyze:

**Secret Detection Logic**:
- Real secrets: Block immediately (approved: false)
- Placeholders/docs: Allow (approved: true)
- Example: "GEMINI_API_KEY=your_key_here" = ALLOW (placeholder)
- Example: "api_key=sk_live_abc123def456..." = BLOCK (real Stripe key)

**Command Analysis**:
- Check for obfuscation, encoding, chaining
- Analyze full command context and impact
- Consider privilege escalation potential
- Look for performance improvement opportunities (suggest rg over grep)

**Context Awareness**:
- Is this legitimate development activity?
- Does the action match user intent?
- Are there safer alternatives?

**Challenge and Improve**:
- Challenge the approach: Is there a better way to achieve this goal?
- Question assumptions: Are there hidden risks or better practices?
- Suggest improvements: Modern tools, security practices, performance optimizations
- Educational feedback: Help the user learn safer development practices

USE YOUR THINKING BUDGET to reason through complex scenarios. Consider social engineering, supply chain attacks, and advanced threats.

RESPONSE REQUIREMENTS:
1. **Decision**: approved: true/false with clear reasoning
2. **Risk Level**: low/medium/high/critical based on overall assessment
3. **Suggestions**: Provide 2-3 specific actionable improvements
4. **Detailed Analysis**: Comprehensive security evaluation
5. **Performance Analysis**: Evaluate efficiency and suggest optimizations:
   - Tool performance (e.g., "rg is 10x faster than grep")
   - Algorithm complexity
   - Resource usage
   - Caching opportunities
6. **Code Quality Analysis**: Assess clarity and maintainability:
   - Is the code self-evident?
   - Can it be simplified?
   - Does it follow language idioms?
   - Are comments necessary or can the code speak for itself?
7. **Alternative Approaches**: List different ways to achieve the same goal:
   - Modern tool alternatives
   - Different command patterns
   - Framework-specific solutions
8. **Severity Breakdown**: Categorize all findings:
   - BLOCK: Critical issues that prevent execution
   - WARN: Important issues that should be addressed
   - INFO: Helpful tips and optimizations
9. **Thinking Process**: Document your reasoning

ANALYSIS PRINCIPLES:
- Provide value on EVERY operation, not just failures
- Be a mentor, not just a guard
- Challenge approaches constructively
- Suggest modern, efficient alternatives
- Promote self-evident code that needs no comments
- Consider both immediate and long-term implications

DECISION CRITERIA:
- approved: true = Operation can proceed (may have WARN/INFO items)
- approved: false = Operation blocked (has BLOCK items)
- Always be educational and constructive
- Focus on helping developers write better code

Analyze comprehensively, teach continuously, and help developers level up their skills with every interaction."""


def main() -> None:
    """Main entry point for Claude Code hook"""
    try:
        hook_input = json.loads(sys.stdin.read())
    except json.JSONDecodeError:
        print("Invalid JSON input", file=sys.stderr)
        sys.exit(0)

    api_key = os.environ.get("GEMINI_API_KEY")
    if not api_key:
        print(
            "GEMINI_API_KEY not configured - allowing all operations", file=sys.stderr
        )
        sys.exit(0)

    validator = ClaudeToolValidator(api_key)
    validation = validator.before_tool_callback(hook_input)

    if validation is None:
        sys.exit(0)
    else:
        # Enhanced stderr output with detailed analysis
        print("=" * 80, file=sys.stderr)
        print("CLAUDE CODE SECURITY VALIDATION ANALYSIS", file=sys.stderr)
        print("=" * 80, file=sys.stderr)

        print(
            f"\nDECISION: {'BLOCKED' if not validation.get('approved', True) else 'APPROVED'}",
            file=sys.stderr,
        )
        print(
            file=sys.stderr,
        )
        print(
            f"REASON: {validation.get('reason', 'No reason provided')}",
            file=sys.stderr,
        )

        if validation.get("suggestions"):
            print("\nSUGGESTIONS:", file=sys.stderr)
            for i, suggestion in enumerate(validation.get("suggestions", []), 1):
                print(f"  {i}. {suggestion}", file=sys.stderr)

        if validation.get("thinking_process"):
            print("\nTHINKING PROCESS:", file=sys.stderr)
            print("-" * 60, file=sys.stderr)
            print(validation.get("thinking_process"), file=sys.stderr)
            print("-" * 60, file=sys.stderr)

        if validation.get("file_analysis"):
            print("\nFILE ANALYSIS:", file=sys.stderr)
            file_analysis = validation.get("file_analysis")

            if file_analysis and file_analysis.get("security_issues"):
                print("\nSECURITY ISSUES:", file=sys.stderr)
                for issue in file_analysis.get("security_issues", []):
                    print(f"  • {issue}", file=sys.stderr)

            if file_analysis and file_analysis.get("code_quality_concerns"):
                print("\nCODE QUALITY CONCERNS:", file=sys.stderr)
                for concern in file_analysis.get("code_quality_concerns", []):
                    print(f"  • {concern}", file=sys.stderr)

            if file_analysis and file_analysis.get("risk_assessment"):
                print("\nRISK ASSESSMENT:", file=sys.stderr)
                print(f"  {file_analysis.get('risk_assessment')}", file=sys.stderr)

            if file_analysis and file_analysis.get("recommendations"):
                print("\nFILE RECOMMENDATIONS:", file=sys.stderr)
                for rec in file_analysis.get("recommendations", []):
                    print(f"  • {rec}", file=sys.stderr)

        if validation.get("full_context"):
            print("\nFULL CONTEXT:", file=sys.stderr)
            print("-" * 60, file=sys.stderr)
            print(validation.get("full_context"), file=sys.stderr)
            print("-" * 60, file=sys.stderr)

        if validation.get("raw_response"):
            print("\nRAW LLM RESPONSE:", file=sys.stderr)
            print("-" * 60, file=sys.stderr)
            print(validation.get("raw_response"), file=sys.stderr)
            print("-" * 60, file=sys.stderr)

        print("\nVALIDATION SUMMARY:", file=sys.stderr)
        print(json.dumps(validation, indent=2), file=sys.stderr)
        print("=" * 80, file=sys.stderr)

        sys.exit(2)


if __name__ == "__main__":
    main()
