lmzjms's picture
Upload 35 files
15ac91d
raw
history blame
No virus
15.1 kB
import os
import sys
sys.path.insert(1, os.path.join(sys.path[0], '../utils'))
import numpy as np
import argparse
import time
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from utilities import (create_folder, get_filename, create_logging, Mixup,
StatisticsContainer)
from models import (PVT, PVT2, PVT_lr, PVT_nopretrain, PVT_2layer, Cnn14, Cnn14_no_specaug, Cnn14_no_dropout,
Cnn6, Cnn10, ResNet22, ResNet38, ResNet54, Cnn14_emb512, Cnn14_emb128,
Cnn14_emb32, MobileNetV1, MobileNetV2, LeeNet11, LeeNet24, DaiNet19,
Res1dNet31, Res1dNet51, Wavegram_Cnn14, Wavegram_Logmel_Cnn14,
Wavegram_Logmel128_Cnn14, Cnn14_16k, Cnn14_8k, Cnn14_mel32, Cnn14_mel128,
Cnn14_mixup_time_domain, Cnn14_DecisionLevelMax, Cnn14_DecisionLevelAtt, Cnn6_Transformer, GLAM, GLAM2, GLAM3, Cnn4, EAT)
#from models_test import (PVT_test)
#from models1 import (PVT1)
#from models_vig import (VIG, VIG2)
#from models_vvt import (VVT)
#from models2 import (MPVIT, MPVIT2)
#from models_reshape import (PVT_reshape, PVT_tscam)
#from models_swin import (Swin, Swin_nopretrain)
#from models_swin2 import (Swin2)
#from models_van import (Van, Van_tiny)
#from models_focal import (Focal)
#from models_cross import (Cross)
#from models_cov import (Cov)
#from models_cnn import (Cnn_light)
#from models_twins import (Twins)
#from models_cmt import (Cmt, Cmt1)
#from models_shunted import (Shunted)
#from models_quadtree import (Quadtree, Quadtree2, Quadtree_nopretrain)
#from models_davit import (Davit_tscam, Davit, Davit_nopretrain)
from pytorch_utils import (move_data_to_device, count_parameters, count_flops,
do_mixup)
from data_generator import (AudioSetDataset, TrainSampler, BalancedTrainSampler,
AlternateTrainSampler, EvaluateSampler, collate_fn)
from evaluate import Evaluator
import config
from losses import get_loss_func
def train(args):
"""Train AudioSet tagging model.
Args:
dataset_dir: str
workspace: str
data_type: 'balanced_train' | 'full_train'
window_size: int
hop_size: int
mel_bins: int
model_type: str
loss_type: 'clip_bce'
balanced: 'none' | 'balanced' | 'alternate'
augmentation: 'none' | 'mixup'
batch_size: int
learning_rate: float
resume_iteration: int
early_stop: int
accumulation_steps: int
cuda: bool
"""
# Arugments & parameters
workspace = args.workspace
data_type = args.data_type
sample_rate = args.sample_rate
window_size = args.window_size
hop_size = args.hop_size
mel_bins = args.mel_bins
fmin = args.fmin
fmax = args.fmax
model_type = args.model_type
loss_type = args.loss_type
balanced = args.balanced
augmentation = args.augmentation
batch_size = args.batch_size
learning_rate = args.learning_rate
resume_iteration = args.resume_iteration
early_stop = args.early_stop
device = torch.device('cuda') if args.cuda and torch.cuda.is_available() else torch.device('cpu')
filename = args.filename
num_workers = 8
clip_samples = config.clip_samples
classes_num = config.classes_num
loss_func = get_loss_func(loss_type)
# Paths
black_list_csv = None
train_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
'{}.h5'.format(data_type))
eval_bal_indexes_hdf5_path = os.path.join(workspace,
'hdf5s', 'indexes', 'balanced_train.h5')
eval_test_indexes_hdf5_path = os.path.join(workspace, 'hdf5s', 'indexes',
'eval.h5')
checkpoints_dir = os.path.join(workspace, 'checkpoints', filename,
'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'.format(
sample_rate, window_size, hop_size, mel_bins, fmin, fmax),
'data_type={}'.format(data_type), model_type,
'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size))
create_folder(checkpoints_dir)
statistics_path = os.path.join(workspace, 'statistics', filename,
'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'.format(
sample_rate, window_size, hop_size, mel_bins, fmin, fmax),
'data_type={}'.format(data_type), model_type,
'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size),
'statistics.pkl')
create_folder(os.path.dirname(statistics_path))
logs_dir = os.path.join(workspace, 'logs', filename,
'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'.format(
sample_rate, window_size, hop_size, mel_bins, fmin, fmax),
'data_type={}'.format(data_type), model_type,
'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size))
create_logging(logs_dir, filemode='w')
logging.info(args)
if 'cuda' in str(device):
logging.info('Using GPU.')
device = 'cuda'
else:
logging.info('Using CPU. Set --cuda flag to use GPU.')
device = 'cpu'
# Model
Model = eval(model_type)
model = Model(sample_rate=sample_rate, window_size=window_size,
hop_size=hop_size, mel_bins=mel_bins, fmin=fmin, fmax=fmax,
classes_num=classes_num)
total = sum(p.numel() for p in model.parameters())
print("Total params: %.2fM" % (total/1e6))
logging.info("Total params: %.2fM" % (total/1e6))
#params_num = count_parameters(model)
# flops_num = count_flops(model, clip_samples)
#logging.info('Parameters num: {}'.format(params_num))
# logging.info('Flops num: {:.3f} G'.format(flops_num / 1e9))
# Dataset will be used by DataLoader later. Dataset takes a meta as input
# and return a waveform and a target.
dataset = AudioSetDataset(sample_rate=sample_rate)
# Train sampler
if balanced == 'none':
Sampler = TrainSampler
elif balanced == 'balanced':
Sampler = BalancedTrainSampler
elif balanced == 'alternate':
Sampler = AlternateTrainSampler
train_sampler = Sampler(
indexes_hdf5_path=train_indexes_hdf5_path,
batch_size=batch_size * 2 if 'mixup' in augmentation else batch_size,
black_list_csv=black_list_csv)
# Evaluate sampler
eval_bal_sampler = EvaluateSampler(
indexes_hdf5_path=eval_bal_indexes_hdf5_path, batch_size=batch_size)
eval_test_sampler = EvaluateSampler(
indexes_hdf5_path=eval_test_indexes_hdf5_path, batch_size=batch_size)
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_sampler=train_sampler, collate_fn=collate_fn,
num_workers=num_workers, pin_memory=True)
eval_bal_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_sampler=eval_bal_sampler, collate_fn=collate_fn,
num_workers=num_workers, pin_memory=True)
eval_test_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_sampler=eval_test_sampler, collate_fn=collate_fn,
num_workers=num_workers, pin_memory=True)
mix=0.5
if 'mixup' in augmentation:
mixup_augmenter = Mixup(mixup_alpha=mix)
print(mix)
logging.info(mix)
# Evaluator
evaluator = Evaluator(model=model)
# Statistics
statistics_container = StatisticsContainer(statistics_path)
# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.05, amsgrad=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=4, min_lr=1e-06, verbose=True)
train_bgn_time = time.time()
# Resume training
if resume_iteration > 0:
resume_checkpoint_path = os.path.join(workspace, 'checkpoints', filename,
'sample_rate={},window_size={},hop_size={},mel_bins={},fmin={},fmax={}'.format(
sample_rate, window_size, hop_size, mel_bins, fmin, fmax),
'data_type={}'.format(data_type), model_type,
'loss_type={}'.format(loss_type), 'balanced={}'.format(balanced),
'augmentation={}'.format(augmentation), 'batch_size={}'.format(batch_size),
'{}_iterations.pth'.format(resume_iteration))
logging.info('Loading checkpoint {}'.format(resume_checkpoint_path))
checkpoint = torch.load(resume_checkpoint_path)
model.load_state_dict(checkpoint['model'])
train_sampler.load_state_dict(checkpoint['sampler'])
statistics_container.load_state_dict(resume_iteration)
iteration = checkpoint['iteration']
else:
iteration = 0
# Parallel
print('GPU number: {}'.format(torch.cuda.device_count()))
model = torch.nn.DataParallel(model)
if 'cuda' in str(device):
model.to(device)
if resume_iteration:
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
print(optimizer.state_dict()['param_groups'][0]['lr'])
time1 = time.time()
for batch_data_dict in train_loader:
"""batch_data_dict: {
'audio_name': (batch_size [*2 if mixup],),
'waveform': (batch_size [*2 if mixup], clip_samples),
'target': (batch_size [*2 if mixup], classes_num),
(ifexist) 'mixup_lambda': (batch_size * 2,)}
"""
# Evaluate
if (iteration % 2000 == 0 and iteration >= resume_iteration) or (iteration == 0):
train_fin_time = time.time()
bal_statistics = evaluator.evaluate(eval_bal_loader)
test_statistics = evaluator.evaluate(eval_test_loader)
logging.info('Validate bal mAP: {:.3f}'.format(
np.mean(bal_statistics['average_precision'])))
logging.info('Validate test mAP: {:.3f}'.format(
np.mean(test_statistics['average_precision'])))
statistics_container.append(iteration, bal_statistics, data_type='bal')
statistics_container.append(iteration, test_statistics, data_type='test')
statistics_container.dump()
train_time = train_fin_time - train_bgn_time
validate_time = time.time() - train_fin_time
logging.info(
'iteration: {}, train time: {:.3f} s, validate time: {:.3f} s'
''.format(iteration, train_time, validate_time))
logging.info('------------------------------------')
train_bgn_time = time.time()
# Save model
if iteration % 2000 == 0:
checkpoint = {
'iteration': iteration,
'model': model.module.state_dict(),
'sampler': train_sampler.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict()}
checkpoint_path = os.path.join(
checkpoints_dir, '{}_iterations.pth'.format(iteration))
torch.save(checkpoint, checkpoint_path)
logging.info('Model saved to {}'.format(checkpoint_path))
# Mixup lambda
if 'mixup' in augmentation:
batch_data_dict['mixup_lambda'] = mixup_augmenter.get_lambda(
batch_size=len(batch_data_dict['waveform']))
# Move data to device
for key in batch_data_dict.keys():
batch_data_dict[key] = move_data_to_device(batch_data_dict[key], device)
# Forward
model.train()
if 'mixup' in augmentation:
batch_output_dict = model(batch_data_dict['waveform'],
batch_data_dict['mixup_lambda'])
"""{'clipwise_output': (batch_size, classes_num), ...}"""
batch_target_dict = {'target': do_mixup(batch_data_dict['target'],
batch_data_dict['mixup_lambda'])}
"""{'target': (batch_size, classes_num)}"""
else:
batch_output_dict = model(batch_data_dict['waveform'], None)
"""{'clipwise_output': (batch_size, classes_num), ...}"""
batch_target_dict = {'target': batch_data_dict['target']}
"""{'target': (batch_size, classes_num)}"""
# Loss
loss = loss_func(batch_output_dict, batch_target_dict)
# Backward
loss.backward()
optimizer.step()
optimizer.zero_grad()
if iteration % 10 == 0:
print(iteration, loss)
#print('--- Iteration: {}, train time: {:.3f} s / 10 iterations ---'\
# .format(iteration, time.time() - time1))
#time1 = time.time()
if iteration % 2000 == 0:
scheduler.step(np.mean(test_statistics['average_precision']))
print(optimizer.state_dict()['param_groups'][0]['lr'])
logging.info(optimizer.state_dict()['param_groups'][0]['lr'])
# Stop learning
if iteration == early_stop:
break
iteration += 1
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Example of parser. ')
subparsers = parser.add_subparsers(dest='mode')
parser_train = subparsers.add_parser('train')
parser_train.add_argument('--workspace', type=str, required=True)
parser_train.add_argument('--data_type', type=str, default='full_train', choices=['balanced_train', 'full_train'])
parser_train.add_argument('--sample_rate', type=int, default=32000)
parser_train.add_argument('--window_size', type=int, default=1024)
parser_train.add_argument('--hop_size', type=int, default=320)
parser_train.add_argument('--mel_bins', type=int, default=64)
parser_train.add_argument('--fmin', type=int, default=50)
parser_train.add_argument('--fmax', type=int, default=14000)
parser_train.add_argument('--model_type', type=str, required=True)
parser_train.add_argument('--loss_type', type=str, default='clip_bce', choices=['clip_bce'])
parser_train.add_argument('--balanced', type=str, default='balanced', choices=['none', 'balanced', 'alternate'])
parser_train.add_argument('--augmentation', type=str, default='mixup', choices=['none', 'mixup'])
parser_train.add_argument('--batch_size', type=int, default=32)
parser_train.add_argument('--learning_rate', type=float, default=1e-3)
parser_train.add_argument('--resume_iteration', type=int, default=0)
parser_train.add_argument('--early_stop', type=int, default=1000000)
parser_train.add_argument('--cuda', action='store_true', default=False)
args = parser.parse_args()
args.filename = get_filename(__file__)
if args.mode == 'train':
train(args)
else:
raise Exception('Error argument!')