import os import sys sys.path.insert(1, os.path.join(sys.path[0], '../utils')) import numpy as np import argparse import time import logging import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torch.utils.data from utilities import (create_folder, get_filename, create_logging, Mixup, StatisticsContainer) from models import (PVT, PVT2, PVT_lr, PVT_nopretrain, PVT_2layer, Cnn14, Cnn14_no_specaug, Cnn14_no_dropout, Cnn6, Cnn10, ResNet22, ResNet38, ResNet54, Cnn14_emb512, Cnn14_emb128, Cnn14_emb32, MobileNetV1, MobileNetV2, LeeNet11, LeeNet24, DaiNet19, Res1dNet31, Res1dNet51, Wavegram_Cnn14, Wavegram_Logmel_Cnn14, Wavegram_Logmel128_Cnn14, Cnn14_16k, Cnn14_8k, Cnn14_mel32, Cnn14_mel128, Cnn14_mixup_time_domain, Cnn14_DecisionLevelMax, Cnn14_DecisionLevelAtt, Cnn6_Transformer, GLAM, GLAM2, GLAM3, Cnn4, EAT) #from models_test import (PVT_test) #from models1 import (PVT1) #from models_vig import (VIG, VIG2) #from models_vvt import (VVT) #from models2 import (MPVIT, MPVIT2) #from models_reshape import (PVT_reshape, PVT_tscam) #from models_swin import (Swin, Swin_nopretrain) #from models_swin2 import (Swin2) #from models_van import (Van, Van_tiny) #from models_focal import (Focal) #from models_cross import (Cross) #from models_cov import (Cov) #from models_cnn import (Cnn_light) #from models_twins import (Twins) #from models_cmt import (Cmt, Cmt1) #from models_shunted import (Shunted) #from models_quadtree import (Quadtree, Quadtree2, Quadtree_nopretrain) #from models_davit import (Davit_tscam, Davit, Davit_nopretrain) from pytorch_utils import (move_data_to_device, count_parameters, count_flops, do_mixup) from data_generator import (AudioSetDataset, TrainSampler, BalancedTrainSampler, AlternateTrainSampler, EvaluateSampler, collate_fn) from evaluate import Evaluator import config from losses import get_loss_func def train(args): """Train AudioSet tagging model. Args: dataset_dir: str workspace: str data_type: 'balanced_train' | 'full_train' window_size: int hop_size: int mel_bins: int model_type: str loss_type: 'clip_bce' balanced: 'none' | 'balanced' | 'alternate' augmentation: 'none' | 'mixup' batch_size: int learning_rate: float resume_iteration: int early_stop: int accumulation_steps: int cuda: bool """ # Arugments & parameters workspace = args.workspace data_type = args.data_type sample_rate = args.sample_rate window_size = args.window_size hop_size = args.hop_size mel_bins = args.mel_bins fmin = args.fmin fmax = args.fmax model_type = args.model_type loss_type = args.loss_type balanced = args.balanced augmentation = args.augmentation batch_size = args.batch_size learning_rate = args.learning_rate resume_iteration = args.resume_iteration early_stop = args.early_stop device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu') filename = args.filename num_workers = 8 clip_samples = config.clip_samples classes_num = config.classes_num loss_func = get_loss_func(loss_type) # Paths black_list_csv = None train_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes', '{}.h5'.format(data_type)) eval_bal_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes', 'balanced_train.h5') eval_test_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes', 'eval.h5') checkpoints_dir = os.path.join(workspace, 'checkpoints', filename, 'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'.format( sample_rate, window_size, hop_size, mel_bins, fmin, fmax), 'data_type={}'.format(data_type), model_type, 'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced), 'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size)) create_folder(checkpoints_dir) statistics_path = os.path.join(workspace, 'statistics', filename, 'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'.format( sample_rate, window_size, hop_size, mel_bins, fmin, fmax), 'data_type={}'.format(data_type), model_type, 'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced), 'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size), 'statistics.pkl') create_folder(os.path.dirname(statistics_path)) logs_dir = os.path.join(workspace, 'logs', filename, 'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'.format( sample_rate, window_size, hop_size, mel_bins, fmin, fmax), 'data_type={}'.format(data_type), model_type, 'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced), 'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size)) create_logging(logs_dir, filemode='w') logging.info(args) if 'cuda' in str(device): logging.info('Using GPU.') device = 'cuda' else: logging.info('Using CPU. Set --cuda flag to use GPU.') device = 'cpu' # Model Model = eval(model_type) model = Model(sample_rate=sample_rate, window_size=window_size, hop_size=hop_size, mel_bins=mel_bins, fmin=fmin, fmax=fmax, classes_num=classes_num) total = sum(p.numel() for p in model.parameters()) print("Total params: %.2fM" % (total/1e6)) logging.info("Total params: %.2fM" % (total/1e6)) #params_num = count_parameters(model) # flops_num = count_flops(model, clip_samples) #logging.info('Parameters num: {}'.format(params_num)) # logging.info('Flops num: {:.3f} G'.format(flops_num / 1e9)) # Dataset will be used by DataLoader later. Dataset takes a meta as input # and return a waveform and a target. dataset = AudioSetDataset(sample_rate=sample_rate) # Train sampler if balanced == 'none': Sampler = TrainSampler elif balanced == 'balanced': Sampler = BalancedTrainSampler elif balanced == 'alternate': Sampler = AlternateTrainSampler train_sampler = Sampler( indexes_hdf5_path=train_indexes_hdf5_path, batch_size=batch_size * 2 if 'mixup' in augmentation else batch_size, black_list_csv=black_list_csv) # Evaluate sampler eval_bal_sampler = EvaluateSampler( indexes_hdf5_path=eval_bal_indexes_hdf5_path, batch_size=batch_size) eval_test_sampler = EvaluateSampler( indexes_hdf5_path=eval_test_indexes_hdf5_path, batch_size=batch_size) # Data loader train_loader = torch.utils.data.DataLoader(dataset=dataset, batch_sampler=train_sampler, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True) eval_bal_loader = torch.utils.data.DataLoader(dataset=dataset, batch_sampler=eval_bal_sampler, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True) eval_test_loader = torch.utils.data.DataLoader(dataset=dataset, batch_sampler=eval_test_sampler, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True) mix=0.5 if 'mixup' in augmentation: mixup_augmenter = Mixup(mixup_alpha=mix) print(mix) logging.info(mix) # Evaluator evaluator = Evaluator(model=model) # Statistics statistics_container = StatisticsContainer(statistics_path) # Optimizer optimizer = optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.05, amsgrad=True) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=4, min_lr=1e-06, verbose=True) train_bgn_time = time.time() # Resume training if resume_iteration > 0: resume_checkpoint_path = os.path.join(workspace, 'checkpoints', filename, 'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'.format( sample_rate, window_size, hop_size, mel_bins, fmin, fmax), 'data_type={}'.format(data_type), model_type, 'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced), 'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size), '{}_iterations.pth'.format(resume_iteration)) logging.info('Loading checkpoint {}'.format(resume_checkpoint_path)) checkpoint = torch.load(resume_checkpoint_path) model.load_state_dict(checkpoint['model']) train_sampler.load_state_dict(checkpoint['sampler']) statistics_container.load_state_dict(resume_iteration) iteration = checkpoint['iteration'] else: iteration = 0 # Parallel print('GPU number: {}'.format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) if 'cuda' in str(device): model.to(device) if resume_iteration: optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) print(optimizer.state_dict()['param_groups'][0]['lr']) time1 = time.time() for batch_data_dict in train_loader: """batch_data_dict: { 'audio_name': (batch_size [*2 if mixup],), 'waveform': (batch_size [*2 if mixup], clip_samples), 'target': (batch_size [*2 if mixup], classes_num), (ifexist) 'mixup_lambda': (batch_size * 2,)} """ # Evaluate if (iteration % 2000 == 0 and iteration >= resume_iteration) or (iteration == 0): train_fin_time = time.time() bal_statistics = evaluator.evaluate(eval_bal_loader) test_statistics = evaluator.evaluate(eval_test_loader) logging.info('Validate bal mAP: {:.3f}'.format( np.mean(bal_statistics['average_precision']))) logging.info('Validate test mAP: {:.3f}'.format( np.mean(test_statistics['average_precision']))) statistics_container.append(iteration, bal_statistics, data_type='bal') statistics_container.append(iteration, test_statistics, data_type='test') statistics_container.dump() train_time = train_fin_time - train_bgn_time validate_time = time.time() - train_fin_time logging.info( 'iteration: {}, train time: {:.3f} s, validate time: {:.3f} s' ''.format(iteration, train_time, validate_time)) logging.info('------------------------------------') train_bgn_time = time.time() # Save model if iteration % 2000 == 0: checkpoint = { 'iteration': iteration, 'model': model.module.state_dict(), 'sampler': train_sampler.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict()} checkpoint_path = os.path.join( checkpoints_dir, '{}_iterations.pth'.format(iteration)) torch.save(checkpoint, checkpoint_path) logging.info('Model saved to {}'.format(checkpoint_path)) # Mixup lambda if 'mixup' in augmentation: batch_data_dict['mixup_lambda'] = mixup_augmenter.get_lambda( batch_size=len(batch_data_dict['waveform'])) # Move data to device for key in batch_data_dict.keys(): batch_data_dict[key] = move_data_to_device(batch_data_dict[key], device) # Forward model.train() if 'mixup' in augmentation: batch_output_dict = model(batch_data_dict['waveform'], batch_data_dict['mixup_lambda']) """{'clipwise_output': (batch_size, classes_num), ...}""" batch_target_dict = {'target': do_mixup(batch_data_dict['target'], batch_data_dict['mixup_lambda'])} """{'target': (batch_size, classes_num)}""" else: batch_output_dict = model(batch_data_dict['waveform'], None) """{'clipwise_output': (batch_size, classes_num), ...}""" batch_target_dict = {'target': batch_data_dict['target']} """{'target': (batch_size, classes_num)}""" # Loss loss = loss_func(batch_output_dict, batch_target_dict) # Backward loss.backward() optimizer.step() optimizer.zero_grad() if iteration % 10 == 0: print(iteration, loss) #print('--- Iteration: {}, train time: {:.3f} s / 10 iterations ---'\ # .format(iteration, time.time() - time1)) #time1 = time.time() if iteration % 2000 == 0: scheduler.step(np.mean(test_statistics['average_precision'])) print(optimizer.state_dict()['param_groups'][0]['lr']) logging.info(optimizer.state_dict()['param_groups'][0]['lr']) # Stop learning if iteration == early_stop: break iteration += 1 if __name__ == '__main__': parser = argparse.ArgumentParser(description='Example of parser. ') subparsers = parser.add_subparsers(dest='mode') parser_train = subparsers.add_parser('train') parser_train.add_argument('--workspace', type=str, required=True) parser_train.add_argument('--data_type', type=str, default='full_train', choices=['balanced_train', 'full_train']) parser_train.add_argument('--sample_rate', type=int, default=32000) parser_train.add_argument('--window_size', type=int, default=1024) parser_train.add_argument('--hop_size', type=int, default=320) parser_train.add_argument('--mel_bins', type=int, default=64) parser_train.add_argument('--fmin', type=int, default=50) parser_train.add_argument('--fmax', type=int, default=14000) parser_train.add_argument('--model_type', type=str, required=True) parser_train.add_argument('--loss_type', type=str, default='clip_bce', choices=['clip_bce']) parser_train.add_argument('--balanced', type=str, default='balanced', choices=['none', 'balanced', 'alternate']) parser_train.add_argument('--augmentation', type=str, default='mixup', choices=['none', 'mixup']) parser_train.add_argument('--batch_size', type=int, default=32) parser_train.add_argument('--learning_rate', type=float, default=1e-3) parser_train.add_argument('--resume_iteration', type=int, default=0) parser_train.add_argument('--early_stop', type=int, default=1000000) parser_train.add_argument('--cuda', action='store_true', default=False) args = parser.parse_args() args.filename = get_filename(__file__) if args.mode == 'train': train(args) else: raise Exception('Error argument!')