yeeef commited on
Commit
7361bee
1 Parent(s): 6542487

add tune method, prepare to refactor it as self method

Browse files
Files changed (1) hide show
  1. OpenAIChatAtomicFlow.py +276 -1
OpenAIChatAtomicFlow.py CHANGED
@@ -2,11 +2,14 @@ import pprint
2
  from copy import deepcopy
3
 
4
  import hydra
 
5
 
6
  import colorama
7
  import time
8
 
9
- from typing import List, Dict, Optional, Any
 
 
10
 
11
  from langchain import PromptTemplate
12
  import langchain
@@ -22,6 +25,7 @@ from flows.messages.chat_message import ChatMessage
22
  from flows.utils.caching_utils import flow_run_cache
23
 
24
  log = utils.get_pylogger(__name__)
 
25
 
26
 
27
  class OpenAIChatAtomicFlow(AtomicFlow):
@@ -43,6 +47,28 @@ class OpenAIChatAtomicFlow(AtomicFlow):
43
  demonstrations_response_template: PromptTemplate = None
44
  response_annotators: Optional[Dict[str, MessageAnnotator]] = {}
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def __init__(self, **kwargs):
47
  self._validate_parameters(kwargs)
48
  super().__init__(**kwargs)
@@ -321,3 +347,252 @@ class OpenAIChatAtomicFlow(AtomicFlow):
321
 
322
  # ~~~ The final answer should be in self.flow_state, thus allow_class_namespace=False ~~~
323
  return self._get_keys_from_state(keys=expected_outputs, allow_class_namespace=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from copy import deepcopy
3
 
4
  import hydra
5
+ import logging
6
 
7
  import colorama
8
  import time
9
 
10
+ from typing import List, Dict, Optional, Any, Callable, Tuple
11
+
12
+ from flaml import tune, BlendSearch
13
 
14
  from langchain import PromptTemplate
15
  import langchain
 
25
  from flows.utils.caching_utils import flow_run_cache
26
 
27
  log = utils.get_pylogger(__name__)
28
+ logger = log
29
 
30
 
31
  class OpenAIChatAtomicFlow(AtomicFlow):
 
47
  demonstrations_response_template: PromptTemplate = None
48
  response_annotators: Optional[Dict[str, MessageAnnotator]] = {}
49
 
50
+ default_search_space = {
51
+ "model": tune.choice(
52
+ [
53
+ # "text-ada-001",
54
+ # "text-babbage-001",
55
+ # "text-davinci-003",
56
+ "gpt-3.5-turbo",
57
+ # "gpt-4",
58
+ ]
59
+ ),
60
+ "temperature_or_top_p": tune.choice(
61
+ [
62
+ {"temperature": tune.uniform(0, 2)},
63
+ {"top_p": tune.uniform(0, 1)},
64
+ ]
65
+ ),
66
+ "max_tokens": tune.lograndint(1000, 4000),
67
+ # we use langchain api, https://github.com/hwchase17/langchain/blob/master/langchain/chat_models/base.py#L201
68
+ # it only take the first generation as the output, thus n is not relevant
69
+ # "n": tune.randint(1, 100),
70
+ }
71
+
72
  def __init__(self, **kwargs):
73
  self._validate_parameters(kwargs)
74
  super().__init__(**kwargs)
 
347
 
348
  # ~~~ The final answer should be in self.flow_state, thus allow_class_namespace=False ~~~
349
  return self._get_keys_from_state(keys=expected_outputs, allow_class_namespace=False)
350
+
351
+ @classmethod
352
+ def tune(
353
+ cls,
354
+ tune_dps: List[Dict],
355
+ metric: str,
356
+ mode: str,
357
+ eval_func: Callable,
358
+ api_key: str,
359
+ log_file_name: Optional[str] = None, # TODO(yeeef)
360
+ inference_budget: Optional[float] = None,
361
+ optimization_budget: Optional[float] = None,
362
+ num_samples: Optional[int] = 1,
363
+ logging_level: Optional[int] = logging.WARN, # TODO(yeeef)
364
+ **config,
365
+ ) -> Tuple[Dict, Any]: # tune.ExperimentAnalysis
366
+ """
367
+ Args:
368
+ - tune_dps (list): The list of data points to tune the hyperparameters.
369
+ - metric (str): The metric to optimize.
370
+ - mode (str): The optimization mode, "min" or "max.
371
+ - eval_func (Callable): The evaluation function for responses.
372
+ The function should take a response and a data point as input,
373
+ and return a dict of metrics.
374
+ - log_file_name (str, optional): The log file.
375
+ - inference_budget (float, optional): The inference budget, dollar per instance.
376
+ - optimization_budget (float, optional): The optimization budget, dollar in total.
377
+ - num_samples (int, optional): The number of samples to evaluate.
378
+ -1 means no hard restriction in the number of trials
379
+ and the actual number is decided by optimization_budget. Defaults to 1.
380
+ - logging_level (optional): logging level. Defaults to logging.WARNING.
381
+ - **config (dict): The search space to update over the default search.
382
+ For prompt, please provide a string/Callable or a list of strings/Callables.
383
+ - If prompt is provided for chat models, it will be converted to messages under role "user".
384
+ - Do not provide both prompt and messages for chat models, but provide either of them.
385
+ - A string template will be used to generate a prompt for each data instance
386
+ using `prompt.format(**data)`.
387
+ - A callable template will be used to generate a prompt for each data instance
388
+ using `prompt(data)`.
389
+ For stop, please provide a string, a list of strings, or a list of lists of strings.
390
+ For messages (chat models only), please provide a list of messages (for a single chat prefix)
391
+ or a list of lists of messages (for multiple choices of chat prefix to choose from).
392
+ Each message should be a dict with keys "role" and "content". The value of "content" can be a string/Callable template.
393
+
394
+ Returns:
395
+ - dict: The optimized hyperparameter setting.
396
+ - tune.ExperimentAnalysis: The tuning results.
397
+ """
398
+
399
+ space = cls.default_search_space.copy()
400
+
401
+ if config is not None:
402
+ space.update(config)
403
+ if "messages" in space:
404
+ space.pop("prompt", None)
405
+ temperature = space.pop("temperature", None)
406
+ top_p = space.pop("top_p", None)
407
+ if temperature is not None and top_p is None:
408
+ space["temperature_or_top_p"] = {"temperature": temperature}
409
+ elif temperature is None and top_p is not None:
410
+ space["temperature_or_top_p"] = {"top_p": top_p}
411
+ elif temperature is not None and top_p is not None:
412
+ space.pop("temperature_or_top_p")
413
+ space["temperature"] = temperature
414
+ space["top_p"] = top_p
415
+ log.warning("temperature and top_p are not recommended to vary together.")
416
+
417
+ # TODO: shall we use cls method?
418
+ cls._max_valid_n_per_max_tokens, cls._min_invalid_n_per_max_tokens = {}, {}
419
+ cls.optimization_budget = optimization_budget
420
+ cls.inference_budget = inference_budget
421
+ cls._prune_hp = "best_of" if space.get("best_of", 1) != 1 else "n"
422
+ cls._prompts = space.get("prompt")
423
+
424
+ # if cls._prompts is None:
425
+ # cls._messages = space.get("messages")
426
+ # assert isinstance(cls._messages, list) and isinstance(
427
+ # cls._messages[0], (dict, list)
428
+ # ), "messages must be a list of dicts or a list of lists."
429
+ # if isinstance(cls._messages[0], dict):
430
+ # cls._messages = [cls._messages]
431
+ # space["messages"] = tune.choice(list(range(len(cls._messages))))
432
+ # else:
433
+ # assert space.get("messages") is None, "messages and prompt cannot be provided at the same time."
434
+ # assert isinstance(cls._prompts, (str, list)), "prompt must be a string or a list of strings."
435
+ # if isinstance(cls._prompts, str):
436
+ # cls._prompts = [cls._prompts]
437
+ # space["prompt"] = tune.choice(list(range(len(cls._prompts))))
438
+ # cls._stops = space.get("stop")
439
+ # if cls._stops:
440
+ # assert isinstance(
441
+ # cls._stops, (str, list)
442
+ # ), "stop must be a string, a list of strings, or a list of lists of strings."
443
+ # if not (isinstance(cls._stops, list) and isinstance(cls._stops[0], list)):
444
+ # cls._stops = [cls._stops]
445
+ # space["stop"] = tune.choice(list(range(len(cls._stops))))
446
+
447
+ # cls._config_list = space.get("config_list")
448
+ # if cls._config_list is not None:
449
+ # is_const = is_constant(cls._config_list)
450
+ # if is_const:
451
+ # space.pop("config_list")
452
+ # cls._metric, cls._mode = metric, mode
453
+ # cls._total_cost = 0 # total optimization cost
454
+ # cls._eval_func = eval_func
455
+ # cls.data = data
456
+ # cls.avg_input_tokens = None
457
+
458
+ space_model = space["model"]
459
+
460
+ if not isinstance(space_model, str) and len(space_model) > 1:
461
+ # make a hierarchical search space
462
+ subspace = {}
463
+ if "max_tokens" in space:
464
+ subspace["max_tokens"] = space.pop("max_tokens")
465
+ if "temperature_or_top_p" in space:
466
+ subspace["temperature_or_top_p"] = space.pop("temperature_or_top_p")
467
+ if "best_of" in space:
468
+ subspace["best_of"] = space.pop("best_of")
469
+ if "n" in space:
470
+ subspace["n"] = space.pop("n")
471
+ choices = []
472
+ for model in space["model"]:
473
+ choices.append({"model": model, **subspace})
474
+ space["subspace"] = tune.choice(choices)
475
+ space.pop("model")
476
+ # start all the models with the same hp config
477
+ search_alg = BlendSearch(
478
+ cost_attr="cost",
479
+ cost_budget=optimization_budget,
480
+ metric=metric,
481
+ mode=mode,
482
+ space=space,
483
+ )
484
+ config0 = search_alg.suggest("t0")
485
+ points_to_evaluate = [config0]
486
+ for model in space_model:
487
+ if model != config0["subspace"]["model"]:
488
+ point = config0.copy()
489
+ point["subspace"] = point["subspace"].copy()
490
+ point["subspace"]["model"] = model
491
+ points_to_evaluate.append(point)
492
+ search_alg = BlendSearch(
493
+ cost_attr="cost",
494
+ cost_budget=optimization_budget,
495
+ metric=metric,
496
+ mode=mode,
497
+ space=space,
498
+ points_to_evaluate=points_to_evaluate,
499
+ )
500
+ else:
501
+ # TODO: currently we always falls in this branch
502
+ search_alg = BlendSearch(
503
+ cost_attr="cost",
504
+ cost_budget=optimization_budget,
505
+ metric=metric,
506
+ mode=mode,
507
+ space=space,
508
+ )
509
+
510
+ # Args:
511
+ # evaluation_function: A user-defined evaluation function.
512
+ # It takes a configuration as input, outputs a evaluation
513
+ # result (can be a numerical value or a dictionary of string
514
+ # and numerical value pairs) for the input configuration.
515
+ # For machine learning tasks, it usually involves training and
516
+ # scoring a machine learning model, e.g., through validation loss.
517
+
518
+
519
+ def updated_flow_config_with_search_config(flow_config: Dict[str, Any], search_config: Dict[str, Any]):
520
+ """
521
+ inputs are immutable
522
+ """
523
+ flow_config = deepcopy(flow_config)
524
+ search_config = deepcopy(search_config)
525
+
526
+ temperature_or_top_p = search_config.pop("temperature_or_top_p", None)
527
+ if temperature_or_top_p is not None:
528
+ search_config.update(temperature_or_top_p)
529
+
530
+ flow_config["model_name"] = search_config["model"]
531
+ generation_parameters = flow_config["generation_parameters"]
532
+ for generation_parameter in generation_parameters:
533
+ if generation_parameter == "model_kwargs":
534
+ continue
535
+ if generation_parameter in search_config:
536
+ generation_parameters[generation_parameter] = search_config[generation_parameter]
537
+
538
+ model_kwargs = generation_parameters["model_kwargs"]
539
+ for model_kwarg in model_kwargs:
540
+ if model_kwarg in search_config:
541
+ model_kwargs[model_kwarg] = search_config[model_kwarg]
542
+
543
+ return flow_config
544
+
545
+ def tune_run_eval(search_config: Dict[str, Any]) -> Dict[str, float]:
546
+ """
547
+ evaluation_function: A user-defined evaluation function.
548
+ It takes a configuration as input, outputs a evaluation
549
+ result (can be a numerical value or a dictionary of string
550
+ and numerical value pairs) for the input configuration.
551
+ For machine learning tasks, it usually involves training and
552
+ scoring a machine learning model, e.g., through validation loss.
553
+ """
554
+ # extract the flow_construct_kwargs from search_config
555
+ """
556
+ {'expected_inputs': [], 'expected_outputs': [], 'flow_type': 'Flow', 'verbose': True, 'dry_run': False, 'namespace_clearing_after_run': True, 'n_api_retries': 6, 'wait_time_between_retries': 20, 'system_name': 'system', 'user_name': 'user', 'assistant_name': 'assistant', 'response_annotators': {'code_extractor': <flows.message_annotators.regex_extractor_first.RegexFirstOccurrenceExtractor object at 0x7f532121bc70>}, 'query_message_prompt_template': {'_target_': 'langchain.PromptTemplate', 'template': '# Problem statement\n{{problem_description}}\n\n# Input description\n{{input_description}}\n\n# Output description\n{{output_description}}\n\n{{io_examples_and_explanation}}\n\n\nThe input should be read from the standard input and the output should be passed to the standard output.\nReturn Python code that solves the problem. Reply in the following format:\n```python\n{{code_placeholder}}\n```', 'input_variables': ['problem_description', 'input_description', 'output_description', 'io_examples_and_explanation'], 'partial_variables': {'code_placeholder': '{{python_code}}'}, 'template_format': 'jinja2'}, 'demonstrations': None, 'demonstrations_response_template': None, 'name': 'CodeAgent', 'description': 'ToDO: add description', 'model_name': 'gpt-3.5-turbo', 'generation_parameters': {'n': 1, 'max_tokens': 3000, 'temperature': 0.3, 'model_kwargs': {'top_p': 0.2, 'frequency_penalty': 0, 'presence_penalty': 0}}, 'system_message_prompt_template': {'_target_': 'langchain.PromptTemplate', 'template': 'Your goal is to provide executable Python code that solves a competitive programming problem. The code should correctly handle all corner cases in order to pass the hidden test cases, which are used to evaluate the correctness of the solution.\n\nThe user will specify the problem by providing you with:\n - the problem statement\n - input description\n - output description\n - example test cases\n - (optional) explanation of the test cases\n\nThe user will provide you with a task and an output format that you will strictly follow.', 'input_variables': [], 'template_format': 'jinja2'}, 'human_message_prompt_template': {'_target_': 'langchain.PromptTemplate', 'template': '{{query}}', 'input_variables': ['query'], 'template_format': 'jinja2'}}
557
+ """
558
+ log.info(f"Tunning with config: {search_config}")
559
+ # TODO: the code currently only works when there is no subspace, i.e. there is only one model to tune with
560
+ # align search_config with flow_config
561
+ updated_flow_config = updated_flow_config_with_search_config(flow_config=cls.get_config(), search_config=search_config)
562
+ log.info(f"Updated flow_config: {updated_flow_config}")
563
+ # flow_launcher = FlowAPILauncher(flow, 1, False, 3, 0, ["code"]) TODO: maybe refactor with flow_launcher
564
+
565
+ # TODO: limitations: langchain api call does not give us the cost of the api call
566
+ final_metrics = {}
567
+ for sample in tune_dps:
568
+ sample["api_key"] = api_key
569
+ # log.info(f"sample: {sample}")
570
+ flow = cls.instantiate_from_config(updated_flow_config)
571
+ task_message = flow.package_task_message(recipient_flow=flow,
572
+ task_name="run_task",
573
+ task_data=sample,
574
+ expected_outputs=["code"])
575
+ output_message = flow(task_message)
576
+ # log.info(f"output_message: {output_message}")
577
+
578
+ metrics = eval_func(output_message.data['code'], sample)
579
+ log.info(f"metrics for dp: {metrics}")
580
+ if not final_metrics:
581
+ final_metrics = metrics
582
+ else:
583
+ for k, v in metrics.items():
584
+ final_metrics[k] += v
585
+ log.info(f"final metric {final_metrics} for this config {search_config}")
586
+ return final_metrics
587
+
588
+ analysis = tune.run(
589
+ tune_run_eval,
590
+ search_alg=search_alg,
591
+ num_samples=num_samples,
592
+ log_file_name=log_file_name,
593
+ verbose=3,
594
+ )
595
+ best_search_config = analysis.best_config
596
+ flow_config = updated_flow_config_with_search_config(cls.get_config(), best_search_config)
597
+ log.info(f"best search config found: {best_search_config}, analysis: {analysis.best_result}")
598
+ return flow_config, analysis