sjrhuschlee commited on
Commit
86ae316
1 Parent(s): c834e50

Upload modeling_t5seq.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_t5seq.py +40 -107
modeling_t5seq.py CHANGED
@@ -7,28 +7,19 @@ from torch import nn
7
  from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
8
 
9
  from transformers import AutoModelForSequenceClassification
10
- from transformers.modeling_outputs import (
11
- BaseModelOutput,
12
- Seq2SeqSequenceClassifierOutput,
13
- )
14
  from transformers.models.t5.configuration_t5 import T5Config
15
- from transformers.models.t5.modeling_t5 import T5PreTrainedModel, T5Stack
16
 
17
 
18
  class T5ClassificationHead(nn.Module):
19
  """Head for sentence-level classification tasks."""
20
 
21
- def __init__(
22
- self,
23
- input_dim: int,
24
- inner_dim: int,
25
- num_classes: int,
26
- pooler_dropout: float,
27
- ):
28
  super().__init__()
29
- self.dense = nn.Linear(input_dim, inner_dim)
30
- self.dropout = nn.Dropout(p=pooler_dropout)
31
- self.out_proj = nn.Linear(inner_dim, num_classes)
32
 
33
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
34
  hidden_states = self.dropout(hidden_states)
@@ -45,50 +36,14 @@ class T5ForSequenceClassification(T5PreTrainedModel):
45
 
46
  def __init__(self, config: T5Config):
47
  super().__init__(config)
48
- self.model_dim = config.d_model
49
-
50
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
51
-
52
- encoder_config = copy.deepcopy(config)
53
- encoder_config.is_decoder = False
54
- encoder_config.use_cache = False
55
- encoder_config.is_encoder_decoder = False
56
- self.encoder = T5Stack(encoder_config, self.shared)
57
-
58
- decoder_config = copy.deepcopy(config)
59
- decoder_config.is_decoder = True
60
- decoder_config.is_encoder_decoder = False
61
- decoder_config.num_layers = config.num_decoder_layers
62
- self.decoder = T5Stack(decoder_config, self.shared)
63
-
64
- self.num_labels = config.num_labels
65
-
66
- self.classification_head = T5ClassificationHead(
67
- config.d_model,
68
- config.d_model,
69
- config.num_labels,
70
- config.classifier_dropout,
71
- )
72
 
73
  # Initialize weights and apply final processing
74
  self.post_init()
75
 
76
  self.model_parallel = False
77
 
78
- def get_input_embeddings(self):
79
- return self.shared
80
-
81
- def set_input_embeddings(self, new_embeddings):
82
- self.shared = new_embeddings
83
- self.encoder.set_input_embeddings(new_embeddings)
84
- self.decoder.set_input_embeddings(new_embeddings)
85
-
86
- def get_encoder(self):
87
- return self.encoder
88
-
89
- def get_decoder(self):
90
- return self.decoder
91
-
92
  def forward(
93
  self,
94
  input_ids: torch.LongTensor = None,
@@ -114,13 +69,16 @@ class T5ForSequenceClassification(T5PreTrainedModel):
114
  Returns:
115
  """
116
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
117
- use_cache = use_cache if use_cache is not None else self.config.use_cache
118
  if labels is not None:
119
  use_cache = False
120
 
121
- # Copied from models.bart.modeling_bart.BartModel.forward
122
- # different to other models, T5 automatically creates decoder_input_ids from
123
- # input_ids if no decoder_input_ids are provided
 
 
 
 
124
  if decoder_input_ids is None and decoder_inputs_embeds is None:
125
  if input_ids is None:
126
  raise ValueError(
@@ -130,57 +88,30 @@ class T5ForSequenceClassification(T5PreTrainedModel):
130
  )
131
  decoder_input_ids = self._shift_right(input_ids)
132
 
133
- # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
134
- if head_mask is not None and decoder_head_mask is None:
135
- if self.config.num_layers == self.config.num_decoder_layers:
136
- warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
137
- decoder_head_mask = head_mask
138
-
139
- # Encode if needed (training, first prediction pass)
140
- if encoder_outputs is None:
141
- encoder_outputs = self.encoder(
142
- input_ids=input_ids,
143
- attention_mask=attention_mask,
144
- inputs_embeds=inputs_embeds,
145
- head_mask=head_mask,
146
- output_attentions=output_attentions,
147
- output_hidden_states=output_hidden_states,
148
- return_dict=return_dict,
149
- )
150
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
151
- encoder_outputs = BaseModelOutput(
152
- last_hidden_state=encoder_outputs[0],
153
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
154
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
155
- )
156
-
157
- hidden_states = encoder_outputs[0]
158
-
159
- # Decode
160
- decoder_outputs = self.decoder(
161
- input_ids=decoder_input_ids,
162
- attention_mask=decoder_attention_mask,
163
- inputs_embeds=decoder_inputs_embeds,
164
- past_key_values=None,
165
- encoder_hidden_states=hidden_states,
166
- encoder_attention_mask=attention_mask,
167
- head_mask=decoder_head_mask,
168
  cross_attn_head_mask=cross_attn_head_mask,
 
 
 
169
  use_cache=use_cache,
170
  output_attentions=output_attentions,
171
  output_hidden_states=output_hidden_states,
172
  return_dict=return_dict,
173
  )
174
-
175
- sequence_output = decoder_outputs[0]
176
 
177
  eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device)
178
 
179
  if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
180
  raise ValueError("All examples must have the same number of <eos> tokens.")
181
- sentence_representation = sequence_output[eos_mask, :].view(
182
- sequence_output.size(0), -1, sequence_output.size(-1)
183
- )[:, -1, :]
184
  logits = self.classification_head(sentence_representation)
185
 
186
  loss = None
@@ -207,21 +138,23 @@ class T5ForSequenceClassification(T5PreTrainedModel):
207
  loss_fct = BCEWithLogitsLoss()
208
  loss = loss_fct(logits, labels)
209
  if not return_dict:
210
- output = (logits,) + decoder_outputs[1:] + encoder_outputs
211
  return ((loss,) + output) if loss is not None else output
212
 
213
  return Seq2SeqSequenceClassifierOutput(
214
  loss=loss,
215
  logits=logits,
216
- past_key_values=decoder_outputs.past_key_values,
217
- decoder_hidden_states=decoder_outputs.hidden_states,
218
- decoder_attentions=decoder_outputs.attentions,
219
- cross_attentions=decoder_outputs.cross_attentions,
220
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
221
- encoder_hidden_states=encoder_outputs.hidden_states,
222
- encoder_attentions=encoder_outputs.attentions,
223
  )
224
 
225
-
226
- AutoModelForSequenceClassification.register(T5Config, T5ForSequenceClassification)
 
 
227
 
 
7
  from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
8
 
9
  from transformers import AutoModelForSequenceClassification
10
+ from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput
 
 
 
11
  from transformers.models.t5.configuration_t5 import T5Config
12
+ from transformers.models.t5.modeling_t5 import T5PreTrainedModel, T5Model
13
 
14
 
15
  class T5ClassificationHead(nn.Module):
16
  """Head for sentence-level classification tasks."""
17
 
18
+ def __init__(self, config: T5Config):
 
 
 
 
 
 
19
  super().__init__()
20
+ self.dense = nn.Linear(config.d_model, config.d_model)
21
+ self.dropout = nn.Dropout(p=config.classifier_dropout)
22
+ self.out_proj = nn.Linear(config.d_model, config.num_labels)
23
 
24
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
25
  hidden_states = self.dropout(hidden_states)
 
36
 
37
  def __init__(self, config: T5Config):
38
  super().__init__(config)
39
+ self.transformer = T5Model(config)
40
+ self.classification_head = T5ClassificationHead(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  # Initialize weights and apply final processing
43
  self.post_init()
44
 
45
  self.model_parallel = False
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def forward(
48
  self,
49
  input_ids: torch.LongTensor = None,
 
69
  Returns:
70
  """
71
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
72
  if labels is not None:
73
  use_cache = False
74
 
75
+ if input_ids is None and inputs_embeds is not None:
76
+ raise NotImplementedError(
77
+ f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
78
+ )
79
+
80
+ # Copied from models.bart.modeling_bart.BartModel.forward different to other models, T5 automatically creates
81
+ # decoder_input_ids from input_ids if no decoder_input_ids are provided
82
  if decoder_input_ids is None and decoder_inputs_embeds is None:
83
  if input_ids is None:
84
  raise ValueError(
 
88
  )
89
  decoder_input_ids = self._shift_right(input_ids)
90
 
91
+ outputs = self.transformer(
92
+ input_ids,
93
+ attention_mask=attention_mask,
94
+ decoder_input_ids=decoder_input_ids,
95
+ decoder_attention_mask=decoder_attention_mask,
96
+ head_mask=head_mask,
97
+ decoder_head_mask=decoder_head_mask,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  cross_attn_head_mask=cross_attn_head_mask,
99
+ encoder_outputs=encoder_outputs,
100
+ inputs_embeds=inputs_embeds,
101
+ decoder_inputs_embeds=decoder_inputs_embeds,
102
  use_cache=use_cache,
103
  output_attentions=output_attentions,
104
  output_hidden_states=output_hidden_states,
105
  return_dict=return_dict,
106
  )
107
+ sequence_output = outputs[0]
 
108
 
109
  eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device)
110
 
111
  if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
112
  raise ValueError("All examples must have the same number of <eos> tokens.")
113
+ batch_size, _, hidden_size = sequence_output.shape
114
+ sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :]
 
115
  logits = self.classification_head(sentence_representation)
116
 
117
  loss = None
 
138
  loss_fct = BCEWithLogitsLoss()
139
  loss = loss_fct(logits, labels)
140
  if not return_dict:
141
+ output = (logits,) + outputs[1:]
142
  return ((loss,) + output) if loss is not None else output
143
 
144
  return Seq2SeqSequenceClassifierOutput(
145
  loss=loss,
146
  logits=logits,
147
+ past_key_values=outputs.past_key_values,
148
+ decoder_hidden_states=outputs.decoder_hidden_states,
149
+ decoder_attentions=outputs.decoder_attentions,
150
+ cross_attentions=outputs.cross_attentions,
151
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
152
+ encoder_hidden_states=outputs.encoder_hidden_states,
153
+ encoder_attentions=outputs.encoder_attentions,
154
  )
155
 
156
+ try:
157
+ AutoModelForSequenceClassification.register(T5Config, T5ForSequenceClassification)
158
+ except ValueError:
159
+ pass
160