/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.h2o;

import hex.genmodel.MojoModel;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.dmg.pmml.MissingValueTreatmentMethod;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.h2o.Converter;
import org.jpmml.h2o.ImputerUtil;

public abstract class GlmMojoModelBaseConverter<M extends MojoModel>
extends Converter<M> {
    private static final Class<?> CLASS_GLMMOJOMODELBASE;
    private static final Field FIELD_BETA;
    private static final Field FIELD_CATS;
    private static final Field FIELD_CATMODES;
    private static final Field FIELD_CATOFFSETS;
    private static final Field FIELD_FAMILY;
    private static final Field FIELD_MEANIMPUTATION;
    private static final Field FIELD_NUMS;
    private static final Field FIELD_NUMMEANS;
    private static final Field FIELD_USEALLFACTORLEVELS;

    public GlmMojoModelBaseConverter(M model) {
        super(model);
    }

    @Override
    public Schema toMojoModelSchema(Schema schema) {
        Object model = this.getModel();
        int cats = GlmMojoModelBaseConverter.getCats(model);
        int[] catOffsets = GlmMojoModelBaseConverter.getCatOffsets(model);
        int nums = GlmMojoModelBaseConverter.getNums(model);
        boolean meanImputation = GlmMojoModelBaseConverter.getMeanImputation(model);
        final boolean useAllFactorLevels = GlmMojoModelBaseConverter.getUseAllFactorLevels(model);
        final PMMLEncoder encoder = schema.getEncoder();
        Label label = schema.getLabel();
        List features = schema.getFeatures();
        List categoricalFeatures = features.stream().filter(feature -> feature instanceof CategoricalFeature).collect(Collectors.toList());
        SchemaUtil.checkSize((int)cats, categoricalFeatures);
        for (int i = 0; i < cats; ++i) {
            CategoricalFeature categoricalFeature = (CategoricalFeature)categoricalFeatures.get(i);
            SchemaUtil.checkSize((int)(catOffsets[i + 1] - catOffsets[i] + (useAllFactorLevels ? 0 : 1)), (CategoricalFeature)categoricalFeature);
        }
        List continuousFeatures = features.stream().filter(feature -> !(feature instanceof CategoricalFeature)).map(feature -> feature.toContinuousFeature()).collect(Collectors.toList());
        SchemaUtil.checkSize((int)nums, continuousFeatures);
        ArrayList reorderedFeatures = new ArrayList();
        reorderedFeatures.addAll(categoricalFeatures);
        reorderedFeatures.addAll(continuousFeatures);
        features = reorderedFeatures;
        if (meanImputation) {
            int i;
            int[] catModes = GlmMojoModelBaseConverter.getCatModes(model);
            double[] numMeans = GlmMojoModelBaseConverter.getNumMeans(model);
            if (catModes.length != cats) {
                throw new IllegalArgumentException("Expected " + cats + " mode values, got " + catModes.length + " mode values");
            }
            if (numMeans.length != nums) {
                throw new IllegalArgumentException("Expected " + nums + " mean values, got " + numMeans.length + " mean values");
            }
            for (i = 0; i < cats; ++i) {
                CategoricalFeature categoricalFeature = (CategoricalFeature)categoricalFeatures.get(i);
                List values = categoricalFeature.getValues();
                ImputerUtil.encodeFeature((Feature)categoricalFeature, values.get(catModes[i]), MissingValueTreatmentMethod.AS_MODE);
            }
            for (i = 0; i < nums; ++i) {
                ContinuousFeature continuousFeature = (ContinuousFeature)continuousFeatures.get(i);
                ImputerUtil.encodeFeature((Feature)continuousFeature, numMeans[i], MissingValueTreatmentMethod.AS_MEAN);
            }
        }
        Function<Feature, Stream<Feature>> function = new Function<Feature, Stream<Feature>>(){

            @Override
            public Stream<Feature> apply(Feature feature) {
                if (feature instanceof CategoricalFeature) {
                    CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
                    List values = categoricalFeature.getValues();
                    if (!useAllFactorLevels) {
                        values = values.subList(1, values.size());
                    }
                    return values.stream().map(value -> new BinaryFeature(encoder, (Feature)categoricalFeature, value));
                }
                return Stream.of(feature);
            }
        };
        features = features.stream().flatMap(function).collect(Collectors.toList());
        return new Schema(encoder, label, features);
    }

    public static double[] getBeta(MojoModel model) {
        return (double[])GlmMojoModelBaseConverter.getFieldValue(FIELD_BETA, model);
    }

    public static int getCats(MojoModel model) {
        return (Integer)GlmMojoModelBaseConverter.getFieldValue(FIELD_CATS, model);
    }

    public static int[] getCatModes(MojoModel model) {
        return (int[])GlmMojoModelBaseConverter.getFieldValue(FIELD_CATMODES, model);
    }

    public static int[] getCatOffsets(MojoModel model) {
        return (int[])GlmMojoModelBaseConverter.getFieldValue(FIELD_CATOFFSETS, model);
    }

    public static String getFamily(MojoModel model) {
        return (String)GlmMojoModelBaseConverter.getFieldValue(FIELD_FAMILY, model);
    }

    public static boolean getMeanImputation(MojoModel model) {
        return (Boolean)GlmMojoModelBaseConverter.getFieldValue(FIELD_MEANIMPUTATION, model);
    }

    public static int getNums(MojoModel model) {
        return (Integer)GlmMojoModelBaseConverter.getFieldValue(FIELD_NUMS, model);
    }

    public static double[] getNumMeans(MojoModel model) {
        return (double[])GlmMojoModelBaseConverter.getFieldValue(FIELD_NUMMEANS, model);
    }

    public static boolean getUseAllFactorLevels(MojoModel model) {
        return (Boolean)GlmMojoModelBaseConverter.getFieldValue(FIELD_USEALLFACTORLEVELS, model);
    }

    static {
        try {
            CLASS_GLMMOJOMODELBASE = Class.forName("hex.genmodel.algos.glm.GlmMojoModelBase");
        }
        catch (ReflectiveOperationException roe) {
            throw new RuntimeException(roe);
        }
        try {
            FIELD_BETA = CLASS_GLMMOJOMODELBASE.getDeclaredField("_beta");
            FIELD_CATS = CLASS_GLMMOJOMODELBASE.getDeclaredField("_cats");
            FIELD_CATMODES = CLASS_GLMMOJOMODELBASE.getDeclaredField("_catModes");
            FIELD_CATOFFSETS = CLASS_GLMMOJOMODELBASE.getDeclaredField("_catOffsets");
            FIELD_FAMILY = CLASS_GLMMOJOMODELBASE.getDeclaredField("_family");
            FIELD_MEANIMPUTATION = CLASS_GLMMOJOMODELBASE.getDeclaredField("_meanImputation");
            FIELD_NUMS = CLASS_GLMMOJOMODELBASE.getDeclaredField("_nums");
            FIELD_NUMMEANS = CLASS_GLMMOJOMODELBASE.getDeclaredField("_numMeans");
            FIELD_USEALLFACTORLEVELS = CLASS_GLMMOJOMODELBASE.getDeclaredField("_useAllFactorLevels");
        }
        catch (ReflectiveOperationException roe) {
            throw new RuntimeException(roe);
        }
    }
}

