|
import os |
|
from pathlib import Path |
|
from typing import Any, Callable, Dict, Optional, Union |
|
|
|
import hydra |
|
from omegaconf import OmegaConf |
|
from relik.retriever.indexers.faiss import FaissDocumentIndex |
|
from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel |
|
from rich.pretty import pprint |
|
|
|
from relik.common.log import get_console_logger, get_logger |
|
from relik.common.upload import upload |
|
from relik.common.utils import CONFIG_NAME, from_cache, get_callable_from_string |
|
from relik.inference.data.objects import EntitySpan, RelikOutput |
|
from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer |
|
from relik.inference.data.window.manager import WindowManager |
|
from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction |
|
from relik.reader.relik_reader import RelikReader |
|
from relik.retriever.data.utils import batch_generator |
|
from relik.retriever.indexers.base import BaseDocumentIndex |
|
from relik.retriever.pytorch_modules.model import GoldenRetriever |
|
|
|
logger = get_logger(__name__) |
|
console_logger = get_console_logger() |
|
|
|
|
|
class Relik: |
|
""" |
|
Relik main class. It is a wrapper around a retriever and a reader. |
|
|
|
Args: |
|
retriever (`Optional[GoldenRetriever]`, `optional`): |
|
The retriever to use. If `None`, a retriever will be instantiated from the |
|
provided `question_encoder`, `passage_encoder` and `document_index`. |
|
Defaults to `None`. |
|
question_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`): |
|
The question encoder to use. If `retriever` is `None`, a retriever will be |
|
instantiated from this parameter. Defaults to `None`. |
|
passage_encoder (`Optional[Union[str, GoldenRetrieverModel]]`, `optional`): |
|
The passage encoder to use. If `retriever` is `None`, a retriever will be |
|
instantiated from this parameter. Defaults to `None`. |
|
document_index (`Optional[Union[str, BaseDocumentIndex]]`, `optional`): |
|
The document index to use. If `retriever` is `None`, a retriever will be |
|
instantiated from this parameter. Defaults to `None`. |
|
reader (`Optional[Union[str, RelikReader]]`, `optional`): |
|
The reader to use. If `None`, a reader will be instantiated from the |
|
provided `reader`. Defaults to `None`. |
|
retriever_device (`str`, `optional`, defaults to `cpu`): |
|
The device to use for the retriever. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
retriever: GoldenRetriever | None = None, |
|
question_encoder: str | GoldenRetrieverModel | None = None, |
|
passage_encoder: str | GoldenRetrieverModel | None = None, |
|
document_index: str | BaseDocumentIndex | None = None, |
|
reader: str | RelikReader | None = None, |
|
device: str = "cpu", |
|
retriever_device: str | None = None, |
|
document_index_device: str | None = None, |
|
reader_device: str | None = None, |
|
precision: int = 32, |
|
retriever_precision: int | None = None, |
|
document_index_precision: int | None = None, |
|
reader_precision: int | None = None, |
|
reader_kwargs: dict | None = None, |
|
retriever_kwargs: dict | None = None, |
|
candidates_preprocessing_fn: str | Callable | None = None, |
|
top_k: int | None = None, |
|
window_size: int | None = None, |
|
window_stride: int | None = None, |
|
**kwargs, |
|
) -> None: |
|
|
|
retriever_device = retriever_device or device |
|
document_index_device = document_index_device or device |
|
retriever_precision = retriever_precision or precision |
|
document_index_precision = document_index_precision or precision |
|
if retriever is None and question_encoder is None: |
|
raise ValueError( |
|
"Either `retriever` or `question_encoder` must be provided" |
|
) |
|
if retriever is None: |
|
self.retriever_kwargs = dict( |
|
question_encoder=question_encoder, |
|
passage_encoder=passage_encoder, |
|
document_index=document_index, |
|
device=retriever_device, |
|
precision=retriever_precision, |
|
index_device=document_index_device, |
|
index_precision=document_index_precision, |
|
) |
|
|
|
self.retriever_kwargs.update(retriever_kwargs or {}) |
|
retriever = GoldenRetriever(**self.retriever_kwargs) |
|
retriever.training = False |
|
retriever.eval() |
|
self.retriever = retriever |
|
|
|
|
|
self.reader_device = reader_device or device |
|
self.reader_precision = reader_precision or precision |
|
self.reader_kwargs = reader_kwargs |
|
if isinstance(reader, str): |
|
reader_kwargs = reader_kwargs or {} |
|
reader = RelikReaderForSpanExtraction(reader, **reader_kwargs) |
|
self.reader = reader |
|
|
|
|
|
self.tokenizer = SpacyTokenizer(language="en") |
|
self.window_manager: WindowManager | None = None |
|
|
|
|
|
|
|
candidates_preprocessing_fn = candidates_preprocessing_fn or (lambda x: x) |
|
if isinstance(candidates_preprocessing_fn, str): |
|
candidates_preprocessing_fn = get_callable_from_string( |
|
candidates_preprocessing_fn |
|
) |
|
self.candidates_preprocessing_fn = candidates_preprocessing_fn |
|
|
|
|
|
self.top_k = top_k |
|
self.window_size = window_size |
|
self.window_stride = window_stride |
|
|
|
def __call__( |
|
self, |
|
text: Union[str, list], |
|
top_k: Optional[int] = None, |
|
window_size: Optional[int] = None, |
|
window_stride: Optional[int] = None, |
|
retriever_batch_size: Optional[int] = 32, |
|
reader_batch_size: Optional[int] = 32, |
|
return_also_windows: bool = False, |
|
**kwargs, |
|
) -> Union[RelikOutput, list[RelikOutput]]: |
|
""" |
|
Annotate a text with entities. |
|
|
|
Args: |
|
text (`str` or `list`): |
|
The text to annotate. If a list is provided, each element of the list |
|
will be annotated separately. |
|
top_k (`int`, `optional`, defaults to `None`): |
|
The number of candidates to retrieve for each window. |
|
window_size (`int`, `optional`, defaults to `None`): |
|
The size of the window. If `None`, the whole text will be annotated. |
|
window_stride (`int`, `optional`, defaults to `None`): |
|
The stride of the window. If `None`, there will be no overlap between windows. |
|
retriever_batch_size (`int`, `optional`, defaults to `None`): |
|
The batch size to use for the retriever. The whole input is the batch for the retriever. |
|
reader_batch_size (`int`, `optional`, defaults to `None`): |
|
The batch size to use for the reader. The whole input is the batch for the reader. |
|
return_also_windows (`bool`, `optional`, defaults to `False`): |
|
Whether to return the windows in the output. |
|
**kwargs: |
|
Additional keyword arguments to pass to the retriever and the reader. |
|
|
|
Returns: |
|
`RelikOutput` or `list[RelikOutput]`: |
|
The annotated text. If a list was provided as input, a list of |
|
`RelikOutput` objects will be returned. |
|
""" |
|
if top_k is None: |
|
top_k = self.top_k or 100 |
|
if window_size is None: |
|
window_size = self.window_size |
|
if window_stride is None: |
|
window_stride = self.window_stride |
|
|
|
if isinstance(text, str): |
|
text = [text] |
|
|
|
if window_size is not None: |
|
if self.window_manager is None: |
|
self.window_manager = WindowManager(self.tokenizer) |
|
|
|
if window_size == "sentence": |
|
|
|
raise NotImplementedError("Sentence windowizer not implemented yet") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
windows = [ |
|
window |
|
for doc_id, t in enumerate(text) |
|
for window in self.window_manager.create_windows( |
|
t, |
|
window_size=window_size, |
|
stride=window_stride, |
|
doc_id=doc_id, |
|
) |
|
] |
|
|
|
|
|
windows_candidates = [] |
|
|
|
for batch in batch_generator(windows, batch_size=retriever_batch_size): |
|
retriever_out = self.retriever.retrieve([b.text for b in batch], k=top_k) |
|
windows_candidates.extend( |
|
[[p.label for p in predictions] for predictions in retriever_out] |
|
) |
|
|
|
|
|
for window, candidates in zip(windows, windows_candidates): |
|
window.window_candidates = [ |
|
self.candidates_preprocessing_fn(c) for c in candidates |
|
] |
|
|
|
windows = self.reader.read(samples=windows, max_batch_size=reader_batch_size) |
|
windows = self.window_manager.merge_windows(windows) |
|
|
|
|
|
output = [] |
|
for w in windows: |
|
sample_output = RelikOutput( |
|
text=text[w.doc_id], |
|
labels=sorted( |
|
[ |
|
EntitySpan( |
|
start=ss, end=se, label=sl, text=text[w.doc_id][ss:se] |
|
) |
|
for ss, se, sl in w.predicted_window_labels_chars |
|
], |
|
key=lambda x: x.start, |
|
), |
|
) |
|
output.append(sample_output) |
|
|
|
if return_also_windows: |
|
for i, sample_output in enumerate(output): |
|
sample_output.windows = [w for w in windows if w.doc_id == i] |
|
|
|
|
|
if len(output) == 1: |
|
return output[0] |
|
|
|
return output |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
model_name_or_dir: Union[str, os.PathLike], |
|
config_kwargs: Optional[Dict] = None, |
|
config_file_name: str = CONFIG_NAME, |
|
*args, |
|
**kwargs, |
|
) -> "Relik": |
|
cache_dir = kwargs.pop("cache_dir", None) |
|
force_download = kwargs.pop("force_download", False) |
|
|
|
model_dir = from_cache( |
|
model_name_or_dir, |
|
filenames=[config_file_name], |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
) |
|
|
|
config_path = model_dir / config_file_name |
|
if not config_path.exists(): |
|
raise FileNotFoundError( |
|
f"Model configuration file not found at {config_path}." |
|
) |
|
|
|
|
|
config = OmegaConf.load(config_path) |
|
if config_kwargs is not None: |
|
|
|
config = OmegaConf.merge(config, OmegaConf.create(config_kwargs)) |
|
|
|
pprint(OmegaConf.to_container(config), console=console_logger, expand_all=True) |
|
|
|
|
|
relik = hydra.utils.instantiate(config, *args, **kwargs) |
|
|
|
return relik |
|
|
|
def save_pretrained( |
|
self, |
|
output_dir: Union[str, os.PathLike], |
|
config: Optional[Dict[str, Any]] = None, |
|
config_file_name: Optional[str] = None, |
|
save_weights: bool = False, |
|
push_to_hub: bool = False, |
|
model_id: Optional[str] = None, |
|
organization: Optional[str] = None, |
|
repo_name: Optional[str] = None, |
|
**kwargs, |
|
): |
|
""" |
|
Save the configuration of Relik to the specified directory as a YAML file. |
|
|
|
Args: |
|
output_dir (`str`): |
|
The directory to save the configuration file to. |
|
config (`Optional[Dict[str, Any]]`, `optional`): |
|
The configuration to save. If `None`, the current configuration will be |
|
saved. Defaults to `None`. |
|
config_file_name (`Optional[str]`, `optional`): |
|
The name of the configuration file. Defaults to `config.yaml`. |
|
save_weights (`bool`, `optional`): |
|
Whether to save the weights of the model. Defaults to `False`. |
|
push_to_hub (`bool`, `optional`): |
|
Whether to push the saved model to the hub. Defaults to `False`. |
|
model_id (`Optional[str]`, `optional`): |
|
The id of the model to push to the hub. If `None`, the name of the |
|
directory will be used. Defaults to `None`. |
|
organization (`Optional[str]`, `optional`): |
|
The organization to push the model to. Defaults to `None`. |
|
repo_name (`Optional[str]`, `optional`): |
|
The name of the repository to push the model to. Defaults to `None`. |
|
**kwargs: |
|
Additional keyword arguments to pass to `OmegaConf.save`. |
|
""" |
|
if config is None: |
|
|
|
config = { |
|
"_target_": f"{self.__class__.__module__}.{self.__class__.__name__}" |
|
} |
|
if self.retriever is not None: |
|
if self.retriever.question_encoder is not None: |
|
config[ |
|
"question_encoder" |
|
] = self.retriever.question_encoder.name_or_path |
|
if self.retriever.passage_encoder is not None: |
|
config[ |
|
"passage_encoder" |
|
] = self.retriever.passage_encoder.name_or_path |
|
if self.retriever.document_index is not None: |
|
config["document_index"] = self.retriever.document_index.name_or_dir |
|
if self.reader is not None: |
|
config["reader"] = self.reader.model_path |
|
|
|
config["retriever_kwargs"] = self.retriever_kwargs |
|
config["reader_kwargs"] = self.reader_kwargs |
|
|
|
config[ |
|
"candidates_preprocessing_fn" |
|
] = f"{self.candidates_preprocessing_fn.__module__}.{self.candidates_preprocessing_fn.__name__}" |
|
|
|
|
|
config["top_k"] = self.top_k |
|
config["window_size"] = self.window_size |
|
config["window_stride"] = self.window_stride |
|
|
|
config_file_name = config_file_name or CONFIG_NAME |
|
|
|
|
|
output_dir = Path(output_dir) |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
logger.info(f"Saving relik config to {output_dir / config_file_name}") |
|
|
|
pprint(config, console=console_logger, expand_all=True) |
|
OmegaConf.save(config, output_dir / config_file_name) |
|
|
|
if save_weights: |
|
model_id = model_id or output_dir.name |
|
retriever_model_id = model_id + "-retriever" |
|
|
|
logger.info(f"Saving retriever to {output_dir / retriever_model_id}") |
|
self.retriever.save_pretrained( |
|
output_dir / retriever_model_id, |
|
question_encoder_name=retriever_model_id + "-question-encoder", |
|
passage_encoder_name=retriever_model_id + "-passage-encoder", |
|
document_index_name=retriever_model_id + "-index", |
|
push_to_hub=push_to_hub, |
|
organization=organization, |
|
repo_name=repo_name, |
|
**kwargs, |
|
) |
|
reader_model_id = model_id + "-reader" |
|
logger.info(f"Saving reader to {output_dir / reader_model_id}") |
|
self.reader.save_pretrained( |
|
output_dir / reader_model_id, |
|
push_to_hub=push_to_hub, |
|
organization=organization, |
|
repo_name=repo_name, |
|
**kwargs, |
|
) |
|
|
|
if push_to_hub: |
|
|
|
logger.info(f"Pushing to hub") |
|
model_id = model_id or output_dir.name |
|
upload(output_dir, model_id, organization=organization, repo_name=repo_name) |
|
|
|
|
|
def main(): |
|
from pprint import pprint |
|
|
|
document_index = FaissDocumentIndex.from_pretrained( |
|
"/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index", |
|
config_kwargs={"_target_": "relik.retriever.indexers.faiss.FaissDocumentIndex", "index_type": "IVFx,Flat"}, |
|
) |
|
|
|
relik = Relik( |
|
question_encoder="/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder", |
|
document_index=document_index, |
|
reader="/root/relik-spaces/models/relik-reader-aida-deberta-small", |
|
device="cuda", |
|
precision=16, |
|
top_k=100, |
|
window_size=32, |
|
window_stride=16, |
|
candidates_preprocessing_fn="relik.inference.preprocessing.wikipedia_title_and_openings_preprocessing", |
|
) |
|
|
|
input_text = """ |
|
Bernie Ecclestone, the former boss of Formula One, has admitted fraud after failing to declare more than £400m held in a trust in Singapore. |
|
The 92-year-old billionaire did not disclose the trust to the government in July 2015. |
|
Appearing at Southwark Crown Court on Thursday, he told the judge "I plead guilty" after having previously pleaded not guilty. |
|
Ecclestone had been due to go on trial next month. |
|
""" |
|
|
|
preds = relik(input_text) |
|
pprint(preds) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|