fix(config): passing gradient_checkpoint_kwargs (#1412)
Browse files* fix(config): change default use_reentrant to true
* Update trainer_builder.py
* fix: make sure to pass kwargs to enable checkpoint
* chore: lint
- README.md +1 -1
 - src/axolotl/core/trainer_builder.py +0 -4
 - src/axolotl/utils/models.py +3 -1
 
    	
        README.md
    CHANGED
    
    | 
         @@ -859,7 +859,7 @@ group_by_length: false 
     | 
|
| 859 | 
         
             
            gradient_checkpointing: false
         
     | 
| 860 | 
         
             
            # additional kwargs to pass to the trainer for gradient checkpointing
         
     | 
| 861 | 
         
             
            # gradient_checkpointing_kwargs:
         
     | 
| 862 | 
         
            -
            #   use_reentrant:  
     | 
| 863 | 
         | 
| 864 | 
         
             
            # Stop training after this many evaluation losses have increased in a row
         
     | 
| 865 | 
         
             
            # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
         
     | 
| 
         | 
|
| 859 | 
         
             
            gradient_checkpointing: false
         
     | 
| 860 | 
         
             
            # additional kwargs to pass to the trainer for gradient checkpointing
         
     | 
| 861 | 
         
             
            # gradient_checkpointing_kwargs:
         
     | 
| 862 | 
         
            +
            #   use_reentrant: true
         
     | 
| 863 | 
         | 
| 864 | 
         
             
            # Stop training after this many evaluation losses have increased in a row
         
     | 
| 865 | 
         
             
            # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
         
     | 
    	
        src/axolotl/core/trainer_builder.py
    CHANGED
    
    | 
         @@ -970,10 +970,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): 
     | 
|
| 970 | 
         
             
                            training_arguments_kwargs[
         
     | 
| 971 | 
         
             
                                "gradient_checkpointing_kwargs"
         
     | 
| 972 | 
         
             
                            ] = self.cfg.gradient_checkpointing_kwargs
         
     | 
| 973 | 
         
            -
                        else:
         
     | 
| 974 | 
         
            -
                            training_arguments_kwargs["gradient_checkpointing_kwargs"] = {
         
     | 
| 975 | 
         
            -
                                "use_reentrant": False
         
     | 
| 976 | 
         
            -
                            }
         
     | 
| 977 | 
         
             
                    if self.cfg.fsdp:
         
     | 
| 978 | 
         
             
                        training_arguments_kwargs["fsdp"] = self.cfg.fsdp
         
     | 
| 979 | 
         
             
                        if self.cfg.fsdp_config:
         
     | 
| 
         | 
|
| 970 | 
         
             
                            training_arguments_kwargs[
         
     | 
| 971 | 
         
             
                                "gradient_checkpointing_kwargs"
         
     | 
| 972 | 
         
             
                            ] = self.cfg.gradient_checkpointing_kwargs
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 973 | 
         
             
                    if self.cfg.fsdp:
         
     | 
| 974 | 
         
             
                        training_arguments_kwargs["fsdp"] = self.cfg.fsdp
         
     | 
| 975 | 
         
             
                        if self.cfg.fsdp_config:
         
     | 
    	
        src/axolotl/utils/models.py
    CHANGED
    
    | 
         @@ -888,7 +888,9 @@ def load_model( 
     | 
|
| 888 | 
         | 
| 889 | 
         
             
                if cfg.adapter in ["lora", "qlora"]:
         
     | 
| 890 | 
         
             
                    if cfg.gradient_checkpointing:
         
     | 
| 891 | 
         
            -
                        model.gradient_checkpointing_enable( 
     | 
| 
         | 
|
| 
         | 
|
| 892 | 
         
             
                    if (
         
     | 
| 893 | 
         
             
                        cfg.load_in_8bit or cfg.load_in_4bit
         
     | 
| 894 | 
         
             
                    ) and not skip_prepare_model_for_kbit_training:
         
     | 
| 
         | 
|
| 888 | 
         | 
| 889 | 
         
             
                if cfg.adapter in ["lora", "qlora"]:
         
     | 
| 890 | 
         
             
                    if cfg.gradient_checkpointing:
         
     | 
| 891 | 
         
            +
                        model.gradient_checkpointing_enable(
         
     | 
| 892 | 
         
            +
                            gradient_checkpointing_kwargs=cfg.gradient_checkpointing_kwargs
         
     | 
| 893 | 
         
            +
                        )
         
     | 
| 894 | 
         
             
                    if (
         
     | 
| 895 | 
         
             
                        cfg.load_in_8bit or cfg.load_in_4bit
         
     | 
| 896 | 
         
             
                    ) and not skip_prepare_model_for_kbit_training:
         
     |