/*
 * Decompiled with CFR 0.152.
 */
package h2o.estimators;

import hex.genmodel.MojoModel;
import hex.genmodel.algos.glm.GlmOrdinalMojoModel;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FeatureList;
import org.jpmml.converter.FeatureUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.OrdinalLabel;
import org.jpmml.converter.ScalarLabelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.TypeUtil;
import org.jpmml.h2o.Converter;
import org.jpmml.h2o.ConverterFactory;
import org.jpmml.h2o.H2OEncoder;
import org.jpmml.h2o.MojoModelUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.Encodable;
import org.jpmml.sklearn.SkLearnEncoder;
import org.jpmml.sklearn.SkLearnException;
import sklearn.Classifier;
import sklearn.Estimator;
import sklearn.HasClasses;

public class H2OEstimator
extends Estimator
implements HasClasses,
Encodable {
    private MojoModel mojoModel = null;
    private static final String TYPE_CLASSIFIER = "classifier";
    private static final String TYPE_REGRESSOR = "regressor";

    public H2OEstimator(String module, String name) {
        super(module, name);
    }

    public MiningFunction getMiningFunction() {
        String estimatorType;
        switch (estimatorType = this.getEstimatorType()) {
            case "classifier": {
                return MiningFunction.CLASSIFICATION;
            }
            case "regressor": {
                return MiningFunction.REGRESSION;
            }
        }
        throw new IllegalArgumentException(estimatorType);
    }

    public boolean isSupervised() {
        String estimatorType;
        switch (estimatorType = this.getEstimatorType()) {
            case "classifier": 
            case "regressor": {
                return true;
            }
        }
        throw new IllegalArgumentException(estimatorType);
    }

    public int getNumberOfOutputs() {
        String estimatorType;
        switch (estimatorType = this.getEstimatorType()) {
            case "classifier": 
            case "regressor": {
                return 1;
            }
        }
        throw new IllegalArgumentException(estimatorType);
    }

    public List<?> getClasses() {
        MojoModel mojoModel = this.getMojoModel();
        if (this.hasattr("pmml_classes_")) {
            List values = this.getListLike("pmml_classes_");
            return Classifier.canonicalizeValues((List)values);
        }
        int responseIdx = mojoModel.getResponseIdx();
        String[] responseValues = mojoModel.getDomainValues(responseIdx);
        if (responseValues == null) {
            throw new IllegalArgumentException();
        }
        return Arrays.asList(responseValues);
    }

    public boolean hasProbabilityDistribution() {
        String estimatorType;
        switch (estimatorType = this.getEstimatorType()) {
            case "classifier": {
                return true;
            }
            case "regressor": {
                return false;
            }
        }
        throw new IllegalArgumentException(estimatorType);
    }

    public Label encodeLabel(List<String> names, SkLearnEncoder encoder) {
        String estimatorType = this.getEstimatorType();
        MojoModel mojoModel = this.getMojoModel();
        ClassDictUtil.checkSize((int)1, (Collection[])new Collection[]{names});
        String name = names.get(0);
        switch (estimatorType) {
            case "classifier": {
                List<?> categories = this.getClasses();
                OpType opType = OpType.CATEGORICAL;
                DataType dataType = TypeUtil.getDataType(categories, (DataType)DataType.STRING);
                if (mojoModel instanceof GlmOrdinalMojoModel) {
                    opType = OpType.ORDINAL;
                }
                if (name != null) {
                    DataField dataField = encoder.createDataField(name, opType, dataType, categories);
                    return ScalarLabelUtil.createScalarLabel((DataField)dataField);
                }
                switch (opType) {
                    case CATEGORICAL: {
                        return new CategoricalLabel(dataType, categories);
                    }
                    case ORDINAL: {
                        return new OrdinalLabel(dataType, categories);
                    }
                }
                throw new IllegalArgumentException();
            }
            case "regressor": {
                if (name != null) {
                    DataField dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.DOUBLE);
                    return ScalarLabelUtil.createScalarLabel((DataField)dataField);
                }
                return new ContinuousLabel(DataType.DOUBLE);
            }
        }
        throw new IllegalArgumentException(estimatorType);
    }

    public Model encodeModel(Schema schema) {
        Converter<?> converter = this.createConverter();
        ModelEncoder encoder = schema.getEncoder();
        Label label = schema.getLabel();
        List features = schema.getFeatures();
        H2OEncoder h2oEncoder = new H2OEncoder();
        Schema h2oSchema = converter.encodeSchema(h2oEncoder);
        List h2oFeatures = h2oSchema.getFeatures();
        ArrayList<Feature> reorderedFeatures = new ArrayList<Feature>();
        for (Feature h2oFeature : h2oFeatures) {
            Feature feature;
            String name = h2oFeature.getName();
            if (features instanceof FeatureList) {
                FeatureList namedFeatures = (FeatureList)features;
                feature = namedFeatures.resolveFeature(name);
            } else {
                feature = FeatureUtil.findFeature((List)features, (String)name);
                if (feature == null) {
                    int index = Integer.parseInt(name.substring(1)) - 1;
                    feature = (Feature)features.get(index);
                }
            }
            reorderedFeatures.add(feature);
        }
        Schema mojoModelSchema = converter.toMojoModelSchema(new Schema(encoder, label, reorderedFeatures));
        return converter.encodeModel(mojoModelSchema);
    }

    public PMML encodePMML() {
        Converter<?> converter = this.createConverter();
        return converter.encodePMML();
    }

    public String getEstimatorType() {
        return (String)this.getEnum("_estimator_type", arg_0 -> ((H2OEstimator)this).getString(arg_0), Arrays.asList(TYPE_CLASSIFIER, TYPE_REGRESSOR));
    }

    public byte[] getMojoBytes() {
        return (byte[])this.get("_mojo_bytes", byte[].class);
    }

    public String getMojoPath() {
        return this.getString("_mojo_path");
    }

    public H2OEstimator setMojoPath(String mojoPath) {
        this.setattr("_mojo_path", mojoPath);
        return this;
    }

    private Converter<?> createConverter() {
        MojoModel mojoModel = this.getMojoModel();
        try {
            ConverterFactory converterFactory = ConverterFactory.newConverterFactory();
            return converterFactory.newConverter(mojoModel);
        }
        catch (Exception e) {
            throw new SkLearnException("Failed to create H2O.ai converter", (Throwable)e);
        }
    }

    private MojoModel getMojoModel() {
        if (this.mojoModel == null) {
            this.mojoModel = this.loadMojoModel();
        }
        return this.mojoModel;
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private MojoModel loadMojoModel() {
        if (this.hasattr("_mojo_bytes")) {
            byte[] mojoBytes = this.getMojoBytes();
            try (ByteArrayInputStream is = new ByteArrayInputStream(mojoBytes);){
                MojoModel mojoModel = MojoModelUtil.readFrom((InputStream)is);
                return mojoModel;
            }
            catch (Exception e) {
                throw new SkLearnException("Failed to load H2O.ai MOJO object", (Throwable)e);
            }
        }
        String mojoPath = this.getMojoPath();
        try {
            return MojoModelUtil.readFrom((File)new File(mojoPath), (boolean)false);
        }
        catch (Exception e) {
            throw new SkLearnException("Failed to load H2O.ai MOJO object", (Throwable)e);
        }
    }
}

