from typing import Optional, Union

import click

from clinica import option
from clinica.pipelines import cli_param
from clinica.pipelines.engine import clinica_pipeline
from clinica.utils.atlas import T1AndPetVolumeAtlasName
from clinica.utils.pet import SUVRReferenceRegion, Tracer

pipeline_name = "machinelearning-classification"


@clinica_pipeline
@click.command(name=pipeline_name)
@cli_param.argument.caps_directory
@cli_param.argument.group_label
@click.argument(
    "orig_input_data",
    type=click.Choice(["VoxelBased", "RegionBased"]),
)
@click.argument(
    "image_type",
    type=click.Choice(["T1w", "PET"]),
)
@click.argument(
    "algorithm",
    type=click.Choice(["DualSVM", "LogisticRegression", "RandomForest"]),
)
@click.argument(
    "validation",
    type=click.Choice(["RepeatedHoldOut", "RepeatedKFoldCV"]),
)
@click.argument(
    "subjects_visits_tsv",
    type=click.Path(exists=True, resolve_path=True),
)
@click.argument(
    "diagnoses_tsv",
    type=click.Path(exists=True, resolve_path=True),
)
@click.argument(
    "output_directory", type=click.Path(exists=False, writable=True, resolve_path=True)
)
@cli_param.option_group.pipeline_specific_options
@cli_param.option.acq_label
@cli_param.option.suvr_reference_region
@cli_param.option_group.option(
    "-atlas",
    "--atlas",
    type=click.Choice(T1AndPetVolumeAtlasName),
    help="One of the atlases generated by t1-volume or pet-volume pipeline.",
)
@option.global_option_group
@option.n_procs
def cli(
    caps_directory: str,
    group_label: str,
    orig_input_data: str,
    image_type: str,
    algorithm: str,
    validation: str,
    subjects_visits_tsv: str,
    diagnoses_tsv: str,
    output_directory: str,
    acq_label: Optional[Union[str, Tracer]] = None,
    suvr_reference_region: Optional[Union[str, SUVRReferenceRegion]] = None,
    atlas: Optional[str] = None,
    n_procs: Optional[int] = None,
) -> None:
    """Classification based on machine learning using scikit-learn.

    Parameters
    ----------
    caps_directory : str

    group_label : str
        String defining the group label for the current analysis, which helps you keep track of different analyses.

    orig_input_data : str
        Defines the type of features for classification.
        It can be 'RegionBased' or 'VoxelBased'.

    image_type : str
        Defines the studied modality ('T1w' or 'PET')

    algorithm : str
        Defines the algorithm. It can be 'DualSVM', 'LogisticRegression' or 'RandomForest'.

    validation : str
        Defines the validation method. It can be 'RepeatedHoldOut' or 'RepeatedKFoldCV'.

    subjects_visits_tsv : str
        TSV file containing the participant_id and the session_id columns.

    diagnoses_tsv : str
        TSV file where the diagnosis for each participant (identified by a participant ID) is reported (e.g. AD, CN).
        It allows the algorithm to perform the dual classification (between the two labels reported).

    output_directory : str
        The output folder path.

    acq_label : str, optional

    suvr_reference_region : str, optional

    atlas : str, optional

    n_procs : int, optional
        The number of processes to be used by the pipeline.

    Notes
    -----
    See https://aramislab.paris.inria.fr/clinica/docs/public/latest/Pipelines/MachineLearning_Classification/
    """
    from clinica.utils.exceptions import ClinicaException

    from .ml_workflows import (
        RegionBasedRepHoldOutDualSVM,
        RegionBasedRepHoldOutLogisticRegression,
        RegionBasedRepHoldOutRandomForest,
        RegionBasedRepKFoldDualSVM,
        VoxelBasedRepHoldOutDualSVM,
        VoxelBasedRepKFoldDualSVM,
    )

    if image_type == "PET":
        if not acq_label:
            raise ClinicaException(
                "You selected PET inputs without setting --acq_label flag. "
                "Clinica will now exit."
            )
        if not suvr_reference_region:
            raise ClinicaException(
                "You selected PET inputs without setting --suvr_reference_region flag. "
                "Clinica will now exit."
            )

    if orig_input_data == "RegionBased" and not atlas:
        raise ClinicaException(
            "You selected region-based inputs without setting --atlas flag. "
            "Clinica will now exit."
        )

    if algorithm in ("LogisticRegression", "RandomForest"):
        if orig_input_data != "RegionBased" or validation != "RepeatedHoldOut":
            raise ClinicaException(
                "LogisticRegression or RandomForest algorithm can only work "
                "on region-based featured or RepeatedHoldOut algorithm. "
                "Clinica will now exit."
            )

    if (
        (orig_input_data == "RegionBased")
        and (validation == "RepeatedHoldOut")
        and (algorithm == "DualSVM")
    ):
        pipeline = RegionBasedRepHoldOutDualSVM(
            caps_directory=caps_directory,
            subjects_visits_tsv=subjects_visits_tsv,
            diagnoses_tsv=diagnoses_tsv,
            group_label=group_label,
            image_type=image_type,
            output_dir=output_directory,
            acq_label=acq_label,
            suvr_reference_region=suvr_reference_region,
            atlas=atlas,
            n_threads=n_procs,
        )
    elif (
        (orig_input_data == "RegionBased")
        and (validation == "RepeatedKFoldCV")
        and (algorithm == "DualSVM")
    ):
        pipeline = RegionBasedRepKFoldDualSVM(
            caps_directory=caps_directory,
            subjects_visits_tsv=subjects_visits_tsv,
            diagnoses_tsv=diagnoses_tsv,
            group_label=group_label,
            image_type=image_type,
            output_dir=output_directory,
            acq_label=acq_label,
            suvr_reference_region=suvr_reference_region,
            atlas=atlas,
            n_threads=n_procs,
        )
    elif (
        (orig_input_data == "RegionBased")
        and (validation == "RepeatedHoldOut")
        and (algorithm == "LogisticRegression")
    ):
        pipeline = RegionBasedRepHoldOutLogisticRegression(
            caps_directory=caps_directory,
            subjects_visits_tsv=subjects_visits_tsv,
            diagnoses_tsv=diagnoses_tsv,
            group_label=group_label,
            image_type=image_type,
            output_dir=output_directory,
            acq_label=acq_label,
            suvr_reference_region=suvr_reference_region,
            atlas=atlas,
            n_threads=n_procs,
        )
    elif (
        (orig_input_data == "RegionBased")
        and (validation == "RepeatedHoldOut")
        and (algorithm == "RandomForest")
    ):
        pipeline = RegionBasedRepHoldOutRandomForest(
            caps_directory=caps_directory,
            subjects_visits_tsv=subjects_visits_tsv,
            diagnoses_tsv=diagnoses_tsv,
            group_label=group_label,
            image_type=image_type,
            output_dir=output_directory,
            acq_label=acq_label,
            suvr_reference_region=suvr_reference_region,
            atlas=atlas,
            n_threads=n_procs,
        )
    elif (
        (orig_input_data == "VoxelBased")
        and (validation == "RepeatedHoldOut")
        and (algorithm == "DualSVM")
    ):
        pipeline = VoxelBasedRepHoldOutDualSVM(
            caps_directory=caps_directory,
            subjects_visits_tsv=subjects_visits_tsv,
            diagnoses_tsv=diagnoses_tsv,
            group_label=group_label,
            image_type=image_type,
            output_dir=output_directory,
            acq_label=acq_label,
            suvr_reference_region=suvr_reference_region,
            n_threads=n_procs,
        )
    elif (
        (orig_input_data == "VoxelBased")
        and (validation == "RepeatedKFoldCV")
        and (algorithm == "DualSVM")
    ):
        pipeline = VoxelBasedRepKFoldDualSVM(
            caps_directory=caps_directory,
            subjects_visits_tsv=subjects_visits_tsv,
            diagnoses_tsv=diagnoses_tsv,
            group_label=group_label,
            image_type=image_type,
            output_dir=output_directory,
            acq_label=acq_label,
            suvr_reference_region=suvr_reference_region,
            n_threads=n_procs,
        )
    else:
        raise ClinicaException(
            "Unknown combination of machine learning classification."
        )

    pipeline.run()


if __name__ == "__main__":
    cli()
