"""WHERE clause evaluator for NLQL.

This module evaluates WHERE conditions against text units.
"""

from typing import Any

from nlql.ast.nodes import (
    ASTNode,
    ComparisonExpr,
    FunctionCall,
    Identifier,
    Literal,
    LogicalExpr,
    OperatorCall,
    UnaryOp,
)
from nlql.errors import NLQLExecutionError
from nlql.registry.functions import get_function
from nlql.registry.operators import get_operator
from nlql.text.units import TextUnit


class WhereEvaluator:
    """Evaluates WHERE clause conditions against text units."""

    def __init__(self, context: Any = None) -> None:
        """Initialize evaluator with optional execution context.

        Args:
            context: Optional execution context with instance-level registries
        """
        self.context = context

    def _get_function(self, name: str) -> Any:
        """Get function from instance registry or global registry.

        Instance-level registrations take precedence over global registrations.

        Args:
            name: Function name

        Returns:
            Function implementation or None if not found
        """
        # Check instance-level registry first
        if self.context and self.context.function_registry:
            func = self.context.function_registry.get(name)
            if func is not None:
                return func

        # Fall back to global registry
        return get_function(name)

    def _get_operator(self, name: str) -> Any:
        """Get operator from instance registry or global registry.

        Instance-level registrations take precedence over global registrations.

        Args:
            name: Operator name

        Returns:
            Operator implementation or None if not found
        """
        # Check instance-level registry first
        if self.context and self.context.operator_registry:
            op = self.context.operator_registry.get(name)
            if op is not None:
                return op

        # Fall back to global registry
        return get_operator(name)

    def evaluate(self, condition: ASTNode, unit: TextUnit) -> bool:
        """Evaluate a WHERE condition against a text unit.

        Args:
            condition: AST node representing the condition
            unit: Text unit to evaluate against

        Returns:
            True if the condition is satisfied, False otherwise

        Raises:
            NLQLExecutionError: If evaluation fails
        """
        try:
            result = self._eval_node(condition, unit)
            # Ensure boolean result
            if isinstance(result, bool):
                return result
            # Truthy conversion for non-boolean results
            return bool(result)
        except Exception as e:
            raise NLQLExecutionError(
                f"Failed to evaluate WHERE condition: {e}"
            ) from e

    def _eval_node(self, node: ASTNode, unit: TextUnit) -> Any:
        """Recursively evaluate an AST node.

        Args:
            node: AST node to evaluate
            unit: Text unit context

        Returns:
            Evaluation result
        """
        if isinstance(node, Literal):
            return node.value

        elif isinstance(node, Identifier):
            # Identifiers refer to fields in the text unit
            # For now, we support 'content' as the main field
            if node.name == "content":
                return unit.content
            else:
                # Try to get from metadata
                return unit.metadata.get(node.name)

        elif isinstance(node, OperatorCall):
            return self._eval_operator(node, unit)

        elif isinstance(node, FunctionCall):
            return self._eval_function(node, unit)

        elif isinstance(node, ComparisonExpr):
            return self._eval_comparison(node, unit)

        elif isinstance(node, LogicalExpr):
            return self._eval_logical(node, unit)

        elif isinstance(node, UnaryOp):
            return self._eval_unary(node, unit)

        else:
            raise NLQLExecutionError(f"Unknown AST node type: {type(node)}")

    def _eval_operator(self, node: OperatorCall, unit: TextUnit) -> Any:
        """Evaluate an operator call.

        Args:
            node: OperatorCall node
            unit: Text unit context

        Returns:
            Operator result
        """
        operator_func = self._get_operator(node.operator)
        if operator_func is None:
            raise NLQLExecutionError(f"Unknown operator: {node.operator}")

        # Evaluate arguments
        args = [self._eval_node(arg, unit) for arg in node.args]

        # Special handling for META operator - pass metadata dict
        if node.operator == "META":
            # META("field") should return metadata["field"]
            if len(args) != 1:
                raise NLQLExecutionError("META operator requires exactly 1 argument")
            field_name = args[0]
            return operator_func(unit.metadata, field_name)

        # For text operators (CONTAINS, MATCH), handle both forms:
        # - CONTAINS(content, "text") - explicit form with 2 args
        # - CONTAINS("text") - shorthand form with 1 arg (auto-use unit.content)
        elif node.operator in ("CONTAINS", "MATCH"):
            if len(args) == 1:
                # Shorthand form: CONTAINS("text") -> use unit.content as first arg
                return operator_func(unit.content, args[0])
            else:
                # Explicit form: CONTAINS(content, "text") -> args already evaluated
                return operator_func(*args)

        # Special handling for SIMILAR_TO operator
        elif node.operator == "SIMILAR_TO":
            # SIMILAR_TO("query") returns the similarity score from metadata
            # The score was already computed by executor's _apply_semantic_search()
            # and stored in metadata["similarity"]
            return unit.metadata.get("similarity", 0.0)

        else:
            # Generic operator call
            return operator_func(*args)

    def _eval_function(self, node: FunctionCall, unit: TextUnit) -> Any:
        """Evaluate a function call.

        Args:
            node: FunctionCall node
            unit: Text unit context

        Returns:
            Function result
        """
        # First check if this is a custom operator (uppercase name)
        # Custom operators are parsed as function calls but should be treated as operators
        if node.name.isupper():
            operator_func = self._get_operator(node.name)
            if operator_func is not None:
                # Treat as operator call
                args = [self._eval_node(arg, unit) for arg in node.args]
                return operator_func(*args)

        # Otherwise, treat as function call
        func = self._get_function(node.name)
        if func is None:
            raise NLQLExecutionError(f"Unknown function: {node.name}")

        # Evaluate arguments
        args = [self._eval_node(arg, unit) for arg in node.args]

        return func(*args)

    def _eval_comparison(self, node: ComparisonExpr, unit: TextUnit) -> bool:
        """Evaluate a comparison expression.

        Args:
            node: ComparisonExpr node
            unit: Text unit context

        Returns:
            Comparison result
        """
        left = self._eval_node(node.left, unit)
        right = self._eval_node(node.right, unit)

        if node.op == "==":
            return left == right
        elif node.op == "!=":
            return left != right
        elif node.op == "<":
            return left < right
        elif node.op == ">":
            return left > right
        elif node.op == "<=":
            return left <= right
        elif node.op == ">=":
            return left >= right
        else:
            raise NLQLExecutionError(f"Unknown comparison operator: {node.op}")

    def _eval_logical(self, node: LogicalExpr, unit: TextUnit) -> bool:
        """Evaluate a logical expression (AND, OR).

        Args:
            node: LogicalExpr node
            unit: Text unit context

        Returns:
            Logical result
        """
        if node.op == "AND":
            # Short-circuit evaluation
            for operand in node.operands:
                if not self._eval_node(operand, unit):
                    return False
            return True

        elif node.op == "OR":
            # Short-circuit evaluation
            for operand in node.operands:
                if self._eval_node(operand, unit):
                    return True
            return False

        else:
            raise NLQLExecutionError(f"Unknown logical operator: {node.op}")

    def _eval_unary(self, node: UnaryOp, unit: TextUnit) -> Any:
        """Evaluate a unary operation.

        Args:
            node: UnaryOp node
            unit: Text unit context

        Returns:
            Unary operation result
        """
        operand = self._eval_node(node.operand, unit)

        if node.op == "NOT":
            return not operand
        elif node.op == "-":
            return -operand
        else:
            raise NLQLExecutionError(f"Unknown unary operator: {node.op}")

