import gdsfactory as gf
import numpy as np
from gdsfactory.component import Component
from gdsfactory.typings import ComponentSpec

from . import util
from .types import Int, OrientationChar, OrientationTransition, Um

__all__ = []


@gf.cell
def straights(straight: ComponentSpec, num: Int, spacing: Um, **kwargs) -> Component:
    c = gf.Component()
    cstraight = gf.get_component(straight, **kwargs)
    o1, o2 = [p.name for p in cstraight.ports]

    refs = [c.add_ref(cstraight) for _ in range(num)]
    util.orient_east_at_origin(refs[0])
    c.add_port(o1, port=refs[0].ports[o1])
    c.add_port(o2, port=refs[0].ports[o2])

    for i in range(num):
        util.orient_east_at_origin(refs[i])
        refs[i].dmove((0.0, float((num - 1) * spacing / 2 - i * spacing)))

    return c


@gf.cell
def bends(
    bend: ComponentSpec, straight: ComponentSpec, num: Int, spacing: Um, **kwargs
) -> Component:
    c = gf.Component()
    cbend = gf.get_component(bend, **kwargs)
    straight_kwargs = {
        k: kwargs.get(k) for k in ["cross_section", "width"] if k in kwargs
    }
    cstraights = [
        gf.get_component(straight, length=spacing * i, **straight_kwargs)
        for i in range(1, num)
    ]
    o1 = str([p.name for p in cstraights[0].ports][0])
    b1, b2 = [p.name for p in cbend.ports]

    bend_refs = [c.add_ref(cbend) for _ in range(num)]
    for i in range(num):
        util.orient_east_to_north_at_origin(bend_refs[i])

    straight_refs = [(c.add_ref(s), c.add_ref(s)) for s in cstraights]
    for i in range(num - 1):
        util.orient_east_at_origin(straight_refs[i][0])
        util.orient_east_at_origin(straight_refs[i][1])

    for i in range(num):
        bend_refs[i].dmove(
            (
                float(-i * spacing + (num - 1) * spacing),
                float(i * spacing - (num - 1) * spacing / 2),
            )
        )

    for i in range(num - 1):
        straight_refs[i][0].connect(o1, bend_refs[num - 2 - i].ports[b1])
        straight_refs[i][1].connect(o1, bend_refs[num - 2 - i].ports[b2])

    inv_dbu = util.get_inv_dbu()
    radius = util.extract_bend_radius(cbend) / inv_dbu
    c.add_port(name=b1, port=cbend.ports[b1])
    c.ports[b1].dcenter = (
        0.0,
        0.0,
    )
    c.add_port(name=b2, port=cbend.ports[b2])
    c.ports[b2].dcenter = (
        float(radius + (num - 1) * spacing / 2),
        float(radius + (num - 1) * spacing / 2),
    )
    return c


@gf.cell
def frame(
    width: Um = 100,
    height: Um = 50,
    frame_width: Um = 1.0,
    straight: ComponentSpec = "straight",
) -> Component:
    c = gf.Component()
    h = gf.get_component(straight, length=width, width=frame_width)
    v = gf.get_component(straight, length=height + 2 * frame_width, width=frame_width)
    top = c.add_ref(h)
    util.orient_east_at_origin(top)
    top.dmove((float(frame_width / 2), float(height + frame_width / 2)))
    btm = c.add_ref(h)
    util.orient_east_at_origin(btm)
    btm.dmove((float(frame_width / 2), -float(frame_width / 2)))
    lft = c.add_ref(v)
    util.orient_east_at_origin(lft)
    lft.drotate(90)
    lft.dmove((0.0, -float(frame_width)))
    rgt = c.add_ref(v)
    util.orient_east_at_origin(rgt)
    rgt.drotate(90)
    rgt.dmove((float(width + frame_width), -float(frame_width)))
    return c


@gf.cell
def field0(
    width: Um = 20, height: Um = 20, straight: ComponentSpec = "straight"
) -> Component:
    c = gf.Component()
    c.add_ref(frame(width, height))
    ref = c.add_ref(gf.get_component(straight, length=1, width=1), name="in")
    util.orient_east_at_origin(ref)
    ref.dmove((1.5, 2.5))
    ref = c.add_ref(gf.get_component(straight, length=1, width=1), name="out")
    util.orient_east_at_origin(ref)
    ref.dmove((float(width - 2.5), float(height - 2.5)))
    return c


@gf.cell
def field1(
    width: Um = 50, height: Um = 50, straight: ComponentSpec = "straight"
) -> Component:
    c = gf.Component()
    c.add_ref(frame(width, height))
    ref = c.add_ref(
        gf.get_component(straight, length=4 / 5 * width, width=0.1 * height)
    )
    util.orient_east_at_origin(ref)
    ref.dmove((0.5, float(29 / 50 * height)))
    ref = c.add_ref(gf.get_component(straight, length=0.1 * width, width=0.1 * height))
    util.orient_east_at_origin(ref)
    ref.dmove((float(0.9 * width + 0.5), float(3 / 5 * height)))
    ref = c.add_ref(gf.get_component(straight, length=4 / 5 * width, width=1))
    util.orient_east_at_origin(ref)
    ref.dmove((float(width - 4 / 5 * width + 0.5), 40.0))
    ref = c.add_ref(gf.get_component(straight, length=1, width=1), name="in")
    util.orient_east_at_origin(ref)
    ref.dmove((1.5, 2.5))
    ref = c.add_ref(gf.get_component(straight, length=1, width=1), name="out")
    util.orient_east_at_origin(ref)
    ref.dmove((float(width - 2.5), float(height - 2.5)))
    return c


@gf.cell
def fanout_frame(
    width: Um = 100,
    num_inputs: Int = 5,
    input_spacing: Um = 40,
    straight: ComponentSpec = "straight",
    orientation: OrientationChar = "e",
    add_frame=True,
):
    add_width = float(orientation in "ws")
    idx = 1 if orientation in "we" else 0
    c = gf.Component()
    height = num_inputs * input_spacing
    cstraight = gf.get_component(straight, length=1)
    refs_in = [c.add_ref(cstraight, name=f"in{i}") for i in range(num_inputs)]
    for i in range(num_inputs):
        util.orient_at_origin(refs_in[i], orientation)
    for i in range(num_inputs):
        move = np.array([add_width, add_width]) * width
        move[idx] = float(height / 2 - (i + 1) // 2 * input_spacing * (2 * (i % 2) - 1))
        refs_in[i].dmove((move[0], move[1]))

    if add_frame:
        if idx == 1:
            c.add_ref(frame(width=width, height=height))
        else:
            c.add_ref(frame(width=height, height=width))

    for i in range(num_inputs):
        c.add_port(f"in{i}", refs_in[i].ports["o2"])
    return c


@gf.cell
def fanout_frame2(
    width: Um = 100,
    num_inputs: Int = 5,
    input_spacing: Um = 40,
    output_spacing: Um = 30,
    straight: ComponentSpec = "straight",
    transition: OrientationTransition = ("e", "w"),
    add_frame: bool = True,
):
    c = gf.Component()
    r1 = c << fanout_frame(
        width, num_inputs, input_spacing, straight, transition[0], add_frame
    )
    r2 = c << fanout_frame(
        width, num_inputs, output_spacing, straight, transition[1], False
    )
    for i, p in enumerate(r1.ports):
        c.add_port(f"in{i}", port=p)
    for i, p in enumerate(r2.ports):
        c.add_port(f"out{i}", port=p)
    return c


def fanout_frame3(
    width: Um = 100,
    num_inputs: Int = 5,
    input_spacing: Um = 40,
    output_spacing: Um = 30,
    straight: ComponentSpec = "straight",
    transition: OrientationTransition = ("e", "w"),
):
    c = gf.Component()
    r1 = c << fanout_frame(
        width, num_inputs, input_spacing, straight, transition[0], False
    )
    height = num_inputs * input_spacing
    r2 = c << fanout_frame(
        width, num_inputs, output_spacing, straight, transition[1], False
    )
    r2.drotate(90.0).dmove((height, height / 2))
    for i, p in enumerate(r1.ports):
        c.add_port(f"in{i}", port=p)
    for i, p in enumerate(r2.ports):
        c.add_port(f"out{i}", port=p)
    return c


@gf.cell
def routing_frame(
    num_links=5,
    input_spacing=60,
    output_spacing=40,
    straight="straight",
    bend_radius=5,
):
    c = gf.Component()
    height = num_links * input_spacing
    cstraight = gf.get_component(straight, length=1)
    refs_in = [c.add_ref(cstraight, name=f"in{i}") for i in range(num_links)]
    for i in range(num_links):
        refs_in[i].dmove(
            (0, height / 2 - (i + 1) // 2 * input_spacing * (2 * (i % 2) - 1))
        )
    refs_out = [c.add_ref(cstraight, name=f"out{i}") for i in range(num_links)]
    for i in range(num_links):
        refs_out[i].dmove(
            (
                0,
                6 * bend_radius
                + height / 2
                - (i + 1) // 2 * output_spacing * (2 * (i % 2) - 1),
            )
        )

    # starts = np.array(sorted([r.ports["o2"].center for r in refs_in]))
    # stops = np.array(sorted([r.ports["o1"].center for r in refs_out]))
    # above = ((starts[:, 1] - stops[:, 1]) > 0).sum()
    # below = num_links - above

    # width = int(2 * bend_radius + max(above, below) * output_spacing)
    width = 200
    c.add_ref(frame(width=width, height=height))
    for i in range(num_links):
        refs_out[i].dmove((width, 0))

    for i in range(num_links):
        c.add_port(f"in{i}", refs_in[i].ports["o2"])
        c.add_port(f"out{i}", refs_out[i].ports["o1"])
    return c
