"""AST visitor pattern implementation."""

from typing import Any

from nlql.ast.nodes import (
    ASTNode,
    BinaryOp,
    ComparisonExpr,
    FunctionCall,
    Identifier,
    Literal,
    LogicalExpr,
    OperatorCall,
    OrderByClause,
    SelectStatement,
    UnaryOp,
    WhereClause,
)


class ASTVisitor:
    """Base visitor for traversing AST nodes.

    Subclasses can override visit_* methods to implement custom logic.
    """

    def visit(self, node: ASTNode) -> Any:
        """Visit a node and dispatch to the appropriate visit_* method.

        Args:
            node: AST node to visit

        Returns:
            Result of the visit method
        """
        method_name = f"visit_{node.__class__.__name__}"
        visitor = getattr(self, method_name, self.generic_visit)
        return visitor(node)

    def generic_visit(self, node: ASTNode) -> Any:
        """Default visitor for nodes without specific visit_* methods."""
        return node

    def visit_SelectStatement(self, node: SelectStatement) -> Any:
        """Visit a SELECT statement."""
        if node.where:
            self.visit(node.where)
        for order_by in node.order_by:
            self.visit(order_by)
        return node

    def visit_WhereClause(self, node: WhereClause) -> Any:
        """Visit a WHERE clause."""
        self.visit(node.condition)
        return node

    def visit_OrderByClause(self, node: OrderByClause) -> Any:
        """Visit an ORDER BY clause."""
        self.visit(node.field)
        return node

    def visit_LogicalExpr(self, node: LogicalExpr) -> Any:
        """Visit a logical expression."""
        for operand in node.operands:
            self.visit(operand)
        return node

    def visit_ComparisonExpr(self, node: ComparisonExpr) -> Any:
        """Visit a comparison expression."""
        self.visit(node.left)
        self.visit(node.right)
        return node

    def visit_BinaryOp(self, node: BinaryOp) -> Any:
        """Visit a binary operation."""
        self.visit(node.left)
        self.visit(node.right)
        return node

    def visit_UnaryOp(self, node: UnaryOp) -> Any:
        """Visit a unary operation."""
        self.visit(node.operand)
        return node

    def visit_FunctionCall(self, node: FunctionCall) -> Any:
        """Visit a function call."""
        for arg in node.args:
            self.visit(arg)
        return node

    def visit_OperatorCall(self, node: OperatorCall) -> Any:
        """Visit an operator call."""
        for arg in node.args:
            self.visit(arg)
        return node

    def visit_Literal(self, node: Literal) -> Any:
        """Visit a literal."""
        return node

    def visit_Identifier(self, node: Identifier) -> Any:
        """Visit an identifier."""
        return node

