/*
 * Decompiled with CFR 0.152.
 */
package sklearn2pmml.ensemble;

import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.VisitorAction;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelEncoder;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.model.visitors.AbstractVisitor;
import org.jpmml.python.ClassDictUtil;
import sklearn.Estimator;
import sklearn.EstimatorUtil;
import sklearn.preprocessing.MultiOneHotEncoder;

public class GBDTUtil {
    private GBDTUtil() {
    }

    public static MiningModel encodeModel(Estimator gbdt, MultiOneHotEncoder ohe, List<? extends Number> coef, Number intercept, Schema schema) {
        Map<Integer, Number> nodeScores;
        Model model = EstimatorUtil.encodeNativeLike(gbdt, schema);
        final ArrayList treeModels = new ArrayList();
        AbstractVisitor modelVisitor = new AbstractVisitor(){

            public VisitorAction visit(TreeModel treeModel) {
                treeModels.add(treeModel);
                return super.visit(treeModel);
            }
        };
        modelVisitor.applyTo((Visitable)model);
        List<List<Object>> treeCategories = ohe.getCategories();
        ClassDictUtil.checkSize((Collection[])new Collection[]{treeModels, treeCategories});
        ArrayList treeNodeScores = new ArrayList();
        int coefOffset = 0;
        for (List<Object> treeCategory : treeCategories) {
            nodeScores = new LinkedHashMap();
            for (int i = 0; i < treeCategory.size(); ++i) {
                Integer id = ValueUtil.asInteger((Number)((Number)treeCategory.get(i)));
                Number score = coef.get(coefOffset + i);
                if (ValueUtil.isZeroLike((Number)score)) {
                    score = 0.0;
                }
                nodeScores.put(id, score);
            }
            treeNodeScores.add(nodeScores);
            coefOffset += treeCategory.size();
        }
        ClassDictUtil.checkSize((int)coefOffset, (Collection[])new Collection[]{coef});
        for (int i = 0; i < treeModels.size(); ++i) {
            TreeModel treeModel = (TreeModel)treeModels.get(i);
            nodeScores = (Map)treeNodeScores.get(i);
            treeModel.setMiningFunction(MiningFunction.REGRESSION).setMathContext(null);
            AbstractVisitor treeModelVisitor = new AbstractVisitor(){

                public VisitorAction visit(Node node) {
                    Object id = node.getId();
                    if (id instanceof String) {
                        String string = (String)id;
                        id = Integer.parseInt(string);
                    } else if (id instanceof Number) {
                        Number number = (Number)id;
                        id = ValueUtil.asInteger((Number)number);
                    } else {
                        throw new IllegalArgumentException(String.valueOf(id));
                    }
                    if (node.hasScoreDistributions()) {
                        List scoreDistributions = node.getScoreDistributions();
                        scoreDistributions.clear();
                    }
                    Number score = (Number)nodeScores.get((Integer)id);
                    node.setScore((Object)score);
                    return super.visit(node);
                }
            };
            treeModelVisitor.applyTo((Visitable)treeModel);
        }
        ModelEncoder encoder = schema.getEncoder();
        Label label = schema.getLabel();
        ContinuousLabel continuousLabel = label instanceof ContinuousLabel ? (ContinuousLabel)label : new ContinuousLabel(DataType.DOUBLE);
        MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label)continuousLabel)).setSegmentation(MiningModelUtil.createSegmentation((Segmentation.MultipleModelMethod)Segmentation.MultipleModelMethod.SUM, (Segmentation.MissingPredictionTreatment)Segmentation.MissingPredictionTreatment.RETURN_MISSING, treeModels)).setTargets(ModelUtil.createRescaleTargets(null, (Number)intercept, (ContinuousLabel)continuousLabel));
        encoder.transferContent(model, (Model)miningModel);
        return miningModel;
    }
}

