/*
 * Decompiled with CFR 0.152.
 */
package ai.machinelearning.bayes;

import ai.machinelearning.bayes.TrainingInstance;
import ai.machinelearning.bayes.featuregeneration.FeatureGenerator;
import ai.stochastic.UnitActionProbabilityDistribution;
import java.util.ArrayList;
import java.util.List;
import org.jdom.Element;
import rts.GameState;
import rts.UnitAction;
import rts.units.Unit;
import rts.units.UnitType;
import rts.units.UnitTypeTable;
import util.Sampler;
import util.XMLWriter;

public abstract class BayesianModel
extends UnitActionProbabilityDistribution {
    public static final int ESTIMATION_COUNTS = 1;
    public static final int ESTIMATION_LAPLACE = 2;
    public static final double laplaceBeta = 1.0;
    protected List<UnitAction> allPossibleActions = null;
    protected FeatureGenerator featureGenerator = null;
    protected String name = null;

    public BayesianModel(UnitTypeTable unitTypeTable, FeatureGenerator featureGenerator, String string) {
        super(unitTypeTable);
        this.allPossibleActions = BayesianModel.generateAllPossibleUnitActions(unitTypeTable);
        this.featureGenerator = featureGenerator;
        this.name = string;
    }

    public abstract Object clone();

    public abstract void clearTraining();

    public abstract void train(List<int[]> var1, List<Integer> var2, List<TrainingInstance> var3) throws Exception;

    public void calibrateProbabilities(List<int[]> list, List<Integer> list2, List<TrainingInstance> list3) throws Exception {
    }

    public abstract void featureSelectionByCrossValidation(List<int[]> var1, List<Integer> var2, List<TrainingInstance> var3) throws Exception;

    public abstract void featureSelectionByGainRatio(List<int[]> var1, List<Integer> var2, double var3);

    public double[] predictDistribution(Unit unit, GameState gameState) throws Exception {
        TrainingInstance trainingInstance = new TrainingInstance(gameState, unit.getID(), null);
        int[] nArray = this.featureGenerator.generateFeaturesAsArray(trainingInstance);
        return this.predictDistribution(nArray, trainingInstance);
    }

    @Override
    public double[] predictDistribution(Unit unit, GameState gameState, List<UnitAction> list) throws Exception {
        TrainingInstance trainingInstance = new TrainingInstance(gameState, unit.getID(), null);
        int[] nArray = this.featureGenerator.generateFeaturesAsArray(trainingInstance);
        double[] dArray = this.predictDistribution(nArray, trainingInstance);
        return this.filterByPossibleActions(dArray, unit, list);
    }

    public abstract double[] predictDistribution(int[] var1, TrainingInstance var2);

    public int predictMax(int[] nArray, TrainingInstance trainingInstance) {
        double[] dArray = this.predictDistribution(nArray, trainingInstance);
        int n = 0;
        for (int j = 1; j < dArray.length; ++j) {
            if (!(dArray[j] > dArray[n])) continue;
            n = j;
        }
        return n;
    }

    public int predictSample(int[] nArray, TrainingInstance trainingInstance) throws Exception {
        double[] dArray = this.predictDistribution(nArray, trainingInstance);
        return Sampler.weighted(dArray);
    }

    public double[] filterByPossibleActionIndexes(double[] dArray, List<Integer> list) {
        int n;
        double d = 0.0;
        int n2 = dArray.length;
        double[] dArray2 = new double[n2];
        for (n = 0; n < n2; ++n) {
            if (!list.contains(n)) continue;
            d += dArray[n];
        }
        for (n = 0; n < n2; ++n) {
            dArray2[n] = list.contains(n) ? dArray[n] / d : 0.0;
        }
        return dArray2;
    }

    public double[] filterByPossibleActions(double[] dArray, Unit unit, List<UnitAction> list) {
        int n;
        double[] dArray2 = new double[list.size()];
        double d = 0.0;
        for (n = 0; n < list.size(); ++n) {
            UnitAction unitAction = list.get(n);
            if (unitAction.getType() == 5) {
                unitAction = new UnitAction(5, unitAction.getLocationX() - unit.getX(), unitAction.getLocationY() - unit.getY());
            }
            int n2 = this.allPossibleActions.indexOf(unitAction);
            dArray2[n] = dArray[n2];
            d += dArray[n2];
        }
        if (d > 0.0) {
            n = 0;
            while (n < list.size()) {
                int n3 = n++;
                dArray2[n3] = dArray2[n3] / d;
            }
        } else {
            for (n = 0; n < list.size(); ++n) {
                dArray2[n] = 1.0 / (double)list.size();
            }
        }
        return dArray2;
    }

    public static List<UnitAction> generateAllPossibleUnitActions(UnitTypeTable unitTypeTable) {
        int n;
        ArrayList<UnitAction> arrayList = new ArrayList<UnitAction>();
        int n2 = 1;
        int[] nArray = new int[]{0, 1, 2, 3};
        for (UnitType unitType : unitTypeTable.getUnitTypes()) {
            if (unitType.attackRange <= n2) continue;
            n2 = unitType.attackRange;
        }
        arrayList.add(new UnitAction(0, 10));
        for (Object object2 : (Object)nArray) {
            arrayList.add(new UnitAction(1, (int)object2));
        }
        for (Object object2 : (Object)nArray) {
            arrayList.add(new UnitAction(2, (int)object2));
        }
        for (Object object2 : (Object)nArray) {
            arrayList.add(new UnitAction(3, (int)object2));
        }
        Object object = nArray;
        int n3 = ((Object)object).length;
        for (n = 0; n < n3; ++n) {
            Object object2;
            object2 = object[n];
            for (UnitType unitType : unitTypeTable.getUnitTypes()) {
                arrayList.add(new UnitAction(4, (int)object2, unitType));
            }
        }
        for (int j = -n2; j <= n2; ++j) {
            for (n3 = -n2; n3 <= n2; ++n3) {
                n = j * j + n3 * n3;
                if (n <= 0 || n > n2 * n2) continue;
                arrayList.add(new UnitAction(5, j, n3));
            }
        }
        return arrayList;
    }

    public abstract void save(XMLWriter var1) throws Exception;

    public abstract void load(Element var1) throws Exception;

    public String toString() {
        return this.name;
    }
}

