Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| # Standard library imports | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import time | |
| import uuid | |
| from io import StringIO | |
| from typing import List, Optional | |
| import ast | |
| import markdown | |
| from bs4 import BeautifulSoup | |
| import pandas as pd | |
| from datetime import datetime, UTC | |
| # Third-party imports | |
| import uvicorn | |
| from dotenv import load_dotenv | |
| from fastapi import ( | |
| Depends, | |
| FastAPI, | |
| File, | |
| Form, | |
| HTTPException, | |
| Request, | |
| UploadFile | |
| ) | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from fastapi.security import APIKeyHeader | |
| from llama_index.core import Document, VectorStoreIndex | |
| from pydantic import BaseModel | |
| # Local application imports | |
| from scripts.format_response import format_response_to_markdown | |
| from src.agents.agents import * | |
| from src.agents.retrievers.retrievers import * | |
| from src.managers.ai_manager import AI_Manager | |
| from src.managers.session_manager import SessionManager | |
| from src.managers.app_manager import AppState | |
| from src.routes.analytics_routes import router as analytics_router | |
| from src.routes.blog_routes import router as blog_router | |
| from src.routes.chat_routes import router as chat_router | |
| from src.routes.code_routes import router as code_router | |
| from src.routes.feedback_routes import router as feedback_router | |
| from src.routes.session_routes import router as session_router, get_session_id_dependency | |
| from src.routes.deep_analysis_routes import router as deep_analysis_router | |
| from src.routes.templates_routes import router as templates_router | |
| from src.schemas.query_schema import QueryRequest | |
| from src.utils.logger import Logger | |
| # Import deep analysis components directly | |
| # from src.agents.try_deep_agents import deep_analysis_module | |
| from src.agents.deep_agents import deep_analysis_module | |
| from src.utils.generate_report import generate_html_report | |
| from src.utils.model_registry import MODEL_OBJECTS | |
| logger = Logger("app", see_time=True, console_log=True) | |
| load_dotenv() | |
| # Request models | |
| class DeepAnalysisRequest(BaseModel): | |
| goal: str | |
| class DeepAnalysisResponse(BaseModel): | |
| goal: str | |
| deep_questions: str | |
| deep_plan: str | |
| summaries: List[str] | |
| code: str | |
| plotly_figs: List | |
| synthesis: List[str] | |
| final_conclusion: str | |
| html_report: Optional[str] = None | |
| styling_instructions = [ | |
| { | |
| "category": "line_charts", | |
| "description": "Used to visualize trends and changes over time, often with multiple series.", | |
| "styling": { | |
| "template": "plotly_white", | |
| "axes_line_width": 0.2, | |
| "grid_width": 1, | |
| "title": { | |
| "bold_html": True, | |
| "include": True | |
| }, | |
| "colors": "use multiple colors if more than one line", | |
| "annotations": ["min", "max"], | |
| "number_format": { | |
| "apply_k_m": True, | |
| "thresholds": {"K": 1000, "M": 100000}, | |
| "percentage_decimals": 2, | |
| "percentage_sign": True | |
| }, | |
| "default_size": {"height": 1200, "width": 1000} | |
| } | |
| }, | |
| { | |
| "category": "bar_charts", | |
| "description": "Useful for comparing discrete categories or groups with bars representing values.", | |
| "styling": { | |
| "template": "plotly_white", | |
| "axes_line_width": 0.2, | |
| "grid_width": 1, | |
| "title": {"bold_html": True, "include": True}, | |
| "annotations": ["bar values"], | |
| "number_format": { | |
| "apply_k_m": True, | |
| "thresholds": {"K": 1000, "M": 100000}, | |
| "percentage_decimals": 2, | |
| "percentage_sign": True | |
| }, | |
| "default_size": {"height": 1200, "width": 1000} | |
| } | |
| }, | |
| { | |
| "category": "histograms", | |
| "description": "Display the distribution of a data set, useful for returns or frequency distributions.", | |
| "styling": { | |
| "template": "plotly_white", | |
| "bin_size": 50, | |
| "axes_line_width": 0.2, | |
| "grid_width": 1, | |
| "title": {"bold_html": True, "include": True}, | |
| "annotations": ["x values"], | |
| "number_format": { | |
| "apply_k_m": True, | |
| "thresholds": {"K": 1000, "M": 100000}, | |
| "percentage_decimals": 2, | |
| "percentage_sign": True | |
| }, | |
| "default_size": {"height": 1200, "width": 1000} | |
| } | |
| }, | |
| { | |
| "category": "pie_charts", | |
| "description": "Show composition or parts of a whole with slices representing categories.", | |
| "styling": { | |
| "template": "plotly_white", | |
| "top_categories_to_show": 10, | |
| "bundle_rest_as": "Others", | |
| "axes_line_width": 0.2, | |
| "grid_width": 1, | |
| "title": {"bold_html": True, "include": True}, | |
| "annotations": ["x values"], | |
| "number_format": { | |
| "apply_k_m": True, | |
| "thresholds": {"K": 1000, "M": 100000}, | |
| "percentage_decimals": 2, | |
| "percentage_sign": True | |
| }, | |
| "default_size": {"height": 1200, "width": 1000} | |
| } | |
| }, | |
| { | |
| "category": "tabular_and_generic_charts", | |
| "description": "Applies to charts where number formatting needs flexibility, including mixed or raw data.", | |
| "styling": { | |
| "template": "plotly_white", | |
| "axes_line_width": 0.2, | |
| "grid_width": 1, | |
| "title": {"bold_html": True, "include": True}, | |
| "annotations": ["x values"], | |
| "number_format": { | |
| "apply_k_m": True, | |
| "thresholds": {"K": 1000, "M": 100000}, | |
| "exclude_if_commas_present": True, | |
| "exclude_if_not_numeric": True, | |
| "percentage_decimals": 2, | |
| "percentage_sign": True | |
| }, | |
| "default_size": {"height": 1200, "width": 1000} | |
| } | |
| }, | |
| { | |
| "category": "heat_maps", | |
| "description": "Show data density or intensity using color scales on a matrix or grid.", | |
| "styling": { | |
| "template": "plotly_white", | |
| "axes_styles": { | |
| "line_color": "black", | |
| "line_width": 0.2, | |
| "grid_width": 1, | |
| "format_numbers_as_k_m": True, | |
| "exclude_non_numeric_formatting": True | |
| }, | |
| "title": {"bold_html": True, "include": True}, | |
| "default_size": {"height": 1200, "width": 1000} | |
| } | |
| }, | |
| { | |
| "category": "histogram_distribution", | |
| "description": "Specialized histogram for return distributions with opacity control.", | |
| "styling": { | |
| "template": "plotly_white", | |
| "opacity": 0.75, | |
| "axes_styles": { | |
| "grid_width": 1, | |
| "format_numbers_as_k_m": True, | |
| "exclude_non_numeric_formatting": True | |
| }, | |
| "title": {"bold_html": True, "include": True}, | |
| "default_size": {"height": 1200, "width": 1000} | |
| } | |
| } | |
| ] | |
| # Convert to list of JSON strings | |
| styling_instructions = [str(chart_dict) for chart_dict in styling_instructions] | |
| # Output (just show first 2 for readability) | |
| # Add near the top of the file, after imports | |
| DEFAULT_MODEL_CONFIG = { | |
| "provider": os.getenv("MODEL_PROVIDER", "openai"), | |
| "model": os.getenv("MODEL_NAME", "gpt-5-mini"), | |
| "api_key": os.getenv("OPENAI_API_KEY"), | |
| "temperature": float(os.getenv("TEMPERATURE", 1.0)), | |
| "max_tokens": int(os.getenv("MAX_TOKENS", 6000)), "cache": False | |
| } | |
| # Create default LM config but don't set it globally | |
| default_lm = MODEL_OBJECTS[DEFAULT_MODEL_CONFIG['model']] | |
| # lm = dspy.LM('openai/gpt-4o-mini', api_key=os.getenv("OPENAI_API_KEY")) | |
| dspy.configure(lm=default_lm, async_max_workers=100) | |
| # Function to get model config from session or use default | |
| def get_session_lm(session_state): | |
| """Get the appropriate LM instance for a session, or default if not configured""" | |
| # First check if we have a valid session-specific model config | |
| if session_state and isinstance(session_state, dict) and "model_config" in session_state: | |
| model_config = session_state["model_config"] | |
| if model_config and isinstance(model_config, dict) and "model" in model_config: | |
| # Found valid session-specific model config, use it | |
| provider = model_config.get("provider", "openai").lower() | |
| model_name = model_config.get("model", DEFAULT_MODEL_CONFIG["model"]) | |
| if 'gpt-5' or 'o1' not in model_name: | |
| MODEL_OBJECTS[model_name].__dict__['kwargs']['max_tokens'] = model_config.get("max_tokens", DEFAULT_MODEL_CONFIG["max_tokens"]) | |
| MODEL_OBJECTS[model_name].__dict__['kwargs']['temperature'] = model_config.get("temperature", DEFAULT_MODEL_CONFIG["temperature"]) | |
| elif 'gpt-5' or 'o1' in model_name and provider =='openai': | |
| MODEL_OBJECTS[model_name].__dict__['kwargs']['max_completion_tokens'] = model_config.get("max_tokens", DEFAULT_MODEL_CONFIG["max_tokens"]) | |
| MODEL_OBJECTS[model_name].__dict__['kwargs']['temperature'] = 1.0 | |
| else: | |
| MODEL_OBJECTS[model_name].__dict__['kwargs']['max_tokens'] = model_config.get("max_tokens", DEFAULT_MODEL_CONFIG["max_tokens"]) | |
| MODEL_OBJECTS[model_name].__dict__['kwargs']['temperature'] = model_config.get("temperature", DEFAULT_MODEL_CONFIG["temperature"]) | |
| # If no valid session config, use default | |
| return MODEL_OBJECTS[model_name] | |
| # Initialize retrievers with empty data first | |
| # clear console | |
| def clear_console(): | |
| os.system('cls' if os.name == 'nt' else 'clear') | |
| # Check for Housing.csv | |
| housing_csv_path = "Housing.csv" | |
| if not os.path.exists(housing_csv_path): | |
| logger.log_message(f"Housing.csv not found at {os.path.abspath(housing_csv_path)}", level=logging.ERROR) | |
| raise FileNotFoundError(f"Housing.csv not found at {os.path.abspath(housing_csv_path)}") | |
| # All agents are now loaded from database - no hardcoded dictionaries needed | |
| # Add session header | |
| X_SESSION_ID = APIKeyHeader(name="X-Session-ID", auto_error=False) | |
| # Update AppState class to use SessionManager | |
| # The AppState class is now in src.managers.app_manager | |
| # Initialize FastAPI app with state | |
| app = FastAPI(title="AI Analytics API", version="1.0") | |
| # Pass required parameters to AppState | |
| app.state = AppState(styling_instructions, chat_history_name_agent, DEFAULT_MODEL_CONFIG) | |
| # Configure middleware | |
| # Use a wildcard for local development or read from environment | |
| is_development = os.getenv("ENVIRONMENT", "development").lower() == "development" | |
| allowed_origins = [] | |
| frontend_url = os.getenv("FRONTEND_URL", "").strip() | |
| print(f"FRONTEND_URL: {frontend_url}") | |
| if is_development: | |
| allowed_origins = ["*"] | |
| elif frontend_url: | |
| allowed_origins = [frontend_url] | |
| else: | |
| logger.log_message("CORS misconfigured: FRONTEND_URL not set", level=logging.ERROR) | |
| allowed_origins = [] # or set a default safe origin | |
| # Add a strict origin verification middleware | |
| async def verify_origin_middleware(request: Request, call_next): | |
| # Skip origin check in development mode | |
| if is_development: | |
| return await call_next(request) | |
| # Get the origin from the request headers | |
| origin = request.headers.get("origin") | |
| # Log the origin for debugging | |
| if origin: | |
| print(f"Request from origin: {origin}") | |
| # If no origin header or origin not in allowed list, reject the request | |
| if origin and frontend_url and origin != frontend_url: | |
| print(f"Blocked request from unauthorized origin: {origin}") | |
| return JSONResponse( | |
| status_code=403, | |
| content={"detail": "Not authorized"} | |
| ) | |
| # Continue processing the request if origin is allowed | |
| return await call_next(request) | |
| # CORS middleware (still needed for browser preflight) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=allowed_origins, | |
| allow_origin_regex=None, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| expose_headers=["*"], | |
| max_age=600 # Cache preflight requests for 10 minutes (for performance) | |
| ) | |
| # Add these constants at the top of the file with other imports/constants | |
| RESPONSE_ERROR_INVALID_QUERY = "Please provide a valid query..." | |
| RESPONSE_ERROR_NO_DATASET = "No dataset is currently loaded. Please link a dataset before proceeding with your analysis." | |
| DEFAULT_TOKEN_RATIO = 1.5 | |
| REQUEST_TIMEOUT_SECONDS = 30 # Timeout for LLM requests | |
| MAX_RECENT_MESSAGES = 5 | |
| DB_BATCH_SIZE = 10 # For future batch DB operations | |
| async def chat_with_agent( | |
| agent_name: str, | |
| request: QueryRequest, | |
| request_obj: Request, | |
| session_id: str = Depends(get_session_id_dependency) | |
| ): | |
| session_state = app.state.get_session_state(session_id) | |
| logger.log_message(f"[DEBUG] chat_with_agent called with agent: '{agent_name}', query: '{request.query[:100]}...'", level=logging.DEBUG) | |
| try: | |
| # Extract and validate query parameters | |
| logger.log_message(f"[DEBUG] Updating session from query params", level=logging.DEBUG) | |
| _update_session_from_query_params(request_obj, session_state) | |
| logger.log_message(f"[DEBUG] Session state after query params: user_id={session_state.get('user_id')}, chat_id={session_state.get('chat_id')}", level=logging.DEBUG) | |
| # Validate dataset and agent name | |
| if session_state["datasets"] is None: | |
| logger.log_message(f"[DEBUG] No dataset loaded", level=logging.DEBUG) | |
| raise HTTPException(status_code=400, detail=RESPONSE_ERROR_NO_DATASET) | |
| # Log the dataset being used for analysis with detailed information | |
| datasets = session_state["datasets"] | |
| dataset_names = list(datasets.keys()) | |
| if dataset_names: | |
| current_dataset_name = dataset_names[-1] # Get the last (most recent) dataset | |
| dataset_shape = datasets[current_dataset_name].shape | |
| # Check if this is the default dataset and explain why | |
| session_name = session_state.get("name", "") | |
| is_default_dataset = (current_dataset_name == "df" and session_name == "Housing.csv") or current_dataset_name == "Housing.csv" | |
| if is_default_dataset: | |
| logger.log_message(f"[ANALYSIS] Using DEFAULT dataset 'Housing.csv' for analysis (shape: {dataset_shape[0]} rows, {dataset_shape[1]} columns)", level=logging.INFO) | |
| logger.log_message(f"[ANALYSIS] Reason: No custom dataset uploaded yet - using default Housing.csv dataset", level=logging.INFO) | |
| else: | |
| logger.log_message(f"[ANALYSIS] Using CUSTOM dataset '{current_dataset_name}' for analysis (shape: {dataset_shape[0]} rows, {dataset_shape[1]} columns)", level=logging.INFO) | |
| logger.log_message(f"[ANALYSIS] This is a user-uploaded dataset, not the default", level=logging.INFO) | |
| else: | |
| logger.log_message(f"[ANALYSIS] No datasets available in session {session_id}", level=logging.WARNING) | |
| logger.log_message(f"[DEBUG] About to validate agent name: '{agent_name}'", level=logging.DEBUG) | |
| _validate_agent_name(agent_name, session_state) | |
| logger.log_message(f"[DEBUG] Agent validation completed successfully", level=logging.DEBUG) | |
| # Record start time for timing | |
| start_time = time.time() | |
| # Get chat context and prepare query | |
| logger.log_message(f"[DEBUG] Preparing query with context", level=logging.DEBUG) | |
| enhanced_query = _prepare_query_with_context(request.query, session_state) | |
| logger.log_message(f"[DEBUG] Enhanced query length: {len(enhanced_query)}", level=logging.DEBUG) | |
| # Initialize agent - handle standard, template, and custom agents | |
| if "," in agent_name: | |
| logger.log_message(f"[DEBUG] Processing multiple agents: {agent_name}", level=logging.DEBUG) | |
| # Multiple agents case | |
| agent_list = [agent.strip() for agent in agent_name.split(",")] | |
| # Categorize agents | |
| standard_agents = [agent for agent in agent_list if _is_standard_agent(agent)] | |
| template_agents = [agent for agent in agent_list if _is_template_agent(agent)] | |
| custom_agents = [agent for agent in agent_list if not _is_standard_agent(agent) and not _is_template_agent(agent)] | |
| logger.log_message(f"[DEBUG] Agent categorization - standard: {standard_agents}, template: {template_agents}, custom: {custom_agents}", level=logging.DEBUG) | |
| if custom_agents: | |
| # If any custom agents, use session AI system for all | |
| ai_system = session_state["ai_system"] | |
| session_lm = get_session_lm(session_state) | |
| logger.log_message(f"[DEBUG] Using custom agent execution path", level=logging.DEBUG) | |
| with dspy.context(lm=session_lm): | |
| response = await asyncio.wait_for( | |
| _execute_custom_agents(ai_system, agent_list, enhanced_query), | |
| timeout=REQUEST_TIMEOUT_SECONDS | |
| ) | |
| logger.log_message(f"[DEBUG] Custom agents response type: {type(response)}, keys: {list(response.keys()) if isinstance(response, dict) else 'not a dict'}", level=logging.DEBUG) | |
| else: | |
| # All standard/template agents - use auto_analyst_ind which loads from DB | |
| user_id = session_state.get("user_id") | |
| logger.log_message(f"[DEBUG] Using auto_analyst_ind for multiple standard/template agents with user_id: {user_id}", level=logging.DEBUG) | |
| # Create database session for agent loading | |
| from src.db.init_db import session_factory | |
| db_session = session_factory() | |
| try: | |
| # auto_analyst_ind will load all agents from database | |
| logger.log_message(f"[DEBUG] Creating auto_analyst_ind instance", level=logging.DEBUG) | |
| agent = auto_analyst_ind(agents=[], retrievers=session_state["retrievers"], user_id=user_id, db_session=db_session) | |
| session_lm = get_session_lm(session_state) | |
| logger.log_message(f"[DEBUG] About to call agent.forward with query and agent list", level=logging.DEBUG) | |
| with dspy.context(lm=session_lm): | |
| response = await asyncio.wait_for( | |
| agent.forward(enhanced_query, ",".join(agent_list)), | |
| timeout=REQUEST_TIMEOUT_SECONDS | |
| ) | |
| logger.log_message(f"[DEBUG] auto_analyst_ind response type: {type(response)}, content: {str(response)[:200]}...", level=logging.DEBUG) | |
| finally: | |
| db_session.close() | |
| else: | |
| logger.log_message(f"[DEBUG] Processing single agent: {agent_name}", level=logging.DEBUG) | |
| # Single agent case | |
| if _is_standard_agent(agent_name) or _is_template_agent(agent_name): | |
| # Standard or template agent - use auto_analyst_ind which loads from DB | |
| user_id = session_state.get("user_id") | |
| logger.log_message(f"[DEBUG] Using auto_analyst_ind for single standard/template agent '{agent_name}' with user_id: {user_id}", level=logging.DEBUG) | |
| # Create database session for agent loading | |
| from src.db.init_db import session_factory | |
| db_session = session_factory() | |
| try: | |
| # auto_analyst_ind will load all agents from database | |
| logger.log_message(f"[DEBUG] Creating auto_analyst_ind instance for single agent", level=logging.DEBUG) | |
| agent = auto_analyst_ind(agents=[], retrievers=session_state["retrievers"], user_id=user_id, db_session=db_session) | |
| session_lm = get_session_lm(session_state) | |
| logger.log_message(f"[DEBUG] About to call agent.forward for single agent '{agent_name}'", level=logging.DEBUG) | |
| with dspy.context(lm=session_lm): | |
| response = await asyncio.wait_for( | |
| agent.forward(enhanced_query, agent_name), | |
| timeout=REQUEST_TIMEOUT_SECONDS | |
| ) | |
| logger.log_message(f"[DEBUG] Single agent response type: {type(response)}, content: {str(response)[:200]}...", level=logging.DEBUG) | |
| finally: | |
| db_session.close() | |
| else: | |
| # Custom agent - use session AI system | |
| ai_system = session_state["ai_system"] | |
| session_lm = get_session_lm(session_state) | |
| logger.log_message(f"[DEBUG] Using custom agent execution for '{agent_name}'", level=logging.DEBUG) | |
| with dspy.context(lm=session_lm): | |
| response = await asyncio.wait_for( | |
| _execute_custom_agents(ai_system, [agent_name], enhanced_query), | |
| timeout=REQUEST_TIMEOUT_SECONDS | |
| ) | |
| logger.log_message(f"[DEBUG] Custom single agent response type: {type(response)}, content: {str(response)[:200]}...", level=logging.DEBUG) | |
| logger.log_message(f"[DEBUG] About to format response to markdown. Response type: {type(response)}", level=logging.DEBUG) | |
| formatted_response = format_response_to_markdown(response, agent_name, session_state["datasets"]) | |
| logger.log_message(f"[DEBUG] Formatted response type: {type(formatted_response)}, length: {len(str(formatted_response))}", level=logging.DEBUG) | |
| if formatted_response == RESPONSE_ERROR_INVALID_QUERY: | |
| logger.log_message(f"[DEBUG] Response was invalid query error", level=logging.DEBUG) | |
| return { | |
| "agent_name": agent_name, | |
| "query": request.query, | |
| "response": formatted_response, | |
| "session_id": session_id | |
| } | |
| # Track usage statistics | |
| if session_state.get("user_id"): | |
| logger.log_message(f"[DEBUG] Tracking model usage", level=logging.DEBUG) | |
| _track_model_usage( | |
| session_state=session_state, | |
| enhanced_query=enhanced_query, | |
| response=response, | |
| processing_time_ms=int((time.time() - start_time) * 1000) | |
| ) | |
| logger.log_message(f"[DEBUG] chat_with_agent completed successfully", level=logging.DEBUG) | |
| return { | |
| "agent_name": agent_name, | |
| "query": request.query, # Return original query without context | |
| "response": formatted_response, | |
| "session_id": session_id | |
| } | |
| except HTTPException: | |
| # Re-raise HTTP exceptions to preserve status codes | |
| logger.log_message(f"[DEBUG] HTTPException caught and re-raised", level=logging.DEBUG) | |
| raise | |
| except asyncio.TimeoutError: | |
| logger.log_message(f"[ERROR] Timeout error in chat_with_agent", level=logging.ERROR) | |
| raise HTTPException(status_code=504, detail="Request timed out. Please try a simpler query.") | |
| except Exception as e: | |
| logger.log_message(f"[ERROR] Unexpected error in chat_with_agent: {str(e)}", level=logging.ERROR) | |
| logger.log_message(f"[ERROR] Exception type: {type(e)}, traceback: {str(e)}", level=logging.ERROR) | |
| import traceback | |
| logger.log_message(f"[ERROR] Full traceback: {traceback.format_exc()}", level=logging.ERROR) | |
| raise HTTPException(status_code=500, detail="An unexpected error occurred. Please try again later.") | |
| async def chat_with_all( | |
| request: QueryRequest, | |
| request_obj: Request, | |
| session_id: str = Depends(get_session_id_dependency) | |
| ): | |
| session_state = app.state.get_session_state(session_id) | |
| try: | |
| # Extract and validate query parameters | |
| _update_session_from_query_params(request_obj, session_state) | |
| # Validate dataset | |
| if session_state["datasets"] is None: | |
| raise HTTPException(status_code=400, detail=RESPONSE_ERROR_NO_DATASET) | |
| if session_state["ai_system"] is None: | |
| raise HTTPException(status_code=500, detail="AI system not properly initialized.") | |
| # Get session-specific model | |
| session_lm = get_session_lm(session_state) | |
| # Create streaming response | |
| return StreamingResponse( | |
| _generate_streaming_responses(session_state, request.query, session_lm), | |
| media_type='text/event-stream', | |
| headers={ | |
| 'Cache-Control': 'no-cache', | |
| 'Connection': 'keep-alive', | |
| 'Content-Type': 'text/event-stream', | |
| 'Access-Control-Allow-Origin': '*', | |
| 'X-Accel-Buffering': 'no' | |
| } | |
| ) | |
| except HTTPException: | |
| # Re-raise HTTP exceptions to preserve status codes | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail="An unexpected error occurred. Please try again later.") | |
| # Helper functions to reduce duplication and improve modularity | |
| def _update_session_from_query_params(request_obj: Request, session_state: dict): | |
| """Extract and validate chat_id and user_id from query parameters""" | |
| # Check for chat_id in query parameters | |
| if "chat_id" in request_obj.query_params: | |
| try: | |
| chat_id_param = int(request_obj.query_params.get("chat_id")) | |
| # Update session state with this chat ID | |
| session_state["chat_id"] = chat_id_param | |
| except (ValueError, TypeError): | |
| logger.log_message("Invalid chat_id parameter", level=logging.WARNING) | |
| # Continue without updating chat_id | |
| # Check for user_id in query parameters | |
| if "user_id" in request_obj.query_params: | |
| try: | |
| user_id = int(request_obj.query_params["user_id"]) | |
| session_state["user_id"] = user_id | |
| except (ValueError, TypeError): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Invalid user_id in query params. Please provide a valid integer." | |
| ) | |
| def _validate_agent_name(agent_name: str, session_state: dict = None): | |
| """Validate that the agent name(s) are available""" | |
| logger.log_message(f"[DEBUG] Validating agent name: '{agent_name}'", level=logging.DEBUG) | |
| if "," in agent_name: | |
| # Multiple agents | |
| agent_list = [agent.strip() for agent in agent_name.split(",")] | |
| logger.log_message(f"[DEBUG] Multiple agents detected: {agent_list}", level=logging.DEBUG) | |
| for agent in agent_list: | |
| is_available = _is_agent_available(agent, session_state) | |
| logger.log_message(f"[DEBUG] Agent '{agent}' availability: {is_available}", level=logging.DEBUG) | |
| if not is_available: | |
| available_agents = _get_available_agents_list(session_state) | |
| logger.log_message(f"[DEBUG] Agent '{agent}' not found. Available: {available_agents}", level=logging.DEBUG) | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Agent '{agent}' not found. Available agents: {available_agents}" | |
| ) | |
| else: | |
| # Single agent | |
| is_available = _is_agent_available(agent_name, session_state) | |
| logger.log_message(f"[DEBUG] Single agent '{agent_name}' availability: {is_available}", level=logging.DEBUG) | |
| if not is_available: | |
| available_agents = _get_available_agents_list(session_state) | |
| logger.log_message(f"[DEBUG] Agent '{agent_name}' not found. Available: {available_agents}", level=logging.DEBUG) | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Agent '{agent_name}' not found. Available agents: {available_agents}" | |
| ) | |
| logger.log_message(f"[DEBUG] Agent validation passed for: '{agent_name}'", level=logging.DEBUG) | |
| def _is_agent_available(agent_name: str, session_state: dict = None) -> bool: | |
| """Check if an agent is available (standard, template, or custom)""" | |
| # Check if it's a standard agent | |
| if _is_standard_agent(agent_name): | |
| return True | |
| # Check if it's a template agent | |
| if _is_template_agent(agent_name): | |
| return True | |
| # Check if it's a custom agent in session | |
| if session_state and "ai_system" in session_state: | |
| ai_system = session_state["ai_system"] | |
| if hasattr(ai_system, 'agents') and agent_name in ai_system.agents: | |
| return True | |
| return False | |
| def _get_available_agents_list(session_state: dict = None) -> list: | |
| """Get list of all available agents from database""" | |
| from src.db.init_db import session_factory | |
| from src.agents.agents import load_all_available_templates_from_db | |
| # Core agents (always available) | |
| available = ["preprocessing_agent", "statistical_analytics_agent", "sk_learn_agent", "data_viz_agent"] | |
| # Add template agents from database | |
| db_session = session_factory() | |
| try: | |
| template_agents_dict = load_all_available_templates_from_db(db_session) | |
| # template_agents_dict is a dict with template_name as keys | |
| template_names = [template_name for template_name in template_agents_dict.keys() | |
| if template_name not in available and template_name != 'basic_qa_agent'] | |
| available.extend(template_names) | |
| except Exception as e: | |
| logger.log_message(f"Error loading template agents: {str(e)}", level=logging.ERROR) | |
| finally: | |
| db_session.close() | |
| return available | |
| def _is_standard_agent(agent_name: str) -> bool: | |
| """Check if agent is one of the 4 core standard agents""" | |
| standard_agents = ["preprocessing_agent", "statistical_analytics_agent", "sk_learn_agent", "data_viz_agent"] | |
| return agent_name in standard_agents | |
| def _is_template_agent(agent_name: str) -> bool: | |
| """Check if agent is a template agent""" | |
| try: | |
| from src.db.init_db import session_factory | |
| from src.db.schemas.models import AgentTemplate | |
| db_session = session_factory() | |
| try: | |
| template = db_session.query(AgentTemplate).filter( | |
| AgentTemplate.template_name == agent_name, | |
| AgentTemplate.is_active == True | |
| ).first() | |
| return template is not None | |
| finally: | |
| db_session.close() | |
| except Exception as e: | |
| logger.log_message(f"Error checking if {agent_name} is template: {str(e)}", level=logging.ERROR) | |
| return False | |
| async def _execute_custom_agents(ai_system, agent_names: list, query: str): | |
| """Execute custom agents using the session's AI system""" | |
| try: | |
| # For custom agents, we need to use the AI system's execute_agent method | |
| agent_results = [ai_system] | |
| if len(agent_names) == 1: | |
| # Single custom agent | |
| agent_name = agent_names[0] | |
| # Prepare inputs for the custom agent (similar to standard agents like data_viz_agent) | |
| dict_ = {} | |
| dict_['dataset'] = ai_system.dataset.retrieve(query)[0].text | |
| dict_['styling_index'] = ai_system.styling_index.retrieve(query)[0].text | |
| dict_['goal'] = query | |
| dict_['Agent_desc'] = str(ai_system.agent_desc) | |
| # Get input fields for this agent | |
| if agent_name in ai_system.agent_inputs: | |
| inputs = {x: dict_[x] for x in ai_system.agent_inputs[agent_name] if x in dict_} | |
| # Execute the custom agent | |
| agent_name_result, result_dict = await ai_system.agents[agent_name](**inputs) | |
| return {agent_name_result: result_dict} | |
| else: | |
| logger.log_message(f"Agent '{agent_name}' not found in ai_system.agent_inputs", level=logging.ERROR) | |
| return {"error": f"Agent '{agent_name}' input configuration not found"} | |
| else: | |
| # Multiple agents - execute sequentially | |
| results = {} | |
| for agent_name in agent_names: | |
| single_result = await _execute_custom_agents(ai_system, [agent_name], query) | |
| results.update(single_result) | |
| return results | |
| except Exception as e: | |
| logger.log_message(f"Error in _execute_custom_agents: {str(e)}", level=logging.ERROR) | |
| return {"error": f"Error executing custom agents: {str(e)}"} | |
| def _prepare_query_with_context(query: str, session_state: dict) -> str: | |
| """Prepare the query with chat context from previous messages""" | |
| chat_id = session_state.get("chat_id") | |
| if not chat_id: | |
| return query | |
| # Get chat manager from app state | |
| chat_manager = app.state._session_manager.chat_manager | |
| # Get recent messages | |
| recent_messages = chat_manager.get_recent_chat_history(chat_id, limit=MAX_RECENT_MESSAGES) | |
| # Extract response history | |
| chat_context = chat_manager.extract_response_history(recent_messages) | |
| # Append context to the query if available | |
| if chat_context: | |
| return f"### Current Query:\n{query}\n\n{chat_context}" | |
| return query | |
| def _track_model_usage(session_state: dict, enhanced_query: str, response, processing_time_ms: int): | |
| """Track model usage statistics in the database""" | |
| try: | |
| ai_manager = app.state.get_ai_manager() | |
| # Get model configuration | |
| model_config = session_state.get("model_config", DEFAULT_MODEL_CONFIG) | |
| model_name = model_config.get("model", DEFAULT_MODEL_CONFIG["model"]) | |
| provider = ai_manager.get_provider_for_model(model_name) | |
| # Calculate token usage | |
| try: | |
| # Try exact tokenization | |
| prompt_tokens = len(ai_manager.tokenizer.encode(enhanced_query)) | |
| completion_tokens = len(ai_manager.tokenizer.encode(str(response))) | |
| total_tokens = prompt_tokens + completion_tokens | |
| except Exception as token_error: | |
| # Fall back to estimation | |
| logger.log_message(f"Tokenization error: {str(token_error)}", level=logging.WARNING) | |
| prompt_words = len(enhanced_query.split()) | |
| completion_words = len(str(response).split()) | |
| prompt_tokens = int(prompt_words * DEFAULT_TOKEN_RATIO) | |
| completion_tokens = int(completion_words * DEFAULT_TOKEN_RATIO) | |
| total_tokens = prompt_tokens + completion_tokens | |
| # Calculate cost | |
| cost = ai_manager.calculate_cost(model_name, prompt_tokens, completion_tokens) | |
| # Save usage to database | |
| ai_manager.save_usage_to_db( | |
| user_id=session_state.get("user_id"), | |
| chat_id=session_state.get("chat_id"), | |
| model_name=model_name, | |
| provider=provider, | |
| prompt_tokens=int(prompt_tokens), | |
| completion_tokens=int(completion_tokens), | |
| total_tokens=int(total_tokens), | |
| query_size=len(enhanced_query), | |
| response_size=len(str(response)), | |
| cost=round(cost, 7), | |
| request_time_ms=processing_time_ms, | |
| is_streaming=False | |
| ) | |
| except Exception as e: | |
| # Log but don't fail the request if usage tracking fails | |
| logger.log_message(f"Failed to track model usage: {str(e)}", level=logging.ERROR) | |
| async def _generate_streaming_responses(session_state: dict, query: str, session_lm): | |
| """Generate streaming responses for chat_with_all endpoint""" | |
| overall_start_time = time.time() | |
| total_response = "" | |
| total_inputs = "" | |
| usage_records = [] | |
| # Add chat context from previous messages | |
| enhanced_query = _prepare_query_with_context(query, session_state) | |
| # try: | |
| # Get the plan - planner is now async, so we need to await it | |
| plan_response = await session_state["ai_system"].get_plan(enhanced_query) | |
| plan_description = format_response_to_markdown( | |
| {"analytical_planner": plan_response}, | |
| datasets=session_state["datasets"] | |
| ) | |
| # Check if plan is valid | |
| if plan_description == RESPONSE_ERROR_INVALID_QUERY: | |
| yield json.dumps({ | |
| "agent": "Analytical Planner", | |
| "content": plan_description, | |
| "status": "error" | |
| }) + "\n" | |
| return | |
| yield json.dumps({ | |
| "agent": "Analytical Planner", | |
| "content": plan_description, | |
| "status": "success" if plan_description else "error" | |
| }) + "\n" | |
| # Track planner usage | |
| if session_state.get("user_id"): | |
| planner_tokens = _estimate_tokens(ai_manager=app.state.ai_manager, | |
| input_text=enhanced_query, | |
| output_text=plan_description) | |
| usage_records.append(_create_usage_record( | |
| session_state=session_state, | |
| model_name=session_state.get("model_config", DEFAULT_MODEL_CONFIG)["model"], | |
| prompt_tokens=planner_tokens["prompt"], | |
| completion_tokens=planner_tokens["completion"], | |
| query_size=len(enhanced_query), | |
| response_size=len(plan_description), | |
| processing_time_ms=int((time.time() - overall_start_time) * 1000), | |
| is_streaming=False | |
| )) | |
| logger.log_message(f"Plan response: {plan_response}", level=logging.INFO) | |
| logger.log_message(f"Plan response type: {type(plan_response)}", level=logging.INFO) | |
| # Check if plan_response is valid | |
| # if not plan_response or not isinstance(plan_response, dict): | |
| # yield json.dumps({ | |
| # "agent": "Analytical Planner", | |
| # "content": "**Error: Invalid plan response**\n\nResponse: " + str(plan_response), | |
| # "status": "error" | |
| # }) + "\n" | |
| # return | |
| # Execute the plan with well-managed concurrency | |
| with dspy.context(lm = session_lm): | |
| # try: | |
| async for agent_name, inputs, response in session_state["ai_system"].execute_plan(enhanced_query, plan_response): | |
| if agent_name == "plan_not_found": | |
| yield json.dumps({ | |
| "agent": "Analytical Planner", | |
| "content": "**No plan found**\n\nPlease try again with a different query or try using a different model.", | |
| "status": "error" | |
| }) + "\n" | |
| return | |
| if agent_name == "plan_not_formated_correctly": | |
| yield json.dumps({ | |
| "agent": "Analytical Planner", | |
| "content": "**Something went wrong with formatting, retry the query!**", | |
| "status": "error" | |
| }) + "\n" | |
| return | |
| formatted_response = format_response_to_markdown( | |
| {agent_name: response}, | |
| datasets=session_state["datasets"] | |
| ) | |
| yield json.dumps({ | |
| "agent": agent_name.split("__")[0] if "__" in agent_name else agent_name, | |
| "content": formatted_response, | |
| "status": "success" if response else "error" | |
| }) + "\n" | |
| # Handle agent errors | |
| if isinstance(response, dict) and "error" in response: | |
| yield json.dumps({ | |
| "agent": agent_name, | |
| "content": f"**Error in {agent_name}**: {response['error']}", | |
| "status": "error" | |
| }) + "\n" | |
| continue # Continue with next agent instead of returning | |
| if formatted_response == RESPONSE_ERROR_INVALID_QUERY: | |
| yield json.dumps({ | |
| "agent": agent_name, | |
| "content": formatted_response, | |
| "status": "error" | |
| }) + "\n" | |
| continue # Continue with next agent instead of returning | |
| # Send response chunk | |
| # Track agent usage for future batch DB write | |
| if session_state.get("user_id"): | |
| agent_tokens = _estimate_tokens( | |
| ai_manager=app.state.ai_manager, | |
| input_text=str(inputs), | |
| output_text=str(response) | |
| ) | |
| # Get appropriate model name for code combiner | |
| if "code_combiner_agent" in agent_name and "__" in agent_name: | |
| provider = agent_name.split("__")[1] | |
| model_name = _get_model_name_for_provider(provider) | |
| else: | |
| model_name = session_state.get("model_config", DEFAULT_MODEL_CONFIG)["model"] | |
| usage_records.append(_create_usage_record( | |
| session_state=session_state, | |
| model_name=model_name, | |
| prompt_tokens=agent_tokens["prompt"], | |
| completion_tokens=agent_tokens["completion"], | |
| query_size=len(str(inputs)), | |
| response_size=len(str(response)), | |
| processing_time_ms=int((time.time() - overall_start_time) * 1000), | |
| is_streaming=True | |
| )) | |
| # except asyncio.TimeoutError: | |
| # yield json.dumps({ | |
| # "agent": "planner", | |
| # "content": "The request timed out. Please try a simpler query.", | |
| # "status": "error" | |
| # }) + "\n" | |
| # return | |
| # except Exception as e: | |
| # logger.log_message(f"Error executing plan: {str(e)}", level=logging.ERROR) | |
| # yield json.dumps({ | |
| # "agent": "planner", | |
| # "content": f"An error occurred while executing the plan: {str(e)}", | |
| # "status": "error" | |
| # }) + "\n" | |
| # return | |
| # except Exception as e: | |
| # logger.log_message(f"Error in streaming response: {str(e)}", level=logging.ERROR) | |
| # yield json.dumps({ | |
| # "agent": "planner", | |
| # "content": "An error occurred while generating responses. Please try again!" + str(e) + str({k: v for k, v in session_lm.__dict__['kwargs'].items() if k != 'api_key'}), | |
| # "status": "error" | |
| # }) + "\n" | |
| def _estimate_tokens(ai_manager, input_text: str, output_text: str) -> dict: | |
| """Estimate token counts, with fallback for tokenization errors""" | |
| try: | |
| # Try exact tokenization | |
| prompt_tokens = len(ai_manager.tokenizer.encode(input_text)) | |
| completion_tokens = len(ai_manager.tokenizer.encode(output_text)) | |
| except Exception: | |
| # Fall back to estimation | |
| prompt_words = len(input_text.split()) | |
| completion_words = len(output_text.split()) | |
| prompt_tokens = int(prompt_words * DEFAULT_TOKEN_RATIO) | |
| completion_tokens = int(completion_words * DEFAULT_TOKEN_RATIO) | |
| return { | |
| "prompt": prompt_tokens, | |
| "completion": completion_tokens, | |
| "total": prompt_tokens + completion_tokens | |
| } | |
| def _create_usage_record(session_state: dict, model_name: str, prompt_tokens: int, | |
| completion_tokens: int, query_size: int, response_size: int, | |
| processing_time_ms: int, is_streaming: bool) -> dict: | |
| """Create a usage record for the database""" | |
| ai_manager = app.state.get_ai_manager() | |
| provider = ai_manager.get_provider_for_model(model_name) | |
| cost = ai_manager.calculate_cost(model_name, prompt_tokens, completion_tokens) | |
| return { | |
| "user_id": session_state.get("user_id"), | |
| "chat_id": session_state.get("chat_id"), | |
| "model_name": model_name, | |
| "provider": provider, | |
| "prompt_tokens": int(prompt_tokens), | |
| "completion_tokens": int(completion_tokens), | |
| "total_tokens": int(prompt_tokens + completion_tokens), | |
| "query_size": query_size, | |
| "response_size": response_size, | |
| "cost": round(cost, 7), | |
| "request_time_ms": processing_time_ms, | |
| "is_streaming": is_streaming | |
| } | |
| def _get_model_name_for_provider(provider: str) -> str: | |
| """Get the model name for a provider""" | |
| provider_model_map = { | |
| "openai": "o3-mini", | |
| "anthropic": "claude-3-7-sonnet-latest", | |
| "gemini": "gemini-2.5-pro-preview-03-25" | |
| } | |
| return provider_model_map.get(provider, "o3-mini") | |
| # Add an endpoint to list available agents | |
| async def list_agents(request: Request, session_id: str = Depends(get_session_id_dependency)): | |
| """Get all available agents (standard, template, and custom)""" | |
| session_state = app.state.get_session_state(session_id) | |
| try: | |
| # Get all available agents from database and session | |
| available_agents_list = _get_available_agents_list(session_state) | |
| # Categorize agents | |
| standard_agents = ["preprocessing_agent", "statistical_analytics_agent", "sk_learn_agent", "data_viz_agent"] | |
| # Get template agents from database | |
| from src.db.init_db import session_factory | |
| from src.agents.agents import load_all_available_templates_from_db | |
| db_session = session_factory() | |
| try: | |
| template_agents_dict = load_all_available_templates_from_db(db_session) | |
| # template_agents_dict is a dict with template_name as keys | |
| template_agents = [template_name for template_name in template_agents_dict.keys() | |
| if template_name not in standard_agents and template_name != 'basic_qa_agent'] | |
| except Exception as e: | |
| logger.log_message(f"Error loading template agents in /agents endpoint: {str(e)}", level=logging.ERROR) | |
| template_agents = [] | |
| finally: | |
| db_session.close() | |
| # Get custom agents from session | |
| custom_agents = [] | |
| if session_state and "ai_system" in session_state: | |
| ai_system = session_state["ai_system"] | |
| if hasattr(ai_system, 'agents'): | |
| custom_agents = [agent for agent in available_agents_list | |
| if agent not in standard_agents and agent not in template_agents] | |
| # Ensure template agents are in the available list | |
| for template_agent in template_agents: | |
| if template_agent not in available_agents_list: | |
| available_agents_list.append(template_agent) | |
| return { | |
| "available_agents": available_agents_list, | |
| "standard_agents": standard_agents, | |
| "template_agents": template_agents, | |
| "custom_agents": custom_agents | |
| } | |
| except Exception as e: | |
| logger.log_message(f"Error getting agents list: {str(e)}", level=logging.ERROR) | |
| raise HTTPException(status_code=500, detail=f"Error getting agents list: {str(e)}") | |
| async def health(): | |
| return {"message": "API is healthy and running"} | |
| async def index(): | |
| return { | |
| "title": "Welcome to the AI Analytics API", | |
| "message": "Explore our API for advanced analytics and visualization tools designed to empower your data-driven decisions.", | |
| "description": "Utilize our powerful agents and models to gain insights from your data effortlessly.", | |
| "colors": { | |
| "primary": "#007bff", | |
| "secondary": "#6c757d", | |
| "success": "#28a745", | |
| "danger": "#dc3545", | |
| }, | |
| "features": [ | |
| "Real-time data processing", | |
| "Customizable visualizations", | |
| "Seamless integration with various data sources", | |
| "User-friendly interface for easy navigation", | |
| "Custom Analytics", | |
| ], | |
| } | |
| async def chat_history_name(request: dict, session_id: str = Depends(get_session_id_dependency)): | |
| query = request.get("query") | |
| name = None | |
| lm = dspy.LM(model="gpt-4o-mini", max_tokens=300, temperature=0.5) | |
| with dspy.context(lm=lm): | |
| name = app.state.get_chat_history_name_agent()(query=str(query)) | |
| return {"name": name.name if name else "New Chat"} | |
| async def deep_analysis_streaming( | |
| request: DeepAnalysisRequest, | |
| request_obj: Request, | |
| session_id: str = Depends(get_session_id_dependency) | |
| ): | |
| """Perform streaming deep analysis with real-time updates""" | |
| session_state = app.state.get_session_state(session_id) | |
| try: | |
| # Extract and validate query parameters | |
| _update_session_from_query_params(request_obj, session_state) | |
| # Validate dataset | |
| if session_state["datasets"] is None: | |
| raise HTTPException(status_code=400, detail=RESPONSE_ERROR_NO_DATASET) | |
| # Get user_id from session state (if available) | |
| user_id = session_state.get("user_id") | |
| # Generate a UUID for this report | |
| import uuid | |
| report_uuid = str(uuid.uuid4()) | |
| # Create initial pending report in the database | |
| try: | |
| from src.db.init_db import session_factory | |
| from src.db.schemas.models import DeepAnalysisReport | |
| db_session = session_factory() | |
| try: | |
| # Create a pending report entry | |
| new_report = DeepAnalysisReport( | |
| report_uuid=report_uuid, | |
| user_id=user_id, | |
| goal=request.goal, | |
| status="pending", | |
| start_time=datetime.now(UTC), | |
| progress_percentage=0 | |
| ) | |
| db_session.add(new_report) | |
| db_session.commit() | |
| db_session.refresh(new_report) | |
| # Store the report ID in session state for later updates | |
| session_state["current_deep_analysis_id"] = new_report.report_id | |
| session_state["current_deep_analysis_uuid"] = report_uuid | |
| except Exception as e: | |
| logger.log_message(f"Error creating initial deep analysis report: {str(e)}", level=logging.ERROR) | |
| # Continue even if DB storage fails | |
| finally: | |
| db_session.close() | |
| except Exception as e: | |
| logger.log_message(f"Database operation failed: {str(e)}", level=logging.ERROR) | |
| # Continue even if DB operation fails | |
| # Get session-specific model | |
| # session_lm = get_session_lm(session_state) | |
| session_lm = dspy.LM(model="anthropic/claude-sonnet-4-20250514", max_tokens=7000, temperature=0.5) | |
| return StreamingResponse( | |
| _generate_deep_analysis_stream(session_state, request.goal, session_lm, session_id), | |
| media_type='text/event-stream', | |
| headers={ | |
| 'Cache-Control': 'no-cache', | |
| 'Connection': 'keep-alive', | |
| 'Content-Type': 'text/event-stream', | |
| 'Access-Control-Allow-Origin': '*', | |
| 'X-Accel-Buffering': 'no' | |
| } | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.log_message(f"Streaming deep analysis failed: {str(e)}", level=logging.ERROR) | |
| raise HTTPException(status_code=500, detail=f"Streaming deep analysis failed: {str(e)}") | |
| async def _generate_deep_analysis_stream(session_state: dict, goal: str, session_lm, session_id: str): | |
| """Generate streaming responses for deep analysis""" | |
| # Track the start time for duration calculation | |
| start_time = datetime.now(UTC) | |
| try: | |
| # Get dataset info | |
| datasets = session_state["datasets"] | |
| desc = session_state['description'] | |
| # Generate dataset info for all datasets | |
| logger.log_message(f"🔍 DEEP ANALYSIS START - datasets type: {type(datasets)}, keys: {list(datasets.keys()) if datasets else 'None'}", level=logging.DEBUG) | |
| dataset_info = desc | |
| logger.log_message(f"🔍 DEEP ANALYSIS - dataset_info type: {type(dataset_info)}, length: {len(dataset_info) if isinstance(dataset_info, str) else 'N/A'}", level=logging.DEBUG) | |
| logger.log_message(f"🔍 DEEP ANALYSIS - dataset_info content: {dataset_info[:200]}...", level=logging.DEBUG) | |
| # Get report info from session state | |
| report_id = session_state.get("current_deep_analysis_id") | |
| report_uuid = session_state.get("current_deep_analysis_uuid") | |
| user_id = session_state.get("user_id") | |
| # Helper function to update report in database | |
| async def update_report_in_db(status, progress, step=None, content=None): | |
| if not report_id: | |
| return | |
| try: | |
| from src.db.init_db import session_factory | |
| from src.db.schemas.models import DeepAnalysisReport | |
| db_session = session_factory() | |
| try: | |
| report = db_session.query(DeepAnalysisReport).filter(DeepAnalysisReport.report_id == report_id).first() | |
| if report: | |
| report.status = status | |
| report.progress_percentage = progress | |
| # Update step-specific fields if provided | |
| if step == "questions" and content: | |
| report.deep_questions = content | |
| elif step == "planning" and content: | |
| report.deep_plan = content | |
| elif step == "analysis" and content: | |
| # For analysis step, we get the full object with multiple fields | |
| if isinstance(content, dict): | |
| # Update fields from content if they exist | |
| if "deep_questions" in content and content["deep_questions"]: | |
| report.deep_questions = content["deep_questions"] | |
| if "deep_plan" in content and content["deep_plan"]: | |
| report.deep_plan = content["deep_plan"] | |
| if "code" in content and content["code"]: | |
| report.analysis_code = content["code"] | |
| if "final_conclusion" in content and content["final_conclusion"]: | |
| report.final_conclusion = content["final_conclusion"] | |
| # Also update summary from conclusion | |
| conclusion = content["final_conclusion"] | |
| conclusion = conclusion.replace("**Conclusion**", "") | |
| report.report_summary = conclusion[:200] + "..." if len(conclusion) > 200 else conclusion | |
| # Handle JSON fields | |
| if "summaries" in content and content["summaries"]: | |
| report.summaries = json.dumps(content["summaries"]) | |
| if "plotly_figs" in content and content["plotly_figs"]: | |
| report.plotly_figures = json.dumps(content["plotly_figs"]) | |
| if "synthesis" in content and content["synthesis"]: | |
| report.synthesis = json.dumps(content["synthesis"]) | |
| # For the final step, update the HTML report | |
| if step == "completed": | |
| if content: | |
| report.html_report = content | |
| else: | |
| logger.log_message("No HTML content provided for completed step", level=logging.WARNING) | |
| report.end_time = datetime.now(UTC) | |
| # Ensure start_time is timezone-aware before calculating duration | |
| if report.start_time.tzinfo is None: | |
| start_time_utc = report.start_time.replace(tzinfo=UTC) | |
| else: | |
| start_time_utc = report.start_time | |
| report.duration_seconds = int((report.end_time - start_time_utc).total_seconds()) | |
| report.updated_at = datetime.now(UTC) | |
| db_session.commit() | |
| except Exception as e: | |
| db_session.rollback() | |
| logger.log_message(f"Error updating deep analysis report: {str(e)}", level=logging.ERROR) | |
| finally: | |
| db_session.close() | |
| except Exception as e: | |
| logger.log_message(f"Database operation failed: {str(e)}", level=logging.ERROR) | |
| # Use session model for this request | |
| with dspy.context(lm=session_lm): | |
| # Send initial status | |
| yield json.dumps({ | |
| "step": "initialization", | |
| "status": "starting", | |
| "message": "Initializing deep analysis...", | |
| "progress": 5 | |
| }) + "\n" | |
| # Update DB status to running | |
| await update_report_in_db("running", 5) | |
| # Get deep analyzer - use the correct session_id from the session_state | |
| logger.log_message(f"Getting deep analyzer for session_id: {session_id}, user_id: {user_id}", level=logging.INFO) | |
| deep_analyzer = app.state.get_deep_analyzer(session_id) | |
| # Make all datasets available globally for code execution | |
| for dataset_name, dataset_df in datasets.items(): | |
| globals()[dataset_name] = dataset_df | |
| # Use the new streaming method and forward all progress updates | |
| final_result = None | |
| logger.log_message(f"🔍 CALLING DEEP ANALYSIS - goal: {goal[:100]}...", level=logging.DEBUG) | |
| logger.log_message(f"🔍 CALLING DEEP ANALYSIS - dataset_info type: {type(dataset_info)}, length: {len(dataset_info) if isinstance(dataset_info, str) else 'N/A'}", level=logging.DEBUG) | |
| logger.log_message(f"🔍 CALLING DEEP ANALYSIS - session_datasets type: {type(datasets)}, keys: {list(datasets.keys()) if datasets else 'None'}", level=logging.DEBUG) | |
| async for update in deep_analyzer.execute_deep_analysis_streaming( | |
| goal=goal, | |
| dataset_info=dataset_info, | |
| session_datasets=datasets # Pass all datasets instead of single df | |
| ): | |
| # Convert the update to the expected format and yield it | |
| if update.get("step") == "questions" and update.get("status") == "completed": | |
| # Update DB with questions | |
| await update_report_in_db("running", update.get("progress", 0), "questions", update.get("content")) | |
| elif update.get("step") == "planning" and update.get("status") == "completed": | |
| # Update DB with planning | |
| await update_report_in_db("running", update.get("progress", 0), "planning", update.get("content")) | |
| elif update.get("step") == "conclusion" and update.get("status") == "completed": | |
| # Store the final result for later processing | |
| final_result = update.get("final_result") | |
| # Convert Plotly figures to JSON format for network transmission | |
| if final_result: | |
| import plotly.io | |
| serialized_return_dict = final_result.copy() | |
| # Convert plotly_figs to JSON format | |
| if 'plotly_figs' in serialized_return_dict and serialized_return_dict['plotly_figs']: | |
| json_figs = [] | |
| for fig_list in serialized_return_dict['plotly_figs']: | |
| if isinstance(fig_list, list): | |
| json_fig_list = [] | |
| for fig in fig_list: | |
| if hasattr(fig, 'to_json'): # Check if it's a Plotly figure | |
| json_fig_list.append(plotly.io.to_json(fig)) | |
| else: | |
| json_fig_list.append(fig) # Already JSON or other format | |
| json_figs.append(json_fig_list) | |
| else: | |
| # Single figure case | |
| if hasattr(fig_list, 'to_json'): | |
| json_figs.append(plotly.io.to_json(fig_list)) | |
| else: | |
| json_figs.append(fig_list) | |
| serialized_return_dict['plotly_figs'] = json_figs | |
| # Update DB with analysis results | |
| await update_report_in_db("running", update.get("progress", 0), "analysis", serialized_return_dict) | |
| # Generate HTML report using the original final_result with Figure objects | |
| html_report = None | |
| try: | |
| html_report = generate_html_report(final_result) | |
| except Exception as e: | |
| logger.log_message(f"Error generating HTML report: {str(e)}", level=logging.ERROR) | |
| # Continue even if HTML generation fails | |
| # Send the analysis results | |
| yield json.dumps({ | |
| "step": "analysis", | |
| "status": "completed", | |
| "content": serialized_return_dict, | |
| "progress": 90 | |
| }) + "\n" | |
| # Send report generation status | |
| yield json.dumps({ | |
| "step": "report", | |
| "status": "processing", | |
| "message": "Generating final report...", | |
| "progress": 95 | |
| }) + "\n" | |
| # Send final completion | |
| yield json.dumps({ | |
| "step": "completed", | |
| "status": "success", | |
| "analysis": serialized_return_dict, | |
| "html_report": html_report, | |
| "progress": 100 | |
| }) + "\n" | |
| # Update DB with completed report (with HTML if generated) | |
| if html_report: | |
| logger.log_message(f"Saving HTML report to database, length: {len(html_report)}", level=logging.INFO) | |
| else: | |
| logger.log_message("No HTML report to save to database", level=logging.WARNING) | |
| await update_report_in_db("completed", 100, "completed", html_report) | |
| elif update.get("step") == "error": | |
| # Forward error directly | |
| yield json.dumps(update) + "\n" | |
| await update_report_in_db("failed", 0) | |
| return | |
| else: | |
| # Forward all other progress updates | |
| yield json.dumps(update) + "\n" | |
| # If we somehow exit the loop without getting a final result, that's an error | |
| if not final_result: | |
| yield json.dumps({ | |
| "step": "error", | |
| "status": "failed", | |
| "message": "Deep analysis completed without final result", | |
| "progress": 0 | |
| }) + "\n" | |
| await update_report_in_db("failed", 0) | |
| except Exception as e: | |
| logger.log_message(f"Error in deep analysis stream: {str(e)}", level=logging.ERROR) | |
| yield json.dumps({ | |
| "step": "error", | |
| "status": "failed", | |
| "message": f"Deep analysis failed: {str(e)}", | |
| "progress": 0 | |
| }) + "\n" | |
| # Update DB with error status | |
| if 'update_report_in_db' in locals() and session_state.get("current_deep_analysis_id"): | |
| await update_report_in_db("failed", 0) | |
| async def download_html_report( | |
| request: dict, | |
| session_id: str = Depends(get_session_id_dependency) | |
| ): | |
| """Download HTML report from previous deep analysis""" | |
| try: | |
| analysis_data = request.get("analysis_data") | |
| if not analysis_data: | |
| raise HTTPException(status_code=400, detail="No analysis data provided") | |
| # Get report UUID from request if available (for saving to DB) | |
| report_uuid = request.get("report_uuid") | |
| session_state = app.state.get_session_state(session_id) | |
| # If no report_uuid in request, try to get it from session state | |
| if not report_uuid and session_state.get("current_deep_analysis_uuid"): | |
| report_uuid = session_state.get("current_deep_analysis_uuid") | |
| # Convert JSON-serialized Plotly figures back to Figure objects for HTML generation | |
| processed_data = analysis_data.copy() | |
| if 'plotly_figs' in processed_data and processed_data['plotly_figs']: | |
| import plotly.io | |
| import plotly.graph_objects as go | |
| figure_objects = [] | |
| for fig_list in processed_data['plotly_figs']: | |
| if isinstance(fig_list, list): | |
| fig_obj_list = [] | |
| for fig_json in fig_list: | |
| if isinstance(fig_json, str): | |
| # Convert JSON string back to Figure object | |
| try: | |
| fig_obj = plotly.io.from_json(fig_json) | |
| fig_obj_list.append(fig_obj) | |
| except Exception as e: | |
| logger.log_message(f"Error parsing Plotly JSON: {str(e)}", level=logging.WARNING) | |
| continue | |
| elif hasattr(fig_json, 'to_html'): | |
| # Already a Figure object | |
| fig_obj_list.append(fig_json) | |
| figure_objects.append(fig_obj_list) | |
| else: | |
| # Single figure case | |
| if isinstance(fig_list, str): | |
| try: | |
| fig_obj = plotly.io.from_json(fig_list) | |
| figure_objects.append(fig_obj) | |
| except Exception as e: | |
| logger.log_message(f"Error parsing Plotly JSON: {str(e)}", level=logging.WARNING) | |
| continue | |
| elif hasattr(fig_list, 'to_html'): | |
| figure_objects.append(fig_list) | |
| processed_data['plotly_figs'] = figure_objects | |
| # Generate HTML report | |
| html_report = generate_html_report(processed_data) | |
| # Save report to database if we have a UUID | |
| if report_uuid: | |
| try: | |
| from src.db.init_db import session_factory | |
| from src.db.schemas.models import DeepAnalysisReport | |
| db_session = session_factory() | |
| try: | |
| # Try to find existing report by UUID | |
| report = db_session.query(DeepAnalysisReport).filter(DeepAnalysisReport.report_uuid == report_uuid).first() | |
| if report: | |
| # Update existing report with HTML content | |
| report.html_report = html_report | |
| report.updated_at = datetime.now(UTC) | |
| db_session.commit() | |
| except Exception as e: | |
| db_session.rollback() | |
| finally: | |
| db_session.close() | |
| except Exception as e: | |
| logger.log_message(f"Database operation failed when storing HTML report: {str(e)}", level=logging.ERROR) | |
| # Continue even if DB storage fails | |
| # Create a filename with timestamp | |
| timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S") | |
| filename = f"deep_analysis_report_{timestamp}.html" | |
| # Return as downloadable file | |
| return StreamingResponse( | |
| iter([html_report.encode('utf-8')]), | |
| media_type='text/html', | |
| headers={ | |
| 'Content-Disposition': f'attachment; filename="{filename}"', | |
| 'Content-Type': 'text/html; charset=utf-8' | |
| } | |
| ) | |
| except Exception as e: | |
| logger.log_message(f"Failed to generate HTML report: {str(e)}", level=logging.ERROR) | |
| raise HTTPException(status_code=500, detail=f"Failed to generate report: {str(e)}") | |
| # In the section where routers are included, add the session_router | |
| app.include_router(chat_router) | |
| app.include_router(analytics_router) | |
| app.include_router(code_router) | |
| app.include_router(session_router) | |
| app.include_router(feedback_router) | |
| app.include_router(deep_analysis_router) | |
| app.include_router(templates_router) | |
| app.include_router(blog_router) | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 8000)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |