Spaces:
Running
Running
Oisin Mac Aodha
commited on
Commit
•
9ace58a
1
Parent(s):
2502b35
added bat code
Browse files- README.md +75 -12
- app.py +71 -4
- bat_detect/__init__.py +0 -0
- bat_detect/detector/__init__.py +0 -0
- bat_detect/detector/compute_features.py +84 -0
- bat_detect/detector/model_helpers.py +93 -0
- bat_detect/detector/models.py +218 -0
- bat_detect/detector/parameters.py +108 -0
- bat_detect/detector/post_process.py +100 -0
- bat_detect/evaluate/evaluate_models.py +565 -0
- bat_detect/evaluate/readme.md +33 -0
- bat_detect/finetune/finetune_model.py +183 -0
- bat_detect/finetune/prep_data_finetune.py +151 -0
- bat_detect/finetune/readme.md +40 -0
- bat_detect/train/__init__.py +0 -0
- bat_detect/train/audio_dataloader.py +407 -0
- bat_detect/train/evaluate.py +333 -0
- bat_detect/train/losses.py +56 -0
- bat_detect/train/readme.md +18 -0
- bat_detect/train/train_model.py +356 -0
- bat_detect/train/train_split.py +231 -0
- bat_detect/train/train_utils.py +185 -0
- bat_detect/utils/__init__.py +0 -0
- bat_detect/utils/audio_utils.py +164 -0
- bat_detect/utils/detector_utils.py +291 -0
- bat_detect/utils/plot_utils.py +371 -0
- bat_detect/utils/visualize.py +158 -0
- bat_detect/utils/wavfile.py +291 -0
- example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav +0 -0
- example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav +0 -0
- example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav +0 -0
- models/readme.md +1 -0
- requirements.txt +9 -0
README.md
CHANGED
@@ -1,12 +1,75 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# BatDetect2
|
2 |
+
<img align="left" width="64" height="64" src="ims/bat_icon.png">
|
3 |
+
|
4 |
+
Code for detecting and classifying bat echolocation calls in high frequency audio recordings.
|
5 |
+
|
6 |
+
|
7 |
+
### Getting started
|
8 |
+
1) Install the Anaconda Python 3.10 distribution for your operating system from [here](https://www.continuum.io/downloads).
|
9 |
+
2) Download this code from the repository (by clicking on the green button on top right) and unzip it.
|
10 |
+
3) Create a new environment and install the required packages:
|
11 |
+
`conda create -y --name batdetect2 python==3.10`
|
12 |
+
`conda activate batdetect2`
|
13 |
+
`conda install --file requirements.txt`
|
14 |
+
|
15 |
+
|
16 |
+
### Try the model
|
17 |
+
Click [here](https://colab.research.google.com/github/macaodha/batdetect2/blob/master/batdetect2_notebook.ipynb) to run the model using Google Colab. You can also run this notebook locally.
|
18 |
+
|
19 |
+
|
20 |
+
### Running the model on your own data
|
21 |
+
After following the above steps to install the code you can run the model on your own data by opening the command line where the code is located and typing:
|
22 |
+
`python run_batdetect.py AUDIO_DIR ANN_DIR DETECTION_THRESHOLD`
|
23 |
+
e.g.
|
24 |
+
`python run_batdetect.py example_data/audio/ example_data/anns/ 0.3`
|
25 |
+
|
26 |
+
|
27 |
+
`AUDIO_DIR` is the path on your computer to the audio wav files of interest.
|
28 |
+
`ANN_DIR` is the path on your computer where the model predictions will be saved. The model will output both `.csv` and `.json` results for each audio file.
|
29 |
+
`DETECTION_THRESHOLD` is a number between 0 and 1 specifying the cut-off threshold applied to the calls. A smaller number will result in more calls detected, but with the chance of introducing more mistakes.
|
30 |
+
|
31 |
+
There are also optional arguments, e.g. you can request that the model outputs features (i.e. estimated call parameters) such as duration, max_frequency, etc. by setting the flag `--spec_features`. These will be saved as `*_spec_features.csv` files:
|
32 |
+
`python run_batdetect.py example_data/audio/ example_data/anns/ 0.3 --spec_features`
|
33 |
+
|
34 |
+
You can also specify which model to use by setting the `--model_path` argument. If not specified, it will default to using a model trained on UK data e.g.
|
35 |
+
`python run_batdetect.py example_data/audio/ example_data/anns/ 0.3 --model_path models/Net2DFast_UK_same.pth.tar`
|
36 |
+
|
37 |
+
|
38 |
+
### Training the model on your own data
|
39 |
+
Take a look at the steps outlined in fintuning readme [here](bat_detect/finetune/readme.md) for a description of how to train your own model.
|
40 |
+
|
41 |
+
|
42 |
+
### Data and annotations
|
43 |
+
The raw audio data and annotations used to train the models in the paper will be added soon.
|
44 |
+
The audio interface used to annotate audio data for training and evaluation is available [here](https://github.com/macaodha/batdetect2_GUI).
|
45 |
+
|
46 |
+
|
47 |
+
### Warning
|
48 |
+
The models developed and shared as part of this repository should be used with caution.
|
49 |
+
While they have been evaluated on held out audio data, great care should be taken when using the model outputs for any form of biodiversity assessment.
|
50 |
+
Your data may differ, and as a result it is very strongly recommended that you validate the model first using data with known species to ensure that the outputs can be trusted.
|
51 |
+
|
52 |
+
|
53 |
+
### FAQ
|
54 |
+
For more information please consult our [FAQ](faq.md).
|
55 |
+
|
56 |
+
|
57 |
+
### Reference
|
58 |
+
If you find our work useful in your research please consider citing our paper which you can find [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1):
|
59 |
+
```
|
60 |
+
@article{batdetect2_2022,
|
61 |
+
title = {Towards a General Approach for Bat Echolocation Detection and Classification},
|
62 |
+
author = {Mac Aodha, Oisin and Mart\'{i}nez Balvanera, Santiago and Damstra, Elise and Cooke, Martyn and Eichinski, Philip and Browning, Ella and Barataudm, Michel and Boughey, Katherine and Coles, Roger and Giacomini, Giada and MacSwiney G., M. Cristina and K. Obrist, Martin and Parsons, Stuart and Sattler, Thomas and Jones, Kate E.},
|
63 |
+
journal = {bioRxiv},
|
64 |
+
year = {2022}
|
65 |
+
}
|
66 |
+
```
|
67 |
+
|
68 |
+
### Acknowledgements
|
69 |
+
Thanks to all the contributors who spent time collecting and annotating audio data.
|
70 |
+
|
71 |
+
|
72 |
+
### TODOs
|
73 |
+
- [x] Release the code and pretrained model
|
74 |
+
- [ ] Release the datasets and annotations used the experiments in the paper
|
75 |
+
- [ ] Add the scripts used to generate the tables and figures from the paper
|
app.py
CHANGED
@@ -1,8 +1,75 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
5 |
|
6 |
-
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
|
7 |
-
iface.launch()
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import os
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import pandas as pd
|
5 |
+
import numpy as np
|
6 |
|
7 |
+
import bat_detect.utils.detector_utils as du
|
8 |
+
import bat_detect.utils.audio_utils as au
|
9 |
+
import bat_detect.utils.plot_utils as viz
|
10 |
|
|
|
|
|
11 |
|
12 |
+
# setup the arguments
|
13 |
+
args = {}
|
14 |
+
args = du.get_default_bd_args()
|
15 |
+
args['detection_threshold'] = 0.3
|
16 |
+
args['time_expansion_factor'] = 1
|
17 |
+
args['model_path'] = 'models/Net2DFast_UK_same.pth.tar'
|
18 |
+
|
19 |
+
# load the model
|
20 |
+
model, params = du.load_model(args['model_path'])
|
21 |
+
|
22 |
+
|
23 |
+
df = gr.Dataframe(
|
24 |
+
headers=["species", "time_in_file", "species_prob"],
|
25 |
+
datatype=["str", "str", "str"],
|
26 |
+
row_count=1,
|
27 |
+
col_count=(3, "fixed"),
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
examples = [['example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav', 0.3],
|
32 |
+
['example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav', 0.3],
|
33 |
+
['example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav', 0.3]]
|
34 |
+
|
35 |
+
|
36 |
+
def make_prediction(file_name=None, detection_threshold=0.3):
|
37 |
+
|
38 |
+
if file_name is not None:
|
39 |
+
audio_file = file_name
|
40 |
+
else:
|
41 |
+
return "You must provide an input audio file."
|
42 |
+
|
43 |
+
if detection_threshold != '':
|
44 |
+
args['detection_threshold'] = float(detection_threshold)
|
45 |
+
|
46 |
+
results = du.process_file(audio_file, model, params, args, max_duration=5.0)
|
47 |
+
|
48 |
+
clss = [aa['class'] for aa in results['pred_dict']['annotation']]
|
49 |
+
st_time = [aa['start_time'] for aa in results['pred_dict']['annotation']]
|
50 |
+
cls_prob = [aa['class_prob'] for aa in results['pred_dict']['annotation']]
|
51 |
+
|
52 |
+
data = {'species': clss, 'time_in_file': st_time, 'species_prob': cls_prob}
|
53 |
+
df = pd.DataFrame(data=data)
|
54 |
+
|
55 |
+
return df
|
56 |
+
|
57 |
+
|
58 |
+
descr_txt = "Demo of BatDetect2 deep learning-based bat echolocation call detection. " \
|
59 |
+
"<br>This model is only trained on bat species from the UK. If the input " \
|
60 |
+
"file is longer than 5 seconds, only the first 5 seconds will be processed." \
|
61 |
+
"<br>Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)."
|
62 |
+
|
63 |
+
gr.Interface(
|
64 |
+
fn = make_prediction,
|
65 |
+
inputs = [gr.Audio(source="upload", type="filepath", optional=True),
|
66 |
+
gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])],
|
67 |
+
outputs = df,
|
68 |
+
theme = "huggingface",
|
69 |
+
title = "BatDetect2 Demo",
|
70 |
+
description = descr_txt,
|
71 |
+
examples = examples,
|
72 |
+
allow_flagging = 'never',
|
73 |
+
).launch()
|
74 |
+
|
75 |
+
|
bat_detect/__init__.py
ADDED
File without changes
|
bat_detect/detector/__init__.py
ADDED
File without changes
|
bat_detect/detector/compute_features.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq):
|
5 |
+
spec_ind = spec_height-spec_ind
|
6 |
+
return round((spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2)
|
7 |
+
|
8 |
+
|
9 |
+
def extract_spec_slices(spec, pred_nms, params):
|
10 |
+
"""
|
11 |
+
Extracts spectrogram slices from spectrogram based on detected call locations.
|
12 |
+
"""
|
13 |
+
|
14 |
+
x_pos = pred_nms['x_pos']
|
15 |
+
y_pos = pred_nms['y_pos']
|
16 |
+
bb_width = pred_nms['bb_width']
|
17 |
+
bb_height = pred_nms['bb_height']
|
18 |
+
slices = []
|
19 |
+
|
20 |
+
# add 20% padding either side of call
|
21 |
+
pad = bb_width*0.2
|
22 |
+
x_pos_pad = x_pos - pad
|
23 |
+
bb_width_pad = bb_width + 2*pad
|
24 |
+
|
25 |
+
for ff in range(len(pred_nms['det_probs'])):
|
26 |
+
x_start = int(np.maximum(0, x_pos_pad[ff]))
|
27 |
+
x_end = int(np.minimum(spec.shape[1]-1, np.round(x_pos_pad[ff] + bb_width_pad[ff])))
|
28 |
+
slices.append(spec[:, x_start:x_end].astype(np.float16))
|
29 |
+
return slices
|
30 |
+
|
31 |
+
|
32 |
+
def get_feature_names():
|
33 |
+
feature_names = ['duration', 'low_freq_bb', 'high_freq_bb', 'bandwidth',
|
34 |
+
'max_power_bb', 'max_power', 'max_power_first',
|
35 |
+
'max_power_second', 'call_interval']
|
36 |
+
return feature_names
|
37 |
+
|
38 |
+
|
39 |
+
def get_feats(spec, pred_nms, params):
|
40 |
+
"""
|
41 |
+
Extracts features from spectrogram based on detected call locations.
|
42 |
+
Condsider re-extracting spectrogram for this to get better temporal resolution.
|
43 |
+
|
44 |
+
For more possible features check out:
|
45 |
+
https://github.com/YvesBas/Tadarida-D/blob/master/Manual_Tadarida-D.odt
|
46 |
+
"""
|
47 |
+
|
48 |
+
x_pos = pred_nms['x_pos']
|
49 |
+
y_pos = pred_nms['y_pos']
|
50 |
+
bb_width = pred_nms['bb_width']
|
51 |
+
bb_height = pred_nms['bb_height']
|
52 |
+
|
53 |
+
feature_names = get_feature_names()
|
54 |
+
num_detections = len(pred_nms['det_probs'])
|
55 |
+
features = np.ones((num_detections, len(feature_names)), dtype=np.float32)*-1
|
56 |
+
|
57 |
+
for ff in range(num_detections):
|
58 |
+
x_start = int(np.maximum(0, x_pos[ff]))
|
59 |
+
x_end = int(np.minimum(spec.shape[1]-1, np.round(x_pos[ff] + bb_width[ff])))
|
60 |
+
# y low is the lowest freq but it will have a higher value due to array starting at 0 at top
|
61 |
+
y_low = int(np.minimum(spec.shape[0]-1, y_pos[ff]))
|
62 |
+
y_high = int(np.maximum(0, np.round(y_pos[ff] - bb_height[ff])))
|
63 |
+
spec_slice = spec[:, x_start:x_end]
|
64 |
+
|
65 |
+
if spec_slice.shape[1] > 1:
|
66 |
+
features[ff, 0] = round(pred_nms['end_times'][ff] - pred_nms['start_times'][ff], 5)
|
67 |
+
features[ff, 1] = int(pred_nms['low_freqs'][ff])
|
68 |
+
features[ff, 2] = int(pred_nms['high_freqs'][ff])
|
69 |
+
features[ff, 3] = int(pred_nms['high_freqs'][ff] - pred_nms['low_freqs'][ff])
|
70 |
+
features[ff, 4] = int(convert_int_to_freq(y_high+spec_slice[y_high:y_low, :].sum(1).argmax(),
|
71 |
+
spec.shape[0], params['min_freq'], params['max_freq']))
|
72 |
+
features[ff, 5] = int(convert_int_to_freq(spec_slice.sum(1).argmax(),
|
73 |
+
spec.shape[0], params['min_freq'], params['max_freq']))
|
74 |
+
hlf_val = spec_slice.shape[1]//2
|
75 |
+
|
76 |
+
features[ff, 6] = int(convert_int_to_freq(spec_slice[:, :hlf_val].sum(1).argmax(),
|
77 |
+
spec.shape[0], params['min_freq'], params['max_freq']))
|
78 |
+
features[ff, 7] = int(convert_int_to_freq(spec_slice[:, hlf_val:].sum(1).argmax(),
|
79 |
+
spec.shape[0], params['min_freq'], params['max_freq']))
|
80 |
+
|
81 |
+
if ff > 0:
|
82 |
+
features[ff, 8] = round(pred_nms['start_times'][ff] - pred_nms['start_times'][ff-1], 5)
|
83 |
+
|
84 |
+
return features
|
bat_detect/detector/model_helpers.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import math
|
6 |
+
|
7 |
+
|
8 |
+
class SelfAttention(nn.Module):
|
9 |
+
def __init__(self, ip_dim, att_dim):
|
10 |
+
super(SelfAttention, self).__init__()
|
11 |
+
# Note, does not encode position information (absolute or realtive)
|
12 |
+
self.temperature = 1.0
|
13 |
+
self.att_dim = att_dim
|
14 |
+
self.key_fun = nn.Linear(ip_dim, att_dim)
|
15 |
+
self.val_fun = nn.Linear(ip_dim, att_dim)
|
16 |
+
self.que_fun = nn.Linear(ip_dim, att_dim)
|
17 |
+
self.pro_fun = nn.Linear(att_dim, ip_dim)
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
x = x.squeeze(2).permute(0,2,1)
|
21 |
+
|
22 |
+
kk = torch.matmul(x, self.key_fun.weight.T) + self.key_fun.bias.unsqueeze(0).unsqueeze(0)
|
23 |
+
qq = torch.matmul(x, self.que_fun.weight.T) + self.que_fun.bias.unsqueeze(0).unsqueeze(0)
|
24 |
+
vv = torch.matmul(x, self.val_fun.weight.T) + self.val_fun.bias.unsqueeze(0).unsqueeze(0)
|
25 |
+
|
26 |
+
kk_qq = torch.bmm(kk, qq.permute(0,2,1)) / (self.temperature*self.att_dim)
|
27 |
+
att_weights = F.softmax(kk_qq, 1) # each col of each attention matrix sums to 1
|
28 |
+
att = torch.bmm(vv.permute(0,2,1), att_weights)
|
29 |
+
|
30 |
+
op = torch.matmul(att.permute(0,2,1), self.pro_fun.weight.T) + self.pro_fun.bias.unsqueeze(0).unsqueeze(0)
|
31 |
+
op = op.permute(0,2,1).unsqueeze(2)
|
32 |
+
|
33 |
+
return op
|
34 |
+
|
35 |
+
|
36 |
+
class ConvBlockDownCoordF(nn.Module):
|
37 |
+
def __init__(self, in_chn, out_chn, ip_height, k_size=3, pad_size=1, stride=1):
|
38 |
+
super(ConvBlockDownCoordF, self).__init__()
|
39 |
+
self.coords = nn.Parameter(torch.linspace(-1, 1, ip_height)[None, None, ..., None], requires_grad=False)
|
40 |
+
self.conv = nn.Conv2d(in_chn+1, out_chn, kernel_size=k_size, padding=pad_size, stride=stride)
|
41 |
+
self.conv_bn = nn.BatchNorm2d(out_chn)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
freq_info = self.coords.repeat(x.shape[0],1,1,x.shape[3])
|
45 |
+
x = torch.cat((x, freq_info), 1)
|
46 |
+
x = F.max_pool2d(self.conv(x), 2, 2)
|
47 |
+
x = F.relu(self.conv_bn(x), inplace=True)
|
48 |
+
return x
|
49 |
+
|
50 |
+
|
51 |
+
class ConvBlockDownStandard(nn.Module):
|
52 |
+
def __init__(self, in_chn, out_chn, ip_height=None, k_size=3, pad_size=1, stride=1):
|
53 |
+
super(ConvBlockDownStandard, self).__init__()
|
54 |
+
self.conv = nn.Conv2d(in_chn, out_chn, kernel_size=k_size, padding=pad_size, stride=stride)
|
55 |
+
self.conv_bn = nn.BatchNorm2d(out_chn)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
x = F.max_pool2d(self.conv(x), 2, 2)
|
59 |
+
x = F.relu(self.conv_bn(x), inplace=True)
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
class ConvBlockUpF(nn.Module):
|
64 |
+
def __init__(self, in_chn, out_chn, ip_height, k_size=3, pad_size=1, up_mode='bilinear', up_scale=(2,2)):
|
65 |
+
super(ConvBlockUpF, self).__init__()
|
66 |
+
self.up_scale = up_scale
|
67 |
+
self.up_mode = up_mode
|
68 |
+
self.coords = nn.Parameter(torch.linspace(-1, 1, ip_height*up_scale[0])[None, None, ..., None], requires_grad=False)
|
69 |
+
self.conv = nn.Conv2d(in_chn+1, out_chn, kernel_size=k_size, padding=pad_size)
|
70 |
+
self.conv_bn = nn.BatchNorm2d(out_chn)
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
op = F.interpolate(x, size=(x.shape[-2]*self.up_scale[0], x.shape[-1]*self.up_scale[1]), mode=self.up_mode, align_corners=False)
|
74 |
+
freq_info = self.coords.repeat(op.shape[0],1,1,op.shape[3])
|
75 |
+
op = torch.cat((op, freq_info), 1)
|
76 |
+
op = self.conv(op)
|
77 |
+
op = F.relu(self.conv_bn(op), inplace=True)
|
78 |
+
return op
|
79 |
+
|
80 |
+
|
81 |
+
class ConvBlockUpStandard(nn.Module):
|
82 |
+
def __init__(self, in_chn, out_chn, ip_height=None, k_size=3, pad_size=1, up_mode='bilinear', up_scale=(2,2)):
|
83 |
+
super(ConvBlockUpStandard, self).__init__()
|
84 |
+
self.up_scale = up_scale
|
85 |
+
self.up_mode = up_mode
|
86 |
+
self.conv = nn.Conv2d(in_chn, out_chn, kernel_size=k_size, padding=pad_size)
|
87 |
+
self.conv_bn = nn.BatchNorm2d(out_chn)
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
op = F.interpolate(x, size=(x.shape[-2]*self.up_scale[0], x.shape[-1]*self.up_scale[1]), mode=self.up_mode, align_corners=False)
|
91 |
+
op = self.conv(op)
|
92 |
+
op = F.relu(self.conv_bn(op), inplace=True)
|
93 |
+
return op
|
bat_detect/detector/models.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from .model_helpers import *
|
6 |
+
|
7 |
+
import torchvision
|
8 |
+
|
9 |
+
import torch.fft
|
10 |
+
from torch import nn
|
11 |
+
|
12 |
+
|
13 |
+
class Net2DFast(nn.Module):
|
14 |
+
def __init__(self, num_filts, num_classes=0, emb_dim=0, ip_height=128, resize_factor=0.5):
|
15 |
+
super(Net2DFast, self).__init__()
|
16 |
+
self.num_classes = num_classes
|
17 |
+
self.emb_dim = emb_dim
|
18 |
+
self.num_filts = num_filts
|
19 |
+
self.resize_factor = resize_factor
|
20 |
+
self.ip_height_rs = ip_height
|
21 |
+
self.bneck_height = self.ip_height_rs//32
|
22 |
+
|
23 |
+
# encoder
|
24 |
+
self.conv_dn_0 = ConvBlockDownCoordF(1, num_filts//4, self.ip_height_rs, k_size=3, pad_size=1, stride=1)
|
25 |
+
self.conv_dn_1 = ConvBlockDownCoordF(num_filts//4, num_filts//2, self.ip_height_rs//2, k_size=3, pad_size=1, stride=1)
|
26 |
+
self.conv_dn_2 = ConvBlockDownCoordF(num_filts//2, num_filts, self.ip_height_rs//4, k_size=3, pad_size=1, stride=1)
|
27 |
+
self.conv_dn_3 = nn.Conv2d(num_filts, num_filts*2, 3, padding=1)
|
28 |
+
self.conv_dn_3_bn = nn.BatchNorm2d(num_filts*2)
|
29 |
+
|
30 |
+
# bottleneck
|
31 |
+
self.conv_1d = nn.Conv2d(num_filts*2, num_filts*2, (self.ip_height_rs//8,1), padding=0)
|
32 |
+
self.conv_1d_bn = nn.BatchNorm2d(num_filts*2)
|
33 |
+
self.att = SelfAttention(num_filts*2, num_filts*2)
|
34 |
+
|
35 |
+
# decoder
|
36 |
+
self.conv_up_2 = ConvBlockUpF(num_filts*2, num_filts//2, self.ip_height_rs//8)
|
37 |
+
self.conv_up_3 = ConvBlockUpF(num_filts//2, num_filts//4, self.ip_height_rs//4)
|
38 |
+
self.conv_up_4 = ConvBlockUpF(num_filts//4, num_filts//4, self.ip_height_rs//2)
|
39 |
+
|
40 |
+
# output
|
41 |
+
# +1 to include background class for class output
|
42 |
+
self.conv_op = nn.Conv2d(num_filts//4, num_filts//4, kernel_size=3, padding=1)
|
43 |
+
self.conv_op_bn = nn.BatchNorm2d(num_filts//4)
|
44 |
+
self.conv_size_op = nn.Conv2d(num_filts//4, 2, kernel_size=1, padding=0)
|
45 |
+
self.conv_classes_op = nn.Conv2d(num_filts//4, self.num_classes+1, kernel_size=1, padding=0)
|
46 |
+
|
47 |
+
if self.emb_dim > 0:
|
48 |
+
self.conv_emb = nn.Conv2d(num_filts, self.emb_dim, kernel_size=1, padding=0)
|
49 |
+
|
50 |
+
|
51 |
+
def forward(self, ip, return_feats=False):
|
52 |
+
|
53 |
+
# encoder
|
54 |
+
x1 = self.conv_dn_0(ip)
|
55 |
+
x2 = self.conv_dn_1(x1)
|
56 |
+
x3 = self.conv_dn_2(x2)
|
57 |
+
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
|
58 |
+
|
59 |
+
# bottleneck
|
60 |
+
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
|
61 |
+
x = self.att(x)
|
62 |
+
x = x.repeat([1,1,self.bneck_height*4,1])
|
63 |
+
|
64 |
+
# decoder
|
65 |
+
x = self.conv_up_2(x+x3)
|
66 |
+
x = self.conv_up_3(x+x2)
|
67 |
+
x = self.conv_up_4(x+x1)
|
68 |
+
|
69 |
+
# output
|
70 |
+
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
|
71 |
+
cls = self.conv_classes_op(x)
|
72 |
+
comb = torch.softmax(cls, 1)
|
73 |
+
|
74 |
+
op = {}
|
75 |
+
op['pred_det'] = comb[:,:-1, :, :].sum(1).unsqueeze(1)
|
76 |
+
op['pred_size'] = F.relu(self.conv_size_op(x), inplace=True)
|
77 |
+
op['pred_class'] = comb
|
78 |
+
op['pred_class_un_norm'] = cls
|
79 |
+
if self.emb_dim > 0:
|
80 |
+
op['pred_emb'] = self.conv_emb(x)
|
81 |
+
if return_feats:
|
82 |
+
op['features'] = x
|
83 |
+
|
84 |
+
return op
|
85 |
+
|
86 |
+
|
87 |
+
class Net2DFastNoAttn(nn.Module):
|
88 |
+
def __init__(self, num_filts, num_classes=0, emb_dim=0, ip_height=128, resize_factor=0.5):
|
89 |
+
super(Net2DFastNoAttn, self).__init__()
|
90 |
+
|
91 |
+
self.num_classes = num_classes
|
92 |
+
self.emb_dim = emb_dim
|
93 |
+
self.num_filts = num_filts
|
94 |
+
self.resize_factor = resize_factor
|
95 |
+
self.ip_height_rs = ip_height
|
96 |
+
self.bneck_height = self.ip_height_rs//32
|
97 |
+
|
98 |
+
self.conv_dn_0 = ConvBlockDownCoordF(1, num_filts//4, self.ip_height_rs, k_size=3, pad_size=1, stride=1)
|
99 |
+
self.conv_dn_1 = ConvBlockDownCoordF(num_filts//4, num_filts//2, self.ip_height_rs//2, k_size=3, pad_size=1, stride=1)
|
100 |
+
self.conv_dn_2 = ConvBlockDownCoordF(num_filts//2, num_filts, self.ip_height_rs//4, k_size=3, pad_size=1, stride=1)
|
101 |
+
self.conv_dn_3 = nn.Conv2d(num_filts, num_filts*2, 3, padding=1)
|
102 |
+
self.conv_dn_3_bn = nn.BatchNorm2d(num_filts*2)
|
103 |
+
|
104 |
+
self.conv_1d = nn.Conv2d(num_filts*2, num_filts*2, (self.ip_height_rs//8,1), padding=0)
|
105 |
+
self.conv_1d_bn = nn.BatchNorm2d(num_filts*2)
|
106 |
+
|
107 |
+
|
108 |
+
self.conv_up_2 = ConvBlockUpF(num_filts*2, num_filts//2, self.ip_height_rs//8)
|
109 |
+
self.conv_up_3 = ConvBlockUpF(num_filts//2, num_filts//4, self.ip_height_rs//4)
|
110 |
+
self.conv_up_4 = ConvBlockUpF(num_filts//4, num_filts//4, self.ip_height_rs//2)
|
111 |
+
|
112 |
+
# output
|
113 |
+
# +1 to include background class for class output
|
114 |
+
self.conv_op = nn.Conv2d(num_filts//4, num_filts//4, kernel_size=3, padding=1)
|
115 |
+
self.conv_op_bn = nn.BatchNorm2d(num_filts//4)
|
116 |
+
self.conv_size_op = nn.Conv2d(num_filts//4, 2, kernel_size=1, padding=0)
|
117 |
+
self.conv_classes_op = nn.Conv2d(num_filts//4, self.num_classes+1, kernel_size=1, padding=0)
|
118 |
+
|
119 |
+
if self.emb_dim > 0:
|
120 |
+
self.conv_emb = nn.Conv2d(num_filts, self.emb_dim, kernel_size=1, padding=0)
|
121 |
+
|
122 |
+
def forward(self, ip, return_feats=False):
|
123 |
+
|
124 |
+
x1 = self.conv_dn_0(ip)
|
125 |
+
x2 = self.conv_dn_1(x1)
|
126 |
+
x3 = self.conv_dn_2(x2)
|
127 |
+
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
|
128 |
+
|
129 |
+
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
|
130 |
+
x = x.repeat([1,1,self.bneck_height*4,1])
|
131 |
+
|
132 |
+
x = self.conv_up_2(x+x3)
|
133 |
+
x = self.conv_up_3(x+x2)
|
134 |
+
x = self.conv_up_4(x+x1)
|
135 |
+
|
136 |
+
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
|
137 |
+
cls = self.conv_classes_op(x)
|
138 |
+
comb = torch.softmax(cls, 1)
|
139 |
+
|
140 |
+
op = {}
|
141 |
+
op['pred_det'] = comb[:,:-1, :, :].sum(1).unsqueeze(1)
|
142 |
+
op['pred_size'] = F.relu(self.conv_size_op(x), inplace=True)
|
143 |
+
op['pred_class'] = comb
|
144 |
+
op['pred_class_un_norm'] = cls
|
145 |
+
if self.emb_dim > 0:
|
146 |
+
op['pred_emb'] = self.conv_emb(x)
|
147 |
+
if return_feats:
|
148 |
+
op['features'] = x
|
149 |
+
|
150 |
+
return op
|
151 |
+
|
152 |
+
|
153 |
+
class Net2DFastNoCoordConv(nn.Module):
|
154 |
+
def __init__(self, num_filts, num_classes=0, emb_dim=0, ip_height=128, resize_factor=0.5):
|
155 |
+
super(Net2DFastNoCoordConv, self).__init__()
|
156 |
+
|
157 |
+
self.num_classes = num_classes
|
158 |
+
self.emb_dim = emb_dim
|
159 |
+
self.num_filts = num_filts
|
160 |
+
self.resize_factor = resize_factor
|
161 |
+
self.ip_height_rs = ip_height
|
162 |
+
self.bneck_height = self.ip_height_rs//32
|
163 |
+
|
164 |
+
self.conv_dn_0 = ConvBlockDownStandard(1, num_filts//4, self.ip_height_rs, k_size=3, pad_size=1, stride=1)
|
165 |
+
self.conv_dn_1 = ConvBlockDownStandard(num_filts//4, num_filts//2, self.ip_height_rs//2, k_size=3, pad_size=1, stride=1)
|
166 |
+
self.conv_dn_2 = ConvBlockDownStandard(num_filts//2, num_filts, self.ip_height_rs//4, k_size=3, pad_size=1, stride=1)
|
167 |
+
self.conv_dn_3 = nn.Conv2d(num_filts, num_filts*2, 3, padding=1)
|
168 |
+
self.conv_dn_3_bn = nn.BatchNorm2d(num_filts*2)
|
169 |
+
|
170 |
+
self.conv_1d = nn.Conv2d(num_filts*2, num_filts*2, (self.ip_height_rs//8,1), padding=0)
|
171 |
+
self.conv_1d_bn = nn.BatchNorm2d(num_filts*2)
|
172 |
+
|
173 |
+
self.att = SelfAttention(num_filts*2, num_filts*2)
|
174 |
+
|
175 |
+
self.conv_up_2 = ConvBlockUpStandard(num_filts*2, num_filts//2, self.ip_height_rs//8)
|
176 |
+
self.conv_up_3 = ConvBlockUpStandard(num_filts//2, num_filts//4, self.ip_height_rs//4)
|
177 |
+
self.conv_up_4 = ConvBlockUpStandard(num_filts//4, num_filts//4, self.ip_height_rs//2)
|
178 |
+
|
179 |
+
# output
|
180 |
+
# +1 to include background class for class output
|
181 |
+
self.conv_op = nn.Conv2d(num_filts//4, num_filts//4, kernel_size=3, padding=1)
|
182 |
+
self.conv_op_bn = nn.BatchNorm2d(num_filts//4)
|
183 |
+
self.conv_size_op = nn.Conv2d(num_filts//4, 2, kernel_size=1, padding=0)
|
184 |
+
self.conv_classes_op = nn.Conv2d(num_filts//4, self.num_classes+1, kernel_size=1, padding=0)
|
185 |
+
|
186 |
+
if self.emb_dim > 0:
|
187 |
+
self.conv_emb = nn.Conv2d(num_filts, self.emb_dim, kernel_size=1, padding=0)
|
188 |
+
|
189 |
+
def forward(self, ip, return_feats=False):
|
190 |
+
|
191 |
+
x1 = self.conv_dn_0(ip)
|
192 |
+
x2 = self.conv_dn_1(x1)
|
193 |
+
x3 = self.conv_dn_2(x2)
|
194 |
+
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
|
195 |
+
|
196 |
+
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
|
197 |
+
x = self.att(x)
|
198 |
+
x = x.repeat([1,1,self.bneck_height*4,1])
|
199 |
+
|
200 |
+
x = self.conv_up_2(x+x3)
|
201 |
+
x = self.conv_up_3(x+x2)
|
202 |
+
x = self.conv_up_4(x+x1)
|
203 |
+
|
204 |
+
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
|
205 |
+
cls = self.conv_classes_op(x)
|
206 |
+
comb = torch.softmax(cls, 1)
|
207 |
+
|
208 |
+
op = {}
|
209 |
+
op['pred_det'] = comb[:,:-1, :, :].sum(1).unsqueeze(1)
|
210 |
+
op['pred_size'] = F.relu(self.conv_size_op(x), inplace=True)
|
211 |
+
op['pred_class'] = comb
|
212 |
+
op['pred_class_un_norm'] = cls
|
213 |
+
if self.emb_dim > 0:
|
214 |
+
op['pred_emb'] = self.conv_emb(x)
|
215 |
+
if return_feats:
|
216 |
+
op['features'] = x
|
217 |
+
|
218 |
+
return op
|
bat_detect/detector/parameters.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import datetime
|
4 |
+
|
5 |
+
|
6 |
+
def mk_dir(path):
|
7 |
+
if not os.path.isdir(path):
|
8 |
+
os.makedirs(path)
|
9 |
+
|
10 |
+
|
11 |
+
def get_params(make_dirs=False, exps_dir='../../experiments/'):
|
12 |
+
params = {}
|
13 |
+
|
14 |
+
params['model_name'] = 'Net2DFast' # Net2DFast, Net2DSkip, Net2DSimple, Net2DSkipDS, Net2DRN
|
15 |
+
params['num_filters'] = 128
|
16 |
+
|
17 |
+
now_str = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")
|
18 |
+
model_name = now_str + '.pth.tar'
|
19 |
+
params['experiment'] = os.path.join(exps_dir, now_str, '')
|
20 |
+
params['model_file_name'] = os.path.join(params['experiment'], model_name)
|
21 |
+
params['op_im_dir'] = os.path.join(params['experiment'], 'op_ims', '')
|
22 |
+
params['op_im_dir_test'] = os.path.join(params['experiment'], 'op_ims_test', '')
|
23 |
+
#params['notes'] = '' # can save notes about an experiment here
|
24 |
+
|
25 |
+
|
26 |
+
# spec parameters
|
27 |
+
params['target_samp_rate'] = 256000 # resamples all audio so that it is at this rate
|
28 |
+
params['fft_win_length'] = 512 / 256000.0 # in milliseconds, amount of time per stft time step
|
29 |
+
params['fft_overlap'] = 0.75 # stft window overlap
|
30 |
+
|
31 |
+
params['max_freq'] = 120000 # in Hz, everything above this will be discarded
|
32 |
+
params['min_freq'] = 10000 # in Hz, everything below this will be discarded
|
33 |
+
|
34 |
+
params['resize_factor'] = 0.5 # resize so the spectrogram at the input of the network
|
35 |
+
params['spec_height'] = 256 # units are number of frequency bins (before resizing is performed)
|
36 |
+
params['spec_train_width'] = 512 # units are number of time steps (before resizing is performed)
|
37 |
+
params['spec_divide_factor'] = 32 # spectrogram should be divisible by this amount in width and height
|
38 |
+
|
39 |
+
# spec processing params
|
40 |
+
params['denoise_spec_avg'] = True # removes the mean for each frequency band
|
41 |
+
params['scale_raw_audio'] = False # scales the raw audio to [-1, 1]
|
42 |
+
params['max_scale_spec'] = False # scales the spectrogram so that it is max 1
|
43 |
+
params['spec_scale'] = 'pcen' # 'log', 'pcen', 'none'
|
44 |
+
|
45 |
+
# detection params
|
46 |
+
params['detection_overlap'] = 0.01 # has to be within this number of ms to count as detection
|
47 |
+
params['ignore_start_end'] = 0.01 # if start of GT calls are within this time from the start/end of file ignore
|
48 |
+
params['detection_threshold'] = 0.01 # the smaller this is the better the recall will be
|
49 |
+
params['nms_kernel_size'] = 9
|
50 |
+
params['nms_top_k_per_sec'] = 200 # keep top K highest predictions per second of audio
|
51 |
+
params['target_sigma'] = 2.0
|
52 |
+
|
53 |
+
# augmentation params
|
54 |
+
params['aug_prob'] = 0.20 # augmentations will be performed with this probability
|
55 |
+
params['augment_at_train'] = True
|
56 |
+
params['augment_at_train_combine'] = True
|
57 |
+
params['echo_max_delay'] = 0.005 # simulate echo by adding copy of raw audio
|
58 |
+
params['stretch_squeeze_delta'] = 0.04 # stretch or squeeze spec
|
59 |
+
params['mask_max_time_perc'] = 0.05 # max mask size - here percentage, not ideal
|
60 |
+
params['mask_max_freq_perc'] = 0.10 # max mask size - here percentage, not ideal
|
61 |
+
params['spec_amp_scaling'] = 2.0 # multiply the "volume" by 0:X times current amount
|
62 |
+
params['aug_sampling_rates'] = [220500, 256000, 300000, 312500, 384000, 441000, 500000]
|
63 |
+
|
64 |
+
# loss params
|
65 |
+
params['train_loss'] = 'focal' # mse or focal
|
66 |
+
params['det_loss_weight'] = 1.0 # weight for the detection part of the loss
|
67 |
+
params['size_loss_weight'] = 0.1 # weight for the bbox size loss
|
68 |
+
params['class_loss_weight'] = 2.0 # weight for the classification loss
|
69 |
+
params['individual_loss_weight'] = 0.0 # not used
|
70 |
+
if params['individual_loss_weight'] == 0.0:
|
71 |
+
params['emb_dim'] = 0 # number of dimensions used for individual id embedding
|
72 |
+
else:
|
73 |
+
params['emb_dim'] = 3
|
74 |
+
|
75 |
+
# train params
|
76 |
+
params['lr'] = 0.001
|
77 |
+
params['batch_size'] = 8
|
78 |
+
params['num_workers'] = 4
|
79 |
+
params['num_epochs'] = 200
|
80 |
+
params['num_eval_epochs'] = 5 # run evaluation every X epochs
|
81 |
+
params['device'] = 'cuda'
|
82 |
+
params['save_test_image_during_train'] = False
|
83 |
+
params['save_test_image_after_train'] = True
|
84 |
+
|
85 |
+
params['convert_to_genus'] = False
|
86 |
+
params['genus_mapping'] = []
|
87 |
+
params['class_names'] = []
|
88 |
+
params['classes_to_ignore'] = ['', ' ', 'Unknown', 'Not Bat']
|
89 |
+
params['generic_class'] = ['Bat']
|
90 |
+
params['events_of_interest'] = ['Echolocation'] # will ignore all other types of events e.g. social calls
|
91 |
+
|
92 |
+
# the classes in this list are standardized during training so that the same low and high freq are used
|
93 |
+
params['standardize_classs_names'] = []
|
94 |
+
|
95 |
+
# create directories
|
96 |
+
if make_dirs:
|
97 |
+
print('Model name : ' + params['model_name'])
|
98 |
+
print('Model file : ' + params['model_file_name'])
|
99 |
+
print('Experiment : ' + params['experiment'])
|
100 |
+
|
101 |
+
mk_dir(params['experiment'])
|
102 |
+
if params['save_test_image_during_train']:
|
103 |
+
mk_dir(params['op_im_dir'])
|
104 |
+
if params['save_test_image_after_train']:
|
105 |
+
mk_dir(params['op_im_dir_test'])
|
106 |
+
mk_dir(os.path.dirname(params['model_file_name']))
|
107 |
+
|
108 |
+
return params
|
bat_detect/detector/post_process.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
np.seterr(divide='ignore', invalid='ignore')
|
6 |
+
|
7 |
+
|
8 |
+
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
|
9 |
+
nfft = int(fft_win_length*sampling_rate)
|
10 |
+
noverlap = int(fft_overlap*nfft)
|
11 |
+
return ((x_pos*(nfft - noverlap)) + noverlap) / sampling_rate
|
12 |
+
#return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
|
13 |
+
|
14 |
+
|
15 |
+
def overall_class_pred(det_prob, class_prob):
|
16 |
+
weighted_pred = (class_prob*det_prob).sum(1)
|
17 |
+
return weighted_pred / weighted_pred.sum()
|
18 |
+
|
19 |
+
|
20 |
+
def run_nms(outputs, params, sampling_rate):
|
21 |
+
|
22 |
+
pred_det = outputs['pred_det'] # probability of box
|
23 |
+
pred_size = outputs['pred_size'] # box size
|
24 |
+
|
25 |
+
pred_det_nms = non_max_suppression(pred_det, params['nms_kernel_size'])
|
26 |
+
freq_rescale = (params['max_freq'] - params['min_freq']) /pred_det.shape[-2]
|
27 |
+
|
28 |
+
# NOTE there will be small differences depending on which sampling rate is chosen
|
29 |
+
# as we are choosing the same sampling rate for the entire batch
|
30 |
+
duration = x_coords_to_time(pred_det.shape[-1], sampling_rate[0].item(),
|
31 |
+
params['fft_win_length'], params['fft_overlap'])
|
32 |
+
top_k = int(duration * params['nms_top_k_per_sec'])
|
33 |
+
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
|
34 |
+
|
35 |
+
# loop over batch to save outputs
|
36 |
+
preds = []
|
37 |
+
feats = []
|
38 |
+
for ii in range(pred_det_nms.shape[0]):
|
39 |
+
# get valid indices
|
40 |
+
inds_ord = torch.argsort(x_pos[ii, :])
|
41 |
+
valid_inds = scores[ii, inds_ord] > params['detection_threshold']
|
42 |
+
valid_inds = inds_ord[valid_inds]
|
43 |
+
|
44 |
+
# create result dictionary
|
45 |
+
pred = {}
|
46 |
+
pred['det_probs'] = scores[ii, valid_inds]
|
47 |
+
pred['x_pos'] = x_pos[ii, valid_inds]
|
48 |
+
pred['y_pos'] = y_pos[ii, valid_inds]
|
49 |
+
pred['bb_width'] = pred_size[ii, 0, pred['y_pos'], pred['x_pos']]
|
50 |
+
pred['bb_height'] = pred_size[ii, 1, pred['y_pos'], pred['x_pos']]
|
51 |
+
pred['start_times'] = x_coords_to_time(pred['x_pos'].float() / params['resize_factor'],
|
52 |
+
sampling_rate[ii].item(), params['fft_win_length'], params['fft_overlap'])
|
53 |
+
pred['end_times'] = x_coords_to_time((pred['x_pos'].float()+pred['bb_width']) / params['resize_factor'],
|
54 |
+
sampling_rate[ii].item(), params['fft_win_length'], params['fft_overlap'])
|
55 |
+
pred['low_freqs'] = (pred_size[ii].shape[1] - pred['y_pos'].float())*freq_rescale + params['min_freq']
|
56 |
+
pred['high_freqs'] = pred['low_freqs'] + pred['bb_height']*freq_rescale
|
57 |
+
|
58 |
+
# extract the per class votes
|
59 |
+
if 'pred_class' in outputs:
|
60 |
+
pred['class_probs'] = outputs['pred_class'][ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds]]
|
61 |
+
|
62 |
+
# extract the model features
|
63 |
+
if 'features' in outputs:
|
64 |
+
feat = outputs['features'][ii, :, y_pos[ii, valid_inds], x_pos[ii, valid_inds]].transpose(0, 1)
|
65 |
+
feat = feat.cpu().numpy().astype(np.float32)
|
66 |
+
feats.append(feat)
|
67 |
+
|
68 |
+
# convert to numpy
|
69 |
+
for kk in pred.keys():
|
70 |
+
pred[kk] = pred[kk].cpu().numpy().astype(np.float32)
|
71 |
+
preds.append(pred)
|
72 |
+
|
73 |
+
return preds, feats
|
74 |
+
|
75 |
+
|
76 |
+
def non_max_suppression(heat, kernel_size):
|
77 |
+
# kernel can be an int or list/tuple
|
78 |
+
if type(kernel_size) is int:
|
79 |
+
kernel_size_h = kernel_size
|
80 |
+
kernel_size_w = kernel_size
|
81 |
+
|
82 |
+
pad_h = (kernel_size_h - 1) // 2
|
83 |
+
pad_w = (kernel_size_w - 1) // 2
|
84 |
+
|
85 |
+
hmax = nn.functional.max_pool2d(heat, (kernel_size_h, kernel_size_w), stride=1, padding=(pad_h, pad_w))
|
86 |
+
keep = (hmax == heat).float()
|
87 |
+
|
88 |
+
return heat * keep
|
89 |
+
|
90 |
+
|
91 |
+
def get_topk_scores(scores, K):
|
92 |
+
# expects input of size: batch x 1 x height x width
|
93 |
+
batch, _, height, width = scores.size()
|
94 |
+
|
95 |
+
topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K)
|
96 |
+
topk_inds = topk_inds % (height * width)
|
97 |
+
topk_ys = torch.div(topk_inds, width, rounding_mode='floor').long()
|
98 |
+
topk_xs = (topk_inds % width).long()
|
99 |
+
|
100 |
+
return topk_scores, topk_ys, topk_xs
|
bat_detect/evaluate/evaluate_models.py
ADDED
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Evaluates trained model on test set and generates plots.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import sys
|
7 |
+
import os
|
8 |
+
import copy
|
9 |
+
import json
|
10 |
+
import pandas as pd
|
11 |
+
from sklearn.ensemble import RandomForestClassifier
|
12 |
+
import argparse
|
13 |
+
|
14 |
+
sys.path.append('../../')
|
15 |
+
import bat_detect.utils.detector_utils as du
|
16 |
+
import bat_detect.train.train_utils as tu
|
17 |
+
import bat_detect.detector.parameters as parameters
|
18 |
+
import bat_detect.train.evaluate as evl
|
19 |
+
import bat_detect.utils.plot_utils as pu
|
20 |
+
|
21 |
+
|
22 |
+
def get_blank_annotation(ip_str):
|
23 |
+
|
24 |
+
res = {}
|
25 |
+
res['class_name'] = ''
|
26 |
+
res['duration'] = -1
|
27 |
+
res['id'] = ''# fileName
|
28 |
+
res['issues'] = False
|
29 |
+
res['notes'] = ip_str
|
30 |
+
res['time_exp'] = 1
|
31 |
+
res['annotated'] = False
|
32 |
+
res['annotation'] = []
|
33 |
+
|
34 |
+
ann = {}
|
35 |
+
ann['class'] = ''
|
36 |
+
ann['event'] = 'Echolocation'
|
37 |
+
ann['individual'] = -1
|
38 |
+
ann['start_time'] = -1
|
39 |
+
ann['end_time'] = -1
|
40 |
+
ann['low_freq'] = -1
|
41 |
+
ann['high_freq'] = -1
|
42 |
+
ann['confidence'] = -1
|
43 |
+
|
44 |
+
return copy.deepcopy(res), copy.deepcopy(ann)
|
45 |
+
|
46 |
+
|
47 |
+
def create_genus_mapping(gt_test, preds, class_names):
|
48 |
+
# rolls the per class predictions and ground truth back up to genus level
|
49 |
+
class_names_genus, cls_to_genus = np.unique([cc.split(' ')[0] for cc in class_names], return_inverse=True)
|
50 |
+
genus_to_cls_map = [np.where(np.array(cls_to_genus) == cc)[0] for cc in range(len(class_names_genus))]
|
51 |
+
|
52 |
+
gt_test_g = []
|
53 |
+
for gg in gt_test:
|
54 |
+
gg_g = copy.deepcopy(gg)
|
55 |
+
inds = np.where(gg_g['class_ids']!=-1)[0]
|
56 |
+
gg_g['class_ids'][inds] = cls_to_genus[gg_g['class_ids'][inds]]
|
57 |
+
gt_test_g.append(gg_g)
|
58 |
+
|
59 |
+
# note, will have entries geater than one as we are summing across the respective classes
|
60 |
+
preds_g = []
|
61 |
+
for pp in preds:
|
62 |
+
pp_g = copy.deepcopy(pp)
|
63 |
+
pp_g['class_probs'] = np.zeros((len(class_names_genus), pp_g['class_probs'].shape[1]), dtype=np.float32)
|
64 |
+
for cc, inds in enumerate(genus_to_cls_map):
|
65 |
+
pp_g['class_probs'][cc, :] = pp['class_probs'][inds, :].sum(0)
|
66 |
+
preds_g.append(pp_g)
|
67 |
+
|
68 |
+
return class_names_genus, preds_g, gt_test_g
|
69 |
+
|
70 |
+
|
71 |
+
def load_tadarida_pred(ip_dir, dataset, file_of_interest):
|
72 |
+
|
73 |
+
res, ann = get_blank_annotation('Generated by Tadarida')
|
74 |
+
|
75 |
+
# create the annotations in the correct format
|
76 |
+
da_c = pd.read_csv(ip_dir + dataset + '/' + file_of_interest.replace('.wav', '.ta').replace('.WAV', '.ta'), sep='\t')
|
77 |
+
|
78 |
+
res_c = copy.deepcopy(res)
|
79 |
+
res_c['id'] = file_of_interest
|
80 |
+
res_c['dataset'] = dataset
|
81 |
+
res_c['feats'] = da_c.iloc[:, 6:].values.astype(np.float32)
|
82 |
+
|
83 |
+
if da_c.shape[0] > 0:
|
84 |
+
res_c['class_name'] = ''
|
85 |
+
res_c['class_prob'] = 0.0
|
86 |
+
|
87 |
+
for aa in range(da_c.shape[0]):
|
88 |
+
ann_c = copy.deepcopy(ann)
|
89 |
+
ann_c['class'] = 'Not Bat' # will assign to class later
|
90 |
+
ann_c['start_time'] = np.round(da_c.iloc[aa]['StTime']/1000.0 ,5)
|
91 |
+
ann_c['end_time'] = np.round((da_c.iloc[aa]['StTime'] + da_c.iloc[aa]['Dur'])/1000.0, 5)
|
92 |
+
ann_c['low_freq'] = np.round(da_c.iloc[aa]['Fmin'] * 1000.0, 2)
|
93 |
+
ann_c['high_freq'] = np.round(da_c.iloc[aa]['Fmax'] * 1000.0, 2)
|
94 |
+
ann_c['det_prob'] = 0.0
|
95 |
+
res_c['annotation'].append(ann_c)
|
96 |
+
|
97 |
+
return res_c
|
98 |
+
|
99 |
+
|
100 |
+
def load_sonobat_meta(ip_dir, datasets, region_classifier, class_names, only_accepted_species=True):
|
101 |
+
|
102 |
+
sp_dict = {}
|
103 |
+
for ss in class_names:
|
104 |
+
sp_key = ss.split(' ')[0][:3] + ss.split(' ')[1][:3]
|
105 |
+
sp_dict[sp_key] = ss
|
106 |
+
|
107 |
+
sp_dict['x'] = '' # not bat
|
108 |
+
sp_dict['Bat'] = 'Bat'
|
109 |
+
|
110 |
+
sonobat_meta = {}
|
111 |
+
for tt in datasets:
|
112 |
+
dataset = tt['dataset_name']
|
113 |
+
sb_ip_dir = ip_dir + dataset + '/' + region_classifier + '/'
|
114 |
+
|
115 |
+
# load the call level predictions
|
116 |
+
ip_file_p = sb_ip_dir + dataset + '_Parameters_v4.5.0.txt'
|
117 |
+
#ip_file_p = sb_ip_dir + 'audio_SonoBatch_v30.0 beta.txt'
|
118 |
+
da = pd.read_csv(ip_file_p, sep='\t')
|
119 |
+
|
120 |
+
# load the file level predictions
|
121 |
+
ip_file_b = sb_ip_dir + dataset + '_SonoBatch_v4.5.0.txt'
|
122 |
+
#ip_file_b = sb_ip_dir + 'audio_CumulativeParameters_v30.0 beta.txt'
|
123 |
+
|
124 |
+
with open(ip_file_b) as f:
|
125 |
+
lines = f.readlines()
|
126 |
+
lines = [x.strip() for x in lines]
|
127 |
+
del lines[0]
|
128 |
+
|
129 |
+
file_res = {}
|
130 |
+
for ll in lines:
|
131 |
+
# note this does not seem to parse the file very well
|
132 |
+
ll_data = ll.split('\t')
|
133 |
+
|
134 |
+
# there are sometimes many different species names per file
|
135 |
+
if only_accepted_species:
|
136 |
+
# only choosing "SppAccp"
|
137 |
+
ind = 4
|
138 |
+
else:
|
139 |
+
# choosing ""~Spp" if "SppAccp" does not exist
|
140 |
+
if ll_data[4] != 'x':
|
141 |
+
ind = 4 # choosing "SppAccp", along with "Prob" here
|
142 |
+
else:
|
143 |
+
ind = 8 # choosing "~Spp", along with "~Prob" here
|
144 |
+
|
145 |
+
sp_name_1 = sp_dict[ll_data[ind]]
|
146 |
+
prob_1 = ll_data[ind+1]
|
147 |
+
if prob_1 == 'x':
|
148 |
+
prob_1 = 0.0
|
149 |
+
file_res[ll_data[1]] = {'id':ll_data[1], 'species_1':sp_name_1, 'prob_1':prob_1}
|
150 |
+
|
151 |
+
sonobat_meta[dataset] = {}
|
152 |
+
sonobat_meta[dataset]['file_res'] = file_res
|
153 |
+
sonobat_meta[dataset]['call_info'] = da
|
154 |
+
|
155 |
+
return sonobat_meta
|
156 |
+
|
157 |
+
|
158 |
+
def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
|
159 |
+
|
160 |
+
# create the annotations in the correct format
|
161 |
+
res, ann = get_blank_annotation('Generated by Sonobat')
|
162 |
+
res_c = copy.deepcopy(res)
|
163 |
+
res_c['id'] = id
|
164 |
+
res_c['dataset'] = dataset
|
165 |
+
|
166 |
+
da = sb_meta[dataset]['call_info']
|
167 |
+
da_c = da[da['Filename'] == id]
|
168 |
+
|
169 |
+
file_res = sb_meta[dataset]['file_res']
|
170 |
+
res_c['feats'] = np.zeros((0,0))
|
171 |
+
|
172 |
+
if da_c.shape[0] > 0:
|
173 |
+
res_c['class_name'] = file_res[id]['species_1']
|
174 |
+
res_c['class_prob'] = file_res[id]['prob_1']
|
175 |
+
res_c['feats'] = da_c.iloc[:, 3:105].values.astype(np.float32)
|
176 |
+
|
177 |
+
for aa in range(da_c.shape[0]):
|
178 |
+
ann_c = copy.deepcopy(ann)
|
179 |
+
if set_class_name is None:
|
180 |
+
ann_c['class'] = file_res[id]['species_1']
|
181 |
+
else:
|
182 |
+
ann_c['class'] = set_class_name
|
183 |
+
ann_c['start_time'] = np.round(da_c.iloc[aa]['TimeInFile'] / 1000.0 ,5)
|
184 |
+
ann_c['end_time'] = np.round(ann_c['start_time'] + da_c.iloc[aa]['CallDuration']/1000.0, 5)
|
185 |
+
ann_c['low_freq'] = np.round(da_c.iloc[aa]['LowFreq'] * 1000.0, 2)
|
186 |
+
ann_c['high_freq'] = np.round(da_c.iloc[aa]['HiFreq'] * 1000.0, 2)
|
187 |
+
ann_c['det_prob'] = np.round(da_c.iloc[aa]['Quality'], 3)
|
188 |
+
res_c['annotation'].append(ann_c)
|
189 |
+
|
190 |
+
return res_c
|
191 |
+
|
192 |
+
|
193 |
+
def bb_overlap(bb_g_in, bb_p_in):
|
194 |
+
|
195 |
+
freq_scale = 10000000.0 # ensure that both axis are roughly the same range
|
196 |
+
bb_g = [bb_g_in['start_time'], bb_g_in['low_freq']/freq_scale, bb_g_in['end_time'], bb_g_in['high_freq']/freq_scale]
|
197 |
+
bb_p = [bb_p_in['start_time'], bb_p_in['low_freq']/freq_scale, bb_p_in['end_time'], bb_p_in['high_freq']/freq_scale]
|
198 |
+
|
199 |
+
xA = max(bb_g[0], bb_p[0])
|
200 |
+
yA = max(bb_g[1], bb_p[1])
|
201 |
+
xB = min(bb_g[2], bb_p[2])
|
202 |
+
yB = min(bb_g[3], bb_p[3])
|
203 |
+
|
204 |
+
# compute the area of intersection rectangle
|
205 |
+
inter_area = abs(max((xB - xA, 0.0)) * max((yB - yA), 0.0))
|
206 |
+
|
207 |
+
if inter_area == 0:
|
208 |
+
iou = 0.0
|
209 |
+
|
210 |
+
else:
|
211 |
+
# compute the area of both
|
212 |
+
bb_area_g = abs((bb_g[2] - bb_g[0]) * (bb_g[3] - bb_g[1]))
|
213 |
+
bb_area_p = abs((bb_p[2] - bb_p[0]) * (bb_p[3] - bb_p[1]))
|
214 |
+
|
215 |
+
iou = inter_area / float(bb_area_g + bb_area_p - inter_area)
|
216 |
+
|
217 |
+
return iou
|
218 |
+
|
219 |
+
|
220 |
+
def assign_to_gt(gt, pred, iou_thresh):
|
221 |
+
# this will edit pred in place
|
222 |
+
|
223 |
+
num_preds = len(pred['annotation'])
|
224 |
+
num_gts = len(gt['annotation'])
|
225 |
+
if num_preds > 0 and num_gts > 0:
|
226 |
+
iou_m = np.zeros((num_preds, num_gts))
|
227 |
+
for ii in range(num_preds):
|
228 |
+
for jj in range(num_gts):
|
229 |
+
iou_m[ii, jj] = bb_overlap(gt['annotation'][jj], pred['annotation'][ii])
|
230 |
+
|
231 |
+
# greedily assign detections to ground truths
|
232 |
+
# needs to be greater than some threshold and we cannot assign GT
|
233 |
+
# to more than one detection
|
234 |
+
# TODO could try to do an optimal assignment
|
235 |
+
for jj in range(num_gts):
|
236 |
+
max_iou = np.argmax(iou_m[:, jj])
|
237 |
+
if iou_m[max_iou, jj] > iou_thresh:
|
238 |
+
pred['annotation'][max_iou]['class'] = gt['annotation'][jj]['class']
|
239 |
+
iou_m[max_iou, :] = -1.0
|
240 |
+
|
241 |
+
return pred
|
242 |
+
|
243 |
+
|
244 |
+
def parse_data(data, class_names, non_event_classes, is_pred=False):
|
245 |
+
class_names_all = class_names + non_event_classes
|
246 |
+
|
247 |
+
data['class_names'] = np.array([aa['class'] for aa in data['annotation']])
|
248 |
+
data['start_times'] = np.array([aa['start_time'] for aa in data['annotation']])
|
249 |
+
data['end_times'] = np.array([aa['end_time'] for aa in data['annotation']])
|
250 |
+
data['high_freqs'] = np.array([float(aa['high_freq']) for aa in data['annotation']])
|
251 |
+
data['low_freqs'] = np.array([float(aa['low_freq']) for aa in data['annotation']])
|
252 |
+
|
253 |
+
if is_pred:
|
254 |
+
# when loading predictions
|
255 |
+
data['det_probs'] = np.array([float(aa['det_prob']) for aa in data['annotation']])
|
256 |
+
data['class_probs'] = np.zeros((len(class_names)+1, len(data['annotation'])))
|
257 |
+
data['class_ids'] = np.array([class_names_all.index(aa['class']) for aa in data['annotation']]).astype(np.int32)
|
258 |
+
else:
|
259 |
+
# when loading ground truth
|
260 |
+
# if the class label is not in the set of interest then set to -1
|
261 |
+
labels = []
|
262 |
+
for aa in data['annotation']:
|
263 |
+
if aa['class'] in class_names:
|
264 |
+
labels.append(class_names_all.index(aa['class']))
|
265 |
+
else:
|
266 |
+
labels.append(-1)
|
267 |
+
data['class_ids'] = np.array(labels).astype(np.int32)
|
268 |
+
|
269 |
+
return data
|
270 |
+
|
271 |
+
|
272 |
+
def load_gt_data(datasets, events_of_interest, class_names, classes_to_ignore):
|
273 |
+
gt_data = []
|
274 |
+
for dd in datasets:
|
275 |
+
print('\n' + dd['dataset_name'])
|
276 |
+
gt_dataset = tu.load_set_of_anns([dd], events_of_interest=events_of_interest, verbose=True)
|
277 |
+
gt_dataset = [parse_data(gg, class_names, classes_to_ignore, False) for gg in gt_dataset]
|
278 |
+
|
279 |
+
for gt in gt_dataset:
|
280 |
+
gt['dataset_name'] = dd['dataset_name']
|
281 |
+
|
282 |
+
gt_data.extend(gt_dataset)
|
283 |
+
|
284 |
+
return gt_data
|
285 |
+
|
286 |
+
|
287 |
+
def train_rf_model(x_train, y_train, num_classes, seed=2001):
|
288 |
+
# TODO search for the best hyper parameters on val set
|
289 |
+
# Currently only training on the species and 'not bat' - exclude 'generic_class' which is last
|
290 |
+
# alternative would be to first have a "bat" vs "not bat" classifier, and then a species classifier?
|
291 |
+
|
292 |
+
x_train = np.vstack(x_train)
|
293 |
+
y_train = np.hstack(y_train)
|
294 |
+
|
295 |
+
inds = np.where(y_train < num_classes)[0]
|
296 |
+
x_train = x_train[inds, :]
|
297 |
+
y_train = y_train[inds]
|
298 |
+
un_train_class = np.unique(y_train)
|
299 |
+
|
300 |
+
clf = RandomForestClassifier(random_state=seed, n_jobs=-1)
|
301 |
+
clf.fit(x_train, y_train)
|
302 |
+
y_pred = clf.predict(x_train)
|
303 |
+
tr_acc = (y_pred==y_train).mean()
|
304 |
+
#print('Train acc', round(tr_acc*100, 2))
|
305 |
+
return clf, un_train_class
|
306 |
+
|
307 |
+
|
308 |
+
def eval_rf_model(clf, pred, un_train_class, num_classes):
|
309 |
+
# stores the prediction in place
|
310 |
+
if pred['feats'].shape[0] > 0:
|
311 |
+
pred['class_probs'] = np.zeros((num_classes, pred['feats'].shape[0]))
|
312 |
+
pred['class_probs'][un_train_class, :] = clf.predict_proba(pred['feats']).T
|
313 |
+
pred['det_probs'] = pred['class_probs'][:-1, :].sum(0)
|
314 |
+
else:
|
315 |
+
pred['class_probs'] = np.zeros((num_classes, 0))
|
316 |
+
pred['det_probs'] = np.zeros(0)
|
317 |
+
return pred
|
318 |
+
|
319 |
+
|
320 |
+
def save_summary_to_json(op_dir, mod_name, results):
|
321 |
+
op = {}
|
322 |
+
op['avg_prec'] = round(results['avg_prec'], 3)
|
323 |
+
op['avg_prec_class'] = round(results['avg_prec_class'], 3)
|
324 |
+
op['top_class'] = round(results['top_class']['avg_prec'], 3)
|
325 |
+
op['file_acc'] = round(results['file_acc'], 3)
|
326 |
+
op['model'] = mod_name
|
327 |
+
|
328 |
+
op['per_class'] = {}
|
329 |
+
for cc in results['class_pr']:
|
330 |
+
op['per_class'][cc['name']] = cc['avg_prec']
|
331 |
+
|
332 |
+
op_file_name = os.path.join(op_dir, mod_name + '_results.json')
|
333 |
+
with open(op_file_name, 'w') as da:
|
334 |
+
json.dump(op, da, indent=2)
|
335 |
+
|
336 |
+
|
337 |
+
def print_results(model_name, mod_str, results, op_dir, class_names, file_type, title_text=''):
|
338 |
+
print('\nResults - ' + model_name)
|
339 |
+
print('avg_prec ', round(results['avg_prec'], 3))
|
340 |
+
print('avg_prec_class', round(results['avg_prec_class'], 3))
|
341 |
+
print('top_class ', round(results['top_class']['avg_prec'], 3))
|
342 |
+
print('file_acc ', round(results['file_acc'], 3))
|
343 |
+
|
344 |
+
print('\nSaving ' + model_name + ' results to: ' + op_dir)
|
345 |
+
save_summary_to_json(op_dir, mod_str, results)
|
346 |
+
|
347 |
+
pu.plot_pr_curve(op_dir, mod_str+'_test_all_det', mod_str+'_test_all_det', results, file_type, title_text + 'Detection PR')
|
348 |
+
pu.plot_pr_curve(op_dir, mod_str+'_test_all_top_class', mod_str+'_test_all_top_class', results['top_class'], file_type, title_text + 'Top Class')
|
349 |
+
pu.plot_pr_curve_class(op_dir, mod_str+'_test_all_class', mod_str+'_test_all_class', results, file_type, title_text + 'Per-Class PR')
|
350 |
+
pu.plot_confusion_matrix(op_dir, mod_str+'_confusion', results['gt_valid_file'], results['pred_valid_file'],
|
351 |
+
results['file_acc'], class_names, True, file_type, title_text + 'Confusion Matrix')
|
352 |
+
|
353 |
+
|
354 |
+
def add_root_path_back(data_sets, ann_path, wav_path):
|
355 |
+
for dd in data_sets:
|
356 |
+
dd['ann_path'] = os.path.join(ann_path, dd['ann_path'])
|
357 |
+
dd['wav_path'] = os.path.join(wav_path, dd['wav_path'])
|
358 |
+
return data_sets
|
359 |
+
|
360 |
+
|
361 |
+
def check_classes_in_train(gt_list, class_names):
|
362 |
+
num_gt_total = np.sum([gg['start_times'].shape[0] for gg in gt_list])
|
363 |
+
num_with_no_class = 0
|
364 |
+
for gt in gt_list:
|
365 |
+
for cc in gt['class_names']:
|
366 |
+
if cc not in class_names:
|
367 |
+
num_with_no_class += 1
|
368 |
+
return num_with_no_class
|
369 |
+
|
370 |
+
|
371 |
+
if __name__ == "__main__":
|
372 |
+
|
373 |
+
parser = argparse.ArgumentParser()
|
374 |
+
parser.add_argument('op_dir', type=str, default='plots/results_compare/',
|
375 |
+
help='Output directory for plots')
|
376 |
+
parser.add_argument('data_dir', type=str,
|
377 |
+
help='Path to root of datasets')
|
378 |
+
parser.add_argument('ann_dir', type=str,
|
379 |
+
help='Path to extracted annotations')
|
380 |
+
parser.add_argument('bd_model_path', type=str,
|
381 |
+
help='Path to BatDetect model')
|
382 |
+
parser.add_argument('--test_file', type=str, default='',
|
383 |
+
help='Path to json file used for evaluation.')
|
384 |
+
parser.add_argument('--sb_ip_dir', type=str, default='',
|
385 |
+
help='Path to sonobat predictions')
|
386 |
+
parser.add_argument('--sb_region_classifier', type=str, default='south',
|
387 |
+
help='Path to sonobat predictions')
|
388 |
+
parser.add_argument('--td_ip_dir', type=str, default='',
|
389 |
+
help='Path to tadarida_D predictions')
|
390 |
+
parser.add_argument('--iou_thresh', type=float, default=0.01,
|
391 |
+
help='IOU threshold for assigning predictions to ground truth')
|
392 |
+
parser.add_argument('--file_type', type=str, default='png',
|
393 |
+
help='Type of image to save - png or pdf')
|
394 |
+
parser.add_argument('--title_text', type=str, default='',
|
395 |
+
help='Text to add as title of plots')
|
396 |
+
parser.add_argument('--rand_seed', type=int, default=2001,
|
397 |
+
help='Random seed')
|
398 |
+
args = vars(parser.parse_args())
|
399 |
+
|
400 |
+
np.random.seed(args['rand_seed'])
|
401 |
+
|
402 |
+
if not os.path.isdir(args['op_dir']):
|
403 |
+
os.makedirs(args['op_dir'])
|
404 |
+
|
405 |
+
|
406 |
+
# load the model
|
407 |
+
params_eval = parameters.get_params(False)
|
408 |
+
_, params_bd = du.load_model(args['bd_model_path'])
|
409 |
+
|
410 |
+
class_names = params_bd['class_names']
|
411 |
+
num_classes = len(class_names) + 1 # num classes plus background class
|
412 |
+
|
413 |
+
classes_to_ignore = ['Not Bat', 'Bat', 'Unknown']
|
414 |
+
events_of_interest = ['Echolocation']
|
415 |
+
|
416 |
+
# load test data
|
417 |
+
if args['test_file'] == '':
|
418 |
+
# load the test files of interest from the trained model
|
419 |
+
test_sets = add_root_path_back(params_bd['test_sets'], args['ann_dir'], args['data_dir'])
|
420 |
+
test_sets = [dd for dd in test_sets if not dd['is_binary']] # exclude bat/not datasets
|
421 |
+
else:
|
422 |
+
# user specified annotation file to evaluate
|
423 |
+
test_dict = {}
|
424 |
+
test_dict['dataset_name'] = args['test_file'].replace('.json', '')
|
425 |
+
test_dict['is_test'] = True
|
426 |
+
test_dict['is_binary'] = True
|
427 |
+
test_dict['ann_path'] = os.path.join(args['ann_dir'], args['test_file'])
|
428 |
+
test_dict['wav_path'] = args['data_dir']
|
429 |
+
test_sets = [test_dict]
|
430 |
+
|
431 |
+
# load the gt for the test set
|
432 |
+
gt_test = load_gt_data(test_sets, events_of_interest, class_names, classes_to_ignore)
|
433 |
+
total_num_calls = np.sum([gg['start_times'].shape[0] for gg in gt_test])
|
434 |
+
print('\nTotal number of test files:', len(gt_test))
|
435 |
+
print('Total number of test calls:', np.sum([gg['start_times'].shape[0] for gg in gt_test]))
|
436 |
+
|
437 |
+
# check if test contains classes not in the train set
|
438 |
+
num_with_no_class = check_classes_in_train(gt_test, class_names)
|
439 |
+
if total_num_calls == num_with_no_class:
|
440 |
+
print('Classes from the test set are not in the train set.')
|
441 |
+
assert False
|
442 |
+
|
443 |
+
# only need the train data if evaluating Sonobat or Tadarida
|
444 |
+
if args['sb_ip_dir'] != '' or args['td_ip_dir'] != '':
|
445 |
+
train_sets = add_root_path_back(params_bd['train_sets'], args['ann_dir'], args['data_dir'])
|
446 |
+
train_sets = [dd for dd in train_sets if not dd['is_binary']] # exclude bat/not datasets
|
447 |
+
gt_train = load_gt_data(train_sets, events_of_interest, class_names, classes_to_ignore)
|
448 |
+
|
449 |
+
|
450 |
+
#
|
451 |
+
# evaluate Sonobat by training random forest classifier
|
452 |
+
#
|
453 |
+
# NOTE: Sonobat may only make predictions for a subset of the files
|
454 |
+
#
|
455 |
+
if args['sb_ip_dir'] != '':
|
456 |
+
sb_meta = load_sonobat_meta(args['sb_ip_dir'], train_sets + test_sets, args['sb_region_classifier'], class_names)
|
457 |
+
|
458 |
+
preds_sb = []
|
459 |
+
keep_inds_sb = []
|
460 |
+
for ii, gt in enumerate(gt_test):
|
461 |
+
sb_pred = load_sonobat_preds(gt['dataset_name'], gt['id'], sb_meta)
|
462 |
+
if sb_pred['class_name'] != '':
|
463 |
+
sb_pred = parse_data(sb_pred, class_names, classes_to_ignore, True)
|
464 |
+
sb_pred['class_probs'][sb_pred['class_ids'], np.arange(sb_pred['class_probs'].shape[1])] = sb_pred['det_probs']
|
465 |
+
preds_sb.append(sb_pred)
|
466 |
+
keep_inds_sb.append(ii)
|
467 |
+
|
468 |
+
results_sb = evl.evaluate_predictions([gt_test[ii] for ii in keep_inds_sb], preds_sb, class_names,
|
469 |
+
params_eval['detection_overlap'], params_eval['ignore_start_end'])
|
470 |
+
print_results('Sonobat', 'sb', results_sb, args['op_dir'], class_names,
|
471 |
+
args['file_type'], args['title_text'] + ' - Species - ')
|
472 |
+
print('Only reporting results for', len(keep_inds_sb), 'files, out of', len(gt_test))
|
473 |
+
|
474 |
+
|
475 |
+
# train our own random forest on sonobat features
|
476 |
+
x_train = []
|
477 |
+
y_train = []
|
478 |
+
for gt in gt_train:
|
479 |
+
pred = load_sonobat_preds(gt['dataset_name'], gt['id'], sb_meta, 'Not Bat')
|
480 |
+
|
481 |
+
if len(pred['annotation']) > 0:
|
482 |
+
# compute detection overlap with ground truth to determine which are the TP detections
|
483 |
+
assign_to_gt(gt, pred, args['iou_thresh'])
|
484 |
+
pred = parse_data(pred, class_names, classes_to_ignore, True)
|
485 |
+
x_train.append(pred['feats'])
|
486 |
+
y_train.append(pred['class_ids'])
|
487 |
+
|
488 |
+
# train random forest on tadarida predictions
|
489 |
+
clf_sb, un_train_class = train_rf_model(x_train, y_train, num_classes, args['rand_seed'])
|
490 |
+
|
491 |
+
# run the model on the test set
|
492 |
+
preds_sb_rf = []
|
493 |
+
for gt in gt_test:
|
494 |
+
pred = load_sonobat_preds(gt['dataset_name'], gt['id'], sb_meta, 'Not Bat')
|
495 |
+
pred = parse_data(pred, class_names, classes_to_ignore, True)
|
496 |
+
pred = eval_rf_model(clf_sb, pred, un_train_class, num_classes)
|
497 |
+
preds_sb_rf.append(pred)
|
498 |
+
|
499 |
+
results_sb_rf = evl.evaluate_predictions(gt_test, preds_sb_rf, class_names,
|
500 |
+
params_eval['detection_overlap'], params_eval['ignore_start_end'])
|
501 |
+
print_results('Sonobat RF', 'sb_rf', results_sb_rf, args['op_dir'], class_names,
|
502 |
+
args['file_type'], args['title_text'] + ' - Species - ')
|
503 |
+
print('\n\nWARNING\nThis is evaluating on the full test set, but there is only dections for a subset of files\n\n')
|
504 |
+
|
505 |
+
|
506 |
+
#
|
507 |
+
# evaluate Tadarida-D by training random forest classifier
|
508 |
+
#
|
509 |
+
if args['td_ip_dir'] != '':
|
510 |
+
x_train = []
|
511 |
+
y_train = []
|
512 |
+
for gt in gt_train:
|
513 |
+
pred = load_tadarida_pred(args['td_ip_dir'], gt['dataset_name'], gt['id'])
|
514 |
+
# compute detection overlap with ground truth to determine which are the TP detections
|
515 |
+
assign_to_gt(gt, pred, args['iou_thresh'])
|
516 |
+
pred = parse_data(pred, class_names, classes_to_ignore, True)
|
517 |
+
x_train.append(pred['feats'])
|
518 |
+
y_train.append(pred['class_ids'])
|
519 |
+
|
520 |
+
# train random forest on Tadarida-D predictions
|
521 |
+
clf_td, un_train_class = train_rf_model(x_train, y_train, num_classes, args['rand_seed'])
|
522 |
+
|
523 |
+
# run the model on the test set
|
524 |
+
preds_td = []
|
525 |
+
for gt in gt_test:
|
526 |
+
pred = load_tadarida_pred(args['td_ip_dir'], gt['dataset_name'], gt['id'])
|
527 |
+
pred = parse_data(pred, class_names, classes_to_ignore, True)
|
528 |
+
pred = eval_rf_model(clf_td, pred, un_train_class, num_classes)
|
529 |
+
preds_td.append(pred)
|
530 |
+
|
531 |
+
results_td = evl.evaluate_predictions(gt_test, preds_td, class_names,
|
532 |
+
params_eval['detection_overlap'], params_eval['ignore_start_end'])
|
533 |
+
print_results('Tadarida', 'td_rf', results_td, args['op_dir'], class_names,
|
534 |
+
args['file_type'], args['title_text'] + ' - Species - ')
|
535 |
+
|
536 |
+
|
537 |
+
#
|
538 |
+
# evaluate BatDetect
|
539 |
+
#
|
540 |
+
if args['bd_model_path'] != '':
|
541 |
+
# load model
|
542 |
+
bd_args = du.get_default_bd_args()
|
543 |
+
model, params_bd = du.load_model(args['bd_model_path'])
|
544 |
+
|
545 |
+
# check if the class names are the same
|
546 |
+
if params_bd['class_names'] != class_names:
|
547 |
+
print('Warning: Class names are not the same as the trained model')
|
548 |
+
assert False
|
549 |
+
|
550 |
+
preds_bd = []
|
551 |
+
for ii, gg in enumerate(gt_test):
|
552 |
+
pred = du.process_file(gg['file_path'], model, params_bd, bd_args, return_raw_preds=True)
|
553 |
+
preds_bd.append(pred)
|
554 |
+
|
555 |
+
results_bd = evl.evaluate_predictions(gt_test, preds_bd, class_names,
|
556 |
+
params_eval['detection_overlap'], params_eval['ignore_start_end'])
|
557 |
+
print_results('BatDetect', 'bd', results_bd, args['op_dir'],
|
558 |
+
class_names, args['file_type'], args['title_text'] + ' - Species - ')
|
559 |
+
|
560 |
+
# evaluate genus level
|
561 |
+
class_names_genus, preds_bd_g, gt_test_g = create_genus_mapping(gt_test, preds_bd, class_names)
|
562 |
+
results_bd_genus = evl.evaluate_predictions(gt_test_g, preds_bd_g, class_names_genus,
|
563 |
+
params_eval['detection_overlap'], params_eval['ignore_start_end'])
|
564 |
+
print_results('BatDetect Genus', 'bd_genus', results_bd_genus, args['op_dir'],
|
565 |
+
class_names_genus, args['file_type'], args['title_text'] + ' - Genus - ')
|
bat_detect/evaluate/readme.md
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Evaluating BatDetect2
|
2 |
+
This script evaluates a trained model and outputs several plots summarizing the performance. It is used as follows:
|
3 |
+
`python path_to_store_images/ path_to_audio_files/ path_to_annotation_file/ path_to_trained_model/`
|
4 |
+
|
5 |
+
e.g.
|
6 |
+
`python evaluate_models.py ../../plots/results_compare_yuc/ /data1/bat_data/data/yucatan/audio/ /data1/bat_data/annotations/anns_finetune/ ../../experiments/2021_12_17__15_58_43/2021_12_17__15_58_43.pth.tar`
|
7 |
+
|
8 |
+
By default this will just evaluate the set of test files that are already specified in the model at training time. However, you can also specify a single set of annotations to evaluate using the `--test_file` flag. These must be stored in one annotation file, containing a list of the individual files.
|
9 |
+
|
10 |
+
e.g.
|
11 |
+
`python evaluate_models.py ../../plots/results_compare_yuc/ /data1/bat_data/data/yucatan/audio/ /data1/bat_data/annotations/anns_finetune/ ../../experiments/2021_12_17__15_58_43/2021_12_17__15_58_43.pth.tar --test_file yucatan_TEST.json`
|
12 |
+
|
13 |
+
You can also specify if the plots are saved as a .png or .pdf using `--file_type` and you can set title text for a plot using `--title_text`, e.g. `--file_type pdf --title_text "My Dataset Name"`
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
### Comparing to Tadarida-D
|
19 |
+
It is also possible to compare to Tadarida-D. For Tadarida-D the following steps are performed:
|
20 |
+
- Matches Tadarida's detections to manually annotated calls
|
21 |
+
- Trains a RandomForest classifier using Tadarida call features
|
22 |
+
- Evaluate the classifier on a held out set
|
23 |
+
|
24 |
+
Uses precomputed binaries for Tadarida-D from:
|
25 |
+
`https://github.com/YvesBas/Tadarida-D/archive/master.zip`
|
26 |
+
|
27 |
+
Needs to be run using the following arguments:
|
28 |
+
`./TadaridaD -t 4 -x 1 ip_dir/`
|
29 |
+
-t 4 means 4 threads
|
30 |
+
-x 1 means time expansions of 1
|
31 |
+
|
32 |
+
This will generate a folder called `txt` which contains a corresponding `.ta` file for each input audio file. Example usage is as follows:
|
33 |
+
`python evaluate_models.py ../../plots/results_compare_yuc/ /data1/bat_data/data/yucatan/audio/ /data1/bat_data/annotations/anns_finetune/ ../../experiments/2021_12_17__15_58_43/2021_12_17__15_58_43.pth.tar --td_ip_dir /data1/bat_data/baselines/tadarida_D/`
|
bat_detect/finetune/finetune_model.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
7 |
+
import json
|
8 |
+
import argparse
|
9 |
+
import glob
|
10 |
+
|
11 |
+
import sys
|
12 |
+
sys.path.append(os.path.join('..', '..'))
|
13 |
+
import bat_detect.train.train_model as tm
|
14 |
+
import bat_detect.train.audio_dataloader as adl
|
15 |
+
import bat_detect.train.evaluate as evl
|
16 |
+
import bat_detect.train.train_utils as tu
|
17 |
+
import bat_detect.train.losses as losses
|
18 |
+
|
19 |
+
import bat_detect.detector.parameters as parameters
|
20 |
+
import bat_detect.detector.models as models
|
21 |
+
import bat_detect.detector.post_process as pp
|
22 |
+
import bat_detect.utils.plot_utils as pu
|
23 |
+
import bat_detect.utils.detector_utils as du
|
24 |
+
|
25 |
+
|
26 |
+
if __name__ == "__main__":
|
27 |
+
|
28 |
+
info_str = '\nBatDetect - Finetune Model\n'
|
29 |
+
|
30 |
+
print(info_str)
|
31 |
+
parser = argparse.ArgumentParser()
|
32 |
+
parser.add_argument('audio_path', type=str, help='Input directory for audio')
|
33 |
+
parser.add_argument('train_ann_path', type=str,
|
34 |
+
help='Path to where train annotation file is stored')
|
35 |
+
parser.add_argument('test_ann_path', type=str,
|
36 |
+
help='Path to where test annotation file is stored')
|
37 |
+
parser.add_argument('model_path', type=str,
|
38 |
+
help='Path to pretrained model')
|
39 |
+
parser.add_argument('--op_model_name', type=str, default='',
|
40 |
+
help='Path and name for finetuned model')
|
41 |
+
parser.add_argument('--num_epochs', type=int, default=200, dest='num_epochs',
|
42 |
+
help='Number of finetuning epochs')
|
43 |
+
parser.add_argument('--finetune_only_last_layer', action='store_true',
|
44 |
+
help='Only train final layers')
|
45 |
+
parser.add_argument('--train_from_scratch', action='store_true',
|
46 |
+
help='Do not use pretrained weights')
|
47 |
+
parser.add_argument('--do_not_save_images', action='store_false',
|
48 |
+
help='Do not save images at the end of training')
|
49 |
+
parser.add_argument('--notes', type=str, default='',
|
50 |
+
help='Notes to save in text file')
|
51 |
+
args = vars(parser.parse_args())
|
52 |
+
|
53 |
+
params = parameters.get_params(True, '../../experiments/')
|
54 |
+
if torch.cuda.is_available():
|
55 |
+
params['device'] = 'cuda'
|
56 |
+
else:
|
57 |
+
params['device'] = 'cpu'
|
58 |
+
print('\nNote, this will be a lot faster if you use computer with a GPU.\n')
|
59 |
+
|
60 |
+
print('\nAudio directory: ' + args['audio_path'])
|
61 |
+
print('Train file: ' + args['train_ann_path'])
|
62 |
+
print('Test file: ' + args['test_ann_path'])
|
63 |
+
print('Loading model: ' + args['model_path'])
|
64 |
+
|
65 |
+
dataset_name = os.path.basename(args['train_ann_path']).replace('.json', '').replace('_TRAIN', '')
|
66 |
+
|
67 |
+
if args['train_from_scratch']:
|
68 |
+
print('\nTraining model from scratch i.e. not using pretrained weights')
|
69 |
+
model, params_train = du.load_model(args['model_path'], False)
|
70 |
+
else:
|
71 |
+
model, params_train = du.load_model(args['model_path'], True)
|
72 |
+
model.to(params['device'])
|
73 |
+
|
74 |
+
params['num_epochs'] = args['num_epochs']
|
75 |
+
if args['op_model_name'] != '':
|
76 |
+
params['model_file_name'] = args['op_model_name']
|
77 |
+
classes_to_ignore = params['classes_to_ignore']+params['generic_class']
|
78 |
+
|
79 |
+
# save notes file
|
80 |
+
params['notes'] = args['notes']
|
81 |
+
if args['notes'] != '':
|
82 |
+
tu.write_notes_file(params['experiment'] + 'notes.txt', args['notes'])
|
83 |
+
|
84 |
+
|
85 |
+
# load train annotations
|
86 |
+
train_sets = []
|
87 |
+
train_sets.append(tu.get_blank_dataset_dict(dataset_name, False, args['train_ann_path'], args['audio_path']))
|
88 |
+
params['train_sets'] = [tu.get_blank_dataset_dict(dataset_name, False, os.path.basename(args['train_ann_path']), args['audio_path'])]
|
89 |
+
|
90 |
+
print('\nTrain set:')
|
91 |
+
data_train, params['class_names'], params['class_inv_freq'] = \
|
92 |
+
tu.load_set_of_anns(train_sets, classes_to_ignore, params['events_of_interest'])
|
93 |
+
print('Number of files', len(data_train))
|
94 |
+
|
95 |
+
params['genus_names'], params['genus_mapping'] = tu.get_genus_mapping(params['class_names'])
|
96 |
+
params['class_names_short'] = tu.get_short_class_names(params['class_names'])
|
97 |
+
|
98 |
+
# load test annotations
|
99 |
+
test_sets = []
|
100 |
+
test_sets.append(tu.get_blank_dataset_dict(dataset_name, True, args['test_ann_path'], args['audio_path']))
|
101 |
+
params['test_sets'] = [tu.get_blank_dataset_dict(dataset_name, True, os.path.basename(args['test_ann_path']), args['audio_path'])]
|
102 |
+
|
103 |
+
print('\nTest set:')
|
104 |
+
data_test, _, _ = tu.load_set_of_anns(test_sets, classes_to_ignore, params['events_of_interest'])
|
105 |
+
print('Number of files', len(data_test))
|
106 |
+
|
107 |
+
# train loader
|
108 |
+
train_dataset = adl.AudioLoader(data_train, params, is_train=True)
|
109 |
+
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params['batch_size'],
|
110 |
+
shuffle=True, num_workers=params['num_workers'], pin_memory=True)
|
111 |
+
|
112 |
+
# test loader - batch size of one because of variable file length
|
113 |
+
test_dataset = adl.AudioLoader(data_test, params, is_train=False)
|
114 |
+
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1,
|
115 |
+
shuffle=False, num_workers=params['num_workers'], pin_memory=True)
|
116 |
+
|
117 |
+
inputs_train = next(iter(train_loader))
|
118 |
+
params['ip_height'] = inputs_train['spec'].shape[2]
|
119 |
+
print('\ntrain batch size :', inputs_train['spec'].shape)
|
120 |
+
|
121 |
+
assert(params_train['model_name'] == 'Net2DFast')
|
122 |
+
print('\n\nSOME hyperparams need to be the same as the loaded model (e.g. FFT) - currently they are getting overwritten.\n\n')
|
123 |
+
|
124 |
+
# set the number of output classes
|
125 |
+
num_filts = model.conv_classes_op.in_channels
|
126 |
+
k_size = model.conv_classes_op.kernel_size
|
127 |
+
pad = model.conv_classes_op.padding
|
128 |
+
model.conv_classes_op = torch.nn.Conv2d(num_filts, len(params['class_names'])+1, kernel_size=k_size, padding=pad)
|
129 |
+
model.conv_classes_op.to(params['device'])
|
130 |
+
|
131 |
+
if args['finetune_only_last_layer']:
|
132 |
+
print('\nOnly finetuning the final layers.\n')
|
133 |
+
train_layers_i = ['conv_classes', 'conv_classes_op', 'conv_size', 'conv_size_op']
|
134 |
+
train_layers = [tt + '.weight' for tt in train_layers_i] + [tt + '.bias' for tt in train_layers_i]
|
135 |
+
for name, param in model.named_parameters():
|
136 |
+
if name in train_layers:
|
137 |
+
param.requires_grad = True
|
138 |
+
else:
|
139 |
+
param.requires_grad = False
|
140 |
+
|
141 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'])
|
142 |
+
scheduler = CosineAnnealingLR(optimizer, params['num_epochs'] * len(train_loader))
|
143 |
+
if params['train_loss'] == 'mse':
|
144 |
+
det_criterion = losses.mse_loss
|
145 |
+
elif params['train_loss'] == 'focal':
|
146 |
+
det_criterion = losses.focal_loss
|
147 |
+
|
148 |
+
# plotting
|
149 |
+
train_plt_ls = pu.LossPlotter(params['experiment'] + 'train_loss.png', params['num_epochs']+1,
|
150 |
+
['train_loss'], None, None, ['epoch', 'train_loss'], logy=True)
|
151 |
+
test_plt_ls = pu.LossPlotter(params['experiment'] + 'test_loss.png', params['num_epochs']+1,
|
152 |
+
['test_loss'], None, None, ['epoch', 'test_loss'], logy=True)
|
153 |
+
test_plt = pu.LossPlotter(params['experiment'] + 'test.png', params['num_epochs']+1,
|
154 |
+
['avg_prec', 'rec_at_x', 'avg_prec_class', 'file_acc', 'top_class'], [0,1], None, ['epoch', ''])
|
155 |
+
test_plt_class = pu.LossPlotter(params['experiment'] + 'test_avg_prec.png', params['num_epochs']+1,
|
156 |
+
params['class_names_short'], [0,1], params['class_names_short'], ['epoch', 'avg_prec'])
|
157 |
+
|
158 |
+
# main train loop
|
159 |
+
for epoch in range(0, params['num_epochs']+1):
|
160 |
+
|
161 |
+
train_loss = tm.train(model, epoch, train_loader, det_criterion, optimizer, scheduler, params)
|
162 |
+
train_plt_ls.update_and_save(epoch, [train_loss['train_loss']])
|
163 |
+
|
164 |
+
if epoch % params['num_eval_epochs'] == 0:
|
165 |
+
# detection accuracy on test set
|
166 |
+
test_res, test_loss = tm.test(model, epoch, test_loader, det_criterion, params)
|
167 |
+
test_plt_ls.update_and_save(epoch, [test_loss['test_loss']])
|
168 |
+
test_plt.update_and_save(epoch, [test_res['avg_prec'], test_res['rec_at_x'],
|
169 |
+
test_res['avg_prec_class'], test_res['file_acc'], test_res['top_class']['avg_prec']])
|
170 |
+
test_plt_class.update_and_save(epoch, [rs['avg_prec'] for rs in test_res['class_pr']])
|
171 |
+
pu.plot_pr_curve_class(params['experiment'] , 'test_pr', 'test_pr', test_res)
|
172 |
+
|
173 |
+
# save finetuned model
|
174 |
+
print('saving model to: ' + params['model_file_name'])
|
175 |
+
op_state = {'epoch': epoch + 1,
|
176 |
+
'state_dict': model.state_dict(),
|
177 |
+
'params' : params}
|
178 |
+
torch.save(op_state, params['model_file_name'])
|
179 |
+
|
180 |
+
|
181 |
+
# save an image with associated prediction for each batch in the test set
|
182 |
+
if not args['do_not_save_images']:
|
183 |
+
tm.save_images_batch(model, test_loader, params)
|
bat_detect/finetune/prep_data_finetune.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
|
6 |
+
import sys
|
7 |
+
sys.path.append(os.path.join('..', '..'))
|
8 |
+
import bat_detect.train.train_utils as tu
|
9 |
+
|
10 |
+
|
11 |
+
def print_dataset_stats(data, split_name, classes_to_ignore):
|
12 |
+
|
13 |
+
print('\nSplit:', split_name)
|
14 |
+
print('Num files:', len(data))
|
15 |
+
|
16 |
+
class_cnts = {}
|
17 |
+
for dd in data:
|
18 |
+
for aa in dd['annotation']:
|
19 |
+
if aa['class'] not in classes_to_ignore:
|
20 |
+
if aa['class'] in class_cnts:
|
21 |
+
class_cnts[aa['class']] += 1
|
22 |
+
else:
|
23 |
+
class_cnts[aa['class']] = 1
|
24 |
+
|
25 |
+
if len(class_cnts) == 0:
|
26 |
+
class_names = []
|
27 |
+
else:
|
28 |
+
class_names = np.sort([*class_cnts]).tolist()
|
29 |
+
print('Class count:')
|
30 |
+
str_len = np.max([len(cc) for cc in class_names]) + 5
|
31 |
+
|
32 |
+
for ii, cc in enumerate(class_names):
|
33 |
+
print(str(ii).ljust(5) + cc.ljust(str_len) + str(class_cnts[cc]))
|
34 |
+
|
35 |
+
return class_names
|
36 |
+
|
37 |
+
|
38 |
+
def load_file_names(file_name):
|
39 |
+
|
40 |
+
if os.path.isfile(file_name):
|
41 |
+
with open(file_name) as da:
|
42 |
+
files = [line.rstrip() for line in da.readlines()]
|
43 |
+
for ff in files:
|
44 |
+
if ff.lower()[-3:] != 'wav':
|
45 |
+
print('Error: Filenames need to end in .wav - ', ff)
|
46 |
+
assert(False)
|
47 |
+
else:
|
48 |
+
print('Error: Input file not found - ', file_name)
|
49 |
+
assert(False)
|
50 |
+
|
51 |
+
return files
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
|
56 |
+
info_str = '\nBatDetect - Prepare Data for Finetuning\n'
|
57 |
+
|
58 |
+
print(info_str)
|
59 |
+
parser = argparse.ArgumentParser()
|
60 |
+
parser.add_argument('dataset_name', type=str, help='Name to call your dataset')
|
61 |
+
parser.add_argument('audio_dir', type=str, help='Input directory for audio')
|
62 |
+
parser.add_argument('ann_dir', type=str, help='Input directory for where the audio annotations are stored')
|
63 |
+
parser.add_argument('op_dir', type=str, help='Path where the train and test splits will be stored')
|
64 |
+
parser.add_argument('--percent_val', type=float, default=0.20,
|
65 |
+
help='Hold out this much data for validation. Should be number between 0 and 1')
|
66 |
+
parser.add_argument('--rand_seed', type=int, default=2001,
|
67 |
+
help='Random seed used for creating the validation split')
|
68 |
+
parser.add_argument('--train_file', type=str, default='',
|
69 |
+
help='Text file where each line is a wav file in train split')
|
70 |
+
parser.add_argument('--test_file', type=str, default='',
|
71 |
+
help='Text file where each line is a wav file in test split')
|
72 |
+
parser.add_argument('--input_class_names', type=str, default='',
|
73 |
+
help='Specify names of classes that you want to change. Separate with ";"')
|
74 |
+
parser.add_argument('--output_class_names', type=str, default='',
|
75 |
+
help='New class names to use instead. One to one mapping with "--input_class_names". \
|
76 |
+
Separate with ";"')
|
77 |
+
args = vars(parser.parse_args())
|
78 |
+
|
79 |
+
|
80 |
+
np.random.seed(args['rand_seed'])
|
81 |
+
|
82 |
+
classes_to_ignore = ['', ' ', 'Unknown', 'Not Bat']
|
83 |
+
generic_class = ['Bat']
|
84 |
+
events_of_interest = ['Echolocation']
|
85 |
+
|
86 |
+
if args['input_class_names'] != '' and args['output_class_names'] != '':
|
87 |
+
# change the names of the classes
|
88 |
+
ip_names = args['input_class_names'].split(';')
|
89 |
+
op_names = args['output_class_names'].split(';')
|
90 |
+
name_dict = dict(zip(ip_names, op_names))
|
91 |
+
else:
|
92 |
+
name_dict = False
|
93 |
+
|
94 |
+
# load annotations
|
95 |
+
data_all, _, _ = tu.load_set_of_anns({'ann_path': args['ann_dir'], 'wav_path': args['audio_dir']},
|
96 |
+
classes_to_ignore, events_of_interest, False, False,
|
97 |
+
list_of_anns=True, filter_issues=True, name_replace=name_dict)
|
98 |
+
|
99 |
+
print('Dataset name: ' + args['dataset_name'])
|
100 |
+
print('Audio directory: ' + args['audio_dir'])
|
101 |
+
print('Annotation directory: ' + args['ann_dir'])
|
102 |
+
print('Ouput directory: ' + args['op_dir'])
|
103 |
+
print('Num annotated files: ' + str(len(data_all)))
|
104 |
+
|
105 |
+
if args['train_file'] != '' and args['test_file'] != '':
|
106 |
+
# user has specifed the train / test split
|
107 |
+
train_files = load_file_names(args['train_file'])
|
108 |
+
test_files = load_file_names(args['test_file'])
|
109 |
+
file_names_all = [dd['id'] for dd in data_all]
|
110 |
+
train_inds = [file_names_all.index(ff) for ff in train_files if ff in file_names_all]
|
111 |
+
test_inds = [file_names_all.index(ff) for ff in test_files if ff in file_names_all]
|
112 |
+
|
113 |
+
else:
|
114 |
+
# split the data into train and test at the file level
|
115 |
+
num_exs = len(data_all)
|
116 |
+
test_inds = np.random.choice(np.arange(num_exs), int(num_exs*args['percent_val']), replace=False)
|
117 |
+
test_inds = np.sort(test_inds)
|
118 |
+
train_inds = np.setdiff1d(np.arange(num_exs), test_inds)
|
119 |
+
|
120 |
+
data_train = [data_all[ii] for ii in train_inds]
|
121 |
+
data_test = [data_all[ii] for ii in test_inds]
|
122 |
+
|
123 |
+
if not os.path.isdir(args['op_dir']):
|
124 |
+
os.makedirs(args['op_dir'])
|
125 |
+
op_name = os.path.join(args['op_dir'], args['dataset_name'])
|
126 |
+
op_name_train = op_name + '_TRAIN.json'
|
127 |
+
op_name_test = op_name + '_TEST.json'
|
128 |
+
|
129 |
+
class_un_train = print_dataset_stats(data_train, 'Train', classes_to_ignore)
|
130 |
+
class_un_test = print_dataset_stats(data_test, 'Test', classes_to_ignore)
|
131 |
+
|
132 |
+
if len(data_train) > 0 and len(data_test) > 0:
|
133 |
+
if class_un_train != class_un_test:
|
134 |
+
print('\nError: some classes are not in both the training and test sets.\
|
135 |
+
\nTry a different random seed "--rand_seed".')
|
136 |
+
assert False
|
137 |
+
|
138 |
+
print('\n')
|
139 |
+
if len(data_train) == 0:
|
140 |
+
print('No train annotations to save')
|
141 |
+
else:
|
142 |
+
print('Saving: ', op_name_train)
|
143 |
+
with open(op_name_train, 'w') as da:
|
144 |
+
json.dump(data_train, da, indent=2)
|
145 |
+
|
146 |
+
if len(data_test) == 0:
|
147 |
+
print('No test annotations to save')
|
148 |
+
else:
|
149 |
+
print('Saving: ', op_name_test)
|
150 |
+
with open(op_name_test, 'w') as da:
|
151 |
+
json.dump(data_test, da, indent=2)
|
bat_detect/finetune/readme.md
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Finetuning the BatDetet2 model on your own data
|
3 |
+
Main steps:
|
4 |
+
1. Annotate your data using the annotation GUI.
|
5 |
+
2. Run `prep_data_finetune.py` to create a training and validation split for your data.
|
6 |
+
3. Run `finetune_model.py` to finetune a model on your data.
|
7 |
+
|
8 |
+
|
9 |
+
## 1. Annotate calls of interest in audio data
|
10 |
+
Use the annotation tools provided [here](https://github.com/macaodha/batdetect2_GUI) to manually identify where the events of interest (e.g. bat echolocation calls) are in your files.
|
11 |
+
This will result in a directory of audio files and a directory of annotation files, where each audio file will have a corresponding `.json` annotation file.
|
12 |
+
Make sure to annotation all instances of a bat call.
|
13 |
+
If unsure of the species, just label the call as `Bat`.
|
14 |
+
|
15 |
+
|
16 |
+
## 2. Split data into train and test sets
|
17 |
+
After performing the previous step you should have a directory of annotations files saved as jsons, one for each audio file you have annotated.
|
18 |
+
* The next step is to split these into training and testing subsets.
|
19 |
+
Run `prep_data_finetune.py` to split the data into train and test sets. This will result in two separate files, a train and a test one, i.e.
|
20 |
+
`python prep_data_finetune.py dataset_name path_to_audio/ path_to_annotations/ path_to_output_anns/`
|
21 |
+
This may result an error if it does not generate output files containing the same set of species in the train and test splits. You can try different random seeds if this is an issue e.g. `--rand_seed 123456`.
|
22 |
+
|
23 |
+
* You can also load the train and test split using text files, where each line of the text file is the name of a `wav` file (without the file path) e.g.
|
24 |
+
`python prep_data_finetune.py dataset_name path_to_audio/ path_to_annotations/ path_to_output/ --train_file path_to_file/list_of_train_files.txt --test_file path_to_file/list_of_test_files.txt`
|
25 |
+
|
26 |
+
|
27 |
+
* Can also replace class names. This can be helpful if you don't think you have enough calls/files for a given species. Use semi-colons to separate, without spaces between them e.g.
|
28 |
+
`python prep_data_finetune.py dataset_name path_to_audio/audio/ path_to_annotations/anns/ path_to_output/ --input_class_names "Histiotus;Molossidae;Lasiurus;Myotis;Rhogeesa;Vespertilionidae" --output_class_names "Group One;Group One;Group One;Group Two;Group Two;Group Three"`
|
29 |
+
|
30 |
+
|
31 |
+
## 3. Finetuning the model
|
32 |
+
Finally, you can finetune the model using your data i.e.
|
33 |
+
`python finetune_model.py path_to_audio/ path_to_train/TRAIN.json path_to_train/TEST.json ../../models/Net2DFast_UK_same.pth.tar`
|
34 |
+
Here, `TRAIN.json` and `TEST.json` are the splits created in the previous steps.
|
35 |
+
|
36 |
+
|
37 |
+
#### Additional notes
|
38 |
+
* For the first step it is better to cut the files into less than 5 second audio clips and make sure to annotate them exhaustively (i.e. all bat calls should be annotated).
|
39 |
+
* You can train the model for longer, by setting the `--num_epochs` flag to a larger number e.g. `--num_epochs 400`. The default is `200`.
|
40 |
+
* If you do not want to finetune the model, but instead want to train it from scratch, you can set the `--train_from_scratch` flag.
|
bat_detect/train/__init__.py
ADDED
File without changes
|
bat_detect/train/audio_dataloader.py
ADDED
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import random
|
3 |
+
import numpy as np
|
4 |
+
import copy
|
5 |
+
import librosa
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torchaudio
|
8 |
+
import os
|
9 |
+
|
10 |
+
import sys
|
11 |
+
sys.path.append(os.path.join('..', '..'))
|
12 |
+
import bat_detect.utils.audio_utils as au
|
13 |
+
|
14 |
+
|
15 |
+
def generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, params):
|
16 |
+
# spec may be resized on input into the network
|
17 |
+
num_classes = len(params['class_names'])
|
18 |
+
op_height = spec_op_shape[0]
|
19 |
+
op_width = spec_op_shape[1]
|
20 |
+
freq_per_bin = (params['max_freq'] - params['min_freq']) / op_height
|
21 |
+
|
22 |
+
# start and end times
|
23 |
+
x_pos_start = au.time_to_x_coords(ann['start_times'], sampling_rate,
|
24 |
+
params['fft_win_length'], params['fft_overlap'])
|
25 |
+
x_pos_start = (params['resize_factor']*x_pos_start).astype(np.int)
|
26 |
+
x_pos_end = au.time_to_x_coords(ann['end_times'], sampling_rate,
|
27 |
+
params['fft_win_length'], params['fft_overlap'])
|
28 |
+
x_pos_end = (params['resize_factor']*x_pos_end).astype(np.int)
|
29 |
+
|
30 |
+
# location on y axis i.e. frequency
|
31 |
+
y_pos_low = (ann['low_freqs'] - params['min_freq']) / freq_per_bin
|
32 |
+
y_pos_low = (op_height - y_pos_low).astype(np.int)
|
33 |
+
y_pos_high = (ann['high_freqs'] - params['min_freq']) / freq_per_bin
|
34 |
+
y_pos_high = (op_height - y_pos_high).astype(np.int)
|
35 |
+
bb_widths = x_pos_end - x_pos_start
|
36 |
+
bb_heights = (y_pos_low - y_pos_high)
|
37 |
+
|
38 |
+
valid_inds = np.where((x_pos_start >= 0) & (x_pos_start < op_width) &
|
39 |
+
(y_pos_low >= 0) & (y_pos_low < (op_height-1)))[0]
|
40 |
+
|
41 |
+
ann_aug = {}
|
42 |
+
ann_aug['x_inds'] = x_pos_start[valid_inds]
|
43 |
+
ann_aug['y_inds'] = y_pos_low[valid_inds]
|
44 |
+
keys = ['start_times', 'end_times', 'high_freqs', 'low_freqs', 'class_ids', 'individual_ids']
|
45 |
+
for kk in keys:
|
46 |
+
ann_aug[kk] = ann[kk][valid_inds]
|
47 |
+
|
48 |
+
# if the number of calls is only 1, then it is unique
|
49 |
+
# TODO would be better if we found these unique calls at the merging stage
|
50 |
+
if len(ann_aug['individual_ids']) == 1:
|
51 |
+
ann_aug['individual_ids'][0] = 0
|
52 |
+
|
53 |
+
y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
|
54 |
+
y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32)
|
55 |
+
# num classes and "background" class
|
56 |
+
y_2d_classes = np.zeros((num_classes+1, op_height, op_width), dtype=np.float32)
|
57 |
+
|
58 |
+
# create 2D ground truth heatmaps
|
59 |
+
for ii in valid_inds:
|
60 |
+
draw_gaussian(y_2d_det[0,:], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'])
|
61 |
+
#draw_gaussian(y_2d_det[0,:], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2)
|
62 |
+
y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii]
|
63 |
+
y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii]
|
64 |
+
|
65 |
+
cls_id = ann['class_ids'][ii]
|
66 |
+
if cls_id > -1:
|
67 |
+
draw_gaussian(y_2d_classes[cls_id, :], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'])
|
68 |
+
#draw_gaussian(y_2d_classes[cls_id, :], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2)
|
69 |
+
|
70 |
+
# be careful as this will have a 1.0 places where we have event but dont know gt class
|
71 |
+
# this will be masked in training anyway
|
72 |
+
y_2d_classes[num_classes, :] = 1.0 - y_2d_classes.sum(0)
|
73 |
+
y_2d_classes = y_2d_classes / y_2d_classes.sum(0)[np.newaxis, ...]
|
74 |
+
y_2d_classes[np.isnan(y_2d_classes)] = 0.0
|
75 |
+
|
76 |
+
return y_2d_det, y_2d_size, y_2d_classes, ann_aug
|
77 |
+
|
78 |
+
|
79 |
+
def draw_gaussian(heatmap, center, sigmax, sigmay=None):
|
80 |
+
# center is (x, y)
|
81 |
+
# this edits the heatmap inplace
|
82 |
+
|
83 |
+
if sigmay is None:
|
84 |
+
sigmay = sigmax
|
85 |
+
tmp_size = np.maximum(sigmax, sigmay) * 3
|
86 |
+
mu_x = int(center[0] + 0.5)
|
87 |
+
mu_y = int(center[1] + 0.5)
|
88 |
+
w, h = heatmap.shape[0], heatmap.shape[1]
|
89 |
+
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
|
90 |
+
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
|
91 |
+
|
92 |
+
if ul[0] >= h or ul[1] >= w or br[0] < 0 or br[1] < 0:
|
93 |
+
return False
|
94 |
+
|
95 |
+
size = 2 * tmp_size + 1
|
96 |
+
x = np.arange(0, size, 1, np.float32)
|
97 |
+
y = x[:, np.newaxis]
|
98 |
+
x0 = y0 = size // 2
|
99 |
+
#g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
|
100 |
+
g = np.exp(- ((x - x0) ** 2)/(2 * sigmax ** 2) - ((y - y0) ** 2)/(2 * sigmay ** 2))
|
101 |
+
g_x = max(0, -ul[0]), min(br[0], h) - ul[0]
|
102 |
+
g_y = max(0, -ul[1]), min(br[1], w) - ul[1]
|
103 |
+
img_x = max(0, ul[0]), min(br[0], h)
|
104 |
+
img_y = max(0, ul[1]), min(br[1], w)
|
105 |
+
heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]] = np.maximum(
|
106 |
+
heatmap[img_y[0]:img_y[1], img_x[0]:img_x[1]],
|
107 |
+
g[g_y[0]:g_y[1], g_x[0]:g_x[1]])
|
108 |
+
return True
|
109 |
+
|
110 |
+
|
111 |
+
def pad_aray(ip_array, pad_size):
|
112 |
+
return np.hstack((ip_array, np.ones(pad_size, dtype=np.int)*-1))
|
113 |
+
|
114 |
+
|
115 |
+
def warp_spec_aug(spec, ann, return_spec_for_viz, params):
|
116 |
+
# This is messy
|
117 |
+
# Augment spectrogram by randomly stretch and squeezing
|
118 |
+
# NOTE this also changes the start and stop time in place
|
119 |
+
|
120 |
+
# not taking care of spec for viz
|
121 |
+
if return_spec_for_viz:
|
122 |
+
assert False
|
123 |
+
|
124 |
+
delta = params['stretch_squeeze_delta']
|
125 |
+
op_size = (spec.shape[1], spec.shape[2])
|
126 |
+
resize_fract_r = np.random.rand()*delta*2 - delta + 1.0
|
127 |
+
resize_amt = int(spec.shape[2]*resize_fract_r)
|
128 |
+
if resize_amt >= spec.shape[2]:
|
129 |
+
spec_r = torch.cat((spec, torch.zeros((1, spec.shape[1], resize_amt-spec.shape[2]), dtype=spec.dtype)), 2)
|
130 |
+
else:
|
131 |
+
spec_r = spec[:, :, :resize_amt]
|
132 |
+
spec = F.interpolate(spec_r.unsqueeze(0), size=op_size, mode='bilinear', align_corners=False).squeeze(0)
|
133 |
+
ann['start_times'] *= (1.0/resize_fract_r)
|
134 |
+
ann['end_times'] *= (1.0/resize_fract_r)
|
135 |
+
return spec
|
136 |
+
|
137 |
+
|
138 |
+
def mask_time_aug(spec, params):
|
139 |
+
# Mask out a random block of time - repeat up to 3 times
|
140 |
+
# SpecAugment: A Simple Data Augmentation Methodfor Automatic Speech Recognition
|
141 |
+
fm = torchaudio.transforms.TimeMasking(int(spec.shape[1]*params['mask_max_time_perc']))
|
142 |
+
for ii in range(np.random.randint(1, 4)):
|
143 |
+
spec = fm(spec)
|
144 |
+
return spec
|
145 |
+
|
146 |
+
|
147 |
+
def mask_freq_aug(spec, params):
|
148 |
+
# Mask out a random frequncy range - repeat up to 3 times
|
149 |
+
# SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
|
150 |
+
fm = torchaudio.transforms.FrequencyMasking(int(spec.shape[1]*params['mask_max_freq_perc']))
|
151 |
+
for ii in range(np.random.randint(1, 4)):
|
152 |
+
spec = fm(spec)
|
153 |
+
return spec
|
154 |
+
|
155 |
+
|
156 |
+
def scale_vol_aug(spec, params):
|
157 |
+
return spec * np.random.random()*params['spec_amp_scaling']
|
158 |
+
|
159 |
+
|
160 |
+
def echo_aug(audio, sampling_rate, params):
|
161 |
+
sample_offset = int(params['echo_max_delay']*np.random.random()*sampling_rate) + 1
|
162 |
+
audio[:-sample_offset] += np.random.random()*audio[sample_offset:]
|
163 |
+
return audio
|
164 |
+
|
165 |
+
|
166 |
+
def resample_aug(audio, sampling_rate, params):
|
167 |
+
sampling_rate_old = sampling_rate
|
168 |
+
sampling_rate = np.random.choice(params['aug_sampling_rates'])
|
169 |
+
audio = librosa.resample(audio, sampling_rate_old, sampling_rate, res_type='polyphase')
|
170 |
+
|
171 |
+
audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'],
|
172 |
+
params['fft_overlap'], params['resize_factor'],
|
173 |
+
params['spec_divide_factor'], params['spec_train_width'])
|
174 |
+
duration = audio.shape[0] / float(sampling_rate)
|
175 |
+
return audio, sampling_rate, duration
|
176 |
+
|
177 |
+
|
178 |
+
def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2):
|
179 |
+
if sampling_rate != sampling_rate2:
|
180 |
+
audio2 = librosa.resample(audio2, sampling_rate2, sampling_rate, res_type='polyphase')
|
181 |
+
sampling_rate2 = sampling_rate
|
182 |
+
if audio2.shape[0] < num_samples:
|
183 |
+
audio2 = np.hstack((audio2, np.zeros((num_samples-audio2.shape[0]), dtype=audio2.dtype)))
|
184 |
+
elif audio2.shape[0] > num_samples:
|
185 |
+
audio2 = audio2[:num_samples]
|
186 |
+
return audio2, sampling_rate2
|
187 |
+
|
188 |
+
|
189 |
+
def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2):
|
190 |
+
|
191 |
+
# resample so they are the same
|
192 |
+
audio2, sampling_rate2 = resample_audio(audio.shape[0], sampling_rate, audio2, sampling_rate2)
|
193 |
+
|
194 |
+
# # set mean and std to be the same
|
195 |
+
# audio2 = (audio2 - audio2.mean())
|
196 |
+
# audio2 = (audio2/audio2.std())*audio.std()
|
197 |
+
# audio2 = audio2 + audio.mean()
|
198 |
+
|
199 |
+
if ann['annotated'] and (ann2['annotated']) and \
|
200 |
+
(sampling_rate2 == sampling_rate) and (audio.shape[0] == audio2.shape[0]):
|
201 |
+
comb_weight = 0.3 + np.random.random()*0.4
|
202 |
+
audio = comb_weight*audio + (1-comb_weight)*audio2
|
203 |
+
inds = np.argsort(np.hstack((ann['start_times'], ann2['start_times'])))
|
204 |
+
for kk in ann.keys():
|
205 |
+
|
206 |
+
# when combining calls from different files, assume they come from different individuals
|
207 |
+
if kk == 'individual_ids':
|
208 |
+
if (ann[kk]>-1).sum() > 0:
|
209 |
+
ann2[kk][ann2[kk]>-1] += np.max(ann[kk][ann[kk]>-1]) + 1
|
210 |
+
|
211 |
+
if (kk != 'class_id_file') and (kk != 'annotated'):
|
212 |
+
ann[kk] = np.hstack((ann[kk], ann2[kk]))[inds]
|
213 |
+
|
214 |
+
return audio, ann
|
215 |
+
|
216 |
+
|
217 |
+
class AudioLoader(torch.utils.data.Dataset):
|
218 |
+
def __init__(self, data_anns_ip, params, dataset_name=None, is_train=False):
|
219 |
+
|
220 |
+
self.data_anns = []
|
221 |
+
self.is_train = is_train
|
222 |
+
self.params = params
|
223 |
+
self.return_spec_for_viz = False
|
224 |
+
|
225 |
+
for ii in range(len(data_anns_ip)):
|
226 |
+
dd = copy.deepcopy(data_anns_ip[ii])
|
227 |
+
|
228 |
+
# filter out unused annotation here
|
229 |
+
filtered_annotations = []
|
230 |
+
for ii, aa in enumerate(dd['annotation']):
|
231 |
+
|
232 |
+
if 'individual' in aa.keys():
|
233 |
+
aa['individual'] = int(aa['individual'])
|
234 |
+
|
235 |
+
# if only one call labeled it has to be from the same individual
|
236 |
+
if len(dd['annotation']) == 1:
|
237 |
+
aa['individual'] = 0
|
238 |
+
|
239 |
+
# convert class name into class label
|
240 |
+
if aa['class'] in self.params['class_names']:
|
241 |
+
aa['class_id'] = self.params['class_names'].index(aa['class'])
|
242 |
+
else:
|
243 |
+
aa['class_id'] = -1
|
244 |
+
|
245 |
+
if aa['class'] not in self.params['classes_to_ignore']:
|
246 |
+
filtered_annotations.append(aa)
|
247 |
+
|
248 |
+
dd['annotation'] = filtered_annotations
|
249 |
+
dd['start_times'] = np.array([aa['start_time'] for aa in dd['annotation']])
|
250 |
+
dd['end_times'] = np.array([aa['end_time'] for aa in dd['annotation']])
|
251 |
+
dd['high_freqs'] = np.array([float(aa['high_freq']) for aa in dd['annotation']])
|
252 |
+
dd['low_freqs'] = np.array([float(aa['low_freq']) for aa in dd['annotation']])
|
253 |
+
dd['class_ids'] = np.array([aa['class_id'] for aa in dd['annotation']]).astype(np.int)
|
254 |
+
dd['individual_ids'] = np.array([aa['individual'] for aa in dd['annotation']]).astype(np.int)
|
255 |
+
|
256 |
+
# file level class name
|
257 |
+
dd['class_id_file'] = -1
|
258 |
+
if 'class_name' in dd.keys():
|
259 |
+
if dd['class_name'] in self.params['class_names']:
|
260 |
+
dd['class_id_file'] = self.params['class_names'].index(dd['class_name'])
|
261 |
+
|
262 |
+
self.data_anns.append(dd)
|
263 |
+
|
264 |
+
ann_cnt = [len(aa['annotation']) for aa in self.data_anns]
|
265 |
+
self.max_num_anns = 2*np.max(ann_cnt) # x2 because we may be combining files during training
|
266 |
+
|
267 |
+
print('\n')
|
268 |
+
if dataset_name is not None:
|
269 |
+
print('Dataset : ' + dataset_name)
|
270 |
+
if self.is_train:
|
271 |
+
print('Split type : train')
|
272 |
+
else:
|
273 |
+
print('Split type : test')
|
274 |
+
print('Num files : ' + str(len(self.data_anns)))
|
275 |
+
print('Num calls : ' + str(np.sum(ann_cnt)))
|
276 |
+
|
277 |
+
|
278 |
+
def get_file_and_anns(self, index=None):
|
279 |
+
|
280 |
+
# if no file specified, choose random one
|
281 |
+
if index == None:
|
282 |
+
index = np.random.randint(0, len(self.data_anns))
|
283 |
+
|
284 |
+
audio_file = self.data_anns[index]['file_path']
|
285 |
+
sampling_rate, audio_raw = au.load_audio_file(audio_file, self.data_anns[index]['time_exp'],
|
286 |
+
self.params['target_samp_rate'], self.params['scale_raw_audio'])
|
287 |
+
|
288 |
+
# copy annotation
|
289 |
+
ann = {}
|
290 |
+
ann['annotated'] = self.data_anns[index]['annotated']
|
291 |
+
ann['class_id_file'] = self.data_anns[index]['class_id_file']
|
292 |
+
keys = ['start_times', 'end_times', 'high_freqs', 'low_freqs', 'class_ids', 'individual_ids']
|
293 |
+
for kk in keys:
|
294 |
+
ann[kk] = self.data_anns[index][kk].copy()
|
295 |
+
|
296 |
+
# if train then grab a random crop
|
297 |
+
if self.is_train:
|
298 |
+
nfft = int(self.params['fft_win_length']*sampling_rate)
|
299 |
+
noverlap = int(self.params['fft_overlap']*nfft)
|
300 |
+
length_samples = self.params['spec_train_width']*(nfft - noverlap) + noverlap
|
301 |
+
|
302 |
+
if audio_raw.shape[0] - length_samples > 0:
|
303 |
+
sample_crop = np.random.randint(audio_raw.shape[0] - length_samples)
|
304 |
+
else:
|
305 |
+
sample_crop = 0
|
306 |
+
audio_raw = audio_raw[sample_crop:sample_crop+length_samples]
|
307 |
+
ann['start_times'] = ann['start_times'] - sample_crop/float(sampling_rate)
|
308 |
+
ann['end_times'] = ann['end_times'] - sample_crop/float(sampling_rate)
|
309 |
+
|
310 |
+
# pad audio
|
311 |
+
if self.is_train:
|
312 |
+
op_spec_target_size = self.params['spec_train_width']
|
313 |
+
else:
|
314 |
+
op_spec_target_size = None
|
315 |
+
audio_raw = au.pad_audio(audio_raw, sampling_rate, self.params['fft_win_length'],
|
316 |
+
self.params['fft_overlap'], self.params['resize_factor'],
|
317 |
+
self.params['spec_divide_factor'], op_spec_target_size)
|
318 |
+
duration = audio_raw.shape[0] / float(sampling_rate)
|
319 |
+
|
320 |
+
# sort based on time
|
321 |
+
inds = np.argsort(ann['start_times'])
|
322 |
+
for kk in ann.keys():
|
323 |
+
if (kk != 'class_id_file') and (kk != 'annotated'):
|
324 |
+
ann[kk] = ann[kk][inds]
|
325 |
+
|
326 |
+
return audio_raw, sampling_rate, duration, ann
|
327 |
+
|
328 |
+
|
329 |
+
def __getitem__(self, index):
|
330 |
+
|
331 |
+
# load audio file
|
332 |
+
audio, sampling_rate, duration, ann = self.get_file_and_anns(index)
|
333 |
+
|
334 |
+
# augment on raw audio
|
335 |
+
if self.is_train and self.params['augment_at_train']:
|
336 |
+
# augment - combine with random audio file
|
337 |
+
if self.params['augment_at_train_combine'] and np.random.random() < self.params['aug_prob']:
|
338 |
+
audio2, sampling_rate2, duration2, ann2 = self.get_file_and_anns()
|
339 |
+
audio, ann = combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2)
|
340 |
+
|
341 |
+
# simulate echo by adding delayed copy of the file
|
342 |
+
if np.random.random() < self.params['aug_prob']:
|
343 |
+
audio = echo_aug(audio, sampling_rate, self.params)
|
344 |
+
|
345 |
+
# resample the audio
|
346 |
+
#if np.random.random() < self.params['aug_prob']:
|
347 |
+
# audio, sampling_rate, duration = resample_aug(audio, sampling_rate, self.params)
|
348 |
+
|
349 |
+
# create spectrogram
|
350 |
+
spec, spec_for_viz = au.generate_spectrogram(audio, sampling_rate, self.params, self.return_spec_for_viz)
|
351 |
+
rsf = self.params['resize_factor']
|
352 |
+
spec_op_shape = (int(self.params['spec_height']*rsf), int(spec.shape[1]*rsf))
|
353 |
+
|
354 |
+
# resize the spec
|
355 |
+
spec = torch.from_numpy(spec).unsqueeze(0).unsqueeze(0)
|
356 |
+
spec = F.interpolate(spec, size=spec_op_shape, mode='bilinear', align_corners=False).squeeze(0)
|
357 |
+
|
358 |
+
# augment spectrogram
|
359 |
+
if self.is_train and self.params['augment_at_train']:
|
360 |
+
|
361 |
+
if np.random.random() < self.params['aug_prob']:
|
362 |
+
spec = scale_vol_aug(spec, self.params)
|
363 |
+
|
364 |
+
if np.random.random() < self.params['aug_prob']:
|
365 |
+
spec = warp_spec_aug(spec, ann, self.return_spec_for_viz, self.params)
|
366 |
+
|
367 |
+
if np.random.random() < self.params['aug_prob']:
|
368 |
+
spec = mask_time_aug(spec, self.params)
|
369 |
+
|
370 |
+
if np.random.random() < self.params['aug_prob']:
|
371 |
+
spec = mask_freq_aug(spec, self.params)
|
372 |
+
|
373 |
+
outputs = {}
|
374 |
+
outputs['spec'] = spec
|
375 |
+
if self.return_spec_for_viz:
|
376 |
+
outputs['spec_for_viz'] = torch.from_numpy(spec_for_viz).unsqueeze(0)
|
377 |
+
|
378 |
+
# create ground truth heatmaps
|
379 |
+
outputs['y_2d_det'], outputs['y_2d_size'], outputs['y_2d_classes'], ann_aug =\
|
380 |
+
generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, self.params)
|
381 |
+
|
382 |
+
# hack to get around requirement that all vectors are the same length in
|
383 |
+
# the output batch
|
384 |
+
pad_size = self.max_num_anns-len(ann_aug['individual_ids'])
|
385 |
+
outputs['is_valid'] = pad_aray(np.ones(len(ann_aug['individual_ids'])), pad_size)
|
386 |
+
keys = ['class_ids', 'individual_ids', 'x_inds', 'y_inds',
|
387 |
+
'start_times', 'end_times', 'low_freqs', 'high_freqs']
|
388 |
+
for kk in keys:
|
389 |
+
outputs[kk] = pad_aray(ann_aug[kk], pad_size)
|
390 |
+
|
391 |
+
# convert to pytorch
|
392 |
+
for kk in outputs.keys():
|
393 |
+
if type(outputs[kk]) != torch.Tensor:
|
394 |
+
outputs[kk] = torch.from_numpy(outputs[kk])
|
395 |
+
|
396 |
+
# scalars
|
397 |
+
outputs['class_id_file'] = ann['class_id_file']
|
398 |
+
outputs['annotated'] = ann['annotated']
|
399 |
+
outputs['duration'] = duration
|
400 |
+
outputs['sampling_rate'] = sampling_rate
|
401 |
+
outputs['file_id'] = index
|
402 |
+
|
403 |
+
return outputs
|
404 |
+
|
405 |
+
|
406 |
+
def __len__(self):
|
407 |
+
return len(self.data_anns)
|
bat_detect/train/evaluate.py
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from sklearn.metrics import roc_curve, auc
|
3 |
+
from sklearn.metrics import accuracy_score, balanced_accuracy_score
|
4 |
+
|
5 |
+
|
6 |
+
def compute_error_auc(op_str, gt, pred, prob):
|
7 |
+
|
8 |
+
# classification error
|
9 |
+
pred_int = (pred > prob).astype(np.int)
|
10 |
+
class_acc = (pred_int == gt).mean() * 100.0
|
11 |
+
|
12 |
+
# ROC - area under curve
|
13 |
+
fpr, tpr, thresholds = roc_curve(gt, pred)
|
14 |
+
roc_auc = auc(fpr, tpr)
|
15 |
+
|
16 |
+
print(op_str + ", class acc = {:.3f}, ROC AUC = {:.3f}".format(class_acc, roc_auc))
|
17 |
+
#return class_acc, roc_auc
|
18 |
+
|
19 |
+
|
20 |
+
def calc_average_precision(recall, precision):
|
21 |
+
|
22 |
+
precision[np.isnan(precision)] = 0
|
23 |
+
recall[np.isnan(recall)] = 0
|
24 |
+
|
25 |
+
# pascal 12 way
|
26 |
+
mprec = np.hstack((0, precision, 0))
|
27 |
+
mrec = np.hstack((0, recall, 1))
|
28 |
+
for ii in range(mprec.shape[0]-2, -1,-1):
|
29 |
+
mprec[ii] = np.maximum(mprec[ii], mprec[ii+1])
|
30 |
+
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0]+1
|
31 |
+
ave_prec = ((mrec[inds] - mrec[inds-1])*mprec[inds]).sum()
|
32 |
+
|
33 |
+
return float(ave_prec)
|
34 |
+
|
35 |
+
|
36 |
+
def calc_recall_at_x(recall, precision, x=0.95):
|
37 |
+
precision[np.isnan(precision)] = 0
|
38 |
+
recall[np.isnan(recall)] = 0
|
39 |
+
|
40 |
+
inds = np.where(precision[::-1]>x)[0]
|
41 |
+
if len(inds) > 0:
|
42 |
+
return float(recall[::-1][inds[0]])
|
43 |
+
else:
|
44 |
+
return 0.0
|
45 |
+
|
46 |
+
|
47 |
+
def compute_affinity_1d(pred_box, gt_boxes, threshold):
|
48 |
+
# first entry is start time
|
49 |
+
score = np.abs(pred_box[0] - gt_boxes[:, 0])
|
50 |
+
valid_detection = np.min(score) <= threshold
|
51 |
+
return valid_detection, np.argmin(score)
|
52 |
+
|
53 |
+
|
54 |
+
def compute_pre_rec(gts, preds, eval_mode, class_of_interest, num_classes, threshold, ignore_start_end):
|
55 |
+
"""
|
56 |
+
Computes precision and recall. Assumes that each file has been exhaustively
|
57 |
+
annotated. Will not count predicted detection with a start time that is within
|
58 |
+
ignore_start_end miliseconds of the start or end of the file.
|
59 |
+
|
60 |
+
eval_mode == 'detection'
|
61 |
+
Returns overall detection results (not per class)
|
62 |
+
|
63 |
+
eval_mode == 'per_class'
|
64 |
+
Filters ground truth based on class of interest. This will ignore predictions
|
65 |
+
assigned to gt with unknown class.
|
66 |
+
|
67 |
+
eval_mode = 'top_class'
|
68 |
+
Turns the problem into a binary one and selects the top predicted class
|
69 |
+
for each predicted detection
|
70 |
+
|
71 |
+
"""
|
72 |
+
|
73 |
+
# get predictions and put in array
|
74 |
+
pred_boxes = []
|
75 |
+
confidence = []
|
76 |
+
pred_class = []
|
77 |
+
file_ids = []
|
78 |
+
for pid, pp in enumerate(preds):
|
79 |
+
|
80 |
+
# filter predicted calls that are too near the start or end of the file
|
81 |
+
file_dur = gts[pid]['duration']
|
82 |
+
valid_inds = (pp['start_times'] >= ignore_start_end) & (pp['start_times'] <= (file_dur - ignore_start_end))
|
83 |
+
|
84 |
+
pred_boxes.append(np.vstack((pp['start_times'][valid_inds], pp['end_times'][valid_inds],
|
85 |
+
pp['low_freqs'][valid_inds], pp['high_freqs'][valid_inds])).T)
|
86 |
+
|
87 |
+
if eval_mode == 'detection':
|
88 |
+
# overall detection
|
89 |
+
confidence.append(pp['det_probs'][valid_inds])
|
90 |
+
elif eval_mode == 'per_class':
|
91 |
+
# per class
|
92 |
+
confidence.append(pp['class_probs'].T[valid_inds, class_of_interest])
|
93 |
+
elif eval_mode == 'top_class':
|
94 |
+
# per class - note that sometimes 'class_probs' can be num_classes+1 in size
|
95 |
+
top_class = np.argmax(pp['class_probs'].T[valid_inds, :num_classes], 1)
|
96 |
+
confidence.append(pp['class_probs'].T[valid_inds, top_class])
|
97 |
+
pred_class.append(top_class)
|
98 |
+
|
99 |
+
# be careful, assuming the order in the list is same as GT
|
100 |
+
file_ids.append([pid]*valid_inds.sum())
|
101 |
+
|
102 |
+
confidence = np.hstack(confidence)
|
103 |
+
file_ids = np.hstack(file_ids).astype(np.int)
|
104 |
+
pred_boxes = np.vstack(pred_boxes)
|
105 |
+
if len(pred_class) > 0:
|
106 |
+
pred_class = np.hstack(pred_class)
|
107 |
+
|
108 |
+
|
109 |
+
# extract relevant ground truth boxes
|
110 |
+
gt_boxes = []
|
111 |
+
gt_assigned = []
|
112 |
+
gt_class = []
|
113 |
+
gt_generic_class = []
|
114 |
+
num_positives = 0
|
115 |
+
for gg in gts:
|
116 |
+
|
117 |
+
# filter ground truth calls that are too near the start or end of the file
|
118 |
+
file_dur = gg['duration']
|
119 |
+
valid_inds = (gg['start_times'] >= ignore_start_end) & (gg['start_times'] <= (file_dur - ignore_start_end))
|
120 |
+
|
121 |
+
# note, files with the incorrect duration will cause a problem
|
122 |
+
if (gg['start_times'] > file_dur).sum() > 0:
|
123 |
+
print('Error: file duration incorrect for', gg['id'])
|
124 |
+
assert(False)
|
125 |
+
|
126 |
+
boxes = np.vstack((gg['start_times'][valid_inds], gg['end_times'][valid_inds],
|
127 |
+
gg['low_freqs'][valid_inds], gg['high_freqs'][valid_inds])).T
|
128 |
+
gen_class = gg['class_ids'][valid_inds] == -1
|
129 |
+
class_ids = gg['class_ids'][valid_inds]
|
130 |
+
|
131 |
+
# keep track of the number of relevant ground truth calls
|
132 |
+
if eval_mode == 'detection':
|
133 |
+
# all valid ones
|
134 |
+
num_positives += len(gg['start_times'][valid_inds])
|
135 |
+
elif eval_mode == 'per_class':
|
136 |
+
# all valid ones with class of interest
|
137 |
+
num_positives += (gg['class_ids'][valid_inds] == class_of_interest).sum()
|
138 |
+
elif eval_mode == 'top_class':
|
139 |
+
# all valid ones with non generic class
|
140 |
+
num_positives += (gg['class_ids'][valid_inds] > -1).sum()
|
141 |
+
|
142 |
+
# find relevant classes (i.e. class_of_interest) and events without known class (i.e. generic class, -1)
|
143 |
+
if eval_mode == 'per_class':
|
144 |
+
class_inds = (class_ids == class_of_interest) | (class_ids == -1)
|
145 |
+
boxes = boxes[class_inds, :]
|
146 |
+
gen_class = gen_class[class_inds]
|
147 |
+
class_ids = class_ids[class_inds]
|
148 |
+
|
149 |
+
gt_assigned.append(np.zeros(boxes.shape[0]))
|
150 |
+
gt_boxes.append(boxes)
|
151 |
+
gt_generic_class.append(gen_class)
|
152 |
+
gt_class.append(class_ids)
|
153 |
+
|
154 |
+
|
155 |
+
# loop through detections and keep track of those that have been assigned
|
156 |
+
true_pos = np.zeros(confidence.shape[0])
|
157 |
+
valid_inds = np.ones(confidence.shape[0]) == 1 # intialize to True
|
158 |
+
sorted_inds = np.argsort(confidence)[::-1] # sort high to low
|
159 |
+
for ii, ind in enumerate(sorted_inds):
|
160 |
+
gt_id = file_ids[ind]
|
161 |
+
|
162 |
+
valid_det = False
|
163 |
+
if gt_boxes[gt_id].shape[0] > 0:
|
164 |
+
# compute overlap
|
165 |
+
valid_det, det_ind = compute_affinity_1d(pred_boxes[ind], gt_boxes[gt_id],
|
166 |
+
threshold)
|
167 |
+
|
168 |
+
# valid detection that has not already been assigned
|
169 |
+
if valid_det and (gt_assigned[gt_id][det_ind] == 0):
|
170 |
+
|
171 |
+
count_as_true_pos = True
|
172 |
+
if eval_mode == 'top_class' and (gt_class[gt_id][det_ind] != pred_class[ind]):
|
173 |
+
# needs to be the same class
|
174 |
+
count_as_true_pos = False
|
175 |
+
|
176 |
+
if count_as_true_pos:
|
177 |
+
true_pos[ii] = 1
|
178 |
+
|
179 |
+
gt_assigned[gt_id][det_ind] = 1
|
180 |
+
|
181 |
+
# if event is generic class (i.e. gt_generic_class[gt_id][det_ind] is True)
|
182 |
+
# and eval_mode != 'detection', then ignore it
|
183 |
+
if gt_generic_class[gt_id][det_ind]:
|
184 |
+
if eval_mode == 'per_class' or eval_mode == 'top_class':
|
185 |
+
valid_inds[ii] = False
|
186 |
+
|
187 |
+
|
188 |
+
# store threshold values - used for plotting
|
189 |
+
conf_sorted = np.sort(confidence)[::-1][valid_inds]
|
190 |
+
thresholds = np.linspace(0.1, 0.9, 9)
|
191 |
+
thresholds_inds = np.zeros(len(thresholds), dtype=np.int)
|
192 |
+
for ii, tt in enumerate(thresholds):
|
193 |
+
thresholds_inds[ii] = np.argmin(conf_sorted > tt)
|
194 |
+
thresholds_inds[thresholds_inds==0] = -1
|
195 |
+
|
196 |
+
# compute precision and recall
|
197 |
+
true_pos = true_pos[valid_inds]
|
198 |
+
false_pos_c = np.cumsum(1-true_pos)
|
199 |
+
true_pos_c = np.cumsum(true_pos)
|
200 |
+
|
201 |
+
recall = true_pos_c / num_positives
|
202 |
+
precision = true_pos_c / np.maximum(true_pos_c + false_pos_c, np.finfo(np.float64).eps)
|
203 |
+
|
204 |
+
results = {}
|
205 |
+
results['recall'] = recall
|
206 |
+
results['precision'] = precision
|
207 |
+
results['num_gt'] = num_positives
|
208 |
+
|
209 |
+
results['thresholds'] = thresholds
|
210 |
+
results['thresholds_inds'] = thresholds_inds
|
211 |
+
|
212 |
+
if num_positives == 0:
|
213 |
+
results['avg_prec'] = np.nan
|
214 |
+
results['rec_at_x'] = np.nan
|
215 |
+
else:
|
216 |
+
results['avg_prec'] = np.round(calc_average_precision(recall, precision), 5)
|
217 |
+
results['rec_at_x'] = np.round(calc_recall_at_x(recall, precision), 5)
|
218 |
+
|
219 |
+
return results
|
220 |
+
|
221 |
+
|
222 |
+
def compute_file_accuracy_simple(gts, preds, num_classes):
|
223 |
+
"""
|
224 |
+
Evaluates the prediction accuracy at a file level.
|
225 |
+
Does not include files that have more than one class (or the generic class).
|
226 |
+
|
227 |
+
Simply chooses the class per file that has the highest probability overall.
|
228 |
+
"""
|
229 |
+
|
230 |
+
gt_valid = []
|
231 |
+
pred_valid = []
|
232 |
+
for ii in range(len(gts)):
|
233 |
+
gt_class = np.unique(gts[ii]['class_ids'])
|
234 |
+
if len(gt_class) == 1 and gt_class[0] != -1:
|
235 |
+
gt_valid.append(gt_class[0])
|
236 |
+
pred = preds[ii]['class_probs'][:num_classes, :].T
|
237 |
+
pred_valid.append(np.argmax(pred.mean(0)))
|
238 |
+
acc = (np.array(gt_valid) == np.array(pred_valid)).mean()
|
239 |
+
|
240 |
+
res = {}
|
241 |
+
res['num_valid_files'] = len(gt_valid)
|
242 |
+
res['num_total_files'] = len(gts)
|
243 |
+
res['gt_valid_file'] = gt_valid
|
244 |
+
res['pred_valid_file'] = pred_valid
|
245 |
+
res['file_acc'] = np.round(acc, 5)
|
246 |
+
return res
|
247 |
+
|
248 |
+
|
249 |
+
def compute_file_accuracy(gts, preds, num_classes):
|
250 |
+
"""
|
251 |
+
Evaluates the prediction accuracy at a file level.
|
252 |
+
Does not include files that have more than one class (or the unknown class).
|
253 |
+
|
254 |
+
Tries several different detection thresholds and picks the best one.
|
255 |
+
"""
|
256 |
+
|
257 |
+
# compute min and max scoring range - then threshold
|
258 |
+
min_val = 0
|
259 |
+
mins = [pp['class_probs'].min() for pp in preds if pp['class_probs'].shape[1] > 0]
|
260 |
+
if len(mins) > 0:
|
261 |
+
min_val = np.min(mins)
|
262 |
+
|
263 |
+
max_val = 1.0
|
264 |
+
maxes = [pp['class_probs'].max() for pp in preds if pp['class_probs'].shape[1] > 0]
|
265 |
+
if len(maxes) > 0:
|
266 |
+
max_val = np.max(maxes)
|
267 |
+
|
268 |
+
thresh = np.linspace(min_val, max_val, 11)[:10]
|
269 |
+
|
270 |
+
# loop over the files and store the accuracy at different prediction thresholds
|
271 |
+
# only include gt files that have one valid species
|
272 |
+
gt_valid = []
|
273 |
+
pred_valid_all = []
|
274 |
+
for ii in range(len(gts)):
|
275 |
+
gt_class = np.unique(gts[ii]['class_ids'])
|
276 |
+
if len(gt_class) == 1 and gt_class[0] != -1:
|
277 |
+
gt_valid.append(gt_class[0])
|
278 |
+
pred = preds[ii]['class_probs'][:num_classes, :].T
|
279 |
+
p_class = np.zeros(len(thresh))
|
280 |
+
for tt in range(len(thresh)):
|
281 |
+
p_class[tt] = (pred*(pred>=thresh[tt])).sum(0).argmax()
|
282 |
+
pred_valid_all.append(p_class)
|
283 |
+
|
284 |
+
# pick the result corresponding to the overall best threshold
|
285 |
+
pred_valid_all = np.vstack(pred_valid_all)
|
286 |
+
acc_per_thresh = (np.array(gt_valid)[..., np.newaxis] == pred_valid_all).mean(0)
|
287 |
+
best_thresh = np.argmax(acc_per_thresh)
|
288 |
+
best_acc = acc_per_thresh[best_thresh]
|
289 |
+
pred_valid = pred_valid_all[:, best_thresh].astype(np.int).tolist()
|
290 |
+
|
291 |
+
res = {}
|
292 |
+
res['num_valid_files'] = len(gt_valid)
|
293 |
+
res['num_total_files'] = len(gts)
|
294 |
+
res['gt_valid_file'] = gt_valid
|
295 |
+
res['pred_valid_file'] = pred_valid
|
296 |
+
res['file_acc'] = np.round(best_acc, 5)
|
297 |
+
|
298 |
+
return res
|
299 |
+
|
300 |
+
|
301 |
+
def evaluate_predictions(gts, preds, class_names, detection_overlap, ignore_start_end=0.0):
|
302 |
+
"""
|
303 |
+
Computes metrics derived from the precision and recall.
|
304 |
+
Assumes that gts and preds are both lists of the same lengths, with ground
|
305 |
+
truth and predictions contained within.
|
306 |
+
|
307 |
+
Returns the overall detection results, and per class results
|
308 |
+
"""
|
309 |
+
|
310 |
+
assert(len(gts) == len(preds))
|
311 |
+
num_classes = len(class_names)
|
312 |
+
|
313 |
+
# evaluate detection on its own i.e. ignoring class
|
314 |
+
det_results = compute_pre_rec(gts, preds, 'detection', None, num_classes, detection_overlap, ignore_start_end)
|
315 |
+
top_class = compute_pre_rec(gts, preds, 'top_class', None, num_classes, detection_overlap, ignore_start_end)
|
316 |
+
det_results['top_class'] = top_class
|
317 |
+
|
318 |
+
# per class evaluation
|
319 |
+
det_results['class_pr'] = []
|
320 |
+
for cc in range(num_classes):
|
321 |
+
res = compute_pre_rec(gts, preds, 'per_class', cc, num_classes, detection_overlap, ignore_start_end)
|
322 |
+
res['name'] = class_names[cc]
|
323 |
+
det_results['class_pr'].append(res)
|
324 |
+
|
325 |
+
# ignores classes that are not present in the test set
|
326 |
+
det_results['avg_prec_class'] = np.mean([rs['avg_prec'] for rs in det_results['class_pr'] if rs['num_gt'] > 0])
|
327 |
+
det_results['avg_prec_class'] = np.round(det_results['avg_prec_class'], 5)
|
328 |
+
|
329 |
+
# file level evaluation
|
330 |
+
res_file = compute_file_accuracy(gts, preds, num_classes)
|
331 |
+
det_results.update(res_file)
|
332 |
+
|
333 |
+
return det_results
|
bat_detect/train/losses.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
def bbox_size_loss(pred_size, gt_size):
|
6 |
+
"""
|
7 |
+
Bounding box size loss. Only compute loss where there is a bounding box.
|
8 |
+
"""
|
9 |
+
gt_size_mask = (gt_size > 0).float()
|
10 |
+
return (F.l1_loss(pred_size*gt_size_mask, gt_size, reduction='sum') / (gt_size_mask.sum() + 1e-5))
|
11 |
+
|
12 |
+
|
13 |
+
def focal_loss(pred, gt, weights=None, valid_mask=None):
|
14 |
+
"""
|
15 |
+
Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints
|
16 |
+
pred (batch x c x h x w)
|
17 |
+
gt (batch x c x h x w)
|
18 |
+
"""
|
19 |
+
eps = 1e-5
|
20 |
+
beta = 4
|
21 |
+
alpha = 2
|
22 |
+
|
23 |
+
pos_inds = gt.eq(1).float()
|
24 |
+
neg_inds = gt.lt(1).float()
|
25 |
+
|
26 |
+
pos_loss = torch.log(pred + eps) * torch.pow(1 - pred, alpha) * pos_inds
|
27 |
+
neg_loss = torch.log(1 - pred + eps) * torch.pow(pred, alpha) * torch.pow(1 - gt, beta) * neg_inds
|
28 |
+
|
29 |
+
if weights is not None:
|
30 |
+
pos_loss = pos_loss*weights
|
31 |
+
#neg_loss = neg_loss*weights
|
32 |
+
|
33 |
+
if valid_mask is not None:
|
34 |
+
pos_loss = pos_loss*valid_mask
|
35 |
+
neg_loss = neg_loss*valid_mask
|
36 |
+
|
37 |
+
pos_loss = pos_loss.sum()
|
38 |
+
neg_loss = neg_loss.sum()
|
39 |
+
|
40 |
+
num_pos = pos_inds.float().sum()
|
41 |
+
if num_pos == 0:
|
42 |
+
loss = -neg_loss
|
43 |
+
else:
|
44 |
+
loss = -(pos_loss + neg_loss) / num_pos
|
45 |
+
return loss
|
46 |
+
|
47 |
+
|
48 |
+
def mse_loss(pred, gt, weights=None, valid_mask=None):
|
49 |
+
"""
|
50 |
+
Mean squared error loss.
|
51 |
+
"""
|
52 |
+
if valid_mask is None:
|
53 |
+
op = ((gt-pred)**2).mean()
|
54 |
+
else:
|
55 |
+
op = (valid_mask*((gt-pred)**2)).sum() / valid_mask.sum()
|
56 |
+
return op
|
bat_detect/train/readme.md
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## How to train a model from scratch
|
2 |
+
`python train_model.py data_dir annotation_dir` e.g.
|
3 |
+
`python train_model.py /data1/bat_data/data/ /data1/bat_data/annotations/anns/`
|
4 |
+
|
5 |
+
More comprehensive instructions are provided in the finetune directory.
|
6 |
+
|
7 |
+
|
8 |
+
## Training on your own data
|
9 |
+
You can either use the finetuning scripts to finetune from an existing training dataset. Follow the instructions in the `../finetune/` directory.
|
10 |
+
|
11 |
+
Alternatively, you can train from scratch. First, you will need to create your own annotation file (like in the finetune example), and then you will need to edit `train_split.py` to add your new dataset and specify which combination of files you want to train on.
|
12 |
+
|
13 |
+
Note, if training from scratch and you want to include the existing data, you may need to set all the class names to the generic class name ('Bat') so that the existing species are not added to your model, but instead just used to help perform the bat/not bat task.
|
14 |
+
|
15 |
+
## Additional notes
|
16 |
+
Having blank files with no bats in them is also useful, just make sure that the annotation files lists them as not being annotated (i.e. `is_annotated=True`).
|
17 |
+
|
18 |
+
Training will be slow without a GPU.
|
bat_detect/train/train_model.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
7 |
+
import json
|
8 |
+
import argparse
|
9 |
+
|
10 |
+
import sys
|
11 |
+
sys.path.append(os.path.join('..', '..'))
|
12 |
+
|
13 |
+
import bat_detect.detector.parameters as parameters
|
14 |
+
import bat_detect.detector.models as models
|
15 |
+
import bat_detect.detector.post_process as pp
|
16 |
+
import bat_detect.utils.plot_utils as pu
|
17 |
+
|
18 |
+
import bat_detect.train.audio_dataloader as adl
|
19 |
+
import bat_detect.train.evaluate as evl
|
20 |
+
import bat_detect.train.train_utils as tu
|
21 |
+
import bat_detect.train.train_split as ts
|
22 |
+
import bat_detect.train.losses as losses
|
23 |
+
|
24 |
+
import warnings
|
25 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
26 |
+
|
27 |
+
|
28 |
+
def save_images_batch(model, data_loader, params):
|
29 |
+
print('\nsaving images ...')
|
30 |
+
|
31 |
+
is_train_state = data_loader.dataset.is_train
|
32 |
+
data_loader.dataset.is_train = False
|
33 |
+
data_loader.dataset.return_spec_for_viz = True
|
34 |
+
model.eval()
|
35 |
+
|
36 |
+
ind = 0 # first image in each batch
|
37 |
+
with torch.no_grad():
|
38 |
+
for batch_idx, inputs in enumerate(data_loader):
|
39 |
+
data = inputs['spec'].to(params['device'])
|
40 |
+
outputs = model(data)
|
41 |
+
|
42 |
+
spec_viz = inputs['spec_for_viz'].data.cpu().numpy()
|
43 |
+
orig_index = inputs['file_id'][ind]
|
44 |
+
plot_title = data_loader.dataset.data_anns[orig_index]['id']
|
45 |
+
op_file_name = params['op_im_dir_test'] + data_loader.dataset.data_anns[orig_index]['id'] + '.jpg'
|
46 |
+
save_image(spec_viz, outputs, ind, inputs, params, op_file_name, plot_title)
|
47 |
+
|
48 |
+
data_loader.dataset.is_train = is_train_state
|
49 |
+
data_loader.dataset.return_spec_for_viz = False
|
50 |
+
|
51 |
+
|
52 |
+
def save_image(spec_viz, outputs, ind, inputs, params, op_file_name, plot_title):
|
53 |
+
pred_nms, _ = pp.run_nms(outputs, params, inputs['sampling_rate'].float())
|
54 |
+
pred_hm = outputs['pred_det'][ind, 0, :].data.cpu().numpy()
|
55 |
+
spec_viz = spec_viz[ind, 0, :]
|
56 |
+
gt = parse_gt_data(inputs)[ind]
|
57 |
+
sampling_rate = inputs['sampling_rate'][ind].item()
|
58 |
+
duration = inputs['duration'][ind].item()
|
59 |
+
|
60 |
+
pu.plot_spec(spec_viz, sampling_rate, duration, gt, pred_nms[ind],
|
61 |
+
params, plot_title, op_file_name, pred_hm, plot_boxes=True, fixed_aspect=False)
|
62 |
+
|
63 |
+
|
64 |
+
def loss_fun(outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq):
|
65 |
+
|
66 |
+
# detection loss
|
67 |
+
loss = params['det_loss_weight']*det_criterion(outputs['pred_det'], gt_det)
|
68 |
+
|
69 |
+
# bounding box size loss
|
70 |
+
loss += params['size_loss_weight']*losses.bbox_size_loss(outputs['pred_size'], gt_size)
|
71 |
+
|
72 |
+
# classification loss
|
73 |
+
valid_mask = (gt_class[:, :-1, :, :].sum(1) > 0).float().unsqueeze(1)
|
74 |
+
p_class = outputs['pred_class'][:, :-1, :]
|
75 |
+
loss += params['class_loss_weight']*det_criterion(p_class, gt_class[:, :-1, :], valid_mask=valid_mask)
|
76 |
+
|
77 |
+
return loss
|
78 |
+
|
79 |
+
|
80 |
+
def train(model, epoch, data_loader, det_criterion, optimizer, scheduler, params):
|
81 |
+
|
82 |
+
model.train()
|
83 |
+
|
84 |
+
train_loss = tu.AverageMeter()
|
85 |
+
class_inv_freq = torch.from_numpy(np.array(params['class_inv_freq'], dtype=np.float32)).to(params['device'])
|
86 |
+
class_inv_freq = class_inv_freq.unsqueeze(0).unsqueeze(2).unsqueeze(2)
|
87 |
+
|
88 |
+
print('\nEpoch', epoch)
|
89 |
+
for batch_idx, inputs in enumerate(data_loader):
|
90 |
+
|
91 |
+
data = inputs['spec'].to(params['device'])
|
92 |
+
gt_det = inputs['y_2d_det'].to(params['device'])
|
93 |
+
gt_size = inputs['y_2d_size'].to(params['device'])
|
94 |
+
gt_class = inputs['y_2d_classes'].to(params['device'])
|
95 |
+
|
96 |
+
optimizer.zero_grad()
|
97 |
+
outputs = model(data)
|
98 |
+
|
99 |
+
loss = loss_fun(outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq)
|
100 |
+
|
101 |
+
train_loss.update(loss.item(), data.shape[0])
|
102 |
+
loss.backward()
|
103 |
+
optimizer.step()
|
104 |
+
scheduler.step()
|
105 |
+
|
106 |
+
if batch_idx % 50 == 0 and batch_idx != 0:
|
107 |
+
print('[{}/{}]\tLoss: {:.4f}'.format(
|
108 |
+
batch_idx * len(data), len(data_loader.dataset), train_loss.avg))
|
109 |
+
|
110 |
+
print('Train loss : {:.4f}'.format(train_loss.avg))
|
111 |
+
|
112 |
+
res = {}
|
113 |
+
res['train_loss'] = float(train_loss.avg)
|
114 |
+
return res
|
115 |
+
|
116 |
+
|
117 |
+
def test(model, epoch, data_loader, det_criterion, params):
|
118 |
+
model.eval()
|
119 |
+
predictions = []
|
120 |
+
ground_truths = []
|
121 |
+
test_loss = tu.AverageMeter()
|
122 |
+
|
123 |
+
class_inv_freq = torch.from_numpy(np.array(params['class_inv_freq'], dtype=np.float32)).to(params['device'])
|
124 |
+
class_inv_freq = class_inv_freq.unsqueeze(0).unsqueeze(2).unsqueeze(2)
|
125 |
+
|
126 |
+
with torch.no_grad():
|
127 |
+
for batch_idx, inputs in enumerate(data_loader):
|
128 |
+
|
129 |
+
data = inputs['spec'].to(params['device'])
|
130 |
+
gt_det = inputs['y_2d_det'].to(params['device'])
|
131 |
+
gt_size = inputs['y_2d_size'].to(params['device'])
|
132 |
+
gt_class = inputs['y_2d_classes'].to(params['device'])
|
133 |
+
|
134 |
+
outputs = model(data)
|
135 |
+
|
136 |
+
# if the model needs a fixed sized intput run this
|
137 |
+
# data = torch.cat(torch.split(data, int(params['spec_train_width']*params['resize_factor']), 3), 0)
|
138 |
+
# outputs = model(data)
|
139 |
+
# for kk in ['pred_det', 'pred_size', 'pred_class']:
|
140 |
+
# outputs[kk] = torch.cat([oo for oo in outputs[kk]], 2).unsqueeze(0)
|
141 |
+
|
142 |
+
if params['save_test_image_during_train'] and batch_idx == 0:
|
143 |
+
# for visualization - save the first prediction
|
144 |
+
ind = 0
|
145 |
+
orig_index = inputs['file_id'][ind]
|
146 |
+
plot_title = data_loader.dataset.data_anns[orig_index]['id']
|
147 |
+
op_file_name = params['op_im_dir'] + str(orig_index.item()).zfill(4) + '_' + str(epoch).zfill(4) + '_pred.jpg'
|
148 |
+
save_image(data, outputs, ind, inputs, params, op_file_name, plot_title)
|
149 |
+
|
150 |
+
loss = loss_fun(outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq)
|
151 |
+
test_loss.update(loss.item(), data.shape[0])
|
152 |
+
|
153 |
+
# do NMS
|
154 |
+
pred_nms, _ = pp.run_nms(outputs, params, inputs['sampling_rate'].float())
|
155 |
+
predictions.extend(pred_nms)
|
156 |
+
|
157 |
+
ground_truths.extend(parse_gt_data(inputs))
|
158 |
+
|
159 |
+
res_det = evl.evaluate_predictions(ground_truths, predictions, params['class_names'],
|
160 |
+
params['detection_overlap'], params['ignore_start_end'])
|
161 |
+
|
162 |
+
print('\nTest loss : {:.4f}'.format(test_loss.avg))
|
163 |
+
print('Rec at 0.95 (det) : {:.4f}'.format(res_det['rec_at_x']))
|
164 |
+
print('Avg prec (cls) : {:.4f}'.format(res_det['avg_prec']))
|
165 |
+
print('File acc (cls) : {:.2f} - for {} out of {}'.format(res_det['file_acc'],
|
166 |
+
res_det['num_valid_files'], res_det['num_total_files']))
|
167 |
+
print('Cls Avg prec (cls) : {:.4f}'.format(res_det['avg_prec_class']))
|
168 |
+
|
169 |
+
print('\nPer class average precision')
|
170 |
+
str_len = np.max([len(rs['name']) for rs in res_det['class_pr']]) + 5
|
171 |
+
for cc, rs in enumerate(res_det['class_pr']):
|
172 |
+
if rs['num_gt'] > 0:
|
173 |
+
print(str(cc).ljust(5) + rs['name'].ljust(str_len) + '{:.4f}'.format(rs['avg_prec']))
|
174 |
+
|
175 |
+
res = {}
|
176 |
+
res['test_loss'] = float(test_loss.avg)
|
177 |
+
|
178 |
+
return res_det, res
|
179 |
+
|
180 |
+
|
181 |
+
def parse_gt_data(inputs):
|
182 |
+
# reads the torch arrays into a dictionary of numpy arrays, taking care to
|
183 |
+
# remove padding data i.e. not valid ones
|
184 |
+
keys = ['start_times', 'end_times', 'low_freqs', 'high_freqs', 'class_ids', 'individual_ids']
|
185 |
+
batch_data = []
|
186 |
+
for ind in range(inputs['start_times'].shape[0]):
|
187 |
+
is_valid = inputs['is_valid'][ind]==1
|
188 |
+
gt = {}
|
189 |
+
for kk in keys:
|
190 |
+
gt[kk] = inputs[kk][ind][is_valid].numpy().astype(np.float32)
|
191 |
+
gt['duration'] = inputs['duration'][ind].item()
|
192 |
+
gt['file_id'] = inputs['file_id'][ind].item()
|
193 |
+
gt['class_id_file'] = inputs['class_id_file'][ind].item()
|
194 |
+
batch_data.append(gt)
|
195 |
+
return batch_data
|
196 |
+
|
197 |
+
|
198 |
+
def select_model(params):
|
199 |
+
num_classes = len(params['class_names'])
|
200 |
+
if params['model_name'] == 'Net2DFast':
|
201 |
+
model = models.Net2DFast(params['num_filters'], num_classes=num_classes,
|
202 |
+
emb_dim=params['emb_dim'], ip_height=params['ip_height'],
|
203 |
+
resize_factor=params['resize_factor'])
|
204 |
+
elif params['model_name'] == 'Net2DFastNoAttn':
|
205 |
+
model = models.Net2DFastNoAttn(params['num_filters'], num_classes=num_classes,
|
206 |
+
emb_dim=params['emb_dim'], ip_height=params['ip_height'],
|
207 |
+
resize_factor=params['resize_factor'])
|
208 |
+
elif params['model_name'] == 'Net2DFastNoCoordConv':
|
209 |
+
model = models.Net2DFastNoCoordConv(params['num_filters'], num_classes=num_classes,
|
210 |
+
emb_dim=params['emb_dim'], ip_height=params['ip_height'],
|
211 |
+
resize_factor=params['resize_factor'])
|
212 |
+
else:
|
213 |
+
print('No valid network specified')
|
214 |
+
return model
|
215 |
+
|
216 |
+
|
217 |
+
if __name__ == "__main__":
|
218 |
+
|
219 |
+
plt.close('all')
|
220 |
+
|
221 |
+
params = parameters.get_params(True)
|
222 |
+
|
223 |
+
if torch.cuda.is_available():
|
224 |
+
params['device'] = 'cuda'
|
225 |
+
else:
|
226 |
+
params['device'] = 'cpu'
|
227 |
+
|
228 |
+
# setup arg parser and populate it with exiting parameters - will not work with lists
|
229 |
+
parser = argparse.ArgumentParser()
|
230 |
+
parser.add_argument('data_dir', type=str,
|
231 |
+
help='Path to root of datasets')
|
232 |
+
parser.add_argument('ann_dir', type=str,
|
233 |
+
help='Path to extracted annotations')
|
234 |
+
parser.add_argument('--train_split', type=str, default='diff', # diff, same
|
235 |
+
help='Which train split to use')
|
236 |
+
parser.add_argument('--notes', type=str, default='',
|
237 |
+
help='Notes to save in text file')
|
238 |
+
parser.add_argument('--do_not_save_images', action='store_false',
|
239 |
+
help='Do not save images at the end of training')
|
240 |
+
parser.add_argument('--standardize_classs_names_ip', type=str,
|
241 |
+
default='Rhinolophus ferrumequinum;Rhinolophus hipposideros',
|
242 |
+
help='Will set low and high frequency the same for these classes. Separate names with ";"')
|
243 |
+
for key, val in params.items():
|
244 |
+
parser.add_argument('--'+key, type=type(val), default=val)
|
245 |
+
params = vars(parser.parse_args())
|
246 |
+
|
247 |
+
# save notes file
|
248 |
+
if params['notes'] != '':
|
249 |
+
tu.write_notes_file(params['experiment'] + 'notes.txt', params['notes'])
|
250 |
+
|
251 |
+
# load the training and test meta data - there are different splits defined
|
252 |
+
train_sets, test_sets = ts.get_train_test_data(params['ann_dir'], params['data_dir'], params['train_split'])
|
253 |
+
train_sets_no_path, test_sets_no_path = ts.get_train_test_data('', '', params['train_split'])
|
254 |
+
|
255 |
+
# keep track of what we have trained on
|
256 |
+
params['train_sets'] = train_sets_no_path
|
257 |
+
params['test_sets'] = test_sets_no_path
|
258 |
+
|
259 |
+
# load train annotations - merge them all together
|
260 |
+
print('\nTraining on:')
|
261 |
+
for tt in train_sets:
|
262 |
+
print(tt['ann_path'])
|
263 |
+
classes_to_ignore = params['classes_to_ignore']+params['generic_class']
|
264 |
+
data_train, params['class_names'], params['class_inv_freq'] = \
|
265 |
+
tu.load_set_of_anns(train_sets, classes_to_ignore, params['events_of_interest'], params['convert_to_genus'])
|
266 |
+
params['genus_names'], params['genus_mapping'] = tu.get_genus_mapping(params['class_names'])
|
267 |
+
params['class_names_short'] = tu.get_short_class_names(params['class_names'])
|
268 |
+
|
269 |
+
# standardize the low and high frequency value for specified classes
|
270 |
+
params['standardize_classs_names'] = params['standardize_classs_names_ip'].split(';')
|
271 |
+
for cc in params['standardize_classs_names']:
|
272 |
+
if cc in params['class_names']:
|
273 |
+
data_train = tu.standardize_low_freq(data_train, cc)
|
274 |
+
else:
|
275 |
+
print(cc, 'not found')
|
276 |
+
|
277 |
+
# train loader
|
278 |
+
train_dataset = adl.AudioLoader(data_train, params, is_train=True)
|
279 |
+
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params['batch_size'],
|
280 |
+
shuffle=True, num_workers=params['num_workers'], pin_memory=True)
|
281 |
+
|
282 |
+
|
283 |
+
# test set
|
284 |
+
print('\nTesting on:')
|
285 |
+
for tt in test_sets:
|
286 |
+
print(tt['ann_path'])
|
287 |
+
data_test, _, _ = tu.load_set_of_anns(test_sets, classes_to_ignore, params['events_of_interest'], params['convert_to_genus'])
|
288 |
+
data_train = tu.remove_dupes(data_train, data_test)
|
289 |
+
test_dataset = adl.AudioLoader(data_test, params, is_train=False)
|
290 |
+
# batch size of 1 because of variable file length
|
291 |
+
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1,
|
292 |
+
shuffle=False, num_workers=params['num_workers'], pin_memory=True)
|
293 |
+
|
294 |
+
|
295 |
+
inputs_train = next(iter(train_loader))
|
296 |
+
# TODO remove params['ip_height'], this is just legacy
|
297 |
+
params['ip_height'] = int(params['spec_height']*params['resize_factor'])
|
298 |
+
print('\ntrain batch spec size :', inputs_train['spec'].shape)
|
299 |
+
print('class target size :', inputs_train['y_2d_classes'].shape)
|
300 |
+
|
301 |
+
# select network
|
302 |
+
model = select_model(params)
|
303 |
+
model = model.to(params['device'])
|
304 |
+
|
305 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'])
|
306 |
+
#optimizer = torch.optim.SGD(model.parameters(), lr=params['lr'], momentum=0.9)
|
307 |
+
scheduler = CosineAnnealingLR(optimizer, params['num_epochs'] * len(train_loader))
|
308 |
+
if params['train_loss'] == 'mse':
|
309 |
+
det_criterion = losses.mse_loss
|
310 |
+
elif params['train_loss'] == 'focal':
|
311 |
+
det_criterion = losses.focal_loss
|
312 |
+
|
313 |
+
# save parameters to file
|
314 |
+
with open(params['experiment'] + 'params.json', 'w') as da:
|
315 |
+
json.dump(params, da, indent=2, sort_keys=True)
|
316 |
+
|
317 |
+
# plotting
|
318 |
+
train_plt_ls = pu.LossPlotter(params['experiment'] + 'train_loss.png', params['num_epochs']+1,
|
319 |
+
['train_loss'], None, None, ['epoch', 'train_loss'], logy=True)
|
320 |
+
test_plt_ls = pu.LossPlotter(params['experiment'] + 'test_loss.png', params['num_epochs']+1,
|
321 |
+
['test_loss'], None, None, ['epoch', 'test_loss'], logy=True)
|
322 |
+
test_plt = pu.LossPlotter(params['experiment'] + 'test.png', params['num_epochs']+1,
|
323 |
+
['avg_prec', 'rec_at_x', 'avg_prec_class', 'file_acc', 'top_class'], [0,1], None, ['epoch', ''])
|
324 |
+
test_plt_class = pu.LossPlotter(params['experiment'] + 'test_avg_prec.png', params['num_epochs']+1,
|
325 |
+
params['class_names_short'], [0,1], params['class_names_short'], ['epoch', 'avg_prec'])
|
326 |
+
|
327 |
+
|
328 |
+
#
|
329 |
+
# main train loop
|
330 |
+
for epoch in range(0, params['num_epochs']+1):
|
331 |
+
|
332 |
+
train_loss = train(model, epoch, train_loader, det_criterion, optimizer, scheduler, params)
|
333 |
+
train_plt_ls.update_and_save(epoch, [train_loss['train_loss']])
|
334 |
+
|
335 |
+
if epoch % params['num_eval_epochs'] == 0:
|
336 |
+
# detection accuracy on test set
|
337 |
+
test_res, test_loss = test(model, epoch, test_loader, det_criterion, params)
|
338 |
+
test_plt_ls.update_and_save(epoch, [test_loss['test_loss']])
|
339 |
+
test_plt.update_and_save(epoch, [test_res['avg_prec'], test_res['rec_at_x'],
|
340 |
+
test_res['avg_prec_class'], test_res['file_acc'], test_res['top_class']['avg_prec']])
|
341 |
+
test_plt_class.update_and_save(epoch, [rs['avg_prec'] for rs in test_res['class_pr']])
|
342 |
+
pu.plot_pr_curve_class(params['experiment'] , 'test_pr', 'test_pr', test_res)
|
343 |
+
|
344 |
+
|
345 |
+
# save trained model
|
346 |
+
print('saving model to: ' + params['model_file_name'])
|
347 |
+
op_state = {'epoch': epoch + 1,
|
348 |
+
'state_dict': model.state_dict(),
|
349 |
+
#'optimizer' : optimizer.state_dict(),
|
350 |
+
'params' : params}
|
351 |
+
torch.save(op_state, params['model_file_name'])
|
352 |
+
|
353 |
+
|
354 |
+
# save an image with associated prediction for each batch in the test set
|
355 |
+
if not args['do_not_save_images']:
|
356 |
+
save_images_batch(model, test_loader, params)
|
bat_detect/train/train_split.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Run scripts/extract_anns.py to generate these json files.
|
3 |
+
"""
|
4 |
+
|
5 |
+
def get_train_test_data(ann_dir, wav_dir, split_name, load_extra=True):
|
6 |
+
if split_name == 'diff':
|
7 |
+
train_sets, test_sets = split_diff(ann_dir, wav_dir, load_extra)
|
8 |
+
elif split_name == 'same':
|
9 |
+
train_sets, test_sets = split_same(ann_dir, wav_dir, load_extra)
|
10 |
+
else:
|
11 |
+
print('Split not defined')
|
12 |
+
assert False
|
13 |
+
|
14 |
+
return train_sets, test_sets
|
15 |
+
|
16 |
+
|
17 |
+
def split_diff(ann_dir, wav_dir, load_extra=True):
|
18 |
+
|
19 |
+
train_sets = []
|
20 |
+
if load_extra:
|
21 |
+
train_sets.append({'dataset_name': 'BatDetective',
|
22 |
+
'is_test': False,
|
23 |
+
'is_binary': True, # just a bat / not bat dataset ie no classes
|
24 |
+
'ann_path': ann_dir + 'train_set_bulgaria_batdetective_with_bbs.json',
|
25 |
+
'wav_path': wav_dir + 'bat_detective/audio/'})
|
26 |
+
train_sets.append({'dataset_name': 'bat_logger_qeop_empty',
|
27 |
+
'is_test': False,
|
28 |
+
'is_binary': True,
|
29 |
+
'ann_path': ann_dir + 'bat_logger_qeop_empty.json',
|
30 |
+
'wav_path': wav_dir + 'bat_logger_qeop_empty/audio/'})
|
31 |
+
train_sets.append({'dataset_name': 'bat_logger_2016_empty',
|
32 |
+
'is_test': False,
|
33 |
+
'is_binary': True,
|
34 |
+
'ann_path': ann_dir + 'train_set_bat_logger_2016_empty.json',
|
35 |
+
'wav_path': wav_dir + 'bat_logger_2016/audio/'})
|
36 |
+
# train_sets.append({'dataset_name': 'brazil_data_binary',
|
37 |
+
# 'is_test': False,
|
38 |
+
# 'ann_path': ann_dir + 'brazil_data_binary.json',
|
39 |
+
# 'wav_path': wav_dir + 'brazil_data/audio/'})
|
40 |
+
|
41 |
+
train_sets.append({'dataset_name': 'echobank',
|
42 |
+
'is_test': False,
|
43 |
+
'is_binary': False,
|
44 |
+
'ann_path': ann_dir + 'Echobank_train_expert.json',
|
45 |
+
'wav_path': wav_dir + 'echobank/audio/'})
|
46 |
+
train_sets.append({'dataset_name': 'sn_scot_nor',
|
47 |
+
'is_test': False,
|
48 |
+
'is_binary': False,
|
49 |
+
'ann_path': ann_dir + 'sn_scot_nor_0.5_expert.json',
|
50 |
+
'wav_path': wav_dir + 'sn_scot_nor/audio/'})
|
51 |
+
train_sets.append({'dataset_name': 'BCT_1_sec',
|
52 |
+
'is_test': False,
|
53 |
+
'is_binary': False,
|
54 |
+
'ann_path': ann_dir + 'BCT_1_sec_train_expert.json',
|
55 |
+
'wav_path': wav_dir + 'BCT_1_sec/audio/'})
|
56 |
+
train_sets.append({'dataset_name': 'bcireland',
|
57 |
+
'is_test': False,
|
58 |
+
'is_binary': False,
|
59 |
+
'ann_path': ann_dir + 'bcireland_expert.json',
|
60 |
+
'wav_path': wav_dir + 'bcireland/audio/'})
|
61 |
+
train_sets.append({'dataset_name': 'rhinolophus_steve_BCT',
|
62 |
+
'is_test': False,
|
63 |
+
'is_binary': False,
|
64 |
+
'ann_path': ann_dir + 'rhinolophus_steve_BCT_expert.json',
|
65 |
+
'wav_path': wav_dir + 'rhinolophus_steve_BCT/audio/'})
|
66 |
+
|
67 |
+
test_sets = []
|
68 |
+
test_sets.append({'dataset_name': 'bat_data_martyn_2018',
|
69 |
+
'is_test': True,
|
70 |
+
'is_binary': False,
|
71 |
+
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_train_expert.json',
|
72 |
+
'wav_path': wav_dir + 'bat_data_martyn_2018/audio/'})
|
73 |
+
test_sets.append({'dataset_name': 'bat_data_martyn_2018_test',
|
74 |
+
'is_test': True,
|
75 |
+
'is_binary': False,
|
76 |
+
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_test_expert.json',
|
77 |
+
'wav_path': wav_dir + 'bat_data_martyn_2018_test/audio/'})
|
78 |
+
test_sets.append({'dataset_name': 'bat_data_martyn_2019',
|
79 |
+
'is_test': True,
|
80 |
+
'is_binary': False,
|
81 |
+
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_train_expert.json',
|
82 |
+
'wav_path': wav_dir + 'bat_data_martyn_2019/audio/'})
|
83 |
+
test_sets.append({'dataset_name': 'bat_data_martyn_2019_test',
|
84 |
+
'is_test': True,
|
85 |
+
'is_binary': False,
|
86 |
+
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_test_expert.json',
|
87 |
+
'wav_path': wav_dir + 'bat_data_martyn_2019_test/audio/'})
|
88 |
+
|
89 |
+
return train_sets, test_sets
|
90 |
+
|
91 |
+
|
92 |
+
def split_same(ann_dir, wav_dir, load_extra=True):
|
93 |
+
|
94 |
+
train_sets = []
|
95 |
+
if load_extra:
|
96 |
+
train_sets.append({'dataset_name': 'BatDetective',
|
97 |
+
'is_test': False,
|
98 |
+
'is_binary': True,
|
99 |
+
'ann_path': ann_dir + 'train_set_bulgaria_batdetective_with_bbs.json',
|
100 |
+
'wav_path': wav_dir + 'bat_detective/audio/'})
|
101 |
+
train_sets.append({'dataset_name': 'bat_logger_qeop_empty',
|
102 |
+
'is_test': False,
|
103 |
+
'is_binary': True,
|
104 |
+
'ann_path': ann_dir + 'bat_logger_qeop_empty.json',
|
105 |
+
'wav_path': wav_dir + 'bat_logger_qeop_empty/audio/'})
|
106 |
+
train_sets.append({'dataset_name': 'bat_logger_2016_empty',
|
107 |
+
'is_test': False,
|
108 |
+
'is_binary': True,
|
109 |
+
'ann_path': ann_dir + 'train_set_bat_logger_2016_empty.json',
|
110 |
+
'wav_path': wav_dir + 'bat_logger_2016/audio/'})
|
111 |
+
# train_sets.append({'dataset_name': 'brazil_data_binary',
|
112 |
+
# 'is_test': False,
|
113 |
+
# 'ann_path': ann_dir + 'brazil_data_binary.json',
|
114 |
+
# 'wav_path': wav_dir + 'brazil_data/audio/'})
|
115 |
+
|
116 |
+
train_sets.append({'dataset_name': 'echobank',
|
117 |
+
'is_test': False,
|
118 |
+
'is_binary': False,
|
119 |
+
'ann_path': ann_dir + 'Echobank_train_expert_TRAIN.json',
|
120 |
+
'wav_path': wav_dir + 'echobank/audio/'})
|
121 |
+
train_sets.append({'dataset_name': 'sn_scot_nor',
|
122 |
+
'is_test': False,
|
123 |
+
'is_binary': False,
|
124 |
+
'ann_path': ann_dir + 'sn_scot_nor_0.5_expert_TRAIN.json',
|
125 |
+
'wav_path': wav_dir + 'sn_scot_nor/audio/'})
|
126 |
+
train_sets.append({'dataset_name': 'BCT_1_sec',
|
127 |
+
'is_test': False,
|
128 |
+
'is_binary': False,
|
129 |
+
'ann_path': ann_dir + 'BCT_1_sec_train_expert_TRAIN.json',
|
130 |
+
'wav_path': wav_dir + 'BCT_1_sec/audio/'})
|
131 |
+
train_sets.append({'dataset_name': 'bcireland',
|
132 |
+
'is_test': False,
|
133 |
+
'is_binary': False,
|
134 |
+
'ann_path': ann_dir + 'bcireland_expert_TRAIN.json',
|
135 |
+
'wav_path': wav_dir + 'bcireland/audio/'})
|
136 |
+
train_sets.append({'dataset_name': 'rhinolophus_steve_BCT',
|
137 |
+
'is_test': False,
|
138 |
+
'is_binary': False,
|
139 |
+
'ann_path': ann_dir + 'rhinolophus_steve_BCT_expert_TRAIN.json',
|
140 |
+
'wav_path': wav_dir + 'rhinolophus_steve_BCT/audio/'})
|
141 |
+
train_sets.append({'dataset_name': 'bat_data_martyn_2018',
|
142 |
+
'is_test': False,
|
143 |
+
'is_binary': False,
|
144 |
+
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_train_expert_TRAIN.json',
|
145 |
+
'wav_path': wav_dir + 'bat_data_martyn_2018/audio/'})
|
146 |
+
train_sets.append({'dataset_name': 'bat_data_martyn_2018_test',
|
147 |
+
'is_test': False,
|
148 |
+
'is_binary': False,
|
149 |
+
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_test_expert_TRAIN.json',
|
150 |
+
'wav_path': wav_dir + 'bat_data_martyn_2018_test/audio/'})
|
151 |
+
train_sets.append({'dataset_name': 'bat_data_martyn_2019',
|
152 |
+
'is_test': False,
|
153 |
+
'is_binary': False,
|
154 |
+
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_train_expert_TRAIN.json',
|
155 |
+
'wav_path': wav_dir + 'bat_data_martyn_2019/audio/'})
|
156 |
+
train_sets.append({'dataset_name': 'bat_data_martyn_2019_test',
|
157 |
+
'is_test': False,
|
158 |
+
'is_binary': False,
|
159 |
+
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_test_expert_TRAIN.json',
|
160 |
+
'wav_path': wav_dir + 'bat_data_martyn_2019_test/audio/'})
|
161 |
+
|
162 |
+
# train_sets.append({'dataset_name': 'bat_data_martyn_2021_train',
|
163 |
+
# 'is_test': False,
|
164 |
+
# 'is_binary': False,
|
165 |
+
# 'ann_path': ann_dir + 'bat_data_martyn_2021_TRAIN.json',
|
166 |
+
# 'wav_path': wav_dir + 'bat_data_martyn_2021/audio/'})
|
167 |
+
# train_sets.append({'dataset_name': 'volunteers_2021_train',
|
168 |
+
# 'is_test': False,
|
169 |
+
# 'is_binary': False,
|
170 |
+
# 'ann_path': ann_dir + 'volunteers_2021_TRAIN.json',
|
171 |
+
# 'wav_path': wav_dir + 'volunteers_2021/audio/'})
|
172 |
+
|
173 |
+
test_sets = []
|
174 |
+
test_sets.append({'dataset_name': 'echobank',
|
175 |
+
'is_test': True,
|
176 |
+
'is_binary': False,
|
177 |
+
'ann_path': ann_dir + 'Echobank_train_expert_TEST.json',
|
178 |
+
'wav_path': wav_dir + 'echobank/audio/'})
|
179 |
+
test_sets.append({'dataset_name': 'sn_scot_nor',
|
180 |
+
'is_test': True,
|
181 |
+
'is_binary': False,
|
182 |
+
'ann_path': ann_dir + 'sn_scot_nor_0.5_expert_TEST.json',
|
183 |
+
'wav_path': wav_dir + 'sn_scot_nor/audio/'})
|
184 |
+
test_sets.append({'dataset_name': 'BCT_1_sec',
|
185 |
+
'is_test': True,
|
186 |
+
'is_binary': False,
|
187 |
+
'ann_path': ann_dir + 'BCT_1_sec_train_expert_TEST.json',
|
188 |
+
'wav_path': wav_dir + 'BCT_1_sec/audio/'})
|
189 |
+
test_sets.append({'dataset_name': 'bcireland',
|
190 |
+
'is_test': True,
|
191 |
+
'is_binary': False,
|
192 |
+
'ann_path': ann_dir + 'bcireland_expert_TEST.json',
|
193 |
+
'wav_path': wav_dir + 'bcireland/audio/'})
|
194 |
+
test_sets.append({'dataset_name': 'rhinolophus_steve_BCT',
|
195 |
+
'is_test': True,
|
196 |
+
'is_binary': False,
|
197 |
+
'ann_path': ann_dir + 'rhinolophus_steve_BCT_expert_TEST.json',
|
198 |
+
'wav_path': wav_dir + 'rhinolophus_steve_BCT/audio/'})
|
199 |
+
test_sets.append({'dataset_name': 'bat_data_martyn_2018',
|
200 |
+
'is_test': True,
|
201 |
+
'is_binary': False,
|
202 |
+
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_train_expert_TEST.json',
|
203 |
+
'wav_path': wav_dir + 'bat_data_martyn_2018/audio/'})
|
204 |
+
test_sets.append({'dataset_name': 'bat_data_martyn_2018_test',
|
205 |
+
'is_test': True,
|
206 |
+
'is_binary': False,
|
207 |
+
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2018_1_sec_test_expert_TEST.json',
|
208 |
+
'wav_path': wav_dir + 'bat_data_martyn_2018_test/audio/'})
|
209 |
+
test_sets.append({'dataset_name': 'bat_data_martyn_2019',
|
210 |
+
'is_test': True,
|
211 |
+
'is_binary': False,
|
212 |
+
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_train_expert_TEST.json',
|
213 |
+
'wav_path': wav_dir + 'bat_data_martyn_2019/audio/'})
|
214 |
+
test_sets.append({'dataset_name': 'bat_data_martyn_2019_test',
|
215 |
+
'is_test': True,
|
216 |
+
'is_binary': False,
|
217 |
+
'ann_path': ann_dir + 'BritishBatCalls_MartynCooke_2019_1_sec_test_expert_TEST.json',
|
218 |
+
'wav_path': wav_dir + 'bat_data_martyn_2019_test/audio/'})
|
219 |
+
|
220 |
+
# test_sets.append({'dataset_name': 'bat_data_martyn_2021_test',
|
221 |
+
# 'is_test': True,
|
222 |
+
# 'is_binary': False,
|
223 |
+
# 'ann_path': ann_dir + 'bat_data_martyn_2021_TEST.json',
|
224 |
+
# 'wav_path': wav_dir + 'bat_data_martyn_2021/audio/'})
|
225 |
+
# test_sets.append({'dataset_name': 'volunteers_2021_test',
|
226 |
+
# 'is_test': True,
|
227 |
+
# 'is_binary': False,
|
228 |
+
# 'ann_path': ann_dir + 'volunteers_2021_TEST.json',
|
229 |
+
# 'wav_path': wav_dir + 'volunteers_2021/audio/'})
|
230 |
+
|
231 |
+
return train_sets, test_sets
|
bat_detect/train/train_utils.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import os
|
4 |
+
import glob
|
5 |
+
import json
|
6 |
+
|
7 |
+
|
8 |
+
def write_notes_file(file_name, text):
|
9 |
+
with open(file_name, 'a') as da:
|
10 |
+
da.write(text + '\n')
|
11 |
+
|
12 |
+
|
13 |
+
def get_blank_dataset_dict(dataset_name, is_test, ann_path, wav_path):
|
14 |
+
ddict = {'dataset_name': dataset_name, 'is_test': is_test, 'is_binary': False,
|
15 |
+
'ann_path': ann_path, 'wav_path': wav_path}
|
16 |
+
return ddict
|
17 |
+
|
18 |
+
|
19 |
+
def get_short_class_names(class_names, str_len=3):
|
20 |
+
class_names_short = []
|
21 |
+
for cc in class_names:
|
22 |
+
class_names_short.append(' '.join([sp[:str_len] for sp in cc.split(' ')]))
|
23 |
+
return class_names_short
|
24 |
+
|
25 |
+
|
26 |
+
def remove_dupes(data_train, data_test):
|
27 |
+
test_ids = [dd['id'] for dd in data_test]
|
28 |
+
data_train_prune = []
|
29 |
+
for aa in data_train:
|
30 |
+
if aa['id'] not in test_ids:
|
31 |
+
data_train_prune.append(aa)
|
32 |
+
diff = len(data_train) - len(data_train_prune)
|
33 |
+
if diff != 0:
|
34 |
+
print(diff, 'items removed from train set')
|
35 |
+
return data_train_prune
|
36 |
+
|
37 |
+
|
38 |
+
def get_genus_mapping(class_names):
|
39 |
+
genus_names, genus_mapping = np.unique([cc.split(' ')[0] for cc in class_names], return_inverse=True)
|
40 |
+
return genus_names.tolist(), genus_mapping.tolist()
|
41 |
+
|
42 |
+
|
43 |
+
def standardize_low_freq(data, class_of_interest):
|
44 |
+
# address the issue of highly variable low frequency annotations
|
45 |
+
# this often happens for contstant frequency calls
|
46 |
+
# for the class of interest sets the low and high freq to be the dataset mean
|
47 |
+
low_freqs = []
|
48 |
+
high_freqs = []
|
49 |
+
for dd in data:
|
50 |
+
for aa in dd['annotation']:
|
51 |
+
if aa['class'] == class_of_interest:
|
52 |
+
low_freqs.append(aa['low_freq'])
|
53 |
+
high_freqs.append(aa['high_freq'])
|
54 |
+
|
55 |
+
low_mean = np.mean(low_freqs)
|
56 |
+
high_mean = np.mean(high_freqs)
|
57 |
+
assert(low_mean < high_mean)
|
58 |
+
|
59 |
+
print('\nStandardizing low and high frequency for:')
|
60 |
+
print(class_of_interest)
|
61 |
+
print('low: ', round(low_mean, 2))
|
62 |
+
print('high: ', round(high_mean, 2))
|
63 |
+
|
64 |
+
# only set the low freq, high stays the same
|
65 |
+
# assumes that low_mean < high_mean
|
66 |
+
for dd in data:
|
67 |
+
for aa in dd['annotation']:
|
68 |
+
if aa['class'] == class_of_interest:
|
69 |
+
aa['low_freq'] = low_mean
|
70 |
+
if aa['high_freq'] < low_mean:
|
71 |
+
aa['high_freq'] = high_mean
|
72 |
+
|
73 |
+
return data
|
74 |
+
|
75 |
+
|
76 |
+
def load_set_of_anns(data, classes_to_ignore=[], events_of_interest=None,
|
77 |
+
convert_to_genus=False, verbose=True, list_of_anns=False,
|
78 |
+
filter_issues=False, name_replace=False):
|
79 |
+
|
80 |
+
# load the annotations
|
81 |
+
anns = []
|
82 |
+
if list_of_anns:
|
83 |
+
# path to list of individual json files
|
84 |
+
anns.extend(load_anns_from_path(data['ann_path'], data['wav_path']))
|
85 |
+
else:
|
86 |
+
# dictionary of datasets
|
87 |
+
for dd in data:
|
88 |
+
anns.extend(load_anns(dd['ann_path'], dd['wav_path']))
|
89 |
+
|
90 |
+
# discarding unannoated files
|
91 |
+
anns = [aa for aa in anns if aa['annotated'] is True]
|
92 |
+
|
93 |
+
# filter files that have annotation issues - is the input is a dictionary of
|
94 |
+
# datasets, this will lilely have already been done
|
95 |
+
if filter_issues:
|
96 |
+
anns = [aa for aa in anns if aa['issues'] is False]
|
97 |
+
|
98 |
+
# check for some basic formatting errors with class names
|
99 |
+
for ann in anns:
|
100 |
+
for aa in ann['annotation']:
|
101 |
+
aa['class'] = aa['class'].strip()
|
102 |
+
|
103 |
+
# only load specified events - i.e. types of calls
|
104 |
+
if events_of_interest is not None:
|
105 |
+
for ann in anns:
|
106 |
+
filtered_events = []
|
107 |
+
for aa in ann['annotation']:
|
108 |
+
if aa['event'] in events_of_interest:
|
109 |
+
filtered_events.append(aa)
|
110 |
+
ann['annotation'] = filtered_events
|
111 |
+
|
112 |
+
# change class names
|
113 |
+
# replace_names will be a dictionary mapping input name to output
|
114 |
+
if type(name_replace) is dict:
|
115 |
+
for ann in anns:
|
116 |
+
for aa in ann['annotation']:
|
117 |
+
if aa['class'] in name_replace:
|
118 |
+
aa['class'] = name_replace[aa['class']]
|
119 |
+
|
120 |
+
# convert everything to genus name
|
121 |
+
if convert_to_genus:
|
122 |
+
for ann in anns:
|
123 |
+
for aa in ann['annotation']:
|
124 |
+
aa['class'] = aa['class'].split(' ')[0]
|
125 |
+
|
126 |
+
# get unique class names
|
127 |
+
class_names_all = []
|
128 |
+
for ann in anns:
|
129 |
+
for aa in ann['annotation']:
|
130 |
+
if aa['class'] not in classes_to_ignore:
|
131 |
+
class_names_all.append(aa['class'])
|
132 |
+
|
133 |
+
class_names, class_cnts = np.unique(class_names_all, return_counts=True)
|
134 |
+
class_inv_freq = (class_cnts.sum() / (len(class_names) * class_cnts.astype(np.float32)))
|
135 |
+
|
136 |
+
if verbose:
|
137 |
+
print('Class count:')
|
138 |
+
str_len = np.max([len(cc) for cc in class_names]) + 5
|
139 |
+
for cc in range(len(class_names)):
|
140 |
+
print(str(cc).ljust(5) + class_names[cc].ljust(str_len) + str(class_cnts[cc]))
|
141 |
+
|
142 |
+
if len(classes_to_ignore) == 0:
|
143 |
+
return anns
|
144 |
+
else:
|
145 |
+
return anns, class_names.tolist(), class_inv_freq.tolist()
|
146 |
+
|
147 |
+
|
148 |
+
def load_anns(ann_file_name, raw_audio_dir):
|
149 |
+
with open(ann_file_name) as da:
|
150 |
+
anns = json.load(da)
|
151 |
+
|
152 |
+
for aa in anns:
|
153 |
+
aa['file_path'] = raw_audio_dir + aa['id']
|
154 |
+
|
155 |
+
return anns
|
156 |
+
|
157 |
+
|
158 |
+
def load_anns_from_path(ann_file_dir, raw_audio_dir):
|
159 |
+
files = glob.glob(ann_file_dir + '*.json')
|
160 |
+
anns = []
|
161 |
+
for ff in files:
|
162 |
+
with open(ff) as da:
|
163 |
+
ann = json.load(da)
|
164 |
+
ann['file_path'] = raw_audio_dir + ann['id']
|
165 |
+
anns.append(ann)
|
166 |
+
|
167 |
+
return anns
|
168 |
+
|
169 |
+
|
170 |
+
class AverageMeter(object):
|
171 |
+
"""Computes and stores the average and current value"""
|
172 |
+
def __init__(self):
|
173 |
+
self.reset()
|
174 |
+
|
175 |
+
def reset(self):
|
176 |
+
self.val = 0
|
177 |
+
self.avg = 0
|
178 |
+
self.sum = 0
|
179 |
+
self.count = 0
|
180 |
+
|
181 |
+
def update(self, val, n=1):
|
182 |
+
self.val = val
|
183 |
+
self.sum += val * n
|
184 |
+
self.count += n
|
185 |
+
self.avg = self.sum / self.count
|
bat_detect/utils/__init__.py
ADDED
File without changes
|
bat_detect/utils/audio_utils.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from . import wavfile
|
3 |
+
import warnings
|
4 |
+
import torch
|
5 |
+
import librosa
|
6 |
+
|
7 |
+
|
8 |
+
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
|
9 |
+
nfft = np.floor(fft_win_length*sampling_rate) # int() uses floor
|
10 |
+
noverlap = np.floor(fft_overlap*nfft)
|
11 |
+
return (time_in_file*sampling_rate-noverlap) / (nfft - noverlap)
|
12 |
+
|
13 |
+
|
14 |
+
# NOTE this is also defined in post_process
|
15 |
+
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
|
16 |
+
nfft = np.floor(fft_win_length*sampling_rate)
|
17 |
+
noverlap = np.floor(fft_overlap*nfft)
|
18 |
+
return ((x_pos*(nfft - noverlap)) + noverlap) / sampling_rate
|
19 |
+
#return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
|
20 |
+
|
21 |
+
|
22 |
+
def generate_spectrogram(audio, sampling_rate, params, return_spec_for_viz=False, check_spec_size=True):
|
23 |
+
|
24 |
+
# generate spectrogram
|
25 |
+
spec = gen_mag_spectrogram(audio, sampling_rate, params['fft_win_length'], params['fft_overlap'])
|
26 |
+
|
27 |
+
# crop to min/max freq
|
28 |
+
max_freq = round(params['max_freq']*params['fft_win_length'])
|
29 |
+
min_freq = round(params['min_freq']*params['fft_win_length'])
|
30 |
+
if spec.shape[0] < max_freq:
|
31 |
+
freq_pad = max_freq - spec.shape[0]
|
32 |
+
spec = np.vstack((np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec))
|
33 |
+
spec_cropped = spec[-max_freq:spec.shape[0]-min_freq, :]
|
34 |
+
|
35 |
+
if params['spec_scale'] == 'log':
|
36 |
+
log_scaling = 2.0 * (1.0 / sampling_rate) * (1.0/(np.abs(np.hanning(int(params['fft_win_length']*sampling_rate)))**2).sum())
|
37 |
+
#log_scaling = (1.0 / sampling_rate)*0.1
|
38 |
+
#log_scaling = (1.0 / sampling_rate)*10e4
|
39 |
+
spec = np.log1p(log_scaling*spec_cropped)
|
40 |
+
elif params['spec_scale'] == 'pcen':
|
41 |
+
spec = pcen(spec_cropped, sampling_rate)
|
42 |
+
elif params['spec_scale'] == 'none':
|
43 |
+
pass
|
44 |
+
|
45 |
+
if params['denoise_spec_avg']:
|
46 |
+
spec = spec - np.mean(spec, 1)[:, np.newaxis]
|
47 |
+
spec.clip(min=0, out=spec)
|
48 |
+
|
49 |
+
if params['max_scale_spec']:
|
50 |
+
spec = spec / (spec.max() + 10e-6)
|
51 |
+
|
52 |
+
# needs to be divisible by specific factor - if not it should have been padded
|
53 |
+
#if check_spec_size:
|
54 |
+
#assert((int(spec.shape[0]*params['resize_factor']) % params['spec_divide_factor']) == 0)
|
55 |
+
#assert((int(spec.shape[1]*params['resize_factor']) % params['spec_divide_factor']) == 0)
|
56 |
+
|
57 |
+
# for visualization purposes - use log scaled spectrogram
|
58 |
+
if return_spec_for_viz:
|
59 |
+
log_scaling = 2.0 * (1.0 / sampling_rate) * (1.0/(np.abs(np.hanning(int(params['fft_win_length']*sampling_rate)))**2).sum())
|
60 |
+
spec_for_viz = np.log1p(log_scaling*spec_cropped).astype(np.float32)
|
61 |
+
else:
|
62 |
+
spec_for_viz = None
|
63 |
+
|
64 |
+
return spec, spec_for_viz
|
65 |
+
|
66 |
+
|
67 |
+
def load_audio_file(audio_file, time_exp_fact, target_samp_rate, scale=False):
|
68 |
+
with warnings.catch_warnings():
|
69 |
+
warnings.filterwarnings('ignore', category=wavfile.WavFileWarning)
|
70 |
+
#sampling_rate, audio_raw = wavfile.read(audio_file)
|
71 |
+
audio_raw, sampling_rate = librosa.load(audio_file, sr=None)
|
72 |
+
|
73 |
+
if len(audio_raw.shape) > 1:
|
74 |
+
raise Exception('Currently does not handle stereo files')
|
75 |
+
sampling_rate = sampling_rate * time_exp_fact
|
76 |
+
|
77 |
+
# resample - need to do this after correcting for time expansion
|
78 |
+
sampling_rate_old = sampling_rate
|
79 |
+
sampling_rate = target_samp_rate
|
80 |
+
audio_raw = librosa.resample(audio_raw, orig_sr=sampling_rate_old, target_sr=sampling_rate, res_type='polyphase')
|
81 |
+
|
82 |
+
# convert to float32 and scale
|
83 |
+
audio_raw = audio_raw.astype(np.float32)
|
84 |
+
if scale:
|
85 |
+
audio_raw = audio_raw - audio_raw.mean()
|
86 |
+
audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6)
|
87 |
+
|
88 |
+
return sampling_rate, audio_raw
|
89 |
+
|
90 |
+
|
91 |
+
def pad_audio(audio_raw, fs, ms, overlap_perc, resize_factor, divide_factor, fixed_width=None):
|
92 |
+
# Adds zeros to the end of the raw data so that the generated sepctrogram
|
93 |
+
# will be evenly divisible by `divide_factor`
|
94 |
+
# Also deals with very short audio clips and fixed_width during training
|
95 |
+
|
96 |
+
# This code could be clearer, clean up
|
97 |
+
nfft = int(ms*fs)
|
98 |
+
noverlap = int(overlap_perc*nfft)
|
99 |
+
step = nfft - noverlap
|
100 |
+
min_size = int(divide_factor*(1.0/resize_factor))
|
101 |
+
spec_width = ((audio_raw.shape[0]-noverlap)//step)
|
102 |
+
spec_width_rs = spec_width * resize_factor
|
103 |
+
|
104 |
+
if fixed_width is not None and spec_width < fixed_width:
|
105 |
+
# too small
|
106 |
+
# used during training to ensure all the batches are the same size
|
107 |
+
diff = fixed_width*step + noverlap - audio_raw.shape[0]
|
108 |
+
audio_raw = np.hstack((audio_raw, np.zeros(diff, dtype=audio_raw.dtype)))
|
109 |
+
|
110 |
+
elif fixed_width is not None and spec_width > fixed_width:
|
111 |
+
# too big
|
112 |
+
# used during training to ensure all the batches are the same size
|
113 |
+
diff = fixed_width*step + noverlap - audio_raw.shape[0]
|
114 |
+
audio_raw = audio_raw[:diff]
|
115 |
+
|
116 |
+
elif spec_width_rs < min_size or (np.floor(spec_width_rs) % divide_factor) != 0:
|
117 |
+
# need to be at least min_size
|
118 |
+
div_amt = np.ceil(spec_width_rs / float(divide_factor))
|
119 |
+
div_amt = np.maximum(1, div_amt)
|
120 |
+
target_size = int(div_amt*divide_factor*(1.0/resize_factor))
|
121 |
+
diff = target_size*step + noverlap - audio_raw.shape[0]
|
122 |
+
audio_raw = np.hstack((audio_raw, np.zeros(diff, dtype=audio_raw.dtype)))
|
123 |
+
|
124 |
+
return audio_raw
|
125 |
+
|
126 |
+
|
127 |
+
def gen_mag_spectrogram(x, fs, ms, overlap_perc):
|
128 |
+
# Computes magnitude spectrogram by specifying time.
|
129 |
+
|
130 |
+
x = x.astype(np.float32)
|
131 |
+
nfft = int(ms*fs)
|
132 |
+
noverlap = int(overlap_perc*nfft)
|
133 |
+
|
134 |
+
# window data
|
135 |
+
step = nfft - noverlap
|
136 |
+
|
137 |
+
# compute spec
|
138 |
+
spec, _ = librosa.core.spectrum._spectrogram(y=x, power=1, n_fft=nfft, hop_length=step, center=False)
|
139 |
+
|
140 |
+
# remove DC component and flip vertical orientation
|
141 |
+
spec = np.flipud(spec[1:, :])
|
142 |
+
|
143 |
+
return spec.astype(np.float32)
|
144 |
+
|
145 |
+
|
146 |
+
def gen_mag_spectrogram_pt(x, fs, ms, overlap_perc):
|
147 |
+
nfft = int(ms*fs)
|
148 |
+
nstep = round((1.0-overlap_perc)*nfft)
|
149 |
+
|
150 |
+
han_win = torch.hann_window(nfft, periodic=False).to(x.device)
|
151 |
+
|
152 |
+
complex_spec = torch.stft(x, nfft, nstep, window=han_win, center=False)
|
153 |
+
spec = complex_spec.pow(2.0).sum(-1)
|
154 |
+
|
155 |
+
# remove DC component and flip vertically
|
156 |
+
spec = torch.flipud(spec[0, 1:,:])
|
157 |
+
|
158 |
+
return spec
|
159 |
+
|
160 |
+
|
161 |
+
def pcen(spec_cropped, sampling_rate):
|
162 |
+
# TODO should be passing hop_length too i.e. step
|
163 |
+
spec = librosa.pcen(spec_cropped * (2**31), sr=sampling_rate/10).astype(np.float32)
|
164 |
+
return spec
|
bat_detect/utils/detector_utils.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import json
|
7 |
+
import sys
|
8 |
+
|
9 |
+
from bat_detect.detector import models
|
10 |
+
import bat_detect.detector.compute_features as feats
|
11 |
+
import bat_detect.detector.post_process as pp
|
12 |
+
import bat_detect.utils.audio_utils as au
|
13 |
+
|
14 |
+
|
15 |
+
def get_default_bd_args():
|
16 |
+
args = {}
|
17 |
+
args['detection_threshold'] = 0.001
|
18 |
+
args['time_expansion_factor'] = 1
|
19 |
+
args['audio_dir'] = ''
|
20 |
+
args['ann_dir'] = ''
|
21 |
+
args['spec_slices'] = False
|
22 |
+
args['chunk_size'] = 3
|
23 |
+
args['spec_features'] = False
|
24 |
+
args['cnn_features'] = False
|
25 |
+
args['quiet'] = True
|
26 |
+
args['save_preds_if_empty'] = True
|
27 |
+
args['ann_dir'] = os.path.join(args['ann_dir'], '')
|
28 |
+
return args
|
29 |
+
|
30 |
+
|
31 |
+
def get_audio_files(ip_dir):
|
32 |
+
|
33 |
+
matches = []
|
34 |
+
for root, dirnames, filenames in os.walk(ip_dir):
|
35 |
+
for filename in filenames:
|
36 |
+
if filename.lower().endswith('.wav'):
|
37 |
+
matches.append(os.path.join(root, filename))
|
38 |
+
return matches
|
39 |
+
|
40 |
+
|
41 |
+
def load_model(model_path, load_weights=True):
|
42 |
+
|
43 |
+
# load model
|
44 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
45 |
+
if os.path.isfile(model_path):
|
46 |
+
net_params = torch.load(model_path, map_location=device)
|
47 |
+
else:
|
48 |
+
print('Error: model not found.')
|
49 |
+
sys.exit(1)
|
50 |
+
|
51 |
+
params = net_params['params']
|
52 |
+
params['device'] = device
|
53 |
+
|
54 |
+
if params['model_name'] == 'Net2DFast':
|
55 |
+
model = models.Net2DFast(params['num_filters'], num_classes=len(params['class_names']),
|
56 |
+
emb_dim=params['emb_dim'], ip_height=params['ip_height'],
|
57 |
+
resize_factor=params['resize_factor'])
|
58 |
+
elif params['model_name'] == 'Net2DFastNoAttn':
|
59 |
+
model = models.Net2DFastNoAttn(params['num_filters'], num_classes=len(params['class_names']),
|
60 |
+
emb_dim=params['emb_dim'], ip_height=params['ip_height'],
|
61 |
+
resize_factor=params['resize_factor'])
|
62 |
+
elif params['model_name'] == 'Net2DFastNoCoordConv':
|
63 |
+
model = models.Net2DFastNoCoordConv(params['num_filters'], num_classes=len(params['class_names']),
|
64 |
+
emb_dim=params['emb_dim'], ip_height=params['ip_height'],
|
65 |
+
resize_factor=params['resize_factor'])
|
66 |
+
else:
|
67 |
+
print('Error: unknown model.')
|
68 |
+
|
69 |
+
if load_weights:
|
70 |
+
model.load_state_dict(net_params['state_dict'])
|
71 |
+
|
72 |
+
model = model.to(params['device'])
|
73 |
+
model.eval()
|
74 |
+
|
75 |
+
return model, params
|
76 |
+
|
77 |
+
|
78 |
+
def merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
79 |
+
|
80 |
+
predictions_m = {}
|
81 |
+
num_preds = np.sum([len(pp['det_probs']) for pp in predictions])
|
82 |
+
|
83 |
+
if num_preds > 0:
|
84 |
+
for kk in predictions[0].keys():
|
85 |
+
predictions_m[kk] = np.hstack([pp[kk] for pp in predictions if pp['det_probs'].shape[0] > 0])
|
86 |
+
else:
|
87 |
+
# hack in case where no detected calls as we need some of the key names in dict
|
88 |
+
predictions_m = predictions[0]
|
89 |
+
|
90 |
+
if len(spec_feats) > 0:
|
91 |
+
spec_feats = np.vstack(spec_feats)
|
92 |
+
if len(cnn_feats) > 0:
|
93 |
+
cnn_feats = np.vstack(cnn_feats)
|
94 |
+
return predictions_m, spec_feats, cnn_feats, spec_slices
|
95 |
+
|
96 |
+
|
97 |
+
def convert_results(file_id, time_exp, duration, params, predictions, spec_feats, cnn_feats, spec_slices):
|
98 |
+
|
99 |
+
# create a single dictionary - this is the format used by the annotation tool
|
100 |
+
pred_dict = {}
|
101 |
+
pred_dict['id'] = file_id
|
102 |
+
pred_dict['annotated'] = False
|
103 |
+
pred_dict['issues'] = False
|
104 |
+
pred_dict['notes'] = 'Automatically generated.'
|
105 |
+
pred_dict['time_exp'] = time_exp
|
106 |
+
pred_dict['duration'] = round(duration, 4)
|
107 |
+
pred_dict['annotation'] = []
|
108 |
+
|
109 |
+
class_prob_best = predictions['class_probs'].max(0)
|
110 |
+
class_ind_best = predictions['class_probs'].argmax(0)
|
111 |
+
class_overall = pp.overall_class_pred(predictions['det_probs'], predictions['class_probs'])
|
112 |
+
pred_dict['class_name'] = params['class_names'][np.argmax(class_overall)]
|
113 |
+
|
114 |
+
for ii in range(predictions['det_probs'].shape[0]):
|
115 |
+
res = {}
|
116 |
+
res['start_time'] = round(float(predictions['start_times'][ii]), 4)
|
117 |
+
res['end_time'] = round(float(predictions['end_times'][ii]), 4)
|
118 |
+
res['low_freq'] = int(predictions['low_freqs'][ii])
|
119 |
+
res['high_freq'] = int(predictions['high_freqs'][ii])
|
120 |
+
res['class'] = str(params['class_names'][int(class_ind_best[ii])])
|
121 |
+
res['class_prob'] = round(float(class_prob_best[ii]), 3)
|
122 |
+
res['det_prob'] = round(float(predictions['det_probs'][ii]), 3)
|
123 |
+
res['individual'] = '-1'
|
124 |
+
res['event'] = 'Echolocation'
|
125 |
+
pred_dict['annotation'].append(res)
|
126 |
+
|
127 |
+
# combine into final results dictionary
|
128 |
+
results = {}
|
129 |
+
results['pred_dict'] = pred_dict
|
130 |
+
if len(spec_feats) > 0:
|
131 |
+
results['spec_feats'] = spec_feats
|
132 |
+
results['spec_feat_names'] = feats.get_feature_names()
|
133 |
+
if len(cnn_feats) > 0:
|
134 |
+
results['cnn_feats'] = cnn_feats
|
135 |
+
results['cnn_feat_names'] = [str(ii) for ii in range(cnn_feats.shape[1])]
|
136 |
+
if len(spec_slices) > 0:
|
137 |
+
results['spec_slices'] = spec_slices
|
138 |
+
|
139 |
+
return results
|
140 |
+
|
141 |
+
|
142 |
+
def save_results_to_file(results, op_path):
|
143 |
+
|
144 |
+
# make directory if it does not exist
|
145 |
+
if not os.path.isdir(os.path.dirname(op_path)):
|
146 |
+
os.makedirs(os.path.dirname(op_path))
|
147 |
+
|
148 |
+
# save csv file - if there are predictions
|
149 |
+
result_list = [res for res in results['pred_dict']['annotation']]
|
150 |
+
df = pd.DataFrame(result_list)
|
151 |
+
df['file_name'] = [results['pred_dict']['id']]*len(result_list)
|
152 |
+
df.index.name = 'id'
|
153 |
+
if 'class_prob' in df.columns:
|
154 |
+
df = df[['det_prob', 'start_time', 'end_time', 'high_freq',
|
155 |
+
'low_freq', 'class', 'class_prob']]
|
156 |
+
df.to_csv(op_path + '.csv', sep=',')
|
157 |
+
|
158 |
+
# save features
|
159 |
+
if 'spec_feats' in results.keys():
|
160 |
+
df = pd.DataFrame(results['spec_feats'], columns=results['spec_feat_names'])
|
161 |
+
df.to_csv(op_path + '_spec_features.csv', sep=',', index=False, float_format='%.5f')
|
162 |
+
|
163 |
+
if 'cnn_feats' in results.keys():
|
164 |
+
df = pd.DataFrame(results['cnn_feats'], columns=results['cnn_feat_names'])
|
165 |
+
df.to_csv(op_path + '_cnn_features.csv', sep=',', index=False, float_format='%.5f')
|
166 |
+
|
167 |
+
# save json file
|
168 |
+
with open(op_path + '.json', 'w') as da:
|
169 |
+
json.dump(results['pred_dict'], da, indent=2, sort_keys=True)
|
170 |
+
|
171 |
+
|
172 |
+
def compute_spectrogram(audio, sampling_rate, params, return_np=False):
|
173 |
+
|
174 |
+
# pad audio so it is evenly divisible by downsampling factors
|
175 |
+
duration = audio.shape[0] / float(sampling_rate)
|
176 |
+
audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'],
|
177 |
+
params['fft_overlap'], params['resize_factor'],
|
178 |
+
params['spec_divide_factor'])
|
179 |
+
|
180 |
+
# generate spectrogram
|
181 |
+
spec, _ = au.generate_spectrogram(audio, sampling_rate, params)
|
182 |
+
|
183 |
+
# convert to pytorch
|
184 |
+
spec = torch.from_numpy(spec).to(params['device'])
|
185 |
+
spec = spec.unsqueeze(0).unsqueeze(0)
|
186 |
+
|
187 |
+
# resize the spec
|
188 |
+
rs = params['resize_factor']
|
189 |
+
spec_op_shape = (int(params['spec_height']*rs), int(spec.shape[-1]*rs))
|
190 |
+
spec = F.interpolate(spec, size=spec_op_shape, mode='bilinear', align_corners=False)
|
191 |
+
|
192 |
+
if return_np:
|
193 |
+
spec_np = spec[0,0,:].cpu().data.numpy()
|
194 |
+
else:
|
195 |
+
spec_np = None
|
196 |
+
|
197 |
+
return duration, spec, spec_np
|
198 |
+
|
199 |
+
|
200 |
+
def process_file(audio_file, model, params, args, time_exp=None, top_n=5, return_raw_preds=False, max_duration=False):
|
201 |
+
|
202 |
+
# store temporary results here
|
203 |
+
predictions = []
|
204 |
+
spec_feats = []
|
205 |
+
cnn_feats = []
|
206 |
+
spec_slices = []
|
207 |
+
|
208 |
+
# get time expansion factor
|
209 |
+
if time_exp is None:
|
210 |
+
time_exp = args['time_expansion_factor']
|
211 |
+
|
212 |
+
params['detection_threshold'] = args['detection_threshold']
|
213 |
+
|
214 |
+
# load audio file
|
215 |
+
sampling_rate, audio_full = au.load_audio_file(audio_file, time_exp,
|
216 |
+
params['target_samp_rate'], params['scale_raw_audio'])
|
217 |
+
|
218 |
+
# clipping maximum duration
|
219 |
+
if max_duration is not False:
|
220 |
+
max_duration = np.minimum(int(sampling_rate*max_duration), audio_full.shape[0])
|
221 |
+
audio_full = audio_full[:max_duration]
|
222 |
+
|
223 |
+
duration_full = audio_full.shape[0] / float(sampling_rate)
|
224 |
+
|
225 |
+
return_np_spec = args['spec_features'] or args['spec_slices']
|
226 |
+
|
227 |
+
# loop through larger file and split into chunks
|
228 |
+
# TODO fix so that it overlaps correctly and takes care of duplicate detections at borders
|
229 |
+
num_chunks = int(np.ceil(duration_full/args['chunk_size']))
|
230 |
+
for chunk_id in range(num_chunks):
|
231 |
+
|
232 |
+
# chunk
|
233 |
+
chunk_time = args['chunk_size']*chunk_id
|
234 |
+
chunk_length = int(sampling_rate*args['chunk_size'])
|
235 |
+
start_sample = chunk_id*chunk_length
|
236 |
+
end_sample = np.minimum((chunk_id+1)*chunk_length, audio_full.shape[0])
|
237 |
+
audio = audio_full[start_sample:end_sample]
|
238 |
+
|
239 |
+
# load audio file and compute spectrogram
|
240 |
+
duration, spec, spec_np = compute_spectrogram(audio, sampling_rate, params, return_np_spec)
|
241 |
+
|
242 |
+
# evaluate model
|
243 |
+
with torch.no_grad():
|
244 |
+
outputs = model(spec, return_feats=args['cnn_features'])
|
245 |
+
|
246 |
+
# run non-max suppression
|
247 |
+
pred_nms, features = pp.run_nms(outputs, params, np.array([float(sampling_rate)]))
|
248 |
+
pred_nms = pred_nms[0]
|
249 |
+
pred_nms['start_times'] += chunk_time
|
250 |
+
pred_nms['end_times'] += chunk_time
|
251 |
+
|
252 |
+
# if we have a background class
|
253 |
+
if pred_nms['class_probs'].shape[0] > len(params['class_names']):
|
254 |
+
pred_nms['class_probs'] = pred_nms['class_probs'][:-1, :]
|
255 |
+
|
256 |
+
predictions.append(pred_nms)
|
257 |
+
|
258 |
+
# extract features - if there are any calls detected
|
259 |
+
if (pred_nms['det_probs'].shape[0] > 0):
|
260 |
+
if args['spec_features']:
|
261 |
+
spec_feats.append(feats.get_feats(spec_np, pred_nms, params))
|
262 |
+
|
263 |
+
if args['cnn_features']:
|
264 |
+
cnn_feats.append(features[0])
|
265 |
+
|
266 |
+
if args['spec_slices']:
|
267 |
+
spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms, params))
|
268 |
+
|
269 |
+
# convert the predictions into output dictionary
|
270 |
+
file_id = os.path.basename(audio_file)
|
271 |
+
predictions, spec_feats, cnn_feats, spec_slices =\
|
272 |
+
merge_results(predictions, spec_feats, cnn_feats, spec_slices)
|
273 |
+
results = convert_results(file_id, time_exp, duration_full, params,
|
274 |
+
predictions, spec_feats, cnn_feats, spec_slices)
|
275 |
+
|
276 |
+
# summarize results
|
277 |
+
if not args['quiet']:
|
278 |
+
num_detections = len(results['pred_dict']['annotation'])
|
279 |
+
print('{}'.format(num_detections) + ' call(s) detected above the threshold.')
|
280 |
+
|
281 |
+
# print results for top n classes
|
282 |
+
if not args['quiet'] and (num_detections > 0):
|
283 |
+
class_overall = pp.overall_class_pred(predictions['det_probs'], predictions['class_probs'])
|
284 |
+
print('species name'.ljust(30) + 'probablity present')
|
285 |
+
for cc in np.argsort(class_overall)[::-1][:top_n]:
|
286 |
+
print(params['class_names'][cc].ljust(30) + str(round(class_overall[cc], 3)))
|
287 |
+
|
288 |
+
if return_raw_preds:
|
289 |
+
return predictions
|
290 |
+
else:
|
291 |
+
return results
|
bat_detect/utils/plot_utils.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import json
|
4 |
+
from sklearn.metrics import confusion_matrix
|
5 |
+
from matplotlib import patches
|
6 |
+
from matplotlib.collections import PatchCollection
|
7 |
+
|
8 |
+
from . import audio_utils as au
|
9 |
+
|
10 |
+
|
11 |
+
def create_box_image(spec, fig, detections_ip, start_time, end_time, duration, params, max_val, hide_axis=True, plot_class_names=False):
|
12 |
+
# filter detections
|
13 |
+
stop_time = start_time + duration
|
14 |
+
detections = []
|
15 |
+
for bb in detections_ip:
|
16 |
+
if (bb['start_time'] >= start_time) and (bb['start_time'] < stop_time-0.02): #(bb['end_time'] < end_time):
|
17 |
+
detections.append(bb)
|
18 |
+
|
19 |
+
# create figure
|
20 |
+
freq_scale = 1000 # turn Hz to kHz
|
21 |
+
min_freq = params['min_freq']//freq_scale
|
22 |
+
max_freq = params['max_freq']//freq_scale
|
23 |
+
y_extent = [0, duration, min_freq, max_freq]
|
24 |
+
|
25 |
+
if hide_axis:
|
26 |
+
ax = plt.Axes(fig, [0., 0., 1., 1.])
|
27 |
+
ax.set_axis_off()
|
28 |
+
fig.add_axes(ax)
|
29 |
+
else:
|
30 |
+
ax = plt.gca()
|
31 |
+
|
32 |
+
plt.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent, vmin=0, vmax=max_val)
|
33 |
+
boxes = plot_bounding_box_patch_ann(detections, freq_scale, start_time)
|
34 |
+
ax.add_collection(PatchCollection(boxes, match_original=True))
|
35 |
+
plt.grid(False)
|
36 |
+
|
37 |
+
if plot_class_names:
|
38 |
+
for ii, bb in enumerate(boxes):
|
39 |
+
txt = ' '.join([sp[:3] for sp in detections_ip[ii]['class'].split(' ')])
|
40 |
+
font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()}
|
41 |
+
y_pos = bb.get_xy()[1] + bb.get_height()
|
42 |
+
if y_pos > (max_freq - 10):
|
43 |
+
y_pos = max_freq - 10
|
44 |
+
plt.gca().text(bb.get_xy()[0], y_pos, txt, fontdict=font_info)
|
45 |
+
|
46 |
+
|
47 |
+
def save_ann_spec(op_path, spec, min_freq, max_freq, duration, start_time, title_text='', anns=None):
|
48 |
+
# create figure and plot boxes
|
49 |
+
freq_scale = 1000 # turn Hz to kHz
|
50 |
+
min_freq = min_freq//freq_scale
|
51 |
+
max_freq = max_freq//freq_scale
|
52 |
+
y_extent = [0, duration, min_freq, max_freq]
|
53 |
+
|
54 |
+
plt.close('all')
|
55 |
+
fig = plt.figure(0, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100)
|
56 |
+
plt.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent, vmin=0, vmax=spec.max()*1.1)
|
57 |
+
|
58 |
+
plt.ylabel('Freq - kHz')
|
59 |
+
plt.xlabel('Time - secs')
|
60 |
+
if title_text != '':
|
61 |
+
plt.title(title_text)
|
62 |
+
plt.tight_layout()
|
63 |
+
|
64 |
+
if anns is not None:
|
65 |
+
# drawing bounding boxes and class names
|
66 |
+
boxes = plot_bounding_box_patch_ann(anns, freq_scale, start_time)
|
67 |
+
plt.gca().add_collection(PatchCollection(boxes, match_original=True))
|
68 |
+
for ii, bb in enumerate(boxes):
|
69 |
+
txt = ' '.join([sp[:3] for sp in anns[ii]['class'].split(' ')])
|
70 |
+
font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()}
|
71 |
+
y_pos = bb.get_xy()[1] + bb.get_height()
|
72 |
+
if y_pos > (max_freq - 10):
|
73 |
+
y_pos = max_freq - 10
|
74 |
+
plt.gca().text(bb.get_xy()[0], y_pos, txt, fontdict=font_info)
|
75 |
+
|
76 |
+
print('Saving figure to:', op_path)
|
77 |
+
plt.savefig(op_path)
|
78 |
+
|
79 |
+
|
80 |
+
def plot_pts(fig_id, feats, class_names, colors, marker_size=4.0, plot_legend=False):
|
81 |
+
plt.figure(fig_id)
|
82 |
+
un_class, labels = np.unique(class_names, return_inverse=True)
|
83 |
+
un_labels = np.unique(labels)
|
84 |
+
if un_labels.shape[0] > len(colors):
|
85 |
+
colors = [plt.cm.jet(float(ii)/un_labels.shape[0]) for ii in un_labels]
|
86 |
+
|
87 |
+
for ii, u in enumerate(un_labels):
|
88 |
+
inds = np.where(labels==u)[0]
|
89 |
+
plt.scatter(feats[inds, 0], feats[inds, 1], c=colors[ii], label=str(un_class[ii]), s=marker_size)
|
90 |
+
if plot_legend:
|
91 |
+
plt.legend()
|
92 |
+
plt.xticks([])
|
93 |
+
plt.yticks([])
|
94 |
+
plt.title('downsampled features')
|
95 |
+
|
96 |
+
|
97 |
+
def plot_bounding_box_patch(pred, freq_scale, ecolor='w'):
|
98 |
+
patch_collect = []
|
99 |
+
for bb in range(len(pred['start_times'])):
|
100 |
+
xx = pred['start_times'][bb]
|
101 |
+
ww = pred['end_times'][bb] - pred['start_times'][bb]
|
102 |
+
yy = pred['low_freqs'][bb] / freq_scale
|
103 |
+
hh = (pred['high_freqs'][bb] - pred['low_freqs'][bb]) / freq_scale
|
104 |
+
|
105 |
+
if 'det_probs' in pred.keys():
|
106 |
+
alpha_val = pred['det_probs'][bb]
|
107 |
+
else:
|
108 |
+
alpha_val = 1.0
|
109 |
+
patch_collect.append(patches.Rectangle((xx, yy), ww, hh, linewidth=1,
|
110 |
+
edgecolor=ecolor, facecolor='none', alpha=alpha_val))
|
111 |
+
return patch_collect
|
112 |
+
|
113 |
+
|
114 |
+
def plot_bounding_box_patch_ann(anns, freq_scale, start_time):
|
115 |
+
patch_collect = []
|
116 |
+
for aa in range(len(anns)):
|
117 |
+
xx = anns[aa]['start_time'] - start_time
|
118 |
+
ww = anns[aa]['end_time'] - anns[aa]['start_time']
|
119 |
+
yy = anns[aa]['low_freq'] / freq_scale
|
120 |
+
hh = (anns[aa]['high_freq'] - anns[aa]['low_freq']) / freq_scale
|
121 |
+
if 'det_prob' in anns[aa]:
|
122 |
+
alpha = anns[aa]['det_prob']
|
123 |
+
else:
|
124 |
+
alpha = 1.0
|
125 |
+
patch_collect.append(patches.Rectangle((xx,yy), ww, hh, linewidth=1,
|
126 |
+
edgecolor='w', facecolor='none', alpha=alpha))
|
127 |
+
return patch_collect
|
128 |
+
|
129 |
+
|
130 |
+
def plot_spec(spec, sampling_rate, duration, gt, pred, params, plot_title,
|
131 |
+
op_file_name, pred_2d_hm, plot_boxes=True, fixed_aspect=True):
|
132 |
+
|
133 |
+
if fixed_aspect:
|
134 |
+
# ouptut image will be this width irrespective of the duration of the audio file
|
135 |
+
width = 12
|
136 |
+
else:
|
137 |
+
width = 12*duration
|
138 |
+
|
139 |
+
fig = plt.figure(1, figsize=(width, 8))
|
140 |
+
ax0 = plt.axes([0.05, 0.65, 0.9, 0.30]) # l b w h
|
141 |
+
ax1 = plt.axes([0.05, 0.33, 0.9, 0.30])
|
142 |
+
ax2 = plt.axes([0.05, 0.01, 0.9, 0.30])
|
143 |
+
|
144 |
+
freq_scale = 1000 # turn Hz in kHz
|
145 |
+
#duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap'])
|
146 |
+
y_extent = [0, duration, params['min_freq']//freq_scale, params['max_freq']//freq_scale]
|
147 |
+
|
148 |
+
# plot gt boxes
|
149 |
+
ax0.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent)
|
150 |
+
ax0.xaxis.set_ticklabels([])
|
151 |
+
font_info = {'color': 'white', 'size': 12, 'weight': 'bold'}
|
152 |
+
ax0.text(0, params['min_freq']//freq_scale, 'Ground Truth', fontdict=font_info)
|
153 |
+
|
154 |
+
plt.grid(False)
|
155 |
+
if plot_boxes:
|
156 |
+
boxes = plot_bounding_box_patch(gt, freq_scale)
|
157 |
+
ax0.add_collection(PatchCollection(boxes, match_original=True))
|
158 |
+
for ii, bb in enumerate(boxes):
|
159 |
+
class_id = int(gt['class_ids'][ii])
|
160 |
+
if class_id < 0:
|
161 |
+
txt = params['generic_class'][0]
|
162 |
+
else:
|
163 |
+
txt = params['class_names_short'][class_id]
|
164 |
+
font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()}
|
165 |
+
y_pos = bb.get_xy()[1] + bb.get_height()
|
166 |
+
ax0.text(bb.get_xy()[0], y_pos, txt, fontdict=font_info)
|
167 |
+
|
168 |
+
# plot predicted boxes
|
169 |
+
ax1.imshow(spec, aspect='auto', cmap='plasma', extent=y_extent)
|
170 |
+
ax1.xaxis.set_ticklabels([])
|
171 |
+
font_info = {'color': 'white', 'size': 12, 'weight': 'bold'}
|
172 |
+
ax1.text(0, params['min_freq']//freq_scale, 'Prediction', fontdict=font_info)
|
173 |
+
|
174 |
+
plt.grid(False)
|
175 |
+
if plot_boxes:
|
176 |
+
boxes = plot_bounding_box_patch(pred, freq_scale)
|
177 |
+
ax1.add_collection(PatchCollection(boxes, match_original=True))
|
178 |
+
for ii, bb in enumerate(boxes):
|
179 |
+
if pred['class_probs'].shape[0] > len(params['class_names_short']):
|
180 |
+
class_id = pred['class_probs'][:-1, ii].argmax()
|
181 |
+
else:
|
182 |
+
class_id = pred['class_probs'][:, ii].argmax()
|
183 |
+
txt = params['class_names_short'][class_id]
|
184 |
+
font_info = {'color': 'white', 'size': 10, 'weight': 'bold', 'alpha': bb.get_alpha()}
|
185 |
+
y_pos = bb.get_xy()[1] + bb.get_height()
|
186 |
+
ax1.text(bb.get_xy()[0], y_pos, txt, fontdict=font_info)
|
187 |
+
|
188 |
+
# plot 2D heatmap
|
189 |
+
if pred_2d_hm is not None:
|
190 |
+
min_val = 0.0 if pred_2d_hm.min() > 0.0 else pred_2d_hm.min()
|
191 |
+
max_val = 1.0 if pred_2d_hm.max() < 1.0 else pred_2d_hm.max()
|
192 |
+
|
193 |
+
ax2.imshow(pred_2d_hm, aspect='auto', cmap='plasma', extent=y_extent, clim=[min_val, max_val])
|
194 |
+
#ax2.xaxis.set_ticklabels([])
|
195 |
+
font_info = {'color': 'white', 'size': 12, 'weight': 'bold'}
|
196 |
+
ax2.text(0, params['min_freq']//freq_scale, 'Heatmap', fontdict=font_info)
|
197 |
+
|
198 |
+
plt.grid(False)
|
199 |
+
|
200 |
+
plt.suptitle(plot_title)
|
201 |
+
if op_file_name is not None:
|
202 |
+
fig.savefig(op_file_name)
|
203 |
+
|
204 |
+
plt.close(1)
|
205 |
+
|
206 |
+
|
207 |
+
def plot_pr_curve(op_dir, plt_title, file_name, results, file_type='png', title_text=''):
|
208 |
+
precision = results['precision']
|
209 |
+
recall = results['recall']
|
210 |
+
avg_prec = results['avg_prec']
|
211 |
+
|
212 |
+
plt.figure(0, figsize=(10,8))
|
213 |
+
plt.plot(recall, precision)
|
214 |
+
plt.ylabel('Precision', fontsize=20)
|
215 |
+
plt.xlabel('Recall', fontsize=20)
|
216 |
+
if title_text != '':
|
217 |
+
plt.title(title_text, fontdict={'fontsize': 28})
|
218 |
+
else:
|
219 |
+
plt.title(plt_title + ' {:.3f}\n'.format(avg_prec))
|
220 |
+
plt.xlim(0,1.02)
|
221 |
+
plt.ylim(0,1.02)
|
222 |
+
plt.grid(True)
|
223 |
+
plt.tight_layout()
|
224 |
+
plt.savefig(op_dir + file_name + '.' + file_type)
|
225 |
+
plt.close(0)
|
226 |
+
|
227 |
+
|
228 |
+
def plot_pr_curve_class(op_dir, plt_title, file_name, results, file_type='png', title_text=''):
|
229 |
+
plt.figure(0, figsize=(10,8))
|
230 |
+
plt.ylabel('Precision', fontsize=20)
|
231 |
+
plt.xlabel('Recall', fontsize=20)
|
232 |
+
plt.xlim(0,1.02)
|
233 |
+
plt.ylim(0,1.02)
|
234 |
+
plt.grid(True)
|
235 |
+
linestyles = ['-', ':', '--']
|
236 |
+
markers = ['o', 'v', '>', '^', '<', 's', 'P', 'X', '*']
|
237 |
+
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
|
238 |
+
|
239 |
+
# plot the PR curves
|
240 |
+
for ii, rr in enumerate(results['class_pr']):
|
241 |
+
class_name = ' '.join([sp[:3] for sp in rr['name'].split(' ')])
|
242 |
+
cur_color = colors[int(ii%10)]
|
243 |
+
plt.plot(rr['recall'], rr['precision'], label=class_name, color=cur_color,
|
244 |
+
linestyle=linestyles[int(ii//10)], lw=2.5)
|
245 |
+
|
246 |
+
#print(class_name)
|
247 |
+
# plot the location of the confidence threshold values
|
248 |
+
for jj, tt in enumerate(rr['thresholds']):
|
249 |
+
ind = rr['thresholds_inds'][jj]
|
250 |
+
if ind > -1:
|
251 |
+
plt.plot(rr['recall'][ind], rr['precision'][ind], markers[jj],
|
252 |
+
color=cur_color, ms=10)
|
253 |
+
#print(np.round(tt,2), np.round(rr['recall'][ind],3), np.round(rr['precision'][ind],3))
|
254 |
+
|
255 |
+
if title_text != '':
|
256 |
+
plt.title(title_text, fontdict={'fontsize': 28})
|
257 |
+
else:
|
258 |
+
plt.title(plt_title + ' {:.3f}\n'.format(results['avg_prec_class']))
|
259 |
+
plt.legend(loc='lower left', prop={'size': 14})
|
260 |
+
plt.tight_layout()
|
261 |
+
plt.savefig(op_dir + file_name + '.' + file_type)
|
262 |
+
plt.close(0)
|
263 |
+
|
264 |
+
|
265 |
+
def plot_confusion_matrix(op_dir, op_file, gt, pred, file_acc, class_names_long, verbose=False, file_type='png', title_text=''):
|
266 |
+
# shorten the class names for plotting
|
267 |
+
class_names = []
|
268 |
+
for cc in class_names_long:
|
269 |
+
class_name_sm = ''.join([cc_sm[:3] + ' ' for cc_sm in cc.split(' ')])[:-1]
|
270 |
+
class_names.append(class_name_sm)
|
271 |
+
|
272 |
+
num_classes = len(class_names)
|
273 |
+
cm = confusion_matrix(gt, pred, labels=np.arange(num_classes)).astype(np.float32)
|
274 |
+
cm_norm = cm.sum(1)
|
275 |
+
|
276 |
+
valid_inds = np.where(cm_norm > 0)[0]
|
277 |
+
cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
|
278 |
+
cm[np.where(cm_norm ==- 0)[0], :] = np.nan
|
279 |
+
|
280 |
+
if verbose:
|
281 |
+
print('Per class accuracy:')
|
282 |
+
str_len = np.max([len(cc) for cc in class_names_long]) + 5
|
283 |
+
accs = np.diag(cm)
|
284 |
+
for ii, cc in enumerate(class_names_long):
|
285 |
+
if np.isnan(accs[ii]):
|
286 |
+
print(str(ii).ljust(5) + cc.ljust(str_len))
|
287 |
+
else:
|
288 |
+
print(str(ii).ljust(5) + cc.ljust(str_len) + '{:.2f}'.format(accs[ii]*100))
|
289 |
+
|
290 |
+
plt.figure(0, figsize=(10,8))
|
291 |
+
plt.imshow(cm, vmin=0, vmax=1, cmap='plasma')
|
292 |
+
plt.colorbar()
|
293 |
+
plt.xticks(np.arange(cm.shape[1]), class_names, rotation='vertical')
|
294 |
+
plt.yticks(np.arange(cm.shape[0]), class_names)
|
295 |
+
plt.xlabel('Predicted', fontsize=20)
|
296 |
+
plt.ylabel('Ground Truth', fontsize=20)
|
297 |
+
if title_text != '':
|
298 |
+
plt.title(title_text, fontdict={'fontsize': 28})
|
299 |
+
else:
|
300 |
+
plt.title(op_file + ' {:.3f}\n'.format(file_acc))
|
301 |
+
plt.tight_layout()
|
302 |
+
plt.savefig(op_dir + op_file + '.' + file_type)
|
303 |
+
plt.close('all')
|
304 |
+
|
305 |
+
|
306 |
+
class LossPlotter(object):
|
307 |
+
def __init__(self, op_file_name, duration, labels, ylim, class_names, axis_labels=None, logy=False):
|
308 |
+
self.reset()
|
309 |
+
self.op_file_name = op_file_name
|
310 |
+
self.duration = duration # length of x axis
|
311 |
+
self.labels = labels
|
312 |
+
self.ylim = ylim
|
313 |
+
self.class_names = class_names
|
314 |
+
self.axis_labels = axis_labels
|
315 |
+
self.logy = logy
|
316 |
+
|
317 |
+
def reset(self):
|
318 |
+
self.epochs = []
|
319 |
+
self.vals = []
|
320 |
+
|
321 |
+
def update_and_save(self, epoch, val, gt=None, pred=None):
|
322 |
+
self.epochs.append(epoch)
|
323 |
+
self.vals.append(val)
|
324 |
+
self.save_plot()
|
325 |
+
self.save_json()
|
326 |
+
if gt is not None:
|
327 |
+
self.save_confusion_matrix(gt, pred)
|
328 |
+
|
329 |
+
def save_plot(self):
|
330 |
+
linestyles = ['-', ':', '--']
|
331 |
+
plt.figure(0, figsize=(8,5))
|
332 |
+
for ii in range(len(self.vals[0])):
|
333 |
+
l_vals = [vv[ii] for vv in self.vals]
|
334 |
+
plt.plot(self.epochs, l_vals, label=self.labels[ii], linestyle=linestyles[int(ii//10)])
|
335 |
+
plt.xlim(0, np.maximum(self.duration, len(self.vals)))
|
336 |
+
if self.ylim is not None:
|
337 |
+
plt.ylim(self.ylim[0], self.ylim[1])
|
338 |
+
if self.axis_labels is not None:
|
339 |
+
plt.xlabel(self.axis_labels[0])
|
340 |
+
plt.ylabel(self.axis_labels[1])
|
341 |
+
if self.logy:
|
342 |
+
plt.gca().set_yscale('log')
|
343 |
+
plt.grid(True)
|
344 |
+
plt.legend(bbox_to_anchor=(1.01, 1), loc='upper left', borderaxespad=0.0)
|
345 |
+
plt.tight_layout()
|
346 |
+
plt.savefig(self.op_file_name)
|
347 |
+
plt.close(0)
|
348 |
+
|
349 |
+
def save_json(self):
|
350 |
+
data = {}
|
351 |
+
data['epochs'] = self.epochs
|
352 |
+
for ii in range(len(self.vals[0])):
|
353 |
+
data[self.labels[ii]] = [round(vv[ii],4) for vv in self.vals]
|
354 |
+
with open(self.op_file_name[:-4] + '.json', 'w') as da:
|
355 |
+
json.dump(data, da, indent=2)
|
356 |
+
|
357 |
+
def save_confusion_matrix(self, gt, pred):
|
358 |
+
plt.figure(0)
|
359 |
+
cm = confusion_matrix(gt, pred, np.arange(len(self.class_names))).astype(np.float32)
|
360 |
+
cm_norm = cm.sum(1)
|
361 |
+
valid_inds = np.where(cm_norm > 0)[0]
|
362 |
+
cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
|
363 |
+
plt.imshow(cm, vmin=0, vmax=1, cmap='plasma')
|
364 |
+
plt.colorbar()
|
365 |
+
plt.xticks(np.arange(cm.shape[1]), self.class_names, rotation='vertical')
|
366 |
+
plt.yticks(np.arange(cm.shape[0]), self.class_names)
|
367 |
+
plt.xlabel('Predicted')
|
368 |
+
plt.ylabel('Ground Truth')
|
369 |
+
plt.tight_layout()
|
370 |
+
plt.savefig(self.op_file_name[:-4] + '_cm.png')
|
371 |
+
plt.close(0)
|
bat_detect/utils/visualize.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
from matplotlib import patches
|
4 |
+
from sklearn.svm import LinearSVC
|
5 |
+
from matplotlib.axes._axes import _log as matplotlib_axes_logger
|
6 |
+
matplotlib_axes_logger.setLevel('ERROR')
|
7 |
+
|
8 |
+
|
9 |
+
colors = ['#e6194B', '#3cb44b', '#ffe119', '#4363d8', '#f58231', '#911eb4',
|
10 |
+
'#42d4f4', '#f032e6', '#bfef45', '#fabebe', '#469990', '#e6beff',
|
11 |
+
'#9A6324', '#fffac8', '#800000', '#aaffc3', '#808000', '#ffd8b1',
|
12 |
+
'#000075', '#a9a9a9']
|
13 |
+
|
14 |
+
|
15 |
+
class InteractivePlotter:
|
16 |
+
def __init__(self, feats_ds, feats, spec_slices, call_info, freq_lims, allow_training):
|
17 |
+
"""
|
18 |
+
Plots 2D low dimensional features on left and corresponding spectgrams on
|
19 |
+
the right.
|
20 |
+
"""
|
21 |
+
self.feats_ds = feats_ds
|
22 |
+
self.feats = feats
|
23 |
+
self.clf = None
|
24 |
+
|
25 |
+
self.spec_slices = spec_slices
|
26 |
+
self.call_info = call_info
|
27 |
+
#_, self.labels = np.unique([cc['class'] for cc in call_info], return_inverse=True)
|
28 |
+
self.labels = np.zeros(len(call_info), dtype=np.int)
|
29 |
+
self.annotated = np.zeros(self.labels.shape[0], dtype=np.int) # can populate this with 1's where we have labels
|
30 |
+
self.labels_cols = [colors[self.labels[ii]] for ii in range(len(self.labels))]
|
31 |
+
self.freq_lims = freq_lims
|
32 |
+
|
33 |
+
self.allow_training = allow_training
|
34 |
+
self.pt_size = 5.0
|
35 |
+
self.spec_pad = 0.2 # this much padding has been applied to the spec slices
|
36 |
+
self.fig_width = 12
|
37 |
+
self.fig_height = 8
|
38 |
+
|
39 |
+
self.current_id = 0
|
40 |
+
max_ind = np.argmax([ss.shape[1] for ss in self.spec_slices])
|
41 |
+
self.max_width = self.spec_slices[max_ind].shape[1]
|
42 |
+
self.blank_spec = np.zeros((self.spec_slices[0].shape[0], self.max_width))
|
43 |
+
|
44 |
+
|
45 |
+
def plot(self, fig_id):
|
46 |
+
self.fig, self.ax = plt.subplots(nrows=1, ncols=2, num=fig_id, figsize=(self.fig_width, self.fig_height),
|
47 |
+
gridspec_kw={'width_ratios': [2, 1]})
|
48 |
+
plt.tight_layout()
|
49 |
+
|
50 |
+
# plot 2D TNSE features
|
51 |
+
self.low_dim_plt = self.ax[0].scatter(self.feats_ds[:, 0], self.feats_ds[:, 1],
|
52 |
+
c=self.labels_cols, s=self.pt_size, picker=5)
|
53 |
+
self.ax[0].set_title('TSNE of Call Features')
|
54 |
+
self.ax[0].set_xticks([])
|
55 |
+
self.ax[0].set_yticks([])
|
56 |
+
|
57 |
+
# plot clip from spectrogram
|
58 |
+
spec_min_max = (0, self.blank_spec.shape[1], self.freq_lims[0], self.freq_lims[1])
|
59 |
+
self.ax[1].imshow(self.blank_spec, extent=spec_min_max, cmap='plasma', aspect='auto')
|
60 |
+
self.spec_im = self.ax[1].get_images()[0]
|
61 |
+
self.ax[1].set_title('Spectrogram')
|
62 |
+
self.ax[1].grid(color='w', linewidth=0.5)
|
63 |
+
self.ax[1].set_xticks([])
|
64 |
+
self.ax[1].set_ylabel('kHz')
|
65 |
+
|
66 |
+
bbox_orig = patches.Rectangle((0,0),0,0, edgecolor='w', linewidth=0, fill=False)
|
67 |
+
self.ax[1].add_patch(bbox_orig)
|
68 |
+
|
69 |
+
self.annot = self.ax[0].annotate('', xy=(0,0), xytext=(20,20),textcoords='offset points',
|
70 |
+
bbox=dict(boxstyle='round', fc='w'), arrowprops=dict(arrowstyle='->'))
|
71 |
+
self.annot.set_visible(False)
|
72 |
+
|
73 |
+
self.fig.canvas.mpl_connect('motion_notify_event', self.mouse_hover)
|
74 |
+
self.fig.canvas.mpl_connect('key_press_event', self.key_press)
|
75 |
+
|
76 |
+
|
77 |
+
def mouse_hover(self, event):
|
78 |
+
vis = self.annot.get_visible()
|
79 |
+
if event.inaxes == self.ax[0]:
|
80 |
+
cont, ind = self.low_dim_plt.contains(event)
|
81 |
+
if cont:
|
82 |
+
self.current_id = ind['ind'][0]
|
83 |
+
|
84 |
+
# copy spec into full window - probably a better way of doing this
|
85 |
+
new_spec = self.blank_spec.copy()
|
86 |
+
w_diff = (self.blank_spec.shape[1] - self.spec_slices[self.current_id].shape[1])//2
|
87 |
+
new_spec[:, w_diff:self.spec_slices[self.current_id].shape[1]+w_diff] = self.spec_slices[self.current_id]
|
88 |
+
self.spec_im.set_data(new_spec)
|
89 |
+
self.spec_im.set_clim(vmin=0, vmax=new_spec.max())
|
90 |
+
|
91 |
+
# draw bounding box around call
|
92 |
+
self.ax[1].patches[0].remove()
|
93 |
+
spec_width_orig = self.spec_slices[self.current_id].shape[1]/(1.0+2.0*self.spec_pad)
|
94 |
+
xx = w_diff + self.spec_pad*spec_width_orig
|
95 |
+
ww = spec_width_orig
|
96 |
+
yy = self.call_info[self.current_id]['low_freq']/1000
|
97 |
+
hh = (self.call_info[self.current_id]['high_freq']-self.call_info[self.current_id]['low_freq'])/1000
|
98 |
+
bbox = patches.Rectangle((xx,yy),ww,hh, edgecolor='r', linewidth=0.5, fill=False)
|
99 |
+
self.ax[1].add_patch(bbox)
|
100 |
+
|
101 |
+
# update annotation arrow
|
102 |
+
pos = self.low_dim_plt.get_offsets()[self.current_id]
|
103 |
+
self.annot.xy = pos
|
104 |
+
self.annot.set_visible(True)
|
105 |
+
|
106 |
+
# write call info
|
107 |
+
info_str = self.call_info[self.current_id]['file_name'] + ', time=' \
|
108 |
+
+ str(round(self.call_info[self.current_id]['start_time'],3)) \
|
109 |
+
+ ', prob=' + str(round(self.call_info[self.current_id]['det_prob'],3))
|
110 |
+
self.ax[0].set_xlabel(info_str)
|
111 |
+
|
112 |
+
# redraw
|
113 |
+
self.fig.canvas.draw_idle()
|
114 |
+
|
115 |
+
|
116 |
+
def key_press(self, event):
|
117 |
+
if event.key.isdigit():
|
118 |
+
self.labels_cols[self.current_id] = colors[int(event.key)]
|
119 |
+
self.labels[self.current_id] = int(event.key)
|
120 |
+
self.annotated[self.current_id] = 1
|
121 |
+
elif event.key == 'enter' and self.allow_training:
|
122 |
+
self.train_classifier()
|
123 |
+
elif event.key == 'x' and self.allow_training:
|
124 |
+
self.get_classifier_params()
|
125 |
+
|
126 |
+
self.ax[0].scatter(self.feats_ds[:, 0], self.feats_ds[:, 1],
|
127 |
+
c=self.labels_cols, s=self.pt_size)
|
128 |
+
self.fig.canvas.draw_idle()
|
129 |
+
|
130 |
+
|
131 |
+
def train_classifier(self):
|
132 |
+
# TODO maybe it's better to classify in 2D space - but then can't be linear ...
|
133 |
+
inds = np.where(self.annotated == 1)[0]
|
134 |
+
labs_un, labs_inds = np.unique(self.labels[inds], return_inverse=True)
|
135 |
+
|
136 |
+
if labs_un.shape[0] > 1: # needs at least 2 classes
|
137 |
+
self.clf = LinearSVC(C=1.0, penalty='l2', loss='squared_hinge', tol=0.0001,
|
138 |
+
intercept_scaling=1.0, max_iter=2000)
|
139 |
+
|
140 |
+
self.clf.fit(self.feats[inds, :], self.labels[inds])
|
141 |
+
|
142 |
+
# update labels
|
143 |
+
inds_unlab = np.where(self.annotated == 0)[0]
|
144 |
+
self.labels[inds_unlab] = self.clf.predict(self.feats[inds_unlab])
|
145 |
+
for ii in inds_unlab:
|
146 |
+
self.labels_cols[ii] = colors[self.labels[ii]]
|
147 |
+
else:
|
148 |
+
print('Not enough data - please label more classes.')
|
149 |
+
|
150 |
+
|
151 |
+
def get_classifier_params(self):
|
152 |
+
res = {}
|
153 |
+
if self.clf is None:
|
154 |
+
print('Model not trained!')
|
155 |
+
else:
|
156 |
+
res['weights'] = self.clf.coef_.astype(np.float32)
|
157 |
+
res['biases'] = self.clf.intercept_.astype(np.float32)
|
158 |
+
return res
|
bat_detect/utils/wavfile.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module to read / write wav files using numpy arrays
|
3 |
+
|
4 |
+
Functions
|
5 |
+
---------
|
6 |
+
`read`: Return the sample rate (in samples/sec) and data from a WAV file.
|
7 |
+
|
8 |
+
`write`: Write a numpy array as a WAV file.
|
9 |
+
|
10 |
+
"""
|
11 |
+
from __future__ import division, print_function, absolute_import
|
12 |
+
|
13 |
+
import sys
|
14 |
+
import numpy
|
15 |
+
import struct
|
16 |
+
import warnings
|
17 |
+
import os
|
18 |
+
|
19 |
+
|
20 |
+
class WavFileWarning(UserWarning):
|
21 |
+
pass
|
22 |
+
|
23 |
+
_big_endian = False
|
24 |
+
|
25 |
+
WAVE_FORMAT_PCM = 0x0001
|
26 |
+
WAVE_FORMAT_IEEE_FLOAT = 0x0003
|
27 |
+
WAVE_FORMAT_EXTENSIBLE = 0xfffe
|
28 |
+
KNOWN_WAVE_FORMATS = (WAVE_FORMAT_PCM, WAVE_FORMAT_IEEE_FLOAT)
|
29 |
+
|
30 |
+
# assumes file pointer is immediately
|
31 |
+
# after the 'fmt ' id
|
32 |
+
|
33 |
+
|
34 |
+
def _read_fmt_chunk(fid):
|
35 |
+
if _big_endian:
|
36 |
+
fmt = '>'
|
37 |
+
else:
|
38 |
+
fmt = '<'
|
39 |
+
res = struct.unpack(fmt+'iHHIIHH',fid.read(20))
|
40 |
+
size, comp, noc, rate, sbytes, ba, bits = res
|
41 |
+
if comp not in KNOWN_WAVE_FORMATS or size > 16:
|
42 |
+
comp = WAVE_FORMAT_PCM
|
43 |
+
warnings.warn("Unknown wave file format", WavFileWarning)
|
44 |
+
if size > 16:
|
45 |
+
fid.read(size - 16)
|
46 |
+
|
47 |
+
return size, comp, noc, rate, sbytes, ba, bits
|
48 |
+
|
49 |
+
|
50 |
+
# assumes file pointer is immediately
|
51 |
+
# after the 'data' id
|
52 |
+
def _read_data_chunk(fid, comp, noc, bits, mmap=False):
|
53 |
+
if _big_endian:
|
54 |
+
fmt = '>i'
|
55 |
+
else:
|
56 |
+
fmt = '<i'
|
57 |
+
size = struct.unpack(fmt,fid.read(4))[0]
|
58 |
+
|
59 |
+
bytes = bits//8
|
60 |
+
if bits == 8:
|
61 |
+
dtype = 'u1'
|
62 |
+
else:
|
63 |
+
if _big_endian:
|
64 |
+
dtype = '>'
|
65 |
+
else:
|
66 |
+
dtype = '<'
|
67 |
+
if comp == 1:
|
68 |
+
dtype += 'i%d' % bytes
|
69 |
+
else:
|
70 |
+
dtype += 'f%d' % bytes
|
71 |
+
if not mmap:
|
72 |
+
data = numpy.fromstring(fid.read(size), dtype=dtype)
|
73 |
+
else:
|
74 |
+
start = fid.tell()
|
75 |
+
data = numpy.memmap(fid, dtype=dtype, mode='c', offset=start,
|
76 |
+
shape=(size//bytes,))
|
77 |
+
fid.seek(start + size)
|
78 |
+
|
79 |
+
if noc > 1:
|
80 |
+
data = data.reshape(-1,noc)
|
81 |
+
return data
|
82 |
+
|
83 |
+
|
84 |
+
def _skip_unknown_chunk(fid):
|
85 |
+
if _big_endian:
|
86 |
+
fmt = '>i'
|
87 |
+
else:
|
88 |
+
fmt = '<i'
|
89 |
+
|
90 |
+
data = fid.read(4)
|
91 |
+
size = struct.unpack(fmt, data)[0]
|
92 |
+
fid.seek(size, 1)
|
93 |
+
|
94 |
+
|
95 |
+
def _read_riff_chunk(fid):
|
96 |
+
global _big_endian
|
97 |
+
str1 = fid.read(4)
|
98 |
+
if str1 == b'RIFX':
|
99 |
+
_big_endian = True
|
100 |
+
elif str1 != b'RIFF':
|
101 |
+
raise ValueError("Not a WAV file.")
|
102 |
+
if _big_endian:
|
103 |
+
fmt = '>I'
|
104 |
+
else:
|
105 |
+
fmt = '<I'
|
106 |
+
fsize = struct.unpack(fmt, fid.read(4))[0] + 8
|
107 |
+
str2 = fid.read(4)
|
108 |
+
if (str2 != b'WAVE'):
|
109 |
+
raise ValueError("Not a WAV file.")
|
110 |
+
if str1 == b'RIFX':
|
111 |
+
_big_endian = True
|
112 |
+
return fsize
|
113 |
+
|
114 |
+
# open a wave-file
|
115 |
+
|
116 |
+
|
117 |
+
def read(filename, mmap=False):
|
118 |
+
"""
|
119 |
+
Return the sample rate (in samples/sec) and data from a WAV file
|
120 |
+
|
121 |
+
Parameters
|
122 |
+
----------
|
123 |
+
filename : string or open file handle
|
124 |
+
Input wav file.
|
125 |
+
mmap : bool, optional
|
126 |
+
Whether to read data as memory mapped.
|
127 |
+
Only to be used on real files (Default: False)
|
128 |
+
|
129 |
+
.. versionadded:: 0.12.0
|
130 |
+
|
131 |
+
Returns
|
132 |
+
-------
|
133 |
+
rate : int
|
134 |
+
Sample rate of wav file
|
135 |
+
data : numpy array
|
136 |
+
Data read from wav file
|
137 |
+
|
138 |
+
Notes
|
139 |
+
-----
|
140 |
+
|
141 |
+
* The file can be an open file or a filename.
|
142 |
+
|
143 |
+
* The returned sample rate is a Python integer
|
144 |
+
* The data is returned as a numpy array with a
|
145 |
+
data-type determined from the file.
|
146 |
+
|
147 |
+
"""
|
148 |
+
if hasattr(filename,'read'):
|
149 |
+
fid = filename
|
150 |
+
mmap = False
|
151 |
+
else:
|
152 |
+
fid = open(filename, 'rb')
|
153 |
+
|
154 |
+
try:
|
155 |
+
|
156 |
+
# some files seem to have the size recorded in the header greater than
|
157 |
+
# the actual file size.
|
158 |
+
fid.seek(0, os.SEEK_END)
|
159 |
+
actual_size = fid.tell()
|
160 |
+
fid.seek(0)
|
161 |
+
|
162 |
+
fsize = _read_riff_chunk(fid)
|
163 |
+
|
164 |
+
# the fsize should be identical to the actual size, if not
|
165 |
+
# the header information is wrong and we need to correct it.
|
166 |
+
if fsize != actual_size:
|
167 |
+
fsize = actual_size
|
168 |
+
|
169 |
+
noc = 1
|
170 |
+
bits = 8
|
171 |
+
comp = WAVE_FORMAT_PCM
|
172 |
+
while (fid.tell() < fsize):
|
173 |
+
# read the next chunk
|
174 |
+
chunk_id = fid.read(4)
|
175 |
+
if chunk_id == b'fmt ':
|
176 |
+
size, comp, noc, rate, sbytes, ba, bits = _read_fmt_chunk(fid)
|
177 |
+
elif chunk_id == b'fact':
|
178 |
+
_skip_unknown_chunk(fid)
|
179 |
+
elif chunk_id == b'data':
|
180 |
+
data = _read_data_chunk(fid, comp, noc, bits, mmap=mmap)
|
181 |
+
elif chunk_id == b'LIST':
|
182 |
+
# Someday this could be handled properly but for now skip it
|
183 |
+
_skip_unknown_chunk(fid)
|
184 |
+
|
185 |
+
# OMA warning - I've commented out the following lines
|
186 |
+
# else:
|
187 |
+
# warnings.warn("Chunk (non-data) not understood, skipping it.", WavFileWarning)
|
188 |
+
# _skip_unknown_chunk(fid)
|
189 |
+
finally:
|
190 |
+
if not hasattr(filename,'read'):
|
191 |
+
fid.close()
|
192 |
+
else:
|
193 |
+
fid.seek(0)
|
194 |
+
|
195 |
+
return rate, data
|
196 |
+
|
197 |
+
# Write a wave-file
|
198 |
+
# sample rate, data
|
199 |
+
|
200 |
+
|
201 |
+
def write(filename, rate, data):
|
202 |
+
"""
|
203 |
+
Write a numpy array as a WAV file
|
204 |
+
|
205 |
+
Parameters
|
206 |
+
----------
|
207 |
+
filename : string or open file handle
|
208 |
+
Output wav file
|
209 |
+
rate : int
|
210 |
+
The sample rate (in samples/sec).
|
211 |
+
data : ndarray
|
212 |
+
A 1-D or 2-D numpy array of either integer or float data-type.
|
213 |
+
|
214 |
+
Notes
|
215 |
+
-----
|
216 |
+
* The file can be an open file or a filename.
|
217 |
+
|
218 |
+
* Writes a simple uncompressed WAV file.
|
219 |
+
* The bits-per-sample will be determined by the data-type.
|
220 |
+
* To write multiple-channels, use a 2-D array of shape
|
221 |
+
(Nsamples, Nchannels).
|
222 |
+
|
223 |
+
"""
|
224 |
+
if hasattr(filename, 'write'):
|
225 |
+
fid = filename
|
226 |
+
else:
|
227 |
+
fid = open(filename, 'wb')
|
228 |
+
|
229 |
+
try:
|
230 |
+
# kind of numeric data in the numpy array
|
231 |
+
dkind = data.dtype.kind
|
232 |
+
if not (dkind == 'i' or dkind == 'f' or (dkind == 'u' and data.dtype.itemsize == 1)):
|
233 |
+
raise ValueError("Unsupported data type '%s'" % data.dtype)
|
234 |
+
|
235 |
+
# wav header stuff
|
236 |
+
# http://soundfile.sapp.org/doc/WaveFormat/
|
237 |
+
fid.write(b'RIFF')
|
238 |
+
# placeholder for chunk size (updated later)
|
239 |
+
fid.write(b'\x00\x00\x00\x00')
|
240 |
+
fid.write(b'WAVE')
|
241 |
+
# fmt chunk
|
242 |
+
fid.write(b'fmt ')
|
243 |
+
if dkind == 'f':
|
244 |
+
# comp stands for compression. PCM = 1
|
245 |
+
comp = 3
|
246 |
+
else:
|
247 |
+
comp = 1
|
248 |
+
# determine number of channels
|
249 |
+
if data.ndim == 1:
|
250 |
+
noc = 1
|
251 |
+
else:
|
252 |
+
noc = data.shape[1]
|
253 |
+
bits = data.dtype.itemsize * 8
|
254 |
+
# number of bytes per second, at the specified sampling rate rate,
|
255 |
+
# bits per sample and number of channels (just needed for wav header)
|
256 |
+
sbytes = rate*(bits // 8)*noc
|
257 |
+
# number of bytes per sample
|
258 |
+
ba = noc * (bits // 8)
|
259 |
+
|
260 |
+
# https://docs.python.org/3/library/struct.html#struct-format-strings
|
261 |
+
# Write the data (16, comp, noc, etc) in the correct binary format
|
262 |
+
# for the wav header. the string format (first arg) specifies how many bytes for each
|
263 |
+
# value.
|
264 |
+
fid.write(struct.pack('<ihHIIHH', 16, comp, noc, rate, sbytes, ba, bits))
|
265 |
+
# data chunk: the word 'data' followed by the size followed by the actual data
|
266 |
+
fid.write(b'data')
|
267 |
+
fid.write(struct.pack('<i', data.nbytes))
|
268 |
+
if data.dtype.byteorder == '>' or (data.dtype.byteorder == '=' and sys.byteorder == 'big'):
|
269 |
+
data = data.byteswap()
|
270 |
+
_array_tofile(fid, data)
|
271 |
+
|
272 |
+
# Determine file size and place it in correct
|
273 |
+
# position at start of the file (replacing the 4 bytes of zeros)
|
274 |
+
size = fid.tell()
|
275 |
+
fid.seek(4)
|
276 |
+
fid.write(struct.pack('<i', size-8))
|
277 |
+
|
278 |
+
finally:
|
279 |
+
if not hasattr(filename,'write'):
|
280 |
+
fid.close()
|
281 |
+
else:
|
282 |
+
fid.seek(0)
|
283 |
+
|
284 |
+
|
285 |
+
if sys.version_info[0] >= 3:
|
286 |
+
def _array_tofile(fid, data):
|
287 |
+
# ravel gives a c-contiguous buffer
|
288 |
+
fid.write(data.ravel().view('b').data)
|
289 |
+
else:
|
290 |
+
def _array_tofile(fid, data):
|
291 |
+
fid.write(data.tostring())
|
example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav
ADDED
Binary file (500 kB). View file
|
example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav
ADDED
Binary file (384 kB). View file
|
example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav
ADDED
Binary file (384 kB). View file
|
models/readme.md
ADDED
@@ -0,0 +1 @@
|
|
|
1 |
+
Trained models go here.
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
librosa==0.9.2
|
2 |
+
matplotlib==3.6.2
|
3 |
+
numpy==1.23.4
|
4 |
+
pandas==1.5.2
|
5 |
+
scikit_learn==1.2.0
|
6 |
+
scipy==1.9.3
|
7 |
+
torch==1.13.0
|
8 |
+
torchaudio==0.13.0
|
9 |
+
torchvision==0.14.0
|