martinjosifoski commited on
Commit
6542487
1 Parent(s): 259f9b9

Update instantiation of flow.

Browse files
Files changed (2) hide show
  1. OpenAIChatAtomicFlow.py +103 -61
  2. OpenAIChatAtomicFlow.yaml +1 -1
OpenAIChatAtomicFlow.py CHANGED
@@ -1,4 +1,6 @@
1
  import pprint
 
 
2
  import hydra
3
 
4
  import colorama
@@ -7,15 +9,17 @@ import time
7
  from typing import List, Dict, Optional, Any
8
 
9
  from langchain import PromptTemplate
10
- from langchain.chat_models import ChatOpenAI
11
  from langchain.schema import HumanMessage, AIMessage, SystemMessage
12
 
 
13
  from flows.message_annotators.abstract import MessageAnnotator
14
  from flows.base_flows.abstract import AtomicFlow
15
  from flows.datasets import GenericDemonstrationsDataset
16
 
17
  from flows import utils
18
  from flows.messages.chat_message import ChatMessage
 
19
 
20
  log = utils.get_pylogger(__name__)
21
 
@@ -40,55 +44,98 @@ class OpenAIChatAtomicFlow(AtomicFlow):
40
  response_annotators: Optional[Dict[str, MessageAnnotator]] = {}
41
 
42
  def __init__(self, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # ~~~ Model generation ~~~
44
- if "model_name" not in kwargs:
45
- raise KeyError
46
 
47
- if "generation_parameters" not in kwargs:
48
- raise KeyError
49
 
50
  # ~~~ Prompting ~~~
51
  if "system_message_prompt_template" not in kwargs:
52
- raise KeyError
 
 
 
53
 
54
  if "human_message_prompt_template" not in kwargs:
55
- raise KeyError
56
 
57
- super().__init__(**kwargs)
58
- self._instantiate()
 
59
 
60
- assert self.name not in [
61
- "system",
62
- "user",
63
- "assistant",
64
- ], f"Flow name '{self.name}' cannot be 'system', 'user' or 'assistant'"
65
-
66
- def _instantiate(self):
67
- # ~~~ Instantiate prompts ~~~
68
- self.system_message_prompt_template = \
69
- hydra.utils.instantiate(self.flow_config['system_message_prompt_template'], _convert_="partial")
70
- self.query_message_prompt_template = \
71
- hydra.utils.instantiate(self.flow_config['query_message_prompt_template'], _convert_="partial")
72
- if self.flow_config["human_message_prompt_template"] is not None:
73
- self.human_message_prompt_template = \
74
- hydra.utils.instantiate(self.flow_config['human_message_prompt_template'], _convert_="partial")
75
-
76
- # ~~~ Instantiate response annotators ~~~
77
- if self.flow_config["response_annotators"] and len(self.flow_config["response_annotators"]) > 0:
78
- for key, config in self.flow_config["response_annotators"].items():
79
- self.response_annotators[key] = hydra.utils.instantiate(config, _convert_="partial")
80
-
81
- def is_initialized(self):
82
- conv_init = False
83
- if "conversation_initialized" in self.flow_state:
84
- conv_init = self.flow_state["conversation_initialized"]
85
- return conv_init
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  def expected_inputs_given_state(self):
88
- if self.is_initialized():
89
  return ["query"]
90
  else:
91
- return self.expected_inputs
92
 
93
  @staticmethod
94
  def _get_message(prompt_template, input_data: Dict[str, Any]):
@@ -100,10 +147,12 @@ class OpenAIChatAtomicFlow(AtomicFlow):
100
  return msg_content
101
 
102
  def _get_demonstration_query_message_content(self, sample_data: Dict):
103
- return self.query_message_prompt_template.format(**sample_data), []
 
104
 
105
  def _get_demonstration_response_message_content(self, sample_data: Dict):
106
- return self.demonstrations_response_template.format(**sample_data), []
 
107
 
108
  def _get_annotator_with_key(self, key: str):
109
  for _, ra in self.response_annotators.items():
@@ -113,6 +162,9 @@ class OpenAIChatAtomicFlow(AtomicFlow):
113
  def _response_parsing(self, response: str, expected_outputs: List[str]):
114
  target_annotators = [ra for _, ra in self.response_annotators.items() if ra.key in expected_outputs]
115
 
 
 
 
116
  parsed_outputs = {}
117
  for ra in target_annotators:
118
  parsed_out = ra(response)
@@ -137,7 +189,7 @@ class OpenAIChatAtomicFlow(AtomicFlow):
137
  chat_message = ChatMessage(
138
  message_creator=message_creator,
139
  parent_message_ids=parent_message_ids,
140
- flow_runner=self.name,
141
  flow_run_id=self.flow_run_id,
142
  content=content
143
  )
@@ -155,10 +207,6 @@ class OpenAIChatAtomicFlow(AtomicFlow):
155
  self._update_state(update_data={"conversation_initialized": True})
156
 
157
  def get_conversation_messages(self, message_format: Optional[str] = None):
158
- assert message_format is None or message_format in [
159
- "open_ai"
160
- ], f"Currently supported conversation message formats: 'open_ai'. '{message_format}' is not supported"
161
-
162
  messages = self.flow_state["history"].get_chat_messages()
163
 
164
  if message_format is None:
@@ -178,15 +226,16 @@ class OpenAIChatAtomicFlow(AtomicFlow):
178
  raise ValueError(f"Unknown name: {message.message_creator}")
179
  return processed_messages
180
  else:
181
- raise ValueError(f"Unknown message format: {message_format}")
 
182
 
183
  def _call(self):
184
  api_key = self.flow_state["api_key"]
185
 
186
- backend = ChatOpenAI(
187
- model_name=self.model_name,
188
  openai_api_key=api_key,
189
- **self.generation_parameters,
190
  )
191
 
192
  messages = self.get_conversation_messages(
@@ -218,17 +267,17 @@ class OpenAIChatAtomicFlow(AtomicFlow):
218
  if not _success:
219
  raise error
220
 
221
- if self.verbose:
222
  messages_str = self.flow_state["history"].to_string()
223
  log.info(
224
- f"\n{colorama.Fore.MAGENTA}~~~ History [{self.name}] ~~~\n"
225
  f"{colorama.Style.RESET_ALL}{messages_str}"
226
  )
227
 
228
  return response
229
 
230
  def _prepare_conversation(self, input_data: Dict[str, Any]):
231
- if self.is_initialized():
232
  # ~~~ Check that the message has a `query` field ~~~
233
  user_message_content = self.human_message_prompt_template.format(query=input_data["query"])
234
 
@@ -239,14 +288,7 @@ class OpenAIChatAtomicFlow(AtomicFlow):
239
  self._log_chat_message(message_creator=self.user_name,
240
  content=user_message_content)
241
 
242
- # if self.flow_state["dry_run"]:
243
- # messages_str = self.flow_state["history"].to_string()
244
- # log.info(
245
- # f"\n{colorama.Fore.MAGENTA}~~~ Messages [{self.name} -- {self.flow_run_id}] ~~~\n"
246
- # f"{colorama.Style.RESET_ALL}{messages_str}"
247
- # )
248
- # exit(0)
249
-
250
  def run(self, input_data: Dict[str, Any], expected_outputs: List[str]) -> Dict[str, Any]:
251
  # ~~~ Chat-specific preparation ~~~
252
  self._prepare_conversation(input_data)
@@ -254,7 +296,7 @@ class OpenAIChatAtomicFlow(AtomicFlow):
254
  # ~~~ Call ~~~
255
  response = self._call()
256
  answer_message = self._log_chat_message(
257
- message_creator=self.assistant_name,
258
  content=response
259
  )
260
 
@@ -265,7 +307,7 @@ class OpenAIChatAtomicFlow(AtomicFlow):
265
  )
266
  self._update_state(update_data=parsed_outputs)
267
 
268
- if self.verbose:
269
  parsed_output_messages_str = pprint.pformat({k: m for k, m in parsed_outputs.items()},
270
  indent=4)
271
  log.info(
 
1
  import pprint
2
+ from copy import deepcopy
3
+
4
  import hydra
5
 
6
  import colorama
 
9
  from typing import List, Dict, Optional, Any
10
 
11
  from langchain import PromptTemplate
12
+ import langchain
13
  from langchain.schema import HumanMessage, AIMessage, SystemMessage
14
 
15
+ from flows.history import FlowHistory
16
  from flows.message_annotators.abstract import MessageAnnotator
17
  from flows.base_flows.abstract import AtomicFlow
18
  from flows.datasets import GenericDemonstrationsDataset
19
 
20
  from flows import utils
21
  from flows.messages.chat_message import ChatMessage
22
+ from flows.utils.caching_utils import flow_run_cache
23
 
24
  log = utils.get_pylogger(__name__)
25
 
 
44
  response_annotators: Optional[Dict[str, MessageAnnotator]] = {}
45
 
46
  def __init__(self, **kwargs):
47
+ self._validate_parameters(kwargs)
48
+ super().__init__(**kwargs)
49
+
50
+ assert self.flow_config["name"] not in [
51
+ "system",
52
+ "user",
53
+ "assistant",
54
+ ], f"Flow name '{self.flow_config['name']}' cannot be 'system', 'user' or 'assistant'"
55
+
56
+ def set_up_flow_state(self):
57
+ super().set_up_flow_state()
58
+ self.flow_state["conversation_initialized"] = False
59
+
60
+ @classmethod
61
+ def _validate_parameters(cls, kwargs):
62
+ # ToDo: Deal with this in a cleaner way (with less repetition)
63
+ super()._validate_parameters(kwargs)
64
+
65
  # ~~~ Model generation ~~~
66
+ if "model_name" not in kwargs["flow_config"]:
67
+ raise KeyError("model_name not specified in the flow_config.")
68
 
69
+ if "generation_parameters" not in kwargs["flow_config"]:
70
+ raise KeyError("generation_parameters not specified in the flow_config.")
71
 
72
  # ~~~ Prompting ~~~
73
  if "system_message_prompt_template" not in kwargs:
74
+ raise KeyError("system_message_prompt_template not passed to the constructor.")
75
+
76
+ if "query_message_prompt_template" not in kwargs:
77
+ raise KeyError("query_message_prompt_template not passed to the constructor.")
78
 
79
  if "human_message_prompt_template" not in kwargs:
80
+ raise KeyError("human_message_prompt_template not passed to the constructor.")
81
 
82
+ @classmethod
83
+ def _set_up_prompts(cls, config):
84
+ kwargs = {}
85
 
86
+ kwargs["system_message_prompt_template"] = \
87
+ hydra.utils.instantiate(config['system_message_prompt_template'], _convert_="partial")
88
+ kwargs["query_message_prompt_template"] = \
89
+ hydra.utils.instantiate(config['query_message_prompt_template'], _convert_="partial")
90
+ kwargs["human_message_prompt_template"] = \
91
+ hydra.utils.instantiate(config['human_message_prompt_template'], _convert_="partial")
92
+
93
+ return kwargs
94
+
95
+ @classmethod
96
+ def _set_up_demonstration_templates(cls, config):
97
+ kwargs = {}
98
+
99
+ if "demonstrations_response_template" in config:
100
+ kwargs["demonstrations_response_template"] = \
101
+ hydra.utils.instantiate(config['demonstrations_response_template'], _convert_="partial")
102
+
103
+ return kwargs
104
+
105
+ @classmethod
106
+ def _set_up_response_annotators(cls, config):
107
+ response_annotators = config.get("response_annotators", {})
108
+ if len(response_annotators) > 0:
109
+ for key, config in response_annotators.items():
110
+ response_annotators[key] = hydra.utils.instantiate(config, _convert_="partial")
111
+ return {"response_annotators": response_annotators}
112
+
113
+ @classmethod
114
+ def instantiate_from_config(cls, config):
115
+ flow_config = deepcopy(config)
116
+
117
+ kwargs = {"flow_config": flow_config}
118
+
119
+ # ~~~ Set up prompts ~~~
120
+ kwargs.update(cls._set_up_prompts(flow_config))
121
+
122
+ # ~~~ Set up demonstration templates ~~~
123
+ kwargs.update(cls._set_up_demonstration_templates(flow_config))
124
+
125
+ # ~~~ Set up response annotators ~~~
126
+ kwargs.update(cls._set_up_response_annotators(flow_config))
127
+
128
+ # ~~~ Instantiate flow ~~~
129
+ return cls(**kwargs)
130
+
131
+ def _is_conversation_initialized(self):
132
+ return self.flow_state["conversation_initialized"]
133
 
134
  def expected_inputs_given_state(self):
135
+ if self._is_conversation_initialized():
136
  return ["query"]
137
  else:
138
+ return self.flow_config["expected_inputs"]
139
 
140
  @staticmethod
141
  def _get_message(prompt_template, input_data: Dict[str, Any]):
 
147
  return msg_content
148
 
149
  def _get_demonstration_query_message_content(self, sample_data: Dict):
150
+ input_variables = self.query_message_prompt_template.input_variables
151
+ return self.query_message_prompt_template.format(**{k: sample_data[k] for k in input_variables}), []
152
 
153
  def _get_demonstration_response_message_content(self, sample_data: Dict):
154
+ input_variables = self.demonstrations_response_template.input_variables
155
+ return self.demonstrations_response_template.format(**{k: sample_data[k] for k in input_variables}), []
156
 
157
  def _get_annotator_with_key(self, key: str):
158
  for _, ra in self.response_annotators.items():
 
162
  def _response_parsing(self, response: str, expected_outputs: List[str]):
163
  target_annotators = [ra for _, ra in self.response_annotators.items() if ra.key in expected_outputs]
164
 
165
+ if len(target_annotators) == 0:
166
+ return {expected_outputs[0]: response}
167
+
168
  parsed_outputs = {}
169
  for ra in target_annotators:
170
  parsed_out = ra(response)
 
189
  chat_message = ChatMessage(
190
  message_creator=message_creator,
191
  parent_message_ids=parent_message_ids,
192
+ flow_runner=self.flow_config["name"],
193
  flow_run_id=self.flow_run_id,
194
  content=content
195
  )
 
207
  self._update_state(update_data={"conversation_initialized": True})
208
 
209
  def get_conversation_messages(self, message_format: Optional[str] = None):
 
 
 
 
210
  messages = self.flow_state["history"].get_chat_messages()
211
 
212
  if message_format is None:
 
226
  raise ValueError(f"Unknown name: {message.message_creator}")
227
  return processed_messages
228
  else:
229
+ raise ValueError(
230
+ f"Currently supported conversation message formats: 'open_ai'. '{message_format}' is not supported")
231
 
232
  def _call(self):
233
  api_key = self.flow_state["api_key"]
234
 
235
+ backend = langchain.chat_models.ChatOpenAI(
236
+ model_name=self.flow_config["model_name"],
237
  openai_api_key=api_key,
238
+ **self.flow_config["generation_parameters"],
239
  )
240
 
241
  messages = self.get_conversation_messages(
 
267
  if not _success:
268
  raise error
269
 
270
+ if self.flow_config["verbose"]:
271
  messages_str = self.flow_state["history"].to_string()
272
  log.info(
273
+ f"\n{colorama.Fore.MAGENTA}~~~ History [{self.flow_config['name']}] ~~~\n"
274
  f"{colorama.Style.RESET_ALL}{messages_str}"
275
  )
276
 
277
  return response
278
 
279
  def _prepare_conversation(self, input_data: Dict[str, Any]):
280
+ if self._is_conversation_initialized():
281
  # ~~~ Check that the message has a `query` field ~~~
282
  user_message_content = self.human_message_prompt_template.format(query=input_data["query"])
283
 
 
288
  self._log_chat_message(message_creator=self.user_name,
289
  content=user_message_content)
290
 
291
+ @flow_run_cache()
 
 
 
 
 
 
 
292
  def run(self, input_data: Dict[str, Any], expected_outputs: List[str]) -> Dict[str, Any]:
293
  # ~~~ Chat-specific preparation ~~~
294
  self._prepare_conversation(input_data)
 
296
  # ~~~ Call ~~~
297
  response = self._call()
298
  answer_message = self._log_chat_message(
299
+ message_creator=self.flow_config["assistant_name"],
300
  content=response
301
  )
302
 
 
307
  )
308
  self._update_state(update_data=parsed_outputs)
309
 
310
+ if self.flow_config["verbose"]:
311
  parsed_output_messages_str = pprint.pformat({k: m for k, m in parsed_outputs.items()},
312
  indent=4)
313
  log.info(
OpenAIChatAtomicFlow.yaml CHANGED
@@ -1,4 +1,4 @@
1
- # This is an abstract flow, therefore some required fields are missing (not defined)
2
 
3
  n_api_retries: 6
4
  wait_time_between_retries: 20
 
1
+ # This is an abstract flow, therefore some required fields are not defined (and must be defined by the concrete flow)
2
 
3
  n_api_retries: 6
4
  wait_time_between_retries: 20