Spaces:
Sleeping
Sleeping
""" | |
Script to store FAISS and pickle files in Azure Cosmos DB MongoDB API | |
This script should be run only once to upload the vector database files. | |
run this at root of the project: | |
python -m app.services.chatbot.vectorDB_upload_script | |
""" | |
import os | |
import asyncio | |
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorGridFSBucket | |
from datetime import datetime | |
import logging | |
from dotenv import load_dotenv | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load environment variables from .env | |
load_dotenv() | |
# Database configuration | |
MONGO_URI = os.getenv("MONGODB_KEY") | |
DB_NAME = "SysmodelerDB" | |
BUCKET_NAME = "vector_db" | |
# File paths | |
FAISS_PATH = r"c:\Users\User\Downloads\faiss_index_sysml\index.faiss" | |
PICKLE_PATH = r"c:\Users\User\Downloads\faiss_index_sysml\index.pkl" | |
class VectorDBUploader: | |
def __init__(self): | |
self.client = None | |
self.db = None | |
self.fs = None | |
async def connect(self): | |
"""Connect to MongoDB""" | |
try: | |
self.client = AsyncIOMotorClient(MONGO_URI) | |
self.db = self.client[DB_NAME] | |
self.fs = AsyncIOMotorGridFSBucket(self.db, bucket_name=BUCKET_NAME) | |
# Test connection | |
await self.client.admin.command('ping') | |
logger.info("Successfully connected to MongoDB") | |
return True | |
except Exception as e: | |
logger.error(f"Failed to connect to MongoDB: {e}") | |
return False | |
async def file_exists(self, filename: str) -> bool: | |
"""Check if file already exists in GridFS""" | |
try: | |
async for file_doc in self.fs.find({"filename": filename}): | |
return True | |
return False | |
except Exception as e: | |
logger.error(f"Error checking file existence: {e}") | |
return False | |
# async def upload_file(self, file_path: str, filename: str) -> bool: | |
# """Upload a file to GridFS""" | |
# try: | |
# # Check if file exists locally | |
# if not os.path.exists(file_path): | |
# logger.error(f"Local file not found: {file_path}") | |
# return False | |
# # Check if file already exists in database | |
# if await self.file_exists(filename): | |
# logger.warning(f"File {filename} already exists in database. Skipping...") | |
# return True | |
# # Get file size for logging | |
# file_size = os.path.getsize(file_path) | |
# logger.info(f"Uploading {filename} ({file_size} bytes)...") | |
# # Upload file | |
# with open(file_path, 'rb') as f: | |
# file_id = await self.fs.upload_from_stream( | |
# filename, | |
# f, | |
# metadata={ | |
# "uploaded_at": datetime.utcnow(), | |
# "original_path": file_path, | |
# "file_size": file_size, | |
# "description": f"Vector database file: {filename}" | |
# } | |
# ) | |
# logger.info(f"Successfully uploaded {filename} with ID: {file_id}") | |
# return True | |
# except Exception as e: | |
# logger.error(f"Failed to upload {filename}: {e}") | |
# return False | |
async def upload_file(self, file_path: str, filename: str) -> bool: | |
"""Upload a file to GridFS, overwriting any existing files with the same name""" | |
try: | |
# Check if file exists locally | |
if not os.path.exists(file_path): | |
logger.error(f"Local file not found: {file_path}") | |
return False | |
# Remove any existing file with the same name | |
async for file_doc in self.fs.find({"filename": filename}): | |
await self.fs.delete(file_doc._id) | |
logger.info(f"Deleted existing file: {filename} (ID: {file_doc._id})") | |
# Get file size for logging | |
file_size = os.path.getsize(file_path) | |
logger.info(f"Uploading {filename} ({file_size} bytes)...") | |
# Upload file | |
with open(file_path, 'rb') as f: | |
file_id = await self.fs.upload_from_stream( | |
filename, | |
f, | |
metadata={ | |
"uploaded_at": datetime.utcnow(), | |
"original_path": file_path, | |
"file_size": file_size, | |
"description": f"Vector database file: {filename}" | |
} | |
) | |
logger.info(f"Successfully uploaded {filename} with ID: {file_id}") | |
return True | |
except Exception as e: | |
logger.error(f"Failed to upload {filename}: {e}") | |
return False | |
async def list_files(self): | |
"""List all files in the GridFS bucket""" | |
try: | |
logger.info(f"Files in {BUCKET_NAME} bucket:") | |
async for file_doc in self.fs.find(): | |
logger.info(f"- {file_doc.filename} (ID: {file_doc._id}, Size: {file_doc.length} bytes)") | |
except Exception as e: | |
logger.error(f"Failed to list files: {e}") | |
async def upload_vector_files(self): | |
"""Upload both FAISS and pickle files""" | |
files_to_upload = [ | |
(FAISS_PATH, "index.faiss"), | |
(PICKLE_PATH, "index.pkl") | |
] | |
success_count = 0 | |
for file_path, filename in files_to_upload: | |
if await self.upload_file(file_path, filename): | |
success_count += 1 | |
else: | |
logger.error(f"Failed to upload {filename}") | |
logger.info(f"Upload completed: {success_count}/{len(files_to_upload)} files uploaded successfully") | |
return success_count == len(files_to_upload) | |
async def close(self): | |
"""Close database connection""" | |
if self.client: | |
self.client.close() | |
logger.info("Database connection closed") | |
async def main(): | |
"""Main function to upload vector database files""" | |
uploader = VectorDBUploader() | |
try: | |
# Connect to database | |
if not await uploader.connect(): | |
logger.error("Failed to connect to database. Exiting...") | |
return | |
# Upload files | |
logger.info("Starting vector database files upload...") | |
success = await uploader.upload_vector_files() | |
if success: | |
logger.info("All files uploaded successfully!") | |
else: | |
logger.error("Some files failed to upload") | |
# List all files in the bucket | |
await uploader.list_files() | |
except Exception as e: | |
logger.error(f"Unexpected error: {e}") | |
finally: | |
await uploader.close() | |
if __name__ == "__main__": | |
print("Vector Database File Uploader") | |
print("=" * 50) | |
print(f"Database: {DB_NAME}") | |
print(f"Bucket: {BUCKET_NAME}") | |
print(f"FAISS file: {FAISS_PATH}") | |
print(f"Pickle file: {PICKLE_PATH}") | |
print("=" * 50) | |
# Run the upload process | |
asyncio.run(main()) | |