/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight;

import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.atom.PersistedAtomManager;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.Predicate;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.util.IteratorUtils;

public class TrainingMap {
    private final Map<RandomVariableAtom, ObservedAtom> trainingMap;
    private final Set<RandomVariableAtom> latentVariables;
    private final Map<ObservedAtom, ObservedAtom> observedMap;

    public TrainingMap(PersistedAtomManager rvAtomManager, Database observedDB, boolean fetchObservedPairs) {
        HashMap<RandomVariableAtom, ObservedAtom> tempTrainingMap = new HashMap<RandomVariableAtom, ObservedAtom>(rvAtomManager.getPersistedCount());
        HashMap<ObservedAtom, ObservedAtom> tempObservedMap = new HashMap<ObservedAtom, ObservedAtom>();
        HashSet<RandomVariableAtom> tempLatentVariables = new HashSet<RandomVariableAtom>();
        HashSet<ObservedAtom> seenTruthAtoms = null;
        if (fetchObservedPairs) {
            seenTruthAtoms = new HashSet<ObservedAtom>(rvAtomManager.getPersistedCount());
        }
        for (RandomVariableAtom rvAtom : rvAtomManager.getPersistedRVAtoms()) {
            GroundAtom otherAtom = observedDB.getAtom((Predicate)rvAtom.getPredicate(), rvAtom.getArguments());
            if (otherAtom instanceof ObservedAtom) {
                tempTrainingMap.put(rvAtom, (ObservedAtom)otherAtom);
                if (!fetchObservedPairs) continue;
                seenTruthAtoms.add((ObservedAtom)otherAtom);
                continue;
            }
            tempLatentVariables.add(rvAtom);
        }
        if (fetchObservedPairs) {
            for (StandardPredicate predicate : observedDB.getDataStore().getRegisteredPredicates()) {
                for (GroundAtom atom : observedDB.getAllGroundAtoms(predicate)) {
                    ObservedAtom truthAtom;
                    if (!(atom instanceof ObservedAtom) || seenTruthAtoms.contains(truthAtom = (ObservedAtom)atom)) continue;
                    GroundAtom otherAtom = null;
                    try {
                        otherAtom = rvAtomManager.getAtom(truthAtom.getPredicate(), truthAtom.getArguments());
                    }
                    catch (PersistedAtomManager.PersistedAccessException ex) {
                        continue;
                    }
                    if (otherAtom instanceof ObservedAtom) {
                        tempObservedMap.put((ObservedAtom)otherAtom, truthAtom);
                        continue;
                    }
                    throw new IllegalStateException("Found a non-observed atom after we got all the RVA... was the data store changed under us?");
                }
            }
        }
        this.trainingMap = Collections.unmodifiableMap(tempTrainingMap);
        this.observedMap = Collections.unmodifiableMap(tempObservedMap);
        this.latentVariables = Collections.unmodifiableSet(tempLatentVariables);
    }

    public Map<RandomVariableAtom, ObservedAtom> getTrainingMap() {
        return this.trainingMap;
    }

    public Map<ObservedAtom, ObservedAtom> getObservedMap() {
        return this.observedMap;
    }

    public Set<RandomVariableAtom> getLatentVariables() {
        return this.latentVariables;
    }

    public Iterable<GroundAtom> getTargetAtoms() {
        return this.getTargetAtoms(false);
    }

    public Iterable<GroundAtom> getTargetAtoms(boolean includeLatent) {
        if (includeLatent) {
            return IteratorUtils.join(this.trainingMap.keySet(), this.observedMap.keySet(), this.latentVariables);
        }
        return IteratorUtils.join(this.trainingMap.keySet(), this.observedMap.keySet());
    }

    public Iterable<GroundAtom> getTruthAtoms() {
        return IteratorUtils.join(this.trainingMap.values(), this.observedMap.values());
    }

    public Iterable<Map.Entry<GroundAtom, GroundAtom>> getFullMap() {
        Iterable<Map.Entry<GroundAtom, GroundAtom>> temp = IteratorUtils.join(this.trainingMap.entrySet(), this.observedMap.entrySet());
        return temp;
    }
}

