relik-entity-linking / relik /reader /relik_reader_re.py
riccorl's picture
first commit
626eca0
import logging
from pathlib import Path
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
import numpy as np
import torch
import transformers as tr
from reader.data.relik_reader_data_utils import batchify, flatten
from reader.data.relik_reader_sample import RelikReaderSample
from reader.pytorch_modules.hf.modeling_relik import (
RelikReaderConfig,
RelikReaderREModel,
)
from tqdm import tqdm
from transformers import AutoConfig
from relik.common.log import get_console_logger, get_logger
from relik.reader.utils.special_symbols import NME_SYMBOL, get_special_symbols_re
console_logger = get_console_logger()
logger = get_logger(__name__, level=logging.INFO)
class RelikReaderForTripletExtraction(torch.nn.Module):
def __init__(
self,
transformer_model: Optional[Union[str, tr.PreTrainedModel]] = None,
additional_special_symbols: Optional[int] = 0,
num_layers: Optional[int] = None,
activation: str = "gelu",
linears_hidden_size: Optional[int] = 512,
use_last_k_layers: int = 1,
training: bool = False,
device: Optional[Union[str, torch.device]] = None,
tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None,
**kwargs,
) -> None:
super().__init__()
if isinstance(transformer_model, str):
config = AutoConfig.from_pretrained(
transformer_model, trust_remote_code=True
)
if "relik_reader" in config.model_type:
transformer_model = RelikReaderREModel.from_pretrained(
transformer_model, **kwargs
)
else:
reader_config = RelikReaderConfig(
transformer_model=transformer_model,
additional_special_symbols=additional_special_symbols,
num_layers=num_layers,
activation=activation,
linears_hidden_size=linears_hidden_size,
use_last_k_layers=use_last_k_layers,
training=training,
)
transformer_model = RelikReaderREModel(reader_config)
self.relik_reader_re_model = transformer_model
self._tokenizer = tokenizer
# move the model to the device
self.to(device or torch.device("cpu"))
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor,
prediction_mask: Optional[torch.Tensor] = None,
special_symbols_mask: Optional[torch.Tensor] = None,
special_symbols_mask_entities: Optional[torch.Tensor] = None,
start_labels: Optional[torch.Tensor] = None,
end_labels: Optional[torch.Tensor] = None,
disambiguation_labels: Optional[torch.Tensor] = None,
relation_labels: Optional[torch.Tensor] = None,
is_validation: bool = False,
is_prediction: bool = False,
*args,
**kwargs,
) -> Dict[str, Any]:
return self.relik_reader_re_model(
input_ids,
attention_mask,
token_type_ids,
prediction_mask,
special_symbols_mask,
special_symbols_mask_entities,
start_labels,
end_labels,
disambiguation_labels,
relation_labels,
is_validation,
is_prediction,
*args,
**kwargs,
)
def batch_predict(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
prediction_mask: Optional[torch.Tensor] = None,
special_symbols_mask: Optional[torch.Tensor] = None,
special_symbols_mask_entities: Optional[torch.Tensor] = None,
sample: Optional[List[RelikReaderSample]] = None,
*args,
**kwargs,
) -> Iterator[RelikReaderSample]:
"""
Predicts the labels for a batch of samples.
Args:
input_ids: The input ids of the batch.
attention_mask: The attention mask of the batch.
token_type_ids: The token type ids of the batch.
prediction_mask: The prediction mask of the batch.
special_symbols_mask: The special symbols mask of the batch.
special_symbols_mask_entities: The special symbols mask entities of the batch.
sample: The samples of the batch.
Returns:
The predicted labels for each sample.
"""
forward_output = self.forward(
input_ids,
attention_mask,
token_type_ids,
prediction_mask,
special_symbols_mask,
special_symbols_mask_entities,
is_prediction=True,
)
ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy()
ned_end_predictions = forward_output["ned_end_predictions"] # .cpu().numpy()
ed_predictions = forward_output["re_entities_predictions"].cpu().numpy()
ned_type_predictions = forward_output["ned_type_predictions"].cpu().numpy()
re_predictions = forward_output["re_predictions"].cpu().numpy()
re_probabilities = forward_output["re_probabilities"].detach().cpu().numpy()
if sample is None:
sample = [RelikReaderSample() for _ in range(len(input_ids))]
for ts, ne_st, ne_end, re_pred, re_prob, edp, ne_et in zip(
sample,
ned_start_predictions,
ned_end_predictions,
re_predictions,
re_probabilities,
ed_predictions,
ned_type_predictions,
):
ne_end = ne_end.cpu().numpy()
entities = []
if self.relik_reader_re_model.entity_type_loss:
starts = np.argwhere(ne_st)
i = 0
for start, end in zip(starts, ne_end):
ends = np.argwhere(end)
for e in ends:
entities.append([start[0], e[0], ne_et[i]])
i += 1
else:
starts = np.argwhere(ne_st)
for start, end in zip(starts, ne_end):
ends = np.argwhere(end)
for e in ends:
entities.append([start[0], e[0]])
edp = edp[: len(entities)]
re_pred = re_pred[: len(entities), : len(entities)]
re_prob = re_prob[: len(entities), : len(entities)]
possible_re = np.argwhere(re_pred)
predicted_triplets = []
predicted_triplets_prob = []
for i, j, r in possible_re:
if self.relik_reader_re_model.relation_disambiguation_loss:
if not (
i != j
and edp[i, r] == 1
and edp[j, r] == 1
and edp[i, 0] == 0
and edp[j, 0] == 0
):
continue
predicted_triplets.append([i, j, r])
predicted_triplets_prob.append(re_prob[i, j, r])
ts._d["predicted_relations"] = predicted_triplets
ts._d["predicted_entities"] = entities
ts._d["predicted_relations_probabilities"] = predicted_triplets_prob
if ts.token2word:
self._convert_tokens_to_word_annotations(ts)
yield ts
def _build_input(self, text: List[str], candidates: List[List[str]]) -> List[int]:
candidates_symbols = get_special_symbols_re(len(candidates))
candidates = [
[cs, ct] if ct != NME_SYMBOL else [NME_SYMBOL]
for cs, ct in zip(candidates_symbols, candidates)
]
return (
[self.tokenizer.cls_token]
+ text
+ [self.tokenizer.sep_token]
+ flatten(candidates)
+ [self.tokenizer.sep_token]
)
@staticmethod
def _compute_offsets(offsets_mapping):
offsets_mapping = offsets_mapping.numpy()
token2word = []
word2token = {}
count = 0
for i, offset in enumerate(offsets_mapping):
if offset[0] == 0:
token2word.append(i - count)
word2token[i - count] = [i]
else:
token2word.append(token2word[-1])
word2token[token2word[-1]].append(i)
count += 1
return token2word, word2token
@staticmethod
def _convert_tokens_to_word_annotations(sample: RelikReaderSample):
triplets = []
entities = []
for entity in sample.predicted_entities:
if sample.entity_candidates:
entities.append(
(
sample.token2word[entity[0] - 1],
sample.token2word[entity[1] - 1] + 1,
sample.entity_candidates[entity[2]],
)
)
else:
entities.append(
(
sample.token2word[entity[0] - 1],
sample.token2word[entity[1] - 1] + 1,
-1,
)
)
for predicted_triplet, predicted_triplet_probabilities in zip(
sample.predicted_relations, sample.predicted_relations_probabilities
):
subject, object_, relation = predicted_triplet
subject = entities[subject]
object_ = entities[object_]
relation = sample.candidates[relation]
triplets.append(
{
"subject": {
"start": subject[0],
"end": subject[1],
"type": subject[2],
"name": " ".join(sample.tokens[subject[0] : subject[1]]),
},
"relation": {
"name": relation,
"probability": float(predicted_triplet_probabilities.round(2)),
},
"object": {
"start": object_[0],
"end": object_[1],
"type": object_[2],
"name": " ".join(sample.tokens[object_[0] : object_[1]]),
},
}
)
sample.predicted_entities = entities
sample.predicted_relations = triplets
sample.predicted_relations_probabilities = None
@torch.no_grad()
@torch.inference_mode()
def read(
self,
text: Optional[Union[List[str], List[List[str]]]] = None,
samples: Optional[List[RelikReaderSample]] = None,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
prediction_mask: Optional[torch.Tensor] = None,
special_symbols_mask: Optional[torch.Tensor] = None,
special_symbols_mask_entities: Optional[torch.Tensor] = None,
candidates: Optional[List[List[str]]] = None,
max_length: Optional[int] = 1024,
max_batch_size: Optional[int] = 64,
token_batch_size: Optional[int] = None,
progress_bar: bool = False,
*args,
**kwargs,
) -> List[List[RelikReaderSample]]:
"""
Reads the given text.
Args:
text: The text to read in tokens.
input_ids: The input ids of the text.
attention_mask: The attention mask of the text.
token_type_ids: The token type ids of the text.
prediction_mask: The prediction mask of the text.
special_symbols_mask: The special symbols mask of the text.
special_symbols_mask_entities: The special symbols mask entities of the text.
candidates: The candidates of the text.
max_length: The maximum length of the text.
max_batch_size: The maximum batch size.
token_batch_size: The maximum number of tokens per batch.
Returns:
The predicted labels for each sample.
"""
if text is None and input_ids is None and samples is None:
raise ValueError(
"Either `text` or `input_ids` or `samples` must be provided."
)
if (input_ids is None and samples is None) and (
text is None or candidates is None
):
raise ValueError(
"`text` and `candidates` must be provided to return the predictions when `input_ids` and `samples` is not provided."
)
if text is not None and samples is None:
if len(text) != len(candidates):
raise ValueError("`text` and `candidates` must have the same length.")
if isinstance(text[0], str): # change to list of text
text = [text]
candidates = [candidates]
samples = [
RelikReaderSample(tokens=t, candidates=c)
for t, c in zip(text, candidates)
]
if samples is not None:
# function that creates a batch from the 'current_batch' list
def output_batch() -> Dict[str, Any]:
assert (
len(
set(
[
len(elem["predictable_candidates"])
for elem in current_batch
]
)
)
== 1
), " ".join(
map(
str,
[len(elem["predictable_candidates"]) for elem in current_batch],
)
)
batch_dict = dict()
de_values_by_field = {
fn: [de[fn] for de in current_batch if fn in de]
for fn in self.fields_batcher
}
# in case you provide fields batchers but in the batch
# there are no elements for that field
de_values_by_field = {
fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0
}
assert len(set([len(v) for v in de_values_by_field.values()]))
# todo: maybe we should report the user about possible
# fields filtering due to "None" instances
de_values_by_field = {
fn: fvs
for fn, fvs in de_values_by_field.items()
if all([fv is not None for fv in fvs])
}
for field_name, field_values in de_values_by_field.items():
field_batch = (
self.fields_batcher[field_name]([fv[0] for fv in field_values])
if self.fields_batcher[field_name] is not None
else field_values
)
batch_dict[field_name] = field_batch
batch_dict = {
k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in batch_dict.items()
}
return batch_dict
current_batch = []
predictions = []
current_cand_len = -1
for sample in tqdm(samples, disable=not progress_bar):
sample.candidates = [NME_SYMBOL] + sample.candidates
inputs_text = self._build_input(sample.tokens, sample.candidates)
model_inputs = self.tokenizer(
inputs_text,
is_split_into_words=True,
add_special_tokens=False,
padding=False,
truncation=True,
max_length=max_length or self.tokenizer.model_max_length,
return_offsets_mapping=True,
return_tensors="pt",
)
model_inputs["special_symbols_mask"] = (
model_inputs["input_ids"] > self.tokenizer.vocab_size
)
# prediction mask is 0 until the first special symbol
model_inputs["token_type_ids"] = (
torch.cumsum(model_inputs["special_symbols_mask"], dim=1) > 0
).long()
# shift prediction_mask to the left
model_inputs["prediction_mask"] = model_inputs["token_type_ids"].roll(
shifts=-1, dims=1
)
model_inputs["prediction_mask"][:, -1] = 1
model_inputs["prediction_mask"][:, 0] = 1
assert (
len(model_inputs["special_symbols_mask"])
== len(model_inputs["prediction_mask"])
== len(model_inputs["input_ids"])
)
model_inputs["sample"] = sample
# compute cand_len using special_symbols_mask
model_inputs["predictable_candidates"] = sample.candidates[
: model_inputs["special_symbols_mask"].sum().item()
]
# cand_len = sum([id_ > self.tokenizer.vocab_size for id_ in model_inputs["input_ids"]])
offsets = model_inputs.pop("offset_mapping")
offsets = offsets[model_inputs["prediction_mask"] == 0]
sample.token2word, sample.word2token = self._compute_offsets(offsets)
future_max_len = max(
len(model_inputs["input_ids"]),
max([len(b["input_ids"]) for b in current_batch], default=0),
)
future_tokens_per_batch = future_max_len * (len(current_batch) + 1)
if len(current_batch) > 0 and (
(
len(model_inputs["predictable_candidates"]) != current_cand_len
and current_cand_len != -1
)
or (
isinstance(token_batch_size, int)
and future_tokens_per_batch >= token_batch_size
)
or len(current_batch) == max_batch_size
):
batch_inputs = output_batch()
current_batch = []
predictions.extend(list(self.batch_predict(**batch_inputs)))
current_cand_len = len(model_inputs["predictable_candidates"])
current_batch.append(model_inputs)
if current_batch:
batch_inputs = output_batch()
predictions.extend(list(self.batch_predict(**batch_inputs)))
else:
predictions = list(
self.batch_predict(
input_ids,
attention_mask,
token_type_ids,
prediction_mask,
special_symbols_mask,
special_symbols_mask_entities,
*args,
**kwargs,
)
)
return predictions
@property
def device(self) -> torch.device:
"""
The device of the model.
"""
return next(self.parameters()).device
@property
def tokenizer(self) -> tr.PreTrainedTokenizer:
"""
The tokenizer.
"""
if self._tokenizer:
return self._tokenizer
self._tokenizer = tr.AutoTokenizer.from_pretrained(
self.relik_reader_re_model.config.name_or_path
)
return self._tokenizer
@property
def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]:
fields_batchers = {
"input_ids": lambda x: batchify(
x, padding_value=self.tokenizer.pad_token_id
),
"attention_mask": lambda x: batchify(x, padding_value=0),
"token_type_ids": lambda x: batchify(x, padding_value=0),
"prediction_mask": lambda x: batchify(x, padding_value=1),
"global_attention": lambda x: batchify(x, padding_value=0),
"token2word": None,
"sample": None,
"special_symbols_mask": lambda x: batchify(x, padding_value=False),
"special_symbols_mask_entities": lambda x: batchify(x, padding_value=False),
}
if "roberta" in self.relik_reader_re_model.config.model_type:
del fields_batchers["token_type_ids"]
return fields_batchers
def save_pretrained(
self,
output_dir: str,
model_name: Optional[str] = None,
push_to_hub: bool = False,
**kwargs,
) -> None:
"""
Saves the model to the given path.
Args:
output_dir: The path to save the model to.
model_name: The name of the model.
push_to_hub: Whether to push the model to the hub.
"""
# create the output directory
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
model_name = model_name or "relik_reader_for_triplet_extraction"
logger.info(f"Saving reader to {output_dir / model_name}")
# save the model
self.relik_reader_re_model.register_for_auto_class()
self.relik_reader_re_model.save_pretrained(
output_dir / model_name, push_to_hub=push_to_hub, **kwargs
)
logger.info("Saving reader to disk done.")
if self.tokenizer:
self.tokenizer.save_pretrained(
output_dir / model_name, push_to_hub=push_to_hub, **kwargs
)
logger.info("Saving tokenizer to disk done.")