Allow return attention weights

#1
by orby-yanan - opened
Files changed (2) hide show
  1. blocks.py +2 -2
  2. modeling_mpt.py +1 -1
blocks.py CHANGED
@@ -31,9 +31,9 @@ class MPTBlock(nn.Module):
31
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
32
  self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
33
 
34
- def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35
  a = self.norm_1(x)
36
- (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37
  x = x + self.resid_attn_dropout(b)
38
  m = self.norm_2(x)
39
  n = self.ffn(m)
 
31
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
32
  self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
33
 
34
+ def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True, needs_weights: bool=False) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35
  a = self.norm_1(x)
36
+ (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal, needs_weights=needs_weights)
37
  x = x + self.resid_attn_dropout(b)
38
  m = self.norm_2(x)
39
  n = self.ffn(m)
modeling_mpt.py CHANGED
@@ -199,7 +199,7 @@ class MPTModel(MPTPreTrainedModel):
199
  assert all_hidden_states is not None
200
  all_hidden_states = all_hidden_states + (x,)
201
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
202
- (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
203
  if past_key_values is not None:
204
  past_key_values[b_idx] = past_key_value
205
  if output_attentions:
 
199
  assert all_hidden_states is not None
200
  all_hidden_states = all_hidden_states + (x,)
201
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
202
+ (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, needs_weights=output_attentions)
203
  if past_key_values is not None:
204
  past_key_values[b_idx] = past_key_value
205
  if output_attentions: