BirdNet / app.py
BilalSardar's picture
Update app.py
9d7f3c9
import concurrent.futures
import os
import sys
from multiprocessing import freeze_support
from pathlib import Path
import gradio as gr
import librosa
#import webview
import analyze
import config as cfg
import segments
import species
import utils
from train import trainModel
#_WINDOW: webview.Window
OUTPUT_TYPE_MAP = {"Raven selection table": "table", "Audacity": "audacity", "R": "r", "CSV": "csv"}
ORIGINAL_MODEL_PATH = cfg.MODEL_PATH
ORIGINAL_MDATA_MODEL_PATH = cfg.MDATA_MODEL_PATH
ORIGINAL_LABELS_FILE = cfg.LABELS_FILE
ORIGINAL_TRANSLATED_LABELS_PATH = cfg.TRANSLATED_LABELS_PATH
def analyzeFile_wrapper(entry):
return (entry[0], analyze.analyzeFile(entry))
def extractSegments_wrapper(entry):
return (entry[0][0], segments.extractSegments(entry))
def validate(value, msg):
"""Checks if the value ist not falsy.
If the value is falsy, an error will be raised.
Args:
value: Value to be tested.
msg: Message in case of an error.
"""
if not value:
raise gr.Error(msg)
def runSingleFileAnalysis(
input_path,
confidence,
sensitivity,
overlap,
species_list_choice,
species_list_file,
lat,
lon,
week,
use_yearlong,
sf_thresh,
custom_classifier_file,
locale,
):
validate(input_path, "Please select a file.")
return runAnalysis(
input_path,
None,
confidence,
sensitivity,
overlap,
species_list_choice,
species_list_file,
lat,
lon,
week,
use_yearlong,
sf_thresh,
custom_classifier_file,
"csv",
"en" if not locale else locale,
1,
4,
None,
progress=None,
)
def runAnalysis(
input_path: str,
output_path: str | None,
confidence: float,
sensitivity: float,
overlap: float,
species_list_choice: str,
species_list_file,
lat: float,
lon: float,
week: int,
use_yearlong: bool,
sf_thresh: float,
custom_classifier_file,
output_type: str,
locale: str,
batch_size: int,
threads: int,
input_dir: str,
progress: gr.Progress | None,
):
"""Starts the analysis.
Args:
input_path: Either a file or directory.
output_path: The output path for the result, if None the input_path is used
confidence: The selected minimum confidence.
sensitivity: The selected sensitivity.
overlap: The selected segment overlap.
species_list_choice: The choice for the species list.
species_list_file: The selected custom species list file.
lat: The selected latitude.
lon: The selected longitude.
week: The selected week of the year.
use_yearlong: Use yearlong instead of week.
sf_thresh: The threshold for the predicted species list.
custom_classifier_file: Custom classifier to be used.
output_type: The type of result to be generated.
locale: The translation to be used.
batch_size: The number of samples in a batch.
threads: The number of threads to be used.
input_dir: The input directory.
progress: The gradio progress bar.
"""
if progress is not None:
progress(0, desc="Preparing ...")
locale = locale.lower()
# Load eBird codes, labels
cfg.CODES = analyze.loadCodes()
cfg.LABELS = utils.readLines(ORIGINAL_LABELS_FILE)
cfg.LATITUDE, cfg.LONGITUDE, cfg.WEEK = lat, lon, -1 if use_yearlong else week
cfg.LOCATION_FILTER_THRESHOLD = sf_thresh
if species_list_choice == _CUSTOM_SPECIES:
if not species_list_file or not species_list_file.name:
cfg.SPECIES_LIST_FILE = None
else:
cfg.SPECIES_LIST_FILE = os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), species_list_file.name)
if os.path.isdir(cfg.SPECIES_LIST_FILE):
cfg.SPECIES_LIST_FILE = os.path.join(cfg.SPECIES_LIST_FILE, "species_list.txt")
cfg.SPECIES_LIST = utils.readLines(cfg.SPECIES_LIST_FILE)
cfg.CUSTOM_CLASSIFIER = None
elif species_list_choice == _PREDICT_SPECIES:
cfg.SPECIES_LIST_FILE = None
cfg.CUSTOM_CLASSIFIER = None
cfg.SPECIES_LIST = species.getSpeciesList(cfg.LATITUDE, cfg.LONGITUDE, cfg.WEEK, cfg.LOCATION_FILTER_THRESHOLD)
elif species_list_choice == _CUSTOM_CLASSIFIER:
if custom_classifier_file is None:
raise gr.Error("No custom classifier selected.")
# Set custom classifier?
cfg.CUSTOM_CLASSIFIER = custom_classifier_file # we treat this as absolute path, so no need to join with dirname
cfg.LABELS_FILE = custom_classifier_file.replace(".tflite", "_Labels.txt") # same for labels file
cfg.LABELS = utils.readLines(cfg.LABELS_FILE)
cfg.LATITUDE = -1
cfg.LONGITUDE = -1
cfg.SPECIES_LIST_FILE = None
cfg.SPECIES_LIST = []
locale = "en"
else:
cfg.SPECIES_LIST_FILE = None
cfg.SPECIES_LIST = []
cfg.CUSTOM_CLASSIFIER = None
# Load translated labels
lfile = os.path.join(cfg.TRANSLATED_LABELS_PATH, os.path.basename(cfg.LABELS_FILE).replace(".txt", f"_{locale}.txt"))
if not locale in ["en"] and os.path.isfile(lfile):
cfg.TRANSLATED_LABELS = utils.readLines(lfile)
else:
cfg.TRANSLATED_LABELS = cfg.LABELS
if len(cfg.SPECIES_LIST) == 0:
print(f"Species list contains {len(cfg.LABELS)} species")
else:
print(f"Species list contains {len(cfg.SPECIES_LIST)} species")
# Set input and output path
cfg.INPUT_PATH = input_path
if input_dir:
cfg.OUTPUT_PATH = output_path if output_path else input_dir
else:
cfg.OUTPUT_PATH = output_path if output_path else input_path.split(".", 1)[0] + ".csv"
# Parse input files
if input_dir:
cfg.FILE_LIST = utils.collect_audio_files(input_dir)
cfg.INPUT_PATH = input_dir
elif os.path.isdir(cfg.INPUT_PATH):
cfg.FILE_LIST = utils.collect_audio_files(cfg.INPUT_PATH)
else:
cfg.FILE_LIST = [cfg.INPUT_PATH]
validate(cfg.FILE_LIST, "No audio files found.")
# Set confidence threshold
cfg.MIN_CONFIDENCE = confidence
# Set sensitivity
cfg.SIGMOID_SENSITIVITY = sensitivity
# Set overlap
cfg.SIG_OVERLAP = overlap
# Set result type
cfg.RESULT_TYPE = OUTPUT_TYPE_MAP[output_type] if output_type in OUTPUT_TYPE_MAP else output_type.lower()
if not cfg.RESULT_TYPE in ["table", "audacity", "r", "csv"]:
cfg.RESULT_TYPE = "table"
# Set number of threads
if input_dir:
cfg.CPU_THREADS = max(1, int(threads))
cfg.TFLITE_THREADS = 1
else:
cfg.CPU_THREADS = 1
cfg.TFLITE_THREADS = max(1, int(threads))
# Set batch size
cfg.BATCH_SIZE = max(1, int(batch_size))
flist = []
for f in cfg.FILE_LIST:
flist.append((f, cfg.getConfig()))
result_list = []
if progress is not None:
progress(0, desc="Starting ...")
# Analyze files
if cfg.CPU_THREADS < 2:
for entry in flist:
result = analyzeFile_wrapper(entry)
result_list.append(result)
else:
with concurrent.futures.ProcessPoolExecutor(max_workers=cfg.CPU_THREADS) as executor:
futures = (executor.submit(analyzeFile_wrapper, arg) for arg in flist)
for i, f in enumerate(concurrent.futures.as_completed(futures), start=1):
if progress is not None:
progress((i, len(flist)), total=len(flist), unit="files")
result = f.result()
result_list.append(result)
return [[os.path.relpath(r[0], input_dir), r[1]] for r in result_list] if input_dir else cfg.OUTPUT_PATH
_CUSTOM_SPECIES = "Custom species list"
_PREDICT_SPECIES = "Species by location"
_CUSTOM_CLASSIFIER = "Custom classifier"
_ALL_SPECIES = "all species"
def show_species_choice(choice: str):
"""Sets the visibility of the species list choices.
Args:
choice: The label of the currently active choice.
Returns:
A list of [
Row update,
File update,
Column update,
Column update,
]
"""
if choice == _CUSTOM_SPECIES:
return [
gr.Row.update(visible=False),
gr.File.update(visible=True),
gr.Column.update(visible=False),
gr.Column.update(visible=False),
]
elif choice == _PREDICT_SPECIES:
return [
gr.Row.update(visible=True),
gr.File.update(visible=False),
gr.Column.update(visible=False),
gr.Column.update(visible=False),
]
elif choice == _CUSTOM_CLASSIFIER:
return [
gr.Row.update(visible=False),
gr.File.update(visible=False),
gr.Column.update(visible=True),
gr.Column.update(visible=False),
]
return [
gr.Row.update(visible=False),
gr.File.update(visible=False),
gr.Column.update(visible=False),
gr.Column.update(visible=True),
]
def select_subdirectories():
"""Creates a directory selection dialog.
Returns:
A tuples of (directory, list of subdirectories) or (None, None) if the dialog was canceled.
"""
dir_name = _WINDOW.create_file_dialog(webview.FOLDER_DIALOG)
if dir_name:
subdirs = utils.list_subdirectories(dir_name[0])
return dir_name[0], [[d] for d in subdirs]
return None, None
def select_file(filetypes=()):
"""Creates a file selection dialog.
Args:
filetypes: List of filetypes to be filtered in the dialog.
Returns:
The selected file or None of the dialog was canceled.
"""
files = _WINDOW.create_file_dialog(webview.OPEN_DIALOG, file_types=filetypes)
return files[0] if files else None
def format_seconds(secs: float):
"""Formats a number of seconds into a string.
Formats the seconds into the format "h:mm:ss.ms"
Args:
secs: Number of seconds.
Returns:
A string with the formatted seconds.
"""
hours, secs = divmod(secs, 3600)
minutes, secs = divmod(secs, 60)
return "{:2.0f}:{:02.0f}:{:06.3f}".format(hours, minutes, secs)
def select_directory(collect_files=True):
"""Shows a directory selection system dialog.
Uses the pywebview to create a system dialog.
Args:
collect_files: If True, also lists a files inside the directory.
Returns:
If collect_files==True, returns (directory path, list of (relative file path, audio length))
else just the directory path.
All values will be None of the dialog is cancelled.
"""
dir_name = _WINDOW.create_file_dialog(webview.FOLDER_DIALOG)
if collect_files:
if not dir_name:
return None, None
files = utils.collect_audio_files(dir_name[0])
return dir_name[0], [
[os.path.relpath(file, dir_name[0]), format_seconds(librosa.get_duration(filename=file))] for file in files
]
return dir_name[0] if dir_name else None
def sample_sliders(opened=True):
"""Creates the gradio accordion for the inference settings.
Args:
opened: If True the accordion is open on init.
Returns:
A tuple with the created elements:
(Slider (min confidence), Slider (sensitivity), Slider (overlap))
"""
with gr.Accordion("Inference settings", open=opened):
with gr.Row():
confidence_slider = gr.Slider(
minimum=0, maximum=1, value=0.5, step=0.01, label="Minimum Confidence", info="Minimum confidence threshold."
)
sensitivity_slider = gr.Slider(
minimum=0.5,
maximum=1.5,
value=1,
step=0.01,
label="Sensitivity",
info="Detection sensitivity; Higher values result in higher sensitivity.",
)
overlap_slider = gr.Slider(
minimum=0, maximum=2.99, value=0, step=0.01, label="Overlap", info="Overlap of prediction segments."
)
return confidence_slider, sensitivity_slider, overlap_slider
def locale():
"""Creates the gradio elements for locale selection
Reads the translated labels inside the checkpoints directory.
Returns:
The dropdown element.
"""
label_files = os.listdir(os.path.join(os.path.dirname(sys.argv[0]), ORIGINAL_TRANSLATED_LABELS_PATH))
options = ["EN"] + [label_file.rsplit("_", 1)[-1].split(".")[0].upper() for label_file in label_files]
return gr.Dropdown(options, value="EN", label="Locale", info="Locale for the translated species common names.")
def species_lists(opened=True):
"""Creates the gradio accordion for species selection.
Args:
opened: If True the accordion is open on init.
Returns:
A tuple with the created elements:
(Radio (choice), File (custom species list), Slider (lat), Slider (lon), Slider (week), Slider (threshold), Checkbox (yearlong?), State (custom classifier))
"""
with gr.Accordion("Species selection", open=opened):
with gr.Row():
species_list_radio = gr.Radio(
[_CUSTOM_SPECIES, _PREDICT_SPECIES, _CUSTOM_CLASSIFIER, _ALL_SPECIES],
value=_ALL_SPECIES,
label="Species list",
info="List of all possible species",
elem_classes="d-block",
)
with gr.Column(visible=False) as position_row:
lat_number = gr.Slider(
minimum=-90, maximum=90, value=0, step=1, label="Latitude", info="Recording location latitude."
)
lon_number = gr.Slider(
minimum=-180, maximum=180, value=0, step=1, label="Longitude", info="Recording location longitude."
)
with gr.Row():
yearlong_checkbox = gr.Checkbox(True, label="Year-round")
week_number = gr.Slider(
minimum=1,
maximum=48,
value=1,
step=1,
interactive=False,
label="Week",
info="Week of the year when the recording was made. Values in [1, 48] (4 weeks per month).",
)
def onChange(use_yearlong):
return gr.Slider.update(interactive=(not use_yearlong))
yearlong_checkbox.change(onChange, inputs=yearlong_checkbox, outputs=week_number, show_progress=False)
sf_thresh_number = gr.Slider(
minimum=0.01,
maximum=0.99,
value=0.03,
step=0.01,
label="Location filter threshold",
info="Minimum species occurrence frequency threshold for location filter.",
)
species_file_input = gr.File(file_types=[".txt"], info="Path to species list file or folder.", visible=False)
empty_col = gr.Column()
with gr.Column(visible=False) as custom_classifier_selector:
classifier_selection_button = gr.Button("Select classifier")
classifier_file_input = gr.Files(
file_types=[".tflite"], info="Path to the custom classifier.", visible=False, interactive=False
)
selected_classifier_state = gr.State()
def on_custom_classifier_selection_click():
file = select_file(("TFLite classifier (*.tflite)",))
if file:
labels = os.path.splitext(file)[0] + "_Labels.txt"
return file, gr.File.update(value=[file, labels], visible=True)
return None
classifier_selection_button.click(
on_custom_classifier_selection_click,
outputs=[selected_classifier_state, classifier_file_input],
show_progress=False,
)
species_list_radio.change(
show_species_choice,
inputs=[species_list_radio],
outputs=[position_row, species_file_input, custom_classifier_selector, empty_col],
show_progress=False,
)
return (
species_list_radio,
species_file_input,
lat_number,
lon_number,
week_number,
sf_thresh_number,
yearlong_checkbox,
selected_classifier_state,
)
if __name__ == "__main__":
freeze_support()
def build_single_analysis_tab():
with gr.Tab("Single file"):
audio_input = gr.Audio(type="filepath", label="file", elem_id="single_file_audio")
confidence_slider, sensitivity_slider, overlap_slider = sample_sliders(False)
(
species_list_radio,
species_file_input,
lat_number,
lon_number,
week_number,
sf_thresh_number,
yearlong_checkbox,
selected_classifier_state,
) = species_lists(False)
locale_radio = locale()
inputs = [
audio_input,
confidence_slider,
sensitivity_slider,
overlap_slider,
species_list_radio,
species_file_input,
lat_number,
lon_number,
week_number,
yearlong_checkbox,
sf_thresh_number,
selected_classifier_state,
locale_radio,
]
output_dataframe = gr.Dataframe(
type="pandas",
headers=["Start (s)", "End (s)", "Scientific name", "Common name", "Confidence"],
elem_classes="mh-200",
)
single_file_analyze = gr.Button("Analyze")
single_file_analyze.click(runSingleFileAnalysis, inputs=inputs, outputs=output_dataframe)
with gr.Blocks(
css=r".d-block .wrap {display: block !important;} .mh-200 {max-height: 300px; overflow-y: auto !important;} footer {display: none !important;} #single_file_audio, #single_file_audio * {max-height: 81.6px; min-height: 0;}",
theme=gr.themes.Default(),
analytics_enabled=False,
) as demo:
build_single_analysis_tab()
demo.launch(show_api=True)
#url = demo.queue(api_open=False).launch(prevent_thread_lock=True, quiet=True)[1]
#_WINDOW = webview.create_window("BirdNET-Analyzer", url.rstrip("/") + "?__theme=light", min_size=(1024, 768))
#webview.start(private_mode=False)