technophyle commited on
Commit
538d769
·
verified ·
1 Parent(s): 64fcde1

Sync from GitHub via hub-sync

Browse files
Files changed (4) hide show
  1. README.md +1 -1
  2. requirements.txt +1 -0
  3. src/embeddings.py +79 -1
  4. src/rag_system.py +41 -1
README.md CHANGED
@@ -16,6 +16,6 @@ Behavior:
16
  - Clones a public GitHub repo
17
  - Chunks it with tree-sitter
18
  - Builds retrieval state with a Qdrant adapter
19
- - Answers questions with Groq-hosted Llama or Vertex AI Gemini depending on environment configuration
20
  - Deletes the cloned repo after indexing
21
  - Keeps only lightweight repo metadata in SQLite
 
16
  - Clones a public GitHub repo
17
  - Chunks it with tree-sitter
18
  - Builds retrieval state with a Qdrant adapter
19
+ - Answers questions with Groq-hosted Llama, AWS Bedrock, or Vertex AI Gemini depending on environment configuration
20
  - Deletes the cloned repo after indexing
21
  - Keeps only lightweight repo metadata in SQLite
requirements.txt CHANGED
@@ -5,6 +5,7 @@ pydantic==2.6.1
5
  python-dotenv==1.0.1
6
 
7
  openai==1.109.1
 
8
  google-genai==1.12.1
9
  httpx==0.28.1
10
  numpy==1.26.4
 
5
  python-dotenv==1.0.1
6
 
7
  openai==1.109.1
8
+ boto3==1.40.58
9
  google-genai==1.12.1
10
  httpx==0.28.1
11
  numpy==1.26.4
src/embeddings.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import time
3
  from typing import Callable, List, Optional
@@ -71,6 +72,21 @@ class EmbeddingGenerator:
71
  str(self.vertex_output_dimensionality or 3072),
72
  )
73
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  else:
75
  model_device = self.device or "cpu"
76
  print(
@@ -98,6 +114,8 @@ class EmbeddingGenerator:
98
  [text],
99
  task_type=self.vertex_task_type_query,
100
  )[0]
 
 
101
  query_text = f"{self.query_prefix}: {text}" if self.query_prefix else text
102
  return self._encode_with_backoff([query_text], prompt_name=self.query_prompt_name)[0]
103
 
@@ -125,6 +143,12 @@ class EmbeddingGenerator:
125
  batch_size=batch_size,
126
  progress_callback=progress_callback,
127
  )
 
 
 
 
 
 
128
 
129
  effective_batch_size = max(1, batch_size or self.batch_size)
130
  all_embeddings = []
@@ -222,6 +246,58 @@ class EmbeddingGenerator:
222
 
223
  return np.array(values, dtype="float32")
224
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  def _encode_with_backoff(
226
  self,
227
  texts: List[str],
@@ -266,12 +342,14 @@ class EmbeddingGenerator:
266
  return configured_provider
267
  if self._is_hf_space() or self._is_test_context():
268
  return "local"
269
- return "vertex_ai"
270
 
271
  def _resolve_model_name(self) -> str:
272
  explicit_model = os.getenv("EMBEDDING_MODEL")
273
  if explicit_model:
274
  return explicit_model
 
 
275
  if self.provider == "vertex_ai":
276
  return os.getenv("VERTEX_EMBEDDING_MODEL", "gemini-embedding-001")
277
  if self._is_hf_space() or self._is_test_context():
 
1
+ import json
2
  import os
3
  import time
4
  from typing import Callable, List, Optional
 
72
  str(self.vertex_output_dimensionality or 3072),
73
  )
74
  )
75
+ elif self.provider == "bedrock":
76
+ print(
77
+ f"[embeddings] Initializing AWS Bedrock embeddings with model={self.model_name}",
78
+ flush=True,
79
+ )
80
+ try:
81
+ import boto3
82
+ except ImportError as exc:
83
+ raise RuntimeError(
84
+ "AWS Bedrock embedding support requires the `boto3` package."
85
+ ) from exc
86
+
87
+ region = os.getenv("AWS_REGION", os.getenv("AWS_DEFAULT_REGION", "us-east-1"))
88
+ self.client = boto3.client("bedrock-runtime", region_name=region)
89
+ self.embedding_dim = int(os.getenv("BEDROCK_EMBEDDING_DIM", "1024"))
90
  else:
91
  model_device = self.device or "cpu"
92
  print(
 
114
  [text],
115
  task_type=self.vertex_task_type_query,
116
  )[0]
117
+ if self.provider == "bedrock":
118
+ return self._embed_with_bedrock(text)
119
  query_text = f"{self.query_prefix}: {text}" if self.query_prefix else text
120
  return self._encode_with_backoff([query_text], prompt_name=self.query_prompt_name)[0]
121
 
 
143
  batch_size=batch_size,
144
  progress_callback=progress_callback,
145
  )
146
+ if self.provider == "bedrock":
147
+ return self._embed_batch_with_bedrock(
148
+ texts=texts,
149
+ batch_size=batch_size,
150
+ progress_callback=progress_callback,
151
+ )
152
 
153
  effective_batch_size = max(1, batch_size or self.batch_size)
154
  all_embeddings = []
 
246
 
247
  return np.array(values, dtype="float32")
248
 
249
+ def _embed_batch_with_bedrock(
250
+ self,
251
+ texts: List[str],
252
+ batch_size: int = None,
253
+ progress_callback: Optional[Callable[[int, int], None]] = None,
254
+ ) -> np.ndarray:
255
+ effective_batch_size = max(1, batch_size or self.batch_size)
256
+ all_embeddings = []
257
+ total = len(texts)
258
+
259
+ for start in range(0, total, effective_batch_size):
260
+ batch = texts[start : start + effective_batch_size]
261
+ batch_number = (start // effective_batch_size) + 1
262
+ total_batches = (total + effective_batch_size - 1) // effective_batch_size
263
+ print(
264
+ f"[embeddings] Bedrock batch {batch_number}/{total_batches} "
265
+ f"items={len(batch)} progress={start}/{total}",
266
+ flush=True,
267
+ )
268
+ started_at = time.perf_counter()
269
+ batch_embeddings = [self._embed_with_bedrock(text) for text in batch]
270
+ all_embeddings.append(np.vstack(batch_embeddings))
271
+ elapsed = time.perf_counter() - started_at
272
+ print(
273
+ f"[embeddings] Finished Bedrock batch {batch_number}/{total_batches} "
274
+ f"elapsed={elapsed:.2f}s progress={min(start + len(batch), total)}/{total}",
275
+ flush=True,
276
+ )
277
+ if progress_callback:
278
+ progress_callback(min(start + len(batch), total), total)
279
+
280
+ return np.vstack(all_embeddings).astype("float32")
281
+
282
+ def _embed_with_bedrock(self, text: str) -> np.ndarray:
283
+ payload = {"inputText": text, "normalize": True}
284
+ if self.embedding_dim in {256, 512, 1024}:
285
+ payload["dimensions"] = self.embedding_dim
286
+
287
+ response = self.client.invoke_model(
288
+ modelId=self.model_name,
289
+ body=json.dumps(payload),
290
+ accept="application/json",
291
+ contentType="application/json",
292
+ )
293
+ body = json.loads(response["body"].read())
294
+ values = body.get("embedding")
295
+ if values is None:
296
+ values = (body.get("embeddingsByType") or {}).get("float")
297
+ if not values:
298
+ raise RuntimeError("AWS Bedrock embeddings returned an empty response.")
299
+ return np.array(values, dtype="float32")
300
+
301
  def _encode_with_backoff(
302
  self,
303
  texts: List[str],
 
342
  return configured_provider
343
  if self._is_hf_space() or self._is_test_context():
344
  return "local"
345
+ return "bedrock"
346
 
347
  def _resolve_model_name(self) -> str:
348
  explicit_model = os.getenv("EMBEDDING_MODEL")
349
  if explicit_model:
350
  return explicit_model
351
+ if self.provider == "bedrock":
352
+ return os.getenv("BEDROCK_EMBEDDING_MODEL", "amazon.titan-embed-text-v2:0")
353
  if self.provider == "vertex_ai":
354
  return os.getenv("VERTEX_EMBEDDING_MODEL", "gemini-embedding-001")
355
  if self._is_hf_space() or self._is_test_context():
src/rag_system.py CHANGED
@@ -45,7 +45,7 @@ class CodebaseRAGSystem:
45
  )
46
  )
47
  self.app_env = os.getenv("APP_ENV", os.getenv("ENVIRONMENT", "local")).lower()
48
- self.llm_provider = os.getenv("LLM_PROVIDER", "vertex_ai").lower()
49
  self.llm_client = None
50
  self.llm_model = ""
51
  self._configure_llm()
@@ -534,6 +534,21 @@ Do not leave the answer unfinished.
534
  self.llm_model = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
535
  return
536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
  if self.llm_provider == "vertex_ai":
538
  try:
539
  from google import genai
@@ -575,6 +590,31 @@ Do not leave the answer unfinished.
575
  finish_reason = getattr(response.choices[0], "finish_reason", "") or ""
576
  return self._normalize_markdown_answer(content), str(finish_reason)
577
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
578
  response = self.llm_client.models.generate_content(
579
  model=self.llm_model,
580
  contents=f"{system_prompt.strip()}\n\n{user_prompt.strip()}",
 
45
  )
46
  )
47
  self.app_env = os.getenv("APP_ENV", os.getenv("ENVIRONMENT", "local")).lower()
48
+ self.llm_provider = os.getenv("LLM_PROVIDER", "bedrock").lower()
49
  self.llm_client = None
50
  self.llm_model = ""
51
  self._configure_llm()
 
534
  self.llm_model = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
535
  return
536
 
537
+ if self.llm_provider == "bedrock":
538
+ try:
539
+ import boto3
540
+ except ImportError as exc:
541
+ raise RuntimeError(
542
+ "AWS Bedrock LLM support requires the `boto3` package."
543
+ ) from exc
544
+
545
+ region = os.getenv("AWS_REGION", os.getenv("AWS_DEFAULT_REGION", "us-east-1"))
546
+ self.llm_client = boto3.client("bedrock-runtime", region_name=region)
547
+ self.llm_model = os.getenv(
548
+ "BEDROCK_LLM_MODEL", "us.meta.llama3-3-70b-instruct-v1:0"
549
+ )
550
+ return
551
+
552
  if self.llm_provider == "vertex_ai":
553
  try:
554
  from google import genai
 
590
  finish_reason = getattr(response.choices[0], "finish_reason", "") or ""
591
  return self._normalize_markdown_answer(content), str(finish_reason)
592
 
593
+ if self.llm_provider == "bedrock":
594
+ response = self.llm_client.converse(
595
+ modelId=self.llm_model,
596
+ system=[{"text": system_prompt.strip()}],
597
+ messages=[
598
+ {
599
+ "role": "user",
600
+ "content": [{"text": user_prompt.strip()}],
601
+ }
602
+ ],
603
+ inferenceConfig={
604
+ "temperature": 0.1,
605
+ "maxTokens": 2200,
606
+ },
607
+ )
608
+ output_message = (response.get("output") or {}).get("message") or {}
609
+ content_blocks = output_message.get("content") or []
610
+ text = "".join(
611
+ block.get("text", "") for block in content_blocks if isinstance(block, dict)
612
+ )
613
+ if not text.strip():
614
+ raise RuntimeError("AWS Bedrock returned an empty response.")
615
+ stop_reason = response.get("stopReason", "") or ""
616
+ return self._normalize_markdown_answer(text), str(stop_reason)
617
+
618
  response = self.llm_client.models.generate_content(
619
  model=self.llm_model,
620
  contents=f"{system_prompt.strip()}\n\n{user_prompt.strip()}",