import concurrent.futures
import logging
import math
import mmap
import multiprocessing
import os
import pathlib
import tempfile
import typing
from urllib.parse import parse_qs, urlparse

import requests

from ideas import environments, exceptions
from ideas.session import Session, get_default_session
from ideas.utils import api_types, checksum, display, http
from ideas.utils.api_types import (
    File,
    FileMetadata,
    FilePart,
    FileParts,
    Project,
    Tenant,
)
from ideas.utils.custom_click_types import KeyValueOptionValue
from ideas.utils.presigned_urls import PresignedUrlsManager
from ideas.utils.upload import get_part_size, thread_count

logger = logging.getLogger()


def get_environments(include_dev: bool, environments_url: str | None = None):
    """
    Return the names of the supported environments, to used with the `--env` argument.
    """
    return list(
        environments.get_environments(
            include_dev=include_dev, environments_url=environments_url
        ).keys()
    )


def get_tenants(
    filters, *, session: typing.Optional[Session] = None
) -> typing.Iterator[Tenant]:
    session = session or get_default_session()
    yield from http.handle_pagination(
        http.get,
        f"{session.base_url}/api/{http.IDEAS_API_VERSION}/tenant/tenants/",
        session.headers,
        session.auth,
        filters,
    )


def get_projects(
    filters, *, session: typing.Optional[Session] = None
) -> typing.Iterator[Project]:
    session = session or get_default_session()
    yield from http.handle_pagination(
        http.get,
        f"{session.base_url}/api/{http.IDEAS_API_VERSION}/library/projects/",
        session.headers,
        session.auth,
        filters,
    )


def get_files(
    filters, *, session: typing.Optional[Session] = None
) -> typing.Iterator[File]:
    session = session or get_default_session()
    yield from http.handle_pagination(
        http.get,
        f"{session.base_url}/api/{http.IDEAS_API_VERSION}/drs/files/",
        session.headers,
        session.auth,
        filters,
    )


def complete_upload(
    file_id, session: typing.Optional[Session] = None
) -> list[FilePart]:
    """
    Completes a multipart upload, returning the file parts. Can be called multiple times if you are
    looking to resume progress from the existing parts already uploaded.
    """
    session = session or get_default_session()
    response = typing.cast(
        FileParts,
        http.post(
            f"{session.base_url}/api/{http.IDEAS_API_VERSION}/drs/files/{file_id}/complete_upload/",
            session.headers,
            session.auth,
            data={},
        ),
    )

    # The API returns None in some cases; we will consider that as no parts uploaded
    return response["parts"] or []


def create_file(
    project: str,
    filepath: pathlib.Path,
    metadata: typing.Optional[tuple[KeyValueOptionValue]] = None,
    session: typing.Optional[Session] = None,
) -> File:
    session = session or get_default_session()
    size_in_bytes = os.path.getsize(filepath)

    data = {
        "name": os.path.basename(filepath),
        "part_size": get_part_size(size_in_bytes),
        "size": size_in_bytes,
    }

    if project is not None:
        data["project"] = project
    else:
        raise ValueError("Project must be specified")

    file = typing.cast(
        File,
        http.post(
            f"{session.base_url}/api/{http.IDEAS_API_VERSION}/drs/files/",
            session.headers,
            session.auth,
            data,
        ),
    )

    if metadata is not None:
        for key, value in metadata:
            typing.cast(
                FileMetadata,
                http.post(
                    f"{session.base_url}/api/{http.IDEAS_API_VERSION}/metadata/filemetadata/",
                    session.headers,
                    session.auth,
                    data={
                        "file": file["id"],
                        "key": key,
                        "value": value,
                    },
                ),
            )
    return typing.cast(
        File,
        http.get(
            f"{session.base_url}/api/{http.IDEAS_API_VERSION}/drs/files/{file['id']}",
            session.headers,
            session.auth,
        ),
    )


def multipart_upload_source(
    filepath: pathlib.Path,
    project: str,
    upload_progress_filepath: typing.Optional[pathlib.Path] = None,
    metadata: typing.Optional[tuple[KeyValueOptionValue]] = None,
    file_id: typing.Optional[str] = None,
    max_label_length: int = 40,
    upload_threads: typing.Optional[int] = None,
    session: typing.Optional[Session] = None,
) -> File:
    """
    Uploads a file sequentially in parts using a multipart upload, and retrying
    on most failures (with an exponential backoff).

    If `upload_progress_filepath` is not set, log to a file in /tmp or similar.
    """
    session = session or get_default_session()
    file_size = os.path.getsize(filepath)  # file size in bytes
    part_size = get_part_size(file_size)  # part size in bytes
    uploaded_parts = []

    if upload_progress_filepath is None:
        upload_progress_filepath = (
            pathlib.Path(tempfile.gettempdir()) / "ideas-cli-upload_progress.txt"
        )

    if file_id is not None:
        file_parts = complete_upload(file_id, session)
        uploaded_parts = [part["PartNumber"] for part in file_parts]
    else:
        file_id = create_file(
            project,
            filepath,
            metadata,
            session,
        )["id"]

    # TODO in the event that all parts have already been uploaded, but the upload hasn't been
    # completed, this will basically restart the upload from scratch, but leave it in an 'Uploading'
    # state
    # TODO pass the whole session
    presigned_urls = PresignedUrlsManager(
        session.base_url, session.headers, session.auth, file_id
    )
    uploaded_bytes = multiprocessing.Value("L", 0)

    # Set up a pool of threads to upload multiple parts at once. Note that because of the GIL, only
    # one thread is actually running Python code at once, but because uploads are almost entirely
    # I/O-bound, this still gives us a substantial speed-boost (I can saturate a 1Gbps upload link
    # with eight worker threads).
    #
    # If upload_threads is None, python will default to 5 * the number of CPU cores. One word of
    # caution here: each worker must read `part_size` bytes of the file, and since we have
    # `n` workers, we could use a large amount of memory; however, each part is released as it
    # completes, so unless this is running on a machine with many, many cores but slow internet, we
    # shouldn't see much memory usage.
    num_workers = upload_threads or thread_count()
    logger.debug("Using %s thread(s) to upload file.", num_workers)
    with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
        try:
            # Open and truncate existing upload progress file, so it only contains logs
            # for this particular upload. Set the buffering mode to line-buffered to
            # flush writes to this file by line, instead of using the system default
            upload_progress_file = open(upload_progress_filepath, "w", buffering=1)

            with display.file_progress(
                file_size=file_size,
                label=os.path.basename(filepath),
                max_label_length=max_label_length,
                progress_file=upload_progress_file,
            ) as progress:

                def upload_part(
                    part_number: int,
                    file_descriptor: int,
                    uploaded_bytes: multiprocessing.Value,
                ) -> tuple[int, int]:
                    """
                    Executed in a worker thread; uploads a single part to s3. Updates a thread-safe
                    counter of total bytes uploaded (this has a small lock around it). Note that
                    this is executed lazily by the pool as workers free up.

                    Returns a tuple of the part number (1-indexed) and how many bytes were uploaded
                    for this part, which is then used by the main thread to update progress bars and
                    log.

                    Defined inline for convenience so we have access to the `progress` object.
                    """
                    # File access is 0-indexed like normal, but AWS multipart uploads are 1-indexed
                    api_part_no = part_number + 1

                    try:
                        try:
                            # TODO this is a hacky workaround to the fact that we can't mmap the
                            # entire file in on arm64 architectures (our target for IDAS). This
                            # workaround is to instead have each worker mmap its own chunk, which
                            # adds a lot of overhead to the uploading.
                            #
                            # A better fix is to mmap the file in chunks (in the main thread) and
                            # pass each worker the offset and reference to the mmap, so they can
                            # access their particular chunk.
                            #
                            # CAVEAT: the part_size value used for the length must be a multiple of
                            # the page_size on the host system, but this is guaranteed by our
                            # calculation of the page_size value being a multiple of our default
                            # page size, which is itself a multiple of the page size.
                            #
                            # https://inscopix.atlassian.net/browse/ID-2656
                            file_data = mmap.mmap(
                                file_descriptor,
                                # Take a chunk of length part_size
                                length=part_size,
                                prot=mmap.PROT_READ,
                                # Start at our part offset
                                offset=part_number * part_size,
                            )
                        except ValueError:
                            # We are at the end of the file, we can't take length part_size, so use
                            # length=0 which lets mmap calculate the appropriate length to the end
                            # of the file
                            file_data = mmap.mmap(
                                file_descriptor,
                                length=0,
                                prot=mmap.PROT_READ,
                                # Start at our part offset
                                offset=part_number * part_size,
                            )

                        if api_part_no not in uploaded_parts:
                            presigned_url = presigned_urls.get_url_for_part(api_part_no)
                            try:
                                http.put(presigned_url, data=file_data)
                            except requests.exceptions.SSLError:
                                # In an obscure case where the token used to generate the presigned URL
                                # expires before the presigned URL does, our request will be quite
                                # rudely interrupted, causing an SSLError. We specifically log here to
                                # see if this could be caught in the future.
                                logger.exception(
                                    "Got an SSLError while uploading part %s",
                                    api_part_no,
                                )
                                # Invalidate the cache of URLs, the next part to try will refresh
                                # and this part will get retried when we wait on the futures below
                                presigned_urls.invalidate()

                        bytes_uploaded = len(file_data)
                        uploaded_bytes.value += bytes_uploaded

                        return api_part_no, bytes_uploaded
                    except:
                        logger.exception("Failed to upload part %s", api_part_no)
                        raise

                # Keep file open for worker threads to access
                file = open(filepath, "rb")
                file_descriptor = file.fileno()

                try:
                    # Calculate how many parts we have to upload based on the part size. The last part
                    # may be between one and part_size bytes.
                    parts = math.ceil(file_size / part_size)

                    # Submit one task to the pool for each part
                    tasks = {
                        executor.submit(
                            upload_part, part, file_descriptor, uploaded_bytes
                        ): part
                        for part in range(0, parts)
                    }

                    retries = {}

                    # This will wait until all parts are completed (or cancelled). If any parts fail,
                    # `result()` should raise an exception which we handle in the main thread.
                    for future in concurrent.futures.as_completed(tasks):
                        # TODO do this lazily instead of all at once, we can have up to 10k parts
                        try:
                            part_no, bytes_uploaded = future.result()
                            progress.update(
                                bytes_uploaded, (part_no, uploaded_bytes.value)
                            )
                        except Exception:
                            part = tasks[future]
                            logger.warning("Part %s failed, retrying", part)
                            future = executor.submit(
                                upload_part, part, file_descriptor, uploaded_bytes
                            )
                            retries[future] = part
                finally:
                    file.close()

                # Wait on any retries that we submitted, but this time fail loudly
                for future in concurrent.futures.as_completed(retries):
                    try:
                        part_no, bytes_uploaded = future.result()
                        progress.update(bytes_uploaded, (part_no, uploaded_bytes.value))
                    except Exception as e:
                        part = retries[future]
                        raise RuntimeError(f"Part {part} failed") from e

                # A failsafe—the API already won't complete the upload, but better to get a specific
                # error here than just have the file stuck in 'Uploading' state forever.
                if uploaded_bytes.value != file_size:
                    raise exceptions.UnhandledError(
                        f"Uploaded bytes {uploaded_bytes.value} doesn't match file size {file_size}",
                    )

            # Complete multipart upload
            complete_upload(file_id, session)
        except KeyboardInterrupt:
            # Abort any in-progress or unstarted upload tasks
            executor.shutdown(wait=False, cancel_futures=True)
            display.abort(
                f"Add --resume-file-id {file_id} to resume this upload.",
                log="Keyboard interrupt during file upload",
                exit_code=display.EXIT_STATUS.UNSPECIFIED_ERROR,
            )
        finally:
            upload_progress_file.close()

    # Retrieve the file again to return up-to-date status
    file = typing.cast(
        File,
        http.get(
            f"{session.base_url}/api/{http.IDEAS_API_VERSION}/drs/files/{file_id}/",
            session.headers,
            session.auth,
        ),
    )
    return file


def download_file(
    file_id: str,
    download_dir: typing.Optional[str] = None,
    download_progress_filepath: typing.Optional[pathlib.Path] = None,
    session: typing.Optional[Session] = None,
) -> str:
    session = session or get_default_session()
    if download_progress_filepath is None:
        download_progress_filepath = (
            pathlib.Path(tempfile.gettempdir()) / "ideas-cli-download_progress.txt"
        )

    file = typing.cast(
        File,
        http.get(
            f"{session.base_url}/api/{http.IDEAS_API_VERSION}/drs/files/{file_id}",
            session.headers,
            session.auth,
        ),
    )

    part_size = file["part_size"]

    if file["status"] != 3:  # Available
        display.abort(
            f"File {file_id} is not available and can not be downloaded.",
            log="File status is not available for download",
            exit_code=display.EXIT_STATUS.UNSPECIFIED_ERROR,
        )

    response = typing.cast(
        api_types.FileDownload,
        http.get(
            f"{session.base_url}/api/{http.IDEAS_API_VERSION}/drs/files/{file_id}/download_url/",
            session.headers,
            session.auth,
        ),
    )

    url = response["Url"]
    response = requests.get(url, stream=True)
    response.raise_for_status()

    try:
        content_disposition = response.headers["Content-Disposition"]
    except KeyError:
        # If the s3 backend we are talking to doesn't return the requested content-dispotion in the
        # headers, we can retrieve it from the url
        parsed_url = urlparse(url)
        content_disposition = parse_qs(parsed_url.query)[
            "response-content-disposition"
        ][0]

    filename = content_disposition.split("filename=")[1].strip('"')

    filepath = filename if not download_dir else os.path.join(download_dir, filename)
    file_size = file["size"]
    etag = response.headers["ETag"].strip('"')

    # Open and truncate existing download progress file, so it only contains logs
    # for this particular upload. Set the buffering mode to line-buffered to
    # flush writes to this file by line, instead of using the system default
    download_progress_file = open(download_progress_filepath, "w", buffering=1)

    downloaded_bytes = 0

    with display.file_progress(
        file_size=file_size,
        label=filename,
        max_label_length=len(filename),
        progress_file=download_progress_file,
        operation="download",
    ) as progress:
        with open(filepath, "wb") as f:
            part = 0
            etag_checksum = checksum.ETagChecksum(etag)
            for chunk in response.iter_content(chunk_size=part_size):
                if chunk:  # filter out keep-alive new chunks
                    part += 1
                    f.write(chunk)
                    downloaded_bytes += len(chunk)
                    progress.update(part_size, (part, downloaded_bytes))
                    etag_checksum.add_chunk_md5(chunk)

            try:
                etag_checksum.verify()
            except exceptions.MismatchedChecksumError as e:
                # Attempt to delete failed downloaded file to leave no trace behind
                try:
                    f.close()
                    os.remove(filepath)
                except Exception:
                    pass

                display.abort(
                    f"Checksum mismatch when downloading {file_id}: {e}",
                    log="Failed to download file due to checksum mismatch",
                    exit_code=display.EXIT_STATUS.UNSPECIFIED_ERROR,
                )

    return os.path.abspath(filepath)


def validate_tenant(tenant_id: int, session: typing.Optional[Session] = None):
    session = session or get_default_session()
    tenants = list(get_tenants(filters=[("id", tenant_id)], session=session))
    if len(tenants) == 1:
        return tenants[0]
    raise exceptions.UserNotTenantMemberError(
        f"User is not a member of tenant with id {tenant_id}"
    )
