|  |  | 
					
						
						|  | from typing import Optional | 
					
						
						|  | from typing import Tuple | 
					
						
						|  | import torch | 
					
						
						|  | from torch import Tensor | 
					
						
						|  | from torch.nn import Linear | 
					
						
						|  | from torch.nn import Module | 
					
						
						|  | from torch.nn.init import constant_ | 
					
						
						|  | from torch.nn.init import xavier_normal_ | 
					
						
						|  | from torch.nn.init import xavier_uniform_ | 
					
						
						|  | from torch.nn.modules.linear import NonDynamicallyQuantizableLinear | 
					
						
						|  | from torch.nn.parameter import Parameter | 
					
						
						|  |  | 
					
						
						|  | from torch.nn import functional as F | 
					
						
						|  | from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MultiheadAttention(Module): | 
					
						
						|  | __constants__ = ["batch_first"] | 
					
						
						|  | bias_k: Optional[torch.Tensor] | 
					
						
						|  | bias_v: Optional[torch.Tensor] | 
					
						
						|  |  | 
					
						
						|  | def __init__( | 
					
						
						|  | self, | 
					
						
						|  | embed_dim, | 
					
						
						|  | num_heads, | 
					
						
						|  | dropout=0.0, | 
					
						
						|  | bias=True, | 
					
						
						|  | add_bias_kv=False, | 
					
						
						|  | add_zero_attn=False, | 
					
						
						|  | kdim=None, | 
					
						
						|  | vdim=None, | 
					
						
						|  | batch_first=False, | 
					
						
						|  | linear1_cls=Linear, | 
					
						
						|  | linear2_cls=Linear, | 
					
						
						|  | device=None, | 
					
						
						|  | dtype=None, | 
					
						
						|  | ) -> None: | 
					
						
						|  | factory_kwargs = {"device": device, "dtype": dtype} | 
					
						
						|  | super(MultiheadAttention, self).__init__() | 
					
						
						|  | self.embed_dim = embed_dim | 
					
						
						|  | self.kdim = kdim if kdim is not None else embed_dim | 
					
						
						|  | self.vdim = vdim if vdim is not None else embed_dim | 
					
						
						|  | self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim | 
					
						
						|  |  | 
					
						
						|  | self.num_heads = num_heads | 
					
						
						|  | self.dropout = dropout | 
					
						
						|  | self.batch_first = batch_first | 
					
						
						|  | self.head_dim = embed_dim // num_heads | 
					
						
						|  | assert ( | 
					
						
						|  | self.head_dim * num_heads == self.embed_dim | 
					
						
						|  | ), "embed_dim must be divisible by num_heads" | 
					
						
						|  |  | 
					
						
						|  | if add_bias_kv: | 
					
						
						|  | self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) | 
					
						
						|  | self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) | 
					
						
						|  | else: | 
					
						
						|  | self.bias_k = self.bias_v = None | 
					
						
						|  |  | 
					
						
						|  | if linear1_cls == Linear: | 
					
						
						|  | if not self._qkv_same_embed_dim: | 
					
						
						|  | self.q_proj_weight = Parameter( | 
					
						
						|  | torch.empty((embed_dim, embed_dim), **factory_kwargs) | 
					
						
						|  | ) | 
					
						
						|  | self.k_proj_weight = Parameter( | 
					
						
						|  | torch.empty((embed_dim, self.kdim), **factory_kwargs) | 
					
						
						|  | ) | 
					
						
						|  | self.v_proj_weight = Parameter( | 
					
						
						|  | torch.empty((embed_dim, self.vdim), **factory_kwargs) | 
					
						
						|  | ) | 
					
						
						|  | self.register_parameter("in_proj_weight", None) | 
					
						
						|  | else: | 
					
						
						|  | self.in_proj_weight = Parameter( | 
					
						
						|  | torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) | 
					
						
						|  | ) | 
					
						
						|  | self.register_parameter("q_proj_weight", None) | 
					
						
						|  | self.register_parameter("k_proj_weight", None) | 
					
						
						|  | self.register_parameter("v_proj_weight", None) | 
					
						
						|  |  | 
					
						
						|  | if bias: | 
					
						
						|  | self.in_proj_bias = Parameter( | 
					
						
						|  | torch.empty(3 * embed_dim, **factory_kwargs) | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | self.register_parameter("in_proj_bias", None) | 
					
						
						|  | self.out_proj = NonDynamicallyQuantizableLinear( | 
					
						
						|  | embed_dim, embed_dim, bias=bias, **factory_kwargs | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | self._reset_parameters() | 
					
						
						|  | else: | 
					
						
						|  | if not self._qkv_same_embed_dim: | 
					
						
						|  | raise NotImplementedError | 
					
						
						|  | else: | 
					
						
						|  | self.in_proj_linear = linear1_cls( | 
					
						
						|  | embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs | 
					
						
						|  | ) | 
					
						
						|  | self.in_proj_weight = self.in_proj_linear.weight | 
					
						
						|  |  | 
					
						
						|  | self.register_parameter("q_proj_weight", None) | 
					
						
						|  | self.register_parameter("k_proj_weight", None) | 
					
						
						|  | self.register_parameter("v_proj_weight", None) | 
					
						
						|  |  | 
					
						
						|  | if bias: | 
					
						
						|  | self.in_proj_bias = self.in_proj_linear.bias | 
					
						
						|  | else: | 
					
						
						|  | self.register_parameter("in_proj_bias", None) | 
					
						
						|  |  | 
					
						
						|  | self.out_proj = linear2_cls( | 
					
						
						|  | embed_dim, embed_dim, bias=bias, **factory_kwargs | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if self.bias_k is not None: | 
					
						
						|  | xavier_normal_(self.bias_k) | 
					
						
						|  | if self.bias_v is not None: | 
					
						
						|  | xavier_normal_(self.bias_v) | 
					
						
						|  |  | 
					
						
						|  | self.add_zero_attn = add_zero_attn | 
					
						
						|  |  | 
					
						
						|  | def _reset_parameters(self): | 
					
						
						|  | if self._qkv_same_embed_dim: | 
					
						
						|  | xavier_uniform_(self.in_proj_weight) | 
					
						
						|  | else: | 
					
						
						|  | xavier_uniform_(self.q_proj_weight) | 
					
						
						|  | xavier_uniform_(self.k_proj_weight) | 
					
						
						|  | xavier_uniform_(self.v_proj_weight) | 
					
						
						|  |  | 
					
						
						|  | if self.in_proj_bias is not None: | 
					
						
						|  | constant_(self.in_proj_bias, 0.0) | 
					
						
						|  | constant_(self.out_proj.bias, 0.0) | 
					
						
						|  |  | 
					
						
						|  | if self.bias_k is not None: | 
					
						
						|  | xavier_normal_(self.bias_k) | 
					
						
						|  | if self.bias_v is not None: | 
					
						
						|  | xavier_normal_(self.bias_v) | 
					
						
						|  |  | 
					
						
						|  | def __setstate__(self, state): | 
					
						
						|  |  | 
					
						
						|  | if "_qkv_same_embed_dim" not in state: | 
					
						
						|  | state["_qkv_same_embed_dim"] = True | 
					
						
						|  |  | 
					
						
						|  | super(MultiheadAttention, self).__setstate__(state) | 
					
						
						|  |  | 
					
						
						|  | def forward( | 
					
						
						|  | self, | 
					
						
						|  | query: Tensor, | 
					
						
						|  | key: Tensor, | 
					
						
						|  | value: Tensor, | 
					
						
						|  | key_padding_mask: Optional[Tensor] = None, | 
					
						
						|  | need_weights: bool = True, | 
					
						
						|  | attn_mask: Optional[Tensor] = None, | 
					
						
						|  | average_attn_weights: bool = True, | 
					
						
						|  | cache=None, | 
					
						
						|  | ) -> Tuple[Tensor, Optional[Tensor]]: | 
					
						
						|  | any_nested = query.is_nested or key.is_nested or value.is_nested | 
					
						
						|  | query = key = value = query.transpose(1, 0) | 
					
						
						|  | attn_output = multi_head_attention_forward_patched( | 
					
						
						|  | query, | 
					
						
						|  | key, | 
					
						
						|  | value, | 
					
						
						|  | self.embed_dim, | 
					
						
						|  | self.num_heads, | 
					
						
						|  | self.in_proj_weight, | 
					
						
						|  | self.in_proj_bias, | 
					
						
						|  | self.bias_k, | 
					
						
						|  | self.bias_v, | 
					
						
						|  | self.add_zero_attn, | 
					
						
						|  | self.dropout, | 
					
						
						|  | self.out_proj.weight, | 
					
						
						|  | self.out_proj.bias, | 
					
						
						|  | training=self.training, | 
					
						
						|  | key_padding_mask=key_padding_mask, | 
					
						
						|  | need_weights=need_weights, | 
					
						
						|  | attn_mask=attn_mask, | 
					
						
						|  | average_attn_weights=average_attn_weights, | 
					
						
						|  | cache=cache, | 
					
						
						|  | ) | 
					
						
						|  | return attn_output.transpose(1, 0) | 
					
						
						|  |  |