/*
 * Decompiled with CFR 0.152.
 */
package ai.machinelearning.bayes;

import ai.machinelearning.bayes.BayesianModel;
import ai.machinelearning.bayes.DiscreteCPD;
import ai.machinelearning.bayes.FeatureSelection;
import ai.machinelearning.bayes.TrainingInstance;
import ai.machinelearning.bayes.featuregeneration.FeatureGenerator;
import ai.machinelearning.bayes.featuregeneration.FeatureGeneratorComplex;
import ai.machinelearning.bayes.featuregeneration.FeatureGeneratorEmpty;
import ai.machinelearning.bayes.featuregeneration.FeatureGeneratorSimple;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import org.jdom.Element;
import rts.UnitAction;
import rts.units.Unit;
import rts.units.UnitTypeTable;
import util.XMLWriter;

public class ActionInterdependenceModel
extends BayesianModel {
    int estimationMethod = 1;
    double calibrationFactor = 0.0;
    double[] prior_distribution = null;
    DiscreteCPD[] distributions = null;
    boolean[] selectedFeatures = null;
    int Ysize = 0;
    int YtypeSize = 0;
    int[] Xsizes;
    int[] action_allowed_counts_prior = null;
    int[][] selected_allowed_action_prior = null;
    List<Integer> allPossibleActionsTypes = null;
    int[] actiontypes_allowed_counts_prior = null;
    int[][] selected_allowed_actiontype_prior = null;
    boolean consider_individual_actions = false;
    boolean consider_action_types = true;

    public ActionInterdependenceModel(int[] nArray, int n, int n2, double d, UnitTypeTable unitTypeTable, FeatureGenerator featureGenerator, String string) {
        super(unitTypeTable, featureGenerator, string);
        this.Ysize = n;
        this.Xsizes = nArray;
        this.estimationMethod = n2;
        this.calibrationFactor = d;
        this.allPossibleActionsTypes = new ArrayList<Integer>();
        for (UnitAction unitAction : this.allPossibleActions) {
            this.allPossibleActionsTypes.add(unitAction.getType());
        }
        this.YtypeSize = 6;
        this.clearTraining();
    }

    @Override
    public Object clone() {
        ActionInterdependenceModel actionInterdependenceModel = new ActionInterdependenceModel(this.Xsizes, this.Ysize, this.estimationMethod, this.calibrationFactor, this.utt, this.featureGenerator, this.name);
        return actionInterdependenceModel;
    }

    @Override
    public void clearTraining() {
        this.action_allowed_counts_prior = null;
        this.selected_allowed_action_prior = null;
        this.actiontypes_allowed_counts_prior = null;
        this.selected_allowed_actiontype_prior = null;
        if (this.Xsizes != null) {
            int n = this.Xsizes.length;
            this.distributions = new DiscreteCPD[n];
            for (int j = 0; j < n; ++j) {
                this.distributions[j] = new DiscreteCPD(this.Ysize, this.Xsizes[j]);
            }
        } else {
            this.distributions = null;
        }
    }

    /*
     * Could not resolve type clashes
     */
    @Override
    public void train(List<int[]> list, List<Integer> list2, List<TrainingInstance> list3) throws Exception {
        int n;
        int n2 = this.distributions.length;
        this.prior_distribution = new double[this.Ysize];
        this.action_allowed_counts_prior = new int[this.Ysize];
        this.selected_allowed_action_prior = new int[this.Ysize][this.Ysize];
        this.actiontypes_allowed_counts_prior = new int[this.YtypeSize];
        this.selected_allowed_actiontype_prior = new int[this.YtypeSize][this.YtypeSize];
        for (n = 0; n < list.size(); ++n) {
            int n3;
            int n4;
            int[] nArray = list.get(n);
            int n5 = n4 = list2.get(n).intValue();
            this.prior_distribution[n5] = this.prior_distribution[n5] + 1.0;
            for (int j = 0; j < n2; ++j) {
                this.distributions[j].addObservation(n4, nArray[j]);
            }
            List<Integer> list4 = list3.get(n).getPossibleActions(this.allPossibleActions);
            Object object = list4.iterator();
            while (object.hasNext()) {
                int n6;
                int n7 = n6 = object.next().intValue();
                this.action_allowed_counts_prior[n7] = this.action_allowed_counts_prior[n7] + 1;
                if (n6 != n4) continue;
                Object object2 = list4.iterator();
                while (object2.hasNext()) {
                    n3 = (Integer)object2.next();
                    int[] nArray2 = this.selected_allowed_action_prior[n6];
                    int n8 = n3;
                    nArray2[n8] = nArray2[n8] + 1;
                }
            }
            object = new ArrayList();
            for (Object object2 : list4) {
                n3 = this.allPossibleActionsTypes.get((Integer)object2);
                if (object.contains(n3)) continue;
                object.add(n3);
            }
            Iterator<Integer> iterator = object.iterator();
            while (iterator.hasNext()) {
                int n9;
                int n10 = n9 = iterator.next().intValue();
                this.actiontypes_allowed_counts_prior[n10] = this.actiontypes_allowed_counts_prior[n10] + 1;
                if (n9 != this.allPossibleActionsTypes.get(n4)) continue;
                Iterator iterator2 = object.iterator();
                while (iterator2.hasNext()) {
                    int n11 = (Integer)iterator2.next();
                    int[] nArray3 = this.selected_allowed_actiontype_prior[n9];
                    int n12 = n11;
                    nArray3[n12] = nArray3[n12] + 1;
                }
            }
        }
        if (this.estimationMethod == 1) {
            n = 0;
            while (n < this.Ysize) {
                int n13 = n++;
                this.prior_distribution[n13] = this.prior_distribution[n13] / (double)list.size();
            }
        } else {
            for (n = 0; n < this.Ysize; ++n) {
                this.prior_distribution[n] = (this.prior_distribution[n] + 1.0) / (double)(list.size() + this.Ysize);
            }
        }
    }

    @Override
    public void calibrateProbabilities(List<int[]> list, List<Integer> list2, List<TrainingInstance> list3) throws Exception {
        double d = 0.0;
        double d2 = Double.NEGATIVE_INFINITY;
        for (double d3 = 0.0; d3 <= 1.05; d3 += 0.05) {
            this.calibrationFactor = d3;
            double d4 = 0.0;
            for (int j = 0; j < list.size(); ++j) {
                int n;
                Unit unit = list3.get((int)j).u;
                List<UnitAction> list4 = unit.getUnitActions(list3.get((int)j).gs);
                ArrayList<Integer> arrayList = new ArrayList<Integer>();
                for (UnitAction unitAction : list4) {
                    if (unitAction.getType() == 5) {
                        unitAction = new UnitAction(5, unitAction.getLocationX() - unit.getX(), unitAction.getLocationY() - unit.getY());
                    }
                    if ((n = this.allPossibleActions.indexOf(unitAction)) < 0) {
                        throw new Exception("Unknown action: " + unitAction);
                    }
                    arrayList.add(n);
                }
                if (list4.size() <= 1) continue;
                Object object = this.predictDistribution(list.get(j), list3.get(j));
                object = this.filterByPossibleActionIndexes((double[])object, arrayList);
                int n2 = list2.get(j);
                if (!arrayList.contains(n2)) continue;
                n = -1;
                Collections.shuffle(list4);
                Iterator iterator = arrayList.iterator();
                while (iterator.hasNext()) {
                    int n3 = (Integer)iterator.next();
                    if (n == -1) {
                        n = n3;
                        continue;
                    }
                    if (!(object[n3] > object[n])) continue;
                    n = n3;
                }
                double d5 = Math.log((double)object[n2]);
                if (Double.isInfinite(d5)) {
                    System.out.println(Arrays.toString((double[])object));
                    System.out.println(arrayList);
                    System.out.println(n2 + " : " + this.allPossibleActions.get(n2));
                    System.exit(1);
                }
                d4 += d5;
            }
            if (!(d4 > d2)) break;
            d = d3;
            d2 = d4;
        }
        System.out.println("best calibration factor = " + d);
        this.calibrationFactor = d;
    }

    @Override
    public void featureSelectionByGainRatio(List<int[]> list, List<Integer> list2, double d) {
        int n;
        ArrayList<Integer> arrayList = new ArrayList<Integer>();
        final ArrayList<Double> arrayList2 = new ArrayList<Double>();
        int n2 = this.distributions.length;
        this.selectedFeatures = new boolean[n2];
        for (n = 0; n < n2; ++n) {
            arrayList.add(n);
            arrayList2.add(FeatureSelection.featureGainRatio(list, list2, n));
            this.selectedFeatures[n] = false;
        }
        Collections.sort(arrayList, new Comparator<Integer>(){

            @Override
            public int compare(Integer n, Integer n2) {
                return Double.compare((Double)arrayList2.get(n2), (Double)arrayList2.get(n));
            }
        });
        n = 0;
        while ((double)n < d * (double)n2) {
            this.selectedFeatures[((Integer)arrayList.get((int)n)).intValue()] = true;
            ++n;
        }
    }

    @Override
    public double[] predictDistribution(int[] nArray, TrainingInstance trainingInstance) {
        return this.predictDistribution(nArray, trainingInstance, this.calibrationFactor);
    }

    public double[] predictDistribution(int[] nArray, TrainingInstance trainingInstance, double d) {
        int n;
        List<Integer> list = trainingInstance.getPossibleActions(this.allPossibleActions);
        double[] dArray = new double[this.Ysize];
        double d2 = 1.0;
        for (int j = 0; j < this.Ysize; ++j) {
            dArray[j] = 0.0;
        }
        for (int n2 : list) {
            if (this.prior_distribution == null) {
                dArray[n2] = 1.0;
                continue;
            }
            dArray[n2] = this.prior_distribution[n2];
        }
        for (int j = 0; j < nArray.length; ++j) {
            double[] dArray2;
            if (this.selectedFeatures != null && !this.selectedFeatures[j]) continue;
            d2 += 1.0;
            if (this.estimationMethod == 1) {
                for (int n3 : list) {
                    dArray2 = this.distributions[j].distribution(n3);
                    int n4 = n3;
                    dArray[n4] = dArray[n4] * dArray2[nArray[j]];
                }
                continue;
            }
            for (int n3 : list) {
                dArray2 = this.distributions[j].distributionLaplace(n3, 1.0);
                double d3 = 1.0;
                d3 = dArray2.length > nArray[j] ? dArray2[nArray[j]] : 1.0 / (double)this.Ysize;
                int n5 = n3;
                dArray[n5] = dArray[n5] * d3;
            }
        }
        if (this.consider_action_types && this.selected_allowed_actiontype_prior != null) {
            ArrayList<Integer> arrayList = new ArrayList<Integer>();
            for (Integer n6 : list) {
                int n7 = this.allPossibleActionsTypes.get(n6);
                if (arrayList.contains(n7)) continue;
                arrayList.add(n7);
            }
            d2 += (double)(arrayList.size() - 1);
            for (int n8 : list) {
                int n9 = this.allPossibleActionsTypes.get(n8);
                Iterator iterator = arrayList.iterator();
                while (iterator.hasNext()) {
                    double d4;
                    int n10 = (Integer)iterator.next();
                    if (n10 == n9) continue;
                    if (this.estimationMethod == 1) {
                        d4 = (double)this.selected_allowed_actiontype_prior[n9][n10] / (double)this.actiontypes_allowed_counts_prior[n9];
                        int n11 = n8;
                        dArray[n11] = dArray[n11] * d4;
                        continue;
                    }
                    d4 = (double)(this.selected_allowed_actiontype_prior[n9][n10] + 1) / (double)(this.actiontypes_allowed_counts_prior[n9] + 2);
                    int n12 = n8;
                    dArray[n12] = dArray[n12] * d4;
                }
            }
        }
        if (this.consider_individual_actions && this.selected_allowed_action_prior != null) {
            d2 += (double)(list.size() - 1);
            for (int n13 : list) {
                for (int n14 : list) {
                    double d5;
                    if (n14 == n13) continue;
                    if (this.estimationMethod == 1) {
                        d5 = (double)this.selected_allowed_action_prior[n13][n14] / (double)this.action_allowed_counts_prior[n13];
                        int n15 = n13;
                        dArray[n15] = dArray[n15] * d5;
                        continue;
                    }
                    d5 = (double)(this.selected_allowed_action_prior[n13][n14] + 1) / (double)(this.action_allowed_counts_prior[n13] + 2);
                    int n16 = n13;
                    dArray[n16] = dArray[n16] * d5;
                }
            }
        }
        double d6 = 0.0;
        for (n = 0; n < this.Ysize; ++n) {
            dArray[n] = Math.pow(dArray[n], 1.0 / (1.0 * (1.0 - d) + d2 * d));
            d6 += dArray[n];
        }
        if (d6 <= 0.0) {
            for (n = 0; n < this.Ysize; ++n) {
                dArray[n] = 1.0 / (double)this.Ysize;
            }
        } else {
            n = 0;
            while (n < this.Ysize) {
                int n17 = n++;
                dArray[n17] = dArray[n17] / d6;
            }
        }
        return dArray;
    }

    /*
     * WARNING - void declaration
     */
    @Override
    public void save(XMLWriter xMLWriter) throws Exception {
        void var2_11;
        xMLWriter.tagWithAttributes(this.getClass().getSimpleName(), "estimationMethod=\"" + this.estimationMethod + "\" Ysize=\"" + this.Ysize + "\" calibrationFactor=\"" + this.calibrationFactor + "\" nfeatures=\"" + this.distributions.length + "\" featureGenerationClass=\"" + this.featureGenerator.getClass().getSimpleName() + "\"");
        xMLWriter.tag("Xsizes");
        for (int n : this.Xsizes) {
            xMLWriter.rawXML(n + " ");
        }
        xMLWriter.rawXML("\n");
        xMLWriter.tag("/Xsizes");
        xMLWriter.tag("priorDistribution");
        for (double d : this.prior_distribution) {
            xMLWriter.rawXML(d + " ");
        }
        xMLWriter.rawXML("\n");
        xMLWriter.tag("/priorDistribution");
        if (this.selectedFeatures != null) {
            xMLWriter.tag("selectedFeatures");
            for (boolean bl : this.selectedFeatures) {
                xMLWriter.rawXML(bl + " ");
            }
            xMLWriter.rawXML("\n");
            xMLWriter.tag("/selectedFeatures");
        }
        xMLWriter.tag("action_allowed_counts_prior");
        for (int n : this.action_allowed_counts_prior) {
            xMLWriter.rawXML(n + " ");
        }
        xMLWriter.rawXML("\n");
        xMLWriter.tag("/action_allowed_counts_prior");
        xMLWriter.tag("selected_action_pairs_prior");
        int[][] nArray = this.selected_allowed_action_prior;
        int n = nArray.length;
        for (int j = 0; j < n; ++j) {
            int[] nArray2;
            for (int n2 : nArray2 = nArray[j]) {
                xMLWriter.rawXML(n2 + " ");
            }
            xMLWriter.rawXML("\n");
        }
        xMLWriter.tag("/selected_action_pairs_prior");
        xMLWriter.tag("actiontypes_allowed_counts_prior");
        for (int n3 : this.actiontypes_allowed_counts_prior) {
            xMLWriter.rawXML(n3 + " ");
        }
        xMLWriter.rawXML("\n");
        xMLWriter.tag("/actiontypes_allowed_counts_prior");
        xMLWriter.tag("selected_allowed_actiontype_prior");
        for (int[] nArray3 : this.selected_allowed_actiontype_prior) {
            for (int n2 : nArray3) {
                xMLWriter.rawXML(n2 + " ");
            }
            xMLWriter.rawXML("\n");
        }
        xMLWriter.tag("/selected_allowed_actiontype_prior");
        boolean bl = false;
        while (var2_11 < this.distributions.length) {
            this.distributions[var2_11].save(xMLWriter);
            ++var2_11;
        }
        xMLWriter.tag("/" + this.getClass().getSimpleName());
        xMLWriter.flush();
    }

    public ActionInterdependenceModel(Element element, UnitTypeTable unitTypeTable, String string) throws Exception {
        super(unitTypeTable, null, string);
        this.load(element);
    }

    @Override
    public void load(Element element) throws Exception {
        int n;
        int n2;
        int n3;
        if (!element.getName().equals(this.getClass().getSimpleName())) {
            throw new Exception("Head tag " + element.getName() + " is not '" + this.getClass().getSimpleName() + "'!");
        }
        this.allPossibleActionsTypes = new ArrayList<Integer>();
        for (UnitAction unitAction : this.allPossibleActions) {
            this.allPossibleActionsTypes.add(unitAction.getType());
        }
        String string = element.getAttributeValue("featureGenerationClass");
        if (string.contains("FeatureGeneratorEmpty")) {
            this.featureGenerator = new FeatureGeneratorEmpty();
        } else if (string.contains("FeatureGeneratorSimple")) {
            this.featureGenerator = new FeatureGeneratorSimple();
        } else if (string.contains("FeatureGeneratorComplex")) {
            this.featureGenerator = new FeatureGeneratorComplex();
        }
        this.YtypeSize = 6;
        this.estimationMethod = Integer.parseInt(element.getAttributeValue("estimationMethod"));
        this.Ysize = Integer.parseInt(element.getAttributeValue("Ysize"));
        this.calibrationFactor = Double.parseDouble(element.getAttributeValue("calibrationFactor"));
        int n4 = Integer.parseInt(element.getAttributeValue("nfeatures"));
        Element element2 = element.getChild("Xsizes");
        Object object = element2.getTextTrim();
        Object object2 = ((String)object).split(" ");
        this.Xsizes = new int[n4];
        for (int j = 0; j < n4; ++j) {
            this.Xsizes[j] = Integer.parseInt(object2[j]);
        }
        object = element.getChild("priorDistribution");
        object2 = ((Element)object).getTextTrim();
        Object object3 = ((String)object2).split(" ");
        this.prior_distribution = new double[this.Ysize];
        for (int j = 0; j < this.Ysize; ++j) {
            this.prior_distribution[j] = Double.parseDouble(object3[j]);
        }
        object2 = element.getChild("selectedFeatures");
        if (object2 != null) {
            object3 = ((Element)object2).getTextTrim();
            String[] stringArray = ((String)object3).split(" ");
            this.selectedFeatures = new boolean[n4];
            for (int j = 0; j < n4; ++j) {
                this.selectedFeatures[j] = Boolean.parseBoolean(stringArray[j]);
            }
        } else {
            this.selectedFeatures = null;
        }
        object3 = element.getChild("action_allowed_counts_prior");
        String string2 = ((Element)object3).getTextTrim();
        Object object4 = string2.split(" ");
        this.action_allowed_counts_prior = new int[this.Ysize];
        for (n3 = 0; n3 < this.Ysize; ++n3) {
            this.action_allowed_counts_prior[n3] = Integer.parseInt(object4[n3]);
        }
        object3 = element.getChild("selected_action_pairs_prior");
        string2 = ((Element)object3).getTextTrim();
        object4 = string2.split(" |\n");
        this.selected_allowed_action_prior = new int[this.Ysize][this.Ysize];
        n3 = 0;
        for (n2 = 0; n2 < this.Ysize; ++n2) {
            n = 0;
            while (n < this.Ysize) {
                while (object4[n3].equals("")) {
                    ++n3;
                }
                this.selected_allowed_action_prior[n2][n] = Integer.parseInt(object4[n3]);
                ++n;
                ++n3;
            }
        }
        object3 = element.getChild("actiontypes_allowed_counts_prior");
        string2 = ((Element)object3).getTextTrim();
        object4 = string2.split(" ");
        this.actiontypes_allowed_counts_prior = new int[this.YtypeSize];
        for (n3 = 0; n3 < this.YtypeSize; ++n3) {
            this.actiontypes_allowed_counts_prior[n3] = Integer.parseInt(object4[n3]);
        }
        object3 = element.getChild("selected_allowed_actiontype_prior");
        string2 = ((Element)object3).getTextTrim();
        object4 = string2.split(" |\n");
        this.selected_allowed_actiontype_prior = new int[this.YtypeSize][this.YtypeSize];
        n3 = 0;
        for (n2 = 0; n2 < this.YtypeSize; ++n2) {
            n = 0;
            while (n < this.YtypeSize) {
                while (object4[n3].equals("")) {
                    ++n3;
                }
                this.selected_allowed_actiontype_prior[n2][n] = Integer.parseInt(object4[n3]);
                ++n;
                ++n3;
            }
        }
        this.distributions = new DiscreteCPD[n4];
        object3 = element.getChildren("DiscreteCPD");
        for (int j = 0; j < n4; ++j) {
            object4 = (Element)object3.get(j);
            this.distributions[j] = new DiscreteCPD((Element)object4);
        }
    }

    @Override
    public void featureSelectionByCrossValidation(List<int[]> list, List<Integer> list2, List<TrainingInstance> list3) throws Exception {
        boolean bl;
        int n = this.distributions.length;
        System.out.println("featureSelectionByCrossValidation " + list.size());
        boolean[] blArray = new boolean[n];
        for (int j = 0; j < n; ++j) {
            blArray[j] = false;
        }
        this.selectedFeatures = blArray;
        double d = (Double)FeatureSelection.crossValidation((BayesianModel)this, list, list2, list3, (List<UnitAction>)this.allPossibleActions, (int)10).m_a;
        System.out.println("  loglikelihood with " + Arrays.toString(this.selectedFeatures) + ": " + d);
        do {
            bl = false;
            boolean[] blArray2 = blArray;
            for (int j = 0; j < n; ++j) {
                if (blArray[j]) continue;
                boolean[] blArray3 = new boolean[n];
                for (int k = 0; k < n; ++k) {
                    blArray3[k] = blArray[k];
                }
                blArray3[j] = true;
                this.selectedFeatures = blArray3;
                double d2 = (Double)FeatureSelection.crossValidation((BayesianModel)this, list, list2, list3, (List<UnitAction>)this.allPossibleActions, (int)10).m_a;
                System.out.println("  loglikelihood with " + Arrays.toString(this.selectedFeatures) + ": " + d2);
                if (!(d2 > d)) continue;
                blArray2 = blArray3;
                d = d2;
                bl = true;
            }
            blArray = blArray2;
        } while (bl);
        this.selectedFeatures = blArray;
        System.out.println("Selected features: " + Arrays.toString(this.selectedFeatures));
    }
}

