|
from abc import ABC |
|
from dataclasses import dataclass |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import torch |
|
from torch.nn.modules.loss import _Loss |
|
from transformers import XLMRobertaXLPreTrainedModel, XLMRobertaXLModel, XLMRobertaXLConfig |
|
from transformers import AutoModelForSequenceClassification, AutoConfig |
|
from transformers.modeling_outputs import ModelOutput |
|
from pytorch_metric_learning.losses import NTXentLoss |
|
|
|
|
|
@dataclass |
|
class HierarchicalSequenceEmbedderOutput(ModelOutput): |
|
loss: Optional[torch.FloatTensor] = None |
|
embeddings: torch.FloatTensor = None |
|
layer_embeddings: torch.FloatTensor = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
@dataclass |
|
class HierarchicalSequenceClassifierOutput(ModelOutput): |
|
loss: Optional[torch.FloatTensor] = None |
|
logits: torch.FloatTensor = None |
|
embeddings: torch.FloatTensor = None |
|
layer_embeddings: torch.FloatTensor = None |
|
hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
attentions: Optional[Tuple[torch.FloatTensor]] = None |
|
|
|
|
|
class HierarchicalXLMRobertaXLConfig(XLMRobertaXLConfig): |
|
model_type = "hierarchical-xlm-roberta-xl" |
|
|
|
def __init__(self, label_smoothing: Optional[float] = None, **kwargs): |
|
super().__init__(**kwargs) |
|
self.label_smoothing = label_smoothing |
|
|
|
|
|
class XLMRobertaXLHierarchicalClassificationHead(torch.nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size) |
|
classifier_dropout = ( |
|
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
|
) |
|
self.dropout = torch.nn.Dropout(classifier_dropout) |
|
self.out_proj = torch.nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
def forward(self, features, **kwargs): |
|
x = self.dropout(features) |
|
x = self.dense(x) |
|
x = torch.tanh(x) |
|
x = self.dropout(x) |
|
x = self.out_proj(x) |
|
return x |
|
|
|
|
|
def distance_to_probability(distance: torch.Tensor, margin: float) -> torch.Tensor: |
|
margin = torch.full(size=distance.size(), fill_value=margin, |
|
dtype=distance.dtype, device=distance.device, requires_grad=False) |
|
p = (1.0 + torch.exp(-margin)) / (1.0 + torch.exp(distance - margin)) |
|
del margin |
|
return p |
|
|
|
|
|
class DistanceBasedLogisticLoss(_Loss): |
|
__constants__ = ['margin', 'reduction'] |
|
margin: float |
|
|
|
def __init__(self, margin: float = 1.0, size_average=None, reduce=None, reduction: str = 'mean'): |
|
super(DistanceBasedLogisticLoss, self).__init__(size_average, reduce, reduction) |
|
self.margin = margin |
|
|
|
def forward(self, inputs, targets): |
|
inputs = inputs.view(-1) |
|
targets = targets.to(inputs.dtype).view(-1) |
|
p = distance_to_probability(inputs, self.margin) |
|
return torch.nn.functional.binary_cross_entropy(input=p, target=targets, reduction=self.reduction) |
|
|
|
|
|
class LayerGatingNetwork(torch.nn.Module): |
|
__constants__ = ['in_features'] |
|
in_features: int |
|
weight: torch.Tensor |
|
|
|
def __init__(self, in_features: int, device=None, dtype=None) -> None: |
|
factory_kwargs = {'device': device, 'dtype': dtype} |
|
super().__init__() |
|
self.in_features = in_features |
|
self.weight = torch.nn.Parameter(torch.empty((1, in_features), **factory_kwargs)) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self) -> None: |
|
initial_layer_weights = np.array( |
|
[1.0 / (self.in_features - layer_idx) for layer_idx in range(self.in_features)], |
|
dtype=np.float32 |
|
) |
|
initial_layer_weights /= np.sum(initial_layer_weights) |
|
initial_layer_weights_pt = torch.tensor( |
|
initial_layer_weights.reshape((1, self.in_features)), |
|
dtype=self.weight.dtype, |
|
device=self.weight.device |
|
) |
|
del initial_layer_weights |
|
self.weight = torch.nn.Parameter(initial_layer_weights_pt) |
|
del initial_layer_weights_pt |
|
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor: |
|
return torch.nn.functional.linear(input, torch.softmax(self.weight, dim=-1)) |
|
|
|
def extra_repr(self) -> str: |
|
return 'in_features={}'.format(self.in_features) |
|
|
|
|
|
class XLMRobertaXLForHierarchicalEmbedding(XLMRobertaXLPreTrainedModel, ABC): |
|
config_class = HierarchicalXLMRobertaXLConfig |
|
|
|
def __init__(self, config: HierarchicalXLMRobertaXLConfig): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.temperature = config.temperature |
|
self.config = config |
|
|
|
self.roberta = XLMRobertaXLModel(config, add_pooling_layer=False) |
|
self.layer_weights = LayerGatingNetwork(in_features=config.num_hidden_layers) |
|
|
|
self.init_weights() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
right_input_ids: Optional[torch.LongTensor] = None, |
|
right_attention_mask: Optional[torch.LongTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, HierarchicalSequenceEmbedderOutput]: |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.roberta( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=True, |
|
return_dict=False |
|
) |
|
cls_hidden_states = torch.stack( |
|
tensors=outputs[2][-self.config.num_hidden_layers:], |
|
dim=1 |
|
)[:, :, 0, :] |
|
cls_emb = self.layer_weights(cls_hidden_states.permute(0, 2, 1))[:, :, 0] |
|
|
|
loss = None |
|
if labels is not None: |
|
cls_emb_ = cls_emb.view(-1, self.config.hidden_size) |
|
emb_norm = torch.linalg.norm(cls_emb_, dim=-1, keepdim=True) + 1e-9 |
|
if (right_input_ids is not None) or (right_attention_mask is not None): |
|
if right_input_ids is None: |
|
raise ValueError(f'right_input_ids is not specified!') |
|
if right_attention_mask is None: |
|
raise ValueError(f'right_attention_mask is not specified!') |
|
right_outputs = self.roberta( |
|
right_input_ids, |
|
attention_mask=right_attention_mask, |
|
output_hidden_states=True, |
|
return_dict=False |
|
) |
|
right_cls_hidden_states = torch.stack( |
|
tensors=right_outputs[2][-self.config.num_hidden_layers:], |
|
dim=1 |
|
)[:, :, 0, :] |
|
right_cls_emb = self.layer_weights(right_cls_hidden_states.permute(0, 2, 1))[:, :, 0] |
|
right_cls_emb_ = right_cls_emb.view(-1, self.config.hidden_size) |
|
right_emb_norm = torch.linalg.norm(right_cls_emb_, dim=-1, keepdim=True) + 1e-9 |
|
distances = torch.norm(cls_emb_ / emb_norm - right_cls_emb_ / right_emb_norm, 2, dim=-1) |
|
loss_fct = DistanceBasedLogisticLoss(margin=1.0) |
|
loss = loss_fct(distances, labels.view(-1)) |
|
else: |
|
loss_fct = NTXentLoss(temperature=self.temperature) |
|
loss = loss_fct(cls_emb_ / emb_norm, labels.view(-1)) |
|
|
|
if not return_dict: |
|
output = (cls_emb, cls_hidden_states) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return HierarchicalSequenceEmbedderOutput( |
|
loss=loss, |
|
embeddings=cls_emb, |
|
layer_embeddings=cls_hidden_states, |
|
hidden_states=outputs[2], |
|
attentions=outputs[3] if output_attentions else None, |
|
) |
|
|
|
@property |
|
def layer_importances(self) -> List[Tuple[int, float]]: |
|
with torch.no_grad(): |
|
importances = torch.softmax(self.layer_weights.weight, dim=-1).detach().cpu().numpy().flatten() |
|
indices_and_importances = [] |
|
for layer_idx in range(importances.shape[0]): |
|
indices_and_importances.append((layer_idx + 1, float(importances[layer_idx]))) |
|
indices_and_importances.sort(key=lambda it: (-it[1], it[0])) |
|
return indices_and_importances |
|
|
|
|
|
class XLMRobertaXLForHierarchicalSequenceClassification(XLMRobertaXLForHierarchicalEmbedding, ABC): |
|
def __init__(self, config: HierarchicalXLMRobertaXLConfig): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.label_smoothing = config.label_smoothing |
|
self.config = config |
|
|
|
self.classifier = XLMRobertaXLHierarchicalClassificationHead(config) |
|
|
|
self.init_weights() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
right_input_ids: Optional[torch.LongTensor] = None, |
|
right_attention_mask: Optional[torch.LongTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, HierarchicalSequenceClassifierOutput]: |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = super().forward( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
return_dict=return_dict, |
|
) |
|
sequence_output = outputs[0] |
|
logits = self.classifier(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
if self.config.problem_type is None: |
|
if self.num_labels == 1: |
|
self.config.problem_type = "regression" |
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
|
self.config.problem_type = "single_label_classification" |
|
else: |
|
self.config.problem_type = "multi_label_classification" |
|
|
|
if self.config.problem_type == "regression": |
|
loss_fct = torch.nn.MSELoss() |
|
if self.num_labels == 1: |
|
loss = loss_fct(logits.squeeze(), labels.squeeze()) |
|
else: |
|
loss = loss_fct(logits, labels) |
|
elif self.config.problem_type == "single_label_classification": |
|
if self.label_smoothing is None: |
|
loss_fct = torch.nn.CrossEntropyLoss() |
|
else: |
|
loss_fct = torch.nn.CrossEntropyLoss(label_smoothing=self.label_smoothing) |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
elif self.config.problem_type == "multi_label_classification": |
|
loss_fct = torch.nn.BCEWithLogitsLoss() |
|
loss = loss_fct(logits, labels) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return HierarchicalSequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
embeddings=outputs.embeddings, |
|
layer_embeddings=outputs.layer_embeddings, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions |
|
) |
|
|
|
|
|
AutoConfig.register("hierarchical-xlm-roberta-xl", HierarchicalXLMRobertaXLConfig) |
|
AutoModelForSequenceClassification.register( |
|
HierarchicalXLMRobertaXLConfig, |
|
XLMRobertaXLForHierarchicalSequenceClassification |
|
) |
|
|