jbochi commited on
Commit
a3055fa
1 Parent(s): 334b292

Update config and add decoding code (buggy)

Browse files

It loads all weights, and does a forward pass, but logits have NaNs

Need to find the bug

config.json CHANGED
@@ -23,8 +23,11 @@
23
  "relative_attention_max_distance": 128,
24
  "relative_attention_num_buckets": 32,
25
  "task_specific_params": {},
26
- "tie_word_embeddings": false,
27
  "transformers_version": "4.23.1",
28
  "use_cache": true,
29
- "vocab_size": 256512
 
 
 
30
  }
 
23
  "relative_attention_max_distance": 128,
24
  "relative_attention_num_buckets": 32,
25
  "task_specific_params": {},
26
+ "tie_word_embeddings": true,
27
  "transformers_version": "4.23.1",
28
  "use_cache": true,
29
+ "vocab_size": 256512,
30
+ "parallel_layers": true,
31
+ "has_relative_attention_bias": false,
32
+ "multi_query_attention": true
33
  }
decoder_only_t5/__init__.py ADDED
File without changes
decoder_only_t5/config.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.t5.configuration_t5 import T5Config
2
+
3
+
4
+ class DecoderOnlyT5Config(T5Config):
5
+ is_decoder_only = True
6
+ # whether to call attention and mlp in parallel.
7
+ # https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L384
8
+ parallel_layers = True
9
+ has_relative_attention_bias = False
10
+ # https://arxiv.org/abs/1911.02150
11
+ multi_query_attention = True
decoder_only_t5/modeling.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import CrossEntropyLoss
7
+ from transformers.models.t5 import modeling_t5
8
+ from transformers.modeling_outputs import Seq2SeqLMOutput
9
+ from transformers.utils import (
10
+ add_start_docstrings_to_model_forward,
11
+ logging,
12
+ replace_return_docstrings,
13
+ )
14
+
15
+ from decoder_only_t5.config import DecoderOnlyT5Config
16
+
17
+
18
+ logger = logging.get_logger(__name__)
19
+ _CONFIG_FOR_DOC = "DecoderOnlyT5Config"
20
+
21
+
22
+ class DecoderOnlyT5LayerFF(modeling_t5.T5LayerFF):
23
+ def __init__(self, config: DecoderOnlyT5Config):
24
+ super(modeling_t5.T5LayerFF, self).__init__()
25
+ if config.is_gated_act:
26
+ self.DenseReluDense = modeling_t5.T5DenseGatedActDense(config)
27
+ else:
28
+ self.DenseReluDense = modeling_t5.T5DenseActDense(config)
29
+
30
+ if not config.parallel_layers:
31
+ self.layer_norm = modeling_t5.T5LayerNorm(
32
+ config.d_model, eps=config.layer_norm_epsilon
33
+ )
34
+ else:
35
+ self.layer_norm = nn.Identity()
36
+ self.dropout = nn.Dropout(config.dropout_rate)
37
+
38
+
39
+ # https://github.com/huggingface/transformers/blob/7ee995fd9c692761c4601ddbffa2ac2ec9f27b0b/src/transformers/models/llama/modeling_llama.py#L263
40
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
41
+ """
42
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
43
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
44
+ """
45
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
46
+ if n_rep == 1:
47
+ return hidden_states
48
+ hidden_states = hidden_states[:, :, None, :, :].expand(
49
+ batch, num_key_value_heads, n_rep, slen, head_dim
50
+ )
51
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
52
+
53
+
54
+ class DecoderOnlyT5Attention(modeling_t5.T5Attention):
55
+ """
56
+ Supports both multi-head and multi-query attention.
57
+ https://arxiv.org/abs/1911.02150
58
+ https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/components/attention/dense_attention.py#L292
59
+ """
60
+
61
+ def __init__(self, config: DecoderOnlyT5Config, has_relative_attention_bias=False):
62
+ super(modeling_t5.T5Attention, self).__init__()
63
+ self.is_decoder = config.is_decoder
64
+ self.has_relative_attention_bias = has_relative_attention_bias
65
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
66
+ self.relative_attention_max_distance = config.relative_attention_max_distance
67
+ self.d_model = config.d_model
68
+ self.key_value_proj_dim = config.d_kv
69
+ self.n_heads = config.num_heads
70
+ self.n_kv_heads = 1 if config.multi_query_attention else self.n_heads
71
+ self.n_kv_groups = self.n_heads // self.n_kv_heads
72
+ self.dropout = config.dropout_rate
73
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
74
+ self.kv_inner_dim = self.n_kv_heads * self.key_value_proj_dim
75
+
76
+ # Mesh TensorFlow initialization to avoid scaling before softmax
77
+
78
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
79
+ self.k = nn.Linear(self.d_model, self.kv_inner_dim, bias=False)
80
+ self.v = nn.Linear(self.d_model, self.kv_inner_dim, bias=False)
81
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
82
+
83
+ if self.has_relative_attention_bias:
84
+ self.relative_attention_bias = nn.Embedding(
85
+ self.relative_attention_num_buckets, self.n_heads
86
+ )
87
+ self.pruned_heads = set()
88
+ self.gradient_checkpointing = False
89
+
90
+ def forward(
91
+ self,
92
+ hidden_states,
93
+ mask=None,
94
+ key_value_states=None,
95
+ position_bias=None,
96
+ past_key_value=None,
97
+ layer_head_mask=None,
98
+ query_length=None,
99
+ use_cache=False,
100
+ output_attentions=False,
101
+ ):
102
+ """
103
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
104
+ """
105
+ # Input is (batch_size, seq_length, dim)
106
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
107
+ # past_key_value[0] is (batch_size, n_kv_heads, q_len - 1, dim_per_head)
108
+ batch_size, seq_length = hidden_states.shape[:2]
109
+
110
+ real_seq_length = seq_length
111
+
112
+ if past_key_value is not None:
113
+ if len(past_key_value) != 2:
114
+ raise ValueError(
115
+ f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
116
+ )
117
+ real_seq_length += (
118
+ past_key_value[0].shape[2] if query_length is None else query_length
119
+ )
120
+
121
+ key_length = (
122
+ real_seq_length if key_value_states is None else key_value_states.shape[1]
123
+ )
124
+
125
+ def shape(states, n_heads):
126
+ """projection"""
127
+ return states.view(
128
+ batch_size, -1, n_heads, self.key_value_proj_dim
129
+ ).transpose(1, 2)
130
+
131
+ def unshape(states):
132
+ """reshape"""
133
+ return (
134
+ states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
135
+ )
136
+
137
+ def project(hidden_states, proj_layer, key_value_states, past_key_value):
138
+ """projects hidden states correctly to key/query states"""
139
+ if key_value_states is None:
140
+ # self-attn
141
+ # (batch_size, n_kv_heads, seq_length, dim_per_head)
142
+ hidden_states = shape(proj_layer(hidden_states), self.n_kv_heads)
143
+ elif past_key_value is None:
144
+ # cross-attn
145
+ # (batch_size, n_kv_heads, seq_length, dim_per_head)
146
+ hidden_states = shape(proj_layer(key_value_states), self.n_kv_heads)
147
+
148
+ if past_key_value is not None:
149
+ if key_value_states is None:
150
+ # self-attn
151
+ # (batch_size, n_kv_heads, key_length, dim_per_head)
152
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
153
+ elif past_key_value.shape[2] != key_value_states.shape[1]:
154
+ # checking that the `sequence_length` of the `past_key_value` is the same as
155
+ # the provided `key_value_states` to support prefix tuning
156
+ # cross-attn
157
+ # (batch_size, n_kv_heads, seq_length, dim_per_head)
158
+ hidden_states = shape(proj_layer(key_value_states), self.n_kv_heads)
159
+ else:
160
+ # cross-attn
161
+ hidden_states = past_key_value
162
+ return hidden_states
163
+
164
+ # get query states
165
+ query_states = shape(
166
+ self.q(hidden_states), self.n_heads
167
+ ) # (batch_size, n_heads, seq_length, dim_per_head)
168
+
169
+ # get key/value states
170
+ key_states = project(
171
+ hidden_states,
172
+ self.k,
173
+ key_value_states,
174
+ past_key_value[0] if past_key_value is not None else None,
175
+ )
176
+ value_states = project(
177
+ hidden_states,
178
+ self.v,
179
+ key_value_states,
180
+ past_key_value[1] if past_key_value is not None else None,
181
+ )
182
+
183
+ # compute scores
184
+ scores = torch.matmul(
185
+ query_states, repeat_kv(key_states, self.n_kv_groups).transpose(3, 2)
186
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
187
+
188
+ if position_bias is None:
189
+ if not self.has_relative_attention_bias:
190
+ position_bias = torch.zeros(
191
+ (1, self.n_heads, real_seq_length, key_length),
192
+ device=scores.device,
193
+ dtype=scores.dtype,
194
+ )
195
+ if self.gradient_checkpointing and self.training:
196
+ position_bias.requires_grad = True
197
+ else:
198
+ position_bias = self.compute_bias(
199
+ real_seq_length, key_length, device=scores.device
200
+ )
201
+
202
+ # if key and values are already calculated
203
+ # we want only the last query position bias
204
+ if past_key_value is not None:
205
+ position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
206
+
207
+ if mask is not None:
208
+ position_bias = (
209
+ position_bias + mask
210
+ ) # (batch_size, n_heads, seq_length, key_length)
211
+
212
+ if self.pruned_heads:
213
+ mask = torch.ones(position_bias.shape[1])
214
+ mask[list(self.pruned_heads)] = 0
215
+ position_bias_masked = position_bias[:, mask.bool()]
216
+ else:
217
+ position_bias_masked = position_bias
218
+
219
+ scores += position_bias_masked
220
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
221
+ scores
222
+ ) # (batch_size, n_heads, seq_length, key_length)
223
+ attn_weights = nn.functional.dropout(
224
+ attn_weights, p=self.dropout, training=self.training
225
+ ) # (batch_size, n_heads, seq_length, key_length)
226
+
227
+ # Mask heads if we want to
228
+ if layer_head_mask is not None:
229
+ attn_weights = attn_weights * layer_head_mask
230
+
231
+ attn_output = unshape(
232
+ torch.matmul(attn_weights, value_states)
233
+ ) # (batch_size, seq_length, dim)
234
+ attn_output = self.o(attn_output)
235
+
236
+ present_key_value_state = (
237
+ (key_states, value_states) if (self.is_decoder and use_cache) else None
238
+ )
239
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
240
+
241
+ if output_attentions:
242
+ outputs = outputs + (attn_weights,)
243
+ return outputs
244
+
245
+
246
+ class DecoderOnlyT5LayerSelfAttention(modeling_t5.T5LayerSelfAttention):
247
+ def __init__(self, config, has_relative_attention_bias=False):
248
+ super(modeling_t5.T5LayerSelfAttention, self).__init__()
249
+ self.SelfAttention = DecoderOnlyT5Attention(
250
+ config, has_relative_attention_bias=has_relative_attention_bias
251
+ )
252
+ self.layer_norm = modeling_t5.T5LayerNorm(
253
+ config.d_model, eps=config.layer_norm_epsilon
254
+ )
255
+ self.dropout = nn.Dropout(config.dropout_rate)
256
+ self.parallel_layers = config.parallel_layers
257
+
258
+ def forward(
259
+ self,
260
+ hidden_states,
261
+ attention_mask=None,
262
+ position_bias=None,
263
+ layer_head_mask=None,
264
+ past_key_value=None,
265
+ use_cache=False,
266
+ output_attentions=False,
267
+ ):
268
+ if not self.parallel_layers:
269
+ x = self.layer_norm(hidden_states)
270
+ else:
271
+ x = hidden_states
272
+ attention_output = self.SelfAttention(
273
+ x,
274
+ mask=attention_mask,
275
+ position_bias=position_bias,
276
+ layer_head_mask=layer_head_mask,
277
+ past_key_value=past_key_value,
278
+ use_cache=use_cache,
279
+ output_attentions=output_attentions,
280
+ )
281
+ if not self.parallel_layers:
282
+ # When parallel_layers is True, the residual connection is applied
283
+ # in the decoder block instead of here.
284
+ hidden_states = hidden_states + self.dropout(attention_output[0])
285
+ else:
286
+ hidden_states = attention_output[0]
287
+ outputs = (hidden_states,) + attention_output[
288
+ 1:
289
+ ] # add attentions if we output them
290
+ return outputs
291
+
292
+
293
+ class DecoderOnlyT5Block(modeling_t5.T5Block):
294
+ def __init__(self, config, has_relative_attention_bias=False):
295
+ super(modeling_t5.T5Block, self).__init__()
296
+ self.is_decoder = config.is_decoder
297
+ self.is_decoder_only = config.is_decoder_only
298
+ self.layer = nn.ModuleList()
299
+ self.layer.append(
300
+ DecoderOnlyT5LayerSelfAttention(
301
+ config, has_relative_attention_bias=has_relative_attention_bias
302
+ )
303
+ )
304
+ if self.is_decoder:
305
+ if config.is_decoder_only:
306
+ self.layer.append(nn.Identity())
307
+ else:
308
+ self.layer.append(modeling_t5.T5LayerCrossAttention(config))
309
+ self.parallel_layers = config.parallel_layers
310
+ self.layer.append(DecoderOnlyT5LayerFF(config))
311
+
312
+ def forward(
313
+ self,
314
+ hidden_states,
315
+ attention_mask=None,
316
+ position_bias=None,
317
+ encoder_hidden_states=None,
318
+ encoder_attention_mask=None,
319
+ encoder_decoder_position_bias=None,
320
+ layer_head_mask=None,
321
+ cross_attn_layer_head_mask=None,
322
+ past_key_value=None,
323
+ use_cache=False,
324
+ output_attentions=False,
325
+ return_dict=True,
326
+ ):
327
+ if past_key_value is not None:
328
+ if not self.is_decoder:
329
+ logger.warning(
330
+ "`past_key_values` is passed to the encoder. Please make sure this is intended."
331
+ )
332
+ expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
333
+
334
+ if len(past_key_value) != expected_num_past_key_values:
335
+ raise ValueError(
336
+ f"There should be {expected_num_past_key_values} past states. "
337
+ f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
338
+ f"Got {len(past_key_value)} past key / value states"
339
+ )
340
+
341
+ self_attn_past_key_value = past_key_value[:2]
342
+ cross_attn_past_key_value = past_key_value[2:]
343
+ else:
344
+ self_attn_past_key_value, cross_attn_past_key_value = None, None
345
+
346
+ ff_layer = self.layer[-1]
347
+ if self.parallel_layers:
348
+ x = self.layer[0].layer_norm(hidden_states)
349
+ ff_output = ff_layer(hidden_states)
350
+ else:
351
+ x = hidden_states
352
+
353
+ self_attention_outputs = self.layer[0](
354
+ x,
355
+ attention_mask=attention_mask,
356
+ position_bias=position_bias,
357
+ layer_head_mask=layer_head_mask,
358
+ past_key_value=self_attn_past_key_value,
359
+ use_cache=use_cache,
360
+ output_attentions=output_attentions,
361
+ )
362
+ x, present_key_value_state = self_attention_outputs[:2]
363
+ attention_outputs = self_attention_outputs[
364
+ 2:
365
+ ] # Keep self-attention outputs and relative position weights
366
+
367
+ # clamp inf values to enable fp16 training
368
+ if x.dtype == torch.float16:
369
+ clamp_value = torch.where(
370
+ torch.isinf(x).any(),
371
+ torch.finfo(x.dtype).max - 1000,
372
+ torch.finfo(x.dtype).max,
373
+ )
374
+ x = torch.clamp(x, min=-clamp_value, max=clamp_value)
375
+
376
+ do_cross_attention = (
377
+ self.is_decoder
378
+ and not self.is_decoder_only
379
+ and encoder_hidden_states is not None
380
+ )
381
+ if do_cross_attention:
382
+ # the actual query length is unknown for cross attention
383
+ # if using past key value states. Need to inject it here
384
+ if present_key_value_state is not None:
385
+ query_length = present_key_value_state[0].shape[2]
386
+ else:
387
+ query_length = None
388
+
389
+ cross_attention_outputs = self.layer[1](
390
+ x,
391
+ key_value_states=encoder_hidden_states,
392
+ attention_mask=encoder_attention_mask,
393
+ position_bias=encoder_decoder_position_bias,
394
+ layer_head_mask=cross_attn_layer_head_mask,
395
+ past_key_value=cross_attn_past_key_value,
396
+ query_length=query_length,
397
+ use_cache=use_cache,
398
+ output_attentions=output_attentions,
399
+ )
400
+ x = cross_attention_outputs[0]
401
+
402
+ # clamp inf values to enable fp16 training
403
+ if x.dtype == torch.float16:
404
+ clamp_value = torch.where(
405
+ torch.isinf(x).any(),
406
+ torch.finfo(x.dtype).max - 1000,
407
+ torch.finfo(x.dtype).max,
408
+ )
409
+ x = torch.clamp(x, min=-clamp_value, max=clamp_value)
410
+
411
+ # Combine self attn and cross attn key value states
412
+ if present_key_value_state is not None:
413
+ present_key_value_state = (
414
+ present_key_value_state + cross_attention_outputs[1]
415
+ )
416
+
417
+ # Keep cross-attention outputs and relative position weights
418
+ attention_outputs = attention_outputs + cross_attention_outputs[2:]
419
+
420
+ if self.parallel_layers:
421
+ # https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L295
422
+ hidden_states = x + ff_output
423
+ hidden_states *= 2**-0.5
424
+ hidden_states = hidden_states + self.layer[0].dropout(hidden_states)
425
+ else:
426
+ hidden_states = ff_layer(x)
427
+
428
+ # clamp inf values to enable fp16 training
429
+ if hidden_states.dtype == torch.float16:
430
+ clamp_value = torch.where(
431
+ torch.isinf(hidden_states).any(),
432
+ torch.finfo(hidden_states.dtype).max - 1000,
433
+ torch.finfo(hidden_states.dtype).max,
434
+ )
435
+ hidden_states = torch.clamp(
436
+ hidden_states, min=-clamp_value, max=clamp_value
437
+ )
438
+
439
+ outputs = (hidden_states,)
440
+
441
+ if use_cache:
442
+ outputs = outputs + (present_key_value_state,) + attention_outputs
443
+ else:
444
+ outputs = outputs + attention_outputs
445
+
446
+ return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
447
+
448
+
449
+ class DecoderOnlyT5Stack(modeling_t5.T5Stack):
450
+ def __init__(self, config, embed_tokens=None):
451
+ super(modeling_t5.T5Stack, self).__init__(config)
452
+
453
+ self.embed_tokens = embed_tokens
454
+ self.is_decoder = config.is_decoder
455
+
456
+ self.block = nn.ModuleList(
457
+ [
458
+ DecoderOnlyT5Block(
459
+ config,
460
+ has_relative_attention_bias=(
461
+ config.has_relative_attention_bias and bool(i == 0)
462
+ ),
463
+ )
464
+ for i in range(config.num_layers)
465
+ ]
466
+ )
467
+ if not config.parallel_layers:
468
+ self.final_layer_norm = modeling_t5.T5LayerNorm(
469
+ config.d_model, eps=config.layer_norm_epsilon
470
+ )
471
+ else:
472
+ self.final_layer_norm = nn.Identity()
473
+ self.dropout = nn.Dropout(config.dropout_rate)
474
+
475
+ # Initialize weights and apply final processing
476
+ self.post_init()
477
+ # Model parallel
478
+ self.model_parallel = False
479
+ self.device_map = None
480
+ self.gradient_checkpointing = False
481
+
482
+
483
+ class DecoderOnlyT5Model(modeling_t5.T5ForConditionalGeneration):
484
+ def __init__(self, config: DecoderOnlyT5Config):
485
+ super(modeling_t5.T5ForConditionalGeneration, self).__init__(config)
486
+ self.model_dim = config.d_model
487
+
488
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
489
+ assert (
490
+ self.config.num_layers == 0
491
+ ), "Decoder only model cannot have encoder layers"
492
+ self.encoder = None
493
+
494
+ decoder_config = copy.deepcopy(config)
495
+ decoder_config.is_decoder = True
496
+ decoder_config.is_encoder_decoder = False
497
+ decoder_config.num_layers = config.num_decoder_layers
498
+ self.decoder = DecoderOnlyT5Stack(decoder_config, self.shared)
499
+
500
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
501
+
502
+ # Initialize weights and apply final processing
503
+ self.post_init()
504
+
505
+ # Model parallel
506
+ self.model_parallel = False
507
+ self.device_map = None
508
+
509
+ @add_start_docstrings_to_model_forward(modeling_t5.T5_INPUTS_DOCSTRING)
510
+ @replace_return_docstrings(
511
+ output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
512
+ )
513
+ def forward(
514
+ self,
515
+ _input_ids: Optional[torch.LongTensor] = None,
516
+ attention_mask: Optional[torch.FloatTensor] = None,
517
+ decoder_input_ids: Optional[torch.LongTensor] = None,
518
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
519
+ head_mask: Optional[torch.FloatTensor] = None,
520
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
521
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
522
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
523
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
524
+ _inputs_embeds: Optional[torch.FloatTensor] = None,
525
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
526
+ labels: Optional[torch.LongTensor] = None,
527
+ use_cache: Optional[bool] = None,
528
+ output_attentions: Optional[bool] = None,
529
+ output_hidden_states: Optional[bool] = None,
530
+ return_dict: Optional[bool] = None,
531
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
532
+ r"""
533
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
534
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
535
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
536
+ labels in `[0, ..., config.vocab_size]`
537
+
538
+ Returns:
539
+
540
+ Examples:
541
+
542
+ ```"""
543
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
544
+ return_dict = (
545
+ return_dict if return_dict is not None else self.config.use_return_dict
546
+ )
547
+
548
+ if self.model_parallel:
549
+ torch.cuda.set_device(self.decoder.first_device)
550
+
551
+ if (
552
+ labels is not None
553
+ and decoder_input_ids is None
554
+ and decoder_inputs_embeds is None
555
+ ):
556
+ # get decoder inputs from shifting lm labels to the right
557
+ decoder_input_ids = self._shift_right(labels)
558
+
559
+ # Set device for model parallelism
560
+ if self.model_parallel:
561
+ torch.cuda.set_device(self.decoder.first_device)
562
+ if decoder_input_ids is not None:
563
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
564
+ if attention_mask is not None:
565
+ attention_mask = attention_mask.to(self.decoder.first_device)
566
+ if decoder_attention_mask is not None:
567
+ decoder_attention_mask = decoder_attention_mask.to(
568
+ self.decoder.first_device
569
+ )
570
+
571
+ # Decode
572
+ decoder_outputs = self.decoder(
573
+ input_ids=decoder_input_ids,
574
+ attention_mask=decoder_attention_mask,
575
+ inputs_embeds=decoder_inputs_embeds,
576
+ past_key_values=past_key_values,
577
+ # encoder_hidden_states=hidden_states,
578
+ encoder_attention_mask=attention_mask,
579
+ head_mask=decoder_head_mask,
580
+ cross_attn_head_mask=cross_attn_head_mask,
581
+ use_cache=use_cache,
582
+ output_attentions=output_attentions,
583
+ output_hidden_states=output_hidden_states,
584
+ return_dict=return_dict,
585
+ )
586
+
587
+ sequence_output = decoder_outputs[0]
588
+
589
+ # Set device for model parallelism
590
+ if self.model_parallel:
591
+ torch.cuda.set_device(self.decoder.first_device)
592
+ self.lm_head = self.lm_head.to(self.decoder.first_device)
593
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
594
+
595
+ if self.config.tie_word_embeddings:
596
+ # Rescale output before projecting on vocab
597
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
598
+ sequence_output = sequence_output * (self.model_dim**-0.5)
599
+
600
+ lm_logits = self.lm_head(sequence_output)
601
+
602
+ loss = None
603
+ if labels is not None:
604
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
605
+ # move labels to correct device to enable PP
606
+ labels = labels.to(lm_logits.device)
607
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
608
+ # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
609
+
610
+ if not return_dict:
611
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
612
+ return ((loss,) + output) if loss is not None else output
613
+
614
+ return Seq2SeqLMOutput(
615
+ loss=loss,
616
+ logits=lm_logits,
617
+ past_key_values=decoder_outputs.past_key_values,
618
+ decoder_hidden_states=decoder_outputs.hidden_states,
619
+ decoder_attentions=decoder_outputs.attentions,
620
+ )