serverdaun commited on
Commit
b3de77b
·
1 Parent(s): 4ed0f41

Refactor imports and improve code formatting across multiple files for better readability and organization.

Browse files
app.py CHANGED
@@ -1,27 +1,24 @@
1
- import gradio as gr
2
- import os
3
  import atexit
4
  import glob
 
5
  import shutil
6
- from src.config import (
7
- DOCS_DIR,
8
- COLLECTION_NAME,
9
- EMBEDDING_MODEL_NAME,
10
- MILVUS_DB_PATH,
11
- )
12
  from src.data_loader import load_data
13
  from src.embedding_generator import (
14
  generate_document_embeddings,
15
  generate_query_embeddings,
16
  )
 
17
  from src.vector_store import (
18
- get_milvus_client,
19
  create_collection_if_not_exists,
 
20
  insert_data,
21
  search,
22
  )
23
- from src.rag_pipeline import answer_question
24
- from llama_index.embeddings.huggingface import HuggingFaceEmbedding
25
 
26
  # Initialize models and clients
27
  embedding_model = HuggingFaceEmbedding(
@@ -32,16 +29,18 @@ embedding_model = HuggingFaceEmbedding(
32
 
33
  milvus_client = get_milvus_client(MILVUS_DB_PATH)
34
 
 
35
  # --- Cleanup Function ---
36
  def cleanup_documents():
37
  """Remove all files from the documents directory."""
38
  print("Cleaning up uploaded documents...")
39
- files = glob.glob(os.path.join(DOCS_DIR, '*'))
40
  for f in files:
41
  if os.path.isfile(f):
42
  os.remove(f)
43
  print("Cleanup complete.")
44
 
 
45
  # Register the cleanup function to run on exit
46
  atexit.register(cleanup_documents)
47
 
@@ -114,4 +113,4 @@ with gr.Blocks() as demo:
114
  if __name__ == "__main__":
115
  # Ensure the documents directory exists from the start
116
  os.makedirs(DOCS_DIR, exist_ok=True)
117
- demo.launch()
 
 
 
1
  import atexit
2
  import glob
3
+ import os
4
  import shutil
5
+
6
+ import gradio as gr
7
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
8
+
9
+ from src.config import COLLECTION_NAME, DOCS_DIR, EMBEDDING_MODEL_NAME, MILVUS_DB_PATH
 
10
  from src.data_loader import load_data
11
  from src.embedding_generator import (
12
  generate_document_embeddings,
13
  generate_query_embeddings,
14
  )
15
+ from src.rag_pipeline import answer_question
16
  from src.vector_store import (
 
17
  create_collection_if_not_exists,
18
+ get_milvus_client,
19
  insert_data,
20
  search,
21
  )
 
 
22
 
23
  # Initialize models and clients
24
  embedding_model = HuggingFaceEmbedding(
 
29
 
30
  milvus_client = get_milvus_client(MILVUS_DB_PATH)
31
 
32
+
33
  # --- Cleanup Function ---
34
  def cleanup_documents():
35
  """Remove all files from the documents directory."""
36
  print("Cleaning up uploaded documents...")
37
+ files = glob.glob(os.path.join(DOCS_DIR, "*"))
38
  for f in files:
39
  if os.path.isfile(f):
40
  os.remove(f)
41
  print("Cleanup complete.")
42
 
43
+
44
  # Register the cleanup function to run on exit
45
  atexit.register(cleanup_documents)
46
 
 
113
  if __name__ == "__main__":
114
  # Ensure the documents directory exists from the start
115
  os.makedirs(DOCS_DIR, exist_ok=True)
116
+ demo.launch()
src/config.py CHANGED
@@ -20,4 +20,4 @@ If the context information is not relevant to the user's query, say "I don't kno
20
  {query}
21
 
22
  # Answer
23
- """
 
20
  {query}
21
 
22
  # Answer
23
+ """
src/data_loader.py CHANGED
@@ -12,7 +12,11 @@ def load_data(data_dir: str) -> list:
12
  A list of documents
13
  """
14
  try:
15
- loader = SimpleDirectoryReader(input_dir=data_dir, required_exts=[".pdf", ".txt", ".md", ".docx", ".doc"], recursive=True)
 
 
 
 
16
  docs = loader.load_data()
17
  return docs
18
  except Exception as e:
 
12
  A list of documents
13
  """
14
  try:
15
+ loader = SimpleDirectoryReader(
16
+ input_dir=data_dir,
17
+ required_exts=[".pdf", ".txt", ".md", ".docx", ".doc"],
18
+ recursive=True,
19
+ )
20
  docs = loader.load_data()
21
  return docs
22
  except Exception as e:
src/embedding_generator.py CHANGED
@@ -1,4 +1,5 @@
1
  from typing import Any, Generator
 
2
  import numpy as np
3
 
4
 
@@ -15,10 +16,12 @@ def batch_iterate(items: Any, batch_size: int) -> Generator[Any, None, None]:
15
  A generator of batches
16
  """
17
  for i in range(0, len(items), batch_size):
18
- yield items[i:i + batch_size]
19
 
20
 
21
- def generate_document_embeddings(documents: list[str], embedding_model: Any) -> list[bytes]:
 
 
22
  """
23
  Generate document embeddings.
24
 
@@ -49,6 +52,7 @@ def generate_document_embeddings(documents: list[str], embedding_model: Any) ->
49
  print(f"Error generating document embeddings: {e}")
50
  return []
51
 
 
52
  def generate_query_embeddings(query: str, embdding_model: Any) -> bytes:
53
  """
54
  Generate query embeddings.
@@ -73,4 +77,3 @@ def generate_query_embeddings(query: str, embdding_model: Any) -> bytes:
73
  except Exception as e:
74
  print(f"Error generating query embeddings: {e}")
75
  return None
76
-
 
1
  from typing import Any, Generator
2
+
3
  import numpy as np
4
 
5
 
 
16
  A generator of batches
17
  """
18
  for i in range(0, len(items), batch_size):
19
+ yield items[i : i + batch_size]
20
 
21
 
22
+ def generate_document_embeddings(
23
+ documents: list[str], embedding_model: Any
24
+ ) -> list[bytes]:
25
  """
26
  Generate document embeddings.
27
 
 
52
  print(f"Error generating document embeddings: {e}")
53
  return []
54
 
55
+
56
  def generate_query_embeddings(query: str, embdding_model: Any) -> bytes:
57
  """
58
  Generate query embeddings.
 
77
  except Exception as e:
78
  print(f"Error generating query embeddings: {e}")
79
  return None
 
src/rag_pipeline.py CHANGED
@@ -1,9 +1,12 @@
1
- from langchain_core.messages import HumanMessage
2
  from langchain.chat_models import init_chat_model
3
- from .config import PROMPT, MODEL_NAME, TEMPERATURE, MODEL_PROVIDER
 
 
4
 
 
 
 
5
 
6
- llm = init_chat_model(MODEL_NAME, model_provider=MODEL_PROVIDER, temperature=TEMPERATURE)
7
 
8
  def answer_question(query: str, contexts: list[str]) -> str:
9
  """
 
 
1
  from langchain.chat_models import init_chat_model
2
+ from langchain_core.messages import HumanMessage
3
+
4
+ from .config import MODEL_NAME, MODEL_PROVIDER, PROMPT, TEMPERATURE
5
 
6
+ llm = init_chat_model(
7
+ MODEL_NAME, model_provider=MODEL_PROVIDER, temperature=TEMPERATURE
8
+ )
9
 
 
10
 
11
  def answer_question(query: str, contexts: list[str]) -> str:
12
  """
src/vector_store.py CHANGED
@@ -1,4 +1,4 @@
1
- from pymilvus import MilvusClient, DataType
2
 
3
 
4
  def get_milvus_client(db_path: str) -> MilvusClient:
@@ -14,12 +14,15 @@ def get_milvus_client(db_path: str) -> MilvusClient:
14
  try:
15
  client = MilvusClient(db_path)
16
  return client
17
-
18
  except Exception as e:
19
  print(f"Error getting Milvus client: {e}")
20
  return None
21
 
22
- def create_collection_if_not_exists(client: MilvusClient, collection_name: str, dim: int) -> None:
 
 
 
23
  """
24
  Create a collection in Milvus if it does not exist.
25
 
@@ -63,8 +66,8 @@ def create_collection_if_not_exists(client: MilvusClient, collection_name: str,
63
  index_params.add_index(
64
  field_name="binary_vector",
65
  index_name="binary_vector_index",
66
- index_type="BIN_FLAT", # Exact search for binary vectors
67
- metric_type="HAMMING", # Hamming distance for binary vectors
68
  )
69
  # Create collection with schema and index
70
  client.create_collection(
@@ -77,6 +80,7 @@ def create_collection_if_not_exists(client: MilvusClient, collection_name: str,
77
  print(f"Error creating collection: {e}")
78
  return None
79
 
 
80
  def insert_data(client: MilvusClient, collection_name: str, data: list[dict]):
81
  """
82
  Insert data into a collection in Milvus.
@@ -95,7 +99,9 @@ def insert_data(client: MilvusClient, collection_name: str, data: list[dict]):
95
  print(f"Error inserting data: {e}")
96
 
97
 
98
- def search(client: MilvusClient, collection_name: str, binary_query: bytes, limit: int = 5):
 
 
99
  """
100
  Search for data in a collection in Milvus.
101
  """
@@ -115,10 +121,10 @@ def search(client: MilvusClient, collection_name: str, binary_query: bytes, limi
115
  if not results:
116
  print("No search results found")
117
  return []
118
-
119
  contexts = [res.entity.context for res in results[0]]
120
  return contexts
121
 
122
  except Exception as e:
123
  print(f"Error searching for data: {e}")
124
- return []
 
1
+ from pymilvus import DataType, MilvusClient
2
 
3
 
4
  def get_milvus_client(db_path: str) -> MilvusClient:
 
14
  try:
15
  client = MilvusClient(db_path)
16
  return client
17
+
18
  except Exception as e:
19
  print(f"Error getting Milvus client: {e}")
20
  return None
21
 
22
+
23
+ def create_collection_if_not_exists(
24
+ client: MilvusClient, collection_name: str, dim: int
25
+ ) -> None:
26
  """
27
  Create a collection in Milvus if it does not exist.
28
 
 
66
  index_params.add_index(
67
  field_name="binary_vector",
68
  index_name="binary_vector_index",
69
+ index_type="BIN_FLAT", # Exact search for binary vectors
70
+ metric_type="HAMMING", # Hamming distance for binary vectors
71
  )
72
  # Create collection with schema and index
73
  client.create_collection(
 
80
  print(f"Error creating collection: {e}")
81
  return None
82
 
83
+
84
  def insert_data(client: MilvusClient, collection_name: str, data: list[dict]):
85
  """
86
  Insert data into a collection in Milvus.
 
99
  print(f"Error inserting data: {e}")
100
 
101
 
102
+ def search(
103
+ client: MilvusClient, collection_name: str, binary_query: bytes, limit: int = 5
104
+ ):
105
  """
106
  Search for data in a collection in Milvus.
107
  """
 
121
  if not results:
122
  print("No search results found")
123
  return []
124
+
125
  contexts = [res.entity.context for res in results[0]]
126
  return contexts
127
 
128
  except Exception as e:
129
  print(f"Error searching for data: {e}")
130
+ return []