/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.lightgbm;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.Interval;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.TypeDefinitionField;
import org.dmg.pmml.Visitable;
import org.dmg.pmml.Visitor;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Decorator;
import org.jpmml.converter.ImportanceDecorator;
import org.jpmml.converter.Label;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.lightgbm.BinomialLogisticRegression;
import org.jpmml.lightgbm.LightGBMEncoder;
import org.jpmml.lightgbm.LightGBMUtil;
import org.jpmml.lightgbm.MultinomialLogisticRegression;
import org.jpmml.lightgbm.ObjectiveFunction;
import org.jpmml.lightgbm.PoissonRegression;
import org.jpmml.lightgbm.Regression;
import org.jpmml.lightgbm.Section;
import org.jpmml.lightgbm.Tree;
import org.jpmml.lightgbm.visitors.TreeModelCompactor;

public class GBDT {
    private int max_feature_idx_;
    private int label_idx_;
    private String[] feature_names_;
    private String[] feature_infos_;
    private boolean boost_from_average_;
    private ObjectiveFunction object_function_;
    private Tree[] models_;
    private Map<String, String> feature_importances = Collections.emptyMap();
    private static final Integer CATEGORY_MISSING = -1;

    public void load(List<Section> sections) {
        Section section;
        int index = 0;
        Section section2 = sections.get(index);
        if (!section2.checkId("tree")) {
            throw new IllegalArgumentException();
        }
        this.max_feature_idx_ = section2.getInt("max_feature_idx");
        this.label_idx_ = section2.getInt("label_index");
        this.feature_names_ = section2.getStringArray("feature_names", this.max_feature_idx_ + 1);
        this.feature_infos_ = section2.getStringArray("feature_infos", this.max_feature_idx_ + 1);
        this.boost_from_average_ = section2.containsKey("boost_from_average");
        this.object_function_ = GBDT.parseObjectiveFunction(section2.getString("objective"));
        ++index;
        ArrayList<Tree> trees = new ArrayList<Tree>();
        while (index < sections.size() && (section = sections.get(index)).checkId("Tree=" + String.valueOf(index - 1))) {
            Tree tree = new Tree();
            tree.load(section);
            trees.add(tree);
            ++index;
        }
        this.models_ = trees.toArray(new Tree[trees.size()]);
        if (index < sections.size() && (section = sections.get(index)).checkId("feature importances:")) {
            this.feature_importances = this.loadFeatureSection(section);
            ++index;
        }
    }

    public PMML encodePMML(FieldName targetField, List<String> targetCategories, Integer numIteration, boolean transform) {
        LightGBMEncoder encoder = new LightGBMEncoder();
        if (targetField == null) {
            targetField = FieldName.create((String)"_target");
        }
        Label label = this.object_function_.encodeLabel(targetField, targetCategories, (PMMLEncoder)encoder);
        ArrayList<Object> features = new ArrayList<Object>();
        String[] featureNames = this.feature_names_;
        String[] featureInfos = this.feature_infos_;
        for (int i = 0; i < featureNames.length; ++i) {
            DataField dataField;
            Boolean categorical;
            String featureName = featureNames[i];
            String featureInfo = featureInfos[i];
            Boolean binary = this.isBinary(i);
            if (binary == null) {
                binary = Boolean.FALSE;
            }
            if ((categorical = this.isCategorical(i)) == null) {
                categorical = LightGBMUtil.isValues(featureInfo);
            }
            FieldName activeField = FieldName.create((String)featureNames[i]);
            if (categorical.booleanValue()) {
                if (binary.booleanValue()) {
                    throw new IllegalArgumentException();
                }
                ArrayList<Integer> categories = new ArrayList<Integer>();
                categories.addAll(LightGBMUtil.parseValues(featureInfo));
                if (categories.contains(CATEGORY_MISSING)) {
                    categories.remove(CATEGORY_MISSING);
                }
                Collections.sort(categories);
                dataField = encoder.createDataField(activeField, OpType.CATEGORICAL, DataType.INTEGER);
                PMMLUtil.addValues((DataField)dataField, (List)Lists.transform(categories, LightGBMUtil.CATEGORY_FORMATTER));
                features.add(new CategoricalFeature((PMMLEncoder)encoder, dataField));
            } else if (binary.booleanValue()) {
                DataField dataField2 = encoder.createDataField(activeField, OpType.CATEGORICAL, DataType.INTEGER, Arrays.asList("0", "1"));
                features.add(new BinaryFeature((PMMLEncoder)encoder, (TypeDefinitionField)dataField2, "1"));
            } else {
                Interval interval = LightGBMUtil.parseInterval(featureInfo);
                dataField = encoder.createDataField(activeField, OpType.CONTINUOUS, DataType.DOUBLE);
                PMMLUtil.addIntervals((DataField)dataField, Arrays.asList(interval));
                features.add(new ContinuousFeature((PMMLEncoder)encoder, (TypeDefinitionField)dataField));
            }
            ImportanceDecorator importanceDecorator = new ImportanceDecorator().setImportance(this.getFeatureImportance(featureName));
            encoder.addDecorator(activeField, (Decorator)importanceDecorator);
        }
        Schema schema = new Schema(label, features);
        MiningModel miningModel = this.encodeMiningModel(numIteration, transform, schema);
        PMML pmml = encoder.encodePMML((Model)miningModel);
        return pmml;
    }

    public MiningModel encodeMiningModel(Integer numIteration, boolean transform, Schema schema) {
        MiningModel miningModel = this.object_function_.encodeMiningModel(Arrays.asList(this.models_), numIteration, schema);
        if (transform) {
            List<Visitor> visitors = Arrays.asList(new Visitor[]{new TreeModelCompactor()});
            for (Visitor visitor : visitors) {
                visitor.applyTo((Visitable)miningModel);
            }
        }
        return miningModel;
    }

    public String[] getFeatureNames() {
        return this.feature_names_;
    }

    public String[] getFeatureInfos() {
        return this.feature_infos_;
    }

    Boolean isBinary(int feature) {
        Tree[] trees;
        String featureInfo = this.feature_infos_[feature];
        if (!LightGBMUtil.isBinaryInterval(featureInfo)) {
            return Boolean.FALSE;
        }
        Boolean result = null;
        for (Tree tree : trees = this.models_) {
            Boolean binary = tree.isBinary(feature);
            if (binary == null) continue;
            if (!binary.booleanValue()) {
                return Boolean.FALSE;
            }
            result = Boolean.TRUE;
        }
        return result;
    }

    Boolean isCategorical(int feature) {
        Tree[] trees;
        String featureInfo = this.feature_infos_[feature];
        if (!LightGBMUtil.isValues(featureInfo)) {
            return Boolean.FALSE;
        }
        Boolean result = null;
        for (Tree tree : trees = this.models_) {
            Boolean categorical = tree.isCategorical(feature);
            if (categorical == null) continue;
            if (!categorical.booleanValue()) {
                return Boolean.FALSE;
            }
            result = Boolean.TRUE;
        }
        return result;
    }

    Double getFeatureImportance(String featureName) {
        String value = this.feature_importances.get(featureName);
        return value != null ? Double.valueOf(value) : null;
    }

    private Map<String, String> loadFeatureSection(Section section) {
        LinkedHashMap<String, String> result = new LinkedHashMap<String, String>(section);
        result.keySet().retainAll(Arrays.asList(this.feature_names_));
        return result;
    }

    public static ObjectiveFunction parseObjectiveFunction(String string) {
        String[] tokens = LightGBMUtil.parseStringArray(string, -1);
        if (tokens.length == 0) {
            throw new IllegalArgumentException(string);
        }
        String objective = tokens[0];
        Section section = new Section();
        for (int i = 1; i < tokens.length; ++i) {
            section.put(tokens[i], ':');
        }
        switch (objective) {
            case "regression": 
            case "regression_l2": 
            case "mean_squared_error": 
            case "mse": 
            case "regression_l1": 
            case "mean_absolute_error": 
            case "mae": 
            case "huber": 
            case "fair": {
                return new Regression();
            }
            case "poisson": {
                return new PoissonRegression();
            }
            case "binary": {
                return new BinomialLogisticRegression(section.getDouble("sigmoid"));
            }
            case "multiclass": {
                return new MultinomialLogisticRegression(section.getInt("num_class"));
            }
        }
        throw new IllegalArgumentException(objective);
    }
}

