Spaces:
Running
Running
update app
Browse files
app.py
CHANGED
@@ -15,19 +15,19 @@ args = du.get_default_bd_args()
|
|
15 |
args['detection_threshold'] = 0.3
|
16 |
args['time_expansion_factor'] = 1
|
17 |
args['model_path'] = 'models/Net2DFast_UK_same.pth.tar'
|
|
|
18 |
|
19 |
# load the model
|
20 |
model, params = du.load_model(args['model_path'])
|
21 |
|
22 |
|
23 |
df = gr.Dataframe(
|
24 |
-
headers=["species", "
|
25 |
-
datatype=["str", "str", "str"],
|
26 |
row_count=1,
|
27 |
-
col_count=(
|
28 |
)
|
29 |
|
30 |
-
|
31 |
examples = [['example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav', 0.3],
|
32 |
['example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav', 0.3],
|
33 |
['example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav', 0.3]]
|
@@ -43,33 +43,63 @@ def make_prediction(file_name=None, detection_threshold=0.3):
|
|
43 |
if detection_threshold is not None and detection_threshold != '':
|
44 |
args['detection_threshold'] = float(detection_threshold)
|
45 |
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
clss = [aa['class'] for aa in results['pred_dict']['annotation']]
|
49 |
-
st_time = [aa['start_time'] for aa in results['pred_dict']['annotation']]
|
50 |
-
cls_prob = [aa['class_prob'] for aa in results['pred_dict']['annotation']]
|
51 |
-
|
52 |
-
data = {'species': clss, 'time_in_file': st_time, 'species_prob': cls_prob}
|
53 |
df = pd.DataFrame(data=data)
|
|
|
|
|
|
|
54 |
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
|
58 |
descr_txt = "Demo of BatDetect2 deep learning-based bat echolocation call detection. " \
|
59 |
"<br>This model is only trained on bat species from the UK. If the input " \
|
60 |
-
"file is longer than
|
61 |
"<br>Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)."
|
62 |
|
63 |
gr.Interface(
|
64 |
fn = make_prediction,
|
65 |
inputs = [gr.Audio(source="upload", type="filepath", optional=True),
|
66 |
gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])],
|
67 |
-
outputs = df,
|
68 |
theme = "huggingface",
|
69 |
title = "BatDetect2 Demo",
|
70 |
description = descr_txt,
|
71 |
examples = examples,
|
72 |
allow_flagging = 'never',
|
73 |
).launch()
|
74 |
-
|
75 |
-
|
|
|
15 |
args['detection_threshold'] = 0.3
|
16 |
args['time_expansion_factor'] = 1
|
17 |
args['model_path'] = 'models/Net2DFast_UK_same.pth.tar'
|
18 |
+
max_duration = 2.0
|
19 |
|
20 |
# load the model
|
21 |
model, params = du.load_model(args['model_path'])
|
22 |
|
23 |
|
24 |
df = gr.Dataframe(
|
25 |
+
headers=["species", "time", "detection_prob", "species_prob"],
|
26 |
+
datatype=["str", "str", "str", "str"],
|
27 |
row_count=1,
|
28 |
+
col_count=(4, "fixed"),
|
29 |
)
|
30 |
|
|
|
31 |
examples = [['example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav', 0.3],
|
32 |
['example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav', 0.3],
|
33 |
['example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav', 0.3]]
|
|
|
43 |
if detection_threshold is not None and detection_threshold != '':
|
44 |
args['detection_threshold'] = float(detection_threshold)
|
45 |
|
46 |
+
# process the file to generate predictions
|
47 |
+
results = du.process_file(audio_file, model, params, args, max_duration=max_duration)
|
48 |
+
|
49 |
+
anns = [ann for ann in results['pred_dict']['annotation']]
|
50 |
+
clss = [aa['class'] for aa in anns]
|
51 |
+
st_time = [aa['start_time'] for aa in anns]
|
52 |
+
cls_prob = [aa['class_prob'] for aa in anns]
|
53 |
+
det_prob = [aa['det_prob'] for aa in anns]
|
54 |
+
data = {'species': clss, 'time': st_time, 'detection_prob': det_prob, 'species_prob': cls_prob}
|
55 |
|
|
|
|
|
|
|
|
|
|
|
56 |
df = pd.DataFrame(data=data)
|
57 |
+
im = generate_results_image(audio_file, anns)
|
58 |
+
|
59 |
+
return [df, im]
|
60 |
|
61 |
+
|
62 |
+
def generate_results_image(audio_file, anns):
|
63 |
+
|
64 |
+
# load audio
|
65 |
+
sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'],
|
66 |
+
params['target_samp_rate'], params['scale_raw_audio'], max_duration=max_duration)
|
67 |
+
duration = audio.shape[0] / sampling_rate
|
68 |
+
|
69 |
+
# generate spec
|
70 |
+
spec, spec_viz = au.generate_spectrogram(audio, sampling_rate, params, True, False)
|
71 |
+
|
72 |
+
# create fig
|
73 |
+
plt.close('all')
|
74 |
+
fig = plt.figure(1, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100, frameon=False)
|
75 |
+
spec_duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap'])
|
76 |
+
viz.create_box_image(spec, fig, anns, 0, spec_duration, spec_duration, params, spec.max()*1.1, False, True)
|
77 |
+
plt.ylabel('Freq - kHz')
|
78 |
+
plt.xlabel('Time - secs')
|
79 |
+
plt.tight_layout()
|
80 |
+
|
81 |
+
# convert fig to image
|
82 |
+
fig.canvas.draw()
|
83 |
+
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
84 |
+
w, h = fig.canvas.get_width_height()
|
85 |
+
im = data.reshape((int(h), int(w), -1))
|
86 |
+
|
87 |
+
return im
|
88 |
|
89 |
|
90 |
descr_txt = "Demo of BatDetect2 deep learning-based bat echolocation call detection. " \
|
91 |
"<br>This model is only trained on bat species from the UK. If the input " \
|
92 |
+
"file is longer than 2 seconds, only the first 2 seconds will be processed." \
|
93 |
"<br>Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)."
|
94 |
|
95 |
gr.Interface(
|
96 |
fn = make_prediction,
|
97 |
inputs = [gr.Audio(source="upload", type="filepath", optional=True),
|
98 |
gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])],
|
99 |
+
outputs = [df, "image"],
|
100 |
theme = "huggingface",
|
101 |
title = "BatDetect2 Demo",
|
102 |
description = descr_txt,
|
103 |
examples = examples,
|
104 |
allow_flagging = 'never',
|
105 |
).launch()
|
|
|
|