/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.xgboost;

import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonPrimitive;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.BranchNode;
import org.dmg.pmml.tree.LeafNode;
import org.dmg.pmml.tree.SimpleNode;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoryManager;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.MissingValueFeature;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ThresholdFeature;
import org.jpmml.converter.ThresholdFeatureUtil;
import org.jpmml.converter.ValueUtil;
import org.jpmml.xgboost.BinaryLoadable;
import org.jpmml.xgboost.BinaryNode;
import org.jpmml.xgboost.BinaryNodeStat;
import org.jpmml.xgboost.JSONLoadable;
import org.jpmml.xgboost.JSONNode;
import org.jpmml.xgboost.JSONUtil;
import org.jpmml.xgboost.Node;
import org.jpmml.xgboost.NodeStat;
import org.jpmml.xgboost.XGBoostDataInput;

public class RegTree
implements BinaryLoadable,
JSONLoadable {
    private int num_roots;
    private int num_nodes;
    private int num_deleted;
    private int max_depth;
    private int num_feature;
    private int size_leaf_vector;
    private Node[] nodes;
    private NodeStat[] stats;

    @Override
    public void loadBinary(XGBoostDataInput input) throws IOException {
        this.num_roots = input.readInt();
        this.num_nodes = input.readInt();
        this.num_deleted = input.readInt();
        this.max_depth = input.readInt();
        this.num_feature = input.readInt();
        this.size_leaf_vector = input.readInt();
        input.readReserved(31);
        this.nodes = (Node[])input.readObjectArray(BinaryNode.class, this.num_nodes);
        this.stats = (NodeStat[])input.readObjectArray(BinaryNodeStat.class, this.num_nodes);
    }

    @Override
    public void loadJSON(JsonObject tree) {
        JsonObject treeParam = tree.getAsJsonObject("tree_param");
        this.num_nodes = treeParam.getAsJsonPrimitive("num_nodes").getAsInt();
        this.num_deleted = treeParam.getAsJsonPrimitive("num_deleted").getAsInt();
        this.num_feature = treeParam.getAsJsonPrimitive("num_feature").getAsInt();
        this.size_leaf_vector = treeParam.getAsJsonPrimitive("size_leaf_vector").getAsInt();
        int[] parents = JSONUtil.toIntArray(tree.getAsJsonArray("parents"));
        int[] left_children = JSONUtil.toIntArray(tree.getAsJsonArray("left_children"));
        int[] right_children = JSONUtil.toIntArray(tree.getAsJsonArray("right_children"));
        boolean[] default_left = JSONUtil.toBooleanArray(tree.getAsJsonArray("default_left"));
        int[] split_indices = JSONUtil.toIntArray(tree.getAsJsonArray("split_indices"));
        int[] split_type = JSONUtil.toIntArray(tree.getAsJsonArray("split_type"));
        float[] split_conditions = JSONUtil.toFloatArray(tree.getAsJsonArray("split_conditions"));
        this.nodes = new Node[this.num_nodes];
        for (int i = 0; i < this.num_nodes; ++i) {
            if (split_type[i] != 0) {
                throw new IllegalArgumentException();
            }
            JsonObject node = new JsonObject();
            node.add("parent", (JsonElement)new JsonPrimitive((Number)parents[i]));
            node.add("left_child", (JsonElement)new JsonPrimitive((Number)left_children[i]));
            node.add("right_child", (JsonElement)new JsonPrimitive((Number)right_children[i]));
            node.add("default_left", (JsonElement)new JsonPrimitive(Boolean.valueOf(default_left[i])));
            node.add("split_index", (JsonElement)new JsonPrimitive((Number)split_indices[i]));
            node.add("split_condition", (JsonElement)new JsonPrimitive((Number)Float.valueOf(split_conditions[i])));
            this.nodes[i] = new JSONNode();
            ((JSONLoadable)((Object)this.nodes[i])).loadJSON(node);
        }
    }

    public Float getLeafValue() {
        Node node = this.nodes[0];
        if (node.is_leaf()) {
            return Float.valueOf(node.leaf_value());
        }
        return null;
    }

    public TreeModel encodeTreeModel(boolean numeric, PredicateManager predicateManager, Schema schema) {
        org.dmg.pmml.tree.Node root = this.encodeNode(0, (Predicate)True.INSTANCE, numeric, new CategoryManager(), predicateManager, schema);
        TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label)schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT).setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD).setMathContext(MathContext.FLOAT);
        return treeModel;
    }

    private org.dmg.pmml.tree.Node encodeNode(int index, Predicate predicate, boolean numeric, CategoryManager categoryManager, PredicateManager predicateManager, Schema schema) {
        Integer id = index;
        Node node = this.nodes[index];
        if (!node.is_leaf()) {
            Predicate rightPredicate;
            Predicate leftPredicate;
            int splitIndex = node.split_index();
            Feature feature = schema.getFeature(splitIndex);
            boolean defaultLeft = node.default_left();
            boolean swapChildren = false;
            CategoryManager leftCategoryManager = categoryManager;
            CategoryManager rightCategoryManager = categoryManager;
            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature)feature;
                Object value2 = binaryFeature.getValue();
                leftPredicate = predicateManager.createSimplePredicate((Feature)binaryFeature, SimplePredicate.Operator.NOT_EQUAL, value2);
                rightPredicate = predicateManager.createSimplePredicate((Feature)binaryFeature, SimplePredicate.Operator.EQUAL, value2);
            } else if (feature instanceof MissingValueFeature) {
                MissingValueFeature missingValueFeature = (MissingValueFeature)feature;
                leftPredicate = predicateManager.createSimplePredicate((Feature)missingValueFeature, SimplePredicate.Operator.IS_NOT_MISSING, null);
                rightPredicate = predicateManager.createSimplePredicate((Feature)missingValueFeature, SimplePredicate.Operator.IS_MISSING, null);
            } else if (feature instanceof ThresholdFeature && !numeric) {
                ThresholdFeature thresholdFeature = (ThresholdFeature)feature;
                FieldName name = thresholdFeature.getName();
                Object missingValue = thresholdFeature.getMissingValue();
                Float splitValue = Float.valueOf(Float.intBitsToFloat(node.split_cond()));
                java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name);
                if (!ValueUtil.isNaN((Object)missingValue)) {
                    valueFilter = valueFilter.and(value -> !ValueUtil.isNaN((Object)value));
                }
                List leftValues = thresholdFeature.getValues(value -> value.floatValue() < splitValue.floatValue()).stream().filter(valueFilter).collect(Collectors.toList());
                List rightValues = thresholdFeature.getValues(value -> value.floatValue() >= splitValue.floatValue()).stream().filter(valueFilter).collect(Collectors.toList());
                leftCategoryManager = leftCategoryManager.fork(name, leftValues);
                rightCategoryManager = rightCategoryManager.fork(name, rightValues);
                leftPredicate = ThresholdFeatureUtil.createPredicate((ThresholdFeature)thresholdFeature, leftValues, (Object)missingValue, (PredicateManager)predicateManager);
                rightPredicate = ThresholdFeatureUtil.createPredicate((ThresholdFeature)thresholdFeature, rightValues, (Object)missingValue, (PredicateManager)predicateManager);
                if (!ThresholdFeatureUtil.isMissingValueSafe((Predicate)leftPredicate) && ThresholdFeatureUtil.isMissingValueSafe((Predicate)rightPredicate)) {
                    swapChildren = true;
                }
            } else {
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                Number splitValue = Float.valueOf(Float.intBitsToFloat(node.split_cond()));
                DataType dataType = continuousFeature.getDataType();
                switch (dataType) {
                    case INTEGER: {
                        splitValue = (int)(((Number)splitValue).floatValue() + 1.0f);
                        break;
                    }
                    case FLOAT: {
                        break;
                    }
                    default: {
                        throw new IllegalArgumentException("Expected integer or float data type for continuous feature " + continuousFeature.getName() + ", got " + dataType.value() + " data type");
                    }
                }
                leftPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.LESS_THAN, (Object)splitValue);
                rightPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, (Object)splitValue);
            }
            org.dmg.pmml.tree.Node leftChild = this.encodeNode(node.left_child(), leftPredicate, numeric, leftCategoryManager, predicateManager, schema);
            org.dmg.pmml.tree.Node rightChild = this.encodeNode(node.right_child(), rightPredicate, numeric, rightCategoryManager, predicateManager, schema);
            org.dmg.pmml.tree.Node result = new BranchNode(null, predicate).setId((Object)id).setDefaultChild(defaultLeft ? leftChild.getId() : rightChild.getId()).addNodes(leftChild, rightChild);
            if (swapChildren) {
                List children = result.getNodes();
                Collections.swap(children, 0, 1);
            }
            return result;
        }
        Float value3 = Float.valueOf(node.leaf_value() + 0.0f);
        SimpleNode result = new LeafNode((Object)value3, predicate).setId((Object)id);
        return result;
    }
}

