3ie-intervention-outcome-entity-linking
/
models
/relik-reader-aida-deberta-small
/modeling_relik.py
from typing import Optional, Dict, Any | |
import torch | |
from transformers import AutoModel, PreTrainedModel | |
from transformers.activations import GELUActivation, ClippedGELUActivation | |
from transformers.configuration_utils import PretrainedConfig | |
from transformers.modeling_utils import PoolerEndLogits | |
from .configuration_relik import RelikReaderConfig | |
class RelikReaderSample: | |
def __init__(self, **kwargs): | |
super().__setattr__("_d", {}) | |
self._d = kwargs | |
def __getattribute__(self, item): | |
return super(RelikReaderSample, self).__getattribute__(item) | |
def __getattr__(self, item): | |
if item.startswith("__") and item.endswith("__"): | |
# this is likely some python library-specific variable (such as __deepcopy__ for copy) | |
# better follow standard behavior here | |
raise AttributeError(item) | |
elif item in self._d: | |
return self._d[item] | |
else: | |
return None | |
def __setattr__(self, key, value): | |
if key in self._d: | |
self._d[key] = value | |
else: | |
super().__setattr__(key, value) | |
activation2functions = { | |
"relu": torch.nn.ReLU(), | |
"gelu": GELUActivation(), | |
"gelu_10": ClippedGELUActivation(-10, 10), | |
} | |
class PoolerEndLogitsBi(PoolerEndLogits): | |
def __init__(self, config: PretrainedConfig): | |
super().__init__(config) | |
self.dense_1 = torch.nn.Linear(config.hidden_size, 2) | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
start_states: Optional[torch.FloatTensor] = None, | |
start_positions: Optional[torch.LongTensor] = None, | |
p_mask: Optional[torch.FloatTensor] = None, | |
) -> torch.FloatTensor: | |
if p_mask is not None: | |
p_mask = p_mask.unsqueeze(-1) | |
logits = super().forward( | |
hidden_states, | |
start_states, | |
start_positions, | |
p_mask, | |
) | |
return logits | |
class RelikReaderSpanModel(PreTrainedModel): | |
config_class = RelikReaderConfig | |
def __init__(self, config: RelikReaderConfig, *args, **kwargs): | |
super().__init__(config) | |
# Transformer model declaration | |
self.config = config | |
self.transformer_model = ( | |
AutoModel.from_pretrained(self.config.transformer_model) | |
if self.config.num_layers is None | |
else AutoModel.from_pretrained( | |
self.config.transformer_model, num_hidden_layers=self.config.num_layers | |
) | |
) | |
self.transformer_model.resize_token_embeddings( | |
self.transformer_model.config.vocab_size | |
+ self.config.additional_special_symbols | |
) | |
self.activation = self.config.activation | |
self.linears_hidden_size = self.config.linears_hidden_size | |
self.use_last_k_layers = self.config.use_last_k_layers | |
# named entity detection layers | |
self.ned_start_classifier = self._get_projection_layer( | |
self.activation, last_hidden=2, layer_norm=False | |
) | |
self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config) | |
# END entity disambiguation layer | |
self.ed_start_projector = self._get_projection_layer(self.activation) | |
self.ed_end_projector = self._get_projection_layer(self.activation) | |
self.training = self.config.training | |
# criterion | |
self.criterion = torch.nn.CrossEntropyLoss() | |
def _get_projection_layer( | |
self, | |
activation: str, | |
last_hidden: Optional[int] = None, | |
input_hidden=None, | |
layer_norm: bool = True, | |
) -> torch.nn.Sequential: | |
head_components = [ | |
torch.nn.Dropout(0.1), | |
torch.nn.Linear( | |
self.transformer_model.config.hidden_size * self.use_last_k_layers | |
if input_hidden is None | |
else input_hidden, | |
self.linears_hidden_size, | |
), | |
activation2functions[activation], | |
torch.nn.Dropout(0.1), | |
torch.nn.Linear( | |
self.linears_hidden_size, | |
self.linears_hidden_size if last_hidden is None else last_hidden, | |
), | |
] | |
if layer_norm: | |
head_components.append( | |
torch.nn.LayerNorm( | |
self.linears_hidden_size if last_hidden is None else last_hidden, | |
self.transformer_model.config.layer_norm_eps, | |
) | |
) | |
return torch.nn.Sequential(*head_components) | |
def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | |
mask = mask.unsqueeze(-1) | |
if next(self.parameters()).dtype == torch.float16: | |
logits = logits * (1 - mask) - 65500 * mask | |
else: | |
logits = logits * (1 - mask) - 1e30 * mask | |
return logits | |
def _get_model_features( | |
self, | |
input_ids: torch.Tensor, | |
attention_mask: torch.Tensor, | |
token_type_ids: Optional[torch.Tensor], | |
): | |
model_input = { | |
"input_ids": input_ids, | |
"attention_mask": attention_mask, | |
"output_hidden_states": self.use_last_k_layers > 1, | |
} | |
if token_type_ids is not None: | |
model_input["token_type_ids"] = token_type_ids | |
model_output = self.transformer_model(**model_input) | |
if self.use_last_k_layers > 1: | |
model_features = torch.cat( | |
model_output[1][-self.use_last_k_layers :], dim=-1 | |
) | |
else: | |
model_features = model_output[0] | |
return model_features | |
def compute_ned_end_logits( | |
self, | |
start_predictions, | |
start_labels, | |
model_features, | |
prediction_mask, | |
batch_size, | |
) -> Optional[torch.Tensor]: | |
# todo: maybe when constraining on the spans, | |
# we should not use a prediction_mask for the end tokens. | |
# at least we should not during training imo | |
start_positions = start_labels if self.training else start_predictions | |
start_positions_indices = ( | |
torch.arange(start_positions.size(1), device=start_positions.device) | |
.unsqueeze(0) | |
.expand(batch_size, -1)[start_positions > 0] | |
).to(start_positions.device) | |
if len(start_positions_indices) > 0: | |
expanded_features = torch.cat( | |
[ | |
model_features[i].unsqueeze(0).expand(x, -1, -1) | |
for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) | |
if x > 0 | |
], | |
dim=0, | |
).to(start_positions_indices.device) | |
expanded_prediction_mask = torch.cat( | |
[ | |
prediction_mask[i].unsqueeze(0).expand(x, -1) | |
for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) | |
if x > 0 | |
], | |
dim=0, | |
).to(expanded_features.device) | |
end_logits = self.ned_end_classifier( | |
hidden_states=expanded_features, | |
start_positions=start_positions_indices, | |
p_mask=expanded_prediction_mask, | |
) | |
return end_logits | |
return None | |
def compute_classification_logits( | |
self, | |
model_features, | |
special_symbols_mask, | |
prediction_mask, | |
batch_size, | |
start_positions=None, | |
end_positions=None, | |
) -> torch.Tensor: | |
if start_positions is None or end_positions is None: | |
start_positions = torch.zeros_like(prediction_mask) | |
end_positions = torch.zeros_like(prediction_mask) | |
model_start_features = self.ed_start_projector(model_features) | |
model_end_features = self.ed_end_projector(model_features) | |
model_end_features[start_positions > 0] = model_end_features[end_positions > 0] | |
model_ed_features = torch.cat( | |
[model_start_features, model_end_features], dim=-1 | |
) | |
# computing ed features | |
classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item() | |
special_symbols_representation = model_ed_features[special_symbols_mask].view( | |
batch_size, classes_representations, -1 | |
) | |
logits = torch.bmm( | |
model_ed_features, | |
torch.permute(special_symbols_representation, (0, 2, 1)), | |
) | |
logits = self._mask_logits(logits, prediction_mask) | |
return logits | |
def forward( | |
self, | |
input_ids: torch.Tensor, | |
attention_mask: torch.Tensor, | |
token_type_ids: Optional[torch.Tensor] = None, | |
prediction_mask: Optional[torch.Tensor] = None, | |
special_symbols_mask: Optional[torch.Tensor] = None, | |
start_labels: Optional[torch.Tensor] = None, | |
end_labels: Optional[torch.Tensor] = None, | |
use_predefined_spans: bool = False, | |
*args, | |
**kwargs, | |
) -> Dict[str, Any]: | |
batch_size, seq_len = input_ids.shape | |
model_features = self._get_model_features( | |
input_ids, attention_mask, token_type_ids | |
) | |
ned_start_labels = None | |
# named entity detection if required | |
if use_predefined_spans: # no need to compute spans | |
ned_start_logits, ned_start_probabilities, ned_start_predictions = ( | |
None, | |
None, | |
torch.clone(start_labels) | |
if start_labels is not None | |
else torch.zeros_like(input_ids), | |
) | |
ned_end_logits, ned_end_probabilities, ned_end_predictions = ( | |
None, | |
None, | |
torch.clone(end_labels) | |
if end_labels is not None | |
else torch.zeros_like(input_ids), | |
) | |
ned_start_predictions[ned_start_predictions > 0] = 1 | |
ned_end_predictions[ned_end_predictions > 0] = 1 | |
else: # compute spans | |
# start boundary prediction | |
ned_start_logits = self.ned_start_classifier(model_features) | |
ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask) | |
ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1) | |
ned_start_predictions = ned_start_probabilities.argmax(dim=-1) | |
# end boundary prediction | |
ned_start_labels = ( | |
torch.zeros_like(start_labels) if start_labels is not None else None | |
) | |
if ned_start_labels is not None: | |
ned_start_labels[start_labels == -100] = -100 | |
ned_start_labels[start_labels > 0] = 1 | |
ned_end_logits = self.compute_ned_end_logits( | |
ned_start_predictions, | |
ned_start_labels, | |
model_features, | |
prediction_mask, | |
batch_size, | |
) | |
if ned_end_logits is not None: | |
ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1) | |
ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1) | |
else: | |
ned_end_logits, ned_end_probabilities = None, None | |
ned_end_predictions = ned_start_predictions.new_zeros(batch_size) | |
# flattening end predictions | |
# (flattening can happen only if the | |
# end boundaries were not predicted using the gold labels) | |
if not self.training: | |
flattened_end_predictions = torch.clone(ned_start_predictions) | |
flattened_end_predictions[flattened_end_predictions > 0] = 0 | |
batch_start_predictions = list() | |
for elem_idx in range(batch_size): | |
batch_start_predictions.append( | |
torch.where(ned_start_predictions[elem_idx] > 0)[0].tolist() | |
) | |
# check that the total number of start predictions | |
# is equal to the end predictions | |
total_start_predictions = sum(map(len, batch_start_predictions)) | |
total_end_predictions = len(ned_end_predictions) | |
assert ( | |
total_start_predictions == 0 | |
or total_start_predictions == total_end_predictions | |
), ( | |
f"Total number of start predictions = {total_start_predictions}. " | |
f"Total number of end predictions = {total_end_predictions}" | |
) | |
curr_end_pred_num = 0 | |
for elem_idx, bsp in enumerate(batch_start_predictions): | |
for sp in bsp: | |
ep = ned_end_predictions[curr_end_pred_num].item() | |
if ep < sp: | |
ep = sp | |
# if we already set this span throw it (no overlap) | |
if flattened_end_predictions[elem_idx, ep] == 1: | |
ned_start_predictions[elem_idx, sp] = 0 | |
else: | |
flattened_end_predictions[elem_idx, ep] = 1 | |
curr_end_pred_num += 1 | |
ned_end_predictions = flattened_end_predictions | |
start_position, end_position = ( | |
(start_labels, end_labels) | |
if self.training | |
else (ned_start_predictions, ned_end_predictions) | |
) | |
# Entity disambiguation | |
ed_logits = self.compute_classification_logits( | |
model_features, | |
special_symbols_mask, | |
prediction_mask, | |
batch_size, | |
start_position, | |
end_position, | |
) | |
ed_probabilities = torch.softmax(ed_logits, dim=-1) | |
ed_predictions = torch.argmax(ed_probabilities, dim=-1) | |
# output build | |
output_dict = dict( | |
batch_size=batch_size, | |
ned_start_logits=ned_start_logits, | |
ned_start_probabilities=ned_start_probabilities, | |
ned_start_predictions=ned_start_predictions, | |
ned_end_logits=ned_end_logits, | |
ned_end_probabilities=ned_end_probabilities, | |
ned_end_predictions=ned_end_predictions, | |
ed_logits=ed_logits, | |
ed_probabilities=ed_probabilities, | |
ed_predictions=ed_predictions, | |
) | |
# compute loss if labels | |
if start_labels is not None and end_labels is not None and self.training: | |
# named entity detection loss | |
# start | |
if ned_start_logits is not None: | |
ned_start_loss = self.criterion( | |
ned_start_logits.view(-1, ned_start_logits.shape[-1]), | |
ned_start_labels.view(-1), | |
) | |
else: | |
ned_start_loss = 0 | |
# end | |
if ned_end_logits is not None: | |
ned_end_labels = torch.zeros_like(end_labels) | |
ned_end_labels[end_labels == -100] = -100 | |
ned_end_labels[end_labels > 0] = 1 | |
ned_end_loss = self.criterion( | |
ned_end_logits, | |
( | |
torch.arange( | |
ned_end_labels.size(1), device=ned_end_labels.device | |
) | |
.unsqueeze(0) | |
.expand(batch_size, -1)[ned_end_labels > 0] | |
).to(ned_end_labels.device), | |
) | |
else: | |
ned_end_loss = 0 | |
# entity disambiguation loss | |
start_labels[ned_start_labels != 1] = -100 | |
ed_labels = torch.clone(start_labels) | |
ed_labels[end_labels > 0] = end_labels[end_labels > 0] | |
ed_loss = self.criterion( | |
ed_logits.view(-1, ed_logits.shape[-1]), | |
ed_labels.view(-1), | |
) | |
output_dict["ned_start_loss"] = ned_start_loss | |
output_dict["ned_end_loss"] = ned_end_loss | |
output_dict["ed_loss"] = ed_loss | |
output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss | |
return output_dict | |
class RelikReaderREModel(PreTrainedModel): | |
config_class = RelikReaderConfig | |
def __init__(self, config, *args, **kwargs): | |
super().__init__(config) | |
# Transformer model declaration | |
# self.transformer_model_name = transformer_model | |
self.config = config | |
self.transformer_model = ( | |
AutoModel.from_pretrained(config.transformer_model) | |
if config.num_layers is None | |
else AutoModel.from_pretrained( | |
config.transformer_model, num_hidden_layers=config.num_layers | |
) | |
) | |
self.transformer_model.resize_token_embeddings( | |
self.transformer_model.config.vocab_size + config.additional_special_symbols | |
) | |
# named entity detection layers | |
self.ned_start_classifier = self._get_projection_layer( | |
config.activation, last_hidden=2, layer_norm=False | |
) | |
self.ned_end_classifier = PoolerEndLogitsBi(self.transformer_model.config) | |
self.entity_type_loss = ( | |
config.entity_type_loss if hasattr(config, "entity_type_loss") else False | |
) | |
self.relation_disambiguation_loss = ( | |
config.relation_disambiguation_loss | |
if hasattr(config, "relation_disambiguation_loss") | |
else False | |
) | |
input_hidden_ents = 2 * self.transformer_model.config.hidden_size | |
self.re_subject_projector = self._get_projection_layer( | |
config.activation, input_hidden=input_hidden_ents | |
) | |
self.re_object_projector = self._get_projection_layer( | |
config.activation, input_hidden=input_hidden_ents | |
) | |
self.re_relation_projector = self._get_projection_layer(config.activation) | |
if self.entity_type_loss or self.relation_disambiguation_loss: | |
self.re_entities_projector = self._get_projection_layer( | |
config.activation, | |
input_hidden=2 * self.transformer_model.config.hidden_size, | |
) | |
self.re_definition_projector = self._get_projection_layer( | |
config.activation, | |
) | |
self.re_classifier = self._get_projection_layer( | |
config.activation, | |
input_hidden=config.linears_hidden_size, | |
last_hidden=2, | |
layer_norm=False, | |
) | |
if self.entity_type_loss or self.relation_disambiguation_loss: | |
self.re_ed_classifier = self._get_projection_layer( | |
config.activation, | |
input_hidden=config.linears_hidden_size, | |
last_hidden=2, | |
layer_norm=False, | |
) | |
self.training = config.training | |
# criterion | |
self.criterion = torch.nn.CrossEntropyLoss() | |
def _get_projection_layer( | |
self, | |
activation: str, | |
last_hidden: Optional[int] = None, | |
input_hidden=None, | |
layer_norm: bool = True, | |
) -> torch.nn.Sequential: | |
head_components = [ | |
torch.nn.Dropout(0.1), | |
torch.nn.Linear( | |
self.transformer_model.config.hidden_size | |
* self.config.use_last_k_layers | |
if input_hidden is None | |
else input_hidden, | |
self.config.linears_hidden_size, | |
), | |
activation2functions[activation], | |
torch.nn.Dropout(0.1), | |
torch.nn.Linear( | |
self.config.linears_hidden_size, | |
self.config.linears_hidden_size if last_hidden is None else last_hidden, | |
), | |
] | |
if layer_norm: | |
head_components.append( | |
torch.nn.LayerNorm( | |
self.config.linears_hidden_size | |
if last_hidden is None | |
else last_hidden, | |
self.transformer_model.config.layer_norm_eps, | |
) | |
) | |
return torch.nn.Sequential(*head_components) | |
def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | |
mask = mask.unsqueeze(-1) | |
if next(self.parameters()).dtype == torch.float16: | |
logits = logits * (1 - mask) - 65500 * mask | |
else: | |
logits = logits * (1 - mask) - 1e30 * mask | |
return logits | |
def _get_model_features( | |
self, | |
input_ids: torch.Tensor, | |
attention_mask: torch.Tensor, | |
token_type_ids: Optional[torch.Tensor], | |
): | |
model_input = { | |
"input_ids": input_ids, | |
"attention_mask": attention_mask, | |
"output_hidden_states": self.config.use_last_k_layers > 1, | |
} | |
if token_type_ids is not None: | |
model_input["token_type_ids"] = token_type_ids | |
model_output = self.transformer_model(**model_input) | |
if self.config.use_last_k_layers > 1: | |
model_features = torch.cat( | |
model_output[1][-self.config.use_last_k_layers :], dim=-1 | |
) | |
else: | |
model_features = model_output[0] | |
return model_features | |
def compute_ned_end_logits( | |
self, | |
start_predictions, | |
start_labels, | |
model_features, | |
prediction_mask, | |
batch_size, | |
) -> Optional[torch.Tensor]: | |
# todo: maybe when constraining on the spans, | |
# we should not use a prediction_mask for the end tokens. | |
# at least we should not during training imo | |
start_positions = start_labels if self.training else start_predictions | |
start_positions_indices = ( | |
torch.arange(start_positions.size(1), device=start_positions.device) | |
.unsqueeze(0) | |
.expand(batch_size, -1)[start_positions > 0] | |
).to(start_positions.device) | |
if len(start_positions_indices) > 0: | |
expanded_features = torch.cat( | |
[ | |
model_features[i].unsqueeze(0).expand(x, -1, -1) | |
for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) | |
if x > 0 | |
], | |
dim=0, | |
).to(start_positions_indices.device) | |
expanded_prediction_mask = torch.cat( | |
[ | |
prediction_mask[i].unsqueeze(0).expand(x, -1) | |
for i, x in enumerate(torch.sum(start_positions > 0, dim=-1)) | |
if x > 0 | |
], | |
dim=0, | |
).to(expanded_features.device) | |
# mask all tokens before start_positions_indices ie, mask all tokens with | |
# indices < start_positions_indices with 1, ie. [range(x) for x in start_positions_indices] | |
expanded_prediction_mask = torch.stack( | |
[ | |
torch.cat( | |
[ | |
torch.ones(x, device=expanded_features.device), | |
expanded_prediction_mask[i, x:], | |
] | |
) | |
for i, x in enumerate(start_positions_indices) | |
if x > 0 | |
], | |
dim=0, | |
).to(expanded_features.device) | |
end_logits = self.ned_end_classifier( | |
hidden_states=expanded_features, | |
start_positions=start_positions_indices, | |
p_mask=expanded_prediction_mask, | |
) | |
return end_logits | |
return None | |
def compute_relation_logits( | |
self, | |
model_entity_features, | |
special_symbols_features, | |
) -> torch.Tensor: | |
model_subject_features = self.re_subject_projector(model_entity_features) | |
model_object_features = self.re_object_projector(model_entity_features) | |
special_symbols_start_representation = self.re_relation_projector( | |
special_symbols_features | |
) | |
re_logits = torch.einsum( | |
"bse,bde,bfe->bsdfe", | |
model_subject_features, | |
model_object_features, | |
special_symbols_start_representation, | |
) | |
re_logits = self.re_classifier(re_logits) | |
return re_logits | |
def compute_entity_logits( | |
self, | |
model_entity_features, | |
special_symbols_features, | |
) -> torch.Tensor: | |
model_ed_features = self.re_entities_projector(model_entity_features) | |
special_symbols_ed_representation = self.re_definition_projector( | |
special_symbols_features | |
) | |
logits = torch.einsum( | |
"bce,bde->bcde", | |
model_ed_features, | |
special_symbols_ed_representation, | |
) | |
logits = self.re_ed_classifier(logits) | |
start_logits = self._mask_logits( | |
logits, | |
(model_entity_features == -100) | |
.all(2) | |
.long() | |
.unsqueeze(2) | |
.repeat(1, 1, torch.sum(model_entity_features, dim=1)[0].item()), | |
) | |
return logits | |
def compute_loss(self, logits, labels, mask=None): | |
logits = logits.view(-1, logits.shape[-1]) | |
labels = labels.view(-1).long() | |
if mask is not None: | |
return self.criterion(logits[mask], labels[mask]) | |
return self.criterion(logits, labels) | |
def compute_ned_end_loss(self, ned_end_logits, end_labels): | |
if ned_end_logits is None: | |
return 0 | |
ned_end_labels = torch.zeros_like(end_labels) | |
ned_end_labels[end_labels == -100] = -100 | |
ned_end_labels[end_labels > 0] = 1 | |
return self.compute_loss(ned_end_logits, ned_end_labels) | |
def compute_ned_type_loss( | |
self, | |
disambiguation_labels, | |
re_ned_entities_logits, | |
ned_type_logits, | |
re_entities_logits, | |
entity_types, | |
): | |
if self.entity_type_loss and self.relation_disambiguation_loss: | |
return self.compute_loss(disambiguation_labels, re_ned_entities_logits) | |
if self.entity_type_loss: | |
return self.compute_loss( | |
disambiguation_labels[:, :, :entity_types], ned_type_logits | |
) | |
if self.relation_disambiguation_loss: | |
return self.compute_loss(disambiguation_labels, re_entities_logits) | |
return 0 | |
def compute_relation_loss(self, relation_labels, re_logits): | |
return self.compute_loss( | |
re_logits, relation_labels, relation_labels.view(-1) != -100 | |
) | |
def forward( | |
self, | |
input_ids: torch.Tensor, | |
attention_mask: torch.Tensor, | |
token_type_ids: torch.Tensor, | |
prediction_mask: Optional[torch.Tensor] = None, | |
special_symbols_mask: Optional[torch.Tensor] = None, | |
special_symbols_mask_entities: Optional[torch.Tensor] = None, | |
start_labels: Optional[torch.Tensor] = None, | |
end_labels: Optional[torch.Tensor] = None, | |
disambiguation_labels: Optional[torch.Tensor] = None, | |
relation_labels: Optional[torch.Tensor] = None, | |
is_validation: bool = False, | |
is_prediction: bool = False, | |
*args, | |
**kwargs, | |
) -> Dict[str, Any]: | |
batch_size = input_ids.shape[0] | |
model_features = self._get_model_features( | |
input_ids, attention_mask, token_type_ids | |
) | |
# named entity detection | |
if is_prediction and start_labels is not None: | |
ned_start_logits, ned_start_probabilities, ned_start_predictions = ( | |
None, | |
None, | |
torch.zeros_like(start_labels), | |
) | |
ned_end_logits, ned_end_probabilities, ned_end_predictions = ( | |
None, | |
None, | |
torch.zeros_like(end_labels), | |
) | |
ned_start_predictions[start_labels > 0] = 1 | |
ned_end_predictions[end_labels > 0] = 1 | |
ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)] | |
else: | |
# start boundary prediction | |
ned_start_logits = self.ned_start_classifier(model_features) | |
ned_start_logits = self._mask_logits( | |
ned_start_logits, prediction_mask | |
) # why? | |
ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1) | |
ned_start_predictions = ned_start_probabilities.argmax(dim=-1) | |
# end boundary prediction | |
ned_start_labels = ( | |
torch.zeros_like(start_labels) if start_labels is not None else None | |
) | |
# start_labels contain entity id at their position, we just need 1 for start of entity | |
if ned_start_labels is not None: | |
ned_start_labels[start_labels > 0] = 1 | |
# compute end logits only if there are any start predictions. | |
# For each start prediction, n end predictions are made | |
ned_end_logits = self.compute_ned_end_logits( | |
ned_start_predictions, | |
ned_start_labels, | |
model_features, | |
prediction_mask, | |
batch_size, | |
) | |
# For each start prediction, n end predictions are made based on | |
# binary classification ie. argmax at each position. | |
ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1) | |
ned_end_predictions = ned_end_probabilities.argmax(dim=-1) | |
if is_prediction or is_validation: | |
end_preds_count = ned_end_predictions.sum(1) | |
# If there are no end predictions for a start prediction, remove the start prediction | |
ned_start_predictions[ned_start_predictions == 1] = ( | |
end_preds_count != 0 | |
).long() | |
ned_end_predictions = ned_end_predictions[end_preds_count != 0] | |
if end_labels is not None: | |
end_labels = end_labels[~(end_labels == -100).all(2)] | |
start_position, end_position = ( | |
(start_labels, end_labels) | |
if (not is_prediction and not is_validation) | |
else (ned_start_predictions, ned_end_predictions) | |
) | |
start_counts = (start_position > 0).sum(1) | |
ned_end_predictions = ned_end_predictions.split(start_counts.tolist()) | |
# We can only predict relations if we have start and end predictions | |
if (end_position > 0).sum() > 0: | |
ends_count = (end_position > 0).sum(1) | |
model_subject_features = torch.cat( | |
[ | |
torch.repeat_interleave( | |
model_features[start_position > 0], ends_count, dim=0 | |
), # start position features | |
torch.repeat_interleave(model_features, start_counts, dim=0)[ | |
end_position > 0 | |
], # end position features | |
], | |
dim=-1, | |
) | |
ents_count = torch.nn.utils.rnn.pad_sequence( | |
torch.split(ends_count, start_counts.tolist()), | |
batch_first=True, | |
padding_value=0, | |
).sum(1) | |
model_subject_features = torch.nn.utils.rnn.pad_sequence( | |
torch.split(model_subject_features, ents_count.tolist()), | |
batch_first=True, | |
padding_value=-100, | |
) | |
if is_validation or is_prediction: | |
model_subject_features = model_subject_features[:, :30, :] | |
# entity disambiguation. Here relation_disambiguation_loss would only be useful to | |
# reduce the number of candidate relations for the next step, but currently unused. | |
if self.entity_type_loss or self.relation_disambiguation_loss: | |
(re_ned_entities_logits) = self.compute_entity_logits( | |
model_subject_features, | |
model_features[ | |
special_symbols_mask | special_symbols_mask_entities | |
].view(batch_size, -1, model_features.shape[-1]), | |
) | |
entity_types = torch.sum(special_symbols_mask_entities, dim=1)[0].item() | |
ned_type_logits = re_ned_entities_logits[:, :, :entity_types] | |
re_entities_logits = re_ned_entities_logits[:, :, entity_types:] | |
if self.entity_type_loss: | |
ned_type_probabilities = torch.softmax(ned_type_logits, dim=-1) | |
ned_type_predictions = ned_type_probabilities.argmax(dim=-1) | |
ned_type_predictions = ned_type_predictions.argmax(dim=-1) | |
re_entities_probabilities = torch.softmax(re_entities_logits, dim=-1) | |
re_entities_predictions = re_entities_probabilities.argmax(dim=-1) | |
else: | |
( | |
ned_type_logits, | |
ned_type_probabilities, | |
re_entities_logits, | |
re_entities_probabilities, | |
) = (None, None, None, None) | |
ned_type_predictions, re_entities_predictions = ( | |
torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), | |
torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), | |
) | |
# Compute relation logits | |
re_logits = self.compute_relation_logits( | |
model_subject_features, | |
model_features[special_symbols_mask].view( | |
batch_size, -1, model_features.shape[-1] | |
), | |
) | |
re_probabilities = torch.softmax(re_logits, dim=-1) | |
# we set a thresshold instead of argmax in cause it needs to be tweaked | |
re_predictions = re_probabilities[:, :, :, :, 1] > 0.5 | |
# re_predictions = re_probabilities.argmax(dim=-1) | |
re_probabilities = re_probabilities[:, :, :, :, 1] | |
else: | |
( | |
ned_type_logits, | |
ned_type_probabilities, | |
re_entities_logits, | |
re_entities_probabilities, | |
) = (None, None, None, None) | |
ned_type_predictions, re_entities_predictions = ( | |
torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), | |
torch.zeros([batch_size, 1], dtype=torch.long).to(input_ids.device), | |
) | |
re_logits, re_probabilities, re_predictions = ( | |
torch.zeros( | |
[batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long | |
).to(input_ids.device), | |
torch.zeros( | |
[batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long | |
).to(input_ids.device), | |
torch.zeros( | |
[batch_size, 1, 1, special_symbols_mask.sum(1)[0]], dtype=torch.long | |
).to(input_ids.device), | |
) | |
# output build | |
output_dict = dict( | |
batch_size=batch_size, | |
ned_start_logits=ned_start_logits, | |
ned_start_probabilities=ned_start_probabilities, | |
ned_start_predictions=ned_start_predictions, | |
ned_end_logits=ned_end_logits, | |
ned_end_probabilities=ned_end_probabilities, | |
ned_end_predictions=ned_end_predictions, | |
ned_type_logits=ned_type_logits, | |
ned_type_probabilities=ned_type_probabilities, | |
ned_type_predictions=ned_type_predictions, | |
re_entities_logits=re_entities_logits, | |
re_entities_probabilities=re_entities_probabilities, | |
re_entities_predictions=re_entities_predictions, | |
re_logits=re_logits, | |
re_probabilities=re_probabilities, | |
re_predictions=re_predictions, | |
) | |
if ( | |
start_labels is not None | |
and end_labels is not None | |
and relation_labels is not None | |
): | |
ned_start_loss = self.compute_loss(ned_start_logits, ned_start_labels) | |
ned_end_loss = self.compute_ned_end_loss(ned_end_logits, end_labels) | |
if self.entity_type_loss or self.relation_disambiguation_loss: | |
ned_type_loss = self.compute_ned_type_loss( | |
disambiguation_labels, | |
re_ned_entities_logits, | |
ned_type_logits, | |
re_entities_logits, | |
entity_types, | |
) | |
relation_loss = self.compute_relation_loss(relation_labels, re_logits) | |
# compute loss. We can skip the relation loss if we are in the first epochs (optional) | |
if self.entity_type_loss or self.relation_disambiguation_loss: | |
output_dict["loss"] = ( | |
ned_start_loss + ned_end_loss + relation_loss + ned_type_loss | |
) / 4 | |
output_dict["ned_type_loss"] = ned_type_loss | |
else: | |
output_dict["loss"] = ( | |
ned_start_loss + ned_end_loss + relation_loss | |
) / 3 | |
output_dict["ned_start_loss"] = ned_start_loss | |
output_dict["ned_end_loss"] = ned_end_loss | |
output_dict["re_loss"] = relation_loss | |
return output_dict | |