/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.pmml.consumer;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;
import weka.classifiers.pmml.consumer.NeuralNetwork;
import weka.classifiers.pmml.consumer.PMMLClassifier;
import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.pmml.MiningSchema;
import weka.core.pmml.TargetMetaInfo;
import weka.core.pmml.VectorDictionary;
import weka.core.pmml.VectorInstance;
import weka.gui.Logger;

public class SupportVectorMachineModel
extends PMMLClassifier
implements Serializable {
    private static final long serialVersionUID = 6225095165118374296L;
    protected NeuralNetwork.MiningFunction m_functionType = NeuralNetwork.MiningFunction.CLASSIFICATION;
    protected classificationMethod m_classificationMethod = classificationMethod.NONE;
    protected String m_modelName;
    protected String m_algorithmName;
    protected VectorDictionary m_vectorDictionary;
    protected Kernel m_kernel;
    protected List<SupportVectorMachine> m_machines = new ArrayList<SupportVectorMachine>();
    protected int m_alternateBinaryTargetCategory = -1;
    protected SVM_representation m_svmRepresentation = SVM_representation.SUPPORT_VECTORS;
    protected double m_threshold = 0.0;

    public SupportVectorMachineModel(Element model, Instances dataDictionary, MiningSchema miningSchema) throws Exception {
        super(dataDictionary, miningSchema);
        String classificationMethodS;
        String thresholdS;
        String altTargetCat;
        String svmRep;
        String algoName;
        String modelName;
        String fn;
        if (!this.getPMMLVersion().equals("3.2")) {
            // empty if block
        }
        if ((fn = model.getAttribute("functionName")).equals("regression")) {
            this.m_functionType = NeuralNetwork.MiningFunction.REGRESSION;
        }
        if ((modelName = model.getAttribute("modelName")) != null && modelName.length() > 0) {
            this.m_modelName = modelName;
        }
        if ((algoName = model.getAttribute("algorithmName")) != null && algoName.length() > 0) {
            this.m_algorithmName = algoName;
        }
        if ((svmRep = model.getAttribute("svmRepresentation")) != null && svmRep.length() > 0 && svmRep.equals("Coefficients")) {
            this.m_svmRepresentation = SVM_representation.COEFFICIENTS;
        }
        if ((altTargetCat = model.getAttribute("alternateBinaryTargetCategory")) != null && altTargetCat.length() > 0) {
            int altTargetInd = this.m_miningSchema.getFieldsAsInstances().classAttribute().indexOfValue(altTargetCat);
            if (altTargetInd < 0) {
                throw new Exception("[SupportVectorMachineModel] can't find alternate target value " + altTargetCat);
            }
            this.m_alternateBinaryTargetCategory = altTargetInd;
        }
        if ((thresholdS = model.getAttribute("threshold")) != null && thresholdS.length() > 0) {
            this.m_threshold = Double.parseDouble(thresholdS);
        }
        if (this.getPMMLVersion().startsWith("4.")) {
            this.m_classificationMethod = classificationMethod.ONE_AGAINST_ALL;
        }
        if ((classificationMethodS = model.getAttribute("classificationMethod")) != null && classificationMethodS.length() > 0 && classificationMethodS.equals("OneAgainstOne")) {
            this.m_classificationMethod = classificationMethod.ONE_AGAINST_ONE;
        }
        if (this.m_svmRepresentation == SVM_representation.SUPPORT_VECTORS) {
            this.m_vectorDictionary = VectorDictionary.getVectorDictionary(model, miningSchema);
        }
        this.m_kernel = Kernel.getKernel(model, this.m_log);
        if (this.m_svmRepresentation == SVM_representation.COEFFICIENTS && !(this.m_kernel instanceof LinearKernel)) {
            throw new Exception("[SupportVectorMachineModel] representation is coefficients, but kernel is not linear!");
        }
        NodeList machineL = model.getElementsByTagName("SupportVectorMachine");
        if (machineL.getLength() == 0) {
            throw new Exception("[SupportVectorMachineModel] No binary SVMs defined in model file!");
        }
        for (int i2 = 0; i2 < machineL.getLength(); ++i2) {
            Node machine = machineL.item(i2);
            SupportVectorMachine newMach = new SupportVectorMachine((Element)machine, this.m_miningSchema, this.m_vectorDictionary, this.m_svmRepresentation, this.m_alternateBinaryTargetCategory, this.m_log);
            this.m_machines.add(newMach);
        }
    }

    @Override
    public double[] distributionForInstance(Instance inst) throws Exception {
        if (!this.m_initialized) {
            this.mapToMiningSchema(inst.dataset());
        }
        double[] preds = null;
        if (this.m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
            preds = new double[1];
        } else {
            preds = new double[this.m_miningSchema.getFieldsAsInstances().classAttribute().numValues()];
            for (int i2 = 0; i2 < preds.length; ++i2) {
                preds[i2] = -1.0;
            }
        }
        double[] incoming = this.m_fieldsMap.instanceToSchema(inst, this.m_miningSchema);
        boolean hasMissing = false;
        for (int i3 = 0; i3 < incoming.length; ++i3) {
            if (i3 == this.m_miningSchema.getFieldsAsInstances().classIndex() || !Double.isNaN(incoming[i3])) continue;
            hasMissing = true;
            break;
        }
        if (hasMissing) {
            if (!this.m_miningSchema.hasTargetMetaData()) {
                String message = "[SupportVectorMachineModel] WARNING: Instance to predict has missing value(s) but there is no missing value handling meta data and no prior probabilities/default value to fall back to. No prediction will be made (" + (this.m_miningSchema.getFieldsAsInstances().classAttribute().isNominal() || this.m_miningSchema.getFieldsAsInstances().classAttribute().isString() ? "zero probabilities output)." : "NaN output).");
                if (this.m_log == null) {
                    System.err.println(message);
                } else {
                    this.m_log.logMessage(message);
                }
                if (this.m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
                    preds[0] = Utils.missingValue();
                }
                return preds;
            }
            TargetMetaInfo targetData = this.m_miningSchema.getTargetMetaData();
            if (this.m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
                preds[0] = targetData.getDefaultValue();
            } else {
                Instances miningSchemaI = this.m_miningSchema.getFieldsAsInstances();
                for (int i4 = 0; i4 < miningSchemaI.classAttribute().numValues(); ++i4) {
                    preds[i4] = targetData.getPriorProbability(miningSchemaI.classAttribute().value(i4));
                }
            }
            return preds;
        }
        for (SupportVectorMachine m : this.m_machines) {
            m.distributionForInstance(incoming, this.m_kernel, this.m_vectorDictionary, preds, this.m_classificationMethod, this.m_threshold);
        }
        if (this.m_classificationMethod != classificationMethod.NONE && this.m_miningSchema.getFieldsAsInstances().classAttribute().isNominal() && this.m_classificationMethod == classificationMethod.ONE_AGAINST_ALL) {
            int minI = Utils.minIndex(preds);
            preds = new double[preds.length];
            preds[minI] = 1.0;
        }
        if (this.m_machines.size() == preds.length - 1) {
            double total = 0.0;
            int unset = -1;
            for (int i5 = 0; i5 < preds.length; ++i5) {
                if (preds[i5] != -1.0) {
                    total += preds[i5];
                    continue;
                }
                unset = i5;
            }
            if (total > 1.0) {
                throw new Exception("[SupportVectorMachineModel] total of probabilities is greater than 1!");
            }
            preds[unset] = 1.0 - total;
        }
        if (preds.length > 1) {
            Utils.normalize(preds);
        }
        return preds;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 8034 $");
    }

    public String toString() {
        StringBuffer temp = new StringBuffer();
        temp.append("PMML version " + this.getPMMLVersion());
        if (!this.getCreatorApplication().equals("?")) {
            temp.append("\nApplication: " + this.getCreatorApplication());
        }
        temp.append("\nPMML Model: Support Vector Machine Model");
        temp.append("\n\n");
        temp.append(this.m_miningSchema);
        temp.append("Kernel: \n\t");
        temp.append(this.m_kernel);
        temp.append("\n");
        if (this.m_classificationMethod != classificationMethod.NONE) {
            temp.append("Multi-class classifcation using ");
            if (this.m_classificationMethod == classificationMethod.ONE_AGAINST_ALL) {
                temp.append("one-against-all");
            } else {
                temp.append("one-against-one");
            }
            temp.append("\n\n");
        }
        for (SupportVectorMachine v : this.m_machines) {
            temp.append("\n" + v);
        }
        return temp.toString();
    }

    static enum classificationMethod {
        NONE,
        ONE_AGAINST_ALL,
        ONE_AGAINST_ONE;

    }

    static enum SVM_representation {
        SUPPORT_VECTORS,
        COEFFICIENTS;

    }

    static class SupportVectorMachine
    implements Serializable {
        private static final long serialVersionUID = -7650496802836815608L;
        protected String m_targetCategory;
        protected int m_globalAlternateTargetCategoryIndex = -1;
        protected int m_targetCategoryIndex = -1;
        protected int m_localAlternateTargetCategoryIndex = -1;
        protected double m_localThreshold = Double.MAX_VALUE;
        protected MiningSchema m_miningSchema;
        protected Logger m_log;
        protected boolean m_coeffsOnly = false;
        protected List<VectorInstance> m_supportVectors = new ArrayList<VectorInstance>();
        protected double m_intercept = 0.0;
        protected double[] m_coefficients;

        public void distributionForInstance(double[] input, Kernel kernel, VectorDictionary vecDict, double[] preds, classificationMethod cMethod, double globalThreshold) throws Exception {
            int targetIndex = 0;
            if (!this.m_coeffsOnly) {
                input = vecDict.incomingInstanceToVectorFieldVals(input);
            }
            if (this.m_miningSchema.getFieldsAsInstances().classAttribute().isNominal()) {
                targetIndex = this.m_targetCategoryIndex;
            }
            double result = 0.0;
            for (int i2 = 0; i2 < this.m_coefficients.length; ++i2) {
                double val = 0.0;
                val = !this.m_coeffsOnly ? kernel.evaluate(this.m_supportVectors.get(i2), input) : input[i2];
                result += (val *= this.m_coefficients[i2]);
            }
            result += this.m_intercept;
            if (cMethod == classificationMethod.NONE || this.m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
                preds[targetIndex] = this.m_miningSchema.getFieldsAsInstances().classAttribute().isNominal() ? (result < 0.0 ? 1.0 : 0.0) : result;
            } else if (cMethod == classificationMethod.ONE_AGAINST_ALL) {
                preds[targetIndex] = result;
            } else {
                double threshold;
                double d = threshold = this.m_localThreshold < Double.MAX_VALUE ? this.m_localThreshold : globalThreshold;
                if (result < threshold) {
                    int n = targetIndex;
                    preds[n] = preds[n] + 1.0;
                } else {
                    int altCat;
                    int n = altCat = this.m_localAlternateTargetCategoryIndex != -1 ? this.m_localAlternateTargetCategoryIndex : this.m_globalAlternateTargetCategoryIndex;
                    preds[n] = preds[n] + 1.0;
                }
            }
        }

        /*
         * Enabled force condition propagation
         * Lifted jumps to return sites
         */
        public SupportVectorMachine(Element machineElement, MiningSchema miningSchema, VectorDictionary dictionary, SVM_representation svmRep, int altCategoryInd, Logger log) throws Exception {
            NodeList coeffL;
            NodeList coefficientsL;
            this.m_miningSchema = miningSchema;
            this.m_log = log;
            String targetCat = machineElement.getAttribute("targetCategory");
            if (targetCat != null && targetCat.length() > 0) {
                this.m_targetCategory = targetCat;
                Attribute classAtt = this.m_miningSchema.getFieldsAsInstances().classAttribute();
                if (!classAtt.isNominal()) throw new Exception("[SupportVectorMachine] : target category supplied but class attribute is numeric!");
                int index = classAtt.indexOfValue(this.m_targetCategory);
                if (index < 0) {
                    throw new Exception("[SupportVectorMachine] : can't find target category: " + this.m_targetCategory + " in the class attribute!");
                }
                this.m_targetCategoryIndex = index;
                String altTargetCat = machineElement.getAttribute("alternateTargetCategory");
                if (altTargetCat != null && altTargetCat.length() > 0) {
                    index = classAtt.indexOfValue(altTargetCat);
                    if (index < 0) {
                        throw new Exception("[SupportVectorMachine] : can't find alternate target category: " + altTargetCat + " in the class attribute!");
                    }
                    this.m_localAlternateTargetCategoryIndex = index;
                } else {
                    this.m_globalAlternateTargetCategoryIndex = altCategoryInd;
                }
            } else if (this.m_miningSchema.getFieldsAsInstances().classAttribute().isNominal()) {
                this.m_targetCategoryIndex = altCategoryInd == 0 ? 1 : 0;
                this.m_globalAlternateTargetCategoryIndex = altCategoryInd;
                System.err.println("Setting target index for machine to " + this.m_targetCategoryIndex);
            }
            if (svmRep == SVM_representation.SUPPORT_VECTORS) {
                NodeList vectorsL = machineElement.getElementsByTagName("SupportVectors");
                if (vectorsL.getLength() > 0) {
                    Element vectors = (Element)vectorsL.item(0);
                    NodeList allTheVectorsL = vectors.getElementsByTagName("SupportVector");
                    for (int i2 = 0; i2 < allTheVectorsL.getLength(); ++i2) {
                        Node vec = allTheVectorsL.item(i2);
                        String vecId = ((Element)vec).getAttribute("vectorId");
                        VectorInstance suppV = dictionary.getVector(vecId);
                        if (suppV == null) {
                            throw new Exception("[SupportVectorMachine] : can't find vector with ID: " + vecId + " in the vector dictionary!");
                        }
                        this.m_supportVectors.add(suppV);
                    }
                }
            } else {
                this.m_coeffsOnly = true;
            }
            if ((coefficientsL = machineElement.getElementsByTagName("Coefficients")).getLength() != 1) {
                throw new Exception("[SupportVectorMachine] Should be just one list of coefficients per binary SVM!");
            }
            Element cL = (Element)coefficientsL.item(0);
            String intercept = cL.getAttribute("absoluteValue");
            if (intercept != null && intercept.length() > 0) {
                this.m_intercept = Double.parseDouble(intercept);
            }
            if ((coeffL = cL.getElementsByTagName("Coefficient")).getLength() == 0) {
                throw new Exception("[SupportVectorMachine] No coefficients defined!");
            }
            this.m_coefficients = new double[coeffL.getLength()];
            for (int i3 = 0; i3 < coeffL.getLength(); ++i3) {
                Element coeff = (Element)coeffL.item(i3);
                String val = coeff.getAttribute("value");
                this.m_coefficients[i3] = Double.parseDouble(val);
            }
        }

        public String toString() {
            StringBuffer temp = new StringBuffer();
            temp.append("Binary SVM");
            if (this.m_miningSchema.getFieldsAsInstances().classAttribute().isNominal()) {
                temp.append(" (target category = " + this.m_targetCategory + ")");
                if (this.m_localAlternateTargetCategoryIndex != -1) {
                    temp.append("\n (alternate category = " + this.m_miningSchema.getFieldsAsInstances().classAttribute().value(this.m_localAlternateTargetCategoryIndex) + ")");
                }
            }
            temp.append("\n\n");
            for (int i2 = 0; i2 < this.m_supportVectors.size(); ++i2) {
                temp.append("\n" + this.m_coefficients[i2] + " * [" + this.m_supportVectors.get(i2).getValues() + " * X]");
            }
            if (this.m_intercept >= 0.0) {
                temp.append("\n +" + this.m_intercept);
            } else {
                temp.append("\n " + this.m_intercept);
            }
            return temp.toString();
        }
    }

    static class SigmoidKernel
    extends Kernel
    implements Serializable {
        private static final long serialVersionUID = 8713475894705750117L;
        protected double m_gamma = 1.0;
        protected double m_coef0 = 1.0;

        public SigmoidKernel(Element sigElement) {
            this(sigElement, null);
        }

        public SigmoidKernel(Element sigElement, Logger log) {
            super(log);
            String coefString;
            String gammaString = sigElement.getAttribute("gamma");
            if (gammaString != null && gammaString.length() > 0) {
                try {
                    this.m_gamma = Double.parseDouble(gammaString);
                }
                catch (NumberFormatException e) {
                    String message = "[SigmoidKernel] : WARNING, can't parse gamma attribute. Using default value of 1.";
                    if (this.m_log == null) {
                        System.err.println(message);
                    }
                    this.m_log.logMessage(message);
                }
            }
            if ((coefString = sigElement.getAttribute("coef0")) != null && coefString.length() > 0) {
                try {
                    this.m_coef0 = Double.parseDouble(coefString);
                }
                catch (NumberFormatException e) {
                    String message = "[SigmoidKernel] : WARNING, can't parse coef0 attribute. Using default value of 1.";
                    if (this.m_log == null) {
                        System.err.println(message);
                    }
                    this.m_log.logMessage(message);
                }
            }
        }

        @Override
        public double evaluate(VectorInstance x, VectorInstance y) throws Exception {
            double dotProd = x.dotProduct(y);
            double z = this.m_gamma * dotProd + this.m_coef0;
            double a = Math.exp(z);
            double b = Math.exp(-z);
            return (a - b) / (a + b);
        }

        @Override
        public double evaluate(VectorInstance x, double[] y) throws Exception {
            double dotProd = x.dotProduct(y);
            double z = this.m_gamma * dotProd + this.m_coef0;
            double a = Math.exp(z);
            double b = Math.exp(-z);
            return (a - b) / (a + b);
        }

        public String toString() {
            return "Sigmoid kernel: K(x,y) = tanh(" + this.m_gamma + " * <x,y> + " + this.m_coef0 + ")";
        }
    }

    static class RadialBasisKernel
    extends Kernel
    implements Serializable {
        private static final long serialVersionUID = -3834238621822239042L;
        protected double m_gamma = 1.0;

        public RadialBasisKernel(Element radialElement) {
            this(radialElement, null);
        }

        public RadialBasisKernel(Element radialElement, Logger log) {
            super(log);
            String gammaString = radialElement.getAttribute("gamma");
            if (gammaString != null && gammaString.length() > 0) {
                try {
                    this.m_gamma = Double.parseDouble(gammaString);
                }
                catch (NumberFormatException e) {
                    String message = "[RadialBasisKernel] : WARNING, can't parse gamma attribute. Using default value of 1.";
                    if (this.m_log == null) {
                        System.err.println(message);
                    }
                    this.m_log.logMessage(message);
                }
            }
        }

        @Override
        public double evaluate(VectorInstance x, VectorInstance y) throws Exception {
            VectorInstance diff = x.subtract(y);
            double result = -this.m_gamma * diff.dotProduct(diff);
            return Math.exp(result);
        }

        @Override
        public double evaluate(VectorInstance x, double[] y) throws Exception {
            VectorInstance diff = x.subtract(y);
            double result = -this.m_gamma * diff.dotProduct(diff);
            return Math.exp(result);
        }

        public String toString() {
            return "Radial kernel: K(x,y) = exp(-" + this.m_gamma + " * ||x - y||^2)";
        }
    }

    static class PolynomialKernel
    extends Kernel
    implements Serializable {
        private static final long serialVersionUID = -616176630397865281L;
        protected double m_gamma = 1.0;
        protected double m_coef0 = 1.0;
        protected double m_degree = 1.0;

        public PolynomialKernel(Element polyNode) {
            this(polyNode, null);
        }

        public PolynomialKernel(Element polyNode, Logger log) {
            super(log);
            String degreeString;
            String coefString;
            String gammaString = polyNode.getAttribute("gamma");
            if (gammaString != null && gammaString.length() > 0) {
                try {
                    this.m_gamma = Double.parseDouble(gammaString);
                }
                catch (NumberFormatException e) {
                    String message = "[PolynomialKernel] : WARNING, can't parse gamma attribute. Using default value of 1.";
                    if (this.m_log == null) {
                        System.err.println(message);
                    }
                    this.m_log.logMessage(message);
                }
            }
            if ((coefString = polyNode.getAttribute("coef0")) != null && coefString.length() > 0) {
                try {
                    this.m_coef0 = Double.parseDouble(coefString);
                }
                catch (NumberFormatException e) {
                    String message = "[PolynomialKernel] : WARNING, can't parse coef0 attribute. Using default value of 1.";
                    if (this.m_log == null) {
                        System.err.println(message);
                    }
                    this.m_log.logMessage(message);
                }
            }
            if ((degreeString = polyNode.getAttribute("degree")) != null && degreeString.length() > 0) {
                try {
                    this.m_degree = Double.parseDouble(degreeString);
                }
                catch (NumberFormatException e) {
                    String message = "[PolynomialKernel] : WARNING, can't parse degree attribute. Using default value of 1.";
                    if (this.m_log == null) {
                        System.err.println(message);
                    }
                    this.m_log.logMessage(message);
                }
            }
        }

        @Override
        public double evaluate(VectorInstance x, VectorInstance y) throws Exception {
            double dotProd = x.dotProduct(y);
            return Math.pow(this.m_gamma * dotProd + this.m_coef0, this.m_degree);
        }

        @Override
        public double evaluate(VectorInstance x, double[] y) throws Exception {
            double dotProd = x.dotProduct(y);
            return Math.pow(this.m_gamma * dotProd + this.m_coef0, this.m_degree);
        }

        public String toString() {
            return "Polynomial kernel: K(x,y) = (" + this.m_gamma + " * <x,y> + " + this.m_coef0 + ")^" + this.m_degree;
        }
    }

    static class LinearKernel
    extends Kernel
    implements Serializable {
        private static final long serialVersionUID = 8991716708484953837L;

        public LinearKernel(Logger log) {
            super(log);
        }

        public LinearKernel() {
            super(null);
        }

        @Override
        public double evaluate(VectorInstance x, VectorInstance y) throws Exception {
            return x.dotProduct(y);
        }

        @Override
        public double evaluate(VectorInstance x, double[] y) throws Exception {
            return x.dotProduct(y);
        }

        public String toString() {
            return "Linear kernel: K(x,y) = <x,y>";
        }
    }

    static abstract class Kernel
    implements Serializable {
        protected Logger m_log = null;
        private static final long serialVersionUID = -6696443459968934767L;

        protected Kernel(Logger log) {
            this.m_log = log;
        }

        public abstract double evaluate(VectorInstance var1, VectorInstance var2) throws Exception;

        public abstract double evaluate(VectorInstance var1, double[] var2) throws Exception;

        public static Kernel getKernel(Element svmMachineModelElement, Logger log) throws Exception {
            NodeList kList = svmMachineModelElement.getElementsByTagName("LinearKernelType");
            if (kList.getLength() > 0) {
                return new LinearKernel(log);
            }
            kList = svmMachineModelElement.getElementsByTagName("PolynomialKernelType");
            if (kList.getLength() > 0) {
                return new PolynomialKernel((Element)kList.item(0), log);
            }
            kList = svmMachineModelElement.getElementsByTagName("RadialBasisKernelType");
            if (kList.getLength() > 0) {
                return new RadialBasisKernel((Element)kList.item(0), log);
            }
            kList = svmMachineModelElement.getElementsByTagName("SigmoidKernelType");
            if (kList.getLength() > 0) {
                return new SigmoidKernel((Element)kList.item(0), log);
            }
            throw new Exception("[Kernel] Can't find a kernel that I recognize!");
        }
    }
}

