mColBERT / colbert /ranking /retrieval.py
vjeronymo2's picture
Adding model and checkpoint
828992f
import os
import time
import faiss
import random
import torch
import itertools
from colbert.utils.runs import Run
from multiprocessing import Pool
from colbert.modeling.inference import ModelInference
from colbert.evaluation.ranking_logger import RankingLogger
from colbert.utils.utils import print_message, batch
from colbert.ranking.rankers import Ranker
def retrieve(args):
inference = ModelInference(args.colbert, amp=args.amp)
ranker = Ranker(args, inference, faiss_depth=args.faiss_depth)
ranking_logger = RankingLogger(Run.path, qrels=None)
milliseconds = 0
with ranking_logger.context('ranking.tsv', also_save_annotations=False) as rlogger:
queries = args.queries
qids_in_order = list(queries.keys())
for qoffset, qbatch in batch(qids_in_order, 100, provide_offset=True):
qbatch_text = [queries[qid] for qid in qbatch]
rankings = []
for query_idx, q in enumerate(qbatch_text):
torch.cuda.synchronize('cuda:0')
s = time.time()
Q = ranker.encode([q])
pids, scores = ranker.rank(Q)
torch.cuda.synchronize()
milliseconds += (time.time() - s) * 1000.0
if len(pids):
print(qoffset+query_idx, q, len(scores), len(pids), scores[0], pids[0],
milliseconds / (qoffset+query_idx+1), 'ms')
rankings.append(zip(pids, scores))
for query_idx, (qid, ranking) in enumerate(zip(qbatch, rankings)):
query_idx = qoffset + query_idx
if query_idx % 100 == 0:
print_message(f"#> Logging query #{query_idx} (qid {qid}) now...")
ranking = [(score, pid, None) for pid, score in itertools.islice(ranking, args.depth)]
rlogger.log(qid, ranking, is_ranked=True)
print('\n\n')
print(ranking_logger.filename)
print("#> Done.")
print('\n\n')