/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.lightgbm;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.SimplePredicate;
import org.dmg.pmml.True;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PredicateManager;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.lightgbm.LightGBMUtil;
import org.jpmml.lightgbm.Section;

public class Tree {
    private int num_leaves_;
    private int num_cat_;
    private int[] left_child_;
    private int[] right_child_;
    private int[] split_feature_real_;
    private double[] threshold_;
    private int[] decision_type_;
    private double[] leaf_value_;
    private int[] leaf_count_;
    private double[] internal_value_;
    private int[] internal_count_;
    private int[] cat_boundaries_;
    private long[] cat_threshold_;
    private static final int MASK_CATEGORICAL = 1;
    private static final int MASK_DEFAULT_LEFT = 2;

    public void load(Section section) {
        this.num_leaves_ = section.getInt("num_leaves");
        this.num_cat_ = section.getInt("num_cat");
        this.left_child_ = section.getIntArray("left_child", this.num_leaves_ - 1);
        this.right_child_ = section.getIntArray("right_child", this.num_leaves_ - 1);
        this.split_feature_real_ = section.getIntArray("split_feature", this.num_leaves_ - 1);
        this.threshold_ = section.getDoubleArray("threshold", this.num_leaves_ - 1);
        this.decision_type_ = section.getIntArray("decision_type", this.num_leaves_ - 1);
        this.leaf_value_ = section.getDoubleArray("leaf_value", this.num_leaves_);
        this.leaf_count_ = section.getIntArray("leaf_count", this.num_leaves_);
        this.internal_value_ = section.getDoubleArray("internal_value", this.num_leaves_ - 1);
        this.internal_count_ = section.getIntArray("internal_count", this.num_leaves_ - 1);
        if (this.num_cat_ > 0) {
            this.cat_boundaries_ = section.getIntArray("cat_boundaries", this.num_cat_ + 1);
            this.cat_threshold_ = section.getUnsignedIntArray("cat_threshold", -1);
        }
    }

    public TreeModel encodeTreeModel(PredicateManager predicateManager, Schema schema) {
        Node root = new Node().setPredicate((Predicate)new True());
        this.encodeNode(root, predicateManager, Collections.emptyMap(), 0, schema);
        TreeModel treeModel = new TreeModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema((Label)schema.getLabel()), root).setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT).setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD);
        return treeModel;
    }

    public void encodeNode(Node parent, PredicateManager predicateManager, Map<FieldName, List<String>> fieldValues, int index, Schema schema) {
        parent.setId(String.valueOf(index));
        Map<FieldName, List<String>> leftFieldValues = fieldValues;
        Map<FieldName, List<String>> rightFieldValues = fieldValues;
        if (index >= 0) {
            Predicate rightPredicate;
            Predicate leftPredicate;
            String value;
            parent.setScore(null);
            parent.setRecordCount(Double.valueOf(this.internal_count_[index]));
            Feature feature = schema.getFeature(this.split_feature_real_[index]);
            double threshold_ = this.threshold_[index];
            int decision_type_ = this.decision_type_[index];
            boolean defaultLeft = Tree.hasDefaultLeftMask(decision_type_);
            if (feature instanceof BinaryFeature) {
                BinaryFeature binaryFeature = (BinaryFeature)feature;
                if (Tree.hasCategoricalMask(decision_type_) || threshold_ != 0.5) {
                    throw new IllegalArgumentException();
                }
                value = binaryFeature.getValue();
                leftPredicate = predicateManager.createSimplePredicate((Feature)binaryFeature, SimplePredicate.Operator.NOT_EQUAL, value);
                rightPredicate = predicateManager.createSimplePredicate((Feature)binaryFeature, SimplePredicate.Operator.EQUAL, value);
            } else if (feature instanceof CategoricalFeature) {
                CategoricalFeature categoricalFeature = (CategoricalFeature)feature;
                if (!Tree.hasCategoricalMask(decision_type_)) {
                    throw new IllegalArgumentException();
                }
                FieldName name = categoricalFeature.getName();
                List values = fieldValues.get(name);
                if (values == null) {
                    values = categoricalFeature.getValues();
                }
                int cat_idx = ValueUtil.asInt((Number)threshold_);
                List<String> leftValues = this.selectValues(values, cat_idx, true);
                List<String> rightValues = this.selectValues(values, cat_idx, false);
                leftFieldValues = new HashMap<FieldName, List<String>>(fieldValues);
                leftFieldValues.put(name, leftValues);
                rightFieldValues = new HashMap<FieldName, List<String>>(fieldValues);
                rightFieldValues.put(name, rightValues);
                leftPredicate = predicateManager.createSimpleSetPredicate((Feature)categoricalFeature, leftValues);
                rightPredicate = predicateManager.createSimpleSetPredicate((Feature)categoricalFeature, rightValues);
                defaultLeft = false;
            } else {
                ContinuousFeature continuousFeature = feature.toContinuousFeature();
                if (Tree.hasCategoricalMask(decision_type_)) {
                    throw new IllegalArgumentException();
                }
                value = ValueUtil.formatValue((Number)threshold_);
                leftPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value);
                rightPredicate = predicateManager.createSimplePredicate((Feature)continuousFeature, SimplePredicate.Operator.GREATER_THAN, value);
            }
            Node leftChild = new Node().setPredicate(leftPredicate);
            this.encodeNode(leftChild, predicateManager, leftFieldValues, this.left_child_[index], schema);
            Node rightChild = new Node().setPredicate(rightPredicate);
            this.encodeNode(rightChild, predicateManager, rightFieldValues, this.right_child_[index], schema);
            parent.addNodes(new Node[]{leftChild, rightChild});
            parent.setDefaultChild(defaultLeft ? leftChild.getId() : rightChild.getId());
        } else {
            parent.setScore(ValueUtil.formatValue((Number)this.leaf_value_[index ^= 0xFFFFFFFF]));
            parent.setRecordCount(Double.valueOf(this.leaf_count_[index]));
        }
    }

    private List<String> selectValues(List<String> values, int cat_idx, boolean left) {
        ArrayList<Object> result = left ? new ArrayList() : new ArrayList<String>(values);
        int n = this.cat_boundaries_[cat_idx + 1] - this.cat_boundaries_[cat_idx];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < 32; ++j) {
                int cat = i * 32 + j;
                if (!Tree.findInBitset(this.cat_threshold_, this.cat_boundaries_[cat_idx], n, cat)) continue;
                String value = (String)LightGBMUtil.CATEGORY_FORMATTER.apply((Object)cat);
                if (left) {
                    result.add(value);
                    continue;
                }
                result.remove(value);
            }
        }
        if (left ? result.isEmpty() : result.equals(values)) {
            throw new IllegalArgumentException();
        }
        return result;
    }

    Boolean isBinary(int feature) {
        Boolean result = null;
        for (int i = 0; i < this.split_feature_real_.length; ++i) {
            if (this.split_feature_real_[i] != feature) continue;
            if (Tree.hasCategoricalMask(this.decision_type_[i])) {
                return Boolean.FALSE;
            }
            if (this.threshold_[i] != 0.5) {
                return Boolean.FALSE;
            }
            result = Boolean.TRUE;
        }
        return result;
    }

    Boolean isCategorical(int feature) {
        Boolean result = null;
        for (int i = 0; i < this.split_feature_real_.length; ++i) {
            if (this.split_feature_real_[i] != feature) continue;
            if (!Tree.hasCategoricalMask(this.decision_type_[i])) {
                return Boolean.FALSE;
            }
            result = Boolean.TRUE;
        }
        return result;
    }

    private static boolean hasCategoricalMask(int decision_type) {
        return Tree.getDecisionType(decision_type, 1) == 1;
    }

    private static boolean hasDefaultLeftMask(int decision_type) {
        return Tree.getDecisionType(decision_type, 2) == 2;
    }

    static int getDecisionType(int decision_type, int mask) {
        return decision_type & mask;
    }

    static int getMissingType(int decision_type) {
        return Tree.getDecisionType(decision_type >> 2, 3);
    }

    private static boolean findInBitset(long[] bits, int bitOffset, int n, int pos) {
        int i1 = pos / 32;
        if (i1 >= n) {
            return false;
        }
        int i2 = pos % 32;
        return (bits[bitOffset + i1] >> i2 & 1L) == 1L;
    }
}

