/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.graphscope.common.ir.meta.glogue.calcite.handler;

import com.alibaba.graphscope.common.ir.meta.glogue.DetailedSourceCost;
import com.alibaba.graphscope.common.ir.meta.glogue.PrimitiveCountEstimator;
import com.alibaba.graphscope.common.ir.meta.glogue.Utils;
import com.alibaba.graphscope.common.ir.rel.CommonTableScan;
import com.alibaba.graphscope.common.ir.rel.GraphExtendIntersect;
import com.alibaba.graphscope.common.ir.rel.GraphJoinDecomposition;
import com.alibaba.graphscope.common.ir.rel.GraphPattern;
import com.alibaba.graphscope.common.ir.rel.graph.AbstractBindableTableScan;
import com.alibaba.graphscope.common.ir.rel.graph.GraphLogicalPathExpand;
import com.alibaba.graphscope.common.ir.rel.graph.GraphLogicalSource;
import com.alibaba.graphscope.common.ir.rel.graph.GraphPhysicalExpand;
import com.alibaba.graphscope.common.ir.rel.graph.GraphPhysicalGetV;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.ExtendStep;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.GlogueQuery;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.FuzzyPatternVertex;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.Pattern;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.PatternEdge;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.PatternVertex;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.SinglePatternVertex;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.calcite.plan.GraphOptCluster;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.volcano.RelSubset;
import org.apache.calcite.plan.volcano.VolcanoPlanner;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.metadata.BuiltInMetadata;
import org.apache.calcite.rel.metadata.RelMdRowCount;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.rules.MultiJoin;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexUtil;
import org.apache.commons.lang3.ObjectUtils;
import org.checkerframework.checker.nullness.qual.Nullable;

public class GraphRowCountHandler
implements BuiltInMetadata.RowCount.Handler {
    private final PrimitiveCountEstimator countEstimator;
    private final RelOptPlanner optPlanner;
    private final RelMdRowCount mdRowCount;

    public GraphRowCountHandler(RelOptPlanner optPlanner, GlogueQuery glogueQuery) {
        this.optPlanner = optPlanner;
        this.countEstimator = new PrimitiveCountEstimator(glogueQuery);
        this.mdRowCount = new RelMdRowCount();
    }

    public Double getRowCount(RelNode node, RelMetadataQuery mq) {
        if (node instanceof GraphPattern) {
            RelSubset subset;
            Pattern pattern = ((GraphPattern)node).getPattern();
            Double countEstimate = this.countEstimator.estimate(pattern);
            if (countEstimate != null) {
                return countEstimate;
            }
            if (this.optPlanner instanceof VolcanoPlanner && (subset = ((VolcanoPlanner)this.optPlanner).getSubset(node)) != null) {
                GraphExtendIntersect extendIntersect = (GraphExtendIntersect)this.feasibleIntersects(subset);
                if (extendIntersect != null) {
                    ExtendStep extendStep = extendIntersect.getGlogueEdge().getExtendStep();
                    int targetOrder = extendStep.getTargetVertexOrder();
                    PatternVertex target = pattern.getVertexByOrder(targetOrder);
                    Set<PatternEdge> adjacentEdges = pattern.getEdgesOf(target);
                    Pattern extendPattern = new Pattern();
                    ArrayList extendFromVertices = Lists.newArrayList();
                    for (PatternEdge edge : adjacentEdges) {
                        extendPattern.addVertex(edge.getSrcVertex());
                        extendPattern.addVertex(edge.getDstVertex());
                        extendPattern.addEdge(edge.getSrcVertex(), edge.getDstVertex(), edge);
                        extendFromVertices.add(Utils.getExtendFromVertex(edge, target));
                    }
                    return this.getRowCount((GraphPattern)this.subGraphPattern((RelNode)extendIntersect, 0), new GraphPattern(node.getCluster(), node.getTraitSet(), extendPattern), extendFromVertices, mq);
                }
                GraphJoinDecomposition joinDecomposition = (GraphJoinDecomposition)this.feasibleJoinDecomposition(subset);
                if (joinDecomposition != null) {
                    Pattern buildPattern = joinDecomposition.getBuildPattern();
                    List<PatternVertex> jointVertices = joinDecomposition.getJoinVertexPairs().stream().map(k -> buildPattern.getVertexByOrder(k.getRightOrderId())).collect(Collectors.toList());
                    return this.getRowCount((GraphPattern)this.subGraphPattern((RelNode)joinDecomposition, 0), (GraphPattern)this.subGraphPattern((RelNode)joinDecomposition, 1), jointVertices, mq);
                }
            }
            double totalRowCount = 1.0;
            for (PatternEdge edge : pattern.getEdgeSet()) {
                totalRowCount *= this.countEstimator.estimate(edge);
            }
            for (PatternVertex vertex : pattern.getVertexSet()) {
                int degree = pattern.getEdgesOf(vertex).size();
                if (degree <= 0) continue;
                totalRowCount /= Math.pow(this.countEstimator.estimate(vertex), degree - 1);
            }
            return totalRowCount;
        }
        if (node instanceof RelSubset) {
            return mq.getRowCount(((RelSubset)node).getOriginal());
        }
        if (node instanceof GraphExtendIntersect || node instanceof GraphJoinDecomposition) {
            RelSubset subset;
            if (this.optPlanner instanceof VolcanoPlanner && (subset = ((VolcanoPlanner)this.optPlanner).getSubset(node)) != null) {
                return mq.getRowCount((RelNode)subset);
            }
            Pattern original = node instanceof GraphExtendIntersect ? ((GraphExtendIntersect)node).getGlogueEdge().getDstPattern() : ((GraphJoinDecomposition)node).getParentPatten();
            return mq.getRowCount((RelNode)new GraphPattern(node.getCluster(), node.getTraitSet(), original));
        }
        if (node instanceof AbstractBindableTableScan) {
            return this.getRowCount((AbstractBindableTableScan)node, mq);
        }
        if (node instanceof GraphLogicalPathExpand) {
            return node.estimateRowCount(mq);
        }
        if (node instanceof GraphPhysicalExpand) {
            return node.estimateRowCount(mq);
        }
        if (node instanceof GraphPhysicalGetV) {
            return node.estimateRowCount(mq);
        }
        if (node instanceof MultiJoin) {
            GraphOptCluster optCluster = (GraphOptCluster)node.getCluster();
            RelOptCost cachedCost = optCluster.getLocalState().getCachedCost();
            if (cachedCost != null) {
                return cachedCost.getRows();
            }
        } else {
            if (node instanceof Join) {
                GraphOptCluster optCluster = (GraphOptCluster)node.getCluster();
                RelOptCost cachedCost = optCluster.getLocalState().getCachedCost();
                return cachedCost != null ? cachedCost.getRows() : this.mdRowCount.getRowCount((Join)node, mq).doubleValue();
            }
            if (node instanceof Union) {
                return this.mdRowCount.getRowCount((Union)node, mq);
            }
            if (node instanceof Filter) {
                return this.mdRowCount.getRowCount((Filter)node, mq);
            }
            if (node instanceof Aggregate) {
                return this.mdRowCount.getRowCount((Aggregate)node, mq);
            }
            if (node instanceof Sort) {
                return this.mdRowCount.getRowCount((Sort)node, mq);
            }
            if (node instanceof Project) {
                return this.mdRowCount.getRowCount((Project)node, mq);
            }
            if (node instanceof CommonTableScan) {
                return this.mdRowCount.getRowCount((TableScan)((CommonTableScan)node), mq);
            }
        }
        throw new IllegalArgumentException("can not estimate row count for the node=" + node);
    }

    private double getRowCount(AbstractBindableTableScan rel, RelMetadataQuery mq) {
        if (rel.getCachedCost() == null && rel instanceof GraphLogicalSource) {
            GraphLogicalSource source = (GraphLogicalSource)rel;
            List<Integer> vertexTypeIds = Utils.getVertexTypeIds((RelNode)rel);
            PatternVertex vertex = vertexTypeIds.size() == 1 ? new SinglePatternVertex(vertexTypeIds.get(0)) : new FuzzyPatternVertex(vertexTypeIds);
            double fullCount = mq.getRowCount((RelNode)new GraphPattern(rel.getCluster(), rel.getTraitSet(), new Pattern(vertex)));
            ArrayList sourceFilters = Lists.newArrayList();
            if (source.getUniqueKeyFilters() != null) {
                sourceFilters.add(source.getUniqueKeyFilters());
            }
            if (ObjectUtils.isNotEmpty(source.getFilters())) {
                sourceFilters.addAll(source.getFilters());
            }
            double selectivity = mq.getSelectivity((RelNode)rel, RexUtil.composeConjunction((RexBuilder)rel.getCluster().getRexBuilder(), (Iterable)sourceFilters));
            source.setCachedCost((RelOptCost)new DetailedSourceCost(fullCount, fullCount * selectivity));
        }
        if (rel.getCachedCost() != null) {
            return rel.estimateRowCount(mq);
        }
        throw new IllegalArgumentException("can not estimate row count for the rel=" + rel);
    }

    private double getRowCount(GraphPattern p1, GraphPattern p2, List<PatternVertex> jointVertices, RelMetadataQuery mq) {
        double count = this.getRowCount((RelNode)p1, mq) * this.getRowCount((RelNode)p2, mq);
        for (PatternVertex vertex : jointVertices) {
            count /= this.countEstimator.estimate(vertex);
        }
        return count;
    }

    private @Nullable RelNode feasibleIntersects(RelSubset subSet) {
        List rels = subSet.getRelList();
        for (RelNode rel : rels) {
            RelSubset subset1;
            GraphExtendIntersect intersect1;
            if (!(rel instanceof GraphExtendIntersect) || !((intersect1 = (GraphExtendIntersect)rel).getInput(0) instanceof RelSubset) || (subset1 = (RelSubset)intersect1.getInput(0)).getBest() == null) continue;
            return rel;
        }
        return null;
    }

    private @Nullable RelNode subGraphPattern(RelNode rel, int subId) {
        RelNode input = rel.getInput(subId);
        return input instanceof RelSubset ? ((RelSubset)input).getOriginal() : input;
    }

    private @Nullable RelNode feasibleJoinDecomposition(RelSubset subSet) {
        List rels = subSet.getRelList();
        for (RelNode rel : rels) {
            GraphJoinDecomposition decomposition;
            if (!(rel instanceof GraphJoinDecomposition) || !((decomposition = (GraphJoinDecomposition)rel).getLeft() instanceof RelSubset) || !(decomposition.getRight() instanceof RelSubset)) continue;
            RelSubset left = (RelSubset)decomposition.getLeft();
            RelSubset right = (RelSubset)decomposition.getRight();
            if (left.getBest() == null || right.getBest() == null) continue;
            return rel;
        }
        return null;
    }
}

