import os import ujson import torch import random from collections import defaultdict, OrderedDict from colbert.parameters import DEVICE from colbert.modeling.colbert import ColBERT from colbert.utils.utils import print_message, load_checkpoint from colbert.evaluation.load_model import load_model from colbert.utils.runs import Run def load_queries(queries_path): queries = OrderedDict() print_message("#> Loading the queries from", queries_path, "...") with open(queries_path) as f: for line in f: qid, query, *_ = line.strip().split('\t') qid = int(qid) assert (qid not in queries), ("Query QID", qid, "is repeated!") queries[qid] = query print_message("#> Got", len(queries), "queries. All QIDs are unique.\n") return queries def load_qrels(qrels_path): if qrels_path is None: return None print_message("#> Loading qrels from", qrels_path, "...") qrels = OrderedDict() with open(qrels_path, mode='r', encoding="utf-8") as f: for line in f: qid, x, pid, y = map(int, line.strip().split('\t')) assert x == 0 and y == 1 qrels[qid] = qrels.get(qid, []) qrels[qid].append(pid) # assert all(len(qrels[qid]) == len(set(qrels[qid])) for qid in qrels) for qid in qrels: qrels[qid] = list(set(qrels[qid])) avg_positive = round(sum(len(qrels[qid]) for qid in qrels) / len(qrels), 2) print_message("#> Loaded qrels for", len(qrels), "unique queries with", avg_positive, "positives per query on average.\n") return qrels def load_topK(topK_path): queries = OrderedDict() topK_docs = OrderedDict() topK_pids = OrderedDict() print_message("#> Loading the top-k per query from", topK_path, "...") with open(topK_path) as f: for line_idx, line in enumerate(f): if line_idx and line_idx % (10*1000*1000) == 0: print(line_idx, end=' ', flush=True) qid, pid, query, passage = line.split('\t') qid, pid = int(qid), int(pid) assert (qid not in queries) or (queries[qid] == query) queries[qid] = query topK_docs[qid] = topK_docs.get(qid, []) topK_docs[qid].append(passage) topK_pids[qid] = topK_pids.get(qid, []) topK_pids[qid].append(pid) print() assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids) Ks = [len(topK_pids[qid]) for qid in topK_pids] print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2)) print_message("#> Loaded the top-k per query for", len(queries), "unique queries.\n") return queries, topK_docs, topK_pids def load_topK_pids(topK_path, qrels): topK_pids = defaultdict(list) topK_positives = defaultdict(list) print_message("#> Loading the top-k PIDs per query from", topK_path, "...") with open(topK_path) as f: for line_idx, line in enumerate(f): if line_idx and line_idx % (10*1000*1000) == 0: print(line_idx, end=' ', flush=True) qid, pid, *rest = line.strip().split('\t') qid, pid = int(qid), int(pid) topK_pids[qid].append(pid) assert len(rest) in [1, 2, 3] if len(rest) > 1: *_, label = rest label = int(label) assert label in [0, 1] if label >= 1: topK_positives[qid].append(pid) print() assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids) assert all(len(topK_positives[qid]) == len(set(topK_positives[qid])) for qid in topK_positives) # Make them sets for fast lookups later topK_positives = {qid: set(topK_positives[qid]) for qid in topK_positives} Ks = [len(topK_pids[qid]) for qid in topK_pids] print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2)) print_message("#> Loaded the top-k per query for", len(topK_pids), "unique queries.\n") if len(topK_positives) == 0: topK_positives = None else: assert len(topK_pids) >= len(topK_positives) for qid in set.difference(set(topK_pids.keys()), set(topK_positives.keys())): topK_positives[qid] = [] assert len(topK_pids) == len(topK_positives) avg_positive = round(sum(len(topK_positives[qid]) for qid in topK_positives) / len(topK_pids), 2) print_message("#> Concurrently got annotations for", len(topK_positives), "unique queries with", avg_positive, "positives per query on average.\n") assert qrels is None or topK_positives is None, "Cannot have both qrels and an annotated top-K file!" if topK_positives is None: topK_positives = qrels return topK_pids, topK_positives def load_collection(collection_path): print_message("#> Loading collection...") collection = [] with open(collection_path) as f: for line_idx, line in enumerate(f): if line_idx % (1000*1000) == 0: print(f'{line_idx // 1000 // 1000}M', end=' ', flush=True) pid, passage, *rest = line.strip('\n\r ').split('\t') assert pid == 'id' or int(pid) == line_idx if len(rest) >= 1: title = rest[0] passage = title + ' | ' + passage collection.append(passage) print() return collection def load_colbert(args, do_print=True): colbert, checkpoint = load_model(args, do_print) # TODO: If the parameters below were not specified on the command line, their *checkpoint* values should be used. # I.e., not their purely (i.e., training) default values. for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp']: if 'arguments' in checkpoint and hasattr(args, k): if k in checkpoint['arguments'] and checkpoint['arguments'][k] != getattr(args, k): a, b = checkpoint['arguments'][k], getattr(args, k) Run.warn(f"Got checkpoint['arguments']['{k}'] != args.{k} (i.e., {a} != {b})") if 'arguments' in checkpoint: if args.rank < 1: print(ujson.dumps(checkpoint['arguments'], indent=4)) if do_print: print('\n') return colbert, checkpoint