CodeMode Agent commited on
Commit
48ca3cd
·
1 Parent(s): fb9394b

Deploy CodeMode via Agent

Browse files
Files changed (1) hide show
  1. app.py +415 -240
app.py CHANGED
@@ -10,114 +10,171 @@ from pathlib import Path
10
  import chromadb
11
  from chromadb.config import Settings
12
  import uuid
 
13
 
14
- # --- Add scripts to path so we can import ingestion modules ---
15
- # --- Add scripts to path so we can import ingestion modules ---
16
- sys.path.append(os.path.dirname(__file__))
17
  from scripts.core.ingestion.ingest import GitCrawler
18
  from scripts.core.ingestion.chunk import RepoChunker
19
 
20
  # --- Configuration ---
21
- MODEL_NAME = "shubharuidas/codebert-base-code-embed-mrl-langchain-langgraph"
 
22
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
23
- DB_DIR = Path("data/chroma_db")
24
  DB_DIR.mkdir(parents=True, exist_ok=True)
25
 
26
- print(f"Loading model: {MODEL_NAME} on {DEVICE}...")
27
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
28
- model = AutoModel.from_pretrained(MODEL_NAME)
29
- model.to(DEVICE)
30
- model.eval()
31
- print("Model loaded!")
32
 
33
- # --- Vector Database Setup ---
34
- # Initialize ChromaDB Client (Persistent)
 
 
 
 
 
 
35
  chroma_client = chromadb.PersistentClient(path=str(DB_DIR))
 
 
36
 
37
- # Create or Get Collection
38
- # We use cosine similarity space
39
- collection = chroma_client.get_or_create_collection(name="codemode_rag", metadata={"hnsw:space": "cosine"})
 
 
 
 
 
40
 
41
- # --- Helper Functions ---
42
- def compute_embeddings(text_list):
43
- """Batch compute embeddings"""
44
  if not text_list: return None
45
- # Truncate to 512 tokens to avoid errors
46
- inputs = tokenizer(text_list, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
47
  with torch.no_grad():
48
- out = model(**inputs)
49
  emb = out.last_hidden_state.mean(dim=1)
50
  return F.normalize(emb, p=2, dim=1)
51
 
52
- def reset_db():
53
- """Clear database"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  try:
55
- chroma_client.delete_collection("codemode_rag")
56
- chroma_client.get_or_create_collection(name="codemode_rag", metadata={"hnsw:space": "cosine"})
57
- return "Database reset (All embeddings deleted)."
 
 
 
 
 
 
 
 
58
  except Exception as e:
59
- return f"Error resetting DB: {e}"
60
 
61
- def search_codebase(query, top_k=5):
62
- """Semantic Search using ChromaDB"""
63
- if collection.count() == 0: return []
 
64
 
65
- query_emb = compute_embeddings([query])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  if query_emb is None: return []
67
-
68
- # Convert tensor to list for Chroma
69
  query_vec = query_emb.cpu().numpy().tolist()[0]
70
-
71
- results = collection.query(
72
- query_embeddings=[query_vec],
73
- n_results=min(top_k, collection.count()),
74
- include=["metadatas", "documents", "distances"]
75
- )
76
-
77
- # Parse items
78
  output = []
79
  if results['ids']:
80
  for i in range(len(results['ids'][0])):
81
  meta = results['metadatas'][0][i]
82
  code = results['documents'][0][i]
83
  dist = results['distances'][0][i]
84
- score = 1 - dist # Cosine distance to similarity
85
-
86
- link_icon = "[Link]" if score > 0.7 else ""
87
- output.append([meta.get("file_name", "unknown"), f"{score:.4f} {link_icon}", code[:300] + "..."])
88
-
89
  return output
90
 
91
- def fn_ingest(repo_url):
92
- """
93
- 1. Clone Repo
94
- 2. Chunk Files
95
- 3. Compute Embeddings (Batched)
96
- 4. Store in ChromaDB
97
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  if not repo_url.startswith("http"):
99
- return "Invalid URL"
 
100
 
101
  DATA_DIR = Path(os.path.abspath("data/raw_ingest"))
102
  import stat
103
  def remove_readonly(func, path, _):
104
  os.chmod(path, stat.S_IWRITE)
105
  func(path)
106
-
107
  try:
108
- # Clean up old raw data
109
  if DATA_DIR.exists():
110
  shutil.rmtree(DATA_DIR, onerror=remove_readonly)
111
 
112
- # 1. Clone
113
  yield f"Cloning {repo_url}..."
114
  crawler = GitCrawler(cache_dir=DATA_DIR)
115
  repo_path = crawler.clone_repository(repo_url)
116
-
117
  if not repo_path:
118
- return "Failed to clone repository."
119
-
120
- # 2. Chunk
121
  yield "Listing files..."
122
  files = crawler.list_files(repo_path, extensions={'.py', '.md', '.json', '.js', '.ts', '.java', '.cpp'})
123
  if isinstance(files, tuple): files = [f.path for f in files[0]]
@@ -136,65 +193,124 @@ def fn_ingest(repo_url):
136
  all_chunks.extend(file_chunks)
137
  except Exception as e:
138
  print(f"Skipping {file_path}: {e}")
139
-
140
  if not all_chunks:
141
- return "No valid chunks found."
142
-
143
- # 3. Indexing Loop (Batched)
144
  total_chunks = len(all_chunks)
145
- yield f"Generated {total_chunks} chunks. Embedding & Indexing into ChromaDB..."
146
 
147
  batch_size = 64
 
148
  for i in range(0, total_chunks, batch_size):
149
  batch = all_chunks[i:i+batch_size]
150
-
151
- # Prepare data
152
  texts = [c.code for c in batch]
153
  ids = [str(uuid.uuid4()) for _ in batch]
154
  metadatas = [{"file_name": Path(c.file_path).name, "url": repo_url} for c in batch]
155
 
156
- # Compute Embeddings
157
- embeddings = compute_embeddings(texts)
158
  if embeddings is not None:
159
- # Add to Chroma
160
- collection.add(
161
- ids=ids,
162
- embeddings=embeddings.cpu().numpy().tolist(),
163
- metadatas=metadatas,
164
- documents=texts
165
- )
 
 
 
166
 
167
- progress = int((i / total_chunks) * 100)
168
- yield f"Indexed {min(i+batch_size, total_chunks)}/{total_chunks} ({progress}%)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
- count = collection.count()
171
- yield f"Success! Database now has {count} code chunks. Ready for search."
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  except Exception as e:
174
  import traceback
175
  traceback.print_exc()
176
  yield f"Error: {str(e)}"
177
 
178
- # --- Analysis Functions ---
179
- def fn_analyze_embeddings():
180
- count = collection.count()
181
  if count < 5:
182
  return "Not enough data (Need > 5 chunks).", None
183
 
184
  try:
185
- # Fetch all embeddings (Limit to 2000 for visualization speed)
186
  limit = min(count, 2000)
187
- data = collection.get(limit=limit, include=["embeddings", "metadatas"])
188
 
189
  X = torch.tensor(data['embeddings'])
190
-
191
- # PCA
192
  X_mean = torch.mean(X, 0)
193
  X_centered = X - X_mean
194
  U, S, V = torch.pca_lowrank(X_centered, q=2)
195
  projected = torch.matmul(X_centered, V[:, :2]).numpy()
196
 
197
- # Diversity
198
  indices = torch.randint(0, len(X), (min(100, len(X)),))
199
  sample = X[indices]
200
  sim_matrix = torch.mm(sample, sample.t())
@@ -203,10 +319,11 @@ def fn_analyze_embeddings():
203
  diversity_score = 1.0 - avg_sim
204
 
205
  metrics = (
 
206
  f"Total Chunks: {count}\n"
207
- f"Analyzed: {len(X)} (Sampled)\n"
208
  f"Diversity Score: {diversity_score:.4f}\n"
209
- f"Est. Avg Similarity: {avg_sim:.4f}"
210
  )
211
 
212
  plot_df = pd.DataFrame({
@@ -215,22 +332,61 @@ def fn_analyze_embeddings():
215
  "topic": [m.get("file_name", "unknown") for m in data['metadatas']]
216
  })
217
 
218
- return metrics, gr.ScatterPlot(value=plot_df, x="x", y="y", color="topic", title="Semantic Space", tooltip="topic")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
 
220
  except Exception as e:
221
  import traceback
222
  traceback.print_exc()
223
- return f"Analysis Error: {e}", None
224
 
225
- def fn_evaluate_retrieval(sample_limit):
226
- count = collection.count()
227
  if count < 10: return "Not enough data for evaluation (Need > 10 chunks)."
228
 
229
  try:
230
- # Sample random chunks
231
- # Chroma doesn't support random sample easily, so we get a larger batch and pick random
232
- fetch_limit = min(count, 2000) # Fetch up to 2k to sample from
233
- data = collection.get(limit=fetch_limit, include=["documents"])
234
 
235
  import random
236
  actual_sample_size = min(sample_limit, len(data['ids']))
@@ -240,191 +396,210 @@ def fn_evaluate_retrieval(sample_limit):
240
  hits_at_5 = 0
241
  mrr_sum = 0
242
 
243
- # Generator for progress updates
244
- yield f"Running evaluation on {actual_sample_size} chunks..."
245
 
246
  for i, idx in enumerate(sample_indices):
247
  target_id = data['ids'][idx]
248
  code = data['documents'][idx]
249
-
250
- # Synthetic Query
251
  query = "\n".join(code.split("\n")[:3])
252
- query_emb = compute_embeddings([query]).cpu().numpy().tolist()[0]
253
-
254
- # Query DB
255
- results = collection.query(query_embeddings=[query_emb], n_results=10)
256
-
257
- # Check results
258
  found_ids = results['ids'][0]
259
  if target_id in found_ids:
260
  rank = found_ids.index(target_id) + 1
261
  mrr_sum += 1.0 / rank
262
  if rank == 1: hits_at_1 += 1
263
  if rank <= 5: hits_at_5 += 1
264
-
265
  if i % 10 == 0:
266
- yield f"Evaluated {i}/{actual_sample_size}..."
267
 
268
  recall_1 = hits_at_1 / actual_sample_size
269
  recall_5 = hits_at_5 / actual_sample_size
270
  mrr = mrr_sum / actual_sample_size
271
 
272
  report = (
273
- f"Evaluation on {actual_sample_size} random chunks:\n"
274
- f"--------------------------------------------\n"
275
  f"Recall@1: {recall_1:.4f}\n"
276
  f"Recall@5: {recall_5:.4f}\n"
277
- f"MRR: {mrr:.4f}\n"
278
- f"\n(Note: Using ChromaDB for retrieval)"
279
  )
280
  yield report
281
  except Exception as e:
282
  import traceback
283
  traceback.print_exc()
284
- yield f"Eval Error: {e}"
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
- # --- UI Layout ---
288
- theme = gr.themes.Soft(
289
- primary_hue="slate",
290
- neutral_hue="slate",
291
- spacing_size="sm",
292
- radius_size="md"
293
- ).set(
294
- body_background_fill="*neutral_50",
295
- block_background_fill="white",
296
- block_border_width="1px",
297
- block_title_text_weight="600"
298
- )
299
 
300
  css = """
301
- h1 {
302
- text-align: center;
303
- font-family: 'Inter', sans-serif;
304
- margin-bottom: 1rem;
305
- color: #1e293b;
306
- }
307
- .gradio-container {
308
- max-width: 1200px !important;
309
- margin: auto;
310
- }
311
  """
312
 
313
- with gr.Blocks(theme=theme, css=css, title="CodeMode") as demo:
314
- gr.Markdown("# CodeMode")
 
315
 
316
  with gr.Tabs():
317
- # --- TAB 1: INGEST ---
318
- with gr.Tab("1. Ingest GitHub Repo"):
319
- gr.Markdown("### Connect a Repository")
320
- with gr.Row():
321
- repo_input = gr.Textbox(label="GitHub URL", placeholder="https://github.com/fastapi/fastapi", value="https://github.com/langchain-ai/langgraph")
322
- ingest_btn = gr.Button("Ingest & Index", variant="primary")
 
 
 
 
 
 
 
 
323
 
324
  with gr.Row():
325
- reset_btn = gr.Button("Reset Database", variant="stop")
326
- ingest_status = gr.Textbox(label="Status")
 
 
 
 
 
 
 
 
327
 
328
- with gr.Accordion("Database Inspector", open=False):
329
- list_files_btn = gr.Button("Refresh File List")
330
- files_df = gr.Dataframe(
331
- headers=["File Name", "Chunks", "Source URL"],
332
- datatype=["str", "number", "str"],
333
- interactive=False
334
- )
335
-
336
- def fn_list_files():
337
- count = collection.count()
338
- if count == 0: return [["Database Empty", 0, "-"]]
339
-
340
- try:
341
- # Fetch all metadata (limit to 10k to prevent UI freeze)
342
- limit = min(count, 10000)
343
- data = collection.get(limit=limit, include=["metadatas"])
344
-
345
- if not data or 'metadatas' not in data or data['metadatas'] is None:
346
- return [["Error: No metadata found", 0, "-"]]
347
-
348
- # Aggregate stats
349
- file_counts = {} # filename -> count
350
- file_urls = {} # filename -> url
351
-
352
- for meta in data['metadatas']:
353
- if meta is None: continue # Skip None entries
354
- fname = meta.get("file_name", "unknown")
355
- url = meta.get("url", "-")
356
- file_counts[fname] = file_counts.get(fname, 0) + 1
357
- file_urls[fname] = url
358
-
359
- # Convert to list
360
- output = []
361
- for fname, count in file_counts.items():
362
- output.append([fname, count, file_urls[fname]])
363
-
364
- if not output:
365
- return [["No files found in metadata", 0, "-"]]
366
-
367
- # Sort by chunk count (descending)
368
- output.sort(key=lambda x: x[1], reverse=True)
369
- return output
370
- except Exception as e:
371
- import traceback
372
- traceback.print_exc()
373
- return [[f"Error: {str(e)}", 0, "-"]]
374
-
375
- ingest_btn.click(fn_ingest, inputs=repo_input, outputs=[ingest_status])
376
- reset_btn.click(fn=reset_db, inputs=[], outputs=[ingest_status])
377
- list_files_btn.click(fn_list_files, inputs=[], outputs=[files_df])
378
-
379
- # --- TAB 2: SEARCH ---
380
- with gr.Tab("2. Semantic Search"):
381
- gr.Markdown("### Search the Ingested Code")
382
  with gr.Row():
383
- search_box = gr.Textbox(label="Search Query", placeholder="e.g., 'how to create a state graph'")
384
- search_btn = gr.Button("Search", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
- results_df = gr.Dataframe(
387
- headers=["File Name", "Score", "Code Snippet"],
388
- datatype=["str", "str", "str"],
389
- interactive=False,
390
- wrap=True
391
- )
392
- search_btn.click(fn=search_codebase, inputs=search_box, outputs=results_df)
393
-
394
- # --- TAB 3: CODE SEARCH ---
395
- with gr.Tab("3. Find Similar Code"):
396
- gr.Markdown("### Code-to-Code Retrieval")
397
  with gr.Row():
398
- code_input = gr.Code(label="Reference Code", language="python")
399
- code_search_btn = gr.Button("Find Matches", variant="primary")
 
400
 
401
- code_results_df = gr.Dataframe(
402
- headers=["File Name", "Score", "Matched Code"],
403
- datatype=["str", "str", "str"],
404
- interactive=False,
405
- wrap=True
406
- )
407
- code_search_btn.click(fn=search_codebase, inputs=code_input, outputs=code_results_df)
408
-
409
- # --- TAB 4: MLOps MONITORING ---
410
- with gr.Tab("4. Deployment Monitoring"):
411
  gr.Markdown("### Embedding Quality Analysis")
412
- analyze_btn = gr.Button("Analyze Embeddings", variant="secondary")
413
 
414
  with gr.Row():
415
- quality_metrics = gr.Textbox(label="Quality Metrics")
416
- plot_output = gr.ScatterPlot(label="Semantic Space (PCA)")
 
 
 
 
 
 
 
 
 
 
 
417
 
418
- analyze_btn.click(fn_analyze_embeddings, inputs=[], outputs=[quality_metrics, plot_output])
419
-
420
- gr.Markdown("### Extrinsic Evaluation (Retrieval Performance)")
421
- with gr.Row():
422
- eval_size = gr.Slider(minimum=10, maximum=1000, value=50, step=10, label="Sample Size (Chunks)")
423
- eval_btn = gr.Button("Run Retrieval Evaluation", variant="primary")
424
 
425
- eval_output = gr.Textbox(label="Evaluation Report")
426
 
427
- eval_btn.click(fn_evaluate_retrieval, inputs=[eval_size], outputs=eval_output)
 
 
 
 
 
 
 
 
 
 
 
428
 
429
  if __name__ == "__main__":
430
- demo.queue().launch()
 
10
  import chromadb
11
  from chromadb.config import Settings
12
  import uuid
13
+ import tempfile
14
 
15
+ # --- Add scripts to path ---
16
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
 
17
  from scripts.core.ingestion.ingest import GitCrawler
18
  from scripts.core.ingestion.chunk import RepoChunker
19
 
20
  # --- Configuration ---
21
+ BASELINE_MODEL = "microsoft/codebert-base"
22
+ FINETUNED_MODEL = "shubharuidas/codebert-base-code-embed-mrl-langchain-langgraph"
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
+ DB_DIR = Path(os.path.abspath("data/chroma_db_comparison"))
25
  DB_DIR.mkdir(parents=True, exist_ok=True)
26
 
27
+ print(f"Loading models on {DEVICE}...")
28
+ print("1. Loading baseline model...")
29
+ baseline_tokenizer = AutoTokenizer.from_pretrained(BASELINE_MODEL)
30
+ baseline_model = AutoModel.from_pretrained(BASELINE_MODEL)
31
+ baseline_model.to(DEVICE)
32
+ baseline_model.eval()
33
 
34
+ print("2. Loading fine-tuned model...")
35
+ finetuned_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL)
36
+ finetuned_model = AutoModel.from_pretrained(FINETUNED_MODEL)
37
+ finetuned_model.to(DEVICE)
38
+ finetuned_model.eval()
39
+ print("Both models loaded!")
40
+
41
+ # --- ChromaDB Setup ---
42
  chroma_client = chromadb.PersistentClient(path=str(DB_DIR))
43
+ baseline_collection = chroma_client.get_or_create_collection(name="baseline_rag", metadata={"hnsw:space": "cosine"})
44
+ finetuned_collection = chroma_client.get_or_create_collection(name="finetuned_rag", metadata={"hnsw:space": "cosine"})
45
 
46
+ # --- Embedding Functions ---
47
+ def compute_baseline_embeddings(text_list):
48
+ if not text_list: return None
49
+ inputs = baseline_tokenizer(text_list, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
50
+ with torch.no_grad():
51
+ out = baseline_model(**inputs)
52
+ emb = out.last_hidden_state.mean(dim=1)
53
+ return F.normalize(emb, p=2, dim=1)
54
 
55
+ def compute_finetuned_embeddings(text_list):
 
 
56
  if not text_list: return None
57
+ inputs = finetuned_tokenizer(text_list, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
 
58
  with torch.no_grad():
59
+ out = finetuned_model(**inputs)
60
  emb = out.last_hidden_state.mean(dim=1)
61
  return F.normalize(emb, p=2, dim=1)
62
 
63
+ # --- Reset Functions ---
64
+ def reset_baseline():
65
+ chroma_client.delete_collection("baseline_rag")
66
+ global baseline_collection
67
+ baseline_collection = chroma_client.get_or_create_collection(name="baseline_rag", metadata={"hnsw:space": "cosine"})
68
+ return "Baseline database reset."
69
+
70
+ def reset_finetuned():
71
+ chroma_client.delete_collection("finetuned_rag")
72
+ global finetuned_collection
73
+ finetuned_collection = chroma_client.get_or_create_collection(name="finetuned_rag", metadata={"hnsw:space": "cosine"})
74
+ return "Fine-tuned database reset."
75
+
76
+ # --- Database Inspector Functions ---
77
+ def list_baseline_files():
78
+ count = baseline_collection.count()
79
+ if count == 0:
80
+ return [["No data indexed yet", "-", "-"]]
81
+
82
  try:
83
+ data = baseline_collection.get(limit=min(count, 1000), include=["metadatas"])
84
+ file_stats = {}
85
+ for meta in data['metadatas']:
86
+ fname = meta.get("file_name", "unknown")
87
+ url = meta.get("url", "unknown")
88
+ if fname not in file_stats:
89
+ file_stats[fname] = {"count": 0, "url": url}
90
+ file_stats[fname]["count"] += 1
91
+
92
+ results = [[fname, stats["count"], stats["url"]] for fname, stats in file_stats.items()]
93
+ return sorted(results, key=lambda x: x[1], reverse=True)
94
  except Exception as e:
95
+ return [[f"Error: {str(e)}", "-", "-"]]
96
 
97
+ def list_finetuned_files():
98
+ count = finetuned_collection.count()
99
+ if count == 0:
100
+ return [["No data indexed yet", "-", "-"]]
101
 
102
+ try:
103
+ data = finetuned_collection.get(limit=min(count, 1000), include=["metadatas"])
104
+ file_stats = {}
105
+ for meta in data['metadatas']:
106
+ fname = meta.get("file_name", "unknown")
107
+ url = meta.get("url", "unknown")
108
+ if fname not in file_stats:
109
+ file_stats[fname] = {"count": 0, "url": url}
110
+ file_stats[fname]["count"] += 1
111
+
112
+ results = [[fname, stats["count"], stats["url"]] for fname, stats in file_stats.items()]
113
+ return sorted(results, key=lambda x: x[1], reverse=True)
114
+ except Exception as e:
115
+ return [[f"Error: {str(e)}", "-", "-"]]
116
+
117
+ # --- Search Functions ---
118
+ def search_baseline(query, top_k=5):
119
+ if baseline_collection.count() == 0: return []
120
+ query_emb = compute_baseline_embeddings([query])
121
  if query_emb is None: return []
 
 
122
  query_vec = query_emb.cpu().numpy().tolist()[0]
123
+ results = baseline_collection.query(query_embeddings=[query_vec], n_results=min(top_k, baseline_collection.count()), include=["metadatas", "documents", "distances"])
 
 
 
 
 
 
 
124
  output = []
125
  if results['ids']:
126
  for i in range(len(results['ids'][0])):
127
  meta = results['metadatas'][0][i]
128
  code = results['documents'][0][i]
129
  dist = results['distances'][0][i]
130
+ score = 1 - dist
131
+ output.append([meta.get("file_name", "unknown"), f"{score:.4f}", code[:300] + "..."])
 
 
 
132
  return output
133
 
134
+ def search_finetuned(query, top_k=5):
135
+ if finetuned_collection.count() == 0: return []
136
+ query_emb = compute_finetuned_embeddings([query])
137
+ if query_emb is None: return []
138
+ query_vec = query_emb.cpu().numpy().tolist()[0]
139
+ results = finetuned_collection.query(query_embeddings=[query_vec], n_results=min(top_k, finetuned_collection.count()), include=["metadatas", "documents", "distances"])
140
+ output = []
141
+ if results['ids']:
142
+ for i in range(len(results['ids'][0])):
143
+ meta = results['metadatas'][0][i]
144
+ code = results['documents'][0][i]
145
+ dist = results['distances'][0][i]
146
+ score = 1 - dist
147
+ output.append([meta.get("file_name", "unknown"), f"{score:.4f}", code[:300] + "..."])
148
+ return output
149
+
150
+ def search_comparison(query, top_k=5):
151
+ baseline_results = search_baseline(query, top_k)
152
+ finetuned_results = search_finetuned(query, top_k)
153
+ return baseline_results, finetuned_results
154
+
155
+ # --- Ingestion Functions ---
156
+ def ingest_from_url(repo_url):
157
  if not repo_url.startswith("http"):
158
+ yield "Invalid URL"
159
+ return
160
 
161
  DATA_DIR = Path(os.path.abspath("data/raw_ingest"))
162
  import stat
163
  def remove_readonly(func, path, _):
164
  os.chmod(path, stat.S_IWRITE)
165
  func(path)
166
+
167
  try:
 
168
  if DATA_DIR.exists():
169
  shutil.rmtree(DATA_DIR, onerror=remove_readonly)
170
 
 
171
  yield f"Cloning {repo_url}..."
172
  crawler = GitCrawler(cache_dir=DATA_DIR)
173
  repo_path = crawler.clone_repository(repo_url)
 
174
  if not repo_path:
175
+ yield "Failed to clone repository."
176
+ return
177
+
178
  yield "Listing files..."
179
  files = crawler.list_files(repo_path, extensions={'.py', '.md', '.json', '.js', '.ts', '.java', '.cpp'})
180
  if isinstance(files, tuple): files = [f.path for f in files[0]]
 
193
  all_chunks.extend(file_chunks)
194
  except Exception as e:
195
  print(f"Skipping {file_path}: {e}")
196
+
197
  if not all_chunks:
198
+ yield "No valid chunks found."
199
+ return
200
+
201
  total_chunks = len(all_chunks)
202
+ yield f"Generated {total_chunks} chunks. Embedding (BASELINE)..."
203
 
204
  batch_size = 64
205
+ # Index with baseline
206
  for i in range(0, total_chunks, batch_size):
207
  batch = all_chunks[i:i+batch_size]
 
 
208
  texts = [c.code for c in batch]
209
  ids = [str(uuid.uuid4()) for _ in batch]
210
  metadatas = [{"file_name": Path(c.file_path).name, "url": repo_url} for c in batch]
211
 
212
+ embeddings = compute_baseline_embeddings(texts)
 
213
  if embeddings is not None:
214
+ baseline_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts)
215
+ yield f"Baseline: {min(i+batch_size, total_chunks)}/{total_chunks}"
216
+
217
+ yield f"Embedding (FINE-TUNED)..."
218
+ # Index with fine-tuned
219
+ for i in range(0, total_chunks, batch_size):
220
+ batch = all_chunks[i:i+batch_size]
221
+ texts = [c.code for c in batch]
222
+ ids = [str(uuid.uuid4()) for _ in batch]
223
+ metadatas = [{"file_name": Path(c.file_path).name, "url": repo_url} for c in batch]
224
 
225
+ embeddings = compute_finetuned_embeddings(texts)
226
+ if embeddings is not None:
227
+ finetuned_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts)
228
+ yield f"Fine-tuned: {min(i+batch_size, total_chunks)}/{total_chunks}"
229
+
230
+ yield f"SUCCESS! Indexed {total_chunks} chunks in both databases."
231
+ except Exception as e:
232
+ import traceback
233
+ traceback.print_exc()
234
+ yield f"Error: {str(e)}"
235
+
236
+ def ingest_from_files(files):
237
+ if not files or len(files) == 0:
238
+ yield "No files uploaded."
239
+ return
240
+
241
+ try:
242
+ yield f"Processing {len(files)} file(s)..."
243
+
244
+ chunker = RepoChunker()
245
+ all_chunks = []
246
+
247
+ for i, file in enumerate(files):
248
+ yield f"Chunking file {i+1}/{len(files)}: {Path(file.name).name}"
249
+ try:
250
+ # Gradio file upload: file.name contains the temp path
251
+ file_path = Path(file.name)
252
+ meta = {"file_name": file_path.name, "url": "uploaded"}
253
+ file_chunks = chunker.chunk_file(file_path, repo_metadata=meta)
254
+ all_chunks.extend(file_chunks)
255
+ except Exception as e:
256
+ yield f"Error chunking {Path(file.name).name}: {str(e)}"
257
+ import traceback
258
+ traceback.print_exc()
259
 
 
 
260
 
261
+ if not all_chunks:
262
+ yield "No valid chunks found."
263
+ return
264
+
265
+ total_chunks = len(all_chunks)
266
+ yield f"Generated {total_chunks} chunks. Embedding (BASELINE)..."
267
+
268
+ batch_size = 64
269
+ for i in range(0, total_chunks, batch_size):
270
+ batch = all_chunks[i:i+batch_size]
271
+ texts = [c.code for c in batch]
272
+ ids = [str(uuid.uuid4()) for _ in batch]
273
+ metadatas = [{"file_name": Path(c.file_path).name, "url": "uploaded"} for c in batch]
274
+
275
+ embeddings = compute_baseline_embeddings(texts)
276
+ if embeddings is not None:
277
+ baseline_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts)
278
+ yield f"Baseline: {min(i+batch_size, total_chunks)}/{total_chunks}"
279
+
280
+ yield f"Embedding (FINE-TUNED)..."
281
+ for i in range(0, total_chunks, batch_size):
282
+ batch = all_chunks[i:i+batch_size]
283
+ texts = [c.code for c in batch]
284
+ ids = [str(uuid.uuid4()) for _ in batch]
285
+ metadatas = [{"file_name": Path(c.file_path).name, "url": "uploaded"} for c in batch]
286
+
287
+ embeddings = compute_finetuned_embeddings(texts)
288
+ if embeddings is not None:
289
+ finetuned_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts)
290
+ yield f"Fine-tuned: {min(i+batch_size, total_chunks)}/{total_chunks}"
291
+
292
+ yield f"SUCCESS! Indexed {total_chunks} chunks from uploaded files."
293
  except Exception as e:
294
  import traceback
295
  traceback.print_exc()
296
  yield f"Error: {str(e)}"
297
 
298
+ # --- Analysis & Evaluation Functions ---
299
+ def analyze_embeddings_baseline():
300
+ count = baseline_collection.count()
301
  if count < 5:
302
  return "Not enough data (Need > 5 chunks).", None
303
 
304
  try:
 
305
  limit = min(count, 2000)
306
+ data = baseline_collection.get(limit=limit, include=["embeddings", "metadatas"])
307
 
308
  X = torch.tensor(data['embeddings'])
 
 
309
  X_mean = torch.mean(X, 0)
310
  X_centered = X - X_mean
311
  U, S, V = torch.pca_lowrank(X_centered, q=2)
312
  projected = torch.matmul(X_centered, V[:, :2]).numpy()
313
 
 
314
  indices = torch.randint(0, len(X), (min(100, len(X)),))
315
  sample = X[indices]
316
  sim_matrix = torch.mm(sample, sample.t())
 
319
  diversity_score = 1.0 - avg_sim
320
 
321
  metrics = (
322
+ f"BASELINE MODEL\n"
323
  f"Total Chunks: {count}\n"
324
+ f"Analyzed: {len(X)}\n"
325
  f"Diversity Score: {diversity_score:.4f}\n"
326
+ f"Avg Similarity: {avg_sim:.4f}"
327
  )
328
 
329
  plot_df = pd.DataFrame({
 
332
  "topic": [m.get("file_name", "unknown") for m in data['metadatas']]
333
  })
334
 
335
+ return metrics, gr.ScatterPlot(value=plot_df, x="x", y="y", color="topic", title="Baseline Semantic Space", tooltip="topic")
336
+ except Exception as e:
337
+ import traceback
338
+ traceback.print_exc()
339
+ return f"Error: {e}", None
340
+
341
+ def analyze_embeddings_finetuned():
342
+ count = finetuned_collection.count()
343
+ if count < 5:
344
+ return "Not enough data (Need > 5 chunks).", None
345
+
346
+ try:
347
+ limit = min(count, 2000)
348
+ data = finetuned_collection.get(limit=limit, include=["embeddings", "metadatas"])
349
+
350
+ X = torch.tensor(data['embeddings'])
351
+ X_mean = torch.mean(X, 0)
352
+ X_centered = X - X_mean
353
+ U, S, V = torch.pca_lowrank(X_centered, q=2)
354
+ projected = torch.matmul(X_centered, V[:, :2]).numpy()
355
+
356
+ indices = torch.randint(0, len(X), (min(100, len(X)),))
357
+ sample = X[indices]
358
+ sim_matrix = torch.mm(sample, sample.t())
359
+ mask = ~torch.eye(len(sample), dtype=bool)
360
+ avg_sim = sim_matrix[mask].mean().item()
361
+ diversity_score = 1.0 - avg_sim
362
+
363
+ metrics = (
364
+ f"FINE-TUNED MODEL\n"
365
+ f"Total Chunks: {count}\n"
366
+ f"Analyzed: {len(X)}\n"
367
+ f"Diversity Score: {diversity_score:.4f}\n"
368
+ f"Avg Similarity: {avg_sim:.4f}"
369
+ )
370
+
371
+ plot_df = pd.DataFrame({
372
+ "x": projected[:, 0],
373
+ "y": projected[:, 1],
374
+ "topic": [m.get("file_name", "unknown") for m in data['metadatas']]
375
+ })
376
 
377
+ return metrics, gr.ScatterPlot(value=plot_df, x="x", y="y", color="topic", title="Fine-tuned Semantic Space", tooltip="topic")
378
  except Exception as e:
379
  import traceback
380
  traceback.print_exc()
381
+ return f"Error: {e}", None
382
 
383
+ def evaluate_retrieval_baseline(sample_limit):
384
+ count = baseline_collection.count()
385
  if count < 10: return "Not enough data for evaluation (Need > 10 chunks)."
386
 
387
  try:
388
+ fetch_limit = min(count, 2000)
389
+ data = baseline_collection.get(limit=fetch_limit, include=["documents"])
 
 
390
 
391
  import random
392
  actual_sample_size = min(sample_limit, len(data['ids']))
 
396
  hits_at_5 = 0
397
  mrr_sum = 0
398
 
399
+ yield f"BASELINE: Evaluating {actual_sample_size} chunks..."
 
400
 
401
  for i, idx in enumerate(sample_indices):
402
  target_id = data['ids'][idx]
403
  code = data['documents'][idx]
 
 
404
  query = "\n".join(code.split("\n")[:3])
405
+ query_emb = compute_baseline_embeddings([query]).cpu().numpy().tolist()[0]
406
+ results = baseline_collection.query(query_embeddings=[query_emb], n_results=10)
 
 
 
 
407
  found_ids = results['ids'][0]
408
  if target_id in found_ids:
409
  rank = found_ids.index(target_id) + 1
410
  mrr_sum += 1.0 / rank
411
  if rank == 1: hits_at_1 += 1
412
  if rank <= 5: hits_at_5 += 1
 
413
  if i % 10 == 0:
414
+ yield f"Baseline: {i}/{actual_sample_size}..."
415
 
416
  recall_1 = hits_at_1 / actual_sample_size
417
  recall_5 = hits_at_5 / actual_sample_size
418
  mrr = mrr_sum / actual_sample_size
419
 
420
  report = (
421
+ f"BASELINE EVALUATION ({actual_sample_size} chunks)\n"
422
+ f"{'='*40}\n"
423
  f"Recall@1: {recall_1:.4f}\n"
424
  f"Recall@5: {recall_5:.4f}\n"
425
+ f"MRR: {mrr:.4f}"
 
426
  )
427
  yield report
428
  except Exception as e:
429
  import traceback
430
  traceback.print_exc()
431
+ yield f"Error: {e}"
432
 
433
+ def evaluate_retrieval_finetuned(sample_limit):
434
+ count = finetuned_collection.count()
435
+ if count < 10: return "Not enough data for evaluation (Need > 10 chunks)."
436
+
437
+ try:
438
+ fetch_limit = min(count, 2000)
439
+ data = finetuned_collection.get(limit=fetch_limit, include=["documents"])
440
+
441
+ import random
442
+ actual_sample_size = min(sample_limit, len(data['ids']))
443
+ sample_indices = random.sample(range(len(data['ids'])), actual_sample_size)
444
+
445
+ hits_at_1 = 0
446
+ hits_at_5 = 0
447
+ mrr_sum = 0
448
+
449
+ yield f"FINE-TUNED: Evaluating {actual_sample_size} chunks..."
450
+
451
+ for i, idx in enumerate(sample_indices):
452
+ target_id = data['ids'][idx]
453
+ code = data['documents'][idx]
454
+ query = "\n".join(code.split("\n")[:3])
455
+ query_emb = compute_finetuned_embeddings([query]).cpu().numpy().tolist()[0]
456
+ results = finetuned_collection.query(query_embeddings=[query_emb], n_results=10)
457
+ found_ids = results['ids'][0]
458
+ if target_id in found_ids:
459
+ rank = found_ids.index(target_id) + 1
460
+ mrr_sum += 1.0 / rank
461
+ if rank == 1: hits_at_1 += 1
462
+ if rank <= 5: hits_at_5 += 1
463
+ if i % 10 == 0:
464
+ yield f"Fine-tuned: {i}/{actual_sample_size}..."
465
+
466
+ recall_1 = hits_at_1 / actual_sample_size
467
+ recall_5 = hits_at_5 / actual_sample_size
468
+ mrr = mrr_sum / actual_sample_size
469
+
470
+ report = (
471
+ f"FINE-TUNED EVALUATION ({actual_sample_size} chunks)\n"
472
+ f"{'='*40}\n"
473
+ f"Recall@1: {recall_1:.4f}\n"
474
+ f"Recall@5: {recall_5:.4f}\n"
475
+ f"MRR: {mrr:.4f}"
476
+ )
477
+ yield report
478
+ except Exception as e:
479
+ import traceback
480
+ traceback.print_exc()
481
+ yield f"Error: {e}"
482
 
483
+ # --- UI ---
484
+ theme = gr.themes.Soft(primary_hue="slate", neutral_hue="slate", spacing_size="sm", radius_size="md").set(body_background_fill="*neutral_50", block_background_fill="white", block_border_width="1px", block_title_text_weight="600")
 
 
 
 
 
 
 
 
 
 
485
 
486
  css = """
487
+ h1 { text-align: center; font-family: 'Inter', sans-serif; margin-bottom: 1rem; color: #1e293b; }
488
+ .gradio-container { max-width: 1400px !important; margin: auto; }
489
+ .comparison-header { font-size: 1.1rem; font-weight: 600; color: #334155; text-align: center; padding: 0.5rem; }
 
 
 
 
 
 
 
490
  """
491
 
492
+ with gr.Blocks(theme=theme, css=css, title="CodeMode - Baseline vs Fine-tuned") as demo:
493
+ gr.Markdown("# CodeMode: Baseline vs Fine-tuned Model Comparison")
494
+ gr.Markdown("Compare retrieval performance between **microsoft/codebert-base** (baseline) and **MRL-enhanced fine-tuned** model")
495
 
496
  with gr.Tabs():
497
+ # TAB 1: INGEST
498
+ with gr.Tab("1. Ingest Code"):
499
+ with gr.Tabs():
500
+ with gr.Tab("GitHub Repository"):
501
+ repo_input = gr.Textbox(label="GitHub URL", placeholder="https://github.com/pallets/flask")
502
+ ingest_url_btn = gr.Button("Ingest from URL", variant="primary")
503
+ url_status = gr.Textbox(label="Status")
504
+ ingest_url_btn.click(ingest_from_url, inputs=repo_input, outputs=url_status)
505
+
506
+ with gr.Tab("Upload Python Files"):
507
+ file_upload = gr.File(label="Upload .py files", file_types=[".py"], file_count="multiple")
508
+ ingest_files_btn = gr.Button("Ingest Uploaded Files", variant="primary")
509
+ upload_status = gr.Textbox(label="Status")
510
+ ingest_files_btn.click(ingest_from_files, inputs=file_upload, outputs=upload_status)
511
 
512
  with gr.Row():
513
+ reset_baseline_btn = gr.Button("Reset Baseline DB", variant="stop")
514
+ reset_finetuned_btn = gr.Button("Reset Fine-tuned DB", variant="stop")
515
+ reset_status = gr.Textbox(label="Reset Status")
516
+
517
+ reset_baseline_btn.click(reset_baseline, inputs=[], outputs=reset_status)
518
+ reset_finetuned_btn.click(reset_finetuned, inputs=[], outputs=reset_status)
519
+
520
+ gr.Markdown("---")
521
+ gr.Markdown("### Database Inspector")
522
+ gr.Markdown("View indexed files in each collection")
523
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
  with gr.Row():
525
+ with gr.Column():
526
+ gr.Markdown("#### Baseline Collection")
527
+ inspect_baseline_btn = gr.Button("Inspect Baseline DB", variant="secondary")
528
+ baseline_files_df = gr.Dataframe(
529
+ headers=["File Name", "Chunks", "Source URL"],
530
+ datatype=["str", "number", "str"],
531
+ interactive=False,
532
+ value=[["No data yet", "-", "-"]]
533
+ )
534
+ inspect_baseline_btn.click(list_baseline_files, inputs=[], outputs=baseline_files_df)
535
+
536
+ with gr.Column():
537
+ gr.Markdown("#### Fine-tuned Collection")
538
+ inspect_finetuned_btn = gr.Button("Inspect Fine-tuned DB", variant="secondary")
539
+ finetuned_files_df = gr.Dataframe(
540
+ headers=["File Name", "Chunks", "Source URL"],
541
+ datatype=["str", "number", "str"],
542
+ interactive=False,
543
+ value=[["No data yet", "-", "-"]]
544
+ )
545
+ inspect_finetuned_btn.click(list_finetuned_files, inputs=[], outputs=finetuned_files_df)
546
+
547
+ # TAB 2: COMPARISON SEARCH
548
+ with gr.Tab("2. Comparison Search (Note: Semantic search is sensitive to query phrasing)"):
549
+ gr.Markdown("### Side-by-Side Retrieval Comparison")
550
+ search_query = gr.Textbox(label="Search Query", placeholder="e.g., 'Flask route decorator'")
551
+ compare_btn = gr.Button("Compare Models", variant="primary")
552
 
 
 
 
 
 
 
 
 
 
 
 
553
  with gr.Row():
554
+ with gr.Column():
555
+ gr.Markdown("<div class='comparison-header'>BASELINE (CodeBERT)</div>", elem_classes="comparison-header")
556
+ baseline_results = gr.Dataframe(headers=["File", "Score", "Code Snippet"], datatype=["str", "str", "str"], interactive=False, wrap=True)
557
 
558
+ with gr.Column():
559
+ gr.Markdown("<div class='comparison-header'>FINE-TUNED (MRL-Enhanced)</div>", elem_classes="comparison-header")
560
+ finetuned_results = gr.Dataframe(headers=["File", "Score", "Code Snippet"], datatype=["str", "str", "str"], interactive=False, wrap=True)
561
+
562
+ compare_btn.click(search_comparison, inputs=search_query, outputs=[baseline_results, finetuned_results])
563
+
564
+
565
+ # TAB 3: DEPLOYMENT MONITORING
566
+ with gr.Tab("3. Deployment Monitoring"):
 
567
  gr.Markdown("### Embedding Quality Analysis")
568
+ gr.Markdown("Analyze the semantic space distribution and diversity of embeddings")
569
 
570
  with gr.Row():
571
+ with gr.Column():
572
+ gr.Markdown("#### Baseline Model")
573
+ analyze_baseline_btn = gr.Button("Analyze Baseline Embeddings", variant="secondary")
574
+ baseline_metrics = gr.Textbox(label="Baseline Metrics")
575
+ baseline_plot = gr.ScatterPlot(label="Baseline Semantic Space (PCA)")
576
+ analyze_baseline_btn.click(analyze_embeddings_baseline, inputs=[], outputs=[baseline_metrics, baseline_plot])
577
+
578
+ with gr.Column():
579
+ gr.Markdown("#### Fine-tuned Model")
580
+ analyze_finetuned_btn = gr.Button("Analyze Fine-tuned Embeddings", variant="secondary")
581
+ finetuned_metrics = gr.Textbox(label="Fine-tuned Metrics")
582
+ finetuned_plot = gr.ScatterPlot(label="Fine-tuned Semantic Space (PCA)")
583
+ analyze_finetuned_btn.click(analyze_embeddings_finetuned, inputs=[], outputs=[finetuned_metrics, finetuned_plot])
584
 
585
+ gr.Markdown("---")
586
+ gr.Markdown("### Retrieval Performance Evaluation")
587
+ gr.Markdown("Evaluate retrieval accuracy using synthetic queries (query = first 3 lines of code)")
 
 
 
588
 
589
+ eval_size = gr.Slider(minimum=10, maximum=500, value=50, step=10, label="Sample Size (Chunks to Evaluate)")
590
 
591
+ with gr.Row():
592
+ with gr.Column():
593
+ gr.Markdown("#### Baseline Evaluation")
594
+ eval_baseline_btn = gr.Button("Run Baseline Evaluation", variant="primary")
595
+ baseline_eval_output = gr.Textbox(label="Baseline Results")
596
+ eval_baseline_btn.click(evaluate_retrieval_baseline, inputs=[eval_size], outputs=baseline_eval_output)
597
+
598
+ with gr.Column():
599
+ gr.Markdown("#### Fine-tuned Evaluation")
600
+ eval_finetuned_btn = gr.Button("Run Fine-tuned Evaluation", variant="primary")
601
+ finetuned_eval_output = gr.Textbox(label="Fine-tuned Results")
602
+ eval_finetuned_btn.click(evaluate_retrieval_finetuned, inputs=[eval_size], outputs=finetuned_eval_output)
603
 
604
  if __name__ == "__main__":
605
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False)