import requests
from tqdm import tqdm
import zipfile
import tarfile
import os
import shutil

def _create_if_not_exists(path: str, remove=True) -> None:
    if path is None:
        return
        
    if os.path.exists(path):
        if not remove:
            return
        else:
            if os.path.isdir(path):
                shutil.rmtree(path)
            else:
                os.remove(path)

    # If is a file -> get parent. If no parent -> exit
    if os.path.split(path)[0] == '':
        return

    if '.' in os.path.split(path)[1]:
        path = os.path.split(path)[0]

    os.makedirs(path, exist_ok=True)

def download(url: str, filename: str, unzip=True, unzip_path: str = None, force_download=False, force_unzip=False, clean=False) -> str:
    """
    Download a file from a OneDrive url.

    Yes, this documentation was generated by copilot. Thank you copilot.
    :param url: The url to download from (should end with '?download=1').
    :param filename: The filename to save the file as.
    :param unzip: Whether to unzip the file.
    :param unzip_path: The path to unzip the file to. Default is current path.
    :param force_download: Whether to force download the file even if it already exists.
    :param force_unzip: Whether to force unzip and overwrite the file even if it already exists.
    :param clean: Whether to clean the unzipped files after unzipping.

    :returns: Path of the downloaded (and extracted if 'unzip') file.
    """
    assert url is not None, "URL cannot be None!"
    assert filename is not None, "Parameter filename cannot be None!"

    ret_path = None
    embed=False

    if 'iframe' in url:
        url = url.split('src="')[1].split('"')[0]

    if 'embed' in url:
        url = url.replace('embed', 'download')
        embed=True
    elif not url.endswith("?download=1"):
        # replace everithing after the last ? with ?download=1
        url = url.split("?")[0] + "?download=1"

    try:
        response = requests.get(url, stream=True)

        if not embed:
            if 'id=' in response.url and '&' in response.url:
                fname = response.url.split("id=")[-1].split("&")[0].split("%2F")[-1].split('?')[0]
            else:
                fname = response.url.split('/')[-1].split('?')[0]
        else:
            fname = response.url.split('/')[-1].split('?')[0]

        if os.path.split(filename)[-1] == '' or '.' not in os.path.split(filename)[-1]:
            filename = os.path.join(filename, fname)

        total_size_in_bytes = int(response.headers.get('content-length', 0))
        block_size = 1024

        ret_path = filename

        if not os.path.exists(filename) or force_download:
            progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) if total_size_in_bytes > 1024 else None

            _create_if_not_exists(filename)

            with open(filename, 'wb') as f:
                for data in response.iter_content(block_size):
                    if progress_bar is not None:
                        progress_bar.update(len(data))
                    f.write(data)
            
            if progress_bar is not None:
                progress_bar.close()

                if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
                    raise Exception(
                        f"ERROR, something went wrong during download!\nExpected {total_size_in_bytes} Bytes, got {progress_bar.n} Bytes.")

        # unzip file if necessary
        if unzip:
            if filename.endswith(".zip") or filename.endswith(".tar.gz"):
                print("Unzipping file...")
                unzip_path = unzip_path if unzip_path is not None else os.path.split(filename)[0]
                clean_unzip_path = force_unzip and os.path.realpath(unzip_path) not in os.path.realpath(filename)
                ret_path = unzip_path

                _create_if_not_exists(unzip_path, remove=clean_unzip_path)

                if force_unzip:
                    print("Warning: overwriting existing files!")

                if filename.endswith(".zip"):
                    with zipfile.ZipFile(filename, 'r') as zip_ref:
                        for file in tqdm(iterable=zip_ref.namelist(), total=len(zip_ref.namelist()), desc="Extracting files"):
                            if not os.path.exists(os.path.join(unzip_path, file)) or force_unzip:
                                zip_ref.extract(member=file, path=unzip_path)
                elif filename.endswith(".tar.gz"):
                    with tarfile.open(filename, 'r:gz') as tar_ref:
                        for file in tqdm(iterable=tar_ref.getnames(), total=len(tar_ref.getnames()), desc="Extracting files"):
                            if not os.path.exists(os.path.join(unzip_path, file)) or force_unzip:
                                tar_ref.extract(member=file, path=unzip_path)

                if clean:
                    os.remove(filename)

        return ret_path
    except Exception as e:
        print(e)
        raise Exception("ERROR, something went wrong, see error above!")


if __name__ == "__main__":
    ln = 'https://unimore365-my.sharepoint.com/:u:/g/personal/215580_unimore_it/Eb4BgDZ5g_1Imuwz_PJAmdgBc8k9I_P5p0Y-A97edhsxIw?e=WmlZZc'
    ln2 = '<iframe src="https://onedrive.live.com/embed?cid=D3924A2D106E0039&resid=D3924A2D106E0039%21110&authkey=AIEfi5nlRyY1yaE" width="98" height="120" frameborder="0" scrolling="no"></iframe>'
    print('Downloading dataset')
    ret = download(ln2, filename="./tmp/", clean=True)
    print(ret)