Spaces:
Running
Running
import numpy as np | |
import matplotlib.pyplot as plt | |
import os | |
import torch | |
import torch.nn.functional as F | |
from torch.optim.lr_scheduler import CosineAnnealingLR | |
import json | |
import argparse | |
import sys | |
sys.path.append(os.path.join('..', '..')) | |
import bat_detect.detector.parameters as parameters | |
import bat_detect.detector.models as models | |
import bat_detect.detector.post_process as pp | |
import bat_detect.utils.plot_utils as pu | |
import bat_detect.train.audio_dataloader as adl | |
import bat_detect.train.evaluate as evl | |
import bat_detect.train.train_utils as tu | |
import bat_detect.train.train_split as ts | |
import bat_detect.train.losses as losses | |
import warnings | |
warnings.filterwarnings("ignore", category=UserWarning) | |
def save_images_batch(model, data_loader, params): | |
print('\nsaving images ...') | |
is_train_state = data_loader.dataset.is_train | |
data_loader.dataset.is_train = False | |
data_loader.dataset.return_spec_for_viz = True | |
model.eval() | |
ind = 0 # first image in each batch | |
with torch.no_grad(): | |
for batch_idx, inputs in enumerate(data_loader): | |
data = inputs['spec'].to(params['device']) | |
outputs = model(data) | |
spec_viz = inputs['spec_for_viz'].data.cpu().numpy() | |
orig_index = inputs['file_id'][ind] | |
plot_title = data_loader.dataset.data_anns[orig_index]['id'] | |
op_file_name = params['op_im_dir_test'] + data_loader.dataset.data_anns[orig_index]['id'] + '.jpg' | |
save_image(spec_viz, outputs, ind, inputs, params, op_file_name, plot_title) | |
data_loader.dataset.is_train = is_train_state | |
data_loader.dataset.return_spec_for_viz = False | |
def save_image(spec_viz, outputs, ind, inputs, params, op_file_name, plot_title): | |
pred_nms, _ = pp.run_nms(outputs, params, inputs['sampling_rate'].float()) | |
pred_hm = outputs['pred_det'][ind, 0, :].data.cpu().numpy() | |
spec_viz = spec_viz[ind, 0, :] | |
gt = parse_gt_data(inputs)[ind] | |
sampling_rate = inputs['sampling_rate'][ind].item() | |
duration = inputs['duration'][ind].item() | |
pu.plot_spec(spec_viz, sampling_rate, duration, gt, pred_nms[ind], | |
params, plot_title, op_file_name, pred_hm, plot_boxes=True, fixed_aspect=False) | |
def loss_fun(outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq): | |
# detection loss | |
loss = params['det_loss_weight']*det_criterion(outputs['pred_det'], gt_det) | |
# bounding box size loss | |
loss += params['size_loss_weight']*losses.bbox_size_loss(outputs['pred_size'], gt_size) | |
# classification loss | |
valid_mask = (gt_class[:, :-1, :, :].sum(1) > 0).float().unsqueeze(1) | |
p_class = outputs['pred_class'][:, :-1, :] | |
loss += params['class_loss_weight']*det_criterion(p_class, gt_class[:, :-1, :], valid_mask=valid_mask) | |
return loss | |
def train(model, epoch, data_loader, det_criterion, optimizer, scheduler, params): | |
model.train() | |
train_loss = tu.AverageMeter() | |
class_inv_freq = torch.from_numpy(np.array(params['class_inv_freq'], dtype=np.float32)).to(params['device']) | |
class_inv_freq = class_inv_freq.unsqueeze(0).unsqueeze(2).unsqueeze(2) | |
print('\nEpoch', epoch) | |
for batch_idx, inputs in enumerate(data_loader): | |
data = inputs['spec'].to(params['device']) | |
gt_det = inputs['y_2d_det'].to(params['device']) | |
gt_size = inputs['y_2d_size'].to(params['device']) | |
gt_class = inputs['y_2d_classes'].to(params['device']) | |
optimizer.zero_grad() | |
outputs = model(data) | |
loss = loss_fun(outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq) | |
train_loss.update(loss.item(), data.shape[0]) | |
loss.backward() | |
optimizer.step() | |
scheduler.step() | |
if batch_idx % 50 == 0 and batch_idx != 0: | |
print('[{}/{}]\tLoss: {:.4f}'.format( | |
batch_idx * len(data), len(data_loader.dataset), train_loss.avg)) | |
print('Train loss : {:.4f}'.format(train_loss.avg)) | |
res = {} | |
res['train_loss'] = float(train_loss.avg) | |
return res | |
def test(model, epoch, data_loader, det_criterion, params): | |
model.eval() | |
predictions = [] | |
ground_truths = [] | |
test_loss = tu.AverageMeter() | |
class_inv_freq = torch.from_numpy(np.array(params['class_inv_freq'], dtype=np.float32)).to(params['device']) | |
class_inv_freq = class_inv_freq.unsqueeze(0).unsqueeze(2).unsqueeze(2) | |
with torch.no_grad(): | |
for batch_idx, inputs in enumerate(data_loader): | |
data = inputs['spec'].to(params['device']) | |
gt_det = inputs['y_2d_det'].to(params['device']) | |
gt_size = inputs['y_2d_size'].to(params['device']) | |
gt_class = inputs['y_2d_classes'].to(params['device']) | |
outputs = model(data) | |
# if the model needs a fixed sized intput run this | |
# data = torch.cat(torch.split(data, int(params['spec_train_width']*params['resize_factor']), 3), 0) | |
# outputs = model(data) | |
# for kk in ['pred_det', 'pred_size', 'pred_class']: | |
# outputs[kk] = torch.cat([oo for oo in outputs[kk]], 2).unsqueeze(0) | |
if params['save_test_image_during_train'] and batch_idx == 0: | |
# for visualization - save the first prediction | |
ind = 0 | |
orig_index = inputs['file_id'][ind] | |
plot_title = data_loader.dataset.data_anns[orig_index]['id'] | |
op_file_name = params['op_im_dir'] + str(orig_index.item()).zfill(4) + '_' + str(epoch).zfill(4) + '_pred.jpg' | |
save_image(data, outputs, ind, inputs, params, op_file_name, plot_title) | |
loss = loss_fun(outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq) | |
test_loss.update(loss.item(), data.shape[0]) | |
# do NMS | |
pred_nms, _ = pp.run_nms(outputs, params, inputs['sampling_rate'].float()) | |
predictions.extend(pred_nms) | |
ground_truths.extend(parse_gt_data(inputs)) | |
res_det = evl.evaluate_predictions(ground_truths, predictions, params['class_names'], | |
params['detection_overlap'], params['ignore_start_end']) | |
print('\nTest loss : {:.4f}'.format(test_loss.avg)) | |
print('Rec at 0.95 (det) : {:.4f}'.format(res_det['rec_at_x'])) | |
print('Avg prec (cls) : {:.4f}'.format(res_det['avg_prec'])) | |
print('File acc (cls) : {:.2f} - for {} out of {}'.format(res_det['file_acc'], | |
res_det['num_valid_files'], res_det['num_total_files'])) | |
print('Cls Avg prec (cls) : {:.4f}'.format(res_det['avg_prec_class'])) | |
print('\nPer class average precision') | |
str_len = np.max([len(rs['name']) for rs in res_det['class_pr']]) + 5 | |
for cc, rs in enumerate(res_det['class_pr']): | |
if rs['num_gt'] > 0: | |
print(str(cc).ljust(5) + rs['name'].ljust(str_len) + '{:.4f}'.format(rs['avg_prec'])) | |
res = {} | |
res['test_loss'] = float(test_loss.avg) | |
return res_det, res | |
def parse_gt_data(inputs): | |
# reads the torch arrays into a dictionary of numpy arrays, taking care to | |
# remove padding data i.e. not valid ones | |
keys = ['start_times', 'end_times', 'low_freqs', 'high_freqs', 'class_ids', 'individual_ids'] | |
batch_data = [] | |
for ind in range(inputs['start_times'].shape[0]): | |
is_valid = inputs['is_valid'][ind]==1 | |
gt = {} | |
for kk in keys: | |
gt[kk] = inputs[kk][ind][is_valid].numpy().astype(np.float32) | |
gt['duration'] = inputs['duration'][ind].item() | |
gt['file_id'] = inputs['file_id'][ind].item() | |
gt['class_id_file'] = inputs['class_id_file'][ind].item() | |
batch_data.append(gt) | |
return batch_data | |
def select_model(params): | |
num_classes = len(params['class_names']) | |
if params['model_name'] == 'Net2DFast': | |
model = models.Net2DFast(params['num_filters'], num_classes=num_classes, | |
emb_dim=params['emb_dim'], ip_height=params['ip_height'], | |
resize_factor=params['resize_factor']) | |
elif params['model_name'] == 'Net2DFastNoAttn': | |
model = models.Net2DFastNoAttn(params['num_filters'], num_classes=num_classes, | |
emb_dim=params['emb_dim'], ip_height=params['ip_height'], | |
resize_factor=params['resize_factor']) | |
elif params['model_name'] == 'Net2DFastNoCoordConv': | |
model = models.Net2DFastNoCoordConv(params['num_filters'], num_classes=num_classes, | |
emb_dim=params['emb_dim'], ip_height=params['ip_height'], | |
resize_factor=params['resize_factor']) | |
else: | |
print('No valid network specified') | |
return model | |
if __name__ == "__main__": | |
plt.close('all') | |
params = parameters.get_params(True) | |
if torch.cuda.is_available(): | |
params['device'] = 'cuda' | |
else: | |
params['device'] = 'cpu' | |
# setup arg parser and populate it with exiting parameters - will not work with lists | |
parser = argparse.ArgumentParser() | |
parser.add_argument('data_dir', type=str, | |
help='Path to root of datasets') | |
parser.add_argument('ann_dir', type=str, | |
help='Path to extracted annotations') | |
parser.add_argument('--train_split', type=str, default='diff', # diff, same | |
help='Which train split to use') | |
parser.add_argument('--notes', type=str, default='', | |
help='Notes to save in text file') | |
parser.add_argument('--do_not_save_images', action='store_false', | |
help='Do not save images at the end of training') | |
parser.add_argument('--standardize_classs_names_ip', type=str, | |
default='Rhinolophus ferrumequinum;Rhinolophus hipposideros', | |
help='Will set low and high frequency the same for these classes. Separate names with ";"') | |
for key, val in params.items(): | |
parser.add_argument('--'+key, type=type(val), default=val) | |
params = vars(parser.parse_args()) | |
# save notes file | |
if params['notes'] != '': | |
tu.write_notes_file(params['experiment'] + 'notes.txt', params['notes']) | |
# load the training and test meta data - there are different splits defined | |
train_sets, test_sets = ts.get_train_test_data(params['ann_dir'], params['data_dir'], params['train_split']) | |
train_sets_no_path, test_sets_no_path = ts.get_train_test_data('', '', params['train_split']) | |
# keep track of what we have trained on | |
params['train_sets'] = train_sets_no_path | |
params['test_sets'] = test_sets_no_path | |
# load train annotations - merge them all together | |
print('\nTraining on:') | |
for tt in train_sets: | |
print(tt['ann_path']) | |
classes_to_ignore = params['classes_to_ignore']+params['generic_class'] | |
data_train, params['class_names'], params['class_inv_freq'] = \ | |
tu.load_set_of_anns(train_sets, classes_to_ignore, params['events_of_interest'], params['convert_to_genus']) | |
params['genus_names'], params['genus_mapping'] = tu.get_genus_mapping(params['class_names']) | |
params['class_names_short'] = tu.get_short_class_names(params['class_names']) | |
# standardize the low and high frequency value for specified classes | |
params['standardize_classs_names'] = params['standardize_classs_names_ip'].split(';') | |
for cc in params['standardize_classs_names']: | |
if cc in params['class_names']: | |
data_train = tu.standardize_low_freq(data_train, cc) | |
else: | |
print(cc, 'not found') | |
# train loader | |
train_dataset = adl.AudioLoader(data_train, params, is_train=True) | |
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params['batch_size'], | |
shuffle=True, num_workers=params['num_workers'], pin_memory=True) | |
# test set | |
print('\nTesting on:') | |
for tt in test_sets: | |
print(tt['ann_path']) | |
data_test, _, _ = tu.load_set_of_anns(test_sets, classes_to_ignore, params['events_of_interest'], params['convert_to_genus']) | |
data_train = tu.remove_dupes(data_train, data_test) | |
test_dataset = adl.AudioLoader(data_test, params, is_train=False) | |
# batch size of 1 because of variable file length | |
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, | |
shuffle=False, num_workers=params['num_workers'], pin_memory=True) | |
inputs_train = next(iter(train_loader)) | |
# TODO remove params['ip_height'], this is just legacy | |
params['ip_height'] = int(params['spec_height']*params['resize_factor']) | |
print('\ntrain batch spec size :', inputs_train['spec'].shape) | |
print('class target size :', inputs_train['y_2d_classes'].shape) | |
# select network | |
model = select_model(params) | |
model = model.to(params['device']) | |
optimizer = torch.optim.Adam(model.parameters(), lr=params['lr']) | |
#optimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], momentum=0.9) | |
scheduler = CosineAnnealingLR(optimizer, params['num_epochs'] * len(train_loader)) | |
if params['train_loss'] == 'mse': | |
det_criterion = losses.mse_loss | |
elif params['train_loss'] == 'focal': | |
det_criterion = losses.focal_loss | |
# save parameters to file | |
with open(params['experiment'] + 'params.json', 'w') as da: | |
json.dump(params, da, indent=2, sort_keys=True) | |
# plotting | |
train_plt_ls = pu.LossPlotter(params['experiment'] + 'train_loss.png', params['num_epochs']+1, | |
['train_loss'], None, None, ['epoch', 'train_loss'], logy=True) | |
test_plt_ls = pu.LossPlotter(params['experiment'] + 'test_loss.png', params['num_epochs']+1, | |
['test_loss'], None, None, ['epoch', 'test_loss'], logy=True) | |
test_plt = pu.LossPlotter(params['experiment'] + 'test.png', params['num_epochs']+1, | |
['avg_prec', 'rec_at_x', 'avg_prec_class', 'file_acc', 'top_class'], [0,1], None, ['epoch', '']) | |
test_plt_class = pu.LossPlotter(params['experiment'] + 'test_avg_prec.png', params['num_epochs']+1, | |
params['class_names_short'], [0,1], params['class_names_short'], ['epoch', 'avg_prec']) | |
# | |
# main train loop | |
for epoch in range(0, params['num_epochs']+1): | |
train_loss = train(model, epoch, train_loader, det_criterion, optimizer, scheduler, params) | |
train_plt_ls.update_and_save(epoch, [train_loss['train_loss']]) | |
if epoch % params['num_eval_epochs'] == 0: | |
# detection accuracy on test set | |
test_res, test_loss = test(model, epoch, test_loader, det_criterion, params) | |
test_plt_ls.update_and_save(epoch, [test_loss['test_loss']]) | |
test_plt.update_and_save(epoch, [test_res['avg_prec'], test_res['rec_at_x'], | |
test_res['avg_prec_class'], test_res['file_acc'], test_res['top_class']['avg_prec']]) | |
test_plt_class.update_and_save(epoch, [rs['avg_prec'] for rs in test_res['class_pr']]) | |
pu.plot_pr_curve_class(params['experiment'] , 'test_pr', 'test_pr', test_res) | |
# save trained model | |
print('saving model to: ' + params['model_file_name']) | |
op_state = {'epoch': epoch + 1, | |
'state_dict': model.state_dict(), | |
#'optimizer' : optimizer.state_dict(), | |
'params' : params} | |
torch.save(op_state, params['model_file_name']) | |
# save an image with associated prediction for each batch in the test set | |
if not args['do_not_save_images']: | |
save_images_batch(model, test_loader, params) | |