riccorl's picture
first commit
626eca0
raw
history blame
No virus
23.6 kB
import collections
import logging
from pathlib import Path
from typing import Any, Callable, Dict, Iterator, List, Union
import torch
import transformers as tr
from tqdm import tqdm
from transformers import AutoConfig
from relik.common.log import get_console_logger, get_logger
from relik.reader.data.relik_reader_data_utils import batchify, flatten
from relik.reader.data.relik_reader_sample import RelikReaderSample
from relik.reader.pytorch_modules.hf.modeling_relik import (
RelikReaderConfig,
RelikReaderSpanModel,
)
from relik.reader.relik_reader_predictor import RelikReaderPredictor
from relik.reader.utils.save_load_utilities import load_model_and_conf
from relik.reader.utils.special_symbols import NME_SYMBOL, get_special_symbols
console_logger = get_console_logger()
logger = get_logger(__name__, level=logging.INFO)
class RelikReaderForSpanExtraction(torch.nn.Module):
def __init__(
self,
transformer_model: str | tr.PreTrainedModel | None = None,
additional_special_symbols: int = 0,
num_layers: int | None = None,
activation: str = "gelu",
linears_hidden_size: int | None = 512,
use_last_k_layers: int = 1,
training: bool = False,
device: str | torch.device | None = None,
tokenizer: str | tr.PreTrainedTokenizer | None = None,
**kwargs,
) -> None:
super().__init__()
if isinstance(transformer_model, str):
config = AutoConfig.from_pretrained(
transformer_model, trust_remote_code=True
)
if "relik-reader" in config.model_type:
transformer_model = RelikReaderSpanModel.from_pretrained(
transformer_model, **kwargs
)
else:
reader_config = RelikReaderConfig(
transformer_model=transformer_model,
additional_special_symbols=additional_special_symbols,
num_layers=num_layers,
activation=activation,
linears_hidden_size=linears_hidden_size,
use_last_k_layers=use_last_k_layers,
training=training,
)
transformer_model = RelikReaderSpanModel(reader_config)
self.relik_reader_model = transformer_model
self._tokenizer = tokenizer
# move the model to the device
self.to(device or torch.device("cpu"))
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor,
prediction_mask: torch.Tensor | None = None,
special_symbols_mask: torch.Tensor | None = None,
special_symbols_mask_entities: torch.Tensor | None = None,
start_labels: torch.Tensor | None = None,
end_labels: torch.Tensor | None = None,
disambiguation_labels: torch.Tensor | None = None,
relation_labels: torch.Tensor | None = None,
is_validation: bool = False,
is_prediction: bool = False,
*args,
**kwargs,
) -> Dict[str, Any]:
return self.relik_reader_model(
input_ids,
attention_mask,
token_type_ids,
prediction_mask,
special_symbols_mask,
special_symbols_mask_entities,
start_labels,
end_labels,
disambiguation_labels,
relation_labels,
is_validation,
is_prediction,
*args,
**kwargs,
)
def batch_predict(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor | None = None,
prediction_mask: torch.Tensor | None = None,
special_symbols_mask: torch.Tensor | None = None,
sample: List[RelikReaderSample] | None = None,
top_k: int = 5, # the amount of top-k most probable entities to predict
*args,
**kwargs,
) -> Iterator[RelikReaderSample]:
"""
Args:
input_ids:
attention_mask:
token_type_ids:
prediction_mask:
special_symbols_mask:
sample:
top_k:
*args:
**kwargs:
Returns:
"""
forward_output = self.forward(
input_ids,
attention_mask,
token_type_ids,
prediction_mask,
special_symbols_mask,
)
ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy()
ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy()
ed_predictions = forward_output["ed_predictions"].cpu().numpy()
ed_probabilities = forward_output["ed_probabilities"].cpu().numpy()
batch_predictable_candidates = kwargs["predictable_candidates"]
patch_offset = kwargs["patch_offset"]
for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip(
sample,
ned_start_predictions,
ned_end_predictions,
ed_predictions,
ed_probabilities,
batch_predictable_candidates,
patch_offset,
):
ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0]
ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0]
final_class2predicted_spans = collections.defaultdict(list)
spans2predicted_probabilities = dict()
for start_token_index, end_token_index in zip(
ne_start_indices, ne_end_indices
):
# predicted candidate
token_class = edp[start_token_index + 1] - 1
predicted_candidate_title = pred_cands[token_class]
final_class2predicted_spans[predicted_candidate_title].append(
[start_token_index, end_token_index]
)
# candidates probabilities
classes_probabilities = edpr[start_token_index + 1]
classes_probabilities_best_indices = classes_probabilities.argsort()[
::-1
]
titles_2_probs = []
top_k = (
min(
top_k,
len(classes_probabilities_best_indices),
)
if top_k != -1
else len(classes_probabilities_best_indices)
)
for i in range(top_k):
titles_2_probs.append(
(
pred_cands[classes_probabilities_best_indices[i] - 1],
classes_probabilities[
classes_probabilities_best_indices[i]
].item(),
)
)
spans2predicted_probabilities[
(start_token_index, end_token_index)
] = titles_2_probs
if "patches" not in ts._d:
ts._d["patches"] = dict()
ts._d["patches"][po] = dict()
sample_patch = ts._d["patches"][po]
sample_patch["predicted_window_labels"] = final_class2predicted_spans
sample_patch["span_title_probabilities"] = spans2predicted_probabilities
# additional info
sample_patch["predictable_candidates"] = pred_cands
yield ts
def _build_input(self, text: List[str], candidates: List[List[str]]) -> list[str]:
candidates_symbols = get_special_symbols(len(candidates))
candidates = [
[cs, ct] if ct != NME_SYMBOL else [NME_SYMBOL]
for cs, ct in zip(candidates_symbols, candidates)
]
return (
[self.tokenizer.cls_token]
+ text
+ [self.tokenizer.sep_token]
+ flatten(candidates)
+ [self.tokenizer.sep_token]
)
@staticmethod
def _compute_offsets(offsets_mapping):
offsets_mapping = offsets_mapping.numpy()
token2word = []
word2token = {}
count = 0
for i, offset in enumerate(offsets_mapping):
if offset[0] == 0:
token2word.append(i - count)
word2token[i - count] = [i]
else:
token2word.append(token2word[-1])
word2token[token2word[-1]].append(i)
count += 1
return token2word, word2token
@staticmethod
def _convert_tokens_to_word_annotations(sample: RelikReaderSample):
triplets = []
entities = []
for entity in sample.predicted_entities:
if sample.entity_candidates:
entities.append(
(
sample.token2word[entity[0] - 1],
sample.token2word[entity[1] - 1] + 1,
sample.entity_candidates[entity[2]],
)
)
else:
entities.append(
(
sample.token2word[entity[0] - 1],
sample.token2word[entity[1] - 1] + 1,
-1,
)
)
for predicted_triplet, predicted_triplet_probabilities in zip(
sample.predicted_relations, sample.predicted_relations_probabilities
):
subject, object_, relation = predicted_triplet
subject = entities[subject]
object_ = entities[object_]
relation = sample.candidates[relation]
triplets.append(
{
"subject": {
"start": subject[0],
"end": subject[1],
"type": subject[2],
"name": " ".join(sample.tokens[subject[0] : subject[1]]),
},
"relation": {
"name": relation,
"probability": float(predicted_triplet_probabilities.round(2)),
},
"object": {
"start": object_[0],
"end": object_[1],
"type": object_[2],
"name": " ".join(sample.tokens[object_[0] : object_[1]]),
},
}
)
sample.predicted_entities = entities
sample.predicted_relations = triplets
sample.predicted_relations_probabilities = None
@torch.no_grad()
@torch.inference_mode()
def read(
self,
text: List[str] | List[List[str]] | None = None,
samples: List[RelikReaderSample] | None = None,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
token_type_ids: torch.Tensor | None = None,
prediction_mask: torch.Tensor | None = None,
special_symbols_mask: torch.Tensor | None = None,
special_symbols_mask_entities: torch.Tensor | None = None,
candidates: List[List[str]] | None = None,
max_length: int | None = 1024,
max_batch_size: int | None = 64,
token_batch_size: int | None = None,
progress_bar: bool = False,
*args,
**kwargs,
) -> List[List[RelikReaderSample]]:
"""
Reads the given text.
Args:
text: The text to read in tokens.
samples:
input_ids: The input ids of the text.
attention_mask: The attention mask of the text.
token_type_ids: The token type ids of the text.
prediction_mask: The prediction mask of the text.
special_symbols_mask: The special symbols mask of the text.
special_symbols_mask_entities: The special symbols mask entities of the text.
candidates: The candidates of the text.
max_length: The maximum length of the text.
max_batch_size: The maximum batch size.
token_batch_size: The maximum number of tokens per batch.
progress_bar:
Returns:
The predicted labels for each sample.
"""
if text is None and input_ids is None and samples is None:
raise ValueError(
"Either `text` or `input_ids` or `samples` must be provided."
)
if (input_ids is None and samples is None) and (
text is None or candidates is None
):
raise ValueError(
"`text` and `candidates` must be provided to return the predictions when "
"`input_ids` and `samples` is not provided."
)
if text is not None and samples is None:
if len(text) != len(candidates):
raise ValueError("`text` and `candidates` must have the same length.")
if isinstance(text[0], str): # change to list of text
text = [text]
candidates = [candidates]
samples = [
RelikReaderSample(tokens=t, candidates=c)
for t, c in zip(text, candidates)
]
if samples is not None:
# function that creates a batch from the 'current_batch' list
def output_batch() -> Dict[str, Any]:
assert (
len(
set(
[
len(elem["predictable_candidates"])
for elem in current_batch
]
)
)
== 1
), " ".join(
map(
str,
[len(elem["predictable_candidates"]) for elem in current_batch],
)
)
batch_dict = dict()
de_values_by_field = {
fn: [de[fn] for de in current_batch if fn in de]
for fn in self.fields_batcher
}
# in case you provide fields batchers but in the batch
# there are no elements for that field
de_values_by_field = {
fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0
}
assert len(set([len(v) for v in de_values_by_field.values()]))
# todo: maybe we should report the user about possible
# fields filtering due to "None" instances
de_values_by_field = {
fn: fvs
for fn, fvs in de_values_by_field.items()
if all([fv is not None for fv in fvs])
}
for field_name, field_values in de_values_by_field.items():
field_batch = (
self.fields_batcher[field_name]([fv[0] for fv in field_values])
if self.fields_batcher[field_name] is not None
else field_values
)
batch_dict[field_name] = field_batch
batch_dict = {
k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in batch_dict.items()
}
return batch_dict
current_batch = []
predictions = []
current_cand_len = -1
for sample in tqdm(samples, disable=not progress_bar):
sample.candidates = [NME_SYMBOL] + sample.candidates
inputs_text = self._build_input(sample.tokens, sample.candidates)
model_inputs = self.tokenizer(
inputs_text,
is_split_into_words=True,
add_special_tokens=False,
padding=False,
truncation=True,
max_length=max_length or self.tokenizer.model_max_length,
return_offsets_mapping=True,
return_tensors="pt",
)
model_inputs["special_symbols_mask"] = (
model_inputs["input_ids"] > self.tokenizer.vocab_size
)
# prediction mask is 0 until the first special symbol
model_inputs["token_type_ids"] = (
torch.cumsum(model_inputs["special_symbols_mask"], dim=1) > 0
).long()
# shift prediction_mask to the left
model_inputs["prediction_mask"] = model_inputs["token_type_ids"].roll(
shifts=-1, dims=1
)
model_inputs["prediction_mask"][:, -1] = 1
model_inputs["prediction_mask"][:, 0] = 1
assert (
len(model_inputs["special_symbols_mask"])
== len(model_inputs["prediction_mask"])
== len(model_inputs["input_ids"])
)
model_inputs["sample"] = sample
# compute cand_len using special_symbols_mask
model_inputs["predictable_candidates"] = sample.candidates[
: model_inputs["special_symbols_mask"].sum().item()
]
# cand_len = sum([id_ > self.tokenizer.vocab_size for id_ in model_inputs["input_ids"]])
offsets = model_inputs.pop("offset_mapping")
offsets = offsets[model_inputs["prediction_mask"] == 0]
sample.token2word, sample.word2token = self._compute_offsets(offsets)
future_max_len = max(
len(model_inputs["input_ids"]),
max([len(b["input_ids"]) for b in current_batch], default=0),
)
future_tokens_per_batch = future_max_len * (len(current_batch) + 1)
if len(current_batch) > 0 and (
(
len(model_inputs["predictable_candidates"]) != current_cand_len
and current_cand_len != -1
)
or (
isinstance(token_batch_size, int)
and future_tokens_per_batch >= token_batch_size
)
or len(current_batch) == max_batch_size
):
batch_inputs = output_batch()
current_batch = []
predictions.extend(list(self.batch_predict(**batch_inputs)))
current_cand_len = len(model_inputs["predictable_candidates"])
current_batch.append(model_inputs)
if current_batch:
batch_inputs = output_batch()
predictions.extend(list(self.batch_predict(**batch_inputs)))
else:
predictions = list(
self.batch_predict(
input_ids,
attention_mask,
token_type_ids,
prediction_mask,
special_symbols_mask,
special_symbols_mask_entities,
*args,
**kwargs,
)
)
return predictions
@property
def device(self) -> torch.device:
"""
The device of the model.
"""
return next(self.parameters()).device
@property
def tokenizer(self) -> tr.PreTrainedTokenizer:
"""
The tokenizer.
"""
if self._tokenizer:
return self._tokenizer
self._tokenizer = tr.AutoTokenizer.from_pretrained(
self.relik_reader_model.config.name_or_path
)
return self._tokenizer
@property
def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]:
fields_batchers = {
"input_ids": lambda x: batchify(
x, padding_value=self.tokenizer.pad_token_id
),
"attention_mask": lambda x: batchify(x, padding_value=0),
"token_type_ids": lambda x: batchify(x, padding_value=0),
"prediction_mask": lambda x: batchify(x, padding_value=1),
"global_attention": lambda x: batchify(x, padding_value=0),
"token2word": None,
"sample": None,
"special_symbols_mask": lambda x: batchify(x, padding_value=False),
"special_symbols_mask_entities": lambda x: batchify(x, padding_value=False),
}
if "roberta" in self.relik_reader_model.config.model_type:
del fields_batchers["token_type_ids"]
return fields_batchers
def save_pretrained(
self,
output_dir: str,
model_name: str | None = None,
push_to_hub: bool = False,
**kwargs,
) -> None:
"""
Saves the model to the given path.
Args:
output_dir: The path to save the model to.
model_name: The name of the model.
push_to_hub: Whether to push the model to the hub.
"""
# create the output directory
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
model_name = model_name or "relik-reader-for-span-extraction"
logger.info(f"Saving reader to {output_dir / model_name}")
# save the model
self.relik_reader_model.register_for_auto_class()
self.relik_reader_model.save_pretrained(
output_dir / model_name, push_to_hub=push_to_hub, **kwargs
)
logger.info("Saving reader to disk done.")
if self.tokenizer:
self.tokenizer.save_pretrained(
output_dir / model_name, push_to_hub=push_to_hub, **kwargs
)
logger.info("Saving tokenizer to disk done.")
class RelikReader:
def __init__(self, model_path: str, predict_nmes: bool = False):
model, model_conf = load_model_and_conf(model_path)
model.training = False
model.eval()
val_dataset_conf = model_conf.data.val_dataset
val_dataset_conf.special_symbols = get_special_symbols(
model_conf.model.entities_per_forward
)
val_dataset_conf.transformer_model = model_conf.model.model.transformer_model
self.predictor = RelikReaderPredictor(
model,
dataset_conf=model_conf.data.val_dataset,
predict_nmes=predict_nmes,
)
self.model_path = model_path
def link_entities(
self,
dataset_path_or_samples: str | Iterator[RelikReaderSample],
token_batch_size: int = 2048,
progress_bar: bool = False,
) -> List[RelikReaderSample]:
data_input = (
(dataset_path_or_samples, None)
if isinstance(dataset_path_or_samples, str)
else (None, dataset_path_or_samples)
)
return self.predictor.predict(
*data_input,
dataset_conf=None,
token_batch_size=token_batch_size,
progress_bar=progress_bar,
)
# def save_pretrained(self, path: Union[str, Path]):
# self.predictor.save(path)
def main():
rr = RelikReader("riccorl/relik-reader-aida-deberta-small-old", predict_nmes=True)
predictions = rr.link_entities(
"/Users/ric/Documents/PhD/Projects/relik/data/reader/aida/testa.jsonl"
)
print(predictions)
if __name__ == "__main__":
main()