GPT-SoVITS-v3 / GPT_SoVITS /s2_train_v3.py
kevinwang676's picture
Upload folder using huggingface_hub
2c3577a verified
import warnings
warnings.filterwarnings("ignore")
import utils, os
hps = utils.get_hparams(stage=2)
os.environ["CUDA_VISIBLE_DEVICES"] = hps.train.gpu_numbers.replace("-", ",")
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
import torch.distributed as dist, traceback
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import logging, traceback
logging.getLogger("matplotlib").setLevel(logging.INFO)
logging.getLogger("h5py").setLevel(logging.INFO)
logging.getLogger("numba").setLevel(logging.INFO)
from random import randint
from module import commons
from module.data_utils import (
TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
DistributedBucketSampler,
)
from module.models import (
SynthesizerTrnV3 as SynthesizerTrn,
MultiPeriodDiscriminator,
)
from module.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
from process_ckpt import savee
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = False
###反正A100fp32更快,那试试tf32吧
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就快一丁点),对于结果造成不了影响
# from config import pretrained_s2G,pretrained_s2D
global_step = 0
device = "cpu" # cuda以外的设备,等mps优化后加入
def main():
if torch.cuda.is_available():
n_gpus = torch.cuda.device_count()
else:
n_gpus = 1
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
mp.spawn(
run,
nprocs=n_gpus,
args=(
n_gpus,
hps,
),
)
def run(rank, n_gpus, hps):
global global_step
if rank == 0:
logger = utils.get_logger(hps.data.exp_dir)
logger.info(hps)
# utils.check_git_hash(hps.s2_ckpt_dir)
writer = SummaryWriter(log_dir=hps.s2_ckpt_dir)
writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval"))
dist.init_process_group(
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False",
world_size=n_gpus,
rank=rank,
)
torch.manual_seed(hps.train.seed)
if torch.cuda.is_available():
torch.cuda.set_device(rank)
train_dataset = TextAudioSpeakerLoader(hps.data) ########
train_sampler = DistributedBucketSampler(
train_dataset,
hps.train.batch_size,
[
32,
300,
400,
500,
600,
700,
800,
900,
1000,
# 1100,
# 1200,
# 1300,
# 1400,
# 1500,
# 1600,
# 1700,
# 1800,
# 1900,
],
num_replicas=n_gpus,
rank=rank,
shuffle=True,
)
collate_fn = TextAudioSpeakerCollate()
train_loader = DataLoader(
train_dataset,
num_workers=6,
shuffle=False,
pin_memory=True,
collate_fn=collate_fn,
batch_sampler=train_sampler,
persistent_workers=True,
prefetch_factor=4,
)
# if rank == 0:
# eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True)
# eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False,
# batch_size=1, pin_memory=True,
# drop_last=False, collate_fn=collate_fn)
net_g = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).cuda(rank) if torch.cuda.is_available() else SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).to(device)
# net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) if torch.cuda.is_available() else MultiPeriodDiscriminator(hps.model.use_spectral_norm).to(device)
# for name, param in net_g.named_parameters():
# if not param.requires_grad:
# print(name, "not requires_grad")
optim_g = torch.optim.AdamW(
filter(lambda p: p.requires_grad, net_g.parameters()),###默认所有层lr一致
hps.train.learning_rate,
betas=hps.train.betas,
eps=hps.train.eps,
)
# optim_d = torch.optim.AdamW(
# net_d.parameters(),
# hps.train.learning_rate,
# betas=hps.train.betas,
# eps=hps.train.eps,
# )
if torch.cuda.is_available():
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
# net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
else:
net_g = net_g.to(device)
# net_d = net_d.to(device)
try: # 如果能加载自动resume
# _, _, _, epoch_str = utils.load_checkpoint(
# utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_*.pth"),
# net_d,
# optim_d,
# ) # D多半加载没事
# if rank == 0:
# logger.info("loaded D")
# _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
_, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path("%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_*.pth"),
net_g,
optim_g,
)
global_step = (epoch_str - 1) * len(train_loader)
# epoch_str = 1
# global_step = 0
except: # 如果首次不能加载,加载pretrain
# traceback.print_exc()
epoch_str = 1
global_step = 0
if hps.train.pretrained_s2G != ""and hps.train.pretrained_s2G != None and os.path.exists(hps.train.pretrained_s2G):
if rank == 0:
logger.info("loaded pretrained %s" % hps.train.pretrained_s2G)
print(
net_g.module.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False,
) if torch.cuda.is_available() else net_g.load_state_dict(
torch.load(hps.train.pretrained_s2G, map_location="cpu")["weight"],
strict=False,
)
) ##测试不加载优化器
# if hps.train.pretrained_s2D != ""and hps.train.pretrained_s2D != None and os.path.exists(hps.train.pretrained_s2D):
# if rank == 0:
# logger.info("loaded pretrained %s" % hps.train.pretrained_s2D)
# print(
# net_d.module.load_state_dict(
# torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
# ) if torch.cuda.is_available() else net_d.load_state_dict(
# torch.load(hps.train.pretrained_s2D, map_location="cpu")["weight"]
# )
# )
# scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
optim_g, gamma=hps.train.lr_decay, last_epoch=-1
)
# scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
# optim_d, gamma=hps.train.lr_decay, last_epoch=-1
# )
for _ in range(epoch_str):
scheduler_g.step()
# scheduler_d.step()
scaler = GradScaler(enabled=hps.train.fp16_run)
net_d=optim_d=scheduler_d=None
for epoch in range(epoch_str, hps.train.epochs + 1):
if rank == 0:
train_and_evaluate(
rank,
epoch,
hps,
[net_g, net_d],
[optim_g, optim_d],
[scheduler_g, scheduler_d],
scaler,
# [train_loader, eval_loader], logger, [writer, writer_eval])
[train_loader, None],
logger,
[writer, writer_eval],
)
else:
train_and_evaluate(
rank,
epoch,
hps,
[net_g, net_d],
[optim_g, optim_d],
[scheduler_g, scheduler_d],
scaler,
[train_loader, None],
None,
None,
)
scheduler_g.step()
# scheduler_d.step()
def train_and_evaluate(
rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers
):
net_g, net_d = nets
optim_g, optim_d = optims
# scheduler_g, scheduler_d = schedulers
train_loader, eval_loader = loaders
if writers is not None:
writer, writer_eval = writers
train_loader.batch_sampler.set_epoch(epoch)
global global_step
net_g.train()
# net_d.train()
# for batch_idx, (
# ssl,
# ssl_lengths,
# spec,
# spec_lengths,
# y,
# y_lengths,
# text,
# text_lengths,
# ) in enumerate(tqdm(train_loader)):
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(tqdm(train_loader)):
if torch.cuda.is_available():
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
rank, non_blocking=True
)
mel, mel_lengths = mel.cuda(rank, non_blocking=True), mel_lengths.cuda(
rank, non_blocking=True
)
ssl = ssl.cuda(rank, non_blocking=True)
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(
rank, non_blocking=True
)
else:
spec, spec_lengths = spec.to(device), spec_lengths.to(device)
mel, mel_lengths = mel.to(device), mel_lengths.to(device)
ssl = ssl.to(device)
ssl.requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text, text_lengths = text.to(device), text_lengths.to(device)
with autocast(enabled=hps.train.fp16_run):
cfm_loss = net_g(ssl, spec, mel,ssl_lengths,spec_lengths, text, text_lengths,mel_lengths, use_grad_ckpt=hps.train.grad_ckpt)
loss_gen_all=cfm_loss
optim_g.zero_grad()
scaler.scale(loss_gen_all).backward()
scaler.unscale_(optim_g)
grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
scaler.step(optim_g)
scaler.update()
if rank == 0:
if global_step % hps.train.log_interval == 0:
lr = optim_g.param_groups[0]['lr']
# losses = [commit_loss,cfm_loss,mel_loss,loss_disc, loss_gen, loss_fm, loss_mel, loss_kl]
losses = [cfm_loss]
logger.info('Train Epoch: {} [{:.0f}%]'.format(
epoch,
100. * batch_idx / len(train_loader)))
logger.info([x.item() for x in losses] + [global_step, lr])
scalar_dict = {"loss/g/total": loss_gen_all, "learning_rate": lr, "grad_norm_g": grad_norm_g}
# image_dict = {
# "slice/mel_org": utils.plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()),
# "slice/mel_gen": utils.plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()),
# "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
# "all/stats_ssl": utils.plot_spectrogram_to_numpy(stats_ssl[0].data.cpu().numpy()),
# }
utils.summarize(
writer=writer,
global_step=global_step,
# images=image_dict,
scalars=scalar_dict)
# if global_step % hps.train.eval_interval == 0:
# # evaluate(hps, net_g, eval_loader, writer_eval)
# utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch,os.path.join(hps.s2_ckpt_dir, "G_{}.pth".format(global_step)),scaler)
# # utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch,os.path.join(hps.s2_ckpt_dir, "D_{}.pth".format(global_step)),scaler)
# # keep_ckpts = getattr(hps.train, 'keep_ckpts', 3)
# # if keep_ckpts > 0:
# # utils.clean_checkpoints(path_to_models=hps.s2_ckpt_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True)
global_step += 1
if epoch % hps.train.save_every_epoch == 0 and rank == 0:
if hps.train.if_save_latest == 0:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(global_step)
),
)
# utils.save_checkpoint(
# net_d,
# optim_d,
# hps.train.learning_rate,
# epoch,
# os.path.join(
# "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(global_step)
# ),
# )
else:
utils.save_checkpoint(
net_g,
optim_g,
hps.train.learning_rate,
epoch,
os.path.join(
"%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "G_{}.pth".format(233333333333)
),
)
# utils.save_checkpoint(
# net_d,
# optim_d,
# hps.train.learning_rate,
# epoch,
# os.path.join(
# "%s/logs_s2_%s" % (hps.data.exp_dir,hps.model.version), "D_{}.pth".format(233333333333)
# ),
# )
if rank == 0 and hps.train.if_save_every_weights == True:
if hasattr(net_g, "module"):
ckpt = net_g.module.state_dict()
else:
ckpt = net_g.state_dict()
logger.info(
"saving ckpt %s_e%s:%s"
% (
hps.name,
epoch,
savee(
ckpt,
hps.name + "_e%s_s%s" % (epoch, global_step),
epoch,
global_step,
hps,
),
)
)
if rank == 0:
logger.info("====> Epoch: {}".format(epoch))
if __name__ == "__main__":
main()