/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.graphscope.common.ir.planner.rules;

import com.alibaba.graphscope.common.ir.meta.glogue.CountHandler;
import com.alibaba.graphscope.common.ir.meta.glogue.ExtendWeightEstimator;
import com.alibaba.graphscope.common.ir.meta.glogue.Utils;
import com.alibaba.graphscope.common.ir.meta.glogue.calcite.GraphRelMetadataQuery;
import com.alibaba.graphscope.common.ir.rel.GraphExtendIntersect;
import com.alibaba.graphscope.common.ir.rel.GraphPattern;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.ExtendEdge;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.ExtendStep;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.GlogueExtendIntersectEdge;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.Pattern;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.PatternEdge;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.PatternVertex;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.tools.RelBuilderFactory;
import org.checkerframework.checker.nullness.qual.Nullable;

public class ExtendIntersectRule<C extends Config>
extends RelRule<C> {
    protected ExtendIntersectRule(C config) {
        super(config);
    }

    public void onMatch(RelOptRuleCall call) {
        List<GraphExtendIntersect> edges = this.getExtendIntersectEdges((GraphPattern)call.rel(0), (GraphRelMetadataQuery)call.getMetadataQuery());
        for (GraphExtendIntersect edge : edges) {
            call.transformTo((RelNode)edge);
        }
    }

    private List<GraphExtendIntersect> getExtendIntersectEdges(final GraphPattern graphPattern, final GraphRelMetadataQuery mq) {
        HeuristicComparator comparator = new HeuristicComparator(graphPattern);
        ExtendWeightEstimator estimator = new ExtendWeightEstimator(new CountHandler(){

            @Override
            public double handle(Pattern pattern) {
                return mq.getRowCount((RelNode)new GraphPattern(graphPattern.getCluster(), graphPattern.getTraitSet(), pattern));
            }

            @Override
            public double labelConstraintsDeltaCost(PatternEdge edge, PatternVertex target) {
                return ((Config)ExtendIntersectRule.this.config).labelConstraintsEnabled() ? mq.getGlogueQuery().getLabelConstraintsDeltaCost(edge, target) : 0.0;
            }
        });
        Pattern pattern = graphPattern.getPattern();
        int patternSize = pattern.getVertexNumber();
        ArrayList edges = Lists.newArrayList();
        if (patternSize <= 1) {
            return edges;
        }
        PruningStrategy pruningStrategy = new PruningStrategy(pattern);
        for (PatternVertex vertex : pattern.getVertexSet()) {
            if (pruningStrategy.toPrune(vertex)) continue;
            edges.add(this.createExtendIntersect(graphPattern, vertex, estimator));
        }
        Collections.sort(edges, comparator.getEdgeComparator());
        return edges;
    }

    private GraphExtendIntersect createExtendIntersect(GraphPattern graphPattern, PatternVertex target, ExtendWeightEstimator estimator) {
        Pattern dst = graphPattern.getPattern();
        Pattern src = new Pattern(dst);
        src.setPatternId(UUID.randomUUID().hashCode());
        src.removeVertex(target);
        ArrayList adjacentEdges = Lists.newArrayList(dst.getEdgesOf(target));
        double totalWeight = estimator.estimate(adjacentEdges, target);
        List<ExtendEdge> extendEdges = adjacentEdges.stream().map(k -> {
            PatternVertex extendFrom = Utils.getExtendFromVertex(k, target);
            return new ExtendEdge(src.getVertexOrder(extendFrom), k.getEdgeTypeIds(), Utils.getExtendDirection(k, target), estimator.estimate((PatternEdge)k, target), k.getElementDetails());
        }).collect(Collectors.toList());
        ExtendStep extendStep = new ExtendStep(target.getVertexTypeIds(), dst.getVertexOrder(target), extendEdges, totalWeight);
        GlogueExtendIntersectEdge glogueEdge = new GlogueExtendIntersectEdge(src, dst, extendStep, this.getOrderMapping(src, dst));
        return new GraphExtendIntersect(graphPattern.getCluster(), graphPattern.getTraitSet(), (RelNode)new GraphPattern(graphPattern.getCluster(), graphPattern.getTraitSet(), src), glogueEdge);
    }

    private Map<Integer, Integer> getOrderMapping(Pattern src, Pattern dst) {
        HashMap srcToDstOrderMap = Maps.newHashMap();
        for (PatternVertex vertex : src.getVertexSet()) {
            Integer dstOrder = dst.getVertexOrder(vertex);
            Preconditions.checkArgument((dstOrder != null ? 1 : 0) != 0, (String)"vertex %s is not in dst pattern %s", (Object)vertex, (Object)dst);
            srcToDstOrderMap.put(src.getVertexOrder(vertex), dstOrder);
        }
        return srcToDstOrderMap;
    }

    public static class Config
    implements RelRule.Config {
        public static Config DEFAULT = new Config().withOperandSupplier(b0 -> b0.operand(GraphPattern.class).anyInputs());
        private RelRule.OperandTransform operandSupplier;
        private @Nullable String description;
        private RelBuilderFactory builderFactory;
        private int maxPatternSizeInGlogue;
        private boolean labelConstraintsEnabled;

        public RelRule toRule() {
            return new ExtendIntersectRule<Config>(this);
        }

        public Config withRelBuilderFactory(RelBuilderFactory relBuilderFactory) {
            this.builderFactory = relBuilderFactory;
            return this;
        }

        public Config withDescription(@Nullable String s) {
            this.description = s;
            return this;
        }

        public Config withOperandSupplier(RelRule.OperandTransform operandTransform) {
            this.operandSupplier = operandTransform;
            return this;
        }

        public Config withMaxPatternSizeInGlogue(int maxPatternSizeInGlogue) {
            this.maxPatternSizeInGlogue = maxPatternSizeInGlogue;
            return this;
        }

        public Config withLabelConstraintsEnabled(boolean labelConstraintsEnabled) {
            this.labelConstraintsEnabled = labelConstraintsEnabled;
            return this;
        }

        public boolean labelConstraintsEnabled() {
            return this.labelConstraintsEnabled;
        }

        public RelRule.OperandTransform operandSupplier() {
            return this.operandSupplier;
        }

        public @Nullable String description() {
            return this.description;
        }

        public RelBuilderFactory relBuilderFactory() {
            return this.builderFactory;
        }

        public int getMaxPatternSizeInGlogue() {
            return this.maxPatternSizeInGlogue;
        }
    }

    private static class PruningStrategy {
        private final List<Predicate<PatternVertex>> predicates = Lists.newArrayList();

        public PruningStrategy(Pattern pattern) {
            this.predicates.add(v -> {
                Pattern clone = new Pattern(pattern);
                List<Set<PatternVertex>> connectedSets = clone.removeVertex((PatternVertex)v);
                return connectedSets.size() != 1;
            });
            List optionalVertices = pattern.getVertexSet().stream().filter(k -> k.getElementDetails().isOptional()).collect(Collectors.toList());
            if (!optionalVertices.isEmpty()) {
                this.predicates.add(v -> !optionalVertices.contains(v));
            } else {
                this.predicates.add(v -> {
                    Pattern clone = new Pattern(pattern);
                    clone.removeVertex((PatternVertex)v);
                    return clone.getEdgeSet().stream().anyMatch(k -> k.getElementDetails().isOptional());
                });
            }
        }

        public boolean toPrune(PatternVertex target) {
            for (Predicate<PatternVertex> predicate : this.predicates) {
                if (!predicate.test(target)) continue;
                return true;
            }
            return false;
        }
    }

    private static class HeuristicComparator {
        private final GraphPattern graphPattern;

        public HeuristicComparator(GraphPattern graphPattern) {
            this.graphPattern = graphPattern;
        }

        public Comparator<PatternVertex> getVertexComparator() {
            Pattern pattern = this.graphPattern.getPattern();
            return (v1, v2) -> pattern.getDegree((PatternVertex)v2) - pattern.getDegree((PatternVertex)v1);
        }

        public Comparator<GraphExtendIntersect> getEdgeComparator() {
            return (e1, e2) -> {
                ExtendStep step1 = e1.getGlogueEdge().getExtendStep();
                ExtendStep step2 = e2.getGlogueEdge().getExtendStep();
                int compareWeight = Double.compare(step1.getWeight(), step2.getWeight());
                if (compareWeight != 0) {
                    return compareWeight;
                }
                PatternVertex targetVertex1 = e1.getGlogueEdge().getDstPattern().getVertexByOrder(step1.getTargetVertexOrder());
                PatternVertex targetVertex2 = e2.getGlogueEdge().getDstPattern().getVertexByOrder(step2.getTargetVertexOrder());
                return targetVertex2.getId() - targetVertex1.getId();
            };
        }
    }
}

