/*
 * Decompiled with CFR 0.152.
 */
package ai.machinelearning.bayes;

import ai.machinelearning.bayes.BayesianModel;
import ai.machinelearning.bayes.DiscreteCPD;
import ai.machinelearning.bayes.FeatureSelection;
import ai.machinelearning.bayes.TrainingInstance;
import ai.machinelearning.bayes.featuregeneration.FeatureGenerator;
import ai.machinelearning.bayes.featuregeneration.FeatureGeneratorComplex;
import ai.machinelearning.bayes.featuregeneration.FeatureGeneratorEmpty;
import ai.machinelearning.bayes.featuregeneration.FeatureGeneratorSimple;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import org.jdom.Element;
import rts.UnitAction;
import rts.units.Unit;
import rts.units.UnitTypeTable;
import util.XMLWriter;

public class CalibratedNaiveBayes
extends BayesianModel {
    int estimationMethod = 1;
    double calibrationFactor = 0.0;
    double[] prior_distribution = null;
    DiscreteCPD[] distributions = null;
    boolean[] selectedFeatures = null;
    int Ysize = 0;
    int[] Xsizes;

    public CalibratedNaiveBayes(int[] nArray, int n, int n2, double d, UnitTypeTable unitTypeTable, FeatureGenerator featureGenerator, String string) {
        super(unitTypeTable, featureGenerator, string);
        this.Ysize = n;
        this.Xsizes = nArray;
        this.estimationMethod = n2;
        this.calibrationFactor = d;
        this.clearTraining();
    }

    @Override
    public Object clone() {
        CalibratedNaiveBayes calibratedNaiveBayes = new CalibratedNaiveBayes(this.Xsizes, this.Ysize, this.estimationMethod, this.calibrationFactor, this.utt, this.featureGenerator, this.name);
        return calibratedNaiveBayes;
    }

    @Override
    public void clearTraining() {
        int n = this.Xsizes.length;
        this.distributions = new DiscreteCPD[n];
        for (int j = 0; j < n; ++j) {
            this.distributions[j] = new DiscreteCPD(this.Ysize, this.Xsizes[j]);
        }
    }

    @Override
    public void train(List<int[]> list, List<Integer> list2, List<TrainingInstance> list3) throws Exception {
        int n;
        int n2 = this.distributions.length;
        this.prior_distribution = new double[this.Ysize];
        for (n = 0; n < list.size(); ++n) {
            int n3;
            int[] nArray = list.get(n);
            int n4 = n3 = list2.get(n).intValue();
            this.prior_distribution[n4] = this.prior_distribution[n4] + 1.0;
            for (int j = 0; j < n2; ++j) {
                this.distributions[j].addObservation(n3, nArray[j]);
            }
        }
        if (this.estimationMethod == 1) {
            n = 0;
            while (n < this.Ysize) {
                int n5 = n++;
                this.prior_distribution[n5] = this.prior_distribution[n5] / (double)list.size();
            }
        } else {
            for (n = 0; n < this.Ysize; ++n) {
                this.prior_distribution[n] = (this.prior_distribution[n] + 1.0) / (double)(list.size() + this.Ysize);
            }
        }
    }

    @Override
    public void calibrateProbabilities(List<int[]> list, List<Integer> list2, List<TrainingInstance> list3) throws Exception {
        double d = 0.0;
        double d2 = Double.NEGATIVE_INFINITY;
        for (double d3 = 0.0; d3 <= 1.05; d3 += 0.05) {
            this.calibrationFactor = d3;
            double d4 = 0.0;
            for (int j = 0; j < list.size(); ++j) {
                int n;
                Unit unit = list3.get((int)j).u;
                List<UnitAction> list4 = unit.getUnitActions(list3.get((int)j).gs);
                ArrayList<Integer> arrayList = new ArrayList<Integer>();
                for (UnitAction unitAction : list4) {
                    if (unitAction.getType() == 5) {
                        unitAction = new UnitAction(5, unitAction.getLocationX() - unit.getX(), unitAction.getLocationY() - unit.getY());
                    }
                    if ((n = this.allPossibleActions.indexOf(unitAction)) < 0) {
                        throw new Exception("Unknown action: " + unitAction);
                    }
                    arrayList.add(n);
                }
                if (list4.size() <= 1) continue;
                Object object = this.predictDistribution(list.get(j), list3.get(j));
                object = this.filterByPossibleActionIndexes((double[])object, arrayList);
                int n2 = list2.get(j);
                if (!arrayList.contains(n2)) continue;
                n = -1;
                Collections.shuffle(list4);
                Iterator iterator = arrayList.iterator();
                while (iterator.hasNext()) {
                    int n3 = (Integer)iterator.next();
                    if (n == -1) {
                        n = n3;
                        continue;
                    }
                    if (!(object[n3] > object[n])) continue;
                    n = n3;
                }
                double d5 = Math.log((double)object[n2]);
                if (Double.isInfinite(d5)) {
                    System.out.println(Arrays.toString((double[])object));
                    System.out.println(arrayList);
                    System.out.println(n2 + " : " + this.allPossibleActions.get(n2));
                    System.exit(1);
                }
                d4 += d5;
            }
            if (!(d4 > d2)) break;
            d = d3;
            d2 = d4;
        }
        System.out.println("best calibration factor = " + d);
        this.calibrationFactor = d;
    }

    @Override
    public void featureSelectionByGainRatio(List<int[]> list, List<Integer> list2, double d) {
        int n;
        ArrayList<Integer> arrayList = new ArrayList<Integer>();
        final ArrayList<Double> arrayList2 = new ArrayList<Double>();
        int n2 = this.distributions.length;
        this.selectedFeatures = new boolean[n2];
        for (n = 0; n < n2; ++n) {
            arrayList.add(n);
            arrayList2.add(FeatureSelection.featureGainRatio(list, list2, n));
            this.selectedFeatures[n] = false;
        }
        Collections.sort(arrayList, new Comparator<Integer>(){

            @Override
            public int compare(Integer n, Integer n2) {
                return Double.compare((Double)arrayList2.get(n2), (Double)arrayList2.get(n));
            }
        });
        n = 0;
        while ((double)n < d * (double)n2) {
            this.selectedFeatures[((Integer)arrayList.get((int)n)).intValue()] = true;
            ++n;
        }
    }

    @Override
    public double[] predictDistribution(int[] nArray, TrainingInstance trainingInstance) {
        return this.predictDistribution(nArray, trainingInstance, this.calibrationFactor);
    }

    public double[] predictDistribution(int[] nArray, TrainingInstance trainingInstance, double d) {
        int n;
        double[] dArray = new double[this.Ysize];
        for (int j = 0; j < this.Ysize; ++j) {
            dArray[j] = this.prior_distribution == null ? 1.0 : this.prior_distribution[j];
        }
        double d2 = 1.0;
        for (int j = 0; j < nArray.length; ++j) {
            double[] dArray2;
            int n2;
            if (this.selectedFeatures != null && !this.selectedFeatures[j]) continue;
            d2 += 1.0;
            if (this.estimationMethod == 1) {
                n2 = 0;
                while (n2 < this.Ysize) {
                    dArray2 = this.distributions[j].distribution(n2);
                    int n3 = n2++;
                    dArray[n3] = dArray[n3] * dArray2[nArray[j]];
                }
                continue;
            }
            n2 = 0;
            while (n2 < this.Ysize) {
                dArray2 = this.distributions[j].distributionLaplace(n2, 1.0);
                double d3 = 1.0;
                d3 = dArray2.length > nArray[j] ? dArray2[nArray[j]] : 1.0 / (double)this.Ysize;
                int n4 = n2++;
                dArray[n4] = dArray[n4] * d3;
            }
        }
        double d4 = 0.0;
        for (n = 0; n < this.Ysize; ++n) {
            dArray[n] = Math.pow(dArray[n], 1.0 / (1.0 * (1.0 - d) + d2 * d));
            d4 += dArray[n];
        }
        if (d4 <= 0.0) {
            for (n = 0; n < this.Ysize; ++n) {
                dArray[n] = 1.0 / (double)this.Ysize;
            }
        } else {
            n = 0;
            while (n < this.Ysize) {
                int n5 = n++;
                dArray[n5] = dArray[n5] / d4;
            }
        }
        return dArray;
    }

    /*
     * WARNING - void declaration
     */
    @Override
    public void save(XMLWriter xMLWriter) throws Exception {
        void var2_7;
        xMLWriter.tagWithAttributes(this.getClass().getSimpleName(), "estimationMethod=\"" + this.estimationMethod + "\" Ysize=\"" + this.Ysize + "\" calibrationFactor=\"" + this.calibrationFactor + "\" nfeatures=\"" + this.distributions.length + "\" featureGenerationClass=\"" + this.featureGenerator.getClass().getSimpleName() + "\"");
        xMLWriter.tag("Xsizes");
        for (int n : this.Xsizes) {
            xMLWriter.rawXML(n + " ");
        }
        xMLWriter.rawXML("\n");
        xMLWriter.tag("/Xsizes");
        xMLWriter.tag("priorDistribution");
        for (double d : this.prior_distribution) {
            xMLWriter.rawXML(d + " ");
        }
        xMLWriter.rawXML("\n");
        xMLWriter.tag("/priorDistribution");
        if (this.selectedFeatures != null) {
            xMLWriter.tag("selectedFeatures");
            for (boolean bl : this.selectedFeatures) {
                xMLWriter.rawXML(bl + " ");
            }
            xMLWriter.rawXML("\n");
            xMLWriter.tag("/selectedFeatures");
        }
        boolean bl = false;
        while (var2_7 < this.distributions.length) {
            this.distributions[var2_7].save(xMLWriter);
            ++var2_7;
        }
        xMLWriter.tag("/SimpleNaiveBayes");
        xMLWriter.flush();
    }

    public CalibratedNaiveBayes(Element element, UnitTypeTable unitTypeTable, String string) throws Exception {
        super(unitTypeTable, null, string);
        this.load(element);
    }

    @Override
    public void load(Element element) throws Exception {
        if (!element.getName().equals(this.getClass().getSimpleName())) {
            throw new Exception("Head tag is not 'SimpleNaiveBayes'!");
        }
        String string = element.getAttributeValue("featureGenerationClass");
        if (string.contains("FeatureGeneratorEmpty")) {
            this.featureGenerator = new FeatureGeneratorEmpty();
        } else if (string.contains("FeatureGeneratorSimple")) {
            this.featureGenerator = new FeatureGeneratorSimple();
        } else if (string.contains("FeatureGeneratorComplex")) {
            this.featureGenerator = new FeatureGeneratorComplex();
        }
        this.estimationMethod = Integer.parseInt(element.getAttributeValue("estimationMethod"));
        this.Ysize = Integer.parseInt(element.getAttributeValue("Ysize"));
        int n = Integer.parseInt(element.getAttributeValue("nfeatures"));
        this.calibrationFactor = Double.parseDouble(element.getAttributeValue("calibrationFactor"));
        Element element2 = element.getChild("Xsizes");
        Object object = element2.getTextTrim();
        Object object2 = ((String)object).split(" ");
        this.Xsizes = new int[n];
        for (int j = 0; j < n; ++j) {
            this.Xsizes[j] = Integer.parseInt(object2[j]);
        }
        object = element.getChild("priorDistribution");
        object2 = ((Element)object).getTextTrim();
        Object object3 = ((String)object2).split(" ");
        this.prior_distribution = new double[this.Ysize];
        for (int j = 0; j < this.Ysize; ++j) {
            this.prior_distribution[j] = Double.parseDouble(object3[j]);
        }
        object2 = element.getChild("selectedFeatures");
        if (object2 != null) {
            object3 = ((Element)object2).getTextTrim();
            String[] stringArray = ((String)object3).split(" ");
            this.selectedFeatures = new boolean[n];
            for (int j = 0; j < n; ++j) {
                this.selectedFeatures[j] = Boolean.parseBoolean(stringArray[j]);
            }
        } else {
            this.selectedFeatures = null;
        }
        this.distributions = new DiscreteCPD[n];
        object3 = element.getChildren("DiscreteCPD");
        for (int j = 0; j < n; ++j) {
            Element element3 = (Element)object3.get(j);
            this.distributions[j] = new DiscreteCPD(element3);
        }
    }

    @Override
    public void featureSelectionByCrossValidation(List<int[]> list, List<Integer> list2, List<TrainingInstance> list3) throws Exception {
        boolean bl;
        int n = this.distributions.length;
        System.out.println("featureSelectionByCrossValidation " + list.size());
        boolean[] blArray = new boolean[n];
        for (int j = 0; j < n; ++j) {
            blArray[j] = false;
        }
        this.selectedFeatures = blArray;
        double d = (Double)FeatureSelection.crossValidation((BayesianModel)this, list, list2, list3, (List<UnitAction>)this.allPossibleActions, (int)10).m_a;
        System.out.println("  loglikelihood with " + Arrays.toString(this.selectedFeatures) + ": " + d);
        do {
            bl = false;
            boolean[] blArray2 = blArray;
            for (int j = 0; j < n; ++j) {
                if (blArray[j]) continue;
                boolean[] blArray3 = new boolean[n];
                for (int k = 0; k < n; ++k) {
                    blArray3[k] = blArray[k];
                }
                blArray3[j] = true;
                this.selectedFeatures = blArray3;
                double d2 = (Double)FeatureSelection.crossValidation((BayesianModel)this, list, list2, list3, (List<UnitAction>)this.allPossibleActions, (int)10).m_a;
                System.out.println("  loglikelihood with " + Arrays.toString(this.selectedFeatures) + ": " + d2);
                if (!(d2 > d)) continue;
                blArray2 = blArray3;
                d = d2;
                bl = true;
            }
            blArray = blArray2;
        } while (bl);
        this.selectedFeatures = blArray;
        System.out.println("Selected features: " + Arrays.toString(this.selectedFeatures));
    }
}

