import numpy as np
from gdsfactory.component import Component
from gdsfactory.typings import ComponentSpec

from . import util
from .routing import add_route_from_steps
from .types import (
    Dbu,
    OrientationChar,
    PointsDbu,
    PortLike,
    StepDbu,
    validate_orientation,
    validate_position_with_orientation,
)


def add_fan_in(
    c: Component,
    inputs: list[PortLike],
    straight: ComponentSpec,
    bend: ComponentSpec,
    x_bundle_dbu: Dbu | None = None,
    y_bundle_dbu: Dbu | None = None,
    spacing_dbu: Dbu | None = None,
    start_dir: OrientationChar | None = None,
):
    invert_direction = False
    starts = [validate_position_with_orientation(p) for p in inputs]
    ds = [d for _, _, d in starts]
    if any([d != ds[0] for d in ds]):
        raise ValueError("start ports all need to have the same orientation.")
    d0 = ds[0]
    start_dir = validate_orientation(start_dir)
    if d0 == "o":
        if start_dir == "o":
            raise ValueError(
                "Please specify a start direction if your inputs don't have an orientation."
            )
        d0 = start_dir
    elif start_dir == "o":
        pass
    elif d0 == start_dir:
        pass
    # elif d0 == util.invert_orientation(start_dir):
    #    invert_direction = True  # yes, that's allowed!
    else:
        raise ValueError("Invalid start direction.")

    # at this point d0 is guarenteed to be in "nesw"
    coord = 1 if d0 in "ew" else 0
    starts = np.array(sorted([(x, y) for x, y, _ in starts], key=lambda xy: xy[coord]))  # type: ignore
    fan_ins = {
        "n": _fan_in_north_steps,
        "e": _fan_in_east_steps,
        "s": _fan_in_south_steps,
        "w": _fan_in_west_steps,
    }
    fan_in = fan_ins[d0]

    stepses, stops = fan_in(
        [(x, y) for x, y in starts],
        straight,
        bend,
        x_bundle_dbu,
        y_bundle_dbu,
        spacing_dbu,
        invert_direction=invert_direction,
    )
    for start, stop, steps in zip(starts, stops, stepses):
        add_route_from_steps(
            c=c,
            start=start,
            stop=stop,
            steps=steps,
            straight=straight,
            bend=bend,
        )
    return stops


def _fan_in_east_steps(
    starts: PointsDbu,
    straight: ComponentSpec,
    bend: ComponentSpec,
    x_bundle_dbu: Dbu | None = None,
    y_bundle_dbu: Dbu | None = None,
    spacing_dbu: Dbu | None = None,
    invert_direction: bool = False,
) -> tuple[list[list[StepDbu]], np.ndarray]:
    # FIXME: invert direction not working properly.
    _starts = np.asarray(starts, dtype=np.int_)
    num_links = _starts.shape[0]
    wg_width_dbu = util.extract_waveguide_width(straight)
    _spacing_dbu = spacing_dbu or 2 * wg_width_dbu

    xs_start = _starts[:, 0]
    ys_start = _starts[:, 1]

    if y_bundle_dbu is None:
        _y_bundle_dbu = ys_start.mean()
        if num_links % 2:
            i = int(np.argmin(np.abs(ys_start - _y_bundle_dbu)))
            dys_start = ys_start[1:] - ys_start[:-1]
            dy = dys_start[min(i, dys_start.shape[0] - 1)]
            _y_bundle_dbu += dy / 2
    else:
        _y_bundle_dbu = int(y_bundle_dbu)

    if (xs_start[:] != xs_start[0]).any():
        raise ValueError("start ports need to be vertically or horizontally aligned.")

    ys_stop = np.sort(
        np.asarray(
            [
                _y_bundle_dbu - (i + 1) // 2 * _spacing_dbu * (2 * (i % 2) - 1)
                for i in range(num_links)
            ],
            dtype=np.int_,
        )
    )
    above = ((ys_start - ys_stop) > 0).sum()
    below = num_links - above
    radius_dbu = util.extract_bend_radius(bend)
    sign = 1 - 2 * invert_direction
    _x_bundle_dbu = x_bundle_dbu or (
        xs_start[0]
        + (1 - invert_direction)
        * (sign * 2 * radius_dbu + (max(above, below) - 1) * _spacing_dbu)
    )
    xs_stop = np.broadcast_to(np.int_(_x_bundle_dbu), ys_stop.shape)
    stops = np.stack([xs_stop, ys_stop], 1)

    steps = []
    for i, stop in enumerate(stops):
        steps.append(
            [
                {"dx": radius_dbu + max((below - i - 1), (i - below)) * _spacing_dbu},
                {"y": stop[1]},
            ]
        )
    return steps, stops


def _fan_in_west_steps(
    starts: PointsDbu,
    straight: ComponentSpec,
    bend: ComponentSpec,
    x_bundle_dbu: Dbu | None = None,
    y_bundle_dbu: Dbu | None = None,
    spacing_dbu: Dbu | None = None,
    invert_direction: bool = False,
) -> tuple[list[list[StepDbu]], np.ndarray]:
    # strategy: calculate east and adjust coordinates accordingly
    steps, stops = _fan_in_east_steps(
        starts,
        straight,
        bend,
        x_bundle_dbu,
        y_bundle_dbu,
        spacing_dbu,
        invert_direction,
    )
    stops[:, 0] = 2 * np.asarray(starts, dtype=np.int_)[:, 0] - stops[:, 0]
    for _steps in steps:
        for step in _steps:
            if "dx" in step:
                step["dx"] = -(step["dx"] or 0)
    return steps, stops


def _fan_in_north_steps(
    starts: PointsDbu,
    straight: ComponentSpec,
    bend: ComponentSpec,
    x_bundle_dbu: Dbu | None = None,
    y_bundle_dbu: Dbu | None = None,
    spacing_dbu: Dbu | None = None,
    invert_direction: bool = False,
) -> tuple[list[list[StepDbu]], np.ndarray]:
    # strategy: calculate east and adjust coordinates accordingly
    steps, stops = _fan_in_east_steps(
        [(y, x) for (x, y) in starts],
        straight,
        bend,
        y_bundle_dbu,
        x_bundle_dbu,
        spacing_dbu,
        invert_direction,
    )
    stops = np.stack([stops[:, 1], stops[:, 0]], 1)
    for _steps in steps:
        for step in _steps:
            if "dx" in step:
                step["dy"] = step.pop("dx")  # type: ignore
            if "y" in step:
                step["x"] = step.pop("y")  # type: ignore
    return steps, stops


def _fan_in_south_steps(
    starts: PointsDbu,
    straight: ComponentSpec,
    bend: ComponentSpec,
    x_bundle_dbu: Dbu | None = None,
    y_bundle_dbu: Dbu | None = None,
    spacing_dbu: Dbu | None = None,
    invert_direction: bool = False,
) -> tuple[list[list[StepDbu]], np.ndarray]:
    # strategy: calculate north and adjust coordinates accordingly
    steps, stops = _fan_in_north_steps(
        starts,
        straight,
        bend,
        x_bundle_dbu,
        y_bundle_dbu,
        spacing_dbu,
        invert_direction,
    )
    stops[:, 1] = 2 * np.asarray(starts, dtype=np.int_)[:, 1] - stops[:, 1]
    for _steps in steps:
        for step in _steps:
            if "dy" in step:
                step["dy"] = -(step["dy"] or 0)
    return steps, stops
