/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.search.grid;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.linqs.psl.application.learning.weight.search.grid.RandomGridSearch;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.util.Logger;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.StringUtils;

public class GuidedRandomGridSearch
extends RandomGridSearch {
    private static final Logger log = Logger.getLogger(GuidedRandomGridSearch.class);
    private final int maxNumSeedLocations;
    private int numSeedLocations;
    private final int maxNumExploreLocations;
    private int numExploreLocations;
    private Set<String> toExplore;

    public GuidedRandomGridSearch(List<Rule> rules, Database trainTargetDatabase, Database trainTruthDatabase, Database validationTargetDatabase, Database validationTruthDatabase, boolean runValidation) {
        super(rules, trainTargetDatabase, trainTruthDatabase, validationTargetDatabase, validationTruthDatabase, runValidation);
        this.numSeedLocations = this.maxNumSeedLocations = Options.WLA_GRGS_SEED_LOCATIONS.getInt();
        this.numExploreLocations = this.maxNumExploreLocations = Options.WLA_GRGS_EXPLORE_LOCATIONS.getInt();
        this.numLocations = Math.min(this.numLocations, this.numSeedLocations + this.numExploreLocations * (int)Math.pow(2.0, this.mutableRules.size()));
        this.toExplore = new HashSet<String>(Math.max(10, this.numLocations - this.numSeedLocations));
    }

    @Override
    protected boolean chooseNextLocation() {
        if (this.objectives.size() < this.numSeedLocations) {
            do {
                this.currentLocation = this.randomConfiguration();
            } while (this.objectives.containsKey(this.currentLocation));
        } else {
            if (this.objectives.size() == this.numSeedLocations) {
                ArrayList locations = new ArrayList(this.objectives.entrySet());
                Collections.sort(locations, new Comparator<Map.Entry<String, Double>>(){

                    @Override
                    public int compare(Map.Entry<String, Double> a, Map.Entry<String, Double> b) {
                        return MathUtils.compare(a.getValue(), b.getValue());
                    }
                });
                for (int i = 0; i < Math.min(this.numExploreLocations, this.objectives.size()); ++i) {
                    log.trace("Adding neighbors for {}.", locations.get(i));
                    this.addNeighbors((String)((Map.Entry)locations.get(i)).getKey());
                }
                this.toExplore.removeAll(this.objectives.keySet());
                log.debug("Seed phase complete, starting explore phase with {} locations.", this.toExplore.size());
            }
            if (this.toExplore.size() == 0) {
                return false;
            }
            this.currentLocation = this.toExplore.iterator().next();
            this.toExplore.remove(this.currentLocation);
        }
        return true;
    }

    private void addNeighbors(String location) {
        int[] indexes = StringUtils.splitInt(location, ":");
        assert (indexes.length == this.mutableRules.size());
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            if (indexes[i] != this.possibleWeights.length - 1) {
                int n = i;
                indexes[n] = indexes[n] + 1;
                this.toExplore.add(StringUtils.join(":", indexes));
                int n2 = i;
                indexes[n2] = indexes[n2] - 1;
            }
            if (indexes[i] == 0) continue;
            int n = i;
            indexes[n] = indexes[n] - 1;
            this.toExplore.add(StringUtils.join(":", indexes));
            int n3 = i;
            indexes[n3] = indexes[n3] + 1;
        }
    }

    @Override
    public void setBudget(double budget) {
        super.setBudget(budget);
        this.numSeedLocations = (int)Math.ceil(budget * (double)this.maxNumSeedLocations);
        this.numExploreLocations = (int)Math.ceil(budget * (double)this.maxNumExploreLocations);
        this.numLocations = Math.min(this.numLocations, this.numSeedLocations + this.numExploreLocations * (int)Math.pow(2.0, this.mutableRules.size()));
    }
}

