# -*- coding: utf-8 -*-
"""
"""
from __future__ import (division, print_function, unicode_literals)
from ..utilities.future_from_2 import str, object

#from phasor.utilities.print import print
#import declarative
import collections
#import copy
import numpy as np

from . import ports
from ..utilities.print import pprint

from ..system.matrix_injections import (
    FactorCouplingBase,
)

def pk_prefs(*preflist):
    def key(pk):
        p, k = pk
        pk = p | k
        ksort = []
        for ktype in preflist:
            ksort.append(
                str(pk.get(ktype, None))
            )
        pk.purge_keys(*preflist)
        ksort.extend(sorted(pk.items()))
        return tuple(ksort)
    return key


def lt_mult(lt, lt2):
    def lt_mult_inner(lt, lt2):
        if isinstance(lt2, list):
            sublist = []
            for sublt2 in lt2:
                sublist = sublist + lt_mult_inner(lt, sublt2)
            return sublist
        elif isinstance(lt2, tuple):
            #multiply the gains and merge the indices
            return [(lt2[0] * lt[0],) + lt[1:] + lt2[1:]]
        else:
            raise RuntimeError("BOO")
    if isinstance(lt, list):
        sublist = []
        for sublt in lt:
            sublist = sublist + lt_mult(sublt, lt2)
        return sublist
    elif isinstance(lt, tuple):
        return lt_mult_inner(lt, lt2)
    else:
        raise RuntimeError("BOO")
    return


def lt_sort_collect(new_list):
    if not new_list:
        return new_list
    sNL = [(tuple(sorted(L[1:])), idx) for idx, L in enumerate(new_list)]
    sNL.sort()
    list_gen = []

    NL, idx = sNL[0]
    prev_NL = NL
    prev_gain = new_list[idx][0]
    for NL, idx in sNL[1:]:
        if prev_NL == NL:
            prev_gain += new_list[idx][0]
        else:
            list_gen.append(
                (prev_gain,) + tuple(prev_NL)
            )
            prev_NL = NL
            prev_gain = new_list[idx][0]
    list_gen.append(
        (prev_gain,) + tuple(prev_NL)
    )
    return list_gen


class ExpMatCoupling(FactorCouplingBase):
    """
    Generated by a dict-dict mapping ddlt[out][in] = list-tup-expr
    where list-tup-expr are a means of expressing addition and multiplication as lists and tups.

    lists represent series addition

    tups represent series multiplication, the first term is a raw number coefficient and any further are keys in the source
    vector
    """
    floating_req_set = None
    #holds a list of 3-tuples, each one carrying an in-set, out-set, and coupling-func
    #it behaves as though all in-set is connected to out-set for requirements analysis
    #but then during coupling matrix construction, only the edges returned from coupling-func
    #will be used
    floating_in_out_func_pairs = None

    #must redefine since it was a property
    edges_req_pkset_dict = None
    def __init__(
        self,
        dLt,
        in_map,
        out_map,
        N_ode = 1,
    ):
        self.N_ode     = N_ode
        self.dLt       = dLt
        self.out_map   = out_map
        self.in_map    = in_map
        self.solution  = dict()
        self.vals_prev = dict()
        self.vals_inj  = dict()

        #make the internal solution have no action initially (so that cavity feedback converges faster)
        for pk_internal, pk_in in self.in_map.items():
            pk_out = self.out_map.get(pk_internal, None)
            if pk_out is not None:
                self.solution[pk_in, pk_out] = 1

        #all edges are generated immediately. Currently assumes full density
        self.edges_NZ_pkset_dict = {}
        self.edges_pkpk_dict = {}
        self.edges_req_pkset_dict = {}

        def gen_edge_func(pk_in, pk_out):
            return lambda sV, sB: self.edge_func(pk_in, pk_out, sV, sB)

        self.floating_req_set = frozenset(self.in_map.values())
        ins = set(self.in_map.values())
        outs = set([pkt for pkt in self.out_map.values() if pkt is not None])
        self.floating_in_out_func_pairs = [(ins, outs, self.edge_mat_func)]

        def gen_src_func_out(pk_out):
            return lambda sV, sB: self.source_func_out(pk_out, sV, sB)
        self.sources_pk_dict = {}
        self.sources_NZ_pkset_dict = {}
        for pk_out in self.out_map.values():
            if pk_out is None:
                continue
            self.sources_pk_dict[pk_out] = gen_src_func_out(pk_out)
            self.sources_NZ_pkset_dict[pk_out] = frozenset()

        pks = set()
        def pks_grab(lt):
            if isinstance(lt, list):
                for sublt in lt:
                    pks_grab(sublt)
            elif isinstance(lt, tuple):
                for pk in lt[1:]:
                    pks.add(pk)
            else:
                raise RuntimeError("BOO")
            return
        for pk_out, lt in self.dLt.items():
            pks.add(pk_out)
            pks_grab(lt)

        #print(ins_p, sol_vector)
        self.pks = list(pks)
        #TODO: debug config
        print("Number of states: ", len(pks))
        #pprint(pks)
        self.pks.sort(key = pk_prefs(
            ports.QuantumKey,
            ports.ElementKey,
            ports.PortKey,
            ports.OpticalFreqKey,
            ports.ClassicalFreqKey,
            ports.PolKEY,
        ))
        self.pks_inv = dict()
        for idx, pk in enumerate(self.pks):
            self.pks_inv[pk] = idx

        h = 1 / self.N_ode
        #remap the index keys into integer indexes for speed
        dLt_accel = dict()
        def lt_remap(lt):
            if isinstance(lt, list):
                sublist = []
                for sublt in lt:
                    sublist.append(lt_remap(sublt))
                return sublist
            elif isinstance(lt, tuple):
                #get the gain
                newtup = [lt[0] * h]
                for pk in lt[1:]:
                    newtup.append(self.pks_inv[pk])
                return tuple(newtup)
            else:
                raise RuntimeError("BOO")
            return
        for pk_out, lt in self.dLt.items():
                dLt_accel[self.pks_inv[pk_out]] = lt_sort_collect(lt_remap(lt))
        self.dLt_accel = dLt_accel

        #pprint("PKS:")
        #pprint(pks)

    _prev_sol_vector = None

    def update_solution(self, sol_vector):
        if self._prev_sol_vector is None or sol_vector != self._prev_sol_vector:
            self._prev_sol_vector = sol_vector
            self.generate_solution(sol_vector)
        return

    def edge_mat_func(self, sol_vector, sB):
        self.update_solution(sol_vector)
        return self.solution

    def source_func_out(self, pk_out, sol_vector, sB):
        self.update_solution(sol_vector)
        return self.vals_inj.get(pk_out, 0)

    def generate_solution(self, sol_vector):
        pks = self.pks
        pkv = np.empty(len(pks), dtype=object)

        all_zeros = True
        for idx, pk in enumerate(pks):
            pk_in = self.in_map[pk]
            #print("PK_G: ", (ins_p, pk))
            sol_val = sol_vector.get(pk_in, 0)
            prev_val = self.vals_prev.get(pk_in, 0)
            tot_val = sol_val - prev_val
            new_val = np.copy(sol_val)
            if np.all(new_val == 0):
                new_val = 0
            pkv[idx] = new_val
            if np.any(abs(tot_val) > 1e-8 * abs(sol_val)):
                all_zeros = False
        if all_zeros:
            #TODO debug statements
            #print("ALL ZEROS")
            #no need to update the solution vector since nothing changed
            return
            pass
        else:
            #TODO debug statements
            #print("NOT ZEROS")
            pass
        #print("SOLVING ", self.pks)

        pk_original = pkv.copy()

        def lt_val(lt):
            assert(isinstance(lt, list))
            val = 0
            for sublt in lt:
                assert(isinstance(sublt, tuple))
                gain = np.copy(sublt[0])
                for pk_idx in sublt[1:]:
                    gain = gain * pkv[pk_idx]
                val += gain
            return val

        def lt_val_matrix_d(vec_d, lt_M_d):
            for pk_idx_from, lt in lt_M_d.items():
                assert(isinstance(lt, list))
                for sublt in lt:
                    assert(isinstance(sublt, tuple))
                    local_gain = np.copy(sublt[0])
                    for pk_idx in sublt[1:]:
                        val = pkv[pk_idx]
                        local_gain = local_gain * val
                    vec_d[pk_idx_from] += local_gain

        def lt_val_matrix_d_generate(pkv_nz_set, lt):
            lt_M_d = dict()
            full = 0
            reduced = 0
            assert(isinstance(lt, list))
            for sublt in lt:
                assert(isinstance(sublt, tuple))
                gain = sublt[0]
                for idx_idx, pk_idx_from in enumerate(sublt[1:]):
                    newtup = [gain]
                    full += 1
                    for idx_idx2, pk_idx in enumerate(sublt[1:]):
                        if idx_idx == idx_idx2:
                            continue
                        if pk_idx not in pkv_nz_set:
                            break
                        newtup.append(pk_idx)
                    else:
                        #only occurs if break was NOT called
                        reduced += 1
                        lt_inj = lt_M_d.setdefault(pk_idx_from, [])
                        lt_inj.append(
                            tuple(newtup)
                        )
            #print("REDUCED LT_d: ", reduced / full)
            return lt_M_d

        def lt_reduced_generate(pkv_nz_set, lt):
            newlt = []
            assert(isinstance(lt, list))
            for sublt in lt:
                assert(isinstance(sublt, tuple))
                newtup = [sublt[0]]
                for pk_idx in sublt[1:]:
                    if pk_idx not in pkv_nz_set:
                        break
                    newtup.append(pk_idx)
                else:
                    #only occurs if break was NOT called
                    newlt.append(
                        tuple(newtup)
                    )
            #print("REDUCED LT: ", len(newlt) / len(lt))
            return newlt

        #a double map matrix dMexp_s1[idx_out][idx_in]
        ##this dMexp_s1 does not have the one. For speed that is applied separately
        dMexp_s1 = collections.defaultdict(lambda : collections.defaultdict(lambda : 0))

        pkv_nz_set = set()
        for idx, val in enumerate(pkv):
            if np.any(val != 0):
                pkv_nz_set.add(idx)

        dLt_base = sorted(self.dLt_accel.items())
        dLt_idx_list = [T[0] for T in dLt_base]
        dLt_lt_list  = [lt_reduced_generate(pkv_nz_set, T[1]) for T in dLt_base]
        dLt_lt_M_list  = [lt_val_matrix_d_generate(pkv_nz_set, T[1]) for T in dLt_base]

        for idx_N in range(self.N_ode):
            idx_in_list = 0
            fullskip = 0
            while idx_in_list < len(dLt_idx_list):
                idx_pk = dLt_idx_list[idx_in_list]
                lt     = dLt_lt_list[idx_in_list]
                lt_M_d = dLt_lt_M_list[idx_in_list]
                idx_in_list += 1
                if not lt and not lt_M_d:
                    fullskip += 1

                dPK = lt_val(lt)
                Mexp_vec_d = collections.defaultdict(lambda : 0)
                lt_val_matrix_d(Mexp_vec_d, lt_M_d)

                if pkv[idx_pk] is 0:
                    if np.any(dPK != 0):
                        #print("STATUS CHANGE: ", idx_pk)
                        pkv_nz_set.add(idx_pk)
                        dLt_idx_list = [T[0] for T in dLt_base]
                        dLt_lt_list  = [lt_reduced_generate(pkv_nz_set, T[1]) for T in dLt_base]
                        dLt_lt_M_list  = [lt_val_matrix_d_generate(pkv_nz_set, T[1]) for T in dLt_base]
                if np.any(dPK != 0):
                    pkv[idx_pk] = pkv[idx_pk] + dPK

                dMexp_update = collections.defaultdict(lambda : 0)
                #print("LEN: ", len(dMexp_s1))
                #print("VEC: ", len(Mexp_vec_d))
                N_zero = 0
                N_skip = 0
                for idx_out, vec_val in Mexp_vec_d.items():
                    for idx_in, edge in dMexp_s1[idx_out].items():
                        if np.all(edge == 0):
                            N_zero += 1
                            continue
                        dMexp_update[idx_in] += (edge * vec_val)

                #print("NZERO: ", N_zero, N_zero / (.001+len(dMexp_s1)))
                #print("NSKIP: ", N_skip, N_skip / (.001+len(dMexp_s1)))

                for idx_in, vec_val in dMexp_update.items():
                    if np.any(vec_val != 0):
                        dMexp_s1[idx_pk][idx_in] += vec_val

                #applied the initial vector again since the dMexp_s1 didn't start with diagonal ones
                for idx_in, vec_val in Mexp_vec_d.items():
                    if np.any(vec_val != 0):
                        dMexp_s1[idx_pk][idx_in] += vec_val
            #print("FULLSKIP: ", fullskip, fullskip / len(dLt_idx_list))

        self.vals_prev = dict()
        #inject subtracted values at input to remove the DC values an only get the derivative
        for idx_in in range(len(pks)):
            pkin = self.in_map[pks[idx_in]]

            val = pk_original[idx_in]
            if np.any(val != 0):
                self.vals_prev[pkin] = val

        #print("START EDGES")
        #for idx_out in range(len(pks)):
        #    for idx_in in range(len(pks)):
        #        edge = dMexp_s1.get(idx_out, dict()).get(idx_in, None)
        #        if edge is not None:
        #            pkin = self.in_map[pks[idx_in]]
        #            pkout = self.out_map[pks[idx_out]]
        #            if idx_in == idx_out:
        #                edge = edge + 1

        #            print("IN: ", pkin[1])
        #            print("  OUT: ", pkout[1])
        #            print("EDGE: ", edge)
        #print("DONE EDGES")
        #dval_out holds the product of the input through the derivative matrix
        #this way it can cancel the forward propagation so that the output is correct assuming
        #the inputs do not change
        solution = dict()
        dval_out = collections.defaultdict(lambda : 0)

        #alter the derivative matrix to pass through the direct values
        for idx in range(len(pks)):
            in_map = dMexp_s1.get(idx, None)
            if in_map is None:
                dMexp_s1[idx] = {idx : 1}
                continue
            in_map[idx] = in_map.get(idx, 0) + 1

        for idx_out, in_map in dMexp_s1.items():
            for idx_in, edge in in_map.items():
                pkin = self.in_map[pks[idx_in]]
                pkout = self.out_map[pks[idx_out]]

                if pkin is not None and pkout is not None:
                    solution[pkin, pkout] = edge

                #also compute dval_out
                val_orig = pk_original[idx_in]
                prod = (edge * val_orig)
                if np.any(prod != 0):
                    dval_out[idx_out] += prod
        self.solution = solution

        vals_inj = dict()
        for idx_out in range(len(pks)):
            pkout = self.out_map[pks[idx_out]]
            if pkout is None:
                continue
            val = pkv[idx_out]
            altered_val = val - dval_out.get(idx_out, 0)
            #print("ALT: ", pkout, val, dval_out.get(idx, 0), altered_val)
            vals_inj[pkout] = altered_val
        self.vals_inj = vals_inj

        #print("DONE SOLVING ")
        return



#Old version
"""
class ExpMatCoupling(FactorCouplingBase):

    #must redefine since it was a property
    edges_req_pkset_dict = None
    def __init__(
        self,
        ddlt,
        in_map,
        out_map,
        N_ode = 1,
        order = 2,
        symplectify = True,
    ):
        self.N_ode   = N_ode
        self.order   = order
        self.ddlt    = ddlt
        self.out_map = out_map
        self.in_map  = in_map

        #all edges are generated immediately. Currently assumes full density
        self.edges_NZ_pkset_dict = {}
        self.edges_pkpk_dict = {}
        self.edges_req_pkset_dict = {}
        def gen_edge_func(pk_in, pk_out):
            return lambda sV, sB: self.edge_func(pk_in, pk_out, sV, sB)
        for pk_out in self.out_map.values():
            if pk_out is None:
                continue
            for pk_in in self.in_map.values():
                self.edges_NZ_pkset_dict[(pk_in, pk_out)] = frozenset()
                self.edges_pkpk_dict[(pk_in, pk_out)] = gen_edge_func(pk_in, pk_out)
                self.edges_req_pkset_dict[(pk_in, pk_out)] = frozenset(self.in_map.values())

        #Currently, nonlinear doesn't need to make any sources. It may in the future as that may be a more stable way to converge
        self.sources_pk_dict = {}
        self.sources_NZ_pkset_dict = {}

        pks = set()
        def pks_grab(lt):
            if isinstance(lt, list):
                for sublt in lt:
                    pks_grab(sublt)
            elif isinstance(lt, tuple):
                for pk in lt[1:]:
                    pks.add(pk)
            else:
                raise RuntimeError("BOO")
            return
        for pk_out, din in self.ddlt.items():
            pks.add(pk_out)
            for pk_in, lt in din.items():
                pks.add(pk_in)
                pks_grab(lt)

        #print(ins_p, sol_vector)
        self.pks = list(pks)
        print("OMG: ", len(pks))
        pprint(pks)
        self.pks.sort()
        self.pks_inv = dict()
        for idx, pk in enumerate(self.pks):
            self.pks_inv[pk] = idx

        h = 1 / self.N_ode
        #remap the index keys into integer indexes for speed
        ddlt_accel = dict()
        def ddlt_remap(lt):
            if isinstance(lt, list):
                sublist = []
                for sublt in lt:
                    sublist.append(ddlt_remap(sublt))
                return sublist
            elif isinstance(lt, tuple):
                #get the gain
                newtup = [lt[0] * h]
                for pk in lt[1:]:
                    newtup.append(self.pks_inv[pk])
                return tuple(newtup)
            else:
                raise RuntimeError("BOO")
            return
        for pk_out, din in self.ddlt.items():
            for pk_in, lt in din.items():
                ddlt_accel[self.pks_inv[pk_out], self.pks_inv[pk_in]] = ddlt_remap(lt)
        self.ddlt_accel = ddlt_accel

        def ddlt_mult(lt, lt2):
            def ddlt_mult2(lt, lt2):
                if isinstance(lt2, list):
                    sublist = []
                    for sublt2 in lt2:
                        sublist = sublist + ddlt_mult2(lt, sublt2)
                    return sublist
                elif isinstance(lt2, tuple):
                    #multiply the gains and merge the indices
                    return [(lt2[0] * lt[0],) + lt[1:] + lt2[1:]]
                else:
                    raise RuntimeError("BOO")
            if isinstance(lt, list):
                sublist = []
                for sublt in lt:
                    sublist = sublist + ddlt_mult(sublt, lt2)
                return sublist
            elif isinstance(lt, tuple):
                return ddlt_mult2(lt, lt2)
            else:
                raise RuntimeError("BOO")
            return
        def sort_collect(new_list):
            if not new_list:
                return new_list
            sNL = [(tuple(sorted(L[1:])), idx) for idx, L in enumerate(new_list)]
            sNL.sort()
            list_gen = []

            NL, idx = sNL[0]
            prev_NL = NL
            prev_gain = new_list[idx][0]
            for NL, idx in sNL[1:]:
                if prev_NL == NL:
                    prev_gain += new_list[idx][0]
                else:
                    list_gen.append(
                        (prev_gain,) + tuple(prev_NL)
                    )
                    prev_NL = NL
                    prev_gain = new_list[idx][0]
            list_gen.append(
                (prev_gain,) + tuple(prev_NL)
            )
            return list_gen


        self.ddlt_accel = ddlt_accel

        if symplectify:
            #print("IS SYMPLECTIC!")
            ddlt_accel_SE = dict()
            #put in the diagonals first
            for idx, pk in enumerate(self.pks):
                ddlt_accel_SE[idx, idx] = [(1,)]

            for pk_out, din in self.ddlt.items():
                pk_out_idx = self.pks_inv[pk_out]
                for pk_in, lt in din.items():
                    pk_in_idx = self.pks_inv[pk_in]
                    assert(pk_out_idx != pk_in_idx)
                    for col_idx, pk in enumerate(self.pks):
                        #needs to multiply everything
                        lt_rm = ddlt_remap(lt)
                        lt_keep = ddlt_accel_SE.get((pk_out_idx, col_idx), [])
                        lt_mult = ddlt_accel_SE.get((pk_in_idx, col_idx), [])
                        new_list = lt_keep + ddlt_mult(lt_mult, lt_rm)
                        #clear out now couplings for speed
                        new_list = [L for L in new_list if len(L[1:]) <= 3]
                        if not new_list or new_list == [(1,)]:
                            ddlt_accel_SE[pk_out_idx, col_idx] = new_list
                            continue

                        new_list = sort_collect(new_list)
                        #if new_list:
                            #print(pk_out_idx, col_idx)
                            #print([("{0:.2f}".format(np.log10(abs(L[0]) / h)), len(L[1:])) for L in new_list])
                        ddlt_accel_SE[pk_out_idx, col_idx] = new_list

            ddlt_accel_SE_use = dict()
            for (idx_out, idx_in), lt in ddlt_accel_SE.items():
                #TODO use a search to remove this since they have been sorted
                if idx_out == idx_in:
                    if lt and lt[0] == (1,):
                        lt = lt[1:]
                    else:
                        lt = lt + [(-1,)]
                if lt:
                    ddlt_accel_SE_use[idx_out, idx_in] = lt
            #pprint(ddlt_accel_SE_use)

            self.ddlt_accel = ddlt_accel_SE_use

            #add in linear term
            new_mat = copy.deepcopy(self.ddlt_accel)
            #add in diagonal (1) term
            for idx in range(len(self.pks)):
                lst = new_mat.setdefault((idx, idx), [])
                lst.append((1,))
            #add in second-order term
            for (idx_out, idx_in), lt1 in self.ddlt_accel.items():
                for idx_in2 in range(len(self.pks)):
                    lt2 = self.ddlt_accel.get((idx_in, idx_in2), None)
                    if lt2 is None:
                        continue
                    lt_mult = []
                    for l1 in lt1:
                        lt_mult.extend([(l1[0] * l2[0] / 2, ) + l1[1:] + l2[1:] for l2 in lt2])
                    #clear out now couplings for speed
                    lt_mult = sort_collect([L for L in lt_mult if abs(L[0]) >= 1e-8])
                    #if lt_mult:
                    #    print("1: ", lt1)
                    #    print("2: ", lt2)
                    #    print("LT_MULT: ", lt_mult)
                    nl = new_mat.setdefault((idx_out, idx_in2), [])
                    nl.extend(lt_mult)

            for (idx_out, idx_in) in list(new_mat.keys()):
                lt1 = new_mat[idx_out, idx_in]
                lt = sort_collect(lt1)
                nl = [L for L in nl if abs(L[0]) >= 1e-5]
                if lt:
                    lt1[:] = lt
                else:
                    del new_mat[idx_out, idx_in]
            self.ddlt_accel_premult2 = new_mat

        #pprint("PKS:")
        #pprint(pks)

    _prev_sol_vector = None

    def edge_func(self, pk_in, pk_out, sol_vector, sB):
        if sol_vector != self._prev_sol_vector:
            self._prev_sol_vector = sol_vector
            if self.order > 0:
                #self.generate_solution(sol_vector)
                self.generate_solution_premult(sol_vector)
            elif self.order == 0:
                self.generate_solution_RK(sol_vector)
        return self.solution.get((pk_in, pk_out), 0)

    def generate_solution_premult(self, sol_vector):
        pks = self.pks
        pkv = np.empty(len(pks), dtype=object)
        for idx, pk in enumerate(pks):
            #print("PK_G: ", (ins_p, pk))
            pkv[idx] = sol_vector.get(self.in_map[pk], 0)
        pkO = pkv.copy()
        #print("PKV: ", pkv)
        #try:
        #    import tabulate
        #    tabular_data = [[str(label)] + [pk] for label, pk in zip(pks, pkv)]
        #    print("PKs:")
        #    print(tabulate.tabulate(tabular_data))
        #except ImportError:
        #    print("XXXX")

        def lt_val(lt):
            if isinstance(lt, list):
                val = 0
                for sublt in lt:
                    val = lt_val(sublt) + val
            elif isinstance(lt, tuple):
                val = lt[0]
                #print("LT0: ", val)
                for pk_idx in lt[1:]:
                    val = val * pkv[pk_idx]  # sol_vector.get(pk, 0)
            else:
                raise RuntimeError("BOO")
            return val

        eye = np.eye(len(pks), dtype = object)
        Mexp_tot = eye
        for idx_N in range(self.N_ode):
            print('pe_A premult')
            Mexp = np.zeros([len(pks), len(pks)], dtype = object)
            for (idx_out, idx_in), lt in self.ddlt_accel_premult2.items():
                    val = lt_val(lt)
                    Mexp[idx_out, idx_in] = val
            #print("Mexp: ", Mexp)
            print('pe_B')
            Mexp_tot = np.dot(Mexp, Mexp_tot)
            print("Sparsity: ", idx_N, np.sum(Mexp_tot.flatten() != 0)/ len(Mexp_tot.flatten()))
            #print(Mexp.shape, pkv.shape)
            pkv = np.dot(Mexp, pkv.reshape(-1, 1)).reshape(-1)
            #try:
            #    import tabulate
            #    tabular_data = [[str(label)] + [str(pk), str(pkk)] for label, pk, pkk in zip(pks, pkv, pkX)]
            #    print("PKs2:")
            #    print(tabulate.tabulate(tabular_data))
            #except ImportError:
            #    print("XXXX")
            #print("pkv2:", type(pkv), pkv.shape)
            #print(pkv)

        #print(m1)
        #print(Mexp)
        try:
            import tabulate
            tabular_data = [[str(label)] + [str(pk)] for label, pk in zip(pks, pkv)]
            print("PKs2:")
            print(tabulate.tabulate(tabular_data))
        except ImportError:
            print("XXXX")
        try:
            import tabulate
            tabular_data = [[str(label)] + list(abs(x) for x in td) for idx, (label, td) in enumerate(zip(pks, Mexp_tot))]
            print("Mexp_tot", idx)
            print(Mexp.dtype)
            print(tabulate.tabulate(tabular_data))
        except ImportError:
            print("XXXX")

        N_sparsity = 0
        solution = dict()
        for idx_in in range(len(pks)):
            for idx_out in range(len(pks)):
                edge = Mexp_tot[idx_out, idx_in]
                if np.any(edge != 0):
                    N_sparsity += 1
                    #pk_in = pks[idx_in]
                    #pk_out = pks[idx_out]
                    ###TODO: add debug config reference for this print
                    #print(pk_in)
                    #print(pk_out)
                    #print(idx_in, idx_out, edge)
                    pkin = self.in_map[pks[idx_in]]
                    pkout = self.out_map[pks[idx_out]]
                    if pkout is None:
                        continue
                    solution[pkin, pkout] = edge
        #print("Sparsity: ", N_sparsity / len(Mexp_tot.flatten()))

        #pprint(pks)

        self.solution = solution

    def generate_solution(self, sol_vector):
        pks = self.pks
        pkv = np.empty(len(pks), dtype=object)
        for idx, pk in enumerate(pks):
            #print("PK_G: ", (ins_p, pk))
            pkv[idx] = sol_vector.get(self.in_map[pk], 0)
        pkO = pkv.copy()
        #print("PKV: ", pkv)
        #try:
        #    import tabulate
        #    tabular_data = [[str(label)] + [pk] for label, pk in zip(pks, pkv)]
        #    print("PKs:")
        #    print(tabulate.tabulate(tabular_data))
        #except ImportError:
        #    print("XXXX")

        def lt_val(lt):
            if isinstance(lt, list):
                val = 0
                for sublt in lt:
                    val = lt_val(sublt) + val
            elif isinstance(lt, tuple):
                val = lt[0]
                #print("LT0: ", val)
                for pk_idx in lt[1:]:
                    val = val * pkv[pk_idx]  # sol_vector.get(pk, 0)
            else:
                raise RuntimeError("BOO")
            return val

        eye = np.eye(len(pks), dtype = object)
        Mexp_tot = eye
        for idx_N in range(self.N_ode):
            print('pe_A')
            m1 = np.zeros([len(pks), len(pks)], dtype = object)
            for (idx_out, idx_in), lt in self.ddlt_accel.items():
                    val = lt_val(lt)
                    m1[idx_out, idx_in] = val
            #print("M1: ", m1)
            print('pe_B')
            Mexp = m1 + eye
            mmem = m1
            #try:
            #    import tabulate
            #    tabular_data = [[str(idx)] + list(td) for idx, (label, td) in enumerate(zip(pks, m1))]
            #    print("M1")
            #    print(m1.dtype)
            #    print(tabulate.tabulate(tabular_data))
            #except ImportError:
            #    print("XXXX")
            #try:
            #    import tabulate
            #    tabular_data = [[str(idx)] + list(str(x) for x in td) for idx, (label, td) in enumerate(zip(pks, Mexp))]
            #    print("Mexp", 1)
            #    print(Mexp.dtype)
            #    print(tabulate.tabulate(tabular_data))
            #except ImportError:
            #    print("XXXX")
            for idx in range(2, self.order+1):
                mmem = (1 / idx) * np.dot(m1, mmem)
                #try:
                #    import tabulate
                #    tabular_data = [[str(idx)] + list(td) for idx, (label, td) in enumerate(zip(pks, mmem))]
                #    print("mmem", idx)
                #    print(mmem.dtype)
                #    print(tabulate.tabulate(tabular_data))
                #except ImportError:
                #    print("XXXX")
                Mexp = Mexp + mmem
            print('pe_C')
            #try:
            #    import tabulate
            #    tabular_data = [[str(idx)] + list(str(x) for x in td) for idx, (label, td) in enumerate(zip(pks, Mexp))]
            #    print("Mexp", idx)
            #    print(Mexp.dtype)
            #    print(tabulate.tabulate(tabular_data))
            #except ImportError:
            #    print("XXXX")
            #import scipy.linalg
            #Mexpe_2 = scipy.linalg.expm(m1.astype(complex))
            #try:
            #    import tabulate
            #    tabular_data = [[str(idx)] + list(str(x) for x in td) for idx, (label, td) in enumerate(zip(pks, Mexpe_2))]
            #    print("Mexpe_2", idx)
            #    print(Mexpe_2.dtype)
            #    print(tabulate.tabulate(tabular_data))
            #except ImportError:
            #    print("XXXX")
            ### IMPROVE POWER CONSERVATION
            #for idx in range(len(pks)):
            #    NORMsq = np.dot(Mexp[idx], Mexp[idx].conjugate())
            #    #print("pwr ", idx, " VAL: ", NORMsq)
            #    Mexp[idx] = Mexp[idx] / (NORMsq.real)**.5
            Mexp_tot = np.dot(Mexp, Mexp_tot)
            print("Sparsity: ", idx_N, np.sum(Mexp_tot.flatten() != 0)/ len(Mexp_tot.flatten()))
            #print(Mexp.shape, pkv.shape)
            pkv = np.dot(Mexp, pkv.reshape(-1, 1)).reshape(-1)
            #try:
            #    import tabulate
            #    tabular_data = [[str(label)] + [str(pk), str(pkk)] for label, pk, pkk in zip(pks, pkv, pkX)]
            #    print("PKs2:")
            #    print(tabulate.tabulate(tabular_data))
            #except ImportError:
            #    print("XXXX")
            #print("pkv2:", type(pkv), pkv.shape)
            #print(pkv)

        #print(m1)
        #print(Mexp)
        #try:
        #    import tabulate
        #    tabular_data = [[str(label)] + [pk] for label, pk in zip(pks, pkv)]
        #    print("PKs2:")
        #    print(tabulate.tabulate(tabular_data))
        #except ImportError:
        #    print("XXXX")
        #try:
        #    import tabulate
        #    tabular_data = [[str(idx)] + list(str(x) for x in td) for idx, (label, td) in enumerate(zip(pks, Mexp_tot))]
        #    print("Mexp_tot", idx)
        #    print(Mexp.dtype)
        #    print(tabulate.tabulate(tabular_data))
        #except ImportError:
        #    print("XXXX")

        N_sparsity = 0
        solution = dict()
        for idx_in in range(len(pks)):
            for idx_out in range(len(pks)):
                edge = Mexp_tot[idx_out, idx_in]
                if np.any(edge != 0):
                    N_sparsity += 1
                    #pk_in = pks[idx_in]
                    #pk_out = pks[idx_out]
                    ###TODO: add debug config reference for this print
                    #print(pk_in)
                    #print(pk_out)
                    #print(idx_in, idx_out, edge)
                    pkin = self.in_map[pks[idx_in]]
                    pkout = self.out_map[pks[idx_out]]
                    if pkout is None:
                        continue
                    solution[pkin, pkout] = edge
        #print("Sparsity: ", N_sparsity / len(Mexp_tot.flatten()))

        #pprint(pks)

        self.solution = solution

    def generate_solution_RK(self, sol_vector):
        pks = self.pks
        pkv = np.empty(len(pks), dtype=object)
        for idx, pk in enumerate(pks):
            #print("PK_G: ", (ins_p, pk))
            pkv[idx] = sol_vector.get(self.in_map[pk], 0)

        def lt_val(lt, pkv):
            if isinstance(lt, list):
                val = 0
                for sublt in lt:
                    val = lt_val(sublt, pkv) + val
            elif isinstance(lt, tuple):
                val = lt[0]
                #print("LT0: ", val)
                for pk_idx in lt[1:]:
                    val = val * pkv[pk_idx]  # sol_vector.get(pk, 0)
            else:
                raise RuntimeError("BOO")
            return val

        eye = np.eye(len(pks), dtype = object)
        Mexp_tot = eye
        for idx_N in range(self.N_ode):
            mk1 = np.zeros([len(pks), len(pks)], dtype = object)
            #the current ddlt_accel already incorporates h, so we must reverse that for the Runge Kutta Solver
            h = 1 / self.N_ode
            for (idx_out, idx_in), lt in self.ddlt_accel.items():
                    val = lt_val(lt, pkv) / h
                    mk1[idx_out, idx_in] = val

            pkv_k1 = np.dot(mk1, pkv.reshape(-1, 1)).reshape(-1)
            mk2 = np.zeros([len(pks), len(pks)], dtype = object)
            for (idx_out, idx_in), lt in self.ddlt_accel.items():
                    val = lt_val(lt, pkv + h/2 * pkv_k1) / h
                    mk2[idx_out, idx_in] = val

            pkv_k2 = np.dot(mk2, pkv.reshape(-1, 1)).reshape(-1)
            mk3 = np.zeros([len(pks), len(pks)], dtype = object)
            for (idx_out, idx_in), lt in self.ddlt_accel.items():
                    val = lt_val(lt, pkv + h/2 * pkv_k2) / h
                    mk3[idx_out, idx_in] = val

            pkv_k3 = np.dot(mk3, pkv.reshape(-1, 1)).reshape(-1)
            mk4 = np.zeros([len(pks), len(pks)], dtype = object)
            for (idx_out, idx_in), lt in self.ddlt_accel.items():
                    val = lt_val(lt, pkv + h * pkv_k3) / h
                    mk4[idx_out, idx_in] = val

            #try:
            #    import tabulate
            #    tabular_data = [[str(idx)] + list(str(t) for t in td) for idx, (label, td) in enumerate(zip(pks, mk1))]
            #    print("MK1")
            #    print(tabulate.tabulate(tabular_data))
            #    tabular_data = [[str(idx)] + list(str(t) for t in td) for idx, (label, td) in enumerate(zip(pks, mk2))]
            #    print("MK2")
            #    print(tabulate.tabulate(tabular_data))
            #    tabular_data = [[str(idx)] + list(str(t) for t in td) for idx, (label, td) in enumerate(zip(pks, mk3))]
            #    print("MK3")
            #    print(tabulate.tabulate(tabular_data))
            #    tabular_data = [[str(idx)] + list(str(t) for t in td) for idx, (label, td) in enumerate(zip(pks, mk4))]
            #    print("MK4")
            #    print(tabulate.tabulate(tabular_data))
            #except ImportError:
            #    print("XXXX")

            Mexp = eye + h/6 * (mk1 + 2 * mk2 + 2 * mk3 + mk4)
            Mexp_tot = np.dot(Mexp, Mexp_tot)
            pkv = np.dot(Mexp, pkv.reshape(-1, 1)).reshape(-1)

        solution = dict()
        for idx_in in range(len(pks)):
            for idx_out in range(len(pks)):
                edge = Mexp_tot[idx_out, idx_in]
                if np.any(edge != 0):
                    #pk_in = pks[idx_in]
                    #pk_out = pks[idx_out]
                    ###TODO: add debug config reference for this print
                    #print(pk_in)
                    #print(pk_out)
                    #print(idx_in, idx_out, edge)
                    pkin = self.in_map[pks[idx_in]]
                    pkout = self.out_map[pks[idx_out]]
                    if pkout is None:
                        continue
                    solution[pkin, pkout] = edge

        #pprint(pks)

        self.solution = solution

"""
