from abc import ABC from dataclasses import dataclass from typing import List, Optional, Tuple, Union import numpy as np import torch from torch.nn.modules.loss import _Loss from transformers import XLMRobertaXLPreTrainedModel, XLMRobertaXLModel, XLMRobertaXLConfig from transformers import AutoModelForSequenceClassification, AutoConfig from transformers.modeling_outputs import ModelOutput from pytorch_metric_learning.losses import NTXentLoss @dataclass class HierarchicalSequenceEmbedderOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None embeddings: torch.FloatTensor = None layer_embeddings: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None @dataclass class HierarchicalSequenceClassifierOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None embeddings: torch.FloatTensor = None layer_embeddings: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None class HierarchicalXLMRobertaXLConfig(XLMRobertaXLConfig): model_type = "hierarchical-xlm-roberta-xl" def __init__(self, label_smoothing: Optional[float] = None, **kwargs): super().__init__(**kwargs) self.label_smoothing = label_smoothing class XLMRobertaXLHierarchicalClassificationHead(torch.nn.Module): def __init__(self, config): super().__init__() self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size) classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dropout = torch.nn.Dropout(classifier_dropout) self.out_proj = torch.nn.Linear(config.hidden_size, config.num_labels) def forward(self, features, **kwargs): x = self.dropout(features) x = self.dense(x) x = torch.tanh(x) x = self.dropout(x) x = self.out_proj(x) return x def distance_to_probability(distance: torch.Tensor, margin: float) -> torch.Tensor: margin = torch.full(size=distance.size(), fill_value=margin, dtype=distance.dtype, device=distance.device, requires_grad=False) p = (1.0 + torch.exp(-margin)) / (1.0 + torch.exp(distance - margin)) del margin return p class DistanceBasedLogisticLoss(_Loss): __constants__ = ['margin', 'reduction'] margin: float def __init__(self, margin: float = 1.0, size_average=None, reduce=None, reduction: str = 'mean'): super(DistanceBasedLogisticLoss, self).__init__(size_average, reduce, reduction) self.margin = margin def forward(self, inputs, targets): inputs = inputs.view(-1) targets = targets.to(inputs.dtype).view(-1) p = distance_to_probability(inputs, self.margin) return torch.nn.functional.binary_cross_entropy(input=p, target=targets, reduction=self.reduction) class LayerGatingNetwork(torch.nn.Module): __constants__ = ['in_features'] in_features: int weight: torch.Tensor def __init__(self, in_features: int, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.in_features = in_features self.weight = torch.nn.Parameter(torch.empty((1, in_features), **factory_kwargs)) self.reset_parameters() def reset_parameters(self) -> None: initial_layer_weights = np.array( [1.0 / (self.in_features - layer_idx) for layer_idx in range(self.in_features)], dtype=np.float32 ) initial_layer_weights /= np.sum(initial_layer_weights) initial_layer_weights_pt = torch.tensor( initial_layer_weights.reshape((1, self.in_features)), dtype=self.weight.dtype, device=self.weight.device ) del initial_layer_weights self.weight = torch.nn.Parameter(initial_layer_weights_pt) del initial_layer_weights_pt def forward(self, input: torch.Tensor) -> torch.Tensor: return torch.nn.functional.linear(input, torch.softmax(self.weight, dim=-1)) def extra_repr(self) -> str: return 'in_features={}'.format(self.in_features) class XLMRobertaXLForHierarchicalEmbedding(XLMRobertaXLPreTrainedModel, ABC): config_class = HierarchicalXLMRobertaXLConfig def __init__(self, config: HierarchicalXLMRobertaXLConfig): super().__init__(config) self.num_labels = config.num_labels self.temperature = config.temperature self.config = config self.roberta = XLMRobertaXLModel(config, add_pooling_layer=False) self.layer_weights = LayerGatingNetwork(in_features=config.num_hidden_layers) self.init_weights() def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, right_input_ids: Optional[torch.LongTensor] = None, right_attention_mask: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, HierarchicalSequenceEmbedderOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=True, return_dict=False ) cls_hidden_states = torch.stack( tensors=outputs[2][-self.config.num_hidden_layers:], dim=1 )[:, :, 0, :] cls_emb = self.layer_weights(cls_hidden_states.permute(0, 2, 1))[:, :, 0] loss = None if labels is not None: cls_emb_ = cls_emb.view(-1, self.config.hidden_size) emb_norm = torch.linalg.norm(cls_emb_, dim=-1, keepdim=True) + 1e-9 if (right_input_ids is not None) or (right_attention_mask is not None): if right_input_ids is None: raise ValueError(f'right_input_ids is not specified!') if right_attention_mask is None: raise ValueError(f'right_attention_mask is not specified!') right_outputs = self.roberta( right_input_ids, attention_mask=right_attention_mask, output_hidden_states=True, return_dict=False ) right_cls_hidden_states = torch.stack( tensors=right_outputs[2][-self.config.num_hidden_layers:], dim=1 )[:, :, 0, :] right_cls_emb = self.layer_weights(right_cls_hidden_states.permute(0, 2, 1))[:, :, 0] right_cls_emb_ = right_cls_emb.view(-1, self.config.hidden_size) right_emb_norm = torch.linalg.norm(right_cls_emb_, dim=-1, keepdim=True) + 1e-9 distances = torch.norm(cls_emb_ / emb_norm - right_cls_emb_ / right_emb_norm, 2, dim=-1) loss_fct = DistanceBasedLogisticLoss(margin=1.0) loss = loss_fct(distances, labels.view(-1)) else: loss_fct = NTXentLoss(temperature=self.temperature) loss = loss_fct(cls_emb_ / emb_norm, labels.view(-1)) if not return_dict: output = (cls_emb, cls_hidden_states) + outputs[2:] return ((loss,) + output) if loss is not None else output return HierarchicalSequenceEmbedderOutput( loss=loss, embeddings=cls_emb, layer_embeddings=cls_hidden_states, hidden_states=outputs[2], attentions=outputs[3] if output_attentions else None, ) @property def layer_importances(self) -> List[Tuple[int, float]]: with torch.no_grad(): importances = torch.softmax(self.layer_weights.weight, dim=-1).detach().cpu().numpy().flatten() indices_and_importances = [] for layer_idx in range(importances.shape[0]): indices_and_importances.append((layer_idx + 1, float(importances[layer_idx]))) indices_and_importances.sort(key=lambda it: (-it[1], it[0])) return indices_and_importances class XLMRobertaXLForHierarchicalSequenceClassification(XLMRobertaXLForHierarchicalEmbedding, ABC): def __init__(self, config: HierarchicalXLMRobertaXLConfig): super().__init__(config) self.num_labels = config.num_labels self.label_smoothing = config.label_smoothing self.config = config self.classifier = XLMRobertaXLHierarchicalClassificationHead(config) self.init_weights() def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, right_input_ids: Optional[torch.LongTensor] = None, right_attention_mask: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, HierarchicalSequenceClassifierOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = super().forward( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, return_dict=return_dict, ) sequence_output = outputs[0] logits = self.classifier(sequence_output) loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = torch.nn.MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": if self.label_smoothing is None: loss_fct = torch.nn.CrossEntropyLoss() else: loss_fct = torch.nn.CrossEntropyLoss(label_smoothing=self.label_smoothing) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = torch.nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs return ((loss,) + output) if loss is not None else output return HierarchicalSequenceClassifierOutput( loss=loss, logits=logits, embeddings=outputs.embeddings, layer_embeddings=outputs.layer_embeddings, hidden_states=outputs.hidden_states, attentions=outputs.attentions ) AutoConfig.register("hierarchical-xlm-roberta-xl", HierarchicalXLMRobertaXLConfig) AutoModelForSequenceClassification.register( HierarchicalXLMRobertaXLConfig, XLMRobertaXLForHierarchicalSequenceClassification )