OlivierDehaene
commited on
Commit
•
fba6000
1
Parent(s):
6102f15
fix forward
Browse files- modeling_gpt2_mq.py +1 -1
modeling_gpt2_mq.py
CHANGED
@@ -148,7 +148,7 @@ class GPT2MQAttention(nn.Module):
|
|
148 |
# (b, sq * num_heads, head_dim) x (b, head_dim, sk) -> (b, sq * num_heads, sk)
|
149 |
|
150 |
if self.scale_attn_weights:
|
151 |
-
query
|
152 |
|
153 |
attn_weights = torch.bmm(query, key)
|
154 |
|
|
|
148 |
# (b, sq * num_heads, head_dim) x (b, head_dim, sk) -> (b, sq * num_heads, sk)
|
149 |
|
150 |
if self.scale_attn_weights:
|
151 |
+
query = query * self.inv_norm_factor
|
152 |
|
153 |
attn_weights = torch.bmm(query, key)
|
154 |
|