Changes in modelling_RW.py to be able to handle past_key_values for faster model generations

#60
by purunfer22 - opened
Files changed (1) hide show
  1. modelling_RW.py +72 -36
modelling_RW.py CHANGED
@@ -11,7 +11,9 @@ import torch.utils.checkpoint
11
  from torch import nn
12
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
13
  from torch.nn import functional as F
14
-
 
 
15
  from transformers.modeling_outputs import (
16
  BaseModelOutputWithPastAndCrossAttentions,
17
  CausalLMOutputWithCrossAttentions,
@@ -87,10 +89,19 @@ class RotaryEmbedding(torch.nn.Module):
87
 
88
  return self.cos_cached, self.sin_cached
89
 
90
- def forward(self, q, k):
91
- batch, seq_len, head_dim = q.shape
 
 
 
 
 
92
  cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
93
- return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
 
 
 
 
94
 
95
 
96
  def _make_causal_mask(
@@ -100,10 +111,10 @@ def _make_causal_mask(
100
  mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
101
  # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
102
  seq_ids = torch.arange(target_length, device=device)
103
- mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
104
 
105
  if past_key_values_length > 0:
106
- mask[:, :past_key_values_length] = False
107
 
108
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
109
  return expanded_mask
@@ -150,6 +161,10 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
150
  out = residual + out
151
  return out
152
 
 
 
 
 
153
 
154
  class Attention(nn.Module):
155
  def __init__(self, config: RWConfig):
@@ -239,9 +254,8 @@ class Attention(nn.Module):
239
  use_cache: bool = False,
240
  output_attentions: bool = False,
241
  ):
 
242
  fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
243
-
244
- # 3 x [batch_size, seq_length, num_heads, head_dim]
245
  (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
246
 
247
  batch_size, q_length, _, _ = query_layer.shape
@@ -254,20 +268,27 @@ class Attention(nn.Module):
254
  )
255
  value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
256
 
257
- query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
 
 
 
 
 
 
 
258
 
259
  if layer_past is not None:
260
  past_key, past_value = layer_past
261
- # concatenate along seq_length dimension:
262
- # - key: [batch_size * self.num_heads, head_dim, kv_length]
263
- # - value: [batch_size * self.num_heads, kv_length, head_dim]
264
  key_layer = torch.cat((past_key, key_layer), dim=1)
265
  value_layer = torch.cat((past_value, value_layer), dim=1)
 
266
 
267
  _, kv_length, _ = key_layer.shape
268
 
269
  if use_cache is True:
270
- present = (key_layer, value_layer)
 
271
  else:
272
  present = None
273
 
@@ -275,10 +296,16 @@ class Attention(nn.Module):
275
  query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
276
  key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
277
  value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
 
278
 
279
- attn_output = F.scaled_dot_product_attention(
280
- query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
281
- )
 
 
 
 
 
282
 
283
  x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
284
  x = x.permute(0, 2, 1, 3)
@@ -475,8 +502,8 @@ class RWPreTrainedModel(PreTrainedModel):
475
  def _convert_to_rw_cache(
476
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
477
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
478
- batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
479
- batch_size_times_num_heads = batch_size * num_heads
480
  # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
481
  # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
482
  return tuple(
@@ -488,6 +515,7 @@ class RWPreTrainedModel(PreTrainedModel):
488
  )
489
 
490
 
 
491
  class RWModel(RWPreTrainedModel):
492
  def __init__(self, config: RWConfig):
493
  super().__init__(config)
@@ -522,10 +550,11 @@ class RWModel(RWPreTrainedModel):
522
  device = attention_mask.device
523
  _, src_length = input_shape
524
 
525
- if src_length > 1:
526
- combined_attention_mask = _make_causal_mask(
527
- input_shape, device=device, past_key_values_length=past_key_values_length
528
- )
 
529
 
530
  # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
531
  expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
@@ -560,7 +589,7 @@ class RWModel(RWPreTrainedModel):
560
  )
561
  if len(deprecated_arguments) > 0:
562
  raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
563
-
564
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
565
  output_hidden_states = (
566
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -616,6 +645,7 @@ class RWModel(RWPreTrainedModel):
616
  input_shape=(batch_size, seq_length),
617
  past_key_values_length=past_key_values_length,
618
  )
 
619
 
620
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
621
 
@@ -646,16 +676,18 @@ class RWModel(RWPreTrainedModel):
646
  )
647
  else:
648
  outputs = block(
649
- hidden_states,
650
- layer_past=layer_past,
651
- attention_mask=causal_mask,
652
- head_mask=head_mask[i],
653
- use_cache=use_cache,
654
- output_attentions=output_attentions,
655
- alibi=alibi,
656
- )
 
657
 
658
  hidden_states = outputs[0]
 
659
  if use_cache is True:
660
  presents = presents + (outputs[1],)
661
 
@@ -704,16 +736,20 @@ class RWForCausalLM(RWPreTrainedModel):
704
  **kwargs,
705
  ) -> dict:
706
  # only last token for input_ids if past is not None
707
- if past:
 
708
  input_ids = input_ids[:, -1].unsqueeze(-1)
709
-
710
  # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
711
- if past[0][0].shape[0] == input_ids.shape[0]:
712
- past = self._convert_to_rw_cache(past)
 
 
 
713
 
714
  return {
715
  "input_ids": input_ids,
716
- "past_key_values": past,
717
  "use_cache": kwargs.get("use_cache"),
718
  "attention_mask": attention_mask,
719
  }
 
11
  from torch import nn
12
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
13
  from torch.nn import functional as F
14
+ import pdb
15
+ import os
16
+ import pickle
17
  from transformers.modeling_outputs import (
18
  BaseModelOutputWithPastAndCrossAttentions,
19
  CausalLMOutputWithCrossAttentions,
 
89
 
90
  return self.cos_cached, self.sin_cached
91
 
92
+ def forward(self, q, k, past_seq_length=None):
93
+ if past_seq_length == None :
94
+ batch, seq_len, head_dim = q.shape
95
+ else :
96
+ # print("past_seq_length", past_seq_length)
97
+ batch, input_seq_len, head_dim = q.shape
98
+ seq_len = past_seq_length + input_seq_len
99
  cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
100
+ if past_seq_length != None :
101
+ return (q * cos[:, past_seq_length:, :]) + (rotate_half(q) * sin[:, past_seq_length:, :]), (k * cos[:, past_seq_length:, :]) + (rotate_half(k) * sin[:, past_seq_length:, :])
102
+ else :
103
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
104
+
105
 
106
 
107
  def _make_causal_mask(
 
111
  mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
112
  # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
113
  seq_ids = torch.arange(target_length, device=device)
114
+ mask[:, past_key_values_length:] = seq_ids[:, None] >= seq_ids[None, :]
115
 
116
  if past_key_values_length > 0:
117
+ mask[:, :past_key_values_length] = True
118
 
119
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
120
  return expanded_mask
 
161
  out = residual + out
162
  return out
163
 
164
+ def dump_value(name, tensor) :
165
+ with open("/home/purushottam/inspect_falcon/{}".format(name), "wb") as f :
166
+ pickle.dump(tensor, f)
167
+
168
 
169
  class Attention(nn.Module):
170
  def __init__(self, config: RWConfig):
 
254
  use_cache: bool = False,
255
  output_attentions: bool = False,
256
  ):
257
+
258
  fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
 
 
259
  (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
260
 
261
  batch_size, q_length, _, _ = query_layer.shape
 
268
  )
269
  value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
270
 
271
+ if layer_past is not None :
272
+ past_key, past_value = layer_past
273
+ past_kv_length = past_key.shape[2]
274
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
275
+ else :
276
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
277
+
278
+
279
 
280
  if layer_past is not None:
281
  past_key, past_value = layer_past
282
+ past_key = past_key.permute(0, 2, 1)
 
 
283
  key_layer = torch.cat((past_key, key_layer), dim=1)
284
  value_layer = torch.cat((past_value, value_layer), dim=1)
285
+
286
 
287
  _, kv_length, _ = key_layer.shape
288
 
289
  if use_cache is True:
290
+ key_layer_permute = key_layer.permute(0, 2, 1)
291
+ present = (key_layer_permute, value_layer)
292
  else:
293
  present = None
294
 
 
296
  query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
297
  key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
298
  value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
299
+
300
 
301
+ if attention_mask is not None :
302
+ attn_output = F.scaled_dot_product_attention(
303
+ query_layer_, key_layer_, value_layer_, attention_mask, 0.0, is_causal=False
304
+ )
305
+ else :
306
+ attn_output = F.scaled_dot_product_attention(
307
+ query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
308
+ )
309
 
310
  x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
311
  x = x.permute(0, 2, 1, 3)
 
502
  def _convert_to_rw_cache(
503
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
504
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
505
+ batch_size, seq_length, head_dim = past_key_value[0][0].shape
506
+ batch_size_times_num_heads = batch_size
507
  # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
508
  # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
509
  return tuple(
 
515
  )
516
 
517
 
518
+
519
  class RWModel(RWPreTrainedModel):
520
  def __init__(self, config: RWConfig):
521
  super().__init__(config)
 
550
  device = attention_mask.device
551
  _, src_length = input_shape
552
 
553
+
554
+ # if src_length > 1:
555
+ combined_attention_mask = _make_causal_mask(
556
+ input_shape, device=device, past_key_values_length=past_key_values_length
557
+ )
558
 
559
  # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
560
  expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
 
589
  )
590
  if len(deprecated_arguments) > 0:
591
  raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
592
+ # pdb.set_trace()
593
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
594
  output_hidden_states = (
595
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
645
  input_shape=(batch_size, seq_length),
646
  past_key_values_length=past_key_values_length,
647
  )
648
+ # print("causal_mask", causal_mask)
649
 
650
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
651
 
 
676
  )
677
  else:
678
  outputs = block(
679
+ hidden_states,
680
+ layer_past=layer_past,
681
+ attention_mask=causal_mask,
682
+ head_mask=head_mask[i],
683
+ use_cache=use_cache,
684
+ output_attentions=output_attentions,
685
+ alibi=alibi,
686
+ )
687
+
688
 
689
  hidden_states = outputs[0]
690
+
691
  if use_cache is True:
692
  presents = presents + (outputs[1],)
693
 
 
736
  **kwargs,
737
  ) -> dict:
738
  # only last token for input_ids if past is not None
739
+ # only last token for input_ids if past is not None
740
+ if kwargs.get("past_key_values", None) :
741
  input_ids = input_ids[:, -1].unsqueeze(-1)
742
+ past_key_values = kwargs["past_key_values"]
743
  # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
744
+ # if kwargs["past_key_values"][0][0].shape[0] == input_ids.shape[0]:
745
+ # past_key_values = self._convert_to_rw_cache(kwargs["past_key_values"])
746
+ # past_key_values = kwargs["past_key_values"]
747
+ else :
748
+ past_key_values = None
749
 
750
  return {
751
  "input_ids": input_ids,
752
+ "past_key_values": past_key_values,
753
  "use_cache": kwargs.get("use_cache"),
754
  "attention_mask": attention_mask,
755
  }