nbaldwin commited on
Commit
bdc9b47
1 Parent(s): 6a1e351

new backend

Browse files
Files changed (3) hide show
  1. OpenAIChatAtomicFlow.py +43 -57
  2. OpenAIChatAtomicFlow.yaml +21 -19
  3. run.py +15 -7
OpenAIChatAtomicFlow.py CHANGED
@@ -1,40 +1,42 @@
1
  from copy import deepcopy
2
-
3
  import hydra
4
 
5
  import time
6
 
7
  from typing import Dict, Optional, Any
8
 
9
- from langchain import PromptTemplate
10
- from langchain.schema import HumanMessage, AIMessage, SystemMessage
11
-
12
  from flows.base_flows import AtomicFlow
13
  from flows.datasets import GenericDemonstrationsDataset
14
 
15
  from flows.utils import logging
16
  from flows.messages.flow_message import UpdateMessage_ChatMessage
17
 
 
 
 
 
18
  log = logging.get_logger(__name__)
19
 
20
 
21
  class OpenAIChatAtomicFlow(AtomicFlow):
22
- REQUIRED_KEYS_CONFIG = ["model_name", "generation_parameters"]
23
 
24
  SUPPORTS_CACHING: bool = True
25
 
26
- system_message_prompt_template: PromptTemplate
27
- human_message_prompt_template: PromptTemplate
28
 
29
- init_human_message_prompt_template: Optional[PromptTemplate] = None
 
30
  demonstrations: GenericDemonstrationsDataset = None
31
  demonstrations_k: Optional[int] = None
32
- demonstrations_response_prompt_template: PromptTemplate = None
33
 
34
  def __init__(self,
35
  system_message_prompt_template,
36
  human_message_prompt_template,
37
  init_human_message_prompt_template,
 
38
  demonstrations_response_prompt_template=None,
39
  demonstrations=None,
40
  **kwargs):
@@ -45,7 +47,7 @@ class OpenAIChatAtomicFlow(AtomicFlow):
45
  self.demonstrations_response_prompt_template = demonstrations_response_prompt_template
46
  self.demonstrations = demonstrations
47
  self.demonstrations_k = self.flow_config.get("demonstrations_k", None)
48
-
49
  assert self.flow_config["name"] not in [
50
  "system",
51
  "user",
@@ -59,20 +61,29 @@ class OpenAIChatAtomicFlow(AtomicFlow):
59
  @classmethod
60
  def _set_up_prompts(cls, config):
61
  kwargs = {}
62
-
63
  kwargs["system_message_prompt_template"] = \
64
  hydra.utils.instantiate(config['system_message_prompt_template'], _convert_="partial")
65
  kwargs["init_human_message_prompt_template"] = \
66
  hydra.utils.instantiate(config['init_human_message_prompt_template'], _convert_="partial")
67
  kwargs["human_message_prompt_template"] = \
68
  hydra.utils.instantiate(config['human_message_prompt_template'], _convert_="partial")
69
-
70
  if "demonstrations_response_prompt_template" in config:
71
  kwargs["demonstrations_response_prompt_template"] = \
72
  hydra.utils.instantiate(config['demonstrations_response_prompt_template'], _convert_="partial")
73
  kwargs["demonstrations"] = GenericDemonstrationsDataset(**config['demonstrations'])
74
 
75
  return kwargs
 
 
 
 
 
 
 
 
 
76
 
77
  @classmethod
78
  def instantiate_from_config(cls, config):
@@ -82,6 +93,7 @@ class OpenAIChatAtomicFlow(AtomicFlow):
82
 
83
  # ~~~ Set up prompts ~~~
84
  kwargs.update(cls._set_up_prompts(flow_config))
 
85
 
86
  # ~~~ Instantiate flow ~~~
87
  return cls(**kwargs)
@@ -106,7 +118,6 @@ class OpenAIChatAtomicFlow(AtomicFlow):
106
  template_kwargs = {}
107
  for input_variable in prompt_template.input_variables:
108
  template_kwargs[input_variable] = input_data[input_variable]
109
-
110
  msg_content = prompt_template.format(**template_kwargs)
111
  return msg_content
112
 
@@ -140,19 +151,16 @@ class OpenAIChatAtomicFlow(AtomicFlow):
140
  role: str,
141
  content: str) -> None:
142
 
143
- # Add the message to the previous messages list
144
- if role == self.flow_config["system_name"]:
145
- self.flow_state["previous_messages"].append(SystemMessage(content=content))
146
- elif role == self.flow_config["user_name"]:
147
- self.flow_state["previous_messages"].append(HumanMessage(content=content))
148
- elif role == self.flow_config["assistant_name"]:
149
- self.flow_state["previous_messages"].append(AIMessage(content=content))
150
  else:
151
  raise Exception(f"Invalid role: `{role}`.\n"
152
  f"Role should be one of: "
153
- f"`{self.flow_config['system_name']}`, "
154
- f"`{self.flow_config['user_name']}`, "
155
- f"`{self.flow_config['assistant_name']}`")
156
 
157
  # Log the update to the flow messages list
158
  chat_message = UpdateMessage_ChatMessage(
@@ -174,49 +182,24 @@ class OpenAIChatAtomicFlow(AtomicFlow):
174
  return all_messages[:first_k] + all_messages[-last_k:]
175
  elif first_k:
176
  return all_messages[:first_k]
177
-
178
  return all_messages[-last_k:]
179
 
180
  def _call(self):
181
- api_information = self._get_from_state("api_information")
182
- api_key = api_information.api_key
183
-
184
- if api_information.backend_used == 'azure':
185
- from backends.azure_openai import SafeAzureChatOpenAI
186
- endpoint = api_information.endpoint
187
- backend = SafeAzureChatOpenAI(
188
- openai_api_type='azure',
189
- openai_api_key=api_key,
190
- openai_api_base=endpoint,
191
- openai_api_version='2023-05-15',
192
- deployment_name=self.flow_config["model_name"],
193
- **self.flow_config["generation_parameters"],
194
- )
195
- elif api_information.backend_used == 'openai':
196
- from backends.openai import SafeChatOpenAI
197
- backend = SafeChatOpenAI(
198
- model_name=self.flow_config["model_name"],
199
- openai_api_key=api_key,
200
- openai_api_type="open_ai",
201
- **self.flow_config["generation_parameters"],
202
- )
203
- else:
204
- raise ValueError(f"Unsupported backend: {api_information.backend_used}")
205
-
206
  messages = self._get_previous_messages()
207
-
208
  _success = False
209
  attempts = 1
210
  error = None
211
  response = None
212
  while attempts <= self.flow_config['n_api_retries']:
213
  try:
214
- response = backend(messages).content
 
215
  _success = True
216
  break
217
  except Exception as e:
218
  log.error(
219
- f"Error {attempts} in calling backend: {e}. Key used: `{api_key}`. "
220
  f"Retrying in {self.flow_config['wait_time_between_retries']} seconds..."
221
  )
222
  # log.error(
@@ -226,7 +209,7 @@ class OpenAIChatAtomicFlow(AtomicFlow):
226
  attempts += 1
227
  time.sleep(self.flow_config['wait_time_between_retries'])
228
  error = e
229
-
230
  if not _success:
231
  raise error
232
 
@@ -266,9 +249,12 @@ class OpenAIChatAtomicFlow(AtomicFlow):
266
 
267
  # ~~~ Call ~~~
268
  response = self._call()
269
- self._state_update_add_chat_message(
270
- role=self.flow_config["assistant_name"],
271
- content=response
272
- )
 
 
 
273
 
274
  return {"api_output": response}
 
1
  from copy import deepcopy
 
2
  import hydra
3
 
4
  import time
5
 
6
  from typing import Dict, Optional, Any
7
 
 
 
 
8
  from flows.base_flows import AtomicFlow
9
  from flows.datasets import GenericDemonstrationsDataset
10
 
11
  from flows.utils import logging
12
  from flows.messages.flow_message import UpdateMessage_ChatMessage
13
 
14
+ from flows.prompt_template import JinjaPrompt
15
+
16
+ from backends.llm_lite import LiteLLMBackend
17
+
18
  log = logging.get_logger(__name__)
19
 
20
 
21
  class OpenAIChatAtomicFlow(AtomicFlow):
22
+ REQUIRED_KEYS_CONFIG = ["backend"]
23
 
24
  SUPPORTS_CACHING: bool = True
25
 
26
+ system_message_prompt_template: JinjaPrompt
27
+ human_message_prompt_template: JinjaPrompt
28
 
29
+ backend: LiteLLMBackend
30
+ init_human_message_prompt_template: Optional[JinjaPrompt] = None
31
  demonstrations: GenericDemonstrationsDataset = None
32
  demonstrations_k: Optional[int] = None
33
+ demonstrations_response_prompt_template: str = None
34
 
35
  def __init__(self,
36
  system_message_prompt_template,
37
  human_message_prompt_template,
38
  init_human_message_prompt_template,
39
+ backend,
40
  demonstrations_response_prompt_template=None,
41
  demonstrations=None,
42
  **kwargs):
 
47
  self.demonstrations_response_prompt_template = demonstrations_response_prompt_template
48
  self.demonstrations = demonstrations
49
  self.demonstrations_k = self.flow_config.get("demonstrations_k", None)
50
+ self.backend = backend
51
  assert self.flow_config["name"] not in [
52
  "system",
53
  "user",
 
61
  @classmethod
62
  def _set_up_prompts(cls, config):
63
  kwargs = {}
64
+
65
  kwargs["system_message_prompt_template"] = \
66
  hydra.utils.instantiate(config['system_message_prompt_template'], _convert_="partial")
67
  kwargs["init_human_message_prompt_template"] = \
68
  hydra.utils.instantiate(config['init_human_message_prompt_template'], _convert_="partial")
69
  kwargs["human_message_prompt_template"] = \
70
  hydra.utils.instantiate(config['human_message_prompt_template'], _convert_="partial")
71
+
72
  if "demonstrations_response_prompt_template" in config:
73
  kwargs["demonstrations_response_prompt_template"] = \
74
  hydra.utils.instantiate(config['demonstrations_response_prompt_template'], _convert_="partial")
75
  kwargs["demonstrations"] = GenericDemonstrationsDataset(**config['demonstrations'])
76
 
77
  return kwargs
78
+
79
+ @classmethod
80
+ def _set_up_backend(cls, config):
81
+ kwargs = {}
82
+
83
+ kwargs["backend"] = \
84
+ hydra.utils.instantiate(config['backend'], _convert_="partial")
85
+
86
+ return kwargs
87
 
88
  @classmethod
89
  def instantiate_from_config(cls, config):
 
93
 
94
  # ~~~ Set up prompts ~~~
95
  kwargs.update(cls._set_up_prompts(flow_config))
96
+ kwargs.update(cls._set_up_backend(flow_config))
97
 
98
  # ~~~ Instantiate flow ~~~
99
  return cls(**kwargs)
 
118
  template_kwargs = {}
119
  for input_variable in prompt_template.input_variables:
120
  template_kwargs[input_variable] = input_data[input_variable]
 
121
  msg_content = prompt_template.format(**template_kwargs)
122
  return msg_content
123
 
 
151
  role: str,
152
  content: str) -> None:
153
 
154
+
155
+ acceptable_roles = [self.flow_config["system_name"],self.flow_config["user_name"],self.flow_config["assistant_name"]]
156
+ if role in acceptable_roles:
157
+ self.flow_state["previous_messages"].append({"role": role , "content": content})
158
+
 
 
159
  else:
160
  raise Exception(f"Invalid role: `{role}`.\n"
161
  f"Role should be one of: "
162
+ f"`{acceptable_roles}`, ")
163
+
 
164
 
165
  # Log the update to the flow messages list
166
  chat_message = UpdateMessage_ChatMessage(
 
182
  return all_messages[:first_k] + all_messages[-last_k:]
183
  elif first_k:
184
  return all_messages[:first_k]
 
185
  return all_messages[-last_k:]
186
 
187
  def _call(self):
188
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  messages = self._get_previous_messages()
 
190
  _success = False
191
  attempts = 1
192
  error = None
193
  response = None
194
  while attempts <= self.flow_config['n_api_retries']:
195
  try:
196
+ response = self.backend(messages=messages,mock_response=False) #set mock_response to True when debugging (fake API request)
197
+ response = [ answer["content"] for answer in response] # because n in the generation parameters can be > 1
198
  _success = True
199
  break
200
  except Exception as e:
201
  log.error(
202
+ f"Error {attempts} in calling backend: {e}. "
203
  f"Retrying in {self.flow_config['wait_time_between_retries']} seconds..."
204
  )
205
  # log.error(
 
209
  attempts += 1
210
  time.sleep(self.flow_config['wait_time_between_retries'])
211
  error = e
212
+
213
  if not _success:
214
  raise error
215
 
 
249
 
250
  # ~~~ Call ~~~
251
  response = self._call()
252
+
253
+ #loop is in case there was more than one answer (n>1 in generation parameters)
254
+ for answer in response:
255
+ self._state_update_add_chat_message(
256
+ role=self.flow_config["assistant_name"],
257
+ content=answer
258
+ )
259
 
260
  return {"api_output": response}
OpenAIChatAtomicFlow.yaml CHANGED
@@ -1,17 +1,6 @@
1
  # This is an abstract flow, therefore some required fields are not defined (and must be defined by the concrete flow)
2
  enable_cache: True
3
 
4
- model_name: "gpt-4"
5
- generation_parameters:
6
- n: 1
7
- max_tokens: 2000
8
- temperature: 0.3
9
-
10
- model_kwargs:
11
- top_p: 0.2
12
- frequency_penalty: 0
13
- presence_penalty: 0
14
-
15
  n_api_retries: 6
16
  wait_time_between_retries: 20
17
 
@@ -19,26 +8,39 @@ system_name: system
19
  user_name: user
20
  assistant_name: assistant
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  system_message_prompt_template:
23
- _target_: langchain.PromptTemplate
24
- template_format: jinja2
25
 
26
  init_human_message_prompt_template:
27
- _target_: langchain.PromptTemplate
28
- template_format: jinja2
29
 
30
  human_message_prompt_template:
31
- _target_: langchain.PromptTemplate
32
  template: "{{query}}"
33
  input_variables:
34
  - "query"
35
- template_format: jinja2
36
  input_interface_initialized:
37
  - "query"
38
 
39
  query_message_prompt_template:
40
- _target_: langchain.PromptTemplate
41
- template_format: jinja2
42
 
43
  previous_messages:
44
  first_k: null # Note that the first message is the system prompt
 
1
  # This is an abstract flow, therefore some required fields are not defined (and must be defined by the concrete flow)
2
  enable_cache: True
3
 
 
 
 
 
 
 
 
 
 
 
 
4
  n_api_retries: 6
5
  wait_time_between_retries: 20
6
 
 
8
  user_name: user
9
  assistant_name: assistant
10
 
11
+ backend:
12
+ _target_: backends.llm_lite.LiteLLMBackend
13
+ api_infos: ???
14
+ model_name: "gpt-3.5-turbo"
15
+ n: 1
16
+ max_tokens: 2000
17
+ temperature: 0.3
18
+
19
+
20
+ top_p: 0.2
21
+ frequency_penalty: 0
22
+ presence_penalty: 0
23
+ stream: True
24
+
25
+
26
  system_message_prompt_template:
27
+ _target_: flows.prompt_template.JinjaPrompt
28
+
29
 
30
  init_human_message_prompt_template:
31
+ _target_: flows.prompt_template.JinjaPrompt
 
32
 
33
  human_message_prompt_template:
34
+ _target_: flows.prompt_template.JinjaPrompt
35
  template: "{{query}}"
36
  input_variables:
37
  - "query"
 
38
  input_interface_initialized:
39
  - "query"
40
 
41
  query_message_prompt_template:
42
+ _target_: flows.prompt_template.JinjaPrompt
43
+
44
 
45
  previous_messages:
46
  first_k: null # Note that the first message is the system prompt
run.py CHANGED
@@ -3,7 +3,8 @@ import os
3
  import hydra
4
 
5
  import flows
6
- from flows.flow_launchers import FlowLauncher, ApiInfo
 
7
  from flows.utils.general_helpers import read_yaml_file
8
 
9
  from flows import logging
@@ -23,25 +24,33 @@ flow_verse.sync_dependencies(dependencies)
23
  if __name__ == "__main__":
24
  # ~~~ Set the API information ~~~
25
  # OpenAI backend
26
- # api_information = ApiInfo("openai", os.getenv("OPENAI_API_KEY"))
27
- # Azure backend
28
- api_information = ApiInfo("azure", os.getenv("AZURE_OPENAI_KEY"), os.getenv("AZURE_OPENAI_ENDPOINT"))
 
 
 
 
 
 
29
 
30
  root_dir = "."
31
  cfg_path = os.path.join(root_dir, "SimpleQA.yaml")
32
  cfg = read_yaml_file(cfg_path)
33
 
 
 
34
  # ~~~ Instantiate the Flow ~~~
35
  flow_with_interfaces = {
36
  "flow": hydra.utils.instantiate(cfg['flow'], _recursive_=False, _convert_="partial"),
37
  "input_interface": (
38
  None
39
- if getattr(cfg, "input_interface", None) is None
40
  else hydra.utils.instantiate(cfg['input_interface'], _recursive_=False)
41
  ),
42
  "output_interface": (
43
  None
44
- if getattr(cfg, "output_interface", None) is None
45
  else hydra.utils.instantiate(cfg['output_interface'], _recursive_=False)
46
  ),
47
  }
@@ -58,7 +67,6 @@ if __name__ == "__main__":
58
  flow_with_interfaces=flow_with_interfaces,
59
  data=data,
60
  path_to_output_file=path_to_output_file,
61
- api_information=api_information,
62
  )
63
 
64
  # ~~~ Print the output ~~~
 
3
  import hydra
4
 
5
  import flows
6
+ from flows.flow_launchers import FlowLauncher
7
+ from backends.api_info import ApiInfo
8
  from flows.utils.general_helpers import read_yaml_file
9
 
10
  from flows import logging
 
24
  if __name__ == "__main__":
25
  # ~~~ Set the API information ~~~
26
  # OpenAI backend
27
+
28
+ api_information = [ApiInfo(backend_used="openai",
29
+ api_key = os.getenv("OPENAI_API_KEY"))]
30
+
31
+ # # Azure backend
32
+ # api_information = ApiInfo(backend_used = "azure",
33
+ # api_base = os.getenv("AZURE_API_BASE"),
34
+ # api_key = os.getenv("AZURE_OPENAI_KEY"),
35
+ # api_version = os.getenv("AZURE_API_VERSION") )
36
 
37
  root_dir = "."
38
  cfg_path = os.path.join(root_dir, "SimpleQA.yaml")
39
  cfg = read_yaml_file(cfg_path)
40
 
41
+ cfg["flow"]["backend"]["api_infos"] = api_information
42
+ # ~~~ Instantiate the Flow ~~~
43
  # ~~~ Instantiate the Flow ~~~
44
  flow_with_interfaces = {
45
  "flow": hydra.utils.instantiate(cfg['flow'], _recursive_=False, _convert_="partial"),
46
  "input_interface": (
47
  None
48
+ if cfg.get( "input_interface", None) is None
49
  else hydra.utils.instantiate(cfg['input_interface'], _recursive_=False)
50
  ),
51
  "output_interface": (
52
  None
53
+ if cfg.get( "output_interface", None) is None
54
  else hydra.utils.instantiate(cfg['output_interface'], _recursive_=False)
55
  ),
56
  }
 
67
  flow_with_interfaces=flow_with_interfaces,
68
  data=data,
69
  path_to_output_file=path_to_output_file,
 
70
  )
71
 
72
  # ~~~ Print the output ~~~