/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner.sgd.term;

import java.nio.ByteBuffer;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.rule.AbstractRule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.reasoner.term.Hyperplane;
import org.linqs.psl.reasoner.term.VariableTermStore;
import org.linqs.psl.reasoner.term.streaming.StreamingTerm;

public class SGDObjectiveTerm
implements StreamingTerm {
    private boolean squared;
    private boolean hinge;
    private WeightedRule rule;
    private float constant;
    private short size;
    private float[] coefficients;
    private int[] variableIndexes;

    public SGDObjectiveTerm(VariableTermStore<SGDObjectiveTerm, GroundAtom> termStore, WeightedRule rule, boolean squared, boolean hinge, Hyperplane<GroundAtom> hyperplane) {
        this.squared = squared;
        this.hinge = hinge;
        this.rule = rule;
        this.size = (short)hyperplane.size();
        this.coefficients = hyperplane.getCoefficients();
        this.constant = hyperplane.getConstant();
        this.variableIndexes = new int[this.size];
        GroundAtom[] variables = (GroundAtom[])hyperplane.getVariables();
        for (int i = 0; i < this.size; ++i) {
            this.variableIndexes[i] = termStore.getVariableIndex(variables[i]);
        }
    }

    public int getVariableIndex(int i) {
        return this.variableIndexes[i];
    }

    @Override
    public int size() {
        return this.size;
    }

    @Override
    public void adjustConstant(float oldValue, float newValue) {
        this.constant = this.constant - oldValue + newValue;
    }

    @Override
    public boolean isConvex() {
        return true;
    }

    public float evaluate(float[] variableValues) {
        float dot = this.dot(variableValues);
        float weight = this.getWeight();
        if (this.squared && this.hinge) {
            return weight * (float)Math.pow(Math.max(0.0f, dot), 2.0);
        }
        if (this.squared && !this.hinge) {
            return weight * (float)Math.pow(dot, 2.0);
        }
        if (!this.squared && this.hinge) {
            return weight * Math.max(0.0f, dot);
        }
        return weight * dot;
    }

    public float computePartial(int varId, float dot, float weight) {
        if (this.hinge && dot <= 0.0f) {
            return 0.0f;
        }
        if (this.squared) {
            return weight * 2.0f * dot * this.coefficients[varId];
        }
        return weight * this.coefficients[varId];
    }

    public float dot(float[] variableValues) {
        float value = 0.0f;
        for (int i = 0; i < this.size; ++i) {
            value += this.coefficients[i] * variableValues[this.variableIndexes[i]];
        }
        return value - this.constant;
    }

    public WeightedRule getRule() {
        return this.rule;
    }

    public int[] getVariableIndexes() {
        return this.variableIndexes;
    }

    @Override
    public int fixedByteSize() {
        int bitSize = 96 + this.size * 64;
        return bitSize / 8;
    }

    @Override
    public void writeFixedValues(ByteBuffer fixedBuffer) {
        fixedBuffer.put((byte)(this.squared ? 1 : 0));
        fixedBuffer.put((byte)(this.hinge ? 1 : 0));
        fixedBuffer.putInt(this.rule.hashCode());
        fixedBuffer.putFloat(this.constant);
        fixedBuffer.putShort(this.size);
        for (int i = 0; i < this.size; ++i) {
            fixedBuffer.putFloat(this.coefficients[i]);
            fixedBuffer.putInt(this.variableIndexes[i]);
        }
    }

    @Override
    public void read(ByteBuffer fixedBuffer, ByteBuffer volatileBuffer) {
        this.squared = fixedBuffer.get() == 1;
        this.hinge = fixedBuffer.get() == 1;
        this.rule = (WeightedRule)AbstractRule.getRule(fixedBuffer.getInt());
        this.constant = fixedBuffer.getFloat();
        this.size = fixedBuffer.getShort();
        if (this.coefficients.length < this.size) {
            this.coefficients = new float[this.size];
            this.variableIndexes = new int[this.size];
        }
        for (int i = 0; i < this.size; ++i) {
            this.coefficients[i] = fixedBuffer.getFloat();
            this.variableIndexes[i] = fixedBuffer.getInt();
        }
    }

    public String toString() {
        return this.toString(null);
    }

    public String toString(VariableTermStore<SGDObjectiveTerm, GroundAtom> termStore) {
        StringBuilder builder = new StringBuilder();
        builder.append(this.getWeight());
        builder.append(" * ");
        if (this.hinge) {
            builder.append("max(0.0, ");
        } else {
            builder.append("(");
        }
        for (int i = 0; i < this.size; ++i) {
            builder.append("(");
            builder.append(this.coefficients[i]);
            if (termStore == null) {
                builder.append(" * <index:");
                builder.append(this.variableIndexes[i]);
                builder.append(">)");
            } else {
                builder.append(" * ");
                builder.append(termStore.getVariableValue(this.variableIndexes[i]));
                builder.append(")");
            }
            if (i == this.size - 1) continue;
            builder.append(" + ");
        }
        builder.append(" - ");
        builder.append(this.constant);
        builder.append(")");
        if (this.squared) {
            builder.append(" ^2");
        }
        return builder.toString();
    }

    private float getWeight() {
        if (this.rule != null && this.rule.isWeighted()) {
            return this.rule.getWeight();
        }
        return Float.POSITIVE_INFINITY;
    }
}

