Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
np.seterr(divide='ignore', invalid='ignore') | |
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap): | |
nfft = int(fft_win_length*sampling_rate) | |
noverlap = int(fft_overlap*nfft) | |
return ((x_pos*(nfft - noverlap)) + noverlap) / sampling_rate | |
#return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window | |
def overall_class_pred(det_prob, class_prob): | |
weighted_pred = (class_prob*det_prob).sum(1) | |
return weighted_pred / weighted_pred.sum() | |
def run_nms(outputs, params, sampling_rate): | |
pred_det = outputs['pred_det'] # probability of box | |
pred_size = outputs['pred_size'] # box size | |
pred_det_nms = non_max_suppression(pred_det, params['nms_kernel_size']) | |
freq_rescale = (params['max_freq'] - params['min_freq']) /pred_det.shape[-2] | |
# NOTE there will be small differences depending on which sampling rate is chosen | |
# as we are choosing the same sampling rate for the entire batch | |
duration = x_coords_to_time(pred_det.shape[-1], sampling_rate[0].item(), | |
params['fft_win_length'], params['fft_overlap']) | |
top_k = int(duration * params['nms_top_k_per_sec']) | |
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k) | |
# loop over batch to save outputs | |
preds = [] | |
feats = [] | |
for ii in range(pred_det_nms.shape[0]): | |
# get valid indices | |
inds_ord = torch.argsort(x_pos[ii, :]) | |
valid_inds = scores[ii, inds_ord] > params['detection_threshold'] | |
valid_inds = inds_ord[valid_inds] | |
# create result dictionary | |
pred = {} | |
pred['det_probs'] = scores[ii, valid_inds] | |
pred['x_pos'] = x_pos[ii, valid_inds] | |
pred['y_pos'] = y_pos[ii, valid_inds] | |
pred['bb_width'] = pred_size[ii, 0, pred['y_pos'], pred['x_pos']] | |
pred['bb_height'] = pred_size[ii, 1, pred['y_pos'], pred['x_pos']] | |
pred['start_times'] = x_coords_to_time(pred['x_pos'].float() / params['resize_factor'], | |
sampling_rate[ii].item(), params['fft_win_length'], params['fft_overlap']) | |
pred['end_times'] = x_coords_to_time((pred['x_pos'].float()+pred['bb_width']) / params['resize_factor'], | |
sampling_rate[ii].item(), params['fft_win_length'], params['fft_overlap']) | |
pred['low_freqs'] = (pred_size[ii].shape[1] - pred['y_pos'].float())*freq_rescale + params['min_freq'] | |
pred['high_freqs'] = pred['low_freqs'] + pred['bb_height']*freq_rescale | |
# extract the per class votes | |
if 'pred_class' in outputs: | |
pred['class_probs'] = outputs['pred_class'][ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds]] | |
# extract the model features | |
if 'features' in outputs: | |
feat = outputs['features'][ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds]].transpose(0, 1) | |
feat = feat.cpu().numpy().astype(np.float32) | |
feats.append(feat) | |
# convert to numpy | |
for kk in pred.keys(): | |
pred[kk] = pred[kk].cpu().numpy().astype(np.float32) | |
preds.append(pred) | |
return preds, feats | |
def non_max_suppression(heat, kernel_size): | |
# kernel can be an int or list/tuple | |
if type(kernel_size) is int: | |
kernel_size_h = kernel_size | |
kernel_size_w = kernel_size | |
pad_h = (kernel_size_h - 1) // 2 | |
pad_w = (kernel_size_w - 1) // 2 | |
hmax = nn.functional.max_pool2d(heat, (kernel_size_h, kernel_size_w), stride=1, padding=(pad_h, pad_w)) | |
keep = (hmax == heat).float() | |
return heat * keep | |
def get_topk_scores(scores, K): | |
# expects input of size: batch x 1 x height x width | |
batch, _, height, width = scores.size() | |
topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K) | |
topk_inds = topk_inds % (height * width) | |
topk_ys = torch.div(topk_inds, width, rounding_mode='floor').long() | |
topk_xs = (topk_inds % width).long() | |
return topk_scores, topk_ys, topk_xs | |