Update ControllerAtomicFlow.py
Browse files- ControllerAtomicFlow.py +5 -6
ControllerAtomicFlow.py
CHANGED
@@ -96,21 +96,19 @@ class ControllerAtomicFlow(OpenAIChatAtomicFlow):
|
|
96 |
return ret
|
97 |
except FileNotFoundError:
|
98 |
return 'There is no function available yet.'
|
99 |
-
|
100 |
def _get_plan(self):
|
101 |
try:
|
102 |
with open(self.plan_file_location, 'r') as file:
|
103 |
return file.read()
|
104 |
except FileNotFoundError:
|
105 |
return "There is no plan yet"
|
106 |
-
|
107 |
-
def
|
108 |
if 'goal' in input_data:
|
109 |
input_data['goal'] += self.hint_for_model
|
110 |
if 'human_feedback' in input_data:
|
111 |
input_data['human_feedback'] += self.hint_for_model
|
112 |
-
|
113 |
-
# self.system_message_prompt_template.template
|
114 |
plan_to_append = self._get_plan()
|
115 |
function_signatures_to_append = self._get_library_function_signatures()
|
116 |
self.system_message_prompt_template.template = \
|
@@ -120,8 +118,9 @@ class ControllerAtomicFlow(OpenAIChatAtomicFlow):
|
|
120 |
+ plan_to_append + "\n\n" + f"Make sure the plan your write is at {self.plan_file_location}\n" \
|
121 |
+ f"Make sure the code you call the code writer to write is at {self.code_file_location}"
|
122 |
|
|
|
|
|
123 |
api_output = super().run(input_data)["api_output"].strip()
|
124 |
-
|
125 |
try:
|
126 |
response = json.loads(api_output)
|
127 |
return response
|
|
|
96 |
return ret
|
97 |
except FileNotFoundError:
|
98 |
return 'There is no function available yet.'
|
99 |
+
|
100 |
def _get_plan(self):
|
101 |
try:
|
102 |
with open(self.plan_file_location, 'r') as file:
|
103 |
return file.read()
|
104 |
except FileNotFoundError:
|
105 |
return "There is no plan yet"
|
106 |
+
|
107 |
+
def _update_prompts_and_input(self, input_data: Dict[str, Any]):
|
108 |
if 'goal' in input_data:
|
109 |
input_data['goal'] += self.hint_for_model
|
110 |
if 'human_feedback' in input_data:
|
111 |
input_data['human_feedback'] += self.hint_for_model
|
|
|
|
|
112 |
plan_to_append = self._get_plan()
|
113 |
function_signatures_to_append = self._get_library_function_signatures()
|
114 |
self.system_message_prompt_template.template = \
|
|
|
118 |
+ plan_to_append + "\n\n" + f"Make sure the plan your write is at {self.plan_file_location}\n" \
|
119 |
+ f"Make sure the code you call the code writer to write is at {self.code_file_location}"
|
120 |
|
121 |
+
def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
|
122 |
+
self._update_prompts_and_input(input_data)
|
123 |
api_output = super().run(input_data)["api_output"].strip()
|
|
|
124 |
try:
|
125 |
response = json.loads(api_output)
|
126 |
return response
|