x54-729 commited on
Commit
2765861
1 Parent(s): 7e775a2

Support dynamic ntk rope

Browse files
Files changed (1) hide show
  1. modeling_internlm.py +163 -73
modeling_internlm.py CHANGED
@@ -19,26 +19,36 @@
19
  # limitations under the License.
20
  """ PyTorch InternLM model."""
21
  import math
 
 
22
  from typing import List, Optional, Tuple, Union
23
- import threading, queue
24
 
25
  import torch
26
  import torch.utils.checkpoint
27
  from torch import nn
28
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
-
30
  from transformers.activations import ACT2FN
31
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
32
- from transformers.modeling_utils import PreTrainedModel
33
  from transformers.generation.streamers import BaseStreamer
34
- from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
35
- from .configuration_internlm import InternLMConfig
 
 
 
 
 
 
 
 
 
 
36
 
 
37
 
38
  logger = logging.get_logger(__name__)
39
 
40
  _CONFIG_FOR_DOC = "InternLMConfig"
41
 
 
42
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
43
  def _make_causal_mask(
44
  input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
@@ -73,6 +83,8 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
73
 
74
 
75
  class InternLMRMSNorm(nn.Module):
 
 
76
  def __init__(self, hidden_size, eps=1e-6):
77
  """
78
  InternLMRMSNorm is equivalent to T5LayerNorm
@@ -93,6 +105,14 @@ class InternLMRMSNorm(nn.Module):
93
 
94
 
95
  class InternLMRotaryEmbedding(torch.nn.Module):
 
 
 
 
 
 
 
 
96
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
97
  super().__init__()
98
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
@@ -124,6 +144,66 @@ class InternLMRotaryEmbedding(torch.nn.Module):
124
  )
125
 
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  def rotate_half(x):
128
  """Rotates half the hidden dims of the input."""
129
  x1 = x[..., : x.shape[-1] // 2]
@@ -135,10 +215,18 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
135
  # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
136
  cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
137
  sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
138
- cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
139
- sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
140
- q_embed = (q * cos) + (rotate_half(q) * sin)
141
- k_embed = (k * cos) + (rotate_half(k) * sin)
 
 
 
 
 
 
 
 
142
  return q_embed, k_embed
143
 
144
 
@@ -179,7 +267,25 @@ class InternLMAttention(nn.Module):
179
  self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
180
  self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
181
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
182
- self.rotary_emb = InternLMRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
185
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
@@ -199,20 +305,18 @@ class InternLMAttention(nn.Module):
199
  key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
200
  value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
201
 
202
- kv_seq_len = key_states.shape[-2]
203
- if past_key_value is not None:
204
- kv_seq_len += past_key_value[0].shape[-2]
205
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
206
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
207
- # [bsz, nh, t, hd]
208
-
209
  if past_key_value is not None:
210
  # reuse k, v, self_attention
211
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
212
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
213
 
 
214
  past_key_value = (key_states, value_states) if use_cache else None
215
 
 
 
 
 
216
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
217
 
218
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
@@ -322,11 +426,9 @@ INTERNLM_START_DOCSTRING = r"""
322
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
323
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
324
  etc.)
325
-
326
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
327
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
328
  and behavior.
329
-
330
  Parameters:
331
  config ([`InternLMConfig`]):
332
  Model configuration class with all the parameters of the model. Initializing with a config file does not
@@ -367,44 +469,34 @@ INTERNLM_INPUTS_DOCSTRING = r"""
367
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
368
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
369
  it.
370
-
371
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
372
  [`PreTrainedTokenizer.__call__`] for details.
373
-
374
  [What are input IDs?](../glossary#input-ids)
375
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
376
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
377
-
378
  - 1 for tokens that are **not masked**,
379
  - 0 for tokens that are **masked**.
380
-
381
  [What are attention masks?](../glossary#attention-mask)
382
-
383
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
384
  [`PreTrainedTokenizer.__call__`] for details.
385
-
386
  If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
387
  `past_key_values`).
388
-
389
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
390
  and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
391
  information on the default strategy.
392
-
393
  - 1 indicates the head is **not masked**,
394
  - 0 indicates the head is **masked**.
395
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
396
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
397
  config.n_positions - 1]`.
398
-
399
  [What are position IDs?](../glossary#position-ids)
400
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
 
401
  Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
402
  `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
403
  `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
404
-
405
  Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
406
  blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
407
-
408
  If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
409
  don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
410
  `decoder_input_ids` of shape `(batch_size, sequence_length)`.
@@ -433,10 +525,10 @@ INTERNLM_INPUTS_DOCSTRING = r"""
433
  class InternLMModel(InternLMPreTrainedModel):
434
  """
435
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLMDecoderLayer`]
436
-
437
  Args:
438
  config: InternLMConfig
439
  """
 
440
  _auto_class = "AutoModel"
441
 
442
  def __init__(self, config: InternLMConfig):
@@ -662,20 +754,14 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
662
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
663
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
664
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
665
-
666
  Returns:
667
-
668
  Example:
669
-
670
  ```python
671
  >>> from transformers import AutoTokenizer, InternLMForCausalLM
672
-
673
  >>> model = InternLMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
674
  >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
675
-
676
  >>> prompt = "Hey, are you consciours? Can you talk to me?"
677
  >>> inputs = tokenizer(prompt, return_tensors="pt")
678
-
679
  >>> # Generate
680
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
681
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
@@ -765,50 +851,56 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
765
  for layer_past in past_key_values:
766
  reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
767
  return reordered_past
768
-
769
  def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []):
770
  prompt = ""
771
  for record in history:
772
  prompt += f"""<|User|>:{record[0]}<eoh>\n<|Bot|>:{record[1]}<eoa>\n"""
773
  prompt += f"""<|User|>:{query}<eoh>\n<|Bot|>:"""
774
  return tokenizer([prompt], return_tensors="pt")
775
-
776
  @torch.no_grad()
777
- def chat(self,
778
- tokenizer,
779
- query: str,
780
- history: List[Tuple[str, str]] = [],
781
- streamer: Optional[BaseStreamer] = None,
782
- max_new_tokens: int = 1024,
783
- do_sample: bool = True,
784
- temperature: float = 0.8,
785
- top_p: float = 0.8,
786
- **kwargs):
 
 
787
  inputs = self.build_inputs(tokenizer, query, history)
788
  inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
789
- outputs = self.generate(**inputs,
790
- streamer=streamer,
791
- max_new_tokens=max_new_tokens,
792
- do_sample=do_sample,
793
- temperature=temperature,
794
- top_p=top_p,
795
- **kwargs)
796
- outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]):]
 
 
797
  response = tokenizer.decode(outputs, skip_special_tokens=True)
798
  response = response.split("<eoa>")[0]
799
  history = history + [(query, response)]
800
  return response, history
801
-
802
  @torch.no_grad()
803
- def stream_chat(self,
804
- tokenizer,
805
- query: str,
806
- history: List[Tuple[str, str]] = [],
807
- max_new_tokens: int = 1024,
808
- do_sample: bool = True,
809
- temperature: float = 0.8,
810
- top_p: float = 0.8,
811
- **kwargs):
 
 
812
  """
813
  Return a generator in format: (response, history)
814
  Eg.
@@ -854,12 +946,12 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
854
  tokenizer=tokenizer,
855
  query=query,
856
  streamer=ChatStreamer(tokenizer=tokenizer),
857
- history=history,
858
  max_new_tokens=max_new_tokens,
859
  do_sample=do_sample,
860
  temperature=temperature,
861
  top_p=top_p,
862
- **kwargs
863
  )
864
 
865
  def consumer():
@@ -877,10 +969,8 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
877
  @add_start_docstrings(
878
  """
879
  The InternLM Model transformer with a sequence classification head on top (linear layer).
880
-
881
  [`InternLMForSequenceClassification`] uses the last token in order to do the classification, as other causal models
882
  (e.g. GPT-2) do.
883
-
884
  Since it does classification on the last token, it requires to know the position of the last token. If a
885
  `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
886
  no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
 
19
  # limitations under the License.
20
  """ PyTorch InternLM model."""
21
  import math
22
+ import queue
23
+ import threading
24
  from typing import List, Optional, Tuple, Union
 
25
 
26
  import torch
27
  import torch.utils.checkpoint
28
  from torch import nn
29
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 
30
  from transformers.activations import ACT2FN
 
 
31
  from transformers.generation.streamers import BaseStreamer
32
+ from transformers.modeling_outputs import (
33
+ BaseModelOutputWithPast,
34
+ CausalLMOutputWithPast,
35
+ SequenceClassifierOutputWithPast,
36
+ )
37
+ from transformers.modeling_utils import PreTrainedModel
38
+ from transformers.utils import (
39
+ add_start_docstrings,
40
+ add_start_docstrings_to_model_forward,
41
+ logging,
42
+ replace_return_docstrings,
43
+ )
44
 
45
+ from .configuration_internlm import InternLMConfig
46
 
47
  logger = logging.get_logger(__name__)
48
 
49
  _CONFIG_FOR_DOC = "InternLMConfig"
50
 
51
+
52
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
53
  def _make_causal_mask(
54
  input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
 
83
 
84
 
85
  class InternLMRMSNorm(nn.Module):
86
+ """RMSNorm implemention."""
87
+
88
  def __init__(self, hidden_size, eps=1e-6):
89
  """
90
  InternLMRMSNorm is equivalent to T5LayerNorm
 
105
 
106
 
107
  class InternLMRotaryEmbedding(torch.nn.Module):
108
+ """Implement InternLM's rotary embedding.
109
+
110
+ Args:
111
+ dim (int): Characteristic dimension of each self-attentional head.
112
+ max_position_embeddings (int, optional): Model's training length. Defaults to 2048.
113
+ base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000.
114
+ device (Any, optional): Running device. Defaults to None.
115
+ """
116
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
117
  super().__init__()
118
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
 
144
  )
145
 
146
 
147
+ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
148
+ """Implement InternLM's DyanmicNTK extrapolation method, thereby broadening the model support context to 16K.
149
+
150
+ Args:
151
+ dim (int): Characteristic dimension of each self-attentional head.
152
+ max_position_embeddings (int, optional): Model's training length. Defaults to 2048.
153
+ base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000.
154
+ device (Any, optional): Running device. Defaults to None.
155
+ scaling_factor (float, optional): NTK method extrapolation coefficient. Defaults to 1.0.
156
+ """
157
+
158
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
159
+ super().__init__()
160
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
161
+ self.register_buffer("inv_freq", inv_freq)
162
+ self.dim = dim
163
+ self.base = base
164
+ self.scaling_factor = scaling_factor
165
+
166
+ # Build here to make `torch.jit.trace` work.
167
+ self.max_position_embeddings = max_position_embeddings
168
+ self.max_seq_len_cached = max_position_embeddings
169
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
170
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
171
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
172
+ emb = torch.cat((freqs, freqs), dim=-1)
173
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
174
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
175
+
176
+ def _update_cached(self, x, seq_len=None):
177
+ self.max_seq_len_cached = max(seq_len, self.max_position_embeddings)
178
+ if seq_len > self.max_position_embeddings:
179
+ base = self.base * (
180
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
181
+ ) ** (self.dim / (self.dim - 2))
182
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
183
+ else:
184
+ inv_freq = self.inv_freq
185
+ t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype)
186
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
187
+ emb = torch.cat((freqs, freqs), dim=-1)
188
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
189
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
190
+
191
+ def forward(self, x, seq_len=None):
192
+ # x: [bs, num_attention_heads, seq_len, head_size]
193
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
194
+ if seq_len <= self.max_position_embeddings:
195
+ # Reset the tables if the sequence length has changed,
196
+ if self.max_seq_len_cached > self.max_position_embeddings:
197
+ self._update_cached(x, seq_len)
198
+ else:
199
+ self._update_cached(x, seq_len)
200
+
201
+ return (
202
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
203
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
204
+ )
205
+
206
+
207
  def rotate_half(x):
208
  """Rotates half the hidden dims of the input."""
209
  x1 = x[..., : x.shape[-1] // 2]
 
215
  # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
216
  cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
217
  sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
218
+ cos = cos.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
219
+ sin = sin.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
220
+ if q.size(2) == 1:
221
+ q_embed = (q * cos[:, :, -1, :]) + (rotate_half(q) * sin[:, :, -1, :])
222
+ else:
223
+ q_embed = (q * cos) + (rotate_half(q) * sin)
224
+
225
+ if k.size(2) == 1:
226
+ k_embed = (k * cos[:, :, -1, :]) + (rotate_half(k) * sin[:, :, -1, :])
227
+ else:
228
+ k_embed = (k * cos) + (rotate_half(k) * sin)
229
+
230
  return q_embed, k_embed
231
 
232
 
 
267
  self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
268
  self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
269
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
270
+ self.rotary_emb = self._init_rope()
271
+
272
+ def _init_rope(self):
273
+ if self.config.rotary["type"] == "origin":
274
+ self.rotary_emb = InternLMRotaryEmbedding(
275
+ self.head_dim,
276
+ max_position_embeddings=self.max_position_embeddings,
277
+ base=self.config.rotary["base"],
278
+ )
279
+ elif self.config.rotary["type"] == "dynamic":
280
+ self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding(
281
+ self.head_dim,
282
+ max_position_embeddings=self.max_position_embeddings,
283
+ base=self.config.rotary["base"],
284
+ scaling_factor=self.config.rotary.get("scaling_factor", 1.0),
285
+ )
286
+ else:
287
+ raise ValueError("Currently we only support rotary embedding's type being one of ('origin', 'dynamic').")
288
+ return self.rotary_emb
289
 
290
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
291
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
 
305
  key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
306
  value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
307
 
 
 
 
 
 
 
 
308
  if past_key_value is not None:
309
  # reuse k, v, self_attention
310
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
311
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
312
 
313
+ # print(use_cache)
314
  past_key_value = (key_states, value_states) if use_cache else None
315
 
316
+ kv_seq_len = key_states.shape[-2]
317
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
318
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
319
+
320
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
321
 
322
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
 
426
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
427
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
428
  etc.)
 
429
  This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
430
  Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
431
  and behavior.
 
432
  Parameters:
433
  config ([`InternLMConfig`]):
434
  Model configuration class with all the parameters of the model. Initializing with a config file does not
 
469
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
470
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
471
  it.
 
472
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
473
  [`PreTrainedTokenizer.__call__`] for details.
 
474
  [What are input IDs?](../glossary#input-ids)
475
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
476
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
 
477
  - 1 for tokens that are **not masked**,
478
  - 0 for tokens that are **masked**.
 
479
  [What are attention masks?](../glossary#attention-mask)
 
480
  Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
481
  [`PreTrainedTokenizer.__call__`] for details.
 
482
  If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
483
  `past_key_values`).
 
484
  If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
485
  and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
486
  information on the default strategy.
 
487
  - 1 indicates the head is **not masked**,
488
  - 0 indicates the head is **masked**.
489
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
490
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
491
  config.n_positions - 1]`.
 
492
  [What are position IDs?](../glossary#position-ids)
493
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or
494
+ when `config.use_cache=True`):
495
  Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
496
  `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
497
  `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
 
498
  Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
499
  blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
 
500
  If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
501
  don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
502
  `decoder_input_ids` of shape `(batch_size, sequence_length)`.
 
525
  class InternLMModel(InternLMPreTrainedModel):
526
  """
527
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLMDecoderLayer`]
 
528
  Args:
529
  config: InternLMConfig
530
  """
531
+
532
  _auto_class = "AutoModel"
533
 
534
  def __init__(self, config: InternLMConfig):
 
754
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
755
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
756
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
 
757
  Returns:
 
758
  Example:
 
759
  ```python
760
  >>> from transformers import AutoTokenizer, InternLMForCausalLM
 
761
  >>> model = InternLMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
762
  >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
 
763
  >>> prompt = "Hey, are you consciours? Can you talk to me?"
764
  >>> inputs = tokenizer(prompt, return_tensors="pt")
 
765
  >>> # Generate
766
  >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
767
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 
851
  for layer_past in past_key_values:
852
  reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
853
  return reordered_past
854
+
855
  def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []):
856
  prompt = ""
857
  for record in history:
858
  prompt += f"""<|User|>:{record[0]}<eoh>\n<|Bot|>:{record[1]}<eoa>\n"""
859
  prompt += f"""<|User|>:{query}<eoh>\n<|Bot|>:"""
860
  return tokenizer([prompt], return_tensors="pt")
861
+
862
  @torch.no_grad()
863
+ def chat(
864
+ self,
865
+ tokenizer,
866
+ query: str,
867
+ history: List[Tuple[str, str]] = [],
868
+ streamer: Optional[BaseStreamer] = None,
869
+ max_new_tokens: int = 1024,
870
+ do_sample: bool = True,
871
+ temperature: float = 0.8,
872
+ top_p: float = 0.8,
873
+ **kwargs,
874
+ ):
875
  inputs = self.build_inputs(tokenizer, query, history)
876
  inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
877
+ outputs = self.generate(
878
+ **inputs,
879
+ streamer=streamer,
880
+ max_new_tokens=max_new_tokens,
881
+ do_sample=do_sample,
882
+ temperature=temperature,
883
+ top_p=top_p,
884
+ **kwargs,
885
+ )
886
+ outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
887
  response = tokenizer.decode(outputs, skip_special_tokens=True)
888
  response = response.split("<eoa>")[0]
889
  history = history + [(query, response)]
890
  return response, history
891
+
892
  @torch.no_grad()
893
+ def stream_chat(
894
+ self,
895
+ tokenizer,
896
+ query: str,
897
+ history: List[Tuple[str, str]] = [],
898
+ max_new_tokens: int = 1024,
899
+ do_sample: bool = True,
900
+ temperature: float = 0.8,
901
+ top_p: float = 0.8,
902
+ **kwargs,
903
+ ):
904
  """
905
  Return a generator in format: (response, history)
906
  Eg.
 
946
  tokenizer=tokenizer,
947
  query=query,
948
  streamer=ChatStreamer(tokenizer=tokenizer),
949
+ history=history,
950
  max_new_tokens=max_new_tokens,
951
  do_sample=do_sample,
952
  temperature=temperature,
953
  top_p=top_p,
954
+ **kwargs,
955
  )
956
 
957
  def consumer():
 
969
  @add_start_docstrings(
970
  """
971
  The InternLM Model transformer with a sequence classification head on top (linear layer).
 
972
  [`InternLMForSequenceClassification`] uses the last token in order to do the classification, as other causal models
973
  (e.g. GPT-2) do.
 
974
  Since it does classification on the last token, it requires to know the position of the last token. If a
975
  `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
976
  no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the