/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner.sgd;

import java.util.Iterator;
import org.linqs.psl.config.Config;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.sgd.term.SGDObjectiveTerm;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.reasoner.term.VariableTermStore;
import org.linqs.psl.util.IteratorUtils;
import org.linqs.psl.util.MathUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SGDReasoner
implements Reasoner {
    private static final Logger log = LoggerFactory.getLogger(SGDReasoner.class);
    public static final String CONFIG_PREFIX = "sgd";
    public static final String MAX_ITER_KEY = "sgd.maxiterations";
    public static final int MAX_ITER_DEFAULT = 200;
    public static final String OBJECTIVE_BREAK_KEY = "sgd.objectivebreak";
    public static final boolean OBJECTIVE_BREAK_DEFAULT = true;
    public static final String OBJ_TOL_KEY = "sgd.tolerance";
    public static final float OBJ_TOL_DEFAULT = 1.0E-5f;
    public static final String LEARNING_RATE_KEY = "sgd.learningrate";
    public static final float LEARNING_RATE_DEFAULT = 1.0f;
    public static final String PRINT_OBJECTIVE = "sgd.printobj";
    public static final boolean PRINT_OBJECTIVE_DEFAULT = true;
    public static final String PRINT_INITIAL_OBJECTIVE_KEY = "sgd.printinitialobj";
    public static final boolean PRINT_INITIAL_OBJECTIVE_DEFAULT = false;
    private int maxIter = Config.getInt("sgd.maxiterations", 200);
    private float tolerance;
    private boolean printObj;
    private boolean printInitialObj;
    private boolean objectiveBreak = Config.getBoolean("sgd.objectivebreak", true);

    public SGDReasoner() {
        this.printObj = Config.getBoolean(PRINT_OBJECTIVE, true);
        this.printInitialObj = Config.getBoolean(PRINT_INITIAL_OBJECTIVE_KEY, false);
        this.tolerance = Config.getFloat(OBJ_TOL_KEY, 1.0E-5f);
    }

    public int getMaxIter() {
        return this.maxIter;
    }

    public void setMaxIter(int maxIter) {
        this.maxIter = maxIter;
    }

    @Override
    public void optimize(TermStore baseTermStore) {
        if (!(baseTermStore instanceof VariableTermStore)) {
            throw new IllegalArgumentException("SGDReasoner requires an VariableTermStore (found " + baseTermStore.getClass().getName() + ").");
        }
        VariableTermStore termStore = (VariableTermStore)baseTermStore;
        float[] variableValues = termStore.getVariableValues();
        float objective = -1.0f;
        float oldObjective = Float.POSITIVE_INFINITY;
        int iteration = 1;
        if (this.printObj) {
            log.trace("objective:Iterations,Time(ms),Objective");
            if (this.printInitialObj) {
                objective = this.computeObjective(termStore, variableValues);
                log.trace("objective:{},{},{}", 0, 0, Float.valueOf(objective));
            }
        }
        long time = 0L;
        while (!(iteration > this.maxIter || this.objectiveBreak && iteration != 1 && MathUtils.equals(objective, oldObjective, this.tolerance))) {
            long start = System.currentTimeMillis();
            for (SGDObjectiveTerm term : termStore) {
                term.minimize(iteration, variableValues);
            }
            long end = System.currentTimeMillis();
            oldObjective = objective;
            objective = this.computeObjective(termStore, variableValues);
            time += end - start;
            if (this.printObj) {
                log.info("objective:{},{},{}", iteration, time, Float.valueOf(objective));
            }
            ++iteration;
        }
        termStore.syncAtoms();
        log.info("Optimization completed in {} iterations. Objective.: {}", (Object)(iteration - 1), (Object)Float.valueOf(objective));
        log.debug("Optimized with {} variables and {} terms.", (Object)termStore.getNumVariables(), (Object)termStore.size());
    }

    public float computeObjective(VariableTermStore<SGDObjectiveTerm, RandomVariableAtom> termStore, float[] variableValues) {
        float objective = 0.0f;
        Iterator termIterator = null;
        termIterator = termStore.isLoaded() ? termStore.noWriteIterator() : termStore.iterator();
        for (SGDObjectiveTerm term : IteratorUtils.newIterable(termIterator)) {
            objective += term.evaluate(variableValues);
        }
        return objective / (float)termStore.size();
    }

    @Override
    public void close() {
    }
}

