|
|
|
|
|
|
|
|
|
|
|
import collections |
|
import json |
|
import os |
|
import sys |
|
import time |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from torch.nn.parallel import DistributedDataParallel |
|
from torch.utils.data import ConcatDataset, DataLoader |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
from models.base.base_sampler import BatchSampler |
|
from utils.util import ( |
|
Logger, |
|
remove_older_ckpt, |
|
save_config, |
|
set_all_random_seed, |
|
ValueWindow, |
|
) |
|
|
|
|
|
class BaseTrainer(object): |
|
def __init__(self, args, cfg): |
|
self.args = args |
|
self.log_dir = args.log_dir |
|
self.cfg = cfg |
|
|
|
self.checkpoint_dir = os.path.join(args.log_dir, "checkpoints") |
|
os.makedirs(self.checkpoint_dir, exist_ok=True) |
|
if not cfg.train.ddp or args.local_rank == 0: |
|
self.sw = SummaryWriter(os.path.join(args.log_dir, "events")) |
|
self.logger = self.build_logger() |
|
self.time_window = ValueWindow(50) |
|
|
|
self.step = 0 |
|
self.epoch = -1 |
|
self.max_epochs = self.cfg.train.epochs |
|
self.max_steps = self.cfg.train.max_steps |
|
|
|
|
|
set_all_random_seed(self.cfg.train.random_seed) |
|
if cfg.train.ddp: |
|
dist.init_process_group(backend="nccl") |
|
|
|
if cfg.model_type not in ["AutoencoderKL", "AudioLDM"]: |
|
self.singers = self.build_singers_lut() |
|
|
|
|
|
self.data_loader = self.build_data_loader() |
|
|
|
|
|
self.model = self.build_model() |
|
print(self.model) |
|
|
|
if isinstance(self.model, dict): |
|
for key, value in self.model.items(): |
|
value.cuda(self.args.local_rank) |
|
if key == "PQMF": |
|
continue |
|
if cfg.train.ddp: |
|
self.model[key] = DistributedDataParallel( |
|
value, device_ids=[self.args.local_rank] |
|
) |
|
else: |
|
self.model.cuda(self.args.local_rank) |
|
if cfg.train.ddp: |
|
self.model = DistributedDataParallel( |
|
self.model, device_ids=[self.args.local_rank] |
|
) |
|
|
|
|
|
self.criterion = self.build_criterion() |
|
if isinstance(self.criterion, dict): |
|
for key, value in self.criterion.items(): |
|
self.criterion[key].cuda(args.local_rank) |
|
else: |
|
self.criterion.cuda(self.args.local_rank) |
|
|
|
|
|
self.optimizer = self.build_optimizer() |
|
self.scheduler = self.build_scheduler() |
|
|
|
|
|
self.config_save_path = os.path.join(self.checkpoint_dir, "args.json") |
|
|
|
def build_logger(self): |
|
log_file = os.path.join(self.checkpoint_dir, "train.log") |
|
logger = Logger(log_file, level=self.args.log_level).logger |
|
|
|
return logger |
|
|
|
def build_dataset(self): |
|
raise NotImplementedError |
|
|
|
def build_data_loader(self): |
|
Dataset, Collator = self.build_dataset() |
|
|
|
datasets_list = [] |
|
for dataset in self.cfg.dataset: |
|
subdataset = Dataset(self.cfg, dataset, is_valid=False) |
|
datasets_list.append(subdataset) |
|
train_dataset = ConcatDataset(datasets_list) |
|
|
|
train_collate = Collator(self.cfg) |
|
|
|
if self.cfg.train.ddp: |
|
raise NotImplementedError("DDP is not supported yet.") |
|
|
|
|
|
batch_sampler = BatchSampler( |
|
cfg=self.cfg, concat_dataset=train_dataset, dataset_list=datasets_list |
|
) |
|
|
|
|
|
train_loader = DataLoader( |
|
train_dataset, |
|
collate_fn=train_collate, |
|
num_workers=self.args.num_workers, |
|
batch_sampler=batch_sampler, |
|
pin_memory=False, |
|
) |
|
if not self.cfg.train.ddp or self.args.local_rank == 0: |
|
datasets_list = [] |
|
for dataset in self.cfg.dataset: |
|
subdataset = Dataset(self.cfg, dataset, is_valid=True) |
|
datasets_list.append(subdataset) |
|
valid_dataset = ConcatDataset(datasets_list) |
|
valid_collate = Collator(self.cfg) |
|
batch_sampler = BatchSampler( |
|
cfg=self.cfg, concat_dataset=valid_dataset, dataset_list=datasets_list |
|
) |
|
valid_loader = DataLoader( |
|
valid_dataset, |
|
collate_fn=valid_collate, |
|
num_workers=1, |
|
batch_sampler=batch_sampler, |
|
) |
|
else: |
|
raise NotImplementedError("DDP is not supported yet.") |
|
|
|
data_loader = {"train": train_loader, "valid": valid_loader} |
|
return data_loader |
|
|
|
def build_singers_lut(self): |
|
|
|
if not os.path.exists(os.path.join(self.log_dir, self.cfg.preprocess.spk2id)): |
|
singers = collections.OrderedDict() |
|
else: |
|
with open( |
|
os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "r" |
|
) as singer_file: |
|
singers = json.load(singer_file) |
|
singer_count = len(singers) |
|
for dataset in self.cfg.dataset: |
|
singer_lut_path = os.path.join( |
|
self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id |
|
) |
|
with open(singer_lut_path, "r") as singer_lut_path: |
|
singer_lut = json.load(singer_lut_path) |
|
for singer in singer_lut.keys(): |
|
if singer not in singers: |
|
singers[singer] = singer_count |
|
singer_count += 1 |
|
with open( |
|
os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "w" |
|
) as singer_file: |
|
json.dump(singers, singer_file, indent=4, ensure_ascii=False) |
|
print( |
|
"singers have been dumped to {}".format( |
|
os.path.join(self.log_dir, self.cfg.preprocess.spk2id) |
|
) |
|
) |
|
return singers |
|
|
|
def build_model(self): |
|
raise NotImplementedError() |
|
|
|
def build_optimizer(self): |
|
raise NotImplementedError |
|
|
|
def build_scheduler(self): |
|
raise NotImplementedError() |
|
|
|
def build_criterion(self): |
|
raise NotImplementedError |
|
|
|
def get_state_dict(self): |
|
raise NotImplementedError |
|
|
|
def save_config_file(self): |
|
save_config(self.config_save_path, self.cfg) |
|
|
|
|
|
def save_checkpoint(self, state_dict, saved_model_path): |
|
torch.save(state_dict, saved_model_path) |
|
|
|
def load_checkpoint(self): |
|
checkpoint_path = os.path.join(self.checkpoint_dir, "checkpoint") |
|
assert os.path.exists(checkpoint_path) |
|
checkpoint_filename = open(checkpoint_path).readlines()[-1].strip() |
|
model_path = os.path.join(self.checkpoint_dir, checkpoint_filename) |
|
assert os.path.exists(model_path) |
|
if not self.cfg.train.ddp or self.args.local_rank == 0: |
|
self.logger.info(f"Re(store) from {model_path}") |
|
checkpoint = torch.load(model_path, map_location="cpu") |
|
return checkpoint |
|
|
|
def load_model(self, checkpoint): |
|
raise NotImplementedError |
|
|
|
def restore(self): |
|
checkpoint = self.load_checkpoint() |
|
self.load_model(checkpoint) |
|
|
|
def train_step(self, data): |
|
raise NotImplementedError( |
|
f"Need to implement function {sys._getframe().f_code.co_name} in " |
|
f"your sub-class of {self.__class__.__name__}. " |
|
) |
|
|
|
@torch.no_grad() |
|
def eval_step(self): |
|
raise NotImplementedError( |
|
f"Need to implement function {sys._getframe().f_code.co_name} in " |
|
f"your sub-class of {self.__class__.__name__}. " |
|
) |
|
|
|
def write_summary(self, losses, stats): |
|
raise NotImplementedError( |
|
f"Need to implement function {sys._getframe().f_code.co_name} in " |
|
f"your sub-class of {self.__class__.__name__}. " |
|
) |
|
|
|
def write_valid_summary(self, losses, stats): |
|
raise NotImplementedError( |
|
f"Need to implement function {sys._getframe().f_code.co_name} in " |
|
f"your sub-class of {self.__class__.__name__}. " |
|
) |
|
|
|
def echo_log(self, losses, mode="Training"): |
|
message = [ |
|
"{} - Epoch {} Step {}: [{:.3f} s/step]".format( |
|
mode, self.epoch + 1, self.step, self.time_window.average |
|
) |
|
] |
|
|
|
for key in sorted(losses.keys()): |
|
if isinstance(losses[key], dict): |
|
for k, v in losses[key].items(): |
|
message.append( |
|
str(k).split("/")[-1] + "=" + str(round(float(v), 5)) |
|
) |
|
else: |
|
message.append( |
|
str(key).split("/")[-1] + "=" + str(round(float(losses[key]), 5)) |
|
) |
|
self.logger.info(", ".join(message)) |
|
|
|
def eval_epoch(self): |
|
self.logger.info("Validation...") |
|
valid_losses = {} |
|
for i, batch_data in enumerate(self.data_loader["valid"]): |
|
for k, v in batch_data.items(): |
|
if isinstance(v, torch.Tensor): |
|
batch_data[k] = v.cuda() |
|
valid_loss, valid_stats, total_valid_loss = self.eval_step(batch_data, i) |
|
for key in valid_loss: |
|
if key not in valid_losses: |
|
valid_losses[key] = 0 |
|
valid_losses[key] += valid_loss[key] |
|
|
|
|
|
|
|
for key in valid_losses: |
|
valid_losses[key] /= i + 1 |
|
self.echo_log(valid_losses, "Valid") |
|
return valid_losses, valid_stats |
|
|
|
def train_epoch(self): |
|
for i, batch_data in enumerate(self.data_loader["train"]): |
|
start_time = time.time() |
|
|
|
for k, v in batch_data.items(): |
|
if isinstance(v, torch.Tensor): |
|
batch_data[k] = v.cuda(self.args.local_rank) |
|
|
|
|
|
train_losses, train_stats, total_loss = self.train_step(batch_data) |
|
self.time_window.append(time.time() - start_time) |
|
|
|
if self.args.local_rank == 0 or not self.cfg.train.ddp: |
|
if self.step % self.args.stdout_interval == 0: |
|
self.echo_log(train_losses, "Training") |
|
|
|
if self.step % self.cfg.train.save_summary_steps == 0: |
|
self.logger.info(f"Save summary as step {self.step}") |
|
self.write_summary(train_losses, train_stats) |
|
|
|
if ( |
|
self.step % self.cfg.train.save_checkpoints_steps == 0 |
|
and self.step != 0 |
|
): |
|
saved_model_name = "step-{:07d}_loss-{:.4f}.pt".format( |
|
self.step, total_loss |
|
) |
|
saved_model_path = os.path.join( |
|
self.checkpoint_dir, saved_model_name |
|
) |
|
saved_state_dict = self.get_state_dict() |
|
self.save_checkpoint(saved_state_dict, saved_model_path) |
|
self.save_config_file() |
|
|
|
remove_older_ckpt( |
|
saved_model_name, |
|
self.checkpoint_dir, |
|
max_to_keep=self.cfg.train.keep_checkpoint_max, |
|
) |
|
|
|
if self.step != 0 and self.step % self.cfg.train.valid_interval == 0: |
|
if isinstance(self.model, dict): |
|
for key in self.model.keys(): |
|
self.model[key].eval() |
|
else: |
|
self.model.eval() |
|
|
|
valid_losses, valid_stats = self.eval_epoch() |
|
if isinstance(self.model, dict): |
|
for key in self.model.keys(): |
|
self.model[key].train() |
|
else: |
|
self.model.train() |
|
|
|
self.write_valid_summary(valid_losses, valid_stats) |
|
self.step += 1 |
|
|
|
def train(self): |
|
for epoch in range(max(0, self.epoch), self.max_epochs): |
|
self.train_epoch() |
|
self.epoch += 1 |
|
if self.step > self.max_steps: |
|
self.logger.info("Training finished!") |
|
break |
|
|