martinjosifoski commited on
Commit
259f9b9
1 Parent(s): 34ee4a5

Clean up and add code for OpenAIChatAtomicFlow.

Browse files
OpenAIChatAtomicFlow.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ import hydra
3
+
4
+ import colorama
5
+ import time
6
+
7
+ from typing import List, Dict, Optional, Any
8
+
9
+ from langchain import PromptTemplate
10
+ from langchain.chat_models import ChatOpenAI
11
+ from langchain.schema import HumanMessage, AIMessage, SystemMessage
12
+
13
+ from flows.message_annotators.abstract import MessageAnnotator
14
+ from flows.base_flows.abstract import AtomicFlow
15
+ from flows.datasets import GenericDemonstrationsDataset
16
+
17
+ from flows import utils
18
+ from flows.messages.chat_message import ChatMessage
19
+
20
+ log = utils.get_pylogger(__name__)
21
+
22
+
23
+ class OpenAIChatAtomicFlow(AtomicFlow):
24
+ model_name: str
25
+ generation_parameters: Dict
26
+
27
+ system_message_prompt_template: PromptTemplate
28
+ human_message_prompt_template: PromptTemplate
29
+
30
+ system_name: str = "system"
31
+ user_name: str = "user"
32
+ assistant_name: str = "assistant"
33
+
34
+ n_api_retries: int = 6
35
+ wait_time_between_retries: int = 20
36
+
37
+ query_message_prompt_template: Optional[PromptTemplate] = None
38
+ demonstrations: GenericDemonstrationsDataset = None
39
+ demonstrations_response_template: PromptTemplate = None
40
+ response_annotators: Optional[Dict[str, MessageAnnotator]] = {}
41
+
42
+ def __init__(self, **kwargs):
43
+ # ~~~ Model generation ~~~
44
+ if "model_name" not in kwargs:
45
+ raise KeyError
46
+
47
+ if "generation_parameters" not in kwargs:
48
+ raise KeyError
49
+
50
+ # ~~~ Prompting ~~~
51
+ if "system_message_prompt_template" not in kwargs:
52
+ raise KeyError
53
+
54
+ if "human_message_prompt_template" not in kwargs:
55
+ raise KeyError
56
+
57
+ super().__init__(**kwargs)
58
+ self._instantiate()
59
+
60
+ assert self.name not in [
61
+ "system",
62
+ "user",
63
+ "assistant",
64
+ ], f"Flow name '{self.name}' cannot be 'system', 'user' or 'assistant'"
65
+
66
+ def _instantiate(self):
67
+ # ~~~ Instantiate prompts ~~~
68
+ self.system_message_prompt_template = \
69
+ hydra.utils.instantiate(self.flow_config['system_message_prompt_template'], _convert_="partial")
70
+ self.query_message_prompt_template = \
71
+ hydra.utils.instantiate(self.flow_config['query_message_prompt_template'], _convert_="partial")
72
+ if self.flow_config["human_message_prompt_template"] is not None:
73
+ self.human_message_prompt_template = \
74
+ hydra.utils.instantiate(self.flow_config['human_message_prompt_template'], _convert_="partial")
75
+
76
+ # ~~~ Instantiate response annotators ~~~
77
+ if self.flow_config["response_annotators"] and len(self.flow_config["response_annotators"]) > 0:
78
+ for key, config in self.flow_config["response_annotators"].items():
79
+ self.response_annotators[key] = hydra.utils.instantiate(config, _convert_="partial")
80
+
81
+ def is_initialized(self):
82
+ conv_init = False
83
+ if "conversation_initialized" in self.flow_state:
84
+ conv_init = self.flow_state["conversation_initialized"]
85
+ return conv_init
86
+
87
+ def expected_inputs_given_state(self):
88
+ if self.is_initialized():
89
+ return ["query"]
90
+ else:
91
+ return self.expected_inputs
92
+
93
+ @staticmethod
94
+ def _get_message(prompt_template, input_data: Dict[str, Any]):
95
+ template_kwargs = {}
96
+ for input_variable in prompt_template.input_variables:
97
+ template_kwargs[input_variable] = input_data[input_variable]
98
+
99
+ msg_content = prompt_template.format(**template_kwargs)
100
+ return msg_content
101
+
102
+ def _get_demonstration_query_message_content(self, sample_data: Dict):
103
+ return self.query_message_prompt_template.format(**sample_data), []
104
+
105
+ def _get_demonstration_response_message_content(self, sample_data: Dict):
106
+ return self.demonstrations_response_template.format(**sample_data), []
107
+
108
+ def _get_annotator_with_key(self, key: str):
109
+ for _, ra in self.response_annotators.items():
110
+ if ra.key == key:
111
+ return ra
112
+
113
+ def _response_parsing(self, response: str, expected_outputs: List[str]):
114
+ target_annotators = [ra for _, ra in self.response_annotators.items() if ra.key in expected_outputs]
115
+
116
+ parsed_outputs = {}
117
+ for ra in target_annotators:
118
+ parsed_out = ra(response)
119
+ parsed_outputs.update(parsed_out)
120
+ return parsed_outputs
121
+
122
+ def _add_demonstrations(self):
123
+ if self.demonstrations is not None:
124
+ for example in self.demonstrations:
125
+ query, parents = self._get_demonstration_query_message_content(example)
126
+ response, parents = self._get_demonstration_response_message_content(example)
127
+
128
+ self._log_chat_message(content=query,
129
+ message_creator=self.user_name,
130
+ parent_message_ids=parents)
131
+
132
+ self._log_chat_message(content=response,
133
+ message_creator=self.assistant_name,
134
+ parent_message_ids=parents)
135
+
136
+ def _log_chat_message(self, message_creator: str, content: str, parent_message_ids: List[str] = None):
137
+ chat_message = ChatMessage(
138
+ message_creator=message_creator,
139
+ parent_message_ids=parent_message_ids,
140
+ flow_runner=self.name,
141
+ flow_run_id=self.flow_run_id,
142
+ content=content
143
+ )
144
+ return self._log_message(chat_message)
145
+
146
+ def _initialize_conversation(self, input_data: Dict[str, Any]):
147
+ # ~~~ Add the system message ~~~
148
+ system_message_content = self._get_message(self.system_message_prompt_template, input_data)
149
+
150
+ self._log_chat_message(content=system_message_content,
151
+ message_creator=self.system_name)
152
+
153
+ # ~~~ Add the demonstration query-response tuples (if any) ~~~
154
+ self._add_demonstrations()
155
+ self._update_state(update_data={"conversation_initialized": True})
156
+
157
+ def get_conversation_messages(self, message_format: Optional[str] = None):
158
+ assert message_format is None or message_format in [
159
+ "open_ai"
160
+ ], f"Currently supported conversation message formats: 'open_ai'. '{message_format}' is not supported"
161
+
162
+ messages = self.flow_state["history"].get_chat_messages()
163
+
164
+ if message_format is None:
165
+ return messages
166
+
167
+ elif message_format == "open_ai":
168
+ processed_messages = []
169
+
170
+ for message in messages:
171
+ if message.message_creator == self.system_name:
172
+ processed_messages.append(SystemMessage(content=message.content))
173
+ elif message.message_creator == self.assistant_name:
174
+ processed_messages.append(AIMessage(content=message.content))
175
+ elif message.message_creator == self.user_name:
176
+ processed_messages.append(HumanMessage(content=message.content))
177
+ else:
178
+ raise ValueError(f"Unknown name: {message.message_creator}")
179
+ return processed_messages
180
+ else:
181
+ raise ValueError(f"Unknown message format: {message_format}")
182
+
183
+ def _call(self):
184
+ api_key = self.flow_state["api_key"]
185
+
186
+ backend = ChatOpenAI(
187
+ model_name=self.model_name,
188
+ openai_api_key=api_key,
189
+ **self.generation_parameters,
190
+ )
191
+
192
+ messages = self.get_conversation_messages(
193
+ message_format="open_ai"
194
+ )
195
+
196
+ _success = False
197
+ attempts = 1
198
+ error = None
199
+ response = None
200
+ while attempts <= self.n_api_retries:
201
+ try:
202
+ response = backend(messages).content
203
+ _success = True
204
+ break
205
+ except Exception as e:
206
+ log.error(
207
+ f"Error {attempts} in calling backend: {e}. Key used: `{api_key}`. "
208
+ f"Retrying in {self.wait_time_between_retries} seconds..."
209
+ )
210
+ log.error(
211
+ f"API call raised Exception with the following arguments arguments: "
212
+ f"\n{self.flow_state['history'].to_string()}"
213
+ )
214
+ attempts += 1
215
+ time.sleep(self.wait_time_between_retries)
216
+ error = e
217
+
218
+ if not _success:
219
+ raise error
220
+
221
+ if self.verbose:
222
+ messages_str = self.flow_state["history"].to_string()
223
+ log.info(
224
+ f"\n{colorama.Fore.MAGENTA}~~~ History [{self.name}] ~~~\n"
225
+ f"{colorama.Style.RESET_ALL}{messages_str}"
226
+ )
227
+
228
+ return response
229
+
230
+ def _prepare_conversation(self, input_data: Dict[str, Any]):
231
+ if self.is_initialized():
232
+ # ~~~ Check that the message has a `query` field ~~~
233
+ user_message_content = self.human_message_prompt_template.format(query=input_data["query"])
234
+
235
+ else:
236
+ self._initialize_conversation(input_data)
237
+ user_message_content = self._get_message(self.query_message_prompt_template, input_data)
238
+
239
+ self._log_chat_message(message_creator=self.user_name,
240
+ content=user_message_content)
241
+
242
+ # if self.flow_state["dry_run"]:
243
+ # messages_str = self.flow_state["history"].to_string()
244
+ # log.info(
245
+ # f"\n{colorama.Fore.MAGENTA}~~~ Messages [{self.name} -- {self.flow_run_id}] ~~~\n"
246
+ # f"{colorama.Style.RESET_ALL}{messages_str}"
247
+ # )
248
+ # exit(0)
249
+
250
+ def run(self, input_data: Dict[str, Any], expected_outputs: List[str]) -> Dict[str, Any]:
251
+ # ~~~ Chat-specific preparation ~~~
252
+ self._prepare_conversation(input_data)
253
+
254
+ # ~~~ Call ~~~
255
+ response = self._call()
256
+ answer_message = self._log_chat_message(
257
+ message_creator=self.assistant_name,
258
+ content=response
259
+ )
260
+
261
+ # ~~~ Response parsing ~~~
262
+ parsed_outputs = self._response_parsing(
263
+ response=response,
264
+ expected_outputs=expected_outputs
265
+ )
266
+ self._update_state(update_data=parsed_outputs)
267
+
268
+ if self.verbose:
269
+ parsed_output_messages_str = pprint.pformat({k: m for k, m in parsed_outputs.items()},
270
+ indent=4)
271
+ log.info(
272
+ f"\n{colorama.Fore.MAGENTA}~~~ "
273
+ f"Response [{answer_message.message_creator} -- "
274
+ f"{answer_message.message_id} -- "
275
+ f"{answer_message.flow_run_id}] ~~~"
276
+ f"\n{colorama.Fore.YELLOW}Content: {answer_message}{colorama.Style.RESET_ALL}"
277
+ f"\n{colorama.Fore.YELLOW}Parsed Outputs: {parsed_output_messages_str}{colorama.Style.RESET_ALL}"
278
+ )
279
+
280
+ # ~~~ The final answer should be in self.flow_state, thus allow_class_namespace=False ~~~
281
+ return self._get_keys_from_state(keys=expected_outputs, allow_class_namespace=False)
OpenAIChatAtomicFlow.yaml CHANGED
@@ -1 +1,14 @@
1
- # This is an abstract flow, therefore the default config is empty.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an abstract flow, therefore some required fields are missing (not defined)
2
+
3
+ n_api_retries: 6
4
+ wait_time_between_retries: 20
5
+
6
+ system_name: system
7
+ user_name: user
8
+ assistant_name: assistant
9
+
10
+ response_annotators: {}
11
+
12
+ query_message_prompt_template: null # ToDo: When will this be null?
13
+ demonstrations: null
14
+ demonstrations_response_template: null
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .OpenAIChatAtomicFlow import OpenAIChatAtomicFlow
test_folder/my_file.yaml DELETED
@@ -1 +0,0 @@
1
- # test file