|
import logging |
|
from pathlib import Path |
|
from typing import Any, Callable, Dict, Iterator, List, Optional, Union |
|
|
|
import numpy as np |
|
import torch |
|
import transformers as tr |
|
from reader.data.relik_reader_data_utils import batchify, flatten |
|
from reader.data.relik_reader_sample import RelikReaderSample |
|
from reader.pytorch_modules.hf.modeling_relik import ( |
|
RelikReaderConfig, |
|
RelikReaderREModel, |
|
) |
|
from tqdm import tqdm |
|
from transformers import AutoConfig |
|
|
|
from relik.common.log import get_console_logger, get_logger |
|
from relik.reader.utils.special_symbols import NME_SYMBOL, get_special_symbols_re |
|
|
|
console_logger = get_console_logger() |
|
logger = get_logger(__name__, level=logging.INFO) |
|
|
|
|
|
class RelikReaderForTripletExtraction(torch.nn.Module): |
|
def __init__( |
|
self, |
|
transformer_model: Optional[Union[str, tr.PreTrainedModel]] = None, |
|
additional_special_symbols: Optional[int] = 0, |
|
num_layers: Optional[int] = None, |
|
activation: str = "gelu", |
|
linears_hidden_size: Optional[int] = 512, |
|
use_last_k_layers: int = 1, |
|
training: bool = False, |
|
device: Optional[Union[str, torch.device]] = None, |
|
tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None, |
|
**kwargs, |
|
) -> None: |
|
super().__init__() |
|
|
|
if isinstance(transformer_model, str): |
|
config = AutoConfig.from_pretrained( |
|
transformer_model, trust_remote_code=True |
|
) |
|
if "relik_reader" in config.model_type: |
|
transformer_model = RelikReaderREModel.from_pretrained( |
|
transformer_model, **kwargs |
|
) |
|
else: |
|
reader_config = RelikReaderConfig( |
|
transformer_model=transformer_model, |
|
additional_special_symbols=additional_special_symbols, |
|
num_layers=num_layers, |
|
activation=activation, |
|
linears_hidden_size=linears_hidden_size, |
|
use_last_k_layers=use_last_k_layers, |
|
training=training, |
|
) |
|
transformer_model = RelikReaderREModel(reader_config) |
|
|
|
self.relik_reader_re_model = transformer_model |
|
|
|
self._tokenizer = tokenizer |
|
|
|
|
|
self.to(device or torch.device("cpu")) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
token_type_ids: torch.Tensor, |
|
prediction_mask: Optional[torch.Tensor] = None, |
|
special_symbols_mask: Optional[torch.Tensor] = None, |
|
special_symbols_mask_entities: Optional[torch.Tensor] = None, |
|
start_labels: Optional[torch.Tensor] = None, |
|
end_labels: Optional[torch.Tensor] = None, |
|
disambiguation_labels: Optional[torch.Tensor] = None, |
|
relation_labels: Optional[torch.Tensor] = None, |
|
is_validation: bool = False, |
|
is_prediction: bool = False, |
|
*args, |
|
**kwargs, |
|
) -> Dict[str, Any]: |
|
return self.relik_reader_re_model( |
|
input_ids, |
|
attention_mask, |
|
token_type_ids, |
|
prediction_mask, |
|
special_symbols_mask, |
|
special_symbols_mask_entities, |
|
start_labels, |
|
end_labels, |
|
disambiguation_labels, |
|
relation_labels, |
|
is_validation, |
|
is_prediction, |
|
*args, |
|
**kwargs, |
|
) |
|
|
|
def batch_predict( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
prediction_mask: Optional[torch.Tensor] = None, |
|
special_symbols_mask: Optional[torch.Tensor] = None, |
|
special_symbols_mask_entities: Optional[torch.Tensor] = None, |
|
sample: Optional[List[RelikReaderSample]] = None, |
|
*args, |
|
**kwargs, |
|
) -> Iterator[RelikReaderSample]: |
|
""" |
|
Predicts the labels for a batch of samples. |
|
Args: |
|
input_ids: The input ids of the batch. |
|
attention_mask: The attention mask of the batch. |
|
token_type_ids: The token type ids of the batch. |
|
prediction_mask: The prediction mask of the batch. |
|
special_symbols_mask: The special symbols mask of the batch. |
|
special_symbols_mask_entities: The special symbols mask entities of the batch. |
|
sample: The samples of the batch. |
|
Returns: |
|
The predicted labels for each sample. |
|
""" |
|
forward_output = self.forward( |
|
input_ids, |
|
attention_mask, |
|
token_type_ids, |
|
prediction_mask, |
|
special_symbols_mask, |
|
special_symbols_mask_entities, |
|
is_prediction=True, |
|
) |
|
ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() |
|
ned_end_predictions = forward_output["ned_end_predictions"] |
|
ed_predictions = forward_output["re_entities_predictions"].cpu().numpy() |
|
ned_type_predictions = forward_output["ned_type_predictions"].cpu().numpy() |
|
re_predictions = forward_output["re_predictions"].cpu().numpy() |
|
re_probabilities = forward_output["re_probabilities"].detach().cpu().numpy() |
|
if sample is None: |
|
sample = [RelikReaderSample() for _ in range(len(input_ids))] |
|
for ts, ne_st, ne_end, re_pred, re_prob, edp, ne_et in zip( |
|
sample, |
|
ned_start_predictions, |
|
ned_end_predictions, |
|
re_predictions, |
|
re_probabilities, |
|
ed_predictions, |
|
ned_type_predictions, |
|
): |
|
ne_end = ne_end.cpu().numpy() |
|
entities = [] |
|
if self.relik_reader_re_model.entity_type_loss: |
|
starts = np.argwhere(ne_st) |
|
i = 0 |
|
for start, end in zip(starts, ne_end): |
|
ends = np.argwhere(end) |
|
for e in ends: |
|
entities.append([start[0], e[0], ne_et[i]]) |
|
i += 1 |
|
else: |
|
starts = np.argwhere(ne_st) |
|
for start, end in zip(starts, ne_end): |
|
ends = np.argwhere(end) |
|
for e in ends: |
|
entities.append([start[0], e[0]]) |
|
|
|
edp = edp[: len(entities)] |
|
re_pred = re_pred[: len(entities), : len(entities)] |
|
re_prob = re_prob[: len(entities), : len(entities)] |
|
possible_re = np.argwhere(re_pred) |
|
predicted_triplets = [] |
|
predicted_triplets_prob = [] |
|
for i, j, r in possible_re: |
|
if self.relik_reader_re_model.relation_disambiguation_loss: |
|
if not ( |
|
i != j |
|
and edp[i, r] == 1 |
|
and edp[j, r] == 1 |
|
and edp[i, 0] == 0 |
|
and edp[j, 0] == 0 |
|
): |
|
continue |
|
predicted_triplets.append([i, j, r]) |
|
predicted_triplets_prob.append(re_prob[i, j, r]) |
|
|
|
ts._d["predicted_relations"] = predicted_triplets |
|
ts._d["predicted_entities"] = entities |
|
ts._d["predicted_relations_probabilities"] = predicted_triplets_prob |
|
if ts.token2word: |
|
self._convert_tokens_to_word_annotations(ts) |
|
yield ts |
|
|
|
def _build_input(self, text: List[str], candidates: List[List[str]]) -> List[int]: |
|
candidates_symbols = get_special_symbols_re(len(candidates)) |
|
candidates = [ |
|
[cs, ct] if ct != NME_SYMBOL else [NME_SYMBOL] |
|
for cs, ct in zip(candidates_symbols, candidates) |
|
] |
|
return ( |
|
[self.tokenizer.cls_token] |
|
+ text |
|
+ [self.tokenizer.sep_token] |
|
+ flatten(candidates) |
|
+ [self.tokenizer.sep_token] |
|
) |
|
|
|
@staticmethod |
|
def _compute_offsets(offsets_mapping): |
|
offsets_mapping = offsets_mapping.numpy() |
|
token2word = [] |
|
word2token = {} |
|
count = 0 |
|
for i, offset in enumerate(offsets_mapping): |
|
if offset[0] == 0: |
|
token2word.append(i - count) |
|
word2token[i - count] = [i] |
|
else: |
|
token2word.append(token2word[-1]) |
|
word2token[token2word[-1]].append(i) |
|
count += 1 |
|
return token2word, word2token |
|
|
|
@staticmethod |
|
def _convert_tokens_to_word_annotations(sample: RelikReaderSample): |
|
triplets = [] |
|
entities = [] |
|
for entity in sample.predicted_entities: |
|
if sample.entity_candidates: |
|
entities.append( |
|
( |
|
sample.token2word[entity[0] - 1], |
|
sample.token2word[entity[1] - 1] + 1, |
|
sample.entity_candidates[entity[2]], |
|
) |
|
) |
|
else: |
|
entities.append( |
|
( |
|
sample.token2word[entity[0] - 1], |
|
sample.token2word[entity[1] - 1] + 1, |
|
-1, |
|
) |
|
) |
|
for predicted_triplet, predicted_triplet_probabilities in zip( |
|
sample.predicted_relations, sample.predicted_relations_probabilities |
|
): |
|
subject, object_, relation = predicted_triplet |
|
subject = entities[subject] |
|
object_ = entities[object_] |
|
relation = sample.candidates[relation] |
|
triplets.append( |
|
{ |
|
"subject": { |
|
"start": subject[0], |
|
"end": subject[1], |
|
"type": subject[2], |
|
"name": " ".join(sample.tokens[subject[0] : subject[1]]), |
|
}, |
|
"relation": { |
|
"name": relation, |
|
"probability": float(predicted_triplet_probabilities.round(2)), |
|
}, |
|
"object": { |
|
"start": object_[0], |
|
"end": object_[1], |
|
"type": object_[2], |
|
"name": " ".join(sample.tokens[object_[0] : object_[1]]), |
|
}, |
|
} |
|
) |
|
sample.predicted_entities = entities |
|
sample.predicted_relations = triplets |
|
sample.predicted_relations_probabilities = None |
|
|
|
@torch.no_grad() |
|
@torch.inference_mode() |
|
def read( |
|
self, |
|
text: Optional[Union[List[str], List[List[str]]]] = None, |
|
samples: Optional[List[RelikReaderSample]] = None, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
prediction_mask: Optional[torch.Tensor] = None, |
|
special_symbols_mask: Optional[torch.Tensor] = None, |
|
special_symbols_mask_entities: Optional[torch.Tensor] = None, |
|
candidates: Optional[List[List[str]]] = None, |
|
max_length: Optional[int] = 1024, |
|
max_batch_size: Optional[int] = 64, |
|
token_batch_size: Optional[int] = None, |
|
progress_bar: bool = False, |
|
*args, |
|
**kwargs, |
|
) -> List[List[RelikReaderSample]]: |
|
""" |
|
Reads the given text. |
|
Args: |
|
text: The text to read in tokens. |
|
input_ids: The input ids of the text. |
|
attention_mask: The attention mask of the text. |
|
token_type_ids: The token type ids of the text. |
|
prediction_mask: The prediction mask of the text. |
|
special_symbols_mask: The special symbols mask of the text. |
|
special_symbols_mask_entities: The special symbols mask entities of the text. |
|
candidates: The candidates of the text. |
|
max_length: The maximum length of the text. |
|
max_batch_size: The maximum batch size. |
|
token_batch_size: The maximum number of tokens per batch. |
|
Returns: |
|
The predicted labels for each sample. |
|
""" |
|
if text is None and input_ids is None and samples is None: |
|
raise ValueError( |
|
"Either `text` or `input_ids` or `samples` must be provided." |
|
) |
|
if (input_ids is None and samples is None) and ( |
|
text is None or candidates is None |
|
): |
|
raise ValueError( |
|
"`text` and `candidates` must be provided to return the predictions when `input_ids` and `samples` is not provided." |
|
) |
|
if text is not None and samples is None: |
|
if len(text) != len(candidates): |
|
raise ValueError("`text` and `candidates` must have the same length.") |
|
if isinstance(text[0], str): |
|
text = [text] |
|
candidates = [candidates] |
|
|
|
samples = [ |
|
RelikReaderSample(tokens=t, candidates=c) |
|
for t, c in zip(text, candidates) |
|
] |
|
|
|
if samples is not None: |
|
|
|
def output_batch() -> Dict[str, Any]: |
|
assert ( |
|
len( |
|
set( |
|
[ |
|
len(elem["predictable_candidates"]) |
|
for elem in current_batch |
|
] |
|
) |
|
) |
|
== 1 |
|
), " ".join( |
|
map( |
|
str, |
|
[len(elem["predictable_candidates"]) for elem in current_batch], |
|
) |
|
) |
|
|
|
batch_dict = dict() |
|
|
|
de_values_by_field = { |
|
fn: [de[fn] for de in current_batch if fn in de] |
|
for fn in self.fields_batcher |
|
} |
|
|
|
|
|
|
|
de_values_by_field = { |
|
fn: fvs for fn, fvs in de_values_by_field.items() if len(fvs) > 0 |
|
} |
|
|
|
assert len(set([len(v) for v in de_values_by_field.values()])) |
|
|
|
|
|
|
|
de_values_by_field = { |
|
fn: fvs |
|
for fn, fvs in de_values_by_field.items() |
|
if all([fv is not None for fv in fvs]) |
|
} |
|
|
|
for field_name, field_values in de_values_by_field.items(): |
|
field_batch = ( |
|
self.fields_batcher[field_name]([fv[0] for fv in field_values]) |
|
if self.fields_batcher[field_name] is not None |
|
else field_values |
|
) |
|
|
|
batch_dict[field_name] = field_batch |
|
|
|
batch_dict = { |
|
k: v.to(self.device) if isinstance(v, torch.Tensor) else v |
|
for k, v in batch_dict.items() |
|
} |
|
return batch_dict |
|
|
|
current_batch = [] |
|
predictions = [] |
|
current_cand_len = -1 |
|
|
|
for sample in tqdm(samples, disable=not progress_bar): |
|
sample.candidates = [NME_SYMBOL] + sample.candidates |
|
inputs_text = self._build_input(sample.tokens, sample.candidates) |
|
model_inputs = self.tokenizer( |
|
inputs_text, |
|
is_split_into_words=True, |
|
add_special_tokens=False, |
|
padding=False, |
|
truncation=True, |
|
max_length=max_length or self.tokenizer.model_max_length, |
|
return_offsets_mapping=True, |
|
return_tensors="pt", |
|
) |
|
model_inputs["special_symbols_mask"] = ( |
|
model_inputs["input_ids"] > self.tokenizer.vocab_size |
|
) |
|
|
|
model_inputs["token_type_ids"] = ( |
|
torch.cumsum(model_inputs["special_symbols_mask"], dim=1) > 0 |
|
).long() |
|
|
|
model_inputs["prediction_mask"] = model_inputs["token_type_ids"].roll( |
|
shifts=-1, dims=1 |
|
) |
|
model_inputs["prediction_mask"][:, -1] = 1 |
|
model_inputs["prediction_mask"][:, 0] = 1 |
|
|
|
assert ( |
|
len(model_inputs["special_symbols_mask"]) |
|
== len(model_inputs["prediction_mask"]) |
|
== len(model_inputs["input_ids"]) |
|
) |
|
|
|
model_inputs["sample"] = sample |
|
|
|
|
|
model_inputs["predictable_candidates"] = sample.candidates[ |
|
: model_inputs["special_symbols_mask"].sum().item() |
|
] |
|
|
|
offsets = model_inputs.pop("offset_mapping") |
|
offsets = offsets[model_inputs["prediction_mask"] == 0] |
|
sample.token2word, sample.word2token = self._compute_offsets(offsets) |
|
future_max_len = max( |
|
len(model_inputs["input_ids"]), |
|
max([len(b["input_ids"]) for b in current_batch], default=0), |
|
) |
|
future_tokens_per_batch = future_max_len * (len(current_batch) + 1) |
|
|
|
if len(current_batch) > 0 and ( |
|
( |
|
len(model_inputs["predictable_candidates"]) != current_cand_len |
|
and current_cand_len != -1 |
|
) |
|
or ( |
|
isinstance(token_batch_size, int) |
|
and future_tokens_per_batch >= token_batch_size |
|
) |
|
or len(current_batch) == max_batch_size |
|
): |
|
batch_inputs = output_batch() |
|
current_batch = [] |
|
predictions.extend(list(self.batch_predict(**batch_inputs))) |
|
current_cand_len = len(model_inputs["predictable_candidates"]) |
|
current_batch.append(model_inputs) |
|
|
|
if current_batch: |
|
batch_inputs = output_batch() |
|
predictions.extend(list(self.batch_predict(**batch_inputs))) |
|
else: |
|
predictions = list( |
|
self.batch_predict( |
|
input_ids, |
|
attention_mask, |
|
token_type_ids, |
|
prediction_mask, |
|
special_symbols_mask, |
|
special_symbols_mask_entities, |
|
*args, |
|
**kwargs, |
|
) |
|
) |
|
return predictions |
|
|
|
@property |
|
def device(self) -> torch.device: |
|
""" |
|
The device of the model. |
|
""" |
|
return next(self.parameters()).device |
|
|
|
@property |
|
def tokenizer(self) -> tr.PreTrainedTokenizer: |
|
""" |
|
The tokenizer. |
|
""" |
|
if self._tokenizer: |
|
return self._tokenizer |
|
|
|
self._tokenizer = tr.AutoTokenizer.from_pretrained( |
|
self.relik_reader_re_model.config.name_or_path |
|
) |
|
return self._tokenizer |
|
|
|
@property |
|
def fields_batcher(self) -> Dict[str, Union[None, Callable[[list], Any]]]: |
|
fields_batchers = { |
|
"input_ids": lambda x: batchify( |
|
x, padding_value=self.tokenizer.pad_token_id |
|
), |
|
"attention_mask": lambda x: batchify(x, padding_value=0), |
|
"token_type_ids": lambda x: batchify(x, padding_value=0), |
|
"prediction_mask": lambda x: batchify(x, padding_value=1), |
|
"global_attention": lambda x: batchify(x, padding_value=0), |
|
"token2word": None, |
|
"sample": None, |
|
"special_symbols_mask": lambda x: batchify(x, padding_value=False), |
|
"special_symbols_mask_entities": lambda x: batchify(x, padding_value=False), |
|
} |
|
if "roberta" in self.relik_reader_re_model.config.model_type: |
|
del fields_batchers["token_type_ids"] |
|
|
|
return fields_batchers |
|
|
|
def save_pretrained( |
|
self, |
|
output_dir: str, |
|
model_name: Optional[str] = None, |
|
push_to_hub: bool = False, |
|
**kwargs, |
|
) -> None: |
|
""" |
|
Saves the model to the given path. |
|
Args: |
|
output_dir: The path to save the model to. |
|
model_name: The name of the model. |
|
push_to_hub: Whether to push the model to the hub. |
|
""" |
|
|
|
output_dir = Path(output_dir) |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
model_name = model_name or "relik_reader_for_triplet_extraction" |
|
|
|
logger.info(f"Saving reader to {output_dir / model_name}") |
|
|
|
|
|
self.relik_reader_re_model.register_for_auto_class() |
|
self.relik_reader_re_model.save_pretrained( |
|
output_dir / model_name, push_to_hub=push_to_hub, **kwargs |
|
) |
|
|
|
logger.info("Saving reader to disk done.") |
|
|
|
if self.tokenizer: |
|
self.tokenizer.save_pretrained( |
|
output_dir / model_name, push_to_hub=push_to_hub, **kwargs |
|
) |
|
logger.info("Saving tokenizer to disk done.") |
|
|