CarlosMalaga's picture
Upload 201 files
2f044c1 verified
raw
history blame
23.7 kB
import logging
import os
import platform
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
import torch
import torch.nn.functional as F
import transformers as tr
from torch.utils.data import DataLoader
from tqdm import tqdm
from relik.common.log import get_logger
from relik.common.torch_utils import (
get_autocast_context,
) # , # load_ort_optimized_hf_model
from relik.common.utils import is_package_available, to_config
from relik.retriever.common.model_inputs import ModelInputs
from relik.retriever.data.base.datasets import BaseDataset
from relik.retriever.data.labels import Labels
from relik.retriever.indexers.base import BaseDocumentIndex
from relik.retriever.indexers.document import Document
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
from relik.retriever.pytorch_modules import PRECISION_MAP, RetrievedSample
from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel
# check if ORT is available
if is_package_available("onnxruntime"):
from optimum.onnxruntime import ORTModel
logger = get_logger(__name__, level=logging.INFO)
@dataclass
class GoldenRetrieverOutput(tr.file_utils.ModelOutput):
"""Class for model's outputs."""
logits: Optional[torch.FloatTensor] = None
loss: Optional[torch.FloatTensor] = None
question_encodings: Optional[torch.FloatTensor] = None
passages_encodings: Optional[torch.FloatTensor] = None
class GoldenRetriever(torch.nn.Module):
def __init__(
self,
question_encoder: Union[str, tr.PreTrainedModel],
loss_type: Optional[torch.nn.Module] = None,
passage_encoder: Optional[Union[str, tr.PreTrainedModel]] = None,
document_index: Optional[Union[str, BaseDocumentIndex]] = None,
question_tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None,
passage_tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None,
device: Optional[Union[str, torch.device]] = "cpu",
precision: Optional[Union[str, int]] = None,
index_precision: Optional[Union[str, int]] = None,
index_device: Optional[Union[str, torch.device]] = None,
*args,
**kwargs,
):
super().__init__()
self.passage_encoder_is_question_encoder = False
# question encoder model
if isinstance(question_encoder, str):
question_encoder = GoldenRetrieverModel.from_pretrained(
question_encoder, **kwargs
)
self.question_encoder = question_encoder
if passage_encoder is None:
# if no passage encoder is provided,
# share the weights of the question encoder
passage_encoder = question_encoder
# keep track of the fact that the passage encoder is the same as the question encoder
self.passage_encoder_is_question_encoder = True
if isinstance(passage_encoder, str):
passage_encoder = GoldenRetrieverModel.from_pretrained(
passage_encoder, **kwargs
)
# passage encoder model
self.passage_encoder = passage_encoder
# loss function
self.loss_type = loss_type
# indexer stuff
index_device = index_device or device
index_precision = index_precision or precision
if document_index is None:
# if no indexer is provided, create a new one
document_index = InMemoryDocumentIndex(
device=index_device, precision=index_precision, **kwargs
)
if isinstance(document_index, str):
document_index = BaseDocumentIndex.from_pretrained(
document_index, device=index_device, precision=index_precision, **kwargs
)
self.document_index = document_index
# lazy load the tokenizer for inference
self._question_tokenizer = question_tokenizer
self._passage_tokenizer = passage_tokenizer
# move the model to the device
self.to(device or torch.device("cpu"))
# set the precision
self.precision = precision
def forward(
self,
questions: Optional[Dict[str, torch.Tensor]] = None,
passages: Optional[Dict[str, torch.Tensor]] = None,
labels: Optional[torch.Tensor] = None,
question_encodings: Optional[torch.Tensor] = None,
passages_encodings: Optional[torch.Tensor] = None,
passages_per_question: Optional[List[int]] = None,
return_loss: bool = False,
return_encodings: bool = False,
*args,
**kwargs,
) -> GoldenRetrieverOutput:
"""
Forward pass of the model.
Args:
questions (`Dict[str, torch.Tensor]`):
The questions to encode.
passages (`Dict[str, torch.Tensor]`):
The passages to encode.
labels (`torch.Tensor`):
The labels of the sentences.
return_loss (`bool`):
Whether to compute the predictions.
question_encodings (`torch.Tensor`):
The encodings of the questions.
passages_encodings (`torch.Tensor`):
The encodings of the passages.
passages_per_question (`List[int]`):
The number of passages per question.
return_loss (`bool`):
Whether to compute the loss.
return_encodings (`bool`):
Whether to return the encodings.
Returns:
obj:`torch.Tensor`: The outputs of the model.
"""
if questions is None and question_encodings is None:
raise ValueError(
"Either `questions` or `question_encodings` must be provided"
)
if passages is None and passages_encodings is None:
raise ValueError(
"Either `passages` or `passages_encodings` must be provided"
)
if question_encodings is None:
question_encodings = self.question_encoder(**questions).pooler_output
if passages_encodings is None:
passages_encodings = self.passage_encoder(**passages).pooler_output
if passages_per_question is not None:
# multiply each question encoding with a passages_per_question encodings
concatenated_passages = torch.stack(
torch.split(passages_encodings, passages_per_question)
).transpose(1, 2)
if isinstance(self.loss_type, torch.nn.BCEWithLogitsLoss):
# normalize the encodings for cosine similarity
concatenated_passages = F.normalize(concatenated_passages, p=2, dim=2)
question_encodings = F.normalize(question_encodings, p=2, dim=1)
logits = torch.bmm(
question_encodings.unsqueeze(1), concatenated_passages
).view(question_encodings.shape[0], -1)
else:
if isinstance(self.loss_type, torch.nn.BCEWithLogitsLoss):
# normalize the encodings for cosine similarity
question_encodings = F.normalize(question_encodings, p=2, dim=1)
passages_encodings = F.normalize(passages_encodings, p=2, dim=1)
logits = torch.matmul(question_encodings, passages_encodings.T)
output = dict(logits=logits)
if return_loss and labels is not None:
if self.loss_type is None:
raise ValueError(
"If `return_loss` is set to `True`, `loss_type` must be provided"
)
if isinstance(self.loss_type, torch.nn.NLLLoss):
labels = labels.argmax(dim=1)
logits = F.log_softmax(logits, dim=1)
if len(question_encodings.size()) > 1:
logits = logits.view(question_encodings.size(0), -1)
output["loss"] = self.loss_type(logits, labels)
if return_encodings:
output["question_encodings"] = question_encodings
output["passages_encodings"] = passages_encodings
return GoldenRetrieverOutput(**output)
@torch.no_grad()
@torch.inference_mode()
def index(
self,
batch_size: int = 32,
num_workers: int = 4,
max_length: int | None = None,
collate_fn: Optional[Callable] = None,
force_reindex: bool = False,
compute_on_cpu: bool = False,
precision: Optional[Union[str, int]] = None,
*args,
**kwargs,
):
"""
Index the passages for later retrieval.
Args:
batch_size (`int`):
The batch size to use for the indexing.
num_workers (`int`):
The number of workers to use for the indexing.
max_length (`int | None`):
The maximum length of the passages.
collate_fn (`Callable`):
The collate function to use for the indexing.
force_reindex (`bool`):
Whether to force reindexing even if the passages are already indexed.
compute_on_cpu (`bool`):
Whether to move the index to the CPU after the indexing.
precision (`Optional[Union[str, int]]`):
The precision to use for the model.
"""
if self.document_index is None:
raise ValueError(
"The retriever must be initialized with an indexer to index "
"the passages within the retriever."
)
# TODO: add kwargs
return self.document_index.index(
retriever=self,
batch_size=batch_size,
num_workers=num_workers,
max_length=max_length,
collate_fn=collate_fn,
encoder_precision=precision or self.precision,
compute_on_cpu=compute_on_cpu,
force_reindex=force_reindex,
*args,
**kwargs,
)
@torch.no_grad()
@torch.inference_mode()
def retrieve(
self,
text: Optional[Union[str, List[str]]] = None,
text_pair: Optional[Union[str, List[str]]] = None,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
k: int | None = None,
max_length: int | None = None,
precision: Optional[Union[str, int]] = None,
collate_fn: Optional[Callable] = None,
batch_size: int | None = None,
num_workers: int = 4,
progress_bar: bool = False,
**kwargs,
) -> List[List[RetrievedSample]]:
"""
Retrieve the passages for the questions.
Args:
text (`Optional[Union[str, List[str]]]`):
The questions to retrieve the passages for.
text_pair (`Optional[Union[str, List[str]]]`):
The questions to retrieve the passages for.
input_ids (`torch.Tensor`):
The input ids of the questions.
attention_mask (`torch.Tensor`):
The attention mask of the questions.
token_type_ids (`torch.Tensor`):
The token type ids of the questions.
k (`int`):
The number of top passages to retrieve.
max_length (`int | None`):
The maximum length of the questions.
precision (`Optional[Union[str, int]]`):
The precision to use for the model.
collate_fn (`Callable`):
The collate function to use for the retrieval.
batch_size (`int`):
The batch size to use for the retrieval.
num_workers (`int`):
The number of workers to use for the retrieval.
progress_bar (`bool`):
Whether to show a progress bar.
Returns:
`List[List[RetrievedSample]]`: The retrieved passages and their indices.
"""
if self.document_index is None:
raise ValueError(
"The indexer must be indexed before it can be used within the retriever."
)
if text is None and input_ids is None:
raise ValueError(
"Either `text` or `input_ids` must be provided to retrieve the passages."
)
if text is not None:
if isinstance(text, str):
text = [text]
if text_pair is not None:
if isinstance(text_pair, str):
text_pair = [text_pair]
else:
text_pair = [None] * len(text)
if collate_fn is None:
tokenizer = self.question_tokenizer
collate_fn = partial(
self.default_collate_fn, max_length=max_length, tokenizer=tokenizer
)
dataloader = DataLoader(
BaseDataset(name="questions", data=list(zip(text, text_pair))),
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=False,
collate_fn=collate_fn,
)
else:
model_inputs = ModelInputs(dict(input_ids=input_ids))
if attention_mask is not None:
model_inputs["attention_mask"] = attention_mask
if token_type_ids is not None:
model_inputs["token_type_ids"] = token_type_ids
dataloader = [model_inputs]
if progress_bar:
dataloader = tqdm(dataloader, desc="Retrieving passages")
retrieved = []
try:
with get_autocast_context(self.device, precision):
for batch in dataloader:
batch = batch.to(self.device)
question_encodings = self.question_encoder(**batch).pooler_output
retrieved += self.document_index.search(question_encodings, k)
except AttributeError as e:
# apparently num_workers > 0 gives some issue on MacOS as of now
if "mac" in platform.platform().lower():
raise ValueError(
"DataLoader with num_workers > 0 is not supported on MacOS. "
"Please set num_workers=0 or try to run on a different machine."
) from e
else:
raise e
if progress_bar:
dataloader.close()
return retrieved
@staticmethod
def default_collate_fn(
x: tuple, tokenizer: tr.PreTrainedTokenizer, max_length: int | None = None
) -> ModelInputs:
# get text and text pair
# TODO: check if only retriever is used
_text = [sample[0] for sample in x]
_text_pair = [sample[1] for sample in x]
_text_pair = None if any([t is None for t in _text_pair]) else _text_pair
return ModelInputs(
tokenizer(
_text,
text_pair=_text_pair,
padding=True,
return_tensors="pt",
truncation=True,
max_length=max_length or tokenizer.model_max_length,
)
)
def get_document_from_index(self, index: int) -> Document:
"""
Get the document from its ID.
Args:
id (`int`):
The ID of the document.
Returns:
`str`: The document.
"""
if self.document_index is None:
raise ValueError(
"The passages must be indexed before they can be retrieved."
)
return self.document_index.get_document_from_index(index)
def get_document_from_passage(self, passage: str) -> Document:
"""
Get the document from its text.
Args:
passage (`str`):
The passage of the document.
Returns:
`str`: The document.
"""
if self.document_index is None:
raise ValueError(
"The passages must be indexed before they can be retrieved."
)
return self.document_index.get_document_from_passage(passage)
def get_index_from_passage(self, passage: str) -> int:
"""
Get the index of the passage.
Args:
passage (`str`):
The passage to get the index for.
Returns:
`int`: The index of the passage.
"""
if self.document_index is None:
raise ValueError(
"The passages must be indexed before they can be retrieved."
)
return self.document_index.get_index_from_passage(passage)
def get_passage_from_index(self, index: int) -> str:
"""
Get the passage from the index.
Args:
index (`int`):
The index of the passage.
Returns:
`str`: The passage.
"""
if self.document_index is None:
raise ValueError(
"The passages must be indexed before they can be retrieved."
)
return self.document_index.get_passage_from_index(index)
def get_vector_from_index(self, index: int) -> torch.Tensor:
"""
Get the passage vector from the index.
Args:
index (`int`):
The index of the passage.
Returns:
`torch.Tensor`: The passage vector.
"""
if self.document_index is None:
raise ValueError(
"The passages must be indexed before they can be retrieved."
)
return self.document_index.get_embeddings_from_index(index)
def get_vector_from_passage(self, passage: str) -> torch.Tensor:
"""
Get the passage vector from the passage.
Args:
passage (`str`):
The passage.
Returns:
`torch.Tensor`: The passage vector.
"""
if self.document_index is None:
raise ValueError(
"The passages must be indexed before they can be retrieved."
)
return self.document_index.get_embeddings_from_passage(passage)
@property
def passage_embeddings(self) -> torch.Tensor:
"""
The passage embeddings.
"""
return self._passage_embeddings
@property
def passage_index(self) -> Labels:
"""
The passage index.
"""
return self._passage_index
@property
def device(self) -> torch.device:
"""
The device of the model.
"""
return next(self.parameters()).device
@property
def question_tokenizer(self) -> tr.PreTrainedTokenizer:
"""
The question tokenizer.
"""
if self._question_tokenizer:
return self._question_tokenizer
if (
self.question_encoder.config.name_or_path
== self.question_encoder.config.name_or_path
):
if not self._question_tokenizer:
self._question_tokenizer = tr.AutoTokenizer.from_pretrained(
self.question_encoder.config.name_or_path
)
self._passage_tokenizer = self._question_tokenizer
return self._question_tokenizer
if not self._question_tokenizer:
self._question_tokenizer = tr.AutoTokenizer.from_pretrained(
self.question_encoder.config.name_or_path
)
return self._question_tokenizer
@property
def passage_tokenizer(self) -> tr.PreTrainedTokenizer:
"""
The passage tokenizer.
"""
if self._passage_tokenizer:
return self._passage_tokenizer
if (
self.question_encoder.config.name_or_path
== self.passage_encoder.config.name_or_path
):
if not self._question_tokenizer:
self._question_tokenizer = tr.AutoTokenizer.from_pretrained(
self.question_encoder.config.name_or_path
)
self._passage_tokenizer = self._question_tokenizer
return self._passage_tokenizer
if not self._passage_tokenizer:
self._passage_tokenizer = tr.AutoTokenizer.from_pretrained(
self.passage_encoder.config.name_or_path
)
return self._passage_tokenizer
def save_pretrained(
self,
output_dir: Union[str, os.PathLike],
question_encoder_name: str | None = None,
passage_encoder_name: str | None = None,
document_index_name: str | None = None,
push_to_hub: bool = False,
**kwargs,
):
"""
Save the retriever to a directory.
Args:
output_dir (`str`):
The directory to save the retriever to.
question_encoder_name (`str | None`):
The name of the question encoder.
passage_encoder_name (`str | None`):
The name of the passage encoder.
document_index_name (`str | None`):
The name of the document index.
push_to_hub (`bool`):
Whether to push the model to the hub.
"""
# create the output directory
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving retriever to {output_dir}")
question_encoder_name = question_encoder_name or "question_encoder"
passage_encoder_name = passage_encoder_name or "passage_encoder"
document_index_name = document_index_name or "document_index"
logger.info(
f"Saving question encoder state to {output_dir / question_encoder_name}"
)
# self.question_encoder.config._name_or_path = question_encoder_name
self.question_encoder.register_for_auto_class()
self.question_encoder.save_pretrained(
str(output_dir / question_encoder_name), push_to_hub=push_to_hub, **kwargs
)
self.question_tokenizer.save_pretrained(
str(output_dir / question_encoder_name), push_to_hub=push_to_hub, **kwargs
)
if not self.passage_encoder_is_question_encoder:
logger.info(
f"Saving passage encoder state to {output_dir / passage_encoder_name}"
)
# self.passage_encoder.config._name_or_path = passage_encoder_name
self.passage_encoder.register_for_auto_class()
self.passage_encoder.save_pretrained(
str(output_dir / passage_encoder_name),
push_to_hub=push_to_hub,
**kwargs,
)
self.passage_tokenizer.save_pretrained(
output_dir / passage_encoder_name, push_to_hub=push_to_hub, **kwargs
)
if self.document_index is not None:
# save the indexer
self.document_index.save_pretrained(
str(output_dir / document_index_name), push_to_hub=push_to_hub, **kwargs
)
logger.info("Saving retriever to disk done.")
@classmethod
def to_config(cls, *args, **kwargs):
config = {
"_target_": f"{cls.__class__.__module__}.{cls.__class__.__name__}",
"question_encoder": cls.question_encoder.config.name_or_path,
"passage_encoder": cls.passage_encoder.config.name_or_path
if not cls.passage_encoder_is_question_encoder
else None,
"document_index": to_config(cls.document_index),
}
return config