#!/usr/bin/env python3

# Convert a sector<N>.info file into:
# - src/sector_<N>.cpp
# - src/sector_<N>_*.cpp
# - src/sector_<N>_*.hpp
# - src/contour_deformation_sector_<N>_*.cpp
# - src/contour_deformation_sector_<N>_*.hpp
# - src/optimize_deformation_parameters_sector_<N>_*.cpp
# - src/optimize_deformation_parameters_sector_<N>_*.hpp
# - distsrc/sector_<N>.cpp
# - distsrc/sector_<N>.cu
#
# Usage: python3 export_sector sector_<N>.info destination-dir

import collections
import contextlib
import getopt
import glob
import os
import os.path
import re
import sys

def load_info(filename):
    """
    Load a dictionary from a FORM-formatted .info file. The
    syntax is:
        @key1=value1
        @key2=val
        ue 2
        @end
    The `@end` delimiter is optional; all whitespace is stripped
    from the values because FORM can't keep its hands away from
    it. A backslash at the end of a line is also stripped, because
    FORM does that too. Fail if the word `FAIL` appears in the
    file, as a precaution.
    """
    with open(filename, "r") as f:
        text = f.read()
    assert "FAIL" not in text
    parts = re.split("^@([a-zA-Z0-9_]* *=?)", text, flags=re.M)
    result = {}
    for i in range(1, len(parts), 2):
        key = parts[i]
        val = parts[i+1]
        if key.endswith("="):
            result[key.strip(" =")] = re.sub(r"\\$|[ \t\n]", "", val, flags=re.M)
    return result

def getlist(text, separator=","):
    return [p for p in text.strip(separator).split(separator) if p]

def getintlist(text, separator=","):
    return [int(p) for p in text.strip(separator).split(separator) if p]

def sed(text, rx, template):
    return re.sub(rx, template, text, flags=re.M)

def grep(text, rx):
    return re.findall(rx, text, flags=re.M)

def cleanup_code(text):
    """
    Take a code dump from FORM, reformat it into C++, rename
    local variables so that each one is only assigned once.
    """
    code = sed(text, " *;", ";\n")
    code = sed(code, "\n+", "\n")
    code = sed(code, r"SecDecInternalAbbreviation\[([0-9]+)\]", r"tmp1_\1")
    code = sed(code, r"SecDecInternalAbbreviations[0-9]+\(([0-9]+)\)", r"tmp1_\1")
    code = sed(code, r"SecDecInternalSecondAbbreviation\[([0-9]+)\]", r"tmp2_\1")
    code = sed(code, r"SecDecInternalSecondAbbreviation\(([0-9]+)\)", r"tmp2_\1")
    term = "[+-]?[a-zA-Z0-9_*]+"
    code = sed(code, fr"\bpow\(({term}),2\)", r"SecDecInternalSqr(\1)")
    code = sed(code, fr"\bpow\(({term}),3\)", r"SecDecInternalSqr(\1)*\1")
    code = sed(code, fr"\bpow\(({term}),4\)", r"SecDecInternalSqr(SecDecInternalSqr(\1))")
    code = sed(code, fr"\bpow\(({term}),([0-9]+)\)", r"SecDecInternalNPow(\1,\2)")
    code = sed(code, r"\bpow\(", r"SecDecInternalPow(")
    code = sed(code, r"\blog\(", r"SecDecInternalLog(")
    code = sed(code, r"if *\((.*)\) *SecDecInternalSignCheckError(.*)\(", r"SecDecInternalSignCheck\2(\1,")
    # Convert i_*expr into SecDecInternalI(expr)
    code = sed(code, fr"= *({term})\*i_\*({term});", r"= SecDecInternalI(\1*\2);")
    code = sed(code, fr"= *({term})\*i_;", r"= SecDecInternalI(\1);")
    code = sed(code, fr"= *([+-]?)i_\*({term});", r"= SecDecInternalI(\1\2);")
    code = sed(code, fr"= *([+-]?)i_;", r"= SecDecInternalI(\g<1>1);")
    # Make numbers either integrers or rationals (via SecDecInternalQuo)
    code = sed(code, r"[.]E[+]0([^0-9])", r"\1")
    code = sed(code, r"[.]([^0-9])", r"\1")
    assert "." not in code
    assert re.search(r"\b[0-9]*[eE][0-9+-]", code) is None
    code = sed(code, r"([0-9]+)/([0-9]+)", r"SecDecInternalQuo(\1,\2)")
    assert "/" not in code
    # Put some cosmetic spaces back in for prettier formatting
    code = sed(code, fr"\b *= *\b", r" = ")
    code = sed(code, r"\b *[+] *\b", r" + ")
    code = sed(code, r" *, *", ", ")
    code = sed(code, r"  *", r" ")
    # Switch to static single assignment form, for variable type
    # stability. Doing this with regular expressions is super
    # dodgy...
    old2new = {}
    uniqindex = 1
    lines = []
    for line in code.splitlines():
        m = re.match("^([a-zA-Z0-9_]+) *= *([a-zA-Z0-9_]+);$", line)
        if m is not None:
            var, expr = m.groups()
            old2new[var] = old2new.get(expr, expr)
            continue
        m = re.match("^([a-zA-Z0-9_]+) *= *(.*)$", line)
        if m is None:
            line = re.sub("[a-zA-Z0-9_]+", lambda m: old2new.get(m.group(0), m.group(0)), line)
        else:
            var, expr = m.groups()
            expr = re.sub("[a-zA-Z0-9_]+", lambda m: old2new.get(m.group(0), m.group(0)), expr)
            if var in old2new:
                old2new[var] = f"tmp3_{uniqindex}"
                uniqindex += 1
            else:
                old2new[var] = var.replace("SecDecInternal", "_")
            line = f"auto {old2new[var]} = {expr}"
        lines.append(line)
    code = "\n".join(lines)
    code = sed(code, r"^result_t (.*SecDecInternal.*Part.*)$", r"real_t \1")
    return code

assert cleanup_code("x=1*2./3.E+0*4.E+0;") == "auto x = 1*SecDecInternalQuo(2, 3)*4;"
assert cleanup_code("x=1/2*2/3;") == "auto x = SecDecInternalQuo(1, 2)*SecDecInternalQuo(2, 3);"
assert cleanup_code("x1e2=x+y;x<=1;y+=z;") == "auto x1e2 = x + y;\nx<=1;\ny+=z;"

def code_variables(code):
    return list(set(grep(code, "^auto ([^ ]+) ")))

def code_to_mma(code):
    code = sed(code, r"^auto ", "")
    code = sed(code, r"^return\(", "Return(")
    code = sed(code, r"\(", "[")
    code = sed(code, r"\)", "]")
    code = sed(code, r"_", "$")
    code = sed(code, r"\bi\b", "I")
    code = sed(code, r"\[\!\[", "[Not[")
    return code

assert code_to_mma("SignCheckCD(!(Val<=0),1)") == "SignCheckCD[Not[Val<=0],1]"

def template_writer(template_source, *argnames):
    """
    A templating language: turns each `${code}` into `{code}`,
    wraps each line of the template in `print(f"...")` except
    for lines that start with `@@ ` -- those are left as they
    are. Returns a function that runs the resulting code,
    and writes the output into a file.
    """
    result = []
    indent = 0
    for line in template_source.splitlines():
        if line.startswith("@@ "):
            line = line[3:]
            result.append(line)
            indent = len(line) - len(line.lstrip(" "))
            if line.endswith(":"): indent += 4
        else:
            line += '\n'
            line = "".join(
                part if i % 2 == 1 else \
                part.replace("{", "{{").replace("}", "}}")
                for i, part in enumerate(re.split(r"\$(\{[^}]*})", line))
            )
            result.append(f"{' '*indent}_write(f{line!r})")
    body = '\n    '.join(result)
    code = f"""\
def _template_fn(_output{"".join(", " + a for a in argnames)}):
    _write = _output.write
    {body}
    pass
"""
    variables = {}
    exec(code, None, variables)
    return variables["_template_fn"]

class DictionaryWrapper:
    """
    A small wrapper around dictionaries that allows one to access
    values as x.key in addition to x["key"].
    """
    def __init__(self, dict):
        self.dict = dict
    def __getattr__(self, key):
        return self.dict[key]
    def __getitem__(self, key):
        return self.dict[key]

def make_list(python_list):
    return ','.join(str(item) for item in python_list)

def make_CXX_Series_initialization(regulator_names, min_orders, max_orders, sector_ID, contour_deformation, numIV):
    '''
    Return the c++ code that initializes the container class
    (``Series<Series<...<Series<IntegrandContainer>>...>``).
    '''
    assert len(min_orders) == len(max_orders)
    last_regulator_index = len(min_orders) - 1
    def multiindex_to_cpp_order(multiindex):
        '(-1,3,2,-4) --> n1_3_2_n4'
        snippets = []
        for order in multiindex:
            snippets.append(str(order).replace('-','n'))
        return '_'.join(snippets)
    current_orders = min_orders.copy() # use as nonlocal variable in `recursion`
    def recursion(regulator_index):
        if regulator_index < last_regulator_index:
            outstr_body_snippets = []
            outstr_head = '{%i,%i,{' % (min_orders[regulator_index],max_orders[regulator_index])
            for this_regulator_order in range(min_orders[regulator_index],max_orders[regulator_index]+1):
                current_orders[regulator_index] = this_regulator_order
                outstr_body_snippets.append( recursion(regulator_index + 1) )
            outstr_tail = '},true,"%s"}' % (regulator_names[regulator_index],)
            return ''.join( (outstr_head, ','.join(outstr_body_snippets), outstr_tail) )
        else: # regulator_index == last_regulator_index; i.e. processing last regulator
            outstr_head = '{%i,%i,{{' % (min_orders[regulator_index],max_orders[regulator_index])
            outstr_body_snippets = []
            for this_regulator_order in range(min_orders[regulator_index],max_orders[regulator_index]+1):
                current_orders[regulator_index] = this_regulator_order
                cpp_order = multiindex_to_cpp_order(current_orders)
                order = make_list(current_orders)
                if contour_deformation:
                    outstr_body_snippets.append(
                        f'{sector_ID},{{{order}}},{numIV[cpp_order]},sector_{sector_ID}_order_{cpp_order}_integrand,\n'
                        f'#ifdef SECDEC_WITH_CUDA\n'
                        f'get_device_sector_{sector_ID}_order_{cpp_order}_integrand,\n'
                        f'#endif\n'
                        f'sector_{sector_ID}_order_{cpp_order}_contour_deformation_polynomial,'
                        f'sector_{sector_ID}_order_{cpp_order}_maximal_allowed_deformation_parameters'
                    )
                else:
                    outstr_body_snippets.append(
                        f'{sector_ID},{{{order}}},{numIV[cpp_order]},sector_{sector_ID}_order_{cpp_order}_integrand\n'
                        f'#ifdef SECDEC_WITH_CUDA\n'
                        f',get_device_sector_{sector_ID}_order_{cpp_order}_integrand\n'
                        f'#endif\n'
                    )
            outstr_tail = '}},true,"%s"}' % (regulator_names[regulator_index],)
            return ''.join( (outstr_head, '},{'.join(outstr_body_snippets), outstr_tail) )
    return recursion(0)

SECTOR_CPP = template_writer("""\
#include <secdecutil/series.hpp>

@@ for oidx in range(1, int(i.numOrders) + 1):
@@     order_name = i[f"order{oidx}_name"]
@@     order_numIV = len(getlist(i[f"order{oidx}_integrationVariables"]))
@@     so = f"sector_{i.sector}_{order_name}"
@@     sorder = f"sector_{i.sector}_order_{order_name}"
#include "${so}.hpp"
@@     if int(i.contourDeformation):
#include "contour_deformation_${so}.hpp"
#include "optimize_deformation_parameters_${so}.hpp"
@@     pass
@@ pass

namespace ${i.namespace}
{
nested_series_t<sector_container_t> get_integrand_of_sector_${i.sector}()
{
@@ min_orders = [-o for o in getintlist(i.highestPoles)]
@@ numIV = {i[f"order{o}_name"] : len(getlist(i[f"order{o}_integrationVariables"])) for o in range(1, 1+int(i.numOrders))}
return ${make_CXX_Series_initialization(getlist(i.regulators), min_orders, getintlist(i.requiredOrders), int(i.sector), bool(int(i.contourDeformation)), numIV)};
}

}
""", "i")

SECTOR_ORDER_CPP = template_writer("""\
@@ so = f"sector_{i.sector}_{i.order_name}"
@@ sorder = f"sector_{i.sector}_order_{i.order_name}"
#include "${so}.hpp"
namespace ${i.namespace}
{
#ifdef SECDEC_WITH_CUDA
__host__ __device__
#endif
integrand_return_t ${sorder}_integrand
(
    real_t const * restrict const integration_variables,
    real_t const * restrict const real_parameters,
    complex_t const * restrict const complex_parameters,
@@ if int(i.contourDeformation):
    real_t const * restrict const deformation_parameters,
@@ pass
    secdecutil::ResultInfo * restrict const result_info
)
{
@@ for j, v in enumerate(getlist(i.order_integrationVariables)):
    const auto ${v} = integration_variables[${j}]; (void)${v};
@@ for j, v in enumerate(getlist(i.realParameters)):
    const auto ${v} = real_parameters[${j}]; (void)${v};
@@ for j, v in enumerate(getlist(i.complexParameters)):
    const auto ${v} = complex_parameters[${j}]; (void)${v};
@@ for j, v in enumerate(getlist(i.order_deformationParameters)):
    const auto ${v} = deformation_parameters[${j}]; (void)${v};
@@ code = cleanup_code(i.order_integrandBody)
@@ for line in code.splitlines():
    ${line}
@@ pass
}
#ifdef SECDEC_WITH_CUDA
@@ if int(i.contourDeformation):
__device__ secdecutil::SectorContainerWithDeformation<real_t, complex_t>::DeformedIntegrandFunction* const device_${sorder}_integrand = ${sorder}_integrand;
secdecutil::SectorContainerWithDeformation<real_t, complex_t>::DeformedIntegrandFunction* get_device_${sorder}_integrand()
{
    using IntegrandFunction = secdecutil::SectorContainerWithDeformation<real_t, complex_t>::DeformedIntegrandFunction;
@@ else:
__device__ secdecutil::SectorContainerWithoutDeformation<real_t, complex_t, integrand_return_t>::IntegrandFunction* const device_${sorder}_integrand = ${sorder}_integrand;
secdecutil::SectorContainerWithoutDeformation<real_t, complex_t, integrand_return_t>::IntegrandFunction* get_device_${sorder}_integrand()
{
    using IntegrandFunction = secdecutil::SectorContainerWithoutDeformation<real_t, complex_t, integrand_return_t>::IntegrandFunction;
@@ pass
    IntegrandFunction* device_address_on_host;
    auto errcode = cudaMemcpyFromSymbol(&device_address_on_host,device_${sorder}_integrand, sizeof(IntegrandFunction*));
    if (errcode != cudaSuccess) throw secdecutil::cuda_error( cudaGetErrorString(errcode) );
    return device_address_on_host;
}
#endif
}
""", "i")

SECTOR_ORDER_HPP = template_writer("""\
@@ so = f"sector_{i.sector}_{i.order_name}"
@@ sorder = f"sector_{i.sector}_order_{i.order_name}"
#ifndef ${i.namespace}_codegen_${so}_hpp_included
#define ${i.namespace}_codegen_${so}_hpp_included
#include "${i.namespace}.hpp"
#include "functions.hpp"
@@ if int(i.contourDeformation):
#include "contour_deformation_${so}.hpp"
@@ pass
namespace ${i.namespace}
{
#ifdef SECDEC_WITH_CUDA
__host__ __device__
#endif
@@ if int(i.contourDeformation):
secdecutil::SectorContainerWithDeformation<real_t, complex_t>::DeformedIntegrandFunction ${sorder}_integrand;
#ifdef SECDEC_WITH_CUDA
secdecutil::SectorContainerWithDeformation<real_t, complex_t>::DeformedIntegrandFunction* get_device_${sorder}_integrand();
#endif
@@ else:
secdecutil::SectorContainerWithoutDeformation<real_t, complex_t, integrand_return_t>::IntegrandFunction ${sorder}_integrand;
#ifdef SECDEC_WITH_CUDA
secdecutil::SectorContainerWithoutDeformation<real_t, complex_t, integrand_return_t>::IntegrandFunction* get_device_${sorder}_integrand();
#endif
@@ pass
}
#endif
""", "i")

CONTOUR_DEFORMATION_SECTOR_ORDER_CPP = template_writer("""\
@@ so = f"sector_{i.sector}_{i.order_name}"
@@ sorder = f"sector_{i.sector}_order_{i.order_name}"
#include "contour_deformation_${so}.hpp"
namespace ${i.namespace}
{
#ifdef SECDEC_WITH_CUDA
#define SecDecInternalRealPart(x) (complex_t{x}).real()
#else
#define SecDecInternalRealPart(x) std::real(x)
#endif
integrand_return_t ${sorder}_contour_deformation_polynomial
(
    real_t const * restrict const integration_variables,
    real_t const * restrict const real_parameters,
    complex_t const * restrict const complex_parameters,
    real_t const * restrict const deformation_parameters,
    secdecutil::ResultInfo * restrict const result_info
)
{
@@ for j, v in enumerate(getlist(i.order_integrationVariables)):
    const auto ${v} = integration_variables[${j}]; (void)${v};
@@ for j, v in enumerate(getlist(i.realParameters)):
    const auto ${v} = real_parameters[${j}]; (void)${v};
@@ for j, v in enumerate(getlist(i.complexParameters)):
    const auto ${v} = complex_parameters[${j}]; (void)${v};
@@ for j, v in enumerate(getlist(i.order_deformationParameters)):
    const auto ${v} = deformation_parameters[${j}]; (void)${v};
@@ code = cleanup_code(i.order_contourDeformationPolynomialBody)
@@ for line in code.splitlines():
    ${line}
@@ pass
}
}
""", "i")

CONTOUR_DEFORMATION_SECTOR_ORDER_HPP = template_writer("""\
@@ sorder = f"sector_{i.sector}_order_{i.order_name}"
#ifndef ${i.namespace}_codegen_contour_deformation_${sorder}_hpp_included
#define ${i.namespace}_codegen_contour_deformation_${sorder}_hpp_included
#include "${i.namespace}.hpp"
#include "functions.hpp"
namespace ${i.namespace}
{
secdecutil::SectorContainerWithDeformation<real_t, complex_t>::DeformedIntegrandFunction ${sorder}_contour_deformation_polynomial;
};
#endif
""", "i")

OPTIMIZE_DEFORMATION_PARAMETERS_SECTOR_ORDER_CPP = template_writer("""\
@@ so = f"sector_{i.sector}_{i.order_name}"
@@ sorder = f"sector_{i.sector}_order_{i.order_name}"
#include "optimize_deformation_parameters_${so}.hpp"
namespace ${i.namespace}
{
void ${sorder}_maximal_allowed_deformation_parameters
(
    real_t * restrict const output_deformation_parameters,
    real_t const * restrict const integration_variables,
    real_t const * restrict const real_parameters,
    complex_t const * restrict const complex_parameters,
    secdecutil::ResultInfo * restrict const result_info
)
{
@@ for j, v in enumerate(getlist(i.order_integrationVariables)):
    const auto ${v} = integration_variables[${j}]; (void)${v};
@@ for j, v in enumerate(getlist(i.realParameters)):
    const auto ${v} = real_parameters[${j}]; (void)${v};
@@ for j, v in enumerate(getlist(i.complexParameters)):
    const auto ${v} = complex_parameters[${j}]; (void)${v};
@@ code = cleanup_code(i.order_optimizeDeformationParametersBody)
@@ for line in code.splitlines():
    ${line}
@@ pass
}
}
""", "i")

OPTIMIZE_DEFORMATION_PARAMETERS_SECTOR_ORDER_HPP = template_writer("""\
@@ sorder = f"sector_{i.sector}_order_{i.order_name}"
#ifndef ${i.namespace}_codegen_optimize_deformation_parameters_${sorder}_hpp_included
#define ${i.namespace}_codegen_optimize_deformation_parameters_${sorder}_hpp_included
#include "${i.namespace}.hpp"
#include "functions.hpp"
#include <cmath>
#include <limits>
#include <vector>
namespace ${i.namespace}
{
secdecutil::SectorContainerWithDeformation<real_t, complex_t>::MaximalDeformationFunction ${sorder}_maximal_allowed_deformation_parameters;
};
#endif
""", "i")

VECSIZE = 4
DIST_SECTOR_ORDER_CPP = template_writer("""\
@@ complex = i.complexParameters or int(i.contourDeformation) or int(i.enforceComplex)
#define SECDEC_RESULT_IS_COMPLEX ${1 if complex else 0}
#include "common_cpu.h"

#define SecDecInternalSignCheckPositivePolynomial(cond, id) if (unlikely(cond)) {*presult = REAL_NAN; return 1; }
#define SecDecInternalSignCheckContourDeformation(cond, id) if (unlikely(cond)) {*presult = REAL_NAN; return 2; }

extern "C" int
${i.namespace}__sector_${i.sector}_order_${i.order_name}(
    result_t * restrict presult,
    const uint64_t lattice,
    const uint64_t index1,
    const uint64_t index2,
    const uint64_t * restrict genvec,
    const real_t * restrict shift,
    const real_t * restrict realp,
    const complex_t * restrict complexp,
    const real_t * restrict deformp
)
{
@@ for j, v in enumerate(getlist(i.realParameters)):
    const real_t ${v} = realp[${j}]; (void)${v};
@@ for j, v in enumerate(getlist(i.complexParameters)):
    const complex_t ${v} = complexp[${j}]; (void)${v};
@@ for j, v in enumerate(getlist(i.order_deformationParameters)):
    const real_t ${v} = deformp[${j}];
@@ pass
    const real_t invlattice = SecDecInternalDenominator((real_t)(double)lattice);
    resultvec_t acc = RESULTVEC_ZERO;
    uint64_t index = index1;
@@ intvars = getlist(i.order_integrationVariables)
@@ for j, v in enumerate(intvars):
    int_t li_${v} = mulmod(genvec[${j}], index, lattice);
@@ pass
    for (; index < index2; index += ${VECSIZE}) {
@@ for j, v in enumerate(intvars):
@@     for k in range(VECSIZE):
        int_t li_${v}_${k} = li_${v}; li_${v} = warponce_i(li_${v} + genvec[${j}], lattice);
@@     li_list = ", ".join(f"li_{v}_{k}*invlattice" for k in range(VECSIZE))
        realvec_t ${v} = {{ ${li_list} }};
        ${v} = warponce(${v} + shift[${j}], 1);
@@ pass
@@ for j, v in enumerate(intvars):
        auto w_${v} = ${i.qmcTransform}_w(${v});
@@ pass
        realvec_t w = ${"*".join("w_" + v for v in intvars) if intvars else "REALVEC_CONST(1)"};
@@ for k in range(1, VECSIZE):
        if (unlikely(index + ${k} >= index2)) w.x[${k}] = 0;
@@ for j, v in enumerate(intvars):
        ${v} = ${i.qmcTransform}_f(${v});
@@ for line in cleanup_code(i.order_integrandBody).splitlines():
        ${line.replace("return(", "acc = acc + w*(")}
@@ pass
    }
    *presult = componentsum(acc);
    return 0;
}

@@ if int(i.contourDeformation):
#define SecDecInternalOutputDeformationParameters(i, v) deformp[i] = vec_min(deformp[i], v);

extern "C" void
${i.namespace}__sector_${i.sector}_order_${i.order_name}__maxdeformp(
    real_t * restrict maxdeformp,
    const uint64_t lattice,
    const uint64_t index1,
    const uint64_t index2,
    const uint64_t * restrict genvec,
    const real_t * restrict shift,
    const real_t * restrict realp,
    const complex_t * restrict complexp
)
{
@@     for j, v in enumerate(getlist(i.realParameters)):
    const real_t ${v} = realp[${j}]; (void)${v};
@@     for j, v in enumerate(getlist(i.complexParameters)):
    const complex_t ${v} = complexp[${j}]; (void)${v};
@@     ndeformp = len(getlist(i.order_deformationParameters))
@@     deformp_init = ', '.join(['REALVEC_CONST(10.0)']*ndeformp)
    const real_t invlattice = SecDecInternalDenominator((real_t)(double)lattice);
    realvec_t deformp[${ndeformp}] = { ${deformp_init} };
    uint64_t index = index1;
@@     intvars = getlist(i.order_integrationVariables)
@@     for j, v in enumerate(intvars):
    int_t li_${v} = mulmod(genvec[${j}], index, lattice);
@@     pass
    for (; index < index2; index += ${VECSIZE}) {
@@     for j, v in enumerate(intvars):
@@         for k in range(VECSIZE):
        int_t li_${v}_${k} = li_${v}; li_${v} = warponce_i(li_${v} + genvec[${j}], lattice);
@@         li_list = ", ".join(f"li_{v}_{k}*invlattice" for k in range(VECSIZE))
        realvec_t ${v} = {{ ${li_list} }};
        ${v} = warponce(${v} + shift[${j}], 1);
@@     for j, v in enumerate(intvars):
        ${v} = ${i.qmcTransform}_f(${v});
@@     code = cleanup_code(i.order_optimizeDeformationParametersBody)
@@     for line in code.splitlines():
        ${line}
@@     pass
    }
@@     for j, v in enumerate(getlist(i.order_deformationParameters)):
    maxdeformp[${j}] = componentmin(deformp[${j}]);
@@     pass
}

extern "C" int
${i.namespace}__sector_${i.sector}_order_${i.order_name}__fpolycheck(
    const uint64_t lattice,
    const uint64_t index1,
    const uint64_t index2,
    const uint64_t * restrict genvec,
    const real_t * restrict shift,
    const real_t * restrict realp,
    const complex_t * restrict complexp,
    const real_t * restrict deformp
)
{
@@     for j, v in enumerate(getlist(i.realParameters)):
    const real_t ${v} = realp[${j}]; (void)${v};
@@     for j, v in enumerate(getlist(i.complexParameters)):
    const complex_t ${v} = complexp[${j}]; (void)${v};
@@     for j, v in enumerate(getlist(i.order_deformationParameters)):
    const real_t ${v} = deformp[${j}]; (void)${v};
@@     pass
    const real_t invlattice = SecDecInternalDenominator((real_t)(double)lattice);
    uint64_t index = index1;
@@     intvars = getlist(i.order_integrationVariables)
@@     for j, v in enumerate(intvars):
    int_t li_${v} = mulmod(genvec[${j}], index, lattice);
@@     pass
    for (; index < index2; index += ${VECSIZE}) {
@@     for j, v in enumerate(intvars):
@@         for k in range(VECSIZE):
        int_t li_${v}_${k} = li_${v}; li_${v} = warponce_i(li_${v} + genvec[${j}], lattice);
@@         li_list = ", ".join(f"li_{v}_{k}*invlattice" for k in range(VECSIZE))
        realvec_t ${v} = {{ ${li_list} }};
        ${v} = warponce(${v} + shift[${j}], 1);
@@     for j, v in enumerate(intvars):
        ${v} = ${i.qmcTransform}_f(${v});
@@     code = cleanup_code(i.order_contourDeformationPolynomialBody)
@@     for line in code.splitlines():
        ${line.replace("return(", "auto fpoly_im = SecDecInternalImagPart(")}
@@     pass
        if (unlikely(!(fpoly_im <= 0))) return 1;
    }
    return 0;
}
""", "i")

DIST_SECTOR_ORDER_CU = template_writer("""\
@@ complex = i.complexParameters or int(i.contourDeformation) or int(i.enforceComplex)
#define SECDEC_RESULT_IS_COMPLEX ${1 if complex else 0}
#include "common_cuda.h"

#define SecDecInternalSignCheckPositivePolynomial(cond, id) if (unlikely(cond)) {val = REAL_NAN; break;}
#define SecDecInternalSignCheckContourDeformation(cond, id) if (unlikely(cond)) {val = REAL_NAN; break;}

extern "C" __global__ void
${i.namespace}__sector_${i.sector}_order_${i.order_name}(
    result_t * __restrict__ result,
    const uint64_t lattice,
    const uint64_t index1,
    const uint64_t index2,
    const uint64_t * __restrict__ genvec,
    const real_t * __restrict__ shift,
    const real_t * __restrict__ realp,
    const complex_t * __restrict__ complexp,
    const real_t * __restrict__ deformp
)
{
    // assert(blockDim.x == 128);
    const uint64_t bid = blockIdx.x;
    const uint64_t tid = threadIdx.x;
@@ for j, v in enumerate(getlist(i.realParameters)):
    const real_t ${v} = realp[${j}]; (void)${v};
@@ for j, v in enumerate(getlist(i.complexParameters)):
    const complex_t ${v} = complexp[${j}]; (void)${v};
@@ for j, v in enumerate(getlist(i.order_deformationParameters)):
    const real_t ${v} = deformp[${j}]; (void)${v};
@@ pass
    const real_t invlattice = SecDecInternalDenominator((real_t)(double)lattice);
    result_t val = (result_t)0;
    uint64_t index = index1 + (bid*128 + tid)*8;
@@ intvars = getlist(i.order_integrationVariables)
@@ for j, v in enumerate(intvars):
    uint64_t li_${v} = mulmod(index, genvec[${j}], lattice);
@@ pass
    for (uint64_t i = 0; (i < 8) && (index < index2); i++, index++) {
@@ for j, v in enumerate(intvars):
        real_t ${v} = warponce(invlattice*(double)li_${v} + shift[${j}], 1.0);
        li_${v} = warponce_i(li_${v} + genvec[${j}], lattice);
@@ for j, v in enumerate(intvars):
        real_t w_${v} = ${i.qmcTransform}_w(${v});
@@ pass
        real_t w = ${"*".join("w_" + v for v in intvars) if intvars else "1"};
@@ for j, v in enumerate(intvars):
        ${v} = ${i.qmcTransform}_f(${v});
@@ for line in cleanup_code(i.order_integrandBody).splitlines():
        ${line.replace("return(", "val = val + w*(")}
@@ pass
    }
    // Sum up 128*8=1024 values across 4 warps.
    typedef cub::BlockReduce<result_t, 128, cub::BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY> Reduce;
    __shared__ typename Reduce::TempStorage shared;
    result_t sum = Reduce(shared).Sum(val);
    if (tid == 0) result[bid] = sum;
}
""", "i")

SECTOR_ORDER_MMA = template_writer("""\
SecDecInternalQuo[n_,d_] := n/d;
SecDecInternalI[ex_] := I * ex;
SecDecInternalRealPart[ex_] := Re[ex];
SecDecInternalImagPart[ex_] := Im[ex];
SecDecInternalAbs[ex_] := Abs[ex];
SecDecInternalDenominator[ex_] := 1/ex;
SecDecInternalNPow[ex_, n_] := ex^n;
SecDecInternalLog[ex_] := Conjugate[Log[Conjugate[ex]]];
SecDecInternalPow[ex_, n_] := Conjugate[Power[Conjugate[ex],n]];
SecDecInternalSignCheckPositivePolynomial[cond_, id_] If[cond, Throw[$PositivePolynomial[id]]];
SecDecInternalSignCheckContourDeformation[cond_, id_] If[cond, Throw[$ContourDeformation[id]]];
SecDecInternalSqr[ex_] := ex*ex;

${i.namespace.replace("_", "$")}$sector$${i.sector}$order$${i.order_name}Integrand[
    {${", ".join([x + "_" for x in getlist(i.order_integrationVariables)])}, ___},
    {${", ".join([x + "_" for x in getlist(i.realParameters) + getlist(i.complexParameters)])}},
    {${", ".join([x + "_" for x in getlist(i.order_deformationParameters)])}}
] :=
@@ code = cleanup_code(i.order_integrandBody).replace("_", "$")
@@ varnames = code_variables(code)
Module[{${", ".join(varnames)}},
@@ for line in code_to_mma(code).splitlines():
    ${line}
@@ pass
];

${i.namespace.replace("_", "$")}$sector$${i.sector}$order$${i.order_name}MaxDeformP[
    {${", ".join([x + "_" for x in getlist(i.order_integrationVariables)])}, ___},
    {${", ".join([x + "_" for x in getlist(i.realParameters) + getlist(i.complexParameters)])}}
@@ if int(i.contourDeformation):
] :=
@@     code = cleanup_code(i.order_optimizeDeformationParametersBody).replace("_", "$")
@@     varnames = code_variables(code)
Module[{${", ".join(varnames)}, SecDecInternalOutputDeformationParameters, deformp},
    deformp = Table[0, ${len(getlist(i.order_deformationParameters))}];
    SecDecInternalOutputDeformationParameters[i_, v_] := (deformp[[i+1]] = v);
@@     for line in code_to_mma(code).splitlines():
    ${line}
@@     pass
    deformp
];
@@ else:
] := {}
@@ pass

${i.namespace.replace("_", "$")}$sector$${i.sector}$order$${i.order_name}F[
    {${", ".join([x + "_" for x in getlist(i.order_integrationVariables)])}, ___},
    {${", ".join([x + "_" for x in getlist(i.realParameters) + getlist(i.complexParameters)])}},
    {${", ".join([x + "_" for x in getlist(i.order_deformationParameters)])}}
@@ if int(i.contourDeformation):
] :=
@@     code = cleanup_code(i.order_contourDeformationPolynomialBody).replace("_", "$")
@@     varnames = code_variables(code)
Module[{${", ".join(varnames)}},
@@     for line in code_to_mma(code).splitlines():
    ${line}
@@     pass
];
@@ else:
] := None
@@ pass

${i.namespace.replace("_", "$")}$sector$${i.sector}$order$${i.order_name}DeformOK[
    integrationVariables_List,
    parameters_List,
    deformationParameters_List
@@ if int(i.contourDeformation):
] :=
  SecDecInternalImagPart[
    ${i.namespace.replace("_", "$")}$sector$${i.sector}$order$${i.order_name}F[
      integrationVariables, parameters, deformationParameters]] <= 0
@@ else:
] := True
@@ pass
""", "i")

if __name__ == "__main__":

    subset = ["cpp", "disteval"]
    optlist, arglist = getopt.gnu_getopt(sys.argv[1:], "", ["subset="])
    for opt, val in optlist:
        if opt == "--subset": subset = val.split(",")
    if len(arglist) != 2:
        print("usage: ${sys.argv[0]} [--subset=cpp,disteval,mma] sector.info destination-dir")
        exit(1)
    sectorfile = arglist[0]
    dstdir = arglist[1]

    info = load_info(sectorfile)

    if "cpp" in subset:
        fname = os.path.join(dstdir, "src/sector_" + info["sector"] + ".cpp")
        with open(fname, "w") as f:
            SECTOR_CPP(f, DictionaryWrapper(info))
    if "disteval" in subset:
        os.makedirs("distsrc", exist_ok=True)
    if "mma" in subset:
        os.makedirs("mma", exist_ok=True)

    for oidx in range(1, int(info["numOrders"]) + 1):
        so = "sector_" + info["sector"] + "_" + info[f"order{oidx}_name"]
        files = {}
        if "cpp" in subset:
            files.update({
                f"src/{so}.cpp": SECTOR_ORDER_CPP,
                f"src/{so}.hpp": SECTOR_ORDER_HPP,
            })
            if int(info["contourDeformation"]):
                files.update({
                    f"src/contour_deformation_{so}.cpp": CONTOUR_DEFORMATION_SECTOR_ORDER_CPP,
                    f"src/contour_deformation_{so}.hpp": CONTOUR_DEFORMATION_SECTOR_ORDER_HPP,
                    f"src/optimize_deformation_parameters_{so}.cpp": OPTIMIZE_DEFORMATION_PARAMETERS_SECTOR_ORDER_CPP,
                    f"src/optimize_deformation_parameters_{so}.hpp": OPTIMIZE_DEFORMATION_PARAMETERS_SECTOR_ORDER_HPP,
                })
        if "disteval" in subset:
            files.update({
                f"distsrc/{so}.cpp": DIST_SECTOR_ORDER_CPP,
                f"distsrc/{so}.cu": DIST_SECTOR_ORDER_CU,
            })
        if "mma" in subset:
            files.update({
                f"mma/{so}.m": SECTOR_ORDER_MMA,
            })
        # Mungle the key names so that templates could access
        # the current order keys as `order_xxx`, instead of
        # `order<N>_xxx`, which is much more awkward.
        info_thisorder = DictionaryWrapper({
            key.replace(f"order{oidx}_", "order_") : value
            for key, value in info.items()
        })
        for filename, template in files.items():
            fname = os.path.join(dstdir, filename)
            with open(fname, "w") as f:
                template(f, info_thisorder)
