import sqlite3
from collections import namedtuple
import datetime
import itertools
import json
import pathlib

Column = namedtuple(
    "Column", ("cid", "name", "type", "notnull", "default_value", "is_pk")
)
ForeignKey = namedtuple(
    "ForeignKey", ("table", "column", "other_table", "other_column")
)
Index = namedtuple("Index", ("seq", "name", "unique", "origin", "partial", "columns"))


class Database:
    def __init__(self, filename_or_conn):
        if isinstance(filename_or_conn, str):
            self.conn = sqlite3.connect(filename_or_conn)
        elif isinstance(filename_or_conn, pathlib.Path):
            self.conn = sqlite3.connect(str(filename_or_conn))
        else:
            self.conn = filename_or_conn

    def __getitem__(self, table_name):
        return Table(self, table_name)

    def __repr__(self):
        return "<Database {}>".format(self.conn)

    def table_names(self, fts4=False, fts5=False):
        where = ["type = 'table'"]
        if fts4:
            where.append("sql like '%FTS4%'")
        if fts5:
            where.append("sql like '%FTS5%'")
        sql = "select name from sqlite_master where {}".format(" AND ".join(where))
        return [r[0] for r in self.conn.execute(sql).fetchall()]

    @property
    def tables(self):
        return [self[name] for name in self.table_names()]

    def execute_returning_dicts(self, sql, params=None):
        cursor = self.conn.execute(sql, params or tuple())
        keys = [d[0] for d in cursor.description]
        return [dict(zip(keys, row)) for row in cursor.fetchall()]

    def create_table(
        self, name, columns, pk=None, foreign_keys=None, column_order=None
    ):
        foreign_keys = foreign_keys or []
        foreign_keys_by_name = {fk[0]: fk for fk in foreign_keys}
        column_items = list(columns.items())
        if column_order is not None:
            column_items.sort(
                key=lambda p: column_order.index(p[0]) if p[0] in column_order else 999
            )
        extra = ""
        col_type_mapping = {
            float: "FLOAT",
            int: "INTEGER",
            bool: "INTEGER",
            str: "TEXT",
            bytes.__class__: "BLOB",
            bytes: "BLOB",
            datetime.datetime: "TEXT",
            datetime.date: "TEXT",
            datetime.time: "TEXT",
            None.__class__: "TEXT",
        }
        columns_sql = ",\n".join(
            "   [{col_name}] {col_type} {primary_key} {references}".format(
                col_name=col_name,
                col_type=col_type_mapping[col_type],
                primary_key=" PRIMARY KEY" if (pk == col_name) else "",
                references=(
                    " REFERENCES [{other_table}]([{other_column}])".format(
                        other_table=foreign_keys_by_name[col_name][2],
                        other_column=foreign_keys_by_name[col_name][3],
                    )
                    if col_name in foreign_keys_by_name
                    else ""
                ),
            )
            for col_name, col_type in column_items
        )
        sql = """CREATE TABLE [{table}] (
            {columns_sql}
        ){extra};
        """.format(
            table=name, columns_sql=columns_sql, extra=extra
        )
        self.conn.execute(sql)
        return self[name]

    def create_view(self, name, sql):
        self.conn.execute(
            """
            CREATE VIEW {name} AS {sql}
        """.format(
                name=name, sql=sql
            )
        )

    def vacuum(self):
        self.conn.execute("VACUUM;")


class Table:
    def __init__(self, db, name):
        self.db = db
        self.name = name
        self.exists = self.name in self.db.table_names()

    def __repr__(self):
        return "<Table {}{}>".format(
            self.name, " (does not exist yet)" if not self.exists else ""
        )

    @property
    def count(self):
        return self.db.conn.execute(
            "select count(*) from [{}]".format(self.name)
        ).fetchone()[0]

    @property
    def columns(self):
        if not self.exists:
            return []
        rows = self.db.conn.execute(
            "PRAGMA table_info([{}])".format(self.name)
        ).fetchall()
        return [Column(*row) for row in rows]

    @property
    def pks(self):
        return [column.name for column in self.columns if column.is_pk]

    @property
    def foreign_keys(self):
        fks = []
        for row in self.db.conn.execute(
            "PRAGMA foreign_key_list([{}])".format(self.name)
        ).fetchall():
            if row is not None:
                id, seq, table_name, from_, to_, on_update, on_delete, match = row
                fks.append(
                    ForeignKey(
                        table=self.name,
                        column=from_,
                        other_table=table_name,
                        other_column=to_,
                    )
                )
        return fks

    @property
    def schema(self):
        return self.db.conn.execute(
            "select sql from sqlite_master where name = ?", (self.name,)
        ).fetchone()[0]

    @property
    def indexes(self):
        sql = 'PRAGMA index_list("{}")'.format(self.name)
        indexes = []
        for row in self.db.execute_returning_dicts(sql):
            index_name = row["name"]
            index_name_quoted = (
                '"{}"'.format(index_name)
                if not index_name.startswith('"')
                else index_name
            )
            column_sql = "PRAGMA index_info({})".format(index_name_quoted)
            columns = []
            for seqno, cid, name in self.db.conn.execute(column_sql).fetchall():
                columns.append(name)
            row["columns"] = columns
            # These coluns may be missing on older SQLite versions:
            for key, default in {"origin": "c", "partial": 0}.items():
                if key not in row:
                    row[key] = default
            indexes.append(Index(**row))
        return indexes

    def create(self, columns, pk=None, foreign_keys=None, column_order=None):
        columns = {name: value for (name, value) in columns.items()}
        with self.db.conn:
            self.db.create_table(
                self.name,
                columns,
                pk=pk,
                foreign_keys=foreign_keys,
                column_order=column_order,
            )
        self.exists = True
        return self

    def create_index(self, columns, index_name=None):
        if index_name is None:
            index_name = "idx_{}_{}".format(
                self.name.replace(" ", "_"), "_".join(columns)
            )
        sql = """
            CREATE INDEX {index_name}
                ON {table_name} ({columns});
        """.format(
            index_name=index_name, table_name=self.name, columns=", ".join(columns)
        )
        self.db.conn.execute(sql)
        return self

    def drop(self):
        return self.db.conn.execute("DROP TABLE {}".format(self.name))

    def add_foreign_key(self, column, column_type, other_table, other_column):
        sql = """
            ALTER TABLE {table} ADD COLUMN {column} {column_type}
            REFERENCES {other_table}({other_column});
        """.format(
            table=self.name,
            column=column,
            column_type=column_type,
            other_table=other_table,
            other_column=other_column,
        )
        self.db.conn.execute(sql)
        self.db.conn.commit()
        return self

    def enable_fts(self, columns, fts_version="FTS5"):
        "Enables FTS on the specified columns"
        sql = """
            CREATE VIRTUAL TABLE "{table}_fts" USING {fts_version} (
                {columns},
                content="{table}"
            );
        """.format(
            table=self.name, columns=", ".join(columns), fts_version=fts_version
        )
        self.db.conn.executescript(sql)
        self.populate_fts(columns)
        return self

    def populate_fts(self, columns):
        sql = """
            INSERT INTO "{table}_fts" (rowid, {columns})
                SELECT rowid, {columns} FROM {table};
        """.format(
            table=self.name, columns=", ".join(columns)
        )
        self.db.conn.executescript(sql)
        return self

    def detect_fts(self):
        "Detect if table has a corresponding FTS virtual table and return it"
        rows = self.db.conn.execute(
            """
            SELECT name FROM sqlite_master
                WHERE rootpage = 0
                AND (
                    sql LIKE '%VIRTUAL TABLE%USING FTS%content="{table}"%'
                    OR (
                        tbl_name = "{table}"
                        AND sql LIKE '%VIRTUAL TABLE%USING FTS%'
                    )
                )
        """.format(
                table=self.name
            )
        ).fetchall()
        if len(rows) == 0:
            return None
        else:
            return rows[0][0]

    def optimize(self):
        fts_table = self.detect_fts()
        if fts_table is not None:
            self.db.conn.execute(
                """
                INSERT INTO [{table}] ([{table}]) VALUES ("optimize");
            """.format(
                    table=fts_table
                )
            )
        return self

    def detect_column_types(self, records):
        all_column_types = {}
        for record in records:
            for key, value in record.items():
                all_column_types.setdefault(key, set()).add(type(value))
        column_types = {}
        for key, types in all_column_types.items():
            if len(types) == 1:
                t = list(types)[0]
                # But if it's list / tuple / dict, use str instead as we
                # will be storing it as JSON in the table
                if t in (list, tuple, dict):
                    t = str
            elif {int, bool}.issuperset(types):
                t = int
            elif {int, float, bool}.issuperset(types):
                t = float
            elif {bytes, str}.issuperset(types):
                t = bytes
            else:
                t = str
            column_types[key] = t
        return column_types

    def search(self, q):
        sql = """
            select * from {table} where rowid in (
                select rowid from [{table}_fts]
                where [{table}_fts] match :search
            )
            order by rowid
        """.format(
            table=self.name
        )
        return self.db.conn.execute(sql, (q,)).fetchall()

    def insert(
        self, record, pk=None, foreign_keys=None, upsert=False, column_order=None
    ):
        return self.insert_all(
            [record],
            pk=pk,
            foreign_keys=foreign_keys,
            upsert=upsert,
            column_order=column_order,
        )

    def insert_all(
        self,
        records,
        pk=None,
        foreign_keys=None,
        upsert=False,
        batch_size=100,
        column_order=None,
    ):
        """
        Like .insert() but takes a list of records and ensures that the table
        that it creates (if table does not exist) has columns for ALL of that
        data
        """
        all_columns = None
        first = True
        for chunk in chunks(records, batch_size):
            chunk = list(chunk)
            if first:
                if not self.exists:
                    # Use the first batch to derive the table names
                    self.create(
                        self.detect_column_types(chunk),
                        pk,
                        foreign_keys,
                        column_order=column_order,
                    )
                all_columns = set()
                for record in chunk:
                    all_columns.update(record.keys())
                all_columns = list(sorted(all_columns))
            first = False
            sql = """
                INSERT {upsert} INTO [{table}] ({columns}) VALUES {rows};
            """.format(
                upsert="OR REPLACE" if upsert else "",
                table=self.name,
                columns=", ".join("[{}]".format(c) for c in all_columns),
                rows=", ".join(
                    """
                    ({placeholders})
                """.format(
                        placeholders=", ".join(["?"] * len(all_columns))
                    )
                    for record in chunk
                ),
            )
            values = []
            for record in chunk:
                values.extend(
                    jsonify_if_needed(record.get(key, None)) for key in all_columns
                )
            with self.db.conn:
                result = self.db.conn.execute(sql, values)
                self.last_id = result.lastrowid
        return self

    def upsert(self, record, pk=None, foreign_keys=None, column_order=None):
        return self.insert(
            record,
            pk=pk,
            foreign_keys=foreign_keys,
            upsert=True,
            column_order=column_order,
        )

    def upsert_all(self, records, pk=None, foreign_keys=None, column_order=None):
        return self.insert_all(
            records,
            pk=pk,
            foreign_keys=foreign_keys,
            upsert=True,
            column_order=column_order,
        )


def chunks(sequence, size):
    iterator = iter(sequence)
    for item in iterator:
        yield itertools.chain([item], itertools.islice(iterator, size - 1))


def jsonify_if_needed(value):
    if isinstance(value, (dict, list, tuple)):
        return json.dumps(value)
    else:
        return value
