/*
 * Decompiled with CFR 0.152.
 */
package ai.montecarlo.lsi;

import ai.core.AI;
import ai.evaluation.EvaluationFunction;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import rts.GameState;
import rts.PhysicalGameState;
import rts.PlayerAction;
import rts.ResourceUsage;
import rts.UnitAction;
import rts.UnitActionAssignment;
import rts.units.Unit;
import util.CartesianProduct;
import util.Pair;
import util.Sampler;

public class Sampling {
    private final AgentOrderingType agentOrderingType;
    private final int lookAhead;
    private final EvaluationFunction evalFunction;
    private final AI simulationAi;
    private int simulationCount = 0;

    public Sampling(AgentOrderingType agentOrderingType, int lookAhead, AI simulationAi, EvaluationFunction evalFunction) {
        this.agentOrderingType = agentOrderingType;
        this.lookAhead = lookAhead;
        this.evalFunction = evalFunction;
        this.simulationAi = simulationAi;
    }

    public double evaluatePlayerAction(int player, GameState gs, PlayerAction playerAction, int numEval) throws Exception {
        double evalMean = 0.0;
        for (int step = 0; step < numEval; ++step) {
            GameState gs2 = gs.cloneIssue(playerAction);
            GameState gs3 = gs2.clone();
            this.simulate(gs3, gs3.getTime() + this.lookAhead);
            int time = gs3.getTime() - gs2.getTime();
            double eval = (double)this.evalFunction.evaluate(player, 1 - player, gs3) * Math.pow(0.99, (double)time / 10.0);
            evalMean = ((double)step * evalMean + eval) / (double)(step + 1);
        }
        return evalMean;
    }

    private void simulate(GameState gs, int lookaheadTime) throws Exception {
        ++this.simulationCount;
        boolean gameover = false;
        do {
            if (gs.isComplete()) {
                gameover = gs.cycle();
                continue;
            }
            gs.issue(this.simulationAi.getAction(0, gs));
            gs.issue(this.simulationAi.getAction(1, gs));
        } while (!gameover && gs.getTime() < lookaheadTime);
    }

    public PlayerAction generatePlayerActionGivenDist(List<UnitActionTableEntry> unitActionTable, int player, GameState gameState, List<double[]> distributions, List<Integer> forcedAgentOrder) throws Exception {
        ArrayList<Pair<Integer, Double>> ent_list;
        PlayerAction pa;
        PhysicalGameState pgs;
        block14: {
            block13: {
                ResourceUsage base_ru = new ResourceUsage();
                pgs = gameState.getPhysicalGameState();
                for (Unit u : pgs.getUnits()) {
                    UnitActionAssignment uaa = gameState.getUnitActions().get(u);
                    if (uaa == null) continue;
                    ResourceUsage resourceUsage = uaa.action.resourceUsage(u, pgs);
                    base_ru.merge(resourceUsage);
                }
                pa = new PlayerAction();
                pa.setResourceUsage(base_ru.clone());
                ent_list = new ArrayList<Pair<Integer, Double>>(distributions.size());
                if (forcedAgentOrder != null) break block13;
                for (int j = 0; j < distributions.size(); ++j) {
                    ent_list.add(new Pair<Integer, Double>(j, this.entropy(distributions.get(j))));
                }
                switch (this.agentOrderingType) {
                    case RANDOM: {
                        Collections.shuffle(ent_list);
                        break block14;
                    }
                    case ENTROPY: {
                        ent_list.sort(new Comparator<Pair<Integer, Double>>(){

                            @Override
                            public int compare(Pair<Integer, Double> p1, Pair<Integer, Double> p2) {
                                return ((Double)p1.m_b).compareTo((Double)p2.m_b);
                            }
                        });
                        break block14;
                    }
                    default: {
                        throw new RuntimeException("Unknown AgentOrderingType");
                    }
                }
            }
            for (Integer n : forcedAgentOrder) {
                ent_list.add(new Pair<Integer, Double>(n, 0.0));
            }
        }
        for (Pair pair : ent_list) {
            double[] distribution = distributions.get((Integer)pair.m_a);
            UnitActionTableEntry ate = unitActionTable.get((Integer)pair.m_a);
            int code = Sampler.weighted(distribution);
            UnitAction ua = ate.actions.get(code);
            ResourceUsage r2 = ua.resourceUsage(ate.u, pgs);
            if (!pa.getResourceUsage().consistentWith(r2, gameState)) {
                ArrayList<Double> dist_l = new ArrayList<Double>();
                ArrayList<Integer> dist_outputs = new ArrayList<Integer>();
                for (int j = 0; j < distribution.length; ++j) {
                    dist_l.add(distribution[j]);
                    dist_outputs.add(j);
                }
                do {
                    int idx = dist_outputs.indexOf(code);
                    dist_l.remove(idx);
                    dist_outputs.remove(idx);
                    code = (Integer)Sampler.weighted(dist_l, dist_outputs);
                    ua = ate.actions.get(code);
                    r2 = ua.resourceUsage(ate.u, pgs);
                } while (!pa.getResourceUsage().consistentWith(r2, gameState));
            }
            pa.getResourceUsage().merge(r2);
            pa.addUnitAction(ate.u, ua);
        }
        PlayerAction orderedPA = new PlayerAction();
        for (UnitActionTableEntry agentTableEntry : unitActionTable) {
            for (Pair<Unit, UnitAction> pair : pa.getActions()) {
                if (!((Unit)pair.m_a).equals(agentTableEntry.u)) continue;
                orderedPA.addUnitAction((Unit)pair.m_a, (UnitAction)pair.m_b);
            }
        }
        pa = orderedPA;
        return pa;
    }

    public PlayerAction generatePlayerActionOneDist(List<UnitActionTableEntry> unitActionTable, int player, GameState gameState, List<double[]> distributions) throws Exception {
        ResourceUsage base_ru = new ResourceUsage();
        PhysicalGameState pgs = gameState.getPhysicalGameState();
        for (Unit u : pgs.getUnits()) {
            UnitActionAssignment uaa = gameState.getUnitActions().get(u);
            if (uaa == null) continue;
            ResourceUsage ru = uaa.action.resourceUsage(u, pgs);
            base_ru.merge(ru);
        }
        PlayerAction pa = new PlayerAction();
        pa.setResourceUsage(base_ru.clone());
        ArrayList idxTable = new ArrayList();
        ArrayList distTable = new ArrayList();
        int i = 0;
        for (double[] actionDist : distributions) {
            double sum = 0.0;
            ArrayList<Double> distList = new ArrayList<Double>();
            ArrayList<Integer> idxList = new ArrayList<Integer>();
            for (int j = 0; j < actionDist.length; ++j) {
                distList.add(actionDist[j]);
                idxList.add(j);
                sum += actionDist[j];
            }
            Pair distPair = new Pair(sum, distList);
            Pair idxPair = new Pair(i, idxList);
            distTable.add(distPair);
            idxTable.add(idxPair);
            ++i;
        }
        double density = 0.0;
        for (Pair pair : distTable) {
            density += ((Double)pair.m_a).doubleValue();
        }
        block4: while (!distTable.isEmpty()) {
            Random gen = new Random();
            double d = gen.nextDouble() * density;
            for (int x = 0; x < distTable.size(); ++x) {
                if (d > (Double)((Pair)distTable.get((int)x)).m_a) {
                    d -= ((Double)((Pair)distTable.get((int)x)).m_a).doubleValue();
                    continue;
                }
                for (int y = 0; y < ((ArrayList)((Pair)distTable.get((int)x)).m_b).size(); ++y) {
                    if (d > (Double)((ArrayList)((Pair)distTable.get((int)x)).m_b).get(y)) {
                        d -= ((Double)((ArrayList)((Pair)distTable.get((int)x)).m_b).get(y)).doubleValue();
                        continue;
                    }
                    UnitActionTableEntry ate = unitActionTable.get((Integer)((Pair)idxTable.get((int)x)).m_a);
                    UnitAction ua = ate.actions.get((Integer)((ArrayList)((Pair)idxTable.get((int)x)).m_b).get(y));
                    ResourceUsage r2 = ua.resourceUsage(ate.u, pgs);
                    if (!pa.getResourceUsage().consistentWith(r2, gameState)) {
                        density -= ((Double)((ArrayList)((Pair)distTable.get((int)x)).m_b).get(y)).doubleValue();
                        Pair pair = (Pair)distTable.get(x);
                        Double.valueOf((Double)pair.m_a - (Double)((ArrayList)((Pair)distTable.get((int)x)).m_b).get(y));
                        pair.m_a = pair.m_a;
                        ((ArrayList)((Pair)distTable.get((int)x)).m_b).remove(y);
                        ((ArrayList)((Pair)idxTable.get((int)x)).m_b).remove(y);
                        continue;
                    }
                    density -= ((Double)((Pair)distTable.get((int)x)).m_a).doubleValue();
                    distTable.remove(x);
                    idxTable.remove(x);
                    pa.getResourceUsage().merge(r2);
                    pa.addUnitAction(ate.u, ua);
                    continue block4;
                }
                continue block4;
            }
        }
        return pa;
    }

    public Set<PlayerAction> generatePlayerActionAll(List<UnitActionTableEntry> unitActionTable, int player, GameState gameState, boolean includeNoops) throws Exception {
        ResourceUsage base_ru = new ResourceUsage();
        PhysicalGameState pgs = gameState.getPhysicalGameState();
        for (Unit u : pgs.getUnits()) {
            UnitActionAssignment uaa = gameState.getUnitActions().get(u);
            if (uaa == null) continue;
            ResourceUsage ru = uaa.action.resourceUsage(u, pgs);
            base_ru.merge(ru);
        }
        HashSet<PlayerAction> actionSet = new HashSet<PlayerAction>();
        ArrayList definitionOfDomains = new ArrayList(unitActionTable.size());
        for (UnitActionTableEntry unitActionTableEntry : unitActionTable) {
            HashSet<Integer> domain = new HashSet<Integer>();
            for (int i = 0; i < unitActionTableEntry.nactions; ++i) {
                if (unitActionTableEntry.actions.get(i).getType() == 0 && !includeNoops) continue;
                domain.add(i);
            }
            definitionOfDomains.add(domain);
        }
        CartesianProduct product = new CartesianProduct(definitionOfDomains);
        int size = product.size();
        for (int elementIndex = 0; elementIndex < size; ++elementIndex) {
            List element = product.element(elementIndex);
            PlayerAction pa = new PlayerAction();
            pa.setResourceUsage(base_ru.clone());
            boolean isValid = true;
            for (int i = 0; i < element.size(); ++i) {
                int actionIndex = (Integer)element.get(i);
                UnitActionTableEntry unitActionTableEntry = unitActionTable.get(i);
                UnitAction unitAction = unitActionTableEntry.actions.get(actionIndex);
                if (!pa.consistentWith(unitAction.resourceUsage(unitActionTableEntry.u, pgs), gameState)) {
                    isValid = false;
                    break;
                }
                pa.addUnitAction(unitActionTableEntry.u, unitAction);
            }
            if (!isValid) continue;
            actionSet.add(pa);
        }
        if (actionSet.size() == 0) {
            actionSet.add(new PlayerAction());
        }
        return actionSet;
    }

    public List<Pair<PlayerAction, Pair<Double, Integer>>> halvedSampling(List<Pair<PlayerAction, Pair<Double, Integer>>> actionList, GameState gameState, int player, int num) throws Exception {
        for (Pair<PlayerAction, Pair<Double, Integer>> pair : actionList) {
            double eval = this.evaluatePlayerAction(player, gameState, (PlayerAction)pair.m_a, num);
            double oldEval = (Double)((Pair)pair.m_b).m_a;
            int oldNum = (Integer)((Pair)pair.m_b).m_b;
            ((Pair)pair.m_b).m_a = oldEval + eval;
            ((Pair)pair.m_b).m_b = oldNum + num;
        }
        actionList.sort(new Comparator<Pair<PlayerAction, Pair<Double, Integer>>>(){

            @Override
            public int compare(Pair<PlayerAction, Pair<Double, Integer>> p1, Pair<PlayerAction, Pair<Double, Integer>> p2) {
                double eval1 = (Double)((Pair)p1.m_b).m_a / (double)((Integer)((Pair)p1.m_b).m_b).intValue();
                double eval2 = (Double)((Pair)p2.m_b).m_a / (double)((Integer)((Pair)p2.m_b).m_b).intValue();
                return Double.compare(eval2, eval1);
            }
        });
        return actionList.subList(0, actionList.size() / 2 + 1);
    }

    public List<Pair<PlayerAction, Double>> halvedOriginalSampling(List<Pair<PlayerAction, Double>> actionList, GameState gameState, int player, int numEval, int numEvalPrevious) throws Exception {
        for (Pair<PlayerAction, Double> pair : actionList) {
            double eval = this.evaluatePlayerAction(player, gameState, (PlayerAction)pair.m_a, numEval);
            pair.m_b = ((Double)pair.m_b * (double)numEvalPrevious + eval * (double)numEval) / (double)(numEvalPrevious + numEval);
        }
        actionList.sort(new Comparator<Pair<PlayerAction, Double>>(){

            @Override
            public int compare(Pair<PlayerAction, Double> p1, Pair<PlayerAction, Double> p2) {
                return ((Double)p2.m_b).compareTo((Double)p1.m_b);
            }
        });
        return actionList.subList(0, actionList.size() / 2 + 1);
    }

    public List<Pair<PlayerAction, Double>> halvedOriginalSamplingFill(List<Pair<PlayerAction, Double>> actionList, GameState gameState, int player, int numEval, int numEvalPrevious) throws Exception {
        for (Pair<PlayerAction, Double> pair : actionList) {
            double eval = this.evaluatePlayerAction(player, gameState, (PlayerAction)pair.m_a, numEval);
            pair.m_b = ((Double)pair.m_b * (double)numEvalPrevious + eval * (double)numEval) / (double)(numEvalPrevious + numEval);
        }
        actionList.sort(new Comparator<Pair<PlayerAction, Double>>(){

            @Override
            public int compare(Pair<PlayerAction, Double> p1, Pair<PlayerAction, Double> p2) {
                return ((Double)p2.m_b).compareTo((Double)p1.m_b);
            }
        });
        return actionList.subList(0, actionList.size() / 2);
    }

    public double entropy(double[] distribution) {
        double sum = 0.0;
        for (double prob : distribution) {
            sum += prob;
        }
        double ent = 0.0;
        for (double prob : distribution) {
            if (prob == 0.0) continue;
            ent += -1.0 * (prob / sum) * Sampling.log(prob / sum, 2.0);
        }
        return ent;
    }

    public double difference(List<UnitActionTableEntry> unitActionTable, List<double[]> distributions, PlayerAction playerAction, int agentIndex) {
        Pair<Unit, UnitAction> ute = playerAction.getActions().get(agentIndex);
        int j = 0;
        for (UnitAction ua : unitActionTable.get((int)agentIndex).actions) {
            if (((UnitAction)ute.m_b).equals(ua)) break;
            ++j;
        }
        return distributions.get(agentIndex)[j] - distributions.get(agentIndex)[distributions.get(agentIndex).length - 1];
    }

    public void resetSimulationCount() {
        this.simulationCount = 0;
    }

    public int getSimulationCount() {
        return this.simulationCount;
    }

    public static double log(double x, double base) {
        return Math.log(x) / Math.log(base);
    }

    public void increaseSimulationCount(double d) {
        this.simulationCount = (int)((double)this.simulationCount + d);
    }

    public static enum AgentOrderingType {
        RANDOM,
        ENTROPY;

    }

    public static class UnitActionTableEntry {
        public int idx;
        public Unit u;
        public int nactions = 0;
        public List<UnitAction> actions;
        public double[] accum_evaluation;
        public int[] visit_count;
    }
}

