import math
from typing import Optional, Tuple, Union

import torch


def gaudi_vit_self_attention_forward(
    self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
    """
    Same method as transformers.models.vit.modeling_vit.ViTSelfAttention.forward with a small tweak:
    the division is performed before the matmul for computing attention scores.
    This gives better performance on HPU.
    """

    mixed_query_layer = self.query(hidden_states)

    key_layer = self.transpose_for_scores(self.key(hidden_states))
    value_layer = self.transpose_for_scores(self.value(hidden_states))
    query_layer = self.transpose_for_scores(mixed_query_layer)

    # Take the dot product between "query" and "key" to get the raw attention scores.
    # The div has been put inside the matmul because it achieves better performance on HPU.
    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2) / math.sqrt(self.attention_head_size))

    # Normalize the attention scores to probabilities.
    attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = self.dropout(attention_probs)

    # Mask heads if we want to
    if head_mask is not None:
        attention_probs = attention_probs * head_mask

    context_layer = torch.matmul(attention_probs, value_layer)

    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
    new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
    context_layer = context_layer.view(new_context_layer_shape)

    outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

    return outputs
