geekyrakshit commited on
Commit
cce1c58
·
1 Parent(s): 0d77bb1

add: NVEmbed2Retriever

Browse files
docs/retreival/nv_embed_2.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # NV-Embed-v2 Retrieval
2
+
3
+ ::: medrag_multi_modal.retrieval.nv_embed_2
medrag_multi_modal/retrieval/__init__.py CHANGED
@@ -3,6 +3,7 @@ from .colpali_retrieval import CalPaliRetriever
3
  from .common import SimilarityMetric
4
  from .contriever_retrieval import ContrieverRetriever
5
  from .medcpt_retrieval import MedCPTRetriever
 
6
 
7
  __all__ = [
8
  "CalPaliRetriever",
@@ -10,4 +11,5 @@ __all__ = [
10
  "ContrieverRetriever",
11
  "SimilarityMetric",
12
  "MedCPTRetriever",
 
13
  ]
 
3
  from .common import SimilarityMetric
4
  from .contriever_retrieval import ContrieverRetriever
5
  from .medcpt_retrieval import MedCPTRetriever
6
+ from .nv_embed_2 import NVEmbed2Retriever
7
 
8
  __all__ = [
9
  "CalPaliRetriever",
 
11
  "ContrieverRetriever",
12
  "SimilarityMetric",
13
  "MedCPTRetriever",
14
+ "NVEmbed2Retriever",
15
  ]
medrag_multi_modal/retrieval/nv_embed_2.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ import safetensors
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import weave
8
+ from sentence_transformers import SentenceTransformer
9
+
10
+ from ..utils import get_torch_backend, get_wandb_artifact
11
+ from .common import SimilarityMetric, argsort_scores, save_vector_index
12
+
13
+
14
+ class NVEmbed2Retriever(weave.Model):
15
+ """
16
+ `NVEmbed2Retriever` is a class for retrieving relevant text chunks from a dataset using the
17
+ [NV-Embed-v2](https://huggingface.co/nvidia/NV-Embed-v2) model.
18
+
19
+ This class leverages the SentenceTransformer model to encode text chunks into vector representations and
20
+ performs similarity-based retrieval. It supports indexing a dataset of text chunks, saving the vector index,
21
+ and retrieving the most relevant chunks for a given query.
22
+
23
+ Args:
24
+ model_name (str): The name of the pre-trained model to use for encoding.
25
+ vector_index (Optional[torch.Tensor]): The tensor containing the vector representations of the indexed chunks.
26
+ chunk_dataset (Optional[list[dict]]): The dataset of text chunks to be indexed.
27
+ """
28
+
29
+ model_name: str
30
+ _chunk_dataset: Optional[list[dict]]
31
+ _model: SentenceTransformer
32
+ _vector_index: Optional[torch.Tensor]
33
+
34
+ def __init__(
35
+ self,
36
+ model_name: str = "sentence-transformers/nvembed2-nli-v1",
37
+ vector_index: Optional[torch.Tensor] = None,
38
+ chunk_dataset: Optional[list[dict]] = None,
39
+ ):
40
+ super().__init__(model_name=model_name)
41
+ self._model = SentenceTransformer(
42
+ self.model_name,
43
+ trust_remote_code=True,
44
+ model_kwargs={"torch_dtype": torch.float16},
45
+ device=get_torch_backend(),
46
+ )
47
+ self._model.max_seq_length = 32768
48
+ self._model.tokenizer.padding_side = "right"
49
+ self._vector_index = vector_index
50
+ self._chunk_dataset = chunk_dataset
51
+
52
+ def add_eos(self, input_examples):
53
+ input_examples = [
54
+ input_example + self._model.tokenizer.eos_token
55
+ for input_example in input_examples
56
+ ]
57
+ return input_examples
58
+
59
+ def index(self, chunk_dataset_name: str, index_name: Optional[str] = None):
60
+ """
61
+ Indexes a dataset of text chunks and optionally saves the vector index to a file.
62
+
63
+ This method retrieves a dataset of text chunks from a Weave reference, encodes the
64
+ text chunks into vector representations using the NV-Embed-v2 model, and stores the
65
+ resulting vector index. If an index name is provided, the vector index is saved to
66
+ a file in the safetensors format. Additionally, if a Weave run is active, the vector
67
+ index file is logged as an artifact to Weave.
68
+
69
+ !!! example "Example Usage"
70
+ ```python
71
+ import weave
72
+ from dotenv import load_dotenv
73
+
74
+ import wandb
75
+ from medrag_multi_modal.retrieval import NVEmbed2Retriever
76
+
77
+ load_dotenv()
78
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
79
+ wandb.init(project="medrag-multi-modal", entity="ml-colabs", job_type="nvembed2-index")
80
+ retriever = NVEmbed2Retriever(model_name="nvidia/NV-Embed-v2")
81
+ retriever.index(
82
+ chunk_dataset_name="grays-anatomy-chunks:v0",
83
+ index_name="grays-anatomy-nvembed2",
84
+ )
85
+ ```
86
+
87
+ Args:
88
+ chunk_dataset_name (str): The name of the Weave dataset containing the text chunks
89
+ to be indexed.
90
+ index_name (Optional[str]): The name of the index artifact to be saved. If provided,
91
+ the vector index is saved to a file and logged as an artifact to Weave.
92
+ """
93
+ self._chunk_dataset = weave.ref(chunk_dataset_name).get().rows
94
+ corpus = [row["text"] for row in self._chunk_dataset]
95
+ self._vector_index = self._model.encode(
96
+ self.add_eos(corpus), batch_size=len(corpus), normalize_embeddings=True
97
+ )
98
+ with torch.no_grad():
99
+ if index_name:
100
+ save_vector_index(
101
+ torch.from_numpy(self._vector_index),
102
+ "nvembed2-index",
103
+ index_name,
104
+ {"model_name": self.model_name},
105
+ )
106
+
107
+ @classmethod
108
+ def from_wandb_artifact(cls, chunk_dataset_name: str, index_artifact_address: str):
109
+ """
110
+ Creates an instance of the class from a Weave artifact.
111
+
112
+ This method retrieves a vector index and metadata from a Weave artifact stored in
113
+ Weights & Biases (wandb). It also retrieves a dataset of text chunks from a Weave
114
+ reference. The vector index is loaded from a safetensors file and moved to the
115
+ appropriate device (CPU or GPU). The text chunks are converted into a list of
116
+ dictionaries. The method then returns an instance of the class initialized with
117
+ the retrieved model name, vector index, and chunk dataset.
118
+
119
+ !!! example "Example Usage"
120
+ ```python
121
+ import weave
122
+ from dotenv import load_dotenv
123
+
124
+ import wandb
125
+ from medrag_multi_modal.retrieval import NVEmbed2Retriever
126
+
127
+ load_dotenv()
128
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
129
+ retriever = NVEmbed2Retriever(model_name="nvidia/NV-Embed-v2")
130
+ retriever.index(
131
+ chunk_dataset_name="grays-anatomy-chunks:v0",
132
+ index_name="grays-anatomy-nvembed2",
133
+ )
134
+ retriever = NVEmbed2Retriever.from_wandb_artifact(
135
+ chunk_dataset_name="grays-anatomy-chunks:v0",
136
+ index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-nvembed2:v0",
137
+ )
138
+ ```
139
+
140
+ Args:
141
+ chunk_dataset_name (str): The name of the Weave dataset containing the text chunks.
142
+ index_artifact_address (str): The address of the Weave artifact containing the
143
+ vector index.
144
+
145
+ Returns:
146
+ An instance of the class initialized with the retrieved model name, vector index,
147
+ and chunk dataset.
148
+ """
149
+ artifact_dir, metadata = get_wandb_artifact(
150
+ index_artifact_address, "nvembed2-index", get_metadata=True
151
+ )
152
+ with safetensors.torch.safe_open(
153
+ os.path.join(artifact_dir, "vector_index.safetensors"), framework="pt"
154
+ ) as f:
155
+ vector_index = f.get_tensor("vector_index")
156
+ device = torch.device(get_torch_backend())
157
+ vector_index = vector_index.to(device)
158
+ chunk_dataset = [dict(row) for row in weave.ref(chunk_dataset_name).get().rows]
159
+ return cls(
160
+ model_name=metadata["model_name"],
161
+ vector_index=vector_index,
162
+ chunk_dataset=chunk_dataset,
163
+ )
164
+
165
+ @weave.op()
166
+ def retrieve(
167
+ self,
168
+ query: list[str],
169
+ top_k: int = 2,
170
+ metric: SimilarityMetric = SimilarityMetric.COSINE,
171
+ ):
172
+ """
173
+ Retrieves the top-k most relevant chunks for a given query using the specified similarity metric.
174
+
175
+ This method encodes the input query into an embedding and computes similarity scores between
176
+ the query embedding and the precomputed vector index. The similarity metric can be either
177
+ cosine similarity or Euclidean distance. The top-k chunks with the highest similarity scores
178
+ are returned as a list of dictionaries, each containing a chunk and its corresponding score.
179
+
180
+ !!! example "Example Usage"
181
+ ```python
182
+ import weave
183
+ from dotenv import load_dotenv
184
+
185
+ import wandb
186
+ from medrag_multi_modal.retrieval import NVEmbed2Retriever
187
+
188
+ load_dotenv()
189
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
190
+ retriever = NVEmbed2Retriever(model_name="nvidia/NV-Embed-v2")
191
+ retriever.index(
192
+ chunk_dataset_name="grays-anatomy-chunks:v0",
193
+ index_name="grays-anatomy-nvembed2",
194
+ )
195
+ retriever = NVEmbed2Retriever.from_wandb_artifact(
196
+ chunk_dataset_name="grays-anatomy-chunks:v0",
197
+ index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-nvembed2:v0",
198
+ )
199
+ ```
200
+
201
+ Args:
202
+ query (list[str]): The input query strings to search for relevant chunks.
203
+ top_k (int, optional): The number of top relevant chunks to retrieve.
204
+ metric (SimilarityMetric, optional): The similarity metric to use for scoring.
205
+
206
+ Returns:
207
+ list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
208
+ """
209
+ device = torch.device(get_torch_backend())
210
+ with torch.no_grad():
211
+ query_embedding = self._model.encode(
212
+ self.add_eos(query), normalize_embeddings=True
213
+ )
214
+ query_embedding = torch.from_numpy(query_embedding).to(device)
215
+ if metric == SimilarityMetric.EUCLIDEAN:
216
+ scores = torch.squeeze(query_embedding @ self._vector_index.T)
217
+ else:
218
+ scores = F.cosine_similarity(query_embedding, self._vector_index)
219
+ scores = scores.cpu().numpy().tolist()
220
+ scores = argsort_scores(scores, descending=True)[:top_k]
221
+ retrieved_chunks = []
222
+ for score in scores:
223
+ retrieved_chunks.append(
224
+ {
225
+ "chunk": self._chunk_dataset[score["original_index"]],
226
+ "score": score["item"],
227
+ }
228
+ )
229
+ return retrieved_chunks
230
+
231
+ @weave.op()
232
+ def predict(
233
+ self,
234
+ query: str,
235
+ top_k: int = 2,
236
+ metric: SimilarityMetric = SimilarityMetric.COSINE,
237
+ ):
238
+ """
239
+ Predicts the top-k most relevant chunks for a given query using the specified similarity metric.
240
+
241
+ This method formats the input query string by prepending an instruction prompt and then calls the
242
+ `retrieve` method to get the most relevant chunks. The similarity metric can be either cosine similarity
243
+ or Euclidean distance. The top-k chunks with the highest similarity scores are returned.
244
+
245
+ !!! example "Example Usage"
246
+ ```python
247
+ import weave
248
+ from dotenv import load_dotenv
249
+
250
+ import wandb
251
+ from medrag_multi_modal.retrieval import NVEmbed2Retriever
252
+
253
+ load_dotenv()
254
+ weave.init(project_name="ml-colabs/medrag-multi-modal")
255
+ retriever = NVEmbed2Retriever(model_name="nvidia/NV-Embed-v2")
256
+ retriever.index(
257
+ chunk_dataset_name="grays-anatomy-chunks:v0",
258
+ index_name="grays-anatomy-nvembed2",
259
+ )
260
+ retriever = NVEmbed2Retriever.from_wandb_artifact(
261
+ chunk_dataset_name="grays-anatomy-chunks:v0",
262
+ index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-nvembed2:v0",
263
+ )
264
+ retriever.predict(query="What are Ribosomes?")
265
+ ```
266
+
267
+ Args:
268
+ query (str): The input query string to search for relevant chunks.
269
+ top_k (int, optional): The number of top relevant chunks to retrieve.
270
+ metric (SimilarityMetric, optional): The similarity metric to use for scoring.
271
+
272
+ Returns:
273
+ list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
274
+ """
275
+ query = [
276
+ f"Instruct: Given a question, retrieve passages that answer the question\nQuery: {query}"
277
+ ]
278
+ return self.retrieve(query, top_k, metric)
mkdocs.yml CHANGED
@@ -82,5 +82,6 @@ nav:
82
  - ColPali: 'retreival/colpali.md'
83
  - Contriever: 'retreival/contriever.md'
84
  - MedCPT: 'retreival/medcpt.md'
 
85
 
86
  repo_url: https://github.com/soumik12345/medrag-multi-modal
 
82
  - ColPali: 'retreival/colpali.md'
83
  - Contriever: 'retreival/contriever.md'
84
  - MedCPT: 'retreival/medcpt.md'
85
+ - NV-Embed-v2: 'retreival/nv_embed_2.md'
86
 
87
  repo_url: https://github.com/soumik12345/medrag-multi-modal
pyproject.toml CHANGED
@@ -7,6 +7,8 @@ requires-python = ">=3.10"
7
  dependencies = [
8
  "adapters>=1.0.0",
9
  "bm25s[full]>=0.2.2",
 
 
10
  "firerequests>=0.0.7",
11
  "jax[cpu]>=0.4.34",
12
  "pdf2image>=1.17.0",
@@ -35,12 +37,15 @@ dependencies = [
35
  "pdfplumber>=0.11.4",
36
  "semchunk>=2.2.0",
37
  "tiktoken>=0.8.0",
 
38
  ]
39
 
40
  [project.optional-dependencies]
41
  core = [
42
  "adapters>=1.0.0",
43
  "bm25s[full]>=0.2.2",
 
 
44
  "firerequests>=0.0.7",
45
  "jax[cpu]>=0.4.34",
46
  "marker-pdf>=0.2.17",
@@ -55,6 +60,7 @@ core = [
55
  "tiktoken>=0.8.0",
56
  "torch>=2.4.1",
57
  "weave>=0.51.14",
 
58
  ]
59
 
60
  dev = ["pytest>=8.3.3", "isort>=5.13.2", "black>=24.10.0", "ruff>=0.6.9"]
 
7
  dependencies = [
8
  "adapters>=1.0.0",
9
  "bm25s[full]>=0.2.2",
10
+ "datasets>=3.0.1",
11
+ "einops>=0.8.0",
12
  "firerequests>=0.0.7",
13
  "jax[cpu]>=0.4.34",
14
  "pdf2image>=1.17.0",
 
37
  "pdfplumber>=0.11.4",
38
  "semchunk>=2.2.0",
39
  "tiktoken>=0.8.0",
40
+ "sentence-transformers>=3.2.0",
41
  ]
42
 
43
  [project.optional-dependencies]
44
  core = [
45
  "adapters>=1.0.0",
46
  "bm25s[full]>=0.2.2",
47
+ "datasets>=3.0.1",
48
+ "einops>=0.8.0",
49
  "firerequests>=0.0.7",
50
  "jax[cpu]>=0.4.34",
51
  "marker-pdf>=0.2.17",
 
60
  "tiktoken>=0.8.0",
61
  "torch>=2.4.1",
62
  "weave>=0.51.14",
63
+ "sentence-transformers>=3.2.0",
64
  ]
65
 
66
  dev = ["pytest>=8.3.3", "isort>=5.13.2", "black>=24.10.0", "ruff>=0.6.9"]