/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner.sgd;

import java.util.Iterator;
import org.linqs.psl.config.Options;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.sgd.term.SGDObjectiveTerm;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.reasoner.term.VariableTermStore;
import org.linqs.psl.util.IteratorUtils;
import org.linqs.psl.util.MathUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SGDReasoner
extends Reasoner {
    private static final Logger log = LoggerFactory.getLogger(SGDReasoner.class);
    private int maxIterations = Options.SGD_MAX_ITER.getInt();
    private boolean watchMovement = Options.SGD_MOVEMENT.getBoolean();
    private float movementThreshold = Options.SGD_MOVEMENT_THRESHOLD.getFloat();

    @Override
    public void optimize(TermStore baseTermStore) {
        float movement;
        if (!(baseTermStore instanceof VariableTermStore)) {
            throw new IllegalArgumentException("SGDReasoner requires a VariableTermStore (found " + baseTermStore.getClass().getName() + ").");
        }
        VariableTermStore termStore = (VariableTermStore)baseTermStore;
        termStore.initForOptimization();
        float objective = -1.0f;
        float oldObjective = Float.POSITIVE_INFINITY;
        if (this.printInitialObj && log.isTraceEnabled()) {
            objective = this.computeObjective(termStore);
            log.trace("Iteration {} -- Objective: {}, Mean Movement: {}, Iteration Time: {}, Total Optimiztion Time: {}", 0, Float.valueOf(objective), Float.valueOf(0.0f), 0, 0);
        }
        int iteration = 1;
        long totalTime = 0L;
        do {
            long start = System.currentTimeMillis();
            movement = 0.0f;
            float[] variableValues = termStore.getVariableValues();
            for (SGDObjectiveTerm term : termStore) {
                movement += term.minimize(iteration, variableValues);
            }
            if (variableValues.length != 0) {
                movement /= (float)variableValues.length;
            }
            long end = System.currentTimeMillis();
            oldObjective = objective;
            objective = this.computeObjective(termStore);
            totalTime += end - start;
            if (log.isTraceEnabled()) {
                log.trace("Iteration {} -- Objective: {}, Mean Movement: {}, Iteration Time: {}, Total Optimiztion Time: {}", iteration, Float.valueOf(objective), Float.valueOf(movement), end - start, totalTime);
            }
            termStore.iterationComplete();
        } while (!this.breakOptimization(++iteration, objective, oldObjective, movement));
        termStore.syncAtoms();
        log.info("Optimization completed in {} iterations. Objective: {}, Total Optimiztion Time: {}", iteration - 1, Float.valueOf(objective), totalTime);
        log.debug("Optimized with {} variables and {} terms.", (Object)termStore.getNumVariables(), (Object)termStore.size());
    }

    private boolean breakOptimization(int iteration, float objective, float oldObjective, float movement) {
        if (iteration > (int)((double)this.maxIterations * this.budget)) {
            return true;
        }
        if (this.watchMovement && movement > this.movementThreshold) {
            return false;
        }
        return this.objectiveBreak && MathUtils.equals(objective, oldObjective, this.tolerance);
    }

    public float computeObjective(VariableTermStore<SGDObjectiveTerm, RandomVariableAtom> termStore) {
        float objective = 0.0f;
        Iterator termIterator = null;
        termIterator = termStore.isLoaded() ? termStore.noWriteIterator() : termStore.iterator();
        float[] variableValues = termStore.getVariableValues();
        for (SGDObjectiveTerm term : IteratorUtils.newIterable(termIterator)) {
            objective += term.evaluate(variableValues);
        }
        return objective / (float)termStore.size();
    }

    @Override
    public void close() {
    }
}

