from abc import ABC, abstractmethod
from typing import Union, Iterable

import pyarrow as pa

from .CDCTransforms import CHANGE_TYPE
from .Dataflow import logger
from .Metadata import Dataset, OperationalMetadata, create_join_condition, Step, Table
from .SQLUtils import convert_list_to_str, quote_str


class Loader(Step, ABC):

    def __init__(self, source: Dataset, table_name: str, name: Union[None, str] = None,
                 pk_list: Union[None, Iterable[str]] = None, allow_evolution: bool = False,
                 is_cdc: bool = False, generated_key_column: Union[None, str] = None,
                 start_value: Union[None, int] = None):
        """
        A Loader is the last object in a dataflow, the target table.
        - If the source is a CDC table and this table loader is not, then the changes are applied.
        - If both are CDC tables, the data is appended.
        - If the target table has a primary and the source is not a CDC table, an upsert is performed
        - If source is not a CDC table and no primary key information is available, only an append is possible
        The Loader also has the option to fill a surrogate key if the generated key column is specified.

        :param source: The input of this step
        :param table_name: the table name of this target table
        :param name: the step name
        :param pk_list: the target table's (assumed) primary key
        :param allow_evolution: does the table support schema evolution?
        :param is_cdc: is this table one with CDC information
        :param generated_key_column: optional name of the generated key column, which is filled for all insert records
        :param start_value: If provided, this is the start value for the generated key column, else the max() is read from the table
        """
        if name is None:
            name = f"Target table {table_name}"
        super().__init__(name)
        self.table_name = table_name
        self.pk_list = pk_list
        self.allow_evolution=allow_evolution
        self.show_projection = "*"
        self.where_clause = None
        self.is_cdc = is_cdc
        self.schema: Union[None, pa.Schema] = None
        self.add_input(source)
        self.source = source
        self.generated_key_column = generated_key_column
        self.start_value = start_value

    def set_pk_list(self, pk_list: Union[None, Iterable[str]] = None):
        self.pk_list = pk_list

    @abstractmethod
    def get_schema(self, duckdb):
        pass

    @abstractmethod
    def get_cols(self, db) -> set[str]:
        pass

    def add_column(self, field: pa.Field):
        if self.schema is None:
            self.schema = pa.schema([field], None)
        else:
            self.schema = self.schema.append(field)

    def set_show_columns(self, projection: list[str]):
        self.show_projection = convert_list_to_str(projection)

    def set_show_where_clause(self, clause):
        self.where_clause = clause

    @abstractmethod
    def show(self, duckdb, heading: Union[None, str] = None):
        pass

    @abstractmethod
    def get_show_data(self, duckdb):
        pass

    def add_default_columns(self):
        if self.generated_key_column is not None:
            self.add_column(pa.field(self.generated_key_column, pa.int32(), True))
            self.set_pk_list([self.generated_key_column])
        if self.is_cdc:
            self.add_column(pa.field(CHANGE_TYPE, pa.string()))

    @abstractmethod
    def get_table_primary_key(self, db) -> Union[None, set[str]]:
        pass

    @abstractmethod
    def create_table(self, duckdb):
        pass

class DuckDBTable(Table):

    def __init__(self, source: Dataset, table_name: str, name: Union[None, str] = None,
                 pk_list: Union[None, Iterable[str]] = None, allow_evolution: bool = False,
                 is_cdc: bool = False, generated_key_column: Union[None, str] = None, start_value: Union[None, int] = None):
        """
        Write the data into the target table. If the source dataset is a CDC source, perform the insert-update-delete
        statements, else an upsert.
        The primary key is either read from the target table or must be provided. The logical pk of the source is
        not used, it must be the physical pk of the target!

        :param source: source dataset
        :param pk_list: optional pk list in case the target does not support PKs
        """
        super().__init__(name, table_name, is_cdc, pk_list)
        self.add_input(source)
        self.start_value = start_value
        self.allow_evolution = allow_evolution
        self.generated_key_column = generated_key_column
        self.source = source

    def get_generated_key_start(self, duckdb):
        if self.start_value is not None:
            return self.start_value
        elif self.generated_key_column is not None:
            sql = f"select max({quote_str(self.generated_key_column)}) from {quote_str(self.table_name)}"
            logger.debug(
                f"DuckDBTable() - No start value provided, reading the max({self.generated_key_column}) value "
                f"from {self.table_name}: <{sql}>")
            res = duckdb.execute(sql).fetchall()
            start_value = res[0][0]
            if start_value is None:
                start_value = 1
            else:
                start_value += 1
            return start_value

    def add_default_columns(self):
        if self.generated_key_column is not None:
            self.add_column(pa.field(self.generated_key_column, pa.int32()))

    def execute(self, duckdb):
        self.last_execution = OperationalMetadata()
        target_table = self.table_name
        target_table_name = quote_str(target_table)
        table_pk_list = self.get_table_primary_key(duckdb)
        if self.pk_list is None:
            logger.debug(f"DuckDBTable() - No logical primary key provided, reading the pk "
                              f"of the target table {target_table}...")
            self.pk_list = table_pk_list
            if self.pk_list is None:
                logger.debug(f"DuckDBTable() - Target table {target_table} has no primary "
                                  f"key columns - data will be appended")
                use_table_pk = False
            else:
                logger.debug(f"DuckDBTable() - Target table {target_table} has the primary "
                                  f"key columns {self.pk_list}")
                use_table_pk = True
        elif self.pk_list == table_pk_list:
            use_table_pk = True
        else:
            use_table_pk = False

        cols = set(self.source.get_cols(duckdb))
        if CHANGE_TYPE not in self.get_cols(duckdb):
            cols.discard(CHANGE_TYPE)

        gen_key_str = ""
        seq_value_str = ""
        if self.generated_key_column is not None:
            sequence_name = self.table_name + "_seq"
            sql = f"create or replace sequence {quote_str(sequence_name)} start {self.get_generated_key_start(duckdb)}"
            logger.debug(f"DuckDBTable() - Creating the sequence for the key: <{sql}>")
            duckdb.execute(sql)
            gen_key_str = ", " + quote_str(self.generated_key_column)
            seq_value_str = f", nextval('{sequence_name}')"
            cols.discard(self.generated_key_column)
        cols_str = convert_list_to_str(cols)

        if self.source.is_cdc and not self.is_cdc and self.pk_list is not None:
            update_set_str = ""
            for col in cols:
                if col not in self.pk_list and col != self.generated_key_column:
                    if len(update_set_str) > 0:
                        update_set_str += ", "
                    update_set_str += f"{quote_str(col)} = s.{quote_str(col)}"
            pk_list_str = convert_list_to_str(self.pk_list)
            join_condition = create_join_condition(self.pk_list, 's', target_table_name)
            sql = f"""with source as {self.source.get_sub_select_clause()} 
                   INSERT INTO {target_table_name}({cols_str}{gen_key_str})
                   SELECT {cols_str}{seq_value_str} from source
                   where \"__change_type\" = 'I'
                """
            logger.debug(f"DuckDBTable() - Insert all __change_type='I' rows via the SQL <{sql}>")
            duckdb.execute(sql)
            sql = f"""with source as {self.source.get_sub_select_clause()}
                   UPDATE {target_table_name} set {update_set_str} from source s
                   where {join_condition} and s.\"__change_type\" = 'U'
                """
            logger.debug(f"DuckDBTable() - Update all __change_type='U' rows in the target via the SQL <{sql}>")
            duckdb.execute(sql)
            sql = f"""with source as {self.source.get_sub_select_clause()}
                   DELETE FROM {target_table_name}
                   where {pk_list_str} in (SELECT {pk_list_str} from source where \"__change_type\" = 'D')
                """
            logger.debug(f"DuckDBTable() - Delete all __change_type='D' rows in the target via the SQL <{sql}>")
            duckdb.execute(sql)
            res = duckdb.execute(f"""
                with source as ({self.source.get_sub_select_clause()})
                select count(*) from source""").fetchall()
            self.last_execution.processed(res[0][0])
            logger.info(f"DuckDBTable() - {self.last_execution}")
        else:
            if use_table_pk:
                # Upsert using the primary key
                sql = f"""with source as {self.source.get_sub_select_clause()} 
                       INSERT OR REPLACE INTO {target_table_name}({cols_str})
                       SELECT {cols_str} from source
                    """
                logger.debug(f"DuckDBTable() - Upsert all rows via the SQL <{sql}>")
                duckdb.execute(sql)
                res = duckdb.execute(f"""
                    with source as ({self.source.get_sub_select_clause()})
                    select count(*) from source""").fetchall()
                self.last_execution.processed(res[0][0])
                logger.info(f"DuckDBTable() - {self.last_execution}")
            elif self.pk_list is not None:
                # Upsert using a logical primary key
                update_set_str = ""
                for col in cols:
                    if col not in self.pk_list and col != self.generated_key_column:
                        if len(update_set_str) > 0:
                            update_set_str += ", "
                        update_set_str += f"{col} = s.{col}"
                pk_list_str = convert_list_to_str(self.pk_list)
                join_condition = create_join_condition(self.pk_list, 's', target_table_name)

                sql = f"""with source as {self.source.get_sub_select_clause()}
                       UPDATE {target_table_name} set {update_set_str} from source s
                       where {join_condition}
                    """
                logger.debug(f"DuckDBTable() - Updated all matching existing rows via the SQL <{sql}>")
                duckdb.execute(sql)

                sql = f"""with source as {self.source.get_sub_select_clause()} 
                       INSERT INTO {target_table_name}({cols_str})
                       SELECT {cols_str} from source 
                       where {pk_list_str} not in (select {pk_list_str} from {target_table_name})
                    """
                logger.debug(f"DuckDBTable() - Inserted all new rows via the SQL <{sql}>")
                duckdb.execute(sql)

                res = duckdb.execute(f"""
                    with source as ({self.source.get_sub_select_clause()})
                    select count(*) from source""").fetchall()
                self.last_execution.processed(res[0][0])
                logger.info(f"DuckDBTable() - {self.last_execution}")
            else:
                sql = f"""with source as {self.source.get_sub_select_clause()} 
                       INSERT INTO {target_table_name}({cols_str}{gen_key_str})
                       SELECT {cols_str}{seq_value_str} from source
                    """
                logger.debug(f"DuckDBTable() - Insert all rows via the SQL <{sql}>")
                duckdb.execute(sql)
                res = duckdb.execute(f"""
                    with source as {self.source.get_sub_select_clause()}
                    select count(*) from source""").fetchall()
                self.last_execution.processed(res[0][0])
                logger.info(f"DuckDBTable() - {self.last_execution}")


