gaunernst commited on
Commit
5b04966
1 Parent(s): bcc0935

fix stereo audio

Browse files
Files changed (2) hide show
  1. app.py +9 -13
  2. packages.txt +0 -1
app.py CHANGED
@@ -1,14 +1,12 @@
1
  import json
2
- import shlex
3
- import subprocess
4
 
5
  import gradio as gr
6
- import numpy as np
7
  import requests
8
  import timm
9
  import torch
10
  import torch.nn.functional as F
11
  from torchaudio.compliance import kaldi
 
12
 
13
  TAG = "gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k"
14
  MODEL = timm.create_model(f"hf_hub:{TAG}", pretrained=True).eval()
@@ -21,12 +19,6 @@ MEAN = -4.2677393
21
  STD = 4.5689974
22
 
23
 
24
- def resample(x: np.ndarray, sr: int):
25
- cmd = f"ffmpeg -ar {sr} -f s16le -i - -ar {SAMPLING_RATE} -f f32le -"
26
- proc = subprocess.run(shlex.split(cmd), capture_output=True, input=x.tobytes())
27
- return np.frombuffer(proc.stdout, dtype=np.float32)
28
-
29
-
30
  def preprocess(x: torch.Tensor):
31
  x = x - x.mean()
32
  melspec = kaldi.fbank(x.unsqueeze(0), htk_compat=True, window_type="hanning", num_mel_bins=128)
@@ -35,7 +27,7 @@ def preprocess(x: torch.Tensor):
35
  else:
36
  melspec = melspec[:1024]
37
  melspec = (melspec - MEAN) / (STD * 2)
38
- return melspec.view(1, 1, 1024, 128)
39
 
40
 
41
  def predict(audio, start):
@@ -43,11 +35,15 @@ def predict(audio, start):
43
  if x.shape[0] < start * sr:
44
  raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({x.shape[0] / sr:.0f}s)")
45
 
46
- x = resample(x[int(start * sr) :], sr)
47
- x = torch.from_numpy(x)
 
 
 
 
48
 
49
  with torch.inference_mode():
50
- logits = MODEL(preprocess(x)).squeeze(0)
51
 
52
  topk_probs, topk_classes = logits.sigmoid().topk(10)
53
  return [[AUDIOSET_LABELS[cls], prob.item() * 100] for cls, prob in zip(topk_classes, topk_probs)]
 
1
  import json
 
 
2
 
3
  import gradio as gr
 
4
  import requests
5
  import timm
6
  import torch
7
  import torch.nn.functional as F
8
  from torchaudio.compliance import kaldi
9
+ from torchaudio.functional import resample
10
 
11
  TAG = "gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k"
12
  MODEL = timm.create_model(f"hf_hub:{TAG}", pretrained=True).eval()
 
19
  STD = 4.5689974
20
 
21
 
 
 
 
 
 
 
22
  def preprocess(x: torch.Tensor):
23
  x = x - x.mean()
24
  melspec = kaldi.fbank(x.unsqueeze(0), htk_compat=True, window_type="hanning", num_mel_bins=128)
 
27
  else:
28
  melspec = melspec[:1024]
29
  melspec = (melspec - MEAN) / (STD * 2)
30
+ return melspec.view(1, 1024, 128)
31
 
32
 
33
  def predict(audio, start):
 
35
  if x.shape[0] < start * sr:
36
  raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({x.shape[0] / sr:.0f}s)")
37
 
38
+ x = torch.from_numpy(x) / (1 << 15)
39
+ if x.ndim > 1:
40
+ x = x.mean(-1)
41
+ assert x.ndim == 1
42
+ x = resample(x[int(start * sr) :], sr, SAMPLING_RATE)
43
+ x = preprocess(x)
44
 
45
  with torch.inference_mode():
46
+ logits = MODEL(x.unsqueeze(0)).squeeze(0)
47
 
48
  topk_probs, topk_classes = logits.sigmoid().topk(10)
49
  return [[AUDIOSET_LABELS[cls], prob.item() * 100] for cls, prob in zip(topk_classes, topk_probs)]
packages.txt DELETED
@@ -1 +0,0 @@
1
- ffmpeg