/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight;

import java.util.List;
import java.util.Map;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.model.rule.WeightedRule;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class VotedPerceptron
extends WeightLearningApplication {
    private static final Logger log = LoggerFactory.getLogger(VotedPerceptron.class);
    protected double[] observedIncompatibility;
    protected double[] expectedIncompatibility;
    protected final double l2Regularization;
    protected final double l1Regularization;
    protected final boolean scaleGradient;
    protected double baseStepSize;
    protected boolean scaleStepSize;
    protected boolean averageSteps;
    protected boolean zeroInitialWeights;
    protected boolean clipNegativeWeights;
    protected boolean cutObjective;
    protected double inertia;
    protected final int maxNumSteps;
    protected int numSteps;
    private double currentLoss;

    public VotedPerceptron(List<Rule> rules, Database rvDB, Database observedDB) {
        super(rules, rvDB, observedDB);
        this.observedIncompatibility = new double[this.mutableRules.size()];
        this.expectedIncompatibility = new double[this.mutableRules.size()];
        this.maxNumSteps = this.numSteps = Options.WLA_VP_NUM_STEPS.getInt();
        this.baseStepSize = Options.WLA_VP_STEP.getDouble();
        this.inertia = Options.WLA_VP_INERTIA.getDouble();
        this.l1Regularization = Options.WLA_VP_L1.getDouble();
        this.l2Regularization = Options.WLA_VP_L2.getDouble();
        this.scaleGradient = Options.WLA_VP_SCALE_GRADIENT.getBoolean();
        this.averageSteps = Options.WLA_VP_AVERAGE_STEPS.getBoolean();
        this.scaleStepSize = Options.WLA_VP_SCALE_STEP.getBoolean();
        this.zeroInitialWeights = Options.WLA_VP_ZERO_INITIAL_WEIGHTS.getBoolean();
        this.clipNegativeWeights = Options.WLA_VP_CLIP_NEGATIVE_WEIGHTS.getBoolean();
        this.cutObjective = Options.WLA_VP_CUT_OBJECTIVE.getBoolean();
        this.currentLoss = Double.NaN;
    }

    @Override
    protected void postInitGroundModel() {
        if (this.trainingMap.getLatentVariables().size() > 0) {
            log.warn("Latent variable(s) found when using a VotedPerceptron-based weight learning method ({}). VotedPerceptron uses gradients to update weights, but latent variables may make the gradients less accurate. Weight learning may still perform sufficiently. Found {} latent variables. Example latent variable: [{}].", this.getClass().getName(), this.trainingMap.getLatentVariables().size(), this.trainingMap.getLatentVariables().get(0));
        }
    }

    @Override
    protected void doLearn() {
        int i;
        double[] avgWeights = new double[this.mutableRules.size()];
        this.computeObservedIncompatibility();
        this.setDefaultRandomVariables();
        if (this.zeroInitialWeights) {
            for (WeightedRule rule : this.mutableRules) {
                rule.setWeight(0.0);
            }
        }
        if (log.isDebugEnabled() && this.evaluator != null) {
            this.computeMPEState();
            this.evaluator.compute(this.trainingMap);
            double objective = -1.0 * this.evaluator.getNormalizedRepMetric();
            log.debug("Initial Training Objective: {}", (Object)objective);
        }
        double[] scalingFactor = this.computeScalingFactor();
        double[] lastSteps = new double[this.mutableRules.size()];
        double lastObjective = -1.0;
        double[] lastWeights = new double[this.mutableRules.size()];
        for (i = 0; i < this.mutableRules.size(); ++i) {
            lastWeights[i] = ((WeightedRule)this.mutableRules.get(i)).getWeight();
        }
        for (int step = 0; step < this.numSteps; ++step) {
            int i2;
            log.debug("Starting iteration {}", (Object)step);
            this.currentLoss = Double.NaN;
            this.computeExpectedIncompatibility();
            double norm = 0.0;
            for (int i3 = 0; i3 < this.mutableRules.size(); ++i3) {
                double newWeight = ((WeightedRule)this.mutableRules.get(i3)).getWeight();
                double currentStep = (this.expectedIncompatibility[i3] - this.observedIncompatibility[i3] - this.l2Regularization * newWeight - this.l1Regularization) / scalingFactor[i3];
                currentStep *= this.baseStepSize;
                if (this.scaleStepSize) {
                    currentStep /= (double)(step + 1);
                }
                newWeight = this.clipNegativeWeights ? Math.max(0.0, newWeight + currentStep) : (newWeight += (currentStep += this.inertia * lastSteps[i3]));
                log.trace("Gradient: {} (without momentun: {}), Expected Incomp.: {}, Observed Incomp.: {} -- ({}) {}", currentStep, currentStep - this.inertia * lastSteps[i3], this.expectedIncompatibility[i3], this.observedIncompatibility[i3], i3, this.mutableRules.get(i3));
                ((WeightedRule)this.mutableRules.get(i3)).setWeight(newWeight);
                lastSteps[i3] = currentStep;
                int n = i3;
                avgWeights[n] = avgWeights[n] + newWeight;
                norm += Math.pow(this.expectedIncompatibility[i3] - this.observedIncompatibility[i3], 2.0);
            }
            this.inMPEState = false;
            norm = Math.sqrt(norm);
            if (log.isDebugEnabled()) {
                this.getLoss();
            }
            double objective = -1.0;
            if ((this.cutObjective || log.isDebugEnabled()) && this.evaluator != null) {
                this.computeMPEState();
                this.evaluator.compute(this.trainingMap);
                objective = -1.0 * this.evaluator.getNormalizedRepMetric();
                if (this.cutObjective && step > 0 && objective > lastObjective) {
                    log.trace("Objective increased: {} -> {}, cutting step size: {} -> {}.", lastObjective, objective, this.baseStepSize, this.baseStepSize / 2.0);
                    this.baseStepSize /= 2.0;
                    objective = lastObjective;
                    for (i2 = 0; i2 < this.mutableRules.size(); ++i2) {
                        lastSteps[i2] = 0.0;
                        int n = i2;
                        avgWeights[n] = avgWeights[n] - ((WeightedRule)this.mutableRules.get(i2)).getWeight();
                        ((WeightedRule)this.mutableRules.get(i2)).setWeight(lastWeights[i2]);
                    }
                } else {
                    lastObjective = objective;
                }
            }
            for (i2 = 0; i2 < this.mutableRules.size(); ++i2) {
                lastWeights[i2] = ((WeightedRule)this.mutableRules.get(i2)).getWeight();
            }
            log.debug("Iteration {} complete. Likelihood: {}. Training Objective: {}, Icomp. L2-norm: {}", step, this.currentLoss, objective, norm);
            log.trace("Model {} ", (Object)this.mutableRules);
        }
        if (this.averageSteps) {
            for (i = 0; i < this.mutableRules.size(); ++i) {
                ((WeightedRule)this.mutableRules.get(i)).setWeight(avgWeights[i] / (double)this.numSteps);
            }
        }
    }

    protected double computeLoss() {
        double loss = 0.0;
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            loss += ((WeightedRule)this.mutableRules.get(i)).getWeight() * (this.observedIncompatibility[i] - this.expectedIncompatibility[i]);
        }
        return loss;
    }

    protected double computeRegularizer() {
        if (this.l1Regularization == 0.0 && this.l2Regularization == 0.0) {
            return 0.0;
        }
        double l2 = 0.0;
        double l1 = 0.0;
        for (WeightedRule rule : this.mutableRules) {
            l2 += Math.pow(rule.getWeight(), 2.0);
            l1 += Math.abs(rule.getWeight());
        }
        return 0.5 * this.l2Regularization * l2 + this.l1Regularization * l1;
    }

    public double getLoss() {
        if (Double.isNaN(this.currentLoss)) {
            this.currentLoss = this.computeLoss();
        }
        return this.currentLoss;
    }

    protected double[] computeScalingFactor() {
        double[] factor = new double[this.mutableRules.size()];
        for (int i = 0; i < factor.length; ++i) {
            factor[i] = Math.max(1.0, (double)this.inference.getGroundRuleStore().count((Rule)this.mutableRules.get(i)));
        }
        return factor;
    }

    protected void computeObservedIncompatibility() {
        int i;
        this.setLabeledRandomVariables();
        for (i = 0; i < this.observedIncompatibility.length; ++i) {
            this.observedIncompatibility[i] = 0.0;
        }
        for (i = 0; i < this.mutableRules.size(); ++i) {
            for (GroundRule groundRule : this.inference.getGroundRuleStore().getGroundRules((Rule)this.mutableRules.get(i))) {
                int n = i;
                this.observedIncompatibility[n] = this.observedIncompatibility[n] + ((WeightedGroundRule)groundRule).getIncompatibility();
            }
        }
    }

    protected void computeExpectedIncompatibility() {
        int i;
        this.computeMPEState();
        for (i = 0; i < this.expectedIncompatibility.length; ++i) {
            this.expectedIncompatibility[i] = 0.0;
        }
        for (i = 0; i < this.mutableRules.size(); ++i) {
            for (GroundRule groundRule : this.inference.getGroundRuleStore().getGroundRules((Rule)this.mutableRules.get(i))) {
                int n = i;
                this.expectedIncompatibility[n] = this.expectedIncompatibility[n] + ((WeightedGroundRule)groundRule).getIncompatibility();
            }
        }
    }

    protected void setLabeledRandomVariables() {
        this.inMPEState = false;
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getLabelMap().entrySet()) {
            entry.getKey().setValue(entry.getValue().getValue());
        }
    }

    protected void setDefaultRandomVariables() {
        this.inMPEState = false;
        for (RandomVariableAtom atom : this.trainingMap.getLabelMap().keySet()) {
            atom.setValue(0.0f);
        }
        for (RandomVariableAtom atom : this.trainingMap.getLatentVariables()) {
            atom.setValue(0.0f);
        }
    }

    @Override
    public void setBudget(double budget) {
        super.setBudget(budget);
        this.numSteps = (int)Math.ceil(budget * (double)this.maxNumSteps);
    }
}

