Tachi67 commited on
Commit
5ba61c7
·
1 Parent(s): ca341b8

Update PlanGeneratorAtomicFlow.py

Browse files
Files changed (1) hide show
  1. PlanGeneratorAtomicFlow.py +13 -11
PlanGeneratorAtomicFlow.py CHANGED
@@ -10,6 +10,13 @@ class PlanGeneratorAtomicFlow(OpenAIChatAtomicFlow):
10
  code_file_location: str,
11
  **kwargs):
12
  super().__init__(**kwargs)
 
 
 
 
 
 
 
13
  self.code_file_location = code_file_location
14
  self.original_system_template = self.system_message_prompt_template.template
15
 
@@ -44,23 +51,18 @@ class PlanGeneratorAtomicFlow(OpenAIChatAtomicFlow):
44
  return ret
45
  except FileNotFoundError:
46
  return 'There is no function available yet.'
47
-
48
- def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
49
- hint_for_model = """
50
- Make sure your response is in the following format:
51
- Response Format:
52
- {
53
- "plan": "Python printable string of the plan corresponding to the goal",
54
- }
55
- """
56
  if 'goal' in input_data:
57
- input_data['goal'] += hint_for_model
58
 
59
  function_signatures_to_append = self._get_library_function_signatures()
60
  self.system_message_prompt_template.template = \
61
  self.original_system_template + "\n\n" + f"Here are the available functions at {self.code_file_location}\n" \
62
- + function_signatures_to_append + "\n" \
63
 
 
 
64
  api_output = super().run(input_data)["api_output"].strip()
65
  try:
66
  response = json.loads(api_output)
 
10
  code_file_location: str,
11
  **kwargs):
12
  super().__init__(**kwargs)
13
+ self.hint_for_model = """
14
+ Make sure your response is in the following format:
15
+ Response Format:
16
+ {
17
+ "plan": "Python printable string of the plan corresponding to the goal",
18
+ }
19
+ """
20
  self.code_file_location = code_file_location
21
  self.original_system_template = self.system_message_prompt_template.template
22
 
 
51
  return ret
52
  except FileNotFoundError:
53
  return 'There is no function available yet.'
54
+
55
+ def _update_prompt_and_input(self, input_data: Dict[str, Any]):
 
 
 
 
 
 
 
56
  if 'goal' in input_data:
57
+ input_data['goal'] += self.hint_for_model
58
 
59
  function_signatures_to_append = self._get_library_function_signatures()
60
  self.system_message_prompt_template.template = \
61
  self.original_system_template + "\n\n" + f"Here are the available functions at {self.code_file_location}\n" \
62
+ + function_signatures_to_append + "\n"
63
 
64
+ def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
65
+ self._update_prompt_and_input(input_data)
66
  api_output = super().run(input_data)["api_output"].strip()
67
  try:
68
  response = json.loads(api_output)