/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.search.grid;

import java.util.List;
import org.linqs.psl.application.learning.weight.search.grid.BaseGridSearch;
import org.linqs.psl.config.Config;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.util.RandUtils;

public class ContinuousRandomGridSearch
extends BaseGridSearch {
    public static final String CONFIG_PREFIX = "continuousrandomgridsearch";
    public static final String MAX_LOCATIONS_KEY = "continuousrandomgridsearch.maxlocations";
    public static final int MAX_LOCATIONS_DEFAULT = 250;
    public static final String BASE_WEIGHT_KEY = "continuousrandomgridsearch.baseweight";
    public static final double BASE_WEIGHT_DEFAULT = 0.4;
    public static final String VARIANCE_KEY = "continuousrandomgridsearch.variance";
    public static final double VARIANCE_DEFAULT = 0.2;
    public static final String UNIFORM_BASE_KEY = "continuousrandomgridsearch.uniformbase";
    public static final boolean UNIFORM_BASE_DEFAULT = true;
    public static final String SCALE_ORDERS_KEY = "continuousrandomgridsearch.scaleorders";
    public static final int SCALE_ORDERS_DEFAULT = 0;
    public static final int SCALE_FACTOR = 10;
    private double[] weightMeans;
    private double baseWeight;
    private double variance;
    private boolean uniformBase;
    private int scaleOrder = Math.max(0, Config.getInt("continuousrandomgridsearch.scaleorders", 0));
    private int currentScale = 0;

    public ContinuousRandomGridSearch(Model model, Database rvDB, Database observedDB) {
        this(model.getRules(), rvDB, observedDB);
    }

    public ContinuousRandomGridSearch(List<Rule> rules, Database rvDB, Database observedDB) {
        super(rules, rvDB, observedDB);
        this.numLocations = Config.getInt(MAX_LOCATIONS_KEY, 250);
        if (this.scaleOrder > 0) {
            this.numLocations *= this.scaleOrder + 1;
        }
        this.baseWeight = Config.getDouble(BASE_WEIGHT_KEY, 0.4);
        this.variance = Config.getDouble(VARIANCE_KEY, 0.2);
        this.uniformBase = Config.getBoolean(UNIFORM_BASE_KEY, true);
        this.weightMeans = null;
    }

    @Override
    protected void postInitGroundModel() {
        this.computeWeightMeans();
    }

    @Override
    protected void getWeights(double[] weights) {
        if (this.currentScale == 0) {
            for (int i = 0; i < this.mutableRules.size(); ++i) {
                weights[i] = RandUtils.nextDouble() * Math.sqrt(this.variance) + this.weightMeans[i];
            }
        } else {
            int i = 0;
            while (i < this.mutableRules.size()) {
                int n = i++;
                weights[n] = weights[n] * 10.0;
            }
        }
        ++this.currentScale;
        if (this.currentScale > this.scaleOrder) {
            this.currentScale = 0;
        }
    }

    @Override
    protected boolean chooseNextLocation() {
        this.currentLocation = "" + this.objectives.size();
        return true;
    }

    private void computeWeightMeans() {
        int i;
        this.weightMeans = new double[this.mutableRules.size()];
        if (this.uniformBase) {
            for (int i2 = 0; i2 < this.mutableRules.size(); ++i2) {
                this.weightMeans[i2] = this.baseWeight;
            }
            return;
        }
        for (WeightedRule rule : this.mutableRules) {
            rule.setWeight(1.0);
        }
        this.inMPEState = false;
        this.computeMPEState();
        double smallestCompatability = 1.0;
        for (i = 0; i < this.mutableRules.size(); ++i) {
            int count = 0;
            for (GroundRule groundRule : this.groundRuleStore.getGroundRules((Rule)this.mutableRules.get(i))) {
                if (!(groundRule instanceof WeightedGroundRule)) continue;
                ++count;
                int n = i;
                this.weightMeans[n] = this.weightMeans[n] + (1.0 - ((WeightedGroundRule)groundRule).getIncompatibility());
            }
            if (count == 0) {
                this.weightMeans[i] = 0.0;
            } else {
                int n = i;
                this.weightMeans[n] = this.weightMeans[n] / (double)count;
            }
            if (!(this.weightMeans[i] < smallestCompatability)) continue;
            smallestCompatability = this.weightMeans[i];
        }
        for (i = 0; i < this.mutableRules.size(); ++i) {
            this.weightMeans[i] = this.baseWeight * this.weightMeans[i] / smallestCompatability;
        }
    }
}

