jbochi commited on
Commit
9244048
1 Parent(s): 06e4cba

Fix activation and use rotary embeddings

Browse files
Files changed (2) hide show
  1. config.json +4 -2
  2. decoder_only_t5/modeling.py +425 -32
config.json CHANGED
@@ -9,7 +9,7 @@
9
  "decoder_start_token_id": 0,
10
  "pad_token_id": 1,
11
  "eos_token_id": 3,
12
- "feed_forward_proj": "gated-gelu",
13
  "initializer_factor": 1.0,
14
  "is_encoder_decoder": false,
15
  "is_decoder_only": true,
@@ -29,5 +29,7 @@
29
  "vocab_size": 256512,
30
  "parallel_layers": true,
31
  "has_relative_attention_bias": false,
32
- "multi_query_attention": true
 
 
33
  }
 
9
  "decoder_start_token_id": 0,
10
  "pad_token_id": 1,
11
  "eos_token_id": 3,
12
+ "feed_forward_proj": "gated-swish",
13
  "initializer_factor": 1.0,
14
  "is_encoder_decoder": false,
15
  "is_decoder_only": true,
 
29
  "vocab_size": 256512,
30
  "parallel_layers": true,
31
  "has_relative_attention_bias": false,
32
+ "multi_query_attention": true,
33
+ "use_rotary_embedding": true,
34
+ "rotary_embedding_max_timescale": 1000
35
  }
decoder_only_t5/modeling.py CHANGED
@@ -36,6 +36,84 @@ class DecoderOnlyT5LayerFF(modeling_t5.T5LayerFF):
36
  self.dropout = nn.Dropout(config.dropout_rate)
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  # https://github.com/huggingface/transformers/blob/7ee995fd9c692761c4601ddbffa2ac2ec9f27b0b/src/transformers/models/llama/modeling_llama.py#L263
40
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
41
  """
@@ -72,9 +150,16 @@ class DecoderOnlyT5Attention(modeling_t5.T5Attention):
72
  self.dropout = config.dropout_rate
73
  self.inner_dim = self.n_heads * self.key_value_proj_dim
74
  self.kv_inner_dim = self.n_kv_heads * self.key_value_proj_dim
 
 
 
 
 
 
 
 
75
 
76
  # Mesh TensorFlow initialization to avoid scaling before softmax
77
-
78
  self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
79
  self.k = nn.Linear(self.d_model, self.kv_inner_dim, bias=False)
80
  self.v = nn.Linear(self.d_model, self.kv_inner_dim, bias=False)
@@ -93,6 +178,7 @@ class DecoderOnlyT5Attention(modeling_t5.T5Attention):
93
  mask=None,
94
  key_value_states=None,
95
  position_bias=None,
 
96
  past_key_value=None,
97
  layer_head_mask=None,
98
  query_length=None,
@@ -144,21 +230,25 @@ class DecoderOnlyT5Attention(modeling_t5.T5Attention):
144
  # cross-attn
145
  # (batch_size, n_kv_heads, seq_length, dim_per_head)
146
  hidden_states = shape(proj_layer(key_value_states), self.n_kv_heads)
 
147
 
148
- if past_key_value is not None:
149
- if key_value_states is None:
150
- # self-attn
151
- # (batch_size, n_kv_heads, key_length, dim_per_head)
152
- hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
153
- elif past_key_value.shape[2] != key_value_states.shape[1]:
154
- # checking that the `sequence_length` of the `past_key_value` is the same as
155
- # the provided `key_value_states` to support prefix tuning
156
- # cross-attn
157
- # (batch_size, n_kv_heads, seq_length, dim_per_head)
158
- hidden_states = shape(proj_layer(key_value_states), self.n_kv_heads)
159
- else:
160
- # cross-attn
161
- hidden_states = past_key_value
 
 
 
162
  return hidden_states
163
 
164
  # get query states
@@ -167,24 +257,35 @@ class DecoderOnlyT5Attention(modeling_t5.T5Attention):
167
  ) # (batch_size, n_heads, seq_length, dim_per_head)
168
 
169
  # get key/value states
170
- key_states = repeat_kv(
171
- project(
172
- hidden_states,
173
- self.k,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  key_value_states,
175
- past_key_value[0] if past_key_value is not None else None,
176
- ),
177
- self.n_kv_groups,
178
- )
179
- value_states = repeat_kv(
180
- project(
181
- hidden_states,
182
- self.v,
183
  key_value_states,
184
- past_key_value[1] if past_key_value is not None else None,
185
- ),
186
- self.n_kv_groups,
187
- )
 
188
 
189
  # compute scores
190
  scores = torch.matmul(
@@ -266,6 +367,7 @@ class DecoderOnlyT5LayerSelfAttention(modeling_t5.T5LayerSelfAttention):
266
  hidden_states,
267
  attention_mask=None,
268
  position_bias=None,
 
269
  layer_head_mask=None,
270
  past_key_value=None,
271
  use_cache=False,
@@ -279,6 +381,7 @@ class DecoderOnlyT5LayerSelfAttention(modeling_t5.T5LayerSelfAttention):
279
  x,
280
  mask=attention_mask,
281
  position_bias=position_bias,
 
282
  layer_head_mask=layer_head_mask,
283
  past_key_value=past_key_value,
284
  use_cache=use_cache,
@@ -320,6 +423,7 @@ class DecoderOnlyT5Block(modeling_t5.T5Block):
320
  hidden_states,
321
  attention_mask=None,
322
  position_bias=None,
 
323
  encoder_hidden_states=None,
324
  encoder_attention_mask=None,
325
  encoder_decoder_position_bias=None,
@@ -361,6 +465,7 @@ class DecoderOnlyT5Block(modeling_t5.T5Block):
361
  x,
362
  attention_mask=attention_mask,
363
  position_bias=position_bias,
 
364
  layer_head_mask=layer_head_mask,
365
  past_key_value=self_attn_past_key_value,
366
  use_cache=use_cache,
@@ -398,6 +503,7 @@ class DecoderOnlyT5Block(modeling_t5.T5Block):
398
  key_value_states=encoder_hidden_states,
399
  attention_mask=encoder_attention_mask,
400
  position_bias=encoder_decoder_position_bias,
 
401
  layer_head_mask=cross_attn_layer_head_mask,
402
  past_key_value=cross_attn_past_key_value,
403
  query_length=query_length,
@@ -486,6 +592,284 @@ class DecoderOnlyT5Stack(modeling_t5.T5Stack):
486
  self.device_map = None
487
  self.gradient_checkpointing = False
488
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
  class DecoderOnlyT5Model(modeling_t5.T5ForConditionalGeneration):
491
  def __init__(self, config: DecoderOnlyT5Config):
@@ -513,6 +897,14 @@ class DecoderOnlyT5Model(modeling_t5.T5ForConditionalGeneration):
513
  self.model_parallel = False
514
  self.device_map = None
515
 
 
 
 
 
 
 
 
 
516
  @add_start_docstrings_to_model_forward(modeling_t5.T5_INPUTS_DOCSTRING)
517
  @replace_return_docstrings(
518
  output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
@@ -520,8 +912,8 @@ class DecoderOnlyT5Model(modeling_t5.T5ForConditionalGeneration):
520
  def forward(
521
  self,
522
  input_ids: Optional[torch.LongTensor] = None,
523
- attention_mask: Optional[torch.FloatTensor] = None,
524
  position_ids: Optional[torch.LongTensor] = None,
 
525
  past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
526
  inputs_embeds: Optional[torch.FloatTensor] = None,
527
  labels: Optional[torch.LongTensor] = None,
@@ -560,6 +952,7 @@ class DecoderOnlyT5Model(modeling_t5.T5ForConditionalGeneration):
560
  # Decode
561
  outputs = self.decoder(
562
  input_ids=input_ids,
 
563
  attention_mask=attention_mask,
564
  inputs_embeds=inputs_embeds,
565
  past_key_values=past_key_values,
 
36
  self.dropout = nn.Dropout(config.dropout_rate)
37
 
38
 
39
+ # LlamaRotaryEmbedding
40
+ class T5DecoderOnlyRotaryEmbedding(nn.Module):
41
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
42
+ super().__init__()
43
+
44
+ self.dim = dim
45
+ self.max_position_embeddings = max_position_embeddings
46
+ self.base = base
47
+ inv_freq = 1.0 / (
48
+ self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
49
+ )
50
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
51
+
52
+ # Build here to make `torch.jit.trace` work.
53
+ self._set_cos_sin_cache(
54
+ seq_len=max_position_embeddings,
55
+ device=self.inv_freq.device,
56
+ dtype=torch.get_default_dtype(),
57
+ )
58
+
59
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
60
+ self.max_seq_len_cached = seq_len
61
+ t = torch.arange(
62
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
63
+ )
64
+
65
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
66
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
67
+ emb = torch.cat((freqs, freqs), dim=-1)
68
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
69
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
70
+
71
+ def forward(self, x, seq_len=None):
72
+ # x: [bs, num_attention_heads, seq_len, head_size]
73
+ if seq_len > self.max_seq_len_cached:
74
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
75
+
76
+ return (
77
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
78
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
79
+ )
80
+
81
+
82
+ def rotate_half(x):
83
+ """Rotates half the hidden dims of the input."""
84
+ x1 = x[..., : x.shape[-1] // 2]
85
+ x2 = x[..., x.shape[-1] // 2 :]
86
+ return torch.cat((-x2, x1), dim=-1)
87
+
88
+
89
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
90
+ """Applies Rotary Position Embedding to the query and key tensors.
91
+
92
+ Args:
93
+ q (`torch.Tensor`): The query tensor.
94
+ k (`torch.Tensor`): The key tensor.
95
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
96
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
97
+ position_ids (`torch.Tensor`):
98
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
99
+ used to pass offsetted position ids when working with a KV-cache.
100
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
101
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
102
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
103
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
104
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
105
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
106
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
107
+ Returns:
108
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
109
+ """
110
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
111
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
112
+ q_embed = (q * cos) + (rotate_half(q) * sin)
113
+ k_embed = (k * cos) + (rotate_half(k) * sin)
114
+ return q_embed, k_embed
115
+
116
+
117
  # https://github.com/huggingface/transformers/blob/7ee995fd9c692761c4601ddbffa2ac2ec9f27b0b/src/transformers/models/llama/modeling_llama.py#L263
118
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
119
  """
 
150
  self.dropout = config.dropout_rate
151
  self.inner_dim = self.n_heads * self.key_value_proj_dim
152
  self.kv_inner_dim = self.n_kv_heads * self.key_value_proj_dim
153
+ if config.use_rotary_embedding:
154
+ self.rotary_embedding = T5DecoderOnlyRotaryEmbedding(
155
+ self.key_value_proj_dim,
156
+ max_position_embeddings=config.relative_attention_max_distance,
157
+ base=config.rotary_embedding_max_timescale,
158
+ )
159
+ else:
160
+ self.rotary_embedding = None
161
 
162
  # Mesh TensorFlow initialization to avoid scaling before softmax
 
163
  self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
164
  self.k = nn.Linear(self.d_model, self.kv_inner_dim, bias=False)
165
  self.v = nn.Linear(self.d_model, self.kv_inner_dim, bias=False)
 
178
  mask=None,
179
  key_value_states=None,
180
  position_bias=None,
181
+ position_ids=None,
182
  past_key_value=None,
183
  layer_head_mask=None,
184
  query_length=None,
 
230
  # cross-attn
231
  # (batch_size, n_kv_heads, seq_length, dim_per_head)
232
  hidden_states = shape(proj_layer(key_value_states), self.n_kv_heads)
233
+ return hidden_states
234
 
235
+ def concat_past_key_value(hidden_states, past_key_value, key_value_states):
236
+ if key_value_states is None:
237
+ # self-attn
238
+ # (batch_size, n_kv_heads, key_length, dim_per_head)
239
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
240
+ elif past_key_value.shape[2] != key_value_states.shape[1]:
241
+ # checking that the `sequence_length` of the `past_key_value` is the same as
242
+ # the provided `key_value_states` to support prefix tuning
243
+ # cross-attn
244
+ # (batch_size, n_kv_heads, seq_length, dim_per_head)
245
+ raise NotImplementedError(
246
+ "cross attention with RoPE and past KV is not implemented"
247
+ )
248
+ # hidden_states = shape(proj_layer(key_value_states), self.n_kv_heads)
249
+ else:
250
+ # cross-attn
251
+ hidden_states = past_key_value
252
  return hidden_states
253
 
254
  # get query states
 
257
  ) # (batch_size, n_heads, seq_length, dim_per_head)
258
 
259
  # get key/value states
260
+ key_states = project(hidden_states, self.k, key_value_states, past_key_value)
261
+ value_states = project(hidden_states, self.v, key_value_states, past_key_value)
262
+
263
+ # RoPE
264
+ if self.rotary_embedding is not None:
265
+ kv_seq_len = key_states.shape[-2]
266
+ if past_key_value:
267
+ kv_seq_len += past_key_value[0].shape[-2]
268
+ cos, sin = self.rotary_embedding(query_states, seq_len=kv_seq_len)
269
+ query_states, key_states = apply_rotary_pos_emb(
270
+ query_states, key_states, cos, sin, position_ids
271
+ )
272
+
273
+ # concat past
274
+ if past_key_value is not None:
275
+ key_states = concat_past_key_value(
276
+ key_states,
277
+ past_key_value[0],
278
  key_value_states,
279
+ )
280
+ value_states = concat_past_key_value(
281
+ value_states,
282
+ past_key_value[1],
 
 
 
 
283
  key_value_states,
284
+ )
285
+
286
+ # MultiQueryDotProductAttention
287
+ key_states = repeat_kv(key_states, self.n_kv_groups)
288
+ value_states = repeat_kv(value_states, self.n_kv_groups)
289
 
290
  # compute scores
291
  scores = torch.matmul(
 
367
  hidden_states,
368
  attention_mask=None,
369
  position_bias=None,
370
+ position_ids=None,
371
  layer_head_mask=None,
372
  past_key_value=None,
373
  use_cache=False,
 
381
  x,
382
  mask=attention_mask,
383
  position_bias=position_bias,
384
+ position_ids=position_ids,
385
  layer_head_mask=layer_head_mask,
386
  past_key_value=past_key_value,
387
  use_cache=use_cache,
 
423
  hidden_states,
424
  attention_mask=None,
425
  position_bias=None,
426
+ position_ids=None,
427
  encoder_hidden_states=None,
428
  encoder_attention_mask=None,
429
  encoder_decoder_position_bias=None,
 
465
  x,
466
  attention_mask=attention_mask,
467
  position_bias=position_bias,
468
+ position_ids=position_ids,
469
  layer_head_mask=layer_head_mask,
470
  past_key_value=self_attn_past_key_value,
471
  use_cache=use_cache,
 
503
  key_value_states=encoder_hidden_states,
504
  attention_mask=encoder_attention_mask,
505
  position_bias=encoder_decoder_position_bias,
506
+ # position_ids ?
507
  layer_head_mask=cross_attn_layer_head_mask,
508
  past_key_value=cross_attn_past_key_value,
509
  query_length=query_length,
 
592
  self.device_map = None
593
  self.gradient_checkpointing = False
594
 
595
+ def forward(
596
+ self,
597
+ input_ids=None,
598
+ position_ids=None,
599
+ attention_mask=None,
600
+ encoder_hidden_states=None,
601
+ encoder_attention_mask=None,
602
+ inputs_embeds=None,
603
+ head_mask=None,
604
+ cross_attn_head_mask=None,
605
+ past_key_values=None,
606
+ use_cache=None,
607
+ output_attentions=None,
608
+ output_hidden_states=None,
609
+ return_dict=None,
610
+ ):
611
+ # Model parallel
612
+ if self.model_parallel:
613
+ torch.cuda.set_device(self.first_device)
614
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
615
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
616
+ output_attentions = (
617
+ output_attentions
618
+ if output_attentions is not None
619
+ else self.config.output_attentions
620
+ )
621
+ output_hidden_states = (
622
+ output_hidden_states
623
+ if output_hidden_states is not None
624
+ else self.config.output_hidden_states
625
+ )
626
+ return_dict = (
627
+ return_dict if return_dict is not None else self.config.use_return_dict
628
+ )
629
+
630
+ if input_ids is not None and inputs_embeds is not None:
631
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
632
+ raise ValueError(
633
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
634
+ )
635
+ elif input_ids is not None:
636
+ input_shape = input_ids.size()
637
+ input_ids = input_ids.view(-1, input_shape[-1])
638
+ elif inputs_embeds is not None:
639
+ input_shape = inputs_embeds.size()[:-1]
640
+ else:
641
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
642
+ raise ValueError(
643
+ f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds"
644
+ )
645
+
646
+ if position_ids is None:
647
+ seq_length = input_ids.shape[1]
648
+ past_key_values_length = (
649
+ 0 if past_key_values is None else past_key_values[0][0].shape[2]
650
+ )
651
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
652
+ position_ids = torch.arange(
653
+ past_key_values_length,
654
+ seq_length + past_key_values_length,
655
+ dtype=torch.long,
656
+ device=device,
657
+ )
658
+ position_ids = position_ids.unsqueeze(0)
659
+
660
+ if inputs_embeds is None:
661
+ if self.embed_tokens is None:
662
+ raise ValueError(
663
+ "You have to initialize the model with valid token embeddings"
664
+ )
665
+ inputs_embeds = self.embed_tokens(input_ids)
666
+
667
+ batch_size, seq_length = input_shape
668
+
669
+ # required mask seq length can be calculated via length of past
670
+ mask_seq_length = (
671
+ past_key_values[0][0].shape[2] + seq_length
672
+ if past_key_values is not None
673
+ else seq_length
674
+ )
675
+
676
+ if use_cache is True:
677
+ if not self.is_decoder:
678
+ raise ValueError(
679
+ f"`use_cache` can only be set to `True` if {self} is used as a decoder"
680
+ )
681
+
682
+ if attention_mask is None:
683
+ attention_mask = torch.ones(
684
+ batch_size, mask_seq_length, device=inputs_embeds.device
685
+ )
686
+ if (
687
+ self.is_decoder
688
+ and encoder_attention_mask is None
689
+ and encoder_hidden_states is not None
690
+ ):
691
+ encoder_seq_length = encoder_hidden_states.shape[1]
692
+ encoder_attention_mask = torch.ones(
693
+ batch_size,
694
+ encoder_seq_length,
695
+ device=inputs_embeds.device,
696
+ dtype=torch.long,
697
+ )
698
+
699
+ # initialize past_key_values with `None` if past does not exist
700
+ if past_key_values is None:
701
+ past_key_values = [None] * len(self.block)
702
+
703
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
704
+ # ourselves in which case we just need to make it broadcastable to all heads.
705
+ extended_attention_mask = self.get_extended_attention_mask(
706
+ attention_mask, input_shape
707
+ )
708
+
709
+ # If a 2D or 3D attention mask is provided for the cross-attention
710
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
711
+ if self.is_decoder and encoder_hidden_states is not None:
712
+ (
713
+ encoder_batch_size,
714
+ encoder_sequence_length,
715
+ _,
716
+ ) = encoder_hidden_states.size()
717
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
718
+ if encoder_attention_mask is None:
719
+ encoder_attention_mask = torch.ones(
720
+ encoder_hidden_shape, device=inputs_embeds.device
721
+ )
722
+ encoder_extended_attention_mask = self.invert_attention_mask(
723
+ encoder_attention_mask
724
+ )
725
+ else:
726
+ encoder_extended_attention_mask = None
727
+
728
+ if self.gradient_checkpointing and self.training:
729
+ if use_cache:
730
+ logger.warning_once(
731
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
732
+ )
733
+ use_cache = False
734
+
735
+ # Prepare head mask if needed
736
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
737
+ cross_attn_head_mask = self.get_head_mask(
738
+ cross_attn_head_mask, self.config.num_layers
739
+ )
740
+ present_key_value_states = () if use_cache else None
741
+ all_hidden_states = () if output_hidden_states else None
742
+ all_attentions = () if output_attentions else None
743
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
744
+ position_bias = None
745
+ encoder_decoder_position_bias = None
746
+
747
+ hidden_states = self.dropout(inputs_embeds)
748
+
749
+ for i, (layer_module, past_key_value) in enumerate(
750
+ zip(self.block, past_key_values)
751
+ ):
752
+ layer_head_mask = head_mask[i]
753
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
754
+ # Model parallel
755
+ if self.model_parallel:
756
+ torch.cuda.set_device(hidden_states.device)
757
+ # Ensure that attention_mask is always on the same device as hidden_states
758
+ if attention_mask is not None:
759
+ attention_mask = attention_mask.to(hidden_states.device)
760
+ if position_bias is not None:
761
+ position_bias = position_bias.to(hidden_states.device)
762
+ if encoder_hidden_states is not None:
763
+ encoder_hidden_states = encoder_hidden_states.to(
764
+ hidden_states.device
765
+ )
766
+ if encoder_extended_attention_mask is not None:
767
+ encoder_extended_attention_mask = (
768
+ encoder_extended_attention_mask.to(hidden_states.device)
769
+ )
770
+ if encoder_decoder_position_bias is not None:
771
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(
772
+ hidden_states.device
773
+ )
774
+ if layer_head_mask is not None:
775
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
776
+ if cross_attn_layer_head_mask is not None:
777
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(
778
+ hidden_states.device
779
+ )
780
+ if output_hidden_states:
781
+ all_hidden_states = all_hidden_states + (hidden_states,)
782
+
783
+ if self.gradient_checkpointing and self.training:
784
+ layer_outputs = self._gradient_checkpointing_func(
785
+ layer_module.forward,
786
+ hidden_states,
787
+ extended_attention_mask,
788
+ position_bias,
789
+ encoder_hidden_states,
790
+ encoder_extended_attention_mask,
791
+ encoder_decoder_position_bias,
792
+ layer_head_mask,
793
+ cross_attn_layer_head_mask,
794
+ None, # past_key_value is always None with gradient checkpointing
795
+ use_cache,
796
+ output_attentions,
797
+ )
798
+ else:
799
+ layer_outputs = layer_module(
800
+ hidden_states,
801
+ attention_mask=extended_attention_mask,
802
+ position_bias=position_bias,
803
+ position_ids=position_ids,
804
+ encoder_hidden_states=encoder_hidden_states,
805
+ encoder_attention_mask=encoder_extended_attention_mask,
806
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
807
+ layer_head_mask=layer_head_mask,
808
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
809
+ past_key_value=past_key_value,
810
+ use_cache=use_cache,
811
+ output_attentions=output_attentions,
812
+ )
813
+
814
+ # layer_outputs is a tuple with:
815
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
816
+ if use_cache is False:
817
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
818
+
819
+ hidden_states, present_key_value_state = layer_outputs[:2]
820
+
821
+ # We share the position biases between the layers - the first layer store them
822
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
823
+ # (cross-attention position bias), (cross-attention weights)
824
+ position_bias = layer_outputs[2]
825
+ if self.is_decoder and encoder_hidden_states is not None:
826
+ encoder_decoder_position_bias = layer_outputs[
827
+ 4 if output_attentions else 3
828
+ ]
829
+ # append next layer key value states
830
+ if use_cache:
831
+ present_key_value_states = present_key_value_states + (
832
+ present_key_value_state,
833
+ )
834
+
835
+ if output_attentions:
836
+ all_attentions = all_attentions + (layer_outputs[3],)
837
+ if self.is_decoder:
838
+ all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
839
+
840
+ # Model Parallel: If it's the last layer for that device, put things on the next device
841
+ if self.model_parallel:
842
+ for k, v in self.device_map.items():
843
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
844
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
845
+
846
+ hidden_states = self.final_layer_norm(hidden_states)
847
+ hidden_states = self.dropout(hidden_states)
848
+
849
+ # Add last layer
850
+ if output_hidden_states:
851
+ all_hidden_states = all_hidden_states + (hidden_states,)
852
+
853
+ if not return_dict:
854
+ return tuple(
855
+ v
856
+ for v in [
857
+ hidden_states,
858
+ present_key_value_states,
859
+ all_hidden_states,
860
+ all_attentions,
861
+ all_cross_attentions,
862
+ ]
863
+ if v is not None
864
+ )
865
+ return modeling_t5.BaseModelOutputWithPastAndCrossAttentions(
866
+ last_hidden_state=hidden_states,
867
+ past_key_values=present_key_value_states,
868
+ hidden_states=all_hidden_states,
869
+ attentions=all_attentions,
870
+ cross_attentions=all_cross_attentions,
871
+ )
872
+
873
 
874
  class DecoderOnlyT5Model(modeling_t5.T5ForConditionalGeneration):
875
  def __init__(self, config: DecoderOnlyT5Config):
 
897
  self.model_parallel = False
898
  self.device_map = None
899
 
900
+ def _tie_weights(self):
901
+ if not self.config.tie_word_embeddings:
902
+ return
903
+ if self.encoder:
904
+ self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
905
+ if self.decoder:
906
+ self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
907
+
908
  @add_start_docstrings_to_model_forward(modeling_t5.T5_INPUTS_DOCSTRING)
909
  @replace_return_docstrings(
910
  output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
 
912
  def forward(
913
  self,
914
  input_ids: Optional[torch.LongTensor] = None,
 
915
  position_ids: Optional[torch.LongTensor] = None,
916
+ attention_mask: Optional[torch.FloatTensor] = None,
917
  past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
918
  inputs_embeds: Optional[torch.FloatTensor] = None,
919
  labels: Optional[torch.LongTensor] = None,
 
952
  # Decode
953
  outputs = self.decoder(
954
  input_ids=input_ids,
955
+ position_ids=position_ids,
956
  attention_mask=attention_mask,
957
  inputs_embeds=inputs_embeds,
958
  past_key_values=past_key_values,