Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import re | |
| import json | |
| import sys | |
| import contextlib | |
| from io import StringIO | |
| import time | |
| import logging | |
| from src.utils.logger import Logger | |
| import textwrap | |
| logger = Logger(__name__, level="INFO", see_time=False, console_log=False) | |
| def stdoutIO(stdout=None): | |
| old = sys.stdout | |
| if stdout is None: | |
| stdout = StringIO() | |
| sys.stdout = stdout | |
| yield stdout | |
| sys.stdout = old | |
| # Precompile regex patterns for better performance | |
| SENSITIVE_MODULES = re.compile(r"(os|sys|subprocess|dotenv|requests|http|socket|smtplib|ftplib|telnetlib|paramiko)") | |
| IMPORT_PATTERN = re.compile(r"^\s*import\s+(" + SENSITIVE_MODULES.pattern + r").*?(\n|$)", re.MULTILINE) | |
| FROM_IMPORT_PATTERN = re.compile(r"^\s*from\s+(" + SENSITIVE_MODULES.pattern + r").*?(\n|$)", re.MULTILINE) | |
| DYNAMIC_IMPORT_PATTERN = re.compile(r"__import__\s*\(\s*['\"](" + SENSITIVE_MODULES.pattern + r")['\"].*?\)") | |
| ENV_ACCESS_PATTERN = re.compile(r"(os\.getenv|os\.environ|load_dotenv|\.__import__\s*\(\s*['\"]os['\"].*?\.environ)") | |
| FILE_ACCESS_PATTERN = re.compile(r"(open\(|read\(|write\(|file\(|with\s+open)") | |
| # Enhanced API key detection patterns | |
| API_KEY_PATTERNS = [ | |
| # Direct key assignments | |
| re.compile(r"(?i)(api_?key|access_?token|secret_?key|auth_?token|password|credential|secret)s?\s*=\s*[\"\'][\w\-\+\/\=]{8,}[\"\']"), | |
| # Function calls with keys | |
| re.compile(r"(?i)\.set_api_key\(\s*[\"\'][\w\-\+\/\=]{8,}[\"\']"), | |
| # Dictionary assignments | |
| re.compile(r"(?i)['\"](?:api_?key|access_?token|secret_?key|auth_?token|password|credential|secret)['\"](?:\s*:\s*)[\"\'][\w\-\+\/\=]{8,}[\"\']"), | |
| # Common key formats (base64-like, hex) | |
| re.compile(r"[\"\'](?:[A-Za-z0-9\+\/\=]{32,}|[0-9a-fA-F]{32,})[\"\']"), | |
| # Bearer token pattern | |
| re.compile(r"[\"\'](Bearer\s+[\w\-\+\/\=]{8,})[\"\']"), | |
| # Inline URL with auth | |
| re.compile(r"https?:\/\/[\w\-\+\/\=]{8,}@") | |
| ] | |
| # Network request patterns | |
| NETWORK_REQUEST_PATTERNS = re.compile(r"(requests\.|urllib\.|http\.|\.post\(|\.get\(|\.connect\()") | |
| # DataFrame creation with hardcoded data - block only this specific pattern | |
| def check_security_concerns(code_str, dataset_names): | |
| """Check code for security concerns and return info about what was found""" | |
| security_concerns = { | |
| "has_concern": False, | |
| "blocked_imports": False, | |
| "blocked_dynamic_imports": False, | |
| "blocked_env_access": False, | |
| "blocked_file_access": False, | |
| "blocked_api_keys": False, | |
| "blocked_network": False, | |
| "blocked_dataframe_invention": False, # Add this new field | |
| "messages": [] | |
| } | |
| dataset_names_pattern = "|".join(re.escape(name) for name in dataset_names) | |
| DATAFRAME_INVENTION_PATTERN = re.compile( | |
| rf"({dataset_names_pattern})\s*=\s*pd\.DataFrame\s*\(\s*\{{\s*[^}}]*\}}", | |
| re.MULTILINE | |
| ) | |
| if DATAFRAME_INVENTION_PATTERN.search(code_str): | |
| security_concerns["has_concern"] = True | |
| security_concerns["blocked_dataframe_invention"] = True | |
| security_concerns["messages"].append(f"DataFrame creation blocked for dataset variables: {', '.join(dataset_names)}") | |
| # Check for sensitive imports | |
| if IMPORT_PATTERN.search(code_str) or FROM_IMPORT_PATTERN.search(code_str): | |
| security_concerns["has_concern"] = True | |
| security_concerns["blocked_imports"] = True | |
| security_concerns["messages"].append("Sensitive module imports blocked") | |
| # Check for __import__ bypass technique | |
| if DYNAMIC_IMPORT_PATTERN.search(code_str): | |
| security_concerns["has_concern"] = True | |
| security_concerns["blocked_dynamic_imports"] = True | |
| security_concerns["messages"].append("Dynamic import of sensitive modules blocked") | |
| # Check for environment variables access | |
| if ENV_ACCESS_PATTERN.search(code_str): | |
| security_concerns["has_concern"] = True | |
| security_concerns["blocked_env_access"] = True | |
| security_concerns["messages"].append("Environment variables access blocked") | |
| # Check for file operations | |
| if FILE_ACCESS_PATTERN.search(code_str): | |
| security_concerns["has_concern"] = True | |
| security_concerns["blocked_file_access"] = True | |
| security_concerns["messages"].append("File operations blocked") | |
| # Check for API key patterns | |
| for pattern in API_KEY_PATTERNS: | |
| if pattern.search(code_str): | |
| security_concerns["has_concern"] = True | |
| security_concerns["blocked_api_keys"] = True | |
| security_concerns["messages"].append("API key/token usage blocked") | |
| break | |
| # Check for network requests | |
| if NETWORK_REQUEST_PATTERNS.search(code_str): | |
| security_concerns["has_concern"] = True | |
| security_concerns["blocked_network"] = True | |
| security_concerns["messages"].append("Network requests blocked") | |
| return security_concerns | |
| def clean_code_for_security(code_str, security_concerns, dataset_names): | |
| """Apply security modifications to the code based on detected concerns""" | |
| modified_code = code_str | |
| dataset_names_pattern = "|".join(re.escape(name) for name in dataset_names) | |
| DATAFRAME_INVENTION_PATTERN = re.compile( | |
| rf"({dataset_names_pattern})\s*=\s*pd\.DataFrame\s*\(\s*\{{\s*[^}}]*\}}", | |
| re.MULTILINE | |
| ) | |
| # Block sensitive imports if needed | |
| if security_concerns["blocked_imports"]: | |
| modified_code = IMPORT_PATTERN.sub(r'# BLOCKED: import \1\n', modified_code) | |
| modified_code = FROM_IMPORT_PATTERN.sub(r'# BLOCKED: from \1\n', modified_code) | |
| # Block dynamic imports if needed | |
| if security_concerns["blocked_dynamic_imports"]: | |
| modified_code = DYNAMIC_IMPORT_PATTERN.sub(r'"BLOCKED_DYNAMIC_IMPORT"', modified_code) | |
| # Block environment access if needed | |
| if security_concerns["blocked_env_access"]: | |
| modified_code = ENV_ACCESS_PATTERN.sub(r'"BLOCKED_ENV_ACCESS"', modified_code) | |
| # Block file operations if needed | |
| if security_concerns["blocked_file_access"]: | |
| modified_code = FILE_ACCESS_PATTERN.sub(r'"BLOCKED_FILE_ACCESS"', modified_code) | |
| # Block API keys if needed | |
| if security_concerns["blocked_api_keys"]: | |
| for pattern in API_KEY_PATTERNS: | |
| modified_code = pattern.sub(r'"BLOCKED_API_KEY"', modified_code) | |
| # Block network requests if needed | |
| if security_concerns["blocked_network"]: | |
| modified_code = NETWORK_REQUEST_PATTERNS.sub(r'"BLOCKED_NETWORK_REQUEST"', modified_code) | |
| # Block DataFrame creation with hardcoded data if detected | |
| if security_concerns["blocked_dataframe_invention"]: | |
| modified_code = DATAFRAME_INVENTION_PATTERN.sub( | |
| r"# BLOCKED_DATAFRAME_INVENTION: \g<0>", | |
| modified_code | |
| ) | |
| # Add warning banner if needed | |
| if security_concerns["has_concern"]: | |
| security_message = "⚠️ SECURITY WARNING: " + ". ".join(security_concerns["messages"]) + "." | |
| modified_code = f"print('{security_message}')\n\n" + modified_code | |
| return modified_code | |
| def format_correlation_output(text): | |
| """Format correlation matrix output for better readability""" | |
| lines = text.split('\n') | |
| formatted_lines = [] | |
| for line in lines: | |
| # Skip empty lines at the beginning | |
| if not line.strip() and not formatted_lines: | |
| continue | |
| if not line.strip(): | |
| formatted_lines.append(line) | |
| continue | |
| # Check if this line contains correlation values or variable names | |
| stripped_line = line.strip() | |
| parts = stripped_line.split() | |
| if len(parts) > 1: | |
| # Check if this is a header line with variable names | |
| if all(part.replace('_', '').replace('-', '').isalpha() for part in parts): | |
| # This is a header row with variable names | |
| formatted_header = f"{'':12}" # Empty first column for row labels | |
| for part in parts: | |
| formatted_header += f"{part:>12}" | |
| formatted_lines.append(formatted_header) | |
| elif any(char.isdigit() for char in stripped_line) and ('.' in stripped_line or '-' in stripped_line): | |
| # This looks like a correlation line with numbers | |
| row_name = parts[0] if parts else "" | |
| values = parts[1:] if len(parts) > 1 else [] | |
| formatted_row = f"{row_name:<12}" | |
| for value in values: | |
| try: | |
| val = float(value) | |
| formatted_row += f"{val:>12.3f}" | |
| except ValueError: | |
| formatted_row += f"{value:>12}" | |
| formatted_lines.append(formatted_row) | |
| else: | |
| # Other lines (like titles) | |
| formatted_lines.append(line) | |
| else: | |
| formatted_lines.append(line) | |
| return '\n'.join(formatted_lines) | |
| def format_summary_stats(text): | |
| """Format summary statistics for better readability""" | |
| lines = text.split('\n') | |
| formatted_lines = [] | |
| for line in lines: | |
| if not line.strip(): | |
| formatted_lines.append(line) | |
| continue | |
| # Check if this is a header line with statistical terms only (missing first column) | |
| stripped_line = line.strip() | |
| if any(stat in stripped_line.lower() for stat in ['count', 'mean', 'median', 'std', 'min', 'max', '25%', '50%', '75%']): | |
| parts = stripped_line.split() | |
| # Check if this is a header row (starts with statistical terms) | |
| if parts and parts[0].lower() in ['count', 'mean', 'median', 'std', 'min', 'max', '25%', '50%', '75%']: | |
| # This is a header row - add proper spacing | |
| formatted_header = f"{'':12}" # Empty first column for row labels | |
| for part in parts: | |
| formatted_header += f"{part:>15}" | |
| formatted_lines.append(formatted_header) | |
| else: | |
| # This is a data row - format normally | |
| row_name = parts[0] if parts else "" | |
| values = parts[1:] if len(parts) > 1 else [] | |
| formatted_row = f"{row_name:<12}" | |
| for value in values: | |
| try: | |
| if '.' in value or 'e' in value.lower(): | |
| val = float(value) | |
| if abs(val) >= 1000000: | |
| formatted_row += f"{val:>15.2e}" | |
| elif abs(val) >= 1: | |
| formatted_row += f"{val:>15.2f}" | |
| else: | |
| formatted_row += f"{val:>15.6f}" | |
| else: | |
| val = int(value) | |
| formatted_row += f"{val:>15}" | |
| except ValueError: | |
| formatted_row += f"{value:>15}" | |
| formatted_lines.append(formatted_row) | |
| else: | |
| # Other lines (titles, etc.) - keep as is | |
| formatted_lines.append(line) | |
| return '\n'.join(formatted_lines) | |
| def clean_print_statements(code_block): | |
| """ | |
| This function cleans up any `print()` statements that might contain unwanted `\n` characters. | |
| It ensures print statements are properly formatted without unnecessary newlines. | |
| """ | |
| # This regex targets print statements, even if they have newlines inside | |
| return re.sub(r'print\((.*?)(\\n.*?)(.*?)\)', r'print(\1\3)', code_block, flags=re.DOTALL) | |
| def remove_code_block_from_summary(summary): | |
| # use regex to remove code block from summary list | |
| summary = re.sub(r'```python\n(.*?)\n```', '', summary) | |
| return summary.split("\n") | |
| def remove_main_block(code): | |
| # Match the __main__ block | |
| pattern = r'(?m)^if\s+__name__\s*==\s*["\']__main__["\']\s*:\s*\n((?:\s+.*\n?)*)' | |
| match = re.search(pattern, code) | |
| if match: | |
| main_block = match.group(1) | |
| # Dedent the code block inside __main__ | |
| dedented_block = textwrap.dedent(main_block) | |
| # Remove \n from any print statements in the block (also handling multiline print cases) | |
| dedented_block = clean_print_statements(dedented_block) | |
| # Replace the block in the code | |
| cleaned_code = re.sub(pattern, dedented_block, code) | |
| # Optional: Remove leading newlines if any | |
| cleaned_code = cleaned_code.strip() | |
| return cleaned_code | |
| return code | |
| def format_code_block(code_str): | |
| code_clean = re.sub(r'^```python\n?', '', code_str, flags=re.MULTILINE) | |
| code_clean = re.sub(r'\n```$', '', code_clean) | |
| return f'\n{code_clean}\n' | |
| def format_code_backticked_block(code_str): | |
| # Add None check at the beginning | |
| if code_str is None: | |
| return "```python\n# No code available\n```" # Return empty code block instead of None | |
| # Add type check to ensure it's a string | |
| if not isinstance(code_str, str): | |
| return f"```python\n# Invalid code type: {type(code_str)}\n```" | |
| code_clean = re.sub(r'^```python\n?', '', code_str, flags=re.MULTILINE) | |
| code_clean = re.sub(r'\n```$', '', code_clean) | |
| # Only match assignments at top level (not indented) | |
| # 1. Remove 'df = pd.DataFrame()' if it's at the top level | |
| # Remove reading the csv file if it's already in the context | |
| modified_code = re.sub(r"df\s*=\s*pd\.read_csv\([\"\'].*?[\"\']\).*?(\n|$)", '', code_clean) | |
| modified_code = re.sub(r'^(\s*)(df\s*=.*)$', r'\1# \2', code_clean, flags=re.MULTILINE) | |
| # Only match assignments at top level (not indented) | |
| # 1. Remove 'df = pd.DataFrame()' if it's at the top level | |
| modified_code = re.sub( | |
| r"^df\s*=\s*pd\.DataFrame\(\s*\)\s*(#.*)?$", | |
| '', | |
| modified_code, | |
| flags=re.MULTILINE | |
| ) | |
| # # Remove sample dataframe lines with multiple array values | |
| modified_code = re.sub(r"^# Sample DataFrames?.*?(\n|$)", '', modified_code, flags=re.MULTILINE | re.IGNORECASE) | |
| # # Remove plt.show() statements | |
| modified_code = re.sub(r"plt\.show\(\).*?(\n|$)", '', modified_code) | |
| # remove main | |
| code_clean = remove_main_block(modified_code) | |
| return f'```python\n{code_clean}\n```' | |
| def execute_code_from_markdown(code_str, datasets=None): | |
| import pandas as pd | |
| import plotly.express as px | |
| import plotly | |
| import plotly.graph_objects as go | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import numpy as np | |
| import re | |
| import traceback | |
| import sys | |
| from io import StringIO, BytesIO | |
| import base64 | |
| context_names = list(datasets.keys()) | |
| # Check for security concerns in the code | |
| security_concerns = check_security_concerns(code_str, context_names) | |
| # Apply security modifications to the code | |
| modified_code = clean_code_for_security(code_str, security_concerns, context_names) | |
| # Enhanced print function that detects and formats tabular data | |
| captured_outputs = [] | |
| original_print = print | |
| # Set pandas display options for full table display | |
| pd.set_option('display.max_columns', None) | |
| pd.set_option('display.max_rows', 20) # Limit to 20 rows instead of unlimited | |
| pd.set_option('display.width', None) | |
| pd.set_option('display.max_colwidth', 50) | |
| pd.set_option('display.expand_frame_repr', False) | |
| def enhanced_print(*args, **kwargs): | |
| # Convert all args to strings | |
| str_args = [str(arg) for arg in args] | |
| output_text = kwargs.get('sep', ' ').join(str_args) | |
| # Special case for DataFrames - use pipe delimiter and clean format | |
| if isinstance(args[0], pd.DataFrame) and len(args) == 1: | |
| # Format DataFrame with pipe delimiter using to_csv for reliable column separation | |
| df = args[0] | |
| # Use StringIO to capture CSV output with pipe delimiter | |
| from io import StringIO | |
| csv_buffer = StringIO() | |
| # Export to CSV with pipe delimiter, preserving index | |
| df.to_csv(csv_buffer, sep='|', index=True, float_format='%.6g') | |
| csv_output = csv_buffer.getvalue() | |
| # Clean up the CSV output - remove quotes and extra formatting | |
| lines = csv_output.strip().split('\n') | |
| cleaned_lines = [] | |
| for line in lines: | |
| # Remove any quotes that might have been added by to_csv | |
| clean_line = line.replace('"', '') | |
| # Split by pipe, strip whitespace from each part, then rejoin | |
| parts = [part.strip() for part in clean_line.split('|')] | |
| cleaned_lines.append(' | '.join(parts)) | |
| output_text = '\n'.join(cleaned_lines) | |
| captured_outputs.append(f"<TABLE_START>\n{output_text}\n<TABLE_END>") | |
| original_print(output_text) | |
| return | |
| # Detect if this looks like tabular data (generic approach) | |
| is_table = False | |
| # Check for table patterns: | |
| # 1. Multiple lines with consistent spacing | |
| lines = output_text.split('\n') | |
| if len(lines) > 2: | |
| # Count lines that look like they have multiple columns (2+ spaces between words) | |
| multi_column_lines = sum(1 for line in lines if len(line.split()) > 1 and ' ' in line) | |
| if multi_column_lines >= 2: # At least 2 lines with multiple columns | |
| is_table = True | |
| # Check for pandas DataFrame patterns like index with column names | |
| if any(re.search(r'^\s*\d+\s+', line) for line in lines): | |
| # Look for lines starting with an index number followed by spaces | |
| is_table = True | |
| # Look for table-like structured output with multiple rows of similar format | |
| if len(lines) >= 3: | |
| # Sample a few lines to check for consistent structure | |
| sample_lines = [lines[i] for i in range(min(len(lines), 5)) if i < len(lines) and lines[i].strip()] | |
| # Check for consistent whitespace patterns | |
| if len(sample_lines) >= 2: | |
| # Get positions of whitespace groups in first line | |
| whitespace_positions = [] | |
| for i, line in enumerate(sample_lines): | |
| if not line.strip(): | |
| continue | |
| positions = [m.start() for m in re.finditer(r'\s{2,}', line)] | |
| if i == 0: | |
| whitespace_positions = positions | |
| elif len(positions) == len(whitespace_positions): | |
| # Check if whitespace positions are roughly the same | |
| is_similar = all(abs(pos - whitespace_positions[j]) <= 3 | |
| for j, pos in enumerate(positions) | |
| if j < len(whitespace_positions)) | |
| if is_similar: | |
| is_table = True | |
| # 2. Contains common table indicators | |
| if any(indicator in output_text.lower() for indicator in [ | |
| 'count', 'mean', 'std', 'min', 'max', '25%', '50%', '75%', # Summary stats | |
| 'correlation', 'corr', # Correlation tables | |
| 'coefficient', 'r-squared', 'p-value', # Regression tables | |
| ]): | |
| is_table = True | |
| # 3. Has many decimal numbers (likely a data table) | |
| if output_text.count('.') > 5 and len(lines) > 2: | |
| is_table = True | |
| # If we have detected a table, convert space-delimited to pipe-delimited format | |
| if is_table: | |
| # Convert the table to pipe-delimited format for better parsing in frontend | |
| formatted_lines = [] | |
| for line in lines: | |
| if not line.strip(): | |
| formatted_lines.append(line) # Keep empty lines | |
| continue | |
| # Split by multiple spaces and join with pipe delimiter | |
| parts = re.split(r'\s{2,}', line.strip()) | |
| if parts: | |
| formatted_lines.append(" | ".join(parts)) | |
| else: | |
| formatted_lines.append(line) | |
| # Use the pipe-delimited format | |
| output_text = "\n".join(formatted_lines) | |
| # Format and mark the output for table processing in UI | |
| captured_outputs.append(f"<TABLE_START>\n{output_text}\n<TABLE_END>") | |
| else: | |
| captured_outputs.append(output_text) | |
| # Also use original print for stdout capture | |
| original_print(*args, **kwargs) | |
| # Custom matplotlib capture function | |
| def capture_matplotlib_chart(): | |
| """Capture current matplotlib figure as base64 encoded image""" | |
| try: | |
| fig = plt.gcf() # Get current figure | |
| if fig.get_axes(): # Check if figure has any plots | |
| buffer = BytesIO() | |
| fig.savefig(buffer, format='png', dpi=150, bbox_inches='tight', | |
| facecolor='white', edgecolor='none') | |
| buffer.seek(0) | |
| img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
| buffer.close() | |
| plt.close(fig) # Close the figure to free memory | |
| return img_base64 | |
| return None | |
| except Exception: | |
| return None | |
| # Store original plt.show function | |
| original_plt_show = plt.show | |
| def custom_plt_show(*args, **kwargs): | |
| """Custom plt.show that captures the chart instead of displaying it""" | |
| img_base64 = capture_matplotlib_chart() | |
| if img_base64: | |
| matplotlib_outputs.append(img_base64) | |
| # Don't call original show to prevent display | |
| context = { | |
| 'pd': pd, | |
| 'px': px, | |
| 'go': go, | |
| 'plt': plt, | |
| 'plotly': plotly, | |
| '__builtins__': __builtins__, | |
| '__import__': __import__, | |
| 'sns': sns, | |
| 'np': np, | |
| 'json_outputs': [], # List to store multiple Plotly JSON outputs | |
| 'matplotlib_outputs': [], # List to store matplotlib chart images as base64 | |
| 'print': enhanced_print # Replace print with our enhanced version | |
| } | |
| # Add matplotlib_outputs to local scope for the custom show function | |
| matplotlib_outputs = context['matplotlib_outputs'] | |
| # Replace plt.show with our custom function | |
| plt.show = custom_plt_show | |
| # Modify code to store multiple JSON outputs | |
| modified_code = re.sub( | |
| r'(\w*_?)fig(\w*)\.show\(\)', | |
| r'json_outputs.append(plotly.io.to_json(\1fig\2, pretty=True))', | |
| modified_code | |
| ) | |
| modified_code = re.sub( | |
| r'(\w*_?)fig(\w*)\.to_html\(.*?\)', | |
| r'json_outputs.append(plotly.io.to_json(\1fig\2, pretty=True))', | |
| modified_code | |
| ) | |
| # Remove reading the csv file if it's already in the context | |
| modified_code = re.sub(r"df\s*=\s*pd\.read_csv\([\"\'].*?[\"\']\).*?(\n|$)", '', modified_code) | |
| # Only match assignments at top level (not indented) | |
| # 1. Remove 'df = pd.DataFrame()' if it's at the top level | |
| modified_code = re.sub( | |
| r"^df\s*=\s*pd\.DataFrame\(\s*\)\s*(#.*)?$", | |
| '', | |
| modified_code, | |
| flags=re.MULTILINE | |
| ) | |
| # Custom display function for DataFrames to show head + tail for large datasets | |
| original_repr = pd.DataFrame.__repr__ | |
| def custom_df_repr(self): | |
| if len(self) > 15: | |
| # For large DataFrames, show first 10 and last 5 rows | |
| head_part = self.head(10) | |
| tail_part = self.tail(5) | |
| head_str = head_part.__repr__() | |
| tail_str = tail_part.__repr__() | |
| # Extract just the data rows (skip the header from tail) | |
| tail_lines = tail_str.split('\n') | |
| tail_data = '\n'.join(tail_lines[1:]) # Skip header line | |
| return f"{head_str}\n...\n{tail_data}" | |
| else: | |
| return original_repr(self) | |
| # Apply custom representation temporarily | |
| pd.DataFrame.__repr__ = custom_df_repr | |
| # If a dataframe is provided, add it to the context | |
| for dataset_name, dataset_df in datasets.items(): | |
| if dataset_df is not None: | |
| context[dataset_name] = dataset_df | |
| logger.log_message(f"Added dataset '{dataset_name}' to execution context", level=logging.DEBUG) | |
| # remove pd.read_csv() if it's already in the context | |
| modified_code = re.sub(r"pd\.read_csv\(\s*[\"\'].*?[\"\']\s*\)", '', modified_code) | |
| # Remove sample dataframe lines with multiple array values | |
| modified_code = re.sub(r"^# Sample DataFrames?.*?(\n|$)", '', modified_code, flags=re.MULTILINE | re.IGNORECASE) | |
| # Replace plt.savefig() calls with plt.show() to ensure plots are displayed | |
| modified_code = re.sub(r'plt\.savefig\([^)]*\)', 'plt.show()', modified_code) | |
| # Instead of removing plt.show(), keep them - they'll be handled by our custom function | |
| # Also handle seaborn plots that might not have explicit plt.show() | |
| # Add plt.show() after seaborn plot functions if not already present | |
| seaborn_plot_functions = [ | |
| 'sns.scatterplot', 'sns.lineplot', 'sns.barplot', 'sns.boxplot', 'sns.violinplot', | |
| 'sns.stripplot', 'sns.swarmplot', 'sns.pointplot', 'sns.catplot', 'sns.relplot', | |
| 'sns.displot', 'sns.histplot', 'sns.kdeplot', 'sns.ecdfplot', 'sns.rugplot', | |
| 'sns.distplot', 'sns.jointplot', 'sns.pairplot', 'sns.FacetGrid', 'sns.PairGrid', | |
| 'sns.heatmap', 'sns.clustermap', 'sns.regplot', 'sns.lmplot', 'sns.residplot' | |
| ] | |
| # Add automatic plt.show() after seaborn plots if not already present | |
| for func in seaborn_plot_functions: | |
| pattern = rf'({re.escape(func)}\([^)]*\)(?:\.[^(]*\([^)]*\))*)' | |
| def add_show(match): | |
| plot_call = match.group(1) | |
| # Check if the next non-empty line already has plt.show() | |
| return f'{plot_call}\nplt.show()' | |
| modified_code = re.sub(pattern, add_show, modified_code) | |
| # Only add df = pd.read_csv() if no dataframe was provided and the code contains pd.read_csv | |
| # Identify code blocks by comments | |
| code_blocks = [] | |
| current_block = [] | |
| current_block_name = "unknown" | |
| for line in modified_code.splitlines(): | |
| # Check if line contains a block identifier comment | |
| block_match = re.match(r'^# ([a-zA-Z_]+)_agent code start', line) | |
| if block_match: | |
| # If we had a previous block, save it | |
| if current_block: | |
| code_blocks.append((current_block_name, '\n'.join(current_block))) | |
| # Start a new block | |
| current_block_name = block_match.group(1) | |
| current_block = [] | |
| else: | |
| current_block.append(line) | |
| # Add the last block if it exists | |
| if current_block: | |
| code_blocks.append((current_block_name, '\n'.join(current_block))) | |
| # Execute each code block separately | |
| all_outputs = [] | |
| for block_name, block_code in code_blocks: | |
| try: | |
| # Clear captured outputs for each block | |
| captured_outputs.clear() | |
| # Fix indentation issues before execution | |
| try: | |
| block_code = textwrap.dedent(block_code) | |
| except Exception as dedent_error: | |
| logger.log_message(f"Failed to dedent code block '{block_name}': {str(dedent_error)}", level=logging.WARNING) | |
| with stdoutIO() as s: | |
| exec(block_code, context) # Execute the block | |
| # Get both stdout and our enhanced captured outputs | |
| stdout_output = s.getvalue() | |
| # Combine outputs, preferring our enhanced format when available | |
| if captured_outputs: | |
| combined_output = '\n'.join(captured_outputs) | |
| else: | |
| combined_output = stdout_output | |
| all_outputs.append((block_name, combined_output, None)) # None means no error | |
| except Exception as e: | |
| # Reset pandas options in case of error | |
| pd.reset_option('display.max_columns') | |
| pd.reset_option('display.max_rows') | |
| pd.reset_option('display.width') | |
| pd.reset_option('display.max_colwidth') | |
| pd.reset_option('display.expand_frame_repr') | |
| # Restore original DataFrame representation in case of error | |
| pd.DataFrame.__repr__ = original_repr | |
| # Restore original plt.show | |
| plt.show = original_plt_show | |
| error_traceback = traceback.format_exc() | |
| # Extract error message and error type | |
| error_message = str(e) | |
| error_type = type(e).__name__ | |
| error_lines = error_traceback.splitlines() | |
| # Format error with context of the actual code | |
| formatted_error = f"Error in {block_name}_agent: {error_message}\n" | |
| # Add first few lines of traceback | |
| first_lines = error_lines[:3] | |
| formatted_error += "\n".join(first_lines) + "\n" | |
| # Parse problem variables/values from the error message | |
| problem_vars = [] | |
| # Look for common error patterns | |
| if "not in index" in error_message: | |
| # Extract column names for 'not in index' errors | |
| column_match = re.search(r"\['([^']+)'(?:, '([^']+)')*\] not in index", error_message) | |
| if column_match: | |
| problem_vars = [g for g in column_match.groups() if g is not None] | |
| # Look for DataFrame accessing operations and list/variable definitions | |
| potential_lines = [] | |
| code_lines = block_code.splitlines() | |
| # First, find all DataFrame column access patterns | |
| df_access_patterns = [] | |
| for i, line in enumerate(code_lines): | |
| # Find DataFrame variables from patterns like "df_name[...]" or "df_name.loc[...]" | |
| df_matches = re.findall(r'(\w+)(?:\[|\.)(?:loc|iloc|columns|at|iat|\.select)', line) | |
| for df_var in df_matches: | |
| df_access_patterns.append((i, df_var)) | |
| # Find variables that might contain column lists | |
| for var in problem_vars: | |
| if re.search(r'\b(numeric_columns|categorical_columns|columns|features|cols)\b', line): | |
| potential_lines.append(i) | |
| # Identify the most likely problematic lines | |
| if df_access_patterns: | |
| for i, df_var in df_access_patterns: | |
| if any(re.search(rf'{df_var}\[.*?\]', line) for line in code_lines): | |
| potential_lines.append(i) | |
| # If no specific lines found yet, look for any DataFrame operations | |
| if not potential_lines: | |
| for i, line in enumerate(code_lines): | |
| if re.search(r'(?:corr|drop|groupby|pivot|merge|join|concat|apply|map|filter|loc|iloc)\(', line): | |
| potential_lines.append(i) | |
| # Sort and deduplicate | |
| potential_lines = sorted(set(potential_lines)) | |
| elif "name" in error_message and "is not defined" in error_message: | |
| # Extract variable name for NameError | |
| var_match = re.search(r"name '([^']+)' is not defined", error_message) | |
| if var_match: | |
| problem_vars = [var_match.group(1)] | |
| elif "object has no attribute" in error_message: | |
| # Extract attribute name for AttributeError | |
| attr_match = re.search(r"'([^']+)' object has no attribute '([^']+)'", error_message) | |
| if attr_match: | |
| problem_vars = [f"{attr_match.group(1)}.{attr_match.group(2)}"] | |
| # Scan code for lines containing the problem variables | |
| if problem_vars: | |
| formatted_error += "\nProblem likely in these lines:\n" | |
| code_lines = block_code.splitlines() | |
| problem_lines = [] | |
| # First try direct variable references | |
| direct_matches = False | |
| for i, line in enumerate(code_lines): | |
| if any(var in line for var in problem_vars): | |
| direct_matches = True | |
| # Get line and its context (1 line before and after) | |
| start_idx = max(0, i-1) | |
| end_idx = min(len(code_lines), i+2) | |
| for j in range(start_idx, end_idx): | |
| line_prefix = f"{j+1}: " | |
| if j == i: # The line with the problem variable | |
| problem_lines.append(f"{line_prefix}>>> {code_lines[j]} <<<") | |
| else: | |
| problem_lines.append(f"{line_prefix}{code_lines[j]}") | |
| problem_lines.append("") # Empty line between sections | |
| # If no direct matches found but we identified potential problematic lines for DataFrame issues | |
| if not direct_matches and "not in index" in error_message and 'potential_lines' in locals(): | |
| for i in potential_lines: | |
| start_idx = max(0, i-1) | |
| end_idx = min(len(code_lines), i+2) | |
| for j in range(start_idx, end_idx): | |
| line_prefix = f"{j+1}: " | |
| if j == i: | |
| problem_lines.append(f"{line_prefix}>>> {code_lines[j]} <<<") | |
| else: | |
| problem_lines.append(f"{line_prefix}{code_lines[j]}") | |
| problem_lines.append("") # Empty line between sections | |
| if problem_lines: | |
| formatted_error += "\n".join(problem_lines) | |
| else: | |
| # Special message for column errors when we can't find the exact reference | |
| if "not in index" in error_message: | |
| formatted_error += (f"Unable to locate direct reference to columns: {', '.join(problem_vars)}\n" | |
| f"Check for variables that might contain these column names (like numeric_columns, " | |
| f"categorical_columns, etc.)\n") | |
| else: | |
| formatted_error += f"Unable to locate lines containing: {', '.join(problem_vars)}\n" | |
| else: | |
| # If we couldn't identify specific variables, check for line numbers in traceback | |
| for line in reversed(error_lines): # Search from the end of traceback | |
| # Look for user code references in the traceback | |
| if ', line ' in line and '<module>' in line: | |
| try: | |
| line_num = int(re.search(r', line (\d+)', line).group(1)) | |
| code_lines = block_code.splitlines() | |
| if 0 < line_num <= len(code_lines): | |
| line_idx = line_num - 1 | |
| start_idx = max(0, line_idx-2) | |
| end_idx = min(len(code_lines), line_idx+3) | |
| formatted_error += "\nProblem at this location:\n" | |
| for i in range(start_idx, end_idx): | |
| line_prefix = f"{i+1}: " | |
| if i == line_idx: | |
| formatted_error += f"{line_prefix}>>> {code_lines[i]} <<<\n" | |
| else: | |
| formatted_error += f"{line_prefix}{code_lines[i]}\n" | |
| break | |
| except (ValueError, AttributeError, IndexError): | |
| pass | |
| # Add the last few lines of the traceback | |
| formatted_error += "\nFull error details:\n" | |
| last_lines = error_lines[-3:] | |
| formatted_error += "\n".join(last_lines) | |
| all_outputs.append((block_name, None, formatted_error)) | |
| # Reset pandas options after execution | |
| pd.reset_option('display.max_columns') | |
| pd.reset_option('display.max_rows') | |
| pd.reset_option('display.width') | |
| pd.reset_option('display.max_colwidth') | |
| pd.reset_option('display.expand_frame_repr') | |
| # Restore original DataFrame representation | |
| pd.DataFrame.__repr__ = original_repr | |
| # Restore original plt.show | |
| plt.show = original_plt_show | |
| # Compile all outputs and errors | |
| output_text = "" | |
| json_outputs = context.get('json_outputs', []) | |
| matplotlib_outputs = context.get('matplotlib_outputs', []) | |
| error_found = False | |
| for block_name, output, error in all_outputs: | |
| if error: | |
| output_text += f"\n\n=== ERROR IN {block_name.upper()}_AGENT ===\n{error}\n" | |
| error_found = True | |
| elif output: | |
| output_text += f"\n\n=== OUTPUT FROM {block_name.upper()}_AGENT ===\n{output}\n" | |
| if error_found: | |
| return output_text, [], [] | |
| else: | |
| return output_text, json_outputs, matplotlib_outputs | |
| def format_plan_instructions(plan_instructions): | |
| """ | |
| Format any plan instructions (JSON string or dict) into markdown sections per agent. | |
| """ | |
| # Parse input into a dict | |
| if "basic_qa_agent" in str(plan_instructions): | |
| return "**Non-Data Request**: Please ask a data related query, don't waste credits!" | |
| try: | |
| if isinstance(plan_instructions, str): | |
| try: | |
| instructions = json.loads(plan_instructions) | |
| except json.JSONDecodeError as e: | |
| # Try to clean the string if it's not valid JSON | |
| cleaned_str = plan_instructions.strip() | |
| if cleaned_str.startswith("'") and cleaned_str.endswith("'"): | |
| cleaned_str = cleaned_str[1:-1] | |
| try: | |
| instructions = json.loads(cleaned_str) | |
| except json.JSONDecodeError: | |
| raise ValueError(f"Invalid JSON format in plan instructions: {str(e)}") | |
| elif isinstance(plan_instructions, dict): | |
| instructions = plan_instructions | |
| else: | |
| raise TypeError(f"Unsupported plan instructions type: {type(plan_instructions)}") | |
| except Exception as e: | |
| raise ValueError(f"Error processing plan instructions: {str(e)} + {dspy.settings.lm} ") | |
| # logger.log_message(f"Plan instructions: {instructions}", level=logging.INFO) | |
| markdown_lines = [] | |
| for agent, content in instructions.items(): | |
| if agent != 'basic_qa_agent': | |
| agent_title = agent.replace('_', ' ').title() | |
| markdown_lines.append(f"#### {agent_title}") | |
| if isinstance(content, dict): | |
| # Handle 'create' key | |
| create_vals = content.get('create', []) | |
| if create_vals: | |
| markdown_lines.append(f"- **Create**:") | |
| for item in create_vals: | |
| markdown_lines.append(f" - {item}") | |
| else: | |
| markdown_lines.append(f"- **Create**: None") | |
| # Handle 'use' key | |
| use_vals = content.get('use', []) | |
| if use_vals: | |
| markdown_lines.append(f"- **Use**:") | |
| for item in use_vals: | |
| markdown_lines.append(f" - {item}") | |
| else: | |
| markdown_lines.append(f"- **Use**: None") | |
| # Handle 'instruction' key | |
| instr = content.get('instruction') | |
| if isinstance(instr, str) and instr: | |
| markdown_lines.append(f"- **Instruction**: {instr}") | |
| else: | |
| markdown_lines.append(f"- **Instruction**: None") | |
| else: | |
| # Fallback for non-dict content | |
| markdown_lines.append(f"- {content}") | |
| markdown_lines.append("") # blank line between agents | |
| else: | |
| markdown_lines.append(f"**Non-Data Request**: {content.get('instruction')}") | |
| return "\n".join(markdown_lines).strip() | |
| def format_complexity(instructions): | |
| markdown_lines = [] | |
| # Extract complexity from various possible locations in the structure | |
| if isinstance(instructions, dict): | |
| # Case 1: Direct complexity field | |
| if 'complexity' in instructions: | |
| complexity = instructions['complexity'] | |
| # Case 2: Complexity in 'plan' object | |
| elif 'plan' in instructions and isinstance(instructions['plan'], dict): | |
| if 'complexity' in instructions['plan']: | |
| complexity = instructions['plan']['complexity'] | |
| else: | |
| complexity = "unrelated" | |
| if 'plan' in instructions and isinstance(instructions['plan'], str) and "basic_qa_agent" in instructions['plan']: | |
| complexity = "unrelated" | |
| if complexity: | |
| # Pink color scheme variations | |
| color_map = { | |
| "unrelated": "#FFB6B6", # Light pink | |
| "basic": "#FF9E9E", # Medium pink | |
| "intermediate": "#FF7F7F", # Main pink | |
| "advanced": "#FF5F5F" # Dark pink | |
| } | |
| indicator_map = { | |
| "unrelated": "○", | |
| "basic": "●", | |
| "intermediate": "●●", | |
| "advanced": "●●●" | |
| } | |
| color = color_map.get(complexity.lower(), "#FFB6B6") # Default to light pink | |
| indicator = indicator_map.get(complexity.lower(), "○") | |
| # Slightly larger display with pink styling | |
| markdown_lines.append(f"<div style='color: {color}; border: 2px solid {color}; padding: 2px 8px; border-radius: 12px; display: inline-block; font-size: 14.4px;'>{indicator} {complexity}</div>\n") | |
| return "\n".join(markdown_lines).strip() | |
| def format_response_to_markdown(api_response, agent_name = None, datasets=None): | |
| try: | |
| markdown = [] | |
| # logger.log_message(f"API response for {agent_name} at {time.strftime('%Y-%m-%d %H:%M:%S')}: {api_response}", level=logging.INFO) | |
| if isinstance(api_response, dict): | |
| for key in api_response: | |
| if "error" in api_response[key] and "litellm.RateLimitError" in api_response[key]['error'].lower(): | |
| return f"**Error**: Rate limit exceeded. Please try switching models from the settings." | |
| # You can add more checks here if needed for other keys | |
| # Handle error responses | |
| if isinstance(api_response, dict) and "error" in api_response: | |
| return f"**Error**: {api_response['error']}" | |
| if "response" in api_response and isinstance(api_response['response'], str): | |
| if any(err in api_response['response'].lower() for err in ["auth", "api", "lm"]): | |
| return "**Error**: Authentication failed. Please check your API key in settings and try again." | |
| if "model" in api_response['response'].lower(): | |
| return "**Error**: Model configuration error. Please verify your model selection in settings." | |
| for agent, content in api_response.items(): | |
| agent = agent.split("__")[0] if "__" in agent else agent | |
| if "memory" in agent or not content: | |
| continue | |
| if "complexity" in content: | |
| markdown.append(f"{format_complexity(content)}\n") | |
| markdown.append(f"\n## {agent.replace('_', ' ').title()}\n") | |
| if agent == "analytical_planner": | |
| logger.log_message(f"Analytical planner content: {content}", level=logging.INFO) | |
| if 'plan_desc' in content: | |
| markdown.append(f"### Reasoning\n{content['plan_desc']}\n") | |
| if 'plan_instructions' in content: | |
| markdown.append(f"{format_plan_instructions(content['plan_instructions'])}\n") | |
| else: | |
| markdown.append(f"### Reasoning\n{content['rationale']}\n") | |
| else: | |
| if "rationale" in content: | |
| markdown.append(f"### Reasoning\n{content['rationale']}\n") | |
| if 'code' in content and content['code'] is not None: | |
| formatted_code = format_code_backticked_block(content['code']) | |
| if formatted_code: # Check if formatted_code is not None | |
| markdown.append(f"### Code Implementation\n{formatted_code}\n") | |
| if 'answer' in content: | |
| markdown.append(f"### Answer\n{content['answer']}\n Please ask a query about the data") | |
| if 'summary' in content: | |
| import re | |
| summary_text = content['summary'] | |
| summary_text = re.sub(r'```python\n(.*?)\n```', '', summary_text, flags=re.DOTALL) | |
| markdown.append("### Summary\n") | |
| # Extract pre-list intro, bullet points, and post-list text | |
| intro_match = re.split(r'\(\d+\)', summary_text, maxsplit=1) | |
| if len(intro_match) > 1: | |
| intro_text = intro_match[0].strip() | |
| rest_text = "(1)" + intro_match[1] # reattach for bullet parsing | |
| else: | |
| intro_text = summary_text.strip() | |
| rest_text = "" | |
| if intro_text: | |
| markdown.append(f"{intro_text}\n") | |
| # Split bullets at numbered items like (1)...(8) | |
| bullets = re.split(r'\(\d+\)', rest_text) | |
| bullets = [b.strip(" ,.\n") for b in bullets if b.strip()] | |
| # Check for post-list content (anything after the last number) | |
| for i, bullet in enumerate(bullets): | |
| markdown.append(f"* {bullet}\n") | |
| if 'refined_complete_code' in content and 'summary' in content: | |
| try: | |
| if content['refined_complete_code'] is not None and content['refined_complete_code'] != "": | |
| clean_code = format_code_block(content['refined_complete_code']) | |
| markdown_code = format_code_backticked_block(content['refined_complete_code']) | |
| output, json_outputs, matplotlib_outputs = execute_code_from_markdown(clean_code, datasets) | |
| elif "```python" in content['summary']: | |
| clean_code = format_code_block(content['summary']) | |
| markdown_code = format_code_backticked_block(content['summary']) | |
| output, json_outputs, matplotlib_outputs = execute_code_from_markdown(clean_code, datasets) | |
| except Exception as e: | |
| logger.log_message(f"Error in execute_code_from_markdown: {str(e)}", level=logging.ERROR) | |
| markdown_code = f"**Error**: {str(e)}" | |
| output = None | |
| json_outputs = [] | |
| matplotlib_outputs = [] | |
| # continue | |
| if markdown_code is not None: | |
| markdown.append(f"### Refined Complete Code\n{markdown_code}\n") | |
| if output: | |
| markdown.append("### Execution Output\n") | |
| markdown.append(f"```output\n{output}\n```\n") | |
| if json_outputs: | |
| markdown.append("### Plotly JSON Outputs\n") | |
| for idx, json_output in enumerate(json_outputs): | |
| markdown.append(f"```plotly\n{json_output}\n```\n") | |
| if matplotlib_outputs: | |
| markdown.append("### Matplotlib/Seaborn Charts\n") | |
| for idx, img_base64 in enumerate(matplotlib_outputs): | |
| markdown.append(f"```matplotlib\n{img_base64}\n```\n") | |
| # if agent_name is not None: | |
| # if f"memory_{agent_name}" in api_response: | |
| # markdown.append(f"### Memory\n{api_response[f'memory_{agent_name}']}\n") | |
| except Exception as e: | |
| logger.log_message(f"Error in format_response_to_markdown: {str(e)}", level=logging.ERROR) | |
| return f"error formating markdown {str(e)}" | |
| # logger.log_message(f"Generated markdown content for agent '{agent_name}' at {time.strftime('%Y-%m-%d %H:%M:%S')}: {markdown}, length: {len(markdown)}", level=logging.INFO) | |
| if not markdown or len(markdown) <= 1: | |
| logger.log_message( | |
| f"Invalid markdown content for agent '{agent_name}' at {time.strftime('%Y-%m-%d %H:%M:%S')}: " | |
| f"Content: '{markdown}', Type: {type(markdown)}, Length: {len(markdown) if markdown else 0}, " | |
| f"API Response: {api_response}", | |
| level=logging.ERROR | |
| ) | |
| return "" | |
| return '\n'.join(markdown) | |