gaunernst's picture
initial commit
cac3ec7
raw
history blame
1.54 kB
import json
import shlex
import subprocess
import gradio as gr
import numpy as np
import requests
import timm
import torch
import torch.nn.functional as F
from torchaudio.compliance import kaldi
TAG = "gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k"
MODEL = timm.create_model(f"hf_hub:{TAG}", pretrained=True).eval()
LABEL_URL = "https://huggingface.co/datasets/huggingface/label-files/raw/main/audioset-id2label.json"
AUDIOSET_LABELS = list(json.loads(requests.get(LABEL_URL).content).values())
SAMPLING_RATE = 16_000
def resample(x: np.ndarray, sr: int):
cmd = f"ffmpeg -ar {sr} -f s16le -i - -ar {SAMPLING_RATE} -f f32le -"
proc = subprocess.run(shlex.split(cmd), capture_output=True, input=x.tobytes())
return np.frombuffer(proc.stdout, dtype=np.float32)
def preprocess(x: torch.Tensor):
melspec = kaldi.fbank(x.unsqueeze(0), htk_compat=True, window_type="hanning", num_mel_bins=128)
if melspec.shape[0] < 1024:
melspec = F.pad(melspec, (0, 0, 0, 1024 - melspec.shape[0]))
else:
melspec = melspec[:1024]
return melspec.view(1, 1, 1024, 128)
def predict(audio):
sr, x = audio
x = resample(x, sr)
x = torch.from_numpy(x)
with torch.inference_mode():
logits = MODEL(preprocess(x)).squeeze(0)
topk_probs, topk_classes = logits.softmax(dim=-1).topk(5)
return [[AUDIOSET_LABELS[cls], prob.item() * 100] for cls, prob in zip(topk_classes, topk_probs)]
iface = gr.Interface(fn=predict, inputs="audio", outputs="dataframe")
iface.launch()