/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight;

import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.List;
import org.linqs.psl.application.ModelApplication;
import org.linqs.psl.application.inference.InferenceApplication;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.Database;
import org.linqs.psl.evaluation.statistics.Evaluator;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.util.RandUtils;
import org.linqs.psl.util.Reflection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class WeightLearningApplication
implements ModelApplication {
    private static final Logger log = LoggerFactory.getLogger(WeightLearningApplication.class);
    protected Database rvDB;
    protected Database observedDB;
    protected List<Rule> allRules;
    protected List<WeightedRule> mutableRules;
    protected TrainingMap trainingMap;
    protected InferenceApplication inference;
    protected Evaluator evaluator;
    private boolean groundModelInit;
    protected boolean inMPEState;

    public WeightLearningApplication(List<Rule> rules, Database rvDB, Database observedDB) {
        this.rvDB = rvDB;
        this.observedDB = observedDB;
        this.allRules = new ArrayList<Rule>();
        this.mutableRules = new ArrayList<WeightedRule>();
        for (Rule rule : rules) {
            this.allRules.add(rule);
            if (!(rule instanceof WeightedRule)) continue;
            this.mutableRules.add((WeightedRule)rule);
        }
        this.groundModelInit = false;
        this.inMPEState = false;
        this.evaluator = (Evaluator)Options.WLA_EVAL.getNewObject();
    }

    public void learn() {
        this.initGroundModel();
        this.doLearn();
    }

    protected abstract void doLearn();

    public void setBudget(double budget) {
        this.inference.setBudget(budget);
    }

    public InferenceApplication getInferenceApplication() {
        return this.inference;
    }

    protected void initGroundModel() {
        if (this.groundModelInit) {
            return;
        }
        InferenceApplication inference = InferenceApplication.getInferenceApplication(Options.WLA_INFERENCE.getString(), this.allRules, this.rvDB);
        this.initGroundModel(inference);
    }

    private void initGroundModel(InferenceApplication inference) {
        if (this.groundModelInit) {
            return;
        }
        TrainingMap trainingMap = new TrainingMap(inference.getAtomManager(), this.observedDB);
        this.initGroundModel(inference, trainingMap);
    }

    public void initGroundModel(InferenceApplication inference, TrainingMap trainingMap) {
        if (this.groundModelInit) {
            return;
        }
        this.inference = inference;
        this.trainingMap = trainingMap;
        if (Options.WLA_RANDOM_WEIGHTS.getBoolean()) {
            this.initRandomWeights();
        }
        this.postInitGroundModel();
        this.groundModelInit = true;
    }

    private void initRandomWeights() {
        log.trace("Randomly Weighted Rules:");
        for (WeightedRule rule : this.mutableRules) {
            rule.setWeight(RandUtils.nextFloat());
            log.trace("    " + rule.toString());
        }
    }

    protected void postInitGroundModel() {
    }

    protected void computeMPEState() {
        if (this.inMPEState) {
            return;
        }
        this.inference.inference(false, true);
        this.inMPEState = true;
    }

    @Override
    public void close() {
        if (this.inference != null) {
            this.inference.commit();
            this.inference.close();
            this.inference = null;
        }
        this.trainingMap = null;
        this.rvDB = null;
        this.observedDB = null;
    }

    public static WeightLearningApplication getWLA(String name, List<Rule> rules, Database randomVariableDatabase, Database observedTruthDatabase) {
        String className = Reflection.resolveClassName(name);
        if (className == null) {
            throw new IllegalArgumentException("Could not find class: " + name);
        }
        Class<?> classObject = null;
        try {
            Class<?> uncheckedClassObject;
            classObject = uncheckedClassObject = Class.forName(className);
        }
        catch (ClassNotFoundException ex) {
            throw new IllegalArgumentException("Could not find class: " + className, ex);
        }
        Constructor<?> constructor = null;
        try {
            constructor = classObject.getConstructor(List.class, Database.class, Database.class);
        }
        catch (NoSuchMethodException ex) {
            throw new IllegalArgumentException("No sutible constructor found for weight learner: " + className + ".", ex);
        }
        WeightLearningApplication wla = null;
        try {
            wla = (WeightLearningApplication)constructor.newInstance(rules, randomVariableDatabase, observedTruthDatabase);
        }
        catch (InstantiationException ex) {
            throw new RuntimeException("Unable to instantiate weight learner (" + className + ")", ex);
        }
        catch (IllegalAccessException ex) {
            throw new RuntimeException("Insufficient access to constructor for " + className, ex);
        }
        catch (InvocationTargetException ex) {
            throw new RuntimeException("Error thrown while constructing " + className, ex);
        }
        return wla;
    }
}

