#!/Users/mathewj2/repos/SPT/venv/bin/python
import argparse
import os
from os.path import join

import pandas as pd

import spatialprofilingtoolbox as spt
from spatialprofilingtoolbox.dataset_designs.multiplexed_imaging.halo_cell_metadata_provider import HALOCellMetadata
from spatialprofilingtoolbox.dataset_designs.multiplexed_imaging.halo_cell_metadata_design import HALOCellMetadataDesign
from spatialprofilingtoolbox.environment.file_io import get_outcomes_files


def do_aggregation():
    parser = argparse.ArgumentParser(
        description = ''.join([
            'This script aggregates all input cell data from an spt-pipeline run',
            'into a single batch, and does subsampling.',
        ])
    )
    parser.add_argument('--max-per-sample',
        dest='max_per_sample',
        type=int,
        required=False,
        default=100,
        help='The maximum number of cells to draw from each sample/image.',
    )
    parser.add_argument('--omit-column',
        dest='omit_column',
        type=str,
        required=False,
        default=None,
        help='A data column to omit ',
    )
    args = parser.parse_args()

    parameters = spt.get_config_parameters()
    parameters['file_manifest_file'] = join(parameters['input_path'], parameters['file_manifest_file'])

    dataset_design = HALOCellMetadataDesign(**parameters)
    cell_data = HALOCellMetadata(
        input_files_path = parameters['input_path'],
        dataset_design = dataset_design,
        file_manifest_file = parameters['file_manifest_file'],
    )
    cell_data.initialize()

    outcomes_files = get_outcomes_files(dataset_design.dataset_settings)
    cell_data.write_subsampled(
    	max_per_sample = args.max_per_sample,
    	outcomes_file = outcomes_files[0],
        omit_column = args.omit_column,
    )

if __name__=='__main__':
    do_aggregation()
