/*
 * Decompiled with CFR 0.152.
 */
package ai.machinelearning.bayes;

import ai.machinelearning.bayes.BayesianModel;
import ai.machinelearning.bayes.DiscreteCPD;
import ai.machinelearning.bayes.FeatureSelection;
import ai.machinelearning.bayes.TrainingInstance;
import ai.machinelearning.bayes.featuregeneration.FeatureGenerator;
import ai.machinelearning.bayes.featuregeneration.FeatureGeneratorComplex;
import ai.machinelearning.bayes.featuregeneration.FeatureGeneratorEmpty;
import ai.machinelearning.bayes.featuregeneration.FeatureGeneratorSimple;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import org.jdom.Element;
import rts.UnitAction;
import rts.units.Unit;
import rts.units.UnitTypeTable;
import util.XMLWriter;

public class ActionInterdependenceModel
extends BayesianModel {
    int estimationMethod = 1;
    double calibrationFactor = 0.0;
    double[] prior_distribution;
    DiscreteCPD[] distributions;
    boolean[] selectedFeatures;
    int Ysize = 0;
    int YtypeSize = 0;
    int[] Xsizes;
    int[] action_allowed_counts_prior;
    int[][] selected_allowed_action_prior;
    List<Integer> allPossibleActionsTypes;
    int[] actiontypes_allowed_counts_prior;
    int[][] selected_allowed_actiontype_prior;
    boolean consider_individual_actions = false;
    boolean consider_action_types = true;

    public ActionInterdependenceModel(int[] a_Xsizes, int a_Ysize, int estimation, double a_correctionFactor, UnitTypeTable utt, FeatureGenerator fg, String a_name) {
        super(utt, fg, a_name);
        this.Ysize = a_Ysize;
        this.Xsizes = a_Xsizes;
        this.estimationMethod = estimation;
        this.calibrationFactor = a_correctionFactor;
        this.allPossibleActionsTypes = new ArrayList<Integer>();
        for (UnitAction ua : this.allPossibleActions) {
            this.allPossibleActionsTypes.add(ua.getType());
        }
        this.YtypeSize = 6;
        this.clearTraining();
    }

    @Override
    public Object clone() {
        ActionInterdependenceModel c = new ActionInterdependenceModel(this.Xsizes, this.Ysize, this.estimationMethod, this.calibrationFactor, this.utt, this.featureGenerator, this.name);
        return c;
    }

    @Override
    public void clearTraining() {
        this.action_allowed_counts_prior = null;
        this.selected_allowed_action_prior = null;
        this.actiontypes_allowed_counts_prior = null;
        this.selected_allowed_actiontype_prior = null;
        if (this.Xsizes != null) {
            int nfeatures = this.Xsizes.length;
            this.distributions = new DiscreteCPD[nfeatures];
            for (int i = 0; i < nfeatures; ++i) {
                this.distributions[i] = new DiscreteCPD(this.Ysize, this.Xsizes[i]);
            }
        } else {
            this.distributions = null;
        }
    }

    @Override
    public void train(List<int[]> x_l, List<Integer> y_l, List<TrainingInstance> i_l) throws Exception {
        int i;
        int nfeatures = this.distributions.length;
        this.prior_distribution = new double[this.Ysize];
        this.action_allowed_counts_prior = new int[this.Ysize];
        this.selected_allowed_action_prior = new int[this.Ysize][this.Ysize];
        this.actiontypes_allowed_counts_prior = new int[this.YtypeSize];
        this.selected_allowed_actiontype_prior = new int[this.YtypeSize][this.YtypeSize];
        for (i = 0; i < x_l.size(); ++i) {
            int y;
            int[] x = x_l.get(i);
            int n = y = y_l.get(i).intValue();
            this.prior_distribution[n] = this.prior_distribution[n] + 1.0;
            for (int j = 0; j < nfeatures; ++j) {
                this.distributions[j].addObservation(y, x[j]);
            }
            List<Integer> l = i_l.get(i).getPossibleActions(this.allPossibleActions);
            Iterator<Integer> iterator = l.iterator();
            while (iterator.hasNext()) {
                int idx1;
                int n2 = idx1 = iterator.next().intValue();
                this.action_allowed_counts_prior[n2] = this.action_allowed_counts_prior[n2] + 1;
                if (idx1 != y) continue;
                for (int idx2 : l) {
                    int[] nArray = this.selected_allowed_action_prior[idx1];
                    int n3 = idx2;
                    nArray[n3] = nArray[n3] + 1;
                }
            }
            ArrayList<Integer> ltypes = new ArrayList<Integer>();
            for (Integer ua : l) {
                int ua_type = this.allPossibleActionsTypes.get(ua);
                if (ltypes.contains(ua_type)) continue;
                ltypes.add(ua_type);
            }
            Iterator<Integer> iterator2 = ltypes.iterator();
            while (iterator2.hasNext()) {
                int idx1;
                int n4 = idx1 = iterator2.next().intValue();
                this.actiontypes_allowed_counts_prior[n4] = this.actiontypes_allowed_counts_prior[n4] + 1;
                if (idx1 != this.allPossibleActionsTypes.get(y)) continue;
                Iterator iterator3 = ltypes.iterator();
                while (iterator3.hasNext()) {
                    int idx2 = (Integer)iterator3.next();
                    int[] nArray = this.selected_allowed_actiontype_prior[idx1];
                    int n5 = idx2;
                    nArray[n5] = nArray[n5] + 1;
                }
            }
        }
        if (this.estimationMethod == 1) {
            i = 0;
            while (i < this.Ysize) {
                int n = i++;
                this.prior_distribution[n] = this.prior_distribution[n] / (double)x_l.size();
            }
        } else {
            for (i = 0; i < this.Ysize; ++i) {
                this.prior_distribution[i] = (this.prior_distribution[i] + 1.0) / (double)(x_l.size() + this.Ysize);
            }
        }
    }

    @Override
    public void calibrateProbabilities(List<int[]> x_l, List<Integer> y_l, List<TrainingInstance> i_l) throws Exception {
        double best_c = 0.0;
        double best_ll = Double.NEGATIVE_INFINITY;
        for (double c = 0.0; c <= 1.05; c += 0.05) {
            this.calibrationFactor = c;
            double loglikelihood = 0.0;
            for (int i = 0; i < x_l.size(); ++i) {
                Unit u = i_l.get((int)i).u;
                List<UnitAction> possibleUnitActions = u.getUnitActions(i_l.get((int)i).gs);
                ArrayList<Integer> possibleUnitActionIndexes = new ArrayList<Integer>();
                for (UnitAction ua : possibleUnitActions) {
                    int idx;
                    if (ua.getType() == 5) {
                        ua = new UnitAction(5, ua.getLocationX() - u.getX(), ua.getLocationY() - u.getY());
                    }
                    if ((idx = this.allPossibleActions.indexOf(ua)) < 0) {
                        throw new Exception("Unknown action: " + ua);
                    }
                    possibleUnitActionIndexes.add(idx);
                }
                if (possibleUnitActions.size() <= 1) continue;
                double[] predicted_distribution = this.predictDistribution(x_l.get(i), i_l.get(i));
                predicted_distribution = this.filterByPossibleActionIndexes(predicted_distribution, possibleUnitActionIndexes);
                int actual_y = y_l.get(i);
                if (!possibleUnitActionIndexes.contains(actual_y)) continue;
                int predicted_y = -1;
                Collections.shuffle(possibleUnitActions);
                Iterator iterator = possibleUnitActionIndexes.iterator();
                while (iterator.hasNext()) {
                    int idx = (Integer)iterator.next();
                    if (predicted_y == -1) {
                        predicted_y = idx;
                        continue;
                    }
                    if (!(predicted_distribution[idx] > predicted_distribution[predicted_y])) continue;
                    predicted_y = idx;
                }
                double ll = Math.log(predicted_distribution[actual_y]);
                if (Double.isInfinite(ll)) {
                    System.out.println(Arrays.toString(predicted_distribution));
                    System.out.println(possibleUnitActionIndexes);
                    System.out.println(actual_y + " : " + this.allPossibleActions.get(actual_y));
                    System.exit(1);
                }
                loglikelihood += ll;
            }
            if (!(loglikelihood > best_ll)) break;
            best_c = c;
            best_ll = loglikelihood;
        }
        System.out.println("best calibration factor = " + best_c);
        this.calibrationFactor = best_c;
    }

    @Override
    public void featureSelectionByGainRatio(List<int[]> x_l, List<Integer> y_l, double fractionOfFeaturesToKeep) {
        int i;
        ArrayList<Integer> featureIndexes = new ArrayList<Integer>();
        final ArrayList<Double> featureGR = new ArrayList<Double>();
        int nfeatures = this.distributions.length;
        this.selectedFeatures = new boolean[nfeatures];
        for (i = 0; i < nfeatures; ++i) {
            featureIndexes.add(i);
            featureGR.add(FeatureSelection.featureGainRatio(x_l, y_l, i));
            this.selectedFeatures[i] = false;
        }
        featureIndexes.sort(new Comparator<Integer>(){

            @Override
            public int compare(Integer o1, Integer o2) {
                return Double.compare((Double)featureGR.get(o2), (Double)featureGR.get(o1));
            }
        });
        i = 0;
        while ((double)i < fractionOfFeaturesToKeep * (double)nfeatures) {
            this.selectedFeatures[((Integer)featureIndexes.get((int)i)).intValue()] = true;
            ++i;
        }
    }

    @Override
    public double[] predictDistribution(int[] x, TrainingInstance ti) {
        return this.predictDistribution(x, ti, this.calibrationFactor);
    }

    public double[] predictDistribution(int[] x, TrainingInstance ti, double correction) {
        int i;
        List<Integer> l = ti.getPossibleActions(this.allPossibleActions);
        double[] d = new double[this.Ysize];
        double n_factors = 1.0;
        for (int i2 = 0; i2 < this.Ysize; ++i2) {
            d[i2] = 0.0;
        }
        for (int i3 : l) {
            if (this.prior_distribution == null) {
                d[i3] = 1.0;
                continue;
            }
            d[i3] = this.prior_distribution[i3];
        }
        for (int i4 = 0; i4 < x.length; ++i4) {
            double[] d2;
            if (this.selectedFeatures != null && !this.selectedFeatures[i4]) continue;
            n_factors += 1.0;
            if (this.estimationMethod == 1) {
                for (int j : l) {
                    d2 = this.distributions[i4].distribution(j);
                    int n = j;
                    d[n] = d[n] * d2[x[i4]];
                }
                continue;
            }
            for (int j : l) {
                d2 = this.distributions[i4].distributionLaplace(j, 1.0);
                double v = 1.0;
                v = d2.length > x[i4] ? d2[x[i4]] : 1.0 / (double)this.Ysize;
                int n = j;
                d[n] = d[n] * v;
            }
        }
        if (this.consider_action_types && this.selected_allowed_actiontype_prior != null) {
            ArrayList<Integer> ltypes = new ArrayList<Integer>();
            for (Integer ua : l) {
                int ua_type = this.allPossibleActionsTypes.get(ua);
                if (ltypes.contains(ua_type)) continue;
                ltypes.add(ua_type);
            }
            n_factors += (double)(ltypes.size() - 1);
            for (int i5 : l) {
                int i_type = this.allPossibleActionsTypes.get(i5);
                Iterator v = ltypes.iterator();
                while (v.hasNext()) {
                    double p;
                    int j = (Integer)v.next();
                    if (j == i_type) continue;
                    if (this.estimationMethod == 1) {
                        p = (double)this.selected_allowed_actiontype_prior[i_type][j] / (double)this.actiontypes_allowed_counts_prior[i_type];
                        int n = i5;
                        d[n] = d[n] * p;
                        continue;
                    }
                    p = (double)(this.selected_allowed_actiontype_prior[i_type][j] + 1) / (double)(this.actiontypes_allowed_counts_prior[i_type] + 2);
                    int n = i5;
                    d[n] = d[n] * p;
                }
            }
        }
        if (this.consider_individual_actions && this.selected_allowed_action_prior != null) {
            n_factors += (double)(l.size() - 1);
            for (int i6 : l) {
                for (int j : l) {
                    double p;
                    if (j == i6) continue;
                    if (this.estimationMethod == 1) {
                        p = (double)this.selected_allowed_action_prior[i6][j] / (double)this.action_allowed_counts_prior[i6];
                        int n = i6;
                        d[n] = d[n] * p;
                        continue;
                    }
                    p = (double)(this.selected_allowed_action_prior[i6][j] + 1) / (double)(this.action_allowed_counts_prior[i6] + 2);
                    int n = i6;
                    d[n] = d[n] * p;
                }
            }
        }
        double accum = 0.0;
        for (i = 0; i < this.Ysize; ++i) {
            d[i] = Math.pow(d[i], 1.0 / (1.0 * (1.0 - correction) + n_factors * correction));
            accum += d[i];
        }
        if (accum <= 0.0) {
            for (i = 0; i < this.Ysize; ++i) {
                d[i] = 1.0 / (double)this.Ysize;
            }
        } else {
            i = 0;
            while (i < this.Ysize) {
                int n = i++;
                d[n] = d[n] / accum;
            }
        }
        return d;
    }

    @Override
    public void save(XMLWriter w) throws Exception {
        w.tagWithAttributes(this.getClass().getSimpleName(), "estimationMethod=\"" + this.estimationMethod + "\" Ysize=\"" + this.Ysize + "\" calibrationFactor=\"" + this.calibrationFactor + "\" nfeatures=\"" + this.distributions.length + "\" featureGenerationClass=\"" + this.featureGenerator.getClass().getSimpleName() + "\"");
        w.tag("Xsizes");
        for (int n : this.Xsizes) {
            w.rawXML(n + " ");
        }
        w.rawXML("\n");
        w.tag("/Xsizes");
        w.tag("priorDistribution");
        for (double d : this.prior_distribution) {
            w.rawXML(d + " ");
        }
        w.rawXML("\n");
        w.tag("/priorDistribution");
        if (this.selectedFeatures != null) {
            w.tag("selectedFeatures");
            for (boolean bl : this.selectedFeatures) {
                w.rawXML(bl + " ");
            }
            w.rawXML("\n");
            w.tag("/selectedFeatures");
        }
        w.tag("action_allowed_counts_prior");
        for (int n : this.action_allowed_counts_prior) {
            w.rawXML(n + " ");
        }
        w.rawXML("\n");
        w.tag("/action_allowed_counts_prior");
        w.tag("selected_action_pairs_prior");
        int[][] nArray = this.selected_allowed_action_prior;
        int n = nArray.length;
        for (int i = 0; i < n; ++i) {
            int[] nArray2;
            for (int v : nArray2 = nArray[i]) {
                w.rawXML(v + " ");
            }
            w.rawXML("\n");
        }
        w.tag("/selected_action_pairs_prior");
        w.tag("actiontypes_allowed_counts_prior");
        for (int n2 : this.actiontypes_allowed_counts_prior) {
            w.rawXML(n2 + " ");
        }
        w.rawXML("\n");
        w.tag("/actiontypes_allowed_counts_prior");
        w.tag("selected_allowed_actiontype_prior");
        for (int[] nArray3 : this.selected_allowed_actiontype_prior) {
            for (int v : nArray3) {
                w.rawXML(v + " ");
            }
            w.rawXML("\n");
        }
        w.tag("/selected_allowed_actiontype_prior");
        for (DiscreteCPD discreteCPD : this.distributions) {
            discreteCPD.save(w);
        }
        w.tag("/" + this.getClass().getSimpleName());
        w.flush();
    }

    public ActionInterdependenceModel(Element e, UnitTypeTable utt, String a_name) throws Exception {
        super(utt, null, a_name);
        this.load(e);
    }

    @Override
    public void load(Element e) throws Exception {
        int j;
        int i;
        int i2;
        if (!e.getName().equals(this.getClass().getSimpleName())) {
            throw new Exception("Head tag " + e.getName() + " is not '" + this.getClass().getSimpleName() + "'!");
        }
        this.allPossibleActionsTypes = new ArrayList<Integer>();
        for (UnitAction ua : this.allPossibleActions) {
            this.allPossibleActionsTypes.add(ua.getType());
        }
        String fgclass = e.getAttributeValue("featureGenerationClass");
        if (fgclass.contains("FeatureGeneratorEmpty")) {
            this.featureGenerator = new FeatureGeneratorEmpty();
        } else if (fgclass.contains("FeatureGeneratorSimple")) {
            this.featureGenerator = new FeatureGeneratorSimple();
        } else if (fgclass.contains("FeatureGeneratorComplex")) {
            this.featureGenerator = new FeatureGeneratorComplex();
        }
        this.YtypeSize = 6;
        this.estimationMethod = Integer.parseInt(e.getAttributeValue("estimationMethod"));
        this.Ysize = Integer.parseInt(e.getAttributeValue("Ysize"));
        this.calibrationFactor = Double.parseDouble(e.getAttributeValue("calibrationFactor"));
        int nfeatures = Integer.parseInt(e.getAttributeValue("nfeatures"));
        Element xs_xml = e.getChild("Xsizes");
        String text = xs_xml.getTextTrim();
        String[] tokens = text.split(" ");
        this.Xsizes = new int[nfeatures];
        for (int i3 = 0; i3 < nfeatures; ++i3) {
            this.Xsizes[i3] = Integer.parseInt(tokens[i3]);
        }
        Element pd_xml = e.getChild("priorDistribution");
        String text2 = pd_xml.getTextTrim();
        String[] tokens2 = text2.split(" ");
        this.prior_distribution = new double[this.Ysize];
        for (int i4 = 0; i4 < this.Ysize; ++i4) {
            this.prior_distribution[i4] = Double.parseDouble(tokens2[i4]);
        }
        Element sf_xml = e.getChild("selectedFeatures");
        if (sf_xml != null) {
            String text3 = sf_xml.getTextTrim();
            String[] tokens3 = text3.split(" ");
            this.selectedFeatures = new boolean[nfeatures];
            for (int i5 = 0; i5 < nfeatures; ++i5) {
                this.selectedFeatures[i5] = Boolean.parseBoolean(tokens3[i5]);
            }
        } else {
            this.selectedFeatures = null;
        }
        Element action_allowed_counts_prior_xml = e.getChild("action_allowed_counts_prior");
        String text4 = action_allowed_counts_prior_xml.getTextTrim();
        String[] tokens4 = text4.split(" ");
        this.action_allowed_counts_prior = new int[this.Ysize];
        for (i2 = 0; i2 < this.Ysize; ++i2) {
            this.action_allowed_counts_prior[i2] = Integer.parseInt(tokens4[i2]);
        }
        Element selected_action_pairs_prior_xml = e.getChild("selected_action_pairs_prior");
        text4 = selected_action_pairs_prior_xml.getTextTrim();
        tokens4 = text4.split(" |\n");
        this.selected_allowed_action_prior = new int[this.Ysize][this.Ysize];
        int k = 0;
        for (i = 0; i < this.Ysize; ++i) {
            j = 0;
            while (j < this.Ysize) {
                while (tokens4[k].equals("")) {
                    ++k;
                }
                this.selected_allowed_action_prior[i][j] = Integer.parseInt(tokens4[k]);
                ++j;
                ++k;
            }
        }
        Element actiontypes_allowed_counts_prior_xml = e.getChild("actiontypes_allowed_counts_prior");
        text4 = actiontypes_allowed_counts_prior_xml.getTextTrim();
        tokens4 = text4.split(" ");
        this.actiontypes_allowed_counts_prior = new int[this.YtypeSize];
        for (i2 = 0; i2 < this.YtypeSize; ++i2) {
            this.actiontypes_allowed_counts_prior[i2] = Integer.parseInt(tokens4[i2]);
        }
        Element selected_allowed_actiontype_prior_xml = e.getChild("selected_allowed_actiontype_prior");
        text4 = selected_allowed_actiontype_prior_xml.getTextTrim();
        tokens4 = text4.split(" |\n");
        this.selected_allowed_actiontype_prior = new int[this.YtypeSize][this.YtypeSize];
        k = 0;
        for (i = 0; i < this.YtypeSize; ++i) {
            j = 0;
            while (j < this.YtypeSize) {
                while (tokens4[k].equals("")) {
                    ++k;
                }
                this.selected_allowed_actiontype_prior[i][j] = Integer.parseInt(tokens4[k]);
                ++j;
                ++k;
            }
        }
        this.distributions = new DiscreteCPD[nfeatures];
        List cpd_xml_l = e.getChildren("DiscreteCPD");
        for (int i6 = 0; i6 < nfeatures; ++i6) {
            Element cpd_xml = (Element)cpd_xml_l.get(i6);
            this.distributions[i6] = new DiscreteCPD(cpd_xml);
        }
    }

    @Override
    public void featureSelectionByCrossValidation(List<int[]> x_l, List<Integer> y_l, List<TrainingInstance> i_l) throws Exception {
        boolean change;
        int nfeatures = this.distributions.length;
        System.out.println("featureSelectionByCrossValidation " + x_l.size());
        boolean[] bestSelection = new boolean[nfeatures];
        for (int i = 0; i < nfeatures; ++i) {
            bestSelection[i] = false;
        }
        this.selectedFeatures = bestSelection;
        double best_score = (Double)FeatureSelection.crossValidation((BayesianModel)this, x_l, y_l, i_l, (List<UnitAction>)this.allPossibleActions, (int)10).m_a;
        System.out.println("  loglikelihood with " + Arrays.toString(this.selectedFeatures) + ": " + best_score);
        do {
            change = false;
            boolean[] bestLastSelection = bestSelection;
            for (int i = 0; i < nfeatures; ++i) {
                if (bestSelection[i]) continue;
                boolean[] currentSelection = new boolean[nfeatures];
                System.arraycopy(bestSelection, 0, currentSelection, 0, nfeatures);
                currentSelection[i] = true;
                this.selectedFeatures = currentSelection;
                double score = (Double)FeatureSelection.crossValidation((BayesianModel)this, x_l, y_l, i_l, (List<UnitAction>)this.allPossibleActions, (int)10).m_a;
                System.out.println("  loglikelihood with " + Arrays.toString(this.selectedFeatures) + ": " + score);
                if (!(score > best_score)) continue;
                bestLastSelection = currentSelection;
                best_score = score;
                change = true;
            }
            bestSelection = bestLastSelection;
        } while (change);
        this.selectedFeatures = bestSelection;
        System.out.println("Selected features: " + Arrays.toString(this.selectedFeatures));
    }
}

