/*
 * Decompiled with CFR 0.152.
 */
package infodynamics.measures.continuous.kraskov;

import infodynamics.measures.continuous.ConditionalMutualInfoMultiVariateCommon;
import infodynamics.utils.KdTree;
import infodynamics.utils.MathsUtils;
import infodynamics.utils.MatrixUtils;
import infodynamics.utils.NearestNeighbourSearcher;
import infodynamics.utils.UnivariateNearestNeighbourSearcher;
import java.util.Calendar;
import java.util.Random;

public abstract class ConditionalMutualInfoCalculatorMultiVariateKraskov
extends ConditionalMutualInfoMultiVariateCommon
implements Cloneable {
    protected int k = 4;
    protected int normType = 2;
    public static final String PROP_K = "k";
    public static final String PROP_NORM_TYPE = "NORM_TYPE";
    public static final String PROP_ADD_NOISE = "NOISE_LEVEL_TO_ADD";
    public static final String PROP_DYN_CORR_EXCL_TIME = "DYN_CORR_EXCL";
    public static final String PROP_NUM_THREADS = "NUM_THREADS";
    public static final String USE_ALL_THREADS = "USE_ALL";
    protected boolean addNoise = true;
    protected double noiseLevel = 1.0E-8;
    protected boolean dynCorrExcl = false;
    protected int dynCorrExclTime = 0;
    protected int numThreads = Runtime.getRuntime().availableProcessors();
    protected boolean isAlgorithm1 = false;
    protected KdTree kdTreeJoint;
    protected KdTree kdTreeVar1Conditional;
    protected UnivariateNearestNeighbourSearcher uniNNSearcherVar1;
    protected KdTree kdTreeVar2Conditional;
    protected UnivariateNearestNeighbourSearcher uniNNSearcherVar2;
    protected NearestNeighbourSearcher nnSearcherConditional;
    protected double digammaK;

    @Override
    public void initialise(int n, int n2, int n3) {
        this.kdTreeJoint = null;
        this.kdTreeVar1Conditional = null;
        this.kdTreeVar2Conditional = null;
        this.nnSearcherConditional = null;
        this.uniNNSearcherVar1 = null;
        this.uniNNSearcherVar2 = null;
        super.initialise(n, n2, n3);
    }

    @Override
    public void setProperty(String string, String string2) {
        if (string.equalsIgnoreCase(PROP_K)) {
            this.k = Integer.parseInt(string2);
        } else if (string.equalsIgnoreCase(PROP_NORM_TYPE)) {
            this.normType = KdTree.validateNormType(string2);
        } else if (string.equalsIgnoreCase(PROP_DYN_CORR_EXCL_TIME)) {
            this.dynCorrExclTime = Integer.parseInt(string2);
            this.dynCorrExcl = this.dynCorrExclTime > 0;
        } else if (string.equalsIgnoreCase(PROP_ADD_NOISE)) {
            if (string2.equals("0") || string2.equalsIgnoreCase("false")) {
                this.addNoise = false;
                this.noiseLevel = 0.0;
            } else {
                this.addNoise = true;
                this.noiseLevel = Double.parseDouble(string2);
            }
        } else if (string.equalsIgnoreCase(PROP_NUM_THREADS)) {
            this.numThreads = string2.equalsIgnoreCase(USE_ALL_THREADS) ? Runtime.getRuntime().availableProcessors() : Integer.parseInt(string2);
        } else {
            super.setProperty(string, string2);
        }
    }

    @Override
    public String getProperty(String string) {
        if (string.equalsIgnoreCase(PROP_K)) {
            return Integer.toString(this.k);
        }
        if (string.equalsIgnoreCase(PROP_NORM_TYPE)) {
            return KdTree.convertNormTypeToString(this.normType);
        }
        if (string.equalsIgnoreCase(PROP_DYN_CORR_EXCL_TIME)) {
            return Integer.toString(this.dynCorrExclTime);
        }
        if (string.equalsIgnoreCase(PROP_ADD_NOISE)) {
            return Double.toString(this.noiseLevel);
        }
        if (string.equalsIgnoreCase(PROP_NUM_THREADS)) {
            return Integer.toString(this.numThreads);
        }
        return super.getProperty(string);
    }

    @Override
    public void finaliseAddObservations() throws Exception {
        super.finaliseAddObservations();
        if (this.dynCorrExcl && this.addedMoreThanOneObservationSet) {
            throw new RuntimeException("Addition of multiple observation sets is not currently supported with property DYN_CORR_EXCL set");
        }
        if (this.totalObservations <= this.k + 2 * this.dynCorrExclTime) {
            throw new Exception("There are less observations provided (" + this.totalObservations + ") than required for the number of nearest neighbours parameter (" + this.k + ") and any dynamic correlation exclusion (" + this.dynCorrExclTime + ")");
        }
        if (this.addNoise) {
            Random random = new Random();
            for (int i = 0; i < this.var1Observations.length; ++i) {
                int n = 0;
                while (n < this.dimensionsVar1) {
                    double[] dArray = this.var1Observations[i];
                    int n2 = n++;
                    dArray[n2] = dArray[n2] + random.nextGaussian() * this.noiseLevel;
                }
                n = 0;
                while (n < this.dimensionsVar2) {
                    double[] dArray = this.var2Observations[i];
                    int n3 = n++;
                    dArray[n3] = dArray[n3] + random.nextGaussian() * this.noiseLevel;
                }
                n = 0;
                while (n < this.dimensionsCond) {
                    double[] dArray = this.condObservations[i];
                    int n4 = n++;
                    dArray[n4] = dArray[n4] + random.nextGaussian() * this.noiseLevel;
                }
            }
        }
        this.digammaK = MathsUtils.digamma(this.k);
    }

    @Override
    public double computeAverageLocalOfObservations() throws Exception {
        double d = Calendar.getInstance().getTimeInMillis();
        this.lastAverage = this.computeFromObservations(false, null)[0];
        this.condMiComputed = true;
        if (this.debug) {
            Calendar calendar = Calendar.getInstance();
            long l = calendar.getTimeInMillis();
            System.out.println("Calculation time: " + ((double)l - d) / 1000.0 + " sec");
        }
        return this.lastAverage;
    }

    @Override
    public double computeAverageLocalOfObservations(int n, int[] nArray) throws Exception {
        double[][] dArray;
        if (nArray == null) {
            return this.computeAverageLocalOfObservations();
        }
        KdTree kdTree = this.kdTreeJoint;
        this.kdTreeJoint = null;
        KdTree kdTree2 = this.kdTreeVar1Conditional;
        UnivariateNearestNeighbourSearcher univariateNearestNeighbourSearcher = this.uniNNSearcherVar1;
        KdTree kdTree3 = this.kdTreeVar2Conditional;
        UnivariateNearestNeighbourSearcher univariateNearestNeighbourSearcher2 = this.uniNNSearcherVar2;
        if (n == 1) {
            dArray = this.var1Observations;
            this.kdTreeVar1Conditional = null;
            this.uniNNSearcherVar1 = null;
            this.var1Observations = MatrixUtils.extractSelectedTimePointsReusingArrays(dArray, nArray);
        } else {
            dArray = this.var2Observations;
            this.kdTreeVar2Conditional = null;
            this.uniNNSearcherVar2 = null;
            this.var2Observations = MatrixUtils.extractSelectedTimePointsReusingArrays(dArray, nArray);
        }
        double d = this.computeFromObservations(false, null)[0];
        this.kdTreeJoint = kdTree;
        if (n == 1) {
            this.var1Observations = dArray;
            this.kdTreeVar1Conditional = kdTree2;
            this.uniNNSearcherVar1 = univariateNearestNeighbourSearcher;
        } else {
            this.var2Observations = dArray;
            this.kdTreeVar2Conditional = kdTree3;
            this.uniNNSearcherVar2 = univariateNearestNeighbourSearcher2;
        }
        return d;
    }

    @Override
    public double[] computeLocalOfPreviousObservations() throws Exception {
        double[] dArray = this.computeFromObservations(true, null);
        this.lastAverage = MatrixUtils.mean(dArray);
        this.condMiComputed = true;
        return dArray;
    }

    @Override
    public double[] computeLocalUsingPreviousObservations(double[][] dArray, double[][] dArray2, double[][] dArray3) throws Exception {
        double[][] dArray4;
        double[][] dArray5;
        double[][] dArray6;
        if (this.normalise) {
            dArray6 = MatrixUtils.normaliseIntoNewArray(dArray, this.var1Means, this.var1Stds);
            dArray5 = MatrixUtils.normaliseIntoNewArray(dArray2, this.var2Means, this.var2Stds);
            dArray4 = this.dimensionsCond != 0 ? MatrixUtils.normaliseIntoNewArray(dArray3, this.condMeans, this.condStds) : (double[][])null;
        } else {
            dArray6 = dArray;
            dArray5 = dArray2;
            dArray4 = dArray3;
        }
        double[] dArray7 = this.computeFromObservations(true, new double[][][]{dArray6, dArray5, dArray4});
        return dArray7;
    }

    protected double[] computeFromObservations(boolean bl, double[][][] dArray) throws Exception {
        int n;
        int n2 = this.var1Observations.length;
        double[] dArray2 = null;
        if (this.kdTreeJoint == null) {
            this.kdTreeJoint = new KdTree(new int[]{this.dimensionsVar1, this.dimensionsVar2, this.dimensionsCond}, new double[][][]{this.var1Observations, this.var2Observations, this.condObservations});
            this.kdTreeJoint.setNormType(this.normType);
        }
        if (this.dimensionsVar1 > 1) {
            if (this.kdTreeVar1Conditional == null) {
                this.kdTreeVar1Conditional = new KdTree(new int[]{this.dimensionsVar1, this.dimensionsCond}, new double[][][]{this.var1Observations, this.condObservations});
                this.kdTreeVar1Conditional.setNormType(this.normType);
            }
        } else if (this.uniNNSearcherVar1 == null) {
            this.uniNNSearcherVar1 = new UnivariateNearestNeighbourSearcher(this.var1Observations);
        }
        if (this.dimensionsVar2 > 1) {
            if (this.kdTreeVar2Conditional == null) {
                this.kdTreeVar2Conditional = new KdTree(new int[]{this.dimensionsVar2, this.dimensionsCond}, new double[][][]{this.var2Observations, this.condObservations});
                this.kdTreeVar2Conditional.setNormType(this.normType);
            }
        } else if (this.uniNNSearcherVar2 == null) {
            this.uniNNSearcherVar2 = new UnivariateNearestNeighbourSearcher(this.var2Observations);
        }
        if (this.nnSearcherConditional == null && this.dimensionsCond > 0) {
            this.nnSearcherConditional = NearestNeighbourSearcher.create(this.condObservations);
            this.nnSearcherConditional.setNormType(this.normType);
        }
        int n3 = n = dArray == null ? n2 : dArray[0].length;
        if (this.numThreads == 1) {
            dArray2 = dArray == null ? this.partialComputeFromObservations(0, n, bl) : this.partialComputeFromNewObservations(0, n, dArray[0], dArray[1], dArray[2], bl);
        } else {
            int n4;
            dArray2 = bl ? new double[n] : new double[6];
            int n5 = n / this.numThreads;
            int n6 = n % this.numThreads;
            if (this.debug) {
                System.out.printf("Computing Kraskov conditional MI with %d threads (%d timesteps each, plus %d residual)\n", this.numThreads, n5, n6);
            }
            Thread[] threadArray = new Thread[this.numThreads];
            CondMiKraskovThreadRunner[] condMiKraskovThreadRunnerArray = new CondMiKraskovThreadRunner[this.numThreads];
            for (n4 = 0; n4 < this.numThreads; ++n4) {
                int n7;
                int n8 = n4 == 0 ? 0 : n5 * n4 + n6;
                int n9 = n7 = n4 == 0 ? n5 + n6 : n5;
                if (this.debug) {
                    System.out.println(n4 + ".Thread: from " + n8 + " to " + (n8 + n7));
                }
                condMiKraskovThreadRunnerArray[n4] = new CondMiKraskovThreadRunner(this, n8, n7, dArray, bl);
                threadArray[n4] = new Thread(condMiKraskovThreadRunnerArray[n4]);
                threadArray[n4].start();
            }
            for (n4 = 0; n4 < this.numThreads; ++n4) {
                if (threadArray[n4] != null) {
                    threadArray[n4].join();
                }
                if (bl) {
                    System.arraycopy(condMiKraskovThreadRunnerArray[n4].getReturnValues(), 0, dArray2, condMiKraskovThreadRunnerArray[n4].myStartTimePoint, condMiKraskovThreadRunnerArray[n4].numberOfTimePoints);
                    continue;
                }
                MatrixUtils.addInPlace(dArray2, condMiKraskovThreadRunnerArray[n4].getReturnValues());
            }
        }
        if (bl) {
            return dArray2;
        }
        double d = dArray2[0] / (double)n;
        double d2 = dArray2[1] / (double)n;
        double d3 = dArray2[2] / (double)n;
        double d4 = dArray2[3] / (double)n;
        if (this.debug) {
            System.out.printf("<n_xz>=%.3f, <n_yz>=%.3f, <n_z>=%.3f\n", d2, d3, d4);
        }
        if (this.isAlgorithm1) {
            if (this.debug) {
                System.out.printf("Av = digamma(k)=%.3f + <digammas>=%.3f = %.3f \n", MathsUtils.digamma(this.k), d, MathsUtils.digamma(this.k) + d);
            }
            double[] dArray3 = new double[]{MathsUtils.digamma(this.k) + d};
            return dArray3;
        }
        double d5 = dArray2[4] / (double)n;
        double d6 = dArray2[5] / (double)n;
        double d7 = this.dimensionsCond > 0 ? 2.0 / (double)this.k : 1.0 / (double)this.k;
        double d8 = MathsUtils.digamma(this.k) - d7 + d + d5 + d6;
        if (this.debug) {
            System.out.printf("Av = digamma(k)=%.3f + <digammas>=%.3f +<inverses>=%.3f - $d/k=%.3f  = %.3f (<1/n_yz>=%.3f, <1/n_xz>=%.3f)\n", MathsUtils.digamma(this.k), d, d5 + d6, this.dimensionsCond > 0 ? 2 : 1, d7, d8, d6, d5);
        }
        double[] dArray4 = new double[]{d8};
        return dArray4;
    }

    protected abstract double[] partialComputeFromObservations(int var1, int var2, boolean var3) throws Exception;

    protected abstract double[] partialComputeFromNewObservations(int var1, int var2, double[][] var3, double[][] var4, double[][] var5, boolean var6) throws Exception;

    private class CondMiKraskovThreadRunner
    implements Runnable {
        protected ConditionalMutualInfoCalculatorMultiVariateKraskov condMiCalc;
        protected int myStartTimePoint;
        protected int numberOfTimePoints;
        protected double[][][] newObservations;
        protected boolean computeLocals;
        protected double[] returnValues = null;
        protected Exception problem = null;
        public static final int INDEX_SUM_DIGAMMAS = 0;
        public static final int INDEX_SUM_NXZ = 1;
        public static final int INDEX_SUM_NYZ = 2;
        public static final int INDEX_SUM_NZ = 3;
        public static final int INDEX_SUM_INV_NXZ = 4;
        public static final int INDEX_SUM_INV_NYZ = 5;
        public static final int RETURN_ARRAY_LENGTH = 6;

        public CondMiKraskovThreadRunner(ConditionalMutualInfoCalculatorMultiVariateKraskov conditionalMutualInfoCalculatorMultiVariateKraskov2, int n, int n2, double[][][] dArray, boolean bl) {
            this.condMiCalc = conditionalMutualInfoCalculatorMultiVariateKraskov2;
            this.myStartTimePoint = n;
            this.numberOfTimePoints = n2;
            this.computeLocals = bl;
            this.newObservations = dArray;
        }

        public double[] getReturnValues() throws Exception {
            if (this.problem != null) {
                throw this.problem;
            }
            return this.returnValues;
        }

        @Override
        public void run() {
            try {
                this.returnValues = this.newObservations == null ? this.condMiCalc.partialComputeFromObservations(this.myStartTimePoint, this.numberOfTimePoints, this.computeLocals) : this.condMiCalc.partialComputeFromNewObservations(this.myStartTimePoint, this.numberOfTimePoints, this.newObservations[0], this.newObservations[1], this.newObservations[2], this.computeLocals);
            }
            catch (Exception exception) {
                this.problem = exception;
                return;
            }
        }
    }
}

