3loi commited on
Commit
76fe58b
1 Parent(s): ad02703

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -9,6 +9,8 @@ import numpy as np
9
 
10
 
11
  mean, std = -8.278621631819787e-05, 0.08485510250851999
 
 
12
  id2label = {0: 'arousal', 1: 'dominance', 2: 'valence'}
13
  description_text = "Multi-label (arousal, dominance, valence) Odyssey 2024 Emotion Recognition competition baseline model.<br> \
14
  The model is trained on MSP-Podcast. \
@@ -28,19 +30,22 @@ def classify_audio(audio_file):
28
 
29
  y = raw_wav.astype(np.float32, order='C') / np.iinfo(raw_wav.dtype).max
30
 
 
 
 
 
 
 
31
 
32
- norm_wav = (y - mean) / (std+0.000001)
33
 
 
 
34
  mask = torch.ones(1, len(norm_wav))
35
  wavs = torch.tensor(norm_wav).unsqueeze(0)
36
 
37
  pred = model(wavs, mask).detach().numpy()
38
 
39
- output = ''
40
- if sr != 16000:
41
- output += "{} sampling rate is uncompatible. The model was trained on {} sampleing rate\n".format(sr, 16000)
42
- # for i, audio_pred in enumerate(pred):
43
- # output[i] = {}
44
  for att_i, att_val in enumerate(pred[0]):
45
  output += "{}: \t{:0.4f}\n".format(id2label[att_i], att_val)
46
 
 
9
 
10
 
11
  mean, std = -8.278621631819787e-05, 0.08485510250851999
12
+ model_sr=model.config.sampling_rate
13
+
14
  id2label = {0: 'arousal', 1: 'dominance', 2: 'valence'}
15
  description_text = "Multi-label (arousal, dominance, valence) Odyssey 2024 Emotion Recognition competition baseline model.<br> \
16
  The model is trained on MSP-Podcast. \
 
30
 
31
  y = raw_wav.astype(np.float32, order='C') / np.iinfo(raw_wav.dtype).max
32
 
33
+
34
+
35
+ output = ''
36
+ if sr != 16000:
37
+ y = librosa.resample(y, orig_sr=sr, target_sr=model_sr)
38
+ output += "{} sampling rate is uncompatible, converted to {} as the model was trained on {} sampling rate<br>".format(sr, model_sr, model_sr)
39
 
 
40
 
41
+
42
+ norm_wav = (y - mean) / (std+0.000001)
43
  mask = torch.ones(1, len(norm_wav))
44
  wavs = torch.tensor(norm_wav).unsqueeze(0)
45
 
46
  pred = model(wavs, mask).detach().numpy()
47
 
48
+
 
 
 
 
49
  for att_i, att_val in enumerate(pred[0]):
50
  output += "{}: \t{:0.4f}\n".format(id2label[att_i], att_val)
51