Tachi67 commited on
Commit
bdadbc2
·
1 Parent(s): 5240345

Update Controller_CoderFlow.py

Browse files
Files changed (1) hide show
  1. Controller_CoderFlow.py +65 -16
Controller_CoderFlow.py CHANGED
@@ -1,5 +1,8 @@
 
 
1
  from typing import Any, Dict, List
2
- from flow_modules.Tachi67.ExtendLibraryFlowModule import ControllerFlow_ExtLib
 
3
 
4
 
5
  from dataclasses import dataclass
@@ -11,7 +14,7 @@ class Command:
11
  description: str
12
  input_args: List[str]
13
 
14
- class Controller_CoderFlow(ControllerFlow_ExtLib):
15
  def __init__(
16
  self,
17
  commands: List[Command],
@@ -36,28 +39,62 @@ class Controller_CoderFlow(ControllerFlow_ExtLib):
36
  }
37
  """
38
 
39
- def _get_code_library_location(self, input_data):
 
 
 
 
 
 
 
 
 
 
40
  assert "memory_files" in input_data, "memory_files not passed to Coder/Controller"
41
- assert "code_library" in input_data["memory_files"], "code_library not in memory files"
42
- return input_data["memory_files"]["code_library"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- def _get_code_library_content(self, input_data):
45
- assert "code_library" in input_data, "code_library not passed to Coder/Controller"
46
- code_library_content = input_data["code_library"]
47
- if len(code_library_content) == 0:
48
- code_library_content = 'No plan yet'
49
- return code_library_content
 
 
 
 
 
 
 
 
 
 
50
 
51
  def _update_prompts_and_input(self, input_data: Dict[str, Any]):
52
  if 'goal' in input_data:
53
  input_data['goal'] += self.hint_for_model
54
  if 'result' in input_data:
55
  input_data['result'] += self.hint_for_model
56
- plan_file_location = self._get_plan_file_location(input_data)
57
- plan_content = self._get_plan_content(input_data)
58
- code_library_location = self._get_code_library_location(input_data)
59
- code_library_content = self._get_code_library_content(input_data)
60
- logs_content = self._get_logs_content(input_data)
61
  self.system_message_prompt_template = self.system_message_prompt_template.partial(
62
  plan_file_location=plan_file_location,
63
  plan=plan_content,
@@ -66,3 +103,15 @@ class Controller_CoderFlow(ControllerFlow_ExtLib):
66
  logs=logs_content
67
  )
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from copy import deepcopy
3
  from typing import Any, Dict, List
4
+
5
+ from flow_modules.aiflows.ChatFlowModule import ChatAtomicFlow
6
 
7
 
8
  from dataclasses import dataclass
 
14
  description: str
15
  input_args: List[str]
16
 
17
+ class Controller_CoderFlow(ChatAtomicFlow):
18
  def __init__(
19
  self,
20
  commands: List[Command],
 
39
  }
40
  """
41
 
42
+ @staticmethod
43
+ def _build_commands_manual(commands: List[Command]) -> str:
44
+ ret = ""
45
+ for i, command in enumerate(commands):
46
+ command_input_json_schema = json.dumps(
47
+ {input_arg: f"YOUR_{input_arg.upper()}" for input_arg in command.input_args})
48
+ ret += f"{i + 1}. {command.name}: {command.description} Input arguments (given in the JSON schema): {command_input_json_schema}\n"
49
+ return ret
50
+
51
+ def _get_content_file_location(self, input_data, content_name):
52
+ # get the location of the file that contains the content: plan, logs, code_library
53
  assert "memory_files" in input_data, "memory_files not passed to Coder/Controller"
54
+ assert content_name in input_data["memory_files"], f"{content_name} not in memory files"
55
+ return input_data["memory_files"][content_name]
56
+
57
+ def _get_content(self, input_data, content_name):
58
+ # get the content of the file that contains the content: plan, logs, code_library
59
+ assert content_name in input_data, f"{content_name} not passed to Coder/Controller"
60
+ content = input_data[content_name]
61
+ if len(content) == 0:
62
+ content = f'No {content_name} yet'
63
+ return content
64
+
65
+ @classmethod
66
+ def instantiate_from_config(cls, config):
67
+ flow_config = deepcopy(config)
68
+
69
+ kwargs = {"flow_config": flow_config}
70
 
71
+ # ~~~ Set up prompts ~~~
72
+ kwargs.update(cls._set_up_prompts(flow_config))
73
+
74
+ # ~~~Set up backend ~~~
75
+ kwargs.update(cls._set_up_backend(flow_config))
76
+
77
+ # ~~~ Set up commands ~~~
78
+ commands = flow_config["commands"]
79
+ commands = [
80
+ Command(name, command_conf["description"], command_conf["input_args"]) for name, command_conf in
81
+ commands.items()
82
+ ]
83
+ kwargs.update({"commands": commands})
84
+
85
+ # ~~~ Instantiate flow ~~~
86
+ return cls(**kwargs)
87
 
88
  def _update_prompts_and_input(self, input_data: Dict[str, Any]):
89
  if 'goal' in input_data:
90
  input_data['goal'] += self.hint_for_model
91
  if 'result' in input_data:
92
  input_data['result'] += self.hint_for_model
93
+ plan_file_location = self._get_content_file_location(input_data, "plan")
94
+ plan_content = self._get_content(input_data, "plan")
95
+ code_library_location = self._get_content_file_location(input_data, "code_library")
96
+ code_library_content = self._get_content(input_data, "code_library")
97
+ logs_content = self._get_content(input_data, "logs")
98
  self.system_message_prompt_template = self.system_message_prompt_template.partial(
99
  plan_file_location=plan_file_location,
100
  plan=plan_content,
 
103
  logs=logs_content
104
  )
105
 
106
+ def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
107
+ self._update_prompts_and_input(input_data)
108
+ api_output = super().run(input_data)["api_output"].strip()
109
+ try:
110
+ response = json.loads(api_output)
111
+ return response
112
+ except json.decoder.JSONDecodeError:
113
+ new_input_data = input_data.copy()
114
+ new_input_data['result'] = "The previous respond cannot be parsed with json.loads. Make sure your next response is in JSON format."
115
+ new_api_output = super().run(new_input_data)["api_output"].strip()
116
+ return json.loads(new_api_output)
117
+