Akatuki25's picture
Add seed-vc Python files without binary examples
1b8b9eb
import os
import sys
os.environ['HF_HUB_CACHE'] = './checkpoints/hf_cache'
import torch
import torch.multiprocessing as mp
import random
import librosa
import yaml
import argparse
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import glob
import time
from tqdm import tqdm
import shutil
import accelerate
from optimizers import build_optimizer
from data.ft_dataset import build_ft_dataloader
import hydra
from omegaconf import DictConfig
from accelerate import Accelerator
from accelerate import DistributedDataParallelKwargs
from accelerate.logging import get_logger
class Trainer:
def __init__(
self,
config_path,
pretrained_cfm_ckpt_path,
pretrained_ar_ckpt_path,
data_dir,
run_name,
batch_size=0,
num_workers=0,
steps=1000,
save_interval=500,
max_epochs=1000,
train_cfm=True,
train_ar=False,
mixed_precision=None,
):
self.config_path = config_path
self.mixed_precision = mixed_precision
# Load configuration
self.config = yaml.safe_load(open(config_path))
# Setup logging directory
self.log_dir = os.path.join("runs", run_name)
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir, exist_ok=True)
shutil.copy(config_path, os.path.join(self.log_dir, os.path.basename(config_path)))
# Setup accelerator
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True, broadcast_buffers=False)
self.accelerator = Accelerator(
project_dir=self.log_dir,
split_batches=True,
kwargs_handlers=[ddp_kwargs],
mixed_precision=mixed_precision
)
self.device = self.accelerator.device
# Initialize training parameters
self._init_dataloader(
data_dir=data_dir,
batch_size=batch_size,
num_workers=num_workers,
spect_params=self.config['mel_fn'],
sr=self.config['sr'],
)
# Initialize models and optimizers
self._init_models(train_cfm=train_cfm, train_ar=train_ar)
# Load checkpoint if available
self._load_checkpoint(pretrained_cfm_ckpt_path, pretrained_ar_ckpt_path)
# Initialize training parameters
self.iters = 0
self.start_epoch = 0
self.log_interval = 10
self.max_steps = steps
self.save_interval = save_interval
self.max_epochs = max_epochs
def _init_dataloader(self, data_dir, batch_size, num_workers, spect_params, sr):
self.spect_params = spect_params
self.sr = sr
# Initialize dataloader
self.train_dataloader = build_ft_dataloader(
data_dir,
spect_params,
self.sr,
batch_size=batch_size,
num_workers=num_workers,
)
def _init_models(self, train_cfm=True, train_ar=False):
"""Initialize models and optimizers"""
assert train_cfm or train_ar, "At least one model should be trained"
self.train_cfm = train_cfm
self.train_ar = train_ar
# Initialize main model
self._init_main_model(train_cfm=train_cfm, train_ar=train_ar)
# Initialize optimizers
self._init_optimizers()
def _init_main_model(self, train_cfm=True, train_ar=False):
"""Initialize the main model"""
with self.accelerator.main_process_first():
cfg = DictConfig(self.config)
self.model = hydra.utils.instantiate(cfg).to(self.device)
for p in self.model.parameters():
p.requires_grad = False
if train_cfm:
for p in self.model.cfm.parameters():
p.requires_grad = True
for p in self.model.cfm_length_regulator.parameters():
p.requires_grad = True
if train_ar:
for p in self.model.ar.parameters():
p.requires_grad = True
for p in self.model.ar_length_regulator.parameters():
p.requires_grad = True
def _init_optimizers(self):
"""Initialize optimizers and schedulers"""
from optimizers import build_single_optimizer
self.optimizer, self.scheduler = build_single_optimizer(
self.model,
lr=2e-5,
)
self.optimizer = self.accelerator.prepare(self.optimizer)
self.scheduler = self.accelerator.prepare(self.scheduler)
def _find_checkpoint(self, name_pattern, max_keep=1):
"""Find checkpoint files in the specified directory"""
available_checkpoints = glob.glob(os.path.join(self.log_dir, name_pattern))
if len(available_checkpoints) > max_keep - 1:
# find the checkpoint that has the highest step number
latest_checkpoint = max(
available_checkpoints, key=lambda x: int(x.split("_")[-1].split(".")[0])
)
earliest_checkpoint = min(
available_checkpoints, key=lambda x: int(x.split("_")[-1].split(".")[0])
)
# delete the earliest checkpoint
if (
earliest_checkpoint != latest_checkpoint
and self.accelerator.is_main_process
and len(available_checkpoints) > max_keep
):
os.remove(earliest_checkpoint)
print(f"Removed {earliest_checkpoint}")
return latest_checkpoint
else:
return None
def _load_checkpoint(self, pretrained_cfm_ckpt_path, pretrained_ar_ckpt_path):
"""Load checkpoint if available"""
cfm_checkpoint_path = pretrained_cfm_ckpt_path or self._find_checkpoint("CFM_epoch_*_step_*.pth", max_keep=1)
ar_checkpoint_path = pretrained_ar_ckpt_path or self._find_checkpoint("AR_epoch_*_step_*.pth", max_keep=1)
with self.accelerator.main_process_first():
if cfm_checkpoint_path:
print(f"Loading CFM checkpoint from {cfm_checkpoint_path}")
if ar_checkpoint_path:
print(f"Loading AR checkpoint from {ar_checkpoint_path}")
self.model.load_checkpoints(cfm_checkpoint_path=cfm_checkpoint_path, ar_checkpoint_path=ar_checkpoint_path)
self.model = self.accelerator.prepare(self.model)
def filter_state_dict_shapes(self, params, model):
model_state_dict = model.state_dict()
filtered_state_dict = {
k: v
for k, v in params.items()
if k in model_state_dict and v.shape == model_state_dict[k].shape
}
skipped_keys = set(params.keys()) - set(filtered_state_dict.keys())
if skipped_keys:
print(
f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
)
return filtered_state_dict, skipped_keys
def train(self):
"""Main training loop"""
for epoch in range(self.start_epoch, self.start_epoch + 1000):
epoch_start_time = time.time()
try:
self.train_dataloader.sampler.set_epoch(epoch)
except AttributeError:
pass
self.model.train()
for i, batch in enumerate(tqdm(self.train_dataloader)):
# Process batch
self._process_batch(epoch, i, batch)
if self.iters >= self.max_steps and self.accelerator.is_main_process:
print("Reached max steps, stopping training")
self._save_checkpoint(epoch)
exit()
# Log epoch completion
if self.accelerator.is_main_process:
print(f"Epoch {epoch} completed in {time.time() - epoch_start_time:.2f} seconds")
if epoch + 1 >= self.max_epochs and self.accelerator.is_main_process:
print("Reached max epochs, stopping training")
self._save_checkpoint(epoch)
exit()
def _process_batch(self, epoch, i, batch):
"""Process a single batch"""
# Move batch to device
waves, mels, wave_lens, mel_lens = batch
# Resample to 16kHz for ASR models
waves_16k = torchaudio.functional.resample(waves, self.sr, 16000)
wave_lengths_16k = (wave_lens.float() * 16000 / self.sr).long()
# Forward pass and loss calculation
with self.accelerator.autocast():
loss_ar, loss_cfm = self.model(
waves_16k.to(self.device),
mels.to(self.device),
wave_lengths_16k.to(self.device),
mel_lens.to(self.device),
forward_ar=self.train_ar,
forward_cfm=self.train_cfm,
)
loss = loss_ar + loss_cfm
self.accelerator.backward(loss)
grad_norm_g = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), 1000.0
)
self.optimizer.step()
self.scheduler.step(self.iters)
self.optimizer.zero_grad()
# Log training progress
self._log_training_progress(epoch, i, loss, loss_ar, loss_cfm, grad_norm_g)
# Save checkpoint
if self.iters != 0 and self.iters % self.save_interval == 0 and self.accelerator.is_main_process:
self._save_checkpoint(epoch)
# Increment iteration counter
self.iters += 1
def _log_training_progress(self, epoch, i, loss, loss_ar, loss_cfm, grad_norm_g):
"""Log training progress to tensorboard and wandb"""
if self.iters % self.log_interval == 0 and self.accelerator.is_main_process:
with torch.no_grad():
cur_lr = self.scheduler.get_last_lr()[0] if i != 0 else 0
# Log to console
print("Epoch %d, Iteration %d, Loss: %.4f, Loss AR: %.4f, Loss CFM: %.4f, Grad Norm: %.4f, LR: %.6f"
% (epoch, i, loss.item(), loss_ar.item(), loss_cfm.item(), grad_norm_g, cur_lr))
def _save_checkpoint(self, epoch):
"""Save model checkpoint"""
print('Saving checkpoint...')
if self.train_ar:
state = {
'net': {
'ar': self.accelerator.unwrap_model(self.model).ar.state_dict(),
'length_regulator': self.accelerator.unwrap_model(self.model).ar_length_regulator.state_dict(),
},
'iters': self.iters,
'epoch': epoch,
}
save_path = os.path.join(self.log_dir, 'AR_epoch_%05d_step_%05d.pth' % (epoch, self.iters))
torch.save(state, save_path)
print(f"Saved AR checkpoint to {save_path}")
# Find all checkpoints and remove old ones
self._remove_old_checkpoints("AR_epoch_*_step_*.pth", max_keep=1)
if self.train_cfm:
state = {
'net': {
'cfm': self.accelerator.unwrap_model(self.model).cfm.state_dict(),
'length_regulator': self.accelerator.unwrap_model(self.model).cfm_length_regulator.state_dict(),
},
'iters': self.iters,
'epoch': epoch,
}
save_path = os.path.join(self.log_dir, 'CFM_epoch_%05d_step_%05d.pth' % (epoch, self.iters))
torch.save(state, save_path)
print(f"Saved CFM checkpoint to {save_path}")
# Find all checkpoints and remove old ones
self._remove_old_checkpoints("CFM_epoch_*_step_*.pth", max_keep=1)
def _remove_old_checkpoints(self, name_pattern, max_keep=1):
"""Remove old checkpoints"""
checkpoints = glob.glob(os.path.join(self.log_dir, name_pattern))
if len(checkpoints) > max_keep:
# Sort by step
checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
# Remove all except last 1
for cp in checkpoints[:-max_keep]:
os.remove(cp)
def main(args):
trainer = Trainer(
config_path=args.config,
pretrained_cfm_ckpt_path=args.pretrained_cfm_ckpt,
pretrained_ar_ckpt_path=args.pretrained_ar_ckpt,
data_dir=args.dataset_dir,
run_name=args.run_name,
batch_size=args.batch_size,
steps=args.max_steps,
max_epochs=args.max_epochs,
save_interval=args.save_every,
num_workers=args.num_workers,
train_cfm=args.train_cfm,
train_ar=args.train_ar,
)
trainer.train()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='configs/v2/vc_wrapper.yaml')
parser.add_argument('--pretrained-cfm-ckpt', type=str, default=None)
parser.add_argument('--pretrained-ar-ckpt', type=str, default=None)
parser.add_argument('--dataset-dir', type=str, default='/path/to/dataset')
parser.add_argument('--run-name', type=str, default='my_run')
parser.add_argument('--batch-size', type=int, default=2)
parser.add_argument('--max-steps', type=int, default=1000)
parser.add_argument('--max-epochs', type=int, default=1000)
parser.add_argument('--save-every', type=int, default=500)
parser.add_argument('--num-workers', type=int, default=0)
parser.add_argument('--train-cfm', action='store_true', help='Train CFM model')
parser.add_argument('--train-ar', action='store_true', help='Train AR model')
args = parser.parse_args()
main(args)