Jan Philipp Harries commited on
Commit
490923f
1 Parent(s): 5855dde

Save Axolotl config as WandB artifact (#716)

Browse files
src/axolotl/cli/__init__.py CHANGED
@@ -194,6 +194,7 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
194
  # load the config from the yaml file
195
  with open(config, encoding="utf-8") as file:
196
  cfg: DictDefault = DictDefault(yaml.safe_load(file))
 
197
  # if there are any options passed in the cli, if it is something that seems valid from the yaml,
198
  # then overwrite the value
199
  cfg_keys = cfg.keys()
 
194
  # load the config from the yaml file
195
  with open(config, encoding="utf-8") as file:
196
  cfg: DictDefault = DictDefault(yaml.safe_load(file))
197
+ cfg.axolotl_config_path = config
198
  # if there are any options passed in the cli, if it is something that seems valid from the yaml,
199
  # then overwrite the value
200
  cfg_keys = cfg.keys()
src/axolotl/utils/callbacks.py CHANGED
@@ -514,3 +514,27 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
514
  return control
515
 
516
  return LogPredictionCallback
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
514
  return control
515
 
516
  return LogPredictionCallback
517
+
518
+
519
+ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
520
+ """Callback to save axolotl config to wandb"""
521
+
522
+ def __init__(self, axolotl_config_path):
523
+ self.axolotl_config_path = axolotl_config_path
524
+
525
+ def on_train_begin(
526
+ self,
527
+ args: AxolotlTrainingArguments, # pylint: disable=unused-argument
528
+ state: TrainerState, # pylint: disable=unused-argument
529
+ control: TrainerControl,
530
+ **kwargs, # pylint: disable=unused-argument
531
+ ):
532
+ if is_main_process():
533
+ try:
534
+ artifact = wandb.Artifact(name="axolotl-config", type="config")
535
+ artifact.add_file(local_path=self.axolotl_config_path)
536
+ wandb.run.log_artifact(artifact)
537
+ LOG.info("Axolotl config has been saved to WandB as an artifact.")
538
+ except (FileNotFoundError, ConnectionError) as err:
539
+ LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
540
+ return control
src/axolotl/utils/trainer.py CHANGED
@@ -30,6 +30,7 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
30
  from axolotl.utils.callbacks import (
31
  EvalFirstStepCallback,
32
  GPUStatsCallback,
 
33
  SaveBetterTransformerModelCallback,
34
  bench_eval_callback_factory,
35
  log_prediction_callback_factory,
@@ -775,6 +776,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
775
  LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
776
  trainer.add_callback(LogPredictionCallback(cfg))
777
 
 
 
 
778
  if cfg.do_bench_eval:
779
  trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
780
 
 
30
  from axolotl.utils.callbacks import (
31
  EvalFirstStepCallback,
32
  GPUStatsCallback,
33
+ SaveAxolotlConfigtoWandBCallback,
34
  SaveBetterTransformerModelCallback,
35
  bench_eval_callback_factory,
36
  log_prediction_callback_factory,
 
776
  LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
777
  trainer.add_callback(LogPredictionCallback(cfg))
778
 
779
+ if cfg.use_wandb:
780
+ trainer.add_callback(SaveAxolotlConfigtoWandBCallback(cfg.axolotl_config_path))
781
+
782
  if cfg.do_bench_eval:
783
  trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
784