gonglinyuan commited on
Commit
9979d35
·
1 Parent(s): 677d0db

Upload FairseqT5ForConditionalGeneration

Browse files
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "FairseqT5ForConditionalGeneration"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_fairseq_t5.FairseqT5Config",
7
+ "AutoModelForSeq2SeqLM": "modeling_fairseq_t5.FairseqT5ForConditionalGeneration"
8
+ },
9
+ "d_ff": 3072,
10
+ "d_kv": 64,
11
+ "d_model": 768,
12
+ "decoder_start_token_id": 2,
13
+ "dropout_rate": 0.1,
14
+ "eos_token_id": 2,
15
+ "feed_forward_proj": "relu",
16
+ "initializer_factor": 1.0,
17
+ "is_encoder_decoder": true,
18
+ "layer_norm_epsilon": 1e-05,
19
+ "max_positions": 1024,
20
+ "model_type": "fairseq_t5",
21
+ "num_decoder_layers": 12,
22
+ "num_heads": 12,
23
+ "num_layers": 12,
24
+ "pad_token_id": 1,
25
+ "relative_attention_max_distance": 128,
26
+ "relative_attention_num_buckets": 32,
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.28.1",
29
+ "use_cache": true,
30
+ "vocab_size": 64512
31
+ }
configuration_fairseq_t5.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+
4
+ class FairseqT5Config(PretrainedConfig):
5
+ model_type = "fairseq_t5"
6
+ keys_to_ignore_at_inference = ["past_key_values"]
7
+ attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_size=64518,
12
+ d_model=768,
13
+ d_kv=64,
14
+ d_ff=3072,
15
+ num_layers=6,
16
+ num_decoder_layers=None,
17
+ num_heads=8,
18
+ relative_attention_num_buckets=32,
19
+ relative_attention_max_distance=128,
20
+ max_positions=1024,
21
+ dropout_rate=0.1,
22
+ layer_norm_epsilon=1e-6,
23
+ initializer_factor=1.0,
24
+ feed_forward_proj="relu",
25
+ is_encoder_decoder=True,
26
+ use_cache=True,
27
+ pad_token_id=1,
28
+ eos_token_id=2,
29
+ **kwargs
30
+ ):
31
+ self.vocab_size = vocab_size
32
+ self.d_model = d_model
33
+ self.d_kv = d_kv
34
+ self.d_ff = d_ff
35
+ self.num_layers = num_layers
36
+ self.num_decoder_layers = (
37
+ num_decoder_layers if num_decoder_layers is not None else self.num_layers
38
+ ) # default = symmetry
39
+ self.num_heads = num_heads
40
+ self.relative_attention_num_buckets = relative_attention_num_buckets
41
+ self.relative_attention_max_distance = relative_attention_max_distance
42
+ self.max_positions = max_positions
43
+ self.dropout_rate = dropout_rate
44
+ self.layer_norm_epsilon = layer_norm_epsilon
45
+ self.initializer_factor = initializer_factor
46
+ self.feed_forward_proj = feed_forward_proj
47
+ self.use_cache = use_cache
48
+ super().__init__(
49
+ pad_token_id=pad_token_id,
50
+ eos_token_id=eos_token_id,
51
+ is_encoder_decoder=is_encoder_decoder,
52
+ **kwargs,
53
+ )
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "decoder_start_token_id": 2,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 1,
6
+ "transformers_version": "4.28.1"
7
+ }
modeling_fairseq_t5.py ADDED
@@ -0,0 +1,1585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ from typing import Dict, Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch import Tensor
8
+ from torch.utils.checkpoint import checkpoint
9
+ from transformers.activations import ACT2FN
10
+ from transformers.file_utils import DUMMY_INPUTS, DUMMY_MASK, is_torch_fx_proxy
11
+ from transformers.modeling_outputs import (
12
+ BaseModelOutput,
13
+ BaseModelOutputWithPastAndCrossAttentions,
14
+ Seq2SeqLMOutput,
15
+ Seq2SeqModelOutput,
16
+ )
17
+ from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
18
+ from transformers.utils import logging
19
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
20
+
21
+ from .configuration_fairseq_t5 import FairseqT5Config
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ def make_positions(tensor, padding_idx: int, onnx_trace: bool = False):
27
+ """Replace non-padding symbols with their position numbers.
28
+ Position numbers begin at padding_idx+1. Padding symbols are ignored.
29
+ """
30
+ # The series of casts and type-conversions here are carefully
31
+ # balanced to both work with ONNX export and XLA. In particular XLA
32
+ # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
33
+ # how to handle the dtype kwarg in cumsum.
34
+ mask = tensor.ne(padding_idx).int()
35
+ return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
36
+
37
+
38
+ class LearnedPositionalEmbedding(nn.Embedding):
39
+ """
40
+ This module learns positional embeddings up to a fixed maximum size.
41
+ Padding ids are ignored by either offsetting based on padding_idx
42
+ or by setting padding_idx to None and ensuring that the appropriate
43
+ position ids are passed to the forward function.
44
+ """
45
+
46
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
47
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
48
+ self.onnx_trace = False
49
+ if self.padding_idx is not None:
50
+ self.max_positions = self.num_embeddings - self.padding_idx - 1
51
+ else:
52
+ self.max_positions = self.num_embeddings
53
+
54
+ def forward(
55
+ self,
56
+ input: Tensor,
57
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
58
+ positions: Optional[Tensor] = None,
59
+ offset=0,
60
+ ):
61
+ """Input is expected to be of size [bsz x seqlen]."""
62
+ assert (positions is None) or (
63
+ self.padding_idx is None
64
+ ), "If positions is pre-computed then padding_idx should not be set."
65
+
66
+ if positions is None:
67
+ if incremental_state is not None:
68
+ # positions is the same for every token when decoding a single step
69
+ # Without the int() cast, it doesn't work in some cases when exporting to ONNX
70
+ positions = torch.zeros(
71
+ (1, 1), device=input.device, dtype=input.dtype
72
+ ).fill_(int(self.padding_idx + input.size(1)))
73
+ else:
74
+ positions = make_positions(
75
+ input, self.padding_idx, onnx_trace=self.onnx_trace
76
+ )
77
+ if offset > 0 and positions.size(1) == 1:
78
+ positions = positions + offset
79
+ return nn.functional.embedding(
80
+ positions,
81
+ self.weight,
82
+ self.padding_idx,
83
+ self.max_norm,
84
+ self.norm_type,
85
+ self.scale_grad_by_freq,
86
+ self.sparse,
87
+ )
88
+
89
+
90
+ def PositionalEmbedding(
91
+ num_embeddings: int,
92
+ embedding_dim: int,
93
+ padding_idx: int,
94
+ ):
95
+ # if padding_idx is specified then offset the embedding ids by
96
+ # this index and adjust num_embeddings appropriately
97
+ # TODO: The right place for this offset would be inside
98
+ # LearnedPositionalEmbedding. Move this there for a cleaner implementation.
99
+ if padding_idx is not None:
100
+ num_embeddings = num_embeddings + padding_idx + 1
101
+ m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
102
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
103
+ if padding_idx is not None:
104
+ nn.init.constant_(m.weight[padding_idx], 0)
105
+ return m
106
+
107
+
108
+ class T5LayerNorm(nn.Module):
109
+ def __init__(self, hidden_size, eps=1e-5):
110
+ """
111
+ Construct a layernorm module in the T5 style No bias and no subtraction of mean.
112
+ """
113
+ super().__init__()
114
+ self.weight = nn.Parameter(torch.ones(hidden_size))
115
+ self.bias = nn.Parameter(torch.ones(hidden_size))
116
+ self.variance_epsilon = eps
117
+
118
+ def forward(self, hidden_states):
119
+ # layer norm should always be calculated in float32
120
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
121
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
122
+
123
+ # convert into half-precision if necessary
124
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
125
+ hidden_states = hidden_states.to(self.weight.dtype)
126
+
127
+ return self.weight * hidden_states + self.bias
128
+
129
+
130
+ def FST5LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
131
+ return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
132
+
133
+
134
+ class T5DenseReluDense(nn.Module):
135
+ def __init__(self, config):
136
+ super().__init__()
137
+ if_bias = True
138
+ self.wi = nn.Linear(config.d_model, config.d_ff, bias=if_bias) #
139
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=if_bias) #
140
+ self.dropout = nn.Dropout(config.dropout_rate)
141
+
142
+ def forward(self, hidden_states):
143
+ hidden_states = self.wi(hidden_states)
144
+ hidden_states = nn.functional.relu(hidden_states)
145
+ hidden_states = self.dropout(hidden_states)
146
+ hidden_states = self.wo(hidden_states)
147
+ return hidden_states
148
+
149
+
150
+ class T5DenseGatedGeluDense(nn.Module):
151
+ def __init__(self, config):
152
+ super().__init__()
153
+ self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
154
+ self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
155
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
156
+ self.dropout = nn.Dropout(config.dropout_rate)
157
+ self.gelu_act = ACT2FN["gelu_new"]
158
+
159
+ def forward(self, hidden_states):
160
+ hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
161
+ hidden_linear = self.wi_1(hidden_states)
162
+ hidden_states = hidden_gelu * hidden_linear
163
+ hidden_states = self.dropout(hidden_states)
164
+ hidden_states = self.wo(hidden_states)
165
+ return hidden_states
166
+
167
+
168
+ class T5LayerFF(nn.Module):
169
+ def __init__(self, config, normalize_before=False):
170
+ super().__init__()
171
+ if config.feed_forward_proj == "relu":
172
+ self.DenseReluDense = T5DenseReluDense(config)
173
+ elif config.feed_forward_proj == "gated-gelu":
174
+ self.DenseReluDense = T5DenseGatedGeluDense(config)
175
+ else:
176
+ raise ValueError(
177
+ f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
178
+ )
179
+
180
+ self.layer_norm = FST5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
181
+ self.dropout = nn.Dropout(config.dropout_rate)
182
+
183
+ self.normalize_before = normalize_before
184
+
185
+ def forward(self, hidden_states):
186
+ if self.normalize_before:
187
+ forwarded_states = self.layer_norm(hidden_states)
188
+ else:
189
+ forwarded_states = hidden_states
190
+ forwarded_states = self.DenseReluDense(forwarded_states)
191
+ hidden_states = hidden_states + self.dropout(forwarded_states)
192
+
193
+ if not self.normalize_before:
194
+ hidden_states = self.layer_norm(hidden_states)
195
+ return hidden_states
196
+
197
+
198
+ class T5Attention(nn.Module):
199
+ def __init__(self, config: FairseqT5Config, has_relative_attention_bias=False):
200
+ super().__init__()
201
+ self.is_decoder = config.is_decoder
202
+ self.has_relative_attention_bias = has_relative_attention_bias
203
+
204
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
205
+ self.relative_attention_max_distance = config.relative_attention_max_distance
206
+ self.max_positions = config.max_positions
207
+ self.d_model = config.d_model
208
+ self.key_value_proj_dim = config.d_kv
209
+ self.n_heads = config.num_heads
210
+ self.dropout = config.dropout_rate
211
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
212
+
213
+ # Mesh TensorFlow initialization to avoid scaling before softmax
214
+ if_bias = True
215
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=if_bias)
216
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=if_bias)
217
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=if_bias)
218
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=if_bias)
219
+
220
+ if self.has_relative_attention_bias:
221
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
222
+ self.pruned_heads = set()
223
+ self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False)
224
+
225
+ # rp from fs
226
+ relative_position = (
227
+ torch.arange(self.max_positions, dtype=torch.long)[None, :]
228
+ - torch.arange(self.max_positions, dtype=torch.long)[:, None]
229
+ )
230
+ self.rp_bucket = self.relative_position_bucket(
231
+ relative_position,
232
+ num_buckets=self.relative_attention_num_buckets,
233
+ max_distance=self.relative_attention_max_distance
234
+ )
235
+ self.rp_bucket -= self.rp_bucket.min()
236
+
237
+ self.head_dim = self.d_model // self.n_heads
238
+ self.scaling = self.head_dim ** -0.5
239
+
240
+ def prune_heads(self, heads):
241
+ if len(heads) == 0:
242
+ return
243
+ heads, index = find_pruneable_heads_and_indices(
244
+ heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
245
+ )
246
+ # Prune linear layers
247
+ self.q = prune_linear_layer(self.q, index)
248
+ self.k = prune_linear_layer(self.k, index)
249
+ self.v = prune_linear_layer(self.v, index)
250
+ self.o = prune_linear_layer(self.o, index, dim=1)
251
+ # Update hyper params
252
+ self.n_heads = self.n_heads - len(heads)
253
+ self.inner_dim = self.key_value_proj_dim * self.n_heads
254
+ self.pruned_heads = self.pruned_heads.union(heads)
255
+
256
+ @staticmethod
257
+ def relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
258
+ sign = torch.sign(relative_position)
259
+ num_buckets //= 2
260
+ n = torch.abs(relative_position)
261
+
262
+ # half of the buckets are for exact increments in positions
263
+ max_exact = num_buckets // 2
264
+ is_small = n < max_exact
265
+ max_bucket_val = num_buckets - 1 - max_exact
266
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
267
+ val_if_large = max_exact + torch.ceil(
268
+ torch.log(n.float() / max_exact)
269
+ / math.log((max_distance - 1) / max_exact)
270
+ * max_bucket_val
271
+ ).long()
272
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
273
+ ret = torch.where(is_small, n, val_if_large) * sign
274
+ return ret
275
+
276
+ def compute_bias(self, query_length, key_length):
277
+ relative_position_bucket = self.rp_bucket[:query_length, :key_length]
278
+ relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
279
+ values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
280
+ values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
281
+ return values
282
+
283
+ def forward(
284
+ self,
285
+ hidden_states,
286
+ mask=None,
287
+ key_value_states=None,
288
+ position_bias=None,
289
+ past_key_value=None,
290
+ layer_head_mask=None,
291
+ query_length=None,
292
+ use_cache=False,
293
+ output_attentions=False,
294
+ key_padding_mask=None,
295
+ ):
296
+ """
297
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
298
+ """
299
+ # Input is (batch_size, seq_length, dim)
300
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
301
+ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
302
+ batch_size, seq_length = hidden_states.shape[:2]
303
+
304
+ int_seq_length = int(seq_length)
305
+
306
+ real_seq_length = seq_length
307
+
308
+ if past_key_value is not None:
309
+ assert (
310
+ len(past_key_value) == 2
311
+ ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
312
+ real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
313
+
314
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
315
+
316
+ def shape(states):
317
+ """projection"""
318
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
319
+
320
+ def unshape(states):
321
+ """reshape"""
322
+ return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
323
+
324
+ def project(hidden_states, proj_layer, key_value_states, past_key_value):
325
+ """projects hidden states correctly to key/query states"""
326
+ if key_value_states is None:
327
+ # self-attn
328
+ # (batch_size, n_heads, seq_length, dim_per_head)
329
+ hidden_states = shape(proj_layer(hidden_states))
330
+ elif past_key_value is None:
331
+ # cross-attn
332
+ # (batch_size, n_heads, seq_length, dim_per_head)
333
+ hidden_states = shape(proj_layer(key_value_states))
334
+
335
+ if past_key_value is not None:
336
+ if key_value_states is None:
337
+ # self-attn
338
+ # (batch_size, n_heads, key_length, dim_per_head)
339
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
340
+ else:
341
+ # cross-attn
342
+ hidden_states = past_key_value
343
+ return hidden_states
344
+
345
+ # get query states
346
+ query_states = shape(self.q(hidden_states)) * self.scaling # (batch_size, n_heads, seq_length, dim_per_head)
347
+
348
+ # get key/value states
349
+ key_states = project(
350
+ hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
351
+ )
352
+ value_states = project(
353
+ hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
354
+ )
355
+
356
+ # compute scores
357
+ scores = torch.matmul(
358
+ query_states, key_states.transpose(3, 2)
359
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
360
+
361
+ if position_bias is None:
362
+ if not self.has_relative_attention_bias:
363
+ position_bias = torch.zeros(
364
+ (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
365
+ )
366
+ if self.gradient_checkpointing and self.training:
367
+ position_bias.requires_grad = True
368
+ else:
369
+ position_bias = self.compute_bias(real_seq_length, key_length)
370
+
371
+ # if key and values are already calculated
372
+ # we want only the last query position bias
373
+ if past_key_value is not None:
374
+ position_bias = position_bias[:, :, -int_seq_length:, :]
375
+
376
+ if mask is not None:
377
+ position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
378
+
379
+ scores += position_bias
380
+
381
+ if key_padding_mask is not None:
382
+ scores = scores.masked_fill(
383
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
384
+ float("-inf"),
385
+ )
386
+
387
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
388
+ scores
389
+ ) # (batch_size, n_heads, seq_length, key_length)
390
+ attn_weights = nn.functional.dropout(
391
+ attn_weights, p=self.dropout, training=self.training
392
+ ) # (batch_size, n_heads, seq_length, key_length)
393
+
394
+ # Mask heads if we want to
395
+ if layer_head_mask is not None:
396
+ attn_weights = attn_weights * layer_head_mask
397
+
398
+ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
399
+ attn_output = self.o(attn_output)
400
+
401
+ present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
402
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
403
+
404
+ if output_attentions:
405
+ outputs = outputs + (attn_weights,)
406
+ return outputs
407
+
408
+
409
+ class T5LayerSelfAttention(nn.Module):
410
+ def __init__(self, config, has_relative_attention_bias=False, normalize_before=False):
411
+ super().__init__()
412
+ self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
413
+ # self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
414
+ self.layer_norm = FST5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
415
+ self.dropout = nn.Dropout(config.dropout_rate)
416
+ self.normalize_before = normalize_before
417
+ self.has_relative_attention_bias = has_relative_attention_bias
418
+
419
+ def forward(
420
+ self,
421
+ hidden_states,
422
+ attention_mask=None,
423
+ position_bias=None,
424
+ layer_head_mask=None,
425
+ past_key_value=None,
426
+ use_cache=False,
427
+ output_attentions=False,
428
+ key_padding_mask=None,
429
+ ):
430
+ if self.normalize_before:
431
+ normed_hidden_states = self.layer_norm(hidden_states)
432
+ else:
433
+ normed_hidden_states = hidden_states
434
+
435
+ attention_output = self.SelfAttention(
436
+ normed_hidden_states,
437
+ mask=attention_mask,
438
+ position_bias=position_bias,
439
+ layer_head_mask=layer_head_mask,
440
+ past_key_value=past_key_value,
441
+ use_cache=use_cache,
442
+ output_attentions=output_attentions,
443
+ key_padding_mask=key_padding_mask,
444
+ )
445
+ hidden_states = hidden_states + self.dropout(attention_output[0])
446
+
447
+ if not self.normalize_before:
448
+ hidden_states = self.layer_norm(hidden_states)
449
+
450
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
451
+ return outputs
452
+
453
+
454
+ class T5LayerCrossAttention(nn.Module):
455
+ def __init__(self, config, normalize_before=False):
456
+ super().__init__()
457
+ self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
458
+ # self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
459
+ self.layer_norm = FST5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
460
+ self.dropout = nn.Dropout(config.dropout_rate)
461
+
462
+ self.normalize_before = normalize_before
463
+
464
+ def forward(
465
+ self,
466
+ hidden_states,
467
+ key_value_states,
468
+ attention_mask=None,
469
+ position_bias=None,
470
+ layer_head_mask=None,
471
+ past_key_value=None,
472
+ use_cache=False,
473
+ query_length=None,
474
+ output_attentions=False,
475
+ ):
476
+ if self.normalize_before:
477
+ normed_hidden_states = self.layer_norm(hidden_states)
478
+ else:
479
+ normed_hidden_states = hidden_states
480
+
481
+ attention_output = self.EncDecAttention(
482
+ normed_hidden_states,
483
+ mask=attention_mask,
484
+ key_value_states=key_value_states,
485
+ position_bias=position_bias,
486
+ layer_head_mask=layer_head_mask,
487
+ past_key_value=past_key_value,
488
+ use_cache=use_cache,
489
+ query_length=query_length,
490
+ output_attentions=output_attentions,
491
+ )
492
+ layer_output = hidden_states + self.dropout(attention_output[0])
493
+
494
+ if not self.normalize_before:
495
+ layer_output = self.layer_norm(layer_output)
496
+
497
+ outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
498
+ return outputs
499
+
500
+
501
+ class T5Block(nn.Module):
502
+ def __init__(self, config, has_relative_attention_bias=False):
503
+ super().__init__()
504
+ self.is_decoder = config.is_decoder
505
+ self.layer = nn.ModuleList()
506
+ self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
507
+ if self.is_decoder:
508
+ self.layer.append(T5LayerCrossAttention(config))
509
+
510
+ self.layer.append(T5LayerFF(config))
511
+
512
+ def forward(
513
+ self,
514
+ hidden_states,
515
+ attention_mask=None,
516
+ position_bias=None,
517
+ encoder_hidden_states=None,
518
+ encoder_attention_mask=None,
519
+ encoder_decoder_position_bias=None,
520
+ layer_head_mask=None,
521
+ cross_attn_layer_head_mask=None,
522
+ past_key_value=None,
523
+ use_cache=False,
524
+ output_attentions=False,
525
+ return_dict=True,
526
+ key_padding_mask=None,
527
+ ):
528
+
529
+ if past_key_value is not None:
530
+ assert self.is_decoder, "Only decoder can use `past_key_values`"
531
+ expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
532
+
533
+ if len(past_key_value) != expected_num_past_key_values:
534
+ raise ValueError(
535
+ f"There should be {expected_num_past_key_values} past states. "
536
+ f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
537
+ f"Got {len(past_key_value)} past key / value states"
538
+ )
539
+
540
+ self_attn_past_key_value = past_key_value[:2]
541
+ cross_attn_past_key_value = past_key_value[2:]
542
+ else:
543
+ self_attn_past_key_value, cross_attn_past_key_value = None, None
544
+
545
+ self_attention_outputs = self.layer[0](
546
+ hidden_states,
547
+ attention_mask=attention_mask,
548
+ position_bias=position_bias,
549
+ layer_head_mask=layer_head_mask,
550
+ past_key_value=self_attn_past_key_value,
551
+ use_cache=use_cache,
552
+ output_attentions=output_attentions,
553
+ key_padding_mask=key_padding_mask,
554
+ )
555
+ hidden_states, present_key_value_state = self_attention_outputs[:2]
556
+ attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
557
+
558
+ # clamp inf values to enable fp16 training
559
+ if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
560
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
561
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
562
+
563
+ do_cross_attention = self.is_decoder and encoder_hidden_states is not None
564
+ if do_cross_attention:
565
+ # the actual query length is unknown for cross attention
566
+ # if using past key value states. Need to inject it here
567
+ if present_key_value_state is not None:
568
+ query_length = present_key_value_state[0].shape[2]
569
+ else:
570
+ query_length = None
571
+
572
+ cross_attention_outputs = self.layer[1](
573
+ hidden_states,
574
+ key_value_states=encoder_hidden_states,
575
+ attention_mask=encoder_attention_mask,
576
+ position_bias=encoder_decoder_position_bias,
577
+ layer_head_mask=cross_attn_layer_head_mask,
578
+ past_key_value=cross_attn_past_key_value,
579
+ query_length=query_length,
580
+ use_cache=use_cache,
581
+ output_attentions=output_attentions,
582
+ )
583
+ hidden_states = cross_attention_outputs[0]
584
+
585
+ # clamp inf values to enable fp16 training
586
+ if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
587
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
588
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
589
+
590
+ # Combine self attn and cross attn key value states
591
+ if present_key_value_state is not None:
592
+ present_key_value_state = present_key_value_state + cross_attention_outputs[1]
593
+
594
+ # Keep cross-attention outputs and relative position weights
595
+ attention_outputs = attention_outputs + cross_attention_outputs[2:]
596
+
597
+ # Apply Feed Forward layer
598
+ hidden_states = self.layer[-1](hidden_states)
599
+
600
+ # clamp inf values to enable fp16 training
601
+ if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
602
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
603
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
604
+
605
+ outputs = (hidden_states,)
606
+
607
+ if use_cache:
608
+ outputs = outputs + (present_key_value_state,) + attention_outputs
609
+ else:
610
+ outputs = outputs + attention_outputs
611
+
612
+ return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
613
+
614
+
615
+ class FairseqT5PreTrainedModel(PreTrainedModel):
616
+ """
617
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
618
+ models.
619
+ """
620
+
621
+ config_class = FairseqT5Config
622
+ load_tf_weights = None
623
+ base_model_prefix = "transformer"
624
+ is_parallelizable = True
625
+ supports_gradient_checkpointing = True
626
+
627
+ @property
628
+ def dummy_inputs(self):
629
+ input_ids = torch.tensor(DUMMY_INPUTS)
630
+ input_mask = torch.tensor(DUMMY_MASK)
631
+ dummy_inputs = {
632
+ "decoder_input_ids": input_ids,
633
+ "input_ids": input_ids,
634
+ "decoder_attention_mask": input_mask,
635
+ }
636
+ return dummy_inputs
637
+
638
+ def _init_weights(self, module):
639
+ """Initialize the weights"""
640
+ factor = self.config.initializer_factor # Used for testing weights initialization
641
+ if isinstance(module, T5LayerNorm) or isinstance(module, torch.nn.LayerNorm):
642
+ module.weight.data.fill_(factor * 1.0)
643
+ elif isinstance(module, (FairseqT5Model, FairseqT5ForConditionalGeneration, FairseqT5EncoderModel)):
644
+ # Mesh TensorFlow embeddings initialization
645
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
646
+ module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
647
+ elif isinstance(module, T5DenseReluDense):
648
+ # Mesh TensorFlow FF initialization
649
+ # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
650
+ # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
651
+ module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
652
+ if hasattr(module.wi, "bias") and module.wi.bias is not None:
653
+ module.wi.bias.data.zero_()
654
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
655
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
656
+ module.wo.bias.data.zero_()
657
+ elif isinstance(module, T5DenseGatedGeluDense):
658
+ module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
659
+ if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
660
+ module.wi_0.bias.data.zero_()
661
+ module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
662
+ if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
663
+ module.wi_1.bias.data.zero_()
664
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
665
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
666
+ module.wo.bias.data.zero_()
667
+ elif isinstance(module, T5Attention):
668
+ # Mesh TensorFlow attention initialization to avoid scaling before softmax
669
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
670
+ d_model = self.config.d_model
671
+ key_value_proj_dim = self.config.d_kv
672
+ n_heads = self.config.num_heads
673
+ module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
674
+ module.k.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
675
+ module.v.weight.data.normal_(mean=0.0, std=factor * (d_model ** -0.5))
676
+ module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
677
+ if module.has_relative_attention_bias:
678
+ module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
679
+
680
+ def _set_gradient_checkpointing(self, module, value=False):
681
+ if isinstance(module, (T5Attention, FairseqT5Stack)):
682
+ module.gradient_checkpointing = value
683
+
684
+ def _shift_right(self, input_ids):
685
+ decoder_start_token_id = self.config.decoder_start_token_id
686
+ pad_token_id = self.config.pad_token_id
687
+
688
+ assert (
689
+ decoder_start_token_id is not None
690
+ ), "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. See T5 docs for more information"
691
+
692
+ # shift inputs to the right
693
+ if is_torch_fx_proxy(input_ids):
694
+ # Item assignment is not supported natively for proxies.
695
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
696
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
697
+ else:
698
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
699
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
700
+ shifted_input_ids[..., 0] = decoder_start_token_id
701
+
702
+ assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
703
+ # replace possible -100 values in labels by `pad_token_id`
704
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
705
+
706
+ assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
707
+
708
+ return shifted_input_ids
709
+
710
+
711
+ class FairseqT5Stack(FairseqT5PreTrainedModel):
712
+ def __init__(self, config, embed_tokens=None):
713
+ super().__init__(config)
714
+
715
+ self.embed_tokens = embed_tokens
716
+ self.pos_embed = PositionalEmbedding(
717
+ 1024,
718
+ config.d_model,
719
+ config.pad_token_id,
720
+ )
721
+ self.is_decoder = config.is_decoder
722
+
723
+ # self.first_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) # final_layer_norm -> first layer norm
724
+ self.first_layer_norm = FST5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
725
+ self.dropout = nn.Dropout(config.dropout_rate) #
726
+
727
+ # modified
728
+ if not self.is_decoder:
729
+ self.block = nn.ModuleList(
730
+ # [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
731
+ [T5Block(config, has_relative_attention_bias=True) for i in range(config.num_layers)]
732
+ )
733
+ else:
734
+ self.block = nn.ModuleList(
735
+ # [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
736
+ [T5Block(config, has_relative_attention_bias=False) for i in range(config.num_layers)]
737
+ )
738
+
739
+ self.init_weights()
740
+ # Model parallel
741
+ self.model_parallel = False
742
+ self.device_map = None
743
+ self.gradient_checkpointing = False
744
+
745
+ self.padding_idx = self.config.pad_token_id
746
+
747
+ def parallelize(self, device_map=None):
748
+ # Check validity of device_map
749
+ self.device_map = (
750
+ get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
751
+ )
752
+ assert_device_map(self.device_map, len(self.block))
753
+ self.model_parallel = True
754
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
755
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
756
+ # Load onto devices
757
+ for k, v in self.device_map.items():
758
+ for layer in v:
759
+ cuda_device = "cuda:" + str(k)
760
+ self.block[layer] = self.block[layer].to(cuda_device)
761
+
762
+ # Set embed_tokens to first layer
763
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
764
+ self.pos_embed = self.pos_embed.to(self.first_device)
765
+ # Set first layer norm to first device
766
+ self.first_layer_norm = self.first_layer_norm.to(self.first_device)
767
+
768
+ def deparallelize(self):
769
+ self.model_parallel = False
770
+ self.device_map = None
771
+ self.first_device = "cpu"
772
+ self.last_device = "cpu"
773
+ for i in range(len(self.block)):
774
+ self.block[i] = self.block[i].to("cpu")
775
+ self.embed_tokens = self.embed_tokens.to("cpu")
776
+ self.first_layer_norm = self.first_layer_norm.to("cpu")
777
+ torch.cuda.empty_cache()
778
+
779
+ def get_input_embeddings(self):
780
+ return self.embed_tokens
781
+
782
+ def set_input_embeddings(self, new_embeddings):
783
+ self.embed_tokens = new_embeddings
784
+
785
+ def forward(
786
+ self,
787
+ input_ids=None,
788
+ attention_mask=None,
789
+ encoder_hidden_states=None,
790
+ encoder_attention_mask=None,
791
+ inputs_embeds=None,
792
+ head_mask=None,
793
+ cross_attn_head_mask=None,
794
+ past_key_values=None,
795
+ use_cache=None,
796
+ output_attentions=None,
797
+ output_hidden_states=None,
798
+ return_dict=None,
799
+ pos_offset=0,
800
+ ):
801
+ # Model parallel
802
+ if self.model_parallel:
803
+ torch.cuda.set_device(self.first_device)
804
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
805
+ self.pos_embed = self.pos_embed.to(self.first_device)
806
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
807
+
808
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
809
+ output_hidden_states = (
810
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
811
+ )
812
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
813
+
814
+ if input_ids is not None and inputs_embeds is not None:
815
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
816
+ raise ValueError(
817
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
818
+ )
819
+ elif input_ids is not None:
820
+ input_shape = input_ids.size()
821
+ input_ids = input_ids.view(-1, input_shape[-1])
822
+ elif inputs_embeds is not None:
823
+ input_shape = inputs_embeds.size()[:-1]
824
+ else:
825
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
826
+ raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
827
+
828
+ if inputs_embeds is None:
829
+ assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
830
+ inputs_embeds = self.embed_tokens(input_ids)
831
+
832
+ batch_size, seq_length = input_shape
833
+
834
+ # required mask seq length can be calculated via length of past
835
+ mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
836
+
837
+ if use_cache is True:
838
+ assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
839
+
840
+ if attention_mask is None:
841
+ attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
842
+ if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
843
+ encoder_seq_length = encoder_hidden_states.shape[1]
844
+ encoder_attention_mask = torch.ones(
845
+ batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
846
+ )
847
+
848
+ # initialize past_key_values with `None` if past does not exist
849
+ if past_key_values is None:
850
+ past_key_values = [None] * len(self.block)
851
+
852
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
853
+ # ourselves in which case we just need to make it broadcastable to all heads.
854
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
855
+
856
+ # If a 2D or 3D attention mask is provided for the cross-attention
857
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
858
+ if self.is_decoder and encoder_attention_mask is not None:
859
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
860
+ else:
861
+ encoder_extended_attention_mask = None
862
+
863
+ # Prepare head mask if needed
864
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
865
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
866
+ present_key_value_states = () if use_cache else None
867
+ all_hidden_states = () if output_hidden_states else None
868
+ all_attentions = () if output_attentions else None
869
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
870
+ position_bias = None
871
+ encoder_decoder_position_bias = None
872
+
873
+ # modified: position embedding
874
+ # if input_ids is not None:
875
+ # include position offset for decoding
876
+ pos_embeds = self.pos_embed(input_ids, offset=pos_offset)
877
+ inputs_embeds = inputs_embeds + pos_embeds
878
+
879
+ # hidden_states = self.dropout(inputs_embeds)
880
+ hidden_states = self.first_layer_norm(inputs_embeds) # modified: first layer_norm
881
+ hidden_states = self.dropout(hidden_states)
882
+
883
+ key_padding_mask: Optional[Tensor] = None
884
+ if self.is_decoder:
885
+ if input_ids.eq(self.padding_idx).any():
886
+ key_padding_mask = input_ids.eq(self.padding_idx)
887
+
888
+ for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
889
+ layer_head_mask = head_mask[i]
890
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
891
+ # Model parallel
892
+ if self.model_parallel:
893
+ torch.cuda.set_device(hidden_states.device)
894
+ # Ensure that attention_mask is always on the same device as hidden_states
895
+ if attention_mask is not None:
896
+ attention_mask = attention_mask.to(hidden_states.device)
897
+ if position_bias is not None:
898
+ position_bias = position_bias.to(hidden_states.device)
899
+ if encoder_hidden_states is not None:
900
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
901
+ if encoder_extended_attention_mask is not None:
902
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
903
+ if encoder_decoder_position_bias is not None:
904
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
905
+ if layer_head_mask is not None:
906
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
907
+ if cross_attn_layer_head_mask is not None:
908
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
909
+ if output_hidden_states:
910
+ all_hidden_states = all_hidden_states + (hidden_states,)
911
+
912
+ if self.gradient_checkpointing and self.training:
913
+ if use_cache:
914
+ logger.warn(
915
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
916
+ )
917
+ use_cache = False
918
+
919
+ def create_custom_forward(module):
920
+ def custom_forward(*inputs):
921
+ return tuple(module(*inputs, use_cache, output_attentions))
922
+
923
+ return custom_forward
924
+
925
+ layer_outputs = checkpoint(
926
+ create_custom_forward(layer_module),
927
+ hidden_states,
928
+ extended_attention_mask,
929
+ position_bias,
930
+ encoder_hidden_states,
931
+ encoder_extended_attention_mask,
932
+ encoder_decoder_position_bias,
933
+ layer_head_mask,
934
+ cross_attn_layer_head_mask,
935
+ None, # past_key_value is always None with gradient checkpointing
936
+ key_padding_mask=key_padding_mask,
937
+ )
938
+ else:
939
+ layer_outputs = layer_module(
940
+ hidden_states,
941
+ attention_mask=extended_attention_mask,
942
+ position_bias=position_bias,
943
+ encoder_hidden_states=encoder_hidden_states,
944
+ encoder_attention_mask=encoder_extended_attention_mask,
945
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
946
+ layer_head_mask=layer_head_mask,
947
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
948
+ past_key_value=past_key_value,
949
+ use_cache=use_cache,
950
+ output_attentions=output_attentions,
951
+ key_padding_mask=key_padding_mask,
952
+ )
953
+
954
+ # layer_outputs is a tuple with:
955
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
956
+ if use_cache is False:
957
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
958
+
959
+ hidden_states, present_key_value_state = layer_outputs[:2]
960
+
961
+ # We share the position biases between the layers - the first layer store them
962
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
963
+ # (cross-attention position bias), (cross-attention weights)
964
+ # position_bias = layer_outputs[2]
965
+ if self.is_decoder and encoder_hidden_states is not None:
966
+ encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
967
+ # append next layer key value states
968
+ if use_cache:
969
+ present_key_value_states = present_key_value_states + (present_key_value_state,)
970
+
971
+ if output_attentions:
972
+ all_attentions = all_attentions + (layer_outputs[3],)
973
+ if self.is_decoder:
974
+ all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
975
+
976
+ # Model Parallel: If it's the last layer for that device, put things on the next device
977
+ if self.model_parallel:
978
+ for k, v in self.device_map.items():
979
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
980
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
981
+
982
+ # modified: no final_layer_norm
983
+ # hidden_states = self.final_layer_norm(hidden_states)
984
+ # hidden_states = self.dropout(hidden_states)
985
+
986
+ # Add last layer
987
+ if output_hidden_states:
988
+ all_hidden_states = all_hidden_states + (hidden_states,)
989
+
990
+ if not return_dict:
991
+ return tuple(
992
+ v
993
+ for v in [
994
+ hidden_states,
995
+ present_key_value_states,
996
+ all_hidden_states,
997
+ all_attentions,
998
+ all_cross_attentions,
999
+ ]
1000
+ if v is not None
1001
+ )
1002
+ return BaseModelOutputWithPastAndCrossAttentions(
1003
+ last_hidden_state=hidden_states,
1004
+ past_key_values=present_key_value_states,
1005
+ hidden_states=all_hidden_states,
1006
+ attentions=all_attentions,
1007
+ cross_attentions=all_cross_attentions,
1008
+ )
1009
+
1010
+
1011
+ class FairseqT5Model(FairseqT5PreTrainedModel):
1012
+ _keys_to_ignore_on_load_missing = [
1013
+ r"encoder\.embed_tokens\.weight",
1014
+ r"decoder\.embed_tokens\.weight",
1015
+ ]
1016
+ _keys_to_ignore_on_load_unexpected = [
1017
+ r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
1018
+ ]
1019
+
1020
+ def __init__(self, config: FairseqT5Config):
1021
+ super().__init__(config)
1022
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1023
+
1024
+ encoder_config = copy.deepcopy(config)
1025
+ encoder_config.is_decoder = False
1026
+ encoder_config.use_cache = False
1027
+ encoder_config.is_encoder_decoder = False
1028
+ self.encoder = FairseqT5Stack(encoder_config, self.shared)
1029
+
1030
+ decoder_config = copy.deepcopy(config)
1031
+ decoder_config.is_decoder = True
1032
+ decoder_config.is_encoder_decoder = False
1033
+ decoder_config.num_layers = config.num_decoder_layers
1034
+ self.decoder = FairseqT5Stack(decoder_config, self.shared)
1035
+
1036
+ # Initialize weights and apply final processing
1037
+ self.init_weights()
1038
+
1039
+ # Model parallel
1040
+ self.model_parallel = False
1041
+ self.device_map = None
1042
+
1043
+ def parallelize(self, device_map=None):
1044
+ self.device_map = (
1045
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1046
+ if device_map is None
1047
+ else device_map
1048
+ )
1049
+ assert_device_map(self.device_map, len(self.encoder.block))
1050
+ self.encoder.parallelize(self.device_map)
1051
+ self.decoder.parallelize(self.device_map)
1052
+ self.model_parallel = True
1053
+
1054
+ def deparallelize(self):
1055
+ self.encoder.deparallelize()
1056
+ self.decoder.deparallelize()
1057
+ self.encoder = self.encoder.to("cpu")
1058
+ self.decoder = self.decoder.to("cpu")
1059
+ self.model_parallel = False
1060
+ self.device_map = None
1061
+ torch.cuda.empty_cache()
1062
+
1063
+ def get_input_embeddings(self):
1064
+ return self.shared
1065
+
1066
+ def set_input_embeddings(self, new_embeddings):
1067
+ self.shared = new_embeddings
1068
+ self.encoder.set_input_embeddings(new_embeddings)
1069
+ self.decoder.set_input_embeddings(new_embeddings)
1070
+
1071
+ def get_encoder(self):
1072
+ return self.encoder
1073
+
1074
+ def get_decoder(self):
1075
+ return self.decoder
1076
+
1077
+ def _prune_heads(self, heads_to_prune):
1078
+ """
1079
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1080
+ class PreTrainedModel
1081
+ """
1082
+ for layer, heads in heads_to_prune.items():
1083
+ self.encoder.layer[layer].attention.prune_heads(heads)
1084
+
1085
+ def forward(
1086
+ self,
1087
+ input_ids=None,
1088
+ attention_mask=None,
1089
+ decoder_input_ids=None,
1090
+ decoder_attention_mask=None,
1091
+ head_mask=None,
1092
+ decoder_head_mask=None,
1093
+ cross_attn_head_mask=None,
1094
+ encoder_outputs=None,
1095
+ past_key_values=None,
1096
+ inputs_embeds=None,
1097
+ decoder_inputs_embeds=None,
1098
+ use_cache=None,
1099
+ output_attentions=None,
1100
+ output_hidden_states=None,
1101
+ return_dict=None,
1102
+ ):
1103
+ r"""
1104
+ Returns: Seq2SeqModelOutput
1105
+
1106
+ Example:
1107
+
1108
+ ```python
1109
+ >>> from transformers import T5Tokenizer, T5Model
1110
+
1111
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
1112
+ >>> model = FairseqT5Model.from_pretrained("t5-small")
1113
+
1114
+ >>> input_ids = tokenizer(
1115
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1116
+ >>> ).input_ids # Batch size 1
1117
+ >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
1118
+
1119
+ >>> # forward pass
1120
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
1121
+ >>> last_hidden_states = outputs.last_hidden_state
1122
+ ```"""
1123
+ use_cache = False
1124
+
1125
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1126
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1127
+
1128
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1129
+ if head_mask is not None and decoder_head_mask is None:
1130
+ if self.config.num_layers == self.config.num_decoder_layers:
1131
+ decoder_head_mask = head_mask
1132
+
1133
+ # Encode if needed (training, first prediction pass)
1134
+ if encoder_outputs is None:
1135
+ encoder_outputs = self.encoder(
1136
+ input_ids=input_ids,
1137
+ attention_mask=attention_mask,
1138
+ inputs_embeds=inputs_embeds,
1139
+ head_mask=head_mask,
1140
+ output_attentions=output_attentions,
1141
+ output_hidden_states=output_hidden_states,
1142
+ return_dict=return_dict,
1143
+ )
1144
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1145
+ encoder_outputs = BaseModelOutput(
1146
+ last_hidden_state=encoder_outputs[0],
1147
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1148
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1149
+ )
1150
+
1151
+ hidden_states = encoder_outputs[0]
1152
+ if self.model_parallel:
1153
+ torch.cuda.set_device(self.decoder.first_device)
1154
+ # Set device for model parallelism
1155
+ if self.model_parallel:
1156
+ torch.cuda.set_device(self.decoder.first_device)
1157
+ hidden_states = hidden_states.to(self.decoder.first_device)
1158
+ if decoder_input_ids is not None:
1159
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
1160
+ if attention_mask is not None:
1161
+ attention_mask = attention_mask.to(self.decoder.first_device)
1162
+ if decoder_attention_mask is not None:
1163
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
1164
+
1165
+ # Decode
1166
+ decoder_outputs = self.decoder(
1167
+ input_ids=decoder_input_ids,
1168
+ attention_mask=decoder_attention_mask,
1169
+ inputs_embeds=decoder_inputs_embeds,
1170
+ past_key_values=past_key_values,
1171
+ encoder_hidden_states=hidden_states,
1172
+ encoder_attention_mask=attention_mask,
1173
+ head_mask=decoder_head_mask,
1174
+ cross_attn_head_mask=cross_attn_head_mask,
1175
+ use_cache=use_cache,
1176
+ output_attentions=output_attentions,
1177
+ output_hidden_states=output_hidden_states,
1178
+ return_dict=return_dict,
1179
+ )
1180
+
1181
+ if not return_dict:
1182
+ return decoder_outputs + encoder_outputs
1183
+
1184
+ return Seq2SeqModelOutput(
1185
+ last_hidden_state=decoder_outputs.last_hidden_state,
1186
+ past_key_values=decoder_outputs.past_key_values,
1187
+ decoder_hidden_states=decoder_outputs.hidden_states,
1188
+ decoder_attentions=decoder_outputs.attentions,
1189
+ cross_attentions=decoder_outputs.cross_attentions,
1190
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1191
+ encoder_hidden_states=encoder_outputs.hidden_states,
1192
+ encoder_attentions=encoder_outputs.attentions,
1193
+ )
1194
+
1195
+
1196
+ class FairseqT5ForConditionalGeneration(FairseqT5PreTrainedModel):
1197
+ _keys_to_ignore_on_load_missing = [
1198
+ r"encoder\.embed_tokens\.weight",
1199
+ r"decoder\.embed_tokens\.weight",
1200
+ r"lm_head\.weight",
1201
+ ]
1202
+ _keys_to_ignore_on_load_unexpected = [
1203
+ r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
1204
+ ]
1205
+
1206
+ def __init__(self, config):
1207
+ super().__init__(config)
1208
+ self.model_dim = config.d_model
1209
+
1210
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1211
+
1212
+ encoder_config = copy.deepcopy(config)
1213
+ encoder_config.is_decoder = False
1214
+ encoder_config.use_cache = False
1215
+ encoder_config.is_encoder_decoder = False
1216
+ self.encoder = FairseqT5Stack(encoder_config, self.shared)
1217
+
1218
+ decoder_config = copy.deepcopy(config)
1219
+ decoder_config.is_decoder = True
1220
+ decoder_config.is_encoder_decoder = False
1221
+ decoder_config.num_layers = config.num_decoder_layers
1222
+ self.decoder = FairseqT5Stack(decoder_config, self.shared)
1223
+
1224
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
1225
+
1226
+ # Initialize weights and apply final processing
1227
+ self.init_weights()
1228
+
1229
+ # Model parallel
1230
+ self.model_parallel = False
1231
+ self.device_map = None
1232
+
1233
+ def parallelize(self, device_map=None):
1234
+ self.device_map = (
1235
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1236
+ if device_map is None
1237
+ else device_map
1238
+ )
1239
+ assert_device_map(self.device_map, len(self.encoder.block))
1240
+ self.encoder.parallelize(self.device_map)
1241
+ self.decoder.parallelize(self.device_map)
1242
+ self.lm_head = self.lm_head.to(self.decoder.first_device)
1243
+ self.model_parallel = True
1244
+
1245
+ def deparallelize(self):
1246
+ self.encoder.deparallelize()
1247
+ self.decoder.deparallelize()
1248
+ self.encoder = self.encoder.to("cpu")
1249
+ self.decoder = self.decoder.to("cpu")
1250
+ self.lm_head = self.lm_head.to("cpu")
1251
+ self.model_parallel = False
1252
+ self.device_map = None
1253
+ torch.cuda.empty_cache()
1254
+
1255
+ def get_input_embeddings(self):
1256
+ return self.shared
1257
+
1258
+ def set_input_embeddings(self, new_embeddings):
1259
+ self.shared = new_embeddings
1260
+ self.encoder.set_input_embeddings(new_embeddings)
1261
+ self.decoder.set_input_embeddings(new_embeddings)
1262
+
1263
+ def set_output_embeddings(self, new_embeddings):
1264
+ self.lm_head = new_embeddings
1265
+
1266
+ def get_output_embeddings(self):
1267
+ return self.lm_head
1268
+
1269
+ def get_encoder(self):
1270
+ return self.encoder
1271
+
1272
+ def get_decoder(self):
1273
+ return self.decoder
1274
+
1275
+ def forward(
1276
+ self,
1277
+ input_ids=None,
1278
+ attention_mask=None,
1279
+ decoder_input_ids=None,
1280
+ decoder_attention_mask=None,
1281
+ head_mask=None,
1282
+ decoder_head_mask=None,
1283
+ cross_attn_head_mask=None,
1284
+ encoder_outputs=None,
1285
+ past_key_values=None,
1286
+ inputs_embeds=None,
1287
+ decoder_inputs_embeds=None,
1288
+ labels=None,
1289
+ use_cache=None,
1290
+ output_attentions=None,
1291
+ output_hidden_states=None,
1292
+ return_dict=None,
1293
+ pos_offset=0,
1294
+ ):
1295
+ r"""
1296
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1297
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
1298
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
1299
+ labels in `[0, ..., config.vocab_size]`
1300
+
1301
+ Returns: Seq2SeqLMOutput
1302
+
1303
+ Examples:
1304
+
1305
+ ```python
1306
+ >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
1307
+
1308
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
1309
+ >>> model = FairseqT5ForConditionalGeneration.from_pretrained("t5-small")
1310
+
1311
+ >>> # training
1312
+ >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
1313
+ >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
1314
+ >>> outputs = model(input_ids=input_ids, labels=labels)
1315
+ >>> loss = outputs.loss
1316
+ >>> logits = outputs.logits
1317
+
1318
+ >>> # inference
1319
+ >>> input_ids = tokenizer(
1320
+ ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
1321
+ >>> ).input_ids # Batch size 1
1322
+ >>> outputs = model.generate(input_ids)
1323
+ >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
1324
+ >>> # studies have shown that owning a dog is good for you.
1325
+ ```"""
1326
+
1327
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1328
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1329
+
1330
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1331
+ if head_mask is not None and decoder_head_mask is None:
1332
+ if self.config.num_layers == self.config.num_decoder_layers:
1333
+ decoder_head_mask = head_mask
1334
+
1335
+ # Encode if needed (training, first prediction pass)
1336
+ if encoder_outputs is None:
1337
+ # Convert encoder inputs in embeddings if needed
1338
+ encoder_outputs = self.encoder(
1339
+ input_ids=input_ids,
1340
+ attention_mask=attention_mask,
1341
+ inputs_embeds=inputs_embeds,
1342
+ head_mask=head_mask,
1343
+ output_attentions=output_attentions,
1344
+ output_hidden_states=output_hidden_states,
1345
+ return_dict=return_dict,
1346
+ )
1347
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1348
+ encoder_outputs = BaseModelOutput(
1349
+ last_hidden_state=encoder_outputs[0],
1350
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1351
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1352
+ )
1353
+
1354
+ hidden_states = encoder_outputs[0]
1355
+
1356
+ if self.model_parallel:
1357
+ torch.cuda.set_device(self.decoder.first_device)
1358
+
1359
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
1360
+ # get decoder inputs from shifting lm labels to the right
1361
+ decoder_input_ids = self._shift_right(labels)
1362
+
1363
+ # If decoding with past key value states, only the last tokens
1364
+ # should be given as an input
1365
+ if past_key_values is not None:
1366
+ assert labels is None, "Decoder should not use cached key value states when training."
1367
+ if decoder_input_ids is not None:
1368
+ decoder_input_ids = decoder_input_ids[:, -1:]
1369
+ if decoder_inputs_embeds is not None:
1370
+ decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
1371
+
1372
+ # Set device for model parallelism
1373
+ if self.model_parallel:
1374
+ torch.cuda.set_device(self.decoder.first_device)
1375
+ hidden_states = hidden_states.to(self.decoder.first_device)
1376
+ if decoder_input_ids is not None:
1377
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
1378
+ if attention_mask is not None:
1379
+ attention_mask = attention_mask.to(self.decoder.first_device)
1380
+ if decoder_attention_mask is not None:
1381
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
1382
+
1383
+ # Decode
1384
+ decoder_outputs = self.decoder(
1385
+ input_ids=decoder_input_ids,
1386
+ attention_mask=decoder_attention_mask,
1387
+ inputs_embeds=decoder_inputs_embeds,
1388
+ past_key_values=past_key_values,
1389
+ encoder_hidden_states=hidden_states,
1390
+ encoder_attention_mask=attention_mask,
1391
+ head_mask=decoder_head_mask,
1392
+ cross_attn_head_mask=cross_attn_head_mask,
1393
+ use_cache=use_cache,
1394
+ output_attentions=output_attentions,
1395
+ output_hidden_states=output_hidden_states,
1396
+ return_dict=return_dict,
1397
+ pos_offset=pos_offset,
1398
+ )
1399
+
1400
+ sequence_output = decoder_outputs[0]
1401
+
1402
+ # Set device for model parallelism
1403
+ if self.model_parallel:
1404
+ torch.cuda.set_device(self.encoder.first_device)
1405
+ self.lm_head = self.lm_head.to(self.encoder.first_device)
1406
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
1407
+
1408
+ lm_logits = self.lm_head(sequence_output)
1409
+
1410
+ loss = None
1411
+ if labels is not None:
1412
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
1413
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1414
+ # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
1415
+
1416
+ if not return_dict:
1417
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
1418
+ return ((loss,) + output) if loss is not None else output
1419
+
1420
+ return Seq2SeqLMOutput(
1421
+ loss=loss,
1422
+ logits=lm_logits,
1423
+ past_key_values=decoder_outputs.past_key_values,
1424
+ decoder_hidden_states=decoder_outputs.hidden_states,
1425
+ decoder_attentions=decoder_outputs.attentions,
1426
+ cross_attentions=decoder_outputs.cross_attentions,
1427
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1428
+ encoder_hidden_states=encoder_outputs.hidden_states,
1429
+ encoder_attentions=encoder_outputs.attentions,
1430
+ )
1431
+
1432
+ def prepare_inputs_for_generation(
1433
+ self,
1434
+ input_ids,
1435
+ past=None,
1436
+ attention_mask=None,
1437
+ head_mask=None,
1438
+ decoder_head_mask=None,
1439
+ cross_attn_head_mask=None,
1440
+ use_cache=None,
1441
+ encoder_outputs=None,
1442
+ **kwargs
1443
+ ):
1444
+ # cut decoder_input_ids if past is used
1445
+ offset = 0
1446
+ if past is not None:
1447
+ offset = max(0, int(input_ids.size(1)) - 1)
1448
+ input_ids = input_ids[:, -1:]
1449
+
1450
+ return {
1451
+ "decoder_input_ids": input_ids,
1452
+ "past_key_values": past,
1453
+ "encoder_outputs": encoder_outputs,
1454
+ "attention_mask": attention_mask,
1455
+ "head_mask": head_mask,
1456
+ "decoder_head_mask": decoder_head_mask,
1457
+ "cross_attn_head_mask": cross_attn_head_mask,
1458
+ "use_cache": use_cache,
1459
+ "pos_offset": offset,
1460
+ }
1461
+
1462
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1463
+ return self._shift_right(labels)
1464
+
1465
+ def _reorder_cache(self, past, beam_idx):
1466
+ # if decoder past is not included in output
1467
+ # speedy decoding is disabled and no need to reorder
1468
+ if past is None:
1469
+ logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
1470
+ return past
1471
+
1472
+ reordered_decoder_past = ()
1473
+ for layer_past_states in past:
1474
+ # get the correct batch idx from layer past batch dim
1475
+ # batch dim of `past` is at 2nd position
1476
+ reordered_layer_past_states = ()
1477
+ for layer_past_state in layer_past_states:
1478
+ # need to set correct `past` for each of the four key / value states
1479
+ reordered_layer_past_states = reordered_layer_past_states + (
1480
+ layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
1481
+ )
1482
+
1483
+ assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
1484
+ assert len(reordered_layer_past_states) == len(layer_past_states)
1485
+
1486
+ reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
1487
+ return reordered_decoder_past
1488
+
1489
+
1490
+ class FairseqT5EncoderModel(FairseqT5PreTrainedModel):
1491
+ authorized_missing_keys = [
1492
+ r"encoder\.embed_tokens\.weight",
1493
+ ]
1494
+
1495
+ def __init__(self, config: FairseqT5Config):
1496
+ super().__init__(config)
1497
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1498
+
1499
+ encoder_config = copy.deepcopy(config)
1500
+ encoder_config.is_decoder = False
1501
+ encoder_config.use_cache = False
1502
+ encoder_config.is_encoder_decoder = False
1503
+ self.encoder = FairseqT5Stack(encoder_config, self.shared)
1504
+
1505
+ # Initialize weights and apply final processing
1506
+ self.init_weights()
1507
+
1508
+ # Model parallel
1509
+ self.model_parallel = False
1510
+ self.device_map = None
1511
+
1512
+ def parallelize(self, device_map=None):
1513
+ self.device_map = (
1514
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1515
+ if device_map is None
1516
+ else device_map
1517
+ )
1518
+ assert_device_map(self.device_map, len(self.encoder.block))
1519
+ self.encoder.parallelize(self.device_map)
1520
+ self.model_parallel = True
1521
+
1522
+ def deparallelize(self):
1523
+ self.encoder.deparallelize()
1524
+ self.encoder = self.encoder.to("cpu")
1525
+ self.model_parallel = False
1526
+ self.device_map = None
1527
+ torch.cuda.empty_cache()
1528
+
1529
+ def get_input_embeddings(self):
1530
+ return self.shared
1531
+
1532
+ def set_input_embeddings(self, new_embeddings):
1533
+ self.shared = new_embeddings
1534
+ self.encoder.set_input_embeddings(new_embeddings)
1535
+
1536
+ def get_encoder(self):
1537
+ return self.encoder
1538
+
1539
+ def _prune_heads(self, heads_to_prune):
1540
+ """
1541
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1542
+ class PreTrainedModel
1543
+ """
1544
+ for layer, heads in heads_to_prune.items():
1545
+ self.encoder.layer[layer].attention.prune_heads(heads)
1546
+
1547
+ def forward(
1548
+ self,
1549
+ input_ids=None,
1550
+ attention_mask=None,
1551
+ head_mask=None,
1552
+ inputs_embeds=None,
1553
+ output_attentions=None,
1554
+ output_hidden_states=None,
1555
+ return_dict=None,
1556
+ ):
1557
+ r"""
1558
+ Returns: BaseModelOutput
1559
+
1560
+ Example:
1561
+
1562
+ ```python
1563
+ >>> from transformers import T5Tokenizer, T5EncoderModel
1564
+
1565
+ >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
1566
+ >>> model = FairseqT5EncoderModel.from_pretrained("t5-small")
1567
+ >>> input_ids = tokenizer(
1568
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1569
+ >>> ).input_ids # Batch size 1
1570
+ >>> outputs = model(input_ids=input_ids)
1571
+ >>> last_hidden_states = outputs.last_hidden_state
1572
+ ```"""
1573
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1574
+
1575
+ encoder_outputs = self.encoder(
1576
+ input_ids=input_ids,
1577
+ attention_mask=attention_mask,
1578
+ inputs_embeds=inputs_embeds,
1579
+ head_mask=head_mask,
1580
+ output_attentions=output_attentions,
1581
+ output_hidden_states=output_hidden_states,
1582
+ return_dict=return_dict,
1583
+ )
1584
+
1585
+ return encoder_outputs
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b2e26afe9280954c88f17214b435f6d15d4d4e663dabeca1f95c64e6394082f
3
+ size 998586155