/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.graphscope.common.ir.rel.metadata.schema;

import com.alibaba.graphscope.common.ir.meta.IrMetaStats;
import com.alibaba.graphscope.common.ir.meta.glogue.Utils;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.PatternDirection;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.PatternEdge;
import com.alibaba.graphscope.common.ir.rel.metadata.glogue.pattern.PatternVertex;
import com.alibaba.graphscope.common.ir.rel.metadata.schema.EdgeTypeId;
import com.alibaba.graphscope.groot.common.schema.api.EdgeRelation;
import com.alibaba.graphscope.groot.common.schema.api.GraphEdge;
import com.alibaba.graphscope.groot.common.schema.api.GraphSchema;
import com.alibaba.graphscope.groot.common.schema.api.GraphStatistics;
import com.alibaba.graphscope.groot.common.schema.api.GraphVertex;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.AtomicDouble;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import org.jgrapht.Graph;
import org.jgrapht.graph.DirectedPseudograph;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GlogueSchema {
    private Graph<Integer, EdgeTypeId> schemaGraph;
    private HashMap<Integer, Double> vertexTypeCardinality;
    private HashMap<EdgeTypeId, Double> edgeTypeCardinality;
    private static Logger logger = LoggerFactory.getLogger(GlogueSchema.class);

    public GlogueSchema(GraphSchema graphSchema, HashMap<Integer, Double> vertexTypeCardinality, HashMap<EdgeTypeId, Double> edgeTypeCardinality) {
        this.schemaGraph = new DirectedPseudograph(EdgeTypeId.class);
        for (GraphVertex vertex : graphSchema.getVertexList()) {
            this.schemaGraph.addVertex((Object)vertex.getLabelId());
        }
        for (GraphEdge edge : graphSchema.getEdgeList()) {
            for (EdgeRelation relation : edge.getRelationList()) {
                int sourceType = relation.getSource().getLabelId();
                int targetType = relation.getTarget().getLabelId();
                EdgeTypeId edgeType = new EdgeTypeId(sourceType, targetType, edge.getLabelId());
                this.schemaGraph.addEdge((Object)sourceType, (Object)targetType, (Object)edgeType);
            }
        }
        this.vertexTypeCardinality = vertexTypeCardinality;
        this.edgeTypeCardinality = edgeTypeCardinality;
    }

    public GlogueSchema(GraphSchema graphSchema) {
        this.schemaGraph = new DirectedPseudograph(EdgeTypeId.class);
        this.vertexTypeCardinality = new HashMap();
        this.edgeTypeCardinality = new HashMap();
        for (GraphVertex vertex : graphSchema.getVertexList()) {
            this.schemaGraph.addVertex((Object)vertex.getLabelId());
            this.vertexTypeCardinality.put(vertex.getLabelId(), 1.0);
        }
        for (GraphEdge edge : graphSchema.getEdgeList()) {
            for (EdgeRelation relation : edge.getRelationList()) {
                int sourceType = relation.getSource().getLabelId();
                int targetType = relation.getTarget().getLabelId();
                EdgeTypeId edgeType = new EdgeTypeId(sourceType, targetType, edge.getLabelId());
                this.schemaGraph.addEdge((Object)sourceType, (Object)targetType, (Object)edgeType);
                this.edgeTypeCardinality.put(edgeType, 1.0);
            }
        }
        logger.info("GlogueSchema created with default cardinality 1.0: {}", (Object)this);
    }

    public GlogueSchema(GraphSchema graphSchema, GraphStatistics statistics) {
        logger.info("Creating GlogueSchema with statistics, vertex count: {}, edge count: {}", (Object)statistics.getVertexCount(), (Object)statistics.getEdgeCount());
        this.schemaGraph = new DirectedPseudograph(EdgeTypeId.class);
        this.vertexTypeCardinality = new HashMap();
        this.edgeTypeCardinality = new HashMap();
        for (GraphVertex vertex : graphSchema.getVertexList()) {
            this.schemaGraph.addVertex((Object)vertex.getLabelId());
            Long vertexTypeCount = statistics.getVertexTypeCount(Integer.valueOf(vertex.getLabelId()));
            if (vertexTypeCount == null) {
                throw new IllegalArgumentException("Vertex type count not found for vertex type: " + vertex.getLabelId());
            }
            if (vertexTypeCount == 0L) {
                this.vertexTypeCardinality.put(vertex.getLabelId(), 1.0);
                continue;
            }
            this.vertexTypeCardinality.put(vertex.getLabelId(), vertexTypeCount.doubleValue());
        }
        for (GraphEdge edge : graphSchema.getEdgeList()) {
            for (EdgeRelation relation : edge.getRelationList()) {
                int sourceType = relation.getSource().getLabelId();
                int targetType = relation.getTarget().getLabelId();
                EdgeTypeId edgeType = new EdgeTypeId(sourceType, targetType, edge.getLabelId());
                this.schemaGraph.addEdge((Object)sourceType, (Object)targetType, (Object)edgeType);
                Long edgeTypeCount = statistics.getEdgeTypeCount(Optional.of(sourceType), Optional.of(edge.getLabelId()), Optional.of(targetType));
                if (edgeTypeCount == null) {
                    throw new IllegalArgumentException("Edge type count not found for edge type: " + edge.getLabelId());
                }
                if (edgeTypeCount == 0L) {
                    this.edgeTypeCardinality.put(edgeType, 1.0);
                    continue;
                }
                this.edgeTypeCardinality.put(edgeType, edgeTypeCount.doubleValue());
            }
        }
        logger.info("GlogueSchema created with statistics: {}", (Object)this);
    }

    public static GlogueSchema fromMeta(IrMetaStats irMeta) {
        if (irMeta.getStatistics() == null) {
            return new GlogueSchema(irMeta.getSchema());
        }
        return new GlogueSchema(irMeta.getSchema(), irMeta.getStatistics());
    }

    public Double getLabelConstraintsDeltaCost(PatternEdge edge, PatternVertex target) {
        PatternDirection direction = Utils.getExtendDirection(edge, target);
        double deltaCost = 0.0;
        if (direction != PatternDirection.IN) {
            deltaCost += this.getLabelConstraintsDeltaCost(edge, PatternDirection.OUT).doubleValue();
        }
        if (direction != PatternDirection.OUT) {
            deltaCost += this.getLabelConstraintsDeltaCost(edge, PatternDirection.IN).doubleValue();
        }
        return deltaCost;
    }

    private Double getLabelConstraintsDeltaCost(PatternEdge edge, PatternDirection direction) {
        AtomicDouble deltaCost = new AtomicDouble(0.0);
        HashSet visited = Sets.newHashSet();
        edge.getEdgeTypeIds().forEach(edgeTypeId -> {
            EdgeTypeId key;
            EdgeTypeId edgeTypeId2 = key = direction == PatternDirection.OUT ? new EdgeTypeId(edgeTypeId.getSrcLabelId(), edgeTypeId.getEdgeLabelId(), -1) : new EdgeTypeId(-1, edgeTypeId.getEdgeLabelId(), edgeTypeId.getDstLabelId());
            if (visited.contains(key)) {
                return;
            }
            visited.add(key);
            ArrayList candidates = Lists.newArrayList();
            this.edgeTypeCardinality.forEach((k, v) -> {
                switch (direction) {
                    case OUT: {
                        if (edgeTypeId.getSrcLabelId() != k.getSrcLabelId() || edgeTypeId.getEdgeLabelId() != k.getEdgeLabelId()) break;
                        candidates.add(k);
                        break;
                    }
                    case IN: {
                        if (edgeTypeId.getDstLabelId() != k.getDstLabelId() || edgeTypeId.getEdgeLabelId() != k.getEdgeLabelId()) break;
                        candidates.add(k);
                        break;
                    }
                }
            });
            if (!edge.getEdgeTypeIds().containsAll(candidates)) {
                double deltaSum = 0.0;
                for (EdgeTypeId candidate : candidates) {
                    deltaSum += this.getEdgeTypeCardinality(candidate).doubleValue();
                }
                deltaCost.addAndGet(deltaSum);
            }
        });
        return deltaCost.get();
    }

    public List<Integer> getVertexTypes() {
        return List.copyOf(this.schemaGraph.vertexSet());
    }

    public List<EdgeTypeId> getEdgeTypes() {
        return List.copyOf(this.schemaGraph.edgeSet());
    }

    public List<EdgeTypeId> getAdjEdgeTypes(Integer source) {
        return List.copyOf(this.schemaGraph.edgesOf((Object)source));
    }

    public List<EdgeTypeId> getEdgeTypes(Integer source, Integer target) {
        return List.copyOf(this.schemaGraph.getAllEdges((Object)source, (Object)target));
    }

    public Double getVertexTypeCardinality(Integer vertexType) {
        Double cardinality = this.vertexTypeCardinality.get(vertexType);
        if (cardinality == null) {
            logger.debug("Vertex type {} not found in schema, assuming cardinality 1.0", (Object)vertexType);
            return 1.0;
        }
        return cardinality;
    }

    public Double getEdgeTypeCardinality(EdgeTypeId edgeType) {
        Double cardinality = this.edgeTypeCardinality.get(edgeType);
        if (cardinality == null) {
            logger.debug("Edge type {} not found in schema, assuming cardinality 1.0", (Object)edgeType);
            return 1.0;
        }
        return cardinality;
    }

    public String toString() {
        Object s = "GlogueSchema:\n";
        s = (String)s + "VertexTypes:\n";
        for (Integer v : this.schemaGraph.vertexSet()) {
            s = (String)s + v + " " + this.vertexTypeCardinality.get(v) + "\n";
        }
        s = (String)s + "\nEdgeTypes:\n";
        for (EdgeTypeId e : this.schemaGraph.edgeSet()) {
            s = (String)s + e.toString() + " " + this.edgeTypeCardinality.get(e) + "\n";
        }
        return s;
    }
}

