|
import os |
|
import shutil |
|
from typing import Dict |
|
|
|
from injector import Injector, inject |
|
|
|
from taskweaver.code_interpreter import CodeInterpreter, CodeInterpreterPluginOnly |
|
from taskweaver.code_interpreter.code_executor import CodeExecutor |
|
from taskweaver.config.module_config import ModuleConfig |
|
from taskweaver.logging import TelemetryLogger |
|
from taskweaver.memory import Memory, Post, Round |
|
from taskweaver.planner.planner import Planner |
|
from taskweaver.workspace.workspace import Workspace |
|
|
|
|
|
class AppSessionConfig(ModuleConfig): |
|
def _configure(self) -> None: |
|
self._set_name("session") |
|
|
|
self.code_interpreter_only = self._get_bool("code_interpreter_only", False) |
|
self.max_internal_chat_round_num = self._get_int("max_internal_chat_round_num", 10) |
|
self.plugin_only_mode = self._get_bool("plugin_only_mode", False) |
|
|
|
|
|
class Session: |
|
@inject |
|
def __init__( |
|
self, |
|
session_id: str, |
|
workspace: Workspace, |
|
app_injector: Injector, |
|
logger: TelemetryLogger, |
|
config: AppSessionConfig, |
|
) -> None: |
|
assert session_id is not None, "session_id must be provided" |
|
self.logger = logger |
|
self.session_injector = app_injector.create_child_injector() |
|
|
|
self.config = config |
|
|
|
self.session_id: str = session_id |
|
|
|
self.workspace = workspace.get_session_dir(self.session_id) |
|
self.execution_cwd = os.path.join(self.workspace, "cwd") |
|
|
|
self.round_index = 0 |
|
self.memory = Memory(session_id=self.session_id) |
|
|
|
self.session_var: Dict[str, str] = {} |
|
|
|
self.planner = self.session_injector.create_object( |
|
Planner, |
|
{ |
|
"plugin_only": self.config.plugin_only_mode, |
|
}, |
|
) |
|
self.code_executor = self.session_injector.create_object( |
|
CodeExecutor, |
|
{ |
|
"session_id": self.session_id, |
|
"workspace": self.workspace, |
|
"execution_cwd": self.execution_cwd, |
|
}, |
|
) |
|
self.session_injector.binder.bind(CodeExecutor, self.code_executor) |
|
if self.config.plugin_only_mode: |
|
self.code_interpreter = self.session_injector.get(CodeInterpreterPluginOnly) |
|
else: |
|
self.code_interpreter = self.session_injector.get(CodeInterpreter) |
|
|
|
self.max_internal_chat_round_num = self.config.max_internal_chat_round_num |
|
self.internal_chat_num = 0 |
|
|
|
self.init() |
|
|
|
self.logger.dump_log_file( |
|
self, |
|
file_path=os.path.join(self.workspace, f"{self.session_id}.json"), |
|
) |
|
|
|
def init(self): |
|
if not os.path.exists(self.workspace): |
|
os.makedirs(self.workspace) |
|
|
|
if not os.path.exists(self.execution_cwd): |
|
os.makedirs(self.execution_cwd) |
|
|
|
self.logger.info(f"Session {self.session_id} is initialized") |
|
|
|
def update_session_var(self, variables: Dict[str, str]): |
|
self.session_var.update(variables) |
|
|
|
def send_message(self, message: str, event_handler: callable = None) -> Round: |
|
event_handler = event_handler or (lambda *args: None) |
|
chat_round = self.memory.create_round(user_query=message) |
|
|
|
def _send_message(recipient: str, post: Post): |
|
chat_round.add_post(post) |
|
|
|
use_back_up_engine = True if recipient == post.send_from else False |
|
self.logger.info(f"Use back up engine: {use_back_up_engine}") |
|
|
|
if recipient == "Planner": |
|
reply_post = self.planner.reply( |
|
self.memory, |
|
prompt_log_path=os.path.join( |
|
self.workspace, |
|
f"planner_prompt_log_{chat_round.id}_{post.id}.json", |
|
), |
|
event_handler=event_handler, |
|
use_back_up_engine=use_back_up_engine, |
|
) |
|
elif recipient == "CodeInterpreter": |
|
reply_post = self.code_interpreter.reply( |
|
self.memory, |
|
event_handler=event_handler, |
|
prompt_log_path=os.path.join( |
|
self.workspace, |
|
f"code_generator_prompt_log_{chat_round.id}_{post.id}.json", |
|
), |
|
use_back_up_engine=use_back_up_engine, |
|
) |
|
else: |
|
raise Exception(f"Unknown recipient {recipient}") |
|
|
|
return reply_post |
|
|
|
try: |
|
if not self.config.code_interpreter_only: |
|
post = Post.create(message=message, send_from="User", send_to="Planner") |
|
while True: |
|
post = _send_message(post.send_to, post) |
|
self.logger.info( |
|
f"{post.send_from} talk to {post.send_to}: {post.message}", |
|
) |
|
self.internal_chat_num += 1 |
|
if post.send_to == "User": |
|
chat_round.add_post(post) |
|
self.internal_chat_num = 0 |
|
break |
|
if self.internal_chat_num >= self.max_internal_chat_round_num: |
|
raise Exception( |
|
f"Internal chat round number exceeds the limit of {self.max_internal_chat_round_num}", |
|
) |
|
else: |
|
post = Post.create( |
|
message=message, |
|
send_from="Planner", |
|
send_to="CodeInterpreter", |
|
) |
|
post = _send_message("CodeInterpreter", post) |
|
event_handler("final_reply_message", post.message) |
|
|
|
self.round_index += 1 |
|
chat_round.change_round_state("finished") |
|
|
|
except Exception as e: |
|
import traceback |
|
|
|
stack_trace_str = traceback.format_exc() |
|
self.logger.error(stack_trace_str) |
|
chat_round.change_round_state("failed") |
|
err_message = f"Cannot process your request due to Exception: {str(e)} \n {stack_trace_str}" |
|
event_handler("error", err_message) |
|
|
|
finally: |
|
self.internal_chat_num = 0 |
|
self.logger.dump_log_file( |
|
chat_round, |
|
file_path=os.path.join( |
|
self.workspace, |
|
f"{self.session_id}_{chat_round.id}.json", |
|
), |
|
) |
|
return chat_round |
|
|
|
def send_file( |
|
self, |
|
file_name: str, |
|
file_path: str, |
|
event_handler: callable, |
|
) -> Round: |
|
file_full_path = self.get_full_path(self.execution_cwd, file_name) |
|
if os.path.exists(file_full_path): |
|
os.remove(file_full_path) |
|
message = f'reload file "{file_name}"' |
|
else: |
|
message = f'load file "{file_name}"' |
|
|
|
shutil.copyfile(file_path, file_full_path) |
|
|
|
return self.send_message(message, event_handler=event_handler) |
|
|
|
def get_full_path(self, *file_path: str, in_execution_cwd: bool = False) -> str: |
|
return str( |
|
os.path.realpath( |
|
os.path.join( |
|
self.workspace if not in_execution_cwd else self.execution_cwd, |
|
*file_path, |
|
), |
|
), |
|
) |
|
|
|
def to_dict(self) -> Dict: |
|
return { |
|
"session_id": self.session_id, |
|
"workspace": self.workspace, |
|
"execution_cwd": self.execution_cwd, |
|
} |
|
|