Arslan1997 commited on
Commit
463a6a7
·
1 Parent(s): f04f084
Files changed (1) hide show
  1. src/routes/code_routes.py +29 -118
src/routes/code_routes.py CHANGED
@@ -283,26 +283,17 @@ def extract_relevant_error_section(error_message: str) -> str:
283
 
284
  async def fix_code_with_dspy(code: str, error: str, dataset_context: str = "", datasets: dict = None):
285
  """
286
- Fix code using DSPy with dataset context and actual datasets
287
  """
288
  try:
289
- # Create score function with actual datasets
290
- def create_score_code_with_datasets(datasets_dict):
291
- def score_code_with_datasets(args, pred):
292
- return score_code(args, pred, datasets=datasets_dict) # Fixed: use datasets= instead of session_state_datasets=
293
- return score_code_with_datasets
294
-
295
- # Create refine_fixer with datasets
296
- if datasets:
297
- score_fn = create_score_code_with_datasets(datasets)
298
- else:
299
- score_fn = score_code # Fallback to original function
300
-
301
  refine_fixer = dspy.Refine(
302
- module=dspy.Predict(code_fix),
303
  N=3,
304
  threshold=1.0,
305
- reward_fn=score_fn,
306
  fail_count=3
307
  )
308
 
@@ -311,115 +302,35 @@ async def fix_code_with_dspy(code: str, error: str, dataset_context: str = "", d
311
  if not anthropic_key:
312
  raise ValueError("ANTHROPIC_API_KEY environment variable is not set")
313
 
314
- # Find the blocks with errors
315
- faulty_blocks = identify_error_blocks(code, error)
316
-
317
- if not faulty_blocks:
318
- # If no specific errors found, fix the entire code using refine
319
- try:
320
- # Create the LM instance that will be used
321
- # thread_lm = dspy.LM("anthropic/claude-3-5-sonnet-latest", api_key=anthropic_key, max_tokens=2500)
322
- thread_lm = MODEL_OBJECTS['claude-3-5-sonnet-latest']
323
-
324
- # Define the blocking function to run in thread
325
- def run_refine_fixer():
326
- with dspy.context(lm=thread_lm):
327
- return refine_fixer(
328
- dataset_context=str(dataset_context) or "",
329
- faulty_code=str(code) or "",
330
- error=str(error) or "",
331
- )
332
-
333
- # Use asyncio.to_thread for better async integration
334
- result = await asyncio.to_thread(run_refine_fixer)
335
- return result.fixed_code
336
-
337
- except Exception as e:
338
- logger.log_message(f"Error during refine code fixing: {str(e)}", level=logging.ERROR)
339
- raise e
340
-
341
- # Start with the original code
342
- result_code = code.replace("```python", "").replace("```", "")
343
-
344
- # Fix each faulty block separately using async refine
345
  try:
 
346
  thread_lm = MODEL_OBJECTS['claude-3-5-sonnet-latest']
347
 
348
- for agent_name, block_code, specific_error in faulty_blocks:
349
- try:
350
- # Extract inner code between the markers
351
- inner_code_match = re.search(r'#\s+\w+\s+code\s+start\s*\n([\s\S]*?)#\s+\w+\s+code\s+end', block_code)
352
- if not inner_code_match:
353
- continue
354
-
355
- inner_code = inner_code_match.group(1).strip()
356
-
357
- # Find markers
358
- start_marker_match = re.search(r'(#\s+\w+\s+code\s+start)', block_code)
359
- end_marker_match = re.search(r'(#\s+\w+\s+code\s+end)', block_code)
360
-
361
- if not start_marker_match or not end_marker_match:
362
- logger.log_message(f"Could not find start/end markers for {agent_name}", level=logging.WARNING)
363
- continue
364
-
365
- start_marker = start_marker_match.group(1)
366
- end_marker = end_marker_match.group(1)
367
-
368
- # Extract the error type and actual error message
369
- error_type = ""
370
- error_msg = specific_error
371
-
372
- # Look for common error patterns to provide focused context to the LLM
373
- error_type_match = re.search(r'(TypeError|ValueError|AttributeError|IndexError|KeyError|NameError):\s*([^\n]+)', specific_error)
374
- if error_type_match:
375
- error_type = error_type_match.group(1)
376
- error_msg = f"{error_type}: {error_type_match.group(2)}"
377
-
378
- # Add problem location if available
379
- if "Problem at this location:" in specific_error:
380
- problem_section = re.search(r'Problem at this location:([\s\S]*?)(?:\n\n|$)', specific_error)
381
- if problem_section:
382
- error_msg = f"{error_msg}\n\nProblem at: {problem_section.group(1).strip()}"
383
-
384
- # Define the blocking function to run in thread for this specific block
385
- def run_block_fixer():
386
- with dspy.context(lm=thread_lm):
387
- return refine_fixer(
388
- dataset_context=str(dataset_context) or "",
389
- faulty_code=str(inner_code) or "",
390
- error=str(error_msg) or "",
391
- )
392
-
393
- # Use asyncio.to_thread for better async integration
394
- result = await asyncio.to_thread(run_block_fixer)
395
-
396
- # Ensure the fixed code is properly stripped and doesn't include markers
397
- fixed_inner_code = result.fixed_code.strip()
398
- if fixed_inner_code.startswith('#') and 'code start' in fixed_inner_code:
399
- # If LLM included markers in response, extract only inner code
400
- inner_match = re.search(r'#\s+\w+\s+code\s+start\s*\n([\s\S]*?)#\s+\w+\s+code\s+end', fixed_inner_code)
401
- if inner_match:
402
- fixed_inner_code = inner_match.group(1).strip()
403
-
404
- # Reconstruct the block with fixed code
405
- fixed_block = f"{start_marker}\n\n{fixed_inner_code}\n\n{end_marker}"
406
-
407
- # Replace the original block with the fixed block in the full code
408
- result_code = result_code.replace(block_code, fixed_block)
409
-
410
- except Exception as e:
411
- # Log the error but continue with other blocks
412
- logger.log_message(f"Error fixing {agent_name} block: {str(e)}", level=logging.ERROR)
413
- continue
414
-
415
  except Exception as e:
416
- logger.log_message(f"Error during async code fixing: {str(e)}", level=logging.ERROR)
417
- raise e
418
-
419
- return result_code
420
  except Exception as e:
421
  logger.log_message(f"Error in fix_code_with_dspy: {str(e)}", level=logging.ERROR)
422
- raise e
423
 
424
  def get_dataset_context(df):
425
  """
 
283
 
284
  async def fix_code_with_dspy(code: str, error: str, dataset_context: str = "", datasets: dict = None):
285
  """
286
+ Fix code using DSPy Refine with datasets-aware reward function
287
  """
288
  try:
289
+ # Wrap score_code to fix datasets argument
290
+ reward_fn_with_datasets = lambda args, pred: score_code(args, pred, datasets=datasets)
291
+
 
 
 
 
 
 
 
 
 
292
  refine_fixer = dspy.Refine(
293
+ module=dspy.Predict(code_fix),
294
  N=3,
295
  threshold=1.0,
296
+ reward_fn=reward_fn_with_datasets,
297
  fail_count=3
298
  )
299
 
 
302
  if not anthropic_key:
303
  raise ValueError("ANTHROPIC_API_KEY environment variable is not set")
304
 
305
+ # Fix the entire code using refine
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  try:
307
+ # Create the LM instance that will be used
308
  thread_lm = MODEL_OBJECTS['claude-3-5-sonnet-latest']
309
 
310
+ # Define the blocking function to run in thread
311
+ def run_refine_fixer():
312
+ with dspy.context(lm=thread_lm):
313
+ return refine_fixer(
314
+ dataset_context=str(dataset_context) or "",
315
+ faulty_code=str(code) or "",
316
+ error=str(error) or "",
317
+ )
318
+
319
+ # Use asyncio.to_thread for better async integration
320
+ result = await asyncio.to_thread(run_refine_fixer)
321
+
322
+ if not hasattr(result, 'fixed_code'):
323
+ raise ValueError("DSPy Refine did not return a result with 'fixed_code' attribute")
324
+
325
+ return result.fixed_code
326
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  except Exception as e:
328
+ logger.log_message(f"Error during refine code fixing: {str(e)}", level=logging.ERROR)
329
+ raise RuntimeError(f"Code fixing failed: {str(e)}") from e
330
+
 
331
  except Exception as e:
332
  logger.log_message(f"Error in fix_code_with_dspy: {str(e)}", level=logging.ERROR)
333
+ raise RuntimeError(f"Fix code setup failed: {str(e)}") from e
334
 
335
  def get_dataset_context(df):
336
  """