pseudotensor
commited on
Commit
•
cba2f63
1
Parent(s):
11599d0
Update modelling_RW.py
Browse files- modelling_RW.py +49 -49
modelling_RW.py
CHANGED
@@ -52,11 +52,10 @@ class RotaryEmbedding(torch.nn.Module):
|
|
52 |
|
53 |
def __init__(
|
54 |
self,
|
55 |
-
|
56 |
base=10000,
|
|
|
57 |
):
|
58 |
-
head_dim = config.head_dim
|
59 |
-
self.use_cache = config.use_cache
|
60 |
super().__init__()
|
61 |
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
62 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
@@ -65,6 +64,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
65 |
self.batch_size_cached = None
|
66 |
self.cos_cached: torch.Tensor | None = None
|
67 |
self.sin_cached: torch.Tensor | None = None
|
|
|
68 |
|
69 |
def cos_sin(
|
70 |
self,
|
@@ -107,10 +107,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
107 |
def forward(self, q, k):
|
108 |
batch, seq_len, head_dim = q.shape
|
109 |
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
110 |
-
|
111 |
-
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
112 |
-
except Exception as e:
|
113 |
-
raise
|
114 |
|
115 |
|
116 |
def _make_causal_mask(
|
@@ -187,7 +184,7 @@ class Attention(nn.Module):
|
|
187 |
f" {self.num_heads})."
|
188 |
)
|
189 |
|
190 |
-
self.maybe_rotary = RotaryEmbedding(config) if config.rotary else lambda q, k: (q, k)
|
191 |
|
192 |
# Layer-wise attention scaling
|
193 |
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
@@ -195,34 +192,44 @@ class Attention(nn.Module):
|
|
195 |
|
196 |
self.query_key_value = Linear(
|
197 |
self.hidden_size,
|
198 |
-
|
199 |
bias=config.bias,
|
200 |
)
|
201 |
-
self.multi_query = config.multi_query
|
202 |
self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
|
203 |
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
204 |
-
self.num_kv = config.
|
205 |
|
206 |
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
207 |
"""
|
208 |
-
Split the last dimension into (num_heads, head_dim)
|
209 |
storage as `fused_qkv`
|
210 |
|
211 |
Args:
|
212 |
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
|
213 |
|
214 |
Returns:
|
215 |
-
query: [batch_size, seq_length, num_heads, head_dim]
|
|
|
216 |
value: [batch_size, seq_length, num_heads, head_dim]
|
217 |
"""
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
|
227 |
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
|
228 |
"""
|
@@ -268,11 +275,11 @@ class Attention(nn.Module):
|
|
268 |
|
269 |
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
270 |
key_layer = key_layer.transpose(1, 2).reshape(
|
271 |
-
batch_size * self.
|
272 |
q_length,
|
273 |
self.head_dim,
|
274 |
)
|
275 |
-
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.
|
276 |
|
277 |
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
278 |
|
@@ -293,15 +300,12 @@ class Attention(nn.Module):
|
|
293 |
|
294 |
if alibi is None:
|
295 |
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
296 |
-
key_layer_ = key_layer.reshape(batch_size, self.
|
297 |
-
value_layer_ = value_layer.reshape(batch_size, self.
|
298 |
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
)
|
303 |
-
except Exception as e:
|
304 |
-
raise
|
305 |
|
306 |
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
307 |
x = x.permute(0, 2, 1, 3)
|
@@ -326,7 +330,8 @@ class Attention(nn.Module):
|
|
326 |
attention_scores = attention_scores.to(torch.float32)
|
327 |
# attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
328 |
attention_probs = F.softmax(
|
329 |
-
(attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor
|
|
|
330 |
dim=-1,
|
331 |
dtype=hidden_states.dtype,
|
332 |
)
|
@@ -375,14 +380,12 @@ class DecoderLayer(nn.Module):
|
|
375 |
super().__init__()
|
376 |
hidden_size = config.hidden_size
|
377 |
|
378 |
-
self.
|
|
|
|
|
379 |
self.num_heads = config.n_head
|
380 |
self.self_attention = Attention(config)
|
381 |
|
382 |
-
if not config.parallel_attn:
|
383 |
-
# unused if parallel attn
|
384 |
-
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
385 |
-
|
386 |
self.mlp = MLP(config)
|
387 |
|
388 |
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
@@ -401,12 +404,14 @@ class DecoderLayer(nn.Module):
|
|
401 |
output_attentions: bool = False,
|
402 |
):
|
403 |
|
404 |
-
|
|
|
|
|
405 |
residual = hidden_states
|
406 |
|
407 |
# Self attention.
|
408 |
attn_outputs = self.self_attention(
|
409 |
-
|
410 |
layer_past=layer_past,
|
411 |
attention_mask=attention_mask,
|
412 |
alibi=alibi,
|
@@ -417,19 +422,14 @@ class DecoderLayer(nn.Module):
|
|
417 |
|
418 |
attention_output = attn_outputs[0]
|
419 |
|
420 |
-
if not self.config.parallel_attn:
|
421 |
-
residual = dropout_add(attention_output, residual, self.config.attention_dropout, training=self.training)
|
422 |
-
layernorm_output = self.post_attention_layernorm(residual)
|
423 |
-
|
424 |
outputs = attn_outputs[1:]
|
425 |
|
426 |
# MLP.
|
427 |
-
mlp_output = self.mlp(
|
428 |
-
|
429 |
-
if self.config.parallel_attn:
|
430 |
-
mlp_output += attention_output
|
431 |
|
432 |
-
output = dropout_add(
|
|
|
|
|
433 |
|
434 |
if use_cache:
|
435 |
outputs = (output,) + outputs
|
@@ -1120,4 +1120,4 @@ class RWForQuestionAnswering(RWPreTrainedModel):
|
|
1120 |
end_logits=end_logits,
|
1121 |
hidden_states=outputs.hidden_states,
|
1122 |
attentions=outputs.attentions,
|
1123 |
-
)
|
|
|
52 |
|
53 |
def __init__(
|
54 |
self,
|
55 |
+
head_dim: int,
|
56 |
base=10000,
|
57 |
+
use_cache=False,
|
58 |
):
|
|
|
|
|
59 |
super().__init__()
|
60 |
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
61 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
|
64 |
self.batch_size_cached = None
|
65 |
self.cos_cached: torch.Tensor | None = None
|
66 |
self.sin_cached: torch.Tensor | None = None
|
67 |
+
self.use_cache = use_cache
|
68 |
|
69 |
def cos_sin(
|
70 |
self,
|
|
|
107 |
def forward(self, q, k):
|
108 |
batch, seq_len, head_dim = q.shape
|
109 |
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
110 |
+
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
|
|
|
|
|
|
111 |
|
112 |
|
113 |
def _make_causal_mask(
|
|
|
184 |
f" {self.num_heads})."
|
185 |
)
|
186 |
|
187 |
+
self.maybe_rotary = RotaryEmbedding(config.head_dim) if config.rotary else lambda q, k: (q, k)
|
188 |
|
189 |
# Layer-wise attention scaling
|
190 |
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
|
|
192 |
|
193 |
self.query_key_value = Linear(
|
194 |
self.hidden_size,
|
195 |
+
(config.n_head_kv * 2 + config.n_head) * self.head_dim,
|
196 |
bias=config.bias,
|
197 |
)
|
|
|
198 |
self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
|
199 |
self.attention_dropout = nn.Dropout(config.attention_dropout)
|
200 |
+
self.num_kv = config.n_head_kv
|
201 |
|
202 |
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
203 |
"""
|
204 |
+
Split the last dimension into (num_heads, head_dim), results share same memory
|
205 |
storage as `fused_qkv`
|
206 |
|
207 |
Args:
|
208 |
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
|
209 |
|
210 |
Returns:
|
211 |
+
query: [batch_size, seq_length, num_heads, head_dim]
|
212 |
+
key: [batch_size, seq_length, num_heads, head_dim]
|
213 |
value: [batch_size, seq_length, num_heads, head_dim]
|
214 |
"""
|
215 |
+
batch, seq_len, _ = fused_qkv.shape
|
216 |
+
qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv + 2, 64)
|
217 |
+
q = qkv[:, :, :, :-2]
|
218 |
+
k = qkv[:, :, :, [-2]]
|
219 |
+
v = qkv[:, :, :, [-1]]
|
220 |
+
k = torch.broadcast_to(k, q.shape)
|
221 |
+
v = torch.broadcast_to(v, q.shape)
|
222 |
+
|
223 |
+
q, k, v = [
|
224 |
+
rearrange(
|
225 |
+
x,
|
226 |
+
"batch seq_len group num_heads head_dim ->\
|
227 |
+
batch seq_len (group num_heads) head_dim",
|
228 |
+
head_dim=self.head_dim,
|
229 |
+
)
|
230 |
+
for x in [q, k, v]
|
231 |
+
]
|
232 |
+
return q, k, v
|
233 |
|
234 |
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
|
235 |
"""
|
|
|
275 |
|
276 |
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
277 |
key_layer = key_layer.transpose(1, 2).reshape(
|
278 |
+
batch_size * self.num_heads,
|
279 |
q_length,
|
280 |
self.head_dim,
|
281 |
)
|
282 |
+
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
283 |
|
284 |
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
285 |
|
|
|
300 |
|
301 |
if alibi is None:
|
302 |
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
303 |
+
key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
304 |
+
value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
305 |
|
306 |
+
attn_output = F.scaled_dot_product_attention(
|
307 |
+
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
|
308 |
+
)
|
|
|
|
|
|
|
309 |
|
310 |
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
311 |
x = x.permute(0, 2, 1, 3)
|
|
|
330 |
attention_scores = attention_scores.to(torch.float32)
|
331 |
# attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
332 |
attention_probs = F.softmax(
|
333 |
+
(attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor
|
334 |
+
+ attention_mask_float,
|
335 |
dim=-1,
|
336 |
dtype=hidden_states.dtype,
|
337 |
)
|
|
|
380 |
super().__init__()
|
381 |
hidden_size = config.hidden_size
|
382 |
|
383 |
+
self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
384 |
+
self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
385 |
+
|
386 |
self.num_heads = config.n_head
|
387 |
self.self_attention = Attention(config)
|
388 |
|
|
|
|
|
|
|
|
|
389 |
self.mlp = MLP(config)
|
390 |
|
391 |
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
|
|
404 |
output_attentions: bool = False,
|
405 |
):
|
406 |
|
407 |
+
ln_attn = self.ln_attn(hidden_states)
|
408 |
+
ln_mlp = self.ln_mlp(hidden_states)
|
409 |
+
|
410 |
residual = hidden_states
|
411 |
|
412 |
# Self attention.
|
413 |
attn_outputs = self.self_attention(
|
414 |
+
ln_attn,
|
415 |
layer_past=layer_past,
|
416 |
attention_mask=attention_mask,
|
417 |
alibi=alibi,
|
|
|
422 |
|
423 |
attention_output = attn_outputs[0]
|
424 |
|
|
|
|
|
|
|
|
|
425 |
outputs = attn_outputs[1:]
|
426 |
|
427 |
# MLP.
|
428 |
+
mlp_output = self.mlp(ln_mlp)
|
|
|
|
|
|
|
429 |
|
430 |
+
output = dropout_add(
|
431 |
+
mlp_output + attention_output, residual, self.config.hidden_dropout, training=self.training
|
432 |
+
)
|
433 |
|
434 |
if use_cache:
|
435 |
outputs = (output,) + outputs
|
|
|
1120 |
end_logits=end_logits,
|
1121 |
hidden_states=outputs.hidden_states,
|
1122 |
attentions=outputs.attentions,
|
1123 |
+
)
|