/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner.sgd;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.evaluation.statistics.Evaluator;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.sgd.term.SGDObjectiveTerm;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.reasoner.term.VariableTermStore;
import org.linqs.psl.util.ArrayUtils;
import org.linqs.psl.util.IteratorUtils;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.RandUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SGDReasoner
extends Reasoner {
    private static final Logger log = LoggerFactory.getLogger(SGDReasoner.class);
    private static final float EPSILON = 1.0E-8f;
    private int maxIterations = Options.SGD_MAX_ITER.getInt();
    private boolean watchMovement = Options.SGD_MOVEMENT.getBoolean();
    private float movementThreshold = Options.SGD_MOVEMENT_THRESHOLD.getFloat();
    private float initialLearningRate = Options.SGD_LEARNING_RATE.getFloat();
    private float learningRateInverseScaleExp = Options.SGD_INVERSE_TIME_EXP.getFloat();
    private float adamBeta1 = Options.SGD_ADAM_BETA_1.getFloat();
    private float adamBeta2 = Options.SGD_ADAM_BETA_2.getFloat();
    private float[] accumulatedGradientSquares = null;
    private float[] accumulatedGradientMean = null;
    private float[] accumulatedGradientVariance = null;
    private boolean coordinateStep = Options.SGD_COORDINATE_STEP.getBoolean();
    private SGDLearningSchedule learningSchedule = SGDLearningSchedule.valueOf(Options.SGD_LEARNING_SCHEDULE.getString().toUpperCase());
    private SGDExtension sgdExtension = SGDExtension.valueOf(Options.SGD_EXTENSION.getString().toUpperCase());

    @Override
    public double optimize(TermStore baseTermStore, List<Evaluator> evaluators, TrainingMap trainingMap, Set<StandardPredicate> evaluationPredicates) {
        if (!(baseTermStore instanceof VariableTermStore)) {
            throw new IllegalArgumentException("SGDReasoner requires a VariableTermStore (found " + baseTermStore.getClass().getName() + ").");
        }
        VariableTermStore termStore = (VariableTermStore)baseTermStore;
        termStore.initForOptimization();
        this.initForOptimization(termStore);
        long termCount = 0L;
        float meanMovement = 0.0f;
        float learningRate = 0.0f;
        double change = 0.0;
        double objective = 0.0;
        double oldObjective = Double.POSITIVE_INFINITY;
        float[] oldVariableValues = null;
        long totalTime = 0L;
        boolean breakSGD = false;
        int iteration = 1;
        while (!breakSGD) {
            long start = System.currentTimeMillis();
            termCount = 0L;
            meanMovement = 0.0f;
            objective = 0.0;
            learningRate = this.calculateAnnealedLearningRate(iteration);
            boolean useNonConvex = false;
            if (iteration >= this.nonconvexPeriod && iteration % this.nonconvexPeriod < this.nonconvexRounds) {
                useNonConvex = true;
            }
            for (SGDObjectiveTerm term : termStore) {
                if (iteration > 1) {
                    objective += (double)term.evaluate(oldVariableValues);
                }
                ++termCount;
                meanMovement += this.variableUpdate(term, termStore, iteration, learningRate);
            }
            this.evaluate(termStore, iteration, evaluators, trainingMap, evaluationPredicates);
            termStore.iterationComplete();
            if (termCount != 0L) {
                meanMovement /= (float)termCount;
            }
            breakSGD = this.breakOptimization(iteration, objective, oldObjective, meanMovement, termCount);
            if (iteration == 1) {
                oldVariableValues = Arrays.copyOf(termStore.getVariableValues(), termStore.getVariableValues().length);
            } else {
                System.arraycopy(termStore.getVariableValues(), 0, oldVariableValues, 0, oldVariableValues.length);
                oldObjective = objective;
            }
            long end = System.currentTimeMillis();
            totalTime += end - start;
            if (iteration > 1 && log.isTraceEnabled()) {
                log.trace("Iteration {} -- Objective: {}, Normalized Objective: {}, Iteration Time: {}, Total Optimization Time: {}", iteration - 1, objective, objective / (double)termCount, end - start, totalTime);
            }
            ++iteration;
        }
        this.optimizationComplete();
        objective = this.computeObjective(termStore);
        change = termStore.syncAtoms();
        log.info("Final Objective: {}, Final Normalized Objective: {}, Total Optimization Time: {}, Total Number of Iterations: {}", objective, objective / (double)termCount, totalTime, iteration);
        log.debug("Movement of variables from initial state: {}", (Object)change);
        log.debug("Optimized with {} variables and {} terms.", (Object)termStore.getNumRandomVariables(), (Object)termCount);
        return objective;
    }

    private void initForOptimization(VariableTermStore<SGDObjectiveTerm, GroundAtom> termStore) {
        switch (this.sgdExtension) {
            case NONE: {
                break;
            }
            case ADAGRAD: {
                this.accumulatedGradientSquares = new float[termStore.getNumRandomVariables()];
                break;
            }
            case ADAM: {
                this.accumulatedGradientMean = new float[termStore.getNumRandomVariables()];
                this.accumulatedGradientVariance = new float[termStore.getNumRandomVariables()];
                break;
            }
            default: {
                throw new IllegalArgumentException(String.format("Unsupported SGD Extensions: '%s'", new Object[]{this.sgdExtension}));
            }
        }
    }

    private void optimizationComplete() {
        this.accumulatedGradientSquares = null;
        this.accumulatedGradientMean = null;
        this.accumulatedGradientVariance = null;
    }

    private boolean breakOptimization(int iteration, double objective, double oldObjective, float movement, long termCount) {
        if (iteration > (int)((double)this.maxIterations * this.budget)) {
            return true;
        }
        if (this.runFullIterations) {
            return false;
        }
        if (this.watchMovement && movement > this.movementThreshold) {
            return false;
        }
        return this.objectiveBreak && MathUtils.equals(objective / (double)termCount, oldObjective / (double)termCount, (double)this.tolerance);
    }

    private double computeObjective(VariableTermStore<SGDObjectiveTerm, GroundAtom> termStore) {
        double objective = 0.0;
        Iterator termIterator = null;
        termIterator = termStore.isLoaded() ? termStore.noWriteIterator() : termStore.iterator();
        float[] variableValues = termStore.getVariableValues();
        for (SGDObjectiveTerm term : IteratorUtils.newIterable(termIterator)) {
            objective += (double)term.evaluate(variableValues);
        }
        return objective;
    }

    private float calculateAnnealedLearningRate(int iteration) {
        switch (this.learningSchedule) {
            case CONSTANT: {
                return this.initialLearningRate;
            }
            case STEPDECAY: {
                return this.initialLearningRate / (float)Math.pow(iteration, this.learningRateInverseScaleExp);
            }
        }
        throw new IllegalArgumentException(String.format("Illegal value found for SGD learning schedule: '%s'", new Object[]{this.learningSchedule}));
    }

    private float variableUpdate(SGDObjectiveTerm term, VariableTermStore<SGDObjectiveTerm, GroundAtom> termStore, int iteration, float learningRate) {
        if (!MathUtils.isZero(term.getDeterEpsilon())) {
            return this.updateDeter(term, termStore);
        }
        float movement = 0.0f;
        float variableStep = 0.0f;
        float newValue = 0.0f;
        float partial = 0.0f;
        GroundAtom[] variableAtoms = termStore.getVariableAtoms();
        float[] variableValues = termStore.getVariableValues();
        int size = term.size();
        WeightedRule rule = term.getRule();
        int[] variableIndexes = term.getVariableIndexes();
        float dot = term.dot(variableValues);
        for (int i = 0; i < size; ++i) {
            if (variableAtoms[variableIndexes[i]] instanceof ObservedAtom) continue;
            partial = term.computePartial(i, dot, rule.getWeight());
            variableStep = this.computeVariableStep(variableIndexes[i], iteration, learningRate, partial);
            newValue = Math.max(0.0f, Math.min(1.0f, variableValues[variableIndexes[i]] - variableStep));
            movement += Math.abs(newValue - variableValues[variableIndexes[i]]);
            variableValues[variableIndexes[i]] = newValue;
            if (!this.coordinateStep) continue;
            dot = term.dot(variableValues);
        }
        return movement;
    }

    private float computeVariableStep(int variableIndex, int iteration, float learningRate, float partial) {
        float step = 0.0f;
        float adaptedLearningRate = 0.0f;
        switch (this.sgdExtension) {
            case NONE: {
                step = partial * learningRate;
                break;
            }
            case ADAGRAD: {
                this.accumulatedGradientSquares = ArrayUtils.ensureCapacity(this.accumulatedGradientSquares, variableIndex);
                this.accumulatedGradientSquares[variableIndex] = this.accumulatedGradientSquares[variableIndex] + partial * partial;
                adaptedLearningRate = learningRate / (float)Math.sqrt(this.accumulatedGradientSquares[variableIndex] + 1.0E-8f);
                step = partial * adaptedLearningRate;
                break;
            }
            case ADAM: {
                float biasedGradientMean = 0.0f;
                float biasedGradientVariance = 0.0f;
                this.accumulatedGradientMean = ArrayUtils.ensureCapacity(this.accumulatedGradientMean, variableIndex);
                this.accumulatedGradientMean[variableIndex] = this.adamBeta1 * this.accumulatedGradientMean[variableIndex] + (1.0f - this.adamBeta1) * partial;
                this.accumulatedGradientVariance = ArrayUtils.ensureCapacity(this.accumulatedGradientVariance, variableIndex);
                this.accumulatedGradientVariance[variableIndex] = this.adamBeta2 * this.accumulatedGradientVariance[variableIndex] + (1.0f - this.adamBeta2) * partial * partial;
                biasedGradientMean = this.accumulatedGradientMean[variableIndex] / (1.0f - (float)Math.pow(this.adamBeta1, iteration));
                biasedGradientVariance = this.accumulatedGradientVariance[variableIndex] / (1.0f - (float)Math.pow(this.adamBeta2, iteration));
                adaptedLearningRate = learningRate / ((float)Math.sqrt(biasedGradientVariance) + 1.0E-8f);
                step = biasedGradientMean * adaptedLearningRate;
                break;
            }
            default: {
                throw new IllegalArgumentException(String.format("Unsupported SGD Extensions: '%s'", new Object[]{this.sgdExtension}));
            }
        }
        return step;
    }

    private float updateDeter(SGDObjectiveTerm term, VariableTermStore<SGDObjectiveTerm, GroundAtom> termStore) {
        float[] variableValues = termStore.getVariableValues();
        int[] variableIndexes = term.getVariableIndexes();
        int size = term.size();
        float deterValue = 1.0f / (float)size;
        float distance = 0.0f;
        for (int i = 0; i < size; ++i) {
            distance += Math.abs(deterValue - variableValues[variableIndexes[i]]);
        }
        if ((distance /= (float)size) > term.getDeterEpsilon()) {
            return 0.0f;
        }
        int upPoint = RandUtils.nextInt(size);
        float movement = 0.0f;
        for (int i = 0; i < size; ++i) {
            float newValue = i == upPoint ? 1.0f : 0.0f;
            movement += Math.abs(newValue - variableValues[variableIndexes[i]]);
            variableValues[variableIndexes[i]] = newValue;
        }
        return movement;
    }

    @Override
    public void close() {
    }

    public static enum SGDLearningSchedule {
        CONSTANT,
        STEPDECAY;

    }

    public static enum SGDExtension {
        NONE,
        ADAGRAD,
        ADAM;

    }
}

