/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner.admm.term;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.Semaphore;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.reasoner.admm.term.ADMMObjectiveTerm;
import org.linqs.psl.reasoner.admm.term.Hyperplane;
import org.linqs.psl.reasoner.admm.term.LocalVariable;
import org.linqs.psl.util.FloatMatrix;
import org.linqs.psl.util.HashCode;

public abstract class SquaredHyperplaneTerm
extends ADMMObjectiveTerm {
    protected final float[] coefficients;
    protected final float constant;
    private FloatMatrix lowerTriangle;
    private static Map<Integer, FloatMatrix> lowerTriangleCache = new HashMap<Integer, FloatMatrix>();
    private static final Semaphore matrixSemaphore = new Semaphore(1);

    public SquaredHyperplaneTerm(GroundRule groundRule, Hyperplane hyperplane) {
        super(hyperplane, groundRule);
        this.coefficients = hyperplane.getCoefficients();
        this.constant = hyperplane.getConstant();
        this.lowerTriangle = null;
    }

    private void initLowerTriangle(float stepSize) {
        int hash = HashCode.build(((WeightedGroundRule)this.groundRule).getWeight());
        hash = HashCode.build(hash, Float.valueOf(stepSize));
        for (int i = 0; i < this.size; ++i) {
            hash = HashCode.build(hash, Float.valueOf(this.coefficients[i]));
        }
        this.lowerTriangle = lowerTriangleCache.get(hash);
        if (this.lowerTriangle != null) {
            return;
        }
        this.lowerTriangle = this.computeLowerTriangle(stepSize, hash);
    }

    private synchronized FloatMatrix computeLowerTriangle(float stepSize, int hash) {
        if (lowerTriangleCache.containsKey(hash)) {
            return lowerTriangleCache.get(hash);
        }
        float weight = (float)((WeightedGroundRule)this.groundRule).getWeight();
        float coeff = 0.0f;
        FloatMatrix matrix = FloatMatrix.zeroes(this.size, this.size);
        for (int i = 0; i < this.size; ++i) {
            for (int j = i; j < this.size; ++j) {
                if (i == j) {
                    coeff = 2.0f * weight * this.coefficients[i] * this.coefficients[i] + stepSize;
                    matrix.set(i, i, coeff);
                    continue;
                }
                coeff = 2.0f * weight * this.coefficients[i] * this.coefficients[j];
                matrix.set(i, j, coeff);
                matrix.set(j, i, coeff);
            }
        }
        matrix.choleskyDecomposition(true);
        lowerTriangleCache.put(hash, matrix);
        return matrix;
    }

    @Override
    public float evaluate() {
        float value = 0.0f;
        for (int i = 0; i < this.size; ++i) {
            value += this.coefficients[i] * this.variables[i].getValue();
        }
        return value - this.constant;
    }

    protected void minWeightedSquaredHyperplane(float stepSize, float[] consensusValues) {
        int i;
        float weight = (float)((WeightedGroundRule)this.groundRule).getWeight();
        for (i = 0; i < this.size; ++i) {
            LocalVariable variable = this.variables[i];
            float value = stepSize * (consensusValues[variable.getGlobalId()] - variable.getLagrange() / stepSize);
            variable.setValue(value += 2.0f * weight * this.coefficients[i] * this.constant);
        }
        if (this.size == 1) {
            LocalVariable variable = this.variables[0];
            float coeff = this.coefficients[0];
            variable.setValue(variable.getValue() / (2.0f * weight * coeff * coeff + stepSize));
            return;
        }
        if (this.size == 2) {
            LocalVariable variable0 = this.variables[0];
            LocalVariable variable1 = this.variables[1];
            float coeff0 = this.coefficients[0];
            float coeff1 = this.coefficients[1];
            float a0 = 2.0f * weight * coeff0 * coeff0 + stepSize;
            float b1 = 2.0f * weight * coeff1 * coeff1 + stepSize;
            float a1b0 = 2.0f * weight * coeff0 * coeff1;
            variable1.setValue(variable1.getValue() - a1b0 * variable0.getValue() / a0);
            variable1.setValue(variable1.getValue() / (b1 - a1b0 * a1b0 / a0));
            variable0.setValue((variable0.getValue() - a1b0 * variable1.getValue()) / a0);
            return;
        }
        if (this.lowerTriangle == null) {
            this.initLowerTriangle(stepSize);
        }
        for (i = 0; i < this.size; ++i) {
            float newValue = this.variables[i].getValue();
            for (int j = 0; j < i; ++j) {
                newValue -= this.lowerTriangle.get(i, j) * this.variables[j].getValue();
            }
            this.variables[i].setValue(newValue / this.lowerTriangle.get(i, i));
        }
        for (i = this.size - 1; i >= 0; --i) {
            float newValue = this.variables[i].getValue();
            for (int j = this.size - 1; j > i; --j) {
                newValue -= this.lowerTriangle.get(j, i) * this.variables[j].getValue();
            }
            this.variables[i].setValue(newValue / this.lowerTriangle.get(i, i));
        }
    }
}

