Spaces:
Running
Running
import datetime | |
import os | |
from TTS.utils.io import save_fsspec | |
def save_checkpoint(model, optimizer, model_loss, out_path, current_step): | |
checkpoint_path = "checkpoint_{}.pth".format(current_step) | |
checkpoint_path = os.path.join(out_path, checkpoint_path) | |
print(" | | > Checkpoint saving : {}".format(checkpoint_path)) | |
new_state_dict = model.state_dict() | |
state = { | |
"model": new_state_dict, | |
"optimizer": optimizer.state_dict() if optimizer is not None else None, | |
"step": current_step, | |
"loss": model_loss, | |
"date": datetime.date.today().strftime("%B %d, %Y"), | |
} | |
save_fsspec(state, checkpoint_path) | |
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step): | |
if model_loss < best_loss: | |
new_state_dict = model.state_dict() | |
state = { | |
"model": new_state_dict, | |
"optimizer": optimizer.state_dict(), | |
"step": current_step, | |
"loss": model_loss, | |
"date": datetime.date.today().strftime("%B %d, %Y"), | |
} | |
best_loss = model_loss | |
bestmodel_path = "best_model.pth" | |
bestmodel_path = os.path.join(out_path, bestmodel_path) | |
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path)) | |
save_fsspec(state, bestmodel_path) | |
return best_loss | |