/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner.admm;

import java.util.List;
import java.util.Set;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.evaluation.statistics.Evaluator;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.admm.term.ADMMObjectiveTerm;
import org.linqs.psl.reasoner.admm.term.ADMMTermStore;
import org.linqs.psl.reasoner.admm.term.LocalVariable;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.Parallel;

public class ADMMReasoner
extends Reasoner {
    private static final Logger log = Logger.getLogger(ADMMReasoner.class);
    private static final float LOWER_BOUND = 0.0f;
    private static final float UPPER_BOUND = 1.0f;
    private int computePeriod;
    private final float stepSize;
    private double epsilonRel;
    private double epsilonAbs;
    private double primalRes;
    private double epsilonPrimal;
    private double dualRes;
    private double epsilonDual;
    private double AxNorm;
    private double AyNorm;
    private double BzNorm;
    private double lagrangePenalty;
    private double augmentedLagrangePenalty;
    private int maxIterations = Options.ADMM_MAX_ITER.getInt();
    private long termBlockSize;
    private long variableBlockSize;

    public ADMMReasoner() {
        this.stepSize = Options.ADMM_STEP_SIZE.getFloat();
        this.computePeriod = Options.ADMM_COMPUTE_PERIOD.getInt();
        this.epsilonAbs = Options.ADMM_EPSILON_ABS.getDouble();
        this.epsilonRel = Options.ADMM_EPSILON_REL.getDouble();
    }

    public double getEpsilonRel() {
        return this.epsilonRel;
    }

    public void setEpsilonRel(double epsilonRel) {
        this.epsilonRel = epsilonRel;
    }

    public double getEpsilonAbs() {
        return this.epsilonAbs;
    }

    public void setEpsilonAbs(double epsilonAbs) {
        this.epsilonAbs = epsilonAbs;
    }

    public double getLagrangianPenalty() {
        return this.lagrangePenalty;
    }

    public double getAugmentedLagrangianPenalty() {
        return this.augmentedLagrangePenalty;
    }

    @Override
    public double optimize(TermStore baseTermStore, List<Evaluator> evaluators, TrainingMap trainingMap, Set<StandardPredicate> evaluationPredicates) {
        if (!(baseTermStore instanceof ADMMTermStore)) {
            throw new IllegalArgumentException("ADMMReasoner requires an ADMMTermStore (found " + baseTermStore.getClass().getName() + ").");
        }
        ADMMTermStore termStore = (ADMMTermStore)baseTermStore;
        termStore.initForOptimization();
        long numTerms = termStore.size();
        int numVariables = termStore.getNumConsensusVariables();
        log.debug("Performing optimization with {} variables and {} terms.", numVariables, numTerms);
        this.termBlockSize = numTerms / (long)(Parallel.getNumThreads() * 4) + 1L;
        this.variableBlockSize = numVariables / (Parallel.getNumThreads() * 4) + 1;
        long numTermBlocks = (long)Math.ceil((double)numTerms / (double)this.termBlockSize);
        long numVariableBlocks = (long)Math.ceil((double)numVariables / (double)this.variableBlockSize);
        double epsilonAbsTerm = Math.sqrt(termStore.getNumLocalVariables()) * this.epsilonAbs;
        ObjectiveResult objective = null;
        ObjectiveResult oldObjective = null;
        if (log.isTraceEnabled()) {
            objective = this.computeObjective(termStore);
            log.trace("Iteration {} -- Objective: {}, Feasible: {}.", 0, objective.objective, objective.violatedConstraints == 0L);
        }
        int iteration = 1;
        do {
            this.primalRes = 0.0;
            this.dualRes = 0.0;
            this.AxNorm = 0.0;
            this.AyNorm = 0.0;
            this.BzNorm = 0.0;
            this.lagrangePenalty = 0.0;
            this.augmentedLagrangePenalty = 0.0;
            boolean useNonConvex = false;
            if (iteration >= this.nonconvexPeriod && iteration % this.nonconvexPeriod < this.nonconvexRounds) {
                useNonConvex = true;
            }
            Parallel.count(numTermBlocks, new TermWorker(termStore, this.termBlockSize, useNonConvex));
            Parallel.count(numVariableBlocks, new VariableWorker(termStore, this.variableBlockSize, useNonConvex));
            this.primalRes = Math.sqrt(this.primalRes);
            this.dualRes = (double)this.stepSize * Math.sqrt(this.dualRes);
            this.epsilonPrimal = epsilonAbsTerm + this.epsilonRel * Math.max(Math.sqrt(this.AxNorm), Math.sqrt(this.BzNorm));
            this.epsilonDual = epsilonAbsTerm + this.epsilonRel * Math.sqrt(this.AyNorm);
            if (iteration % this.computePeriod != 0) continue;
            if (!this.objectiveBreak) {
                log.trace("Iteration {} -- Primal: {}, Dual: {}, Epsilon Primal: {}, Epsilon Dual: {}.", iteration, this.primalRes, this.dualRes, this.epsilonPrimal, this.epsilonDual);
            } else {
                oldObjective = objective;
                objective = this.computeObjective(termStore);
                log.trace("Iteration {} -- Objective: {}, Feasible: {}, Primal: {}, Dual: {}, Epsilon Primal: {}, Epsilon Dual: {}.", iteration, objective.objective, objective.violatedConstraints == 0L, this.primalRes, this.dualRes, this.epsilonPrimal, this.epsilonDual);
            }
            this.evaluate(termStore, iteration, evaluators, trainingMap, evaluationPredicates);
            termStore.iterationComplete();
        } while (!this.breakOptimization(++iteration, objective, oldObjective) || !this.breakOptimization(iteration, objective = this.computeObjective(termStore), oldObjective));
        log.info("Optimization completed in {} iterations. Objective: {}, Feasible: {}, Primal res.: {}, Dual res.: {}", iteration - 1, objective.objective, objective.violatedConstraints == 0L, this.primalRes, this.dualRes);
        if (objective.violatedConstraints > 0L) {
            log.warn("No feasible solution found. {} constraints violated.", objective.violatedConstraints);
            this.computeObjective(termStore);
        }
        termStore.syncAtoms();
        return objective.objective;
    }

    private boolean breakOptimization(int iteration, ObjectiveResult objective, ObjectiveResult oldObjective) {
        if (iteration > (int)((double)this.maxIterations * this.budget)) {
            return true;
        }
        if (this.runFullIterations) {
            return false;
        }
        if (objective != null && objective.violatedConstraints > 0L) {
            return false;
        }
        if (iteration > 1 && this.primalRes < this.epsilonPrimal && this.dualRes < this.epsilonDual) {
            return true;
        }
        return this.objectiveBreak && oldObjective != null && MathUtils.equals(objective.objective, oldObjective.objective, (double)this.tolerance);
    }

    @Override
    public void close() {
    }

    private ObjectiveResult computeObjective(ADMMTermStore termStore) {
        double objective = 0.0;
        long violatedConstraints = 0L;
        float[] consensusValues = termStore.getConsensusValues();
        for (ADMMObjectiveTerm term : termStore) {
            if (term.isConstraint()) {
                if (!(term.evaluate(consensusValues) > 0.0f)) continue;
                ++violatedConstraints;
                continue;
            }
            objective += (double)term.evaluate(consensusValues);
        }
        return new ObjectiveResult(objective, violatedConstraints);
    }

    private synchronized void updateIterationVariables(double primalRes, double dualRes, double AxNorm, double BzNorm, double AyNorm, double lagrangePenalty, double augmentedLagrangePenalty) {
        this.primalRes += primalRes;
        this.dualRes += dualRes;
        this.AxNorm += AxNorm;
        this.AyNorm += AyNorm;
        this.BzNorm += BzNorm;
        this.lagrangePenalty += lagrangePenalty;
        this.augmentedLagrangePenalty += augmentedLagrangePenalty;
    }

    private static class ObjectiveResult {
        public final double objective;
        public final long violatedConstraints;

        public ObjectiveResult(double objective, long violatedConstraints) {
            this.objective = objective;
            this.violatedConstraints = violatedConstraints;
        }
    }

    private class VariableWorker
    extends Parallel.Worker<Long> {
        private final ADMMTermStore termStore;
        private final long blockSize;
        private final float[] consensusValues;
        private final boolean useNonConvex;

        public VariableWorker(ADMMTermStore termStore, long blockSize, boolean useNonConvex) {
            this.termStore = termStore;
            this.blockSize = blockSize;
            this.useNonConvex = useNonConvex;
            this.consensusValues = termStore.getConsensusValues();
        }

        public Object clone() {
            return new VariableWorker(this.termStore, this.blockSize, this.useNonConvex);
        }

        @Override
        public void work(long blockIndex, Long ignore) {
            int variableIndex;
            int numVariables = this.termStore.getNumConsensusVariables();
            double primalResInc = 0.0;
            double dualResInc = 0.0;
            double AxNormInc = 0.0;
            double BzNormInc = 0.0;
            double AyNormInc = 0.0;
            double lagrangePenaltyInc = 0.0;
            double augmentedLagrangePenaltyInc = 0.0;
            int innerBlockIndex = 0;
            while ((long)innerBlockIndex < this.blockSize && (variableIndex = (int)(blockIndex * this.blockSize + (long)innerBlockIndex)) < numVariables) {
                double total = 0.0;
                int numLocalVariables = this.termStore.getLocalVariables(variableIndex).size();
                for (int localVarIndex = 0; localVarIndex < numLocalVariables; ++localVarIndex) {
                    LocalVariable localVariable = (LocalVariable)this.termStore.getLocalVariables(variableIndex).get(localVarIndex);
                    total += (double)(localVariable.getValue() + localVariable.getLagrange() / ADMMReasoner.this.stepSize);
                    AxNormInc += (double)(localVariable.getValue() * localVariable.getValue());
                    AyNormInc += (double)(localVariable.getLagrange() * localVariable.getLagrange());
                }
                float newConsensusValue = (float)(total / (double)numLocalVariables);
                newConsensusValue = Math.max(Math.min(newConsensusValue, 1.0f), 0.0f);
                float diff = this.consensusValues[variableIndex] - newConsensusValue;
                dualResInc += (double)(diff * diff * (float)numLocalVariables);
                BzNormInc += (double)(newConsensusValue * newConsensusValue * (float)numLocalVariables);
                this.consensusValues[variableIndex] = newConsensusValue;
                for (int localVarIndex = 0; localVarIndex < numLocalVariables; ++localVarIndex) {
                    LocalVariable localVariable = (LocalVariable)this.termStore.getLocalVariables(variableIndex).get(localVarIndex);
                    diff = localVariable.getValue() - newConsensusValue;
                    primalResInc += (double)(diff * diff);
                    lagrangePenaltyInc += (double)(localVariable.getLagrange() * (localVariable.getValue() - this.consensusValues[variableIndex]));
                    augmentedLagrangePenaltyInc += 0.5 * (double)ADMMReasoner.this.stepSize * Math.pow(localVariable.getValue() - this.consensusValues[variableIndex], 2.0);
                }
                ++innerBlockIndex;
            }
            ADMMReasoner.this.updateIterationVariables(primalResInc, dualResInc, AxNormInc, BzNormInc, AyNormInc, lagrangePenaltyInc, augmentedLagrangePenaltyInc);
        }
    }

    private class TermWorker
    extends Parallel.Worker<Long> {
        private final ADMMTermStore termStore;
        private final long blockSize;
        private final float[] consensusValues;
        private final boolean useNonConvex;

        public TermWorker(ADMMTermStore termStore, long blockSize, boolean useNonConvex) {
            this.termStore = termStore;
            this.blockSize = blockSize;
            this.useNonConvex = useNonConvex;
            this.consensusValues = termStore.getConsensusValues();
        }

        public Object clone() {
            return new TermWorker(this.termStore, this.blockSize, this.useNonConvex);
        }

        @Override
        public void work(long blockIndex, Long ignore) {
            long termIndex;
            long numTerms = this.termStore.size();
            int innerBlockIndex = 0;
            while ((long)innerBlockIndex < this.blockSize && (termIndex = blockIndex * this.blockSize + (long)innerBlockIndex) < numTerms) {
                if (this.useNonConvex || ((ADMMObjectiveTerm)this.termStore.get(termIndex)).isConvex()) {
                    ((ADMMObjectiveTerm)this.termStore.get(termIndex)).updateLagrange(ADMMReasoner.this.stepSize, this.consensusValues);
                    ((ADMMObjectiveTerm)this.termStore.get(termIndex)).minimize(ADMMReasoner.this.stepSize, this.consensusValues);
                }
                ++innerBlockIndex;
            }
        }
    }
}

