gugarosa commited on
Commit
ecfe56e
1 Parent(s): 759d148

fix(modeling_phi): Fixes cached generation when above maximum context length.

Browse files
Files changed (1) hide show
  1. modeling_phi.py +13 -22
modeling_phi.py CHANGED
@@ -170,11 +170,11 @@ def _apply_rotary_emb_qkv(
170
 
171
  class RotaryEmbedding(nn.Module):
172
  """Rotary positional embedding (RoPE).
173
-
174
  Reference:
175
  RoFormer: Enhanced Transformer with Rotary Position Embedding.
176
  https://arxiv.org/pdf/2104.09864.pdf.
177
-
178
  """
179
 
180
  def __init__(
@@ -261,32 +261,30 @@ class RotaryEmbedding(nn.Module):
261
  seqlen_offset: int = 0,
262
  **kwargs,
263
  ) -> Tuple[torch.Tensor, torch.Tensor]:
264
- seq_start = seqlen_offset
265
- seq_end = seq_start + qkv.shape[1]
266
-
267
  if (
268
- self._cos_cached.device != qkv.device
 
269
  or self._cos_cached.dtype != qkv.dtype
270
  or (self.training and self._cos_cached.is_inference())
271
  ):
272
- self._update_cos_sin_cache(self.max_position_embeddings, device=qkv.device, dtype=qkv.dtype)
273
 
274
  if kv is None:
275
  return _apply_rotary_emb_qkv(
276
  qkv,
277
- self._cos_cached[seq_start:seq_end],
278
- self._sin_cached[seq_start:seq_end],
279
  )
280
  else:
281
  q = _apply_rotary_emb(
282
  qkv,
283
- self._cos_cached[seq_start:seq_end],
284
- self._sin_cached[seq_start:seq_end],
285
  )
286
  kv = _apply_rotary_emb_kv(
287
  kv,
288
- self._cos_cached[seq_start:seq_end],
289
- self._sin_cached[seq_start:seq_end],
290
  )
291
 
292
  return q, kv
@@ -498,9 +496,9 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
498
  sequence_end = sequence_start + kv.shape[1]
499
 
500
  # When the current sequence length is equal to or larger than the maximum sequence length,
501
- # we need to roll the cache to the left and update it
502
  if sequence_end >= inference_params.max_seqlen:
503
- inference_params.key_value_memory_dict[layer_idx] = inference_params.key_value_memory_dict[layer_idx].roll(-(sequence_end - sequence_start), 1)
504
 
505
  inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
506
  kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...]
@@ -864,13 +862,6 @@ class PhiPreTrainedModel(PreTrainedModel):
864
  attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
865
  **kwargs,
866
  ) -> Dict[str, Any]:
867
- # Truncate `input_ids` and `attention_mask` (if necessary) to prevent exceeding
868
- # the maximum sequence length
869
- if input_ids.shape[1] > self.config.n_positions:
870
- input_ids = input_ids[:, -self.config.n_positions :]
871
- if attention_mask is not None:
872
- attention_mask = attention_mask[:, -self.config.n_positions :]
873
-
874
  if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
875
  past_key_values = InferenceParams(
876
  max_seqlen=self.config.n_positions,
 
170
 
171
  class RotaryEmbedding(nn.Module):
172
  """Rotary positional embedding (RoPE).
173
+
174
  Reference:
175
  RoFormer: Enhanced Transformer with Rotary Position Embedding.
176
  https://arxiv.org/pdf/2104.09864.pdf.
177
+
178
  """
179
 
180
  def __init__(
 
261
  seqlen_offset: int = 0,
262
  **kwargs,
263
  ) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
264
  if (
265
+ self._seq_len_cached < qkv.shape[1] + seqlen_offset
266
+ or self._cos_cached.device != qkv.device
267
  or self._cos_cached.dtype != qkv.dtype
268
  or (self.training and self._cos_cached.is_inference())
269
  ):
270
+ self._update_cos_sin_cache(qkv.shape[1] + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
271
 
272
  if kv is None:
273
  return _apply_rotary_emb_qkv(
274
  qkv,
275
+ self._cos_cached[seqlen_offset:],
276
+ self._sin_cached[seqlen_offset:],
277
  )
278
  else:
279
  q = _apply_rotary_emb(
280
  qkv,
281
+ self._cos_cached[seqlen_offset:],
282
+ self._sin_cached[seqlen_offset:],
283
  )
284
  kv = _apply_rotary_emb_kv(
285
  kv,
286
+ self._cos_cached[seqlen_offset:],
287
+ self._sin_cached[seqlen_offset:],
288
  )
289
 
290
  return q, kv
 
496
  sequence_end = sequence_start + kv.shape[1]
497
 
498
  # When the current sequence length is equal to or larger than the maximum sequence length,
499
+ # we need to concatenate the current `kv` with the cached `kv` to expand its length
500
  if sequence_end >= inference_params.max_seqlen:
501
+ inference_params.key_value_memory_dict[layer_idx] = torch.concatenate((inference_params.key_value_memory_dict[layer_idx], kv), dim=1)
502
 
503
  inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
504
  kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...]
 
862
  attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
863
  **kwargs,
864
  ) -> Dict[str, Any]:
 
 
 
 
 
 
 
865
  if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
866
  past_key_values = InferenceParams(
867
  max_seqlen=self.config.n_positions,