File size: 5,572 Bytes
bdadbc2 258e7ae bdadbc2 258e7ae bdadbc2 258e7ae 527ef24 258e7ae bdadbc2 258e7ae bdadbc2 258e7ae bdadbc2 258e7ae bdadbc2 258e7ae bdadbc2 9bd62d1 aeea4f4 cde39d0 dab9fc7 aeea4f4 bdadbc2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import json
from copy import deepcopy
from typing import Any, Dict, List
from flow_modules.aiflows.ChatFlowModule import ChatAtomicFlow
from dataclasses import dataclass
@dataclass
class Command:
name: str
description: str
input_args: List[str]
class Controller_CoderFlow(ChatAtomicFlow):
def __init__(
self,
commands: List[Command],
**kwargs):
super().__init__(**kwargs)
self.system_message_prompt_template = self.system_message_prompt_template.partial(
commands=self._build_commands_manual(commands),
plan="no plans yet",
plan_file_location="no location yet",
code_library="no code library yet",
code_library_location="no location yet",
logs="no logs yet",
)
self.hint_for_model = """
Make sure your response is in the following format:
Response Format:
{
"command": "call one of the subordinates",
"command_args": {
"arg name": "value"
}
}
"""
@staticmethod
def _build_commands_manual(commands: List[Command]) -> str:
ret = ""
for i, command in enumerate(commands):
command_input_json_schema = json.dumps(
{input_arg: f"YOUR_{input_arg.upper()}" for input_arg in command.input_args})
ret += f"{i + 1}. {command.name}: {command.description} Input arguments (given in the JSON schema): {command_input_json_schema}\n"
return ret
def _get_content_file_location(self, input_data, content_name):
# get the location of the file that contains the content: plan, logs, code_library
assert "memory_files" in input_data, "memory_files not passed to Coder/Controller"
assert content_name in input_data["memory_files"], f"{content_name} not in memory files"
return input_data["memory_files"][content_name]
def _get_content(self, input_data, content_name):
# get the content of the file that contains the content: plan, logs, code_library
assert content_name in input_data, f"{content_name} not passed to Coder/Controller"
content = input_data[content_name]
if len(content) == 0:
content = f'No {content_name} yet'
return content
@classmethod
def instantiate_from_config(cls, config):
flow_config = deepcopy(config)
kwargs = {"flow_config": flow_config}
# ~~~ Set up prompts ~~~
kwargs.update(cls._set_up_prompts(flow_config))
# ~~~Set up backend ~~~
kwargs.update(cls._set_up_backend(flow_config))
# ~~~ Set up commands ~~~
commands = flow_config["commands"]
commands = [
Command(name, command_conf["description"], command_conf["input_args"]) for name, command_conf in
commands.items()
]
kwargs.update({"commands": commands})
# ~~~ Instantiate flow ~~~
return cls(**kwargs)
def _update_prompts_and_input(self, input_data: Dict[str, Any]):
if 'goal' in input_data:
input_data['goal'] += self.hint_for_model
if 'result' in input_data:
input_data['result'] += self.hint_for_model
plan_file_location = self._get_content_file_location(input_data, "plan")
plan_content = self._get_content(input_data, "plan")
code_library_location = self._get_content_file_location(input_data, "code_library")
code_library_content = self._get_content(input_data, "code_library")
logs_content = self._get_content(input_data, "logs")
self.system_message_prompt_template = self.system_message_prompt_template.partial(
plan_file_location=plan_file_location,
plan=plan_content,
code_library_location=code_library_location,
code_library=code_library_content,
logs=logs_content
)
def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
self._update_prompts_and_input(input_data)
# ~~~when conversation is initialized, append the updated system prompts to the chat history ~~~
if self._is_conversation_initialized():
updated_system_message_content = self._get_message(self.system_message_prompt_template, input_data)
self._state_update_add_chat_message(content=updated_system_message_content,
role=self.flow_config["system_name"])
while True:
api_output = super().run(input_data)["api_output"].strip()
try:
start = api_output.index("{")
end = api_output.rindex("}") + 1
json_str = api_output[start:end]
return json.loads(json_str)
except (ValueError, json.decoder.JSONDecodeError, json.JSONDecodeError):
updated_system_message_content = self._get_message(self.system_message_prompt_template, input_data)
self._state_update_add_chat_message(content=updated_system_message_content,
role=self.flow_config["system_name"])
new_goal = "The previous respond cannot be parsed with json.loads. Next time, do not provide any comments or code blocks. Make sure your next response is purely json parsable."
new_input_data = input_data.copy()
new_input_data['result'] = new_goal
input_data = new_input_data
|