amroa commited on
Commit
79fcc82
·
1 Parent(s): 4b48e6e

add audio MAE

Browse files
__pycache__/app.cpython-311.pyc CHANGED
Binary files a/__pycache__/app.cpython-311.pyc and b/__pycache__/app.cpython-311.pyc differ
 
__pycache__/classpred.cpython-311.pyc ADDED
Binary file (3.47 kB). View file
 
app.py CHANGED
@@ -9,7 +9,9 @@ from model import BirdAST
9
  import torch
10
  import librosa
11
  import noisereduce as nr
 
12
  import pandas as pd
 
13
  import torch.nn.functional as F
14
  import random
15
  from torchaudio.compliance import kaldi
@@ -56,7 +58,7 @@ def predict(audio, start, end):
56
  sr, x = audio
57
 
58
  x = np.array(x, dtype=np.float32)/32768.0
59
- x = x[start*sr : end*sr]
60
  res = preprocess_for_inference(x, sr)
61
 
62
  if start >= end:
@@ -72,7 +74,7 @@ def predict(audio, start, end):
72
  fig2 = plot_wave(sr, x)
73
 
74
 
75
- return res, res, fig1, fig2
76
 
77
  def download_model(url, model_path):
78
  if not os.path.exists(model_path):
 
9
  import torch
10
  import librosa
11
  import noisereduce as nr
12
+ import timm
13
  import pandas as pd
14
+ from classpred import predict_class
15
  import torch.nn.functional as F
16
  import random
17
  from torchaudio.compliance import kaldi
 
58
  sr, x = audio
59
 
60
  x = np.array(x, dtype=np.float32)/32768.0
61
+ x = x[int(start*sr) : int(end*sr)]
62
  res = preprocess_for_inference(x, sr)
63
 
64
  if start >= end:
 
74
  fig2 = plot_wave(sr, x)
75
 
76
 
77
+ return predict_class(x, sr, start, end), res, fig1, fig2
78
 
79
  def download_model(url, model_path):
80
  if not os.path.exists(model_path):
classpred.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import json
3
+ import torch
4
+ from torchaudio.functional import resample
5
+ import numpy as np
6
+ from torchaudio.compliance import kaldi
7
+ import torch.nn.functional as F
8
+ import requests
9
+
10
+ TAG = "gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k"
11
+ MODEL = timm.create_model(f"hf_hub:{TAG}", pretrained=True).eval()
12
+
13
+ LABEL_URL = "https://huggingface.co/datasets/huggingface/label-files/raw/main/audioset-id2label.json"
14
+ AUDIOSET_LABELS = list(json.loads(requests.get(LABEL_URL).content).values())
15
+
16
+ SAMPLING_RATE = 16_000
17
+ MEAN = -4.2677393
18
+ STD = 4.5689974
19
+
20
+ def preprocess(x: torch.Tensor):
21
+ x = x - x.mean()
22
+ melspec = kaldi.fbank(x.unsqueeze(0), htk_compat=True, window_type="hanning", num_mel_bins=128)
23
+ if melspec.shape[0] < 1024:
24
+ melspec = F.pad(melspec, (0, 0, 0, 1024 - melspec.shape[0]))
25
+ else:
26
+ melspec = melspec[:1024]
27
+ melspec = (melspec - MEAN) / (STD * 2)
28
+ return melspec
29
+
30
+ def predict_class(x, sr, start, end):
31
+ x = torch.from_numpy(x) / (1 << 15)
32
+ if x.ndim > 1:
33
+ x = x.mean(-1)
34
+ assert x.ndim == 1
35
+ x = resample(x[int(start * sr) : int(end * sr)], sr, SAMPLING_RATE)
36
+ x = preprocess(x)
37
+
38
+ with torch.inference_mode():
39
+ logits = MODEL(x.view(1, 1, 1024, 128)).squeeze(0)
40
+
41
+ topk_probs, topk_classes = logits.sigmoid().topk(10)
42
+ preds = [[AUDIOSET_LABELS[cls], prob.item() * 100] for cls, prob in zip(topk_classes, topk_probs)]
43
+
44
+ return preds