InternLM-Math commited on
Commit
c314c82
1 Parent(s): 96baa57

Update modeling_internlm2.py

Browse files
Files changed (1) hide show
  1. modeling_internlm2.py +196 -76
modeling_internlm2.py CHANGED
@@ -1,10 +1,6 @@
1
- # coding=utf-8
2
- # # Copyright (c) InternLM. All rights reserved.
3
  #
4
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
- # and OPT implementations in this library. It has been modified from its
6
- # original forms to accommodate minor architectural differences compared
7
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
  #
9
  # Licensed under the Apache License, Version 2.0 (the "License");
10
  # you may not use this file except in compliance with the License.
@@ -25,6 +21,7 @@ import warnings
25
  from typing import List, Optional, Tuple, Union
26
 
27
  import torch
 
28
  import torch.utils.checkpoint
29
  from einops import rearrange
30
  from torch import nn
@@ -54,6 +51,31 @@ logger = logging.get_logger(__name__)
54
 
55
  _CONFIG_FOR_DOC = "InternLM2Config"
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
59
  def _make_causal_mask(
@@ -88,6 +110,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
88
  return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
89
 
90
 
 
91
  class InternLM2RMSNorm(nn.Module):
92
  def __init__(self, hidden_size, eps=1e-6):
93
  """
@@ -105,6 +128,7 @@ class InternLM2RMSNorm(nn.Module):
105
  return self.weight * hidden_states.to(input_dtype)
106
 
107
 
 
108
  class InternLM2RotaryEmbedding(nn.Module):
109
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
110
  super().__init__()
@@ -133,7 +157,7 @@ class InternLM2RotaryEmbedding(nn.Module):
133
  def forward(self, x, seq_len=None):
134
  # x: [bs, num_attention_heads, seq_len, head_size]
135
  if seq_len > self.max_seq_len_cached:
136
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
137
 
138
  return (
139
  self.cos_cached[:seq_len].to(dtype=x.dtype),
@@ -141,6 +165,7 @@ class InternLM2RotaryEmbedding(nn.Module):
141
  )
142
 
143
 
 
144
  class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
145
  """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
146
 
@@ -160,6 +185,7 @@ class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
160
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
161
 
162
 
 
163
  class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
164
  """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
165
  Credits to the Reddit users /u/bloc97 and /u/emozilla.
@@ -188,6 +214,7 @@ class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
188
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
189
 
190
 
 
191
  def rotate_half(x):
192
  """Rotates half the hidden dims of the input."""
193
  x1 = x[..., : x.shape[-1] // 2]
@@ -195,22 +222,13 @@ def rotate_half(x):
195
  return torch.cat((-x2, x1), dim=-1)
196
 
197
 
198
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
199
- # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
200
- cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
201
- sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
202
- cos = cos.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
203
- sin = sin.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
204
- if q.size(2) == 1:
205
- q_embed = (q * cos[:, :, -1, :]) + (rotate_half(q) * sin[:, :, -1, :])
206
- else:
207
- q_embed = (q * cos) + (rotate_half(q) * sin)
208
-
209
- if k.size(2) == 1:
210
- k_embed = (k * cos[:, :, -1, :]) + (rotate_half(k) * sin[:, :, -1, :])
211
- else:
212
- k_embed = (k * cos) + (rotate_half(k) * sin)
213
-
214
  return q_embed, k_embed
215
 
216
 
@@ -231,6 +249,7 @@ class InternLM2MLP(nn.Module):
231
  return down_proj
232
 
233
 
 
234
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
235
  """
236
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -243,6 +262,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
243
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
244
 
245
 
 
246
  class InternLM2Attention(nn.Module):
247
  """Multi-headed attention from 'Attention Is All You Need' paper"""
248
 
@@ -287,10 +307,17 @@ class InternLM2Attention(nn.Module):
287
  self.head_dim,
288
  max_position_embeddings=self.max_position_embeddings,
289
  base=self.config.rope_theta,
290
- scaling_factor=scaling_factor
 
 
 
 
 
 
 
291
  )
292
  else:
293
- raise ValueError("Currently we only support rotary embedding's type being 'dynamic'.")
294
  return self.rotary_emb
295
 
296
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
@@ -384,6 +411,7 @@ class InternLM2Attention(nn.Module):
384
  return attn_output, attn_weights, past_key_value
385
 
386
 
 
387
  class InternLM2FlashAttention2(InternLM2Attention):
388
  """
389
  InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
@@ -420,9 +448,8 @@ class InternLM2FlashAttention2(InternLM2Attention):
420
  qkv_states = rearrange(
421
  qkv_states,
422
  "b q (h gs d) -> b q h gs d",
423
- gs=self.num_heads + 2 * self.num_key_value_heads,
424
  d=self.head_dim,
425
- q=q_len,
426
  )
427
 
428
  query_states = qkv_states[..., : self.num_key_value_groups, :]
@@ -430,6 +457,10 @@ class InternLM2FlashAttention2(InternLM2Attention):
430
  key_states = qkv_states[..., -2, :]
431
  value_states = qkv_states[..., -1, :]
432
 
 
 
 
 
433
  kv_seq_len = key_states.shape[-2]
434
  if past_key_value is not None:
435
  kv_seq_len += past_key_value[0].shape[-2]
@@ -449,36 +480,9 @@ class InternLM2FlashAttention2(InternLM2Attention):
449
  key_states = key_states.transpose(1, 2)
450
  value_states = value_states.transpose(1, 2)
451
 
452
- dropout_rate = 0.0 if not self.training else self.attention_dropout
453
-
454
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
455
- # therefore the input hidden states gets silently casted in float32. Hence, we need
456
- # cast them back in the correct dtype just to be sure everything works as expected.
457
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
458
- # in fp32. (InternLM2RMSNorm handles it correctly)
459
-
460
- input_dtype = query_states.dtype
461
- if input_dtype == torch.float32:
462
- # Handle the case where the model is quantized
463
- if hasattr(self.config, "_pre_quantization_dtype"):
464
- target_dtype = self.config._pre_quantization_dtype
465
- else:
466
- target_dtype = self.q_proj.weight.dtype
467
-
468
- logger.warning_once(
469
- f"The input hidden states seems to be silently casted in float32, this might be related to"
470
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back "
471
- f"the input in {target_dtype}."
472
- )
473
-
474
- query_states = query_states.to(target_dtype)
475
- key_states = key_states.to(target_dtype)
476
- value_states = value_states.to(target_dtype)
477
-
478
  attn_output = self._flash_attention_forward(
479
- query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
480
  )
481
-
482
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
483
  attn_output = self.wo(attn_output)
484
 
@@ -487,16 +491,112 @@ class InternLM2FlashAttention2(InternLM2Attention):
487
 
488
  return attn_output, attn_weights, past_key_value
489
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
 
 
491
  class InternLM2DecoderLayer(nn.Module):
492
  def __init__(self, config: InternLM2Config):
493
  super().__init__()
494
  self.hidden_size = config.hidden_size
495
- self.attention = (
496
- InternLM2Attention(config=config)
497
- if not getattr(config, "_flash_attn_2_enabled", False)
498
- else InternLM2FlashAttention2(config=config)
499
- )
500
  self.feed_forward = InternLM2MLP(config)
501
  self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
502
  self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -581,6 +681,7 @@ InternLM2_START_DOCSTRING = r"""
581
  """
582
 
583
 
 
584
  @add_start_docstrings(
585
  "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
586
  InternLM2_START_DOCSTRING,
@@ -591,7 +692,6 @@ class InternLM2PreTrainedModel(PreTrainedModel):
591
  supports_gradient_checkpointing = True
592
  _no_split_modules = ["InternLM2DecoderLayer"]
593
  _skip_keys_device_placement = "past_key_values"
594
- _supports_flash_attn_2 = True
595
 
596
  def _init_weights(self, module):
597
  std = self.config.initializer_range
@@ -670,6 +770,7 @@ InternLM2_INPUTS_DOCSTRING = r"""
670
  """
671
 
672
 
 
673
  @add_start_docstrings(
674
  "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
675
  InternLM2_START_DOCSTRING,
@@ -688,8 +789,10 @@ class InternLM2Model(InternLM2PreTrainedModel):
688
  super().__init__(config)
689
  self.padding_idx = config.pad_token_id
690
  self.vocab_size = config.vocab_size
 
691
 
692
  self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 
693
  self.layers = nn.ModuleList([InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
694
  self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
695
 
@@ -703,7 +806,6 @@ class InternLM2Model(InternLM2PreTrainedModel):
703
  def set_input_embeddings(self, value):
704
  self.tok_embeddings = value
705
 
706
- # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
707
  def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
708
  # create causal mask
709
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -748,6 +850,9 @@ class InternLM2Model(InternLM2PreTrainedModel):
748
 
749
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
750
 
 
 
 
751
  # retrieve input_ids and inputs_embeds
752
  if input_ids is not None and inputs_embeds is not None:
753
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@@ -773,14 +878,18 @@ class InternLM2Model(InternLM2PreTrainedModel):
773
 
774
  if inputs_embeds is None:
775
  inputs_embeds = self.tok_embeddings(input_ids)
776
- # embed positions
777
- if attention_mask is None:
778
- attention_mask = torch.ones(
779
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
 
 
 
 
 
 
 
780
  )
781
- attention_mask = self._prepare_decoder_attention_mask(
782
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
783
- )
784
 
785
  # embed positions
786
  hidden_states = inputs_embeds
@@ -854,6 +963,7 @@ class InternLM2Model(InternLM2PreTrainedModel):
854
  )
855
 
856
 
 
857
  class InternLM2ForCausalLM(InternLM2PreTrainedModel):
858
  _auto_class = "AutoModelForCausalLM"
859
 
@@ -1023,12 +1133,15 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1023
  )
1024
  return reordered_past
1025
 
1026
- def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []):
1027
  prompt = ""
 
 
 
 
1028
  for record in history:
1029
- prompt += f"""[UNUSED_TOKEN_146]user\n{record[0]}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n{record[1]}[UNUSED_TOKEN_145]\n"""
1030
- prompt += f"""[UNUSED_TOKEN_146]user\n{query}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n"""
1031
- print(prompt)
1032
  return tokenizer([prompt], return_tensors="pt")
1033
 
1034
  @torch.no_grad()
@@ -1042,10 +1155,15 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1042
  do_sample: bool = True,
1043
  temperature: float = 0.8,
1044
  top_p: float = 0.8,
 
 
 
1045
  **kwargs,
1046
  ):
1047
- inputs = self.build_inputs(tokenizer, query, history)
1048
  inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
 
 
1049
  outputs = self.generate(
1050
  **inputs,
1051
  streamer=streamer,
@@ -1053,11 +1171,12 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1053
  do_sample=do_sample,
1054
  temperature=temperature,
1055
  top_p=top_p,
 
1056
  **kwargs,
1057
  )
1058
  outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
1059
  response = tokenizer.decode(outputs, skip_special_tokens=True)
1060
- response = response.split("[UNUSED_TOKEN_145]")[0]
1061
  history = history + [(query, response)]
1062
  return response, history
1063
 
@@ -1110,7 +1229,7 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1110
  return
1111
 
1112
  token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
1113
- if token.strip() != "[UNUSED_TOKEN_145]":
1114
  self.response = self.response + token
1115
  history = self.history + [(self.query, self.response)]
1116
  self.queue.put((self.response, history))
@@ -1143,6 +1262,7 @@ class InternLM2ForCausalLM(InternLM2PreTrainedModel):
1143
  return consumer()
1144
 
1145
 
 
1146
  @add_start_docstrings(
1147
  """
1148
  The InternLM2 Model transformer with a sequence classification head on top (linear layer).
 
1
+ # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
 
2
  #
3
+ # This code is based on transformers/src/transformers/models/llama/modeling_llama.py
 
 
 
4
  #
5
  # Licensed under the Apache License, Version 2.0 (the "License");
6
  # you may not use this file except in compliance with the License.
 
21
  from typing import List, Optional, Tuple, Union
22
 
23
  import torch
24
+ import torch.nn.functional as F
25
  import torch.utils.checkpoint
26
  from einops import rearrange
27
  from torch import nn
 
51
 
52
  _CONFIG_FOR_DOC = "InternLM2Config"
53
 
54
+ flash_attn_func, flash_attn_varlen_func = None, None
55
+ pad_input, index_first_axis, unpad_input = None, None, None
56
+ def _import_flash_attn():
57
+ global flash_attn_func, flash_attn_varlen_func
58
+ global pad_input, index_first_axis, unpad_input
59
+ try:
60
+ from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
61
+ from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input
62
+ flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
63
+ pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
64
+ except ImportError:
65
+ raise ImportError("flash_attn is not installed.")
66
+
67
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
68
+ def _get_unpad_data(attention_mask):
69
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
70
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
71
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
72
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
73
+ return (
74
+ indices,
75
+ cu_seqlens,
76
+ max_seqlen_in_batch,
77
+ )
78
+
79
 
80
  # Copied from transformers.models.bart.modeling_bart._make_causal_mask
81
  def _make_causal_mask(
 
110
  return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
111
 
112
 
113
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->InternLM2
114
  class InternLM2RMSNorm(nn.Module):
115
  def __init__(self, hidden_size, eps=1e-6):
116
  """
 
128
  return self.weight * hidden_states.to(input_dtype)
129
 
130
 
131
+ # Copied from transformers.model.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM2
132
  class InternLM2RotaryEmbedding(nn.Module):
133
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
134
  super().__init__()
 
157
  def forward(self, x, seq_len=None):
158
  # x: [bs, num_attention_heads, seq_len, head_size]
159
  if seq_len > self.max_seq_len_cached:
160
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=torch.float32)
161
 
162
  return (
163
  self.cos_cached[:seq_len].to(dtype=x.dtype),
 
165
  )
166
 
167
 
168
+ # Copied from transformers.model.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->InternLM2
169
  class InternLM2LinearScalingRotaryEmbedding(InternLM2RotaryEmbedding):
170
  """InternLM2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
171
 
 
185
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
186
 
187
 
188
+ # Copied from transformers.model.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->InternLM2
189
  class InternLM2DynamicNTKScalingRotaryEmbedding(InternLM2RotaryEmbedding):
190
  """InternLM2RotaryEmbedding extended with Dynamic NTK scaling.
191
  Credits to the Reddit users /u/bloc97 and /u/emozilla.
 
214
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
215
 
216
 
217
+ # Copied from transformers.model.llama.modeling_llama.rotate_half
218
  def rotate_half(x):
219
  """Rotates half the hidden dims of the input."""
220
  x1 = x[..., : x.shape[-1] // 2]
 
222
  return torch.cat((-x2, x1), dim=-1)
223
 
224
 
225
+ # Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb
226
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
227
+ """Applies Rotary Position Embedding to the query and key tensors."""
228
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
229
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
230
+ q_embed = (q * cos) + (rotate_half(q) * sin)
231
+ k_embed = (k * cos) + (rotate_half(k) * sin)
 
 
 
 
 
 
 
 
 
232
  return q_embed, k_embed
233
 
234
 
 
249
  return down_proj
250
 
251
 
252
+ # Copied from transformers.model.llama.modeling_llama.repeat_kv
253
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
254
  """
255
  This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
 
262
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
263
 
264
 
265
+ # Modified from transformers.model.llama.modeling_llama.LlamaAttention
266
  class InternLM2Attention(nn.Module):
267
  """Multi-headed attention from 'Attention Is All You Need' paper"""
268
 
 
307
  self.head_dim,
308
  max_position_embeddings=self.max_position_embeddings,
309
  base=self.config.rope_theta,
310
+ scaling_factor=scaling_factor,
311
+ )
312
+ elif scaling_type == "linear":
313
+ self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
314
+ self.head_dim,
315
+ max_position_embeddings=self.max_position_embeddings,
316
+ base=self.config.rope_theta,
317
+ scaling_factor=scaling_factor,
318
  )
319
  else:
320
+ raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.")
321
  return self.rotary_emb
322
 
323
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
 
411
  return attn_output, attn_weights, past_key_value
412
 
413
 
414
+ # Modified from transformers.model.llama.modeling_llama.InternLM2FlashAttention2
415
  class InternLM2FlashAttention2(InternLM2Attention):
416
  """
417
  InternLM2 flash attention module. This module inherits from `InternLM2Attention` as the weights of the module stays
 
448
  qkv_states = rearrange(
449
  qkv_states,
450
  "b q (h gs d) -> b q h gs d",
451
+ gs=2 + self.num_key_value_groups,
452
  d=self.head_dim,
 
453
  )
454
 
455
  query_states = qkv_states[..., : self.num_key_value_groups, :]
 
457
  key_states = qkv_states[..., -2, :]
458
  value_states = qkv_states[..., -1, :]
459
 
460
+ query_states = query_states.transpose(1, 2)
461
+ key_states = key_states.transpose(1, 2)
462
+ value_states = value_states.transpose(1, 2)
463
+
464
  kv_seq_len = key_states.shape[-2]
465
  if past_key_value is not None:
466
  kv_seq_len += past_key_value[0].shape[-2]
 
480
  key_states = key_states.transpose(1, 2)
481
  value_states = value_states.transpose(1, 2)
482
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  attn_output = self._flash_attention_forward(
484
+ query_states, key_states, value_states, attention_mask, q_len
485
  )
 
486
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
487
  attn_output = self.wo(attn_output)
488
 
 
491
 
492
  return attn_output, attn_weights, past_key_value
493
 
494
+ def _flash_attention_forward(
495
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
496
+ ):
497
+ """
498
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
499
+ first unpad the input, then computes the attention scores and pad the final attention scores.
500
+
501
+ Args:
502
+ query_states (`torch.Tensor`):
503
+ Input query states to be passed to Flash Attention API
504
+ key_states (`torch.Tensor`):
505
+ Input key states to be passed to Flash Attention API
506
+ value_states (`torch.Tensor`):
507
+ Input value states to be passed to Flash Attention API
508
+ attention_mask (`torch.Tensor`):
509
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
510
+ position of padding tokens and 1 for the position of non-padding tokens.
511
+ dropout (`int`, *optional*):
512
+ Attention dropout
513
+ softmax_scale (`float`, *optional*):
514
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
515
+ """
516
+ # Contains at least one padding token in the sequence
517
+ causal = self.is_causal and query_length != 1
518
+ if attention_mask is not None:
519
+ batch_size = query_states.shape[0]
520
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
521
+ query_states, key_states, value_states, attention_mask, query_length
522
+ )
523
+
524
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
525
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
526
+
527
+ attn_output_unpad = flash_attn_varlen_func(
528
+ query_states,
529
+ key_states,
530
+ value_states,
531
+ cu_seqlens_q=cu_seqlens_q,
532
+ cu_seqlens_k=cu_seqlens_k,
533
+ max_seqlen_q=max_seqlen_in_batch_q,
534
+ max_seqlen_k=max_seqlen_in_batch_k,
535
+ dropout_p=dropout,
536
+ softmax_scale=softmax_scale,
537
+ causal=causal,
538
+ )
539
+
540
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
541
+ else:
542
+ attn_output = flash_attn_func(
543
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
544
+ )
545
+
546
+ return attn_output
547
+
548
+ def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
549
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
550
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
551
+
552
+ key_layer = index_first_axis(
553
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
554
+ )
555
+ value_layer = index_first_axis(
556
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
557
+ )
558
+
559
+ if query_length == kv_seq_len:
560
+ query_layer = index_first_axis(
561
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
562
+ )
563
+ cu_seqlens_q = cu_seqlens_k
564
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
565
+ indices_q = indices_k
566
+ elif query_length == 1:
567
+ max_seqlen_in_batch_q = 1
568
+ cu_seqlens_q = torch.arange(
569
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
570
+ ) # There is a memcpy here, that is very bad.
571
+ indices_q = cu_seqlens_q[:-1]
572
+ query_layer = query_layer.squeeze(1)
573
+ else:
574
+ # The -q_len: slice assumes left padding.
575
+ attention_mask = attention_mask[:, -query_length:]
576
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
577
+
578
+ return (
579
+ query_layer,
580
+ key_layer,
581
+ value_layer,
582
+ indices_q.to(torch.int64),
583
+ (cu_seqlens_q, cu_seqlens_k),
584
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
585
+ )
586
+
587
+ INTERNLM2_ATTENTION_CLASSES = {
588
+ "eager": InternLM2Attention,
589
+ "flash_attention_2": InternLM2FlashAttention2,
590
+ }
591
 
592
+ # Modified from transformers.model.llama.modeling_llama.LlamaDecoderLayer
593
  class InternLM2DecoderLayer(nn.Module):
594
  def __init__(self, config: InternLM2Config):
595
  super().__init__()
596
  self.hidden_size = config.hidden_size
597
+
598
+ self.attention = INTERNLM2_ATTENTION_CLASSES[config.attn_implementation](config=config)
599
+
 
 
600
  self.feed_forward = InternLM2MLP(config)
601
  self.attention_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
602
  self.ffn_norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
681
  """
682
 
683
 
684
+ # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->InternLM2
685
  @add_start_docstrings(
686
  "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
687
  InternLM2_START_DOCSTRING,
 
692
  supports_gradient_checkpointing = True
693
  _no_split_modules = ["InternLM2DecoderLayer"]
694
  _skip_keys_device_placement = "past_key_values"
 
695
 
696
  def _init_weights(self, module):
697
  std = self.config.initializer_range
 
770
  """
771
 
772
 
773
+ # Modified from transformers.model.llama.modeling_llama.LlamaModel
774
  @add_start_docstrings(
775
  "The bare InternLM2 Model outputting raw hidden-states without any specific head on top.",
776
  InternLM2_START_DOCSTRING,
 
789
  super().__init__(config)
790
  self.padding_idx = config.pad_token_id
791
  self.vocab_size = config.vocab_size
792
+ self.config = config
793
 
794
  self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
795
+
796
  self.layers = nn.ModuleList([InternLM2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
797
  self.norm = InternLM2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
798
 
 
806
  def set_input_embeddings(self, value):
807
  self.tok_embeddings = value
808
 
 
809
  def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
810
  # create causal mask
811
  # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
 
850
 
851
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
852
 
853
+ if self.config.attn_implementation == "flash_attention_2":
854
+ _import_flash_attn()
855
+
856
  # retrieve input_ids and inputs_embeds
857
  if input_ids is not None and inputs_embeds is not None:
858
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 
878
 
879
  if inputs_embeds is None:
880
  inputs_embeds = self.tok_embeddings(input_ids)
881
+
882
+ if self.config.attn_implementation == "flash_attention_2":
883
+ # 2d mask is passed through the layers
884
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
885
+ else:
886
+ if attention_mask is None:
887
+ attention_mask = torch.ones(
888
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
889
+ )
890
+ attention_mask = self._prepare_decoder_attention_mask(
891
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
892
  )
 
 
 
893
 
894
  # embed positions
895
  hidden_states = inputs_embeds
 
963
  )
964
 
965
 
966
+ # Modified from transformers.model.llama.modeling_llama.LlamaForCausalLM
967
  class InternLM2ForCausalLM(InternLM2PreTrainedModel):
968
  _auto_class = "AutoModelForCausalLM"
969
 
 
1133
  )
1134
  return reordered_past
1135
 
1136
+ def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=""):
1137
  prompt = ""
1138
+ if meta_instruction:
1139
+ prompt += f"""<s><|im_start|>system\n{meta_instruction}<|im_end|>\n"""
1140
+ else:
1141
+ prompt += "<s>"
1142
  for record in history:
1143
+ prompt += f"""<|im_start|>user\n{record[0]}<|im_end|>\n<|im_start|>assistant\n{record[1]}<|im_end|>\n"""
1144
+ prompt += f"""<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n"""
 
1145
  return tokenizer([prompt], return_tensors="pt")
1146
 
1147
  @torch.no_grad()
 
1155
  do_sample: bool = True,
1156
  temperature: float = 0.8,
1157
  top_p: float = 0.8,
1158
+ meta_instruction: str = "You are an AI assistant whose name is InternLM (书生·浦语).\n"
1159
+ "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
1160
+ "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.",
1161
  **kwargs,
1162
  ):
1163
+ inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
1164
  inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
1165
+ # also add end-of-assistant token in eos token id to avoid unnecessary generation
1166
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(["<|im_end|>"])[0]]
1167
  outputs = self.generate(
1168
  **inputs,
1169
  streamer=streamer,
 
1171
  do_sample=do_sample,
1172
  temperature=temperature,
1173
  top_p=top_p,
1174
+ eos_token_id=eos_token_id,
1175
  **kwargs,
1176
  )
1177
  outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
1178
  response = tokenizer.decode(outputs, skip_special_tokens=True)
1179
+ response = response.split("<|im_end|>")[0]
1180
  history = history + [(query, response)]
1181
  return response, history
1182
 
 
1229
  return
1230
 
1231
  token = self.tokenizer.decode([value[-1]], skip_special_tokens=True)
1232
+ if token.strip() != "<|im_end|>":
1233
  self.response = self.response + token
1234
  history = self.history + [(self.query, self.response)]
1235
  self.queue.put((self.response, history))
 
1262
  return consumer()
1263
 
1264
 
1265
+ # Copied from transformers.model.llama.modeling_llama.LlamaForSequenceClassification with Llama->InternLM2
1266
  @add_start_docstrings(
1267
  """
1268
  The InternLM2 Model transformer with a sequence classification head on top (linear layer).