import os import gradio as gr from scipy.io.wavfile import write import config import torch from model.htsat import HTSAT_Swin_Transformer from sed_model import SEDWrapper import librosa import numpy as np example_path = './examples_audio' class_mapping = ['dog', 'rooster', 'pig', 'cow', 'frog', 'cat', 'hen', 'insects', 'sheep', 'crow', 'rain', 'sea_waves', 'crackling_fire', 'crickets', 'chirping_birds', 'water_drops', 'wind', 'pouring_water', 'toilet_flush', 'thunderstorm', 'crying_baby', 'sneezing', 'clapping', 'breathing', 'coughing', 'footsteps', 'laughing', 'brushing_teeth', 'snoring', 'drinking_sipping', 'door_wood_knock', 'mouse_click', 'keyboard_typing', 'door_wood_creaks', 'can_opening', 'washing_machine', 'vacuum_cleaner', 'clock_alarm', 'clock_tick', 'glass_breaking', 'helicopter', 'chainsaw', 'siren', 'car_horn', 'engine', 'train', 'church_bells', 'airplane', 'fireworks', 'hand_saw'] sed_model = HTSAT_Swin_Transformer( spec_size=config.htsat_spec_size, patch_size=config.htsat_patch_size, in_chans=1, num_classes=config.classes_num, window_size=config.htsat_window_size, config=config, depths=config.htsat_depth, embed_dim=config.htsat_dim, patch_stride=config.htsat_stride, num_heads=config.htsat_num_head ) model = SEDWrapper( sed_model=sed_model, config=config, dataset=None ) ckpt = torch.load(config.resume_checkpoint, map_location="cpu") model.load_state_dict(ckpt["state_dict"], strict=False) def inference(audio): sr, y = audio y = y/32767.0 # scipy vs librosa if len(y.shape) != 1: # to mono y = y[:,0] y = librosa.resample(y, orig_sr=sr, target_sr=32000) in_val = np.array([y]) result = model.inference(in_val) pred = result['clipwise_output'][0] # pred = np.exp(pred)/np.sum(np.exp(pred)) # softmax return {class_mapping[i]: float(p) for i, p in enumerate(pred)} # win_classes = np.argmax(result['clipwise_output'], axis=1) # win_class_index = win_classes[0] # win_class_name = class_mapping[win_class_index] # return str({win_class_name: result['clipwise_output'][0][win_class_index]}) title = "HTS-Audio-Transformer" description = "Audio classificatio with ESC-50." # article = "

Music Source Separation in the Waveform Domain | Github Repo

" examples = [['test.mp3']] gr.Interface( inference, gr.inputs.Audio(type="numpy", label="Input"), # gr.outputs.Textbox(), gr.outputs.Label(), title=title, description=description, # article=article, examples=[[os.path.join(example_path, f)] for f in os.listdir(example_path)] ).launch(enable_queue=True)