from functools import partial

import gdsfactory as gf
import numpy as np
from gdsfactory.typings import ComponentSpec

from . import pcells, util
from .types import validate_position_with_orientation, validate_orientation
from .fanin import add_fan_in
from .routing import add_route_astar
from .types import Int, LayerLike, PortLike, Um


def add_bundle_astar(
    component: gf.Component,
    ports1: list[PortLike],
    ports2: list[PortLike],
    spacing: Um,
    bend: ComponentSpec,
    straight: ComponentSpec,
    layer: LayerLike,
    grid_unit: Int = 500,
):
    if len(ports1) != len(ports2):
        raise ValueError("Number of start ports is different than number of end ports")
    num_ports = len(ports1)
    if num_ports == 0:
        raise ValueError("No input/output ports given")
    xyo1 = [validate_position_with_orientation(p) for p in ports1]
    xyo2 = [validate_position_with_orientation(p) for p in ports2]
    os1 = [o for _, _, o in xyo1]
    os2 = [o for _, _, o in xyo2]
    if not all([o == os1[0] for o in os1]):
        raise ValueError(f"Input port orientations are not all equal. Got: {os1}.")
    if not all([o == os2[0] for o in os2]):
        raise ValueError(f"Output port orientations are not all equal. Got: {os1}.")

    o1 = validate_orientation(os1[0])
    o2 = validate_orientation(os2[0])
    if o1 == o2:
        # FIXME: this check seems necessary because the router doesn't seem to find a solution anyway in this case :(
        raise ValueError(
            f"The port orientation at the input needs to be different from the port orientation at the output. Got: {o1!r}=={o2!r}."
        )

    if num_ports == 1:
        start = validate_position_with_orientation(ports1[0], False)
        stop = validate_position_with_orientation(ports2[0], True)
    else:
        inv_dbu = util.get_inv_dbu()
        spacing_dbu = round(spacing * inv_dbu)
        starts = add_fan_in(
            c=component,
            inputs=ports1,
            straight=straight,
            bend=bend,
            spacing_dbu=spacing_dbu,
        )
        stops = add_fan_in(
            c=component,
            inputs=ports2,
            straight=straight,
            bend=bend,
            spacing_dbu=spacing_dbu,
        )
        start = (*np.mean(starts, 0), o1)
        stop = (*np.mean(stops, 0), util.invert_orientation(o2))
        bend = partial(pcells.bends, bend, straight, num_ports, spacing)
        straight = partial(pcells.straights, straight, num_ports, spacing)

    add_route_astar(
        c=component,
        start=start,
        stop=stop,
        layer=layer,
        straight=straight,
        bend=bend,
        grid_unit=grid_unit,
    )
    return [None for _ in range(num_ports)]
