xlm-roberta-xl-hallucination-detector / modeling_hierarchical_classifier.py
bond005's picture
The neural network initialization bug is fixed.
5286319 verified
from abc import ABC
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from torch.nn.modules.loss import _Loss
from transformers import XLMRobertaXLPreTrainedModel, XLMRobertaXLModel, XLMRobertaXLConfig
from transformers import AutoModelForSequenceClassification, AutoConfig
from transformers.modeling_outputs import ModelOutput
from pytorch_metric_learning.losses import NTXentLoss
@dataclass
class HierarchicalSequenceEmbedderOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
embeddings: torch.FloatTensor = None
layer_embeddings: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class HierarchicalSequenceClassifierOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
embeddings: torch.FloatTensor = None
layer_embeddings: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class HierarchicalXLMRobertaXLConfig(XLMRobertaXLConfig):
model_type = "hierarchical-xlm-roberta-xl"
def __init__(self, label_smoothing: Optional[float] = None, **kwargs):
super().__init__(**kwargs)
self.label_smoothing = label_smoothing
class XLMRobertaXLHierarchicalClassificationHead(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = torch.nn.Dropout(classifier_dropout)
self.out_proj = torch.nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = self.dropout(features)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
def distance_to_probability(distance: torch.Tensor, margin: float) -> torch.Tensor:
margin = torch.full(size=distance.size(), fill_value=margin,
dtype=distance.dtype, device=distance.device, requires_grad=False)
p = (1.0 + torch.exp(-margin)) / (1.0 + torch.exp(distance - margin))
del margin
return p
class DistanceBasedLogisticLoss(_Loss):
__constants__ = ['margin', 'reduction']
margin: float
def __init__(self, margin: float = 1.0, size_average=None, reduce=None, reduction: str = 'mean'):
super(DistanceBasedLogisticLoss, self).__init__(size_average, reduce, reduction)
self.margin = margin
def forward(self, inputs, targets):
inputs = inputs.view(-1)
targets = targets.to(inputs.dtype).view(-1)
p = distance_to_probability(inputs, self.margin)
return torch.nn.functional.binary_cross_entropy(input=p, target=targets, reduction=self.reduction)
class LayerGatingNetwork(torch.nn.Module):
__constants__ = ['in_features']
in_features: int
weight: torch.Tensor
def __init__(self, in_features: int, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.weight = torch.nn.Parameter(torch.empty((1, in_features), **factory_kwargs))
self.reset_parameters()
def reset_parameters(self) -> None:
initial_layer_weights = np.array(
[1.0 / (self.in_features - layer_idx) for layer_idx in range(self.in_features)],
dtype=np.float32
)
initial_layer_weights /= np.sum(initial_layer_weights)
initial_layer_weights_pt = torch.tensor(
initial_layer_weights.reshape((1, self.in_features)),
dtype=self.weight.dtype,
device=self.weight.device
)
del initial_layer_weights
self.weight = torch.nn.Parameter(initial_layer_weights_pt)
del initial_layer_weights_pt
def forward(self, input: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.linear(input, torch.softmax(self.weight, dim=-1))
def extra_repr(self) -> str:
return 'in_features={}'.format(self.in_features)
class XLMRobertaXLForHierarchicalEmbedding(XLMRobertaXLPreTrainedModel, ABC):
config_class = HierarchicalXLMRobertaXLConfig
def __init__(self, config: HierarchicalXLMRobertaXLConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.temperature = config.temperature
self.config = config
self.roberta = XLMRobertaXLModel(config, add_pooling_layer=False)
self.layer_weights = LayerGatingNetwork(in_features=config.num_hidden_layers)
self.init_weights()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
right_input_ids: Optional[torch.LongTensor] = None,
right_attention_mask: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, HierarchicalSequenceEmbedderOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=False
)
cls_hidden_states = torch.stack(
tensors=outputs[2][-self.config.num_hidden_layers:],
dim=1
)[:, :, 0, :]
cls_emb = self.layer_weights(cls_hidden_states.permute(0, 2, 1))[:, :, 0]
loss = None
if labels is not None:
cls_emb_ = cls_emb.view(-1, self.config.hidden_size)
emb_norm = torch.linalg.norm(cls_emb_, dim=-1, keepdim=True) + 1e-9
if (right_input_ids is not None) or (right_attention_mask is not None):
if right_input_ids is None:
raise ValueError(f'right_input_ids is not specified!')
if right_attention_mask is None:
raise ValueError(f'right_attention_mask is not specified!')
right_outputs = self.roberta(
right_input_ids,
attention_mask=right_attention_mask,
output_hidden_states=True,
return_dict=False
)
right_cls_hidden_states = torch.stack(
tensors=right_outputs[2][-self.config.num_hidden_layers:],
dim=1
)[:, :, 0, :]
right_cls_emb = self.layer_weights(right_cls_hidden_states.permute(0, 2, 1))[:, :, 0]
right_cls_emb_ = right_cls_emb.view(-1, self.config.hidden_size)
right_emb_norm = torch.linalg.norm(right_cls_emb_, dim=-1, keepdim=True) + 1e-9
distances = torch.norm(cls_emb_ / emb_norm - right_cls_emb_ / right_emb_norm, 2, dim=-1)
loss_fct = DistanceBasedLogisticLoss(margin=1.0)
loss = loss_fct(distances, labels.view(-1))
else:
loss_fct = NTXentLoss(temperature=self.temperature)
loss = loss_fct(cls_emb_ / emb_norm, labels.view(-1))
if not return_dict:
output = (cls_emb, cls_hidden_states) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return HierarchicalSequenceEmbedderOutput(
loss=loss,
embeddings=cls_emb,
layer_embeddings=cls_hidden_states,
hidden_states=outputs[2],
attentions=outputs[3] if output_attentions else None,
)
@property
def layer_importances(self) -> List[Tuple[int, float]]:
with torch.no_grad():
importances = torch.softmax(self.layer_weights.weight, dim=-1).detach().cpu().numpy().flatten()
indices_and_importances = []
for layer_idx in range(importances.shape[0]):
indices_and_importances.append((layer_idx + 1, float(importances[layer_idx])))
indices_and_importances.sort(key=lambda it: (-it[1], it[0]))
return indices_and_importances
class XLMRobertaXLForHierarchicalSequenceClassification(XLMRobertaXLForHierarchicalEmbedding, ABC):
def __init__(self, config: HierarchicalXLMRobertaXLConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.label_smoothing = config.label_smoothing
self.config = config
self.classifier = XLMRobertaXLHierarchicalClassificationHead(config)
self.init_weights()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
right_input_ids: Optional[torch.LongTensor] = None,
right_attention_mask: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, HierarchicalSequenceClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = super().forward(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
return_dict=return_dict,
)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = torch.nn.MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
if self.label_smoothing is None:
loss_fct = torch.nn.CrossEntropyLoss()
else:
loss_fct = torch.nn.CrossEntropyLoss(label_smoothing=self.label_smoothing)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = torch.nn.BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs
return ((loss,) + output) if loss is not None else output
return HierarchicalSequenceClassifierOutput(
loss=loss,
logits=logits,
embeddings=outputs.embeddings,
layer_embeddings=outputs.layer_embeddings,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions
)
AutoConfig.register("hierarchical-xlm-roberta-xl", HierarchicalXLMRobertaXLConfig)
AutoModelForSequenceClassification.register(
HierarchicalXLMRobertaXLConfig,
XLMRobertaXLForHierarchicalSequenceClassification
)