Changes in modelling_RW.py to be able to handle past_key_values for faster model generations
#60
by
purunfer22
- opened
- modelling_RW.py +72 -36
modelling_RW.py
CHANGED
@@ -11,7 +11,9 @@ import torch.utils.checkpoint
|
|
11 |
from torch import nn
|
12 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
|
13 |
from torch.nn import functional as F
|
14 |
-
|
|
|
|
|
15 |
from transformers.modeling_outputs import (
|
16 |
BaseModelOutputWithPastAndCrossAttentions,
|
17 |
CausalLMOutputWithCrossAttentions,
|
@@ -87,10 +89,19 @@ class RotaryEmbedding(torch.nn.Module):
|
|
87 |
|
88 |
return self.cos_cached, self.sin_cached
|
89 |
|
90 |
-
def forward(self, q, k):
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
92 |
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
93 |
-
|
|
|
|
|
|
|
|
|
94 |
|
95 |
|
96 |
def _make_causal_mask(
|
@@ -100,10 +111,10 @@ def _make_causal_mask(
|
|
100 |
mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
|
101 |
# ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
|
102 |
seq_ids = torch.arange(target_length, device=device)
|
103 |
-
mask[:, past_key_values_length:] = seq_ids[:, None]
|
104 |
|
105 |
if past_key_values_length > 0:
|
106 |
-
mask[:, :past_key_values_length] =
|
107 |
|
108 |
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
|
109 |
return expanded_mask
|
@@ -150,6 +161,10 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
|
|
150 |
out = residual + out
|
151 |
return out
|
152 |
|
|
|
|
|
|
|
|
|
153 |
|
154 |
class Attention(nn.Module):
|
155 |
def __init__(self, config: RWConfig):
|
@@ -239,9 +254,8 @@ class Attention(nn.Module):
|
|
239 |
use_cache: bool = False,
|
240 |
output_attentions: bool = False,
|
241 |
):
|
|
|
242 |
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
243 |
-
|
244 |
-
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
245 |
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
246 |
|
247 |
batch_size, q_length, _, _ = query_layer.shape
|
@@ -254,20 +268,27 @@ class Attention(nn.Module):
|
|
254 |
)
|
255 |
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
|
256 |
|
257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
|
259 |
if layer_past is not None:
|
260 |
past_key, past_value = layer_past
|
261 |
-
|
262 |
-
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
263 |
-
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
264 |
key_layer = torch.cat((past_key, key_layer), dim=1)
|
265 |
value_layer = torch.cat((past_value, value_layer), dim=1)
|
|
|
266 |
|
267 |
_, kv_length, _ = key_layer.shape
|
268 |
|
269 |
if use_cache is True:
|
270 |
-
|
|
|
271 |
else:
|
272 |
present = None
|
273 |
|
@@ -275,10 +296,16 @@ class Attention(nn.Module):
|
|
275 |
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
276 |
key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
277 |
value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
|
|
278 |
|
279 |
-
|
280 |
-
|
281 |
-
|
|
|
|
|
|
|
|
|
|
|
282 |
|
283 |
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
284 |
x = x.permute(0, 2, 1, 3)
|
@@ -475,8 +502,8 @@ class RWPreTrainedModel(PreTrainedModel):
|
|
475 |
def _convert_to_rw_cache(
|
476 |
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
|
477 |
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
478 |
-
batch_size,
|
479 |
-
batch_size_times_num_heads = batch_size
|
480 |
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
|
481 |
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
|
482 |
return tuple(
|
@@ -488,6 +515,7 @@ class RWPreTrainedModel(PreTrainedModel):
|
|
488 |
)
|
489 |
|
490 |
|
|
|
491 |
class RWModel(RWPreTrainedModel):
|
492 |
def __init__(self, config: RWConfig):
|
493 |
super().__init__(config)
|
@@ -522,10 +550,11 @@ class RWModel(RWPreTrainedModel):
|
|
522 |
device = attention_mask.device
|
523 |
_, src_length = input_shape
|
524 |
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
|
|
529 |
|
530 |
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
531 |
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
@@ -560,7 +589,7 @@ class RWModel(RWPreTrainedModel):
|
|
560 |
)
|
561 |
if len(deprecated_arguments) > 0:
|
562 |
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
563 |
-
|
564 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
565 |
output_hidden_states = (
|
566 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
@@ -616,6 +645,7 @@ class RWModel(RWPreTrainedModel):
|
|
616 |
input_shape=(batch_size, seq_length),
|
617 |
past_key_values_length=past_key_values_length,
|
618 |
)
|
|
|
619 |
|
620 |
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
621 |
|
@@ -646,16 +676,18 @@ class RWModel(RWPreTrainedModel):
|
|
646 |
)
|
647 |
else:
|
648 |
outputs = block(
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
)
|
|
|
657 |
|
658 |
hidden_states = outputs[0]
|
|
|
659 |
if use_cache is True:
|
660 |
presents = presents + (outputs[1],)
|
661 |
|
@@ -704,16 +736,20 @@ class RWForCausalLM(RWPreTrainedModel):
|
|
704 |
**kwargs,
|
705 |
) -> dict:
|
706 |
# only last token for input_ids if past is not None
|
707 |
-
if past
|
|
|
708 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
709 |
-
|
710 |
# the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
|
711 |
-
if
|
712 |
-
|
|
|
|
|
|
|
713 |
|
714 |
return {
|
715 |
"input_ids": input_ids,
|
716 |
-
"past_key_values":
|
717 |
"use_cache": kwargs.get("use_cache"),
|
718 |
"attention_mask": attention_mask,
|
719 |
}
|
|
|
11 |
from torch import nn
|
12 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
|
13 |
from torch.nn import functional as F
|
14 |
+
import pdb
|
15 |
+
import os
|
16 |
+
import pickle
|
17 |
from transformers.modeling_outputs import (
|
18 |
BaseModelOutputWithPastAndCrossAttentions,
|
19 |
CausalLMOutputWithCrossAttentions,
|
|
|
89 |
|
90 |
return self.cos_cached, self.sin_cached
|
91 |
|
92 |
+
def forward(self, q, k, past_seq_length=None):
|
93 |
+
if past_seq_length == None :
|
94 |
+
batch, seq_len, head_dim = q.shape
|
95 |
+
else :
|
96 |
+
# print("past_seq_length", past_seq_length)
|
97 |
+
batch, input_seq_len, head_dim = q.shape
|
98 |
+
seq_len = past_seq_length + input_seq_len
|
99 |
cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
|
100 |
+
if past_seq_length != None :
|
101 |
+
return (q * cos[:, past_seq_length:, :]) + (rotate_half(q) * sin[:, past_seq_length:, :]), (k * cos[:, past_seq_length:, :]) + (rotate_half(k) * sin[:, past_seq_length:, :])
|
102 |
+
else :
|
103 |
+
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
104 |
+
|
105 |
|
106 |
|
107 |
def _make_causal_mask(
|
|
|
111 |
mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
|
112 |
# ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
|
113 |
seq_ids = torch.arange(target_length, device=device)
|
114 |
+
mask[:, past_key_values_length:] = seq_ids[:, None] >= seq_ids[None, :]
|
115 |
|
116 |
if past_key_values_length > 0:
|
117 |
+
mask[:, :past_key_values_length] = True
|
118 |
|
119 |
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
|
120 |
return expanded_mask
|
|
|
161 |
out = residual + out
|
162 |
return out
|
163 |
|
164 |
+
def dump_value(name, tensor) :
|
165 |
+
with open("/home/purushottam/inspect_falcon/{}".format(name), "wb") as f :
|
166 |
+
pickle.dump(tensor, f)
|
167 |
+
|
168 |
|
169 |
class Attention(nn.Module):
|
170 |
def __init__(self, config: RWConfig):
|
|
|
254 |
use_cache: bool = False,
|
255 |
output_attentions: bool = False,
|
256 |
):
|
257 |
+
|
258 |
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
|
|
|
|
259 |
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
260 |
|
261 |
batch_size, q_length, _, _ = query_layer.shape
|
|
|
268 |
)
|
269 |
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
|
270 |
|
271 |
+
if layer_past is not None :
|
272 |
+
past_key, past_value = layer_past
|
273 |
+
past_kv_length = past_key.shape[2]
|
274 |
+
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
|
275 |
+
else :
|
276 |
+
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
|
277 |
+
|
278 |
+
|
279 |
|
280 |
if layer_past is not None:
|
281 |
past_key, past_value = layer_past
|
282 |
+
past_key = past_key.permute(0, 2, 1)
|
|
|
|
|
283 |
key_layer = torch.cat((past_key, key_layer), dim=1)
|
284 |
value_layer = torch.cat((past_value, value_layer), dim=1)
|
285 |
+
|
286 |
|
287 |
_, kv_length, _ = key_layer.shape
|
288 |
|
289 |
if use_cache is True:
|
290 |
+
key_layer_permute = key_layer.permute(0, 2, 1)
|
291 |
+
present = (key_layer_permute, value_layer)
|
292 |
else:
|
293 |
present = None
|
294 |
|
|
|
296 |
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
297 |
key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
298 |
value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
|
299 |
+
|
300 |
|
301 |
+
if attention_mask is not None :
|
302 |
+
attn_output = F.scaled_dot_product_attention(
|
303 |
+
query_layer_, key_layer_, value_layer_, attention_mask, 0.0, is_causal=False
|
304 |
+
)
|
305 |
+
else :
|
306 |
+
attn_output = F.scaled_dot_product_attention(
|
307 |
+
query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
|
308 |
+
)
|
309 |
|
310 |
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
311 |
x = x.permute(0, 2, 1, 3)
|
|
|
502 |
def _convert_to_rw_cache(
|
503 |
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
|
504 |
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
505 |
+
batch_size, seq_length, head_dim = past_key_value[0][0].shape
|
506 |
+
batch_size_times_num_heads = batch_size
|
507 |
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
|
508 |
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
|
509 |
return tuple(
|
|
|
515 |
)
|
516 |
|
517 |
|
518 |
+
|
519 |
class RWModel(RWPreTrainedModel):
|
520 |
def __init__(self, config: RWConfig):
|
521 |
super().__init__(config)
|
|
|
550 |
device = attention_mask.device
|
551 |
_, src_length = input_shape
|
552 |
|
553 |
+
|
554 |
+
# if src_length > 1:
|
555 |
+
combined_attention_mask = _make_causal_mask(
|
556 |
+
input_shape, device=device, past_key_values_length=past_key_values_length
|
557 |
+
)
|
558 |
|
559 |
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
560 |
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
|
|
589 |
)
|
590 |
if len(deprecated_arguments) > 0:
|
591 |
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
|
592 |
+
# pdb.set_trace()
|
593 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
594 |
output_hidden_states = (
|
595 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
645 |
input_shape=(batch_size, seq_length),
|
646 |
past_key_values_length=past_key_values_length,
|
647 |
)
|
648 |
+
# print("causal_mask", causal_mask)
|
649 |
|
650 |
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
651 |
|
|
|
676 |
)
|
677 |
else:
|
678 |
outputs = block(
|
679 |
+
hidden_states,
|
680 |
+
layer_past=layer_past,
|
681 |
+
attention_mask=causal_mask,
|
682 |
+
head_mask=head_mask[i],
|
683 |
+
use_cache=use_cache,
|
684 |
+
output_attentions=output_attentions,
|
685 |
+
alibi=alibi,
|
686 |
+
)
|
687 |
+
|
688 |
|
689 |
hidden_states = outputs[0]
|
690 |
+
|
691 |
if use_cache is True:
|
692 |
presents = presents + (outputs[1],)
|
693 |
|
|
|
736 |
**kwargs,
|
737 |
) -> dict:
|
738 |
# only last token for input_ids if past is not None
|
739 |
+
# only last token for input_ids if past is not None
|
740 |
+
if kwargs.get("past_key_values", None) :
|
741 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
742 |
+
past_key_values = kwargs["past_key_values"]
|
743 |
# the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
|
744 |
+
# if kwargs["past_key_values"][0][0].shape[0] == input_ids.shape[0]:
|
745 |
+
# past_key_values = self._convert_to_rw_cache(kwargs["past_key_values"])
|
746 |
+
# past_key_values = kwargs["past_key_values"]
|
747 |
+
else :
|
748 |
+
past_key_values = None
|
749 |
|
750 |
return {
|
751 |
"input_ids": input_ids,
|
752 |
+
"past_key_values": past_key_values,
|
753 |
"use_cache": kwargs.get("use_cache"),
|
754 |
"attention_mask": attention_mask,
|
755 |
}
|