/*
 * Decompiled with CFR 0.152.
 */
package org.linqs.psl.grounding.collective;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.linqs.psl.config.Options;
import org.linqs.psl.database.Database;
import org.linqs.psl.database.DatabaseQuery;
import org.linqs.psl.database.rdbms.Formula2SQL;
import org.linqs.psl.database.rdbms.driver.DatabaseDriver;
import org.linqs.psl.grounding.collective.CandidateQuery;
import org.linqs.psl.grounding.collective.CandidateSearchNode;
import org.linqs.psl.grounding.collective.SearchFringe;
import org.linqs.psl.model.atom.Atom;
import org.linqs.psl.model.formula.Conjunction;
import org.linqs.psl.model.formula.Formula;
import org.linqs.psl.model.predicate.ExternalFunctionalPredicate;
import org.linqs.psl.model.predicate.GroundingOnlyPredicate;
import org.linqs.psl.model.predicate.StandardPredicate;
import org.linqs.psl.model.rule.Rule;
import org.linqs.psl.model.term.Variable;
import org.linqs.psl.util.BitUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CandidateGeneration {
    private static final Logger log = LoggerFactory.getLogger(CandidateGeneration.class);
    public static final double CANDIDATE_SIZE_ADJUSTMENT = 1.0;
    public static final double OPTIMISTIC_QUERY_COST_MULTIPLIER = 0.018;
    public static final double OPTIMISTIC_INSTANTIATION_COST_MULTIPLIER = 0.001;
    public static final double PESSIMISTIC_QUERY_COST_MULTIPLIER = 0.02;
    public static final double PESSIMISTIC_INSTANTIATION_COST_MULTIPLIER = 0.002;
    private SearchType searchType = SearchType.valueOf(Options.GROUNDING_COLLECTIVE_CANDIDATE_SEARCH_TYPE.getString());
    private int budget = Options.GROUNDING_COLLECTIVE_CANDIDATE_SEARCH_BUDGET.getInt();
    private Map<String, DatabaseDriver.ExplainResult> explains = new ConcurrentHashMap<String, DatabaseDriver.ExplainResult>();

    public void generateCandidates(Rule rule, Database database, int maxResults, Collection<CandidateQuery> results) {
        SearchFringe fringe = this.createFringe();
        List<CandidateQuery> candidates = this.search(fringe, rule, database);
        Collections.sort(candidates);
        for (int i = 0; i < Math.min(candidates.size(), maxResults); ++i) {
            results.add(candidates.get(i));
        }
    }

    private List<CandidateQuery> search(SearchFringe fringe, Rule rule, Database database) {
        fringe.clear();
        Formula baseFormula = rule.getRewritableGroundingFormula();
        DatabaseQuery.validate(baseFormula);
        if (baseFormula instanceof Atom) {
            return this.singleAtomSearch(rule, baseFormula, database);
        }
        ArrayList<CandidateQuery> candidates = new ArrayList<CandidateQuery>();
        HashSet<Atom> atomBuffer = new HashSet<Atom>();
        baseFormula.getAtoms(atomBuffer);
        Set<Atom> passthrough = this.filterSpecialAtoms(atomBuffer);
        Map<Variable, Set<Atom>> variableUsageMapping = this.getAllUsedVariables(atomBuffer);
        ArrayList<Atom> atoms = new ArrayList<Atom>(atomBuffer);
        Collections.sort(atoms, new Comparator<Atom>(){

            @Override
            public int compare(Atom a, Atom b) {
                return a.toString().compareTo(b.toString());
            }
        });
        atomBuffer.clear();
        HashSet<Long> seenNodes = new HashSet<Long>();
        boolean[] atomBits = new boolean[atoms.size()];
        for (int i = 0; i < atomBits.length; ++i) {
            atomBits[i] = true;
        }
        CandidateSearchNode rootNode = this.validateAndCreateNode(atomBits, atoms, passthrough, variableUsageMapping, atomBuffer, 0.0, 0.0);
        fringe.push(rootNode);
        seenNodes.add(BitUtils.toBitSet(atomBits));
        int explains = 0;
        while (fringe.size() > 0 && (this.budget <= 0 || explains < this.budget)) {
            CandidateSearchNode node = fringe.pop();
            BitUtils.toBits(node.atomsBitSet, atomBits);
            if (this.explainNode(node, database)) {
                ++explains;
            }
            candidates.add(new CandidateQuery(rule, node.formula, node.optimisticCost));
            fringe.newPessimisticCost(node.pessimisticCost);
            for (int i = 0; i < atoms.size(); ++i) {
                if (!atomBits[i]) continue;
                atomBits[i] = false;
                Long bitId = BitUtils.toBitSet(atomBits);
                if (!seenNodes.contains(bitId) && bitId != 0L) {
                    seenNodes.add(bitId);
                    CandidateSearchNode child = this.validateAndCreateNode(atomBits, atoms, passthrough, variableUsageMapping, atomBuffer, node.optimisticCost, node.pessimisticCost);
                    if (child != null) {
                        fringe.push(child);
                    }
                }
                atomBits[i] = true;
            }
        }
        return candidates;
    }

    private List<CandidateQuery> singleAtomSearch(Rule rule, Formula baseFormula, Database database) {
        assert (baseFormula instanceof Atom);
        ArrayList<CandidateQuery> candidates = new ArrayList<CandidateQuery>(2);
        CandidateSearchNode node = new CandidateSearchNode(0L, baseFormula, 1, 1.0, 1.0);
        this.explainNode(node, database);
        candidates.add(new CandidateQuery(rule, node.formula, node.optimisticCost));
        HashSet<Atom> atoms = new HashSet<Atom>();
        rule.getCoreAtoms(atoms);
        if (atoms.size() != 2) {
            return candidates;
        }
        int openAtoms = 0;
        for (Atom atom : atoms) {
            if (!(atom.getPredicate() instanceof StandardPredicate) || database.isClosed((StandardPredicate)atom.getPredicate())) continue;
            ++openAtoms;
        }
        if (openAtoms == 0) {
            return candidates;
        }
        atoms.remove(baseFormula);
        if (atoms.size() != 1) {
            return candidates;
        }
        node = new CandidateSearchNode(0L, (Formula)atoms.iterator().next(), 1, 1.0, 1.0);
        this.explainNode(node, database);
        candidates.add(new CandidateQuery(rule, node.formula, node.optimisticCost));
        return candidates;
    }

    private CandidateSearchNode validateAndCreateNode(boolean[] atomBits, List<Atom> atoms, Set<Atom> passthrough, Map<Variable, Set<Atom>> variableUsageMapping, Set<Atom> atomBuffer, double parentOptimisticCost, double parentPessimisticCost) {
        Formula formula = this.constructFormula(atomBits, atoms, passthrough, atomBuffer);
        for (Map.Entry<Variable, Set<Atom>> entry : variableUsageMapping.entrySet()) {
            boolean hasVariable = false;
            for (Atom atomWithVariable : entry.getValue()) {
                if (!atomBuffer.contains(atomWithVariable)) continue;
                hasVariable = true;
                break;
            }
            if (hasVariable) continue;
            atomBuffer.clear();
            return null;
        }
        int numAtoms = atomBuffer.size();
        atomBuffer.clear();
        assert (numAtoms > 0);
        double optimisticCost = parentOptimisticCost * (double)numAtoms / (double)(numAtoms + 1);
        double pessimisticCost = parentPessimisticCost * (double)numAtoms / (double)(numAtoms + 1);
        return new CandidateSearchNode(BitUtils.toBitSet(atomBits), formula, numAtoms, optimisticCost, pessimisticCost);
    }

    private boolean explainNode(CandidateSearchNode node, Database database) {
        DatabaseDriver.ExplainResult result = null;
        boolean usedExplain = false;
        String formulaString = node.formula.toString();
        if (this.explains.containsKey(formulaString)) {
            result = this.explains.get(formulaString);
        } else {
            String sql = Formula2SQL.getQuery(node.formula, database, false);
            result = database.getDataStore().explain(sql);
            this.explains.put(formulaString, result);
            usedExplain = true;
        }
        node.approximateCost = false;
        node.optimisticCost = (result.totalCost * 0.018 + (double)result.rows * 0.001) * ((double)node.numAtoms * 1.0);
        node.pessimisticCost = (result.totalCost * 0.02 + (double)result.rows * 0.002) * ((double)node.numAtoms * 1.0);
        if (usedExplain) {
            log.trace("Scored candidate: " + node);
        }
        return usedExplain;
    }

    private Formula constructFormula(boolean[] atomBits, List<Atom> atoms, Set<Atom> passthrough, Set<Atom> atomBuffer) {
        assert (atomBuffer.isEmpty());
        atomBuffer.addAll(passthrough);
        for (int i = 0; i < atomBits.length; ++i) {
            if (!atomBits[i]) continue;
            atomBuffer.add(atoms.get(i));
        }
        Formula formula = null;
        formula = atomBuffer.size() == 1 ? (Formula)atomBuffer.iterator().next() : new Conjunction(atomBuffer.toArray(new Formula[0]));
        return formula;
    }

    private Map<Variable, Set<Atom>> getAllUsedVariables(Set<Atom> atoms) {
        HashMap<Variable, Set<Atom>> variables = new HashMap<Variable, Set<Atom>>();
        for (Atom atom : atoms) {
            if (!(atom.getPredicate() instanceof StandardPredicate)) continue;
            for (Variable variable : atom.getVariables()) {
                if (!variables.containsKey(variable)) {
                    variables.put(variable, new HashSet());
                }
                ((Set)variables.get(variable)).add(atom);
            }
        }
        return variables;
    }

    private Set<Atom> filterSpecialAtoms(Set<Atom> atoms) {
        HashSet<Atom> passthrough = new HashSet<Atom>();
        HashSet<Atom> removeAtoms = new HashSet<Atom>();
        for (Atom atom : atoms) {
            if (atom.getPredicate() instanceof ExternalFunctionalPredicate) {
                removeAtoms.add(atom);
                continue;
            }
            if (atom.getPredicate() instanceof GroundingOnlyPredicate || atom.getPredicate() instanceof StandardPredicate) continue;
            throw new IllegalStateException("Unknown predicate type: " + atom.getPredicate().getClass().getName());
        }
        atoms.removeAll(removeAtoms);
        return passthrough;
    }

    private SearchFringe createFringe() {
        switch (this.searchType) {
            case BFS: {
                return new SearchFringe.BFSSearchFringe();
            }
            case DFS: {
                return new SearchFringe.DFSSearchFringe();
            }
            case UCS: {
                return new SearchFringe.UCSSearchFringe();
            }
            case BoundedUCS: {
                return new SearchFringe.BoundedUCSSearchFringe();
            }
            case BoundedDFS: {
                return new SearchFringe.BoundedDFSSearchFringe();
            }
        }
        throw new IllegalStateException("Unknown search type: " + (Object)((Object)this.searchType));
    }

    public static enum SearchType {
        BFS,
        DFS,
        UCS,
        BoundedUCS,
        BoundedDFS;

    }
}

