JohanWork commited on
Commit
b8e5603
1 Parent(s): 782b6a4

Add mlflow callback for pushing config to mlflow artifacts (#1125)

Browse files

* Update callbacks.py

adding callback for mlflow

* Update trainer_builder.py

* clean up

src/axolotl/core/trainer_builder.py CHANGED
@@ -28,6 +28,7 @@ from axolotl.utils.callbacks import (
28
  EvalFirstStepCallback,
29
  GPUStatsCallback,
30
  LossWatchDogCallback,
 
31
  SaveAxolotlConfigtoWandBCallback,
32
  SaveBetterTransformerModelCallback,
33
  bench_eval_callback_factory,
@@ -543,6 +544,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
543
  callbacks.append(
544
  SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
545
  )
 
 
 
 
546
 
547
  if self.cfg.loss_watchdog_threshold is not None:
548
  callbacks.append(LossWatchDogCallback(self.cfg))
 
28
  EvalFirstStepCallback,
29
  GPUStatsCallback,
30
  LossWatchDogCallback,
31
+ SaveAxolotlConfigtoMlflowCallback,
32
  SaveAxolotlConfigtoWandBCallback,
33
  SaveBetterTransformerModelCallback,
34
  bench_eval_callback_factory,
 
544
  callbacks.append(
545
  SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
546
  )
547
+ if self.cfg.use_mlflow:
548
+ callbacks.append(
549
+ SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
550
+ )
551
 
552
  if self.cfg.loss_watchdog_threshold is not None:
553
  callbacks.append(LossWatchDogCallback(self.cfg))
src/axolotl/utils/callbacks.py CHANGED
@@ -9,6 +9,7 @@ from tempfile import NamedTemporaryFile
9
  from typing import TYPE_CHECKING, Dict, List
10
 
11
  import evaluate
 
12
  import numpy as np
13
  import pandas as pd
14
  import torch
@@ -575,3 +576,31 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
575
  except (FileNotFoundError, ConnectionError) as err:
576
  LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
577
  return control
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from typing import TYPE_CHECKING, Dict, List
10
 
11
  import evaluate
12
+ import mlflow
13
  import numpy as np
14
  import pandas as pd
15
  import torch
 
576
  except (FileNotFoundError, ConnectionError) as err:
577
  LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
578
  return control
579
+
580
+
581
+ class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
582
+ """Callback to save axolotl config to mlflow"""
583
+
584
+ def __init__(self, axolotl_config_path):
585
+ self.axolotl_config_path = axolotl_config_path
586
+
587
+ def on_train_begin(
588
+ self,
589
+ args: AxolotlTrainingArguments, # pylint: disable=unused-argument
590
+ state: TrainerState, # pylint: disable=unused-argument
591
+ control: TrainerControl,
592
+ **kwargs, # pylint: disable=unused-argument
593
+ ):
594
+ if is_main_process():
595
+ try:
596
+ with NamedTemporaryFile(
597
+ mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
598
+ ) as temp_file:
599
+ copyfile(self.axolotl_config_path, temp_file.name)
600
+ mlflow.log_artifact(temp_file.name, artifact_path="")
601
+ LOG.info(
602
+ "The Axolotl config has been saved to the MLflow artifacts."
603
+ )
604
+ except (FileNotFoundError, ConnectionError) as err:
605
+ LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
606
+ return control