add logging and make sure model unloads to float16
Browse files- scripts/finetune.py +1 -0
 - src/axolotl/utils/validation.py +6 -0
 
    	
        scripts/finetune.py
    CHANGED
    
    | 
         @@ -176,6 +176,7 @@ def train( 
     | 
|
| 176 | 
         
             
                if "merge_lora" in kwargs and cfg.adapter is not None:
         
     | 
| 177 | 
         
             
                    logging.info("running merge of LoRA with base model")
         
     | 
| 178 | 
         
             
                    model = model.merge_and_unload()
         
     | 
| 
         | 
|
| 179 | 
         | 
| 180 | 
         
             
                    if cfg.local_rank == 0:
         
     | 
| 181 | 
         
             
                        logging.info("saving merged model")
         
     | 
| 
         | 
|
| 176 | 
         
             
                if "merge_lora" in kwargs and cfg.adapter is not None:
         
     | 
| 177 | 
         
             
                    logging.info("running merge of LoRA with base model")
         
     | 
| 178 | 
         
             
                    model = model.merge_and_unload()
         
     | 
| 179 | 
         
            +
                    model.to(dtype=torch.float16)
         
     | 
| 180 | 
         | 
| 181 | 
         
             
                    if cfg.local_rank == 0:
         
     | 
| 182 | 
         
             
                        logging.info("saving merged model")
         
     | 
    	
        src/axolotl/utils/validation.py
    CHANGED
    
    | 
         @@ -1,3 +1,6 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1 | 
         
             
            def validate_config(cfg):
         
     | 
| 2 | 
         
             
                if cfg.adapter == "qlora":
         
     | 
| 3 | 
         
             
                    if cfg.merge_lora:
         
     | 
| 
         @@ -9,6 +12,9 @@ def validate_config(cfg): 
     | 
|
| 9 | 
         
             
                        assert cfg.load_in_8bit is False
         
     | 
| 10 | 
         
             
                        assert cfg.load_4bit is False
         
     | 
| 11 | 
         
             
                        assert cfg.load_in_4bit is True
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 12 | 
         
             
                # TODO
         
     | 
| 13 | 
         
             
                # MPT 7b
         
     | 
| 14 | 
         
             
                # https://github.com/facebookresearch/bitsandbytes/issues/25
         
     | 
| 
         | 
|
| 1 | 
         
            +
            import logging
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
             
            def validate_config(cfg):
         
     | 
| 5 | 
         
             
                if cfg.adapter == "qlora":
         
     | 
| 6 | 
         
             
                    if cfg.merge_lora:
         
     | 
| 
         | 
|
| 12 | 
         
             
                        assert cfg.load_in_8bit is False
         
     | 
| 13 | 
         
             
                        assert cfg.load_4bit is False
         
     | 
| 14 | 
         
             
                        assert cfg.load_in_4bit is True
         
     | 
| 15 | 
         
            +
                if cfg.load_in_8bit and cfg.adapter == "lora":
         
     | 
| 16 | 
         
            +
                    logging.warning("we recommend setting `load_in_8bit: true`")
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
             
                # TODO
         
     | 
| 19 | 
         
             
                # MPT 7b
         
     | 
| 20 | 
         
             
                # https://github.com/facebookresearch/bitsandbytes/issues/25
         
     |