pvanand commited on
Commit
be2f825
1 Parent(s): 7dc1282

Create embeddings.py

Browse files
Files changed (1) hide show
  1. embeddings.py +112 -0
embeddings.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import logging
4
+ from typing import List
5
+
6
+ from txtai.embeddings import Embeddings
7
+
8
+ # Set up logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class EmbeddingsManager:
14
+ def __init__(self, base_path: str = "./indexes", model_path: str = "avsolatorio/GIST-all-MiniLM-L6-v2"):
15
+ """
16
+ Initializes the EmbeddingsManager.
17
+
18
+ Args:
19
+ base_path (str): Base directory to store indices.
20
+ model_path (str): Path or identifier for the embeddings model.
21
+ """
22
+ self.base_path = base_path
23
+ os.makedirs(self.base_path, exist_ok=True)
24
+ self.model_path = model_path
25
+ self.embeddings = Embeddings({"path": self.model_path})
26
+ logger.info(f"Embeddings model loaded from '{self.model_path}'. Base path set to '{self.base_path}'.")
27
+
28
+ def create_index(self, index_id: str, documents: List[str]) -> None:
29
+ """
30
+ Creates a new embeddings index with the provided documents.
31
+
32
+ Args:
33
+ index_id (str): Unique identifier for the index.
34
+ documents (List[str]): List of documents to be indexed.
35
+
36
+ Raises:
37
+ ValueError: If the index already exists.
38
+ Exception: For any other errors during indexing or saving.
39
+ """
40
+ index_path = os.path.join(self.base_path, index_id)
41
+ if os.path.exists(index_path):
42
+ logger.error(f"Index with index_id '{index_id}' already exists at '{index_path}'.")
43
+ raise ValueError(f"Index with index_id '{index_id}' already exists.")
44
+
45
+ try:
46
+ # Prepare documents for txtai indexing
47
+ document_tuples = [(i, text, None) for i, text in enumerate(documents)]
48
+ self.embeddings.index(document_tuples)
49
+ logger.info(f"Documents indexed for index_id '{index_id}'.")
50
+
51
+ # Create index directory
52
+ os.makedirs(index_path, exist_ok=True)
53
+
54
+ # Save embeddings
55
+ self.embeddings.save(os.path.join(index_path, "embeddings"))
56
+ logger.info(f"Embeddings saved to '{os.path.join(index_path, 'embeddings')}'.")
57
+
58
+ # Save document list
59
+ with open(os.path.join(index_path, "document_list.json"), "w", encoding='utf-8') as f:
60
+ json.dump(documents, f, ensure_ascii=False, indent=4)
61
+ logger.info(f"Document list saved to '{os.path.join(index_path, 'document_list.json')}'.")
62
+
63
+ logger.info(f"Index '{index_id}' created and saved successfully.")
64
+ except Exception as e:
65
+ logger.error(f"Failed to create index '{index_id}': {e}")
66
+ raise Exception(f"Failed to create index '{index_id}': {e}")
67
+
68
+ def query_index(self, index_id: str, query: str, num_results: int = 5) -> List[str]:
69
+ """
70
+ Queries an existing embeddings index.
71
+
72
+ Args:
73
+ index_id (str): Unique identifier for the index to query.
74
+ query (str): The search query.
75
+ num_results (int): Number of top results to return.
76
+
77
+ Returns:
78
+ List[str]: List of top matching documents.
79
+
80
+ Raises:
81
+ FileNotFoundError: If the index does not exist.
82
+ Exception: For any other errors during querying.
83
+ """
84
+ index_path = os.path.join(self.base_path, index_id)
85
+ if not os.path.exists(index_path):
86
+ logger.error(f"Index '{index_id}' not found at '{index_path}'.")
87
+ raise FileNotFoundError(f"Index '{index_id}' not found.")
88
+
89
+ try:
90
+ # Load embeddings from the index
91
+ self.embeddings.load(os.path.join(index_path, "embeddings"))
92
+ logger.info(f"Embeddings loaded from '{os.path.join(index_path, 'embeddings')}' for index '{index_id}'.")
93
+
94
+ # Load document list
95
+ document_list_path = os.path.join(index_path, "document_list.json")
96
+ if not os.path.exists(document_list_path):
97
+ logger.error(f"Document list not found at '{document_list_path}'.")
98
+ raise FileNotFoundError(f"Document list not found for index '{index_id}'.")
99
+
100
+ with open(document_list_path, "r", encoding='utf-8') as f:
101
+ document_list = json.load(f)
102
+ logger.info(f"Document list loaded from '{document_list_path}'.")
103
+
104
+ # Perform the search
105
+ results = self.embeddings.search(query, num_results)
106
+ queried_texts = [document_list[idx[0]] for idx in results]
107
+ logger.info(f"Query executed successfully on index '{index_id}'. Retrieved {len(queried_texts)} results.")
108
+
109
+ return queried_texts
110
+ except Exception as e:
111
+ logger.error(f"Failed to query index '{index_id}': {e}")
112
+ raise Exception(f"Failed to query index '{index_id}': {e}")