Demos / backend /classes /vector_database /milvus_vector_database.py
nikhile-galileo's picture
Added G2.0 changes
753e3c5
import os
import shutil
from typing import List
import pandas as pd
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType
import logging
from backend.classes.vector_database.base_vector_database import VectorDatabaseConfig, VectorDatabase
logger = logging.getLogger(__name__)
class MilvusVectorDatabaseConfig(VectorDatabaseConfig):
"""Configuration for Milvus vector database."""
db_path: str
collection_name: str
vector_dimensions: int
drop_if_exists: bool = True
class Config:
arbitrary_types_allowed = True
class MilvusVectorDatabase(VectorDatabase):
"""Implementation of vector database using Milvus."""
def __init__(self, config: MilvusVectorDatabaseConfig):
super().__init__(config)
# Create database
self.client = self.connect()
self.create_collection(config.drop_if_exists)
# # Create or get collection
# schema = CollectionSchema(fields, description="Text embeddings collection")
# self.collection:Collection = Collection(name=self.config.collection_name, schema=schema)
def connect(self):
logger.info(f"\nConnecting to Milvus at {self.config.db_path}...")
client = MilvusClient(self.config.db_path)
logger.info("Connected to Milvus.")
return client
def _define_schema(self) -> List[FieldSchema]:
"""
Defines the Milvus collection schema for hybrid search.
- `id`: Primary key for unique chunk identification.
- `text_content`: Stores the chunked text, suitable for keyword filtering using `LIKE` or equality.
- `embedding`: Stores the dense vector embedding for similarity search.
- `doc_metadata`: A JSON field to store additional, flexible metadata for filtering.
"""
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=1024),
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=self.config.vector_dimensions),
FieldSchema(name="metadata", dtype=DataType.JSON, description="Flexible JSON metadata for the document")
]
return fields
def create_collection(self, drop_if_exists: bool = True):
"""
Creates the Milvus collection with the defined schema and necessary indexes.
Args:
drop_if_exists (bool): If True, drops the collection if it already exists
before creating a new one. Defaults to True.
"""
if drop_if_exists: # and self.client.has_collection(collection_name=self.config.collection_name):
logger.info(f"Dropping existing collection '{self.config.collection_name}'...")
self.client.drop_collection(collection_name=self.config.collection_name)
# Create scalar index on 'text_content' for efficient filtering (e.g., using LIKE)
logger.info(f"Creating scalar index on 'text_content' for filtering...")
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="embedding",
metric_type="COSINE", # Metric type is ignored for scalar indexes but required by API
index_type="IVF_FLAT", # HNSW is a good general-purpose vector index
params={"nlist": 128}
)
fields = self._define_schema()
milvus_schema = CollectionSchema(
fields=fields,
description="Hybrid search collection for Finance documents" # You can customize this description
)
logger.info(f"Creating collection '{self.config.collection_name}'...")
self.client.create_collection(
collection_name=self.config.collection_name,
schema=milvus_schema,
index_params=index_params,
dimension=self.config.vector_dimensions
)
# # Create scalar index on 'text_content' for efficient filtering (e.g., using LIKE)
# print(f"Creating scalar index on 'text' for filtering...")
# self.client.create_index(
# collection_name=self.config.collection_name,
# field_name="text",
# index_type="STL", # Segment Tree Index, suitable for VARCHAR filtering (equality, range, LIKE)
# metric_type="COSINE", # Metric type is ignored for scalar indexes but required by API
# index_params=index_params
# )
def add_texts(self, df: pd.DataFrame, embeddings: list):
"""
Add texts and their embeddings to the collection.
Args:
df: DataFrame containing text data with columns
embeddings: List of embeddings corresponding to each text
"""
# Prepare data
data = []
for index, row in df.iterrows():
row["embedding"] = embeddings[index]
data.append(row.to_dict())
# data = [
# df.text.tolist(),
# embeddings,
# df.metadata.tolist()
# ]
#
# Insert data
self.client.insert(collection_name=self.config.collection_name,data=data)
def hybrid_search(self, query_embedding: list, query_text: str, limit: int = 5,
text_weight: float = 0.4, embedding_weight: float = 0.6) -> list:
"""
Perform hybrid search combining text-based and vector similarity search.
Args:
query_embedding: Embedding vector for similarity search
query_text: Text query for text-based search
limit: Number of results to return
text_weight: Weight for text-based search score
embedding_weight: Weight for embedding similarity score
Returns:
List of search results with combined scores
"""
output_fields = ["text", "metadata"]
# Vector similarity search
search_results = self.client.search(
collection_name=self.config.collection_name,
data=[query_embedding],
anns_field="embedding",
param={"metric_type": "L2", "params": {"nprobe": 10}},
limit=limit * 2, # Get more candidates to combine with text search
output_fields=output_fields
)
# Process embedding results
formatted_results = []
if search_results and search_results[0]:
for hit in search_results[0]:
result = {
"id": hit['id'],
"distance": hit['distance'],
"text": hit.get('text', 'N/A'),
"metadata": hit.get('metadata', {})
}
# Add any other requested output fields
for field in output_fields:
if field not in result: # Avoid overwriting 'text' or 'metadata' if already handled
result[field] = hit.get(field)
formatted_results.append(result)
return formatted_results
def search_similar_texts(self, query_embedding: list, limit: int = 5):
"""
Search for similar texts based on embeddings.
Args:
query_embedding: Embedding vector to search for
limit: Number of results to return
Returns:
List of similar texts and their distances
"""
output_fields = ["text"]
search_results = self.client.search(
collection_name=self.config.collection_name,
data=query_embedding,
anns_field="embedding",
# param={"metric_type": "L2", "params": {"nprobe": 10}},
limit=limit, # Get more candidates to combine with text search
output_fields=output_fields
)
return [{
"text": result.get("text"),
"distance": result["distance"]
} for result in search_results[0]]
def drop_collection(self):
"""Drop the collection."""
if os.path.exists(self.config.db_path):
logger.info(f"Removing local Milvus Lite data directory: {self.config.db_path}...")
shutil.rmtree(self.config.db_path)
logger.info("Local data removed.")
else:
logger.info(f"Local data directory '{self.config.db_path}' not found, nothing to clean.")