Spaces:
Running
Running
import numpy as np | |
import matplotlib.pyplot as plt | |
import os | |
import torch | |
import torch.nn.functional as F | |
from torch.optim.lr_scheduler import CosineAnnealingLR | |
import json | |
import argparse | |
import glob | |
import sys | |
sys.path.append(os.path.join('..', '..')) | |
import bat_detect.train.train_model as tm | |
import bat_detect.train.audio_dataloader as adl | |
import bat_detect.train.evaluate as evl | |
import bat_detect.train.train_utils as tu | |
import bat_detect.train.losses as losses | |
import bat_detect.detector.parameters as parameters | |
import bat_detect.detector.models as models | |
import bat_detect.detector.post_process as pp | |
import bat_detect.utils.plot_utils as pu | |
import bat_detect.utils.detector_utils as du | |
if __name__ == "__main__": | |
info_str = '\nBatDetect - Finetune Model\n' | |
print(info_str) | |
parser = argparse.ArgumentParser() | |
parser.add_argument('audio_path', type=str, help='Input directory for audio') | |
parser.add_argument('train_ann_path', type=str, | |
help='Path to where train annotation file is stored') | |
parser.add_argument('test_ann_path', type=str, | |
help='Path to where test annotation file is stored') | |
parser.add_argument('model_path', type=str, | |
help='Path to pretrained model') | |
parser.add_argument('--op_model_name', type=str, default='', | |
help='Path and name for finetuned model') | |
parser.add_argument('--num_epochs', type=int, default=200, dest='num_epochs', | |
help='Number of finetuning epochs') | |
parser.add_argument('--finetune_only_last_layer', action='store_true', | |
help='Only train final layers') | |
parser.add_argument('--train_from_scratch', action='store_true', | |
help='Do not use pretrained weights') | |
parser.add_argument('--do_not_save_images', action='store_false', | |
help='Do not save images at the end of training') | |
parser.add_argument('--notes', type=str, default='', | |
help='Notes to save in text file') | |
args = vars(parser.parse_args()) | |
params = parameters.get_params(True, '../../experiments/') | |
if torch.cuda.is_available(): | |
params['device'] = 'cuda' | |
else: | |
params['device'] = 'cpu' | |
print('\nNote, this will be a lot faster if you use computer with a GPU.\n') | |
print('\nAudio directory: ' + args['audio_path']) | |
print('Train file: ' + args['train_ann_path']) | |
print('Test file: ' + args['test_ann_path']) | |
print('Loading model: ' + args['model_path']) | |
dataset_name = os.path.basename(args['train_ann_path']).replace('.json', '').replace('_TRAIN', '') | |
if args['train_from_scratch']: | |
print('\nTraining model from scratch i.e. not using pretrained weights') | |
model, params_train = du.load_model(args['model_path'], False) | |
else: | |
model, params_train = du.load_model(args['model_path'], True) | |
model.to(params['device']) | |
params['num_epochs'] = args['num_epochs'] | |
if args['op_model_name'] != '': | |
params['model_file_name'] = args['op_model_name'] | |
classes_to_ignore = params['classes_to_ignore']+params['generic_class'] | |
# save notes file | |
params['notes'] = args['notes'] | |
if args['notes'] != '': | |
tu.write_notes_file(params['experiment'] + 'notes.txt', args['notes']) | |
# load train annotations | |
train_sets = [] | |
train_sets.append(tu.get_blank_dataset_dict(dataset_name, False, args['train_ann_path'], args['audio_path'])) | |
params['train_sets'] = [tu.get_blank_dataset_dict(dataset_name, False, os.path.basename(args['train_ann_path']), args['audio_path'])] | |
print('\nTrain set:') | |
data_train, params['class_names'], params['class_inv_freq'] = \ | |
tu.load_set_of_anns(train_sets, classes_to_ignore, params['events_of_interest']) | |
print('Number of files', len(data_train)) | |
params['genus_names'], params['genus_mapping'] = tu.get_genus_mapping(params['class_names']) | |
params['class_names_short'] = tu.get_short_class_names(params['class_names']) | |
# load test annotations | |
test_sets = [] | |
test_sets.append(tu.get_blank_dataset_dict(dataset_name, True, args['test_ann_path'], args['audio_path'])) | |
params['test_sets'] = [tu.get_blank_dataset_dict(dataset_name, True, os.path.basename(args['test_ann_path']), args['audio_path'])] | |
print('\nTest set:') | |
data_test, _, _ = tu.load_set_of_anns(test_sets, classes_to_ignore, params['events_of_interest']) | |
print('Number of files', len(data_test)) | |
# train loader | |
train_dataset = adl.AudioLoader(data_train, params, is_train=True) | |
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params['batch_size'], | |
shuffle=True, num_workers=params['num_workers'], pin_memory=True) | |
# test loader - batch size of one because of variable file length | |
test_dataset = adl.AudioLoader(data_test, params, is_train=False) | |
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, | |
shuffle=False, num_workers=params['num_workers'], pin_memory=True) | |
inputs_train = next(iter(train_loader)) | |
params['ip_height'] = inputs_train['spec'].shape[2] | |
print('\ntrain batch size :', inputs_train['spec'].shape) | |
assert(params_train['model_name'] == 'Net2DFast') | |
print('\n\nSOME hyperparams need to be the same as the loaded model (e.g. FFT) - currently they are getting overwritten.\n\n') | |
# set the number of output classes | |
num_filts = model.conv_classes_op.in_channels | |
k_size = model.conv_classes_op.kernel_size | |
pad = model.conv_classes_op.padding | |
model.conv_classes_op = torch.nn.Conv2d(num_filts, len(params['class_names'])+1, kernel_size=k_size, padding=pad) | |
model.conv_classes_op.to(params['device']) | |
if args['finetune_only_last_layer']: | |
print('\nOnly finetuning the final layers.\n') | |
train_layers_i = ['conv_classes', 'conv_classes_op', 'conv_size', 'conv_size_op'] | |
train_layers = [tt + '.weight' for tt in train_layers_i] + [tt + '.bias' for tt in train_layers_i] | |
for name, param in model.named_parameters(): | |
if name in train_layers: | |
param.requires_grad = True | |
else: | |
param.requires_grad = False | |
optimizer = torch.optim.Adam(model.parameters(), lr=params['lr']) | |
scheduler = CosineAnnealingLR(optimizer, params['num_epochs'] * len(train_loader)) | |
if params['train_loss'] == 'mse': | |
det_criterion = losses.mse_loss | |
elif params['train_loss'] == 'focal': | |
det_criterion = losses.focal_loss | |
# plotting | |
train_plt_ls = pu.LossPlotter(params['experiment'] + 'train_loss.png', params['num_epochs']+1, | |
['train_loss'], None, None, ['epoch', 'train_loss'], logy=True) | |
test_plt_ls = pu.LossPlotter(params['experiment'] + 'test_loss.png', params['num_epochs']+1, | |
['test_loss'], None, None, ['epoch', 'test_loss'], logy=True) | |
test_plt = pu.LossPlotter(params['experiment'] + 'test.png', params['num_epochs']+1, | |
['avg_prec', 'rec_at_x', 'avg_prec_class', 'file_acc', 'top_class'], [0,1], None, ['epoch', '']) | |
test_plt_class = pu.LossPlotter(params['experiment'] + 'test_avg_prec.png', params['num_epochs']+1, | |
params['class_names_short'], [0,1], params['class_names_short'], ['epoch', 'avg_prec']) | |
# main train loop | |
for epoch in range(0, params['num_epochs']+1): | |
train_loss = tm.train(model, epoch, train_loader, det_criterion, optimizer, scheduler, params) | |
train_plt_ls.update_and_save(epoch, [train_loss['train_loss']]) | |
if epoch % params['num_eval_epochs'] == 0: | |
# detection accuracy on test set | |
test_res, test_loss = tm.test(model, epoch, test_loader, det_criterion, params) | |
test_plt_ls.update_and_save(epoch, [test_loss['test_loss']]) | |
test_plt.update_and_save(epoch, [test_res['avg_prec'], test_res['rec_at_x'], | |
test_res['avg_prec_class'], test_res['file_acc'], test_res['top_class']['avg_prec']]) | |
test_plt_class.update_and_save(epoch, [rs['avg_prec'] for rs in test_res['class_pr']]) | |
pu.plot_pr_curve_class(params['experiment'] , 'test_pr', 'test_pr', test_res) | |
# save finetuned model | |
print('saving model to: ' + params['model_file_name']) | |
op_state = {'epoch': epoch + 1, | |
'state_dict': model.state_dict(), | |
'params' : params} | |
torch.save(op_state, params['model_file_name']) | |
# save an image with associated prediction for each batch in the test set | |
if not args['do_not_save_images']: | |
tm.save_images_batch(model, test_loader, params) | |