riccorl's picture
first commit
626eca0
raw
history blame
No virus
20 kB
import contextlib
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
import torch
import torch.nn.functional as F
import transformers as tr
from relik.common.log import get_console_logger, get_logger
from relik.retriever.common.model_inputs import ModelInputs
from relik.retriever.data.labels import Labels
from relik.retriever.indexers.base import BaseDocumentIndex
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
from relik.retriever.pytorch_modules import PRECISION_MAP, RetrievedSample
from relik.retriever.pytorch_modules.hf import GoldenRetrieverModel
console_logger = get_console_logger()
logger = get_logger(__name__, level=logging.INFO)
@dataclass
class GoldenRetrieverOutput(tr.file_utils.ModelOutput):
"""Class for model's outputs."""
logits: Optional[torch.FloatTensor] = None
loss: Optional[torch.FloatTensor] = None
question_encodings: Optional[torch.FloatTensor] = None
passages_encodings: Optional[torch.FloatTensor] = None
class GoldenRetriever(torch.nn.Module):
def __init__(
self,
question_encoder: Union[str, tr.PreTrainedModel],
loss_type: Optional[torch.nn.Module] = None,
passage_encoder: Optional[Union[str, tr.PreTrainedModel]] = None,
document_index: Optional[Union[str, BaseDocumentIndex]] = None,
question_tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None,
passage_tokenizer: Optional[Union[str, tr.PreTrainedTokenizer]] = None,
device: Optional[Union[str, torch.device]] = None,
precision: Optional[Union[str, int]] = None,
index_precision: Optional[Union[str, int]] = 32,
index_device: Optional[Union[str, torch.device]] = "cpu",
*args,
**kwargs,
):
super().__init__()
self.passage_encoder_is_question_encoder = False
# question encoder model
if isinstance(question_encoder, str):
question_encoder = GoldenRetrieverModel.from_pretrained(
question_encoder, **kwargs
)
self.question_encoder = question_encoder
if passage_encoder is None:
# if no passage encoder is provided,
# share the weights of the question encoder
passage_encoder = question_encoder
# keep track of the fact that the passage encoder is the same as the question encoder
self.passage_encoder_is_question_encoder = True
if isinstance(passage_encoder, str):
passage_encoder = GoldenRetrieverModel.from_pretrained(
passage_encoder, **kwargs
)
# passage encoder model
self.passage_encoder = passage_encoder
# loss function
self.loss_type = loss_type
# indexer stuff
if document_index is None:
# if no indexer is provided, create a new one
document_index = InMemoryDocumentIndex(
device=index_device, precision=index_precision, **kwargs
)
if isinstance(document_index, str):
document_index = BaseDocumentIndex.from_pretrained(
document_index, device=index_device, precision=index_precision, **kwargs
)
self.document_index = document_index
# lazy load the tokenizer for inference
self._question_tokenizer = question_tokenizer
self._passage_tokenizer = passage_tokenizer
# move the model to the device
self.to(device or torch.device("cpu"))
# set the precision
self.precision = precision
def forward(
self,
questions: Optional[Dict[str, torch.Tensor]] = None,
passages: Optional[Dict[str, torch.Tensor]] = None,
labels: Optional[torch.Tensor] = None,
question_encodings: Optional[torch.Tensor] = None,
passages_encodings: Optional[torch.Tensor] = None,
passages_per_question: Optional[List[int]] = None,
return_loss: bool = False,
return_encodings: bool = False,
*args,
**kwargs,
) -> GoldenRetrieverOutput:
"""
Forward pass of the model.
Args:
questions (`Dict[str, torch.Tensor]`):
The questions to encode.
passages (`Dict[str, torch.Tensor]`):
The passages to encode.
labels (`torch.Tensor`):
The labels of the sentences.
return_loss (`bool`):
Whether to compute the predictions.
question_encodings (`torch.Tensor`):
The encodings of the questions.
passages_encodings (`torch.Tensor`):
The encodings of the passages.
passages_per_question (`List[int]`):
The number of passages per question.
return_loss (`bool`):
Whether to compute the loss.
return_encodings (`bool`):
Whether to return the encodings.
Returns:
obj:`torch.Tensor`: The outputs of the model.
"""
if questions is None and question_encodings is None:
raise ValueError(
"Either `questions` or `question_encodings` must be provided"
)
if passages is None and passages_encodings is None:
raise ValueError(
"Either `passages` or `passages_encodings` must be provided"
)
if question_encodings is None:
question_encodings = self.question_encoder(**questions).pooler_output
if passages_encodings is None:
passages_encodings = self.passage_encoder(**passages).pooler_output
if passages_per_question is not None:
# multiply each question encoding with a passages_per_question encodings
concatenated_passages = torch.stack(
torch.split(passages_encodings, passages_per_question)
).transpose(1, 2)
if isinstance(self.loss_type, torch.nn.BCEWithLogitsLoss):
# normalize the encodings for cosine similarity
concatenated_passages = F.normalize(concatenated_passages, p=2, dim=2)
question_encodings = F.normalize(question_encodings, p=2, dim=1)
logits = torch.bmm(
question_encodings.unsqueeze(1), concatenated_passages
).view(question_encodings.shape[0], -1)
else:
if isinstance(self.loss_type, torch.nn.BCEWithLogitsLoss):
# normalize the encodings for cosine similarity
question_encodings = F.normalize(question_encodings, p=2, dim=1)
passages_encodings = F.normalize(passages_encodings, p=2, dim=1)
logits = torch.matmul(question_encodings, passages_encodings.T)
output = dict(logits=logits)
if return_loss and labels is not None:
if self.loss_type is None:
raise ValueError(
"If `return_loss` is set to `True`, `loss_type` must be provided"
)
if isinstance(self.loss_type, torch.nn.NLLLoss):
labels = labels.argmax(dim=1)
logits = F.log_softmax(logits, dim=1)
if len(question_encodings.size()) > 1:
logits = logits.view(question_encodings.size(0), -1)
output["loss"] = self.loss_type(logits, labels)
if return_encodings:
output["question_encodings"] = question_encodings
output["passages_encodings"] = passages_encodings
return GoldenRetrieverOutput(**output)
@torch.no_grad()
@torch.inference_mode()
def index(
self,
batch_size: int = 32,
num_workers: int = 4,
max_length: Optional[int] = None,
collate_fn: Optional[Callable] = None,
force_reindex: bool = False,
compute_on_cpu: bool = False,
precision: Optional[Union[str, int]] = None,
):
"""
Index the passages for later retrieval.
Args:
batch_size (`int`):
The batch size to use for the indexing.
num_workers (`int`):
The number of workers to use for the indexing.
max_length (`Optional[int]`):
The maximum length of the passages.
collate_fn (`Callable`):
The collate function to use for the indexing.
force_reindex (`bool`):
Whether to force reindexing even if the passages are already indexed.
compute_on_cpu (`bool`):
Whether to move the index to the CPU after the indexing.
precision (`Optional[Union[str, int]]`):
The precision to use for the model.
"""
if self.document_index is None:
raise ValueError(
"The retriever must be initialized with an indexer to index "
"the passages within the retriever."
)
return self.document_index.index(
retriever=self,
batch_size=batch_size,
num_workers=num_workers,
max_length=max_length,
collate_fn=collate_fn,
encoder_precision=precision or self.precision,
compute_on_cpu=compute_on_cpu,
force_reindex=force_reindex,
)
@torch.no_grad()
@torch.inference_mode()
def retrieve(
self,
text: Optional[Union[str, List[str]]] = None,
text_pair: Optional[Union[str, List[str]]] = None,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
k: Optional[int] = None,
max_length: Optional[int] = None,
precision: Optional[Union[str, int]] = None,
) -> List[List[RetrievedSample]]:
"""
Retrieve the passages for the questions.
Args:
text (`Optional[Union[str, List[str]]]`):
The questions to retrieve the passages for.
text_pair (`Optional[Union[str, List[str]]]`):
The questions to retrieve the passages for.
input_ids (`torch.Tensor`):
The input ids of the questions.
attention_mask (`torch.Tensor`):
The attention mask of the questions.
token_type_ids (`torch.Tensor`):
The token type ids of the questions.
k (`int`):
The number of top passages to retrieve.
max_length (`Optional[int]`):
The maximum length of the questions.
precision (`Optional[Union[str, int]]`):
The precision to use for the model.
Returns:
`List[List[RetrievedSample]]`: The retrieved passages and their indices.
"""
if self.document_index is None:
raise ValueError(
"The indexer must be indexed before it can be used within the retriever."
)
if text is None and input_ids is None:
raise ValueError(
"Either `text` or `input_ids` must be provided to retrieve the passages."
)
if text is not None:
if isinstance(text, str):
text = [text]
if text_pair is not None and isinstance(text_pair, str):
text_pair = [text_pair]
tokenizer = self.question_tokenizer
model_inputs = ModelInputs(
tokenizer(
text,
text_pair=text_pair,
padding=True,
return_tensors="pt",
truncation=True,
max_length=max_length or tokenizer.model_max_length,
)
)
else:
model_inputs = ModelInputs(dict(input_ids=input_ids))
if attention_mask is not None:
model_inputs["attention_mask"] = attention_mask
if token_type_ids is not None:
model_inputs["token_type_ids"] = token_type_ids
model_inputs.to(self.device)
# fucking autocast only wants pure strings like 'cpu' or 'cuda'
# we need to convert the model device to that
device_type_for_autocast = str(self.device).split(":")[0]
# autocast doesn't work with CPU and stuff different from bfloat16
autocast_pssg_mngr = (
contextlib.nullcontext()
if device_type_for_autocast == "cpu"
else (
torch.autocast(
device_type=device_type_for_autocast,
dtype=PRECISION_MAP[precision],
)
)
)
with autocast_pssg_mngr:
question_encodings = self.question_encoder(**model_inputs).pooler_output
# TODO: fix if encoder and index are on different device
return self.document_index.search(question_encodings, k)
def get_index_from_passage(self, passage: str) -> int:
"""
Get the index of the passage.
Args:
passage (`str`):
The passage to get the index for.
Returns:
`int`: The index of the passage.
"""
if self.document_index is None:
raise ValueError(
"The passages must be indexed before they can be retrieved."
)
return self.document_index.get_index_from_passage(passage)
def get_passage_from_index(self, index: int) -> str:
"""
Get the passage from the index.
Args:
index (`int`):
The index of the passage.
Returns:
`str`: The passage.
"""
if self.document_index is None:
raise ValueError(
"The passages must be indexed before they can be retrieved."
)
return self.document_index.get_passage_from_index(index)
def get_vector_from_index(self, index: int) -> torch.Tensor:
"""
Get the passage vector from the index.
Args:
index (`int`):
The index of the passage.
Returns:
`torch.Tensor`: The passage vector.
"""
if self.document_index is None:
raise ValueError(
"The passages must be indexed before they can be retrieved."
)
return self.document_index.get_embeddings_from_index(index)
def get_vector_from_passage(self, passage: str) -> torch.Tensor:
"""
Get the passage vector from the passage.
Args:
passage (`str`):
The passage.
Returns:
`torch.Tensor`: The passage vector.
"""
if self.document_index is None:
raise ValueError(
"The passages must be indexed before they can be retrieved."
)
return self.document_index.get_embeddings_from_passage(passage)
@property
def passage_embeddings(self) -> torch.Tensor:
"""
The passage embeddings.
"""
return self._passage_embeddings
@property
def passage_index(self) -> Labels:
"""
The passage index.
"""
return self._passage_index
@property
def device(self) -> torch.device:
"""
The device of the model.
"""
return next(self.parameters()).device
@property
def question_tokenizer(self) -> tr.PreTrainedTokenizer:
"""
The question tokenizer.
"""
if self._question_tokenizer:
return self._question_tokenizer
if (
self.question_encoder.config.name_or_path
== self.question_encoder.config.name_or_path
):
if not self._question_tokenizer:
self._question_tokenizer = tr.AutoTokenizer.from_pretrained(
self.question_encoder.config.name_or_path
)
self._passage_tokenizer = self._question_tokenizer
return self._question_tokenizer
if not self._question_tokenizer:
self._question_tokenizer = tr.AutoTokenizer.from_pretrained(
self.question_encoder.config.name_or_path
)
return self._question_tokenizer
@property
def passage_tokenizer(self) -> tr.PreTrainedTokenizer:
"""
The passage tokenizer.
"""
if self._passage_tokenizer:
return self._passage_tokenizer
if (
self.question_encoder.config.name_or_path
== self.passage_encoder.config.name_or_path
):
if not self._question_tokenizer:
self._question_tokenizer = tr.AutoTokenizer.from_pretrained(
self.question_encoder.config.name_or_path
)
self._passage_tokenizer = self._question_tokenizer
return self._passage_tokenizer
if not self._passage_tokenizer:
self._passage_tokenizer = tr.AutoTokenizer.from_pretrained(
self.passage_encoder.config.name_or_path
)
return self._passage_tokenizer
def save_pretrained(
self,
output_dir: Union[str, os.PathLike],
question_encoder_name: Optional[str] = None,
passage_encoder_name: Optional[str] = None,
document_index_name: Optional[str] = None,
push_to_hub: bool = False,
**kwargs,
):
"""
Save the retriever to a directory.
Args:
output_dir (`str`):
The directory to save the retriever to.
question_encoder_name (`Optional[str]`):
The name of the question encoder.
passage_encoder_name (`Optional[str]`):
The name of the passage encoder.
document_index_name (`Optional[str]`):
The name of the document index.
push_to_hub (`bool`):
Whether to push the model to the hub.
"""
# create the output directory
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving retriever to {output_dir}")
question_encoder_name = question_encoder_name or "question_encoder"
passage_encoder_name = passage_encoder_name or "passage_encoder"
document_index_name = document_index_name or "document_index"
logger.info(
f"Saving question encoder state to {output_dir / question_encoder_name}"
)
# self.question_encoder.config._name_or_path = question_encoder_name
self.question_encoder.register_for_auto_class()
self.question_encoder.save_pretrained(
output_dir / question_encoder_name, push_to_hub=push_to_hub, **kwargs
)
self.question_tokenizer.save_pretrained(
output_dir / question_encoder_name, push_to_hub=push_to_hub, **kwargs
)
if not self.passage_encoder_is_question_encoder:
logger.info(
f"Saving passage encoder state to {output_dir / passage_encoder_name}"
)
# self.passage_encoder.config._name_or_path = passage_encoder_name
self.passage_encoder.register_for_auto_class()
self.passage_encoder.save_pretrained(
output_dir / passage_encoder_name, push_to_hub=push_to_hub, **kwargs
)
self.passage_tokenizer.save_pretrained(
output_dir / passage_encoder_name, push_to_hub=push_to_hub, **kwargs
)
if self.document_index is not None:
# save the indexer
self.document_index.save_pretrained(
output_dir / document_index_name, push_to_hub=push_to_hub, **kwargs
)
logger.info("Saving retriever to disk done.")