import pickle import jax dic = pickle.load(open("./mono_tts_cbhg_small_0700000.ckpt", "rb")) del dic["optim_state_dict"] dic = jax.device_get(dic) pickle.dump(dic, open("./mono_tts_cbhg_small_0700000.ckpt", "wb"))