Spaces:
Running
on
Zero
Running
on
Zero
File size: 464 Bytes
f582ec6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
from accelerate import Accelerator
class AcceleratorSaveTrainableParams(Accelerator):
def get_state_dict(self, model, unwrap=True):
state_dict = super().get_state_dict(model, unwrap)
if hasattr(model, "param_names_to_save"):
param_names_to_save = model.param_names_to_save
return {
k: v
for k, v in state_dict.items() if k in param_names_to_save
}
return state_dict
|