martinjosifoski commited on
Commit
6a9c224
1 Parent(s): 0949dd4

Add support for default_human_input_key and outputs transformations.

Browse files
OpenAIChatAtomicFlow.py CHANGED
@@ -98,6 +98,7 @@ class OpenAIChatAtomicFlow(AtomicFlow):
98
  flow_config = deepcopy(config)
99
 
100
  kwargs = {"flow_config": flow_config}
 
101
 
102
  # ~~~ Set up prompts ~~~
103
  kwargs.update(cls._set_up_prompts(flow_config))
@@ -120,7 +121,7 @@ class OpenAIChatAtomicFlow(AtomicFlow):
120
  def get_input_keys(self, data: Optional[Dict[str, Any]] = None):
121
  """Returns the expected inputs for the flow given the current state and, optionally, the input data"""
122
  if self._is_conversation_initialized():
123
- return ["query"]
124
  else:
125
  return self.flow_config["input_keys"]
126
 
@@ -294,4 +295,4 @@ class OpenAIChatAtomicFlow(AtomicFlow):
294
  )
295
  # self._state_update_dict(update_data=output_data) # ToDo: Is this necessary? When?
296
 
297
- return output_data
 
98
  flow_config = deepcopy(config)
99
 
100
  kwargs = {"flow_config": flow_config}
101
+ kwargs["outputs_transformations"] = cls._set_up_outputs_transformations(flow_config)
102
 
103
  # ~~~ Set up prompts ~~~
104
  kwargs.update(cls._set_up_prompts(flow_config))
 
121
  def get_input_keys(self, data: Optional[Dict[str, Any]] = None):
122
  """Returns the expected inputs for the flow given the current state and, optionally, the input data"""
123
  if self._is_conversation_initialized():
124
+ return [self.flow_config["default_human_input_key"]]
125
  else:
126
  return self.flow_config["input_keys"]
127
 
 
295
  )
296
  # self._state_update_dict(update_data=output_data) # ToDo: Is this necessary? When?
297
 
298
+ return output_data
OpenAIChatAtomicFlow.yaml CHANGED
@@ -30,13 +30,14 @@ user_message_prompt_template:
30
 
31
  human_message_prompt_template:
32
  _target_: langchain.PromptTemplate
 
 
 
33
  template_format: jinja2
 
34
 
35
  query_message_prompt_template:
36
  _target_: langchain.PromptTemplate
37
- template: "{{query}}"
38
- input_variables:
39
- - "query"
40
  template_format: jinja2
41
 
42
  demonstrations: null
 
30
 
31
  human_message_prompt_template:
32
  _target_: langchain.PromptTemplate
33
+ template: "{{query}}"
34
+ input_variables:
35
+ - "query"
36
  template_format: jinja2
37
+ default_human_input_key: "query"
38
 
39
  query_message_prompt_template:
40
  _target_: langchain.PromptTemplate
 
 
 
41
  template_format: jinja2
42
 
43
  demonstrations: null