Spaces:
Runtime error
Runtime error
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import json | |
import os | |
import torch | |
import torch.nn as nn | |
from models.base.new_trainer import BaseTrainer | |
from models.svc.base.svc_dataset import SVCCollator, SVCDataset | |
class SVCTrainer(BaseTrainer): | |
r"""The base trainer for all SVC models. It inherits from BaseTrainer and implements | |
``build_criterion``, ``_build_dataset`` and ``_build_singer_lut`` methods. You can inherit from this | |
class, and implement ``_build_model``, ``_forward_step``. | |
""" | |
def __init__(self, args=None, cfg=None): | |
self.args = args | |
self.cfg = cfg | |
self._init_accelerator() | |
# Only for SVC tasks | |
with self.accelerator.main_process_first(): | |
self.singers = self._build_singer_lut() | |
# Super init | |
BaseTrainer.__init__(self, args, cfg) | |
# Only for SVC tasks | |
self.task_type = "SVC" | |
self.logger.info("Task type: {}".format(self.task_type)) | |
### Following are methods only for SVC tasks ### | |
# TODO: LEGACY CODE, NEED TO BE REFACTORED | |
def _build_dataset(self): | |
return SVCDataset, SVCCollator | |
def _build_criterion(): | |
criterion = nn.MSELoss(reduction="none") | |
return criterion | |
def _compute_loss(criterion, y_pred, y_gt, loss_mask): | |
""" | |
Args: | |
criterion: MSELoss(reduction='none') | |
y_pred, y_gt: (bs, seq_len, D) | |
loss_mask: (bs, seq_len, 1) | |
Returns: | |
loss: Tensor of shape [] | |
""" | |
# (bs, seq_len, D) | |
loss = criterion(y_pred, y_gt) | |
# expand loss_mask to (bs, seq_len, D) | |
loss_mask = loss_mask.repeat(1, 1, loss.shape[-1]) | |
loss = torch.sum(loss * loss_mask) / torch.sum(loss_mask) | |
return loss | |
def _save_auxiliary_states(self): | |
""" | |
To save the singer's look-up table in the checkpoint saving path | |
""" | |
with open( | |
os.path.join(self.tmp_checkpoint_save_path, self.cfg.preprocess.spk2id), "w" | |
) as f: | |
json.dump(self.singers, f, indent=4, ensure_ascii=False) | |
def _build_singer_lut(self): | |
resumed_singer_path = None | |
if self.args.resume_from_ckpt_path and self.args.resume_from_ckpt_path != "": | |
resumed_singer_path = os.path.join( | |
self.args.resume_from_ckpt_path, self.cfg.preprocess.spk2id | |
) | |
if os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)): | |
resumed_singer_path = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id) | |
if resumed_singer_path: | |
with open(resumed_singer_path, "r") as f: | |
singers = json.load(f) | |
else: | |
singers = dict() | |
for dataset in self.cfg.dataset: | |
singer_lut_path = os.path.join( | |
self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id | |
) | |
with open(singer_lut_path, "r") as singer_lut_path: | |
singer_lut = json.load(singer_lut_path) | |
for singer in singer_lut.keys(): | |
if singer not in singers: | |
singers[singer] = len(singers) | |
with open( | |
os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "w" | |
) as singer_file: | |
json.dump(singers, singer_file, indent=4, ensure_ascii=False) | |
print( | |
"singers have been dumped to {}".format( | |
os.path.join(self.exp_dir, self.cfg.preprocess.spk2id) | |
) | |
) | |
return singers | |