|
import logging |
|
import os |
|
import platform |
|
from dataclasses import dataclass |
|
from functools import partial |
|
from pathlib import Path |
|
from typing import Callable, Dict, List, Optional, Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import transformers as tr |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
|
|
from relik.common.log import get_logger |
|
from relik.common.torch_utils import ( |
|
get_autocast_context, |
|
) |
|
from relik.common.utils import is_package_available, to_config |
|
from relik.retriever.common.model_inputs import ModelInputs |
|
from relik.retriever.data.base.datasets import BaseDataset |
|
from relik.retriever.data.labels import Labels |
|
from relik.retriever.indexers.base import BaseDocumentIndex |
|
from relik.retriever.indexers.document import Document |
|
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex |
|
from relik.retriever.pytorch_modules import PRECISION_MAP, RetrievedSample |
|
from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel |
|
|
|
|
|
if is_package_available("onnxruntime"): |
|
from optimum.onnxruntime import ORTModel |
|
|
|
logger = get_logger(__name__, level=logging.INFO) |
|
|
|
|
|
@dataclass |
|
class GoldenRetrieverOutput(tr.file_utils.ModelOutput): |
|
"""Class for model's outputs.""" |
|
|
|
logits: Optional[torch.FloatTensor] = None |
|
loss: Optional[torch.FloatTensor] = None |
|
question_encodings: Optional[torch.FloatTensor] = None |
|
passages_encodings: Optional[torch.FloatTensor] = None |
|
|
|
|
|
class GoldenRetriever(torch.nn.Module): |
|
def __init__( |
|
self, |
|
question_encoder: Union[str, tr.PreTrainedModel], |
|
loss_type: Optional[torch.nn.Module] = None, |
|
passage_encoder: Optional[Union[str, tr.PreTrainedModel]] = None, |
|
document_index: Optional[Union[str, BaseDocumentIndex]] = None, |
|
question_tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None, |
|
passage_tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None, |
|
device: Optional[Union[str, torch.device]] = "cpu", |
|
precision: Optional[Union[str, int]] = None, |
|
index_precision: Optional[Union[str, int]] = None, |
|
index_device: Optional[Union[str, torch.device]] = None, |
|
*args, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
|
|
self.passage_encoder_is_question_encoder = False |
|
|
|
if isinstance(question_encoder, str): |
|
question_encoder = GoldenRetrieverModel.from_pretrained( |
|
question_encoder, **kwargs |
|
) |
|
self.question_encoder = question_encoder |
|
if passage_encoder is None: |
|
|
|
|
|
passage_encoder = question_encoder |
|
|
|
self.passage_encoder_is_question_encoder = True |
|
if isinstance(passage_encoder, str): |
|
passage_encoder = GoldenRetrieverModel.from_pretrained( |
|
passage_encoder, **kwargs |
|
) |
|
|
|
self.passage_encoder = passage_encoder |
|
|
|
|
|
self.loss_type = loss_type |
|
|
|
|
|
index_device = index_device or device |
|
index_precision = index_precision or precision |
|
if document_index is None: |
|
|
|
document_index = InMemoryDocumentIndex( |
|
device=index_device, precision=index_precision, **kwargs |
|
) |
|
if isinstance(document_index, str): |
|
document_index = BaseDocumentIndex.from_pretrained( |
|
document_index, device=index_device, precision=index_precision, **kwargs |
|
) |
|
self.document_index = document_index |
|
|
|
|
|
self._question_tokenizer = question_tokenizer |
|
self._passage_tokenizer = passage_tokenizer |
|
|
|
|
|
self.to(device or torch.device("cpu")) |
|
|
|
|
|
self.precision = precision |
|
|
|
def forward( |
|
self, |
|
questions: Optional[Dict[str, torch.Tensor]] = None, |
|
passages: Optional[Dict[str, torch.Tensor]] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
question_encodings: Optional[torch.Tensor] = None, |
|
passages_encodings: Optional[torch.Tensor] = None, |
|
passages_per_question: Optional[List[int]] = None, |
|
return_loss: bool = False, |
|
return_encodings: bool = False, |
|
*args, |
|
**kwargs, |
|
) -> GoldenRetrieverOutput: |
|
""" |
|
Forward pass of the model. |
|
|
|
Args: |
|
questions (`Dict[str, torch.Tensor]`): |
|
The questions to encode. |
|
passages (`Dict[str, torch.Tensor]`): |
|
The passages to encode. |
|
labels (`torch.Tensor`): |
|
The labels of the sentences. |
|
return_loss (`bool`): |
|
Whether to compute the predictions. |
|
question_encodings (`torch.Tensor`): |
|
The encodings of the questions. |
|
passages_encodings (`torch.Tensor`): |
|
The encodings of the passages. |
|
passages_per_question (`List[int]`): |
|
The number of passages per question. |
|
return_loss (`bool`): |
|
Whether to compute the loss. |
|
return_encodings (`bool`): |
|
Whether to return the encodings. |
|
|
|
Returns: |
|
obj:`torch.Tensor`: The outputs of the model. |
|
""" |
|
if questions is None and question_encodings is None: |
|
raise ValueError( |
|
"Either `questions` or `question_encodings` must be provided" |
|
) |
|
if passages is None and passages_encodings is None: |
|
raise ValueError( |
|
"Either `passages` or `passages_encodings` must be provided" |
|
) |
|
|
|
if question_encodings is None: |
|
question_encodings = self.question_encoder(**questions).pooler_output |
|
if passages_encodings is None: |
|
passages_encodings = self.passage_encoder(**passages).pooler_output |
|
|
|
if passages_per_question is not None: |
|
|
|
concatenated_passages = torch.stack( |
|
torch.split(passages_encodings, passages_per_question) |
|
).transpose(1, 2) |
|
if isinstance(self.loss_type, torch.nn.BCEWithLogitsLoss): |
|
|
|
concatenated_passages = F.normalize(concatenated_passages, p=2, dim=2) |
|
question_encodings = F.normalize(question_encodings, p=2, dim=1) |
|
logits = torch.bmm( |
|
question_encodings.unsqueeze(1), concatenated_passages |
|
).view(question_encodings.shape[0], -1) |
|
else: |
|
if isinstance(self.loss_type, torch.nn.BCEWithLogitsLoss): |
|
|
|
question_encodings = F.normalize(question_encodings, p=2, dim=1) |
|
passages_encodings = F.normalize(passages_encodings, p=2, dim=1) |
|
|
|
logits = torch.matmul(question_encodings, passages_encodings.T) |
|
|
|
output = dict(logits=logits) |
|
|
|
if return_loss and labels is not None: |
|
if self.loss_type is None: |
|
raise ValueError( |
|
"If `return_loss` is set to `True`, `loss_type` must be provided" |
|
) |
|
if isinstance(self.loss_type, torch.nn.NLLLoss): |
|
labels = labels.argmax(dim=1) |
|
logits = F.log_softmax(logits, dim=1) |
|
if len(question_encodings.size()) > 1: |
|
logits = logits.view(question_encodings.size(0), -1) |
|
|
|
output["loss"] = self.loss_type(logits, labels) |
|
|
|
if return_encodings: |
|
output["question_encodings"] = question_encodings |
|
output["passages_encodings"] = passages_encodings |
|
|
|
return GoldenRetrieverOutput(**output) |
|
|
|
@torch.no_grad() |
|
@torch.inference_mode() |
|
def index( |
|
self, |
|
batch_size: int = 32, |
|
num_workers: int = 4, |
|
max_length: int | None = None, |
|
collate_fn: Optional[Callable] = None, |
|
force_reindex: bool = False, |
|
compute_on_cpu: bool = False, |
|
precision: Optional[Union[str, int]] = None, |
|
*args, |
|
**kwargs, |
|
): |
|
""" |
|
Index the passages for later retrieval. |
|
|
|
Args: |
|
batch_size (`int`): |
|
The batch size to use for the indexing. |
|
num_workers (`int`): |
|
The number of workers to use for the indexing. |
|
max_length (`int | None`): |
|
The maximum length of the passages. |
|
collate_fn (`Callable`): |
|
The collate function to use for the indexing. |
|
force_reindex (`bool`): |
|
Whether to force reindexing even if the passages are already indexed. |
|
compute_on_cpu (`bool`): |
|
Whether to move the index to the CPU after the indexing. |
|
precision (`Optional[Union[str, int]]`): |
|
The precision to use for the model. |
|
""" |
|
if self.document_index is None: |
|
raise ValueError( |
|
"The retriever must be initialized with an indexer to index " |
|
"the passages within the retriever." |
|
) |
|
|
|
return self.document_index.index( |
|
retriever=self, |
|
batch_size=batch_size, |
|
num_workers=num_workers, |
|
max_length=max_length, |
|
collate_fn=collate_fn, |
|
encoder_precision=precision or self.precision, |
|
compute_on_cpu=compute_on_cpu, |
|
force_reindex=force_reindex, |
|
*args, |
|
**kwargs, |
|
) |
|
|
|
@torch.no_grad() |
|
@torch.inference_mode() |
|
def retrieve( |
|
self, |
|
text: Optional[Union[str, List[str]]] = None, |
|
text_pair: Optional[Union[str, List[str]]] = None, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
k: int | None = None, |
|
max_length: int | None = None, |
|
precision: Optional[Union[str, int]] = None, |
|
collate_fn: Optional[Callable] = None, |
|
batch_size: int | None = None, |
|
num_workers: int = 4, |
|
progress_bar: bool = False, |
|
**kwargs, |
|
) -> List[List[RetrievedSample]]: |
|
""" |
|
Retrieve the passages for the questions. |
|
|
|
Args: |
|
text (`Optional[Union[str, List[str]]]`): |
|
The questions to retrieve the passages for. |
|
text_pair (`Optional[Union[str, List[str]]]`): |
|
The questions to retrieve the passages for. |
|
input_ids (`torch.Tensor`): |
|
The input ids of the questions. |
|
attention_mask (`torch.Tensor`): |
|
The attention mask of the questions. |
|
token_type_ids (`torch.Tensor`): |
|
The token type ids of the questions. |
|
k (`int`): |
|
The number of top passages to retrieve. |
|
max_length (`int | None`): |
|
The maximum length of the questions. |
|
precision (`Optional[Union[str, int]]`): |
|
The precision to use for the model. |
|
collate_fn (`Callable`): |
|
The collate function to use for the retrieval. |
|
batch_size (`int`): |
|
The batch size to use for the retrieval. |
|
num_workers (`int`): |
|
The number of workers to use for the retrieval. |
|
progress_bar (`bool`): |
|
Whether to show a progress bar. |
|
|
|
Returns: |
|
`List[List[RetrievedSample]]`: The retrieved passages and their indices. |
|
""" |
|
if self.document_index is None: |
|
raise ValueError( |
|
"The indexer must be indexed before it can be used within the retriever." |
|
) |
|
if text is None and input_ids is None: |
|
raise ValueError( |
|
"Either `text` or `input_ids` must be provided to retrieve the passages." |
|
) |
|
|
|
if text is not None: |
|
if isinstance(text, str): |
|
text = [text] |
|
if text_pair is not None: |
|
if isinstance(text_pair, str): |
|
text_pair = [text_pair] |
|
else: |
|
text_pair = [None] * len(text) |
|
|
|
if collate_fn is None: |
|
tokenizer = self.question_tokenizer |
|
collate_fn = partial( |
|
self.default_collate_fn, max_length=max_length, tokenizer=tokenizer |
|
) |
|
|
|
dataloader = DataLoader( |
|
BaseDataset(name="questions", data=list(zip(text, text_pair))), |
|
batch_size=batch_size, |
|
shuffle=False, |
|
num_workers=num_workers, |
|
pin_memory=False, |
|
collate_fn=collate_fn, |
|
) |
|
else: |
|
model_inputs = ModelInputs(dict(input_ids=input_ids)) |
|
if attention_mask is not None: |
|
model_inputs["attention_mask"] = attention_mask |
|
if token_type_ids is not None: |
|
model_inputs["token_type_ids"] = token_type_ids |
|
|
|
dataloader = [model_inputs] |
|
|
|
if progress_bar: |
|
dataloader = tqdm(dataloader, desc="Retrieving passages") |
|
|
|
retrieved = [] |
|
try: |
|
with get_autocast_context(self.device, precision): |
|
for batch in dataloader: |
|
batch = batch.to(self.device) |
|
question_encodings = self.question_encoder(**batch).pooler_output |
|
retrieved += self.document_index.search(question_encodings, k) |
|
except AttributeError as e: |
|
|
|
if "mac" in platform.platform().lower(): |
|
raise ValueError( |
|
"DataLoader with num_workers > 0 is not supported on MacOS. " |
|
"Please set num_workers=0 or try to run on a different machine." |
|
) from e |
|
else: |
|
raise e |
|
|
|
if progress_bar: |
|
dataloader.close() |
|
|
|
return retrieved |
|
|
|
@staticmethod |
|
def default_collate_fn( |
|
x: tuple, tokenizer: tr.PreTrainedTokenizer, max_length: int | None = None |
|
) -> ModelInputs: |
|
|
|
|
|
_text = [sample[0] for sample in x] |
|
_text_pair = [sample[1] for sample in x] |
|
_text_pair = None if any([t is None for t in _text_pair]) else _text_pair |
|
return ModelInputs( |
|
tokenizer( |
|
_text, |
|
text_pair=_text_pair, |
|
padding=True, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=max_length or tokenizer.model_max_length, |
|
) |
|
) |
|
|
|
def get_document_from_index(self, index: int) -> Document: |
|
""" |
|
Get the document from its ID. |
|
|
|
Args: |
|
id (`int`): |
|
The ID of the document. |
|
|
|
Returns: |
|
`str`: The document. |
|
""" |
|
if self.document_index is None: |
|
raise ValueError( |
|
"The passages must be indexed before they can be retrieved." |
|
) |
|
return self.document_index.get_document_from_index(index) |
|
|
|
def get_document_from_passage(self, passage: str) -> Document: |
|
""" |
|
Get the document from its text. |
|
|
|
Args: |
|
passage (`str`): |
|
The passage of the document. |
|
|
|
Returns: |
|
`str`: The document. |
|
""" |
|
if self.document_index is None: |
|
raise ValueError( |
|
"The passages must be indexed before they can be retrieved." |
|
) |
|
return self.document_index.get_document_from_passage(passage) |
|
|
|
def get_index_from_passage(self, passage: str) -> int: |
|
""" |
|
Get the index of the passage. |
|
|
|
Args: |
|
passage (`str`): |
|
The passage to get the index for. |
|
|
|
Returns: |
|
`int`: The index of the passage. |
|
""" |
|
if self.document_index is None: |
|
raise ValueError( |
|
"The passages must be indexed before they can be retrieved." |
|
) |
|
return self.document_index.get_index_from_passage(passage) |
|
|
|
def get_passage_from_index(self, index: int) -> str: |
|
""" |
|
Get the passage from the index. |
|
|
|
Args: |
|
index (`int`): |
|
The index of the passage. |
|
|
|
Returns: |
|
`str`: The passage. |
|
""" |
|
if self.document_index is None: |
|
raise ValueError( |
|
"The passages must be indexed before they can be retrieved." |
|
) |
|
return self.document_index.get_passage_from_index(index) |
|
|
|
def get_vector_from_index(self, index: int) -> torch.Tensor: |
|
""" |
|
Get the passage vector from the index. |
|
|
|
Args: |
|
index (`int`): |
|
The index of the passage. |
|
|
|
Returns: |
|
`torch.Tensor`: The passage vector. |
|
""" |
|
if self.document_index is None: |
|
raise ValueError( |
|
"The passages must be indexed before they can be retrieved." |
|
) |
|
return self.document_index.get_embeddings_from_index(index) |
|
|
|
def get_vector_from_passage(self, passage: str) -> torch.Tensor: |
|
""" |
|
Get the passage vector from the passage. |
|
|
|
Args: |
|
passage (`str`): |
|
The passage. |
|
|
|
Returns: |
|
`torch.Tensor`: The passage vector. |
|
""" |
|
if self.document_index is None: |
|
raise ValueError( |
|
"The passages must be indexed before they can be retrieved." |
|
) |
|
return self.document_index.get_embeddings_from_passage(passage) |
|
|
|
@property |
|
def passage_embeddings(self) -> torch.Tensor: |
|
""" |
|
The passage embeddings. |
|
""" |
|
return self._passage_embeddings |
|
|
|
@property |
|
def passage_index(self) -> Labels: |
|
""" |
|
The passage index. |
|
""" |
|
return self._passage_index |
|
|
|
@property |
|
def device(self) -> torch.device: |
|
""" |
|
The device of the model. |
|
""" |
|
return next(self.parameters()).device |
|
|
|
@property |
|
def question_tokenizer(self) -> tr.PreTrainedTokenizer: |
|
""" |
|
The question tokenizer. |
|
""" |
|
if self._question_tokenizer: |
|
return self._question_tokenizer |
|
|
|
if ( |
|
self.question_encoder.config.name_or_path |
|
== self.question_encoder.config.name_or_path |
|
): |
|
if not self._question_tokenizer: |
|
self._question_tokenizer = tr.AutoTokenizer.from_pretrained( |
|
self.question_encoder.config.name_or_path |
|
) |
|
self._passage_tokenizer = self._question_tokenizer |
|
return self._question_tokenizer |
|
|
|
if not self._question_tokenizer: |
|
self._question_tokenizer = tr.AutoTokenizer.from_pretrained( |
|
self.question_encoder.config.name_or_path |
|
) |
|
return self._question_tokenizer |
|
|
|
@property |
|
def passage_tokenizer(self) -> tr.PreTrainedTokenizer: |
|
""" |
|
The passage tokenizer. |
|
""" |
|
if self._passage_tokenizer: |
|
return self._passage_tokenizer |
|
|
|
if ( |
|
self.question_encoder.config.name_or_path |
|
== self.passage_encoder.config.name_or_path |
|
): |
|
if not self._question_tokenizer: |
|
self._question_tokenizer = tr.AutoTokenizer.from_pretrained( |
|
self.question_encoder.config.name_or_path |
|
) |
|
self._passage_tokenizer = self._question_tokenizer |
|
return self._passage_tokenizer |
|
|
|
if not self._passage_tokenizer: |
|
self._passage_tokenizer = tr.AutoTokenizer.from_pretrained( |
|
self.passage_encoder.config.name_or_path |
|
) |
|
return self._passage_tokenizer |
|
|
|
def save_pretrained( |
|
self, |
|
output_dir: Union[str, os.PathLike], |
|
question_encoder_name: str | None = None, |
|
passage_encoder_name: str | None = None, |
|
document_index_name: str | None = None, |
|
push_to_hub: bool = False, |
|
**kwargs, |
|
): |
|
""" |
|
Save the retriever to a directory. |
|
|
|
Args: |
|
output_dir (`str`): |
|
The directory to save the retriever to. |
|
question_encoder_name (`str | None`): |
|
The name of the question encoder. |
|
passage_encoder_name (`str | None`): |
|
The name of the passage encoder. |
|
document_index_name (`str | None`): |
|
The name of the document index. |
|
push_to_hub (`bool`): |
|
Whether to push the model to the hub. |
|
""" |
|
|
|
|
|
output_dir = Path(output_dir) |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
logger.info(f"Saving retriever to {output_dir}") |
|
|
|
question_encoder_name = question_encoder_name or "question_encoder" |
|
passage_encoder_name = passage_encoder_name or "passage_encoder" |
|
document_index_name = document_index_name or "document_index" |
|
|
|
logger.info( |
|
f"Saving question encoder state to {output_dir / question_encoder_name}" |
|
) |
|
|
|
self.question_encoder.register_for_auto_class() |
|
self.question_encoder.save_pretrained( |
|
str(output_dir / question_encoder_name), push_to_hub=push_to_hub, **kwargs |
|
) |
|
self.question_tokenizer.save_pretrained( |
|
str(output_dir / question_encoder_name), push_to_hub=push_to_hub, **kwargs |
|
) |
|
if not self.passage_encoder_is_question_encoder: |
|
logger.info( |
|
f"Saving passage encoder state to {output_dir / passage_encoder_name}" |
|
) |
|
|
|
self.passage_encoder.register_for_auto_class() |
|
self.passage_encoder.save_pretrained( |
|
str(output_dir / passage_encoder_name), |
|
push_to_hub=push_to_hub, |
|
**kwargs, |
|
) |
|
self.passage_tokenizer.save_pretrained( |
|
output_dir / passage_encoder_name, push_to_hub=push_to_hub, **kwargs |
|
) |
|
|
|
if self.document_index is not None: |
|
|
|
self.document_index.save_pretrained( |
|
str(output_dir / document_index_name), push_to_hub=push_to_hub, **kwargs |
|
) |
|
|
|
logger.info("Saving retriever to disk done.") |
|
|
|
@classmethod |
|
def to_config(cls, *args, **kwargs): |
|
config = { |
|
"_target_": f"{cls.__class__.__module__}.{cls.__class__.__name__}", |
|
"question_encoder": cls.question_encoder.config.name_or_path, |
|
"passage_encoder": cls.passage_encoder.config.name_or_path |
|
if not cls.passage_encoder_is_question_encoder |
|
else None, |
|
"document_index": to_config(cls.document_index), |
|
} |
|
return config |
|
|