/*
 * Decompiled with CFR 0.152.
 */
package hex.genmodel.tools;

import au.com.bytecode.opencsv.CSVReader;
import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.glrm.GlrmMojoModel;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.prediction.AbstractPrediction;
import hex.genmodel.easy.prediction.AnomalyDetectionPrediction;
import hex.genmodel.easy.prediction.AutoEncoderModelPrediction;
import hex.genmodel.easy.prediction.BinomialModelPrediction;
import hex.genmodel.easy.prediction.ClusteringModelPrediction;
import hex.genmodel.easy.prediction.CoxPHModelPrediction;
import hex.genmodel.easy.prediction.DimReductionModelPrediction;
import hex.genmodel.easy.prediction.MultinomialModelPrediction;
import hex.genmodel.easy.prediction.OrdinalModelPrediction;
import hex.genmodel.easy.prediction.RegressionModelPrediction;
import hex.genmodel.utils.ArrayUtils;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Reader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

public class PredictCsv {
    private final String inputCSVFileName;
    private final String outputCSVFileName;
    private final boolean useDecimalOutput;
    private final char separator;
    private final boolean setInvNumNA;
    private final boolean getTreePath;
    private final boolean predictContributions;
    private final boolean predictCalibrated;
    private final boolean returnGLRMReconstruct;
    private final int glrmIterNumber;
    private final boolean outputHeader;
    private EasyPredictModelWrapper modelWrapper;

    private PredictCsv(String inputCSVFileName, String outputCSVFileName, boolean useDecimalOutput, char separator, boolean setInvNumNA, boolean getTreePath, boolean predictContributions, boolean predictCalibrated, boolean returnGLRMReconstruct, int glrmIterNumber, boolean outputHeader) {
        this.inputCSVFileName = inputCSVFileName;
        this.outputCSVFileName = outputCSVFileName;
        this.useDecimalOutput = useDecimalOutput;
        this.separator = separator;
        this.setInvNumNA = setInvNumNA;
        this.getTreePath = getTreePath;
        this.predictContributions = predictContributions;
        this.predictCalibrated = predictCalibrated;
        this.returnGLRMReconstruct = returnGLRMReconstruct;
        this.glrmIterNumber = glrmIterNumber;
        this.outputHeader = outputHeader;
    }

    public static void main(String[] args) {
        PredictCsvCollection predictors = PredictCsv.buildPredictCsv(args);
        PredictCsv main = predictors.main;
        try {
            main.run();
        }
        catch (Exception e) {
            System.out.println("Predict error: " + e.getMessage());
            System.out.println();
            e.printStackTrace();
            System.exit(1);
        }
        if (predictors.concurrent.length > 0) {
            try {
                ExecutorService executor = Executors.newFixedThreadPool(predictors.concurrent.length);
                ArrayList<PredictCsvCallable> callables = new ArrayList<PredictCsvCallable>(predictors.concurrent.length);
                for (int i = 0; i < predictors.concurrent.length; ++i) {
                    callables.add(new PredictCsvCallable(predictors.concurrent[i]));
                }
                int numExceptions = 0;
                for (Future future : executor.invokeAll(callables)) {
                    Exception e = (Exception)future.get();
                    if (e == null) continue;
                    e.printStackTrace();
                    ++numExceptions;
                }
                if (numExceptions > 0) {
                    throw new Exception("Some predictors failed (#failed=" + numExceptions + ")");
                }
            }
            catch (Exception e) {
                System.out.println("Concurrent predict error: " + e.getMessage());
                System.out.println();
                e.printStackTrace();
                System.exit(1);
            }
        }
        System.exit(0);
    }

    public static PredictCsv make(String[] args, GenModel model) {
        PredictCsvCollection predictorCollection = PredictCsv.buildPredictCsv(args);
        if (predictorCollection.concurrent.length != 0) {
            throw new UnsupportedOperationException("Predicting with concurrent predictors is not supported in programmatic mode.");
        }
        PredictCsv predictor = predictorCollection.main;
        if (model != null) {
            try {
                predictor.setModelWrapper(model);
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        return predictor;
    }

    private static RowData formatDataRow(String[] splitLine, String[] inputColumnNames) {
        RowData row = new RowData();
        int maxI = Math.min(inputColumnNames.length, splitLine.length);
        block9: for (int i = 0; i < maxI; ++i) {
            String cellData;
            String columnName = inputColumnNames[i];
            switch (cellData = splitLine[i]) {
                case "": 
                case "NA": 
                case "N/A": 
                case "-": {
                    continue block9;
                }
                default: {
                    row.put(columnName, cellData);
                }
            }
        }
        return row;
    }

    private String myDoubleToString(double d) {
        if (Double.isNaN(d)) {
            return "NA";
        }
        return this.useDecimalOutput ? Double.toString(d) : Double.toHexString(d);
    }

    private void writeTreePathNames(BufferedWriter output) throws Exception {
        String[] columnNames = ((SharedTreeMojoModel)this.modelWrapper.m).getDecisionPathNames();
        this.writeColumnNames(output, columnNames);
    }

    private void writeCalibratedOutputNames(BufferedWriter output) throws Exception {
        String[] outputNames = this.modelWrapper.m.getOutputNames();
        String[] calibOutputNames = new String[outputNames.length - 1];
        for (int i = 0; i < calibOutputNames.length; ++i) {
            calibOutputNames[i] = "cal_" + outputNames[i + 1];
        }
        this.writeColumnNames(output, ArrayUtils.append(outputNames, calibOutputNames));
    }

    private void writeContributionNames(BufferedWriter output) throws Exception {
        this.writeColumnNames(output, this.modelWrapper.getContributionNames());
    }

    private void writeColumnNames(BufferedWriter output, String[] columnNames) throws Exception {
        int lastIndex = columnNames.length - 1;
        for (int index = 0; index < lastIndex; ++index) {
            output.write(columnNames[index]);
            output.write(",");
        }
        output.write(columnNames[lastIndex]);
    }

    public void run() throws Exception {
        ModelCategory category = this.modelWrapper.getModelCategory();
        CSVReader reader = new CSVReader((Reader)new FileReader(this.inputCSVFileName), this.separator);
        BufferedWriter output = new BufferedWriter(new FileWriter(this.outputCSVFileName));
        if (this.outputHeader) {
            switch (category) {
                case Binomial: 
                case Multinomial: 
                case Regression: {
                    if (this.getTreePath) {
                        this.writeTreePathNames(output);
                        break;
                    }
                    if (this.predictContributions) {
                        this.writeContributionNames(output);
                        break;
                    }
                    if (this.predictCalibrated) {
                        this.writeCalibratedOutputNames(output);
                        break;
                    }
                    this.writeHeader(this.modelWrapper.m.getOutputNames(), output);
                    break;
                }
                case DimReduction: {
                    if (this.returnGLRMReconstruct) {
                        String[] colnames = this.modelWrapper.m.getNames();
                        int datawidth = ((GlrmMojoModel)this.modelWrapper.m)._permutation.length;
                        int lastData = datawidth - 1;
                        for (int index = 0; index < datawidth; ++index) {
                            output.write("reconstr_" + colnames[index]);
                            if (index >= lastData) continue;
                            output.write(44);
                        }
                        break;
                    }
                    this.writeHeader(this.modelWrapper.m.getOutputNames(), output);
                    break;
                }
                default: {
                    this.writeHeader(this.modelWrapper.m.getOutputNames(), output);
                }
            }
            output.write("\n");
        }
        int lineNum = 1;
        try {
            String[] inputColumnNames;
            String[] splitLine = reader.readNext();
            if (splitLine != null) {
                inputColumnNames = splitLine;
                this.checkMissingColumns(inputColumnNames);
            } else {
                throw new Exception("Input dataset file is empty!");
            }
            while ((splitLine = reader.readNext()) != null) {
                RowData row = PredictCsv.formatDataRow(splitLine, inputColumnNames);
                String offsetColumn = this.modelWrapper.m.getOffsetName();
                double offset = offsetColumn == null ? 0.0 : Double.parseDouble((String)row.get(offsetColumn));
                switch (category) {
                    case AutoEncoder: {
                        AbstractPrediction p = this.modelWrapper.predictAutoEncoder(row);
                        for (int i = 0; i < ((AutoEncoderModelPrediction)p).reconstructed.length; ++i) {
                            output.write(this.myDoubleToString(((AutoEncoderModelPrediction)p).reconstructed[i]));
                            if (i >= ((AutoEncoderModelPrediction)p).reconstructed.length - 1) continue;
                            output.write(44);
                        }
                        break;
                    }
                    case Binomial: {
                        int i;
                        AbstractPrediction p = this.modelWrapper.predictBinomial(row, offset);
                        if (this.getTreePath) {
                            this.writeTreePaths(((BinomialModelPrediction)p).leafNodeAssignments, output);
                            break;
                        }
                        if (this.predictContributions) {
                            this.writeContributions(((BinomialModelPrediction)p).contributions, output);
                            break;
                        }
                        output.write(((BinomialModelPrediction)p).label);
                        output.write(",");
                        for (i = 0; i < ((BinomialModelPrediction)p).classProbabilities.length; ++i) {
                            if (i > 0) {
                                output.write(",");
                            }
                            output.write(this.myDoubleToString(((BinomialModelPrediction)p).classProbabilities[i]));
                        }
                        if (!this.predictCalibrated) break;
                        for (i = 0; i < ((BinomialModelPrediction)p).classProbabilities.length; ++i) {
                            output.write(",");
                            double calibProb = ((BinomialModelPrediction)p).calibratedClassProbabilities != null ? ((BinomialModelPrediction)p).calibratedClassProbabilities[i] : Double.NaN;
                            output.write(this.myDoubleToString(calibProb));
                        }
                        break;
                    }
                    case Multinomial: {
                        AbstractPrediction p = this.modelWrapper.predictMultinomial(row);
                        if (this.getTreePath) {
                            this.writeTreePaths(((MultinomialModelPrediction)p).leafNodeAssignments, output);
                            break;
                        }
                        output.write(((MultinomialModelPrediction)p).label);
                        output.write(",");
                        for (int i = 0; i < ((MultinomialModelPrediction)p).classProbabilities.length; ++i) {
                            if (i > 0) {
                                output.write(",");
                            }
                            output.write(this.myDoubleToString(((MultinomialModelPrediction)p).classProbabilities[i]));
                        }
                        break;
                    }
                    case Ordinal: {
                        AbstractPrediction p = this.modelWrapper.predictOrdinal(row, offset);
                        output.write(((OrdinalModelPrediction)p).label);
                        output.write(",");
                        for (int i = 0; i < ((OrdinalModelPrediction)p).classProbabilities.length; ++i) {
                            if (i > 0) {
                                output.write(",");
                            }
                            output.write(this.myDoubleToString(((OrdinalModelPrediction)p).classProbabilities[i]));
                        }
                        break;
                    }
                    case Clustering: {
                        AbstractPrediction p = this.modelWrapper.predictClustering(row);
                        output.write(this.myDoubleToString(((ClusteringModelPrediction)p).cluster));
                        break;
                    }
                    case Regression: {
                        AbstractPrediction p = this.modelWrapper.predictRegression(row, offset);
                        if (this.getTreePath) {
                            this.writeTreePaths(((RegressionModelPrediction)p).leafNodeAssignments, output);
                            break;
                        }
                        if (this.predictContributions) {
                            this.writeContributions(((RegressionModelPrediction)p).contributions, output);
                            break;
                        }
                        output.write(this.myDoubleToString(((RegressionModelPrediction)p).value));
                        break;
                    }
                    case CoxPH: {
                        AbstractPrediction p = this.modelWrapper.predictCoxPH(row, offset);
                        output.write(this.myDoubleToString(((CoxPHModelPrediction)p).value));
                        break;
                    }
                    case DimReduction: {
                        AbstractPrediction p = this.modelWrapper.predictDimReduction(row);
                        double[] out = this.returnGLRMReconstruct ? ((DimReductionModelPrediction)p).reconstructed : ((DimReductionModelPrediction)p).dimensions;
                        int lastOne = out.length - 1;
                        for (int i = 0; i < out.length; ++i) {
                            output.write(this.myDoubleToString(out[i]));
                            if (i >= lastOne) continue;
                            output.write(44);
                        }
                        break;
                    }
                    case AnomalyDetection: {
                        AbstractPrediction p = this.modelWrapper.predictAnomalyDetection(row);
                        double[] rawPreds = ((AnomalyDetectionPrediction)p).toPreds();
                        for (int i = 0; i < rawPreds.length - 1; ++i) {
                            output.write(this.myDoubleToString(rawPreds[i]));
                            output.write(44);
                        }
                        output.write(this.myDoubleToString(rawPreds[rawPreds.length - 1]));
                        break;
                    }
                    default: {
                        throw new Exception("Unknown model category " + (Object)((Object)category));
                    }
                }
                output.write("\n");
                ++lineNum;
            }
        }
        catch (Exception e) {
            throw new Exception("Prediction failed on line " + lineNum, e);
        }
        finally {
            output.close();
            reader.close();
        }
    }

    private void writeHeader(String[] colNames, BufferedWriter output) throws Exception {
        output.write(colNames[0]);
        for (int i = 1; i < colNames.length; ++i) {
            output.write(",");
            output.write(colNames[i]);
        }
    }

    private void writeTreePaths(String[] treePaths, BufferedWriter output) throws Exception {
        int len = treePaths.length - 1;
        for (int index = 0; index < len; ++index) {
            output.write(treePaths[index]);
            output.write(",");
        }
        output.write(treePaths[len]);
    }

    private void writeContributions(float[] contributions, BufferedWriter output) throws Exception {
        for (int i = 0; i < contributions.length; ++i) {
            if (i > 0) {
                output.write(",");
            }
            output.write(this.myDoubleToString(contributions[i]));
        }
    }

    private void setModelWrapper(GenModel genModel) throws IOException {
        EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config().setModel(genModel).setConvertUnknownCategoricalLevelsToNa(true).setConvertInvalidNumbersToNa(this.setInvNumNA);
        if (this.getTreePath) {
            config.setEnableLeafAssignment(true);
        }
        if (this.predictContributions) {
            config.setEnableContributions(true);
        }
        if (this.returnGLRMReconstruct) {
            config.setEnableGLRMReconstrut(true);
        }
        if (this.glrmIterNumber > 0) {
            config.setGLRMIterNumber(this.glrmIterNumber);
        }
        this.setModelWrapper(new EasyPredictModelWrapper(config));
    }

    private void setModelWrapper(EasyPredictModelWrapper modelWrapper) {
        this.modelWrapper = modelWrapper;
    }

    private static void usage() {
        System.out.println();
        System.out.println("Usage:  java [...java args...] hex.genmodel.tools.PredictCsv --mojo mojoName");
        System.out.println("             --pojo pojoName --input inputFile --output outputFile --separator sepStr --decimal --setConvertInvalidNum");
        System.out.println();
        System.out.println("     --mojo    Name of the zip file containing model's MOJO.");
        System.out.println("     --pojo    Name of the java class containing the model's POJO. Either this ");
        System.out.println("               parameter or --model must be specified.");
        System.out.println("     --input   text file containing the test data set to score.");
        System.out.println("     --output  Name of the output CSV file with computed predictions.");
        System.out.println("     --separator Separator to be used in input file containing test data set.");
        System.out.println("     --decimal Use decimal numbers in the output (default is to use hexademical).");
        System.out.println("     --setConvertInvalidNum Will call .setConvertInvalidNumbersToNa(true) when loading models.");
        System.out.println("     --leafNodeAssignment will show the leaf node assignment for tree based models instead of prediction results");
        System.out.println("     --predictContributions will output prediction contributions (Shapley values) for tree based models instead of regular model predictions");
        System.out.println("     --glrmReconstruct will return the reconstructed dataset for GLRM mojo instead of X factor derived from the dataset.");
        System.out.println("     --glrmIterNumber integer indicating number of iterations to go through when constructing X factor derived from the dataset.");
        System.out.println("     --testConcurrent integer (for testing) number of concurrent threads that will be making predictions.");
        System.out.println();
        System.exit(1);
    }

    private void checkMissingColumns(String[] parsedColumnNamesArr) {
        StringBuilder stringBuilder;
        String[] modelColumnNames = this.modelWrapper.m._names;
        HashSet parsedColumnNames = new HashSet(parsedColumnNamesArr.length);
        Collections.addAll(parsedColumnNames, parsedColumnNamesArr);
        ArrayList<String> missingColumns = new ArrayList<String>();
        for (String columnName : modelColumnNames) {
            if (!parsedColumnNames.contains(columnName) && !columnName.equals(this.modelWrapper.m._responseColumn)) {
                missingColumns.add(columnName);
                continue;
            }
            parsedColumnNames.remove(columnName);
        }
        if (missingColumns.size() > 0) {
            stringBuilder = new StringBuilder("There were ");
            stringBuilder.append(missingColumns.size());
            stringBuilder.append(" missing columns found in the input data set: {");
            for (int i = 0; i < missingColumns.size(); ++i) {
                stringBuilder.append((String)missingColumns.get(i));
                if (i == missingColumns.size() - 1) continue;
                stringBuilder.append(",");
            }
            stringBuilder.append('}');
            System.out.println(stringBuilder);
        }
        if (parsedColumnNames.size() > 0) {
            stringBuilder = new StringBuilder("Detected ");
            stringBuilder.append(parsedColumnNames.size());
            stringBuilder.append(" unused columns in the input data set: {");
            Iterator iterator = parsedColumnNames.iterator();
            while (iterator.hasNext()) {
                stringBuilder.append((String)iterator.next());
                if (!iterator.hasNext()) continue;
                stringBuilder.append(",");
            }
            stringBuilder.append('}');
            System.out.println(stringBuilder);
        }
    }

    private static PredictCsvCollection buildPredictCsv(String[] args) {
        try {
            GenModel genModel;
            PredictCsvBuilder builder = new PredictCsvBuilder();
            builder.parseArgs(args);
            switch (builder.loadType) {
                case -1: {
                    genModel = null;
                    break;
                }
                case 0: {
                    genModel = PredictCsv.loadPojo(builder.pojoMojoModelNames);
                    break;
                }
                case 1: {
                    genModel = PredictCsv.loadMojo(builder.pojoMojoModelNames);
                    break;
                }
                case 2: {
                    genModel = PredictCsv.loadModel(builder.pojoMojoModelNames);
                    break;
                }
                default: {
                    throw new IllegalStateException("Unexpected value of loadType = " + builder.loadType);
                }
            }
            PredictCsv mainPredictCsv = builder.newPredictCsv();
            if (genModel != null) {
                mainPredictCsv.setModelWrapper(genModel);
            }
            PredictCsv[] concurrentPredictCsvs = new PredictCsv[builder.testConcurrent];
            for (int id = 0; id < concurrentPredictCsvs.length; ++id) {
                PredictCsv concurrentPredictCsv = builder.newConcurrentPredictCsv(id);
                concurrentPredictCsv.setModelWrapper(mainPredictCsv.modelWrapper);
                concurrentPredictCsvs[id] = concurrentPredictCsv;
            }
            return new PredictCsvCollection(mainPredictCsv, concurrentPredictCsvs);
        }
        catch (Exception e) {
            e.printStackTrace();
            PredictCsv.usage();
            throw new IllegalStateException("Should not be reachable");
        }
    }

    private static GenModel loadPojo(String className) throws Exception {
        return (GenModel)Class.forName(className).newInstance();
    }

    private static GenModel loadMojo(String modelName) throws IOException {
        return MojoModel.load(modelName);
    }

    private static GenModel loadModel(String modelName) throws Exception {
        try {
            return PredictCsv.loadMojo(modelName);
        }
        catch (IOException e) {
            return PredictCsv.loadPojo(modelName);
        }
    }

    private static class PredictCsvCallable
    implements Callable<Exception> {
        private final PredictCsv predictCsv;

        private PredictCsvCallable(PredictCsv predictCsv) {
            this.predictCsv = predictCsv;
        }

        @Override
        public Exception call() throws Exception {
            try {
                this.predictCsv.run();
            }
            catch (Exception e) {
                return e;
            }
            return null;
        }
    }

    private static class PredictCsvBuilder {
        private String inputCSVFileName;
        private String outputCSVFileName;
        private boolean useDecimalOutput;
        private char separator = (char)44;
        private boolean setInvNumNA;
        private boolean getTreePath;
        private boolean predictContributions;
        private boolean predictCalibrated;
        private boolean returnGLRMReconstruct;
        private int glrmIterNumber = -1;
        private boolean outputHeader = true;
        private int loadType = 0;
        private String pojoMojoModelNames = "";
        private int testConcurrent = 0;

        private PredictCsvBuilder() {
        }

        private PredictCsv newPredictCsv() {
            return new PredictCsv(this.inputCSVFileName, this.outputCSVFileName, this.useDecimalOutput, this.separator, this.setInvNumNA, this.getTreePath, this.predictContributions, this.predictCalibrated, this.returnGLRMReconstruct, this.glrmIterNumber, this.outputHeader);
        }

        private PredictCsv newConcurrentPredictCsv(int id) {
            return new PredictCsv(this.inputCSVFileName, this.outputCSVFileName + "." + id, this.useDecimalOutput, this.separator, this.setInvNumNA, this.getTreePath, this.predictContributions, this.predictCalibrated, this.returnGLRMReconstruct, this.glrmIterNumber, this.outputHeader);
        }

        private void parseArgs(String[] args) {
            block22: for (int i = 0; i < args.length; ++i) {
                String s = args[i];
                if (s.equals("--header")) continue;
                if (s.equals("--decimal")) {
                    this.useDecimalOutput = true;
                    continue;
                }
                if (s.equals("--glrmReconstruct")) {
                    this.returnGLRMReconstruct = true;
                    continue;
                }
                if (s.equals("--setConvertInvalidNum")) {
                    this.setInvNumNA = true;
                    continue;
                }
                if (s.equals("--leafNodeAssignment")) {
                    this.getTreePath = true;
                    continue;
                }
                if (s.equals("--predictContributions")) {
                    this.predictContributions = true;
                    continue;
                }
                if (s.equals("--predictCalibrated")) {
                    this.predictCalibrated = true;
                    continue;
                }
                if (s.equals("--embedded")) {
                    this.loadType = -1;
                    continue;
                }
                if (++i >= args.length) {
                    PredictCsv.usage();
                }
                String sarg = args[i];
                switch (s) {
                    case "--model": {
                        this.pojoMojoModelNames = sarg;
                        this.loadType = 2;
                        continue block22;
                    }
                    case "--mojo": {
                        this.pojoMojoModelNames = sarg;
                        this.loadType = 1;
                        continue block22;
                    }
                    case "--pojo": {
                        this.pojoMojoModelNames = sarg;
                        this.loadType = 0;
                        continue block22;
                    }
                    case "--input": {
                        this.inputCSVFileName = sarg;
                        continue block22;
                    }
                    case "--output": {
                        this.outputCSVFileName = sarg;
                        continue block22;
                    }
                    case "--separator": {
                        this.separator = sarg.charAt(sarg.length() - 1);
                        continue block22;
                    }
                    case "--glrmIterNumber": {
                        this.glrmIterNumber = Integer.parseInt(sarg);
                        continue block22;
                    }
                    case "--testConcurrent": {
                        this.testConcurrent = Integer.parseInt(sarg);
                        continue block22;
                    }
                    case "--outputHeader": {
                        this.outputHeader = Boolean.parseBoolean(sarg);
                        continue block22;
                    }
                    default: {
                        System.out.println("ERROR: Unknown command line argument: " + s);
                        PredictCsv.usage();
                    }
                }
            }
        }
    }

    private static class PredictCsvCollection {
        private final PredictCsv main;
        private final PredictCsv[] concurrent;

        private PredictCsvCollection(PredictCsv main, PredictCsv[] concurrent) {
            this.main = main;
            this.concurrent = concurrent;
        }
    }
}

