relik-reader-deberta-v3-large-nyt / modeling_relik_dev.py
PereLluis13's picture
Upload model
7d9090e verified
raw
history blame
44.9 kB
from typing import Any, Dict, Optional
import torch
from transformers import AutoModel, PreTrainedModel
from transformers.activations import ClippedGELUActivation, GELUActivation
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PoolerEndLogits
from .configuration_relik import RelikReaderConfig
torch.set_float32_matmul_precision('medium')
def get_mention_features(
features: torch.Tensor,
starts: torch.Tensor,
ends: torch.Tensor,
batch_size: int,
) -> torch.Tensor:
# Step 1: Create the mask for the spans
start_positions = starts.nonzero(as_tuple=True)
end_positions = ends.nonzero(as_tuple=True)
averaged_features = []
for batch_idx, (start, end) in enumerate(zip(start_positions[1], end_positions[1])):
# Select the batch where the span is located
batch_id = start_positions[0][batch_idx]
# Extract features for the span
span_features = features[batch_id, start:end+1]
# Compute the average of the features
span_avg = span_features.mean(dim=0)
averaged_features.append(span_avg)
averaged_features = torch.stack(averaged_features, dim=0)
# use torch.nn.utils.rnn.pad_sequence and split to pad and split the features to batch_size
averaged_features = torch.nn.utils.rnn.pad_sequence(
torch.split(averaged_features, torch.sum(starts, dim=1), dim=0),
)
return averaged_features
def custom_margin_ranking_loss(scores, labels, margin=1.0):
"""
Custom implementation of margin ranking loss for imbalanced positive/negative scores.
Parameters:
- scores: Tensor containing the scores for each example.
- labels: Tensor containing labels (1 for positive, 0 for negative examples).
- margin: Desired margin between positive and negative scores.
Returns:
- loss: Computed loss value.
"""
# Separate scores into positive and negative based on labels
positive_scores = scores[labels == 1]
negative_scores = scores[labels == 0]
# Compute all pairs of positive-negative scores
pos_neg_diff = positive_scores.unsqueeze(1) - negative_scores.unsqueeze(0)
# Calculate loss for all positive-negative pairs
loss_components = torch.clamp(margin - pos_neg_diff, min=0)
# Average loss across all comparisons
loss = torch.mean(loss_components)
return loss
def split_and_process(tensor, projector_start, projector_end):
# Split the tensor along the last dimension
half = tensor.shape[-1] // 2
tensor_start, tensor_end = tensor[..., :half], tensor[..., half:]
# Apply the linear layers
tensor_start_processed = projector_start(tensor_start)
tensor_end_processed = projector_end(tensor_end)
return tensor_start_processed, tensor_end_processed
def get_mention_features_vectorized(features, starts, ends, batch_size):
# Create a range tensor that matches the size of the longest span
max_length = (ends - starts).max() + 1
range_tensor = torch.arange(max_length).to(features.device)
# Expand starts and range_tensor to compute a mask for each position in each span
expanded_starts = starts.unsqueeze(-1) # Adding an extra dimension for broadcasting
expanded_ends = ends.unsqueeze(-1)
range_mask = (range_tensor < (expanded_ends - expanded_starts + 1))
# Use the mask to select features, handling variable-length sequences with padding
span_lengths = (expanded_ends - expanded_starts).squeeze(-1) + 1
max_span_length = span_lengths.max()
padded_features = torch.zeros((batch_size, max_span_length, features.size(-1)), device=features.device)
for i in range(batch_size):
span = features[i, starts[i]:ends[i]+1]
padded_features[i, :span.size(0)] = span
# Compute the mean of features for each span, using the mask for correct averaging
span_means = (padded_features * range_mask.unsqueeze(-1)).sum(dim=1) / range_mask.sum(dim=1, keepdim=True)
return span_means
def random_half_tensor_dropout(tensor, dropout_prob=0.5, is_training=True):
"""
Applies dropout to either the first half or the second half of the tensor with a specified probability.
Dropout is only applied during training.
Args:
tensor (torch.Tensor): The input tensor.
dropout_prob (float): The probability of dropping out half of the tensor.
is_training (bool): If True, apply dropout; if False, do not apply dropout.
Returns:
torch.Tensor: The tensor after applying dropout.
"""
assert 0 <= dropout_prob <= 1, "Dropout probability must be in the range [0, 1]"
if is_training:
# Size of the last dimension
last_dim_size = tensor.size(-1)
# Calculate the index for splitting the tensor into two halves
split_index = last_dim_size // 2
# Generate a random number and compare it with the dropout probability
if torch.rand(1).item() < dropout_prob:
# Randomly choose to drop the first half or the second half
if torch.rand(1).item() < 0.5:
# Set the first half to zero
tensor[..., :split_index] = 0
else:
# Set the second half to zero
tensor[..., split_index:] = 0
return tensor
class RelikReaderSample:
def __init__(self, **kwargs):
super().__setattr__("_d", {})
self._d = kwargs
def __getattribute__(self, item):
return super(RelikReaderSample, self).__getattribute__(item)
def __getattr__(self, item):
if item.startswith("__") and item.endswith("__"):
# this is likely some python library-specific variable (such as __deepcopy__ for copy)
# better follow standard behavior here
raise AttributeError(item)
elif item in self._d:
return self._d[item]
else:
return None
def __setattr__(self, key, value):
if key in self._d:
self._d[key] = value
else:
super().__setattr__(key, value)
activation2functions = {
"relu": torch.nn.ReLU(),
"gelu": GELUActivation(),
"gelu_10": ClippedGELUActivation(-10, 10),
}
class PoolerEndLogitsBi(PoolerEndLogits):
def __init__(self, config: PretrainedConfig):
super().__init__(config)
self.dense_1 = torch.nn.Linear(config.hidden_size, 2)
def forward(
self,
hidden_states: torch.FloatTensor,
start_states: Optional[torch.FloatTensor] = None,
start_positions: Optional[torch.LongTensor] = None,
p_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
if p_mask is not None:
p_mask = p_mask.unsqueeze(-1)
logits = super().forward(
hidden_states,
start_states,
start_positions,
p_mask,
)
return logits
class RelikReaderSpanModel(PreTrainedModel):
config_class = RelikReaderConfig
def __init__(self, config: RelikReaderConfig, *args, **kwargs):
super().__init__(config)
# Transformer model declaration
self.config = config
self.transformer_model = (
AutoModel.from_pretrained(self.config.transformer_model)
if self.config.num_layers is None
else AutoModel.from_pretrained(
self.config.transformer_model, num_hidden_layers=self.config.num_layers
)
)
self.transformer_model.resize_token_embeddings(
self.transformer_model.config.vocab_size
+ self.config.additional_special_symbols,
pad_to_multiple_of=8,
)
self.activation = self.config.activation
self.linears_hidden_size = self.config.linears_hidden_size
self.use_last_k_layers = self.config.use_last_k_layers
# named entity detection layers
self.ned_start_classifier = self._get_projection_layer(
self.activation, last_hidden=2, layer_norm=False
)
self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config)
# END entity disambiguation layer
self.ed_projector = self._get_projection_layer(self.activation, last_hidden = 2*self.linears_hidden_size, hidden=2*self.linears_hidden_size)
self.training = self.config.training
# criterion
self.criterion = torch.nn.CrossEntropyLoss()
def _get_projection_layer(
self,
activation: str,
last_hidden: Optional[int] = None,
hidden: Optional[int] = None,
input_hidden=None,
layer_norm: bool = True,
) -> torch.nn.Sequential:
head_components = [
torch.nn.Dropout(0.1),
torch.nn.Linear(
(
self.transformer_model.config.hidden_size * self.use_last_k_layers
if input_hidden is None
else input_hidden
),
self.linears_hidden_size if hidden is None else hidden,
),
activation2functions[activation],
torch.nn.Dropout(0.1),
torch.nn.Linear(
self.linears_hidden_size if hidden is None else hidden,
self.linears_hidden_size if last_hidden is None else last_hidden,
),
]
if layer_norm:
head_components.append(
torch.nn.LayerNorm(
self.linears_hidden_size if last_hidden is None else last_hidden,
self.transformer_model.config.layer_norm_eps,
)
)
return torch.nn.Sequential(*head_components)
def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
mask = mask.unsqueeze(-1)
if next(self.parameters()).dtype == torch.float16:
logits = logits * (1 - mask) - 65500 * mask
else:
logits = logits * (1 - mask) - 1e30 * mask
return logits
def _get_model_features(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor],
):
model_input = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"output_hidden_states": self.use_last_k_layers > 1,
}
if token_type_ids is not None:
model_input["token_type_ids"] = token_type_ids
model_output = self.transformer_model(**model_input)
if self.use_last_k_layers > 1:
model_features = torch.cat(
model_output[1][-self.use_last_k_layers :], dim=-1
)
else:
model_features = model_output[0]
return model_features
def compute_ned_end_logits(
self,
start_predictions,
start_labels,
model_features,
prediction_mask,
batch_size,
) -> Optional[torch.Tensor]:
# todo: maybe when constraining on the spans,
# we should not use a prediction_mask for the end tokens.
# at least we should not during training imo
start_positions = start_labels if self.training else start_predictions
start_positions_indices = (
torch.arange(start_positions.size(1), device=start_positions.device)
.unsqueeze(0)
.expand(batch_size, -1)[start_positions > 0]
).to(start_positions.device)
if len(start_positions_indices) > 0:
expanded_features = model_features.repeat_interleave(
torch.sum(start_positions > 0, dim=-1), dim=0
)
expanded_prediction_mask = prediction_mask.repeat_interleave(
torch.sum(start_positions > 0, dim=-1), dim=0
)
end_logits = self.ned_end_classifier(
hidden_states=expanded_features,
start_positions=start_positions_indices,
p_mask=expanded_prediction_mask,
)
return end_logits
return None
def compute_classification_logits(
self,
model_features,
special_symbols_mask,
prediction_mask,
batch_size,
start_positions=None,
end_positions=None,
attention_mask=None,
) -> torch.Tensor:
if start_positions is None or end_positions is None:
start_positions = torch.zeros_like(prediction_mask)
end_positions = torch.zeros_like(prediction_mask)
model_ed_features = self.ed_projector(model_features)
model_ed_features[start_positions > 0][:, model_ed_features.shape[-1] // 2:] = model_ed_features[end_positions > 0][
:, :model_ed_features.shape[-1] // 2
]
# computing ed features
classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item()
special_symbols_mask_start = special_symbols_mask.roll(1, 1)
special_symbols_mask_start[:, :2] = torch.tensor([True, False], device=special_symbols_mask.device).expand_as(
special_symbols_mask_start[:, :2]
)
special_symbols_mask_end = special_symbols_mask.roll(-1, 1)
cumsum = special_symbols_mask_end.cumsum(dim=1)
# Identify the second True in each row (where cumulative sum equals 2)
special_symbols_mask_end[cumsum == 2] = False
special_symbols_mask_end[:, [0, -1]] = torch.tensor([True, False], device=special_symbols_mask.device).expand_as(
special_symbols_mask_end[:, [0, -1]]
)
# first padding token in
last_token_ent = attention_mask.sum(1) - 2
special_symbols_mask_end[torch.arange(special_symbols_mask_end.shape[0], device=special_symbols_mask_end.device), last_token_ent] = True
special_symbols_representation_start = model_ed_features[special_symbols_mask_start][:,:model_ed_features.shape[-1] // 2].view(
batch_size, classes_representations, -1
)
special_symbols_representation_end = model_ed_features[special_symbols_mask_end][:,model_ed_features.shape[-1] // 2:].view(
batch_size, classes_representations, -1
)
# special_symbols_representation_start = self.ed_special_tokens_projector_start(special_symbols_representation_start)
# special_symbols_representation_end = self.ed_special_tokens_projector_end(special_symbols_representation_end)
special_symbols_representation = torch.cat(
[special_symbols_representation_start, special_symbols_representation_end, special_symbols_representation_end, special_symbols_representation_start], dim=-1
)
model_ed_features = torch.cat(
[model_ed_features, model_ed_features], dim=-1
)
logits = torch.bmm(
model_ed_features,
torch.permute(special_symbols_representation, (0, 2, 1)),
)
logits = self._mask_logits(logits, prediction_mask)
return logits
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
prediction_mask: Optional[torch.Tensor] = None,
special_symbols_mask: Optional[torch.Tensor] = None,
start_labels: Optional[torch.Tensor] = None,
end_labels: Optional[torch.Tensor] = None,
use_predefined_spans: bool = False,
*args,
**kwargs,
) -> Dict[str, Any]:
batch_size, seq_len = input_ids.shape
model_features = self._get_model_features(
input_ids, attention_mask, token_type_ids
)
ned_start_labels = None
# named entity detection if required
if use_predefined_spans: # no need to compute spans
ned_start_logits, ned_start_probabilities, ned_start_predictions = (
None,
None,
torch.clone(start_labels)
if start_labels is not None
else torch.zeros_like(input_ids),
)
ned_end_logits, ned_end_probabilities, ned_end_predictions = (
None,
None,
torch.clone(end_labels)
if end_labels is not None
else torch.zeros_like(input_ids),
)
ned_start_predictions[ned_start_predictions > 0] = 1
ned_end_predictions[ned_end_predictions > 0] = 1
else: # compute spans
# start boundary prediction
ned_start_logits = self.ned_start_classifier(model_features)
ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask)
ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
ned_start_predictions = ned_start_probabilities.argmax(dim=-1)
# end boundary prediction
ned_start_labels = (
torch.zeros_like(start_labels) if start_labels is not None else None
)
if ned_start_labels is not None:
ned_start_labels[start_labels == -100] = -100
ned_start_labels[start_labels > 0] = 1
ned_end_logits = self.compute_ned_end_logits(
ned_start_predictions,
ned_start_labels,
model_features,
prediction_mask,
batch_size,
)
if ned_end_logits is not None:
ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
else:
ned_end_logits, ned_end_probabilities = None, None
ned_end_predictions = ned_start_predictions.new_zeros(batch_size)
# flattening end predictions
# (flattening can happen only if the
# end boundaries were not predicted using the gold labels)
if not self.training and ned_end_logits is not None:
flattened_end_predictions = torch.zeros_like(ned_start_predictions)
row_indices, start_positions = torch.where(ned_start_predictions > 0)
ned_end_predictions[ned_end_predictions<start_positions] = start_positions[ned_end_predictions<start_positions]
end_spans_repeated = (row_indices + 1)* seq_len + ned_end_predictions
cummax_values, _ = end_spans_repeated.cummax(dim=0)
end_spans_repeated = (end_spans_repeated > torch.cat((end_spans_repeated[:1], cummax_values[:-1])))
end_spans_repeated[0] = True
ned_start_predictions[row_indices[~end_spans_repeated], start_positions[~end_spans_repeated]] = 0
row_indices, start_positions, ned_end_predictions = row_indices[end_spans_repeated], start_positions[end_spans_repeated], ned_end_predictions[end_spans_repeated]
flattened_end_predictions[row_indices, ned_end_predictions] = 1
total_start_predictions, total_end_predictions = ned_start_predictions.sum(), flattened_end_predictions.sum()
assert (
total_start_predictions == 0
or total_start_predictions == total_end_predictions
), (
f"Total number of start predictions = {total_start_predictions}. "
f"Total number of end predictions = {total_end_predictions}"
)
ned_end_predictions = flattened_end_predictions
else:
ned_end_predictions = torch.zeros_like(ned_start_predictions)
start_position, end_position = (
(start_labels, end_labels)
if self.training
else (ned_start_predictions, ned_end_predictions)
)
# Entity disambiguation
ed_logits = self.compute_classification_logits(
model_features,
special_symbols_mask,
prediction_mask,
batch_size,
start_position,
end_position,
attention_mask,
)
ed_probabilities = torch.softmax(ed_logits, dim=-1)
ed_predictions = torch.argmax(ed_probabilities, dim=-1)
# output build
output_dict = dict(
batch_size=batch_size,
ned_start_logits=ned_start_logits,
ned_start_probabilities=ned_start_probabilities,
ned_start_predictions=ned_start_predictions,
ned_end_logits=ned_end_logits,
ned_end_probabilities=ned_end_probabilities,
ned_end_predictions=ned_end_predictions,
ed_logits=ed_logits,
ed_probabilities=ed_probabilities,
ed_predictions=ed_predictions,
)
# compute loss if labels
if start_labels is not None and end_labels is not None and self.training:
# named entity detection loss
# start
if ned_start_logits is not None:
ned_start_loss = self.criterion(
ned_start_logits.view(-1, ned_start_logits.shape[-1]),
ned_start_labels.view(-1),
)
else:
ned_start_loss = 0
# end
if ned_end_logits is not None:
ned_end_labels = torch.zeros_like(end_labels)
ned_end_labels[end_labels == -100] = -100
ned_end_labels[end_labels > 0] = 1
ned_end_loss = self.criterion(
ned_end_logits,
(
torch.arange(
ned_end_labels.size(1), device=ned_end_labels.device
)
.unsqueeze(0)
.expand(batch_size, -1)[ned_end_labels > 0]
).to(ned_end_labels.device),
)
else:
ned_end_loss = 0
# entity disambiguation loss
start_labels[ned_start_labels != 1] = -100
ed_labels = torch.clone(start_labels)
ed_labels[end_labels > 0] = end_labels[end_labels > 0]
ed_loss = self.criterion(
ed_logits.view(-1, ed_logits.shape[-1]),
ed_labels.view(-1),
)
output_dict["ned_start_loss"] = ned_start_loss
output_dict["ned_end_loss"] = ned_end_loss
output_dict["ed_loss"] = ed_loss
output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss
return output_dict
class RelikReaderREModel(PreTrainedModel):
config_class = RelikReaderConfig
def __init__(self, config, *args, **kwargs):
super().__init__(config)
# Transformer model declaration
# self.transformer_model_name = transformer_model
self.config = config
self.transformer_model = (
AutoModel.from_pretrained(config.transformer_model)
if config.num_layers is None
else AutoModel.from_pretrained(
config.transformer_model, num_hidden_layers=config.num_layers
)
)
self.transformer_model.resize_token_embeddings(
self.transformer_model.config.vocab_size
+ config.additional_special_symbols
+ config.additional_special_symbols_types,
pad_to_multiple_of=8,
)
# named entity detection layers
self.ned_start_classifier = self._get_projection_layer(
config.activation, last_hidden=2, layer_norm=False
)
self.ned_end_classifier = PoolerEndLogitsBi(self.transformer_model.config)
self.relation_disambiguation_loss = (
config.relation_disambiguation_loss
if hasattr(config, "relation_disambiguation_loss")
else False
)
if self.config.entity_type_loss and self.config.add_entity_embedding:
input_hidden_ents = 3 * self.config.linears_hidden_size
else:
input_hidden_ents = 2 * self.config.linears_hidden_size
self.re_projector = self._get_projection_layer(
config.activation, input_hidden=2*self.transformer_model.config.hidden_size, hidden=input_hidden_ents, last_hidden=2*self.config.linears_hidden_size
)
self.re_relation_projector = self._get_projection_layer(
config.activation, input_hidden=self.transformer_model.config.hidden_size,
)
if self.config.entity_type_loss or self.relation_disambiguation_loss:
self.re_entities_projector = self._get_projection_layer(
config.activation,
input_hidden=2 * self.transformer_model.config.hidden_size,
)
self.re_definition_projector = self._get_projection_layer(
config.activation,
)
self.re_classifier = self._get_projection_layer(
config.activation,
input_hidden=config.linears_hidden_size,
last_hidden=2,
layer_norm=False,
)
self.training = config.training
# criterion
self.criterion = torch.nn.CrossEntropyLoss()
self.criterion_type = torch.nn.BCEWithLogitsLoss()
def _get_projection_layer(
self,
activation: str,
last_hidden: Optional[int] = None,
hidden: Optional[int] = None,
input_hidden=None,
layer_norm: bool = True,
) -> torch.nn.Sequential:
head_components = [
torch.nn.Dropout(0.1),
torch.nn.Linear(
(
self.transformer_model.config.hidden_size * self.config.use_last_k_layers
if input_hidden is None
else input_hidden
),
self.config.linears_hidden_size if hidden is None else hidden,
),
activation2functions[activation],
torch.nn.Dropout(0.1),
torch.nn.Linear(
self.config.linears_hidden_size if hidden is None else hidden,
self.config.linears_hidden_size if last_hidden is None else last_hidden,
),
]
if layer_norm:
head_components.append(
torch.nn.LayerNorm(
self.config.linears_hidden_size if last_hidden is None else last_hidden,
self.transformer_model.config.layer_norm_eps,
)
)
return torch.nn.Sequential(*head_components)
def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
mask = mask.unsqueeze(-1)
if next(self.parameters()).dtype == torch.float16:
logits = logits * (1 - mask) - 65500 * mask
else:
logits = logits * (1 - mask) - 1e30 * mask
return logits
def _get_model_features(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor],
):
model_input = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"output_hidden_states": self.config.use_last_k_layers > 1,
}
if token_type_ids is not None:
model_input["token_type_ids"] = token_type_ids
model_output = self.transformer_model(**model_input)
if self.config.use_last_k_layers > 1:
model_features = torch.cat(
model_output[1][-self.config.use_last_k_layers :], dim=-1
)
else:
model_features = model_output[0]
return model_features
def compute_ned_end_logits(
self,
start_predictions,
start_labels,
model_features,
prediction_mask,
batch_size,
mask_preceding: bool = False,
) -> Optional[torch.Tensor]:
# todo: maybe when constraining on the spans,
# we should not use a prediction_mask for the end tokens.
# at least we should not during training imo
start_positions = start_labels if self.training else start_predictions
start_positions_indices = (
torch.arange(start_positions.size(1), device=start_positions.device)
.unsqueeze(0)
.expand(batch_size, -1)[start_positions > 0]
).to(start_positions.device)
if len(start_positions_indices) > 0:
expanded_features = model_features.repeat_interleave(
torch.sum(start_positions > 0, dim=-1), dim=0
)
expanded_prediction_mask = prediction_mask.repeat_interleave(
torch.sum(start_positions > 0, dim=-1), dim=0
)
if mask_preceding:
expanded_prediction_mask[
torch.arange(
expanded_prediction_mask.shape[1],
device=expanded_prediction_mask.device,
)
< start_positions_indices.unsqueeze(1)
] = 1
end_logits = self.ned_end_classifier(
hidden_states=expanded_features,
start_positions=start_positions_indices,
p_mask=expanded_prediction_mask,
)
return end_logits
return None
def compute_relation_logits(
self,
model_entity_features,
special_symbols_features,
) -> torch.Tensor:
model_subject_object_features = self.re_projector(model_entity_features)
model_subject_features = model_subject_object_features[
:, :, : model_subject_object_features.shape[-1] // 2
]
model_object_features = model_subject_object_features[
:, :, model_subject_object_features.shape[-1] // 2 :
]
special_symbols_start_representation = self.re_relation_projector(
special_symbols_features
)
re_logits = torch.einsum(
"bse,bde,bfe->bsdfe",
model_subject_features,
model_object_features,
special_symbols_start_representation,
)
re_logits = self.re_classifier(re_logits)
return re_logits
def compute_entity_logits(
self,
model_entity_features,
special_symbols_features,
) -> torch.Tensor:
model_ed_features = self.re_entities_projector(model_entity_features)
special_symbols_ed_representation = self.re_definition_projector(
special_symbols_features
)
logits = torch.bmm(
model_ed_features,
torch.permute(special_symbols_ed_representation, (0, 2, 1)),
)
logits = self._mask_logits(
logits, (model_entity_features == -100).all(2).long()
)
return logits
def compute_loss(self, logits, labels, mask=None):
logits = logits.reshape(-1, logits.shape[-1])
labels = labels.reshape(-1).long()
if mask is not None:
return self.criterion(logits[mask], labels[mask])
return self.criterion(logits, labels)
def compute_ned_type_loss(
self,
disambiguation_labels,
re_ned_entities_logits,
ned_type_logits,
re_entities_logits,
entity_types,
mask,
):
if self.config.entity_type_loss and self.relation_disambiguation_loss:
return self.criterion_type(
re_ned_entities_logits[disambiguation_labels != -100],
disambiguation_labels[disambiguation_labels != -100],
)
if self.config.entity_type_loss:
return self.criterion_type(
ned_type_logits[mask],
disambiguation_labels[:, :, :entity_types][mask],
)
if self.relation_disambiguation_loss:
return self.criterion_type(
re_entities_logits[disambiguation_labels != -100],
disambiguation_labels[disambiguation_labels != -100],
)
return 0
def compute_relation_loss(self, relation_labels, re_logits):
return self.compute_loss(
re_logits, relation_labels, relation_labels.view(-1) != -100
)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor,
prediction_mask: Optional[torch.Tensor] = None,
special_symbols_mask: Optional[torch.Tensor] = None,
special_symbols_mask_entities: Optional[torch.Tensor] = None,
start_labels: Optional[torch.Tensor] = None,
end_labels: Optional[torch.Tensor] = None,
disambiguation_labels: Optional[torch.Tensor] = None,
relation_labels: Optional[torch.Tensor] = None,
relation_threshold: float = None,
is_validation: bool = False,
is_prediction: bool = False,
use_predefined_spans: bool = False,
*args,
**kwargs,
) -> Dict[str, Any]:
thresshold = self.config.threshold if relation_threshold is None else relation_threshold
batch_size = input_ids.shape[0]
model_features = self._get_model_features(
input_ids, attention_mask, token_type_ids
)
# named entity detection
if use_predefined_spans:
ned_start_logits, ned_start_probabilities, ned_start_predictions = (
None,
None,
torch.zeros_like(start_labels),
)
ned_end_logits, ned_end_probabilities, ned_end_predictions = (
None,
None,
torch.zeros_like(end_labels),
)
ned_start_predictions[start_labels > 0] = 1
ned_end_predictions[end_labels > 0] = 1
ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)]
ned_start_labels = start_labels
ned_start_labels[start_labels > 0] = 1
else:
# start boundary prediction
ned_start_logits = self.ned_start_classifier(model_features)
if is_validation or is_prediction:
ned_start_logits = self._mask_logits(
ned_start_logits, prediction_mask
) # why?
ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
ned_start_predictions = ned_start_probabilities.argmax(dim=-1)
# end boundary prediction
ned_start_labels = (
torch.zeros_like(start_labels) if start_labels is not None else None
)
# start_labels contain entity id at their position, we just need 1 for start of entity
if ned_start_labels is not None:
ned_start_labels[start_labels == -100] = -100
ned_start_labels[start_labels > 0] = 1
# compute end logits only if there are any start predictions.
# For each start prediction, n end predictions are made
ned_end_logits = self.compute_ned_end_logits(
ned_start_predictions,
ned_start_labels,
model_features,
prediction_mask,
batch_size,
True,
)
if ned_end_logits is not None:
# For each start prediction, n end predictions are made based on
# binary classification ie. argmax at each position.
ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
ned_end_predictions = ned_end_probabilities.argmax(dim=-1)
else:
ned_end_logits, ned_end_probabilities = None, None
ned_end_predictions = torch.zeros_like(ned_start_predictions)
if is_prediction or is_validation:
end_preds_count = ned_end_predictions.sum(1)
# If there are no end predictions for a start prediction, remove the start prediction
if (end_preds_count == 0).any() and (ned_start_predictions > 0).any():
ned_start_predictions[ned_start_predictions == 1] = (
end_preds_count != 0
).long()
ned_end_predictions = ned_end_predictions[end_preds_count != 0]
if end_labels is not None:
end_labels = end_labels[~(end_labels == -100).all(2)]
start_position, end_position = (
(start_labels, end_labels)
if (not is_prediction and not is_validation)
else (ned_start_predictions, ned_end_predictions)
)
start_counts = (start_position > 0).sum(1)
if (start_counts > 0).any():
ned_end_predictions = ned_end_predictions.split(start_counts.tolist())
# limit to 30 predictions per document using start_counts, by setting all po after sum is 30 to 0
# if is_validation or is_prediction:
# ned_start_predictions[ned_start_predictions == 1] = start_counts
# We can only predict relations if we have start and end predictions
if (end_position > 0).sum() > 0:
ends_count = (end_position > 0).sum(1)
model_subject_features = torch.cat(
[
torch.repeat_interleave(
model_features[start_position > 0], ends_count, dim=0
), # start position features
torch.repeat_interleave(model_features, start_counts, dim=0)[
end_position > 0
], # end position features
],
dim=-1,
)
ents_count = torch.nn.utils.rnn.pad_sequence(
torch.split(ends_count, start_counts.tolist()),
batch_first=True,
padding_value=0,
).sum(1)
model_subject_features = torch.nn.utils.rnn.pad_sequence(
torch.split(model_subject_features, ents_count.tolist()),
batch_first=True,
padding_value=-100,
)
# if is_validation or is_prediction:
# model_subject_features = model_subject_features[:, :30, :]
# entity disambiguation. Here relation_disambiguation_loss would only be useful to
# reduce the number of candidate relations for the next step, but currently unused.
if self.config.entity_type_loss or self.relation_disambiguation_loss:
(re_ned_entities_logits) = self.compute_entity_logits(
model_subject_features,
model_features[
special_symbols_mask | special_symbols_mask_entities
].view(batch_size, -1, model_features.shape[-1]),
)
entity_types = torch.sum(special_symbols_mask_entities, dim=1)[0].item()
ned_type_logits = re_ned_entities_logits[:, :, :entity_types]
re_entities_logits = re_ned_entities_logits[:, :, entity_types:]
if self.config.entity_type_loss:
ned_type_probabilities = torch.sigmoid(ned_type_logits)
ned_type_predictions = ned_type_probabilities.argmax(dim=-1)
if self.config.add_entity_embedding:
special_symbols_representation = model_features[
special_symbols_mask_entities
].view(batch_size, entity_types, -1)
entities_representation = torch.einsum(
"bsp,bpe->bse",
ned_type_probabilities,
special_symbols_representation,
)
model_subject_features = torch.cat(
[model_subject_features, entities_representation], dim=-1
)
re_entities_probabilities = torch.sigmoid(re_entities_logits)
re_entities_predictions = re_entities_probabilities.round()
else:
(
ned_type_logits,
ned_type_probabilities,
re_entities_logits,
re_entities_probabilities,
) = (None, None, None, None)
ned_type_predictions, re_entities_predictions = (
torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
)
# Compute relation logits
re_logits = self.compute_relation_logits(
model_subject_features,
model_features[special_symbols_mask].view(
batch_size, -1, model_features.shape[-1]
),
)
re_probabilities = torch.softmax(re_logits, dim=-1)
# we set a thresshold instead of argmax in cause it needs to be tweaked
re_predictions = re_probabilities[:, :, :, :, 1] > relation_threshold
re_probabilities = re_probabilities[:, :, :, :, 1]
else:
(
ned_type_logits,
ned_type_probabilities,
re_entities_logits,
re_entities_probabilities,
) = (None, None, None, None)
ned_type_predictions, re_entities_predictions = (
torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device),
)
re_logits, re_probabilities, re_predictions = (
torch.zeros(
[batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
).to(input_ids.device),
torch.zeros(
[batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
).to(input_ids.device),
torch.zeros(
[batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long
).to(input_ids.device),
)
# output build
output_dict = dict(
batch_size=batch_size,
ned_start_logits=ned_start_logits,
ned_start_probabilities=ned_start_probabilities,
ned_start_predictions=ned_start_predictions,
ned_end_logits=ned_end_logits,
ned_end_probabilities=ned_end_probabilities,
ned_end_predictions=ned_end_predictions,
ned_type_logits=ned_type_logits,
ned_type_probabilities=ned_type_probabilities,
ned_type_predictions=ned_type_predictions,
re_entities_logits=re_entities_logits,
re_entities_probabilities=re_entities_probabilities,
re_entities_predictions=re_entities_predictions,
re_logits=re_logits,
re_probabilities=re_probabilities,
re_predictions=re_predictions,
)
if (
start_labels is not None
and end_labels is not None
and relation_labels is not None
and is_prediction is False
):
ned_start_loss = self.compute_loss(ned_start_logits, ned_start_labels)
end_labels[end_labels > 0] = 1
ned_end_loss = self.compute_loss(ned_end_logits, end_labels)
if self.config.entity_type_loss or self.relation_disambiguation_loss:
ned_type_loss = self.compute_ned_type_loss(
disambiguation_labels,
re_ned_entities_logits,
ned_type_logits,
re_entities_logits,
entity_types,
(model_subject_features != -100).all(2),
)
relation_loss = self.compute_relation_loss(relation_labels, re_logits)
# compute loss. We can skip the relation loss if we are in the first epochs (optional)
if self.config.entity_type_loss or self.relation_disambiguation_loss:
output_dict["loss"] = (
ned_start_loss + ned_end_loss + relation_loss + ned_type_loss
) / 4
output_dict["ned_type_loss"] = ned_type_loss
else:
# output_dict["loss"] = ((1 / 4) * (ned_start_loss + ned_end_loss)) + (
# (1 / 2) * relation_loss
# )
output_dict["loss"] = ((1 / 20) * (ned_start_loss + ned_end_loss)) + (
(9 / 10) * relation_loss
)
output_dict["ned_start_loss"] = ned_start_loss
output_dict["ned_end_loss"] = ned_end_loss
output_dict["re_loss"] = relation_loss
return output_dict