/*
 * Decompiled with CFR 0.152.
 */
package weka.gui.explorer;

import java.util.ArrayList;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.classifiers.IntervalEstimator;
import weka.classifiers.evaluation.NumericPrediction;
import weka.classifiers.evaluation.Prediction;
import weka.classifiers.misc.InputMappedClassifier;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.gui.explorer.AbstractPlotInstances;
import weka.gui.explorer.ExplorerDefaults;
import weka.gui.visualize.PlotData2D;

public class ClassifierErrorsPlotInstances
extends AbstractPlotInstances {
    private static final long serialVersionUID = -3941976365792013279L;
    protected int m_MinimumPlotSizeNumeric;
    protected int m_MaximumPlotSizeNumeric;
    protected boolean m_SaveForVisualization;
    protected boolean m_pointSizeProportionalToMargin;
    protected ArrayList<Integer> m_PlotShapes;
    protected ArrayList<Object> m_PlotSizes;
    protected Classifier m_Classifier;
    protected int m_ClassIndex;
    protected Evaluation m_Evaluation;

    @Override
    protected void initialize() {
        super.initialize();
        this.m_PlotShapes = new ArrayList();
        this.m_PlotSizes = new ArrayList();
        this.m_Classifier = null;
        this.m_ClassIndex = -1;
        this.m_Evaluation = null;
        this.m_SaveForVisualization = true;
        this.m_MinimumPlotSizeNumeric = ExplorerDefaults.getClassifierErrorsMinimumPlotSizeNumeric();
        this.m_MaximumPlotSizeNumeric = ExplorerDefaults.getClassifierErrorsMaximumPlotSizeNumeric();
    }

    public ArrayList<Integer> getPlotShapes() {
        return this.m_PlotShapes;
    }

    public ArrayList<Object> getPlotSizes() {
        return this.m_PlotSizes;
    }

    public void setPlotShapes(ArrayList<Integer> plotShapes) {
        this.m_PlotShapes = plotShapes;
    }

    public void setPlotSizes(ArrayList<Object> plotSizes) {
        this.m_PlotSizes = plotSizes;
    }

    public void setClassifier(Classifier value) {
        this.m_Classifier = value;
    }

    public Classifier getClassifier() {
        return this.m_Classifier;
    }

    public void setClassIndex(int index) {
        this.m_ClassIndex = index;
    }

    public int getClassIndex() {
        return this.m_ClassIndex;
    }

    public void setEvaluation(Evaluation value) {
        this.m_Evaluation = value;
    }

    public Evaluation getEvaluation() {
        return this.m_Evaluation;
    }

    public void setSaveForVisualization(boolean value) {
        this.m_SaveForVisualization = value;
    }

    public boolean getSaveForVisualization() {
        return this.m_SaveForVisualization;
    }

    public void setPointSizeProportionalToMargin(boolean b) {
        this.m_pointSizeProportionalToMargin = b;
    }

    public boolean getPointSizeProportionalToMargin() {
        return this.m_pointSizeProportionalToMargin;
    }

    @Override
    protected void check() {
        super.check();
        if (this.m_Classifier == null) {
            throw new IllegalStateException("No classifier set!");
        }
        if (this.m_ClassIndex == -1) {
            throw new IllegalStateException("No class index set!");
        }
        if (this.m_Evaluation == null) {
            throw new IllegalStateException("No evaluation set");
        }
    }

    @Override
    protected void determineFormat() {
        Attribute predictedClass;
        int i2;
        Attribute margin = null;
        if (!this.m_SaveForVisualization) {
            this.m_PlotInstances = null;
            return;
        }
        ArrayList<Attribute> hv = new ArrayList<Attribute>();
        Attribute classAt = this.m_Instances.attribute(this.m_ClassIndex);
        if (classAt.isNominal()) {
            ArrayList<String> attVals = new ArrayList<String>();
            for (i2 = 0; i2 < classAt.numValues(); ++i2) {
                attVals.add(classAt.value(i2));
            }
            predictedClass = new Attribute("predicted " + classAt.name(), attVals);
            margin = new Attribute("prediction margin");
        } else {
            predictedClass = new Attribute("predicted" + classAt.name());
        }
        for (i2 = 0; i2 < this.m_Instances.numAttributes(); ++i2) {
            if (i2 == this.m_Instances.classIndex()) {
                if (classAt.isNominal()) {
                    hv.add(margin);
                }
                hv.add(predictedClass);
            }
            hv.add((Attribute)this.m_Instances.attribute(i2).copy());
        }
        this.m_PlotInstances = new Instances(this.m_Instances.relationName() + "_predicted", hv, this.m_Instances.numInstances());
        if (classAt.isNominal()) {
            this.m_PlotInstances.setClassIndex(this.m_ClassIndex + 2);
        } else {
            this.m_PlotInstances.setClassIndex(this.m_ClassIndex + 1);
        }
    }

    public void process(Instances batch, double[][] predictions, Evaluation eval) {
        try {
            for (int j = 0; j < batch.numInstances(); ++j) {
                Instance toPredict = batch.instance(j);
                double[] preds = predictions[j];
                double probActual = 0.0;
                double probNext = 0.0;
                double pred = 0.0;
                if (batch.classAttribute().isNominal()) {
                    double d = pred = Utils.sum(preds) == 0.0 ? Utils.missingValue() : (double)Utils.maxIndex(preds);
                    probActual = Utils.sum(preds) == 0.0 ? Utils.missingValue() : (!Utils.isMissingValue(toPredict.classIndex()) ? preds[(int)toPredict.classValue()] : preds[Utils.maxIndex(preds)]);
                    for (int i2 = 0; i2 < toPredict.classAttribute().numValues(); ++i2) {
                        if (i2 == (int)toPredict.classValue() || !(preds[i2] > probNext)) continue;
                        probNext = preds[i2];
                    }
                } else {
                    pred = preds[0];
                }
                eval.evaluationForSingleInstance(preds, toPredict, true);
                if (!this.m_SaveForVisualization || this.m_PlotInstances == null) continue;
                double[] values = new double[this.m_PlotInstances.numAttributes()];
                boolean isNominal = toPredict.classAttribute().isNominal();
                for (int i3 = 0; i3 < this.m_PlotInstances.numAttributes(); ++i3) {
                    if (i3 < toPredict.classIndex()) {
                        values[i3] = toPredict.value(i3);
                        continue;
                    }
                    if (i3 == toPredict.classIndex()) {
                        if (isNominal) {
                            values[i3] = probActual - probNext;
                            values[i3 + 1] = pred;
                            values[i3 + 2] = toPredict.value(i3);
                            i3 += 2;
                            continue;
                        }
                        values[i3] = pred;
                        values[i3 + 1] = toPredict.value(i3);
                        ++i3;
                        continue;
                    }
                    values[i3] = isNominal ? toPredict.value(i3 - 2) : toPredict.value(i3 - 1);
                }
                this.m_PlotInstances.add(new DenseInstance(1.0, values));
                if (toPredict.classAttribute().isNominal()) {
                    if (toPredict.isMissing(toPredict.classIndex()) || Utils.isMissingValue(pred)) {
                        this.m_PlotShapes.add(new Integer(2000));
                    } else if (pred != toPredict.classValue()) {
                        this.m_PlotShapes.add(new Integer(1000));
                    } else {
                        this.m_PlotShapes.add(new Integer(-1));
                    }
                    if (this.m_pointSizeProportionalToMargin) {
                        this.m_PlotSizes.add(new Double(probActual - probNext));
                        continue;
                    }
                    int sizeAdj = 0;
                    if (pred != toPredict.classValue()) {
                        sizeAdj = 1;
                    }
                    this.m_PlotSizes.add(new Integer(2 + sizeAdj));
                    continue;
                }
                Double errd = null;
                if (!toPredict.isMissing(toPredict.classIndex()) && !Utils.isMissingValue(pred)) {
                    errd = new Double(pred - toPredict.classValue());
                    this.m_PlotShapes.add(new Integer(-1));
                } else {
                    this.m_PlotShapes.add(new Integer(2000));
                }
                this.m_PlotSizes.add(errd);
            }
        }
        catch (Exception ex) {
            ex.printStackTrace();
        }
    }

    public void process(Instance toPredict, Classifier classifier, Evaluation eval) {
        try {
            int i2;
            double pred = 0.0;
            double[] preds = null;
            double probActual = 0.0;
            double probNext = 0.0;
            int mappedClass = -1;
            Instance classMissing = (Instance)toPredict.copy();
            classMissing.setDataset(toPredict.dataset());
            if (classifier instanceof InputMappedClassifier && toPredict.classAttribute().isNominal()) {
                toPredict = (Instance)toPredict.copy();
                toPredict = ((InputMappedClassifier)classifier).constructMappedInstance(toPredict);
                mappedClass = ((InputMappedClassifier)classifier).getMappedClassIndex();
                classMissing.setMissing(mappedClass);
            } else {
                classMissing.setClassMissing();
            }
            if (toPredict.classAttribute().isNominal()) {
                preds = classifier.distributionForInstance(classMissing);
                double d = pred = Utils.sum(preds) == 0.0 ? Utils.missingValue() : (double)Utils.maxIndex(preds);
                probActual = Utils.sum(preds) == 0.0 ? Utils.missingValue() : (!Utils.isMissingValue(toPredict.classIndex()) ? preds[(int)toPredict.classValue()] : preds[Utils.maxIndex(preds)]);
                for (i2 = 0; i2 < toPredict.classAttribute().numValues(); ++i2) {
                    if (i2 == (int)toPredict.classValue() || !(preds[i2] > probNext)) continue;
                    probNext = preds[i2];
                }
                eval.evaluationForSingleInstance(preds, toPredict, true);
            } else {
                pred = eval.evaluateModelOnceAndRecordPrediction(classifier, toPredict);
            }
            if (!this.m_SaveForVisualization) {
                return;
            }
            if (this.m_PlotInstances != null) {
                boolean isNominal = toPredict.classAttribute().isNominal();
                double[] values = new double[this.m_PlotInstances.numAttributes()];
                for (i2 = 0; i2 < this.m_PlotInstances.numAttributes(); ++i2) {
                    if (i2 < toPredict.classIndex()) {
                        values[i2] = toPredict.value(i2);
                        continue;
                    }
                    if (i2 == toPredict.classIndex()) {
                        if (isNominal) {
                            values[i2] = probActual - probNext;
                            values[i2 + 1] = pred;
                            values[i2 + 2] = toPredict.value(i2);
                            i2 += 2;
                            continue;
                        }
                        values[i2] = pred;
                        values[i2 + 1] = toPredict.value(i2);
                        ++i2;
                        continue;
                    }
                    values[i2] = isNominal ? toPredict.value(i2 - 2) : toPredict.value(i2 - 1);
                }
                this.m_PlotInstances.add(new DenseInstance(1.0, values));
                if (toPredict.classAttribute().isNominal()) {
                    if (toPredict.isMissing(toPredict.classIndex()) || Utils.isMissingValue(pred)) {
                        this.m_PlotShapes.add(new Integer(2000));
                    } else if (pred != toPredict.classValue()) {
                        this.m_PlotShapes.add(new Integer(1000));
                    } else {
                        this.m_PlotShapes.add(new Integer(-1));
                    }
                    if (this.m_pointSizeProportionalToMargin) {
                        this.m_PlotSizes.add(new Double(probActual - probNext));
                    } else {
                        int sizeAdj = 0;
                        if (pred != toPredict.classValue()) {
                            sizeAdj = 1;
                        }
                        this.m_PlotSizes.add(new Integer(2 + sizeAdj));
                    }
                } else {
                    Double errd = null;
                    if (!toPredict.isMissing(toPredict.classIndex()) && !Utils.isMissingValue(pred)) {
                        errd = new Double(pred - toPredict.classValue());
                        this.m_PlotShapes.add(new Integer(-1));
                    } else {
                        this.m_PlotShapes.add(new Integer(2000));
                    }
                    this.m_PlotSizes.add(errd);
                }
            }
        }
        catch (Exception ex) {
            ex.printStackTrace();
        }
    }

    protected void scaleNumericPredictions() {
        double err;
        Double errd;
        int i2;
        double maxErr = Double.NEGATIVE_INFINITY;
        double minErr = Double.POSITIVE_INFINITY;
        if (this.m_Instances.classAttribute().isNominal()) {
            maxErr = 1.0;
            minErr = 0.0;
        } else {
            for (i2 = 0; i2 < this.m_PlotSizes.size(); ++i2) {
                errd = (Double)this.m_PlotSizes.get(i2);
                if (errd == null) continue;
                err = Math.abs(errd);
                if (err < minErr) {
                    minErr = err;
                }
                if (!(err > maxErr)) continue;
                maxErr = err;
            }
        }
        for (i2 = 0; i2 < this.m_PlotSizes.size(); ++i2) {
            errd = (Double)this.m_PlotSizes.get(i2);
            if (errd != null) {
                err = Math.abs(errd);
                if (maxErr - minErr > 0.0) {
                    double temp = (err - minErr) / (maxErr - minErr) * (double)(this.m_MaximumPlotSizeNumeric - this.m_MinimumPlotSizeNumeric + 1);
                    this.m_PlotSizes.set(i2, new Integer((int)temp) + this.m_MinimumPlotSizeNumeric);
                    continue;
                }
                this.m_PlotSizes.set(i2, new Integer(this.m_MinimumPlotSizeNumeric));
                continue;
            }
            this.m_PlotSizes.set(i2, new Integer(this.m_MinimumPlotSizeNumeric));
        }
    }

    protected void addPredictionIntervals() {
        int i2;
        int maxNum = 0;
        ArrayList<Prediction> preds = this.m_Evaluation.predictions();
        for (i2 = 0; i2 < preds.size(); ++i2) {
            int num = ((NumericPrediction)preds.get(i2)).predictionIntervals().length;
            if (num <= maxNum) continue;
            maxNum = num;
        }
        ArrayList<Attribute> atts = new ArrayList<Attribute>();
        for (i2 = 0; i2 < this.m_PlotInstances.numAttributes(); ++i2) {
            atts.add(this.m_PlotInstances.attribute(i2));
        }
        for (i2 = 0; i2 < maxNum; ++i2) {
            atts.add(new Attribute("predictionInterval_" + (i2 + 1) + "-lowerBoundary"));
            atts.add(new Attribute("predictionInterval_" + (i2 + 1) + "-upperBoundary"));
            atts.add(new Attribute("predictionInterval_" + (i2 + 1) + "-width"));
        }
        Instances data = new Instances(this.m_PlotInstances.relationName(), atts, this.m_PlotInstances.numInstances());
        data.setClassIndex(this.m_PlotInstances.classIndex());
        for (i2 = 0; i2 < this.m_PlotInstances.numInstances(); ++i2) {
            Instance inst = this.m_PlotInstances.instance(i2);
            double[] values = new double[data.numAttributes()];
            System.arraycopy(inst.toDoubleArray(), 0, values, 0, inst.numAttributes());
            double[][] predInt = ((NumericPrediction)preds.get(i2)).predictionIntervals();
            for (int n = 0; n < maxNum; ++n) {
                if (n < predInt.length) {
                    values[this.m_PlotInstances.numAttributes() + n * 3 + 0] = predInt[n][0];
                    values[this.m_PlotInstances.numAttributes() + n * 3 + 1] = predInt[n][1];
                    values[this.m_PlotInstances.numAttributes() + n * 3 + 2] = predInt[n][1] - predInt[n][0];
                    continue;
                }
                values[this.m_PlotInstances.numAttributes() + n * 3 + 0] = Utils.missingValue();
                values[this.m_PlotInstances.numAttributes() + n * 3 + 1] = Utils.missingValue();
                values[this.m_PlotInstances.numAttributes() + n * 3 + 2] = Utils.missingValue();
            }
            DenseInstance newInst = new DenseInstance(inst.weight(), values);
            data.add(newInst);
        }
        this.m_PlotInstances = data;
    }

    @Override
    protected void finishUp() {
        super.finishUp();
        if (!this.m_SaveForVisualization) {
            return;
        }
        if (this.m_Instances.classAttribute().isNumeric() || this.m_pointSizeProportionalToMargin) {
            this.scaleNumericPredictions();
        }
        if (this.m_Instances.attribute(this.m_ClassIndex).isNumeric() && this.m_Classifier instanceof IntervalEstimator) {
            this.addPredictionIntervals();
        }
    }

    @Override
    protected PlotData2D createPlotData(String name) throws Exception {
        if (!this.m_SaveForVisualization) {
            return null;
        }
        PlotData2D result = new PlotData2D(this.m_PlotInstances);
        result.setShapeSize(this.m_PlotSizes);
        result.setShapeType(this.m_PlotShapes);
        result.setPlotName(name + " (" + this.m_Instances.relationName() + ")");
        return result;
    }

    @Override
    public void cleanUp() {
        super.cleanUp();
        this.m_Classifier = null;
        this.m_PlotShapes = null;
        this.m_PlotSizes = null;
        this.m_Evaluation = null;
    }
}

