batdetect2 / app.py
Oisin Mac Aodha
added bat code
9ace58a
raw
history blame
2.39 kB
import gradio as gr
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import bat_detect.utils.detector_utils as du
import bat_detect.utils.audio_utils as au
import bat_detect.utils.plot_utils as viz
# setup the arguments
args = {}
args = du.get_default_bd_args()
args['detection_threshold'] = 0.3
args['time_expansion_factor'] = 1
args['model_path'] = 'models/Net2DFast_UK_same.pth.tar'
# load the model
model, params = du.load_model(args['model_path'])
df = gr.Dataframe(
headers=["species", "time_in_file", "species_prob"],
datatype=["str", "str", "str"],
row_count=1,
col_count=(3, "fixed"),
)
examples = [['example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav', 0.3],
['example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav', 0.3],
['example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav', 0.3]]
def make_prediction(file_name=None, detection_threshold=0.3):
if file_name is not None:
audio_file = file_name
else:
return "You must provide an input audio file."
if detection_threshold != '':
args['detection_threshold'] = float(detection_threshold)
results = du.process_file(audio_file, model, params, args, max_duration=5.0)
clss = [aa['class'] for aa in results['pred_dict']['annotation']]
st_time = [aa['start_time'] for aa in results['pred_dict']['annotation']]
cls_prob = [aa['class_prob'] for aa in results['pred_dict']['annotation']]
data = {'species': clss, 'time_in_file': st_time, 'species_prob': cls_prob}
df = pd.DataFrame(data=data)
return df
descr_txt = "Demo of BatDetect2 deep learning-based bat echolocation call detection. " \
"<br>This model is only trained on bat species from the UK. If the input " \
"file is longer than 5 seconds, only the first 5 seconds will be processed." \
"<br>Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)."
gr.Interface(
fn = make_prediction,
inputs = [gr.Audio(source="upload", type="filepath", optional=True),
gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])],
outputs = df,
theme = "huggingface",
title = "BatDetect2 Demo",
description = descr_txt,
examples = examples,
allow_flagging = 'never',
).launch()