/*
 * Decompiled with CFR 0.152.
 */
package tests.bayesianmodels;

import ai.machinelearning.bayes.ActionInterdependenceModel;
import ai.machinelearning.bayes.BayesianModel;
import ai.machinelearning.bayes.BayesianModelByUnitTypeWithDefaultModel;
import ai.machinelearning.bayes.CalibratedNaiveBayes;
import ai.machinelearning.bayes.TrainingInstance;
import ai.machinelearning.bayes.featuregeneration.FeatureGenerator;
import ai.machinelearning.bayes.featuregeneration.FeatureGeneratorSimple;
import java.io.File;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.List;
import java.util.StringTokenizer;
import org.jdom.input.SAXBuilder;
import rts.GameState;
import rts.Trace;
import rts.TraceEntry;
import rts.UnitAction;
import rts.units.Unit;
import rts.units.UnitTypeTable;
import util.Pair;
import util.XMLWriter;

public class PretrainNaiveBayesModels {
    public static int CALIBRATED_NAIVE_BAYES = 0;
    public static int ACTION_INTERDEPENDENCE_MODEL = 1;
    public static int CALIBRATED_NAIVE_BAYES_BY_UNIT_TYPE = 2;
    public static int ACTION_INTERDEPENDENCE_MODEL_BY_UNIT_TYPE = 3;

    public static void main(String[] stringArray) throws Exception {
        PretrainNaiveBayesModels.pretrain("data/bayesianmodels/trainingdata/learning-traces-500", "AI0", "data/bayesianmodels/pretrained/ActionInterdependenceModel-WR.xml", ACTION_INTERDEPENDENCE_MODEL_BY_UNIT_TYPE, new FeatureGeneratorSimple());
        PretrainNaiveBayesModels.pretrain("data/bayesianmodels/trainingdata/learning-traces-500", "AI1", "data/bayesianmodels/pretrained/ActionInterdependenceModel-LR.xml", ACTION_INTERDEPENDENCE_MODEL_BY_UNIT_TYPE, new FeatureGeneratorSimple());
        PretrainNaiveBayesModels.pretrain("data/bayesianmodels/trainingdata/learning-traces-500", "AI2", "data/bayesianmodels/pretrained/ActionInterdependenceModel-HR.xml", ACTION_INTERDEPENDENCE_MODEL_BY_UNIT_TYPE, new FeatureGeneratorSimple());
        PretrainNaiveBayesModels.pretrain("data/bayesianmodels/trainingdata/learning-traces-500", "AI3", "data/bayesianmodels/pretrained/ActionInterdependenceModel-RR.xml", ACTION_INTERDEPENDENCE_MODEL_BY_UNIT_TYPE, new FeatureGeneratorSimple());
        PretrainNaiveBayesModels.pretrain("data/bayesianmodels/trainingdata/learning-traces-500", "AI4", "data/bayesianmodels/pretrained/ActionInterdependenceModel-LSI500.xml", ACTION_INTERDEPENDENCE_MODEL_BY_UNIT_TYPE, new FeatureGeneratorSimple());
        PretrainNaiveBayesModels.pretrain("data/bayesianmodels/trainingdata/learning-traces-500", "AI5", "data/bayesianmodels/pretrained/ActionInterdependenceModel-NaiveMCTS500.xml", ACTION_INTERDEPENDENCE_MODEL_BY_UNIT_TYPE, new FeatureGeneratorSimple());
        PretrainNaiveBayesModels.pretrain("data/bayesianmodels/trainingdata/learning-traces-10000", "AI4", "data/bayesianmodels/pretrained/ActionInterdependenceModel-LSI10000.xml", ACTION_INTERDEPENDENCE_MODEL_BY_UNIT_TYPE, new FeatureGeneratorSimple());
        PretrainNaiveBayesModels.pretrain("data/bayesianmodels/trainingdata/learning-traces-10000", "AI5", "data/bayesianmodels/pretrained/ActionInterdependenceModel-NaiveMCTS10000.xml", ACTION_INTERDEPENDENCE_MODEL_BY_UNIT_TYPE, new FeatureGeneratorSimple());
    }

    public static void pretrain(String string, String string2, String string3, int n, FeatureGenerator featureGenerator) throws Exception {
        Object object;
        UnitTypeTable unitTypeTable = new UnitTypeTable();
        List<TrainingInstance> list = PretrainNaiveBayesModels.generateInstances(string, string2);
        System.out.println(list.size() + " instances generated.");
        ArrayList<List<Object>> arrayList2 = new ArrayList<List<Object>>();
        for (TrainingInstance object22 : list) {
            arrayList2.add(featureGenerator.generateFeatures(object22));
        }
        int n2 = ((List)arrayList2.get(0)).size();
        int[] nArray = new int[n2];
        ArrayList<int[]> arrayList = new ArrayList<int[]>();
        for (List list2 : arrayList2) {
            object = new int[list2.size()];
            for (int trainingInstance = 0; trainingInstance < list2.size(); ++trainingInstance) {
                object[trainingInstance] = (Integer)list2.get(trainingInstance);
                if (object[trainingInstance] < nArray[trainingInstance]) continue;
                nArray[trainingInstance] = (int)(object[trainingInstance] + true);
            }
            arrayList.add((int[])object);
        }
        List<UnitAction> list3 = BayesianModel.generateAllPossibleUnitActions(unitTypeTable);
        System.out.println(list3.size() + " labels: " + list3);
        ArrayList<Integer> arrayList3 = new ArrayList<Integer>();
        object = list.iterator();
        while (object.hasNext()) {
            TrainingInstance trainingInstance = (TrainingInstance)object.next();
            int n3 = list3.indexOf(trainingInstance.ua);
            if (n3 < 0) {
                throw new Exception("Undefined action " + trainingInstance.ua);
            }
            arrayList3.add(n3);
        }
        System.out.println("Dataset generated, ready to learn");
        object = null;
        if (n == CALIBRATED_NAIVE_BAYES) {
            object = new CalibratedNaiveBayes(nArray, list3.size(), 2, 0.0, unitTypeTable, featureGenerator, "CNB");
        } else if (n == ACTION_INTERDEPENDENCE_MODEL) {
            object = new ActionInterdependenceModel(nArray, list3.size(), 2, 0.0, unitTypeTable, featureGenerator, "AIM");
        } else if (n == CALIBRATED_NAIVE_BAYES_BY_UNIT_TYPE) {
            object = new BayesianModelByUnitTypeWithDefaultModel(unitTypeTable, new CalibratedNaiveBayes(nArray, list3.size(), 2, 0.0, unitTypeTable, featureGenerator, "CNB"), "CNB");
        } else if (n == ACTION_INTERDEPENDENCE_MODEL_BY_UNIT_TYPE) {
            object = new BayesianModelByUnitTypeWithDefaultModel(unitTypeTable, new ActionInterdependenceModel(nArray, list3.size(), 2, 0.0, unitTypeTable, featureGenerator, "AIM"), "AIM");
        }
        ((BayesianModel)object).featureSelectionByCrossValidation(arrayList, arrayList3, list);
        ((BayesianModel)object).train(arrayList, arrayList3, list);
        ((BayesianModel)object).calibrateProbabilities(arrayList, arrayList3, list);
        XMLWriter xMLWriter = new XMLWriter(new FileWriter(string3));
        ((BayesianModel)object).save(xMLWriter);
        xMLWriter.close();
    }

    public static List<TrainingInstance> generateInstances(String string, String string2) throws Exception {
        ArrayList<TrainingInstance> arrayList = new ArrayList<TrainingInstance>();
        File file = new File(string);
        for (File file2 : file.listFiles()) {
            String string3 = file2.getAbsolutePath();
            if (!string3.endsWith(".xml")) continue;
            String string4 = file2.getName();
            StringTokenizer stringTokenizer = new StringTokenizer(string4, "-");
            stringTokenizer.nextToken();
            String string5 = stringTokenizer.nextToken();
            if (!string5.startsWith("map")) {
                string5 = stringTokenizer.nextToken();
            }
            String string6 = stringTokenizer.nextToken();
            String string7 = stringTokenizer.nextToken();
            int n = -1;
            if (string6.equals(string2)) {
                n = 0;
            }
            if (string7.equals(string2)) {
                n = 1;
            }
            if (n < 0) continue;
            Trace trace = new Trace(new SAXBuilder().build(string3).getRootElement());
            for (TraceEntry traceEntry : trace.getEntries()) {
                GameState gameState = trace.getGameStateAtCycle(traceEntry.getTime());
                for (Pair<Unit, UnitAction> pair : traceEntry.getActions()) {
                    if (((Unit)pair.m_a).getUnitActions(gameState).size() <= 1 || ((Unit)pair.m_a).getPlayer() != n) continue;
                    TrainingInstance trainingInstance = new TrainingInstance(gameState, ((Unit)pair.m_a).getID(), (UnitAction)pair.m_b);
                    List<UnitAction> list = ((Unit)pair.m_a).getUnitActions(gameState);
                    if (!list.contains(pair.m_b)) {
                        System.out.println("invalid instance...: " + pair.m_b);
                        continue;
                    }
                    arrayList.add(trainingInstance);
                }
            }
        }
        return arrayList;
    }
}

