File size: 1,536 Bytes
cac3ec7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import json
import shlex
import subprocess

import gradio as gr
import numpy as np
import requests
import timm
import torch
import torch.nn.functional as F
from torchaudio.compliance import kaldi

TAG = "gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k"
MODEL = timm.create_model(f"hf_hub:{TAG}", pretrained=True).eval()

LABEL_URL = "https://huggingface.co/datasets/huggingface/label-files/raw/main/audioset-id2label.json"
AUDIOSET_LABELS = list(json.loads(requests.get(LABEL_URL).content).values())

SAMPLING_RATE = 16_000


def resample(x: np.ndarray, sr: int):
    cmd = f"ffmpeg -ar {sr} -f s16le -i - -ar {SAMPLING_RATE} -f f32le -"
    proc = subprocess.run(shlex.split(cmd), capture_output=True, input=x.tobytes())
    return np.frombuffer(proc.stdout, dtype=np.float32)


def preprocess(x: torch.Tensor):
    melspec = kaldi.fbank(x.unsqueeze(0), htk_compat=True, window_type="hanning", num_mel_bins=128)
    if melspec.shape[0] < 1024:
        melspec = F.pad(melspec, (0, 0, 0, 1024 - melspec.shape[0]))
    else:
        melspec = melspec[:1024]
    return melspec.view(1, 1, 1024, 128)


def predict(audio):
    sr, x = audio
    x = resample(x, sr)
    x = torch.from_numpy(x)

    with torch.inference_mode():
        logits = MODEL(preprocess(x)).squeeze(0)

    topk_probs, topk_classes = logits.softmax(dim=-1).topk(5)
    return [[AUDIOSET_LABELS[cls], prob.item() * 100] for cls, prob in zip(topk_classes, topk_probs)]


iface = gr.Interface(fn=predict, inputs="audio", outputs="dataframe")
iface.launch()