Spaces:
Sleeping
Sleeping
import logging | |
from pathlib import Path | |
from typing import Any, Callable, Dict, Iterator, List, Optional, Union | |
import numpy as np | |
import torch | |
import transformers as tr | |
from reader.data.relik_reader_data_utils import batchify, flatten | |
from reader.data.relik_reader_sample import RelikReaderSample | |
from reader.pytorch_modules.hf.modeling_relik import ( | |
RelikReaderConfig, | |
RelikReaderREModel, | |
) | |
from tqdm import tqdm | |
from transformers import AutoConfig | |
from relik.common.log import get_console_logger, get_logger | |
from relik.reader.utils.special_symbols import NME_SYMBOL, get_special_symbols_re | |
console_logger = get_console_logger() | |
logger = get_logger(__name__, level=logging.INFO) | |
class RelikReaderForTripletExtraction(torch.nn.Module): | |
def __init__( | |
self, | |
transformer_model: Optional[Union[str, tr.PreTrainedModel]] = None, | |
additional_special_symbols: Optional[int] = 0, | |
num_layers: Optional[int] = None, | |
activation: str = "gelu", | |
linears_hidden_size: Optional[int] = 512, | |
use_last_k_layers: int = 1, | |
training: bool = False, | |
device: Optional[Union[str, torch.device]] = None, | |
tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None, | |
**kwargs, | |
) -> None: | |
super().__init__() | |
if isinstance(transformer_model, str): | |
config = AutoConfig.from_pretrained( | |
transformer_model, trust_remote_code=True | |
) | |
if "relik_reader" in config.model_type: | |
transformer_model = RelikReaderREModel.from_pretrained( | |
transformer_model, **kwargs | |
) | |
else: | |
reader_config = RelikReaderConfig( | |
transformer_model=transformer_model, | |
additional_special_symbols=additional_special_symbols, | |
num_layers=num_layers, | |
activation=activation, | |
linears_hidden_size=linears_hidden_size, | |
use_last_k_layers=use_last_k_layers, | |
training=training, | |
) | |
transformer_model = RelikReaderREModel(reader_config) | |
self.relik_reader_re_model = transformer_model | |
self._tokenizer = tokenizer | |
# move the model to the device | |
self.to(device or torch.device("cpu")) | |
def forward( | |
self, | |
input_ids: torch.Tensor, | |
attention_mask: torch.Tensor, | |
token_type_ids: torch.Tensor, | |
prediction_mask: Optional[torch.Tensor] = None, | |
special_symbols_mask: Optional[torch.Tensor] = None, | |
special_symbols_mask_entities: Optional[torch.Tensor] = None, | |
start_labels: Optional[torch.Tensor] = None, | |
end_labels: Optional[torch.Tensor] = None, | |
disambiguation_labels: Optional[torch.Tensor] = None, | |
relation_labels: Optional[torch.Tensor] = None, | |
is_validation: bool = False, | |
is_prediction: bool = False, | |
*args, | |
**kwargs, | |
) -> Dict[str, Any]: | |
return self.relik_reader_re_model( | |
input_ids, | |
attention_mask, | |
token_type_ids, | |
prediction_mask, | |
special_symbols_mask, | |
special_symbols_mask_entities, | |
start_labels, | |
end_labels, | |
disambiguation_labels, | |
relation_labels, | |
is_validation, | |
is_prediction, | |
*args, | |
**kwargs, | |
) | |
def batch_predict( | |
self, | |
input_ids: torch.Tensor, | |
attention_mask: torch.Tensor, | |
token_type_ids: Optional[torch.Tensor] = None, | |
prediction_mask: Optional[torch.Tensor] = None, | |
special_symbols_mask: Optional[torch.Tensor] = None, | |
special_symbols_mask_entities: Optional[torch.Tensor] = None, | |
sample: Optional[List[RelikReaderSample]] = None, | |
*args, | |
**kwargs, | |
) -> Iterator[RelikReaderSample]: | |
""" | |
Predicts the labels for a batch of samples. | |
Args: | |
input_ids: The input ids of the batch. | |
attention_mask: The attention mask of the batch. | |
token_type_ids: The token type ids of the batch. | |
prediction_mask: The prediction mask of the batch. | |
special_symbols_mask: The special symbols mask of the batch. | |
special_symbols_mask_entities: The special symbols mask entities of the batch. | |
sample: The samples of the batch. | |
Returns: | |
The predicted labels for each sample. | |
""" | |
forward_output = self.forward( | |
input_ids, | |
attention_mask, | |
token_type_ids, | |
prediction_mask, | |
special_symbols_mask, | |
special_symbols_mask_entities, | |
is_prediction=True, | |
) | |
ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() | |
ned_end_predictions = forward_output["ned_end_predictions"] # .cpu().numpy() | |
ed_predictions = forward_output["re_entities_predictions"].cpu().numpy() | |
ned_type_predictions = forward_output["ned_type_predictions"].cpu().numpy() | |
re_predictions = forward_output["re_predictions"].cpu().numpy() | |
re_probabilities = forward_output["re_probabilities"].detach().cpu().numpy() | |
if sample is None: | |
sample = [RelikReaderSample() for _ in range(len(input_ids))] | |
for ts, ne_st, ne_end, re_pred, re_prob, edp, ne_et in zip( | |
sample, | |
ned_start_predictions, | |
ned_end_predictions, | |
re_predictions, | |
re_probabilities, | |
ed_predictions, | |
ned_type_predictions, | |
): | |
ne_end = ne_end.cpu().numpy() | |
entities = [] | |
if self.relik_reader_re_model.entity_type_loss: | |
starts = np.argwhere(ne_st) | |
i = 0 | |
for start, end in zip(starts, ne_end): | |
ends = np.argwhere(end) | |
for e in ends: | |
entities.append([start[0], e[0], ne_et[i]]) | |
i += 1 | |
else: | |
starts = np.argwhere(ne_st) | |
for start, end in zip(starts, ne_end): | |
ends = np.argwhere(end) | |
for e in ends: | |
entities.append([start[0], e[0]]) | |
edp = edp[: len(entities)] | |
re_pred = re_pred[: len(entities), : len(entities)] | |
re_prob = re_prob[: len(entities), : len(entities)] | |
possible_re = np.argwhere(re_pred) | |
predicted_triplets = [] | |
predicted_triplets_prob = [] | |
for i, j, r in possible_re: | |
if self.relik_reader_re_model.relation_disambiguation_loss: | |
if not ( | |
i != j | |
and edp[i, r] == 1 | |
and edp[j, r] == 1 | |
and edp[i, 0] == 0 | |
and edp[j, 0] == 0 | |
): | |
continue | |
predicted_triplets.append([i, j, r]) | |
predicted_triplets_prob.append(re_prob[i, j, r]) | |
ts._d["predicted_relations"] = predicted_triplets | |
ts._d["predicted_entities"] = entities | |
ts._d["predicted_relations_probabilities"] = predicted_triplets_prob | |
if ts.token2word: | |
self._convert_tokens_to_word_annotations(ts) | |
yield ts | |
def _build_input(self, text: List[str], candidates: List[List[str]]) -> List[int]: | |
candidates_symbols = get_special_symbols_re(len(candidates)) | |
candidates = [ | |
[cs, ct] if ct != NME_SYMBOL else [NME_SYMBOL] | |
for cs, ct in zip(candidates_symbols, candidates) | |
] | |
return ( | |
[self.tokenizer.cls_token] | |
+ text | |
+ [self.tokenizer.sep_token] | |
+ flatten(candidates) | |
+ [self.tokenizer.sep_token] | |
) | |
def _compute_offsets(offsets_mapping): | |
offsets_mapping = offsets_mapping.numpy() | |
token2word = [] | |
word2token = {} | |
count = 0 | |
for i, offset in enumerate(offsets_mapping): | |
if offset[0] == 0: | |
token2word.append(i - count) | |
word2token[i - count] = [i] | |
else: | |
token2word.append(token2word[-1]) | |
word2token[token2word[-1]].append(i) | |
count += 1 | |
return token2word, word2token | |
def _convert_tokens_to_word_annotations(sample: RelikReaderSample): | |
triplets = [] | |
entities = [] | |
for entity in sample.predicted_entities: | |
if sample.entity_candidates: | |
entities.append( | |
( | |
sample.token2word[entity[0] - 1], | |
sample.token2word[entity[1] - 1] + 1, | |
sample.entity_candidates[entity[2]], | |
) | |
) | |
else: | |
entities.append( | |
( | |
sample.token2word[entity[0] - 1], | |
sample.token2word[entity[1] - 1] + 1, | |
-1, | |
) | |
) | |
for predicted_triplet, predicted_triplet_probabilities in zip( | |
sample.predicted_relations, sample.predicted_relations_probabilities | |
): | |
subject, object_, relation = predicted_triplet | |
subject = entities[subject] | |
object_ = entities[object_] | |
relation = sample.candidates[relation] | |
triplets.append( | |
{ | |
"subject": { | |
"start": subject[0], | |
"end": subject[1], | |
"type": subject[2], | |
"name": " ".join(sample.tokens[subject[0] : subject[1]]), | |
}, | |
"relation": { | |
"name": relation, | |
"probability": float(predicted_triplet_probabilities.round(2)), | |
}, | |
"object": { | |
"start": object_[0], | |
"end": object_[1], | |
"type": object_[2], | |
"name": " ".join(sample.tokens[object_[0] : object_[1]]), | |
}, | |
} | |
) | |
sample.predicted_entities = entities | |
sample.predicted_relations = triplets | |
sample.predicted_relations_probabilities = None | |
def read( | |
self, | |
text: Optional[Union[List[str], List[List[str]]]] = None, | |
samples: Optional[List[RelikReaderSample]] = None, | |
input_ids: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
token_type_ids: Optional[torch.Tensor] = None, | |
prediction_mask: Optional[torch.Tensor] = None, | |
special_symbols_mask: Optional[torch.Tensor] = None, | |
special_symbols_mask_entities: Optional[torch.Tensor] = None, | |
candidates: Optional[List[List[str]]] = None, | |
max_length: Optional[int] = 1024, | |
max_batch_size: Optional[int] = 64, | |
token_batch_size: Optional[int] = None, | |
progress_bar: bool = False, | |
*args, | |
**kwargs, | |
) -> List[List[RelikReaderSample]]: | |
""" | |
Reads the given text. | |
Args: | |
text: The text to read in tokens. | |
input_ids: The input ids of the text. | |
attention_mask: The attention mask of the text. | |
token_type_ids: The token type ids of the text. | |
prediction_mask: The prediction mask of the text. | |
special_symbols_mask: The special symbols mask of the text. | |
special_symbols_mask_entities: The special symbols mask entities of the text. | |
candidates: The candidates of the text. | |
max_length: The maximum length of the text. | |
max_batch_size: The maximum batch size. | |
token_batch_size: The maximum number of tokens per batch. | |
Returns: | |
The predicted labels for each sample. | |
""" | |
if text is None and input_ids is None and samples is None: | |
raise ValueError( | |
"Either `text` or `input_ids` or `samples` must be provided." | |
) | |
if (input_ids is None and samples is None) and ( | |
text is None or candidates is None | |
): | |
raise ValueError( | |
"`text` and `candidates` must be provided to return the predictions when `input_ids` and `samples` is not provided." | |
) | |
if text is not None and samples is None: | |
if len(text) != len(candidates): | |
raise ValueError("`text` and `candidates` must have the same length.") | |
if isinstance(text[0], str): # change to list of text | |
text = [text] | |
candidates = [candidates] | |
samples = [ | |
RelikReaderSample(tokens=t, candidates=c) | |
for t, c in zip(text, candidates) | |
] | |
if samples is not None: | |
# function that creates a batch from the 'current_batch' list | |
def output_batch() -> Dict[str, Any]: | |
assert ( | |
len( | |
set( | |
[ | |
len(elem["predictable_candidates"]) | |
for elem in current_batch | |
] | |
) | |
) | |
== 1 | |
), " ".join( | |
map( | |
str, | |
[len(elem["predictable_candidates"]) for elem in current_batch], | |
) | |
) | |
batch_dict = dict() | |
de_values_by_field = { | |
fn: [de[fn] for de in current_batch if fn in de] | |
for fn in self.fields_batcher | |
} | |
# in case you provide fields batchers but in the batch | |
# there are no elements for that field | |
de_values_by_field = { | |
fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0 | |
} | |
assert len(set([len(v) for v in de_values_by_field.values()])) | |
# todo: maybe we should report the user about possible | |
# fields filtering due to "None" instances | |
de_values_by_field = { | |
fn: fvs | |
for fn, fvs in de_values_by_field.items() | |
if all([fv is not None for fv in fvs]) | |
} | |
for field_name, field_values in de_values_by_field.items(): | |
field_batch = ( | |
self.fields_batcher[field_name]([fv[0] for fv in field_values]) | |
if self.fields_batcher[field_name] is not None | |
else field_values | |
) | |
batch_dict[field_name] = field_batch | |
batch_dict = { | |
k: v.to(self.device) if isinstance(v, torch.Tensor) else v | |
for k, v in batch_dict.items() | |
} | |
return batch_dict | |
current_batch = [] | |
predictions = [] | |
current_cand_len = -1 | |
for sample in tqdm(samples, disable=not progress_bar): | |
sample.candidates = [NME_SYMBOL] + sample.candidates | |
inputs_text = self._build_input(sample.tokens, sample.candidates) | |
model_inputs = self.tokenizer( | |
inputs_text, | |
is_split_into_words=True, | |
add_special_tokens=False, | |
padding=False, | |
truncation=True, | |
max_length=max_length or self.tokenizer.model_max_length, | |
return_offsets_mapping=True, | |
return_tensors="pt", | |
) | |
model_inputs["special_symbols_mask"] = ( | |
model_inputs["input_ids"] > self.tokenizer.vocab_size | |
) | |
# prediction mask is 0 until the first special symbol | |
model_inputs["token_type_ids"] = ( | |
torch.cumsum(model_inputs["special_symbols_mask"], dim=1) > 0 | |
).long() | |
# shift prediction_mask to the left | |
model_inputs["prediction_mask"] = model_inputs["token_type_ids"].roll( | |
shifts=-1, dims=1 | |
) | |
model_inputs["prediction_mask"][:, -1] = 1 | |
model_inputs["prediction_mask"][:, 0] = 1 | |
assert ( | |
len(model_inputs["special_symbols_mask"]) | |
== len(model_inputs["prediction_mask"]) | |
== len(model_inputs["input_ids"]) | |
) | |
model_inputs["sample"] = sample | |
# compute cand_len using special_symbols_mask | |
model_inputs["predictable_candidates"] = sample.candidates[ | |
: model_inputs["special_symbols_mask"].sum().item() | |
] | |
# cand_len = sum([id_ > self.tokenizer.vocab_size for id_ in model_inputs["input_ids"]]) | |
offsets = model_inputs.pop("offset_mapping") | |
offsets = offsets[model_inputs["prediction_mask"] == 0] | |
sample.token2word, sample.word2token = self._compute_offsets(offsets) | |
future_max_len = max( | |
len(model_inputs["input_ids"]), | |
max([len(b["input_ids"]) for b in current_batch], default=0), | |
) | |
future_tokens_per_batch = future_max_len * (len(current_batch) + 1) | |
if len(current_batch) > 0 and ( | |
( | |
len(model_inputs["predictable_candidates"]) != current_cand_len | |
and current_cand_len != -1 | |
) | |
or ( | |
isinstance(token_batch_size, int) | |
and future_tokens_per_batch >= token_batch_size | |
) | |
or len(current_batch) == max_batch_size | |
): | |
batch_inputs = output_batch() | |
current_batch = [] | |
predictions.extend(list(self.batch_predict(**batch_inputs))) | |
current_cand_len = len(model_inputs["predictable_candidates"]) | |
current_batch.append(model_inputs) | |
if current_batch: | |
batch_inputs = output_batch() | |
predictions.extend(list(self.batch_predict(**batch_inputs))) | |
else: | |
predictions = list( | |
self.batch_predict( | |
input_ids, | |
attention_mask, | |
token_type_ids, | |
prediction_mask, | |
special_symbols_mask, | |
special_symbols_mask_entities, | |
*args, | |
**kwargs, | |
) | |
) | |
return predictions | |
def device(self) -> torch.device: | |
""" | |
The device of the model. | |
""" | |
return next(self.parameters()).device | |
def tokenizer(self) -> tr.PreTrainedTokenizer: | |
""" | |
The tokenizer. | |
""" | |
if self._tokenizer: | |
return self._tokenizer | |
self._tokenizer = tr.AutoTokenizer.from_pretrained( | |
self.relik_reader_re_model.config.name_or_path | |
) | |
return self._tokenizer | |
def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]: | |
fields_batchers = { | |
"input_ids": lambda x: batchify( | |
x, padding_value=self.tokenizer.pad_token_id | |
), | |
"attention_mask": lambda x: batchify(x, padding_value=0), | |
"token_type_ids": lambda x: batchify(x, padding_value=0), | |
"prediction_mask": lambda x: batchify(x, padding_value=1), | |
"global_attention": lambda x: batchify(x, padding_value=0), | |
"token2word": None, | |
"sample": None, | |
"special_symbols_mask": lambda x: batchify(x, padding_value=False), | |
"special_symbols_mask_entities": lambda x: batchify(x, padding_value=False), | |
} | |
if "roberta" in self.relik_reader_re_model.config.model_type: | |
del fields_batchers["token_type_ids"] | |
return fields_batchers | |
def save_pretrained( | |
self, | |
output_dir: str, | |
model_name: Optional[str] = None, | |
push_to_hub: bool = False, | |
**kwargs, | |
) -> None: | |
""" | |
Saves the model to the given path. | |
Args: | |
output_dir: The path to save the model to. | |
model_name: The name of the model. | |
push_to_hub: Whether to push the model to the hub. | |
""" | |
# create the output directory | |
output_dir = Path(output_dir) | |
output_dir.mkdir(parents=True, exist_ok=True) | |
model_name = model_name or "relik_reader_for_triplet_extraction" | |
logger.info(f"Saving reader to {output_dir / model_name}") | |
# save the model | |
self.relik_reader_re_model.register_for_auto_class() | |
self.relik_reader_re_model.save_pretrained( | |
output_dir / model_name, push_to_hub=push_to_hub, **kwargs | |
) | |
logger.info("Saving reader to disk done.") | |
if self.tokenizer: | |
self.tokenizer.save_pretrained( | |
output_dir / model_name, push_to_hub=push_to_hub, **kwargs | |
) | |
logger.info("Saving tokenizer to disk done.") | |