/*
 * Decompiled with CFR 0.152.
 */
package tests.bayesianmodels;

import ai.machinelearning.bayes.ActionInterdependenceModel;
import ai.machinelearning.bayes.BayesianModel;
import ai.machinelearning.bayes.BayesianModelByUnitTypeWithDefaultModel;
import ai.machinelearning.bayes.TrainingInstance;
import ai.machinelearning.bayes.featuregeneration.FeatureGenerator;
import ai.machinelearning.bayes.featuregeneration.FeatureGeneratorSimple;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.jdom.input.SAXBuilder;
import rts.UnitAction;
import rts.units.Unit;
import rts.units.UnitType;
import rts.units.UnitTypeTable;
import tests.bayesianmodels.PretrainNaiveBayesModels;

public class TestPretrainedBayesianModel {
    public static void main(String[] stringArray) throws Exception {
        UnitTypeTable unitTypeTable = new UnitTypeTable();
        FeatureGeneratorSimple featureGeneratorSimple = new FeatureGeneratorSimple();
        TestPretrainedBayesianModel.test(new BayesianModelByUnitTypeWithDefaultModel(new SAXBuilder().build("data/bayesianmodels/pretrained/ActionInterdependenceModel-WR.xml").getRootElement(), unitTypeTable, new ActionInterdependenceModel(null, 0, 0, 0.0, unitTypeTable, featureGeneratorSimple, ""), "AIM_WR"), "data/bayesianmodels/trainingdata/learning-traces-500", "AI0", unitTypeTable, featureGeneratorSimple);
    }

    public static void test(BayesianModel bayesianModel, String string, String string2, UnitTypeTable unitTypeTable, FeatureGenerator featureGenerator) throws Exception {
        Object object;
        List<TrainingInstance> list = PretrainNaiveBayesModels.generateInstances(string, string2);
        System.out.println(list.size() + " instances generated.");
        ArrayList<List<Object>> arrayList2 = new ArrayList<List<Object>>();
        for (TrainingInstance object22 : list) {
            arrayList2.add(featureGenerator.generateFeatures(object22));
        }
        int n = ((List)arrayList2.get(0)).size();
        int[] nArray = new int[n];
        ArrayList<int[]> arrayList = new ArrayList<int[]>();
        for (List list2 : arrayList2) {
            object = new int[list2.size()];
            for (int trainingInstance = 0; trainingInstance < list2.size(); ++trainingInstance) {
                object[trainingInstance] = (Integer)list2.get(trainingInstance);
                if (object[trainingInstance] < nArray[trainingInstance]) continue;
                nArray[trainingInstance] = (int)(object[trainingInstance] + true);
            }
            arrayList.add((int[])object);
        }
        List<UnitAction> list3 = TestPretrainedBayesianModel.generateAllPossibleUnitActions(unitTypeTable);
        System.out.println(list3.size() + " labels: " + list3);
        ArrayList<Integer> arrayList3 = new ArrayList<Integer>();
        object = list.iterator();
        while (object.hasNext()) {
            TrainingInstance trainingInstance = (TrainingInstance)object.next();
            int n2 = list3.indexOf(trainingInstance.ua);
            if (n2 < 0) {
                throw new Exception("Undefined action " + trainingInstance.ua);
            }
            arrayList3.add(n2);
        }
        System.out.println(" /------------ start testing " + string + " - " + string2 + " --------------\\ ");
        TestPretrainedBayesianModel.crossValidation(bayesianModel, arrayList, arrayList3, list, list3, 10, true, false);
        System.out.println(" \\------------ end testing " + string + " - " + string2 + " --------------/ ");
    }

    public static double crossValidation(BayesianModel bayesianModel, List<int[]> list, List<Integer> list2, List<TrainingInstance> list3, List<UnitAction> list4, int n, boolean bl, boolean bl2) throws Exception {
        int n2;
        int n3;
        Random random = new Random();
        List[] listArray = new List[n];
        int n4 = list.get(0).length;
        int[] nArray = new int[n4];
        int n5 = 0;
        UnitTypeTable unitTypeTable = list3.get((int)0).gs.getUnitTypeTable();
        for (n3 = 0; n3 < n; ++n3) {
            listArray[n3] = new ArrayList();
        }
        for (n3 = 0; n3 < list.size(); ++n3) {
            int n6 = random.nextInt(n);
            listArray[n6].add(n3);
            for (int j = 0; j < n4; ++j) {
                if (list.get(n3)[j] < nArray[j]) continue;
                nArray[j] = list.get(n3)[j] + 1;
            }
            if (list2.get(n3) < n5) continue;
            n5 = list2.get(n3) + 1;
        }
        if (bl) {
            System.out.println("Xsizes: " + Arrays.toString(nArray));
        }
        if (bl) {
            System.out.println("Ysize: " + n5);
        }
        double[] dArray = new double[unitTypeTable.getUnitTypes().size()];
        double[] dArray2 = new double[unitTypeTable.getUnitTypes().size()];
        double[] dArray3 = new double[unitTypeTable.getUnitTypes().size()];
        for (int j = 0; j < n; ++j) {
            int n7;
            Object object;
            if (bl) {
                System.out.println("Evaluating fold " + (j + 1) + "/" + n + ":");
            }
            ArrayList<int[]> arrayList = new ArrayList<int[]>();
            ArrayList<Integer> arrayList2 = new ArrayList<Integer>();
            ArrayList<TrainingInstance> arrayList3 = new ArrayList<TrainingInstance>();
            ArrayList<int[]> arrayList4 = new ArrayList<int[]>();
            ArrayList<Integer> arrayList5 = new ArrayList<Integer>();
            ArrayList<TrainingInstance> arrayList6 = new ArrayList<TrainingInstance>();
            for (int k = 0; k < n; ++k) {
                int n8;
                if (k == j) {
                    object = listArray[k].iterator();
                    while (object.hasNext()) {
                        n8 = (Integer)object.next();
                        arrayList4.add(list.get(n8));
                        arrayList5.add(list2.get(n8));
                        arrayList6.add(list3.get(n8));
                    }
                    continue;
                }
                object = listArray[k].iterator();
                while (object.hasNext()) {
                    n8 = (Integer)object.next();
                    arrayList.add(list.get(n8));
                    arrayList2.add(list2.get(n8));
                    arrayList3.add(list3.get(n8));
                }
            }
            if (bl) {
                System.out.println("  training/test split is " + arrayList.size() + "/" + arrayList4.size());
            }
            bayesianModel.clearTraining();
            bayesianModel.train(arrayList, arrayList2, arrayList3);
            if (bl2) {
                bayesianModel.calibrateProbabilities(arrayList, arrayList2, arrayList3);
            }
            int[] nArray2 = new int[unitTypeTable.getUnitTypes().size()];
            object = new int[unitTypeTable.getUnitTypes().size()];
            double[] dArray4 = new double[unitTypeTable.getUnitTypes().size()];
            double d = 0.0;
            for (int k = 0; k < arrayList4.size(); ++k) {
                int n9;
                Unit unit = ((TrainingInstance)arrayList6.get((int)k)).u;
                List<UnitAction> list5 = unit.getUnitActions(((TrainingInstance)arrayList6.get((int)k)).gs);
                ArrayList<Integer> arrayList7 = new ArrayList<Integer>();
                for (UnitAction unitAction : list5) {
                    if (unitAction.getType() == 5) {
                        unitAction = new UnitAction(5, unitAction.getLocationX() - unit.getX(), unitAction.getLocationY() - unit.getY());
                    }
                    if ((n9 = list4.indexOf(unitAction)) < 0) {
                        throw new Exception("Unknown action: " + unitAction);
                    }
                    arrayList7.add(n9);
                }
                if (list5.size() <= 1) continue;
                d += (double)list5.size();
                Object object2 = bayesianModel.predictDistribution((int[])arrayList4.get(k), (TrainingInstance)arrayList6.get(k));
                object2 = bayesianModel.filterByPossibleActionIndexes((double[])object2, arrayList7);
                int n10 = (Integer)arrayList5.get(k);
                if (!arrayList7.contains(n10)) {
                    System.out.println("Actual action in the dataset is not possible!");
                    continue;
                }
                n9 = -1;
                Collections.shuffle(list5);
                Iterator iterator = arrayList7.iterator();
                while (iterator.hasNext()) {
                    int n11 = (Integer)iterator.next();
                    if (n9 == -1) {
                        n9 = n11;
                        continue;
                    }
                    if (!(object2[n11] > object2[n9])) continue;
                    n9 = n11;
                }
                if (n9 == n10) {
                    int n12 = unit.getType().ID;
                    nArray2[n12] = nArray2[n12] + 1;
                }
                Object object3 = object;
                int n13 = unit.getType().ID;
                object3[n13] = object3[n13] + true;
                double d2 = Math.log((double)object2[n10]);
                if (Double.isInfinite(d2)) {
                    System.out.println(Arrays.toString((double[])object2));
                    System.out.println(arrayList7);
                    System.out.println(n10 + " : " + list4.get(n10));
                    System.exit(1);
                }
                int n14 = unit.getType().ID;
                dArray4[n14] = dArray4[n14] + d2;
            }
            double[] dArray5 = new double[unitTypeTable.getUnitTypes().size()];
            if (bl) {
                System.out.println("Average possible actions: " + d / (double)arrayList4.size());
            }
            for (n7 = 0; n7 < unitTypeTable.getUnitTypes().size(); ++n7) {
                dArray5[n7] = (double)nArray2[n7] / (double)object[n7];
                if (bl) {
                    System.out.println("Fold accuracy (" + unitTypeTable.getUnitTypes().get((int)n7).name + "): " + dArray5[n7] + "   (" + nArray2[n7] + "/" + (int)object[n7] + ")");
                }
                int n15 = n7;
                dArray[n15] = dArray[n15] + (double)nArray2[n7];
                int n16 = n7;
                dArray2[n16] = dArray2[n16] + (double)object[n7];
            }
            for (n7 = 0; n7 < unitTypeTable.getUnitTypes().size(); ++n7) {
                if (bl) {
                    System.out.println("Fold loglikelihood (" + unitTypeTable.getUnitTypes().get((int)n7).name + "): " + dArray4[n7] + " (average: " + dArray4[n7] / (double)object[n7] + ")");
                }
                int n17 = n7;
                dArray3[n17] = dArray3[n17] + dArray4[n7];
            }
        }
        if (bl) {
            System.out.println(" ---------- ");
        }
        double d = 0.0;
        double d3 = 0.0;
        double d4 = 0.0;
        for (n2 = 0; n2 < unitTypeTable.getUnitTypes().size(); ++n2) {
            double d5 = dArray[n2] / dArray2[n2];
            if (bl) {
                System.out.println("Final accuracy (" + unitTypeTable.getUnitTypes().get((int)n2).name + "): " + d5 + "   (" + dArray[n2] + "/" + dArray2[n2] + ")");
            }
            d += dArray[n2];
            d3 += dArray2[n2];
        }
        for (n2 = 0; n2 < unitTypeTable.getUnitTypes().size(); ++n2) {
            if (bl) {
                System.out.println("Final loglikelihood (" + unitTypeTable.getUnitTypes().get((int)n2).name + "): " + dArray3[n2] + " (average: " + dArray3[n2] / dArray2[n2] + ")");
            }
            d4 += dArray3[n2];
        }
        double d6 = d / d3;
        if (bl) {
            System.out.println("Final accuracy: " + d6);
        }
        if (bl) {
            System.out.println("Final loglikelihood: " + d4 + " (average " + d4 / d3 + ")");
        }
        return d4 / d3;
    }

    public static List<UnitAction> generateAllPossibleUnitActions(UnitTypeTable unitTypeTable) {
        int n;
        ArrayList<UnitAction> arrayList = new ArrayList<UnitAction>();
        int[] nArray = new int[]{0, 1, 2, 3};
        arrayList.add(new UnitAction(0, 10));
        for (int n2 : nArray) {
            arrayList.add(new UnitAction(1, n2));
        }
        for (int n2 : nArray) {
            arrayList.add(new UnitAction(2, n2));
        }
        for (int n2 : nArray) {
            arrayList.add(new UnitAction(3, n2));
        }
        int[] nArray2 = nArray;
        int n3 = nArray2.length;
        for (n = 0; n < n3; ++n) {
            int n2;
            n2 = nArray2[n];
            for (UnitType unitType : unitTypeTable.getUnitTypes()) {
                arrayList.add(new UnitAction(4, n2, unitType));
            }
        }
        for (int j = -3; j <= 3; ++j) {
            for (n3 = -3; n3 <= 3; ++n3) {
                n = j * j + n3 * n3;
                if (n <= 0 || n > 9) continue;
                arrayList.add(new UnitAction(5, j, n3));
            }
        }
        return arrayList;
    }
}

