from sentence_transformers import SentenceTransformer import faiss import numpy as np from typing import List, Dict, Any import torch import gc class FAQEmbedder: def __init__(self, model_name: str = "all-MiniLM-L6-v2"): """ Initialize the FAQ embedder with a sentence transformer model Optimized for memory efficiency """ print(f"Initializing FAQ Embedder with model: {model_name}") # Use CPU for embedding model to save GPU memory for LLM self.device = "cpu" self.model = SentenceTransformer(model_name, device=self.device) self.index = None self.faqs = None self.embeddings = None def create_embeddings(self, faqs: List[Dict[str, Any]], batch_size: int = 32) -> None: """ Create embeddings for all FAQs and build FAISS index Using batching for memory efficiency """ self.faqs = faqs print(f"Creating embeddings for {len(faqs)} FAQs in batches of {batch_size}...") # Extract questions for embedding questions = [faq['question'] for faq in faqs] # Process in batches to reduce memory usage all_embeddings = [] for i in range(0, len(questions), batch_size): batch = questions[i:i+batch_size] print(f"Processing batch {i//batch_size + 1}/{(len(questions) + batch_size - 1)//batch_size}") # Create embeddings for this batch batch_embeddings = self.model.encode(batch, show_progress_bar=False, convert_to_numpy=True) all_embeddings.append(batch_embeddings) # Combine all batches self.embeddings = np.vstack(all_embeddings).astype('float32') # Clear memory explicitly all_embeddings = None gc.collect() # Create FAISS index dimension = self.embeddings.shape[1] self.index = faiss.IndexFlatL2(dimension) self.index.add(self.embeddings) print(f"Created embeddings of shape {self.embeddings.shape}") print(f"FAISS index contains {self.index.ntotal} vectors") def retrieve_relevant_faqs(self, query: str, k: int = 3) -> List[Dict[str, Any]]: """ Retrieve top-k relevant FAQs for a given query """ if self.index is None or self.faqs is None or self.embeddings is None: raise ValueError("Embeddings not created yet. Call create_embeddings first.") # Embed the query query_embedding = self.model.encode([query], convert_to_numpy=True).astype('float32') # Search in FAISS distances, indices = self.index.search(query_embedding, k) # Get the relevant FAQs with their similarity scores relevant_faqs = [] for i, idx in enumerate(indices[0]): if idx < len(self.faqs): # Ensure index is valid faq = self.faqs[idx].copy() # Convert L2 distance to similarity score (higher is better) similarity = 1.0 / (1.0 + distances[0][i]) faq['similarity'] = similarity relevant_faqs.append(faq) return relevant_faqs