/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner.admm.term;

import java.util.Collection;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.rule.arithmetic.AbstractArithmeticRule;
import org.linqs.psl.reasoner.admm.term.ADMMObjectiveTerm;
import org.linqs.psl.reasoner.admm.term.LocalVariable;
import org.linqs.psl.reasoner.function.FunctionComparator;
import org.linqs.psl.reasoner.term.Hyperplane;
import org.linqs.psl.reasoner.term.HyperplaneTermGenerator;
import org.linqs.psl.reasoner.term.ReasonerLocalVariable;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.util.MathUtils;

public class ADMMTermGenerator
extends HyperplaneTermGenerator<ADMMObjectiveTerm, LocalVariable> {
    public ADMMTermGenerator() {
        this(true);
    }

    public ADMMTermGenerator(boolean mergeConstants) {
        super(mergeConstants);
    }

    @Override
    public Class<LocalVariable> getLocalVariableType() {
        return LocalVariable.class;
    }

    @Override
    public int createLossTerm(Collection<ADMMObjectiveTerm> newTerms, TermStore<ADMMObjectiveTerm, LocalVariable> termStore, boolean isHinge, boolean isSquared, GroundRule groundRule, Hyperplane<LocalVariable> hyperplane) {
        if (isHinge && isSquared) {
            newTerms.add(ADMMObjectiveTerm.createSquaredHingeLossTerm(hyperplane, groundRule.getRule()));
        } else if (isHinge && !isSquared) {
            newTerms.add(ADMMObjectiveTerm.createHingeLossTerm(hyperplane, groundRule.getRule()));
        } else if (!isHinge && isSquared) {
            newTerms.add(ADMMObjectiveTerm.createSquaredLinearLossTerm(hyperplane, groundRule.getRule()));
        } else {
            newTerms.add(ADMMObjectiveTerm.createLinearLossTerm(hyperplane, groundRule.getRule()));
        }
        return 1;
    }

    @Override
    public int createLinearConstraintTerm(Collection<ADMMObjectiveTerm> newTerms, TermStore<ADMMObjectiveTerm, LocalVariable> termStore, GroundRule groundRule, Hyperplane<LocalVariable> hyperplane, FunctionComparator comparator) {
        newTerms.add(ADMMObjectiveTerm.createLinearConstraintTerm(hyperplane, groundRule.getRule(), comparator));
        if (!this.addDeterTerms) {
            return 1;
        }
        Rule rawRule = groundRule.getRule();
        if (rawRule == null || !(rawRule instanceof AbstractArithmeticRule)) {
            return 1;
        }
        AbstractArithmeticRule rule = (AbstractArithmeticRule)rawRule;
        if (!rule.getExpression().looksLikeFunctionalConstraint()) {
            return 1;
        }
        if (this.collectiveDeter) {
            newTerms.add(ADMMObjectiveTerm.createCollectiveDeterTerm(hyperplane, this.deterWeight, this.deterEpsilon));
            return 2;
        }
        float activeDeterConstant = this.deterConstant;
        if (MathUtils.isZero(activeDeterConstant)) {
            activeDeterConstant = 1.0f / (float)hyperplane.size();
        }
        for (int i = 0; i < hyperplane.size(); ++i) {
            Hyperplane independentHyperplane = new Hyperplane((ReasonerLocalVariable[])new LocalVariable[]{hyperplane.getVariable(i)}, new float[]{1.0f}, 0.0f, 1);
            newTerms.add(ADMMObjectiveTerm.createIndependentDeterTerm(independentHyperplane, this.deterWeight, activeDeterConstant));
        }
        return 1 + hyperplane.size();
    }
}

