import inspect
import sys
import re
import traceback
from functools import wraps
from velocity.db import exceptions
from velocity.db.core.transaction import Transaction


class Engine:
    """
    Encapsulates driver config, connection logic, error handling, and transaction decoration.
    """

    MAX_RETRIES = 100

    def __init__(self, driver, config, sql):
        self.__config = config
        self.__sql = sql
        self.__driver = driver

    def __str__(self):
        return f"[{self.sql.server}] engine({self.config})"

    def connect(self):
        """
        Connects to the database and returns the connection object.
        If the database is missing, tries to create it, then reconnect.
        """
        try:
            conn = self.__connect()
        except exceptions.DbDatabaseMissingError:
            self.create_database()
            conn = self.__connect()
        if self.sql.server == "SQLite3":
            conn.isolation_level = None
        return conn

    def __connect(self):
        """
        Internal connection logic, raising suitable exceptions on error.
        """
        try:
            if isinstance(self.config, dict):
                return self.driver.connect(**self.config)
            if isinstance(self.config, (tuple, list)):
                return self.driver.connect(*self.config)
            if isinstance(self.config, str):
                return self.driver.connect(self.config)
            raise Exception("Unhandled configuration parameter.")
        except:
            self.process_error()

    def transaction(self, func_or_cls=None):
        """
        Decorator that provides a Transaction. If `tx` is passed in, uses it; otherwise, creates a new one.
        May also be used to decorate a class, in which case all methods are wrapped in a transaction if they accept `tx`.
        With no arguments, returns a new Transaction directly.
        """
        # print("Transaction", func_or_cls.__name__, type(func_or_cls))

        if func_or_cls is None:
            return Transaction(self)

        if isinstance(func_or_cls, classmethod):
            return classmethod(self.transaction(func_or_cls.__func__))

        if inspect.isfunction(func_or_cls) or inspect.ismethod(func_or_cls):
            names = list(inspect.signature(func_or_cls).parameters.keys())
            # print(func_or_cls.__name__, names)
            if "_tx" in names:
                raise NameError(
                    f"In function {func_or_cls.__name__}, '_tx' is not allowed as a parameter."
                )

            @wraps(func_or_cls)
            def new_function(*args, **kwds):
                tx = None
                names = list(inspect.signature(func_or_cls).parameters.keys())

                # print("inside", func_or_cls.__name__)
                # print(names)
                # print(args, kwds)

                if "tx" not in names:
                    # The function doesn't even declare a `tx` parameter, so run normally.
                    return func_or_cls(*args, **kwds)

                if "tx" in kwds:
                    if isinstance(kwds["tx"], Transaction):
                        tx = kwds["tx"]
                    else:
                        raise TypeError(
                            f"In function {func_or_cls.__name__}, keyword argument `tx` must be a Transaction object."
                        )
                else:
                    # Might be in positional args
                    pos = names.index("tx")
                    if len(args) > pos:
                        if isinstance(args[pos], Transaction):
                            tx = args[pos]

                if tx:
                    return self.exec_function(func_or_cls, tx, *args, **kwds)

                with Transaction(self) as local_tx:
                    pos = names.index("tx")
                    new_args = args[:pos] + (local_tx,) + args[pos:]
                    return self.exec_function(func_or_cls, local_tx, *new_args, **kwds)

            return new_function

        if inspect.isclass(func_or_cls):

            NewCls = type(func_or_cls.__name__, (func_or_cls,), {})

            for attr_name in dir(func_or_cls):
                # Optionally skip special methods
                if attr_name.startswith("__") and attr_name.endswith("__"):
                    continue

                attr = getattr(func_or_cls, attr_name)

                if callable(attr):
                    setattr(NewCls, attr_name, self.transaction(attr))

            return NewCls

        return Transaction(self)

    def exec_function(self, function, _tx, *args, **kwds):
        """
        Executes the given function inside the transaction `_tx`.
        Retries if it raises DbRetryTransaction or DbLockTimeoutError, up to MAX_RETRIES times.
        """
        depth = getattr(_tx, "_exec_function_depth", 0)
        setattr(_tx, "_exec_function_depth", depth + 1)

        try:
            if depth > 0:
                # Not top-level. Just call the function.
                return function(*args, **kwds)
            else:
                retry_count = 0
                lock_timeout_count = 0
                while True:
                    try:
                        return function(*args, **kwds)
                    except exceptions.DbRetryTransaction as e:
                        retry_count += 1
                        if retry_count > self.MAX_RETRIES:
                            raise
                        _tx.rollback()
                    except exceptions.DbLockTimeoutError as e:
                        lock_timeout_count += 1
                        if lock_timeout_count > self.MAX_RETRIES:
                            raise
                        _tx.rollback()
                        continue
                    except:
                        raise
        finally:
            setattr(_tx, "_exec_function_depth", depth)
            # or if depth was 0, you might delete the attribute:
            # if depth == 0:
            #     delattr(_tx, "_exec_function_depth")

    @property
    def driver(self):
        return self.__driver

    @property
    def config(self):
        return self.__config

    @property
    def sql(self):
        return self.__sql

    @property
    def version(self):
        """
        Returns the DB server version.
        """
        with Transaction(self) as tx:
            sql, vals = self.sql.version()
            return tx.execute(sql, vals).scalar()

    @property
    def timestamp(self):
        """
        Returns the current timestamp from the DB server.
        """
        with Transaction(self) as tx:
            sql, vals = self.sql.timestamp()
            return tx.execute(sql, vals).scalar()

    @property
    def user(self):
        """
        Returns the current user as known by the DB server.
        """
        with Transaction(self) as tx:
            sql, vals = self.sql.user()
            return tx.execute(sql, vals).scalar()

    @property
    def databases(self):
        """
        Returns a list of available databases.
        """
        with Transaction(self) as tx:
            sql, vals = self.sql.databases()
            result = tx.execute(sql, vals)
            return [x[0] for x in result.as_tuple()]

    @property
    def current_database(self):
        """
        Returns the name of the current database.
        """
        with Transaction(self) as tx:
            sql, vals = self.sql.current_database()
            return tx.execute(sql, vals).scalar()

    def create_database(self, name=None):
        """
        Creates a database if it doesn't exist, or does nothing if it does.
        """
        old = None
        if name is None:
            old = self.config["database"]
            self.set_config({"database": "postgres"})
            name = old
        with Transaction(self) as tx:
            sql, vals = self.sql.create_database(name)
            tx.execute(sql, vals, single=True)
        if old:
            self.set_config({"database": old})
        return self

    def switch_to_database(self, database):
        """
        Switch the config to use a different database name, closing any existing connection.
        """
        conf = self.config
        if "database" in conf:
            conf["database"] = database
        if "dbname" in conf:
            conf["dbname"] = database
        return self

    def set_config(self, config):
        """
        Updates the internal config dictionary.
        """
        self.config.update(config)

    @property
    def schemas(self):
        """
        Returns a list of schemas in the current database.
        """
        with Transaction(self) as tx:
            sql, vals = self.sql.schemas()
            result = tx.execute(sql, vals)
            return [x[0] for x in result.as_tuple()]

    @property
    def current_schema(self):
        """
        Returns the current schema in use.
        """
        with Transaction(self) as tx:
            sql, vals = self.sql.current_schema()
            return tx.execute(sql, vals).scalar()

    @property
    def tables(self):
        """
        Returns a list of 'schema.table' for all tables in the current DB.
        """
        with Transaction(self) as tx:
            sql, vals = self.sql.tables()
            result = tx.execute(sql, vals)
            return [f"{x[0]}.{x[1]}" for x in result.as_tuple()]

    @property
    def views(self):
        """
        Returns a list of 'schema.view' for all views in the current DB.
        """
        with Transaction(self) as tx:
            sql, vals = self.sql.views()
            result = tx.execute(sql, vals)
            return [f"{x[0]}.{x[1]}" for x in result.as_tuple()]

    def process_error(self, sql_stmt=None, sql_params=None):
        """
        Central method to parse driver exceptions and re-raise them as our custom exceptions.
        """
        e = sys.exc_info()[1]
        msg = str(e).strip().lower()

        if isinstance(e, exceptions.DbException):
            raise

        error_code, error_mesg = self.sql.get_error(e)

        if error_code in self.sql.ApplicationErrorCodes:
            raise exceptions.DbApplicationError(e)
        if error_code in self.sql.ColumnMissingErrorCodes:
            raise exceptions.DbColumnMissingError(e)
        if error_code in self.sql.TableMissingErrorCodes:
            raise exceptions.DbTableMissingError(e)
        if error_code in self.sql.DatabaseMissingErrorCodes:
            raise exceptions.DbDatabaseMissingError(e)
        if error_code in self.sql.ForeignKeyMissingErrorCodes:
            raise exceptions.DbForeignKeyMissingError(e)
        if error_code in self.sql.TruncationErrorCodes:
            raise exceptions.DbTruncationError(e)
        if error_code in self.sql.DataIntegrityErrorCodes:
            raise exceptions.DbDataIntegrityError(e)
        if error_code in self.sql.ConnectionErrorCodes:
            raise exceptions.DbConnectionError(e)
        if error_code in self.sql.DuplicateKeyErrorCodes:
            raise exceptions.DbDuplicateKeyError(e)
        if re.search(r"key \(sys_id\)=\(\d+\) already exists.", msg, re.M):
            raise exceptions.DbDuplicateKeyError(e)
        if error_code in self.sql.DatabaseObjectExistsErrorCodes:
            raise exceptions.DbObjectExistsError(e)
        if error_code in self.sql.LockTimeoutErrorCodes:
            raise exceptions.DbLockTimeoutError(e)
        if error_code in self.sql.RetryTransactionCodes:
            raise exceptions.DbRetryTransaction(e)
        if re.findall(r"database.*does not exist", msg, re.M):
            raise exceptions.DbDatabaseMissingError(e)
        if re.findall(r"no such database", msg, re.M):
            raise exceptions.DbDatabaseMissingError(e)
        if re.findall(r"already exists", msg, re.M):
            raise exceptions.DbObjectExistsError(e)
        if re.findall(r"server closed the connection unexpectedly", msg, re.M):
            raise exceptions.DbConnectionError(e)
        if re.findall(r"no connection to the server", msg, re.M):
            raise exceptions.DbConnectionError(e)
        if re.findall(r"connection timed out", msg, re.M):
            raise exceptions.DbConnectionError(e)
        if re.findall(r"could not connect to server", msg, re.M):
            raise exceptions.DbConnectionError(e)
        if re.findall(r"cannot connect to server", msg, re.M):
            raise exceptions.DbConnectionError(e)
        if re.findall(r"connection already closed", msg, re.M):
            raise exceptions.DbConnectionError(e)
        if re.findall(r"cursor already closed", msg, re.M):
            raise exceptions.DbConnectionError(e)
        if "no such table:" in msg:
            raise exceptions.DbTableMissingError(e)

        msg = f"""
Unhandled/Unknown Error in engine.process_error
EXC_TYPE = {type(e)}
EXC_MSG = {str(e).strip()}

ERROR_CODE = {error_code}
ERROR_MSG = {error_mesg}

SQL_STMT = {sql_stmt}
SQL_PARAMS = {sql_params}

{traceback.format_exc()}
"""
        print(msg)
        raise
