|
import collections |
|
import contextlib |
|
import logging |
|
from typing import Any, Dict, Iterator, List |
|
|
|
import torch |
|
import transformers as tr |
|
from lightning_fabric.utilities import move_data_to_device |
|
from torch.utils.data import DataLoader, IterableDataset |
|
from tqdm import tqdm |
|
|
|
from relik.common.log import get_console_logger, get_logger |
|
from relik.common.utils import get_callable_from_string |
|
from relik.reader.data.relik_reader_sample import RelikReaderSample |
|
from relik.reader.pytorch_modules.base import RelikReaderBase |
|
from relik.reader.utils.special_symbols import get_special_symbols |
|
from relik.retriever.pytorch_modules import PRECISION_MAP |
|
|
|
console_logger = get_console_logger() |
|
logger = get_logger(__name__, level=logging.INFO) |
|
|
|
|
|
class RelikReaderForSpanExtraction(RelikReaderBase): |
|
""" |
|
A class for the RelikReader model for span extraction. |
|
|
|
Args: |
|
transformer_model (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`): |
|
The transformer model to use. If `None`, the default model is used. |
|
additional_special_symbols (:obj:`int`, `optional`, defaults to 0): |
|
The number of additional special symbols to add to the tokenizer. |
|
num_layers (:obj:`int`, `optional`): |
|
The number of layers to use. If `None`, all layers are used. |
|
activation (:obj:`str`, `optional`, defaults to "gelu"): |
|
The activation function to use. |
|
linears_hidden_size (:obj:`int`, `optional`, defaults to 512): |
|
The hidden size of the linears. |
|
use_last_k_layers (:obj:`int`, `optional`, defaults to 1): |
|
The number of last layers to use. |
|
training (:obj:`bool`, `optional`, defaults to False): |
|
Whether the model is in training mode. |
|
device (:obj:`str` or :obj:`torch.device` or :obj:`None`, `optional`): |
|
The device to use. If `None`, the default device is used. |
|
tokenizer (:obj:`str` or :obj:`transformers.PreTrainedTokenizer` or :obj:`None`, `optional`): |
|
The tokenizer to use. If `None`, the default tokenizer is used. |
|
dataset (:obj:`IterableDataset` or :obj:`str` or :obj:`None`, `optional`): |
|
The dataset to use. If `None`, the default dataset is used. |
|
dataset_kwargs (:obj:`Dict[str, Any]` or :obj:`None`, `optional`): |
|
The keyword arguments to pass to the dataset class. |
|
default_reader_class (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`): |
|
The default reader class to use. If `None`, the default reader class is used. |
|
**kwargs: |
|
Keyword arguments. |
|
""" |
|
|
|
default_reader_class: str = ( |
|
"relik.reader.pytorch_modules.hf.modeling_relik.RelikReaderSpanModel" |
|
) |
|
default_data_class: str = "relik.reader.data.relik_reader_data.RelikDataset" |
|
|
|
def __init__( |
|
self, |
|
transformer_model: str | tr.PreTrainedModel | None = None, |
|
additional_special_symbols: int = 0, |
|
num_layers: int | None = None, |
|
activation: str = "gelu", |
|
linears_hidden_size: int | None = 512, |
|
use_last_k_layers: int = 1, |
|
training: bool = False, |
|
device: str | torch.device | None = None, |
|
tokenizer: str | tr.PreTrainedTokenizer | None = None, |
|
dataset: IterableDataset | str | None = None, |
|
dataset_kwargs: Dict[str, Any] | None = None, |
|
default_reader_class: tr.PreTrainedModel | str | None = None, |
|
**kwargs, |
|
): |
|
super().__init__( |
|
transformer_model=transformer_model, |
|
additional_special_symbols=additional_special_symbols, |
|
num_layers=num_layers, |
|
activation=activation, |
|
linears_hidden_size=linears_hidden_size, |
|
use_last_k_layers=use_last_k_layers, |
|
training=training, |
|
device=device, |
|
tokenizer=tokenizer, |
|
dataset=dataset, |
|
default_reader_class=default_reader_class, |
|
**kwargs, |
|
) |
|
|
|
self.dataset = dataset |
|
if self.dataset is None: |
|
default_data_kwargs = dict( |
|
dataset_path=None, |
|
materialize_samples=False, |
|
transformer_model=self.tokenizer, |
|
special_symbols=get_special_symbols( |
|
self.relik_reader_model.config.additional_special_symbols |
|
), |
|
for_inference=True, |
|
) |
|
|
|
default_data_kwargs.update(dataset_kwargs or {}) |
|
self.dataset = get_callable_from_string(self.default_data_class)( |
|
**default_data_kwargs |
|
) |
|
|
|
@torch.no_grad() |
|
@torch.inference_mode() |
|
def _read( |
|
self, |
|
samples: List[RelikReaderSample] | None = None, |
|
input_ids: torch.Tensor | None = None, |
|
attention_mask: torch.Tensor | None = None, |
|
token_type_ids: torch.Tensor | None = None, |
|
prediction_mask: torch.Tensor | None = None, |
|
special_symbols_mask: torch.Tensor | None = None, |
|
max_length: int = 1000, |
|
max_batch_size: int = 128, |
|
token_batch_size: int = 2048, |
|
precision: str = 32, |
|
annotation_type: str = "char", |
|
progress_bar: bool = False, |
|
*args: object, |
|
**kwargs: object, |
|
) -> List[RelikReaderSample] | List[List[RelikReaderSample]]: |
|
""" |
|
A wrapper around the forward method that returns the predicted labels for each sample. |
|
|
|
Args: |
|
samples (:obj:`List[RelikReaderSample]`, `optional`): |
|
The samples to read. If provided, `text` and `candidates` are ignored. |
|
input_ids (:obj:`torch.Tensor`, `optional`): |
|
The input ids of the text. If `samples` is provided, this is ignored. |
|
attention_mask (:obj:`torch.Tensor`, `optional`): |
|
The attention mask of the text. If `samples` is provided, this is ignored. |
|
token_type_ids (:obj:`torch.Tensor`, `optional`): |
|
The token type ids of the text. If `samples` is provided, this is ignored. |
|
prediction_mask (:obj:`torch.Tensor`, `optional`): |
|
The prediction mask of the text. If `samples` is provided, this is ignored. |
|
special_symbols_mask (:obj:`torch.Tensor`, `optional`): |
|
The special symbols mask of the text. If `samples` is provided, this is ignored. |
|
max_length (:obj:`int`, `optional`, defaults to 1000): |
|
The maximum length of the text. |
|
max_batch_size (:obj:`int`, `optional`, defaults to 128): |
|
The maximum batch size. |
|
token_batch_size (:obj:`int`, `optional`): |
|
The token batch size. |
|
progress_bar (:obj:`bool`, `optional`, defaults to False): |
|
Whether to show a progress bar. |
|
precision (:obj:`str`, `optional`, defaults to 32): |
|
The precision to use for the model. |
|
annotation_type (:obj:`str`, `optional`, defaults to "char"): |
|
The annotation type to use. It can be either "char", "token" or "word". |
|
*args: |
|
Positional arguments. |
|
**kwargs: |
|
Keyword arguments. |
|
|
|
Returns: |
|
:obj:`List[RelikReaderSample]` or :obj:`List[List[RelikReaderSample]]`: |
|
The predicted labels for each sample. |
|
""" |
|
|
|
precision = precision or self.precision |
|
if samples is not None: |
|
|
|
def _read_iterator(): |
|
def samples_it(): |
|
for i, sample in enumerate(samples): |
|
assert sample._mixin_prediction_position is None |
|
sample._mixin_prediction_position = i |
|
yield sample |
|
|
|
next_prediction_position = 0 |
|
position2predicted_sample = {} |
|
|
|
|
|
if self.dataset is None: |
|
raise ValueError( |
|
"You need to pass a dataset to the model in order to predict" |
|
) |
|
self.dataset.samples = samples_it() |
|
self.dataset.model_max_length = max_length |
|
self.dataset.tokens_per_batch = token_batch_size |
|
self.dataset.max_batch_size = max_batch_size |
|
|
|
|
|
iterator = DataLoader( |
|
self.dataset, batch_size=None, num_workers=0, shuffle=False |
|
) |
|
if progress_bar: |
|
iterator = tqdm(iterator, desc="Predicting with RelikReader") |
|
|
|
|
|
|
|
device_type_for_autocast = str(self.device).split(":")[0] |
|
|
|
autocast_mngr = ( |
|
contextlib.nullcontext() |
|
if device_type_for_autocast == "cpu" |
|
else ( |
|
torch.autocast( |
|
device_type=device_type_for_autocast, |
|
dtype=PRECISION_MAP[precision], |
|
) |
|
) |
|
) |
|
|
|
with autocast_mngr: |
|
for batch in iterator: |
|
batch = move_data_to_device(batch, self.device) |
|
batch_out = self._batch_predict(**batch) |
|
|
|
for sample in batch_out: |
|
if ( |
|
sample._mixin_prediction_position |
|
>= next_prediction_position |
|
): |
|
position2predicted_sample[ |
|
sample._mixin_prediction_position |
|
] = sample |
|
|
|
|
|
while next_prediction_position in position2predicted_sample: |
|
yield position2predicted_sample[next_prediction_position] |
|
del position2predicted_sample[next_prediction_position] |
|
next_prediction_position += 1 |
|
|
|
outputs = list(_read_iterator()) |
|
for sample in outputs: |
|
self.dataset.merge_patches_predictions(sample) |
|
self.dataset.convert_tokens_to_char_annotations(sample) |
|
|
|
else: |
|
outputs = list( |
|
self._batch_predict( |
|
input_ids, |
|
attention_mask, |
|
token_type_ids, |
|
prediction_mask, |
|
special_symbols_mask, |
|
*args, |
|
**kwargs, |
|
) |
|
) |
|
return outputs |
|
|
|
def _batch_predict( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
token_type_ids: torch.Tensor | None = None, |
|
prediction_mask: torch.Tensor | None = None, |
|
special_symbols_mask: torch.Tensor | None = None, |
|
sample: List[RelikReaderSample] | None = None, |
|
top_k: int = 5, |
|
*args, |
|
**kwargs, |
|
) -> Iterator[RelikReaderSample]: |
|
""" |
|
A wrapper around the forward method that returns the predicted labels for each sample. |
|
It also adds the predicted labels to the samples. |
|
|
|
Args: |
|
input_ids (:obj:`torch.Tensor`): |
|
The input ids of the text. |
|
attention_mask (:obj:`torch.Tensor`): |
|
The attention mask of the text. |
|
token_type_ids (:obj:`torch.Tensor`, `optional`): |
|
The token type ids of the text. |
|
prediction_mask (:obj:`torch.Tensor`, `optional`): |
|
The prediction mask of the text. |
|
special_symbols_mask (:obj:`torch.Tensor`, `optional`): |
|
The special symbols mask of the text. |
|
sample (:obj:`List[RelikReaderSample]`, `optional`): |
|
The samples to read. If provided, `text` and `candidates` are ignored. |
|
top_k (:obj:`int`, `optional`, defaults to 5): |
|
The amount of top-k most probable entities to predict. |
|
*args: |
|
Positional arguments. |
|
**kwargs: |
|
Keyword arguments. |
|
|
|
Returns: |
|
The predicted labels for each sample. |
|
""" |
|
forward_output = self.forward( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
prediction_mask=prediction_mask, |
|
special_symbols_mask=special_symbols_mask, |
|
) |
|
|
|
ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() |
|
ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy() |
|
ed_predictions = forward_output["ed_predictions"].cpu().numpy() |
|
ed_probabilities = forward_output["ed_probabilities"].cpu().numpy() |
|
|
|
batch_predictable_candidates = kwargs["predictable_candidates"] |
|
patch_offset = kwargs["patch_offset"] |
|
for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip( |
|
sample, |
|
ned_start_predictions, |
|
ned_end_predictions, |
|
ed_predictions, |
|
ed_probabilities, |
|
batch_predictable_candidates, |
|
patch_offset, |
|
): |
|
ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0] |
|
ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0] |
|
|
|
final_class2predicted_spans = collections.defaultdict(list) |
|
spans2predicted_probabilities = dict() |
|
for start_token_index, end_token_index in zip( |
|
ne_start_indices, ne_end_indices |
|
): |
|
|
|
token_class = edp[start_token_index + 1] - 1 |
|
predicted_candidate_title = pred_cands[token_class] |
|
final_class2predicted_spans[predicted_candidate_title].append( |
|
[start_token_index, end_token_index] |
|
) |
|
|
|
|
|
classes_probabilities = edpr[start_token_index + 1] |
|
classes_probabilities_best_indices = classes_probabilities.argsort()[ |
|
::-1 |
|
] |
|
titles_2_probs = [] |
|
top_k = ( |
|
min( |
|
top_k, |
|
len(classes_probabilities_best_indices), |
|
) |
|
if top_k != -1 |
|
else len(classes_probabilities_best_indices) |
|
) |
|
for i in range(top_k): |
|
titles_2_probs.append( |
|
( |
|
pred_cands[classes_probabilities_best_indices[i] - 1], |
|
classes_probabilities[ |
|
classes_probabilities_best_indices[i] |
|
].item(), |
|
) |
|
) |
|
spans2predicted_probabilities[ |
|
(start_token_index, end_token_index) |
|
] = titles_2_probs |
|
|
|
if "patches" not in ts._d: |
|
ts._d["patches"] = dict() |
|
|
|
ts._d["patches"][po] = dict() |
|
sample_patch = ts._d["patches"][po] |
|
|
|
sample_patch["predicted_window_labels"] = final_class2predicted_spans |
|
sample_patch["span_title_probabilities"] = spans2predicted_probabilities |
|
|
|
|
|
sample_patch["predictable_candidates"] = pred_cands |
|
|
|
yield ts |
|
|