petProject / create_vector_stores.py
nesanchezo's picture
Update create_vector_stores.py
857eb0d verified
import logging
from typing import List, Dict, Any
import pickle
import nltk
from nltk.tokenize import word_tokenize
from rank_bm25 import BM25Okapi
import chromadb
from chromadb.config import Settings
from openai import OpenAI
import pandas as pd
from tqdm import tqdm
from dotenv import load_dotenv
import os
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
class VectorStoreCreator:
"""Class to create and manage vector stores for dog food product search."""
def __init__(self, data_path: str):
"""
Initialize the VectorStoreCreator.
Args:
data_path: Path to the pickle file containing the product data
"""
# Load environment variables
#load_dotenv()
# Obtener las claves de los secrets de Hugging Face
#openai.api_key = st.secrets["OPENAI_API_KEY"].strip()
#os.environ["LANGCHAIN_API_KEY"] = st.secrets["LANGCHAIN_API_KEY"]
#os.environ["LANGCHAIN_TRACING_V2"] = st.secrets["LANGCHAIN_TRACING_V2"]
# Initialize OpenAI client
self.client = OpenAI()
# Download NLTK resources
nltk.download('punkt', quiet=True)
# Load data
self.df = pd.read_pickle(data_path)
# Initialize stores
self.bm25_model = None
self.chroma_collection = None
self.chunks = []
self.metadata = []
def prepare_data(self) -> None:
"""Prepare data for BM25 and embeddings."""
logging.info("Preparing data for vector stores...")
# Log initial dataframe info
total_rows = len(self.df)
logging.info(f"Total rows in DataFrame: {total_rows}")
for _, row in self.df.iterrows():
# Combine English and Spanish descriptions
combined_text = f"{row['description_en']} {row['description_es']}"
self.chunks.append(combined_text)
# Create metadata
metadata = {
"product_name": row["product_name"],
"brand": row["brand"],
"dog_type": row["dog_type"],
"food_type": row["food_type"],
"weight": float(row["weight"]),
"price": float(row["price"]),
"reviews": float(row["reviews"]) if pd.notna(row["reviews"]) else 0.0
}
self.metadata.append(metadata)
# Log final chunks info
logging.info(f"Total chunks created: {len(self.chunks)}")
if len(self.chunks) != total_rows:
logging.warning(f"Mismatch between DataFrame rows ({total_rows}) and chunks created ({len(self.chunks)})")
# Log sample of first chunk
if self.chunks:
logging.info(f"Sample of first chunk: {self.chunks[0][:200]}...")
def create_bm25_index(self, save_path: str = "bm25_index.pkl") -> None:
"""
Create and save BM25 index.
Args:
save_path: Path to save the BM25 index
"""
logging.info("Creating BM25 index...")
# Tokenize chunks
tokenized_chunks = [word_tokenize(chunk.lower()) for chunk in self.chunks]
# Create BM25 model
self.bm25_model = BM25Okapi(tokenized_chunks)
# Save the model and related data
with open(save_path, 'wb') as f:
pickle.dump({
'model': self.bm25_model,
'chunks': self.chunks,
'metadata': self.metadata
}, f)
logging.info(f"BM25 index saved to {save_path}")
def create_chroma_db(self, db_path: str = "chroma_db") -> None:
"""
Create ChromaDB database.
Args:
db_path: Path to save the ChromaDB
"""
logging.info("Creating ChromaDB database...")
# Initialize ChromaDB with new client syntax
client = chromadb.PersistentClient(path=db_path)
# Create or get collection
self.chroma_collection = client.get_or_create_collection(
name="dog_food_descriptions"
)
# Add documents in batches
batch_size = 10
for i in tqdm(range(0, len(self.chunks), batch_size)):
batch_chunks = self.chunks[i:i + batch_size]
batch_metadata = self.metadata[i:i + batch_size]
batch_ids = [str(idx) for idx in range(i, min(i + batch_size, len(self.chunks)))]
# Get embeddings for batch
embeddings = []
for chunk in batch_chunks:
response = self.client.embeddings.create(
model="text-embedding-ada-002",
input=chunk
)
embeddings.append(response.data[0].embedding)
# Add to collection
self.chroma_collection.add(
embeddings=embeddings,
metadatas=batch_metadata,
documents=batch_chunks,
ids=batch_ids
)
logging.info(f"ChromaDB saved to {db_path}")
def main():
"""Main execution function."""
try:
# Initialize creator
creator = VectorStoreCreator("3rd_clean_comida_dogs_enriched_multilingual_2.pkl")
# Prepare data
creator.prepare_data()
# Create indices
creator.create_bm25_index()
creator.create_chroma_db()
logging.info("Vector stores created successfully!")
except Exception as e:
logging.error(f"An error occurred: {e}")
raise
if __name__ == "__main__":
main()