import torch
import habana_frameworks.torch.core as htcore
import habana_frameworks.torch.utils.experimental as htexp
from .common import *

GAUDI2 = htexp.synDeviceType.synDeviceGaudi2
GAUDI3 = htexp.synDeviceType.synDeviceGaudi3

EXP_WIDTH={torch.float32: 8, torch.bfloat16: 8, torch.float8_e4m3fn: 4, torch.float8_e5m2: 5}
def get_default_exp_bias(dtype):
  exp_width=EXP_WIDTH[dtype]
  return (2**(exp_width-1)-1)

EXP_BIAS_SETS = {
  (GAUDI2, torch.float8_e4m3fn): [3, 7, 11, 15],
  (GAUDI2, torch.float8_e5m2): [15],
  (GAUDI3, torch.float8_e4m3fn): range(0, 63),
  (GAUDI3, torch.float8_e5m2): range(0, 63)
}

MAX_RANGE = {
            torch.float32: 2**((2**8-2-get_default_exp_bias(torch.float32)))*(2-2**-(23)),
            torch.bfloat16: 2**((2**8-2-get_default_exp_bias(torch.bfloat16)))*(2-2**-(7)),
            torch.float8_e4m3fn: 2**((2**4-2-get_default_exp_bias(torch.float8_e4m3fn)))*(2-2**-(8-1-4)),
            torch.float8_e5m2: 2**((2**5-2-get_default_exp_bias(torch.float8_e5m2)))*(2-2**-(8-1-5)),
          }

def get_fullscale(dtype, exp_bias=None):
  default_exp_bias=get_default_exp_bias(dtype)
  fullscale=MAX_RANGE[dtype]
  exp_bias=default_exp_bias if exp_bias==None else exp_bias
  fullscale=fullscale*(2**(default_exp_bias-exp_bias))
  return fullscale

def get_fullscales_by_expbias_set(dtype, expbias_set):
  return [get_fullscale(dtype, exp_bias=eb) for eb in expbias_set]

def get_fp8_hw_alligned_scales(dtype, device):
  exp_bias_set = EXP_BIAS_SETS.get((device, dtype), None)
  return None if exp_bias_set == None else [x/MAX_RANGE[dtype] for x in get_fullscales_by_expbias_set(dtype, exp_bias_set)]

DEVICES_SCALE_FACTORS = {htexp.synDeviceType.synDeviceGaudi2: 4, htexp.synDeviceType.synDeviceGaudi3: 1}
FP8_143_SCALES = {device: get_fp8_hw_alligned_scales(torch.float8_e4m3fn, device) for device in DEVICES_SCALE_FACTORS.keys()}
FP8_143_SCALES_TRAITS = {device: (min(FP8_143_SCALES[device]), max(FP8_143_SCALES[device]), DEVICES_SCALE_FACTORS[device]) for device in DEVICES_SCALE_FACTORS.keys()}


def calc_maxabs_scale(xmaxabs, fullscale, backoff=1):
  scale=xmaxabs/(fullscale*backoff)
  return scale

def scale_to_pow2(scale):
  scale_pow2 = 2 ** torch.ceil(torch.log2(scale))
  return scale_pow2

# Considering range of hw alligned scales: 2^a, 2^a+1,..., 2^b (a<b)
# we want to choose scale s for maxabs m such that 2^a <= s=2^x <= 2^b (for integer a<=x<=b)
# and also 2^(x-1) < m <= 2^x
# if m>=2^b then s=2^b, therefor min(_, 2^b)
# if m<=2^a then s=2^a, therefor max(_, 2^a) --> 2^a <= min(max(_,2^a),2^b) <=2^b
# if s^a<m<2^b then m as a positive number can be written as m=2^y (y=log2(m))
# if y is integer then y=ciel(y) we choose x=y so s=2^x=2^y=2^ciel(y)=2^ciel(log2(m))
# else we choose x=ciel(y) and a<=x-1<y<x<=b and s=2^x=2^ciel(y)=2^ciel(log2(m))
# for Gaudi2 the range is 16^-2..16^1 so we change 2 with 16 and remember that:
# 16 = 2^4, log16(m)=log2(m)/log2(16)=log2(m)/4, and we get:
# we choose s=16^ciel(log16(m))=2^4^ciel(log2(m)/4)=2^(4*ciel(log2(m)/4))=2^(ciel(log2(m)/4)*4)
def scale_to_pow2_hw(scale, device_type):
  scale_pow2=scale_to_pow2(scale)
  min_scale, max_scale, scale_factor = FP8_143_SCALES_TRAITS[device_type]
  scale_pow2_hw = torch.minimum(torch.maximum(2**(torch.ceil(torch.log2(scale_pow2)/scale_factor)*scale_factor), torch.tensor(min_scale, dtype=scale.dtype, device=scale.device)), torch.tensor(max_scale, dtype=scale.dtype, device=scale.device))
  return scale_pow2_hw

def mmse_scale_multi(x, ref_scale, scales, lp_dtype, hp_dtype):
  # TODO: SW-176672 move weights to hpu before the scale calculations
  x = x.to('hpu')
  Nch=x.shape[-1]
  opt_err = torch.ones(Nch, dtype=hp_dtype, device=x.device)*torch.inf
  opt_scale= torch.ones(Nch, dtype=hp_dtype, device=x.device)*-1
  sum_axis=list(range(x.ndim-1))
  rs=ref_scale.unsqueeze(dim=1).transpose(0, 1)
  for s in scales:
    sv = torch.ones(Nch, dtype=hp_dtype, device=x.device)*s
    xscales = rs*sv
    y=scale_fcn(x, xscales)
    y=cast_to_fp8_fcn(y, lp_dtype)
    htcore.mark_step() # we are measuring the error so we want to avoid fusion of the converts
    y=cast_fcn(y, hp_dtype)
    y=descale_fcn(y, xscales)
    err=torch.sum((x-y)**2, dim=sum_axis)
    opt_scale=torch.where(err<opt_err, sv, opt_scale)
    opt_err = torch.where(err < opt_err, err, opt_err)
    htcore.mark_step()
  return opt_scale*ref_scale


def mmse_scale(x, scales, lp_dtype, hp_dtype):
  # TODO: SW-176672 move weights to hpu before the scale calculations
  x = x.to('hpu')
  opt_err = torch.ones(1, dtype=hp_dtype, device=x.device)*torch.inf
  opt_scale= -1
  for s in scales:
    y=scale_fcn(x, s)
    y=cast_to_fp8_fcn(y, lp_dtype)
    htcore.mark_step() # we are measuring the error so we want to avoid fusion of the converts
    y=cast_fcn(y, hp_dtype)
    y=descale_fcn(y, s)
    err=torch.norm(x-y)
    opt_scale=torch.where(err<=opt_err, s, opt_scale)
    opt_err = torch.where(err<=opt_err, err, opt_err)
    htcore.mark_step()
  return opt_scale

def manipulate_scales(scales, func):
  new_inputs = [func(input) for input in scales.inputs]
  new_weights = {}
  if 'weight' in scales.params.keys():
    if isinstance(scales.params['weight'], (torch.Tensor, float)):
      new_weights = {'weight' : func(scales.params['weight'])}
    elif isinstance(scales.params['weight'], dict):
      new_weights_dict = {}
      for key, wt in scales.params['weight'].items():
        new_weights_dict[key] = func(wt)
      new_weights = {'weight' : new_weights_dict}
  new_scales = module_config((new_inputs), func(scales.outputs), new_weights)
  return new_scales

def invert_scales(scales):
  def invert(x):
    if isinstance(x, (list, tuple)):
      return [1 / x_i for x_i in x]
    return 1 / x
  return manipulate_scales(scales, invert)

