Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import re | |
import json | |
import sys | |
import contextlib | |
from io import StringIO | |
import time | |
import logging | |
from src.utils.logger import Logger | |
import textwrap | |
logger = Logger(__name__, level="INFO", see_time=False, console_log=False) | |
def stdoutIO(stdout=None): | |
old = sys.stdout | |
if stdout is None: | |
stdout = StringIO() | |
sys.stdout = stdout | |
yield stdout | |
sys.stdout = old | |
# Precompile regex patterns for better performance | |
SENSITIVE_MODULES = re.compile(r"(os|sys|subprocess|dotenv|requests|http|socket|smtplib|ftplib|telnetlib|paramiko)") | |
IMPORT_PATTERN = re.compile(r"^\s*import\s+(" + SENSITIVE_MODULES.pattern + r").*?(\n|$)", re.MULTILINE) | |
FROM_IMPORT_PATTERN = re.compile(r"^\s*from\s+(" + SENSITIVE_MODULES.pattern + r").*?(\n|$)", re.MULTILINE) | |
DYNAMIC_IMPORT_PATTERN = re.compile(r"__import__\s*\(\s*['\"](" + SENSITIVE_MODULES.pattern + r")['\"].*?\)") | |
ENV_ACCESS_PATTERN = re.compile(r"(os\.getenv|os\.environ|load_dotenv|\.__import__\s*\(\s*['\"]os['\"].*?\.environ)") | |
FILE_ACCESS_PATTERN = re.compile(r"(open\(|read\(|write\(|file\(|with\s+open)") | |
# Enhanced API key detection patterns | |
API_KEY_PATTERNS = [ | |
# Direct key assignments | |
re.compile(r"(?i)(api_?key|access_?token|secret_?key|auth_?token|password|credential|secret)s?\s*=\s*[\"\'][\w\-\+\/\=]{8,}[\"\']"), | |
# Function calls with keys | |
re.compile(r"(?i)\.set_api_key\(\s*[\"\'][\w\-\+\/\=]{8,}[\"\']"), | |
# Dictionary assignments | |
re.compile(r"(?i)['\"](?:api_?key|access_?token|secret_?key|auth_?token|password|credential|secret)['\"](?:\s*:\s*)[\"\'][\w\-\+\/\=]{8,}[\"\']"), | |
# Common key formats (base64-like, hex) | |
re.compile(r"[\"\'](?:[A-Za-z0-9\+\/\=]{32,}|[0-9a-fA-F]{32,})[\"\']"), | |
# Bearer token pattern | |
re.compile(r"[\"\'](Bearer\s+[\w\-\+\/\=]{8,})[\"\']"), | |
# Inline URL with auth | |
re.compile(r"https?:\/\/[\w\-\+\/\=]{8,}@") | |
] | |
# Network request patterns | |
NETWORK_REQUEST_PATTERNS = re.compile(r"(requests\.|urllib\.|http\.|\.post\(|\.get\(|\.connect\()") | |
def check_security_concerns(code_str): | |
"""Check code for security concerns and return info about what was found""" | |
security_concerns = { | |
"has_concern": False, | |
"messages": [], | |
"blocked_imports": False, | |
"blocked_dynamic_imports": False, | |
"blocked_env_access": False, | |
"blocked_file_access": False, | |
"blocked_api_keys": False, | |
"blocked_network": False | |
} | |
# Check for sensitive imports | |
if IMPORT_PATTERN.search(code_str) or FROM_IMPORT_PATTERN.search(code_str): | |
security_concerns["has_concern"] = True | |
security_concerns["blocked_imports"] = True | |
security_concerns["messages"].append("Sensitive module imports blocked") | |
# Check for __import__ bypass technique | |
if DYNAMIC_IMPORT_PATTERN.search(code_str): | |
security_concerns["has_concern"] = True | |
security_concerns["blocked_dynamic_imports"] = True | |
security_concerns["messages"].append("Dynamic import of sensitive modules blocked") | |
# Check for environment variables access | |
if ENV_ACCESS_PATTERN.search(code_str): | |
security_concerns["has_concern"] = True | |
security_concerns["blocked_env_access"] = True | |
security_concerns["messages"].append("Environment variables access blocked") | |
# Check for file operations | |
if FILE_ACCESS_PATTERN.search(code_str): | |
security_concerns["has_concern"] = True | |
security_concerns["blocked_file_access"] = True | |
security_concerns["messages"].append("File operations blocked") | |
# Check for API key patterns | |
for pattern in API_KEY_PATTERNS: | |
if pattern.search(code_str): | |
security_concerns["has_concern"] = True | |
security_concerns["blocked_api_keys"] = True | |
security_concerns["messages"].append("API key/token usage blocked") | |
break | |
# Check for network requests | |
if NETWORK_REQUEST_PATTERNS.search(code_str): | |
security_concerns["has_concern"] = True | |
security_concerns["blocked_network"] = True | |
security_concerns["messages"].append("Network requests blocked") | |
return security_concerns | |
def clean_code_for_security(code_str, security_concerns): | |
"""Apply security modifications to the code based on detected concerns""" | |
modified_code = code_str | |
# Block sensitive imports if needed | |
if security_concerns["blocked_imports"]: | |
modified_code = IMPORT_PATTERN.sub(r'# BLOCKED: import \1\n', modified_code) | |
modified_code = FROM_IMPORT_PATTERN.sub(r'# BLOCKED: from \1\n', modified_code) | |
# Block dynamic imports if needed | |
if security_concerns["blocked_dynamic_imports"]: | |
modified_code = DYNAMIC_IMPORT_PATTERN.sub(r'"BLOCKED_DYNAMIC_IMPORT"', modified_code) | |
# Block environment access if needed | |
if security_concerns["blocked_env_access"]: | |
modified_code = ENV_ACCESS_PATTERN.sub(r'"BLOCKED_ENV_ACCESS"', modified_code) | |
# Block file operations if needed | |
if security_concerns["blocked_file_access"]: | |
modified_code = FILE_ACCESS_PATTERN.sub(r'"BLOCKED_FILE_ACCESS"', modified_code) | |
# Block API keys if needed | |
if security_concerns["blocked_api_keys"]: | |
for pattern in API_KEY_PATTERNS: | |
modified_code = pattern.sub(r'"BLOCKED_API_KEY"', modified_code) | |
# Block network requests if needed | |
if security_concerns["blocked_network"]: | |
modified_code = NETWORK_REQUEST_PATTERNS.sub(r'"BLOCKED_NETWORK_REQUEST"', modified_code) | |
# Add warning banner if needed | |
if security_concerns["has_concern"]: | |
security_message = "⚠️ SECURITY WARNING: " + ". ".join(security_concerns["messages"]) + "." | |
modified_code = f"print('{security_message}')\n\n" + modified_code | |
return modified_code | |
def format_correlation_output(text): | |
"""Format correlation matrix output for better readability""" | |
lines = text.split('\n') | |
formatted_lines = [] | |
for line in lines: | |
# Skip empty lines at the beginning | |
if not line.strip() and not formatted_lines: | |
continue | |
if not line.strip(): | |
formatted_lines.append(line) | |
continue | |
# Check if this line contains correlation values or variable names | |
stripped_line = line.strip() | |
parts = stripped_line.split() | |
if len(parts) > 1: | |
# Check if this is a header line with variable names | |
if all(part.replace('_', '').replace('-', '').isalpha() for part in parts): | |
# This is a header row with variable names | |
formatted_header = f"{'':12}" # Empty first column for row labels | |
for part in parts: | |
formatted_header += f"{part:>12}" | |
formatted_lines.append(formatted_header) | |
elif any(char.isdigit() for char in stripped_line) and ('.' in stripped_line or '-' in stripped_line): | |
# This looks like a correlation line with numbers | |
row_name = parts[0] if parts else "" | |
values = parts[1:] if len(parts) > 1 else [] | |
formatted_row = f"{row_name:<12}" | |
for value in values: | |
try: | |
val = float(value) | |
formatted_row += f"{val:>12.3f}" | |
except ValueError: | |
formatted_row += f"{value:>12}" | |
formatted_lines.append(formatted_row) | |
else: | |
# Other lines (like titles) | |
formatted_lines.append(line) | |
else: | |
formatted_lines.append(line) | |
return '\n'.join(formatted_lines) | |
def format_summary_stats(text): | |
"""Format summary statistics for better readability""" | |
lines = text.split('\n') | |
formatted_lines = [] | |
for line in lines: | |
if not line.strip(): | |
formatted_lines.append(line) | |
continue | |
# Check if this is a header line with statistical terms only (missing first column) | |
stripped_line = line.strip() | |
if any(stat in stripped_line.lower() for stat in ['count', 'mean', 'median', 'std', 'min', 'max', '25%', '50%', '75%']): | |
parts = stripped_line.split() | |
# Check if this is a header row (starts with statistical terms) | |
if parts and parts[0].lower() in ['count', 'mean', 'median', 'std', 'min', 'max', '25%', '50%', '75%']: | |
# This is a header row - add proper spacing | |
formatted_header = f"{'':12}" # Empty first column for row labels | |
for part in parts: | |
formatted_header += f"{part:>15}" | |
formatted_lines.append(formatted_header) | |
else: | |
# This is a data row - format normally | |
row_name = parts[0] if parts else "" | |
values = parts[1:] if len(parts) > 1 else [] | |
formatted_row = f"{row_name:<12}" | |
for value in values: | |
try: | |
if '.' in value or 'e' in value.lower(): | |
val = float(value) | |
if abs(val) >= 1000000: | |
formatted_row += f"{val:>15.2e}" | |
elif abs(val) >= 1: | |
formatted_row += f"{val:>15.2f}" | |
else: | |
formatted_row += f"{val:>15.6f}" | |
else: | |
val = int(value) | |
formatted_row += f"{val:>15}" | |
except ValueError: | |
formatted_row += f"{value:>15}" | |
formatted_lines.append(formatted_row) | |
else: | |
# Other lines (titles, etc.) - keep as is | |
formatted_lines.append(line) | |
return '\n'.join(formatted_lines) | |
def clean_print_statements(code_block): | |
""" | |
This function cleans up any `print()` statements that might contain unwanted `\n` characters. | |
It ensures print statements are properly formatted without unnecessary newlines. | |
""" | |
# This regex targets print statements, even if they have newlines inside | |
return re.sub(r'print\((.*?)(\\n.*?)(.*?)\)', r'print(\1\3)', code_block, flags=re.DOTALL) | |
def remove_code_block_from_summary(summary): | |
# use regex to remove code block from summary list | |
summary = re.sub(r'```python\n(.*?)\n```', '', summary) | |
return summary.split("\n") | |
def remove_main_block(code): | |
# Match the __main__ block | |
pattern = r'(?m)^if\s+__name__\s*==\s*["\']__main__["\']\s*:\s*\n((?:\s+.*\n?)*)' | |
match = re.search(pattern, code) | |
if match: | |
main_block = match.group(1) | |
# Dedent the code block inside __main__ | |
dedented_block = textwrap.dedent(main_block) | |
# Remove \n from any print statements in the block (also handling multiline print cases) | |
dedented_block = clean_print_statements(dedented_block) | |
# Replace the block in the code | |
cleaned_code = re.sub(pattern, dedented_block, code) | |
# Optional: Remove leading newlines if any | |
cleaned_code = cleaned_code.strip() | |
return cleaned_code | |
return code | |
def format_code_block(code_str): | |
code_clean = re.sub(r'^```python\n?', '', code_str, flags=re.MULTILINE) | |
code_clean = re.sub(r'\n```$', '', code_clean) | |
return f'\n{code_clean}\n' | |
def format_code_backticked_block(code_str): | |
code_clean = re.sub(r'^```python\n?', '', code_str, flags=re.MULTILINE) | |
code_clean = re.sub(r'\n```$', '', code_clean) | |
# Only match assignments at top level (not indented) | |
# 1. Remove 'df = pd.DataFrame()' if it's at the top level | |
# Remove reading the csv file if it's already in the context | |
modified_code = re.sub(r"df\s*=\s*pd\.read_csv\([\"\'].*?[\"\']\).*?(\n|$)", '', code_clean) | |
# Only match assignments at top level (not indented) | |
# 1. Remove 'df = pd.DataFrame()' if it's at the top level | |
modified_code = re.sub( | |
r"^df\s*=\s*pd\.DataFrame\(\s*\)\s*(#.*)?$", | |
'', | |
modified_code, | |
flags=re.MULTILINE | |
) | |
# # Remove sample dataframe lines with multiple array values | |
modified_code = re.sub(r"^# Sample DataFrames?.*?(\n|$)", '', modified_code, flags=re.MULTILINE | re.IGNORECASE) | |
# # Remove plt.show() statements | |
modified_code = re.sub(r"plt\.show\(\).*?(\n|$)", '', modified_code) | |
# remove main | |
code_clean = remove_main_block(modified_code) | |
return f'```python\n{code_clean}\n```' | |
def execute_code_from_markdown(code_str, dataframe=None): | |
import pandas as pd | |
import plotly.express as px | |
import plotly | |
import plotly.graph_objects as go | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import numpy as np | |
import re | |
import traceback | |
import sys | |
from io import StringIO, BytesIO | |
import base64 | |
# Check for security concerns in the code | |
security_concerns = check_security_concerns(code_str) | |
# Apply security modifications to the code | |
modified_code = clean_code_for_security(code_str, security_concerns) | |
# Enhanced print function that detects and formats tabular data | |
captured_outputs = [] | |
original_print = print | |
# Set pandas display options for full table display | |
pd.set_option('display.max_columns', None) | |
pd.set_option('display.max_rows', 20) # Limit to 20 rows instead of unlimited | |
pd.set_option('display.width', None) | |
pd.set_option('display.max_colwidth', 50) | |
pd.set_option('display.expand_frame_repr', False) | |
def enhanced_print(*args, **kwargs): | |
# Convert all args to strings | |
str_args = [str(arg) for arg in args] | |
output_text = kwargs.get('sep', ' ').join(str_args) | |
# Special case for DataFrames - use pipe delimiter and clean format | |
if isinstance(args[0], pd.DataFrame) and len(args) == 1: | |
# Format DataFrame with pipe delimiter using to_csv for reliable column separation | |
df = args[0] | |
# Use StringIO to capture CSV output with pipe delimiter | |
from io import StringIO | |
csv_buffer = StringIO() | |
# Export to CSV with pipe delimiter, preserving index | |
df.to_csv(csv_buffer, sep='|', index=True, float_format='%.6g') | |
csv_output = csv_buffer.getvalue() | |
# Clean up the CSV output - remove quotes and extra formatting | |
lines = csv_output.strip().split('\n') | |
cleaned_lines = [] | |
for line in lines: | |
# Remove any quotes that might have been added by to_csv | |
clean_line = line.replace('"', '') | |
# Split by pipe, strip whitespace from each part, then rejoin | |
parts = [part.strip() for part in clean_line.split('|')] | |
cleaned_lines.append(' | '.join(parts)) | |
output_text = '\n'.join(cleaned_lines) | |
captured_outputs.append(f"<TABLE_START>\n{output_text}\n<TABLE_END>") | |
original_print(output_text) | |
return | |
# Detect if this looks like tabular data (generic approach) | |
is_table = False | |
# Check for table patterns: | |
# 1. Multiple lines with consistent spacing | |
lines = output_text.split('\n') | |
if len(lines) > 2: | |
# Count lines that look like they have multiple columns (2+ spaces between words) | |
multi_column_lines = sum(1 for line in lines if len(line.split()) > 1 and ' ' in line) | |
if multi_column_lines >= 2: # At least 2 lines with multiple columns | |
is_table = True | |
# Check for pandas DataFrame patterns like index with column names | |
if any(re.search(r'^\s*\d+\s+', line) for line in lines): | |
# Look for lines starting with an index number followed by spaces | |
is_table = True | |
# Look for table-like structured output with multiple rows of similar format | |
if len(lines) >= 3: | |
# Sample a few lines to check for consistent structure | |
sample_lines = [lines[i] for i in range(min(len(lines), 5)) if i < len(lines) and lines[i].strip()] | |
# Check for consistent whitespace patterns | |
if len(sample_lines) >= 2: | |
# Get positions of whitespace groups in first line | |
whitespace_positions = [] | |
for i, line in enumerate(sample_lines): | |
if not line.strip(): | |
continue | |
positions = [m.start() for m in re.finditer(r'\s{2,}', line)] | |
if i == 0: | |
whitespace_positions = positions | |
elif len(positions) == len(whitespace_positions): | |
# Check if whitespace positions are roughly the same | |
is_similar = all(abs(pos - whitespace_positions[j]) <= 3 | |
for j, pos in enumerate(positions) | |
if j < len(whitespace_positions)) | |
if is_similar: | |
is_table = True | |
# 2. Contains common table indicators | |
if any(indicator in output_text.lower() for indicator in [ | |
'count', 'mean', 'std', 'min', 'max', '25%', '50%', '75%', # Summary stats | |
'correlation', 'corr', # Correlation tables | |
'coefficient', 'r-squared', 'p-value', # Regression tables | |
]): | |
is_table = True | |
# 3. Has many decimal numbers (likely a data table) | |
if output_text.count('.') > 5 and len(lines) > 2: | |
is_table = True | |
# If we have detected a table, convert space-delimited to pipe-delimited format | |
if is_table: | |
# Convert the table to pipe-delimited format for better parsing in frontend | |
formatted_lines = [] | |
for line in lines: | |
if not line.strip(): | |
formatted_lines.append(line) # Keep empty lines | |
continue | |
# Split by multiple spaces and join with pipe delimiter | |
parts = re.split(r'\s{2,}', line.strip()) | |
if parts: | |
formatted_lines.append(" | ".join(parts)) | |
else: | |
formatted_lines.append(line) | |
# Use the pipe-delimited format | |
output_text = "\n".join(formatted_lines) | |
# Format and mark the output for table processing in UI | |
captured_outputs.append(f"<TABLE_START>\n{output_text}\n<TABLE_END>") | |
else: | |
captured_outputs.append(output_text) | |
# Also use original print for stdout capture | |
original_print(*args, **kwargs) | |
# Custom matplotlib capture function | |
def capture_matplotlib_chart(): | |
"""Capture current matplotlib figure as base64 encoded image""" | |
try: | |
fig = plt.gcf() # Get current figure | |
if fig.get_axes(): # Check if figure has any plots | |
buffer = BytesIO() | |
fig.savefig(buffer, format='png', dpi=150, bbox_inches='tight', | |
facecolor='white', edgecolor='none') | |
buffer.seek(0) | |
img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
buffer.close() | |
plt.close(fig) # Close the figure to free memory | |
return img_base64 | |
return None | |
except Exception: | |
return None | |
# Store original plt.show function | |
original_plt_show = plt.show | |
def custom_plt_show(*args, **kwargs): | |
"""Custom plt.show that captures the chart instead of displaying it""" | |
img_base64 = capture_matplotlib_chart() | |
if img_base64: | |
matplotlib_outputs.append(img_base64) | |
# Don't call original show to prevent display | |
context = { | |
'pd': pd, | |
'px': px, | |
'go': go, | |
'plt': plt, | |
'plotly': plotly, | |
'__builtins__': __builtins__, | |
'__import__': __import__, | |
'sns': sns, | |
'np': np, | |
'json_outputs': [], # List to store multiple Plotly JSON outputs | |
'matplotlib_outputs': [], # List to store matplotlib chart images as base64 | |
'print': enhanced_print # Replace print with our enhanced version | |
} | |
# Add matplotlib_outputs to local scope for the custom show function | |
matplotlib_outputs = context['matplotlib_outputs'] | |
# Replace plt.show with our custom function | |
plt.show = custom_plt_show | |
# Modify code to store multiple JSON outputs | |
modified_code = re.sub( | |
r'(\w*_?)fig(\w*)\.show\(\)', | |
r'json_outputs.append(plotly.io.to_json(\1fig\2, pretty=True))', | |
modified_code | |
) | |
modified_code = re.sub( | |
r'(\w*_?)fig(\w*)\.to_html\(.*?\)', | |
r'json_outputs.append(plotly.io.to_json(\1fig\2, pretty=True))', | |
modified_code | |
) | |
# Remove reading the csv file if it's already in the context | |
modified_code = re.sub(r"df\s*=\s*pd\.read_csv\([\"\'].*?[\"\']\).*?(\n|$)", '', modified_code) | |
# Only match assignments at top level (not indented) | |
# 1. Remove 'df = pd.DataFrame()' if it's at the top level | |
modified_code = re.sub( | |
r"^df\s*=\s*pd\.DataFrame\(\s*\)\s*(#.*)?$", | |
'', | |
modified_code, | |
flags=re.MULTILINE | |
) | |
# Custom display function for DataFrames to show head + tail for large datasets | |
original_repr = pd.DataFrame.__repr__ | |
def custom_df_repr(self): | |
if len(self) > 15: | |
# For large DataFrames, show first 10 and last 5 rows | |
head_part = self.head(10) | |
tail_part = self.tail(5) | |
head_str = head_part.__repr__() | |
tail_str = tail_part.__repr__() | |
# Extract just the data rows (skip the header from tail) | |
tail_lines = tail_str.split('\n') | |
tail_data = '\n'.join(tail_lines[1:]) # Skip header line | |
return f"{head_str}\n...\n{tail_data}" | |
else: | |
return original_repr(self) | |
# Apply custom representation temporarily | |
pd.DataFrame.__repr__ = custom_df_repr | |
# If a dataframe is provided, add it to the context | |
if dataframe is not None: | |
context['df'] = dataframe | |
# remove pd.read_csv() if it's already in the context | |
modified_code = re.sub(r"pd\.read_csv\(\s*[\"\'].*?[\"\']\s*\)", '', modified_code) | |
# Remove sample dataframe lines with multiple array values | |
modified_code = re.sub(r"^# Sample DataFrames?.*?(\n|$)", '', modified_code, flags=re.MULTILINE | re.IGNORECASE) | |
# Replace plt.savefig() calls with plt.show() to ensure plots are displayed | |
modified_code = re.sub(r'plt\.savefig\([^)]*\)', 'plt.show()', modified_code) | |
# Instead of removing plt.show(), keep them - they'll be handled by our custom function | |
# Also handle seaborn plots that might not have explicit plt.show() | |
# Add plt.show() after seaborn plot functions if not already present | |
seaborn_plot_functions = [ | |
'sns.scatterplot', 'sns.lineplot', 'sns.barplot', 'sns.boxplot', 'sns.violinplot', | |
'sns.stripplot', 'sns.swarmplot', 'sns.pointplot', 'sns.catplot', 'sns.relplot', | |
'sns.displot', 'sns.histplot', 'sns.kdeplot', 'sns.ecdfplot', 'sns.rugplot', | |
'sns.distplot', 'sns.jointplot', 'sns.pairplot', 'sns.FacetGrid', 'sns.PairGrid', | |
'sns.heatmap', 'sns.clustermap', 'sns.regplot', 'sns.lmplot', 'sns.residplot' | |
] | |
# Add automatic plt.show() after seaborn plots if not already present | |
for func in seaborn_plot_functions: | |
pattern = rf'({re.escape(func)}\([^)]*\)(?:\.[^(]*\([^)]*\))*)' | |
def add_show(match): | |
plot_call = match.group(1) | |
# Check if the next non-empty line already has plt.show() | |
return f'{plot_call}\nplt.show()' | |
modified_code = re.sub(pattern, add_show, modified_code) | |
# Only add df = pd.read_csv() if no dataframe was provided and the code contains pd.read_csv | |
if dataframe is None and 'pd.read_csv' not in modified_code: | |
modified_code = re.sub( | |
r'import pandas as pd', | |
r'import pandas as pd\n\n# Read Housing.csv\ndf = pd.read_csv("Housing.csv")', | |
modified_code | |
) | |
# Identify code blocks by comments | |
code_blocks = [] | |
current_block = [] | |
current_block_name = "unknown" | |
for line in modified_code.splitlines(): | |
# Check if line contains a block identifier comment | |
block_match = re.match(r'^# ([a-zA-Z_]+)_agent code start', line) | |
if block_match: | |
# If we had a previous block, save it | |
if current_block: | |
code_blocks.append((current_block_name, '\n'.join(current_block))) | |
# Start a new block | |
current_block_name = block_match.group(1) | |
current_block = [] | |
else: | |
current_block.append(line) | |
# Add the last block if it exists | |
if current_block: | |
code_blocks.append((current_block_name, '\n'.join(current_block))) | |
# Execute each code block separately | |
all_outputs = [] | |
for block_name, block_code in code_blocks: | |
try: | |
# Clear captured outputs for each block | |
captured_outputs.clear() | |
with stdoutIO() as s: | |
exec(block_code, context) # Execute the block | |
# Get both stdout and our enhanced captured outputs | |
stdout_output = s.getvalue() | |
# Combine outputs, preferring our enhanced format when available | |
if captured_outputs: | |
combined_output = '\n'.join(captured_outputs) | |
else: | |
combined_output = stdout_output | |
all_outputs.append((block_name, combined_output, None)) # None means no error | |
except Exception as e: | |
# Reset pandas options in case of error | |
pd.reset_option('display.max_columns') | |
pd.reset_option('display.max_rows') | |
pd.reset_option('display.width') | |
pd.reset_option('display.max_colwidth') | |
pd.reset_option('display.expand_frame_repr') | |
# Restore original DataFrame representation in case of error | |
pd.DataFrame.__repr__ = original_repr | |
# Restore original plt.show | |
plt.show = original_plt_show | |
error_traceback = traceback.format_exc() | |
# Extract error message and error type | |
error_message = str(e) | |
error_type = type(e).__name__ | |
error_lines = error_traceback.splitlines() | |
# Format error with context of the actual code | |
formatted_error = f"Error in {block_name}_agent: {error_message}\n" | |
# Add first few lines of traceback | |
first_lines = error_lines[:3] | |
formatted_error += "\n".join(first_lines) + "\n" | |
# Parse problem variables/values from the error message | |
problem_vars = [] | |
# Look for common error patterns | |
if "not in index" in error_message: | |
# Extract column names for 'not in index' errors | |
column_match = re.search(r"\['([^']+)'(?:, '([^']+)')*\] not in index", error_message) | |
if column_match: | |
problem_vars = [g for g in column_match.groups() if g is not None] | |
# Look for DataFrame accessing operations and list/variable definitions | |
potential_lines = [] | |
code_lines = block_code.splitlines() | |
# First, find all DataFrame column access patterns | |
df_access_patterns = [] | |
for i, line in enumerate(code_lines): | |
# Find DataFrame variables from patterns like "df_name[...]" or "df_name.loc[...]" | |
df_matches = re.findall(r'(\w+)(?:\[|\.)(?:loc|iloc|columns|at|iat|\.select)', line) | |
for df_var in df_matches: | |
df_access_patterns.append((i, df_var)) | |
# Find variables that might contain column lists | |
for var in problem_vars: | |
if re.search(r'\b(numeric_columns|categorical_columns|columns|features|cols)\b', line): | |
potential_lines.append(i) | |
# Identify the most likely problematic lines | |
if df_access_patterns: | |
for i, df_var in df_access_patterns: | |
if any(re.search(rf'{df_var}\[.*?\]', line) for line in code_lines): | |
potential_lines.append(i) | |
# If no specific lines found yet, look for any DataFrame operations | |
if not potential_lines: | |
for i, line in enumerate(code_lines): | |
if re.search(r'(?:corr|drop|groupby|pivot|merge|join|concat|apply|map|filter|loc|iloc)\(', line): | |
potential_lines.append(i) | |
# Sort and deduplicate | |
potential_lines = sorted(set(potential_lines)) | |
elif "name" in error_message and "is not defined" in error_message: | |
# Extract variable name for NameError | |
var_match = re.search(r"name '([^']+)' is not defined", error_message) | |
if var_match: | |
problem_vars = [var_match.group(1)] | |
elif "object has no attribute" in error_message: | |
# Extract attribute name for AttributeError | |
attr_match = re.search(r"'([^']+)' object has no attribute '([^']+)'", error_message) | |
if attr_match: | |
problem_vars = [f"{attr_match.group(1)}.{attr_match.group(2)}"] | |
# Scan code for lines containing the problem variables | |
if problem_vars: | |
formatted_error += "\nProblem likely in these lines:\n" | |
code_lines = block_code.splitlines() | |
problem_lines = [] | |
# First try direct variable references | |
direct_matches = False | |
for i, line in enumerate(code_lines): | |
if any(var in line for var in problem_vars): | |
direct_matches = True | |
# Get line and its context (1 line before and after) | |
start_idx = max(0, i-1) | |
end_idx = min(len(code_lines), i+2) | |
for j in range(start_idx, end_idx): | |
line_prefix = f"{j+1}: " | |
if j == i: # The line with the problem variable | |
problem_lines.append(f"{line_prefix}>>> {code_lines[j]} <<<") | |
else: | |
problem_lines.append(f"{line_prefix}{code_lines[j]}") | |
problem_lines.append("") # Empty line between sections | |
# If no direct matches found but we identified potential problematic lines for DataFrame issues | |
if not direct_matches and "not in index" in error_message and 'potential_lines' in locals(): | |
for i in potential_lines: | |
start_idx = max(0, i-1) | |
end_idx = min(len(code_lines), i+2) | |
for j in range(start_idx, end_idx): | |
line_prefix = f"{j+1}: " | |
if j == i: | |
problem_lines.append(f"{line_prefix}>>> {code_lines[j]} <<<") | |
else: | |
problem_lines.append(f"{line_prefix}{code_lines[j]}") | |
problem_lines.append("") # Empty line between sections | |
if problem_lines: | |
formatted_error += "\n".join(problem_lines) | |
else: | |
# Special message for column errors when we can't find the exact reference | |
if "not in index" in error_message: | |
formatted_error += (f"Unable to locate direct reference to columns: {', '.join(problem_vars)}\n" | |
f"Check for variables that might contain these column names (like numeric_columns, " | |
f"categorical_columns, etc.)\n") | |
else: | |
formatted_error += f"Unable to locate lines containing: {', '.join(problem_vars)}\n" | |
else: | |
# If we couldn't identify specific variables, check for line numbers in traceback | |
for line in reversed(error_lines): # Search from the end of traceback | |
# Look for user code references in the traceback | |
if ', line ' in line and '<module>' in line: | |
try: | |
line_num = int(re.search(r', line (\d+)', line).group(1)) | |
code_lines = block_code.splitlines() | |
if 0 < line_num <= len(code_lines): | |
line_idx = line_num - 1 | |
start_idx = max(0, line_idx-2) | |
end_idx = min(len(code_lines), line_idx+3) | |
formatted_error += "\nProblem at this location:\n" | |
for i in range(start_idx, end_idx): | |
line_prefix = f"{i+1}: " | |
if i == line_idx: | |
formatted_error += f"{line_prefix}>>> {code_lines[i]} <<<\n" | |
else: | |
formatted_error += f"{line_prefix}{code_lines[i]}\n" | |
break | |
except (ValueError, AttributeError, IndexError): | |
pass | |
# Add the last few lines of the traceback | |
formatted_error += "\nFull error details:\n" | |
last_lines = error_lines[-3:] | |
formatted_error += "\n".join(last_lines) | |
all_outputs.append((block_name, None, formatted_error)) | |
# Reset pandas options after execution | |
pd.reset_option('display.max_columns') | |
pd.reset_option('display.max_rows') | |
pd.reset_option('display.width') | |
pd.reset_option('display.max_colwidth') | |
pd.reset_option('display.expand_frame_repr') | |
# Restore original DataFrame representation | |
pd.DataFrame.__repr__ = original_repr | |
# Restore original plt.show | |
plt.show = original_plt_show | |
# Compile all outputs and errors | |
output_text = "" | |
json_outputs = context.get('json_outputs', []) | |
matplotlib_outputs = context.get('matplotlib_outputs', []) | |
error_found = False | |
for block_name, output, error in all_outputs: | |
if error: | |
output_text += f"\n\n=== ERROR IN {block_name.upper()}_AGENT ===\n{error}\n" | |
error_found = True | |
elif output: | |
output_text += f"\n\n=== OUTPUT FROM {block_name.upper()}_AGENT ===\n{output}\n" | |
if error_found: | |
return output_text, [], [] | |
else: | |
return output_text, json_outputs, matplotlib_outputs | |
def format_plan_instructions(plan_instructions): | |
""" | |
Format any plan instructions (JSON string or dict) into markdown sections per agent. | |
""" | |
# Parse input into a dict | |
if "basic_qa_agent" in str(plan_instructions): | |
return "**Non-Data Request**: Please ask a data related query, don't waste credits!" | |
try: | |
if isinstance(plan_instructions, str): | |
try: | |
instructions = json.loads(plan_instructions) | |
except json.JSONDecodeError as e: | |
# Try to clean the string if it's not valid JSON | |
cleaned_str = plan_instructions.strip() | |
if cleaned_str.startswith("'") and cleaned_str.endswith("'"): | |
cleaned_str = cleaned_str[1:-1] | |
try: | |
instructions = json.loads(cleaned_str) | |
except json.JSONDecodeError: | |
raise ValueError(f"Invalid JSON format in plan instructions: {str(e)}") | |
elif isinstance(plan_instructions, dict): | |
instructions = plan_instructions | |
else: | |
raise TypeError(f"Unsupported plan instructions type: {type(plan_instructions)}") | |
except Exception as e: | |
raise ValueError(f"Error processing plan instructions: {str(e)}") | |
# logger.log_message(f"Plan instructions: {instructions}", level=logging.INFO) | |
markdown_lines = [] | |
for agent, content in instructions.items(): | |
if agent != 'basic_qa_agent': | |
agent_title = agent.replace('_', ' ').title() | |
markdown_lines.append(f"#### {agent_title}") | |
if isinstance(content, dict): | |
# Handle 'create' key | |
create_vals = content.get('create', []) | |
if create_vals: | |
markdown_lines.append(f"- **Create**:") | |
for item in create_vals: | |
markdown_lines.append(f" - {item}") | |
else: | |
markdown_lines.append(f"- **Create**: None") | |
# Handle 'use' key | |
use_vals = content.get('use', []) | |
if use_vals: | |
markdown_lines.append(f"- **Use**:") | |
for item in use_vals: | |
markdown_lines.append(f" - {item}") | |
else: | |
markdown_lines.append(f"- **Use**: None") | |
# Handle 'instruction' key | |
instr = content.get('instruction') | |
if isinstance(instr, str) and instr: | |
markdown_lines.append(f"- **Instruction**: {instr}") | |
else: | |
markdown_lines.append(f"- **Instruction**: None") | |
else: | |
# Fallback for non-dict content | |
markdown_lines.append(f"- {content}") | |
markdown_lines.append("") # blank line between agents | |
else: | |
markdown_lines.append(f"**Non-Data Request**: {content.get('instruction')}") | |
return "\n".join(markdown_lines).strip() | |
def format_complexity(instructions): | |
markdown_lines = [] | |
# Extract complexity from various possible locations in the structure | |
if isinstance(instructions, dict): | |
# Case 1: Direct complexity field | |
if 'complexity' in instructions: | |
complexity = instructions['complexity'] | |
# Case 2: Complexity in 'plan' object | |
elif 'plan' in instructions and isinstance(instructions['plan'], dict): | |
if 'complexity' in instructions['plan']: | |
complexity = instructions['plan']['complexity'] | |
else: | |
complexity = "unrelated" | |
if 'plan' in instructions and isinstance(instructions['plan'], str) and "basic_qa_agent" in instructions['plan']: | |
complexity = "unrelated" | |
if complexity: | |
# Pink color scheme variations | |
color_map = { | |
"unrelated": "#FFB6B6", # Light pink | |
"basic": "#FF9E9E", # Medium pink | |
"intermediate": "#FF7F7F", # Main pink | |
"advanced": "#FF5F5F" # Dark pink | |
} | |
indicator_map = { | |
"unrelated": "○", | |
"basic": "●", | |
"intermediate": "●●", | |
"advanced": "●●●" | |
} | |
color = color_map.get(complexity.lower(), "#FFB6B6") # Default to light pink | |
indicator = indicator_map.get(complexity.lower(), "○") | |
# Slightly larger display with pink styling | |
markdown_lines.append(f"<div style='color: {color}; border: 2px solid {color}; padding: 2px 8px; border-radius: 12px; display: inline-block; font-size: 14.4px;'>{indicator} {complexity}</div>\n") | |
return "\n".join(markdown_lines).strip() | |
def format_response_to_markdown(api_response, agent_name = None, dataframe=None): | |
try: | |
markdown = [] | |
# logger.log_message(f"API response for {agent_name} at {time.strftime('%Y-%m-%d %H:%M:%S')}: {api_response}", level=logging.INFO) | |
if isinstance(api_response, dict): | |
for key in api_response: | |
if "error" in api_response[key] and "litellm.RateLimitError" in api_response[key]['error'].lower(): | |
return f"**Error**: Rate limit exceeded. Please try switching models from the settings." | |
# You can add more checks here if needed for other keys | |
# Handle error responses | |
if isinstance(api_response, dict) and "error" in api_response: | |
return f"**Error**: {api_response['error']}" | |
if "response" in api_response and isinstance(api_response['response'], str): | |
if any(err in api_response['response'].lower() for err in ["auth", "api", "lm"]): | |
return "**Error**: Authentication failed. Please check your API key in settings and try again." | |
if "model" in api_response['response'].lower(): | |
return "**Error**: Model configuration error. Please verify your model selection in settings." | |
for agent, content in api_response.items(): | |
agent = agent.split("__")[0] if "__" in agent else agent | |
if "memory" in agent or not content: | |
continue | |
if "complexity" in content: | |
markdown.append(f"{format_complexity(content)}\n") | |
markdown.append(f"\n## {agent.replace('_', ' ').title()}\n") | |
if agent == "analytical_planner": | |
logger.log_message(f"Analytical planner content: {content}", level=logging.INFO) | |
if 'plan_desc' in content: | |
markdown.append(f"### Reasoning\n{content['plan_desc']}\n") | |
if 'plan_instructions' in content: | |
markdown.append(f"{format_plan_instructions(content['plan_instructions'])}\n") | |
else: | |
markdown.append(f"### Reasoning\n{content['rationale']}\n") | |
else: | |
if "rationale" in content: | |
markdown.append(f"### Reasoning\n{content['rationale']}\n") | |
if 'code' in content: | |
markdown.append(f"### Code Implementation\n{format_code_backticked_block(content['code'])}\n") | |
if 'answer' in content: | |
markdown.append(f"### Answer\n{content['answer']}\n Please ask a query about the data") | |
if 'summary' in content: | |
import re | |
summary_text = content['summary'] | |
summary_text = re.sub(r'```python\n(.*?)\n```', '', summary_text, flags=re.DOTALL) | |
markdown.append("### Summary\n") | |
# Extract pre-list intro, bullet points, and post-list text | |
intro_match = re.split(r'\(\d+\)', summary_text, maxsplit=1) | |
if len(intro_match) > 1: | |
intro_text = intro_match[0].strip() | |
rest_text = "(1)" + intro_match[1] # reattach for bullet parsing | |
else: | |
intro_text = summary_text.strip() | |
rest_text = "" | |
if intro_text: | |
markdown.append(f"{intro_text}\n") | |
# Split bullets at numbered items like (1)...(8) | |
bullets = re.split(r'\(\d+\)', rest_text) | |
bullets = [b.strip(" ,.\n") for b in bullets if b.strip()] | |
# Check for post-list content (anything after the last number) | |
for i, bullet in enumerate(bullets): | |
markdown.append(f"* {bullet}\n") | |
if 'refined_complete_code' in content and 'summary' in content: | |
try: | |
if content['refined_complete_code'] is not None and content['refined_complete_code'] != "": | |
clean_code = format_code_block(content['refined_complete_code']) | |
markdown_code = format_code_backticked_block(content['refined_complete_code']) | |
output, json_outputs, matplotlib_outputs = execute_code_from_markdown(clean_code, dataframe) | |
elif "```python" in content['summary']: | |
clean_code = format_code_block(content['summary']) | |
markdown_code = format_code_backticked_block(content['summary']) | |
output, json_outputs, matplotlib_outputs = execute_code_from_markdown(clean_code, dataframe) | |
except Exception as e: | |
logger.log_message(f"Error in execute_code_from_markdown: {str(e)}", level=logging.ERROR) | |
markdown_code = f"**Error**: {str(e)}" | |
output = None | |
json_outputs = [] | |
matplotlib_outputs = [] | |
# continue | |
if markdown_code is not None: | |
markdown.append(f"### Refined Complete Code\n{markdown_code}\n") | |
if output: | |
markdown.append("### Execution Output\n") | |
markdown.append(f"```output\n{output}\n```\n") | |
if json_outputs: | |
markdown.append("### Plotly JSON Outputs\n") | |
for idx, json_output in enumerate(json_outputs): | |
markdown.append(f"```plotly\n{json_output}\n```\n") | |
if matplotlib_outputs: | |
markdown.append("### Matplotlib/Seaborn Charts\n") | |
for idx, img_base64 in enumerate(matplotlib_outputs): | |
markdown.append(f"```matplotlib\n{img_base64}\n```\n") | |
# if agent_name is not None: | |
# if f"memory_{agent_name}" in api_response: | |
# markdown.append(f"### Memory\n{api_response[f'memory_{agent_name}']}\n") | |
except Exception as e: | |
logger.log_message(f"Error in format_response_to_markdown: {str(e)}", level=logging.ERROR) | |
return f"{str(e)}" | |
# logger.log_message(f"Generated markdown content for agent '{agent_name}' at {time.strftime('%Y-%m-%d %H:%M:%S')}: {markdown}, length: {len(markdown)}", level=logging.INFO) | |
if not markdown or len(markdown) <= 1: | |
logger.log_message( | |
f"Invalid markdown content for agent '{agent_name}' at {time.strftime('%Y-%m-%d %H:%M:%S')}: " | |
f"Content: '{markdown}', Type: {type(markdown)}, Length: {len(markdown) if markdown else 0}, " | |
f"API Response: {api_response}", | |
level=logging.ERROR | |
) | |
return " " | |
return '\n'.join(markdown) | |
# Example usage with dummy data | |
if __name__ == "__main__": | |
sample_response = { | |
"code_combiner_agent": { | |
"reasoning": "Sample reasoning for multiple charts.", | |
"refined_complete_code": """ | |
```python | |
import plotly.express as px | |
import pandas as pd | |
# Sample Data | |
df = pd.DataFrame({'Category': ['A', 'B', 'C'], 'Values': [10, 20, 30]}) | |
# First Chart | |
fig = px.bar(df, x='Category', y='Values', title='Bar Chart') | |
fig.show() | |
# Second Chart | |
fig2 = px.pie(df, values='Values', names='Category', title='Pie Chart') | |
fig2.show() | |
``` | |
""" | |
} | |
} | |
formatted_md = format_response_to_markdown(sample_response) |