faq-rag-chatbot / src /embedding.py
Techbite's picture
initial commit
26d1a81
raw
history blame
3.24 kB
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from typing import List, Dict, Any
import torch
import gc
class FAQEmbedder:
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
"""
Initialize the FAQ embedder with a sentence transformer model
Optimized for memory efficiency
"""
print(f"Initializing FAQ Embedder with model: {model_name}")
# Use CPU for embedding model to save GPU memory for LLM
self.device = "cpu"
self.model = SentenceTransformer(model_name, device=self.device)
self.index = None
self.faqs = None
self.embeddings = None
def create_embeddings(self, faqs: List[Dict[str, Any]], batch_size: int = 32) -> None:
"""
Create embeddings for all FAQs and build FAISS index
Using batching for memory efficiency
"""
self.faqs = faqs
print(f"Creating embeddings for {len(faqs)} FAQs in batches of {batch_size}...")
# Extract questions for embedding
questions = [faq['question'] for faq in faqs]
# Process in batches to reduce memory usage
all_embeddings = []
for i in range(0, len(questions), batch_size):
batch = questions[i:i+batch_size]
print(f"Processing batch {i//batch_size + 1}/{(len(questions) + batch_size - 1)//batch_size}")
# Create embeddings for this batch
batch_embeddings = self.model.encode(batch, show_progress_bar=False, convert_to_numpy=True)
all_embeddings.append(batch_embeddings)
# Combine all batches
self.embeddings = np.vstack(all_embeddings).astype('float32')
# Clear memory explicitly
all_embeddings = None
gc.collect()
# Create FAISS index
dimension = self.embeddings.shape[1]
self.index = faiss.IndexFlatL2(dimension)
self.index.add(self.embeddings)
print(f"Created embeddings of shape {self.embeddings.shape}")
print(f"FAISS index contains {self.index.ntotal} vectors")
def retrieve_relevant_faqs(self, query: str, k: int = 3) -> List[Dict[str, Any]]:
"""
Retrieve top-k relevant FAQs for a given query
"""
if self.index is None or self.faqs is None or self.embeddings is None:
raise ValueError("Embeddings not created yet. Call create_embeddings first.")
# Embed the query
query_embedding = self.model.encode([query], convert_to_numpy=True).astype('float32')
# Search in FAISS
distances, indices = self.index.search(query_embedding, k)
# Get the relevant FAQs with their similarity scores
relevant_faqs = []
for i, idx in enumerate(indices[0]):
if idx < len(self.faqs): # Ensure index is valid
faq = self.faqs[idx].copy()
# Convert L2 distance to similarity score (higher is better)
similarity = 1.0 / (1.0 + distances[0][i])
faq['similarity'] = similarity
relevant_faqs.append(faq)
return relevant_faqs