Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
# Standard library imports | |
import asyncio | |
import json | |
import logging | |
import os | |
import time | |
import uuid | |
from io import StringIO | |
from typing import List, Optional | |
import ast | |
import markdown | |
from bs4 import BeautifulSoup | |
import pandas as pd | |
from datetime import datetime, UTC | |
# Third-party imports | |
import uvicorn | |
from dotenv import load_dotenv | |
from fastapi import ( | |
Depends, | |
FastAPI, | |
File, | |
Form, | |
HTTPException, | |
Request, | |
UploadFile | |
) | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse, StreamingResponse | |
from fastapi.security import APIKeyHeader | |
from llama_index.core import Document, VectorStoreIndex | |
from pydantic import BaseModel | |
# Local application imports | |
from scripts.format_response import format_response_to_markdown | |
from src.agents.agents import * | |
from src.agents.retrievers.retrievers import * | |
from src.managers.ai_manager import AI_Manager | |
from src.managers.session_manager import SessionManager | |
from src.managers.app_manager import AppState | |
from src.routes.analytics_routes import router as analytics_router | |
from src.routes.blog_routes import router as blog_router | |
from src.routes.chat_routes import router as chat_router | |
from src.routes.code_routes import router as code_router | |
from src.routes.feedback_routes import router as feedback_router | |
from src.routes.session_routes import router as session_router, get_session_id_dependency | |
from src.routes.deep_analysis_routes import router as deep_analysis_router | |
from src.routes.templates_routes import router as templates_router | |
from src.schemas.query_schema import QueryRequest | |
from src.utils.logger import Logger | |
# Import deep analysis components directly | |
# from src.agents.try_deep_agents import deep_analysis_module | |
from src.agents.deep_agents import deep_analysis_module | |
from src.utils.generate_report import generate_html_report | |
from src.utils.model_registry import MODEL_OBJECTS | |
logger = Logger("app", see_time=True, console_log=True) | |
load_dotenv() | |
# Request models | |
class DeepAnalysisRequest(BaseModel): | |
goal: str | |
class DeepAnalysisResponse(BaseModel): | |
goal: str | |
deep_questions: str | |
deep_plan: str | |
summaries: List[str] | |
code: str | |
plotly_figs: List | |
synthesis: List[str] | |
final_conclusion: str | |
html_report: Optional[str] = None | |
styling_instructions = [ | |
{ | |
"category": "line_charts", | |
"description": "Used to visualize trends and changes over time, often with multiple series.", | |
"styling": { | |
"template": "plotly_white", | |
"axes_line_width": 0.2, | |
"grid_width": 1, | |
"title": { | |
"bold_html": True, | |
"include": True | |
}, | |
"colors": "use multiple colors if more than one line", | |
"annotations": ["min", "max"], | |
"number_format": { | |
"apply_k_m": True, | |
"thresholds": {"K": 1000, "M": 100000}, | |
"percentage_decimals": 2, | |
"percentage_sign": True | |
}, | |
"default_size": {"height": 1200, "width": 1000} | |
} | |
}, | |
{ | |
"category": "bar_charts", | |
"description": "Useful for comparing discrete categories or groups with bars representing values.", | |
"styling": { | |
"template": "plotly_white", | |
"axes_line_width": 0.2, | |
"grid_width": 1, | |
"title": {"bold_html": True, "include": True}, | |
"annotations": ["bar values"], | |
"number_format": { | |
"apply_k_m": True, | |
"thresholds": {"K": 1000, "M": 100000}, | |
"percentage_decimals": 2, | |
"percentage_sign": True | |
}, | |
"default_size": {"height": 1200, "width": 1000} | |
} | |
}, | |
{ | |
"category": "histograms", | |
"description": "Display the distribution of a data set, useful for returns or frequency distributions.", | |
"styling": { | |
"template": "plotly_white", | |
"bin_size": 50, | |
"axes_line_width": 0.2, | |
"grid_width": 1, | |
"title": {"bold_html": True, "include": True}, | |
"annotations": ["x values"], | |
"number_format": { | |
"apply_k_m": True, | |
"thresholds": {"K": 1000, "M": 100000}, | |
"percentage_decimals": 2, | |
"percentage_sign": True | |
}, | |
"default_size": {"height": 1200, "width": 1000} | |
} | |
}, | |
{ | |
"category": "pie_charts", | |
"description": "Show composition or parts of a whole with slices representing categories.", | |
"styling": { | |
"template": "plotly_white", | |
"top_categories_to_show": 10, | |
"bundle_rest_as": "Others", | |
"axes_line_width": 0.2, | |
"grid_width": 1, | |
"title": {"bold_html": True, "include": True}, | |
"annotations": ["x values"], | |
"number_format": { | |
"apply_k_m": True, | |
"thresholds": {"K": 1000, "M": 100000}, | |
"percentage_decimals": 2, | |
"percentage_sign": True | |
}, | |
"default_size": {"height": 1200, "width": 1000} | |
} | |
}, | |
{ | |
"category": "tabular_and_generic_charts", | |
"description": "Applies to charts where number formatting needs flexibility, including mixed or raw data.", | |
"styling": { | |
"template": "plotly_white", | |
"axes_line_width": 0.2, | |
"grid_width": 1, | |
"title": {"bold_html": True, "include": True}, | |
"annotations": ["x values"], | |
"number_format": { | |
"apply_k_m": True, | |
"thresholds": {"K": 1000, "M": 100000}, | |
"exclude_if_commas_present": True, | |
"exclude_if_not_numeric": True, | |
"percentage_decimals": 2, | |
"percentage_sign": True | |
}, | |
"default_size": {"height": 1200, "width": 1000} | |
} | |
}, | |
{ | |
"category": "heat_maps", | |
"description": "Show data density or intensity using color scales on a matrix or grid.", | |
"styling": { | |
"template": "plotly_white", | |
"axes_styles": { | |
"line_color": "black", | |
"line_width": 0.2, | |
"grid_width": 1, | |
"format_numbers_as_k_m": True, | |
"exclude_non_numeric_formatting": True | |
}, | |
"title": {"bold_html": True, "include": True}, | |
"default_size": {"height": 1200, "width": 1000} | |
} | |
}, | |
{ | |
"category": "histogram_distribution", | |
"description": "Specialized histogram for return distributions with opacity control.", | |
"styling": { | |
"template": "plotly_white", | |
"opacity": 0.75, | |
"axes_styles": { | |
"grid_width": 1, | |
"format_numbers_as_k_m": True, | |
"exclude_non_numeric_formatting": True | |
}, | |
"title": {"bold_html": True, "include": True}, | |
"default_size": {"height": 1200, "width": 1000} | |
} | |
} | |
] | |
# Convert to list of JSON strings | |
styling_instructions = [str(chart_dict) for chart_dict in styling_instructions] | |
# Output (just show first 2 for readability) | |
# Add near the top of the file, after imports | |
DEFAULT_MODEL_CONFIG = { | |
"provider": os.getenv("MODEL_PROVIDER", "openai"), | |
"model": os.getenv("MODEL_NAME", "gpt-5-mini"), | |
"api_key": os.getenv("OPENAI_API_KEY"), | |
"temperature": float(os.getenv("TEMPERATURE", 1.0)), | |
"max_tokens": int(os.getenv("MAX_TOKENS", 6000)), "cache": False | |
} | |
# Create default LM config but don't set it globally | |
default_lm = MODEL_OBJECTS[DEFAULT_MODEL_CONFIG['model']] | |
# lm = dspy.LM('openai/gpt-4o-mini', api_key=os.getenv("OPENAI_API_KEY")) | |
dspy.configure(lm=default_lm, async_max_workers=1000) | |
# Function to get model config from session or use default | |
def get_session_lm(session_state): | |
"""Get the appropriate LM instance for a session, or default if not configured""" | |
# First check if we have a valid session-specific model config | |
if session_state and isinstance(session_state, dict) and "model_config" in session_state: | |
model_config = session_state["model_config"] | |
if model_config and isinstance(model_config, dict) and "model" in model_config: | |
# Found valid session-specific model config, use it | |
provider = model_config.get("provider", "openai").lower() | |
model_name = model_config.get("model", DEFAULT_MODEL_CONFIG["model"]) | |
# Get temperature and clamp to valid range for Anthropic (0..1) | |
temp = model_config.get("temperature", DEFAULT_MODEL_CONFIG["temperature"]) | |
if provider == "anthropic": | |
temp = min(1.0, max(0.0, float(temp))) | |
# Handle special OpenAI models (gpt-5 and o1 series) | |
if ('gpt-5' in model_name or 'o1' in model_name) and provider == 'openai': | |
if 'gpt-5' in model_name: | |
MODEL_OBJECTS[model_name].__dict__['kwargs']['max_tokens'] = 16_000 | |
if 'o1' in model_name: | |
MODEL_OBJECTS[model_name].__dict__['kwargs']['max_tokens'] = 20_000 | |
MODEL_OBJECTS[model_name].__dict__['kwargs']['temperature'] = 1.0 | |
else: | |
# All other models | |
MODEL_OBJECTS[model_name].__dict__['kwargs']['max_tokens'] = model_config.get("max_tokens", DEFAULT_MODEL_CONFIG["max_tokens"]) | |
MODEL_OBJECTS[model_name].__dict__['kwargs']['temperature'] = temp | |
# If no valid session config, use default | |
return MODEL_OBJECTS[model_name] | |
# Initialize retrievers with empty data first | |
# clear console | |
def clear_console(): | |
os.system('cls' if os.name == 'nt' else 'clear') | |
# Check for Housing.csv | |
housing_csv_path = "Housing.csv" | |
if not os.path.exists(housing_csv_path): | |
logger.log_message(f"Housing.csv not found at {os.path.abspath(housing_csv_path)}", level=logging.ERROR) | |
raise FileNotFoundError(f"Housing.csv not found at {os.path.abspath(housing_csv_path)}") | |
# All agents are now loaded from database - no hardcoded dictionaries needed | |
# Add session header | |
X_SESSION_ID = APIKeyHeader(name="X-Session-ID", auto_error=False) | |
# Update AppState class to use SessionManager | |
# The AppState class is now in src.managers.app_manager | |
# Initialize FastAPI app with state | |
app = FastAPI(title="AI Analytics API", version="1.0") | |
# Pass required parameters to AppState | |
app.state = AppState(styling_instructions, chat_history_name_agent, DEFAULT_MODEL_CONFIG) | |
# Configure middleware | |
# Use a wildcard for local development or read from environment | |
is_development = os.getenv("ENVIRONMENT", "development").lower() == "development" | |
allowed_origins = [] | |
frontend_url = os.getenv("FRONTEND_URL", "").strip() | |
print(f"FRONTEND_URL: {frontend_url}") | |
if is_development: | |
allowed_origins = ["*"] | |
elif frontend_url: | |
allowed_origins = [frontend_url] | |
else: | |
logger.log_message("CORS misconfigured: FRONTEND_URL not set", level=logging.ERROR) | |
allowed_origins = [] # or set a default safe origin | |
# Add a strict origin verification middleware | |
async def verify_origin_middleware(request: Request, call_next): | |
# Skip origin check in development mode | |
if is_development: | |
return await call_next(request) | |
# Get the origin from the request headers | |
origin = request.headers.get("origin") | |
# Log the origin for debugging | |
if origin: | |
print(f"Request from origin: {origin}") | |
# If no origin header or origin not in allowed list, reject the request | |
if origin and frontend_url and origin != frontend_url: | |
print(f"Blocked request from unauthorized origin: {origin}") | |
return JSONResponse( | |
status_code=403, | |
content={"detail": "Not authorized"} | |
) | |
# Continue processing the request if origin is allowed | |
return await call_next(request) | |
# CORS middleware (still needed for browser preflight) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=allowed_origins, | |
allow_origin_regex=None, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
expose_headers=["*"], | |
max_age=600 # Cache preflight requests for 10 minutes (for performance) | |
) | |
# Add these constants at the top of the file with other imports/constants | |
RESPONSE_ERROR_INVALID_QUERY = "Please provide a valid query..." | |
RESPONSE_ERROR_NO_DATASET = "No dataset is currently loaded. Please link a dataset before proceeding with your analysis." | |
DEFAULT_TOKEN_RATIO = 1.5 | |
REQUEST_TIMEOUT_SECONDS = 30 # Timeout for LLM requests | |
MAX_RECENT_MESSAGES = 5 | |
DB_BATCH_SIZE = 10 # For future batch DB operations | |
async def chat_with_agent( | |
agent_name: str, | |
request: QueryRequest, | |
request_obj: Request, | |
session_id: str = Depends(get_session_id_dependency) | |
): | |
session_state = app.state.get_session_state(session_id) | |
logger.log_message(f"[DEBUG] chat_with_agent called with agent: '{agent_name}', query: '{request.query[:100]}...'", level=logging.DEBUG) | |
try: | |
# Extract and validate query parameters | |
logger.log_message(f"[DEBUG] Updating session from query params", level=logging.DEBUG) | |
_update_session_from_query_params(request_obj, session_state) | |
logger.log_message(f"[DEBUG] Session state after query params: user_id={session_state.get('user_id')}, chat_id={session_state.get('chat_id')}", level=logging.DEBUG) | |
# Validate dataset and agent name | |
if session_state["datasets"] is None: | |
logger.log_message(f"[DEBUG] No dataset loaded", level=logging.DEBUG) | |
raise HTTPException(status_code=400, detail=RESPONSE_ERROR_NO_DATASET) | |
# Log the dataset being used for analysis with detailed information | |
datasets = session_state["datasets"] | |
dataset_names = list(datasets.keys()) | |
if dataset_names: | |
current_dataset_name = dataset_names[-1] # Get the last (most recent) dataset | |
dataset_shape = datasets[current_dataset_name].shape | |
# Check if this is the default dataset and explain why | |
session_name = session_state.get("name", "") | |
is_default_dataset = (current_dataset_name == "df" and session_name == "Housing.csv") or current_dataset_name == "Housing.csv" | |
if is_default_dataset: | |
logger.log_message(f"[ANALYSIS] Using DEFAULT dataset 'Housing.csv' for analysis (shape: {dataset_shape[0]} rows, {dataset_shape[1]} columns)", level=logging.INFO) | |
logger.log_message(f"[ANALYSIS] Reason: No custom dataset uploaded yet - using default Housing.csv dataset", level=logging.INFO) | |
else: | |
logger.log_message(f"[ANALYSIS] Using CUSTOM dataset '{current_dataset_name}' for analysis (shape: {dataset_shape[0]} rows, {dataset_shape[1]} columns)", level=logging.INFO) | |
logger.log_message(f"[ANALYSIS] This is a user-uploaded dataset, not the default", level=logging.INFO) | |
else: | |
logger.log_message(f"[ANALYSIS] No datasets available in session {session_id}", level=logging.WARNING) | |
logger.log_message(f"[DEBUG] About to validate agent name: '{agent_name}'", level=logging.DEBUG) | |
_validate_agent_name(agent_name, session_state) | |
logger.log_message(f"[DEBUG] Agent validation completed successfully", level=logging.DEBUG) | |
# Record start time for timing | |
start_time = time.time() | |
# Get chat context and prepare query | |
logger.log_message(f"[DEBUG] Preparing query with context", level=logging.DEBUG) | |
enhanced_query = _prepare_query_with_context(request.query, session_state) | |
logger.log_message(f"[DEBUG] Enhanced query length: {len(enhanced_query)}", level=logging.DEBUG) | |
# Initialize agent - handle standard, template, and custom agents | |
if "," in agent_name: | |
logger.log_message(f"[DEBUG] Processing multiple agents: {agent_name}", level=logging.DEBUG) | |
# Multiple agents case | |
agent_list = [agent.strip() for agent in agent_name.split(",")] | |
# Categorize agents | |
standard_agents = [agent for agent in agent_list if _is_standard_agent(agent)] | |
template_agents = [agent for agent in agent_list if _is_template_agent(agent)] | |
custom_agents = [agent for agent in agent_list if not _is_standard_agent(agent) and not _is_template_agent(agent)] | |
logger.log_message(f"[DEBUG] Agent categorization - standard: {standard_agents}, template: {template_agents}, custom: {custom_agents}", level=logging.DEBUG) | |
if custom_agents: | |
# If any custom agents, use session AI system for all | |
ai_system = session_state["ai_system"] | |
session_lm = get_session_lm(session_state) | |
logger.log_message(f"[DEBUG] Using custom agent execution path", level=logging.DEBUG) | |
with dspy.context(lm=session_lm): | |
response = await asyncio.wait_for( | |
_execute_custom_agents(ai_system, agent_list, enhanced_query), | |
timeout=REQUEST_TIMEOUT_SECONDS | |
) | |
logger.log_message(f"[DEBUG] Custom agents response type: {type(response)}, keys: {list(response.keys()) if isinstance(response, dict) else 'not a dict'}", level=logging.DEBUG) | |
else: | |
# All standard/template agents - use auto_analyst_ind which loads from DB | |
user_id = session_state.get("user_id") | |
logger.log_message(f"[DEBUG] Using auto_analyst_ind for multiple standard/template agents with user_id: {user_id}", level=logging.DEBUG) | |
# Create database session for agent loading | |
from src.db.init_db import session_factory | |
db_session = session_factory() | |
try: | |
# auto_analyst_ind will load all agents from database | |
logger.log_message(f"[DEBUG] Creating auto_analyst_ind instance", level=logging.DEBUG) | |
agent = auto_analyst_ind(agents=[], retrievers=session_state["retrievers"], user_id=user_id, db_session=db_session) | |
session_lm = get_session_lm(session_state) | |
logger.log_message(f"[DEBUG] About to call agent.forward with query and agent list", level=logging.DEBUG) | |
with dspy.context(lm=session_lm): | |
response = await asyncio.wait_for( | |
agent(enhanced_query, ",".join(agent_list)), | |
timeout=REQUEST_TIMEOUT_SECONDS | |
) | |
logger.log_message(f"[DEBUG] auto_analyst_ind response type: {type(response)}, content: {str(response)[:200]}...", level=logging.DEBUG) | |
finally: | |
db_session.close() | |
else: | |
logger.log_message(f"[DEBUG] Processing single agent: {agent_name}", level=logging.DEBUG) | |
# Single agent case | |
if _is_standard_agent(agent_name) or _is_template_agent(agent_name): | |
# Standard or template agent - use auto_analyst_ind which loads from DB | |
user_id = session_state.get("user_id") | |
logger.log_message(f"[DEBUG] Using auto_analyst_ind for single standard/template agent '{agent_name}' with user_id: {user_id}", level=logging.DEBUG) | |
# Create database session for agent loading | |
from src.db.init_db import session_factory | |
db_session = session_factory() | |
try: | |
# auto_analyst_ind will load all agents from database | |
logger.log_message(f"[DEBUG] Creating auto_analyst_ind instance for single agent", level=logging.DEBUG) | |
agent = auto_analyst_ind(agents=[], retrievers=session_state["retrievers"], user_id=user_id, db_session=db_session) | |
session_lm = get_session_lm(session_state) | |
logger.log_message(f"[DEBUG] About to call agent.forward for single agent '{agent_name}'", level=logging.DEBUG) | |
with dspy.context(lm=session_lm): | |
response = await asyncio.wait_for( | |
agent(enhanced_query, agent_name), | |
timeout=REQUEST_TIMEOUT_SECONDS | |
) | |
logger.log_message(f"[DEBUG] Single agent response type: {type(response)}, content: {str(response)[:200]}...", level=logging.DEBUG) | |
finally: | |
db_session.close() | |
else: | |
# Custom agent - use session AI system | |
ai_system = session_state["ai_system"] | |
session_lm = get_session_lm(session_state) | |
logger.log_message(f"[DEBUG] Using custom agent execution for '{agent_name}'", level=logging.DEBUG) | |
with dspy.context(lm=session_lm): | |
response = await asyncio.wait_for( | |
_execute_custom_agents(ai_system, [agent_name], enhanced_query), | |
timeout=REQUEST_TIMEOUT_SECONDS | |
) | |
logger.log_message(f"[DEBUG] Custom single agent response type: {type(response)}, content: {str(response)[:200]}...", level=logging.DEBUG) | |
logger.log_message(f"[DEBUG] About to format response to markdown. Response type: {type(response)}", level=logging.DEBUG) | |
formatted_response = format_response_to_markdown(response, agent_name, session_state["datasets"]) | |
logger.log_message(f"[DEBUG] Formatted response type: {type(formatted_response)}, length: {len(str(formatted_response))}", level=logging.DEBUG) | |
if formatted_response == RESPONSE_ERROR_INVALID_QUERY: | |
logger.log_message(f"[DEBUG] Response was invalid query error", level=logging.DEBUG) | |
return { | |
"agent_name": agent_name, | |
"query": request.query, | |
"response": formatted_response, | |
"session_id": session_id | |
} | |
# Track usage statistics | |
if session_state.get("user_id"): | |
logger.log_message(f"[DEBUG] Tracking model usage", level=logging.DEBUG) | |
_track_model_usage( | |
session_state=session_state, | |
enhanced_query=enhanced_query, | |
response=response, | |
processing_time_ms=int((time.time() - start_time) * 1000) | |
) | |
logger.log_message(f"[DEBUG] chat_with_agent completed successfully", level=logging.DEBUG) | |
return { | |
"agent_name": agent_name, | |
"query": request.query, # Return original query without context | |
"response": formatted_response, | |
"session_id": session_id | |
} | |
except HTTPException: | |
# Re-raise HTTP exceptions to preserve status codes | |
logger.log_message(f"[DEBUG] HTTPException caught and re-raised", level=logging.DEBUG) | |
raise | |
except asyncio.TimeoutError: | |
logger.log_message(f"[ERROR] Timeout error in chat_with_agent", level=logging.ERROR) | |
raise HTTPException(status_code=504, detail="Request timed out. Please try a simpler query.") | |
except Exception as e: | |
logger.log_message(f"[ERROR] Unexpected error in chat_with_agent: {str(e)}", level=logging.ERROR) | |
logger.log_message(f"[ERROR] Exception type: {type(e)}, traceback: {str(e)}", level=logging.ERROR) | |
import traceback | |
logger.log_message(f"[ERROR] Full traceback: {traceback.format_exc()}", level=logging.ERROR) | |
raise HTTPException(status_code=500, detail="An unexpected error occurred. Please try again later.") | |
async def chat_with_all( | |
request: QueryRequest, | |
request_obj: Request, | |
session_id: str = Depends(get_session_id_dependency) | |
): | |
session_state = app.state.get_session_state(session_id) | |
try: | |
# Extract and validate query parameters | |
_update_session_from_query_params(request_obj, session_state) | |
# Validate dataset | |
if session_state["datasets"] is None: | |
raise HTTPException(status_code=400, detail=RESPONSE_ERROR_NO_DATASET) | |
if session_state["ai_system"] is None: | |
raise HTTPException(status_code=500, detail="AI system not properly initialized.") | |
# Get session-specific model | |
session_lm = get_session_lm(session_state) | |
# Create streaming response | |
return StreamingResponse( | |
_generate_streaming_responses(session_state, request.query, session_lm), | |
media_type='text/event-stream', | |
headers={ | |
'Cache-Control': 'no-cache', | |
'Connection': 'keep-alive', | |
'Content-Type': 'text/event-stream', | |
'Access-Control-Allow-Origin': '*', | |
'X-Accel-Buffering': 'no' | |
} | |
) | |
except HTTPException: | |
# Re-raise HTTP exceptions to preserve status codes | |
raise | |
except Exception as e: | |
raise HTTPException(status_code=500, detail="An unexpected error occurred. Please try again later.") | |
# Helper functions to reduce duplication and improve modularity | |
def _update_session_from_query_params(request_obj: Request, session_state: dict): | |
"""Extract and validate chat_id and user_id from query parameters""" | |
# Check for chat_id in query parameters | |
if "chat_id" in request_obj.query_params: | |
try: | |
chat_id_param = int(request_obj.query_params.get("chat_id")) | |
# Update session state with this chat ID | |
session_state["chat_id"] = chat_id_param | |
except (ValueError, TypeError): | |
logger.log_message("Invalid chat_id parameter", level=logging.WARNING) | |
# Continue without updating chat_id | |
# Check for user_id in query parameters | |
if "user_id" in request_obj.query_params: | |
try: | |
user_id = int(request_obj.query_params["user_id"]) | |
session_state["user_id"] = user_id | |
except (ValueError, TypeError): | |
raise HTTPException( | |
status_code=400, | |
detail="Invalid user_id in query params. Please provide a valid integer." | |
) | |
def _validate_agent_name(agent_name: str, session_state: dict = None): | |
"""Validate that the agent name(s) are available""" | |
logger.log_message(f"[DEBUG] Validating agent name: '{agent_name}'", level=logging.DEBUG) | |
if "," in agent_name: | |
# Multiple agents | |
agent_list = [agent.strip() for agent in agent_name.split(",")] | |
logger.log_message(f"[DEBUG] Multiple agents detected: {agent_list}", level=logging.DEBUG) | |
for agent in agent_list: | |
is_available = _is_agent_available(agent, session_state) | |
logger.log_message(f"[DEBUG] Agent '{agent}' availability: {is_available}", level=logging.DEBUG) | |
if not is_available: | |
available_agents = _get_available_agents_list(session_state) | |
logger.log_message(f"[DEBUG] Agent '{agent}' not found. Available: {available_agents}", level=logging.DEBUG) | |
raise HTTPException( | |
status_code=400, | |
detail=f"Agent '{agent}' not found. Available agents: {available_agents}" | |
) | |
else: | |
# Single agent | |
is_available = _is_agent_available(agent_name, session_state) | |
logger.log_message(f"[DEBUG] Single agent '{agent_name}' availability: {is_available}", level=logging.DEBUG) | |
if not is_available: | |
available_agents = _get_available_agents_list(session_state) | |
logger.log_message(f"[DEBUG] Agent '{agent_name}' not found. Available: {available_agents}", level=logging.DEBUG) | |
raise HTTPException( | |
status_code=400, | |
detail=f"Agent '{agent_name}' not found. Available agents: {available_agents}" | |
) | |
logger.log_message(f"[DEBUG] Agent validation passed for: '{agent_name}'", level=logging.DEBUG) | |
def _is_agent_available(agent_name: str, session_state: dict = None) -> bool: | |
"""Check if an agent is available (standard, template, or custom)""" | |
# Check if it's a standard agent | |
if _is_standard_agent(agent_name): | |
return True | |
# Check if it's a template agent | |
if _is_template_agent(agent_name): | |
return True | |
# Check if it's a custom agent in session | |
if session_state and "ai_system" in session_state: | |
ai_system = session_state["ai_system"] | |
if hasattr(ai_system, 'agents') and agent_name in ai_system.agents: | |
return True | |
return False | |
def _get_available_agents_list(session_state: dict = None) -> list: | |
"""Get list of all available agents from database""" | |
from src.db.init_db import session_factory | |
from src.agents.agents import load_all_available_templates_from_db | |
# Core agents (always available) | |
available = ["preprocessing_agent", "statistical_analytics_agent", "sk_learn_agent", "data_viz_agent"] | |
# Add template agents from database | |
db_session = session_factory() | |
try: | |
template_agents_dict = load_all_available_templates_from_db(db_session) | |
# template_agents_dict is a dict with template_name as keys | |
template_names = [template_name for template_name in template_agents_dict.keys() | |
if template_name not in available and template_name != 'basic_qa_agent'] | |
available.extend(template_names) | |
except Exception as e: | |
logger.log_message(f"Error loading template agents: {str(e)}", level=logging.ERROR) | |
finally: | |
db_session.close() | |
return available | |
def _is_standard_agent(agent_name: str) -> bool: | |
"""Check if agent is one of the 4 core standard agents""" | |
standard_agents = ["preprocessing_agent", "statistical_analytics_agent", "sk_learn_agent", "data_viz_agent"] | |
return agent_name in standard_agents | |
def _is_template_agent(agent_name: str) -> bool: | |
"""Check if agent is a template agent""" | |
try: | |
from src.db.init_db import session_factory | |
from src.db.schemas.models import AgentTemplate | |
db_session = session_factory() | |
try: | |
template = db_session.query(AgentTemplate).filter( | |
AgentTemplate.template_name == agent_name, | |
AgentTemplate.is_active == True | |
).first() | |
return template is not None | |
finally: | |
db_session.close() | |
except Exception as e: | |
logger.log_message(f"Error checking if {agent_name} is template: {str(e)}", level=logging.ERROR) | |
return False | |
async def _execute_custom_agents(ai_system, agent_names: list, query: str): | |
"""Execute custom agents using the session's AI system""" | |
try: | |
# For custom agents, we need to use the AI system's execute_agent method | |
agent_results = [ai_system] | |
if len(agent_names) == 1: | |
# Single custom agent | |
agent_name = agent_names[0] | |
# Prepare inputs for the custom agent (similar to standard agents like data_viz_agent) | |
dict_ = {} | |
dict_['dataset'] = ai_system.dataset.retrieve(query)[0].text | |
dict_['styling_index'] = ai_system.styling_index.retrieve(query)[0].text | |
dict_['goal'] = query | |
dict_['Agent_desc'] = str(ai_system.agent_desc) | |
# Get input fields for this agent | |
if agent_name in ai_system.agent_inputs: | |
inputs = {x: dict_[x] for x in ai_system.agent_inputs[agent_name] if x in dict_} | |
# Execute the custom agent | |
agent_name_result, result_dict = await ai_system.agents[agent_name](**inputs) | |
return {agent_name_result: result_dict} | |
else: | |
logger.log_message(f"Agent '{agent_name}' not found in ai_system.agent_inputs", level=logging.ERROR) | |
return {"error": f"Agent '{agent_name}' input configuration not found"} | |
else: | |
# Multiple agents - execute sequentially | |
results = {} | |
for agent_name in agent_names: | |
single_result = await _execute_custom_agents(ai_system, [agent_name], query) | |
results.update(single_result) | |
return results | |
except Exception as e: | |
logger.log_message(f"Error in _execute_custom_agents: {str(e)}", level=logging.ERROR) | |
return {"error": f"Error executing custom agents: {str(e)}"} | |
def _prepare_query_with_context(query: str, session_state: dict) -> str: | |
"""Prepare the query with chat context from previous messages""" | |
chat_id = session_state.get("chat_id") | |
if not chat_id: | |
return query | |
# Get chat manager from app state | |
chat_manager = app.state._session_manager.chat_manager | |
# Get recent messages | |
recent_messages = chat_manager.get_recent_chat_history(chat_id, limit=MAX_RECENT_MESSAGES) | |
# Extract response history | |
chat_context = chat_manager.extract_response_history(recent_messages) | |
# Append context to the query if available | |
if chat_context: | |
return f"### Current Query:\n{query}\n\n{chat_context}" | |
return query | |
def _track_model_usage(session_state: dict, enhanced_query: str, response, processing_time_ms: int): | |
"""Track model usage statistics in the database""" | |
try: | |
ai_manager = app.state.get_ai_manager() | |
# Get model configuration | |
model_config = session_state.get("model_config", DEFAULT_MODEL_CONFIG) | |
model_name = model_config.get("model", DEFAULT_MODEL_CONFIG["model"]) | |
provider = ai_manager.get_provider_for_model(model_name) | |
# Calculate token usage | |
try: | |
# Try exact tokenization | |
prompt_tokens = len(ai_manager.tokenizer.encode(enhanced_query)) | |
completion_tokens = len(ai_manager.tokenizer.encode(str(response))) | |
total_tokens = prompt_tokens + completion_tokens | |
except Exception as token_error: | |
# Fall back to estimation | |
logger.log_message(f"Tokenization error: {str(token_error)}", level=logging.WARNING) | |
prompt_words = len(enhanced_query.split()) | |
completion_words = len(str(response).split()) | |
prompt_tokens = int(prompt_words * DEFAULT_TOKEN_RATIO) | |
completion_tokens = int(completion_words * DEFAULT_TOKEN_RATIO) | |
total_tokens = prompt_tokens + completion_tokens | |
# Calculate cost | |
cost = ai_manager.calculate_cost(model_name, prompt_tokens, completion_tokens) | |
# Save usage to database | |
ai_manager.save_usage_to_db( | |
user_id=session_state.get("user_id"), | |
chat_id=session_state.get("chat_id"), | |
model_name=model_name, | |
provider=provider, | |
prompt_tokens=int(prompt_tokens), | |
completion_tokens=int(completion_tokens), | |
total_tokens=int(total_tokens), | |
query_size=len(enhanced_query), | |
response_size=len(str(response)), | |
cost=round(cost, 7), | |
request_time_ms=processing_time_ms, | |
is_streaming=False | |
) | |
except Exception as e: | |
# Log but don't fail the request if usage tracking fails | |
logger.log_message(f"Failed to track model usage: {str(e)}", level=logging.ERROR) | |
async def _generate_streaming_responses(session_state: dict, query: str, session_lm): | |
"""Generate streaming responses for chat_with_all endpoint""" | |
overall_start_time = time.time() | |
total_response = "" | |
total_inputs = "" | |
usage_records = [] | |
# Add chat context from previous messages | |
enhanced_query = _prepare_query_with_context(query, session_state) | |
# try: | |
# Get the plan - planner is now async, so we need to await it | |
plan_response = await session_state["ai_system"].get_plan(enhanced_query) | |
plan_description = format_response_to_markdown( | |
{"analytical_planner": plan_response}, | |
datasets=session_state["datasets"] | |
) | |
# Check if plan is valid | |
if plan_description == RESPONSE_ERROR_INVALID_QUERY: | |
yield json.dumps({ | |
"agent": "Analytical Planner", | |
"content": plan_description, | |
"status": "error" | |
}) + "\n" | |
return | |
yield json.dumps({ | |
"agent": "Analytical Planner", | |
"content": plan_description, | |
"status": "success" if plan_description else "error" | |
}) + "\n" | |
# Track planner usage | |
if session_state.get("user_id"): | |
planner_tokens = _estimate_tokens(ai_manager=app.state.ai_manager, | |
input_text=enhanced_query, | |
output_text=plan_description) | |
usage_records.append(_create_usage_record( | |
session_state=session_state, | |
model_name=session_state.get("model_config", DEFAULT_MODEL_CONFIG)["model"], | |
prompt_tokens=planner_tokens["prompt"], | |
completion_tokens=planner_tokens["completion"], | |
query_size=len(enhanced_query), | |
response_size=len(plan_description), | |
processing_time_ms=int((time.time() - overall_start_time) * 1000), | |
is_streaming=False | |
)) | |
logger.log_message(f"Plan response: {plan_response}", level=logging.INFO) | |
logger.log_message(f"Plan response type: {type(plan_response)}", level=logging.INFO) | |
# Check if plan_response is valid | |
# if not plan_response or not isinstance(plan_response, dict): | |
# yield json.dumps({ | |
# "agent": "Analytical Planner", | |
# "content": "**Error: Invalid plan response**\n\nResponse: " + str(plan_response), | |
# "status": "error" | |
# }) + "\n" | |
# return | |
# Execute the plan with well-managed concurrency | |
with dspy.context(lm = session_lm): | |
# try: | |
async for agent_name, inputs, response in session_state["ai_system"].execute_plan(enhanced_query, plan_response): | |
if agent_name == "plan_not_found": | |
yield json.dumps({ | |
"agent": "Analytical Planner", | |
"content": "**No plan found**\n\nPlease try again with a different query or try using a different model.", | |
"status": "error" | |
}) + "\n" | |
return | |
if agent_name == "plan_not_formated_correctly": | |
yield json.dumps({ | |
"agent": "Analytical Planner", | |
"content": "**Something went wrong with formatting, retry the query!**", | |
"status": "error" | |
}) + "\n" | |
return | |
formatted_response = format_response_to_markdown( | |
{agent_name: response}, | |
datasets=session_state["datasets"] | |
) | |
yield json.dumps({ | |
"agent": agent_name.split("__")[0] if "__" in agent_name else agent_name, | |
"content": formatted_response, | |
"status": "success" if response else "error" | |
}) + "\n" | |
# Handle agent errors | |
if isinstance(response, dict) and "error" in response: | |
yield json.dumps({ | |
"agent": agent_name, | |
"content": f"**Error in {agent_name}**: {response['error']}", | |
"status": "error" | |
}) + "\n" | |
continue # Continue with next agent instead of returning | |
if formatted_response == RESPONSE_ERROR_INVALID_QUERY: | |
yield json.dumps({ | |
"agent": agent_name, | |
"content": formatted_response, | |
"status": "error" | |
}) + "\n" | |
continue # Continue with next agent instead of returning | |
# Send response chunk | |
# Track agent usage for future batch DB write | |
if session_state.get("user_id"): | |
agent_tokens = _estimate_tokens( | |
ai_manager=app.state.ai_manager, | |
input_text=str(inputs), | |
output_text=str(response) | |
) | |
# Get appropriate model name for code combiner | |
if "code_combiner_agent" in agent_name and "__" in agent_name: | |
provider = agent_name.split("__")[1] | |
model_name = _get_model_name_for_provider(provider) | |
else: | |
model_name = session_state.get("model_config", DEFAULT_MODEL_CONFIG)["model"] | |
usage_records.append(_create_usage_record( | |
session_state=session_state, | |
model_name=model_name, | |
prompt_tokens=agent_tokens["prompt"], | |
completion_tokens=agent_tokens["completion"], | |
query_size=len(str(inputs)), | |
response_size=len(str(response)), | |
processing_time_ms=int((time.time() - overall_start_time) * 1000), | |
is_streaming=True | |
)) | |
# except asyncio.TimeoutError: | |
# yield json.dumps({ | |
# "agent": "planner", | |
# "content": "The request timed out. Please try a simpler query.", | |
# "status": "error" | |
# }) + "\n" | |
# return | |
# except Exception as e: | |
# logger.log_message(f"Error executing plan: {str(e)}", level=logging.ERROR) | |
# yield json.dumps({ | |
# "agent": "planner", | |
# "content": f"An error occurred while executing the plan: {str(e)}", | |
# "status": "error" | |
# }) + "\n" | |
# return | |
# except Exception as e: | |
# logger.log_message(f"Error in streaming response: {str(e)}", level=logging.ERROR) | |
# yield json.dumps({ | |
# "agent": "planner", | |
# "content": "An error occurred while generating responses. Please try again!" + str(e) + str({k: v for k, v in session_lm.__dict__['kwargs'].items() if k != 'api_key'}), | |
# "status": "error" | |
# }) + "\n" | |
def _estimate_tokens(ai_manager, input_text: str, output_text: str) -> dict: | |
"""Estimate token counts, with fallback for tokenization errors""" | |
try: | |
# Try exact tokenization | |
prompt_tokens = len(ai_manager.tokenizer.encode(input_text)) | |
completion_tokens = len(ai_manager.tokenizer.encode(output_text)) | |
except Exception: | |
# Fall back to estimation | |
prompt_words = len(input_text.split()) | |
completion_words = len(output_text.split()) | |
prompt_tokens = int(prompt_words * DEFAULT_TOKEN_RATIO) | |
completion_tokens = int(completion_words * DEFAULT_TOKEN_RATIO) | |
return { | |
"prompt": prompt_tokens, | |
"completion": completion_tokens, | |
"total": prompt_tokens + completion_tokens | |
} | |
def _create_usage_record(session_state: dict, model_name: str, prompt_tokens: int, | |
completion_tokens: int, query_size: int, response_size: int, | |
processing_time_ms: int, is_streaming: bool) -> dict: | |
"""Create a usage record for the database""" | |
ai_manager = app.state.get_ai_manager() | |
provider = ai_manager.get_provider_for_model(model_name) | |
cost = ai_manager.calculate_cost(model_name, prompt_tokens, completion_tokens) | |
return { | |
"user_id": session_state.get("user_id"), | |
"chat_id": session_state.get("chat_id"), | |
"model_name": model_name, | |
"provider": provider, | |
"prompt_tokens": int(prompt_tokens), | |
"completion_tokens": int(completion_tokens), | |
"total_tokens": int(prompt_tokens + completion_tokens), | |
"query_size": query_size, | |
"response_size": response_size, | |
"cost": round(cost, 7), | |
"request_time_ms": processing_time_ms, | |
"is_streaming": is_streaming | |
} | |
def _get_model_name_for_provider(provider: str) -> str: | |
"""Get the model name for a provider""" | |
provider_model_map = { | |
"openai": "o3-mini", | |
"anthropic": "claude-3-7-sonnet-latest", | |
"gemini": "gemini-2.5-pro-preview-03-25" | |
} | |
return provider_model_map.get(provider, "o3-mini") | |
# Add an endpoint to list available agents | |
async def list_agents(request: Request, session_id: str = Depends(get_session_id_dependency)): | |
"""Get all available agents (standard, template, and custom)""" | |
session_state = app.state.get_session_state(session_id) | |
try: | |
# Get all available agents from database and session | |
available_agents_list = _get_available_agents_list(session_state) | |
# Categorize agents | |
standard_agents = ["preprocessing_agent", "statistical_analytics_agent", "sk_learn_agent", "data_viz_agent"] | |
# Get template agents from database | |
from src.db.init_db import session_factory | |
from src.agents.agents import load_all_available_templates_from_db | |
db_session = session_factory() | |
try: | |
template_agents_dict = load_all_available_templates_from_db(db_session) | |
# template_agents_dict is a dict with template_name as keys | |
template_agents = [template_name for template_name in template_agents_dict.keys() | |
if template_name not in standard_agents and template_name != 'basic_qa_agent'] | |
except Exception as e: | |
logger.log_message(f"Error loading template agents in /agents endpoint: {str(e)}", level=logging.ERROR) | |
template_agents = [] | |
finally: | |
db_session.close() | |
# Get custom agents from session | |
custom_agents = [] | |
if session_state and "ai_system" in session_state: | |
ai_system = session_state["ai_system"] | |
if hasattr(ai_system, 'agents'): | |
custom_agents = [agent for agent in available_agents_list | |
if agent not in standard_agents and agent not in template_agents] | |
# Ensure template agents are in the available list | |
for template_agent in template_agents: | |
if template_agent not in available_agents_list: | |
available_agents_list.append(template_agent) | |
return { | |
"available_agents": available_agents_list, | |
"standard_agents": standard_agents, | |
"template_agents": template_agents, | |
"custom_agents": custom_agents | |
} | |
except Exception as e: | |
logger.log_message(f"Error getting agents list: {str(e)}", level=logging.ERROR) | |
raise HTTPException(status_code=500, detail=f"Error getting agents list: {str(e)}") | |
async def health(): | |
return {"message": "API is healthy and running"} | |
async def index(): | |
return { | |
"title": "Welcome to the AI Analytics API", | |
"message": "Explore our API for advanced analytics and visualization tools designed to empower your data-driven decisions.", | |
"description": "Utilize our powerful agents and models to gain insights from your data effortlessly.", | |
"colors": { | |
"primary": "#007bff", | |
"secondary": "#6c757d", | |
"success": "#28a745", | |
"danger": "#dc3545", | |
}, | |
"features": [ | |
"Real-time data processing", | |
"Customizable visualizations", | |
"Seamless integration with various data sources", | |
"User-friendly interface for easy navigation", | |
"Custom Analytics", | |
], | |
} | |
async def chat_history_name(request: dict, session_id: str = Depends(get_session_id_dependency)): | |
query = request.get("query") | |
name = None | |
lm = dspy.LM(model="gpt-4o-mini", max_tokens=300, temperature=0.5) | |
with dspy.context(lm=lm): | |
name = app.state.get_chat_history_name_agent()(query=str(query)) | |
return {"name": name.name if name else "New Chat"} | |
async def deep_analysis_streaming( | |
request: DeepAnalysisRequest, | |
request_obj: Request, | |
session_id: str = Depends(get_session_id_dependency) | |
): | |
"""Perform streaming deep analysis with real-time updates""" | |
session_state = app.state.get_session_state(session_id) | |
try: | |
# Extract and validate query parameters | |
_update_session_from_query_params(request_obj, session_state) | |
# Validate dataset | |
if session_state["datasets"] is None: | |
raise HTTPException(status_code=400, detail=RESPONSE_ERROR_NO_DATASET) | |
# Get user_id from session state (if available) | |
user_id = session_state.get("user_id") | |
# Generate a UUID for this report | |
import uuid | |
report_uuid = str(uuid.uuid4()) | |
# Create initial pending report in the database | |
try: | |
from src.db.init_db import session_factory | |
from src.db.schemas.models import DeepAnalysisReport | |
db_session = session_factory() | |
try: | |
# Create a pending report entry | |
new_report = DeepAnalysisReport( | |
report_uuid=report_uuid, | |
user_id=user_id, | |
goal=request.goal, | |
status="pending", | |
start_time=datetime.now(UTC), | |
progress_percentage=0 | |
) | |
db_session.add(new_report) | |
db_session.commit() | |
db_session.refresh(new_report) | |
# Store the report ID in session state for later updates | |
session_state["current_deep_analysis_id"] = new_report.report_id | |
session_state["current_deep_analysis_uuid"] = report_uuid | |
except Exception as e: | |
logger.log_message(f"Error creating initial deep analysis report: {str(e)}", level=logging.ERROR) | |
# Continue even if DB storage fails | |
finally: | |
db_session.close() | |
except Exception as e: | |
logger.log_message(f"Database operation failed: {str(e)}", level=logging.ERROR) | |
# Continue even if DB operation fails | |
# Get session-specific model | |
# session_lm = get_session_lm(session_state) | |
session_lm = dspy.LM(model="anthropic/claude-sonnet-4-20250514", max_tokens=7000, temperature=0.5) | |
return StreamingResponse( | |
_generate_deep_analysis_stream(session_state, request.goal, session_lm, session_id), | |
media_type='text/event-stream', | |
headers={ | |
'Cache-Control': 'no-cache', | |
'Connection': 'keep-alive', | |
'Content-Type': 'text/event-stream', | |
'Access-Control-Allow-Origin': '*', | |
'X-Accel-Buffering': 'no' | |
} | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.log_message(f"Streaming deep analysis failed: {str(e)}", level=logging.ERROR) | |
raise HTTPException(status_code=500, detail=f"Streaming deep analysis failed: {str(e)}") | |
async def _generate_deep_analysis_stream(session_state: dict, goal: str, session_lm, session_id: str): | |
"""Generate streaming responses for deep analysis""" | |
# Track the start time for duration calculation | |
start_time = datetime.now(UTC) | |
try: | |
# Get dataset info | |
datasets = session_state["datasets"] | |
desc = session_state['description'] | |
# Generate dataset info for all datasets | |
logger.log_message(f"🔍 DEEP ANALYSIS START - datasets type: {type(datasets)}, keys: {list(datasets.keys()) if datasets else 'None'}", level=logging.DEBUG) | |
dataset_info = desc | |
logger.log_message(f"🔍 DEEP ANALYSIS - dataset_info type: {type(dataset_info)}, length: {len(dataset_info) if isinstance(dataset_info, str) else 'N/A'}", level=logging.DEBUG) | |
logger.log_message(f"🔍 DEEP ANALYSIS - dataset_info content: {dataset_info[:200]}...", level=logging.DEBUG) | |
# Get report info from session state | |
report_id = session_state.get("current_deep_analysis_id") | |
report_uuid = session_state.get("current_deep_analysis_uuid") | |
user_id = session_state.get("user_id") | |
# Helper function to update report in database | |
async def update_report_in_db(status, progress, step=None, content=None): | |
if not report_id: | |
return | |
try: | |
from src.db.init_db import session_factory | |
from src.db.schemas.models import DeepAnalysisReport | |
db_session = session_factory() | |
try: | |
report = db_session.query(DeepAnalysisReport).filter(DeepAnalysisReport.report_id == report_id).first() | |
if report: | |
report.status = status | |
report.progress_percentage = progress | |
# Update step-specific fields if provided | |
if step == "questions" and content: | |
report.deep_questions = content | |
elif step == "planning" and content: | |
report.deep_plan = content | |
elif step == "analysis" and content: | |
# For analysis step, we get the full object with multiple fields | |
if isinstance(content, dict): | |
# Update fields from content if they exist | |
if "deep_questions" in content and content["deep_questions"]: | |
report.deep_questions = content["deep_questions"] | |
if "deep_plan" in content and content["deep_plan"]: | |
report.deep_plan = content["deep_plan"] | |
if "code" in content and content["code"]: | |
report.analysis_code = content["code"] | |
if "final_conclusion" in content and content["final_conclusion"]: | |
report.final_conclusion = content["final_conclusion"] | |
# Also update summary from conclusion | |
conclusion = content["final_conclusion"] | |
conclusion = conclusion.replace("**Conclusion**", "") | |
report.report_summary = conclusion[:200] + "..." if len(conclusion) > 200 else conclusion | |
# Handle JSON fields | |
if "summaries" in content and content["summaries"]: | |
report.summaries = json.dumps(content["summaries"]) | |
if "plotly_figs" in content and content["plotly_figs"]: | |
report.plotly_figures = json.dumps(content["plotly_figs"]) | |
if "synthesis" in content and content["synthesis"]: | |
report.synthesis = json.dumps(content["synthesis"]) | |
# For the final step, update the HTML report | |
if step == "completed": | |
if content: | |
report.html_report = content | |
else: | |
logger.log_message("No HTML content provided for completed step", level=logging.WARNING) | |
report.end_time = datetime.now(UTC) | |
# Ensure start_time is timezone-aware before calculating duration | |
if report.start_time.tzinfo is None: | |
start_time_utc = report.start_time.replace(tzinfo=UTC) | |
else: | |
start_time_utc = report.start_time | |
report.duration_seconds = int((report.end_time - start_time_utc).total_seconds()) | |
report.updated_at = datetime.now(UTC) | |
db_session.commit() | |
except Exception as e: | |
db_session.rollback() | |
logger.log_message(f"Error updating deep analysis report: {str(e)}", level=logging.ERROR) | |
finally: | |
db_session.close() | |
except Exception as e: | |
logger.log_message(f"Database operation failed: {str(e)}", level=logging.ERROR) | |
# Use session model for this request | |
with dspy.context(lm=session_lm): | |
# Send initial status | |
yield json.dumps({ | |
"step": "initialization", | |
"status": "starting", | |
"message": "Initializing deep analysis...", | |
"progress": 5 | |
}) + "\n" | |
# Update DB status to running | |
await update_report_in_db("running", 5) | |
# Get deep analyzer - use the correct session_id from the session_state | |
logger.log_message(f"Getting deep analyzer for session_id: {session_id}, user_id: {user_id}", level=logging.INFO) | |
deep_analyzer = app.state.get_deep_analyzer(session_id) | |
# Make all datasets available globally for code execution | |
for dataset_name, dataset_df in datasets.items(): | |
globals()[dataset_name] = dataset_df | |
# Use the new streaming method and forward all progress updates | |
final_result = None | |
logger.log_message(f"🔍 CALLING DEEP ANALYSIS - goal: {goal[:100]}...", level=logging.DEBUG) | |
logger.log_message(f"🔍 CALLING DEEP ANALYSIS - dataset_info type: {type(dataset_info)}, length: {len(dataset_info) if isinstance(dataset_info, str) else 'N/A'}", level=logging.DEBUG) | |
logger.log_message(f"🔍 CALLING DEEP ANALYSIS - session_datasets type: {type(datasets)}, keys: {list(datasets.keys()) if datasets else 'None'}", level=logging.DEBUG) | |
async for update in deep_analyzer.execute_deep_analysis_streaming( | |
goal=goal, | |
dataset_info=dataset_info, | |
session_datasets=datasets # Pass all datasets instead of single df | |
): | |
# Convert the update to the expected format and yield it | |
if update.get("step") == "questions" and update.get("status") == "completed": | |
# Update DB with questions | |
await update_report_in_db("running", update.get("progress", 0), "questions", update.get("content")) | |
elif update.get("step") == "planning" and update.get("status") == "completed": | |
# Update DB with planning | |
await update_report_in_db("running", update.get("progress", 0), "planning", update.get("content")) | |
elif update.get("step") == "conclusion" and update.get("status") == "completed": | |
# Store the final result for later processing | |
final_result = update.get("final_result") | |
# Convert Plotly figures to JSON format for network transmission | |
if final_result: | |
import plotly.io | |
serialized_return_dict = final_result.copy() | |
# Convert plotly_figs to JSON format | |
if 'plotly_figs' in serialized_return_dict and serialized_return_dict['plotly_figs']: | |
json_figs = [] | |
for fig_list in serialized_return_dict['plotly_figs']: | |
if isinstance(fig_list, list): | |
json_fig_list = [] | |
for fig in fig_list: | |
if hasattr(fig, 'to_json'): # Check if it's a Plotly figure | |
json_fig_list.append(plotly.io.to_json(fig)) | |
else: | |
json_fig_list.append(fig) # Already JSON or other format | |
json_figs.append(json_fig_list) | |
else: | |
# Single figure case | |
if hasattr(fig_list, 'to_json'): | |
json_figs.append(plotly.io.to_json(fig_list)) | |
else: | |
json_figs.append(fig_list) | |
serialized_return_dict['plotly_figs'] = json_figs | |
# Update DB with analysis results | |
await update_report_in_db("running", update.get("progress", 0), "analysis", serialized_return_dict) | |
# Generate HTML report using the original final_result with Figure objects | |
html_report = None | |
try: | |
html_report = generate_html_report(final_result) | |
except Exception as e: | |
logger.log_message(f"Error generating HTML report: {str(e)}", level=logging.ERROR) | |
# Continue even if HTML generation fails | |
# Send the analysis results | |
yield json.dumps({ | |
"step": "analysis", | |
"status": "completed", | |
"content": serialized_return_dict, | |
"progress": 90 | |
}) + "\n" | |
# Send report generation status | |
yield json.dumps({ | |
"step": "report", | |
"status": "processing", | |
"message": "Generating final report...", | |
"progress": 95 | |
}) + "\n" | |
# Send final completion | |
yield json.dumps({ | |
"step": "completed", | |
"status": "success", | |
"analysis": serialized_return_dict, | |
"html_report": html_report, | |
"progress": 100 | |
}) + "\n" | |
# Update DB with completed report (with HTML if generated) | |
if html_report: | |
logger.log_message(f"Saving HTML report to database, length: {len(html_report)}", level=logging.INFO) | |
else: | |
logger.log_message("No HTML report to save to database", level=logging.WARNING) | |
await update_report_in_db("completed", 100, "completed", html_report) | |
elif update.get("step") == "error": | |
# Forward error directly | |
yield json.dumps(update) + "\n" | |
await update_report_in_db("failed", 0) | |
return | |
else: | |
# Forward all other progress updates | |
yield json.dumps(update) + "\n" | |
# If we somehow exit the loop without getting a final result, that's an error | |
if not final_result: | |
yield json.dumps({ | |
"step": "error", | |
"status": "failed", | |
"message": "Deep analysis completed without final result", | |
"progress": 0 | |
}) + "\n" | |
await update_report_in_db("failed", 0) | |
except Exception as e: | |
logger.log_message(f"Error in deep analysis stream: {str(e)}", level=logging.ERROR) | |
yield json.dumps({ | |
"step": "error", | |
"status": "failed", | |
"message": f"Deep analysis failed: {str(e)}", | |
"progress": 0 | |
}) + "\n" | |
# Update DB with error status | |
if 'update_report_in_db' in locals() and session_state.get("current_deep_analysis_id"): | |
await update_report_in_db("failed", 0) | |
async def download_html_report( | |
request: dict, | |
session_id: str = Depends(get_session_id_dependency) | |
): | |
"""Download HTML report from previous deep analysis""" | |
try: | |
analysis_data = request.get("analysis_data") | |
if not analysis_data: | |
raise HTTPException(status_code=400, detail="No analysis data provided") | |
# Get report UUID from request if available (for saving to DB) | |
report_uuid = request.get("report_uuid") | |
session_state = app.state.get_session_state(session_id) | |
# If no report_uuid in request, try to get it from session state | |
if not report_uuid and session_state.get("current_deep_analysis_uuid"): | |
report_uuid = session_state.get("current_deep_analysis_uuid") | |
# Convert JSON-serialized Plotly figures back to Figure objects for HTML generation | |
processed_data = analysis_data.copy() | |
if 'plotly_figs' in processed_data and processed_data['plotly_figs']: | |
import plotly.io | |
import plotly.graph_objects as go | |
figure_objects = [] | |
for fig_list in processed_data['plotly_figs']: | |
if isinstance(fig_list, list): | |
fig_obj_list = [] | |
for fig_json in fig_list: | |
if isinstance(fig_json, str): | |
# Convert JSON string back to Figure object | |
try: | |
fig_obj = plotly.io.from_json(fig_json) | |
fig_obj_list.append(fig_obj) | |
except Exception as e: | |
logger.log_message(f"Error parsing Plotly JSON: {str(e)}", level=logging.WARNING) | |
continue | |
elif hasattr(fig_json, 'to_html'): | |
# Already a Figure object | |
fig_obj_list.append(fig_json) | |
figure_objects.append(fig_obj_list) | |
else: | |
# Single figure case | |
if isinstance(fig_list, str): | |
try: | |
fig_obj = plotly.io.from_json(fig_list) | |
figure_objects.append(fig_obj) | |
except Exception as e: | |
logger.log_message(f"Error parsing Plotly JSON: {str(e)}", level=logging.WARNING) | |
continue | |
elif hasattr(fig_list, 'to_html'): | |
figure_objects.append(fig_list) | |
processed_data['plotly_figs'] = figure_objects | |
# Generate HTML report | |
html_report = generate_html_report(processed_data) | |
# Save report to database if we have a UUID | |
if report_uuid: | |
try: | |
from src.db.init_db import session_factory | |
from src.db.schemas.models import DeepAnalysisReport | |
db_session = session_factory() | |
try: | |
# Try to find existing report by UUID | |
report = db_session.query(DeepAnalysisReport).filter(DeepAnalysisReport.report_uuid == report_uuid).first() | |
if report: | |
# Update existing report with HTML content | |
report.html_report = html_report | |
report.updated_at = datetime.now(UTC) | |
db_session.commit() | |
except Exception as e: | |
db_session.rollback() | |
finally: | |
db_session.close() | |
except Exception as e: | |
logger.log_message(f"Database operation failed when storing HTML report: {str(e)}", level=logging.ERROR) | |
# Continue even if DB storage fails | |
# Create a filename with timestamp | |
timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S") | |
filename = f"deep_analysis_report_{timestamp}.html" | |
# Return as downloadable file | |
return StreamingResponse( | |
iter([html_report.encode('utf-8')]), | |
media_type='text/html', | |
headers={ | |
'Content-Disposition': f'attachment; filename="{filename}"', | |
'Content-Type': 'text/html; charset=utf-8' | |
} | |
) | |
except Exception as e: | |
logger.log_message(f"Failed to generate HTML report: {str(e)}", level=logging.ERROR) | |
raise HTTPException(status_code=500, detail=f"Failed to generate report: {str(e)}") | |
# In the section where routers are included, add the session_router | |
app.include_router(chat_router) | |
app.include_router(analytics_router) | |
app.include_router(code_router) | |
app.include_router(session_router) | |
app.include_router(feedback_router) | |
app.include_router(deep_analysis_router) | |
app.include_router(templates_router) | |
app.include_router(blog_router) | |
if __name__ == "__main__": | |
port = int(os.environ.get("PORT", 8000)) | |
uvicorn.run(app, host="0.0.0.0", port=port) | |