"""AST node definitions for NLQL."""

from dataclasses import dataclass, field
from typing import Any


@dataclass
class ASTNode:
    """Base class for all AST nodes."""

    pass


@dataclass
class Literal(ASTNode):
    """Literal value (string, number, etc.)."""

    value: Any
    type: str = "string"  # string, number, boolean


@dataclass
class Identifier(ASTNode):
    """Identifier (variable name, field name, etc.)."""

    name: str


@dataclass
class FunctionCall(ASTNode):
    """Function call expression."""

    name: str
    args: list[ASTNode] = field(default_factory=list)


@dataclass
class OperatorCall(ASTNode):
    """Operator call (MATCH, SIMILAR_TO, etc.)."""

    operator: str
    args: list[ASTNode] = field(default_factory=list)


@dataclass
class BinaryOp(ASTNode):
    """Binary operation (arithmetic, comparison, etc.)."""

    op: str  # +, -, *, /, <, >, <=, >=, ==, !=
    left: ASTNode
    right: ASTNode


@dataclass
class UnaryOp(ASTNode):
    """Unary operation (NOT, -, etc.)."""

    op: str  # NOT, -
    operand: ASTNode


@dataclass
class ComparisonExpr(ASTNode):
    """Comparison expression."""

    op: str  # <, >, <=, >=, ==, !=
    left: ASTNode
    right: ASTNode


@dataclass
class LogicalExpr(ASTNode):
    """Logical expression (AND, OR)."""

    op: str  # AND, OR
    operands: list[ASTNode] = field(default_factory=list)


@dataclass
class WhereClause(ASTNode):
    """WHERE clause."""

    condition: ASTNode


@dataclass
class OrderByClause(ASTNode):
    """ORDER BY clause."""

    field: ASTNode
    direction: str = "ASC"  # ASC or DESC


@dataclass
class SelectStatement(ASTNode):
    """Complete SELECT statement."""

    select_unit: str = "CHUNK"  # DOCUMENT, CHUNK, SENTENCE, SPAN
    span_config: dict[str, Any] | None = None  # For SPAN(unit, window=N)
    where: WhereClause | None = None
    order_by: list[OrderByClause] = field(default_factory=list)
    limit: int | None = None

