from typing import Annotated, Any, Literal, TypeAlias

from pydantic import (
    AliasChoices,
    AliasPath,
    BaseModel,
    ConfigDict,
    Field,
    TypeAdapter,
    computed_field,
)

from .helper import unix_ms


class _Base(BaseModel):
    model_config = ConfigDict(extra="allow")


class TimeInfo(_Base):
    """Represents start and end timestamps for an event."""

    start: int
    end: int

    @property
    def duration_ms(self) -> int:
        """Duration in milliseconds (end - start).

        Note: Can be negative in rare cases of clock skew between components.
        """
        return self.end - self.start


class CacheInfo(_Base):
    """Token cache usage information."""

    read: int
    write: int


class TokenUsage(_Base):
    """Detailed token usage for a step."""

    input: int
    output: int
    reasoning: int = 0
    cache: CacheInfo | None = None


class ToolStateInput(_Base):
    """Input parameters for a tool call. Fields are optional as they vary by tool."""

    command: str | None = None
    description: str | None = None


class ToolState(_Base):
    """The complete state of a tool's execution."""

    status: str
    input: ToolStateInput
    output: str | None = None
    title: str | None = None
    metadata: dict[str, Any] = Field(default_factory=dict)
    time: TimeInfo | None = None


class BasePart(_Base):
    """Base model for all event parts, containing common identifiers."""

    id: str = Field(repr=False, exclude=True)
    sessionID: str = Field(repr=False, exclude=True)
    messageID: str = Field(repr=False, exclude=True)


class StepStartPart(BasePart):
    """Part model for a 'step-start' event."""

    type: Literal["step-start"]
    snapshot: str | None = Field(default=None, repr=False, exclude=True)


class TextPart(BasePart):
    """Part model for a 'text' event."""

    type: Literal["text"]
    text: str
    time: TimeInfo | None = Field(default=None, repr=False, exclude=True)


class ToolUsePart(BasePart):
    """Part model for a 'tool_use' event."""

    type: Literal["tool"]
    callID: str = Field(repr=False, exclude=True)
    tool: str
    state: ToolState


class StepFinishPart(BasePart):
    """Part model for a 'step-finish' event."""

    type: Literal["step-finish"]
    snapshot: str | None = Field(default=None, repr=False, exclude=True)
    cost: float
    tokens: TokenUsage


class BaseEvent(_Base):
    """Base structure shared by all events from the CLI (excludes `type`)."""

    seq: int = Field(default=0)
    timestamp: int = Field(default_factory=unix_ms)
    sessionID: str = Field(default="", repr=False, exclude=True)


class StepStartEvent(BaseEvent):
    """Event representing the beginning of a processing step."""

    type: Literal["step_start"]
    part: StepStartPart = Field(repr=False, exclude=True)


class TextEvent(BaseEvent):
    """Event representing a chunk of text generated by InnerLoop."""

    type: Literal["text"]
    part: TextPart = Field(repr=False, exclude=True)

    @computed_field
    def text(self) -> str:
        return self.part.text


class ToolUseEvent(BaseEvent):
    """Event representing a tool invocation by InnerLoop (e.g., bash)."""

    type: Literal["tool_use"]
    part: ToolUsePart = Field(repr=False, exclude=True)

    @computed_field
    def output(self) -> str:
        return self.part.state.output or ""

    @computed_field
    def status(self) -> str:
        return self.part.state.status

    @computed_field
    def tool(self) -> str:
        return self.part.tool


class StepFinishEvent(BaseEvent):
    """Event representing the end of a processing step, with cost and token info."""

    type: Literal["step_finish"]
    part: StepFinishPart = Field(repr=False, exclude=True)

    @property
    def cost(self) -> float:
        return self.part.cost

    @property
    def tokens(self) -> TokenUsage:
        return self.part.tokens


class ErrorEvent(BaseEvent):
    """Typed error event emitted by the CLI.

    Error events differ from other events by using a flat payload rather than a
    nested `part` model. Only common fields are represented here; providers may
    omit optional fields.
    """

    type: Literal["error"]
    message: str = Field(
        default="Unknown error",
        validation_alias=AliasChoices(
            "message",
            AliasPath("error", "message"),
            AliasPath("error", "data", "message"),
        ),
    )
    code: str | None = Field(
        default=None,
        validation_alias=AliasChoices("code", AliasPath("error", "name")),
    )
    severity: str | None = Field(default="error")

    @property
    def error_message(self) -> str:
        """Alias for the top-level message field for consistency."""
        return self.message


OpenCodeEvent: TypeAlias = (
    StepStartEvent | TextEvent | ToolUseEvent | StepFinishEvent | ErrorEvent
)

EventUnion = Annotated[OpenCodeEvent, Field(discriminator="type")]

EventAdapter: TypeAdapter[OpenCodeEvent] = TypeAdapter(EventUnion)


__all__ = [
    "TimeInfo",
    "CacheInfo",
    "TokenUsage",
    "ToolStateInput",
    "ToolState",
    "BasePart",
    "StepStartPart",
    "TextPart",
    "ToolUsePart",
    "StepFinishPart",
    "BaseEvent",
    "StepStartEvent",
    "TextEvent",
    "ToolUseEvent",
    "StepFinishEvent",
    "ErrorEvent",
    "OpenCodeEvent",
    "EventUnion",
    "EventAdapter",
]


def parse_event(raw: dict[str, Any]) -> OpenCodeEvent:
    """Adapter-level parser with minimal pre-normalization for known variants.

    Currently supports a vendor variant where text events arrive as:
      {"type": "text", "text": "..."}
    Instead of the canonical part-based structure. This function converts the
    variant to the canonical schema before delegating to the TypeAdapter.
    """
    try:
        t = raw.get("type")
    except Exception:
        t = None

    if t == "text" and "part" not in raw and "text" in raw:
        # Build a canonical part payload. If the vendor provided an object under
        # "text", extract identifiers when present; otherwise, fill with empty
        # strings (excluded from public dumps and unused by assembly logic).
        text_field = raw.get("text")
        time_val: Any | None = None
        if isinstance(text_field, dict):
            text_val = text_field.get("text", "")
            id_val = text_field.get("id", "")
            sid_val = text_field.get("sessionID", "")
            mid_val = text_field.get("messageID", "")
            time_val = text_field.get("time")
        else:
            text_val = text_field or ""
            id_val = sid_val = mid_val = ""

        new_raw = dict(raw)
        new_raw.pop("text", None)
        part = {
            "type": "text",
            "text": text_val,
            "id": id_val,
            "sessionID": sid_val,
            "messageID": mid_val,
        }
        if isinstance(time_val, dict):
            part["time"] = time_val
        new_raw["part"] = part
        return EventAdapter.validate_python(new_raw)

    return EventAdapter.validate_python(raw)
