import functools
import json
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from fnmatch import fnmatch
from random import getrandbits
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union

import attrs
import sqlalchemy as sa
from fsspec.callbacks import DEFAULT_CALLBACK, Callback

from dvcx.data_storage.warehouse import RANDOM_BITS
from dvcx.sql.types import JSON, Boolean, DateTime, Int, SQLType, String

if TYPE_CHECKING:
    from dvcx.catalog import Catalog
    from dvcx.dataset import DatasetRow as Row


class ColumnMeta(type):
    def __getattr__(cls, name: str):
        return cls(name)


class Column(sa.ColumnClause, metaclass=ColumnMeta):
    inherit_cache: Optional[bool] = True

    def __init__(self, text, type_=None, is_literal=False, _selectable=None):
        self.name = text
        super().__init__(
            text, type_=type_, is_literal=is_literal, _selectable=_selectable
        )

    def glob(self, glob_str):
        return self.op("GLOB")(glob_str)


class UDFParameter(ABC):
    @abstractmethod
    def get_value(self, catalog: "Catalog", row: "Row", **kwargs) -> Any: ...

    async def get_value_async(
        self, catalog: "Catalog", row: "Row", mapper, **kwargs
    ) -> Any:
        return self.get_value(catalog, row, **kwargs)


@attrs.define(slots=False)
class ColumnParameter(UDFParameter):
    name: str

    def get_value(self, catalog, row, **kwargs):
        return row[self.name]


@attrs.define(slots=False)
class Object(UDFParameter):
    """
    Object is used as a placeholder parameter to indicate the actual stored object
    being passed as a parameter to the UDF.
    """

    reader: Callable

    def get_value(
        self,
        catalog: "Catalog",
        row: "Row",
        *,
        cache: bool = False,
        cb: Callback = DEFAULT_CALLBACK,
        **kwargs,
    ) -> Any:
        client = catalog.get_client(row["source"])
        uid = row.as_uid()
        if cache:
            client.download(uid, callback=cb)
        with client.open_object(uid, use_cache=cache, cb=cb) as f:
            return self.reader(f)

    async def get_value_async(
        self,
        catalog: "Catalog",
        row: "Row",
        mapper,
        *,
        cache: bool = False,
        cb: Callback = DEFAULT_CALLBACK,
        **kwargs,
    ) -> Any:
        client = catalog.get_client(row["source"])
        uid = row.as_uid()
        if cache:
            await client._download(uid, callback=cb)
        obj = await mapper.to_thread(
            functools.partial(client.open_object, uid, use_cache=cache, cb=cb)
        )
        with obj:
            return await mapper.to_thread(self.reader, obj)


@attrs.define(slots=False)
class Stream(UDFParameter):
    """
    A Stream() parameter receives a binary stream over the object contents.
    """

    def get_value(
        self,
        catalog: "Catalog",
        row: "Row",
        *,
        cache: bool = False,
        cb: Callback = DEFAULT_CALLBACK,
        **kwargs,
    ) -> Any:
        client = catalog.get_client(row["source"])
        uid = row.as_uid()
        if cache:
            client.download(uid, callback=cb)
        return client.open_object(uid, use_cache=cache, cb=cb)

    async def get_value_async(
        self,
        catalog: "Catalog",
        row: "Row",
        mapper,
        *,
        cache: bool = False,
        cb: Callback = DEFAULT_CALLBACK,
        **kwargs,
    ) -> Any:
        client = catalog.get_client(row["source"])
        uid = row.as_uid()
        if cache:
            await client._download(uid, callback=cb)
        return await mapper.to_thread(
            functools.partial(client.open_object, uid, use_cache=cache, cb=cb)
        )


@attrs.define(slots=False)
class LocalFilename(UDFParameter):
    """
    Placeholder parameter representing the local path to a cached copy of the object.

    If glob is None, then all files will be returned. If glob is specified,
    then only files matching the glob will be returned,
    otherwise None will be returned.
    """

    glob: Optional[str] = None

    def get_value(
        self,
        catalog: "Catalog",
        row: "Row",
        *,
        cb: Callback = DEFAULT_CALLBACK,
        **kwargs,
    ) -> Optional[str]:
        if self.glob and not fnmatch(row["name"], self.glob):  # type: ignore[type-var]
            # If the glob pattern is specified and the row filename
            # does not match it, then return None
            return None
        client = catalog.get_client(row["source"])
        uid = row.as_uid()
        client.download(uid, callback=cb)
        return client.cache.get_path(uid)

    async def get_value_async(
        self,
        catalog: "Catalog",
        row: "Row",
        mapper,
        *,
        cache: bool = False,
        cb: Callback = DEFAULT_CALLBACK,
        **kwargs,
    ) -> Optional[str]:
        if self.glob and not fnmatch(row["name"], self.glob):  # type: ignore[type-var]
            # If the glob pattern is specified and the row filename
            # does not match it, then return None
            return None
        client = catalog.get_client(row["source"])
        uid = row.as_uid()
        await client._download(uid, callback=cb)
        return client.cache.get_path(uid)


UDFParamSpec = Union[str, Column, UDFParameter]


def normalize_param(param: UDFParamSpec) -> UDFParameter:
    if isinstance(param, str):
        return ColumnParameter(param)
    if isinstance(param, Column):
        return ColumnParameter(param.name)
    if isinstance(param, UDFParameter):
        return param
    raise TypeError(f"Invalid UDF parameter: {param}")


class DatasetRow:
    schema: ClassVar[dict[str, type[SQLType]]] = {
        "source": String,
        "parent": String,
        "name": String,
        "size": Int,
        "location": JSON,
        "vtype": String,
        "dir_type": Int,
        "owner_name": String,
        "owner_id": String,
        "is_latest": Boolean,
        "last_modified": DateTime,
        "version": String,
        "etag": String,
        # system column
        "random": Int,
    }

    @staticmethod
    def create(
        name: str,
        source: str = "",
        parent: str = "",
        size: int = 0,
        location: Optional[dict[str, Any]] = None,
        vtype: str = "",
        dir_type: int = 0,
        owner_name: str = "",
        owner_id: str = "",
        is_latest: bool = True,
        last_modified: Optional[datetime] = None,
        version: str = "",
        etag: str = "",
    ) -> tuple[
        str,
        str,
        str,
        int,
        Optional[str],
        str,
        int,
        str,
        str,
        bool,
        datetime,
        str,
        str,
        int,
    ]:
        if location:
            location = json.dumps([location])  # type: ignore [assignment]

        last_modified = last_modified or datetime.now(timezone.utc)

        random = getrandbits(RANDOM_BITS)

        return (  # type: ignore [return-value]
            source,
            parent,
            name,
            size,
            location,
            vtype,
            dir_type,
            owner_name,
            owner_id,
            is_latest,
            last_modified,
            version,
            etag,
            random,
        )

    @staticmethod
    def extend(**columns):
        cols = {**DatasetRow.schema}
        cols.update(columns)
        return cols


C = Column
