/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.evaluation.RegressionAnalysis;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class SimpleLinearRegression
extends AbstractClassifier
implements WeightedInstancesHandler {
    static final long serialVersionUID = 1679336022895414137L;
    private Attribute m_attribute;
    private int m_attributeIndex;
    private double m_slope;
    private double m_intercept;
    private double m_classMeanForMissing;
    protected boolean m_outputAdditionalStats;
    private int m_df;
    private double m_seSlope = Double.NaN;
    private double m_seIntercept = Double.NaN;
    private double m_tstatSlope = Double.NaN;
    private double m_tstatIntercept = Double.NaN;
    private double m_rsquared = Double.NaN;
    private double m_rsquaredAdj = Double.NaN;
    private double m_fstat = Double.NaN;
    private boolean m_suppressErrorMessage = false;

    public String globalInfo() {
        return "Learns a simple linear regression model. Picks the attribute that results in the lowest squared error. Can only deal with numeric attributes.";
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> newVector = new Vector<Option>();
        newVector.addElement(new Option("\tOutput additional statistics.", "additional-stats", 0, "-additional-stats"));
        newVector.addAll(Collections.list(super.listOptions()));
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        this.setOutputAdditionalStats(Utils.getFlag("additional-stats", options));
        super.setOptions(options);
        Utils.checkForRemainingOptions(options);
    }

    @Override
    public String[] getOptions() {
        Vector<String> result = new Vector<String>();
        if (this.getOutputAdditionalStats()) {
            result.add("-additional-stats");
        }
        Collections.addAll(result, super.getOptions());
        return result.toArray(new String[result.size()]);
    }

    public String outputAdditionalStatsTipText() {
        return "Output additional statistics (such as std deviation of coefficients and t-statistics)";
    }

    public void setOutputAdditionalStats(boolean additional) {
        this.m_outputAdditionalStats = additional;
    }

    public boolean getOutputAdditionalStats() {
        return this.m_outputAdditionalStats;
    }

    @Override
    public double classifyInstance(Instance inst) throws Exception {
        if (this.m_attribute == null) {
            return this.m_intercept;
        }
        if (inst.isMissing(this.m_attributeIndex)) {
            return this.m_classMeanForMissing;
        }
        return this.m_intercept + this.m_slope * inst.value(this.m_attributeIndex);
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.NUMERIC_CLASS);
        result.enable(Capabilities.Capability.DATE_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        return result;
    }

    @Override
    public void buildClassifier(Instances insts) throws Exception {
        this.getCapabilities().testWithFail(insts);
        if (this.m_outputAdditionalStats) {
            boolean ok = true;
            for (int i2 = 0; i2 < insts.numInstances(); ++i2) {
                if (insts.instance(i2).weight() == 1.0) continue;
                ok = false;
                break;
            }
            if (!ok) {
                throw new Exception("Can only compute additional statistics on unweighted data");
            }
        }
        double[] sum = new double[insts.numAttributes()];
        double[] count = new double[insts.numAttributes()];
        double[] classSumForMissing = new double[insts.numAttributes()];
        double[] classSumSquaredForMissing = new double[insts.numAttributes()];
        double classCount = 0.0;
        double classSum = 0.0;
        for (int j = 0; j < insts.numInstances(); ++j) {
            Instance inst = insts.instance(j);
            if (inst.classIsMissing()) continue;
            for (int i3 = 0; i3 < insts.numAttributes(); ++i3) {
                if (!inst.isMissing(i3)) {
                    int n = i3;
                    sum[n] = sum[n] + inst.weight() * inst.value(i3);
                    int n2 = i3;
                    count[n2] = count[n2] + inst.weight();
                    continue;
                }
                int n = i3;
                classSumForMissing[n] = classSumForMissing[n] + inst.classValue() * inst.weight();
                int n3 = i3;
                classSumSquaredForMissing[n3] = classSumSquaredForMissing[n3] + inst.classValue() * inst.classValue() * inst.weight();
            }
            classCount += inst.weight();
            classSum += inst.weight() * inst.classValue();
        }
        double[] mean = new double[insts.numAttributes()];
        double[] classMeanForMissing = new double[insts.numAttributes()];
        double[] classMeanForKnown = new double[insts.numAttributes()];
        for (int i4 = 0; i4 < insts.numAttributes(); ++i4) {
            if (i4 == insts.classIndex()) continue;
            if (count[i4] > 0.0) {
                mean[i4] = sum[i4] / count[i4];
            }
            if (classCount - count[i4] > 0.0) {
                classMeanForMissing[i4] = classSumForMissing[i4] / (classCount - count[i4]);
            }
            if (!(count[i4] > 0.0)) continue;
            classMeanForKnown[i4] = (classSum - classSumForMissing[i4]) / count[i4];
        }
        sum = null;
        count = null;
        double[] slopes = new double[insts.numAttributes()];
        double[] sumWeightedDiffsSquared = new double[insts.numAttributes()];
        double[] sumWeightedClassDiffsSquared = new double[insts.numAttributes()];
        for (int j = 0; j < insts.numInstances(); ++j) {
            Instance inst = insts.instance(j);
            if (inst.classIsMissing()) continue;
            for (int i5 = 0; i5 < insts.numAttributes(); ++i5) {
                if (inst.isMissing(i5) || i5 == insts.classIndex()) continue;
                double yDiff = inst.classValue() - classMeanForKnown[i5];
                double weightedYDiff = inst.weight() * yDiff;
                double diff = inst.value(i5) - mean[i5];
                double weightedDiff = inst.weight() * diff;
                int n = i5;
                slopes[n] = slopes[n] + weightedYDiff * diff;
                int n4 = i5;
                sumWeightedDiffsSquared[n4] = sumWeightedDiffsSquared[n4] + weightedDiff * diff;
                int n5 = i5;
                sumWeightedClassDiffsSquared[n5] = sumWeightedClassDiffsSquared[n5] + weightedYDiff * yDiff;
            }
        }
        double minSSE = Double.MAX_VALUE;
        this.m_attribute = null;
        int chosen = -1;
        double chosenSlope = Double.NaN;
        double chosenIntercept = Double.NaN;
        double chosenMeanForMissing = Double.NaN;
        for (int i6 = 0; i6 < insts.numAttributes(); ++i6) {
            double sseForMissing = classSumSquaredForMissing[i6] - classSumForMissing[i6] * classMeanForMissing[i6];
            if (i6 == insts.classIndex() || sumWeightedDiffsSquared[i6] == 0.0) continue;
            double numerator = slopes[i6];
            int n = i6;
            slopes[n] = slopes[n] / sumWeightedDiffsSquared[i6];
            double intercept = classMeanForKnown[i6] - slopes[i6] * mean[i6];
            double sse = sumWeightedClassDiffsSquared[i6] - slopes[i6] * numerator;
            if (!((sse += sseForMissing) < minSSE)) continue;
            minSSE = sse;
            chosen = i6;
            chosenSlope = slopes[i6];
            chosenIntercept = intercept;
            chosenMeanForMissing = classMeanForMissing[i6];
        }
        if (chosen == -1) {
            if (!this.m_suppressErrorMessage) {
                System.err.println("----- no useful attribute found");
            }
            this.m_attribute = null;
            this.m_attributeIndex = 0;
            this.m_slope = 0.0;
            this.m_intercept = classSum / classCount;
            this.m_classMeanForMissing = 0.0;
        } else {
            this.m_attribute = insts.attribute(chosen);
            this.m_attributeIndex = chosen;
            this.m_slope = chosenSlope;
            this.m_intercept = chosenIntercept;
            this.m_classMeanForMissing = chosenMeanForMissing;
            if (this.m_outputAdditionalStats) {
                Instances newInsts = new Instances(insts, insts.numInstances());
                for (int i7 = 0; i7 < insts.numInstances(); ++i7) {
                    Instance inst = insts.instance(i7);
                    if (inst.classIsMissing() || inst.isMissing(this.m_attributeIndex)) continue;
                    newInsts.add(inst);
                }
                insts = newInsts;
                this.m_df = insts.numInstances() - 2;
                double[] stdErrors = RegressionAnalysis.calculateStdErrorOfCoef(insts, this.m_attribute, this.m_slope, this.m_intercept, this.m_df);
                this.m_seSlope = stdErrors[0];
                this.m_seIntercept = stdErrors[1];
                double[] coef = new double[]{this.m_slope, this.m_intercept};
                double[] tStats = RegressionAnalysis.calculateTStats(coef, stdErrors, 2);
                this.m_tstatSlope = tStats[0];
                this.m_tstatIntercept = tStats[1];
                double ssr = RegressionAnalysis.calculateSSR(insts, this.m_attribute, this.m_slope, this.m_intercept);
                this.m_rsquared = RegressionAnalysis.calculateRSquared(insts, ssr);
                this.m_rsquaredAdj = RegressionAnalysis.calculateAdjRSquared(this.m_rsquared, insts.numInstances(), 2);
                this.m_fstat = RegressionAnalysis.calculateFStat(this.m_rsquared, insts.numInstances(), 2);
            }
        }
    }

    public boolean foundUsefulAttribute() {
        return this.m_attribute != null;
    }

    public int getAttributeIndex() {
        return this.m_attributeIndex;
    }

    public double getSlope() {
        return this.m_slope;
    }

    public double getIntercept() {
        return this.m_intercept;
    }

    public void setSuppressErrorMessage(boolean s) {
        this.m_suppressErrorMessage = s;
    }

    public String toString() {
        StringBuffer text = new StringBuffer();
        if (this.m_attribute == null) {
            text.append("Predicting constant " + this.m_intercept);
        } else {
            text.append("Linear regression on " + this.m_attribute.name() + "\n\n");
            text.append(Utils.doubleToString(this.m_slope, 2) + " * " + this.m_attribute.name());
            if (this.m_intercept > 0.0) {
                text.append(" + " + Utils.doubleToString(this.m_intercept, 2));
            } else {
                text.append(" - " + Utils.doubleToString(-this.m_intercept, 2));
            }
            text.append("\n\nPredicting " + Utils.doubleToString(this.m_classMeanForMissing, 2) + " if attribute value is missing.");
            if (this.m_outputAdditionalStats) {
                int attNameLength = this.m_attribute.name().length() + 3;
                if (attNameLength < "Variable".length() + 3) {
                    attNameLength = "Variable".length() + 3;
                }
                text.append("\n\nRegression Analysis:\n\n" + Utils.padRight("Variable", attNameLength) + "  Coefficient     SE of Coef        t-Stat");
                text.append("\n" + Utils.padRight(this.m_attribute.name(), attNameLength));
                text.append(Utils.doubleToString(this.m_slope, 12, 4));
                text.append("   " + Utils.doubleToString(this.m_seSlope, 12, 5));
                text.append("   " + Utils.doubleToString(this.m_tstatSlope, 12, 5));
                text.append(Utils.padRight("\nconst", attNameLength + 1) + Utils.doubleToString(this.m_intercept, 12, 4));
                text.append("   " + Utils.doubleToString(this.m_seIntercept, 12, 5));
                text.append("   " + Utils.doubleToString(this.m_tstatIntercept, 12, 5));
                text.append("\n\nDegrees of freedom = " + Integer.toString(this.m_df));
                text.append("\nR^2 value = " + Utils.doubleToString(this.m_rsquared, 5));
                text.append("\nAdjusted R^2 = " + Utils.doubleToString(this.m_rsquaredAdj, 5));
                text.append("\nF-statistic = " + Utils.doubleToString(this.m_fstat, 5));
            }
        }
        text.append("\n");
        return text.toString();
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 11130 $");
    }

    public static void main(String[] argv) {
        SimpleLinearRegression.runClassifier(new SimpleLinearRegression(), argv);
    }
}

