File size: 1,409 Bytes
a19a241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# file: embedding.py

import torch
from sentence_transformers import SentenceTransformer
from typing import List

# --- Configuration ---
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"

class EmbeddingClient:
    """A client for generating text embeddings using a local sentence transformer model."""

    def __init__(self, model_name: str = EMBEDDING_MODEL_NAME):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = SentenceTransformer(model_name, device=self.device)
        print(f"EmbeddingClient initialized with model '{model_name}' on device '{self.device}'.")

    def create_embeddings(self, texts: List[str]) -> torch.Tensor:
        """
        Generates embeddings for a list of text chunks.

        Args:
            texts: A list of strings to be embedded.

        Returns:
            A torch.Tensor containing the generated embeddings.
        """
        if not texts:
            return torch.tensor([])
            
        print(f"Generating embeddings for {len(texts)} text chunks on {self.device}...")
        try:
            embeddings = self.model.encode(
                texts, convert_to_tensor=True, show_progress_bar=False
            )
            print("Embeddings generated successfully.")
            return embeddings
        except Exception as e:
            print(f"An error occurred during embedding generation: {e}")
            raise