import numpy as np import matplotlib.pyplot as plt import os import torch import torch.nn.functional as F from torch.optim.lr_scheduler import CosineAnnealingLR import json import argparse import sys sys.path.append(os.path.join('..', '..')) import bat_detect.detector.parameters as parameters import bat_detect.detector.models as models import bat_detect.detector.post_process as pp import bat_detect.utils.plot_utils as pu import bat_detect.train.audio_dataloader as adl import bat_detect.train.evaluate as evl import bat_detect.train.train_utils as tu import bat_detect.train.train_split as ts import bat_detect.train.losses as losses import warnings warnings.filterwarnings("ignore", category=UserWarning) def save_images_batch(model, data_loader, params): print('\nsaving images ...') is_train_state = data_loader.dataset.is_train data_loader.dataset.is_train = False data_loader.dataset.return_spec_for_viz = True model.eval() ind = 0 # first image in each batch with torch.no_grad(): for batch_idx, inputs in enumerate(data_loader): data = inputs['spec'].to(params['device']) outputs = model(data) spec_viz = inputs['spec_for_viz'].data.cpu().numpy() orig_index = inputs['file_id'][ind] plot_title = data_loader.dataset.data_anns[orig_index]['id'] op_file_name = params['op_im_dir_test'] + data_loader.dataset.data_anns[orig_index]['id'] + '.jpg' save_image(spec_viz, outputs, ind, inputs, params, op_file_name, plot_title) data_loader.dataset.is_train = is_train_state data_loader.dataset.return_spec_for_viz = False def save_image(spec_viz, outputs, ind, inputs, params, op_file_name, plot_title): pred_nms, _ = pp.run_nms(outputs, params, inputs['sampling_rate'].float()) pred_hm = outputs['pred_det'][ind, 0, :].data.cpu().numpy() spec_viz = spec_viz[ind, 0, :] gt = parse_gt_data(inputs)[ind] sampling_rate = inputs['sampling_rate'][ind].item() duration = inputs['duration'][ind].item() pu.plot_spec(spec_viz, sampling_rate, duration, gt, pred_nms[ind], params, plot_title, op_file_name, pred_hm, plot_boxes=True, fixed_aspect=False) def loss_fun(outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq): # detection loss loss = params['det_loss_weight']*det_criterion(outputs['pred_det'], gt_det) # bounding box size loss loss += params['size_loss_weight']*losses.bbox_size_loss(outputs['pred_size'], gt_size) # classification loss valid_mask = (gt_class[:, :-1, :, :].sum(1) > 0).float().unsqueeze(1) p_class = outputs['pred_class'][:, :-1, :] loss += params['class_loss_weight']*det_criterion(p_class, gt_class[:, :-1, :], valid_mask=valid_mask) return loss def train(model, epoch, data_loader, det_criterion, optimizer, scheduler, params): model.train() train_loss = tu.AverageMeter() class_inv_freq = torch.from_numpy(np.array(params['class_inv_freq'], dtype=np.float32)).to(params['device']) class_inv_freq = class_inv_freq.unsqueeze(0).unsqueeze(2).unsqueeze(2) print('\nEpoch', epoch) for batch_idx, inputs in enumerate(data_loader): data = inputs['spec'].to(params['device']) gt_det = inputs['y_2d_det'].to(params['device']) gt_size = inputs['y_2d_size'].to(params['device']) gt_class = inputs['y_2d_classes'].to(params['device']) optimizer.zero_grad() outputs = model(data) loss = loss_fun(outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq) train_loss.update(loss.item(), data.shape[0]) loss.backward() optimizer.step() scheduler.step() if batch_idx % 50 == 0 and batch_idx != 0: print('[{}/{}]\tLoss: {:.4f}'.format( batch_idx * len(data), len(data_loader.dataset), train_loss.avg)) print('Train loss : {:.4f}'.format(train_loss.avg)) res = {} res['train_loss'] = float(train_loss.avg) return res def test(model, epoch, data_loader, det_criterion, params): model.eval() predictions = [] ground_truths = [] test_loss = tu.AverageMeter() class_inv_freq = torch.from_numpy(np.array(params['class_inv_freq'], dtype=np.float32)).to(params['device']) class_inv_freq = class_inv_freq.unsqueeze(0).unsqueeze(2).unsqueeze(2) with torch.no_grad(): for batch_idx, inputs in enumerate(data_loader): data = inputs['spec'].to(params['device']) gt_det = inputs['y_2d_det'].to(params['device']) gt_size = inputs['y_2d_size'].to(params['device']) gt_class = inputs['y_2d_classes'].to(params['device']) outputs = model(data) # if the model needs a fixed sized intput run this # data = torch.cat(torch.split(data, int(params['spec_train_width']*params['resize_factor']), 3), 0) # outputs = model(data) # for kk in ['pred_det', 'pred_size', 'pred_class']: # outputs[kk] = torch.cat([oo for oo in outputs[kk]], 2).unsqueeze(0) if params['save_test_image_during_train'] and batch_idx == 0: # for visualization - save the first prediction ind = 0 orig_index = inputs['file_id'][ind] plot_title = data_loader.dataset.data_anns[orig_index]['id'] op_file_name = params['op_im_dir'] + str(orig_index.item()).zfill(4) + '_' + str(epoch).zfill(4) + '_pred.jpg' save_image(data, outputs, ind, inputs, params, op_file_name, plot_title) loss = loss_fun(outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq) test_loss.update(loss.item(), data.shape[0]) # do NMS pred_nms, _ = pp.run_nms(outputs, params, inputs['sampling_rate'].float()) predictions.extend(pred_nms) ground_truths.extend(parse_gt_data(inputs)) res_det = evl.evaluate_predictions(ground_truths, predictions, params['class_names'], params['detection_overlap'], params['ignore_start_end']) print('\nTest loss : {:.4f}'.format(test_loss.avg)) print('Rec at 0.95 (det) : {:.4f}'.format(res_det['rec_at_x'])) print('Avg prec (cls) : {:.4f}'.format(res_det['avg_prec'])) print('File acc (cls) : {:.2f} - for {} out of {}'.format(res_det['file_acc'], res_det['num_valid_files'], res_det['num_total_files'])) print('Cls Avg prec (cls) : {:.4f}'.format(res_det['avg_prec_class'])) print('\nPer class average precision') str_len = np.max([len(rs['name']) for rs in res_det['class_pr']]) + 5 for cc, rs in enumerate(res_det['class_pr']): if rs['num_gt'] > 0: print(str(cc).ljust(5) + rs['name'].ljust(str_len) + '{:.4f}'.format(rs['avg_prec'])) res = {} res['test_loss'] = float(test_loss.avg) return res_det, res def parse_gt_data(inputs): # reads the torch arrays into a dictionary of numpy arrays, taking care to # remove padding data i.e. not valid ones keys = ['start_times', 'end_times', 'low_freqs', 'high_freqs', 'class_ids', 'individual_ids'] batch_data = [] for ind in range(inputs['start_times'].shape[0]): is_valid = inputs['is_valid'][ind]==1 gt = {} for kk in keys: gt[kk] = inputs[kk][ind][is_valid].numpy().astype(np.float32) gt['duration'] = inputs['duration'][ind].item() gt['file_id'] = inputs['file_id'][ind].item() gt['class_id_file'] = inputs['class_id_file'][ind].item() batch_data.append(gt) return batch_data def select_model(params): num_classes = len(params['class_names']) if params['model_name'] == 'Net2DFast': model = models.Net2DFast(params['num_filters'], num_classes=num_classes, emb_dim=params['emb_dim'], ip_height=params['ip_height'], resize_factor=params['resize_factor']) elif params['model_name'] == 'Net2DFastNoAttn': model = models.Net2DFastNoAttn(params['num_filters'], num_classes=num_classes, emb_dim=params['emb_dim'], ip_height=params['ip_height'], resize_factor=params['resize_factor']) elif params['model_name'] == 'Net2DFastNoCoordConv': model = models.Net2DFastNoCoordConv(params['num_filters'], num_classes=num_classes, emb_dim=params['emb_dim'], ip_height=params['ip_height'], resize_factor=params['resize_factor']) else: print('No valid network specified') return model if __name__ == "__main__": plt.close('all') params = parameters.get_params(True) if torch.cuda.is_available(): params['device'] = 'cuda' else: params['device'] = 'cpu' # setup arg parser and populate it with exiting parameters - will not work with lists parser = argparse.ArgumentParser() parser.add_argument('data_dir', type=str, help='Path to root of datasets') parser.add_argument('ann_dir', type=str, help='Path to extracted annotations') parser.add_argument('--train_split', type=str, default='diff', # diff, same help='Which train split to use') parser.add_argument('--notes', type=str, default='', help='Notes to save in text file') parser.add_argument('--do_not_save_images', action='store_false', help='Do not save images at the end of training') parser.add_argument('--standardize_classs_names_ip', type=str, default='Rhinolophus ferrumequinum;Rhinolophus hipposideros', help='Will set low and high frequency the same for these classes. Separate names with ";"') for key, val in params.items(): parser.add_argument('--'+key, type=type(val), default=val) params = vars(parser.parse_args()) # save notes file if params['notes'] != '': tu.write_notes_file(params['experiment'] + 'notes.txt', params['notes']) # load the training and test meta data - there are different splits defined train_sets, test_sets = ts.get_train_test_data(params['ann_dir'], params['data_dir'], params['train_split']) train_sets_no_path, test_sets_no_path = ts.get_train_test_data('', '', params['train_split']) # keep track of what we have trained on params['train_sets'] = train_sets_no_path params['test_sets'] = test_sets_no_path # load train annotations - merge them all together print('\nTraining on:') for tt in train_sets: print(tt['ann_path']) classes_to_ignore = params['classes_to_ignore']+params['generic_class'] data_train, params['class_names'], params['class_inv_freq'] = \ tu.load_set_of_anns(train_sets, classes_to_ignore, params['events_of_interest'], params['convert_to_genus']) params['genus_names'], params['genus_mapping'] = tu.get_genus_mapping(params['class_names']) params['class_names_short'] = tu.get_short_class_names(params['class_names']) # standardize the low and high frequency value for specified classes params['standardize_classs_names'] = params['standardize_classs_names_ip'].split(';') for cc in params['standardize_classs_names']: if cc in params['class_names']: data_train = tu.standardize_low_freq(data_train, cc) else: print(cc, 'not found') # train loader train_dataset = adl.AudioLoader(data_train, params, is_train=True) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True, num_workers=params['num_workers'], pin_memory=True) # test set print('\nTesting on:') for tt in test_sets: print(tt['ann_path']) data_test, _, _ = tu.load_set_of_anns(test_sets, classes_to_ignore, params['events_of_interest'], params['convert_to_genus']) data_train = tu.remove_dupes(data_train, data_test) test_dataset = adl.AudioLoader(data_test, params, is_train=False) # batch size of 1 because of variable file length test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=params['num_workers'], pin_memory=True) inputs_train = next(iter(train_loader)) # TODO remove params['ip_height'], this is just legacy params['ip_height'] = int(params['spec_height']*params['resize_factor']) print('\ntrain batch spec size :', inputs_train['spec'].shape) print('class target size :', inputs_train['y_2d_classes'].shape) # select network model = select_model(params) model = model.to(params['device']) optimizer = torch.optim.Adam(model.parameters(), lr=params['lr']) #optimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], momentum=0.9) scheduler = CosineAnnealingLR(optimizer, params['num_epochs'] * len(train_loader)) if params['train_loss'] == 'mse': det_criterion = losses.mse_loss elif params['train_loss'] == 'focal': det_criterion = losses.focal_loss # save parameters to file with open(params['experiment'] + 'params.json', 'w') as da: json.dump(params, da, indent=2, sort_keys=True) # plotting train_plt_ls = pu.LossPlotter(params['experiment'] + 'train_loss.png', params['num_epochs']+1, ['train_loss'], None, None, ['epoch', 'train_loss'], logy=True) test_plt_ls = pu.LossPlotter(params['experiment'] + 'test_loss.png', params['num_epochs']+1, ['test_loss'], None, None, ['epoch', 'test_loss'], logy=True) test_plt = pu.LossPlotter(params['experiment'] + 'test.png', params['num_epochs']+1, ['avg_prec', 'rec_at_x', 'avg_prec_class', 'file_acc', 'top_class'], [0,1], None, ['epoch', '']) test_plt_class = pu.LossPlotter(params['experiment'] + 'test_avg_prec.png', params['num_epochs']+1, params['class_names_short'], [0,1], params['class_names_short'], ['epoch', 'avg_prec']) # # main train loop for epoch in range(0, params['num_epochs']+1): train_loss = train(model, epoch, train_loader, det_criterion, optimizer, scheduler, params) train_plt_ls.update_and_save(epoch, [train_loss['train_loss']]) if epoch % params['num_eval_epochs'] == 0: # detection accuracy on test set test_res, test_loss = test(model, epoch, test_loader, det_criterion, params) test_plt_ls.update_and_save(epoch, [test_loss['test_loss']]) test_plt.update_and_save(epoch, [test_res['avg_prec'], test_res['rec_at_x'], test_res['avg_prec_class'], test_res['file_acc'], test_res['top_class']['avg_prec']]) test_plt_class.update_and_save(epoch, [rs['avg_prec'] for rs in test_res['class_pr']]) pu.plot_pr_curve_class(params['experiment'] , 'test_pr', 'test_pr', test_res) # save trained model print('saving model to: ' + params['model_file_name']) op_state = {'epoch': epoch + 1, 'state_dict': model.state_dict(), #'optimizer' : optimizer.state_dict(), 'params' : params} torch.save(op_state, params['model_file_name']) # save an image with associated prediction for each batch in the test set if not args['do_not_save_images']: save_images_batch(model, test_loader, params)