File size: 8,865 Bytes
b22343e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import torch
from torch import nn
from typing import Optional, Union, Tuple
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.models.t5.modeling_t5 import T5Config, T5ClassificationHead, T5EncoderModel
from transformers.models.mt5.modeling_mt5 import MT5Config, MT5ClassificationHead, MT5EncoderModel
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
def clean_t5_config(config: Union[T5Config, MT5Config], model_type: str):
assert model_type in ['t5', 'mt5']
setattr(config, 'pooling_mode', 'mean')
setattr(config, 'model_type', model_type)
setattr(config, 'use_cache', False)
setattr(config, 'is_encoder_decoder', False)
setattr(config, 'num_decoder_layers', None)
delattr(config, 'task_specific_params') if hasattr(config, 'task_specific_params') else None
class T5EncoderForSequenceClassification(T5EncoderModel):
"""
T5 encoder for sequence classification tasks.
:param config: The T5 configuration object.
"""
def __init__(self, config: T5Config):
super().__init__(config)
self.pool_layer = PoolLayer(config.pooling_mode)
self.classification_head = T5ClassificationHead(config)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
"""
Forward pass of the T5 encoder for sequence classification tasks.
:param input_ids: The input token IDs.
:param attention_mask: The attention mask.
:param head_mask: The head mask.
:param inputs_embeds: The input embeddings.
:param labels: The target labels.
:param output_attentions: Whether to output attentions.
:param output_hidden_states: Whether to output hidden states.
:param return_dict: Whether to return a dictionary.
:returns: The logits for the classification task or a dictionary containing the outputs.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
loss = None
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = self.pool_layer(outputs.last_hidden_state, attention_mask)
logits = self.classification_head(sequence_output)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class MT5EncoderForSequenceClassification(MT5EncoderModel):
"""
mT5 encoder for sequence classification tasks.
:param config: The mT5 configuration object.
"""
def __init__(self, config: MT5Config):
super().__init__(config)
self.pool_layer = PoolLayer(config.pooling_mode)
self.classification_head = MT5ClassificationHead(config)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
"""
Forward pass of the mT5 encoder for sequence classification tasks.
:param input_ids: The input token IDs.
:param attention_mask: The attention mask.
:param head_mask: The head mask.
:param inputs_embeds: The input embeddings.
:param labels: The target labels.
:param output_attentions: Whether to output attentions.
:param output_hidden_states: Whether to output hidden states.
:param return_dict: Whether to return a dictionary.
:returns: The logits for the classification task or a dictionary containing the outputs.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
loss = None
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = self.pool_layer(outputs.last_hidden_state, attention_mask)
logits = self.classification_head(sequence_output)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class PoolLayer(nn.Module):
"""
Pooling layer on top of the commputed token embeddings.
:param pooling_mode: The pooling strategy to use.
"""
def __init__(self, pooling_mode: str):
super().__init__()
assert pooling_mode in ['first', 'mean', 'max'], f"ERROR: Unknown pooling strategy '{pooling_mode}'"
self.pooling_mode = pooling_mode
def forward(self, token_embeddings: torch.Tensor, attention_masks: torch.Tensor) -> torch.Tensor:
"""
Compute the passage vector by pooling the token embeddings.
:param token_embeddings: A 3D tensor of size [batch_size, seq_len, d_model].
:param attention_masks: A 2D tensor of size [batch_size, seq_len].
:returns: A 2D tensor of size [batch_size, d_model].
"""
if self.pooling_mode == 'first':
text_vectors = token_embeddings[:, 0, :]
elif self.pooling_mode == 'max':
# Set all values of the [PAD] embeddings to large negative values (so that they are never considered as maximum for a channel).
attention_masks_expanded = attention_masks.unsqueeze(-1).expand(token_embeddings.size())
token_embeddings[attention_masks_expanded == 0] = -1e+9 if token_embeddings.dtype == torch.float32 else -1e+4
# Compute the maxima along the 'seq_length' dimension (-> Tensor[batch_size, d_model]).
text_vectors = torch.max(token_embeddings, dim=1).values
else:
# Set all values of the [PAD] embeddings to zeros (so that they are not taken into account in the sum for a channel).
attention_masks_expanded = attention_masks.unsqueeze(-1).expand(token_embeddings.size())
token_embeddings[attention_masks_expanded == 0] = 0.0
# Compute the means by first summing along the 'seq_length' dimension (-> Tensor[batch_size, d_model]).
sum_embeddings = torch.sum(token_embeddings, dim=1)
# Then, divide all values of a passage vector by the original passage length.
sum_mask = attention_masks_expanded.sum(dim=1) # -> Tensor[batch_size, d_model] where each value is the length of the corresponding passage.
sum_mask = torch.clamp(sum_mask, min=1e-7) # Make sure not to have zeros by lower bounding all elements to 1e-7.
text_vectors = sum_embeddings / sum_mask # Divide each dimension by the sequence length.
return text_vectors
if __name__ == "__main__":
config = AutoConfig.from_pretrained(model_name)
if isinstance(config, T5Config):
clean_t5_config(self.config, model_type='t5')
model = T5EncoderForSequenceClassification.from_pretrained(model_name, config=config)
elif isinstance(config, MT5Config):
clean_t5_config(self.config, model_type='t5')
model = MT5EncoderForSequenceClassification.from_pretrained(model_name, config=config)
else:
model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config)
tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_args)
|