from os.path import join
import sys

sys.path.append(".")
from pkg.io import files_match_pattern


"""
This workflow performs single-sample or multi-sample segmentation of MSI data (provided as imzML files)

input:
    data folder must contain the following directories:
    - "msi": directory with imzML files named group_sampleno.imzML

output:
    - "msi/segmented": various visualizations of segmentation and UMAP
"""


def file_size_resources(wildcards, input_file):
    # Get the file size in bytes
    file_size = os.path.getsize(input_file)

    # Convert file size to MB
    file_size_mb = file_size / (1024 * 1024)

    # Define the resources based on file size
    mem_mb_needed = file_size_mb * 10

    return {"mem_mb": mem_mb_needed}


def get_input_all(msi_dir, files, multi_sample, input_list=[]):
    if multi_sample:
        input_list.append(expand(msi_dir + '/segmented/{fl}.png', fl=files))
    else:
        input_list.append(expand(msi_dir + '/segmented/{fl}_cluster_image.svg', fl=files))
    return input_list


configfile: 'data/config.yaml'

imzML_pattern = r"^[a-zA-z0-9]+_[a-zA-z0-9]+\.imzML$" # imzML files must be named: condition_sampleno.tif

MSI_SEG_SCRIPT_PATH = 'msi_segmentation_flow/scripts'
DATA_PATH = config['data']
# BENCHMARK_PATH = join(DATA_PATH, 'benchmark')
# BM_REPEAT = 5
MSI_DIR = join(DATA_PATH, 'msi')
MSI_FILES = files_match_pattern(MSI_DIR, imzML_pattern)
MSI_FILES = [f.split('.')[0] for f in MSI_FILES]
MULTI = config['general']['multi_sample']

rule all:
    input: get_input_all(MSI_DIR, MSI_FILES, MULTI)

rule single_sample_segmentation:
    input: MSI_DIR + '/{fl}.imzML'
    output: MSI_DIR + '/segmented/{fl}_cluster_image.svg'
    resources:
        mem_mb=lambda wildcards: file_size_resources(wildcards,input_file="{}/{}.ibd".format(MSI_DIR, wildcards.fl))["mem_mb"]
    # benchmark: repeat(BENCHMARK_PATH + "/{fl}.single_sample_segmentation.benchmark.txt",BM_REPEAT)
    shell: 'python ' + MSI_SEG_SCRIPT_PATH + '/single_sample_segmentation.py {input:q} -result_dir ' + MSI_DIR + '/segmented '
            + '-method {config[dim_reduction][method]} -n_components {config[dim_reduction][n_components]} '
              '-metric {config[umap][metric]} -n_neighbors {config[umap][n_neighbors]} -min_dist {config[umap][min_dist]} '
              '-cluster {config[clustering][method]} -n_clusters {config[clustering][n_clusters]} '
              '-min_cluster_size {config[hdbscan][min_cluster_size]} -min_samples {config[hdbscan][min_samples]} '
              '-dot_size {config[general][dot_size]}'

rule multi_sample_segmentation:
    input: expand(MSI_DIR + '/{fl}.imzML', fl=MSI_FILES)
    output: expand(MSI_DIR + '/segmented/{fl}.png', fl=MSI_FILES)
    # benchmark: repeat(BENCHMARK_PATH + "/multi_sample_segmentation.benchmark.txt",BM_REPEAT)
    shell: 'python ' + MSI_SEG_SCRIPT_PATH + '/multi_sample_segmentation.py ' + MSI_DIR + ' -result_dir '
            + MSI_DIR + '/segmented -method {config[dim_reduction][method]} -dist_metric {config[umap][metric]} '
                        '-n_neighbors {config[umap][n_neighbors]} -min_dist {config[umap][min_dist]} '
                        '-clustering_method {config[clustering][method]} -min_cluster_size {config[hdbscan][min_cluster_size]} '
                        '-min_samples {config[hdbscan][min_samples]} -n_clusters {config[clustering][n_clusters]} '
                        '-dot_size {config[general][dot_size]} -batch_correct {config[general][batch_correct]}'