#!python
import argparse
import logging
import math
import sys
import struct
import os

import numpy as np
import pandas as pd
import tskit

### importing egrm
from egrm import varGRM, varGRM_C, Gmap


def get_logger(name, path=None):
    """
    Simple function to set up logging to screen and output file if specified
    """
    logger = logging.getLogger(name)
    if not logger.handlers:
        # Prevent logging from propagating to the root logger
        logger.propagate = 0
        console = logging.StreamHandler()
        logger.addHandler(console)

        log_format = "[%(asctime)s - %(levelname)s] %(message)s"
        date_format = "%Y-%m-%d %H:%M:%S"
        formatter = logging.Formatter(fmt=log_format, datefmt=date_format)
        console.setFormatter(formatter)

        if path is not None:
            disk_log_stream = open("{}.log".format(path), "w")
            disk_handler = logging.StreamHandler(disk_log_stream)
            logger.addHandler(disk_handler)
            disk_handler.setFormatter(formatter)

    return logger


def write_gcta_bin(K, mu, ids, output):
    """
    Write out eGRM in GCTA binary format.

    :param: K numpy.ndarray of expected relatedness
    :param: mu floating point number of expected mutations
    :param: numpy.ndarray/list of individual IDs
    :param: str of output
    :returns: None
    """
    # todo: write out 3 files in GCTA format
    # K = prefix_path.grm.bin; relatedness diagonal + lower diagonal
    # mu = prefix_path.grm.N.bin; number of shared mutations between individuals on diagonal + lower diagonal
    # samples = prefix_path.grm.id; 2 column text = family_id individual_id
    n, n = K.shape
    with open("{}.grm.bin".format(output), "wb") as grmfile:
        for idx in range(n):
            for jdx in range(idx + 1):
                val = struct.pack("f", K[idx, jdx])
                grmfile.write(val)

    with open("{}.grm.N.bin".format(output), "wb") as grmfile:
        for idx in range(n):
            for jdx in range(idx + 1):
                val = struct.pack("f", mu)
                grmfile.write(val)

    with open("{}.grm.id".format(output), "w") as grmfile:
        for idx in range(n):
            fid = 0
            iid = ids[idx]
            grmfile.write("\t".join([str(fid), str(iid)]) + os.linesep)

    return


def main(args):
    ### parse arguments
    parser = argparse.ArgumentParser(
        description="Construct eGRM matrix from tree sequence data",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "input", type=argparse.FileType("r"), help="path to ts-kit tree sequence file"
    )
    parser.add_argument(
        "--output", "--o", default="egrm", type=str, help="output file prefix"
    )
    parser.add_argument(
        "--c-extension", "--c", help="acceleration by C extension", action="store_true"
    )
    parser.add_argument(
        "--skip-first-tree",
        "--sft",
        help="discard the first tree in the tree sequence",
        action="store_true",
    )
    parser.add_argument(
        "--run-var",
        "--var",
        default=False,
        help="compute varGRM in addition to eGRM",
        action="store_true",
    )
    parser.add_argument(
        "--genetic-map", "--map", type=str, default=None, help="map file fullname"
    )
    parser.add_argument(
        "--left",
        "--l",
        type=int,
        default=0,
        help="leftmost genomic position to be included",
    )
    parser.add_argument(
        "--right",
        "--r",
        type=int,
        default=math.inf,
        help="rightmost genomic position to be included",
    )
    parser.add_argument("--rlim", type=float, default=0, help="most recent time limit")
    parser.add_argument(
        "--alim", type=float, default=math.inf, help="most ancient time limit"
    )
    parser.add_argument(
        "--verbose",
        default=False,
        action="store_true",
        help="verbose logging. Includes debug info.",
    )
    parser.add_argument(
        "--haploid",
        default=False,
        action="store_true",
        help="output eGRM over haploids. Default is diploid/genotype eGRM.",
    )
    parser.add_argument(
        "--output-format",
        "--f",
        choices=["gcta", "numpy"],
        default="gcta",
        help="output format of eGRM",
    )

    # parse command line
    args = parser.parse_args(args)

    log = get_logger(__name__, f"{args.output}.log")
    if args.verbose:
        log.setLevel(logging.DEBUG)
    else:
        log.setLevel(logging.INFO)

    # load tree sequence data
    tree_input = f"{args.input.name}"
    log.info(f"Beginning importing tree sequence at {tree_input}")
    trees = tskit.load(args.input)
    log.info(f"Finished importing tree sequence at {tree_input}")

    log.info("Constructing genetic map")
    gmap = Gmap(args.genetic_map)

    # determine which backend to use
    if args.c_extension:
        log.debug("Using C-extension to estimate eGRM")
        construct = varGRM_C
    else:
        log.debug("Using native Python to estimate eGRM")
        construct = varGRM

    # eGRM!!!
    log.info("Beginning eGRM estimation")
    egrm, vargrm, egrm_mu = construct(
        trees,
        log=None,
        rlim=args.rlim,
        alim=args.alim,
        left=args.left,
        right=args.right,
        gmap=gmap,
        var=args.run_var,
        sft=args.skip_first_tree,
    )
    log.info("Finished eGRM estimation")

    if not args.haploid:
        N = egrm.shape[0]
        maternals = np.array(range(0, N, 2))
        paternals = np.array(range(1, N, 2))
        egrm = 0.5 * (egrm[maternals, :][:, maternals] + egrm[maternals, :][:, paternals] + \
                      egrm[paternals, :][:, maternals] + egrm[paternals, :][:, paternals])

    # output/save our eGRM
    log.info("Beginning export of eGRM")
    if args.output_format == "numpy":
        np.save(args.output + ".npy", egrm)
        np.save(args.output + "_mu.npy", egrm_mu)
        if args.run_var:
            np.save(args.output + "_var.npy", vargrm)
    elif args.output_format == "gcta":
        # TODO: we need a robust way to map to IDs
        ids = [f"id_{idx}" for idx in range(egrm.shape[0])]
        write_gcta_bin(egrm, egrm_mu, ids, args.output)
        if args.run_var:
            write_gcta_bin(egrm, egrm_mu, ids, args.output + "_var")
    log.info("Finished export of eGRM")
    log.info("Finished! :D")

    return 0


if __name__ == "__main__":
    sys.exit(main(sys.argv[1:]))
