# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license

import os
import platform

import torch

from .. import __version__
from . import logger as LOGGER


def select_device(device="", batch=0, newline=False, verbose=True):
    """Selects PyTorch Device. Options are device = None or 'cpu' or 0 or '0' or '0,1,2,3'."""
    s = f"Yolo Tracking v{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} "
    device = str(device).lower()
    for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
        device = device.replace(
            remove, ""
        )  # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
    cpu = device == "cpu"
    mps = device == "mps"  # Apple Metal Performance Shaders (MPS)
    if cpu or mps:
        os.environ[
            "CUDA_VISIBLE_DEVICES"
        ] = "-1"  # force torch.cuda.is_available() = False
    elif device:  # non-cpu device requested
        visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
        os.environ[
            "CUDA_VISIBLE_DEVICES"
        ] = device  # set environment variable - must be before assert is_available()
        if not (
            torch.cuda.is_available() and
            torch.cuda.device_count() >= len(device.replace(",", ""))
        ):
            install = (
                "See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no "
                "CUDA devices are seen by torch.\n"
                if torch.cuda.device_count() == 0
                else ""
            )
            raise ValueError(
                f"Invalid CUDA 'device={device}' requested."
                f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
                f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
                f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}"
                f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}"
                f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
                f"{install}"
            )

    if not cpu and not mps and torch.cuda.is_available():  # prefer GPU if available
        devices = (
            device.split(",") if device else "0"
        )  # range(torch.cuda.device_count())  # i.e. 0,1,6,7
        n = len(devices)  # device count
        if (
            n > 1 and batch > 0 and batch % n != 0
        ):  # check batch_size is divisible by device_count
            raise ValueError(
                f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
                f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
            )
        space = " " * (len(s) + 1)
        for i, d in enumerate(devices):
            p = torch.cuda.get_device_properties(i)
            s += f"{'' if i == 0 else space}CUDA:{d} ({p.name}, {p.total_memory / (1 << 20):.0f}MiB)\n"  # bytes to MB
        arg = "cuda:0"
    elif mps:
        s += "MPS\n"
        arg = "mps"
    else:  # revert to CPU
        s += "CPU\n"
        arg = "cpu"

    LOGGER.info(s)
    return torch.device(arg)
