Arslan1997 commited on
Commit
d3f8066
·
1 Parent(s): 9eece41

added stuff

Browse files
Files changed (4) hide show
  1. app.py +1504 -4
  2. src/agents/agents.py +8 -8
  3. src/agents/deep_agents.py +40 -3
  4. src/routes/code_routes.py +126 -117
app.py CHANGED
@@ -1,77 +1,151 @@
1
  # Standard library imports
 
2
  import asyncio
 
3
  import json
 
4
  import logging
 
5
  import os
 
6
  import time
 
7
  import uuid
 
8
  from io import StringIO
 
9
  from typing import List, Optional
 
10
  import ast
 
11
  import markdown
 
12
  from bs4 import BeautifulSoup
 
13
  import pandas as pd
 
14
  from datetime import datetime, UTC
 
15
  # Third-party imports
 
16
  import uvicorn
 
17
  from dotenv import load_dotenv
 
18
  from fastapi import (
 
19
  Depends,
 
20
  FastAPI,
 
21
  File,
 
22
  Form,
 
23
  HTTPException,
 
24
  Request,
 
25
  UploadFile
 
26
  )
 
27
  from fastapi.middleware.cors import CORSMiddleware
 
28
  from fastapi.responses import JSONResponse, StreamingResponse
 
29
  from fastapi.security import APIKeyHeader
 
30
  from llama_index.core import Document, VectorStoreIndex
 
31
  from pydantic import BaseModel
32
 
 
 
33
  # Local application imports
 
34
  from scripts.format_response import format_response_to_markdown
 
35
  from src.agents.agents import *
 
36
  from src.agents.retrievers.retrievers import *
 
37
  from src.managers.ai_manager import AI_Manager
 
38
  from src.managers.session_manager import SessionManager
 
39
  from src.routes.analytics_routes import router as analytics_router
 
40
  from src.routes.blog_routes import router as blog_router
 
41
  from src.routes.chat_routes import router as chat_router
 
42
  from src.routes.code_routes import router as code_router
 
43
  from src.routes.feedback_routes import router as feedback_router
 
44
  from src.routes.session_routes import router as session_router, get_session_id_dependency
 
45
  from src.routes.deep_analysis_routes import router as deep_analysis_router
 
46
  from src.routes.templates_routes import router as templates_router
 
47
  from src.schemas.query_schema import QueryRequest
 
48
  from src.utils.logger import Logger
49
 
 
 
50
  # Import deep analysis components directly
 
51
  # from src.agents.try_deep_agents import deep_analysis_module
 
52
  from src.agents.deep_agents import deep_analysis_module
 
53
  from src.utils.generate_report import generate_html_report
54
 
 
 
55
  from src.utils.model_registry import MODEL_OBJECTS
56
 
 
 
57
  logger = Logger("app", see_time=True, console_log=True)
 
58
  load_dotenv()
59
 
 
 
60
  # Request models
 
61
  class DeepAnalysisRequest(BaseModel):
 
62
  goal: str
 
63
 
 
64
  class DeepAnalysisResponse(BaseModel):
 
65
  goal: str
 
66
  deep_questions: str
 
67
  deep_plan: str
 
68
  summaries: List[str]
 
69
  code: str
 
70
  plotly_figs: List
 
71
  synthesis: List[str]
 
72
  final_conclusion: str
 
73
  html_report: Optional[str] = None
74
 
 
 
75
  styling_instructions = [
76
  {
77
  "category": "line_charts",
@@ -212,285 +286,566 @@ styling_instructions = [str(chart_dict) for chart_dict in styling_instructions]
212
 
213
 
214
 
 
215
  # Add near the top of the file, after imports
 
216
  DEFAULT_MODEL_CONFIG = {
217
- "provider": os.getenv("MODEL_PROVIDER", "anthropic"),
218
- "model": os.getenv("MODEL_NAME", "claude-3-5-sonnet-latest"),
219
- "api_key": os.getenv("ANTHROPIC_API_KEY"),
 
 
 
 
220
  "temperature": float(os.getenv("TEMPERATURE", 1.0)),
 
221
  "max_tokens": int(os.getenv("MAX_TOKENS", 6000)), "cache": False
 
222
  }
223
 
 
 
224
  # Create default LM config but don't set it globally
225
 
 
 
226
  default_lm = MODEL_OBJECTS[DEFAULT_MODEL_CONFIG['model']]
 
227
 
228
 
 
 
229
 
 
230
  # lm = dspy.LM('openai/gpt-4o-mini', api_key=os.getenv("OPENAI_API_KEY"))
 
231
  dspy.configure(lm=default_lm, async_max_workers=100)
232
 
 
 
233
  # Function to get model config from session or use default
 
234
  def get_session_lm(session_state):
 
235
  """Get the appropriate LM instance for a session, or default if not configured"""
 
236
  # First check if we have a valid session-specific model config
 
237
  if session_state and isinstance(session_state, dict) and "model_config" in session_state:
 
238
  model_config = session_state["model_config"]
 
239
  if model_config and isinstance(model_config, dict) and "model" in model_config:
 
240
  # Found valid session-specific model config, use it
 
241
  provider = model_config.get("provider", "openai").lower()
 
242
  model_name = model_config.get("model", DEFAULT_MODEL_CONFIG["model"])
 
243
  if 'gpt-5' or 'o1' not in model_name:
 
244
  MODEL_OBJECTS[model_name].__dict__['kwargs']['max_tokens'] = model_config.get("max_tokens", DEFAULT_MODEL_CONFIG["max_tokens"])
 
245
  MODEL_OBJECTS[model_name].__dict__['kwargs']['temperature'] = model_config.get("temperature", DEFAULT_MODEL_CONFIG["temperature"])
 
246
  elif 'gpt-5' or 'o1' in model_name and provider =='openai':
 
247
  MODEL_OBJECTS[model_name].__dict__['kwargs']['max_completion_tokens'] = model_config.get("max_tokens", DEFAULT_MODEL_CONFIG["max_tokens"])
 
248
  MODEL_OBJECTS[model_name].__dict__['kwargs']['temperature'] = 1.0
 
249
  else:
 
250
  MODEL_OBJECTS[model_name].__dict__['kwargs']['max_tokens'] = model_config.get("max_tokens", DEFAULT_MODEL_CONFIG["max_tokens"])
 
251
  MODEL_OBJECTS[model_name].__dict__['kwargs']['temperature'] = model_config.get("temperature", DEFAULT_MODEL_CONFIG["temperature"])
252
 
 
 
253
 
 
254
  # If no valid session config, use default
 
255
  return MODEL_OBJECTS[model_name]
256
 
 
 
257
  # Initialize retrievers with empty data first
258
 
 
 
259
  # clear console
 
260
  def clear_console():
 
261
  os.system('cls' if os.name == 'nt' else 'clear')
262
 
263
 
 
 
 
264
  # Check for Housing.csv
 
265
  housing_csv_path = "Housing.csv"
 
266
  if not os.path.exists(housing_csv_path):
 
267
  logger.log_message(f"Housing.csv not found at {os.path.abspath(housing_csv_path)}", level=logging.ERROR)
 
268
  raise FileNotFoundError(f"Housing.csv not found at {os.path.abspath(housing_csv_path)}")
269
 
 
 
270
  # All agents are now loaded from database - no hardcoded dictionaries needed
271
 
 
 
272
  # Add session header
 
273
  X_SESSION_ID = APIKeyHeader(name="X-Session-ID", auto_error=False)
274
 
 
 
275
  # Update AppState class to use SessionManager
 
276
  class AppState:
 
277
  def __init__(self):
 
278
  self._session_manager = SessionManager(styling_instructions, {}) # Empty dict, agents loaded from DB
 
279
  self.model_config = DEFAULT_MODEL_CONFIG.copy()
 
280
  # Update the SessionManager with the current model_config
 
281
  self._session_manager._app_model_config = self.model_config
 
282
  self.ai_manager = AI_Manager()
 
283
  self.chat_name_agent = chat_history_name_agent
 
284
  # Initialize deep analysis module
 
285
  self.deep_analyzer = None
 
286
 
 
287
  def get_session_state(self, session_id: str):
 
288
  """Get or create session-specific state using the SessionManager"""
 
289
  return self._session_manager.get_session_state(session_id)
290
 
 
 
291
  def clear_session_state(self, session_id: str):
 
292
  """Clear session-specific state using the SessionManager"""
 
293
  self._session_manager.clear_session_state(session_id)
294
 
 
 
295
  def update_session_dataset(self, session_id: str, datasets, names, desc, pre_generated=False):
296
  """Update dataset for a specific session using the SessionManager"""
 
297
  self._session_manager.update_session_dataset(session_id, datasets, names, desc, pre_generated=pre_generated)
298
 
 
299
  def reset_session_to_default(self, session_id: str):
 
300
  """Reset a session to use the default dataset using the SessionManager"""
 
301
  self._session_manager.reset_session_to_default(session_id)
 
302
 
 
303
  def set_session_user(self, session_id: str, user_id: int, chat_id: int = None):
 
304
  """Associate a user with a session using the SessionManager"""
 
305
  return self._session_manager.set_session_user(session_id, user_id, chat_id)
 
306
 
 
307
  def get_ai_manager(self):
 
308
  """Get the AI Manager instance"""
 
309
  return self.ai_manager
 
310
 
 
311
  def get_provider_for_model(self, model_name):
 
312
  return self.ai_manager.get_provider_for_model(model_name)
 
313
 
 
314
  def calculate_cost(self, model_name, input_tokens, output_tokens):
 
315
  return self.ai_manager.calculate_cost(model_name, input_tokens, output_tokens)
 
316
 
 
317
  def save_usage_to_db(self, user_id, chat_id, model_name, provider, prompt_tokens, completion_tokens, total_tokens, query_size, response_size, cost, request_time_ms, is_streaming=False):
 
318
  return self.ai_manager.save_usage_to_db(user_id, chat_id, model_name, provider, prompt_tokens, completion_tokens, total_tokens, query_size, response_size, round(cost, 7), request_time_ms, is_streaming)
 
319
 
 
320
  def get_tokenizer(self):
 
321
  return self.ai_manager.tokenizer
 
322
 
 
323
  def get_chat_history_name_agent(self):
 
324
  return dspy.Predict(self.chat_name_agent)
325
 
 
 
326
  def get_deep_analyzer(self, session_id: str):
 
327
  """Get or create deep analysis module for a session"""
 
328
  session_state = self.get_session_state(session_id)
 
329
  user_id = session_state.get("user_id")
 
330
 
 
331
  # Check if we need to recreate the deep analyzer (user changed or doesn't exist)
 
332
  current_analyzer = session_state.get('deep_analyzer')
 
333
  analyzer_user_id = session_state.get('deep_analyzer_user_id')
 
334
 
 
335
  logger.log_message(f"Deep analyzer check - session: {session_id}, current_user: {user_id}, analyzer_user: {analyzer_user_id}, has_analyzer: {current_analyzer is not None}", level=logging.INFO)
 
336
 
 
337
  if (not current_analyzer or
 
338
  analyzer_user_id != user_id or
 
339
  not hasattr(session_state, 'deep_analyzer')):
 
340
 
 
341
  logger.log_message(f"Creating/recreating deep analyzer for session {session_id}, user_id: {user_id} (reason: analyzer_exists={current_analyzer is not None}, user_match={analyzer_user_id == user_id})", level=logging.INFO)
 
342
 
 
343
  # Load user-enabled agents from database using preference system
 
344
  from src.db.init_db import session_factory
 
345
  from src.agents.agents import load_user_enabled_templates_for_planner_from_db
 
346
 
 
347
  db_session = session_factory()
 
348
  try:
 
349
  # Load user-enabled agents for planner (respects preferences)
 
350
  if user_id:
 
351
  enabled_agents_dict = load_user_enabled_templates_for_planner_from_db(user_id, db_session)
 
352
  logger.log_message(f"Deep analyzer loaded {len(enabled_agents_dict)} enabled agents for user {user_id}: {list(enabled_agents_dict.keys())}", level=logging.INFO)
 
353
 
 
354
  if not enabled_agents_dict:
 
355
  logger.log_message(f"WARNING: No enabled agents found for user {user_id}, falling back to defaults", level=logging.WARNING)
 
356
  # Fallback to default agents if no enabled agents
 
357
  from src.agents.agents import preprocessing_agent, statistical_analytics_agent, sk_learn_agent, data_viz_agent
 
358
  enabled_agents_dict = {
 
359
  "preprocessing_agent": preprocessing_agent,
 
360
  "statistical_analytics_agent": statistical_analytics_agent,
 
361
  "sk_learn_agent": sk_learn_agent,
 
362
  "data_viz_agent": data_viz_agent
 
363
  }
 
364
  else:
 
365
  # Fallback to default agents if no user_id
 
366
  logger.log_message("No user_id in session, loading default agents for deep analysis", level=logging.WARNING)
 
367
  from src.agents.agents import preprocessing_agent, statistical_analytics_agent, sk_learn_agent, data_viz_agent
 
368
  enabled_agents_dict = {
 
369
  "preprocessing_agent": preprocessing_agent,
 
370
  "statistical_analytics_agent": statistical_analytics_agent,
 
371
  "sk_learn_agent": sk_learn_agent,
 
372
  "data_viz_agent": data_viz_agent
 
373
  }
 
374
 
 
375
  # Create agents dictionary for deep analysis using enabled agents
 
376
  deep_agents = {}
 
377
  deep_agents_desc = {}
 
378
 
 
379
  for agent_name, signature in enabled_agents_dict.items():
 
380
  deep_agents[agent_name] = dspy.asyncify(dspy.ChainOfThought(signature))
 
381
  # Get agent description from database
 
382
  deep_agents_desc[agent_name] = get_agent_description(agent_name)
 
383
 
 
384
  logger.log_message(f"Deep analyzer initialized with {len(deep_agents)} agents: {list(deep_agents.keys())}", level=logging.INFO)
 
385
 
 
386
  except Exception as e:
 
387
  logger.log_message(f"Error loading agents for deep analysis: {str(e)}", level=logging.ERROR)
 
388
  # Fallback to minimal set
 
389
  from src.agents.agents import preprocessing_agent, statistical_analytics_agent, sk_learn_agent, data_viz_agent
 
390
  deep_agents = {
 
391
  "preprocessing_agent": dspy.asyncify(dspy.Predict(preprocessing_agent)),
 
392
  "statistical_analytics_agent": dspy.asyncify(dspy.Predict(statistical_analytics_agent)),
 
393
  "sk_learn_agent": dspy.asyncify(dspy.Predict(sk_learn_agent)),
 
394
  "data_viz_agent": dspy.asyncify(dspy.Predict(data_viz_agent))
 
395
  }
 
396
  deep_agents_desc = {name: get_agent_description(name) for name in deep_agents.keys()}
 
397
  logger.log_message(f"Using fallback agents: {list(deep_agents.keys())}", level=logging.WARNING)
 
398
  finally:
 
399
  db_session.close()
 
400
 
401
- session_state['deep_analyzer'] = deep_analysis_module(agents=deep_agents, agents_desc=deep_agents_desc)
 
 
 
 
 
 
402
  session_state['deep_analyzer_user_id'] = user_id # Track which user this analyzer was created for
 
403
  else:
 
404
  logger.log_message(f"Using existing deep analyzer for session {session_id}, user_id: {user_id}", level=logging.INFO)
 
405
 
 
406
  return session_state['deep_analyzer']
407
 
 
 
408
  # Initialize FastAPI app with state
 
409
  app = FastAPI(title="AI Analytics API", version="1.0")
 
410
  app.state = AppState()
411
 
412
 
 
 
 
413
  # Configure middleware
 
414
  # Use a wildcard for local development or read from environment
 
415
  is_development = os.getenv("ENVIRONMENT", "development").lower() == "development"
416
 
 
 
417
  allowed_origins = []
 
418
  frontend_url = os.getenv("FRONTEND_URL", "").strip()
 
419
  print(f"FRONTEND_URL: {frontend_url}")
 
420
  if is_development:
 
421
  allowed_origins = ["*"]
 
422
  elif frontend_url:
 
423
  allowed_origins = [frontend_url]
 
424
  else:
 
425
  logger.log_message("CORS misconfigured: FRONTEND_URL not set", level=logging.ERROR)
 
426
  allowed_origins = [] # or set a default safe origin
427
 
 
 
428
  # Add a strict origin verification middleware
 
429
  @app.middleware("http")
 
430
  async def verify_origin_middleware(request: Request, call_next):
 
431
  # Skip origin check in development mode
 
432
  if is_development:
 
433
  return await call_next(request)
 
434
 
 
435
  # Get the origin from the request headers
 
436
  origin = request.headers.get("origin")
 
437
 
 
438
  # Log the origin for debugging
 
439
  if origin:
 
440
  print(f"Request from origin: {origin}")
 
441
 
 
442
  # If no origin header or origin not in allowed list, reject the request
 
443
  if origin and frontend_url and origin != frontend_url:
 
444
  print(f"Blocked request from unauthorized origin: {origin}")
 
445
  return JSONResponse(
 
446
  status_code=403,
 
447
  content={"detail": "Not authorized"}
 
448
  )
 
449
 
 
450
  # Continue processing the request if origin is allowed
 
451
  return await call_next(request)
452
 
 
 
453
  # CORS middleware (still needed for browser preflight)
 
454
  app.add_middleware(
 
455
  CORSMiddleware,
 
456
  allow_origins=allowed_origins,
 
457
  allow_origin_regex=None,
 
458
  allow_credentials=True,
 
459
  allow_methods=["*"],
 
460
  allow_headers=["*"],
 
461
  expose_headers=["*"],
 
462
  max_age=600 # Cache preflight requests for 10 minutes (for performance)
 
463
  )
464
 
 
 
465
  # Add these constants at the top of the file with other imports/constants
 
466
  RESPONSE_ERROR_INVALID_QUERY = "Please provide a valid query..."
 
467
  RESPONSE_ERROR_NO_DATASET = "No dataset is currently loaded. Please link a dataset before proceeding with your analysis."
 
468
  DEFAULT_TOKEN_RATIO = 1.5
 
469
  REQUEST_TIMEOUT_SECONDS = 30 # Timeout for LLM requests
 
470
  MAX_RECENT_MESSAGES = 5
 
471
  DB_BATCH_SIZE = 10 # For future batch DB operations
472
 
 
 
473
  @app.post("/chat/{agent_name}", response_model=dict)
 
474
  async def chat_with_agent(
 
475
  agent_name: str,
 
476
  request: QueryRequest,
 
477
  request_obj: Request,
 
478
  session_id: str = Depends(get_session_id_dependency)
 
479
  ):
 
480
  session_state = app.state.get_session_state(session_id)
 
481
  logger.log_message(f"[DEBUG] chat_with_agent called with agent: '{agent_name}', query: '{request.query[:100]}...'", level=logging.DEBUG)
 
482
 
 
483
  try:
 
484
  # Extract and validate query parameters
 
485
  logger.log_message(f"[DEBUG] Updating session from query params", level=logging.DEBUG)
 
486
  _update_session_from_query_params(request_obj, session_state)
 
487
  logger.log_message(f"[DEBUG] Session state after query params: user_id={session_state.get('user_id')}, chat_id={session_state.get('chat_id')}", level=logging.DEBUG)
 
488
 
 
489
  # Validate dataset and agent name
 
490
  if session_state["datasets"] is None:
491
  logger.log_message(f"[DEBUG] No dataset loaded", level=logging.DEBUG)
 
492
  raise HTTPException(status_code=400, detail=RESPONSE_ERROR_NO_DATASET)
493
 
 
 
494
  # Log the dataset being used for analysis with detailed information
495
  datasets = session_state["datasets"]
496
  dataset_names = list(datasets.keys())
@@ -512,1153 +867,2298 @@ async def chat_with_agent(
512
  logger.log_message(f"[ANALYSIS] No datasets available in session {session_id}", level=logging.WARNING)
513
 
514
  logger.log_message(f"[DEBUG] About to validate agent name: '{agent_name}'", level=logging.DEBUG)
 
515
  _validate_agent_name(agent_name, session_state)
 
516
  logger.log_message(f"[DEBUG] Agent validation completed successfully", level=logging.DEBUG)
 
517
 
 
518
  # Record start time for timing
 
519
  start_time = time.time()
 
520
 
 
521
  # Get chat context and prepare query
 
522
  logger.log_message(f"[DEBUG] Preparing query with context", level=logging.DEBUG)
 
523
  enhanced_query = _prepare_query_with_context(request.query, session_state)
 
524
  logger.log_message(f"[DEBUG] Enhanced query length: {len(enhanced_query)}", level=logging.DEBUG)
 
525
 
 
526
  # Initialize agent - handle standard, template, and custom agents
 
527
  if "," in agent_name:
 
528
  logger.log_message(f"[DEBUG] Processing multiple agents: {agent_name}", level=logging.DEBUG)
 
529
  # Multiple agents case
 
530
  agent_list = [agent.strip() for agent in agent_name.split(",")]
 
531
 
 
532
  # Categorize agents
 
533
  standard_agents = [agent for agent in agent_list if _is_standard_agent(agent)]
 
534
  template_agents = [agent for agent in agent_list if _is_template_agent(agent)]
 
535
  custom_agents = [agent for agent in agent_list if not _is_standard_agent(agent) and not _is_template_agent(agent)]
 
536
 
 
537
  logger.log_message(f"[DEBUG] Agent categorization - standard: {standard_agents}, template: {template_agents}, custom: {custom_agents}", level=logging.DEBUG)
 
538
 
 
539
  if custom_agents:
 
540
  # If any custom agents, use session AI system for all
 
541
  ai_system = session_state["ai_system"]
 
542
  session_lm = get_session_lm(session_state)
 
543
  logger.log_message(f"[DEBUG] Using custom agent execution path", level=logging.DEBUG)
 
544
  with dspy.context(lm=session_lm):
 
545
  response = await asyncio.wait_for(
 
546
  _execute_custom_agents(ai_system, agent_list, enhanced_query),
 
547
  timeout=REQUEST_TIMEOUT_SECONDS
 
548
  )
 
549
  logger.log_message(f"[DEBUG] Custom agents response type: {type(response)}, keys: {list(response.keys()) if isinstance(response, dict) else 'not a dict'}", level=logging.DEBUG)
 
550
  else:
 
551
  # All standard/template agents - use auto_analyst_ind which loads from DB
 
552
  user_id = session_state.get("user_id")
 
553
  logger.log_message(f"[DEBUG] Using auto_analyst_ind for multiple standard/template agents with user_id: {user_id}", level=logging.DEBUG)
 
554
 
 
555
  # Create database session for agent loading
 
556
  from src.db.init_db import session_factory
 
557
  db_session = session_factory()
 
558
  try:
 
559
  # auto_analyst_ind will load all agents from database
 
560
  logger.log_message(f"[DEBUG] Creating auto_analyst_ind instance", level=logging.DEBUG)
 
561
  agent = auto_analyst_ind(agents=[], retrievers=session_state["retrievers"], user_id=user_id, db_session=db_session)
 
562
  session_lm = get_session_lm(session_state)
 
563
  logger.log_message(f"[DEBUG] About to call agent.forward with query and agent list", level=logging.DEBUG)
 
564
  with dspy.context(lm=session_lm):
 
565
  response = await asyncio.wait_for(
 
566
  agent.forward(enhanced_query, ",".join(agent_list)),
 
567
  timeout=REQUEST_TIMEOUT_SECONDS
 
568
  )
 
569
  logger.log_message(f"[DEBUG] auto_analyst_ind response type: {type(response)}, content: {str(response)[:200]}...", level=logging.DEBUG)
 
570
  finally:
 
571
  db_session.close()
 
572
  else:
 
573
  logger.log_message(f"[DEBUG] Processing single agent: {agent_name}", level=logging.DEBUG)
 
574
  # Single agent case
 
575
  if _is_standard_agent(agent_name) or _is_template_agent(agent_name):
 
576
  # Standard or template agent - use auto_analyst_ind which loads from DB
 
577
  user_id = session_state.get("user_id")
 
578
  logger.log_message(f"[DEBUG] Using auto_analyst_ind for single standard/template agent '{agent_name}' with user_id: {user_id}", level=logging.DEBUG)
 
579
 
 
580
  # Create database session for agent loading
 
581
  from src.db.init_db import session_factory
 
582
  db_session = session_factory()
 
583
  try:
 
584
  # auto_analyst_ind will load all agents from database
 
585
  logger.log_message(f"[DEBUG] Creating auto_analyst_ind instance for single agent", level=logging.DEBUG)
 
586
  agent = auto_analyst_ind(agents=[], retrievers=session_state["retrievers"], user_id=user_id, db_session=db_session)
 
587
  session_lm = get_session_lm(session_state)
 
588
  logger.log_message(f"[DEBUG] About to call agent.forward for single agent '{agent_name}'", level=logging.DEBUG)
 
589
  with dspy.context(lm=session_lm):
 
590
  response = await asyncio.wait_for(
 
591
  agent.forward(enhanced_query, agent_name),
 
592
  timeout=REQUEST_TIMEOUT_SECONDS
 
593
  )
 
594
  logger.log_message(f"[DEBUG] Single agent response type: {type(response)}, content: {str(response)[:200]}...", level=logging.DEBUG)
 
595
  finally:
 
596
  db_session.close()
 
597
  else:
 
598
  # Custom agent - use session AI system
 
599
  ai_system = session_state["ai_system"]
 
600
  session_lm = get_session_lm(session_state)
 
601
  logger.log_message(f"[DEBUG] Using custom agent execution for '{agent_name}'", level=logging.DEBUG)
 
602
  with dspy.context(lm=session_lm):
 
603
  response = await asyncio.wait_for(
 
604
  _execute_custom_agents(ai_system, [agent_name], enhanced_query),
 
605
  timeout=REQUEST_TIMEOUT_SECONDS
 
606
  )
 
607
  logger.log_message(f"[DEBUG] Custom single agent response type: {type(response)}, content: {str(response)[:200]}...", level=logging.DEBUG)
 
608
 
 
609
  logger.log_message(f"[DEBUG] About to format response to markdown. Response type: {type(response)}", level=logging.DEBUG)
 
610
  formatted_response = format_response_to_markdown(response, agent_name, session_state["datasets"])
611
  logger.log_message(f"[DEBUG] Formatted response type: {type(formatted_response)}, length: {len(str(formatted_response))}", level=logging.DEBUG)
 
612
 
 
613
  if formatted_response == RESPONSE_ERROR_INVALID_QUERY:
 
614
  logger.log_message(f"[DEBUG] Response was invalid query error", level=logging.DEBUG)
 
615
  return {
 
616
  "agent_name": agent_name,
 
617
  "query": request.query,
 
618
  "response": formatted_response,
 
619
  "session_id": session_id
 
620
  }
 
621
 
 
622
  # Track usage statistics
 
623
  if session_state.get("user_id"):
 
624
  logger.log_message(f"[DEBUG] Tracking model usage", level=logging.DEBUG)
 
625
  _track_model_usage(
 
626
  session_state=session_state,
 
627
  enhanced_query=enhanced_query,
 
628
  response=response,
 
629
  processing_time_ms=int((time.time() - start_time) * 1000)
 
630
  )
 
631
 
 
632
  logger.log_message(f"[DEBUG] chat_with_agent completed successfully", level=logging.DEBUG)
 
633
  return {
 
634
  "agent_name": agent_name,
 
635
  "query": request.query, # Return original query without context
 
636
  "response": formatted_response,
 
637
  "session_id": session_id
 
638
  }
 
639
  except HTTPException:
 
640
  # Re-raise HTTP exceptions to preserve status codes
 
641
  logger.log_message(f"[DEBUG] HTTPException caught and re-raised", level=logging.DEBUG)
 
642
  raise
 
643
  except asyncio.TimeoutError:
 
644
  logger.log_message(f"[ERROR] Timeout error in chat_with_agent", level=logging.ERROR)
 
645
  raise HTTPException(status_code=504, detail="Request timed out. Please try a simpler query.")
 
646
  except Exception as e:
 
647
  logger.log_message(f"[ERROR] Unexpected error in chat_with_agent: {str(e)}", level=logging.ERROR)
 
648
  logger.log_message(f"[ERROR] Exception type: {type(e)}, traceback: {str(e)}", level=logging.ERROR)
 
649
  import traceback
 
650
  logger.log_message(f"[ERROR] Full traceback: {traceback.format_exc()}", level=logging.ERROR)
 
651
  raise HTTPException(status_code=500, detail="An unexpected error occurred. Please try again later.")
652
 
653
 
 
 
 
654
  @app.post("/chat", response_model=dict)
 
655
  async def chat_with_all(
 
656
  request: QueryRequest,
 
657
  request_obj: Request,
 
658
  session_id: str = Depends(get_session_id_dependency)
 
659
  ):
 
660
  session_state = app.state.get_session_state(session_id)
661
 
 
 
662
  try:
 
663
  # Extract and validate query parameters
 
664
  _update_session_from_query_params(request_obj, session_state)
 
665
 
 
666
  # Validate dataset
 
667
  if session_state["datasets"] is None:
668
  raise HTTPException(status_code=400, detail=RESPONSE_ERROR_NO_DATASET)
 
669
 
 
670
  if session_state["ai_system"] is None:
 
671
  raise HTTPException(status_code=500, detail="AI system not properly initialized.")
672
 
 
 
673
  # Get session-specific model
 
674
  session_lm = get_session_lm(session_state)
675
 
 
 
676
  # Create streaming response
 
677
  return StreamingResponse(
 
678
  _generate_streaming_responses(session_state, request.query, session_lm),
 
679
  media_type='text/event-stream',
 
680
  headers={
 
681
  'Cache-Control': 'no-cache',
 
682
  'Connection': 'keep-alive',
 
683
  'Content-Type': 'text/event-stream',
 
684
  'Access-Control-Allow-Origin': '*',
 
685
  'X-Accel-Buffering': 'no'
 
686
  }
 
687
  )
 
688
  except HTTPException:
 
689
  # Re-raise HTTP exceptions to preserve status codes
 
690
  raise
 
691
  except Exception as e:
 
692
  raise HTTPException(status_code=500, detail="An unexpected error occurred. Please try again later.")
693
 
694
 
 
 
 
695
  # Helper functions to reduce duplication and improve modularity
 
696
  def _update_session_from_query_params(request_obj: Request, session_state: dict):
 
697
  """Extract and validate chat_id and user_id from query parameters"""
 
698
  # Check for chat_id in query parameters
 
699
  if "chat_id" in request_obj.query_params:
 
700
  try:
 
701
  chat_id_param = int(request_obj.query_params.get("chat_id"))
 
702
  # Update session state with this chat ID
 
703
  session_state["chat_id"] = chat_id_param
 
704
  except (ValueError, TypeError):
 
705
  logger.log_message("Invalid chat_id parameter", level=logging.WARNING)
 
706
  # Continue without updating chat_id
707
 
 
 
708
  # Check for user_id in query parameters
 
709
  if "user_id" in request_obj.query_params:
 
710
  try:
 
711
  user_id = int(request_obj.query_params["user_id"])
 
712
  session_state["user_id"] = user_id
 
713
  except (ValueError, TypeError):
 
714
  raise HTTPException(
 
715
  status_code=400,
 
716
  detail="Invalid user_id in query params. Please provide a valid integer."
 
717
  )
718
 
719
 
 
 
 
720
  def _validate_agent_name(agent_name: str, session_state: dict = None):
 
721
  """Validate that the agent name(s) are available"""
 
722
  logger.log_message(f"[DEBUG] Validating agent name: '{agent_name}'", level=logging.DEBUG)
 
723
 
 
724
  if "," in agent_name:
 
725
  # Multiple agents
 
726
  agent_list = [agent.strip() for agent in agent_name.split(",")]
 
727
  logger.log_message(f"[DEBUG] Multiple agents detected: {agent_list}", level=logging.DEBUG)
 
728
  for agent in agent_list:
 
729
  is_available = _is_agent_available(agent, session_state)
 
730
  logger.log_message(f"[DEBUG] Agent '{agent}' availability: {is_available}", level=logging.DEBUG)
 
731
  if not is_available:
 
732
  available_agents = _get_available_agents_list(session_state)
 
733
  logger.log_message(f"[DEBUG] Agent '{agent}' not found. Available: {available_agents}", level=logging.DEBUG)
 
734
  raise HTTPException(
 
735
  status_code=400,
 
736
  detail=f"Agent '{agent}' not found. Available agents: {available_agents}"
 
737
  )
 
738
  else:
 
739
  # Single agent
 
740
  is_available = _is_agent_available(agent_name, session_state)
 
741
  logger.log_message(f"[DEBUG] Single agent '{agent_name}' availability: {is_available}", level=logging.DEBUG)
 
742
  if not is_available:
 
743
  available_agents = _get_available_agents_list(session_state)
 
744
  logger.log_message(f"[DEBUG] Agent '{agent_name}' not found. Available: {available_agents}", level=logging.DEBUG)
 
745
  raise HTTPException(
 
746
  status_code=400,
 
747
  detail=f"Agent '{agent_name}' not found. Available agents: {available_agents}"
 
748
  )
 
749
 
 
750
  logger.log_message(f"[DEBUG] Agent validation passed for: '{agent_name}'", level=logging.DEBUG)
751
 
 
 
752
  def _is_agent_available(agent_name: str, session_state: dict = None) -> bool:
 
753
  """Check if an agent is available (standard, template, or custom)"""
 
754
  # Check if it's a standard agent
 
755
  if _is_standard_agent(agent_name):
 
756
  return True
 
757
 
 
758
  # Check if it's a template agent
 
759
  if _is_template_agent(agent_name):
 
760
  return True
 
761
 
 
762
  # Check if it's a custom agent in session
 
763
  if session_state and "ai_system" in session_state:
 
764
  ai_system = session_state["ai_system"]
 
765
  if hasattr(ai_system, 'agents') and agent_name in ai_system.agents:
 
766
  return True
 
767
 
 
768
  return False
769
 
 
 
770
  def _get_available_agents_list(session_state: dict = None) -> list:
 
771
  """Get list of all available agents from database"""
 
772
  from src.db.init_db import session_factory
 
773
  from src.agents.agents import load_all_available_templates_from_db
 
774
 
 
775
  # Core agents (always available)
 
776
  available = ["preprocessing_agent", "statistical_analytics_agent", "sk_learn_agent", "data_viz_agent"]
 
777
 
 
778
  # Add template agents from database
 
779
  db_session = session_factory()
 
780
  try:
 
781
  template_agents_dict = load_all_available_templates_from_db(db_session)
 
782
  # template_agents_dict is a dict with template_name as keys
 
783
  template_names = [template_name for template_name in template_agents_dict.keys()
 
784
  if template_name not in available and template_name != 'basic_qa_agent']
 
785
  available.extend(template_names)
 
786
  except Exception as e:
 
787
  logger.log_message(f"Error loading template agents: {str(e)}", level=logging.ERROR)
 
788
  finally:
 
789
  db_session.close()
 
790
 
 
791
  return available
792
 
 
 
793
  def _is_standard_agent(agent_name: str) -> bool:
 
794
  """Check if agent is one of the 4 core standard agents"""
 
795
  standard_agents = ["preprocessing_agent", "statistical_analytics_agent", "sk_learn_agent", "data_viz_agent"]
 
796
  return agent_name in standard_agents
797
 
 
 
798
  def _is_template_agent(agent_name: str) -> bool:
 
799
  """Check if agent is a template agent"""
 
800
  try:
 
801
  from src.db.init_db import session_factory
 
802
  from src.db.schemas.models import AgentTemplate
 
803
 
 
804
  db_session = session_factory()
 
805
  try:
 
806
  template = db_session.query(AgentTemplate).filter(
 
807
  AgentTemplate.template_name == agent_name,
 
808
  AgentTemplate.is_active == True
 
809
  ).first()
 
810
  return template is not None
 
811
  finally:
 
812
  db_session.close()
 
813
  except Exception as e:
 
814
  logger.log_message(f"Error checking if {agent_name} is template: {str(e)}", level=logging.ERROR)
 
815
  return False
816
 
 
 
817
  async def _execute_custom_agents(ai_system, agent_names: list, query: str):
 
818
  """Execute custom agents using the session's AI system"""
 
819
  try:
 
820
  # For custom agents, we need to use the AI system's execute_agent method
821
 
 
 
822
  agent_results = [ai_system]
 
823
  if len(agent_names) == 1:
 
824
  # Single custom agent
 
825
  agent_name = agent_names[0]
 
826
  # Prepare inputs for the custom agent (similar to standard agents like data_viz_agent)
 
827
  dict_ = {}
 
828
  dict_['dataset'] = ai_system.dataset.retrieve(query)[0].text
 
829
  dict_['styling_index'] = ai_system.styling_index.retrieve(query)[0].text
 
830
  dict_['goal'] = query
 
831
  dict_['Agent_desc'] = str(ai_system.agent_desc)
832
 
 
 
833
  # Get input fields for this agent
 
834
  if agent_name in ai_system.agent_inputs:
 
835
  inputs = {x: dict_[x] for x in ai_system.agent_inputs[agent_name] if x in dict_}
 
836
 
 
837
  # Execute the custom agent
 
838
  agent_name_result, result_dict = await ai_system.agents[agent_name](**inputs)
 
839
  return {agent_name_result: result_dict}
 
840
  else:
 
841
  logger.log_message(f"Agent '{agent_name}' not found in ai_system.agent_inputs", level=logging.ERROR)
 
842
  return {"error": f"Agent '{agent_name}' input configuration not found"}
 
843
  else:
 
844
  # Multiple agents - execute sequentially
 
845
  results = {}
 
846
  for agent_name in agent_names:
 
847
  single_result = await _execute_custom_agents(ai_system, [agent_name], query)
 
848
  results.update(single_result)
 
849
  return results
 
850
 
 
851
  except Exception as e:
 
852
  logger.log_message(f"Error in _execute_custom_agents: {str(e)}", level=logging.ERROR)
 
853
  return {"error": f"Error executing custom agents: {str(e)}"}
854
 
 
 
855
  def _prepare_query_with_context(query: str, session_state: dict) -> str:
 
856
  """Prepare the query with chat context from previous messages"""
 
857
  chat_id = session_state.get("chat_id")
 
858
  if not chat_id:
 
859
  return query
 
860
 
 
861
  # Get chat manager from app state
 
862
  chat_manager = app.state._session_manager.chat_manager
 
863
  # Get recent messages
 
864
  recent_messages = chat_manager.get_recent_chat_history(chat_id, limit=MAX_RECENT_MESSAGES)
 
865
  # Extract response history
 
866
  chat_context = chat_manager.extract_response_history(recent_messages)
 
867
 
 
868
  # Append context to the query if available
 
869
  if chat_context:
 
870
  return f"### Current Query:\n{query}\n\n{chat_context}"
 
871
  return query
872
 
873
 
 
 
 
874
  def _track_model_usage(session_state: dict, enhanced_query: str, response, processing_time_ms: int):
 
875
  """Track model usage statistics in the database"""
 
876
  try:
 
877
  ai_manager = app.state.get_ai_manager()
 
878
 
 
879
  # Get model configuration
 
880
  model_config = session_state.get("model_config", DEFAULT_MODEL_CONFIG)
 
881
  model_name = model_config.get("model", DEFAULT_MODEL_CONFIG["model"])
 
882
  provider = ai_manager.get_provider_for_model(model_name)
 
883
 
 
884
  # Calculate token usage
 
885
  try:
 
886
  # Try exact tokenization
 
887
  prompt_tokens = len(ai_manager.tokenizer.encode(enhanced_query))
 
888
  completion_tokens = len(ai_manager.tokenizer.encode(str(response)))
 
889
  total_tokens = prompt_tokens + completion_tokens
 
890
  except Exception as token_error:
 
891
  # Fall back to estimation
 
892
  logger.log_message(f"Tokenization error: {str(token_error)}", level=logging.WARNING)
 
893
  prompt_words = len(enhanced_query.split())
 
894
  completion_words = len(str(response).split())
 
895
  prompt_tokens = int(prompt_words * DEFAULT_TOKEN_RATIO)
 
896
  completion_tokens = int(completion_words * DEFAULT_TOKEN_RATIO)
 
897
  total_tokens = prompt_tokens + completion_tokens
 
898
 
 
899
  # Calculate cost
 
900
  cost = ai_manager.calculate_cost(model_name, prompt_tokens, completion_tokens)
 
901
 
 
902
  # Save usage to database
 
903
  ai_manager.save_usage_to_db(
 
904
  user_id=session_state.get("user_id"),
 
905
  chat_id=session_state.get("chat_id"),
 
906
  model_name=model_name,
 
907
  provider=provider,
 
908
  prompt_tokens=int(prompt_tokens),
 
909
  completion_tokens=int(completion_tokens),
 
910
  total_tokens=int(total_tokens),
 
911
  query_size=len(enhanced_query),
 
912
  response_size=len(str(response)),
 
913
  cost=round(cost, 7),
 
914
  request_time_ms=processing_time_ms,
 
915
  is_streaming=False
 
916
  )
 
917
  except Exception as e:
 
918
  # Log but don't fail the request if usage tracking fails
 
919
  logger.log_message(f"Failed to track model usage: {str(e)}", level=logging.ERROR)
920
 
921
 
 
 
 
922
  async def _generate_streaming_responses(session_state: dict, query: str, session_lm):
 
923
  """Generate streaming responses for chat_with_all endpoint"""
 
924
  overall_start_time = time.time()
 
925
  total_response = ""
 
926
  total_inputs = ""
 
927
  usage_records = []
928
 
 
 
929
  # Add chat context from previous messages
 
930
  enhanced_query = _prepare_query_with_context(query, session_state)
 
931
 
 
932
  # try:
 
933
  # Get the plan - planner is now async, so we need to await it
 
934
  plan_response = await session_state["ai_system"].get_plan(enhanced_query)
 
935
 
 
936
  plan_description = format_response_to_markdown(
 
937
  {"analytical_planner": plan_response},
 
938
  datasets=session_state["datasets"]
939
  )
 
940
 
 
941
  # Check if plan is valid
 
942
  if plan_description == RESPONSE_ERROR_INVALID_QUERY:
 
943
  yield json.dumps({
 
944
  "agent": "Analytical Planner",
 
945
  "content": plan_description,
 
946
  "status": "error"
 
947
  }) + "\n"
 
948
  return
 
949
 
 
950
  yield json.dumps({
 
951
  "agent": "Analytical Planner",
 
952
  "content": plan_description,
 
953
  "status": "success" if plan_description else "error"
 
954
  }) + "\n"
 
955
 
 
956
  # Track planner usage
 
957
  if session_state.get("user_id"):
 
958
  planner_tokens = _estimate_tokens(ai_manager=app.state.ai_manager,
 
959
  input_text=enhanced_query,
 
960
  output_text=plan_description)
 
961
 
 
962
  usage_records.append(_create_usage_record(
 
963
  session_state=session_state,
 
964
  model_name=session_state.get("model_config", DEFAULT_MODEL_CONFIG)["model"],
 
965
  prompt_tokens=planner_tokens["prompt"],
 
966
  completion_tokens=planner_tokens["completion"],
 
967
  query_size=len(enhanced_query),
 
968
  response_size=len(plan_description),
 
969
  processing_time_ms=int((time.time() - overall_start_time) * 1000),
 
970
  is_streaming=False
 
971
  ))
 
972
 
 
973
  logger.log_message(f"Plan response: {plan_response}", level=logging.INFO)
 
974
  logger.log_message(f"Plan response type: {type(plan_response)}", level=logging.INFO)
975
 
 
 
976
  # Check if plan_response is valid
 
977
  # if not plan_response or not isinstance(plan_response, dict):
 
978
  # yield json.dumps({
 
979
  # "agent": "Analytical Planner",
 
980
  # "content": "**Error: Invalid plan response**\n\nResponse: " + str(plan_response),
 
981
  # "status": "error"
 
982
  # }) + "\n"
 
983
  # return
 
984
 
 
985
  # Execute the plan with well-managed concurrency
 
986
  with dspy.context(lm = session_lm):
 
987
  # try:
 
988
 
 
989
  async for agent_name, inputs, response in session_state["ai_system"].execute_plan(enhanced_query, plan_response):
 
990
 
 
991
  if agent_name == "plan_not_found":
 
992
  yield json.dumps({
 
993
  "agent": "Analytical Planner",
 
994
  "content": "**No plan found**\n\nPlease try again with a different query or try using a different model.",
 
995
  "status": "error"
 
996
  }) + "\n"
 
997
  return
 
998
 
 
999
  if agent_name == "plan_not_formated_correctly":
 
1000
  yield json.dumps({
 
1001
  "agent": "Analytical Planner",
 
1002
  "content": "**Something went wrong with formatting, retry the query!**",
 
1003
  "status": "error"
 
1004
  }) + "\n"
 
1005
  return
 
1006
 
1007
 
 
 
1008
  formatted_response = format_response_to_markdown(
 
1009
  {agent_name: response},
 
1010
  datasets=session_state["datasets"]
1011
  )
1012
 
 
 
1013
  yield json.dumps({
 
1014
  "agent": agent_name.split("__")[0] if "__" in agent_name else agent_name,
 
1015
  "content": formatted_response,
 
1016
  "status": "success" if response else "error"
 
1017
  }) + "\n"
1018
 
 
 
1019
  # Handle agent errors
 
1020
  if isinstance(response, dict) and "error" in response:
 
1021
  yield json.dumps({
 
1022
  "agent": agent_name,
 
1023
  "content": f"**Error in {agent_name}**: {response['error']}",
 
1024
  "status": "error"
 
1025
  }) + "\n"
 
1026
  continue # Continue with next agent instead of returning
1027
 
1028
 
1029
 
 
 
 
 
1030
  if formatted_response == RESPONSE_ERROR_INVALID_QUERY:
 
1031
  yield json.dumps({
 
1032
  "agent": agent_name,
 
1033
  "content": formatted_response,
 
1034
  "status": "error"
 
1035
  }) + "\n"
 
1036
  continue # Continue with next agent instead of returning
1037
 
 
 
1038
  # Send response chunk
1039
 
 
 
1040
 
 
1041
  # Track agent usage for future batch DB write
 
1042
  if session_state.get("user_id"):
 
1043
  agent_tokens = _estimate_tokens(
 
1044
  ai_manager=app.state.ai_manager,
 
1045
  input_text=str(inputs),
 
1046
  output_text=str(response)
 
1047
  )
 
1048
 
 
1049
  # Get appropriate model name for code combiner
 
1050
  if "code_combiner_agent" in agent_name and "__" in agent_name:
 
1051
  provider = agent_name.split("__")[1]
 
1052
  model_name = _get_model_name_for_provider(provider)
 
1053
  else:
 
1054
  model_name = session_state.get("model_config", DEFAULT_MODEL_CONFIG)["model"]
1055
 
 
 
1056
  usage_records.append(_create_usage_record(
 
1057
  session_state=session_state,
 
1058
  model_name=model_name,
 
1059
  prompt_tokens=agent_tokens["prompt"],
 
1060
  completion_tokens=agent_tokens["completion"],
 
1061
  query_size=len(str(inputs)),
 
1062
  response_size=len(str(response)),
 
1063
  processing_time_ms=int((time.time() - overall_start_time) * 1000),
 
1064
  is_streaming=True
 
1065
  ))
 
1066
 
 
1067
  # except asyncio.TimeoutError:
 
1068
  # yield json.dumps({
 
1069
  # "agent": "planner",
 
1070
  # "content": "The request timed out. Please try a simpler query.",
 
1071
  # "status": "error"
 
1072
  # }) + "\n"
 
1073
  # return
 
1074
 
 
1075
  # except Exception as e:
 
1076
  # logger.log_message(f"Error executing plan: {str(e)}", level=logging.ERROR)
 
1077
  # yield json.dumps({
 
1078
  # "agent": "planner",
 
1079
  # "content": f"An error occurred while executing the plan: {str(e)}",
 
1080
  # "status": "error"
 
1081
  # }) + "\n"
 
1082
  # return
 
1083
 
 
1084
  # except Exception as e:
 
1085
  # logger.log_message(f"Error in streaming response: {str(e)}", level=logging.ERROR)
 
1086
  # yield json.dumps({
 
1087
  # "agent": "planner",
 
1088
  # "content": "An error occurred while generating responses. Please try again!" + str(e) + str({k: v for k, v in session_lm.__dict__['kwargs'].items() if k != 'api_key'}),
 
1089
  # "status": "error"
 
1090
  # }) + "\n"
1091
 
1092
 
 
 
 
1093
  def _estimate_tokens(ai_manager, input_text: str, output_text: str) -> dict:
 
1094
  """Estimate token counts, with fallback for tokenization errors"""
 
1095
  try:
 
1096
  # Try exact tokenization
 
1097
  prompt_tokens = len(ai_manager.tokenizer.encode(input_text))
 
1098
  completion_tokens = len(ai_manager.tokenizer.encode(output_text))
 
1099
  except Exception:
 
1100
  # Fall back to estimation
 
1101
  prompt_words = len(input_text.split())
 
1102
  completion_words = len(output_text.split())
 
1103
  prompt_tokens = int(prompt_words * DEFAULT_TOKEN_RATIO)
 
1104
  completion_tokens = int(completion_words * DEFAULT_TOKEN_RATIO)
 
1105
 
 
1106
  return {
 
1107
  "prompt": prompt_tokens,
 
1108
  "completion": completion_tokens,
 
1109
  "total": prompt_tokens + completion_tokens
 
1110
  }
1111
 
1112
 
 
 
 
1113
  def _create_usage_record(session_state: dict, model_name: str, prompt_tokens: int,
 
1114
  completion_tokens: int, query_size: int, response_size: int,
 
1115
  processing_time_ms: int, is_streaming: bool) -> dict:
 
1116
  """Create a usage record for the database"""
 
1117
  ai_manager = app.state.get_ai_manager()
 
1118
  provider = ai_manager.get_provider_for_model(model_name)
 
1119
  cost = ai_manager.calculate_cost(model_name, prompt_tokens, completion_tokens)
 
1120
 
 
1121
  return {
 
1122
  "user_id": session_state.get("user_id"),
 
1123
  "chat_id": session_state.get("chat_id"),
 
1124
  "model_name": model_name,
 
1125
  "provider": provider,
 
1126
  "prompt_tokens": int(prompt_tokens),
 
1127
  "completion_tokens": int(completion_tokens),
 
1128
  "total_tokens": int(prompt_tokens + completion_tokens),
 
1129
  "query_size": query_size,
 
1130
  "response_size": response_size,
 
1131
  "cost": round(cost, 7),
 
1132
  "request_time_ms": processing_time_ms,
 
1133
  "is_streaming": is_streaming
 
1134
  }
1135
 
1136
 
 
 
 
1137
  def _get_model_name_for_provider(provider: str) -> str:
 
1138
  """Get the model name for a provider"""
 
1139
  provider_model_map = {
 
1140
  "openai": "o3-mini",
 
1141
  "anthropic": "claude-3-7-sonnet-latest",
 
1142
  "gemini": "gemini-2.5-pro-preview-03-25"
 
1143
  }
 
1144
  return provider_model_map.get(provider, "o3-mini")
1145
 
1146
 
1147
 
 
 
 
 
1148
  # Add an endpoint to list available agents
 
1149
  @app.get("/agents", response_model=dict)
 
1150
  async def list_agents(request: Request, session_id: str = Depends(get_session_id_dependency)):
 
1151
  """Get all available agents (standard, template, and custom)"""
 
1152
  session_state = app.state.get_session_state(session_id)
 
1153
 
 
1154
  try:
 
1155
  # Get all available agents from database and session
 
1156
  available_agents_list = _get_available_agents_list(session_state)
 
1157
 
 
1158
  # Categorize agents
 
1159
  standard_agents = ["preprocessing_agent", "statistical_analytics_agent", "sk_learn_agent", "data_viz_agent"]
 
1160
 
 
1161
  # Get template agents from database
 
1162
  from src.db.init_db import session_factory
 
1163
  from src.agents.agents import load_all_available_templates_from_db
 
1164
 
 
1165
  db_session = session_factory()
 
1166
  try:
 
1167
  template_agents_dict = load_all_available_templates_from_db(db_session)
 
1168
  # template_agents_dict is a dict with template_name as keys
 
1169
  template_agents = [template_name for template_name in template_agents_dict.keys()
 
1170
  if template_name not in standard_agents and template_name != 'basic_qa_agent']
 
1171
  except Exception as e:
 
1172
  logger.log_message(f"Error loading template agents in /agents endpoint: {str(e)}", level=logging.ERROR)
 
1173
  template_agents = []
 
1174
  finally:
 
1175
  db_session.close()
 
1176
 
 
1177
  # Get custom agents from session
 
1178
  custom_agents = []
 
1179
  if session_state and "ai_system" in session_state:
 
1180
  ai_system = session_state["ai_system"]
 
1181
  if hasattr(ai_system, 'agents'):
 
1182
  custom_agents = [agent for agent in available_agents_list
 
1183
  if agent not in standard_agents and agent not in template_agents]
 
1184
 
 
1185
  # Ensure template agents are in the available list
 
1186
  for template_agent in template_agents:
 
1187
  if template_agent not in available_agents_list:
 
1188
  available_agents_list.append(template_agent)
 
1189
 
 
1190
  return {
 
1191
  "available_agents": available_agents_list,
 
1192
  "standard_agents": standard_agents,
 
1193
  "template_agents": template_agents,
 
1194
  "custom_agents": custom_agents
 
1195
  }
 
1196
  except Exception as e:
 
1197
  logger.log_message(f"Error getting agents list: {str(e)}", level=logging.ERROR)
 
1198
  raise HTTPException(status_code=500, detail=f"Error getting agents list: {str(e)}")
1199
 
 
 
1200
  @app.get("/health", response_model=dict)
 
1201
  async def health():
 
1202
  return {"message": "API is healthy and running"}
1203
 
 
 
1204
  @app.get("/")
 
1205
  async def index():
 
1206
  return {
 
1207
  "title": "Welcome to the AI Analytics API",
 
1208
  "message": "Explore our API for advanced analytics and visualization tools designed to empower your data-driven decisions.",
 
1209
  "description": "Utilize our powerful agents and models to gain insights from your data effortlessly.",
 
1210
  "colors": {
 
1211
  "primary": "#007bff",
 
1212
  "secondary": "#6c757d",
 
1213
  "success": "#28a745",
 
1214
  "danger": "#dc3545",
 
1215
  },
 
1216
  "features": [
 
1217
  "Real-time data processing",
 
1218
  "Customizable visualizations",
 
1219
  "Seamless integration with various data sources",
 
1220
  "User-friendly interface for easy navigation",
 
1221
  "Custom Analytics",
 
1222
  ],
 
1223
  }
1224
 
 
 
1225
  @app.post("/chat_history_name")
 
1226
  async def chat_history_name(request: dict, session_id: str = Depends(get_session_id_dependency)):
 
1227
  query = request.get("query")
 
1228
  name = None
 
1229
 
 
1230
  lm = dspy.LM(model="gpt-4o-mini", max_tokens=300, temperature=0.5)
 
1231
 
 
1232
  with dspy.context(lm=lm):
 
1233
  name = app.state.get_chat_history_name_agent()(query=str(query))
 
1234
 
 
1235
  return {"name": name.name if name else "New Chat"}
1236
 
 
 
1237
  @app.post("/deep_analysis_streaming")
 
1238
  async def deep_analysis_streaming(
 
1239
  request: DeepAnalysisRequest,
 
1240
  request_obj: Request,
 
1241
  session_id: str = Depends(get_session_id_dependency)
 
1242
  ):
 
1243
  """Perform streaming deep analysis with real-time updates"""
 
1244
  session_state = app.state.get_session_state(session_id)
 
1245
 
 
1246
  try:
 
1247
  # Extract and validate query parameters
 
1248
  _update_session_from_query_params(request_obj, session_state)
 
1249
 
 
1250
  # Validate dataset
 
1251
  if session_state["datasets"] is None:
1252
  raise HTTPException(status_code=400, detail=RESPONSE_ERROR_NO_DATASET)
 
1253
 
 
1254
  # Get user_id from session state (if available)
 
1255
  user_id = session_state.get("user_id")
 
1256
 
 
1257
  # Generate a UUID for this report
 
1258
  import uuid
 
1259
  report_uuid = str(uuid.uuid4())
 
1260
 
 
1261
  # Create initial pending report in the database
 
1262
  try:
 
1263
  from src.db.init_db import session_factory
 
1264
  from src.db.schemas.models import DeepAnalysisReport
 
1265
 
 
1266
  db_session = session_factory()
 
1267
 
 
1268
  try:
 
1269
  # Create a pending report entry
 
1270
  new_report = DeepAnalysisReport(
 
1271
  report_uuid=report_uuid,
 
1272
  user_id=user_id,
 
1273
  goal=request.goal,
 
1274
  status="pending",
 
1275
  start_time=datetime.now(UTC),
 
1276
  progress_percentage=0
 
1277
  )
 
1278
 
 
1279
  db_session.add(new_report)
 
1280
  db_session.commit()
 
1281
  db_session.refresh(new_report)
 
1282
 
 
1283
  # Store the report ID in session state for later updates
 
1284
  session_state["current_deep_analysis_id"] = new_report.report_id
 
1285
  session_state["current_deep_analysis_uuid"] = report_uuid
 
1286
 
 
1287
  except Exception as e:
 
1288
  logger.log_message(f"Error creating initial deep analysis report: {str(e)}", level=logging.ERROR)
 
1289
  # Continue even if DB storage fails
 
1290
  finally:
 
1291
  db_session.close()
 
1292
 
 
1293
  except Exception as e:
 
1294
  logger.log_message(f"Database operation failed: {str(e)}", level=logging.ERROR)
 
1295
  # Continue even if DB operation fails
 
1296
 
 
1297
  # Get session-specific model
 
1298
  # session_lm = get_session_lm(session_state)
 
1299
  session_lm = dspy.LM(model="anthropic/claude-sonnet-4-20250514", max_tokens=7000, temperature=0.5)
 
1300
 
 
1301
  return StreamingResponse(
 
1302
  _generate_deep_analysis_stream(session_state, request.goal, session_lm, session_id),
 
1303
  media_type='text/event-stream',
 
1304
  headers={
 
1305
  'Cache-Control': 'no-cache',
 
1306
  'Connection': 'keep-alive',
 
1307
  'Content-Type': 'text/event-stream',
 
1308
  'Access-Control-Allow-Origin': '*',
 
1309
  'X-Accel-Buffering': 'no'
 
1310
  }
 
1311
  )
 
1312
 
 
1313
  except HTTPException:
 
1314
  raise
 
1315
  except Exception as e:
 
1316
  logger.log_message(f"Streaming deep analysis failed: {str(e)}", level=logging.ERROR)
 
1317
  raise HTTPException(status_code=500, detail=f"Streaming deep analysis failed: {str(e)}")
1318
 
 
 
1319
  async def _generate_deep_analysis_stream(session_state: dict, goal: str, session_lm, session_id: str):
 
1320
  """Generate streaming responses for deep analysis"""
 
1321
  # Track the start time for duration calculation
 
1322
  start_time = datetime.now(UTC)
 
1323
 
 
1324
  try:
 
1325
  # Get dataset info
 
1326
  datasets = session_state["datasets"]
1327
  dtypes_info = pd.DataFrame({
 
1328
  'Column': df.columns,
 
1329
  'Data Type': df.dtypes.astype(str)
 
1330
  }).to_markdown()
 
1331
  dataset_info = f"Sample Data:\n{df.head(2).to_markdown()}\n\nData Types:\n{dtypes_info}"
 
1332
 
 
1333
  # Get report info from session state
 
1334
  report_id = session_state.get("current_deep_analysis_id")
 
1335
  report_uuid = session_state.get("current_deep_analysis_uuid")
 
1336
  user_id = session_state.get("user_id")
 
1337
 
 
1338
  # Helper function to update report in database
 
1339
  async def update_report_in_db(status, progress, step=None, content=None):
 
1340
  if not report_id:
 
1341
  return
 
1342
 
 
1343
  try:
 
1344
  from src.db.init_db import session_factory
 
1345
  from src.db.schemas.models import DeepAnalysisReport
 
1346
 
 
1347
  db_session = session_factory()
 
1348
 
 
1349
  try:
 
1350
  report = db_session.query(DeepAnalysisReport).filter(DeepAnalysisReport.report_id == report_id).first()
 
1351
 
 
1352
  if report:
 
1353
  report.status = status
 
1354
  report.progress_percentage = progress
 
1355
 
 
1356
  # Update step-specific fields if provided
 
1357
  if step == "questions" and content:
 
1358
  report.deep_questions = content
 
1359
  elif step == "planning" and content:
 
1360
  report.deep_plan = content
 
1361
  elif step == "analysis" and content:
 
1362
  # For analysis step, we get the full object with multiple fields
 
1363
  if isinstance(content, dict):
 
1364
  # Update fields from content if they exist
 
1365
  if "deep_questions" in content and content["deep_questions"]:
 
1366
  report.deep_questions = content["deep_questions"]
 
1367
  if "deep_plan" in content and content["deep_plan"]:
 
1368
  report.deep_plan = content["deep_plan"]
 
1369
  if "code" in content and content["code"]:
 
1370
  report.analysis_code = content["code"]
 
1371
  if "final_conclusion" in content and content["final_conclusion"]:
 
1372
  report.final_conclusion = content["final_conclusion"]
 
1373
  # Also update summary from conclusion
 
1374
  conclusion = content["final_conclusion"]
 
1375
  conclusion = conclusion.replace("**Conclusion**", "")
 
1376
  report.report_summary = conclusion[:200] + "..." if len(conclusion) > 200 else conclusion
 
1377
 
 
1378
  # Handle JSON fields
 
1379
  if "summaries" in content and content["summaries"]:
 
1380
  report.summaries = json.dumps(content["summaries"])
 
1381
  if "plotly_figs" in content and content["plotly_figs"]:
 
1382
  report.plotly_figures = json.dumps(content["plotly_figs"])
 
1383
  if "synthesis" in content and content["synthesis"]:
 
1384
  report.synthesis = json.dumps(content["synthesis"])
 
1385
 
 
1386
  # For the final step, update the HTML report
 
1387
  if step == "completed":
 
1388
  if content:
 
1389
  report.html_report = content
 
1390
  else:
 
1391
  logger.log_message("No HTML content provided for completed step", level=logging.WARNING)
 
1392
 
 
1393
  report.end_time = datetime.now(UTC)
 
1394
  # Ensure start_time is timezone-aware before calculating duration
 
1395
  if report.start_time.tzinfo is None:
 
1396
  start_time_utc = report.start_time.replace(tzinfo=UTC)
 
1397
  else:
 
1398
  start_time_utc = report.start_time
 
1399
  report.duration_seconds = int((report.end_time - start_time_utc).total_seconds())
 
1400
 
 
1401
  report.updated_at = datetime.now(UTC)
 
1402
  db_session.commit()
 
1403
 
 
1404
  except Exception as e:
 
1405
  db_session.rollback()
 
1406
  logger.log_message(f"Error updating deep analysis report: {str(e)}", level=logging.ERROR)
 
1407
  finally:
 
1408
  db_session.close()
 
1409
  except Exception as e:
 
1410
  logger.log_message(f"Database operation failed: {str(e)}", level=logging.ERROR)
 
1411
 
 
1412
  # Use session model for this request
 
1413
  with dspy.context(lm=session_lm):
 
1414
  # Send initial status
 
1415
  yield json.dumps({
 
1416
  "step": "initialization",
 
1417
  "status": "starting",
 
1418
  "message": "Initializing deep analysis...",
 
1419
  "progress": 5
 
1420
  }) + "\n"
 
1421
 
 
1422
  # Update DB status to running
 
1423
  await update_report_in_db("running", 5)
 
1424
 
 
1425
  # Get deep analyzer - use the correct session_id from the session_state
 
1426
  logger.log_message(f"Getting deep analyzer for session_id: {session_id}, user_id: {user_id}", level=logging.INFO)
 
1427
  deep_analyzer = app.state.get_deep_analyzer(session_id)
 
1428
 
 
1429
  # Make the dataset available globally for code execution
 
1430
  globals()['df'] = df
 
1431
 
 
1432
  # Use the new streaming method and forward all progress updates
 
1433
  final_result = None
 
1434
  async for update in deep_analyzer.execute_deep_analysis_streaming(
 
1435
  goal=goal,
 
1436
  dataset_info=dataset_info,
 
1437
  session_df=df
 
1438
  ):
 
1439
  # Convert the update to the expected format and yield it
 
1440
  if update.get("step") == "questions" and update.get("status") == "completed":
 
1441
  # Update DB with questions
 
1442
  await update_report_in_db("running", update.get("progress", 0), "questions", update.get("content"))
 
1443
  elif update.get("step") == "planning" and update.get("status") == "completed":
 
1444
  # Update DB with planning
 
1445
  await update_report_in_db("running", update.get("progress", 0), "planning", update.get("content"))
 
1446
  elif update.get("step") == "conclusion" and update.get("status") == "completed":
 
1447
  # Store the final result for later processing
 
1448
  final_result = update.get("final_result")
 
1449
 
 
1450
  # Convert Plotly figures to JSON format for network transmission
 
1451
  if final_result:
 
1452
  import plotly.io
 
1453
  serialized_return_dict = final_result.copy()
 
1454
 
 
1455
  # Convert plotly_figs to JSON format
 
1456
  if 'plotly_figs' in serialized_return_dict and serialized_return_dict['plotly_figs']:
 
1457
  json_figs = []
 
1458
  for fig_list in serialized_return_dict['plotly_figs']:
 
1459
  if isinstance(fig_list, list):
 
1460
  json_fig_list = []
 
1461
  for fig in fig_list:
 
1462
  if hasattr(fig, 'to_json'): # Check if it's a Plotly figure
 
1463
  json_fig_list.append(plotly.io.to_json(fig))
 
1464
  else:
 
1465
  json_fig_list.append(fig) # Already JSON or other format
 
1466
  json_figs.append(json_fig_list)
 
1467
  else:
 
1468
  # Single figure case
 
1469
  if hasattr(fig_list, 'to_json'):
 
1470
  json_figs.append(plotly.io.to_json(fig_list))
 
1471
  else:
 
1472
  json_figs.append(fig_list)
 
1473
  serialized_return_dict['plotly_figs'] = json_figs
 
1474
 
 
1475
  # Update DB with analysis results
 
1476
  await update_report_in_db("running", update.get("progress", 0), "analysis", serialized_return_dict)
 
1477
 
 
1478
  # Generate HTML report using the original final_result with Figure objects
 
1479
  html_report = None
 
1480
  try:
 
1481
  html_report = generate_html_report(final_result)
 
1482
  except Exception as e:
 
1483
  logger.log_message(f"Error generating HTML report: {str(e)}", level=logging.ERROR)
 
1484
  # Continue even if HTML generation fails
 
1485
 
 
1486
  # Send the analysis results
 
1487
  yield json.dumps({
 
1488
  "step": "analysis",
 
1489
  "status": "completed",
 
1490
  "content": serialized_return_dict,
 
1491
  "progress": 90
 
1492
  }) + "\n"
 
1493
 
 
1494
  # Send report generation status
 
1495
  yield json.dumps({
 
1496
  "step": "report",
 
1497
  "status": "processing",
 
1498
  "message": "Generating final report...",
 
1499
  "progress": 95
 
1500
  }) + "\n"
 
1501
 
 
1502
  # Send final completion
 
1503
  yield json.dumps({
 
1504
  "step": "completed",
 
1505
  "status": "success",
 
1506
  "analysis": serialized_return_dict,
 
1507
  "html_report": html_report,
 
1508
  "progress": 100
 
1509
  }) + "\n"
 
1510
 
 
1511
  # Update DB with completed report (with HTML if generated)
 
1512
  if html_report:
 
1513
  logger.log_message(f"Saving HTML report to database, length: {len(html_report)}", level=logging.INFO)
 
1514
  else:
 
1515
  logger.log_message("No HTML report to save to database", level=logging.WARNING)
 
1516
  await update_report_in_db("completed", 100, "completed", html_report)
 
1517
  elif update.get("step") == "error":
 
1518
  # Forward error directly
 
1519
  yield json.dumps(update) + "\n"
 
1520
  await update_report_in_db("failed", 0)
 
1521
  return
 
1522
  else:
 
1523
  # Forward all other progress updates
 
1524
  yield json.dumps(update) + "\n"
 
1525
 
 
1526
  # If we somehow exit the loop without getting a final result, that's an error
 
1527
  if not final_result:
 
1528
  yield json.dumps({
 
1529
  "step": "error",
 
1530
  "status": "failed",
 
1531
  "message": "Deep analysis completed without final result",
 
1532
  "progress": 0
 
1533
  }) + "\n"
 
1534
  await update_report_in_db("failed", 0)
 
1535
 
 
1536
  except Exception as e:
 
1537
  logger.log_message(f"Error in deep analysis stream: {str(e)}", level=logging.ERROR)
 
1538
  yield json.dumps({
 
1539
  "step": "error",
 
1540
  "status": "failed",
 
1541
  "message": f"Deep analysis failed: {str(e)}",
 
1542
  "progress": 0
 
1543
  }) + "\n"
 
1544
 
 
1545
  # Update DB with error status
 
1546
  if 'update_report_in_db' in locals() and session_state.get("current_deep_analysis_id"):
 
1547
  await update_report_in_db("failed", 0)
1548
 
 
 
1549
  @app.post("/deep_analysis/download_report")
 
1550
  async def download_html_report(
 
1551
  request: dict,
 
1552
  session_id: str = Depends(get_session_id_dependency)
 
1553
  ):
 
1554
  """Download HTML report from previous deep analysis"""
 
1555
  try:
 
1556
  analysis_data = request.get("analysis_data")
 
1557
  if not analysis_data:
 
1558
  raise HTTPException(status_code=400, detail="No analysis data provided")
 
1559
 
 
1560
  # Get report UUID from request if available (for saving to DB)
 
1561
  report_uuid = request.get("report_uuid")
 
1562
  session_state = app.state.get_session_state(session_id)
 
1563
 
 
1564
  # If no report_uuid in request, try to get it from session state
 
1565
  if not report_uuid and session_state.get("current_deep_analysis_uuid"):
 
1566
  report_uuid = session_state.get("current_deep_analysis_uuid")
 
1567
 
 
1568
  # Convert JSON-serialized Plotly figures back to Figure objects for HTML generation
 
1569
  processed_data = analysis_data.copy()
 
1570
 
 
1571
  if 'plotly_figs' in processed_data and processed_data['plotly_figs']:
 
1572
  import plotly.io
 
1573
  import plotly.graph_objects as go
 
1574
 
 
1575
  figure_objects = []
 
1576
  for fig_list in processed_data['plotly_figs']:
 
1577
  if isinstance(fig_list, list):
 
1578
  fig_obj_list = []
 
1579
  for fig_json in fig_list:
 
1580
  if isinstance(fig_json, str):
 
1581
  # Convert JSON string back to Figure object
 
1582
  try:
 
1583
  fig_obj = plotly.io.from_json(fig_json)
 
1584
  fig_obj_list.append(fig_obj)
 
1585
  except Exception as e:
 
1586
  logger.log_message(f"Error parsing Plotly JSON: {str(e)}", level=logging.WARNING)
 
1587
  continue
 
1588
  elif hasattr(fig_json, 'to_html'):
 
1589
  # Already a Figure object
 
1590
  fig_obj_list.append(fig_json)
 
1591
  figure_objects.append(fig_obj_list)
 
1592
  else:
 
1593
  # Single figure case
 
1594
  if isinstance(fig_list, str):
 
1595
  try:
 
1596
  fig_obj = plotly.io.from_json(fig_list)
 
1597
  figure_objects.append(fig_obj)
 
1598
  except Exception as e:
 
1599
  logger.log_message(f"Error parsing Plotly JSON: {str(e)}", level=logging.WARNING)
 
1600
  continue
 
1601
  elif hasattr(fig_list, 'to_html'):
 
1602
  figure_objects.append(fig_list)
 
1603
 
 
1604
  processed_data['plotly_figs'] = figure_objects
 
1605
 
 
1606
  # Generate HTML report
 
1607
  html_report = generate_html_report(processed_data)
 
1608
 
 
1609
  # Save report to database if we have a UUID
 
1610
  if report_uuid:
 
1611
  try:
 
1612
  from src.db.init_db import session_factory
 
1613
  from src.db.schemas.models import DeepAnalysisReport
 
1614
 
 
1615
  db_session = session_factory()
 
1616
  try:
 
1617
  # Try to find existing report by UUID
 
1618
  report = db_session.query(DeepAnalysisReport).filter(DeepAnalysisReport.report_uuid == report_uuid).first()
 
1619
 
 
1620
  if report:
 
1621
  # Update existing report with HTML content
 
1622
  report.html_report = html_report
 
1623
  report.updated_at = datetime.now(UTC)
 
1624
  db_session.commit()
 
1625
  except Exception as e:
 
1626
  db_session.rollback()
 
1627
  finally:
 
1628
  db_session.close()
 
1629
  except Exception as e:
 
1630
  logger.log_message(f"Database operation failed when storing HTML report: {str(e)}", level=logging.ERROR)
 
1631
  # Continue even if DB storage fails
 
1632
 
 
1633
  # Create a filename with timestamp
 
1634
  timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S")
 
1635
  filename = f"deep_analysis_report_{timestamp}.html"
 
1636
 
 
1637
  # Return as downloadable file
 
1638
  return StreamingResponse(
 
1639
  iter([html_report.encode('utf-8')]),
 
1640
  media_type='text/html',
 
1641
  headers={
 
1642
  'Content-Disposition': f'attachment; filename="{filename}"',
 
1643
  'Content-Type': 'text/html; charset=utf-8'
 
1644
  }
 
1645
  )
 
1646
 
 
1647
  except Exception as e:
 
1648
  logger.log_message(f"Failed to generate HTML report: {str(e)}", level=logging.ERROR)
 
1649
  raise HTTPException(status_code=500, detail=f"Failed to generate report: {str(e)}")
1650
 
1651
 
 
 
 
1652
  # In the section where routers are included, add the session_router
 
1653
  app.include_router(chat_router)
 
1654
  app.include_router(analytics_router)
 
1655
  app.include_router(code_router)
 
1656
  app.include_router(session_router)
 
1657
  app.include_router(feedback_router)
 
1658
  app.include_router(deep_analysis_router)
 
1659
  app.include_router(templates_router)
 
1660
  app.include_router(blog_router)
1661
 
 
 
1662
  if __name__ == "__main__":
 
1663
  port = int(os.environ.get("PORT", 8000))
 
1664
  uvicorn.run(app, host="0.0.0.0", port=port)
 
 
1
  # Standard library imports
2
+
3
  import asyncio
4
+
5
  import json
6
+
7
  import logging
8
+
9
  import os
10
+
11
  import time
12
+
13
  import uuid
14
+
15
  from io import StringIO
16
+
17
  from typing import List, Optional
18
+
19
  import ast
20
+
21
  import markdown
22
+
23
  from bs4 import BeautifulSoup
24
+
25
  import pandas as pd
26
+
27
  from datetime import datetime, UTC
28
+
29
  # Third-party imports
30
+
31
  import uvicorn
32
+
33
  from dotenv import load_dotenv
34
+
35
  from fastapi import (
36
+
37
  Depends,
38
+
39
  FastAPI,
40
+
41
  File,
42
+
43
  Form,
44
+
45
  HTTPException,
46
+
47
  Request,
48
+
49
  UploadFile
50
+
51
  )
52
+
53
  from fastapi.middleware.cors import CORSMiddleware
54
+
55
  from fastapi.responses import JSONResponse, StreamingResponse
56
+
57
  from fastapi.security import APIKeyHeader
58
+
59
  from llama_index.core import Document, VectorStoreIndex
60
+
61
  from pydantic import BaseModel
62
 
63
+
64
+
65
  # Local application imports
66
+
67
  from scripts.format_response import format_response_to_markdown
68
+
69
  from src.agents.agents import *
70
+
71
  from src.agents.retrievers.retrievers import *
72
+
73
  from src.managers.ai_manager import AI_Manager
74
+
75
  from src.managers.session_manager import SessionManager
76
+
77
  from src.routes.analytics_routes import router as analytics_router
78
+
79
  from src.routes.blog_routes import router as blog_router
80
+
81
  from src.routes.chat_routes import router as chat_router
82
+
83
  from src.routes.code_routes import router as code_router
84
+
85
  from src.routes.feedback_routes import router as feedback_router
86
+
87
  from src.routes.session_routes import router as session_router, get_session_id_dependency
88
+
89
  from src.routes.deep_analysis_routes import router as deep_analysis_router
90
+
91
  from src.routes.templates_routes import router as templates_router
92
+
93
  from src.schemas.query_schema import QueryRequest
94
+
95
  from src.utils.logger import Logger
96
 
97
+
98
+
99
  # Import deep analysis components directly
100
+
101
  # from src.agents.try_deep_agents import deep_analysis_module
102
+
103
  from src.agents.deep_agents import deep_analysis_module
104
+
105
  from src.utils.generate_report import generate_html_report
106
 
107
+
108
+
109
  from src.utils.model_registry import MODEL_OBJECTS
110
 
111
+
112
+
113
  logger = Logger("app", see_time=True, console_log=True)
114
+
115
  load_dotenv()
116
 
117
+
118
+
119
  # Request models
120
+
121
  class DeepAnalysisRequest(BaseModel):
122
+
123
  goal: str
124
+
125
 
126
+
127
  class DeepAnalysisResponse(BaseModel):
128
+
129
  goal: str
130
+
131
  deep_questions: str
132
+
133
  deep_plan: str
134
+
135
  summaries: List[str]
136
+
137
  code: str
138
+
139
  plotly_figs: List
140
+
141
  synthesis: List[str]
142
+
143
  final_conclusion: str
144
+
145
  html_report: Optional[str] = None
146
 
147
+
148
+
149
  styling_instructions = [
150
  {
151
  "category": "line_charts",
 
286
 
287
 
288
 
289
+
290
  # Add near the top of the file, after imports
291
+
292
  DEFAULT_MODEL_CONFIG = {
293
+
294
+ "provider": os.getenv("MODEL_PROVIDER", "openai"),
295
+
296
+ "model": os.getenv("MODEL_NAME", "gpt-5-mini"),
297
+
298
+ "api_key": os.getenv("OPENAI_API_KEY"),
299
+
300
  "temperature": float(os.getenv("TEMPERATURE", 1.0)),
301
+
302
  "max_tokens": int(os.getenv("MAX_TOKENS", 6000)), "cache": False
303
+
304
  }
305
 
306
+
307
+
308
  # Create default LM config but don't set it globally
309
 
310
+
311
+
312
  default_lm = MODEL_OBJECTS[DEFAULT_MODEL_CONFIG['model']]
313
+
314
 
315
 
316
+
317
+
318
 
319
+
320
  # lm = dspy.LM('openai/gpt-4o-mini', api_key=os.getenv("OPENAI_API_KEY"))
321
+
322
  dspy.configure(lm=default_lm, async_max_workers=100)
323
 
324
+
325
+
326
  # Function to get model config from session or use default
327
+
328
  def get_session_lm(session_state):
329
+
330
  """Get the appropriate LM instance for a session, or default if not configured"""
331
+
332
  # First check if we have a valid session-specific model config
333
+
334
  if session_state and isinstance(session_state, dict) and "model_config" in session_state:
335
+
336
  model_config = session_state["model_config"]
337
+
338
  if model_config and isinstance(model_config, dict) and "model" in model_config:
339
+
340
  # Found valid session-specific model config, use it
341
+
342
  provider = model_config.get("provider", "openai").lower()
343
+
344
  model_name = model_config.get("model", DEFAULT_MODEL_CONFIG["model"])
345
+
346
  if 'gpt-5' or 'o1' not in model_name:
347
+
348
  MODEL_OBJECTS[model_name].__dict__['kwargs']['max_tokens'] = model_config.get("max_tokens", DEFAULT_MODEL_CONFIG["max_tokens"])
349
+
350
  MODEL_OBJECTS[model_name].__dict__['kwargs']['temperature'] = model_config.get("temperature", DEFAULT_MODEL_CONFIG["temperature"])
351
+
352
  elif 'gpt-5' or 'o1' in model_name and provider =='openai':
353
+
354
  MODEL_OBJECTS[model_name].__dict__['kwargs']['max_completion_tokens'] = model_config.get("max_tokens", DEFAULT_MODEL_CONFIG["max_tokens"])
355
+
356
  MODEL_OBJECTS[model_name].__dict__['kwargs']['temperature'] = 1.0
357
+
358
  else:
359
+
360
  MODEL_OBJECTS[model_name].__dict__['kwargs']['max_tokens'] = model_config.get("max_tokens", DEFAULT_MODEL_CONFIG["max_tokens"])
361
+
362
  MODEL_OBJECTS[model_name].__dict__['kwargs']['temperature'] = model_config.get("temperature", DEFAULT_MODEL_CONFIG["temperature"])
363
 
364
+
365
+
366
 
367
+
368
  # If no valid session config, use default
369
+
370
  return MODEL_OBJECTS[model_name]
371
 
372
+
373
+
374
  # Initialize retrievers with empty data first
375
 
376
+
377
+
378
  # clear console
379
+
380
  def clear_console():
381
+
382
  os.system('cls' if os.name == 'nt' else 'clear')
383
 
384
 
385
+
386
+
387
+
388
  # Check for Housing.csv
389
+
390
  housing_csv_path = "Housing.csv"
391
+
392
  if not os.path.exists(housing_csv_path):
393
+
394
  logger.log_message(f"Housing.csv not found at {os.path.abspath(housing_csv_path)}", level=logging.ERROR)
395
+
396
  raise FileNotFoundError(f"Housing.csv not found at {os.path.abspath(housing_csv_path)}")
397
 
398
+
399
+
400
  # All agents are now loaded from database - no hardcoded dictionaries needed
401
 
402
+
403
+
404
  # Add session header
405
+
406
  X_SESSION_ID = APIKeyHeader(name="X-Session-ID", auto_error=False)
407
 
408
+
409
+
410
  # Update AppState class to use SessionManager
411
+
412
  class AppState:
413
+
414
  def __init__(self):
415
+
416
  self._session_manager = SessionManager(styling_instructions, {}) # Empty dict, agents loaded from DB
417
+
418
  self.model_config = DEFAULT_MODEL_CONFIG.copy()
419
+
420
  # Update the SessionManager with the current model_config
421
+
422
  self._session_manager._app_model_config = self.model_config
423
+
424
  self.ai_manager = AI_Manager()
425
+
426
  self.chat_name_agent = chat_history_name_agent
427
+
428
  # Initialize deep analysis module
429
+
430
  self.deep_analyzer = None
431
+
432
 
433
+
434
  def get_session_state(self, session_id: str):
435
+
436
  """Get or create session-specific state using the SessionManager"""
437
+
438
  return self._session_manager.get_session_state(session_id)
439
 
440
+
441
+
442
  def clear_session_state(self, session_id: str):
443
+
444
  """Clear session-specific state using the SessionManager"""
445
+
446
  self._session_manager.clear_session_state(session_id)
447
 
448
+
449
+
450
  def update_session_dataset(self, session_id: str, datasets, names, desc, pre_generated=False):
451
  """Update dataset for a specific session using the SessionManager"""
452
+
453
  self._session_manager.update_session_dataset(session_id, datasets, names, desc, pre_generated=pre_generated)
454
 
455
+
456
  def reset_session_to_default(self, session_id: str):
457
+
458
  """Reset a session to use the default dataset using the SessionManager"""
459
+
460
  self._session_manager.reset_session_to_default(session_id)
461
+
462
 
463
+
464
  def set_session_user(self, session_id: str, user_id: int, chat_id: int = None):
465
+
466
  """Associate a user with a session using the SessionManager"""
467
+
468
  return self._session_manager.set_session_user(session_id, user_id, chat_id)
469
+
470
 
471
+
472
  def get_ai_manager(self):
473
+
474
  """Get the AI Manager instance"""
475
+
476
  return self.ai_manager
477
+
478
 
479
+
480
  def get_provider_for_model(self, model_name):
481
+
482
  return self.ai_manager.get_provider_for_model(model_name)
483
+
484
 
485
+
486
  def calculate_cost(self, model_name, input_tokens, output_tokens):
487
+
488
  return self.ai_manager.calculate_cost(model_name, input_tokens, output_tokens)
489
+
490
 
491
+
492
  def save_usage_to_db(self, user_id, chat_id, model_name, provider, prompt_tokens, completion_tokens, total_tokens, query_size, response_size, cost, request_time_ms, is_streaming=False):
493
+
494
  return self.ai_manager.save_usage_to_db(user_id, chat_id, model_name, provider, prompt_tokens, completion_tokens, total_tokens, query_size, response_size, round(cost, 7), request_time_ms, is_streaming)
495
+
496
 
497
+
498
  def get_tokenizer(self):
499
+
500
  return self.ai_manager.tokenizer
501
+
502
 
503
+
504
  def get_chat_history_name_agent(self):
505
+
506
  return dspy.Predict(self.chat_name_agent)
507
 
508
+
509
+
510
  def get_deep_analyzer(self, session_id: str):
511
+
512
  """Get or create deep analysis module for a session"""
513
+
514
  session_state = self.get_session_state(session_id)
515
+
516
  user_id = session_state.get("user_id")
517
+
518
 
519
+
520
  # Check if we need to recreate the deep analyzer (user changed or doesn't exist)
521
+
522
  current_analyzer = session_state.get('deep_analyzer')
523
+
524
  analyzer_user_id = session_state.get('deep_analyzer_user_id')
525
+
526
 
527
+
528
  logger.log_message(f"Deep analyzer check - session: {session_id}, current_user: {user_id}, analyzer_user: {analyzer_user_id}, has_analyzer: {current_analyzer is not None}", level=logging.INFO)
529
+
530
 
531
+
532
  if (not current_analyzer or
533
+
534
  analyzer_user_id != user_id or
535
+
536
  not hasattr(session_state, 'deep_analyzer')):
537
+
538
 
539
+
540
  logger.log_message(f"Creating/recreating deep analyzer for session {session_id}, user_id: {user_id} (reason: analyzer_exists={current_analyzer is not None}, user_match={analyzer_user_id == user_id})", level=logging.INFO)
541
+
542
 
543
+
544
  # Load user-enabled agents from database using preference system
545
+
546
  from src.db.init_db import session_factory
547
+
548
  from src.agents.agents import load_user_enabled_templates_for_planner_from_db
549
+
550
 
551
+
552
  db_session = session_factory()
553
+
554
  try:
555
+
556
  # Load user-enabled agents for planner (respects preferences)
557
+
558
  if user_id:
559
+
560
  enabled_agents_dict = load_user_enabled_templates_for_planner_from_db(user_id, db_session)
561
+
562
  logger.log_message(f"Deep analyzer loaded {len(enabled_agents_dict)} enabled agents for user {user_id}: {list(enabled_agents_dict.keys())}", level=logging.INFO)
563
+
564
 
565
+
566
  if not enabled_agents_dict:
567
+
568
  logger.log_message(f"WARNING: No enabled agents found for user {user_id}, falling back to defaults", level=logging.WARNING)
569
+
570
  # Fallback to default agents if no enabled agents
571
+
572
  from src.agents.agents import preprocessing_agent, statistical_analytics_agent, sk_learn_agent, data_viz_agent
573
+
574
  enabled_agents_dict = {
575
+
576
  "preprocessing_agent": preprocessing_agent,
577
+
578
  "statistical_analytics_agent": statistical_analytics_agent,
579
+
580
  "sk_learn_agent": sk_learn_agent,
581
+
582
  "data_viz_agent": data_viz_agent
583
+
584
  }
585
+
586
  else:
587
+
588
  # Fallback to default agents if no user_id
589
+
590
  logger.log_message("No user_id in session, loading default agents for deep analysis", level=logging.WARNING)
591
+
592
  from src.agents.agents import preprocessing_agent, statistical_analytics_agent, sk_learn_agent, data_viz_agent
593
+
594
  enabled_agents_dict = {
595
+
596
  "preprocessing_agent": preprocessing_agent,
597
+
598
  "statistical_analytics_agent": statistical_analytics_agent,
599
+
600
  "sk_learn_agent": sk_learn_agent,
601
+
602
  "data_viz_agent": data_viz_agent
603
+
604
  }
605
+
606
 
607
+
608
  # Create agents dictionary for deep analysis using enabled agents
609
+
610
  deep_agents = {}
611
+
612
  deep_agents_desc = {}
613
+
614
 
615
+
616
  for agent_name, signature in enabled_agents_dict.items():
617
+
618
  deep_agents[agent_name] = dspy.asyncify(dspy.ChainOfThought(signature))
619
+
620
  # Get agent description from database
621
+
622
  deep_agents_desc[agent_name] = get_agent_description(agent_name)
623
+
624
 
625
+
626
  logger.log_message(f"Deep analyzer initialized with {len(deep_agents)} agents: {list(deep_agents.keys())}", level=logging.INFO)
627
+
628
 
629
+
630
  except Exception as e:
631
+
632
  logger.log_message(f"Error loading agents for deep analysis: {str(e)}", level=logging.ERROR)
633
+
634
  # Fallback to minimal set
635
+
636
  from src.agents.agents import preprocessing_agent, statistical_analytics_agent, sk_learn_agent, data_viz_agent
637
+
638
  deep_agents = {
639
+
640
  "preprocessing_agent": dspy.asyncify(dspy.Predict(preprocessing_agent)),
641
+
642
  "statistical_analytics_agent": dspy.asyncify(dspy.Predict(statistical_analytics_agent)),
643
+
644
  "sk_learn_agent": dspy.asyncify(dspy.Predict(sk_learn_agent)),
645
+
646
  "data_viz_agent": dspy.asyncify(dspy.Predict(data_viz_agent))
647
+
648
  }
649
+
650
  deep_agents_desc = {name: get_agent_description(name) for name in deep_agents.keys()}
651
+
652
  logger.log_message(f"Using fallback agents: {list(deep_agents.keys())}", level=logging.WARNING)
653
+
654
  finally:
655
+
656
  db_session.close()
657
+
658
 
659
+
660
+ session_state['deep_analyzer'] = deep_analysis_module(
661
+ agents=deep_agents,
662
+ agents_desc=deep_agents_desc
663
+ )
664
+ # Set datasets separately or pass them when needed
665
+ session_state['deep_analyzer'].datasets = session_state.get("datasets")
666
  session_state['deep_analyzer_user_id'] = user_id # Track which user this analyzer was created for
667
+
668
  else:
669
+
670
  logger.log_message(f"Using existing deep analyzer for session {session_id}, user_id: {user_id}", level=logging.INFO)
671
+
672
 
673
+
674
  return session_state['deep_analyzer']
675
 
676
+
677
+
678
  # Initialize FastAPI app with state
679
+
680
  app = FastAPI(title="AI Analytics API", version="1.0")
681
+
682
  app.state = AppState()
683
 
684
 
685
+
686
+
687
+
688
  # Configure middleware
689
+
690
  # Use a wildcard for local development or read from environment
691
+
692
  is_development = os.getenv("ENVIRONMENT", "development").lower() == "development"
693
 
694
+
695
+
696
  allowed_origins = []
697
+
698
  frontend_url = os.getenv("FRONTEND_URL", "").strip()
699
+
700
  print(f"FRONTEND_URL: {frontend_url}")
701
+
702
  if is_development:
703
+
704
  allowed_origins = ["*"]
705
+
706
  elif frontend_url:
707
+
708
  allowed_origins = [frontend_url]
709
+
710
  else:
711
+
712
  logger.log_message("CORS misconfigured: FRONTEND_URL not set", level=logging.ERROR)
713
+
714
  allowed_origins = [] # or set a default safe origin
715
 
716
+
717
+
718
  # Add a strict origin verification middleware
719
+
720
  @app.middleware("http")
721
+
722
  async def verify_origin_middleware(request: Request, call_next):
723
+
724
  # Skip origin check in development mode
725
+
726
  if is_development:
727
+
728
  return await call_next(request)
729
+
730
 
731
+
732
  # Get the origin from the request headers
733
+
734
  origin = request.headers.get("origin")
735
+
736
 
737
+
738
  # Log the origin for debugging
739
+
740
  if origin:
741
+
742
  print(f"Request from origin: {origin}")
743
+
744
 
745
+
746
  # If no origin header or origin not in allowed list, reject the request
747
+
748
  if origin and frontend_url and origin != frontend_url:
749
+
750
  print(f"Blocked request from unauthorized origin: {origin}")
751
+
752
  return JSONResponse(
753
+
754
  status_code=403,
755
+
756
  content={"detail": "Not authorized"}
757
+
758
  )
759
+
760
 
761
+
762
  # Continue processing the request if origin is allowed
763
+
764
  return await call_next(request)
765
 
766
+
767
+
768
  # CORS middleware (still needed for browser preflight)
769
+
770
  app.add_middleware(
771
+
772
  CORSMiddleware,
773
+
774
  allow_origins=allowed_origins,
775
+
776
  allow_origin_regex=None,
777
+
778
  allow_credentials=True,
779
+
780
  allow_methods=["*"],
781
+
782
  allow_headers=["*"],
783
+
784
  expose_headers=["*"],
785
+
786
  max_age=600 # Cache preflight requests for 10 minutes (for performance)
787
+
788
  )
789
 
790
+
791
+
792
  # Add these constants at the top of the file with other imports/constants
793
+
794
  RESPONSE_ERROR_INVALID_QUERY = "Please provide a valid query..."
795
+
796
  RESPONSE_ERROR_NO_DATASET = "No dataset is currently loaded. Please link a dataset before proceeding with your analysis."
797
+
798
  DEFAULT_TOKEN_RATIO = 1.5
799
+
800
  REQUEST_TIMEOUT_SECONDS = 30 # Timeout for LLM requests
801
+
802
  MAX_RECENT_MESSAGES = 5
803
+
804
  DB_BATCH_SIZE = 10 # For future batch DB operations
805
 
806
+
807
+
808
  @app.post("/chat/{agent_name}", response_model=dict)
809
+
810
  async def chat_with_agent(
811
+
812
  agent_name: str,
813
+
814
  request: QueryRequest,
815
+
816
  request_obj: Request,
817
+
818
  session_id: str = Depends(get_session_id_dependency)
819
+
820
  ):
821
+
822
  session_state = app.state.get_session_state(session_id)
823
+
824
  logger.log_message(f"[DEBUG] chat_with_agent called with agent: '{agent_name}', query: '{request.query[:100]}...'", level=logging.DEBUG)
825
+
826
 
827
+
828
  try:
829
+
830
  # Extract and validate query parameters
831
+
832
  logger.log_message(f"[DEBUG] Updating session from query params", level=logging.DEBUG)
833
+
834
  _update_session_from_query_params(request_obj, session_state)
835
+
836
  logger.log_message(f"[DEBUG] Session state after query params: user_id={session_state.get('user_id')}, chat_id={session_state.get('chat_id')}", level=logging.DEBUG)
837
+
838
 
839
+
840
  # Validate dataset and agent name
841
+
842
  if session_state["datasets"] is None:
843
  logger.log_message(f"[DEBUG] No dataset loaded", level=logging.DEBUG)
844
+
845
  raise HTTPException(status_code=400, detail=RESPONSE_ERROR_NO_DATASET)
846
 
847
+
848
+
849
  # Log the dataset being used for analysis with detailed information
850
  datasets = session_state["datasets"]
851
  dataset_names = list(datasets.keys())
 
867
  logger.log_message(f"[ANALYSIS] No datasets available in session {session_id}", level=logging.WARNING)
868
 
869
  logger.log_message(f"[DEBUG] About to validate agent name: '{agent_name}'", level=logging.DEBUG)
870
+
871
  _validate_agent_name(agent_name, session_state)
872
+
873
  logger.log_message(f"[DEBUG] Agent validation completed successfully", level=logging.DEBUG)
874
+
875
 
876
+
877
  # Record start time for timing
878
+
879
  start_time = time.time()
880
+
881
 
882
+
883
  # Get chat context and prepare query
884
+
885
  logger.log_message(f"[DEBUG] Preparing query with context", level=logging.DEBUG)
886
+
887
  enhanced_query = _prepare_query_with_context(request.query, session_state)
888
+
889
  logger.log_message(f"[DEBUG] Enhanced query length: {len(enhanced_query)}", level=logging.DEBUG)
890
+
891
 
892
+
893
  # Initialize agent - handle standard, template, and custom agents
894
+
895
  if "," in agent_name:
896
+
897
  logger.log_message(f"[DEBUG] Processing multiple agents: {agent_name}", level=logging.DEBUG)
898
+
899
  # Multiple agents case
900
+
901
  agent_list = [agent.strip() for agent in agent_name.split(",")]
902
+
903
 
904
+
905
  # Categorize agents
906
+
907
  standard_agents = [agent for agent in agent_list if _is_standard_agent(agent)]
908
+
909
  template_agents = [agent for agent in agent_list if _is_template_agent(agent)]
910
+
911
  custom_agents = [agent for agent in agent_list if not _is_standard_agent(agent) and not _is_template_agent(agent)]
912
+
913
 
914
+
915
  logger.log_message(f"[DEBUG] Agent categorization - standard: {standard_agents}, template: {template_agents}, custom: {custom_agents}", level=logging.DEBUG)
916
+
917
 
918
+
919
  if custom_agents:
920
+
921
  # If any custom agents, use session AI system for all
922
+
923
  ai_system = session_state["ai_system"]
924
+
925
  session_lm = get_session_lm(session_state)
926
+
927
  logger.log_message(f"[DEBUG] Using custom agent execution path", level=logging.DEBUG)
928
+
929
  with dspy.context(lm=session_lm):
930
+
931
  response = await asyncio.wait_for(
932
+
933
  _execute_custom_agents(ai_system, agent_list, enhanced_query),
934
+
935
  timeout=REQUEST_TIMEOUT_SECONDS
936
+
937
  )
938
+
939
  logger.log_message(f"[DEBUG] Custom agents response type: {type(response)}, keys: {list(response.keys()) if isinstance(response, dict) else 'not a dict'}", level=logging.DEBUG)
940
+
941
  else:
942
+
943
  # All standard/template agents - use auto_analyst_ind which loads from DB
944
+
945
  user_id = session_state.get("user_id")
946
+
947
  logger.log_message(f"[DEBUG] Using auto_analyst_ind for multiple standard/template agents with user_id: {user_id}", level=logging.DEBUG)
948
+
949
 
950
+
951
  # Create database session for agent loading
952
+
953
  from src.db.init_db import session_factory
954
+
955
  db_session = session_factory()
956
+
957
  try:
958
+
959
  # auto_analyst_ind will load all agents from database
960
+
961
  logger.log_message(f"[DEBUG] Creating auto_analyst_ind instance", level=logging.DEBUG)
962
+
963
  agent = auto_analyst_ind(agents=[], retrievers=session_state["retrievers"], user_id=user_id, db_session=db_session)
964
+
965
  session_lm = get_session_lm(session_state)
966
+
967
  logger.log_message(f"[DEBUG] About to call agent.forward with query and agent list", level=logging.DEBUG)
968
+
969
  with dspy.context(lm=session_lm):
970
+
971
  response = await asyncio.wait_for(
972
+
973
  agent.forward(enhanced_query, ",".join(agent_list)),
974
+
975
  timeout=REQUEST_TIMEOUT_SECONDS
976
+
977
  )
978
+
979
  logger.log_message(f"[DEBUG] auto_analyst_ind response type: {type(response)}, content: {str(response)[:200]}...", level=logging.DEBUG)
980
+
981
  finally:
982
+
983
  db_session.close()
984
+
985
  else:
986
+
987
  logger.log_message(f"[DEBUG] Processing single agent: {agent_name}", level=logging.DEBUG)
988
+
989
  # Single agent case
990
+
991
  if _is_standard_agent(agent_name) or _is_template_agent(agent_name):
992
+
993
  # Standard or template agent - use auto_analyst_ind which loads from DB
994
+
995
  user_id = session_state.get("user_id")
996
+
997
  logger.log_message(f"[DEBUG] Using auto_analyst_ind for single standard/template agent '{agent_name}' with user_id: {user_id}", level=logging.DEBUG)
998
+
999
 
1000
+
1001
  # Create database session for agent loading
1002
+
1003
  from src.db.init_db import session_factory
1004
+
1005
  db_session = session_factory()
1006
+
1007
  try:
1008
+
1009
  # auto_analyst_ind will load all agents from database
1010
+
1011
  logger.log_message(f"[DEBUG] Creating auto_analyst_ind instance for single agent", level=logging.DEBUG)
1012
+
1013
  agent = auto_analyst_ind(agents=[], retrievers=session_state["retrievers"], user_id=user_id, db_session=db_session)
1014
+
1015
  session_lm = get_session_lm(session_state)
1016
+
1017
  logger.log_message(f"[DEBUG] About to call agent.forward for single agent '{agent_name}'", level=logging.DEBUG)
1018
+
1019
  with dspy.context(lm=session_lm):
1020
+
1021
  response = await asyncio.wait_for(
1022
+
1023
  agent.forward(enhanced_query, agent_name),
1024
+
1025
  timeout=REQUEST_TIMEOUT_SECONDS
1026
+
1027
  )
1028
+
1029
  logger.log_message(f"[DEBUG] Single agent response type: {type(response)}, content: {str(response)[:200]}...", level=logging.DEBUG)
1030
+
1031
  finally:
1032
+
1033
  db_session.close()
1034
+
1035
  else:
1036
+
1037
  # Custom agent - use session AI system
1038
+
1039
  ai_system = session_state["ai_system"]
1040
+
1041
  session_lm = get_session_lm(session_state)
1042
+
1043
  logger.log_message(f"[DEBUG] Using custom agent execution for '{agent_name}'", level=logging.DEBUG)
1044
+
1045
  with dspy.context(lm=session_lm):
1046
+
1047
  response = await asyncio.wait_for(
1048
+
1049
  _execute_custom_agents(ai_system, [agent_name], enhanced_query),
1050
+
1051
  timeout=REQUEST_TIMEOUT_SECONDS
1052
+
1053
  )
1054
+
1055
  logger.log_message(f"[DEBUG] Custom single agent response type: {type(response)}, content: {str(response)[:200]}...", level=logging.DEBUG)
1056
+
1057
 
1058
+
1059
  logger.log_message(f"[DEBUG] About to format response to markdown. Response type: {type(response)}", level=logging.DEBUG)
1060
+
1061
  formatted_response = format_response_to_markdown(response, agent_name, session_state["datasets"])
1062
  logger.log_message(f"[DEBUG] Formatted response type: {type(formatted_response)}, length: {len(str(formatted_response))}", level=logging.DEBUG)
1063
+
1064
 
1065
+
1066
  if formatted_response == RESPONSE_ERROR_INVALID_QUERY:
1067
+
1068
  logger.log_message(f"[DEBUG] Response was invalid query error", level=logging.DEBUG)
1069
+
1070
  return {
1071
+
1072
  "agent_name": agent_name,
1073
+
1074
  "query": request.query,
1075
+
1076
  "response": formatted_response,
1077
+
1078
  "session_id": session_id
1079
+
1080
  }
1081
+
1082
 
1083
+
1084
  # Track usage statistics
1085
+
1086
  if session_state.get("user_id"):
1087
+
1088
  logger.log_message(f"[DEBUG] Tracking model usage", level=logging.DEBUG)
1089
+
1090
  _track_model_usage(
1091
+
1092
  session_state=session_state,
1093
+
1094
  enhanced_query=enhanced_query,
1095
+
1096
  response=response,
1097
+
1098
  processing_time_ms=int((time.time() - start_time) * 1000)
1099
+
1100
  )
1101
+
1102
 
1103
+
1104
  logger.log_message(f"[DEBUG] chat_with_agent completed successfully", level=logging.DEBUG)
1105
+
1106
  return {
1107
+
1108
  "agent_name": agent_name,
1109
+
1110
  "query": request.query, # Return original query without context
1111
+
1112
  "response": formatted_response,
1113
+
1114
  "session_id": session_id
1115
+
1116
  }
1117
+
1118
  except HTTPException:
1119
+
1120
  # Re-raise HTTP exceptions to preserve status codes
1121
+
1122
  logger.log_message(f"[DEBUG] HTTPException caught and re-raised", level=logging.DEBUG)
1123
+
1124
  raise
1125
+
1126
  except asyncio.TimeoutError:
1127
+
1128
  logger.log_message(f"[ERROR] Timeout error in chat_with_agent", level=logging.ERROR)
1129
+
1130
  raise HTTPException(status_code=504, detail="Request timed out. Please try a simpler query.")
1131
+
1132
  except Exception as e:
1133
+
1134
  logger.log_message(f"[ERROR] Unexpected error in chat_with_agent: {str(e)}", level=logging.ERROR)
1135
+
1136
  logger.log_message(f"[ERROR] Exception type: {type(e)}, traceback: {str(e)}", level=logging.ERROR)
1137
+
1138
  import traceback
1139
+
1140
  logger.log_message(f"[ERROR] Full traceback: {traceback.format_exc()}", level=logging.ERROR)
1141
+
1142
  raise HTTPException(status_code=500, detail="An unexpected error occurred. Please try again later.")
1143
 
1144
 
1145
+
1146
+
1147
+
1148
  @app.post("/chat", response_model=dict)
1149
+
1150
  async def chat_with_all(
1151
+
1152
  request: QueryRequest,
1153
+
1154
  request_obj: Request,
1155
+
1156
  session_id: str = Depends(get_session_id_dependency)
1157
+
1158
  ):
1159
+
1160
  session_state = app.state.get_session_state(session_id)
1161
 
1162
+
1163
+
1164
  try:
1165
+
1166
  # Extract and validate query parameters
1167
+
1168
  _update_session_from_query_params(request_obj, session_state)
1169
+
1170
 
1171
+
1172
  # Validate dataset
1173
+
1174
  if session_state["datasets"] is None:
1175
  raise HTTPException(status_code=400, detail=RESPONSE_ERROR_NO_DATASET)
1176
+
1177
 
1178
+
1179
  if session_state["ai_system"] is None:
1180
+
1181
  raise HTTPException(status_code=500, detail="AI system not properly initialized.")
1182
 
1183
+
1184
+
1185
  # Get session-specific model
1186
+
1187
  session_lm = get_session_lm(session_state)
1188
 
1189
+
1190
+
1191
  # Create streaming response
1192
+
1193
  return StreamingResponse(
1194
+
1195
  _generate_streaming_responses(session_state, request.query, session_lm),
1196
+
1197
  media_type='text/event-stream',
1198
+
1199
  headers={
1200
+
1201
  'Cache-Control': 'no-cache',
1202
+
1203
  'Connection': 'keep-alive',
1204
+
1205
  'Content-Type': 'text/event-stream',
1206
+
1207
  'Access-Control-Allow-Origin': '*',
1208
+
1209
  'X-Accel-Buffering': 'no'
1210
+
1211
  }
1212
+
1213
  )
1214
+
1215
  except HTTPException:
1216
+
1217
  # Re-raise HTTP exceptions to preserve status codes
1218
+
1219
  raise
1220
+
1221
  except Exception as e:
1222
+
1223
  raise HTTPException(status_code=500, detail="An unexpected error occurred. Please try again later.")
1224
 
1225
 
1226
+
1227
+
1228
+
1229
  # Helper functions to reduce duplication and improve modularity
1230
+
1231
  def _update_session_from_query_params(request_obj: Request, session_state: dict):
1232
+
1233
  """Extract and validate chat_id and user_id from query parameters"""
1234
+
1235
  # Check for chat_id in query parameters
1236
+
1237
  if "chat_id" in request_obj.query_params:
1238
+
1239
  try:
1240
+
1241
  chat_id_param = int(request_obj.query_params.get("chat_id"))
1242
+
1243
  # Update session state with this chat ID
1244
+
1245
  session_state["chat_id"] = chat_id_param
1246
+
1247
  except (ValueError, TypeError):
1248
+
1249
  logger.log_message("Invalid chat_id parameter", level=logging.WARNING)
1250
+
1251
  # Continue without updating chat_id
1252
 
1253
+
1254
+
1255
  # Check for user_id in query parameters
1256
+
1257
  if "user_id" in request_obj.query_params:
1258
+
1259
  try:
1260
+
1261
  user_id = int(request_obj.query_params["user_id"])
1262
+
1263
  session_state["user_id"] = user_id
1264
+
1265
  except (ValueError, TypeError):
1266
+
1267
  raise HTTPException(
1268
+
1269
  status_code=400,
1270
+
1271
  detail="Invalid user_id in query params. Please provide a valid integer."
1272
+
1273
  )
1274
 
1275
 
1276
+
1277
+
1278
+
1279
  def _validate_agent_name(agent_name: str, session_state: dict = None):
1280
+
1281
  """Validate that the agent name(s) are available"""
1282
+
1283
  logger.log_message(f"[DEBUG] Validating agent name: '{agent_name}'", level=logging.DEBUG)
1284
+
1285
 
1286
+
1287
  if "," in agent_name:
1288
+
1289
  # Multiple agents
1290
+
1291
  agent_list = [agent.strip() for agent in agent_name.split(",")]
1292
+
1293
  logger.log_message(f"[DEBUG] Multiple agents detected: {agent_list}", level=logging.DEBUG)
1294
+
1295
  for agent in agent_list:
1296
+
1297
  is_available = _is_agent_available(agent, session_state)
1298
+
1299
  logger.log_message(f"[DEBUG] Agent '{agent}' availability: {is_available}", level=logging.DEBUG)
1300
+
1301
  if not is_available:
1302
+
1303
  available_agents = _get_available_agents_list(session_state)
1304
+
1305
  logger.log_message(f"[DEBUG] Agent '{agent}' not found. Available: {available_agents}", level=logging.DEBUG)
1306
+
1307
  raise HTTPException(
1308
+
1309
  status_code=400,
1310
+
1311
  detail=f"Agent '{agent}' not found. Available agents: {available_agents}"
1312
+
1313
  )
1314
+
1315
  else:
1316
+
1317
  # Single agent
1318
+
1319
  is_available = _is_agent_available(agent_name, session_state)
1320
+
1321
  logger.log_message(f"[DEBUG] Single agent '{agent_name}' availability: {is_available}", level=logging.DEBUG)
1322
+
1323
  if not is_available:
1324
+
1325
  available_agents = _get_available_agents_list(session_state)
1326
+
1327
  logger.log_message(f"[DEBUG] Agent '{agent_name}' not found. Available: {available_agents}", level=logging.DEBUG)
1328
+
1329
  raise HTTPException(
1330
+
1331
  status_code=400,
1332
+
1333
  detail=f"Agent '{agent_name}' not found. Available agents: {available_agents}"
1334
+
1335
  )
1336
+
1337
 
1338
+
1339
  logger.log_message(f"[DEBUG] Agent validation passed for: '{agent_name}'", level=logging.DEBUG)
1340
 
1341
+
1342
+
1343
  def _is_agent_available(agent_name: str, session_state: dict = None) -> bool:
1344
+
1345
  """Check if an agent is available (standard, template, or custom)"""
1346
+
1347
  # Check if it's a standard agent
1348
+
1349
  if _is_standard_agent(agent_name):
1350
+
1351
  return True
1352
+
1353
 
1354
+
1355
  # Check if it's a template agent
1356
+
1357
  if _is_template_agent(agent_name):
1358
+
1359
  return True
1360
+
1361
 
1362
+
1363
  # Check if it's a custom agent in session
1364
+
1365
  if session_state and "ai_system" in session_state:
1366
+
1367
  ai_system = session_state["ai_system"]
1368
+
1369
  if hasattr(ai_system, 'agents') and agent_name in ai_system.agents:
1370
+
1371
  return True
1372
+
1373
 
1374
+
1375
  return False
1376
 
1377
+
1378
+
1379
  def _get_available_agents_list(session_state: dict = None) -> list:
1380
+
1381
  """Get list of all available agents from database"""
1382
+
1383
  from src.db.init_db import session_factory
1384
+
1385
  from src.agents.agents import load_all_available_templates_from_db
1386
+
1387
 
1388
+
1389
  # Core agents (always available)
1390
+
1391
  available = ["preprocessing_agent", "statistical_analytics_agent", "sk_learn_agent", "data_viz_agent"]
1392
+
1393
 
1394
+
1395
  # Add template agents from database
1396
+
1397
  db_session = session_factory()
1398
+
1399
  try:
1400
+
1401
  template_agents_dict = load_all_available_templates_from_db(db_session)
1402
+
1403
  # template_agents_dict is a dict with template_name as keys
1404
+
1405
  template_names = [template_name for template_name in template_agents_dict.keys()
1406
+
1407
  if template_name not in available and template_name != 'basic_qa_agent']
1408
+
1409
  available.extend(template_names)
1410
+
1411
  except Exception as e:
1412
+
1413
  logger.log_message(f"Error loading template agents: {str(e)}", level=logging.ERROR)
1414
+
1415
  finally:
1416
+
1417
  db_session.close()
1418
+
1419
 
1420
+
1421
  return available
1422
 
1423
+
1424
+
1425
  def _is_standard_agent(agent_name: str) -> bool:
1426
+
1427
  """Check if agent is one of the 4 core standard agents"""
1428
+
1429
  standard_agents = ["preprocessing_agent", "statistical_analytics_agent", "sk_learn_agent", "data_viz_agent"]
1430
+
1431
  return agent_name in standard_agents
1432
 
1433
+
1434
+
1435
  def _is_template_agent(agent_name: str) -> bool:
1436
+
1437
  """Check if agent is a template agent"""
1438
+
1439
  try:
1440
+
1441
  from src.db.init_db import session_factory
1442
+
1443
  from src.db.schemas.models import AgentTemplate
1444
+
1445
 
1446
+
1447
  db_session = session_factory()
1448
+
1449
  try:
1450
+
1451
  template = db_session.query(AgentTemplate).filter(
1452
+
1453
  AgentTemplate.template_name == agent_name,
1454
+
1455
  AgentTemplate.is_active == True
1456
+
1457
  ).first()
1458
+
1459
  return template is not None
1460
+
1461
  finally:
1462
+
1463
  db_session.close()
1464
+
1465
  except Exception as e:
1466
+
1467
  logger.log_message(f"Error checking if {agent_name} is template: {str(e)}", level=logging.ERROR)
1468
+
1469
  return False
1470
 
1471
+
1472
+
1473
  async def _execute_custom_agents(ai_system, agent_names: list, query: str):
1474
+
1475
  """Execute custom agents using the session's AI system"""
1476
+
1477
  try:
1478
+
1479
  # For custom agents, we need to use the AI system's execute_agent method
1480
 
1481
+
1482
+
1483
  agent_results = [ai_system]
1484
+
1485
  if len(agent_names) == 1:
1486
+
1487
  # Single custom agent
1488
+
1489
  agent_name = agent_names[0]
1490
+
1491
  # Prepare inputs for the custom agent (similar to standard agents like data_viz_agent)
1492
+
1493
  dict_ = {}
1494
+
1495
  dict_['dataset'] = ai_system.dataset.retrieve(query)[0].text
1496
+
1497
  dict_['styling_index'] = ai_system.styling_index.retrieve(query)[0].text
1498
+
1499
  dict_['goal'] = query
1500
+
1501
  dict_['Agent_desc'] = str(ai_system.agent_desc)
1502
 
1503
+
1504
+
1505
  # Get input fields for this agent
1506
+
1507
  if agent_name in ai_system.agent_inputs:
1508
+
1509
  inputs = {x: dict_[x] for x in ai_system.agent_inputs[agent_name] if x in dict_}
1510
+
1511
 
1512
+
1513
  # Execute the custom agent
1514
+
1515
  agent_name_result, result_dict = await ai_system.agents[agent_name](**inputs)
1516
+
1517
  return {agent_name_result: result_dict}
1518
+
1519
  else:
1520
+
1521
  logger.log_message(f"Agent '{agent_name}' not found in ai_system.agent_inputs", level=logging.ERROR)
1522
+
1523
  return {"error": f"Agent '{agent_name}' input configuration not found"}
1524
+
1525
  else:
1526
+
1527
  # Multiple agents - execute sequentially
1528
+
1529
  results = {}
1530
+
1531
  for agent_name in agent_names:
1532
+
1533
  single_result = await _execute_custom_agents(ai_system, [agent_name], query)
1534
+
1535
  results.update(single_result)
1536
+
1537
  return results
1538
+
1539
 
1540
+
1541
  except Exception as e:
1542
+
1543
  logger.log_message(f"Error in _execute_custom_agents: {str(e)}", level=logging.ERROR)
1544
+
1545
  return {"error": f"Error executing custom agents: {str(e)}"}
1546
 
1547
+
1548
+
1549
  def _prepare_query_with_context(query: str, session_state: dict) -> str:
1550
+
1551
  """Prepare the query with chat context from previous messages"""
1552
+
1553
  chat_id = session_state.get("chat_id")
1554
+
1555
  if not chat_id:
1556
+
1557
  return query
1558
+
1559
 
1560
+
1561
  # Get chat manager from app state
1562
+
1563
  chat_manager = app.state._session_manager.chat_manager
1564
+
1565
  # Get recent messages
1566
+
1567
  recent_messages = chat_manager.get_recent_chat_history(chat_id, limit=MAX_RECENT_MESSAGES)
1568
+
1569
  # Extract response history
1570
+
1571
  chat_context = chat_manager.extract_response_history(recent_messages)
1572
+
1573
 
1574
+
1575
  # Append context to the query if available
1576
+
1577
  if chat_context:
1578
+
1579
  return f"### Current Query:\n{query}\n\n{chat_context}"
1580
+
1581
  return query
1582
 
1583
 
1584
+
1585
+
1586
+
1587
  def _track_model_usage(session_state: dict, enhanced_query: str, response, processing_time_ms: int):
1588
+
1589
  """Track model usage statistics in the database"""
1590
+
1591
  try:
1592
+
1593
  ai_manager = app.state.get_ai_manager()
1594
+
1595
 
1596
+
1597
  # Get model configuration
1598
+
1599
  model_config = session_state.get("model_config", DEFAULT_MODEL_CONFIG)
1600
+
1601
  model_name = model_config.get("model", DEFAULT_MODEL_CONFIG["model"])
1602
+
1603
  provider = ai_manager.get_provider_for_model(model_name)
1604
+
1605
 
1606
+
1607
  # Calculate token usage
1608
+
1609
  try:
1610
+
1611
  # Try exact tokenization
1612
+
1613
  prompt_tokens = len(ai_manager.tokenizer.encode(enhanced_query))
1614
+
1615
  completion_tokens = len(ai_manager.tokenizer.encode(str(response)))
1616
+
1617
  total_tokens = prompt_tokens + completion_tokens
1618
+
1619
  except Exception as token_error:
1620
+
1621
  # Fall back to estimation
1622
+
1623
  logger.log_message(f"Tokenization error: {str(token_error)}", level=logging.WARNING)
1624
+
1625
  prompt_words = len(enhanced_query.split())
1626
+
1627
  completion_words = len(str(response).split())
1628
+
1629
  prompt_tokens = int(prompt_words * DEFAULT_TOKEN_RATIO)
1630
+
1631
  completion_tokens = int(completion_words * DEFAULT_TOKEN_RATIO)
1632
+
1633
  total_tokens = prompt_tokens + completion_tokens
1634
+
1635
 
1636
+
1637
  # Calculate cost
1638
+
1639
  cost = ai_manager.calculate_cost(model_name, prompt_tokens, completion_tokens)
1640
+
1641
 
1642
+
1643
  # Save usage to database
1644
+
1645
  ai_manager.save_usage_to_db(
1646
+
1647
  user_id=session_state.get("user_id"),
1648
+
1649
  chat_id=session_state.get("chat_id"),
1650
+
1651
  model_name=model_name,
1652
+
1653
  provider=provider,
1654
+
1655
  prompt_tokens=int(prompt_tokens),
1656
+
1657
  completion_tokens=int(completion_tokens),
1658
+
1659
  total_tokens=int(total_tokens),
1660
+
1661
  query_size=len(enhanced_query),
1662
+
1663
  response_size=len(str(response)),
1664
+
1665
  cost=round(cost, 7),
1666
+
1667
  request_time_ms=processing_time_ms,
1668
+
1669
  is_streaming=False
1670
+
1671
  )
1672
+
1673
  except Exception as e:
1674
+
1675
  # Log but don't fail the request if usage tracking fails
1676
+
1677
  logger.log_message(f"Failed to track model usage: {str(e)}", level=logging.ERROR)
1678
 
1679
 
1680
+
1681
+
1682
+
1683
  async def _generate_streaming_responses(session_state: dict, query: str, session_lm):
1684
+
1685
  """Generate streaming responses for chat_with_all endpoint"""
1686
+
1687
  overall_start_time = time.time()
1688
+
1689
  total_response = ""
1690
+
1691
  total_inputs = ""
1692
+
1693
  usage_records = []
1694
 
1695
+
1696
+
1697
  # Add chat context from previous messages
1698
+
1699
  enhanced_query = _prepare_query_with_context(query, session_state)
1700
+
1701
 
1702
+
1703
  # try:
1704
+
1705
  # Get the plan - planner is now async, so we need to await it
1706
+
1707
  plan_response = await session_state["ai_system"].get_plan(enhanced_query)
1708
+
1709
 
1710
+
1711
  plan_description = format_response_to_markdown(
1712
+
1713
  {"analytical_planner": plan_response},
1714
+
1715
  datasets=session_state["datasets"]
1716
  )
1717
+
1718
 
1719
+
1720
  # Check if plan is valid
1721
+
1722
  if plan_description == RESPONSE_ERROR_INVALID_QUERY:
1723
+
1724
  yield json.dumps({
1725
+
1726
  "agent": "Analytical Planner",
1727
+
1728
  "content": plan_description,
1729
+
1730
  "status": "error"
1731
+
1732
  }) + "\n"
1733
+
1734
  return
1735
+
1736
 
1737
+
1738
  yield json.dumps({
1739
+
1740
  "agent": "Analytical Planner",
1741
+
1742
  "content": plan_description,
1743
+
1744
  "status": "success" if plan_description else "error"
1745
+
1746
  }) + "\n"
1747
+
1748
 
1749
+
1750
  # Track planner usage
1751
+
1752
  if session_state.get("user_id"):
1753
+
1754
  planner_tokens = _estimate_tokens(ai_manager=app.state.ai_manager,
1755
+
1756
  input_text=enhanced_query,
1757
+
1758
  output_text=plan_description)
1759
+
1760
 
1761
+
1762
  usage_records.append(_create_usage_record(
1763
+
1764
  session_state=session_state,
1765
+
1766
  model_name=session_state.get("model_config", DEFAULT_MODEL_CONFIG)["model"],
1767
+
1768
  prompt_tokens=planner_tokens["prompt"],
1769
+
1770
  completion_tokens=planner_tokens["completion"],
1771
+
1772
  query_size=len(enhanced_query),
1773
+
1774
  response_size=len(plan_description),
1775
+
1776
  processing_time_ms=int((time.time() - overall_start_time) * 1000),
1777
+
1778
  is_streaming=False
1779
+
1780
  ))
1781
+
1782
 
1783
+
1784
  logger.log_message(f"Plan response: {plan_response}", level=logging.INFO)
1785
+
1786
  logger.log_message(f"Plan response type: {type(plan_response)}", level=logging.INFO)
1787
 
1788
+
1789
+
1790
  # Check if plan_response is valid
1791
+
1792
  # if not plan_response or not isinstance(plan_response, dict):
1793
+
1794
  # yield json.dumps({
1795
+
1796
  # "agent": "Analytical Planner",
1797
+
1798
  # "content": "**Error: Invalid plan response**\n\nResponse: " + str(plan_response),
1799
+
1800
  # "status": "error"
1801
+
1802
  # }) + "\n"
1803
+
1804
  # return
1805
+
1806
 
1807
+
1808
  # Execute the plan with well-managed concurrency
1809
+
1810
  with dspy.context(lm = session_lm):
1811
+
1812
  # try:
1813
+
1814
 
1815
+
1816
  async for agent_name, inputs, response in session_state["ai_system"].execute_plan(enhanced_query, plan_response):
1817
+
1818
 
1819
+
1820
  if agent_name == "plan_not_found":
1821
+
1822
  yield json.dumps({
1823
+
1824
  "agent": "Analytical Planner",
1825
+
1826
  "content": "**No plan found**\n\nPlease try again with a different query or try using a different model.",
1827
+
1828
  "status": "error"
1829
+
1830
  }) + "\n"
1831
+
1832
  return
1833
+
1834
 
1835
+
1836
  if agent_name == "plan_not_formated_correctly":
1837
+
1838
  yield json.dumps({
1839
+
1840
  "agent": "Analytical Planner",
1841
+
1842
  "content": "**Something went wrong with formatting, retry the query!**",
1843
+
1844
  "status": "error"
1845
+
1846
  }) + "\n"
1847
+
1848
  return
1849
+
1850
 
1851
 
1852
+
1853
+
1854
  formatted_response = format_response_to_markdown(
1855
+
1856
  {agent_name: response},
1857
+
1858
  datasets=session_state["datasets"]
1859
  )
1860
 
1861
+
1862
+
1863
  yield json.dumps({
1864
+
1865
  "agent": agent_name.split("__")[0] if "__" in agent_name else agent_name,
1866
+
1867
  "content": formatted_response,
1868
+
1869
  "status": "success" if response else "error"
1870
+
1871
  }) + "\n"
1872
 
1873
+
1874
+
1875
  # Handle agent errors
1876
+
1877
  if isinstance(response, dict) and "error" in response:
1878
+
1879
  yield json.dumps({
1880
+
1881
  "agent": agent_name,
1882
+
1883
  "content": f"**Error in {agent_name}**: {response['error']}",
1884
+
1885
  "status": "error"
1886
+
1887
  }) + "\n"
1888
+
1889
  continue # Continue with next agent instead of returning
1890
 
1891
 
1892
 
1893
+
1894
+
1895
+
1896
+
1897
  if formatted_response == RESPONSE_ERROR_INVALID_QUERY:
1898
+
1899
  yield json.dumps({
1900
+
1901
  "agent": agent_name,
1902
+
1903
  "content": formatted_response,
1904
+
1905
  "status": "error"
1906
+
1907
  }) + "\n"
1908
+
1909
  continue # Continue with next agent instead of returning
1910
 
1911
+
1912
+
1913
  # Send response chunk
1914
 
1915
+
1916
+
1917
 
1918
+
1919
  # Track agent usage for future batch DB write
1920
+
1921
  if session_state.get("user_id"):
1922
+
1923
  agent_tokens = _estimate_tokens(
1924
+
1925
  ai_manager=app.state.ai_manager,
1926
+
1927
  input_text=str(inputs),
1928
+
1929
  output_text=str(response)
1930
+
1931
  )
1932
+
1933
 
1934
+
1935
  # Get appropriate model name for code combiner
1936
+
1937
  if "code_combiner_agent" in agent_name and "__" in agent_name:
1938
+
1939
  provider = agent_name.split("__")[1]
1940
+
1941
  model_name = _get_model_name_for_provider(provider)
1942
+
1943
  else:
1944
+
1945
  model_name = session_state.get("model_config", DEFAULT_MODEL_CONFIG)["model"]
1946
 
1947
+
1948
+
1949
  usage_records.append(_create_usage_record(
1950
+
1951
  session_state=session_state,
1952
+
1953
  model_name=model_name,
1954
+
1955
  prompt_tokens=agent_tokens["prompt"],
1956
+
1957
  completion_tokens=agent_tokens["completion"],
1958
+
1959
  query_size=len(str(inputs)),
1960
+
1961
  response_size=len(str(response)),
1962
+
1963
  processing_time_ms=int((time.time() - overall_start_time) * 1000),
1964
+
1965
  is_streaming=True
1966
+
1967
  ))
1968
+
1969
 
1970
+
1971
  # except asyncio.TimeoutError:
1972
+
1973
  # yield json.dumps({
1974
+
1975
  # "agent": "planner",
1976
+
1977
  # "content": "The request timed out. Please try a simpler query.",
1978
+
1979
  # "status": "error"
1980
+
1981
  # }) + "\n"
1982
+
1983
  # return
1984
+
1985
 
1986
+
1987
  # except Exception as e:
1988
+
1989
  # logger.log_message(f"Error executing plan: {str(e)}", level=logging.ERROR)
1990
+
1991
  # yield json.dumps({
1992
+
1993
  # "agent": "planner",
1994
+
1995
  # "content": f"An error occurred while executing the plan: {str(e)}",
1996
+
1997
  # "status": "error"
1998
+
1999
  # }) + "\n"
2000
+
2001
  # return
2002
+
2003
 
2004
+
2005
  # except Exception as e:
2006
+
2007
  # logger.log_message(f"Error in streaming response: {str(e)}", level=logging.ERROR)
2008
+
2009
  # yield json.dumps({
2010
+
2011
  # "agent": "planner",
2012
+
2013
  # "content": "An error occurred while generating responses. Please try again!" + str(e) + str({k: v for k, v in session_lm.__dict__['kwargs'].items() if k != 'api_key'}),
2014
+
2015
  # "status": "error"
2016
+
2017
  # }) + "\n"
2018
 
2019
 
2020
+
2021
+
2022
+
2023
  def _estimate_tokens(ai_manager, input_text: str, output_text: str) -> dict:
2024
+
2025
  """Estimate token counts, with fallback for tokenization errors"""
2026
+
2027
  try:
2028
+
2029
  # Try exact tokenization
2030
+
2031
  prompt_tokens = len(ai_manager.tokenizer.encode(input_text))
2032
+
2033
  completion_tokens = len(ai_manager.tokenizer.encode(output_text))
2034
+
2035
  except Exception:
2036
+
2037
  # Fall back to estimation
2038
+
2039
  prompt_words = len(input_text.split())
2040
+
2041
  completion_words = len(output_text.split())
2042
+
2043
  prompt_tokens = int(prompt_words * DEFAULT_TOKEN_RATIO)
2044
+
2045
  completion_tokens = int(completion_words * DEFAULT_TOKEN_RATIO)
2046
+
2047
 
2048
+
2049
  return {
2050
+
2051
  "prompt": prompt_tokens,
2052
+
2053
  "completion": completion_tokens,
2054
+
2055
  "total": prompt_tokens + completion_tokens
2056
+
2057
  }
2058
 
2059
 
2060
+
2061
+
2062
+
2063
  def _create_usage_record(session_state: dict, model_name: str, prompt_tokens: int,
2064
+
2065
  completion_tokens: int, query_size: int, response_size: int,
2066
+
2067
  processing_time_ms: int, is_streaming: bool) -> dict:
2068
+
2069
  """Create a usage record for the database"""
2070
+
2071
  ai_manager = app.state.get_ai_manager()
2072
+
2073
  provider = ai_manager.get_provider_for_model(model_name)
2074
+
2075
  cost = ai_manager.calculate_cost(model_name, prompt_tokens, completion_tokens)
2076
+
2077
 
2078
+
2079
  return {
2080
+
2081
  "user_id": session_state.get("user_id"),
2082
+
2083
  "chat_id": session_state.get("chat_id"),
2084
+
2085
  "model_name": model_name,
2086
+
2087
  "provider": provider,
2088
+
2089
  "prompt_tokens": int(prompt_tokens),
2090
+
2091
  "completion_tokens": int(completion_tokens),
2092
+
2093
  "total_tokens": int(prompt_tokens + completion_tokens),
2094
+
2095
  "query_size": query_size,
2096
+
2097
  "response_size": response_size,
2098
+
2099
  "cost": round(cost, 7),
2100
+
2101
  "request_time_ms": processing_time_ms,
2102
+
2103
  "is_streaming": is_streaming
2104
+
2105
  }
2106
 
2107
 
2108
+
2109
+
2110
+
2111
  def _get_model_name_for_provider(provider: str) -> str:
2112
+
2113
  """Get the model name for a provider"""
2114
+
2115
  provider_model_map = {
2116
+
2117
  "openai": "o3-mini",
2118
+
2119
  "anthropic": "claude-3-7-sonnet-latest",
2120
+
2121
  "gemini": "gemini-2.5-pro-preview-03-25"
2122
+
2123
  }
2124
+
2125
  return provider_model_map.get(provider, "o3-mini")
2126
 
2127
 
2128
 
2129
+
2130
+
2131
+
2132
+
2133
  # Add an endpoint to list available agents
2134
+
2135
  @app.get("/agents", response_model=dict)
2136
+
2137
  async def list_agents(request: Request, session_id: str = Depends(get_session_id_dependency)):
2138
+
2139
  """Get all available agents (standard, template, and custom)"""
2140
+
2141
  session_state = app.state.get_session_state(session_id)
2142
+
2143
 
2144
+
2145
  try:
2146
+
2147
  # Get all available agents from database and session
2148
+
2149
  available_agents_list = _get_available_agents_list(session_state)
2150
+
2151
 
2152
+
2153
  # Categorize agents
2154
+
2155
  standard_agents = ["preprocessing_agent", "statistical_analytics_agent", "sk_learn_agent", "data_viz_agent"]
2156
+
2157
 
2158
+
2159
  # Get template agents from database
2160
+
2161
  from src.db.init_db import session_factory
2162
+
2163
  from src.agents.agents import load_all_available_templates_from_db
2164
+
2165
 
2166
+
2167
  db_session = session_factory()
2168
+
2169
  try:
2170
+
2171
  template_agents_dict = load_all_available_templates_from_db(db_session)
2172
+
2173
  # template_agents_dict is a dict with template_name as keys
2174
+
2175
  template_agents = [template_name for template_name in template_agents_dict.keys()
2176
+
2177
  if template_name not in standard_agents and template_name != 'basic_qa_agent']
2178
+
2179
  except Exception as e:
2180
+
2181
  logger.log_message(f"Error loading template agents in /agents endpoint: {str(e)}", level=logging.ERROR)
2182
+
2183
  template_agents = []
2184
+
2185
  finally:
2186
+
2187
  db_session.close()
2188
+
2189
 
2190
+
2191
  # Get custom agents from session
2192
+
2193
  custom_agents = []
2194
+
2195
  if session_state and "ai_system" in session_state:
2196
+
2197
  ai_system = session_state["ai_system"]
2198
+
2199
  if hasattr(ai_system, 'agents'):
2200
+
2201
  custom_agents = [agent for agent in available_agents_list
2202
+
2203
  if agent not in standard_agents and agent not in template_agents]
2204
+
2205
 
2206
+
2207
  # Ensure template agents are in the available list
2208
+
2209
  for template_agent in template_agents:
2210
+
2211
  if template_agent not in available_agents_list:
2212
+
2213
  available_agents_list.append(template_agent)
2214
+
2215
 
2216
+
2217
  return {
2218
+
2219
  "available_agents": available_agents_list,
2220
+
2221
  "standard_agents": standard_agents,
2222
+
2223
  "template_agents": template_agents,
2224
+
2225
  "custom_agents": custom_agents
2226
+
2227
  }
2228
+
2229
  except Exception as e:
2230
+
2231
  logger.log_message(f"Error getting agents list: {str(e)}", level=logging.ERROR)
2232
+
2233
  raise HTTPException(status_code=500, detail=f"Error getting agents list: {str(e)}")
2234
 
2235
+
2236
+
2237
  @app.get("/health", response_model=dict)
2238
+
2239
  async def health():
2240
+
2241
  return {"message": "API is healthy and running"}
2242
 
2243
+
2244
+
2245
  @app.get("/")
2246
+
2247
  async def index():
2248
+
2249
  return {
2250
+
2251
  "title": "Welcome to the AI Analytics API",
2252
+
2253
  "message": "Explore our API for advanced analytics and visualization tools designed to empower your data-driven decisions.",
2254
+
2255
  "description": "Utilize our powerful agents and models to gain insights from your data effortlessly.",
2256
+
2257
  "colors": {
2258
+
2259
  "primary": "#007bff",
2260
+
2261
  "secondary": "#6c757d",
2262
+
2263
  "success": "#28a745",
2264
+
2265
  "danger": "#dc3545",
2266
+
2267
  },
2268
+
2269
  "features": [
2270
+
2271
  "Real-time data processing",
2272
+
2273
  "Customizable visualizations",
2274
+
2275
  "Seamless integration with various data sources",
2276
+
2277
  "User-friendly interface for easy navigation",
2278
+
2279
  "Custom Analytics",
2280
+
2281
  ],
2282
+
2283
  }
2284
 
2285
+
2286
+
2287
  @app.post("/chat_history_name")
2288
+
2289
  async def chat_history_name(request: dict, session_id: str = Depends(get_session_id_dependency)):
2290
+
2291
  query = request.get("query")
2292
+
2293
  name = None
2294
+
2295
 
2296
+
2297
  lm = dspy.LM(model="gpt-4o-mini", max_tokens=300, temperature=0.5)
2298
+
2299
 
2300
+
2301
  with dspy.context(lm=lm):
2302
+
2303
  name = app.state.get_chat_history_name_agent()(query=str(query))
2304
+
2305
 
2306
+
2307
  return {"name": name.name if name else "New Chat"}
2308
 
2309
+
2310
+
2311
  @app.post("/deep_analysis_streaming")
2312
+
2313
  async def deep_analysis_streaming(
2314
+
2315
  request: DeepAnalysisRequest,
2316
+
2317
  request_obj: Request,
2318
+
2319
  session_id: str = Depends(get_session_id_dependency)
2320
+
2321
  ):
2322
+
2323
  """Perform streaming deep analysis with real-time updates"""
2324
+
2325
  session_state = app.state.get_session_state(session_id)
2326
+
2327
 
2328
+
2329
  try:
2330
+
2331
  # Extract and validate query parameters
2332
+
2333
  _update_session_from_query_params(request_obj, session_state)
2334
+
2335
 
2336
+
2337
  # Validate dataset
2338
+
2339
  if session_state["datasets"] is None:
2340
  raise HTTPException(status_code=400, detail=RESPONSE_ERROR_NO_DATASET)
2341
+
2342
 
2343
+
2344
  # Get user_id from session state (if available)
2345
+
2346
  user_id = session_state.get("user_id")
2347
+
2348
 
2349
+
2350
  # Generate a UUID for this report
2351
+
2352
  import uuid
2353
+
2354
  report_uuid = str(uuid.uuid4())
2355
+
2356
 
2357
+
2358
  # Create initial pending report in the database
2359
+
2360
  try:
2361
+
2362
  from src.db.init_db import session_factory
2363
+
2364
  from src.db.schemas.models import DeepAnalysisReport
2365
+
2366
 
2367
+
2368
  db_session = session_factory()
2369
+
2370
 
2371
+
2372
  try:
2373
+
2374
  # Create a pending report entry
2375
+
2376
  new_report = DeepAnalysisReport(
2377
+
2378
  report_uuid=report_uuid,
2379
+
2380
  user_id=user_id,
2381
+
2382
  goal=request.goal,
2383
+
2384
  status="pending",
2385
+
2386
  start_time=datetime.now(UTC),
2387
+
2388
  progress_percentage=0
2389
+
2390
  )
2391
+
2392
 
2393
+
2394
  db_session.add(new_report)
2395
+
2396
  db_session.commit()
2397
+
2398
  db_session.refresh(new_report)
2399
+
2400
 
2401
+
2402
  # Store the report ID in session state for later updates
2403
+
2404
  session_state["current_deep_analysis_id"] = new_report.report_id
2405
+
2406
  session_state["current_deep_analysis_uuid"] = report_uuid
2407
+
2408
 
2409
+
2410
  except Exception as e:
2411
+
2412
  logger.log_message(f"Error creating initial deep analysis report: {str(e)}", level=logging.ERROR)
2413
+
2414
  # Continue even if DB storage fails
2415
+
2416
  finally:
2417
+
2418
  db_session.close()
2419
+
2420
 
2421
+
2422
  except Exception as e:
2423
+
2424
  logger.log_message(f"Database operation failed: {str(e)}", level=logging.ERROR)
2425
+
2426
  # Continue even if DB operation fails
2427
+
2428
 
2429
+
2430
  # Get session-specific model
2431
+
2432
  # session_lm = get_session_lm(session_state)
2433
+
2434
  session_lm = dspy.LM(model="anthropic/claude-sonnet-4-20250514", max_tokens=7000, temperature=0.5)
2435
+
2436
 
2437
+
2438
  return StreamingResponse(
2439
+
2440
  _generate_deep_analysis_stream(session_state, request.goal, session_lm, session_id),
2441
+
2442
  media_type='text/event-stream',
2443
+
2444
  headers={
2445
+
2446
  'Cache-Control': 'no-cache',
2447
+
2448
  'Connection': 'keep-alive',
2449
+
2450
  'Content-Type': 'text/event-stream',
2451
+
2452
  'Access-Control-Allow-Origin': '*',
2453
+
2454
  'X-Accel-Buffering': 'no'
2455
+
2456
  }
2457
+
2458
  )
2459
+
2460
 
2461
+
2462
  except HTTPException:
2463
+
2464
  raise
2465
+
2466
  except Exception as e:
2467
+
2468
  logger.log_message(f"Streaming deep analysis failed: {str(e)}", level=logging.ERROR)
2469
+
2470
  raise HTTPException(status_code=500, detail=f"Streaming deep analysis failed: {str(e)}")
2471
 
2472
+
2473
+
2474
  async def _generate_deep_analysis_stream(session_state: dict, goal: str, session_lm, session_id: str):
2475
+
2476
  """Generate streaming responses for deep analysis"""
2477
+
2478
  # Track the start time for duration calculation
2479
+
2480
  start_time = datetime.now(UTC)
2481
+
2482
 
2483
+
2484
  try:
2485
+
2486
  # Get dataset info
2487
+
2488
  datasets = session_state["datasets"]
2489
  dtypes_info = pd.DataFrame({
2490
+
2491
  'Column': df.columns,
2492
+
2493
  'Data Type': df.dtypes.astype(str)
2494
+
2495
  }).to_markdown()
2496
+
2497
  dataset_info = f"Sample Data:\n{df.head(2).to_markdown()}\n\nData Types:\n{dtypes_info}"
2498
+
2499
 
2500
+
2501
  # Get report info from session state
2502
+
2503
  report_id = session_state.get("current_deep_analysis_id")
2504
+
2505
  report_uuid = session_state.get("current_deep_analysis_uuid")
2506
+
2507
  user_id = session_state.get("user_id")
2508
+
2509
 
2510
+
2511
  # Helper function to update report in database
2512
+
2513
  async def update_report_in_db(status, progress, step=None, content=None):
2514
+
2515
  if not report_id:
2516
+
2517
  return
2518
+
2519
 
2520
+
2521
  try:
2522
+
2523
  from src.db.init_db import session_factory
2524
+
2525
  from src.db.schemas.models import DeepAnalysisReport
2526
+
2527
 
2528
+
2529
  db_session = session_factory()
2530
+
2531
 
2532
+
2533
  try:
2534
+
2535
  report = db_session.query(DeepAnalysisReport).filter(DeepAnalysisReport.report_id == report_id).first()
2536
+
2537
 
2538
+
2539
  if report:
2540
+
2541
  report.status = status
2542
+
2543
  report.progress_percentage = progress
2544
+
2545
 
2546
+
2547
  # Update step-specific fields if provided
2548
+
2549
  if step == "questions" and content:
2550
+
2551
  report.deep_questions = content
2552
+
2553
  elif step == "planning" and content:
2554
+
2555
  report.deep_plan = content
2556
+
2557
  elif step == "analysis" and content:
2558
+
2559
  # For analysis step, we get the full object with multiple fields
2560
+
2561
  if isinstance(content, dict):
2562
+
2563
  # Update fields from content if they exist
2564
+
2565
  if "deep_questions" in content and content["deep_questions"]:
2566
+
2567
  report.deep_questions = content["deep_questions"]
2568
+
2569
  if "deep_plan" in content and content["deep_plan"]:
2570
+
2571
  report.deep_plan = content["deep_plan"]
2572
+
2573
  if "code" in content and content["code"]:
2574
+
2575
  report.analysis_code = content["code"]
2576
+
2577
  if "final_conclusion" in content and content["final_conclusion"]:
2578
+
2579
  report.final_conclusion = content["final_conclusion"]
2580
+
2581
  # Also update summary from conclusion
2582
+
2583
  conclusion = content["final_conclusion"]
2584
+
2585
  conclusion = conclusion.replace("**Conclusion**", "")
2586
+
2587
  report.report_summary = conclusion[:200] + "..." if len(conclusion) > 200 else conclusion
2588
+
2589
 
2590
+
2591
  # Handle JSON fields
2592
+
2593
  if "summaries" in content and content["summaries"]:
2594
+
2595
  report.summaries = json.dumps(content["summaries"])
2596
+
2597
  if "plotly_figs" in content and content["plotly_figs"]:
2598
+
2599
  report.plotly_figures = json.dumps(content["plotly_figs"])
2600
+
2601
  if "synthesis" in content and content["synthesis"]:
2602
+
2603
  report.synthesis = json.dumps(content["synthesis"])
2604
+
2605
 
2606
+
2607
  # For the final step, update the HTML report
2608
+
2609
  if step == "completed":
2610
+
2611
  if content:
2612
+
2613
  report.html_report = content
2614
+
2615
  else:
2616
+
2617
  logger.log_message("No HTML content provided for completed step", level=logging.WARNING)
2618
+
2619
 
2620
+
2621
  report.end_time = datetime.now(UTC)
2622
+
2623
  # Ensure start_time is timezone-aware before calculating duration
2624
+
2625
  if report.start_time.tzinfo is None:
2626
+
2627
  start_time_utc = report.start_time.replace(tzinfo=UTC)
2628
+
2629
  else:
2630
+
2631
  start_time_utc = report.start_time
2632
+
2633
  report.duration_seconds = int((report.end_time - start_time_utc).total_seconds())
2634
+
2635
 
2636
+
2637
  report.updated_at = datetime.now(UTC)
2638
+
2639
  db_session.commit()
2640
+
2641
 
2642
+
2643
  except Exception as e:
2644
+
2645
  db_session.rollback()
2646
+
2647
  logger.log_message(f"Error updating deep analysis report: {str(e)}", level=logging.ERROR)
2648
+
2649
  finally:
2650
+
2651
  db_session.close()
2652
+
2653
  except Exception as e:
2654
+
2655
  logger.log_message(f"Database operation failed: {str(e)}", level=logging.ERROR)
2656
+
2657
 
2658
+
2659
  # Use session model for this request
2660
+
2661
  with dspy.context(lm=session_lm):
2662
+
2663
  # Send initial status
2664
+
2665
  yield json.dumps({
2666
+
2667
  "step": "initialization",
2668
+
2669
  "status": "starting",
2670
+
2671
  "message": "Initializing deep analysis...",
2672
+
2673
  "progress": 5
2674
+
2675
  }) + "\n"
2676
+
2677
 
2678
+
2679
  # Update DB status to running
2680
+
2681
  await update_report_in_db("running", 5)
2682
+
2683
 
2684
+
2685
  # Get deep analyzer - use the correct session_id from the session_state
2686
+
2687
  logger.log_message(f"Getting deep analyzer for session_id: {session_id}, user_id: {user_id}", level=logging.INFO)
2688
+
2689
  deep_analyzer = app.state.get_deep_analyzer(session_id)
2690
+
2691
 
2692
+
2693
  # Make the dataset available globally for code execution
2694
+
2695
  globals()['df'] = df
2696
+
2697
 
2698
+
2699
  # Use the new streaming method and forward all progress updates
2700
+
2701
  final_result = None
2702
+
2703
  async for update in deep_analyzer.execute_deep_analysis_streaming(
2704
+
2705
  goal=goal,
2706
+
2707
  dataset_info=dataset_info,
2708
+
2709
  session_df=df
2710
+
2711
  ):
2712
+
2713
  # Convert the update to the expected format and yield it
2714
+
2715
  if update.get("step") == "questions" and update.get("status") == "completed":
2716
+
2717
  # Update DB with questions
2718
+
2719
  await update_report_in_db("running", update.get("progress", 0), "questions", update.get("content"))
2720
+
2721
  elif update.get("step") == "planning" and update.get("status") == "completed":
2722
+
2723
  # Update DB with planning
2724
+
2725
  await update_report_in_db("running", update.get("progress", 0), "planning", update.get("content"))
2726
+
2727
  elif update.get("step") == "conclusion" and update.get("status") == "completed":
2728
+
2729
  # Store the final result for later processing
2730
+
2731
  final_result = update.get("final_result")
2732
+
2733
 
2734
+
2735
  # Convert Plotly figures to JSON format for network transmission
2736
+
2737
  if final_result:
2738
+
2739
  import plotly.io
2740
+
2741
  serialized_return_dict = final_result.copy()
2742
+
2743
 
2744
+
2745
  # Convert plotly_figs to JSON format
2746
+
2747
  if 'plotly_figs' in serialized_return_dict and serialized_return_dict['plotly_figs']:
2748
+
2749
  json_figs = []
2750
+
2751
  for fig_list in serialized_return_dict['plotly_figs']:
2752
+
2753
  if isinstance(fig_list, list):
2754
+
2755
  json_fig_list = []
2756
+
2757
  for fig in fig_list:
2758
+
2759
  if hasattr(fig, 'to_json'): # Check if it's a Plotly figure
2760
+
2761
  json_fig_list.append(plotly.io.to_json(fig))
2762
+
2763
  else:
2764
+
2765
  json_fig_list.append(fig) # Already JSON or other format
2766
+
2767
  json_figs.append(json_fig_list)
2768
+
2769
  else:
2770
+
2771
  # Single figure case
2772
+
2773
  if hasattr(fig_list, 'to_json'):
2774
+
2775
  json_figs.append(plotly.io.to_json(fig_list))
2776
+
2777
  else:
2778
+
2779
  json_figs.append(fig_list)
2780
+
2781
  serialized_return_dict['plotly_figs'] = json_figs
2782
+
2783
 
2784
+
2785
  # Update DB with analysis results
2786
+
2787
  await update_report_in_db("running", update.get("progress", 0), "analysis", serialized_return_dict)
2788
+
2789
 
2790
+
2791
  # Generate HTML report using the original final_result with Figure objects
2792
+
2793
  html_report = None
2794
+
2795
  try:
2796
+
2797
  html_report = generate_html_report(final_result)
2798
+
2799
  except Exception as e:
2800
+
2801
  logger.log_message(f"Error generating HTML report: {str(e)}", level=logging.ERROR)
2802
+
2803
  # Continue even if HTML generation fails
2804
+
2805
 
2806
+
2807
  # Send the analysis results
2808
+
2809
  yield json.dumps({
2810
+
2811
  "step": "analysis",
2812
+
2813
  "status": "completed",
2814
+
2815
  "content": serialized_return_dict,
2816
+
2817
  "progress": 90
2818
+
2819
  }) + "\n"
2820
+
2821
 
2822
+
2823
  # Send report generation status
2824
+
2825
  yield json.dumps({
2826
+
2827
  "step": "report",
2828
+
2829
  "status": "processing",
2830
+
2831
  "message": "Generating final report...",
2832
+
2833
  "progress": 95
2834
+
2835
  }) + "\n"
2836
+
2837
 
2838
+
2839
  # Send final completion
2840
+
2841
  yield json.dumps({
2842
+
2843
  "step": "completed",
2844
+
2845
  "status": "success",
2846
+
2847
  "analysis": serialized_return_dict,
2848
+
2849
  "html_report": html_report,
2850
+
2851
  "progress": 100
2852
+
2853
  }) + "\n"
2854
+
2855
 
2856
+
2857
  # Update DB with completed report (with HTML if generated)
2858
+
2859
  if html_report:
2860
+
2861
  logger.log_message(f"Saving HTML report to database, length: {len(html_report)}", level=logging.INFO)
2862
+
2863
  else:
2864
+
2865
  logger.log_message("No HTML report to save to database", level=logging.WARNING)
2866
+
2867
  await update_report_in_db("completed", 100, "completed", html_report)
2868
+
2869
  elif update.get("step") == "error":
2870
+
2871
  # Forward error directly
2872
+
2873
  yield json.dumps(update) + "\n"
2874
+
2875
  await update_report_in_db("failed", 0)
2876
+
2877
  return
2878
+
2879
  else:
2880
+
2881
  # Forward all other progress updates
2882
+
2883
  yield json.dumps(update) + "\n"
2884
+
2885
 
2886
+
2887
  # If we somehow exit the loop without getting a final result, that's an error
2888
+
2889
  if not final_result:
2890
+
2891
  yield json.dumps({
2892
+
2893
  "step": "error",
2894
+
2895
  "status": "failed",
2896
+
2897
  "message": "Deep analysis completed without final result",
2898
+
2899
  "progress": 0
2900
+
2901
  }) + "\n"
2902
+
2903
  await update_report_in_db("failed", 0)
2904
+
2905
 
2906
+
2907
  except Exception as e:
2908
+
2909
  logger.log_message(f"Error in deep analysis stream: {str(e)}", level=logging.ERROR)
2910
+
2911
  yield json.dumps({
2912
+
2913
  "step": "error",
2914
+
2915
  "status": "failed",
2916
+
2917
  "message": f"Deep analysis failed: {str(e)}",
2918
+
2919
  "progress": 0
2920
+
2921
  }) + "\n"
2922
+
2923
 
2924
+
2925
  # Update DB with error status
2926
+
2927
  if 'update_report_in_db' in locals() and session_state.get("current_deep_analysis_id"):
2928
+
2929
  await update_report_in_db("failed", 0)
2930
 
2931
+
2932
+
2933
  @app.post("/deep_analysis/download_report")
2934
+
2935
  async def download_html_report(
2936
+
2937
  request: dict,
2938
+
2939
  session_id: str = Depends(get_session_id_dependency)
2940
+
2941
  ):
2942
+
2943
  """Download HTML report from previous deep analysis"""
2944
+
2945
  try:
2946
+
2947
  analysis_data = request.get("analysis_data")
2948
+
2949
  if not analysis_data:
2950
+
2951
  raise HTTPException(status_code=400, detail="No analysis data provided")
2952
+
2953
 
2954
+
2955
  # Get report UUID from request if available (for saving to DB)
2956
+
2957
  report_uuid = request.get("report_uuid")
2958
+
2959
  session_state = app.state.get_session_state(session_id)
2960
+
2961
 
2962
+
2963
  # If no report_uuid in request, try to get it from session state
2964
+
2965
  if not report_uuid and session_state.get("current_deep_analysis_uuid"):
2966
+
2967
  report_uuid = session_state.get("current_deep_analysis_uuid")
2968
+
2969
 
2970
+
2971
  # Convert JSON-serialized Plotly figures back to Figure objects for HTML generation
2972
+
2973
  processed_data = analysis_data.copy()
2974
+
2975
 
2976
+
2977
  if 'plotly_figs' in processed_data and processed_data['plotly_figs']:
2978
+
2979
  import plotly.io
2980
+
2981
  import plotly.graph_objects as go
2982
+
2983
 
2984
+
2985
  figure_objects = []
2986
+
2987
  for fig_list in processed_data['plotly_figs']:
2988
+
2989
  if isinstance(fig_list, list):
2990
+
2991
  fig_obj_list = []
2992
+
2993
  for fig_json in fig_list:
2994
+
2995
  if isinstance(fig_json, str):
2996
+
2997
  # Convert JSON string back to Figure object
2998
+
2999
  try:
3000
+
3001
  fig_obj = plotly.io.from_json(fig_json)
3002
+
3003
  fig_obj_list.append(fig_obj)
3004
+
3005
  except Exception as e:
3006
+
3007
  logger.log_message(f"Error parsing Plotly JSON: {str(e)}", level=logging.WARNING)
3008
+
3009
  continue
3010
+
3011
  elif hasattr(fig_json, 'to_html'):
3012
+
3013
  # Already a Figure object
3014
+
3015
  fig_obj_list.append(fig_json)
3016
+
3017
  figure_objects.append(fig_obj_list)
3018
+
3019
  else:
3020
+
3021
  # Single figure case
3022
+
3023
  if isinstance(fig_list, str):
3024
+
3025
  try:
3026
+
3027
  fig_obj = plotly.io.from_json(fig_list)
3028
+
3029
  figure_objects.append(fig_obj)
3030
+
3031
  except Exception as e:
3032
+
3033
  logger.log_message(f"Error parsing Plotly JSON: {str(e)}", level=logging.WARNING)
3034
+
3035
  continue
3036
+
3037
  elif hasattr(fig_list, 'to_html'):
3038
+
3039
  figure_objects.append(fig_list)
3040
+
3041
 
3042
+
3043
  processed_data['plotly_figs'] = figure_objects
3044
+
3045
 
3046
+
3047
  # Generate HTML report
3048
+
3049
  html_report = generate_html_report(processed_data)
3050
+
3051
 
3052
+
3053
  # Save report to database if we have a UUID
3054
+
3055
  if report_uuid:
3056
+
3057
  try:
3058
+
3059
  from src.db.init_db import session_factory
3060
+
3061
  from src.db.schemas.models import DeepAnalysisReport
3062
+
3063
 
3064
+
3065
  db_session = session_factory()
3066
+
3067
  try:
3068
+
3069
  # Try to find existing report by UUID
3070
+
3071
  report = db_session.query(DeepAnalysisReport).filter(DeepAnalysisReport.report_uuid == report_uuid).first()
3072
+
3073
 
3074
+
3075
  if report:
3076
+
3077
  # Update existing report with HTML content
3078
+
3079
  report.html_report = html_report
3080
+
3081
  report.updated_at = datetime.now(UTC)
3082
+
3083
  db_session.commit()
3084
+
3085
  except Exception as e:
3086
+
3087
  db_session.rollback()
3088
+
3089
  finally:
3090
+
3091
  db_session.close()
3092
+
3093
  except Exception as e:
3094
+
3095
  logger.log_message(f"Database operation failed when storing HTML report: {str(e)}", level=logging.ERROR)
3096
+
3097
  # Continue even if DB storage fails
3098
+
3099
 
3100
+
3101
  # Create a filename with timestamp
3102
+
3103
  timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S")
3104
+
3105
  filename = f"deep_analysis_report_{timestamp}.html"
3106
+
3107
 
3108
+
3109
  # Return as downloadable file
3110
+
3111
  return StreamingResponse(
3112
+
3113
  iter([html_report.encode('utf-8')]),
3114
+
3115
  media_type='text/html',
3116
+
3117
  headers={
3118
+
3119
  'Content-Disposition': f'attachment; filename="{filename}"',
3120
+
3121
  'Content-Type': 'text/html; charset=utf-8'
3122
+
3123
  }
3124
+
3125
  )
3126
+
3127
 
3128
+
3129
  except Exception as e:
3130
+
3131
  logger.log_message(f"Failed to generate HTML report: {str(e)}", level=logging.ERROR)
3132
+
3133
  raise HTTPException(status_code=500, detail=f"Failed to generate report: {str(e)}")
3134
 
3135
 
3136
+
3137
+
3138
+
3139
  # In the section where routers are included, add the session_router
3140
+
3141
  app.include_router(chat_router)
3142
+
3143
  app.include_router(analytics_router)
3144
+
3145
  app.include_router(code_router)
3146
+
3147
  app.include_router(session_router)
3148
+
3149
  app.include_router(feedback_router)
3150
+
3151
  app.include_router(deep_analysis_router)
3152
+
3153
  app.include_router(templates_router)
3154
+
3155
  app.include_router(blog_router)
3156
 
3157
+
3158
+
3159
  if __name__ == "__main__":
3160
+
3161
  port = int(os.environ.get("PORT", 8000))
3162
+
3163
  uvicorn.run(app, host="0.0.0.0", port=port)
3164
+
src/agents/agents.py CHANGED
@@ -791,11 +791,11 @@ class planner_module(dspy.Module):
791
  "plan_instructions": {"message": "No agents are currently enabled for analysis. Please enable at least one agent (preprocessing, statistical analysis, machine learning, or visualization) in your template preferences to proceed with data analysis."}
792
  }
793
 
794
- output = {
795
- "complexity": complexity.exact_word_complexity.strip().lower(),
796
- "plan": plan.plan,
797
- "plan_instructions": plan.plan_instructions
798
- }
799
 
800
  return output
801
 
@@ -1670,10 +1670,10 @@ class data_context_gen(dspy.Signature):
1670
  "quantity": {"type": "int", "role": "measure"},
1671
  "unit_price": {"type": "float", "role": "measure"}
1672
  },
1673
- "metrics": [
1674
  "revenue = quantity * unit_price"
1675
- ],
1676
- "use_cases": [
1677
  "Revenue trend analysis",
1678
  "Regional sales comparison"
1679
  ]
 
791
  "plan_instructions": {"message": "No agents are currently enabled for analysis. Please enable at least one agent (preprocessing, statistical analysis, machine learning, or visualization) in your template preferences to proceed with data analysis."}
792
  }
793
 
794
+ output = {
795
+ "complexity": complexity.exact_word_complexity.strip().lower(),
796
+ "plan": plan.plan,
797
+ "plan_instructions": plan.plan_instructions
798
+ }
799
 
800
  return output
801
 
 
1670
  "quantity": {"type": "int", "role": "measure"},
1671
  "unit_price": {"type": "float", "role": "measure"}
1672
  },
1673
+ "metrics": [
1674
  "revenue = quantity * unit_price"
1675
+ ],
1676
+ "use_cases": [
1677
  "Revenue trend analysis",
1678
  "Regional sales comparison"
1679
  ]
src/agents/deep_agents.py CHANGED
@@ -353,7 +353,7 @@ def clean_and_store_code(code, session_df=None):
353
 
354
  return output_dict
355
 
356
- def score_code(args, code):
357
  """
358
  Cleans and stores code execution results in a standardized format.
359
  Safely handles execution errors and returns clean output even if execution fails.
@@ -362,6 +362,7 @@ def score_code(args, code):
362
  Args:
363
  args: Arguments (unused but required for dspy.Refine)
364
  code: Code object with combined_code attribute
 
365
 
366
  Returns:
367
  int: Score (0=error, 1=success, 2=success with plots)
@@ -399,16 +400,34 @@ def score_code(args, code):
399
  cleaned_code = re.sub(r'\w+_fig\w*\.show\(\s*[^)]*\s*\)', '', cleaned_code) # *_fig*.show(any_args)
400
 
401
  cleaned_code = remove_main_block(cleaned_code)
 
402
  # Capture stdout using StringIO
403
  from io import StringIO
404
  import sys
405
  import plotly.graph_objects as go
 
 
 
406
  stdout_capture = StringIO()
407
  original_stdout = sys.stdout
408
  sys.stdout = stdout_capture
409
 
410
- # Execute code in a new namespace to avoid polluting globals
411
  local_vars = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  exec(cleaned_code, globals(), local_vars)
413
 
414
  # Capture any plotly figures from local namespace
@@ -902,7 +921,25 @@ class deep_analysis_module(dspy.Module):
902
  code.append(c.replace('try\n','try:\n'))
903
 
904
  # Create deep coder without asyncify to avoid source inspection issues
905
- deep_coder = dspy.Refine(module=self.deep_code_synthesizer_sync, N=5, reward_fn=score_code, threshold=1.0, fail_count=10)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
906
 
907
  # Check if we have valid API key
908
  anthropic_key = os.environ.get('ANTHROPIC_API_KEY')
 
353
 
354
  return output_dict
355
 
356
+ def score_code(args, code, datasets=None):
357
  """
358
  Cleans and stores code execution results in a standardized format.
359
  Safely handles execution errors and returns clean output even if execution fails.
 
362
  Args:
363
  args: Arguments (unused but required for dspy.Refine)
364
  code: Code object with combined_code attribute
365
+ datasets: Dictionary of datasets from session state (optional)
366
 
367
  Returns:
368
  int: Score (0=error, 1=success, 2=success with plots)
 
400
  cleaned_code = re.sub(r'\w+_fig\w*\.show\(\s*[^)]*\s*\)', '', cleaned_code) # *_fig*.show(any_args)
401
 
402
  cleaned_code = remove_main_block(cleaned_code)
403
+
404
  # Capture stdout using StringIO
405
  from io import StringIO
406
  import sys
407
  import plotly.graph_objects as go
408
+ import pandas as pd
409
+ import numpy as np
410
+
411
  stdout_capture = StringIO()
412
  original_stdout = sys.stdout
413
  sys.stdout = stdout_capture
414
 
415
+ # Execute code in a new namespace with datasets available
416
  local_vars = {}
417
+
418
+ # Add datasets to the execution context if provided
419
+ if datasets:
420
+ local_vars.update(datasets)
421
+
422
+ # Add common imports to the execution context
423
+ local_vars.update({
424
+ 'pd': pd,
425
+ 'np': np,
426
+ 'go': go,
427
+ 'plt': __import__('matplotlib.pyplot'),
428
+ 'sns': __import__('seaborn'),
429
+ })
430
+
431
  exec(cleaned_code, globals(), local_vars)
432
 
433
  # Capture any plotly figures from local namespace
 
921
  code.append(c.replace('try\n','try:\n'))
922
 
923
  # Create deep coder without asyncify to avoid source inspection issues
924
+ def create_score_code_with_datasets(datasets):
925
+ """
926
+ Creates a score_code function with access to datasets
927
+
928
+ Args:
929
+ datasets: Dictionary of datasets from session_state['datasets']
930
+
931
+ Returns:
932
+ A reward function compatible with dspy.Refine
933
+ """
934
+ def score_code_with_datasets(args, pred):
935
+ return score_code(args, pred, session_state_datasets=datasets)
936
+
937
+ return score_code_with_datasets
938
+
939
+ # Then in your deep analysis method:
940
+ # Create score function with datasets
941
+ score_fn = create_score_code_with_datasets(self.datasets)
942
+ deep_coder = dspy.Refine(module=self.deep_code_synthesizer_sync, N=5, reward_fn=score_fn, threshold=1.0, fail_count=10)
943
 
944
  # Check if we have valid API key
945
  anthropic_key = os.environ.get('ANTHROPIC_API_KEY')
src/routes/code_routes.py CHANGED
@@ -104,13 +104,7 @@ def score_code(args, code):
104
  return 0
105
 
106
 
107
- refine_fixer = dspy.Refine(
108
- module=dspy.ChainOfThought(code_fix),
109
- N=3,
110
- threshold=1.0,
111
- reward_fn=score_code,
112
- fail_count=3
113
- )
114
 
115
 
116
  def format_code(code: str) -> str:
@@ -287,131 +281,145 @@ def extract_relevant_error_section(error_message: str) -> str:
287
  # If the error is short enough, return as is
288
  return error_message
289
 
290
- async def fix_code_with_dspy(code: str, error: str, dataset_context: str = ""):
291
  """
292
- Fix code with errors by identifying faulty blocks and fixing them individually using async refine
293
-
294
- Args:
295
- code (str): The code containing errors
296
- error (str): Error message from execution
297
- dataset_context (str): Context about the dataset
298
-
299
- Returns:
300
- str: The fixed code
301
  """
302
- import asyncio
303
-
304
- # Check if we have valid API key
305
- anthropic_key = os.environ.get('ANTHROPIC_API_KEY')
306
- if not anthropic_key:
307
- raise ValueError("ANTHROPIC_API_KEY environment variable is not set")
308
-
309
- # Find the blocks with errors
310
- faulty_blocks = identify_error_blocks(code, error)
311
-
312
- if not faulty_blocks:
313
- # If no specific errors found, fix the entire code using refine
314
- try:
315
- # Create the LM instance that will be used
316
- # thread_lm = dspy.LM("anthropic/claude-3-5-sonnet-latest", api_key=anthropic_key, max_tokens=2500)
317
- thread_lm = MODEL_OBJECTS['claude-3-5-sonnet-latest']
318
-
319
- # Define the blocking function to run in thread
320
- def run_refine_fixer():
321
- with dspy.context(lm=thread_lm):
322
- return refine_fixer(
323
- dataset_context=str(dataset_context) or "",
324
- faulty_code=str(code) or "",
325
- error=str(error) or "",
326
- )
327
-
328
- # Use asyncio.to_thread for better async integration
329
- result = await asyncio.to_thread(run_refine_fixer)
330
- return result.fixed_code
331
-
332
- except Exception as e:
333
- logger.log_message(f"Error during refine code fixing: {str(e)}", level=logging.ERROR)
334
- raise e
335
-
336
- # Start with the original code
337
- result_code = code.replace("```python", "").replace("```", "")
338
-
339
- # Fix each faulty block separately using async refine
340
  try:
341
- thread_lm = MODEL_OBJECTS['claude-3-5-sonnet-latest']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
- for agent_name, block_code, specific_error in faulty_blocks:
 
344
  try:
345
- # Extract inner code between the markers
346
- inner_code_match = re.search(r'#\s+\w+\s+code\s+start\s*\n([\s\S]*?)#\s+\w+\s+code\s+end', block_code)
347
- if not inner_code_match:
348
- continue
349
-
350
- inner_code = inner_code_match.group(1).strip()
351
-
352
- # Find markers
353
- start_marker_match = re.search(r'(#\s+\w+\s+code\s+start)', block_code)
354
- end_marker_match = re.search(r'(#\s+\w+\s+code\s+end)', block_code)
355
-
356
- if not start_marker_match or not end_marker_match:
357
- logger.log_message(f"Could not find start/end markers for {agent_name}", level=logging.WARNING)
358
- continue
359
-
360
- start_marker = start_marker_match.group(1)
361
- end_marker = end_marker_match.group(1)
362
-
363
- # Extract the error type and actual error message
364
- error_type = ""
365
- error_msg = specific_error
366
 
367
- # Look for common error patterns to provide focused context to the LLM
368
- error_type_match = re.search(r'(TypeError|ValueError|AttributeError|IndexError|KeyError|NameError):\s*([^\n]+)', specific_error)
369
- if error_type_match:
370
- error_type = error_type_match.group(1)
371
- error_msg = f"{error_type}: {error_type_match.group(2)}"
372
-
373
- # Add problem location if available
374
- if "Problem at this location:" in specific_error:
375
- problem_section = re.search(r'Problem at this location:([\s\S]*?)(?:\n\n|$)', specific_error)
376
- if problem_section:
377
- error_msg = f"{error_msg}\n\nProblem at: {problem_section.group(1).strip()}"
378
-
379
- # Define the blocking function to run in thread for this specific block
380
- def run_block_fixer():
381
  with dspy.context(lm=thread_lm):
382
  return refine_fixer(
383
  dataset_context=str(dataset_context) or "",
384
- faulty_code=str(inner_code) or "",
385
- error=str(error_msg) or "",
386
  )
387
 
388
  # Use asyncio.to_thread for better async integration
389
- result = await asyncio.to_thread(run_block_fixer)
390
-
391
- # Ensure the fixed code is properly stripped and doesn't include markers
392
- fixed_inner_code = result.fixed_code.strip()
393
- if fixed_inner_code.startswith('#') and 'code start' in fixed_inner_code:
394
- # If LLM included markers in response, extract only inner code
395
- inner_match = re.search(r'#\s+\w+\s+code\s+start\s*\n([\s\S]*?)#\s+\w+\s+code\s+end', fixed_inner_code)
396
- if inner_match:
397
- fixed_inner_code = inner_match.group(1).strip()
398
-
399
- # Reconstruct the block with fixed code
400
- fixed_block = f"{start_marker}\n\n{fixed_inner_code}\n\n{end_marker}"
401
-
402
- # Replace the original block with the fixed block in the full code
403
- result_code = result_code.replace(block_code, fixed_block)
404
 
405
  except Exception as e:
406
- # Log the error but continue with other blocks
407
- logger.log_message(f"Error fixing {agent_name} block: {str(e)}", level=logging.ERROR)
408
- continue
409
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  except Exception as e:
411
- logger.log_message(f"Error during async code fixing: {str(e)}", level=logging.ERROR)
412
  raise e
413
-
414
- return result_code
415
 
416
  def get_dataset_context(df):
417
  """
@@ -756,7 +764,8 @@ async def fix_code(
756
  fixed_code = await fix_code_with_dspy(
757
  request_data.code,
758
  request_data.error,
759
- dataset_context
 
760
  )
761
 
762
  fixed_code = format_code_block(fixed_code)
 
104
  return 0
105
 
106
 
107
+ # Remove the global refine_fixer declaration
 
 
 
 
 
 
108
 
109
 
110
  def format_code(code: str) -> str:
 
281
  # If the error is short enough, return as is
282
  return error_message
283
 
284
+ async def fix_code_with_dspy(code: str, error: str, dataset_context: str = "", datasets: dict = None):
285
  """
286
+ Fix code using DSPy with dataset context and actual datasets
 
 
 
 
 
 
 
 
287
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  try:
289
+ # Create score function with actual datasets
290
+ def create_score_code_with_datasets(datasets_dict):
291
+ def score_code_with_datasets(args, pred):
292
+ return score_code(args, pred, session_state_datasets=datasets_dict)
293
+ return score_code_with_datasets
294
+
295
+ # Create refine_fixer with datasets
296
+ if datasets:
297
+ score_fn = create_score_code_with_datasets(datasets)
298
+ else:
299
+ score_fn = score_code # Fallback to original function
300
+
301
+ refine_fixer = dspy.Refine(
302
+ module=dspy.ChainOfThought(code_fix),
303
+ N=3,
304
+ threshold=1.0,
305
+ reward_fn=score_fn,
306
+ fail_count=3
307
+ )
308
+
309
+ # Check if we have valid API key
310
+ anthropic_key = os.environ.get('ANTHROPIC_API_KEY')
311
+ if not anthropic_key:
312
+ raise ValueError("ANTHROPIC_API_KEY environment variable is not set")
313
+
314
+ # Find the blocks with errors
315
+ faulty_blocks = identify_error_blocks(code, error)
316
 
317
+ if not faulty_blocks:
318
+ # If no specific errors found, fix the entire code using refine
319
  try:
320
+ # Create the LM instance that will be used
321
+ # thread_lm = dspy.LM("anthropic/claude-3-5-sonnet-latest", api_key=anthropic_key, max_tokens=2500)
322
+ thread_lm = MODEL_OBJECTS['claude-3-5-sonnet-latest']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
+ # Define the blocking function to run in thread
325
+ def run_refine_fixer():
 
 
 
 
 
 
 
 
 
 
 
 
326
  with dspy.context(lm=thread_lm):
327
  return refine_fixer(
328
  dataset_context=str(dataset_context) or "",
329
+ faulty_code=str(code) or "",
330
+ error=str(error) or "",
331
  )
332
 
333
  # Use asyncio.to_thread for better async integration
334
+ result = await asyncio.to_thread(run_refine_fixer)
335
+ return result.fixed_code
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
  except Exception as e:
338
+ logger.log_message(f"Error during refine code fixing: {str(e)}", level=logging.ERROR)
339
+ raise e
340
+
341
+ # Start with the original code
342
+ result_code = code.replace("```python", "").replace("```", "")
343
+
344
+ # Fix each faulty block separately using async refine
345
+ try:
346
+ thread_lm = MODEL_OBJECTS['claude-3-5-sonnet-latest']
347
+
348
+ for agent_name, block_code, specific_error in faulty_blocks:
349
+ try:
350
+ # Extract inner code between the markers
351
+ inner_code_match = re.search(r'#\s+\w+\s+code\s+start\s*\n([\s\S]*?)#\s+\w+\s+code\s+end', block_code)
352
+ if not inner_code_match:
353
+ continue
354
+
355
+ inner_code = inner_code_match.group(1).strip()
356
+
357
+ # Find markers
358
+ start_marker_match = re.search(r'(#\s+\w+\s+code\s+start)', block_code)
359
+ end_marker_match = re.search(r'(#\s+\w+\s+code\s+end)', block_code)
360
+
361
+ if not start_marker_match or not end_marker_match:
362
+ logger.log_message(f"Could not find start/end markers for {agent_name}", level=logging.WARNING)
363
+ continue
364
+
365
+ start_marker = start_marker_match.group(1)
366
+ end_marker = end_marker_match.group(1)
367
+
368
+ # Extract the error type and actual error message
369
+ error_type = ""
370
+ error_msg = specific_error
371
+
372
+ # Look for common error patterns to provide focused context to the LLM
373
+ error_type_match = re.search(r'(TypeError|ValueError|AttributeError|IndexError|KeyError|NameError):\s*([^\n]+)', specific_error)
374
+ if error_type_match:
375
+ error_type = error_type_match.group(1)
376
+ error_msg = f"{error_type}: {error_type_match.group(2)}"
377
+
378
+ # Add problem location if available
379
+ if "Problem at this location:" in specific_error:
380
+ problem_section = re.search(r'Problem at this location:([\s\S]*?)(?:\n\n|$)', specific_error)
381
+ if problem_section:
382
+ error_msg = f"{error_msg}\n\nProblem at: {problem_section.group(1).strip()}"
383
+
384
+ # Define the blocking function to run in thread for this specific block
385
+ def run_block_fixer():
386
+ with dspy.context(lm=thread_lm):
387
+ return refine_fixer(
388
+ dataset_context=str(dataset_context) or "",
389
+ faulty_code=str(inner_code) or "",
390
+ error=str(error_msg) or "",
391
+ )
392
+
393
+ # Use asyncio.to_thread for better async integration
394
+ result = await asyncio.to_thread(run_block_fixer)
395
+
396
+ # Ensure the fixed code is properly stripped and doesn't include markers
397
+ fixed_inner_code = result.fixed_code.strip()
398
+ if fixed_inner_code.startswith('#') and 'code start' in fixed_inner_code:
399
+ # If LLM included markers in response, extract only inner code
400
+ inner_match = re.search(r'#\s+\w+\s+code\s+start\s*\n([\s\S]*?)#\s+\w+\s+code\s+end', fixed_inner_code)
401
+ if inner_match:
402
+ fixed_inner_code = inner_match.group(1).strip()
403
+
404
+ # Reconstruct the block with fixed code
405
+ fixed_block = f"{start_marker}\n\n{fixed_inner_code}\n\n{end_marker}"
406
+
407
+ # Replace the original block with the fixed block in the full code
408
+ result_code = result_code.replace(block_code, fixed_block)
409
+
410
+ except Exception as e:
411
+ # Log the error but continue with other blocks
412
+ logger.log_message(f"Error fixing {agent_name} block: {str(e)}", level=logging.ERROR)
413
+ continue
414
+
415
+ except Exception as e:
416
+ logger.log_message(f"Error during async code fixing: {str(e)}", level=logging.ERROR)
417
+ raise e
418
+
419
+ return result_code
420
  except Exception as e:
421
+ logger.log_message(f"Error in fix_code_with_dspy: {str(e)}", level=logging.ERROR)
422
  raise e
 
 
423
 
424
  def get_dataset_context(df):
425
  """
 
764
  fixed_code = await fix_code_with_dspy(
765
  request_data.code,
766
  request_data.error,
767
+ dataset_context,
768
+ session_state["datasets"] # Pass the actual datasets
769
  )
770
 
771
  fixed_code = format_code_block(fixed_code)