/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.application.learning.weight;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.atom.PersistedAtomManager;
import org.linqs.psl.model.atom.GroundAtom;
import org.linqs.psl.model.atom.ObservedAtom;
import org.linqs.psl.model.atom.RandomVariableAtom;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.util.IteratorUtils;

public class TrainingMap {
    private final Map<RandomVariableAtom, ObservedAtom> labelMap;
    private final Map<ObservedAtom, ObservedAtom> observedMap;
    private final List<RandomVariableAtom> latentVariables;
    private final List<ObservedAtom> missingLabels;
    private final List<ObservedAtom> missingTargets;

    public TrainingMap(PersistedAtomManager targets, Database truthDatabase) {
        HashMap<RandomVariableAtom, ObservedAtom> tempLabelMap = new HashMap<RandomVariableAtom, ObservedAtom>(targets.getPersistedCount());
        HashMap<ObservedAtom, ObservedAtom> tempObservedMap = new HashMap<ObservedAtom, ObservedAtom>();
        ArrayList<RandomVariableAtom> tempLatentVariables = new ArrayList<RandomVariableAtom>();
        ArrayList<ObservedAtom> tempMissingLabels = new ArrayList<ObservedAtom>();
        ArrayList<ObservedAtom> tempMissingTargets = new ArrayList<ObservedAtom>();
        HashSet<ObservedAtom> seenTruthAtoms = new HashSet<ObservedAtom>();
        this.prefetchTruthAtoms(truthDatabase);
        for (GroundAtom targetAtom : targets.getDatabase().getAllCachedAtoms()) {
            GroundAtom truthAtom = null;
            if (truthDatabase.hasCachedAtom((StandardPredicate)targetAtom.getPredicate(), targetAtom.getArguments())) {
                truthAtom = truthDatabase.getAtom((StandardPredicate)targetAtom.getPredicate(), false, targetAtom.getArguments());
            }
            if (truthAtom != null && !(truthAtom instanceof ObservedAtom)) continue;
            if (targetAtom instanceof RandomVariableAtom) {
                if (truthAtom == null) {
                    tempLatentVariables.add((RandomVariableAtom)targetAtom);
                    continue;
                }
                seenTruthAtoms.add((ObservedAtom)truthAtom);
                tempLabelMap.put((RandomVariableAtom)targetAtom, (ObservedAtom)truthAtom);
                continue;
            }
            if (truthAtom == null) {
                tempMissingLabels.add((ObservedAtom)targetAtom);
                continue;
            }
            seenTruthAtoms.add((ObservedAtom)truthAtom);
            tempObservedMap.put((ObservedAtom)targetAtom, (ObservedAtom)truthAtom);
        }
        for (GroundAtom truthAtom : truthDatabase.getAllCachedAtoms()) {
            if (!(truthAtom instanceof ObservedAtom) || seenTruthAtoms.contains(truthAtom)) continue;
            boolean hasAtom = targets.getDatabase().hasAtom((StandardPredicate)truthAtom.getPredicate(), truthAtom.getArguments());
            if (hasAtom) {
                throw new IllegalStateException("Un-persisted target atom: " + truthAtom);
            }
            tempMissingTargets.add((ObservedAtom)truthAtom);
        }
        this.labelMap = Collections.unmodifiableMap(tempLabelMap);
        this.observedMap = Collections.unmodifiableMap(tempObservedMap);
        this.latentVariables = Collections.unmodifiableList(tempLatentVariables);
        this.missingLabels = Collections.unmodifiableList(tempMissingLabels);
        this.missingTargets = Collections.unmodifiableList(tempMissingTargets);
    }

    public Map<RandomVariableAtom, ObservedAtom> getLabelMap() {
        return this.labelMap;
    }

    public Map<ObservedAtom, ObservedAtom> getObservedMap() {
        return this.observedMap;
    }

    public List<RandomVariableAtom> getLatentVariables() {
        return this.latentVariables;
    }

    public List<ObservedAtom> getMissingLabels() {
        return this.missingLabels;
    }

    public List<ObservedAtom> getMissingTargets() {
        return this.missingTargets;
    }

    public Iterable<RandomVariableAtom> getAllPredictions() {
        return IteratorUtils.join(this.labelMap.keySet(), this.latentVariables);
    }

    public Iterable<GroundAtom> getAllTargets() {
        return IteratorUtils.join(this.labelMap.keySet(), this.observedMap.keySet(), this.latentVariables, this.missingLabels);
    }

    public Iterable<GroundAtom> getAllTruths() {
        return IteratorUtils.join(this.labelMap.values(), this.observedMap.values(), this.missingTargets);
    }

    public Iterable<Map.Entry<GroundAtom, GroundAtom>> getFullMap() {
        Iterable<Map.Entry<GroundAtom, GroundAtom>> temp = IteratorUtils.join(this.labelMap.entrySet(), this.observedMap.entrySet());
        return temp;
    }

    public String toString() {
        return String.format("Training Map -- Label Map: %d, Observed Map: %d, Latent Variables: %d, Missing Labels: %d, Missing Targets: %d", this.labelMap.size(), this.observedMap.size(), this.latentVariables.size(), this.missingLabels.size(), this.missingTargets.size());
    }

    private void prefetchTruthAtoms(Database truthDatabase) {
        for (StandardPredicate predicate : truthDatabase.getDataStore().getRegisteredPredicates()) {
            truthDatabase.getAllGroundAtoms(predicate);
        }
    }
}

