#!/usr/bin/env python3

import json
import re
from typing import Optional, Dict, Any, List
from pydantic import BaseModel

try:
    from google import genai
    from google.genai import types
except ImportError:
    genai = None

from .tdd_prompts import (
    TDDCorePrompt,
    EditAnalysisPrompt,
    WriteAnalysisPrompt,
    MultiEditAnalysisPrompt,
    TDDContextFormatter,
)
from .config import GEMINI_MODEL, FILE_CATEGORIZATION_MODEL, TDD_THINKING_BUDGET


class TDDValidationResponse(BaseModel):
    """TDD-specific validation response model"""

    approved: bool
    violation_type: Optional[str] = (
        None  # "multiple_tests", "over_implementation", "premature_implementation"
    )
    test_count: Optional[int] = None
    affected_files: List[str] = []
    tdd_phase: str = "unknown"  # "red", "green", "refactor", "unknown"
    reason: str = ""
    suggestions: List[str] = []
    detailed_analysis: Optional[str] = None


class FileCategorizationResponse(BaseModel):
    """Response model for file categorization"""

    category: str  # "config", "docs", "data", "test", "implementation"
    reason: str
    requires_tdd: bool


class TDDValidator:
    """
    TDDValidator enforces Test-Driven Development principles through:
    - Operation-specific validation (Edit, Write, MultiEdit)
    - Red-Green-Refactor cycle enforcement
    - New test count detection
    - Over-implementation prevention
    """

    def __init__(self, api_key: Optional[str] = None):
        self.api_key = api_key
        self.client = genai.Client(api_key=api_key) if api_key and genai else None
        self.model_name = GEMINI_MODEL

    def categorize_file(self, file_path: str, content: str = "") -> dict:
        """
        Use LLM to categorize file type for TDD validation logic.

        Returns dict with:
            category: 'config', 'docs', 'data', 'test', 'implementation'
            requires_tdd: bool
            reason: str
        """
        if not self.api_key or not self.client:
            # Fallback to basic categorization if no API
            import os

            ext = os.path.splitext(file_path)[1].lower()
            if ext in [".py", ".js", ".ts", ".java", ".go"]:
                return {
                    "category": "implementation",
                    "requires_tdd": True,
                    "reason": "Code file",
                }
            return {
                "category": "data",
                "requires_tdd": False,
                "reason": "Non-code file",
            }

        prompt = f"""Analyze this file and categorize it for TDD (Test-Driven Development) validation purposes.

FILE PATH: {file_path}
CONTENT PREVIEW (first 500 chars): {content[:500] if content else "No content provided"}

Categories:
- 'config': Configuration files (pyproject.toml, package.json, .env, etc.) - NO TDD REQUIRED
- 'docs': Documentation files (README.md, docs/, etc.) - NO TDD REQUIRED  
- 'data': Data/resource files (CSV, JSON data, images, etc.) - NO TDD REQUIRED
- 'test': Test files (test_*.py, *.test.js, etc.) - SPECIAL TDD RULES
- 'implementation': Implementation code files - TDD REQUIRED

Determine the category based on:
1. File extension and naming patterns
2. File location in project structure
3. Content analysis (if provided)
4. Common development practices

Be intelligent about edge cases:
- setup.py, __init__.py are usually config/infrastructure, not implementation
- Migration files, schema files are data/config
- Example/demo files might be docs
- Only actual business logic should be 'implementation'

Provide a clear reason for your categorization."""

        try:
            response = self.client.models.generate_content(
                model=FILE_CATEGORIZATION_MODEL,
                contents=prompt,
                config=types.GenerateContentConfig(
                    response_mime_type="application/json",
                    response_schema=FileCategorizationResponse,
                    tools=[types.Tool(google_search=types.GoogleSearch())],
                ),
            )

            if hasattr(response, "parsed") and response.parsed:
                result = response.parsed
                return {
                    "category": result.category,
                    "requires_tdd": result.requires_tdd,
                    "reason": result.reason,
                }
            else:
                # Fallback parsing
                return json.loads(response.text)

        except Exception:
            # Fallback on error
            return {
                "category": "implementation",
                "requires_tdd": True,
                "reason": "Categorization failed, assuming implementation",
            }

    def validate(
        self,
        tool_name: str,
        tool_input: dict,
        context: str,
        tdd_context: Dict[str, Any],
    ) -> dict:
        """
        Main TDD validation entry point.

        Args:
            tool_name: The Claude tool being executed
            tool_input: The tool's input parameters
            context: Conversation context
            tdd_context: TDD-specific context (test results, todos, modifications)

        Returns:
            TDDValidationResponse dict with TDD compliance status
        """

        # Skip TDD validation if no API key
        if not self.api_key:
            return {
                "approved": True,
                "reason": "TDD validation service unavailable",
                "tdd_phase": "unknown",
            }

        # Check file categorization for file-based operations
        if tool_name in ["Write", "Edit", "MultiEdit"]:
            file_path = tool_input.get("file_path", "")
            content = tool_input.get("content", "") if tool_name == "Write" else ""

            # Get file category using LLM
            categorization = self.categorize_file(file_path, content)

            # Skip TDD validation for non-implementation files
            if not categorization.get("requires_tdd", True):
                return {
                    "approved": True,
                    "reason": f"TDD validation not required for {categorization.get('category', 'unknown')} files: {categorization.get('reason', '')}",
                    "tdd_phase": "not_applicable",
                    "file_category": categorization.get("category"),
                }

        # Route to operation-specific validation
        try:
            if tool_name == "Edit":
                return self.validate_edit_operation(tool_input, tdd_context)
            elif tool_name == "Write":
                return self.validate_write_operation(tool_input, tdd_context)
            elif tool_name == "MultiEdit":
                return self.validate_multi_edit_operation(tool_input, tdd_context)
            else:
                # Other operations (Bash, etc.) don't need TDD validation
                return {
                    "approved": True,
                    "reason": f"{tool_name} operation doesn't require TDD validation",
                    "tdd_phase": "unknown",
                }

        except Exception as e:
            # Fail-safe: allow operation if TDD validation fails
            return {
                "approved": True,
                "reason": f"TDD validation service error: {str(e)}",
                "tdd_phase": "unknown",
            }

    def validate_edit_operation(
        self, tool_input: dict, tdd_context: Dict[str, Any]
    ) -> dict:
        """Validate Edit operations for TDD compliance"""

        file_path = tool_input.get("file_path", "")
        old_content = tool_input.get("old_string", "")
        new_content = tool_input.get("new_string", "")

        # Build TDD analysis prompt
        prompt = self.build_edit_validation_prompt(
            old_content, new_content, file_path, tdd_context
        )

        return self.execute_tdd_validation(prompt, [file_path])

    def validate_write_operation(
        self, tool_input: dict, tdd_context: Dict[str, Any]
    ) -> dict:
        """Validate Write operations for TDD compliance"""

        file_path = tool_input.get("file_path", "")
        content = tool_input.get("content", "")

        # Build TDD analysis prompt
        prompt = self.build_write_validation_prompt(file_path, content, tdd_context)

        return self.execute_tdd_validation(prompt, [file_path])

    def validate_multi_edit_operation(
        self, tool_input: dict, tdd_context: Dict[str, Any]
    ) -> dict:
        """Validate MultiEdit operations for TDD compliance"""

        edits = tool_input.get("edits", [])
        file_path = tool_input.get("file_path", "")

        # Build TDD analysis prompt for multiple edits
        prompt = self.build_multi_edit_validation_prompt(edits, file_path, tdd_context)

        return self.execute_tdd_validation(prompt, [file_path])

    def build_edit_validation_prompt(
        self,
        old_content: str,
        new_content: str,
        file_path: str,
        tdd_context: Dict[str, Any],
    ) -> str:
        """Build validation prompt for Edit operations"""

        tdd_principles = TDDCorePrompt.get_tdd_principles()
        edit_analysis = EditAnalysisPrompt.get_analysis_prompt(
            old_content, new_content, file_path
        )
        context_info = TDDContextFormatter.format_tdd_context(tdd_context)

        return f"""You are a TDD compliance validator. Analyze this Edit operation for Test-Driven Development violations.

{tdd_principles}

{edit_analysis}

{context_info}

## VALIDATION REQUIREMENTS

Your task is to determine if this Edit operation violates TDD principles. Focus on:

1. **New Test Count**: How many completely new tests are being added?
2. **Implementation Scope**: Is the implementation minimal and test-driven?
3. **TDD Phase Compliance**: Does this follow Red-Green-Refactor properly?
4. **Over-implementation**: Are features being added beyond test requirements?

## DECISION FRAMEWORK

**APPROVE** if:
- Zero or one new test being added
- Implementation is minimal and addresses specific test failures
- Changes follow Red-Green-Refactor discipline
- No premature optimization or over-engineering

**BLOCK** if:
- Multiple new tests being added in single operation
- Over-implementation beyond current test requirements
- Implementation without corresponding test failures
- Features added that aren't tested

## RESPONSE FORMAT

Provide structured TDD validation response with:
- **approved**: boolean decision
- **violation_type**: specific TDD violation if any
- **test_count**: number of new tests detected
- **tdd_phase**: current phase (red/green/refactor)
- **reason**: clear explanation of decision
- **suggestions**: actionable TDD improvements
- **detailed_analysis**: comprehensive TDD assessment

Analyze thoroughly and enforce TDD discipline to maintain code quality and test coverage."""

    def build_write_validation_prompt(
        self, file_path: str, content: str, tdd_context: Dict[str, Any]
    ) -> str:
        """Build validation prompt for Write operations"""

        tdd_principles = TDDCorePrompt.get_tdd_principles()
        write_analysis = WriteAnalysisPrompt.get_analysis_prompt(file_path, content)
        context_info = TDDContextFormatter.format_tdd_context(tdd_context)

        return f"""You are a TDD compliance validator. Analyze this Write operation for Test-Driven Development violations.

{tdd_principles}

{write_analysis}

{context_info}

## VALIDATION REQUIREMENTS

Your task is to determine if this Write operation violates TDD principles. Focus on:

1. **File Type**: Is this a test file or implementation file?
2. **Test Count**: If test file, count how many NEW test functions are being added (CRITICAL: only ONE allowed)
3. **Test Coverage**: For implementation files, are there corresponding tests?
4. **Implementation Justification**: Is implementation driven by test failures?
5. **Scope Assessment**: Is implementation minimal and focused?

## DECISION FRAMEWORK

**APPROVE** if:
- Writing test files with ONLY ONE new test at a time
- Writing minimal implementation to address specific test failures
- Creating infrastructure/setup code that supports testing
- Implementation scope matches test requirements

**BLOCK** if:
- Writing multiple tests in a single operation (even in test files)
- Creating implementation files without corresponding tests
- Over-implementing beyond current test requirements
- Writing speculative code not driven by test failures
- Implementing multiple features without adequate test coverage

## RESPONSE FORMAT

Provide structured TDD validation response focusing on file creation compliance with TDD workflow."""

    def build_multi_edit_validation_prompt(
        self, edits: List[Dict[str, Any]], file_path: str, tdd_context: Dict[str, Any]
    ) -> str:
        """Build validation prompt for MultiEdit operations"""

        tdd_principles = TDDCorePrompt.get_tdd_principles()
        multi_edit_analysis = MultiEditAnalysisPrompt.get_analysis_prompt(edits)
        context_info = TDDContextFormatter.format_tdd_context(tdd_context)

        return f"""You are a TDD compliance validator. Analyze this MultiEdit operation for Test-Driven Development violations.

{tdd_principles}

{multi_edit_analysis}

{context_info}

## VALIDATION REQUIREMENTS

Your task is to determine if this MultiEdit operation violates TDD principles. Focus on:

1. **Cumulative New Tests**: Total new tests across ALL edits
2. **Sequential Implementation**: Is each edit minimal and justified?
3. **Scope Coherence**: Do all edits work toward single test goal?
4. **Progressive Compliance**: Does each edit maintain TDD discipline?

## CRITICAL RULE

**CUMULATIVE NEW TEST COUNT** across all edits must not exceed 1. This is the most important check for MultiEdit operations.

## DECISION FRAMEWORK

**APPROVE** if:
- Total new tests across all edits ≤ 1
- Each edit contributes to minimal implementation
- Sequential changes maintain test-driven approach
- No over-implementation or feature sprawl

**BLOCK** if:
- Total new tests across all edits > 1
- Edits implement features beyond test requirements
- Sequential changes show scope creep or over-engineering
- MultiEdit is being used to circumvent single-test rule

## RESPONSE FORMAT

Provide structured TDD validation response with special attention to cumulative effects of multiple edits."""

    def execute_tdd_validation(self, prompt: str, affected_files: List[str]) -> dict:
        """Execute TDD validation using Gemini with structured response"""

        try:
            thinking_config = types.ThinkingConfig(thinking_budget=TDD_THINKING_BUDGET)

            assert self.client is not None  # Type guard for mypy
            response = self.client.models.generate_content(
                model=self.model_name,
                contents=prompt,
                config=types.GenerateContentConfig(
                    response_mime_type="application/json",
                    response_schema=TDDValidationResponse,
                    thinking_config=thinking_config,
                    tools=[types.Tool(google_search=types.GoogleSearch())],
                ),
            )

            if hasattr(response, "parsed") and response.parsed:
                result = response.parsed
                return {
                    "approved": result.approved,
                    "violation_type": result.violation_type,
                    "test_count": result.test_count,
                    "affected_files": result.affected_files or affected_files,
                    "tdd_phase": result.tdd_phase,
                    "reason": result.reason,
                    "suggestions": result.suggestions or [],
                    "detailed_analysis": result.detailed_analysis,
                }
            else:
                # Fallback JSON parsing
                result = json.loads(response.text)
                result["affected_files"] = result.get("affected_files", affected_files)
                return result

        except Exception as e:
            # Fail-safe: allow operation if TDD validation fails
            return {
                "approved": True,
                "reason": f"TDD validation service error: {str(e)}",
                "tdd_phase": "unknown",
                "affected_files": affected_files,
            }

    def detect_test_files(self, file_path: str, content: str = "") -> bool:
        """Detect if a file is a test file based on path and content"""

        # Path-based detection
        test_path_patterns = [
            r"test.*\.py$",
            r".*_test\.py$",
            r".*\.test\.py$",
            r"test.*\.js$",
            r".*\.test\.js$",
            r".*\.spec\.js$",
            r"test.*\.go$",
            r".*_test\.go$",
            r"Test.*\.java$",
            r".*Test\.java$",
            r"/tests?/",
            r"\\tests?\\",
        ]

        for pattern in test_path_patterns:
            if re.search(pattern, file_path, re.IGNORECASE):
                return True

        # Content-based detection
        if content:
            test_content_patterns = [
                r"def test_",
                r"class Test",
                r"import unittest",
                r"test\(",
                r"describe\(",
                r"it\(",
                r"expect\(",
                r"func Test",
                r"@Test",
                r"@pytest",
            ]

            for pattern in test_content_patterns:
                if re.search(pattern, content):
                    return True

        return False

    def count_new_tests(self, old_content: str, new_content: str) -> int:
        """Count new test functions added (character-by-character comparison)"""

        # Extract test functions from both contents
        old_tests = self.extract_test_functions(old_content)
        new_tests = self.extract_test_functions(new_content)

        # Count tests that exist in new but not in old
        new_test_count = 0
        for test in new_tests:
            if test not in old_tests:
                new_test_count += 1

        return new_test_count

    def extract_test_functions(self, content: str) -> List[str]:
        """Extract test function names from code content"""

        test_patterns = [
            r"def (test_\w+)",  # Python test functions
            r"def (should_\w+)",  # Python BDD-style tests
            r'test\s*\(\s*[\'"]([^\'"]+)[\'"]',  # JavaScript test()
            r'it\s*\(\s*[\'"]([^\'"]+)[\'"]',  # JavaScript it()
            r"func (Test\w+)",  # Go test functions
            r"@Test\s+\w+\s+(\w+)",  # Java test methods
        ]

        test_functions = []
        for pattern in test_patterns:
            matches = re.findall(pattern, content, re.MULTILINE)
            test_functions.extend(matches)

        return test_functions
