import os import shutil from typing import Dict from injector import Injector, inject from taskweaver.code_interpreter import CodeInterpreter, CodeInterpreterPluginOnly from taskweaver.code_interpreter.code_executor import CodeExecutor from taskweaver.config.module_config import ModuleConfig from taskweaver.logging import TelemetryLogger from taskweaver.memory import Memory, Post, Round from taskweaver.planner.planner import Planner from taskweaver.workspace.workspace import Workspace class AppSessionConfig(ModuleConfig): def _configure(self) -> None: self._set_name("session") self.code_interpreter_only = self._get_bool("code_interpreter_only", False) self.max_internal_chat_round_num = self._get_int("max_internal_chat_round_num", 10) self.plugin_only_mode = self._get_bool("plugin_only_mode", False) class Session: @inject def __init__( self, session_id: str, workspace: Workspace, app_injector: Injector, logger: TelemetryLogger, config: AppSessionConfig, # TODO: change to SessionConfig ) -> None: assert session_id is not None, "session_id must be provided" self.logger = logger self.session_injector = app_injector.create_child_injector() self.config = config self.session_id: str = session_id self.workspace = workspace.get_session_dir(self.session_id) self.execution_cwd = os.path.join(self.workspace, "cwd") self.round_index = 0 self.memory = Memory(session_id=self.session_id) self.session_var: Dict[str, str] = {} self.planner = self.session_injector.create_object( Planner, { "plugin_only": self.config.plugin_only_mode, }, ) self.code_executor = self.session_injector.create_object( CodeExecutor, { "session_id": self.session_id, "workspace": self.workspace, "execution_cwd": self.execution_cwd, }, ) self.session_injector.binder.bind(CodeExecutor, self.code_executor) if self.config.plugin_only_mode: self.code_interpreter = self.session_injector.get(CodeInterpreterPluginOnly) else: self.code_interpreter = self.session_injector.get(CodeInterpreter) self.max_internal_chat_round_num = self.config.max_internal_chat_round_num self.internal_chat_num = 0 self.init() self.logger.dump_log_file( self, file_path=os.path.join(self.workspace, f"{self.session_id}.json"), ) def init(self): if not os.path.exists(self.workspace): os.makedirs(self.workspace) if not os.path.exists(self.execution_cwd): os.makedirs(self.execution_cwd) self.logger.info(f"Session {self.session_id} is initialized") def update_session_var(self, variables: Dict[str, str]): self.session_var.update(variables) def send_message(self, message: str, event_handler: callable = None) -> Round: event_handler = event_handler or (lambda *args: None) chat_round = self.memory.create_round(user_query=message) def _send_message(recipient: str, post: Post): chat_round.add_post(post) use_back_up_engine = True if recipient == post.send_from else False self.logger.info(f"Use back up engine: {use_back_up_engine}") if recipient == "Planner": reply_post = self.planner.reply( self.memory, prompt_log_path=os.path.join( self.workspace, f"planner_prompt_log_{chat_round.id}_{post.id}.json", ), event_handler=event_handler, use_back_up_engine=use_back_up_engine, ) elif recipient == "CodeInterpreter": reply_post = self.code_interpreter.reply( self.memory, event_handler=event_handler, prompt_log_path=os.path.join( self.workspace, f"code_generator_prompt_log_{chat_round.id}_{post.id}.json", ), use_back_up_engine=use_back_up_engine, ) else: raise Exception(f"Unknown recipient {recipient}") return reply_post try: if not self.config.code_interpreter_only: post = Post.create(message=message, send_from="User", send_to="Planner") while True: post = _send_message(post.send_to, post) self.logger.info( f"{post.send_from} talk to {post.send_to}: {post.message}", ) self.internal_chat_num += 1 if post.send_to == "User": chat_round.add_post(post) self.internal_chat_num = 0 break if self.internal_chat_num >= self.max_internal_chat_round_num: raise Exception( f"Internal chat round number exceeds the limit of {self.max_internal_chat_round_num}", ) else: post = Post.create( message=message, send_from="Planner", send_to="CodeInterpreter", ) post = _send_message("CodeInterpreter", post) event_handler("final_reply_message", post.message) self.round_index += 1 chat_round.change_round_state("finished") except Exception as e: import traceback stack_trace_str = traceback.format_exc() self.logger.error(stack_trace_str) chat_round.change_round_state("failed") err_message = f"Cannot process your request due to Exception: {str(e)} \n {stack_trace_str}" event_handler("error", err_message) finally: self.internal_chat_num = 0 self.logger.dump_log_file( chat_round, file_path=os.path.join( self.workspace, f"{self.session_id}_{chat_round.id}.json", ), ) return chat_round def send_file( self, file_name: str, file_path: str, event_handler: callable, ) -> Round: file_full_path = self.get_full_path(self.execution_cwd, file_name) if os.path.exists(file_full_path): os.remove(file_full_path) message = f'reload file "{file_name}"' else: message = f'load file "{file_name}"' shutil.copyfile(file_path, file_full_path) return self.send_message(message, event_handler=event_handler) def get_full_path(self, *file_path: str, in_execution_cwd: bool = False) -> str: return str( os.path.realpath( os.path.join( self.workspace if not in_execution_cwd else self.execution_cwd, *file_path, # type: ignore ), ), ) def to_dict(self) -> Dict: return { "session_id": self.session_id, "workspace": self.workspace, "execution_cwd": self.execution_cwd, }