/*
 * Decompiled with CFR 0.152.
 */
package interpret.glassbox.ebm;

import com.google.common.collect.Iterables;
import interpret.glassbox.ebm.ExplainableBoostingUtil;
import interpret.glassbox.ebm.HasExplainableBooster;
import java.util.AbstractList;
import java.util.Arrays;
import java.util.List;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Model;
import org.dmg.pmml.regression.RegressionModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.DiscreteLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Schema;
import org.jpmml.converter.regression.RegressionModelUtil;
import org.jpmml.python.HasArray;
import sklearn.Classifier;

public class ExplainableBoostingClassifier
extends Classifier
implements HasExplainableBooster {
    private static final String LINK_LOGIT = "logit";

    public ExplainableBoostingClassifier(String module, String name) {
        super(module, name);
    }

    public Model encodeModel(Schema schema) {
        List<Number> intercept = this.getIntercept();
        RegressionModel.NormalizationMethod normalizationMethod = ExplainableBoostingClassifier.parseLink(this.getLink());
        CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
        final List<Feature> features = ExplainableBoostingUtil.encodeExplainableBooster(this, schema);
        AbstractList<Number> coefficients = new AbstractList<Number>(){

            @Override
            public int size() {
                return features.size();
            }

            @Override
            public Number get(int index) {
                return 1.0;
            }
        };
        if (categoricalLabel.size() != 2) {
            throw new IllegalArgumentException();
        }
        RegressionModel regressionModel = RegressionModelUtil.createBinaryLogisticClassification(features, (List)coefficients, (Number)((Number)Iterables.getOnlyElement(intercept)), (RegressionModel.NormalizationMethod)normalizationMethod, (boolean)false, (Schema)schema);
        this.encodePredictProbaOutput((Model)regressionModel, DataType.DOUBLE, (DiscreteLabel)categoricalLabel);
        return regressionModel;
    }

    @Override
    public List<List<?>> getBins() {
        return this.getList("bins_", List.class);
    }

    @Override
    public List<String> getFeatureTypesIn() {
        return this.getEnumList("feature_types_in_", arg_0 -> ((ExplainableBoostingClassifier)this).getStringList(arg_0), Arrays.asList("continuous", "nominal"));
    }

    public List<Number> getIntercept() {
        return this.getNumberArray("intercept_");
    }

    public String getLink() {
        return (String)this.getEnum("link_", arg_0 -> ((ExplainableBoostingClassifier)this).getString(arg_0), Arrays.asList(LINK_LOGIT));
    }

    @Override
    public List<Object[]> getTermFeatures() {
        return this.getTupleList("term_features_");
    }

    @Override
    public List<HasArray> getTermScores() {
        return this.getArrayList("term_scores_");
    }

    private static RegressionModel.NormalizationMethod parseLink(String link) {
        switch (link) {
            case "logit": {
                return RegressionModel.NormalizationMethod.LOGIT;
            }
        }
        throw new IllegalArgumentException(link);
    }
}

