batdetect2 / bat_detect /finetune /finetune_model.py
Oisin Mac Aodha
added bat code
9ace58a
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR
import json
import argparse
import glob
import sys
sys.path.append(os.path.join('..', '..'))
import bat_detect.train.train_model as tm
import bat_detect.train.audio_dataloader as adl
import bat_detect.train.evaluate as evl
import bat_detect.train.train_utils as tu
import bat_detect.train.losses as losses
import bat_detect.detector.parameters as parameters
import bat_detect.detector.models as models
import bat_detect.detector.post_process as pp
import bat_detect.utils.plot_utils as pu
import bat_detect.utils.detector_utils as du
if __name__ == "__main__":
info_str = '\nBatDetect - Finetune Model\n'
print(info_str)
parser = argparse.ArgumentParser()
parser.add_argument('audio_path', type=str, help='Input directory for audio')
parser.add_argument('train_ann_path', type=str,
help='Path to where train annotation file is stored')
parser.add_argument('test_ann_path', type=str,
help='Path to where test annotation file is stored')
parser.add_argument('model_path', type=str,
help='Path to pretrained model')
parser.add_argument('--op_model_name', type=str, default='',
help='Path and name for finetuned model')
parser.add_argument('--num_epochs', type=int, default=200, dest='num_epochs',
help='Number of finetuning epochs')
parser.add_argument('--finetune_only_last_layer', action='store_true',
help='Only train final layers')
parser.add_argument('--train_from_scratch', action='store_true',
help='Do not use pretrained weights')
parser.add_argument('--do_not_save_images', action='store_false',
help='Do not save images at the end of training')
parser.add_argument('--notes', type=str, default='',
help='Notes to save in text file')
args = vars(parser.parse_args())
params = parameters.get_params(True, '../../experiments/')
if torch.cuda.is_available():
params['device'] = 'cuda'
else:
params['device'] = 'cpu'
print('\nNote, this will be a lot faster if you use computer with a GPU.\n')
print('\nAudio directory: ' + args['audio_path'])
print('Train file: ' + args['train_ann_path'])
print('Test file: ' + args['test_ann_path'])
print('Loading model: ' + args['model_path'])
dataset_name = os.path.basename(args['train_ann_path']).replace('.json', '').replace('_TRAIN', '')
if args['train_from_scratch']:
print('\nTraining model from scratch i.e. not using pretrained weights')
model, params_train = du.load_model(args['model_path'], False)
else:
model, params_train = du.load_model(args['model_path'], True)
model.to(params['device'])
params['num_epochs'] = args['num_epochs']
if args['op_model_name'] != '':
params['model_file_name'] = args['op_model_name']
classes_to_ignore = params['classes_to_ignore']+params['generic_class']
# save notes file
params['notes'] = args['notes']
if args['notes'] != '':
tu.write_notes_file(params['experiment'] + 'notes.txt', args['notes'])
# load train annotations
train_sets = []
train_sets.append(tu.get_blank_dataset_dict(dataset_name, False, args['train_ann_path'], args['audio_path']))
params['train_sets'] = [tu.get_blank_dataset_dict(dataset_name, False, os.path.basename(args['train_ann_path']), args['audio_path'])]
print('\nTrain set:')
data_train, params['class_names'], params['class_inv_freq'] = \
tu.load_set_of_anns(train_sets, classes_to_ignore, params['events_of_interest'])
print('Number of files', len(data_train))
params['genus_names'], params['genus_mapping'] = tu.get_genus_mapping(params['class_names'])
params['class_names_short'] = tu.get_short_class_names(params['class_names'])
# load test annotations
test_sets = []
test_sets.append(tu.get_blank_dataset_dict(dataset_name, True, args['test_ann_path'], args['audio_path']))
params['test_sets'] = [tu.get_blank_dataset_dict(dataset_name, True, os.path.basename(args['test_ann_path']), args['audio_path'])]
print('\nTest set:')
data_test, _, _ = tu.load_set_of_anns(test_sets, classes_to_ignore, params['events_of_interest'])
print('Number of files', len(data_test))
# train loader
train_dataset = adl.AudioLoader(data_train, params, is_train=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params['batch_size'],
shuffle=True, num_workers=params['num_workers'], pin_memory=True)
# test loader - batch size of one because of variable file length
test_dataset = adl.AudioLoader(data_test, params, is_train=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1,
shuffle=False, num_workers=params['num_workers'], pin_memory=True)
inputs_train = next(iter(train_loader))
params['ip_height'] = inputs_train['spec'].shape[2]
print('\ntrain batch size :', inputs_train['spec'].shape)
assert(params_train['model_name'] == 'Net2DFast')
print('\n\nSOME hyperparams need to be the same as the loaded model (e.g. FFT) - currently they are getting overwritten.\n\n')
# set the number of output classes
num_filts = model.conv_classes_op.in_channels
k_size = model.conv_classes_op.kernel_size
pad = model.conv_classes_op.padding
model.conv_classes_op = torch.nn.Conv2d(num_filts, len(params['class_names'])+1, kernel_size=k_size, padding=pad)
model.conv_classes_op.to(params['device'])
if args['finetune_only_last_layer']:
print('\nOnly finetuning the final layers.\n')
train_layers_i = ['conv_classes', 'conv_classes_op', 'conv_size', 'conv_size_op']
train_layers = [tt + '.weight' for tt in train_layers_i] + [tt + '.bias' for tt in train_layers_i]
for name, param in model.named_parameters():
if name in train_layers:
param.requires_grad = True
else:
param.requires_grad = False
optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'])
scheduler = CosineAnnealingLR(optimizer, params['num_epochs'] * len(train_loader))
if params['train_loss'] == 'mse':
det_criterion = losses.mse_loss
elif params['train_loss'] == 'focal':
det_criterion = losses.focal_loss
# plotting
train_plt_ls = pu.LossPlotter(params['experiment'] + 'train_loss.png', params['num_epochs']+1,
['train_loss'], None, None, ['epoch', 'train_loss'], logy=True)
test_plt_ls = pu.LossPlotter(params['experiment'] + 'test_loss.png', params['num_epochs']+1,
['test_loss'], None, None, ['epoch', 'test_loss'], logy=True)
test_plt = pu.LossPlotter(params['experiment'] + 'test.png', params['num_epochs']+1,
['avg_prec', 'rec_at_x', 'avg_prec_class', 'file_acc', 'top_class'], [0,1], None, ['epoch', ''])
test_plt_class = pu.LossPlotter(params['experiment'] + 'test_avg_prec.png', params['num_epochs']+1,
params['class_names_short'], [0,1], params['class_names_short'], ['epoch', 'avg_prec'])
# main train loop
for epoch in range(0, params['num_epochs']+1):
train_loss = tm.train(model, epoch, train_loader, det_criterion, optimizer, scheduler, params)
train_plt_ls.update_and_save(epoch, [train_loss['train_loss']])
if epoch % params['num_eval_epochs'] == 0:
# detection accuracy on test set
test_res, test_loss = tm.test(model, epoch, test_loader, det_criterion, params)
test_plt_ls.update_and_save(epoch, [test_loss['test_loss']])
test_plt.update_and_save(epoch, [test_res['avg_prec'], test_res['rec_at_x'],
test_res['avg_prec_class'], test_res['file_acc'], test_res['top_class']['avg_prec']])
test_plt_class.update_and_save(epoch, [rs['avg_prec'] for rs in test_res['class_pr']])
pu.plot_pr_curve_class(params['experiment'] , 'test_pr', 'test_pr', test_res)
# save finetuned model
print('saving model to: ' + params['model_file_name'])
op_state = {'epoch': epoch + 1,
'state_dict': model.state_dict(),
'params' : params}
torch.save(op_state, params['model_file_name'])
# save an image with associated prediction for each batch in the test set
if not args['do_not_save_images']:
tm.save_images_batch(model, test_loader, params)