File size: 11,095 Bytes
9ace58a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import torch
import torch.nn.functional as F
import os
import numpy as np
import pandas as pd
import json
import sys

from bat_detect.detector import models
import bat_detect.detector.compute_features as feats
import bat_detect.detector.post_process as pp
import bat_detect.utils.audio_utils as au


def get_default_bd_args():
    args = {}
    args['detection_threshold'] = 0.001
    args['time_expansion_factor'] = 1
    args['audio_dir'] = ''
    args['ann_dir'] = ''
    args['spec_slices'] = False
    args['chunk_size']  = 3
    args['spec_features'] = False
    args['cnn_features'] = False
    args['quiet'] = True
    args['save_preds_if_empty'] = True
    args['ann_dir'] = os.path.join(args['ann_dir'], '')
    return args
    
    
def get_audio_files(ip_dir):

    matches = []
    for root, dirnames, filenames in os.walk(ip_dir):
        for filename in filenames:
            if filename.lower().endswith('.wav'):
                matches.append(os.path.join(root, filename))
    return matches


def load_model(model_path, load_weights=True):

    # load model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if os.path.isfile(model_path):
        net_params = torch.load(model_path, map_location=device)
    else:
        print('Error: model not found.')
        sys.exit(1)

    params = net_params['params']
    params['device'] = device

    if params['model_name'] == 'Net2DFast':
        model = models.Net2DFast(params['num_filters'], num_classes=len(params['class_names']),
                                 emb_dim=params['emb_dim'], ip_height=params['ip_height'],
                                 resize_factor=params['resize_factor'])
    elif params['model_name'] == 'Net2DFastNoAttn':
        model = models.Net2DFastNoAttn(params['num_filters'], num_classes=len(params['class_names']),
                                 emb_dim=params['emb_dim'], ip_height=params['ip_height'],
                                 resize_factor=params['resize_factor'])
    elif params['model_name'] == 'Net2DFastNoCoordConv':
        model = models.Net2DFastNoCoordConv(params['num_filters'], num_classes=len(params['class_names']),
                                 emb_dim=params['emb_dim'], ip_height=params['ip_height'],
                                 resize_factor=params['resize_factor'])
    else:
        print('Error: unknown model.')

    if load_weights:
        model.load_state_dict(net_params['state_dict'])

    model = model.to(params['device'])
    model.eval()

    return model, params


def merge_results(predictions, spec_feats, cnn_feats, spec_slices):

    predictions_m = {}
    num_preds = np.sum([len(pp['det_probs']) for pp in predictions])

    if num_preds > 0:
        for kk in predictions[0].keys():
            predictions_m[kk] = np.hstack([pp[kk] for pp in predictions if pp['det_probs'].shape[0] > 0])
    else:
        # hack in case where no detected calls as we need some of the key names in dict
        predictions_m = predictions[0]

    if len(spec_feats) > 0:
        spec_feats = np.vstack(spec_feats)
    if len(cnn_feats) > 0:
        cnn_feats = np.vstack(cnn_feats)
    return predictions_m, spec_feats, cnn_feats, spec_slices


def convert_results(file_id, time_exp, duration, params, predictions, spec_feats, cnn_feats, spec_slices):

    # create a single dictionary - this is the format used by the annotation tool
    pred_dict = {}
    pred_dict['id'] = file_id
    pred_dict['annotated'] = False
    pred_dict['issues'] = False
    pred_dict['notes'] = 'Automatically generated.'
    pred_dict['time_exp'] = time_exp
    pred_dict['duration'] = round(duration, 4)
    pred_dict['annotation'] = []

    class_prob_best = predictions['class_probs'].max(0)
    class_ind_best  = predictions['class_probs'].argmax(0)
    class_overall   = pp.overall_class_pred(predictions['det_probs'], predictions['class_probs'])
    pred_dict['class_name'] = params['class_names'][np.argmax(class_overall)]

    for ii in range(predictions['det_probs'].shape[0]):
        res = {}
        res['start_time'] = round(float(predictions['start_times'][ii]), 4)
        res['end_time']   = round(float(predictions['end_times'][ii]), 4)
        res['low_freq']   = int(predictions['low_freqs'][ii])
        res['high_freq']  = int(predictions['high_freqs'][ii])
        res['class']      = str(params['class_names'][int(class_ind_best[ii])])
        res['class_prob'] = round(float(class_prob_best[ii]), 3)
        res['det_prob']   = round(float(predictions['det_probs'][ii]), 3)
        res['individual'] = '-1'
        res['event']      = 'Echolocation'
        pred_dict['annotation'].append(res)

    # combine into final results dictionary
    results = {}
    results['pred_dict'] = pred_dict
    if len(spec_feats) > 0:
        results['spec_feats'] = spec_feats
        results['spec_feat_names'] = feats.get_feature_names()
    if len(cnn_feats) > 0:
        results['cnn_feats'] = cnn_feats
        results['cnn_feat_names'] = [str(ii) for ii in range(cnn_feats.shape[1])]
    if len(spec_slices) > 0:
        results['spec_slices'] = spec_slices

    return results


def save_results_to_file(results, op_path):

    # make directory if it does not exist
    if not os.path.isdir(os.path.dirname(op_path)):
        os.makedirs(os.path.dirname(op_path))

    # save csv file - if there are predictions
    result_list = [res for res in results['pred_dict']['annotation']]
    df = pd.DataFrame(result_list)
    df['file_name'] = [results['pred_dict']['id']]*len(result_list)
    df.index.name = 'id'
    if 'class_prob' in df.columns:
        df = df[['det_prob', 'start_time', 'end_time', 'high_freq',
                 'low_freq', 'class', 'class_prob']]
        df.to_csv(op_path + '.csv', sep=',')

    # save features
    if 'spec_feats' in results.keys():
        df = pd.DataFrame(results['spec_feats'], columns=results['spec_feat_names'])
        df.to_csv(op_path + '_spec_features.csv', sep=',', index=False, float_format='%.5f')

    if 'cnn_feats' in results.keys():
        df = pd.DataFrame(results['cnn_feats'], columns=results['cnn_feat_names'])
        df.to_csv(op_path + '_cnn_features.csv', sep=',', index=False, float_format='%.5f')

    # save json file
    with open(op_path + '.json', 'w') as da:
        json.dump(results['pred_dict'], da, indent=2, sort_keys=True)


def compute_spectrogram(audio, sampling_rate, params, return_np=False):

    # pad audio so it is evenly divisible by downsampling factors
    duration = audio.shape[0] / float(sampling_rate)
    audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'],
                         params['fft_overlap'], params['resize_factor'],
                         params['spec_divide_factor'])

    # generate spectrogram
    spec, _ = au.generate_spectrogram(audio, sampling_rate, params)

    # convert to pytorch
    spec = torch.from_numpy(spec).to(params['device'])
    spec = spec.unsqueeze(0).unsqueeze(0)

    # resize the spec
    rs = params['resize_factor']
    spec_op_shape = (int(params['spec_height']*rs), int(spec.shape[-1]*rs))
    spec = F.interpolate(spec, size=spec_op_shape, mode='bilinear', align_corners=False)

    if return_np:
        spec_np = spec[0,0,:].cpu().data.numpy()
    else:
        spec_np = None

    return duration, spec, spec_np


def process_file(audio_file, model, params, args, time_exp=None, top_n=5, return_raw_preds=False, max_duration=False):

    # store temporary results here
    predictions = []
    spec_feats  = []
    cnn_feats   = []
    spec_slices = []

    # get time expansion  factor
    if time_exp is None:
        time_exp = args['time_expansion_factor']

    params['detection_threshold'] = args['detection_threshold']

    # load audio file
    sampling_rate, audio_full = au.load_audio_file(audio_file, time_exp,
                                   params['target_samp_rate'], params['scale_raw_audio'])

    # clipping maximum duration
    if max_duration is not False:
        max_duration = np.minimum(int(sampling_rate*max_duration), audio_full.shape[0])
        audio_full = audio_full[:max_duration]
    
    duration_full = audio_full.shape[0] / float(sampling_rate)

    return_np_spec = args['spec_features'] or args['spec_slices']

    # loop through larger file and split into chunks
    # TODO fix so that it overlaps correctly and takes care of duplicate detections at borders
    num_chunks = int(np.ceil(duration_full/args['chunk_size']))
    for chunk_id in range(num_chunks):

        # chunk
        chunk_time   = args['chunk_size']*chunk_id
        chunk_length = int(sampling_rate*args['chunk_size'])
        start_sample = chunk_id*chunk_length
        end_sample   = np.minimum((chunk_id+1)*chunk_length, audio_full.shape[0])
        audio = audio_full[start_sample:end_sample]

        # load audio file and compute spectrogram
        duration, spec, spec_np = compute_spectrogram(audio, sampling_rate, params, return_np_spec)

        # evaluate model
        with torch.no_grad():
            outputs = model(spec, return_feats=args['cnn_features'])

        # run non-max suppression
        pred_nms, features = pp.run_nms(outputs, params, np.array([float(sampling_rate)]))
        pred_nms = pred_nms[0]
        pred_nms['start_times'] += chunk_time
        pred_nms['end_times'] += chunk_time

        # if we have a background class
        if pred_nms['class_probs'].shape[0] > len(params['class_names']):
            pred_nms['class_probs'] = pred_nms['class_probs'][:-1, :]

        predictions.append(pred_nms)

        # extract features - if there are any calls detected
        if (pred_nms['det_probs'].shape[0] > 0):
            if args['spec_features']:
                spec_feats.append(feats.get_feats(spec_np, pred_nms, params))

            if args['cnn_features']:
                cnn_feats.append(features[0])

            if args['spec_slices']:
                spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms, params))

    # convert the predictions into output dictionary
    file_id = os.path.basename(audio_file)
    predictions, spec_feats, cnn_feats, spec_slices =\
              merge_results(predictions, spec_feats, cnn_feats, spec_slices)
    results = convert_results(file_id, time_exp, duration_full, params,
                              predictions, spec_feats, cnn_feats, spec_slices)

    # summarize results
    if not args['quiet']:
        num_detections = len(results['pred_dict']['annotation'])
        print('{}'.format(num_detections) + ' call(s) detected above the threshold.')

    # print results for top n classes
    if not args['quiet'] and (num_detections > 0):
        class_overall = pp.overall_class_pred(predictions['det_probs'], predictions['class_probs'])
        print('species name'.ljust(30) + 'probablity present')
        for cc in np.argsort(class_overall)[::-1][:top_n]:
            print(params['class_names'][cc].ljust(30) + str(round(class_overall[cc], 3)))

    if return_raw_preds:
        return predictions
    else:
        return results