/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.search;

import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.config.Config;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.util.RandUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Hyperband
extends WeightLearningApplication {
    private static final Logger log = LoggerFactory.getLogger(Hyperband.class);
    public static final String CONFIG_PREFIX = "hyperband";
    public static final String SURVIVAL_KEY = "hyperband.survival";
    public static final int SURVIVAL_DEFAULT = 4;
    public static final String BASE_BRACKET_SIZE_KEY = "hyperband.basebracketsize";
    public static final int BASE_BRACKET_SIZE_DEFAULT = 10;
    public static final String NUM_BRACKETS_KEY = "hyperband.numbrackets";
    public static final int NUM_BRACKETS_DEFAULT = 4;
    public static final double MIN_BUDGET_PROPORTION = 0.001;
    public static final int MIN_BRACKET_SIZE = 1;
    public static final double MEAN = 0.5;
    public static final double VARIANCE = 0.1;
    private final int survival = Config.getInt("hyperband.survival", 4);
    private double bestObjective;
    private double[] bestWeights;
    private int numBrackets;
    private int baseBracketSize;

    public Hyperband(Model model, Database rvDB, Database observedDB) {
        this(model.getRules(), rvDB, observedDB);
    }

    public Hyperband(List<Rule> rules, Database rvDB, Database observedDB) {
        super(rules, rvDB, observedDB, false);
        if (this.survival < 1) {
            throw new IllegalArgumentException("Need at least one survival porportion.");
        }
        this.numBrackets = Config.getInt(NUM_BRACKETS_KEY, 4);
        if (this.numBrackets < 1) {
            throw new IllegalArgumentException("Need at least one bracket.");
        }
        this.baseBracketSize = Config.getInt(BASE_BRACKET_SIZE_KEY, 10);
        if (this.baseBracketSize < 1) {
            throw new IllegalArgumentException("Need at least one bracket size.");
        }
    }

    @Override
    protected void doLearn() {
        double bestObjective = -1.0;
        double[] bestWeights = null;
        this.computeObservedIncompatibility();
        double totalCost = 0.0;
        int numEvaluatedConfigs = 0;
        for (int bracket = 0; bracket < this.numBrackets; ++bracket) {
            double bracketProportion = Math.pow(this.survival, bracket) / (double)(bracket + 1);
            int bracketSize = (int)Math.max(1.0, Math.ceil(bracketProportion * (double)this.baseBracketSize));
            numEvaluatedConfigs += bracketSize;
            double bracketBudget = Math.pow(this.survival, -1.0 * (double)bracket);
            log.debug("Bracket {} / {} -- Size: {} ({}), Budget: {}", bracket + 1, this.numBrackets, bracketSize, bracketProportion, bracketBudget);
            List<double[]> configs = this.chooseConfigs(bracketSize);
            for (int round = 0; round <= bracket; ++round) {
                int roundSize = configs.size();
                double roundBudget = bracketBudget * Math.pow(this.survival, round);
                this.setBudget(Math.max(0.001, Math.min(1.0, roundBudget)));
                log.debug("  Round {} / {} -- Size: {}, Budget: {}", round + 1, bracket + 1, roundSize, roundBudget);
                PriorityQueue<RunResult> results = new PriorityQueue<RunResult>();
                for (double[] config : configs) {
                    totalCost += roundBudget;
                    for (int i = 0; i < this.mutableRules.size(); ++i) {
                        ((WeightedRule)this.mutableRules.get(i)).setWeight(config[i]);
                    }
                    this.inMPEState = false;
                    this.inLatentMPEState = false;
                    double objective = this.run(config);
                    RunResult result = new RunResult(config, objective);
                    results.add(result);
                    if (bestWeights == null || objective < bestObjective) {
                        bestObjective = objective;
                        bestWeights = config;
                    }
                    log.debug("Training Objective: {}, Weights: {}", (Object)objective, (Object)config);
                }
                configs.clear();
                for (int i = 0; i < (int)Math.floor((double)roundSize / (double)this.survival); ++i) {
                    configs.add(((RunResult)results.poll()).weights);
                }
            }
        }
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            ((WeightedRule)this.mutableRules.get(i)).setWeight((double)bestWeights[i]);
        }
        this.inMPEState = false;
        this.inLatentMPEState = false;
        log.debug("Hyperband complete. Configurations examined: {}. Total budget: {}", (Object)numEvaluatedConfigs, (Object)totalCost);
    }

    private List<double[]> chooseConfigs(int bracketSize) {
        ArrayList<double[]> configs = new ArrayList<double[]>(bracketSize);
        for (int i = 0; i < bracketSize; ++i) {
            double[] config = new double[this.mutableRules.size()];
            for (int weightIndex = 0; weightIndex < this.mutableRules.size(); ++weightIndex) {
                config[weightIndex] = RandUtils.nextDouble() * Math.sqrt(0.1) + 0.5;
            }
            configs.add(config);
        }
        return configs;
    }

    protected double run(double[] weights) {
        this.setDefaultRandomVariables();
        this.computeExpectedIncompatibility();
        this.evaluator.compute(this.trainingMap);
        double score = this.evaluator.getRepresentativeMetric();
        score = this.evaluator.isHigherRepresentativeBetter() ? -1.0 * score : score;
        return score;
    }

    private static class RunResult
    implements Comparable<RunResult> {
        public double[] weights;
        public double objective;

        public RunResult(double[] weights, double objective) {
            this.weights = weights;
            this.objective = objective;
        }

        @Override
        public int compareTo(RunResult other) {
            return Double.compare(this.objective, other.objective);
        }
    }
}

