from cmath import nan
from copy import deepcopy
from logging import error

from qosm import (Vec3, Quaternion, Surface, Item, GaussianBeam, gbtc_beam_tracing, gbtc_compute_coupling,
                  gbtc_apply_abcd)
from numpy import sqrt, isnan, array, eye, angle, unwrap

from qosm.propagation.TRL import TRLCalibration
from qosm.propagation.PW import simulate as simulate_pw


def simulate(params, frequency_GHz: float, ax = None) -> (complex, complex, complex, complex):
    tx_port = None
    rx_port_list = []

    calibration_mode = params.get('calibration', 'thru').lower()

    # lists of items
    sim_items = []  # Sim with sample
    thru_items = []  # THRU cal step simulation
    line_items = []  # LINE cal step simulation
    refl_items = []  # REFL cal step simulation

    common_lens = params.get('lens', None)

    legacy_port_type = {2: ('B', 'RX'), 3: ('TX', 'RX', 'RX')}
    for i, port in enumerate(params['ports']):
        if common_lens is not None:
            port['lens'] = common_lens
        port_type = port.get('port_type', None)
        if port_type is None:
            port_type = legacy_port_type[len(params['ports'])][i]
            port['port_type'] = port_type
        if port_type == 'TX':
            tx_port = port
        elif port_type == 'B':
            tx_port = port
            rx_port_list.append(port)
        elif port_type == 'RX':
            rx_port_list.append(port)

    num_refl = params.get('num_reflections', 5)

    # TX always +z (and at origin)
    r_tx = Vec3(0, 0, 0)

    if ax:
        ax.plot(r_tx[2], r_tx[0], '+', color='green', label='TX')

    sample_position = Vec3(0, 0, tx_port['distance_port_lens'] + tx_port['distance_lens_sample'])
    if 'sample_attitude' in params:
        sample_attitude_deg = Vec3(params['sample_attitude'])
    else:
        # legacy configuration compatibility (assume y-axis rotation)
        sample_attitude_deg = Vec3(0, params['angle_deg'], 0)
    sample_pose = (sample_position, sample_attitude_deg)

    # Add TX lens
    cur_id_surf = 0
    tx_lens = create_lens(tx_port, sample_pose, force_tx=True, id_surf=cur_id_surf)
    cur_id_surf += 2

    # Display TX lens
    if ax:
        for surf in tx_lens.surfaces:
            ax.plot(surf.centre[2], surf.centre[0], 'x', color='yellow')

    # Add RX lenses
    attitude_rx_port = None
    legacy_port_s_param = ('S11', 'S21')
    for i, port in enumerate(rx_port_list):
        rx_lens = create_lens(port, sample_pose, id_surf=cur_id_surf)
        cur_id_surf += 2

        s_parameter = port.get('s_parameter', None)
        if s_parameter is None:
            s_parameter = legacy_port_s_param[i]
            port['s_parameter'] = s_parameter
        if s_parameter == 'S21':
            sim_items.append(rx_lens)
            thru_items.append(rx_lens)
            if calibration_mode == 'trl':
                rx_lens_line = create_lens(port, sample_pose, line_offset=params['trl']['line_offset'])
                line_items.append(rx_lens_line)
            if ax:
                _u_rx, _r_rx, _ = get_port_pointing_direction(port, sample_pose)
                dz = _u_rx[2] * .02
                dx = _u_rx[0] * .02
                ax.plot(_r_rx[2], _r_rx[0], 'x', label='RX S21', color='orange')
                ax.quiver(_r_rx[2], _r_rx[0], dz, dx, width=0.004, color='orange')
        elif s_parameter == 'S11':
            sim_items.append(rx_lens)
            refl_items.append(rx_lens)
            if ax:
                _u_rx, _r_rx, _ = get_port_pointing_direction(port, sample_pose)
                dz = _u_rx[2] * .02
                dx = _u_rx[0] * .02
                ax.plot(_r_rx[2], _r_rx[0], 'o', label='RX S11', color='pink')
                ax.quiver(_r_rx[2], _r_rx[0], dz, dx, width=0.004, color='pink')

    # get TX beam (after TX lens)

    # retrieve the transmitted beam
    def get_beam_tx(id_surf):
        beam_tx0 = GaussianBeam(frequency_GHz=frequency_GHz, w0=tx_port['w0'], z0=0, ori=r_tx, dir=Vec3(0, 0, 1))
        _beams_tx = gbtc_beam_tracing(beam=beam_tx0, items=(tx_lens,), num_reflections=0)

        __beam_tx = None
        for beam in _beams_tx:
            if beam.id_surface == 1:
                __beam_tx = beam
                __beam_tx.ori += Vec3(0, 0, 0.001)

        __beam_tx.id_surface = id_surf
        if ax:
            print('Get tx beam: ', __beam_tx)
        return __beam_tx

    # Multilayer sample construction
    if len(params['mut']) > 0:
        slab_thickness = params['mut'][0].get('offset', 0)
        ior_list = []

        # Calculate refractive indices for each layer
        for n, slab in enumerate(params['mut']):
            if slab.get('thickness', 0) == 0:
                error(f'\n\nMedium of index {n}: Thickness cannot be 0 ! Please comment or remove this sample.\n')
                exit(42)

            eps_slab = slab['epsilon_r'] + 0j
            ior_slab = sqrt(eps_slab)
            ior_list.append(ior_slab)

        # First interface (air -> first layer)
        surfaces = [Surface(id=cur_id_surf, centre=Vec3(0, 0, slab_thickness), normal=Vec3(0, 0, -1), ior1=1.0 - 0j,
                            ior2=ior_list[0], curvature=nan, max_radius=1, allow_reflection=True, allow_refraction=True)]
        cur_id_surf += 1


        # Create interfaces between layers and from last layer to air
        for n, slab in enumerate(params['mut']):
            cur_ior = ior_list[n]  # Current layer index

            slab_thickness += slab['thickness']  # Update total thickness
            next_ior = ior_list[n + 1] if n < len(ior_list) - 1 else 1.0 - 0j  # Next index or air

            # Exit interface of current layer
            last_surf = Surface(id=cur_id_surf, centre=Vec3(0, 0, slab_thickness), normal=Vec3(0, 0, -1), ior1=cur_ior,
                                ior2=next_ior, curvature=nan, max_radius=1, allow_reflection=True,
                                allow_refraction=True)
            cur_id_surf += 1
            surfaces.append(last_surf)

        # Assemble multilayer sample with angular orientation
        slabs = Item(pos=sample_position, attitude_deg=sample_attitude_deg, surfaces=surfaces)

        sim_items.append(slabs)

    if ax:
        print('=====\nSIM\n=====')
        disp = True
        for item in sim_items:
            for ii, surf in enumerate(item.surfaces):
                print(surf)
                if disp:
                    disp = False
                    ax.plot(surf.centre[2], surf.centre[0], 'x', color='k', label='sim')
                else:
                    ax.plot(surf.centre[2], surf.centre[0], 'x', color='k')
        print('=====\nTHRU\n=====')
        disp = True
        for item in thru_items:
            for ii, surf in enumerate(item.surfaces):
                print(surf)
                if disp:
                    disp = False
                    ax.plot(surf.centre[2], surf.centre[0], '.', color='lime', label='thu')
                else:
                    ax.plot(surf.centre[2], surf.centre[0], '.', color='lime')
        print('=====\nREFL\n=====')
        disp = True
        for item in refl_items:
            for ii, surf in enumerate(item.surfaces):
                print(surf)
                if disp:
                    disp = False
                    ax.plot(surf.centre[2], surf.centre[0], '.', color='red', label='refl')
                else:
                    ax.plot(surf.centre[2], surf.centre[0], '.', color='red')

    beams_slab = gbtc_beam_tracing(beam=get_beam_tx(2), items=sim_items, num_reflections=num_refl)

    # TRL Calibration
    # --------------------------------
    # Assume:
    # --------------------------------
    #   - S11 THRU = 0
    #   - S21 REFLECT = 0
    #   - S11 LINE = 0
    #   - S12 = S21 for all standards
    #   - S11 = S22 for all standards
    # --------------------------------
    params_trl = params.get('trl', None)

    s11_sys = s21_sys = s21_line = s11_refl = s21_thru = nan
    for port in rx_port_list:
        u_rx, pos_rx, attitude_rx_port = get_port_pointing_direction(port, sample_pose)
        if port['s_parameter'] == 'S21':
            beams_thru = gbtc_beam_tracing(beam=get_beam_tx(2), items=thru_items, num_reflections=0)
            if calibration_mode == 'trl':
                beams_line = gbtc_beam_tracing(beam=get_beam_tx(2), items=line_items, num_reflections=1)
                pos_rx_line = pos_rx - u_rx * params['trl']['line_offset']
                s21_line, _, _, _ = gbtc_compute_coupling(beams=beams_line, w0_tx=tx_port['w0'], w0_rx=port['w0'],
                                                          r_rx=pos_rx_line, u_rx=u_rx, max_radius=.1, id_beam=0)

            s21_sys, _, _, _ = gbtc_compute_coupling(beams=beams_slab, w0_tx=tx_port['w0'], w0_rx=port['w0'],
                                                     r_rx=pos_rx, u_rx=u_rx, max_radius=.1, id_beam=0)
            s21_thru, _, _, _ = gbtc_compute_coupling(beams=beams_thru, w0_tx=tx_port['w0'], w0_rx=port['w0'],
                                                      r_rx=pos_rx, u_rx=u_rx, max_radius=.1, id_beam=0)
        elif port['s_parameter'] == 'S11':
            # Mirror for reflection measurements (calibration)
            mirror = Item(pos=sample_position, attitude_deg=attitude_rx_port * .5, surfaces=[
                Surface(id=4242, centre=Vec3(), normal=Vec3(0, 0, -1), ior1=1.0 - 0.j, ior2=1e30 - 0j, curvature=nan,
                        max_radius=.1, allow_reflection=True, allow_refraction=False),
            ])
            refl_items.append(mirror)

            beams_refl = gbtc_beam_tracing(beam=get_beam_tx(2), items=refl_items, num_reflections=1)
            s11_sys, _, _, _ = gbtc_compute_coupling(beams=beams_slab, w0_tx=tx_port['w0'], w0_rx=port['w0'],
                                                     r_rx=pos_rx, u_rx=u_rx, max_radius=.1, id_beam=0)
            s11_refl, _, _, _ = gbtc_compute_coupling(beams=beams_refl, w0_tx=tx_port['w0'], w0_rx=port['w0'],
                                                      r_rx=pos_rx, u_rx=u_rx,max_radius=.1, id_beam=0)

    if calibration_mode == 'trl' and (isnan(s21_line) or isnan(s21_thru) or isnan(s11_refl)):
        calibration_mode = 'thru'
        print('info: switched to THRU calibration mode')

    if calibration_mode == 'trl':
        trl = TRLCalibration(params_trl)
        trl.extract_error_terms(s21_thru, s11_refl, s21_line)
        return trl.apply_correction(s11_sys, s21_sys, s21_sys, s11_sys)
    else:
        s11 = - s11_sys / s11_refl if not isnan(s11_refl) and abs(s11_refl) != 0 else nan
        s21 = s21_sys / s21_thru if not isnan(s21_thru) and abs(s21_thru) != 0 else nan
        s22 = s11
        s12 = s21
        return s11, s12, s21, s22

def get_port_pointing_direction(port_data: dict, sample_pose: tuple, force_tx: bool = False) -> tuple:
    if port_data['port_type'] == 'TX' or force_tx:
        return Vec3(0, 0, 1), Vec3(), Vec3()

    pos_sample, att_sample = sample_pose
    if port_data.get('distance_lens_sample', None) is None:
        pos_port = Vec3(port_data['position'])
        att_port = Vec3(port_data['rotation'])
        u_port = (pos_sample - pos_port).normalised()
    else:
        dist_sample_lens = port_data['distance_lens_sample']
        distance_port_lens = port_data['distance_port_lens']
        pos_port_norot = Vec3(0, 0, -dist_sample_lens - distance_port_lens)

        angle = 2. * att_sample.norm()
        axis = att_sample.normalised()
        q_rot = Quaternion(angle=angle, axis=axis, deg=True)
        pos_port = q_rot.rotate(pos_port_norot) + pos_sample
        att_port = axis * angle
        u_port = q_rot.rotate(Vec3(0, 0, 1))

    return u_port, pos_port, att_port

def create_lens(port_data: dict, sample_pose: tuple, force_tx: bool = False, id_surf=-1, line_offset: float = 0) -> tuple:
    # n_: complex, f: float, h: float, R2: float, pos: Vec3, att: Vec3
    lens_data = port_data['lens']
    n_ = lens_data['ior']
    f = lens_data['focal']
    h = lens_data['thickness']
    R1 = nan if lens_data.get('R1', 0) == 0 else lens_data.get('R1', 0)
    R2 = nan if lens_data.get('R2', 0) == 0 else lens_data.get('R2', 0)
    radius = lens_data['radius']

    n = n_.real
    h2 = - f * (n - 1) * h / (n * R2)

    u_port, pos_port, att_port = get_port_pointing_direction(port_data, sample_pose, force_tx)
    pos_int1 = Vec3(0, 0, port_data['distance_port_lens'])
    pos_int2 = Vec3(0, 0, port_data['distance_port_lens'] + (h - h2))

    s1 = Surface(id=id_surf, centre=pos_int1, normal=Vec3(0, 0, -1), ior1=1.0 - 0.j, ior2=n_, curvature=R1,
                 max_radius=radius, allow_reflection=False, allow_refraction=True)
    s2 = Surface(id=id_surf+1, centre=pos_int2, normal=Vec3(0, 0, 1), ior1=1.0 - 0.j, ior2=n_, curvature=R2,
                 max_radius=radius, allow_reflection=False, allow_refraction=True)

    return Item(pos_port - u_port*line_offset, att_port, (s1, s2))


if __name__ == "__main__":
    import matplotlib
    from matplotlib import pyplot as plt
    from numpy import linspace, zeros_like, log10, abs

    def setup_matplotlib():
        # Tester les backends dans l'ordre de préférence
        backends = ['Qt5Agg', 'TkAgg', 'GTK3Agg', 'Agg']

        for backend in backends:
            try:
                matplotlib.use(backend)
                fig = plt.figure()
                plt.close(fig)
                return backend
            except:
                continue
        matplotlib.use('Agg')
        return 'Agg'


    backend = setup_matplotlib()

    dist_lens_sample = 0.2
    dist_port_lens = 0.095

    config = {
        'frequency_GHz': 275.0,
        'ports': [
            {   'port_type': 'TX',
                'w0': 0.0023, 'z0': 0.0,
                'distance_port_lens': dist_port_lens,
                'distance_lens_sample': dist_lens_sample,
            },
            {   'port_type': 'RX',
                'w0': 0.0023, 'z0': 0.0,
                'distance_port_lens': dist_port_lens,
                'distance_lens_sample': dist_lens_sample,
            },
            {   'port_type': 'RX',
                'w0': 0.0023, 'z0': 0.0,
                'distance_port_lens': dist_port_lens,
                'position': (0, 0, 0.6),
                'rotation': (0, 180, 0),
            }
        ],
        'angle_deg': 1.0,
        'lens': {'focal': 0.1, 'R1': 0.0, 'R2': -0.04, 'radius': 0.05, 'thickness': 0.013,
                 'ior': 1.4 },
        'mut': [
            {'epsilon_r': 2.59 - 0.0256j, 'thickness': 0.003},
            {'epsilon_r': 7.5 - 10.4j, 'thickness': 0.0006},
            {'epsilon_r': 2.59 - 0.0256j, 'thickness': 0.003}
        ],
        'calibration': 'trl',
        'num_reflections': 4,
        'trl': {
            'line_offset': 0.00025,
            'type_reflector': 'CC'
        }
    }

    setup_matplotlib()
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111)



    simulate(config, frequency_GHz=275., ax=ax)

    ax.legend()
    ax.set_aspect('equal')
    ax.set_ylim(-.21, .21)

    freq_array_GHz = linspace(220, 330, 501)
    s11 = zeros_like(freq_array_GHz, dtype=complex)
    s21 = zeros_like(freq_array_GHz, dtype=complex)
    s11_pw = zeros_like(freq_array_GHz, dtype=complex)
    s21_pw = zeros_like(freq_array_GHz, dtype=complex)
    i = 0
    for freq_GHz in freq_array_GHz:
        s11[i], s21[i], _, _ = simulate(config, frequency_GHz=freq_GHz)
        s11_pw[i], _, s21_pw[i], _ = simulate_pw(config, frequency_GHz=freq_GHz)
        i += 1

    _, ax = plt.subplots(2, 1)
    ax[0].plot(freq_array_GHz, 20 * log10(abs(s11)), '--', color='blue')
    ax[0].plot(freq_array_GHz, 20 * log10(abs(s11_pw)), '--', color='red')
    ax[0].plot(freq_array_GHz, 20 * log10(abs(s21)), color='blue', label='GBTC')
    ax[0].plot(freq_array_GHz, 20 * log10(abs(s21_pw)), color='red', label='PW')
    ax[0].legend()

    ax[1].plot(freq_array_GHz, unwrap(angle(s11)), '--', color='blue')
    ax[1].plot(freq_array_GHz, unwrap(angle(s11_pw)), '--', color='red')
    ax[1].plot(freq_array_GHz, unwrap(angle(s21)), color='blue', label='GBTC')
    ax[1].plot(freq_array_GHz, unwrap(angle(s21_pw)), color='red', label='PW')
    ax[1].legend()

    plt.show()
