/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.reasoner.sgd.term;

import org.linqs.psl.config.Config;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.rule.GroundRule;
import org.linqs.psl.model.rule.WeightedGroundRule;
import org.linqs.psl.reasoner.function.FunctionComparator;
import org.linqs.psl.reasoner.sgd.term.SGDObjectiveTerm;
import org.linqs.psl.reasoner.term.Hyperplane;
import org.linqs.psl.reasoner.term.HyperplaneTermGenerator;
import org.linqs.psl.reasoner.term.TermStore;
import org.linqs.psl.reasoner.term.VariableTermStore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SGDTermGenerator
extends HyperplaneTermGenerator<SGDObjectiveTerm, RandomVariableAtom> {
    private static final Logger log = LoggerFactory.getLogger(SGDTermGenerator.class);
    private float learningRate = Config.getFloat("sgd.learningrate", 1.0f);

    @Override
    public Class<RandomVariableAtom> getLocalVariableType() {
        return RandomVariableAtom.class;
    }

    @Override
    public SGDObjectiveTerm createLossTerm(TermStore<SGDObjectiveTerm, RandomVariableAtom> baseTermStore, boolean isHinge, boolean isSquared, GroundRule groundRule, Hyperplane<RandomVariableAtom> hyperplane) {
        VariableTermStore termStore = (VariableTermStore)baseTermStore;
        float weight = (float)((WeightedGroundRule)groundRule).getWeight();
        return new SGDObjectiveTerm(termStore, isSquared, isHinge, hyperplane, weight, this.learningRate);
    }

    @Override
    public SGDObjectiveTerm createLinearConstraintTerm(TermStore<SGDObjectiveTerm, RandomVariableAtom> termStore, GroundRule groundRule, Hyperplane<RandomVariableAtom> hyperplane, FunctionComparator comparator) {
        log.warn("SGD does not support hard constraints, i.e. " + groundRule);
        return null;
    }
}

