/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.database.atom;

import java.util.LinkedList;
import java.util.Set;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.atom.AtomManager;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.Predicate;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.predicate.model.ModelPredicate;
import org.linqs.psl.model.term.Constant;
import org.linqs.psl.reasoner.InitialValue;
import org.linqs.psl.util.IteratorUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PersistedAtomManager
extends AtomManager {
    private static final Logger log = LoggerFactory.getLogger(PersistedAtomManager.class);
    private final boolean throwOnIllegalAccess = Options.PAM_THROW_ACCESS_EXCEPTION.getBoolean();
    private boolean warnOnIllegalAccess = !this.throwOnIllegalAccess;
    private InitialValue initialValueOnIllegalAccess;
    protected int persistedAtomCount;

    public PersistedAtomManager(Database db) {
        this(db, false);
    }

    public PersistedAtomManager(Database db, boolean prebuiltCache) {
        this(db, prebuiltCache, InitialValue.ATOM);
    }

    public PersistedAtomManager(Database db, boolean prebuiltCache, InitialValue initialValueOnIllegalAccess) {
        super(db);
        this.initialValueOnIllegalAccess = initialValueOnIllegalAccess;
        if (prebuiltCache) {
            this.persistedAtomCount = db.getCachedRVACount();
        } else {
            this.buildPersistedAtomCache();
        }
    }

    private void buildPersistedAtomCache() {
        this.persistedAtomCount = 0;
        LinkedList<RandomVariableAtom> mirrorAtoms = new LinkedList<RandomVariableAtom>();
        for (StandardPredicate predicate : this.db.getDataStore().getRegisteredPredicates()) {
            if (this.db.isClosed(predicate)) {
                this.db.getAllGroundAtoms(predicate);
                continue;
            }
            if (predicate instanceof ModelPredicate) continue;
            for (RandomVariableAtom atom : this.db.getAllGroundRandomVariableAtoms(predicate)) {
                atom.setPersisted(true);
                ++this.persistedAtomCount;
                if (predicate.getMirror() == null) continue;
                RandomVariableAtom mirrorAtom = (RandomVariableAtom)this.db.getAtom(predicate.getMirror(), true, atom.getArguments());
                mirrorAtoms.add(mirrorAtom);
                atom.setMirror(mirrorAtom);
                mirrorAtom.setMirror(atom);
                mirrorAtom.setPersisted(true);
                ++this.persistedAtomCount;
            }
            this.db.getAllGroundObservedAtoms(predicate);
            if (mirrorAtoms.size() <= 0) continue;
            this.db.commit(mirrorAtoms);
            mirrorAtoms.clear();
        }
    }

    @Override
    public GroundAtom getAtom(Predicate predicate, Constant ... arguments) {
        GroundAtom atom = this.db.getAtom(predicate, arguments);
        if (!(atom instanceof RandomVariableAtom)) {
            return atom;
        }
        RandomVariableAtom rvAtom = (RandomVariableAtom)atom;
        if (!rvAtom.getPersisted()) {
            if (!rvAtom.getAccessException()) {
                rvAtom.setValue(this.initialValueOnIllegalAccess.getVariableValue(rvAtom));
            }
            rvAtom.setAccessException(true);
        }
        if (this.enableAccessExceptions && (this.throwOnIllegalAccess || this.warnOnIllegalAccess) && rvAtom.getAccessException()) {
            this.reportAccessException(null, rvAtom);
        }
        return rvAtom;
    }

    public void commitPersistedAtoms() {
        this.db.commitCachedAtoms(true);
    }

    public int getPersistedCount() {
        return this.persistedAtomCount;
    }

    public Iterable<RandomVariableAtom> getPersistedRVAtoms() {
        return IteratorUtils.filter(this.db.getAllCachedRandomVariableAtoms(), new IteratorUtils.FilterFunction<RandomVariableAtom>(){

            @Override
            public boolean keep(RandomVariableAtom atom) {
                return atom.getPersisted();
            }
        });
    }

    protected void addToPersistedCache(Set<RandomVariableAtom> atoms) {
        for (RandomVariableAtom atom : atoms) {
            this.addToPersistedCache(atom);
        }
    }

    protected void addToPersistedCache(RandomVariableAtom atom) {
        if (!atom.getPersisted()) {
            atom.setPersisted(true);
            ++this.persistedAtomCount;
        }
    }

    @Override
    public void reportAccessException(RuntimeException ex, GroundAtom offendingAtom) {
        if (this.throwOnIllegalAccess) {
            if (ex == null) {
                ex = new PersistedAccessException((RandomVariableAtom)offendingAtom);
            }
            throw ex;
        }
        if (this.warnOnIllegalAccess) {
            this.warnOnIllegalAccess = false;
            log.warn(String.format("Found a non-persisted RVA (%s). If you do not understand the implications of this warning, check your configuration and set '%s' to true. This warning will only be logged once.", offendingAtom, Options.PAM_THROW_ACCESS_EXCEPTION.name()));
        }
    }

    public static class PersistedAccessException
    extends IllegalArgumentException {
        public RandomVariableAtom atom;

        public PersistedAccessException(RandomVariableAtom atom) {
            super("Can only call getAtom() on persisted RandomVariableAtoms (RVAs) using a PersistedAtomManager. Cannot access " + atom + ". This typically means that provided data is insufficient. An RVA (atom to be inferred (target)) was constructed during grounding that does not exist in the provided data.");
            this.atom = atom;
        }
    }
}

