Nanobit commited on
Commit
cfbce02
·
unverified ·
1 Parent(s): 67b9888

Fix: Fail bf16 check when running on cpu during merge (#631)

Browse files
src/axolotl/utils/config.py CHANGED
@@ -94,7 +94,7 @@ def validate_config(cfg):
94
  if not cfg.bf16 and not cfg.bfloat16:
95
  LOG.info("bf16 support detected, but not enabled for this configuration.")
96
  else:
97
- if cfg.bf16 or cfg.bfloat16:
98
  raise ValueError(
99
  "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
100
  )
 
94
  if not cfg.bf16 and not cfg.bfloat16:
95
  LOG.info("bf16 support detected, but not enabled for this configuration.")
96
  else:
97
+ if not cfg.merge_lora and (cfg.bf16 or cfg.bfloat16):
98
  raise ValueError(
99
  "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
100
  )
tests/test_validation.py CHANGED
@@ -351,3 +351,26 @@ class ValidationTest(unittest.TestCase):
351
  regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
352
  with pytest.raises(ValueError, match=regex_exp):
353
  validate_config(cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
  regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
352
  with pytest.raises(ValueError, match=regex_exp):
353
  validate_config(cfg)
354
+
355
+ def test_merge_lora_no_bf16_fail(self):
356
+ """
357
+ This is assumed to be run on a CPU machine, so bf16 is not supported.
358
+ """
359
+
360
+ cfg = DictDefault(
361
+ {
362
+ "bf16": True,
363
+ }
364
+ )
365
+
366
+ with pytest.raises(ValueError, match=r".*AMP is not supported on this GPU*"):
367
+ validate_config(cfg)
368
+
369
+ cfg = DictDefault(
370
+ {
371
+ "bf16": True,
372
+ "merge_lora": True,
373
+ }
374
+ )
375
+
376
+ validate_config(cfg)