/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner.dcd;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.evaluation.statistics.Evaluator;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.dcd.term.DCDObjectiveTerm;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.reasoner.term.VariableTermStore;
import org.linqs.psl.util.IteratorUtils;
import org.linqs.psl.util.MathUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DCDReasoner
extends Reasoner {
    private static final Logger log = LoggerFactory.getLogger(DCDReasoner.class);
    private int maxIterations = Options.DCD_MAX_ITER.getInt();
    private float c = Options.DCD_C.getFloat();
    private boolean truncateEveryStep = Options.DCD_TRUNCATE_EVERY_STEP.getBoolean();

    @Override
    public double optimize(TermStore baseTermStore, List<Evaluator> evaluators, TrainingMap trainingMap, Set<StandardPredicate> evaluationPredicates) {
        if (!(baseTermStore instanceof VariableTermStore)) {
            throw new IllegalArgumentException("DCDReasoner requires an VariableTermStore (found " + baseTermStore.getClass().getName() + ").");
        }
        VariableTermStore termStore = (VariableTermStore)baseTermStore;
        termStore.initForOptimization();
        long termCount = 0L;
        double change = 0.0;
        double objective = Double.POSITIVE_INFINITY;
        double oldObjective = Double.POSITIVE_INFINITY;
        float[] oldVariableValues = null;
        long totalTime = 0L;
        boolean breakDCD = false;
        int iteration = 1;
        while (!breakDCD) {
            long start = System.currentTimeMillis();
            termCount = 0L;
            objective = 0.0;
            for (DCDObjectiveTerm term : termStore) {
                if (iteration > 1) {
                    objective += (double)(term.evaluate(oldVariableValues) / this.c);
                }
                ++termCount;
                this.variableUpdate(term, termStore);
            }
            if (!this.truncateEveryStep) {
                float[] variableValues = termStore.getVariableValues();
                for (int i = 0; i < termStore.getNumVariables(); ++i) {
                    variableValues[i] = Math.max(0.0f, Math.min(1.0f, variableValues[i]));
                }
            }
            this.evaluate(termStore, iteration, evaluators, trainingMap, evaluationPredicates);
            termStore.iterationComplete();
            breakDCD = this.breakOptimization(iteration, objective, oldObjective, termCount);
            if (iteration == 1) {
                oldVariableValues = Arrays.copyOf(termStore.getVariableValues(), termStore.getVariableValues().length);
            } else {
                System.arraycopy(termStore.getVariableValues(), 0, oldVariableValues, 0, oldVariableValues.length);
                oldObjective = objective;
            }
            long end = System.currentTimeMillis();
            totalTime += end - start;
            if (iteration > 1 && log.isTraceEnabled()) {
                log.trace("Iteration {} -- Objective: {}, Normalized Objective: {}, Iteration Time: {}, Total Optimization Time: {}", iteration - 1, objective, objective / (double)termCount, end - start, totalTime);
            }
            ++iteration;
        }
        objective = this.computeObjective(termStore);
        change = termStore.syncAtoms();
        log.info("Final Objective: {}, Final Normalized Objective: {}, Total Optimization Time: {}, Total Number of Iterations: {}", objective, objective / (double)termCount, totalTime, iteration);
        log.debug("Movement of variables from initial state: {}", (Object)change);
        log.debug("Optimized with {} variables and {} terms.", (Object)termStore.getNumRandomVariables(), (Object)termCount);
        return objective;
    }

    private boolean breakOptimization(int iteration, double objective, double oldObjective, long termCount) {
        if (iteration > (int)((double)this.maxIterations * this.budget)) {
            return true;
        }
        if (this.runFullIterations) {
            return false;
        }
        return this.objectiveBreak && MathUtils.equals(objective / (double)termCount, oldObjective / (double)termCount, (double)this.tolerance);
    }

    private double computeObjective(VariableTermStore<DCDObjectiveTerm, GroundAtom> termStore) {
        double objective = 0.0;
        Iterator termIterator = null;
        termIterator = termStore.isLoaded() ? termStore.noWriteIterator() : termStore.iterator();
        for (DCDObjectiveTerm term : IteratorUtils.newIterable(termIterator)) {
            objective += (double)(term.evaluate(termStore.getVariableValues()) / this.c);
        }
        return objective;
    }

    private void variableUpdate(DCDObjectiveTerm term, VariableTermStore<DCDObjectiveTerm, GroundAtom> termStore) {
        GroundAtom[] variableAtoms = termStore.getVariableAtoms();
        float[] variableValues = termStore.getVariableValues();
        WeightedRule rule = term.getRule();
        float adjustedWeight = rule.getWeight() * this.c;
        float gradient = term.computeGradient(variableValues);
        if (term.isSquared()) {
            this.variableUpdate(term, gradient += term.getLagrange() / (2.0f * adjustedWeight), adjustedWeight, Float.POSITIVE_INFINITY, variableValues, variableAtoms);
        } else {
            this.variableUpdate(term, gradient, adjustedWeight, adjustedWeight, variableValues, variableAtoms);
        }
    }

    private void variableUpdate(DCDObjectiveTerm term, float gradient, float adjustedWeight, float lim, float[] variableValues, GroundAtom[] variableAtoms) {
        float pg = gradient;
        if (MathUtils.isZero(term.getLagrange())) {
            pg = Math.min(0.0f, gradient);
        }
        if (MathUtils.equals(lim, adjustedWeight) && MathUtils.equals(term.getLagrange(), adjustedWeight)) {
            pg = Math.max(0.0f, gradient);
        }
        if (MathUtils.isZero(pg)) {
            return;
        }
        float pa = term.getLagrange();
        int[] variableIndexes = term.getVariableIndexes();
        float[] coefficients = term.getCoefficients();
        term.setLagrange(Math.min(lim, Math.max(0.0f, term.getLagrange() - gradient / term.getQii())));
        for (int i = 0; i < term.size(); ++i) {
            if (variableAtoms[variableIndexes[i]] instanceof ObservedAtom) continue;
            float val = variableValues[variableIndexes[i]] - (term.getLagrange() - pa) * coefficients[i];
            if (this.truncateEveryStep) {
                val = Math.max(0.0f, Math.min(1.0f, val));
            }
            variableValues[variableIndexes[i]] = val;
        }
    }

    @Override
    public void close() {
    }
}

