# SPDX-FileCopyrightText: 2025 Qoro Quantum Ltd <divi@qoroquantum.de>
#
# SPDX-License-Identifier: Apache-2.0

import re
from functools import partial
from itertools import product
from warnings import warn

import cirq
import numpy as np
import pennylane as qml
from pennylane.tape import QuantumScript
from pennylane.wires import Wires
from sympy import Symbol

from divi.exp.cirq import cirq_circuit_from_qasm
from divi.qem import QEMProtocol

OPENQASM_GATES = {
    "CNOT": "cx",
    "CZ": "cz",
    "U3": "u3",
    "U2": "u2",
    "U1": "u1",
    "Identity": "id",
    "PauliX": "x",
    "PauliY": "y",
    "PauliZ": "z",
    "Hadamard": "h",
    "S": "s",
    "Adjoint(S)": "sdg",
    "T": "t",
    "Adjoint(T)": "tdg",
    "RX": "rx",
    "RY": "ry",
    "RZ": "rz",
    "CRX": "crx",
    "CRY": "cry",
    "CRZ": "crz",
    "SWAP": "swap",
    "Toffoli": "ccx",
    "CSWAP": "cswap",
    "PhaseShift": "u1",
}


def _ops_to_qasm(operations, precision, wires):
    # create the QASM code representing the operations
    qasm_str = ""

    for op in operations:
        try:
            gate = OPENQASM_GATES[op.name]
        except KeyError as e:
            raise ValueError(
                f"Operation {op.name} not supported by the QASM serializer"
            ) from e

        wire_labels = ",".join([f"q[{wires.index(w)}]" for w in op.wires.tolist()])
        params = ""

        if op.num_params > 0:
            # If the operation takes parameters, construct a string
            # with parameter values.
            if precision is not None:
                params = (
                    "(" + ",".join([f"{p:.{precision}}" for p in op.parameters]) + ")"
                )
            else:
                # use default precision
                params = "(" + ",".join([str(p) for p in op.parameters]) + ")"

        qasm_str += f"{gate}{params} {wire_labels};\n"

    return qasm_str


def to_openqasm(
    main_qscript,
    measurement_groups: list[list[qml.measurements.ExpectationMP]],
    measure_all: bool = True,
    precision: int | None = None,
    return_measurements_separately: bool = False,
    symbols: list[Symbol] = None,
    qem_protocol: QEMProtocol | None = None,
) -> list[str] | tuple[str, list[str]]:
    """
    Serialize the circuit as an OpenQASM 2.0 program.

    A modified version of PennyLane's function that is more compatible with having
    several measurements and incorporates modifications introduced by splitting transforms,
    as well as error mitigation through folding.

    The measurement outputs can be restricted to only those specified in the script by
    setting ``measure_all=False``.

    .. note::

        The serialized OpenQASM program assumes that gate definitions
        in ``qelib1.inc`` are available.

    Args:
        main_qscript (QuantumScript): the quantum circuit to be converted, as a QuantumScript/QuantumTape object.
        measurement_groups (list[list]): A list of list of commuting observables, generated by the grouping Pennylane transformation.
        measure_all (bool): whether to perform a computational basis measurement on all qubits
            or just those specified in the script
        precision (int): decimal digits to display for parameters
        return_measurements_separately (bool): whether to not append the measurement instructions
            and their diagonalizations to the main circuit QASM code and return separately.
        symbols (list): Sympy symbols present in the circuit. Needed for some QEM routines.
        qem_protocol (QEMProtocol): An optional QEMProtocol object for error mitigation, which may modify the circuit.

    Returns:
        list[str] or tuple[str, list[str]]: OpenQASM serialization of the circuit
    """

    if qem_protocol and symbols is None:
        raise ValueError(
            "When passing a QEMProtocol instance, the Sympy symbols in the circuit should be provided for the openqasm 3 conversion."
        )

    wires = main_qscript.wires

    _to_qasm = partial(_ops_to_qasm, precision=precision, wires=wires)

    # Add the QASM headers
    main_qasm_str = (
        'OPENQASM 3.0;\ninclude "stdgates.inc";\n'
        if qem_protocol
        else 'OPENQASM 2.0;\ninclude "qelib1.inc";\n'
    )

    if main_qscript.num_wires == 0:
        # empty circuit
        return main_qasm_str

    if qem_protocol:
        for symbol in symbols:
            main_qasm_str += f"input angle[32] {str(symbol)};\n"

    # create the quantum and classical registers
    main_qasm_str += (
        f"qubit[{len(wires)}] q;\n" if qem_protocol else f"qreg q[{len(wires)}];\n"
    )
    main_qasm_str += (
        f"bit[{len(wires)}] c;\n" if qem_protocol else f"creg c[{len(wires)}];\n"
    )

    # Wrapping Sympy Symbols in a numpy object to bypass
    # Pennylane's sanitization
    for op in main_qscript.operations:
        if qml.math.get_interface(*op.data) == "sympy":
            op.data = np.array(op.data)

    [transformed_tape], _ = qml.transforms.convert_to_numpy_parameters(main_qscript)
    operations = transformed_tape.operations

    # Decompose the queue
    just_ops = QuantumScript(operations)
    [decomposed_tape], _ = qml.transforms.decompose(
        just_ops, gate_set=lambda obj: obj.name in OPENQASM_GATES
    )

    main_qasm_str += _to_qasm(decomposed_tape.operations)

    main_qasm_strs = []
    if qem_protocol:
        for circ in qem_protocol.modify_circuit(cirq_circuit_from_qasm(main_qasm_str)):
            # Convert back to QASM2.0 code, with the symbolic parameters
            qasm_str = cirq.qasm(circ)
            # Remove redundant newlines
            qasm_str = re.sub(r"\n+", "\n", qasm_str)
            # Remove comments
            qasm_str = re.sub(r"^//.*\n?", "", qasm_str, flags=re.MULTILINE)
            # Add missing classical reg
            qasm_str = re.sub(r"qreg q\[(\d+)\];", r"qreg q[\1];creg c[\1];", qasm_str)

            main_qasm_strs.append(qasm_str)
    else:
        main_qasm_strs.append(main_qasm_str)

    qasm_circuits = []
    measurement_qasms = []

    # Create a copy of the program for every measurement that we have
    for meas_group in measurement_groups:
        curr_diag_qasm_str = (
            _to_qasm(diag_ops)
            if (diag_ops := QuantumScript(measurements=meas_group).diagonalizing_gates)
            else ""
        )

        measure_qasm_str = ""
        if measure_all:
            for wire in range(len(wires)):
                measure_qasm_str += f"measure q[{wire}] -> c[{wire}];\n"
        else:
            measured_wires = Wires.all_wires(
                [m.wires for m in main_qscript.measurements]
            )

            for w in measured_wires:
                wire_indx = main_qscript.wires.index(w)
                measure_qasm_str += f"measure q[{wire_indx}] -> c[{wire_indx}];\n"

        measurement_qasms.append(curr_diag_qasm_str + measure_qasm_str)

    if not return_measurements_separately:
        qasm_circuits.extend(product(main_qasm_strs, measurement_qasms))

    if len(measurement_groups) == 0:
        warn(
            "No measurement groups provided. Returning the QASM of the circuit operations only."
        )
        qasm_circuits.extend(np.atleast_1d(main_qasm_strs).tolist())
        return qasm_circuits

    return qasm_circuits or (np.atleast_1d(main_qasm_strs).tolist(), measurement_qasms)
