import sys
import time
from modules.python.TextColor import TextColor
from modules.python.CallConsensusInterface import call_consensus
from modules.python.StitchInterface import perform_stitch
from modules.python.FileManager import FileManager
"""
The Call Consensus method generates base predictions for images generated through MarginPolish. This script reads
hdf5 files generated by MarginPolish and produces another Hdf5 file that holds all predictions. The generated hdf5 file
is given to stitch.py which then stitches the segments using an alignment which gives us a polished sequence.

The algorithm is described here:

  1) INPUTS:
    - directory path to the image files generated by MarginPolish
    - model path directing to a trained model
    - batch size for mini-batch prediction
    - num workers for mini-batch processing threads
    - output directory path to where the output hdf5 will be saved
    - gpu mode indicating if GPU will be used
  2) METHOD:
    - Call predict function that loads the neural network and generates base predictions and saves it into a hdf5 file
        - Loads the model
        - Iterates over the input images in minibatch
        - For each image uses a sliding window method to slide of the image sequence
        - Aggregate the predictions to get sequence prediction for the entire image sequence
        - Save all the predictions to a file
  3) OUTPUT:
    - A hdf5 file containing all the base predictions   
"""


def get_elapsed_time_string(start_time, end_time):
    """
    Get a string representing the elapsed time given a start and end time.
    :param start_time: Start time (time.time())
    :param end_time: End time (time.time())
    :return:
    """
    elapsed = end_time - start_time
    hours = str(int(elapsed/3600))
    elapsed = elapsed - (elapsed/3600)
    mins = str(int(elapsed/60))
    secs = str(int(elapsed) % 60)
    time_string = hours + " HOURS " + mins + " MINS " + secs + " SECS."

    return time_string


def polish_genome(image_dir, model_path, batch_size, num_workers, threads, output_dir, output_prefix, gpu_mode,
                  device_ids, callers):
    """
    This method provides an interface too call the predict method that generates the prediction hdf5 file.
    :param image_dir: Path to directory where all MarginPolish images are saved.
    :param model_path: Path to a trained model.
    :param batch_size: Batch size for minibatch processing.
    :param num_workers: Number of workers for minibatch processing.
    :param threads: Number of threads for pytorch.
    :param output_dir: Path to the output directory.
    :param output_prefix: Prefix of the output HDF5 file.
    :param gpu_mode: If true, predict method will use GPU.
    :param device_ids: List of GPU devices.
    :param callers: Total number of callers to use.
    :return:
    """
    output_dir = FileManager.handle_output_directory(output_dir)
    timestr = time.strftime("%m%d%Y_%H%M%S")

    prediction_output_directory = output_dir + "/predictions_" + str(timestr) + "/"
    prediction_output_directory = FileManager.handle_output_directory(prediction_output_directory)

    sys.stderr.write(TextColor.GREEN + "INFO: RUN-ID: " + str(timestr) + "\n" + TextColor.END)
    sys.stderr.write(TextColor.GREEN + "INFO: PREDICTION OUTPUT DIRECTORY: "
                     + str(prediction_output_directory) + "\n" + TextColor.END)

    call_consensus_start_time = time.time()
    sys.stderr.write(TextColor.GREEN + "INFO: CALL CONSENSUS STARTING\n" + TextColor.END)
    call_consensus(image_dir,
                   model_path,
                   batch_size,
                   num_workers,
                   threads,
                   prediction_output_directory,
                   output_prefix,
                   gpu_mode,
                   device_ids,
                   callers)
    call_consensus_end_time = time.time()

    stitch_start_time = time.time()
    sys.stderr.write(TextColor.GREEN + "INFO: STITCH STARTING\n" + TextColor.END)
    print(prediction_output_directory)
    perform_stitch(prediction_output_directory,
                   output_dir,
                   output_prefix,
                   threads)
    stitch_end_time = time.time()

    call_consensus_time = get_elapsed_time_string(call_consensus_start_time, call_consensus_end_time)
    stitch_time = get_elapsed_time_string(stitch_start_time, stitch_end_time)
    overall_time = get_elapsed_time_string(call_consensus_start_time, stitch_end_time)

    sys.stderr.write(TextColor.GREEN + "INFO: FINISHED PROCESSING.\n" + TextColor.END)
    sys.stderr.write(TextColor.GREEN + "INFO: TOTAL TIME ELAPSED: " + str(overall_time) + "\n" + TextColor.END)
    sys.stderr.write(TextColor.GREEN + "INFO: PREDICTION TIME: " + str(call_consensus_time) + "\n" + TextColor.END)
    sys.stderr.write(TextColor.GREEN + "INFO: STITCH TIME: " + str(stitch_time) + "\n" + TextColor.END)

