/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner.admm.term;

import java.util.HashSet;
import org.linqs.psl.application.groundrulestore.GroundRuleStore;
import org.linqs.psl.config.Config;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.UnweightedGroundRule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.model.rule.WeightedRule;
import org.linqs.psl.reasoner.admm.term.ADMMObjectiveTerm;
import org.linqs.psl.reasoner.admm.term.ADMMTermStore;
import org.linqs.psl.reasoner.admm.term.HingeLossTerm;
import org.linqs.psl.reasoner.admm.term.Hyperplane;
import org.linqs.psl.reasoner.admm.term.LinearConstraintTerm;
import org.linqs.psl.reasoner.admm.term.LinearLossTerm;
import org.linqs.psl.reasoner.admm.term.LocalVariable;
import org.linqs.psl.reasoner.admm.term.SquaredHingeLossTerm;
import org.linqs.psl.reasoner.admm.term.SquaredLinearLossTerm;
import org.linqs.psl.reasoner.function.ConstraintTerm;
import org.linqs.psl.reasoner.function.FunctionTerm;
import org.linqs.psl.reasoner.function.GeneralFunction;
import org.linqs.psl.reasoner.term.TermGenerator;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.util.MathUtils;
import org.linqs.psl.util.Parallel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ADMMTermGenerator
implements TermGenerator<ADMMObjectiveTerm> {
    private static final Logger log = LoggerFactory.getLogger(ADMMTermGenerator.class);
    public static final String CONFIG_PREFIX = "admmtermgenerator";
    public static final String INVERT_NEGATIVE_WEIGHTS_KEY = "admmtermgenerator.invertnegativeweights";
    public static final boolean INVERT_NEGATIVE_WEIGHTS_DEFAULT = false;
    private boolean invertNegativeWeight = Config.getBoolean("admmtermgenerator.invertnegativeweights", false);

    @Override
    public int generateTerms(GroundRuleStore ruleStore, TermStore<ADMMObjectiveTerm> termStore) {
        return this.generateTerms(ruleStore, termStore, 0);
    }

    public int generateTerms(GroundRuleStore ruleStore, final TermStore<ADMMObjectiveTerm> termStore, int rvaCount) {
        if (!(termStore instanceof ADMMTermStore)) {
            throw new IllegalArgumentException("ADMMTermGenerator requires an ADMMTermStore");
        }
        int initialSize = termStore.size();
        termStore.ensureCapacity(initialSize + ruleStore.size());
        ((ADMMTermStore)termStore).ensureVariableCapacity(rvaCount);
        HashSet<WeightedRule> rules = new HashSet<WeightedRule>();
        for (GroundRule groundRule : ruleStore.getGroundRules()) {
            if (!(groundRule instanceof WeightedGroundRule)) continue;
            rules.add((WeightedRule)groundRule.getRule());
        }
        for (WeightedRule weightedRule : rules) {
            if (!(weightedRule.getWeight() < 0.0)) continue;
            log.warn("Found a rule with a negative weight, but config says not to invert it... skipping: " + weightedRule);
        }
        Parallel.foreach(ruleStore.getGroundRules(), new Parallel.Worker<GroundRule>(){

            @Override
            public void work(int index, GroundRule rule) {
                boolean negativeWeight;
                boolean bl = negativeWeight = rule instanceof WeightedGroundRule && ((WeightedGroundRule)rule).getWeight() < 0.0;
                if (negativeWeight) {
                    if (!ADMMTermGenerator.this.invertNegativeWeight) {
                        return;
                    }
                    for (GroundRule negatedRule : rule.negate()) {
                        ADMMObjectiveTerm term = ADMMTermGenerator.this.createTerm(negatedRule, (ADMMTermStore)termStore);
                        if (term == null || term.size() <= 0) continue;
                        termStore.add(rule, term);
                    }
                } else {
                    ADMMObjectiveTerm term = ADMMTermGenerator.this.createTerm(rule, (ADMMTermStore)termStore);
                    if (term != null && term.size() > 0) {
                        termStore.add(rule, term);
                    }
                }
            }
        });
        return termStore.size() - initialSize;
    }

    private ADMMObjectiveTerm createTerm(GroundRule groundRule, ADMMTermStore termStore) {
        ADMMObjectiveTerm term;
        if (groundRule instanceof WeightedGroundRule) {
            GeneralFunction function = ((WeightedGroundRule)groundRule).getFunctionDefinition();
            Hyperplane hyperplane = this.processHyperplane(function, termStore);
            if (hyperplane == null) {
                return null;
            }
            if (function.isNonNegative() && function.isSquared()) {
                term = new SquaredHingeLossTerm(groundRule, hyperplane);
            } else if (function.isNonNegative() && !function.isSquared()) {
                term = new HingeLossTerm(groundRule, hyperplane);
            } else if (!function.isNonNegative() && function.isSquared()) {
                hyperplane.setConstant(0.0f);
                term = new SquaredLinearLossTerm(groundRule, hyperplane);
            } else {
                term = new LinearLossTerm(groundRule, hyperplane);
            }
        } else if (groundRule instanceof UnweightedGroundRule) {
            ConstraintTerm constraint = ((UnweightedGroundRule)groundRule).getConstraintDefinition();
            GeneralFunction function = constraint.getFunction();
            Hyperplane hyperplane = this.processHyperplane(function, termStore);
            if (hyperplane == null) {
                return null;
            }
            hyperplane.setConstant(constraint.getValue() + hyperplane.getConstant());
            term = new LinearConstraintTerm(groundRule, hyperplane, constraint.getComparator());
        } else {
            throw new IllegalArgumentException("Unsupported ground rule: " + groundRule);
        }
        return term;
    }

    private Hyperplane processHyperplane(GeneralFunction sum, ADMMTermStore termStore) {
        Hyperplane hyperplane = new Hyperplane(sum.size(), -1.0f * sum.getConstant());
        for (int i = 0; i < sum.size(); ++i) {
            float coefficient = sum.getCoefficient(i);
            FunctionTerm term = sum.getTerm(i);
            if (term instanceof RandomVariableAtom) {
                LocalVariable variable = termStore.createLocalVariable((RandomVariableAtom)term);
                int localIndex = hyperplane.indexOfVariable(variable);
                if (localIndex != -1) {
                    if (sum.isNonNegative() && !MathUtils.signsMatch(hyperplane.getCoefficient(localIndex), coefficient)) {
                        return null;
                    }
                    hyperplane.appendCoefficient(localIndex, coefficient);
                    continue;
                }
                hyperplane.addTerm(variable, coefficient);
                continue;
            }
            if (term.isConstant()) {
                hyperplane.setConstant(hyperplane.getConstant() - coefficient * term.getValue());
                continue;
            }
            throw new IllegalArgumentException("Unexpected summand: " + sum + "[" + i + "] (" + term + ").");
        }
        return hyperplane;
    }
}

