Changes in modelling_RW.py to be able to handle past_key_values for faster model generations

#64
by puru22 - opened
Files changed (1) hide show
  1. modelling_RW.py +42 -19
modelling_RW.py CHANGED
@@ -87,10 +87,18 @@ class RotaryEmbedding(torch.nn.Module):
87
 
88
  return self.cos_cached, self.sin_cached
89
 
90
- def forward(self, q, k):
91
- batch, seq_len, head_dim = q.shape
 
 
 
 
 
92
  cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
93
- return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
 
 
 
94
 
95
 
96
  def _make_causal_mask(
@@ -100,10 +108,10 @@ def _make_causal_mask(
100
  mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
101
  # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
102
  seq_ids = torch.arange(target_length, device=device)
103
- mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
104
 
105
  if past_key_values_length > 0:
106
- mask[:, :past_key_values_length] = False
107
 
108
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
109
  return expanded_mask
@@ -264,20 +272,27 @@ class Attention(nn.Module):
264
  )
265
  value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
266
 
267
- query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
 
 
 
 
 
268
 
269
  if layer_past is not None:
270
  past_key, past_value = layer_past
271
  # concatenate along seq_length dimension:
272
  # - key: [batch_size * self.num_heads, head_dim, kv_length]
273
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
 
274
  key_layer = torch.cat((past_key, key_layer), dim=1)
275
  value_layer = torch.cat((past_value, value_layer), dim=1)
276
 
277
  _, kv_length, _ = key_layer.shape
278
 
279
  if use_cache is True:
280
- present = (key_layer, value_layer)
 
281
  else:
282
  present = None
283
 
@@ -286,9 +301,14 @@ class Attention(nn.Module):
286
  key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
287
  value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
288
 
289
- attn_output = F.scaled_dot_product_attention(
290
- query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
291
- )
 
 
 
 
 
292
 
293
  x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
294
  x = x.permute(0, 2, 1, 3)
@@ -528,10 +548,10 @@ class RWModel(RWPreTrainedModel):
528
  device = attention_mask.device
529
  _, src_length = input_shape
530
 
531
- if src_length > 1:
532
- combined_attention_mask = _make_causal_mask(
533
- input_shape, device=device, past_key_values_length=past_key_values_length
534
- )
535
 
536
  # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
537
  expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
@@ -710,16 +730,19 @@ class RWForCausalLM(RWPreTrainedModel):
710
  **kwargs,
711
  ) -> dict:
712
  # only last token for input_ids if past is not None
713
- if past:
714
  input_ids = input_ids[:, -1].unsqueeze(-1)
715
-
716
  # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
717
- if past[0][0].shape[0] == input_ids.shape[0]:
718
- past = self._convert_to_rw_cache(past)
 
 
 
719
 
720
  return {
721
  "input_ids": input_ids,
722
- "past_key_values": past,
723
  "use_cache": kwargs.get("use_cache"),
724
  "attention_mask": attention_mask,
725
  }
 
87
 
88
  return self.cos_cached, self.sin_cached
89
 
90
+ def forward(self, q, k, past_seq_length=None):
91
+ if past_seq_length == None :
92
+ batch, seq_len, head_dim = q.shape
93
+ else :
94
+ # print("past_seq_length", past_seq_length)
95
+ batch, input_seq_len, head_dim = q.shape
96
+ seq_len = past_seq_length + input_seq_len
97
  cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
98
+ if past_seq_length != None :
99
+ return (q * cos[:, past_seq_length:, :]) + (rotate_half(q) * sin[:, past_seq_length:, :]), (k * cos[:, past_seq_length:, :]) + (rotate_half(k) * sin[:, past_seq_length:, :])
100
+ else :
101
+ return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
102
 
103
 
104
  def _make_causal_mask(
 
108
  mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
109
  # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
110
  seq_ids = torch.arange(target_length, device=device)
111
+ mask[:, past_key_values_length:] = seq_ids[:, None] >= seq_ids[None, :]
112
 
113
  if past_key_values_length > 0:
114
+ mask[:, :past_key_values_length] = True
115
 
116
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
117
  return expanded_mask
 
272
  )
273
  value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
274
 
275
+ if layer_past is not None :
276
+ past_key, past_value = layer_past
277
+ past_kv_length = past_key.shape[2]
278
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
279
+ else :
280
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
281
 
282
  if layer_past is not None:
283
  past_key, past_value = layer_past
284
  # concatenate along seq_length dimension:
285
  # - key: [batch_size * self.num_heads, head_dim, kv_length]
286
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
287
+ past_key = past_key.permute(0, 2, 1)
288
  key_layer = torch.cat((past_key, key_layer), dim=1)
289
  value_layer = torch.cat((past_value, value_layer), dim=1)
290
 
291
  _, kv_length, _ = key_layer.shape
292
 
293
  if use_cache is True:
294
+ key_layer_permute = key_layer.permute(0, 2, 1)
295
+ present = (key_layer_permute, value_layer)
296
  else:
297
  present = None
298
 
 
301
  key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
302
  value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
303
 
304
+ if attention_mask is not None :
305
+ attn_output = F.scaled_dot_product_attention(
306
+ query_layer_, key_layer_, value_layer_, attention_mask, 0.0, is_causal=False
307
+ )
308
+ else :
309
+ attn_output = F.scaled_dot_product_attention(
310
+ query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
311
+ )
312
 
313
  x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
314
  x = x.permute(0, 2, 1, 3)
 
548
  device = attention_mask.device
549
  _, src_length = input_shape
550
 
551
+ # if src_length > 1:
552
+ combined_attention_mask = _make_causal_mask(
553
+ input_shape, device=device, past_key_values_length=past_key_values_length
554
+ )
555
 
556
  # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
557
  expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
 
730
  **kwargs,
731
  ) -> dict:
732
  # only last token for input_ids if past is not None
733
+ if kwargs.get("past_key_values", None) :
734
  input_ids = input_ids[:, -1].unsqueeze(-1)
735
+ past_key_values = kwargs["past_key_values"]
736
  # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
737
+ # if kwargs["past_key_values"][0][0].shape[0] == input_ids.shape[0]:
738
+ # past_key_values = self._convert_to_rw_cache(kwargs["past_key_values"])
739
+ # past_key_values = kwargs["past_key_values"]
740
+ else :
741
+ past_key_values = None
742
 
743
  return {
744
  "input_ids": input_ids,
745
+ "past_key_values": past_key_values,
746
  "use_cache": kwargs.get("use_cache"),
747
  "attention_mask": attention_mask,
748
  }