Added output_attentions: bool=False to GroupedQueryAttention.forward() as a temporary fix for AWQ
Browse files- attention.py +1 -1
attention.py
CHANGED
@@ -260,7 +260,7 @@ class GroupedQueryAttention(nn.Module):
|
|
260 |
self.out_proj = FC_CLASS_REGISTRY[fc_type](self.d_model, self.d_model, **fc_kwargs)
|
261 |
self.out_proj._is_residual = True
|
262 |
|
263 |
-
def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, is_causal: bool=True, needs_weights: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
264 |
qkv = self.Wqkv(x)
|
265 |
if self.clip_qkv:
|
266 |
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
|
|
|
260 |
self.out_proj = FC_CLASS_REGISTRY[fc_type](self.d_model, self.d_model, **fc_kwargs)
|
261 |
self.out_proj._is_residual = True
|
262 |
|
263 |
+
def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, is_causal: bool=True, output_attentions: bool=False, needs_weights: bool=False) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
264 |
qkv = self.Wqkv(x)
|
265 |
if self.clip_qkv:
|
266 |
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
|