nbm_v1 / app.py
paulpeyret-biophonia
fix example file
50d2530
# %%
import os
import csv
import soundfile as sf
import numpy as np
import json
import time
import gradio
import torchaudio
from torchvision.utils import save_image
from pathlib import Path
import sys
import pandas as pd
from PIL import Image
from ia_model_utils import *
# %%
DATADIR = os.getcwd()
model_path = os.path.join(DATADIR, "ia_data/")
# change directory to the dataset where our
# custom scripts are found
#os.chdir(os.path.join(model_path, "code"))
from ia_model_utils import *
# reset our working directory
os.chdir(DATADIR)
# Import model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
bird_call_detection = load_model(model_path, post_nms_topN_eval=50, device=device)
# check the example wav files that are in the example root folder
EXAMPLES_PATH = Path(".")
# Liste les fichiers exemples présents à la racine du dossier
example_files = []
fname_examples = example_files
# fname_examples=!ls *.wav # on enregistre la liste des fichiers dans une variable
print("Voici la liste des fichiers disponibles dans l'environement")
print(fname_examples)
# %%
interface_options = {
"title": "NBM Classification",
"description": "Online NBM classifier: Please upload a wavfile",
"article": "Online NBM classifier: Please upload a wavfile",
# Audio from validation file
"examples": fname_examples, # [str(EXAMPLES_PATH/f) for f in fname_examples],
"allow_flagging": "never",
}
def modeloutput_to_formated_data(class_bbox, temp_res, out_fname="labels-test.txt"):
"This function format the output data of the model to be standard Audacity format"
# Import bird dictionary here
dict_dir = "./ia_data"
with open(os.path.join(dict_dir, "bird_dict.json"), "r") as f:
birds_dict = json.load(f)
birds_dict.update({"Non bird sound": 0, "Other": len(birds_dict) + 1})
reverse_dict = {id: bird_name for bird_name, id in birds_dict.items()}
# Maps bbox coord into a readable format
output = {
reverse_dict[idx]: {
key: value.cpu().numpy().tolist()
for key, value in class_bbox[str(idx)].items()
}
for idx in range(len(reverse_dict))
if len(class_bbox[str(idx)]["bbox_coord"]) > 0
}
# Save class_bbox, output and spectrogram
# Convert to table
table = []
table.append(["label", "score", "x1", "y1", "x2", "y2"])
for species_entry in output.items():
for i in range(len(species_entry[1]["bbox_coord"])):
label = species_entry[0]
bbox = species_entry[1]["bbox_coord"][i]
score = species_entry[1]["scores"][i]
row = [label, score] + bbox
table.append(row)
# Sort table by bbox position
table[1:].sort(key=lambda entry: float(entry[2]))
# Convert to audacity txt format
data = ""
for row in table[1:]:
label = row[0]
score = row[1]
if label == "Non bird sound":
continue
if label == "Other":
continue
x1, y1, x2, y2 = row[2:]
# Convert coordinates
y1 = y1 * 33.3 + 500
y2 = y2 * 33.3 + 500
x1 *= temp_res
x2 *= temp_res
# Format data
entry = f"{x1}\t{x2}\t{label} {score:.2f}\n\\\t{y1}\t{y2}\n"
data += entry
with open(out_fname, "w") as f:
f.write(data)
return out_fname, data
def end2endpipeline(wav_path):
class_bbox, spectrogram, temp_res = bird_call_detection.process_wav(
wav_path, min_score=0.5
)
out_fname = Path(wav_path).stem + ".txt"
out_f, data = modeloutput_to_formated_data(
class_bbox, temp_res, out_fname=out_fname
)
return out_f, data
demo = gradio.Interface(
fn=end2endpipeline,
inputs=gradio.Audio(sources="upload", type="filepath"),
outputs=[gradio.File(), gradio.Textbox(interactive=False)],
**interface_options,
)
launch_options = {
#"enable_queue": True,
# "share": True,
"inline": True,
# "cache_examples": True,
}
demo.launch(**launch_options)