martinjosifoski commited on
Commit
4f4d036
·
1 Parent(s): 8076214

First commit.

Browse files
Files changed (7) hide show
  1. OpenAIChatAtomicFlow.py +274 -0
  2. OpenAIChatAtomicFlow.yaml +51 -0
  3. README.md +23 -0
  4. __init__.py +1 -0
  5. pip_requirements.py +1 -0
  6. run.py +66 -0
  7. simpleQA.yaml +51 -0
OpenAIChatAtomicFlow.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ import hydra
4
+
5
+ import time
6
+
7
+ from typing import Dict, Optional, Any
8
+
9
+ from langchain import PromptTemplate
10
+ from langchain.schema import HumanMessage, AIMessage, SystemMessage
11
+
12
+ from flows.base_flows import AtomicFlow
13
+ from flows.datasets import GenericDemonstrationsDataset
14
+
15
+ from flows.utils import logging
16
+ from flows.messages.flow_message import UpdateMessage_ChatMessage
17
+
18
+ log = logging.get_logger(__name__)
19
+
20
+
21
+ class OpenAIChatAtomicFlow(AtomicFlow):
22
+ REQUIRED_KEYS_CONFIG = ["model_name", "generation_parameters"]
23
+
24
+ SUPPORTS_CACHING: bool = True
25
+
26
+ system_message_prompt_template: PromptTemplate
27
+ human_message_prompt_template: PromptTemplate
28
+
29
+ init_human_message_prompt_template: Optional[PromptTemplate] = None
30
+ demonstrations: GenericDemonstrationsDataset = None
31
+ demonstrations_k: Optional[int] = None
32
+ demonstrations_response_prompt_template: PromptTemplate = None
33
+
34
+ def __init__(self,
35
+ system_message_prompt_template,
36
+ human_message_prompt_template,
37
+ init_human_message_prompt_template,
38
+ demonstrations_response_prompt_template=None,
39
+ demonstrations=None,
40
+ **kwargs):
41
+ super().__init__(**kwargs)
42
+ self.system_message_prompt_template = system_message_prompt_template
43
+ self.human_message_prompt_template = human_message_prompt_template
44
+ self.init_human_message_prompt_template = init_human_message_prompt_template
45
+ self.demonstrations_response_prompt_template = demonstrations_response_prompt_template
46
+ self.demonstrations = demonstrations
47
+ self.demonstrations_k = self.flow_config.get("demonstrations_k", None)
48
+
49
+ assert self.flow_config["name"] not in [
50
+ "system",
51
+ "user",
52
+ "assistant",
53
+ ], f"Flow name '{self.flow_config['name']}' cannot be 'system', 'user' or 'assistant'"
54
+
55
+ def set_up_flow_state(self):
56
+ super().set_up_flow_state()
57
+ self.flow_state["previous_messages"] = []
58
+
59
+ @classmethod
60
+ def _set_up_prompts(cls, config):
61
+ kwargs = {}
62
+
63
+ kwargs["system_message_prompt_template"] = \
64
+ hydra.utils.instantiate(config['system_message_prompt_template'], _convert_="partial")
65
+ kwargs["init_human_message_prompt_template"] = \
66
+ hydra.utils.instantiate(config['init_human_message_prompt_template'], _convert_="partial")
67
+ kwargs["human_message_prompt_template"] = \
68
+ hydra.utils.instantiate(config['human_message_prompt_template'], _convert_="partial")
69
+
70
+ if "demonstrations_response_prompt_template" in config:
71
+ kwargs["demonstrations_response_prompt_template"] = \
72
+ hydra.utils.instantiate(config['demonstrations_response_prompt_template'], _convert_="partial")
73
+ kwargs["demonstrations"] = GenericDemonstrationsDataset(**config['demonstrations'])
74
+
75
+ return kwargs
76
+
77
+ @classmethod
78
+ def instantiate_from_config(cls, config):
79
+ flow_config = deepcopy(config)
80
+
81
+ kwargs = {"flow_config": flow_config}
82
+
83
+ # ~~~ Set up prompts ~~~
84
+ kwargs.update(cls._set_up_prompts(flow_config))
85
+
86
+ # ~~~ Instantiate flow ~~~
87
+ return cls(**kwargs)
88
+
89
+ def _is_conversation_initialized(self):
90
+ if len(self.flow_state["previous_messages"]) > 0:
91
+ return True
92
+
93
+ return False
94
+
95
+ def get_interface_description(self):
96
+ if self._is_conversation_initialized():
97
+
98
+ return {"input": self.flow_config["input_interface_initialized"],
99
+ "output": self.flow_config["output_interface"]}
100
+ else:
101
+ return {"input": self.flow_config["input_interface_non_initialized"],
102
+ "output": self.flow_config["output_interface"]}
103
+
104
+ @staticmethod
105
+ def _get_message(prompt_template, input_data: Dict[str, Any]):
106
+ template_kwargs = {}
107
+ for input_variable in prompt_template.input_variables:
108
+ template_kwargs[input_variable] = input_data[input_variable]
109
+
110
+ msg_content = prompt_template.format(**template_kwargs)
111
+ return msg_content
112
+
113
+ def _get_demonstration_query_message_content(self, sample_data: Dict):
114
+ input_variables = self.init_human_message_prompt_template.input_variables
115
+ return self.init_human_message_prompt_template.format(**{k: sample_data[k] for k in input_variables})
116
+
117
+ def _get_demonstration_response_message_content(self, sample_data: Dict):
118
+ input_variables = self.demonstrations_response_prompt_template.input_variables
119
+ return self.demonstrations_response_prompt_template.format(**{k: sample_data[k] for k in input_variables})
120
+
121
+ def _add_demonstrations(self):
122
+ if self.demonstrations is not None:
123
+ demonstrations = self.demonstrations
124
+
125
+ c = 0
126
+ for example in demonstrations:
127
+ if self.demonstrations_k is not None and c >= self.demonstrations_k:
128
+ break
129
+ c += 1
130
+ query = self._get_demonstration_query_message_content(example)
131
+ response = self._get_demonstration_response_message_content(example)
132
+
133
+ self._state_update_add_chat_message(content=query,
134
+ role=self.flow_config["user_name"])
135
+
136
+ self._state_update_add_chat_message(content=response,
137
+ role=self.flow_config["assistant_name"])
138
+
139
+ def _state_update_add_chat_message(self,
140
+ role: str,
141
+ content: str) -> None:
142
+
143
+ # Add the message to the previous messages list
144
+ if role == self.flow_config["system_name"]:
145
+ self.flow_state["previous_messages"].append(SystemMessage(content=content))
146
+ elif role == self.flow_config["user_name"]:
147
+ self.flow_state["previous_messages"].append(HumanMessage(content=content))
148
+ elif role == self.flow_config["assistant_name"]:
149
+ self.flow_state["previous_messages"].append(AIMessage(content=content))
150
+ else:
151
+ raise Exception(f"Invalid role: `{role}`.\n"
152
+ f"Role should be one of: "
153
+ f"`{self.flow_config['system_name']}`, "
154
+ f"`{self.flow_config['user_name']}`, "
155
+ f"`{self.flow_config['assistant_name']}`")
156
+
157
+ # Log the update to the flow messages list
158
+ chat_message = UpdateMessage_ChatMessage(
159
+ created_by=self.flow_config["name"],
160
+ updated_flow=self.flow_config["name"],
161
+ role=role,
162
+ content=content,
163
+ )
164
+ self._log_message(chat_message)
165
+
166
+ def _get_previous_messages(self):
167
+ all_messages = self.flow_state["previous_messages"]
168
+ first_k = self.flow_config["previous_messages"]["first_k"]
169
+ last_k = self.flow_config["previous_messages"]["last_k"]
170
+
171
+ if not first_k and not last_k:
172
+ return all_messages
173
+ elif first_k and last_k:
174
+ return all_messages[:first_k] + all_messages[-last_k:]
175
+ elif first_k:
176
+ return all_messages[:first_k]
177
+
178
+ return all_messages[-last_k:]
179
+
180
+ def _call(self):
181
+ api_information = self._get_from_state("api_information")
182
+ api_key = api_information.api_key
183
+
184
+ if api_information.backend_used == 'azure':
185
+ from backends.azure_openai import SafeAzureChatOpenAI
186
+ endpoint = api_information.endpoint
187
+ backend = SafeAzureChatOpenAI(
188
+ openai_api_type='azure',
189
+ openai_api_key=api_key,
190
+ openai_api_base=endpoint,
191
+ openai_api_version='2023-05-15',
192
+ deployment_name=self.flow_config["model_name"],
193
+ **self.flow_config["generation_parameters"],
194
+ )
195
+ elif api_information.backend_used == 'openai':
196
+ from backends.openai import SafeChatOpenAI
197
+ backend = SafeChatOpenAI(
198
+ model_name=self.flow_config["model_name"],
199
+ openai_api_key=api_key,
200
+ openai_api_type="open_ai",
201
+ **self.flow_config["generation_parameters"],
202
+ )
203
+ else:
204
+ raise ValueError(f"Unsupported backend: {api_information.backend_used}")
205
+
206
+ messages = self._get_previous_messages()
207
+
208
+ _success = False
209
+ attempts = 1
210
+ error = None
211
+ response = None
212
+ while attempts <= self.flow_config['n_api_retries']:
213
+ try:
214
+ response = backend(messages).content
215
+ _success = True
216
+ break
217
+ except Exception as e:
218
+ log.error(
219
+ f"Error {attempts} in calling backend: {e}. Key used: `{api_key}`. "
220
+ f"Retrying in {self.flow_config['wait_time_between_retries']} seconds..."
221
+ )
222
+ # log.error(
223
+ # f"The API call raised an exception with the following arguments: "
224
+ # f"\n{self.flow_state['history'].to_string()}"
225
+ # ) # ToDo: Make this message more user-friendly
226
+ attempts += 1
227
+ time.sleep(self.flow_config['wait_time_between_retries'])
228
+ error = e
229
+
230
+ if not _success:
231
+ raise error
232
+
233
+ return response
234
+
235
+ def _initialize_conversation(self, input_data: Dict[str, Any]):
236
+ # ~~~ Add the system message ~~~
237
+ system_message_content = self._get_message(self.system_message_prompt_template, input_data)
238
+
239
+ self._state_update_add_chat_message(content=system_message_content,
240
+ role=self.flow_config["system_name"])
241
+
242
+ # # ~~~ Add the demonstration query-response tuples (if any) ~~~
243
+ self._add_demonstrations()
244
+
245
+ def _process_input(self, input_data: Dict[str, Any]):
246
+ if self._is_conversation_initialized():
247
+ # Construct the message using the human message prompt template
248
+ user_message_content = self._get_message(self.human_message_prompt_template, input_data)
249
+
250
+ else:
251
+ # Initialize the conversation (add the system message, and potentially the demonstrations)
252
+ self._initialize_conversation(input_data)
253
+ if getattr(self, "init_human_message_prompt_template", None) is not None:
254
+ # Construct the message using the query message prompt template
255
+ user_message_content = self._get_message(self.init_human_message_prompt_template, input_data)
256
+ else:
257
+ user_message_content = self._get_message(self.human_message_prompt_template, input_data)
258
+
259
+ self._state_update_add_chat_message(role=self.flow_config["user_name"],
260
+ content=user_message_content)
261
+
262
+ def run(self,
263
+ input_data: Dict[str, Any]) -> Dict[str, Any]:
264
+ # ~~~ Process input ~~~
265
+ self._process_input(input_data)
266
+
267
+ # ~~~ Call ~~~
268
+ response = self._call()
269
+ self._state_update_add_chat_message(
270
+ role=self.flow_config["assistant_name"],
271
+ content=response
272
+ )
273
+
274
+ return {"api_output": response}
OpenAIChatAtomicFlow.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an abstract flow, therefore some required fields are not defined (and must be defined by the concrete flow)
2
+ enable_cache: True
3
+
4
+ model_name: "gpt-4"
5
+ generation_parameters:
6
+ n: 1
7
+ max_tokens: 2000
8
+ temperature: 0.3
9
+
10
+ model_kwargs:
11
+ top_p: 0.2
12
+ frequency_penalty: 0
13
+ presence_penalty: 0
14
+
15
+ n_api_retries: 6
16
+ wait_time_between_retries: 20
17
+
18
+ system_name: system
19
+ user_name: user
20
+ assistant_name: assistant
21
+
22
+ system_message_prompt_template:
23
+ _target_: langchain.PromptTemplate
24
+ template_format: jinja2
25
+
26
+ init_human_message_prompt_template:
27
+ _target_: langchain.PromptTemplate
28
+ template_format: jinja2
29
+
30
+ human_message_prompt_template:
31
+ _target_: langchain.PromptTemplate
32
+ template: "{{query}}"
33
+ input_variables:
34
+ - "query"
35
+ template_format: jinja2
36
+ input_interface_initialized:
37
+ - "query"
38
+
39
+ query_message_prompt_template:
40
+ _target_: langchain.PromptTemplate
41
+ template_format: jinja2
42
+
43
+ previous_messages:
44
+ first_k: null # Note that the first message is the system prompt
45
+ last_k: null
46
+
47
+ demonstrations: null
48
+ demonstrations_response_template: null
49
+
50
+ output_interface:
51
+ - "api_output"
README.md CHANGED
@@ -1,3 +1,26 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+ (TODO)
5
+
6
+ ## Description
7
+
8
+ &lt; Flow description &gt;
9
+
10
+ ## Configuration parameters
11
+
12
+ &lt; Name 1 &gt; (&lt; Type 1 &gt;): &lt; Description 1 &gt;. Required parameter.
13
+
14
+ &lt; Name 2 &gt; (&lt; Type 2 &gt;): &lt; Description 2 &gt;. Default value is: &lt; value 2 &gt;
15
+
16
+ ## Input interface
17
+
18
+ &lt; Name 1 &gt; (&lt; Type 1 &gt;): &lt; Description 1 &gt;.
19
+
20
+ (Note that the interface might depend on the state of the Flow.)
21
+
22
+ ## Output interface
23
+
24
+ &lt; Name 1 &gt; (&lt; Type 1 &gt;): &lt; Description 1 &gt;.
25
+
26
+ (Note that the interface might depend on the state of the Flow.)
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .OpenAIChatAtomicFlow import OpenAIChatAtomicFlow
pip_requirements.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # ToDo
run.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import hydra
4
+
5
+ import flows
6
+ from flows.flow_launchers import FlowLauncher, ApiInfo
7
+ from flows.utils.general_helpers import read_yaml_file
8
+
9
+ from flows import logging
10
+ from flows.flow_cache import CACHING_PARAMETERS, clear_cache
11
+
12
+ CACHING_PARAMETERS.do_caching = False # Set to True in order to disable caching
13
+ # clear_cache() # Uncomment this line to clear the cache
14
+
15
+ logging.set_verbosity_debug()
16
+
17
+ dependencies = [
18
+ {"url": "aiflows/OpenAIChatAtomicFlowModule", "revision": os.getcwd()},
19
+ ]
20
+ from flows import flow_verse
21
+ flow_verse.sync_dependencies(dependencies)
22
+
23
+ if __name__ == "__main__":
24
+ # ~~~ Set the API information ~~~
25
+ # OpenAI backend
26
+ # api_information = ApiInfo("openai", os.getenv("OPENAI_API_KEY"))
27
+ # Azure backend
28
+ api_information = ApiInfo("azure", os.getenv("AZURE_OPENAI_KEY"), os.getenv("AZURE_OPENAI_ENDPOINT"))
29
+
30
+ root_dir = "."
31
+ cfg_path = os.path.join(root_dir, "SimpleQA.yaml")
32
+ cfg = read_yaml_file(cfg_path)
33
+
34
+ # ~~~ Instantiate the Flow ~~~
35
+ flow_with_interfaces = {
36
+ "flow": hydra.utils.instantiate(cfg['flow'], _recursive_=False, _convert_="partial"),
37
+ "input_interface": (
38
+ None
39
+ if getattr(cfg, "input_interface", None) is None
40
+ else hydra.utils.instantiate(cfg['input_interface'], _recursive_=False)
41
+ ),
42
+ "output_interface": (
43
+ None
44
+ if getattr(cfg, "output_interface", None) is None
45
+ else hydra.utils.instantiate(cfg['output_interface'], _recursive_=False)
46
+ ),
47
+ }
48
+
49
+ # ~~~ Get the data ~~~
50
+ data = {"id": 0, "question": "What is the capital of France?"} # This can be a list of samples
51
+ # data = {"id": 0, "question": "Who was the NBA champion in 2023?"} # This can be a list of samples
52
+
53
+ # ~~~ Run inference ~~~
54
+ path_to_output_file = None
55
+ # path_to_output_file = "output.jsonl" # Uncomment this line to save the output to disk
56
+
57
+ _, outputs = FlowLauncher.launch(
58
+ flow_with_interfaces=flow_with_interfaces,
59
+ data=data,
60
+ path_to_output_file=path_to_output_file,
61
+ api_information=api_information,
62
+ )
63
+
64
+ # ~~~ Print the output ~~~
65
+ flow_output_data = outputs[0]
66
+ print(flow_output_data)
simpleQA.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_interface: # Connector between the "input data" and the Flow
2
+ _target_: flows.interfaces.KeyInterface
3
+ additional_transformations:
4
+ - _target_: flows.data_transformations.KeyMatchInput # Pass the input parameters specified by the flow
5
+
6
+ output_interface: # Connector between the Flow's output and the caller
7
+ _target_: flows.interfaces.KeyInterface
8
+ keys_to_rename:
9
+ api_output: answer # Rename the api_output to answer
10
+
11
+ flow: # Overrides the OpenAIChatAtomicFlow config
12
+ _target_: aiflows.OpenAIChatAtomicFlowModule.OpenAIChatAtomicFlow.instantiate_from_default_config
13
+
14
+ name: "SimpleQA_Flow"
15
+ description: "A flow that answers questions."
16
+
17
+ # ~~~ Input interface specification ~~~
18
+ input_interface_non_initialized:
19
+ - "question"
20
+
21
+ # ~~~ OpenAI model parameters ~~
22
+ model: "gpt-3.5-turbo"
23
+ generation_parameters:
24
+ n: 1
25
+ max_tokens: 3000
26
+ temperature: 0.3
27
+
28
+ model_kwargs:
29
+ top_p: 0.2
30
+ frequency_penalty: 0
31
+ presence_penalty: 0
32
+
33
+ n_api_retries: 6
34
+ wait_time_between_retries: 20
35
+
36
+ # ~~~ Prompt specification ~~~
37
+ system_message_prompt_template:
38
+ _target_: langchain.PromptTemplate
39
+ template: |2-
40
+ You are a helpful chatbot that truthfully answers questions.
41
+ input_variables: []
42
+ partial_variables: {}
43
+ template_format: jinja2
44
+
45
+ init_human_message_prompt_template:
46
+ _target_: langchain.PromptTemplate
47
+ template: |2-
48
+ Answer the following question: {{question}}
49
+ input_variables: ["question"]
50
+ partial_variables: {}
51
+ template_format: jinja2