import os
import re
from abc import ABC, abstractmethod
from collections import deque, defaultdict
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List


def check_path(func):
    """
    ? A decorator for raising an OSError if a path argument does not exist.
    
    ! The path argument MUST be the first argument of the method.
    """
    def wrapper(*args, **kwargs):
        if not os.path.exists(args[1]):
            raise OSError(f'The provided path, "{args[1]}", could not be found.')
        return func(*args, **kwargs)
    return wrapper


class InvalidDBScriptFormatError(Exception):
    """
    ! Raised when a database script is read and cannot be processed according to the regex patterns of the associated DB flavor.
    """
    pass


class NotMatchingDBFlavorError(Exception):
    """
    ! Raised when attempting to append a `DBScript` instance with DB flavor A to a `DBScripts` instance with DB flavor B.
    """
    pass


class DBScriptAlreadyPresentError(Exception):
    """
    ! Raised when attempting to append a `DBScript` instance O to a `DBScripts` instance that O is already in and `DBScripts.append(O, True)`.
    """
    pass


class CyclicalDependenciesError(Exception):
    """
    ! Raised when cyclical dependencies are encountered during a `DBScripts.calculate_dependencies()` call.
    """
    pass


class DBObjectTypes(Enum):
    """
    ? An enumerated type of the different kinds of common database objects.
    """
    TABLE = "table"
    VIEW = "view"
    TRIGGER = "trigger"
    STORED_PROCEDURE = "stored procedure"
    SCALAR_FUNCTION = "scalar function"
    TABLE_FUNCTION = "table function"


@dataclass
class DBScriptMetadata:
    """
    ? A data class that stores information on a database script - stores its object's name, type, and schema.
    """
    obj_name: str
    obj_type: DBObjectTypes
    obj_schema: str


class DBScript:
    """
    ? Represents a 'database script' - a script, usually generated by a DBMS, that represents some kind of database object
    """
    @check_path
    def __init__(self, path: str, flavor: "IDBFlavor"):
        """
        Args:
            path (str): the path to the database script.
            flavor (IDBFlavor): the flavor of the script - currently, only `DBFlavor_MSSQL()` is available.
        """
        self.path = path
        self.flavor = flavor
        with open(path, 'r') as f:
            self.contents = f.read()
        self.metadata = flavor.get_dbscript_metadata(self)
        self.dependencies: List[DBScript] = []


class IDBScriptsAppender(ABC):
    @abstractmethod
    def append(self, dbscripts: "DBScripts", dbscript: DBScript) -> None:
        """
        ? Properly appends a `DBScript` instance to the collection.
        
        # ! Use this over self.scripts.append(dbscript) !
        """
        pass


class DBScriptsAppendRegular(IDBScriptsAppender):
    """
    ? Appends a script to the `DBScripts` collection with no additional logic - allows duplicates.
    """
    def append(self, dbscripts: "DBScripts", dbscript: DBScript) -> None:
        dbscripts.scripts.append(dbscript)
        dbscripts.obj_name_to_dbscript_mapping[dbscript.metadata.obj_name] = dbscript


class DBScriptsAppendIgnoreDuplicates(IDBScriptsAppender):
    """
    ? Appends a script to the `DBScripts` collection if not already present. If present, then the method ends.
    """
    def append(self, dbscripts: "DBScripts", dbscript: DBScript) -> None:
        if dbscript in dbscripts.scripts:
            return
        dbscripts.scripts.append(dbscript)
        dbscripts.obj_name_to_dbscript_mapping[dbscript.metadata.obj_name] = dbscript


class DBScriptsAppendErrorOnDuplicates(IDBScriptsAppender):
    """
    ? Appends a script to the `DBScripts` collection if not already present. If present, a `DBScriptAlreadyPresentError` is raised.
    """
    def append(self, dbscripts: "DBScripts", dbscript: DBScript) -> None:
        if dbscript in dbscripts.scripts:
            raise DBScriptAlreadyPresentError(f'The database script at "{dbscript.path}" is already present in the DBScripts collection.')
        dbscripts.scripts.append(dbscript)
        dbscripts.obj_name_to_dbscript_mapping[dbscript.metadata.obj_name] = dbscript


class DBScripts:
    """
    ? A collection of `DBScript` instances - provides methods that, one way or another, act upon multiple database scripts.
    """
    def __init__(self, flavor: "IDBFlavor", appender: IDBScriptsAppender):
        """
        Args:
            flavor (IDBFlavor): the flavor of the scripts - currently, only `DBFlavor_MSSQL()` is available.
            appender (IDBScriptsAppender): the appending strategy to utilize. Choose from `DBScriptsAppendRegular()`, `DBScriptsAppendIgnoreDuplicates()`, or `DBScriptsAppendErrorOnDuplicates()`.
        """
        self.flavor = flavor
        self.appender = appender
        self.scripts: List[DBScript] = []
        self.obj_name_to_dbscript_mapping: Dict[str, DBScript] = {}
    
    def append(self, dbscript: DBScript) -> None:
        self.appender.append(self, dbscript)
    
    @check_path
    def populate_from_dir(self, dir: str) -> None:
        """
        ? Populates the collection of `DBScript` instances with an instance for every .sql file in a directory.
        
        ! You must ensure that all `.sql` files in the directory are 'database scripts' as formatted with the starting format of the `IDBFlavor` implementation used.
        """
        for _, _, filenames in os.walk(dir):
            for filename in filenames:
                filename = filename.strip()
                if filename.endswith('.sql'):
                    self.append(DBScript(os.path.join(dir, filename), self.flavor))
    
    def calculate_dependencies(self) -> None:
        """
        ? Populates the dependencies attribute for all the `DBScript` instances in the collection via Khan's topological sort.
        
        ! Raises a `CyclicalDependenciesError` if cyclical dependencies are detected following the dependency calculation.
        """
        graph = defaultdict(list)
        in_degree = defaultdict(int)
        
        for script in self.scripts:
            in_degree[script.metadata.obj_name] = 0
            for obj_name in self.obj_name_to_dbscript_mapping.keys():
                if obj_name != script.metadata.obj_name and self.flavor.is_valid_ref(obj_name, script):
                    graph[obj_name].append(script.metadata.obj_name)
                    in_degree[script.metadata.obj_name] += 1
        
        self._safe_execution_order: List[DBScript] = []
        queue = deque([script for script in self.scripts if in_degree[script.metadata.obj_name] == 0])
        while queue:
            current = queue.popleft()
            self._safe_execution_order.append(current)
            for dependent_obj_name in graph[current.metadata.obj_name]:
                in_degree[dependent_obj_name] -= 1
                if in_degree[dependent_obj_name] == 0:
                    queue.append(self.obj_name_to_dbscript_mapping[dependent_obj_name])
        
        if len(self._safe_execution_order) != len(self.scripts):
            raise CyclicalDependenciesError('Cyclical dependencies were detected when calculating dependencies within the DBScript collection.')

    def safe_execution_order(self, recalculate_dependencies: bool) -> List[DBScript]:
        """
        ? Returns the list of `DBScript` instances in the collection in an order safe to execute without missing dependencies.
        
        ! Note that, if some dependencies were not included in the `DBScripts` instance to begin with, these will not be accounted for! 
        """
        b = False
        if not hasattr(self, '_safe_execution_order'):
            self.calculate_dependencies()
            b = True
        if recalculate_dependencies and not b:
            self.calculate_dependencies()
        return self._safe_execution_order


class IDBFlavor(ABC):
    @abstractmethod
    def get_dbscript_metadata(self, dbscript: DBScript) -> DBScriptMetadata:
        """
        ? Returns a `DBScriptMetadata` object storing metadata about the `DBScript` object provided.
        
        ! Raises an `InvalidDBScriptFormatError` if a provided script does not match the correct starting format for the `IDBFavlor` implementation in use.
        """
        pass
    
    @staticmethod
    @abstractmethod
    def cleaned_contents(contents: str) -> str:
        """
        ? Returns the contents of a database script with comments and string literals removed.
        """
        pass
    
    @abstractmethod
    def is_valid_ref(self, obj_name: str, dbscript: DBScript) -> bool:
        """
        ? Returns whether or not a database object reference in a script is a proper reference or just a false positive.
        """
        pass


class DBFlavor_MSSQL(IDBFlavor):
    """
    ? Provides regex patterns and methods that corrospond to Microsoft SQL Server syntax and SQL.
    
    ! `DBScript` objects MUST have their contents be of the following starting format in order to be read:
    * "<CREATE|ALTER|CREATE OR ALTER> <TABLE|VIEW|TRIGGER|FUNCTION|PROCEDURE> [SCHEMA].[OBJ_NAME]"
    
    ! Failure to ensure this will raise an `InvalidDBScriptFormatError` on `DBScript.__init__` when this flavor is passed.
    """
    patterns = {
        DBObjectTypes.TABLE: re.compile(r'CREATE\s+(OR\s+ALTER\s+)?TABLE\s+\[([a-zA-Z0-9_]+)\]\.\[([a-zA-Z0-9_]+)\]', re.IGNORECASE),
        DBObjectTypes.VIEW: re.compile(r'CREATE\s+(OR\s+ALTER\s+)?VIEW\s+\[([a-zA-Z0-9_]+)\]\.\[([a-zA-Z0-9_]+)\]', re.IGNORECASE),
        DBObjectTypes.TRIGGER: re.compile(r'CREATE\s+(OR\s+ALTER\s+)?TRIGGER\s+\[([a-zA-Z0-9_]+)\]\.\[([a-zA-Z0-9_]+)\]', re.IGNORECASE),
        DBObjectTypes.TABLE_FUNCTION: re.compile(r'CREATE\s+(OR\s+ALTER\s+)?FUNCTION\s+\[([a-zA-Z0-9_]+)\]\.\[([a-zA-Z0-9_]+)\]\s*\([\w\s,@]*\)\s*RETURNS\s+TABLE', re.IGNORECASE),
        DBObjectTypes.SCALAR_FUNCTION: re.compile(r'CREATE\s+(OR\s+ALTER\s+)?FUNCTION\s+\[([a-zA-Z0-9_]+)\]\.\[([a-zA-Z0-9_]+)\]\s*\([\w\s,@]*\)\s*RETURNS\s+[a-zA-Z]', re.IGNORECASE),
        DBObjectTypes.STORED_PROCEDURE: re.compile(r'CREATE\s+(OR\s+ALTER\s+)?PROCEDURE\s+\[([a-zA-Z0-9_]+)\]\.\[([a-zA-Z0-9_]+)\]', re.IGNORECASE)
    }
    valid_context_keywords = ('JOIN', 'FROM', 'INTO', 'UPDATE', 'DELETE', 'INSERT', 'EXEC', 'CALL')
    
    def get_dbscript_metadata(self, dbscript: "DBScript") -> DBScriptMetadata:
        if (m := re.search(self.patterns[DBObjectTypes.TABLE], dbscript.contents)):
            return DBScriptMetadata(m.group(3), DBObjectTypes.TABLE, m.group(2))
        elif (m := re.search(self.patterns[DBObjectTypes.VIEW], dbscript.contents)):
            return DBScriptMetadata(m.group(3), DBObjectTypes.VIEW, m.group(2))
        elif (m := re.search(self.patterns[DBObjectTypes.TRIGGER], dbscript.contents)):
            return DBScriptMetadata(m.group(3), DBObjectTypes.TRIGGER, m.group(2))
        elif (m := re.search(self.patterns[DBObjectTypes.TABLE_FUNCTION], dbscript.contents)):
            return DBScriptMetadata(m.group(3), DBObjectTypes.TABLE_FUNCTION, m.group(2))
        elif (m := re.search(self.patterns[DBObjectTypes.SCALAR_FUNCTION], dbscript.contents)):
            return DBScriptMetadata(m.group(3), DBObjectTypes.SCALAR_FUNCTION, m.group(2))
        elif (m := re.search(self.patterns[DBObjectTypes.STORED_PROCEDURE], dbscript.contents)):
            return DBScriptMetadata(m.group(3), DBObjectTypes.STORED_PROCEDURE, m.group(2))
        else:
            raise InvalidDBScriptFormatError(f'Invalid contents format for the DB script at "{dbscript.path}".')

    @staticmethod
    def cleaned_contents(contents: str) -> str:
        contents = re.sub(r'--.*', '', contents)
        contents = re.sub(r'/\*.*?\*/', '', contents, flags=re.DOTALL)
        contents = re.sub(r"'([^']*)'", '', contents)
        contents = re.sub(r'"([^"]*)"', '', contents)
        return contents
    
    def is_valid_ref(self, obj_name: str, dbscript: DBScript) -> bool:
        processed_script = self.cleaned_contents(dbscript.contents)
        escaped_obj_name = re.escape(obj_name)
        pattern = re.compile(
            rf"([a-zA-Z0-9_]+|\[[a-zA-Z0-9_]+\])\.({escaped_obj_name}|\[{escaped_obj_name}\])|({escaped_obj_name}|\[{escaped_obj_name}\])", 
            re.IGNORECASE
        )
        for keyword in self.valid_context_keywords:
            if re.search(rf'{keyword}\s+{pattern.pattern}', processed_script, re.IGNORECASE):
                return True
        return False
