File size: 24,683 Bytes
e948884 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 |
import pprint
from copy import deepcopy
import hydra
import logging
import colorama
import time
from typing import List, Dict, Optional, Any, Callable, Tuple
from flaml import tune, BlendSearch
from langchain import PromptTemplate
import langchain
from langchain.schema import HumanMessage, AIMessage, SystemMessage
from flows.history import FlowHistory
from flows.message_annotators.abstract import MessageAnnotator
from flows.base_flows.abstract import AtomicFlow
from flows.datasets import GenericDemonstrationsDataset
from flows import utils
from flows.messages.chat_message import ChatMessage
from flows.utils.caching_utils import flow_run_cache
log = utils.get_pylogger(__name__)
logger = log
class FLAMLOpenAIChatAtomicFlow(AtomicFlow):
model_name: str
generation_parameters: Dict
system_message_prompt_template: PromptTemplate
human_message_prompt_template: PromptTemplate
system_name: str = "system"
user_name: str = "user"
assistant_name: str = "assistant"
n_api_retries: int = 6
wait_time_between_retries: int = 20
query_message_prompt_template: Optional[PromptTemplate] = None
demonstrations: GenericDemonstrationsDataset = None
demonstrations_response_template: PromptTemplate = None
response_annotators: Optional[Dict[str, MessageAnnotator]] = {}
default_search_space = {
# "model": tune.choice(
# [
# # "text-ada-001",
# # "text-babbage-001",
# # "text-davinci-003",
# "gpt-3.5-turbo",
# # "gpt-4",
# ]
# ),
"temperature_or_top_p": tune.choice(
[
{"temperature": tune.uniform(0, 2)},
{"top_p": tune.uniform(0, 1)},
]
),
"max_tokens": tune.lograndint(1000, 4000),
# we use langchain api, https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/base.py#L201
# it only take the first generation as the output, thus n is not relevant
# "n": tune.randint(1, 100),
}
def __init__(self, **kwargs):
self._validate_parameters(kwargs)
super().__init__(**kwargs)
assert self.flow_config["name"] not in [
"system",
"user",
"assistant",
], f"Flow name '{self.flow_config['name']}' cannot be 'system', 'user' or 'assistant'"
def set_up_flow_state(self):
super().set_up_flow_state()
self.flow_state["conversation_initialized"] = False
@classmethod
def _validate_parameters(cls, kwargs):
# ToDo: Deal with this in a cleaner way (with less repetition)
super()._validate_parameters(kwargs)
# ~~~ Model generation ~~~
if "model_name" not in kwargs["flow_config"]:
raise KeyError("model_name not specified in the flow_config.")
if "generation_parameters" not in kwargs["flow_config"]:
raise KeyError("generation_parameters not specified in the flow_config.")
# ~~~ Prompting ~~~
if "system_message_prompt_template" not in kwargs:
raise KeyError("system_message_prompt_template not passed to the constructor.")
if "query_message_prompt_template" not in kwargs:
raise KeyError("query_message_prompt_template not passed to the constructor.")
if "human_message_prompt_template" not in kwargs:
raise KeyError("human_message_prompt_template not passed to the constructor.")
@classmethod
def _set_up_prompts(cls, config):
kwargs = {}
kwargs["system_message_prompt_template"] = \
hydra.utils.instantiate(config['system_message_prompt_template'], _convert_="partial")
kwargs["query_message_prompt_template"] = \
hydra.utils.instantiate(config['query_message_prompt_template'], _convert_="partial")
kwargs["human_message_prompt_template"] = \
hydra.utils.instantiate(config['human_message_prompt_template'], _convert_="partial")
return kwargs
@classmethod
def _set_up_demonstration_templates(cls, config):
kwargs = {}
if "demonstrations_response_template" in config:
kwargs["demonstrations_response_template"] = \
hydra.utils.instantiate(config['demonstrations_response_template'], _convert_="partial")
return kwargs
@classmethod
def _set_up_response_annotators(cls, config):
response_annotators = config.get("response_annotators", {})
if len(response_annotators) > 0:
for key, config in response_annotators.items():
if isinstance(config, MessageAnnotator):
response_annotators[key] = config
else:
response_annotators[key] = hydra.utils.instantiate(config, _convert_="partial")
return {"response_annotators": response_annotators}
@classmethod
def instantiate_from_config(cls, config):
flow_config = deepcopy(config)
kwargs = {"flow_config": flow_config}
# ~~~ Set up prompts ~~~
kwargs.update(cls._set_up_prompts(flow_config))
# ~~~ Set up demonstration templates ~~~
kwargs.update(cls._set_up_demonstration_templates(flow_config))
# ~~~ Set up response annotators ~~~
kwargs.update(cls._set_up_response_annotators(flow_config))
# ~~~ Instantiate flow ~~~
return cls(**kwargs)
def _is_conversation_initialized(self):
return self.flow_state["conversation_initialized"]
def expected_inputs_given_state(self):
if self._is_conversation_initialized():
return ["query"]
else:
return self.flow_config["expected_inputs"]
@staticmethod
def _get_message(prompt_template, input_data: Dict[str, Any]):
template_kwargs = {}
for input_variable in prompt_template.input_variables:
template_kwargs[input_variable] = input_data[input_variable]
msg_content = prompt_template.format(**template_kwargs)
return msg_content
def _get_demonstration_query_message_content(self, sample_data: Dict):
input_variables = self.query_message_prompt_template.input_variables
return self.query_message_prompt_template.format(**{k: sample_data[k] for k in input_variables}), []
def _get_demonstration_response_message_content(self, sample_data: Dict):
input_variables = self.demonstrations_response_template.input_variables
return self.demonstrations_response_template.format(**{k: sample_data[k] for k in input_variables}), []
def _get_annotator_with_key(self, key: str):
for _, ra in self.response_annotators.items():
if ra.key == key:
return ra
def _response_parsing(self, response: str, expected_outputs: List[str]):
target_annotators = [ra for _, ra in self.response_annotators.items() if ra.key in expected_outputs]
if len(target_annotators) == 0:
return {expected_outputs[0]: response}
parsed_outputs = {}
for ra in target_annotators:
parsed_out = ra(response)
parsed_outputs.update(parsed_out)
return parsed_outputs
def _add_demonstrations(self):
if self.demonstrations is not None:
for example in self.demonstrations:
query, parents = self._get_demonstration_query_message_content(example)
response, parents = self._get_demonstration_response_message_content(example)
self._log_chat_message(content=query,
message_creator=self.user_name,
parent_message_ids=parents)
self._log_chat_message(content=response,
message_creator=self.assistant_name,
parent_message_ids=parents)
def _log_chat_message(self, message_creator: str, content: str, parent_message_ids: List[str] = None):
chat_message = ChatMessage(
message_creator=message_creator,
parent_message_ids=parent_message_ids,
flow_runner=self.flow_config["name"],
flow_run_id=self.flow_run_id,
content=content
)
return self._log_message(chat_message)
def _initialize_conversation(self, input_data: Dict[str, Any]):
# ~~~ Add the system message ~~~
system_message_content = self._get_message(self.system_message_prompt_template, input_data)
self._log_chat_message(content=system_message_content,
message_creator=self.system_name)
# ~~~ Add the demonstration query-response tuples (if any) ~~~
self._add_demonstrations()
self._update_state(update_data={"conversation_initialized": True})
def get_conversation_messages(self, message_format: Optional[str] = None):
messages = self.flow_state["history"].get_chat_messages()
if message_format is None:
return messages
elif message_format == "open_ai":
processed_messages = []
for message in messages:
if message.message_creator == self.system_name:
processed_messages.append(SystemMessage(content=message.content))
elif message.message_creator == self.assistant_name:
processed_messages.append(AIMessage(content=message.content))
elif message.message_creator == self.user_name:
processed_messages.append(HumanMessage(content=message.content))
else:
raise ValueError(f"Unknown name: {message.message_creator}")
return processed_messages
else:
raise ValueError(
f"Currently supported conversation message formats: 'open_ai'. '{message_format}' is not supported")
def _call(self):
api_key = self.flow_state["api_key"]
backend = langchain.chat_models.ChatOpenAI(
model_name=self.flow_config["model_name"],
openai_api_key=api_key,
**self.flow_config["generation_parameters"],
)
messages = self.get_conversation_messages(
message_format="open_ai"
)
_success = False
attempts = 1
error = None
response = None
while attempts <= self.n_api_retries:
try:
response = backend(messages).content
_success = True
break
except Exception as e:
log.error(
f"Error {attempts} in calling backend: {e}. Key used: `{api_key}`. "
f"Retrying in {self.wait_time_between_retries} seconds..."
)
log.error(
f"API call raised Exception with the following arguments arguments: "
f"\n{self.flow_state['history'].to_string()}"
)
attempts += 1
time.sleep(self.wait_time_between_retries)
error = e
if not _success:
raise error
if self.flow_config["verbose"]:
messages_str = self.flow_state["history"].to_string()
log.info(
f"\n{colorama.Fore.MAGENTA}~~~ History [{self.flow_config['name']}] ~~~\n"
f"{colorama.Style.RESET_ALL}{messages_str}"
)
return response
def _prepare_conversation(self, input_data: Dict[str, Any]):
if self._is_conversation_initialized():
# ~~~ Check that the message has a `query` field ~~~
user_message_content = self.human_message_prompt_template.format(query=input_data["query"])
else:
self._initialize_conversation(input_data)
user_message_content = self._get_message(self.query_message_prompt_template, input_data)
self._log_chat_message(message_creator=self.user_name,
content=user_message_content)
@flow_run_cache()
def run(self, input_data: Dict[str, Any], expected_outputs: List[str]) -> Dict[str, Any]:
# ~~~ Chat-specific preparation ~~~
self._prepare_conversation(input_data)
# ~~~ Call ~~~
response = self._call()
answer_message = self._log_chat_message(
message_creator=self.flow_config["assistant_name"],
content=response
)
# ~~~ Response parsing ~~~
parsed_outputs = self._response_parsing(
response=response,
expected_outputs=expected_outputs
)
self._update_state(update_data=parsed_outputs)
if self.flow_config["verbose"]:
parsed_output_messages_str = pprint.pformat({k: m for k, m in parsed_outputs.items()},
indent=4)
log.info(
f"\n{colorama.Fore.MAGENTA}~~~ "
f"Response [{answer_message.message_creator} -- "
f"{answer_message.message_id} -- "
f"{answer_message.flow_run_id}] ~~~"
f"\n{colorama.Fore.YELLOW}Content: {answer_message}{colorama.Style.RESET_ALL}"
f"\n{colorama.Fore.YELLOW}Parsed Outputs: {parsed_output_messages_str}{colorama.Style.RESET_ALL}"
)
# ~~~ The final answer should be in self.flow_state, thus allow_class_namespace=False ~~~
return self._get_keys_from_state(keys=expected_outputs, allow_class_namespace=False)
@classmethod
def tune(
cls,
tune_dps: List[Dict],
metric: str,
mode: str,
eval_func: Callable,
api_key: str,
log_file_name: Optional[str] = None, # TODO(yeeef)
inference_budget: Optional[float] = None,
optimization_budget: Optional[float] = None,
num_samples: Optional[int] = 1,
logging_level: Optional[int] = logging.WARN, # TODO(yeeef)
initial_flow_config: Optional[Dict] = None, # if not supplied will use default flow config of the class (xxx.yaml)
**config,
) -> Tuple[Dict, Any]: # tune.ExperimentAnalysis
"""
Args:
- tune_dps (list): The list of data points to tune the hyperparameters.
- metric (str): The metric to optimize.
- mode (str): The optimization mode, "min" or "max.
- eval_func (Callable): The evaluation function for responses.
The function should take a response and a data point as input,
and return a dict of metrics.
- log_file_name (str, optional): The log file.
- inference_budget (float, optional): The inference budget, dollar per instance.
- optimization_budget (float, optional): The optimization budget, dollar in total.
- num_samples (int, optional): The number of samples to evaluate.
-1 means no hard restriction in the number of trials
and the actual number is decided by optimization_budget. Defaults to 1.
- logging_level (optional): logging level. Defaults to logging.WARNING.
- **config (dict): The search space to update over the default search.
For prompt, please provide a string/Callable or a list of strings/Callables.
- If prompt is provided for chat models, it will be converted to messages under role "user".
- Do not provide both prompt and messages for chat models, but provide either of them.
- A string template will be used to generate a prompt for each data instance
using `prompt.format(**data)`.
- A callable template will be used to generate a prompt for each data instance
using `prompt(data)`.
For stop, please provide a string, a list of strings, or a list of lists of strings.
For messages (chat models only), please provide a list of messages (for a single chat prefix)
or a list of lists of messages (for multiple choices of chat prefix to choose from).
Each message should be a dict with keys "role" and "content". The value of "content" can be a string/Callable template.
Returns:
- dict: The optimized hyperparameter setting.
- tune.ExperimentAnalysis: The tuning results.
"""
initial_flow_config = initial_flow_config or cls.get_config()
space = cls.default_search_space.copy()
if config is not None:
space.update(config)
if "messages" in space:
space.pop("prompt", None)
temperature = space.pop("temperature", None)
top_p = space.pop("top_p", None)
if temperature is not None and top_p is None:
space["temperature_or_top_p"] = {"temperature": temperature}
elif temperature is None and top_p is not None:
space["temperature_or_top_p"] = {"top_p": top_p}
elif temperature is not None and top_p is not None:
space.pop("temperature_or_top_p")
space["temperature"] = temperature
space["top_p"] = top_p
log.warning("temperature and top_p are not recommended to vary together.")
# Note: currently we fix the model rather than make it tunable
search_alg = BlendSearch(
cost_attr="cost",
cost_budget=optimization_budget,
metric=metric,
mode=mode,
space=space,
)
# Args:
# evaluation_function: A user-defined evaluation function.
# It takes a configuration as input, outputs a evaluation
# result (can be a numerical value or a dictionary of string
# and numerical value pairs) for the input configuration.
# For machine learning tasks, it usually involves training and
# scoring a machine learning model, e.g., through validation loss.
def updated_flow_config_with_search_config(flow_config: Dict[str, Any], search_config: Dict[str, Any]):
"""
inputs are immutable
"""
flow_config = deepcopy(flow_config)
search_config = deepcopy(search_config)
temperature_or_top_p = search_config.pop("temperature_or_top_p", None)
if temperature_or_top_p is not None:
search_config.update(temperature_or_top_p)
flow_config["model_name"] = search_config.get("model", flow_config["model_name"])
generation_parameters = flow_config["generation_parameters"]
for generation_parameter in generation_parameters:
if generation_parameter == "model_kwargs":
continue
if generation_parameter in search_config:
generation_parameters[generation_parameter] = search_config[generation_parameter]
model_kwargs = generation_parameters["model_kwargs"]
for model_kwarg in model_kwargs:
if model_kwarg in search_config:
model_kwargs[model_kwarg] = search_config[model_kwarg]
return flow_config
def tune_run_eval(search_config: Dict[str, Any]) -> Dict[str, float]:
"""
evaluation_function: A user-defined evaluation function.
It takes a configuration as input, outputs a evaluation
result (can be a numerical value or a dictionary of string
and numerical value pairs) for the input configuration.
For machine learning tasks, it usually involves training and
scoring a machine learning model, e.g., through validation loss.
"""
# extract the flow_construct_kwargs from search_config
"""
{'expected_inputs': [], 'expected_outputs': [], 'flow_type': 'Flow', 'verbose': True, 'dry_run': False, 'namespace_clearing_after_run': True, 'n_api_retries': 6, 'wait_time_between_retries': 20, 'system_name': 'system', 'user_name': 'user', 'assistant_name': 'assistant', 'response_annotators': {'code_extractor': <flows.message_annotators.regex_extractor_first.RegexFirstOccurrenceExtractor object at 0x7f532121bc70>}, 'query_message_prompt_template': {'_target_': 'langchain.PromptTemplate', 'template': '# Problem statement\n{{problem_description}}\n\n# Input description\n{{input_description}}\n\n# Output description\n{{output_description}}\n\n{{io_examples_and_explanation}}\n\n\nThe input should be read from the standard input and the output should be passed to the standard output.\nReturn Python code that solves the problem. Reply in the following format:\n```python\n{{code_placeholder}}\n```', 'input_variables': ['problem_description', 'input_description', 'output_description', 'io_examples_and_explanation'], 'partial_variables': {'code_placeholder': '{{python_code}}'}, 'template_format': 'jinja2'}, 'demonstrations': None, 'demonstrations_response_template': None, 'name': 'CodeAgent', 'description': 'ToDO: add description', 'model_name': 'gpt-3.5-turbo', 'generation_parameters': {'n': 1, 'max_tokens': 3000, 'temperature': 0.3, 'model_kwargs': {'top_p': 0.2, 'frequency_penalty': 0, 'presence_penalty': 0}}, 'system_message_prompt_template': {'_target_': 'langchain.PromptTemplate', 'template': 'Your goal is to provide executable Python code that solves a competitive programming problem. The code should correctly handle all corner cases in order to pass the hidden test cases, which are used to evaluate the correctness of the solution.\n\nThe user will specify the problem by providing you with:\n - the problem statement\n - input description\n - output description\n - example test cases\n - (optional) explanation of the test cases\n\nThe user will provide you with a task and an output format that you will strictly follow.', 'input_variables': [], 'template_format': 'jinja2'}, 'human_message_prompt_template': {'_target_': 'langchain.PromptTemplate', 'template': '{{query}}', 'input_variables': ['query'], 'template_format': 'jinja2'}}
"""
log.info(f"Tunning with config: {search_config}")
# TODO: the code currently only works when there is no subspace, i.e. there is only one model to tune with
# align search_config with flow_config
updated_flow_config = updated_flow_config_with_search_config(flow_config=initial_flow_config, search_config=search_config)
log.info(f"Updated flow_config: {updated_flow_config}")
# flow_launcher = FlowAPILauncher(flow, 1, False, 3, 0, ["code"]) TODO: maybe refactor with flow_launcher
# TODO: limitations: langchain api call does not give us the cost of the api call, and only give us
# one result no matter the n
final_metrics = {}
for sample in tune_dps:
sample["api_key"] = api_key
# log.info(f"sample: {sample}")
flow = cls.instantiate_from_config(updated_flow_config)
task_message = flow.package_task_message(recipient_flow=flow,
task_name="run_task",
task_data=sample,
expected_outputs=["code"])
output_message = flow(task_message)
# log.info(f"output_message: {output_message}")
metrics = eval_func(output_message.data['code'], sample)
log.info(f"metrics for dp: {metrics}")
if not final_metrics:
final_metrics = metrics
else:
for k, v in metrics.items():
final_metrics[k] += v
log.info(f"final metric {final_metrics} for this config {search_config}")
return final_metrics
analysis = tune.run(
tune_run_eval,
search_alg=search_alg,
num_samples=num_samples,
log_file_name=log_file_name,
verbose=3,
)
best_search_config = analysis.best_config
flow_config = updated_flow_config_with_search_config(initial_flow_config, best_search_config)
log.info(f"best search config found: {best_search_config}, analysis: {analysis.best_result}")
return flow_config, analysis
|