dylanglenister commited on
Commit
4ca8eaf
·
1 Parent(s): e5c9fd8

REFACTOR: RAG ready embedding.

Browse files

Reworked the emebdding file to match the embedding used for the knowledge base so that RAG can be implemented correctly.

scripts/download_model.py CHANGED
@@ -6,7 +6,7 @@ import os
6
  from huggingface_hub import snapshot_download
7
 
8
  # Set up paths
9
- MODEL_REPO = "sentence-transformers/all-MiniLM-L6-v2"
10
  MODEL_CACHE_DIR = "/app/model_cache"
11
  HF_CACHE_DIR = os.getenv("HF_HOME", "/home/user/.cache/huggingface")
12
 
 
6
  from huggingface_hub import snapshot_download
7
 
8
  # Set up paths
9
+ MODEL_REPO = "abhinand/MedEmbed-large-v0.1"
10
  MODEL_CACHE_DIR = "/app/model_cache"
11
  HF_CACHE_DIR = os.getenv("HF_HOME", "/home/user/.cache/huggingface")
12
 
src/config/settings.py CHANGED
@@ -8,7 +8,8 @@ class Settings:
8
  DEFAULT_TOP_K: int = 5
9
  SEMANTIC_CONTEXT_SIZE: int = 17
10
  SIMILARITY_THRESHOLD: float = 0.15
11
-
 
12
  # Safety Guard settings
13
  SAFETY_GUARD_ENABLED: bool = os.getenv("SAFETY_GUARD_ENABLED", "true").lower() == "true"
14
  SAFETY_GUARD_TIMEOUT: int = int(os.getenv("SAFETY_GUARD_TIMEOUT", "30"))
 
8
  DEFAULT_TOP_K: int = 5
9
  SEMANTIC_CONTEXT_SIZE: int = 17
10
  SIMILARITY_THRESHOLD: float = 0.15
11
+ EMBEDDING_MODEL_NAME: str = "MedEmbed-large-v0.1"
12
+
13
  # Safety Guard settings
14
  SAFETY_GUARD_ENABLED: bool = os.getenv("SAFETY_GUARD_ENABLED", "true").lower() == "true"
15
  SAFETY_GUARD_TIMEOUT: int = int(os.getenv("SAFETY_GUARD_TIMEOUT", "30"))
src/core/state.py CHANGED
@@ -1,5 +1,6 @@
1
  # src/core/state.py
2
 
 
3
  from src.core.memory_manager import MemoryManager
4
  from src.utils.embeddings import EmbeddingClient
5
  from src.utils.rotator import APIKeyRotator
@@ -34,7 +35,7 @@ class AppState:
34
  def initialize(self):
35
  """Initializes all core application components in the correct order."""
36
  # Initialize components with no dependencies first
37
- self.embedding_client = EmbeddingClient(model_name="all-MiniLM-L6-v2", dimension=384)
38
  self.gemini_rotator = APIKeyRotator("GEMINI_API_", max_slots=5)
39
  self.nvidia_rotator = APIKeyRotator("NVIDIA_API_", max_slots=5)
40
 
 
1
  # src/core/state.py
2
 
3
+ from src.config.settings import settings
4
  from src.core.memory_manager import MemoryManager
5
  from src.utils.embeddings import EmbeddingClient
6
  from src.utils.rotator import APIKeyRotator
 
35
  def initialize(self):
36
  """Initializes all core application components in the correct order."""
37
  # Initialize components with no dependencies first
38
+ self.embedding_client = EmbeddingClient(model_name=settings.EMBEDDING_MODEL_NAME)
39
  self.gemini_rotator = APIKeyRotator("GEMINI_API_", max_slots=5)
40
  self.nvidia_rotator = APIKeyRotator("NVIDIA_API_", max_slots=5)
41
 
src/utils/embeddings.py CHANGED
@@ -1,125 +1,163 @@
1
  # src/utils/embeddings.py
2
 
3
  import numpy as np
 
 
4
  from numpy.typing import NDArray
 
 
5
 
6
  from src.config.settings import settings
7
  from src.utils.logger import logger
8
 
9
 
10
  class EmbeddingClient:
11
- """A simple embedding client with a fallback mechanism."""
 
 
 
12
 
13
- def __init__(self, model_name: str = "default", dimension: int = 384):
14
  self.model_name = model_name
15
- self.dimension = dimension
16
- self.model = None
17
- self._fallback_mode = True
18
- self._init_embedding_model()
19
-
20
- def _init_embedding_model(self):
21
- """Initializes the sentence-transformer embedding model."""
 
22
  try:
23
- from sentence_transformers import SentenceTransformer # type: ignore
24
- self.model = SentenceTransformer(self.model_name)
25
- self._fallback_mode = False
26
- logger().info(f"Successfully loaded embedding model: {self.model_name}")
27
- except ImportError:
28
- logger().warning("sentence-transformers not found, using fallback embedding mode.")
 
 
 
29
  except Exception as e:
30
- logger().error(f"Error loading embedding model '{self.model_name}': {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- def embed(self, texts: str | list[str]) -> list[list[float]]:
33
- """Generates embeddings for the given texts."""
34
  if isinstance(texts, str):
35
  texts = [texts]
36
- return self._fallback_embed(texts) if self._fallback_mode else self._proper_embed(texts)
37
 
38
- def _proper_embed(self, texts: list[str]) -> list[list[float]]:
39
- """Generates embeddings using the sentence-transformer model."""
40
- try:
41
- embeddings = self.model.encode(texts, convert_to_numpy=True) # type: ignore
42
- return embeddings.tolist()
43
- except Exception as e:
44
- logger().error(f"Error during embedding generation: {e}")
45
- return self._fallback_embed(texts)
46
-
47
- def _fallback_embed(self, texts: list[str]) -> list[list[float]]:
48
- """Generates deterministic, hash-based embeddings as a fallback."""
49
- embeddings = []
50
- for text in texts:
51
- # Create a deterministic hash-based embedding
52
- text_hash = hash(text) % (2**32)
53
- np.random.seed(text_hash)
54
- vector = np.random.normal(0, 1, self.dimension)
55
- norm = np.linalg.norm(vector)
56
- if norm > 0:
57
- vector /= norm
58
- embeddings.append(vector.tolist())
59
- return embeddings
 
 
 
 
 
 
 
 
 
60
 
61
  def is_available(self) -> bool:
62
- """Checks if the proper embedding model is available."""
63
- return not self._fallback_mode
64
 
65
  def semantic_search(
66
  self,
67
  query: str,
68
  candidates: list[str],
69
  top_k: int = settings.SEMANTIC_CONTEXT_SIZE,
70
- threshold: float = settings.SIMILARITY_THRESHOLD
71
  ) -> list[str]:
72
  """Finds semantically similar texts using embedding-based search."""
73
- if not candidates:
74
  return []
75
 
76
  query_vector = np.array(self.embed(query)[0], dtype="float32")
77
- candidate_vectors = self.embed([s.strip() for s in candidates])
 
 
 
78
 
79
  similarities = [
80
- (self._cosine_similarity(query_vector, np.array(vec, dtype="float32")), text)
81
- for vec, text in zip(candidate_vectors, candidates)
 
 
 
82
  ]
83
 
84
  similarities.sort(key=lambda x: x[0], reverse=True)
85
  return [text for score, text in similarities[:top_k] if score > threshold]
86
 
87
- def similarity(self, text1: str, text2: str) -> float:
88
- """Calculate cosine similarity between two texts."""
89
- emb1 = self.embed([text1])[0]
90
- emb2 = self.embed([text2])[0]
91
-
92
- # Convert to numpy arrays
93
- emb1_np = np.array(emb1)
94
- emb2_np = np.array(emb2)
95
-
96
- return self._cosine_similarity(emb1_np, emb2_np)
97
-
98
- def batch_similarity(self, query: str, candidates: list[str]) -> list[float]:
99
- """Calculate similarity between a query and multiple candidate texts."""
100
- query_emb = self.embed([query])[0]
101
- candidate_embs = self.embed(candidates)
102
-
103
- similarities = []
104
- query_emb_np = np.array(query_emb)
105
-
106
- for candidate_emb in candidate_embs:
107
- candidate_emb_np = np.array(candidate_emb)
108
- similarities.append(self._cosine_similarity(query_emb_np, candidate_emb_np))
109
-
110
- return similarities
111
-
112
  def get_model_info(self) -> dict:
113
- """Get information about the current embedding model"""
114
  return {
115
  "model_name": self.model_name,
116
  "dimension": self.dimension,
117
- "fallback_mode": self._fallback_mode,
118
- "available": self.is_available()
119
  }
120
 
121
  @staticmethod
122
- def _cosine_similarity(vec_a: NDArray[np.float32], vec_b: NDArray[np.float32]) -> float:
 
 
123
  """Calculates the cosine similarity between two vectors."""
124
  norm_a = np.linalg.norm(vec_a)
125
  norm_b = np.linalg.norm(vec_b)
 
1
  # src/utils/embeddings.py
2
 
3
  import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
  from numpy.typing import NDArray
7
+ from transformers import (AutoModel, AutoTokenizer, PreTrainedModel,
8
+ PreTrainedTokenizer)
9
 
10
  from src.config.settings import settings
11
  from src.utils.logger import logger
12
 
13
 
14
  class EmbeddingClient:
15
+ """
16
+ An embedding client that generates vector embeddings for text using a
17
+ transformer model, mirroring the logic used for knowledge base creation.
18
+ """
19
 
20
+ def __init__(self, model_name: str):
21
  self.model_name = model_name
22
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ self.tokenizer: PreTrainedTokenizer | None = None
24
+ self.model: PreTrainedModel | None = None
25
+ self.dimension: int | None = None
26
+ self._available = self._init_embedding_model()
27
+
28
+ def _init_embedding_model(self) -> bool:
29
+ """Initializes the transformer model and tokenizer."""
30
  try:
31
+ logger().info(f"Loading embedding model '{self.model_name}' on {self.device}")
32
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
33
+ self.model = AutoModel.from_pretrained(self.model_name).to(self.device)
34
+ self.model.eval()
35
+
36
+ # Dynamically determine the embedding dimension
37
+ self.dimension = self._get_embedding_dimension()
38
+ logger().info(f"Successfully loaded model. Embedding dimension: {self.dimension}")
39
+ return True
40
  except Exception as e:
41
+ logger().error(f"Failed to load embedding model '{self.model_name}': {e}")
42
+ return False
43
+
44
+ def _get_embedding_dimension(self) -> int:
45
+ """Runs a test input to determine the model's output dimension."""
46
+ if not self.tokenizer or not self.model:
47
+ raise RuntimeError("Model and tokenizer must be initialized.")
48
+
49
+ test_input = self.tokenizer(
50
+ "test", return_tensors="pt", truncation=True, padding=True
51
+ ).to(self.device)
52
+
53
+ with torch.no_grad():
54
+ test_output = self.model(**test_input)
55
+ test_embedding = self._mean_pooling(
56
+ test_output.last_hidden_state, test_input["attention_mask"]
57
+ )
58
+ return test_embedding.shape[1]
59
+
60
+ def _mean_pooling(
61
+ self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
62
+ ) -> torch.Tensor:
63
+ """Performs mean pooling on token embeddings using an attention mask."""
64
+ input_mask_expanded = (
65
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
66
+ )
67
+ masked_embeddings = token_embeddings * input_mask_expanded
68
+ summed_embeddings = torch.sum(masked_embeddings, 1)
69
+ summed_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
70
+ return summed_embeddings / summed_mask
71
+
72
+ def embed(self, texts: str | list[str], batch_size: int = 64) -> list[list[float]]:
73
+ """
74
+ Generates normalized, mean-pooled embeddings for the given texts.
75
+ Returns an empty list if the model is not available or an error occurs.
76
+ """
77
+ if not self.is_available() or not self.tokenizer or not self.model:
78
+ logger().error("Embedding model is not available, cannot generate embeddings.")
79
+ return [[] for _ in range(len(texts) if isinstance(texts, list) else 1)]
80
 
 
 
81
  if isinstance(texts, str):
82
  texts = [texts]
 
83
 
84
+ all_embeddings = []
85
+ for i in range(0, len(texts), batch_size):
86
+ batch_texts = texts[i : i + batch_size]
87
+ try:
88
+ inputs = self.tokenizer(
89
+ batch_texts,
90
+ truncation=True,
91
+ padding=True,
92
+ max_length=512,
93
+ return_tensors="pt",
94
+ ).to(self.device)
95
+
96
+ with torch.no_grad():
97
+ outputs = self.model(**outputs)
98
+
99
+ attention_mask = inputs["attention_mask"]
100
+ chunk_embeddings = self._mean_pooling(
101
+ outputs.last_hidden_state, attention_mask
102
+ )
103
+
104
+ # L2 Normalization - CRITICAL STEP FOR COMPATIBILITY
105
+ normalized_embeddings = F.normalize(chunk_embeddings, p=2, dim=1)
106
+
107
+ all_embeddings.extend(normalized_embeddings.cpu().numpy().tolist())
108
+
109
+ except Exception as e:
110
+ logger().error(f"Error during embedding generation for a batch: {e}")
111
+ # Add empty embeddings for the failed batch
112
+ all_embeddings.extend([[] for _ in batch_texts])
113
+
114
+ return all_embeddings
115
 
116
  def is_available(self) -> bool:
117
+ """Checks if the embedding model was loaded successfully."""
118
+ return self._available
119
 
120
  def semantic_search(
121
  self,
122
  query: str,
123
  candidates: list[str],
124
  top_k: int = settings.SEMANTIC_CONTEXT_SIZE,
125
+ threshold: float = settings.SIMILARITY_THRESHOLD,
126
  ) -> list[str]:
127
  """Finds semantically similar texts using embedding-based search."""
128
+ if not self.is_available() or not candidates:
129
  return []
130
 
131
  query_vector = np.array(self.embed(query)[0], dtype="float32")
132
+ if query_vector.size == 0:
133
+ return []
134
+
135
+ candidate_vectors = self.embed(candidates)
136
 
137
  similarities = [
138
+ (
139
+ self._cosine_similarity(query_vector, np.array(vec, dtype="float32")),
140
+ text,
141
+ )
142
+ for vec, text in zip(candidate_vectors, candidates) if vec
143
  ]
144
 
145
  similarities.sort(key=lambda x: x[0], reverse=True)
146
  return [text for score, text in similarities[:top_k] if score > threshold]
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  def get_model_info(self) -> dict:
149
+ """Get information about the current embedding model."""
150
  return {
151
  "model_name": self.model_name,
152
  "dimension": self.dimension,
153
+ "device": str(self.device),
154
+ "available": self.is_available(),
155
  }
156
 
157
  @staticmethod
158
+ def _cosine_similarity(
159
+ vec_a: NDArray[np.float32], vec_b: NDArray[np.float32]
160
+ ) -> float:
161
  """Calculates the cosine similarity between two vectors."""
162
  norm_a = np.linalg.norm(vec_a)
163
  norm_b = np.linalg.norm(vec_b)