Spaces:
Runtime error
Runtime error
import hydra | |
from hydra import utils | |
from itertools import chain | |
from pathlib import Path | |
import numpy as np | |
import torch | |
import torch.optim as optim | |
from torch.utils.data import DataLoader | |
from dataset import CPCDataset_sameSeq as CPCDataset | |
from scheduler import WarmupScheduler | |
from model_encoder import Encoder, CPCLoss_sameSeq, Encoder_lf0 | |
from model_decoder import Decoder_ac | |
from model_encoder import SpeakerEncoder as Encoder_spk | |
from mi_estimators import CLUBSample_group, CLUBSample_reshape | |
import apex.amp as amp | |
import os | |
import time | |
torch.manual_seed(137) | |
np.random.seed(137) | |
def save_checkpoint(encoder, encoder_lf0, cpc, encoder_spk, \ | |
cs_mi_net, ps_mi_net, cp_mi_net, decoder, \ | |
optimizer, optimizer_cs_mi_net, optimizer_ps_mi_net, optimizer_cp_mi_net, scheduler, amp, epoch, checkpoint_dir, cfg): | |
if cfg.use_amp: | |
amp_state_dict = amp.state_dict() | |
else: | |
amp_state_dict = None | |
checkpoint_state = { | |
"encoder": encoder.state_dict(), | |
"encoder_lf0": encoder_lf0.state_dict(), | |
"cpc": cpc.state_dict(), | |
"encoder_spk": encoder_spk.state_dict(), | |
"ps_mi_net": ps_mi_net.state_dict(), | |
"cp_mi_net": cp_mi_net.state_dict(), | |
"cs_mi_net": cs_mi_net.state_dict(), | |
"decoder": decoder.state_dict(), | |
"optimizer": optimizer.state_dict(), | |
"optimizer_cs_mi_net": optimizer_cs_mi_net.state_dict(), | |
"optimizer_ps_mi_net": optimizer_ps_mi_net.state_dict(), | |
"optimizer_cp_mi_net": optimizer_cp_mi_net.state_dict(), | |
"scheduler": scheduler.state_dict(), | |
"amp": amp_state_dict, | |
"epoch": epoch | |
} | |
checkpoint_dir.mkdir(exist_ok=True, parents=True) | |
checkpoint_path = checkpoint_dir / "model.ckpt-{}.pt".format(epoch) | |
torch.save(checkpoint_state, checkpoint_path) | |
print("Saved checkpoint: {}".format(checkpoint_path.stem)) | |
def mi_first_forward(mels, lf0, encoder, encoder_lf0, encoder_spk, cs_mi_net, optimizer_cs_mi_net, | |
ps_mi_net, optimizer_ps_mi_net, cp_mi_net, optimizer_cp_mi_net, cfg): | |
optimizer_cs_mi_net.zero_grad() | |
optimizer_ps_mi_net.zero_grad() | |
optimizer_cp_mi_net.zero_grad() | |
z, _, _, _, _ = encoder(mels) | |
z = z.detach() | |
lf0_embs = encoder_lf0(lf0).detach() | |
spk_embs = encoder_spk(mels).detach() | |
if cfg.use_CSMI: | |
lld_cs_loss = -cs_mi_net.loglikeli(spk_embs, z) | |
if cfg.use_amp: | |
with amp.scale_loss(lld_cs_loss, optimizer_cs_mi_net) as sl: | |
sl.backward() | |
else: | |
lld_cs_loss.backward() | |
optimizer_cs_mi_net.step() | |
else: | |
lld_cs_loss = torch.tensor(0.) | |
if cfg.use_CPMI: | |
lld_cp_loss = -cp_mi_net.loglikeli(lf0_embs.unsqueeze(1).reshape(lf0_embs.shape[0],-1,2,lf0_embs.shape[-1]).mean(2), z) | |
if cfg.use_amp: | |
with amp.scale_loss(lld_cp_loss, optimizer_cp_mi_net) as slll: | |
slll.backward() | |
else: | |
lld_cp_loss.backward() | |
torch.nn.utils.clip_grad_norm_(cp_mi_net.parameters(), 1) | |
optimizer_cp_mi_net.step() | |
else: | |
lld_cp_loss = torch.tensor(0.) | |
if cfg.use_PSMI: | |
lld_ps_loss = -ps_mi_net.loglikeli(spk_embs, lf0_embs) | |
if cfg.use_amp: | |
with amp.scale_loss(lld_ps_loss, optimizer_ps_mi_net) as sll: | |
sll.backward() | |
else: | |
lld_ps_loss.backward() | |
optimizer_ps_mi_net.step() | |
else: | |
lld_ps_loss = torch.tensor(0.) | |
return optimizer_cs_mi_net, lld_cs_loss, optimizer_ps_mi_net, lld_ps_loss, optimizer_cp_mi_net, lld_cp_loss | |
def mi_second_forward(mels, lf0, encoder, encoder_lf0, cpc, encoder_spk, cs_mi_net, ps_mi_net, cp_mi_net, decoder, cfg, optimizer, scheduler): | |
optimizer.zero_grad() | |
z, c, _, vq_loss, perplexity = encoder(mels) | |
cpc_loss, accuracy = cpc(z, c) | |
spk_embs = encoder_spk(mels) | |
lf0_embs = encoder_lf0(lf0) | |
recon_loss, pred_mels = decoder(z, lf0_embs, spk_embs, mels.transpose(1,2)) | |
loss = recon_loss + cpc_loss + vq_loss | |
if cfg.use_CSMI: | |
mi_cs_loss = cfg.mi_weight*cs_mi_net.mi_est(spk_embs, z) | |
else: | |
mi_cs_loss = torch.tensor(0.).to(loss.device) | |
if cfg.use_CPMI: | |
mi_cp_loss = cfg.mi_weight*cp_mi_net.mi_est(lf0_embs.unsqueeze(1).reshape(lf0_embs.shape[0],-1,2,lf0_embs.shape[-1]).mean(2), z) | |
else: | |
mi_cp_loss = torch.tensor(0.).to(loss.device) | |
if cfg.use_PSMI: | |
mi_ps_loss = cfg.mi_weight*ps_mi_net.mi_est(spk_embs, lf0_embs) | |
else: | |
mi_ps_loss = torch.tensor(0.).to(loss.device) | |
loss = loss + mi_cs_loss + mi_ps_loss + mi_cp_loss | |
if cfg.use_amp: | |
with amp.scale_loss(loss, optimizer) as scaled_loss: | |
scaled_loss.backward() | |
else: | |
loss.backward() | |
optimizer.step() | |
return optimizer, recon_loss, vq_loss, cpc_loss, accuracy, perplexity, mi_cs_loss, mi_ps_loss, mi_cp_loss | |
def calculate_eval_loss(mels, lf0, \ | |
encoder, encoder_lf0, cpc, \ | |
encoder_spk, cs_mi_net, ps_mi_net, \ | |
cp_mi_net, decoder, cfg): | |
with torch.no_grad(): | |
z, c, z_beforeVQ, vq_loss, perplexity = encoder(mels) | |
c = c | |
lf0_embs = encoder_lf0(lf0) | |
spk_embs = encoder_spk(mels) | |
if cfg.use_CSMI: | |
lld_cs_loss = -cs_mi_net.loglikeli(spk_embs, z) | |
mi_cs_loss = cfg.mi_weight*cs_mi_net.mi_est(spk_embs, z) | |
else: | |
lld_cs_loss = torch.tensor(0.) | |
mi_cs_loss = torch.tensor(0.) | |
# z, c, z_beforeVQ, vq_loss, perplexity = encoder(mels) | |
cpc_loss, accuracy = cpc(z, c) | |
recon_loss, pred_mels = decoder(z, lf0_embs, spk_embs, mels.transpose(1,2)) | |
if cfg.use_CPMI: | |
mi_cp_loss = cfg.mi_weight*cp_mi_net.mi_est(lf0_embs.unsqueeze(1).reshape(lf0_embs.shape[0],-1,2,lf0_embs.shape[-1]).mean(2), z) | |
lld_cp_loss = -cp_mi_net.loglikeli(lf0_embs.unsqueeze(1).reshape(lf0_embs.shape[0],-1,2,lf0_embs.shape[-1]).mean(2), z) | |
else: | |
mi_cp_loss = torch.tensor(0.) | |
lld_cp_loss = torch.tensor(0.) | |
if cfg.use_PSMI: | |
mi_ps_loss = cfg.mi_weight*ps_mi_net.mi_est(spk_embs, lf0_embs) | |
lld_ps_loss = -ps_mi_net.loglikeli(spk_embs, lf0_embs) | |
else: | |
mi_ps_loss = torch.tensor(0.) | |
lld_ps_loss = torch.tensor(0.) | |
return recon_loss, vq_loss, cpc_loss, accuracy, perplexity, mi_cs_loss, lld_cs_loss, mi_ps_loss, lld_ps_loss, mi_cp_loss, lld_cp_loss | |
def to_eval(all_models): | |
for m in all_models: | |
m.eval() | |
def to_train(all_models): | |
for m in all_models: | |
m.train() | |
def eval_model(epoch, checkpoint_dir, device, valid_dataloader, encoder, encoder_lf0, cpc, encoder_spk, cs_mi_net, ps_mi_net, cp_mi_net, decoder, cfg): | |
stime = time.time() | |
average_cpc_loss = average_vq_loss = average_perplexity = average_recon_loss = 0 | |
average_accuracies = np.zeros(cfg.training.n_prediction_steps) | |
average_lld_cs_loss = average_mi_cs_loss = average_lld_ps_loss = average_mi_ps_loss = average_lld_cp_loss = average_mi_cp_loss = 0 | |
all_models = [encoder, encoder_lf0, cpc, encoder_spk, cs_mi_net, ps_mi_net, cp_mi_net, decoder] | |
to_eval(all_models) | |
for i, (mels, lf0, speakers) in enumerate(valid_dataloader, 1): | |
lf0 = lf0.to(device) | |
mels = mels.to(device) # (bs, 80, 128) | |
recon_loss, vq_loss, cpc_loss, accuracy, perplexity, mi_cs_loss, lld_cs_loss, mi_ps_loss, lld_ps_loss, mi_cp_loss, lld_cp_loss = \ | |
calculate_eval_loss(mels, lf0, \ | |
encoder, encoder_lf0, cpc, \ | |
encoder_spk, cs_mi_net, ps_mi_net, \ | |
cp_mi_net, decoder, cfg) | |
average_recon_loss += (recon_loss.item() - average_recon_loss) / i | |
average_cpc_loss += (cpc_loss.item() - average_cpc_loss) / i | |
average_vq_loss += (vq_loss.item() - average_vq_loss) / i | |
average_perplexity += (perplexity.item() - average_perplexity) / i | |
average_accuracies += (np.array(accuracy) - average_accuracies) / i | |
average_lld_cs_loss += (lld_cs_loss.item() - average_lld_cs_loss) / i | |
average_mi_cs_loss += (mi_cs_loss.item() - average_mi_cs_loss) / i | |
average_lld_ps_loss += (lld_ps_loss.item() - average_lld_ps_loss) / i | |
average_mi_ps_loss += (mi_ps_loss.item() - average_mi_ps_loss) / i | |
average_lld_cp_loss += (lld_cp_loss.item() - average_lld_cp_loss) / i | |
average_mi_cp_loss += (mi_cp_loss.item() - average_mi_cp_loss) / i | |
ctime = time.time() | |
print("Eval | epoch:{}, recon loss:{:.3f}, cpc loss:{:.3f}, vq loss:{:.3f}, perpexlity:{:.3f}, lld cs loss:{:.3f}, mi cs loss:{:.3E}, lld ps loss:{:.3f}, mi ps loss:{:.3f}, lld cp loss:{:.3f}, mi cp loss:{:.3f}, used time:{:.3f}s" | |
.format(epoch, average_recon_loss, average_cpc_loss, average_vq_loss, average_perplexity, average_lld_cs_loss, average_mi_cs_loss, average_lld_ps_loss, average_mi_ps_loss, average_lld_cp_loss, average_mi_cp_loss, ctime-stime)) | |
print(100 * average_accuracies) | |
results_txt = open(f'{str(checkpoint_dir)}/results.txt', 'a') | |
results_txt.write("Eval | epoch:{}, recon loss:{:.3f}, cpc loss:{:.3f}, vq loss:{:.3f}, perpexlity:{:.3f}, lld cs loss:{:.3f}, mi cs loss:{:.3E}, lld ps loss:{:.3f}, mi ps loss:{:.3f}, lld cp loss:{:.3f}, mi cp loss:{:.3f}" | |
.format(epoch, average_recon_loss, average_cpc_loss, average_vq_loss, average_perplexity, average_lld_cs_loss, average_mi_cs_loss, average_lld_ps_loss, average_mi_ps_loss, average_lld_cp_loss, average_mi_cp_loss)+'\n') | |
results_txt.write(' '.join([str(cpc_acc) for cpc_acc in average_accuracies])+'\n') | |
results_txt.close() | |
to_train(all_models) | |
def train_model(cfg): | |
cfg.checkpoint_dir = f'{cfg.checkpoint_dir}/useCSMI{cfg.use_CSMI}_useCPMI{cfg.use_CPMI}_usePSMI{cfg.use_PSMI}_useAmp{cfg.use_amp}' | |
if cfg.encoder_lf0_type == 'no_emb': # default | |
dim_lf0 = 1 | |
else: | |
dim_lf0 = 64 | |
checkpoint_dir = Path(utils.to_absolute_path(cfg.checkpoint_dir)) | |
checkpoint_dir.mkdir(exist_ok=True, parents=True) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# define model | |
encoder = Encoder(**cfg.model.encoder) | |
encoder_lf0 = Encoder_lf0(cfg.encoder_lf0_type) | |
cpc = CPCLoss_sameSeq(**cfg.model.cpc) | |
encoder_spk = Encoder_spk() | |
cs_mi_net = CLUBSample_group(256, cfg.model.encoder.z_dim, 512) | |
ps_mi_net = CLUBSample_group(256, dim_lf0, 512) | |
cp_mi_net = CLUBSample_reshape(dim_lf0, cfg.model.encoder.z_dim, 512) | |
decoder = Decoder_ac(dim_neck=cfg.model.encoder.z_dim, dim_lf0=dim_lf0, use_l1_loss=True) | |
encoder.to(device) | |
cpc.to(device) | |
encoder_lf0.to(device) | |
encoder_spk.to(device) | |
cs_mi_net.to(device) | |
ps_mi_net.to(device) | |
cp_mi_net.to(device) | |
decoder.to(device) | |
optimizer = optim.Adam( | |
chain(encoder.parameters(), encoder_lf0.parameters(), cpc.parameters(), encoder_spk.parameters(), decoder.parameters()), | |
lr=cfg.training.scheduler.initial_lr) | |
optimizer_cs_mi_net = optim.Adam(cs_mi_net.parameters(), lr=cfg.mi_lr) | |
optimizer_ps_mi_net = optim.Adam(ps_mi_net.parameters(), lr=cfg.mi_lr) | |
optimizer_cp_mi_net = optim.Adam(cp_mi_net.parameters(), lr=cfg.mi_lr) | |
# TODO: use_amp is set default to True to speed up training; no-amp -> more stable training? => need to be verified | |
if cfg.use_amp: | |
[encoder, encoder_lf0, cpc, encoder_spk, decoder], optimizer = amp.initialize([encoder, encoder_lf0, cpc, encoder_spk, decoder], optimizer, opt_level='O1') | |
[cs_mi_net], optimizer_cs_mi_net = amp.initialize([cs_mi_net], optimizer_cs_mi_net, opt_level='O1') | |
[ps_mi_net], optimizer_ps_mi_net = amp.initialize([ps_mi_net], optimizer_ps_mi_net, opt_level='O1') | |
[cp_mi_net], optimizer_cp_mi_net = amp.initialize([cp_mi_net], optimizer_cp_mi_net, opt_level='O1') | |
root_path = Path(utils.to_absolute_path("data")) | |
dataset = CPCDataset( | |
root=root_path, | |
n_sample_frames=cfg.training.sample_frames, # 128 | |
mode='train') | |
valid_dataset = CPCDataset( | |
root=root_path, | |
n_sample_frames=cfg.training.sample_frames, # 128 | |
mode='valid') | |
warmup_epochs = 2000 // (len(dataset)//cfg.training.batch_size) | |
print('warmup_epochs:', warmup_epochs) | |
scheduler = WarmupScheduler( | |
optimizer, | |
warmup_epochs=warmup_epochs, | |
initial_lr=cfg.training.scheduler.initial_lr, | |
max_lr=cfg.training.scheduler.max_lr, | |
milestones=cfg.training.scheduler.milestones, | |
gamma=cfg.training.scheduler.gamma) | |
dataloader = DataLoader( | |
dataset, | |
batch_size=cfg.training.batch_size, # 256 | |
shuffle=True, | |
num_workers=cfg.training.n_workers, | |
pin_memory=True, | |
drop_last=False) | |
valid_dataloader = DataLoader( | |
valid_dataset, | |
batch_size=cfg.training.batch_size, # 256 | |
shuffle=False, | |
num_workers=cfg.training.n_workers, | |
pin_memory=True, | |
drop_last=False) | |
if cfg.resume: | |
print("Resume checkpoint from: {}:".format(cfg.resume)) | |
resume_path = utils.to_absolute_path(cfg.resume) | |
checkpoint = torch.load(resume_path, map_location=lambda storage, loc: storage) | |
encoder.load_state_dict(checkpoint["encoder"]) | |
encoder_lf0.load_state_dict(checkpoint["encoder_lf0"]) | |
cpc.load_state_dict(checkpoint["cpc"]) | |
encoder_spk.load_state_dict(checkpoint["encoder_spk"]) | |
cs_mi_net.load_state_dict(checkpoint["cs_mi_net"]) | |
ps_mi_net.load_state_dict(checkpoint["ps_mi_net"]) | |
if cfg.use_CPMI: | |
cp_mi_net.load_state_dict(checkpoint["cp_mi_net"]) | |
decoder.load_state_dict(checkpoint["decoder"]) | |
optimizer.load_state_dict(checkpoint["optimizer"]) | |
optimizer_cs_mi_net.load_state_dict(checkpoint["optimizer_cs_mi_net"]) | |
optimizer_ps_mi_net.load_state_dict(checkpoint["optimizer_ps_mi_net"]) | |
optimizer_cp_mi_net.load_state_dict(checkpoint["optimizer_cp_mi_net"]) | |
if cfg.use_amp: | |
amp.load_state_dict(checkpoint["amp"]) | |
scheduler.load_state_dict(checkpoint["scheduler"]) | |
start_epoch = checkpoint["epoch"] | |
else: | |
start_epoch = 1 | |
if os.path.exists(f'{str(checkpoint_dir)}/results.txt'): | |
wmode = 'a' | |
else: | |
wmode = 'w' | |
results_txt = open(f'{str(checkpoint_dir)}/results.txt', wmode) | |
results_txt.write('save training info...\n') | |
results_txt.close() | |
global_step = 0 | |
stime = time.time() | |
for epoch in range(start_epoch, cfg.training.n_epochs + 1): | |
average_cpc_loss = average_vq_loss = average_perplexity = average_recon_loss = 0 | |
average_accuracies = np.zeros(cfg.training.n_prediction_steps) | |
average_lld_cs_loss = average_mi_cs_loss = average_lld_ps_loss = average_mi_ps_loss = average_lld_cp_loss = average_mi_cp_loss = 0 | |
for i, (mels, lf0, speakers) in enumerate(dataloader, 1): | |
lf0 = lf0.to(device) | |
mels = mels.to(device) # (bs, 80, 128) | |
if cfg.use_CSMI or cfg.use_CPMI or cfg.use_PSMI: | |
for j in range(cfg.mi_iters): | |
optimizer_cs_mi_net, lld_cs_loss, optimizer_ps_mi_net, lld_ps_loss, optimizer_cp_mi_net, lld_cp_loss = mi_first_forward(mels, lf0, encoder, encoder_lf0, encoder_spk, cs_mi_net, optimizer_cs_mi_net, \ | |
ps_mi_net, optimizer_ps_mi_net, cp_mi_net, optimizer_cp_mi_net, cfg) | |
else: | |
lld_cs_loss = torch.tensor(0.) | |
lld_ps_loss = torch.tensor(0.) | |
lld_cp_loss = torch.tensor(0.) | |
optimizer, recon_loss, vq_loss, cpc_loss, accuracy, perplexity, mi_cs_loss, mi_ps_loss, mi_cp_loss = mi_second_forward(mels, lf0, \ | |
encoder, encoder_lf0, cpc, \ | |
encoder_spk, cs_mi_net, ps_mi_net, \ | |
cp_mi_net, decoder, cfg, \ | |
optimizer, scheduler) | |
average_recon_loss += (recon_loss.item() - average_recon_loss) / i | |
average_cpc_loss += (cpc_loss.item() - average_cpc_loss) / i | |
average_vq_loss += (vq_loss.item() - average_vq_loss) / i | |
average_perplexity += (perplexity.item() - average_perplexity) / i | |
average_accuracies += (np.array(accuracy) - average_accuracies) / i | |
average_lld_cs_loss += (lld_cs_loss.item() - average_lld_cs_loss) / i | |
average_mi_cs_loss += (mi_cs_loss.item() - average_mi_cs_loss) / i | |
average_lld_ps_loss += (lld_ps_loss.item() - average_lld_ps_loss) / i | |
average_mi_ps_loss += (mi_ps_loss.item() - average_mi_ps_loss) / i | |
average_lld_cp_loss += (lld_cp_loss.item() - average_lld_cp_loss) / i | |
average_mi_cp_loss += (mi_cp_loss.item() - average_mi_cp_loss) / i | |
ctime = time.time() | |
print("epoch:{}, global step:{}, recon loss:{:.3f}, cpc loss:{:.3f}, vq loss:{:.3f}, perpexlity:{:.3f}, lld cs loss:{:.3f}, mi cs loss:{:.3E}, lld ps loss:{:.3f}, mi ps loss:{:.3f}, lld cp loss:{:.3f}, mi cp loss:{:.3f}, used time:{:.3f}s" | |
.format(epoch, global_step, average_recon_loss, average_cpc_loss, average_vq_loss, average_perplexity, average_lld_cs_loss, average_mi_cs_loss, average_lld_ps_loss, average_mi_ps_loss, average_lld_cp_loss, average_mi_cp_loss, ctime-stime)) | |
print(100 * average_accuracies) | |
stime = time.time() | |
global_step += 1 | |
# scheduler.step() | |
results_txt = open(f'{str(checkpoint_dir)}/results.txt', 'a') | |
results_txt.write("epoch:{}, global step:{}, recon loss:{:.3f}, cpc loss:{:.3f}, vq loss:{:.3f}, perpexlity:{:.3f}, lld cs loss:{:.3f}, mi cs loss:{:.3E}, lld ps loss:{:.3f}, mi ps loss:{:.3f}, lld cp loss:{:.3f}, mi cp loss:{:.3f}" | |
.format(epoch, global_step, average_recon_loss, average_cpc_loss, average_vq_loss, average_perplexity, average_lld_cs_loss, average_mi_cs_loss, average_lld_ps_loss, average_mi_ps_loss, average_lld_cp_loss, average_mi_cp_loss)+'\n') | |
results_txt.write(' '.join([str(cpc_acc) for cpc_acc in average_accuracies])+'\n') | |
results_txt.close() | |
scheduler.step() | |
if epoch % cfg.training.log_interval == 0 and epoch != start_epoch: | |
eval_model(epoch, checkpoint_dir, device, valid_dataloader, encoder, encoder_lf0, cpc, encoder_spk, cs_mi_net, ps_mi_net, cp_mi_net, decoder, cfg) | |
ctime = time.time() | |
print("epoch:{}, global step:{}, recon loss:{:.3f}, cpc loss:{:.3f}, vq loss:{:.3f}, perpexlity:{:.3f}, lld cs loss:{:.3f}, mi cs loss:{:.3E}, lld ps loss:{:.3f}, mi ps loss:{:.3f}, lld cp loss:{:.3f}, mi cp loss:{:.3f}, used time:{:.3f}s" | |
.format(epoch, global_step, average_recon_loss, average_cpc_loss, average_vq_loss, average_perplexity, average_lld_cs_loss, average_mi_cs_loss, average_lld_ps_loss, average_mi_ps_loss, average_lld_cp_loss, average_mi_cp_loss, ctime-stime)) | |
print(100 * average_accuracies) | |
stime = time.time() | |
if epoch % cfg.training.checkpoint_interval == 0 and epoch != start_epoch: | |
save_checkpoint(encoder, encoder_lf0, cpc, encoder_spk, \ | |
cs_mi_net, ps_mi_net, cp_mi_net, decoder, \ | |
optimizer, optimizer_cs_mi_net, optimizer_ps_mi_net, optimizer_cp_mi_net, scheduler, amp, epoch, checkpoint_dir, cfg) | |
if __name__ == "__main__": | |
train_model() | |