/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.functions;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.UpdateableBatchProcessor;
import weka.classifiers.UpdateableClassifier;
import weka.classifiers.functions.SGD;
import weka.core.Aggregateable;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.stemmers.NullStemmer;
import weka.core.stemmers.Stemmer;
import weka.core.stopwords.Null;
import weka.core.stopwords.StopwordsHandler;
import weka.core.tokenizers.Tokenizer;
import weka.core.tokenizers.WordTokenizer;

public class SGDText
extends RandomizableClassifier
implements UpdateableClassifier,
UpdateableBatchProcessor,
WeightedInstancesHandler,
Aggregateable<SGDText> {
    private static final long serialVersionUID = 7200171484002029584L;
    protected int m_periodicP = 0;
    protected double m_minWordP = 3.0;
    protected double m_minAbsCoefficient = 0.001;
    protected boolean m_wordFrequencies = false;
    protected boolean m_normalize = false;
    protected double m_norm = 1.0;
    protected double m_lnorm = 2.0;
    protected LinkedHashMap<String, Count> m_dictionary;
    protected StopwordsHandler m_StopwordsHandler = new Null();
    protected Tokenizer m_tokenizer = new WordTokenizer();
    protected boolean m_lowercaseTokens;
    protected Stemmer m_stemmer = new NullStemmer();
    protected double m_lambda = 1.0E-4;
    protected double m_learningRate = 0.01;
    protected double m_t;
    protected double m_bias;
    protected double m_numInstances;
    protected Instances m_data;
    protected int m_epochs = 500;
    protected transient LinkedHashMap<String, Count> m_inputVector;
    public static final int HINGE = 0;
    public static final int LOGLOSS = 1;
    protected int m_loss = 0;
    public static final Tag[] TAGS_SELECTION = new Tag[]{new Tag(0, "Hinge loss (SVM)"), new Tag(1, "Log loss (logistic regression)")};
    protected SGD m_svmProbs;
    protected boolean m_fitLogistic = false;
    protected Instances m_fitLogisticStructure;
    protected int m_numModels = 0;

    protected double dloss(double z) {
        if (this.m_loss == 0) {
            return z < 1.0 ? 1.0 : 0.0;
        }
        if (z < 0.0) {
            return 1.0 / (Math.exp(z) + 1.0);
        }
        double t = Math.exp(-z);
        return t / (t + 1.0);
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.STRING_ATTRIBUTES);
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enable(Capabilities.Capability.BINARY_CLASS);
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        result.setMinimumNumberInstances(0);
        return result;
    }

    public void setStemmer(Stemmer value) {
        this.m_stemmer = value != null ? value : new NullStemmer();
    }

    public Stemmer getStemmer() {
        return this.m_stemmer;
    }

    public String stemmerTipText() {
        return "The stemming algorithm to use on the words.";
    }

    public void setTokenizer(Tokenizer value) {
        this.m_tokenizer = value;
    }

    public Tokenizer getTokenizer() {
        return this.m_tokenizer;
    }

    public String tokenizerTipText() {
        return "The tokenizing algorithm to use on the strings.";
    }

    public String useWordFrequenciesTipText() {
        return "Use word frequencies rather than binary bag of words representation";
    }

    public void setUseWordFrequencies(boolean u) {
        this.m_wordFrequencies = u;
    }

    public boolean getUseWordFrequencies() {
        return this.m_wordFrequencies;
    }

    public String lowercaseTokensTipText() {
        return "Whether to convert all tokens to lowercase";
    }

    public void setLowercaseTokens(boolean l) {
        this.m_lowercaseTokens = l;
    }

    public boolean getLowercaseTokens() {
        return this.m_lowercaseTokens;
    }

    public void setStopwordsHandler(StopwordsHandler value) {
        this.m_StopwordsHandler = value != null ? value : new Null();
    }

    public StopwordsHandler getStopwordsHandler() {
        return this.m_StopwordsHandler;
    }

    public String stopwordsHandlerTipText() {
        return "The stopwords handler to use (Null means no stopwords are used).";
    }

    public String periodicPruningTipText() {
        return "How often (number of instances) to prune the dictionary of low frequency terms. 0 means don't prune. Setting a positive integer n means prune after every n instances";
    }

    public void setPeriodicPruning(int p) {
        this.m_periodicP = p;
    }

    public int getPeriodicPruning() {
        return this.m_periodicP;
    }

    public String minWordFrequencyTipText() {
        return "Ignore any words that don't occur at least min frequency times in the training data. If periodic pruning is turned on, then the dictionary is pruned according to this value";
    }

    public void setMinWordFrequency(double minFreq) {
        this.m_minWordP = minFreq;
    }

    public double getMinWordFrequency() {
        return this.m_minWordP;
    }

    public String minAbsoluteCoefficientValueTipText() {
        return "The minimum absolute magnitude for model coefficients. Terms with weights smaller than this value are ignored. If periodic pruning is turned on then this is also used to determine if a word should be removed from the dictionary.";
    }

    public void setMinAbsoluteCoefficientValue(double minCoeff) {
        this.m_minAbsCoefficient = minCoeff;
    }

    public double getMinAbsoluteCoefficientValue() {
        return this.m_minAbsCoefficient;
    }

    public String normalizeDocLengthTipText() {
        return "If true then document length is normalized according to the settings for norm and lnorm";
    }

    public void setNormalizeDocLength(boolean norm) {
        this.m_normalize = norm;
    }

    public boolean getNormalizeDocLength() {
        return this.m_normalize;
    }

    public String normTipText() {
        return "The norm of the instances after normalization.";
    }

    public double getNorm() {
        return this.m_norm;
    }

    public void setNorm(double newNorm) {
        this.m_norm = newNorm;
    }

    public String LNormTipText() {
        return "The LNorm to use for document length normalization.";
    }

    public double getLNorm() {
        return this.m_lnorm;
    }

    public void setLNorm(double newLNorm) {
        this.m_lnorm = newLNorm;
    }

    public String lambdaTipText() {
        return "The regularization constant. (default = 0.0001)";
    }

    public void setLambda(double lambda) {
        this.m_lambda = lambda;
    }

    public double getLambda() {
        return this.m_lambda;
    }

    public void setLearningRate(double lr) {
        this.m_learningRate = lr;
    }

    public double getLearningRate() {
        return this.m_learningRate;
    }

    public String learningRateTipText() {
        return "The learning rate.";
    }

    public String epochsTipText() {
        return "The number of epochs to perform (batch learning). The total number of iterations is epochs * num instances.";
    }

    public void setEpochs(int e) {
        this.m_epochs = e;
    }

    public int getEpochs() {
        return this.m_epochs;
    }

    public void setLossFunction(SelectedTag function) {
        if (function.getTags() == TAGS_SELECTION) {
            this.m_loss = function.getSelectedTag().getID();
        }
    }

    public SelectedTag getLossFunction() {
        return new SelectedTag(this.m_loss, TAGS_SELECTION);
    }

    public String lossFunctionTipText() {
        return "The loss function to use. Hinge loss (SVM), log loss (logistic regression) or squared loss (regression).";
    }

    public void setOutputProbsForSVM(boolean o) {
        this.m_fitLogistic = o;
    }

    public boolean getOutputProbsForSVM() {
        return this.m_fitLogistic;
    }

    public String outputProbsForSVMTipText() {
        return "Fit a logistic regression to the output of SVM for producing probability estimates";
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> newVector = new Vector<Option>();
        newVector.add(new Option("\tSet the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression)\n\t(default = 0)", "F", 1, "-F"));
        newVector.add(new Option("\tOutput probabilities for SVMs (fits a logsitic\n\tmodel to the output of the SVM)", "output-probs", 0, "-outputProbs"));
        newVector.add(new Option("\tThe learning rate (default = 0.01).", "L", 1, "-L"));
        newVector.add(new Option("\tThe lambda regularization constant (default = 0.0001)", "R", 1, "-R <double>"));
        newVector.add(new Option("\tThe number of epochs to perform (batch learning only, default = 500)", "E", 1, "-E <integer>"));
        newVector.add(new Option("\tUse word frequencies instead of binary bag of words.", "W", 0, "-W"));
        newVector.add(new Option("\tHow often to prune the dictionary of low frequency words (default = 0, i.e. don't prune)", "P", 1, "-P <# instances>"));
        newVector.add(new Option("\tMinimum word frequency. Words with less than this frequence are ignored.\n\tIf periodic pruning is turned on then this is also used to determine which\n\twords to remove from the dictionary (default = 3).", "M", 1, "-M <double>"));
        newVector.add(new Option("\tMinimum absolute value of coefficients in the model.\n\tIf periodic pruning is turned on then this\n\tis also used to prune words from the dictionary\n\t(default = 0.001", "min-coeff", 1, "-min-coeff <double>"));
        newVector.addElement(new Option("\tNormalize document length (use in conjunction with -norm and -lnorm)", "normalize", 0, "-normalize"));
        newVector.addElement(new Option("\tSpecify the norm that each instance must have (default 1.0)", "norm", 1, "-norm <num>"));
        newVector.addElement(new Option("\tSpecify L-norm to use (default 2.0)", "lnorm", 1, "-lnorm <num>"));
        newVector.addElement(new Option("\tConvert all tokens to lowercase before adding to the dictionary.", "lowercase", 0, "-lowercase"));
        newVector.addElement(new Option("\tThe stopwords handler to use (default Null).", "-stopwords-handler", 1, "-stopwords-handler"));
        newVector.addElement(new Option("\tThe tokenizing algorihtm (classname plus parameters) to use.\n\t(default: " + WordTokenizer.class.getName() + ")", "tokenizer", 1, "-tokenizer <spec>"));
        newVector.addElement(new Option("\tThe stemmering algorihtm (classname plus parameters) to use.", "stemmer", 1, "-stemmer <spec>"));
        newVector.addAll(Collections.list(super.listOptions()));
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String lnormFreqS;
        String minCoeff;
        String minFreq;
        String epochsString;
        String learningRateString;
        this.reset();
        String lossString = Utils.getOption('F', options);
        if (lossString.length() != 0) {
            this.setLossFunction(new SelectedTag(Integer.parseInt(lossString), TAGS_SELECTION));
        }
        this.setOutputProbsForSVM(Utils.getFlag("output-probs", options));
        String lambdaString = Utils.getOption('R', options);
        if (lambdaString.length() > 0) {
            this.setLambda(Double.parseDouble(lambdaString));
        }
        if ((learningRateString = Utils.getOption('L', options)).length() > 0) {
            this.setLearningRate(Double.parseDouble(learningRateString));
        }
        if ((epochsString = Utils.getOption("E", options)).length() > 0) {
            this.setEpochs(Integer.parseInt(epochsString));
        }
        this.setUseWordFrequencies(Utils.getFlag("W", options));
        String pruneFreqS = Utils.getOption("P", options);
        if (pruneFreqS.length() > 0) {
            this.setPeriodicPruning(Integer.parseInt(pruneFreqS));
        }
        if ((minFreq = Utils.getOption("M", options)).length() > 0) {
            this.setMinWordFrequency(Double.parseDouble(minFreq));
        }
        if ((minCoeff = Utils.getOption("min-coeff", options)).length() > 0) {
            this.setMinAbsoluteCoefficientValue(Double.parseDouble(minCoeff));
        }
        this.setNormalizeDocLength(Utils.getFlag("normalize", options));
        String normFreqS = Utils.getOption("norm", options);
        if (normFreqS.length() > 0) {
            this.setNorm(Double.parseDouble(normFreqS));
        }
        if ((lnormFreqS = Utils.getOption("lnorm", options)).length() > 0) {
            this.setLNorm(Double.parseDouble(lnormFreqS));
        }
        this.setLowercaseTokens(Utils.getFlag("lowercase", options));
        String stemmerString = Utils.getOption("stemmer", options);
        if (stemmerString.length() == 0) {
            this.setStemmer(null);
        } else {
            String[] stemmerSpec = Utils.splitOptions(stemmerString);
            if (stemmerSpec.length == 0) {
                throw new Exception("Invalid stemmer specification string");
            }
            String stemmerName = stemmerSpec[0];
            stemmerSpec[0] = "";
            Stemmer stemmer = (Stemmer)Utils.forName(Class.forName("weka.core.stemmers.Stemmer"), stemmerName, stemmerSpec);
            this.setStemmer(stemmer);
        }
        String stopwordsHandlerString = Utils.getOption("stopwords-handler", options);
        if (stopwordsHandlerString.length() == 0) {
            this.setStopwordsHandler(null);
        } else {
            String[] stopwordsHandlerSpec = Utils.splitOptions(stopwordsHandlerString);
            if (stopwordsHandlerSpec.length == 0) {
                throw new Exception("Invalid StopwordsHandler specification string");
            }
            String stopwordsHandlerName = stopwordsHandlerSpec[0];
            stopwordsHandlerSpec[0] = "";
            StopwordsHandler stopwordsHandler = (StopwordsHandler)Utils.forName(Class.forName("weka.core.stopwords.StopwordsHandler"), stopwordsHandlerName, stopwordsHandlerSpec);
            this.setStopwordsHandler(stopwordsHandler);
        }
        String tokenizerString = Utils.getOption("tokenizer", options);
        if (tokenizerString.length() == 0) {
            this.setTokenizer(new WordTokenizer());
        } else {
            String[] tokenizerSpec = Utils.splitOptions(tokenizerString);
            if (tokenizerSpec.length == 0) {
                throw new Exception("Invalid tokenizer specification string");
            }
            String tokenizerName = tokenizerSpec[0];
            tokenizerSpec[0] = "";
            Tokenizer tokenizer = (Tokenizer)Utils.forName(Class.forName("weka.core.tokenizers.Tokenizer"), tokenizerName, tokenizerSpec);
            this.setTokenizer(tokenizer);
        }
        super.setOptions(options);
        Utils.checkForRemainingOptions(options);
    }

    @Override
    public String[] getOptions() {
        String spec;
        ArrayList<String> options = new ArrayList<String>();
        options.add("-F");
        options.add("" + this.getLossFunction().getSelectedTag().getID());
        if (this.getOutputProbsForSVM()) {
            options.add("-output-probs");
        }
        options.add("-L");
        options.add("" + this.getLearningRate());
        options.add("-R");
        options.add("" + this.getLambda());
        options.add("-E");
        options.add("" + this.getEpochs());
        if (this.getUseWordFrequencies()) {
            options.add("-W");
        }
        options.add("-P");
        options.add("" + this.getPeriodicPruning());
        options.add("-M");
        options.add("" + this.getMinWordFrequency());
        options.add("-min-coeff");
        options.add("" + this.getMinAbsoluteCoefficientValue());
        if (this.getNormalizeDocLength()) {
            options.add("-normalize");
        }
        options.add("-norm");
        options.add("" + this.getNorm());
        options.add("-lnorm");
        options.add("" + this.getLNorm());
        if (this.getLowercaseTokens()) {
            options.add("-lowercase");
        }
        if (this.getStopwordsHandler() != null) {
            options.add("-stopwords-handler");
            spec = this.getStopwordsHandler().getClass().getName();
            if (this.getStopwordsHandler() instanceof OptionHandler) {
                spec = spec + " " + Utils.joinOptions(((OptionHandler)((Object)this.getStopwordsHandler())).getOptions());
            }
            options.add(spec.trim());
        }
        options.add("-tokenizer");
        spec = this.getTokenizer().getClass().getName();
        if (this.getTokenizer() instanceof OptionHandler) {
            spec = spec + " " + Utils.joinOptions(this.getTokenizer().getOptions());
        }
        options.add(spec.trim());
        if (this.getStemmer() != null) {
            options.add("-stemmer");
            spec = this.getStemmer().getClass().getName();
            if (this.getStemmer() instanceof OptionHandler) {
                spec = spec + " " + Utils.joinOptions(((OptionHandler)((Object)this.getStemmer())).getOptions());
            }
            options.add(spec.trim());
        }
        Collections.addAll(options, super.getOptions());
        return options.toArray(new String[1]);
    }

    public String globalInfo() {
        return "Implements stochastic gradient descent for learning a linear binary class SVM or binary class logistic regression on text data. Operates directly (and only) on String attributes. Other types of input attributes are accepted but ignored during training and classification.";
    }

    public void reset() {
        this.m_t = 1.0;
        this.m_bias = 0.0;
        this.m_dictionary = null;
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        this.reset();
        this.getCapabilities().testWithFail(data);
        this.m_dictionary = new LinkedHashMap(10000);
        this.m_numInstances = data.numInstances();
        this.m_data = new Instances(data, 0);
        data = new Instances(data);
        if (this.m_fitLogistic && this.m_loss == 0) {
            this.initializeSVMProbs(data);
        }
        if (data.numInstances() > 0) {
            data.randomize(new Random(this.getSeed()));
            this.train(data);
            this.pruneDictionary(true);
        }
    }

    protected void initializeSVMProbs(Instances data) throws Exception {
        this.m_svmProbs = new SGD();
        this.m_svmProbs.setLossFunction(new SelectedTag(1, TAGS_SELECTION));
        this.m_svmProbs.setLearningRate(this.m_learningRate);
        this.m_svmProbs.setLambda(this.m_lambda);
        this.m_svmProbs.setEpochs(this.m_epochs);
        ArrayList<Attribute> atts = new ArrayList<Attribute>(2);
        atts.add(new Attribute("pred"));
        ArrayList<String> attVals = new ArrayList<String>(2);
        attVals.add(data.classAttribute().value(0));
        attVals.add(data.classAttribute().value(1));
        atts.add(new Attribute("class", attVals));
        this.m_fitLogisticStructure = new Instances("data", atts, 0);
        this.m_fitLogisticStructure.setClassIndex(1);
        this.m_svmProbs.buildClassifier(this.m_fitLogisticStructure);
    }

    protected void train(Instances data) throws Exception {
        for (int e = 0; e < this.m_epochs; ++e) {
            for (int i2 = 0; i2 < data.numInstances(); ++i2) {
                if (e == 0) {
                    this.updateClassifier(data.instance(i2), true);
                    continue;
                }
                this.updateClassifier(data.instance(i2), false);
            }
        }
    }

    @Override
    public void updateClassifier(Instance instance) throws Exception {
        this.updateClassifier(instance, true);
    }

    protected void updateClassifier(Instance instance, boolean updateDictionary) throws Exception {
        if (!instance.classIsMissing()) {
            this.tokenizeInstance(instance, updateDictionary);
            if (this.m_loss == 0 && this.m_fitLogistic) {
                double pred = this.svmOutput();
                double[] vals = new double[]{pred, instance.classValue()};
                DenseInstance metaI = new DenseInstance(instance.weight(), vals);
                metaI.setDataset(this.m_fitLogisticStructure);
                this.m_svmProbs.updateClassifier(metaI);
            }
            double wx = this.dotProd(this.m_inputVector);
            double y = instance.classValue() == 0.0 ? -1.0 : 1.0;
            double z = y * (wx + this.m_bias);
            double multiplier = 1.0;
            multiplier = this.m_numInstances == 0.0 ? 1.0 - this.m_learningRate * this.m_lambda / this.m_t : 1.0 - this.m_learningRate * this.m_lambda / this.m_numInstances;
            for (Map.Entry<String, Count> c : this.m_dictionary.entrySet()) {
                c.getValue().m_weight *= multiplier;
            }
            if (this.m_loss != 0 || z < 1.0) {
                double dloss = this.dloss(z);
                double factor = this.m_learningRate * y * dloss;
                for (Map.Entry<String, Count> feature : this.m_inputVector.entrySet()) {
                    double value;
                    String word = feature.getKey();
                    double d = value = this.m_wordFrequencies ? feature.getValue().m_count : 1.0;
                    Count c = this.m_dictionary.get(word);
                    if (c == null) continue;
                    c.m_weight += factor * value;
                }
                this.m_bias += factor;
            }
            this.m_t += 1.0;
        }
    }

    protected void tokenizeInstance(Instance instance, boolean updateDictionary) {
        if (this.m_inputVector == null) {
            this.m_inputVector = new LinkedHashMap();
        } else {
            this.m_inputVector.clear();
        }
        for (int i2 = 0; i2 < instance.numAttributes(); ++i2) {
            if (!instance.attribute(i2).isString() || instance.isMissing(i2)) continue;
            this.m_tokenizer.tokenize(instance.stringValue(i2));
            while (this.m_tokenizer.hasMoreElements()) {
                String word = this.m_tokenizer.nextElement();
                if (this.m_lowercaseTokens) {
                    word = word.toLowerCase();
                }
                if (this.m_StopwordsHandler.isStopword(word = this.m_stemmer.stem(word))) continue;
                Count docCount = this.m_inputVector.get(word);
                if (docCount == null) {
                    this.m_inputVector.put(word, new Count(instance.weight()));
                } else {
                    docCount.m_count += instance.weight();
                }
                if (!updateDictionary) continue;
                Count count = this.m_dictionary.get(word);
                if (count == null) {
                    this.m_dictionary.put(word, new Count(instance.weight()));
                    continue;
                }
                count.m_count += instance.weight();
            }
        }
        if (updateDictionary) {
            this.pruneDictionary(false);
        }
    }

    protected void pruneDictionary(boolean force) {
        if ((this.m_periodicP <= 0 || this.m_t % (double)this.m_periodicP > 0.0) && !force) {
            return;
        }
        Iterator<Map.Entry<String, Count>> entries = this.m_dictionary.entrySet().iterator();
        while (entries.hasNext()) {
            Map.Entry<String, Count> entry = entries.next();
            if (!(entry.getValue().m_count < this.m_minWordP) && !(Math.abs(entry.getValue().m_weight) < this.m_minAbsCoefficient)) continue;
            entries.remove();
        }
    }

    protected double svmOutput() {
        double wx = this.dotProd(this.m_inputVector);
        double z = wx + this.m_bias;
        return z;
    }

    @Override
    public double[] distributionForInstance(Instance inst) throws Exception {
        double[] result = new double[2];
        this.tokenizeInstance(inst, false);
        double wx = this.dotProd(this.m_inputVector);
        double z = wx + this.m_bias;
        if (this.m_loss == 0 && this.m_fitLogistic) {
            double pred = z;
            double[] vals = new double[]{pred, Utils.missingValue()};
            DenseInstance metaI = new DenseInstance(inst.weight(), vals);
            metaI.setDataset(this.m_fitLogisticStructure);
            return this.m_svmProbs.distributionForInstance(metaI);
        }
        if (z <= 0.0) {
            if (this.m_loss == 1) {
                result[0] = 1.0 / (1.0 + Math.exp(z));
                result[1] = 1.0 - result[0];
            } else {
                result[0] = 1.0;
            }
        } else if (this.m_loss == 1) {
            result[1] = 1.0 / (1.0 + Math.exp(-z));
            result[0] = 1.0 - result[1];
        } else {
            result[1] = 1.0;
        }
        return result;
    }

    protected double dotProd(Map<String, Count> document) {
        double result = 0.0;
        double iNorm = 0.0;
        double fv = 0.0;
        if (this.m_normalize) {
            for (Count count : document.values()) {
                fv = this.m_wordFrequencies ? count.m_count : 1.0;
                iNorm += Math.pow(Math.abs(fv), this.m_lnorm);
            }
            iNorm = Math.pow(iNorm, 1.0 / this.m_lnorm);
        }
        for (Map.Entry entry : document.entrySet()) {
            Count weight;
            double freq;
            String word = (String)entry.getKey();
            double d = freq = this.m_wordFrequencies ? ((Count)entry.getValue()).m_count : 1.0;
            if (this.m_normalize) {
                freq *= this.m_norm / iNorm;
            }
            if ((weight = this.m_dictionary.get(word)) == null || !(weight.m_count >= this.m_minWordP) || !(Math.abs(weight.m_weight) >= this.m_minAbsCoefficient)) continue;
            result += freq * weight.m_weight;
        }
        return result;
    }

    public String toString() {
        if (this.m_dictionary == null) {
            return "SGDText: No model built yet.\n";
        }
        StringBuffer buff = new StringBuffer();
        buff.append("SGDText:\n\n");
        buff.append("Loss function: ");
        if (this.m_loss == 0) {
            buff.append("Hinge loss (SVM)\n\n");
        } else {
            buff.append("Log loss (logistic regression)\n\n");
        }
        int dictSize = 0;
        for (Map.Entry<String, Count> entry : this.m_dictionary.entrySet()) {
            if (!(entry.getValue().m_count >= this.m_minWordP) || !(Math.abs(entry.getValue().m_weight) >= this.m_minAbsCoefficient)) continue;
            ++dictSize;
        }
        buff.append("Dictionary size: " + dictSize + "\n\n");
        buff.append(this.m_data.classAttribute().name() + " = \n\n");
        int printed = 0;
        for (Map.Entry<String, Count> entry : this.m_dictionary.entrySet()) {
            if (!(entry.getValue().m_count >= this.m_minWordP) || !(Math.abs(entry.getValue().m_weight) >= this.m_minAbsCoefficient)) continue;
            if (printed > 0) {
                buff.append(" + ");
            } else {
                buff.append("   ");
            }
            buff.append(Utils.doubleToString(entry.getValue().m_weight, 12, 4) + " " + entry.getKey() + " " + entry.getValue().m_count + "\n");
            ++printed;
        }
        if (this.m_bias > 0.0) {
            buff.append(" + " + Utils.doubleToString(this.m_bias, 12, 4));
        } else {
            buff.append(" - " + Utils.doubleToString(-this.m_bias, 12, 4));
        }
        return buff.toString();
    }

    public LinkedHashMap<String, Count> getDictionary() {
        return this.m_dictionary;
    }

    public int getDictionarySize() {
        int size = 0;
        if (this.m_dictionary != null) {
            for (Map.Entry<String, Count> entry : this.m_dictionary.entrySet()) {
                if (!(entry.getValue().m_count >= this.m_minWordP) || !(Math.abs(entry.getValue().m_weight) >= this.m_minAbsCoefficient)) continue;
                ++size;
            }
        }
        return size;
    }

    public double bias() {
        return this.m_bias;
    }

    public void setBias(double bias) {
        this.m_bias = bias;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 13280 $");
    }

    @Override
    public SGDText aggregate(SGDText toAggregate) throws Exception {
        if (this.m_dictionary == null) {
            throw new Exception("No model built yet, can't aggregate");
        }
        LinkedHashMap<String, Count> tempDict = toAggregate.getDictionary();
        for (Map.Entry<String, Count> entry : tempDict.entrySet()) {
            Count masterCount = this.m_dictionary.get(entry.getKey());
            if (masterCount == null) {
                masterCount = new Count(entry.getValue().m_count);
                masterCount.m_weight = entry.getValue().m_weight;
                this.m_dictionary.put(entry.getKey(), masterCount);
                continue;
            }
            masterCount.m_count += entry.getValue().m_count;
            masterCount.m_weight += entry.getValue().m_weight;
        }
        this.m_bias += toAggregate.bias();
        ++this.m_numModels;
        return this;
    }

    @Override
    public void finalizeAggregation() throws Exception {
        if (this.m_numModels == 0) {
            throw new Exception("Unable to finalize aggregation - haven't seen any models to aggregate");
        }
        this.pruneDictionary(true);
        for (Map.Entry<String, Count> entry : this.m_dictionary.entrySet()) {
            entry.getValue().m_count /= (double)(this.m_numModels + 1);
            entry.getValue().m_weight /= (double)(this.m_numModels + 1);
        }
        this.m_bias /= (double)(this.m_numModels + 1);
        this.m_numModels = 0;
    }

    @Override
    public void batchFinished() throws Exception {
        this.pruneDictionary(true);
    }

    public static void main(String[] args) {
        SGDText.runClassifier(new SGDText(), args);
    }

    public static class Count
    implements Serializable {
        private static final long serialVersionUID = 2104201532017340967L;
        public double m_count;
        public double m_weight;

        public Count(double c) {
            this.m_count = c;
        }
    }
}

