Spaces:
Sleeping
Sleeping
Commit
·
b3de77b
1
Parent(s):
4ed0f41
Refactor imports and improve code formatting across multiple files for better readability and organization.
Browse files- app.py +12 -13
- src/config.py +1 -1
- src/data_loader.py +5 -1
- src/embedding_generator.py +6 -3
- src/rag_pipeline.py +6 -3
- src/vector_store.py +14 -8
app.py
CHANGED
|
@@ -1,27 +1,24 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
-
import os
|
| 3 |
import atexit
|
| 4 |
import glob
|
|
|
|
| 5 |
import shutil
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
)
|
| 12 |
from src.data_loader import load_data
|
| 13 |
from src.embedding_generator import (
|
| 14 |
generate_document_embeddings,
|
| 15 |
generate_query_embeddings,
|
| 16 |
)
|
|
|
|
| 17 |
from src.vector_store import (
|
| 18 |
-
get_milvus_client,
|
| 19 |
create_collection_if_not_exists,
|
|
|
|
| 20 |
insert_data,
|
| 21 |
search,
|
| 22 |
)
|
| 23 |
-
from src.rag_pipeline import answer_question
|
| 24 |
-
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
| 25 |
|
| 26 |
# Initialize models and clients
|
| 27 |
embedding_model = HuggingFaceEmbedding(
|
|
@@ -32,16 +29,18 @@ embedding_model = HuggingFaceEmbedding(
|
|
| 32 |
|
| 33 |
milvus_client = get_milvus_client(MILVUS_DB_PATH)
|
| 34 |
|
|
|
|
| 35 |
# --- Cleanup Function ---
|
| 36 |
def cleanup_documents():
|
| 37 |
"""Remove all files from the documents directory."""
|
| 38 |
print("Cleaning up uploaded documents...")
|
| 39 |
-
files = glob.glob(os.path.join(DOCS_DIR,
|
| 40 |
for f in files:
|
| 41 |
if os.path.isfile(f):
|
| 42 |
os.remove(f)
|
| 43 |
print("Cleanup complete.")
|
| 44 |
|
|
|
|
| 45 |
# Register the cleanup function to run on exit
|
| 46 |
atexit.register(cleanup_documents)
|
| 47 |
|
|
@@ -114,4 +113,4 @@ with gr.Blocks() as demo:
|
|
| 114 |
if __name__ == "__main__":
|
| 115 |
# Ensure the documents directory exists from the start
|
| 116 |
os.makedirs(DOCS_DIR, exist_ok=True)
|
| 117 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
| 1 |
import atexit
|
| 2 |
import glob
|
| 3 |
+
import os
|
| 4 |
import shutil
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
| 8 |
+
|
| 9 |
+
from src.config import COLLECTION_NAME, DOCS_DIR, EMBEDDING_MODEL_NAME, MILVUS_DB_PATH
|
|
|
|
| 10 |
from src.data_loader import load_data
|
| 11 |
from src.embedding_generator import (
|
| 12 |
generate_document_embeddings,
|
| 13 |
generate_query_embeddings,
|
| 14 |
)
|
| 15 |
+
from src.rag_pipeline import answer_question
|
| 16 |
from src.vector_store import (
|
|
|
|
| 17 |
create_collection_if_not_exists,
|
| 18 |
+
get_milvus_client,
|
| 19 |
insert_data,
|
| 20 |
search,
|
| 21 |
)
|
|
|
|
|
|
|
| 22 |
|
| 23 |
# Initialize models and clients
|
| 24 |
embedding_model = HuggingFaceEmbedding(
|
|
|
|
| 29 |
|
| 30 |
milvus_client = get_milvus_client(MILVUS_DB_PATH)
|
| 31 |
|
| 32 |
+
|
| 33 |
# --- Cleanup Function ---
|
| 34 |
def cleanup_documents():
|
| 35 |
"""Remove all files from the documents directory."""
|
| 36 |
print("Cleaning up uploaded documents...")
|
| 37 |
+
files = glob.glob(os.path.join(DOCS_DIR, "*"))
|
| 38 |
for f in files:
|
| 39 |
if os.path.isfile(f):
|
| 40 |
os.remove(f)
|
| 41 |
print("Cleanup complete.")
|
| 42 |
|
| 43 |
+
|
| 44 |
# Register the cleanup function to run on exit
|
| 45 |
atexit.register(cleanup_documents)
|
| 46 |
|
|
|
|
| 113 |
if __name__ == "__main__":
|
| 114 |
# Ensure the documents directory exists from the start
|
| 115 |
os.makedirs(DOCS_DIR, exist_ok=True)
|
| 116 |
+
demo.launch()
|
src/config.py
CHANGED
|
@@ -20,4 +20,4 @@ If the context information is not relevant to the user's query, say "I don't kno
|
|
| 20 |
{query}
|
| 21 |
|
| 22 |
# Answer
|
| 23 |
-
"""
|
|
|
|
| 20 |
{query}
|
| 21 |
|
| 22 |
# Answer
|
| 23 |
+
"""
|
src/data_loader.py
CHANGED
|
@@ -12,7 +12,11 @@ def load_data(data_dir: str) -> list:
|
|
| 12 |
A list of documents
|
| 13 |
"""
|
| 14 |
try:
|
| 15 |
-
loader = SimpleDirectoryReader(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
docs = loader.load_data()
|
| 17 |
return docs
|
| 18 |
except Exception as e:
|
|
|
|
| 12 |
A list of documents
|
| 13 |
"""
|
| 14 |
try:
|
| 15 |
+
loader = SimpleDirectoryReader(
|
| 16 |
+
input_dir=data_dir,
|
| 17 |
+
required_exts=[".pdf", ".txt", ".md", ".docx", ".doc"],
|
| 18 |
+
recursive=True,
|
| 19 |
+
)
|
| 20 |
docs = loader.load_data()
|
| 21 |
return docs
|
| 22 |
except Exception as e:
|
src/embedding_generator.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
from typing import Any, Generator
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
|
| 4 |
|
|
@@ -15,10 +16,12 @@ def batch_iterate(items: Any, batch_size: int) -> Generator[Any, None, None]:
|
|
| 15 |
A generator of batches
|
| 16 |
"""
|
| 17 |
for i in range(0, len(items), batch_size):
|
| 18 |
-
yield items[i:i + batch_size]
|
| 19 |
|
| 20 |
|
| 21 |
-
def generate_document_embeddings(
|
|
|
|
|
|
|
| 22 |
"""
|
| 23 |
Generate document embeddings.
|
| 24 |
|
|
@@ -49,6 +52,7 @@ def generate_document_embeddings(documents: list[str], embedding_model: Any) ->
|
|
| 49 |
print(f"Error generating document embeddings: {e}")
|
| 50 |
return []
|
| 51 |
|
|
|
|
| 52 |
def generate_query_embeddings(query: str, embdding_model: Any) -> bytes:
|
| 53 |
"""
|
| 54 |
Generate query embeddings.
|
|
@@ -73,4 +77,3 @@ def generate_query_embeddings(query: str, embdding_model: Any) -> bytes:
|
|
| 73 |
except Exception as e:
|
| 74 |
print(f"Error generating query embeddings: {e}")
|
| 75 |
return None
|
| 76 |
-
|
|
|
|
| 1 |
from typing import Any, Generator
|
| 2 |
+
|
| 3 |
import numpy as np
|
| 4 |
|
| 5 |
|
|
|
|
| 16 |
A generator of batches
|
| 17 |
"""
|
| 18 |
for i in range(0, len(items), batch_size):
|
| 19 |
+
yield items[i : i + batch_size]
|
| 20 |
|
| 21 |
|
| 22 |
+
def generate_document_embeddings(
|
| 23 |
+
documents: list[str], embedding_model: Any
|
| 24 |
+
) -> list[bytes]:
|
| 25 |
"""
|
| 26 |
Generate document embeddings.
|
| 27 |
|
|
|
|
| 52 |
print(f"Error generating document embeddings: {e}")
|
| 53 |
return []
|
| 54 |
|
| 55 |
+
|
| 56 |
def generate_query_embeddings(query: str, embdding_model: Any) -> bytes:
|
| 57 |
"""
|
| 58 |
Generate query embeddings.
|
|
|
|
| 77 |
except Exception as e:
|
| 78 |
print(f"Error generating query embeddings: {e}")
|
| 79 |
return None
|
|
|
src/rag_pipeline.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
| 1 |
-
from langchain_core.messages import HumanMessage
|
| 2 |
from langchain.chat_models import init_chat_model
|
| 3 |
-
from .
|
|
|
|
|
|
|
| 4 |
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
llm = init_chat_model(MODEL_NAME, model_provider=MODEL_PROVIDER, temperature=TEMPERATURE)
|
| 7 |
|
| 8 |
def answer_question(query: str, contexts: list[str]) -> str:
|
| 9 |
"""
|
|
|
|
|
|
|
| 1 |
from langchain.chat_models import init_chat_model
|
| 2 |
+
from langchain_core.messages import HumanMessage
|
| 3 |
+
|
| 4 |
+
from .config import MODEL_NAME, MODEL_PROVIDER, PROMPT, TEMPERATURE
|
| 5 |
|
| 6 |
+
llm = init_chat_model(
|
| 7 |
+
MODEL_NAME, model_provider=MODEL_PROVIDER, temperature=TEMPERATURE
|
| 8 |
+
)
|
| 9 |
|
|
|
|
| 10 |
|
| 11 |
def answer_question(query: str, contexts: list[str]) -> str:
|
| 12 |
"""
|
src/vector_store.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from pymilvus import
|
| 2 |
|
| 3 |
|
| 4 |
def get_milvus_client(db_path: str) -> MilvusClient:
|
|
@@ -14,12 +14,15 @@ def get_milvus_client(db_path: str) -> MilvusClient:
|
|
| 14 |
try:
|
| 15 |
client = MilvusClient(db_path)
|
| 16 |
return client
|
| 17 |
-
|
| 18 |
except Exception as e:
|
| 19 |
print(f"Error getting Milvus client: {e}")
|
| 20 |
return None
|
| 21 |
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
| 23 |
"""
|
| 24 |
Create a collection in Milvus if it does not exist.
|
| 25 |
|
|
@@ -63,8 +66,8 @@ def create_collection_if_not_exists(client: MilvusClient, collection_name: str,
|
|
| 63 |
index_params.add_index(
|
| 64 |
field_name="binary_vector",
|
| 65 |
index_name="binary_vector_index",
|
| 66 |
-
index_type="BIN_FLAT",
|
| 67 |
-
metric_type="HAMMING",
|
| 68 |
)
|
| 69 |
# Create collection with schema and index
|
| 70 |
client.create_collection(
|
|
@@ -77,6 +80,7 @@ def create_collection_if_not_exists(client: MilvusClient, collection_name: str,
|
|
| 77 |
print(f"Error creating collection: {e}")
|
| 78 |
return None
|
| 79 |
|
|
|
|
| 80 |
def insert_data(client: MilvusClient, collection_name: str, data: list[dict]):
|
| 81 |
"""
|
| 82 |
Insert data into a collection in Milvus.
|
|
@@ -95,7 +99,9 @@ def insert_data(client: MilvusClient, collection_name: str, data: list[dict]):
|
|
| 95 |
print(f"Error inserting data: {e}")
|
| 96 |
|
| 97 |
|
| 98 |
-
def search(
|
|
|
|
|
|
|
| 99 |
"""
|
| 100 |
Search for data in a collection in Milvus.
|
| 101 |
"""
|
|
@@ -115,10 +121,10 @@ def search(client: MilvusClient, collection_name: str, binary_query: bytes, limi
|
|
| 115 |
if not results:
|
| 116 |
print("No search results found")
|
| 117 |
return []
|
| 118 |
-
|
| 119 |
contexts = [res.entity.context for res in results[0]]
|
| 120 |
return contexts
|
| 121 |
|
| 122 |
except Exception as e:
|
| 123 |
print(f"Error searching for data: {e}")
|
| 124 |
-
return []
|
|
|
|
| 1 |
+
from pymilvus import DataType, MilvusClient
|
| 2 |
|
| 3 |
|
| 4 |
def get_milvus_client(db_path: str) -> MilvusClient:
|
|
|
|
| 14 |
try:
|
| 15 |
client = MilvusClient(db_path)
|
| 16 |
return client
|
| 17 |
+
|
| 18 |
except Exception as e:
|
| 19 |
print(f"Error getting Milvus client: {e}")
|
| 20 |
return None
|
| 21 |
|
| 22 |
+
|
| 23 |
+
def create_collection_if_not_exists(
|
| 24 |
+
client: MilvusClient, collection_name: str, dim: int
|
| 25 |
+
) -> None:
|
| 26 |
"""
|
| 27 |
Create a collection in Milvus if it does not exist.
|
| 28 |
|
|
|
|
| 66 |
index_params.add_index(
|
| 67 |
field_name="binary_vector",
|
| 68 |
index_name="binary_vector_index",
|
| 69 |
+
index_type="BIN_FLAT", # Exact search for binary vectors
|
| 70 |
+
metric_type="HAMMING", # Hamming distance for binary vectors
|
| 71 |
)
|
| 72 |
# Create collection with schema and index
|
| 73 |
client.create_collection(
|
|
|
|
| 80 |
print(f"Error creating collection: {e}")
|
| 81 |
return None
|
| 82 |
|
| 83 |
+
|
| 84 |
def insert_data(client: MilvusClient, collection_name: str, data: list[dict]):
|
| 85 |
"""
|
| 86 |
Insert data into a collection in Milvus.
|
|
|
|
| 99 |
print(f"Error inserting data: {e}")
|
| 100 |
|
| 101 |
|
| 102 |
+
def search(
|
| 103 |
+
client: MilvusClient, collection_name: str, binary_query: bytes, limit: int = 5
|
| 104 |
+
):
|
| 105 |
"""
|
| 106 |
Search for data in a collection in Milvus.
|
| 107 |
"""
|
|
|
|
| 121 |
if not results:
|
| 122 |
print("No search results found")
|
| 123 |
return []
|
| 124 |
+
|
| 125 |
contexts = [res.entity.context for res in results[0]]
|
| 126 |
return contexts
|
| 127 |
|
| 128 |
except Exception as e:
|
| 129 |
print(f"Error searching for data: {e}")
|
| 130 |
+
return []
|