/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.graphscope.common.ir.meta.glogue.calcite.handler;

import com.alibaba.graphscope.common.config.PlannerConfig;
import com.alibaba.graphscope.common.ir.rel.GraphExtendIntersect;
import com.alibaba.graphscope.common.ir.rel.GraphJoinDecomposition;
import com.alibaba.graphscope.common.ir.rel.GraphPattern;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.GlogueExtendIntersectEdge;
import java.util.List;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptCostFactory;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.metadata.BuiltInMetadata;
import org.apache.calcite.rel.metadata.RelMetadataQuery;

public class GraphNonCumulativeCostHandler
implements BuiltInMetadata.NonCumulativeCost.Handler {
    private final RelOptPlanner optPlanner;
    private final RelOptCostFactory costFactory;
    private final PlannerConfig plannerConfig;

    public GraphNonCumulativeCostHandler(RelOptPlanner optPlanner, PlannerConfig plannerConfig) {
        this.optPlanner = optPlanner;
        this.costFactory = optPlanner.getCostFactory();
        this.plannerConfig = plannerConfig;
    }

    public RelOptCost getNonCumulativeCost(RelNode node, RelMetadataQuery mq) {
        if (node instanceof GraphExtendIntersect) {
            GlogueExtendIntersectEdge glogueEdge = ((GraphExtendIntersect)node).getGlogueEdge();
            double weight = glogueEdge.getExtendStep().getWeight();
            double srcPatternCount = mq.getRowCount(node.getInput(0));
            double dRows = weight * srcPatternCount;
            if (glogueEdge.getExtendStep().getExtendEdges().size() > 1) {
                dRows *= (double)this.plannerConfig.getIntersectCostFactor();
            }
            double dCpu = dRows + 1.0;
            double dIo = mq.getRowCount(node);
            return this.costFactory.makeCost(dRows, dCpu, dIo);
        }
        if (node instanceof GraphPattern) {
            int patternSize = ((GraphPattern)node).getPattern().getVertexNumber();
            if (patternSize <= 1) {
                double dRows = mq.getRowCount(node);
                return this.costFactory.makeCost(dRows, dRows + 1.0, dRows);
            }
            return this.costFactory.makeInfiniteCost();
        }
        if (node instanceof GraphJoinDecomposition) {
            GraphJoinDecomposition decomposition = (GraphJoinDecomposition)node;
            double probeCount = mq.getRowCount(decomposition.getLeft());
            double buildCount = mq.getRowCount(decomposition.getRight());
            List<GraphJoinDecomposition.JoinVertexPair> joinVertexPairs = decomposition.getJoinVertexPairs();
            double dRows = joinVertexPairs.stream().allMatch(k -> k.isForeignKey()) ? Math.min(probeCount, buildCount) * 2.0 : (double)this.plannerConfig.getJoinCostFactor1() * probeCount + (double)this.plannerConfig.getJoinCostFactor2() * buildCount;
            double dCpu = dRows + 1.0;
            double dIo = mq.getRowCount(node);
            return this.costFactory.makeCost(dRows, dCpu, dIo);
        }
        return node.computeSelfCost(this.optPlanner, mq);
    }
}

