/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.graphscope.common.ir.meta.glogue.calcite.handler;

import com.alibaba.graphscope.common.ir.meta.glogue.Utils;
import com.alibaba.graphscope.common.ir.rel.GraphPattern;
import com.alibaba.graphscope.common.ir.rel.graph.AbstractBindableTableScan;
import com.alibaba.graphscope.common.ir.rel.graph.GraphLogicalExpand;
import com.alibaba.graphscope.common.ir.rel.graph.GraphLogicalGetV;
import com.alibaba.graphscope.common.ir.rel.graph.GraphLogicalSource;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.ElementDetails;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.FuzzyPatternEdge;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.FuzzyPatternVertex;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.Pattern;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.PatternEdge;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.PatternVertex;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.SinglePatternEdge;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.SinglePatternVertex;
import com.alibaba.graphscope.common.ir.rel.metadata.schema.EdgeTypeId;
import com.alibaba.graphscope.common.ir.rex.RexGraphVariable;
import com.alibaba.graphscope.common.ir.rex.RexVariableAliasCollector;
import com.alibaba.graphscope.common.ir.tools.config.GraphOpt;
import com.alibaba.graphscope.common.ir.type.GraphNameOrId;
import com.alibaba.graphscope.common.ir.type.GraphProperty;
import com.alibaba.graphscope.common.ir.type.GraphSchemaType;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rel.metadata.BuiltInMetadata;
import org.apache.calcite.rel.metadata.RelMdSelectivity;
import org.apache.calcite.rel.metadata.RelMdUtil;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Sarg;
import org.checkerframework.checker.nullness.qual.Nullable;

public class GraphSelectivityHandler
extends RelMdSelectivity
implements BuiltInMetadata.Selectivity.Handler {
    private static final double FACTOR = 1.2;

    public @Nullable Double getSelectivity(RelNode node, RelMetadataQuery mq, @Nullable RexNode condition) {
        if (node instanceof TableScan) {
            return this.getSelectivity((TableScan)node, mq, condition);
        }
        return RelMdUtil.guessSelectivity((RexNode)condition);
    }

    public Double getSelectivity(TableScan tableScan, RelMetadataQuery mq, RexNode condition) {
        double total = 1.0;
        if (condition == null || condition.isAlwaysTrue()) {
            return total;
        }
        for (RexNode conjunction : RelOptUtil.conjunctions((RexNode)condition)) {
            double perSelectivity = 0.0;
            for (RexNode disjunction : RelOptUtil.disjunctions((RexNode)conjunction)) {
                perSelectivity += this.guessSelectivity(tableScan, mq, disjunction);
            }
            total *= perSelectivity;
        }
        return total;
    }

    private double guessSelectivity(TableScan tableScan, RelMetadataQuery mq, RexNode condition) {
        RexVariableAliasCollector<Pair> varTableScanCollector = new RexVariableAliasCollector<Pair>(true, var -> {
            TableScan scanByAlias = this.getTableScanByAlias((RelNode)tableScan, var.getAliasId());
            Preconditions.checkArgument((scanByAlias != null ? 1 : 0) != 0, (Object)("can not find table scan for aliasId=" + var.getAliasId()));
            return Pair.of((Object)var, (Object)scanByAlias);
        });
        double maxCountForUniqueKeys = 0.0;
        double maxCount = 0.0;
        for (Pair varTableScan : (List)condition.accept(varTableScanCollector)) {
            RexGraphVariable var2 = (RexGraphVariable)((Object)varTableScan.left);
            TableScan scan = (TableScan)varTableScan.right;
            double count = this.getFullRowCount(scan, mq);
            if (this.isUniqueKey(var2, (RelNode)scan) && count > maxCountForUniqueKeys) {
                maxCountForUniqueKeys = count;
            }
            if (!(count > maxCount)) continue;
            maxCount = count;
        }
        if (Double.compare(maxCountForUniqueKeys, 0.0) != 0) {
            if (condition.isA(SqlKind.SEARCH)) {
                RexNode right = (RexNode)((RexCall)condition).getOperands().get(1);
                Sarg sarg = (Sarg)((RexLiteral)right).getValueAs(Sarg.class);
                return (double)sarg.pointCount / maxCountForUniqueKeys;
            }
            if (condition.isA(SqlKind.EQUALS)) {
                return 1.0 / maxCountForUniqueKeys;
            }
        }
        return Math.max(RelMdUtil.guessSelectivity((RexNode)condition), this.relax(1.0 / maxCount));
    }

    private double relax(double value) {
        double relaxValue = value * 1.2;
        return Double.compare(relaxValue, 1.0) > 0 ? 1.0 : relaxValue;
    }

    private boolean isUniqueKey(RexGraphVariable var, RelNode tableScan) {
        if (var.getProperty() == null) {
            return false;
        }
        switch (var.getProperty().getOpt()) {
            case ID: {
                return true;
            }
            case KEY: {
                GraphSchemaType schemaType = (GraphSchemaType)((RelDataTypeField)tableScan.getRowType().getFieldList().get(0)).getType();
                ImmutableBitSet propertyIds = this.getPropertyIds(var.getProperty(), schemaType);
                if (propertyIds.isEmpty() || !tableScan.getTable().isKey(propertyIds)) break;
                return true;
            }
        }
        return false;
    }

    private ImmutableBitSet getPropertyIds(GraphProperty property, GraphSchemaType schemaType) {
        if (property.getOpt() != GraphProperty.Opt.KEY) {
            return ImmutableBitSet.of();
        }
        GraphNameOrId key = property.getKey();
        if (key.getOpt() == GraphNameOrId.Opt.ID) {
            return ImmutableBitSet.of((int[])new int[]{key.getId()});
        }
        for (int i = 0; i < schemaType.getFieldList().size(); ++i) {
            RelDataTypeField field = (RelDataTypeField)schemaType.getFieldList().get(i);
            if (!field.getName().equals(key.getName())) continue;
            return ImmutableBitSet.of((int[])new int[]{i});
        }
        return ImmutableBitSet.of();
    }

    private TableScan getTableScanByAlias(RelNode top, int aliasId) {
        ArrayList queue = Lists.newArrayList((Object[])new RelNode[]{top});
        while (!queue.isEmpty()) {
            RelNode cur = (RelNode)queue.remove(0);
            if (cur instanceof AbstractBindableTableScan && (aliasId == -1 || ((AbstractBindableTableScan)cur).getAliasId() == aliasId)) {
                return (AbstractBindableTableScan)cur;
            }
            queue.addAll(cur.getInputs());
        }
        return null;
    }

    private double getFullRowCount(TableScan rel, RelMetadataQuery mq) {
        if (rel instanceof GraphLogicalSource || rel instanceof GraphLogicalGetV) {
            List<Integer> vertexTypeIds = Utils.getVertexTypeIds((RelNode)rel);
            PatternVertex vertex = vertexTypeIds.size() == 1 ? new SinglePatternVertex(vertexTypeIds.get(0)) : new FuzzyPatternVertex(vertexTypeIds);
            return mq.getRowCount((RelNode)new GraphPattern(rel.getCluster(), rel.getTraitSet(), new Pattern(vertex)));
        }
        if (rel instanceof GraphLogicalExpand) {
            List<EdgeTypeId> edgeTypeIds = Utils.getEdgeTypeIds((RelNode)rel);
            List<Integer> srcVertexTypeIds = edgeTypeIds.stream().map(k -> k.getSrcLabelId()).collect(Collectors.toList());
            List<Integer> dstVertexTypeIds = edgeTypeIds.stream().map(k -> k.getDstLabelId()).collect(Collectors.toList());
            PatternVertex srcVertex = srcVertexTypeIds.size() == 1 ? new SinglePatternVertex((Integer)srcVertexTypeIds.get(0), 0) : new FuzzyPatternVertex(srcVertexTypeIds, 0);
            PatternVertex dstVertex = dstVertexTypeIds.size() == 1 ? new SinglePatternVertex((Integer)dstVertexTypeIds.get(0), 1) : new FuzzyPatternVertex(dstVertexTypeIds, 1);
            boolean isBoth = ((GraphLogicalExpand)rel).getOpt() == GraphOpt.Expand.BOTH;
            PatternEdge edge = edgeTypeIds.size() == 1 ? new SinglePatternEdge(srcVertex, dstVertex, edgeTypeIds.get(0), 0, isBoth, new ElementDetails()) : new FuzzyPatternEdge(srcVertex, dstVertex, edgeTypeIds, 0, isBoth, new ElementDetails());
            Pattern pattern = new Pattern();
            pattern.addVertex(srcVertex);
            pattern.addVertex(dstVertex);
            pattern.addEdge(srcVertex, dstVertex, edge);
            return mq.getRowCount((RelNode)new GraphPattern(rel.getCluster(), rel.getTraitSet(), pattern));
        }
        throw new IllegalArgumentException("can not estimate row count for the rel=" + rel);
    }
}

