|
import atexit |
|
import json |
|
import logging |
|
import os |
|
import sys |
|
from ast import literal_eval |
|
from dataclasses import dataclass, field |
|
from typing import Any, Dict, List, Literal, Optional, Union |
|
|
|
from jupyter_client.kernelspec import KernelSpec, KernelSpecManager |
|
from jupyter_client.manager import KernelManager |
|
from jupyter_client.multikernelmanager import MultiKernelManager |
|
|
|
from taskweaver.ces.common import EnvPlugin, ExecutionArtifact, ExecutionResult, get_id |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
handler = logging.StreamHandler(sys.stdout) |
|
handler.setLevel(logging.DEBUG) |
|
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
|
handler.setFormatter(formatter) |
|
logger.addHandler(handler) |
|
|
|
ExecType = Literal["user", "control"] |
|
ResultMimeType = Union[ |
|
Literal["text/plain", "text/html", "text/markdown", "text/latex"], |
|
str, |
|
] |
|
|
|
|
|
@dataclass |
|
class DisplayData: |
|
data: Dict[ResultMimeType, Any] = field(default_factory=dict) |
|
metadata: Dict[str, Any] = field(default_factory=dict) |
|
transient: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
|
@dataclass |
|
class EnvExecution: |
|
exec_id: str |
|
code: str |
|
exec_type: ExecType = "user" |
|
|
|
|
|
stdout: List[str] = field(default_factory=list) |
|
stderr: List[str] = field(default_factory=list) |
|
displays: List[DisplayData] = field(default_factory=list) |
|
|
|
|
|
result: Dict[ResultMimeType, str] = field(default_factory=dict) |
|
error: str = "" |
|
|
|
|
|
@dataclass |
|
class EnvSession: |
|
session_id: str |
|
kernel_status: Literal[ |
|
"pending", |
|
"ready", |
|
"running", |
|
"stopped", |
|
"error", |
|
] = "pending" |
|
kernel_id: str = "" |
|
execution_count: int = 0 |
|
execution_dict: Dict[str, EnvExecution] = field(default_factory=dict) |
|
session_dir: str = "" |
|
session_var: Dict[str, str] = field(default_factory=dict) |
|
plugins: Dict[str, EnvPlugin] = field(default_factory=dict) |
|
|
|
|
|
class KernelSpecProvider(KernelSpecManager): |
|
def get_kernel_spec(self, kernel_name: str) -> KernelSpec: |
|
if kernel_name == "taskweaver": |
|
return KernelSpec( |
|
argv=[ |
|
"python", |
|
"-m", |
|
"taskweaver.ces.kernel.launcher", |
|
"-f", |
|
"{connection_file}", |
|
], |
|
display_name="TaskWeaver", |
|
language="python", |
|
metadata={"debugger": True}, |
|
) |
|
return super().get_kernel_spec(kernel_name) |
|
|
|
|
|
class TaskWeaverMultiKernelManager(MultiKernelManager): |
|
def pre_start_kernel( |
|
self, |
|
kernel_name: str | None, |
|
kwargs: Any, |
|
) -> tuple[KernelManager, str, str]: |
|
env: Optional[Dict[str, str]] = kwargs.get("env") |
|
km, kernel_name, kernel_id = super().pre_start_kernel(kernel_name, kwargs) |
|
if env is not None and "CONNECTION_FILE" in env: |
|
km.connection_file = env["CONNECTION_FILE"] |
|
return km, kernel_name, kernel_id |
|
|
|
|
|
class Environment: |
|
def __init__( |
|
self, |
|
env_id: Optional[str] = None, |
|
env_dir: Optional[str] = None, |
|
) -> None: |
|
self.session_dict: Dict[str, EnvSession] = {} |
|
self.id = get_id(prefix="env") if env_id is None else env_id |
|
self.env_dir = env_dir if env_dir is not None else os.getcwd() |
|
|
|
self.multi_kernel_manager = self.init_kernel_manager() |
|
|
|
def clean_up(self) -> None: |
|
for session in self.session_dict.values(): |
|
try: |
|
self.stop_session(session.session_id) |
|
except Exception as e: |
|
logger.error(e) |
|
|
|
def init_kernel_manager(self) -> MultiKernelManager: |
|
atexit.register(self.clean_up) |
|
return TaskWeaverMultiKernelManager( |
|
default_kernel_name="taskweaver", |
|
kernel_spec_manager=KernelSpecProvider(), |
|
) |
|
|
|
def start_session( |
|
self, |
|
session_id: str, |
|
session_dir: Optional[str] = None, |
|
cwd: Optional[str] = None, |
|
) -> None: |
|
session = self._get_session(session_id, session_dir=session_dir) |
|
ces_session_dir = os.path.join(session.session_dir, "ces") |
|
kernel_id = get_id(prefix="knl") |
|
|
|
os.makedirs(ces_session_dir, exist_ok=True) |
|
connection_file = os.path.join( |
|
ces_session_dir, |
|
f"conn-{session.session_id}-{kernel_id}.json", |
|
) |
|
|
|
cwd = cwd if cwd is not None else os.path.join(session.session_dir, "cwd") |
|
os.makedirs(cwd, exist_ok=True) |
|
|
|
|
|
python_home = os.path.sep.join(sys.executable.split(os.path.sep)[:-2]) |
|
python_path = os.pathsep.join( |
|
[ |
|
os.path.realpath(os.path.join(os.path.dirname(__file__), "..", "..")), |
|
os.path.join(python_home, "Lib", "site-packages"), |
|
] |
|
+ sys.path, |
|
) |
|
|
|
|
|
|
|
kernel_env = os.environ.copy() |
|
kernel_env.update( |
|
{ |
|
"TASKWEAVER_ENV_ID": self.id, |
|
"TASKWEAVER_SESSION_ID": session.session_id, |
|
"TASKWEAVER_SESSION_DIR": session.session_dir, |
|
"TASKWEAVER_LOGGING_FILE_PATH": os.path.join( |
|
ces_session_dir, |
|
"kernel_logging.log", |
|
), |
|
"CONNECTION_FILE": connection_file, |
|
"PATH": os.environ["PATH"], |
|
"PYTHONPATH": python_path, |
|
"PYTHONHOME": python_home, |
|
}, |
|
) |
|
session.kernel_id = self.multi_kernel_manager.start_kernel( |
|
kernel_id=kernel_id, |
|
cwd=cwd, |
|
env=kernel_env, |
|
) |
|
self._cmd_session_init(session) |
|
session.kernel_status = "ready" |
|
|
|
def execute_code( |
|
self, |
|
session_id: str, |
|
code: str, |
|
exec_id: Optional[str] = None, |
|
) -> ExecutionResult: |
|
exec_id = get_id(prefix="exec") if exec_id is None else exec_id |
|
session = self._get_session(session_id) |
|
if session.kernel_status == "pending": |
|
self.start_session(session_id) |
|
|
|
session.execution_count += 1 |
|
execution_index = session.execution_count |
|
self._execute_control_code_on_kernel( |
|
session.kernel_id, |
|
f"%_taskweaver_exec_pre_check {execution_index} {exec_id}", |
|
) |
|
exec_result = self._execute_code_on_kernel( |
|
session.kernel_id, |
|
exec_id=exec_id, |
|
code=code, |
|
) |
|
exec_extra_result = self._execute_control_code_on_kernel( |
|
session.kernel_id, |
|
f"%_taskweaver_exec_post_check {execution_index} {exec_id}", |
|
) |
|
session.execution_dict[exec_id] = exec_result |
|
|
|
|
|
return self._parse_exec_result(exec_result, exec_extra_result["data"]) |
|
|
|
def load_plugin( |
|
self, |
|
session_id: str, |
|
plugin_name: str, |
|
plugin_impl: str, |
|
plugin_config: Optional[Dict[str, str]] = None, |
|
) -> None: |
|
session = self._get_session(session_id) |
|
if plugin_name in session.plugins.keys(): |
|
prev_plugin = session.plugins[plugin_name] |
|
if prev_plugin.loaded: |
|
self._cmd_plugin_unload(session, prev_plugin) |
|
del session.plugins[plugin_name] |
|
|
|
plugin = EnvPlugin( |
|
name=plugin_name, |
|
impl=plugin_impl, |
|
config=plugin_config, |
|
loaded=False, |
|
) |
|
self._cmd_plugin_load(session, plugin) |
|
plugin.loaded = True |
|
session.plugins[plugin_name] = plugin |
|
|
|
def test_plugin( |
|
self, |
|
session_id: str, |
|
plugin_name: str, |
|
) -> None: |
|
session = self._get_session(session_id) |
|
plugin = session.plugins[plugin_name] |
|
self._cmd_plugin_test(session, plugin) |
|
|
|
def unload_plugin( |
|
self, |
|
session_id: str, |
|
plugin_name: str, |
|
) -> None: |
|
session = self._get_session(session_id) |
|
if plugin_name in session.plugins.keys(): |
|
plugin = session.plugins[plugin_name] |
|
if plugin.loaded: |
|
self._cmd_plugin_unload(session, plugin) |
|
del session.plugins[plugin_name] |
|
|
|
def update_session_var( |
|
self, |
|
session_id: str, |
|
session_var: Dict[str, str], |
|
) -> None: |
|
session = self._get_session(session_id) |
|
session.session_var.update(session_var) |
|
self._update_session_var(session) |
|
|
|
def stop_session(self, session_id: str) -> None: |
|
session = self._get_session(session_id) |
|
if session.kernel_status == "stopped": |
|
return |
|
if session.kernel_status == "pending": |
|
session.kernel_status = "stopped" |
|
return |
|
try: |
|
if session.kernel_id != "": |
|
kernel = self.multi_kernel_manager.get_kernel(session.kernel_id) |
|
is_alive = kernel.is_alive() |
|
if is_alive: |
|
kernel.shutdown_kernel(now=True) |
|
kernel.cleanup_resources() |
|
except Exception as e: |
|
logger.error(e) |
|
session.kernel_status = "stopped" |
|
|
|
def download_file(self, session_id: str, file_path: str) -> str: |
|
session = self._get_session(session_id) |
|
full_path = self._execute_code_on_kernel( |
|
session.kernel_id, |
|
get_id(prefix="exec"), |
|
f"%%_taskweaver_convert_path\n{file_path}", |
|
silent=True, |
|
) |
|
return full_path.result["text/plain"] |
|
|
|
def _get_session( |
|
self, |
|
session_id: str, |
|
session_dir: Optional[str] = None, |
|
) -> EnvSession: |
|
if session_id not in self.session_dict: |
|
new_session = EnvSession(session_id) |
|
new_session.session_dir = ( |
|
session_dir if session_dir is not None else self._get_default_session_dir(session_id) |
|
) |
|
os.makedirs(new_session.session_dir, exist_ok=True) |
|
self.session_dict[session_id] = new_session |
|
return self.session_dict[session_id] |
|
|
|
def _get_default_session_dir(self, session_id: str) -> str: |
|
os.makedirs(os.path.join(self.env_dir, "sessions"), exist_ok=True) |
|
return os.path.join(self.env_dir, "sessions", session_id) |
|
|
|
def _execute_control_code_on_kernel( |
|
self, |
|
kernel_id: str, |
|
code: str, |
|
silent: bool = False, |
|
store_history: bool = False, |
|
) -> Dict[Literal["is_success", "message", "data"], Union[bool, str, Any]]: |
|
exec_result = self._execute_code_on_kernel( |
|
kernel_id, |
|
get_id(prefix="exec"), |
|
code=code, |
|
silent=silent, |
|
store_history=store_history, |
|
exec_type="control", |
|
) |
|
if exec_result.error != "": |
|
raise Exception(exec_result.error) |
|
if "text/plain" not in exec_result.result: |
|
raise Exception("No text returned.") |
|
result = literal_eval(exec_result.result["text/plain"]) |
|
if not result["is_success"]: |
|
raise Exception(result["message"]) |
|
return result |
|
|
|
def _execute_code_on_kernel( |
|
self, |
|
kernel_id: str, |
|
exec_id: str, |
|
code: str, |
|
silent: bool = False, |
|
store_history: bool = True, |
|
exec_type: ExecType = "user", |
|
) -> EnvExecution: |
|
exec_result = EnvExecution(exec_id=exec_id, code=code, exec_type=exec_type) |
|
km = self.multi_kernel_manager.get_kernel(kernel_id) |
|
kc = km.client() |
|
kc.start_channels() |
|
kc.wait_for_ready(10) |
|
result_msg_id = kc.execute( |
|
code=code, |
|
silent=silent, |
|
store_history=store_history, |
|
allow_stdin=False, |
|
stop_on_error=True, |
|
) |
|
try: |
|
|
|
while True: |
|
message = kc.get_iopub_msg(timeout=180) |
|
|
|
logger.info(json.dumps(message, indent=2, default=str)) |
|
|
|
assert message["parent_header"]["msg_id"] == result_msg_id |
|
msg_type = message["msg_type"] |
|
if msg_type == "status": |
|
if message["content"]["execution_state"] == "idle": |
|
break |
|
elif msg_type == "stream": |
|
stream_name = message["content"]["name"] |
|
stream_text = message["content"]["text"] |
|
|
|
if stream_name == "stdout": |
|
exec_result.stdout.append(stream_text) |
|
elif stream_name == "stderr": |
|
exec_result.stderr.append(stream_text) |
|
else: |
|
assert False, f"Unsupported stream name: {stream_name}" |
|
|
|
elif msg_type == "execute_result": |
|
execute_result = message["content"]["data"] |
|
exec_result.result = execute_result |
|
elif msg_type == "error": |
|
error_name = message["content"]["ename"] |
|
error_value = message["content"]["evalue"] |
|
error_traceback_lines = message["content"]["traceback"] |
|
if error_traceback_lines is None: |
|
error_traceback_lines = [f"{error_name}: {error_value}"] |
|
error_traceback = "\n".join(error_traceback_lines) |
|
exec_result.error = error_traceback |
|
elif msg_type == "execute_input": |
|
pass |
|
elif msg_type == "display_data": |
|
data: Dict[ResultMimeType, Any] = message["content"]["data"] |
|
metadata: Dict[str, Any] = message["content"]["metadata"] |
|
transient: Dict[str, Any] = message["content"]["transient"] |
|
exec_result.displays.append( |
|
DisplayData(data=data, metadata=metadata, transient=transient), |
|
) |
|
elif msg_type == "update_display_data": |
|
data: Dict[ResultMimeType, Any] = message["content"]["data"] |
|
metadata: Dict[str, Any] = message["content"]["metadata"] |
|
transient: Dict[str, Any] = message["content"]["transient"] |
|
exec_result.displays.append( |
|
DisplayData(data=data, metadata=metadata, transient=transient), |
|
) |
|
else: |
|
assert False, f"Unsupported message from kernel: {msg_type}, the jupyter_client might be outdated." |
|
finally: |
|
kc.stop_channels() |
|
return exec_result |
|
|
|
def _update_session_var(self, session: EnvSession) -> None: |
|
self._execute_control_code_on_kernel( |
|
session.kernel_id, |
|
f"%%_taskweaver_update_session_var\n{json.dumps(session.session_var)}", |
|
) |
|
|
|
def _cmd_session_init(self, session: EnvSession) -> None: |
|
self._execute_control_code_on_kernel( |
|
session.kernel_id, |
|
f"%_taskweaver_session_init {session.session_id}", |
|
) |
|
|
|
def _cmd_plugin_load(self, session: EnvSession, plugin: EnvPlugin) -> None: |
|
self._execute_control_code_on_kernel( |
|
session.kernel_id, |
|
f"%%_taskweaver_plugin_register {plugin.name}\n{plugin.impl}", |
|
) |
|
self._execute_control_code_on_kernel( |
|
session.kernel_id, |
|
f"%%_taskweaver_plugin_load {plugin.name}\n{json.dumps(plugin.config or {})}", |
|
) |
|
|
|
def _cmd_plugin_test(self, session: EnvSession, plugin: EnvPlugin) -> None: |
|
self._execute_control_code_on_kernel( |
|
session.kernel_id, |
|
f"%_taskweaver_plugin_test {plugin.name}", |
|
) |
|
|
|
def _cmd_plugin_unload(self, session: EnvSession, plugin: EnvPlugin) -> None: |
|
self._execute_control_code_on_kernel( |
|
session.kernel_id, |
|
f"%_taskweaver_plugin_unload {plugin.name}", |
|
) |
|
|
|
def _parse_exec_result( |
|
self, |
|
exec_result: EnvExecution, |
|
extra_result: Optional[Dict[str, Any]] = None, |
|
) -> ExecutionResult: |
|
result = ExecutionResult( |
|
execution_id=exec_result.exec_id, |
|
code=exec_result.code, |
|
is_success=exec_result.error == "", |
|
error=exec_result.error, |
|
output="", |
|
stdout=exec_result.stdout, |
|
stderr=exec_result.stderr, |
|
log=[], |
|
artifact=[], |
|
) |
|
|
|
for mime_type in exec_result.result.keys(): |
|
if mime_type.startswith("text/"): |
|
text_result = exec_result.result[mime_type] |
|
try: |
|
parsed_result = literal_eval(text_result) |
|
result.output = parsed_result |
|
except Exception: |
|
result.output = text_result |
|
display_artifact_count = 0 |
|
for display in exec_result.displays: |
|
display_artifact_count += 1 |
|
artifact = ExecutionArtifact() |
|
artifact.name = f"{exec_result.exec_id}-display-{display_artifact_count}" |
|
has_svg = False |
|
has_pic = False |
|
for mime_type in display.data.keys(): |
|
if mime_type.startswith("image/"): |
|
if mime_type == "image/svg+xml": |
|
if has_pic and has_svg: |
|
continue |
|
has_svg = True |
|
has_pic = True |
|
artifact.type = "svg" |
|
artifact.file_content_encoding = "str" |
|
else: |
|
if has_pic: |
|
continue |
|
has_pic = True |
|
artifact.type = "image" |
|
artifact.file_content_encoding = "base64" |
|
artifact.mime_type = mime_type |
|
artifact.file_content = display.data[mime_type] |
|
if mime_type.startswith("text/"): |
|
artifact.preview = display.data[mime_type] |
|
|
|
if has_pic: |
|
result.artifact.append(artifact) |
|
|
|
if isinstance(extra_result, dict): |
|
for key, value in extra_result.items(): |
|
if key == "log": |
|
result.log = value |
|
elif key == "artifact": |
|
for artifact_dict in value: |
|
artifact_item = ExecutionArtifact( |
|
name=artifact_dict["name"], |
|
type=artifact_dict["type"], |
|
original_name=artifact_dict["original_name"], |
|
file_name=artifact_dict["file"], |
|
preview=artifact_dict["preview"], |
|
) |
|
result.artifact.append(artifact_item) |
|
else: |
|
pass |
|
|
|
return result |
|
|