winglian commited on
Commit
68601ec
1 Parent(s): 60f5ce0

make sure everything stays in the same dtype when using dpo + FSDP (#1559)

Browse files
src/axolotl/core/trainer_builder.py CHANGED
@@ -54,6 +54,7 @@ from axolotl.utils.collators import (
54
  MambaDataCollator,
55
  V2BatchSamplerDataCollatorForSeq2Seq,
56
  )
 
57
  from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
58
  from axolotl.utils.schedulers import (
59
  get_cosine_schedule_with_min_lr,
@@ -1569,6 +1570,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
1569
  callbacks=self.get_callbacks(),
1570
  **dpo_trainer_kwargs,
1571
  )
 
 
 
1572
  dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
1573
  for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
1574
  dpo_trainer.add_callback(callback)
 
54
  MambaDataCollator,
55
  V2BatchSamplerDataCollatorForSeq2Seq,
56
  )
57
+ from axolotl.utils.models import ensure_dtype
58
  from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
59
  from axolotl.utils.schedulers import (
60
  get_cosine_schedule_with_min_lr,
 
1570
  callbacks=self.get_callbacks(),
1571
  **dpo_trainer_kwargs,
1572
  )
1573
+ if self.cfg.fsdp:
1574
+ ensure_dtype(dpo_trainer.model, dtype=self.cfg.torch_dtype)
1575
+
1576
  dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
1577
  for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
1578
  dpo_trainer.add_callback(callback)
src/axolotl/utils/models.py CHANGED
@@ -993,3 +993,13 @@ def load_lora(model, cfg, inference=False, config_only=False):
993
  setup_quantized_peft_meta_for_training(model)
994
 
995
  return model, lora_config
 
 
 
 
 
 
 
 
 
 
 
993
  setup_quantized_peft_meta_for_training(model)
994
 
995
  return model, lora_config
996
+
997
+
998
+ def ensure_dtype(model, dtype=torch.bfloat16):
999
+ for name, module in model.named_modules():
1000
+ try:
1001
+ if module.weight.dtype != dtype:
1002
+ print(f"Converting module {name}: {module.weight.dtype} -> {dtype}")
1003
+ module.to(dtype)
1004
+ except AttributeError:
1005
+ pass