riccorl's picture
first commit
626eca0
raw
history blame
16 kB
import collections
import contextlib
import logging
from typing import Any, Dict, Iterator, List
import torch
import transformers as tr
from lightning_fabric.utilities import move_data_to_device
from torch.utils.data import DataLoader, IterableDataset
from tqdm import tqdm
from relik.common.log import get_console_logger, get_logger
from relik.common.utils import get_callable_from_string
from relik.reader.data.relik_reader_sample import RelikReaderSample
from relik.reader.pytorch_modules.base import RelikReaderBase
from relik.reader.utils.special_symbols import get_special_symbols
from relik.retriever.pytorch_modules import PRECISION_MAP
console_logger = get_console_logger()
logger = get_logger(__name__, level=logging.INFO)
class RelikReaderForSpanExtraction(RelikReaderBase):
"""
A class for the RelikReader model for span extraction.
Args:
transformer_model (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`):
The transformer model to use. If `None`, the default model is used.
additional_special_symbols (:obj:`int`, `optional`, defaults to 0):
The number of additional special symbols to add to the tokenizer.
num_layers (:obj:`int`, `optional`):
The number of layers to use. If `None`, all layers are used.
activation (:obj:`str`, `optional`, defaults to "gelu"):
The activation function to use.
linears_hidden_size (:obj:`int`, `optional`, defaults to 512):
The hidden size of the linears.
use_last_k_layers (:obj:`int`, `optional`, defaults to 1):
The number of last layers to use.
training (:obj:`bool`, `optional`, defaults to False):
Whether the model is in training mode.
device (:obj:`str` or :obj:`torch.device` or :obj:`None`, `optional`):
The device to use. If `None`, the default device is used.
tokenizer (:obj:`str` or :obj:`transformers.PreTrainedTokenizer` or :obj:`None`, `optional`):
The tokenizer to use. If `None`, the default tokenizer is used.
dataset (:obj:`IterableDataset` or :obj:`str` or :obj:`None`, `optional`):
The dataset to use. If `None`, the default dataset is used.
dataset_kwargs (:obj:`Dict[str, Any]` or :obj:`None`, `optional`):
The keyword arguments to pass to the dataset class.
default_reader_class (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`):
The default reader class to use. If `None`, the default reader class is used.
**kwargs:
Keyword arguments.
"""
default_reader_class: str = (
"relik.reader.pytorch_modules.hf.modeling_relik.RelikReaderSpanModel"
)
default_data_class: str = "relik.reader.data.relik_reader_data.RelikDataset"
def __init__(
self,
transformer_model: str | tr.PreTrainedModel | None = None,
additional_special_symbols: int = 0,
num_layers: int | None = None,
activation: str = "gelu",
linears_hidden_size: int | None = 512,
use_last_k_layers: int = 1,
training: bool = False,
device: str | torch.device | None = None,
tokenizer: str | tr.PreTrainedTokenizer | None = None,
dataset: IterableDataset | str | None = None,
dataset_kwargs: Dict[str, Any] | None = None,
default_reader_class: tr.PreTrainedModel | str | None = None,
**kwargs,
):
super().__init__(
transformer_model=transformer_model,
additional_special_symbols=additional_special_symbols,
num_layers=num_layers,
activation=activation,
linears_hidden_size=linears_hidden_size,
use_last_k_layers=use_last_k_layers,
training=training,
device=device,
tokenizer=tokenizer,
dataset=dataset,
default_reader_class=default_reader_class,
**kwargs,
)
# and instantiate the dataset class
self.dataset = dataset
if self.dataset is None:
default_data_kwargs = dict(
dataset_path=None,
materialize_samples=False,
transformer_model=self.tokenizer,
special_symbols=get_special_symbols(
self.relik_reader_model.config.additional_special_symbols
),
for_inference=True,
)
# merge the default data kwargs with the ones passed to the model
default_data_kwargs.update(dataset_kwargs or {})
self.dataset = get_callable_from_string(self.default_data_class)(
**default_data_kwargs
)
@torch.no_grad()
@torch.inference_mode()
def _read(
self,
samples: List[RelikReaderSample] | None = None,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
token_type_ids: torch.Tensor | None = None,
prediction_mask: torch.Tensor | None = None,
special_symbols_mask: torch.Tensor | None = None,
max_length: int = 1000,
max_batch_size: int = 128,
token_batch_size: int = 2048,
precision: str = 32,
annotation_type: str = "char",
progress_bar: bool = False,
*args: object,
**kwargs: object,
) -> List[RelikReaderSample] | List[List[RelikReaderSample]]:
"""
A wrapper around the forward method that returns the predicted labels for each sample.
Args:
samples (:obj:`List[RelikReaderSample]`, `optional`):
The samples to read. If provided, `text` and `candidates` are ignored.
input_ids (:obj:`torch.Tensor`, `optional`):
The input ids of the text. If `samples` is provided, this is ignored.
attention_mask (:obj:`torch.Tensor`, `optional`):
The attention mask of the text. If `samples` is provided, this is ignored.
token_type_ids (:obj:`torch.Tensor`, `optional`):
The token type ids of the text. If `samples` is provided, this is ignored.
prediction_mask (:obj:`torch.Tensor`, `optional`):
The prediction mask of the text. If `samples` is provided, this is ignored.
special_symbols_mask (:obj:`torch.Tensor`, `optional`):
The special symbols mask of the text. If `samples` is provided, this is ignored.
max_length (:obj:`int`, `optional`, defaults to 1000):
The maximum length of the text.
max_batch_size (:obj:`int`, `optional`, defaults to 128):
The maximum batch size.
token_batch_size (:obj:`int`, `optional`):
The token batch size.
progress_bar (:obj:`bool`, `optional`, defaults to False):
Whether to show a progress bar.
precision (:obj:`str`, `optional`, defaults to 32):
The precision to use for the model.
annotation_type (:obj:`str`, `optional`, defaults to "char"):
The annotation type to use. It can be either "char", "token" or "word".
*args:
Positional arguments.
**kwargs:
Keyword arguments.
Returns:
:obj:`List[RelikReaderSample]` or :obj:`List[List[RelikReaderSample]]`:
The predicted labels for each sample.
"""
precision = precision or self.precision
if samples is not None:
def _read_iterator():
def samples_it():
for i, sample in enumerate(samples):
assert sample._mixin_prediction_position is None
sample._mixin_prediction_position = i
yield sample
next_prediction_position = 0
position2predicted_sample = {}
# instantiate dataset
if self.dataset is None:
raise ValueError(
"You need to pass a dataset to the model in order to predict"
)
self.dataset.samples = samples_it()
self.dataset.model_max_length = max_length
self.dataset.tokens_per_batch = token_batch_size
self.dataset.max_batch_size = max_batch_size
# instantiate dataloader
iterator = DataLoader(
self.dataset, batch_size=None, num_workers=0, shuffle=False
)
if progress_bar:
iterator = tqdm(iterator, desc="Predicting with RelikReader")
# fucking autocast only wants pure strings like 'cpu' or 'cuda'
# we need to convert the model device to that
device_type_for_autocast = str(self.device).split(":")[0]
# autocast doesn't work with CPU and stuff different from bfloat16
autocast_mngr = (
contextlib.nullcontext()
if device_type_for_autocast == "cpu"
else (
torch.autocast(
device_type=device_type_for_autocast,
dtype=PRECISION_MAP[precision],
)
)
)
with autocast_mngr:
for batch in iterator:
batch = move_data_to_device(batch, self.device)
batch_out = self._batch_predict(**batch)
for sample in batch_out:
if (
sample._mixin_prediction_position
>= next_prediction_position
):
position2predicted_sample[
sample._mixin_prediction_position
] = sample
# yield
while next_prediction_position in position2predicted_sample:
yield position2predicted_sample[next_prediction_position]
del position2predicted_sample[next_prediction_position]
next_prediction_position += 1
outputs = list(_read_iterator())
for sample in outputs:
self.dataset.merge_patches_predictions(sample)
self.dataset.convert_tokens_to_char_annotations(sample)
else:
outputs = list(
self._batch_predict(
input_ids,
attention_mask,
token_type_ids,
prediction_mask,
special_symbols_mask,
*args,
**kwargs,
)
)
return outputs
def _batch_predict(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: torch.Tensor | None = None,
prediction_mask: torch.Tensor | None = None,
special_symbols_mask: torch.Tensor | None = None,
sample: List[RelikReaderSample] | None = None,
top_k: int = 5, # the amount of top-k most probable entities to predict
*args,
**kwargs,
) -> Iterator[RelikReaderSample]:
"""
A wrapper around the forward method that returns the predicted labels for each sample.
It also adds the predicted labels to the samples.
Args:
input_ids (:obj:`torch.Tensor`):
The input ids of the text.
attention_mask (:obj:`torch.Tensor`):
The attention mask of the text.
token_type_ids (:obj:`torch.Tensor`, `optional`):
The token type ids of the text.
prediction_mask (:obj:`torch.Tensor`, `optional`):
The prediction mask of the text.
special_symbols_mask (:obj:`torch.Tensor`, `optional`):
The special symbols mask of the text.
sample (:obj:`List[RelikReaderSample]`, `optional`):
The samples to read. If provided, `text` and `candidates` are ignored.
top_k (:obj:`int`, `optional`, defaults to 5):
The amount of top-k most probable entities to predict.
*args:
Positional arguments.
**kwargs:
Keyword arguments.
Returns:
The predicted labels for each sample.
"""
forward_output = self.forward(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
prediction_mask=prediction_mask,
special_symbols_mask=special_symbols_mask,
)
ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy()
ned_end_predictions = forward_output["ned_end_predictions"].cpu().numpy()
ed_predictions = forward_output["ed_predictions"].cpu().numpy()
ed_probabilities = forward_output["ed_probabilities"].cpu().numpy()
batch_predictable_candidates = kwargs["predictable_candidates"]
patch_offset = kwargs["patch_offset"]
for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip(
sample,
ned_start_predictions,
ned_end_predictions,
ed_predictions,
ed_probabilities,
batch_predictable_candidates,
patch_offset,
):
ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0]
ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0]
final_class2predicted_spans = collections.defaultdict(list)
spans2predicted_probabilities = dict()
for start_token_index, end_token_index in zip(
ne_start_indices, ne_end_indices
):
# predicted candidate
token_class = edp[start_token_index + 1] - 1
predicted_candidate_title = pred_cands[token_class]
final_class2predicted_spans[predicted_candidate_title].append(
[start_token_index, end_token_index]
)
# candidates probabilities
classes_probabilities = edpr[start_token_index + 1]
classes_probabilities_best_indices = classes_probabilities.argsort()[
::-1
]
titles_2_probs = []
top_k = (
min(
top_k,
len(classes_probabilities_best_indices),
)
if top_k != -1
else len(classes_probabilities_best_indices)
)
for i in range(top_k):
titles_2_probs.append(
(
pred_cands[classes_probabilities_best_indices[i] - 1],
classes_probabilities[
classes_probabilities_best_indices[i]
].item(),
)
)
spans2predicted_probabilities[
(start_token_index, end_token_index)
] = titles_2_probs
if "patches" not in ts._d:
ts._d["patches"] = dict()
ts._d["patches"][po] = dict()
sample_patch = ts._d["patches"][po]
sample_patch["predicted_window_labels"] = final_class2predicted_spans
sample_patch["span_title_probabilities"] = spans2predicted_probabilities
# additional info
sample_patch["predictable_candidates"] = pred_cands
yield ts