OlivierDehaene commited on
Commit
fba6000
1 Parent(s): 6102f15

fix forward

Browse files
Files changed (1) hide show
  1. modeling_gpt2_mq.py +1 -1
modeling_gpt2_mq.py CHANGED
@@ -148,7 +148,7 @@ class GPT2MQAttention(nn.Module):
148
  # (b, sq * num_heads, head_dim) x (b, head_dim, sk) -> (b, sq * num_heads, sk)
149
 
150
  if self.scale_attn_weights:
151
- query *= self.inv_norm_factor
152
 
153
  attn_weights = torch.bmm(query, key)
154
 
 
148
  # (b, sq * num_heads, head_dim) x (b, head_dim, sk) -> (b, sq * num_heads, sk)
149
 
150
  if self.scale_attn_weights:
151
+ query = query * self.inv_norm_factor
152
 
153
  attn_weights = torch.bmm(query, key)
154