ChatDev / chatdev /chat_chain.py
ennet's picture
Duplicate from sp12138sp/ChatDev
594c559
import importlib
import json
import os
import shutil
from datetime import datetime
import logging
import time
from camel.agents import RolePlaying
from camel.configs import ChatGPTConfig
from camel.typing import TaskType, ModelType
from chatdev.chat_env import ChatEnv, ChatEnvConfig
from chatdev.statistics import get_info
from chatdev.utils import log_and_print_online, now
def check_bool(s):
return s.lower() == "true"
class ChatChain:
def __init__(self,
config_path: str = None,
config_phase_path: str = None,
config_role_path: str = None,
task_prompt: str = None,
project_name: str = None,
org_name: str = None,
model_type: ModelType = ModelType.GPT_3_5_TURBO) -> None:
"""
Args:
config_path: path to the ChatChainConfig.json
config_phase_path: path to the PhaseConfig.json
config_role_path: path to the RoleConfig.json
task_prompt: the user input prompt for software
project_name: the user input name for software
org_name: the organization name of the human user
"""
# load config file
self.config_path = config_path
self.config_phase_path = config_phase_path
self.config_role_path = config_role_path
self.project_name = project_name
self.org_name = org_name
self.model_type = model_type
with open(self.config_path, 'r', encoding="utf8") as file:
self.config = json.load(file)
with open(self.config_phase_path, 'r', encoding="utf8") as file:
self.config_phase = json.load(file)
with open(self.config_role_path, 'r', encoding="utf8") as file:
self.config_role = json.load(file)
# init chatchain config and recruitments
self.chain = self.config["chain"]
self.recruitments = self.config["recruitments"]
# init default max chat turn
self.chat_turn_limit_default = 10
# init ChatEnv
self.chat_env_config = ChatEnvConfig(clear_structure=check_bool(self.config["clear_structure"]),
brainstorming=check_bool(self.config["brainstorming"]),
gui_design=check_bool(self.config["gui_design"]),
git_management=check_bool(self.config["git_management"]))
self.chat_env = ChatEnv(self.chat_env_config)
# the user input prompt will be self-improved (if set "self_improve": "True" in ChatChainConfig.json)
# the self-improvement is done in self.preprocess
self.task_prompt_raw = task_prompt
self.task_prompt = ""
# init role prompts
self.role_prompts = dict()
for role in self.config_role:
self.role_prompts[role] = "\n".join(self.config_role[role])
# init log
self.start_time, self.log_filepath = self.get_logfilepath()
# init SimplePhase instances
# import all used phases in PhaseConfig.json from chatdev.phase
# note that in PhaseConfig.json there only exist SimplePhases
# ComposedPhases are defined in ChatChainConfig.json and will be imported in self.execute_step
self.compose_phase_module = importlib.import_module("chatdev.composed_phase")
self.phase_module = importlib.import_module("chatdev.phase")
self.phases = dict()
for phase in self.config_phase:
assistant_role_name = self.config_phase[phase]['assistant_role_name']
user_role_name = self.config_phase[phase]['user_role_name']
phase_prompt = "\n\n".join(self.config_phase[phase]['phase_prompt'])
phase_class = getattr(self.phase_module, phase)
phase_instance = phase_class(assistant_role_name=assistant_role_name,
user_role_name=user_role_name,
phase_prompt=phase_prompt,
role_prompts=self.role_prompts,
phase_name=phase,
model_type=self.model_type,
log_filepath=self.log_filepath)
self.phases[phase] = phase_instance
def make_recruitment(self):
"""
recruit all employees
Returns: None
"""
for employee in self.recruitments:
self.chat_env.recruit(agent_name=employee)
def execute_step(self, phase_item: dict):
"""
execute single phase in the chain
Args:
phase_item: single phase configuration in the ChatChainConfig.json
Returns:
"""
phase = phase_item['phase']
phase_type = phase_item['phaseType']
# For SimplePhase, just look it up from self.phases and conduct the "Phase.execute" method
if phase_type == "SimplePhase":
max_turn_step = phase_item['max_turn_step']
need_reflect = check_bool(phase_item['need_reflect'])
if phase in self.phases:
self.chat_env = self.phases[phase].execute(self.chat_env,
self.chat_turn_limit_default if max_turn_step <= 0 else max_turn_step,
need_reflect)
else:
raise RuntimeError(f"Phase '{phase}' is not yet implemented in chatdev.phase")
# For ComposedPhase, we create instance here then conduct the "ComposedPhase.execute" method
elif phase_type == "ComposedPhase":
cycle_num = phase_item['cycleNum']
composition = phase_item['Composition']
compose_phase_class = getattr(self.compose_phase_module, phase)
if not compose_phase_class:
raise RuntimeError(f"Phase '{phase}' is not yet implemented in chatdev.compose_phase")
compose_phase_instance = compose_phase_class(phase_name=phase,
cycle_num=cycle_num,
composition=composition,
config_phase=self.config_phase,
config_role=self.config_role,
model_type=self.model_type,
log_filepath=self.log_filepath)
self.chat_env = compose_phase_instance.execute(self.chat_env)
else:
raise RuntimeError(f"PhaseType '{phase_type}' is not yet implemented.")
def execute_chain(self):
"""
execute the whole chain based on ChatChainConfig.json
Returns: None
"""
for phase_item in self.chain:
self.execute_step(phase_item)
def get_logfilepath(self):
"""
get the log path (under the software path)
Returns:
start_time: time for starting making the software
log_filepath: path to the log
"""
start_time = now()
filepath = os.path.dirname(__file__)
# root = "/".join(filepath.split("/")[:-1])
root = os.path.dirname(filepath)
# directory = root + "/WareHouse/"
directory = os.path.join(root, "WareHouse")
log_filepath = os.path.join(directory, "{}.log".format("_".join([self.project_name, self.org_name,start_time])))
return start_time, log_filepath
def pre_processing(self):
"""
remove useless files and log some global config settings
Returns: None
"""
if self.chat_env.config.clear_structure:
filepath = os.path.dirname(__file__)
# root = "/".join(filepath.split("/")[:-1])
root = os.path.dirname(filepath)
# directory = root + "/WareHouse"
directory = os.path.join(root, "WareHouse")
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
# logs with error trials are left in WareHouse/
if os.path.isfile(file_path) and not filename.endswith(".py") and not filename.endswith(".log"):
os.remove(file_path)
print("{} Removed.".format(file_path))
software_path = os.path.join(directory, "_".join([self.project_name, self.org_name, self.start_time]))
self.chat_env.set_directory(software_path)
# copy config files to software path
shutil.copy(self.config_path, software_path)
shutil.copy(self.config_phase_path, software_path)
shutil.copy(self.config_role_path, software_path)
# write task prompt to software path
with open(os.path.join(software_path, self.project_name + ".prompt"), "w") as f:
f.write(self.task_prompt_raw)
preprocess_msg = "**[Preprocessing]**\n\n"
chat_gpt_config = ChatGPTConfig()
preprocess_msg += "**ChatDev Starts** ({})\n\n".format(self.start_time)
preprocess_msg += "**Timestamp**: {}\n\n".format(self.start_time)
preprocess_msg += "**config_path**: {}\n\n".format(self.config_path)
preprocess_msg += "**config_phase_path**: {}\n\n".format(self.config_phase_path)
preprocess_msg += "**config_role_path**: {}\n\n".format(self.config_role_path)
preprocess_msg += "**task_prompt**: {}\n\n".format(self.task_prompt_raw)
preprocess_msg += "**project_name**: {}\n\n".format(self.project_name)
preprocess_msg += "**Log File**: {}\n\n".format(self.log_filepath)
preprocess_msg += "**ChatDevConfig**:\n {}\n\n".format(self.chat_env.config.__str__())
preprocess_msg += "**ChatGPTConfig**:\n {}\n\n".format(chat_gpt_config)
log_and_print_online(preprocess_msg)
# init task prompt
if check_bool(self.config['self_improve']):
self.chat_env.env_dict['task_prompt'] = self.self_task_improve(self.task_prompt_raw)
else:
self.chat_env.env_dict['task_prompt'] = self.task_prompt_raw
def post_processing(self):
"""
summarize the production and move log files to the software directory
Returns: None
"""
self.chat_env.write_meta()
filepath = os.path.dirname(__file__)
# root = "/".join(filepath.split("/")[:-1])
root = os.path.dirname(filepath)
post_info = "**[Post Info]**\n\n"
now_time = now()
time_format = "%Y%m%d%H%M%S"
datetime1 = datetime.strptime(self.start_time, time_format)
datetime2 = datetime.strptime(now_time, time_format)
duration = (datetime2 - datetime1).total_seconds()
post_info += "Software Info: {}".format(
get_info(self.chat_env.env_dict['directory'], self.log_filepath) + "\n\n🕑**duration**={:.2f}s\n\n".format(duration))
post_info += "ChatDev Starts ({})".format(self.start_time) + "\n\n"
post_info += "ChatDev Ends ({})".format(now_time) + "\n\n"
if self.chat_env.config.clear_structure:
directory = self.chat_env.env_dict['directory']
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
if os.path.isdir(file_path) and file_path.endswith("__pycache__"):
shutil.rmtree(file_path, ignore_errors=True)
post_info += "{} Removed.".format(file_path) + "\n\n"
log_and_print_online(post_info)
logging.shutdown()
time.sleep(1)
shutil.move(self.log_filepath,
os.path.join(root + "/WareHouse", "_".join([self.project_name, self.org_name, self.start_time]),
os.path.basename(self.log_filepath)))
# @staticmethod
def self_task_improve(self, task_prompt):
"""
ask agent to improve the user query prompt
Args:
task_prompt: original user query prompt
Returns:
revised_task_prompt: revised prompt from the prompt engineer agent
"""
self_task_improve_prompt = """I will give you a short description of a software design requirement,
please rewrite it into a detailed prompt that can make large language model know how to make this software better based this prompt,
the prompt should ensure LLMs build a software that can be run correctly, which is the most import part you need to consider.
remember that the revised prompt should not contain more than 200 words,
here is the short description:\"{}\".
If the revised prompt is revised_version_of_the_description,
then you should return a message in a format like \"<INFO> revised_version_of_the_description\", do not return messages in other formats.""".format(
task_prompt)
role_play_session = RolePlaying(
assistant_role_name="Prompt Engineer",
assistant_role_prompt="You are an professional prompt engineer that can improve user input prompt to make LLM better understand these prompts.",
user_role_prompt="You are an user that want to use LLM to build software.",
user_role_name="User",
task_type=TaskType.CHATDEV,
task_prompt="Do prompt engineering on user query",
with_task_specify=False,
model_type=self.model_type,
)
# log_and_print_online("System", role_play_session.assistant_sys_msg)
# log_and_print_online("System", role_play_session.user_sys_msg)
_, input_user_msg = role_play_session.init_chat(None, None, self_task_improve_prompt)
assistant_response, user_response = role_play_session.step(input_user_msg, True)
revised_task_prompt = assistant_response.msg.content.split("<INFO>")[-1].lower().strip()
log_and_print_online(role_play_session.assistant_agent.role_name, assistant_response.msg.content)
log_and_print_online(
"**[Task Prompt Self Improvement]**\n**Original Task Prompt**: {}\n**Improved Task Prompt**: {}".format(
task_prompt, revised_task_prompt))
return revised_task_prompt