/*
 * Decompiled with CFR 0.152.
 */
package ai.machinelearning.bayes;

import ai.machinelearning.bayes.BayesianModel;
import ai.machinelearning.bayes.TrainingInstance;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.jdom.Element;
import rts.units.UnitType;
import rts.units.UnitTypeTable;
import util.XMLWriter;

public class BayesianModelByUnitTypeWithDefaultModel
extends BayesianModel {
    BayesianModel templateModel;
    HashMap<UnitType, BayesianModel> unitModels = new HashMap();
    BayesianModel defaultModel;

    public BayesianModelByUnitTypeWithDefaultModel(UnitTypeTable utt, BayesianModel tm, String a_name) {
        super(utt, tm.featureGenerator, a_name);
        this.templateModel = tm;
    }

    @Override
    public Object clone() {
        return new BayesianModelByUnitTypeWithDefaultModel(this.utt, this.templateModel, this.name);
    }

    @Override
    public void clearTraining() {
        for (BayesianModel model : this.unitModels.values()) {
            model.clearTraining();
        }
        this.defaultModel.clearTraining();
    }

    @Override
    public void train(List<int[]> x_l, List<Integer> y_l, List<TrainingInstance> i_l) throws Exception {
        HashMap<UnitType, ArrayList<int[]>> x_l_ut_l = new HashMap<UnitType, ArrayList<int[]>>();
        HashMap<UnitType, ArrayList<Integer>> y_l_ut_l = new HashMap<UnitType, ArrayList<Integer>>();
        HashMap<UnitType, ArrayList<TrainingInstance>> i_l_ut_l = new HashMap<UnitType, ArrayList<TrainingInstance>>();
        for (int i = 0; i < x_l.size(); ++i) {
            UnitType ut = i_l.get((int)i).u.getType();
            ArrayList<int[]> x_l_ut = (ArrayList<int[]>)x_l_ut_l.get(ut);
            ArrayList<Integer> y_l_ut = (ArrayList<Integer>)y_l_ut_l.get(ut);
            ArrayList<TrainingInstance> i_l_ut = (ArrayList<TrainingInstance>)i_l_ut_l.get(ut);
            if (x_l_ut == null) {
                x_l_ut = new ArrayList<int[]>();
                x_l_ut_l.put(ut, x_l_ut);
                y_l_ut = new ArrayList<Integer>();
                y_l_ut_l.put(ut, y_l_ut);
                i_l_ut = new ArrayList<TrainingInstance>();
                i_l_ut_l.put(ut, i_l_ut);
            }
            x_l_ut.add(x_l.get(i));
            y_l_ut.add(y_l.get(i));
            i_l_ut.add(i_l.get(i));
        }
        for (UnitType ut : x_l_ut_l.keySet()) {
            BayesianModel model_ut = this.unitModels.get(ut);
            if (model_ut == null) {
                model_ut = (BayesianModel)this.templateModel.clone();
                this.unitModels.put(ut, model_ut);
            }
            model_ut.train((List)x_l_ut_l.get(ut), (List)y_l_ut_l.get(ut), (List)i_l_ut_l.get(ut));
        }
        if (this.defaultModel == null) {
            this.defaultModel = (BayesianModel)this.templateModel.clone();
        }
        this.defaultModel.train(x_l, y_l, i_l);
    }

    @Override
    public void calibrateProbabilities(List<int[]> x_l, List<Integer> y_l, List<TrainingInstance> i_l) throws Exception {
        HashMap<UnitType, ArrayList<int[]>> x_l_ut_l = new HashMap<UnitType, ArrayList<int[]>>();
        HashMap<UnitType, ArrayList<Integer>> y_l_ut_l = new HashMap<UnitType, ArrayList<Integer>>();
        HashMap<UnitType, ArrayList<TrainingInstance>> i_l_ut_l = new HashMap<UnitType, ArrayList<TrainingInstance>>();
        for (int i = 0; i < x_l.size(); ++i) {
            UnitType ut = i_l.get((int)i).u.getType();
            ArrayList<int[]> x_l_ut = (ArrayList<int[]>)x_l_ut_l.get(ut);
            ArrayList<Integer> y_l_ut = (ArrayList<Integer>)y_l_ut_l.get(ut);
            ArrayList<TrainingInstance> i_l_ut = (ArrayList<TrainingInstance>)i_l_ut_l.get(ut);
            if (x_l_ut == null) {
                x_l_ut = new ArrayList<int[]>();
                x_l_ut_l.put(ut, x_l_ut);
                y_l_ut = new ArrayList<Integer>();
                y_l_ut_l.put(ut, y_l_ut);
                i_l_ut = new ArrayList<TrainingInstance>();
                i_l_ut_l.put(ut, i_l_ut);
            }
            x_l_ut.add(x_l.get(i));
            y_l_ut.add(y_l.get(i));
            i_l_ut.add(i_l.get(i));
        }
        for (UnitType ut : x_l_ut_l.keySet()) {
            BayesianModel model_ut = this.unitModels.get(ut);
            if (model_ut == null) {
                model_ut = (BayesianModel)this.templateModel.clone();
                this.unitModels.put(ut, model_ut);
            }
            model_ut.calibrateProbabilities((List)x_l_ut_l.get(ut), (List)y_l_ut_l.get(ut), (List)i_l_ut_l.get(ut));
        }
        if (this.defaultModel == null) {
            this.defaultModel = (BayesianModel)this.templateModel.clone();
        }
        this.defaultModel.calibrateProbabilities(x_l, y_l, i_l);
    }

    @Override
    public void featureSelectionByGainRatio(List<int[]> x_l, List<Integer> y_l, double fractionOfFeaturesToKeep) {
        for (UnitType ut : this.unitModels.keySet()) {
            BayesianModel model_ut = this.unitModels.get(ut);
            if (model_ut == null) {
                model_ut = (BayesianModel)this.templateModel.clone();
                this.unitModels.put(ut, model_ut);
            }
            model_ut.featureSelectionByGainRatio(x_l, y_l, fractionOfFeaturesToKeep);
        }
        if (this.defaultModel == null) {
            this.defaultModel = (BayesianModel)this.templateModel.clone();
        }
        this.defaultModel.featureSelectionByGainRatio(x_l, y_l, fractionOfFeaturesToKeep);
    }

    @Override
    public void featureSelectionByCrossValidation(List<int[]> x_l, List<Integer> y_l, List<TrainingInstance> i_l) throws Exception {
        HashMap<UnitType, ArrayList<int[]>> x_l_ut_l = new HashMap<UnitType, ArrayList<int[]>>();
        HashMap<UnitType, ArrayList<Integer>> y_l_ut_l = new HashMap<UnitType, ArrayList<Integer>>();
        HashMap<UnitType, ArrayList<TrainingInstance>> i_l_ut_l = new HashMap<UnitType, ArrayList<TrainingInstance>>();
        for (int i = 0; i < x_l.size(); ++i) {
            UnitType ut = i_l.get((int)i).u.getType();
            ArrayList<int[]> x_l_ut = (ArrayList<int[]>)x_l_ut_l.get(ut);
            ArrayList<Integer> y_l_ut = (ArrayList<Integer>)y_l_ut_l.get(ut);
            ArrayList<TrainingInstance> i_l_ut = (ArrayList<TrainingInstance>)i_l_ut_l.get(ut);
            if (x_l_ut == null) {
                x_l_ut = new ArrayList<int[]>();
                x_l_ut_l.put(ut, x_l_ut);
                y_l_ut = new ArrayList<Integer>();
                y_l_ut_l.put(ut, y_l_ut);
                i_l_ut = new ArrayList<TrainingInstance>();
                i_l_ut_l.put(ut, i_l_ut);
            }
            x_l_ut.add(x_l.get(i));
            y_l_ut.add(y_l.get(i));
            i_l_ut.add(i_l.get(i));
        }
        for (UnitType ut : x_l_ut_l.keySet()) {
            BayesianModel model_ut = this.unitModels.get(ut);
            if (model_ut == null) {
                model_ut = (BayesianModel)this.templateModel.clone();
                this.unitModels.put(ut, model_ut);
            }
            model_ut.featureSelectionByCrossValidation((List)x_l_ut_l.get(ut), (List)y_l_ut_l.get(ut), (List)i_l_ut_l.get(ut));
        }
        if (this.defaultModel == null) {
            this.defaultModel = (BayesianModel)this.templateModel.clone();
        }
        this.defaultModel.featureSelectionByCrossValidation(x_l, y_l, i_l);
    }

    @Override
    public double[] predictDistribution(int[] x, TrainingInstance ti) {
        BayesianModel model_ut = this.unitModels.get(ti.u.getType());
        if (model_ut != null) {
            return model_ut.predictDistribution(x, ti);
        }
        return this.defaultModel.predictDistribution(x, ti);
    }

    @Override
    public void save(XMLWriter w) throws Exception {
        w.tag(this.getClass().getSimpleName());
        for (UnitType ut : this.unitModels.keySet()) {
            w.tagWithAttributes("UnitType", "name=\"" + ut.name + "\" ID=\"" + ut.ID + "\"");
            this.unitModels.get(ut).save(w);
            w.tag("/UnitType");
        }
        w.tag("defaultModel");
        this.defaultModel.save(w);
        w.tag("/defaultModel");
        w.tag("/" + this.getClass().getSimpleName());
        w.flush();
    }

    public BayesianModelByUnitTypeWithDefaultModel(Element e, UnitTypeTable utt, BayesianModel tm, String a_name) throws Exception {
        super(utt, tm.featureGenerator, a_name);
        this.templateModel = tm;
        this.load(e);
    }

    @Override
    public void load(Element e) throws Exception {
        if (!e.getName().equals(this.getClass().getSimpleName())) {
            throw new Exception("Head tag is not '" + this.getClass().getSimpleName() + "'!");
        }
        List models = e.getChildren("UnitType");
        for (Object o : models) {
            Element ut_xml = (Element)o;
            UnitType ut = this.utt.getUnitType(ut_xml.getAttributeValue("name"));
            BayesianModel model = (BayesianModel)this.templateModel.clone();
            model.load((Element)ut_xml.getChildren().get(0));
            this.unitModels.put(ut, model);
        }
        Element dm_xml = e.getChild("defaultModel");
        this.defaultModel = (BayesianModel)this.templateModel.clone();
        this.defaultModel.load((Element)dm_xml.getChildren().get(0));
    }
}

