File size: 464 Bytes
f582ec6
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from accelerate import Accelerator


class AcceleratorSaveTrainableParams(Accelerator):
    def get_state_dict(self, model, unwrap=True):
        state_dict = super().get_state_dict(model, unwrap)
        if hasattr(model, "param_names_to_save"):
            param_names_to_save = model.param_names_to_save
            return {
                k: v
                for k, v in state_dict.items() if k in param_names_to_save
            }
        return state_dict