Johan Hansson winglian commited on
Commit
090c24d
1 Parent(s): 651b7a3

Add: mlflow for experiment tracking (#1059) [skip ci]

Browse files

* Update requirements.txt

adding mlflow

* Update __init__.py

Imports for mlflow

* Update README.md

* Create mlflow_.py (#1)

* Update README.md

* fix precommits

* Update README.md

Update mlflow_tracking_uri

* Update trainer_builder.py

update trainer building

* chore: lint

* make ternary a bit more readable

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>

README.md CHANGED
@@ -10,7 +10,7 @@ Features:
10
  - Integrated with xformer, flash attention, rope scaling, and multipacking
11
  - Works with single GPU or multiple GPUs via FSDP or Deepspeed
12
  - Easily run with Docker locally or on the cloud
13
- - Log results and optionally checkpoints to wandb
14
  - And more!
15
 
16
 
@@ -695,6 +695,10 @@ wandb_name: # Set the name of your wandb run
695
  wandb_run_id: # Set the ID of your wandb run
696
  wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
697
 
 
 
 
 
698
  # Where to save the full-finetuned model to
699
  output_dir: ./completed-model
700
 
 
10
  - Integrated with xformer, flash attention, rope scaling, and multipacking
11
  - Works with single GPU or multiple GPUs via FSDP or Deepspeed
12
  - Easily run with Docker locally or on the cloud
13
+ - Log results and optionally checkpoints to wandb or mlflow
14
  - And more!
15
 
16
 
 
695
  wandb_run_id: # Set the ID of your wandb run
696
  wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
697
 
698
+ # mlflow configuration if you're using it
699
+ mlflow_tracking_uri: # URI to mlflow
700
+ mlflow_experiment_name: # Your experiment name
701
+
702
  # Where to save the full-finetuned model to
703
  output_dir: ./completed-model
704
 
requirements.txt CHANGED
@@ -22,6 +22,7 @@ hf_transfer
22
  colorama
23
  numba
24
  numpy>=1.24.4
 
25
  # qlora things
26
  bert-score==0.3.13
27
  evaluate==0.4.0
 
22
  colorama
23
  numba
24
  numpy>=1.24.4
25
+ mlflow
26
  # qlora things
27
  bert-score==0.3.13
28
  evaluate==0.4.0
src/axolotl/cli/__init__.py CHANGED
@@ -29,6 +29,7 @@ from axolotl.utils.config import normalize_config, validate_config
29
  from axolotl.utils.data import prepare_dataset
30
  from axolotl.utils.dict import DictDefault
31
  from axolotl.utils.distributed import is_main_process
 
32
  from axolotl.utils.models import load_tokenizer
33
  from axolotl.utils.tokenization import check_dataset_labels
34
  from axolotl.utils.trainer import prepare_optim_env
@@ -289,6 +290,9 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
289
  normalize_config(cfg)
290
 
291
  setup_wandb_env_vars(cfg)
 
 
 
292
  return cfg
293
 
294
 
 
29
  from axolotl.utils.data import prepare_dataset
30
  from axolotl.utils.dict import DictDefault
31
  from axolotl.utils.distributed import is_main_process
32
+ from axolotl.utils.mlflow_ import setup_mlflow_env_vars
33
  from axolotl.utils.models import load_tokenizer
34
  from axolotl.utils.tokenization import check_dataset_labels
35
  from axolotl.utils.trainer import prepare_optim_env
 
290
  normalize_config(cfg)
291
 
292
  setup_wandb_env_vars(cfg)
293
+
294
+ setup_mlflow_env_vars(cfg)
295
+
296
  return cfg
297
 
298
 
src/axolotl/core/trainer_builder.py CHANGED
@@ -747,7 +747,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
747
  False if self.cfg.ddp else None
748
  )
749
  training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
750
- training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
 
 
 
 
 
751
  training_arguments_kwargs["run_name"] = (
752
  self.cfg.wandb_name if self.cfg.use_wandb else None
753
  )
 
747
  False if self.cfg.ddp else None
748
  )
749
  training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
750
+ report_to = None
751
+ if self.cfg.use_wandb:
752
+ report_to = "wandb"
753
+ if self.cfg.use_mlflow:
754
+ report_to = "mlflow"
755
+ training_arguments_kwargs["report_to"] = report_to
756
  training_arguments_kwargs["run_name"] = (
757
  self.cfg.wandb_name if self.cfg.use_wandb else None
758
  )
src/axolotl/utils/mlflow_.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module for mlflow utilities"""
2
+
3
+ import os
4
+
5
+ from axolotl.utils.dict import DictDefault
6
+
7
+
8
+ def setup_mlflow_env_vars(cfg: DictDefault):
9
+ for key in cfg.keys():
10
+ if key.startswith("mlflow_"):
11
+ value = cfg.get(key, "")
12
+
13
+ if value and isinstance(value, str) and len(value) > 0:
14
+ os.environ[key.upper()] = value
15
+
16
+ # Enable mlflow if experiment name is present
17
+ if cfg.mlflow_experiment_name and len(cfg.mlflow_experiment_name) > 0:
18
+ cfg.use_mlflow = True