Spaces:
Sleeping
Sleeping
File size: 4,239 Bytes
5ff6b14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
"""
Embedding Manager Module
Handles creation of embeddings for text chunks using sentence transformers.
"""
import asyncio
import numpy as np
from typing import List
from sentence_transformers import SentenceTransformer
from config.config import EMBEDDING_MODEL, BATCH_SIZE
class EmbeddingManager:
"""Handles embedding creation for text chunks."""
def __init__(self):
"""Initialize the embedding manager."""
self.embedding_model = None
self._init_embedding_model()
def _init_embedding_model(self):
"""Initialize the embedding model."""
print(f"π Loading embedding model: {EMBEDDING_MODEL}")
self.embedding_model = SentenceTransformer(EMBEDDING_MODEL)
print(f"β
Embedding model loaded successfully")
async def create_embeddings(self, chunks: List[str]) -> np.ndarray:
"""
Create embeddings for text chunks.
Args:
chunks: List of text chunks to embed
Returns:
np.ndarray: Array of embeddings with shape (num_chunks, embedding_dim)
"""
print(f"π§ Creating embeddings for {len(chunks)} chunks")
if not chunks:
raise ValueError("No chunks provided for embedding creation")
def create_embeddings_sync():
"""Synchronous embedding creation to run in thread pool."""
embeddings = self.embedding_model.encode(
chunks,
batch_size=BATCH_SIZE,
show_progress_bar=True,
normalize_embeddings=True
)
return np.array(embeddings).astype("float32")
# Run in thread pool to avoid blocking the event loop
loop = asyncio.get_event_loop()
embeddings = await loop.run_in_executor(None, create_embeddings_sync)
print(f"β
Created embeddings with shape: {embeddings.shape}")
return embeddings
def get_embedding_dimension(self) -> int:
"""
Get the dimension of embeddings produced by the model.
Returns:
int: Embedding dimension
"""
if self.embedding_model is None:
raise RuntimeError("Embedding model not initialized")
# Get dimension from model
return self.embedding_model.get_sentence_embedding_dimension()
def validate_embeddings(self, embeddings: np.ndarray, expected_count: int) -> bool:
"""
Validate that embeddings have the expected shape and properties.
Args:
embeddings: The embeddings array to validate
expected_count: Expected number of embeddings
Returns:
bool: True if embeddings are valid, False otherwise
"""
if embeddings is None:
return False
if embeddings.shape[0] != expected_count:
print(f"β Embedding count mismatch: expected {expected_count}, got {embeddings.shape[0]}")
return False
if embeddings.dtype != np.float32:
print(f"β Embedding dtype mismatch: expected float32, got {embeddings.dtype}")
return False
# Check for NaN or infinite values
if np.any(np.isnan(embeddings)) or np.any(np.isinf(embeddings)):
print("β Embeddings contain NaN or infinite values")
return False
print(f"β
Embeddings validation passed: {embeddings.shape}")
return True
def get_model_info(self) -> dict:
"""
Get information about the embedding model.
Returns:
dict: Model information
"""
if self.embedding_model is None:
return {"model_name": EMBEDDING_MODEL, "status": "not_loaded"}
return {
"model_name": EMBEDDING_MODEL,
"embedding_dimension": self.get_embedding_dimension(),
"max_sequence_length": getattr(self.embedding_model, 'max_seq_length', 'unknown'),
"status": "loaded"
}
|