import numpy as np import pandas as pd from typing import List import logging from sentence_transformers import SentenceTransformer import os from pathlib import Path import pickle os.environ['TOKENIZERS_PARALLELISM'] = 'false' # To avoid warnings logger = logging.getLogger(__name__) class EmbeddingGenerator: def __init__(self, model_name: str = 'all-MiniLM-L6-v2', cache_dir: str = None): try: self.model_name = model_name self.model = SentenceTransformer(model_name) # Setup cache directory self.cache_dir = Path(cache_dir) if cache_dir else Path('data/embedding_cache') self.cache_dir.mkdir(parents=True, exist_ok=True) # Cache file for embeddings self.cache_file = self.cache_dir / f"embeddings_cache_{model_name.replace('/', '_')}.pkl" # Load existing cache if available self.embedding_cache = self._load_cache() logger.info(f"Successfully loaded model: {model_name}") except Exception as e: logger.error(f"Error loading model: {str(e)}") raise def _load_cache(self) -> dict: """Load embedding cache from file if it exists""" try: if self.cache_file.exists(): with open(self.cache_file, 'rb') as f: cache = pickle.load(f) logger.info(f"Loaded {len(cache)} cached embeddings") return cache return {} except Exception as e: logger.warning(f"Error loading cache, starting fresh: {str(e)}") return {} def _save_cache(self): """Save embedding cache to file""" try: with open(self.cache_file, 'wb') as f: pickle.dump(self.embedding_cache, f) logger.info(f"Saved {len(self.embedding_cache)} embeddings to cache") except Exception as e: logger.error(f"Error saving cache: {str(e)}") def generate_embeddings(self, texts: pd.Series) -> np.ndarray: try: # Convert texts to list text_list = texts.tolist() # Initialize array to store embeddings all_embeddings = [] texts_to_embed = [] indices_to_embed = [] # Check cache for existing embeddings for i, text in enumerate(text_list): text_hash = hash(text) if text_hash in self.embedding_cache: all_embeddings.append(self.embedding_cache[text_hash]) else: texts_to_embed.append(text) indices_to_embed.append(i) # Generate embeddings only for new texts if texts_to_embed: logger.info(f"Generating embeddings for {len(texts_to_embed)} new texts") new_embeddings = self.model.encode( texts_to_embed, show_progress_bar=True, convert_to_numpy=True ) # Cache new embeddings for text, embedding in zip(texts_to_embed, new_embeddings): text_hash = hash(text) self.embedding_cache[text_hash] = embedding # Save updated cache self._save_cache() # Insert new embeddings in correct positions for idx, embedding in zip(indices_to_embed, new_embeddings): all_embeddings.insert(idx, embedding) else: logger.info("All embeddings found in cache") return np.array(all_embeddings) except Exception as e: logger.error(f"Error generating embeddings: {str(e)}") raise def add_embeddings_to_df(self, df: pd.DataFrame, text_column: str = 'description') -> pd.DataFrame: try: embeddings = self.generate_embeddings(df[text_column]) df['embeddings'] = list(embeddings) return df except Exception as e: logger.error(f"Error adding embeddings to DataFrame: {str(e)}") raise def clear_cache(self): """Clear the embedding cache""" try: self.embedding_cache = {} if self.cache_file.exists(): self.cache_file.unlink() logger.info("Embedding cache cleared") except Exception as e: logger.error(f"Error clearing cache: {str(e)}") raise