x54-729 commited on
Commit
c95809d
1 Parent(s): 2d83118

Update modeling_internlm.py

Browse files
Files changed (1) hide show
  1. modeling_internlm.py +52 -36
modeling_internlm.py CHANGED
@@ -1,5 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
  #
4
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
  # and OPT implementations in this library. It has been modified from its
@@ -28,7 +28,6 @@ import torch.utils.checkpoint
28
  from torch import nn
29
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
  from transformers.activations import ACT2FN
31
- from transformers.generation.streamers import BaseStreamer
32
  from transformers.modeling_outputs import (
33
  BaseModelOutputWithPast,
34
  CausalLMOutputWithPast,
@@ -42,6 +41,11 @@ from transformers.utils import (
42
  replace_return_docstrings,
43
  )
44
 
 
 
 
 
 
45
  from .configuration_internlm import InternLMConfig
46
 
47
  logger = logging.get_logger(__name__)
@@ -113,6 +117,7 @@ class InternLMRotaryEmbedding(torch.nn.Module):
113
  base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000.
114
  device (Any, optional): Running device. Defaults to None.
115
  """
 
116
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
117
  super().__init__()
118
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
@@ -124,8 +129,8 @@ class InternLMRotaryEmbedding(torch.nn.Module):
124
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
125
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
126
  emb = torch.cat((freqs, freqs), dim=-1)
127
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
128
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
129
 
130
  def forward(self, x, seq_len=None):
131
  # x: [bs, num_attention_heads, seq_len, head_size]
@@ -136,11 +141,11 @@ class InternLMRotaryEmbedding(torch.nn.Module):
136
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
137
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
138
  emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
139
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
140
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
141
  return (
142
- self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
143
- self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
144
  )
145
 
146
 
@@ -158,7 +163,7 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
158
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
159
  super().__init__()
160
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
161
- self.register_buffer("inv_freq", inv_freq)
162
  self.dim = dim
163
  self.base = base
164
  self.scaling_factor = scaling_factor
@@ -170,8 +175,8 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
170
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
171
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
172
  emb = torch.cat((freqs, freqs), dim=-1)
173
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
174
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
175
 
176
  def _update_cached(self, x, seq_len=None):
177
  self.max_seq_len_cached = max(seq_len, self.max_position_embeddings)
@@ -185,8 +190,8 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
185
  t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype)
186
  freqs = torch.einsum("i,j->ij", t, inv_freq)
187
  emb = torch.cat((freqs, freqs), dim=-1)
188
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
189
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
190
 
191
  def forward(self, x, seq_len=None):
192
  # x: [bs, num_attention_heads, seq_len, head_size]
@@ -199,8 +204,8 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
199
  self._update_cached(x, seq_len)
200
 
201
  return (
202
- self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
203
- self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
204
  )
205
 
206
 
@@ -210,23 +215,23 @@ def rotate_half(x):
210
  x2 = x[..., x.shape[-1] // 2 :]
211
  return torch.cat((-x2, x1), dim=-1)
212
 
213
-
214
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
215
- # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
216
- cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
217
- sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
218
- cos = cos.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
219
- sin = sin.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
220
- if q.size(2) == 1:
221
- q_embed = (q * cos[:, :, -1, :]) + (rotate_half(q) * sin[:, :, -1, :])
 
 
 
 
222
  else:
 
 
223
  q_embed = (q * cos) + (rotate_half(q) * sin)
224
-
225
- if k.size(2) == 1:
226
- k_embed = (k * cos[:, :, -1, :]) + (rotate_half(k) * sin[:, :, -1, :])
227
- else:
228
  k_embed = (k * cos) + (rotate_half(k) * sin)
229
-
230
  return q_embed, k_embed
231
 
232
 
@@ -270,7 +275,7 @@ class InternLMAttention(nn.Module):
270
  self.rotary_emb = self._init_rope()
271
 
272
  def _init_rope(self):
273
- if self.config.rotary["type"] == "origin":
274
  self.rotary_emb = InternLMRotaryEmbedding(
275
  self.head_dim,
276
  max_position_embeddings=self.max_position_embeddings,
@@ -310,7 +315,6 @@ class InternLMAttention(nn.Module):
310
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
311
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
312
 
313
- # print(use_cache)
314
  past_key_value = (key_states, value_states) if use_cache else None
315
 
316
  kv_seq_len = key_states.shape[-2]
@@ -851,12 +855,16 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
851
  for layer_past in past_key_values:
852
  reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
853
  return reordered_past
854
-
855
- def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []):
856
  prompt = ""
 
 
 
 
857
  for record in history:
858
- prompt += f"""<|User|>:{record[0]}<eoh>\n<|Bot|>:{record[1]}<eoa>\n"""
859
- prompt += f"""<|User|>:{query}<eoh>\n<|Bot|>:"""
860
  return tokenizer([prompt], return_tensors="pt")
861
 
862
  @torch.no_grad()
@@ -870,9 +878,12 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
870
  do_sample: bool = True,
871
  temperature: float = 0.8,
872
  top_p: float = 0.8,
 
 
 
873
  **kwargs,
874
  ):
875
- inputs = self.build_inputs(tokenizer, query, history)
876
  inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
877
  outputs = self.generate(
878
  **inputs,
@@ -907,6 +918,11 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
907
  ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
908
  ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
909
  """
 
 
 
 
 
910
 
911
  response_queue = queue.Queue(maxsize=20)
912
 
@@ -1083,4 +1099,4 @@ class InternLMForSequenceClassification(InternLMPreTrainedModel):
1083
  past_key_values=transformer_outputs.past_key_values,
1084
  hidden_states=transformer_outputs.hidden_states,
1085
  attentions=transformer_outputs.attentions,
1086
- )
 
1
  # coding=utf-8
2
+ # Copyright (c) InternLM. All rights reserved.
3
  #
4
  # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
  # and OPT implementations in this library. It has been modified from its
 
28
  from torch import nn
29
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
  from transformers.activations import ACT2FN
 
31
  from transformers.modeling_outputs import (
32
  BaseModelOutputWithPast,
33
  CausalLMOutputWithPast,
 
41
  replace_return_docstrings,
42
  )
43
 
44
+ try:
45
+ from transformers.generation.streamers import BaseStreamer
46
+ except: # noqa # pylint: disable=bare-except
47
+ BaseStreamer = None
48
+
49
  from .configuration_internlm import InternLMConfig
50
 
51
  logger = logging.get_logger(__name__)
 
117
  base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000.
118
  device (Any, optional): Running device. Defaults to None.
119
  """
120
+
121
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
122
  super().__init__()
123
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
 
129
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
130
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
131
  emb = torch.cat((freqs, freqs), dim=-1)
132
+ self.register_buffer("cos_cached", emb.cos().to(torch.float32), persistent=False)
133
+ self.register_buffer("sin_cached", emb.sin().to(torch.float32), persistent=False)
134
 
135
  def forward(self, x, seq_len=None):
136
  # x: [bs, num_attention_heads, seq_len, head_size]
 
141
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
142
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
143
  emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
144
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
145
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
146
  return (
147
+ self.cos_cached[:seq_len, ...].to(dtype=x.dtype),
148
+ self.sin_cached[:seq_len, ...].to(dtype=x.dtype),
149
  )
150
 
151
 
 
163
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
164
  super().__init__()
165
  inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
166
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
167
  self.dim = dim
168
  self.base = base
169
  self.scaling_factor = scaling_factor
 
175
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
176
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
177
  emb = torch.cat((freqs, freqs), dim=-1)
178
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
179
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
180
 
181
  def _update_cached(self, x, seq_len=None):
182
  self.max_seq_len_cached = max(seq_len, self.max_position_embeddings)
 
190
  t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype)
191
  freqs = torch.einsum("i,j->ij", t, inv_freq)
192
  emb = torch.cat((freqs, freqs), dim=-1)
193
+ self.register_buffer("cos_cached", emb.cos(), persistent=False)
194
+ self.register_buffer("sin_cached", emb.sin(), persistent=False)
195
 
196
  def forward(self, x, seq_len=None):
197
  # x: [bs, num_attention_heads, seq_len, head_size]
 
204
  self._update_cached(x, seq_len)
205
 
206
  return (
207
+ self.cos_cached[:seq_len, ...].to(dtype=x.dtype),
208
+ self.sin_cached[:seq_len, ...].to(dtype=x.dtype),
209
  )
210
 
211
 
 
215
  x2 = x[..., x.shape[-1] // 2 :]
216
  return torch.cat((-x2, x1), dim=-1)
217
 
 
218
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
219
+ if position_ids.size(1) == 1:
220
+ q_cos = cos[position_ids].unsqueeze(1).expand(q.shape)
221
+ q_sin = sin[position_ids].unsqueeze(1).expand(q.shape)
222
+ q_embed = (q * q_cos) + (rotate_half(q) * q_sin)
223
+
224
+ position_ids = position_ids.flatten() + 1
225
+ max_length = max(position_ids)
226
+ position_ids = torch.stack([torch.cat([torch.ones(max_length - w, dtype=torch.long), torch.arange(w)]) for w in position_ids])
227
+ k_cos = cos[position_ids].unsqueeze(1).expand(k.shape)
228
+ k_sin = sin[position_ids].unsqueeze(1).expand(k.shape)
229
+ k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
230
  else:
231
+ cos = cos[position_ids].unsqueeze(1).expand(q.shape)
232
+ sin = sin[position_ids].unsqueeze(1).expand(q.shape)
233
  q_embed = (q * cos) + (rotate_half(q) * sin)
 
 
 
 
234
  k_embed = (k * cos) + (rotate_half(k) * sin)
 
235
  return q_embed, k_embed
236
 
237
 
 
275
  self.rotary_emb = self._init_rope()
276
 
277
  def _init_rope(self):
278
+ if self.config.rotary["type"] == "origin"
279
  self.rotary_emb = InternLMRotaryEmbedding(
280
  self.head_dim,
281
  max_position_embeddings=self.max_position_embeddings,
 
315
  key_states = torch.cat([past_key_value[0], key_states], dim=2)
316
  value_states = torch.cat([past_key_value[1], value_states], dim=2)
317
 
 
318
  past_key_value = (key_states, value_states) if use_cache else None
319
 
320
  kv_seq_len = key_states.shape[-2]
 
855
  for layer_past in past_key_values:
856
  reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
857
  return reordered_past
858
+
859
+ def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction=""):
860
  prompt = ""
861
+ if meta_instruction:
862
+ prompt += f"""<s><|System|>:{meta_instruction}\n"""
863
+ else:
864
+ prompt += "<s>"
865
  for record in history:
866
+ prompt += f"""<|User|>:{record[0]}\n<|Bot|>:{record[1]}<eoa>\n"""
867
+ prompt += f"""<|User|>:{query}\n<|Bot|>:"""
868
  return tokenizer([prompt], return_tensors="pt")
869
 
870
  @torch.no_grad()
 
878
  do_sample: bool = True,
879
  temperature: float = 0.8,
880
  top_p: float = 0.8,
881
+ meta_instruction: str = "You are an AI assistant whose name is InternLM (书生·浦语).\n"
882
+ "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n"
883
+ "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user such as English and 中文.",
884
  **kwargs,
885
  ):
886
+ inputs = self.build_inputs(tokenizer, query, history, meta_instruction)
887
  inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
888
  outputs = self.generate(
889
  **inputs,
 
918
  ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
919
  ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
920
  """
921
+ if BaseStreamer is None:
922
+ raise ModuleNotFoundError(
923
+ "The version of `transformers` is too low. Please make sure "
924
+ "that you have installed `transformers>=4.28.0`."
925
+ )
926
 
927
  response_queue = queue.Queue(maxsize=20)
928
 
 
1099
  past_key_values=transformer_outputs.past_key_values,
1100
  hidden_states=transformer_outputs.hidden_states,
1101
  attentions=transformer_outputs.attentions,
1102
+ )