#!python

import argparse
import collections
import fnmatch
import os
import sys

import toml
from hgdb import DebugSymbolTable

# chisel doesn't escape the keyword before handling to firrtl
# just some selected keywords that's commonly used
SV_KEYWORD = {"logic", "reg", "module", "wire", "begin", "end", "if", "case", "interface", "modport"}


def get_args():
    parser = argparse.ArgumentParser("Convert toml debug file to hgdb")
    parser.add_argument("input", help="Input toml file", type=str)
    parser.add_argument("output", help="Output symbol table file", type=str)
    parser.add_argument("--disable-firrtl-fix", action="store_true", dest="disable_firrtl_fix",
                        help="If set, no symbol mapping fix will be performed")
    parser.add_argument("--detail", help="If set, creates variables for child instances as well",
                        type=str, nargs='+', required=False)
    args = parser.parse_args()
    if not os.path.exists(args.input):
        print(args.input, "does not exist", file=sys.stderr)
        exit(1)
    if os.path.exists(args.output):
        os.remove(args.output)
    return args


def match_name(def_name, detail_names):
    if detail_names is None:
        return False
    for n in detail_names:
        if fnmatch.fnmatch(def_name, n):
            return True
    return False


def parse_toml(toml_filename):
    data = toml.load(toml_filename)
    return data


def get_tops(design):
    # compute top modules
    child_def = set()
    defs = set()
    for design_name, detail in design.items():
        defs.add(design_name)
        if "instances" in detail:
            for d in detail["instances"].values():
                child_def.add(d)
    result = []
    for d in defs:
        if d not in child_def:
            result.append(d)
    # ordering guarantee
    result.sort()
    return result


var_count = 0
breakpoint_id_count = 0
instance_count = 0
scope_count = 0


# that's firrtl name scheme
def escape_rtl_name(name):
    if name in SV_KEYWORD:
        return name + "_"
    else:
        return name


def fix_variable_name(variables):
    # notice for bundled variables (regs or node)
    # chisel blows it up before handing it over to firrtl
    # so we have to play with lots of heuristics
    result = collections.OrderedDict()
    names = list(variables.keys())
    for name in names:
        if "." in name or "[" in name:
            result[name] = variables[name]
            continue
        if "_" not in name:
            result[name] = variables[name]
            continue
        # try to merge
        tokens = name.split("_")
        if name[0] == "_":
            identifier = "_" + tokens[0]
        elif tokens[0] in names:
            # naming conflict
            result[name] = variables[name]
            continue
        else:
            identifier = tokens[0] + "_"
        has_target = any(n.startswith(identifier) for n in names if n != name)
        if has_target:
            # change _ to .
            if name[0] == "_":
                new_name = "_" + name[1:].replace("_", ".")
            else:
                new_name = name.replace("_", ".")
            result[new_name] = variables[name]
        else:
            result[name] = variables[name]

    variables.clear()
    for name, entry in result.items():
        variables[name] = entry


def levenshtein_distance(s1, s2):
    if len(s1) > len(s2):
        s1, s2 = s2, s1

    distances = range(len(s1) + 1)
    for i2, c2 in enumerate(s2):
        distances_ = [i2+1]
        for i1, c1 in enumerate(s1):
            if c1 == c2:
                distances_.append(distances[i1])
            else:
                distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
        distances = distances_
    return distances[-1]


def get_filename(def_name, local_vars):
    filename_count = {}
    for entry in local_vars.values():
        info = entry[-1]
        filename = info.split(" ")[0]
        if filename not in filename_count:
            filename_count[filename] = 0
        filename_count[filename] += 1
    filename_counts = []
    for n, c in filename_count.items():
        filename_counts.append((n, c))
    filename_counts.sort(key=lambda x: x[-1])
    # if there is only one
    # we are good
    if len(filename_counts) == 1:
        return filename_counts[0][0]
    # pick top 2
    targets = filename_counts[-2:]
    # compute the edit distance

    target_dist = []
    for n, _ in targets:
        dist = levenshtein_distance(def_name, n)
        target_dist.append((n, dist))
    # use the one with smallest distance
    target_dist.sort(key=lambda x: x[-1])
    return target_dist[0][0]


def write_out_assignments(db, assignments, breakpoint_index):
    for fn_info, target, value, cond in assignments:
        fn_info = fn_info.replace(":", " ")
        info = tuple(fn_info.split(" "))
        if info not in breakpoint_index:
            continue
        bp_id = breakpoint_index[info]
        db.store_assignment(target, value, bp_id, cond)


def write_out_local_scope(def_name, local_vars, db, breakpoint_index, fix_variable):
    global var_count, scope_count
    # since the compiler already gives us the scope information, we don't need
    # to create guess work any more
    # well we still need to compute the file name, since compilers such as Chisel doesn't give
    # proper names
    # notice that local scope must specify line number range
    if "#range" not in local_vars:
        raise Exception("Local scope must specify line number range")
    ln_min, ln_max = local_vars["#range"]
    ln_range = range(ln_min, ln_max + 1)
    local_vars.pop("#range")
    current_filename = get_filename(def_name, local_vars)
    if current_filename is None:
        # nothing to be done here
        return
    # we hope that dictionary is in order, otherwise it's going to be bad, really bad
    if fix_variable:
        fix_variable_name(local_vars)

    breakpoint_ids = []
    var_id_mapping = {}
    scope = {}
    for name, entry in local_vars.items():
        fn_info = entry[-1].replace(":", " ")
        info = tuple(fn_info.split(" "))
        # there is no guess work so we simply try to match with declared breakpoints
        ln = int(info[1])
        # find matched breakpoint index
        for bp_info, bp_id in breakpoint_index.items():
            bp_ln = int(bp_info[1])
            # notice that we don't do == since we assume when the breakpoint hits,
            # the variable hasn't been declared yet
            if bp_ln > ln and bp_ln in ln_range:
                # it's a match
                if bp_id not in breakpoint_ids:
                    breakpoint_ids.append(bp_id)
                    scope[bp_id] = []
                rtl_name = entry[0]
                rtl_name = escape_rtl_name(rtl_name)
                db.store_variable(var_count, rtl_name)
                var_id_mapping[rtl_name] = var_count
                scope[bp_id].append((name, var_count))
                var_count += 1

    for breakpoint_id, entries in scope.items():
        for n, var_id in entries:
            db.store_context_variable(n, breakpoint_id, var_id)

    # store local scopes as well
    # assume the breakpoints are created in order
    breakpoint_ids.sort()
    db.store_scope(scope_count, *breakpoint_ids)
    scope_count += 1


def write_out_locals(def_name, local_vars, db, breakpoint_index, fix_variable, detailed_local):
    global var_count
    # first pass to determine the actual filename
    # for some reason chisel inserts the original function
    # call file info
    current_filename = get_filename(def_name, local_vars)
    if current_filename is None:
        # nothing to be done here
        return
    scope = []
    # if the line number different is more than this threshold
    skip_line_threshold = 10000 if detailed_local else 10
    # we hope that dictionary is in order, otherwise it's going to be bad, really bad
    if fix_variable:
        fix_variable_name(local_vars)
    # compute range since not all of the variables creation has direct correspondence to
    # a breakpoint
    bp_lines = set()
    for info in breakpoint_index:
        ln = int(info[1])
        start_idx = max(0, ln - skip_line_threshold)
        for i in range(start_idx, ln):
            bp_lines.add(i)
    # second pass to group line numbers
    targets = {}
    var_id_mapping = {}
    for name, entry in local_vars.items():
        fn_info = entry[-1].replace(":", " ")
        info = tuple(fn_info.split(" "))
        filename = info[0]
        ln = int(info[1])
        if ln not in bp_lines:
            continue
        # there is a "bug"/feature in the current chisel that use the first file info, as compared
        # to the last, which is what people what
        if current_filename != filename:
            # nothing we can do
            continue
        if info not in targets:
            targets[info] = []
        rtl_name = entry[0]
        rtl_name = escape_rtl_name(rtl_name)
        db.store_variable(var_count, rtl_name)
        targets[info].append((name, rtl_name))
        var_id_mapping[rtl_name] = var_count
        var_count += 1

    # reorder breakpoint index based on line number. I believe this is due to different nodes in FIRRTL?
    breakpoint_index_ = []
    for bp_info, breakpoint_id in breakpoint_index.items():
        breakpoint_index_.append((bp_info, breakpoint_id))
    breakpoint_index_.sort(key=lambda x: int(x[0][1]))

    # third pass to actually produce the table
    pre_ln = None
    ln_set = set()
    for info, entries in targets.items():
        ln = int(info[1])
        fn = info[0]
        for bp_info, breakpoint_id in breakpoint_index_:
            bp_ln = int(bp_info[1])
            bp_fn = bp_info[0]
            if bp_fn != fn:
                continue
            if bp_ln >= ln or (bp_ln - ln) > skip_line_threshold:
                continue
            # FIXME: change the implementation such that we define
            #   a scope if not defined and then we add the variables up in the runtime
            #   instead of storing them in the database
            if not detailed_local and bp_ln in ln_set:
                continue
            ln_set.add(bp_ln)
            if pre_ln is not None and (pre_ln - ln) > skip_line_threshold:
                scope.clear()
            pre_ln = ln
            for entry in entries:
                scope.append(entry)
            for name, rtl_name in scope:
                var_id = var_id_mapping[rtl_name]
                db.store_context_variable(name, breakpoint_id, var_id)


def write_out_variables(db, variables, fix_variables, prefix=""):
    global var_count, instance_count
    if fix_variables:
        fix_variable_name(variables)
    for var_name, rtl_name in variables.items():
        var_name = prefix + var_name
        rtl_name = prefix + escape_rtl_name(rtl_name)
        db.store_variable(var_count, rtl_name)
        db.store_generator_variable(var_name, instance_id=instance_count, variable_id=var_count)
        var_count += 1


def write_out_design(def_name, inst_name, design_data, db, fix_variable, detailed_def_names):
    # this is recursive. no compression so far
    global var_count, breakpoint_id_count, instance_count
    db.store_instance(instance_count, inst_name)
    def_data = design_data[def_name]
    # create variable
    if "variables" in def_data:
        write_out_variables(db, def_data["variables"], fix_variable)
        if match_name(def_name, detailed_def_names) and "instances" in def_data:
            instances = def_data["instances"]
            for inst_name, d_name in instances.items():
                inst_def = design_data[d_name]
                if "variables" in inst_def:
                    write_out_variables(db, inst_def["variables"], fix_variable, prefix=inst_name + ".")

    breakpoint_index = {}

    if "breakpoints" in def_data:
        for entry in def_data["breakpoints"]:
            if len(entry) == 2:
                fn_info, cond = entry
                trigger = ""
            else:
                assert len(entry) == 3
                fn_info, cond, trigger = entry
            fn_info = fn_info.replace(":", " ")
            info = fn_info.split(" ")
            if len(info) == 2:
                fn, ln = info
                cn = 0
            else:
                assert len(info) == 3
                fn, ln, cn = info
            db.store_breakpoint(breakpoint_id_count, instance_count, fn, int(ln), column_num=int(cn),
                                condition=cond, trigger=trigger)
            breakpoint_index[tuple(info)] = breakpoint_id_count
            breakpoint_id_count += 1
    if "assignments" in def_data:
        write_out_assignments(db, def_data["assignments"], breakpoint_index)

    instance_count += 1

    # handle local variables
    # notice that firrtl doesn't have scope function for now
    # so again, we have to play with heuristics
    # we can approximate the scope as a hack
    if "locals" in def_data:
        write_out_locals(def_name, def_data["locals"], db, breakpoint_index, fix_variable,
                         match_name(def_name, detailed_def_names))
    # could be scoped as well
    local_id = 0
    while True:
        local_name = "locals" + str(local_id)
        if local_name not in def_data:
            break
        write_out_local_scope(def_name, def_data[local_name], db, breakpoint_index, fix_variable)
        local_id += 1

    # recursively call the child instances
    if "instances" in def_data:
        for i_name, d_name in def_data["instances"].items():
            write_out_design(d_name, f"{inst_name}.{i_name}", design_data, db, fix_variable, detailed_def_names)


def main():
    args = get_args()
    design = parse_toml(args.input)
    tops = get_tops(design)
    db = DebugSymbolTable(args.output)
    for top_name in tops:
        db.begin_transaction()
        write_out_design(top_name, top_name, design, db, not args.disable_firrtl_fix,
                         detailed_def_names=args.detail)
        db.end_transaction()


if __name__ == "__main__":
    main()
