/*
 * Decompiled with CFR 0.152.
 */
package ai.machinelearning.bayes;

import ai.machinelearning.bayes.TrainingInstance;
import ai.machinelearning.bayes.featuregeneration.FeatureGenerator;
import ai.stochastic.UnitActionProbabilityDistribution;
import java.util.ArrayList;
import java.util.List;
import org.jdom.Element;
import rts.GameState;
import rts.UnitAction;
import rts.units.Unit;
import rts.units.UnitType;
import rts.units.UnitTypeTable;
import util.Sampler;
import util.XMLWriter;

public abstract class BayesianModel
extends UnitActionProbabilityDistribution {
    public static final int ESTIMATION_COUNTS = 1;
    public static final int ESTIMATION_LAPLACE = 2;
    public static final double laplaceBeta = 1.0;
    protected List<UnitAction> allPossibleActions;
    protected FeatureGenerator featureGenerator;
    protected String name;

    public BayesianModel(UnitTypeTable utt, FeatureGenerator fg, String a_name) {
        super(utt);
        this.allPossibleActions = BayesianModel.generateAllPossibleUnitActions(utt);
        this.featureGenerator = fg;
        this.name = a_name;
    }

    public abstract Object clone();

    public abstract void clearTraining();

    public abstract void train(List<int[]> var1, List<Integer> var2, List<TrainingInstance> var3) throws Exception;

    public void calibrateProbabilities(List<int[]> x_l, List<Integer> y_l, List<TrainingInstance> i_l) throws Exception {
    }

    public abstract void featureSelectionByCrossValidation(List<int[]> var1, List<Integer> var2, List<TrainingInstance> var3) throws Exception;

    public abstract void featureSelectionByGainRatio(List<int[]> var1, List<Integer> var2, double var3);

    public double[] predictDistribution(Unit u, GameState gs) throws Exception {
        TrainingInstance ti = new TrainingInstance(gs, u.getID(), null);
        int[] x = this.featureGenerator.generateFeaturesAsArray(ti);
        return this.predictDistribution(x, ti);
    }

    @Override
    public double[] predictDistribution(Unit u, GameState gs, List<UnitAction> actions2) throws Exception {
        TrainingInstance ti = new TrainingInstance(gs, u.getID(), null);
        int[] x = this.featureGenerator.generateFeaturesAsArray(ti);
        double[] prediction = this.predictDistribution(x, ti);
        return this.filterByPossibleActions(prediction, u, actions2);
    }

    public abstract double[] predictDistribution(int[] var1, TrainingInstance var2);

    public int predictMax(int[] x, TrainingInstance ti) {
        double[] d = this.predictDistribution(x, ti);
        int argmax = 0;
        for (int i = 1; i < d.length; ++i) {
            if (!(d[i] > d[argmax])) continue;
            argmax = i;
        }
        return argmax;
    }

    public int predictSample(int[] x, TrainingInstance ti) throws Exception {
        double[] d = this.predictDistribution(x, ti);
        return Sampler.weighted(d);
    }

    public double[] filterByPossibleActionIndexes(double[] predicted_distribution, List<Integer> possibleUnitActionIndexes) {
        int i;
        double accum = 0.0;
        int n = predicted_distribution.length;
        double[] d = new double[n];
        for (i = 0; i < n; ++i) {
            if (!possibleUnitActionIndexes.contains(i)) continue;
            accum += predicted_distribution[i];
        }
        for (i = 0; i < n; ++i) {
            d[i] = possibleUnitActionIndexes.contains(i) ? predicted_distribution[i] / accum : 0.0;
        }
        return d;
    }

    public double[] filterByPossibleActions(double[] d, Unit u, List<UnitAction> l) {
        int j;
        double[] filtered = new double[l.size()];
        double total = 0.0;
        for (int i = 0; i < l.size(); ++i) {
            UnitAction ua = l.get(i);
            if (ua.getType() == 5) {
                ua = new UnitAction(5, ua.getLocationX() - u.getX(), ua.getLocationY() - u.getY());
            }
            int idx = this.allPossibleActions.indexOf(ua);
            filtered[i] = d[idx];
            total += d[idx];
        }
        if (total > 0.0) {
            j = 0;
            while (j < l.size()) {
                int n = j++;
                filtered[n] = filtered[n] / total;
            }
        } else {
            for (j = 0; j < l.size(); ++j) {
                filtered[j] = 1.0 / (double)l.size();
            }
        }
        return filtered;
    }

    public static List<UnitAction> generateAllPossibleUnitActions(UnitTypeTable utt) {
        ArrayList<UnitAction> l = new ArrayList<UnitAction>();
        int maxAttackRange = 1;
        int[] directions = new int[]{0, 1, 2, 3};
        for (UnitType ut : utt.getUnitTypes()) {
            if (ut.attackRange <= maxAttackRange) continue;
            maxAttackRange = ut.attackRange;
        }
        l.add(new UnitAction(0, 10));
        for (Object d : (Object)directions) {
            l.add(new UnitAction(1, (int)d));
        }
        for (Object d : (Object)directions) {
            l.add(new UnitAction(2, (int)d));
        }
        for (Object d : (Object)directions) {
            l.add(new UnitAction(3, (int)d));
        }
        for (Object d : (Object)directions) {
            for (UnitType ut : utt.getUnitTypes()) {
                if (ut.producedBy.isEmpty()) continue;
                l.add(new UnitAction(4, (int)d, ut));
            }
        }
        for (int ox = -maxAttackRange; ox <= maxAttackRange; ++ox) {
            for (int oy = -maxAttackRange; oy <= maxAttackRange; ++oy) {
                int d = ox * ox + oy * oy;
                if (d <= 0 || d > maxAttackRange * maxAttackRange) continue;
                l.add(new UnitAction(5, ox, oy));
            }
        }
        return l;
    }

    public abstract void save(XMLWriter var1) throws Exception;

    public abstract void load(Element var1) throws Exception;

    public String toString() {
        return this.name;
    }
}

