File size: 4,239 Bytes
5ff6b14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""

Embedding Manager Module



Handles creation of embeddings for text chunks using sentence transformers.

"""

import asyncio
import numpy as np
from typing import List
from sentence_transformers import SentenceTransformer
from config.config import EMBEDDING_MODEL, BATCH_SIZE


class EmbeddingManager:
    """Handles embedding creation for text chunks."""
    
    def __init__(self):
        """Initialize the embedding manager."""
        self.embedding_model = None
        self._init_embedding_model()
    
    def _init_embedding_model(self):
        """Initialize the embedding model."""
        print(f"πŸ”„ Loading embedding model: {EMBEDDING_MODEL}")
        self.embedding_model = SentenceTransformer(EMBEDDING_MODEL)
        print(f"βœ… Embedding model loaded successfully")
    
    async def create_embeddings(self, chunks: List[str]) -> np.ndarray:
        """

        Create embeddings for text chunks.

        

        Args:

            chunks: List of text chunks to embed

            

        Returns:

            np.ndarray: Array of embeddings with shape (num_chunks, embedding_dim)

        """
        print(f"🧠 Creating embeddings for {len(chunks)} chunks")
        
        if not chunks:
            raise ValueError("No chunks provided for embedding creation")
        
        def create_embeddings_sync():
            """Synchronous embedding creation to run in thread pool."""
            embeddings = self.embedding_model.encode(
                chunks,
                batch_size=BATCH_SIZE,
                show_progress_bar=True,
                normalize_embeddings=True
            )
            return np.array(embeddings).astype("float32")
        
        # Run in thread pool to avoid blocking the event loop
        loop = asyncio.get_event_loop()
        embeddings = await loop.run_in_executor(None, create_embeddings_sync)
        
        print(f"βœ… Created embeddings with shape: {embeddings.shape}")
        return embeddings
    
    def get_embedding_dimension(self) -> int:
        """

        Get the dimension of embeddings produced by the model.

        

        Returns:

            int: Embedding dimension

        """
        if self.embedding_model is None:
            raise RuntimeError("Embedding model not initialized")
        
        # Get dimension from model
        return self.embedding_model.get_sentence_embedding_dimension()
    
    def validate_embeddings(self, embeddings: np.ndarray, expected_count: int) -> bool:
        """

        Validate that embeddings have the expected shape and properties.

        

        Args:

            embeddings: The embeddings array to validate

            expected_count: Expected number of embeddings

            

        Returns:

            bool: True if embeddings are valid, False otherwise

        """
        if embeddings is None:
            return False
        
        if embeddings.shape[0] != expected_count:
            print(f"❌ Embedding count mismatch: expected {expected_count}, got {embeddings.shape[0]}")
            return False
        
        if embeddings.dtype != np.float32:
            print(f"❌ Embedding dtype mismatch: expected float32, got {embeddings.dtype}")
            return False
        
        # Check for NaN or infinite values
        if np.any(np.isnan(embeddings)) or np.any(np.isinf(embeddings)):
            print("❌ Embeddings contain NaN or infinite values")
            return False
        
        print(f"βœ… Embeddings validation passed: {embeddings.shape}")
        return True
    
    def get_model_info(self) -> dict:
        """

        Get information about the embedding model.

        

        Returns:

            dict: Model information

        """
        if self.embedding_model is None:
            return {"model_name": EMBEDDING_MODEL, "status": "not_loaded"}
        
        return {
            "model_name": EMBEDDING_MODEL,
            "embedding_dimension": self.get_embedding_dimension(),
            "max_sequence_length": getattr(self.embedding_model, 'max_seq_length', 'unknown'),
            "status": "loaded"
        }