"""
This was initially generated by datamodel-codegen from the labware schema in
shared-data. It's been modified by hand to be more friendly.
"""

from __future__ import annotations

from enum import Enum
from math import sqrt, asin
from typing import Final
from numpy import pi, trapz
from functools import cached_property

from pydantic import (
    ConfigDict,
    BaseModel,
    Discriminator,
    Field,
    StrictInt,
    StrictFloat,
    TypeAdapter,
)
from typing_extensions import Annotated, Literal

from .types import LocatingFeatures
from .constants import (
    Conical,
    Cuboidal,
    RoundedCuboid,
    SquaredCone,
    Spherical,
    WellShape,
    Circular,
    Rectangular,
)

SAFE_STRING_REGEX = "^[a-z0-9._]+$"
RECURSIVE_SEARCH_VOLUME_TOLERANCE = 0.001

# TODO(jh, 2025-05-09): We need to handle both positive numbers (by schema 2 convention)
#  and negative numbers (by schema 3 convention) no matter what, so we can just not enforce signage at the schema level.

_StrictNonNegativeInt = Annotated[int, Field(strict=True, ge=0)]
_StrictNonNegativeFloat = Annotated[float, Field(strict=True, ge=0.0)]
_StrictNonPositiveInt = Annotated[int, Field(strict=True, le=0)]
_StrictNonPositiveFloat = Annotated[float, Field(strict=True, le=0.0)]


_Number = StrictInt | StrictFloat
"""JSON number type, written to preserve lack of decimal point.

For labware definition hashing, which is an older part of the codebase,
this ensures that Pydantic won't change `"someFloatField: 0` to
`"someFloatField": 0.0`, which would hash differently.
"""

_NonNegativeNumber = _StrictNonNegativeInt | _StrictNonNegativeFloat
"""Non-negative JSON number type, written to preserve lack of decimal point."""

_NonPositiveNumber = _StrictNonPositiveInt | _StrictNonPositiveFloat
"""Non-positive JSON number type, written to preserve lack of decimal point."""


class Vector2D(BaseModel):
    x: _Number
    y: _Number


class Vector3D(BaseModel):
    x: _Number
    y: _Number
    z: _Number


class AxisAlignedBoundingBox2D(BaseModel):
    backLeft: Vector2D
    frontRight: Vector2D


class AxisAlignedBoundingBox3D(BaseModel):
    backLeftBottom: Vector3D
    frontRightTop: Vector3D


class GripperOffsets(BaseModel):
    pickUpOffset: Vector3D
    dropOffset: Vector3D


class BrandData(BaseModel):
    brand: str
    brandId: list[str] | None = None
    links: list[str] | None = None


class DisplayCategory(str, Enum):
    tipRack = "tipRack"
    tubeRack = "tubeRack"
    reservoir = "reservoir"
    trash = "trash"
    wellPlate = "wellPlate"
    aluminumBlock = "aluminumBlock"
    adapter = "adapter"
    other = "other"
    lid = "lid"
    system = "system"


class LabwareRole(str, Enum):
    labware = "labware"
    fixture = "fixture"
    adapter = "adapter"
    maintenance = "maintenance"
    lid = "lid"
    system = "system"


class Quirks(Enum):
    disableGeometryBasedGripCheck = "disableGeometryBasedGripCheck"


class Metadata(BaseModel):
    displayName: str
    displayCategory: DisplayCategory
    displayVolumeUnits: Literal["µL", "mL", "L"]
    tags: list[str] | None = None


class Parameters2(BaseModel):
    format: Literal["96Standard", "384Standard", "trough", "irregular", "trash"]
    quirks: list[str] | None = None
    isTiprack: bool
    tipLength: _NonNegativeNumber | None = None
    tipOverlap: _NonNegativeNumber | None = None
    loadName: Annotated[str, Field(pattern=SAFE_STRING_REGEX)]
    isMagneticModuleCompatible: bool
    isDeckSlotCompatible: bool | None = None
    magneticModuleEngageHeight: _NonNegativeNumber | None = None


class Parameters3(Parameters2, BaseModel):
    pass  # Currently equivalent to Parameters2.


class Dimensions(BaseModel):
    yDimension: _NonNegativeNumber
    zDimension: _NonNegativeNumber
    xDimension: _NonNegativeNumber


class _WellCommonMixin(BaseModel):
    model_config = ConfigDict(extra="allow")

    depth: _NonNegativeNumber
    totalLiquidVolume: _NonNegativeNumber
    x: _NonNegativeNumber
    z: _NonNegativeNumber
    geometryDefinitionId: str | None = None


class _WellCommon2(_WellCommonMixin):
    y: _NonNegativeNumber


class _WellCommon3(_WellCommonMixin):
    y: _NonPositiveNumber


class _CircularWellMixin(BaseModel):
    shape: Literal["circular"]
    diameter: _NonNegativeNumber


class _RectangularWellMixin(BaseModel):
    shape: Literal["rectangular"]
    xDimension: _NonNegativeNumber
    yDimension: _NonNegativeNumber


class CircularWellDefinition2(_WellCommon2, _CircularWellMixin):
    pass


class RectangularWellDefinition2(_WellCommon2, _RectangularWellMixin):
    pass


WellDefinition2 = Annotated[
    CircularWellDefinition2 | RectangularWellDefinition2, Discriminator("shape")
]


class CircularWellDefinition3(_WellCommon3, _CircularWellMixin):
    pass


class RectangularWellDefinition3(_WellCommon3, _RectangularWellMixin):
    pass


WellDefinition3 = Annotated[
    CircularWellDefinition3 | RectangularWellDefinition3, Discriminator("shape")
]


class SphericalSegment(BaseModel):
    shape: Spherical
    radiusOfCurvature: _NonNegativeNumber
    topHeight: _NonNegativeNumber
    bottomHeight: _NonNegativeNumber
    xCount: _StrictNonNegativeInt = 1
    yCount: _StrictNonNegativeInt = 1

    @cached_property
    def count(self) -> int:
        return self.xCount * self.yCount

    model_config = ConfigDict(ignored_types=(cached_property,))


class ConicalFrustum(BaseModel):
    shape: Conical
    bottomDiameter: _NonNegativeNumber
    topDiameter: _NonNegativeNumber
    topHeight: _NonNegativeNumber
    bottomHeight: _NonNegativeNumber
    xCount: _StrictNonNegativeInt = 1
    yCount: _StrictNonNegativeInt = 1

    def height_from_volume_search(self, target_volume: float) -> float:
        total_height = self.topHeight - self.bottomHeight
        max_height, min_height = total_height, 0.0
        volume_at_max_height = self.volume_from_height_circular(
            top_radius=self.topDiameter / 2,
            bottom_radius=self.bottomDiameter / 2,
            target_height=total_height,
            total_height=total_height,
        )
        if target_volume == volume_at_max_height:
            return max_height
        volume_at_min_height = self.volume_from_height_circular(
            top_radius=self.topDiameter / 2,
            bottom_radius=self.bottomDiameter / 2,
            target_height=0,
            total_height=total_height,
        )
        if target_volume == volume_at_min_height:
            return min_height

        y = total_height / 2
        volume_at_y = self.volume_from_height_circular(
            top_radius=self.topDiameter / 2,
            bottom_radius=self.bottomDiameter / 2,
            target_height=y,
            total_height=total_height,
        )
        guesses = [
            (volume_at_min_height, min_height),
            (volume_at_max_height, max_height),
        ]
        while abs(volume_at_y - target_volume) > RECURSIVE_SEARCH_VOLUME_TOLERANCE:
            max_height, max_volume = guesses[-1][1], guesses[-1][0]
            min_height, min_volume = guesses[0][1], guesses[0][0]

            # between volume_at_y and max value- undershot
            if volume_at_y < target_volume < max_volume:
                guesses = [(volume_at_y, y), (max_volume, max_height)]
            # overshot
            elif min_volume < target_volume < volume_at_y:
                guesses = [(min_volume, min_height), (volume_at_y, y)]
            y = (guesses[0][1] + guesses[1][1]) / 2

            volume_at_y = self.volume_from_height_circular(
                top_radius=self.topDiameter / 2,
                bottom_radius=self.bottomDiameter / 2,
                target_height=y,
                total_height=total_height,
            )
        return y

    def volume_from_height_circular(
        self,
        top_radius: float,
        bottom_radius: float,
        target_height: float,
        total_height: float,
    ) -> float:
        r_y = (target_height / total_height) * (
            top_radius - bottom_radius
        ) + bottom_radius
        return (pi * target_height / 3) * (
            bottom_radius**2 + bottom_radius * r_y + r_y**2
        )

    @cached_property
    def count(self) -> int:
        return self.xCount * self.yCount

    model_config = ConfigDict(ignored_types=(cached_property,))


class CuboidalFrustum(BaseModel):
    shape: Cuboidal
    bottomXDimension: _NonNegativeNumber
    bottomYDimension: _NonNegativeNumber
    topXDimension: _NonNegativeNumber
    topYDimension: _NonNegativeNumber
    topHeight: _NonNegativeNumber
    bottomHeight: _NonNegativeNumber
    xCount: _StrictNonNegativeInt = 1
    yCount: _StrictNonNegativeInt = 1

    @cached_property
    def count(self) -> int:
        return self.xCount * self.yCount

    model_config = ConfigDict(ignored_types=(cached_property,))


# A squared cone is the intersection of a cube and a cone that both
# share a central axis, and is a transitional shape between a cone and pyramid
"""
module RectangularPrismToCone(bottom_shape, diameter, x, y, z) {
    circle_radius = diameter/2;
    r1 = sqrt(x*x + y*y)/2;
    r2 = circle_radius/2;
    top_r = bottom_shape == "square" ? r1 : r2;
    bottom_r = bottom_shape == "square" ? r2 : r1;
    intersection() {
        cylinder(z,top_r,bottom_r,$fn=100);
        translate([0,0,z/2])cube([x, y, z], center=true);
    }
}
"""


class SquaredConeSegment(BaseModel):
    shape: SquaredCone
    bottomCrossSection: WellShape
    circleDiameter: _NonNegativeNumber
    rectangleXDimension: _NonNegativeNumber
    rectangleYDimension: _NonNegativeNumber
    topHeight: _NonNegativeNumber
    bottomHeight: _NonNegativeNumber
    xCount: _StrictNonNegativeInt = 1
    yCount: _StrictNonNegativeInt = 1

    @staticmethod
    def _area_trap_points(
        total_frustum_height: float,
        circle_diameter: float,
        rectangle_x: float,
        rectangle_y: float,
        dx: float,
    ) -> list[float]:
        """Grab a bunch of data points of area at given heights."""

        def _area_arcs(r: float, c: float, d: float) -> float:
            """Return the area of all 4 arc segments."""
            theata_y = asin(c / r)
            theata_x = asin(d / r)
            theata_arc = (pi / 2) - theata_y - theata_x
            # area of all 4 arcs is 4 * pi*r^2*(theata/2pi)
            return 2 * r**2 * theata_arc

        def _area(r: float) -> float:
            """Return the area of a given r_y."""
            # distance from the center of y axis of the rectangle to where the arc intercepts that side
            c: float = (
                sqrt(r**2 - (rectangle_y / 2) ** 2) if (rectangle_y / 2) < r else 0
            )
            # distance from the center of x axis of the rectangle to where the arc intercepts that side
            d: float = (
                sqrt(r**2 - (rectangle_x / 2) ** 2) if (rectangle_x / 2) < r else 0
            )
            arc_area = _area_arcs(r, c, d)
            y_triangles: float = rectangle_y * c
            x_triangles: float = rectangle_x * d
            return arc_area + y_triangles + x_triangles

        r_0 = circle_diameter / 2
        r_h = sqrt(rectangle_x**2 + rectangle_y**2) / 2

        num_steps = int(total_frustum_height / dx)
        points = [0.0]
        for i in range(num_steps + 1):
            r_y = (i * dx / total_frustum_height) * (r_h - r_0) + r_0
            points.append(_area(r_y))
        return points

    @cached_property
    def height_to_volume_table(self) -> dict[float, float]:
        """Return a lookup table of heights to volumes."""
        # the accuracy of this method is approximately +- 10*dx so for dx of 0.001 we have a +- 0.01 ul
        dx = 0.001
        total_height = self.topHeight - self.bottomHeight
        points = SquaredConeSegment._area_trap_points(
            total_height,
            self.circleDiameter,
            self.rectangleXDimension,
            self.rectangleYDimension,
            dx,
        )
        if self.bottomCrossSection is Rectangular:
            # The points function assumes the circle is at the bottom but if its flipped we just reverse the points
            points.reverse()
        elif self.bottomCrossSection is not Circular:
            raise NotImplementedError(
                "If you see this error a new well shape has been added without updating this code"
            )
        y = 0.0
        table: dict[float, float] = {}
        # fill in the table
        while y < total_height:
            table[y] = trapz(points[0 : int(y / dx)], dx=dx)
            y = y + dx

        # we always want to include the volume at the max height
        table[total_height] = trapz(points, dx=dx)
        return table

    @cached_property
    def volume_to_height_table(self) -> dict[float, float]:
        return dict((v, k) for k, v in self.height_to_volume_table.items())

    @cached_property
    def count(self) -> int:
        return self.xCount * self.yCount

    model_config = ConfigDict(ignored_types=(cached_property,))


"""
module filitedCuboidSquare(bottom_shape, diameter, width, length, height, steps) {
    module _slice(depth, x, y, r) {
        echo("called with: ", depth, x, y, r);
        circle_centers = [
            [(x/2)-r, (y/2)-r, 0],
            [(-x/2)+r, (y/2)-r, 0],
            [(x/2)-r, (-y/2)+r, 0],
            [(-x/2)+r, (-y/2)+r, 0]

        ];
        translate([0,0,depth/2])cube([x-2*r,y,depth], center=true);
        translate([0,0,depth/2])cube([x,y-2*r,depth], center=true);
        for (center = circle_centers) {
            translate(center) cylinder(depth, r, r, $fn=100);
        }
    }
    for (slice_height = [0:height/steps:height]) {
        r = (diameter) * (slice_height/height);
        translate([0,0,slice_height]) {
             _slice(height/steps , width, length, r/2);
        }
    }
}
module filitedCuboidForce(bottom_shape, diameter, width, length, height, steps) {
    module single_cone(r,x,y,z) {
        r = diameter/2;
        circle_face = [[ for (i = [0:1: steps]) i ]];
        theta = 360/steps;
        circle_points = [for (step = [0:1:steps]) [r*cos(theta*step), r*sin(theta*step), z]];
        final_points = [[x,y,0]];
        all_points = concat(circle_points, final_points);
        triangles = [for (step = [0:1:steps-1]) [step, step+1, steps+1]];
        faces = concat(circle_face, triangles);
        polyhedron(all_points, faces);
    }
    module square_section(r, x, y, z) {
        points = [
            [x,y,0],
            [-x,y,0],
            [-x,-y,0],
            [x,-y,0],
            [r,0,z],
            [0,r,z],
            [-r,0,z],
            [0,-r,z],
        ];
        faces = [
            [0,1,2,3],
            [4,5,6,7],
            [4, 0, 3],
            [5, 0, 1],
            [6, 1, 2],
            [7, 2, 3],
        ];
        polyhedron(points, faces);
    }
    circle_height = bottom_shape == "square" ? height : -height;
    translate_height = bottom_shape == "square" ? 0 : height;
    translate ([0,0, translate_height]) {
        union() {
            single_cone(diameter/2, width/2, length/2, circle_height);
            single_cone(diameter/2, -width/2, length/2, circle_height);
            single_cone(diameter/2, width/2, -length/2, circle_height);
            single_cone(diameter/2, -width/2, -length/2, circle_height);
            square_section(diameter/2, width/2, length/2, circle_height);
        }
    }
}

module filitedCuboid(bottom_shape, diameter, width, length, height) {
    if (width == length && width == diameter) {
        filitedCuboidSquare(bottom_shape, diameter, width, length, height, 100);
    }
    else {
        filitedCuboidForce(bottom_shape, diameter, width, length, height, 100);
    }
}"""


class RoundedCuboidSegment(BaseModel):
    shape: RoundedCuboid
    bottomCrossSection: WellShape
    circleDiameter: _NonNegativeNumber
    rectangleXDimension: _NonNegativeNumber
    rectangleYDimension: _NonNegativeNumber
    topHeight: _NonNegativeNumber
    bottomHeight: _NonNegativeNumber
    xCount: _StrictNonNegativeInt = 1
    yCount: _StrictNonNegativeInt = 1

    @cached_property
    def count(self) -> int:
        return self.xCount * self.yCount

    model_config = ConfigDict(ignored_types=(cached_property,))


class GroupMetadata(BaseModel):
    displayName: str | None = None
    displayCategory: DisplayCategory | None = None
    wellBottomShape: Literal["flat", "u", "v"] | None = None


class Group(BaseModel):
    wells: list[str]
    metadata: GroupMetadata
    brand: BrandData | None = None


WellSegment = Annotated[
    ConicalFrustum
    | CuboidalFrustum
    | SquaredConeSegment
    | RoundedCuboidSegment
    | SphericalSegment,
    Discriminator("shape"),
]


class HeightVolumePair(BaseModel):
    height: float
    volume: float


class InnerWellGeometry(BaseModel):
    sections: Annotated[list[WellSegment], Field(min_length=1)]


class UserDefinedVolumes(BaseModel):
    heightToVolumeMap: list[HeightVolumePair]


class Extents(BaseModel):
    total: AxisAlignedBoundingBox3D


class LabwareDefinition2(BaseModel):
    schemaVersion: Literal[2]
    version: Annotated[int, Field(ge=1)]
    namespace: Annotated[str, Field(pattern=SAFE_STRING_REGEX)]
    metadata: Metadata
    brand: BrandData
    parameters: Parameters2
    cornerOffsetFromSlot: Vector3D
    ordering: list[list[str]]
    dimensions: Dimensions
    wells: dict[str, WellDefinition2]
    groups: list[Group]
    stackingOffsetWithLabware: dict[str, Vector3D] = Field(default_factory=dict)
    stackingOffsetWithModule: dict[str, Vector3D] = Field(default_factory=dict)
    allowedRoles: list[LabwareRole] = Field(default_factory=list)
    gripperOffsets: dict[str, GripperOffsets] = Field(default_factory=dict)
    gripForce: float | None = None
    gripHeightFromLabwareBottom: float | None = None
    stackLimit: int | None = None
    compatibleParentLabware: list[str] | None = None
    innerLabwareGeometry: dict[
        str, InnerWellGeometry | UserDefinedVolumes
    ] | None = None


class LabwareDefinition3(BaseModel):
    otSharedSchema: Annotated[
        Literal["#/labware/schemas/3"], Field(alias="$otSharedSchema")
    ]
    schemaVersion: Literal[3]
    version: Annotated[int, Field(ge=1)]
    namespace: Annotated[str, Field(pattern=SAFE_STRING_REGEX)]
    metadata: Metadata
    brand: BrandData
    parameters: Parameters3
    ordering: list[list[str]]
    features: LocatingFeatures
    extents: Extents
    wells: dict[str, WellDefinition3]
    groups: list[Group]
    stackingOffsetWithLabware: dict[str, Vector3D] = Field(default_factory=dict)
    legacyStackingOffsetWithLabware: dict[str, Vector3D] = Field(default_factory=dict)
    stackingOffsetWithModule: dict[str, Vector3D] = Field(default_factory=dict)
    allowedRoles: list[LabwareRole] = Field(default_factory=list)
    gripperOffsets: dict[str, GripperOffsets] = Field(default_factory=dict)
    gripForce: float | None = None
    gripHeightFromLabwareOrigin: float | None = None
    stackLimit: int | None = None
    compatibleParentLabware: list[str] | None = None
    innerLabwareGeometry: dict[str, InnerWellGeometry] | None = None


LabwareDefinition = Annotated[
    LabwareDefinition2 | LabwareDefinition3,
    Discriminator("schemaVersion"),
]


labware_definition_type_adapter: Final = TypeAdapter[LabwareDefinition](
    LabwareDefinition
)
