CarlosMalaga
commited on
Delete examples
Browse files- examples/explore_faiss.md +0 -8
- examples/explore_faiss.py +0 -163
- examples/train_retriever.py +0 -45
examples/explore_faiss.md
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
# table to store results
|
2 |
-
|
3 |
-
| Index | nprobe | Recall | Time |
|
4 |
-
|----------------|--------|--------|-------|
|
5 |
-
| Flat | 1 | 98.7 | 38.64 |
|
6 |
-
| IVFx,Flat | 1 | 42.5 | 23.46 |
|
7 |
-
| IVFx,Flat | 14 | 88.5 | 133 |
|
8 |
-
| IVFx_HNSW,Flat | 1 | 88.5 | 133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/explore_faiss.py
DELETED
@@ -1,163 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import json
|
3 |
-
import logging
|
4 |
-
import os
|
5 |
-
from pathlib import Path
|
6 |
-
import time
|
7 |
-
from typing import Union
|
8 |
-
|
9 |
-
import torch
|
10 |
-
import tqdm
|
11 |
-
|
12 |
-
from relik.retriever import GoldenRetriever
|
13 |
-
from relik.common.log import get_logger
|
14 |
-
from relik.retriever.common.model_inputs import ModelInputs
|
15 |
-
from relik.retriever.data.base.datasets import BaseDataset
|
16 |
-
from relik.retriever.indexers.base import BaseDocumentIndex
|
17 |
-
from relik.retriever.indexers.faiss import FaissDocumentIndex
|
18 |
-
|
19 |
-
logger = get_logger(level=logging.INFO)
|
20 |
-
|
21 |
-
|
22 |
-
def compute_retriever_stats(dataset) -> None:
|
23 |
-
correct, total = 0, 0
|
24 |
-
for sample in dataset:
|
25 |
-
window_candidates = sample["window_candidates"]
|
26 |
-
window_candidates = [c.replace("_", " ").lower() for c in window_candidates]
|
27 |
-
|
28 |
-
for ss, se, label in sample["window_labels"]:
|
29 |
-
if label == "--NME--":
|
30 |
-
continue
|
31 |
-
if label.replace("_", " ").lower() in window_candidates:
|
32 |
-
correct += 1
|
33 |
-
total += 1
|
34 |
-
|
35 |
-
recall = correct / total
|
36 |
-
print("Recall:", recall)
|
37 |
-
|
38 |
-
|
39 |
-
@torch.no_grad()
|
40 |
-
def add_candidates(
|
41 |
-
retriever_name_or_path: Union[str, os.PathLike],
|
42 |
-
document_index_name_or_path: Union[str, os.PathLike],
|
43 |
-
input_path: Union[str, os.PathLike],
|
44 |
-
batch_size: int = 128,
|
45 |
-
num_workers: int = 4,
|
46 |
-
index_type: str = "Flat",
|
47 |
-
nprobe: int = 1,
|
48 |
-
device: str = "cpu",
|
49 |
-
precision: str = "fp32",
|
50 |
-
topics: bool = False,
|
51 |
-
):
|
52 |
-
document_index = BaseDocumentIndex.from_pretrained(
|
53 |
-
document_index_name_or_path,
|
54 |
-
# config_kwargs={
|
55 |
-
# "_target_": "relik.retriever.indexers.faiss.FaissDocumentIndex",
|
56 |
-
# "index_type": index_type,
|
57 |
-
# "nprobe": nprobe,
|
58 |
-
# },
|
59 |
-
device=device,
|
60 |
-
precision=precision,
|
61 |
-
)
|
62 |
-
|
63 |
-
retriever = GoldenRetriever(
|
64 |
-
question_encoder=retriever_name_or_path,
|
65 |
-
document_index=document_index,
|
66 |
-
device=device,
|
67 |
-
precision=precision,
|
68 |
-
index_device=device,
|
69 |
-
index_precision=precision,
|
70 |
-
)
|
71 |
-
retriever.eval()
|
72 |
-
|
73 |
-
logger.info(f"Loading from {input_path}")
|
74 |
-
with open(input_path) as f:
|
75 |
-
samples = [json.loads(line) for line in f.readlines()]
|
76 |
-
|
77 |
-
topics = topics and "doc_topic" in samples[0]
|
78 |
-
|
79 |
-
# get tokenizer
|
80 |
-
tokenizer = retriever.question_tokenizer
|
81 |
-
collate_fn = lambda batch: ModelInputs(
|
82 |
-
tokenizer(
|
83 |
-
[b["text"] for b in batch],
|
84 |
-
text_pair=[b["doc_topic"] for b in batch] if topics else None,
|
85 |
-
padding=True,
|
86 |
-
return_tensors="pt",
|
87 |
-
truncation=True,
|
88 |
-
)
|
89 |
-
)
|
90 |
-
logger.info(f"Creating dataloader with batch size {batch_size}")
|
91 |
-
dataloader = torch.utils.data.DataLoader(
|
92 |
-
BaseDataset(name="passage", data=samples),
|
93 |
-
batch_size=batch_size,
|
94 |
-
shuffle=False,
|
95 |
-
num_workers=num_workers,
|
96 |
-
pin_memory=False,
|
97 |
-
collate_fn=collate_fn,
|
98 |
-
)
|
99 |
-
|
100 |
-
# we also dump the candidates to a file after a while
|
101 |
-
retrieved_accumulator = []
|
102 |
-
with torch.inference_mode():
|
103 |
-
start = time.time()
|
104 |
-
num_completed_docs = 0
|
105 |
-
|
106 |
-
for documents_batch in tqdm.tqdm(dataloader):
|
107 |
-
retrieve_kwargs = {
|
108 |
-
**documents_batch,
|
109 |
-
"k": 100,
|
110 |
-
"precision": precision,
|
111 |
-
}
|
112 |
-
batch_out = retriever.retrieve(**retrieve_kwargs)
|
113 |
-
retrieved_accumulator.extend(batch_out)
|
114 |
-
|
115 |
-
end = time.time()
|
116 |
-
|
117 |
-
output_data = []
|
118 |
-
# get the correct document from the original dataset
|
119 |
-
# the dataloader is not shuffled, so we can just count the number of
|
120 |
-
# documents we have seen so far
|
121 |
-
for sample, retrieved in zip(
|
122 |
-
samples[
|
123 |
-
num_completed_docs : num_completed_docs + len(retrieved_accumulator)
|
124 |
-
],
|
125 |
-
retrieved_accumulator,
|
126 |
-
):
|
127 |
-
candidate_titles = [c.label.split(" <def>", 1)[0] for c in retrieved]
|
128 |
-
sample["window_candidates"] = candidate_titles
|
129 |
-
sample["window_candidates_scores"] = [c.score for c in retrieved]
|
130 |
-
output_data.append(sample)
|
131 |
-
|
132 |
-
# for sample in output_data:
|
133 |
-
# f_out.write(json.dumps(sample) + "\n")
|
134 |
-
|
135 |
-
num_completed_docs += len(retrieved_accumulator)
|
136 |
-
retrieved_accumulator = []
|
137 |
-
|
138 |
-
compute_retriever_stats(output_data)
|
139 |
-
print(f"Retrieval took {end - start:.2f} seconds")
|
140 |
-
|
141 |
-
|
142 |
-
if __name__ == "__main__":
|
143 |
-
# arg_parser = argparse.ArgumentParser()
|
144 |
-
# arg_parser.add_argument("--retriever_name_or_path", type=str, required=True)
|
145 |
-
# arg_parser.add_argument("--document_index_name_or_path", type=str, required=True)
|
146 |
-
# arg_parser.add_argument("--input_path", type=str, required=True)
|
147 |
-
# arg_parser.add_argument("--output_path", type=str, required=True)
|
148 |
-
# arg_parser.add_argument("--batch_size", type=int, default=128)
|
149 |
-
# arg_parser.add_argument("--device", type=str, default="cuda")
|
150 |
-
# arg_parser.add_argument("--index_device", type=str, default="cpu")
|
151 |
-
# arg_parser.add_argument("--precision", type=str, default="fp32")
|
152 |
-
|
153 |
-
# add_candidates(**vars(arg_parser.parse_args()))
|
154 |
-
add_candidates(
|
155 |
-
"/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/question_encoder",
|
156 |
-
"/root/relik-spaces/models/relik-retriever-small-aida-blink-pretrain-omniencoder/document_index_filtered",
|
157 |
-
"/root/relik-spaces/data/reader/aida/testa_windowed.jsonl",
|
158 |
-
# index_type="HNSW32",
|
159 |
-
# index_type="IVF1024,PQ8",
|
160 |
-
# nprobe=1,
|
161 |
-
topics=True,
|
162 |
-
device="cuda",
|
163 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/train_retriever.py
DELETED
@@ -1,45 +0,0 @@
|
|
1 |
-
from relik.retriever.trainer import RetrieverTrainer
|
2 |
-
from relik import GoldenRetriever
|
3 |
-
from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
|
4 |
-
from relik.retriever.data.datasets import AidaInBatchNegativesDataset
|
5 |
-
|
6 |
-
if __name__ == "__main__":
|
7 |
-
# instantiate retriever
|
8 |
-
document_index = InMemoryDocumentIndex(
|
9 |
-
documents="/root/golden-retriever-v2/data/dpr-like/el/definitions.txt",
|
10 |
-
device="cuda",
|
11 |
-
precision="16",
|
12 |
-
)
|
13 |
-
retriever = GoldenRetriever(
|
14 |
-
question_encoder="intfloat/e5-small-v2", document_index=document_index
|
15 |
-
)
|
16 |
-
|
17 |
-
train_dataset = AidaInBatchNegativesDataset(
|
18 |
-
name="aida_train",
|
19 |
-
path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/train.jsonl",
|
20 |
-
tokenizer=retriever.question_tokenizer,
|
21 |
-
question_batch_size=64,
|
22 |
-
passage_batch_size=400,
|
23 |
-
max_passage_length=64,
|
24 |
-
use_topics=True,
|
25 |
-
shuffle=True,
|
26 |
-
)
|
27 |
-
val_dataset = AidaInBatchNegativesDataset(
|
28 |
-
name="aida_val",
|
29 |
-
path="/root/golden-retriever-v2/data/dpr-like/el/aida_32_tokens_topic/val.jsonl",
|
30 |
-
tokenizer=retriever.question_tokenizer,
|
31 |
-
question_batch_size=64,
|
32 |
-
passage_batch_size=400,
|
33 |
-
max_passage_length=64,
|
34 |
-
use_topics=True,
|
35 |
-
)
|
36 |
-
|
37 |
-
trainer = RetrieverTrainer(
|
38 |
-
retriever=retriever,
|
39 |
-
train_dataset=train_dataset,
|
40 |
-
val_dataset=val_dataset,
|
41 |
-
max_steps=25_000,
|
42 |
-
wandb_offline_mode=True,
|
43 |
-
)
|
44 |
-
|
45 |
-
trainer.train()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|