Spaces:
Running
Running
import threading | |
from typing import Dict, Any, List, Tuple, Optional, Union | |
import io | |
from contextlib import redirect_stdout | |
from timeout_decorator import timeout | |
import base64 | |
from PIL import Image | |
from vis_python_exe import PythonExecutor, GenericRuntime | |
class SharedRuntimeExecutor(PythonExecutor): | |
""" | |
支持变量共享的Python执行器,增强特性: | |
1. 当 var_whitelist="RETAIN_ALL_VARS" 时保留所有变量 | |
2. 默认模式仅保留系统必要变量和白名单变量 | |
3. 线程安全的运行时管理 | |
""" | |
def __init__( | |
self, | |
runtime_class=None, | |
get_answer_symbol: Optional[str] = None, | |
get_answer_expr: Optional[str] = None, | |
get_answer_from_stdout: bool = True, | |
timeout_length: int = 20, | |
var_whitelist: Union[List[str], str, None] = None, | |
): | |
""" | |
Args: | |
var_whitelist: | |
- 列表: 保留指定前缀的变量 | |
- "RETAIN_ALL_VARS": 保留所有变量 | |
- None: 仅保留系统变量 | |
""" | |
super().__init__( | |
runtime_class=runtime_class, | |
get_answer_symbol=get_answer_symbol, | |
get_answer_expr=get_answer_expr, | |
get_answer_from_stdout=get_answer_from_stdout, | |
timeout_length=timeout_length, | |
) | |
# 变量保留策略 | |
self.retain_all_vars = (var_whitelist == "RETAIN_ALL_VARS") | |
self.var_whitelist = [] if self.retain_all_vars else (var_whitelist or []) | |
# 确保系统必要变量 | |
if '_captured_figures' not in self.var_whitelist: | |
self.var_whitelist.append('_captured_figures') | |
# 线程安全运行时存储 | |
self._runtime_pool: Dict[str, GenericRuntime] = {} | |
self._lock = threading.Lock() | |
def apply(self, code: str, messages: List[Dict], session_id: str = "default") -> Tuple[Any, str]: | |
"""执行代码并保持会话状态""" | |
runtime = self._get_runtime(session_id, messages) | |
try: | |
# 执行代码 | |
result, report = self._execute_shared(code, runtime) | |
# 清理变量(保留策略在此生效) | |
self._clean_runtime_vars(runtime) | |
return result, report | |
except Exception as e: | |
return None, f"Execution failed: {str(e)}" | |
def _get_runtime(self, session_id: str, messages: List[Dict]) -> GenericRuntime: | |
"""线程安全地获取运行时实例""" | |
with self._lock: | |
if session_id not in self._runtime_pool: | |
self._runtime_pool[session_id] = self.runtime_class(messages) | |
return self._runtime_pool[session_id] | |
def _execute_shared(self, code: str, runtime: GenericRuntime) -> Tuple[Any, str]: | |
"""使用共享运行时执行代码""" | |
code_lines = [line for line in code.split('\n') if line.strip()] | |
try: | |
if self.get_answer_from_stdout: | |
program_io = io.StringIO() | |
with redirect_stdout(program_io): | |
timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines)) | |
program_io.seek(0) | |
result = program_io.read() | |
elif self.answer_symbol: | |
timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines)) | |
result = runtime._global_vars.get(self.answer_symbol, "") | |
elif self.answer_expr: | |
timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines)) | |
result = timeout(self.timeout_length)(runtime.eval_code)(self.answer_expr) | |
else: | |
if len(code_lines) > 1: | |
timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines[:-1])) | |
result = timeout(self.timeout_length)(runtime.eval_code)(code_lines[-1]) | |
else: | |
timeout(self.timeout_length)(runtime.exec_code)("\n".join(code_lines)) | |
result = "" | |
# 处理捕获的图像 | |
captured_figures = runtime._global_vars.get("_captured_figures", []) | |
if captured_figures: | |
result = { | |
'text': str(result).strip(), | |
'images': captured_figures | |
} | |
return result, "Success" | |
except Exception as e: | |
return None, f"Error: {str(e)}" | |
def _clean_runtime_vars(self, runtime: GenericRuntime): | |
"""实现变量保留策略""" | |
if self.retain_all_vars: | |
# 全保留模式:保留所有非系统变量 | |
persistent_vars = { | |
k: self._serialize_var(v) | |
for k, v in runtime._global_vars.items() | |
if not k.startswith('_sys_') # 示例:排除真正的系统变量 | |
} | |
else: | |
# 正常模式:按白名单保留 | |
persistent_vars = { | |
k: self._serialize_var(v) | |
for k, v in runtime._global_vars.items() | |
if ( | |
k.startswith('image_clue_') or # 保留注入的图像 | |
any(k.startswith(p) for p in self.var_whitelist) # 用户白名单 | |
) | |
} | |
# 重建变量空间 | |
runtime._global_vars.clear() | |
runtime._global_vars.update(persistent_vars) | |
# 确保必要的系统变量存在 | |
runtime._global_vars.setdefault('_captured_figures', []) | |
def _serialize_var(self, var_value: Any) -> Any: | |
"""处理特殊对象的序列化""" | |
if isinstance(var_value, Image.Image): | |
# PIL图像转为base64 | |
buf = io.BytesIO() | |
var_value.save(buf, format='PNG') | |
return base64.b64encode(buf.getvalue()).decode('utf-8') | |
return var_value | |
def cleanup_session(self, session_id: str): | |
"""清理指定会话""" | |
with self._lock: | |
if session_id in self._runtime_pool: | |
del self._runtime_pool[session_id] | |
def cleanup_all(self): | |
"""清理所有会话""" | |
with self._lock: | |
self._runtime_pool.clear() |