/*
 * Decompiled with CFR 0.152.
 */
package infodynamics.measures.mixed.kraskov;

import infodynamics.measures.mixed.MutualInfoCalculatorMultiVariateWithDiscrete;
import infodynamics.utils.EmpiricalMeasurementDistribution;
import infodynamics.utils.KdTree;
import infodynamics.utils.MathsUtils;
import infodynamics.utils.MatrixUtils;
import infodynamics.utils.RandomGenerator;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Random;
import java.util.Vector;

public class MutualInfoCalculatorMultiVariateWithDiscreteKraskov
implements MutualInfoCalculatorMultiVariateWithDiscrete,
Cloneable {
    protected int k = 4;
    protected double[][] continuousData;
    protected int[] discreteData;
    protected int[] counts;
    protected int base;
    protected int dimensions;
    protected int totalObservations;
    protected boolean debug;
    protected double mi;
    protected boolean miComputed;
    protected double digammaN;
    protected double digammaK;
    protected Vector<double[][]> vectorOfContinuousObservations;
    protected Vector<int[]> vectorOfDiscreteObservations;
    protected KdTree kdTreeJoint;
    protected KdTree[] kdTreeBins;
    protected int normType = 2;
    public static final String PROP_K = "k";
    public static final String PROP_NORM_TYPE = "NORM_TYPE";
    public static final String PROP_NORMALISE = "NORMALISE";
    public static final String PROP_TIME_DIFF = "TIME_DIFF";
    public static final String PROP_ADD_NOISE = "NOISE_LEVEL_TO_ADD";
    protected boolean normalise = true;
    protected double[] means;
    protected double[] stds;
    protected int timeDiff = 0;
    protected boolean addNoise = true;
    protected double noiseLevel = 1.0E-8;
    protected int dynCorrExclTime = 0;

    @Override
    public void initialise(int n, int n2) {
        this.mi = 0.0;
        this.miComputed = false;
        this.totalObservations = 0;
        this.continuousData = null;
        this.means = null;
        this.stds = null;
        this.discreteData = null;
        this.dimensions = n;
        this.base = n2;
        this.kdTreeJoint = null;
        this.kdTreeBins = null;
    }

    @Override
    public void setProperty(String string, String string2) throws Exception {
        if (string.equalsIgnoreCase(PROP_K)) {
            this.k = Integer.parseInt(string2);
        } else if (string.equalsIgnoreCase(PROP_NORM_TYPE)) {
            this.normType = KdTree.validateNormType(string2);
        } else if (string.equalsIgnoreCase(PROP_NORMALISE)) {
            this.normalise = Boolean.parseBoolean(string2);
        } else if (string.equalsIgnoreCase(PROP_ADD_NOISE)) {
            if (string2.equals("0") || string2.equalsIgnoreCase("false")) {
                this.addNoise = false;
                this.noiseLevel = 0.0;
            } else {
                this.addNoise = true;
                this.noiseLevel = Double.parseDouble(string2);
            }
        } else if (string.equalsIgnoreCase(PROP_TIME_DIFF)) {
            this.timeDiff = Integer.parseInt(string2);
        }
    }

    public void addObservations(double[][] dArray, int[] nArray) throws Exception {
        if (this.vectorOfContinuousObservations == null) {
            throw new RuntimeException("User did not call startAddObservations before addObservations");
        }
        if (dArray.length != nArray.length) {
            throw new Exception("Time steps for observations2 " + nArray.length + " does not match the length " + "of observations1 " + dArray.length);
        }
        if (dArray[0].length != this.dimensions) {
            throw new Exception("The continuous observations do not have the expected number of variables (" + this.dimensions + ")");
        }
        if (dArray.length > Math.abs(this.timeDiff)) {
            this.vectorOfContinuousObservations.add(dArray);
            this.vectorOfDiscreteObservations.add(nArray);
        }
    }

    public void addObservations(double[][] dArray, double[][] dArray2, int n, int n2) throws Exception {
        throw new RuntimeException("Not implemented yet");
    }

    public void setObservations(double[][] dArray, double[][] dArray2, boolean[] blArray, boolean[] blArray2) throws Exception {
        throw new RuntimeException("Not implemented yet");
    }

    public void setObservations(double[][] dArray, double[][] dArray2, boolean[][] blArray, boolean[][] blArray2) throws Exception {
        throw new RuntimeException("Not implemented yet");
    }

    public void startAddObservations() {
        this.vectorOfContinuousObservations = new Vector();
        this.vectorOfDiscreteObservations = new Vector();
    }

    public void finaliseAddObservations() throws Exception {
        int arrayIndexOutOfBoundsException;
        if (this.vectorOfContinuousObservations.size() < 1) {
            throw new Exception("Cannot compute MI with a null set of data");
        }
        this.totalObservations = 0;
        for (double[][] object2 : this.vectorOfContinuousObservations) {
            this.totalObservations += object2.length - Math.abs(this.timeDiff);
        }
        this.continuousData = new double[this.totalObservations][this.dimensions];
        this.discreteData = new int[this.totalObservations];
        int n2 = 0;
        Iterator<double[][]> iterator = this.vectorOfContinuousObservations.iterator();
        for (int[] i : this.vectorOfDiscreteObservations) {
            double[][] n4 = iterator.next();
            if (this.timeDiff >= 0) {
                MatrixUtils.arrayCopy(n4, 0, 0, this.continuousData, n2, 0, n4.length - this.timeDiff, this.dimensions);
                System.arraycopy(i, this.timeDiff, this.discreteData, n2, i.length - this.timeDiff);
                n2 += n4.length - this.timeDiff;
                continue;
            }
            MatrixUtils.arrayCopy(n4, Math.abs(this.timeDiff), 0, this.continuousData, n2, 0, n4.length - Math.abs(this.timeDiff), this.dimensions);
            System.arraycopy(i, 0, this.discreteData, n2, i.length - Math.abs(this.timeDiff));
            n2 += n4.length - Math.abs(this.timeDiff);
        }
        this.vectorOfContinuousObservations = null;
        this.vectorOfDiscreteObservations = null;
        if (this.normalise) {
            this.means = MatrixUtils.means(this.continuousData);
            this.stds = MatrixUtils.stdDevs(this.continuousData, this.means);
            MatrixUtils.normalise(this.continuousData, this.means, this.stds);
        }
        this.counts = new int[this.base];
        try {
            for (arrayIndexOutOfBoundsException = 0; arrayIndexOutOfBoundsException < this.discreteData.length; ++arrayIndexOutOfBoundsException) {
                int n = this.discreteData[arrayIndexOutOfBoundsException];
                this.counts[n] = this.counts[n] + 1;
            }
        }
        catch (ArrayIndexOutOfBoundsException random) {
            this.totalObservations = 0;
            this.continuousData = null;
            this.discreteData = null;
            throw new RuntimeException("Values of the discrete variable must range from 0 to base-1");
        }
        for (arrayIndexOutOfBoundsException = 0; arrayIndexOutOfBoundsException < this.counts.length; ++arrayIndexOutOfBoundsException) {
            if (this.counts[arrayIndexOutOfBoundsException] >= this.k) continue;
            throw new RuntimeException("This implementation assumes there are at least k items in each discrete bin");
        }
        if (this.addNoise) {
            Random random = new Random();
            for (int i = 0; i < this.totalObservations; ++i) {
                int n = 0;
                while (n < this.dimensions) {
                    double[] dArray = this.continuousData[i];
                    int n3 = n++;
                    dArray[n3] = dArray[n3] + random.nextGaussian() * this.noiseLevel;
                }
            }
        }
        this.digammaN = MathsUtils.digamma(this.totalObservations);
        this.digammaK = MathsUtils.digamma(this.k);
        this.ensureKdTreesConstructed();
    }

    @Override
    public void setObservations(double[][] dArray, int[] nArray) throws Exception {
        this.startAddObservations();
        this.addObservations(dArray, nArray);
        this.finaliseAddObservations();
    }

    public void ensureKdTreesConstructed() {
        if (this.kdTreeJoint == null) {
            this.kdTreeJoint = new KdTree(this.continuousData);
            this.kdTreeJoint.setNormType(this.normType);
        }
        if (this.kdTreeBins == null) {
            this.kdTreeBins = new KdTree[this.base];
            for (int i = 0; i < this.base; ++i) {
                this.kdTreeBins[i] = new KdTree(MatrixUtils.extractSelectedPointsMatchingCondition(this.continuousData, this.discreteData, i));
                this.kdTreeBins[i].setNormType(this.normType);
            }
        }
    }

    public double computeAverageLocalOfObservations(int[] nArray) throws Exception {
        if (nArray == null) {
            return this.computeAverageLocalOfObservations();
        }
        MutualInfoCalculatorMultiVariateWithDiscreteKraskov mutualInfoCalculatorMultiVariateWithDiscreteKraskov = (MutualInfoCalculatorMultiVariateWithDiscreteKraskov)this.clone();
        int[] nArray2 = MatrixUtils.extractSelectedTimePoints(this.discreteData, nArray);
        mutualInfoCalculatorMultiVariateWithDiscreteKraskov.setProperty(PROP_TIME_DIFF, "0");
        mutualInfoCalculatorMultiVariateWithDiscreteKraskov.initialise(this.dimensions, this.base);
        mutualInfoCalculatorMultiVariateWithDiscreteKraskov.setObservations(this.continuousData, nArray2);
        return mutualInfoCalculatorMultiVariateWithDiscreteKraskov.computeAverageLocalOfObservations();
    }

    @Override
    public double computeAverageLocalOfObservations() throws Exception {
        return this.computeFromObservations(false)[0];
    }

    public double[] computeLocalOfPreviousObservations() throws Exception {
        return this.computeFromObservations(true);
    }

    protected double[] computeFromObservations(boolean bl) throws Exception {
        int[] nArray = new int[this.base];
        Arrays.fill(nArray, 0);
        double d = 0.0;
        double d2 = 0.0;
        double d3 = 0.0;
        double d4 = 0.0;
        int n = this.totalObservations;
        double[] dArray = null;
        if (bl) {
            dArray = new double[n];
        }
        for (int i = 0; i < n; ++i) {
            int n2;
            int n3 = n2 = this.discreteData[i];
            int n4 = nArray[n3];
            nArray[n3] = n4 + 1;
            double d5 = this.kdTreeBins[n2].findKNearestNeighbours((int)this.k, (int)n4).poll().distance;
            int n5 = this.kdTreeJoint.countPointsWithinOrOnR(i, d5, this.dynCorrExclTime);
            int n6 = this.counts[n2] - 1;
            d2 += (double)n5;
            d3 += (double)n6;
            double d6 = MathsUtils.digamma(n5) + MathsUtils.digamma(n6);
            d += d6;
            double d7 = this.digammaK + this.digammaN - d6;
            if (bl) {
                dArray[i] = d7;
            }
            if (!this.debug) continue;
            d4 += d7;
            if (this.dimensions == 1) {
                System.out.printf("t=%d: x=%.3f, eps_x=%.3f, n_x=%d, n_y=%d, local=%.3f, running total = %.5f\n", i, this.continuousData[i][0], d5, n5, n6, d7, d4);
                continue;
            }
            System.out.printf("t=%d: eps_x=%.3f, n_x=%d, n_y=%d, local=%.3f, running total = %.5f\n", i, d5, n5, n6, d7, d4);
        }
        d /= (double)n;
        if (this.debug) {
            System.out.println(String.format("Average n_x=%.3f (-> digam=%.3f %.3f), Average n_y=%.3f (-> digam=%.3f)", d2 /= (double)n, MathsUtils.digamma((int)d2), MathsUtils.digamma((int)d2 - 1), d3 /= (double)n, MathsUtils.digamma((int)d3)));
            System.out.printf("Independent average num in joint box is %.3f\n", d2 * d3 / (double)n);
            System.out.println(String.format("digamma(k)=%.3f - averageDiGammas=%.3f + digamma(N)=%.3f\n", this.digammaK, d, this.digammaN));
        }
        this.mi = this.digammaK + this.digammaN - d;
        this.miComputed = true;
        double[] dArray2 = bl ? dArray : new double[]{this.mi};
        return dArray2;
    }

    @Override
    public synchronized EmpiricalMeasurementDistribution computeSignificance(int n) throws Exception {
        RandomGenerator randomGenerator = new RandomGenerator();
        int[][] nArray = randomGenerator.generateRandomPerturbations(this.continuousData.length, n);
        return this.computeSignificance(nArray);
    }

    @Override
    public EmpiricalMeasurementDistribution computeSignificance(int[][] nArray) throws Exception {
        int n = nArray.length;
        if (!this.miComputed) {
            this.computeAverageLocalOfObservations();
        }
        double d = this.mi;
        EmpiricalMeasurementDistribution empiricalMeasurementDistribution = new EmpiricalMeasurementDistribution(n);
        int n2 = 0;
        for (int i = 0; i < n; ++i) {
            double d2;
            empiricalMeasurementDistribution.distribution[i] = d2 = this.computeAverageLocalOfObservations(nArray[i]);
            if (this.debug) {
                System.out.println("New MI was " + d2);
            }
            if (!(d2 >= d)) continue;
            ++n2;
        }
        this.mi = d;
        empiricalMeasurementDistribution.pValue = (double)n2 / (double)n;
        empiricalMeasurementDistribution.actualValue = this.mi;
        return empiricalMeasurementDistribution;
    }

    @Override
    public double[] computeLocalUsingPreviousObservations(double[][] dArray, int[] nArray) throws Exception {
        if (this.normalise) {
            dArray = MatrixUtils.normaliseIntoNewArray(dArray, this.means, this.stds);
        }
        double d = this.digammaK + MathsUtils.digamma(this.totalObservations);
        double d2 = 0.0;
        double d3 = 0.0;
        double d4 = 0.0;
        double[] dArray2 = new double[nArray.length];
        for (int i = 0; i < nArray.length; ++i) {
            int n = nArray[i];
            double[][] dArrayArray = new double[][]{dArray[i]};
            double d5 = this.kdTreeBins[n].findKNearestNeighbours((int)this.k, (double[][])dArrayArray).poll().distance;
            int n2 = this.kdTreeJoint.countPointsWithinR(dArrayArray, d5, true);
            int n3 = this.counts[n];
            dArray2[i] = d - MathsUtils.digamma(n2) - MathsUtils.digamma(n3);
            if (!this.debug) continue;
            d2 += dArray2[i];
            d3 += (double)n2;
            d4 += (double)n3;
            if (this.dimensions == 1) {
                System.out.printf("t=%d: x=%.3f, eps_x=%.3f, n_x=%d, n_y=%d, local=%.3f, running total = %.5f\n", i, dArray[i][0], d5, n2, n3, dArray2[i], d2);
                continue;
            }
            System.out.printf("t=%d: eps_x=%.3f, n_x=%d, n_y=%d, local=%.3f, running total = %.5f\n", i, d5, n2, n3, dArray2[i], d2);
        }
        if (this.debug) {
            System.out.printf("Average n_x=%.3f, Average n_y=%.3f\n", d3 /= (double)nArray.length, d4 /= (double)nArray.length);
        }
        return dArray2;
    }

    @Override
    public void setDebug(boolean bl) {
        this.debug = bl;
    }

    @Override
    public double getLastAverage() {
        return this.mi;
    }

    @Override
    public int getNumObservations() {
        return this.totalObservations;
    }
}

