/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight.search;

import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;
import org.linqs.psl.application.learning.weight.WeightLearningApplication;
import org.linqs.psl.application.learning.weight.search.WeightSampler;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.Database;
import org.linqs.psl.model.Model;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Hyperband
extends WeightLearningApplication {
    private static final Logger log = LoggerFactory.getLogger(Hyperband.class);
    public static final double MIN_BUDGET_PROPORTION = 0.001;
    public static final int MIN_BRACKET_SIZE = 1;
    private final int survival;
    private WeightSampler weightSampler;
    private int numBrackets;
    private int baseBracketSize;

    public Hyperband(Model model, Database rvDB, Database observedDB) {
        this(model.getRules(), rvDB, observedDB);
    }

    public Hyperband(List<Rule> rules, Database rvDB, Database observedDB) {
        super(rules, rvDB, observedDB);
        this.weightSampler = new WeightSampler(this.mutableRules.size());
        this.survival = Options.WLA_HB_SURVIVAL.getInt();
        this.numBrackets = Options.WLA_HB_NUM_BRACKETS.getInt();
        this.baseBracketSize = Options.WLA_HB_BRACKET_SIZE.getInt();
    }

    @Override
    protected void doLearn() {
        double bestObjective = -1.0;
        float[] bestWeights = null;
        String currentLocation = null;
        double totalCost = 0.0;
        int numEvaluatedConfigs = 0;
        for (int bracket = 0; bracket < this.numBrackets; ++bracket) {
            double bracketProportion = Math.pow(this.survival, bracket) / (double)(bracket + 1);
            int bracketSize = (int)Math.max(1.0, Math.ceil(bracketProportion * (double)this.baseBracketSize));
            numEvaluatedConfigs += bracketSize;
            double bracketBudget = Math.pow(this.survival, -1.0 * (double)bracket);
            log.debug("Bracket {} / {} -- Size: {} ({}), Budget: {}", bracket + 1, this.numBrackets, bracketSize, bracketProportion, bracketBudget);
            List<float[]> configs = this.chooseConfigs(bracketSize);
            for (int round = 0; round <= bracket; ++round) {
                int roundSize = configs.size();
                double roundBudget = bracketBudget * Math.pow(this.survival, round);
                this.setBudget(Math.max(0.001, Math.min(1.0, roundBudget)));
                log.debug("  Round {} / {} -- Size: {}, Budget: {}", round + 1, bracket + 1, roundSize, roundBudget);
                PriorityQueue<RunResult> results = new PriorityQueue<RunResult>();
                for (float[] config : configs) {
                    totalCost += roundBudget;
                    for (int i = 0; i < this.mutableRules.size(); ++i) {
                        ((WeightedRule)this.mutableRules.get(i)).setWeight(config[i]);
                    }
                    currentLocation = StringUtils.join(":", config);
                    log.trace("Weights: {}", (Object)config);
                    this.inMPEState = false;
                    double objective = this.run(config);
                    RunResult result = new RunResult(config, objective);
                    results.add(result);
                    log.debug("Weights: {} -- objective: {}", (Object)currentLocation, (Object)objective);
                    if (bestWeights == null || objective < bestObjective) {
                        bestObjective = objective;
                        bestWeights = config;
                    }
                    log.debug("Training Objective: {}, Weights: {}", (Object)objective, (Object)config);
                }
                configs.clear();
                for (int i = 0; i < (int)Math.floor((float)roundSize / (float)this.survival); ++i) {
                    configs.add(((RunResult)results.poll()).weights);
                }
            }
        }
        for (int i = 0; i < this.mutableRules.size(); ++i) {
            ((WeightedRule)this.mutableRules.get(i)).setWeight((float)bestWeights[i]);
        }
        this.inMPEState = false;
        log.debug("Hyperband complete. Configurations examined: {}. Total budget: {}", (Object)numEvaluatedConfigs, (Object)totalCost);
    }

    private List<float[]> chooseConfigs(int bracketSize) {
        ArrayList<float[]> configs = new ArrayList<float[]>(bracketSize);
        for (int i = 0; i < bracketSize; ++i) {
            float[] config = new float[this.mutableRules.size()];
            this.weightSampler.getRandomWeights(config);
            configs.add(config);
        }
        return configs;
    }

    protected double run(float[] weights) {
        this.computeMPEState();
        this.evaluator.compute(this.trainingMap);
        return -1.0 * this.evaluator.getNormalizedRepMetric();
    }

    private static class RunResult
    implements Comparable<RunResult> {
        private final float[] weights;
        private final double objective;

        public RunResult(float[] weights, double objective) {
            this.weights = weights;
            this.objective = objective;
        }

        public float[] getWeights() {
            return this.weights;
        }

        @Override
        public int compareTo(RunResult other) {
            return Double.compare(this.objective, other.objective);
        }

        public boolean equals(Object other) {
            if (other == null || !(other instanceof RunResult)) {
                return false;
            }
            return MathUtils.equals(this.objective, ((RunResult)other).objective);
        }

        public int hashCode() {
            return (int)(this.objective * 1000000.0);
        }
    }
}

