/*
 * Decompiled with CFR 0.152.
 */
package ai.mcts.naivemcts;

import ai.mcts.MCTSNode;
import ai.mcts.naivemcts.UnitActionTableEntry;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import rts.GameState;
import rts.PlayerAction;
import rts.PlayerActionGenerator;
import rts.ResourceUsage;
import rts.UnitAction;
import rts.units.Unit;
import util.Pair;
import util.Sampler;

public class NaiveMCTSNode
extends MCTSNode {
    public static final int E_GREEDY = 0;
    public static final int UCB1 = 1;
    public static int DEBUG = 0;
    public static float C = 0.05f;
    boolean forceExplorationOfNonSampledActions = true;
    boolean hasMoreActions = true;
    public PlayerActionGenerator moveGenerator;
    HashMap<BigInteger, NaiveMCTSNode> childrenMap = new LinkedHashMap<BigInteger, NaiveMCTSNode>();
    public List<UnitActionTableEntry> unitActionTable;
    double evaluation_bound;
    public BigInteger[] multipliers;

    public NaiveMCTSNode(int maxplayer, int minplayer, GameState a_gs, NaiveMCTSNode a_parent, double a_evaluation_bound, int a_creation_ID, boolean fensa) throws Exception {
        this.parent = a_parent;
        this.gs = a_gs;
        this.depth = this.parent == null ? 0 : this.parent.depth + 1;
        this.evaluation_bound = a_evaluation_bound;
        this.creation_ID = a_creation_ID;
        this.forceExplorationOfNonSampledActions = fensa;
        while (!(this.gs.winner() != -1 || this.gs.gameover() || this.gs.canExecuteAnyAction(maxplayer) || this.gs.canExecuteAnyAction(minplayer))) {
            this.gs.cycle();
        }
        if (this.gs.winner() != -1 || this.gs.gameover()) {
            this.type = -1;
        } else if (this.gs.canExecuteAnyAction(maxplayer)) {
            this.type = 0;
            this.moveGenerator = new PlayerActionGenerator(this.gs, maxplayer);
            this.actions = new ArrayList();
            this.children = new ArrayList();
            this.unitActionTable = new LinkedList<UnitActionTableEntry>();
            this.multipliers = new BigInteger[this.moveGenerator.getChoices().size()];
            BigInteger baseMultiplier = BigInteger.ONE;
            int idx = 0;
            for (Pair<Unit, List<UnitAction>> choice : this.moveGenerator.getChoices()) {
                UnitActionTableEntry ae = new UnitActionTableEntry();
                ae.u = (Unit)choice.m_a;
                ae.nactions = ((List)choice.m_b).size();
                ae.actions = (List)choice.m_b;
                ae.accum_evaluation = new double[ae.nactions];
                ae.visit_count = new int[ae.nactions];
                for (int i = 0; i < ae.nactions; ++i) {
                    ae.accum_evaluation[i] = 0.0;
                    ae.visit_count[i] = 0;
                }
                this.unitActionTable.add(ae);
                this.multipliers[idx] = baseMultiplier;
                baseMultiplier = baseMultiplier.multiply(BigInteger.valueOf(ae.nactions));
                ++idx;
            }
        } else if (this.gs.canExecuteAnyAction(minplayer)) {
            this.type = 1;
            this.moveGenerator = new PlayerActionGenerator(this.gs, minplayer);
            this.actions = new ArrayList();
            this.children = new ArrayList();
            this.unitActionTable = new LinkedList<UnitActionTableEntry>();
            this.multipliers = new BigInteger[this.moveGenerator.getChoices().size()];
            BigInteger baseMultiplier = BigInteger.ONE;
            int idx = 0;
            for (Pair<Unit, List<UnitAction>> choice : this.moveGenerator.getChoices()) {
                UnitActionTableEntry ae = new UnitActionTableEntry();
                ae.u = (Unit)choice.m_a;
                ae.nactions = ((List)choice.m_b).size();
                ae.actions = (List)choice.m_b;
                ae.accum_evaluation = new double[ae.nactions];
                ae.visit_count = new int[ae.nactions];
                for (int i = 0; i < ae.nactions; ++i) {
                    ae.accum_evaluation[i] = 0.0;
                    ae.visit_count[i] = 0;
                }
                this.unitActionTable.add(ae);
                this.multipliers[idx] = baseMultiplier;
                baseMultiplier = baseMultiplier.multiply(BigInteger.valueOf(ae.nactions));
                ++idx;
            }
        } else {
            this.type = -1;
            System.err.println("NaiveMCTSNode: This should not have happened...");
        }
    }

    public NaiveMCTSNode selectLeaf(int maxplayer, int minplayer, float epsilon_l, float epsilon_g, float epsilon_0, int global_strategy, int max_depth, int a_creation_ID) throws Exception {
        if (this.unitActionTable == null) {
            return this;
        }
        if (this.depth >= max_depth) {
            return this;
        }
        if (this.children.size() > 0 && r.nextFloat() >= epsilon_0) {
            NaiveMCTSNode selected = null;
            if (global_strategy == 0) {
                selected = this.selectFromAlreadySampledEpsilonGreedy(epsilon_g);
            } else if (global_strategy == 1) {
                selected = this.selectFromAlreadySampledUCB1(C);
            }
            return selected.selectLeaf(maxplayer, minplayer, epsilon_l, epsilon_g, epsilon_0, global_strategy, max_depth, a_creation_ID);
        }
        return this.selectLeafUsingLocalMABs(maxplayer, minplayer, epsilon_l, epsilon_g, epsilon_0, global_strategy, max_depth, a_creation_ID);
    }

    public NaiveMCTSNode selectFromAlreadySampledEpsilonGreedy(float epsilon_g) throws Exception {
        if (r.nextFloat() >= epsilon_g) {
            NaiveMCTSNode best = null;
            for (MCTSNode pate : this.children) {
                if (this.type == 0) {
                    if (best != null && !(pate.accum_evaluation / (double)pate.visit_count > best.accum_evaluation / (double)best.visit_count)) continue;
                    best = (NaiveMCTSNode)pate;
                    continue;
                }
                if (best != null && !(pate.accum_evaluation / (double)pate.visit_count < best.accum_evaluation / (double)best.visit_count)) continue;
                best = (NaiveMCTSNode)pate;
            }
            return best;
        }
        NaiveMCTSNode best = (NaiveMCTSNode)this.children.get(r.nextInt(this.children.size()));
        return best;
    }

    public NaiveMCTSNode selectFromAlreadySampledUCB1(float C) throws Exception {
        NaiveMCTSNode best = null;
        double bestScore = 0.0;
        for (MCTSNode pate : this.children) {
            double exploitation = pate.accum_evaluation / (double)pate.visit_count;
            double exploration = Math.sqrt(Math.log(this.visit_count) / (double)pate.visit_count);
            exploitation = this.type == 0 ? (this.evaluation_bound + exploitation) / (2.0 * this.evaluation_bound) : (this.evaluation_bound - exploitation) / (2.0 * this.evaluation_bound);
            double tmp = (double)C * exploitation + exploration;
            if (best != null && !(tmp > bestScore)) continue;
            best = (NaiveMCTSNode)pate;
            bestScore = tmp;
        }
        return best;
    }

    public NaiveMCTSNode selectLeafUsingLocalMABs(int maxplayer, int minplayer, float epsilon_l, float epsilon_g, float epsilon_0, int global_strategy, int max_depth, int a_creation_ID) throws Exception {
        LinkedList<double[]> distributions = new LinkedList<double[]>();
        LinkedList<Integer> notSampledYet = new LinkedList<Integer>();
        for (UnitActionTableEntry unitActionTableEntry : this.unitActionTable) {
            int i;
            double[] dist = new double[unitActionTableEntry.nactions];
            int bestIdx = -1;
            double bestEvaluation = 0.0;
            int visits = 0;
            for (i = 0; i < unitActionTableEntry.nactions; ++i) {
                if (this.type == 0) {
                    if (bestIdx == -1 || visits != 0 && unitActionTableEntry.visit_count[i] == 0 || visits != 0 && unitActionTableEntry.accum_evaluation[i] / (double)unitActionTableEntry.visit_count[i] > bestEvaluation) {
                        bestIdx = i;
                        bestEvaluation = unitActionTableEntry.visit_count[i] > 0 ? unitActionTableEntry.accum_evaluation[i] / (double)unitActionTableEntry.visit_count[i] : 0.0;
                        visits = unitActionTableEntry.visit_count[i];
                    }
                } else if (bestIdx == -1 || visits != 0 && unitActionTableEntry.visit_count[i] == 0 || visits != 0 && unitActionTableEntry.accum_evaluation[i] / (double)unitActionTableEntry.visit_count[i] < bestEvaluation) {
                    bestIdx = i;
                    bestEvaluation = unitActionTableEntry.visit_count[i] > 0 ? unitActionTableEntry.accum_evaluation[i] / (double)unitActionTableEntry.visit_count[i] : 0.0;
                    visits = unitActionTableEntry.visit_count[i];
                }
                dist[i] = epsilon_l / (float)unitActionTableEntry.nactions;
            }
            if (unitActionTableEntry.visit_count[bestIdx] != 0) {
                dist[bestIdx] = 1.0f - epsilon_l + epsilon_l / (float)unitActionTableEntry.nactions;
            } else if (this.forceExplorationOfNonSampledActions) {
                for (int j = 0; j < dist.length; ++j) {
                    if (unitActionTableEntry.visit_count[j] <= 0) continue;
                    dist[j] = 0.0;
                }
            }
            if (DEBUG >= 3) {
                System.out.print("[ ");
                for (i = 0; i < unitActionTableEntry.nactions; ++i) {
                    System.out.print("(" + unitActionTableEntry.visit_count[i] + "," + unitActionTableEntry.accum_evaluation[i] / (double)unitActionTableEntry.visit_count[i] + ")");
                }
                System.out.println("]");
                System.out.print("[ ");
                for (double v : dist) {
                    System.out.print(v + " ");
                }
                System.out.println("]");
            }
            notSampledYet.add(distributions.size());
            distributions.add(dist);
        }
        ResourceUsage base_ru = new ResourceUsage();
        for (Unit u : this.gs.getUnits()) {
            UnitAction ua = this.gs.getUnitAction(u);
            if (ua == null) continue;
            ResourceUsage ru = ua.resourceUsage(u, this.gs.getPhysicalGameState());
            base_ru.merge(ru);
        }
        PlayerAction pa2 = new PlayerAction();
        BigInteger actionCode = BigInteger.ZERO;
        pa2.setResourceUsage(base_ru.clone());
        while (!notSampledYet.isEmpty()) {
            int n = (Integer)notSampledYet.remove(r.nextInt(notSampledYet.size()));
            try {
                UnitActionTableEntry ate = this.unitActionTable.get(n);
                double[] distribution = (double[])distributions.get(n);
                int code = Sampler.weighted(distribution);
                UnitAction ua = ate.actions.get(code);
                ResourceUsage r2 = ua.resourceUsage(ate.u, this.gs.getPhysicalGameState());
                if (!pa2.getResourceUsage().consistentWith(r2, this.gs)) {
                    ArrayList<Double> dist_l = new ArrayList<Double>();
                    ArrayList<Integer> dist_outputs = new ArrayList<Integer>();
                    for (int j = 0; j < distribution.length; ++j) {
                        dist_l.add(distribution[j]);
                        dist_outputs.add(j);
                    }
                    do {
                        int idx = dist_outputs.indexOf(code);
                        dist_l.remove(idx);
                        dist_outputs.remove(idx);
                        code = (Integer)Sampler.weighted(dist_l, dist_outputs);
                        ua = ate.actions.get(code);
                        r2 = ua.resourceUsage(ate.u, this.gs.getPhysicalGameState());
                    } while (!pa2.getResourceUsage().consistentWith(r2, this.gs));
                }
                if (this.gs.getUnit(ate.u.getID()) == null) {
                    throw new Error("Issuing an action to an inexisting unit!!!");
                }
                pa2.getResourceUsage().merge(r2);
                pa2.addUnitAction(ate.u, ua);
                actionCode = actionCode.add(BigInteger.valueOf(code).multiply(this.multipliers[n]));
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }
        NaiveMCTSNode naiveMCTSNode = this.childrenMap.get(actionCode);
        if (naiveMCTSNode == null) {
            this.actions.add(pa2);
            GameState gs2 = this.gs.cloneIssue(pa2);
            NaiveMCTSNode node = new NaiveMCTSNode(maxplayer, minplayer, gs2.clone(), this, this.evaluation_bound, a_creation_ID, this.forceExplorationOfNonSampledActions);
            this.childrenMap.put(actionCode, node);
            this.children.add(node);
            return node;
        }
        return naiveMCTSNode.selectLeaf(maxplayer, minplayer, epsilon_l, epsilon_g, epsilon_0, global_strategy, max_depth, a_creation_ID);
    }

    public UnitActionTableEntry getActionTableEntry(Unit u) {
        for (UnitActionTableEntry e : this.unitActionTable) {
            if (e.u != u) continue;
            return e;
        }
        throw new Error("Could not find Action Table Entry!");
    }

    public void propagateEvaluation(double evaluation, NaiveMCTSNode child) {
        this.accum_evaluation += evaluation;
        ++this.visit_count;
        if (child != null) {
            int idx = this.children.indexOf(child);
            PlayerAction pa = (PlayerAction)this.actions.get(idx);
            for (Pair<Unit, UnitAction> ua : pa.getActions()) {
                UnitActionTableEntry actionTable = this.getActionTableEntry((Unit)ua.m_a);
                idx = actionTable.actions.indexOf(ua.m_b);
                if (idx == -1) {
                    System.out.println("Looking for action: " + ua.m_b);
                    System.out.println("Available actions are: " + actionTable.actions);
                }
                int n = idx;
                actionTable.accum_evaluation[n] = actionTable.accum_evaluation[n] + evaluation;
                int n2 = idx;
                actionTable.visit_count[n2] = actionTable.visit_count[n2] + 1;
            }
        }
        if (this.parent != null) {
            ((NaiveMCTSNode)this.parent).propagateEvaluation(evaluation, this);
        }
    }

    public void printUnitActionTable() {
        for (UnitActionTableEntry uat : this.unitActionTable) {
            System.out.println("Actions for unit " + uat.u);
            for (int i = 0; i < uat.nactions; ++i) {
                System.out.println("   " + uat.actions.get(i) + " visited " + uat.visit_count[i] + " with average evaluation " + uat.accum_evaluation[i] / (double)uat.visit_count[i]);
            }
        }
    }
}

