/*
 * Decompiled with CFR 0.152.
 */
package ai.machinelearning.bayes;

import ai.machinelearning.bayes.BayesianModel;
import ai.machinelearning.bayes.TrainingInstance;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.jdom.Element;
import rts.units.UnitType;
import rts.units.UnitTypeTable;
import util.XMLWriter;

public class BayesianModelByUnitTypeWithDefaultModel
extends BayesianModel {
    BayesianModel templateModel = null;
    HashMap<UnitType, BayesianModel> unitModels = new HashMap();
    BayesianModel defaultModel = null;

    public BayesianModelByUnitTypeWithDefaultModel(UnitTypeTable unitTypeTable, BayesianModel bayesianModel, String string) {
        super(unitTypeTable, bayesianModel.featureGenerator, string);
        this.templateModel = bayesianModel;
    }

    @Override
    public Object clone() {
        return new BayesianModelByUnitTypeWithDefaultModel(this.utt, this.templateModel, this.name);
    }

    @Override
    public void clearTraining() {
        for (BayesianModel bayesianModel : this.unitModels.values()) {
            bayesianModel.clearTraining();
        }
        this.defaultModel.clearTraining();
    }

    @Override
    public void train(List<int[]> list, List<Integer> list2, List<TrainingInstance> list3) throws Exception {
        Object object;
        HashMap<UnitType, ArrayList<int[]>> hashMap = new HashMap<UnitType, ArrayList<int[]>>();
        HashMap<UnitType, ArrayList<Integer>> hashMap2 = new HashMap<UnitType, ArrayList<Integer>>();
        HashMap<UnitType, ArrayList<TrainingInstance>> hashMap3 = new HashMap<UnitType, ArrayList<TrainingInstance>>();
        for (int j = 0; j < list.size(); ++j) {
            UnitType unitType = list3.get((int)j).u.getType();
            object = (List)hashMap.get(unitType);
            ArrayList<Integer> arrayList = (ArrayList<Integer>)hashMap2.get(unitType);
            ArrayList<TrainingInstance> arrayList2 = (ArrayList<TrainingInstance>)hashMap3.get(unitType);
            if (object == null) {
                object = new ArrayList<int[]>();
                hashMap.put(unitType, (ArrayList<int[]>)object);
                arrayList = new ArrayList<Integer>();
                hashMap2.put(unitType, arrayList);
                arrayList2 = new ArrayList<TrainingInstance>();
                hashMap3.put(unitType, arrayList2);
            }
            object.add(list.get(j));
            arrayList.add(list2.get(j));
            arrayList2.add(list3.get(j));
        }
        for (UnitType unitType : hashMap.keySet()) {
            object = this.unitModels.get(unitType);
            if (object == null) {
                object = (BayesianModel)this.templateModel.clone();
                this.unitModels.put(unitType, (BayesianModel)object);
            }
            ((BayesianModel)object).train((List)hashMap.get(unitType), (List)hashMap2.get(unitType), (List)hashMap3.get(unitType));
        }
        if (this.defaultModel == null) {
            this.defaultModel = (BayesianModel)this.templateModel.clone();
        }
        this.defaultModel.train(list, list2, list3);
    }

    @Override
    public void calibrateProbabilities(List<int[]> list, List<Integer> list2, List<TrainingInstance> list3) throws Exception {
        Object object;
        HashMap<UnitType, ArrayList<int[]>> hashMap = new HashMap<UnitType, ArrayList<int[]>>();
        HashMap<UnitType, ArrayList<Integer>> hashMap2 = new HashMap<UnitType, ArrayList<Integer>>();
        HashMap<UnitType, ArrayList<TrainingInstance>> hashMap3 = new HashMap<UnitType, ArrayList<TrainingInstance>>();
        for (int j = 0; j < list.size(); ++j) {
            UnitType unitType = list3.get((int)j).u.getType();
            object = (List)hashMap.get(unitType);
            ArrayList<Integer> arrayList = (ArrayList<Integer>)hashMap2.get(unitType);
            ArrayList<TrainingInstance> arrayList2 = (ArrayList<TrainingInstance>)hashMap3.get(unitType);
            if (object == null) {
                object = new ArrayList<int[]>();
                hashMap.put(unitType, (ArrayList<int[]>)object);
                arrayList = new ArrayList<Integer>();
                hashMap2.put(unitType, arrayList);
                arrayList2 = new ArrayList<TrainingInstance>();
                hashMap3.put(unitType, arrayList2);
            }
            object.add(list.get(j));
            arrayList.add(list2.get(j));
            arrayList2.add(list3.get(j));
        }
        for (UnitType unitType : hashMap.keySet()) {
            object = this.unitModels.get(unitType);
            if (object == null) {
                object = (BayesianModel)this.templateModel.clone();
                this.unitModels.put(unitType, (BayesianModel)object);
            }
            ((BayesianModel)object).calibrateProbabilities((List)hashMap.get(unitType), (List)hashMap2.get(unitType), (List)hashMap3.get(unitType));
        }
        if (this.defaultModel == null) {
            this.defaultModel = (BayesianModel)this.templateModel.clone();
        }
        this.defaultModel.calibrateProbabilities(list, list2, list3);
    }

    @Override
    public void featureSelectionByGainRatio(List<int[]> list, List<Integer> list2, double d) {
        for (UnitType unitType : this.unitModels.keySet()) {
            BayesianModel bayesianModel = this.unitModels.get(unitType);
            if (bayesianModel == null) {
                bayesianModel = (BayesianModel)this.templateModel.clone();
                this.unitModels.put(unitType, bayesianModel);
            }
            bayesianModel.featureSelectionByGainRatio(list, list2, d);
        }
        if (this.defaultModel == null) {
            this.defaultModel = (BayesianModel)this.templateModel.clone();
        }
        this.defaultModel.featureSelectionByGainRatio(list, list2, d);
    }

    @Override
    public void featureSelectionByCrossValidation(List<int[]> list, List<Integer> list2, List<TrainingInstance> list3) throws Exception {
        Object object;
        HashMap<UnitType, ArrayList<int[]>> hashMap = new HashMap<UnitType, ArrayList<int[]>>();
        HashMap<UnitType, ArrayList<Integer>> hashMap2 = new HashMap<UnitType, ArrayList<Integer>>();
        HashMap<UnitType, ArrayList<TrainingInstance>> hashMap3 = new HashMap<UnitType, ArrayList<TrainingInstance>>();
        for (int j = 0; j < list.size(); ++j) {
            UnitType unitType = list3.get((int)j).u.getType();
            object = (List)hashMap.get(unitType);
            ArrayList<Integer> arrayList = (ArrayList<Integer>)hashMap2.get(unitType);
            ArrayList<TrainingInstance> arrayList2 = (ArrayList<TrainingInstance>)hashMap3.get(unitType);
            if (object == null) {
                object = new ArrayList<int[]>();
                hashMap.put(unitType, (ArrayList<int[]>)object);
                arrayList = new ArrayList<Integer>();
                hashMap2.put(unitType, arrayList);
                arrayList2 = new ArrayList<TrainingInstance>();
                hashMap3.put(unitType, arrayList2);
            }
            object.add(list.get(j));
            arrayList.add(list2.get(j));
            arrayList2.add(list3.get(j));
        }
        for (UnitType unitType : hashMap.keySet()) {
            object = this.unitModels.get(unitType);
            if (object == null) {
                object = (BayesianModel)this.templateModel.clone();
                this.unitModels.put(unitType, (BayesianModel)object);
            }
            ((BayesianModel)object).featureSelectionByCrossValidation((List)hashMap.get(unitType), (List)hashMap2.get(unitType), (List)hashMap3.get(unitType));
        }
        if (this.defaultModel == null) {
            this.defaultModel = (BayesianModel)this.templateModel.clone();
        }
        this.defaultModel.featureSelectionByCrossValidation(list, list2, list3);
    }

    @Override
    public double[] predictDistribution(int[] nArray, TrainingInstance trainingInstance) {
        BayesianModel bayesianModel = this.unitModels.get(trainingInstance.u.getType());
        if (bayesianModel != null) {
            return bayesianModel.predictDistribution(nArray, trainingInstance);
        }
        return this.defaultModel.predictDistribution(nArray, trainingInstance);
    }

    @Override
    public void save(XMLWriter xMLWriter) throws Exception {
        xMLWriter.tag(this.getClass().getSimpleName());
        for (UnitType unitType : this.unitModels.keySet()) {
            xMLWriter.tagWithAttributes("UnitType", "name=\"" + unitType.name + "\" ID=\"" + unitType.ID + "\"");
            this.unitModels.get(unitType).save(xMLWriter);
            xMLWriter.tag("/UnitType");
        }
        xMLWriter.tag("defaultModel");
        this.defaultModel.save(xMLWriter);
        xMLWriter.tag("/defaultModel");
        xMLWriter.tag("/" + this.getClass().getSimpleName());
        xMLWriter.flush();
    }

    public BayesianModelByUnitTypeWithDefaultModel(Element element, UnitTypeTable unitTypeTable, BayesianModel bayesianModel, String string) throws Exception {
        super(unitTypeTable, bayesianModel.featureGenerator, string);
        this.templateModel = bayesianModel;
        this.load(element);
    }

    @Override
    public void load(Element element) throws Exception {
        if (!element.getName().equals(this.getClass().getSimpleName())) {
            throw new Exception("Head tag is not '" + this.getClass().getSimpleName() + "'!");
        }
        List list = element.getChildren("UnitType");
        for (Object e : list) {
            Element element2 = (Element)e;
            UnitType unitType = this.utt.getUnitType(element2.getAttributeValue("name"));
            BayesianModel bayesianModel = (BayesianModel)this.templateModel.clone();
            bayesianModel.load((Element)element2.getChildren().get(0));
            this.unitModels.put(unitType, bayesianModel);
        }
        Element element3 = element.getChild("defaultModel");
        this.defaultModel = (BayesianModel)this.templateModel.clone();
        this.defaultModel.load((Element)element3.getChildren().get(0));
    }
}

