Spaces:
Runtime error
Runtime error
Sync from GitHub via hub-sync
Browse files- README.md +1 -1
- evals/run_eval.py +234 -63
- evals/sample_eval_set.json +53 -0
- requirements.txt +1 -0
- src/bedrock_claude.py +51 -0
- src/embeddings.py +100 -69
- src/rag_system.py +48 -37
README.md
CHANGED
|
@@ -16,6 +16,6 @@ Behavior:
|
|
| 16 |
- Clones a public GitHub repo
|
| 17 |
- Chunks it with tree-sitter
|
| 18 |
- Builds retrieval state with a Qdrant adapter
|
| 19 |
-
- Answers questions with Groq-hosted Llama
|
| 20 |
- Deletes the cloned repo after indexing
|
| 21 |
- Keeps only lightweight repo metadata in SQLite
|
|
|
|
| 16 |
- Clones a public GitHub repo
|
| 17 |
- Chunks it with tree-sitter
|
| 18 |
- Builds retrieval state with a Qdrant adapter
|
| 19 |
+
- Answers questions with Groq-hosted Llama or Amazon Bedrock Claude depending on environment configuration
|
| 20 |
- Deletes the cloned repo after indexing
|
| 21 |
- Keeps only lightweight repo metadata in SQLite
|
evals/run_eval.py
CHANGED
|
@@ -3,6 +3,7 @@ import os
|
|
| 3 |
import sys
|
| 4 |
import asyncio
|
| 5 |
import re
|
|
|
|
| 6 |
from pathlib import Path
|
| 7 |
from collections import Counter, defaultdict
|
| 8 |
from statistics import mean
|
|
@@ -16,6 +17,7 @@ if str(SERVER_ROOT) not in sys.path:
|
|
| 16 |
|
| 17 |
load_dotenv(SERVER_ROOT / ".env")
|
| 18 |
|
|
|
|
| 19 |
from src.embeddings import EmbeddingGenerator
|
| 20 |
|
| 21 |
|
|
@@ -24,6 +26,8 @@ REPO_ID = int(os.getenv("CODEBASE_RAG_REPO_ID", "1"))
|
|
| 24 |
SESSION_ID = os.getenv("CODEBASE_RAG_SESSION_ID", "eval-session")
|
| 25 |
TOP_K = int(os.getenv("CODEBASE_RAG_TOP_K", "8"))
|
| 26 |
QUERY_TIMEOUT_SECONDS = int(os.getenv("CODEBASE_RAG_QUERY_TIMEOUT_SECONDS", "180"))
|
|
|
|
|
|
|
| 27 |
ENABLE_RAGAS = os.getenv("CODEBASE_RAG_ENABLE_RAGAS", "1").lower() not in {"0", "false", "no"}
|
| 28 |
RAGAS_ASYNC = os.getenv("CODEBASE_RAG_RAGAS_ASYNC", "0").lower() in {"1", "true", "yes"}
|
| 29 |
RAGAS_RAISE_EXCEPTIONS = os.getenv("CODEBASE_RAG_RAGAS_RAISE_EXCEPTIONS", "0").lower() in {
|
|
@@ -31,6 +35,8 @@ RAGAS_RAISE_EXCEPTIONS = os.getenv("CODEBASE_RAG_RAGAS_RAISE_EXCEPTIONS", "0").l
|
|
| 31 |
"true",
|
| 32 |
"yes",
|
| 33 |
}
|
|
|
|
|
|
|
| 34 |
EVAL_SET_PATH = Path(
|
| 35 |
os.getenv(
|
| 36 |
"CODEBASE_RAG_EVAL_SET",
|
|
@@ -43,6 +49,47 @@ def log(message: str):
|
|
| 43 |
print(f"[eval] {message}", file=sys.stderr, flush=True)
|
| 44 |
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
def load_eval_rows():
|
| 47 |
return json.loads(EVAL_SET_PATH.read_text())
|
| 48 |
|
|
@@ -54,24 +101,51 @@ def post_query(row):
|
|
| 54 |
"top_k": TOP_K,
|
| 55 |
"history": row.get("turns", []),
|
| 56 |
}
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
detail = response.text
|
| 65 |
try:
|
| 66 |
parsed = response.json()
|
| 67 |
detail = parsed.get("detail") or parsed
|
| 68 |
except Exception:
|
| 69 |
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
raise RuntimeError(
|
| 71 |
-
f"Query failed for eval case {
|
| 72 |
f"with status {response.status_code}: {detail}"
|
| 73 |
)
|
| 74 |
-
|
|
|
|
| 75 |
|
| 76 |
|
| 77 |
def normalize_path(path: str) -> str:
|
|
@@ -115,6 +189,18 @@ def tokenize_text(text: str):
|
|
| 115 |
return re.findall(r"[a-z0-9_./+-]+", (text or "").lower())
|
| 116 |
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
def compute_retrieval_metrics(expected_sources, actual_sources):
|
| 119 |
expected = {normalize_path(path) for path in expected_sources}
|
| 120 |
actual = [normalize_path(path) for path in actual_sources]
|
|
@@ -165,24 +251,52 @@ def compute_retrieval_metrics(expected_sources, actual_sources):
|
|
| 165 |
}
|
| 166 |
|
| 167 |
|
| 168 |
-
def
|
| 169 |
-
keywords =
|
| 170 |
if not keywords:
|
| 171 |
return None
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
|
| 177 |
-
def keyword_pass(row,
|
| 178 |
-
if
|
| 179 |
return None
|
| 180 |
minimum = int(row.get("min_keyword_matches", 1))
|
| 181 |
-
|
| 182 |
-
if not keywords:
|
| 183 |
-
return None
|
| 184 |
-
matched = round(coverage * len(keywords))
|
| 185 |
-
return 1 if matched >= minimum else 0
|
| 186 |
|
| 187 |
|
| 188 |
def answer_length_metrics(answer: str):
|
|
@@ -193,7 +307,7 @@ def answer_length_metrics(answer: str):
|
|
| 193 |
}
|
| 194 |
|
| 195 |
|
| 196 |
-
def
|
| 197 |
reference_terms = {
|
| 198 |
token for token in tokenize_text(reference)
|
| 199 |
if len(token) > 2 and token not in STOPWORDS
|
|
@@ -201,8 +315,23 @@ def lexical_overlap_ratio(reference: str, candidate: str):
|
|
| 201 |
if not reference_terms:
|
| 202 |
return None
|
| 203 |
candidate_terms = set(tokenize_text(candidate))
|
| 204 |
-
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
|
| 208 |
def validate_eval_rows(rows):
|
|
@@ -236,6 +365,13 @@ def validate_eval_rows(rows):
|
|
| 236 |
errors.append(f"{row_id}: expected_sources must be a non-empty list")
|
| 237 |
if must_include_any and not isinstance(must_include_any, list):
|
| 238 |
errors.append(f"{row_id}: must_include_any must be a list when present")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
if row.get("turns"):
|
| 240 |
conversation_cases += 1
|
| 241 |
expected_source_counts.append(len(expected_sources) if isinstance(expected_sources, list) else 0)
|
|
@@ -279,12 +415,16 @@ def validate_eval_rows(rows):
|
|
| 279 |
def summarize_custom_metrics(details):
|
| 280 |
keyword_coverages = [item["keyword_coverage"] for item in details if item["keyword_coverage"] is not None]
|
| 281 |
keyword_passes = [item["keyword_pass"] for item in details if item["keyword_pass"] is not None]
|
|
|
|
|
|
|
|
|
|
| 282 |
grounded_answer_passes = [
|
| 283 |
1
|
| 284 |
for item in details
|
| 285 |
if item["retrieval_hit"] == 1
|
| 286 |
and item["has_substantive_answer"] == 1
|
| 287 |
and (item["keyword_pass"] in {None, 1})
|
|
|
|
| 288 |
]
|
| 289 |
exact_source_recall_cases = [1 for item in details if item["source_recall"] == 1.0]
|
| 290 |
return {
|
|
@@ -296,6 +436,7 @@ def summarize_custom_metrics(details):
|
|
| 296 |
"duplicate_source_rate": round(mean(item["duplicate_source_rate"] for item in details), 4),
|
| 297 |
"keyword_coverage": round(mean(keyword_coverages), 4) if keyword_coverages else None,
|
| 298 |
"keyword_pass_rate": round(mean(keyword_passes), 4) if keyword_passes else None,
|
|
|
|
| 299 |
"ground_truth_lexical_overlap": round(
|
| 300 |
mean(item["ground_truth_lexical_overlap"] for item in details if item["ground_truth_lexical_overlap"] is not None),
|
| 301 |
4,
|
|
@@ -323,10 +464,23 @@ def summarize_by_category(details):
|
|
| 323 |
"source_recall": round(mean(item["source_recall"] for item in items), 4),
|
| 324 |
"mrr": round(mean(item["mrr"] for item in items), 4),
|
| 325 |
"keyword_pass_rate": round(mean(keyword_passes), 4) if keyword_passes else None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
"grounded_answer_rate": round(
|
| 327 |
mean(
|
| 328 |
1
|
| 329 |
-
if item["retrieval_hit"] == 1
|
|
|
|
|
|
|
|
|
|
| 330 |
else 0
|
| 331 |
for item in items
|
| 332 |
),
|
|
@@ -346,6 +500,7 @@ def build_headline_metrics(custom_metrics, audit):
|
|
| 346 |
"source_recall": custom_metrics["source_recall"],
|
| 347 |
"grounded_answer_rate": custom_metrics["grounded_answer_rate"],
|
| 348 |
"keyword_pass_rate": custom_metrics["keyword_pass_rate"],
|
|
|
|
| 349 |
}
|
| 350 |
|
| 351 |
|
|
@@ -361,9 +516,14 @@ def build_resume_summary(custom_metrics, audit, ragas_report, ragas_error):
|
|
| 361 |
f"source recall {custom_metrics['source_recall']:.1%}."
|
| 362 |
),
|
| 363 |
(
|
| 364 |
-
f"
|
| 365 |
+ (
|
| 366 |
-
f", keyword/checklist pass rate {custom_metrics['keyword_pass_rate']:.1%}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
if custom_metrics["keyword_pass_rate"] is not None
|
| 368 |
else "."
|
| 369 |
)
|
|
@@ -424,13 +584,12 @@ def maybe_write_report(report):
|
|
| 424 |
|
| 425 |
|
| 426 |
def build_bedrock_ragas_llm(run_config):
|
| 427 |
-
import boto3
|
| 428 |
from langchain_core.outputs import Generation, LLMResult
|
| 429 |
from ragas.llms.base import BaseRagasLLM
|
| 430 |
|
| 431 |
class BedrockRagasLLM(BaseRagasLLM):
|
| 432 |
-
def __init__(self, model: str,
|
| 433 |
-
self.client =
|
| 434 |
self.model = model
|
| 435 |
self.set_run_config(run_config)
|
| 436 |
|
|
@@ -445,35 +604,19 @@ def build_bedrock_ragas_llm(run_config):
|
|
| 445 |
|
| 446 |
def _generate_once(self, prompt, n=1, temperature=1e-8, stop=None, callbacks=None):
|
| 447 |
prompt_text = self._prompt_to_text(prompt)
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
response = self.client.converse(
|
| 456 |
-
modelId=self.model,
|
| 457 |
-
messages=[
|
| 458 |
-
{
|
| 459 |
-
"role": "user",
|
| 460 |
-
"content": [{"text": prompt_text}],
|
| 461 |
-
}
|
| 462 |
-
],
|
| 463 |
-
inferenceConfig=inference_config,
|
| 464 |
)
|
| 465 |
|
| 466 |
-
generations = []
|
| 467 |
-
output_message = (response.get("output") or {}).get("message") or {}
|
| 468 |
-
content_blocks = output_message.get("content") or []
|
| 469 |
-
text = "".join(
|
| 470 |
-
block.get("text", "") for block in content_blocks if isinstance(block, dict)
|
| 471 |
-
).strip()
|
| 472 |
-
if text:
|
| 473 |
-
generations.append(Generation(text=text))
|
| 474 |
|
| 475 |
if not generations:
|
| 476 |
-
raise RuntimeError("
|
| 477 |
|
| 478 |
return LLMResult(generations=[generations])
|
| 479 |
|
|
@@ -496,12 +639,11 @@ def build_bedrock_ragas_llm(run_config):
|
|
| 496 |
callbacks,
|
| 497 |
)
|
| 498 |
|
| 499 |
-
region = os.getenv("AWS_REGION", os.getenv("AWS_DEFAULT_REGION", "us-east-1"))
|
| 500 |
model = os.getenv(
|
| 501 |
"EVAL_MODEL",
|
| 502 |
-
os.getenv("BEDROCK_EVAL_MODEL", "
|
| 503 |
)
|
| 504 |
-
return BedrockRagasLLM(model=model,
|
| 505 |
|
| 506 |
|
| 507 |
def build_ragas_embeddings(run_config):
|
|
@@ -568,8 +710,8 @@ def run_ragas(rows, outputs):
|
|
| 568 |
max_wait=int(os.getenv("EVAL_MAX_WAIT_SECONDS", "60")),
|
| 569 |
)
|
| 570 |
log(
|
| 571 |
-
"Using
|
| 572 |
-
f"({os.getenv('EVAL_MODEL', os.getenv('BEDROCK_EVAL_MODEL', '
|
| 573 |
)
|
| 574 |
log(
|
| 575 |
f"RAGAS runtime: async={RAGAS_ASYNC}, raise_exceptions={RAGAS_RAISE_EXCEPTIONS}, "
|
|
@@ -596,10 +738,19 @@ def run():
|
|
| 596 |
log(f"Loading eval set from {EVAL_SET_PATH}")
|
| 597 |
rows = load_eval_rows()
|
| 598 |
audit = validate_eval_rows(rows)
|
|
|
|
| 599 |
if audit["errors"]:
|
| 600 |
raise RuntimeError("Eval set validation failed: " + "; ".join(audit["errors"]))
|
| 601 |
for warning in audit["warnings"]:
|
| 602 |
log(f"Eval set warning: {warning}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
log(
|
| 604 |
f"Starting eval with api_url={API_URL}, repo_id={REPO_ID}, "
|
| 605 |
f"session_id={SESSION_ID}, top_k={TOP_K}, cases={len(rows)}"
|
|
@@ -619,10 +770,13 @@ def run():
|
|
| 619 |
|
| 620 |
cited_paths = [source["file_path"] for source in result.get("sources", [])]
|
| 621 |
metrics = compute_retrieval_metrics(row.get("expected_sources", []), cited_paths)
|
| 622 |
-
|
| 623 |
-
|
|
|
|
| 624 |
length_metrics = answer_length_metrics(result.get("answer", ""))
|
| 625 |
-
|
|
|
|
|
|
|
| 626 |
|
| 627 |
details.append(
|
| 628 |
{
|
|
@@ -640,7 +794,15 @@ def run():
|
|
| 640 |
"duplicate_source_rate": metrics["duplicate_source_rate"],
|
| 641 |
"keyword_coverage": keyword_coverage,
|
| 642 |
"keyword_pass": keyword_gate,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
"ground_truth_lexical_overlap": overlap,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 644 |
**length_metrics,
|
| 645 |
}
|
| 646 |
)
|
|
@@ -659,8 +821,17 @@ def run():
|
|
| 659 |
"repo_id": REPO_ID,
|
| 660 |
"session_id": SESSION_ID,
|
| 661 |
"top_k": TOP_K,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 662 |
"query_timeout_seconds": QUERY_TIMEOUT_SECONDS,
|
|
|
|
|
|
|
| 663 |
"eval_set": str(EVAL_SET_PATH),
|
|
|
|
|
|
|
| 664 |
},
|
| 665 |
"eval_set_audit": audit,
|
| 666 |
"headline_metrics": headline_metrics,
|
|
|
|
| 3 |
import sys
|
| 4 |
import asyncio
|
| 5 |
import re
|
| 6 |
+
import time
|
| 7 |
from pathlib import Path
|
| 8 |
from collections import Counter, defaultdict
|
| 9 |
from statistics import mean
|
|
|
|
| 17 |
|
| 18 |
load_dotenv(SERVER_ROOT / ".env")
|
| 19 |
|
| 20 |
+
from src.bedrock_claude import create_bedrock_runtime_client, generate_bedrock_claude_text
|
| 21 |
from src.embeddings import EmbeddingGenerator
|
| 22 |
|
| 23 |
|
|
|
|
| 26 |
SESSION_ID = os.getenv("CODEBASE_RAG_SESSION_ID", "eval-session")
|
| 27 |
TOP_K = int(os.getenv("CODEBASE_RAG_TOP_K", "8"))
|
| 28 |
QUERY_TIMEOUT_SECONDS = int(os.getenv("CODEBASE_RAG_QUERY_TIMEOUT_SECONDS", "180"))
|
| 29 |
+
QUERY_MAX_RETRIES = int(os.getenv("CODEBASE_RAG_QUERY_MAX_RETRIES", "5"))
|
| 30 |
+
QUERY_RETRY_BASE_SECONDS = float(os.getenv("CODEBASE_RAG_QUERY_RETRY_BASE_SECONDS", "2"))
|
| 31 |
ENABLE_RAGAS = os.getenv("CODEBASE_RAG_ENABLE_RAGAS", "1").lower() not in {"0", "false", "no"}
|
| 32 |
RAGAS_ASYNC = os.getenv("CODEBASE_RAG_RAGAS_ASYNC", "0").lower() in {"1", "true", "yes"}
|
| 33 |
RAGAS_RAISE_EXCEPTIONS = os.getenv("CODEBASE_RAG_RAGAS_RAISE_EXCEPTIONS", "0").lower() in {
|
|
|
|
| 35 |
"true",
|
| 36 |
"yes",
|
| 37 |
}
|
| 38 |
+
MIN_REFERENCE_OVERLAP = float(os.getenv("CODEBASE_RAG_MIN_REFERENCE_OVERLAP", "0.2"))
|
| 39 |
+
MIN_REFERENCE_TERM_MATCHES = int(os.getenv("CODEBASE_RAG_MIN_REFERENCE_TERM_MATCHES", "2"))
|
| 40 |
EVAL_SET_PATH = Path(
|
| 41 |
os.getenv(
|
| 42 |
"CODEBASE_RAG_EVAL_SET",
|
|
|
|
| 49 |
print(f"[eval] {message}", file=sys.stderr, flush=True)
|
| 50 |
|
| 51 |
|
| 52 |
+
def get_app_model_config():
|
| 53 |
+
llm_provider = os.getenv("LLM_PROVIDER", "bedrock").lower()
|
| 54 |
+
if llm_provider == "groq":
|
| 55 |
+
llm_model = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
|
| 56 |
+
elif llm_provider == "bedrock":
|
| 57 |
+
llm_model = os.getenv(
|
| 58 |
+
"BEDROCK_LLM_MODEL",
|
| 59 |
+
"anthropic.claude-sonnet-4-20250514-v1:0",
|
| 60 |
+
)
|
| 61 |
+
elif llm_provider == "vertex_ai":
|
| 62 |
+
llm_model = os.getenv("VERTEX_LLM_MODEL", "claude-sonnet-4@20250514")
|
| 63 |
+
else:
|
| 64 |
+
llm_model = "unknown"
|
| 65 |
+
|
| 66 |
+
embedding_provider = os.getenv("EMBEDDING_PROVIDER", "auto").lower()
|
| 67 |
+
if embedding_provider == "bedrock":
|
| 68 |
+
embedding_model = os.getenv("BEDROCK_EMBEDDING_MODEL", "cohere.embed-v4:0")
|
| 69 |
+
elif embedding_provider == "vertex_ai":
|
| 70 |
+
embedding_model = os.getenv("VERTEX_EMBEDDING_MODEL", "gemini-embedding-001")
|
| 71 |
+
elif embedding_provider == "openai":
|
| 72 |
+
embedding_model = os.getenv("EMBEDDING_MODEL", "text-embedding-3-small")
|
| 73 |
+
elif embedding_provider == "local":
|
| 74 |
+
embedding_model = os.getenv("EMBEDDING_MODEL") or os.getenv(
|
| 75 |
+
"LOCAL_EMBEDDING_MODEL", "nomic-ai/CodeRankEmbed"
|
| 76 |
+
)
|
| 77 |
+
else:
|
| 78 |
+
embedding_model = os.getenv("EMBEDDING_MODEL") or "auto"
|
| 79 |
+
|
| 80 |
+
eval_model = os.getenv(
|
| 81 |
+
"EVAL_MODEL",
|
| 82 |
+
os.getenv("BEDROCK_EVAL_MODEL", "anthropic.claude-opus-4-20250514-v1:0"),
|
| 83 |
+
)
|
| 84 |
+
return {
|
| 85 |
+
"llm_provider": llm_provider,
|
| 86 |
+
"llm_model": llm_model,
|
| 87 |
+
"embedding_provider": embedding_provider,
|
| 88 |
+
"embedding_model": embedding_model,
|
| 89 |
+
"eval_model": eval_model,
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
def load_eval_rows():
|
| 94 |
return json.loads(EVAL_SET_PATH.read_text())
|
| 95 |
|
|
|
|
| 101 |
"top_k": TOP_K,
|
| 102 |
"history": row.get("turns", []),
|
| 103 |
}
|
| 104 |
+
case_id = row.get("id", row["question"])
|
| 105 |
+
|
| 106 |
+
for attempt in range(1, QUERY_MAX_RETRIES + 1):
|
| 107 |
+
response = requests.post(
|
| 108 |
+
f"{API_URL}/api/query",
|
| 109 |
+
json=payload,
|
| 110 |
+
headers={"X-Session-Id": SESSION_ID},
|
| 111 |
+
timeout=QUERY_TIMEOUT_SECONDS,
|
| 112 |
+
)
|
| 113 |
+
if response.ok:
|
| 114 |
+
return response.json()
|
| 115 |
+
|
| 116 |
detail = response.text
|
| 117 |
try:
|
| 118 |
parsed = response.json()
|
| 119 |
detail = parsed.get("detail") or parsed
|
| 120 |
except Exception:
|
| 121 |
pass
|
| 122 |
+
|
| 123 |
+
detail_text = str(detail)
|
| 124 |
+
is_retryable = response.status_code in {429, 500, 502, 503, 504} and any(
|
| 125 |
+
marker in detail_text
|
| 126 |
+
for marker in [
|
| 127 |
+
"ThrottlingException",
|
| 128 |
+
"Too many requests",
|
| 129 |
+
"timed out",
|
| 130 |
+
"timeout",
|
| 131 |
+
"ServiceUnavailable",
|
| 132 |
+
]
|
| 133 |
+
)
|
| 134 |
+
if is_retryable and attempt < QUERY_MAX_RETRIES:
|
| 135 |
+
wait_seconds = QUERY_RETRY_BASE_SECONDS * (2 ** (attempt - 1))
|
| 136 |
+
log(
|
| 137 |
+
f"Retrying case {case_id} after transient query failure "
|
| 138 |
+
f"(attempt {attempt}/{QUERY_MAX_RETRIES}, wait={wait_seconds:.1f}s): {detail_text}"
|
| 139 |
+
)
|
| 140 |
+
time.sleep(wait_seconds)
|
| 141 |
+
continue
|
| 142 |
+
|
| 143 |
raise RuntimeError(
|
| 144 |
+
f"Query failed for eval case {case_id!r} "
|
| 145 |
f"with status {response.status_code}: {detail}"
|
| 146 |
)
|
| 147 |
+
|
| 148 |
+
raise RuntimeError(f"Query failed for eval case {case_id!r}: exhausted retries")
|
| 149 |
|
| 150 |
|
| 151 |
def normalize_path(path: str) -> str:
|
|
|
|
| 189 |
return re.findall(r"[a-z0-9_./+-]+", (text or "").lower())
|
| 190 |
|
| 191 |
|
| 192 |
+
def normalize_keywords(keywords):
|
| 193 |
+
normalized = []
|
| 194 |
+
seen = set()
|
| 195 |
+
for keyword in keywords or []:
|
| 196 |
+
phrase = " ".join(tokenize_text(str(keyword)))
|
| 197 |
+
if not phrase or phrase in seen:
|
| 198 |
+
continue
|
| 199 |
+
seen.add(phrase)
|
| 200 |
+
normalized.append(phrase)
|
| 201 |
+
return normalized
|
| 202 |
+
|
| 203 |
+
|
| 204 |
def compute_retrieval_metrics(expected_sources, actual_sources):
|
| 205 |
expected = {normalize_path(path) for path in expected_sources}
|
| 206 |
actual = [normalize_path(path) for path in actual_sources]
|
|
|
|
| 251 |
}
|
| 252 |
|
| 253 |
|
| 254 |
+
def keyword_match_details(row, answer: str):
|
| 255 |
+
keywords = normalize_keywords(row.get("must_include_any", []))
|
| 256 |
if not keywords:
|
| 257 |
return None
|
| 258 |
+
|
| 259 |
+
answer_tokens = tokenize_text(answer)
|
| 260 |
+
if not answer_tokens:
|
| 261 |
+
return {
|
| 262 |
+
"coverage": 0.0,
|
| 263 |
+
"matched_count": 0,
|
| 264 |
+
"total_keywords": len(keywords),
|
| 265 |
+
"matched_keywords": [],
|
| 266 |
+
"missing_keywords": keywords,
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
matched_keywords = []
|
| 270 |
+
for keyword in keywords:
|
| 271 |
+
keyword_tokens = keyword.split()
|
| 272 |
+
window = len(keyword_tokens)
|
| 273 |
+
if window == 1:
|
| 274 |
+
if keyword_tokens[0] in answer_tokens:
|
| 275 |
+
matched_keywords.append(keyword)
|
| 276 |
+
continue
|
| 277 |
+
|
| 278 |
+
for index in range(0, len(answer_tokens) - window + 1):
|
| 279 |
+
if answer_tokens[index : index + window] == keyword_tokens:
|
| 280 |
+
matched_keywords.append(keyword)
|
| 281 |
+
break
|
| 282 |
+
|
| 283 |
+
matched_set = set(matched_keywords)
|
| 284 |
+
missing_keywords = [keyword for keyword in keywords if keyword not in matched_set]
|
| 285 |
+
matched_count = len(matched_set)
|
| 286 |
+
return {
|
| 287 |
+
"coverage": matched_count / len(keywords),
|
| 288 |
+
"matched_count": matched_count,
|
| 289 |
+
"total_keywords": len(keywords),
|
| 290 |
+
"matched_keywords": sorted(matched_set),
|
| 291 |
+
"missing_keywords": missing_keywords,
|
| 292 |
+
}
|
| 293 |
|
| 294 |
|
| 295 |
+
def keyword_pass(row, keyword_details):
|
| 296 |
+
if keyword_details is None:
|
| 297 |
return None
|
| 298 |
minimum = int(row.get("min_keyword_matches", 1))
|
| 299 |
+
return 1 if keyword_details["matched_count"] >= minimum else 0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
|
| 302 |
def answer_length_metrics(answer: str):
|
|
|
|
| 307 |
}
|
| 308 |
|
| 309 |
|
| 310 |
+
def reference_support_details(reference: str, candidate: str):
|
| 311 |
reference_terms = {
|
| 312 |
token for token in tokenize_text(reference)
|
| 313 |
if len(token) > 2 and token not in STOPWORDS
|
|
|
|
| 315 |
if not reference_terms:
|
| 316 |
return None
|
| 317 |
candidate_terms = set(tokenize_text(candidate))
|
| 318 |
+
matched_terms = sorted(token for token in reference_terms if token in candidate_terms)
|
| 319 |
+
matched_count = len(matched_terms)
|
| 320 |
+
return {
|
| 321 |
+
"ratio": matched_count / len(reference_terms),
|
| 322 |
+
"matched_count": matched_count,
|
| 323 |
+
"reference_term_count": len(reference_terms),
|
| 324 |
+
"matched_terms": matched_terms,
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def reference_support_pass(reference_details):
|
| 329 |
+
if reference_details is None:
|
| 330 |
+
return None
|
| 331 |
+
return 1 if (
|
| 332 |
+
reference_details["ratio"] >= MIN_REFERENCE_OVERLAP
|
| 333 |
+
and reference_details["matched_count"] >= MIN_REFERENCE_TERM_MATCHES
|
| 334 |
+
) else 0
|
| 335 |
|
| 336 |
|
| 337 |
def validate_eval_rows(rows):
|
|
|
|
| 365 |
errors.append(f"{row_id}: expected_sources must be a non-empty list")
|
| 366 |
if must_include_any and not isinstance(must_include_any, list):
|
| 367 |
errors.append(f"{row_id}: must_include_any must be a list when present")
|
| 368 |
+
if isinstance(must_include_any, list):
|
| 369 |
+
normalized_keywords = normalize_keywords(must_include_any)
|
| 370 |
+
if len(normalized_keywords) != len([keyword for keyword in must_include_any if str(keyword).strip()]):
|
| 371 |
+
warnings.append(
|
| 372 |
+
f"{row_id}: duplicate or case-variant keywords were normalized; "
|
| 373 |
+
"resume metrics are stricter than the raw checklist wording."
|
| 374 |
+
)
|
| 375 |
if row.get("turns"):
|
| 376 |
conversation_cases += 1
|
| 377 |
expected_source_counts.append(len(expected_sources) if isinstance(expected_sources, list) else 0)
|
|
|
|
| 415 |
def summarize_custom_metrics(details):
|
| 416 |
keyword_coverages = [item["keyword_coverage"] for item in details if item["keyword_coverage"] is not None]
|
| 417 |
keyword_passes = [item["keyword_pass"] for item in details if item["keyword_pass"] is not None]
|
| 418 |
+
reference_support_passes = [
|
| 419 |
+
item["reference_support_pass"] for item in details if item["reference_support_pass"] is not None
|
| 420 |
+
]
|
| 421 |
grounded_answer_passes = [
|
| 422 |
1
|
| 423 |
for item in details
|
| 424 |
if item["retrieval_hit"] == 1
|
| 425 |
and item["has_substantive_answer"] == 1
|
| 426 |
and (item["keyword_pass"] in {None, 1})
|
| 427 |
+
and (item["reference_support_pass"] in {None, 1})
|
| 428 |
]
|
| 429 |
exact_source_recall_cases = [1 for item in details if item["source_recall"] == 1.0]
|
| 430 |
return {
|
|
|
|
| 436 |
"duplicate_source_rate": round(mean(item["duplicate_source_rate"] for item in details), 4),
|
| 437 |
"keyword_coverage": round(mean(keyword_coverages), 4) if keyword_coverages else None,
|
| 438 |
"keyword_pass_rate": round(mean(keyword_passes), 4) if keyword_passes else None,
|
| 439 |
+
"reference_support_rate": round(mean(reference_support_passes), 4) if reference_support_passes else None,
|
| 440 |
"ground_truth_lexical_overlap": round(
|
| 441 |
mean(item["ground_truth_lexical_overlap"] for item in details if item["ground_truth_lexical_overlap"] is not None),
|
| 442 |
4,
|
|
|
|
| 464 |
"source_recall": round(mean(item["source_recall"] for item in items), 4),
|
| 465 |
"mrr": round(mean(item["mrr"] for item in items), 4),
|
| 466 |
"keyword_pass_rate": round(mean(keyword_passes), 4) if keyword_passes else None,
|
| 467 |
+
"reference_support_rate": round(
|
| 468 |
+
mean(
|
| 469 |
+
item["reference_support_pass"]
|
| 470 |
+
for item in items
|
| 471 |
+
if item["reference_support_pass"] is not None
|
| 472 |
+
),
|
| 473 |
+
4,
|
| 474 |
+
)
|
| 475 |
+
if any(item["reference_support_pass"] is not None for item in items)
|
| 476 |
+
else None,
|
| 477 |
"grounded_answer_rate": round(
|
| 478 |
mean(
|
| 479 |
1
|
| 480 |
+
if item["retrieval_hit"] == 1
|
| 481 |
+
and item["has_substantive_answer"] == 1
|
| 482 |
+
and item["keyword_pass"] in {None, 1}
|
| 483 |
+
and item["reference_support_pass"] in {None, 1}
|
| 484 |
else 0
|
| 485 |
for item in items
|
| 486 |
),
|
|
|
|
| 500 |
"source_recall": custom_metrics["source_recall"],
|
| 501 |
"grounded_answer_rate": custom_metrics["grounded_answer_rate"],
|
| 502 |
"keyword_pass_rate": custom_metrics["keyword_pass_rate"],
|
| 503 |
+
"reference_support_rate": custom_metrics["reference_support_rate"],
|
| 504 |
}
|
| 505 |
|
| 506 |
|
|
|
|
| 516 |
f"source recall {custom_metrics['source_recall']:.1%}."
|
| 517 |
),
|
| 518 |
(
|
| 519 |
+
f"Strict answer quality checks: grounded answer rate {custom_metrics['grounded_answer_rate']:.1%}"
|
| 520 |
+ (
|
| 521 |
+
f", keyword/checklist pass rate {custom_metrics['keyword_pass_rate']:.1%}"
|
| 522 |
+
+ (
|
| 523 |
+
f", reference-support pass rate {custom_metrics['reference_support_rate']:.1%}."
|
| 524 |
+
if custom_metrics["reference_support_rate"] is not None
|
| 525 |
+
else "."
|
| 526 |
+
)
|
| 527 |
if custom_metrics["keyword_pass_rate"] is not None
|
| 528 |
else "."
|
| 529 |
)
|
|
|
|
| 584 |
|
| 585 |
|
| 586 |
def build_bedrock_ragas_llm(run_config):
|
|
|
|
| 587 |
from langchain_core.outputs import Generation, LLMResult
|
| 588 |
from ragas.llms.base import BaseRagasLLM
|
| 589 |
|
| 590 |
class BedrockRagasLLM(BaseRagasLLM):
|
| 591 |
+
def __init__(self, model: str, run_config):
|
| 592 |
+
self.client = create_bedrock_runtime_client()
|
| 593 |
self.model = model
|
| 594 |
self.set_run_config(run_config)
|
| 595 |
|
|
|
|
| 604 |
|
| 605 |
def _generate_once(self, prompt, n=1, temperature=1e-8, stop=None, callbacks=None):
|
| 606 |
prompt_text = self._prompt_to_text(prompt)
|
| 607 |
+
text, _ = generate_bedrock_claude_text(
|
| 608 |
+
self.client,
|
| 609 |
+
self.model,
|
| 610 |
+
"Return only valid JSON or the exact structured output requested.",
|
| 611 |
+
prompt_text,
|
| 612 |
+
max_tokens=int(os.getenv("EVAL_MAX_OUTPUT_TOKENS", "2048")),
|
| 613 |
+
temperature=0.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 614 |
)
|
| 615 |
|
| 616 |
+
generations = [Generation(text=text)] if text else []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 617 |
|
| 618 |
if not generations:
|
| 619 |
+
raise RuntimeError("Bedrock Claude judge returned an empty response.")
|
| 620 |
|
| 621 |
return LLMResult(generations=[generations])
|
| 622 |
|
|
|
|
| 639 |
callbacks,
|
| 640 |
)
|
| 641 |
|
|
|
|
| 642 |
model = os.getenv(
|
| 643 |
"EVAL_MODEL",
|
| 644 |
+
os.getenv("BEDROCK_EVAL_MODEL", "anthropic.claude-opus-4-20250514-v1:0"),
|
| 645 |
)
|
| 646 |
+
return BedrockRagasLLM(model=model, run_config=run_config)
|
| 647 |
|
| 648 |
|
| 649 |
def build_ragas_embeddings(run_config):
|
|
|
|
| 710 |
max_wait=int(os.getenv("EVAL_MAX_WAIT_SECONDS", "60")),
|
| 711 |
)
|
| 712 |
log(
|
| 713 |
+
"Using Bedrock for RAGAS judge model "
|
| 714 |
+
f"({os.getenv('EVAL_MODEL', os.getenv('BEDROCK_EVAL_MODEL', 'anthropic.claude-opus-4-20250514-v1:0'))})"
|
| 715 |
)
|
| 716 |
log(
|
| 717 |
f"RAGAS runtime: async={RAGAS_ASYNC}, raise_exceptions={RAGAS_RAISE_EXCEPTIONS}, "
|
|
|
|
| 738 |
log(f"Loading eval set from {EVAL_SET_PATH}")
|
| 739 |
rows = load_eval_rows()
|
| 740 |
audit = validate_eval_rows(rows)
|
| 741 |
+
model_config = get_app_model_config()
|
| 742 |
if audit["errors"]:
|
| 743 |
raise RuntimeError("Eval set validation failed: " + "; ".join(audit["errors"]))
|
| 744 |
for warning in audit["warnings"]:
|
| 745 |
log(f"Eval set warning: {warning}")
|
| 746 |
+
log(
|
| 747 |
+
"Eval model config: "
|
| 748 |
+
f"qna_provider={model_config['llm_provider']}, "
|
| 749 |
+
f"qna_model={model_config['llm_model']}, "
|
| 750 |
+
f"embedding_provider={model_config['embedding_provider']}, "
|
| 751 |
+
f"embedding_model={model_config['embedding_model']}, "
|
| 752 |
+
f"judge_model={model_config['eval_model']}"
|
| 753 |
+
)
|
| 754 |
log(
|
| 755 |
f"Starting eval with api_url={API_URL}, repo_id={REPO_ID}, "
|
| 756 |
f"session_id={SESSION_ID}, top_k={TOP_K}, cases={len(rows)}"
|
|
|
|
| 770 |
|
| 771 |
cited_paths = [source["file_path"] for source in result.get("sources", [])]
|
| 772 |
metrics = compute_retrieval_metrics(row.get("expected_sources", []), cited_paths)
|
| 773 |
+
keyword_details = keyword_match_details(row, result.get("answer", ""))
|
| 774 |
+
keyword_coverage = keyword_details["coverage"] if keyword_details else None
|
| 775 |
+
keyword_gate = keyword_pass(row, keyword_details)
|
| 776 |
length_metrics = answer_length_metrics(result.get("answer", ""))
|
| 777 |
+
reference_details = reference_support_details(row.get("ground_truth", ""), result.get("answer", ""))
|
| 778 |
+
overlap = reference_details["ratio"] if reference_details else None
|
| 779 |
+
reference_gate = reference_support_pass(reference_details)
|
| 780 |
|
| 781 |
details.append(
|
| 782 |
{
|
|
|
|
| 794 |
"duplicate_source_rate": metrics["duplicate_source_rate"],
|
| 795 |
"keyword_coverage": keyword_coverage,
|
| 796 |
"keyword_pass": keyword_gate,
|
| 797 |
+
"matched_keyword_count": keyword_details["matched_count"] if keyword_details else None,
|
| 798 |
+
"total_keywords": keyword_details["total_keywords"] if keyword_details else None,
|
| 799 |
+
"matched_keywords": keyword_details["matched_keywords"] if keyword_details else [],
|
| 800 |
+
"missing_keywords": keyword_details["missing_keywords"] if keyword_details else [],
|
| 801 |
"ground_truth_lexical_overlap": overlap,
|
| 802 |
+
"reference_support_pass": reference_gate,
|
| 803 |
+
"reference_term_match_count": reference_details["matched_count"] if reference_details else None,
|
| 804 |
+
"reference_term_count": reference_details["reference_term_count"] if reference_details else None,
|
| 805 |
+
"matched_reference_terms": reference_details["matched_terms"] if reference_details else [],
|
| 806 |
**length_metrics,
|
| 807 |
}
|
| 808 |
)
|
|
|
|
| 821 |
"repo_id": REPO_ID,
|
| 822 |
"session_id": SESSION_ID,
|
| 823 |
"top_k": TOP_K,
|
| 824 |
+
"qna_provider": model_config["llm_provider"],
|
| 825 |
+
"qna_model": model_config["llm_model"],
|
| 826 |
+
"embedding_provider": model_config["embedding_provider"],
|
| 827 |
+
"embedding_model": model_config["embedding_model"],
|
| 828 |
+
"eval_model": model_config["eval_model"],
|
| 829 |
"query_timeout_seconds": QUERY_TIMEOUT_SECONDS,
|
| 830 |
+
"query_max_retries": QUERY_MAX_RETRIES,
|
| 831 |
+
"query_retry_base_seconds": QUERY_RETRY_BASE_SECONDS,
|
| 832 |
"eval_set": str(EVAL_SET_PATH),
|
| 833 |
+
"min_reference_overlap": MIN_REFERENCE_OVERLAP,
|
| 834 |
+
"min_reference_term_matches": MIN_REFERENCE_TERM_MATCHES,
|
| 835 |
},
|
| 836 |
"eval_set_audit": audit,
|
| 837 |
"headline_metrics": headline_metrics,
|
evals/sample_eval_set.json
CHANGED
|
@@ -669,5 +669,58 @@
|
|
| 669 |
"FastAPI"
|
| 670 |
],
|
| 671 |
"min_keyword_matches": 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 672 |
}
|
| 673 |
]
|
|
|
|
| 669 |
"FastAPI"
|
| 670 |
],
|
| 671 |
"min_keyword_matches": 2
|
| 672 |
+
},
|
| 673 |
+
{
|
| 674 |
+
"id": "sqlmodel-sa-column-conflict-error",
|
| 675 |
+
"category": "error-handling",
|
| 676 |
+
"question": "What happens if you pass both sa_column and other Field options like primary_key or index in SQLModel?",
|
| 677 |
+
"ground_truth": "SQLModel raises a ValueError when sa_column is combined with other field-level options like primary_key, index, or foreign_key because sa_column is meant to be a fully self-contained SQLAlchemy column definition and mixing it with SQLModel field shortcuts creates an ambiguous configuration.",
|
| 678 |
+
"expected_sources": [
|
| 679 |
+
"sqlmodel/main.py"
|
| 680 |
+
],
|
| 681 |
+
"must_include_any": [
|
| 682 |
+
"sa_column",
|
| 683 |
+
"primary_key",
|
| 684 |
+
"raise",
|
| 685 |
+
"conflict"
|
| 686 |
+
],
|
| 687 |
+
"min_keyword_matches": 3
|
| 688 |
+
},
|
| 689 |
+
{
|
| 690 |
+
"id": "sqlmodel-codegen-basic-model",
|
| 691 |
+
"category": "code-generation",
|
| 692 |
+
"question": "Write a SQLModel table model for a User with an integer primary key, a required name string, and an optional email string",
|
| 693 |
+
"ground_truth": "A correct answer defines a class inheriting from SQLModel with table=True, uses Field(primary_key=True) on an integer id, declares name as a required str, and declares email as Optional[str] with a default of None.",
|
| 694 |
+
"expected_sources": [
|
| 695 |
+
"sqlmodel/main.py",
|
| 696 |
+
"README.md"
|
| 697 |
+
],
|
| 698 |
+
"must_include_any": [
|
| 699 |
+
"SQLModel",
|
| 700 |
+
"Field",
|
| 701 |
+
"table=True",
|
| 702 |
+
"primary_key",
|
| 703 |
+
"Optional"
|
| 704 |
+
],
|
| 705 |
+
"min_keyword_matches": 3
|
| 706 |
+
},
|
| 707 |
+
{
|
| 708 |
+
"id": "sqlmodel-codegen-session-query",
|
| 709 |
+
"category": "code-generation",
|
| 710 |
+
"question": "Write a SQLModel example that creates an engine, opens a session, inserts a User row, and queries all users",
|
| 711 |
+
"ground_truth": "A correct answer uses create_engine to set up the database, SQLModel.metadata.create_all to create tables, opens a Session using a context manager, adds and commits a User instance, then uses select(User) with session.exec to retrieve all rows.",
|
| 712 |
+
"expected_sources": [
|
| 713 |
+
"README.md",
|
| 714 |
+
"sqlmodel/__init__.py",
|
| 715 |
+
"sqlmodel/orm/session.py"
|
| 716 |
+
],
|
| 717 |
+
"must_include_any": [
|
| 718 |
+
"create_engine",
|
| 719 |
+
"Session",
|
| 720 |
+
"select",
|
| 721 |
+
"exec",
|
| 722 |
+
"commit"
|
| 723 |
+
],
|
| 724 |
+
"min_keyword_matches": 4
|
| 725 |
}
|
| 726 |
]
|
requirements.txt
CHANGED
|
@@ -6,6 +6,7 @@ python-dotenv==1.0.1
|
|
| 6 |
|
| 7 |
openai==1.109.1
|
| 8 |
boto3==1.40.58
|
|
|
|
| 9 |
google-genai==1.12.1
|
| 10 |
httpx==0.28.1
|
| 11 |
numpy==1.26.4
|
|
|
|
| 6 |
|
| 7 |
openai==1.109.1
|
| 8 |
boto3==1.40.58
|
| 9 |
+
anthropic[vertex]==0.73.0
|
| 10 |
google-genai==1.12.1
|
| 11 |
httpx==0.28.1
|
| 12 |
numpy==1.26.4
|
src/bedrock_claude.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def create_bedrock_runtime_client():
|
| 6 |
+
try:
|
| 7 |
+
import boto3
|
| 8 |
+
except ImportError as exc:
|
| 9 |
+
raise RuntimeError("Bedrock Claude support requires the `boto3` package.") from exc
|
| 10 |
+
|
| 11 |
+
return boto3.client(
|
| 12 |
+
"bedrock-runtime",
|
| 13 |
+
region_name=os.getenv("AWS_REGION", os.getenv("AWS_DEFAULT_REGION", "us-east-1")),
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def generate_bedrock_claude_text(
|
| 18 |
+
client,
|
| 19 |
+
model: str,
|
| 20 |
+
system_prompt: str,
|
| 21 |
+
user_prompt: str,
|
| 22 |
+
*,
|
| 23 |
+
max_tokens: int,
|
| 24 |
+
temperature: float,
|
| 25 |
+
top_p: Optional[float] = None,
|
| 26 |
+
) -> Tuple[str, str]:
|
| 27 |
+
inference_config = {
|
| 28 |
+
"maxTokens": max_tokens,
|
| 29 |
+
"temperature": temperature,
|
| 30 |
+
}
|
| 31 |
+
if top_p is not None:
|
| 32 |
+
inference_config["topP"] = top_p
|
| 33 |
+
|
| 34 |
+
response = client.converse(
|
| 35 |
+
modelId=model,
|
| 36 |
+
system=[{"text": system_prompt.strip()}],
|
| 37 |
+
messages=[
|
| 38 |
+
{
|
| 39 |
+
"role": "user",
|
| 40 |
+
"content": [{"text": user_prompt.strip()}],
|
| 41 |
+
}
|
| 42 |
+
],
|
| 43 |
+
inferenceConfig=inference_config,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
content_blocks = (((response or {}).get("output") or {}).get("message") or {}).get("content") or []
|
| 47 |
+
text = "".join(block.get("text", "") for block in content_blocks if block.get("text")).strip()
|
| 48 |
+
if not text:
|
| 49 |
+
raise RuntimeError("Bedrock Claude returned an empty response.")
|
| 50 |
+
|
| 51 |
+
return text, str((response or {}).get("stopReason", "") or "")
|
src/embeddings.py
CHANGED
|
@@ -1,11 +1,10 @@
|
|
| 1 |
-
import json
|
| 2 |
import os
|
| 3 |
import time
|
|
|
|
| 4 |
from typing import Callable, List, Optional
|
| 5 |
|
| 6 |
import numpy as np
|
| 7 |
from openai import OpenAI
|
| 8 |
-
from sentence_transformers import SentenceTransformer
|
| 9 |
|
| 10 |
|
| 11 |
class EmbeddingGenerator:
|
|
@@ -17,6 +16,10 @@ class EmbeddingGenerator:
|
|
| 17 |
self.device = os.getenv("EMBEDDING_DEVICE")
|
| 18 |
self.client = None
|
| 19 |
self.model = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
self.vertex_task_type_document = os.getenv(
|
| 21 |
"VERTEX_EMBEDDING_TASK_TYPE_DOCUMENT", "RETRIEVAL_DOCUMENT"
|
| 22 |
)
|
|
@@ -42,6 +45,28 @@ class EmbeddingGenerator:
|
|
| 42 |
)
|
| 43 |
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 44 |
self.embedding_dim = int(os.getenv("OPENAI_EMBEDDING_DIM", "1536"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
elif self.provider == "vertex_ai":
|
| 46 |
print(
|
| 47 |
f"[embeddings] Initializing Vertex AI embeddings with model={self.model_name}",
|
|
@@ -55,7 +80,7 @@ class EmbeddingGenerator:
|
|
| 55 |
) from exc
|
| 56 |
|
| 57 |
project = os.getenv("GOOGLE_CLOUD_PROJECT")
|
| 58 |
-
location = os.getenv("GOOGLE_CLOUD_LOCATION", "
|
| 59 |
if not project:
|
| 60 |
raise RuntimeError(
|
| 61 |
"GOOGLE_CLOUD_PROJECT must be set when using Vertex AI embeddings."
|
|
@@ -72,22 +97,13 @@ class EmbeddingGenerator:
|
|
| 72 |
str(self.vertex_output_dimensionality or 3072),
|
| 73 |
)
|
| 74 |
)
|
| 75 |
-
|
| 76 |
-
print(
|
| 77 |
-
f"[embeddings] Initializing AWS Bedrock embeddings with model={self.model_name}",
|
| 78 |
-
flush=True,
|
| 79 |
-
)
|
| 80 |
try:
|
| 81 |
-
import
|
| 82 |
except ImportError as exc:
|
| 83 |
raise RuntimeError(
|
| 84 |
-
"
|
| 85 |
) from exc
|
| 86 |
-
|
| 87 |
-
region = os.getenv("AWS_REGION", os.getenv("AWS_DEFAULT_REGION", "us-east-1"))
|
| 88 |
-
self.client = boto3.client("bedrock-runtime", region_name=region)
|
| 89 |
-
self.embedding_dim = int(os.getenv("BEDROCK_EMBEDDING_DIM", "1024"))
|
| 90 |
-
else:
|
| 91 |
model_device = self.device or "cpu"
|
| 92 |
print(
|
| 93 |
f"[embeddings] Loading local embedding model={self.model_name} on device={model_device}",
|
|
@@ -109,13 +125,16 @@ class EmbeddingGenerator:
|
|
| 109 |
def embed_text(self, text: str) -> np.ndarray:
|
| 110 |
if self.provider == "openai":
|
| 111 |
return self.embed_batch([text])[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
if self.provider == "vertex_ai":
|
| 113 |
return self._embed_with_vertex(
|
| 114 |
[text],
|
| 115 |
task_type=self.vertex_task_type_query,
|
| 116 |
)[0]
|
| 117 |
-
if self.provider == "bedrock":
|
| 118 |
-
return self._embed_with_bedrock(text)
|
| 119 |
query_text = f"{self.query_prefix}: {text}" if self.query_prefix else text
|
| 120 |
return self._encode_with_backoff([query_text], prompt_name=self.query_prompt_name)[0]
|
| 121 |
|
|
@@ -137,19 +156,18 @@ class EmbeddingGenerator:
|
|
| 137 |
if progress_callback:
|
| 138 |
progress_callback(len(texts), len(texts))
|
| 139 |
return np.array(embeddings, dtype="float32")
|
| 140 |
-
if self.provider == "
|
| 141 |
-
return self.
|
| 142 |
texts=texts,
|
| 143 |
batch_size=batch_size,
|
| 144 |
progress_callback=progress_callback,
|
| 145 |
)
|
| 146 |
-
if self.provider == "
|
| 147 |
-
return self.
|
| 148 |
texts=texts,
|
| 149 |
batch_size=batch_size,
|
| 150 |
progress_callback=progress_callback,
|
| 151 |
)
|
| 152 |
-
|
| 153 |
effective_batch_size = max(1, batch_size or self.batch_size)
|
| 154 |
all_embeddings = []
|
| 155 |
total = len(texts)
|
|
@@ -216,36 +234,6 @@ class EmbeddingGenerator:
|
|
| 216 |
|
| 217 |
return np.vstack(all_embeddings).astype("float32")
|
| 218 |
|
| 219 |
-
def _embed_with_vertex(self, texts: List[str], task_type: str) -> np.ndarray:
|
| 220 |
-
config = {
|
| 221 |
-
"task_type": task_type,
|
| 222 |
-
}
|
| 223 |
-
if self.vertex_output_dimensionality:
|
| 224 |
-
config["output_dimensionality"] = self.vertex_output_dimensionality
|
| 225 |
-
|
| 226 |
-
response = self.client.models.embed_content(
|
| 227 |
-
model=self.model_name,
|
| 228 |
-
contents=texts,
|
| 229 |
-
config=config,
|
| 230 |
-
)
|
| 231 |
-
embeddings = getattr(response, "embeddings", None)
|
| 232 |
-
if not embeddings:
|
| 233 |
-
raise RuntimeError("Vertex AI embeddings returned an empty response.")
|
| 234 |
-
|
| 235 |
-
values = []
|
| 236 |
-
for item in embeddings:
|
| 237 |
-
if hasattr(item, "values"):
|
| 238 |
-
values.append(item.values)
|
| 239 |
-
elif isinstance(item, dict):
|
| 240 |
-
values.append(item.get("values"))
|
| 241 |
-
else:
|
| 242 |
-
values.append(getattr(item, "embedding", None))
|
| 243 |
-
|
| 244 |
-
if not values or any(vector is None for vector in values):
|
| 245 |
-
raise RuntimeError("Vertex AI embeddings response could not be parsed.")
|
| 246 |
-
|
| 247 |
-
return np.array(values, dtype="float32")
|
| 248 |
-
|
| 249 |
def _embed_batch_with_bedrock(
|
| 250 |
self,
|
| 251 |
texts: List[str],
|
|
@@ -255,6 +243,7 @@ class EmbeddingGenerator:
|
|
| 255 |
effective_batch_size = max(1, batch_size or self.batch_size)
|
| 256 |
all_embeddings = []
|
| 257 |
total = len(texts)
|
|
|
|
| 258 |
|
| 259 |
for start in range(0, total, effective_batch_size):
|
| 260 |
batch = texts[start : start + effective_batch_size]
|
|
@@ -266,8 +255,11 @@ class EmbeddingGenerator:
|
|
| 266 |
flush=True,
|
| 267 |
)
|
| 268 |
started_at = time.perf_counter()
|
| 269 |
-
batch_embeddings =
|
| 270 |
-
|
|
|
|
|
|
|
|
|
|
| 271 |
elapsed = time.perf_counter() - started_at
|
| 272 |
print(
|
| 273 |
f"[embeddings] Finished Bedrock batch {batch_number}/{total_batches} "
|
|
@@ -279,24 +271,63 @@ class EmbeddingGenerator:
|
|
| 279 |
|
| 280 |
return np.vstack(all_embeddings).astype("float32")
|
| 281 |
|
| 282 |
-
def
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
|
|
|
|
|
|
| 286 |
|
| 287 |
-
response = self.client.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
modelId=self.model_name,
|
| 289 |
-
body=json.dumps(payload),
|
| 290 |
-
accept="application/json",
|
| 291 |
contentType="application/json",
|
|
|
|
|
|
|
| 292 |
)
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
def _encode_with_backoff(
|
| 302 |
self,
|
|
@@ -349,7 +380,7 @@ class EmbeddingGenerator:
|
|
| 349 |
if explicit_model:
|
| 350 |
return explicit_model
|
| 351 |
if self.provider == "bedrock":
|
| 352 |
-
return os.getenv("BEDROCK_EMBEDDING_MODEL", "
|
| 353 |
if self.provider == "vertex_ai":
|
| 354 |
return os.getenv("VERTEX_EMBEDDING_MODEL", "gemini-embedding-001")
|
| 355 |
if self._is_hf_space() or self._is_test_context():
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import time
|
| 3 |
+
import json
|
| 4 |
from typing import Callable, List, Optional
|
| 5 |
|
| 6 |
import numpy as np
|
| 7 |
from openai import OpenAI
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class EmbeddingGenerator:
|
|
|
|
| 16 |
self.device = os.getenv("EMBEDDING_DEVICE")
|
| 17 |
self.client = None
|
| 18 |
self.model = None
|
| 19 |
+
self.bedrock_client = None
|
| 20 |
+
self.bedrock_output_dimensionality = self._optional_int(
|
| 21 |
+
os.getenv("BEDROCK_EMBEDDING_OUTPUT_DIMENSIONALITY")
|
| 22 |
+
)
|
| 23 |
self.vertex_task_type_document = os.getenv(
|
| 24 |
"VERTEX_EMBEDDING_TASK_TYPE_DOCUMENT", "RETRIEVAL_DOCUMENT"
|
| 25 |
)
|
|
|
|
| 45 |
)
|
| 46 |
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 47 |
self.embedding_dim = int(os.getenv("OPENAI_EMBEDDING_DIM", "1536"))
|
| 48 |
+
elif self.provider == "bedrock":
|
| 49 |
+
print(
|
| 50 |
+
f"[embeddings] Initializing Bedrock embeddings with model={self.model_name}",
|
| 51 |
+
flush=True,
|
| 52 |
+
)
|
| 53 |
+
try:
|
| 54 |
+
import boto3
|
| 55 |
+
except ImportError as exc:
|
| 56 |
+
raise RuntimeError(
|
| 57 |
+
"Bedrock embedding support requires the `boto3` package."
|
| 58 |
+
) from exc
|
| 59 |
+
|
| 60 |
+
self.bedrock_client = boto3.client(
|
| 61 |
+
"bedrock-runtime",
|
| 62 |
+
region_name=os.getenv("AWS_REGION", os.getenv("AWS_DEFAULT_REGION", "us-east-1")),
|
| 63 |
+
)
|
| 64 |
+
self.embedding_dim = int(
|
| 65 |
+
os.getenv(
|
| 66 |
+
"BEDROCK_EMBEDDING_DIM",
|
| 67 |
+
str(self.bedrock_output_dimensionality or 1536),
|
| 68 |
+
)
|
| 69 |
+
)
|
| 70 |
elif self.provider == "vertex_ai":
|
| 71 |
print(
|
| 72 |
f"[embeddings] Initializing Vertex AI embeddings with model={self.model_name}",
|
|
|
|
| 80 |
) from exc
|
| 81 |
|
| 82 |
project = os.getenv("GOOGLE_CLOUD_PROJECT")
|
| 83 |
+
location = os.getenv("GOOGLE_CLOUD_LOCATION", "global")
|
| 84 |
if not project:
|
| 85 |
raise RuntimeError(
|
| 86 |
"GOOGLE_CLOUD_PROJECT must be set when using Vertex AI embeddings."
|
|
|
|
| 97 |
str(self.vertex_output_dimensionality or 3072),
|
| 98 |
)
|
| 99 |
)
|
| 100 |
+
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
try:
|
| 102 |
+
from sentence_transformers import SentenceTransformer
|
| 103 |
except ImportError as exc:
|
| 104 |
raise RuntimeError(
|
| 105 |
+
"Local embedding support requires the `sentence-transformers` package."
|
| 106 |
) from exc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
model_device = self.device or "cpu"
|
| 108 |
print(
|
| 109 |
f"[embeddings] Loading local embedding model={self.model_name} on device={model_device}",
|
|
|
|
| 125 |
def embed_text(self, text: str) -> np.ndarray:
|
| 126 |
if self.provider == "openai":
|
| 127 |
return self.embed_batch([text])[0]
|
| 128 |
+
if self.provider == "bedrock":
|
| 129 |
+
return self._embed_with_bedrock(
|
| 130 |
+
[text],
|
| 131 |
+
input_type=os.getenv("BEDROCK_EMBEDDING_INPUT_TYPE_QUERY", "search_query"),
|
| 132 |
+
)[0]
|
| 133 |
if self.provider == "vertex_ai":
|
| 134 |
return self._embed_with_vertex(
|
| 135 |
[text],
|
| 136 |
task_type=self.vertex_task_type_query,
|
| 137 |
)[0]
|
|
|
|
|
|
|
| 138 |
query_text = f"{self.query_prefix}: {text}" if self.query_prefix else text
|
| 139 |
return self._encode_with_backoff([query_text], prompt_name=self.query_prompt_name)[0]
|
| 140 |
|
|
|
|
| 156 |
if progress_callback:
|
| 157 |
progress_callback(len(texts), len(texts))
|
| 158 |
return np.array(embeddings, dtype="float32")
|
| 159 |
+
if self.provider == "bedrock":
|
| 160 |
+
return self._embed_batch_with_bedrock(
|
| 161 |
texts=texts,
|
| 162 |
batch_size=batch_size,
|
| 163 |
progress_callback=progress_callback,
|
| 164 |
)
|
| 165 |
+
if self.provider == "vertex_ai":
|
| 166 |
+
return self._embed_batch_with_vertex(
|
| 167 |
texts=texts,
|
| 168 |
batch_size=batch_size,
|
| 169 |
progress_callback=progress_callback,
|
| 170 |
)
|
|
|
|
| 171 |
effective_batch_size = max(1, batch_size or self.batch_size)
|
| 172 |
all_embeddings = []
|
| 173 |
total = len(texts)
|
|
|
|
| 234 |
|
| 235 |
return np.vstack(all_embeddings).astype("float32")
|
| 236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
def _embed_batch_with_bedrock(
|
| 238 |
self,
|
| 239 |
texts: List[str],
|
|
|
|
| 243 |
effective_batch_size = max(1, batch_size or self.batch_size)
|
| 244 |
all_embeddings = []
|
| 245 |
total = len(texts)
|
| 246 |
+
document_input_type = os.getenv("BEDROCK_EMBEDDING_INPUT_TYPE_DOCUMENT", "search_document")
|
| 247 |
|
| 248 |
for start in range(0, total, effective_batch_size):
|
| 249 |
batch = texts[start : start + effective_batch_size]
|
|
|
|
| 255 |
flush=True,
|
| 256 |
)
|
| 257 |
started_at = time.perf_counter()
|
| 258 |
+
batch_embeddings = self._embed_with_bedrock(
|
| 259 |
+
batch,
|
| 260 |
+
input_type=document_input_type,
|
| 261 |
+
)
|
| 262 |
+
all_embeddings.append(batch_embeddings)
|
| 263 |
elapsed = time.perf_counter() - started_at
|
| 264 |
print(
|
| 265 |
f"[embeddings] Finished Bedrock batch {batch_number}/{total_batches} "
|
|
|
|
| 271 |
|
| 272 |
return np.vstack(all_embeddings).astype("float32")
|
| 273 |
|
| 274 |
+
def _embed_with_vertex(self, texts: List[str], task_type: str) -> np.ndarray:
|
| 275 |
+
config = {
|
| 276 |
+
"task_type": task_type,
|
| 277 |
+
}
|
| 278 |
+
if self.vertex_output_dimensionality:
|
| 279 |
+
config["output_dimensionality"] = self.vertex_output_dimensionality
|
| 280 |
|
| 281 |
+
response = self.client.models.embed_content(
|
| 282 |
+
model=self.model_name,
|
| 283 |
+
contents=texts,
|
| 284 |
+
config=config,
|
| 285 |
+
)
|
| 286 |
+
embeddings = getattr(response, "embeddings", None)
|
| 287 |
+
if not embeddings:
|
| 288 |
+
raise RuntimeError("Vertex AI embeddings returned an empty response.")
|
| 289 |
+
|
| 290 |
+
values = []
|
| 291 |
+
for item in embeddings:
|
| 292 |
+
if hasattr(item, "values"):
|
| 293 |
+
values.append(item.values)
|
| 294 |
+
elif isinstance(item, dict):
|
| 295 |
+
values.append(item.get("values"))
|
| 296 |
+
else:
|
| 297 |
+
values.append(getattr(item, "embedding", None))
|
| 298 |
+
|
| 299 |
+
if not values or any(vector is None for vector in values):
|
| 300 |
+
raise RuntimeError("Vertex AI embeddings response could not be parsed.")
|
| 301 |
+
|
| 302 |
+
return np.array(values, dtype="float32")
|
| 303 |
+
|
| 304 |
+
def _embed_with_bedrock(self, texts: List[str], input_type: str) -> np.ndarray:
|
| 305 |
+
response = self.bedrock_client.invoke_model(
|
| 306 |
modelId=self.model_name,
|
|
|
|
|
|
|
| 307 |
contentType="application/json",
|
| 308 |
+
accept="application/json",
|
| 309 |
+
body=json.dumps(self._build_bedrock_embedding_request(texts, input_type)),
|
| 310 |
)
|
| 311 |
+
payload = json.loads(response["body"].read())
|
| 312 |
+
embeddings = payload.get("embeddings")
|
| 313 |
+
|
| 314 |
+
if isinstance(embeddings, dict):
|
| 315 |
+
embeddings = embeddings.get("float")
|
| 316 |
+
|
| 317 |
+
if not embeddings:
|
| 318 |
+
raise RuntimeError("Bedrock embeddings returned an empty response.")
|
| 319 |
+
|
| 320 |
+
return np.array(embeddings, dtype="float32")
|
| 321 |
+
|
| 322 |
+
def _build_bedrock_embedding_request(self, texts: List[str], input_type: str) -> dict:
|
| 323 |
+
payload = {
|
| 324 |
+
"texts": texts,
|
| 325 |
+
"input_type": input_type,
|
| 326 |
+
"embedding_types": ["float"],
|
| 327 |
+
}
|
| 328 |
+
if self.bedrock_output_dimensionality:
|
| 329 |
+
payload["output_dimension"] = self.bedrock_output_dimensionality
|
| 330 |
+
return payload
|
| 331 |
|
| 332 |
def _encode_with_backoff(
|
| 333 |
self,
|
|
|
|
| 380 |
if explicit_model:
|
| 381 |
return explicit_model
|
| 382 |
if self.provider == "bedrock":
|
| 383 |
+
return os.getenv("BEDROCK_EMBEDDING_MODEL", "cohere.embed-v4:0")
|
| 384 |
if self.provider == "vertex_ai":
|
| 385 |
return os.getenv("VERTEX_EMBEDDING_MODEL", "gemini-embedding-001")
|
| 386 |
if self._is_hf_space() or self._is_test_context():
|
src/rag_system.py
CHANGED
|
@@ -6,6 +6,7 @@ from typing import Dict, List, Optional
|
|
| 6 |
from openai import OpenAI
|
| 7 |
|
| 8 |
from src.code_parser import CodeParser
|
|
|
|
| 9 |
from src.database import Repository, get_db_session, init_db, resolve_database_url
|
| 10 |
from src.embeddings import EmbeddingGenerator
|
| 11 |
from src.hybrid_search import HybridSearchEngine
|
|
@@ -526,6 +527,14 @@ Do not leave the answer unfinished.
|
|
| 526 |
}
|
| 527 |
|
| 528 |
def _configure_llm(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 529 |
if self.llm_provider == "groq":
|
| 530 |
self.llm_client = OpenAI(
|
| 531 |
api_key=os.getenv("GROQ_API_KEY"),
|
|
@@ -534,48 +543,53 @@ Do not leave the answer unfinished.
|
|
| 534 |
self.llm_model = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
|
| 535 |
return
|
| 536 |
|
| 537 |
-
if self.llm_provider == "
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
raise RuntimeError(
|
| 542 |
-
"
|
| 543 |
-
)
|
| 544 |
|
| 545 |
-
|
| 546 |
-
self.
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
|
| 552 |
-
if self.llm_provider == "vertex_ai":
|
| 553 |
try:
|
| 554 |
from google import genai
|
| 555 |
except ImportError as exc:
|
| 556 |
raise RuntimeError(
|
| 557 |
-
"Vertex AI
|
| 558 |
-
"Install server dependencies before running local or eval queries."
|
| 559 |
) from exc
|
| 560 |
|
| 561 |
-
project = os.getenv("GOOGLE_CLOUD_PROJECT")
|
| 562 |
-
location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1")
|
| 563 |
-
if not project:
|
| 564 |
-
raise RuntimeError(
|
| 565 |
-
"GOOGLE_CLOUD_PROJECT must be set when using Vertex AI Gemini."
|
| 566 |
-
)
|
| 567 |
-
|
| 568 |
self.llm_client = genai.Client(
|
| 569 |
vertexai=True,
|
| 570 |
project=project,
|
| 571 |
location=location,
|
| 572 |
)
|
| 573 |
-
self.llm_model = os.getenv("VERTEX_LLM_MODEL", "gemini-2.5-pro")
|
| 574 |
return
|
| 575 |
|
| 576 |
raise RuntimeError(f"Unsupported LLM provider: {self.llm_provider}")
|
| 577 |
|
| 578 |
def _generate_markdown_response(self, system_prompt: str, user_prompt: str) -> tuple[str, str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
if self.llm_provider == "groq":
|
| 580 |
response = self.llm_client.chat.completions.create(
|
| 581 |
model=self.llm_model,
|
|
@@ -590,29 +604,26 @@ Do not leave the answer unfinished.
|
|
| 590 |
finish_reason = getattr(response.choices[0], "finish_reason", "") or ""
|
| 591 |
return self._normalize_markdown_answer(content), str(finish_reason)
|
| 592 |
|
| 593 |
-
if self.llm_provider == "
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
system=
|
|
|
|
|
|
|
| 597 |
messages=[
|
| 598 |
{
|
| 599 |
"role": "user",
|
| 600 |
-
"content":
|
| 601 |
}
|
| 602 |
],
|
| 603 |
-
inferenceConfig={
|
| 604 |
-
"temperature": 0.1,
|
| 605 |
-
"maxTokens": 2200,
|
| 606 |
-
},
|
| 607 |
)
|
| 608 |
-
|
| 609 |
-
content_blocks = output_message.get("content") or []
|
| 610 |
text = "".join(
|
| 611 |
-
|
| 612 |
)
|
| 613 |
if not text.strip():
|
| 614 |
-
raise RuntimeError("
|
| 615 |
-
stop_reason =
|
| 616 |
return self._normalize_markdown_answer(text), str(stop_reason)
|
| 617 |
|
| 618 |
response = self.llm_client.models.generate_content(
|
|
|
|
| 6 |
from openai import OpenAI
|
| 7 |
|
| 8 |
from src.code_parser import CodeParser
|
| 9 |
+
from src.bedrock_claude import create_bedrock_runtime_client, generate_bedrock_claude_text
|
| 10 |
from src.database import Repository, get_db_session, init_db, resolve_database_url
|
| 11 |
from src.embeddings import EmbeddingGenerator
|
| 12 |
from src.hybrid_search import HybridSearchEngine
|
|
|
|
| 527 |
}
|
| 528 |
|
| 529 |
def _configure_llm(self):
|
| 530 |
+
if self.llm_provider == "bedrock":
|
| 531 |
+
self.llm_client = create_bedrock_runtime_client()
|
| 532 |
+
self.llm_model = os.getenv(
|
| 533 |
+
"BEDROCK_LLM_MODEL",
|
| 534 |
+
"anthropic.claude-sonnet-4-20250514-v1:0",
|
| 535 |
+
)
|
| 536 |
+
return
|
| 537 |
+
|
| 538 |
if self.llm_provider == "groq":
|
| 539 |
self.llm_client = OpenAI(
|
| 540 |
api_key=os.getenv("GROQ_API_KEY"),
|
|
|
|
| 543 |
self.llm_model = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
|
| 544 |
return
|
| 545 |
|
| 546 |
+
if self.llm_provider == "vertex_ai":
|
| 547 |
+
project = os.getenv("GOOGLE_CLOUD_PROJECT")
|
| 548 |
+
location = os.getenv("GOOGLE_CLOUD_LOCATION", "global")
|
| 549 |
+
if not project:
|
| 550 |
raise RuntimeError(
|
| 551 |
+
"GOOGLE_CLOUD_PROJECT must be set when using Vertex AI LLMs."
|
| 552 |
+
)
|
| 553 |
|
| 554 |
+
self.llm_model = os.getenv("VERTEX_LLM_MODEL", "claude-sonnet-4@20250514")
|
| 555 |
+
if self.llm_model.startswith("claude-"):
|
| 556 |
+
try:
|
| 557 |
+
from anthropic import AnthropicVertex
|
| 558 |
+
except ImportError as exc:
|
| 559 |
+
raise RuntimeError(
|
| 560 |
+
"Vertex AI Claude support requires the `anthropic[vertex]` package."
|
| 561 |
+
) from exc
|
| 562 |
+
self.llm_client = AnthropicVertex(project_id=project, region=location)
|
| 563 |
+
return
|
| 564 |
|
|
|
|
| 565 |
try:
|
| 566 |
from google import genai
|
| 567 |
except ImportError as exc:
|
| 568 |
raise RuntimeError(
|
| 569 |
+
"Vertex AI Gemini support requires the `google-genai` package."
|
|
|
|
| 570 |
) from exc
|
| 571 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
self.llm_client = genai.Client(
|
| 573 |
vertexai=True,
|
| 574 |
project=project,
|
| 575 |
location=location,
|
| 576 |
)
|
|
|
|
| 577 |
return
|
| 578 |
|
| 579 |
raise RuntimeError(f"Unsupported LLM provider: {self.llm_provider}")
|
| 580 |
|
| 581 |
def _generate_markdown_response(self, system_prompt: str, user_prompt: str) -> tuple[str, str]:
|
| 582 |
+
if self.llm_provider == "bedrock":
|
| 583 |
+
text, stop_reason = generate_bedrock_claude_text(
|
| 584 |
+
self.llm_client,
|
| 585 |
+
self.llm_model,
|
| 586 |
+
system_prompt,
|
| 587 |
+
user_prompt,
|
| 588 |
+
max_tokens=2200,
|
| 589 |
+
temperature=0.1,
|
| 590 |
+
)
|
| 591 |
+
return self._normalize_markdown_answer(text), stop_reason
|
| 592 |
+
|
| 593 |
if self.llm_provider == "groq":
|
| 594 |
response = self.llm_client.chat.completions.create(
|
| 595 |
model=self.llm_model,
|
|
|
|
| 604 |
finish_reason = getattr(response.choices[0], "finish_reason", "") or ""
|
| 605 |
return self._normalize_markdown_answer(content), str(finish_reason)
|
| 606 |
|
| 607 |
+
if self.llm_provider == "vertex_ai" and self.llm_model.startswith("claude-"):
|
| 608 |
+
message = self.llm_client.messages.create(
|
| 609 |
+
model=self.llm_model,
|
| 610 |
+
system=system_prompt.strip(),
|
| 611 |
+
max_tokens=2200,
|
| 612 |
+
temperature=0.1,
|
| 613 |
messages=[
|
| 614 |
{
|
| 615 |
"role": "user",
|
| 616 |
+
"content": user_prompt.strip(),
|
| 617 |
}
|
| 618 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 619 |
)
|
| 620 |
+
content_blocks = getattr(message, "content", None) or []
|
|
|
|
| 621 |
text = "".join(
|
| 622 |
+
getattr(block, "text", "") for block in content_blocks if getattr(block, "text", "")
|
| 623 |
)
|
| 624 |
if not text.strip():
|
| 625 |
+
raise RuntimeError("Vertex AI Claude returned an empty response.")
|
| 626 |
+
stop_reason = getattr(message, "stop_reason", "") or ""
|
| 627 |
return self._normalize_markdown_answer(text), str(stop_reason)
|
| 628 |
|
| 629 |
response = self.llm_client.models.generate_content(
|