PicoAudio2 / utils /accelerate_utilities.py
rookie9's picture
Upload 77 files
f582ec6 verified
raw
history blame contribute delete
464 Bytes
from accelerate import Accelerator
class AcceleratorSaveTrainableParams(Accelerator):
def get_state_dict(self, model, unwrap=True):
state_dict = super().get_state_dict(model, unwrap)
if hasattr(model, "param_names_to_save"):
param_names_to_save = model.param_names_to_save
return {
k: v
for k, v in state_dict.items() if k in param_names_to_save
}
return state_dict