pom commited on
Commit
c760666
1 Parent(s): 8528264

update files

Browse files
Files changed (3) hide show
  1. config.json +1 -1
  2. modeling_xverse.py +153 -103
  3. ms_wrapper.py +95 -0
config.json CHANGED
@@ -13,7 +13,7 @@
13
  "hidden_size": 5120,
14
  "initializer_range": 0.02,
15
  "intermediate_size": 13824,
16
- "max_position_embeddings": 262144,
17
  "max_tokenizer_truncation": 262144,
18
  "model_type": "xverse",
19
  "num_attention_heads": 40,
 
13
  "hidden_size": 5120,
14
  "initializer_range": 0.02,
15
  "intermediate_size": 13824,
16
+ "max_position_embeddings": 32768,
17
  "max_tokenizer_truncation": 262144,
18
  "model_type": "xverse",
19
  "num_attention_heads": 40,
modeling_xverse.py CHANGED
@@ -33,6 +33,7 @@ from transformers.modeling_utils import PreTrainedModel
33
  from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
34
  from transformers.generation.utils import GenerationConfig
35
  from .configuration_xverse import XverseConfig
 
36
 
37
 
38
  logger = logging.get_logger(__name__)
@@ -48,13 +49,15 @@ def _make_causal_mask(
48
  Make causal mask used for bi-directional self-attention.
49
  """
50
  bsz, tgt_len = input_ids_shape
51
- mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
 
52
  mask_cond = torch.arange(mask.size(-1), device=device)
53
  mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
54
  mask = mask.to(dtype)
55
 
56
  if past_key_values_length > 0:
57
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
 
58
  return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
59
 
60
 
@@ -66,7 +69,8 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
66
  bsz, src_len = mask.size()
67
  tgt_len = tgt_len if tgt_len is not None else src_len
68
 
69
- expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
 
70
 
71
  inverted_mask = 1.0 - expanded_mask
72
 
@@ -84,8 +88,10 @@ class XverseRMSNorm(nn.Module):
84
 
85
  def forward(self, hidden_states):
86
  input_dtype = hidden_states.dtype
87
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
88
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
 
 
89
 
90
  return (self.weight * hidden_states).to(input_dtype)
91
 
@@ -93,29 +99,47 @@ class XverseRMSNorm(nn.Module):
93
  class XverseRotaryEmbedding(torch.nn.Module):
94
  def __init__(self, dim, max_position_embeddings=2048, base=500000, device=None):
95
  super().__init__()
96
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
 
 
 
 
97
  self.register_buffer("inv_freq", inv_freq)
98
 
99
  # Build here to make `torch.jit.trace` work.
100
  self.max_seq_len_cached = max_position_embeddings
101
- t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
 
 
102
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
103
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
104
  emb = torch.cat((freqs, freqs), dim=-1)
105
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
106
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
 
 
107
 
108
  def forward(self, x, seq_len=None):
109
  # x: [bs, num_attention_heads, seq_len, head_size]
110
  # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
111
  if seq_len > self.max_seq_len_cached:
112
- self.max_seq_len_cached = seq_len
113
- t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
114
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
115
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
 
 
 
 
 
116
  emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
117
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
118
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
 
 
 
 
 
119
  return (
120
  self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
121
  self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
@@ -125,7 +149,7 @@ class XverseRotaryEmbedding(torch.nn.Module):
125
  def rotate_half(x):
126
  """Rotates half the hidden dims of the input."""
127
  x1 = x[..., : x.shape[-1] // 2]
128
- x2 = x[..., x.shape[-1] // 2 :]
129
  return torch.cat((-x2, x1), dim=-1)
130
 
131
 
@@ -173,11 +197,16 @@ class XverseAttention(nn.Module):
173
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
174
  f" and `num_heads`: {self.num_heads})."
175
  )
176
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
177
- self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
178
- self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
179
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
180
- self.rotary_emb = XverseRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
 
 
 
 
 
181
 
182
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
183
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
@@ -190,64 +219,50 @@ class XverseAttention(nn.Module):
190
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
191
  output_attentions: bool = False,
192
  use_cache: bool = False,
 
193
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
194
- bsz, q_len, _ = hidden_states.size()
195
-
196
- query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
197
- key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
198
- value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
199
-
200
- kv_seq_len = key_states.shape[-2]
201
- if past_key_value is not None:
202
- kv_seq_len += past_key_value[0].shape[-2]
203
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
204
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
205
- # [bsz, nh, t, hd]
206
 
207
- if past_key_value is not None:
208
- # reuse k, v, self_attention
209
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
210
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
211
 
212
- past_key_value = (key_states, value_states) if use_cache else None
 
 
 
 
 
 
213
 
214
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
 
 
215
 
216
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
217
- raise ValueError(
218
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
219
- f" {attn_weights.size()}"
220
- )
221
 
222
- if attention_mask is not None:
223
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
224
- raise ValueError(
225
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
226
- )
227
- attn_weights = attn_weights + attention_mask
228
- attn_weights = torch.max(
229
- attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
230
- )
231
 
232
- # upcast attention to fp32
233
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
234
- attn_output = torch.matmul(attn_weights, value_states)
235
 
236
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
237
- raise ValueError(
238
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
239
- f" {attn_output.size()}"
240
- )
 
 
241
 
242
- attn_output = attn_output.transpose(1, 2)
 
 
 
243
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
244
-
245
  attn_output = self.o_proj(attn_output)
246
-
247
- if not output_attentions:
248
- attn_weights = None
249
-
250
- return attn_output, attn_weights, past_key_value
251
 
252
 
253
  class XverseDecoderLayer(nn.Module):
@@ -260,8 +275,10 @@ class XverseDecoderLayer(nn.Module):
260
  intermediate_size=config.intermediate_size,
261
  hidden_act=config.hidden_act,
262
  )
263
- self.input_layernorm = XverseRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
264
- self.post_attention_layernorm = XverseRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
 
265
 
266
  def forward(
267
  self,
@@ -426,6 +443,7 @@ XVERSE_INPUTS_DOCSTRING = r"""
426
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
427
  """
428
 
 
429
  @add_start_docstrings(
430
  "The bare Xverse Model outputting raw hidden-states without any specific head on top.",
431
  XVERSE_START_DOCSTRING,
@@ -443,8 +461,10 @@ class XverseModel(XversePreTrainedModel):
443
  self.padding_idx = config.pad_token_id
444
  self.vocab_size = config.vocab_size
445
 
446
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
447
- self.layers = nn.ModuleList([XverseDecoderLayer(config) for _ in range(config.num_hidden_layers)])
 
 
448
  self.norm = XverseRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
449
 
450
  self.gradient_checkpointing = False
@@ -476,7 +496,8 @@ class XverseModel(XversePreTrainedModel):
476
  inputs_embeds.device
477
  )
478
  combined_attention_mask = (
479
- expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
 
480
  )
481
 
482
  return combined_attention_mask
@@ -504,13 +525,15 @@ class XverseModel(XversePreTrainedModel):
504
 
505
  # retrieve input_ids and inputs_embeds
506
  if input_ids is not None and inputs_embeds is not None:
507
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
 
508
  elif input_ids is not None:
509
  batch_size, seq_length = input_ids.shape
510
  elif inputs_embeds is not None:
511
  batch_size, seq_length, _ = inputs_embeds.shape
512
  else:
513
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
 
514
 
515
  seq_length_with_past = seq_length
516
  past_key_values_length = 0
@@ -536,7 +559,8 @@ class XverseModel(XversePreTrainedModel):
536
  (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
537
  )
538
  attention_mask = self._prepare_decoder_attention_mask(
539
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
 
540
  )
541
 
542
  hidden_states = inputs_embeds
@@ -588,7 +612,8 @@ class XverseModel(XversePreTrainedModel):
588
  hidden_states = layer_outputs[0]
589
 
590
  if use_cache:
591
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
 
592
 
593
  if output_attentions:
594
  all_self_attns += (layer_outputs[1],)
@@ -615,7 +640,8 @@ class XverseForCausalLM(XversePreTrainedModel):
615
  super().__init__(config)
616
  self.model = XverseModel(config)
617
 
618
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
619
 
620
  # Initialize weights and apply final processing
621
  self.post_init()
@@ -726,25 +752,32 @@ class XverseForCausalLM(XversePreTrainedModel):
726
  attentions=outputs.attentions,
727
  )
728
 
729
- def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=2048):
730
  max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
731
  max_input_tokens = self.config.max_position_embeddings - max_new_tokens
732
- max_input_tokens = max(self.config.max_position_embeddings // 2, max_input_tokens)
733
- max_input_tokens = min(self.config.max_tokenizer_truncation, max_input_tokens)
 
 
734
 
735
  total_input, round_input = [], []
736
- user_prompt_tokens = tokenizer.encode("Human: ", return_token_type_ids=False)
737
- exec_prompt_tokens = tokenizer.encode("Exec: ", return_token_type_ids=False)
738
- assist_prompt_tokens = tokenizer.encode("Assistant: ", return_token_type_ids=False)
 
 
 
739
  assist_prompt_len = len(assist_prompt_tokens)
740
 
741
  for i, message in enumerate(messages[::-1]):
742
  if message['role'] == 'user' or message['role'] == 'exec':
743
  user_content = f"{message['content']}\n\n"
744
  content_tokens = user_prompt_tokens + tokenizer.encode(user_content, return_token_type_ids=False) if message['role'] == 'user' else \
745
- exec_prompt_tokens + tokenizer.encode(user_content, return_token_type_ids=False)
 
746
  if i == 0:
747
- content_tokens = content_tokens[:max_input_tokens-assist_prompt_len]
 
748
  content_tokens += assist_prompt_tokens
749
  round_input = content_tokens + round_input
750
 
@@ -760,27 +793,33 @@ class XverseForCausalLM(XversePreTrainedModel):
760
  round_input = []
761
  elif message['role'] == 'assistant':
762
  assist_content = f"{message['content']}"
763
- content_tokens = assist_prompt_tokens + tokenizer.encode(assist_content, return_token_type_ids=False)
764
- round_input = content_tokens + [self.generation_config.eos_token_id] + round_input
 
 
 
765
  elif message['role'] == 'system':
766
  assert i == len(messages) - 1
767
  user_content = f"{message['content']}\n"
768
- content_tokens = tokenizer.encode(user_content, return_token_type_ids=False)
 
769
  round_input = user_prompt_tokens + content_tokens + round_input
770
  if len(total_input) + len(round_input) > max_input_tokens:
771
  break
772
  else:
773
  total_input = round_input + total_input
774
  else:
775
- raise ValueError(f"message role not supported yet: {message['role']}")
 
776
  total_input = torch.LongTensor([total_input]).to(self.device)
777
  return total_input
778
 
779
  @torch.no_grad()
780
  def chat(self, tokenizer, messages: List[dict], stream=False,
781
- generation_config: Optional[GenerationConfig]=None):
782
  generation_config = generation_config or self.generation_config
783
- input_ids = self._build_chat_input(tokenizer, messages, generation_config.max_new_tokens)
 
784
  if stream:
785
  from transformers import TextIteratorStreamer
786
  from threading import Thread
@@ -788,7 +827,8 @@ class XverseForCausalLM(XversePreTrainedModel):
788
  self.__class__.generate = PreTrainedModel.generate
789
 
790
  def stream_generator():
791
- generation_kwargs = dict(inputs=input_ids, generation_config=generation_config, streamer=streamer)
 
792
  thread = Thread(target=self.generate, kwargs=generation_kwargs)
793
  thread.start()
794
  for next_text in streamer:
@@ -797,8 +837,10 @@ class XverseForCausalLM(XversePreTrainedModel):
797
  return stream_generator()
798
  else:
799
  self.__class__.generate = PreTrainedModel.generate # disable stream
800
- outputs = self.generate(input_ids, generation_config=generation_config)
801
- response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
 
 
802
  return response
803
 
804
  def prepare_inputs_for_generation(
@@ -835,7 +877,8 @@ class XverseForCausalLM(XversePreTrainedModel):
835
  def _reorder_cache(past_key_values, beam_idx):
836
  reordered_past = ()
837
  for layer_past in past_key_values:
838
- reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
 
839
  return reordered_past
840
 
841
  def quantize(self, bit_length: int):
@@ -844,37 +887,44 @@ class XverseForCausalLM(XversePreTrainedModel):
844
  for layer in self.model.layers:
845
  layer.self_attn.q_proj = QuantizationLinear(
846
  bit_length=bit_length,
847
- weight=layer.self_attn.q_proj.weight.to(torch.cuda.current_device()),
 
848
  device=layer.self_attn.q_proj.weight.device,
849
  )
850
  layer.self_attn.k_proj = QuantizationLinear(
851
  bit_length=bit_length,
852
- weight=layer.self_attn.k_proj.weight.to(torch.cuda.current_device()),
 
853
  device=layer.self_attn.k_proj.weight.device
854
  )
855
  layer.self_attn.v_proj = QuantizationLinear(
856
  bit_length=bit_length,
857
- weight=layer.self_attn.v_proj.weight.to(torch.cuda.current_device()),
 
858
  device=layer.self_attn.v_proj.weight.device
859
  )
860
  layer.self_attn.o_proj = QuantizationLinear(
861
  bit_length=bit_length,
862
- weight=layer.self_attn.o_proj.weight.to(torch.cuda.current_device()),
 
863
  device=layer.self_attn.o_proj.weight.device
864
  )
865
  layer.mlp.gate_proj = QuantizationLinear(
866
  bit_length=bit_length,
867
- weight=layer.mlp.gate_proj.weight.to(torch.cuda.current_device()),
 
868
  device=layer.mlp.gate_proj.weight.device
869
  )
870
  layer.mlp.down_proj = QuantizationLinear(
871
  bit_length=bit_length,
872
- weight=layer.mlp.down_proj.weight.to(torch.cuda.current_device()),
 
873
  device=layer.mlp.down_proj.weight.device
874
  )
875
  layer.mlp.up_proj = QuantizationLinear(
876
  bit_length=bit_length,
877
- weight=layer.mlp.up_proj.weight.to(torch.cuda.current_device()),
 
878
  device=layer.mlp.up_proj.weight.device
879
  )
880
 
 
33
  from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
34
  from transformers.generation.utils import GenerationConfig
35
  from .configuration_xverse import XverseConfig
36
+ from xformers import ops as xops
37
 
38
 
39
  logger = logging.get_logger(__name__)
 
49
  Make causal mask used for bi-directional self-attention.
50
  """
51
  bsz, tgt_len = input_ids_shape
52
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(
53
+ torch.finfo(dtype).min, device=device), device=device)
54
  mask_cond = torch.arange(mask.size(-1), device=device)
55
  mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
56
  mask = mask.to(dtype)
57
 
58
  if past_key_values_length > 0:
59
+ mask = torch.cat([torch.zeros(
60
+ tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
61
  return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
62
 
63
 
 
69
  bsz, src_len = mask.size()
70
  tgt_len = tgt_len if tgt_len is not None else src_len
71
 
72
+ expanded_mask = mask[:, None, None, :].expand(
73
+ bsz, 1, tgt_len, src_len).to(dtype)
74
 
75
  inverted_mask = 1.0 - expanded_mask
76
 
 
88
 
89
  def forward(self, hidden_states):
90
  input_dtype = hidden_states.dtype
91
+ variance = hidden_states.to(torch.float32).pow(
92
+ 2).mean(-1, keepdim=True)
93
+ hidden_states = hidden_states * \
94
+ torch.rsqrt(variance + self.variance_epsilon)
95
 
96
  return (self.weight * hidden_states).to(input_dtype)
97
 
 
99
  class XverseRotaryEmbedding(torch.nn.Module):
100
  def __init__(self, dim, max_position_embeddings=2048, base=500000, device=None):
101
  super().__init__()
102
+ self.base = base
103
+ self.dim = dim
104
+ self.max_position_embeddings = max_position_embeddings
105
+ inv_freq = 1.0 / \
106
+ (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
107
  self.register_buffer("inv_freq", inv_freq)
108
 
109
  # Build here to make `torch.jit.trace` work.
110
  self.max_seq_len_cached = max_position_embeddings
111
+
112
+ t = torch.arange(self.max_seq_len_cached,
113
+ device=self.inv_freq.device, dtype=self.inv_freq.dtype)
114
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
115
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
116
  emb = torch.cat((freqs, freqs), dim=-1)
117
+ self.register_buffer("cos_cached", emb.cos()[
118
+ None, None, :, :], persistent=False)
119
+ self.register_buffer("sin_cached", emb.sin()[
120
+ None, None, :, :], persistent=False)
121
 
122
  def forward(self, x, seq_len=None):
123
  # x: [bs, num_attention_heads, seq_len, head_size]
124
  # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
125
  if seq_len > self.max_seq_len_cached:
126
+
127
+ t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
128
+ dim = self.dim
129
+ alpha = (seq_len / (self.max_position_embeddings/2) - 1)
130
+ base = self.base * alpha ** (dim / (dim-2))
131
+ ntk_inv_freq = 1.0 / \
132
+ (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim))
133
+
134
+ freqs = torch.einsum("i,j->ij", t, ntk_inv_freq)
135
  emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
136
+ cos_cached = emb.cos()[None, None, :, :]
137
+ sin_cached = emb.sin()[None, None, :, :]
138
+ return (
139
+ cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
140
+ sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
141
+ )
142
+
143
  return (
144
  self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
145
  self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
 
149
  def rotate_half(x):
150
  """Rotates half the hidden dims of the input."""
151
  x1 = x[..., : x.shape[-1] // 2]
152
+ x2 = x[..., x.shape[-1] // 2:]
153
  return torch.cat((-x2, x1), dim=-1)
154
 
155
 
 
197
  f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
198
  f" and `num_heads`: {self.num_heads})."
199
  )
200
+ self.q_proj = nn.Linear(
201
+ self.hidden_size, self.num_heads * self.head_dim, bias=False)
202
+ self.k_proj = nn.Linear(
203
+ self.hidden_size, self.num_heads * self.head_dim, bias=False)
204
+ self.v_proj = nn.Linear(
205
+ self.hidden_size, self.num_heads * self.head_dim, bias=False)
206
+ self.o_proj = nn.Linear(
207
+ self.num_heads * self.head_dim, self.hidden_size, bias=False)
208
+ self.rotary_emb = XverseRotaryEmbedding(
209
+ self.head_dim, max_position_embeddings=self.max_position_embeddings)
210
 
211
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
212
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
 
219
  past_key_value: Optional[Tuple[torch.Tensor]] = None,
220
  output_attentions: bool = False,
221
  use_cache: bool = False,
222
+ dropout: Optional[float] = 0.1,
223
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
+ bsz, q_len, _ = hidden_states.size() # [bsz, q_len, hidden_size]
 
 
 
226
 
227
+ query_states = self.q_proj(hidden_states).view(
228
+ bsz, q_len, self.num_heads, self.head_dim)
229
+ key_states = self.k_proj(hidden_states).view(
230
+ bsz, q_len, self.num_heads, self.head_dim)
231
+ value_states = self.v_proj(hidden_states).view(
232
+ bsz, q_len, self.num_heads, self.head_dim)
233
+ # [bsz, q_len, nh, hd]
234
 
235
+ query_states = query_states.transpose(1, 2)
236
+ key_states = key_states.transpose(1, 2)
237
+ value_states = value_states.transpose(1, 2)
238
+ # [bsz, nh, q_len, hd]
239
 
240
+ kv_seq_len = key_states.shape[-2] # q_len
241
+ n_head = key_states.shape[-3] # nh
 
 
 
242
 
243
+ assert past_key_value is None, "past_key_value is not supported"
244
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
245
+ query_states, key_states = apply_rotary_pos_emb(
246
+ query_states, key_states, cos, sin, position_ids)
 
 
 
 
 
247
 
248
+ assert not output_attentions, "output_attentions is not supported"
249
+ assert not use_cache, "use_cache is not supported"
 
250
 
251
+ """
252
+ Input tensors must be in format ``[B, M, H, K]``, where B is the batch size, M
253
+ the sequence length, H the number of heads, and K the embeding size per head
254
+ """
255
+ query_states = query_states.transpose(1, 2)
256
+ key_states = key_states.transpose(1, 2)
257
+ value_states = value_states.transpose(1, 2)
258
 
259
+ attn_output = xops.memory_efficient_attention(
260
+ query_states, key_states, value_states, attn_bias=xops.LowerTriangularMask(), p=dropout
261
+ )
262
+
263
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 
264
  attn_output = self.o_proj(attn_output)
265
+ return attn_output, None, None
 
 
 
 
266
 
267
 
268
  class XverseDecoderLayer(nn.Module):
 
275
  intermediate_size=config.intermediate_size,
276
  hidden_act=config.hidden_act,
277
  )
278
+ self.input_layernorm = XverseRMSNorm(
279
+ config.hidden_size, eps=config.rms_norm_eps)
280
+ self.post_attention_layernorm = XverseRMSNorm(
281
+ config.hidden_size, eps=config.rms_norm_eps)
282
 
283
  def forward(
284
  self,
 
443
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
444
  """
445
 
446
+
447
  @add_start_docstrings(
448
  "The bare Xverse Model outputting raw hidden-states without any specific head on top.",
449
  XVERSE_START_DOCSTRING,
 
461
  self.padding_idx = config.pad_token_id
462
  self.vocab_size = config.vocab_size
463
 
464
+ self.embed_tokens = nn.Embedding(
465
+ config.vocab_size, config.hidden_size, self.padding_idx)
466
+ self.layers = nn.ModuleList(
467
+ [XverseDecoderLayer(config) for _ in range(config.num_hidden_layers)])
468
  self.norm = XverseRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
469
 
470
  self.gradient_checkpointing = False
 
496
  inputs_embeds.device
497
  )
498
  combined_attention_mask = (
499
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask +
500
+ combined_attention_mask
501
  )
502
 
503
  return combined_attention_mask
 
525
 
526
  # retrieve input_ids and inputs_embeds
527
  if input_ids is not None and inputs_embeds is not None:
528
+ raise ValueError(
529
+ "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
530
  elif input_ids is not None:
531
  batch_size, seq_length = input_ids.shape
532
  elif inputs_embeds is not None:
533
  batch_size, seq_length, _ = inputs_embeds.shape
534
  else:
535
+ raise ValueError(
536
+ "You have to specify either decoder_input_ids or decoder_inputs_embeds")
537
 
538
  seq_length_with_past = seq_length
539
  past_key_values_length = 0
 
559
  (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
560
  )
561
  attention_mask = self._prepare_decoder_attention_mask(
562
+ attention_mask, (batch_size,
563
+ seq_length), inputs_embeds, past_key_values_length
564
  )
565
 
566
  hidden_states = inputs_embeds
 
612
  hidden_states = layer_outputs[0]
613
 
614
  if use_cache:
615
+ next_decoder_cache += (
616
+ layer_outputs[2 if output_attentions else 1],)
617
 
618
  if output_attentions:
619
  all_self_attns += (layer_outputs[1],)
 
640
  super().__init__(config)
641
  self.model = XverseModel(config)
642
 
643
+ self.lm_head = nn.Linear(
644
+ config.hidden_size, config.vocab_size, bias=False)
645
 
646
  # Initialize weights and apply final processing
647
  self.post_init()
 
752
  attentions=outputs.attentions,
753
  )
754
 
755
+ def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int = 2048):
756
  max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
757
  max_input_tokens = self.config.max_position_embeddings - max_new_tokens
758
+ max_input_tokens = max(
759
+ self.config.max_position_embeddings // 2, max_input_tokens)
760
+ max_input_tokens = min(
761
+ self.config.max_tokenizer_truncation, max_input_tokens)
762
 
763
  total_input, round_input = [], []
764
+ user_prompt_tokens = tokenizer.encode(
765
+ "Human: ", return_token_type_ids=False)
766
+ exec_prompt_tokens = tokenizer.encode(
767
+ "Exec: ", return_token_type_ids=False)
768
+ assist_prompt_tokens = tokenizer.encode(
769
+ "Assistant: ", return_token_type_ids=False)
770
  assist_prompt_len = len(assist_prompt_tokens)
771
 
772
  for i, message in enumerate(messages[::-1]):
773
  if message['role'] == 'user' or message['role'] == 'exec':
774
  user_content = f"{message['content']}\n\n"
775
  content_tokens = user_prompt_tokens + tokenizer.encode(user_content, return_token_type_ids=False) if message['role'] == 'user' else \
776
+ exec_prompt_tokens + \
777
+ tokenizer.encode(user_content, return_token_type_ids=False)
778
  if i == 0:
779
+ content_tokens = content_tokens[:max_input_tokens -
780
+ assist_prompt_len]
781
  content_tokens += assist_prompt_tokens
782
  round_input = content_tokens + round_input
783
 
 
793
  round_input = []
794
  elif message['role'] == 'assistant':
795
  assist_content = f"{message['content']}"
796
+ content_tokens = assist_prompt_tokens + \
797
+ tokenizer.encode(
798
+ assist_content, return_token_type_ids=False)
799
+ round_input = content_tokens + \
800
+ [self.generation_config.eos_token_id] + round_input
801
  elif message['role'] == 'system':
802
  assert i == len(messages) - 1
803
  user_content = f"{message['content']}\n"
804
+ content_tokens = tokenizer.encode(
805
+ user_content, return_token_type_ids=False)
806
  round_input = user_prompt_tokens + content_tokens + round_input
807
  if len(total_input) + len(round_input) > max_input_tokens:
808
  break
809
  else:
810
  total_input = round_input + total_input
811
  else:
812
+ raise ValueError(
813
+ f"message role not supported yet: {message['role']}")
814
  total_input = torch.LongTensor([total_input]).to(self.device)
815
  return total_input
816
 
817
  @torch.no_grad()
818
  def chat(self, tokenizer, messages: List[dict], stream=False,
819
+ generation_config: Optional[GenerationConfig] = None):
820
  generation_config = generation_config or self.generation_config
821
+ input_ids = self._build_chat_input(
822
+ tokenizer, messages, generation_config.max_new_tokens)
823
  if stream:
824
  from transformers import TextIteratorStreamer
825
  from threading import Thread
 
827
  self.__class__.generate = PreTrainedModel.generate
828
 
829
  def stream_generator():
830
+ generation_kwargs = dict(
831
+ inputs=input_ids, generation_config=generation_config, streamer=streamer)
832
  thread = Thread(target=self.generate, kwargs=generation_kwargs)
833
  thread.start()
834
  for next_text in streamer:
 
837
  return stream_generator()
838
  else:
839
  self.__class__.generate = PreTrainedModel.generate # disable stream
840
+ outputs = self.generate(
841
+ input_ids, generation_config=generation_config)
842
+ response = tokenizer.decode(
843
+ outputs[0][len(input_ids[0]):], skip_special_tokens=True)
844
  return response
845
 
846
  def prepare_inputs_for_generation(
 
877
  def _reorder_cache(past_key_values, beam_idx):
878
  reordered_past = ()
879
  for layer_past in past_key_values:
880
+ reordered_past += (tuple(past_state.index_select(0, beam_idx)
881
+ for past_state in layer_past),)
882
  return reordered_past
883
 
884
  def quantize(self, bit_length: int):
 
887
  for layer in self.model.layers:
888
  layer.self_attn.q_proj = QuantizationLinear(
889
  bit_length=bit_length,
890
+ weight=layer.self_attn.q_proj.weight.to(
891
+ torch.cuda.current_device()),
892
  device=layer.self_attn.q_proj.weight.device,
893
  )
894
  layer.self_attn.k_proj = QuantizationLinear(
895
  bit_length=bit_length,
896
+ weight=layer.self_attn.k_proj.weight.to(
897
+ torch.cuda.current_device()),
898
  device=layer.self_attn.k_proj.weight.device
899
  )
900
  layer.self_attn.v_proj = QuantizationLinear(
901
  bit_length=bit_length,
902
+ weight=layer.self_attn.v_proj.weight.to(
903
+ torch.cuda.current_device()),
904
  device=layer.self_attn.v_proj.weight.device
905
  )
906
  layer.self_attn.o_proj = QuantizationLinear(
907
  bit_length=bit_length,
908
+ weight=layer.self_attn.o_proj.weight.to(
909
+ torch.cuda.current_device()),
910
  device=layer.self_attn.o_proj.weight.device
911
  )
912
  layer.mlp.gate_proj = QuantizationLinear(
913
  bit_length=bit_length,
914
+ weight=layer.mlp.gate_proj.weight.to(
915
+ torch.cuda.current_device()),
916
  device=layer.mlp.gate_proj.weight.device
917
  )
918
  layer.mlp.down_proj = QuantizationLinear(
919
  bit_length=bit_length,
920
+ weight=layer.mlp.down_proj.weight.to(
921
+ torch.cuda.current_device()),
922
  device=layer.mlp.down_proj.weight.device
923
  )
924
  layer.mlp.up_proj = QuantizationLinear(
925
  bit_length=bit_length,
926
+ weight=layer.mlp.up_proj.weight.to(
927
+ torch.cuda.current_device()),
928
  device=layer.mlp.up_proj.weight.device
929
  )
930
 
ms_wrapper.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba, Inc. and its affiliates.
2
+
3
+ import torch
4
+
5
+ from modelscope.models.base import TorchModel
6
+ from modelscope.preprocessors.base import Preprocessor
7
+ from modelscope.pipelines.base import Model, Pipeline
8
+ from modelscope.utils.config import Config
9
+ from modelscope.pipelines.builder import PIPELINES
10
+ from modelscope.preprocessors.builder import PREPROCESSORS
11
+ from modelscope.models.builder import MODELS
12
+ from transformers import AutoTokenizer, AutoModelForCausalLM
13
+
14
+ @MODELS.register_module('text-generation', module_name='XVERSE-13B')
15
+ class XVERSE13BTextGeneration(TorchModel):
16
+
17
+ def __init__(self, model_dir, *args, **kwargs):
18
+ super().__init__(model_dir, *args, **kwargs)
19
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
20
+ self.model = AutoModelForCausalLM.from_pretrained(model_dir, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
21
+ self.model = self.model.eval()
22
+
23
+ def forward(self, inputs, **forward_params):
24
+ inputs = self.tokenizer(inputs, return_tensors='pt').input_ids
25
+ inputs = inputs.cuda()
26
+ generated_ids = self.model.generate(inputs, eos_token_id=self.tokenizer.eos_token_id, **forward_params)
27
+ return self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
28
+
29
+
30
+ @PIPELINES.register_module('text-generation', module_name='XVERSE-13B-pipeline')
31
+ class XVERSE13BTextGenerationPipeline(Pipeline):
32
+ """ Give simple introduction to this pipeline.
33
+
34
+ Examples:
35
+
36
+ >>> from modelscope.pipelines import pipeline
37
+ >>> input = "Hello, ModelScope!"
38
+ >>> my_pipeline = pipeline('text-generation', 'xverse/XVERSE-13B')
39
+ >>> result = my_pipeline(input)
40
+
41
+ """
42
+
43
+ def __init__(self, model, **kwargs):
44
+ """
45
+ use `model` and `preprocessor` to create a custom pipeline for prediction
46
+ Args:
47
+ model: model id on modelscope hub.
48
+ preprocessor: the class of method be init_preprocessor
49
+ """
50
+ assert isinstance(model, str) or isinstance(model, Model), \
51
+ 'model must be a single str or Model'
52
+ if isinstance(model, str):
53
+ pipe_model = Model.from_pretrained(model)
54
+ elif isinstance(model, Model):
55
+ pipe_model = model
56
+ else:
57
+ raise NotImplementedError
58
+
59
+ super().__init__(model=pipe_model, **kwargs)
60
+
61
+ def _sanitize_parameters(self, **pipeline_parameters):
62
+ """
63
+ this method should sanitize the keyword args to preprocessor params,
64
+ forward params and postprocess params on '__call__' or '_process_single' method
65
+ considered to be a normal classmethod with default implementation / output
66
+
67
+ Default Returns:
68
+ Dict[str, str]: preprocess_params = {}
69
+ Dict[str, str]: forward_params = {}
70
+ Dict[str, str]: postprocess_params = pipeline_parameters
71
+ """
72
+ return {}, pipeline_parameters, {}
73
+
74
+ def preprocess(self, inputs, **preprocess_params):
75
+ return inputs
76
+
77
+ def forward(self, inputs, **forward_params):
78
+ """ Provide default implementation using self.model and user can reimplement it
79
+ """
80
+ output = super().forward(inputs, **forward_params)
81
+ return {'text': output}
82
+
83
+ def postprocess(self, inputs):
84
+ """ If current pipeline support model reuse, common postprocess
85
+ code should be write here.
86
+
87
+ Args:
88
+ inputs: input data
89
+
90
+ Return:
91
+ dict of results: a dict containing outputs of model, each
92
+ output should have the standard output name.
93
+ """
94
+ return inputs
95
+