|
from __future__ import annotations |
|
|
|
import os |
|
import re |
|
import warnings |
|
from logging import getLogger |
|
from multiprocessing import cpu_count |
|
from pathlib import Path |
|
from typing import Any |
|
|
|
import lightning.pytorch as pl |
|
import torch |
|
from lightning.pytorch.accelerators import MPSAccelerator, TPUAccelerator |
|
from lightning.pytorch.loggers import TensorBoardLogger |
|
from lightning.pytorch.strategies.ddp import DDPStrategy |
|
from lightning.pytorch.tuner import Tuner |
|
from torch.cuda.amp import autocast |
|
from torch.nn import functional as F |
|
from torch.utils.data import DataLoader |
|
from torch.utils.tensorboard.writer import SummaryWriter |
|
|
|
import so_vits_svc_fork.f0 |
|
import so_vits_svc_fork.modules.commons as commons |
|
import so_vits_svc_fork.utils |
|
|
|
from so_vits_svc_fork import utils |
|
from so_vits_svc_fork.dataset import TextAudioCollate, TextAudioDataset |
|
from so_vits_svc_fork.logger import is_notebook |
|
from so_vits_svc_fork.modules.descriminators import MultiPeriodDiscriminator |
|
from so_vits_svc_fork.modules.losses import discriminator_loss, feature_loss, generator_loss, kl_loss |
|
from so_vits_svc_fork.modules.mel_processing import mel_spectrogram_torch |
|
from so_vits_svc_fork.modules.synthesizers import SynthesizerTrn |
|
|
|
from so_vits_svc_fork.train import VitsLightning, VCDataModule |
|
|
|
LOG = getLogger(__name__) |
|
torch.set_float32_matmul_precision("high") |
|
|
|
|
|
from pathlib import Path |
|
|
|
from huggingface_hub import create_repo, upload_folder, login, list_repo_files, delete_file |
|
|
|
if os.environ.get("HF_TOKEN"): |
|
login(os.environ.get("HF_TOKEN")) |
|
|
|
|
|
class HuggingFacePushCallback(pl.Callback): |
|
def __init__(self, repo_id, private=False, every=100): |
|
self.repo_id = repo_id |
|
self.private = private |
|
self.every = every |
|
|
|
def on_validation_epoch_end(self, trainer, pl_module): |
|
self.repo_url = create_repo( |
|
repo_id=self.repo_id, |
|
exist_ok=True, |
|
private=self.private |
|
) |
|
self.repo_id = self.repo_url.repo_id |
|
if pl_module.global_step == 0: |
|
return |
|
print(f"\nπ€ Pushing to Hugging Face Hub: {self.repo_url}...") |
|
model_dir = pl_module.hparams.model_dir |
|
upload_folder( |
|
repo_id=self.repo_id, |
|
folder_path=model_dir, |
|
path_in_repo=".", |
|
commit_message="π» cheers", |
|
ignore_patterns=["*.git*", "*README.md*", "*__pycache__*"], |
|
) |
|
ckpt_pattern = r'^(D_|G_)\d+\.pth$' |
|
todelete = [] |
|
repo_ckpts = [x for x in list_repo_files(self.repo_id) if re.match(ckpt_pattern, x) and x not in ["G_0.pth", "D_0.pth"]] |
|
local_ckpts = [x.name for x in Path(model_dir).glob("*.pth") if re.match(ckpt_pattern, x.name)] |
|
to_delete = set(repo_ckpts) - set(local_ckpts) |
|
|
|
for fname in to_delete: |
|
print(f"π Deleting {fname} from repo") |
|
delete_file(fname, self.repo_id) |
|
|
|
|
|
def train( |
|
config_path: Path | str, model_path: Path | str, reset_optimizer: bool = False |
|
): |
|
config_path = Path(config_path) |
|
model_path = Path(model_path) |
|
|
|
hparams = utils.get_backup_hparams(config_path, model_path) |
|
utils.ensure_pretrained_model(model_path, hparams.model.get("type_", "hifi-gan")) |
|
|
|
datamodule = VCDataModule(hparams) |
|
strategy = ( |
|
( |
|
"ddp_find_unused_parameters_true" |
|
if os.name != "nt" |
|
else DDPStrategy(find_unused_parameters=True, process_group_backend="gloo") |
|
) |
|
if torch.cuda.device_count() > 1 |
|
else "auto" |
|
) |
|
LOG.info(f"Using strategy: {strategy}") |
|
|
|
callbacks = [] |
|
if hparams.train.push_to_hub: |
|
callbacks.append(HuggingFacePushCallback(hparams.train.repo_id, hparams.train.private)) |
|
if not is_notebook(): |
|
callbacks.append(pl.callbacks.RichProgressBar()) |
|
if callbacks == []: |
|
callbacks = None |
|
|
|
trainer = pl.Trainer( |
|
logger=TensorBoardLogger( |
|
model_path, "lightning_logs", hparams.train.get("log_version", 0) |
|
), |
|
|
|
val_check_interval=hparams.train.eval_interval, |
|
max_epochs=hparams.train.epochs, |
|
check_val_every_n_epoch=None, |
|
precision="16-mixed" |
|
if hparams.train.fp16_run |
|
else "bf16-mixed" |
|
if hparams.train.get("bf16_run", False) |
|
else 32, |
|
strategy=strategy, |
|
callbacks=callbacks, |
|
benchmark=True, |
|
enable_checkpointing=False, |
|
) |
|
tuner = Tuner(trainer) |
|
model = VitsLightning(reset_optimizer=reset_optimizer, **hparams) |
|
|
|
|
|
batch_size = hparams.train.batch_size |
|
batch_split = str(batch_size).split("-") |
|
batch_size = batch_split[0] |
|
init_val = 2 if len(batch_split) <= 1 else int(batch_split[1]) |
|
max_trials = 25 if len(batch_split) <= 2 else int(batch_split[2]) |
|
if batch_size == "auto": |
|
batch_size = "binsearch" |
|
if batch_size in ["power", "binsearch"]: |
|
model.tuning = True |
|
tuner.scale_batch_size( |
|
model, |
|
mode=batch_size, |
|
datamodule=datamodule, |
|
steps_per_trial=1, |
|
init_val=init_val, |
|
max_trials=max_trials, |
|
) |
|
model.tuning = False |
|
else: |
|
batch_size = int(batch_size) |
|
|
|
"""if hparams.train.learning_rate == "auto": |
|
lr_finder = tuner.lr_find(model) |
|
LOG.info(lr_finder.results) |
|
fig = lr_finder.plot(suggest=True) |
|
fig.savefig(model_path / "lr_finder.png")""" |
|
|
|
trainer.fit(model, datamodule=datamodule) |
|
|
|
if __name__ == '__main__': |
|
train('configs/44k/config.json', 'logs/44k') |
|
|