|
|
|
|
|
|
|
|
|
import sys |
|
|
|
sys.path.append("src") |
|
|
|
import os |
|
import wandb |
|
|
|
import argparse |
|
import yaml |
|
import torch |
|
from pytorch_lightning.strategies.ddp import DDPStrategy |
|
from qa_mdt.audioldm_train.utilities.data.dataset import AudioDataset |
|
from torch.utils.data import DataLoader |
|
from pytorch_lightning.loggers import WandbLogger |
|
from pytorch_lightning import Trainer |
|
from qa_mdt.audioldm_train.modules.latent_encoder.autoencoder import AutoencoderKL |
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
from qa_mdt.audioldm_train.utilities.tools import get_restore_step |
|
|
|
|
|
def listdir_nohidden(path): |
|
for f in os.listdir(path): |
|
if not f.startswith("."): |
|
yield f |
|
|
|
|
|
def main(configs, exp_group_name, exp_name): |
|
if "precision" in configs.keys(): |
|
torch.set_float32_matmul_precision(configs["precision"]) |
|
batch_size = config_yaml["model"]["params"]["batchsize"] |
|
log_path = config_yaml["log_directory"] |
|
|
|
if "dataloader_add_ons" in configs["data"].keys(): |
|
dataloader_add_ons = configs["data"]["dataloader_add_ons"] |
|
else: |
|
dataloader_add_ons = [] |
|
|
|
dataset = AudioDataset(config_yaml, split="train", add_ons=dataloader_add_ons) |
|
|
|
loader = DataLoader( |
|
dataset, batch_size=batch_size, num_workers=8, pin_memory=True, shuffle=True |
|
) |
|
|
|
print( |
|
"The length of the dataset is %s, the length of the dataloader is %s, the batchsize is %s" |
|
% (len(dataset), len(loader), batch_size) |
|
) |
|
|
|
val_dataset = AudioDataset(config_yaml, split="val", add_ons=dataloader_add_ons) |
|
|
|
val_loader = DataLoader( |
|
val_dataset, |
|
batch_size=batch_size, |
|
num_workers=8, |
|
shuffle=True, |
|
) |
|
|
|
model = AutoencoderKL( |
|
ddconfig=config_yaml["model"]["params"]["ddconfig"], |
|
lossconfig=config_yaml["model"]["params"]["lossconfig"], |
|
embed_dim=config_yaml["model"]["params"]["embed_dim"], |
|
image_key=config_yaml["model"]["params"]["image_key"], |
|
base_learning_rate=config_yaml["model"]["base_learning_rate"], |
|
subband=config_yaml["model"]["params"]["subband"], |
|
sampling_rate=config_yaml["preprocessing"]["audio"]["sampling_rate"], |
|
) |
|
|
|
try: |
|
config_reload_from_ckpt = configs["reload_from_ckpt"] |
|
except: |
|
config_reload_from_ckpt = None |
|
|
|
checkpoint_path = os.path.join(log_path, exp_group_name, exp_name, "checkpoints") |
|
|
|
checkpoint_callback = ModelCheckpoint( |
|
dirpath=checkpoint_path, |
|
monitor="global_step", |
|
mode="max", |
|
filename="checkpoint-{global_step:.0f}", |
|
every_n_train_steps=5000, |
|
save_top_k=config_yaml["step"]["save_top_k"], |
|
auto_insert_metric_name=False, |
|
save_last=True, |
|
) |
|
|
|
wandb_path = os.path.join(log_path, exp_group_name, exp_name) |
|
|
|
model.set_log_dir(log_path, exp_group_name, exp_name) |
|
|
|
os.makedirs(checkpoint_path, exist_ok=True) |
|
|
|
if len(os.listdir(checkpoint_path)) > 0: |
|
print("Load checkpoint from path: %s" % checkpoint_path) |
|
restore_step, n_step = get_restore_step(checkpoint_path) |
|
resume_from_checkpoint = os.path.join(checkpoint_path, restore_step) |
|
print("Resume from checkpoint", resume_from_checkpoint) |
|
elif config_reload_from_ckpt is not None: |
|
resume_from_checkpoint = config_reload_from_ckpt |
|
print("Reload ckpt specified in the config file %s" % resume_from_checkpoint) |
|
else: |
|
print("Train from scratch") |
|
resume_from_checkpoint = None |
|
|
|
devices = torch.cuda.device_count() |
|
|
|
wandb_logger = WandbLogger( |
|
save_dir=wandb_path, |
|
project=config_yaml["project"], |
|
config=config_yaml, |
|
name="%s/%s" % (exp_group_name, exp_name), |
|
) |
|
|
|
trainer = Trainer( |
|
accelerator="gpu", |
|
devices=devices, |
|
logger=wandb_logger, |
|
limit_val_batches=100, |
|
callbacks=[checkpoint_callback], |
|
strategy=DDPStrategy(find_unused_parameters=True), |
|
val_check_interval=2000, |
|
) |
|
|
|
|
|
trainer.fit(model, loader, val_loader, ckpt_path=resume_from_checkpoint) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"-c", |
|
"--autoencoder_config", |
|
type=str, |
|
required=True, |
|
help="path to autoencoder config .yam", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
config_yaml = args.autoencoder_config |
|
exp_name = os.path.basename(config_yaml.split(".")[0]) |
|
exp_group_name = os.path.basename(os.path.dirname(config_yaml)) |
|
|
|
config_yaml = os.path.join(config_yaml) |
|
|
|
config_yaml = yaml.load(open(config_yaml, "r"), Loader=yaml.FullLoader) |
|
|
|
main(config_yaml, exp_group_name, exp_name) |
|
|