yrobel-lima commited on
Commit
5c4df90
1 Parent(s): 956157f

Delete utils/update_vector_database.py

Browse files
Files changed (1) hide show
  1. utils/update_vector_database.py +0 -258
utils/update_vector_database.py DELETED
@@ -1,258 +0,0 @@
1
- import json
2
- import os
3
- import sys
4
- from functools import cache
5
- from pathlib import Path
6
-
7
- import torch
8
- from langchain_community.retrievers import QdrantSparseVectorRetriever
9
- from langchain_community.vectorstores import Qdrant
10
- from langchain_core.documents import Document
11
- from langchain_openai.embeddings import OpenAIEmbeddings
12
- from qdrant_client import QdrantClient, models
13
- from transformers import AutoModelForMaskedLM, AutoTokenizer
14
-
15
- from data_processing import excel_to_dataframe
16
-
17
-
18
- class DataProcessor:
19
- def __init__(self, data_dir: Path):
20
- self.data_dir = data_dir
21
-
22
- def load_practitioners_data(self):
23
- try:
24
- df = excel_to_dataframe(self.data_dir)
25
- practitioners_data = []
26
- for idx, row in df.iterrows():
27
- # I am using dot as a separator for text embeddings
28
- content = ". ".join(f"{key}: {value}" for key, value in row.items())
29
- doc = Document(page_content=content, metadata={"row": idx})
30
- practitioners_data.append(doc)
31
- return practitioners_data
32
- except FileNotFoundError:
33
- sys.exit(
34
- "Directory or Excel file not found. Please check the path and try again."
35
- )
36
-
37
- def load_tall_tree_data(self):
38
- # Check if the file has a .json extension
39
- json_files = [
40
- file for file in self.data_dir.iterdir() if file.suffix == ".json"
41
- ]
42
-
43
- if not json_files:
44
- raise FileNotFoundError("No JSON files found in the specified directory.")
45
- if len(json_files) > 1:
46
- raise ValueError(
47
- "More than one JSON file found in the specified directory."
48
- )
49
-
50
- path = json_files[0]
51
- data = self.load_json_file(path)
52
- tall_tree_data = self.process_json_data(data)
53
-
54
- return tall_tree_data
55
-
56
- def load_json_file(self, path):
57
- try:
58
- with open(path, "r") as f:
59
- data = json.load(f)
60
- return data
61
- except json.JSONDecodeError:
62
- raise ValueError(f"The file {path} is not a valid JSON file.")
63
-
64
- def process_json_data(self, data):
65
- tall_tree_data = []
66
- for idx, (key, value) in enumerate(data.items()):
67
- content = f"{key}: {value}"
68
- doc = Document(page_content=content, metadata={"row": idx})
69
- tall_tree_data.append(doc)
70
- return tall_tree_data
71
-
72
-
73
- class ValidateQdrantClient:
74
- """Base class for retriever clients to ensure environment variables are set."""
75
-
76
- def __init__(self):
77
- self.validate_environment_variables()
78
-
79
- def validate_environment_variables(self):
80
- """Check if the Qdrant environment variables are set."""
81
- required_vars = ["QDRANT_API_KEY", "QDRANT_URL"]
82
- missing_vars = [var for var in required_vars if not os.getenv(var)]
83
- if missing_vars:
84
- raise EnvironmentError(
85
- f"Missing environment variable(s): {', '.join(missing_vars)}"
86
- )
87
-
88
-
89
- class DenseVectorStore(ValidateQdrantClient):
90
- """Store dense data in Qdrant vector database."""
91
-
92
- TEXT_EMBEDDING_MODELS = [
93
- "text-embedding-ada-002",
94
- "text-embedding-3-small",
95
- "text-embedding-3-large",
96
- ]
97
-
98
- def __init__(
99
- self,
100
- documents: list[Document],
101
- embeddings_model: str = "text-embedding-3-small",
102
- collection_name: str = "practitioners_db",
103
- ):
104
- super().__init__()
105
- if embeddings_model not in self.TEXT_EMBEDDING_MODELS:
106
- raise ValueError(
107
- f"Invalid embeddings model: {embeddings_model}. Valid options are {', '.join(self.TEXT_EMBEDDING_MODELS)}."
108
- )
109
- self.documents = documents
110
- self.embeddings_model = embeddings_model
111
- self.collection_name = collection_name
112
- self._qdrant_db = None
113
-
114
- @property
115
- def qdrant_db(self):
116
- if self._qdrant_db is None:
117
- self._qdrant_db = Qdrant.from_documents(
118
- self.documents,
119
- OpenAIEmbeddings(model=self.embeddings_model),
120
- url=os.getenv("QDRANT_URL"),
121
- api_key=os.getenv("QDRANT_API_KEY"),
122
- prefer_grpc=True,
123
- collection_name=self.collection_name,
124
- force_recreate=True,
125
- )
126
- return self._qdrant_db
127
-
128
-
129
- class SparseVectorStore(ValidateQdrantClient):
130
- """Store sparse vectors in Qdrant vector database using SPLADE neural retrieval model."""
131
-
132
- def __init__(
133
- self,
134
- documents: list[Document],
135
- collection_name: str,
136
- vector_name: str,
137
- k: int = 4,
138
- splade_model_id: str = "naver/splade-cocondenser-ensembledistil",
139
- ):
140
-
141
- # Validate Qdrant client
142
- super().__init__()
143
- self.client = QdrantClient(
144
- url=os.getenv("QDRANT_URL"),
145
- api_key=os.getenv("QDRANT_API_KEY"),
146
- ) # TODO: prefer_grpc=True is not working
147
- self.model_id = splade_model_id
148
- self._tokenizer = None
149
- self._model = None
150
- self.collection_name = collection_name
151
- self.vector_name = vector_name
152
- self.k = k
153
- self.sparse_retriever = self.create_sparse_retriever()
154
- self.add_documents(documents)
155
-
156
- @property
157
- @cache
158
- def tokenizer(self):
159
- """Initialize the tokenizer."""
160
- if self._tokenizer is None:
161
- self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
162
- return self._tokenizer
163
-
164
- @property
165
- @cache
166
- def model(self):
167
- """Initialize the SPLADE neural retrieval model."""
168
- if self._model is None:
169
- self._model = AutoModelForMaskedLM.from_pretrained(self.model_id)
170
- return self._model
171
-
172
- def sparse_encoder(self, text: str) -> tuple[list[int], list[float]]:
173
- """Encode the input text into a sparse vector."""
174
- tokens = self.tokenizer(
175
- text,
176
- return_tensors="pt",
177
- max_length=512,
178
- padding="max_length",
179
- truncation=True,
180
- )
181
-
182
- with torch.no_grad():
183
- logits = self.model(**tokens).logits
184
-
185
- relu_log = torch.log1p(torch.relu(logits))
186
- weighted_log = relu_log * tokens.attention_mask.unsqueeze(-1)
187
-
188
- max_val = torch.max(weighted_log, dim=1).values.squeeze()
189
- indices = torch.nonzero(max_val, as_tuple=False).squeeze().cpu().numpy()
190
- values = max_val[indices].cpu().numpy()
191
- return indices.tolist(), values.tolist()
192
-
193
- def create_sparse_retriever(self):
194
- self.client.recreate_collection(
195
- self.collection_name,
196
- vectors_config={},
197
- sparse_vectors_config={
198
- self.vector_name: models.SparseVectorParams(
199
- index=models.SparseIndexParams(
200
- on_disk=False,
201
- )
202
- )
203
- },
204
- )
205
-
206
- return QdrantSparseVectorRetriever(
207
- client=self.client,
208
- collection_name=self.collection_name,
209
- sparse_vector_name=self.vector_name,
210
- sparse_encoder=self.sparse_encoder,
211
- k=self.k,
212
- )
213
-
214
- def add_documents(self, documents):
215
- self.sparse_retriever.add_documents(documents)
216
-
217
-
218
- def main():
219
- data_dir = Path().resolve().parent / "data"
220
- if not data_dir.exists():
221
- sys.exit(f"The directory {data_dir} does not exist.")
222
-
223
- processor = DataProcessor(data_dir)
224
-
225
- print("Loading and cleaning Practitioners data...")
226
- practitioners_dataset = processor.load_practitioners_data()
227
-
228
- print("Loading Tall Tree data from json file...")
229
- tall_tree_dataset = processor.load_tall_tree_data()
230
-
231
- # Set OpenAI embeddings model
232
- # TODO: Test new OpenAI text embeddings models
233
- # text-embedding-3-large
234
- # text-embedding-3-small
235
- EMBEDDINGS_MODEL = "text-embedding-3-small"
236
-
237
- # Store both datasets in Qdrant
238
- print(f"Storing dense vectors in Qdrant using {EMBEDDINGS_MODEL}...")
239
- practitioners_db = DenseVectorStore(
240
- practitioners_dataset, EMBEDDINGS_MODEL, collection_name="practitioners_db"
241
- ).qdrant_db
242
-
243
- tall_tree_db = DenseVectorStore(
244
- tall_tree_dataset, EMBEDDINGS_MODEL, collection_name="tall_tree_db"
245
- ).qdrant_db
246
-
247
- print(f"Storing sparse vectors in Qdrant using SPLADE neural retrieval model...")
248
- practitioners_sparse_vector_db = SparseVectorStore(
249
- documents=practitioners_dataset,
250
- collection_name="practitioners_db_sparse_collection",
251
- vector_name="sparse_vector",
252
- k=15,
253
- splade_model_id="naver/splade-cocondenser-ensembledistil",
254
- )
255
-
256
-
257
- if __name__ == "__main__":
258
- main()