import atexit import json import logging import os import sys from ast import literal_eval from dataclasses import dataclass, field from typing import Any, Dict, List, Literal, Optional, Union from jupyter_client.kernelspec import KernelSpec, KernelSpecManager from jupyter_client.manager import KernelManager from jupyter_client.multikernelmanager import MultiKernelManager from taskweaver.ces.common import EnvPlugin, ExecutionArtifact, ExecutionResult, get_id logger = logging.getLogger(__name__) handler = logging.StreamHandler(sys.stdout) handler.setLevel(logging.DEBUG) formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") handler.setFormatter(formatter) logger.addHandler(handler) ExecType = Literal["user", "control"] ResultMimeType = Union[ Literal["text/plain", "text/html", "text/markdown", "text/latex"], str, ] @dataclass class DisplayData: data: Dict[ResultMimeType, Any] = field(default_factory=dict) metadata: Dict[str, Any] = field(default_factory=dict) transient: Dict[str, Any] = field(default_factory=dict) @dataclass class EnvExecution: exec_id: str code: str exec_type: ExecType = "user" # streaming output stdout: List[str] = field(default_factory=list) stderr: List[str] = field(default_factory=list) displays: List[DisplayData] = field(default_factory=list) # final output result: Dict[ResultMimeType, str] = field(default_factory=dict) error: str = "" @dataclass class EnvSession: session_id: str kernel_status: Literal[ "pending", "ready", "running", "stopped", "error", ] = "pending" kernel_id: str = "" execution_count: int = 0 execution_dict: Dict[str, EnvExecution] = field(default_factory=dict) session_dir: str = "" session_var: Dict[str, str] = field(default_factory=dict) plugins: Dict[str, EnvPlugin] = field(default_factory=dict) class KernelSpecProvider(KernelSpecManager): def get_kernel_spec(self, kernel_name: str) -> KernelSpec: if kernel_name == "taskweaver": return KernelSpec( argv=[ "python", "-m", "taskweaver.ces.kernel.launcher", "-f", "{connection_file}", ], display_name="TaskWeaver", language="python", metadata={"debugger": True}, ) return super().get_kernel_spec(kernel_name) class TaskWeaverMultiKernelManager(MultiKernelManager): def pre_start_kernel( self, kernel_name: str | None, kwargs: Any, ) -> tuple[KernelManager, str, str]: env: Optional[Dict[str, str]] = kwargs.get("env") km, kernel_name, kernel_id = super().pre_start_kernel(kernel_name, kwargs) if env is not None and "CONNECTION_FILE" in env: km.connection_file = env["CONNECTION_FILE"] return km, kernel_name, kernel_id class Environment: def __init__( self, env_id: Optional[str] = None, env_dir: Optional[str] = None, ) -> None: self.session_dict: Dict[str, EnvSession] = {} self.id = get_id(prefix="env") if env_id is None else env_id self.env_dir = env_dir if env_dir is not None else os.getcwd() self.multi_kernel_manager = self.init_kernel_manager() def clean_up(self) -> None: for session in self.session_dict.values(): try: self.stop_session(session.session_id) except Exception as e: logger.error(e) def init_kernel_manager(self) -> MultiKernelManager: atexit.register(self.clean_up) return TaskWeaverMultiKernelManager( default_kernel_name="taskweaver", kernel_spec_manager=KernelSpecProvider(), ) def start_session( self, session_id: str, session_dir: Optional[str] = None, cwd: Optional[str] = None, ) -> None: session = self._get_session(session_id, session_dir=session_dir) ces_session_dir = os.path.join(session.session_dir, "ces") kernel_id = get_id(prefix="knl") os.makedirs(ces_session_dir, exist_ok=True) connection_file = os.path.join( ces_session_dir, f"conn-{session.session_id}-{kernel_id}.json", ) cwd = cwd if cwd is not None else os.path.join(session.session_dir, "cwd") os.makedirs(cwd, exist_ok=True) # set python home from current python environment python_home = os.path.sep.join(sys.executable.split(os.path.sep)[:-2]) python_path = os.pathsep.join( [ os.path.realpath(os.path.join(os.path.dirname(__file__), "..", "..")), os.path.join(python_home, "Lib", "site-packages"), ] + sys.path, ) # inherit current environment variables # TODO: filter out sensitive environment information kernel_env = os.environ.copy() kernel_env.update( { "TASKWEAVER_ENV_ID": self.id, "TASKWEAVER_SESSION_ID": session.session_id, "TASKWEAVER_SESSION_DIR": session.session_dir, "TASKWEAVER_LOGGING_FILE_PATH": os.path.join( ces_session_dir, "kernel_logging.log", ), "CONNECTION_FILE": connection_file, "PATH": os.environ["PATH"], "PYTHONPATH": python_path, "PYTHONHOME": python_home, }, ) session.kernel_id = self.multi_kernel_manager.start_kernel( kernel_id=kernel_id, cwd=cwd, env=kernel_env, ) self._cmd_session_init(session) session.kernel_status = "ready" def execute_code( self, session_id: str, code: str, exec_id: Optional[str] = None, ) -> ExecutionResult: exec_id = get_id(prefix="exec") if exec_id is None else exec_id session = self._get_session(session_id) if session.kernel_status == "pending": self.start_session(session_id) session.execution_count += 1 execution_index = session.execution_count self._execute_control_code_on_kernel( session.kernel_id, f"%_taskweaver_exec_pre_check {execution_index} {exec_id}", ) exec_result = self._execute_code_on_kernel( session.kernel_id, exec_id=exec_id, code=code, ) exec_extra_result = self._execute_control_code_on_kernel( session.kernel_id, f"%_taskweaver_exec_post_check {execution_index} {exec_id}", ) session.execution_dict[exec_id] = exec_result # TODO: handle session id, round id, post id, etc. return self._parse_exec_result(exec_result, exec_extra_result["data"]) def load_plugin( self, session_id: str, plugin_name: str, plugin_impl: str, plugin_config: Optional[Dict[str, str]] = None, ) -> None: session = self._get_session(session_id) if plugin_name in session.plugins.keys(): prev_plugin = session.plugins[plugin_name] if prev_plugin.loaded: self._cmd_plugin_unload(session, prev_plugin) del session.plugins[plugin_name] plugin = EnvPlugin( name=plugin_name, impl=plugin_impl, config=plugin_config, loaded=False, ) self._cmd_plugin_load(session, plugin) plugin.loaded = True session.plugins[plugin_name] = plugin def test_plugin( self, session_id: str, plugin_name: str, ) -> None: session = self._get_session(session_id) plugin = session.plugins[plugin_name] self._cmd_plugin_test(session, plugin) def unload_plugin( self, session_id: str, plugin_name: str, ) -> None: session = self._get_session(session_id) if plugin_name in session.plugins.keys(): plugin = session.plugins[plugin_name] if plugin.loaded: self._cmd_plugin_unload(session, plugin) del session.plugins[plugin_name] def update_session_var( self, session_id: str, session_var: Dict[str, str], ) -> None: session = self._get_session(session_id) session.session_var.update(session_var) self._update_session_var(session) def stop_session(self, session_id: str) -> None: session = self._get_session(session_id) if session.kernel_status == "stopped": return if session.kernel_status == "pending": session.kernel_status = "stopped" return try: if session.kernel_id != "": kernel = self.multi_kernel_manager.get_kernel(session.kernel_id) is_alive = kernel.is_alive() if is_alive: kernel.shutdown_kernel(now=True) kernel.cleanup_resources() except Exception as e: logger.error(e) session.kernel_status = "stopped" def download_file(self, session_id: str, file_path: str) -> str: session = self._get_session(session_id) full_path = self._execute_code_on_kernel( session.kernel_id, get_id(prefix="exec"), f"%%_taskweaver_convert_path\n{file_path}", silent=True, ) return full_path.result["text/plain"] def _get_session( self, session_id: str, session_dir: Optional[str] = None, ) -> EnvSession: if session_id not in self.session_dict: new_session = EnvSession(session_id) new_session.session_dir = ( session_dir if session_dir is not None else self._get_default_session_dir(session_id) ) os.makedirs(new_session.session_dir, exist_ok=True) self.session_dict[session_id] = new_session return self.session_dict[session_id] def _get_default_session_dir(self, session_id: str) -> str: os.makedirs(os.path.join(self.env_dir, "sessions"), exist_ok=True) return os.path.join(self.env_dir, "sessions", session_id) def _execute_control_code_on_kernel( self, kernel_id: str, code: str, silent: bool = False, store_history: bool = False, ) -> Dict[Literal["is_success", "message", "data"], Union[bool, str, Any]]: exec_result = self._execute_code_on_kernel( kernel_id, get_id(prefix="exec"), code=code, silent=silent, store_history=store_history, exec_type="control", ) if exec_result.error != "": raise Exception(exec_result.error) if "text/plain" not in exec_result.result: raise Exception("No text returned.") result = literal_eval(exec_result.result["text/plain"]) if not result["is_success"]: raise Exception(result["message"]) return result def _execute_code_on_kernel( self, kernel_id: str, exec_id: str, code: str, silent: bool = False, store_history: bool = True, exec_type: ExecType = "user", ) -> EnvExecution: exec_result = EnvExecution(exec_id=exec_id, code=code, exec_type=exec_type) km = self.multi_kernel_manager.get_kernel(kernel_id) kc = km.client() kc.start_channels() kc.wait_for_ready(10) result_msg_id = kc.execute( code=code, silent=silent, store_history=store_history, allow_stdin=False, stop_on_error=True, ) try: # TODO: interrupt kernel if it takes too long while True: message = kc.get_iopub_msg(timeout=180) logger.info(json.dumps(message, indent=2, default=str)) assert message["parent_header"]["msg_id"] == result_msg_id msg_type = message["msg_type"] if msg_type == "status": if message["content"]["execution_state"] == "idle": break elif msg_type == "stream": stream_name = message["content"]["name"] stream_text = message["content"]["text"] if stream_name == "stdout": exec_result.stdout.append(stream_text) elif stream_name == "stderr": exec_result.stderr.append(stream_text) else: assert False, f"Unsupported stream name: {stream_name}" elif msg_type == "execute_result": execute_result = message["content"]["data"] exec_result.result = execute_result elif msg_type == "error": error_name = message["content"]["ename"] error_value = message["content"]["evalue"] error_traceback_lines = message["content"]["traceback"] if error_traceback_lines is None: error_traceback_lines = [f"{error_name}: {error_value}"] error_traceback = "\n".join(error_traceback_lines) exec_result.error = error_traceback elif msg_type == "execute_input": pass elif msg_type == "display_data": data: Dict[ResultMimeType, Any] = message["content"]["data"] metadata: Dict[str, Any] = message["content"]["metadata"] transient: Dict[str, Any] = message["content"]["transient"] exec_result.displays.append( DisplayData(data=data, metadata=metadata, transient=transient), ) elif msg_type == "update_display_data": data: Dict[ResultMimeType, Any] = message["content"]["data"] metadata: Dict[str, Any] = message["content"]["metadata"] transient: Dict[str, Any] = message["content"]["transient"] exec_result.displays.append( DisplayData(data=data, metadata=metadata, transient=transient), ) else: assert False, f"Unsupported message from kernel: {msg_type}, the jupyter_client might be outdated." finally: kc.stop_channels() return exec_result def _update_session_var(self, session: EnvSession) -> None: self._execute_control_code_on_kernel( session.kernel_id, f"%%_taskweaver_update_session_var\n{json.dumps(session.session_var)}", ) def _cmd_session_init(self, session: EnvSession) -> None: self._execute_control_code_on_kernel( session.kernel_id, f"%_taskweaver_session_init {session.session_id}", ) def _cmd_plugin_load(self, session: EnvSession, plugin: EnvPlugin) -> None: self._execute_control_code_on_kernel( session.kernel_id, f"%%_taskweaver_plugin_register {plugin.name}\n{plugin.impl}", ) self._execute_control_code_on_kernel( session.kernel_id, f"%%_taskweaver_plugin_load {plugin.name}\n{json.dumps(plugin.config or {})}", ) def _cmd_plugin_test(self, session: EnvSession, plugin: EnvPlugin) -> None: self._execute_control_code_on_kernel( session.kernel_id, f"%_taskweaver_plugin_test {plugin.name}", ) def _cmd_plugin_unload(self, session: EnvSession, plugin: EnvPlugin) -> None: self._execute_control_code_on_kernel( session.kernel_id, f"%_taskweaver_plugin_unload {plugin.name}", ) def _parse_exec_result( self, exec_result: EnvExecution, extra_result: Optional[Dict[str, Any]] = None, ) -> ExecutionResult: result = ExecutionResult( execution_id=exec_result.exec_id, code=exec_result.code, is_success=exec_result.error == "", error=exec_result.error, output="", stdout=exec_result.stdout, stderr=exec_result.stderr, log=[], artifact=[], ) for mime_type in exec_result.result.keys(): if mime_type.startswith("text/"): text_result = exec_result.result[mime_type] try: parsed_result = literal_eval(text_result) result.output = parsed_result except Exception: result.output = text_result display_artifact_count = 0 for display in exec_result.displays: display_artifact_count += 1 artifact = ExecutionArtifact() artifact.name = f"{exec_result.exec_id}-display-{display_artifact_count}" has_svg = False has_pic = False for mime_type in display.data.keys(): if mime_type.startswith("image/"): if mime_type == "image/svg+xml": if has_pic and has_svg: continue has_svg = True has_pic = True artifact.type = "svg" artifact.file_content_encoding = "str" else: if has_pic: continue has_pic = True artifact.type = "image" artifact.file_content_encoding = "base64" artifact.mime_type = mime_type artifact.file_content = display.data[mime_type] if mime_type.startswith("text/"): artifact.preview = display.data[mime_type] if has_pic: result.artifact.append(artifact) if isinstance(extra_result, dict): for key, value in extra_result.items(): if key == "log": result.log = value elif key == "artifact": for artifact_dict in value: artifact_item = ExecutionArtifact( name=artifact_dict["name"], type=artifact_dict["type"], original_name=artifact_dict["original_name"], file_name=artifact_dict["file"], preview=artifact_dict["preview"], ) result.artifact.append(artifact_item) else: pass return result