hackyon commited on
Commit
c88c1b7
1 Parent(s): 70b5857

Upload EncT5ForSequenceClassification

Browse files
Files changed (5) hide show
  1. README.md +201 -0
  2. config.json +68 -0
  3. configuration_enct5.py +136 -0
  4. model.safetensors +3 -0
  5. modeling_enct5.py +324 -0
README.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+
201
+
config.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/Users/hckyn/Desktop/enct5-glue-sst2/",
3
+ "architectures": [
4
+ "EncT5ForSequenceClassification"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_enct5.EncT5Config",
8
+ "AutoModelForSequenceClassification": "modeling_enct5.EncT5ForSequenceClassification"
9
+ },
10
+ "classifier_dropout": 0.0,
11
+ "d_ff": 3072,
12
+ "d_kv": 64,
13
+ "d_model": 768,
14
+ "decoder_start_token_id": 0,
15
+ "decoder_vocab_size": 1,
16
+ "dense_act_fn": "relu",
17
+ "dropout_rate": 0.1,
18
+ "eos_token_id": 1,
19
+ "feed_forward_proj": "relu",
20
+ "initializer_factor": 1.0,
21
+ "is_encoder_decoder": true,
22
+ "is_gated_act": false,
23
+ "layer_norm_epsilon": 1e-06,
24
+ "model_type": "enct5",
25
+ "n_positions": 512,
26
+ "num_decoder_layers": 1,
27
+ "num_heads": 12,
28
+ "num_layers": 12,
29
+ "output_past": true,
30
+ "pad_token_id": 0,
31
+ "problem_type": "single_label_classification",
32
+ "relative_attention_max_distance": 128,
33
+ "relative_attention_num_buckets": 32,
34
+ "task_specific_params": {
35
+ "summarization": {
36
+ "early_stopping": true,
37
+ "length_penalty": 2.0,
38
+ "max_length": 200,
39
+ "min_length": 30,
40
+ "no_repeat_ngram_size": 3,
41
+ "num_beams": 4,
42
+ "prefix": "summarize: "
43
+ },
44
+ "translation_en_to_de": {
45
+ "early_stopping": true,
46
+ "max_length": 300,
47
+ "num_beams": 4,
48
+ "prefix": "translate English to German: "
49
+ },
50
+ "translation_en_to_fr": {
51
+ "early_stopping": true,
52
+ "max_length": 300,
53
+ "num_beams": 4,
54
+ "prefix": "translate English to French: "
55
+ },
56
+ "translation_en_to_ro": {
57
+ "early_stopping": true,
58
+ "max_length": 300,
59
+ "num_beams": 4,
60
+ "prefix": "translate English to Romanian: "
61
+ }
62
+ },
63
+ "tie_word_embeddings": false,
64
+ "torch_dtype": "float32",
65
+ "transformers_version": "4.37.1",
66
+ "use_cache": true,
67
+ "vocab_size": 32128
68
+ }
configuration_enct5.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ EncT5 model configuration"""
15
+
16
+ from transformers.configuration_utils import PretrainedConfig
17
+
18
+
19
+ class EncT5Config(PretrainedConfig):
20
+ r"""
21
+ This is the configuration class to store the configuration of a [`EncT5`]. It is used to instantiate a EncT5 model
22
+ according to the specified arguments, defining the model architecture. Instantiating a configuration with the
23
+ defaults will yield a similar configuration to that of the T5 [t5-small](https://huggingface.co/t5-small)
24
+ architecture.
25
+
26
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
27
+ documentation from [`PretrainedConfig`] for more information.
28
+
29
+ Arguments:
30
+ vocab_size (`int`, *optional*, defaults to 32128):
31
+ Vocabulary size of the EncT5 model. Defines the number of different tokens that can be represented by the
32
+ `inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`].
33
+ decoder_vocab_size (`int`, *optional*, defaults to 1):
34
+ Decoder vocabulary size of the EncT5 model. For regression and single-label classification, this should just
35
+ be 1 (the default). For multi-label classification, this should be the number of labels.
36
+ d_model (`int`, *optional*, defaults to 512):
37
+ Size of the encoder layers and the pooler layer.
38
+ d_kv (`int`, *optional*, defaults to 64):
39
+ Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will
40
+ be defined as `num_heads * d_kv`.
41
+ d_ff (`int`, *optional*, defaults to 2048):
42
+ Size of the intermediate feed forward layer in each `T5Block`.
43
+ num_layers (`int`, *optional*, defaults to 6):
44
+ Number of hidden layers in the Transformer encoder.
45
+ num_decoder_layers (`int`, *optional*, defaults to 1):
46
+ Number of hidden layers in the Transformer decoder.
47
+ num_heads (`int`, *optional*, defaults to 8):
48
+ Number of attention heads for each attention layer in the Transformer encoder.
49
+ relative_attention_num_buckets (`int`, *optional*, defaults to 32):
50
+ The number of buckets to use for each attention layer.
51
+ relative_attention_max_distance (`int`, *optional*, defaults to 128):
52
+ The maximum distance of the longer sequences for the bucket separation.
53
+ dropout_rate (`float`, *optional*, defaults to 0.1):
54
+ The ratio for all dropout layers.
55
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
56
+ The dropout ratio for classifier.
57
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
58
+ The epsilon used by the layer normalization layers.
59
+ initializer_factor (`float`, *optional*, defaults to 1):
60
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
61
+ testing).
62
+ feed_forward_proj (`string`, *optional*, defaults to `"relu"`):
63
+ Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. T5v1.1 uses the
64
+ `"gated-gelu"` feed forward projection. Original T5 uses `"relu"`.
65
+ use_cache (`bool`, *optional*, defaults to `True`):
66
+ Whether or not the model should return the last key/values attentions (not used by all models).
67
+ """
68
+
69
+ model_type = "enct5"
70
+ keys_to_ignore_at_inference = ["past_key_values"]
71
+ attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
72
+
73
+ def __init__(
74
+ self,
75
+ vocab_size=32128,
76
+ decoder_vocab_size=1,
77
+ d_model=512,
78
+ d_kv=64,
79
+ d_ff=2048,
80
+ num_layers=6,
81
+ num_decoder_layers=1,
82
+ num_heads=8,
83
+ relative_attention_num_buckets=32,
84
+ relative_attention_max_distance=128,
85
+ dropout_rate=0.1,
86
+ layer_norm_epsilon=1e-6,
87
+ initializer_factor=1.0,
88
+ feed_forward_proj="relu",
89
+ is_encoder_decoder=True,
90
+ use_cache=True,
91
+ pad_token_id=0,
92
+ eos_token_id=1,
93
+ classifier_dropout=0.0,
94
+ **kwargs,
95
+ ):
96
+ self.vocab_size = vocab_size
97
+ self.decoder_vocab_size = decoder_vocab_size
98
+ self.d_model = d_model
99
+ self.d_kv = d_kv
100
+ self.d_ff = d_ff
101
+ self.num_layers = num_layers
102
+ self.num_decoder_layers = num_decoder_layers
103
+ self.num_heads = num_heads
104
+ self.relative_attention_num_buckets = relative_attention_num_buckets
105
+ self.relative_attention_max_distance = relative_attention_max_distance
106
+ self.dropout_rate = dropout_rate
107
+ self.classifier_dropout = classifier_dropout
108
+ self.layer_norm_epsilon = layer_norm_epsilon
109
+ self.initializer_factor = initializer_factor
110
+ self.feed_forward_proj = feed_forward_proj
111
+ self.use_cache = use_cache
112
+
113
+ act_info = self.feed_forward_proj.split("-")
114
+ self.dense_act_fn = act_info[-1]
115
+ self.is_gated_act = act_info[0] == "gated"
116
+
117
+ if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
118
+ raise ValueError(
119
+ f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. "
120
+ "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
121
+ "'gated-gelu' or 'relu'"
122
+ )
123
+
124
+ # for backwards compatibility
125
+ if feed_forward_proj == "gated-gelu":
126
+ self.dense_act_fn = "gelu_new"
127
+
128
+ super().__init__(
129
+ pad_token_id=pad_token_id,
130
+ eos_token_id=eos_token_id,
131
+ is_encoder_decoder=is_encoder_decoder,
132
+ **kwargs,
133
+ )
134
+
135
+ # Override the default behavior to tie word embeddings.
136
+ self.tie_word_embeddings = False
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84f94588b91759227e5d958ff472fb38b2e24374a766f4d6402928414b631031
3
+ size 476301088
modeling_enct5.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ """ EncT5 model (based on HuggingFace T5 Model) """
14
+
15
+ from typing import Optional, List, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
20
+ from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Model
21
+ from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput
22
+
23
+ from .configuration_enct5 import EncT5Config
24
+
25
+
26
+ class EncT5ClassificationHead(nn.Module):
27
+ """Head for sentence-level classification tasks."""
28
+
29
+ def __init__(self, config: EncT5Config):
30
+ super().__init__()
31
+ self.dropout = nn.Dropout(p=config.classifier_dropout)
32
+ self.out_proj = nn.Linear(config.d_model, config.num_labels)
33
+
34
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
35
+ hidden_states = self.dropout(hidden_states)
36
+ hidden_states = self.out_proj(hidden_states)
37
+ return hidden_states
38
+
39
+
40
+ class EncT5MultiLabelClassificationHead(nn.Module):
41
+ """Head for multi-label sentence-level classification tasks."""
42
+
43
+ def __init__(self, config: EncT5Config):
44
+ super().__init__()
45
+ self.weights = nn.Parameter(torch.Tensor(config.num_labels, config.d_model))
46
+ self.biases = nn.Parameter(torch.Tensor(config.num_labels))
47
+ self.dropout = nn.Dropout(p=config.classifier_dropout)
48
+
49
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
50
+ # The input hidden_states shape should be (batch_size, num_labels, d_model)
51
+ hidden_states = self.dropout(hidden_states)
52
+ # The following element-wise multiplication simulates multiple per-label classification heads (one head per
53
+ # label). The element-wise multiplication of the weights, followed by a summation and addition of biases, is
54
+ # equivalent to a linear projection from d_model down to 1 for each label (but with vectorization).
55
+ hidden_states = torch.sum(hidden_states * self.weights, dim=-1) + self.biases # (batch_size, num_labels)
56
+ return hidden_states
57
+
58
+
59
+ class EncT5PreTrainedModel(T5PreTrainedModel):
60
+ def _init_weights(self, module):
61
+ """Initialize the weights"""
62
+ factor = self.config.initializer_factor # Used for testing weights initialization
63
+ if isinstance(module, EncT5ClassificationHead):
64
+ module.out_proj.weight.data.normal_(mean=0.0, std=factor * (self.config.d_model ** -0.5))
65
+ if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
66
+ module.out_proj.bias.data.zero_()
67
+ elif isinstance(module, EncT5MultiLabelClassificationHead):
68
+ module.weights.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
69
+ module.biases.data.zero_()
70
+ super()._init_weights(module)
71
+
72
+
73
+ class EncT5ForSequenceClassification(EncT5PreTrainedModel):
74
+ r"""
75
+ The EncT5 model was proposed in [EncT5: A Framework for Fine-tuning T5 as Non-autoregressive
76
+ Models](https://arxiv.org/abs/2110.08426) by Frederick Liu, Terry Huang, Shihang Lyu, Siamak Shakeri, Hongkun Yu,
77
+ Jing Li.
78
+
79
+ EncT5 is a variant of T5 that uses mainly the encoder for non-autoregressive tasks. There are several special
80
+ features to EncT5: 1) there are less decoder layers (defaulting to 1 decoder layer), 2) there is a separate decoder
81
+ word embedding, with the decoder input ids being predefined constants, and 3) there is a classification head on top
82
+ of the output. Research has shown that this model can be more efficient and usable over T5 and BERT for
83
+ non-autoregressive tasks such as classification and regression.
84
+ """
85
+ config_class = EncT5Config
86
+ _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
87
+
88
+ def __init__(self, config: EncT5Config):
89
+ super().__init__(config)
90
+
91
+ # Initialize the base T5 model.
92
+ self.transformer = T5Model(T5Config.from_dict(config.to_dict()))
93
+
94
+ # Initiate decoder embedding from scratch and define the corresponding latent vector vocabulary size.
95
+ self.decoder_embeddings = nn.Embedding(config.decoder_vocab_size, config.d_model)
96
+ self.transformer.get_decoder().set_input_embeddings(self.decoder_embeddings)
97
+
98
+ # Initiate decoder projection head from scratch.
99
+ if config.problem_type == "multi_label_classification":
100
+ self.classification_head = EncT5MultiLabelClassificationHead(config)
101
+ else:
102
+ self.classification_head = EncT5ClassificationHead(config)
103
+
104
+ # Initialize weights and apply final processing
105
+ self.post_init()
106
+
107
+ self.model_parallel = False
108
+
109
+ def load_weights_from_pretrained_t5(self, model_path: str):
110
+ pretrained_t5_model = T5Model.from_pretrained(model_path)
111
+
112
+ # Override the decoder embedding weights to make them the correct shape.
113
+ pretrained_state_dict = pretrained_t5_model.state_dict()
114
+ pretrained_state_dict["decoder.embed_tokens.weight"] = self.decoder_embeddings.state_dict()["weight"]
115
+
116
+ self.transformer.load_state_dict(pretrained_state_dict, strict=False)
117
+
118
+ def prepare_for_fine_tuning(self):
119
+ r"""
120
+ Prepares the model for fine-tuning by re-initializing the necessary weights for fine-tuning. This step should be
121
+ performed after loading the pre-trained T5 model but before fine-tuning.
122
+ """
123
+ self.transformer.get_decoder().apply(self._init_weights)
124
+ self._init_weights(self.classification_head)
125
+
126
+ def forward(
127
+ self,
128
+ input_ids: Optional[torch.LongTensor] = None,
129
+ attention_mask: Optional[torch.Tensor] = None,
130
+ decoder_input_ids: Optional[torch.LongTensor] = None,
131
+ decoder_attention_mask: Optional[torch.LongTensor] = None,
132
+ head_mask: Optional[torch.Tensor] = None,
133
+ decoder_head_mask: Optional[torch.Tensor] = None,
134
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
135
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
136
+ inputs_embeds: Optional[torch.FloatTensor] = None,
137
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
138
+ labels: Optional[torch.LongTensor] = None,
139
+ use_cache: Optional[bool] = None,
140
+ output_attentions: Optional[bool] = None,
141
+ output_hidden_states: Optional[bool] = None,
142
+ return_dict: Optional[bool] = None,
143
+ ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
144
+ r"""
145
+ Arguments:
146
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
147
+ Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so
148
+ you should be able to pad the inputs on both the right and the left.
149
+
150
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
151
+ [`PreTrainedTokenizer.__call__`] for detail.
152
+
153
+ [What are input IDs?](../glossary#input-ids)
154
+
155
+ To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
156
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
157
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
158
+
159
+ - 1 for tokens that are **not masked**,
160
+ - 0 for tokens that are **masked**.
161
+
162
+ [What are attention masks?](../glossary#attention-mask)
163
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
164
+ Indices of decoder input sequence tokens in the vocabulary.
165
+
166
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
167
+ [`PreTrainedTokenizer.__call__`] for details.
168
+
169
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
170
+
171
+ T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
172
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
173
+ `past_key_values`).
174
+
175
+ To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
176
+ Training](./t5#training).
177
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
178
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will
179
+ also be used by default.
180
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
181
+ Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in
182
+ `[0, 1]`:
183
+
184
+ - 1 indicates the head is **not masked**,
185
+ - 0 indicates the head is **masked**.
186
+
187
+ decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
188
+ Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in
189
+ `[0, 1]`:
190
+
191
+ - 1 indicates the head is **not masked**,
192
+ - 0 indicates the head is **masked**.
193
+
194
+ cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
195
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected
196
+ in `[0, 1]`:
197
+
198
+ - 1 indicates the head is **not masked**,
199
+ - 0 indicates the head is **masked**.
200
+
201
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
202
+ Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
203
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states
204
+ at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
205
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4
206
+ tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
207
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up
208
+ decoding.
209
+
210
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
211
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
212
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
213
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
214
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
215
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
216
+ than the model's internal embedding lookup matrix.
217
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
218
+ Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
219
+ representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to
220
+ be input (see `past_key_values`). This is useful if you want more control over how to convert
221
+ `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
222
+
223
+ If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the
224
+ value of `inputs_embeds`.
225
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
226
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
227
+ config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
228
+ use_cache (`bool`, *optional*):
229
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
230
+ (see `past_key_values`).
231
+ output_attentions (`bool`, *optional*):
232
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
233
+ tensors for more detail.
234
+ output_hidden_states (`bool`, *optional*):
235
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
236
+ more detail.
237
+ return_dict (`bool`, *optional*):
238
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
239
+ Returns:
240
+ """
241
+
242
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
243
+ if labels is not None:
244
+ use_cache = False
245
+
246
+ if input_ids is None and inputs_embeds is None:
247
+ raise ValueError("You have to specify either input_ids or inputs_embeds.")
248
+ batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
249
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
250
+
251
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
252
+ if self.config.problem_type == "multi_label_classification":
253
+ decoder_input_ids = torch.arange(end=self.config.num_labels, device=device, dtype=torch.long)
254
+ decoder_input_ids = decoder_input_ids.repeat(batch_size, 1) # Shape: (batch_size, num_labels)
255
+ # Provide a 3-dimensional attention mask by default to suppress the default causal mask.
256
+ if decoder_attention_mask is None:
257
+ decoder_attention_mask = torch.ones(
258
+ (batch_size, self.config.num_labels, self.config.num_labels), device=device, dtype=torch.long
259
+ )
260
+ else:
261
+ decoder_input_ids = torch.zeros(batch_size, 1, device=device, dtype=torch.long)
262
+
263
+ outputs = self.transformer(
264
+ input_ids=input_ids,
265
+ attention_mask=attention_mask,
266
+ decoder_input_ids=decoder_input_ids,
267
+ decoder_attention_mask=decoder_attention_mask,
268
+ head_mask=head_mask,
269
+ decoder_head_mask=decoder_head_mask,
270
+ cross_attn_head_mask=cross_attn_head_mask,
271
+ encoder_outputs=encoder_outputs,
272
+ inputs_embeds=inputs_embeds,
273
+ decoder_inputs_embeds=decoder_inputs_embeds,
274
+ use_cache=use_cache,
275
+ output_attentions=output_attentions,
276
+ output_hidden_states=output_hidden_states,
277
+ return_dict=return_dict,
278
+ )
279
+ sequence_output = outputs[0] # Shape: (batch_size, 1 or num_labels, d_model)
280
+
281
+ logits = self.classification_head(sequence_output)
282
+
283
+ loss = None
284
+ if labels is not None:
285
+ labels = labels.to(logits.device)
286
+ if self.config.problem_type is None:
287
+ if self.config.num_labels == 1:
288
+ self.config.problem_type = "regression"
289
+ elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
290
+ self.config.problem_type = "single_label_classification"
291
+ else:
292
+ # The classification head for multi-label classification is different, and so we need the
293
+ # problem_type to be set during initialization to select the proper classification head.
294
+ raise ValueError(
295
+ "For multi-label classification, the config.problem_type must be set to "
296
+ "'multi_label_classification' when initializing the model.")
297
+
298
+ if self.config.problem_type == "regression":
299
+ loss_fct = MSELoss()
300
+ if self.config.num_labels == 1:
301
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
302
+ else:
303
+ loss = loss_fct(logits, labels)
304
+ elif self.config.problem_type == "single_label_classification":
305
+ loss_fct = CrossEntropyLoss()
306
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
307
+ else:
308
+ loss_fct = BCEWithLogitsLoss()
309
+ loss = loss_fct(logits, labels)
310
+ if not return_dict:
311
+ output = (logits,) + outputs[1:]
312
+ return ((loss,) + output) if loss is not None else output
313
+
314
+ return Seq2SeqSequenceClassifierOutput(
315
+ loss=loss,
316
+ logits=logits,
317
+ past_key_values=outputs.past_key_values,
318
+ decoder_hidden_states=outputs.decoder_hidden_states,
319
+ decoder_attentions=outputs.decoder_attentions,
320
+ cross_attentions=outputs.cross_attentions,
321
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
322
+ encoder_hidden_states=outputs.encoder_hidden_states,
323
+ encoder_attentions=outputs.encoder_attentions,
324
+ )