File size: 2,838 Bytes
daf1ccd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import gradio as gr
from scipy.io.wavfile import write
import config

import torch
from model.htsat import HTSAT_Swin_Transformer
from sed_model import SEDWrapper
import librosa
import numpy as np

example_path = './examples_audio'

class_mapping = ['dog', 'rooster', 'pig', 'cow', 'frog', 'cat', 'hen', 'insects', 'sheep', 'crow', 'rain', 'sea_waves', 'crackling_fire', 'crickets', 'chirping_birds', 'water_drops', 'wind', 'pouring_water', 'toilet_flush', 'thunderstorm', 'crying_baby', 'sneezing', 'clapping', 'breathing', 'coughing', 'footsteps', 'laughing',
                 'brushing_teeth', 'snoring', 'drinking_sipping', 'door_wood_knock', 'mouse_click', 'keyboard_typing', 'door_wood_creaks', 'can_opening', 'washing_machine', 'vacuum_cleaner', 'clock_alarm', 'clock_tick', 'glass_breaking', 'helicopter', 'chainsaw', 'siren', 'car_horn', 'engine', 'train', 'church_bells', 'airplane', 'fireworks', 'hand_saw']

sed_model = HTSAT_Swin_Transformer(
    spec_size=config.htsat_spec_size,
    patch_size=config.htsat_patch_size,
    in_chans=1,
    num_classes=config.classes_num,
    window_size=config.htsat_window_size,
    config=config,
    depths=config.htsat_depth,
    embed_dim=config.htsat_dim,
    patch_stride=config.htsat_stride,
    num_heads=config.htsat_num_head
)

model = SEDWrapper(
    sed_model=sed_model,
    config=config,
    dataset=None
)

ckpt = torch.load(config.resume_checkpoint, map_location="cpu")
model.load_state_dict(ckpt["state_dict"], strict=False)

def inference(audio):
    sr, y = audio
    y = y/32767.0 # scipy vs librosa
    if len(y.shape) != 1: # to mono
        y = y[:,0]
    y = librosa.resample(y, orig_sr=sr, target_sr=32000)
    in_val = np.array([y])
    result = model.inference(in_val)
    pred = result['clipwise_output'][0]
    # pred = np.exp(pred)/np.sum(np.exp(pred)) # softmax
    return {class_mapping[i]: float(p) for i, p in enumerate(pred)}
    # win_classes = np.argmax(result['clipwise_output'], axis=1)
    # win_class_index = win_classes[0]
    # win_class_name = class_mapping[win_class_index]
    # return str({win_class_name: result['clipwise_output'][0][win_class_index]})


title = "HTS-Audio-Transformer"
description = "Audio classificatio with ESC-50."
# article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1911.13254' target='_blank'>Music Source Separation in the Waveform Domain</a> | <a href='https://github.com/facebookresearch/demucs' target='_blank'>Github Repo</a></p>"

examples = [['test.mp3']]
gr.Interface(
    inference,
    gr.inputs.Audio(type="numpy", label="Input"),
    # gr.outputs.Textbox(),
    gr.outputs.Label(),
    title=title,
    description=description,
    # article=article,
    examples=[[os.path.join(example_path, f)]
              for f in os.listdir(example_path)]
).launch(enable_queue=True)