/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.gradient;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.linqs.psl.application.inference.InferenceApplication;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.deep.DeepModelPredicate;
import org.linqs.psl.model.predicate.DeepPredicate;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.reasoner.InitialValue;
import org.linqs.psl.reasoner.term.ReasonerTerm;
import org.linqs.psl.reasoner.term.TermState;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;

public abstract class GradientDescent
extends WeightLearningApplication {
    private static final Logger log = Logger.getLogger(GradientDescent.class);
    protected GDExtension gdExtension = GDExtension.valueOf(Options.WLA_GRADIENT_DESCENT_EXTENSION.getString().toUpperCase());
    protected Map<WeightedRule, Integer> ruleIndexMap = new HashMap<WeightedRule, Integer>(this.mutableRules.size());
    protected float[] weightGradient;
    protected float[] rvAtomGradient;
    protected float[] deepAtomGradient;
    protected float[] MAPRVAtomGradient;
    protected float[] MAPDeepAtomGradient;
    protected TermState[] trainMAPTermState;
    protected float[] trainMAPAtomValueState;
    protected TermState[] validationMAPTermState;
    protected float[] validationMAPAtomValueState;
    protected boolean saveBestValidationWeights;
    protected float[] bestValidationWeights;
    double currentValidationEvaluationMetric;
    double bestValidationEvaluationMetric;
    protected float baseStepSize;
    protected boolean scaleStepSize;
    protected float maxGradientMagnitude;
    protected float maxGradientNorm;
    protected float stoppingGradientNorm;
    protected boolean clipWeightGradient;
    protected int maxNumSteps;
    protected boolean runFullIterations;
    protected boolean movementBreak;
    protected float parameterMovement;
    protected float movementTolerance;
    protected boolean normBreak;
    protected float normTolerance;
    protected boolean objectiveBreak;
    protected float objectiveTolerance;
    protected float l2Regularization;
    protected float logRegularization;
    protected float entropyRegularization;

    public GradientDescent(List<Rule> rules, Database trainTargetDatabase, Database trainTruthDatabase, Database validationTargetDatabase, Database validationTruthDatabase, boolean runValidation) {
        super(rules, trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, runValidation);
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            this.ruleIndexMap.put((WeightedRule)this.mutableRules.get(i), i);
        }
        this.weightGradient = new float[this.mutableRules.size()];
        this.rvAtomGradient = null;
        this.deepAtomGradient = null;
        this.MAPRVAtomGradient = null;
        this.MAPDeepAtomGradient = null;
        this.trainMAPTermState = null;
        this.trainMAPAtomValueState = null;
        this.validationMAPTermState = null;
        this.validationMAPAtomValueState = null;
        this.saveBestValidationWeights = Options.WLA_GRADIENT_DESCENT_SAVE_BEST_VALIDATION_WEIGHTS.getBoolean();
        this.bestValidationWeights = null;
        this.currentValidationEvaluationMetric = Double.NEGATIVE_INFINITY;
        this.bestValidationEvaluationMetric = Double.NEGATIVE_INFINITY;
        if (this.saveBestValidationWeights && !this.runValidation) {
            throw new IllegalArgumentException("If saveBestValidationWeights is true, then runValidation must also be true.");
        }
        this.baseStepSize = Options.WLA_GRADIENT_DESCENT_STEP_SIZE.getFloat();
        this.scaleStepSize = Options.WLA_GRADIENT_DESCENT_SCALE_STEP.getBoolean();
        this.clipWeightGradient = Options.WLA_GRADIENT_DESCENT_CLIP_GRADIENT.getBoolean();
        this.maxGradientMagnitude = Options.WLA_GRADIENT_DESCENT_MAX_GRADIENT.getFloat();
        this.maxGradientNorm = Options.WLA_GRADIENT_DESCENT_MAX_GRADIENT_NORM.getFloat();
        this.maxNumSteps = Options.WLA_GRADIENT_DESCENT_NUM_STEPS.getInt();
        this.runFullIterations = Options.WLA_GRADIENT_DESCENT_RUN_FULL_ITERATIONS.getBoolean();
        this.movementBreak = Options.WLA_GRADIENT_DESCENT_MOVEMENT_BREAK.getBoolean();
        this.parameterMovement = Float.POSITIVE_INFINITY;
        this.movementTolerance = Options.WLA_GRADIENT_DESCENT_MOVEMENT_TOLERANCE.getFloat();
        this.normBreak = Options.WLA_GRADIENT_DESCENT_NORM_BREAK.getBoolean();
        this.normTolerance = Options.WLA_GRADIENT_DESCENT_NORM_TOLERANCE.getFloat();
        this.objectiveBreak = Options.WLA_GRADIENT_DESCENT_OBJECTIVE_BREAK.getBoolean();
        this.objectiveTolerance = Options.WLA_GRADIENT_DESCENT_OBJECTIVE_TOLERANCE.getFloat();
        this.stoppingGradientNorm = Options.WLA_GRADIENT_DESCENT_STOPPING_GRADIENT_NORM.getFloat();
        this.l2Regularization = Options.WLA_GRADIENT_DESCENT_L2_REGULARIZATION.getFloat();
        this.logRegularization = Options.WLA_GRADIENT_DESCENT_LOG_REGULARIZATION.getFloat();
        this.entropyRegularization = Options.WLA_GRADIENT_DESCENT_ENTROPY_REGULARIZATION.getFloat();
    }

    @Override
    protected void postInitGroundModel() {
        super.postInitGroundModel();
        if (this.runValidation && this.evaluation == null) {
            throw new IllegalArgumentException("If validation is being run, then an evaluator must be specified for predicates.");
        }
        if (this.runValidation && this.validationInferenceApplication.getDatabase().getAtomStore().size() <= 0) {
            throw new IllegalStateException("If validation is being run, then validation data must be provided in the runtime.json file.");
        }
        this.trainInferenceApplication.setInitialValue(InitialValue.ATOM);
        this.validationInferenceApplication.setInitialValue(InitialValue.ATOM);
        this.trainMAPTermState = this.trainInferenceApplication.getTermStore().saveState();
        this.validationMAPTermState = this.validationInferenceApplication.getTermStore().saveState();
        float[] trainAtomValues = this.trainInferenceApplication.getDatabase().getAtomStore().getAtomValues();
        this.trainMAPAtomValueState = Arrays.copyOf(trainAtomValues, trainAtomValues.length);
        float[] validationAtomValues = this.validationInferenceApplication.getDatabase().getAtomStore().getAtomValues();
        this.validationMAPAtomValueState = Arrays.copyOf(validationAtomValues, validationAtomValues.length);
        this.rvAtomGradient = new float[trainAtomValues.length];
        this.deepAtomGradient = new float[trainAtomValues.length];
        this.MAPRVAtomGradient = new float[trainAtomValues.length];
        this.MAPDeepAtomGradient = new float[trainAtomValues.length];
    }

    protected void initForLearning() {
        switch (this.gdExtension) {
            case MIRROR_DESCENT: 
            case PROJECTED_GRADIENT: {
                this.simplexScaleWeights();
                break;
            }
        }
        for (DeepPredicate deepPredicate : this.deepPredicates) {
            deepPredicate.predictDeepModel(true);
        }
        this.bestValidationWeights = new float[this.mutableRules.size()];
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            this.bestValidationWeights[i] = ((WeightedRule)this.mutableRules.get(i)).getWeight();
        }
        this.currentValidationEvaluationMetric = Double.NEGATIVE_INFINITY;
        this.bestValidationEvaluationMetric = Double.NEGATIVE_INFINITY;
    }

    @Override
    protected void doLearn() {
        boolean breakGD = false;
        float objective = 0.0f;
        float oldObjective = Float.POSITIVE_INFINITY;
        float[] bestWeights = new float[this.mutableRules.size()];
        log.info("Gradient Descent Weight Learning Start.");
        this.initForLearning();
        for (int i = 0; i < bestWeights.length; ++i) {
            bestWeights[i] = ((WeightedRule)this.mutableRules.get(i)).getWeight();
        }
        long totalTime = 0L;
        int iteration = 0;
        while (!breakGD) {
            long start = System.currentTimeMillis();
            log.trace("Model:");
            for (WeightedRule weightedRule : this.mutableRules) {
                log.trace("{}", weightedRule);
            }
            if (log.isTraceEnabled() && this.evaluation != null) {
                this.runMAPEvaluation();
                log.trace("MAP State Training Evaluation Metric: {}", this.evaluation.getNormalizedRepMetric());
            }
            if (this.runValidation) {
                this.runValidationEvaluation();
                log.debug("Current MAP State Validation Evaluation Metric: {}", this.currentValidationEvaluationMetric);
            }
            this.computeIterationStatistics();
            objective = this.computeTotalLoss();
            this.computeTotalWeightGradient();
            this.computeTotalAtomGradient();
            if (this.clipWeightGradient) {
                this.clipWeightGradient();
            }
            this.gradientStep(iteration);
            long end = System.currentTimeMillis();
            totalTime += end - start;
            oldObjective = objective;
            breakGD = this.breakOptimization(iteration, objective, oldObjective);
            log.trace("Iteration {} -- Weight Learning Objective: {}, Gradient Magnitude: {}, Parameter Movement: {}, Iteration Time: {}", iteration, Float.valueOf(objective), Float.valueOf(this.computeGradientNorm()), Float.valueOf(this.parameterMovement), end - start);
            ++iteration;
        }
        log.info("Gradient Descent Weight Learning Finished.");
        if (this.saveBestValidationWeights) {
            for (int i = 0; i < this.mutableRules.size(); ++i) {
                ((WeightedRule)this.mutableRules.get(i)).setWeight(bestWeights[i]);
            }
        }
        if (this.evaluation != null) {
            this.runMAPEvaluation();
            log.info("Final MAP State Evaluation Metric: {}", this.evaluation.getNormalizedRepMetric());
        }
        if (this.runValidation) {
            this.runValidationEvaluation();
            log.info("Final MAP State Validation Evaluation Metric: {}", this.evaluation.getNormalizedRepMetric());
        }
        log.info("Final Model {} ", this.mutableRules);
        log.info("Final Weight Learning Loss: {}, Final Gradient Magnitude: {}, Total optimization time: {}", Float.valueOf(this.computeTotalLoss()), Float.valueOf(this.computeGradientNorm()), totalTime);
        for (DeepPredicate deepPredicate : this.deepPredicates) {
            deepPredicate.saveDeepModel();
            deepPredicate.close();
        }
    }

    protected void runMAPEvaluation() {
        log.trace("Running MAP Inference.");
        this.computeMAPStateWithWarmStart(this.trainInferenceApplication, this.trainMAPTermState, this.trainMAPAtomValueState);
        this.evaluation.compute(this.trainingMap);
        for (DeepPredicate deepPredicate : this.deepPredicates) {
            deepPredicate.evalDeepModel();
        }
        for (DeepPredicate deepPredicate : this.deepPredicates) {
            deepPredicate.predictDeepModel(true);
        }
    }

    protected void runValidationEvaluation() {
        DeepPredicate deepPredicate;
        int i;
        for (i = 0; i < this.deepPredicates.size(); ++i) {
            deepPredicate = (DeepPredicate)this.deepPredicates.get(i);
            deepPredicate.setDeepModel((DeepModelPredicate)this.validationDeepModelPredicates.get(i));
            deepPredicate.predictDeepModel(false);
        }
        log.trace("Running Validation Inference.");
        this.computeMAPStateWithWarmStart(this.validationInferenceApplication, this.validationMAPTermState, this.validationMAPAtomValueState);
        this.evaluation.compute(this.validationMap);
        this.currentValidationEvaluationMetric = this.evaluation.getNormalizedRepMetric();
        if (this.currentValidationEvaluationMetric > this.bestValidationEvaluationMetric) {
            this.bestValidationEvaluationMetric = this.currentValidationEvaluationMetric;
            for (int j = 0; j < this.mutableRules.size(); ++j) {
                this.bestValidationWeights[j] = ((WeightedRule)this.mutableRules.get(j)).getWeight();
            }
            log.debug("New Best Validation Model: {}", this.mutableRules);
        }
        log.debug("MAP State Best Validation Evaluation Metric: {}", this.bestValidationEvaluationMetric);
        for (i = 0; i < this.deepPredicates.size(); ++i) {
            deepPredicate = (DeepPredicate)this.deepPredicates.get(i);
            deepPredicate.setDeepModel((DeepModelPredicate)this.deepModelPredicates.get(i));
            deepPredicate.predictDeepModel(true);
        }
    }

    protected boolean breakOptimization(int iteration, float objective, float oldObjective) {
        if (iteration > this.maxNumSteps) {
            log.trace("Breaking Weight Learning. Reached maximum number of iterations: {}", this.maxNumSteps);
            return true;
        }
        if (this.runFullIterations) {
            return false;
        }
        if (this.movementBreak && MathUtils.equals(this.parameterMovement, 0.0f, this.movementTolerance)) {
            log.trace("Breaking Weight Learning. Parameter Movement: {} is within tolerance: {}", Float.valueOf(this.parameterMovement), Float.valueOf(this.movementTolerance));
            return true;
        }
        if (this.normBreak && MathUtils.equals(this.computeGradientNorm(), 0.0f, this.normTolerance)) {
            log.trace("Breaking Weight Learning. Gradient norm: {} is within tolerance: {}", Float.valueOf(this.computeGradientNorm()), Float.valueOf(this.normTolerance));
            return true;
        }
        if (this.objectiveBreak && MathUtils.equals(objective, oldObjective, this.objectiveTolerance)) {
            log.trace("Breaking Weight Learning. Objective change: {} is within tolerance: {}", Float.valueOf(Math.abs(objective - oldObjective)), Float.valueOf(this.objectiveTolerance));
            return true;
        }
        return false;
    }

    private void clipWeightGradient() {
        float gradientMagnitude = MathUtils.pNorm(this.weightGradient, this.maxGradientNorm);
        if (gradientMagnitude > this.maxGradientMagnitude) {
            log.trace("Clipping gradient. Original gradient magnitude: {} exceeds limit: {} in L_{} space.", Float.valueOf(gradientMagnitude), Float.valueOf(this.maxGradientMagnitude), Float.valueOf(this.maxGradientNorm));
            for (int i = 0; i < this.mutableRules.size(); ++i) {
                this.weightGradient[i] = this.maxGradientMagnitude * this.weightGradient[i] / gradientMagnitude;
            }
        }
    }

    protected void gradientStep(int iteration) {
        this.parameterMovement = 0.0f;
        this.parameterMovement += this.weightGradientStep(iteration);
        this.parameterMovement += this.internalParameterGradientStep(iteration);
        this.parameterMovement += this.atomGradientStep();
    }

    protected float internalParameterGradientStep(int iteration) {
        return 0.0f;
    }

    protected float weightGradientStep(int iteration) {
        float weightChange = 0.0f;
        float[] oldWeights = new float[this.mutableRules.size()];
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            oldWeights[i] = ((WeightedRule)this.mutableRules.get(i)).getWeight();
        }
        float stepSize = this.computeStepSize(iteration);
        switch (this.gdExtension) {
            case MIRROR_DESCENT: {
                int j;
                float exponentiatedGradientSum = 0.0f;
                for (j = 0; j < this.mutableRules.size(); ++j) {
                    exponentiatedGradientSum = (float)((double)exponentiatedGradientSum + (double)((WeightedRule)this.mutableRules.get(j)).getWeight() * Math.exp(-1.0f * stepSize * this.weightGradient[j]));
                }
                for (j = 0; j < this.mutableRules.size(); ++j) {
                    ((WeightedRule)this.mutableRules.get(j)).setWeight((float)((double)((WeightedRule)this.mutableRules.get(j)).getWeight() * Math.exp(-1.0f * stepSize * this.weightGradient[j]) / (double)exponentiatedGradientSum));
                }
                break;
            }
            case PROJECTED_GRADIENT: {
                int j;
                for (j = 0; j < this.mutableRules.size(); ++j) {
                    ((WeightedRule)this.mutableRules.get(j)).setWeight(((WeightedRule)this.mutableRules.get(j)).getWeight() - this.weightGradient[j] * stepSize);
                }
                this.simplexProjectWeights();
                break;
            }
            default: {
                int j;
                for (j = 0; j < this.mutableRules.size(); ++j) {
                    ((WeightedRule)this.mutableRules.get(j)).setWeight(((WeightedRule)this.mutableRules.get(j)).getWeight() - this.weightGradient[j] * stepSize);
                }
            }
        }
        this.inTrainingMAPState = false;
        this.inValidationMAPState = false;
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            weightChange += Math.abs(oldWeights[i] - ((WeightedRule)this.mutableRules.get(i)).getWeight());
        }
        return weightChange;
    }

    protected float atomGradientStep() {
        float deepPredicateChange = 0.0f;
        for (DeepPredicate deepPredicate : this.deepPredicates) {
            deepPredicate.fitDeepPredicate(this.deepAtomGradient);
            deepPredicateChange += deepPredicate.predictDeepModel(true);
        }
        return deepPredicateChange;
    }

    protected float computeStepSize(int iteration) {
        float stepSize = this.baseStepSize;
        if (this.scaleStepSize) {
            stepSize /= (float)(iteration + 1);
        }
        return stepSize;
    }

    protected float computeGradientNorm() {
        float norm = 0.0f;
        switch (this.gdExtension) {
            case MIRROR_DESCENT: {
                norm = this.computeMirrorDescentNorm();
                break;
            }
            case PROJECTED_GRADIENT: {
                norm = this.computeProjectedGradientDescentNorm();
                break;
            }
            default: {
                norm = this.computeGradientDescentNorm();
            }
        }
        return norm += MathUtils.pNorm(this.deepAtomGradient, 2.0f);
    }

    private float computeMirrorDescentNorm() {
        int i;
        float norm = 0.0f;
        float exponentiatedGradientSum = 0.0f;
        for (i = 0; i < this.mutableRules.size(); ++i) {
            exponentiatedGradientSum = (float)((double)exponentiatedGradientSum + Math.exp(this.weightGradient[i]));
        }
        for (i = 0; i < this.mutableRules.size(); ++i) {
            float mappedWeightGradient = (float)Math.exp(this.weightGradient[i]) / exponentiatedGradientSum;
            norm += mappedWeightGradient * (float)Math.log(mappedWeightGradient * (float)this.mutableRules.size());
        }
        return norm;
    }

    private float computeProjectedGradientDescentNorm() {
        int i;
        float norm = 0.0f;
        int numNonZeroGradients = 0;
        float[] simplexClippedGradients = (float[])this.weightGradient.clone();
        for (int i2 = 0; i2 < simplexClippedGradients.length; ++i2) {
            if (this.logRegularization == 0.0f && MathUtils.equalsStrict(((WeightedRule)this.mutableRules.get(i2)).getWeight(), 0.0f) && this.weightGradient[i2] > 0.0f) {
                simplexClippedGradients[i2] = 0.0f;
                continue;
            }
            if (this.logRegularization == 0.0f && MathUtils.equalsStrict(((WeightedRule)this.mutableRules.get(i2)).getWeight(), 1.0f) && this.weightGradient[i2] < 0.0f) {
                simplexClippedGradients[i2] = 0.0f;
                continue;
            }
            simplexClippedGradients[i2] = this.weightGradient[i2];
            if (this.logRegularization == 0.0f && MathUtils.isZero((double)simplexClippedGradients[i2], 1.0E-8)) continue;
            ++numNonZeroGradients;
        }
        float exponentiatedGradientSum = 0.0f;
        for (i = 0; i < this.mutableRules.size(); ++i) {
            if (this.logRegularization == 0.0f && MathUtils.isZero((double)simplexClippedGradients[i], 1.0E-8)) continue;
            exponentiatedGradientSum = (float)((double)exponentiatedGradientSum + Math.exp(this.weightGradient[i]));
        }
        for (i = 0; i < this.mutableRules.size(); ++i) {
            if (this.logRegularization == 0.0f && MathUtils.isZero((double)simplexClippedGradients[i], 1.0E-8)) continue;
            float mappedWeightGradient = (float)Math.exp(this.weightGradient[i]) / exponentiatedGradientSum;
            norm += mappedWeightGradient * (float)Math.log(mappedWeightGradient * (float)numNonZeroGradients);
        }
        return norm;
    }

    private float computeGradientDescentNorm() {
        float[] boundaryClippedGradients = (float[])this.weightGradient.clone();
        for (int i = 0; i < boundaryClippedGradients.length; ++i) {
            boundaryClippedGradients[i] = MathUtils.equals(((WeightedRule)this.mutableRules.get(i)).getWeight(), 0.0f) && this.weightGradient[i] > 0.0f ? 0.0f : this.weightGradient[i];
        }
        return MathUtils.pNorm(boundaryClippedGradients, this.stoppingGradientNorm);
    }

    public void simplexProjectWeights() {
        float nextCumulativeWeightSum;
        float nextTau;
        int numWeights = this.mutableRules.size();
        float[] weights = new float[numWeights];
        for (int i = 0; i < numWeights; ++i) {
            weights[i] = ((WeightedRule)this.mutableRules.get(i)).getWeight();
        }
        Arrays.sort(weights);
        float cumulativeWeightSum = 0.0f;
        float tau = 0.0f;
        for (int i = 1; i <= numWeights && !((nextTau = ((nextCumulativeWeightSum = cumulativeWeightSum + weights[numWeights - i]) - 1.0f) / (float)i) >= weights[numWeights - i]); ++i) {
            cumulativeWeightSum = nextCumulativeWeightSum;
            tau = nextTau;
        }
        for (WeightedRule mutableRule : this.mutableRules) {
            mutableRule.setWeight(Math.max(0.0f, mutableRule.getWeight() - tau));
        }
    }

    private void simplexScaleWeights() {
        float totalWeight = 0.0f;
        for (WeightedRule mutableRule : this.mutableRules) {
            totalWeight += mutableRule.getWeight();
        }
        for (WeightedRule mutableRule : this.mutableRules) {
            mutableRule.setWeight(mutableRule.getWeight() / totalWeight);
        }
    }

    protected void computeMAPStateWithWarmStart(InferenceApplication inferenceApplication, TermState[] warmStartTermState, float[] warmStartAtomValueState) {
        inferenceApplication.getTermStore().loadState(warmStartTermState);
        AtomStore atomStore = inferenceApplication.getDatabase().getAtomStore();
        float[] atomValues = atomStore.getAtomValues();
        for (int i = 0; i < atomStore.size(); ++i) {
            if (atomStore.getAtom(i).isFixed()) continue;
            atomValues[i] = warmStartAtomValueState[i];
        }
        atomStore.sync();
        this.computeMAPState(inferenceApplication);
        inferenceApplication.getTermStore().saveState(warmStartTermState);
        float[] mpeAtomValues = inferenceApplication.getDatabase().getAtomStore().getAtomValues();
        System.arraycopy(mpeAtomValues, 0, warmStartAtomValueState, 0, mpeAtomValues.length);
    }

    protected void computeCurrentIncompatibility(float[] incompatibilityArray) {
        Arrays.fill(incompatibilityArray, 0.0f);
        float[] atomValues = this.trainInferenceApplication.getDatabase().getAtomStore().getAtomValues();
        for (Object rawTerm : this.trainInferenceApplication.getTermStore()) {
            Integer index;
            ReasonerTerm term = (ReasonerTerm)rawTerm;
            if (!(term.getRule() instanceof WeightedRule) || (index = this.ruleIndexMap.get((WeightedRule)term.getRule())) == null) continue;
            int n = index;
            incompatibilityArray[n] = incompatibilityArray[n] + term.evaluateIncompatibility(atomValues);
        }
    }

    protected abstract void computeIterationStatistics();

    protected float computeTotalLoss() {
        float learningLoss = this.computeLearningLoss();
        float regularization = this.computeRegularization();
        log.trace("Learning Loss: {}, Regularization: {}", Float.valueOf(learningLoss), Float.valueOf(regularization));
        return learningLoss + regularization;
    }

    protected abstract float computeLearningLoss();

    protected float computeRegularization() {
        float regularization = 0.0f;
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            WeightedRule mutableRule = (WeightedRule)this.mutableRules.get(i);
            float logWeight = (float)Math.max(Math.log(mutableRule.getWeight()), Math.log(1.0E-8));
            regularization += this.l2Regularization * (float)Math.pow(mutableRule.getWeight(), 2.0) - this.logRegularization * logWeight + this.entropyRegularization * mutableRule.getWeight() * logWeight;
        }
        return regularization;
    }

    protected void computeTotalWeightGradient() {
        Arrays.fill(this.weightGradient, 0.0f);
        this.addLearningLossWeightGradient();
        this.addRegularizationWeightGradient();
    }

    protected abstract void addLearningLossWeightGradient();

    protected void addRegularizationWeightGradient() {
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            float logWeight = (float)Math.log(Math.max((double)((WeightedRule)this.mutableRules.get(i)).getWeight(), 1.0E-8));
            int n = i;
            this.weightGradient[n] = (float)((double)this.weightGradient[n] + ((double)(2.0f * this.l2Regularization * ((WeightedRule)this.mutableRules.get(i)).getWeight()) - (double)this.logRegularization / Math.max((double)((WeightedRule)this.mutableRules.get(i)).getWeight(), 1.0E-8) + (double)(this.entropyRegularization * (logWeight + 1.0f))));
        }
    }

    protected abstract void computeTotalAtomGradient();

    public static enum GDExtension {
        MIRROR_DESCENT,
        PROJECTED_GRADIENT,
        NONE;

    }
}

