/*
 * Decompiled with CFR 0.152.
 */
package ai.mcts.mlps;

import ai.mcts.MCTSNode;
import ai.montecarlo.lsi.Sampling;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import rts.GameState;
import rts.PlayerAction;
import rts.PlayerActionGenerator;
import rts.ResourceUsage;
import rts.UnitAction;
import rts.units.Unit;
import util.Pair;

public class MLPSNode
extends MCTSNode {
    public static int DEBUG = 0;
    boolean hasMoreActions = true;
    public PlayerActionGenerator moveGenerator;
    HashMap<Long, MLPSNode> childrenMap = new LinkedHashMap<Long, MLPSNode>();
    public List<Sampling.UnitActionTableEntry> unitActionTable;
    public List<double[]> UCBExplorationScores;
    public List<double[]> UCBExploitationScores;
    double evaluation_bound = 0.0;
    int max_nactions = 0;
    public long[] multipliers;

    public MLPSNode(int maxplayer, int minplayer, GameState a_gs, MLPSNode a_parent, double bound, int a_creation_ID) throws Exception {
        this.parent = a_parent;
        this.gs = a_gs;
        this.depth = this.parent == null ? 0 : this.parent.depth + 1;
        this.evaluation_bound = bound;
        this.creation_ID = a_creation_ID;
        while (!(this.gs.winner() != -1 || this.gs.gameover() || this.gs.canExecuteAnyAction(maxplayer) || this.gs.canExecuteAnyAction(minplayer))) {
            this.gs.cycle();
        }
        if (this.gs.winner() != -1 || this.gs.gameover()) {
            this.type = -1;
        } else if (this.gs.canExecuteAnyAction(maxplayer)) {
            this.type = 0;
            this.moveGenerator = new PlayerActionGenerator(a_gs, maxplayer);
            this.actions = new ArrayList();
            this.children = new ArrayList();
            this.unitActionTable = new ArrayList<Sampling.UnitActionTableEntry>();
            this.UCBExplorationScores = new ArrayList<double[]>();
            this.UCBExploitationScores = new ArrayList<double[]>();
            this.multipliers = new long[this.moveGenerator.getChoices().size()];
            long baseMultiplier = 1L;
            int idx = 0;
            for (Pair<Unit, List<UnitAction>> choice : this.moveGenerator.getChoices()) {
                Sampling.UnitActionTableEntry ae = new Sampling.UnitActionTableEntry();
                ae.u = (Unit)choice.m_a;
                ae.nactions = ((List)choice.m_b).size();
                if (ae.nactions > this.max_nactions) {
                    this.max_nactions = ae.nactions;
                }
                ae.actions = (List)choice.m_b;
                ae.accum_evaluation = new double[ae.nactions];
                ae.visit_count = new int[ae.nactions];
                for (int i = 0; i < ae.nactions; ++i) {
                    ae.accum_evaluation[i] = 0.0;
                    ae.visit_count[i] = 0;
                }
                this.unitActionTable.add(ae);
                this.UCBExplorationScores.add(new double[ae.nactions]);
                this.UCBExploitationScores.add(new double[ae.nactions]);
                this.multipliers[idx] = baseMultiplier;
                baseMultiplier *= (long)ae.nactions;
                ++idx;
            }
        } else if (this.gs.canExecuteAnyAction(minplayer)) {
            this.type = 1;
            this.moveGenerator = new PlayerActionGenerator(a_gs, minplayer);
            this.actions = new ArrayList();
            this.children = new ArrayList();
            this.unitActionTable = new ArrayList<Sampling.UnitActionTableEntry>();
            this.UCBExplorationScores = new ArrayList<double[]>();
            this.UCBExploitationScores = new ArrayList<double[]>();
            this.multipliers = new long[this.moveGenerator.getChoices().size()];
            long baseMultiplier = 1L;
            int idx = 0;
            for (Pair<Unit, List<UnitAction>> choice : this.moveGenerator.getChoices()) {
                Sampling.UnitActionTableEntry ae = new Sampling.UnitActionTableEntry();
                ae.u = (Unit)choice.m_a;
                ae.nactions = ((List)choice.m_b).size();
                if (ae.nactions > this.max_nactions) {
                    this.max_nactions = ae.nactions;
                }
                ae.actions = (List)choice.m_b;
                ae.accum_evaluation = new double[ae.nactions];
                ae.visit_count = new int[ae.nactions];
                for (int i = 0; i < ae.nactions; ++i) {
                    ae.accum_evaluation[i] = 0.0;
                    ae.visit_count[i] = 0;
                }
                this.unitActionTable.add(ae);
                this.UCBExplorationScores.add(new double[ae.nactions]);
                this.UCBExploitationScores.add(new double[ae.nactions]);
                this.multipliers[idx] = baseMultiplier;
                baseMultiplier *= (long)ae.nactions;
                ++idx;
            }
        } else {
            this.type = -1;
            System.err.println("MLPSNode: This should not have happened...");
        }
    }

    public double actionExploitationValue(Sampling.UnitActionTableEntry e, int action) {
        if (e.visit_count[action] == 0) {
            return 0.0;
        }
        double exploitation = e.accum_evaluation[action] / (double)e.visit_count[action];
        return exploitation;
    }

    public double explorationValue(int M, int n, int n_ij) {
        if (n_ij == 0) {
            return Double.MAX_VALUE;
        }
        double exploration = (double)M * Math.sqrt((double)(M + 1) * Math.log(n) / (double)n_ij);
        return exploration;
    }

    public MLPSNode selectLeaf(int maxplayer, int minplayer, double C, int max_depth, int a_creation_ID) throws Exception {
        if (this.unitActionTable == null) {
            return this;
        }
        if (this.depth >= max_depth) {
            return this;
        }
        if (DEBUG >= 1) {
            System.out.println("MLPSNode.selectLeaf...");
        }
        LinkedList<Integer> notSampledYetIDs = new LinkedList<Integer>();
        for (int ate_idx = 0; ate_idx < this.unitActionTable.size(); ++ate_idx) {
            int i;
            Sampling.UnitActionTableEntry ate = this.unitActionTable.get(ate_idx);
            double[] scoresExploitation = this.UCBExploitationScores.get(ate_idx);
            double[] scoresExploration = this.UCBExplorationScores.get(ate_idx);
            for (i = 0; i < ate.nactions; ++i) {
                scoresExploitation[i] = this.actionExploitationValue(ate, i);
                scoresExploration[i] = this.explorationValue(this.max_nactions, this.visit_count, ate.visit_count[i]);
            }
            if (DEBUG >= 3) {
                System.out.print("[ ");
                for (i = 0; i < ate.nactions; ++i) {
                    System.out.print("(" + ate.visit_count[i] + "," + scoresExploitation[i] + "," + scoresExploration[i] + ")");
                }
                System.out.println("]");
            }
            notSampledYetIDs.add(ate_idx);
        }
        ResourceUsage base_ru = new ResourceUsage();
        for (Unit u : this.gs.getUnits()) {
            UnitAction ua = this.gs.getUnitAction(u);
            if (ua == null) continue;
            ResourceUsage ru = ua.resourceUsage(u, this.gs.getPhysicalGameState());
            base_ru.merge(ru);
        }
        PlayerAction best_pa = null;
        long best_actionCode = -1L;
        double best_accumUCBScore = 0.0;
        for (int repeat = 0; repeat < 10; ++repeat) {
            PlayerAction pa2 = new PlayerAction();
            long actionCode = 0L;
            double accumUCBScore = 0.0;
            double maxExplorationScore = 0.0;
            pa2.setResourceUsage(base_ru.clone());
            LinkedList notSampledYetIDs2 = new LinkedList(notSampledYetIDs);
            while (!notSampledYetIDs2.isEmpty()) {
                if (DEBUG >= 2) {
                    System.out.println("notSampledYet: " + notSampledYetIDs2);
                }
                int i = r.nextInt(notSampledYetIDs2.size());
                i = (Integer)notSampledYetIDs2.remove(i);
                try {
                    Sampling.UnitActionTableEntry ate = this.unitActionTable.get(i);
                    double[] scoresExploitation = this.UCBExploitationScores.get(i);
                    double[] scoresExploration = this.UCBExplorationScores.get(i);
                    int code = -1;
                    for (int j = 0; j < ate.nactions; ++j) {
                        double s2;
                        if (code == -1) {
                            code = j;
                            continue;
                        }
                        double s1 = scoresExploitation[j] + C * Math.max(scoresExploration[j], maxExplorationScore);
                        if (!(s1 > (s2 = scoresExploitation[code] + C * Math.max(scoresExploration[code], maxExplorationScore)))) continue;
                        code = j;
                    }
                    UnitAction ua = ate.actions.get(code);
                    ResourceUsage r2 = ua.resourceUsage(ate.u, this.gs.getPhysicalGameState());
                    if (!pa2.getResourceUsage().consistentWith(r2, this.gs)) {
                        ArrayList<Integer> actions2 = new ArrayList<Integer>();
                        for (int j = 0; j < ate.nactions; ++j) {
                            if (j == code) continue;
                            actions2.add(j);
                        }
                        if (DEBUG >= 4) {
                            System.out.println("    unit " + i + ": trying " + code);
                        }
                        do {
                            code = -1;
                            for (Integer j : actions2) {
                                double s2;
                                if (code == -1) {
                                    code = j;
                                    continue;
                                }
                                double s1 = scoresExploitation[j] + C * Math.max(scoresExploration[j], maxExplorationScore);
                                if (!(s1 > (s2 = scoresExploitation[code] + C * Math.max(scoresExploration[code], maxExplorationScore)))) continue;
                                code = j;
                            }
                            if (DEBUG >= 4) {
                                System.out.println("    unit " + i + ": trying " + code);
                            }
                            actions2.remove((Object)code);
                            ua = ate.actions.get(code);
                            r2 = ua.resourceUsage(ate.u, this.gs.getPhysicalGameState());
                        } while (!pa2.getResourceUsage().consistentWith(r2, this.gs));
                    }
                    if (DEBUG >= 3) {
                        System.out.println("  unit " + i + ": " + code);
                    }
                    accumUCBScore += C * scoresExploitation[code];
                    maxExplorationScore = Math.max(scoresExploration[code], maxExplorationScore);
                    pa2.getResourceUsage().merge(r2);
                    pa2.addUnitAction(ate.u, ua);
                    actionCode += (long)code * this.multipliers[i];
                }
                catch (Exception e) {
                    e.printStackTrace();
                }
            }
            accumUCBScore += maxExplorationScore;
            if (DEBUG >= 1) {
                System.out.println("  accumUCBScore: " + accumUCBScore);
            }
            if (best_pa != null && !(accumUCBScore > best_accumUCBScore)) continue;
            best_pa = pa2;
            best_accumUCBScore = accumUCBScore;
            best_actionCode = actionCode;
        }
        MLPSNode pate = this.childrenMap.get(best_actionCode);
        if (pate == null) {
            this.actions.add(best_pa);
            GameState gs2 = this.gs.cloneIssue(best_pa);
            MLPSNode node = new MLPSNode(maxplayer, minplayer, gs2.clone(), this, this.evaluation_bound, a_creation_ID);
            this.childrenMap.put(best_actionCode, node);
            this.children.add(node);
            return node;
        }
        return pate.selectLeaf(maxplayer, minplayer, C, max_depth, a_creation_ID);
    }

    public Sampling.UnitActionTableEntry getActionTableEntry(Unit u) {
        for (Sampling.UnitActionTableEntry e : this.unitActionTable) {
            if (e.u != u) continue;
            return e;
        }
        return null;
    }

    public void propagateEvaluation(float evaluation, MLPSNode child) {
        this.accum_evaluation += (double)evaluation;
        ++this.visit_count;
        if (child != null) {
            int idx = this.children.indexOf(child);
            PlayerAction pa = (PlayerAction)this.actions.get(idx);
            for (Pair<Unit, UnitAction> ua : pa.getActions()) {
                Sampling.UnitActionTableEntry actionTable = this.getActionTableEntry((Unit)ua.m_a);
                idx = actionTable.actions.indexOf(ua.m_b);
                if (idx == -1) {
                    System.out.println("Looking for action: " + ua.m_b);
                    System.out.println("Available actions are: " + actionTable.actions);
                }
                int n = idx;
                actionTable.accum_evaluation[n] = actionTable.accum_evaluation[n] + (double)evaluation;
                int n2 = idx;
                actionTable.visit_count[n2] = actionTable.visit_count[n2] + 1;
            }
        }
        if (this.parent != null) {
            ((MLPSNode)this.parent).propagateEvaluation(evaluation, this);
        }
    }

    public void printUnitActionTable() {
        for (Sampling.UnitActionTableEntry uat : this.unitActionTable) {
            System.out.println("Actions for unit " + uat.u);
            for (int i = 0; i < uat.nactions; ++i) {
                System.out.println("   " + uat.actions.get(i) + " visited " + uat.visit_count[i] + " with average evaluation " + uat.accum_evaluation[i] / (double)uat.visit_count[i]);
            }
        }
    }
}

