gonglinyuan commited on
Commit
df37bca
1 Parent(s): ce3f731

Upload FairseqT5ForConditionalGeneration

Browse files
README.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+
201
+
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "FairseqT5ForConditionalGeneration"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_fairseq_t5.FairseqT5Config",
7
+ "AutoModelForSeq2SeqLM": "modeling_fairseq_t5.FairseqT5ForConditionalGeneration"
8
+ },
9
+ "d_ff": 3072,
10
+ "d_kv": 64,
11
+ "d_model": 768,
12
+ "decoder_start_token_id": 2,
13
+ "dropout_rate": 0.1,
14
+ "eos_token_id": 2,
15
+ "feed_forward_proj": "relu",
16
+ "initializer_factor": 1.0,
17
+ "is_encoder_decoder": true,
18
+ "layer_norm_epsilon": 1e-05,
19
+ "max_positions": 1024,
20
+ "model_type": "fairseq_t5",
21
+ "num_decoder_layers": 12,
22
+ "num_heads": 12,
23
+ "num_layers": 12,
24
+ "pad_token_id": 1,
25
+ "relative_attention_max_distance": 128,
26
+ "relative_attention_num_buckets": 128,
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.37.2",
29
+ "use_cache": true,
30
+ "vocab_size": 101265
31
+ }
configuration_fairseq_t5.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+
4
+ class FairseqT5Config(PretrainedConfig):
5
+ model_type = "fairseq_t5"
6
+ keys_to_ignore_at_inference = ["past_key_values"]
7
+ attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_size=64518,
12
+ d_model=768,
13
+ d_kv=64,
14
+ d_ff=3072,
15
+ num_layers=6,
16
+ num_decoder_layers=None,
17
+ num_heads=8,
18
+ relative_attention_num_buckets=32,
19
+ relative_attention_max_distance=128,
20
+ max_positions=1024,
21
+ dropout_rate=0.1,
22
+ layer_norm_epsilon=1e-6,
23
+ initializer_factor=1.0,
24
+ feed_forward_proj="relu",
25
+ is_encoder_decoder=True,
26
+ use_cache=True,
27
+ pad_token_id=1,
28
+ eos_token_id=2,
29
+ **kwargs
30
+ ):
31
+ self.vocab_size = vocab_size
32
+ self.d_model = d_model
33
+ self.d_kv = d_kv
34
+ self.d_ff = d_ff
35
+ self.num_layers = num_layers
36
+ self.num_decoder_layers = (
37
+ num_decoder_layers if num_decoder_layers is not None else self.num_layers
38
+ ) # default = symmetry
39
+ self.num_heads = num_heads
40
+ self.relative_attention_num_buckets = relative_attention_num_buckets
41
+ self.relative_attention_max_distance = relative_attention_max_distance
42
+ self.max_positions = max_positions
43
+ self.dropout_rate = dropout_rate
44
+ self.layer_norm_epsilon = layer_norm_epsilon
45
+ self.initializer_factor = initializer_factor
46
+ self.feed_forward_proj = feed_forward_proj
47
+ self.use_cache = use_cache
48
+ super().__init__(
49
+ pad_token_id=pad_token_id,
50
+ eos_token_id=eos_token_id,
51
+ is_encoder_decoder=is_encoder_decoder,
52
+ **kwargs,
53
+ )
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "decoder_start_token_id": 2,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 1,
6
+ "transformers_version": "4.37.2"
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2bb954ff00b70d43f5b622a2745253e61fb49eadebba3649b0b63320e637565f
3
+ size 1111439856
modeling_fairseq_t5.py ADDED
@@ -0,0 +1,1585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ from typing import Dict, Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch import Tensor
8
+ from torch.utils.checkpoint import checkpoint
9
+ from transformers.activations import ACT2FN
10
+ from transformers.file_utils import DUMMY_INPUTS, DUMMY_MASK, is_torch_fx_proxy
11
+ from transformers.modeling_outputs import (
12
+ BaseModelOutput,
13
+ BaseModelOutputWithPastAndCrossAttentions,
14
+ Seq2SeqLMOutput,
15
+ Seq2SeqModelOutput,
16
+ )
17
+ from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
18
+ from transformers.utils import logging
19
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
20
+
21
+ from .configuration_fairseq_t5 import FairseqT5Config
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ def make_positions(tensor, padding_idx: int, onnx_trace: bool = False):
27
+ """Replace non-padding symbols with their position numbers.
28
+ Position numbers begin at padding_idx+1. Padding symbols are ignored.
29
+ """
30
+ # The series of casts and type-conversions here are carefully
31
+ # balanced to both work with ONNX export and XLA. In particular XLA
32
+ # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
33
+ # how to handle the dtype kwarg in cumsum.
34
+ mask = tensor.ne(padding_idx).int()
35
+ return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
36
+
37
+
38
+ class LearnedPositionalEmbedding(nn.Embedding):
39
+ """
40
+ This module learns positional embeddings up to a fixed maximum size.
41
+ Padding ids are ignored by either offsetting based on padding_idx
42
+ or by setting padding_idx to None and ensuring that the appropriate
43
+ position ids are passed to the forward function.
44
+ """
45
+
46
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int):
47
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
48
+ self.onnx_trace = False
49
+ if self.padding_idx is not None:
50
+ self.max_positions = self.num_embeddings - self.padding_idx - 1
51
+ else:
52
+ self.max_positions = self.num_embeddings
53
+
54
+ def forward(
55
+ self,
56
+ input: Tensor,
57
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
58
+ positions: Optional[Tensor] = None,
59
+ offset=0,
60
+ ):
61
+ """Input is expected to be of size [bsz x seqlen]."""
62
+ assert (positions is None) or (
63
+ self.padding_idx is None
64
+ ), "If positions is pre-computed then padding_idx should not be set."
65
+
66
+ if positions is None:
67
+ if incremental_state is not None:
68
+ # positions is the same for every token when decoding a single step
69
+ # Without the int() cast, it doesn't work in some cases when exporting to ONNX
70
+ positions = torch.zeros(
71
+ (1, 1), device=input.device, dtype=input.dtype
72
+ ).fill_(int(self.padding_idx + input.size(1)))
73
+ else:
74
+ positions = make_positions(
75
+ input, self.padding_idx, onnx_trace=self.onnx_trace
76
+ )
77
+ if offset > 0 and positions.size(1) == 1:
78
+ positions = positions + offset
79
+ return nn.functional.embedding(
80
+ positions,
81
+ self.weight,
82
+ self.padding_idx,
83
+ self.max_norm,
84
+ self.norm_type,
85
+ self.scale_grad_by_freq,
86
+ self.sparse,
87
+ )
88
+
89
+
90
+ def PositionalEmbedding(
91
+ num_embeddings: int,
92
+ embedding_dim: int,
93
+ padding_idx: int,
94
+ ):
95
+ # if padding_idx is specified then offset the embedding ids by
96
+ # this index and adjust num_embeddings appropriately
97
+ # TODO: The right place for this offset would be inside
98
+ # LearnedPositionalEmbedding. Move this there for a cleaner implementation.
99
+ if padding_idx is not None:
100
+ num_embeddings = num_embeddings + padding_idx + 1
101
+ m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx)
102
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
103
+ if padding_idx is not None:
104
+ nn.init.constant_(m.weight[padding_idx], 0)
105
+ return m
106
+
107
+
108
+ class T5LayerNorm(nn.Module):
109
+ def __init__(self, hidden_size, eps=1e-5):
110
+ """
111
+ Construct a layernorm module in the T5 style No bias and no subtraction of mean.
112
+ """
113
+ super().__init__()
114
+ self.weight = nn.Parameter(torch.ones(hidden_size))
115
+ self.bias = nn.Parameter(torch.ones(hidden_size))
116
+ self.variance_epsilon = eps
117
+
118
+ def forward(self, hidden_states):
119
+ # layer norm should always be calculated in float32
120
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
121
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
122
+
123
+ # convert into half-precision if necessary
124
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
125
+ hidden_states = hidden_states.to(self.weight.dtype)
126
+
127
+ return self.weight * hidden_states + self.bias
128
+
129
+
130
+ def FST5LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
131
+ return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
132
+
133
+
134
+ class T5DenseReluDense(nn.Module):
135
+ def __init__(self, config):
136
+ super().__init__()
137
+ if_bias = True
138
+ self.wi = nn.Linear(config.d_model, config.d_ff, bias=if_bias) #
139
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=if_bias) #
140
+ self.dropout = nn.Dropout(config.dropout_rate)
141
+
142
+ def forward(self, hidden_states):
143
+ hidden_states = self.wi(hidden_states)
144
+ hidden_states = nn.functional.relu(hidden_states)
145
+ hidden_states = self.dropout(hidden_states)
146
+ hidden_states = self.wo(hidden_states)
147
+ return hidden_states
148
+
149
+
150
+ class T5DenseGatedGeluDense(nn.Module):
151
+ def __init__(self, config):
152
+ super().__init__()
153
+ self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
154
+ self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
155
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
156
+ self.dropout = nn.Dropout(config.dropout_rate)
157
+ self.gelu_act = ACT2FN["gelu_new"]
158
+
159
+ def forward(self, hidden_states):
160
+ hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
161
+ hidden_linear = self.wi_1(hidden_states)
162
+ hidden_states = hidden_gelu * hidden_linear
163
+ hidden_states = self.dropout(hidden_states)
164
+ hidden_states = self.wo(hidden_states)
165
+ return hidden_states
166
+
167
+
168
+ class T5LayerFF(nn.Module):
169
+ def __init__(self, config, normalize_before=False):
170
+ super().__init__()
171
+ if config.feed_forward_proj == "relu":
172
+ self.DenseReluDense = T5DenseReluDense(config)
173
+ elif config.feed_forward_proj == "gated-gelu":
174
+ self.DenseReluDense = T5DenseGatedGeluDense(config)
175
+ else:
176
+ raise ValueError(
177
+ f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
178
+ )
179
+
180
+ self.layer_norm = FST5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
181
+ self.dropout = nn.Dropout(config.dropout_rate)
182
+
183
+ self.normalize_before = normalize_before
184
+
185
+ def forward(self, hidden_states):
186
+ if self.normalize_before:
187
+ forwarded_states = self.layer_norm(hidden_states)
188
+ else:
189
+ forwarded_states = hidden_states
190
+ forwarded_states = self.DenseReluDense(forwarded_states)
191
+ hidden_states = hidden_states + self.dropout(forwarded_states)
192
+
193
+ if not self.normalize_before:
194
+ hidden_states = self.layer_norm(hidden_states)
195
+ return hidden_states
196
+
197
+
198
+ class T5Attention(nn.Module):
199
+ def __init__(self, config: FairseqT5Config, has_relative_attention_bias=False):
200
+ super().__init__()
201
+ self.is_decoder = config.is_decoder
202
+ self.has_relative_attention_bias = has_relative_attention_bias
203
+
204
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
205
+ self.relative_attention_max_distance = config.relative_attention_max_distance
206
+ self.max_positions = config.max_positions
207
+ self.d_model = config.d_model
208
+ self.key_value_proj_dim = config.d_kv
209
+ self.n_heads = config.num_heads
210
+ self.dropout = config.dropout_rate
211
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
212
+
213
+ # Mesh TensorFlow initialization to avoid scaling before softmax
214
+ if_bias = True
215
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=if_bias)
216
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=if_bias)
217
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=if_bias)
218
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=if_bias)
219
+
220
+ if self.has_relative_attention_bias:
221
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
222
+ self.pruned_heads = set()
223
+ self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False)
224
+
225
+ # rp from fs
226
+ relative_position = (
227
+ torch.arange(self.max_positions, dtype=torch.long)[None, :]
228
+ - torch.arange(self.max_positions, dtype=torch.long)[:, None]
229
+ )
230
+ self.rp_bucket = self.relative_position_bucket(
231
+ relative_position,
232
+ num_buckets=self.relative_attention_num_buckets,
233
+ max_distance=self.relative_attention_max_distance
234
+ )
235
+ self.rp_bucket -= self.rp_bucket.min()
236
+
237
+ self.head_dim = self.d_model // self.n_heads
238
+ self.scaling = self.head_dim ** -0.5
239
+
240
+ def prune_heads(self, heads):
241
+ if len(heads) == 0:
242
+ return
243
+ heads, index = find_pruneable_heads_and_indices(
244
+ heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
245
+ )
246
+ # Prune linear layers
247
+ self.q = prune_linear_layer(self.q, index)
248
+ self.k = prune_linear_layer(self.k, index)
249
+ self.v = prune_linear_layer(self.v, index)
250
+ self.o = prune_linear_layer(self.o, index, dim=1)
251
+ # Update hyper params
252
+ self.n_heads = self.n_heads - len(heads)
253
+ self.inner_dim = self.key_value_proj_dim * self.n_heads
254
+ self.pruned_heads = self.pruned_heads.union(heads)
255
+
256
+ @staticmethod
257
+ def relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
258
+ sign = torch.sign(relative_position)
259
+ num_buckets //= 2
260
+ n = torch.abs(relative_position)
261
+
262
+ # half of the buckets are for exact increments in positions
263
+ max_exact = num_buckets // 2
264
+ is_small = n < max_exact
265
+ max_bucket_val = num_buckets - 1 - max_exact
266
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
267
+ val_if_large = max_exact + torch.ceil(
268
+ torch.log(n.float() / max_exact)
269
+ / math.log((max_distance - 1) / max_exact)
270
+ * max_bucket_val
271
+ ).long()
272
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
273
+ ret = torch.where(is_small, n, val_if_large) * sign
274
+ return ret
275
+
276
+ def compute_bias(self, query_length, key_length):
277
+ relative_position_bucket = self.rp_bucket[:query_length, :key_length]
278
+ relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
279
+ values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
280
+ values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
281
+ return values
282
+
283
+ def forward(
284
+ self,
285
+ hidden_states,
286
+ mask=None,
287
+ key_value_states=None,
288
+ position_bias=None,
289
+ past_key_value=None,
290
+ layer_head_mask=None,
291
+ query_length=None,
292
+ use_cache=False,
293
+ output_attentions=False,
294
+ key_padding_mask=None,
295
+ ):
296
+ """
297
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
298
+ """
299
+ # Input is (batch_size, seq_length, dim)
300
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
301
+ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
302
+ batch_size, seq_length = hidden_states.shape[:2]
303
+
304
+ int_seq_length = int(seq_length)
305
+
306
+ real_seq_length = seq_length
307
+
308
+ if past_key_value is not None:
309
+ assert (
310
+ len(past_key_value) == 2
311
+ ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
312
+ real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
313
+
314
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
315
+
316
+ def shape(states):
317
+ """projection"""
318
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
319
+
320
+ def unshape(states):
321
+ """reshape"""
322
+ return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
323
+
324
+ def project(hidden_states, proj_layer, key_value_states, past_key_value):
325
+ """projects hidden states correctly to key/query states"""
326
+ if key_value_states is None:
327
+ # self-attn
328
+ # (batch_size, n_heads, seq_length, dim_per_head)
329
+ hidden_states = shape(proj_layer(hidden_states))
330
+ elif past_key_value is None:
331
+ # cross-attn
332
+ # (batch_size, n_heads, seq_length, dim_per_head)
333
+ hidden_states = shape(proj_layer(key_value_states))
334
+
335
+ if past_key_value is not None:
336
+ if key_value_states is None:
337
+ # self-attn
338
+ # (batch_size, n_heads, key_length, dim_per_head)
339
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
340
+ else:
341
+ # cross-attn
342
+ hidden_states = past_key_value
343
+ return hidden_states
344
+
345
+ # get query states
346
+ query_states = shape(self.q(hidden_states)) * self.scaling # (batch_size, n_heads, seq_length, dim_per_head)
347
+
348
+ # get key/value states
349
+ key_states = project(
350
+ hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
351
+ )
352
+ value_states = project(
353
+ hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
354
+ )
355
+
356
+ # compute scores
357
+ scores = torch.matmul(
358
+ query_states, key_states.transpose(3, 2)
359
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
360
+
361
+ if position_bias is None:
362
+ if not self.has_relative_attention_bias:
363
+ position_bias = torch.zeros(
364
+ (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
365
+ )
366
+ if self.gradient_checkpointing and self.training:
367
+ position_bias.requires_grad = True
368
+ else:
369
+ position_bias = self.compute_bias(real_seq_length, key_length)
370
+
371
+ # if key and values are already calculated
372
+ # we want only the last query position bias
373
+ if past_key_value is not None:
374
+ position_bias = position_bias[:, :, -int_seq_length:, :]
375
+
376
+ if mask is not None:
377
+ position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
378
+
379
+ scores += position_bias
380
+
381
+ if key_padding_mask is not None:
382
+ scores = scores.masked_fill(
383
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
384
+ float("-inf"),
385
+ )
386
+
387
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
388
+ scores
389
+ ) # (batch_size, n_heads, seq_length, key_length)
390
+ attn_weights = nn.functional.dropout(
391
+ attn_weights, p=self.dropout, training=self.training
392
+ ) # (batch_size, n_heads, seq_length, key_length)
393
+
394
+ # Mask heads if we want to
395
+ if layer_head_mask is not None:
396
+ attn_weights = attn_weights * layer_head_mask
397
+
398
+ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
399
+ attn_output = self.o(attn_output)
400
+
401
+ present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
402
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
403
+
404
+ if output_attentions:
405
+ outputs = outputs + (attn_weights,)
406
+ return outputs
407
+
408
+
409
+ class T5LayerSelfAttention(nn.Module):
410
+ def __init__(self, config, has_relative_attention_bias=False, normalize_before=False):
411
+ super().__init__()
412
+ self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
413
+ # self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
414
+ self.layer_norm = FST5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
415
+ self.dropout = nn.Dropout(config.dropout_rate)
416
+ self.normalize_before = normalize_before
417
+ self.has_relative_attention_bias = has_relative_attention_bias
418
+
419
+ def forward(
420
+ self,
421
+ hidden_states,
422
+ attention_mask=None,
423
+ position_bias=None,
424
+ layer_head_mask=None,
425
+ past_key_value=None,
426
+ use_cache=False,
427
+ output_attentions=False,
428
+ key_padding_mask=None,
429
+ ):
430
+ if self.normalize_before:
431
+ normed_hidden_states = self.layer_norm(hidden_states)
432
+ else:
433
+ normed_hidden_states = hidden_states
434
+
435
+ attention_output = self.SelfAttention(
436
+ normed_hidden_states,
437
+ mask=attention_mask,
438
+ position_bias=position_bias,
439
+ layer_head_mask=layer_head_mask,
440
+ past_key_value=past_key_value,
441
+ use_cache=use_cache,
442
+ output_attentions=output_attentions,
443
+ key_padding_mask=key_padding_mask,
444
+ )
445
+ hidden_states = hidden_states + self.dropout(attention_output[0])
446
+
447
+ if not self.normalize_before:
448
+ hidden_states = self.layer_norm(hidden_states)
449
+
450
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
451
+ return outputs
452
+
453
+
454
+ class T5LayerCrossAttention(nn.Module):
455
+ def __init__(self, config, normalize_before=False):
456
+ super().__init__()
457
+ self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
458
+ # self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
459
+ self.layer_norm = FST5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
460
+ self.dropout = nn.Dropout(config.dropout_rate)
461
+
462
+ self.normalize_before = normalize_before
463
+
464
+ def forward(
465
+ self,
466
+ hidden_states,
467
+ key_value_states,
468
+ attention_mask=None,
469
+ position_bias=None,
470
+ layer_head_mask=None,
471
+ past_key_value=None,
472
+ use_cache=False,
473
+ query_length=None,
474
+ output_attentions=False,
475
+ ):
476
+ if self.normalize_before:
477
+ normed_hidden_states = self.layer_norm(hidden_states)
478
+ else:
479
+ normed_hidden_states = hidden_states
480
+
481
+ attention_output = self.EncDecAttention(
482
+ normed_hidden_states,
483
+ mask=attention_mask,
484
+ key_value_states=key_value_states,
485
+ position_bias=position_bias,
486
+ layer_head_mask=layer_head_mask,
487
+ past_key_value=past_key_value,
488
+ use_cache=use_cache,
489
+ query_length=query_length,
490
+ output_attentions=output_attentions,
491
+ )
492 <