# -*- coding: utf-8 -*-

from __future__ import absolute_import, unicode_literals

try:
    from itertools import zip_longest
except ImportError:
    from itertools import izip_longest as zip_longest

import django
from django.db.models.sql import compiler
import re

NEEDS_AGGREGATES_FIX = django.VERSION[:2] < (1, 7)

# query_class returns the base class to use for Django queries.
# The custom 'SqlServerQuery' class derives from django.db.models.sql.query.Query
# which is passed in as "QueryClass" by Django itself.
#
# SqlServerQuery overrides:
# ...insert queries to add "SET IDENTITY_INSERT" if needed.
# ...select queries to emulate LIMIT/OFFSET for sliced queries.

# Pattern to scan a column data type string and split the data type from any
# constraints or other included parts of a column definition. Based upon
# <column_definition> from http://msdn.microsoft.com/en-us/library/ms174979.aspx
_re_data_type_terminator = re.compile(
    r'\s*\b(?:' +
    r'filestream|collate|sparse|not|null|constraint|default|identity|rowguidcol' +
    r'|primary|unique|clustered|nonclustered|with|on|foreign|references|check' +
    ')',
    re.IGNORECASE,
)

# Pattern used in column aliasing to find sub-select placeholders
_re_col_placeholder = re.compile(r'\{_placeholder_(\d+)\}')

# Pattern to find the quoted column name at the end of a field specification
_re_pat_col = re.compile(r"\[([^\[]+)\]$")

_re_order_limit_offset = re.compile(
    r'(?:ORDER BY\s+(.+?))?\s*(?:LIMIT\s+(\d+))?\s*(?:OFFSET\s+(\d+))?$')

_re_find_order_direction = re.compile(r'\s+(asc|desc)\s*$', re.IGNORECASE)


def _get_order_limit_offset(sql):
    return _re_order_limit_offset.search(sql).groups()


def _remove_order_limit_offset(sql):
    return _re_order_limit_offset.sub('', sql).split(None, 1)[1]


def _break(s, find):
    """Break a string s into the part before the substring to find,
    and the part including and after the substring."""
    i = s.find(find)
    return s[:i], s[i:]


class SQLCompiler(compiler.SQLCompiler):
    def resolve_columns(self, row, fields=()):
        values = []
        index_extra_select = len(self.query.extra_select)
        for value, field in zip_longest(row[index_extra_select:], fields):
            # print '\tfield=%s\tvalue=%s' % (repr(field), repr(value))
            if field:
                try:
                    value = self.connection.ops.convert_values(value, field)
                except ValueError:
                    pass
            values.append(value)
        return row[:index_extra_select] + tuple(values)

    def compile(self, node, select_format=False):
        """
        Added with Django 1.7 as a mechanism to evalute expressions
        """
        sql_function = getattr(node, 'function', None)
        if sql_function and sql_function in self.connection.ops._sql_function_overrides:
            sql_function, sql_template = self.connection.ops._sql_function_overrides[sql_function]
            if sql_function:
                node.function = sql_function
            if sql_template:
                node.template = sql_template
        return super(SQLCompiler, self).compile(node)

    def _fix_aggregates(self):
        """
        MSSQL doesn't match the behavior of the other backends on a few of
        the aggregate functions; different return type behavior, different
        function names, etc.

        MSSQL's implementation of AVG maintains datatype without proding. To
        match behavior of other django backends, it needs to not drop remainders.
        E.g. AVG([1, 2]) needs to yield 1.5, not 1
        """
        for alias, aggregate in self.query.annotation_select.items():
            sql_function = getattr(aggregate, 'sql_function', None)
            if not sql_function or sql_function not in self.connection.ops._sql_function_overrides:
                continue

            sql_function, sql_template = self.connection.ops._sql_function_overrides[sql_function]
            if sql_function:
                self.query.annotation_select[alias].sql_function = sql_function
            if sql_template:
                self.query.annotation_select[alias].sql_template = sql_template

    def as_sql(self, with_limits=True, with_col_aliases=False, subquery=False):
        # Django #12192 - Don't execute any DB query when QS slicing results in limit 0
        if with_limits and self.query.low_mark == self.query.high_mark:
            return '', ()

        self._using_row_number = False

        # Get out of the way if we're not a select query or there's no limiting involved.
        check_limits = with_limits and (self.query.low_mark or self.query.high_mark is not None)
        if not check_limits:
            # The ORDER BY clause is invalid in views, inline functions,
            # derived tables, subqueries, and common table expressions,
            # unless TOP or FOR XML is also specified.
            try:
                setattr(self.query, '_mssql_ordering_not_allowed', with_col_aliases)
                result = super(SQLCompiler, self).as_sql(with_limits, with_col_aliases)
            finally:
                # remove in case query is every reused
                delattr(self.query, '_mssql_ordering_not_allowed')
            return result

        raw_sql, fields = super(SQLCompiler, self).as_sql(False, with_col_aliases)

        # Check for high mark only and replace with "TOP"
        if self.query.high_mark is not None and not self.query.low_mark:
            _select = 'SELECT'
            if self.query.distinct:
                _select += ' DISTINCT'

            sql = re.sub(r'(?i)^{0}'.format(_select), '{0} TOP {1}'.format(_select, self.query.high_mark), raw_sql, 1)
            return sql, fields

        # Else we have limits; rewrite the query using ROW_NUMBER()
        self._using_row_number = True

        # Lop off ORDER... and the initial "SELECT"
        inner_select = _remove_order_limit_offset(raw_sql)
        outer_fields, inner_select = self._alias_columns(inner_select)

        order = _get_order_limit_offset(raw_sql)[0]

        qn = self.connection.ops.quote_name
        inner_table_name = qn('AAAA')

        outer_fields, inner_select, order = self._fix_slicing_order(outer_fields, inner_select, order, inner_table_name)

        # map a copy of outer_fields for injected subselect
        f = []
        for x in outer_fields.split(','):
            i = x.upper().find(' AS ')
            if i != -1:
                x = x[i + 4:]
            if x.find('.') != -1:
                tbl, col = x.rsplit('.', 1)
            else:
                col = x
            f.append('{0}.{1}'.format(inner_table_name, col.strip()))

        # inject a subselect to get around OVER requiring ORDER BY to come from FROM
        inner_select = '{fields} FROM ( SELECT {inner} ) AS {inner_as}'.format(
            fields=', '.join(f),
            inner=inner_select,
            inner_as=inner_table_name,
        )

        where_row_num = '{0} < _row_num'.format(self.query.low_mark)
        if self.query.high_mark:
            where_row_num += ' and _row_num <= {0}'.format(self.query.high_mark)

        sql = """SELECT _row_num, {outer}
FROM ( SELECT ROW_NUMBER() OVER ( ORDER BY {order}) as _row_num, {inner}) as QQQ
WHERE {where}""".format(
            outer=outer_fields,
            order=order,
            inner=inner_select,
            where=where_row_num,
        )

        return sql, fields

    def _fix_slicing_order(self, outer_fields, inner_select, order, inner_table_name):
        """
        Apply any necessary fixes to the outer_fields, inner_select, and order
        strings due to slicing.
        """
        # Using ROW_NUMBER requires an ordering
        if order is None:
            meta = self.query.get_meta()
            column = meta.pk.db_column or meta.pk.get_attname()
            order = '{0}.{1} ASC'.format(
                inner_table_name,
                self.connection.ops.quote_name(column),
            )
        else:
            alias_id = 0
            # remap order for injected subselect
            new_order = []
            for x in order.split(','):
                # find the ordering direction
                m = _re_find_order_direction.search(x)
                if m:
                    direction = m.groups()[0]
                else:
                    direction = 'ASC'
                # remove the ordering direction
                x = _re_find_order_direction.sub('', x)
                # remove any namespacing or table name from the column name
                col = x.rsplit('.', 1)[-1]
                # Is the ordering column missing from the inner select?
                # 'inner_select' contains the full query without the leading 'SELECT '.
                # It's possible that this can get a false hit if the ordering
                # column is used in the WHERE while not being in the SELECT. It's
                # not worth the complexity to properly handle that edge case.
                if x not in inner_select:
                    # Ordering requires the column to be selected by the inner select
                    alias_id += 1
                    # alias column name
                    col = '[{0}___o{1}]'.format(
                        col.strip('[]'),
                        alias_id,
                    )
                    # add alias to inner_select
                    inner_select = '({0}) AS {1}, {2}'.format(x, col, inner_select)
                new_order.append('{0}.{1} {2}'.format(inner_table_name, col, direction))
            order = ', '.join(new_order)
        return outer_fields, inner_select, order

    def _alias_columns(self, sql):
        """Return tuple of SELECT and FROM clauses, aliasing duplicate column names."""
        qn = self.connection.ops.quote_name

        outer = list()
        inner = list()
        names_seen = list()

        # replace all parens with placeholders
        paren_depth, paren_buf = 0, ['']
        parens, i = {}, 0
        for ch in sql:
            if ch == '(':
                i += 1
                paren_depth += 1
                paren_buf.append('')
            elif ch == ')':
                paren_depth -= 1
                key = '_placeholder_{0}'.format(i)
                buf = paren_buf.pop()

                # store the expanded paren string
                parens[key] = buf.format(**parens)
                paren_buf[paren_depth] += '({' + key + '})'
            else:
                paren_buf[paren_depth] += ch

        def _replace_sub(col):
            """Replace all placeholders with expanded values"""
            while _re_col_placeholder.search(col):
                col = col.format(**parens)
            return col

        temp_sql = ''.join(paren_buf)

        select_list, from_clause = _break(temp_sql, ' FROM [')

        for col in [x.strip() for x in select_list.split(',')]:
            match = _re_pat_col.search(col)
            if match:
                col_name = match.group(1)
                col_key = col_name.lower()

                if col_key in names_seen:
                    alias = qn('{0}___{1}'.format(col_name, names_seen.count(col_key)))
                    outer.append(alias)
                    inner.append('{0} as {1}'.format(_replace_sub(col), alias))
                else:
                    outer.append(qn(col_name))
                    inner.append(_replace_sub(col))

                names_seen.append(col_key)
            else:
                raise Exception('Unable to find a column name when parsing SQL: {0}'.format(col))

        return ', '.join(outer), ', '.join(inner) + from_clause.format(**parens)

    def get_ordering(self):
        # The ORDER BY clause is invalid in views, inline functions,
        # derived tables, subqueries, and common table expressions,
        # unless TOP or FOR XML is also specified.
        if getattr(self.query, '_mssql_ordering_not_allowed', False):
            if django.VERSION[1] == 1 and django.VERSION[2] < 6:
                return (None, [])
            return (None, [], [])
        return super(SQLCompiler, self).get_ordering()


class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
    # search for after table/column list
    _re_values_sub = re.compile(
        r'(?P<prefix>\)|\])(?P<default>\s*|\s*default\s*)values(?P<suffix>\s*|\s+\()?',
        re.IGNORECASE
    )
    # ... and insert the OUTPUT clause between it and the values list (or DEFAULT VALUES).
    _values_repl = r'\g<prefix> OUTPUT INSERTED.{col} INTO @sqlserver_ado_return_id\g<default>VALUES\g<suffix>'

    def as_sql(self, *args, **kwargs):
        # Fix for Django ticket #14019
        if not hasattr(self, 'return_id'):
            self.return_id = False

        result = super(SQLInsertCompiler, self).as_sql(*args, **kwargs)
        if isinstance(result, list):
            # Django 1.4 wraps return in list
            return [self._fix_insert(x[0], x[1]) for x in result]

        sql, params = result
        return self._fix_insert(sql, params)

    def _fix_insert(self, sql, params):
        """
        Wrap the passed SQL with IDENTITY_INSERT statements and apply
        other necessary fixes.
        """
        meta = self.query.get_meta()

        if meta.has_auto_field:
            if hasattr(self.query, 'fields'):
                # django 1.4 replaced columns with fields
                fields = self.query.fields
                auto_field = meta.auto_field
            else:
                # < django 1.4
                fields = self.query.columns
                auto_field = meta.auto_field.db_column or meta.auto_field.column

            auto_in_fields = auto_field in fields

            quoted_table = self.connection.ops.quote_name(meta.db_table)
            if not fields or (auto_in_fields and len(fields) == 1 and not params):
                # convert format when inserting only the primary key without
                # specifying a value
                sql = 'INSERT INTO {0} DEFAULT VALUES'.format(
                    quoted_table
                )
                params = []
            elif auto_in_fields:
                # wrap with identity insert
                sql = 'SET IDENTITY_INSERT {table} ON;{sql};SET IDENTITY_INSERT {table} OFF'.format(
                    table=quoted_table,
                    sql=sql,
                )

        # mangle SQL to return ID from insert
        # http://msdn.microsoft.com/en-us/library/ms177564.aspx
        if self.return_id and self.connection.features.can_return_id_from_insert:
            col = self.connection.ops.quote_name(meta.pk.db_column or meta.pk.get_attname())

            # Determine datatype for use with the table variable that will return the inserted ID
            pk_db_type = _re_data_type_terminator.split(meta.pk.db_type(self.connection))[0]

            # NOCOUNT ON to prevent additional trigger/stored proc related resultsets
            sql = 'SET NOCOUNT ON;{declare_table_var};{sql};{select_return_id}'.format(
                sql=sql,
                declare_table_var="DECLARE @sqlserver_ado_return_id table ({col_name} {pk_type})".format(
                    col_name=col,
                    pk_type=pk_db_type,
                ),
                select_return_id="SELECT * FROM @sqlserver_ado_return_id",
            )

            output = self._values_repl.format(col=col)
            sql = self._re_values_sub.sub(output, sql)

        return sql, params


class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
    pass


class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
    def as_sql(self, with_limits=True, with_col_aliases=False, subquery=False):
        sql, params = super(SQLUpdateCompiler, self).as_sql()
        if sql:
            # Need the NOCOUNT OFF so UPDATE returns a count, instead of -1
            sql = 'SET NOCOUNT OFF; {0}; SET NOCOUNT ON'.format(sql)
        return sql, params


class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
    def as_sql(self, with_limits=True, with_col_aliases=False, subquery=False):
        self._fix_aggregates()
        return super(SQLAggregateCompiler, self).as_sql()
