#!/usr/bin/env python3
try:
    from mpi4py import MPI
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
except ImportError:
    rank = 0
from teca import *
import sys
import argparse
import numpy as np

teca_profiler.initialize()
teca_profiler.start_event('deeplab ar_detect')

# parse the command line
parser = argparse.ArgumentParser(
    formatter_class=lambda prog: argparse.ArgumentDefaultsHelpFormatter(
                       prog, max_help_position=4, width=100))

parser.add_argument('--input_file', type=str, required=False,
    help='a teca_multi_cf_reader configuration file identifying the set'
         ' of NetCDF CF2 files to process. When present data is read'
         ' using the teca_multi_cf_reader. Use one of either --input_file'
         ' or --input_regex.')

parser.add_argument('--input_regex', type=str, required=False,
    help='a teca_cf_reader regex identifying the'
         ' set of NetCDF CF2 files to process. When present data is read'
         ' using the teca_cf_reader. Use one of either --input_file or '
         ' --input_regex.')

parser.add_argument('--ivt', type=str, required=False, default='IVT',
    help='name of variable with integrated vapor transport magnitude')

parser.add_argument('--compute_ivt_magnitude', action='store_true',
    help='when this flag is present magnitude of vector IVT is calculated.'
         ' use --ivt_u and --ivt_v to set the name of the IVT vector'
         ' components if needed.')

parser.add_argument('--ivt_u', type=str, required=False, default='IVT_U',
    help='name of variable with longitudinal component of the integrated vapor'
         ' transport vector.')

parser.add_argument('--ivt_v', type=str, required=False, default='IVT_V',
    help='name of variable with latitudinal component of the integrated vapor'
         ' transport vector.')

parser.add_argument('--write_ivt_magnitude', action='store_true',
    help='when this flag is present IVT magnitude is written to disk with the'
         ' AR detector results')

parser.add_argument('--compute_ivt', action='store_true',
    help='when this flag is present IVT vector is calculated from'
         ' specific humidity, and wind vector components. use'
         ' --specific_humidity --wind_u and --wind_v to set the name of the'
         ' specific humidity and wind vector components, and --ivt_u and'
         ' --ivt_v to control the names of' ' the results, if needed.')

parser.add_argument('--specific_humidity', type=str, required=False,
    default='Q', help='name of variable with the 3D specific humidity field.')

parser.add_argument('--wind_u', type=str, required=False, default='U',
    help='name of variable with the 3D longitudinal component of the wind'
         'vector.')

parser.add_argument('--wind_v', type=str, required=False, default='V',
    help='name of variable with the 3D latitudinal component of the wind'
         ' vector.')

parser.add_argument('--write_ivt', action='store_true', required=False,
    help='when this flag is present IVT vector is written to disk with'
         ' the result')

parser.add_argument('--x_axis_variable', type=str, default='lon',
    required=False, help='name of x coordinate variable')

parser.add_argument('--y_axis_variable', type=str, default='lat',
     required=False, help='name of y coordinate variable')

parser.add_argument('--z_axis_variable', type=str, default='plev',
     required=False, help='name of z coordinate variable')

parser.add_argument('--output_file', type=str, required=True,
    help='A path and file name pattern for the output NetCDF files. %%t%% is'
         ' replaced with a human readable date and time corresponding to the'
         ' time of the first time step in the file. Use --date_format to change'
         ' the formatting')

parser.add_argument('--file_layout', type=str, default='monthly',
                    help='Selects the size and layout of the set of output'
                         ' files. May be one of number_of_steps, daily,'
                         ' monthly, seasonal, or yearly. Files are structured'
                         ' such that each file contains one of the selected'
                         ' interval. For the number_of_steps option use'
                         ' --steps_per_file.')

parser.add_argument('--steps_per_file', type=int, required=False, default=128,
    help='number of time steps per output file')

parser.add_argument('--target_device', type=str, default='cpu',
    help='set the execution target. May be one of "cpu", or "cuda"')

parser.add_argument('--n_threads', type=int, default=-1,
    help='Sets the thread pool size on each MPI rank. When the default'
         ' value of -1 is used TECA will coordinate the thread pools across'
         ' ranks such each thread is bound to a unique physical core.')

parser.add_argument('--n_threads_max', type=int, default=4,
    help='Sets the max thread pool size on each MPI rank. Set to -1'
         ' to use all available cores.')

parser.add_argument('--binary_ar_threshold', type=float,
    default=(2.0/3.0), help='probability threshold for segmenting'
    'ar_probability to produce ar_binary_tag')

parser.add_argument('--pytorch_model', type=str, required=False,
    help='path to the the pytorch model file')

parser.add_argument('--t_axis_variable', type=str, required=False,
    help='time dimension name')

parser.add_argument('--calendar', type=str, required=False,
    help='time calendar')

parser.add_argument('--t_units', type=str, required=False,
    help='time unit')

parser.add_argument('--filename_time_template', type=str, required=False,
    help='filename time template')

parser.add_argument('--date_format', type=str, required=False,
    help='A strftime format used when encoding dates into the output'
         ' file names (%%F-%%HZ). %%t%% in the file name is replaced with date/time'
         ' of the first time step in the file using this format specifier.')

parser.add_argument('--first_step', type=int, default=0, required=False,
    help='first time step to process')

parser.add_argument('--last_step', type=int, default=-1, required=False,
    help='last time step to process')
parser.add_argument('--start_date', type=str, required=False,
    help='first time to process in "YYYY-MM-DD hh:mm:ss" format')

parser.add_argument('--end_date', type=str, required=False,
    help='end time to process in "YYYY-MM-DD hh:mm:ss" format')

parser.add_argument('--verbose', action='store_true',
    help='Enable verbose output')


# prevent spew when running under mpi
try:
    args = parser.parse_args()
except Exception:
    if rank == 0: raise

# configure the reader
if args.input_file and not args.input_regex:
    reader = teca_multi_cf_reader.New()
    reader.set_input_file(args.input_file)
elif args.input_regex and not args.input_file:
    reader = teca_cf_reader.New()
    reader.set_files_regex(args.input_regex)
else:
    if rank == 0:
        raise RuntimeError('Exactly one of --input_file or --input_regex'
                           ' must be provided')

reader.set_x_axis_variable(args.x_axis_variable)
reader.set_y_axis_variable(args.y_axis_variable)

if args.t_axis_variable is not None:
    reader.set_t_axis_variable(args.t_axis_variable)

if args.calendar:
    reader.set_calendar(args.calendar)

if args.t_units:
    reader.set_t_units(args.t_units)

if args.filename_time_template:
    reader.set_filename_time_template(args.filename_time_template)

head = reader

# configure the integrator
if args.compute_ivt:
    reader.set_z_axis_variable(args.z_axis_variable)

    ivt_int = teca_integrated_vapor_transport.New()
    ivt_int.set_wind_u_variable(args.wind_u)
    ivt_int.set_wind_v_variable(args.wind_v)
    ivt_int.set_specific_humidity_variable(args.specific_humidity)
    ivt_int.set_ivt_u_variable(args.ivt_u)
    ivt_int.set_ivt_v_variable(args.ivt_v)

    ivt_int.set_input_connection(reader.get_output_port())

    head = ivt_int

# configure the norm
if args.compute_ivt or args.compute_ivt_magnitude:

    l2_norm = teca_l2_norm.New()
    l2_norm.set_component_0_variable(args.ivt_u);
    l2_norm.set_component_1_variable(args.ivt_v);
    l2_norm.set_l2_norm_variable(args.ivt);

    l2_norm.set_input_connection(head.get_output_port())

    head = l2_norm

# coordinate normalization
coords = teca_normalize_coordinates.New()
coords.set_input_connection(head.get_output_port())

# ar detector
ar_detect = teca_deeplab_ar_detect.New()
ar_detect.set_input_connection(coords.get_output_port())
ar_detect.set_verbose(args.verbose)
ar_detect.set_ivt_variable(args.ivt)
ar_detect.set_target_device(args.target_device)
ar_detect.set_thread_pool_size(args.n_threads)
ar_detect.set_max_thread_pool_size(args.n_threads_max)
ar_detect.load_model(args.pytorch_model)

# post detection segemntation
seg_atts = teca_metadata()
seg_atts["long_name"] = "binary indicator of atmospheric river"
seg_atts["description"] = "binary indicator of atmospheric river"
seg_atts["scheme"] = "deeplab"
seg_atts["version"] = "0.0"
seg_atts["note"] = "derived by thresholding ar_probability >= %f" \
    % args.binary_ar_threshold

ar_tag = teca_binary_segmentation.New()
ar_tag.set_input_connection(ar_detect.get_output_port())
ar_tag.set_threshold_mode(ar_tag.BY_VALUE)
ar_tag.set_threshold_variable("ar_probability")
ar_tag.set_segmentation_variable("ar_binary_tag")
ar_tag.set_low_threshold_value(args.binary_ar_threshold)
ar_tag.set_segmentation_variable_attributes(seg_atts)

# configure the writer
exe = teca_index_executive.New()
writer = teca_cf_writer.New()
writer.set_input_connection(ar_tag.get_output_port())
writer.set_executive(exe)
writer.set_thread_pool_size(1)
writer.set_file_name(args.output_file)
writer.set_layout(args.file_layout)
writer.set_steps_per_file(args.steps_per_file)
writer.set_first_step(args.first_step)
writer.set_last_step(args.last_step)

point_arrays = ['ar_probability', 'ar_binary_tag']
if args.compute_ivt and args.write_ivt:
    point_arrays.append(args.ivt_u)
    point_arrays.append(args.ivt_v)

if ((args.compute_ivt or args.compute_ivt_magnitude)
     and args.write_ivt_magnitude):
    point_arrays.append(args.ivt)

writer.set_point_arrays(point_arrays)

if args.date_format:
    writer.set_date_format(args.date_format)

if args.start_date or args.end_date:

    # run the metadata reporting phase of the pipeline
    md = reader.update_metadata()

    # get the time axis array attributes
    atrs = md['attributes']

    time_atts = atrs['time']
    calendar = time_atts['calendar']
    units = time_atts['units']

    coords = md['coordinates']
    time = coords['t']

    # convert date string to step, start date
    if args.start_date:
        first_step = coordinate_util.time_step_of(time, True, True, calendar,
                                                  units, args.start_date)
        writer.set_first_step(first_step)

    # and end date
    if args.end_date:
        last_step = coordinate_util.time_step_of(time, False, True, calendar,
                                                 units, args.end_date)
        writer.set_last_step(last_step)

# run the pipeline
writer.update()

teca_profiler.end_event('deeplab ar_detect')
teca_profiler.finalize()
