/*
 * Decompiled with CFR 0.152.
 */
package infodynamics.measures.continuous.kraskov;

import infodynamics.measures.continuous.MutualInfoCalculatorMultiVariate;
import infodynamics.measures.continuous.MutualInfoMultiVariateCommon;
import infodynamics.utils.EmpiricalMeasurementDistribution;
import infodynamics.utils.KdTree;
import infodynamics.utils.MathsUtils;
import infodynamics.utils.MatrixUtils;
import infodynamics.utils.NativeUtils;
import infodynamics.utils.NearestNeighbourSearcher;
import infodynamics.utils.NeighbourNodeData;
import java.util.Calendar;
import java.util.PriorityQueue;
import java.util.Random;

public abstract class MutualInfoCalculatorMultiVariateKraskov
extends MutualInfoMultiVariateCommon
implements MutualInfoCalculatorMultiVariate {
    protected int k = 4;
    protected int normType = 2;
    public static final String PROP_K = "k";
    public static final String PROP_NORM_TYPE = "NORM_TYPE";
    public static final String PROP_NORMALISE = "NORMALISE";
    public static final String PROP_ADD_NOISE = "NOISE_LEVEL_TO_ADD";
    public static final String PROP_DYN_CORR_EXCL_TIME = "DYN_CORR_EXCL";
    public static final String PROP_NUM_THREADS = "NUM_THREADS";
    public static final String USE_ALL_THREADS = "USE_ALL";
    public static final String PROP_USE_GPU = "USE_GPU";
    public static final String PROP_GPU_LIBRARY_PATH = "GPU_LIBRARY_PATH";
    protected boolean normalise = true;
    protected boolean addNoise = true;
    protected double noiseLevel = 1.0E-8;
    protected boolean dynCorrExcl = false;
    protected int dynCorrExclTime = 0;
    protected int numThreads = Runtime.getRuntime().availableProcessors();
    protected boolean isAlgorithm1 = false;
    protected boolean useGPU = false;
    protected String gpuLibraryPath = "";
    protected boolean cudaLibraryLoaded = false;
    protected KdTree kdTreeJoint;
    protected NearestNeighbourSearcher nnSearcherSource;
    protected NearestNeighbourSearcher nnSearcherDest;
    protected double digammaK;
    protected double digammaN;

    @Override
    public void initialise(int n, int n2) {
        this.kdTreeJoint = null;
        this.nnSearcherSource = null;
        this.nnSearcherDest = null;
        super.initialise(n, n2);
    }

    @Override
    public void setProperty(String string, String string2) throws Exception {
        boolean bl = true;
        if (string.equalsIgnoreCase(PROP_K)) {
            this.k = Integer.parseInt(string2);
        } else if (string.equalsIgnoreCase(PROP_NORM_TYPE)) {
            this.normType = KdTree.validateNormType(string2);
        } else if (string.equalsIgnoreCase(PROP_NORMALISE)) {
            this.normalise = Boolean.parseBoolean(string2);
        } else if (string.equalsIgnoreCase(PROP_DYN_CORR_EXCL_TIME)) {
            this.dynCorrExclTime = Integer.parseInt(string2);
            this.dynCorrExcl = this.dynCorrExclTime > 0;
        } else if (string.equalsIgnoreCase(PROP_ADD_NOISE)) {
            if (string2.equals("0") || string2.equalsIgnoreCase("false")) {
                this.addNoise = false;
                this.noiseLevel = 0.0;
            } else {
                this.addNoise = true;
                this.noiseLevel = Double.parseDouble(string2);
            }
        } else if (string.equalsIgnoreCase(PROP_NUM_THREADS)) {
            this.numThreads = string2.equalsIgnoreCase(USE_ALL_THREADS) ? Runtime.getRuntime().availableProcessors() : Integer.parseInt(string2);
        } else if (string.equalsIgnoreCase(PROP_USE_GPU)) {
            this.useGPU = Boolean.parseBoolean(string2);
        } else if (string.equalsIgnoreCase(PROP_GPU_LIBRARY_PATH)) {
            this.gpuLibraryPath = string2;
        } else {
            bl = false;
            super.setProperty(string, string2);
        }
        if (this.debug && bl) {
            System.out.println(this.getClass().getSimpleName() + ": Set property " + string + " to " + string2);
        }
    }

    @Override
    public String getProperty(String string) throws Exception {
        if (string.equalsIgnoreCase(PROP_K)) {
            return Integer.toString(this.k);
        }
        if (string.equalsIgnoreCase(PROP_NORM_TYPE)) {
            return KdTree.convertNormTypeToString(this.normType);
        }
        if (string.equalsIgnoreCase(PROP_NORMALISE)) {
            return Boolean.toString(this.normalise);
        }
        if (string.equalsIgnoreCase(PROP_DYN_CORR_EXCL_TIME)) {
            return Integer.toString(this.dynCorrExclTime);
        }
        if (string.equalsIgnoreCase(PROP_ADD_NOISE)) {
            return Double.toString(this.noiseLevel);
        }
        if (string.equalsIgnoreCase(PROP_NUM_THREADS)) {
            return Integer.toString(this.numThreads);
        }
        if (string.equalsIgnoreCase(PROP_USE_GPU)) {
            return Boolean.toString(this.useGPU);
        }
        return super.getProperty(string);
    }

    @Override
    public void finaliseAddObservations() throws Exception {
        super.finaliseAddObservations();
        if (this.dynCorrExcl && this.addedMoreThanOneObservationSet) {
            throw new RuntimeException("Addition of multiple observation sets is not currently supported with property DYN_CORR_EXCL set");
        }
        if (this.totalObservations <= this.k + 2 * this.dynCorrExclTime) {
            throw new Exception("There are less observations provided (" + this.totalObservations + ") than required for the number of nearest neighbours parameter (" + this.k + ") and any dynamic correlation exclusion (" + this.dynCorrExclTime + ")");
        }
        if (this.normalise) {
            MatrixUtils.normalise(this.sourceObservations);
            MatrixUtils.normalise(this.destObservations);
        }
        if (this.addNoise) {
            Random random = new Random();
            for (int i = 0; i < this.sourceObservations.length; ++i) {
                int n = 0;
                while (n < this.dimensionsSource) {
                    double[] dArray = this.sourceObservations[i];
                    int n2 = n++;
                    dArray[n2] = dArray[n2] + random.nextGaussian() * this.noiseLevel;
                }
                n = 0;
                while (n < this.dimensionsDest) {
                    double[] dArray = this.destObservations[i];
                    int n3 = n++;
                    dArray[n3] = dArray[n3] + random.nextGaussian() * this.noiseLevel;
                }
            }
        }
        this.digammaK = MathsUtils.digamma(this.k);
        this.digammaN = MathsUtils.digamma(this.totalObservations);
    }

    @Override
    public double computeAverageLocalOfObservations() throws Exception {
        double d = Calendar.getInstance().getTimeInMillis();
        this.lastAverage = this.computeFromObservations(false)[0];
        this.miComputed = true;
        if (this.debug) {
            Calendar calendar = Calendar.getInstance();
            long l = calendar.getTimeInMillis();
            System.out.println("Calculation time: " + ((double)l - d) / 1000.0 + " sec");
        }
        return this.lastAverage;
    }

    @Override
    public double computeAverageLocalOfObservations(int[] nArray) throws Exception {
        if (nArray == null) {
            return this.computeAverageLocalOfObservations();
        }
        KdTree kdTree = this.kdTreeJoint;
        this.kdTreeJoint = null;
        NearestNeighbourSearcher nearestNeighbourSearcher = this.nnSearcherDest;
        this.nnSearcherDest = null;
        double[][] dArray = this.destObservations;
        this.destObservations = MatrixUtils.extractSelectedTimePointsReusingArrays(dArray, nArray);
        double d = this.computeFromObservations(false)[0];
        this.destObservations = dArray;
        this.kdTreeJoint = kdTree;
        this.nnSearcherDest = nearestNeighbourSearcher;
        return d;
    }

    @Override
    public double[] computeLocalOfPreviousObservations() throws Exception {
        double[] dArray = this.computeFromObservations(true);
        this.lastAverage = MatrixUtils.mean(dArray);
        this.miComputed = true;
        return dArray;
    }

    @Override
    public double[] computeLocalUsingPreviousObservations(double[][] dArray, double[][] dArray2) throws Exception {
        throw new Exception("Local method not implemented yet");
    }

    protected double[] computeFromObservations(boolean bl) throws Exception {
        int n = this.sourceObservations.length;
        double[] dArray = null;
        this.ensureKdTreesConstructed();
        if (this.useGPU) {
            dArray = this.gpuComputeFromObservations(0, n, bl);
        } else if (this.numThreads == 1) {
            dArray = this.partialComputeFromObservations(0, n, bl);
        } else {
            int n2;
            dArray = bl ? new double[n] : new double[3];
            int n3 = n / this.numThreads;
            int n4 = n % this.numThreads;
            if (this.debug) {
                System.out.printf("Computing Kraskov MI with %d threads (%d timesteps each, plus %d residual)\n", this.numThreads, n3, n4);
            }
            Thread[] threadArray = new Thread[this.numThreads];
            MiKraskovThreadRunner[] miKraskovThreadRunnerArray = new MiKraskovThreadRunner[this.numThreads];
            for (n2 = 0; n2 < this.numThreads; ++n2) {
                int n5;
                int n6 = n2 == 0 ? 0 : n3 * n2 + n4;
                int n7 = n5 = n2 == 0 ? n3 + n4 : n3;
                if (this.debug) {
                    System.out.println(n2 + ".Thread: from " + n6 + " to " + (n6 + n5));
                }
                miKraskovThreadRunnerArray[n2] = new MiKraskovThreadRunner(this, n6, n5, bl);
                threadArray[n2] = new Thread(miKraskovThreadRunnerArray[n2]);
                threadArray[n2].start();
            }
            for (n2 = 0; n2 < this.numThreads; ++n2) {
                if (threadArray[n2] != null) {
                    threadArray[n2].join();
                }
                if (bl) {
                    System.arraycopy(miKraskovThreadRunnerArray[n2].getReturnValues(), 0, dArray, miKraskovThreadRunnerArray[n2].myStartTimePoint, miKraskovThreadRunnerArray[n2].numberOfTimePoints);
                    continue;
                }
                MatrixUtils.addInPlace(dArray, miKraskovThreadRunnerArray[n2].getReturnValues());
            }
        }
        if (bl) {
            return dArray;
        }
        double d = dArray[0] / (double)n;
        double d2 = dArray[1] / (double)n;
        double d3 = dArray[2] / (double)n;
        if (this.debug) {
            System.out.println(String.format("Average n_x=%.3f, Average n_y=%.3f", d2, d3));
        }
        if (this.isAlgorithm1) {
            return new double[]{this.digammaK - d + this.digammaN};
        }
        return new double[]{this.digammaK - 1.0 / (double)this.k - d + this.digammaN};
    }

    protected abstract double[] partialComputeFromObservations(int var1, int var2, boolean var3) throws Exception;

    protected double[] gpuComputeFromObservations(int n, int n2, boolean bl, int n3, int[][] nArray) throws Exception {
        double[] dArray;
        boolean bl2;
        if (this.debug) {
            System.out.println("Start GPU calculation");
        }
        this.ensureCudaLibraryLoaded();
        if (this.normType == 2) {
            bl2 = true;
        } else if (this.normType == 0 || this.normType == 3) {
            bl2 = false;
        } else {
            throw new Exception("Only max and square norms are implemented. Abort.");
        }
        try {
            if (this.debug) {
                System.out.printf("Calling GPU calculation with returnLocals=%b and nb_surrogates=%d\n", bl, n3);
            }
            dArray = this.MIKraskov(this.totalObservations, this.sourceObservations, this.dimensionsSource, this.destObservations, this.dimensionsDest, this.k, this.dynCorrExclTime, bl, bl2, this.isAlgorithm1, n3, null != nArray, nArray);
            if (this.debug) {
                System.out.println("GPU calculation finished successfully. Returning results");
            }
        }
        catch (Throwable throwable) {
            System.out.println("WARNING. Error in GPU code. Reverting back to CPU.");
            throwable.printStackTrace();
            dArray = this.partialComputeFromObservations(0, this.totalObservations, bl);
        }
        return dArray;
    }

    protected double[] gpuComputeFromObservations(int n, int n2, boolean bl) throws Exception {
        return this.gpuComputeFromObservations(n, n2, bl, 0, null);
    }

    protected double[] gpuComputeFromObservations(int n, int n2, boolean bl, int[][] nArray) throws Exception {
        return this.gpuComputeFromObservations(n, n2, bl, nArray.length, nArray);
    }

    private native double[] MIKraskov(int var1, double[][] var2, int var3, double[][] var4, int var5, int var6, int var7, boolean var8, boolean var9, boolean var10, int var11, boolean var12, int[][] var13);

    protected void ensureKdTreesConstructed() throws Exception {
        if (this.kdTreeJoint == null) {
            this.kdTreeJoint = new KdTree(new int[]{this.dimensionsSource, this.dimensionsDest}, new double[][][]{this.sourceObservations, this.destObservations});
            this.kdTreeJoint.setNormType(this.normType);
        }
        if (this.nnSearcherSource == null) {
            this.nnSearcherSource = NearestNeighbourSearcher.create(this.sourceObservations);
            this.nnSearcherSource.setNormType(this.normType);
        }
        if (this.nnSearcherDest == null) {
            this.nnSearcherDest = NearestNeighbourSearcher.create(this.destObservations);
            this.nnSearcherDest.setNormType(this.normType);
        }
    }

    protected void ensureCudaLibraryLoaded() throws Exception {
        if (!this.cudaLibraryLoaded) {
            try {
                if (this.gpuLibraryPath.length() < 1) {
                    NativeUtils.loadLibraryFromJar("/cuda/libKraskov.so");
                } else {
                    System.load(this.gpuLibraryPath);
                }
            }
            catch (Throwable throwable) {
                String string = "GPU library not found. To compile GPU code set the enablegpu flag to true in build.xml";
                if (this.gpuLibraryPath.length() > 0) {
                    string = string + "\nGPU library was not found in the path provided. Provide full path including library file name.";
                    string = string + "\nExample: /home/johndoe/myfolder/libKraskov.so";
                }
                throw new Exception(string);
            }
            this.cudaLibraryLoaded = true;
        }
    }

    @Override
    public EmpiricalMeasurementDistribution computeSignificance(int n) throws Exception {
        if (this.useGPU) {
            double[] dArray = this.gpuComputeFromObservations(0, this.totalObservations, false, n, null);
            return new EmpiricalMeasurementDistribution(MatrixUtils.select(dArray, 1, dArray.length - 1), dArray[0]);
        }
        return super.computeSignificance(n);
    }

    @Override
    public EmpiricalMeasurementDistribution computeSignificance(int[][] nArray) throws Exception {
        if (this.useGPU) {
            double[] dArray = this.gpuComputeFromObservations(0, this.totalObservations, false, nArray.length, nArray);
            return new EmpiricalMeasurementDistribution(MatrixUtils.select(dArray, 1, dArray.length - 1), dArray[0]);
        }
        return super.computeSignificance(nArray);
    }

    public double[] computePredictionErrorsFromObservations(boolean bl) throws Exception {
        return this.computePredictionErrorsFromObservations(bl, this.k);
    }

    public double[] computePredictionErrorsFromObservations(boolean bl, int n) throws Exception {
        int n2 = this.sourceObservations.length;
        double[] dArray = null;
        this.ensureKdTreesConstructed();
        if (this.numThreads == 1) {
            dArray = this.partialComputePredictionErrorFromObservations(0, n2, n, bl);
        } else {
            int n3;
            dArray = new double[bl ? this.dimensionsSource : this.dimensionsDest];
            int n4 = n2 / this.numThreads;
            int n5 = n2 % this.numThreads;
            if (this.debug) {
                System.out.printf("Computing prediction errors for variable %d from variable %d with %d threads (%d timesteps each, plus %d residual)\n", bl ? 1 : 2, bl ? 2 : 1, this.numThreads, n4, n5);
            }
            Thread[] threadArray = new Thread[this.numThreads];
            MiKraskovPredictionThreadRunner[] miKraskovPredictionThreadRunnerArray = new MiKraskovPredictionThreadRunner[this.numThreads];
            for (n3 = 0; n3 < this.numThreads; ++n3) {
                int n6;
                int n7 = n3 == 0 ? 0 : n4 * n3 + n5;
                int n8 = n6 = n3 == 0 ? n4 + n5 : n4;
                if (this.debug) {
                    System.out.println(n3 + ".Thread: from " + n7 + " to " + (n7 + n6));
                }
                miKraskovPredictionThreadRunnerArray[n3] = new MiKraskovPredictionThreadRunner(this, n7, n6, n, bl);
                threadArray[n3] = new Thread(miKraskovPredictionThreadRunnerArray[n3]);
                threadArray[n3].start();
            }
            for (n3 = 0; n3 < this.numThreads; ++n3) {
                if (threadArray[n3] != null) {
                    threadArray[n3].join();
                }
                MatrixUtils.addInPlace(dArray, miKraskovPredictionThreadRunnerArray[n3].getReturnValues());
            }
        }
        if (this.debug) {
            System.out.printf("Total prediction error from variable %d to variable %d=", bl ? 2 : 1, bl ? 1 : 2);
            MatrixUtils.printArray(System.out, 3, dArray);
        }
        return dArray;
    }

    protected double[] partialComputePredictionErrorFromObservations(int n, int n2, int n3, boolean bl) throws Exception {
        double d = Calendar.getInstance().getTimeInMillis();
        double[] dArray = new double[bl ? this.dimensionsSource : this.dimensionsDest];
        for (int i = n; i < n + n2; ++i) {
            int n4;
            int n5;
            double[] dArray2;
            double[] dArray3;
            PriorityQueue<NeighbourNodeData> priorityQueue;
            double[] dArray4;
            if (bl) {
                dArray4 = this.sourceObservations[i];
                priorityQueue = this.nnSearcherDest.findKNearestNeighbours(n3, i, this.dynCorrExclTime);
                dArray3 = new double[this.dimensionsSource];
                for (NeighbourNodeData neighbourNodeData : priorityQueue) {
                    dArray2 = this.sourceObservations[neighbourNodeData.sampleIndex];
                    for (n5 = 0; n5 < this.dimensionsSource; ++n5) {
                        int n6 = n5;
                        dArray3[n6] = dArray3[n6] + dArray2[n5];
                    }
                }
                for (n4 = 0; n4 < this.dimensionsSource; ++n4) {
                    int n7 = n4;
                    dArray3[n7] = dArray3[n7] / (double)n3;
                    int n8 = n4;
                    dArray[n8] = dArray[n8] + (dArray4[n4] - dArray3[n4]) * (dArray4[n4] - dArray3[n4]);
                }
                continue;
            }
            dArray4 = this.destObservations[i];
            priorityQueue = this.nnSearcherSource.findKNearestNeighbours(n3, i, this.dynCorrExclTime);
            dArray3 = new double[this.dimensionsDest];
            for (NeighbourNodeData neighbourNodeData : priorityQueue) {
                dArray2 = this.destObservations[neighbourNodeData.sampleIndex];
                for (n5 = 0; n5 < this.dimensionsDest; ++n5) {
                    int n9 = n5;
                    dArray3[n9] = dArray3[n9] + dArray2[n5];
                }
            }
            for (n4 = 0; n4 < this.dimensionsDest; ++n4) {
                int n10 = n4;
                dArray3[n10] = dArray3[n10] / (double)n3;
                int n11 = n4;
                dArray[n11] = dArray[n11] + (dArray4[n4] - dArray3[n4]) * (dArray4[n4] - dArray3[n4]);
            }
        }
        if (this.debug) {
            Calendar calendar = Calendar.getInstance();
            long l = calendar.getTimeInMillis();
            System.out.println("Subset " + n + ":" + (n + n2) + " Calculation time: " + ((double)l - d) / 1000.0 + " sec");
        }
        return dArray;
    }

    private class MiKraskovPredictionThreadRunner
    implements Runnable {
        protected MutualInfoCalculatorMultiVariateKraskov miCalc;
        protected int myStartTimePoint;
        protected int numberOfTimePoints;
        protected int kNNs;
        protected boolean predictFirstVariable;
        protected double[] returnValues = null;
        protected Exception problem = null;

        public MiKraskovPredictionThreadRunner(MutualInfoCalculatorMultiVariateKraskov mutualInfoCalculatorMultiVariateKraskov2, int n, int n2, int n3, boolean bl) {
            this.miCalc = mutualInfoCalculatorMultiVariateKraskov2;
            this.myStartTimePoint = n;
            this.numberOfTimePoints = n2;
            this.kNNs = n3;
            this.predictFirstVariable = bl;
        }

        public double[] getReturnValues() throws Exception {
            if (this.problem != null) {
                throw this.problem;
            }
            return this.returnValues;
        }

        @Override
        public void run() {
            try {
                this.returnValues = this.miCalc.partialComputePredictionErrorFromObservations(this.myStartTimePoint, this.numberOfTimePoints, this.kNNs, this.predictFirstVariable);
            }
            catch (Exception exception) {
                this.problem = exception;
                return;
            }
        }
    }

    private class MiKraskovThreadRunner
    implements Runnable {
        protected MutualInfoCalculatorMultiVariateKraskov miCalc;
        protected int myStartTimePoint;
        protected int numberOfTimePoints;
        protected boolean computeLocals;
        protected double[] returnValues = null;
        protected Exception problem = null;
        public static final int INDEX_SUM_DIGAMMAS = 0;
        public static final int INDEX_SUM_NX = 1;
        public static final int INDEX_SUM_NY = 2;
        public static final int RETURN_ARRAY_LENGTH = 3;

        public MiKraskovThreadRunner(MutualInfoCalculatorMultiVariateKraskov mutualInfoCalculatorMultiVariateKraskov2, int n, int n2, boolean bl) {
            this.miCalc = mutualInfoCalculatorMultiVariateKraskov2;
            this.myStartTimePoint = n;
            this.numberOfTimePoints = n2;
            this.computeLocals = bl;
        }

        public double[] getReturnValues() throws Exception {
            if (this.problem != null) {
                throw this.problem;
            }
            return this.returnValues;
        }

        @Override
        public void run() {
            try {
                this.returnValues = this.miCalc.partialComputeFromObservations(this.myStartTimePoint, this.numberOfTimePoints, this.computeLocals);
            }
            catch (Exception exception) {
                this.problem = exception;
                return;
            }
        }
    }
}

