macaodha commited on
Commit
73fd754
1 Parent(s): 7475d7b

update app

Browse files
Files changed (1) hide show
  1. app.py +45 -15
app.py CHANGED
@@ -15,19 +15,19 @@ args = du.get_default_bd_args()
15
  args['detection_threshold'] = 0.3
16
  args['time_expansion_factor'] = 1
17
  args['model_path'] = 'models/Net2DFast_UK_same.pth.tar'
 
18
 
19
  # load the model
20
  model, params = du.load_model(args['model_path'])
21
 
22
 
23
  df = gr.Dataframe(
24
- headers=["species", "time_in_file", "species_prob"],
25
- datatype=["str", "str", "str"],
26
  row_count=1,
27
- col_count=(3, "fixed"),
28
  )
29
 
30
-
31
  examples = [['example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav', 0.3],
32
  ['example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav', 0.3],
33
  ['example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav', 0.3]]
@@ -43,33 +43,63 @@ def make_prediction(file_name=None, detection_threshold=0.3):
43
  if detection_threshold is not None and detection_threshold != '':
44
  args['detection_threshold'] = float(detection_threshold)
45
 
46
- results = du.process_file(audio_file, model, params, args, max_duration=5.0)
 
 
 
 
 
 
 
 
47
 
48
- clss = [aa['class'] for aa in results['pred_dict']['annotation']]
49
- st_time = [aa['start_time'] for aa in results['pred_dict']['annotation']]
50
- cls_prob = [aa['class_prob'] for aa in results['pred_dict']['annotation']]
51
-
52
- data = {'species': clss, 'time_in_file': st_time, 'species_prob': cls_prob}
53
  df = pd.DataFrame(data=data)
 
 
 
54
 
55
- return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
 
58
  descr_txt = "Demo of BatDetect2 deep learning-based bat echolocation call detection. " \
59
  "<br>This model is only trained on bat species from the UK. If the input " \
60
- "file is longer than 5 seconds, only the first 5 seconds will be processed." \
61
  "<br>Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)."
62
 
63
  gr.Interface(
64
  fn = make_prediction,
65
  inputs = [gr.Audio(source="upload", type="filepath", optional=True),
66
  gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])],
67
- outputs = df,
68
  theme = "huggingface",
69
  title = "BatDetect2 Demo",
70
  description = descr_txt,
71
  examples = examples,
72
  allow_flagging = 'never',
73
  ).launch()
74
-
75
-
 
15
  args['detection_threshold'] = 0.3
16
  args['time_expansion_factor'] = 1
17
  args['model_path'] = 'models/Net2DFast_UK_same.pth.tar'
18
+ max_duration = 2.0
19
 
20
  # load the model
21
  model, params = du.load_model(args['model_path'])
22
 
23
 
24
  df = gr.Dataframe(
25
+ headers=["species", "time", "detection_prob", "species_prob"],
26
+ datatype=["str", "str", "str", "str"],
27
  row_count=1,
28
+ col_count=(4, "fixed"),
29
  )
30
 
 
31
  examples = [['example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav', 0.3],
32
  ['example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav', 0.3],
33
  ['example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav', 0.3]]
 
43
  if detection_threshold is not None and detection_threshold != '':
44
  args['detection_threshold'] = float(detection_threshold)
45
 
46
+ # process the file to generate predictions
47
+ results = du.process_file(audio_file, model, params, args, max_duration=max_duration)
48
+
49
+ anns = [ann for ann in results['pred_dict']['annotation']]
50
+ clss = [aa['class'] for aa in anns]
51
+ st_time = [aa['start_time'] for aa in anns]
52
+ cls_prob = [aa['class_prob'] for aa in anns]
53
+ det_prob = [aa['det_prob'] for aa in anns]
54
+ data = {'species': clss, 'time': st_time, 'detection_prob': det_prob, 'species_prob': cls_prob}
55
 
 
 
 
 
 
56
  df = pd.DataFrame(data=data)
57
+ im = generate_results_image(audio_file, anns)
58
+
59
+ return [df, im]
60
 
61
+
62
+ def generate_results_image(audio_file, anns):
63
+
64
+ # load audio
65
+ sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'],
66
+ params['target_samp_rate'], params['scale_raw_audio'], max_duration=max_duration)
67
+ duration = audio.shape[0] / sampling_rate
68
+
69
+ # generate spec
70
+ spec, spec_viz = au.generate_spectrogram(audio, sampling_rate, params, True, False)
71
+
72
+ # create fig
73
+ plt.close('all')
74
+ fig = plt.figure(1, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100, frameon=False)
75
+ spec_duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap'])
76
+ viz.create_box_image(spec, fig, anns, 0, spec_duration, spec_duration, params, spec.max()*1.1, False, True)
77
+ plt.ylabel('Freq - kHz')
78
+ plt.xlabel('Time - secs')
79
+ plt.tight_layout()
80
+
81
+ # convert fig to image
82
+ fig.canvas.draw()
83
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
84
+ w, h = fig.canvas.get_width_height()
85
+ im = data.reshape((int(h), int(w), -1))
86
+
87
+ return im
88
 
89
 
90
  descr_txt = "Demo of BatDetect2 deep learning-based bat echolocation call detection. " \
91
  "<br>This model is only trained on bat species from the UK. If the input " \
92
+ "file is longer than 2 seconds, only the first 2 seconds will be processed." \
93
  "<br>Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)."
94
 
95
  gr.Interface(
96
  fn = make_prediction,
97
  inputs = [gr.Audio(source="upload", type="filepath", optional=True),
98
  gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])],
99
+ outputs = [df, "image"],
100
  theme = "huggingface",
101
  title = "BatDetect2 Demo",
102
  description = descr_txt,
103
  examples = examples,
104
  allow_flagging = 'never',
105
  ).launch()