|
import torch |
|
from torch import nn |
|
from typing import Optional, Union, Tuple |
|
|
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
from transformers.models.t5.modeling_t5 import T5Config, T5ClassificationHead, T5EncoderModel |
|
from transformers.models.mt5.modeling_mt5 import MT5Config, MT5ClassificationHead, MT5EncoderModel |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig |
|
|
|
|
|
def clean_t5_config(config: Union[T5Config, MT5Config], model_type: str): |
|
assert model_type in ['t5', 'mt5'] |
|
setattr(config, 'pooling_mode', 'mean') |
|
setattr(config, 'model_type', model_type) |
|
setattr(config, 'use_cache', False) |
|
setattr(config, 'is_encoder_decoder', False) |
|
setattr(config, 'num_decoder_layers', None) |
|
delattr(config, 'task_specific_params') if hasattr(config, 'task_specific_params') else None |
|
|
|
|
|
class T5EncoderForSequenceClassification(T5EncoderModel): |
|
""" |
|
T5 encoder for sequence classification tasks. |
|
|
|
:param config: The T5 configuration object. |
|
""" |
|
def __init__(self, config: T5Config): |
|
super().__init__(config) |
|
self.pool_layer = PoolLayer(config.pooling_mode) |
|
self.classification_head = T5ClassificationHead(config) |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: |
|
""" |
|
Forward pass of the T5 encoder for sequence classification tasks. |
|
|
|
:param input_ids: The input token IDs. |
|
:param attention_mask: The attention mask. |
|
:param head_mask: The head mask. |
|
:param inputs_embeds: The input embeddings. |
|
:param labels: The target labels. |
|
:param output_attentions: Whether to output attentions. |
|
:param output_hidden_states: Whether to output hidden states. |
|
:param return_dict: Whether to return a dictionary. |
|
:returns: The logits for the classification task or a dictionary containing the outputs. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
loss = None |
|
|
|
outputs = self.encoder( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
inputs_embeds=inputs_embeds, |
|
head_mask=head_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
sequence_output = self.pool_layer(outputs.last_hidden_state, attention_mask) |
|
logits = self.classification_head(sequence_output) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
class MT5EncoderForSequenceClassification(MT5EncoderModel): |
|
""" |
|
mT5 encoder for sequence classification tasks. |
|
|
|
:param config: The mT5 configuration object. |
|
""" |
|
def __init__(self, config: MT5Config): |
|
super().__init__(config) |
|
self.pool_layer = PoolLayer(config.pooling_mode) |
|
self.classification_head = MT5ClassificationHead(config) |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: |
|
""" |
|
Forward pass of the mT5 encoder for sequence classification tasks. |
|
|
|
:param input_ids: The input token IDs. |
|
:param attention_mask: The attention mask. |
|
:param head_mask: The head mask. |
|
:param inputs_embeds: The input embeddings. |
|
:param labels: The target labels. |
|
:param output_attentions: Whether to output attentions. |
|
:param output_hidden_states: Whether to output hidden states. |
|
:param return_dict: Whether to return a dictionary. |
|
:returns: The logits for the classification task or a dictionary containing the outputs. |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
loss = None |
|
|
|
outputs = self.encoder( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
inputs_embeds=inputs_embeds, |
|
head_mask=head_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
sequence_output = self.pool_layer(outputs.last_hidden_state, attention_mask) |
|
logits = self.classification_head(sequence_output) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
class PoolLayer(nn.Module): |
|
""" |
|
Pooling layer on top of the commputed token embeddings. |
|
|
|
:param pooling_mode: The pooling strategy to use. |
|
""" |
|
def __init__(self, pooling_mode: str): |
|
super().__init__() |
|
assert pooling_mode in ['first', 'mean', 'max'], f"ERROR: Unknown pooling strategy '{pooling_mode}'" |
|
self.pooling_mode = pooling_mode |
|
|
|
def forward(self, token_embeddings: torch.Tensor, attention_masks: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Compute the passage vector by pooling the token embeddings. |
|
|
|
:param token_embeddings: A 3D tensor of size [batch_size, seq_len, d_model]. |
|
:param attention_masks: A 2D tensor of size [batch_size, seq_len]. |
|
:returns: A 2D tensor of size [batch_size, d_model]. |
|
""" |
|
if self.pooling_mode == 'first': |
|
text_vectors = token_embeddings[:, 0, :] |
|
elif self.pooling_mode == 'max': |
|
|
|
attention_masks_expanded = attention_masks.unsqueeze(-1).expand(token_embeddings.size()) |
|
token_embeddings[attention_masks_expanded == 0] = -1e+9 if token_embeddings.dtype == torch.float32 else -1e+4 |
|
|
|
text_vectors = torch.max(token_embeddings, dim=1).values |
|
else: |
|
|
|
attention_masks_expanded = attention_masks.unsqueeze(-1).expand(token_embeddings.size()) |
|
token_embeddings[attention_masks_expanded == 0] = 0.0 |
|
|
|
sum_embeddings = torch.sum(token_embeddings, dim=1) |
|
|
|
sum_mask = attention_masks_expanded.sum(dim=1) |
|
sum_mask = torch.clamp(sum_mask, min=1e-7) |
|
text_vectors = sum_embeddings / sum_mask |
|
return text_vectors |
|
|
|
|
|
if __name__ == "__main__": |
|
config = AutoConfig.from_pretrained(model_name) |
|
if isinstance(config, T5Config): |
|
clean_t5_config(self.config, model_type='t5') |
|
model = T5EncoderForSequenceClassification.from_pretrained(model_name, config=config) |
|
elif isinstance(config, MT5Config): |
|
clean_t5_config(self.config, model_type='t5') |
|
model = MT5EncoderForSequenceClassification.from_pretrained(model_name, config=config) |
|
else: |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_args) |
|
|