SNAKE_DIR = Path(workflow.basedir)
TEMPLATE_DIR = SNAKE_DIR / "templates"
SCRIPT_DIR = SNAKE_DIR / "scripts"
ENVS_DIR = SNAKE_DIR / "envs"
OUTDIR = Path(config["output_dir"])


from math import ceil
from beast_skygrid.workflow.utils import taxa_from_fasta, decimal_year_to_date

taxa = taxa_from_fasta(config["alignment"])

print(f"Found {len(taxa)} taxa in alignment")
most_recent_sampling_date = max(taxon.date for taxon in taxa)
print(f"Sample dates range from {min(taxon.date for taxon in taxa)} to {most_recent_sampling_date}")

rule all:
    input:
        OUTDIR / "skygrid.svg",
        OUTDIR / "beast" / "skygrid.mcc.tree.svg",

include: "rules/rttr.smk"

rule beast:
    input:
        beast_XML_file = OUTDIR / "beast" / "skygrid.xml",
    output:
        log_file = OUTDIR / "beast" / "skygrid.log",
        trees_file = OUTDIR / "beast" / "skygrid.trees",
    params:
        beast = config["beast_params"],
    log:
        OUTDIR / "beast" / "skygrid.out",
    conda:
        "envs/beast.yaml",
    shell:
        """
        beast \
        {params.beast} \
        -working \
        {input.beast_XML_file} \
        > {log} 2>&1
        """

# convert the constant sites to BEAST
beast_constant_sites =  None
if config.get("constant_sites"):
    beast_constant_sites = " ".join(config.get("constant_sites").split(","))

def get_cutoff():
    if config.get("cutoff"):
        return config.get("cutoff")
    with open(rules.treetime.output.dates) as f:
        # read the second line
        f.readline()
        tMCRA = f.readline().split("\t")[-1].strip()
    if config.get("discard_outliers"):
        taxa = taxa_from_fasta(OUTDIR / "filtered.fasta")
        most_recent_sampling_date = max(taxon.date for taxon in taxa) 
    estimated_cutoff = most_recent_sampling_date - float(tMCRA)
    return estimated_cutoff

def get_dimensions():
    cutoff = ceil(get_cutoff())
    tpy = config.get("transition_points_per_year")
    dimensions = cutoff * tpy
    return int(dimensions)


rule discard_outliers:
    input:
        alignment = config["alignment"],
        outliers = rules.treetime.output.outliers,
    output:
        alignment = OUTDIR / "filtered.fasta",
    conda:
        ENVS_DIR / "seqkit.yaml",
    shell:
        """
        # extract the first column of the outliers file
        cut -f 1 {input.outliers} > {input.outliers}.tmp
        echo "Removing $(wc -l {input.outliers}.tmp) outliers from the alignment"
        cat {input.outliers}.tmp
        # remove the outliers from the alignment
        seqkit grep -v -n -f {input.outliers}.tmp {input.alignment} > {output.alignment}
        rm {input.outliers}.tmp
        """


rule create_beast_xml:
    input:
        alignment = rules.discard_outliers.output.alignment if config.get("discard_outliers") else config["alignment"], 
        treetime_dates = rules.treetime.output.dates,
    output:
        beast_XML_file = OUTDIR / "beast" / "skygrid.xml",
    params:
        template = TEMPLATE_DIR / "skygrid.jinja.xml",
        dimensions = lambda w: get_dimensions(),
        cutoff = lambda w: get_cutoff(),
        clock = config.get("clock"),
        chain_length = config.get("chain_length"),
        samples = config.get("samples"),
        constant_sites = f'--constant-sites "{beast_constant_sites}"' if beast_constant_sites  else "",
    shell:
        """
        python {SCRIPT_DIR}/populate_skygrid_template.py \
            {params.template} \
            {input.alignment} \
            --dimensions {params.dimensions} \
            --cutoff {params.cutoff} \
            --output {output.beast_XML_file} \
            --clock {params.clock} \
            --chain-length {params.chain_length} \
            --samples {params.samples} \
            {params.constant_sites}
        """


rule max_clade_credibility_tree:
    """
    Makes trace plots from the beast log file.
    """
    input:
        rules.beast.output.trees_file,
    output:
        tree = OUTDIR / "beast" / "skygrid.mcc.tree",
    params:
        burnin_trees = int(int(config['samples']) * config.get("burnin", 0.1)),
    conda:
        ENVS_DIR / "beast.yaml"
    shell:
        """
        treeannotator -burninTrees {params.burnin_trees} -heights keep {input} {output}
        """

rule max_clade_credibility_tree_render:
    """
    Renders the MCC tree in SVG format.
    """
    input:
        rules.max_clade_credibility_tree.output.tree,
    output:
        svg = OUTDIR / "beast" / "skygrid.mcc.tree.svg",
        hpd = OUTDIR / "beast" / "skygrid.mcc.tree.height_0.95_HPD.svg",
        posterior = OUTDIR / "beast" / "skygrid.mcc.tree.posterior.svg",
    params:
        mrsd = decimal_year_to_date(most_recent_sampling_date),
        prefix = OUTDIR / "beast" / "skygrid.mcc.tree",
    conda:
        ENVS_DIR / "ggtree.yaml"
    shell:
        "${{CONDA_PREFIX}}/bin/Rscript {SCRIPT_DIR}/plotMccTree.R --input {input} --output-prefix {params.prefix} --mrsd {params.mrsd}"

rule plot_skygrid:
    """
    Plots the skygrid from the BEAST output.
    """
    input:
        log = rules.beast.output.log_file,
    output:
        svg = OUTDIR / "skygrid.svg",
        pdf = OUTDIR / "skygrid.pdf",
        png = OUTDIR / "skygrid.png",
    conda:
        ENVS_DIR / "ggtree.yaml"
    params:
        mrsd = most_recent_sampling_date,
        output = OUTDIR / "skygrid",
    shell:
        """
        ${{CONDA_PREFIX}}/bin/Rscript {SCRIPT_DIR}/plotSkygrid.R --input {input.log} --output {params.output} --mrsd {params.mrsd} 
        """