import argparse import json import logging import os from pathlib import Path import time from typing import Union import torch import tqdm from relik.retriever import GoldenRetriever from relik.common.log import get_logger from relik.retriever.common.model_inputs import ModelInputs from relik.retriever.data.base.datasets import BaseDataset from relik.retriever.indexers.base import BaseDocumentIndex from relik.retriever.indexers.faiss import FaissDocumentIndex logger = get_logger(level=logging.INFO) def compute_retriever_stats(dataset) -> None: correct, total = 0, 0 for sample in dataset: window_candidates = sample["window_candidates"] window_candidates = [c.replace("_", " ").lower() for c in window_candidates] for ss, se, label in sample["window_labels"]: if label == "--NME--": continue if label.replace("_", " ").lower() in window_candidates: correct += 1 total += 1 recall = correct / total print("Recall:", recall) @torch.no_grad() def add_candidates( retriever_name_or_path: Union[str, os.PathLike], document_index_name_or_path: Union[str, os.PathLike], input_path: Union[str, os.PathLike], batch_size: int = 128, num_workers: int = 4, index_type: str = "Flat", nprobe: int = 1, device: str = "cpu", precision: str = "fp32", topics: bool = False, ): document_index = BaseDocumentIndex.from_pretrained( document_index_name_or_path, # config_kwargs={ # "_target_": "relik.retriever.indexers.faiss.FaissDocumentIndex", # "index_type": index_type, # "nprobe": nprobe, # }, device=device, precision=precision, ) retriever = GoldenRetriever( question_encoder=retriever_name_or_path, document_index=document_index, device=device, precision=precision, index_device=device, index_precision=precision, ) retriever.eval() logger.info(f"Loading from {input_path}") with open(input_path) as f: samples = [json.loads(line) for line in f.readlines()] topics = topics and "doc_topic" in samples[0] # get tokenizer tokenizer = retriever.question_tokenizer collate_fn = lambda batch: ModelInputs( tokenizer( [b["text"] for b in batch], text_pair=[b["doc_topic"] for b in batch] if topics else None, padding=True, return_tensors="pt", truncation=True, ) ) logger.info(f"Creating dataloader with batch size {batch_size}") dataloader = torch.utils.data.DataLoader( BaseDataset(name="passage", data=samples), batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=False, collate_fn=collate_fn, ) # we also dump the candidates to a file after a while retrieved_accumulator = [] with torch.inference_mode(): start = time.time() num_completed_docs = 0 for documents_batch in tqdm.tqdm(dataloader): retrieve_kwargs = { **documents_batch, "k": 100, "precision": precision, } batch_out = retriever.retrieve(**retrieve_kwargs) retrieved_accumulator.extend(batch_out) end = time.time() output_data = [] # get the correct document from the original dataset # the dataloader is not shuffled, so we can just count the number of # documents we have seen so far for sample, retrieved in zip( samples[ num_completed_docs : num_completed_docs + len(retrieved_accumulator) ], retrieved_accumulator, ): candidate_titles = [c.label.split(" ", 1)[0] for c in retrieved] sample["window_candidates"] = candidate_titles sample["window_candidates_scores"] = [c.score for c in retrieved] output_data.append(sample) # for sample in output_data: # f_out.write(json.dumps(sample) + "\n") num_completed_docs += len(retrieved_accumulator) retrieved_accumulator = [] compute_retriever_stats(output_data) print(f"Retrieval took {end - start:.2f} seconds") if __name__ == "__main__": # arg_parser = argparse.ArgumentParser() # arg_parser.add_argument("--retriever_name_or_path", type=str, required=True) # arg_parser.add_argument("--document_index_name_or_path", type=str, required=True) # arg_parser.add_argument("--input_path", type=str, required=True) # arg_parser.add_argument("--output_path", type=str, required=True) # arg_parser.add_argument("--batch_size", type=int, default=128) # arg_parser.add_argument("--device", type=str, default="cuda") # arg_parser.add_argument("--index_device", type=str, default="cpu") # arg_parser.add_argument("--precision", type=str, default="fp32") # add_candidates(**vars(arg_parser.parse_args())) add_candidates( "/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder", "/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered", "/root/relik-spaces/data/reader/aida/testa_windowed.jsonl", # index_type="HNSW32", # index_type="IVF1024,PQ8", # nprobe=1, topics=True, device="cuda", )