/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.trees;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.meta.Bagging;
import weka.classifiers.trees.RandomTree;
import weka.core.Capabilities;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.Utils;
import weka.core.WekaException;
import weka.gui.ProgrammaticProperty;

public class RandomForest
extends Bagging {
    static final long serialVersionUID = 1116839470751428698L;
    protected boolean m_computeAttributeImportance;

    @Override
    protected int defaultNumberOfIterations() {
        return 100;
    }

    public RandomForest() {
        RandomTree rTree = new RandomTree();
        rTree.setDoNotCheckCapabilities(true);
        super.setClassifier(rTree);
        super.setRepresentCopiesUsingWeights(true);
        this.setNumIterations(this.defaultNumberOfIterations());
    }

    @Override
    public Capabilities getCapabilities() {
        return new RandomTree().getCapabilities();
    }

    @Override
    protected String defaultClassifierString() {
        return "weka.classifiers.trees.RandomTree";
    }

    @Override
    protected String[] defaultClassifierOptions() {
        String[] args = new String[]{"-do-not-check-capabilities"};
        return args;
    }

    @Override
    public String globalInfo() {
        return "Class for constructing a forest of random trees.\n\nFor more information see: \n\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Leo Breiman");
        result.setValue(TechnicalInformation.Field.YEAR, "2001");
        result.setValue(TechnicalInformation.Field.TITLE, "Random Forests");
        result.setValue(TechnicalInformation.Field.JOURNAL, "Machine Learning");
        result.setValue(TechnicalInformation.Field.VOLUME, "45");
        result.setValue(TechnicalInformation.Field.NUMBER, "1");
        result.setValue(TechnicalInformation.Field.PAGES, "5-32");
        return result;
    }

    @Override
    @ProgrammaticProperty
    public void setClassifier(Classifier newClassifier) {
        if (!(newClassifier instanceof RandomTree)) {
            throw new IllegalArgumentException("RandomForest: Argument of setClassifier() must be a RandomTree.");
        }
        super.setClassifier(newClassifier);
    }

    @Override
    @ProgrammaticProperty
    public void setRepresentCopiesUsingWeights(boolean representUsingWeights) {
        if (!representUsingWeights) {
            throw new IllegalArgumentException("RandomForest: Argument of setRepresentCopiesUsingWeights() must be true.");
        }
        super.setRepresentCopiesUsingWeights(representUsingWeights);
    }

    public String numFeaturesTipText() {
        return ((RandomTree)this.getClassifier()).KValueTipText();
    }

    public int getNumFeatures() {
        return ((RandomTree)this.getClassifier()).getKValue();
    }

    public void setNumFeatures(int newNumFeatures) {
        ((RandomTree)this.getClassifier()).setKValue(newNumFeatures);
    }

    public String computeAttributeImportanceTipText() {
        return "Compute attribute importance via mean impurity decrease";
    }

    public void setComputeAttributeImportance(boolean computeAttributeImportance) {
        this.m_computeAttributeImportance = computeAttributeImportance;
        ((RandomTree)this.m_Classifier).setComputeImpurityDecreases(computeAttributeImportance);
    }

    public boolean getComputeAttributeImportance() {
        return this.m_computeAttributeImportance;
    }

    public String maxDepthTipText() {
        return ((RandomTree)this.getClassifier()).maxDepthTipText();
    }

    public int getMaxDepth() {
        return ((RandomTree)this.getClassifier()).getMaxDepth();
    }

    public void setMaxDepth(int value) {
        ((RandomTree)this.getClassifier()).setMaxDepth(value);
    }

    public String breakTiesRandomlyTipText() {
        return ((RandomTree)this.getClassifier()).breakTiesRandomlyTipText();
    }

    public boolean getBreakTiesRandomly() {
        return ((RandomTree)this.getClassifier()).getBreakTiesRandomly();
    }

    public void setBreakTiesRandomly(boolean newBreakTiesRandomly) {
        ((RandomTree)this.getClassifier()).setBreakTiesRandomly(newBreakTiesRandomly);
    }

    @Override
    public void setDebug(boolean debug) {
        super.setDebug(debug);
        ((RandomTree)this.getClassifier()).setDebug(debug);
    }

    @Override
    public void setNumDecimalPlaces(int num) {
        super.setNumDecimalPlaces(num);
        ((RandomTree)this.getClassifier()).setNumDecimalPlaces(num);
    }

    @Override
    public void setBatchSize(String size) {
        super.setBatchSize(size);
        ((RandomTree)this.getClassifier()).setBatchSize(size);
    }

    @Override
    public void setSeed(int s) {
        super.setSeed(s);
        ((RandomTree)this.getClassifier()).setSeed(s);
    }

    @Override
    public String toString() {
        if (this.m_Classifiers == null) {
            return "RandomForest: No model built yet.";
        }
        StringBuilder buffer = new StringBuilder("RandomForest\n\n");
        buffer.append(super.toString());
        if (this.getComputeAttributeImportance()) {
            try {
                double[] nodeCounts = new double[this.m_data.numAttributes()];
                double[] impurityScores = this.computeAverageImpurityDecreasePerAttribute(nodeCounts);
                int[] sortedIndices = Utils.sort(impurityScores);
                buffer.append("\n\nAttribute importance based on average impurity decrease (and number of nodes using that attribute)\n\n");
                for (int i2 = sortedIndices.length - 1; i2 >= 0; --i2) {
                    int index = sortedIndices[i2];
                    if (index == this.m_data.classIndex()) continue;
                    buffer.append(Utils.doubleToString(impurityScores[index], 10, this.getNumDecimalPlaces())).append(" (").append(Utils.doubleToString(nodeCounts[index], 6, 0)).append(")  ").append(this.m_data.attribute(index).name()).append("\n");
                }
            }
            catch (WekaException wekaException) {
                // empty catch block
            }
        }
        return buffer.toString();
    }

    public double[] computeAverageImpurityDecreasePerAttribute(double[] nodeCounts) throws WekaException {
        if (this.m_Classifiers == null) {
            throw new WekaException("Classifier has not been built yet!");
        }
        if (!this.getComputeAttributeImportance()) {
            throw new WekaException("Stats for attribute importance have not been collected!");
        }
        double[] impurityDecreases = new double[this.m_data.numAttributes()];
        if (nodeCounts == null) {
            nodeCounts = new double[this.m_data.numAttributes()];
        }
        for (Classifier c : this.m_Classifiers) {
            double[][] forClassifier = ((RandomTree)c).getImpurityDecreases();
            for (int i2 = 0; i2 < this.m_data.numAttributes(); ++i2) {
                int n = i2;
                impurityDecreases[n] = impurityDecreases[n] + forClassifier[i2][0];
                int n2 = i2;
                nodeCounts[n2] = nodeCounts[n2] + forClassifier[i2][1];
            }
        }
        for (int i3 = 0; i3 < this.m_data.numAttributes(); ++i3) {
            if (!(nodeCounts[i3] > 0.0)) continue;
            int n = i3;
            impurityDecreases[n] = impurityDecreases[n] / nodeCounts[i3];
        }
        return impurityDecreases;
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> newVector = new Vector<Option>();
        newVector.addElement(new Option("\tSize of each bag, as a percentage of the\n\ttraining set size. (default 100)", "P", 1, "-P"));
        newVector.addElement(new Option("\tCalculate the out of bag error.", "O", 0, "-O"));
        newVector.addElement(new Option("\tWhether to store out of bag predictions in internal evaluation object.", "store-out-of-bag-predictions", 0, "-store-out-of-bag-predictions"));
        newVector.addElement(new Option("\tWhether to output complexity-based statistics when out-of-bag evaluation is performed.", "output-out-of-bag-complexity-statistics", 0, "-output-out-of-bag-complexity-statistics"));
        newVector.addElement(new Option("\tPrint the individual classifiers in the output", "print", 0, "-print"));
        newVector.addElement(new Option("\tCompute and output attribute importance (mean impurity decrease method)", "attribute-importance", 0, "-attribute-importance"));
        newVector.addElement(new Option("\tNumber of iterations.\n\t(current value " + this.getNumIterations() + ")", "I", 1, "-I <num>"));
        newVector.addElement(new Option("\tNumber of execution slots.\n\t(default 1 - i.e. no parallelism)\n\t(use 0 to auto-detect number of cores)", "num-slots", 1, "-num-slots <num>"));
        ArrayList<Option> list = Collections.list(((OptionHandler)((Object)this.getClassifier())).listOptions());
        newVector.addAll(list);
        return newVector.elements();
    }

    @Override
    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        result.add("-P");
        result.add("" + this.getBagSizePercent());
        if (this.getCalcOutOfBag()) {
            result.add("-O");
        }
        if (this.getStoreOutOfBagPredictions()) {
            result.add("-store-out-of-bag-predictions");
        }
        if (this.getOutputOutOfBagComplexityStatistics()) {
            result.add("-output-out-of-bag-complexity-statistics");
        }
        if (this.getPrintClassifiers()) {
            result.add("-print");
        }
        if (this.getComputeAttributeImportance()) {
            result.add("-attribute-importance");
        }
        result.add("-I");
        result.add("" + this.getNumIterations());
        result.add("-num-slots");
        result.add("" + this.getNumExecutionSlots());
        if (this.getDoNotCheckCapabilities()) {
            result.add("-do-not-check-capabilities");
        }
        Vector<String> classifierOptions = new Vector<String>();
        Collections.addAll(classifierOptions, ((OptionHandler)((Object)this.getClassifier())).getOptions());
        Option.deleteFlagString(classifierOptions, "-do-not-check-capabilities");
        result.addAll(classifierOptions);
        return result.toArray(new String[result.size()]);
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String bagSize = Utils.getOption('P', options);
        if (bagSize.length() != 0) {
            this.setBagSizePercent(Integer.parseInt(bagSize));
        } else {
            this.setBagSizePercent(100);
        }
        this.setCalcOutOfBag(Utils.getFlag('O', options));
        this.setStoreOutOfBagPredictions(Utils.getFlag("store-out-of-bag-predictions", options));
        this.setOutputOutOfBagComplexityStatistics(Utils.getFlag("output-out-of-bag-complexity-statistics", options));
        this.setPrintClassifiers(Utils.getFlag("print", options));
        this.setComputeAttributeImportance(Utils.getFlag("attribute-importance", options));
        String iterations = Utils.getOption('I', options);
        if (iterations.length() != 0) {
            this.setNumIterations(Integer.parseInt(iterations));
        } else {
            this.setNumIterations(this.defaultNumberOfIterations());
        }
        String numSlots = Utils.getOption("num-slots", options);
        if (numSlots.length() != 0) {
            this.setNumExecutionSlots(Integer.parseInt(numSlots));
        } else {
            this.setNumExecutionSlots(1);
        }
        RandomTree classifier = (RandomTree)AbstractClassifier.forName(this.defaultClassifierString(), options);
        classifier.setComputeImpurityDecreases(this.m_computeAttributeImportance);
        this.setDoNotCheckCapabilities(classifier.getDoNotCheckCapabilities());
        this.setSeed(classifier.getSeed());
        this.setDebug(classifier.getDebug());
        this.setNumDecimalPlaces(classifier.getNumDecimalPlaces());
        this.setBatchSize(classifier.getBatchSize());
        classifier.setDoNotCheckCapabilities(true);
        this.setClassifier(classifier);
        Utils.checkForRemainingOptions(options);
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 13295 $");
    }

    public static void main(String[] argv) {
        RandomForest.runClassifier(new RandomForest(), argv);
    }
}

