|
import contextlib |
|
import logging |
|
import os |
|
from dataclasses import dataclass |
|
from typing import Callable, List, Optional, Union |
|
|
|
import numpy |
|
import psutil |
|
import torch |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
|
|
from relik.common.log import get_logger |
|
from relik.common.utils import is_package_available |
|
from relik.retriever.common.model_inputs import ModelInputs |
|
from relik.retriever.data.base.datasets import BaseDataset |
|
from relik.retriever.indexers.base import BaseDocumentIndex |
|
from relik.retriever.indexers.document import Document, DocumentStore |
|
from relik.retriever.pytorch_modules import PRECISION_MAP, RetrievedSample |
|
from relik.retriever.pytorch_modules.model import GoldenRetriever |
|
|
|
if is_package_available("faiss"): |
|
import faiss |
|
import faiss.contrib.torch_utils |
|
|
|
logger = get_logger(__name__, level=logging.INFO) |
|
|
|
|
|
@dataclass |
|
class FaissOutput: |
|
indices: Union[torch.Tensor, numpy.ndarray] |
|
distances: Union[torch.Tensor, numpy.ndarray] |
|
|
|
|
|
class FaissDocumentIndex(BaseDocumentIndex): |
|
DOCUMENTS_FILE_NAME = "documents.json" |
|
EMBEDDINGS_FILE_NAME = "embeddings.pt" |
|
INDEX_FILE_NAME = "index.faiss" |
|
|
|
def __init__( |
|
self, |
|
documents: str |
|
| List[str] |
|
| os.PathLike |
|
| List[os.PathLike] |
|
| DocumentStore |
|
| None = None, |
|
embeddings: torch.Tensor | numpy.ndarray | None = None, |
|
metadata_fields: List[str] | None = None, |
|
separator: str = "<def>", |
|
name_or_path: str | os.PathLike | None = None, |
|
device: str = "cpu", |
|
index=None, |
|
index_type: str = "Flat", |
|
nprobe: int = 1, |
|
metric: int = faiss.METRIC_INNER_PRODUCT, |
|
normalize: bool = False, |
|
*args, |
|
**kwargs, |
|
) -> None: |
|
super().__init__( |
|
documents, embeddings, metadata_fields, separator, name_or_path, device |
|
) |
|
|
|
if embeddings is not None and documents is not None: |
|
logger.info("Both documents and embeddings are provided.") |
|
if len(documents) != embeddings.shape[0]: |
|
raise ValueError( |
|
"The number of documents and embeddings must be the same." |
|
) |
|
|
|
faiss.omp_set_num_threads(psutil.cpu_count(logical=False)) |
|
|
|
|
|
self.index_type = index_type |
|
self.metric = metric |
|
self.normalize = normalize |
|
|
|
if index is not None: |
|
self.embeddings = index |
|
if self.device == "cuda": |
|
|
|
faiss_resource = faiss.StandardGpuResources() |
|
self.embeddings = faiss.index_cpu_to_gpu( |
|
faiss_resource, 0, self.embeddings |
|
) |
|
else: |
|
if embeddings is not None: |
|
|
|
logger.info("Building the index from the embeddings.") |
|
self.embeddings = self._build_faiss_index( |
|
embeddings=embeddings, |
|
index_type=index_type, |
|
nprobe=nprobe, |
|
normalize=normalize, |
|
metric=metric, |
|
) |
|
|
|
def to( |
|
self, device_or_precision: str | torch.device | torch.dtype |
|
) -> "BaseDocumentIndex": |
|
""" |
|
Move the retriever to the specified device or precision. |
|
|
|
Args: |
|
device_or_precision (`str` | `torch.device` | `torch.dtype`): |
|
The device or precision to move the retriever to. |
|
|
|
Returns: |
|
`BaseDocumentIndex`: The retriever. |
|
""" |
|
if isinstance(device_or_precision, torch.dtype): |
|
|
|
|
|
|
|
logger.warning( |
|
"FaissDocumentIndex does not support precision conversion. Ignoring." |
|
) |
|
if device_or_precision == "cuda" and self.device == "cpu": |
|
|
|
faiss_resource = faiss.StandardGpuResources() |
|
self.embeddings = faiss.index_cpu_to_gpu(faiss_resource, 0, self.embeddings) |
|
elif device_or_precision == "cpu" and self.device == "cuda": |
|
|
|
self.embeddings = faiss.index_gpu_to_cpu(self.embeddings) |
|
else: |
|
logger.warning( |
|
f"Provided device `{device_or_precision}` is the same as the current device `{self.device}`." |
|
) |
|
return self |
|
|
|
@property |
|
def device(self): |
|
|
|
if faiss.get_num_gpus() > 0: |
|
return "cuda" |
|
return "cpu" |
|
|
|
def _build_faiss_index( |
|
self, |
|
embeddings: Optional[Union[torch.Tensor, numpy.ndarray]], |
|
index_type: str, |
|
nprobe: int, |
|
normalize: bool, |
|
metric: int, |
|
): |
|
|
|
self.normalize = ( |
|
normalize |
|
and metric == faiss.METRIC_INNER_PRODUCT |
|
and not isinstance(embeddings, torch.Tensor) |
|
) |
|
if self.normalize: |
|
index_type = f"L2norm,{index_type}" |
|
faiss_vector_size = embeddings.shape[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.embeddings = faiss.index_factory(faiss_vector_size, index_type, metric) |
|
|
|
|
|
if self.device == "cuda": |
|
|
|
faiss_resource = faiss.StandardGpuResources() |
|
self.embeddings = faiss.index_cpu_to_gpu(faiss_resource, 0, self.embeddings) |
|
else: |
|
|
|
embeddings = ( |
|
embeddings.cpu() if isinstance(embeddings, torch.Tensor) else embeddings |
|
) |
|
|
|
|
|
if isinstance(embeddings, torch.Tensor) and embeddings.dtype == torch.float16: |
|
embeddings = embeddings.float() |
|
|
|
logger.info("Training the index.") |
|
self.embeddings.train(embeddings) |
|
|
|
logger.info("Adding the embeddings to the index.") |
|
self.embeddings.add(embeddings) |
|
|
|
self.embeddings.nprobe = nprobe |
|
|
|
|
|
self.index_type = index_type |
|
self.metric = metric |
|
|
|
|
|
embeddings = None |
|
|
|
return self.embeddings |
|
|
|
@torch.no_grad() |
|
@torch.inference_mode() |
|
def index( |
|
self, |
|
retriever: GoldenRetriever, |
|
documents: Optional[List[Document]] = None, |
|
batch_size: int = 32, |
|
num_workers: int = 4, |
|
max_length: Optional[int] = None, |
|
collate_fn: Optional[Callable] = None, |
|
encoder_precision: Optional[Union[str, int]] = None, |
|
compute_on_cpu: bool = False, |
|
force_reindex: bool = False, |
|
*args, |
|
**kwargs, |
|
) -> "FaissDocumentIndex": |
|
""" |
|
Index the documents using the encoder. |
|
|
|
Args: |
|
retriever (:obj:`torch.nn.Module`): |
|
The encoder to be used for indexing. |
|
documents (:obj:`List[Document]`, `optional`, defaults to None): |
|
The documents to be indexed. |
|
batch_size (:obj:`int`, `optional`, defaults to 32): |
|
The batch size to be used for indexing. |
|
num_workers (:obj:`int`, `optional`, defaults to 4): |
|
The number of workers to be used for indexing. |
|
max_length (:obj:`int`, `optional`, defaults to None): |
|
The maximum length of the input to the encoder. |
|
collate_fn (:obj:`Callable`, `optional`, defaults to None): |
|
The collate function to be used for batching. |
|
encoder_precision (:obj:`Union[str, int]`, `optional`, defaults to None): |
|
The precision to be used for the encoder. |
|
compute_on_cpu (:obj:`bool`, `optional`, defaults to False): |
|
Whether to compute the embeddings on CPU. |
|
force_reindex (:obj:`bool`, `optional`, defaults to False): |
|
Whether to force reindexing. |
|
|
|
Returns: |
|
:obj:`InMemoryIndexer`: The indexer object. |
|
""" |
|
|
|
if self.embeddings is not None and not force_reindex: |
|
logger.log( |
|
"Embeddings are already present and `force_reindex` is `False`. Skipping indexing." |
|
) |
|
if documents is None: |
|
return self |
|
|
|
|
|
if collate_fn is None: |
|
tokenizer = retriever.passage_tokenizer |
|
|
|
def collate_fn(x): |
|
return ModelInputs( |
|
tokenizer( |
|
x, |
|
padding=True, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=max_length or tokenizer.model_max_length, |
|
) |
|
) |
|
|
|
if force_reindex: |
|
if documents is not None: |
|
self.documents.add_document(documents) |
|
data = [k for k in self.get_passages()] |
|
|
|
else: |
|
if documents is not None: |
|
data = [k for k in self.get_passages(DocumentStore(documents))] |
|
else: |
|
return self |
|
|
|
dataloader = DataLoader( |
|
BaseDataset(name="passage", data=data), |
|
batch_size=batch_size, |
|
shuffle=False, |
|
num_workers=num_workers, |
|
pin_memory=False, |
|
collate_fn=collate_fn, |
|
) |
|
|
|
encoder = retriever.passage_encoder |
|
|
|
|
|
passage_embeddings: List[torch.Tensor] = [] |
|
|
|
encoder_device = "cpu" if compute_on_cpu else self.device |
|
|
|
|
|
|
|
device_type_for_autocast = str(encoder_device).split(":")[0] |
|
|
|
autocast_pssg_mngr = ( |
|
contextlib.nullcontext() |
|
if device_type_for_autocast == "cpu" |
|
else ( |
|
torch.autocast( |
|
device_type=device_type_for_autocast, |
|
dtype=PRECISION_MAP[encoder_precision], |
|
) |
|
) |
|
) |
|
with autocast_pssg_mngr: |
|
|
|
for batch in tqdm(dataloader, desc="Indexing"): |
|
|
|
batch: ModelInputs = batch.to(encoder_device) |
|
|
|
passage_outs = encoder(**batch) |
|
|
|
if self.device == "cpu": |
|
passage_embeddings.extend([c.detach().cpu() for c in passage_outs]) |
|
else: |
|
passage_embeddings.extend([c for c in passage_outs]) |
|
|
|
|
|
passage_embeddings = [c.detach().cpu() for c in passage_embeddings] |
|
|
|
passage_embeddings: torch.Tensor = torch.stack(passage_embeddings, dim=0) |
|
|
|
passage_embeddings.to(PRECISION_MAP["float32"]) |
|
|
|
|
|
self.embeddings = self._build_faiss_index( |
|
embeddings=passage_embeddings, |
|
index_type=self.index_type, |
|
normalize=self.normalize, |
|
metric=self.metric, |
|
) |
|
|
|
del passage_embeddings |
|
|
|
return self |
|
|
|
@torch.no_grad() |
|
@torch.inference_mode() |
|
def search(self, query: torch.Tensor, k: int = 1) -> list[list[RetrievedSample]]: |
|
k = min(k, self.embeddings.ntotal) |
|
|
|
if self.normalize: |
|
faiss.normalize_L2(query) |
|
if isinstance(query, torch.Tensor) and self.device == "cpu": |
|
query = query.detach().cpu() |
|
|
|
retriever_out = self.embeddings.search(query, k) |
|
|
|
|
|
batch_top_k: List[List[int]] = retriever_out[1].detach().cpu().tolist() |
|
|
|
batch_scores: List[List[float]] = retriever_out[0].detach().cpu().tolist() |
|
|
|
batch_docs = [ |
|
[self.documents.get_document_from_id(i) for i in indices if i != -1] |
|
for indices in batch_top_k |
|
] |
|
|
|
|
|
batch_retrieved_samples = [ |
|
[ |
|
RetrievedSample(document=doc, score=score) |
|
for doc, score in zip(docs, scores) |
|
] |
|
for docs, scores in zip(batch_docs, batch_scores) |
|
] |
|
return batch_retrieved_samples |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_embeddings_from_index( |
|
self, index: int |
|
) -> Union[torch.Tensor, numpy.ndarray]: |
|
""" |
|
Get the document vector from the index. |
|
|
|
Args: |
|
index (`int`): |
|
The index of the document. |
|
|
|
Returns: |
|
`torch.Tensor`: The document vector. |
|
""" |
|
if self.embeddings is None: |
|
raise ValueError( |
|
"The documents must be indexed before they can be retrieved." |
|
) |
|
if index >= self.embeddings.ntotal: |
|
raise ValueError( |
|
f"The index {index} is out of bounds. The maximum index is {self.embeddings.ntotal}." |
|
) |
|
return self.embeddings.reconstruct(index) |
|
|