# Snakefile
import pandas as pd
import os
from pathlib import Path

# Load configuration
configfile: "config.yaml"
samples = pd.read_csv(config["metadata_path"], sep="\t")
outdir = Path(config["outdir"])
sample_ids = samples.query('singlecell=="Y"')["bamID"].tolist()
chrom_list = config["chrom_list"]

def check_bins_exist(sample, chrom_list):
    """Check if bin files exist and are non-empty for a sample."""
    bins_dir = outdir / "bins" / sample
    if not bins_dir.exists():
        return False
    
    for chrom in chrom_list:
        bin_file = bins_dir / f"{chrom}.bin"
        if not bin_file.exists() or bin_file.stat().st_size == 0:
            return False
    return True

# Define outputs for normalization
def get_normalization_output():
    """Get output files for normalization step."""
    outputs = []
    # Config files
    outputs.extend([
        expand(str(outdir / "cfg/{sample}.cfg"), sample=sample_ids),
        expand(str(outdir / "segcfg/{sample}.seg.cfg"), sample=sample_ids)
    ])
    # Read positions
    outputs.extend([
        expand(str(outdir / "readpos/{sample}/{chrom}.readpos.seq"), 
               sample=sample_ids, chrom=chrom_list)
    ])
    # Binned data
    outputs.extend([
        expand(str(outdir / "bins/{sample}/{chrom}.bin"), 
               sample=sample_ids, chrom=chrom_list)
    ])
    return outputs

# Default rule - run normalization
rule all:
    input:
        get_normalization_output()

# Create configuration files
rule create_config_files:
    output:
        cfg=str(outdir / "cfg/{sample}.cfg"),
        segcfg=str(outdir / "segcfg/{sample}.seg.cfg")
    run:
        # Create cfg file for normalization
        cfg_data = []
        for chrom in chrom_list:
            cfg_data.append([
                chrom,
                f"{config['fasta_folder']}/{chrom}.fasta",
                f"{config['mappability_folder_stem']}chr{chrom}.txt",
                str(outdir / f"readpos/{wildcards.sample}/{chrom}.readpos.seq"),
                str(outdir / f"bins/{wildcards.sample}/{chrom}.bin")
            ])
        pd.DataFrame(
            cfg_data, 
            columns=['chrom_name', 'fa_file', 'mappability', 
                    'readPosFile', 'bin_file_normalized']
        ).to_csv(output.cfg, index=None, sep='\t')

        # Create seg.cfg file for segmentation
        segcfg_data = [
            [chrom, str(outdir / f"bins/{wildcards.sample}/{chrom}.bin")]
            for chrom in chrom_list
        ]
        pd.DataFrame(
            segcfg_data,
            columns=['chromName', 'binFileNorm']
        ).to_csv(output.segcfg, index=None, sep='\t')

# Extract read positions from BAM files
rule get_read_pos:
    input:
        bam=lambda wildcards: samples.set_index("bamID").loc[wildcards.sample, "bam"]
    output:
        readpos=str(outdir / "readpos/{sample}/{chrom}.readpos.seq")
    log:
        str(outdir / "logs/get_read_pos_{sample}_{chrom}.log")
    shell:
        """
        mkdir -p $(dirname {output.readpos})
        samtools view -q 30 -F 1284 {input.bam} {wildcards.chrom} | \
            perl -ane 'print $F[3], "\\n";' > {output.readpos} 2> {log}
        """

# Run normalization using BICseq
rule run_bicseq_norm:
    input:
        config=str(outdir / "cfg/{sample}.cfg"),
        readpos=lambda wildcards: expand(
            str(outdir / "readpos/{sample}/{chrom}.readpos.seq"),
            sample=[wildcards.sample], 
            chrom=chrom_list
        )
    output:
        bins=expand(
            str(outdir / "bins/{{sample}}/{chrom}.bin"),
            chrom=chrom_list
        ),
        temp=str(outdir / "temp/{sample}.temp")
    params:
        binsize=config["binsize"],
        bicseq_norm=config["bicseq_norm"],
        read_length=config.get("read_length", 150),  
        fragment_size=config.get("fragment_size", 300),  
        tmp_dir=str(outdir / "temp")
    log:
        str(outdir / "logs/bicseq_norm_{sample}.log")
    shell:
        """
        mkdir -p {params.tmp_dir}
        {params.bicseq_norm} \
            -b={params.binsize} \
            --gc_bin \
            -p=0.0001 \
            --tmp={params.tmp_dir} \
            {input.config} \
            {output.temp} 2> {log}
        """