GPT-SoVITS-experiment / AR /exps /train_librilight_6k.py
Ricecake123's picture
first commit
e79b770
raw
history blame contribute delete
No virus
5.44 kB
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
import argparse
import logging
import os
from pathlib import Path
import torch
from pytorch_lightning import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DDPStrategy
from AR.data.data_module_librilight_6k import Text2SemanticDataModule
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from soundstorm.utils import get_newest_ckpt
from soundstorm.utils.io import load_yaml_config
logging.getLogger('numba').setLevel(logging.WARNING)
logging.getLogger('matplotlib').setLevel(logging.WARNING)
torch.set_float32_matmul_precision('high')
def main(args):
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
ckpt_dir = output_dir / 'ckpt'
ckpt_dir.mkdir(parents=True, exist_ok=True)
config = load_yaml_config(args.config_file)
seed_everything(config["train"]["seed"], workers=True)
ckpt_callback: ModelCheckpoint = ModelCheckpoint(
save_top_k=-1,
save_on_train_epoch_end=False,
every_n_train_steps=config["train"]["every_n_train_steps"],
dirpath=ckpt_dir)
logger = WandbLogger(
project="AR_S1_LibriLight",
name=output_dir.stem,
save_dir=output_dir,
# resume the loss curve
resume=True,
# id='k19kvsq8'
)
trainer: Trainer = Trainer(
max_epochs=config["train"]["epochs"],
accelerator='gpu',
devices=-1,
benchmark=False,
fast_dev_run=False,
strategy=DDPStrategy(find_unused_parameters=True),
precision=config["train"]["precision"],
logger=logger,
callbacks=[ckpt_callback])
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
config, output_dir)
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
config,
train_semantic_dirs=args.train_semantic_dirs,
train_phoneme_dirs=args.train_phoneme_dirs,
dev_semantic_dirs=args.dev_semantic_dirs,
dev_phoneme_dirs=args.dev_phoneme_dirs,
train_non_speech_dirs=args.train_non_speech_dirs,
dev_non_speech_dirs=args.dev_non_speech_dirs)
try:
newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
ckpt_path = ckpt_dir / newest_ckpt_name
except Exception:
ckpt_path = None
print("ckpt_path:", ckpt_path)
trainer.fit(model, data_module, ckpt_path=ckpt_path)
# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--config_file',
type=str,
default='conf/default.yaml',
help='path of config file')
# args for dataset
parser.add_argument(
'--train_semantic_dirs',
type=list,
nargs='+',
default=["dump/small/train/"],
help='dirs of train semantic')
parser.add_argument(
'--train_phoneme_dirs',
type=list,
nargs='+',
default=["dump/small/train/"],
help='dirs of train phoneme')
parser.add_argument(
'--dev_semantic_dirs',
type=list,
nargs='+',
default=["dump/small/dev/"],
help='dirs of dev semantic')
parser.add_argument(
'--dev_phoneme_dirs',
type=list,
nargs='+',
default=["dump/small/dev/"],
help='dirs of dev phoneme')
parser.add_argument(
'--output_dir',
type=str,
default='exp/default',
help='directory to save the results')
parser.add_argument(
'--train_non_speech_dirs',
type=list,
nargs='+',
default=None,
help='dirs of train non_speech data')
parser.add_argument(
'--dev_non_speech_dirs',
type=list,
nargs='+',
default=None,
help='dirs of dev non_speech data')
args = parser.parse_args()
new_train_semantic_dirs = []
new_train_phoneme_dirs = []
new_dev_semantic_dirs = []
new_dev_phoneme_dirs = []
new_train_non_speech_dirs = []
new_dev_non_speech_dirs = []
# format dataset dirs
for item in args.train_semantic_dirs:
new_train_semantic_dirs.append(''.join(item))
args.train_semantic_dirs = new_train_semantic_dirs
for item in args.train_phoneme_dirs:
new_train_phoneme_dirs.append(''.join(item))
args.train_phoneme_dirs = new_train_phoneme_dirs
for item in args.dev_semantic_dirs:
new_dev_semantic_dirs.append(''.join(item))
args.dev_semantic_dirs = new_dev_semantic_dirs
for item in args.dev_phoneme_dirs:
new_dev_phoneme_dirs.append(''.join(item))
args.dev_phoneme_dirs = new_dev_phoneme_dirs
if args.train_non_speech_dirs is not None:
for item in args.train_non_speech_dirs:
new_train_non_speech_dirs.append(''.join(item))
args.train_non_speech_dirs = new_train_non_speech_dirs
if args.dev_non_speech_dirs is not None:
for item in args.dev_non_speech_dirs:
new_dev_non_speech_dirs.append(''.join(item))
args.dev_non_speech_dirs = new_dev_non_speech_dirs
logging.info(str(args))
main(args)