/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.em;

import java.util.List;
import org.linqs.psl.application.learning.weight.em.ExpectationMaximization;
import org.linqs.psl.config.Config;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.reasoner.admm.ADMMReasoner;
import org.linqs.psl.reasoner.admm.term.ADMMTermStore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PairedDualLearner
extends ExpectationMaximization {
    private static final Logger log = LoggerFactory.getLogger(PairedDualLearner.class);
    public static final String CONFIG_PREFIX = "pairedduallearner";
    public static final String WARMUP_ROUNDS_KEY = "pairedduallearner.warmuprounds";
    public static final int WARMUP_ROUNDS_DEFAULT = 0;
    public static final String ADMM_STEPS_KEY = "pairedduallearner.admmsteps";
    public static final int ADMM_STEPS_DEFAULT = 1;
    private final int warmupRounds = Config.getInt("pairedduallearner.warmuprounds", 0);
    private final int admmIterations;

    public PairedDualLearner(Model model, Database rvDB, Database observedDB) {
        this(model.getRules(), rvDB, observedDB);
    }

    public PairedDualLearner(List<Rule> rules, Database rvDB, Database observedDB) {
        super(rules, rvDB, observedDB);
        if (this.warmupRounds < 0) {
            throw new IllegalArgumentException("pairedduallearner.warmuprounds must be a nonnegative integer.");
        }
        this.admmIterations = Config.getInt(ADMM_STEPS_KEY, 1);
        if (this.admmIterations < 1) {
            throw new IllegalArgumentException("pairedduallearner.admmsteps must be a positive integer.");
        }
    }

    @Override
    protected void computeExpectedIncompatibility() {
        this.computeMPEState();
        for (int i = 0; i < this.expectedIncompatibility.length; ++i) {
            this.expectedIncompatibility[i] = 0.0;
        }
        ADMMReasoner admmReasoner = (ADMMReasoner)this.reasoner;
        float[] consensusBuffer = new float[((ADMMTermStore)this.termStore).getNumGlobalVariables()];
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            for (GroundRule groundRule : this.groundRuleStore.getGroundRules((Rule)this.mutableRules.get(i))) {
                int n = i;
                this.expectedIncompatibility[n] = this.expectedIncompatibility[n] + admmReasoner.getDualIncompatibility(groundRule, (ADMMTermStore)this.termStore, consensusBuffer);
            }
        }
    }

    @Override
    protected void computeObservedIncompatibility() {
        this.setLabeledRandomVariables();
        this.computeLatentMPEState();
        for (int i = 0; i < this.observedIncompatibility.length; ++i) {
            this.observedIncompatibility[i] = 0.0;
        }
        ADMMReasoner admmReasoner = (ADMMReasoner)this.reasoner;
        float[] consensusBuffer = new float[((ADMMTermStore)this.latentTermStore).getNumGlobalVariables()];
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            for (GroundRule groundRule : this.latentGroundRuleStore.getGroundRules((Rule)this.mutableRules.get(i))) {
                int n = i;
                this.observedIncompatibility[n] = this.observedIncompatibility[n] + admmReasoner.getDualIncompatibility(groundRule, (ADMMTermStore)this.latentTermStore, consensusBuffer);
            }
        }
    }

    @Override
    protected void doLearn() {
        if (!(this.reasoner instanceof ADMMReasoner)) {
            throw new IllegalArgumentException(String.format("PairedDualLearning can only be used with ADMMReasoner, found %s.", this.reasoner.getClass().getName()));
        }
        if (!(this.termStore instanceof ADMMTermStore)) {
            throw new IllegalArgumentException(String.format("PairedDualLearning can only be used with ADMMTermStore, found %s.", this.termStore.getClass().getName()));
        }
        if (!(this.latentTermStore instanceof ADMMTermStore)) {
            throw new IllegalArgumentException(String.format("PairedDualLearning (latent) can only be used with ADMMTermStore, found %s.", this.latentTermStore.getClass().getName()));
        }
        ADMMReasoner admmReasoner = (ADMMReasoner)this.reasoner;
        int oldMaxIter = admmReasoner.getMaxIter();
        admmReasoner.setMaxIter(this.admmIterations);
        if (this.warmupRounds > 0) {
            log.debug("Warming up optimizer with {} iterations.", (Object)(this.warmupRounds * this.admmIterations));
            for (int i = 0; i < this.warmupRounds; ++i) {
                this.reasoner.optimize(this.termStore);
                this.reasoner.optimize(this.latentTermStore);
            }
        }
        this.subgrad();
        admmReasoner.setMaxIter(oldMaxIter);
    }

    private void subgrad() {
        log.info("Starting optimization");
        double[] weights = new double[this.mutableRules.size()];
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            weights[i] = ((WeightedRule)this.mutableRules.get(i)).getWeight();
        }
        double[] gradient = new double[this.mutableRules.size()];
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            gradient[i] = 1.0;
        }
        double[] avgWeights = new double[this.mutableRules.size()];
        double objective = 0.0;
        this.emIteration = 0;
        while (this.emIteration < this.iterations) {
            objective = this.getValueAndGradient(gradient, weights);
            double gradNorm = 0.0;
            double change = 0.0;
            for (int i = 0; i < this.mutableRules.size(); ++i) {
                gradNorm += Math.pow(weights[i] - Math.max(0.0, weights[i] - gradient[i]), 2.0);
                double coeff = this.baseStepSize;
                double delta = Math.max(-weights[i], -coeff * gradient[i]);
                int n = i;
                weights[n] = weights[n] + delta;
                gradient[i] = delta;
                change += Math.pow(delta, 2.0);
                avgWeights[i] = (1.0 - 1.0 / ((double)this.emIteration + 1.0)) * avgWeights[i] + 1.0 / ((double)this.emIteration + 1.0) * weights[i];
            }
            gradNorm = Math.sqrt(gradNorm);
            change = Math.sqrt(change);
            log.debug("Iter {}, obj: {}, norm grad: {}, change: {}", this.emIteration, objective, gradNorm, change);
            if (change < this.tolerance) {
                log.info("Change in w ({}) is less than tolerance. Finishing subgrad.", (Object)change);
                break;
            }
            ++this.emIteration;
        }
        log.info("Learning finished with final objective value {}", (Object)objective);
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            if (this.averageSteps) {
                weights[i] = avgWeights[i];
            }
            ((WeightedRule)this.mutableRules.get(i)).setWeight(weights[i]);
        }
        this.inMPEState = false;
        this.inLatentMPEState = false;
    }

    private double getValueAndGradient(double[] gradient, double[] weights) {
        int i;
        for (int i2 = 0; i2 < this.mutableRules.size(); ++i2) {
            if (gradient[i2] == 0.0) continue;
            ((WeightedRule)this.mutableRules.get(i2)).setWeight(weights[i2]);
        }
        this.inMPEState = false;
        this.inLatentMPEState = false;
        ADMMReasoner admmReasoner = (ADMMReasoner)this.reasoner;
        this.computeObservedIncompatibility();
        double eStepLagrangianPenalty = admmReasoner.getLagrangianPenalty();
        double eStepAugLagrangianPenalty = admmReasoner.getAugmentedLagrangianPenalty();
        this.computeExpectedIncompatibility();
        double mStepLagrangianPenalty = admmReasoner.getLagrangianPenalty();
        double mStepAugLagrangianPenalty = admmReasoner.getAugmentedLagrangianPenalty();
        double loss = 0.0;
        for (i = 0; i < this.mutableRules.size(); ++i) {
            loss += weights[i] * (this.observedIncompatibility[i] - this.expectedIncompatibility[i]);
        }
        loss += eStepLagrangianPenalty + eStepAugLagrangianPenalty - mStepLagrangianPenalty - mStepAugLagrangianPenalty;
        for (i = 0; i < this.mutableRules.size(); ++i) {
            log.debug("Incompatibility for rule {}", this.mutableRules.get(i));
            log.debug("Truth incompatbility {}, expected incompatibility {}", (Object)this.observedIncompatibility[i], (Object)this.expectedIncompatibility[i]);
        }
        log.debug("E Penalty: {}, E Aug Penalty: {}, M Penalty: {}, M Aug Penalty: {}", eStepLagrangianPenalty, eStepAugLagrangianPenalty, mStepLagrangianPenalty, mStepAugLagrangianPenalty);
        double regularizer = this.computeRegularizer();
        if (gradient != null) {
            for (int i3 = 0; i3 < this.mutableRules.size(); ++i3) {
                gradient[i3] = this.observedIncompatibility[i3] - this.expectedIncompatibility[i3];
                if (this.scaleGradient && (double)this.groundRuleStore.count((Rule)this.mutableRules.get(i3)) > 0.0) {
                    int n = i3;
                    gradient[n] = gradient[n] / (double)this.groundRuleStore.count((Rule)this.mutableRules.get(i3));
                }
                int n = i3;
                gradient[n] = gradient[n] + (this.l2Regularization * weights[i3] + this.l1Regularization);
            }
        }
        return loss + regularizer;
    }
}

