CodeMode Agent commited on
Commit ·
48ca3cd
1
Parent(s): fb9394b
Deploy CodeMode via Agent
Browse files
app.py
CHANGED
|
@@ -10,114 +10,171 @@ from pathlib import Path
|
|
| 10 |
import chromadb
|
| 11 |
from chromadb.config import Settings
|
| 12 |
import uuid
|
|
|
|
| 13 |
|
| 14 |
-
# --- Add scripts to path
|
| 15 |
-
|
| 16 |
-
sys.path.append(os.path.dirname(__file__))
|
| 17 |
from scripts.core.ingestion.ingest import GitCrawler
|
| 18 |
from scripts.core.ingestion.chunk import RepoChunker
|
| 19 |
|
| 20 |
# --- Configuration ---
|
| 21 |
-
|
|
|
|
| 22 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 23 |
-
DB_DIR = Path("data/
|
| 24 |
DB_DIR.mkdir(parents=True, exist_ok=True)
|
| 25 |
|
| 26 |
-
print(f"Loading
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
chroma_client = chromadb.PersistentClient(path=str(DB_DIR))
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
#
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
|
| 42 |
-
def compute_embeddings(text_list):
|
| 43 |
-
"""Batch compute embeddings"""
|
| 44 |
if not text_list: return None
|
| 45 |
-
|
| 46 |
-
inputs = tokenizer(text_list, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
|
| 47 |
with torch.no_grad():
|
| 48 |
-
out =
|
| 49 |
emb = out.last_hidden_state.mean(dim=1)
|
| 50 |
return F.normalize(emb, p=2, dim=1)
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
try:
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
except Exception as e:
|
| 59 |
-
return f"Error
|
| 60 |
|
| 61 |
-
def
|
| 62 |
-
|
| 63 |
-
if
|
|
|
|
| 64 |
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
if query_emb is None: return []
|
| 67 |
-
|
| 68 |
-
# Convert tensor to list for Chroma
|
| 69 |
query_vec = query_emb.cpu().numpy().tolist()[0]
|
| 70 |
-
|
| 71 |
-
results = collection.query(
|
| 72 |
-
query_embeddings=[query_vec],
|
| 73 |
-
n_results=min(top_k, collection.count()),
|
| 74 |
-
include=["metadatas", "documents", "distances"]
|
| 75 |
-
)
|
| 76 |
-
|
| 77 |
-
# Parse items
|
| 78 |
output = []
|
| 79 |
if results['ids']:
|
| 80 |
for i in range(len(results['ids'][0])):
|
| 81 |
meta = results['metadatas'][0][i]
|
| 82 |
code = results['documents'][0][i]
|
| 83 |
dist = results['distances'][0][i]
|
| 84 |
-
score = 1 - dist
|
| 85 |
-
|
| 86 |
-
link_icon = "[Link]" if score > 0.7 else ""
|
| 87 |
-
output.append([meta.get("file_name", "unknown"), f"{score:.4f} {link_icon}", code[:300] + "..."])
|
| 88 |
-
|
| 89 |
return output
|
| 90 |
|
| 91 |
-
def
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
if not repo_url.startswith("http"):
|
| 99 |
-
|
|
|
|
| 100 |
|
| 101 |
DATA_DIR = Path(os.path.abspath("data/raw_ingest"))
|
| 102 |
import stat
|
| 103 |
def remove_readonly(func, path, _):
|
| 104 |
os.chmod(path, stat.S_IWRITE)
|
| 105 |
func(path)
|
| 106 |
-
|
| 107 |
try:
|
| 108 |
-
# Clean up old raw data
|
| 109 |
if DATA_DIR.exists():
|
| 110 |
shutil.rmtree(DATA_DIR, onerror=remove_readonly)
|
| 111 |
|
| 112 |
-
# 1. Clone
|
| 113 |
yield f"Cloning {repo_url}..."
|
| 114 |
crawler = GitCrawler(cache_dir=DATA_DIR)
|
| 115 |
repo_path = crawler.clone_repository(repo_url)
|
| 116 |
-
|
| 117 |
if not repo_path:
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
yield "Listing files..."
|
| 122 |
files = crawler.list_files(repo_path, extensions={'.py', '.md', '.json', '.js', '.ts', '.java', '.cpp'})
|
| 123 |
if isinstance(files, tuple): files = [f.path for f in files[0]]
|
|
@@ -136,65 +193,124 @@ def fn_ingest(repo_url):
|
|
| 136 |
all_chunks.extend(file_chunks)
|
| 137 |
except Exception as e:
|
| 138 |
print(f"Skipping {file_path}: {e}")
|
| 139 |
-
|
| 140 |
if not all_chunks:
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
total_chunks = len(all_chunks)
|
| 145 |
-
yield f"Generated {total_chunks} chunks. Embedding
|
| 146 |
|
| 147 |
batch_size = 64
|
|
|
|
| 148 |
for i in range(0, total_chunks, batch_size):
|
| 149 |
batch = all_chunks[i:i+batch_size]
|
| 150 |
-
|
| 151 |
-
# Prepare data
|
| 152 |
texts = [c.code for c in batch]
|
| 153 |
ids = [str(uuid.uuid4()) for _ in batch]
|
| 154 |
metadatas = [{"file_name": Path(c.file_path).name, "url": repo_url} for c in batch]
|
| 155 |
|
| 156 |
-
|
| 157 |
-
embeddings = compute_embeddings(texts)
|
| 158 |
if embeddings is not None:
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
-
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
-
count = collection.count()
|
| 171 |
-
yield f"Success! Database now has {count} code chunks. Ready for search."
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
except Exception as e:
|
| 174 |
import traceback
|
| 175 |
traceback.print_exc()
|
| 176 |
yield f"Error: {str(e)}"
|
| 177 |
|
| 178 |
-
# --- Analysis Functions ---
|
| 179 |
-
def
|
| 180 |
-
count =
|
| 181 |
if count < 5:
|
| 182 |
return "Not enough data (Need > 5 chunks).", None
|
| 183 |
|
| 184 |
try:
|
| 185 |
-
# Fetch all embeddings (Limit to 2000 for visualization speed)
|
| 186 |
limit = min(count, 2000)
|
| 187 |
-
data =
|
| 188 |
|
| 189 |
X = torch.tensor(data['embeddings'])
|
| 190 |
-
|
| 191 |
-
# PCA
|
| 192 |
X_mean = torch.mean(X, 0)
|
| 193 |
X_centered = X - X_mean
|
| 194 |
U, S, V = torch.pca_lowrank(X_centered, q=2)
|
| 195 |
projected = torch.matmul(X_centered, V[:, :2]).numpy()
|
| 196 |
|
| 197 |
-
# Diversity
|
| 198 |
indices = torch.randint(0, len(X), (min(100, len(X)),))
|
| 199 |
sample = X[indices]
|
| 200 |
sim_matrix = torch.mm(sample, sample.t())
|
|
@@ -203,10 +319,11 @@ def fn_analyze_embeddings():
|
|
| 203 |
diversity_score = 1.0 - avg_sim
|
| 204 |
|
| 205 |
metrics = (
|
|
|
|
| 206 |
f"Total Chunks: {count}\n"
|
| 207 |
-
f"Analyzed: {len(X)}
|
| 208 |
f"Diversity Score: {diversity_score:.4f}\n"
|
| 209 |
-
f"
|
| 210 |
)
|
| 211 |
|
| 212 |
plot_df = pd.DataFrame({
|
|
@@ -215,22 +332,61 @@ def fn_analyze_embeddings():
|
|
| 215 |
"topic": [m.get("file_name", "unknown") for m in data['metadatas']]
|
| 216 |
})
|
| 217 |
|
| 218 |
-
return metrics, gr.ScatterPlot(value=plot_df, x="x", y="y", color="topic", title="Semantic Space", tooltip="topic")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
|
|
|
| 220 |
except Exception as e:
|
| 221 |
import traceback
|
| 222 |
traceback.print_exc()
|
| 223 |
-
return f"
|
| 224 |
|
| 225 |
-
def
|
| 226 |
-
count =
|
| 227 |
if count < 10: return "Not enough data for evaluation (Need > 10 chunks)."
|
| 228 |
|
| 229 |
try:
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
fetch_limit = min(count, 2000) # Fetch up to 2k to sample from
|
| 233 |
-
data = collection.get(limit=fetch_limit, include=["documents"])
|
| 234 |
|
| 235 |
import random
|
| 236 |
actual_sample_size = min(sample_limit, len(data['ids']))
|
|
@@ -240,191 +396,210 @@ def fn_evaluate_retrieval(sample_limit):
|
|
| 240 |
hits_at_5 = 0
|
| 241 |
mrr_sum = 0
|
| 242 |
|
| 243 |
-
|
| 244 |
-
yield f"Running evaluation on {actual_sample_size} chunks..."
|
| 245 |
|
| 246 |
for i, idx in enumerate(sample_indices):
|
| 247 |
target_id = data['ids'][idx]
|
| 248 |
code = data['documents'][idx]
|
| 249 |
-
|
| 250 |
-
# Synthetic Query
|
| 251 |
query = "\n".join(code.split("\n")[:3])
|
| 252 |
-
query_emb =
|
| 253 |
-
|
| 254 |
-
# Query DB
|
| 255 |
-
results = collection.query(query_embeddings=[query_emb], n_results=10)
|
| 256 |
-
|
| 257 |
-
# Check results
|
| 258 |
found_ids = results['ids'][0]
|
| 259 |
if target_id in found_ids:
|
| 260 |
rank = found_ids.index(target_id) + 1
|
| 261 |
mrr_sum += 1.0 / rank
|
| 262 |
if rank == 1: hits_at_1 += 1
|
| 263 |
if rank <= 5: hits_at_5 += 1
|
| 264 |
-
|
| 265 |
if i % 10 == 0:
|
| 266 |
-
yield f"
|
| 267 |
|
| 268 |
recall_1 = hits_at_1 / actual_sample_size
|
| 269 |
recall_5 = hits_at_5 / actual_sample_size
|
| 270 |
mrr = mrr_sum / actual_sample_size
|
| 271 |
|
| 272 |
report = (
|
| 273 |
-
f"
|
| 274 |
-
f"
|
| 275 |
f"Recall@1: {recall_1:.4f}\n"
|
| 276 |
f"Recall@5: {recall_5:.4f}\n"
|
| 277 |
-
f"MRR: {mrr:.4f}
|
| 278 |
-
f"\n(Note: Using ChromaDB for retrieval)"
|
| 279 |
)
|
| 280 |
yield report
|
| 281 |
except Exception as e:
|
| 282 |
import traceback
|
| 283 |
traceback.print_exc()
|
| 284 |
-
yield f"
|
| 285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
-
# --- UI
|
| 288 |
-
theme = gr.themes.Soft(
|
| 289 |
-
primary_hue="slate",
|
| 290 |
-
neutral_hue="slate",
|
| 291 |
-
spacing_size="sm",
|
| 292 |
-
radius_size="md"
|
| 293 |
-
).set(
|
| 294 |
-
body_background_fill="*neutral_50",
|
| 295 |
-
block_background_fill="white",
|
| 296 |
-
block_border_width="1px",
|
| 297 |
-
block_title_text_weight="600"
|
| 298 |
-
)
|
| 299 |
|
| 300 |
css = """
|
| 301 |
-
h1 {
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
margin-bottom: 1rem;
|
| 305 |
-
color: #1e293b;
|
| 306 |
-
}
|
| 307 |
-
.gradio-container {
|
| 308 |
-
max-width: 1200px !important;
|
| 309 |
-
margin: auto;
|
| 310 |
-
}
|
| 311 |
"""
|
| 312 |
|
| 313 |
-
with gr.Blocks(theme=theme, css=css, title="CodeMode") as demo:
|
| 314 |
-
gr.Markdown("# CodeMode")
|
|
|
|
| 315 |
|
| 316 |
with gr.Tabs():
|
| 317 |
-
#
|
| 318 |
-
with gr.Tab("1. Ingest
|
| 319 |
-
gr.
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
|
| 324 |
with gr.Row():
|
| 325 |
-
|
| 326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
-
with gr.Accordion("Database Inspector", open=False):
|
| 329 |
-
list_files_btn = gr.Button("Refresh File List")
|
| 330 |
-
files_df = gr.Dataframe(
|
| 331 |
-
headers=["File Name", "Chunks", "Source URL"],
|
| 332 |
-
datatype=["str", "number", "str"],
|
| 333 |
-
interactive=False
|
| 334 |
-
)
|
| 335 |
-
|
| 336 |
-
def fn_list_files():
|
| 337 |
-
count = collection.count()
|
| 338 |
-
if count == 0: return [["Database Empty", 0, "-"]]
|
| 339 |
-
|
| 340 |
-
try:
|
| 341 |
-
# Fetch all metadata (limit to 10k to prevent UI freeze)
|
| 342 |
-
limit = min(count, 10000)
|
| 343 |
-
data = collection.get(limit=limit, include=["metadatas"])
|
| 344 |
-
|
| 345 |
-
if not data or 'metadatas' not in data or data['metadatas'] is None:
|
| 346 |
-
return [["Error: No metadata found", 0, "-"]]
|
| 347 |
-
|
| 348 |
-
# Aggregate stats
|
| 349 |
-
file_counts = {} # filename -> count
|
| 350 |
-
file_urls = {} # filename -> url
|
| 351 |
-
|
| 352 |
-
for meta in data['metadatas']:
|
| 353 |
-
if meta is None: continue # Skip None entries
|
| 354 |
-
fname = meta.get("file_name", "unknown")
|
| 355 |
-
url = meta.get("url", "-")
|
| 356 |
-
file_counts[fname] = file_counts.get(fname, 0) + 1
|
| 357 |
-
file_urls[fname] = url
|
| 358 |
-
|
| 359 |
-
# Convert to list
|
| 360 |
-
output = []
|
| 361 |
-
for fname, count in file_counts.items():
|
| 362 |
-
output.append([fname, count, file_urls[fname]])
|
| 363 |
-
|
| 364 |
-
if not output:
|
| 365 |
-
return [["No files found in metadata", 0, "-"]]
|
| 366 |
-
|
| 367 |
-
# Sort by chunk count (descending)
|
| 368 |
-
output.sort(key=lambda x: x[1], reverse=True)
|
| 369 |
-
return output
|
| 370 |
-
except Exception as e:
|
| 371 |
-
import traceback
|
| 372 |
-
traceback.print_exc()
|
| 373 |
-
return [[f"Error: {str(e)}", 0, "-"]]
|
| 374 |
-
|
| 375 |
-
ingest_btn.click(fn_ingest, inputs=repo_input, outputs=[ingest_status])
|
| 376 |
-
reset_btn.click(fn=reset_db, inputs=[], outputs=[ingest_status])
|
| 377 |
-
list_files_btn.click(fn_list_files, inputs=[], outputs=[files_df])
|
| 378 |
-
|
| 379 |
-
# --- TAB 2: SEARCH ---
|
| 380 |
-
with gr.Tab("2. Semantic Search"):
|
| 381 |
-
gr.Markdown("### Search the Ingested Code")
|
| 382 |
with gr.Row():
|
| 383 |
-
|
| 384 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
|
| 386 |
-
results_df = gr.Dataframe(
|
| 387 |
-
headers=["File Name", "Score", "Code Snippet"],
|
| 388 |
-
datatype=["str", "str", "str"],
|
| 389 |
-
interactive=False,
|
| 390 |
-
wrap=True
|
| 391 |
-
)
|
| 392 |
-
search_btn.click(fn=search_codebase, inputs=search_box, outputs=results_df)
|
| 393 |
-
|
| 394 |
-
# --- TAB 3: CODE SEARCH ---
|
| 395 |
-
with gr.Tab("3. Find Similar Code"):
|
| 396 |
-
gr.Markdown("### Code-to-Code Retrieval")
|
| 397 |
with gr.Row():
|
| 398 |
-
|
| 399 |
-
|
|
|
|
| 400 |
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
with gr.Tab("4. Deployment Monitoring"):
|
| 411 |
gr.Markdown("### Embedding Quality Analysis")
|
| 412 |
-
|
| 413 |
|
| 414 |
with gr.Row():
|
| 415 |
-
|
| 416 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
gr.Markdown("
|
| 421 |
-
with gr.Row():
|
| 422 |
-
eval_size = gr.Slider(minimum=10, maximum=1000, value=50, step=10, label="Sample Size (Chunks)")
|
| 423 |
-
eval_btn = gr.Button("Run Retrieval Evaluation", variant="primary")
|
| 424 |
|
| 425 |
-
|
| 426 |
|
| 427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
|
| 429 |
if __name__ == "__main__":
|
| 430 |
-
demo.queue().launch()
|
|
|
|
| 10 |
import chromadb
|
| 11 |
from chromadb.config import Settings
|
| 12 |
import uuid
|
| 13 |
+
import tempfile
|
| 14 |
|
| 15 |
+
# --- Add scripts to path ---
|
| 16 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
|
|
|
|
| 17 |
from scripts.core.ingestion.ingest import GitCrawler
|
| 18 |
from scripts.core.ingestion.chunk import RepoChunker
|
| 19 |
|
| 20 |
# --- Configuration ---
|
| 21 |
+
BASELINE_MODEL = "microsoft/codebert-base"
|
| 22 |
+
FINETUNED_MODEL = "shubharuidas/codebert-base-code-embed-mrl-langchain-langgraph"
|
| 23 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 24 |
+
DB_DIR = Path(os.path.abspath("data/chroma_db_comparison"))
|
| 25 |
DB_DIR.mkdir(parents=True, exist_ok=True)
|
| 26 |
|
| 27 |
+
print(f"Loading models on {DEVICE}...")
|
| 28 |
+
print("1. Loading baseline model...")
|
| 29 |
+
baseline_tokenizer = AutoTokenizer.from_pretrained(BASELINE_MODEL)
|
| 30 |
+
baseline_model = AutoModel.from_pretrained(BASELINE_MODEL)
|
| 31 |
+
baseline_model.to(DEVICE)
|
| 32 |
+
baseline_model.eval()
|
| 33 |
|
| 34 |
+
print("2. Loading fine-tuned model...")
|
| 35 |
+
finetuned_tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL)
|
| 36 |
+
finetuned_model = AutoModel.from_pretrained(FINETUNED_MODEL)
|
| 37 |
+
finetuned_model.to(DEVICE)
|
| 38 |
+
finetuned_model.eval()
|
| 39 |
+
print("Both models loaded!")
|
| 40 |
+
|
| 41 |
+
# --- ChromaDB Setup ---
|
| 42 |
chroma_client = chromadb.PersistentClient(path=str(DB_DIR))
|
| 43 |
+
baseline_collection = chroma_client.get_or_create_collection(name="baseline_rag", metadata={"hnsw:space": "cosine"})
|
| 44 |
+
finetuned_collection = chroma_client.get_or_create_collection(name="finetuned_rag", metadata={"hnsw:space": "cosine"})
|
| 45 |
|
| 46 |
+
# --- Embedding Functions ---
|
| 47 |
+
def compute_baseline_embeddings(text_list):
|
| 48 |
+
if not text_list: return None
|
| 49 |
+
inputs = baseline_tokenizer(text_list, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
|
| 50 |
+
with torch.no_grad():
|
| 51 |
+
out = baseline_model(**inputs)
|
| 52 |
+
emb = out.last_hidden_state.mean(dim=1)
|
| 53 |
+
return F.normalize(emb, p=2, dim=1)
|
| 54 |
|
| 55 |
+
def compute_finetuned_embeddings(text_list):
|
|
|
|
|
|
|
| 56 |
if not text_list: return None
|
| 57 |
+
inputs = finetuned_tokenizer(text_list, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
|
|
|
|
| 58 |
with torch.no_grad():
|
| 59 |
+
out = finetuned_model(**inputs)
|
| 60 |
emb = out.last_hidden_state.mean(dim=1)
|
| 61 |
return F.normalize(emb, p=2, dim=1)
|
| 62 |
|
| 63 |
+
# --- Reset Functions ---
|
| 64 |
+
def reset_baseline():
|
| 65 |
+
chroma_client.delete_collection("baseline_rag")
|
| 66 |
+
global baseline_collection
|
| 67 |
+
baseline_collection = chroma_client.get_or_create_collection(name="baseline_rag", metadata={"hnsw:space": "cosine"})
|
| 68 |
+
return "Baseline database reset."
|
| 69 |
+
|
| 70 |
+
def reset_finetuned():
|
| 71 |
+
chroma_client.delete_collection("finetuned_rag")
|
| 72 |
+
global finetuned_collection
|
| 73 |
+
finetuned_collection = chroma_client.get_or_create_collection(name="finetuned_rag", metadata={"hnsw:space": "cosine"})
|
| 74 |
+
return "Fine-tuned database reset."
|
| 75 |
+
|
| 76 |
+
# --- Database Inspector Functions ---
|
| 77 |
+
def list_baseline_files():
|
| 78 |
+
count = baseline_collection.count()
|
| 79 |
+
if count == 0:
|
| 80 |
+
return [["No data indexed yet", "-", "-"]]
|
| 81 |
+
|
| 82 |
try:
|
| 83 |
+
data = baseline_collection.get(limit=min(count, 1000), include=["metadatas"])
|
| 84 |
+
file_stats = {}
|
| 85 |
+
for meta in data['metadatas']:
|
| 86 |
+
fname = meta.get("file_name", "unknown")
|
| 87 |
+
url = meta.get("url", "unknown")
|
| 88 |
+
if fname not in file_stats:
|
| 89 |
+
file_stats[fname] = {"count": 0, "url": url}
|
| 90 |
+
file_stats[fname]["count"] += 1
|
| 91 |
+
|
| 92 |
+
results = [[fname, stats["count"], stats["url"]] for fname, stats in file_stats.items()]
|
| 93 |
+
return sorted(results, key=lambda x: x[1], reverse=True)
|
| 94 |
except Exception as e:
|
| 95 |
+
return [[f"Error: {str(e)}", "-", "-"]]
|
| 96 |
|
| 97 |
+
def list_finetuned_files():
|
| 98 |
+
count = finetuned_collection.count()
|
| 99 |
+
if count == 0:
|
| 100 |
+
return [["No data indexed yet", "-", "-"]]
|
| 101 |
|
| 102 |
+
try:
|
| 103 |
+
data = finetuned_collection.get(limit=min(count, 1000), include=["metadatas"])
|
| 104 |
+
file_stats = {}
|
| 105 |
+
for meta in data['metadatas']:
|
| 106 |
+
fname = meta.get("file_name", "unknown")
|
| 107 |
+
url = meta.get("url", "unknown")
|
| 108 |
+
if fname not in file_stats:
|
| 109 |
+
file_stats[fname] = {"count": 0, "url": url}
|
| 110 |
+
file_stats[fname]["count"] += 1
|
| 111 |
+
|
| 112 |
+
results = [[fname, stats["count"], stats["url"]] for fname, stats in file_stats.items()]
|
| 113 |
+
return sorted(results, key=lambda x: x[1], reverse=True)
|
| 114 |
+
except Exception as e:
|
| 115 |
+
return [[f"Error: {str(e)}", "-", "-"]]
|
| 116 |
+
|
| 117 |
+
# --- Search Functions ---
|
| 118 |
+
def search_baseline(query, top_k=5):
|
| 119 |
+
if baseline_collection.count() == 0: return []
|
| 120 |
+
query_emb = compute_baseline_embeddings([query])
|
| 121 |
if query_emb is None: return []
|
|
|
|
|
|
|
| 122 |
query_vec = query_emb.cpu().numpy().tolist()[0]
|
| 123 |
+
results = baseline_collection.query(query_embeddings=[query_vec], n_results=min(top_k, baseline_collection.count()), include=["metadatas", "documents", "distances"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
output = []
|
| 125 |
if results['ids']:
|
| 126 |
for i in range(len(results['ids'][0])):
|
| 127 |
meta = results['metadatas'][0][i]
|
| 128 |
code = results['documents'][0][i]
|
| 129 |
dist = results['distances'][0][i]
|
| 130 |
+
score = 1 - dist
|
| 131 |
+
output.append([meta.get("file_name", "unknown"), f"{score:.4f}", code[:300] + "..."])
|
|
|
|
|
|
|
|
|
|
| 132 |
return output
|
| 133 |
|
| 134 |
+
def search_finetuned(query, top_k=5):
|
| 135 |
+
if finetuned_collection.count() == 0: return []
|
| 136 |
+
query_emb = compute_finetuned_embeddings([query])
|
| 137 |
+
if query_emb is None: return []
|
| 138 |
+
query_vec = query_emb.cpu().numpy().tolist()[0]
|
| 139 |
+
results = finetuned_collection.query(query_embeddings=[query_vec], n_results=min(top_k, finetuned_collection.count()), include=["metadatas", "documents", "distances"])
|
| 140 |
+
output = []
|
| 141 |
+
if results['ids']:
|
| 142 |
+
for i in range(len(results['ids'][0])):
|
| 143 |
+
meta = results['metadatas'][0][i]
|
| 144 |
+
code = results['documents'][0][i]
|
| 145 |
+
dist = results['distances'][0][i]
|
| 146 |
+
score = 1 - dist
|
| 147 |
+
output.append([meta.get("file_name", "unknown"), f"{score:.4f}", code[:300] + "..."])
|
| 148 |
+
return output
|
| 149 |
+
|
| 150 |
+
def search_comparison(query, top_k=5):
|
| 151 |
+
baseline_results = search_baseline(query, top_k)
|
| 152 |
+
finetuned_results = search_finetuned(query, top_k)
|
| 153 |
+
return baseline_results, finetuned_results
|
| 154 |
+
|
| 155 |
+
# --- Ingestion Functions ---
|
| 156 |
+
def ingest_from_url(repo_url):
|
| 157 |
if not repo_url.startswith("http"):
|
| 158 |
+
yield "Invalid URL"
|
| 159 |
+
return
|
| 160 |
|
| 161 |
DATA_DIR = Path(os.path.abspath("data/raw_ingest"))
|
| 162 |
import stat
|
| 163 |
def remove_readonly(func, path, _):
|
| 164 |
os.chmod(path, stat.S_IWRITE)
|
| 165 |
func(path)
|
| 166 |
+
|
| 167 |
try:
|
|
|
|
| 168 |
if DATA_DIR.exists():
|
| 169 |
shutil.rmtree(DATA_DIR, onerror=remove_readonly)
|
| 170 |
|
|
|
|
| 171 |
yield f"Cloning {repo_url}..."
|
| 172 |
crawler = GitCrawler(cache_dir=DATA_DIR)
|
| 173 |
repo_path = crawler.clone_repository(repo_url)
|
|
|
|
| 174 |
if not repo_path:
|
| 175 |
+
yield "Failed to clone repository."
|
| 176 |
+
return
|
| 177 |
+
|
| 178 |
yield "Listing files..."
|
| 179 |
files = crawler.list_files(repo_path, extensions={'.py', '.md', '.json', '.js', '.ts', '.java', '.cpp'})
|
| 180 |
if isinstance(files, tuple): files = [f.path for f in files[0]]
|
|
|
|
| 193 |
all_chunks.extend(file_chunks)
|
| 194 |
except Exception as e:
|
| 195 |
print(f"Skipping {file_path}: {e}")
|
| 196 |
+
|
| 197 |
if not all_chunks:
|
| 198 |
+
yield "No valid chunks found."
|
| 199 |
+
return
|
| 200 |
+
|
| 201 |
total_chunks = len(all_chunks)
|
| 202 |
+
yield f"Generated {total_chunks} chunks. Embedding (BASELINE)..."
|
| 203 |
|
| 204 |
batch_size = 64
|
| 205 |
+
# Index with baseline
|
| 206 |
for i in range(0, total_chunks, batch_size):
|
| 207 |
batch = all_chunks[i:i+batch_size]
|
|
|
|
|
|
|
| 208 |
texts = [c.code for c in batch]
|
| 209 |
ids = [str(uuid.uuid4()) for _ in batch]
|
| 210 |
metadatas = [{"file_name": Path(c.file_path).name, "url": repo_url} for c in batch]
|
| 211 |
|
| 212 |
+
embeddings = compute_baseline_embeddings(texts)
|
|
|
|
| 213 |
if embeddings is not None:
|
| 214 |
+
baseline_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts)
|
| 215 |
+
yield f"Baseline: {min(i+batch_size, total_chunks)}/{total_chunks}"
|
| 216 |
+
|
| 217 |
+
yield f"Embedding (FINE-TUNED)..."
|
| 218 |
+
# Index with fine-tuned
|
| 219 |
+
for i in range(0, total_chunks, batch_size):
|
| 220 |
+
batch = all_chunks[i:i+batch_size]
|
| 221 |
+
texts = [c.code for c in batch]
|
| 222 |
+
ids = [str(uuid.uuid4()) for _ in batch]
|
| 223 |
+
metadatas = [{"file_name": Path(c.file_path).name, "url": repo_url} for c in batch]
|
| 224 |
|
| 225 |
+
embeddings = compute_finetuned_embeddings(texts)
|
| 226 |
+
if embeddings is not None:
|
| 227 |
+
finetuned_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts)
|
| 228 |
+
yield f"Fine-tuned: {min(i+batch_size, total_chunks)}/{total_chunks}"
|
| 229 |
+
|
| 230 |
+
yield f"SUCCESS! Indexed {total_chunks} chunks in both databases."
|
| 231 |
+
except Exception as e:
|
| 232 |
+
import traceback
|
| 233 |
+
traceback.print_exc()
|
| 234 |
+
yield f"Error: {str(e)}"
|
| 235 |
+
|
| 236 |
+
def ingest_from_files(files):
|
| 237 |
+
if not files or len(files) == 0:
|
| 238 |
+
yield "No files uploaded."
|
| 239 |
+
return
|
| 240 |
+
|
| 241 |
+
try:
|
| 242 |
+
yield f"Processing {len(files)} file(s)..."
|
| 243 |
+
|
| 244 |
+
chunker = RepoChunker()
|
| 245 |
+
all_chunks = []
|
| 246 |
+
|
| 247 |
+
for i, file in enumerate(files):
|
| 248 |
+
yield f"Chunking file {i+1}/{len(files)}: {Path(file.name).name}"
|
| 249 |
+
try:
|
| 250 |
+
# Gradio file upload: file.name contains the temp path
|
| 251 |
+
file_path = Path(file.name)
|
| 252 |
+
meta = {"file_name": file_path.name, "url": "uploaded"}
|
| 253 |
+
file_chunks = chunker.chunk_file(file_path, repo_metadata=meta)
|
| 254 |
+
all_chunks.extend(file_chunks)
|
| 255 |
+
except Exception as e:
|
| 256 |
+
yield f"Error chunking {Path(file.name).name}: {str(e)}"
|
| 257 |
+
import traceback
|
| 258 |
+
traceback.print_exc()
|
| 259 |
|
|
|
|
|
|
|
| 260 |
|
| 261 |
+
if not all_chunks:
|
| 262 |
+
yield "No valid chunks found."
|
| 263 |
+
return
|
| 264 |
+
|
| 265 |
+
total_chunks = len(all_chunks)
|
| 266 |
+
yield f"Generated {total_chunks} chunks. Embedding (BASELINE)..."
|
| 267 |
+
|
| 268 |
+
batch_size = 64
|
| 269 |
+
for i in range(0, total_chunks, batch_size):
|
| 270 |
+
batch = all_chunks[i:i+batch_size]
|
| 271 |
+
texts = [c.code for c in batch]
|
| 272 |
+
ids = [str(uuid.uuid4()) for _ in batch]
|
| 273 |
+
metadatas = [{"file_name": Path(c.file_path).name, "url": "uploaded"} for c in batch]
|
| 274 |
+
|
| 275 |
+
embeddings = compute_baseline_embeddings(texts)
|
| 276 |
+
if embeddings is not None:
|
| 277 |
+
baseline_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts)
|
| 278 |
+
yield f"Baseline: {min(i+batch_size, total_chunks)}/{total_chunks}"
|
| 279 |
+
|
| 280 |
+
yield f"Embedding (FINE-TUNED)..."
|
| 281 |
+
for i in range(0, total_chunks, batch_size):
|
| 282 |
+
batch = all_chunks[i:i+batch_size]
|
| 283 |
+
texts = [c.code for c in batch]
|
| 284 |
+
ids = [str(uuid.uuid4()) for _ in batch]
|
| 285 |
+
metadatas = [{"file_name": Path(c.file_path).name, "url": "uploaded"} for c in batch]
|
| 286 |
+
|
| 287 |
+
embeddings = compute_finetuned_embeddings(texts)
|
| 288 |
+
if embeddings is not None:
|
| 289 |
+
finetuned_collection.add(ids=ids, embeddings=embeddings.cpu().numpy().tolist(), metadatas=metadatas, documents=texts)
|
| 290 |
+
yield f"Fine-tuned: {min(i+batch_size, total_chunks)}/{total_chunks}"
|
| 291 |
+
|
| 292 |
+
yield f"SUCCESS! Indexed {total_chunks} chunks from uploaded files."
|
| 293 |
except Exception as e:
|
| 294 |
import traceback
|
| 295 |
traceback.print_exc()
|
| 296 |
yield f"Error: {str(e)}"
|
| 297 |
|
| 298 |
+
# --- Analysis & Evaluation Functions ---
|
| 299 |
+
def analyze_embeddings_baseline():
|
| 300 |
+
count = baseline_collection.count()
|
| 301 |
if count < 5:
|
| 302 |
return "Not enough data (Need > 5 chunks).", None
|
| 303 |
|
| 304 |
try:
|
|
|
|
| 305 |
limit = min(count, 2000)
|
| 306 |
+
data = baseline_collection.get(limit=limit, include=["embeddings", "metadatas"])
|
| 307 |
|
| 308 |
X = torch.tensor(data['embeddings'])
|
|
|
|
|
|
|
| 309 |
X_mean = torch.mean(X, 0)
|
| 310 |
X_centered = X - X_mean
|
| 311 |
U, S, V = torch.pca_lowrank(X_centered, q=2)
|
| 312 |
projected = torch.matmul(X_centered, V[:, :2]).numpy()
|
| 313 |
|
|
|
|
| 314 |
indices = torch.randint(0, len(X), (min(100, len(X)),))
|
| 315 |
sample = X[indices]
|
| 316 |
sim_matrix = torch.mm(sample, sample.t())
|
|
|
|
| 319 |
diversity_score = 1.0 - avg_sim
|
| 320 |
|
| 321 |
metrics = (
|
| 322 |
+
f"BASELINE MODEL\n"
|
| 323 |
f"Total Chunks: {count}\n"
|
| 324 |
+
f"Analyzed: {len(X)}\n"
|
| 325 |
f"Diversity Score: {diversity_score:.4f}\n"
|
| 326 |
+
f"Avg Similarity: {avg_sim:.4f}"
|
| 327 |
)
|
| 328 |
|
| 329 |
plot_df = pd.DataFrame({
|
|
|
|
| 332 |
"topic": [m.get("file_name", "unknown") for m in data['metadatas']]
|
| 333 |
})
|
| 334 |
|
| 335 |
+
return metrics, gr.ScatterPlot(value=plot_df, x="x", y="y", color="topic", title="Baseline Semantic Space", tooltip="topic")
|
| 336 |
+
except Exception as e:
|
| 337 |
+
import traceback
|
| 338 |
+
traceback.print_exc()
|
| 339 |
+
return f"Error: {e}", None
|
| 340 |
+
|
| 341 |
+
def analyze_embeddings_finetuned():
|
| 342 |
+
count = finetuned_collection.count()
|
| 343 |
+
if count < 5:
|
| 344 |
+
return "Not enough data (Need > 5 chunks).", None
|
| 345 |
+
|
| 346 |
+
try:
|
| 347 |
+
limit = min(count, 2000)
|
| 348 |
+
data = finetuned_collection.get(limit=limit, include=["embeddings", "metadatas"])
|
| 349 |
+
|
| 350 |
+
X = torch.tensor(data['embeddings'])
|
| 351 |
+
X_mean = torch.mean(X, 0)
|
| 352 |
+
X_centered = X - X_mean
|
| 353 |
+
U, S, V = torch.pca_lowrank(X_centered, q=2)
|
| 354 |
+
projected = torch.matmul(X_centered, V[:, :2]).numpy()
|
| 355 |
+
|
| 356 |
+
indices = torch.randint(0, len(X), (min(100, len(X)),))
|
| 357 |
+
sample = X[indices]
|
| 358 |
+
sim_matrix = torch.mm(sample, sample.t())
|
| 359 |
+
mask = ~torch.eye(len(sample), dtype=bool)
|
| 360 |
+
avg_sim = sim_matrix[mask].mean().item()
|
| 361 |
+
diversity_score = 1.0 - avg_sim
|
| 362 |
+
|
| 363 |
+
metrics = (
|
| 364 |
+
f"FINE-TUNED MODEL\n"
|
| 365 |
+
f"Total Chunks: {count}\n"
|
| 366 |
+
f"Analyzed: {len(X)}\n"
|
| 367 |
+
f"Diversity Score: {diversity_score:.4f}\n"
|
| 368 |
+
f"Avg Similarity: {avg_sim:.4f}"
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
plot_df = pd.DataFrame({
|
| 372 |
+
"x": projected[:, 0],
|
| 373 |
+
"y": projected[:, 1],
|
| 374 |
+
"topic": [m.get("file_name", "unknown") for m in data['metadatas']]
|
| 375 |
+
})
|
| 376 |
|
| 377 |
+
return metrics, gr.ScatterPlot(value=plot_df, x="x", y="y", color="topic", title="Fine-tuned Semantic Space", tooltip="topic")
|
| 378 |
except Exception as e:
|
| 379 |
import traceback
|
| 380 |
traceback.print_exc()
|
| 381 |
+
return f"Error: {e}", None
|
| 382 |
|
| 383 |
+
def evaluate_retrieval_baseline(sample_limit):
|
| 384 |
+
count = baseline_collection.count()
|
| 385 |
if count < 10: return "Not enough data for evaluation (Need > 10 chunks)."
|
| 386 |
|
| 387 |
try:
|
| 388 |
+
fetch_limit = min(count, 2000)
|
| 389 |
+
data = baseline_collection.get(limit=fetch_limit, include=["documents"])
|
|
|
|
|
|
|
| 390 |
|
| 391 |
import random
|
| 392 |
actual_sample_size = min(sample_limit, len(data['ids']))
|
|
|
|
| 396 |
hits_at_5 = 0
|
| 397 |
mrr_sum = 0
|
| 398 |
|
| 399 |
+
yield f"BASELINE: Evaluating {actual_sample_size} chunks..."
|
|
|
|
| 400 |
|
| 401 |
for i, idx in enumerate(sample_indices):
|
| 402 |
target_id = data['ids'][idx]
|
| 403 |
code = data['documents'][idx]
|
|
|
|
|
|
|
| 404 |
query = "\n".join(code.split("\n")[:3])
|
| 405 |
+
query_emb = compute_baseline_embeddings([query]).cpu().numpy().tolist()[0]
|
| 406 |
+
results = baseline_collection.query(query_embeddings=[query_emb], n_results=10)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
found_ids = results['ids'][0]
|
| 408 |
if target_id in found_ids:
|
| 409 |
rank = found_ids.index(target_id) + 1
|
| 410 |
mrr_sum += 1.0 / rank
|
| 411 |
if rank == 1: hits_at_1 += 1
|
| 412 |
if rank <= 5: hits_at_5 += 1
|
|
|
|
| 413 |
if i % 10 == 0:
|
| 414 |
+
yield f"Baseline: {i}/{actual_sample_size}..."
|
| 415 |
|
| 416 |
recall_1 = hits_at_1 / actual_sample_size
|
| 417 |
recall_5 = hits_at_5 / actual_sample_size
|
| 418 |
mrr = mrr_sum / actual_sample_size
|
| 419 |
|
| 420 |
report = (
|
| 421 |
+
f"BASELINE EVALUATION ({actual_sample_size} chunks)\n"
|
| 422 |
+
f"{'='*40}\n"
|
| 423 |
f"Recall@1: {recall_1:.4f}\n"
|
| 424 |
f"Recall@5: {recall_5:.4f}\n"
|
| 425 |
+
f"MRR: {mrr:.4f}"
|
|
|
|
| 426 |
)
|
| 427 |
yield report
|
| 428 |
except Exception as e:
|
| 429 |
import traceback
|
| 430 |
traceback.print_exc()
|
| 431 |
+
yield f"Error: {e}"
|
| 432 |
|
| 433 |
+
def evaluate_retrieval_finetuned(sample_limit):
|
| 434 |
+
count = finetuned_collection.count()
|
| 435 |
+
if count < 10: return "Not enough data for evaluation (Need > 10 chunks)."
|
| 436 |
+
|
| 437 |
+
try:
|
| 438 |
+
fetch_limit = min(count, 2000)
|
| 439 |
+
data = finetuned_collection.get(limit=fetch_limit, include=["documents"])
|
| 440 |
+
|
| 441 |
+
import random
|
| 442 |
+
actual_sample_size = min(sample_limit, len(data['ids']))
|
| 443 |
+
sample_indices = random.sample(range(len(data['ids'])), actual_sample_size)
|
| 444 |
+
|
| 445 |
+
hits_at_1 = 0
|
| 446 |
+
hits_at_5 = 0
|
| 447 |
+
mrr_sum = 0
|
| 448 |
+
|
| 449 |
+
yield f"FINE-TUNED: Evaluating {actual_sample_size} chunks..."
|
| 450 |
+
|
| 451 |
+
for i, idx in enumerate(sample_indices):
|
| 452 |
+
target_id = data['ids'][idx]
|
| 453 |
+
code = data['documents'][idx]
|
| 454 |
+
query = "\n".join(code.split("\n")[:3])
|
| 455 |
+
query_emb = compute_finetuned_embeddings([query]).cpu().numpy().tolist()[0]
|
| 456 |
+
results = finetuned_collection.query(query_embeddings=[query_emb], n_results=10)
|
| 457 |
+
found_ids = results['ids'][0]
|
| 458 |
+
if target_id in found_ids:
|
| 459 |
+
rank = found_ids.index(target_id) + 1
|
| 460 |
+
mrr_sum += 1.0 / rank
|
| 461 |
+
if rank == 1: hits_at_1 += 1
|
| 462 |
+
if rank <= 5: hits_at_5 += 1
|
| 463 |
+
if i % 10 == 0:
|
| 464 |
+
yield f"Fine-tuned: {i}/{actual_sample_size}..."
|
| 465 |
+
|
| 466 |
+
recall_1 = hits_at_1 / actual_sample_size
|
| 467 |
+
recall_5 = hits_at_5 / actual_sample_size
|
| 468 |
+
mrr = mrr_sum / actual_sample_size
|
| 469 |
+
|
| 470 |
+
report = (
|
| 471 |
+
f"FINE-TUNED EVALUATION ({actual_sample_size} chunks)\n"
|
| 472 |
+
f"{'='*40}\n"
|
| 473 |
+
f"Recall@1: {recall_1:.4f}\n"
|
| 474 |
+
f"Recall@5: {recall_5:.4f}\n"
|
| 475 |
+
f"MRR: {mrr:.4f}"
|
| 476 |
+
)
|
| 477 |
+
yield report
|
| 478 |
+
except Exception as e:
|
| 479 |
+
import traceback
|
| 480 |
+
traceback.print_exc()
|
| 481 |
+
yield f"Error: {e}"
|
| 482 |
|
| 483 |
+
# --- UI ---
|
| 484 |
+
theme = gr.themes.Soft(primary_hue="slate", neutral_hue="slate", spacing_size="sm", radius_size="md").set(body_background_fill="*neutral_50", block_background_fill="white", block_border_width="1px", block_title_text_weight="600")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
|
| 486 |
css = """
|
| 487 |
+
h1 { text-align: center; font-family: 'Inter', sans-serif; margin-bottom: 1rem; color: #1e293b; }
|
| 488 |
+
.gradio-container { max-width: 1400px !important; margin: auto; }
|
| 489 |
+
.comparison-header { font-size: 1.1rem; font-weight: 600; color: #334155; text-align: center; padding: 0.5rem; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
"""
|
| 491 |
|
| 492 |
+
with gr.Blocks(theme=theme, css=css, title="CodeMode - Baseline vs Fine-tuned") as demo:
|
| 493 |
+
gr.Markdown("# CodeMode: Baseline vs Fine-tuned Model Comparison")
|
| 494 |
+
gr.Markdown("Compare retrieval performance between **microsoft/codebert-base** (baseline) and **MRL-enhanced fine-tuned** model")
|
| 495 |
|
| 496 |
with gr.Tabs():
|
| 497 |
+
# TAB 1: INGEST
|
| 498 |
+
with gr.Tab("1. Ingest Code"):
|
| 499 |
+
with gr.Tabs():
|
| 500 |
+
with gr.Tab("GitHub Repository"):
|
| 501 |
+
repo_input = gr.Textbox(label="GitHub URL", placeholder="https://github.com/pallets/flask")
|
| 502 |
+
ingest_url_btn = gr.Button("Ingest from URL", variant="primary")
|
| 503 |
+
url_status = gr.Textbox(label="Status")
|
| 504 |
+
ingest_url_btn.click(ingest_from_url, inputs=repo_input, outputs=url_status)
|
| 505 |
+
|
| 506 |
+
with gr.Tab("Upload Python Files"):
|
| 507 |
+
file_upload = gr.File(label="Upload .py files", file_types=[".py"], file_count="multiple")
|
| 508 |
+
ingest_files_btn = gr.Button("Ingest Uploaded Files", variant="primary")
|
| 509 |
+
upload_status = gr.Textbox(label="Status")
|
| 510 |
+
ingest_files_btn.click(ingest_from_files, inputs=file_upload, outputs=upload_status)
|
| 511 |
|
| 512 |
with gr.Row():
|
| 513 |
+
reset_baseline_btn = gr.Button("Reset Baseline DB", variant="stop")
|
| 514 |
+
reset_finetuned_btn = gr.Button("Reset Fine-tuned DB", variant="stop")
|
| 515 |
+
reset_status = gr.Textbox(label="Reset Status")
|
| 516 |
+
|
| 517 |
+
reset_baseline_btn.click(reset_baseline, inputs=[], outputs=reset_status)
|
| 518 |
+
reset_finetuned_btn.click(reset_finetuned, inputs=[], outputs=reset_status)
|
| 519 |
+
|
| 520 |
+
gr.Markdown("---")
|
| 521 |
+
gr.Markdown("### Database Inspector")
|
| 522 |
+
gr.Markdown("View indexed files in each collection")
|
| 523 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 524 |
with gr.Row():
|
| 525 |
+
with gr.Column():
|
| 526 |
+
gr.Markdown("#### Baseline Collection")
|
| 527 |
+
inspect_baseline_btn = gr.Button("Inspect Baseline DB", variant="secondary")
|
| 528 |
+
baseline_files_df = gr.Dataframe(
|
| 529 |
+
headers=["File Name", "Chunks", "Source URL"],
|
| 530 |
+
datatype=["str", "number", "str"],
|
| 531 |
+
interactive=False,
|
| 532 |
+
value=[["No data yet", "-", "-"]]
|
| 533 |
+
)
|
| 534 |
+
inspect_baseline_btn.click(list_baseline_files, inputs=[], outputs=baseline_files_df)
|
| 535 |
+
|
| 536 |
+
with gr.Column():
|
| 537 |
+
gr.Markdown("#### Fine-tuned Collection")
|
| 538 |
+
inspect_finetuned_btn = gr.Button("Inspect Fine-tuned DB", variant="secondary")
|
| 539 |
+
finetuned_files_df = gr.Dataframe(
|
| 540 |
+
headers=["File Name", "Chunks", "Source URL"],
|
| 541 |
+
datatype=["str", "number", "str"],
|
| 542 |
+
interactive=False,
|
| 543 |
+
value=[["No data yet", "-", "-"]]
|
| 544 |
+
)
|
| 545 |
+
inspect_finetuned_btn.click(list_finetuned_files, inputs=[], outputs=finetuned_files_df)
|
| 546 |
+
|
| 547 |
+
# TAB 2: COMPARISON SEARCH
|
| 548 |
+
with gr.Tab("2. Comparison Search (Note: Semantic search is sensitive to query phrasing)"):
|
| 549 |
+
gr.Markdown("### Side-by-Side Retrieval Comparison")
|
| 550 |
+
search_query = gr.Textbox(label="Search Query", placeholder="e.g., 'Flask route decorator'")
|
| 551 |
+
compare_btn = gr.Button("Compare Models", variant="primary")
|
| 552 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
with gr.Row():
|
| 554 |
+
with gr.Column():
|
| 555 |
+
gr.Markdown("<div class='comparison-header'>BASELINE (CodeBERT)</div>", elem_classes="comparison-header")
|
| 556 |
+
baseline_results = gr.Dataframe(headers=["File", "Score", "Code Snippet"], datatype=["str", "str", "str"], interactive=False, wrap=True)
|
| 557 |
|
| 558 |
+
with gr.Column():
|
| 559 |
+
gr.Markdown("<div class='comparison-header'>FINE-TUNED (MRL-Enhanced)</div>", elem_classes="comparison-header")
|
| 560 |
+
finetuned_results = gr.Dataframe(headers=["File", "Score", "Code Snippet"], datatype=["str", "str", "str"], interactive=False, wrap=True)
|
| 561 |
+
|
| 562 |
+
compare_btn.click(search_comparison, inputs=search_query, outputs=[baseline_results, finetuned_results])
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
# TAB 3: DEPLOYMENT MONITORING
|
| 566 |
+
with gr.Tab("3. Deployment Monitoring"):
|
|
|
|
| 567 |
gr.Markdown("### Embedding Quality Analysis")
|
| 568 |
+
gr.Markdown("Analyze the semantic space distribution and diversity of embeddings")
|
| 569 |
|
| 570 |
with gr.Row():
|
| 571 |
+
with gr.Column():
|
| 572 |
+
gr.Markdown("#### Baseline Model")
|
| 573 |
+
analyze_baseline_btn = gr.Button("Analyze Baseline Embeddings", variant="secondary")
|
| 574 |
+
baseline_metrics = gr.Textbox(label="Baseline Metrics")
|
| 575 |
+
baseline_plot = gr.ScatterPlot(label="Baseline Semantic Space (PCA)")
|
| 576 |
+
analyze_baseline_btn.click(analyze_embeddings_baseline, inputs=[], outputs=[baseline_metrics, baseline_plot])
|
| 577 |
+
|
| 578 |
+
with gr.Column():
|
| 579 |
+
gr.Markdown("#### Fine-tuned Model")
|
| 580 |
+
analyze_finetuned_btn = gr.Button("Analyze Fine-tuned Embeddings", variant="secondary")
|
| 581 |
+
finetuned_metrics = gr.Textbox(label="Fine-tuned Metrics")
|
| 582 |
+
finetuned_plot = gr.ScatterPlot(label="Fine-tuned Semantic Space (PCA)")
|
| 583 |
+
analyze_finetuned_btn.click(analyze_embeddings_finetuned, inputs=[], outputs=[finetuned_metrics, finetuned_plot])
|
| 584 |
|
| 585 |
+
gr.Markdown("---")
|
| 586 |
+
gr.Markdown("### Retrieval Performance Evaluation")
|
| 587 |
+
gr.Markdown("Evaluate retrieval accuracy using synthetic queries (query = first 3 lines of code)")
|
|
|
|
|
|
|
|
|
|
| 588 |
|
| 589 |
+
eval_size = gr.Slider(minimum=10, maximum=500, value=50, step=10, label="Sample Size (Chunks to Evaluate)")
|
| 590 |
|
| 591 |
+
with gr.Row():
|
| 592 |
+
with gr.Column():
|
| 593 |
+
gr.Markdown("#### Baseline Evaluation")
|
| 594 |
+
eval_baseline_btn = gr.Button("Run Baseline Evaluation", variant="primary")
|
| 595 |
+
baseline_eval_output = gr.Textbox(label="Baseline Results")
|
| 596 |
+
eval_baseline_btn.click(evaluate_retrieval_baseline, inputs=[eval_size], outputs=baseline_eval_output)
|
| 597 |
+
|
| 598 |
+
with gr.Column():
|
| 599 |
+
gr.Markdown("#### Fine-tuned Evaluation")
|
| 600 |
+
eval_finetuned_btn = gr.Button("Run Fine-tuned Evaluation", variant="primary")
|
| 601 |
+
finetuned_eval_output = gr.Textbox(label="Fine-tuned Results")
|
| 602 |
+
eval_finetuned_btn.click(evaluate_retrieval_finetuned, inputs=[eval_size], outputs=finetuned_eval_output)
|
| 603 |
|
| 604 |
if __name__ == "__main__":
|
| 605 |
+
demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=False)
|