/*
 * Decompiled with CFR 0.152.
 */
package ai.mcts.uct;

import ai.RandomBiasedAI;
import ai.core.AI;
import ai.core.AIWithComputationBudget;
import ai.core.InterruptibleAI;
import ai.core.ParameterSpecification;
import ai.evaluation.EvaluationFunction;
import ai.evaluation.SimpleSqrtEvaluationFunction3;
import ai.mcts.uct.UCTNode;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import rts.GameState;
import rts.PlayerAction;
import rts.units.UnitTypeTable;

public class UCT
extends AIWithComputationBudget
implements InterruptibleAI {
    public static int DEBUG = 0;
    EvaluationFunction ef;
    Random r = new Random();
    AI randomAI = new RandomBiasedAI();
    long max_actions_so_far = 0L;
    GameState gs_to_start_from;
    public UCTNode tree;
    public long total_runs = 0L;
    public long total_cycles_executed = 0L;
    public long total_actions_issued = 0L;
    long total_runs_this_move = 0L;
    int MAXSIMULATIONTIME = 1024;
    int MAX_TREE_DEPTH = 10;
    int playerForThisComputation;

    public UCT(UnitTypeTable utt) {
        this(100, -1, 100, 10, new RandomBiasedAI(), new SimpleSqrtEvaluationFunction3());
    }

    public UCT(int available_time, int max_playouts, int lookahead, int max_depth, AI policy, EvaluationFunction a_ef) {
        super(available_time, max_playouts);
        this.MAXSIMULATIONTIME = lookahead;
        this.randomAI = policy;
        this.MAX_TREE_DEPTH = max_depth;
        this.ef = a_ef;
    }

    @Override
    public String statisticsString() {
        return "Average runs per cycle: " + (double)this.total_runs / (double)this.total_cycles_executed + ", Average runs per action: " + (double)this.total_runs / (double)this.total_actions_issued;
    }

    @Override
    public void printStats() {
        if (this.total_cycles_executed > 0L && this.total_actions_issued > 0L) {
            System.out.println("Average runs per cycle: " + (double)this.total_runs / (double)this.total_cycles_executed);
            System.out.println("Average runs per action: " + (double)this.total_runs / (double)this.total_actions_issued);
        }
    }

    @Override
    public void reset() {
        this.gs_to_start_from = null;
        this.tree = null;
        this.total_runs_this_move = 0L;
    }

    @Override
    public AI clone() {
        return new UCT(this.TIME_BUDGET, this.ITERATIONS_BUDGET, this.MAXSIMULATIONTIME, this.MAX_TREE_DEPTH, this.randomAI, this.ef);
    }

    @Override
    public PlayerAction getAction(int player, GameState gs) throws Exception {
        if (gs.canExecuteAnyAction(player)) {
            this.startNewComputation(player, gs.clone());
            this.computeDuringOneGameFrame();
            return this.getBestActionSoFar();
        }
        return new PlayerAction();
    }

    @Override
    public void startNewComputation(int a_player, GameState gs) throws Exception {
        float evaluation_bound = this.ef.upperBound(gs);
        this.playerForThisComputation = a_player;
        this.tree = new UCTNode(this.playerForThisComputation, 1 - this.playerForThisComputation, gs, null, evaluation_bound);
        this.gs_to_start_from = gs;
        this.total_runs_this_move = 0L;
    }

    public void resetSearch() {
        if (DEBUG >= 2) {
            System.out.println("Resetting search...");
        }
        this.tree = null;
        this.gs_to_start_from = null;
        this.total_runs_this_move = 0L;
    }

    @Override
    public void computeDuringOneGameFrame() throws Exception {
        if (DEBUG >= 2) {
            System.out.println("Search...");
        }
        long start = System.currentTimeMillis();
        int nPlayouts = 0;
        long cutOffTime = start + (long)this.TIME_BUDGET;
        if (this.TIME_BUDGET <= 0) {
            cutOffTime = 0L;
        }
        while (!(cutOffTime > 0L && System.currentTimeMillis() > cutOffTime || this.ITERATIONS_BUDGET > 0 && nPlayouts > this.ITERATIONS_BUDGET)) {
            this.monteCarloRun(this.playerForThisComputation, cutOffTime);
            ++nPlayouts;
        }
        ++this.total_cycles_executed;
    }

    public double monteCarloRun(int player, long cutOffTime) throws Exception {
        UCTNode leaf = this.tree.UCTSelectLeaf(player, 1 - player, cutOffTime, this.MAX_TREE_DEPTH);
        if (leaf != null) {
            GameState gs2 = leaf.gs.clone();
            this.simulate(gs2, gs2.getTime() + this.MAXSIMULATIONTIME);
            int time = gs2.getTime() - this.gs_to_start_from.getTime();
            double evaluation = (double)this.ef.evaluate(player, 1 - player, gs2) * Math.pow(0.99, (double)time / 10.0);
            while (leaf != null) {
                leaf.accum_evaluation = (float)((double)leaf.accum_evaluation + evaluation);
                ++leaf.visit_count;
                leaf = leaf.parent;
            }
            ++this.total_runs;
            ++this.total_runs_this_move;
            return evaluation;
        }
        System.err.println(this.getClass().getSimpleName() + ": claims there are no more leafs to explore...");
        return 0.0;
    }

    @Override
    public PlayerAction getBestActionSoFar() {
        ++this.total_actions_issued;
        if (this.tree.children == null) {
            if (DEBUG >= 1) {
                System.out.println(this.getClass().getSimpleName() + " no children selected. Returning an empty asction");
            }
            return new PlayerAction();
        }
        int mostVisitedIdx = -1;
        UCTNode mostVisited = null;
        for (int i = 0; i < this.tree.children.size(); ++i) {
            UCTNode child = this.tree.children.get(i);
            if (mostVisited != null && child.visit_count <= mostVisited.visit_count && (child.visit_count != mostVisited.visit_count || !(child.accum_evaluation > mostVisited.accum_evaluation))) continue;
            mostVisited = child;
            mostVisitedIdx = i;
        }
        if (DEBUG >= 2) {
            this.tree.showNode(0, 1);
        }
        if (DEBUG >= 1) {
            System.out.println(this.getClass().getSimpleName() + " performed " + this.total_runs_this_move + " playouts.");
        }
        if (DEBUG >= 1) {
            System.out.println(this.getClass().getSimpleName() + " selected children " + this.tree.actions.get(mostVisitedIdx) + " explored " + mostVisited.visit_count + " Avg evaluation: " + (double)mostVisited.accum_evaluation / (double)mostVisited.visit_count);
        }
        if (mostVisitedIdx == -1) {
            return new PlayerAction();
        }
        return this.tree.actions.get(mostVisitedIdx);
    }

    public float getBestActionEvaluation(GameState gs, int player, int N) throws Exception {
        PlayerAction pa = this.getBestActionSoFar();
        if (pa == null) {
            return 0.0f;
        }
        float accum = 0.0f;
        for (int i = 0; i < N; ++i) {
            GameState gs2 = gs.cloneIssue(pa);
            GameState gs3 = gs2.clone();
            this.simulate(gs3, gs3.getTime() + this.MAXSIMULATIONTIME);
            int time = gs3.getTime() - gs2.getTime();
            accum += (float)((double)this.ef.evaluate(player, 1 - player, gs3) * Math.pow(0.99, (double)time / 10.0));
        }
        return accum / (float)N;
    }

    public void simulate(GameState gs, int time) throws Exception {
        boolean gameover = false;
        do {
            if (gs.isComplete()) {
                gameover = gs.cycle();
                continue;
            }
            gs.issue(this.randomAI.getAction(0, gs));
            gs.issue(this.randomAI.getAction(1, gs));
        } while (!gameover && gs.getTime() < time);
    }

    @Override
    public String toString() {
        return this.getClass().getSimpleName() + "(" + this.TIME_BUDGET + ", " + this.ITERATIONS_BUDGET + ", " + this.MAXSIMULATIONTIME + ", " + this.MAX_TREE_DEPTH + ", " + this.randomAI + ", " + this.ef + ")";
    }

    @Override
    public List<ParameterSpecification> getParameters() {
        ArrayList<ParameterSpecification> parameters = new ArrayList<ParameterSpecification>();
        parameters.add(new ParameterSpecification("TimeBudget", Integer.TYPE, 100));
        parameters.add(new ParameterSpecification("IterationsBudget", Integer.TYPE, -1));
        parameters.add(new ParameterSpecification("PlayoutLookahead", Integer.TYPE, 100));
        parameters.add(new ParameterSpecification("MaxTreeDepth", Integer.TYPE, 10));
        parameters.add(new ParameterSpecification("DefaultPolicy", AI.class, this.randomAI));
        parameters.add(new ParameterSpecification("EvaluationFunction", EvaluationFunction.class, new SimpleSqrtEvaluationFunction3()));
        return parameters;
    }

    public int getPlayoutLookahead() {
        return this.MAXSIMULATIONTIME;
    }

    public void setPlayoutLookahead(int a_pola) {
        this.MAXSIMULATIONTIME = a_pola;
    }

    public int getMaxTreeDepth() {
        return this.MAX_TREE_DEPTH;
    }

    public void setMaxTreeDepth(int a_mtd) {
        this.MAX_TREE_DEPTH = a_mtd;
    }

    public AI getDefaultPolicy() {
        return this.randomAI;
    }

    public void setDefaultPolicy(AI a_dp) {
        this.randomAI = a_dp;
    }

    public EvaluationFunction getEvaluationFunction() {
        return this.ef;
    }

    public void setEvaluationFunction(EvaluationFunction a_ef) {
        this.ef = a_ef;
    }
}

