efederici commited on
Commit
648578e
1 Parent(s): 190c13e

Update blocks.py

Browse files
Files changed (1) hide show
  1. blocks.py +2 -2
blocks.py CHANGED
@@ -33,9 +33,9 @@ class MPTBlock(nn.Module):
33
 
34
  def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35
  a = self.norm_1(x)
36
- (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37
  x = x + self.resid_attn_dropout(b)
38
  m = self.norm_2(x)
39
  n = self.ffn(m)
40
  x = x + self.resid_ffn_dropout(n)
41
- return (x, past_key_value)
 
33
 
34
  def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35
  a = self.norm_1(x)
36
+ (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37
  x = x + self.resid_attn_dropout(b)
38
  m = self.norm_2(x)
39
  n = self.ffn(m)
40
  x = x + self.resid_ffn_dropout(n)
41
+ return (x, attn_weights, past_key_value)