/*
 * Decompiled with CFR 0.152.
 */
package sklego.meta;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import numpy.core.ScalarUtil;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.regression.RegressionModel;
import org.dmg.pmml.regression.RegressionTable;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.OrdinalLabel;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.python.CastFunction;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Classifier;

public class OrdinalClassifier
extends Classifier {
    public OrdinalClassifier(String module, String name) {
        super(module, name);
    }

    public Model encodeModel(Schema schema) {
        Map<?, Classifier> estimators = this.getEstimators();
        Map<?, ?> estimatorCategories = this.getEstimatorCategories();
        SkLearnEncoder encoder = (SkLearnEncoder)schema.getEncoder();
        OrdinalLabel ordinalLabel = (OrdinalLabel)schema.getLabel();
        List features = schema.getFeatures();
        SchemaUtil.checkSize((int)(estimators.size() + 1), (DiscreteLabel)ordinalLabel);
        ArrayList<Object> models = new ArrayList<Object>();
        ArrayList probabilityFeatures = new ArrayList();
        int max = ordinalLabel.size() - 1;
        for (int i = 0; i < max; ++i) {
            String name;
            Classifier estimator;
            Object category = ordinalLabel.getValue(i);
            if (estimatorCategories != null && !estimatorCategories.isEmpty()) {
                Object estimatorCategory = estimatorCategories.get(category);
                if (estimatorCategory == null) {
                    throw new IllegalArgumentException();
                }
                estimator = estimators.get(estimatorCategory);
            } else {
                estimator = estimators.get(category);
            }
            if (estimator == null) {
                throw new IllegalArgumentException();
            }
            if (!estimator.hasProbabilityDistribution()) {
                throw new IllegalArgumentException();
            }
            CategoricalLabel segmentLabel = new CategoricalLabel(DataType.DOUBLE, Arrays.asList("<=" + ValueUtil.asString((Object)category), ">" + ValueUtil.asString((Object)category)));
            Schema segmentSchema = schema.toRelabeledSchema((Label)segmentLabel);
            Model model = estimator.encode(segmentSchema);
            List segmentFeatures = encoder.export(model, name = FieldNameUtil.create((String)"probability", (Object[])new Object[]{segmentLabel.getValue(1)}));
            if (segmentFeatures.size() != 1) {
                throw new IllegalArgumentException();
            }
            models.add(model);
            probabilityFeatures.addAll(segmentFeatures);
        }
        SchemaUtil.checkSize((int)estimators.size(), probabilityFeatures);
        ArrayList<RegressionTable> regressionTables = new ArrayList<RegressionTable>();
        for (int i = 0; i < estimators.size(); ++i) {
            RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(Collections.singletonList((Feature)probabilityFeatures.get(i)), Collections.singletonList(1), (Number)0.0).setTargetCategory(ordinalLabel.getValue(i));
            regressionTables.add(regressionTable);
        }
        RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(Collections.emptyList(), Collections.emptyList(), (Number)1.0).setTargetCategory(ordinalLabel.getValue(estimators.size()));
        regressionTables.add(regressionTable);
        RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema((Label)ordinalLabel), regressionTables).setNormalizationMethod(RegressionModel.NormalizationMethod.NONE);
        this.encodePredictProbaOutput((Model)regressionModel, DataType.DOUBLE, (DiscreteLabel)ordinalLabel);
        models.add(regressionModel);
        return MiningModelUtil.createModelChain(models, (Segmentation.MissingPredictionTreatment)Segmentation.MissingPredictionTreatment.RETURN_MISSING);
    }

    protected DiscreteLabel encodeLabel(String name, OpType opType, DataType dataType, List<?> categories, SkLearnEncoder encoder) {
        return super.encodeLabel(name, OpType.ORDINAL, DataType.STRING, categories, encoder);
    }

    public Classifier getEstimator() {
        return (Classifier)this.get("estimator", Classifier.class);
    }

    public Map<?, ? extends Classifier> getEstimators() {
        Map estimators = this.getDict("estimators_");
        Function<Object, Object> keyFunction = new Function<Object, Object>(){

            @Override
            public Object apply(Object object) {
                object = ScalarUtil.decode((Object)object);
                return Classifier.canonicalizeValue((Object)object);
            }
        };
        CastFunction<Classifier> valueFunction = new CastFunction<Classifier>(Classifier.class){

            protected String formatMessage(Object object) {
                return "The item value object (" + ClassDictUtil.formatClass((Object)object) + ") is not a supported Classifier";
            }
        };
        Map<Object, Classifier> result = estimators.entrySet().stream().collect(Collectors.toMap(entry -> keyFunction.apply(entry.getKey()), arg_0 -> OrdinalClassifier.lambda$getEstimators$1((Function)valueFunction, arg_0)));
        return result;
    }

    private Map<?, ?> getEstimatorCategories() {
        if (!this.hasattr("pmml_classes_")) {
            return null;
        }
        List classes = this.getClasses("classes_");
        List pmmlClasses = this.getClasses("pmml_classes_");
        ClassDictUtil.checkSize((Collection[])new Collection[]{classes, pmmlClasses});
        LinkedHashMap result = new LinkedHashMap();
        for (int i = 0; i < classes.size(); ++i) {
            result.put(pmmlClasses.get(i), classes.get(i));
        }
        return result;
    }

    private static /* synthetic */ Classifier lambda$getEstimators$1(Function valueFunction, Map.Entry entry) {
        return (Classifier)valueFunction.apply(entry.getValue());
    }
}

