JohanWork
commited on
Commit
•
b8e5603
1
Parent(s):
782b6a4
Add mlflow callback for pushing config to mlflow artifacts (#1125)
Browse files* Update callbacks.py
adding callback for mlflow
* Update trainer_builder.py
* clean up
src/axolotl/core/trainer_builder.py
CHANGED
@@ -28,6 +28,7 @@ from axolotl.utils.callbacks import (
|
|
28 |
EvalFirstStepCallback,
|
29 |
GPUStatsCallback,
|
30 |
LossWatchDogCallback,
|
|
|
31 |
SaveAxolotlConfigtoWandBCallback,
|
32 |
SaveBetterTransformerModelCallback,
|
33 |
bench_eval_callback_factory,
|
@@ -543,6 +544,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
543 |
callbacks.append(
|
544 |
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
545 |
)
|
|
|
|
|
|
|
|
|
546 |
|
547 |
if self.cfg.loss_watchdog_threshold is not None:
|
548 |
callbacks.append(LossWatchDogCallback(self.cfg))
|
|
|
28 |
EvalFirstStepCallback,
|
29 |
GPUStatsCallback,
|
30 |
LossWatchDogCallback,
|
31 |
+
SaveAxolotlConfigtoMlflowCallback,
|
32 |
SaveAxolotlConfigtoWandBCallback,
|
33 |
SaveBetterTransformerModelCallback,
|
34 |
bench_eval_callback_factory,
|
|
|
544 |
callbacks.append(
|
545 |
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
546 |
)
|
547 |
+
if self.cfg.use_mlflow:
|
548 |
+
callbacks.append(
|
549 |
+
SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
|
550 |
+
)
|
551 |
|
552 |
if self.cfg.loss_watchdog_threshold is not None:
|
553 |
callbacks.append(LossWatchDogCallback(self.cfg))
|
src/axolotl/utils/callbacks.py
CHANGED
@@ -9,6 +9,7 @@ from tempfile import NamedTemporaryFile
|
|
9 |
from typing import TYPE_CHECKING, Dict, List
|
10 |
|
11 |
import evaluate
|
|
|
12 |
import numpy as np
|
13 |
import pandas as pd
|
14 |
import torch
|
@@ -575,3 +576,31 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|
575 |
except (FileNotFoundError, ConnectionError) as err:
|
576 |
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
577 |
return control
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from typing import TYPE_CHECKING, Dict, List
|
10 |
|
11 |
import evaluate
|
12 |
+
import mlflow
|
13 |
import numpy as np
|
14 |
import pandas as pd
|
15 |
import torch
|
|
|
576 |
except (FileNotFoundError, ConnectionError) as err:
|
577 |
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
578 |
return control
|
579 |
+
|
580 |
+
|
581 |
+
class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
|
582 |
+
"""Callback to save axolotl config to mlflow"""
|
583 |
+
|
584 |
+
def __init__(self, axolotl_config_path):
|
585 |
+
self.axolotl_config_path = axolotl_config_path
|
586 |
+
|
587 |
+
def on_train_begin(
|
588 |
+
self,
|
589 |
+
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
|
590 |
+
state: TrainerState, # pylint: disable=unused-argument
|
591 |
+
control: TrainerControl,
|
592 |
+
**kwargs, # pylint: disable=unused-argument
|
593 |
+
):
|
594 |
+
if is_main_process():
|
595 |
+
try:
|
596 |
+
with NamedTemporaryFile(
|
597 |
+
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
598 |
+
) as temp_file:
|
599 |
+
copyfile(self.axolotl_config_path, temp_file.name)
|
600 |
+
mlflow.log_artifact(temp_file.name, artifact_path="")
|
601 |
+
LOG.info(
|
602 |
+
"The Axolotl config has been saved to the MLflow artifacts."
|
603 |
+
)
|
604 |
+
except (FileNotFoundError, ConnectionError) as err:
|
605 |
+
LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
|
606 |
+
return control
|