Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 48,864 Bytes
1c66926 6404ca9 1c66926 5de783f 1c66926 6ae6757 46db677 1c66926 6404ca9 5de783f 46db677 5de783f 1c66926 5de783f 46db677 5de783f 1c66926 5de783f 46db677 5de783f 1c66926 6404ca9 1c66926 46db677 1c66926 5de783f 1c66926 6ae6757 1c66926 6ae6757 1c66926 6ae6757 1c66926 46db677 1c66926 46db677 1c66926 6ae6757 5de783f 6ae6757 5de783f 6ae6757 5de783f 46db677 6ae6757 5de783f 46db677 6ae6757 46db677 6ae6757 46db677 6ae6757 46db677 6ae6757 fd2969e 5de783f 1c66926 6ae6757 fd2969e 5de783f 6ae6757 fd2969e 6ae6757 fd2969e 6ae6757 fd2969e 8ae91f2 fd2969e 6ae6757 fd2969e 6ae6757 5de783f db749fb 5de783f 1c66926 6ae6757 1c66926 6ae6757 1c66926 5de783f db749fb 1c66926 db749fb 0b14a8b 6ae6757 db749fb 0b14a8b 1c66926 fd2969e 1c66926 5de783f 1c66926 5de783f 1c66926 46db677 1c66926 46db677 1c66926 0b14a8b 46db677 1c66926 0b14a8b 1c66926 46db677 1c66926 0b14a8b 1c66926 fd2969e 1c66926 5de783f 1c66926 0b14a8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 |
import re
import json
import sys
import contextlib
from io import StringIO
import time
import logging
from src.utils.logger import Logger
import textwrap
logger = Logger(__name__, level="INFO", see_time=False, console_log=False)
@contextlib.contextmanager
def stdoutIO(stdout=None):
old = sys.stdout
if stdout is None:
stdout = StringIO()
sys.stdout = stdout
yield stdout
sys.stdout = old
# Precompile regex patterns for better performance
SENSITIVE_MODULES = re.compile(r"(os|sys|subprocess|dotenv|requests|http|socket|smtplib|ftplib|telnetlib|paramiko)")
IMPORT_PATTERN = re.compile(r"^\s*import\s+(" + SENSITIVE_MODULES.pattern + r").*?(\n|$)", re.MULTILINE)
FROM_IMPORT_PATTERN = re.compile(r"^\s*from\s+(" + SENSITIVE_MODULES.pattern + r").*?(\n|$)", re.MULTILINE)
DYNAMIC_IMPORT_PATTERN = re.compile(r"__import__\s*\(\s*['\"](" + SENSITIVE_MODULES.pattern + r")['\"].*?\)")
ENV_ACCESS_PATTERN = re.compile(r"(os\.getenv|os\.environ|load_dotenv|\.__import__\s*\(\s*['\"]os['\"].*?\.environ)")
FILE_ACCESS_PATTERN = re.compile(r"(open\(|read\(|write\(|file\(|with\s+open)")
# Enhanced API key detection patterns
API_KEY_PATTERNS = [
# Direct key assignments
re.compile(r"(?i)(api_?key|access_?token|secret_?key|auth_?token|password|credential|secret)s?\s*=\s*[\"\'][\w\-\+\/\=]{8,}[\"\']"),
# Function calls with keys
re.compile(r"(?i)\.set_api_key\(\s*[\"\'][\w\-\+\/\=]{8,}[\"\']"),
# Dictionary assignments
re.compile(r"(?i)['\"](?:api_?key|access_?token|secret_?key|auth_?token|password|credential|secret)['\"](?:\s*:\s*)[\"\'][\w\-\+\/\=]{8,}[\"\']"),
# Common key formats (base64-like, hex)
re.compile(r"[\"\'](?:[A-Za-z0-9\+\/\=]{32,}|[0-9a-fA-F]{32,})[\"\']"),
# Bearer token pattern
re.compile(r"[\"\'](Bearer\s+[\w\-\+\/\=]{8,})[\"\']"),
# Inline URL with auth
re.compile(r"https?:\/\/[\w\-\+\/\=]{8,}@")
]
# Network request patterns
NETWORK_REQUEST_PATTERNS = re.compile(r"(requests\.|urllib\.|http\.|\.post\(|\.get\(|\.connect\()")
def check_security_concerns(code_str):
"""Check code for security concerns and return info about what was found"""
security_concerns = {
"has_concern": False,
"messages": [],
"blocked_imports": False,
"blocked_dynamic_imports": False,
"blocked_env_access": False,
"blocked_file_access": False,
"blocked_api_keys": False,
"blocked_network": False
}
# Check for sensitive imports
if IMPORT_PATTERN.search(code_str) or FROM_IMPORT_PATTERN.search(code_str):
security_concerns["has_concern"] = True
security_concerns["blocked_imports"] = True
security_concerns["messages"].append("Sensitive module imports blocked")
# Check for __import__ bypass technique
if DYNAMIC_IMPORT_PATTERN.search(code_str):
security_concerns["has_concern"] = True
security_concerns["blocked_dynamic_imports"] = True
security_concerns["messages"].append("Dynamic import of sensitive modules blocked")
# Check for environment variables access
if ENV_ACCESS_PATTERN.search(code_str):
security_concerns["has_concern"] = True
security_concerns["blocked_env_access"] = True
security_concerns["messages"].append("Environment variables access blocked")
# Check for file operations
if FILE_ACCESS_PATTERN.search(code_str):
security_concerns["has_concern"] = True
security_concerns["blocked_file_access"] = True
security_concerns["messages"].append("File operations blocked")
# Check for API key patterns
for pattern in API_KEY_PATTERNS:
if pattern.search(code_str):
security_concerns["has_concern"] = True
security_concerns["blocked_api_keys"] = True
security_concerns["messages"].append("API key/token usage blocked")
break
# Check for network requests
if NETWORK_REQUEST_PATTERNS.search(code_str):
security_concerns["has_concern"] = True
security_concerns["blocked_network"] = True
security_concerns["messages"].append("Network requests blocked")
return security_concerns
def clean_code_for_security(code_str, security_concerns):
"""Apply security modifications to the code based on detected concerns"""
modified_code = code_str
# Block sensitive imports if needed
if security_concerns["blocked_imports"]:
modified_code = IMPORT_PATTERN.sub(r'# BLOCKED: import \1\n', modified_code)
modified_code = FROM_IMPORT_PATTERN.sub(r'# BLOCKED: from \1\n', modified_code)
# Block dynamic imports if needed
if security_concerns["blocked_dynamic_imports"]:
modified_code = DYNAMIC_IMPORT_PATTERN.sub(r'"BLOCKED_DYNAMIC_IMPORT"', modified_code)
# Block environment access if needed
if security_concerns["blocked_env_access"]:
modified_code = ENV_ACCESS_PATTERN.sub(r'"BLOCKED_ENV_ACCESS"', modified_code)
# Block file operations if needed
if security_concerns["blocked_file_access"]:
modified_code = FILE_ACCESS_PATTERN.sub(r'"BLOCKED_FILE_ACCESS"', modified_code)
# Block API keys if needed
if security_concerns["blocked_api_keys"]:
for pattern in API_KEY_PATTERNS:
modified_code = pattern.sub(r'"BLOCKED_API_KEY"', modified_code)
# Block network requests if needed
if security_concerns["blocked_network"]:
modified_code = NETWORK_REQUEST_PATTERNS.sub(r'"BLOCKED_NETWORK_REQUEST"', modified_code)
# Add warning banner if needed
if security_concerns["has_concern"]:
security_message = "β οΈ SECURITY WARNING: " + ". ".join(security_concerns["messages"]) + "."
modified_code = f"print('{security_message}')\n\n" + modified_code
return modified_code
def format_correlation_output(text):
"""Format correlation matrix output for better readability"""
lines = text.split('\n')
formatted_lines = []
for line in lines:
# Skip empty lines at the beginning
if not line.strip() and not formatted_lines:
continue
if not line.strip():
formatted_lines.append(line)
continue
# Check if this line contains correlation values or variable names
stripped_line = line.strip()
parts = stripped_line.split()
if len(parts) > 1:
# Check if this is a header line with variable names
if all(part.replace('_', '').replace('-', '').isalpha() for part in parts):
# This is a header row with variable names
formatted_header = f"{'':12}" # Empty first column for row labels
for part in parts:
formatted_header += f"{part:>12}"
formatted_lines.append(formatted_header)
elif any(char.isdigit() for char in stripped_line) and ('.' in stripped_line or '-' in stripped_line):
# This looks like a correlation line with numbers
row_name = parts[0] if parts else ""
values = parts[1:] if len(parts) > 1 else []
formatted_row = f"{row_name:<12}"
for value in values:
try:
val = float(value)
formatted_row += f"{val:>12.3f}"
except ValueError:
formatted_row += f"{value:>12}"
formatted_lines.append(formatted_row)
else:
# Other lines (like titles)
formatted_lines.append(line)
else:
formatted_lines.append(line)
return '\n'.join(formatted_lines)
def format_summary_stats(text):
"""Format summary statistics for better readability"""
lines = text.split('\n')
formatted_lines = []
for line in lines:
if not line.strip():
formatted_lines.append(line)
continue
# Check if this is a header line with statistical terms only (missing first column)
stripped_line = line.strip()
if any(stat in stripped_line.lower() for stat in ['count', 'mean', 'median', 'std', 'min', 'max', '25%', '50%', '75%']):
parts = stripped_line.split()
# Check if this is a header row (starts with statistical terms)
if parts and parts[0].lower() in ['count', 'mean', 'median', 'std', 'min', 'max', '25%', '50%', '75%']:
# This is a header row - add proper spacing
formatted_header = f"{'':12}" # Empty first column for row labels
for part in parts:
formatted_header += f"{part:>15}"
formatted_lines.append(formatted_header)
else:
# This is a data row - format normally
row_name = parts[0] if parts else ""
values = parts[1:] if len(parts) > 1 else []
formatted_row = f"{row_name:<12}"
for value in values:
try:
if '.' in value or 'e' in value.lower():
val = float(value)
if abs(val) >= 1000000:
formatted_row += f"{val:>15.2e}"
elif abs(val) >= 1:
formatted_row += f"{val:>15.2f}"
else:
formatted_row += f"{val:>15.6f}"
else:
val = int(value)
formatted_row += f"{val:>15}"
except ValueError:
formatted_row += f"{value:>15}"
formatted_lines.append(formatted_row)
else:
# Other lines (titles, etc.) - keep as is
formatted_lines.append(line)
return '\n'.join(formatted_lines)
def clean_print_statements(code_block):
"""
This function cleans up any `print()` statements that might contain unwanted `\n` characters.
It ensures print statements are properly formatted without unnecessary newlines.
"""
# This regex targets print statements, even if they have newlines inside
return re.sub(r'print\((.*?)(\\n.*?)(.*?)\)', r'print(\1\3)', code_block, flags=re.DOTALL)
def remove_code_block_from_summary(summary):
# use regex to remove code block from summary list
summary = re.sub(r'```python\n(.*?)\n```', '', summary)
return summary.split("\n")
def remove_main_block(code):
# Match the __main__ block
pattern = r'(?m)^if\s+__name__\s*==\s*["\']__main__["\']\s*:\s*\n((?:\s+.*\n?)*)'
match = re.search(pattern, code)
if match:
main_block = match.group(1)
# Dedent the code block inside __main__
dedented_block = textwrap.dedent(main_block)
# Remove \n from any print statements in the block (also handling multiline print cases)
dedented_block = clean_print_statements(dedented_block)
# Replace the block in the code
cleaned_code = re.sub(pattern, dedented_block, code)
# Optional: Remove leading newlines if any
cleaned_code = cleaned_code.strip()
return cleaned_code
return code
def format_code_block(code_str):
code_clean = re.sub(r'^```python\n?', '', code_str, flags=re.MULTILINE)
code_clean = re.sub(r'\n```$', '', code_clean)
return f'\n{code_clean}\n'
def format_code_backticked_block(code_str):
code_clean = re.sub(r'^```python\n?', '', code_str, flags=re.MULTILINE)
code_clean = re.sub(r'\n```$', '', code_clean)
# Only match assignments at top level (not indented)
# 1. Remove 'df = pd.DataFrame()' if it's at the top level
# Remove reading the csv file if it's already in the context
modified_code = re.sub(r"df\s*=\s*pd\.read_csv\([\"\'].*?[\"\']\).*?(\n|$)", '', code_clean)
# Only match assignments at top level (not indented)
# 1. Remove 'df = pd.DataFrame()' if it's at the top level
modified_code = re.sub(
r"^df\s*=\s*pd\.DataFrame\(\s*\)\s*(#.*)?$",
'',
modified_code,
flags=re.MULTILINE
)
# # Remove sample dataframe lines with multiple array values
modified_code = re.sub(r"^# Sample DataFrames?.*?(\n|$)", '', modified_code, flags=re.MULTILINE | re.IGNORECASE)
# # Remove plt.show() statements
modified_code = re.sub(r"plt\.show\(\).*?(\n|$)", '', modified_code)
# remove main
code_clean = remove_main_block(modified_code)
return f'```python\n{code_clean}\n```'
def execute_code_from_markdown(code_str, dataframe=None):
import pandas as pd
import plotly.express as px
import plotly
import plotly.graph_objects as go
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import re
import traceback
import sys
from io import StringIO, BytesIO
import base64
# Check for security concerns in the code
security_concerns = check_security_concerns(code_str)
# Apply security modifications to the code
modified_code = clean_code_for_security(code_str, security_concerns)
# Enhanced print function that detects and formats tabular data
captured_outputs = []
original_print = print
# Set pandas display options for full table display
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 20) # Limit to 20 rows instead of unlimited
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', 50)
pd.set_option('display.expand_frame_repr', False)
def enhanced_print(*args, **kwargs):
# Convert all args to strings
str_args = [str(arg) for arg in args]
output_text = kwargs.get('sep', ' ').join(str_args)
# Special case for DataFrames - use pipe delimiter and clean format
if isinstance(args[0], pd.DataFrame) and len(args) == 1:
# Format DataFrame with pipe delimiter using to_csv for reliable column separation
df = args[0]
# Use StringIO to capture CSV output with pipe delimiter
from io import StringIO
csv_buffer = StringIO()
# Export to CSV with pipe delimiter, preserving index
df.to_csv(csv_buffer, sep='|', index=True, float_format='%.6g')
csv_output = csv_buffer.getvalue()
# Clean up the CSV output - remove quotes and extra formatting
lines = csv_output.strip().split('\n')
cleaned_lines = []
for line in lines:
# Remove any quotes that might have been added by to_csv
clean_line = line.replace('"', '')
# Split by pipe, strip whitespace from each part, then rejoin
parts = [part.strip() for part in clean_line.split('|')]
cleaned_lines.append(' | '.join(parts))
output_text = '\n'.join(cleaned_lines)
captured_outputs.append(f"<TABLE_START>\n{output_text}\n<TABLE_END>")
original_print(output_text)
return
# Detect if this looks like tabular data (generic approach)
is_table = False
# Check for table patterns:
# 1. Multiple lines with consistent spacing
lines = output_text.split('\n')
if len(lines) > 2:
# Count lines that look like they have multiple columns (2+ spaces between words)
multi_column_lines = sum(1 for line in lines if len(line.split()) > 1 and ' ' in line)
if multi_column_lines >= 2: # At least 2 lines with multiple columns
is_table = True
# Check for pandas DataFrame patterns like index with column names
if any(re.search(r'^\s*\d+\s+', line) for line in lines):
# Look for lines starting with an index number followed by spaces
is_table = True
# Look for table-like structured output with multiple rows of similar format
if len(lines) >= 3:
# Sample a few lines to check for consistent structure
sample_lines = [lines[i] for i in range(min(len(lines), 5)) if i < len(lines) and lines[i].strip()]
# Check for consistent whitespace patterns
if len(sample_lines) >= 2:
# Get positions of whitespace groups in first line
whitespace_positions = []
for i, line in enumerate(sample_lines):
if not line.strip():
continue
positions = [m.start() for m in re.finditer(r'\s{2,}', line)]
if i == 0:
whitespace_positions = positions
elif len(positions) == len(whitespace_positions):
# Check if whitespace positions are roughly the same
is_similar = all(abs(pos - whitespace_positions[j]) <= 3
for j, pos in enumerate(positions)
if j < len(whitespace_positions))
if is_similar:
is_table = True
# 2. Contains common table indicators
if any(indicator in output_text.lower() for indicator in [
'count', 'mean', 'std', 'min', 'max', '25%', '50%', '75%', # Summary stats
'correlation', 'corr', # Correlation tables
'coefficient', 'r-squared', 'p-value', # Regression tables
]):
is_table = True
# 3. Has many decimal numbers (likely a data table)
if output_text.count('.') > 5 and len(lines) > 2:
is_table = True
# If we have detected a table, convert space-delimited to pipe-delimited format
if is_table:
# Convert the table to pipe-delimited format for better parsing in frontend
formatted_lines = []
for line in lines:
if not line.strip():
formatted_lines.append(line) # Keep empty lines
continue
# Split by multiple spaces and join with pipe delimiter
parts = re.split(r'\s{2,}', line.strip())
if parts:
formatted_lines.append(" | ".join(parts))
else:
formatted_lines.append(line)
# Use the pipe-delimited format
output_text = "\n".join(formatted_lines)
# Format and mark the output for table processing in UI
captured_outputs.append(f"<TABLE_START>\n{output_text}\n<TABLE_END>")
else:
captured_outputs.append(output_text)
# Also use original print for stdout capture
original_print(*args, **kwargs)
# Custom matplotlib capture function
def capture_matplotlib_chart():
"""Capture current matplotlib figure as base64 encoded image"""
try:
fig = plt.gcf() # Get current figure
if fig.get_axes(): # Check if figure has any plots
buffer = BytesIO()
fig.savefig(buffer, format='png', dpi=150, bbox_inches='tight',
facecolor='white', edgecolor='none')
buffer.seek(0)
img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
buffer.close()
plt.close(fig) # Close the figure to free memory
return img_base64
return None
except Exception:
return None
# Store original plt.show function
original_plt_show = plt.show
def custom_plt_show(*args, **kwargs):
"""Custom plt.show that captures the chart instead of displaying it"""
img_base64 = capture_matplotlib_chart()
if img_base64:
matplotlib_outputs.append(img_base64)
# Don't call original show to prevent display
context = {
'pd': pd,
'px': px,
'go': go,
'plt': plt,
'plotly': plotly,
'__builtins__': __builtins__,
'__import__': __import__,
'sns': sns,
'np': np,
'json_outputs': [], # List to store multiple Plotly JSON outputs
'matplotlib_outputs': [], # List to store matplotlib chart images as base64
'print': enhanced_print # Replace print with our enhanced version
}
# Add matplotlib_outputs to local scope for the custom show function
matplotlib_outputs = context['matplotlib_outputs']
# Replace plt.show with our custom function
plt.show = custom_plt_show
# Modify code to store multiple JSON outputs
modified_code = re.sub(
r'(\w*_?)fig(\w*)\.show\(\)',
r'json_outputs.append(plotly.io.to_json(\1fig\2, pretty=True))',
modified_code
)
modified_code = re.sub(
r'(\w*_?)fig(\w*)\.to_html\(.*?\)',
r'json_outputs.append(plotly.io.to_json(\1fig\2, pretty=True))',
modified_code
)
# Remove reading the csv file if it's already in the context
modified_code = re.sub(r"df\s*=\s*pd\.read_csv\([\"\'].*?[\"\']\).*?(\n|$)", '', modified_code)
# Only match assignments at top level (not indented)
# 1. Remove 'df = pd.DataFrame()' if it's at the top level
modified_code = re.sub(
r"^df\s*=\s*pd\.DataFrame\(\s*\)\s*(#.*)?$",
'',
modified_code,
flags=re.MULTILINE
)
# Custom display function for DataFrames to show head + tail for large datasets
original_repr = pd.DataFrame.__repr__
def custom_df_repr(self):
if len(self) > 15:
# For large DataFrames, show first 10 and last 5 rows
head_part = self.head(10)
tail_part = self.tail(5)
head_str = head_part.__repr__()
tail_str = tail_part.__repr__()
# Extract just the data rows (skip the header from tail)
tail_lines = tail_str.split('\n')
tail_data = '\n'.join(tail_lines[1:]) # Skip header line
return f"{head_str}\n...\n{tail_data}"
else:
return original_repr(self)
# Apply custom representation temporarily
pd.DataFrame.__repr__ = custom_df_repr
# If a dataframe is provided, add it to the context
if dataframe is not None:
context['df'] = dataframe
# remove pd.read_csv() if it's already in the context
modified_code = re.sub(r"pd\.read_csv\(\s*[\"\'].*?[\"\']\s*\)", '', modified_code)
# Remove sample dataframe lines with multiple array values
modified_code = re.sub(r"^# Sample DataFrames?.*?(\n|$)", '', modified_code, flags=re.MULTILINE | re.IGNORECASE)
# Replace plt.savefig() calls with plt.show() to ensure plots are displayed
modified_code = re.sub(r'plt\.savefig\([^)]*\)', 'plt.show()', modified_code)
# Instead of removing plt.show(), keep them - they'll be handled by our custom function
# Also handle seaborn plots that might not have explicit plt.show()
# Add plt.show() after seaborn plot functions if not already present
seaborn_plot_functions = [
'sns.scatterplot', 'sns.lineplot', 'sns.barplot', 'sns.boxplot', 'sns.violinplot',
'sns.stripplot', 'sns.swarmplot', 'sns.pointplot', 'sns.catplot', 'sns.relplot',
'sns.displot', 'sns.histplot', 'sns.kdeplot', 'sns.ecdfplot', 'sns.rugplot',
'sns.distplot', 'sns.jointplot', 'sns.pairplot', 'sns.FacetGrid', 'sns.PairGrid',
'sns.heatmap', 'sns.clustermap', 'sns.regplot', 'sns.lmplot', 'sns.residplot'
]
# Add automatic plt.show() after seaborn plots if not already present
for func in seaborn_plot_functions:
pattern = rf'({re.escape(func)}\([^)]*\)(?:\.[^(]*\([^)]*\))*)'
def add_show(match):
plot_call = match.group(1)
# Check if the next non-empty line already has plt.show()
return f'{plot_call}\nplt.show()'
modified_code = re.sub(pattern, add_show, modified_code)
# Only add df = pd.read_csv() if no dataframe was provided and the code contains pd.read_csv
if dataframe is None and 'pd.read_csv' not in modified_code:
modified_code = re.sub(
r'import pandas as pd',
r'import pandas as pd\n\n# Read Housing.csv\ndf = pd.read_csv("Housing.csv")',
modified_code
)
# Identify code blocks by comments
code_blocks = []
current_block = []
current_block_name = "unknown"
for line in modified_code.splitlines():
# Check if line contains a block identifier comment
block_match = re.match(r'^# ([a-zA-Z_]+)_agent code start', line)
if block_match:
# If we had a previous block, save it
if current_block:
code_blocks.append((current_block_name, '\n'.join(current_block)))
# Start a new block
current_block_name = block_match.group(1)
current_block = []
else:
current_block.append(line)
# Add the last block if it exists
if current_block:
code_blocks.append((current_block_name, '\n'.join(current_block)))
# Execute each code block separately
all_outputs = []
for block_name, block_code in code_blocks:
try:
# Clear captured outputs for each block
captured_outputs.clear()
with stdoutIO() as s:
exec(block_code, context) # Execute the block
# Get both stdout and our enhanced captured outputs
stdout_output = s.getvalue()
# Combine outputs, preferring our enhanced format when available
if captured_outputs:
combined_output = '\n'.join(captured_outputs)
else:
combined_output = stdout_output
all_outputs.append((block_name, combined_output, None)) # None means no error
except Exception as e:
# Reset pandas options in case of error
pd.reset_option('display.max_columns')
pd.reset_option('display.max_rows')
pd.reset_option('display.width')
pd.reset_option('display.max_colwidth')
pd.reset_option('display.expand_frame_repr')
# Restore original DataFrame representation in case of error
pd.DataFrame.__repr__ = original_repr
# Restore original plt.show
plt.show = original_plt_show
error_traceback = traceback.format_exc()
# Extract error message and error type
error_message = str(e)
error_type = type(e).__name__
error_lines = error_traceback.splitlines()
# Format error with context of the actual code
formatted_error = f"Error in {block_name}_agent: {error_message}\n"
# Add first few lines of traceback
first_lines = error_lines[:3]
formatted_error += "\n".join(first_lines) + "\n"
# Parse problem variables/values from the error message
problem_vars = []
# Look for common error patterns
if "not in index" in error_message:
# Extract column names for 'not in index' errors
column_match = re.search(r"\['([^']+)'(?:, '([^']+)')*\] not in index", error_message)
if column_match:
problem_vars = [g for g in column_match.groups() if g is not None]
# Look for DataFrame accessing operations and list/variable definitions
potential_lines = []
code_lines = block_code.splitlines()
# First, find all DataFrame column access patterns
df_access_patterns = []
for i, line in enumerate(code_lines):
# Find DataFrame variables from patterns like "df_name[...]" or "df_name.loc[...]"
df_matches = re.findall(r'(\w+)(?:\[|\.)(?:loc|iloc|columns|at|iat|\.select)', line)
for df_var in df_matches:
df_access_patterns.append((i, df_var))
# Find variables that might contain column lists
for var in problem_vars:
if re.search(r'\b(numeric_columns|categorical_columns|columns|features|cols)\b', line):
potential_lines.append(i)
# Identify the most likely problematic lines
if df_access_patterns:
for i, df_var in df_access_patterns:
if any(re.search(rf'{df_var}\[.*?\]', line) for line in code_lines):
potential_lines.append(i)
# If no specific lines found yet, look for any DataFrame operations
if not potential_lines:
for i, line in enumerate(code_lines):
if re.search(r'(?:corr|drop|groupby|pivot|merge|join|concat|apply|map|filter|loc|iloc)\(', line):
potential_lines.append(i)
# Sort and deduplicate
potential_lines = sorted(set(potential_lines))
elif "name" in error_message and "is not defined" in error_message:
# Extract variable name for NameError
var_match = re.search(r"name '([^']+)' is not defined", error_message)
if var_match:
problem_vars = [var_match.group(1)]
elif "object has no attribute" in error_message:
# Extract attribute name for AttributeError
attr_match = re.search(r"'([^']+)' object has no attribute '([^']+)'", error_message)
if attr_match:
problem_vars = [f"{attr_match.group(1)}.{attr_match.group(2)}"]
# Scan code for lines containing the problem variables
if problem_vars:
formatted_error += "\nProblem likely in these lines:\n"
code_lines = block_code.splitlines()
problem_lines = []
# First try direct variable references
direct_matches = False
for i, line in enumerate(code_lines):
if any(var in line for var in problem_vars):
direct_matches = True
# Get line and its context (1 line before and after)
start_idx = max(0, i-1)
end_idx = min(len(code_lines), i+2)
for j in range(start_idx, end_idx):
line_prefix = f"{j+1}: "
if j == i: # The line with the problem variable
problem_lines.append(f"{line_prefix}>>> {code_lines[j]} <<<")
else:
problem_lines.append(f"{line_prefix}{code_lines[j]}")
problem_lines.append("") # Empty line between sections
# If no direct matches found but we identified potential problematic lines for DataFrame issues
if not direct_matches and "not in index" in error_message and 'potential_lines' in locals():
for i in potential_lines:
start_idx = max(0, i-1)
end_idx = min(len(code_lines), i+2)
for j in range(start_idx, end_idx):
line_prefix = f"{j+1}: "
if j == i:
problem_lines.append(f"{line_prefix}>>> {code_lines[j]} <<<")
else:
problem_lines.append(f"{line_prefix}{code_lines[j]}")
problem_lines.append("") # Empty line between sections
if problem_lines:
formatted_error += "\n".join(problem_lines)
else:
# Special message for column errors when we can't find the exact reference
if "not in index" in error_message:
formatted_error += (f"Unable to locate direct reference to columns: {', '.join(problem_vars)}\n"
f"Check for variables that might contain these column names (like numeric_columns, "
f"categorical_columns, etc.)\n")
else:
formatted_error += f"Unable to locate lines containing: {', '.join(problem_vars)}\n"
else:
# If we couldn't identify specific variables, check for line numbers in traceback
for line in reversed(error_lines): # Search from the end of traceback
# Look for user code references in the traceback
if ', line ' in line and '<module>' in line:
try:
line_num = int(re.search(r', line (\d+)', line).group(1))
code_lines = block_code.splitlines()
if 0 < line_num <= len(code_lines):
line_idx = line_num - 1
start_idx = max(0, line_idx-2)
end_idx = min(len(code_lines), line_idx+3)
formatted_error += "\nProblem at this location:\n"
for i in range(start_idx, end_idx):
line_prefix = f"{i+1}: "
if i == line_idx:
formatted_error += f"{line_prefix}>>> {code_lines[i]} <<<\n"
else:
formatted_error += f"{line_prefix}{code_lines[i]}\n"
break
except (ValueError, AttributeError, IndexError):
pass
# Add the last few lines of the traceback
formatted_error += "\nFull error details:\n"
last_lines = error_lines[-3:]
formatted_error += "\n".join(last_lines)
all_outputs.append((block_name, None, formatted_error))
# Reset pandas options after execution
pd.reset_option('display.max_columns')
pd.reset_option('display.max_rows')
pd.reset_option('display.width')
pd.reset_option('display.max_colwidth')
pd.reset_option('display.expand_frame_repr')
# Restore original DataFrame representation
pd.DataFrame.__repr__ = original_repr
# Restore original plt.show
plt.show = original_plt_show
# Compile all outputs and errors
output_text = ""
json_outputs = context.get('json_outputs', [])
matplotlib_outputs = context.get('matplotlib_outputs', [])
error_found = False
for block_name, output, error in all_outputs:
if error:
output_text += f"\n\n=== ERROR IN {block_name.upper()}_AGENT ===\n{error}\n"
error_found = True
elif output:
output_text += f"\n\n=== OUTPUT FROM {block_name.upper()}_AGENT ===\n{output}\n"
if error_found:
return output_text, [], []
else:
return output_text, json_outputs, matplotlib_outputs
def format_plan_instructions(plan_instructions):
"""
Format any plan instructions (JSON string or dict) into markdown sections per agent.
"""
# Parse input into a dict
if "basic_qa_agent" in str(plan_instructions):
return "**Non-Data Request**: Please ask a data related query, don't waste credits!"
try:
if isinstance(plan_instructions, str):
try:
instructions = json.loads(plan_instructions)
except json.JSONDecodeError as e:
# Try to clean the string if it's not valid JSON
cleaned_str = plan_instructions.strip()
if cleaned_str.startswith("'") and cleaned_str.endswith("'"):
cleaned_str = cleaned_str[1:-1]
try:
instructions = json.loads(cleaned_str)
except json.JSONDecodeError:
raise ValueError(f"Invalid JSON format in plan instructions: {str(e)}")
elif isinstance(plan_instructions, dict):
instructions = plan_instructions
else:
raise TypeError(f"Unsupported plan instructions type: {type(plan_instructions)}")
except Exception as e:
raise ValueError(f"Error processing plan instructions: {str(e)}")
# logger.log_message(f"Plan instructions: {instructions}", level=logging.INFO)
markdown_lines = []
for agent, content in instructions.items():
if agent != 'basic_qa_agent':
agent_title = agent.replace('_', ' ').title()
markdown_lines.append(f"#### {agent_title}")
if isinstance(content, dict):
# Handle 'create' key
create_vals = content.get('create', [])
if create_vals:
markdown_lines.append(f"- **Create**:")
for item in create_vals:
markdown_lines.append(f" - {item}")
else:
markdown_lines.append(f"- **Create**: None")
# Handle 'use' key
use_vals = content.get('use', [])
if use_vals:
markdown_lines.append(f"- **Use**:")
for item in use_vals:
markdown_lines.append(f" - {item}")
else:
markdown_lines.append(f"- **Use**: None")
# Handle 'instruction' key
instr = content.get('instruction')
if isinstance(instr, str) and instr:
markdown_lines.append(f"- **Instruction**: {instr}")
else:
markdown_lines.append(f"- **Instruction**: None")
else:
# Fallback for non-dict content
markdown_lines.append(f"- {content}")
markdown_lines.append("") # blank line between agents
else:
markdown_lines.append(f"**Non-Data Request**: {content.get('instruction')}")
return "\n".join(markdown_lines).strip()
def format_complexity(instructions):
markdown_lines = []
# Extract complexity from various possible locations in the structure
if isinstance(instructions, dict):
# Case 1: Direct complexity field
if 'complexity' in instructions:
complexity = instructions['complexity']
# Case 2: Complexity in 'plan' object
elif 'plan' in instructions and isinstance(instructions['plan'], dict):
if 'complexity' in instructions['plan']:
complexity = instructions['plan']['complexity']
else:
complexity = "unrelated"
if 'plan' in instructions and isinstance(instructions['plan'], str) and "basic_qa_agent" in instructions['plan']:
complexity = "unrelated"
if complexity:
# Pink color scheme variations
color_map = {
"unrelated": "#FFB6B6", # Light pink
"basic": "#FF9E9E", # Medium pink
"intermediate": "#FF7F7F", # Main pink
"advanced": "#FF5F5F" # Dark pink
}
indicator_map = {
"unrelated": "β",
"basic": "β",
"intermediate": "ββ",
"advanced": "βββ"
}
color = color_map.get(complexity.lower(), "#FFB6B6") # Default to light pink
indicator = indicator_map.get(complexity.lower(), "β")
# Slightly larger display with pink styling
markdown_lines.append(f"<div style='color: {color}; border: 2px solid {color}; padding: 2px 8px; border-radius: 12px; display: inline-block; font-size: 14.4px;'>{indicator} {complexity}</div>\n")
return "\n".join(markdown_lines).strip()
def format_response_to_markdown(api_response, agent_name = None, dataframe=None):
try:
markdown = []
# logger.log_message(f"API response for {agent_name} at {time.strftime('%Y-%m-%d %H:%M:%S')}: {api_response}", level=logging.INFO)
if isinstance(api_response, dict):
for key in api_response:
if "error" in api_response[key] and "litellm.RateLimitError" in api_response[key]['error'].lower():
return f"**Error**: Rate limit exceeded. Please try switching models from the settings."
# You can add more checks here if needed for other keys
# Handle error responses
if isinstance(api_response, dict) and "error" in api_response:
return f"**Error**: {api_response['error']}"
if "response" in api_response and isinstance(api_response['response'], str):
if any(err in api_response['response'].lower() for err in ["auth", "api", "lm"]):
return "**Error**: Authentication failed. Please check your API key in settings and try again."
if "model" in api_response['response'].lower():
return "**Error**: Model configuration error. Please verify your model selection in settings."
for agent, content in api_response.items():
agent = agent.split("__")[0] if "__" in agent else agent
if "memory" in agent or not content:
continue
if "complexity" in content:
markdown.append(f"{format_complexity(content)}\n")
markdown.append(f"\n## {agent.replace('_', ' ').title()}\n")
if agent == "analytical_planner":
logger.log_message(f"Analytical planner content: {content}", level=logging.INFO)
if 'plan_desc' in content:
markdown.append(f"### Reasoning\n{content['plan_desc']}\n")
if 'plan_instructions' in content:
markdown.append(f"{format_plan_instructions(content['plan_instructions'])}\n")
else:
markdown.append(f"### Reasoning\n{content['rationale']}\n")
else:
if "rationale" in content:
markdown.append(f"### Reasoning\n{content['rationale']}\n")
if 'code' in content:
markdown.append(f"### Code Implementation\n{format_code_backticked_block(content['code'])}\n")
if 'answer' in content:
markdown.append(f"### Answer\n{content['answer']}\n Please ask a query about the data")
if 'summary' in content:
import re
summary_text = content['summary']
summary_text = re.sub(r'```python\n(.*?)\n```', '', summary_text, flags=re.DOTALL)
markdown.append("### Summary\n")
# Extract pre-list intro, bullet points, and post-list text
intro_match = re.split(r'\(\d+\)', summary_text, maxsplit=1)
if len(intro_match) > 1:
intro_text = intro_match[0].strip()
rest_text = "(1)" + intro_match[1] # reattach for bullet parsing
else:
intro_text = summary_text.strip()
rest_text = ""
if intro_text:
markdown.append(f"{intro_text}\n")
# Split bullets at numbered items like (1)...(8)
bullets = re.split(r'\(\d+\)', rest_text)
bullets = [b.strip(" ,.\n") for b in bullets if b.strip()]
# Check for post-list content (anything after the last number)
for i, bullet in enumerate(bullets):
markdown.append(f"* {bullet}\n")
if 'refined_complete_code' in content and 'summary' in content:
try:
if content['refined_complete_code'] is not None and content['refined_complete_code'] != "":
clean_code = format_code_block(content['refined_complete_code'])
markdown_code = format_code_backticked_block(content['refined_complete_code'])
output, json_outputs, matplotlib_outputs = execute_code_from_markdown(clean_code, dataframe)
elif "```python" in content['summary']:
clean_code = format_code_block(content['summary'])
markdown_code = format_code_backticked_block(content['summary'])
output, json_outputs, matplotlib_outputs = execute_code_from_markdown(clean_code, dataframe)
except Exception as e:
logger.log_message(f"Error in execute_code_from_markdown: {str(e)}", level=logging.ERROR)
markdown_code = f"**Error**: {str(e)}"
output = None
json_outputs = []
matplotlib_outputs = []
# continue
if markdown_code is not None:
markdown.append(f"### Refined Complete Code\n{markdown_code}\n")
if output:
markdown.append("### Execution Output\n")
markdown.append(f"```output\n{output}\n```\n")
if json_outputs:
markdown.append("### Plotly JSON Outputs\n")
for idx, json_output in enumerate(json_outputs):
markdown.append(f"```plotly\n{json_output}\n```\n")
if matplotlib_outputs:
markdown.append("### Matplotlib/Seaborn Charts\n")
for idx, img_base64 in enumerate(matplotlib_outputs):
markdown.append(f"```matplotlib\n{img_base64}\n```\n")
# if agent_name is not None:
# if f"memory_{agent_name}" in api_response:
# markdown.append(f"### Memory\n{api_response[f'memory_{agent_name}']}\n")
except Exception as e:
logger.log_message(f"Error in format_response_to_markdown: {str(e)}", level=logging.ERROR)
return f"{str(e)}"
# logger.log_message(f"Generated markdown content for agent '{agent_name}' at {time.strftime('%Y-%m-%d %H:%M:%S')}: {markdown}, length: {len(markdown)}", level=logging.INFO)
if not markdown or len(markdown) <= 1:
logger.log_message(
f"Invalid markdown content for agent '{agent_name}' at {time.strftime('%Y-%m-%d %H:%M:%S')}: "
f"Content: '{markdown}', Type: {type(markdown)}, Length: {len(markdown) if markdown else 0}, "
f"API Response: {api_response}",
level=logging.ERROR
)
return " "
return '\n'.join(markdown)
# Example usage with dummy data
if __name__ == "__main__":
sample_response = {
"code_combiner_agent": {
"reasoning": "Sample reasoning for multiple charts.",
"refined_complete_code": """
```python
import plotly.express as px
import pandas as pd
# Sample Data
df = pd.DataFrame({'Category': ['A', 'B', 'C'], 'Values': [10, 20, 30]})
# First Chart
fig = px.bar(df, x='Category', y='Values', title='Bar Chart')
fig.show()
# Second Chart
fig2 = px.pie(df, values='Values', names='Category', title='Pie Chart')
fig2.show()
```
"""
}
}
formatted_md = format_response_to_markdown(sample_response) |