/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.inference.online;

import java.util.List;
import java.util.Set;
import org.linqs.psl.application.inference.InferenceApplication;
import org.linqs.psl.application.inference.online.OnlineServer;
import org.linqs.psl.application.inference.online.messages.OnlineMessage;
import org.linqs.psl.application.inference.online.messages.actions.controls.Stop;
import org.linqs.psl.application.inference.online.messages.actions.model.AddAtom;
import org.linqs.psl.application.inference.online.messages.actions.model.QueryAtom;
import org.linqs.psl.application.inference.online.messages.responses.ActionStatus;
import org.linqs.psl.application.inference.online.messages.responses.QueryAtomResponse;
import org.linqs.psl.application.learning.weight.TrainingMap;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.atom.OnlineAtomManager;
import org.linqs.psl.database.atom.PersistedAtomManager;
import org.linqs.psl.evaluation.statistics.Evaluator;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.reasoner.term.online.OnlineTermStore;
import org.linqs.psl.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class OnlineInference
extends InferenceApplication {
    private static final Logger log = LoggerFactory.getLogger(OnlineInference.class);
    private OnlineServer server;
    private boolean modelUpdates;
    private boolean stopped;
    private double objective;
    private List<Evaluator> evaluators;
    private TrainingMap trainingMap;
    private Set<StandardPredicate> evaluationPredicates;

    protected OnlineInference(List<Rule> rules, Database database) {
        super(rules, database);
    }

    protected OnlineInference(List<Rule> rules, Database database, boolean relaxHardConstraints) {
        super(rules, database, relaxHardConstraints);
    }

    @Override
    protected void initialize() {
        this.stopped = false;
        this.modelUpdates = true;
        this.objective = 0.0;
        this.evaluators = null;
        this.trainingMap = null;
        this.evaluationPredicates = null;
        this.startServer();
        super.initialize();
        this.termStore.ensureVariableCapacity(this.atomManager.getCachedRVACount() + this.atomManager.getCachedObsCount());
    }

    @Override
    protected PersistedAtomManager createAtomManager(Database database) {
        return new OnlineAtomManager(database, this.initialValue);
    }

    @Override
    public void close() {
        this.stopped = true;
        this.closeServer();
        super.close();
    }

    private void closeServer() {
        if (this.server != null) {
            this.server.close();
            this.server = null;
        }
    }

    private void startServer() {
        this.server = new OnlineServer();
        this.server.start();
    }

    protected void executeAction(OnlineMessage action) {
        String response = null;
        if (action.getClass() == AddAtom.class) {
            response = this.doAddAtom((AddAtom)action);
        } else if (action.getClass() == QueryAtom.class) {
            response = this.doQueryAtom((QueryAtom)action);
        } else if (action.getClass() == Stop.class) {
            response = this.doStop();
        } else {
            throw new IllegalArgumentException("Unsupported action: " + action.getClass().getName() + ".");
        }
        this.server.onActionExecution(action, new ActionStatus(action, true, response));
    }

    protected String doAddAtom(AddAtom action) {
        GroundAtom atom = null;
        if (this.atomManager.getDatabase().hasAtom(action.getPredicate(), action.getArguments())) {
            atom = ((OnlineAtomManager)this.atomManager).deleteAtom(action.getPredicate(), action.getArguments());
            ((OnlineTermStore)this.termStore).deleteLocalVariable(atom);
        }
        atom = action.getPartitionName().equalsIgnoreCase("READ") ? ((OnlineAtomManager)this.atomManager).addObservedAtom(action.getPredicate(), action.getValue(), action.getArguments()) : ((OnlineAtomManager)this.atomManager).addRandomVariableAtom(action.getPredicate(), action.getValue(), action.getArguments());
        ((OnlineTermStore)this.termStore).createLocalVariable(atom);
        this.modelUpdates = true;
        return String.format("Added atom: %s", atom.toStringWithValue());
    }

    protected String doQueryAtom(QueryAtom action) {
        if (!((OnlineAtomManager)this.atomManager).hasAtom(action.getPredicate(), action.getArguments())) {
            this.server.onActionExecution(action, new QueryAtomResponse(action, -1.0));
            return String.format("Atom: %s(%s) not found.", action.getPredicate(), StringUtils.join(", ", (Object[])action.getArguments()));
        }
        this.optimize();
        double atomValue = this.atomManager.getAtom(action.getPredicate(), action.getArguments()).getValue();
        this.server.onActionExecution(action, new QueryAtomResponse(action, atomValue));
        return String.format("Atom: %s(%s) found. Returned to client.", action.getPredicate(), StringUtils.join(", ", (Object[])action.getArguments()));
    }

    protected String doStop() {
        this.stopped = true;
        return "OnlinePSL inference stopped.";
    }

    private void optimize() {
        if (!this.modelUpdates) {
            return;
        }
        log.trace("Optimization Start");
        this.objective = this.reasoner.optimize(this.termStore, this.evaluators, this.trainingMap, this.evaluationPredicates);
        log.trace("Optimization End");
        this.modelUpdates = false;
    }

    @Override
    public double internalInference(List<Evaluator> evaluators, TrainingMap trainingMap, Set<StandardPredicate> evaluationPredicates) {
        this.evaluators = evaluators;
        this.trainingMap = trainingMap;
        this.evaluationPredicates = evaluationPredicates;
        this.optimize();
        while (!this.stopped) {
            OnlineMessage action = this.server.getAction();
            if (action == null) continue;
            try {
                log.trace(String.format("Executing action: %s", action));
                this.executeAction(action);
            }
            catch (IllegalArgumentException ex) {
                this.server.onActionExecution(action, new ActionStatus(action, false, ex.getMessage()));
            }
            catch (RuntimeException ex) {
                this.server.onActionExecution(action, new ActionStatus(action, false, ex.getMessage()));
                throw new RuntimeException(String.format("Critically failed to execute action: %s", action), ex);
            }
        }
        this.closeServer();
        return this.objective;
    }
}

