File size: 5,531 Bytes
8197b11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import argparse
import json
import logging
import os
from pathlib import Path
import time
from typing import Union

import torch
import tqdm

from relik.retriever import GoldenRetriever
from relik.common.log import get_logger
from relik.retriever.common.model_inputs import ModelInputs
from relik.retriever.data.base.datasets import BaseDataset
from relik.retriever.indexers.base import BaseDocumentIndex
from relik.retriever.indexers.faiss import FaissDocumentIndex

logger = get_logger(level=logging.INFO)


def compute_retriever_stats(dataset) -> None:
    correct, total = 0, 0
    for sample in dataset:
        window_candidates = sample["window_candidates"]
        window_candidates = [c.replace("_", " ").lower() for c in window_candidates]

        for ss, se, label in sample["window_labels"]:
            if label == "--NME--":
                continue
            if label.replace("_", " ").lower() in window_candidates:
                correct += 1
            total += 1

    recall = correct / total
    print("Recall:", recall)


@torch.no_grad()
def add_candidates(
    retriever_name_or_path: Union[str, os.PathLike],
    document_index_name_or_path: Union[str, os.PathLike],
    input_path: Union[str, os.PathLike],
    batch_size: int = 128,
    num_workers: int = 4,
    index_type: str = "Flat",
    nprobe: int = 1,
    device: str = "cpu",
    precision: str = "fp32",
    topics: bool = False,
):
    document_index = BaseDocumentIndex.from_pretrained(
        document_index_name_or_path,
        # config_kwargs={
        #     "_target_": "relik.retriever.indexers.faiss.FaissDocumentIndex",
        #     "index_type": index_type,
        #     "nprobe": nprobe,
        # },
        device=device,
        precision=precision,
    )

    retriever = GoldenRetriever(
        question_encoder=retriever_name_or_path,
        document_index=document_index,
        device=device,
        precision=precision,
        index_device=device,
        index_precision=precision,
    )
    retriever.eval()

    logger.info(f"Loading from {input_path}")
    with open(input_path) as f:
        samples = [json.loads(line) for line in f.readlines()]

    topics = topics and "doc_topic" in samples[0]

    # get tokenizer
    tokenizer = retriever.question_tokenizer
    collate_fn = lambda batch: ModelInputs(
        tokenizer(
            [b["text"] for b in batch],
            text_pair=[b["doc_topic"] for b in batch] if topics else None,
            padding=True,
            return_tensors="pt",
            truncation=True,
        )
    )
    logger.info(f"Creating dataloader with batch size {batch_size}")
    dataloader = torch.utils.data.DataLoader(
        BaseDataset(name="passage", data=samples),
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=False,
        collate_fn=collate_fn,
    )

    # we also dump the candidates to a file after a while
    retrieved_accumulator = []
    with torch.inference_mode():
        start = time.time()
        num_completed_docs = 0

        for documents_batch in tqdm.tqdm(dataloader):
            retrieve_kwargs = {
                **documents_batch,
                "k": 100,
                "precision": precision,
            }
            batch_out = retriever.retrieve(**retrieve_kwargs)
            retrieved_accumulator.extend(batch_out)

        end = time.time()

        output_data = []
        # get the correct document from the original dataset
        # the dataloader is not shuffled, so we can just count the number of
        # documents we have seen so far
        for sample, retrieved in zip(
            samples[
                num_completed_docs : num_completed_docs + len(retrieved_accumulator)
            ],
            retrieved_accumulator,
        ):
            candidate_titles = [c.label.split(" <def>", 1)[0] for c in retrieved]
            sample["window_candidates"] = candidate_titles
            sample["window_candidates_scores"] = [c.score for c in retrieved]
            output_data.append(sample)

        # for sample in output_data:
        #     f_out.write(json.dumps(sample) + "\n")

        num_completed_docs += len(retrieved_accumulator)
        retrieved_accumulator = []

    compute_retriever_stats(output_data)
    print(f"Retrieval took {end - start:.2f} seconds")


if __name__ == "__main__":
    # arg_parser = argparse.ArgumentParser()
    # arg_parser.add_argument("--retriever_name_or_path", type=str, required=True)
    # arg_parser.add_argument("--document_index_name_or_path", type=str, required=True)
    # arg_parser.add_argument("--input_path", type=str, required=True)
    # arg_parser.add_argument("--output_path", type=str, required=True)
    # arg_parser.add_argument("--batch_size", type=int, default=128)
    # arg_parser.add_argument("--device", type=str, default="cuda")
    # arg_parser.add_argument("--index_device", type=str, default="cpu")
    # arg_parser.add_argument("--precision", type=str, default="fp32")

    # add_candidates(**vars(arg_parser.parse_args()))
    add_candidates(
        "/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder",
        "/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered",
        "/root/relik-spaces/data/reader/aida/testa_windowed.jsonl",
        # index_type="HNSW32",
        # index_type="IVF1024,PQ8",
        # nprobe=1,
        topics=True,
        device="cuda",
    )