Spaces:
Build error
Build error
import os | |
import gradio as gr | |
from scipy.io.wavfile import write | |
import config | |
import torch | |
from model.htsat import HTSAT_Swin_Transformer | |
from sed_model import SEDWrapper | |
import librosa | |
import numpy as np | |
example_path = 'examples_audio' | |
class_mapping = ['dog', 'rooster', 'pig', 'cow', 'frog', 'cat', 'hen', 'insects', 'sheep', 'crow', 'rain', 'sea_waves', 'crackling_fire', 'crickets', 'chirping_birds', 'water_drops', 'wind', 'pouring_water', 'toilet_flush', 'thunderstorm', 'crying_baby', 'sneezing', 'clapping', 'breathing', 'coughing', 'footsteps', 'laughing', | |
'brushing_teeth', 'snoring', 'drinking_sipping', 'door_wood_knock', 'mouse_click', 'keyboard_typing', 'door_wood_creaks', 'can_opening', 'washing_machine', 'vacuum_cleaner', 'clock_alarm', 'clock_tick', 'glass_breaking', 'helicopter', 'chainsaw', 'siren', 'car_horn', 'engine', 'train', 'church_bells', 'airplane', 'fireworks', 'hand_saw'] | |
sed_model = HTSAT_Swin_Transformer( | |
spec_size=config.htsat_spec_size, | |
patch_size=config.htsat_patch_size, | |
in_chans=1, | |
num_classes=config.classes_num, | |
window_size=config.htsat_window_size, | |
config=config, | |
depths=config.htsat_depth, | |
embed_dim=config.htsat_dim, | |
patch_stride=config.htsat_stride, | |
num_heads=config.htsat_num_head | |
) | |
model = SEDWrapper( | |
sed_model=sed_model, | |
config=config, | |
dataset=None | |
) | |
ckpt = torch.load(config.resume_checkpoint, map_location="cpu") | |
model.load_state_dict(ckpt["state_dict"], strict=False) | |
def inference(audio): | |
sr, y = audio | |
y = y/32767.0 # scipy vs librosa | |
if len(y.shape) != 1: # to mono | |
y = y[:,0] | |
y = librosa.resample(y, orig_sr=sr, target_sr=32000) | |
in_val = np.array([y]) | |
result = model.inference(in_val) | |
pred = result['clipwise_output'][0] | |
# pred = np.exp(pred)/np.sum(np.exp(pred)) # softmax | |
output = {class_mapping[i]: float(p) for i, p in enumerate(pred)} | |
win_classes = np.argmax(result['clipwise_output'], axis=1) | |
win_class_index = win_classes[0] | |
win_class_name = class_mapping[win_class_index] | |
# return str({win_class_name: result['clipwise_output'][0][win_class_index]}) | |
return str({win_class_name: result['clipwise_output'][0][win_class_index]}), output | |
title = "HTS-Audio-Transformer" | |
description = "Audio classificatio with ESC-50." | |
# article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1911.13254' target='_blank'>Music Source Separation in the Waveform Domain</a> | <a href='https://github.com/facebookresearch/demucs' target='_blank'>Github Repo</a></p>" | |
examples = [['test.mp3']] | |
gr.Interface( | |
inference, | |
gr.inputs.Audio(type="numpy", label="Input"), | |
# gr.outputs.Textbox(), | |
[gr.outputs.Textbox(), gr.outputs.JSON()], | |
title=title, | |
description=description, | |
# article=article, | |
examples=[[os.path.join(example_path, f)] | |
for f in os.listdir(example_path)] | |
).launch(enable_queue=True) | |