OpenNLPLab commited on
Commit
def6113
1 Parent(s): 460b22e

Upload modeling_transnormer.py

Browse files
Files changed (1) hide show
  1. modeling_transnormer.py +156 -162
modeling_transnormer.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2023 OpenNLPLab
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
@@ -11,6 +11,7 @@
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
 
14
  # coding=utf-8
15
  """ PyTorch Transnormer model."""
16
  import math
@@ -52,8 +53,13 @@ logger = logging.get_logger(__name__)
52
 
53
  _CONFIG_FOR_DOC = "TransnormerConfig"
54
 
 
55
  use_triton = eval(os.environ.get("use_triton", default="True"))
56
  debug = eval(os.environ.get("debug", default="False"))
 
 
 
 
57
 
58
  if use_triton:
59
  try:
@@ -83,6 +89,7 @@ if not has_lightning_attention:
83
  ########## start Transnormer
84
  ##### Linearized Relative Positional Encoding: https://openreview.net/forum?id=xoLyps2qWc&referrer=%5BAuthor%20Console%5D(%2Fgroup%3Fid%3DTMLR%2FAuthors%23your-submissions)
85
  class Lrpe(nn.Module):
 
86
  def __init__(
87
  self,
88
  num_heads=8,
@@ -92,9 +99,8 @@ class Lrpe(nn.Module):
92
  d = num_heads * embed_dim
93
 
94
  self.index = torch.empty(0)
95
- self.theta = nn.Parameter(
96
- 10000 ** (-2 / d * torch.arange(d)).reshape(num_heads, 1, -1)
97
- )
98
 
99
  def extra_repr(self):
100
  return print_module(self)
@@ -113,6 +119,7 @@ class Lrpe(nn.Module):
113
 
114
 
115
  class GLU(nn.Module):
 
116
  def __init__(self, d1, d2, bias=False):
117
  super().__init__()
118
  if debug:
@@ -135,6 +142,7 @@ class GLU(nn.Module):
135
 
136
 
137
  class NormLinearAttention(nn.Module):
 
138
  def __init__(
139
  self,
140
  embed_dim,
@@ -181,7 +189,6 @@ class NormLinearAttention(nn.Module):
181
  use_cache: bool = False,
182
  slope_rate: Optional[torch.Tensor] = None,
183
  ):
184
- do_eval = eval(os.environ.get("do_eval", default="False"))
185
  if (not self.training) and (not do_eval):
186
  return self.inference(
187
  x,
@@ -198,8 +205,8 @@ class NormLinearAttention(nn.Module):
198
  q, k, v, u = self.qkvu_proj(x).chunk(4, dim=-1)
199
  # reshape
200
  q, k, v = map(
201
- lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.num_heads), [q, k, v]
202
- )
203
  # act
204
  q = self.act(q)
205
  k = self.act(k)
@@ -217,24 +224,23 @@ class NormLinearAttention(nn.Module):
217
  # lrpe
218
  if self.linear_use_lrpe:
219
  q = self.lrpe(q, offset=q_offset)
220
- k = self.lrpe(k)
221
 
222
  if attn_mask == None:
223
  attn_mask = (torch.tril(torch.ones(n, n))).to(q)
224
 
225
  if attn_padding_mask is not None:
226
  v = v.masked_fill(
227
- (1 - attn_padding_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0
228
- )
229
 
230
  if not has_lightning_attention:
231
  if slope_rate != None:
232
  attn_mask = torch.exp(slope_rate * attn_mask)
233
  output = linear_attention(q, k, v, attn_mask)
234
  else:
235
- output = lightning_attention(
236
- q, k, v, True, slope_rate.squeeze(-1).squeeze(-1)
237
- )
238
 
239
  # reshape
240
  output = rearrange(output, "b h n d -> b n (h d)")
@@ -253,14 +259,14 @@ class NormLinearAttention(nn.Module):
253
  return output, attn_weights, past_key_value
254
 
255
  def inference(
256
- self,
257
- x,
258
- attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m)
259
- attn_padding_mask: Optional[torch.Tensor] = None, # (b, m)
260
- output_attentions: bool = False,
261
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
262
- use_cache: bool = False,
263
- slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
264
  ):
265
  # x: b n d
266
  n = x.shape[-2]
@@ -268,8 +274,8 @@ class NormLinearAttention(nn.Module):
268
  q, k, v, u = self.qkvu_proj(x).chunk(4, dim=-1)
269
  # reshape
270
  q, k, v = map(
271
- lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.num_heads), [q, k, v]
272
- )
273
  # act
274
  q = self.act(q)
275
  k = self.act(k)
@@ -277,7 +283,7 @@ class NormLinearAttention(nn.Module):
277
  # rpe
278
  if self.linear_use_lrpe:
279
  q = self.lrpe(q, offset=self.offset)
280
- k = self.lrpe(k)
281
 
282
  if past_key_value == None:
283
  self.offset = q.shape[-2]
@@ -288,38 +294,47 @@ class NormLinearAttention(nn.Module):
288
 
289
  # only use for the first time
290
  if past_key_value == None:
291
- if attn_mask == None:
292
- attn_mask = (torch.tril(torch.ones(n, n))).to(q)
293
- if slope_rate != None:
294
- attn_mask = torch.exp(slope_rate * attn_mask)
295
-
296
  if attn_padding_mask is not None:
297
- attn_mask = attn_mask.masked_fill(
298
- (1 - attn_padding_mask).unsqueeze(1).unsqueeze(2).to(torch.bool),
299
- 0,
300
- )
301
- energy = torch.einsum("... n d, ... m d -> ... n m", q, k)
302
-
303
- if attn_mask != None:
304
- energy = energy * attn_mask
305
-
306
- output = torch.einsum("... n m, ... m d -> ... n d", energy, v)
307
-
308
- eval_and_not_generate = eval(
309
- os.environ.get("eval_and_not_generate", default="False")
310
- )
311
- if eval_and_not_generate:
312
- kv = None
313
- else:
314
- # b, h, n, e, d
315
- kv_outproduct = torch.einsum("... n e, ... n d -> ... n e d", k, v)
316
- # 1, 1, n, 1, 1
317
- index = torch.arange(n - 1, -1, -1).reshape(1, 1, -1, 1, 1).to(x)
318
- # (h, 1, 1) -> (1, h, 1, 1, 1); (1, h, 1, 1, 1), (1, 1, n, 1, 1) -> (1, h, n, 1, 1)
319
- decay = ratio.unsqueeze(0).unsqueeze(-1) ** index
320
-
321
- kv_outproduct_with_decay = kv_outproduct * decay
322
- kv = torch.sum(kv_outproduct_with_decay, dim=-3)
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  else:
324
  kv = past_key_value
325
 
@@ -327,12 +342,11 @@ class NormLinearAttention(nn.Module):
327
  for i in range(n):
328
  kv = ratio * kv + torch.einsum(
329
  "... n d, ... n e -> ... d e",
330
- k[:, :, i : i + 1],
331
- v[:, :, i : i + 1],
332
- )
333
- qkv = torch.einsum(
334
- "... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv
335
  )
 
 
336
  output.append(qkv)
337
  output = torch.concat(output, dim=-2)
338
 
@@ -351,6 +365,7 @@ class NormLinearAttention(nn.Module):
351
 
352
 
353
  class TransnormerDecoderLayer(nn.Module):
 
354
  def __init__(self, config: TransnormerConfig):
355
  super().__init__()
356
  self.embed_dim = config.decoder_embed_dim
@@ -389,14 +404,14 @@ class TransnormerDecoderLayer(nn.Module):
389
  return residual + x
390
 
391
  def forward(
392
- self,
393
- x,
394
- attn_mask: Optional[torch.Tensor] = None,
395
- attn_padding_mask: Optional[torch.Tensor] = None,
396
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
397
- output_attentions: Optional[bool] = False,
398
- use_cache: Optional[bool] = False,
399
- slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
400
  ):
401
  residual = x
402
  x = self.token_norm(x)
@@ -416,13 +431,13 @@ class TransnormerDecoderLayer(nn.Module):
416
  x = self.channel_mixer(x)
417
  x = self.residual_connection(x, residual)
418
 
419
- outputs = (x,)
420
 
421
  if output_attentions:
422
- outputs += (self_attn_weights,)
423
 
424
  if use_cache:
425
- outputs += (present_key_value,)
426
 
427
  return outputs
428
 
@@ -444,9 +459,7 @@ TRANSNORMER_START_DOCSTRING = r"""
444
  """
445
 
446
 
447
- @add_start_docstrings(
448
- TRANSNORMER_START_DOCSTRING,
449
- )
450
  class TransnormerPreTrainedModel(PreTrainedModel):
451
  config_class = TransnormerConfig
452
  base_model_prefix = "model"
@@ -531,9 +544,7 @@ TRANSNORMER_INPUTS_DOCSTRING = r"""
531
  """
532
 
533
 
534
- @add_start_docstrings(
535
- TRANSNORMER_START_DOCSTRING,
536
- )
537
  class TransnormerModel(TransnormerPreTrainedModel):
538
  """
539
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`TransnormerDecoderLayer`]
@@ -557,29 +568,31 @@ class TransnormerModel(TransnormerPreTrainedModel):
557
  self.slopes = self._build_slope_tensor(config.decoder_attention_heads)
558
 
559
  # params
560
- self.embed_tokens = nn.Embedding(
561
- config.vocab_size, config.decoder_embed_dim, self.padding_idx
562
- )
563
  self.layers = nn.ModuleList([])
564
  for i in range(config.decoder_layers):
565
  if len(self.linear_use_lrpe_list) > 0:
566
  config.linear_use_lrpe = self.linear_use_lrpe_list[i]
567
  self.layers.append(TransnormerDecoderLayer(config))
568
 
569
- self.final_norm = get_norm_fn(config.norm_type)(config.decoder_embed_dim)
 
570
  self.embed_dim = config.decoder_embed_dim
571
- self.embed_scale = (
572
- 1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
573
- )
574
 
575
  # Initialize weights and apply final processing
576
  self.post_init()
577
 
578
  @staticmethod
579
  def _build_slope_tensor(n_attention_heads: int):
 
580
  def get_slopes(n):
 
581
  def get_slopes_power_of_2(n):
582
- start = 2 ** (-(2 ** -(math.log2(n) - 3)))
583
  ratio = start
584
  return [start * ratio**i for i in range(n)]
585
 
@@ -588,18 +601,15 @@ class TransnormerModel(TransnormerPreTrainedModel):
588
  n
589
  ) # In the paper, we only train models that have 2^a heads for some a. This function has
590
  else: # some good properties that only occur when the input is a power of 2. To maintain that even
591
- closest_power_of_2 = 2 ** math.floor(
592
  math.log2(n)
593
  ) # when the number of heads is not a power of 2, we use this workaround.
594
- return (
595
- get_slopes_power_of_2(closest_power_of_2)
596
- + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
597
- )
598
 
599
  # h, 1, 1
600
  slopes = torch.tensor(get_slopes(n_attention_heads)).reshape(
601
- n_attention_heads, 1, 1
602
- )
603
 
604
  return slopes
605
 
@@ -612,26 +622,26 @@ class TransnormerModel(TransnormerPreTrainedModel):
612
  def set_input_embeddings(self, value):
613
  self.embed_tokens = value
614
 
615
- def _prepare_decoder_linear_attn_mask(
616
- self, input_shape, inputs_embeds, past_key_values_length
617
- ):
618
  bsz, tgt_len = input_shape
619
  src_len = tgt_len + past_key_values_length
620
 
621
  def power_log(x):
622
- return 2 ** (math.ceil(math.log(x, 2)))
623
 
624
  n = power_log(max(tgt_len, src_len))
625
  if self._linear_attn_mask.shape[-1] < n:
626
 
627
  def get_mask(n):
628
- mask = torch.triu(torch.zeros(n, n).float().fill_(float("-inf")), 1)
 
629
  # no slope version
630
  # -n, ..., -2, -1, 0
631
  for i in range(n):
632
  x = torch.arange(i + 1)
633
  y = x
634
- mask[i, : i + 1] = -torch.flip(y, [0])
635
 
636
  return mask
637
 
@@ -643,7 +653,8 @@ class TransnormerModel(TransnormerPreTrainedModel):
643
  linear_attn_mask = self._linear_attn_mask[:, -tgt_len:, -src_len:]
644
  num_heads = linear_attn_mask.shape[0]
645
 
646
- return linear_attn_mask[None, :, :, :].expand(bsz, num_heads, tgt_len, src_len)
 
647
 
648
  @add_start_docstrings_to_model_forward(TRANSNORMER_INPUTS_DOCSTRING)
649
  def forward(
@@ -657,21 +668,15 @@ class TransnormerModel(TransnormerPreTrainedModel):
657
  output_hidden_states: Optional[bool] = None,
658
  return_dict: Optional[bool] = None,
659
  ) -> Union[Tuple, BaseModelOutputWithPast]:
660
- output_attentions = (
661
- output_attentions
662
- if output_attentions is not None
663
- else self.config.output_attentions
664
- )
665
- output_hidden_states = (
666
- output_hidden_states
667
- if output_hidden_states is not None
668
- else self.config.output_hidden_states
669
- )
670
  use_cache = use_cache if use_cache is not None else self.config.use_cache
671
 
672
- return_dict = (
673
- return_dict if return_dict is not None else self.config.use_return_dict
674
- )
675
 
676
  # retrieve input_ids and inputs_embeds
677
  if input_ids is not None and inputs_embeds is not None:
@@ -693,7 +698,7 @@ class TransnormerModel(TransnormerPreTrainedModel):
693
  if past_key_values is not None:
694
  past_key_values_length = past_key_values[0][0].shape[-2]
695
  seq_length_with_past = seq_length_with_past + past_key_values_length
696
-
697
  if inputs_embeds is None:
698
  # !!! use embed_scale
699
  inputs_embeds = self.embed_scale * self.embed_tokens(input_ids)
@@ -715,23 +720,23 @@ class TransnormerModel(TransnormerPreTrainedModel):
715
  ##### norm linear layers
716
  linear_attn_padding_mask = attn_padding_mask
717
  linear_attn_mask = self._prepare_decoder_linear_attn_mask(
718
- (batch_size, seq_length), inputs_embeds, past_key_values_length
719
- )
720
 
721
- slope_rates = [self.slopes.to(input_ids.device) for _ in range(self.num_layers)]
 
 
722
 
723
  for idx, layer in enumerate(self.layers):
724
  if output_hidden_states:
725
- all_hidden_states += (hidden_states,)
726
 
727
- past_key_value = (
728
- past_key_values[idx] if past_key_values is not None else None
729
- )
730
 
731
  slope_rate = slope_rates[idx]
732
  slope_rate = slope_rate * (1 - idx / (self.num_layers - 1) + 1e-5)
733
  mask = linear_attn_mask
734
-
735
  layer_outputs = layer(
736
  hidden_states,
737
  attn_mask=mask,
@@ -745,24 +750,24 @@ class TransnormerModel(TransnormerPreTrainedModel):
745
  hidden_states = layer_outputs[0]
746
 
747
  if use_cache:
748
- next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
 
749
 
750
  if output_attentions:
751
- all_self_attns += (layer_outputs[1],)
752
 
753
  hidden_states = self.final_norm(hidden_states)
754
 
755
  # add hidden states from the last decoder layer
756
  if output_hidden_states:
757
- all_hidden_states += (hidden_states,)
758
 
759
  next_cache = next_decoder_cache if use_cache else None
760
  if not return_dict:
761
  return tuple(
762
- v
763
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
764
- if v is not None
765
- )
766
  return BaseModelOutputWithPast(
767
  last_hidden_state=hidden_states,
768
  past_key_values=next_cache,
@@ -772,6 +777,7 @@ class TransnormerModel(TransnormerPreTrainedModel):
772
 
773
 
774
  class TransnormerForCausalLM(TransnormerPreTrainedModel):
 
775
  def __init__(self, config):
776
  super().__init__(config)
777
  self.model = TransnormerModel(config)
@@ -779,9 +785,9 @@ class TransnormerForCausalLM(TransnormerPreTrainedModel):
779
  logging_info(self.model)
780
 
781
  # the lm_head weight is automatically tied to the embed tokens weight
782
- self.lm_head = nn.Linear(
783
- config.decoder_embed_dim, config.vocab_size, bias=False
784
- )
785
 
786
  # Initialize weights and apply final processing
787
  self.post_init()
@@ -805,9 +811,8 @@ class TransnormerForCausalLM(TransnormerPreTrainedModel):
805
  return self.model
806
 
807
  @add_start_docstrings_to_model_forward(TRANSNORMER_INPUTS_DOCSTRING)
808
- @replace_return_docstrings(
809
- output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
810
- )
811
  def forward(
812
  self,
813
  input_ids: torch.LongTensor = None,
@@ -845,19 +850,13 @@ class TransnormerForCausalLM(TransnormerPreTrainedModel):
845
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
846
  "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
847
  ```"""
848
- output_attentions = (
849
- output_attentions
850
- if output_attentions is not None
851
- else self.config.output_attentions
852
- )
853
- output_hidden_states = (
854
- output_hidden_states
855
- if output_hidden_states is not None
856
- else self.config.output_hidden_states
857
- )
858
- return_dict = (
859
- return_dict if return_dict is not None else self.config.use_return_dict
860
- )
861
 
862
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
863
  outputs = self.model(
@@ -888,8 +887,8 @@ class TransnormerForCausalLM(TransnormerPreTrainedModel):
888
  loss = loss_fct(shift_logits, shift_labels)
889
 
890
  if not return_dict:
891
- output = (logits,) + outputs[1:]
892
- return (loss,) + output if loss is not None else output
893
 
894
  return CausalLMOutputWithPast(
895
  loss=loss,
@@ -916,23 +915,18 @@ class TransnormerForCausalLM(TransnormerPreTrainedModel):
916
  else:
917
  model_inputs = {"input_ids": input_ids}
918
 
919
- model_inputs.update(
920
- {
921
- "past_key_values": past_key_values,
922
- "use_cache": kwargs.get("use_cache"),
923
- "attention_mask": attention_mask,
924
- }
925
- )
926
  return model_inputs
927
 
928
  @staticmethod
929
  def _reorder_cache(past_key_values, beam_idx):
930
  reordered_past = ()
931
  for layer_past in past_key_values:
932
- reordered_past += (
933
- tuple(
934
- past_state.index_select(0, beam_idx) for past_state in layer_past
935
- ),
936
- )
937
  return reordered_past
938
-
 
1
+ # Copyright 2024 OpenNLPLab
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
 
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
+
15
  # coding=utf-8
16
  """ PyTorch Transnormer model."""
17
  import math
 
53
 
54
  _CONFIG_FOR_DOC = "TransnormerConfig"
55
 
56
+ # TODO: fix environment: https://huggingface.co/OpenNLPLab/TransNormerLLM-7B/discussions/1
57
  use_triton = eval(os.environ.get("use_triton", default="True"))
58
  debug = eval(os.environ.get("debug", default="False"))
59
+ do_eval = eval(os.environ.get("do_eval", default="False"))
60
+ eval_and_not_generate = eval(
61
+ os.environ.get("eval_and_not_generate", default="False"))
62
+ BLOCK = 256
63
 
64
  if use_triton:
65
  try:
 
89
  ########## start Transnormer
90
  ##### Linearized Relative Positional Encoding: https://openreview.net/forum?id=xoLyps2qWc&referrer=%5BAuthor%20Console%5D(%2Fgroup%3Fid%3DTMLR%2FAuthors%23your-submissions)
91
  class Lrpe(nn.Module):
92
+
93
  def __init__(
94
  self,
95
  num_heads=8,
 
99
  d = num_heads * embed_dim
100
 
101
  self.index = torch.empty(0)
102
+ self.theta = nn.Parameter(10000**(-2 / d * torch.arange(d)).reshape(
103
+ num_heads, 1, -1))
 
104
 
105
  def extra_repr(self):
106
  return print_module(self)
 
119
 
120
 
121
  class GLU(nn.Module):
122
+
123
  def __init__(self, d1, d2, bias=False):
124
  super().__init__()
125
  if debug:
 
142
 
143
 
144
  class NormLinearAttention(nn.Module):
145
+
146
  def __init__(
147
  self,
148
  embed_dim,
 
189
  use_cache: bool = False,
190
  slope_rate: Optional[torch.Tensor] = None,
191
  ):
 
192
  if (not self.training) and (not do_eval):
193
  return self.inference(
194
  x,
 
205
  q, k, v, u = self.qkvu_proj(x).chunk(4, dim=-1)
206
  # reshape
207
  q, k, v = map(
208
+ lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.num_heads),
209
+ [q, k, v])
210
  # act
211
  q = self.act(q)
212
  k = self.act(k)
 
224
  # lrpe
225
  if self.linear_use_lrpe:
226
  q = self.lrpe(q, offset=q_offset)
227
+ k = self.lrpe(k, offset=q_offset)
228
 
229
  if attn_mask == None:
230
  attn_mask = (torch.tril(torch.ones(n, n))).to(q)
231
 
232
  if attn_padding_mask is not None:
233
  v = v.masked_fill(
234
+ (1 - attn_padding_mask).unsqueeze(1).unsqueeze(-1).to(
235
+ torch.bool), 0)
236
 
237
  if not has_lightning_attention:
238
  if slope_rate != None:
239
  attn_mask = torch.exp(slope_rate * attn_mask)
240
  output = linear_attention(q, k, v, attn_mask)
241
  else:
242
+ output = lightning_attention(q, k, v, True,
243
+ slope_rate.squeeze(-1).squeeze(-1))
 
244
 
245
  # reshape
246
  output = rearrange(output, "b h n d -> b n (h d)")
 
259
  return output, attn_weights, past_key_value
260
 
261
  def inference(
262
+ self,
263
+ x,
264
+ attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m)
265
+ attn_padding_mask: Optional[torch.Tensor] = None, # (b, m)
266
+ output_attentions: bool = False,
267
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
268
+ use_cache: bool = False,
269
+ slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
270
  ):
271
  # x: b n d
272
  n = x.shape[-2]
 
274
  q, k, v, u = self.qkvu_proj(x).chunk(4, dim=-1)
275
  # reshape
276
  q, k, v = map(
277
+ lambda x: rearrange(x, "b n (h d) -> b h n d", h=self.num_heads),
278
+ [q, k, v])
279
  # act
280
  q = self.act(q)
281
  k = self.act(k)
 
283
  # rpe
284
  if self.linear_use_lrpe:
285
  q = self.lrpe(q, offset=self.offset)
286
+ k = self.lrpe(k, offset=self.offset)
287
 
288
  if past_key_value == None:
289
  self.offset = q.shape[-2]
 
294
 
295
  # only use for the first time
296
  if past_key_value == None:
297
+ slope_rate = slope_rate.to(torch.float32)
 
 
 
 
298
  if attn_padding_mask is not None:
299
+ v = v.masked_fill(
300
+ (1 - attn_padding_mask).unsqueeze(1).unsqueeze(-1).to(
301
+ torch.bool), 0)
302
+ NUM_BLOCK = (n + BLOCK - 1) // BLOCK
303
+ b, h, n, d = q.shape
304
+ e = v.shape[-1]
305
+ # other
306
+ array = torch.arange(BLOCK).to(q) + 1 ## !!!! important
307
+ q_decay = torch.exp(-slope_rate * array.reshape(-1, 1))
308
+ k_decay = torch.exp(-slope_rate * (BLOCK - array.reshape(-1, 1)))
309
+ index = array[:, None] - array[None, :]
310
+ s_index = slope_rate * index[
311
+ None,
312
+ None,
313
+ ]
314
+ s_index = torch.where(index >= 0, -s_index, float("-inf"))
315
+ diag_decay = torch.exp(s_index)
316
+
317
+ kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device)
318
+ output = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)
319
+ for i in range(NUM_BLOCK):
320
+ si = i * BLOCK
321
+ ei = min(si + BLOCK, n)
322
+ m = ei - si
323
+
324
+ qi = q[:, :, si:ei].contiguous()
325
+ ki = k[:, :, si:ei].contiguous()
326
+ vi = v[:, :, si:ei].contiguous()
327
+ qkv_none_diag = torch.matmul(qi * q_decay[:, :m],
328
+ kv).to(torch.float32)
329
+
330
+ # diag
331
+ qk = torch.matmul(qi, ki.transpose(-1, -2)).to(
332
+ torch.float32) * diag_decay[:, :, :m, :m]
333
+ qkv_diag = torch.matmul(qk, vi.to(torch.float32))
334
+ block_decay = torch.exp(-slope_rate * m)
335
+ output[:, :, si:ei] = qkv_none_diag + qkv_diag
336
+ kv = block_decay * kv + torch.matmul(
337
+ (ki * k_decay[:, -m:]).transpose(-1, -2).to(vi.dtype), vi)
338
  else:
339
  kv = past_key_value
340
 
 
342
  for i in range(n):
343
  kv = ratio * kv + torch.einsum(
344
  "... n d, ... n e -> ... d e",
345
+ k[:, :, i:i + 1],
346
+ v[:, :, i:i + 1],
 
 
 
347
  )
348
+ qkv = torch.einsum("... n e, ... e d -> ... n d",
349
+ q[:, :, i:i + 1], kv)
350
  output.append(qkv)
351
  output = torch.concat(output, dim=-2)
352
 
 
365
 
366
 
367
  class TransnormerDecoderLayer(nn.Module):
368
+
369
  def __init__(self, config: TransnormerConfig):
370
  super().__init__()
371
  self.embed_dim = config.decoder_embed_dim
 
404
  return residual + x
405
 
406
  def forward(
407
+ self,
408
+ x,
409
+ attn_mask: Optional[torch.Tensor] = None,
410
+ attn_padding_mask: Optional[torch.Tensor] = None,
411
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
412
+ output_attentions: Optional[bool] = False,
413
+ use_cache: Optional[bool] = False,
414
+ slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1)
415
  ):
416
  residual = x
417
  x = self.token_norm(x)
 
431
  x = self.channel_mixer(x)
432
  x = self.residual_connection(x, residual)
433
 
434
+ outputs = (x, )
435
 
436
  if output_attentions:
437
+ outputs += (self_attn_weights, )
438
 
439
  if use_cache:
440
+ outputs += (present_key_value, )
441
 
442
  return outputs
443
 
 
459
  """
460
 
461
 
462
+ @add_start_docstrings(TRANSNORMER_START_DOCSTRING, )
 
 
463
  class TransnormerPreTrainedModel(PreTrainedModel):
464
  config_class = TransnormerConfig
465
  base_model_prefix = "model"
 
544
  """
545
 
546
 
547
+ @add_start_docstrings(TRANSNORMER_START_DOCSTRING, )
 
 
548
  class TransnormerModel(TransnormerPreTrainedModel):
549
  """
550
  Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`TransnormerDecoderLayer`]
 
568
  self.slopes = self._build_slope_tensor(config.decoder_attention_heads)
569
 
570
  # params
571
+ self.embed_tokens = nn.Embedding(config.vocab_size,
572
+ config.decoder_embed_dim,
573
+ self.padding_idx)
574
  self.layers = nn.ModuleList([])
575
  for i in range(config.decoder_layers):
576
  if len(self.linear_use_lrpe_list) > 0:
577
  config.linear_use_lrpe = self.linear_use_lrpe_list[i]
578
  self.layers.append(TransnormerDecoderLayer(config))
579
 
580
+ self.final_norm = get_norm_fn(config.norm_type)(
581
+ config.decoder_embed_dim)
582
  self.embed_dim = config.decoder_embed_dim
583
+ self.embed_scale = (1.0 if config.no_scale_embedding else math.sqrt(
584
+ self.embed_dim))
 
585
 
586
  # Initialize weights and apply final processing
587
  self.post_init()
588
 
589
  @staticmethod
590
  def _build_slope_tensor(n_attention_heads: int):
591
+
592
  def get_slopes(n):
593
+
594
  def get_slopes_power_of_2(n):
595
+ start = 2**(-(2**-(math.log2(n) - 3)))
596
  ratio = start
597
  return [start * ratio**i for i in range(n)]
598
 
 
601
  n
602
  ) # In the paper, we only train models that have 2^a heads for some a. This function has
603
  else: # some good properties that only occur when the input is a power of 2. To maintain that even
604
+ closest_power_of_2 = 2**math.floor(
605
  math.log2(n)
606
  ) # when the number of heads is not a power of 2, we use this workaround.
607
+ return (get_slopes_power_of_2(closest_power_of_2) + get_slopes(
608
+ 2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
 
 
609
 
610
  # h, 1, 1
611
  slopes = torch.tensor(get_slopes(n_attention_heads)).reshape(
612
+ n_attention_heads, 1, 1)
 
613
 
614
  return slopes
615
 
 
622
  def set_input_embeddings(self, value):
623
  self.embed_tokens = value
624
 
625
+ def _prepare_decoder_linear_attn_mask(self, input_shape, inputs_embeds,
626
+ past_key_values_length):
 
627
  bsz, tgt_len = input_shape
628
  src_len = tgt_len + past_key_values_length
629
 
630
  def power_log(x):
631
+ return 2**(math.ceil(math.log(x, 2)))
632
 
633
  n = power_log(max(tgt_len, src_len))
634
  if self._linear_attn_mask.shape[-1] < n:
635
 
636
  def get_mask(n):
637
+ mask = torch.triu(
638
+ torch.zeros(n, n).float().fill_(float("-inf")), 1)
639
  # no slope version
640
  # -n, ..., -2, -1, 0
641
  for i in range(n):
642
  x = torch.arange(i + 1)
643
  y = x
644
+ mask[i, :i + 1] = -torch.flip(y, [0])
645
 
646
  return mask
647
 
 
653
  linear_attn_mask = self._linear_attn_mask[:, -tgt_len:, -src_len:]
654
  num_heads = linear_attn_mask.shape[0]
655
 
656
+ return linear_attn_mask[None, :, :, :].expand(bsz, num_heads, tgt_len,
657
+ src_len)
658
 
659
  @add_start_docstrings_to_model_forward(TRANSNORMER_INPUTS_DOCSTRING)
660
  def forward(
 
668
  output_hidden_states: Optional[bool] = None,
669
  return_dict: Optional[bool] = None,
670
  ) -> Union[Tuple, BaseModelOutputWithPast]:
671
+ output_attentions = (output_attentions if output_attentions is not None
672
+ else self.config.output_attentions)
673
+ output_hidden_states = (output_hidden_states
674
+ if output_hidden_states is not None else
675
+ self.config.output_hidden_states)
 
 
 
 
 
676
  use_cache = use_cache if use_cache is not None else self.config.use_cache
677
 
678
+ return_dict = (return_dict if return_dict is not None else
679
+ self.config.use_return_dict)
 
680
 
681
  # retrieve input_ids and inputs_embeds
682
  if input_ids is not None and inputs_embeds is not None:
 
698
  if past_key_values is not None:
699
  past_key_values_length = past_key_values[0][0].shape[-2]
700
  seq_length_with_past = seq_length_with_past + past_key_values_length
701
+
702
  if inputs_embeds is None:
703
  # !!! use embed_scale
704
  inputs_embeds = self.embed_scale * self.embed_tokens(input_ids)
 
720
  ##### norm linear layers
721
  linear_attn_padding_mask = attn_padding_mask
722
  linear_attn_mask = self._prepare_decoder_linear_attn_mask(
723
+ (batch_size, seq_length), inputs_embeds, past_key_values_length)
 
724
 
725
+ slope_rates = [
726
+ self.slopes.to(input_ids.device) for _ in range(self.num_layers)
727
+ ]
728
 
729
  for idx, layer in enumerate(self.layers):
730
  if output_hidden_states:
731
+ all_hidden_states += (hidden_states, )
732
 
733
+ past_key_value = (past_key_values[idx]
734
+ if past_key_values is not None else None)
 
735
 
736
  slope_rate = slope_rates[idx]
737
  slope_rate = slope_rate * (1 - idx / (self.num_layers - 1) + 1e-5)
738
  mask = linear_attn_mask
739
+
740
  layer_outputs = layer(
741
  hidden_states,
742
  attn_mask=mask,
 
750
  hidden_states = layer_outputs[0]
751
 
752
  if use_cache:
753
+ next_decoder_cache += (
754
+ layer_outputs[2 if output_attentions else 1], )
755
 
756
  if output_attentions:
757
+ all_self_attns += (layer_outputs[1], )
758
 
759
  hidden_states = self.final_norm(hidden_states)
760
 
761
  # add hidden states from the last decoder layer
762
  if output_hidden_states:
763
+ all_hidden_states += (hidden_states, )
764
 
765
  next_cache = next_decoder_cache if use_cache else None
766
  if not return_dict:
767
  return tuple(
768
+ v for v in
769
+ [hidden_states, next_cache, all_hidden_states, all_self_attns]
770
+ if v is not None)
 
771
  return BaseModelOutputWithPast(
772
  last_hidden_state=hidden_states,
773
  past_key_values=next_cache,
 
777
 
778
 
779
  class TransnormerForCausalLM(TransnormerPreTrainedModel):
780
+
781
  def __init__(self, config):
782
  super().__init__(config)
783
  self.model = TransnormerModel(config)
 
785
  logging_info(self.model)
786
 
787
  # the lm_head weight is automatically tied to the embed tokens weight
788
+ self.lm_head = nn.Linear(config.decoder_embed_dim,
789
+ config.vocab_size,
790
+ bias=False)
791
 
792
  # Initialize weights and apply final processing
793
  self.post_init()
 
811
  return self.model
812
 
813
  @add_start_docstrings_to_model_forward(TRANSNORMER_INPUTS_DOCSTRING)
814
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast,
815
+ config_class=_CONFIG_FOR_DOC)
 
816
  def forward(
817
  self,
818
  input_ids: torch.LongTensor = None,
 
850
  >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
851
  "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
852
  ```"""
853
+ output_attentions = (output_attentions if output_attentions is not None
854
+ else self.config.output_attentions)
855
+ output_hidden_states = (output_hidden_states
856
+ if output_hidden_states is not None else
857
+ self.config.output_hidden_states)
858
+ return_dict = (return_dict if return_dict is not None else
859
+ self.config.use_return_dict)
 
 
 
 
 
 
860
 
861
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
862
  outputs = self.model(
 
887
  loss = loss_fct(shift_logits, shift_labels)
888
 
889
  if not return_dict:
890
+ output = (logits, ) + outputs[1:]
891
+ return (loss, ) + output if loss is not None else output
892
 
893
  return CausalLMOutputWithPast(
894
  loss=loss,
 
915
  else:
916
  model_inputs = {"input_ids": input_ids}
917
 
918
+ model_inputs.update({
919
+ "past_key_values": past_key_values,
920
+ "use_cache": kwargs.get("use_cache"),
921
+ "attention_mask": attention_mask,
922
+ })
 
 
923
  return model_inputs
924
 
925
  @staticmethod
926
  def _reorder_cache(past_key_values, beam_idx):
927
  reordered_past = ()
928
  for layer_past in past_key_values:
929
+ reordered_past += (tuple(
930
+ past_state.index_select(0, beam_idx)
931
+ for past_state in layer_past), )
 
 
932
  return reordered_past