Spaces:
Sleeping
Sleeping
import json | |
import shlex | |
import subprocess | |
import gradio as gr | |
import numpy as np | |
import requests | |
import timm | |
import torch | |
import torch.nn.functional as F | |
from torchaudio.compliance import kaldi | |
TAG = "gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k" | |
MODEL = timm.create_model(f"hf_hub:{TAG}", pretrained=True).eval() | |
LABEL_URL = "https://huggingface.co/datasets/huggingface/label-files/raw/main/audioset-id2label.json" | |
AUDIOSET_LABELS = list(json.loads(requests.get(LABEL_URL).content).values()) | |
SAMPLING_RATE = 16_000 | |
def resample(x: np.ndarray, sr: int): | |
cmd = f"ffmpeg -ar {sr} -f s16le -i - -ar {SAMPLING_RATE} -f f32le -" | |
proc = subprocess.run(shlex.split(cmd), capture_output=True, input=x.tobytes()) | |
return np.frombuffer(proc.stdout, dtype=np.float32) | |
def preprocess(x: torch.Tensor): | |
melspec = kaldi.fbank(x.unsqueeze(0), htk_compat=True, window_type="hanning", num_mel_bins=128) | |
if melspec.shape[0] < 1024: | |
melspec = F.pad(melspec, (0, 0, 0, 1024 - melspec.shape[0])) | |
else: | |
melspec = melspec[:1024] | |
return melspec.view(1, 1, 1024, 128) | |
def predict(audio): | |
sr, x = audio | |
x = resample(x, sr) | |
x = torch.from_numpy(x) | |
with torch.inference_mode(): | |
logits = MODEL(preprocess(x)).squeeze(0) | |
topk_probs, topk_classes = logits.softmax(dim=-1).topk(5) | |
return [[AUDIOSET_LABELS[cls], prob.item() * 100] for cls, prob in zip(topk_classes, topk_probs)] | |
iface = gr.Interface(fn=predict, inputs="audio", outputs="dataframe") | |
iface.launch() | |