"""Kùzu graph store index."""
from typing import Any, Dict, List, Optional

from llama_index.graph_stores.types import GraphStore


class KuzuGraphStore(GraphStore):
    def __init__(
        self,
        database: Any,
        node_table_name: str = "entity",
        rel_table_name: str = "links",
        **kwargs: Any,
    ) -> None:
        try:
            import kuzu
        except ImportError:
            raise ImportError("Please install kuzu: pip install kuzu")
        self.database = database
        self.connection = kuzu.Connection(database)
        self.node_table_name = node_table_name
        self.rel_table_name = rel_table_name
        self.init_schema()

    def init_schema(self) -> None:
        """Initialize schema if the tables do not exist."""
        node_tables = self.connection._get_node_table_names()
        if self.node_table_name not in node_tables:
            self.connection.execute(
                "CREATE NODE TABLE %s (ID STRING, PRIMARY KEY(ID))"
                % self.node_table_name
            )
        rel_tables = self.connection._get_rel_table_names()
        rel_tables = [rel_table["name"] for rel_table in rel_tables]
        if self.rel_table_name not in rel_tables:
            self.connection.execute(
                "CREATE REL TABLE {} (FROM {} TO {}, predicate STRING)".format(
                    self.rel_table_name, self.node_table_name, self.node_table_name
                )
            )

    @property
    def client(self) -> Any:
        return self.connection

    def get(self, subj: str) -> List[List[str]]:
        """Get triplets."""
        query = """
            MATCH (n1:%s)-[r:%s]->(n2:%s)
            WHERE n1.ID = $subj
            RETURN r.predicate, n2.ID;
        """
        prepared_statement = self.connection.prepare(
            query % (self.node_table_name, self.rel_table_name, self.node_table_name)
        )
        query_result = self.connection.execute(prepared_statement, [("subj", subj)])
        retval = []
        while query_result.has_next():
            row = query_result.get_next()
            retval.append([row[0], row[1]])
        return retval

    def get_rel_map(
        self, subjs: Optional[List[str]] = None, depth: int = 2, limit: int = 30
    ) -> Dict[str, List[List[str]]]:
        """Get depth-aware rel map."""
        rel_wildcard = "r:%s*1..%d" % (self.rel_table_name, depth)
        match_clause = "MATCH (n1:{})-[{}]->(n2:{})".format(
            self.node_table_name,
            rel_wildcard,
            self.node_table_name,
        )
        return_clause = "RETURN n1, r, n2 LIMIT %d" % limit
        params = []
        if subjs is not None:
            for i, curr_subj in enumerate(subjs):
                if i == 0:
                    where_clause = "WHERE n1.ID = $%d" % i
                else:
                    where_clause += " OR n1.ID = $%d" % i
                params.append((str(i), curr_subj))
        else:
            where_clause = ""
        query = f"{match_clause} {where_clause} {return_clause}"
        prepared_statement = self.connection.prepare(query)
        if subjs is not None:
            query_result = self.connection.execute(prepared_statement, dict(params))
        else:
            query_result = self.connection.execute(prepared_statement)
        retval: Dict[str, List[List[str]]] = {}
        while query_result.has_next():
            row = query_result.get_next()
            curr_path = []
            subj = row[0]
            recursive_rel = row[1]
            obj = row[2]
            nodes_map = {}
            nodes_map[(subj["_id"]["table"], subj["_id"]["offset"])] = subj["ID"]
            nodes_map[(obj["_id"]["table"], obj["_id"]["offset"])] = obj["ID"]
            for node in recursive_rel["_nodes"]:
                nodes_map[(node["_id"]["table"], node["_id"]["offset"])] = node["ID"]
            for rel in recursive_rel["_rels"]:
                predicate = rel["predicate"]
                curr_subj_id = nodes_map[(rel["_src"]["table"], rel["_src"]["offset"])]
                curr_path.append(curr_subj_id)
                curr_path.append(predicate)
            # Add the last node
            curr_path.append(obj["ID"])
            if subj["ID"] not in retval:
                retval[subj["ID"]] = []
            retval[subj["ID"]].append(curr_path)
        return retval

    def upsert_triplet(self, subj: str, rel: str, obj: str) -> None:
        """Add triplet."""

        def check_entity_exists(connection: Any, entity: str) -> bool:
            is_exists_result = connection.execute(
                "MATCH (n:%s) WHERE n.ID = $entity RETURN n.ID" % self.node_table_name,
                dict([("entity", entity)]),
            )
            return is_exists_result.has_next()

        def create_entity(connection: Any, entity: str) -> None:
            connection.execute(
                "CREATE (n:%s {ID: $entity})" % self.node_table_name,
                dict([("entity", entity)]),
            )

        def check_rel_exists(connection: Any, subj: str, obj: str, rel: str) -> bool:
            is_exists_result = connection.execute(
                (
                    "MATCH (n1:{})-[r:{}]->(n2:{}) WHERE n1.ID = $subj AND n2.ID = "
                    "$obj AND r.predicate = $pred RETURN r.predicate"
                ).format(
                    self.node_table_name, self.rel_table_name, self.node_table_name
                ),
                dict([("subj", subj), ("obj", obj), ("pred", rel)]),
            )
            return is_exists_result.has_next()

        def create_rel(connection: Any, subj: str, obj: str, rel: str) -> None:
            connection.execute(
                (
                    "MATCH (n1:{}), (n2:{}) WHERE n1.ID = $subj AND n2.ID = $obj "
                    "CREATE (n1)-[r:{} {{predicate: $pred}}]->(n2)"
                ).format(
                    self.node_table_name, self.node_table_name, self.rel_table_name
                ),
                dict([("subj", subj), ("obj", obj), ("pred", rel)]),
            )

        is_subj_exists = check_entity_exists(self.connection, subj)
        is_obj_exists = check_entity_exists(self.connection, obj)

        if not is_subj_exists:
            create_entity(self.connection, subj)
        if not is_obj_exists:
            create_entity(self.connection, obj)

        if is_subj_exists and is_obj_exists:
            is_rel_exists = check_rel_exists(self.connection, subj, obj, rel)
            if is_rel_exists:
                return

        create_rel(self.connection, subj, obj, rel)

    def delete(self, subj: str, rel: str, obj: str) -> None:
        """Delete triplet."""

        def delete_rel(connection: Any, subj: str, obj: str, rel: str) -> None:
            connection.execute(
                (
                    "MATCH (n1:{})-[r:{}]->(n2:{}) WHERE n1.ID = $subj AND n2.ID"
                    " = $obj AND r.predicate = $pred DELETE r"
                ).format(
                    self.node_table_name, self.rel_table_name, self.node_table_name
                ),
                dict([("subj", subj), ("obj", obj), ("pred", rel)]),
            )

        def delete_entity(connection: Any, entity: str) -> None:
            connection.execute(
                "MATCH (n:%s) WHERE n.ID = $entity DELETE n" % self.node_table_name,
                dict([("entity", entity)]),
            )

        def check_edges(connection: Any, entity: str) -> bool:
            is_exists_result = connection.execute(
                "MATCH (n1:{})-[r:{}]-(n2:{}) WHERE n2.ID = $entity RETURN r.predicate".format(
                    self.node_table_name, self.rel_table_name, self.node_table_name
                ),
                dict([("entity", entity)]),
            )
            return is_exists_result.has_next()

        delete_rel(self.connection, subj, obj, rel)
        if not check_edges(self.connection, subj):
            delete_entity(self.connection, subj)
        if not check_edges(self.connection, obj):
            delete_entity(self.connection, obj)

    @classmethod
    def from_persist_dir(
        cls,
        persist_dir: str,
        node_table_name: str = "entity",
        rel_table_name: str = "links",
    ) -> "KuzuGraphStore":
        """Load from persist dir."""
        try:
            import kuzu
        except ImportError:
            raise ImportError("Please install kuzu: pip install kuzu")
        database = kuzu.Database(persist_dir)
        return cls(database, node_table_name, rel_table_name)

    @classmethod
    def from_dict(cls, config_dict: Dict[str, Any]) -> "KuzuGraphStore":
        """Initialize graph store from configuration dictionary.

        Args:
            config_dict: Configuration dictionary.

        Returns:
            Graph store.
        """
        return cls(**config_dict)
