import os from pathlib import Path from typing import List, Literal from injector import inject from taskweaver.ces.common import ExecutionResult, Manager from taskweaver.config.config_mgt import AppConfigSource from taskweaver.memory.plugin import PluginRegistry from taskweaver.plugin.context import ArtifactType TRUNCATE_CHAR_LENGTH = 1000 def get_artifact_uri(execution_id: str, file: str, use_local_uri: bool) -> str: return ( Path(os.path.join("workspace", execution_id, file)).as_uri() if use_local_uri else f"http://artifact-ref/{file}" ) def get_default_artifact_name(artifact_type: ArtifactType, mine_type: str) -> str: if artifact_type == "file": return "artifact" if artifact_type == "image": if mine_type == "image/png": return "image.png" if mine_type == "image/jpeg": return "image.jpg" if mine_type == "image/gif": return "image.gif" if mine_type == "image/svg+xml": return "image.svg" if artifact_type == "chart": return "chart.json" if artifact_type == "svg": return "svg.svg" return "file" class CodeExecutor: @inject def __init__( self, session_id: str, workspace: str, execution_cwd: str, config: AppConfigSource, exec_mgr: Manager, plugin_registry: PluginRegistry, ) -> None: self.session_id = session_id self.workspace = workspace self.execution_cwd = execution_cwd self.exec_mgr = exec_mgr self.exec_client = exec_mgr.get_session_client( session_id, session_dir=workspace, cwd=execution_cwd, ) self.client_started: bool = False self.plugin_registry = plugin_registry self.plugin_loaded: bool = False self.config = config def execute_code(self, exec_id: str, code: str) -> ExecutionResult: if not self.client_started: self.start() self.client_started = True if not self.plugin_loaded: self.load_plugin() self.plugin_loaded = True result = self.exec_client.execute_code(exec_id, code) if result.is_success: for artifact in result.artifact: if artifact.file_name == "": original_name = ( artifact.original_name if artifact.original_name != "" else get_default_artifact_name( artifact.type, artifact.mime_type, ) ) file_name = f"{artifact.name}_{original_name}" self._save_file( file_name, artifact.file_content, artifact.file_content_encoding, ) artifact.file_name = file_name return result def _save_file( self, file_name: str, content: str, content_encoding: Literal["base64", "str"] = "str", ) -> None: file_path = os.path.join(self.execution_cwd, file_name) if content_encoding == "base64": with open(file_path, "wb") as f: import base64 f.write(base64.b64decode(content)) else: with open(file_path, "w") as f: f.write(content) def load_plugin(self): for p in self.plugin_registry.get_list(): try: src_file = f"{self.config.app_base_path}/plugins/{p.impl}.py" with open(src_file, "r") as f: plugin_code = f.read() self.exec_client.load_plugin( p.name, plugin_code, p.config, ) except Exception as e: print(f"Plugin {p.name} failed to load: {str(e)}") def start(self): self.exec_client.start() def stop(self): self.exec_client.stop() def format_code_output( self, result: ExecutionResult, indent: int = 0, with_code: bool = True, use_local_uri: bool = False, ) -> str: lines: List[str] = [] # code execution result if with_code: lines.append( f"The following python code has been executed:\n" "```python\n" f"{result.code}\n" "```\n\n", ) lines.append( f"The execution of the generated python code above has" f" {'succeeded' if result.is_success else 'failed'}\n", ) # code output if result.output != "": output = result.output if isinstance(output, list) and len(output) > 0: lines.append( "The values of variables of the above Python code after execution are:\n", ) for o in output: lines.append(f"{str(o)}") lines.append("") else: lines.append( "The result of above Python code after execution is:\n" + str(output), ) elif result.is_success: if len(result.stdout) > 0: lines.append( "The stdout is:", ) lines.append("\n".join(result.stdout)[:TRUNCATE_CHAR_LENGTH]) else: lines.append( "The execution is successful but no output is generated.", ) # console output when execution failed if not result.is_success: lines.append( "During execution, the following messages were logged:", ) if len(result.log) > 0: lines.extend([f"- [(l{1})]{ln[0]}: {ln[2]}" for ln in result.log]) if result.error is not None: lines.append(result.error[:TRUNCATE_CHAR_LENGTH]) if len(result.stdout) > 0: lines.append("\n".join(result.stdout)[:TRUNCATE_CHAR_LENGTH]) if len(result.stderr) > 0: lines.append("\n".join(result.stderr)[:TRUNCATE_CHAR_LENGTH]) lines.append("") # artifacts if len(result.artifact) > 0: lines.append("The following artifacts were generated:") lines.extend( [ f"- type: {a.type} ; uri: " + ( get_artifact_uri( execution_id=result.execution_id, file=( a.file_name if os.path.isabs(a.file_name) or not use_local_uri else os.path.join(self.execution_cwd, a.file_name) ), use_local_uri=use_local_uri, ) ) + f" ; description: {a.preview}" for a in result.artifact ], ) lines.append("") return "\n".join([" " * indent + ln for ln in lines])