/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.util.Collections;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.IterativeClassifier;
import weka.classifiers.RandomizableIteratedSingleClassifierEnhancer;
import weka.classifiers.Sourcable;
import weka.classifiers.rules.ZeroR;
import weka.classifiers.trees.DecisionStump;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Randomizable;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class AdaBoostM1
extends RandomizableIteratedSingleClassifierEnhancer
implements WeightedInstancesHandler,
Sourcable,
TechnicalInformationHandler,
IterativeClassifier {
    static final long serialVersionUID = -1178107808933117974L;
    private static int MAX_NUM_RESAMPLING_ITERATIONS = 10;
    protected double[] m_Betas;
    protected int m_NumIterationsPerformed;
    protected int m_WeightThreshold = 100;
    protected boolean m_UseResampling;
    protected int m_NumClasses;
    protected Classifier m_ZeroR;
    protected Instances m_TrainingData;
    protected Random m_RandomInstance;

    public AdaBoostM1() {
        this.m_Classifier = new DecisionStump();
    }

    public String globalInfo() {
        return "Class for boosting a nominal class classifier using the Adaboost M1 method. Only nominal class problems can be tackled. Often dramatically improves performance, but sometimes overfits.\n\nFor more information, see\n\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Yoav Freund and Robert E. Schapire");
        result.setValue(TechnicalInformation.Field.TITLE, "Experiments with a new boosting algorithm");
        result.setValue(TechnicalInformation.Field.BOOKTITLE, "Thirteenth International Conference on Machine Learning");
        result.setValue(TechnicalInformation.Field.YEAR, "1996");
        result.setValue(TechnicalInformation.Field.PAGES, "148-156");
        result.setValue(TechnicalInformation.Field.PUBLISHER, "Morgan Kaufmann");
        result.setValue(TechnicalInformation.Field.ADDRESS, "San Francisco");
        return result;
    }

    @Override
    protected String defaultClassifierString() {
        return "weka.classifiers.trees.DecisionStump";
    }

    protected Instances selectWeightQuantile(Instances data, double quantile) {
        int numInstances = data.numInstances();
        Instances trainData = new Instances(data, numInstances);
        double[] weights = new double[numInstances];
        double sumOfWeights = 0.0;
        for (int i2 = 0; i2 < numInstances; ++i2) {
            weights[i2] = data.instance(i2).weight();
            sumOfWeights += weights[i2];
        }
        double weightMassToSelect = sumOfWeights * quantile;
        int[] sortedIndices = Utils.sort(weights);
        sumOfWeights = 0.0;
        for (int i3 = numInstances - 1; i3 >= 0; --i3) {
            Instance instance = (Instance)data.instance(sortedIndices[i3]).copy();
            trainData.add(instance);
            if ((sumOfWeights += weights[sortedIndices[i3]]) > weightMassToSelect && i3 > 0 && weights[sortedIndices[i3]] != weights[sortedIndices[i3 - 1]]) break;
        }
        if (this.m_Debug) {
            System.err.println("Selected " + trainData.numInstances() + " out of " + numInstances);
        }
        return trainData;
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> newVector = new Vector<Option>();
        newVector.addElement(new Option("\tPercentage of weight mass to base training on.\n\t(default 100, reduce to around 90 speed up)", "P", 1, "-P <num>"));
        newVector.addElement(new Option("\tUse resampling for boosting.", "Q", 0, "-Q"));
        newVector.addAll(Collections.list(super.listOptions()));
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String thresholdString = Utils.getOption('P', options);
        if (thresholdString.length() != 0) {
            this.setWeightThreshold(Integer.parseInt(thresholdString));
        } else {
            this.setWeightThreshold(100);
        }
        this.setUseResampling(Utils.getFlag('Q', options));
        super.setOptions(options);
        Utils.checkForRemainingOptions(options);
    }

    @Override
    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        if (this.getUseResampling()) {
            result.add("-Q");
        }
        result.add("-P");
        result.add("" + this.getWeightThreshold());
        Collections.addAll(result, super.getOptions());
        return result.toArray(new String[result.size()]);
    }

    public String weightThresholdTipText() {
        return "Weight threshold for weight pruning.";
    }

    public void setWeightThreshold(int threshold) {
        this.m_WeightThreshold = threshold;
    }

    public int getWeightThreshold() {
        return this.m_WeightThreshold;
    }

    public String useResamplingTipText() {
        return "Whether resampling is used instead of reweighting.";
    }

    public void setUseResampling(boolean r) {
        this.m_UseResampling = r;
    }

    public boolean getUseResampling() {
        return this.m_UseResampling;
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAllClasses();
        result.disableAllClassDependencies();
        if (super.getCapabilities().handles(Capabilities.Capability.NOMINAL_CLASS)) {
            result.enable(Capabilities.Capability.NOMINAL_CLASS);
        }
        if (super.getCapabilities().handles(Capabilities.Capability.BINARY_CLASS)) {
            result.enable(Capabilities.Capability.BINARY_CLASS);
        }
        return result;
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        this.initializeClassifier(data);
        while (this.next()) {
        }
        this.done();
    }

    @Override
    public void initializeClassifier(Instances data) throws Exception {
        super.buildClassifier(data);
        this.getCapabilities().testWithFail(data);
        data = new Instances(data);
        data.deleteWithMissingClass();
        this.m_ZeroR = new ZeroR();
        this.m_ZeroR.buildClassifier(data);
        this.m_NumClasses = data.numClasses();
        this.m_Betas = new double[this.m_Classifiers.length];
        this.m_NumIterationsPerformed = 0;
        this.m_TrainingData = new Instances(data);
        this.m_RandomInstance = new Random(this.m_Seed);
        if (this.m_UseResampling || !(this.m_Classifier instanceof WeightedInstancesHandler)) {
            double sumProbs = this.m_TrainingData.sumOfWeights();
            for (int i2 = 0; i2 < this.m_TrainingData.numInstances(); ++i2) {
                this.m_TrainingData.instance(i2).setWeight(this.m_TrainingData.instance(i2).weight() / sumProbs);
            }
        }
    }

    @Override
    public boolean next() throws Exception {
        if (this.m_NumIterationsPerformed >= this.m_NumIterations) {
            return false;
        }
        if (this.m_TrainingData.numAttributes() == 1) {
            return false;
        }
        if (this.m_Debug) {
            System.err.println("Training classifier " + (this.m_NumIterationsPerformed + 1));
        }
        Instances trainData = null;
        trainData = this.m_WeightThreshold < 100 ? this.selectWeightQuantile(this.m_TrainingData, (double)this.m_WeightThreshold / 100.0) : new Instances(this.m_TrainingData);
        double epsilon = 0.0;
        if (this.m_UseResampling || !(this.m_Classifier instanceof WeightedInstancesHandler)) {
            Evaluation evaluation;
            int resamplingIterations = 0;
            double[] weights = new double[trainData.numInstances()];
            for (int i2 = 0; i2 < weights.length; ++i2) {
                weights[i2] = trainData.instance(i2).weight();
            }
            do {
                Instances sample = trainData.resampleWithWeights(this.m_RandomInstance, weights);
                this.m_Classifiers[this.m_NumIterationsPerformed].buildClassifier(sample);
                evaluation = new Evaluation(this.m_TrainingData);
                evaluation.evaluateModel(this.m_Classifiers[this.m_NumIterationsPerformed], this.m_TrainingData, new Object[0]);
            } while (Utils.eq(epsilon = evaluation.errorRate(), 0.0) && ++resamplingIterations < MAX_NUM_RESAMPLING_ITERATIONS);
        } else {
            if (this.m_Classifiers[this.m_NumIterationsPerformed] instanceof Randomizable) {
                ((Randomizable)((Object)this.m_Classifiers[this.m_NumIterationsPerformed])).setSeed(this.m_RandomInstance.nextInt());
            }
            this.m_Classifiers[this.m_NumIterationsPerformed].buildClassifier(trainData);
            Evaluation evaluation = new Evaluation(this.m_TrainingData);
            evaluation.evaluateModel(this.m_Classifiers[this.m_NumIterationsPerformed], this.m_TrainingData, new Object[0]);
            epsilon = evaluation.errorRate();
        }
        if (Utils.grOrEq(epsilon, 0.5) || Utils.eq(epsilon, 0.0)) {
            if (this.m_NumIterationsPerformed == 0) {
                this.m_NumIterationsPerformed = 1;
            }
            return false;
        }
        double reweight = (1.0 - epsilon) / epsilon;
        this.m_Betas[this.m_NumIterationsPerformed] = Math.log(reweight);
        if (this.m_Debug) {
            System.err.println("\terror rate = " + epsilon + "  beta = " + this.m_Betas[this.m_NumIterationsPerformed]);
        }
        this.setWeights(this.m_TrainingData, reweight);
        ++this.m_NumIterationsPerformed;
        return true;
    }

    @Override
    public void done() {
        this.m_TrainingData = null;
        if (this.m_NumIterationsPerformed > 0) {
            this.m_ZeroR = null;
        }
    }

    protected void setWeights(Instances training, double reweight) throws Exception {
        Instance instance;
        double oldSumOfWeights = training.sumOfWeights();
        Enumeration<Instance> enu = training.enumerateInstances();
        while (enu.hasMoreElements()) {
            instance = enu.nextElement();
            if (Utils.eq(this.m_Classifiers[this.m_NumIterationsPerformed].classifyInstance(instance), instance.classValue())) continue;
            instance.setWeight(instance.weight() * reweight);
        }
        double newSumOfWeights = training.sumOfWeights();
        enu = training.enumerateInstances();
        while (enu.hasMoreElements()) {
            instance = enu.nextElement();
            instance.setWeight(instance.weight() * oldSumOfWeights / newSumOfWeights);
        }
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        if (this.m_NumIterationsPerformed == 0) {
            return this.m_ZeroR.distributionForInstance(instance);
        }
        if (this.m_NumIterationsPerformed == 0) {
            throw new Exception("No model built");
        }
        double[] sums = new double[instance.numClasses()];
        if (this.m_NumIterationsPerformed == 1) {
            return this.m_Classifiers[0].distributionForInstance(instance);
        }
        for (int i2 = 0; i2 < this.m_NumIterationsPerformed; ++i2) {
            int n = (int)this.m_Classifiers[i2].classifyInstance(instance);
            sums[n] = sums[n] + this.m_Betas[i2];
        }
        return Utils.logs2probs(sums);
    }

    @Override
    public String toSource(String className) throws Exception {
        int i2;
        if (this.m_NumIterationsPerformed == 0) {
            throw new Exception("No model built yet");
        }
        if (!(this.m_Classifiers[0] instanceof Sourcable)) {
            throw new Exception("Base learner " + this.m_Classifier.getClass().getName() + " is not Sourcable");
        }
        StringBuffer text = new StringBuffer("class ");
        text.append(className).append(" {\n\n");
        text.append("  public static double classify(Object[] i) {\n");
        if (this.m_NumIterationsPerformed == 1) {
            text.append("    return " + className + "_0.classify(i);\n");
        } else {
            text.append("    double [] sums = new double [" + this.m_NumClasses + "];\n");
            for (i2 = 0; i2 < this.m_NumIterationsPerformed; ++i2) {
                text.append("    sums[(int) " + className + '_' + i2 + ".classify(i)] += " + this.m_Betas[i2] + ";\n");
            }
            text.append("    double maxV = sums[0];\n    int maxI = 0;\n    for (int j = 1; j < " + this.m_NumClasses + "; j++) {\n      if (sums[j] > maxV) { maxV = sums[j]; maxI = j; }\n    }\n    return (double) maxI;\n");
        }
        text.append("  }\n}\n");
        for (i2 = 0; i2 < this.m_Classifiers.length; ++i2) {
            text.append(((Sourcable)((Object)this.m_Classifiers[i2])).toSource(className + '_' + i2));
        }
        return text.toString();
    }

    public String toString() {
        if (this.m_NumIterationsPerformed == 0) {
            StringBuffer buf = new StringBuffer();
            if (this.m_ZeroR == null) {
                buf.append("AdaBoostM1: No model built yet.\n");
            } else {
                buf.append(this.getClass().getName().replaceAll(".*\\.", "") + "\n");
                buf.append(this.getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=") + "\n\n");
                buf.append("Warning: No model could be built, hence ZeroR model is used:\n\n");
                buf.append(this.m_ZeroR.toString());
            }
            return buf.toString();
        }
        StringBuffer text = new StringBuffer();
        if (this.m_NumIterationsPerformed == 1) {
            text.append("AdaBoostM1: No boosting possible, one classifier used!\n");
            text.append(this.m_Classifiers[0].toString() + "\n");
        } else {
            text.append("AdaBoostM1: Base classifiers and their weights: \n\n");
            for (int i2 = 0; i2 < this.m_NumIterationsPerformed; ++i2) {
                text.append(this.m_Classifiers[i2].toString() + "\n\n");
                text.append("Weight: " + Utils.roundDouble(this.m_Betas[i2], 2) + "\n\n");
            }
            text.append("Number of performed Iterations: " + this.m_NumIterationsPerformed + "\n");
        }
        return text.toString();
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 10969 $");
    }

    public static void main(String[] argv) {
        AdaBoostM1.runClassifier(new AdaBoostM1(), argv);
    }
}

