import pickle import jax dic = pickle.load( open("./wavegru_vocoder_tpu_gta_preemphasis_pruning_0800000.ckpt", "rb") ) dic = jax.device_get(dic) del dic["optim_state_dict"] pickle.dump( dic, open("./wavegru_vocoder_tpu_gta_preemphasis_pruning_0800000.ckpt", "wb") )