Spaces:
Sleeping
Sleeping
import os | |
from dataclasses import dataclass | |
from pathlib import Path | |
from typing import Any, Dict, List, Optional, Union | |
import hydra | |
import numpy | |
import torch | |
from omegaconf import OmegaConf | |
from rich.pretty import pprint | |
from relik.common import upload | |
from relik.common.log import get_console_logger, get_logger | |
from relik.common.utils import ( | |
from_cache, | |
is_remote_url, | |
is_str_a_path, | |
relative_to_absolute_path, | |
sapienzanlp_model_urls, | |
) | |
from relik.retriever.data.labels import Labels | |
# from relik.retriever.models.model import GoldenRetriever, RetrievedSample | |
logger = get_logger(__name__) | |
console_logger = get_console_logger() | |
class IndexerOutput: | |
indices: Union[torch.Tensor, numpy.ndarray] | |
distances: Union[torch.Tensor, numpy.ndarray] | |
class BaseDocumentIndex: | |
CONFIG_NAME = "config.yaml" | |
DOCUMENTS_FILE_NAME = "documents.json" | |
EMBEDDINGS_FILE_NAME = "embeddings.pt" | |
def __init__( | |
self, | |
documents: Union[str, List[str], Labels, os.PathLike, List[os.PathLike]] = None, | |
embeddings: Optional[torch.Tensor] = None, | |
name_or_dir: Optional[Union[str, os.PathLike]] = None, | |
) -> None: | |
if documents is not None: | |
if isinstance(documents, Labels): | |
self.documents = documents | |
else: | |
documents_are_paths = False | |
# normalize the documents to list if not already | |
if not isinstance(documents, list): | |
documents = [documents] | |
# now check if the documents are a list of paths (either str or os.PathLike) | |
if isinstance(documents[0], str) or isinstance( | |
documents[0], os.PathLike | |
): | |
# check if the str is a path | |
documents_are_paths = is_str_a_path(documents[0]) | |
# if the documents are a list of paths, then we load them | |
if documents_are_paths: | |
logger.info("Loading documents from paths") | |
_documents = [] | |
for doc in documents: | |
with open(relative_to_absolute_path(doc)) as f: | |
_documents += [line.strip() for line in f.readlines()] | |
# remove duplicates | |
documents = list(set(_documents)) | |
self.documents = Labels() | |
self.documents.add_labels(documents) | |
else: | |
self.documents = Labels() | |
self.embeddings = embeddings | |
self.name_or_dir = name_or_dir | |
def __iter__(self): | |
# make this class iterable | |
for i in range(len(self)): | |
yield self[i] | |
def __len__(self): | |
return self.documents.get_label_size() | |
def __getitem__(self, index): | |
return self.get_passage_from_index(index) | |
def config(self) -> Dict[str, Any]: | |
""" | |
The configuration of the document index. | |
Returns: | |
`Dict[str, Any]`: The configuration of the retriever. | |
""" | |
def obj_to_dict(obj): | |
match obj: | |
case dict(): | |
data = {} | |
for k, v in obj.items(): | |
data[k] = obj_to_dict(v) | |
return data | |
case list() | tuple(): | |
return [obj_to_dict(x) for x in obj] | |
case object(__dict__=_): | |
data = { | |
"_target_": f"{obj.__class__.__module__}.{obj.__class__.__name__}", | |
} | |
for k, v in obj.__dict__.items(): | |
if not k.startswith("_"): | |
data[k] = obj_to_dict(v) | |
return data | |
case _: | |
return obj | |
return obj_to_dict(self) | |
def index( | |
self, | |
retriever, | |
*args, | |
**kwargs, | |
) -> "BaseDocumentIndex": | |
raise NotImplementedError | |
def search(self, query: Any, k: int = 1, *args, **kwargs) -> List: | |
raise NotImplementedError | |
def get_index_from_passage(self, document: str) -> int: | |
""" | |
Get the index of the passage. | |
Args: | |
document (`str`): | |
The document to get the index for. | |
Returns: | |
`int`: The index of the document. | |
""" | |
return self.documents.get_index_from_label(document) | |
def get_passage_from_index(self, index: int) -> str: | |
""" | |
Get the document from the index. | |
Args: | |
index (`int`): | |
The index of the document. | |
Returns: | |
`str`: The document. | |
""" | |
return self.documents.get_label_from_index(index) | |
def get_embeddings_from_index(self, index: int) -> torch.Tensor: | |
""" | |
Get the document vector from the index. | |
Args: | |
index (`int`): | |
The index of the document. | |
Returns: | |
`torch.Tensor`: The document vector. | |
""" | |
if self.embeddings is None: | |
raise ValueError( | |
"The documents must be indexed before they can be retrieved." | |
) | |
if index >= self.embeddings.shape[0]: | |
raise ValueError( | |
f"The index {index} is out of bounds. The maximum index is {len(self.embeddings) - 1}." | |
) | |
return self.embeddings[index] | |
def get_embeddings_from_passage(self, document: str) -> torch.Tensor: | |
""" | |
Get the document vector from the document label. | |
Args: | |
document (`str`): | |
The document to get the vector for. | |
Returns: | |
`torch.Tensor`: The document vector. | |
""" | |
if self.embeddings is None: | |
raise ValueError( | |
"The documents must be indexed before they can be retrieved." | |
) | |
return self.get_embeddings_from_index(self.get_index_from_passage(document)) | |
def save_pretrained( | |
self, | |
output_dir: Union[str, os.PathLike], | |
config: Optional[Dict[str, Any]] = None, | |
config_file_name: Optional[str] = None, | |
document_file_name: Optional[str] = None, | |
embedding_file_name: Optional[str] = None, | |
push_to_hub: bool = False, | |
**kwargs, | |
): | |
""" | |
Save the retriever to a directory. | |
Args: | |
output_dir (`str`): | |
The directory to save the retriever to. | |
config (`Optional[Dict[str, Any]]`, `optional`): | |
The configuration to save. If `None`, the current configuration of the retriever will be | |
saved. Defaults to `None`. | |
config_file_name (`Optional[str]`, `optional`): | |
The name of the configuration file. Defaults to `config.yaml`. | |
document_file_name (`Optional[str]`, `optional`): | |
The name of the document file. Defaults to `documents.json`. | |
embedding_file_name (`Optional[str]`, `optional`): | |
The name of the embedding file. Defaults to `embeddings.pt`. | |
push_to_hub (`bool`, `optional`): | |
Whether to push the saved retriever to the hub. Defaults to `False`. | |
""" | |
if config is None: | |
# create a default config | |
config = self.config | |
config_file_name = config_file_name or self.CONFIG_NAME | |
document_file_name = document_file_name or self.DOCUMENTS_FILE_NAME | |
embedding_file_name = embedding_file_name or self.EMBEDDINGS_FILE_NAME | |
# create the output directory | |
output_dir = Path(output_dir) | |
output_dir.mkdir(parents=True, exist_ok=True) | |
logger.info(f"Saving retriever to {output_dir}") | |
logger.info(f"Saving config to {output_dir / config_file_name}") | |
# pretty print the config | |
pprint(config, console=console_logger, expand_all=True) | |
OmegaConf.save(config, output_dir / config_file_name) | |
# save the current state of the retriever | |
embedding_path = output_dir / embedding_file_name | |
logger.info(f"Saving retriever state to {output_dir / embedding_path}") | |
torch.save(self.embeddings, embedding_path) | |
# save the passage index | |
documents_path = output_dir / document_file_name | |
logger.info(f"Saving passage index to {documents_path}") | |
self.documents.save(documents_path) | |
logger.info("Saving document index to disk done.") | |
if push_to_hub: | |
# push to hub | |
logger.info(f"Pushing to hub") | |
model_id = model_id or output_dir.name | |
upload(output_dir, model_id, **kwargs) | |
def from_pretrained( | |
cls, | |
name_or_dir: Union[str, os.PathLike], | |
device: str = "cpu", | |
precision: Optional[str] = None, | |
config_file_name: Optional[str] = None, | |
document_file_name: Optional[str] = None, | |
embedding_file_name: Optional[str] = None, | |
config_kwargs: Optional[Dict[str, Any]] = None, | |
*args, | |
**kwargs, | |
) -> "BaseDocumentIndex": | |
cache_dir = kwargs.pop("cache_dir", None) | |
force_download = kwargs.pop("force_download", False) | |
config_file_name = config_file_name or cls.CONFIG_NAME | |
document_file_name = document_file_name or cls.DOCUMENTS_FILE_NAME | |
embedding_file_name = embedding_file_name or cls.EMBEDDINGS_FILE_NAME | |
model_dir = from_cache( | |
name_or_dir, | |
filenames=[config_file_name, document_file_name, embedding_file_name], | |
cache_dir=cache_dir, | |
force_download=force_download, | |
) | |
config_path = model_dir / config_file_name | |
if not config_path.exists(): | |
raise FileNotFoundError( | |
f"Model configuration file not found at {config_path}." | |
) | |
config = OmegaConf.load(config_path) | |
# override the config with the kwargs | |
if config_kwargs is not None: | |
config = OmegaConf.merge(config, OmegaConf.create(config_kwargs)) | |
pprint(OmegaConf.to_container(config), console=console_logger, expand_all=True) | |
# load the documents | |
documents_path = model_dir / document_file_name | |
if not documents_path.exists(): | |
raise ValueError(f"Document file `{documents_path}` does not exist.") | |
logger.info(f"Loading documents from {documents_path}") | |
documents = Labels.from_file(documents_path) | |
# load the passage embeddings | |
embedding_path = model_dir / embedding_file_name | |
# run some checks | |
embeddings = None | |
if embedding_path.exists(): | |
logger.info(f"Loading embeddings from {embedding_path}") | |
embeddings = torch.load(embedding_path, map_location="cpu") | |
else: | |
logger.warning(f"Embedding file `{embedding_path}` does not exist.") | |
document_index = hydra.utils.instantiate( | |
config, | |
documents=documents, | |
embeddings=embeddings, | |
device=device, | |
precision=precision, | |
name_or_dir=name_or_dir, | |
*args, | |
**kwargs, | |
) | |
return document_index | |