#!/usr/bin/env python3

# Convert a sector<N>.info file into:
# - sector_<N>.cpp
# - sector_<N>_*.cpp
# - sector_<N>_*.hpp
# - contour_deformation_sector_<N>_*.cpp
# - contour_deformation_sector_<N>_*.hpp
# - optimize_deformation_parameters_sector_<N>_*.cpp
# - optimize_deformation_parameters_sector_<N>_*.hpp
#
# Usage: python3 export_sector sector_<N>.info destination-dir

import collections
import contextlib
import glob
import os
import os.path
import re
import sys

def load_info(filename):
    """
    Load a dictionary from a FORM-formatted .info file. The
    syntax is:
        @key1=value1
        @key2=val
        ue 2
        @end
    The `@end` delimiter is optional; all whitespace is stripped
    from the values because FORM can't keep its hands away from
    it. A backslash at the end of a line is also stripped, because
    FORM does that too. Fail if the word `FAIL` appears in the
    file, as a precaution.
    """
    with open(filename, "r") as f:
        text = f.read()
    assert "FAIL" not in text
    parts = re.split("^@([a-zA-Z0-9_]* *=?)", text, flags=re.M)
    result = {}
    for i in range(1, len(parts), 2):
        key = parts[i]
        val = parts[i+1]
        if key.endswith("="):
            result[key.strip(" =")] = re.sub(r"\\$|[ \t\n]", "", val, flags=re.M)
    return result

def getlist(text, separator=","):
    return [p for p in text.strip(separator).split(separator) if p]

def getintlist(text, separator=","):
    return [int(p) for p in text.strip(separator).split(separator) if p]

def sed(text, rx, template):
    return re.sub(rx, template, text, flags=re.M)

def cleanup_code(text):
    """
    Take a code dump from FORM, reformat it into C++, rename
    local variables so that each one is only assigned once.
    """
    code = text.replace(";", ";\n").replace("\n\n", "\n")
    code = sed(code, r"SecDecInternalAbbreviation\[([0-9]+)\]", r"tmp1_\1")
    code = sed(code, r"SecDecInternalAbbreviations[0-9]+\(([0-9]+)\)", r"tmp1_\1")
    code = sed(code, r"SecDecInternalSecondAbbreviation\[([0-9]+)\]", r"tmp2_\1")
    code = sed(code, r"pow\(([a-zA-Z0-9]*),2\)", r"\1*\1")
    code = sed(code, r"pow\(([a-zA-Z0-9]*),3\)", r"\1*\1*\1")
    code = sed(code, r"pow\(([a-zA-Z0-9]*),4\)", r"(\1*\1)*(\1*\1)")
    code = sed(code, r"[.]E[+]0([^0-9])", r"\1")
    code = sed(code, r" *= *", " = ")
    code = sed(code, r" *[+] *([^0-9])", r" + \1")
    code = sed(code, r" *, *", ", ")
    code = sed(code, r"  *", r" ")
    # Switch to static single assignment form, for variable type
    # stability. Doing this with regular expressions is super
    # dodgy...
    old2new = {}
    uniqindex = 1
    lines = []
    for line in code.splitlines():
        m = re.match("^([a-zA-Z0-9_]+) *= *([a-zA-Z0-9_]+);$", line)
        if m is not None:
            var, expr = m.groups()
            old2new[var] = old2new.get(expr, expr)
            continue
        m = re.match("^([a-zA-Z0-9_]+) *= *(.*)$", line)
        if m is None:
            line = re.sub("[a-zA-Z0-9_]+", lambda m: old2new.get(m.group(0), m.group(0)), line)
        else:
            var, expr = m.groups()
            expr = re.sub("[a-zA-Z0-9_]+", lambda m: old2new.get(m.group(0), m.group(0)), expr)
            if var in old2new:
                old2new[var] = f"tmp3_{uniqindex}"
                uniqindex += 1
            else:
                old2new[var] = var.replace("SecDecInternal", "_")
            line = f"auto {old2new[var]} = {expr}"
        lines.append(line)
    code = "\n".join(lines)
    ret = old2new.get("tmp", "tmp")
    return code

def template_writer(template_source, *argnames):
    """
    A templating language: turns each `${code}` into `{code}`,
    wraps each line of the template in `print(f"...")` except
    for lines that start with `@@ ` -- those are left as they
    are. Returns a function that runs the resulting code,
    and writes the output into a file.
    """
    result = []
    indent = 0
    for line in template_source.splitlines():
        if line.startswith("@@ "):
            line = line[3:]
            result.append(line)
            indent = len(line) - len(line.lstrip(" "))
            if line.endswith(":"): indent += 4
        else:
            line += '\n'
            line = "".join(
                part if i %% 2 == 1 else \
                part.replace("{", "{{").replace("}", "}}")
                for i, part in enumerate(re.split(r"\$(\{[^}]*})", line))
            )
            result.append(f"{' '*indent}_write(f{line!r})")
    body = '\n    '.join(result)
    code = f"""\
def _template_fn(_output{"".join(", " + a for a in argnames)}):
    _write = _output.write
    {body}
    pass
"""
    variables = {}
    exec(code, None, variables)
    return variables["_template_fn"]

class DictionaryWrapper:
    """
    A small wrapper around dictionaries that allows one to access
    values as x.key in addition to x["key"].
    """
    def __init__(self, dict):
        self.dict = dict
    def __getattr__(self, key):
        return self.dict[key]
    def __getitem__(self, key):
        return self.dict[key]

def make_list(python_list):
    return ','.join(str(item) for item in python_list)

def make_CXX_Series_initialization(regulator_names, min_orders, max_orders, sector_ID, contour_deformation, numIV):
    '''
    Return the c++ code that initilizes the container class
    (``Series<Series<...<Series<IntegrandContainer>>...>``).
    '''
    assert len(min_orders) == len(max_orders)
    last_regulator_index = len(min_orders) - 1
    def multiindex_to_cpp_order(multiindex):
        '(-1,3,2,-4) --> n1_3_2_n4'
        snippets = []
        for order in multiindex:
            snippets.append(str(order).replace('-','n'))
        return '_'.join(snippets)
    current_orders = min_orders.copy() # use as nonlocal variable in `recursion`
    def recursion(regulator_index):
        if regulator_index < last_regulator_index:
            outstr_body_snippets = []
            outstr_head = '{%%i,%%i,{' %% (min_orders[regulator_index],max_orders[regulator_index])
            for this_regulator_order in range(min_orders[regulator_index],max_orders[regulator_index]+1):
                current_orders[regulator_index] = this_regulator_order
                outstr_body_snippets.append( recursion(regulator_index + 1) )
            outstr_tail = '},true,"%%s"}' %% (regulator_names[regulator_index],)
            return ''.join( (outstr_head, ','.join(outstr_body_snippets), outstr_tail) )
        else: # regulator_index == last_regulator_index; i.e. processing last regulator
            outstr_head = '{%%i,%%i,{{' %% (min_orders[regulator_index],max_orders[regulator_index])
            outstr_body_snippets = []
            for this_regulator_order in range(min_orders[regulator_index],max_orders[regulator_index]+1):
                current_orders[regulator_index] = this_regulator_order
                cpp_order = multiindex_to_cpp_order(current_orders)
                order = make_list(current_orders)
                if contour_deformation:
                    outstr_body_snippets.append(
                        f'{sector_ID},{{{order}}},{numIV[cpp_order]},sector_{sector_ID}_order_{cpp_order}_integrand,\n'
                        f'#ifdef SECDEC_WITH_CUDA\n'
                        f'get_device_sector_{sector_ID}_order_{cpp_order}_integrand,\n'
                        f'#endif\n'
                        f'sector_{sector_ID}_order_{cpp_order}_contour_deformation_polynomial,'
                        f'sector_{sector_ID}_order_{cpp_order}_maximal_allowed_deformation_parameters'
                    )
                else:
                    outstr_body_snippets.append(
                        f'{sector_ID},{{{order}}},{numIV[cpp_order]},sector_{sector_ID}_order_{cpp_order}_integrand\n'
                        f'#ifdef SECDEC_WITH_CUDA\n'
                        f',get_device_sector_{sector_ID}_order_{cpp_order}_integrand\n'
                        f'#endif\n'
                    )
            outstr_tail = '}},true,"%%s"}' %% (regulator_names[regulator_index],)
            return ''.join( (outstr_head, '},{'.join(outstr_body_snippets), outstr_tail) )
    return recursion(0)

SECTOR_CPP = template_writer("""\
#include <secdecutil/series.hpp>

@@ for oidx in range(1, int(i.numOrders) + 1):
@@     order_name = i[f"order{oidx}_name"]
@@     order_numIV = len(getlist(i[f"order{oidx}_integrationVariables"]))
@@     so = f"sector_{i.sector}_{order_name}"
@@     sorder = f"sector_{i.sector}_order_{order_name}"
#include "${so}.hpp"
@@     if int(i.contourDeformation):
#include "contour_deformation_${so}.hpp"
#include "optimize_deformation_parameters_${so}.hpp"
@@     pass
@@ pass

namespace ${i.namespace}
{
nested_series_t<sector_container_t> get_integrand_of_sector_${i.sector}()
{
@@ min_orders = [-o for o in getintlist(i.highestPoles)]
@@ numIV = {i[f"order{o}_name"] : len(getlist(i[f"order{o}_integrationVariables"])) for o in range(1, 1+int(i.numOrders))}
return ${make_CXX_Series_initialization(getlist(i.regulators), min_orders, getintlist(i.requiredOrders), int(i.sector), bool(int(i.contourDeformation)), numIV)};
}

}
""", "i")

SECTOR_ORDER_CPP = template_writer("""\
@@ so = f"sector_{i.sector}_{i.order_name}"
@@ sorder = f"sector_{i.sector}_order_{i.order_name}"
#include "${so}.hpp"
namespace ${i.namespace}
{
#ifdef SECDEC_WITH_CUDA
__host__ __device__
#endif
integrand_return_t ${sorder}_integrand
(
    real_t const * restrict const integration_variables,
    real_t const * restrict const real_parameters,
    complex_t const * restrict const complex_parameters,
@@ if int(i.contourDeformation):
    real_t const * restrict const deformation_parameters,
@@ pass
    secdecutil::ResultInfo * restrict const result_info
)
{
@@ for j, v in enumerate(getlist(i.order_integrationVariables)):
    auto ${v} = integration_variables[${j}];
@@ for j, v in enumerate(getlist(i.realParameters)):
    auto ${v} = real_parameters[${j}];
@@ for j, v in enumerate(getlist(i.complexParameters)):
    auto ${v} = complex_parameters[${j}];
@@ for j, v in enumerate(getlist(i.order_deformationParameters)):
    auto ${v} = deformation_parameters[${j}];
@@ code = cleanup_code(i.order_integrandBody)
@@ for line in code.splitlines():
    ${line}
@@ pass
}
#ifdef SECDEC_WITH_CUDA
@@ if int(i.contourDeformation):
__device__ secdecutil::SectorContainerWithDeformation<real_t, complex_t>::DeformedIntegrandFunction* const device_${sorder}_integrand = ${sorder}_integrand;
secdecutil::SectorContainerWithDeformation<real_t, complex_t>::DeformedIntegrandFunction* get_device_${sorder}_integrand()
{
    using IntegrandFunction = secdecutil::SectorContainerWithDeformation<real_t, complex_t>::DeformedIntegrandFunction;
@@ else:
__device__ secdecutil::SectorContainerWithoutDeformation<real_t, complex_t, integrand_return_t>::IntegrandFunction* const device_${sorder}_integrand = ${sorder}_integrand;
secdecutil::SectorContainerWithoutDeformation<real_t, complex_t, integrand_return_t>::IntegrandFunction* get_device_${sorder}_integrand()
{
    using IntegrandFunction = secdecutil::SectorContainerWithoutDeformation<real_t, complex_t, integrand_return_t>::IntegrandFunction;
@@ pass
    IntegrandFunction* device_address_on_host;
    auto errcode = cudaMemcpyFromSymbol(&device_address_on_host,device_${sorder}_integrand, sizeof(IntegrandFunction*));
    if (errcode != cudaSuccess) throw secdecutil::cuda_error( cudaGetErrorString(errcode) );
    return device_address_on_host;
}
#endif
}
""", "i")

SECTOR_ORDER_HPP = template_writer("""\
@@ so = f"sector_{i.sector}_{i.order_name}"
@@ sorder = f"sector_{i.sector}_order_{i.order_name}"
#ifndef ${i.namespace}_codegen_${so}_hpp_included
#define ${i.namespace}_codegen_${so}_hpp_included
#include "${i.namespace}.hpp"
#include "functions.hpp"
@@ if int(i.contourDeformation):
#include "contour_deformation_${so}.hpp"
@@ pass
namespace ${i.namespace}
{
#ifdef SECDEC_WITH_CUDA
__host__ __device__
#endif
@@ if int(i.contourDeformation):
secdecutil::SectorContainerWithDeformation<real_t, complex_t>::DeformedIntegrandFunction ${sorder}_integrand;
#ifdef SECDEC_WITH_CUDA
secdecutil::SectorContainerWithDeformation<real_t, complex_t>::DeformedIntegrandFunction* get_device_${sorder}_integrand();
#endif
@@ else:
secdecutil::SectorContainerWithoutDeformation<real_t, complex_t, integrand_return_t>::IntegrandFunction ${sorder}_integrand;
#ifdef SECDEC_WITH_CUDA
secdecutil::SectorContainerWithoutDeformation<real_t, complex_t, integrand_return_t>::IntegrandFunction* get_device_${sorder}_integrand();
#endif
@@ pass
}
#endif
""", "i")

CONTOUR_DEFORMATION_SECTOR_ORDER_CPP = template_writer("""\
@@ so = f"sector_{i.sector}_{i.order_name}"
@@ sorder = f"sector_{i.sector}_order_{i.order_name}"
#include "contour_deformation_${so}.hpp"
namespace ${i.namespace}
{
#ifdef SECDEC_WITH_CUDA
#define SecDecInternalRealPart(x) (complex_t{x}).real()
#else
#define SecDecInternalRealPart(x) std::real(x)
#endif
integrand_return_t ${sorder}_contour_deformation_polynomial
(
    real_t const * restrict const integration_variables,
    real_t const * restrict const real_parameters,
    complex_t const * restrict const complex_parameters,
    real_t const * restrict const deformation_parameters,
    secdecutil::ResultInfo * restrict const result_info
)
{
@@ for j, v in enumerate(getlist(i.order_integrationVariables)):
    auto ${v} = integration_variables[${j}];
@@ for j, v in enumerate(getlist(i.realParameters)):
    auto ${v} = real_parameters[${j}];
@@ for j, v in enumerate(getlist(i.complexParameters)):
    auto ${v} = complex_parameters[${j}];
@@ for j, v in enumerate(getlist(i.order_deformationParameters)):
    auto ${v} = deformation_parameters[${j}];
@@ code = cleanup_code(i.order_contourDeformationPolynomialBody)
@@ for line in code.splitlines():
    ${line}
@@ pass
}
}
""", "i")

CONTOUR_DEFORMATION_SECTOR_ORDER_HPP = template_writer("""\
@@ sorder = f"sector_{i.sector}_order_{i.order_name}"
#ifndef ${i.namespace}_codegen_contour_deformation_${sorder}_hpp_included
#define ${i.namespace}_codegen_contour_deformation_${sorder}_hpp_included
#include "${i.namespace}.hpp"
#include "functions.hpp"
namespace ${i.namespace}
{
secdecutil::SectorContainerWithDeformation<real_t, complex_t>::DeformedIntegrandFunction ${sorder}_contour_deformation_polynomial;
};
#endif
""", "i")

OPTIMIZE_DEFORMATION_PARAMETERS_SECTOR_ORDER_CPP = template_writer("""\
@@ so = f"sector_{i.sector}_{i.order_name}"
@@ sorder = f"sector_{i.sector}_order_{i.order_name}"
#include "optimize_deformation_parameters_${so}.hpp"
namespace ${i.namespace}
{
void ${sorder}_maximal_allowed_deformation_parameters
(
    real_t * restrict const output_deformation_parameters,
    real_t const * restrict const integration_variables,
    real_t const * restrict const real_parameters,
    complex_t const * restrict const complex_parameters,
    secdecutil::ResultInfo * restrict const result_info
)
{
@@ for j, v in enumerate(getlist(i.order_integrationVariables)):
    auto ${v} = integration_variables[${j}];
@@ for j, v in enumerate(getlist(i.realParameters)):
    auto ${v} = real_parameters[${j}];
@@ for j, v in enumerate(getlist(i.complexParameters)):
    auto ${v} = complex_parameters[${j}];
@@ code = cleanup_code(i.order_optimizeDeformationParametersBody)
@@ for line in code.splitlines():
    ${line}
@@ pass
}
}
""", "i")

OPTIMIZE_DEFORMATION_PARAMETERS_SECTOR_ORDER_HPP = template_writer("""\
@@ sorder = f"sector_{i.sector}_order_{i.order_name}"
#ifndef ${i.namespace}_codegen_optimize_deformation_parameters_${sorder}_hpp_included
#define ${i.namespace}_codegen_optimize_deformation_parameters_${sorder}_hpp_included
#include "${i.namespace}.hpp"
#include "functions.hpp"
#include <cmath>
#include <limits>
#include <vector>
namespace ${i.namespace}
{
secdecutil::SectorContainerWithDeformation<real_t, complex_t>::MaximalDeformationFunction ${sorder}_maximal_allowed_deformation_parameters;
};
#endif
""", "i")

if __name__ == "__main__":

    if len(sys.argv) != 3:
        print("usage: ${sys.argv[0]} sector.info destination-dir")
        exit(1)

    sectorfile = sys.argv[1]
    dstdir = sys.argv[2]

    info = load_info(sectorfile)

    fname = os.path.join(dstdir, "sector_" + info["sector"] + ".cpp")
    print("* Creating", fname)
    with open(fname, "w") as f:
        SECTOR_CPP(f, DictionaryWrapper(info))

    for oidx in range(1, int(info["numOrders"]) + 1):
        so = "sector_" + info["sector"] + "_" + info[f"order{oidx}_name"]
        files = {
            f"{so}.cpp": SECTOR_ORDER_CPP,
            f"{so}.hpp": SECTOR_ORDER_HPP,
        }
        if int(info["contourDeformation"]):
            files.update({
                f"contour_deformation_{so}.cpp": CONTOUR_DEFORMATION_SECTOR_ORDER_CPP,
                f"contour_deformation_{so}.hpp": CONTOUR_DEFORMATION_SECTOR_ORDER_HPP,
                f"optimize_deformation_parameters_{so}.cpp": OPTIMIZE_DEFORMATION_PARAMETERS_SECTOR_ORDER_CPP,
                f"optimize_deformation_parameters_{so}.hpp": OPTIMIZE_DEFORMATION_PARAMETERS_SECTOR_ORDER_HPP,
            })
        # Mungle the key names so that templates could access
        # the current order keys as `order_xxx`, instead of
        # `order<N>_xxx`, which is much more awkward.
        info_thisorder = DictionaryWrapper({
            key.replace(f"order{oidx}_", "order_") : value
            for key, value in info.items()
        })
        for filename, template in files.items():
            fname = os.path.join(dstdir, filename)
            print("* Creating", fname)
            with open(fname, "w") as f:
                template(f, info_thisorder)
