/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.gradient.optimalvalue;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.linqs.psl.application.learning.weight.gradient.GradientDescent;
import org.linqs.psl.database.AtomStore;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.reasoner.term.TermState;
import org.linqs.psl.util.Logger;

public abstract class OptimalValue
extends GradientDescent {
    private static final Logger log = Logger.getLogger(GradientDescent.class);
    protected float[] latentInferenceIncompatibility;
    protected TermState[] latentInferenceTermState;
    protected float[] latentInferenceAtomValueState;
    protected float[] rvLatentAtomGradient;
    protected float[] deepLatentAtomGradient;

    public OptimalValue(List<Rule> rules, Database trainTargetDatabase, Database trainTruthDatabase, Database validationTargetDatabase, Database validationTruthDatabase, boolean runValidation) {
        super(rules, trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, runValidation);
        this.latentInferenceIncompatibility = new float[this.mutableRules.size()];
        this.latentInferenceTermState = null;
        this.latentInferenceAtomValueState = null;
    }

    @Override
    protected void postInitGroundModel() {
        super.postInitGroundModel();
        this.latentInferenceTermState = this.trainInferenceApplication.getTermStore().saveState();
        float[] atomValues = this.trainInferenceApplication.getDatabase().getAtomStore().getAtomValues();
        this.latentInferenceAtomValueState = Arrays.copyOf(atomValues, atomValues.length);
        this.rvLatentAtomGradient = new float[atomValues.length];
        this.deepLatentAtomGradient = new float[atomValues.length];
    }

    protected void computeLatentInferenceIncompatibility() {
        this.fixLabeledRandomVariables();
        log.trace("Running Latent Inference.");
        this.computeMAPStateWithWarmStart(this.trainInferenceApplication, this.latentInferenceTermState, this.latentInferenceAtomValueState);
        this.inTrainingMAPState = true;
        this.computeCurrentIncompatibility(this.latentInferenceIncompatibility);
        this.trainInferenceApplication.getReasoner().computeOptimalValueGradient(this.trainInferenceApplication.getTermStore(), this.rvLatentAtomGradient, this.deepLatentAtomGradient);
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            log.trace("Rule: {} , Latent inference incompatibility: {}", this.mutableRules.get(i), Float.valueOf(this.latentInferenceIncompatibility[i]));
        }
        this.unfixLabeledRandomVariables();
    }

    protected void fixLabeledRandomVariables() {
        AtomStore atomStore = this.trainInferenceApplication.getTermStore().getDatabase().getAtomStore();
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getLabelMap().entrySet()) {
            RandomVariableAtom randomVariableAtom = entry.getKey();
            ObservedAtom observedAtom = entry.getValue();
            int atomIndex = atomStore.getAtomIndex(randomVariableAtom);
            atomStore.getAtoms()[atomIndex] = observedAtom;
            atomStore.getAtomValues()[atomIndex] = observedAtom.getValue();
            this.latentInferenceAtomValueState[atomIndex] = observedAtom.getValue();
            randomVariableAtom.setValue(observedAtom.getValue());
        }
        this.inTrainingMAPState = false;
    }

    protected void unfixLabeledRandomVariables() {
        AtomStore atomStore = this.trainInferenceApplication.getDatabase().getAtomStore();
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getLabelMap().entrySet()) {
            RandomVariableAtom randomVariableAtom = entry.getKey();
            int atomIndex = atomStore.getAtomIndex(randomVariableAtom);
            atomStore.getAtoms()[atomIndex] = randomVariableAtom;
        }
        this.inTrainingMAPState = false;
    }
}

