tmm1 commited on
Commit
267b7b2
1 Parent(s): 98bf76e

simplify linear layer locator

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/models.py +3 -13
src/axolotl/utils/models.py CHANGED
@@ -464,12 +464,8 @@ def load_llama_adapter(model, cfg):
464
  return model, peft_config
465
 
466
 
467
- def find_all_linear_names(bits, model):
468
- cls = (
469
- bnb.nn.Linear4bit
470
- if bits == 4
471
- else (bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear)
472
- )
473
  lora_module_names = set()
474
  for name, module in model.named_modules():
475
  if isinstance(module, cls):
@@ -490,13 +486,7 @@ def load_lora(model, cfg):
490
  lora_target_modules = list(cfg.lora_target_modules or [])
491
 
492
  if cfg.lora_target_linear:
493
- bits = None
494
- if cfg.load_in_4bit:
495
- bits = 4
496
- elif cfg.load_in_8bit:
497
- bits = 8
498
-
499
- linear_names = find_all_linear_names(bits, model)
500
  LOG.info(f"found linear modules: {repr(linear_names)}")
501
  lora_target_modules = list(set(lora_target_modules + linear_names))
502
 
 
464
  return model, peft_config
465
 
466
 
467
+ def find_all_linear_names(model):
468
+ cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
 
 
 
 
469
  lora_module_names = set()
470
  for name, module in model.named_modules():
471
  if isinstance(module, cls):
 
486
  lora_target_modules = list(cfg.lora_target_modules or [])
487
 
488
  if cfg.lora_target_linear:
489
+ linear_names = find_all_linear_names(model)
 
 
 
 
 
 
490
  LOG.info(f"found linear modules: {repr(linear_names)}")
491
  lora_target_modules = list(set(lora_target_modules + linear_names))
492