from __future__ import annotations

import os

from typing import Any, Dict, Union, TYPE_CHECKING


import numpy as np
import pandas as pd
import sqlalchemy as alch
from sqlalchemy.orm import backref, relationship
from sqlalchemy.ext.declarative import declarative_base
import pyodbc

from subtypes import Frame, AutoEnum
from pathmagic import File
from miscutils import lazy_property
from iotools.serializer import LostObject

from .custom import Model, ModelMeta, AutoModel, Query, Session, StringLiteral, BitLiteral
from .expression import Select, Update, Insert, Delete, SelectInto
from .utils import StoredProcedure, Script
from .log import SqlLog
from .database import Database, Schemas
from .config import Config

if TYPE_CHECKING:
    import alembic

assert pyodbc


class Sql:
    """
    Provides access to the complete sqlalchemy API, with custom functionality added for logging and pandas integration. Handles authentication through config settings.
    The custom expression classes provided have additional useful methods and are modified by the 'autocommit' attribute to facilitate human-supervised queries.
    The custom query class provided by the Alchemy object's 'session' attribute also has additional methods. Many commonly used sqlalchemy objects are bound to this object as attributes for easy access.
    The 'Sql.orm' and 'Sql.objects' attributes provide access via attribute or item access to the reflected database models and underlying table objects, respectively.
    """

    class IfExists(AutoEnum):
        FAIL, REPLACE, APPEND  # noqa

    def __init__(self, connection: str = None, database: str = None, log: File = None, autocommit: bool = False) -> None:
        self.config = Config()

        self.engine = self._create_engine(connection=connection, database=database)
        self.engine.sql = self

        self.session = Session(bind=self.engine, query_cls=Query)
        self.database = Database(self)

        self.log, self.autocommit = log, autocommit

        self.Select, self.SelectInto, self.Update = Select, SelectInto, Update
        self.Insert, self.Delete = Insert, Delete
        self.StoredProcedure, self.Script = StoredProcedure.from_sql(self), Script.from_sql(self)

        self.text, self.literal = alch.text, alch.literal
        self.AND, self.OR, self.CAST, self.CASE, self.TRUE, self.FALSE = alch.and_, alch.or_, alch.cast, alch.case, alch.true(), alch.false()

        self.Table, self.Column, self.Relationship, self.Backref, self.ForeignKey, self.Index, self.CheckConstraint = alch.Table, alch.Column, relationship, backref, alch.ForeignKey, alch.Index, alch.CheckConstraint
        self.type, self.func, self.sqlalchemy = alch.types, alch.func, alch

    def __repr__(self) -> str:
        return f"{type(self).__name__}(engine={repr(self.engine)}, database={repr(self.database)})"

    def __len__(self) -> int:
        return len(self.database.meta.tables)

    def __enter__(self) -> Sql:
        self.session.rollback()
        return self

    def __exit__(self, ex_type: Any, ex_value: Any, ex_traceback: Any) -> None:
        self.session.commit() if ex_type is None else self.session.rollback()

    def __getstate__(self) -> dict:
        return {"engine": LostObject(self.engine), "database": LostObject(self.database), "autocommit": self.autocommit, "_log": self.log}

    def __setstate__(self, attrs: dict) -> None:
        self.__dict__ = attrs

    @property
    def Model(self) -> Model:
        """Custom base class for declarative and automap bases to inherit from. Represents a mapped table in a sql database."""
        return self.database.declaration

    @lazy_property
    def AutoModel(self) -> AutoModel:
        return declarative_base(bind=self.engine, metadata=self.database.meta, cls=AutoModel, metaclass=ModelMeta, name="AutoModel", class_registry=self.database._registry)

    @property
    def orm(self) -> Schemas:
        """Property controlling access to mapped models. Models will only appear for tables that have a primary key, and never for views. Schemas must be accessed before tables: E.g. Sql().orm.some_schema.some_table"""
        return self.database.orm

    @property
    def objects(self) -> Schemas:
        """Property controlling access to raw database objects. Schemas must be accessed before tables: E.g. Sql().orm.some_schema.some_table"""
        return self.database.objects

    @lazy_property
    def operations(self) -> alembic.operations.Operations:
        """Property controlling access to alembic operations."""
        from alembic.migration import MigrationContext
        from alembic.operations import Operations

        return Operations(MigrationContext.configure(self.engine.connect()))

    @property
    def log(self) -> SqlLog:
        """Property controlling access to a special SqlLog class. When setting, a simple file path should be used and an appropriate SqlLog will be created."""
        return self._log

    @log.setter
    def log(self, val: File) -> None:
        self._log = SqlLog(path=val, active=False)

    def initialize_log(self, logname: str, logdir: str = None) -> SqlLog:
        """Instantiates a matt.log.SqlLog object from a name and a dirpath, and binds it to this object's 'log' attribute. If 'active' argument is 'False', this method does nothing."""
        self._log = SqlLog.from_details(log_name=logname, log_dir=logdir, active=False)
        return self._log

    # Conversion Methods

    def query_to_frame(self, query: Query, labels: bool = False) -> Frame:
        """Convert sqlalchemy.orm.Query object to a pandas DataFrame. Optionally apply table labels to columns and/or print an ascii representation of the DataFrame."""
        query = query.with_labels() if labels else query

        result = self.session.execute(query.statement)
        cols = [col[0] for col in result.cursor.description]
        frame = Frame(result.fetchall(), columns=cols)

        return frame

    def plaintext_query_to_frame(self, query: str) -> Frame:
        """Convert plaintext SQL to a pandas DataFrame. The SQL statement must be a SELECT that returns rows."""
        return Frame(pd.read_sql_query(query, self.engine))

    def table_to_frame(self, table: str, schema: str = None) -> Frame:
        """Reads the target table or view (from the specified schema) into a pandas DataFrame."""
        return Frame(pd.read_sql_table(table, self.engine, schema=schema))

    def excel_to_table(self, filepath: os.PathLike, table: str = "temp", schema: str = None, if_exists: str = "fail", primary_key: str = "id", **kwargs: Any) -> Model:
        """Bulk insert the contents of the target '.xlsx' file to the specified table."""
        return self.frame_to_table(dataframe=Frame.from_excel(filepath, **kwargs), table=table, schema=schema, if_exists=if_exists, primary_key=primary_key)

    def frame_to_table(self, dataframe: pd.DataFrame, table: str, schema: str = None, if_exists: str = "fail", primary_key: str = "id") -> Model:
        """Bulk insert the contents of a pandas DataFrame to the specified table."""
        dataframe = Frame(dataframe)

        has_identity_pk = False
        if primary_key is None:
            dataframe.reset_index(inplace=True)
            primary_key = dataframe.iloc[:, 0].name
        else:
            if primary_key in dataframe.columns:
                dataframe.set_index(primary_key, inplace=True)
            else:
                has_identity_pk = True
                dataframe.reset_index(inplace=True, drop=True)
                dataframe.index.names = [primary_key]
                dataframe.index += 1

            dataframe.reset_index(inplace=True)

        dtypes = self._sql_dtype_dict_from_frame(dataframe)
        if has_identity_pk:
            dtypes[primary_key] = alch.types.INT

        dataframe.to_sql(engine=self.engine, name=table, if_exists=if_exists, index=False, index_label=None, primary_key=primary_key, schema=schema, dtype=dtypes)

        table_object = self.orm[schema][table]
        self.refresh_table(table=table_object)
        return self.orm[schema][table]

    @staticmethod
    def orm_to_frame(orm_objects: Any) -> Frame:
        """Convert a homogeneous list of sqlalchemy.orm instance objects (or a single one) to a pandas DataFrame."""
        if not isinstance(orm_objects, list):
            orm_objects = [orm_objects]

        if not all([type(orm_objects[0]) == type(item) for item in orm_objects]):
            raise TypeError("All sqlalchemy.orm mapped objects passed into this function must have the same type.")

        cols = [col.name for col in list(type(orm_objects[0]).__table__.columns)]
        vals = [[getattr(item, col) for col in cols] for item in orm_objects]

        return Frame(vals, columns=cols)

    def create_table(self, table: Union[Model, alch.schema.Table]) -> None:
        """Drop a table or the table belonging to an ORM class and remove it from the metadata."""
        self.database.create_table(table)

    def drop_table(self, table: Union[Model, alch.schema.Table]) -> None:
        """Drop a table or the table belonging to an ORM class and remove it from the metadata."""
        self.database.drop_table(table)

    def refresh_table(self, table: Union[Model, alch.schema.Table]) -> None:
        """Refresh the schema of the table passed by reflecting the database definition again."""
        self.database.refresh_table(table=table)

    def clear_metadata(self) -> None:
        """Clear the metadata held by this Sql object's database."""
        self.database.clear()

    # Private internal methods

    def _create_engine(self, connection: str, database: str) -> alch.engine.base.Engine:
        url = self.config.generate_url(connection=connection, database=database)
        return alch.create_engine(str(url), echo=False, dialect=self._create_literal_dialect(url.get_dialect()))

    def _create_literal_dialect(self, dialect_class: alch.engine.default.DefaultDialect) -> alch.engine.default.DefaultDialect:
        from sqlalchemy.dialects.mssql import dialect as mssql
        from sqlalchemy.dialects.sqlite import dialect as sqlite

        class LiteralDialect(dialect_class):
            supports_multivalues_insert = True

            def __init__(self, *args: Any, **kwargs: Any) -> None:
                super().__init__(*args, **kwargs)
                self.colspecs.update(
                    {
                        alch.sql.sqltypes.String: StringLiteral,
                        alch.sql.sqltypes.DateTime: StringLiteral,
                        alch.sql.sqltypes.DATETIME: StringLiteral,
                        alch.sql.sqltypes.Date: StringLiteral,
                        alch.sql.sqltypes.DATE: StringLiteral,
                        alch.sql.sqltypes.NullType: StringLiteral,
                    }
                )

                if dialect_class in (mssql, sqlite):
                    self.colspecs.update({alch.dialects.mssql.BIT: BitLiteral})

        return LiteralDialect()

    @staticmethod
    def _sql_dtype_dict_from_frame(frame: Frame) -> Dict[str, Any]:
        def isnull(val: Any) -> bool:
            return val is None or np.isnan(val)

        def sqlalchemy_dtype_from_series(series: pd.code.series.Series) -> Any:
            if series.dtype.name in ["int64", "Int64"]:
                nums = [num for num in series if not isnull(num)]
                if not nums:
                    return alch.types.Integer
                else:
                    minimum, maximum = min(nums), max(nums)

                    if 0 <= minimum and maximum <= 255:
                        return alch.dialects.mssql.TINYINT
                    elif -2**15 <= minimum and maximum <= 2**15:
                        return alch.types.SmallInteger
                    elif -2**31 <= minimum and maximum <= 2**31:
                        return alch.types.Integer
                    else:
                        return alch.types.BigInteger
            elif series.dtype.name == "object":
                return alch.types.String(int((series.fillna("").astype(str).str.len().max()//50 + 1)*50))
            else:
                raise TypeError(f"Don't know how to process column type '{series.dtype}' of '{series.name}'.")

        return {name: sqlalchemy_dtype_from_series(col) for name, col in frame.infer_objects().iteritems() if col.dtype.name in ["int64", "Int64", "object"]}
