import gradio as gr import os import matplotlib.pyplot as plt import pandas as pd import numpy as np import bat_detect.utils.detector_utils as du import bat_detect.utils.audio_utils as au import bat_detect.utils.plot_utils as viz # setup the arguments args = {} args = du.get_default_bd_args() args['detection_threshold'] = 0.3 args['time_expansion_factor'] = 1 args['model_path'] = 'models/Net2DFast_UK_same.pth.tar' max_duration = 2.0 # load the model model, params = du.load_model(args['model_path']) df = gr.Dataframe( headers=["species", "time", "detection_prob", "species_prob"], datatype=["str", "str", "str", "str"], row_count=1, col_count=(4, "fixed"), label='Predictions' ) examples = [['example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav', 0.3], ['example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav', 0.3], ['example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav', 0.3]] def make_prediction(file_name=None, detection_threshold=0.3): if file_name is not None: audio_file = file_name else: return "You must provide an input audio file." if detection_threshold is not None and detection_threshold != '': args['detection_threshold'] = float(detection_threshold) # process the file to generate predictions results = du.process_file(audio_file, model, params, args, max_duration=max_duration) anns = [ann for ann in results['pred_dict']['annotation']] clss = [aa['class'] for aa in anns] st_time = [aa['start_time'] for aa in anns] cls_prob = [aa['class_prob'] for aa in anns] det_prob = [aa['det_prob'] for aa in anns] data = {'species': clss, 'time': st_time, 'detection_prob': det_prob, 'species_prob': cls_prob} df = pd.DataFrame(data=data) im = generate_results_image(audio_file, anns) return [df, im] def generate_results_image(audio_file, anns): # load audio sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'], params['target_samp_rate'], params['scale_raw_audio'], max_duration=max_duration) duration = audio.shape[0] / sampling_rate # generate spec spec, spec_viz = au.generate_spectrogram(audio, sampling_rate, params, True, False) # create fig plt.close('all') fig = plt.figure(1, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100, frameon=False) spec_duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap']) viz.create_box_image(spec, fig, anns, 0, spec_duration, spec_duration, params, spec.max()*1.1, False, True) plt.ylabel('Freq - kHz') plt.xlabel('Time - secs') plt.tight_layout() # convert fig to image fig.canvas.draw() data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) w, h = fig.canvas.get_width_height() im = data.reshape((int(h), int(w), -1)) return im descr_txt = "Demo of BatDetect2 deep learning-based bat echolocation call detection. " \ "
This model is only trained on bat species from the UK. If the input " \ "file is longer than 2 seconds, only the first 2 seconds will be processed." \ "
Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)." gr.Interface( fn = make_prediction, inputs = [gr.Audio(source="upload", type="filepath", optional=True), gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])], outputs = [df, gr.Image(label="Visualisation")], theme = "huggingface", title = "BatDetect2 Demo", description = descr_txt, examples = examples, allow_flagging = 'never', ).launch()