scope flash-attn+qlora fix correctly, scope to llama, add comment
Browse files
    	
        src/axolotl/utils/models.py
    CHANGED
    
    | 
         @@ -333,13 +333,15 @@ def load_model( 
     | 
|
| 333 | 
         
             
                        model, use_gradient_checkpointing=cfg.gradient_checkpointing
         
     | 
| 334 | 
         
             
                    )
         
     | 
| 335 | 
         | 
| 336 | 
         
            -
             
     | 
| 337 | 
         
            -
                     
     | 
| 338 | 
         
            -
             
     | 
| 339 | 
         
            -
             
     | 
| 340 | 
         
            -
             
     | 
| 341 | 
         
            -
                            if hasattr(module, "weight"):
         
     | 
| 342 | 
         
             
                                module.to(torch_dtype)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 343 | 
         | 
| 344 | 
         
             
                model, lora_config = load_adapter(model, cfg, adapter)
         
     | 
| 345 | 
         | 
| 
         | 
|
| 333 | 
         
             
                        model, use_gradient_checkpointing=cfg.gradient_checkpointing
         
     | 
| 334 | 
         
             
                    )
         
     | 
| 335 | 
         | 
| 336 | 
         
            +
                    # LlamaRMSNorm layers are in fp32 after kit call, so we need to
         
     | 
| 337 | 
         
            +
                    # convert them back to fp16/bf16 for flash-attn compatibility.
         
     | 
| 338 | 
         
            +
                    if cfg.flash_attention and cfg.is_llama_derived_model:
         
     | 
| 339 | 
         
            +
                        for name, module in model.named_modules():
         
     | 
| 340 | 
         
            +
                            if "norm" in name:
         
     | 
| 
         | 
|
| 341 | 
         
             
                                module.to(torch_dtype)
         
     | 
| 342 | 
         
            +
                            if "lm_head" in name or "embed_tokens" in name:
         
     | 
| 343 | 
         
            +
                                if hasattr(module, "weight"):
         
     | 
| 344 | 
         
            +
                                    module.to(torch_dtype)
         
     | 
| 345 | 
         | 
| 346 | 
         
             
                model, lora_config = load_adapter(model, cfg, adapter)
         
     | 
| 347 | 
         |