yeeef commited on
Commit
e948884
1 Parent(s): 4003917

tunable OpenAIChatAtomicFlow

Browse files
.gitattributes CHANGED
@@ -25,7 +25,6 @@
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
 
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
29
  *.tgz filter=lfs diff=lfs merge=lfs -text
30
  *.wasm filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ .idea/
161
+ .*cache*/
FLAMLOpenAIChatAtomicFlow.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ from copy import deepcopy
3
+
4
+ import hydra
5
+ import logging
6
+
7
+ import colorama
8
+ import time
9
+
10
+ from typing import List, Dict, Optional, Any, Callable, Tuple
11
+
12
+ from flaml import tune, BlendSearch
13
+
14
+ from langchain import PromptTemplate
15
+ import langchain
16
+ from langchain.schema import HumanMessage, AIMessage, SystemMessage
17
+
18
+ from flows.history import FlowHistory
19
+ from flows.message_annotators.abstract import MessageAnnotator
20
+ from flows.base_flows.abstract import AtomicFlow
21
+ from flows.datasets import GenericDemonstrationsDataset
22
+
23
+ from flows import utils
24
+ from flows.messages.chat_message import ChatMessage
25
+ from flows.utils.caching_utils import flow_run_cache
26
+
27
+ log = utils.get_pylogger(__name__)
28
+ logger = log
29
+
30
+
31
+ class FLAMLOpenAIChatAtomicFlow(AtomicFlow):
32
+ model_name: str
33
+ generation_parameters: Dict
34
+
35
+ system_message_prompt_template: PromptTemplate
36
+ human_message_prompt_template: PromptTemplate
37
+
38
+ system_name: str = "system"
39
+ user_name: str = "user"
40
+ assistant_name: str = "assistant"
41
+
42
+ n_api_retries: int = 6
43
+ wait_time_between_retries: int = 20
44
+
45
+ query_message_prompt_template: Optional[PromptTemplate] = None
46
+ demonstrations: GenericDemonstrationsDataset = None
47
+ demonstrations_response_template: PromptTemplate = None
48
+ response_annotators: Optional[Dict[str, MessageAnnotator]] = {}
49
+
50
+ default_search_space = {
51
+ # "model": tune.choice(
52
+ # [
53
+ # # "text-ada-001",
54
+ # # "text-babbage-001",
55
+ # # "text-davinci-003",
56
+ # "gpt-3.5-turbo",
57
+ # # "gpt-4",
58
+ # ]
59
+ # ),
60
+ "temperature_or_top_p": tune.choice(
61
+ [
62
+ {"temperature": tune.uniform(0, 2)},
63
+ {"top_p": tune.uniform(0, 1)},
64
+ ]
65
+ ),
66
+ "max_tokens": tune.lograndint(1000, 4000),
67
+ # we use langchain api, https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/base.py#L201
68
+ # it only take the first generation as the output, thus n is not relevant
69
+ # "n": tune.randint(1, 100),
70
+ }
71
+
72
+ def __init__(self, **kwargs):
73
+ self._validate_parameters(kwargs)
74
+ super().__init__(**kwargs)
75
+
76
+ assert self.flow_config["name"] not in [
77
+ "system",
78
+ "user",
79
+ "assistant",
80
+ ], f"Flow name '{self.flow_config['name']}' cannot be 'system', 'user' or 'assistant'"
81
+
82
+ def set_up_flow_state(self):
83
+ super().set_up_flow_state()
84
+ self.flow_state["conversation_initialized"] = False
85
+
86
+ @classmethod
87
+ def _validate_parameters(cls, kwargs):
88
+ # ToDo: Deal with this in a cleaner way (with less repetition)
89
+ super()._validate_parameters(kwargs)
90
+
91
+ # ~~~ Model generation ~~~
92
+ if "model_name" not in kwargs["flow_config"]:
93
+ raise KeyError("model_name not specified in the flow_config.")
94
+
95
+ if "generation_parameters" not in kwargs["flow_config"]:
96
+ raise KeyError("generation_parameters not specified in the flow_config.")
97
+
98
+ # ~~~ Prompting ~~~
99
+ if "system_message_prompt_template" not in kwargs:
100
+ raise KeyError("system_message_prompt_template not passed to the constructor.")
101
+
102
+ if "query_message_prompt_template" not in kwargs:
103
+ raise KeyError("query_message_prompt_template not passed to the constructor.")
104
+
105
+ if "human_message_prompt_template" not in kwargs:
106
+ raise KeyError("human_message_prompt_template not passed to the constructor.")
107
+
108
+ @classmethod
109
+ def _set_up_prompts(cls, config):
110
+ kwargs = {}
111
+
112
+ kwargs["system_message_prompt_template"] = \
113
+ hydra.utils.instantiate(config['system_message_prompt_template'], _convert_="partial")
114
+ kwargs["query_message_prompt_template"] = \
115
+ hydra.utils.instantiate(config['query_message_prompt_template'], _convert_="partial")
116
+ kwargs["human_message_prompt_template"] = \
117
+ hydra.utils.instantiate(config['human_message_prompt_template'], _convert_="partial")
118
+
119
+ return kwargs
120
+
121
+ @classmethod
122
+ def _set_up_demonstration_templates(cls, config):
123
+ kwargs = {}
124
+
125
+ if "demonstrations_response_template" in config:
126
+ kwargs["demonstrations_response_template"] = \
127
+ hydra.utils.instantiate(config['demonstrations_response_template'], _convert_="partial")
128
+
129
+ return kwargs
130
+
131
+ @classmethod
132
+ def _set_up_response_annotators(cls, config):
133
+ response_annotators = config.get("response_annotators", {})
134
+ if len(response_annotators) > 0:
135
+ for key, config in response_annotators.items():
136
+ if isinstance(config, MessageAnnotator):
137
+ response_annotators[key] = config
138
+ else:
139
+ response_annotators[key] = hydra.utils.instantiate(config, _convert_="partial")
140
+ return {"response_annotators": response_annotators}
141
+
142
+ @classmethod
143
+ def instantiate_from_config(cls, config):
144
+ flow_config = deepcopy(config)
145
+
146
+ kwargs = {"flow_config": flow_config}
147
+
148
+ # ~~~ Set up prompts ~~~
149
+ kwargs.update(cls._set_up_prompts(flow_config))
150
+
151
+ # ~~~ Set up demonstration templates ~~~
152
+ kwargs.update(cls._set_up_demonstration_templates(flow_config))
153
+
154
+ # ~~~ Set up response annotators ~~~
155
+ kwargs.update(cls._set_up_response_annotators(flow_config))
156
+
157
+ # ~~~ Instantiate flow ~~~
158
+ return cls(**kwargs)
159
+
160
+ def _is_conversation_initialized(self):
161
+ return self.flow_state["conversation_initialized"]
162
+
163
+ def expected_inputs_given_state(self):
164
+ if self._is_conversation_initialized():
165
+ return ["query"]
166
+ else:
167
+ return self.flow_config["expected_inputs"]
168
+
169
+ @staticmethod
170
+ def _get_message(prompt_template, input_data: Dict[str, Any]):
171
+ template_kwargs = {}
172
+ for input_variable in prompt_template.input_variables:
173
+ template_kwargs[input_variable] = input_data[input_variable]
174
+
175
+ msg_content = prompt_template.format(**template_kwargs)
176
+ return msg_content
177
+
178
+ def _get_demonstration_query_message_content(self, sample_data: Dict):
179
+ input_variables = self.query_message_prompt_template.input_variables
180
+ return self.query_message_prompt_template.format(**{k: sample_data[k] for k in input_variables}), []
181
+
182
+ def _get_demonstration_response_message_content(self, sample_data: Dict):
183
+ input_variables = self.demonstrations_response_template.input_variables
184
+ return self.demonstrations_response_template.format(**{k: sample_data[k] for k in input_variables}), []
185
+
186
+ def _get_annotator_with_key(self, key: str):
187
+ for _, ra in self.response_annotators.items():
188
+ if ra.key == key:
189
+ return ra
190
+
191
+ def _response_parsing(self, response: str, expected_outputs: List[str]):
192
+ target_annotators = [ra for _, ra in self.response_annotators.items() if ra.key in expected_outputs]
193
+
194
+ if len(target_annotators) == 0:
195
+ return {expected_outputs[0]: response}
196
+
197
+ parsed_outputs = {}
198
+ for ra in target_annotators:
199
+ parsed_out = ra(response)
200
+ parsed_outputs.update(parsed_out)
201
+ return parsed_outputs
202
+
203
+ def _add_demonstrations(self):
204
+ if self.demonstrations is not None:
205
+ for example in self.demonstrations:
206
+ query, parents = self._get_demonstration_query_message_content(example)
207
+ response, parents = self._get_demonstration_response_message_content(example)
208
+
209
+ self._log_chat_message(content=query,
210
+ message_creator=self.user_name,
211
+ parent_message_ids=parents)
212
+
213
+ self._log_chat_message(content=response,
214
+ message_creator=self.assistant_name,
215
+ parent_message_ids=parents)
216
+
217
+ def _log_chat_message(self, message_creator: str, content: str, parent_message_ids: List[str] = None):
218
+ chat_message = ChatMessage(
219
+ message_creator=message_creator,
220
+ parent_message_ids=parent_message_ids,
221
+ flow_runner=self.flow_config["name"],
222
+ flow_run_id=self.flow_run_id,
223
+ content=content
224
+ )
225
+ return self._log_message(chat_message)
226
+
227
+ def _initialize_conversation(self, input_data: Dict[str, Any]):
228
+ # ~~~ Add the system message ~~~
229
+ system_message_content = self._get_message(self.system_message_prompt_template, input_data)
230
+
231
+ self._log_chat_message(content=system_message_content,
232
+ message_creator=self.system_name)
233
+
234
+ # ~~~ Add the demonstration query-response tuples (if any) ~~~
235
+ self._add_demonstrations()
236
+ self._update_state(update_data={"conversation_initialized": True})
237
+
238
+ def get_conversation_messages(self, message_format: Optional[str] = None):
239
+ messages = self.flow_state["history"].get_chat_messages()
240
+
241
+ if message_format is None:
242
+ return messages
243
+
244
+ elif message_format == "open_ai":
245
+ processed_messages = []
246
+
247
+ for message in messages:
248
+ if message.message_creator == self.system_name:
249
+ processed_messages.append(SystemMessage(content=message.content))
250
+ elif message.message_creator == self.assistant_name:
251
+ processed_messages.append(AIMessage(content=message.content))
252
+ elif message.message_creator == self.user_name:
253
+ processed_messages.append(HumanMessage(content=message.content))
254
+ else:
255
+ raise ValueError(f"Unknown name: {message.message_creator}")
256
+ return processed_messages
257
+ else:
258
+ raise ValueError(
259
+ f"Currently supported conversation message formats: 'open_ai'. '{message_format}' is not supported")
260
+
261
+ def _call(self):
262
+ api_key = self.flow_state["api_key"]
263
+
264
+ backend = langchain.chat_models.ChatOpenAI(
265
+ model_name=self.flow_config["model_name"],
266
+ openai_api_key=api_key,
267
+ **self.flow_config["generation_parameters"],
268
+ )
269
+
270
+ messages = self.get_conversation_messages(
271
+ message_format="open_ai"
272
+ )
273
+
274
+ _success = False
275
+ attempts = 1
276
+ error = None
277
+ response = None
278
+ while attempts <= self.n_api_retries:
279
+ try:
280
+ response = backend(messages).content
281
+ _success = True
282
+ break
283
+ except Exception as e:
284
+ log.error(
285
+ f"Error {attempts} in calling backend: {e}. Key used: `{api_key}`. "
286
+ f"Retrying in {self.wait_time_between_retries} seconds..."
287
+ )
288
+ log.error(
289
+ f"API call raised Exception with the following arguments arguments: "
290
+ f"\n{self.flow_state['history'].to_string()}"
291
+ )
292
+ attempts += 1
293
+ time.sleep(self.wait_time_between_retries)
294
+ error = e
295
+
296
+ if not _success:
297
+ raise error
298
+
299
+ if self.flow_config["verbose"]:
300
+ messages_str = self.flow_state["history"].to_string()
301
+ log.info(
302
+ f"\n{colorama.Fore.MAGENTA}~~~ History [{self.flow_config['name']}] ~~~\n"
303
+ f"{colorama.Style.RESET_ALL}{messages_str}"
304
+ )
305
+
306
+ return response
307
+
308
+ def _prepare_conversation(self, input_data: Dict[str, Any]):
309
+ if self._is_conversation_initialized():
310
+ # ~~~ Check that the message has a `query` field ~~~
311
+ user_message_content = self.human_message_prompt_template.format(query=input_data["query"])
312
+
313
+ else:
314
+ self._initialize_conversation(input_data)
315
+ user_message_content = self._get_message(self.query_message_prompt_template, input_data)
316
+
317
+ self._log_chat_message(message_creator=self.user_name,
318
+ content=user_message_content)
319
+
320
+ @flow_run_cache()
321
+ def run(self, input_data: Dict[str, Any], expected_outputs: List[str]) -> Dict[str, Any]:
322
+ # ~~~ Chat-specific preparation ~~~
323
+ self._prepare_conversation(input_data)
324
+
325
+ # ~~~ Call ~~~
326
+ response = self._call()
327
+ answer_message = self._log_chat_message(
328
+ message_creator=self.flow_config["assistant_name"],
329
+ content=response
330
+ )
331
+
332
+ # ~~~ Response parsing ~~~
333
+ parsed_outputs = self._response_parsing(
334
+ response=response,
335
+ expected_outputs=expected_outputs
336
+ )
337
+ self._update_state(update_data=parsed_outputs)
338
+
339
+ if self.flow_config["verbose"]:
340
+ parsed_output_messages_str = pprint.pformat({k: m for k, m in parsed_outputs.items()},
341
+ indent=4)
342
+ log.info(
343
+ f"\n{colorama.Fore.MAGENTA}~~~ "
344
+ f"Response [{answer_message.message_creator} -- "
345
+ f"{answer_message.message_id} -- "
346
+ f"{answer_message.flow_run_id}] ~~~"
347
+ f"\n{colorama.Fore.YELLOW}Content: {answer_message}{colorama.Style.RESET_ALL}"
348
+ f"\n{colorama.Fore.YELLOW}Parsed Outputs: {parsed_output_messages_str}{colorama.Style.RESET_ALL}"
349
+ )
350
+
351
+ # ~~~ The final answer should be in self.flow_state, thus allow_class_namespace=False ~~~
352
+ return self._get_keys_from_state(keys=expected_outputs, allow_class_namespace=False)
353
+
354
+ @classmethod
355
+ def tune(
356
+ cls,
357
+ tune_dps: List[Dict],
358
+ metric: str,
359
+ mode: str,
360
+ eval_func: Callable,
361
+ api_key: str,
362
+ log_file_name: Optional[str] = None, # TODO(yeeef)
363
+ inference_budget: Optional[float] = None,
364
+ optimization_budget: Optional[float] = None,
365
+ num_samples: Optional[int] = 1,
366
+ logging_level: Optional[int] = logging.WARN, # TODO(yeeef)
367
+ initial_flow_config: Optional[Dict] = None, # if not supplied will use default flow config of the class (xxx.yaml)
368
+ **config,
369
+ ) -> Tuple[Dict, Any]: # tune.ExperimentAnalysis
370
+ """
371
+ Args:
372
+ - tune_dps (list): The list of data points to tune the hyperparameters.
373
+ - metric (str): The metric to optimize.
374
+ - mode (str): The optimization mode, "min" or "max.
375
+ - eval_func (Callable): The evaluation function for responses.
376
+ The function should take a response and a data point as input,
377
+ and return a dict of metrics.
378
+ - log_file_name (str, optional): The log file.
379
+ - inference_budget (float, optional): The inference budget, dollar per instance.
380
+ - optimization_budget (float, optional): The optimization budget, dollar in total.
381
+ - num_samples (int, optional): The number of samples to evaluate.
382
+ -1 means no hard restriction in the number of trials
383
+ and the actual number is decided by optimization_budget. Defaults to 1.
384
+ - logging_level (optional): logging level. Defaults to logging.WARNING.
385
+ - **config (dict): The search space to update over the default search.
386
+ For prompt, please provide a string/Callable or a list of strings/Callables.
387
+ - If prompt is provided for chat models, it will be converted to messages under role "user".
388
+ - Do not provide both prompt and messages for chat models, but provide either of them.
389
+ - A string template will be used to generate a prompt for each data instance
390
+ using `prompt.format(**data)`.
391
+ - A callable template will be used to generate a prompt for each data instance
392
+ using `prompt(data)`.
393
+ For stop, please provide a string, a list of strings, or a list of lists of strings.
394
+ For messages (chat models only), please provide a list of messages (for a single chat prefix)
395
+ or a list of lists of messages (for multiple choices of chat prefix to choose from).
396
+ Each message should be a dict with keys "role" and "content". The value of "content" can be a string/Callable template.
397
+
398
+ Returns:
399
+ - dict: The optimized hyperparameter setting.
400
+ - tune.ExperimentAnalysis: The tuning results.
401
+ """
402
+
403
+ initial_flow_config = initial_flow_config or cls.get_config()
404
+ space = cls.default_search_space.copy()
405
+
406
+ if config is not None:
407
+ space.update(config)
408
+ if "messages" in space:
409
+ space.pop("prompt", None)
410
+ temperature = space.pop("temperature", None)
411
+ top_p = space.pop("top_p", None)
412
+ if temperature is not None and top_p is None:
413
+ space["temperature_or_top_p"] = {"temperature": temperature}
414
+ elif temperature is None and top_p is not None:
415
+ space["temperature_or_top_p"] = {"top_p": top_p}
416
+ elif temperature is not None and top_p is not None:
417
+ space.pop("temperature_or_top_p")
418
+ space["temperature"] = temperature
419
+ space["top_p"] = top_p
420
+ log.warning("temperature and top_p are not recommended to vary together.")
421
+
422
+ # Note: currently we fix the model rather than make it tunable
423
+ search_alg = BlendSearch(
424
+ cost_attr="cost",
425
+ cost_budget=optimization_budget,
426
+ metric=metric,
427
+ mode=mode,
428
+ space=space,
429
+ )
430
+
431
+ # Args:
432
+ # evaluation_function: A user-defined evaluation function.
433
+ # It takes a configuration as input, outputs a evaluation
434
+ # result (can be a numerical value or a dictionary of string
435
+ # and numerical value pairs) for the input configuration.
436
+ # For machine learning tasks, it usually involves training and
437
+ # scoring a machine learning model, e.g., through validation loss.
438
+
439
+
440
+ def updated_flow_config_with_search_config(flow_config: Dict[str, Any], search_config: Dict[str, Any]):
441
+ """
442
+ inputs are immutable
443
+ """
444
+ flow_config = deepcopy(flow_config)
445
+ search_config = deepcopy(search_config)
446
+
447
+ temperature_or_top_p = search_config.pop("temperature_or_top_p", None)
448
+ if temperature_or_top_p is not None:
449
+ search_config.update(temperature_or_top_p)
450
+
451
+ flow_config["model_name"] = search_config.get("model", flow_config["model_name"])
452
+ generation_parameters = flow_config["generation_parameters"]
453
+ for generation_parameter in generation_parameters:
454
+ if generation_parameter == "model_kwargs":
455
+ continue
456
+ if generation_parameter in search_config:
457
+ generation_parameters[generation_parameter] = search_config[generation_parameter]
458
+
459
+ model_kwargs = generation_parameters["model_kwargs"]
460
+ for model_kwarg in model_kwargs:
461
+ if model_kwarg in search_config:
462
+ model_kwargs[model_kwarg] = search_config[model_kwarg]
463
+
464
+ return flow_config
465
+
466
+ def tune_run_eval(search_config: Dict[str, Any]) -> Dict[str, float]:
467
+ """
468
+ evaluation_function: A user-defined evaluation function.
469
+ It takes a configuration as input, outputs a evaluation
470
+ result (can be a numerical value or a dictionary of string
471
+ and numerical value pairs) for the input configuration.
472
+ For machine learning tasks, it usually involves training and
473
+ scoring a machine learning model, e.g., through validation loss.
474
+ """
475
+ # extract the flow_construct_kwargs from search_config
476
+ """
477
+ {'expected_inputs': [], 'expected_outputs': [], 'flow_type': 'Flow', 'verbose': True, 'dry_run': False, 'namespace_clearing_after_run': True, 'n_api_retries': 6, 'wait_time_between_retries': 20, 'system_name': 'system', 'user_name': 'user', 'assistant_name': 'assistant', 'response_annotators': {'code_extractor': <flows.message_annotators.regex_extractor_first.RegexFirstOccurrenceExtractor object at 0x7f532121bc70>}, 'query_message_prompt_template': {'_target_': 'langchain.PromptTemplate', 'template': '# Problem statement\n{{problem_description}}\n\n# Input description\n{{input_description}}\n\n# Output description\n{{output_description}}\n\n{{io_examples_and_explanation}}\n\n\nThe input should be read from the standard input and the output should be passed to the standard output.\nReturn Python code that solves the problem. Reply in the following format:\n```python\n{{code_placeholder}}\n```', 'input_variables': ['problem_description', 'input_description', 'output_description', 'io_examples_and_explanation'], 'partial_variables': {'code_placeholder': '{{python_code}}'}, 'template_format': 'jinja2'}, 'demonstrations': None, 'demonstrations_response_template': None, 'name': 'CodeAgent', 'description': 'ToDO: add description', 'model_name': 'gpt-3.5-turbo', 'generation_parameters': {'n': 1, 'max_tokens': 3000, 'temperature': 0.3, 'model_kwargs': {'top_p': 0.2, 'frequency_penalty': 0, 'presence_penalty': 0}}, 'system_message_prompt_template': {'_target_': 'langchain.PromptTemplate', 'template': 'Your goal is to provide executable Python code that solves a competitive programming problem. The code should correctly handle all corner cases in order to pass the hidden test cases, which are used to evaluate the correctness of the solution.\n\nThe user will specify the problem by providing you with:\n - the problem statement\n - input description\n - output description\n - example test cases\n - (optional) explanation of the test cases\n\nThe user will provide you with a task and an output format that you will strictly follow.', 'input_variables': [], 'template_format': 'jinja2'}, 'human_message_prompt_template': {'_target_': 'langchain.PromptTemplate', 'template': '{{query}}', 'input_variables': ['query'], 'template_format': 'jinja2'}}
478
+ """
479
+ log.info(f"Tunning with config: {search_config}")
480
+ # TODO: the code currently only works when there is no subspace, i.e. there is only one model to tune with
481
+ # align search_config with flow_config
482
+ updated_flow_config = updated_flow_config_with_search_config(flow_config=initial_flow_config, search_config=search_config)
483
+ log.info(f"Updated flow_config: {updated_flow_config}")
484
+ # flow_launcher = FlowAPILauncher(flow, 1, False, 3, 0, ["code"]) TODO: maybe refactor with flow_launcher
485
+
486
+ # TODO: limitations: langchain api call does not give us the cost of the api call, and only give us
487
+ # one result no matter the n
488
+ final_metrics = {}
489
+ for sample in tune_dps:
490
+ sample["api_key"] = api_key
491
+ # log.info(f"sample: {sample}")
492
+ flow = cls.instantiate_from_config(updated_flow_config)
493
+ task_message = flow.package_task_message(recipient_flow=flow,
494
+ task_name="run_task",
495
+ task_data=sample,
496
+ expected_outputs=["code"])
497
+ output_message = flow(task_message)
498
+ # log.info(f"output_message: {output_message}")
499
+
500
+ metrics = eval_func(output_message.data['code'], sample)
501
+ log.info(f"metrics for dp: {metrics}")
502
+ if not final_metrics:
503
+ final_metrics = metrics
504
+ else:
505
+ for k, v in metrics.items():
506
+ final_metrics[k] += v
507
+ log.info(f"final metric {final_metrics} for this config {search_config}")
508
+ return final_metrics
509
+
510
+ analysis = tune.run(
511
+ tune_run_eval,
512
+ search_alg=search_alg,
513
+ num_samples=num_samples,
514
+ log_file_name=log_file_name,
515
+ verbose=3,
516
+ )
517
+ best_search_config = analysis.best_config
518
+ flow_config = updated_flow_config_with_search_config(initial_flow_config, best_search_config)
519
+ log.info(f"best search config found: {best_search_config}, analysis: {analysis.best_result}")
520
+ return flow_config, analysis
FLAMLOpenAIChatAtomicFlow.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an abstract flow, therefore some required fields are not defined (and must be defined by the concrete flow)
2
+
3
+ n_api_retries: 6
4
+ wait_time_between_retries: 20
5
+
6
+ system_name: system
7
+ user_name: user
8
+ assistant_name: assistant
9
+
10
+ response_annotators: {}
11
+
12
+ query_message_prompt_template: null # ToDo: When will this be null?
13
+ demonstrations: null
14
+ demonstrations_response_template: null
README.md CHANGED
@@ -1,3 +1,4 @@
1
  ---
2
  license: mit
3
  ---
 
 
1
  ---
2
  license: mit
3
  ---
4
+
__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .FLAMLOpenAIChatAtomicFlow import FLAMLOpenAIChatAtomicFlow