/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight;

import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.linqs.psl.application.ModelApplication;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Config;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.atom.AtomManager;
import org.linqs.psl.database.atom.PersistedAtomManager;
import org.linqs.psl.evaluation.statistics.ContinuousEvaluator;
import org.linqs.psl.evaluation.statistics.Evaluator;
import org.linqs.psl.grounding.GroundRuleStore;
import org.linqs.psl.grounding.Grounding;
import org.linqs.psl.grounding.MemoryGroundRuleStore;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.model.rule.misc.GroundValueConstraint;
import org.linqs.psl.reasoner.Reasoner;
import org.linqs.psl.reasoner.admm.ADMMReasoner;
import org.linqs.psl.reasoner.admm.term.ADMMTermGenerator;
import org.linqs.psl.reasoner.admm.term.ADMMTermStore;
import org.linqs.psl.reasoner.term.TermGenerator;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.util.RandUtils;
import org.linqs.psl.util.Reflection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class WeightLearningApplication
implements ModelApplication {
    private static final Logger log = LoggerFactory.getLogger(WeightLearningApplication.class);
    public static final String CONFIG_PREFIX = "weightlearning";
    public static final String REASONER_KEY = "weightlearning.reasoner";
    public static final String REASONER_DEFAULT = ADMMReasoner.class.getName();
    public static final String GROUND_RULE_STORE_KEY = "weightlearning.groundrulestore";
    public static final String GROUND_RULE_STORE_DEFAULT = MemoryGroundRuleStore.class.getName();
    public static final String TERM_STORE_KEY = "weightlearning.termstore";
    public static final String TERM_STORE_DEFAULT = ADMMTermStore.class.getName();
    public static final String TERM_GENERATOR_KEY = "weightlearning.termgenerator";
    public static final String TERM_GENERATOR_DEFAULT = ADMMTermGenerator.class.getName();
    public static final String EVALUATOR_KEY = "weightlearning.evaluator";
    public static final String EVALUATOR_DEFAULT = ContinuousEvaluator.class.getName();
    public static final String RANDOM_WEIGHTS_KEY = "weightlearning.randomweights";
    public static final boolean RANDOM_WEIGHTS_DEFAULT = false;
    public static final int MAX_RANDOM_WEIGHT = 100;
    public static final int MIN_ADMM_STEPS = 3;
    protected boolean supportsLatentVariables;
    protected Database rvDB;
    protected Database observedDB;
    protected PersistedAtomManager atomManager;
    protected List<Rule> allRules;
    protected List<WeightedRule> mutableRules;
    protected double[] observedIncompatibility;
    protected double[] expectedIncompatibility;
    protected TrainingMap trainingMap;
    protected Reasoner reasoner;
    protected GroundRuleStore groundRuleStore;
    protected GroundRuleStore latentGroundRuleStore;
    protected TermGenerator termGenerator;
    protected TermStore termStore;
    protected TermStore latentTermStore;
    protected Evaluator evaluator;
    private boolean groundModelInit;
    protected boolean inMPEState;
    protected boolean inLatentMPEState;

    public WeightLearningApplication(List<Rule> rules, Database rvDB, Database observedDB, boolean supportsLatentVariables) {
        this.rvDB = rvDB;
        this.observedDB = observedDB;
        this.supportsLatentVariables = supportsLatentVariables;
        this.allRules = new ArrayList<Rule>();
        this.mutableRules = new ArrayList<WeightedRule>();
        for (Rule rule : rules) {
            this.allRules.add(rule);
            if (!(rule instanceof WeightedRule)) continue;
            this.mutableRules.add((WeightedRule)rule);
        }
        this.observedIncompatibility = new double[this.mutableRules.size()];
        this.expectedIncompatibility = new double[this.mutableRules.size()];
        this.groundModelInit = false;
        this.inMPEState = false;
        this.inLatentMPEState = false;
        this.evaluator = (Evaluator)Config.getNewObject(EVALUATOR_KEY, EVALUATOR_DEFAULT);
    }

    public void learn() {
        this.initGroundModel();
        if (this.supportsLatentVariables) {
            this.initLatentGroundModel();
        }
        this.doLearn();
    }

    protected abstract void doLearn();

    public void setBudget(double budget) {
        if (this.reasoner instanceof ADMMReasoner) {
            int maxIterations = Config.getInt("admmreasoner.maxiterations", 25000);
            int iterations = (int)Math.ceil((double)maxIterations * budget);
            ((ADMMReasoner)this.reasoner).setMaxIter(Math.max(3, iterations));
            if (this.termStore instanceof ADMMTermStore) {
                ((ADMMTermStore)this.termStore).resetLocalVairables();
            }
        }
    }

    public GroundRuleStore getGroundRuleStore() {
        return this.groundRuleStore;
    }

    protected void initGroundModel() {
        if (this.groundModelInit) {
            return;
        }
        PersistedAtomManager atomManager = this.createAtomManager();
        this.ensureTargets(atomManager);
        GroundRuleStore groundRuleStore = (GroundRuleStore)Config.getNewObject(GROUND_RULE_STORE_KEY, GROUND_RULE_STORE_DEFAULT);
        log.info("Grounding out model.");
        int groundCount = Grounding.groundAll(this.allRules, (AtomManager)atomManager, groundRuleStore);
        this.initGroundModel(atomManager, groundRuleStore);
    }

    public void initGroundModel(GroundRuleStore groundRuleStore) {
        if (this.groundModelInit) {
            return;
        }
        this.initGroundModel(this.createAtomManager(), groundRuleStore);
    }

    private void initGroundModel(PersistedAtomManager atomManager, GroundRuleStore groundRuleStore) {
        if (this.groundModelInit) {
            return;
        }
        TermStore termStore = (TermStore)Config.getNewObject(TERM_STORE_KEY, TERM_STORE_DEFAULT);
        TermGenerator termGenerator = (TermGenerator)Config.getNewObject(TERM_GENERATOR_KEY, TERM_GENERATOR_DEFAULT);
        log.debug("Initializing objective terms for {} ground rules.", (Object)groundRuleStore.size());
        termStore.ensureVariableCapacity(atomManager.getCachedRVACount());
        int termCount = termGenerator.generateTerms(groundRuleStore, termStore);
        log.debug("Generated {} objective terms from {} ground rules.", (Object)termCount, (Object)groundRuleStore.size());
        TrainingMap trainingMap = new TrainingMap(atomManager, this.observedDB, false);
        if (!this.supportsLatentVariables && trainingMap.getLatentVariables().size() > 0) {
            Set<RandomVariableAtom> latentVariables = trainingMap.getLatentVariables();
            throw new IllegalArgumentException(String.format("All RandomVariableAtoms must have corresponding ObservedAtoms, found %d latent variables. Latent variables are not supported by this WeightLearningApplication (%s). Example latent variable: [%s].", latentVariables.size(), this.getClass().getName(), latentVariables.iterator().next()));
        }
        Reasoner reasoner = (Reasoner)Config.getNewObject(REASONER_KEY, REASONER_DEFAULT);
        this.initGroundModel(reasoner, groundRuleStore, termStore, termGenerator, atomManager, trainingMap);
    }

    public void initGroundModel(Reasoner reasoner, GroundRuleStore groundRuleStore, TermStore termStore, TermGenerator termGenerator, PersistedAtomManager atomManager, TrainingMap trainingMap) {
        if (this.groundModelInit) {
            return;
        }
        this.reasoner = reasoner;
        this.groundRuleStore = groundRuleStore;
        this.termStore = termStore;
        this.termGenerator = termGenerator;
        this.atomManager = atomManager;
        this.trainingMap = trainingMap;
        if (Config.getBoolean(RANDOM_WEIGHTS_KEY, false)) {
            this.initRandomWeights();
        }
        this.postInitGroundModel();
        this.groundModelInit = true;
    }

    private void initRandomWeights() {
        log.trace("Randomly Weighted Rules:");
        for (WeightedRule rule : this.mutableRules) {
            rule.setWeight(RandUtils.nextInt(100) + 1);
            log.trace("    " + rule.toString());
        }
    }

    protected void postInitGroundModel() {
    }

    protected void initLatentGroundModel() {
        this.latentGroundRuleStore = (GroundRuleStore)Config.getNewObject(GROUND_RULE_STORE_KEY, GROUND_RULE_STORE_DEFAULT);
        this.latentTermStore = (TermStore)Config.getNewObject(TERM_STORE_KEY, TERM_STORE_DEFAULT);
        log.info("Grounding out latent model.");
        int groundCount = Grounding.groundAll(this.allRules, (AtomManager)this.atomManager, this.latentGroundRuleStore);
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getTrainingMap().entrySet()) {
            this.latentGroundRuleStore.addGroundRule(new GroundValueConstraint(entry.getKey(), entry.getValue().getValue()));
        }
        log.debug("Initializing latent objective terms for {} ground rules.", (Object)(groundCount += this.trainingMap.getTrainingMap().size()));
        this.termStore.ensureVariableCapacity(this.atomManager.getCachedRVACount());
        int termCount = this.termGenerator.generateTerms(this.latentGroundRuleStore, this.latentTermStore);
        log.debug("Generated {} latent objective terms from {} ground rules.", (Object)termCount, (Object)groundCount);
    }

    protected void computeMPEState() {
        if (this.inMPEState) {
            return;
        }
        this.termStore.clear();
        this.termStore.ensureVariableCapacity(this.atomManager.getCachedRVACount());
        this.termGenerator.generateTerms(this.groundRuleStore, this.termStore);
        this.reasoner.optimize(this.termStore);
        this.inMPEState = true;
    }

    protected void computeLatentMPEState() {
        if (this.inLatentMPEState) {
            return;
        }
        this.termStore.clear();
        this.termStore.ensureVariableCapacity(this.atomManager.getCachedRVACount());
        this.termGenerator.generateTerms(this.groundRuleStore, this.termStore);
        this.reasoner.optimize(this.latentTermStore);
        this.inLatentMPEState = true;
    }

    protected void computeObservedIncompatibility() {
        int i;
        this.setLabeledRandomVariables();
        for (i = 0; i < this.observedIncompatibility.length; ++i) {
            this.observedIncompatibility[i] = 0.0;
        }
        for (i = 0; i < this.mutableRules.size(); ++i) {
            for (GroundRule groundRule : this.groundRuleStore.getGroundRules(this.mutableRules.get(i))) {
                int n = i;
                this.observedIncompatibility[n] = this.observedIncompatibility[n] + ((WeightedGroundRule)groundRule).getIncompatibility();
            }
        }
    }

    protected void computeExpectedIncompatibility() {
        int i;
        this.computeMPEState();
        for (i = 0; i < this.expectedIncompatibility.length; ++i) {
            this.expectedIncompatibility[i] = 0.0;
        }
        for (i = 0; i < this.mutableRules.size(); ++i) {
            for (GroundRule groundRule : this.groundRuleStore.getGroundRules(this.mutableRules.get(i))) {
                int n = i;
                this.expectedIncompatibility[n] = this.expectedIncompatibility[n] + ((WeightedGroundRule)groundRule).getIncompatibility();
            }
        }
    }

    public double computeLoss() {
        double loss = 0.0;
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            loss += this.mutableRules.get(i).getWeight() * (this.observedIncompatibility[i] - this.expectedIncompatibility[i]);
        }
        return loss;
    }

    @Override
    public void close() {
        if (this.groundRuleStore != null) {
            this.groundRuleStore.close();
            this.groundRuleStore = null;
        }
        if (this.latentGroundRuleStore != null) {
            this.latentGroundRuleStore.close();
            this.latentGroundRuleStore = null;
        }
        if (this.termStore != null) {
            this.termStore.close();
            this.termStore = null;
        }
        if (this.latentTermStore != null) {
            this.latentTermStore.close();
            this.latentTermStore = null;
        }
        if (this.reasoner != null) {
            this.reasoner.close();
            this.reasoner = null;
        }
        this.termGenerator = null;
        this.trainingMap = null;
        this.atomManager = null;
        this.rvDB = null;
        this.observedDB = null;
    }

    protected void setLabeledRandomVariables() {
        this.inMPEState = false;
        this.inLatentMPEState = false;
        for (Map.Entry<RandomVariableAtom, ObservedAtom> entry : this.trainingMap.getTrainingMap().entrySet()) {
            entry.getKey().setValue(entry.getValue().getValue());
        }
    }

    protected void setDefaultRandomVariables() {
        this.inMPEState = false;
        this.inLatentMPEState = false;
        for (RandomVariableAtom atom : this.trainingMap.getTrainingMap().keySet()) {
            atom.setValue(0.0f);
        }
        for (RandomVariableAtom atom : this.trainingMap.getLatentVariables()) {
            atom.setValue(0.0f);
        }
    }

    protected PersistedAtomManager createAtomManager() {
        return new PersistedAtomManager(this.rvDB);
    }

    private void ensureTargets(PersistedAtomManager atomManager) {
        for (StandardPredicate predicate : this.observedDB.getDataStore().getRegisteredPredicates()) {
            if (this.observedDB.isClosed(predicate)) continue;
            for (ObservedAtom observedAtom : this.observedDB.getAllGroundObservedAtoms(predicate)) {
                GroundAtom otherAtom = atomManager.getAtom(observedAtom.getPredicate(), observedAtom.getArguments());
                if (otherAtom instanceof ObservedAtom) continue;
                RandomVariableAtom rvAtom = (RandomVariableAtom)otherAtom;
                rvAtom.setValue(0.0f);
            }
        }
        atomManager.commitPersistedAtoms();
    }

    public static WeightLearningApplication getWLA(String name, List<Rule> rules, Database randomVariableDatabase, Database observedTruthDatabase) {
        String className = Reflection.resolveClassName(name);
        if (className == null) {
            throw new IllegalArgumentException("Could not find class: " + name);
        }
        Class<?> classObject = null;
        try {
            Class<?> uncheckedClassObject;
            classObject = uncheckedClassObject = Class.forName(className);
        }
        catch (ClassNotFoundException ex) {
            throw new IllegalArgumentException("Could not find class: " + className, ex);
        }
        Constructor<?> constructor = null;
        try {
            constructor = classObject.getConstructor(List.class, Database.class, Database.class);
        }
        catch (NoSuchMethodException ex) {
            throw new IllegalArgumentException("No sutible constructor found for weight learner: " + className + ".", ex);
        }
        WeightLearningApplication wla = null;
        try {
            wla = (WeightLearningApplication)constructor.newInstance(rules, randomVariableDatabase, observedTruthDatabase);
        }
        catch (InstantiationException ex) {
            throw new RuntimeException("Unable to instantiate weight learner (" + className + ")", ex);
        }
        catch (IllegalAccessException ex) {
            throw new RuntimeException("Insufficient access to constructor for " + className, ex);
        }
        catch (InvocationTargetException ex) {
            throw new RuntimeException("Error thrown while constructing " + className, ex);
        }
        return wla;
    }
}

