batdetect2 / bat_detect /utils /detector_utils.py
Oisin Mac Aodha
added bat code
9ace58a
raw
history blame
No virus
11.1 kB
import torch
import torch.nn.functional as F
import os
import numpy as np
import pandas as pd
import json
import sys
from bat_detect.detector import models
import bat_detect.detector.compute_features as feats
import bat_detect.detector.post_process as pp
import bat_detect.utils.audio_utils as au
def get_default_bd_args():
args = {}
args['detection_threshold'] = 0.001
args['time_expansion_factor'] = 1
args['audio_dir'] = ''
args['ann_dir'] = ''
args['spec_slices'] = False
args['chunk_size'] = 3
args['spec_features'] = False
args['cnn_features'] = False
args['quiet'] = True
args['save_preds_if_empty'] = True
args['ann_dir'] = os.path.join(args['ann_dir'], '')
return args
def get_audio_files(ip_dir):
matches = []
for root, dirnames, filenames in os.walk(ip_dir):
for filename in filenames:
if filename.lower().endswith('.wav'):
matches.append(os.path.join(root, filename))
return matches
def load_model(model_path, load_weights=True):
# load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if os.path.isfile(model_path):
net_params = torch.load(model_path, map_location=device)
else:
print('Error: model not found.')
sys.exit(1)
params = net_params['params']
params['device'] = device
if params['model_name'] == 'Net2DFast':
model = models.Net2DFast(params['num_filters'], num_classes=len(params['class_names']),
emb_dim=params['emb_dim'], ip_height=params['ip_height'],
resize_factor=params['resize_factor'])
elif params['model_name'] == 'Net2DFastNoAttn':
model = models.Net2DFastNoAttn(params['num_filters'], num_classes=len(params['class_names']),
emb_dim=params['emb_dim'], ip_height=params['ip_height'],
resize_factor=params['resize_factor'])
elif params['model_name'] == 'Net2DFastNoCoordConv':
model = models.Net2DFastNoCoordConv(params['num_filters'], num_classes=len(params['class_names']),
emb_dim=params['emb_dim'], ip_height=params['ip_height'],
resize_factor=params['resize_factor'])
else:
print('Error: unknown model.')
if load_weights:
model.load_state_dict(net_params['state_dict'])
model = model.to(params['device'])
model.eval()
return model, params
def merge_results(predictions, spec_feats, cnn_feats, spec_slices):
predictions_m = {}
num_preds = np.sum([len(pp['det_probs']) for pp in predictions])
if num_preds > 0:
for kk in predictions[0].keys():
predictions_m[kk] = np.hstack([pp[kk] for pp in predictions if pp['det_probs'].shape[0] > 0])
else:
# hack in case where no detected calls as we need some of the key names in dict
predictions_m = predictions[0]
if len(spec_feats) > 0:
spec_feats = np.vstack(spec_feats)
if len(cnn_feats) > 0:
cnn_feats = np.vstack(cnn_feats)
return predictions_m, spec_feats, cnn_feats, spec_slices
def convert_results(file_id, time_exp, duration, params, predictions, spec_feats, cnn_feats, spec_slices):
# create a single dictionary - this is the format used by the annotation tool
pred_dict = {}
pred_dict['id'] = file_id
pred_dict['annotated'] = False
pred_dict['issues'] = False
pred_dict['notes'] = 'Automatically generated.'
pred_dict['time_exp'] = time_exp
pred_dict['duration'] = round(duration, 4)
pred_dict['annotation'] = []
class_prob_best = predictions['class_probs'].max(0)
class_ind_best = predictions['class_probs'].argmax(0)
class_overall = pp.overall_class_pred(predictions['det_probs'], predictions['class_probs'])
pred_dict['class_name'] = params['class_names'][np.argmax(class_overall)]
for ii in range(predictions['det_probs'].shape[0]):
res = {}
res['start_time'] = round(float(predictions['start_times'][ii]), 4)
res['end_time'] = round(float(predictions['end_times'][ii]), 4)
res['low_freq'] = int(predictions['low_freqs'][ii])
res['high_freq'] = int(predictions['high_freqs'][ii])
res['class'] = str(params['class_names'][int(class_ind_best[ii])])
res['class_prob'] = round(float(class_prob_best[ii]), 3)
res['det_prob'] = round(float(predictions['det_probs'][ii]), 3)
res['individual'] = '-1'
res['event'] = 'Echolocation'
pred_dict['annotation'].append(res)
# combine into final results dictionary
results = {}
results['pred_dict'] = pred_dict
if len(spec_feats) > 0:
results['spec_feats'] = spec_feats
results['spec_feat_names'] = feats.get_feature_names()
if len(cnn_feats) > 0:
results['cnn_feats'] = cnn_feats
results['cnn_feat_names'] = [str(ii) for ii in range(cnn_feats.shape[1])]
if len(spec_slices) > 0:
results['spec_slices'] = spec_slices
return results
def save_results_to_file(results, op_path):
# make directory if it does not exist
if not os.path.isdir(os.path.dirname(op_path)):
os.makedirs(os.path.dirname(op_path))
# save csv file - if there are predictions
result_list = [res for res in results['pred_dict']['annotation']]
df = pd.DataFrame(result_list)
df['file_name'] = [results['pred_dict']['id']]*len(result_list)
df.index.name = 'id'
if 'class_prob' in df.columns:
df = df[['det_prob', 'start_time', 'end_time', 'high_freq',
'low_freq', 'class', 'class_prob']]
df.to_csv(op_path + '.csv', sep=',')
# save features
if 'spec_feats' in results.keys():
df = pd.DataFrame(results['spec_feats'], columns=results['spec_feat_names'])
df.to_csv(op_path + '_spec_features.csv', sep=',', index=False, float_format='%.5f')
if 'cnn_feats' in results.keys():
df = pd.DataFrame(results['cnn_feats'], columns=results['cnn_feat_names'])
df.to_csv(op_path + '_cnn_features.csv', sep=',', index=False, float_format='%.5f')
# save json file
with open(op_path + '.json', 'w') as da:
json.dump(results['pred_dict'], da, indent=2, sort_keys=True)
def compute_spectrogram(audio, sampling_rate, params, return_np=False):
# pad audio so it is evenly divisible by downsampling factors
duration = audio.shape[0] / float(sampling_rate)
audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'],
params['fft_overlap'], params['resize_factor'],
params['spec_divide_factor'])
# generate spectrogram
spec, _ = au.generate_spectrogram(audio, sampling_rate, params)
# convert to pytorch
spec = torch.from_numpy(spec).to(params['device'])
spec = spec.unsqueeze(0).unsqueeze(0)
# resize the spec
rs = params['resize_factor']
spec_op_shape = (int(params['spec_height']*rs), int(spec.shape[-1]*rs))
spec = F.interpolate(spec, size=spec_op_shape, mode='bilinear', align_corners=False)
if return_np:
spec_np = spec[0,0,:].cpu().data.numpy()
else:
spec_np = None
return duration, spec, spec_np
def process_file(audio_file, model, params, args, time_exp=None, top_n=5, return_raw_preds=False, max_duration=False):
# store temporary results here
predictions = []
spec_feats = []
cnn_feats = []
spec_slices = []
# get time expansion factor
if time_exp is None:
time_exp = args['time_expansion_factor']
params['detection_threshold'] = args['detection_threshold']
# load audio file
sampling_rate, audio_full = au.load_audio_file(audio_file, time_exp,
params['target_samp_rate'], params['scale_raw_audio'])
# clipping maximum duration
if max_duration is not False:
max_duration = np.minimum(int(sampling_rate*max_duration), audio_full.shape[0])
audio_full = audio_full[:max_duration]
duration_full = audio_full.shape[0] / float(sampling_rate)
return_np_spec = args['spec_features'] or args['spec_slices']
# loop through larger file and split into chunks
# TODO fix so that it overlaps correctly and takes care of duplicate detections at borders
num_chunks = int(np.ceil(duration_full/args['chunk_size']))
for chunk_id in range(num_chunks):
# chunk
chunk_time = args['chunk_size']*chunk_id
chunk_length = int(sampling_rate*args['chunk_size'])
start_sample = chunk_id*chunk_length
end_sample = np.minimum((chunk_id+1)*chunk_length, audio_full.shape[0])
audio = audio_full[start_sample:end_sample]
# load audio file and compute spectrogram
duration, spec, spec_np = compute_spectrogram(audio, sampling_rate, params, return_np_spec)
# evaluate model
with torch.no_grad():
outputs = model(spec, return_feats=args['cnn_features'])
# run non-max suppression
pred_nms, features = pp.run_nms(outputs, params, np.array([float(sampling_rate)]))
pred_nms = pred_nms[0]
pred_nms['start_times'] += chunk_time
pred_nms['end_times'] += chunk_time
# if we have a background class
if pred_nms['class_probs'].shape[0] > len(params['class_names']):
pred_nms['class_probs'] = pred_nms['class_probs'][:-1, :]
predictions.append(pred_nms)
# extract features - if there are any calls detected
if (pred_nms['det_probs'].shape[0] > 0):
if args['spec_features']:
spec_feats.append(feats.get_feats(spec_np, pred_nms, params))
if args['cnn_features']:
cnn_feats.append(features[0])
if args['spec_slices']:
spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms, params))
# convert the predictions into output dictionary
file_id = os.path.basename(audio_file)
predictions, spec_feats, cnn_feats, spec_slices =\
merge_results(predictions, spec_feats, cnn_feats, spec_slices)
results = convert_results(file_id, time_exp, duration_full, params,
predictions, spec_feats, cnn_feats, spec_slices)
# summarize results
if not args['quiet']:
num_detections = len(results['pred_dict']['annotation'])
print('{}'.format(num_detections) + ' call(s) detected above the threshold.')
# print results for top n classes
if not args['quiet'] and (num_detections > 0):
class_overall = pp.overall_class_pred(predictions['det_probs'], predictions['class_probs'])
print('species name'.ljust(30) + 'probablity present')
for cc in np.argsort(class_overall)[::-1][:top_n]:
print(params['class_names'][cc].ljust(30) + str(round(class_overall[cc], 3)))
if return_raw_preds:
return predictions
else:
return results