ClueAI commited on
Commit
9f52142
1 Parent(s): 0051c59

Update modeling_t5.py

Browse files
Files changed (1) hide show
  1. modeling_t5.py +6 -1939
modeling_t5.py CHANGED
@@ -1,1842 +1,18 @@
1
- # coding=utf-8
2
- # Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """ PyTorch T5 model."""
16
-
17
-
18
- import copy
19
- import math
20
- import os
21
- import warnings
22
- from typing import Optional, Tuple, Union
23
- from typing import Optional, Tuple, Union, List, Callable
24
-
25
- import torch
26
- from torch import nn
27
- from torch.nn import CrossEntropyLoss
28
- from torch.utils.checkpoint import checkpoint
29
-
30
- from transformers.activations import ACT2FN
31
- from transformers.adapters.composition import adjust_tensors_for_parallel
32
- from transformers.adapters.context import ForwardContext
33
- from transformers.adapters.lora import Linear as LoRALinear
34
- from transformers.adapters.mixins.t5 import (
35
- T5CrossAttentionLayerAdaptersMixin,
36
- T5FFLayerAdaptersMixin,
37
- T5ModelAdaptersMixin,
38
- T5ModelWithHeadsAdaptersMixin,
39
- T5SelfAttentionLayerAdaptersMixin,
40
- )
41
- from transformers.adapters.model_mixin import InvertibleAdaptersMixin
42
- from transformers.adapters.prefix_tuning import PrefixTuningShim
43
- from transformers.modeling_outputs import (
44
- BaseModelOutput,
45
- BaseModelOutputWithPastAndCrossAttentions,
46
- Seq2SeqLMOutput,
47
- Seq2SeqModelOutput,
48
- )
49
- from transformers.modeling_utils import PreTrainedModel
50
- from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
51
- from transformers.utils import (
52
- DUMMY_INPUTS,
53
- DUMMY_MASK,
54
- add_start_docstrings,
55
- add_start_docstrings_to_model_forward,
56
- is_torch_fx_proxy,
57
- logging,
58
- replace_return_docstrings,
59
- )
60
- from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
61
  from transformers.models.t5.configuration_t5 import T5Config
 
62
 
63
 
64
- logger = logging.get_logger(__name__)
65
-
66
- _CONFIG_FOR_DOC = "T5Config"
67
- _CHECKPOINT_FOR_DOC = "t5-small"
68
-
69
- ####################################################
70
- # This dict contains ids and associated url
71
- # for the pretrained weights provided with the models
72
- ####################################################
73
- T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
74
- "t5-small",
75
- "t5-base",
76
- "t5-large",
77
- "t5-3b",
78
- "t5-11b",
79
- # See all T5 models at https://huggingface.co/models?filter=t5
80
- ]
81
-
82
-
83
- ####################################################
84
- # This is a conversion method from TF 1.0 to PyTorch
85
- # More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
86
- ####################################################
87
- def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
88
- """Load tf checkpoints in a pytorch model."""
89
- try:
90
- import re
91
-
92
- import numpy as np
93
- import tensorflow as tf
94
- except ImportError:
95
- logger.error(
96
- "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
97
- "https://www.tensorflow.org/install/ for installation instructions."
98
- )
99
- raise
100
- tf_path = os.path.abspath(tf_checkpoint_path)
101
- logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
102
- # Load weights from TF model
103
- init_vars = tf.train.list_variables(tf_path)
104
- names = []
105
- tf_weights = {}
106
- for name, shape in init_vars:
107
- logger.info(f"Loading TF weight {name} with shape {shape}")
108
- array = tf.train.load_variable(tf_path, name)
109
- names.append(name)
110
- tf_weights[name] = array
111
-
112
- for txt_name in names:
113
- name = txt_name.split("/")
114
- # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
115
- # which are not required for using pretrained model
116
- if any(
117
- n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
118
- for n in name
119
- ):
120
- logger.info(f"Skipping {'/'.join(name)}")
121
- tf_weights.pop(txt_name, None)
122
- continue
123
- if "_slot_" in name[-1]:
124
- logger.info(f"Skipping {'/'.join(name)}")
125
- tf_weights.pop(txt_name, None)
126
- continue
127
- pointer = model
128
- array = tf_weights[txt_name]
129
-
130
- for m_name in name:
131
- if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
132
- scope_names = re.split(r"_(\d+)", m_name)
133
- else:
134
- scope_names = [m_name]
135
- if scope_names[0] in ["kernel", "scale", "embedding"]:
136
- pointer = getattr(pointer, "weight")
137
- elif scope_names[0] == "self_attention":
138
- pointer = getattr(pointer, "layer")
139
- pointer = pointer[0]
140
- elif scope_names[0] == "enc_dec_attention":
141
- pointer = getattr(pointer, "layer")
142
- pointer = pointer[1]
143
- elif scope_names[0] == "dense_relu_dense":
144
- pointer = getattr(pointer, "layer")
145
- pointer = pointer[2]
146
- elif scope_names[0] == "rms_norm":
147
- if hasattr(pointer, "layer_norm"):
148
- pointer = getattr(pointer, "layer_norm")
149
- elif hasattr(pointer, "final_layer_norm"):
150
- pointer = getattr(pointer, "final_layer_norm")
151
- elif scope_names[0] == "scale":
152
- pointer = getattr(pointer, "weight")
153
- elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
154
- pointer = getattr(pointer, "bias")
155
- elif scope_names[0] == "squad":
156
- pointer = getattr(pointer, "classifier")
157
- elif scope_names[0] == "decoder" and name[1] == "logits":
158
- continue
159
- elif scope_names[0] == "logits":
160
- pointer = getattr(pointer, "lm_head")
161
- elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit():
162
- pointer = getattr(pointer, f"wi_{scope_names[1]}")
163
- continue
164
- else:
165
- try:
166
- pointer = getattr(pointer, scope_names[0])
167
- except AttributeError:
168
- logger.info(f"Skipping {'/'.join(name)}")
169
- continue
170
- if len(scope_names) >= 2:
171
- num = int(scope_names[1])
172
- pointer = pointer[num]
173
- if scope_names[0] not in ["kernel", "scale", "embedding"]:
174
- pointer = getattr(pointer, "weight")
175
- if scope_names[0] != "embedding":
176
- logger.info(f"Transposing numpy weight of shape {array.shape} for {name}")
177
- array = np.transpose(array)
178
- try:
179
- assert (
180
- pointer.shape == array.shape
181
- ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
182
- except AssertionError as e:
183
- e.args += (pointer.shape, array.shape)
184
- raise
185
- logger.info(f"Initialize PyTorch weight {name}")
186
- pointer.data = torch.from_numpy(array.astype(np.float32))
187
- tf_weights.pop(txt_name, None)
188
-
189
- logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.")
190
- return model
191
-
192
-
193
- ####################################################
194
- # PyTorch Models are constructed by sub-classing
195
- # - torch.nn.Module for the layers and
196
- # - PreTrainedModel for the models (it-self a sub-class of nn.Module)
197
- ####################################################
198
- PARALLELIZE_DOCSTRING = r"""
199
- This is an experimental feature and is a subject to change at a moment's notice.
200
-
201
- Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
202
- it will evenly distribute blocks across all devices.
203
-
204
- Args:
205
- device_map (`Dict[int, list]`, optional, defaults to None):
206
- A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
207
- automatically mapped to the first device (for esoteric reasons). That means that the first device should
208
- have fewer attention modules mapped to it than other devices. For reference, the t5 models have the
209
- following number of attention modules:
210
-
211
- - t5-small: 6
212
- - t5-base: 12
213
- - t5-large: 24
214
- - t5-3b: 24
215
- - t5-11b: 24
216
-
217
- Example:
218
-
219
- ```python
220
- # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules:
221
- model = T5ForConditionalGeneration.from_pretrained("t5-3b")
222
- device_map = {
223
- 0: [0, 1, 2],
224
- 1: [3, 4, 5, 6, 7, 8, 9],
225
- 2: [10, 11, 12, 13, 14, 15, 16],
226
- 3: [17, 18, 19, 20, 21, 22, 23],
227
- }
228
- model.parallelize(device_map)
229
- ```
230
- """
231
- DEPARALLELIZE_DOCSTRING = r"""
232
- Moves the model to cpu from a model parallel state.
233
-
234
- Example:
235
-
236
- ```python
237
- # On a 4 GPU machine with t5-3b:
238
- model = T5ForConditionalGeneration.from_pretrained("t5-3b")
239
- device_map = {
240
- 0: [0, 1, 2],
241
- 1: [3, 4, 5, 6, 7, 8, 9],
242
- 2: [10, 11, 12, 13, 14, 15, 16],
243
- 3: [17, 18, 19, 20, 21, 22, 23],
244
- }
245
- model.parallelize(device_map) # Splits the model across several devices
246
- model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
247
- ```
248
- """
249
-
250
-
251
- class T5LayerNorm(nn.Module):
252
- def __init__(self, hidden_size, eps=1e-6):
253
- """
254
- Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
255
- """
256
- super().__init__()
257
- self.weight = nn.Parameter(torch.ones(hidden_size))
258
- self.variance_epsilon = eps
259
-
260
- def forward(self, hidden_states):
261
-
262
- # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
263
- # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
264
- # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
265
- # half-precision inputs is done in fp32
266
-
267
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
268
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
269
-
270
- # convert into half-precision if necessary
271
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
272
- hidden_states = hidden_states.to(self.weight.dtype)
273
-
274
- return self.weight * hidden_states
275
-
276
-
277
- # try:
278
- # from apex.normalization import FusedRMSNorm
279
-
280
- # T5LayerNorm = FusedRMSNorm # noqa
281
-
282
- # logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm")
283
- # except ImportError:
284
- # # using the normal T5LayerNorm
285
- # pass
286
- # except Exception:
287
- # logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
288
- # pass
289
-
290
- ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
291
-
292
-
293
- class T5DenseActDense(nn.Module):
294
- def __init__(self, config: T5Config):
295
- super().__init__()
296
- self.wi = LoRALinear(config.d_model, config.d_ff, "intermediate", config, bias=False)
297
- self.wo = LoRALinear(config.d_ff, config.d_model, "output", config, bias=False)
298
- self.dropout = nn.Dropout(config.dropout_rate)
299
- self.act = ACT2FN[config.dense_act_fn]
300
-
301
- def forward(self, hidden_states):
302
- hidden_states = self.wi(hidden_states)
303
- hidden_states = self.act(hidden_states)
304
- hidden_states = self.dropout(hidden_states)
305
- if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
306
- hidden_states = hidden_states.to(self.wo.weight.dtype)
307
- hidden_states = self.wo(hidden_states)
308
- return hidden_states
309
-
310
-
311
- class T5DenseGatedActDense(nn.Module):
312
- def __init__(self, config: T5Config):
313
- super().__init__()
314
- self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
315
- self.wi_1 = LoRALinear(config.d_model, config.d_ff, "intermediate", config, bias=False)
316
- self.wo = LoRALinear(config.d_ff, config.d_model, "output", config, bias=False)
317
- self.dropout = nn.Dropout(config.dropout_rate)
318
- self.act = ACT2FN[config.dense_act_fn]
319
-
320
- def forward(self, hidden_states):
321
- hidden_gelu = self.act(self.wi_0(hidden_states))
322
- hidden_linear = self.wi_1(hidden_states)
323
- hidden_states = hidden_gelu * hidden_linear
324
- hidden_states = self.dropout(hidden_states)
325
-
326
- # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
327
- # See https://github.com/huggingface/transformers/issues/20287
328
- # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
329
- if hidden_states.dtype != self.wo.weight.dtype and self.wo.weight.dtype != torch.int8:
330
- hidden_states = hidden_states.to(self.wo.weight.dtype)
331
-
332
- hidden_states = self.wo(hidden_states)
333
- return hidden_states
334
-
335
-
336
- class T5LayerFF(T5FFLayerAdaptersMixin, nn.Module):
337
- def __init__(self, config: T5Config):
338
- super().__init__()
339
- self.config = config
340
- if config.is_gated_act:
341
- self.DenseReluDense = T5DenseGatedActDense(config)
342
- else:
343
- self.DenseReluDense = T5DenseActDense(config)
344
-
345
- self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
346
- self.dropout = nn.Dropout(config.dropout_rate)
347
- self._init_adapter_modules()
348
-
349
- def forward(self, hidden_states):
350
- forwarded_states = self.layer_norm(hidden_states)
351
- forwarded_states = self.DenseReluDense(forwarded_states)
352
- hidden_states = self.adapter_layer_forward(
353
- hidden_states=self.dropout(forwarded_states), residual_input=hidden_states, layer_norm=None
354
- )
355
- return hidden_states
356
-
357
-
358
- class T5Attention(nn.Module):
359
- def __init__(self, config: T5Config, has_relative_attention_bias=False, location_key: Optional[str] = None):
360
- super().__init__()
361
- self.is_decoder = config.is_decoder
362
- self.has_relative_attention_bias = has_relative_attention_bias
363
- self.relative_attention_num_buckets = config.relative_attention_num_buckets
364
- self.relative_attention_max_distance = config.relative_attention_max_distance
365
- self.d_model = config.d_model
366
- self.key_value_proj_dim = config.d_kv
367
- self.n_heads = config.num_heads
368
- self.dropout = config.dropout_rate
369
- self.inner_dim = self.n_heads * self.key_value_proj_dim
370
-
371
- # Mesh TensorFlow initialization to avoid scaling before softmax
372
- self.q = LoRALinear(self.d_model, self.inner_dim, "selfattn", config, attn_key="q", bias=False)
373
- self.k = LoRALinear(self.d_model, self.inner_dim, "selfattn", config, attn_key="k", bias=False)
374
- self.v = LoRALinear(self.d_model, self.inner_dim, "selfattn", config, attn_key="v", bias=False)
375
- self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
376
-
377
- if self.has_relative_attention_bias:
378
- self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
379
- self.pruned_heads = set()
380
- self.gradient_checkpointing = False
381
-
382
- self.prefix_tuning = PrefixTuningShim(location_key + "_prefix" if location_key else None, config)
383
-
384
- def prune_heads(self, heads):
385
- if len(heads) == 0:
386
- return
387
- heads, index = find_pruneable_heads_and_indices(
388
- heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
389
- )
390
- # Prune linear layers
391
- self.q = prune_linear_layer(self.q, index)
392
- self.k = prune_linear_layer(self.k, index)
393
- self.v = prune_linear_layer(self.v, index)
394
- self.o = prune_linear_layer(self.o, index, dim=1)
395
- # Update hyper params
396
- self.n_heads = self.n_heads - len(heads)
397
- self.inner_dim = self.key_value_proj_dim * self.n_heads
398
- self.pruned_heads = self.pruned_heads.union(heads)
399
-
400
- @staticmethod
401
- def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
402
- """
403
- Adapted from Mesh Tensorflow:
404
- https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
405
-
406
- Translate relative position to a bucket number for relative attention. The relative position is defined as
407
- memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
408
- position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
409
- small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
410
- positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
411
- This should allow for more graceful generalization to longer sequences than the model has been trained on
412
-
413
- Args:
414
- relative_position: an int32 Tensor
415
- bidirectional: a boolean - whether the attention is bidirectional
416
- num_buckets: an integer
417
- max_distance: an integer
418
-
419
- Returns:
420
- a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
421
- """
422
- relative_buckets = 0
423
- if bidirectional:
424
- num_buckets //= 2
425
- relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
426
- relative_position = torch.abs(relative_position)
427
- else:
428
- relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
429
- # now relative_position is in the range [0, inf)
430
-
431
- # half of the buckets are for exact increments in positions
432
- max_exact = num_buckets // 2
433
- is_small = relative_position < max_exact
434
-
435
- # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
436
- relative_position_if_large = max_exact + (
437
- torch.log(relative_position.float() / max_exact)
438
- / math.log(max_distance / max_exact)
439
- * (num_buckets - max_exact)
440
- ).to(torch.long)
441
- relative_position_if_large = torch.min(
442
- relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
443
- )
444
-
445
- relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
446
- return relative_buckets
447
-
448
- def compute_bias(self, query_length, key_length, device=None):
449
- """Compute binned relative position bias"""
450
- if device is None:
451
- device = self.relative_attention_bias.weight.device
452
- context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
453
- memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
454
- relative_position = memory_position - context_position # shape (query_length, key_length)
455
- relative_position_bucket = self._relative_position_bucket(
456
- relative_position, # shape (query_length, key_length)
457
- bidirectional=(not self.is_decoder),
458
- num_buckets=self.relative_attention_num_buckets,
459
- max_distance=self.relative_attention_max_distance,
460
- )
461
- values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
462
- values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
463
- return values
464
-
465
- def forward(
466
- self,
467
- hidden_states,
468
- mask=None,
469
- key_value_states=None,
470
- position_bias=None,
471
- past_key_value=None,
472
- layer_head_mask=None,
473
- query_length=None,
474
- use_cache=False,
475
- output_attentions=False,
476
- ):
477
- """
478
- Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
479
- """
480
- # Input is (batch_size, seq_length, dim)
481
- # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
482
- # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
483
- batch_size, seq_length = hidden_states.shape[:2]
484
-
485
- real_seq_length = seq_length
486
-
487
- if past_key_value is not None:
488
- assert (
489
- len(past_key_value) == 2
490
- ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
491
- real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
492
-
493
- key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
494
-
495
- def shape(states):
496
- """projection"""
497
- return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
498
-
499
- def unshape(states):
500
- """reshape"""
501
- return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
502
-
503
- def project(hidden_states, proj_layer, key_value_states, past_key_value):
504
- """projects hidden states correctly to key/query states"""
505
- if key_value_states is None:
506
- # self-attn
507
- # (batch_size, n_heads, seq_length, dim_per_head)
508
- hidden_states = shape(proj_layer(hidden_states))
509
- elif past_key_value is None:
510
- # cross-attn
511
- # (batch_size, n_heads, seq_length, dim_per_head)
512
- hidden_states = shape(proj_layer(key_value_states))
513
-
514
- if past_key_value is not None:
515
- if key_value_states is None:
516
- # self-attn
517
- # (batch_size, n_heads, key_length, dim_per_head)
518
- hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
519
- elif past_key_value.shape[2] != key_value_states.shape[1]:
520
- # checking that the `sequence_length` of the `past_key_value` is the same as
521
- # the provided `key_value_states` to support prefix tuning
522
- # cross-attn
523
- # (batch_size, n_heads, seq_length, dim_per_head)
524
- hidden_states = shape(proj_layer(key_value_states))
525
- else:
526
- # cross-attn
527
- hidden_states = past_key_value
528
- return hidden_states
529
-
530
- # get query states
531
- query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
532
-
533
- # get key/value states
534
- key_states = project(
535
- hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
536
- )
537
- value_states = project(
538
- hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
539
- )
540
-
541
- present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
542
-
543
- key_states, value_states, mask = self.prefix_tuning(key_states, value_states, hidden_states, mask)
544
- (query_states,) = adjust_tensors_for_parallel(key_states, query_states)
545
- batch_size, key_length = key_states.shape[0], key_states.shape[2]
546
-
547
- # compute scores
548
- scores = torch.matmul(
549
- query_states, key_states.transpose(3, 2)
550
- ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
551
-
552
- if position_bias is None:
553
- if not self.has_relative_attention_bias:
554
- position_bias = torch.zeros(
555
- (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
556
- )
557
- if self.gradient_checkpointing and self.training:
558
- position_bias.requires_grad = True
559
- else:
560
- position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
561
-
562
- # if key and values are already calculated
563
- # we want only the last query position bias
564
- if past_key_value is not None:
565
- position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
566
-
567
- if mask is not None:
568
- position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
569
-
570
- if self.pruned_heads:
571
- mask = torch.ones(position_bias.shape[1])
572
- mask[list(self.pruned_heads)] = 0
573
- position_bias_masked = position_bias[:, mask.bool()]
574
- else:
575
- position_bias_masked = position_bias
576
-
577
- scores += position_bias_masked
578
- attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
579
- scores
580
- ) # (batch_size, n_heads, seq_length, key_length)
581
- attn_weights = nn.functional.dropout(
582
- attn_weights, p=self.dropout, training=self.training
583
- ) # (batch_size, n_heads, seq_length, key_length)
584
-
585
- # Mask heads if we want to
586
- if layer_head_mask is not None:
587
- attn_weights = attn_weights * layer_head_mask
588
-
589
- attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
590
- attn_output = self.o(attn_output)
591
-
592
- outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
593
-
594
- if output_attentions:
595
- outputs = outputs + (attn_weights,)
596
- return outputs
597
-
598
-
599
- class T5LayerSelfAttention(T5SelfAttentionLayerAdaptersMixin, nn.Module):
600
- def __init__(self, config, has_relative_attention_bias=False, location_key: Optional[str] = None):
601
- super().__init__()
602
- self.config = config
603
- self.SelfAttention = T5Attention(
604
- config, has_relative_attention_bias=has_relative_attention_bias, location_key=location_key
605
- )
606
- self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
607
- self.dropout = nn.Dropout(config.dropout_rate)
608
- self._init_adapter_modules()
609
-
610
- def forward(
611
- self,
612
- hidden_states,
613
- attention_mask=None,
614
- position_bias=None,
615
- layer_head_mask=None,
616
- past_key_value=None,
617
- use_cache=False,
618
- output_attentions=False,
619
- ):
620
- normed_hidden_states = self.layer_norm(hidden_states)
621
- attention_output = self.SelfAttention(
622
- normed_hidden_states,
623
- mask=attention_mask,
624
- position_bias=position_bias,
625
- layer_head_mask=layer_head_mask,
626
- past_key_value=past_key_value,
627
- use_cache=use_cache,
628
- output_attentions=output_attentions,
629
- )
630
- hidden_states = self.adapter_layer_forward(
631
- hidden_states=self.dropout(attention_output[0]), residual_input=hidden_states, layer_norm=None
632
- )
633
- outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
634
- return outputs
635
-
636
-
637
- class T5LayerCrossAttention(T5CrossAttentionLayerAdaptersMixin, nn.Module):
638
- def __init__(self, config):
639
- super().__init__()
640
- self.config = config
641
- self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False, location_key="cross")
642
- self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
643
- self.dropout = nn.Dropout(config.dropout_rate)
644
- self._init_adapter_modules()
645
-
646
- def forward(
647
- self,
648
- hidden_states,
649
- key_value_states,
650
- attention_mask=None,
651
- position_bias=None,
652
- layer_head_mask=None,
653
- past_key_value=None,
654
- use_cache=False,
655
- query_length=None,
656
- output_attentions=False,
657
- ):
658
- normed_hidden_states = self.layer_norm(hidden_states)
659
- attention_output = self.EncDecAttention(
660
- normed_hidden_states,
661
- mask=attention_mask,
662
- key_value_states=key_value_states,
663
- position_bias=position_bias,
664
- layer_head_mask=layer_head_mask,
665
- past_key_value=past_key_value,
666
- use_cache=use_cache,
667
- query_length=query_length,
668
- output_attentions=output_attentions,
669
- )
670
- layer_output = self.adapter_layer_forward(
671
- hidden_states=self.dropout(attention_output[0]), residual_input=hidden_states, layer_norm=None
672
- )
673
- outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
674
- return outputs
675
-
676
-
677
- class T5Block(nn.Module):
678
- def __init__(self, config, has_relative_attention_bias=False):
679
- super().__init__()
680
- self.is_decoder = config.is_decoder
681
- self.layer = nn.ModuleList()
682
- location_key = "self" if self.is_decoder else "encoder"
683
- self.layer.append(
684
- T5LayerSelfAttention(
685
- config, has_relative_attention_bias=has_relative_attention_bias, location_key=location_key
686
- )
687
- )
688
- if self.is_decoder:
689
- self.layer.append(T5LayerCrossAttention(config))
690
-
691
- self.layer.append(T5LayerFF(config))
692
-
693
- def forward(
694
- self,
695
- hidden_states,
696
- attention_mask=None,
697
- position_bias=None,
698
- encoder_hidden_states=None,
699
- encoder_attention_mask=None,
700
- encoder_decoder_position_bias=None,
701
- layer_head_mask=None,
702
- cross_attn_layer_head_mask=None,
703
- past_key_value=None,
704
- use_cache=False,
705
- output_attentions=False,
706
- return_dict=True,
707
- ):
708
-
709
- if past_key_value is not None:
710
- if not self.is_decoder:
711
- logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
712
- expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
713
-
714
- if len(past_key_value) != expected_num_past_key_values:
715
- raise ValueError(
716
- f"There should be {expected_num_past_key_values} past states. "
717
- f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
718
- f"Got {len(past_key_value)} past key / value states"
719
- )
720
-
721
- self_attn_past_key_value = past_key_value[:2]
722
- cross_attn_past_key_value = past_key_value[2:]
723
- else:
724
- self_attn_past_key_value, cross_attn_past_key_value = None, None
725
-
726
- self_attention_outputs = self.layer[0](
727
- hidden_states,
728
- attention_mask=attention_mask,
729
- position_bias=position_bias,
730
- layer_head_mask=layer_head_mask,
731
- past_key_value=self_attn_past_key_value,
732
- use_cache=use_cache,
733
- output_attentions=output_attentions,
734
- )
735
- hidden_states, present_key_value_state = self_attention_outputs[:2]
736
- attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
737
-
738
- # clamp inf values to enable fp16 training
739
- if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
740
- clamp_value = torch.finfo(hidden_states.dtype).max - 1000
741
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
742
-
743
- do_cross_attention = self.is_decoder and encoder_hidden_states is not None
744
- if do_cross_attention:
745
- # the actual query length is unknown for cross attention
746
- # if using past key value states. Need to inject it here
747
- if present_key_value_state is not None:
748
- query_length = present_key_value_state[0].shape[2]
749
- else:
750
- query_length = None
751
-
752
- cross_attention_outputs = self.layer[1](
753
- hidden_states,
754
- key_value_states=encoder_hidden_states,
755
- attention_mask=encoder_attention_mask,
756
- position_bias=encoder_decoder_position_bias,
757
- layer_head_mask=cross_attn_layer_head_mask,
758
- past_key_value=cross_attn_past_key_value,
759
- query_length=query_length,
760
- use_cache=use_cache,
761
- output_attentions=output_attentions,
762
- )
763
- hidden_states = cross_attention_outputs[0]
764
-
765
- # clamp inf values to enable fp16 training
766
- if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
767
- clamp_value = torch.finfo(hidden_states.dtype).max - 1000
768
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
769
-
770
- # Combine self attn and cross attn key value states
771
- if present_key_value_state is not None:
772
- present_key_value_state = present_key_value_state + cross_attention_outputs[1]
773
-
774
- # Keep cross-attention outputs and relative position weights
775
- attention_outputs = attention_outputs + cross_attention_outputs[2:]
776
-
777
- # Apply Feed Forward layer
778
- hidden_states = self.layer[-1](hidden_states)
779
-
780
- # clamp inf values to enable fp16 training
781
- if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
782
- clamp_value = torch.finfo(hidden_states.dtype).max - 1000
783
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
784
-
785
- outputs = (hidden_states,)
786
-
787
- if use_cache:
788
- outputs = outputs + (present_key_value_state,) + attention_outputs
789
- else:
790
- outputs = outputs + attention_outputs
791
-
792
- return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
793
-
794
-
795
- class T5PreTrainedModel(PreTrainedModel):
796
- """
797
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
798
- models.
799
- """
800
-
801
- config_class = T5Config
802
- load_tf_weights = load_tf_weights_in_t5
803
- base_model_prefix = "transformer"
804
- is_parallelizable = True
805
- supports_gradient_checkpointing = True
806
- _no_split_modules = ["T5Block"]
807
- _keep_in_fp32_modules = ["wo"]
808
-
809
- @property
810
- def dummy_inputs(self):
811
- input_ids = torch.tensor(DUMMY_INPUTS)
812
- input_mask = torch.tensor(DUMMY_MASK)
813
- dummy_inputs = {
814
- "decoder_input_ids": input_ids,
815
- "input_ids": input_ids,
816
- "decoder_attention_mask": input_mask,
817
- }
818
- return dummy_inputs
819
-
820
- def _init_weights(self, module):
821
- """Initialize the weights"""
822
- factor = self.config.initializer_factor # Used for testing weights initialization
823
- if isinstance(module, T5LayerNorm):
824
- module.weight.data.fill_(factor * 1.0)
825
- elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)):
826
- # Mesh TensorFlow embeddings initialization
827
- # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
828
- module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
829
- if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
830
- module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
831
- elif isinstance(module, T5DenseActDense):
832
- # Mesh TensorFlow FF initialization
833
- # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
834
- # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
835
- module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
836
- if hasattr(module.wi, "bias") and module.wi.bias is not None:
837
- module.wi.bias.data.zero_()
838
- module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
839
- if hasattr(module.wo, "bias") and module.wo.bias is not None:
840
- module.wo.bias.data.zero_()
841
- elif isinstance(module, T5DenseGatedActDense):
842
- module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
843
- if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
844
- module.wi_0.bias.data.zero_()
845
- module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
846
- if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
847
- module.wi_1.bias.data.zero_()
848
- module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
849
- if hasattr(module.wo, "bias") and module.wo.bias is not None:
850
- module.wo.bias.data.zero_()
851
- elif isinstance(module, T5Attention):
852
- # Mesh TensorFlow attention initialization to avoid scaling before softmax
853
- # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
854
- d_model = self.config.d_model
855
- key_value_proj_dim = self.config.d_kv
856
- n_heads = self.config.num_heads
857
- module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
858
- module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
859
- module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
860
- module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
861
- if module.has_relative_attention_bias:
862
- module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
863
-
864
- def _set_gradient_checkpointing(self, module, value=False):
865
- if isinstance(module, (T5Attention, T5Stack)):
866
- module.gradient_checkpointing = value
867
-
868
- def _shift_right(self, input_ids):
869
- decoder_start_token_id = self.config.decoder_start_token_id
870
- pad_token_id = self.config.pad_token_id
871
-
872
- assert decoder_start_token_id is not None, (
873
- "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id."
874
- " See T5 docs for more information"
875
- )
876
-
877
- # shift inputs to the right
878
- if is_torch_fx_proxy(input_ids):
879
- # Item assignment is not supported natively for proxies.
880
- shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
881
- shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
882
- else:
883
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
884
- shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
885
- shifted_input_ids[..., 0] = decoder_start_token_id
886
-
887
- assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
888
- # replace possible -100 values in labels by `pad_token_id`
889
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
890
-
891
- return shifted_input_ids
892
-
893
-
894
- class T5Stack(InvertibleAdaptersMixin, T5PreTrainedModel):
895
- def __init__(self, config, embed_tokens=None):
896
- super().__init__(config)
897
-
898
- self.embed_tokens = embed_tokens
899
- self.is_decoder = config.is_decoder
900
-
901
- self.block = nn.ModuleList(
902
- [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
903
- )
904
- self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
905
- self.dropout = nn.Dropout(config.dropout_rate)
906
-
907
- # Initialize weights and apply final processing
908
- self.post_init()
909
- # Model parallel
910
- self.model_parallel = False
911
- self.device_map = None
912
- self.gradient_checkpointing = False
913
-
914
- @add_start_docstrings(PARALLELIZE_DOCSTRING)
915
- def parallelize(self, device_map=None):
916
- # Check validity of device_map
917
- self.device_map = (
918
- get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
919
- )
920
- assert_device_map(self.device_map, len(self.block))
921
- self.model_parallel = True
922
- self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
923
- self.last_device = "cuda:" + str(max(self.device_map.keys()))
924
- # Load onto devices
925
- for k, v in self.device_map.items():
926
- for layer in v:
927
- cuda_device = "cuda:" + str(k)
928
- self.block[layer] = self.block[layer].to(cuda_device)
929
-
930
- # Set embed_tokens to first layer
931
- self.embed_tokens = self.embed_tokens.to(self.first_device)
932
- # Set final layer norm to last device
933
- self.final_layer_norm = self.final_layer_norm.to(self.last_device)
934
-
935
- @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
936
- def deparallelize(self):
937
- self.model_parallel = False
938
- self.device_map = None
939
- self.first_device = "cpu"
940
- self.last_device = "cpu"
941
- for i in range(len(self.block)):
942
- self.block[i] = self.block[i].to("cpu")
943
- self.embed_tokens = self.embed_tokens.to("cpu")
944
- self.final_layer_norm = self.final_layer_norm.to("cpu")
945
- torch.cuda.empty_cache()
946
-
947
- def get_input_embeddings(self):
948
- return self.embed_tokens
949
-
950
- def set_input_embeddings(self, new_embeddings):
951
- self.embed_tokens = new_embeddings
952
-
953
- def forward(
954
- self,
955
- input_ids=None,
956
- attention_mask=None,
957
- encoder_hidden_states=None,
958
- encoder_attention_mask=None,
959
- inputs_embeds=None,
960
- head_mask=None,
961
- cross_attn_head_mask=None,
962
- past_key_values=None,
963
- use_cache=None,
964
- output_attentions=None,
965
- output_hidden_states=None,
966
- return_dict=None,
967
- ):
968
- # Model parallel
969
- if self.model_parallel:
970
- torch.cuda.set_device(self.first_device)
971
- self.embed_tokens = self.embed_tokens.to(self.first_device)
972
- use_cache = use_cache if use_cache is not None else self.config.use_cache
973
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
974
- output_hidden_states = (
975
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
976
- )
977
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
978
- if self.is_decoder and encoder_hidden_states is not None:
979
- input_ids, encoder_attention_mask = adjust_tensors_for_parallel(
980
- encoder_hidden_states, input_ids, encoder_attention_mask
981
- )
982
-
983
- if input_ids is not None and inputs_embeds is not None:
984
- err_msg_prefix = "decoder_" if self.is_decoder else ""
985
- raise ValueError(
986
- f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
987
- )
988
- elif input_ids is not None:
989
- input_shape = input_ids.size()
990
- input_ids = input_ids.view(-1, input_shape[-1])
991
- elif inputs_embeds is not None:
992
- input_shape = inputs_embeds.size()[:-1]
993
- else:
994
- err_msg_prefix = "decoder_" if self.is_decoder else ""
995
- raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
996
-
997
- if inputs_embeds is None:
998
- assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
999
- inputs_embeds = self.embed_tokens(input_ids)
1000
-
1001
- batch_size, seq_length = input_shape
1002
-
1003
- # required mask seq length can be calculated via length of past
1004
- mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
1005
-
1006
- if use_cache is True:
1007
- assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
1008
-
1009
- if attention_mask is None:
1010
- attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
1011
- if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
1012
- encoder_seq_length = encoder_hidden_states.shape[1]
1013
- encoder_attention_mask = torch.ones(
1014
- batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
1015
- )
1016
-
1017
- # initialize past_key_values with `None` if past does not exist
1018
- if past_key_values is None:
1019
- past_key_values = [None] * len(self.block)
1020
-
1021
- # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1022
- # ourselves in which case we just need to make it broadcastable to all heads.
1023
- extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
1024
-
1025
- # If a 2D or 3D attention mask is provided for the cross-attention
1026
- # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1027
- if self.is_decoder and encoder_hidden_states is not None:
1028
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1029
- encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1030
- if encoder_attention_mask is None:
1031
- encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
1032
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1033
- else:
1034
- encoder_extended_attention_mask = None
1035
-
1036
- # Prepare head mask if needed
1037
- head_mask = self.get_head_mask(head_mask, self.config.num_layers)
1038
- cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
1039
- present_key_value_states = () if use_cache else None
1040
- all_hidden_states = () if output_hidden_states else None
1041
- all_attentions = () if output_attentions else None
1042
- all_cross_attentions = () if (output_attentions and self.is_decoder) else None
1043
- position_bias = None
1044
- encoder_decoder_position_bias = None
1045
-
1046
- hidden_states = self.dropout(inputs_embeds)
1047
- if not self.is_decoder:
1048
- hidden_states = self.invertible_adapters_forward(hidden_states)
1049
-
1050
- for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
1051
- layer_head_mask = head_mask[i]
1052
- cross_attn_layer_head_mask = cross_attn_head_mask[i]
1053
- # Model parallel
1054
- if self.model_parallel:
1055
- torch.cuda.set_device(hidden_states.device)
1056
- # Ensure that attention_mask is always on the same device as hidden_states
1057
- if attention_mask is not None:
1058
- attention_mask = attention_mask.to(hidden_states.device)
1059
- if position_bias is not None:
1060
- position_bias = position_bias.to(hidden_states.device)
1061
- if encoder_hidden_states is not None:
1062
- encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
1063
- if encoder_extended_attention_mask is not None:
1064
- encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
1065
- if encoder_decoder_position_bias is not None:
1066
- encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
1067
- if layer_head_mask is not None:
1068
- layer_head_mask = layer_head_mask.to(hidden_states.device)
1069
- if cross_attn_layer_head_mask is not None:
1070
- cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
1071
- if output_hidden_states:
1072
- all_hidden_states = all_hidden_states + (hidden_states,)
1073
-
1074
- if self.gradient_checkpointing and self.training:
1075
- if use_cache:
1076
- logger.warning(
1077
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1078
- )
1079
- use_cache = False
1080
-
1081
- def create_custom_forward(module):
1082
- def custom_forward(*inputs):
1083
- return tuple(module(*inputs, use_cache, output_attentions))
1084
-
1085
- return custom_forward
1086
-
1087
- layer_outputs = checkpoint(
1088
- create_custom_forward(layer_module),
1089
- hidden_states,
1090
- extended_attention_mask,
1091
- position_bias,
1092
- encoder_hidden_states,
1093
- encoder_extended_attention_mask,
1094
- encoder_decoder_position_bias,
1095
- layer_head_mask,
1096
- cross_attn_layer_head_mask,
1097
- None, # past_key_value is always None with gradient checkpointing
1098
- )
1099
- else:
1100
- layer_outputs = layer_module(
1101
- hidden_states,
1102
- attention_mask=extended_attention_mask,
1103
- position_bias=position_bias,
1104
- encoder_hidden_states=encoder_hidden_states,
1105
- encoder_attention_mask=encoder_extended_attention_mask,
1106
- encoder_decoder_position_bias=encoder_decoder_position_bias,
1107
- layer_head_mask=layer_head_mask,
1108
- cross_attn_layer_head_mask=cross_attn_layer_head_mask,
1109
- past_key_value=past_key_value,
1110
- use_cache=use_cache,
1111
- output_attentions=output_attentions,
1112
- )
1113
-
1114
- # layer_outputs is a tuple with:
1115
- # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
1116
- if use_cache is False:
1117
- layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
1118
-
1119
- hidden_states, present_key_value_state = layer_outputs[:2]
1120
-
1121
- attention_mask, extended_attention_mask = adjust_tensors_for_parallel(
1122
- hidden_states, attention_mask, extended_attention_mask
1123
- )
1124
-
1125
- # We share the position biases between the layers - the first layer store them
1126
- # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
1127
- # (cross-attention position bias), (cross-attention weights)
1128
- position_bias = layer_outputs[2]
1129
- if self.is_decoder and encoder_hidden_states is not None:
1130
- encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
1131
- # append next layer key value states
1132
- if use_cache:
1133
- present_key_value_states = present_key_value_states + (present_key_value_state,)
1134
-
1135
- if position_bias is not None:
1136
- position_bias = adjust_tensors_for_parallel(hidden_states, position_bias)[0]
1137
- if encoder_decoder_position_bias is not None:
1138
- encoder_decoder_position_bias = adjust_tensors_for_parallel(
1139
- hidden_states, encoder_decoder_position_bias
1140
- )[0]
1141
-
1142
- if output_attentions:
1143
- all_attentions = all_attentions + (layer_outputs[3],)
1144
- if self.is_decoder:
1145
- all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
1146
-
1147
- # Model Parallel: If it's the last layer for that device, put things on the next device
1148
- if self.model_parallel:
1149
- for k, v in self.device_map.items():
1150
- if i == v[-1] and "cuda:" + str(k) != self.last_device:
1151
- hidden_states = hidden_states.to("cuda:" + str(k + 1))
1152
-
1153
- hidden_states = self.final_layer_norm(hidden_states)
1154
- hidden_states = self.dropout(hidden_states)
1155
-
1156
- # Add last layer
1157
- if output_hidden_states:
1158
- all_hidden_states = all_hidden_states + (hidden_states,)
1159
-
1160
- if not return_dict:
1161
- return tuple(
1162
- v
1163
- for v in [
1164
- hidden_states,
1165
- present_key_value_states,
1166
- all_hidden_states,
1167
- all_attentions,
1168
- all_cross_attentions,
1169
- ]
1170
- if v is not None
1171
- )
1172
- return BaseModelOutputWithPastAndCrossAttentions(
1173
- last_hidden_state=hidden_states,
1174
- past_key_values=present_key_value_states,
1175
- hidden_states=all_hidden_states,
1176
- attentions=all_attentions,
1177
- cross_attentions=all_cross_attentions,
1178
- )
1179
-
1180
-
1181
- T5_START_DOCSTRING = r"""
1182
-
1183
- The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text
1184
- Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan
1185
- Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a
1186
- text-to-text denoising generative setting.
1187
-
1188
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1189
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1190
- etc.)
1191
-
1192
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1193
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1194
- and behavior.
1195
-
1196
- Parameters:
1197
- config ([`T5Config`]): Model configuration class with all the parameters of the model.
1198
- Initializing with a config file does not load the weights associated with the model, only the
1199
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1200
- """
1201
-
1202
- T5_INPUTS_DOCSTRING = r"""
1203
- Args:
1204
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1205
- Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
1206
- should be able to pad the inputs on both the right and the left.
1207
-
1208
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1209
- [`PreTrainedTokenizer.__call__`] for detail.
1210
-
1211
- [What are input IDs?](../glossary#input-ids)
1212
-
1213
- To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
1214
- attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1215
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1216
-
1217
- - 1 for tokens that are **not masked**,
1218
- - 0 for tokens that are **masked**.
1219
-
1220
- [What are attention masks?](../glossary#attention-mask)
1221
- decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
1222
- Indices of decoder input sequence tokens in the vocabulary.
1223
-
1224
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1225
- [`PreTrainedTokenizer.__call__`] for details.
1226
-
1227
- [What are decoder input IDs?](../glossary#decoder-input-ids)
1228
-
1229
- T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
1230
- is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
1231
-
1232
- To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
1233
- Training](./t5#training).
1234
- decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
1235
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
1236
- be used by default.
1237
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1238
- Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
1239
- 1]`:
1240
-
1241
- - 1 indicates the head is **not masked**,
1242
- - 0 indicates the head is **masked**.
1243
-
1244
- decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1245
- Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
1246
- 1]`:
1247
-
1248
- - 1 indicates the head is **not masked**,
1249
- - 0 indicates the head is **masked**.
1250
-
1251
- cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1252
- Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
1253
- `[0, 1]`:
1254
-
1255
- - 1 indicates the head is **not masked**,
1256
- - 0 indicates the head is **masked**.
1257
-
1258
- encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
1259
- Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
1260
- `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
1261
- the output of the last layer of the encoder. Used in the cross-attention of the decoder.
1262
- past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1263
- Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1264
-
1265
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1266
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1267
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1268
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1269
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1270
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1271
- model's internal embedding lookup matrix.
1272
- decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
1273
- Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
1274
- representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
1275
- input (see `past_key_values`). This is useful if you want more control over how to convert
1276
- `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
1277
-
1278
- If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
1279
- of `inputs_embeds`.
1280
-
1281
- use_cache (`bool`, *optional*):
1282
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1283
- `past_key_values`).
1284
-
1285
- output_attentions (`bool`, *optional*):
1286
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1287
- tensors for more detail.
1288
- output_hidden_states (`bool`, *optional*):
1289
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1290
- more detail.
1291
- return_dict (`bool`, *optional*):
1292
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1293
- """
1294
-
1295
- T5_ENCODER_INPUTS_DOCSTRING = r"""
1296
- Args:
1297
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1298
- Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
1299
- should be able to pad the inputs on both the right and the left.
1300
-
1301
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1302
- [`PreTrainedTokenizer.__call__`] for detail.
1303
-
1304
- To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
1305
- attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1306
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1307
-
1308
- - 1 for tokens that are **not masked**,
1309
- - 0 for tokens that are **masked**.
1310
-
1311
- [What are attention masks?](../glossary#attention-mask)
1312
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
1313
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
1314
-
1315
- - 1 indicates the head is **not masked**,
1316
- - 0 indicates the head is **masked**.
1317
-
1318
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1319
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1320
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1321
- model's internal embedding lookup matrix.
1322
- output_attentions (`bool`, *optional*):
1323
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1324
- tensors for more detail.
1325
- output_hidden_states (`bool`, *optional*):
1326
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1327
- more detail.
1328
- return_dict (`bool`, *optional*):
1329
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1330
- """
1331
-
1332
- # Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1333
- __HEAD_MASK_WARNING_MSG = """
1334
- The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
1335
- `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
1336
- If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
1337
- num_heads)`.
1338
- """
1339
-
1340
-
1341
- @add_start_docstrings(
1342
- "The bare T5 Model transformer outputting raw hidden-states without any specific head on top.",
1343
- T5_START_DOCSTRING,
1344
- )
1345
- class T5Model(T5ModelAdaptersMixin, T5PreTrainedModel):
1346
- _keys_to_ignore_on_load_missing = [
1347
- r"encoder.embed_tokens.weight",
1348
- r"decoder.embed_tokens.weight",
1349
- ]
1350
- _keys_to_ignore_on_load_unexpected = [
1351
- r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
1352
- ]
1353
-
1354
- def __init__(self, config: T5Config):
1355
- super().__init__(config)
1356
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
1357
-
1358
- encoder_config = copy.deepcopy(config)
1359
- encoder_config.is_decoder = False
1360
- encoder_config.use_cache = False
1361
- encoder_config.is_encoder_decoder = False
1362
- encoder_config.adapters = config.adapters
1363
- self.encoder = T5Stack(encoder_config, self.shared)
1364
-
1365
- decoder_config = copy.deepcopy(config)
1366
- decoder_config.is_decoder = True
1367
- decoder_config.is_encoder_decoder = False
1368
- decoder_config.num_layers = config.num_decoder_layers
1369
- decoder_config.adapters = config.adapters
1370
- self.decoder = T5Stack(decoder_config, self.shared)
1371
-
1372
- self._init_adapter_modules()
1373
-
1374
- # Initialize weights and apply final processing
1375
- self.post_init()
1376
-
1377
- # Model parallel
1378
- self.model_parallel = False
1379
- self.device_map = None
1380
-
1381
- @add_start_docstrings(PARALLELIZE_DOCSTRING)
1382
- def parallelize(self, device_map=None):
1383
- self.device_map = (
1384
- get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1385
- if device_map is None
1386
- else device_map
1387
- )
1388
- assert_device_map(self.device_map, len(self.encoder.block))
1389
- self.encoder.parallelize(self.device_map)
1390
- self.decoder.parallelize(self.device_map)
1391
- self.model_parallel = True
1392
-
1393
- @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1394
- def deparallelize(self):
1395
- self.encoder.deparallelize()
1396
- self.decoder.deparallelize()
1397
- self.encoder = self.encoder.to("cpu")
1398
- self.decoder = self.decoder.to("cpu")
1399
- self.model_parallel = False
1400
- self.device_map = None
1401
- torch.cuda.empty_cache()
1402
-
1403
- def get_input_embeddings(self):
1404
- return self.shared
1405
-
1406
- def set_input_embeddings(self, new_embeddings):
1407
- self.shared = new_embeddings
1408
- self.encoder.set_input_embeddings(new_embeddings)
1409
- self.decoder.set_input_embeddings(new_embeddings)
1410
-
1411
- def get_encoder(self):
1412
- return self.encoder
1413
-
1414
- def get_decoder(self):
1415
- return self.decoder
1416
-
1417
- def _prune_heads(self, heads_to_prune):
1418
- """
1419
- Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1420
- class PreTrainedModel
1421
- """
1422
- for layer, heads in heads_to_prune.items():
1423
- self.encoder.layer[layer].attention.prune_heads(heads)
1424
-
1425
- @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1426
- @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
1427
- @ForwardContext.wrap
1428
- def forward(
1429
- self,
1430
- input_ids: Optional[torch.LongTensor] = None,
1431
- attention_mask: Optional[torch.FloatTensor] = None,
1432
- decoder_input_ids: Optional[torch.LongTensor] = None,
1433
- decoder_attention_mask: Optional[torch.BoolTensor] = None,
1434
- head_mask: Optional[torch.FloatTensor] = None,
1435
- decoder_head_mask: Optional[torch.FloatTensor] = None,
1436
- cross_attn_head_mask: Optional[torch.Tensor] = None,
1437
- encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1438
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1439
- inputs_embeds: Optional[torch.Tensor] = None,
1440
- decoder_inputs_embeds: Optional[torch.Tensor] = None,
1441
- use_cache: Optional[bool] = None,
1442
- output_attentions: Optional[bool] = None,
1443
- output_hidden_states: Optional[bool] = None,
1444
- return_dict: Optional[bool] = None,
1445
- ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
1446
- r"""
1447
- Returns:
1448
-
1449
- Example:
1450
-
1451
- ```python
1452
- >>> from transformers import AutoTokenizer, T5Model
1453
-
1454
- >>> tokenizer = AutoTokenizer.from_pretrained("t5-small")
1455
- >>> model = T5Model.from_pretrained("t5-small")
1456
-
1457
- >>> input_ids = tokenizer(
1458
- ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1459
- ... ).input_ids # Batch size 1
1460
- >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
1461
-
1462
- >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
1463
- >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
1464
- >>> decoder_input_ids = model._shift_right(decoder_input_ids)
1465
-
1466
- >>> # forward pass
1467
- >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
1468
- >>> last_hidden_states = outputs.last_hidden_state
1469
- ```"""
1470
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1471
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1472
-
1473
- # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1474
- if head_mask is not None and decoder_head_mask is None:
1475
- if self.config.num_layers == self.config.num_decoder_layers:
1476
- warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
1477
- decoder_head_mask = head_mask
1478
-
1479
- # Encode if needed (training, first prediction pass)
1480
- if encoder_outputs is None:
1481
- encoder_outputs = self.encoder(
1482
- input_ids=input_ids,
1483
- attention_mask=attention_mask,
1484
- inputs_embeds=inputs_embeds,
1485
- head_mask=head_mask,
1486
- output_attentions=output_attentions,
1487
- output_hidden_states=output_hidden_states,
1488
- return_dict=return_dict,
1489
- )
1490
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1491
- encoder_outputs = BaseModelOutput(
1492
- last_hidden_state=encoder_outputs[0],
1493
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1494
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1495
- )
1496
-
1497
- hidden_states = encoder_outputs[0]
1498
-
1499
- # Set device for model parallelism
1500
- if self.model_parallel:
1501
- torch.cuda.set_device(self.decoder.first_device)
1502
- hidden_states = hidden_states.to(self.decoder.first_device)
1503
- if decoder_input_ids is not None:
1504
- decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
1505
- if attention_mask is not None:
1506
- attention_mask = attention_mask.to(self.decoder.first_device)
1507
- if decoder_attention_mask is not None:
1508
- decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
1509
-
1510
- # Decode
1511
- decoder_outputs = self.decoder(
1512
- input_ids=decoder_input_ids,
1513
- attention_mask=decoder_attention_mask,
1514
- inputs_embeds=decoder_inputs_embeds,
1515
- past_key_values=past_key_values,
1516
- encoder_hidden_states=hidden_states,
1517
- encoder_attention_mask=attention_mask,
1518
- head_mask=decoder_head_mask,
1519
- cross_attn_head_mask=cross_attn_head_mask,
1520
- use_cache=use_cache,
1521
- output_attentions=output_attentions,
1522
- output_hidden_states=output_hidden_states,
1523
- return_dict=return_dict,
1524
- )
1525
-
1526
- if not return_dict:
1527
- return decoder_outputs + encoder_outputs
1528
 
1529
- return Seq2SeqModelOutput(
1530
- last_hidden_state=decoder_outputs.last_hidden_state,
1531
- past_key_values=decoder_outputs.past_key_values,
1532
- decoder_hidden_states=decoder_outputs.hidden_states,
1533
- decoder_attentions=decoder_outputs.attentions,
1534
- cross_attentions=decoder_outputs.cross_attentions,
1535
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1536
- encoder_hidden_states=encoder_outputs.hidden_states,
1537
- encoder_attentions=encoder_outputs.attentions,
1538
- )
1539
 
1540
 
1541
- @add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
1542
- class T5ForConditionalGeneration(T5ModelWithHeadsAdaptersMixin, T5ModelAdaptersMixin, T5PreTrainedModel):
1543
- _keys_to_ignore_on_load_missing = [
1544
- r"encoder.embed_tokens.weight",
1545
- r"decoder.embed_tokens.weight",
1546
- r"lm_head.weight",
1547
- ]
1548
- _keys_to_ignore_on_load_unexpected = [
1549
- r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
1550
- ]
1551
 
 
 
1552
  def __init__(self, config: T5Config):
1553
  super().__init__(config)
1554
- self.model_dim = config.d_model
1555
-
1556
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
1557
-
1558
- encoder_config = copy.deepcopy(config)
1559
- encoder_config.is_decoder = False
1560
- encoder_config.use_cache = False
1561
- encoder_config.is_encoder_decoder = False
1562
- encoder_config.adapters = config.adapters
1563
- self.encoder = T5Stack(encoder_config, self.shared)
1564
-
1565
- decoder_config = copy.deepcopy(config)
1566
- decoder_config.is_decoder = True
1567
- decoder_config.is_encoder_decoder = False
1568
- decoder_config.num_layers = config.num_decoder_layers
1569
- decoder_config.adapters = config.adapters
1570
- self.decoder = T5Stack(decoder_config, self.shared)
1571
-
1572
- self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
1573
-
1574
- self._init_adapter_modules()
1575
-
1576
- # Initialize weights and apply final processing
1577
- self.post_init()
1578
-
1579
- # Model parallel
1580
- self.model_parallel = False
1581
- self.device_map = None
1582
-
1583
- @add_start_docstrings(PARALLELIZE_DOCSTRING)
1584
- def parallelize(self, device_map=None):
1585
- self.device_map = (
1586
- get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1587
- if device_map is None
1588
- else device_map
1589
- )
1590
- assert_device_map(self.device_map, len(self.encoder.block))
1591
- self.encoder.parallelize(self.device_map)
1592
- self.decoder.parallelize(self.device_map)
1593
- self.lm_head = self.lm_head.to(self.decoder.first_device)
1594
- self.model_parallel = True
1595
-
1596
- @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1597
- def deparallelize(self):
1598
- self.encoder.deparallelize()
1599
- self.decoder.deparallelize()
1600
- self.encoder = self.encoder.to("cpu")
1601
- self.decoder = self.decoder.to("cpu")
1602
- self.lm_head = self.lm_head.to("cpu")
1603
- self.model_parallel = False
1604
- self.device_map = None
1605
- torch.cuda.empty_cache()
1606
-
1607
- def get_input_embeddings(self):
1608
- return self.shared
1609
-
1610
- def set_input_embeddings(self, new_embeddings):
1611
- self.shared = new_embeddings
1612
- self.encoder.set_input_embeddings(new_embeddings)
1613
- self.decoder.set_input_embeddings(new_embeddings)
1614
-
1615
- def set_output_embeddings(self, new_embeddings):
1616
- self.lm_head = new_embeddings
1617
-
1618
- def get_output_embeddings(self):
1619
- return self.lm_head
1620
-
1621
- def get_encoder(self):
1622
- return self.encoder
1623
-
1624
- def get_decoder(self):
1625
- return self.decoder
1626
-
1627
- @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
1628
- @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
1629
- @ForwardContext.wrap
1630
- def forward(
1631
- self,
1632
- input_ids: Optional[torch.LongTensor] = None,
1633
- attention_mask: Optional[torch.FloatTensor] = None,
1634
- decoder_input_ids: Optional[torch.LongTensor] = None,
1635
- decoder_attention_mask: Optional[torch.BoolTensor] = None,
1636
- head_mask: Optional[torch.FloatTensor] = None,
1637
- decoder_head_mask: Optional[torch.FloatTensor] = None,
1638
- cross_attn_head_mask: Optional[torch.Tensor] = None,
1639
- encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1640
- past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1641
- inputs_embeds: Optional[torch.FloatTensor] = None,
1642
- decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1643
- labels: Optional[torch.LongTensor] = None,
1644
- use_cache: Optional[bool] = None,
1645
- output_attentions: Optional[bool] = None,
1646
- output_hidden_states: Optional[bool] = None,
1647
- return_dict: Optional[bool] = None,
1648
- ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
1649
- r"""
1650
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1651
- Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
1652
- config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
1653
- labels in `[0, ..., config.vocab_size]`
1654
-
1655
- Returns:
1656
-
1657
- Examples:
1658
-
1659
- ```python
1660
- >>> from transformers import AutoTokenizer, T5ForConditionalGeneration
1661
-
1662
- >>> tokenizer = AutoTokenizer.from_pretrained("t5-small")
1663
- >>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
1664
-
1665
- >>> # training
1666
- >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
1667
- >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
1668
- >>> outputs = model(input_ids=input_ids, labels=labels)
1669
- >>> loss = outputs.loss
1670
- >>> logits = outputs.logits
1671
-
1672
- >>> # inference
1673
- >>> input_ids = tokenizer(
1674
- ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
1675
- ... ).input_ids # Batch size 1
1676
- >>> outputs = model.generate(input_ids)
1677
- >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
1678
- >>> # studies have shown that owning a dog is good for you.
1679
- ```"""
1680
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1681
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1682
-
1683
- # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1684
- if head_mask is not None and decoder_head_mask is None:
1685
- if self.config.num_layers == self.config.num_decoder_layers:
1686
- warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
1687
- decoder_head_mask = head_mask
1688
-
1689
- # Encode if needed (training, first prediction pass)
1690
- if encoder_outputs is None:
1691
- # Convert encoder inputs in embeddings if needed
1692
- encoder_outputs = self.encoder(
1693
- input_ids=input_ids,
1694
- attention_mask=attention_mask,
1695
- inputs_embeds=inputs_embeds,
1696
- head_mask=head_mask,
1697
- output_attentions=output_attentions,
1698
- output_hidden_states=output_hidden_states,
1699
- return_dict=return_dict,
1700
- )
1701
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1702
- encoder_outputs = BaseModelOutput(
1703
- last_hidden_state=encoder_outputs[0],
1704
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1705
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1706
- )
1707
-
1708
- hidden_states = encoder_outputs[0]
1709
-
1710
- if self.model_parallel:
1711
- torch.cuda.set_device(self.decoder.first_device)
1712
-
1713
- if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
1714
- # get decoder inputs from shifting lm labels to the right
1715
- decoder_input_ids = self._shift_right(labels)
1716
-
1717
- # Set device for model parallelism
1718
- if self.model_parallel:
1719
- torch.cuda.set_device(self.decoder.first_device)
1720
- hidden_states = hidden_states.to(self.decoder.first_device)
1721
- if decoder_input_ids is not None:
1722
- decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
1723
- if attention_mask is not None:
1724
- attention_mask = attention_mask.to(self.decoder.first_device)
1725
- if decoder_attention_mask is not None:
1726
- decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
1727
-
1728
- # Decode
1729
- decoder_outputs = self.decoder(
1730
- input_ids=decoder_input_ids,
1731
- attention_mask=decoder_attention_mask,
1732
- inputs_embeds=decoder_inputs_embeds,
1733
- past_key_values=past_key_values,
1734
- encoder_hidden_states=hidden_states,
1735
- encoder_attention_mask=attention_mask,
1736
- head_mask=decoder_head_mask,
1737
- cross_attn_head_mask=cross_attn_head_mask,
1738
- use_cache=use_cache,
1739
- output_attentions=output_attentions,
1740
- output_hidden_states=output_hidden_states,
1741
- return_dict=return_dict,
1742
- )
1743
-
1744
- sequence_output = decoder_outputs[0]
1745
-
1746
- # Set device for model parallelism
1747
- if self.model_parallel:
1748
- torch.cuda.set_device(self.encoder.first_device)
1749
- self.lm_head = self.lm_head.to(self.encoder.first_device)
1750
- sequence_output = sequence_output.to(self.lm_head.weight.device)
1751
-
1752
- if self.config.tie_word_embeddings:
1753
- # Rescale output before projecting on vocab
1754
- # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
1755
- sequence_output = sequence_output * (self.model_dim**-0.5)
1756
-
1757
- projected_output = self.encoder.invertible_adapters_forward(sequence_output, rev=True)
1758
-
1759
- self.invertible_adapters_forward(projected_output, rev=True)
1760
-
1761
- lm_logits = self.lm_head(projected_output)
1762
-
1763
- loss = None
1764
- if labels is not None:
1765
- loss_fct = CrossEntropyLoss(ignore_index=-100)
1766
- loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1767
- # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
1768
-
1769
- if not return_dict:
1770
- output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
1771
- return ((loss,) + output) if loss is not None else output
1772
-
1773
- return Seq2SeqLMOutput(
1774
- loss=loss,
1775
- logits=lm_logits,
1776
- past_key_values=decoder_outputs.past_key_values,
1777
- decoder_hidden_states=decoder_outputs.hidden_states,
1778
- decoder_attentions=decoder_outputs.attentions,
1779
- cross_attentions=decoder_outputs.cross_attentions,
1780
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1781
- encoder_hidden_states=encoder_outputs.hidden_states,
1782
- encoder_attentions=encoder_outputs.attentions,
1783
- )
1784
-
1785
- def prepare_inputs_for_generation(
1786
- self,
1787
- input_ids,
1788
- past_key_values=None,
1789
- attention_mask=None,
1790
- head_mask=None,
1791
- decoder_head_mask=None,
1792
- cross_attn_head_mask=None,
1793
- use_cache=None,
1794
- encoder_outputs=None,
1795
- **kwargs
1796
- ):
1797
-
1798
- # cut decoder_input_ids if past is used
1799
- if past_key_values is not None:
1800
- input_ids = input_ids[:, -1:]
1801
-
1802
- return {
1803
- "decoder_input_ids": input_ids,
1804
- "past_key_values": past_key_values,
1805
- "encoder_outputs": encoder_outputs,
1806
- "attention_mask": attention_mask,
1807
- "head_mask": head_mask,
1808
- "decoder_head_mask": decoder_head_mask,
1809
- "cross_attn_head_mask": cross_attn_head_mask,
1810
- "use_cache": use_cache,
1811
- }
1812
-
1813
- def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1814
- return self._shift_right(labels)
1815
-
1816
- def _reorder_cache(self, past, beam_idx):
1817
- # if decoder past is not included in output
1818
- # speedy decoding is disabled and no need to reorder
1819
- if past is None:
1820
- logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
1821
- return past
1822
-
1823
- reordered_decoder_past = ()
1824
- for layer_past_states in past:
1825
- # get the correct batch idx from layer past batch dim
1826
- # batch dim of `past` is at 2nd position
1827
- reordered_layer_past_states = ()
1828
- for layer_past_state in layer_past_states:
1829
- # need to set correct `past` for each of the four key / value states
1830
- reordered_layer_past_states = reordered_layer_past_states + (
1831
- layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
1832
- )
1833
-
1834
- assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
1835
- assert len(reordered_layer_past_states) == len(layer_past_states)
1836
-
1837
- reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
1838
- return reordered_decoder_past
1839
-
1840
  def preprocess(self,text):
1841
  text = text.replace("\n", "\\n").replace("\t", "\\t")
1842
  return text
@@ -1876,112 +52,3 @@ class T5ForConditionalGeneration(T5ModelWithHeadsAdaptersMixin, T5ModelAdaptersM
1876
 
1877
  history.append((query, response))
1878
  return response,history
1879
-
1880
- @add_start_docstrings(
1881
- "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
1882
- T5_START_DOCSTRING,
1883
- )
1884
- class T5EncoderModel(T5ModelAdaptersMixin, T5PreTrainedModel):
1885
- authorized_missing_keys = [
1886
- r"encoder.embed_tokens.weight",
1887
- ]
1888
- _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
1889
-
1890
- def __init__(self, config: T5Config):
1891
- super().__init__(config)
1892
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
1893
-
1894
- encoder_config = copy.deepcopy(config)
1895
- encoder_config.use_cache = False
1896
- encoder_config.is_encoder_decoder = False
1897
- encoder_config.adapters = config.adapters
1898
- self.encoder = T5Stack(encoder_config, self.shared)
1899
-
1900
- # Initialize weights and apply final processing
1901
- self.post_init()
1902
-
1903
- # Model parallel
1904
- self.model_parallel = False
1905
- self.device_map = None
1906
-
1907
- self._init_adapter_modules()
1908
-
1909
- @add_start_docstrings(PARALLELIZE_DOCSTRING)
1910
- def parallelize(self, device_map=None):
1911
- self.device_map = (
1912
- get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1913
- if device_map is None
1914
- else device_map
1915
- )
1916
- assert_device_map(self.device_map, len(self.encoder.block))
1917
- self.encoder.parallelize(self.device_map)
1918
- self.model_parallel = True
1919
-
1920
- @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1921
- def deparallelize(self):
1922
- self.encoder.deparallelize()
1923
- self.encoder = self.encoder.to("cpu")
1924
- self.model_parallel = False
1925
- self.device_map = None
1926
- torch.cuda.empty_cache()
1927
-
1928
- def get_input_embeddings(self):
1929
- return self.shared
1930
-
1931
- def set_input_embeddings(self, new_embeddings):
1932
- self.shared = new_embeddings
1933
- self.encoder.set_input_embeddings(new_embeddings)
1934
-
1935
- def get_encoder(self):
1936
- return self.encoder
1937
-
1938
- def _prune_heads(self, heads_to_prune):
1939
- """
1940
- Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1941
- class PreTrainedModel
1942
- """
1943
- for layer, heads in heads_to_prune.items():
1944
- self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
1945
-
1946
- @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
1947
- @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
1948
- @ForwardContext.wrap
1949
- def forward(
1950
- self,
1951
- input_ids: Optional[torch.LongTensor] = None,
1952
- attention_mask: Optional[torch.FloatTensor] = None,
1953
- head_mask: Optional[torch.FloatTensor] = None,
1954
- inputs_embeds: Optional[torch.FloatTensor] = None,
1955
- output_attentions: Optional[bool] = None,
1956
- output_hidden_states: Optional[bool] = None,
1957
- return_dict: Optional[bool] = None,
1958
- ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
1959
- r"""
1960
- Returns:
1961
-
1962
- Example:
1963
-
1964
- ```python
1965
- >>> from transformers import AutoTokenizer, T5EncoderModel
1966
-
1967
- >>> tokenizer = AutoTokenizer.from_pretrained("t5-small")
1968
- >>> model = T5EncoderModel.from_pretrained("t5-small")
1969
- >>> input_ids = tokenizer(
1970
- ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1971
- ... ).input_ids # Batch size 1
1972
- >>> outputs = model(input_ids=input_ids)
1973
- >>> last_hidden_states = outputs.last_hidden_state
1974
- ```"""
1975
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1976
-
1977
- encoder_outputs = self.encoder(
1978
- input_ids=input_ids,
1979
- attention_mask=attention_mask,
1980
- inputs_embeds=inputs_embeds,
1981
- head_mask=head_mask,
1982
- output_attentions=output_attentions,
1983
- output_hidden_states=output_hidden_states,
1984
- return_dict=return_dict,
1985
- )
1986
-
1987
- return encoder_outputs
 
1
+ from transformers import T5ForConditionalGeneration as t5FCG
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from transformers.models.t5.configuration_t5 import T5Config
3
+ from typing import Optional, Tuple, Union, List, Callable
4
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ class T5ForConditionalGeneration(t5FCG):
11
+
12
  def __init__(self, config: T5Config):
13
  super().__init__(config)
14
+
15
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def preprocess(self,text):
17
  text = text.replace("\n", "\\n").replace("\t", "\\t")
18
  return text
 
52
 
53
  history.append((query, response))
54
  return response,history