[hotfix] update gqa impl
Browse files- modeling_grok1.py +20 -0
modeling_grok1.py
CHANGED
@@ -74,6 +74,21 @@ def load_balancing_loss_func(
|
|
74 |
) * (num_experts**2)
|
75 |
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
class RMSNorm(nn.Module):
|
78 |
def __init__(
|
79 |
self,
|
@@ -194,6 +209,7 @@ class MultiHeadAttention(nn.Module):
|
|
194 |
if num_key_value_heads is None:
|
195 |
num_key_value_heads = num_heads
|
196 |
self.num_key_value_heads = num_key_value_heads
|
|
|
197 |
self.attn_output_multiplier = attn_output_multiplier
|
198 |
self.max_attn_val = max_attn_val
|
199 |
|
@@ -259,6 +275,10 @@ class MultiHeadAttention(nn.Module):
|
|
259 |
|
260 |
past_key_value = (key_states, value_states) if use_cache else None
|
261 |
|
|
|
|
|
|
|
|
|
262 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)).to(
|
263 |
torch.float
|
264 |
)
|
|
|
74 |
) * (num_experts**2)
|
75 |
|
76 |
|
77 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
78 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
79 |
+
"""
|
80 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
81 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
82 |
+
"""
|
83 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
84 |
+
if n_rep == 1:
|
85 |
+
return hidden_states
|
86 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(
|
87 |
+
batch, num_key_value_heads, n_rep, slen, head_dim
|
88 |
+
)
|
89 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
90 |
+
|
91 |
+
|
92 |
class RMSNorm(nn.Module):
|
93 |
def __init__(
|
94 |
self,
|
|
|
209 |
if num_key_value_heads is None:
|
210 |
num_key_value_heads = num_heads
|
211 |
self.num_key_value_heads = num_key_value_heads
|
212 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
213 |
self.attn_output_multiplier = attn_output_multiplier
|
214 |
self.max_attn_val = max_attn_val
|
215 |
|
|
|
275 |
|
276 |
past_key_value = (key_states, value_states) if use_cache else None
|
277 |
|
278 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
279 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
280 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
281 |
+
|
282 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)).to(
|
283 |
torch.float
|
284 |
)
|