# MIT License
#
# Copyright (c) 2023 Transmute AI Lab
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import math
from typing import List
import torch
from torch.nn import BatchNorm2d, Conv2d, Linear
import numpy as np
import pandas as pd

# how to change weights.data https://discuss.pytorch.org/t/how-to-delete-every-grad-after-training/63644/8

__all__ = [
    'update_bn_grad',
    'summary_model',
    'is_depthwise_conv2d',
    'prune_conv2d',
    'prune_bn2d',
    'prune_fc',
    'cal_threshold_by_bn2d_weights',
    'mask2idxes',
    'top_k_idxes',
    'ceil',
    'round_up_to_power_of_2',
]


def update_bn_grad(model, s=0.0001):
    """根据 BN 的 gamma 系数稀疏化训练 BN 层 在 loss.backward() 之后执行.

    Args:
        model: nn.Module
        s: 系数训练的权重

    Returns:
    """
    for m in model.modules():
        if isinstance(m, BatchNorm2d):
            m.weight.grad.data.add_(s * torch.sign(m.weight.data))


def summary_model(model: torch.nn.Module,
                  prune_related_layer_types=(Conv2d, BatchNorm2d, Linear)):
    """打印 model 中和剪枝有关的层."""
    info = []
    for name, module in model.named_modules():
        if type(module) in prune_related_layer_types:
            info.append({'name': name, 'module': module})

    df = pd.DataFrame(info)
    df = df.reindex(columns=['name', 'module'])
    print(df.to_markdown())


def is_depthwise_conv2d(module: torch.nn.Module) -> bool:
    if isinstance(module, torch.nn.Conv2d):
        return module.in_channels == module.out_channels == module.groups
    return False


def prune_bn2d(module: BatchNorm2d, keep_idxes):
    module.num_features = len(keep_idxes)
    module.weight = torch.nn.Parameter(module.weight.data[keep_idxes])
    module.weight.grad = None
    module.bias = torch.nn.Parameter(module.bias.data[keep_idxes])
    module.bias.grad = None
    module.running_mean = module.running_mean[keep_idxes]
    module.running_var = module.running_var[keep_idxes]


def prune_conv2d(module: Conv2d, in_keep_idxes=None, out_keep_idxes=None):
    if in_keep_idxes is None:
        in_keep_idxes = list(range(module.weight.shape[1]))

    if out_keep_idxes is None:
        out_keep_idxes = list(range(module.weight.shape[0]))

    is_depthwise = is_depthwise_conv2d(module)

    if is_depthwise:
        module.groups = len(in_keep_idxes)
        assert len(in_keep_idxes) == len(out_keep_idxes)
    else:
        assert (
            len(in_keep_idxes) <= module.weight.shape[1]
        ), f'len(in_keep_idxes): {len(in_keep_idxes)}, module.weight.shape[1]: {module.weight.shape[1]}'

    assert (
        len(out_keep_idxes) <= module.weight.shape[0]
    ), f'len(out_keep_idxes): {len(out_keep_idxes)}, module.weight.shape[0]: {module.weight.shape[0]}'

    module.out_channels = len(out_keep_idxes)
    module.in_channels = len(in_keep_idxes)

    module.weight = torch.nn.Parameter(
        module.weight.data[out_keep_idxes, :, :, :])

    if not is_depthwise:
        module.weight = torch.nn.Parameter(
            module.weight.data[:, in_keep_idxes, :, :])

    module.weight.grad = None

    if module.bias is not None:
        module.bias = torch.nn.Parameter(module.bias.data[out_keep_idxes])
        module.bias.grad = None

    return in_keep_idxes, out_keep_idxes


def prune_fc(module: Linear,
             keep_idxes: List[int],
             bn_num_channels: int = None):
    """

    Args:
        module:
        keep_idxes:
        bn_num_channels: prev bn num_channels

    Returns:

    """
    if bn_num_channels is not None:
        assert module.in_features % bn_num_channels == 0

        channel_step = module.in_features // bn_num_channels

        _keep_idxes = []
        for idx in keep_idxes:
            _keep_idxes.extend(
                np.asarray(list(range(channel_step))) + idx * channel_step)

        keep_idxes = _keep_idxes

    module.in_features = len(keep_idxes)
    module.weight = torch.nn.Parameter(module.weight.data[:, keep_idxes])
    module.weight.grad = None
    return keep_idxes


def cal_threshold_by_bn2d_weights(bn2d_list: List[BatchNorm2d],
                                  sparsity: float):
    """
    sparsity: 要剪枝的比例
    """
    assert 0 < sparsity < 1

    bn_weight_list = []
    for module in bn2d_list:
        bn_weight_list.append(module.weight.data.cpu().abs().clone())

    bn_weights = torch.cat(bn_weight_list)
    k = int(bn_weights.shape[0] * sparsity)

    sorted_bn = torch.sort(bn_weights)[0]
    thresh = sorted_bn[k]
    return thresh


def mask2idxes(mask):
    idxes = np.squeeze(np.argwhere(mask))
    if idxes.size == 1:
        idxes = np.resize(idxes, (1, ))
    return idxes


def top_k_idxes(module, ratio: float = None, k: int = None):
    if k is not None:
        assert k > 0
    assert ratio is not None or k is not None
    assert not (ratio is None and k is None)

    weights = module.weight.data.abs().clone()

    if ratio is not None:
        k = max(int(weights.shape[0] * ratio), 2)

    idxes = torch.topk(weights.view(-1), k, largest=True)[1]
    return idxes.cpu().numpy()


def ceil(num: int, val: int) -> int:
    # c++ version: num + (val-1) & ~(val-1)
    return int(math.ceil(num / val) * val)


def round_up_to_power_of_2(num: int) -> int:
    # leetcode 231
    n = num - 1
    n |= n >> 1
    n |= n >> 2
    n |= n >> 4
    n |= n >> 8
    n |= n >> 16
    if n < 0:
        return 1
    else:
        return n + 1
