relik-entity-linking / relik /reader /relik_reader_core.py
riccorl's picture
first commit
626eca0
raw history blame
No virus
18.5 kB
import collections
from typing import Any, Dict, Iterator, List, Optional
import torch
from transformers import AutoModel
from transformers.activations import ClippedGELUActivation, GELUActivation
from transformers.modeling_utils import PoolerEndLogits
from relik.reader.data.relik_reader_sample import RelikReaderSample
activation2functions = {
"relu": torch.nn.ReLU(),
"gelu": GELUActivation(),
"gelu_10": ClippedGELUActivation(-10, 10),
}
class RelikReaderCoreModel(torch.nn.Module):
def __init__(
self,
transformer_model: str,
additional_special_symbols: int,
num_layers: Optional[int] = None,
activation: str = "gelu",
linears_hidden_size: Optional[int] = 512,
use_last_k_layers: int = 1,
training: bool = False,
) -> None:
super().__init__()
# Transformer model declaration
self.transformer_model_name = transformer_model
self.transformer_model = (
AutoModel.from_pretrained(transformer_model)
if num_layers is None
else AutoModel.from_pretrained(
transformer_model, num_hidden_layers=num_layers
)
)
self.transformer_model.resize_token_embeddings(
self.transformer_model.config.vocab_size + additional_special_symbols
)
self.activation = activation
self.linears_hidden_size = linears_hidden_size
self.use_last_k_layers = use_last_k_layers
# named entity detection layers
self.ned_start_classifier = self._get_projection_layer(
self.activation, last_hidden=2, layer_norm=False
)
self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config)
# END entity disambiguation layer
self.ed_start_projector = self._get_projection_layer(self.activation)
self.ed_end_projector = self._get_projection_layer(self.activation)
self.training = training
# criterion
self.criterion = torch.nn.CrossEntropyLoss()
def _get_projection_layer(
self,
activation: str,
last_hidden: Optional[int] = None,
input_hidden=None,
layer_norm: bool = True,
) -> torch.nn.Sequential:
head_components = [
torch.nn.Dropout(0.1),
torch.nn.Linear(
self.transformer_model.config.hidden_size * self.use_last_k_layers
if input_hidden is None
else input_hidden,
self.linears_hidden_size,
),
activation2functions[activation],
torch.nn.Dropout(0.1),
torch.nn.Linear(
self.linears_hidden_size,
self.linears_hidden_size if last_hidden is None else last_hidden,
),
]
if layer_norm:
head_components.append(
torch.nn.LayerNorm(
self.linears_hidden_size if last_hidden is None else last_hidden,
self.transformer_model.config.layer_norm_eps,
)
)
return torch.nn.Sequential(*head_components)
def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
mask = mask.unsqueeze(-1)
if next(self.parameters()).dtype == torch.float16:
logits = logits * (1 - mask) - 65500 * mask
else:
logits = logits * (1 - mask) - 1e30 * mask
return logits
def _get_model_features(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor],
):
model_input = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"output_hidden_states": self.use_last_k_layers > 1,
}
if token_type_ids is not None:
model_input["token_type_ids"] = token_type_ids
model_output = self.transformer_model(**model_input)
if self.use_last_k_layers > 1:
model_features = torch.cat(
model_output[1][-self.use_last_k_layers :], dim=-1
)
else:
model_features = model_output[0]
return model_features
def compute_ned_end_logits(
self,
start_predictions,
start_labels,
model_features,
prediction_mask,
batch_size,
) -> Optional[torch.Tensor]:
# todo: maybe when constraining on the spans,
# we should not use a prediction_mask for the end tokens.
# at least we should not during training imo
start_positions = start_labels if self.training else start_predictions
start_positions_indices = (
torch.arange(start_positions.size(1), device=start_positions.device)
.unsqueeze(0)
.expand(batch_size, -1)[start_positions > 0]
).to(start_positions.device)
if len(start_positions_indices) > 0:
expanded_features = torch.cat(
[
model_features[i].unsqueeze(0).expand(x, -1, -1)
for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
if x > 0
],
dim=0,
).to(start_positions_indices.device)
expanded_prediction_mask = torch.cat(
[
prediction_mask[i].unsqueeze(0).expand(x, -1)
for i, x in enumerate(torch.sum(start_positions > 0, dim=-1))
if x > 0
],
dim=0,
).to(expanded_features.device)
end_logits = self.ned_end_classifier(
hidden_states=expanded_features,
start_positions=start_positions_indices,
p_mask=expanded_prediction_mask,
)
return end_logits
return None
def compute_classification_logits(
self,
model_features,
special_symbols_mask,
prediction_mask,
batch_size,
start_positions=None,
end_positions=None,
) -> torch.Tensor:
if start_positions is None or end_positions is None:
start_positions = torch.zeros_like(prediction_mask)
end_positions = torch.zeros_like(prediction_mask)
model_start_features = self.ed_start_projector(model_features)
model_end_features = self.ed_end_projector(model_features)
model_end_features[start_positions > 0] = model_end_features[end_positions > 0]
model_ed_features = torch.cat(
[model_start_features, model_end_features], dim=-1
)
# computing ed features
classes_representations = torch.sum(special_symbols_mask, dim=1)[0].item()
special_symbols_representation = model_ed_features[special_symbols_mask].view(
batch_size, classes_representations, -1
)
logits = torch.bmm(
model_ed_features,
torch.permute(special_symbols_representation, (0, 2, 1)),
)
logits = self._mask_logits(logits, prediction_mask)
return logits
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
prediction_mask: Optional[torch.Tensor] = None,
special_symbols_mask: Optional[torch.Tensor] = None,
start_labels: Optional[torch.Tensor] = None,
end_labels: Optional[torch.Tensor] = None,
use_predefined_spans: bool = False,
*args,
**kwargs,
) -> Dict[str, Any]:
batch_size, seq_len = input_ids.shape
model_features = self._get_model_features(
input_ids, attention_mask, token_type_ids
)
# named entity detection if required
if use_predefined_spans: # no need to compute spans
ned_start_logits, ned_start_probabilities, ned_start_predictions = (
None,
None,
torch.clone(start_labels)
if start_labels is not None
else torch.zeros_like(input_ids),
)
ned_end_logits, ned_end_probabilities, ned_end_predictions = (
None,
None,
torch.clone(end_labels)
if end_labels is not None
else torch.zeros_like(input_ids),
)
ned_start_predictions[ned_start_predictions > 0] = 1
ned_end_predictions[ned_end_predictions > 0] = 1
else: # compute spans
# start boundary prediction
ned_start_logits = self.ned_start_classifier(model_features)
ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask)
ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
ned_start_predictions = ned_start_probabilities.argmax(dim=-1)
# end boundary prediction
ned_start_labels = (
torch.zeros_like(start_labels) if start_labels is not None else None
)
if ned_start_labels is not None:
ned_start_labels[start_labels == -100] = -100
ned_start_labels[start_labels > 0] = 1
ned_end_logits = self.compute_ned_end_logits(
ned_start_predictions,
ned_start_labels,
model_features,
prediction_mask,
batch_size,
)
if ned_end_logits is not None:
ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
else:
ned_end_logits, ned_end_probabilities = None, None
ned_end_predictions = ned_start_predictions.new_zeros(batch_size)
# flattening end predictions
# (flattening can happen only if the
# end boundaries were not predicted using the gold labels)
if not self.training:
flattened_end_predictions = torch.clone(ned_start_predictions)
flattened_end_predictions[flattened_end_predictions > 0] = 0
batch_start_predictions = list()
for elem_idx in range(batch_size):
batch_start_predictions.append(
torch.where(ned_start_predictions[elem_idx] > 0)[0].tolist()
)
# check that the total number of start predictions
# is equal to the end predictions
total_start_predictions = sum(map(len, batch_start_predictions))
total_end_predictions = len(ned_end_predictions)
assert (
total_start_predictions == 0
or total_start_predictions == total_end_predictions
), (
f"Total number of start predictions = {total_start_predictions}. "
f"Total number of end predictions = {total_end_predictions}"
)
curr_end_pred_num = 0
for elem_idx, bsp in enumerate(batch_start_predictions):
for sp in bsp:
ep = ned_end_predictions[curr_end_pred_num].item()
if ep < sp:
ep = sp
# if we already set this span throw it (no overlap)
if flattened_end_predictions[elem_idx, ep] == 1:
ned_start_predictions[elem_idx, sp] = 0
else:
flattened_end_predictions[elem_idx, ep] = 1
curr_end_pred_num += 1
ned_end_predictions = flattened_end_predictions
start_position, end_position = (
(start_labels, end_labels)
if self.training
else (ned_start_predictions, ned_end_predictions)
)
# Entity disambiguation
ed_logits = self.compute_classification_logits(
model_features,
special_symbols_mask,
prediction_mask,
batch_size,
start_position,
end_position,
)
ed_probabilities = torch.softmax(ed_logits, dim=-1)
ed_predictions = torch.argmax(ed_probabilities, dim=-1)
# output build
output_dict = dict(
batch_size=batch_size,
ned_start_logits=ned_start_logits,
ned_start_probabilities=ned_start_probabilities,
ned_start_predictions=ned_start_predictions,
ned_end_logits=ned_end_logits,
ned_end_probabilities=ned_end_probabilities,
ned_end_predictions=ned_end_predictions,
ed_logits=ed_logits,
ed_probabilities=ed_probabilities,
ed_predictions=ed_predictions,
)
# compute loss if labels
if start_labels is not None and end_labels is not None and self.training:
# named entity detection loss
# start
if ned_start_logits is not None:
ned_start_loss = self.criterion(
ned_start_logits.view(-1, ned_start_logits.shape[-1]),
ned_start_labels.view(-1),
)
else:
ned_start_loss = 0
# end
if ned_end_logits is not None:
ned_end_labels = torch.zeros_like(end_labels)
ned_end_labels[end_labels == -100] = -100
ned_end_labels[end_labels > 0] = 1
ned_end_loss = self.criterion(
ned_end_logits,
(
torch.arange(
ned_end_labels.size(1), device=ned_end_labels.device
)
.unsqueeze(0)
.expand(batch_size, -1)[ned_end_labels > 0]
).to(ned_end_labels.device),
)
else:
ned_end_loss = 0
# entity disambiguation loss
start_labels[ned_start_labels != 1] = -100
ed_labels = torch.clone(start_labels)
ed_labels[end_labels > 0] = end_labels[end_labels > 0]
ed_loss = self.criterion(
ed_logits.view(-1, ed_logits.shape[-1]),
ed_labels.view(-1),
)
output_dict["ned_start_loss"] = ned_start_loss
output_dict["ned_end_loss"] = ned_end_loss
output_dict["ed_loss"] = ed_loss
output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss
return output_dict
def batch_predict(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
prediction_mask: Optional[torch.Tensor] = None,
special_symbols_mask: Optional[torch.Tensor] = None,
sample: Optional[List[RelikReaderSample]] = None,
top_k: int = 5, # the amount of top-k most probable entities to predict
*args,
**kwargs,
) -> Iterator[RelikReaderSample]:
forward_output = self.forward(
input_ids,
attention_mask,
token_type_ids,
prediction_mask,
special_symbols_mask,
)
ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy()
ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy()
ed_predictions = forward_output["ed_predictions"].cpu().numpy()
ed_probabilities = forward_output["ed_probabilities"].cpu().numpy()
batch_predictable_candidates = kwargs["predictable_candidates"]
patch_offset = kwargs["patch_offset"]
for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip(
sample,
ned_start_predictions,
ned_end_predictions,
ed_predictions,
ed_probabilities,
batch_predictable_candidates,
patch_offset,
):
ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0]
ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0]
final_class2predicted_spans = collections.defaultdict(list)
spans2predicted_probabilities = dict()
for start_token_index, end_token_index in zip(
ne_start_indices, ne_end_indices
):
# predicted candidate
token_class = edp[start_token_index + 1] - 1
predicted_candidate_title = pred_cands[token_class]
final_class2predicted_spans[predicted_candidate_title].append(
[start_token_index, end_token_index]
)
# candidates probabilities
classes_probabilities = edpr[start_token_index + 1]
classes_probabilities_best_indices = classes_probabilities.argsort()[
::-1
]
titles_2_probs = []
top_k = (
min(
top_k,
len(classes_probabilities_best_indices),
)
if top_k != -1
else len(classes_probabilities_best_indices)
)
for i in range(top_k):
titles_2_probs.append(
(
pred_cands[classes_probabilities_best_indices[i] - 1],
classes_probabilities[
classes_probabilities_best_indices[i]
].item(),
)
)
spans2predicted_probabilities[
(start_token_index, end_token_index)
] = titles_2_probs
if "patches" not in ts._d:
ts._d["patches"] = dict()
ts._d["patches"][po] = dict()
sample_patch = ts._d["patches"][po]
sample_patch["predicted_window_labels"] = final_class2predicted_spans
sample_patch["span_title_probabilities"] = spans2predicted_probabilities
# additional info
sample_patch["predictable_candidates"] = pred_cands
yield ts