/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.graphscope.common.ir.rex;

import com.alibaba.graphscope.common.ir.rel.graph.AbstractBindableTableScan;
import com.alibaba.graphscope.common.ir.rel.graph.GraphLogicalSource;
import com.alibaba.graphscope.common.ir.rel.type.TableConfig;
import com.alibaba.graphscope.common.ir.rex.ClassifiedFilter;
import com.alibaba.graphscope.common.ir.rex.RexGraphVariable;
import com.alibaba.graphscope.common.ir.rex.RexVariableAliasCollector;
import com.alibaba.graphscope.common.ir.tools.GraphBuilder;
import com.alibaba.graphscope.common.ir.tools.Utils;
import com.alibaba.graphscope.common.ir.type.GraphLabelType;
import com.alibaba.graphscope.common.ir.type.GraphNameOrId;
import com.alibaba.graphscope.common.ir.type.GraphProperty;
import com.alibaba.graphscope.common.ir.type.GraphSchemaType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexDynamicParam;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Sarg;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.javatuples.Pair;

public class RexFilterClassifier
extends RexVisitorImpl<Filter> {
    private final GraphBuilder builder;
    private final @Nullable AbstractBindableTableScan tableScan;

    public RexFilterClassifier(GraphBuilder builder, @Nullable AbstractBindableTableScan tableScan) {
        super(true);
        this.builder = builder;
        this.tableScan = tableScan;
    }

    public ClassifiedFilter classify(RexNode condition) {
        Filter filter = (Filter)condition.accept((RexVisitor)this);
        ArrayList labelFilters = Lists.newArrayList();
        ArrayList uniqueKeyFilters = Lists.newArrayList();
        filter.getSchemaFilters().forEach(k -> {
            switch (k.getSchemaType()) {
                case LABEL: {
                    labelFilters.add(k.getFilter());
                    break;
                }
                default: {
                    uniqueKeyFilters.add(k.getFilter());
                }
            }
        });
        ArrayList extraFilters = Lists.newArrayList();
        if (filter.getOtherFilter() != null) {
            extraFilters.add(filter.getOtherFilter());
        }
        ArrayList labelValues = Lists.newArrayList();
        labelFilters.forEach(k -> labelValues.addAll(this.getLabelValues((RexNode)k)));
        return new ClassifiedFilter(labelFilters, labelValues, uniqueKeyFilters, extraFilters);
    }

    private List<Comparable> getLabelValues(RexNode labelFilter) {
        return (List)labelFilter.accept((RexVisitor)new LabelValueCollector());
    }

    public Filter visitCall(RexCall call) {
        SqlOperator operator = call.getOperator();
        List operands = call.getOperands();
        switch (operator.getKind()) {
            case AND: {
                return this.conjunctions((RexNode)call, this.visitList(operands));
            }
            case OR: {
                return this.disjunctions(call);
            }
            case EQUALS: 
            case SEARCH: {
                RexVariableAliasCollector<Integer> aliasCollector = new RexVariableAliasCollector<Integer>(true, var -> var.getAliasId());
                if (this.isLabelEqualFilter(call)) {
                    Integer tagId = (Integer)((List)call.accept(aliasCollector)).get(0);
                    return new Filter((List<Filter.SchemaFilter>)ImmutableList.of((Object)new Filter.SchemaFilter(tagId, (RexNode)call, Filter.SchemaType.LABEL)), null);
                }
                if (this.tableScan == null || !this.isUniqueKeyEqualFilter((RexNode)call)) break;
                Integer tagId = (Integer)((List)call.accept(aliasCollector)).get(0);
                return new Filter((List<Filter.SchemaFilter>)ImmutableList.of((Object)new Filter.SchemaFilter(tagId, (RexNode)call, Filter.SchemaType.UNIQUE_KEY)), null);
            }
        }
        return new Filter((List<Filter.SchemaFilter>)ImmutableList.of(), (RexNode)call);
    }

    private Filter conjunctions(RexNode original, List<Filter> filters) {
        LinkedHashMap schemaFilterMap = Maps.newLinkedHashMap();
        ArrayList otherFilters = Lists.newArrayList();
        filters.forEach(k -> {
            k.getSchemaFilters().forEach(v -> {
                Pair key = Pair.with((Object)v.getTagId(), (Object)((Object)v.getSchemaType()));
                RexNode filtering = v.getFilter();
                RexNode existing = (RexNode)schemaFilterMap.get(key);
                if (existing != null) {
                    filtering = this.builder.and(new RexNode[]{existing, v.getFilter()});
                }
                if (!filtering.equals((Object)existing)) {
                    schemaFilterMap.put(key, filtering);
                }
            });
            if (k.getOtherFilter() != null) {
                otherFilters.add(k.getOtherFilter());
            }
        });
        ArrayList andSchemaFilters = Lists.newArrayList();
        schemaFilterMap.forEach((k, v) -> andSchemaFilters.add(new Filter.SchemaFilter((Integer)k.getValue0(), (RexNode)v, (Filter.SchemaType)((Object)((Object)k.getValue1())))));
        RexNode otherFilter = otherFilters.isEmpty() ? null : RexUtil.composeConjunction((RexBuilder)this.builder.getRexBuilder(), (Iterable)otherFilters, (boolean)false);
        return new Filter(andSchemaFilters, otherFilter);
    }

    private Filter disjunctions(RexCall original) {
        SqlOperator operator = original.getOperator();
        switch (operator.getKind()) {
            case OR: {
                List operands = original.getOperands();
                Filter.SchemaFilter schemaFilter = null;
                for (RexNode operand : operands) {
                    if (operand.getKind() != SqlKind.EQUALS && operand.getKind() != SqlKind.SEARCH) {
                        return new Filter((List<Filter.SchemaFilter>)ImmutableList.of(), (RexNode)original);
                    }
                    List<Filter.SchemaFilter> curFilters = ((Filter)operand.accept((RexVisitor)this)).getSchemaFilters();
                    if (curFilters.size() != 1) {
                        return new Filter((List<Filter.SchemaFilter>)ImmutableList.of(), (RexNode)original);
                    }
                    Filter.SchemaFilter cur = curFilters.get(0);
                    if (schemaFilter == null) {
                        schemaFilter = cur;
                        continue;
                    }
                    if (schemaFilter.getTagId() == cur.getTagId() && schemaFilter.getSchemaType() == cur.getSchemaType()) {
                        schemaFilter = new Filter.SchemaFilter(schemaFilter.getTagId(), this.builder.or(new RexNode[]{schemaFilter.getFilter(), cur.getFilter()}), schemaFilter.getSchemaType());
                        continue;
                    }
                    return new Filter((List<Filter.SchemaFilter>)ImmutableList.of(), (RexNode)original);
                }
                break;
            }
        }
        return new Filter((List<Filter.SchemaFilter>)ImmutableList.of(), (RexNode)original);
    }

    private boolean isLabelEqualFilter(RexCall rexCall) {
        return RexFilterClassifier.isLabelEqualFilter0((RexNode)rexCall) != null;
    }

    private static @Nullable RexLiteral isLabelEqualFilter0(RexNode condition) {
        if (condition instanceof RexCall) {
            RexCall rexCall = (RexCall)condition;
            SqlOperator operator = rexCall.getOperator();
            switch (operator.getKind()) {
                case EQUALS: 
                case SEARCH: {
                    RexNode left = (RexNode)rexCall.getOperands().get(0);
                    RexNode right = (RexNode)rexCall.getOperands().get(1);
                    if (left.getType() instanceof GraphLabelType && right instanceof RexLiteral) {
                        Comparable value = ((RexLiteral)right).getValue();
                        if (value instanceof Sarg && !((Sarg)value).isPoints()) {
                            return null;
                        }
                        return (RexLiteral)right;
                    }
                    if (!(right.getType() instanceof GraphLabelType) || !(left instanceof RexLiteral)) break;
                    Comparable value = ((RexLiteral)left).getValue();
                    if (value instanceof Sarg && !((Sarg)value).isPoints()) {
                        return null;
                    }
                    return (RexLiteral)left;
                }
            }
            return null;
        }
        return null;
    }

    private boolean isUniqueKeyEqualFilter(RexNode condition) {
        if (!(this.tableScan instanceof GraphLogicalSource)) {
            return false;
        }
        if (condition instanceof RexCall) {
            RexCall rexCall = (RexCall)condition;
            SqlOperator operator = rexCall.getOperator();
            switch (operator.getKind()) {
                case EQUALS: 
                case SEARCH: {
                    Comparable value;
                    RexNode left = (RexNode)rexCall.getOperands().get(0);
                    RexNode right = (RexNode)rexCall.getOperands().get(1);
                    if (this.isUniqueKey(left, (RelNode)this.tableScan) && this.isLiteralOrDynamicParams(right)) {
                        Comparable value2;
                        return !(right instanceof RexLiteral) || !((value2 = ((RexLiteral)right).getValue()) instanceof Sarg) || ((Sarg)value2).isPoints();
                    }
                    if (!this.isUniqueKey(right, (RelNode)this.tableScan) || !this.isLiteralOrDynamicParams(left)) break;
                    return !(left instanceof RexLiteral) || !((value = ((RexLiteral)left).getValue()) instanceof Sarg) || ((Sarg)value).isPoints();
                }
            }
            return false;
        }
        return false;
    }

    private boolean isUniqueKey(RexNode rexNode, RelNode tableScan) {
        if (rexNode instanceof RexGraphVariable) {
            return this.isUniqueKey((RexGraphVariable)rexNode, tableScan);
        }
        return false;
    }

    private boolean isUniqueKey(RexGraphVariable var, RelNode tableScan) {
        if (var.getProperty() == null) {
            return false;
        }
        switch (var.getProperty().getOpt()) {
            case ID: {
                return true;
            }
            case KEY: {
                GraphSchemaType schemaType = (GraphSchemaType)((RelDataTypeField)tableScan.getRowType().getFieldList().get(0)).getType();
                ImmutableBitSet propertyIds = this.getPropertyIds(var.getProperty(), schemaType);
                TableConfig tableConfig = ((AbstractBindableTableScan)tableScan).getTableConfig();
                if (propertyIds.isEmpty() || !tableConfig.getTables().stream().allMatch(k -> k.isKey(propertyIds))) break;
                return true;
            }
        }
        return false;
    }

    private ImmutableBitSet getPropertyIds(GraphProperty property, GraphSchemaType schemaType) {
        if (property.getOpt() != GraphProperty.Opt.KEY) {
            return ImmutableBitSet.of();
        }
        GraphNameOrId key = property.getKey();
        if (key.getOpt() == GraphNameOrId.Opt.ID) {
            return ImmutableBitSet.of((int[])new int[]{key.getId()});
        }
        for (int i = 0; i < schemaType.getFieldList().size(); ++i) {
            RelDataTypeField field = (RelDataTypeField)schemaType.getFieldList().get(i);
            if (!field.getName().equals(key.getName())) continue;
            return ImmutableBitSet.of((int[])new int[]{i});
        }
        return ImmutableBitSet.of();
    }

    private boolean isLiteralOrDynamicParams(RexNode node) {
        return node instanceof RexLiteral || node instanceof RexDynamicParam;
    }

    public static class Filter {
        private final List<SchemaFilter> schemaFilters;
        private final @Nullable RexNode otherFilter;

        public Filter(List<SchemaFilter> schemaFilters, RexNode otherFilter) {
            this.schemaFilters = schemaFilters;
            this.otherFilter = otherFilter;
        }

        public List<SchemaFilter> getSchemaFilters() {
            return Collections.unmodifiableList(this.schemaFilters);
        }

        public @Nullable RexNode getOtherFilter() {
            return this.otherFilter;
        }

        public static enum SchemaType {
            LABEL,
            UNIQUE_KEY;

        }

        public static class SchemaFilter {
            private final Integer tagId;
            private final RexNode filter;
            private final SchemaType schemaType;

            public SchemaFilter(Integer tagId, RexNode filtering, SchemaType schemaType) {
                this.tagId = tagId;
                this.filter = filtering;
                this.schemaType = schemaType;
            }

            public Integer getTagId() {
                return this.tagId;
            }

            public RexNode getFilter() {
                return this.filter;
            }

            public SchemaType getSchemaType() {
                return this.schemaType;
            }
        }
    }

    private static class LabelValueCollector
    extends RexVisitorImpl<List<Comparable>> {
        public LabelValueCollector() {
            super(true);
        }

        public List<Comparable> visitCall(RexCall call) {
            SqlOperator operator = call.getOperator();
            switch (operator.getKind()) {
                case AND: {
                    ArrayList andLabels = Lists.newArrayList();
                    call.getOperands().forEach(k -> {
                        List cur = (List)k.accept((RexVisitor)this);
                        if (andLabels.isEmpty()) {
                            andLabels.addAll(cur);
                        } else {
                            andLabels.retainAll(cur);
                        }
                        if (andLabels.isEmpty()) {
                            throw new IllegalArgumentException("cannot find common labels between values=" + andLabels + " and values=" + cur);
                        }
                    });
                }
                case OR: {
                    ArrayList orLabels = Lists.newArrayList();
                    call.getOperands().forEach(k -> orLabels.addAll((Collection)k.accept((RexVisitor)this)));
                    return orLabels.stream().distinct().collect(Collectors.toList());
                }
                case EQUALS: 
                case SEARCH: {
                    RexLiteral labelLiteral = RexFilterClassifier.isLabelEqualFilter0((RexNode)call);
                    if (labelLiteral == null) break;
                    return Utils.getValuesAsList((Comparable)labelLiteral.getValueAs(Comparable.class));
                }
            }
            return ImmutableList.of();
        }
    }
}

