Spaces:
Sleeping
Sleeping
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
from typing import List, Dict, Any | |
import torch | |
import gc | |
class FAQEmbedder: | |
def __init__(self, model_name: str = "all-MiniLM-L6-v2"): | |
""" | |
Initialize the FAQ embedder with a sentence transformer model | |
Optimized for memory efficiency | |
""" | |
print(f"Initializing FAQ Embedder with model: {model_name}") | |
# Use CPU for embedding model to save GPU memory for LLM | |
self.device = "cpu" | |
self.model = SentenceTransformer(model_name, device=self.device) | |
self.index = None | |
self.faqs = None | |
self.embeddings = None | |
def create_embeddings(self, faqs: List[Dict[str, Any]], batch_size: int = 32) -> None: | |
""" | |
Create embeddings for all FAQs and build FAISS index | |
Using batching for memory efficiency | |
""" | |
self.faqs = faqs | |
print(f"Creating embeddings for {len(faqs)} FAQs in batches of {batch_size}...") | |
# Extract questions for embedding | |
questions = [faq['question'] for faq in faqs] | |
# Process in batches to reduce memory usage | |
all_embeddings = [] | |
for i in range(0, len(questions), batch_size): | |
batch = questions[i:i+batch_size] | |
print(f"Processing batch {i//batch_size + 1}/{(len(questions) + batch_size - 1)//batch_size}") | |
# Create embeddings for this batch | |
batch_embeddings = self.model.encode(batch, show_progress_bar=False, convert_to_numpy=True) | |
all_embeddings.append(batch_embeddings) | |
# Combine all batches | |
self.embeddings = np.vstack(all_embeddings).astype('float32') | |
# Clear memory explicitly | |
all_embeddings = None | |
gc.collect() | |
# Create FAISS index | |
dimension = self.embeddings.shape[1] | |
self.index = faiss.IndexFlatL2(dimension) | |
self.index.add(self.embeddings) | |
print(f"Created embeddings of shape {self.embeddings.shape}") | |
print(f"FAISS index contains {self.index.ntotal} vectors") | |
def retrieve_relevant_faqs(self, query: str, k: int = 3) -> List[Dict[str, Any]]: | |
""" | |
Retrieve top-k relevant FAQs for a given query | |
""" | |
if self.index is None or self.faqs is None or self.embeddings is None: | |
raise ValueError("Embeddings not created yet. Call create_embeddings first.") | |
# Embed the query | |
query_embedding = self.model.encode([query], convert_to_numpy=True).astype('float32') | |
# Search in FAISS | |
distances, indices = self.index.search(query_embedding, k) | |
# Get the relevant FAQs with their similarity scores | |
relevant_faqs = [] | |
for i, idx in enumerate(indices[0]): | |
if idx < len(self.faqs): # Ensure index is valid | |
faq = self.faqs[idx].copy() | |
# Convert L2 distance to similarity score (higher is better) | |
similarity = 1.0 / (1.0 + distances[0][i]) | |
faq['similarity'] = similarity | |
relevant_faqs.append(faq) | |
return relevant_faqs |