/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.search.grid;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.util.MathUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseGridSearch
extends WeightLearningApplication {
    private static final Logger log = LoggerFactory.getLogger(BaseGridSearch.class);
    protected int maxNumLocations;
    protected int numLocations;
    protected Map<String, Double> objectives;
    protected String currentLocation;

    public BaseGridSearch(Model model, Database rvDB, Database observedDB) {
        this(model.getRules(), rvDB, observedDB);
    }

    public BaseGridSearch(List<Rule> rules, Database rvDB, Database observedDB) {
        super(rules, rvDB, observedDB);
        this.numLocations = this.maxNumLocations = 0;
        this.objectives = new HashMap<String, Double>();
        this.currentLocation = null;
    }

    @Override
    protected void doLearn() {
        double bestObjective = -1.0;
        float[] bestWeights = new float[this.mutableRules.size()];
        float[] weights = new float[this.mutableRules.size()];
        float[] unitWeightVector = new float[this.mutableRules.size()];
        boolean nonZero = false;
        for (int iteration = 0; iteration < this.numLocations; ++iteration) {
            int i;
            if (!this.chooseNextLocation()) {
                log.debug("Stopping search.");
                break;
            }
            log.debug("Iteration {} / {} ({}) -- Inspecting location {}", iteration, this.numLocations, this.maxNumLocations, this.currentLocation);
            nonZero = false;
            this.getWeights(weights);
            System.arraycopy(weights, 0, unitWeightVector, 0, weights.length);
            for (i = 0; i < weights.length; ++i) {
                if (!((double)weights[i] > 0.0)) continue;
                nonZero = true;
                break;
            }
            if (nonZero) {
                MathUtils.toUnit(unitWeightVector);
            }
            for (i = 0; i < this.mutableRules.size(); ++i) {
                ((WeightedRule)this.mutableRules.get(i)).setWeight(weights[i]);
            }
            log.trace("Weights: {}", (Object)weights);
            this.inMPEState = false;
            double objective = this.inspectLocation(weights);
            this.objectives.put(this.currentLocation, new Double(objective));
            if (iteration == 0 || objective < bestObjective) {
                bestObjective = objective;
                for (int i2 = 0; i2 < this.mutableRules.size(); ++i2) {
                    bestWeights[i2] = weights[i2];
                }
            }
            log.debug("Weights: {} -- objective: {}", (Object)this.currentLocation, (Object)objective);
        }
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            ((WeightedRule)this.mutableRules.get(i)).setWeight(bestWeights[i]);
        }
        this.inMPEState = false;
    }

    protected double inspectLocation(float[] weights) {
        this.computeMPEState();
        this.evaluator.compute(this.trainingMap);
        return -1.0 * this.evaluator.getNormalizedRepMetric();
    }

    protected abstract void getWeights(float[] var1);

    protected abstract boolean chooseNextLocation();
}

