/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner;

import java.util.Arrays;
import java.util.List;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.evaluation.EvaluationInstance;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.predicate.DeepPredicate;
import org.linqs.psl.reasoner.term.ReasonerTerm;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.reasoner.term.streaming.StreamingTermStore;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.Parallel;

public abstract class Reasoner<T extends ReasonerTerm> {
    private static final Logger log = Logger.getLogger(Reasoner.class);
    protected double budget = 1.0;
    protected boolean evaluate = Options.REASONER_EVALUATE.getBoolean();
    protected int maxIterations;
    protected boolean runFullIterations = Options.REASONER_RUN_FULL_ITERATIONS.getBoolean();
    protected boolean objectiveBreak = Options.REASONER_OBJECTIVE_BREAK.getBoolean();
    protected float objectiveTolerance = Options.REASONER_OBJECTIVE_TOLERANCE.getFloat();
    protected boolean variableMovementBreak = Options.REASONER_VARIABLE_MOVEMENT_BREAK.getBoolean();
    protected float variableMovementTolerance = Options.REASONER_VARIABLE_MOVEMENT_TOLERANCE.getFloat();
    protected float variableMovementNorm = Options.REASONER_VARIABLE_MOVEMENT_NORM.getFloat();
    protected float[] prevVariableValues = null;
    protected float[][] workerRVAtomGradients = null;
    protected float[][] workerDeepGradients = null;

    public double optimize(TermStore<T> termStore) {
        return this.optimize(termStore, null, null);
    }

    public abstract double optimize(TermStore<T> var1, List<EvaluationInstance> var2, TrainingMap var3);

    public void close() {
    }

    public void setBudget(double budget) {
        this.budget = budget;
    }

    protected void initForOptimization(TermStore<T> termStore) {
        log.debug("Performing optimization with {} variables and {} terms.", termStore.getVariableCounts(), termStore.size());
        if (log.isTraceEnabled()) {
            ObjectiveResult objective = null;
            objective = termStore instanceof StreamingTermStore ? this.computeObjective(termStore) : this.parallelComputeObjective(termStore);
            log.trace("Iteration {} -- Objective: {}, Violated Constraints: {}, Total Optimization Time: {}, Total Number of Iterations: {}.", 0, Float.valueOf(objective.objective), objective.violatedConstraints, 0, 0);
        }
    }

    protected void optimizationComplete(TermStore<T> termStore, ObjectiveResult finalObjective, long totalTime, int iteration) {
        float change = (float)termStore.sync();
        log.info("Final Objective: {}, Violated Constraints: {}, Total Optimization Time: {}, Total Number of Iterations: {}", Float.valueOf(finalObjective.objective), finalObjective.violatedConstraints, totalTime, iteration);
        log.debug("Movement of variables from initial state: {}", Float.valueOf(change));
    }

    protected boolean breakOptimization(int iteration, TermStore<T> termStore, ObjectiveResult objective, ObjectiveResult oldObjective) {
        if (iteration > (int)((double)this.maxIterations * this.budget)) {
            log.trace("Breaking optimization. Max iterations exceeded.");
            return true;
        }
        if (this.runFullIterations) {
            return false;
        }
        if (objective != null && objective.violatedConstraints > 0L) {
            return false;
        }
        if (this.objectiveBreak && objective != null && oldObjective != null && MathUtils.equals(objective.objective, oldObjective.objective, this.objectiveTolerance)) {
            log.trace("Breaking optimization. Objective change: {} below tolerance: {}.", Float.valueOf(Math.abs(objective.objective - oldObjective.objective)), Float.valueOf(this.objectiveTolerance));
            return true;
        }
        if (this.variableMovementBreak) {
            float[] variableValues = termStore.getVariableValues();
            if (this.prevVariableValues != null) {
                float[] movement = Arrays.copyOf(this.prevVariableValues, this.prevVariableValues.length);
                for (int i = 0; i < this.prevVariableValues.length; ++i) {
                    movement[i] = this.prevVariableValues[i] - variableValues[i];
                }
                float distance = MathUtils.pNorm(movement, this.variableMovementNorm);
                if (distance < this.variableMovementTolerance) {
                    log.trace("Breaking optimization. Movement of variables: {} below tolerance: {}.", Float.valueOf(distance), Float.valueOf(this.variableMovementTolerance));
                    return true;
                }
            }
            this.prevVariableValues = Arrays.copyOf(variableValues, variableValues.length);
        }
        return false;
    }

    public void computeOptimalValueGradient(TermStore<T> termStore, float[] rvAtomGradient, float[] deepAtomGradient) {
        this.parallelComputeGradient(termStore, rvAtomGradient, deepAtomGradient);
    }

    public void parallelComputeGradient(TermStore termStore, float[] rvAtomGradient, float[] deepAtomGradient) {
        int blockSize = (int)(termStore.size() / (long)(Parallel.getNumThreads() * 4) + 1L);
        int numTermBlocks = (int)Math.ceil((double)termStore.size() / (double)blockSize);
        if (this.workerRVAtomGradients == null || this.workerRVAtomGradients.length < numTermBlocks || this.workerRVAtomGradients[0].length < rvAtomGradient.length || this.workerDeepGradients == null || this.workerDeepGradients.length < numTermBlocks || this.workerDeepGradients[0].length < deepAtomGradient.length) {
            this.workerRVAtomGradients = new float[numTermBlocks][];
            this.workerDeepGradients = new float[numTermBlocks][];
            for (int i = 0; i < numTermBlocks; ++i) {
                this.workerRVAtomGradients[i] = new float[rvAtomGradient.length];
                this.workerDeepGradients[i] = new float[deepAtomGradient.length];
            }
        }
        Parallel.count(numTermBlocks, new GradientWorker(termStore, this.workerRVAtomGradients, this.workerDeepGradients, blockSize));
        Arrays.fill(rvAtomGradient, 0.0f);
        Arrays.fill(deepAtomGradient, 0.0f);
        for (int j = 0; j < numTermBlocks; ++j) {
            for (int i = 0; i < termStore.getNumVariables(); ++i) {
                int n = i;
                rvAtomGradient[n] = rvAtomGradient[n] + this.workerRVAtomGradients[j][i];
                int n2 = i;
                deepAtomGradient[n2] = deepAtomGradient[n2] + this.workerDeepGradients[j][i];
            }
        }
    }

    protected void clipGradient(float[] variableValues, float[] gradient) {
        for (int i = 0; i < gradient.length; ++i) {
            if (MathUtils.equals(variableValues[i], 0.0f) && gradient[i] > 0.0f) {
                gradient[i] = 0.0f;
                continue;
            }
            if (!MathUtils.equals(variableValues[i], 1.0f) || !(gradient[i] < 0.0f)) continue;
            gradient[i] = 0.0f;
        }
    }

    protected ObjectiveResult computeObjective(TermStore<T> termStore) {
        float objective = 0.0f;
        long violatedConstraints = 0L;
        float[] variableValues = termStore.getVariableValues();
        for (ReasonerTerm term : termStore) {
            if (!term.isActive()) continue;
            if (term.isConstraint()) {
                if (!(term.evaluate(variableValues) > 0.0f)) continue;
                ++violatedConstraints;
                continue;
            }
            objective += term.evaluate(variableValues);
        }
        return new ObjectiveResult(objective, violatedConstraints);
    }

    protected ObjectiveResult parallelComputeObjective(TermStore<T> termStore) {
        assert (!(termStore instanceof StreamingTermStore));
        int blockSize = (int)(termStore.size() / (long)(Parallel.getNumThreads() * 4) + 1L);
        int numTermBlocks = (int)Math.ceil((double)termStore.size() / (double)blockSize);
        float[] workerObjectives = new float[numTermBlocks];
        int[] workerViolatedConstraints = new int[numTermBlocks];
        Parallel.count(numTermBlocks, new ObjectiveWorker(termStore, workerObjectives, workerViolatedConstraints, blockSize));
        float objective = 0.0f;
        int violatedConstraints = 0;
        for (int i = 0; i < numTermBlocks; ++i) {
            objective += workerObjectives[i];
            violatedConstraints += workerViolatedConstraints[i];
        }
        return new ObjectiveResult(objective, violatedConstraints);
    }

    protected void evaluate(TermStore<T> termStore, int iteration, List<EvaluationInstance> evaluations, TrainingMap trainingMap) {
        if (!this.evaluate) {
            return;
        }
        if (trainingMap == null || evaluations == null || evaluations.size() == 0) {
            return;
        }
        termStore.sync();
        for (EvaluationInstance evaluation : evaluations) {
            evaluation.compute(trainingMap);
            log.info("Iteration {} -- {}.", iteration, evaluation.getOutput());
        }
    }

    protected static class ObjectiveResult {
        public float objective;
        public long violatedConstraints;

        public ObjectiveResult(float objective, long violatedConstraints) {
            this.objective = objective;
            this.violatedConstraints = violatedConstraints;
        }
    }

    private static class ObjectiveWorker
    extends Parallel.Worker<Long> {
        private final TermStore<? extends ReasonerTerm> termStore;
        private final int blockSize;
        private final float[] variableValues;
        private final float[] objectives;
        private final int[] violatedConstraints;

        public ObjectiveWorker(TermStore<? extends ReasonerTerm> termStore, float[] objectives, int[] violatedConstraints, int blockSize) {
            this.termStore = termStore;
            this.variableValues = termStore.getVariableValues();
            this.objectives = objectives;
            this.violatedConstraints = violatedConstraints;
            this.blockSize = blockSize;
        }

        public Object clone() {
            return new ObjectiveWorker(this.termStore, this.objectives, this.violatedConstraints, this.blockSize);
        }

        @Override
        public void work(long blockIndex, Long ignore) {
            int termIndex;
            int numTerms = (int)this.termStore.size();
            float objective = 0.0f;
            int violatedConstraints = 0;
            for (int innerBlockIndex = 0; innerBlockIndex < this.blockSize && (termIndex = (int)(blockIndex * (long)this.blockSize + (long)innerBlockIndex)) < numTerms; ++innerBlockIndex) {
                ReasonerTerm term = this.termStore.get(termIndex);
                if (!term.isActive()) continue;
                if (term.isConstraint()) {
                    if (MathUtils.isZero(term.evaluate(this.variableValues))) continue;
                    ++violatedConstraints;
                    continue;
                }
                objective += term.evaluate(this.variableValues);
            }
            this.objectives[(int)blockIndex] = objective;
            this.violatedConstraints[(int)blockIndex] = violatedConstraints;
        }
    }

    private static class GradientWorker
    extends Parallel.Worker<Long> {
        private final TermStore termStore;
        private final int blockSize;
        private final float[] variableValues;
        private final GroundAtom[] variableAtoms;
        private final float[][] rvAtomGradients;
        private final float[][] deepAtomGradients;

        public GradientWorker(TermStore termStore, float[][] rvAtomGradients, float[][] deepAtomGradients, int blockSize) {
            this.termStore = termStore;
            this.variableValues = termStore.getVariableValues();
            this.variableAtoms = termStore.getVariableAtoms();
            this.rvAtomGradients = rvAtomGradients;
            this.deepAtomGradients = deepAtomGradients;
            this.blockSize = blockSize;
        }

        public Object clone() {
            return new GradientWorker(this.termStore, this.rvAtomGradients, this.deepAtomGradients, this.blockSize);
        }

        @Override
        public void work(long blockIndex, Long ignore) {
            int termIndex;
            long numTerms = this.termStore.size();
            Arrays.fill(this.rvAtomGradients[(int)blockIndex], 0.0f);
            Arrays.fill(this.deepAtomGradients[(int)blockIndex], 0.0f);
            for (int innerBlockIndex = 0; innerBlockIndex < this.blockSize && (long)(termIndex = (int)(blockIndex * (long)this.blockSize + (long)innerBlockIndex)) < numTerms; ++innerBlockIndex) {
                Object term = this.termStore.get(termIndex);
                if (!((ReasonerTerm)term).isActive() || ((ReasonerTerm)term).isConstraint()) continue;
                int[] atomIndexes = ((ReasonerTerm)term).getAtomIndexes();
                float innerPotential = ((ReasonerTerm)term).computeInnerPotential(this.variableValues);
                for (int i = 0; i < ((ReasonerTerm)term).size(); ++i) {
                    if (this.variableAtoms[atomIndexes[i]] instanceof ObservedAtom) continue;
                    if (this.variableAtoms[atomIndexes[i]].getPredicate() instanceof DeepPredicate) {
                        float[] fArray = this.deepAtomGradients[(int)blockIndex];
                        int n = atomIndexes[i];
                        fArray[n] = fArray[n] + ((ReasonerTerm)term).computeVariablePartial(i, innerPotential);
                        continue;
                    }
                    float[] fArray = this.rvAtomGradients[(int)blockIndex];
                    int n = atomIndexes[i];
                    fArray[n] = fArray[n] + ((ReasonerTerm)term).computeVariablePartial(i, innerPotential);
                }
            }
        }
    }
}

