import errno
import hashlib
import os
import re
import shutil
import sys
import tempfile
import zipfile
import tarfile
import warnings
from urllib.parse import urlparse
from urllib.request import Request, urlopen
from tqdm import tqdm
from typing import Any, Callable, Optional, Tuple

import oneflow as flow
import oneflow.nn as nn

HASH_REGEX = re.compile(r"([a-f0-9]*)_")


def _is_legacy_tar_format(filename):
    return tarfile.is_tarfile(filename)


def _legacy_tar_load(filename, model_dir, map_location):
    with tarfile.open(filename) as f:
        members = f.getnames()
        extracted_name = members[0]
        extracted_file = os.path.join(model_dir, extracted_name)
        if not os.path.exists(model_dir):
            os.mkdir(model_dir)
        f.extractall(model_dir)
    return flow.load(extracted_file)


def _is_legacy_zip_format(filename):
    return zipfile.is_zipfile(filename)


def _legacy_zip_load(filename, model_dir, map_location):
    # Note: extractall() defaults to overwrite file if exists. No need to clean up beforehand.
    #       We deliberately don't handle tarfile here since our legacy serialization format was in tar.
    with zipfile.ZipFile(filename) as f:
        members = f.infolist()
        extracted_name = members[0].filename
        extracted_file = os.path.join(model_dir, extracted_name)
        if not os.path.exists(extracted_file):
            os.mkdir(extracted_file)
        f.extractall(model_dir)
    # TODO: flow.load doesn't have map_location
    # return flow.load(extracted_file, map_location=map_location)
    return flow.load(extracted_file)


def load_state_dict_from_url(
    url,
    model_dir="./checkpoints",
    map_location=None,
    progress=True,
    check_hash=False,
    file_name=None,
):
    r"""Loads the OneFlow serialized object at the given URL.

    If downloaded file is a zip file, it will be automatically
    decompressed.

    If the object is already present in `model_dir`, it's deserialized and
    returned.

    Args:
        url (string): URL of the object to download
        model_dir (string, optional): directory in which to save the object
        map_location (optional): a function or a dict specifying how to remap storage locations (see flow.load)
        progress (bool, optional): whether or not to display a progress bar to stderr.
            Default: ``True``
        check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
            ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
            digits of the SHA256 hash of the contents of the file. The hash is used to
            ensure unique names and to verify the contents of the file.
            Default: ``False``
        file_name (string, optional): name for the downloaded file. Filename from `url` will be used if not set

    """

    if map_location is not None:
        warnings.warn("Map location is not supported yet.")

    try:
        os.makedirs(model_dir)
    except OSError as e:
        if e.errno == errno.EEXIST:
            # Directory already exists, ignore.
            pass
        else:
            # Unexpected OSError, re-raise.
            raise

    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    if file_name is not None:
        filename = file_name
    cached_file = os.path.join(model_dir, filename)
    if not os.path.exists(cached_file):
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        hash_prefix = None
        if check_hash:
            r = HASH_REGEX.search(filename)  # r is Optional[Match[str]]
            hash_prefix = r.group(1) if r else None
        download_url_to_file(url, cached_file, hash_prefix, progress=progress)

    if _is_legacy_zip_format(cached_file):
        return _legacy_zip_load(cached_file, model_dir, map_location)
    elif _is_legacy_tar_format(cached_file):
        return _legacy_tar_load(cached_file, model_dir, map_location)
    else:
        state_dict = flow.load(cached_file)
        return state_dict


def download_url_to_file(url, dst, hash_prefix=None, progress=True):
    r"""Download object at the given URL to a local path.

    Args:
        url (string): URL of the object to download
        dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file`
        hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`.
            Default: ``None``
        progress (bool, optional): whether or not to display a progress bar to stderr
            Default: ``True``
    """
    file_size = None
    # We use a different API for python2 since urllib(2) doesn't recognize the CA
    # certificates in older Python
    # TODO: if there is a backend requirements, we should add headers in Request module
    req = Request(url)
    u = urlopen(req)
    meta = u.info()
    if hasattr(meta, "getheaders"):
        content_length = meta.getheaders("Content-Length")
    else:
        content_length = meta.get_all("Content-Length")
    if content_length is not None and len(content_length) > 0:
        file_size = int(content_length[0])

    # We deliberately save it in a temp file and move it after
    # download is complete. This prevents a local working checkpoint
    # being overridden by a broken download.
    dst = os.path.expanduser(dst)
    dst_dir = os.path.dirname(dst)
    f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)

    try:
        if hash_prefix is not None:
            sha256 = hashlib.sha256()
        with tqdm(
            total=file_size,
            disable=not progress,
            unit="B",
            unit_scale=True,
            unit_divisor=1024,
        ) as pbar:
            while True:
                buffer = u.read(8192)
                if len(buffer) == 0:
                    break
                f.write(buffer)
                if hash_prefix is not None:
                    sha256.update(buffer)
                pbar.update(len(buffer))

        f.close()
        if hash_prefix is not None:
            digest = sha256.hexdigest()
            if digest[: len(hash_prefix)] != hash_prefix:
                raise RuntimeError(
                    'invalid hash value (expected "{}", got "{}")'.format(
                        hash_prefix, digest
                    )
                )
        shutil.move(f.name, dst)
    finally:
        f.close()
        if os.path.exists(f.name):
            os.remove(f.name)


def named_apply(
    fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
) -> nn.Module:
    if not depth_first and include_root:
        fn(module=module, name=name)
    for child_name, child_module in module.named_children():
        child_name = ".".join((name, child_name)) if name else child_name
        named_apply(
            fn=fn,
            module=child_module,
            name=child_name,
            depth_first=depth_first,
            include_root=True,
        )
    if depth_first and include_root:
        fn(module=module, name=name)
    return module


def named_modules(module: nn.Module, name="", depth_first=True, include_root=False):
    if not depth_first and include_root:
        yield name, module
    for child_name, child_module in module.named_children():
        child_name = ".".join((name, child_name)) if name else child_name
        yield from named_modules(
            module=child_module,
            name=child_name,
            depth_first=depth_first,
            include_root=True,
        )
    if depth_first and include_root:
        yield name, module
