technophyle commited on
Commit
60b97da
·
verified ·
1 Parent(s): 9d09b0a

Sync from GitHub via hub-sync

Browse files
.dockerignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .venv/
6
+ venv/
7
+ .env
8
+ .git/
9
+ .gitignore
10
+ *.db
11
+ faiss/
12
+ uploads/
13
+ temp_uploads/
14
+ data/
15
+ rag_system.db
16
+
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ RUN apt-get update && apt-get install -y --no-install-recommends \
4
+ build-essential \
5
+ && rm -rf /var/lib/apt/lists/*
6
+
7
+ WORKDIR /app
8
+
9
+ COPY requirements.txt /app/requirements.txt
10
+ RUN pip install --no-cache-dir -r /app/requirements.txt
11
+
12
+ COPY . /app
13
+
14
+ ENV PYTHONUNBUFFERED=1
15
+ EXPOSE 7860
16
+
17
+ CMD ["uvicorn", "server_app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,21 @@
1
  ---
2
- title: Code Compass
3
- emoji: 📈
4
  colorFrom: blue
5
- colorTo: gray
6
  sdk: docker
7
- pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Code Compass API
3
+ emoji: 🚀
4
  colorFrom: blue
5
+ colorTo: indigo
6
  sdk: docker
7
+ app_port: 7860
8
  ---
9
 
10
+ # Code Compass Backend
11
+
12
+ FastAPI backend for a session-oriented GitHub repo QA tool.
13
+
14
+ Behavior:
15
+
16
+ - Clones a public GitHub repo
17
+ - Chunks it with tree-sitter
18
+ - Builds retrieval state with a Qdrant adapter
19
+ - Answers questions with Groq-hosted Llama or Vertex AI Gemini depending on environment configuration
20
+ - Deletes the cloned repo after indexing
21
+ - Keeps only lightweight repo metadata in SQLite
evals/run_eval.py ADDED
@@ -0,0 +1,688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ import asyncio
5
+ import re
6
+ from pathlib import Path
7
+ from collections import Counter, defaultdict
8
+ from statistics import mean
9
+
10
+ import requests
11
+ from dotenv import load_dotenv
12
+
13
+ SERVER_ROOT = Path(__file__).resolve().parents[1]
14
+ if str(SERVER_ROOT) not in sys.path:
15
+ sys.path.insert(0, str(SERVER_ROOT))
16
+
17
+ load_dotenv(SERVER_ROOT / ".env")
18
+
19
+ from src.embeddings import EmbeddingGenerator
20
+
21
+
22
+ API_URL = os.getenv("CODEBASE_RAG_API_URL", "http://localhost:8000")
23
+ REPO_ID = int(os.getenv("CODEBASE_RAG_REPO_ID", "1"))
24
+ SESSION_ID = os.getenv("CODEBASE_RAG_SESSION_ID", "eval-session")
25
+ TOP_K = int(os.getenv("CODEBASE_RAG_TOP_K", "8"))
26
+ QUERY_TIMEOUT_SECONDS = int(os.getenv("CODEBASE_RAG_QUERY_TIMEOUT_SECONDS", "180"))
27
+ ENABLE_RAGAS = os.getenv("CODEBASE_RAG_ENABLE_RAGAS", "1").lower() not in {"0", "false", "no"}
28
+ RAGAS_ASYNC = os.getenv("CODEBASE_RAG_RAGAS_ASYNC", "0").lower() in {"1", "true", "yes"}
29
+ RAGAS_RAISE_EXCEPTIONS = os.getenv("CODEBASE_RAG_RAGAS_RAISE_EXCEPTIONS", "0").lower() in {
30
+ "1",
31
+ "true",
32
+ "yes",
33
+ }
34
+ EVAL_SET_PATH = Path(
35
+ os.getenv(
36
+ "CODEBASE_RAG_EVAL_SET",
37
+ Path(__file__).with_name("sample_eval_set.json"),
38
+ )
39
+ )
40
+
41
+
42
+ def log(message: str):
43
+ print(f"[eval] {message}", file=sys.stderr, flush=True)
44
+
45
+
46
+ def load_eval_rows():
47
+ return json.loads(EVAL_SET_PATH.read_text())
48
+
49
+
50
+ def post_query(row):
51
+ payload = {
52
+ "repo_id": REPO_ID,
53
+ "question": row["question"],
54
+ "top_k": TOP_K,
55
+ "history": row.get("turns", []),
56
+ }
57
+ response = requests.post(
58
+ f"{API_URL}/api/query",
59
+ json=payload,
60
+ headers={"X-Session-Id": SESSION_ID},
61
+ timeout=QUERY_TIMEOUT_SECONDS,
62
+ )
63
+ if not response.ok:
64
+ detail = response.text
65
+ try:
66
+ parsed = response.json()
67
+ detail = parsed.get("detail") or parsed
68
+ except Exception:
69
+ pass
70
+ raise RuntimeError(
71
+ f"Query failed for eval case {row.get('id', row['question'])!r} "
72
+ f"with status {response.status_code}: {detail}"
73
+ )
74
+ return response.json()
75
+
76
+
77
+ def normalize_path(path: str) -> str:
78
+ return path.strip().lstrip("./").lower()
79
+
80
+
81
+ STOPWORDS = {
82
+ "a",
83
+ "an",
84
+ "and",
85
+ "are",
86
+ "as",
87
+ "at",
88
+ "be",
89
+ "by",
90
+ "for",
91
+ "from",
92
+ "how",
93
+ "in",
94
+ "into",
95
+ "is",
96
+ "it",
97
+ "its",
98
+ "of",
99
+ "on",
100
+ "or",
101
+ "that",
102
+ "the",
103
+ "their",
104
+ "this",
105
+ "to",
106
+ "what",
107
+ "when",
108
+ "where",
109
+ "which",
110
+ "with",
111
+ }
112
+
113
+
114
+ def tokenize_text(text: str):
115
+ return re.findall(r"[a-z0-9_./+-]+", (text or "").lower())
116
+
117
+
118
+ def compute_retrieval_metrics(expected_sources, actual_sources):
119
+ expected = {normalize_path(path) for path in expected_sources}
120
+ actual = [normalize_path(path) for path in actual_sources]
121
+ unique_actual = list(dict.fromkeys(actual))
122
+
123
+ def matches_expected(actual_path: str) -> bool:
124
+ for expected_path in expected:
125
+ if actual_path == expected_path:
126
+ return True
127
+ if "/" not in expected_path and actual_path.startswith(expected_path.rstrip("/") + "/"):
128
+ return True
129
+ return False
130
+
131
+ hit = 1 if any(matches_expected(path) for path in actual) else 0
132
+ recall = 0.0
133
+ if expected:
134
+ matched_expected = set()
135
+ for expected_path in expected:
136
+ for actual_path in actual:
137
+ if actual_path == expected_path or (
138
+ "/" not in expected_path and actual_path.startswith(expected_path.rstrip("/") + "/")
139
+ ):
140
+ matched_expected.add(expected_path)
141
+ break
142
+ recall = len(matched_expected) / len(expected)
143
+
144
+ mrr = 0.0
145
+ for index, path in enumerate(actual, start=1):
146
+ if matches_expected(path):
147
+ mrr = 1.0 / index
148
+ break
149
+
150
+ return {
151
+ "retrieval_hit": hit,
152
+ "source_recall": recall,
153
+ "mrr": mrr,
154
+ "top1_hit": 1 if actual and matches_expected(actual[0]) else 0,
155
+ "unique_source_precision": (
156
+ sum(1 for path in unique_actual if matches_expected(path)) / len(unique_actual)
157
+ if unique_actual
158
+ else 0.0
159
+ ),
160
+ "duplicate_source_rate": (
161
+ (len(actual) - len(unique_actual)) / len(actual)
162
+ if actual
163
+ else 0.0
164
+ ),
165
+ }
166
+
167
+
168
+ def keyword_match_ratio(row, answer: str):
169
+ keywords = [keyword.lower() for keyword in row.get("must_include_any", []) if keyword.strip()]
170
+ if not keywords:
171
+ return None
172
+ lowered = answer.lower()
173
+ matched = sum(1 for keyword in keywords if keyword in lowered)
174
+ return matched / len(keywords)
175
+
176
+
177
+ def keyword_pass(row, answer: str, coverage: float | None):
178
+ if coverage is None:
179
+ return None
180
+ minimum = int(row.get("min_keyword_matches", 1))
181
+ keywords = [keyword for keyword in row.get("must_include_any", []) if str(keyword).strip()]
182
+ if not keywords:
183
+ return None
184
+ matched = round(coverage * len(keywords))
185
+ return 1 if matched >= minimum else 0
186
+
187
+
188
+ def answer_length_metrics(answer: str):
189
+ tokens = tokenize_text(answer)
190
+ return {
191
+ "answer_word_count": len(tokens),
192
+ "has_substantive_answer": 1 if len(tokens) >= 40 else 0,
193
+ }
194
+
195
+
196
+ def lexical_overlap_ratio(reference: str, candidate: str):
197
+ reference_terms = {
198
+ token for token in tokenize_text(reference)
199
+ if len(token) > 2 and token not in STOPWORDS
200
+ }
201
+ if not reference_terms:
202
+ return None
203
+ candidate_terms = set(tokenize_text(candidate))
204
+ matched = sum(1 for token in reference_terms if token in candidate_terms)
205
+ return matched / len(reference_terms)
206
+
207
+
208
+ def validate_eval_rows(rows):
209
+ errors = []
210
+ warnings = []
211
+ category_counts = Counter()
212
+ id_counts = Counter()
213
+ id_prefix_counts = Counter()
214
+ expected_source_counts = []
215
+ keyword_counts = []
216
+ conversation_cases = 0
217
+
218
+ for index, row in enumerate(rows, start=1):
219
+ row_id = row.get("id") or f"row-{index}"
220
+ id_counts[row_id] += 1
221
+ prefix = row_id.split("-", 1)[0].lower()
222
+ if prefix:
223
+ id_prefix_counts[prefix] += 1
224
+ category_counts[row.get("category", "general")] += 1
225
+
226
+ question = str(row.get("question", "")).strip()
227
+ ground_truth = str(row.get("ground_truth", "")).strip()
228
+ expected_sources = row.get("expected_sources", [])
229
+ must_include_any = row.get("must_include_any", [])
230
+
231
+ if not question:
232
+ errors.append(f"{row_id}: missing question")
233
+ if not ground_truth:
234
+ errors.append(f"{row_id}: missing ground_truth")
235
+ if not isinstance(expected_sources, list) or not expected_sources:
236
+ errors.append(f"{row_id}: expected_sources must be a non-empty list")
237
+ if must_include_any and not isinstance(must_include_any, list):
238
+ errors.append(f"{row_id}: must_include_any must be a list when present")
239
+ if row.get("turns"):
240
+ conversation_cases += 1
241
+ expected_source_counts.append(len(expected_sources) if isinstance(expected_sources, list) else 0)
242
+ keyword_counts.append(len(must_include_any) if isinstance(must_include_any, list) else 0)
243
+
244
+ duplicate_ids = sorted(row_id for row_id, count in id_counts.items() if count > 1)
245
+ if duplicate_ids:
246
+ errors.append(f"duplicate ids found: {', '.join(duplicate_ids)}")
247
+
248
+ if len(rows) < 25:
249
+ warnings.append(
250
+ "Eval set has fewer than 25 cases. Good for iteration, but light for resume-grade benchmarking."
251
+ )
252
+ if len(category_counts) < 4:
253
+ warnings.append("Eval set covers fewer than 4 categories, so breadth is limited.")
254
+ if conversation_cases < 2:
255
+ warnings.append("Eval set has very little multi-turn coverage.")
256
+ if category_counts and min(category_counts.values()) < 2:
257
+ sparse = sorted(category for category, count in category_counts.items() if count < 2)
258
+ warnings.append(f"Some categories are underrepresented: {', '.join(sparse)}.")
259
+
260
+ if id_prefix_counts:
261
+ dominant_prefix, dominant_count = id_prefix_counts.most_common(1)[0]
262
+ if dominant_count / len(rows) >= 0.8:
263
+ warnings.append(
264
+ f"Most cases share the same id prefix ({dominant_prefix}), which suggests a benchmark focused on one target project."
265
+ )
266
+
267
+ return {
268
+ "case_count": len(rows),
269
+ "category_counts": dict(sorted(category_counts.items())),
270
+ "conversation_case_count": conversation_cases,
271
+ "average_expected_sources": round(mean(expected_source_counts), 2) if expected_source_counts else 0.0,
272
+ "average_keywords_per_case": round(mean(keyword_counts), 2) if keyword_counts else 0.0,
273
+ "errors": errors,
274
+ "warnings": warnings,
275
+ "is_valid": not errors,
276
+ }
277
+
278
+
279
+ def summarize_custom_metrics(details):
280
+ keyword_coverages = [item["keyword_coverage"] for item in details if item["keyword_coverage"] is not None]
281
+ keyword_passes = [item["keyword_pass"] for item in details if item["keyword_pass"] is not None]
282
+ grounded_answer_passes = [
283
+ 1
284
+ for item in details
285
+ if item["retrieval_hit"] == 1
286
+ and item["has_substantive_answer"] == 1
287
+ and (item["keyword_pass"] in {None, 1})
288
+ ]
289
+ exact_source_recall_cases = [1 for item in details if item["source_recall"] == 1.0]
290
+ return {
291
+ "retrieval_hit_rate": round(mean(item["retrieval_hit"] for item in details), 4),
292
+ "top1_hit_rate": round(mean(item["top1_hit"] for item in details), 4),
293
+ "source_recall": round(mean(item["source_recall"] for item in details), 4),
294
+ "mrr": round(mean(item["mrr"] for item in details), 4),
295
+ "unique_source_precision": round(mean(item["unique_source_precision"] for item in details), 4),
296
+ "duplicate_source_rate": round(mean(item["duplicate_source_rate"] for item in details), 4),
297
+ "keyword_coverage": round(mean(keyword_coverages), 4) if keyword_coverages else None,
298
+ "keyword_pass_rate": round(mean(keyword_passes), 4) if keyword_passes else None,
299
+ "ground_truth_lexical_overlap": round(
300
+ mean(item["ground_truth_lexical_overlap"] for item in details if item["ground_truth_lexical_overlap"] is not None),
301
+ 4,
302
+ )
303
+ if any(item["ground_truth_lexical_overlap"] is not None for item in details)
304
+ else None,
305
+ "substantive_answer_rate": round(mean(item["has_substantive_answer"] for item in details), 4),
306
+ "grounded_answer_rate": round(sum(grounded_answer_passes) / len(details), 4) if details else 0.0,
307
+ "exact_source_recall_rate": round(sum(exact_source_recall_cases) / len(details), 4) if details else 0.0,
308
+ }
309
+
310
+
311
+ def summarize_by_category(details):
312
+ grouped = defaultdict(list)
313
+ for item in details:
314
+ grouped[item["category"]].append(item)
315
+
316
+ summary = {}
317
+ for category, items in sorted(grouped.items()):
318
+ keyword_passes = [item["keyword_pass"] for item in items if item["keyword_pass"] is not None]
319
+ summary[category] = {
320
+ "case_count": len(items),
321
+ "retrieval_hit_rate": round(mean(item["retrieval_hit"] for item in items), 4),
322
+ "top1_hit_rate": round(mean(item["top1_hit"] for item in items), 4),
323
+ "source_recall": round(mean(item["source_recall"] for item in items), 4),
324
+ "mrr": round(mean(item["mrr"] for item in items), 4),
325
+ "keyword_pass_rate": round(mean(keyword_passes), 4) if keyword_passes else None,
326
+ "grounded_answer_rate": round(
327
+ mean(
328
+ 1
329
+ if item["retrieval_hit"] == 1 and item["has_substantive_answer"] == 1 and item["keyword_pass"] in {None, 1}
330
+ else 0
331
+ for item in items
332
+ ),
333
+ 4,
334
+ ),
335
+ }
336
+ return summary
337
+
338
+
339
+ def build_headline_metrics(custom_metrics, audit):
340
+ return {
341
+ "sample_size": audit["case_count"],
342
+ "category_count": len(audit["category_counts"]),
343
+ "retrieval_hit_rate": custom_metrics["retrieval_hit_rate"],
344
+ "top1_hit_rate": custom_metrics["top1_hit_rate"],
345
+ "mrr": custom_metrics["mrr"],
346
+ "source_recall": custom_metrics["source_recall"],
347
+ "grounded_answer_rate": custom_metrics["grounded_answer_rate"],
348
+ "keyword_pass_rate": custom_metrics["keyword_pass_rate"],
349
+ }
350
+
351
+
352
+ def build_resume_summary(custom_metrics, audit, ragas_report, ragas_error):
353
+ lines = [
354
+ (
355
+ f"Evaluated on {audit['case_count']} repo-QA cases across "
356
+ f"{len(audit['category_counts'])} categories."
357
+ ),
358
+ (
359
+ f"Deterministic retrieval metrics: hit@{TOP_K} {custom_metrics['retrieval_hit_rate']:.1%}, "
360
+ f"top-1 hit {custom_metrics['top1_hit_rate']:.1%}, MRR {custom_metrics['mrr']:.3f}, "
361
+ f"source recall {custom_metrics['source_recall']:.1%}."
362
+ ),
363
+ (
364
+ f"Answer quality checks: grounded answer rate {custom_metrics['grounded_answer_rate']:.1%}"
365
+ + (
366
+ f", keyword/checklist pass rate {custom_metrics['keyword_pass_rate']:.1%}."
367
+ if custom_metrics["keyword_pass_rate"] is not None
368
+ else "."
369
+ )
370
+ ),
371
+ ]
372
+
373
+ if ragas_report and not ragas_error:
374
+ lines.append(
375
+ "LLM-judge metrics (supporting signal, not primary headline): "
376
+ f"faithfulness {ragas_report.get('faithfulness', 0.0):.3f}, "
377
+ f"answer relevancy {ragas_report.get('answer_relevancy', 0.0):.3f}, "
378
+ f"context precision {ragas_report.get('context_precision', 0.0):.3f}."
379
+ )
380
+ else:
381
+ lines.append("LLM-judge metrics were skipped or unstable, so headline metrics rely on deterministic checks.")
382
+
383
+ if audit["warnings"]:
384
+ lines.append(
385
+ "Benchmark caveat: "
386
+ + " ".join(audit["warnings"][:2])
387
+ )
388
+
389
+ return " ".join(lines)
390
+
391
+
392
+ def benchmark_readiness(audit, ragas_error):
393
+ reasons = []
394
+ if audit["case_count"] < 25:
395
+ reasons.append("small_sample")
396
+ if len(audit["category_counts"]) < 4:
397
+ reasons.append("limited_category_coverage")
398
+ if audit["conversation_case_count"] < 2:
399
+ reasons.append("limited_multi_turn_coverage")
400
+ if audit["warnings"]:
401
+ reasons.append("dataset_scope_warnings")
402
+ if ragas_error not in {None, "disabled"}:
403
+ reasons.append("ragas_instability")
404
+
405
+ if reasons:
406
+ return {
407
+ "status": "internal_or_demo_benchmark",
408
+ "reasons": reasons,
409
+ }
410
+ return {
411
+ "status": "presentation_ready",
412
+ "reasons": [],
413
+ }
414
+
415
+
416
+ def maybe_write_report(report):
417
+ output_path = os.getenv("CODEBASE_RAG_EVAL_OUTPUT")
418
+ if not output_path:
419
+ return None
420
+ target = Path(output_path)
421
+ target.parent.mkdir(parents=True, exist_ok=True)
422
+ target.write_text(json.dumps(report, indent=2))
423
+ return str(target)
424
+
425
+
426
+ def build_vertex_ragas_llm(run_config):
427
+ from google import genai
428
+ from langchain_core.outputs import Generation, LLMResult
429
+ from ragas.llms.base import BaseRagasLLM
430
+
431
+ class VertexRagasLLM(BaseRagasLLM):
432
+ def __init__(self, model: str, project: str, location: str, run_config):
433
+ self.client = genai.Client(
434
+ vertexai=True,
435
+ project=project,
436
+ location=location,
437
+ )
438
+ self.model = model
439
+ self.set_run_config(run_config)
440
+
441
+ def _prompt_to_text(self, prompt):
442
+ prefix = (
443
+ "Return only valid JSON or the exact structured output requested by the prompt. "
444
+ "Do not add markdown fences, explanations, or extra prose.\n\n"
445
+ )
446
+ if hasattr(prompt, "to_string"):
447
+ return prefix + prompt.to_string()
448
+ return prefix + str(prompt)
449
+
450
+ def _generate_once(self, prompt, n=1, temperature=1e-8, stop=None, callbacks=None):
451
+ prompt_text = self._prompt_to_text(prompt)
452
+ config = {
453
+ "temperature": 0.0,
454
+ "candidate_count": max(1, n),
455
+ "max_output_tokens": int(os.getenv("EVAL_MAX_OUTPUT_TOKENS", "2048")),
456
+ "response_mime_type": "application/json",
457
+ }
458
+ if stop:
459
+ config["stop_sequences"] = stop
460
+
461
+ response = self.client.models.generate_content(
462
+ model=self.model,
463
+ contents=prompt_text,
464
+ config=config,
465
+ )
466
+
467
+ candidates = getattr(response, "candidates", None) or []
468
+ generations = []
469
+ if candidates:
470
+ for candidate in candidates[: max(1, n)]:
471
+ text = getattr(candidate, "text", None)
472
+ if text is None and hasattr(candidate, "content"):
473
+ parts = getattr(candidate.content, "parts", None) or []
474
+ text = "".join(getattr(part, "text", "") for part in parts if getattr(part, "text", ""))
475
+ generations.append(Generation(text=(text or "").strip()))
476
+ elif getattr(response, "text", None):
477
+ generations.append(Generation(text=response.text.strip()))
478
+
479
+ if not generations:
480
+ raise RuntimeError("Vertex AI judge returned an empty response.")
481
+
482
+ return LLMResult(generations=[generations])
483
+
484
+ def generate_text(self, prompt, n=1, temperature=1e-8, stop=None, callbacks=None):
485
+ return self._generate_once(
486
+ prompt=prompt,
487
+ n=n,
488
+ temperature=temperature,
489
+ stop=stop,
490
+ callbacks=callbacks,
491
+ )
492
+
493
+ async def agenerate_text(self, prompt, n=1, temperature=1e-8, stop=None, callbacks=None):
494
+ return await asyncio.to_thread(
495
+ self._generate_once,
496
+ prompt,
497
+ n,
498
+ temperature,
499
+ stop,
500
+ callbacks,
501
+ )
502
+
503
+ project = os.getenv("GOOGLE_CLOUD_PROJECT")
504
+ location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1")
505
+ model = os.getenv("EVAL_MODEL", os.getenv("VERTEX_LLM_MODEL", "gemini-2.5-pro"))
506
+ if not project:
507
+ raise RuntimeError("GOOGLE_CLOUD_PROJECT must be set for Vertex AI RAGAS evaluation.")
508
+ return VertexRagasLLM(model=model, project=project, location=location, run_config=run_config)
509
+
510
+
511
+ def build_ragas_embeddings(run_config):
512
+ from ragas.embeddings.base import BaseRagasEmbeddings
513
+
514
+ class AppEmbeddingWrapper(BaseRagasEmbeddings):
515
+ def __init__(self, generator, run_config):
516
+ self.generator = generator
517
+ self.set_run_config(run_config)
518
+
519
+ def embed_query(self, text):
520
+ return self.generator.embed_text(text).tolist()
521
+
522
+ def embed_documents(self, texts):
523
+ vectors = self.generator.embed_batch(list(texts))
524
+ return vectors.tolist()
525
+
526
+ async def aembed_query(self, text):
527
+ return await asyncio.to_thread(self.embed_query, text)
528
+
529
+ async def aembed_documents(self, texts):
530
+ return await asyncio.to_thread(self.embed_documents, texts)
531
+
532
+ return AppEmbeddingWrapper(EmbeddingGenerator(), run_config=run_config)
533
+
534
+
535
+ def run_ragas(rows, outputs):
536
+ if not ENABLE_RAGAS:
537
+ log("RAGAS disabled via CODEBASE_RAG_ENABLE_RAGAS=0. Reporting custom metrics only.")
538
+ return None, "disabled"
539
+
540
+ try:
541
+ from datasets import Dataset
542
+ from ragas import evaluate
543
+ from ragas.metrics import answer_relevancy, context_precision, faithfulness
544
+ from ragas.run_config import RunConfig
545
+ except Exception as exc:
546
+ log(f"Skipping RAGAS because the evaluation dependencies could not be loaded: {exc}")
547
+ return None, f"import_error: {exc}"
548
+
549
+ def build_ragas_dataset():
550
+ samples = []
551
+ for row, result in zip(rows, outputs):
552
+ samples.append(
553
+ {
554
+ "question": row["question"],
555
+ "answer": result["answer"],
556
+ "contexts": [source["snippet"] for source in result.get("sources", [])],
557
+ "ground_truth": row["ground_truth"],
558
+ }
559
+ )
560
+ return Dataset.from_list(samples)
561
+
562
+ log("Running RAGAS metrics. This can take a while.")
563
+ try:
564
+ timeout_seconds = int(os.getenv("EVAL_TIMEOUT_SECONDS", "180"))
565
+ thread_timeout_seconds = float(os.getenv("EVAL_THREAD_TIMEOUT_SECONDS", str(max(timeout_seconds, 240))))
566
+ max_workers = int(os.getenv("EVAL_MAX_WORKERS", "4"))
567
+ run_config = RunConfig(
568
+ timeout=timeout_seconds,
569
+ thread_timeout=thread_timeout_seconds,
570
+ max_workers=max_workers,
571
+ max_retries=int(os.getenv("EVAL_MAX_RETRIES", "3")),
572
+ max_wait=int(os.getenv("EVAL_MAX_WAIT_SECONDS", "60")),
573
+ )
574
+ log(
575
+ "Using Vertex AI for RAGAS judge model "
576
+ f"({os.getenv('EVAL_MODEL', os.getenv('VERTEX_LLM_MODEL', 'gemini-2.5-pro'))})"
577
+ )
578
+ log(
579
+ f"RAGAS runtime: async={RAGAS_ASYNC}, raise_exceptions={RAGAS_RAISE_EXCEPTIONS}, "
580
+ f"timeout={timeout_seconds}s, thread_timeout={thread_timeout_seconds}s, max_workers={max_workers}"
581
+ )
582
+ llm = build_vertex_ragas_llm(run_config)
583
+ embeddings = build_ragas_embeddings(run_config)
584
+ ragas_report = evaluate(
585
+ build_ragas_dataset(),
586
+ metrics=[faithfulness, answer_relevancy, context_precision],
587
+ llm=llm,
588
+ embeddings=embeddings,
589
+ run_config=run_config,
590
+ is_async=RAGAS_ASYNC,
591
+ raise_exceptions=RAGAS_RAISE_EXCEPTIONS,
592
+ )
593
+ return {key: float(value) for key, value in ragas_report.items()}, None
594
+ except Exception as exc:
595
+ log(f"RAGAS evaluation failed: {exc}")
596
+ return None, str(exc)
597
+
598
+
599
+ def run():
600
+ log(f"Loading eval set from {EVAL_SET_PATH}")
601
+ rows = load_eval_rows()
602
+ audit = validate_eval_rows(rows)
603
+ if audit["errors"]:
604
+ raise RuntimeError("Eval set validation failed: " + "; ".join(audit["errors"]))
605
+ for warning in audit["warnings"]:
606
+ log(f"Eval set warning: {warning}")
607
+ log(
608
+ f"Starting eval with api_url={API_URL}, repo_id={REPO_ID}, "
609
+ f"session_id={SESSION_ID}, top_k={TOP_K}, cases={len(rows)}"
610
+ )
611
+ outputs = []
612
+ details = []
613
+
614
+ for index, row in enumerate(rows, start=1):
615
+ case_id = row.get("id", row["question"])
616
+ log(f"[{index}/{len(rows)}] Querying case {case_id}")
617
+ result = post_query(row)
618
+ outputs.append(result)
619
+ log(
620
+ f"[{index}/{len(rows)}] Received answer for {case_id} "
621
+ f"with {len(result.get('sources', []))} sources"
622
+ )
623
+
624
+ cited_paths = [source["file_path"] for source in result.get("sources", [])]
625
+ metrics = compute_retrieval_metrics(row.get("expected_sources", []), cited_paths)
626
+ keyword_coverage = keyword_match_ratio(row, result.get("answer", ""))
627
+ keyword_gate = keyword_pass(row, result.get("answer", ""), keyword_coverage)
628
+ length_metrics = answer_length_metrics(result.get("answer", ""))
629
+ overlap = lexical_overlap_ratio(row.get("ground_truth", ""), result.get("answer", ""))
630
+
631
+ details.append(
632
+ {
633
+ "id": row.get("id", row["question"]),
634
+ "category": row.get("category", "general"),
635
+ "question": row["question"],
636
+ "answer": result.get("answer", ""),
637
+ "expected_sources": row.get("expected_sources", []),
638
+ "retrieved_sources": cited_paths,
639
+ "retrieval_hit": metrics["retrieval_hit"],
640
+ "source_recall": metrics["source_recall"],
641
+ "mrr": metrics["mrr"],
642
+ "top1_hit": metrics["top1_hit"],
643
+ "unique_source_precision": metrics["unique_source_precision"],
644
+ "duplicate_source_rate": metrics["duplicate_source_rate"],
645
+ "keyword_coverage": keyword_coverage,
646
+ "keyword_pass": keyword_gate,
647
+ "ground_truth_lexical_overlap": overlap,
648
+ **length_metrics,
649
+ }
650
+ )
651
+
652
+ log("Finished query loop. Computing aggregate metrics.")
653
+ custom_metrics = summarize_custom_metrics(details)
654
+ category_breakdown = summarize_by_category(details)
655
+ ragas_report, ragas_error = run_ragas(rows, outputs)
656
+ headline_metrics = build_headline_metrics(custom_metrics, audit)
657
+ resume_summary = build_resume_summary(custom_metrics, audit, ragas_report, ragas_error)
658
+ readiness = benchmark_readiness(audit, ragas_error)
659
+
660
+ report = {
661
+ "config": {
662
+ "api_url": API_URL,
663
+ "repo_id": REPO_ID,
664
+ "session_id": SESSION_ID,
665
+ "top_k": TOP_K,
666
+ "query_timeout_seconds": QUERY_TIMEOUT_SECONDS,
667
+ "eval_set": str(EVAL_SET_PATH),
668
+ },
669
+ "eval_set_audit": audit,
670
+ "headline_metrics": headline_metrics,
671
+ "benchmark_readiness": readiness,
672
+ "ragas": ragas_report,
673
+ "ragas_error": ragas_error,
674
+ "custom_metrics": custom_metrics,
675
+ "category_breakdown": category_breakdown,
676
+ "resume_summary": resume_summary,
677
+ "cases": details,
678
+ }
679
+ output_path = maybe_write_report(report)
680
+ if output_path:
681
+ log(f"Wrote JSON report to {output_path}")
682
+
683
+ log("Eval complete. Printing JSON report.")
684
+ print(json.dumps(report, indent=2))
685
+
686
+
687
+ if __name__ == "__main__":
688
+ run()
evals/sample_eval_set.json ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "id": "sqlmodel-purpose",
4
+ "category": "architecture",
5
+ "question": "What is SQLModel and how is it positioned relative to Pydantic and SQLAlchemy?",
6
+ "ground_truth": "SQLModel is a thin layer designed to combine Pydantic-style data modeling with SQLAlchemy ORM and SQL expression features. The project presents itself as a library for SQL databases in Python that emphasizes simplicity, compatibility, and robustness while being built on top of Pydantic and SQLAlchemy.",
7
+ "expected_sources": [
8
+ "README.md",
9
+ "sqlmodel/__init__.py",
10
+ "sqlmodel/main.py"
11
+ ],
12
+ "must_include_any": [
13
+ "Pydantic",
14
+ "SQLAlchemy",
15
+ "thin layer"
16
+ ],
17
+ "min_keyword_matches": 2
18
+ },
19
+ {
20
+ "id": "sqlmodel-core-model-class",
21
+ "category": "architecture",
22
+ "question": "Where is the core SQLModel base class defined and what is its role?",
23
+ "ground_truth": "The SQLModel base class is defined in sqlmodel/main.py. It acts as the main model base that bridges typed field definitions, Pydantic-compatible validation behavior, and SQLAlchemy table or ORM metadata.",
24
+ "expected_sources": [
25
+ "sqlmodel/main.py"
26
+ ],
27
+ "must_include_any": [
28
+ "SQLModel",
29
+ "base class",
30
+ "Pydantic",
31
+ "SQLAlchemy"
32
+ ],
33
+ "min_keyword_matches": 2
34
+ },
35
+ {
36
+ "id": "sqlmodel-field-helper",
37
+ "category": "architecture",
38
+ "question": "How does SQLModel expose field declarations for model attributes?",
39
+ "ground_truth": "SQLModel exposes a Field helper in sqlmodel/main.py and re-exports it at the package level. Field collects model metadata such as defaults, primary key flags, indexes, foreign keys, nullability, and other column-related settings used when building SQL-backed models.",
40
+ "expected_sources": [
41
+ "sqlmodel/main.py",
42
+ "sqlmodel/__init__.py"
43
+ ],
44
+ "must_include_any": [
45
+ "Field",
46
+ "primary key",
47
+ "foreign key",
48
+ "re-export"
49
+ ],
50
+ "min_keyword_matches": 2
51
+ },
52
+ {
53
+ "id": "sqlmodel-relationship-helper",
54
+ "category": "architecture",
55
+ "question": "How are relationships modeled in SQLModel?",
56
+ "ground_truth": "Relationships are declared through the Relationship helper and associated metadata in sqlmodel/main.py. SQLModel captures relationship configuration separately from normal field definitions so relationship behavior can be translated into SQLAlchemy ORM relationship setup.",
57
+ "expected_sources": [
58
+ "sqlmodel/main.py",
59
+ "sqlmodel/__init__.py"
60
+ ],
61
+ "must_include_any": [
62
+ "Relationship",
63
+ "relationship",
64
+ "SQLAlchemy",
65
+ "metadata"
66
+ ],
67
+ "min_keyword_matches": 2
68
+ },
69
+ {
70
+ "id": "sqlmodel-field-function",
71
+ "category": "specific-function",
72
+ "question": "What does the Field() function do in SQLModel?",
73
+ "ground_truth": "Field defines metadata for a model attribute, including validation defaults and SQL column configuration such as primary_key, foreign_key, index, nullable, sa_type, or sa_column options. SQLModel uses that metadata when constructing models that can also map to tables.",
74
+ "expected_sources": [
75
+ "sqlmodel/main.py"
76
+ ],
77
+ "must_include_any": [
78
+ "Field",
79
+ "primary_key",
80
+ "nullable",
81
+ "column"
82
+ ],
83
+ "min_keyword_matches": 2
84
+ },
85
+ {
86
+ "id": "sqlmodel-relationship-function",
87
+ "category": "specific-function",
88
+ "question": "What does Relationship() do in SQLModel?",
89
+ "ground_truth": "Relationship captures relationship-specific configuration for ORM links between models, such as back_populates and SQLAlchemy relationship arguments. It provides structured metadata that SQLModel can later translate into SQLAlchemy relationship objects.",
90
+ "expected_sources": [
91
+ "sqlmodel/main.py"
92
+ ],
93
+ "must_include_any": [
94
+ "Relationship",
95
+ "back_populates",
96
+ "relationship",
97
+ "metadata"
98
+ ],
99
+ "min_keyword_matches": 2
100
+ },
101
+ {
102
+ "id": "sqlmodel-session-exec",
103
+ "category": "specific-function",
104
+ "question": "What is special about Session.exec() in SQLModel?",
105
+ "ground_truth": "SQLModel provides a Session class with an exec helper that offers a friendlier typed wrapper around SQLAlchemy execution patterns, especially for SQLModel select statements. It is intended to make common query execution more ergonomic than raw SQLAlchemy session.execute calls.",
106
+ "expected_sources": [
107
+ "sqlmodel/orm/session.py",
108
+ "sqlmodel/__init__.py"
109
+ ],
110
+ "must_include_any": [
111
+ "Session",
112
+ "exec",
113
+ "execute",
114
+ "typed"
115
+ ],
116
+ "min_keyword_matches": 2
117
+ },
118
+ {
119
+ "id": "sqlmodel-async-session-exec",
120
+ "category": "specific-function",
121
+ "question": "How does async query execution work in SQLModel?",
122
+ "ground_truth": "SQLModel provides async session support under sqlmodel.ext.asyncio.session, including an async session wrapper that supports exec-style query execution for SQLModel statements in asynchronous applications.",
123
+ "expected_sources": [
124
+ "sqlmodel/ext/asyncio/session.py"
125
+ ],
126
+ "must_include_any": [
127
+ "async",
128
+ "AsyncSession",
129
+ "exec",
130
+ "greenlet"
131
+ ],
132
+ "min_keyword_matches": 2
133
+ },
134
+ {
135
+ "id": "sqlmodel-select-export",
136
+ "category": "specific-function",
137
+ "question": "How is select exposed to users in SQLModel?",
138
+ "ground_truth": "SQLModel re-exports a select helper from its SQL expression layer so users can write typed select statements directly from the sqlmodel package instead of importing SQLAlchemy primitives manually.",
139
+ "expected_sources": [
140
+ "sqlmodel/__init__.py",
141
+ "sqlmodel/sql/expression.py"
142
+ ],
143
+ "must_include_any": [
144
+ "select",
145
+ "re-export",
146
+ "expression",
147
+ "sqlmodel"
148
+ ],
149
+ "min_keyword_matches": 2
150
+ },
151
+ {
152
+ "id": "sqlmodel-create-engine-export",
153
+ "category": "specific-function",
154
+ "question": "How does SQLModel expose create_engine to application code?",
155
+ "ground_truth": "SQLModel re-exports create_engine from SQLAlchemy at the package level so users can import it directly from sqlmodel while using SQLModel models and sessions together.",
156
+ "expected_sources": [
157
+ "sqlmodel/__init__.py"
158
+ ],
159
+ "must_include_any": [
160
+ "create_engine",
161
+ "re-export",
162
+ "SQLAlchemy",
163
+ "sqlmodel"
164
+ ],
165
+ "min_keyword_matches": 2
166
+ },
167
+ {
168
+ "id": "sqlmodel-metadata-create-all",
169
+ "category": "config-setup",
170
+ "question": "How are database tables created when using SQLModel?",
171
+ "ground_truth": "Table creation typically happens by calling SQLModel.metadata.create_all(engine). SQLModel models register table metadata in a way that allows SQLAlchemy metadata creation workflows to build the underlying database tables.",
172
+ "expected_sources": [
173
+ "README.md",
174
+ "sqlmodel/main.py",
175
+ "docs_src"
176
+ ],
177
+ "must_include_any": [
178
+ "metadata",
179
+ "create_all",
180
+ "engine",
181
+ "table"
182
+ ],
183
+ "min_keyword_matches": 2
184
+ },
185
+ {
186
+ "id": "sqlmodel-package-exports",
187
+ "category": "config-setup",
188
+ "question": "What does sqlmodel.__init__ export for end users?",
189
+ "ground_truth": "The package initializer re-exports core user-facing APIs from SQLAlchemy and SQLModel, including create_engine, Session, SQLModel, Field, Relationship, and select-related helpers so application code can import most common primitives directly from sqlmodel.",
190
+ "expected_sources": [
191
+ "sqlmodel/__init__.py"
192
+ ],
193
+ "must_include_any": [
194
+ "Session",
195
+ "SQLModel",
196
+ "Field",
197
+ "create_engine"
198
+ ],
199
+ "min_keyword_matches": 3
200
+ },
201
+ {
202
+ "id": "sqlmodel-readme-basic-flow",
203
+ "category": "config-setup",
204
+ "question": "What basic database workflow does the README show for SQLModel?",
205
+ "ground_truth": "The README demonstrates defining a SQLModel table model, creating an engine, creating tables with metadata.create_all, opening a Session, inserting rows, committing, and then selecting rows with select and session.exec.",
206
+ "expected_sources": [
207
+ "README.md"
208
+ ],
209
+ "must_include_any": [
210
+ "create_engine",
211
+ "Session",
212
+ "create_all",
213
+ "select"
214
+ ],
215
+ "min_keyword_matches": 3
216
+ },
217
+ {
218
+ "id": "sqlmodel-column-options-errors",
219
+ "category": "error-handling",
220
+ "question": "How does SQLModel guard against conflicting or invalid Field configuration?",
221
+ "ground_truth": "SQLModel performs validation around Field configuration in its core model code and raises errors when incompatible options are combined or when SQLAlchemy-specific arguments conflict with other field settings.",
222
+ "expected_sources": [
223
+ "sqlmodel/main.py"
224
+ ],
225
+ "must_include_any": [
226
+ "raise",
227
+ "Field",
228
+ "conflict",
229
+ "sa_column"
230
+ ],
231
+ "min_keyword_matches": 2
232
+ },
233
+ {
234
+ "id": "sqlmodel-relationship-errors",
235
+ "category": "error-handling",
236
+ "question": "Where would SQLModel enforce invalid relationship configuration?",
237
+ "ground_truth": "Relationship configuration is handled in the core SQLModel model layer, where relationship metadata is collected and incompatible combinations are guarded before being translated to SQLAlchemy ORM behavior.",
238
+ "expected_sources": [
239
+ "sqlmodel/main.py"
240
+ ],
241
+ "must_include_any": [
242
+ "Relationship",
243
+ "metadata",
244
+ "SQLAlchemy",
245
+ "raise"
246
+ ],
247
+ "min_keyword_matches": 2
248
+ },
249
+ {
250
+ "id": "sqlmodel-session-cross-file",
251
+ "category": "cross-file",
252
+ "question": "How do SQLModel models flow into query execution with Session.exec()?",
253
+ "ground_truth": "Models are defined in the core SQLModel layer, queries are built through the SQL expression helpers such as select, and then those statements are executed through the SQLModel Session.exec wrapper, which ties model definitions and typed query execution together.",
254
+ "expected_sources": [
255
+ "sqlmodel/main.py",
256
+ "sqlmodel/sql/expression.py",
257
+ "sqlmodel/orm/session.py"
258
+ ],
259
+ "must_include_any": [
260
+ "select",
261
+ "Session",
262
+ "exec",
263
+ "model"
264
+ ],
265
+ "min_keyword_matches": 3
266
+ },
267
+ {
268
+ "id": "sqlmodel-sync-async-cross-file",
269
+ "category": "cross-file",
270
+ "question": "How does SQLModel support both sync and async session patterns across files?",
271
+ "ground_truth": "SQLModel exposes synchronous session helpers in its ORM session module and asynchronous support in the ext.asyncio package, giving similar exec-oriented ergonomics across both sync and async query paths.",
272
+ "expected_sources": [
273
+ "sqlmodel/orm/session.py",
274
+ "sqlmodel/ext/asyncio/session.py",
275
+ "sqlmodel/__init__.py"
276
+ ],
277
+ "must_include_any": [
278
+ "sync",
279
+ "async",
280
+ "Session",
281
+ "exec"
282
+ ],
283
+ "min_keyword_matches": 3
284
+ },
285
+ {
286
+ "id": "sqlmodel-field-to-table-flow",
287
+ "category": "cross-file",
288
+ "question": "How do typed Field declarations become SQL table columns in SQLModel?",
289
+ "ground_truth": "Typed model attributes and Field metadata are collected in the SQLModel core model layer, where SQLModel builds SQLAlchemy-compatible field and table metadata so the resulting class can participate in SQLAlchemy table creation and ORM mapping.",
290
+ "expected_sources": [
291
+ "sqlmodel/main.py",
292
+ "sqlmodel/_compat.py"
293
+ ],
294
+ "must_include_any": [
295
+ "Field",
296
+ "column",
297
+ "table",
298
+ "metadata"
299
+ ],
300
+ "min_keyword_matches": 3
301
+ },
302
+ {
303
+ "id": "sqlmodel-docs-fastapi-positioning",
304
+ "category": "docs",
305
+ "question": "How does the project describe SQLModel's relationship to FastAPI in its docs or README?",
306
+ "ground_truth": "The project describes SQLModel as being designed to simplify SQL database work in FastAPI applications and emphasizes that it is created by the same author, with strong compatibility between FastAPI, Pydantic, and SQLAlchemy.",
307
+ "expected_sources": [
308
+ "README.md",
309
+ "docs"
310
+ ],
311
+ "must_include_any": [
312
+ "FastAPI",
313
+ "same author",
314
+ "compatibility"
315
+ ],
316
+ "min_keyword_matches": 2
317
+ },
318
+ {
319
+ "id": "sqlmodel-followup-show-session-code",
320
+ "category": "conversation",
321
+ "turns": [
322
+ {
323
+ "role": "user",
324
+ "content": "How does SQLModel make query execution easier than raw SQLAlchemy?"
325
+ },
326
+ {
327
+ "role": "assistant",
328
+ "content": "It provides a Session.exec helper and package-level exports to simplify common query patterns."
329
+ }
330
+ ],
331
+ "question": "show me the code path for that",
332
+ "ground_truth": "The follow-up should stay anchored to Session.exec and SQLModel query ergonomics, retrieving code from the session wrapper and related SQLModel exports instead of drifting to README-only results.",
333
+ "expected_sources": [
334
+ "sqlmodel/orm/session.py",
335
+ "sqlmodel/__init__.py",
336
+ "sqlmodel/sql/expression.py"
337
+ ],
338
+ "must_include_any": [
339
+ "Session",
340
+ "exec",
341
+ "select"
342
+ ],
343
+ "min_keyword_matches": 2
344
+ },
345
+ {
346
+ "id": "sqlmodel-select-implementation-layer",
347
+ "category": "specific-function",
348
+ "question": "Where is select implemented under the hood and how is that different from how it is exposed publicly?",
349
+ "ground_truth": "SQLModel exposes select through package-level imports such as sqlmodel.__init__ and sqlmodel.sql.expression, while the implementation details and overload-heavy generation live in lower-level SQL expression modules like _expression_select_gen.py and related select classes.",
350
+ "expected_sources": [
351
+ "sqlmodel/__init__.py",
352
+ "sqlmodel/sql/expression.py",
353
+ "sqlmodel/sql/_expression_select_gen.py",
354
+ "sqlmodel/sql/_expression_select_cls.py"
355
+ ],
356
+ "must_include_any": [
357
+ "select",
358
+ "public",
359
+ "implementation",
360
+ "re-export"
361
+ ],
362
+ "min_keyword_matches": 2
363
+ },
364
+ {
365
+ "id": "sqlmodel-async-session-delegation",
366
+ "category": "cross-file",
367
+ "question": "How does AsyncSession.exec reuse the synchronous Session.exec path?",
368
+ "ground_truth": "The async session layer delegates execution to the synchronous Session.exec logic rather than duplicating it. AsyncSession uses greenlet-based bridging so async callers can reuse the sync execution wrapper and still get SQLModel-style exec ergonomics.",
369
+ "expected_sources": [
370
+ "sqlmodel/ext/asyncio/session.py",
371
+ "sqlmodel/orm/session.py"
372
+ ],
373
+ "must_include_any": [
374
+ "AsyncSession",
375
+ "Session",
376
+ "greenlet",
377
+ "exec"
378
+ ],
379
+ "min_keyword_matches": 3
380
+ },
381
+ {
382
+ "id": "sqlmodel-select-tutorial-usage",
383
+ "category": "docs",
384
+ "question": "How do the docs teach people to use select together with Session.exec?",
385
+ "ground_truth": "The tutorials show users building a statement with select(...) and then executing it through Session.exec(...), positioning exec as the ergonomic query entry point for SQLModel statements.",
386
+ "expected_sources": [
387
+ "docs/tutorial/select.md",
388
+ "README.md"
389
+ ],
390
+ "must_include_any": [
391
+ "select",
392
+ "Session",
393
+ "exec",
394
+ "statement"
395
+ ],
396
+ "min_keyword_matches": 3
397
+ },
398
+ {
399
+ "id": "sqlmodel-fastapi-response-model-docs",
400
+ "category": "docs",
401
+ "question": "How do the FastAPI docs describe using SQLModel models as response models?",
402
+ "ground_truth": "The FastAPI-focused docs explain that SQLModel classes can participate in API request and response modeling because they build on Pydantic, letting applications reuse models or related model variants in response_model patterns.",
403
+ "expected_sources": [
404
+ "docs/tutorial/fastapi/response-model.md",
405
+ "README.md"
406
+ ],
407
+ "must_include_any": [
408
+ "FastAPI",
409
+ "response_model",
410
+ "Pydantic",
411
+ "model"
412
+ ],
413
+ "min_keyword_matches": 2
414
+ },
415
+ {
416
+ "id": "sqlmodel-independent-library-positioning",
417
+ "category": "docs",
418
+ "question": "Does the project describe SQLModel as FastAPI-only or as a standalone library too?",
419
+ "ground_truth": "The docs position SQLModel as especially strong with FastAPI, but still as an independent library that can be used outside FastAPI. It is not described as FastAPI-only.",
420
+ "expected_sources": [
421
+ "README.md",
422
+ "docs/features.md",
423
+ "docs/index.md"
424
+ ],
425
+ "must_include_any": [
426
+ "FastAPI",
427
+ "independent",
428
+ "library",
429
+ "not"
430
+ ],
431
+ "min_keyword_matches": 2
432
+ },
433
+ {
434
+ "id": "sqlmodel-sa-relationship-test-guard",
435
+ "category": "tests",
436
+ "question": "What invalid Relationship combinations are guarded by tests?",
437
+ "ground_truth": "The relationship tests cover invalid combinations where a pre-built sa_relationship is mixed with sa_relationship_args or sa_relationship_kwargs, confirming that SQLModel raises when overlapping relationship configuration styles are combined.",
438
+ "expected_sources": [
439
+ "tests/test_field_sa_relationship.py",
440
+ "sqlmodel/main.py"
441
+ ],
442
+ "must_include_any": [
443
+ "sa_relationship",
444
+ "args",
445
+ "kwargs",
446
+ "raise"
447
+ ],
448
+ "min_keyword_matches": 3
449
+ },
450
+ {
451
+ "id": "sqlmodel-ondelete-nullable-test",
452
+ "category": "tests",
453
+ "question": "What does the project test about ondelete and nullable relationship fields?",
454
+ "ground_truth": "The test suite checks that using ondelete='SET NULL' on a non-nullable relationship field is invalid. The model layer should raise because SET NULL requires the underlying foreign key column to be nullable.",
455
+ "expected_sources": [
456
+ "tests/test_ondelete_raises.py",
457
+ "sqlmodel/main.py"
458
+ ],
459
+ "must_include_any": [
460
+ "ondelete",
461
+ "SET NULL",
462
+ "nullable",
463
+ "raise"
464
+ ],
465
+ "min_keyword_matches": 3
466
+ },
467
+ {
468
+ "id": "sqlmodel-type-validation-test",
469
+ "category": "tests",
470
+ "question": "What do the tests suggest about invalid SQLAlchemy or field type combinations in SQLModel?",
471
+ "ground_truth": "The tests indicate that SQLModel raises when unsupported or ambiguous field type combinations are mapped into SQLAlchemy columns, reinforcing that not every Python type annotation can become a database column shape automatically.",
472
+ "expected_sources": [
473
+ "tests/test_sqlalchemy_type_errors.py",
474
+ "sqlmodel/main.py"
475
+ ],
476
+ "must_include_any": [
477
+ "type",
478
+ "SQLAlchemy",
479
+ "raise",
480
+ "column"
481
+ ],
482
+ "min_keyword_matches": 2
483
+ },
484
+ {
485
+ "id": "sqlmodel-readme-engine-session-imports",
486
+ "category": "config-setup",
487
+ "question": "What top-level imports does the README encourage for getting started with SQLModel?",
488
+ "ground_truth": "The README encourages importing SQLModel, Field, Session, create_engine, and select from the top-level sqlmodel package so users can define models, create tables, and run queries with a unified import style.",
489
+ "expected_sources": [
490
+ "README.md",
491
+ "sqlmodel/__init__.py"
492
+ ],
493
+ "must_include_any": [
494
+ "SQLModel",
495
+ "Field",
496
+ "Session",
497
+ "create_engine",
498
+ "select"
499
+ ],
500
+ "min_keyword_matches": 4
501
+ },
502
+ {
503
+ "id": "sqlmodel-many-to-many-link-model-docs",
504
+ "category": "docs",
505
+ "question": "How do the relationship docs explain link_model for many-to-many mappings?",
506
+ "ground_truth": "The relationship docs explain that link_model is used as an association or link table model for many-to-many relationships, letting SQLModel connect two models through an explicit intermediary model.",
507
+ "expected_sources": [
508
+ "docs/tutorial/many-to-many/create-models-with-link.md",
509
+ "sqlmodel/main.py"
510
+ ],
511
+ "must_include_any": [
512
+ "link_model",
513
+ "many-to-many",
514
+ "association",
515
+ "relationship"
516
+ ],
517
+ "min_keyword_matches": 3
518
+ },
519
+ {
520
+ "id": "sqlmodel-followup-async-code-path",
521
+ "category": "conversation",
522
+ "turns": [
523
+ {
524
+ "role": "user",
525
+ "content": "How does async query execution work in SQLModel?"
526
+ },
527
+ {
528
+ "role": "assistant",
529
+ "content": "It uses AsyncSession and bridges into the sync session execution path."
530
+ }
531
+ ],
532
+ "question": "show me where that bridge happens",
533
+ "ground_truth": "The follow-up should stay on the async execution path and retrieve the async session module together with the sync session module it delegates to, rather than drifting to docs-only summaries.",
534
+ "expected_sources": [
535
+ "sqlmodel/ext/asyncio/session.py",
536
+ "sqlmodel/orm/session.py"
537
+ ],
538
+ "must_include_any": [
539
+ "AsyncSession",
540
+ "greenlet",
541
+ "Session",
542
+ "exec"
543
+ ],
544
+ "min_keyword_matches": 3
545
+ },
546
+ {
547
+ "id": "sqlmodel-followup-field-column-path",
548
+ "category": "conversation",
549
+ "turns": [
550
+ {
551
+ "role": "user",
552
+ "content": "How do typed Field declarations become SQL table columns in SQLModel?"
553
+ },
554
+ {
555
+ "role": "assistant",
556
+ "content": "The metaclass and field helpers translate Field metadata into SQLAlchemy Column objects."
557
+ }
558
+ ],
559
+ "question": "show me the main code path for that conversion",
560
+ "ground_truth": "The follow-up should stay anchored to the field-to-column conversion path in the core model implementation instead of drifting to tutorial prose alone.",
561
+ "expected_sources": [
562
+ "sqlmodel/main.py",
563
+ "sqlmodel/_compat.py"
564
+ ],
565
+ "must_include_any": [
566
+ "Field",
567
+ "Column",
568
+ "metaclass",
569
+ "conversion"
570
+ ],
571
+ "min_keyword_matches": 3
572
+ },
573
+ {
574
+ "id": "sqlmodel-followup-select-public-path",
575
+ "category": "conversation",
576
+ "turns": [
577
+ {
578
+ "role": "user",
579
+ "content": "How is select exposed to users in SQLModel?"
580
+ },
581
+ {
582
+ "role": "assistant",
583
+ "content": "It is re-exported for public use from the SQLModel package and expression layer."
584
+ }
585
+ ],
586
+ "question": "and where is the lower-level implementation behind that?",
587
+ "ground_truth": "The follow-up should connect the public export path to the lower-level select generator and select class implementation files instead of repeating only the package-level export story.",
588
+ "expected_sources": [
589
+ "sqlmodel/sql/expression.py",
590
+ "sqlmodel/sql/_expression_select_gen.py",
591
+ "sqlmodel/sql/_expression_select_cls.py"
592
+ ],
593
+ "must_include_any": [
594
+ "select",
595
+ "implementation",
596
+ "expression",
597
+ "class"
598
+ ],
599
+ "min_keyword_matches": 2
600
+ },
601
+ {
602
+ "id": "sqlmodel-test-vs-core-evidence-balance",
603
+ "category": "cross-file",
604
+ "question": "When explaining configuration errors in SQLModel, how should core implementation and tests complement each other?",
605
+ "ground_truth": "The core implementation in sqlmodel/main.py is the canonical source for behavior, while tests such as relationship and ondelete checks provide evidence that those guards are enforced in concrete scenarios. A good answer should balance both without treating tests as the primary implementation source.",
606
+ "expected_sources": [
607
+ "sqlmodel/main.py",
608
+ "tests/test_field_sa_relationship.py",
609
+ "tests/test_ondelete_raises.py"
610
+ ],
611
+ "must_include_any": [
612
+ "main.py",
613
+ "tests",
614
+ "canonical",
615
+ "guard"
616
+ ],
617
+ "min_keyword_matches": 2
618
+ },
619
+ {
620
+ "id": "sqlmodel-docs-vs-core-select-balance",
621
+ "category": "cross-file",
622
+ "question": "For explaining select in SQLModel, which files are canonical implementation sources and which are usage-oriented docs?",
623
+ "ground_truth": "The canonical implementation path is in sqlmodel.__init__, sqlmodel.sql.expression, and the lower-level select generator or select class modules, while files like README and docs/tutorial/select.md are usage-oriented documentation rather than the implementation itself.",
624
+ "expected_sources": [
625
+ "sqlmodel/__init__.py",
626
+ "sqlmodel/sql/expression.py",
627
+ "sqlmodel/sql/_expression_select_gen.py",
628
+ "docs/tutorial/select.md",
629
+ "README.md"
630
+ ],
631
+ "must_include_any": [
632
+ "canonical",
633
+ "implementation",
634
+ "docs",
635
+ "usage"
636
+ ],
637
+ "min_keyword_matches": 2
638
+ },
639
+ {
640
+ "id": "sqlmodel-features-doc-positioning",
641
+ "category": "docs",
642
+ "question": "What themes does the features documentation emphasize about SQLModel's value proposition?",
643
+ "ground_truth": "The features docs emphasize reduced duplication, editor friendliness, compatibility across Pydantic and SQLAlchemy, and an ergonomic way to work with SQL databases using standard Python type hints and models.",
644
+ "expected_sources": [
645
+ "docs/features.md",
646
+ "README.md"
647
+ ],
648
+ "must_include_any": [
649
+ "duplication",
650
+ "editor",
651
+ "compatibility",
652
+ "Python"
653
+ ],
654
+ "min_keyword_matches": 2
655
+ },
656
+ {
657
+ "id": "sqlmodel-docs-index-overview",
658
+ "category": "docs",
659
+ "question": "What kind of project overview should a user get from the docs index for SQLModel?",
660
+ "ground_truth": "The docs index should frame SQLModel as a Python SQL library that combines data modeling and database access patterns, point users toward tutorials or feature explanations, and reinforce its relationship to Pydantic, SQLAlchemy, and FastAPI.",
661
+ "expected_sources": [
662
+ "docs/index.md",
663
+ "README.md"
664
+ ],
665
+ "must_include_any": [
666
+ "overview",
667
+ "Pydantic",
668
+ "SQLAlchemy",
669
+ "FastAPI"
670
+ ],
671
+ "min_keyword_matches": 2
672
+ }
673
+ ]
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.109.2
2
+ uvicorn[standard]==0.27.1
3
+ sqlalchemy==2.0.25
4
+ pydantic==2.6.1
5
+ python-dotenv==1.0.1
6
+
7
+ openai==1.12.0
8
+ google-genai==1.12.1
9
+ httpx==0.28.1
10
+ numpy==1.26.4
11
+ rank-bm25==0.2.2
12
+ qdrant-client==1.15.1
13
+ sentence-transformers==2.7.0
14
+ einops==0.8.1
15
+ tree-sitter==0.21.3
16
+ tree-sitter-languages==1.10.2
17
+
18
+ ragas==0.1.10
19
+ datasets==2.18.0
20
+ pandas==2.2.0
server_app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import List, Literal, Optional
4
+
5
+ from fastapi import BackgroundTasks, Depends, FastAPI, Header, HTTPException, Query
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from pydantic import BaseModel, Field, HttpUrl
8
+ from dotenv import load_dotenv
9
+
10
+ from src.rag_system import CodebaseRAGSystem
11
+
12
+ load_dotenv(Path(__file__).with_name(".env"))
13
+
14
+
15
+ app = FastAPI(
16
+ title="Codebase RAG API",
17
+ description="Index GitHub repositories and answer natural-language questions with grounded citations.",
18
+ version="2.0.0",
19
+ )
20
+
21
+ cors_origins = [
22
+ origin.strip()
23
+ for origin in os.getenv("CORS_ORIGINS", "http://localhost:3000").split(",")
24
+ if origin.strip()
25
+ ]
26
+
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=cors_origins,
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
+
35
+ rag_system: Optional[CodebaseRAGSystem] = None
36
+
37
+
38
+ class RepoIndexRequest(BaseModel):
39
+ github_url: HttpUrl
40
+
41
+
42
+ class QueryRequest(BaseModel):
43
+ repo_id: int = Field(..., ge=1)
44
+ question: str = Field(..., min_length=3)
45
+ top_k: int = Field(8, ge=3, le=12)
46
+ history: List["MessageTurn"] = Field(default_factory=list, max_length=8)
47
+
48
+
49
+ class MessageTurn(BaseModel):
50
+ role: Literal["user", "assistant"]
51
+ content: str = Field(..., min_length=1, max_length=4000)
52
+
53
+
54
+ def require_session_id(x_session_id: Optional[str] = Header(None, alias="X-Session-Id")) -> str:
55
+ if not x_session_id or not x_session_id.strip():
56
+ raise HTTPException(status_code=400, detail="Missing session id")
57
+ return x_session_id.strip()
58
+
59
+
60
+ @app.on_event("startup")
61
+ def startup():
62
+ global rag_system
63
+ Path("./data").mkdir(exist_ok=True)
64
+ rag_system = CodebaseRAGSystem()
65
+
66
+
67
+ @app.get("/")
68
+ async def root():
69
+ return {
70
+ "status": "online",
71
+ "message": "Codebase RAG API is running",
72
+ }
73
+
74
+
75
+ @app.get("/api/health")
76
+ async def health():
77
+ return {
78
+ "status": "ok",
79
+ }
80
+
81
+
82
+ @app.get("/api/repos")
83
+ async def list_repositories(session_id: str = Depends(require_session_id)):
84
+ return rag_system.list_repositories_for_session(session_id)
85
+
86
+
87
+ @app.get("/api/repos/{repo_id}")
88
+ async def get_repository(repo_id: int, session_id: str = Depends(require_session_id)):
89
+ repo = rag_system.get_repository_for_session(repo_id, session_id)
90
+ if not repo:
91
+ raise HTTPException(status_code=404, detail="Repository not found")
92
+ return repo
93
+
94
+
95
+ @app.post("/api/repos/index")
96
+ async def queue_repository_index(
97
+ request: RepoIndexRequest,
98
+ background_tasks: BackgroundTasks,
99
+ session_id: str = Depends(require_session_id),
100
+ ):
101
+ try:
102
+ repo = rag_system.create_or_reset_repository(str(request.github_url), session_id)
103
+ background_tasks.add_task(rag_system.index_repository, repo.id)
104
+ return {
105
+ "success": True,
106
+ "message": "Repository indexing started",
107
+ "repo": rag_system.get_repository_for_session(repo.id, session_id),
108
+ }
109
+ except Exception as exc:
110
+ raise HTTPException(status_code=400, detail=str(exc))
111
+
112
+
113
+ @app.post("/api/query")
114
+ async def query_repository(request: QueryRequest, session_id: str = Depends(require_session_id)):
115
+ try:
116
+ return rag_system.answer_question(
117
+ repo_id=request.repo_id,
118
+ session_key=session_id,
119
+ question=request.question.strip(),
120
+ top_k=request.top_k,
121
+ history=request.history,
122
+ )
123
+ except ValueError as exc:
124
+ raise HTTPException(status_code=400, detail=str(exc))
125
+ except Exception as exc:
126
+ raise HTTPException(status_code=500, detail=str(exc))
127
+
128
+
129
+ @app.post("/api/session/end")
130
+ async def end_session(session_id: str = Query(..., min_length=8)):
131
+ rag_system.end_session(session_id)
132
+ return {"success": True}
133
+
134
+
135
+ if __name__ == "__main__":
136
+ import uvicorn
137
+
138
+ uvicorn.run("server_app:app", host="0.0.0.0", port=8000, reload=True)
src/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Codebase RAG backend package.
3
+ """
4
+
5
+ from .code_parser import CodeParser
6
+ from .embeddings import EmbeddingGenerator
7
+ from .hybrid_search import HybridSearchEngine
8
+ from .rag_system import CodebaseRAGSystem
9
+ from .repo_fetcher import RepoFetcher
10
+ from .vector_store import QdrantVectorStore
11
+
12
+ __version__ = "2.0.0"
13
+ __all__ = [
14
+ "CodeParser",
15
+ "CodebaseRAGSystem",
16
+ "EmbeddingGenerator",
17
+ "QdrantVectorStore",
18
+ "HybridSearchEngine",
19
+ "RepoFetcher",
20
+ ]
src/code_parser.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from pathlib import Path
3
+ from typing import Dict, List, Optional
4
+
5
+ from tree_sitter_languages import get_parser
6
+
7
+
8
+ LANGUAGE_BY_EXTENSION = {
9
+ ".py": "python",
10
+ ".js": "javascript",
11
+ ".jsx": "javascript",
12
+ ".ts": "typescript",
13
+ ".tsx": "tsx",
14
+ ".java": "java",
15
+ ".go": "go",
16
+ ".rs": "rust",
17
+ }
18
+
19
+ SYMBOL_NODE_TYPES = {
20
+ "python": {"function_definition", "class_definition"},
21
+ "javascript": {
22
+ "function_declaration",
23
+ "class_declaration",
24
+ "method_definition",
25
+ "generator_function_declaration",
26
+ "lexical_declaration",
27
+ "variable_declaration",
28
+ },
29
+ "typescript": {
30
+ "function_declaration",
31
+ "class_declaration",
32
+ "method_definition",
33
+ "interface_declaration",
34
+ "type_alias_declaration",
35
+ "lexical_declaration",
36
+ "variable_statement",
37
+ },
38
+ "tsx": {
39
+ "function_declaration",
40
+ "class_declaration",
41
+ "method_definition",
42
+ "interface_declaration",
43
+ "type_alias_declaration",
44
+ "lexical_declaration",
45
+ "variable_statement",
46
+ },
47
+ "java": {
48
+ "class_declaration",
49
+ "method_declaration",
50
+ "interface_declaration",
51
+ "enum_declaration",
52
+ },
53
+ "go": {
54
+ "function_declaration",
55
+ "method_declaration",
56
+ "type_declaration",
57
+ },
58
+ "rust": {
59
+ "function_item",
60
+ "impl_item",
61
+ "struct_item",
62
+ "enum_item",
63
+ "trait_item",
64
+ },
65
+ }
66
+
67
+ IDENTIFIER_TYPES = {
68
+ "identifier",
69
+ "property_identifier",
70
+ "type_identifier",
71
+ "field_identifier",
72
+ }
73
+
74
+
75
+ class CodeParser:
76
+ def __init__(self):
77
+ self.parsers = {}
78
+
79
+ def detect_language(self, file_path: str) -> str:
80
+ return LANGUAGE_BY_EXTENSION.get(Path(file_path).suffix.lower(), "text")
81
+
82
+ def _get_parser(self, language: str):
83
+ if language == "text":
84
+ return None
85
+ if language not in self.parsers:
86
+ self.parsers[language] = get_parser(language)
87
+ return self.parsers[language]
88
+
89
+ def chunk_file(self, file_path: str, repo_root: str) -> List[Dict]:
90
+ language = self.detect_language(file_path)
91
+ source = Path(file_path).read_text(encoding="utf-8", errors="ignore")
92
+ relative_path = str(Path(file_path).resolve().relative_to(Path(repo_root).resolve()))
93
+
94
+ if not source.strip():
95
+ return []
96
+
97
+ parser = self._get_parser(language)
98
+ if parser is None:
99
+ return self._fallback_chunks(source, relative_path, language)
100
+
101
+ tree = parser.parse(bytes(source, "utf-8"))
102
+ lines = source.splitlines()
103
+ chunks = []
104
+ capture_types = SYMBOL_NODE_TYPES.get(language, set())
105
+
106
+ def visit(node):
107
+ if node.type in capture_types:
108
+ chunk = self._build_chunk(node, source, lines, relative_path, language)
109
+ if chunk:
110
+ chunks.append(chunk)
111
+ return
112
+ for child in node.children:
113
+ visit(child)
114
+
115
+ visit(tree.root_node)
116
+
117
+ if not chunks:
118
+ return self._fallback_chunks(source, relative_path, language)
119
+
120
+ return chunks
121
+
122
+ def _build_chunk(self, node, source: str, lines: List[str], relative_path: str, language: str) -> Optional[Dict]:
123
+ start_line = node.start_point[0] + 1
124
+ end_line = node.end_point[0] + 1
125
+ snippet = "\n".join(lines[start_line - 1 : end_line]).strip()
126
+ if len(snippet.splitlines()) < 2:
127
+ return None
128
+
129
+ name_node = node.child_by_field_name("name")
130
+ symbol_name = None
131
+ if name_node is not None:
132
+ symbol_name = source[name_node.start_byte : name_node.end_byte].strip()
133
+ if not symbol_name:
134
+ symbol_name = self._find_identifier(node, source)
135
+
136
+ signature = lines[start_line - 1].strip() if start_line - 1 < len(lines) else ""
137
+ searchable_text = "\n".join(
138
+ part for part in [relative_path, symbol_name or "", signature, snippet] if part
139
+ )
140
+
141
+ return {
142
+ "file_path": relative_path,
143
+ "language": language,
144
+ "symbol_name": symbol_name or relative_path.split("/")[-1],
145
+ "symbol_type": node.type,
146
+ "line_start": start_line,
147
+ "line_end": end_line,
148
+ "signature": signature,
149
+ "content": snippet,
150
+ "searchable_text": searchable_text,
151
+ "metadata_json": {
152
+ "parser": "tree-sitter",
153
+ },
154
+ }
155
+
156
+ def _find_identifier(self, node, source: str) -> Optional[str]:
157
+ stack = list(node.children)
158
+ while stack:
159
+ current = stack.pop(0)
160
+ if current.type in IDENTIFIER_TYPES:
161
+ return source[current.start_byte : current.end_byte].strip()
162
+ stack.extend(current.children)
163
+ return None
164
+
165
+ def _fallback_chunks(self, source: str, relative_path: str, language: str) -> List[Dict]:
166
+ blocks = []
167
+ lines = source.splitlines()
168
+ buffer = []
169
+ start_line = 1
170
+ for index, line in enumerate(lines, start=1):
171
+ if not buffer:
172
+ start_line = index
173
+ buffer.append(line)
174
+ trigger = False
175
+ if language == "text":
176
+ trigger = len(buffer) >= 60 or (line.startswith("#") and len(buffer) > 8)
177
+ else:
178
+ trigger = (
179
+ re.match(r"^\s*(def |class |function |const |export |interface |type )", line)
180
+ and len(buffer) > 8
181
+ ) or len(buffer) >= 80
182
+
183
+ if trigger:
184
+ chunk_text = "\n".join(buffer).strip()
185
+ if chunk_text:
186
+ blocks.append(
187
+ {
188
+ "file_path": relative_path,
189
+ "language": language,
190
+ "symbol_name": f"{Path(relative_path).name}:{start_line}",
191
+ "symbol_type": "fallback_chunk",
192
+ "line_start": start_line,
193
+ "line_end": index,
194
+ "signature": buffer[0].strip(),
195
+ "content": chunk_text,
196
+ "searchable_text": f"{relative_path}\n{chunk_text}",
197
+ "metadata_json": {
198
+ "parser": "fallback",
199
+ },
200
+ }
201
+ )
202
+ buffer = []
203
+
204
+ if buffer:
205
+ chunk_text = "\n".join(buffer).strip()
206
+ if chunk_text:
207
+ blocks.append(
208
+ {
209
+ "file_path": relative_path,
210
+ "language": language,
211
+ "symbol_name": f"{Path(relative_path).name}:{start_line}",
212
+ "symbol_type": "fallback_chunk",
213
+ "line_start": start_line,
214
+ "line_end": len(lines),
215
+ "signature": buffer[0].strip(),
216
+ "content": chunk_text,
217
+ "searchable_text": f"{relative_path}\n{chunk_text}",
218
+ "metadata_json": {
219
+ "parser": "fallback",
220
+ },
221
+ }
222
+ )
223
+ return blocks
src/database.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+
5
+ from sqlalchemy import (
6
+ JSON,
7
+ Column,
8
+ DateTime,
9
+ Float,
10
+ ForeignKey,
11
+ Integer,
12
+ String,
13
+ Text,
14
+ create_engine,
15
+ inspect,
16
+ text,
17
+ )
18
+ from sqlalchemy.orm import declarative_base, relationship, sessionmaker
19
+
20
+ Base = declarative_base()
21
+ _ENGINE_CACHE = {}
22
+ _SESSION_FACTORY_CACHE = {}
23
+ SERVER_DIR = Path(__file__).resolve().parents[1]
24
+
25
+
26
+ class Repository(Base):
27
+ __tablename__ = "repositories"
28
+
29
+ id = Column(Integer, primary_key=True)
30
+ github_url = Column(String(1024), nullable=False, unique=True)
31
+ source_url = Column(String(1024))
32
+ session_key = Column(String(255), index=True)
33
+ session_expires_at = Column(DateTime)
34
+ owner = Column(String(255), nullable=False)
35
+ name = Column(String(255), nullable=False)
36
+ branch = Column(String(255), nullable=False, default="main")
37
+ local_path = Column(String(1024))
38
+ status = Column(String(64), nullable=False, default="queued")
39
+ error_message = Column(Text)
40
+ file_count = Column(Integer, nullable=False, default=0)
41
+ chunk_count = Column(Integer, nullable=False, default=0)
42
+ indexed_at = Column(DateTime)
43
+ created_at = Column(DateTime, default=datetime.utcnow)
44
+ updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
45
+
46
+ chunks = relationship(
47
+ "CodeChunk", back_populates="repository", cascade="all, delete-orphan"
48
+ )
49
+ chat_turns = relationship(
50
+ "ChatTurn", back_populates="repository", cascade="all, delete-orphan"
51
+ )
52
+
53
+
54
+ class CodeChunk(Base):
55
+ __tablename__ = "code_chunks"
56
+
57
+ id = Column(Integer, primary_key=True)
58
+ repository_id = Column(Integer, ForeignKey("repositories.id"), nullable=False)
59
+ file_path = Column(String(1024), nullable=False)
60
+ language = Column(String(64), nullable=False)
61
+ symbol_name = Column(String(255))
62
+ symbol_type = Column(String(128), nullable=False, default="chunk")
63
+ line_start = Column(Integer, nullable=False)
64
+ line_end = Column(Integer, nullable=False)
65
+ signature = Column(Text)
66
+ content = Column(Text, nullable=False)
67
+ searchable_text = Column(Text, nullable=False)
68
+ metadata_json = Column(JSON, nullable=False, default=dict)
69
+ embedding_id = Column(Integer)
70
+ rerank_score = Column(Float)
71
+ created_at = Column(DateTime, default=datetime.utcnow)
72
+
73
+ repository = relationship("Repository", back_populates="chunks")
74
+
75
+
76
+ class ChatTurn(Base):
77
+ __tablename__ = "chat_turns"
78
+
79
+ id = Column(Integer, primary_key=True)
80
+ repository_id = Column(Integer, ForeignKey("repositories.id"), nullable=False)
81
+ role = Column(String(32), nullable=False)
82
+ content = Column(Text, nullable=False)
83
+ answer_json = Column(JSON)
84
+ created_at = Column(DateTime, default=datetime.utcnow)
85
+
86
+ repository = relationship("Repository", back_populates="chat_turns")
87
+
88
+
89
+ def init_db(database_url: str = None):
90
+ if database_url is None:
91
+ database_url = os.getenv("DATABASE_URL", "sqlite:///./codebase_rag.db")
92
+
93
+ database_url = resolve_database_url(database_url)
94
+ if database_url in _ENGINE_CACHE:
95
+ return _ENGINE_CACHE[database_url], _SESSION_FACTORY_CACHE[database_url]
96
+
97
+ connect_args = {"check_same_thread": False} if database_url.startswith("sqlite") else {}
98
+ engine = create_engine(database_url, echo=False, connect_args=connect_args)
99
+ Base.metadata.create_all(engine)
100
+ _ensure_runtime_columns(engine)
101
+ session_local = sessionmaker(bind=engine)
102
+ _ENGINE_CACHE[database_url] = engine
103
+ _SESSION_FACTORY_CACHE[database_url] = session_local
104
+ return engine, session_local
105
+
106
+
107
+ def resolve_database_url(database_url: str) -> str:
108
+ if not database_url.startswith("sqlite:///"):
109
+ return database_url
110
+
111
+ sqlite_path = database_url.removeprefix("sqlite:///")
112
+ if sqlite_path == ":memory:":
113
+ return database_url
114
+
115
+ path = Path(sqlite_path)
116
+ if not path.is_absolute():
117
+ path = SERVER_DIR / path
118
+ path.parent.mkdir(parents=True, exist_ok=True)
119
+ path.touch(exist_ok=True)
120
+ return f"sqlite:///{path.resolve()}"
121
+
122
+
123
+ def _ensure_runtime_columns(engine):
124
+ inspector = inspect(engine)
125
+ if "repositories" not in inspector.get_table_names():
126
+ return
127
+
128
+ existing = {column["name"] for column in inspector.get_columns("repositories")}
129
+ alterations = {
130
+ "source_url": "ALTER TABLE repositories ADD COLUMN source_url VARCHAR(1024)",
131
+ "session_key": "ALTER TABLE repositories ADD COLUMN session_key VARCHAR(255)",
132
+ "session_expires_at": "ALTER TABLE repositories ADD COLUMN session_expires_at DATETIME",
133
+ }
134
+
135
+ with engine.begin() as connection:
136
+ for column_name, statement in alterations.items():
137
+ if column_name not in existing:
138
+ connection.execute(text(statement))
139
+
140
+
141
+ def get_db_session(database_url: str = None):
142
+ _, session_local = init_db(database_url)
143
+ return session_local()
src/document_processor.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from typing import List, Tuple
3
+ from pathlib import Path
4
+ import pypdf
5
+
6
+
7
+ class DocumentProcessor:
8
+
9
+ def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50):
10
+ self.chunk_size = chunk_size
11
+ self.chunk_overlap = chunk_overlap
12
+
13
+ def extract_text_from_pdf(self, file_path: str) -> str:
14
+ text = ""
15
+ try:
16
+ with open(file_path, "rb") as file:
17
+ pdf_reader = pypdf.PdfReader(file)
18
+ for page in pdf_reader.pages:
19
+ text += page.extract_text() + "\n"
20
+ except Exception as e:
21
+ raise ValueError(f"Error reading PDF: {str(e)}")
22
+
23
+ return text.strip()
24
+
25
+ def chunk_text(self, text: str) -> List[str]:
26
+ if not text:
27
+ return []
28
+
29
+ chunks = []
30
+ start = 0
31
+ text_length = len(text)
32
+
33
+ while start < text_length:
34
+ end = start + self.chunk_size
35
+ chunk = text[start:end]
36
+
37
+ if end < text_length:
38
+ last_period = chunk.rfind(".")
39
+ last_newline = chunk.rfind("\n")
40
+ break_point = max(last_period, last_newline)
41
+
42
+ if break_point > self.chunk_size * 0.5:
43
+ chunk = chunk[: break_point + 1]
44
+ end = start + break_point + 1
45
+
46
+ chunks.append(chunk.strip())
47
+
48
+ start = end - self.chunk_overlap
49
+
50
+ return [c for c in chunks if c]
51
+
52
+ def process_document(self, file_path: str) -> Tuple[str, List[str]]:
53
+
54
+ file_ext = Path(file_path).suffix.lower()
55
+
56
+ if file_ext == ".pdf":
57
+ text = self.extract_text_from_pdf(file_path)
58
+ elif file_ext == ".txt":
59
+ with open(file_path, "r", encoding="utf-8") as f:
60
+ text = f.read()
61
+ else:
62
+ raise ValueError(f"Unsupported file type: {file_ext}")
63
+
64
+ chunks = self.chunk_text(text)
65
+
66
+ return text, chunks
67
+
68
+ @staticmethod
69
+ def compute_file_hash(file_path: str) -> str:
70
+ hash_md5 = hashlib.md5()
71
+ with open(file_path, "rb") as f:
72
+ for chunk in iter(lambda: f.read(4096), b""):
73
+ hash_md5.update(chunk)
74
+ return hash_md5.hexdigest()
src/embeddings.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from typing import Callable, List, Optional
4
+
5
+ import numpy as np
6
+ from openai import OpenAI
7
+ from sentence_transformers import SentenceTransformer
8
+
9
+
10
+ class EmbeddingGenerator:
11
+ def __init__(self, provider: str = None, model_name: str = None):
12
+ configured_provider = (provider or os.getenv("EMBEDDING_PROVIDER", "auto")).lower()
13
+ self.provider = self._resolve_provider(configured_provider)
14
+ self.model_name = model_name or self._resolve_model_name()
15
+ self.batch_size = int(os.getenv("EMBEDDING_BATCH_SIZE", "8"))
16
+ self.device = os.getenv("EMBEDDING_DEVICE")
17
+ self.client = None
18
+ self.model = None
19
+ self.vertex_task_type_document = os.getenv(
20
+ "VERTEX_EMBEDDING_TASK_TYPE_DOCUMENT", "RETRIEVAL_DOCUMENT"
21
+ )
22
+ self.vertex_task_type_query = os.getenv(
23
+ "VERTEX_EMBEDDING_TASK_TYPE_QUERY", "RETRIEVAL_QUERY"
24
+ )
25
+ self.vertex_output_dimensionality = self._optional_int(
26
+ os.getenv("VERTEX_EMBEDDING_OUTPUT_DIMENSIONALITY")
27
+ )
28
+ self.query_prefix = os.getenv("EMBEDDING_QUERY_PREFIX", "").strip()
29
+ normalized_model_name = self.model_name.lower()
30
+ self.query_prompt_name = (
31
+ os.getenv("EMBEDDING_QUERY_PROMPT_NAME", "query")
32
+ if "nomic-embed-code" in normalized_model_name
33
+ or "coderankembed" in normalized_model_name
34
+ else None
35
+ )
36
+
37
+ if self.provider == "openai":
38
+ print(
39
+ f"[embeddings] Initializing OpenAI embeddings with model={self.model_name}",
40
+ flush=True,
41
+ )
42
+ self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
43
+ self.embedding_dim = int(os.getenv("OPENAI_EMBEDDING_DIM", "1536"))
44
+ elif self.provider == "vertex_ai":
45
+ print(
46
+ f"[embeddings] Initializing Vertex AI embeddings with model={self.model_name}",
47
+ flush=True,
48
+ )
49
+ try:
50
+ from google import genai
51
+ except ImportError as exc:
52
+ raise RuntimeError(
53
+ "Vertex AI embedding support requires the `google-genai` package."
54
+ ) from exc
55
+
56
+ project = os.getenv("GOOGLE_CLOUD_PROJECT")
57
+ location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1")
58
+ if not project:
59
+ raise RuntimeError(
60
+ "GOOGLE_CLOUD_PROJECT must be set when using Vertex AI embeddings."
61
+ )
62
+
63
+ self.client = genai.Client(
64
+ vertexai=True,
65
+ project=project,
66
+ location=location,
67
+ )
68
+ self.embedding_dim = int(
69
+ os.getenv(
70
+ "VERTEX_EMBEDDING_DIM",
71
+ str(self.vertex_output_dimensionality or 3072),
72
+ )
73
+ )
74
+ else:
75
+ model_device = self.device or "cpu"
76
+ print(
77
+ f"[embeddings] Loading local embedding model={self.model_name} on device={model_device}",
78
+ flush=True,
79
+ )
80
+ started_at = time.perf_counter()
81
+ self.model = SentenceTransformer(
82
+ self.model_name,
83
+ trust_remote_code=True,
84
+ device=model_device,
85
+ )
86
+ self.embedding_dim = self.model.get_sentence_embedding_dimension()
87
+ elapsed = time.perf_counter() - started_at
88
+ print(
89
+ f"[embeddings] Model ready dim={self.embedding_dim} load_time={elapsed:.2f}s",
90
+ flush=True,
91
+ )
92
+
93
+ def embed_text(self, text: str) -> np.ndarray:
94
+ if self.provider == "openai":
95
+ return self.embed_batch([text])[0]
96
+ if self.provider == "vertex_ai":
97
+ return self._embed_with_vertex(
98
+ [text],
99
+ task_type=self.vertex_task_type_query,
100
+ )[0]
101
+ query_text = f"{self.query_prefix}: {text}" if self.query_prefix else text
102
+ return self._encode_with_backoff([query_text], prompt_name=self.query_prompt_name)[0]
103
+
104
+ def embed_batch(
105
+ self,
106
+ texts: List[str],
107
+ batch_size: int = None,
108
+ progress_callback: Optional[Callable[[int, int], None]] = None,
109
+ ) -> np.ndarray:
110
+ if not texts:
111
+ return np.array([], dtype="float32")
112
+
113
+ if self.provider == "openai":
114
+ response = self.client.embeddings.create(
115
+ model=self.model_name or "text-embedding-3-small",
116
+ input=texts,
117
+ )
118
+ embeddings = [item.embedding for item in response.data]
119
+ if progress_callback:
120
+ progress_callback(len(texts), len(texts))
121
+ return np.array(embeddings, dtype="float32")
122
+ if self.provider == "vertex_ai":
123
+ return self._embed_batch_with_vertex(
124
+ texts=texts,
125
+ batch_size=batch_size,
126
+ progress_callback=progress_callback,
127
+ )
128
+
129
+ effective_batch_size = max(1, batch_size or self.batch_size)
130
+ all_embeddings = []
131
+ total = len(texts)
132
+
133
+ for start in range(0, total, effective_batch_size):
134
+ batch = texts[start : start + effective_batch_size]
135
+ batch_number = (start // effective_batch_size) + 1
136
+ total_batches = (total + effective_batch_size - 1) // effective_batch_size
137
+ print(
138
+ f"[embeddings] Encoding batch {batch_number}/{total_batches} "
139
+ f"items={len(batch)} progress={start}/{total}",
140
+ flush=True,
141
+ )
142
+ started_at = time.perf_counter()
143
+ batch_embeddings = self._encode_with_backoff(
144
+ batch,
145
+ batch_size=min(effective_batch_size, len(batch)),
146
+ )
147
+ all_embeddings.append(batch_embeddings)
148
+ elapsed = time.perf_counter() - started_at
149
+ print(
150
+ f"[embeddings] Finished batch {batch_number}/{total_batches} "
151
+ f"elapsed={elapsed:.2f}s progress={min(start + len(batch), total)}/{total}",
152
+ flush=True,
153
+ )
154
+ if progress_callback:
155
+ progress_callback(min(start + len(batch), total), total)
156
+
157
+ return np.vstack(all_embeddings).astype("float32")
158
+
159
+ def _embed_batch_with_vertex(
160
+ self,
161
+ texts: List[str],
162
+ batch_size: int = None,
163
+ progress_callback: Optional[Callable[[int, int], None]] = None,
164
+ ) -> np.ndarray:
165
+ effective_batch_size = max(1, batch_size or self.batch_size)
166
+ all_embeddings = []
167
+ total = len(texts)
168
+
169
+ for start in range(0, total, effective_batch_size):
170
+ batch = texts[start : start + effective_batch_size]
171
+ batch_number = (start // effective_batch_size) + 1
172
+ total_batches = (total + effective_batch_size - 1) // effective_batch_size
173
+ print(
174
+ f"[embeddings] Vertex batch {batch_number}/{total_batches} "
175
+ f"items={len(batch)} progress={start}/{total}",
176
+ flush=True,
177
+ )
178
+ started_at = time.perf_counter()
179
+ batch_embeddings = self._embed_with_vertex(
180
+ batch,
181
+ task_type=self.vertex_task_type_document,
182
+ )
183
+ all_embeddings.append(batch_embeddings)
184
+ elapsed = time.perf_counter() - started_at
185
+ print(
186
+ f"[embeddings] Finished Vertex batch {batch_number}/{total_batches} "
187
+ f"elapsed={elapsed:.2f}s progress={min(start + len(batch), total)}/{total}",
188
+ flush=True,
189
+ )
190
+ if progress_callback:
191
+ progress_callback(min(start + len(batch), total), total)
192
+
193
+ return np.vstack(all_embeddings).astype("float32")
194
+
195
+ def _embed_with_vertex(self, texts: List[str], task_type: str) -> np.ndarray:
196
+ config = {
197
+ "task_type": task_type,
198
+ }
199
+ if self.vertex_output_dimensionality:
200
+ config["output_dimensionality"] = self.vertex_output_dimensionality
201
+
202
+ response = self.client.models.embed_content(
203
+ model=self.model_name,
204
+ contents=texts,
205
+ config=config,
206
+ )
207
+ embeddings = getattr(response, "embeddings", None)
208
+ if not embeddings:
209
+ raise RuntimeError("Vertex AI embeddings returned an empty response.")
210
+
211
+ values = []
212
+ for item in embeddings:
213
+ if hasattr(item, "values"):
214
+ values.append(item.values)
215
+ elif isinstance(item, dict):
216
+ values.append(item.get("values"))
217
+ else:
218
+ values.append(getattr(item, "embedding", None))
219
+
220
+ if not values or any(vector is None for vector in values):
221
+ raise RuntimeError("Vertex AI embeddings response could not be parsed.")
222
+
223
+ return np.array(values, dtype="float32")
224
+
225
+ def _encode_with_backoff(
226
+ self,
227
+ texts: List[str],
228
+ batch_size: int = None,
229
+ prompt_name: str = None,
230
+ ) -> np.ndarray:
231
+ effective_batch_size = max(1, batch_size or self.batch_size)
232
+
233
+ while True:
234
+ try:
235
+ encode_kwargs = {
236
+ "sentences": texts,
237
+ "batch_size": effective_batch_size,
238
+ "show_progress_bar": len(texts) > effective_batch_size,
239
+ "convert_to_numpy": True,
240
+ "normalize_embeddings": True,
241
+ }
242
+ if prompt_name:
243
+ encode_kwargs["prompt_name"] = prompt_name
244
+
245
+ embeddings = self.model.encode(
246
+ **encode_kwargs,
247
+ )
248
+ return embeddings.astype("float32")
249
+ except RuntimeError as exc:
250
+ message = str(exc).lower()
251
+ is_memory_error = "out of memory" in message or "mps" in message
252
+ if not is_memory_error or effective_batch_size == 1:
253
+ raise
254
+ print(
255
+ f"[embeddings] Retrying batch with smaller size due to memory pressure: "
256
+ f"{effective_batch_size} -> {max(1, effective_batch_size // 2)}",
257
+ flush=True,
258
+ )
259
+ effective_batch_size = max(1, effective_batch_size // 2)
260
+
261
+ def get_embedding_dim(self) -> int:
262
+ return self.embedding_dim
263
+
264
+ def _resolve_provider(self, configured_provider: str) -> str:
265
+ if configured_provider != "auto":
266
+ return configured_provider
267
+ if self._is_hf_space() or self._is_test_context():
268
+ return "local"
269
+ return "vertex_ai"
270
+
271
+ def _resolve_model_name(self) -> str:
272
+ explicit_model = os.getenv("EMBEDDING_MODEL")
273
+ if explicit_model:
274
+ return explicit_model
275
+ if self.provider == "vertex_ai":
276
+ return os.getenv("VERTEX_EMBEDDING_MODEL", "gemini-embedding-001")
277
+ if self._is_hf_space() or self._is_test_context():
278
+ return os.getenv(
279
+ "LIGHTWEIGHT_LOCAL_EMBEDDING_MODEL",
280
+ "sentence-transformers/all-MiniLM-L6-v2",
281
+ )
282
+ return os.getenv("LOCAL_EMBEDDING_MODEL", "nomic-ai/CodeRankEmbed")
283
+
284
+ def _is_hf_space(self) -> bool:
285
+ return bool(os.getenv("SPACE_ID") or os.getenv("HF_SPACE_ID"))
286
+
287
+ def _is_test_context(self) -> bool:
288
+ app_env = os.getenv("APP_ENV", os.getenv("ENVIRONMENT", "")).lower()
289
+ return app_env == "test" or bool(os.getenv("PYTEST_CURRENT_TEST"))
290
+
291
+ def _optional_int(self, value: Optional[str]) -> Optional[int]:
292
+ if value is None or not str(value).strip():
293
+ return None
294
+ return int(value)
src/hybrid_search.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import re
3
+ from collections import defaultdict
4
+ from typing import List
5
+
6
+ from rank_bm25 import BM25Okapi
7
+ from sentence_transformers import CrossEncoder
8
+
9
+
10
+ TOKEN_RE = re.compile(r"[A-Za-z_][A-Za-z0-9_./:-]*")
11
+
12
+
13
+ def tokenize(text: str) -> List[str]:
14
+ return [token.lower() for token in TOKEN_RE.findall(text)]
15
+
16
+
17
+ class HybridSearchEngine:
18
+ def __init__(self, reranker_model: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
19
+ self.reranker = CrossEncoder(reranker_model)
20
+
21
+ def build_for_repository(self, repo_id: int, chunks: List[dict]):
22
+ return None
23
+
24
+ def remove_repository(self, repo_id: int):
25
+ return None
26
+
27
+ def bm25_search(self, chunks: List[dict], query: str, top_k: int = 12) -> List[dict]:
28
+ if not chunks:
29
+ return []
30
+
31
+ tokens = tokenize(query)
32
+ if not tokens:
33
+ return []
34
+
35
+ corpus_tokens = [tokenize(chunk["searchable_text"]) for chunk in chunks]
36
+ bm25 = BM25Okapi(corpus_tokens) if corpus_tokens else None
37
+ if not bm25:
38
+ return []
39
+
40
+ scores = bm25.get_scores(tokens)
41
+ ranked = sorted(
42
+ zip(chunks, scores),
43
+ key=lambda item: item[1],
44
+ reverse=True,
45
+ )[:top_k]
46
+
47
+ results = []
48
+ for rank, (chunk, score) in enumerate(ranked, start=1):
49
+ chunk = dict(chunk)
50
+ chunk["bm25_score"] = float(score)
51
+ chunk["bm25_rank"] = rank
52
+ results.append(chunk)
53
+ return results
54
+
55
+ def reciprocal_rank_fusion(
56
+ self,
57
+ lexical_results: List[dict],
58
+ semantic_results: List[dict],
59
+ top_k: int = 10,
60
+ k: int = 60,
61
+ ) -> List[dict]:
62
+ fused = defaultdict(lambda: {"rrf_score": 0.0})
63
+
64
+ for rank, item in enumerate(lexical_results, start=1):
65
+ fused[item["id"]]["rrf_score"] += 1.0 / (k + rank)
66
+ fused[item["id"]].update(item)
67
+
68
+ for rank, item in enumerate(semantic_results, start=1):
69
+ fused[item["id"]]["rrf_score"] += 1.0 / (k + rank)
70
+ fused[item["id"]].update(item)
71
+
72
+ merged = sorted(fused.values(), key=lambda item: item["rrf_score"], reverse=True)
73
+ return merged[:top_k]
74
+
75
+ def rerank(self, query: str, candidates: List[dict], top_k: int = 6) -> List[dict]:
76
+ if not candidates:
77
+ return []
78
+
79
+ pairs = [
80
+ [query, f'{item["file_path"]}\n{item.get("signature") or ""}\n{item["content"]}']
81
+ for item in candidates
82
+ ]
83
+ scores = self.reranker.predict(pairs)
84
+
85
+ reranked = []
86
+ for item, score in zip(candidates, scores):
87
+ enriched = dict(item)
88
+ enriched["rerank_score"] = float(score)
89
+ reranked.append(enriched)
90
+
91
+ reranked.sort(key=lambda item: item["rerank_score"], reverse=True)
92
+ return reranked[:top_k]
93
+
94
+ @staticmethod
95
+ def normalize_semantic_results(results: List[dict]) -> List[dict]:
96
+ normalized = []
97
+ for rank, item in enumerate(results, start=1):
98
+ enriched = dict(item)
99
+ enriched["semantic_rank"] = rank
100
+ enriched["semantic_score"] = float(item.get("semantic_score", 0.0))
101
+ normalized.append(enriched)
102
+ return normalized
src/rag_system.py ADDED
@@ -0,0 +1,1145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from datetime import datetime, timedelta
4
+ from typing import Dict, List, Optional
5
+
6
+ from openai import OpenAI
7
+
8
+ from src.code_parser import CodeParser
9
+ from src.database import Repository, get_db_session, init_db, resolve_database_url
10
+ from src.embeddings import EmbeddingGenerator
11
+ from src.hybrid_search import HybridSearchEngine
12
+ from src.repo_fetcher import RepoFetcher
13
+ from src.vector_store import QdrantVectorStore
14
+
15
+
16
+ class SessionCancelledError(RuntimeError):
17
+ pass
18
+
19
+
20
+ class CodebaseRAGSystem:
21
+ def __init__(
22
+ self,
23
+ database_url: str = None,
24
+ repo_dir: str = None,
25
+ index_path: str = None,
26
+ ):
27
+ self.database_url = database_url or os.getenv(
28
+ "DATABASE_URL", "sqlite:///./codebase_rag.db"
29
+ )
30
+ self.database_url = resolve_database_url(self.database_url)
31
+ init_db(self.database_url)
32
+ print(f"[database] Using database_url={self.database_url}", flush=True)
33
+
34
+ self.repo_fetcher = RepoFetcher(base_dir=repo_dir)
35
+ self.parser = CodeParser()
36
+ self.embedder = EmbeddingGenerator()
37
+ self.vector_store = QdrantVectorStore(
38
+ embedding_dim=self.embedder.get_embedding_dim(),
39
+ index_path=index_path or "./data/faiss/codebase_index",
40
+ persist=False,
41
+ )
42
+ self.hybrid_search = HybridSearchEngine(
43
+ reranker_model=os.getenv(
44
+ "RERANKER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2"
45
+ )
46
+ )
47
+ self.app_env = os.getenv("APP_ENV", os.getenv("ENVIRONMENT", "local")).lower()
48
+ self.llm_provider = os.getenv("LLM_PROVIDER", "vertex_ai").lower()
49
+ self.llm_client = None
50
+ self.llm_model = ""
51
+ self._configure_llm()
52
+ self.session_ttl_minutes = int(os.getenv("SESSION_TTL_MINUTES", "120"))
53
+ self.indexing_progress: Dict[int, dict] = {}
54
+ self.repo_chunks: Dict[int, List[dict]] = {}
55
+ self.cancelled_repo_ids = set()
56
+ self.rebuild_indexes()
57
+
58
+ def rebuild_indexes(self):
59
+ session = get_db_session(self.database_url)
60
+ try:
61
+ self.vector_store.clear()
62
+ self.repo_chunks.clear()
63
+ self.indexing_progress.clear()
64
+ self.cancelled_repo_ids.clear()
65
+ repos = session.query(Repository).all()
66
+ self._delete_repositories(session, repos, track_cancellation=False)
67
+ self.cancelled_repo_ids.clear()
68
+ session.commit()
69
+ finally:
70
+ session.close()
71
+
72
+ def create_or_reset_repository(self, github_url: str, session_key: str) -> Repository:
73
+ info = self.repo_fetcher.parse_github_url(github_url)
74
+ registry_key = self._build_registry_key(session_key, github_url)
75
+ session = get_db_session(self.database_url)
76
+ try:
77
+ self._cleanup_expired_sessions(session)
78
+ repo = session.query(Repository).filter_by(github_url=registry_key).first()
79
+ if repo is None:
80
+ repo = Repository(
81
+ github_url=registry_key,
82
+ source_url=github_url,
83
+ session_key=session_key,
84
+ session_expires_at=self._session_expiry(),
85
+ owner=info["owner"],
86
+ name=info["repo"],
87
+ branch=info["branch"],
88
+ status="queued",
89
+ )
90
+ session.add(repo)
91
+ session.flush()
92
+ self.cancelled_repo_ids.discard(repo.id)
93
+ else:
94
+ repo.source_url = github_url
95
+ repo.session_key = session_key
96
+ repo.session_expires_at = self._session_expiry()
97
+ repo.owner = info["owner"]
98
+ repo.name = info["repo"]
99
+ repo.branch = info["branch"]
100
+ repo.status = "queued"
101
+ repo.error_message = None
102
+ repo.file_count = 0
103
+ repo.chunk_count = 0
104
+ repo.indexed_at = None
105
+ self.cancelled_repo_ids.discard(repo.id)
106
+ self.hybrid_search.remove_repository(repo.id)
107
+ self.vector_store.remove_repository(repo.id)
108
+ self.repo_chunks.pop(repo.id, None)
109
+
110
+ session.commit()
111
+ session.refresh(repo)
112
+ return repo
113
+ finally:
114
+ session.close()
115
+
116
+ def index_repository(self, repo_id: int):
117
+ session = get_db_session(self.database_url)
118
+ try:
119
+ self._cleanup_expired_sessions(session)
120
+ repo = session.query(Repository).filter_by(id=repo_id).first()
121
+ if repo is None:
122
+ raise ValueError("Repository not found")
123
+ self._ensure_repo_not_cancelled(repo.id)
124
+ print(f"[indexing] Starting repository index repo_id={repo.id}", flush=True)
125
+
126
+ repo.status = "indexing"
127
+ repo.error_message = None
128
+ repo.session_expires_at = self._session_expiry()
129
+ session.commit()
130
+ self._set_progress(repo.id, phase="cloning", message="Cloning repository")
131
+
132
+ clone_info = self.repo_fetcher.clone_repository(repo.source_url or repo.github_url)
133
+ self._ensure_repo_not_cancelled(repo.id)
134
+ repo.local_path = None
135
+ repo.branch = clone_info["branch"]
136
+ print(
137
+ f"[indexing] Repository cloned repo_id={repo.id} branch={repo.branch} "
138
+ f"path={clone_info['local_path']}",
139
+ flush=True,
140
+ )
141
+
142
+ source_files = list(self.repo_fetcher.iter_source_files(clone_info["local_path"]))
143
+ total_files = len(source_files)
144
+ print(
145
+ f"[indexing] Found {total_files} source files for repo_id={repo.id}",
146
+ flush=True,
147
+ )
148
+ self._set_progress(
149
+ repo.id,
150
+ phase="parsing",
151
+ message=f"Scanning {total_files} source files",
152
+ total_files=total_files,
153
+ processed_files=0,
154
+ discovered_chunks=0,
155
+ )
156
+
157
+ chunk_payloads = []
158
+ file_count = 0
159
+ for index, file_path in enumerate(source_files, start=1):
160
+ file_chunks = self.parser.chunk_file(str(file_path), clone_info["local_path"])
161
+ if not file_chunks:
162
+ self._set_progress(
163
+ repo.id,
164
+ phase="parsing",
165
+ message=f"Parsed {index}/{total_files} files",
166
+ total_files=total_files,
167
+ processed_files=index,
168
+ discovered_chunks=len(chunk_payloads),
169
+ )
170
+ continue
171
+ file_count += 1
172
+ chunk_payloads.extend(file_chunks)
173
+ self._set_progress(
174
+ repo.id,
175
+ phase="parsing",
176
+ message=f"Parsed {index}/{total_files} files",
177
+ total_files=total_files,
178
+ processed_files=index,
179
+ discovered_chunks=len(chunk_payloads),
180
+ )
181
+
182
+ searchable_texts = [chunk["searchable_text"] for chunk in chunk_payloads]
183
+ print(
184
+ f"[indexing] Parsed repo_id={repo.id} files={file_count} chunks={len(searchable_texts)}",
185
+ flush=True,
186
+ )
187
+ self._set_progress(
188
+ repo.id,
189
+ phase="embedding",
190
+ message=f"Embedding {len(searchable_texts)} chunks",
191
+ total_files=total_files,
192
+ processed_files=total_files,
193
+ discovered_chunks=len(chunk_payloads),
194
+ total_chunks=len(chunk_payloads),
195
+ embedded_chunks=0,
196
+ )
197
+ embeddings = self.embedder.embed_batch(
198
+ searchable_texts,
199
+ progress_callback=lambda completed, total: self._set_progress(
200
+ repo.id,
201
+ phase="embedding",
202
+ message=f"Embedding chunks ({completed}/{total})",
203
+ total_files=total_files,
204
+ processed_files=total_files,
205
+ discovered_chunks=len(chunk_payloads),
206
+ total_chunks=total,
207
+ embedded_chunks=completed,
208
+ ),
209
+ )
210
+ self._ensure_repo_not_cancelled(repo.id)
211
+
212
+ vector_metadata = []
213
+ for chunk in chunk_payloads:
214
+ vector_metadata.append(
215
+ {
216
+ "repository_id": repo.id,
217
+ "file_path": chunk["file_path"],
218
+ "language": chunk["language"],
219
+ "symbol_name": chunk["symbol_name"],
220
+ "symbol_type": chunk["symbol_type"],
221
+ "line_start": chunk["line_start"],
222
+ "line_end": chunk["line_end"],
223
+ "signature": chunk["signature"],
224
+ "content": chunk["content"],
225
+ }
226
+ )
227
+
228
+ embedding_ids = self.vector_store.add_embeddings(embeddings, vector_metadata)
229
+ print(
230
+ f"[indexing] Uploaded {len(embedding_ids)} embeddings to vector store for repo_id={repo.id}",
231
+ flush=True,
232
+ )
233
+ self._set_progress(
234
+ repo.id,
235
+ phase="saving",
236
+ message="Saving chunks and search indexes",
237
+ total_files=total_files,
238
+ processed_files=total_files,
239
+ discovered_chunks=len(chunk_payloads),
240
+ )
241
+
242
+ created_rows = []
243
+ for chunk, embedding_id in zip(chunk_payloads, embedding_ids):
244
+ row = {
245
+ **chunk,
246
+ "id": embedding_id,
247
+ "repository_id": repo.id,
248
+ "embedding_id": embedding_id,
249
+ }
250
+ created_rows.append(row)
251
+
252
+ repo.status = "indexed"
253
+ repo.file_count = file_count
254
+ repo.chunk_count = len(created_rows)
255
+ repo.indexed_at = datetime.utcnow()
256
+ repo.session_expires_at = self._session_expiry()
257
+ self._ensure_repo_still_exists(session, repo.id)
258
+ self._ensure_repo_not_cancelled(repo.id)
259
+ session.commit()
260
+
261
+ serialized = [self._serialize_chunk(chunk) for chunk in created_rows]
262
+ self.repo_chunks[repo.id] = serialized
263
+ self.vector_store.save()
264
+ self.indexing_progress.pop(repo.id, None)
265
+ self.cancelled_repo_ids.discard(repo.id)
266
+ self.repo_fetcher.cleanup_repository(clone_info["local_path"])
267
+ print(f"[indexing] Repository index complete repo_id={repo.id}", flush=True)
268
+ except Exception as exc:
269
+ print(f"[indexing] Repository index failed repo_id={repo_id} error={exc}", flush=True)
270
+ session.rollback()
271
+ self.vector_store.remove_repository(repo_id)
272
+ self.repo_chunks.pop(repo_id, None)
273
+ self.hybrid_search.remove_repository(repo_id)
274
+ repo = session.query(Repository).filter_by(id=repo_id).first()
275
+ if repo:
276
+ if repo_id in self.cancelled_repo_ids:
277
+ session.delete(repo)
278
+ else:
279
+ repo.status = "failed"
280
+ repo.error_message = str(exc)
281
+ session.commit()
282
+ try:
283
+ if "clone_info" in locals():
284
+ self.repo_fetcher.cleanup_repository(clone_info["local_path"])
285
+ except Exception:
286
+ pass
287
+ self.indexing_progress.pop(repo_id, None)
288
+ if isinstance(exc, SessionCancelledError):
289
+ return
290
+ raise
291
+ finally:
292
+ session.close()
293
+
294
+ def list_repositories(self) -> List[dict]:
295
+ raise NotImplementedError
296
+
297
+ def list_repositories_for_session(self, session_key: str) -> List[dict]:
298
+ session = get_db_session(self.database_url)
299
+ try:
300
+ self._cleanup_expired_sessions(session)
301
+ repos = (
302
+ session.query(Repository)
303
+ .filter_by(session_key=session_key)
304
+ .order_by(Repository.updated_at.desc())
305
+ .all()
306
+ )
307
+ self._touch_session(session, session_key)
308
+ return [self._serialize_repo(repo) for repo in repos]
309
+ finally:
310
+ session.close()
311
+
312
+ def get_repository(self, repo_id: int) -> Optional[dict]:
313
+ raise NotImplementedError
314
+
315
+ def get_repository_for_session(self, repo_id: int, session_key: str) -> Optional[dict]:
316
+ session = get_db_session(self.database_url)
317
+ try:
318
+ self._cleanup_expired_sessions(session)
319
+ repo = (
320
+ session.query(Repository)
321
+ .filter_by(id=repo_id, session_key=session_key)
322
+ .first()
323
+ )
324
+ self._touch_session(session, session_key)
325
+ return self._serialize_repo(repo) if repo else None
326
+ finally:
327
+ session.close()
328
+
329
+ def answer_question(
330
+ self,
331
+ repo_id: int,
332
+ session_key: str,
333
+ question: str,
334
+ top_k: int = 8,
335
+ history: Optional[List[object]] = None,
336
+ ) -> dict:
337
+ session = get_db_session(self.database_url)
338
+ try:
339
+ self._cleanup_expired_sessions(session)
340
+ repo = (
341
+ session.query(Repository)
342
+ .filter_by(id=repo_id, session_key=session_key)
343
+ .first()
344
+ )
345
+ if repo is None:
346
+ raise ValueError("Repository not found")
347
+ if repo.status != "indexed":
348
+ raise ValueError("Repository is not ready for questions yet")
349
+ if repo_id not in self.repo_chunks:
350
+ raise ValueError("Session cache expired. Re-index the repository and try again.")
351
+ self._touch_session(session, session_key)
352
+
353
+ normalized_history = self._normalize_history(history or [])
354
+ question_intent = self._question_intent(question)
355
+ search_depth = top_k * 4 if question_intent in {"api", "implementation", "cross_file", "setup"} else top_k * 2
356
+ retrieval_query = self._build_retrieval_query(question, normalized_history)
357
+ query_embedding = self.embedder.embed_text(retrieval_query)
358
+ semantic_hits = []
359
+ for score, meta in self.vector_store.search(query_embedding, k=search_depth, repo_filter=repo_id):
360
+ serialized = dict(meta)
361
+ serialized["semantic_score"] = score
362
+ semantic_hits.append(serialized)
363
+
364
+ lexical_hits = self.hybrid_search.bm25_search(
365
+ self.repo_chunks[repo_id],
366
+ retrieval_query,
367
+ top_k=search_depth,
368
+ )
369
+ semantic_hits = self.hybrid_search.normalize_semantic_results(semantic_hits)
370
+ fused = self.hybrid_search.reciprocal_rank_fusion(lexical_hits, semantic_hits, top_k=search_depth)
371
+ rerank_query = retrieval_query if question_intent in {"api", "implementation", "cross_file", "setup"} else question
372
+ reranked = self.hybrid_search.rerank(rerank_query, fused, top_k=search_depth)
373
+ reranked = self._prioritize_results(question, retrieval_query, reranked, top_k=top_k)
374
+ reranked = self._select_answer_sources(question, reranked, top_k=top_k)
375
+
376
+ answer = self._generate_answer(repo, question, reranked, normalized_history)
377
+
378
+ return answer
379
+ finally:
380
+ session.close()
381
+
382
+ def end_session(self, session_key: str):
383
+ session = get_db_session(self.database_url)
384
+ try:
385
+ repos = session.query(Repository).filter_by(session_key=session_key).all()
386
+ self._delete_repositories(session, repos)
387
+ session.commit()
388
+ finally:
389
+ session.close()
390
+
391
+ def _generate_answer(
392
+ self,
393
+ repo: Repository,
394
+ question: str,
395
+ sources: List[dict],
396
+ history: Optional[List[dict]] = None,
397
+ ) -> dict:
398
+ if not sources:
399
+ return {
400
+ "answer": "I could not find enough grounded evidence in the indexed codebase to answer that confidently.",
401
+ "confidence": "low",
402
+ "sources": [],
403
+ "repo": self._serialize_repo(repo),
404
+ }
405
+
406
+ context_blocks = []
407
+ slim_sources = []
408
+ for index, source in enumerate(sources, start=1):
409
+ context_blocks.append(
410
+ "\n".join(
411
+ [
412
+ f"[Source {index}]",
413
+ f"File: {source['file_path']}",
414
+ f"Symbol: {source['symbol_name']}",
415
+ f"Lines: {source['line_start']}-{source['line_end']}",
416
+ source["content"][:2500],
417
+ ]
418
+ )
419
+ )
420
+ slim_sources.append(
421
+ {
422
+ "file_path": source["file_path"],
423
+ "language": source["language"],
424
+ "symbol_name": source["symbol_name"],
425
+ "symbol_type": source["symbol_type"],
426
+ "line_start": source["line_start"],
427
+ "line_end": source["line_end"],
428
+ "signature": source["signature"],
429
+ "snippet": source["content"],
430
+ "semantic_score": round(float(source.get("semantic_score", 0.0)), 4),
431
+ "bm25_score": round(float(source.get("bm25_score", 0.0)), 4),
432
+ "rrf_score": round(float(source.get("rrf_score", 0.0)), 4),
433
+ "rerank_score": round(float(source.get("rerank_score", 0.0)), 4),
434
+ }
435
+ )
436
+
437
+ wants_repo_overview = self._is_repo_overview_question(question)
438
+ question_intent = self._question_intent(question)
439
+
440
+ system_prompt = """
441
+ You are answering questions as a knowledgeable teammate who has carefully read this repository.
442
+
443
+ Rules:
444
+ 1. Use only the supplied repository context.
445
+ 2. Answer conversationally and directly, as if the repo is explaining itself to the user.
446
+ 3. Do not say "Based on the provided context", "The repository is about", or similar throat-clearing phrases.
447
+ 4. Be concrete about files, functions, and behavior.
448
+ 5. If evidence is partial, clearly separate what is certain from what is inferred.
449
+ 6. Respond in Markdown, not JSON.
450
+ 7. Keep the answer complete. Do not stop mid-sentence.
451
+ 8. Use short sections or bullets only when they genuinely help readability.
452
+ 9. Do not leave unfinished headings, dangling bullets, or trailing markdown markers like #, ##, or ###.
453
+ 10. Do not include inline citation markers like [Source 1] in the prose. The UI already shows sources separately.
454
+ 11. Do not make claims that are not directly supported by the supplied sources.
455
+ 12. Prefer the most canonical source files for API and implementation questions, such as package exports, core modules, and session/query code, over tutorial prose when they disagree in specificity.
456
+ 13. Keep the answer tight. Lead with the direct answer, then add only the most important supporting detail.
457
+ """
458
+
459
+ if wants_repo_overview:
460
+ system_prompt += """
461
+ 14. For repository overview questions, lead with a direct one or two sentence summary of what the repo does.
462
+ 15. Prioritize README and top-level documentation when they are present, then use code to support the explanation.
463
+ 16. Mention the main workflow, core stack, and any important product constraints the user would care about.
464
+ 17. Keep the answer polished and self-contained, like the overview a real user expects when they ask what a repo is about.
465
+ """
466
+ elif question_intent in {"api", "implementation", "cross_file", "error_handling", "setup"}:
467
+ system_prompt += """
468
+ 14. For API, implementation, setup, and cross-file questions, prefer the smallest correct answer that is directly supported by code.
469
+ 15. If a detail comes only from docs or examples and not from the canonical implementation, say that clearly instead of presenting it as core behavior.
470
+ 16. When describing exports or code paths, name the file first and keep the explanation precise.
471
+ 17. Default to one short paragraph plus at most 3 short bullets. Avoid long explanatory walkthroughs unless the question explicitly asks for depth.
472
+ """
473
+
474
+ joined_context = "\n\n".join(context_blocks)
475
+ user_prompt = f"""
476
+ Repository: {repo.owner}/{repo.name}
477
+ Question: {question}
478
+ Recent conversation:
479
+ {self._format_history(history or [])}
480
+
481
+ Context:
482
+ {joined_context}
483
+ """
484
+
485
+ answer_text, finish_reason = self._generate_markdown_response(system_prompt, user_prompt)
486
+ if self._looks_incomplete(answer_text, finish_reason):
487
+ repair_prompt = f"""
488
+ The draft answer below appears to be cut off or incomplete.
489
+ Rewrite it into a complete final answer using the same repository context and rules.
490
+
491
+ Draft answer:
492
+ {answer_text}
493
+ """
494
+ answer_text, finish_reason = self._generate_markdown_response(
495
+ system_prompt,
496
+ f"{user_prompt.strip()}\n\n{repair_prompt.strip()}",
497
+ )
498
+ if self._looks_incomplete(answer_text, finish_reason):
499
+ short_prompt = f"""
500
+ Answer the question again, but keep it concise and complete.
501
+ Use 2 short paragraphs or 4-6 bullets max.
502
+ Do not leave the answer unfinished.
503
+ """
504
+ answer_text, _ = self._generate_markdown_response(
505
+ system_prompt,
506
+ f"{user_prompt.strip()}\n\n{short_prompt.strip()}",
507
+ )
508
+ answer_text = self._finalize_answer(answer_text)
509
+ confidence = self._estimate_confidence(sources)
510
+ summary = " ".join(answer_text.split())[:160] if answer_text else ""
511
+ citations = [
512
+ {
513
+ "source": index,
514
+ "reason": f"Relevant context from {source['file_path']}",
515
+ }
516
+ for index, source in enumerate(sources[: min(len(sources), 4)], start=1)
517
+ ]
518
+
519
+ return {
520
+ "answer": answer_text,
521
+ "confidence": confidence,
522
+ "summary": summary,
523
+ "citations": citations,
524
+ "sources": slim_sources,
525
+ "repo": self._serialize_repo(repo),
526
+ }
527
+
528
+ def _configure_llm(self):
529
+ if self.llm_provider == "groq":
530
+ self.llm_client = OpenAI(
531
+ api_key=os.getenv("GROQ_API_KEY"),
532
+ base_url=os.getenv("GROQ_BASE_URL", "https://api.groq.com/openai/v1"),
533
+ )
534
+ self.llm_model = os.getenv("GROQ_MODEL", "llama-3.3-70b-versatile")
535
+ return
536
+
537
+ if self.llm_provider == "vertex_ai":
538
+ try:
539
+ from google import genai
540
+ except ImportError as exc:
541
+ raise RuntimeError(
542
+ "Vertex AI LLM support requires the `google-genai` package. "
543
+ "Install server dependencies before running local or eval queries."
544
+ ) from exc
545
+
546
+ project = os.getenv("GOOGLE_CLOUD_PROJECT")
547
+ location = os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1")
548
+ if not project:
549
+ raise RuntimeError(
550
+ "GOOGLE_CLOUD_PROJECT must be set when using Vertex AI Gemini."
551
+ )
552
+
553
+ self.llm_client = genai.Client(
554
+ vertexai=True,
555
+ project=project,
556
+ location=location,
557
+ )
558
+ self.llm_model = os.getenv("VERTEX_LLM_MODEL", "gemini-2.5-pro")
559
+ return
560
+
561
+ raise RuntimeError(f"Unsupported LLM provider: {self.llm_provider}")
562
+
563
+ def _generate_markdown_response(self, system_prompt: str, user_prompt: str) -> tuple[str, str]:
564
+ if self.llm_provider == "groq":
565
+ response = self.llm_client.chat.completions.create(
566
+ model=self.llm_model,
567
+ messages=[
568
+ {"role": "system", "content": system_prompt},
569
+ {"role": "user", "content": user_prompt},
570
+ ],
571
+ temperature=0.1,
572
+ max_tokens=1600,
573
+ )
574
+ content = response.choices[0].message.content
575
+ finish_reason = getattr(response.choices[0], "finish_reason", "") or ""
576
+ return self._normalize_markdown_answer(content), str(finish_reason)
577
+
578
+ response = self.llm_client.models.generate_content(
579
+ model=self.llm_model,
580
+ contents=f"{system_prompt.strip()}\n\n{user_prompt.strip()}",
581
+ config={
582
+ "temperature": 0.1,
583
+ "max_output_tokens": 2200,
584
+ },
585
+ )
586
+ if not getattr(response, "text", None):
587
+ raise RuntimeError("Vertex AI Gemini returned an empty response.")
588
+ finish_reason = ""
589
+ candidates = getattr(response, "candidates", None) or []
590
+ if candidates:
591
+ finish_reason = str(getattr(candidates[0], "finish_reason", "") or "")
592
+ return self._normalize_markdown_answer(response.text), finish_reason
593
+
594
+ @staticmethod
595
+ def _normalize_markdown_answer(raw_text: str) -> str:
596
+ cleaned = (raw_text or "").strip()
597
+ cleaned = re.sub(r"^```(?:markdown|md)?\s*|\s*```$", "", cleaned, flags=re.IGNORECASE)
598
+ cleaned = re.sub(r"\s*\[(?:Source\s+\d+(?:\s*,\s*Source\s+\d+)*)\]", "", cleaned, flags=re.IGNORECASE)
599
+ cleaned = re.sub(
600
+ r"^(?:based on the provided context[,:\s-]*|from the provided context[,:\s-]*)",
601
+ "",
602
+ cleaned,
603
+ flags=re.IGNORECASE,
604
+ ).strip()
605
+ cleaned = re.sub(
606
+ r"\n(?:#{1,6}|[-*])\s*$",
607
+ "",
608
+ cleaned,
609
+ flags=re.MULTILINE,
610
+ ).strip()
611
+ cleaned = re.sub(r"(?:\n\s*){3,}", "\n\n", cleaned)
612
+ cleaned = cleaned.strip()
613
+ if not cleaned:
614
+ return "I found relevant code context, but the model returned an empty response."
615
+ return cleaned
616
+
617
+ @staticmethod
618
+ def _finalize_answer(answer_text: str) -> str:
619
+ cleaned = (answer_text or "").strip()
620
+ if not cleaned:
621
+ return "I found relevant code context, but the model returned an empty response."
622
+
623
+ # If the tail still looks truncated, trim back to the last complete sentence or list item
624
+ if CodebaseRAGSystem._looks_incomplete(cleaned):
625
+ sentence_match = re.search(r"(?s)^.*[.!?](?:['\"\)`\]]+)?", cleaned)
626
+ if sentence_match:
627
+ trimmed = sentence_match.group(0).strip()
628
+ if len(trimmed.split()) >= 12:
629
+ return trimmed
630
+
631
+ lines = cleaned.splitlines()
632
+ while lines and CodebaseRAGSystem._looks_incomplete(lines[-1]):
633
+ lines.pop()
634
+ candidate = "\n".join(line for line in lines if line.strip()).strip()
635
+ if candidate:
636
+ return candidate
637
+
638
+ return cleaned
639
+
640
+ @staticmethod
641
+ def _looks_incomplete(answer_text: str, finish_reason: str = "") -> bool:
642
+ cleaned = (answer_text or "").strip()
643
+ if not cleaned:
644
+ return True
645
+ finish_reason = (finish_reason or "").strip().lower()
646
+ if finish_reason and finish_reason not in {"stop", "stopsequence", "finish_reason_unspecified"}:
647
+ return True
648
+ if cleaned.endswith(("#", "-", "*", ":", "(", "[", "/", "`")):
649
+ return True
650
+ if cleaned.endswith(("[source", "[source 1", "[source 2", "[source 3", "[source 4")):
651
+ return True
652
+ if cleaned.count("```") % 2 != 0:
653
+ return True
654
+ if cleaned.count("(") > cleaned.count(")"):
655
+ return True
656
+ if cleaned.count("[") > cleaned.count("]"):
657
+ return True
658
+ tokens = re.findall(r"\b[\w'-]+\b", cleaned.lower())
659
+ if not tokens:
660
+ return True
661
+ if tokens[-1] in {"a", "an", "the", "to", "for", "with", "of", "in", "on", "from", "about"}:
662
+ return True
663
+ if len(tokens) >= 20 and cleaned[-1] not in {".", "!", "?", "\"", "'", "`"}:
664
+ return True
665
+ return False
666
+
667
+ @staticmethod
668
+ def _estimate_confidence(sources: List[dict]) -> str:
669
+ if not sources:
670
+ return "low"
671
+
672
+ top = sources[0]
673
+ rerank = float(top.get("rerank_score", 0.0))
674
+ semantic = float(top.get("semantic_score", 0.0))
675
+
676
+ if len(sources) >= 3 and (rerank >= 0.2 or semantic >= 0.75):
677
+ return "high"
678
+ if rerank >= 0.05 or semantic >= 0.45:
679
+ return "medium"
680
+ return "low"
681
+
682
+ def _serialize_repo(self, repo: Repository) -> dict:
683
+ payload = {
684
+ "id": repo.id,
685
+ "github_url": repo.source_url or repo.github_url,
686
+ "owner": repo.owner,
687
+ "name": repo.name,
688
+ "branch": repo.branch,
689
+ "local_path": repo.local_path,
690
+ "status": repo.status,
691
+ "error_message": repo.error_message,
692
+ "file_count": repo.file_count,
693
+ "chunk_count": repo.chunk_count,
694
+ "indexed_at": repo.indexed_at.isoformat() if repo.indexed_at else None,
695
+ "created_at": repo.created_at.isoformat() if repo.created_at else None,
696
+ "updated_at": repo.updated_at.isoformat() if repo.updated_at else None,
697
+ }
698
+ progress = self.indexing_progress.get(repo.id)
699
+ if progress:
700
+ payload["progress"] = progress
701
+ return payload
702
+
703
+ def _set_progress(self, repo_id: int, **progress):
704
+ self.indexing_progress[repo_id] = {
705
+ **self.indexing_progress.get(repo_id, {}),
706
+ **progress,
707
+ "updated_at": datetime.utcnow().isoformat(),
708
+ }
709
+
710
+ def _touch_session(self, session, session_key: str):
711
+ expiry = self._session_expiry()
712
+ repos = session.query(Repository).filter_by(session_key=session_key).all()
713
+ for repo in repos:
714
+ repo.session_expires_at = expiry
715
+ session.commit()
716
+
717
+ def _cleanup_expired_sessions(self, session):
718
+ now = datetime.utcnow()
719
+ expired = (
720
+ session.query(Repository)
721
+ .filter(Repository.session_expires_at.is_not(None))
722
+ .filter(Repository.session_expires_at < now)
723
+ .all()
724
+ )
725
+ if not expired:
726
+ return
727
+ self._delete_repositories(session, expired)
728
+ session.commit()
729
+
730
+ def _delete_repositories(
731
+ self,
732
+ session,
733
+ repos: List[Repository],
734
+ track_cancellation: bool = True,
735
+ ):
736
+ repo_ids = [repo.id for repo in repos]
737
+ for repo_id in repo_ids:
738
+ if track_cancellation:
739
+ self.cancelled_repo_ids.add(repo_id)
740
+ self.hybrid_search.remove_repository(repo_id)
741
+ self.vector_store.remove_repository(repo_id)
742
+ self.repo_chunks.pop(repo_id, None)
743
+ self.indexing_progress.pop(repo_id, None)
744
+ for repo in repos:
745
+ session.delete(repo)
746
+
747
+ def _ensure_repo_not_cancelled(self, repo_id: int):
748
+ if repo_id in self.cancelled_repo_ids:
749
+ raise SessionCancelledError("Session ended before indexing completed.")
750
+
751
+ def _build_retrieval_query(self, question: str, history: List[dict]) -> str:
752
+ normalized = " ".join(question.strip().split())
753
+ if self._is_repo_overview_question(normalized):
754
+ return "\n".join(
755
+ [
756
+ normalized,
757
+ "repository overview purpose main workflow architecture README features stack",
758
+ ]
759
+ )
760
+ if not history:
761
+ return normalized
762
+
763
+ recent_user = [
764
+ turn["content"].strip()
765
+ for turn in reversed(history)
766
+ if turn.get("role") == "user" and turn.get("content", "").strip()
767
+ ]
768
+ recent_assistant = [
769
+ turn["content"].strip()
770
+ for turn in reversed(history)
771
+ if turn.get("role") == "assistant" and turn.get("content", "").strip()
772
+ and self._is_substantive_assistant_message(turn.get("content", ""))
773
+ ]
774
+
775
+ is_follow_up = (
776
+ len(normalized.split()) <= 6
777
+ or bool(re.fullmatch(r"(give|show|where|which|how|what)(?:\s+.+)?", normalized.lower()))
778
+ or any(token in normalized.lower() for token in {"code", "snippet", "implementation"})
779
+ )
780
+ if not is_follow_up or not recent_user:
781
+ return self._expand_query_for_intent(normalized)
782
+
783
+ parts = [self._expand_query_for_intent(normalized)]
784
+ if recent_user:
785
+ parts.append(f"Follow-up to: {recent_user[0]}")
786
+ if recent_assistant:
787
+ parts.append(f"Previous answer: {recent_assistant[0][:300]}")
788
+ return "\n".join(parts)
789
+
790
+ def _prioritize_results(
791
+ self,
792
+ question: str,
793
+ retrieval_query: str,
794
+ results: List[dict],
795
+ top_k: int,
796
+ ) -> List[dict]:
797
+ combined_query = f"{question} {retrieval_query}".lower()
798
+ wants_code = any(
799
+ token in combined_query
800
+ for token in {"code", "snippet", "implementation", "function", "class", "import"}
801
+ )
802
+ wants_docs = self._is_documentation_query(combined_query)
803
+ wants_repo_overview = self._is_repo_overview_question(question) or self._is_repo_overview_question(
804
+ retrieval_query
805
+ )
806
+ question_intent = self._question_intent(question)
807
+
808
+ def sort_key(item: dict):
809
+ is_doc = self._is_doc_source(item)
810
+ return (
811
+ self._canonical_path_priority(item, question),
812
+ self._doc_priority(item),
813
+ 1 if wants_repo_overview and is_doc else 0,
814
+ 1 if (wants_docs and is_doc) or (not wants_docs and not is_doc) else 0,
815
+ 1 if wants_code and not is_doc else 0,
816
+ 1 if question_intent in {"api", "implementation", "cross_file", "error_handling", "setup"} and not is_doc else 0,
817
+ float(item.get("rerank_score", 0.0)),
818
+ float(item.get("semantic_score", 0.0)),
819
+ float(item.get("bm25_score", 0.0)),
820
+ )
821
+
822
+ ranked = sorted(results, key=sort_key, reverse=True)
823
+ if wants_docs or wants_repo_overview:
824
+ return ranked[:top_k]
825
+
826
+ selected = []
827
+ doc_items = []
828
+ for item in ranked:
829
+ if self._is_doc_source(item):
830
+ doc_items.append(item)
831
+ continue
832
+ selected.append(item)
833
+ if len(selected) == top_k:
834
+ return selected
835
+
836
+ selected.extend(doc_items[: max(1, top_k - len(selected))])
837
+ return selected[:top_k]
838
+
839
+ def _select_answer_sources(
840
+ self,
841
+ question: str,
842
+ results: List[dict],
843
+ top_k: int,
844
+ ) -> List[dict]:
845
+ if not results:
846
+ return []
847
+
848
+ intent = self._question_intent(question)
849
+ max_per_file = 2 if intent in {"overview", "docs"} else 1
850
+ selected = []
851
+ file_counts = {}
852
+
853
+ for item in results:
854
+ file_path = item.get("file_path", "")
855
+ count = file_counts.get(file_path, 0)
856
+ if count >= max_per_file:
857
+ continue
858
+ selected.append(item)
859
+ file_counts[file_path] = count + 1
860
+ if len(selected) == top_k:
861
+ break
862
+
863
+ if len(selected) < top_k:
864
+ for item in results:
865
+ if item in selected:
866
+ continue
867
+ selected.append(item)
868
+ if len(selected) == top_k:
869
+ break
870
+
871
+ return selected
872
+
873
+ @staticmethod
874
+ def _is_documentation_query(query: str) -> bool:
875
+ return any(
876
+ token in query
877
+ for token in {
878
+ "readme",
879
+ "docs",
880
+ "documentation",
881
+ "setup",
882
+ "install",
883
+ "installation",
884
+ "usage",
885
+ "overview",
886
+ "what is this repo",
887
+ "what is the repository about",
888
+ "what is the repo about",
889
+ "what does the repo do",
890
+ "what does this repo do",
891
+ "repo summary",
892
+ "repository summary",
893
+ "project summary",
894
+ "feature",
895
+ "features",
896
+ "architecture",
897
+ }
898
+ )
899
+
900
+ @staticmethod
901
+ def _question_intent(question: str) -> str:
902
+ normalized = " ".join((question or "").lower().split())
903
+ if not normalized:
904
+ return "general"
905
+ if CodebaseRAGSystem._is_repo_overview_question(normalized):
906
+ return "overview"
907
+ if any(token in normalized for token in {"error", "invalid", "conflict", "raises", "guard against"}):
908
+ return "error_handling"
909
+ if any(token in normalized for token in {"how are", "how does", "flow", "across files", "code path"}):
910
+ return "cross_file"
911
+ if any(token in normalized for token in {"export", "expose", "import", "public api"}):
912
+ return "api"
913
+ if any(token in normalized for token in {"create", "setup", "install", "configuration", "metadata", "table"}):
914
+ return "setup"
915
+ if any(token in normalized for token in {"function", "method", "class", "implementation", "does ", "what is special"}):
916
+ return "implementation"
917
+ if CodebaseRAGSystem._is_documentation_query(normalized):
918
+ return "docs"
919
+ return "general"
920
+
921
+ def _expand_query_for_intent(self, question: str) -> str:
922
+ normalized = " ".join((question or "").split())
923
+ lowered = normalized.lower()
924
+ hints = []
925
+
926
+ if any(token in lowered for token in {"export", "expose", "import"}):
927
+ hints.extend(["package exports", "__init__.py", "public api", "re-export"])
928
+ if "how is select exposed to users in sqlmodel" in lowered:
929
+ hints.extend(
930
+ [
931
+ "sqlmodel/__init__.py",
932
+ "sqlmodel/sql/expression.py",
933
+ "select re-export",
934
+ "top-level select import",
935
+ ]
936
+ )
937
+ if "select" in lowered:
938
+ hints.extend(
939
+ [
940
+ "select",
941
+ "expression",
942
+ "query builder",
943
+ "public api",
944
+ "sqlmodel/sql/expression.py",
945
+ "sqlmodel/__init__.py",
946
+ "re-export",
947
+ "top-level import",
948
+ ]
949
+ )
950
+ if "session.exec" in lowered or ("session" in lowered and "exec" in lowered):
951
+ hints.extend(["session exec", "orm/session.py", "asyncio/session.py"])
952
+ if "relationship" in lowered:
953
+ hints.extend(["relationship", "Relationship", "main.py"])
954
+ if "field" in lowered:
955
+ hints.extend(["Field", "FieldInfo", "main.py"])
956
+ if "create_engine" in lowered:
957
+ hints.extend(["create_engine", "__init__.py", "re-export"])
958
+ if "create_all" in lowered or "metadata" in lowered:
959
+ hints.extend(
960
+ [
961
+ "metadata create_all",
962
+ "table creation",
963
+ "engine",
964
+ "SQLModel.metadata",
965
+ "README.md",
966
+ "sqlmodel/main.py",
967
+ "docs_src",
968
+ ]
969
+ )
970
+ if "__init__" in lowered or "exports" in lowered:
971
+ hints.extend(["sqlmodel/__init__.py", "package exports", "public api"])
972
+
973
+ if not hints:
974
+ return normalized
975
+ return "\n".join([normalized, " ".join(hints)])
976
+
977
+ @staticmethod
978
+ def _is_repo_overview_question(question: str) -> bool:
979
+ normalized = " ".join((question or "").lower().split())
980
+ return any(
981
+ phrase in normalized
982
+ for phrase in {
983
+ "what is the repo about",
984
+ "what is this repo about",
985
+ "what does the repo do",
986
+ "what does this repo do",
987
+ "what is the repository about",
988
+ "what does the repository do",
989
+ "what is this project about",
990
+ "what does this project do",
991
+ "repo summary",
992
+ "repository summary",
993
+ "project summary",
994
+ "summarize the repo",
995
+ "summarize this repo",
996
+ "repo overview",
997
+ "repository overview",
998
+ "project overview",
999
+ }
1000
+ )
1001
+
1002
+ @staticmethod
1003
+ def _is_doc_source(item: dict) -> bool:
1004
+ file_path = (item.get("file_path") or "").lower()
1005
+ language = (item.get("language") or "").lower()
1006
+ return language == "text" or file_path.endswith(".md") or "/readme" in file_path
1007
+
1008
+ @staticmethod
1009
+ def _doc_priority(item: dict) -> int:
1010
+ file_path = (item.get("file_path") or "").lower()
1011
+ if file_path in {"readme.md", "readme"}:
1012
+ return 3
1013
+ if file_path.startswith("docs/") or "/docs/" in file_path:
1014
+ return 2
1015
+ if file_path.endswith(".md"):
1016
+ return 1
1017
+ return 0
1018
+
1019
+ def _canonical_path_priority(self, item: dict, question: str) -> int:
1020
+ file_path = (item.get("file_path") or "").lower()
1021
+ normalized = " ".join((question or "").lower().split())
1022
+ score = 0
1023
+
1024
+ if file_path == "sqlmodel/__init__.py":
1025
+ score += 4 if any(token in normalized for token in {"export", "expose", "import", "create_engine", "select"}) else 0
1026
+ if file_path == "sqlmodel/sql/expression.py":
1027
+ score += 5 if "select" in normalized else 0
1028
+ if file_path == "sqlmodel/sql/_expression_select_gen.py":
1029
+ score += 2 if "select" in normalized else 0
1030
+ if file_path == "sqlmodel/sql/_expression_select_cls.py":
1031
+ score += 2 if "select" in normalized else 0
1032
+ if file_path == "readme.md":
1033
+ score += 4 if any(token in normalized for token in {"metadata", "create_all", "workflow", "readme"}) else 0
1034
+ if file_path.startswith("docs_src/"):
1035
+ score += 3 if any(token in normalized for token in {"metadata", "create_all", "table", "workflow"}) else 0
1036
+ if file_path == "sqlmodel/main.py":
1037
+ score += 3 if any(token in normalized for token in {"field", "relationship", "metadata", "table", "sqlmodel"}) else 0
1038
+
1039
+ if "__init__.py" in file_path:
1040
+ score += 2 if any(token in normalized for token in {"export", "expose", "import", "public api"}) else 0
1041
+ if any(token in normalized for token in {"select", "expression"}):
1042
+ if "expression" in file_path or "_expression_select" in file_path:
1043
+ score += 3
1044
+ if normalized == "how is select exposed to users in sqlmodel?":
1045
+ if file_path == "sqlmodel/__init__.py":
1046
+ score += 6
1047
+ if file_path == "sqlmodel/sql/expression.py":
1048
+ score += 6
1049
+ if "session" in normalized:
1050
+ if file_path.endswith("session.py") or "/session.py" in file_path:
1051
+ score += 3
1052
+ if "relationship" in normalized and file_path.endswith("main.py"):
1053
+ score += 2
1054
+ if "field" in normalized and file_path.endswith("main.py"):
1055
+ score += 2
1056
+ if any(token in normalized for token in {"create_engine", "export", "expose"}) and "__init__.py" in file_path:
1057
+ score += 2
1058
+ if any(token in normalized for token in {"metadata", "create_all", "table"}) and (
1059
+ "docs_src/" in file_path or file_path.endswith("main.py") or file_path == "readme.md"
1060
+ ):
1061
+ score += 2
1062
+ if self._is_doc_source(item) and self._question_intent(question) in {
1063
+ "api",
1064
+ "implementation",
1065
+ "cross_file",
1066
+ "error_handling",
1067
+ "setup",
1068
+ }:
1069
+ score -= 1
1070
+
1071
+ return score
1072
+
1073
+ @staticmethod
1074
+ def _is_substantive_assistant_message(content: str) -> bool:
1075
+ normalized = " ".join((content or "").strip().lower().split())
1076
+ if len(normalized) < 24:
1077
+ return False
1078
+ if normalized in {
1079
+ "hey, what question do you have for me today?",
1080
+ "ask a question",
1081
+ }:
1082
+ return False
1083
+ return True
1084
+
1085
+ @staticmethod
1086
+ def _normalize_history(history: List[object]) -> List[dict]:
1087
+ normalized = []
1088
+ for turn in history:
1089
+ if isinstance(turn, dict):
1090
+ role = turn.get("role")
1091
+ content = turn.get("content")
1092
+ else:
1093
+ role = getattr(turn, "role", None)
1094
+ content = getattr(turn, "content", None)
1095
+
1096
+ if not role or not content:
1097
+ continue
1098
+
1099
+ normalized.append(
1100
+ {
1101
+ "role": str(role),
1102
+ "content": str(content),
1103
+ }
1104
+ )
1105
+ return normalized
1106
+
1107
+ @staticmethod
1108
+ def _format_history(history: List[dict]) -> str:
1109
+ if not history:
1110
+ return "None"
1111
+ lines = []
1112
+ for turn in history[-4:]:
1113
+ role = turn.get("role", "user").capitalize()
1114
+ content = " ".join(turn.get("content", "").split())
1115
+ if content:
1116
+ lines.append(f"{role}: {content[:400]}")
1117
+ return "\n".join(lines) if lines else "None"
1118
+
1119
+ @staticmethod
1120
+ def _ensure_repo_still_exists(session, repo_id: int):
1121
+ if session.query(Repository.id).filter_by(id=repo_id).first() is None:
1122
+ raise RuntimeError("Repository was removed before indexing completed.")
1123
+
1124
+ def _session_expiry(self) -> datetime:
1125
+ return datetime.utcnow() + timedelta(minutes=self.session_ttl_minutes)
1126
+
1127
+ @staticmethod
1128
+ def _build_registry_key(session_key: str, github_url: str) -> str:
1129
+ return f"{session_key}::{github_url}"
1130
+
1131
+ @staticmethod
1132
+ def _serialize_chunk(chunk: dict) -> dict:
1133
+ return {
1134
+ "id": chunk["id"],
1135
+ "file_path": chunk["file_path"],
1136
+ "language": chunk["language"],
1137
+ "symbol_name": chunk["symbol_name"],
1138
+ "symbol_type": chunk["symbol_type"],
1139
+ "line_start": chunk["line_start"],
1140
+ "line_end": chunk["line_end"],
1141
+ "signature": chunk["signature"],
1142
+ "content": chunk["content"],
1143
+ "searchable_text": chunk["searchable_text"],
1144
+ "metadata_json": chunk.get("metadata_json") or {},
1145
+ }
src/repo_fetcher.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import shutil
4
+ import subprocess
5
+ import tempfile
6
+ from pathlib import Path
7
+ from urllib.parse import urlparse
8
+
9
+
10
+ SUPPORTED_EXTENSIONS = {
11
+ ".py",
12
+ ".js",
13
+ ".jsx",
14
+ ".ts",
15
+ ".tsx",
16
+ ".java",
17
+ ".go",
18
+ ".rs",
19
+ ".md",
20
+ ".json",
21
+ ".yml",
22
+ ".yaml",
23
+ ".toml",
24
+ ".sh",
25
+ ".css",
26
+ ".html",
27
+ }
28
+
29
+ IGNORED_FILENAMES = {
30
+ "package-lock.json",
31
+ "yarn.lock",
32
+ "pnpm-lock.yaml",
33
+ "bun.lockb",
34
+ }
35
+
36
+ IGNORED_DIRS = {
37
+ ".git",
38
+ ".next",
39
+ ".turbo",
40
+ "dist",
41
+ "build",
42
+ "coverage",
43
+ "node_modules",
44
+ "vendor",
45
+ ".venv",
46
+ "venv",
47
+ "__pycache__",
48
+ }
49
+
50
+ MAX_FILE_SIZE_BYTES = 250_000
51
+
52
+
53
+ class RepoFetcher:
54
+ def __init__(self, base_dir: str = None):
55
+ repo_cache_dir = base_dir or os.getenv(
56
+ "REPO_CACHE_DIR",
57
+ str(Path(tempfile.gettempdir()) / "codecompass-repos"),
58
+ )
59
+ self.base_dir = Path(repo_cache_dir)
60
+ self.base_dir.mkdir(parents=True, exist_ok=True)
61
+
62
+ def parse_github_url(self, github_url: str) -> dict:
63
+ parsed = urlparse(github_url)
64
+ path = parsed.path.rstrip("/")
65
+ if parsed.netloc not in {"github.com", "www.github.com"}:
66
+ raise ValueError("Only github.com URLs are supported")
67
+
68
+ parts = [part for part in path.split("/") if part]
69
+ if len(parts) < 2:
70
+ raise ValueError("GitHub URL must include owner and repository name")
71
+
72
+ owner = parts[0]
73
+ repo = parts[1].removesuffix(".git")
74
+ branch = "main"
75
+
76
+ if len(parts) >= 4 and parts[2] in {"tree", "blob"}:
77
+ branch = parts[3]
78
+
79
+ slug = re.sub(r"[^a-zA-Z0-9_.-]+", "-", f"{owner}-{repo}")
80
+ repo_url = f"https://github.com/{owner}/{repo}"
81
+ return {
82
+ "owner": owner,
83
+ "repo": repo,
84
+ "branch": branch,
85
+ "slug": slug,
86
+ "repo_url": repo_url,
87
+ }
88
+
89
+ def clone_repository(self, github_url: str) -> dict:
90
+ info = self.parse_github_url(github_url)
91
+ target_dir = self.base_dir / info["slug"]
92
+
93
+ if target_dir.exists():
94
+ shutil.rmtree(target_dir)
95
+
96
+ clone_cmd = [
97
+ "git",
98
+ "clone",
99
+ "--depth",
100
+ "1",
101
+ "--branch",
102
+ info["branch"],
103
+ github_url,
104
+ str(target_dir),
105
+ ]
106
+
107
+ clone_cmd[6] = info["repo_url"]
108
+
109
+ result = subprocess.run(clone_cmd, capture_output=True, text=True)
110
+ if result.returncode != 0 and info["branch"] != "main":
111
+ info["branch"] = "main"
112
+ clone_cmd[5] = "main"
113
+ result = subprocess.run(clone_cmd, capture_output=True, text=True)
114
+
115
+ if result.returncode != 0:
116
+ default_branch = self._resolve_default_branch(info["repo_url"])
117
+ if default_branch and default_branch != info["branch"]:
118
+ info["branch"] = default_branch
119
+ clone_cmd[5] = default_branch
120
+ result = subprocess.run(clone_cmd, capture_output=True, text=True)
121
+
122
+ if result.returncode != 0:
123
+ raise RuntimeError(result.stderr.strip() or "Failed to clone repository")
124
+
125
+ return {
126
+ **info,
127
+ "local_path": str(target_dir),
128
+ }
129
+
130
+ def _resolve_default_branch(self, github_url: str) -> str | None:
131
+ result = subprocess.run(
132
+ ["git", "ls-remote", "--symref", github_url, "HEAD"],
133
+ capture_output=True,
134
+ text=True,
135
+ )
136
+ if result.returncode != 0:
137
+ return None
138
+
139
+ for line in result.stdout.splitlines():
140
+ if line.startswith("ref: ") and "\tHEAD" in line:
141
+ ref = line.split("\t", 1)[0].removeprefix("ref: ").strip()
142
+ if ref.startswith("refs/heads/"):
143
+ return ref.removeprefix("refs/heads/")
144
+ return None
145
+
146
+ def cleanup_repository(self, repo_path: str):
147
+ target = Path(repo_path)
148
+ if target.exists():
149
+ shutil.rmtree(target)
150
+
151
+ def iter_source_files(self, repo_path: str):
152
+ root = Path(repo_path)
153
+ for file_path in root.rglob("*"):
154
+ if not file_path.is_file():
155
+ continue
156
+ if any(part in IGNORED_DIRS for part in file_path.parts):
157
+ continue
158
+ if file_path.name in IGNORED_FILENAMES:
159
+ continue
160
+ if file_path.suffix.lower() not in SUPPORTED_EXTENSIONS:
161
+ continue
162
+ if file_path.stat().st_size > MAX_FILE_SIZE_BYTES:
163
+ continue
164
+ yield file_path
src/vector_store.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Optional, Tuple
3
+ from uuid import uuid4
4
+
5
+ import numpy as np
6
+ from qdrant_client import QdrantClient, models
7
+
8
+
9
+ class QdrantVectorStore:
10
+ def __init__(self, embedding_dim: int, index_path: str = None, persist: bool = False):
11
+ self.embedding_dim = embedding_dim
12
+ self.collection_name = os.getenv("QDRANT_COLLECTION", "repo_qa_chunks")
13
+ self.upsert_batch_size = max(1, int(os.getenv("QDRANT_UPSERT_BATCH_SIZE", "64")))
14
+ self.client = self._create_client()
15
+ self._ensure_collection()
16
+
17
+ def _create_client(self):
18
+ url = os.getenv("QDRANT_URL")
19
+ api_key = os.getenv("QDRANT_API_KEY")
20
+ timeout = int(os.getenv("QDRANT_TIMEOUT_SECONDS", "120"))
21
+ if url:
22
+ return QdrantClient(
23
+ url=url,
24
+ api_key=api_key,
25
+ timeout=timeout,
26
+ check_compatibility=False,
27
+ )
28
+ return QdrantClient(":memory:")
29
+
30
+ def _ensure_collection(self):
31
+ if not self.client.collection_exists(self.collection_name):
32
+ self.client.create_collection(
33
+ collection_name=self.collection_name,
34
+ vectors_config=models.VectorParams(
35
+ size=self.embedding_dim,
36
+ distance=models.Distance.COSINE,
37
+ ),
38
+ )
39
+ self._ensure_payload_indexes()
40
+
41
+ def _ensure_payload_indexes(self):
42
+ self.client.create_payload_index(
43
+ collection_name=self.collection_name,
44
+ field_name="repository_id",
45
+ field_schema=models.PayloadSchemaType.INTEGER,
46
+ wait=True,
47
+ )
48
+
49
+ def add_embeddings(self, embeddings: np.ndarray, metadata: List[dict]) -> List[int]:
50
+ if embeddings.size == 0:
51
+ return []
52
+
53
+ embeddings = embeddings.astype("float32")
54
+ if embeddings.ndim == 1:
55
+ embeddings = embeddings.reshape(1, -1)
56
+
57
+ ids = [uuid4().hex for _ in metadata]
58
+ points = []
59
+ for idx, meta, embedding in zip(ids, metadata, embeddings):
60
+ payload = dict(meta)
61
+ payload["id"] = idx
62
+ points.append(
63
+ models.PointStruct(
64
+ id=idx,
65
+ vector=embedding.tolist(),
66
+ payload=payload,
67
+ )
68
+ )
69
+ total_points = len(points)
70
+ for start in range(0, total_points, self.upsert_batch_size):
71
+ batch = points[start : start + self.upsert_batch_size]
72
+ batch_number = (start // self.upsert_batch_size) + 1
73
+ total_batches = (total_points + self.upsert_batch_size - 1) // self.upsert_batch_size
74
+ print(
75
+ f"[qdrant] Upserting batch {batch_number}/{total_batches} "
76
+ f"points={len(batch)} progress={start}/{total_points}",
77
+ flush=True,
78
+ )
79
+ self.client.upsert(
80
+ collection_name=self.collection_name,
81
+ wait=True,
82
+ points=batch,
83
+ )
84
+
85
+ return ids
86
+
87
+ def search(
88
+ self,
89
+ query_embedding: np.ndarray,
90
+ k: int = 10,
91
+ repo_filter: Optional[int] = None,
92
+ ) -> List[Tuple[float, dict]]:
93
+ if query_embedding.ndim == 1:
94
+ query_embedding = query_embedding.reshape(1, -1)
95
+ query_embedding = query_embedding.astype("float32")
96
+
97
+ query_filter = None
98
+ if repo_filter is not None:
99
+ query_filter = models.Filter(
100
+ must=[
101
+ models.FieldCondition(
102
+ key="repository_id",
103
+ match=models.MatchValue(value=repo_filter),
104
+ )
105
+ ]
106
+ )
107
+
108
+ hits = self.client.search(
109
+ collection_name=self.collection_name,
110
+ query_vector=query_embedding[0].tolist(),
111
+ query_filter=query_filter,
112
+ limit=k,
113
+ )
114
+
115
+ return [(float(hit.score), dict(hit.payload or {})) for hit in hits]
116
+
117
+ def remove_repository(self, repo_id: int):
118
+ self.client.delete(
119
+ collection_name=self.collection_name,
120
+ wait=True,
121
+ points_selector=models.FilterSelector(
122
+ filter=models.Filter(
123
+ must=[
124
+ models.FieldCondition(
125
+ key="repository_id",
126
+ match=models.MatchValue(value=repo_id),
127
+ )
128
+ ]
129
+ )
130
+ ),
131
+ )
132
+
133
+ def clear(self):
134
+ if self.client.collection_exists(self.collection_name):
135
+ self.client.delete_collection(self.collection_name)
136
+ self._ensure_collection()
137
+
138
+ def save(self):
139
+ return None
140
+
141
+ def load(self):
142
+ self._ensure_collection()
143
+
144
+ def get_stats(self) -> dict:
145
+ info = self.client.get_collection(self.collection_name)
146
+ return {
147
+ "total_vectors": info.points_count or 0,
148
+ "embedding_dim": self.embedding_dim,
149
+ "collection_name": self.collection_name,
150
+ }