riccorl's picture
first commit
history blame
No virus
14.7 kB
import contextlib
import logging
import math
import os
from dataclasses import dataclass
from typing import Callable, List, Optional, Union
import numpy
import torch
from pytorch_modules import RetrievedSample
from torch.utils.data import DataLoader
from tqdm import tqdm
from relik.common.log import get_logger
from relik.common.utils import is_package_available
from relik.retriever.common.model_inputs import ModelInputs
from relik.retriever.data.base.datasets import BaseDataset
from relik.retriever.data.labels import Labels
from relik.retriever.indexers.base import BaseDocumentIndex
from relik.retriever.pytorch_modules import PRECISION_MAP
from relik.retriever.pytorch_modules.model import GoldenRetriever
if is_package_available("faiss"):
import faiss
import faiss.contrib.torch_utils
logger = get_logger(__name__, level=logging.INFO)
class FaissOutput:
indices: Union[torch.Tensor, numpy.ndarray]
distances: Union[torch.Tensor, numpy.ndarray]
class FaissDocumentIndex(BaseDocumentIndex):
DOCUMENTS_FILE_NAME = "documents.json"
EMBEDDINGS_FILE_NAME = "embeddings.pt"
INDEX_FILE_NAME = "index.faiss"
def __init__(
documents: Union[List[str], Labels],
embeddings: Optional[Union[torch.Tensor, numpy.ndarray]] = None,
index_type: str = "Flat",
metric: int = faiss.METRIC_INNER_PRODUCT,
normalize: bool = False,
device: str = "cpu",
name_or_dir: Optional[Union[str, os.PathLike]] = None,
) -> None:
super().__init__(documents, embeddings, name_or_dir)
if embeddings is not None and documents is not None:
logger.info("Both documents and embeddings are provided.")
if documents.get_label_size() != embeddings.shape[0]:
raise ValueError(
"The number of documents and embeddings must be the same."
# device to store the embeddings
self.device = device
# params
self.index_type = index_type
self.metric = metric
self.normalize = normalize
if index is not None:
self.embeddings = index
if self.device == "cuda":
# use a single GPU
faiss_resource = faiss.StandardGpuResources()
self.embeddings = faiss.index_cpu_to_gpu(
faiss_resource, 0, self.embeddings
if embeddings is not None:
# build the faiss index
logger.info("Building the index from the embeddings.")
self.embeddings = self._build_faiss_index(
def _build_faiss_index(
embeddings: Optional[Union[torch.Tensor, numpy.ndarray]],
index_type: str,
normalize: bool,
metric: int,
# build the faiss index
self.normalize = (
and metric == faiss.METRIC_INNER_PRODUCT
and not isinstance(embeddings, torch.Tensor)
if self.normalize:
index_type = f"L2norm,{index_type}"
faiss_vector_size = embeddings.shape[1]
if self.device == "cpu":
index_type = index_type.replace("x,", "x_HNSW32,")
index_type = index_type.replace(
"x", str(math.ceil(math.sqrt(faiss_vector_size)) * 4)
self.embeddings = faiss.index_factory(faiss_vector_size, index_type, metric)
# convert to GPU
if self.device == "cuda":
# use a single GPU
faiss_resource = faiss.StandardGpuResources()
self.embeddings = faiss.index_cpu_to_gpu(faiss_resource, 0, self.embeddings)
# move to CPU if embeddings is a torch.Tensor
embeddings = (
embeddings.cpu() if isinstance(embeddings, torch.Tensor) else embeddings
# convert to float32 if embeddings is a torch.Tensor and is float16
if isinstance(embeddings, torch.Tensor) and embeddings.dtype == torch.float16:
embeddings = embeddings.float()
# save parameters for saving/loading
self.index_type = index_type
self.metric = metric
# clear the embeddings to free up memory
embeddings = None
return self.embeddings
def index(
retriever: GoldenRetriever,
documents: Optional[List[str]] = None,
batch_size: int = 32,
num_workers: int = 4,
max_length: Optional[int] = None,
collate_fn: Optional[Callable] = None,
encoder_precision: Optional[Union[str, int]] = None,
compute_on_cpu: bool = False,
force_reindex: bool = False,
) -> "FaissDocumentIndex":
Index the documents using the encoder.
retriever (:obj:`torch.nn.Module`):
The encoder to be used for indexing.
documents (:obj:`List[str]`, `optional`, defaults to None):
The documents to be indexed.
batch_size (:obj:`int`, `optional`, defaults to 32):
The batch size to be used for indexing.
num_workers (:obj:`int`, `optional`, defaults to 4):
The number of workers to be used for indexing.
max_length (:obj:`int`, `optional`, defaults to None):
The maximum length of the input to the encoder.
collate_fn (:obj:`Callable`, `optional`, defaults to None):
The collate function to be used for batching.
encoder_precision (:obj:`Union[str, int]`, `optional`, defaults to None):
The precision to be used for the encoder.
compute_on_cpu (:obj:`bool`, `optional`, defaults to False):
Whether to compute the embeddings on CPU.
force_reindex (:obj:`bool`, `optional`, defaults to False):
Whether to force reindexing.
:obj:`InMemoryIndexer`: The indexer object.
if self.embeddings is not None and not force_reindex:
"Embeddings are already present and `force_reindex` is `False`. Skipping indexing."
if documents is None:
return self
# release the memory
if collate_fn is None:
tokenizer = retriever.passage_tokenizer
def collate_fn(x):
return ModelInputs(
max_length=max_length or tokenizer.model_max_length,
if force_reindex:
if documents is not None:
data = [k for k in self.documents.get_labels()]
if documents is not None:
data = [k for k in Labels(documents).get_labels()]
return self
dataloader = DataLoader(
BaseDataset(name="passage", data=data),
encoder = retriever.passage_encoder
# Create empty lists to store the passage embeddings and passage index
passage_embeddings: List[torch.Tensor] = []
encoder_device = "cpu" if compute_on_cpu else self.device
# fucking autocast only wants pure strings like 'cpu' or 'cuda'
# we need to convert the model device to that
device_type_for_autocast = str(encoder_device).split(":")[0]
# autocast doesn't work with CPU and stuff different from bfloat16
autocast_pssg_mngr = (
if device_type_for_autocast == "cpu"
else (
with autocast_pssg_mngr:
# Iterate through each batch in the dataloader
for batch in tqdm(dataloader, desc="Indexing"):
# Move the batch to the device
batch: ModelInputs = batch.to(encoder_device)
# Compute the passage embeddings
passage_outs = encoder(**batch)
# Append the passage embeddings to the list
if self.device == "cpu":
passage_embeddings.extend([c.detach().cpu() for c in passage_outs])
passage_embeddings.extend([c for c in passage_outs])
# move the passage embeddings to the CPU if not already done
passage_embeddings = [c.detach().cpu() for c in passage_embeddings]
# stack it
passage_embeddings: torch.Tensor = torch.stack(passage_embeddings, dim=0)
# convert to float32 for faiss
# index the embeddings
self.embeddings = self._build_faiss_index(
# free up memory from the unused variable
del passage_embeddings
return self
def search(self, query: torch.Tensor, k: int = 1) -> list[list[RetrievedSample]]:
k = min(k, self.embeddings.ntotal)
if self.normalize:
if isinstance(query, torch.Tensor) and self.device == "cpu":
query = query.detach().cpu()
# Retrieve the indices of the top k passage embeddings
retriever_out = self.embeddings.search(query, k)
# get int values (second element of the tuple)
batch_top_k: List[List[int]] = retriever_out[1].detach().cpu().tolist()
# get float values (first element of the tuple)
batch_scores: List[List[float]] = retriever_out[0].detach().cpu().tolist()
# Retrieve the passages corresponding to the indices
batch_passages = [
[self.documents.get_label_from_index(i) for i in indices]
for indices in batch_top_k
# build the output object
batch_retrieved_samples = [
RetrievedSample(label=passage, index=index, score=score)
for passage, index, score in zip(passages, indices, scores)
for passages, indices, scores in zip(
batch_passages, batch_top_k, batch_scores
return batch_retrieved_samples
# def save(self, saving_dir: Union[str, os.PathLike]):
# """
# Save the indexer to the disk.
# Args:
# saving_dir (:obj:`Union[str, os.PathLike]`):
# The directory where the indexer will be saved.
# """
# saving_dir = Path(saving_dir)
# # save the passage embeddings
# index_path = saving_dir / self.INDEX_FILE_NAME
# logger.info(f"Saving passage embeddings to {index_path}")
# faiss.write_index(self.embeddings, str(index_path))
# # save the passage index
# documents_path = saving_dir / self.DOCUMENTS_FILE_NAME
# logger.info(f"Saving passage index to {documents_path}")
# self.documents.save(documents_path)
# @classmethod
# def load(
# cls,
# loading_dir: Union[str, os.PathLike],
# device: str = "cpu",
# document_file_name: Optional[str] = None,
# embedding_file_name: Optional[str] = None,
# index_file_name: Optional[str] = None,
# **kwargs,
# ) -> "FaissDocumentIndex":
# loading_dir = Path(loading_dir)
# document_file_name = document_file_name or cls.DOCUMENTS_FILE_NAME
# embedding_file_name = embedding_file_name or cls.EMBEDDINGS_FILE_NAME
# index_file_name = index_file_name or cls.INDEX_FILE_NAME
# # load the documents
# documents_path = loading_dir / document_file_name
# if not documents_path.exists():
# raise ValueError(f"Document file `{documents_path}` does not exist.")
# logger.info(f"Loading documents from {documents_path}")
# documents = Labels.from_file(documents_path)
# index = None
# embeddings = None
# # try to load the index directly
# index_path = loading_dir / index_file_name
# if not index_path.exists():
# # try to load the embeddings
# embedding_path = loading_dir / embedding_file_name
# # run some checks
# if embedding_path.exists():
# logger.info(f"Loading embeddings from {embedding_path}")
# embeddings = torch.load(embedding_path, map_location="cpu")
# logger.warning(
# f"Index file `{index_path}` and embedding file `{embedding_path}` do not exist."
# )
# else:
# logger.info(f"Loading index from {index_path}")
# index = faiss.read_index(str(embedding_path))
# return cls(
# documents=documents,
# embeddings=embeddings,
# index=index,
# device=device,
# **kwargs,
# )
def get_embeddings_from_index(
self, index: int
) -> Union[torch.Tensor, numpy.ndarray]:
Get the document vector from the index.
index (`int`):
The index of the document.
`torch.Tensor`: The document vector.
if self.embeddings is None:
raise ValueError(
"The documents must be indexed before they can be retrieved."
if index >= self.embeddings.ntotal:
raise ValueError(
f"The index {index} is out of bounds. The maximum index is {self.embeddings.ntotal}."
return self.embeddings.reconstruct(index)