RMSnow's picture
init and interface
df2accb
raw
history blame
No virus
3.72 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
import torch
import torch.nn as nn
from models.base.new_trainer import BaseTrainer
from models.svc.base.svc_dataset import SVCCollator, SVCDataset
class SVCTrainer(BaseTrainer):
r"""The base trainer for all SVC models. It inherits from BaseTrainer and implements
``build_criterion``, ``_build_dataset`` and ``_build_singer_lut`` methods. You can inherit from this
class, and implement ``_build_model``, ``_forward_step``.
"""
def __init__(self, args=None, cfg=None):
self.args = args
self.cfg = cfg
self._init_accelerator()
# Only for SVC tasks
with self.accelerator.main_process_first():
self.singers = self._build_singer_lut()
# Super init
BaseTrainer.__init__(self, args, cfg)
# Only for SVC tasks
self.task_type = "SVC"
self.logger.info("Task type: {}".format(self.task_type))
### Following are methods only for SVC tasks ###
# TODO: LEGACY CODE, NEED TO BE REFACTORED
def _build_dataset(self):
return SVCDataset, SVCCollator
@staticmethod
def _build_criterion():
criterion = nn.MSELoss(reduction="none")
return criterion
@staticmethod
def _compute_loss(criterion, y_pred, y_gt, loss_mask):
"""
Args:
criterion: MSELoss(reduction='none')
y_pred, y_gt: (bs, seq_len, D)
loss_mask: (bs, seq_len, 1)
Returns:
loss: Tensor of shape []
"""
# (bs, seq_len, D)
loss = criterion(y_pred, y_gt)
# expand loss_mask to (bs, seq_len, D)
loss_mask = loss_mask.repeat(1, 1, loss.shape[-1])
loss = torch.sum(loss * loss_mask) / torch.sum(loss_mask)
return loss
def _save_auxiliary_states(self):
"""
To save the singer's look-up table in the checkpoint saving path
"""
with open(
os.path.join(self.tmp_checkpoint_save_path, self.cfg.preprocess.spk2id), "w"
) as f:
json.dump(self.singers, f, indent=4, ensure_ascii=False)
def _build_singer_lut(self):
resumed_singer_path = None
if self.args.resume_from_ckpt_path and self.args.resume_from_ckpt_path != "":
resumed_singer_path = os.path.join(
self.args.resume_from_ckpt_path, self.cfg.preprocess.spk2id
)
if os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)):
resumed_singer_path = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
if resumed_singer_path:
with open(resumed_singer_path, "r") as f:
singers = json.load(f)
else:
singers = dict()
for dataset in self.cfg.dataset:
singer_lut_path = os.path.join(
self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
)
with open(singer_lut_path, "r") as singer_lut_path:
singer_lut = json.load(singer_lut_path)
for singer in singer_lut.keys():
if singer not in singers:
singers[singer] = len(singers)
with open(
os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "w"
) as singer_file:
json.dump(singers, singer_file, indent=4, ensure_ascii=False)
print(
"singers have been dumped to {}".format(
os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
)
)
return singers