nbaldwin commited on
Commit
15f0db2
1 Parent(s): 297c90d

coflows update compatible

Browse files
Files changed (3) hide show
  1. ChatAtomicFlow.py +10 -7
  2. demo.yaml +45 -55
  3. run.py +74 -33
ChatAtomicFlow.py CHANGED
@@ -14,6 +14,8 @@ from aiflows.prompt_template import JinjaPrompt
14
 
15
  from aiflows.backends.llm_lite import LiteLLMBackend
16
 
 
 
17
  log = logging.get_logger(__name__)
18
 
19
 
@@ -356,15 +358,13 @@ class ChatAtomicFlow(AtomicFlow):
356
  self._state_update_add_chat_message(role=self.flow_config["user_name"],
357
  content=user_message_content)
358
 
359
- def run(self,input_data: Dict[str, Any]):
360
  """ This method runs the flow. It processes the input, calls the backend and updates the state of the flow.
361
 
362
- :param input_data: The input data of the flow.
363
- :type input_data: Dict[str, Any]
364
- :return: The LLM's api output.
365
- :rtype: Dict[str, Any]
366
  """
367
-
368
  # ~~~ Process input ~~~
369
  self._process_input(input_data)
370
 
@@ -378,4 +378,7 @@ class ChatAtomicFlow(AtomicFlow):
378
  content=answer
379
  )
380
  response = response if len(response) > 1 or len(response) == 0 else response[0]
381
- return {"api_output": response}
 
 
 
 
14
 
15
  from aiflows.backends.llm_lite import LiteLLMBackend
16
 
17
+ from aiflows.messages import FlowMessage
18
+
19
  log = logging.get_logger(__name__)
20
 
21
 
 
358
  self._state_update_add_chat_message(role=self.flow_config["user_name"],
359
  content=user_message_content)
360
 
361
+ def run(self,input_message: FlowMessage):
362
  """ This method runs the flow. It processes the input, calls the backend and updates the state of the flow.
363
 
364
+ :param input_message: The input data of the flow.
365
+ :type input_message: aiflows.messages.FlowMessage
 
 
366
  """
367
+ input_data = input_message.data
368
  # ~~~ Process input ~~~
369
  self._process_input(input_data)
370
 
 
378
  content=answer
379
  )
380
  response = response if len(response) > 1 or len(response) == 0 else response[0]
381
+
382
+ reply_message = self._package_output_message(input_message, response = {"api_output": response})
383
+
384
+ self.reply_to_message(reply = reply_message, to = input_message)
demo.yaml CHANGED
@@ -1,56 +1,46 @@
1
- input_interface: # Connector between the "input data" and the Flow
2
- _target_: aiflows.interfaces.KeyInterface
3
- additional_transformations:
4
- - _target_: aiflows.data_transformations.KeyMatchInput # Pass the input parameters specified by the flow
5
-
6
- output_interface: # Connector between the Flow's output and the caller
7
- _target_: aiflows.interfaces.KeyInterface
8
- keys_to_rename:
9
- api_output: answer # Rename the api_output to answer
10
-
11
- flow: # Overrides the ChatAtomicFlow config
12
- _target_: flow_modules.aiflows.ChatFlowModule.ChatAtomicFlow.instantiate_from_default_config
13
-
14
- name: "SimpleQA_Flow"
15
- description: "A flow that answers questions."
16
-
17
- # ~~~ Input interface specification ~~~
18
- input_interface_non_initialized:
19
- - "question"
20
-
21
- # ~~~ backend model parameters ~~
22
- backend:
23
- _target_: aiflows.backends.llm_lite.LiteLLMBackend
24
- api_infos: ???
25
- model_name:
26
- openai: "gpt-3.5-turbo"
27
- azure: "azure/gpt-4"
28
-
29
- # ~~~ generation_parameters ~~
30
- n: 1
31
- max_tokens: 3000
32
- temperature: 0.3
33
-
34
- top_p: 0.2
35
- frequency_penalty: 0
36
- presence_penalty: 0
37
-
38
- n_api_retries: 6
39
- wait_time_between_retries: 20
40
-
41
- # ~~~ Prompt specification ~~~
42
- system_message_prompt_template:
43
- _target_: aiflows.prompt_template.JinjaPrompt
44
- template: |2-
45
- You are a helpful chatbot that truthfully answers questions.
46
- input_variables: []
47
- partial_variables: {}
48
-
49
-
50
- init_human_message_prompt_template:
51
- _target_: aiflows.prompt_template.JinjaPrompt
52
- template: |2-
53
- Answer the following question: {{question}}
54
- input_variables: ["question"]
55
- partial_variables: {}
56
 
 
1
+
2
+ _target_: flow_modules.aiflows.ChatFlowModule.ChatAtomicFlow.instantiate_from_default_config
3
+
4
+ name: "SimpleQA_Flow"
5
+ description: "A flow that answers questions."
6
+
7
+ # ~~~ Input interface specification ~~~
8
+ input_interface_non_initialized:
9
+ - "question"
10
+
11
+ # ~~~ backend model parameters ~~
12
+ backend:
13
+ _target_: aiflows.backends.llm_lite.LiteLLMBackend
14
+ api_infos: ???
15
+ model_name:
16
+ openai: "gpt-3.5-turbo"
17
+ azure: "azure/gpt-4"
18
+
19
+ # ~~~ generation_parameters ~~
20
+ n: 1
21
+ max_tokens: 3000
22
+ temperature: 0.3
23
+
24
+ top_p: 0.2
25
+ frequency_penalty: 0
26
+ presence_penalty: 0
27
+
28
+ n_api_retries: 6
29
+ wait_time_between_retries: 20
30
+
31
+ # ~~~ Prompt specification ~~~
32
+ system_message_prompt_template:
33
+ _target_: aiflows.prompt_template.JinjaPrompt
34
+ template: |2-
35
+ You are a helpful chatbot that truthfully answers questions.
36
+ input_variables: []
37
+ partial_variables: {}
38
+
39
+
40
+ init_human_message_prompt_template:
41
+ _target_: aiflows.prompt_template.JinjaPrompt
42
+ template: |2-
43
+ Answer the following question: {{question}}
44
+ input_variables: ["question"]
45
+ partial_variables: {}
 
 
 
 
 
 
 
 
 
 
46
 
run.py CHANGED
@@ -5,11 +5,18 @@ import hydra
5
  import aiflows
6
  from aiflows.flow_launchers import FlowLauncher
7
  from aiflows.backends.api_info import ApiInfo
8
- from aiflows.utils.general_helpers import read_yaml_file
9
 
10
  from aiflows import logging
11
  from aiflows.flow_cache import CACHING_PARAMETERS, clear_cache
12
 
 
 
 
 
 
 
 
13
  CACHING_PARAMETERS.do_caching = False # Set to True in order to disable caching
14
  # clear_cache() # Uncomment this line to clear the cache
15
 
@@ -22,52 +29,86 @@ from aiflows import flow_verse
22
  flow_verse.sync_dependencies(dependencies)
23
 
24
  if __name__ == "__main__":
25
- # ~~~ Set the API information ~~~
26
- # OpenAI backend
 
 
 
 
27
 
 
 
 
 
 
 
 
28
  api_information = [ApiInfo(backend_used="openai",
29
  api_key = os.getenv("OPENAI_API_KEY"))]
30
-
31
  # # Azure backend
32
  # api_information = ApiInfo(backend_used = "azure",
33
  # api_base = os.getenv("AZURE_API_BASE"),
34
  # api_key = os.getenv("AZURE_OPENAI_KEY"),
35
  # api_version = os.getenv("AZURE_API_VERSION") )
 
 
36
 
37
- root_dir = "."
38
- cfg_path = os.path.join(root_dir, "demo.yaml")
39
- cfg = read_yaml_file(cfg_path)
 
 
 
 
 
 
 
 
 
40
 
41
- cfg["flow"]["backend"]["api_infos"] = api_information
42
- # ~~~ Instantiate the Flow ~~~
43
- flow_with_interfaces = {
44
- "flow": hydra.utils.instantiate(cfg['flow'], _recursive_=False, _convert_="partial"),
45
- "input_interface": (
46
- None
47
- if cfg.get( "input_interface", None) is None
48
- else hydra.utils.instantiate(cfg['input_interface'], _recursive_=False)
49
- ),
50
- "output_interface": (
51
- None
52
- if cfg.get( "output_interface", None) is None
53
- else hydra.utils.instantiate(cfg['output_interface'], _recursive_=False)
54
- ),
55
- }
56
 
57
- # ~~~ Get the data ~~~
58
  data = {"id": 0, "question": "What is the capital of France?"} # This can be a list of samples
59
  # data = {"id": 0, "question": "Who was the NBA champion in 2023?"} # This can be a list of samples
60
-
61
- # ~~~ Run inference ~~~
62
- path_to_output_file = None
63
- # path_to_output_file = "output.jsonl" # Uncomment this line to save the output to disk
64
-
65
- _, outputs = FlowLauncher.launch(
66
- flow_with_interfaces=flow_with_interfaces,
67
  data=data,
68
- path_to_output_file=path_to_output_file,
69
  )
70
 
 
 
 
 
 
 
 
 
 
 
71
  # ~~~ Print the output ~~~
72
- flow_output_data = outputs[0]
73
- print(flow_output_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import aiflows
6
  from aiflows.flow_launchers import FlowLauncher
7
  from aiflows.backends.api_info import ApiInfo
8
+ from aiflows.utils.general_helpers import read_yaml_file, quick_load_api_keys
9
 
10
  from aiflows import logging
11
  from aiflows.flow_cache import CACHING_PARAMETERS, clear_cache
12
 
13
+ from aiflows.utils import serve_utils
14
+ from aiflows.workers import run_dispatch_worker_thread
15
+ from aiflows.messages import FlowMessage
16
+ from aiflows.interfaces import KeyInterface
17
+ from aiflows.utils.colink_utils import start_colink_server
18
+ from aiflows.workers import run_dispatch_worker_thread
19
+
20
  CACHING_PARAMETERS.do_caching = False # Set to True in order to disable caching
21
  # clear_cache() # Uncomment this line to clear the cache
22
 
 
29
  flow_verse.sync_dependencies(dependencies)
30
 
31
  if __name__ == "__main__":
32
+
33
+ #1. ~~~~~ Set up a colink server ~~~~
34
+ FLOW_MODULES_PATH = "./"
35
+
36
+ cl = start_colink_server()
37
+
38
 
39
+ #2. ~~~~~Load flow config~~~~~~
40
+ root_dir = "."
41
+ cfg_path = os.path.join(root_dir, "demo.yaml")
42
+ cfg = read_yaml_file(cfg_path)
43
+
44
+ #2.1 ~~~ Set the API information ~~~
45
+ # OpenAI backend
46
  api_information = [ApiInfo(backend_used="openai",
47
  api_key = os.getenv("OPENAI_API_KEY"))]
 
48
  # # Azure backend
49
  # api_information = ApiInfo(backend_used = "azure",
50
  # api_base = os.getenv("AZURE_API_BASE"),
51
  # api_key = os.getenv("AZURE_OPENAI_KEY"),
52
  # api_version = os.getenv("AZURE_API_VERSION") )
53
+
54
+ quick_load_api_keys(cfg, api_information, key="api_infos")
55
 
56
+
57
+ #3. ~~~~ Serve The Flow ~~~~
58
+ serve_utils.serve_flow(
59
+ cl = cl,
60
+ flow_type="ChatFlowModule",
61
+ default_config=cfg,
62
+ default_state=None,
63
+ default_dispatch_point="coflows_dispatch"
64
+ )
65
+
66
+ #4. ~~~~~Start A Worker Thread~~~~~
67
+ run_dispatch_worker_thread(cl, dispatch_point="coflows_dispatch", flow_modules_base_path=FLOW_MODULES_PATH)
68
 
69
+ #5. ~~~~~Mount the flow and get its proxy~~~~~~
70
+ proxy_flow = serve_utils.recursive_mount(
71
+ cl=cl,
72
+ client_id="local",
73
+ flow_type="ChatFlowModule",
74
+ config_overrides=None,
75
+ initial_state=None,
76
+ dispatch_point_override=None,
77
+ )
78
+
 
 
 
 
 
79
 
80
+ #6. ~~~ Get the data ~~~
81
  data = {"id": 0, "question": "What is the capital of France?"} # This can be a list of samples
82
  # data = {"id": 0, "question": "Who was the NBA champion in 2023?"} # This can be a list of samples
83
+
84
+ #option1: use the FlowMessage class
85
+ input_message = FlowMessage(
 
 
 
 
86
  data=data,
 
87
  )
88
 
89
+ #option2: use the proxy_flow
90
+ #input_message = proxy_flow._package_input_message(data = data)
91
+
92
+ #7. ~~~ Run inference ~~~
93
+ future = proxy_flow.send_message_blocking(input_message)
94
+
95
+ #uncomment this line if you would like to get the full message back
96
+ #reply_message = future.get_message()
97
+ reply_data = future.get_data()
98
+
99
  # ~~~ Print the output ~~~
100
+ print("~~~~~~Reply~~~~~~")
101
+ print(reply_data)
102
+
103
+
104
+ #8. ~~~~ (Optional) apply output interface on reply ~~~~
105
+ # output_interface = KeyInterface(
106
+ # keys_to_rename={"api_output": "answer"},
107
+ # )
108
+ # print("Output: ", output_interface(reply_data))
109
+
110
+
111
+ #9. ~~~~~Optional: Unserve Flow~~~~~~
112
+ # serve_utils.delete_served_flow(cl, "ReverseNumberAtomicFlow_served")
113
+
114
+