simonJJJ commited on
Commit
0288b7e
1 Parent(s): a3d284e

Update modeling_qwen.py

Browse files
Files changed (1) hide show
  1. modeling_qwen.py +115 -82
modeling_qwen.py CHANGED
@@ -108,14 +108,6 @@ class QWenAttention(nn.Module):
108
  def __init__(self, config):
109
  super().__init__()
110
 
111
- max_positions = config.max_position_embeddings
112
- self.register_buffer(
113
- "bias",
114
- torch.tril(
115
- torch.ones((max_positions, max_positions), dtype=torch.bool)
116
- ).view(1, 1, max_positions, max_positions),
117
- persistent=False,
118
- )
119
  self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
120
  self.seq_length = config.seq_length
121
 
@@ -142,20 +134,6 @@ class QWenAttention(nn.Module):
142
  self.is_fp32 = not (config.bf16 or config.fp16)
143
  self.bf16 = config.bf16
144
 
145
- if config.rotary_pct == 1.0:
146
- self.rotary_ndims = None
147
- else:
148
- assert config.rotary_pct < 1
149
- self.rotary_ndims = int(
150
- self.hidden_size_per_attention_head * config.rotary_pct
151
- )
152
- dim = (
153
- self.rotary_ndims
154
- if self.rotary_ndims is not None
155
- else self.hidden_size_per_attention_head
156
- )
157
- self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
158
-
159
  self.use_dynamic_ntk = config.use_dynamic_ntk
160
  self.use_logn_attn = config.use_logn_attn
161
 
@@ -164,11 +142,10 @@ class QWenAttention(nn.Module):
164
  for i in range(1, 32768)
165
  ]
166
  self.logn_tensor = torch.tensor(logn_list)[None, :, None, None]
167
- self._ntk_cached = 1.0
168
 
169
  self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
170
 
171
- def _attn(self, query, key, value, attention_mask=None, head_mask=None):
172
  attn_weights = torch.matmul(query, key.transpose(-1, -2))
173
 
174
  if self.scale_attn_weights:
@@ -206,7 +183,7 @@ class QWenAttention(nn.Module):
206
  return attn_output, attn_weights
207
 
208
  def _upcast_and_reordered_attn(
209
- self, query, key, value, attention_mask=None, head_mask=None
210
  ):
211
  bsz, num_heads, q_seq_len, dk = query.size()
212
  _, _, k_seq_len, _ = key.size()
@@ -233,7 +210,7 @@ class QWenAttention(nn.Module):
233
  attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
234
 
235
  query_length, key_length = query.size(-2), key.size(-2)
236
- causal_mask = self.bias[
237
  :, :, key_length - query_length : key_length, :key_length
238
  ]
239
  mask_value = torch.finfo(attn_weights.dtype).min
@@ -274,6 +251,8 @@ class QWenAttention(nn.Module):
274
  def forward(
275
  self,
276
  hidden_states: Optional[Tuple[torch.FloatTensor]],
 
 
277
  layer_past: Optional[Tuple[torch.Tensor]] = None,
278
  attention_mask: Optional[torch.FloatTensor] = None,
279
  head_mask: Optional[torch.FloatTensor] = None,
@@ -284,43 +263,19 @@ class QWenAttention(nn.Module):
284
  ):
285
 
286
  mixed_x_layer = self.c_attn(hidden_states)
 
287
  query, key, value = mixed_x_layer.split(self.split_size, dim=2)
288
 
289
  query = self._split_heads(query, self.num_heads, self.head_dim)
290
  key = self._split_heads(key, self.num_heads, self.head_dim)
291
  value = self._split_heads(value, self.num_heads, self.head_dim)
292
 
293
- kv_seq_len = hidden_states.size()[1]
294
- if layer_past:
295
- # layer past[0] shape: bs * seq_len * head_num * dim
296
- kv_seq_len += layer_past[0].shape[1]
297
- if (
298
- self.use_dynamic_ntk
299
- and kv_seq_len == hidden_states.size()[1]
300
- and not self.training
301
- ):
302
- context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
303
- ntk_alpha = 2 ** math.ceil(context_value) - 1
304
- ntk_alpha = max(ntk_alpha, 1)
305
- self._ntk_cached = ntk_alpha
306
- else:
307
- ntk_alpha = self._ntk_cached
308
- rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to(
309
- hidden_states.device
310
- )
311
-
312
- if rotary_pos_emb is not None:
313
- if isinstance(rotary_pos_emb, tuple):
314
- rotary_pos_emb = rotary_pos_emb
315
- else:
316
- rotary_pos_emb = (rotary_pos_emb,) * 2
317
-
318
  if rotary_pos_emb is not None:
 
 
 
319
  q_pos_emb, k_pos_emb = rotary_pos_emb
320
  # Slice the pos emb for current inference
321
- cur_len = query.shape[1]
322
- q_pos_emb = q_pos_emb[:, -cur_len:, :, :]
323
- k_pos_emb = k_pos_emb[:, -cur_len:, :, :]
324
  query = apply_rotary_pos_emb(query, q_pos_emb)
325
  key = apply_rotary_pos_emb(key, k_pos_emb)
326
 
@@ -346,13 +301,14 @@ class QWenAttention(nn.Module):
346
  key = key.permute(0, 2, 1, 3)
347
  value = value.permute(0, 2, 1, 3)
348
  attn_output, attn_weight = self._attn(
349
- query, key, value, attention_mask, head_mask
350
  )
351
  context_layer = self._merge_heads(
352
  attn_output, self.num_heads, self.head_dim
353
  )
354
 
355
  attn_output = self.c_proj(context_layer)
 
356
  outputs = (attn_output, present)
357
  if output_attentions:
358
  outputs += (attn_weight,)
@@ -379,7 +335,6 @@ class QWenMLP(nn.Module):
379
  output = self.c_proj(intermediate_parallel)
380
  return output
381
 
382
-
383
  class QWenBlock(nn.Module):
384
  def __init__(self, config):
385
  super().__init__()
@@ -401,6 +356,8 @@ class QWenBlock(nn.Module):
401
  def forward(
402
  self,
403
  hidden_states: Optional[Tuple[torch.FloatTensor]],
 
 
404
  layer_past: Optional[Tuple[torch.Tensor]] = None,
405
  attention_mask: Optional[torch.FloatTensor] = None,
406
  head_mask: Optional[torch.FloatTensor] = None,
@@ -413,6 +370,8 @@ class QWenBlock(nn.Module):
413
 
414
  attn_outputs = self.attn(
415
  layernorm_output,
 
 
416
  layer_past=layer_past,
417
  attention_mask=attention_mask,
418
  head_mask=head_mask,
@@ -488,14 +447,50 @@ class QWenModel(QWenPreTrainedModel):
488
  self.embed_dim = config.hidden_size
489
 
490
  self.gradient_checkpointing = False
 
 
491
 
492
  self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
493
 
494
  self.drop = nn.Dropout(config.emb_dropout_prob)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  self.h = nn.ModuleList(
496
  [
497
  QWenBlock(
498
- config,
499
  )
500
  for i in range(config.num_hidden_layers)
501
  ]
@@ -556,7 +551,7 @@ class QWenModel(QWenPreTrainedModel):
556
  output_hidden_states: Optional[bool] = None,
557
  return_dict: Optional[bool] = None,
558
  ):
559
- if past_key_values is None and input_ids is not None and torch.any(input_ids == self.config.visual['image_start_id']):
560
  bos_pos = torch.where(input_ids == self.config.visual['image_start_id'])
561
  eos_pos = torch.where(input_ids == self.config.visual['image_start_id'] + 1)
562
  assert (bos_pos[0] == eos_pos[0]).all()
@@ -637,6 +632,25 @@ class QWenModel(QWenPreTrainedModel):
637
 
638
  hidden_states = inputs_embeds
639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
640
  hidden_states = self.drop(hidden_states)
641
  if images is not None:
642
  for idx, (i, a, b) in enumerate(img_pos):
@@ -670,6 +684,8 @@ class QWenModel(QWenPreTrainedModel):
670
  outputs = torch.utils.checkpoint.checkpoint(
671
  create_custom_forward(block),
672
  hidden_states,
 
 
673
  None,
674
  attention_mask,
675
  head_mask[i],
@@ -680,6 +696,8 @@ class QWenModel(QWenPreTrainedModel):
680
  outputs = block(
681
  hidden_states,
682
  layer_past=layer_past,
 
 
683
  attention_mask=attention_mask,
684
  head_mask=head_mask[i],
685
  encoder_hidden_states=encoder_hidden_states,
@@ -690,10 +708,10 @@ class QWenModel(QWenPreTrainedModel):
690
 
691
  hidden_states = outputs[0]
692
  if use_cache is True:
693
- presents = presents + (outputs[2 if output_attentions else 1],)
694
 
695
  if output_attentions:
696
- all_self_attentions = all_self_attentions + (outputs[1],)
697
 
698
  hidden_states = self.ln_f(hidden_states)
699
  hidden_states = hidden_states.view(output_shape)
@@ -890,10 +908,13 @@ class QWenLMHeadModel(QWenPreTrainedModel):
890
  append_history: bool = True,
891
  stream: Optional[bool] = _SENTINEL,
892
  stop_words_ids: Optional[List[List[int]]] = None,
 
893
  **kwargs,
894
  ) -> Tuple[str, HistoryType]:
 
 
895
  assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
896
- assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
897
  if history is None:
898
  history = []
899
  if stop_words_ids is None:
@@ -901,24 +922,25 @@ class QWenLMHeadModel(QWenPreTrainedModel):
901
 
902
  max_window_size = kwargs.get('max_window_size', None)
903
  if max_window_size is None:
904
- max_window_size = self.generation_config.max_window_size
905
  raw_text, context_tokens = make_context(
906
  tokenizer,
907
  query,
908
  history=history,
909
  system=system,
910
  max_window_size=max_window_size,
911
- chat_format=self.generation_config.chat_format,
912
  )
913
 
914
  stop_words_ids.extend(get_stop_words_ids(
915
- self.generation_config.chat_format, tokenizer
916
  ))
917
  input_ids = torch.tensor([context_tokens]).to(self.device)
918
  outputs = self.generate(
919
  input_ids,
920
- stop_words_ids = stop_words_ids,
921
- return_dict_in_generate = False,
 
922
  **kwargs,
923
  )
924
 
@@ -927,7 +949,7 @@ class QWenLMHeadModel(QWenPreTrainedModel):
927
  tokenizer,
928
  raw_text_len=len(raw_text),
929
  context_length=len(context_tokens),
930
- chat_format=self.generation_config.chat_format,
931
  verbose=False,
932
  errors='replace'
933
  )
@@ -945,9 +967,11 @@ class QWenLMHeadModel(QWenPreTrainedModel):
945
  system: str = "You are a helpful assistant.",
946
  stop_words_ids: Optional[List[List[int]]] = None,
947
  logits_processor: Optional[LogitsProcessorList] = None,
 
948
  **kwargs,
949
  ) -> Generator[str, Any, None]:
950
- assert self.generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
 
951
  if history is None:
952
  history = []
953
  if stop_words_ids is None:
@@ -955,23 +979,23 @@ class QWenLMHeadModel(QWenPreTrainedModel):
955
 
956
  max_window_size = kwargs.get('max_window_size', None)
957
  if max_window_size is None:
958
- max_window_size = self.generation_config.max_window_size
959
  raw_text, context_tokens = make_context(
960
  tokenizer,
961
  query,
962
  history=history,
963
  system=system,
964
  max_window_size=max_window_size,
965
- chat_format=self.generation_config.chat_format,
966
  )
967
 
968
  stop_words_ids.extend(get_stop_words_ids(
969
- self.generation_config.chat_format, tokenizer
970
  ))
971
  if stop_words_ids is not None:
972
  stop_words_logits_processor = StopWordsLogitsProcessor(
973
  stop_words_ids=stop_words_ids,
974
- eos_token_id=self.generation_config.eos_token_id,
975
  )
976
  if logits_processor is None:
977
  logits_processor = LogitsProcessorList([stop_words_logits_processor])
@@ -982,7 +1006,8 @@ class QWenLMHeadModel(QWenPreTrainedModel):
982
  from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
983
  self.__class__.generate_stream = NewGenerationMixin.generate
984
  self.__class__.sample_stream = NewGenerationMixin.sample_stream
985
- stream_config = StreamGenerationConfig(**self.generation_config.to_dict(), do_stream=True)
 
986
  def stream_generator():
987
  outputs = []
988
  for token in self.generate_stream(
@@ -1011,17 +1036,19 @@ class QWenLMHeadModel(QWenPreTrainedModel):
1011
  streamer: Optional["BaseStreamer"] = None,
1012
  **kwargs,
1013
  ) -> Union[GenerateOutput, torch.LongTensor]:
 
 
1014
  # Process stop_words_ids.
1015
  stop_words_ids = kwargs.pop("stop_words_ids", None)
1016
  if stop_words_ids is None and generation_config is not None:
1017
  stop_words_ids = getattr(generation_config, "stop_words_ids", None)
1018
  if stop_words_ids is None:
1019
- stop_words_ids = getattr(self.generation_config, "stop_words_ids", None)
1020
 
1021
  if stop_words_ids is not None:
1022
  stop_words_logits_processor = StopWordsLogitsProcessor(
1023
  stop_words_ids=stop_words_ids,
1024
- eos_token_id=self.generation_config.eos_token_id,
1025
  )
1026
  if logits_processor is None:
1027
  logits_processor = LogitsProcessorList([stop_words_logits_processor])
@@ -1069,14 +1096,19 @@ class RotaryEmbedding(torch.nn.Module):
1069
  self._ntk_alpha_cached = ntk_alpha
1070
  seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device)
1071
  freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
 
1072
  emb = torch.cat((freqs, freqs), dim=-1)
1073
  from einops import rearrange
1074
 
1075
- self._rotary_pos_emb_cache = rearrange(emb, "n d -> 1 n 1 d")
 
 
 
1076
 
1077
  def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
1078
  self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)
1079
- return self._rotary_pos_emb_cache[:, offset : offset + max_seq_len]
 
1080
 
1081
 
1082
  def _rotate_half(x):
@@ -1088,19 +1120,20 @@ def _rotate_half(x):
1088
 
1089
 
1090
  def apply_rotary_pos_emb(t, freqs):
 
1091
  if apply_rotary_emb_func is not None and t.is_cuda:
1092
  t_ = t.float()
1093
- freqs = freqs.squeeze(0).squeeze(1)
1094
- cos = freqs[:, : freqs.shape[-1] // 2].cos()
1095
- sin = freqs[:, : freqs.shape[-1] // 2].sin()
1096
  output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
1097
  return output
1098
  else:
1099
- rot_dim = freqs.shape[-1]
 
1100
  t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
1101
  t_ = t_.float()
1102
  t_pass_ = t_pass_.float()
1103
- t_ = (t_ * freqs.cos()) + (_rotate_half(t_) * freqs.sin())
1104
  return torch.cat((t_, t_pass_), dim=-1).type_as(t)
1105
 
1106
 
 
108
  def __init__(self, config):
109
  super().__init__()
110
 
 
 
 
 
 
 
 
 
111
  self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
112
  self.seq_length = config.seq_length
113
 
 
134
  self.is_fp32 = not (config.bf16 or config.fp16)
135
  self.bf16 = config.bf16
136
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  self.use_dynamic_ntk = config.use_dynamic_ntk
138
  self.use_logn_attn = config.use_logn_attn
139
 
 
142
  for i in range(1, 32768)
143
  ]
144
  self.logn_tensor = torch.tensor(logn_list)[None, :, None, None]
 
145
 
146
  self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
147
 
148
+ def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
149
  attn_weights = torch.matmul(query, key.transpose(-1, -2))
150
 
151
  if self.scale_attn_weights:
 
183
  return attn_output, attn_weights
184
 
185
  def _upcast_and_reordered_attn(
186
+ self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None
187
  ):
188
  bsz, num_heads, q_seq_len, dk = query.size()
189
  _, _, k_seq_len, _ = key.size()
 
210
  attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
211
 
212
  query_length, key_length = query.size(-2), key.size(-2)
213
+ causal_mask = registered_causal_mask[
214
  :, :, key_length - query_length : key_length, :key_length
215
  ]
216
  mask_value = torch.finfo(attn_weights.dtype).min
 
251
  def forward(
252
  self,
253
  hidden_states: Optional[Tuple[torch.FloatTensor]],
254
+ rotary_pos_emb: Optional[List[torch.Tensor]] = None,
255
+ registered_causal_mask: Optional[torch.Tensor] = None,
256
  layer_past: Optional[Tuple[torch.Tensor]] = None,
257
  attention_mask: Optional[torch.FloatTensor] = None,
258
  head_mask: Optional[torch.FloatTensor] = None,
 
263
  ):
264
 
265
  mixed_x_layer = self.c_attn(hidden_states)
266
+
267
  query, key, value = mixed_x_layer.split(self.split_size, dim=2)
268
 
269
  query = self._split_heads(query, self.num_heads, self.head_dim)
270
  key = self._split_heads(key, self.num_heads, self.head_dim)
271
  value = self._split_heads(value, self.num_heads, self.head_dim)
272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  if rotary_pos_emb is not None:
274
+ cur_len = query.shape[1]
275
+ rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
276
+ rotary_pos_emb = (rotary_pos_emb,) * 2
277
  q_pos_emb, k_pos_emb = rotary_pos_emb
278
  # Slice the pos emb for current inference
 
 
 
279
  query = apply_rotary_pos_emb(query, q_pos_emb)
280
  key = apply_rotary_pos_emb(key, k_pos_emb)
281
 
 
301
  key = key.permute(0, 2, 1, 3)
302
  value = value.permute(0, 2, 1, 3)
303
  attn_output, attn_weight = self._attn(
304
+ query, key, value, registered_causal_mask, attention_mask, head_mask
305
  )
306
  context_layer = self._merge_heads(
307
  attn_output, self.num_heads, self.head_dim
308
  )
309
 
310
  attn_output = self.c_proj(context_layer)
311
+
312
  outputs = (attn_output, present)
313
  if output_attentions:
314
  outputs += (attn_weight,)
 
335
  output = self.c_proj(intermediate_parallel)
336
  return output
337
 
 
338
  class QWenBlock(nn.Module):
339
  def __init__(self, config):
340
  super().__init__()
 
356
  def forward(
357
  self,
358
  hidden_states: Optional[Tuple[torch.FloatTensor]],
359
+ rotary_pos_emb: Optional[List[torch.Tensor]] = None,
360
+ registered_causal_mask: Optional[torch.Tensor] = None,
361
  layer_past: Optional[Tuple[torch.Tensor]] = None,
362
  attention_mask: Optional[torch.FloatTensor] = None,
363
  head_mask: Optional[torch.FloatTensor] = None,
 
370
 
371
  attn_outputs = self.attn(
372
  layernorm_output,
373
+ rotary_pos_emb,
374
+ registered_causal_mask=registered_causal_mask,
375
  layer_past=layer_past,
376
  attention_mask=attention_mask,
377
  head_mask=head_mask,
 
447
  self.embed_dim = config.hidden_size
448
 
449
  self.gradient_checkpointing = False
450
+ self.use_dynamic_ntk = config.use_dynamic_ntk
451
+ self.seq_length = config.seq_length
452
 
453
  self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
454
 
455
  self.drop = nn.Dropout(config.emb_dropout_prob)
456
+
457
+ if config.rotary_pct == 1.0:
458
+ self.rotary_ndims = None
459
+ else:
460
+ assert config.rotary_pct < 1
461
+ self.rotary_ndims = int(
462
+ config.kv_channels * config.rotary_pct
463
+ )
464
+ dim = (
465
+ self.rotary_ndims
466
+ if self.rotary_ndims is not None
467
+ else config.kv_channels
468
+ )
469
+ self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
470
+
471
+ self.use_flash_attn = config.use_flash_attn
472
+ self.is_fp32 = not (config.bf16 or config.fp16)
473
+ self.registered_causal_mask = None
474
+ # if (
475
+ # self.use_flash_attn
476
+ # and flash_attn_unpadded_func is not None
477
+ # and not self.is_fp32
478
+ # ):
479
+ # self.registered_causal_mask = None
480
+ # else:
481
+ # max_positions = config.max_position_embeddings
482
+ # self.register_buffer(
483
+ # "registered_causal_mask",
484
+ # torch.tril(
485
+ # torch.ones((max_positions, max_positions), dtype=torch.bool)
486
+ # ).view(1, 1, max_positions, max_positions),
487
+ # persistent=False,
488
+ # )
489
+
490
  self.h = nn.ModuleList(
491
  [
492
  QWenBlock(
493
+ config
494
  )
495
  for i in range(config.num_hidden_layers)
496
  ]
 
551
  output_hidden_states: Optional[bool] = None,
552
  return_dict: Optional[bool] = None,
553
  ):
554
+ if past_key_values is None and torch.any(input_ids == self.config.visual['image_start_id']):
555
  bos_pos = torch.where(input_ids == self.config.visual['image_start_id'])
556
  eos_pos = torch.where(input_ids == self.config.visual['image_start_id'] + 1)
557
  assert (bos_pos[0] == eos_pos[0]).all()
 
632
 
633
  hidden_states = inputs_embeds
634
 
635
+ kv_seq_len = hidden_states.size()[1]
636
+ if past_key_values[0] is not None:
637
+ # past key values[0][0] shape: bs * seq_len * head_num * dim
638
+ kv_seq_len += past_key_values[0][0].shape[1]
639
+ if (
640
+ self.use_dynamic_ntk
641
+ and kv_seq_len == hidden_states.size()[1]
642
+ and not self.training
643
+ ):
644
+ context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
645
+ ntk_alpha = 2 ** math.ceil(context_value) - 1
646
+ ntk_alpha = max(ntk_alpha, 1)
647
+ else:
648
+ ntk_alpha = self.rotary_emb._ntk_alpha_cached
649
+
650
+ rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)
651
+ for idx in range(len(rotary_pos_emb)):
652
+ rotary_pos_emb[idx] = rotary_pos_emb[idx].to(hidden_states.device)
653
+
654
  hidden_states = self.drop(hidden_states)
655
  if images is not None:
656
  for idx, (i, a, b) in enumerate(img_pos):
 
684
  outputs = torch.utils.checkpoint.checkpoint(
685
  create_custom_forward(block),
686
  hidden_states,
687
+ rotary_pos_emb,
688
+ self.registered_causal_mask,
689
  None,
690
  attention_mask,
691
  head_mask[i],
 
696
  outputs = block(
697
  hidden_states,
698
  layer_past=layer_past,
699
+ rotary_pos_emb=rotary_pos_emb,
700
+ registered_causal_mask=self.registered_causal_mask,
701
  attention_mask=attention_mask,
702
  head_mask=head_mask[i],
703
  encoder_hidden_states=encoder_hidden_states,
 
708
 
709
  hidden_states = outputs[0]
710
  if use_cache is True:
711
+ presents = presents + (outputs[1],)
712
 
713
  if output_attentions:
714
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
715
 
716
  hidden_states = self.ln_f(hidden_states)
717
  hidden_states = hidden_states.view(output_shape)
 
908
  append_history: bool = True,
909
  stream: Optional[bool] = _SENTINEL,
910
  stop_words_ids: Optional[List[List[int]]] = None,
911
+ generation_config: Optional[GenerationConfig] = None,
912
  **kwargs,
913
  ) -> Tuple[str, HistoryType]:
914
+ generation_config = generation_config if generation_config is not None else self.generation_config
915
+
916
  assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
917
+ assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
918
  if history is None:
919
  history = []
920
  if stop_words_ids is None:
 
922
 
923
  max_window_size = kwargs.get('max_window_size', None)
924
  if max_window_size is None:
925
+ max_window_size = generation_config.max_window_size
926
  raw_text, context_tokens = make_context(
927
  tokenizer,
928
  query,
929
  history=history,
930
  system=system,
931
  max_window_size=max_window_size,
932
+ chat_format=generation_config.chat_format,
933
  )
934
 
935
  stop_words_ids.extend(get_stop_words_ids(
936
+ generation_config.chat_format, tokenizer
937
  ))
938
  input_ids = torch.tensor([context_tokens]).to(self.device)
939
  outputs = self.generate(
940
  input_ids,
941
+ stop_words_ids=stop_words_ids,
942
+ return_dict_in_generate=False,
943
+ generation_config=generation_config,
944
  **kwargs,
945
  )
946
 
 
949
  tokenizer,
950
  raw_text_len=len(raw_text),
951
  context_length=len(context_tokens),
952
+ chat_format=generation_config.chat_format,
953
  verbose=False,
954
  errors='replace'
955
  )
 
967
  system: str = "You are a helpful assistant.",
968
  stop_words_ids: Optional[List[List[int]]] = None,
969
  logits_processor: Optional[LogitsProcessorList] = None,
970
+ generation_config: Optional[GenerationConfig] = None,
971
  **kwargs,
972
  ) -> Generator[str, Any, None]:
973
+ generation_config = generation_config if generation_config is not None else self.generation_config
974
+ assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
975
  if history is None:
976
  history = []
977
  if stop_words_ids is None:
 
979
 
980
  max_window_size = kwargs.get('max_window_size', None)
981
  if max_window_size is None:
982
+ max_window_size = generation_config.max_window_size
983
  raw_text, context_tokens = make_context(
984
  tokenizer,
985
  query,
986
  history=history,
987
  system=system,
988
  max_window_size=max_window_size,
989
+ chat_format=generation_config.chat_format,
990
  )
991
 
992
  stop_words_ids.extend(get_stop_words_ids(
993
+ generation_config.chat_format, tokenizer
994
  ))
995
  if stop_words_ids is not None:
996
  stop_words_logits_processor = StopWordsLogitsProcessor(
997
  stop_words_ids=stop_words_ids,
998
+ eos_token_id=generation_config.eos_token_id,
999
  )
1000
  if logits_processor is None:
1001
  logits_processor = LogitsProcessorList([stop_words_logits_processor])
 
1006
  from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
1007
  self.__class__.generate_stream = NewGenerationMixin.generate
1008
  self.__class__.sample_stream = NewGenerationMixin.sample_stream
1009
+ stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
1010
+
1011
  def stream_generator():
1012
  outputs = []
1013
  for token in self.generate_stream(
 
1036
  streamer: Optional["BaseStreamer"] = None,
1037
  **kwargs,
1038
  ) -> Union[GenerateOutput, torch.LongTensor]:
1039
+ generation_config = generation_config if generation_config is not None else self.generation_config
1040
+
1041
  # Process stop_words_ids.
1042
  stop_words_ids = kwargs.pop("stop_words_ids", None)
1043
  if stop_words_ids is None and generation_config is not None:
1044
  stop_words_ids = getattr(generation_config, "stop_words_ids", None)
1045
  if stop_words_ids is None:
1046
+ stop_words_ids = getattr(generation_config, "stop_words_ids", None)
1047
 
1048
  if stop_words_ids is not None:
1049
  stop_words_logits_processor = StopWordsLogitsProcessor(
1050
  stop_words_ids=stop_words_ids,
1051
+ eos_token_id=generation_config.eos_token_id,
1052
  )
1053
  if logits_processor is None:
1054
  logits_processor = LogitsProcessorList([stop_words_logits_processor])
 
1096
  self._ntk_alpha_cached = ntk_alpha
1097
  seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device)
1098
  freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
1099
+
1100
  emb = torch.cat((freqs, freqs), dim=-1)
1101
  from einops import rearrange
1102
 
1103
+ emb = rearrange(emb, "n d -> 1 n 1 d")
1104
+
1105
+ cos, sin = emb.cos(), emb.sin()
1106
+ self._rotary_pos_emb_cache = [cos, sin]
1107
 
1108
  def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
1109
  self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)
1110
+ cos, sin = self._rotary_pos_emb_cache
1111
+ return [cos[:, offset : offset + max_seq_len], sin[:, offset : offset + max_seq_len]]
1112
 
1113
 
1114
  def _rotate_half(x):
 
1120
 
1121
 
1122
  def apply_rotary_pos_emb(t, freqs):
1123
+ cos, sin = freqs
1124
  if apply_rotary_emb_func is not None and t.is_cuda:
1125
  t_ = t.float()
1126
+ cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2]
1127
+ sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2]
 
1128
  output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
1129
  return output
1130
  else:
1131
+ rot_dim = freqs[0].shape[-1]
1132
+ cos, sin = freqs
1133
  t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
1134
  t_ = t_.float()
1135
  t_pass_ = t_pass_.float()
1136
+ t_ = (t_ * cos) + (_rotate_half(t_) * sin)
1137
  return torch.cat((t_, t_pass_), dim=-1).type_as(t)
1138
 
1139