/*
 * Decompiled with CFR 0.152.
 */
package tests.bayesianmodels;

import ai.machinelearning.bayes.ActionInterdependenceModel;
import ai.machinelearning.bayes.BayesianModel;
import ai.machinelearning.bayes.BayesianModelByUnitTypeWithDefaultModel;
import ai.machinelearning.bayes.TrainingInstance;
import ai.machinelearning.bayes.featuregeneration.FeatureGenerator;
import ai.machinelearning.bayes.featuregeneration.FeatureGeneratorSimple;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.jdom.input.SAXBuilder;
import rts.UnitAction;
import rts.units.Unit;
import rts.units.UnitType;
import rts.units.UnitTypeTable;
import tests.bayesianmodels.PretrainNaiveBayesModels;

public class TestPretrainedBayesianModel {
    public static void main(String[] args) throws Exception {
        UnitTypeTable utt = new UnitTypeTable();
        FeatureGeneratorSimple fg = new FeatureGeneratorSimple();
        TestPretrainedBayesianModel.test(new BayesianModelByUnitTypeWithDefaultModel(new SAXBuilder().build("data/bayesianmodels/pretrained/ActionInterdependenceModel-WR.xml").getRootElement(), utt, new ActionInterdependenceModel(null, 0, 0, 0.0, utt, fg, ""), "AIM_WR"), "data/bayesianmodels/trainingdata/learning-traces-500", "AI0", utt, fg);
    }

    public static void test(BayesianModel model, String tracesFolder, String AIname, UnitTypeTable utt, FeatureGenerator fg) throws Exception {
        List<TrainingInstance> instances = PretrainNaiveBayesModels.generateInstances(tracesFolder, AIname);
        System.out.println(instances.size() + " instances generated.");
        ArrayList<List<Object>> features = new ArrayList<List<Object>>();
        for (TrainingInstance ti : instances) {
            features.add(fg.generateFeatures(ti));
        }
        int nfeatures = ((List)features.get(0)).size();
        int[] Xsizes = new int[nfeatures];
        ArrayList<int[]> X_l = new ArrayList<int[]>();
        for (List list : features) {
            int[] x = new int[list.size()];
            for (int i = 0; i < list.size(); ++i) {
                x[i] = (Integer)list.get(i);
                if (x[i] < Xsizes[i]) continue;
                Xsizes[i] = x[i] + 1;
            }
            X_l.add(x);
        }
        List<UnitAction> allPossibleActions = TestPretrainedBayesianModel.generateAllPossibleUnitActions(utt);
        System.out.println(allPossibleActions.size() + " labels: " + allPossibleActions);
        ArrayList<Integer> arrayList = new ArrayList<Integer>();
        for (TrainingInstance ti : instances) {
            int idx = allPossibleActions.indexOf(ti.ua);
            if (idx < 0) {
                throw new Exception("Undefined action " + ti.ua);
            }
            arrayList.add(idx);
        }
        System.out.println(" /------------ start testing " + tracesFolder + " - " + AIname + " --------------\\ ");
        TestPretrainedBayesianModel.crossValidation(model, X_l, arrayList, instances, allPossibleActions, 10, true, false);
        System.out.println(" \\------------ end testing " + tracesFolder + " - " + AIname + " --------------/ ");
    }

    public static double crossValidation(BayesianModel model, List<int[]> X_l, List<Integer> Y_l, List<TrainingInstance> instances, List<UnitAction> allPossibleActions, int nfolds, boolean DEBUG, boolean calibrate) throws Exception {
        int i;
        int i2;
        Random r = new Random();
        List[] folds = new List[nfolds];
        int nfeatures = X_l.get(0).length;
        int[] Xsizes = new int[nfeatures];
        int Ysize = 0;
        UnitTypeTable utt = instances.get((int)0).gs.getUnitTypeTable();
        for (i2 = 0; i2 < nfolds; ++i2) {
            folds[i2] = new ArrayList();
        }
        for (i2 = 0; i2 < X_l.size(); ++i2) {
            int fold = r.nextInt(nfolds);
            folds[fold].add(i2);
            for (int j = 0; j < nfeatures; ++j) {
                if (X_l.get(i2)[j] < Xsizes[j]) continue;
                Xsizes[j] = X_l.get(i2)[j] + 1;
            }
            if (Y_l.get(i2) < Ysize) continue;
            Ysize = Y_l.get(i2) + 1;
        }
        if (DEBUG) {
            System.out.println("Xsizes: " + Arrays.toString(Xsizes));
        }
        if (DEBUG) {
            System.out.println("Ysize: " + Ysize);
        }
        double[] correct_per_unit = new double[utt.getUnitTypes().size()];
        double[] total_per_unit = new double[utt.getUnitTypes().size()];
        double[] loglikelihood_per_unit = new double[utt.getUnitTypes().size()];
        for (int fold = 0; fold < nfolds; ++fold) {
            int i3;
            if (DEBUG) {
                System.out.println("Evaluating fold " + (fold + 1) + "/" + nfolds + ":");
            }
            ArrayList<int[]> X_training = new ArrayList<int[]>();
            ArrayList<Integer> Y_training = new ArrayList<Integer>();
            ArrayList<TrainingInstance> i_training = new ArrayList<TrainingInstance>();
            ArrayList<int[]> X_test = new ArrayList<int[]>();
            ArrayList<Integer> Y_test = new ArrayList<Integer>();
            ArrayList<TrainingInstance> i_test = new ArrayList<TrainingInstance>();
            for (int i4 = 0; i4 < nfolds; ++i4) {
                int idx;
                Iterator iterator;
                if (i4 == fold) {
                    iterator = folds[i4].iterator();
                    while (iterator.hasNext()) {
                        idx = (Integer)iterator.next();
                        X_test.add(X_l.get(idx));
                        Y_test.add(Y_l.get(idx));
                        i_test.add(instances.get(idx));
                    }
                    continue;
                }
                iterator = folds[i4].iterator();
                while (iterator.hasNext()) {
                    idx = (Integer)iterator.next();
                    X_training.add(X_l.get(idx));
                    Y_training.add(Y_l.get(idx));
                    i_training.add(instances.get(idx));
                }
            }
            if (DEBUG) {
                System.out.println("  training/test split is " + X_training.size() + "/" + X_test.size());
            }
            model.clearTraining();
            model.train(X_training, Y_training, i_training);
            if (calibrate) {
                model.calibrateProbabilities(X_training, Y_training, i_training);
            }
            int[] fold_correct_per_unit = new int[utt.getUnitTypes().size()];
            int[] fold_total_per_unit = new int[utt.getUnitTypes().size()];
            double[] fold_loglikelihood_per_unit = new double[utt.getUnitTypes().size()];
            double numPossibleActionsAccum = 0.0;
            for (int i5 = 0; i5 < X_test.size(); ++i5) {
                Unit u = ((TrainingInstance)i_test.get((int)i5)).u;
                List<UnitAction> possibleUnitActions = u.getUnitActions(((TrainingInstance)i_test.get((int)i5)).gs);
                ArrayList<Integer> possibleUnitActionIndexes = new ArrayList<Integer>();
                for (UnitAction ua : possibleUnitActions) {
                    int idx;
                    if (ua.getType() == 5) {
                        ua = new UnitAction(5, ua.getLocationX() - u.getX(), ua.getLocationY() - u.getY());
                    }
                    if ((idx = allPossibleActions.indexOf(ua)) < 0) {
                        throw new Exception("Unknown action: " + ua);
                    }
                    possibleUnitActionIndexes.add(idx);
                }
                if (possibleUnitActions.size() <= 1) continue;
                numPossibleActionsAccum += (double)possibleUnitActions.size();
                double[] predicted_distribution = model.predictDistribution((int[])X_test.get(i5), (TrainingInstance)i_test.get(i5));
                predicted_distribution = model.filterByPossibleActionIndexes(predicted_distribution, possibleUnitActionIndexes);
                int actual_y = (Integer)Y_test.get(i5);
                if (!possibleUnitActionIndexes.contains(actual_y)) {
                    System.out.println("Actual action in the dataset is not possible!");
                    continue;
                }
                int predicted_y = -1;
                Collections.shuffle(possibleUnitActions);
                Iterator iterator = possibleUnitActionIndexes.iterator();
                while (iterator.hasNext()) {
                    int idx = (Integer)iterator.next();
                    if (predicted_y == -1) {
                        predicted_y = idx;
                        continue;
                    }
                    if (!(predicted_distribution[idx] > predicted_distribution[predicted_y])) continue;
                    predicted_y = idx;
                }
                if (predicted_y == actual_y) {
                    int n = u.getType().ID;
                    fold_correct_per_unit[n] = fold_correct_per_unit[n] + 1;
                }
                int n = u.getType().ID;
                fold_total_per_unit[n] = fold_total_per_unit[n] + 1;
                double loglikelihood = Math.log(predicted_distribution[actual_y]);
                if (Double.isInfinite(loglikelihood)) {
                    System.out.println(Arrays.toString(predicted_distribution));
                    System.out.println(possibleUnitActionIndexes);
                    System.out.println(actual_y + " : " + allPossibleActions.get(actual_y));
                    System.exit(1);
                }
                int n2 = u.getType().ID;
                fold_loglikelihood_per_unit[n2] = fold_loglikelihood_per_unit[n2] + loglikelihood;
            }
            double[] fold_accuracy_per_unit = new double[utt.getUnitTypes().size()];
            if (DEBUG) {
                System.out.println("Average possible actions: " + numPossibleActionsAccum / (double)X_test.size());
            }
            for (i3 = 0; i3 < utt.getUnitTypes().size(); ++i3) {
                fold_accuracy_per_unit[i3] = (double)fold_correct_per_unit[i3] / (double)fold_total_per_unit[i3];
                if (DEBUG) {
                    System.out.println("Fold accuracy (" + utt.getUnitTypes().get((int)i3).name + "): " + fold_accuracy_per_unit[i3] + "   (" + fold_correct_per_unit[i3] + "/" + fold_total_per_unit[i3] + ")");
                }
                int n = i3;
                correct_per_unit[n] = correct_per_unit[n] + (double)fold_correct_per_unit[i3];
                int n3 = i3;
                total_per_unit[n3] = total_per_unit[n3] + (double)fold_total_per_unit[i3];
            }
            for (i3 = 0; i3 < utt.getUnitTypes().size(); ++i3) {
                if (DEBUG) {
                    System.out.println("Fold loglikelihood (" + utt.getUnitTypes().get((int)i3).name + "): " + fold_loglikelihood_per_unit[i3] + " (average: " + fold_loglikelihood_per_unit[i3] / (double)fold_total_per_unit[i3] + ")");
                }
                int n = i3;
                loglikelihood_per_unit[n] = loglikelihood_per_unit[n] + fold_loglikelihood_per_unit[i3];
            }
        }
        if (DEBUG) {
            System.out.println(" ---------- ");
        }
        double correct = 0.0;
        double total = 0.0;
        double loglikelihood = 0.0;
        for (i = 0; i < utt.getUnitTypes().size(); ++i) {
            double accuracy_per_unit = correct_per_unit[i] / total_per_unit[i];
            if (DEBUG) {
                System.out.println("Final accuracy (" + utt.getUnitTypes().get((int)i).name + "): " + accuracy_per_unit + "   (" + correct_per_unit[i] + "/" + total_per_unit[i] + ")");
            }
            correct += correct_per_unit[i];
            total += total_per_unit[i];
        }
        for (i = 0; i < utt.getUnitTypes().size(); ++i) {
            if (DEBUG) {
                System.out.println("Final loglikelihood (" + utt.getUnitTypes().get((int)i).name + "): " + loglikelihood_per_unit[i] + " (average: " + loglikelihood_per_unit[i] / total_per_unit[i] + ")");
            }
            loglikelihood += loglikelihood_per_unit[i];
        }
        double accuracy = correct / total;
        if (DEBUG) {
            System.out.println("Final accuracy: " + accuracy);
        }
        if (DEBUG) {
            System.out.println("Final loglikelihood: " + loglikelihood + " (average " + loglikelihood / total + ")");
        }
        return loglikelihood / total;
    }

    public static List<UnitAction> generateAllPossibleUnitActions(UnitTypeTable utt) {
        ArrayList<UnitAction> l = new ArrayList<UnitAction>();
        int[] directions = new int[]{0, 1, 2, 3};
        l.add(new UnitAction(0, 10));
        for (int d : directions) {
            l.add(new UnitAction(1, d));
        }
        for (int d : directions) {
            l.add(new UnitAction(2, d));
        }
        for (int d : directions) {
            l.add(new UnitAction(3, d));
        }
        for (int d : directions) {
            for (UnitType ut : utt.getUnitTypes()) {
                l.add(new UnitAction(4, d, ut));
            }
        }
        for (int ox = -3; ox <= 3; ++ox) {
            for (int oy = -3; oy <= 3; ++oy) {
                int d = ox * ox + oy * oy;
                if (d <= 0 || d > 9) continue;
                l.add(new UnitAction(5, ox, oy));
            }
        }
        return l;
    }
}

