Changes in modelling_RW.py to be able to handle past_key_values for faster model generations
#64
by
puru22
- opened
- modelling_RW.py +42 -19
modelling_RW.py
CHANGED
@@ -87,10 +87,18 @@ class RotaryEmbedding(torch.nn.Module):
|
|
87 |
|
88 |
return self.cos_cached, self.sin_cached
|
89 |
|
90 |
-
def forward(self, q, k):
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
92 |
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
93 |
-
|
|
|
|
|
|
|
94 |
|
95 |
|
96 |
def _make_causal_mask(
|
@@ -100,10 +108,10 @@ def _make_causal_mask(
|
|
100 |
mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
|
101 |
# ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
|
102 |
seq_ids = torch.arange(target_length, device=device)
|
103 |
-
mask[:, past_key_values_length:] = seq_ids[:, None]
|
104 |
|
105 |
if past_key_values_length > 0:
|
106 |
-
mask[:, :past_key_values_length] =
|
107 |
|
108 |
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
|
109 |
return expanded_mask
|
@@ -264,20 +272,27 @@ class Attention(nn.Module):
|
|
264 |
)
|
265 |
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
266 |
|
267 |
-
|
|
|
|
|
|
|
|
|
|
|
268 |
|
269 |
if layer_past is not None:
|
270 |
past_key, past_value = layer_past
|
271 |
# concatenate along seq_length dimension:
|
272 |
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
273 |
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
|
|
274 |
key_layer = torch.cat((past_key, key_layer), dim=1)
|
275 |
value_layer = torch.cat((past_value, value_layer), dim=1)
|
276 |
|
277 |
_, kv_length, _ = key_layer.shape
|
278 |
|
279 |
if use_cache is True:
|
280 |
-
|
|
|
281 |
else:
|
282 |
present = None
|
283 |
|
@@ -286,9 +301,14 @@ class Attention(nn.Module):
|
|
286 |
key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
287 |
value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
288 |
|
289 |
-
|
290 |
-
|
291 |
-
|
|
|
|
|
|
|
|
|
|
|
292 |
|
293 |
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
294 |
x = x.permute(0, 2, 1, 3)
|
@@ -528,10 +548,10 @@ class RWModel(RWPreTrainedModel):
|
|
528 |
device = attention_mask.device
|
529 |
_, src_length = input_shape
|
530 |
|
531 |
-
if src_length > 1:
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
|
536 |
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
537 |
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
@@ -710,16 +730,19 @@ class RWForCausalLM(RWPreTrainedModel):
|
|
710 |
**kwargs,
|
711 |
) -> dict:
|
712 |
# only last token for input_ids if past is not None
|
713 |
-
if
|
714 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
715 |
-
|
716 |
# the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
|
717 |
-
if
|
718 |
-
|
|
|
|
|
|
|
719 |
|
720 |
return {
|
721 |
"input_ids": input_ids,
|
722 |
-
"past_key_values":
|
723 |
"use_cache": kwargs.get("use_cache"),
|
724 |
"attention_mask": attention_mask,
|
725 |
}
|
|
|
87 |
|
88 |
return self.cos_cached, self.sin_cached
|
89 |
|
90 |
+
def forward(self, q, k, past_seq_length=None):
|
91 |
+
if past_seq_length == None :
|
92 |
+
batch, seq_len, head_dim = q.shape
|
93 |
+
else :
|
94 |
+
# print("past_seq_length", past_seq_length)
|
95 |
+
batch, input_seq_len, head_dim = q.shape
|
96 |
+
seq_len = past_seq_length + input_seq_len
|
97 |
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
98 |
+
if past_seq_length != None :
|
99 |
+
return (q * cos[:, past_seq_length:, :]) + (rotate_half(q) * sin[:, past_seq_length:, :]), (k * cos[:, past_seq_length:, :]) + (rotate_half(k) * sin[:, past_seq_length:, :])
|
100 |
+
else :
|
101 |
+
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
102 |
|
103 |
|
104 |
def _make_causal_mask(
|
|
|
108 |
mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
|
109 |
# ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
|
110 |
seq_ids = torch.arange(target_length, device=device)
|
111 |
+
mask[:, past_key_values_length:] = seq_ids[:, None] >= seq_ids[None, :]
|
112 |
|
113 |
if past_key_values_length > 0:
|
114 |
+
mask[:, :past_key_values_length] = True
|
115 |
|
116 |
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
|
117 |
return expanded_mask
|
|
|
272 |
)
|
273 |
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
274 |
|
275 |
+
if layer_past is not None :
|
276 |
+
past_key, past_value = layer_past
|
277 |
+
past_kv_length = past_key.shape[2]
|
278 |
+
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
|
279 |
+
else :
|
280 |
+
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
281 |
|
282 |
if layer_past is not None:
|
283 |
past_key, past_value = layer_past
|
284 |
# concatenate along seq_length dimension:
|
285 |
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
286 |
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
287 |
+
past_key = past_key.permute(0, 2, 1)
|
288 |
key_layer = torch.cat((past_key, key_layer), dim=1)
|
289 |
value_layer = torch.cat((past_value, value_layer), dim=1)
|
290 |
|
291 |
_, kv_length, _ = key_layer.shape
|
292 |
|
293 |
if use_cache is True:
|
294 |
+
key_layer_permute = key_layer.permute(0, 2, 1)
|
295 |
+
present = (key_layer_permute, value_layer)
|
296 |
else:
|
297 |
present = None
|
298 |
|
|
|
301 |
key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
302 |
value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
303 |
|
304 |
+
if attention_mask is not None :
|
305 |
+
attn_output = F.scaled_dot_product_attention(
|
306 |
+
query_layer_, key_layer_, value_layer_, attention_mask, 0.0, is_causal=False
|
307 |
+
)
|
308 |
+
else :
|
309 |
+
attn_output = F.scaled_dot_product_attention(
|
310 |
+
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
|
311 |
+
)
|
312 |
|
313 |
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
314 |
x = x.permute(0, 2, 1, 3)
|
|
|
548 |
device = attention_mask.device
|
549 |
_, src_length = input_shape
|
550 |
|
551 |
+
# if src_length > 1:
|
552 |
+
combined_attention_mask = _make_causal_mask(
|
553 |
+
input_shape, device=device, past_key_values_length=past_key_values_length
|
554 |
+
)
|
555 |
|
556 |
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
557 |
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
|
|
730 |
**kwargs,
|
731 |
) -> dict:
|
732 |
# only last token for input_ids if past is not None
|
733 |
+
if kwargs.get("past_key_values", None) :
|
734 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
735 |
+
past_key_values = kwargs["past_key_values"]
|
736 |
# the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
|
737 |
+
# if kwargs["past_key_values"][0][0].shape[0] == input_ids.shape[0]:
|
738 |
+
# past_key_values = self._convert_to_rw_cache(kwargs["past_key_values"])
|
739 |
+
# past_key_values = kwargs["past_key_values"]
|
740 |
+
else :
|
741 |
+
past_key_values = None
|
742 |
|
743 |
return {
|
744 |
"input_ids": input_ids,
|
745 |
+
"past_key_values": past_key_values,
|
746 |
"use_cache": kwargs.get("use_cache"),
|
747 |
"attention_mask": attention_mask,
|
748 |
}
|