PsalmsJava commited on
Commit
fe1e779
·
verified ·
1 Parent(s): f7c744d

Update app/model.py

Browse files
Files changed (1) hide show
  1. app/model.py +20 -11
app/model.py CHANGED
@@ -1,13 +1,15 @@
1
  import torch
2
- from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification
 
 
 
 
3
 
4
  device = "cpu"
5
 
6
  model = None
7
  feature_extractor = None
8
 
9
- EMOTIONS = ["angry", "happy", "sad", "neutral"]
10
-
11
  def load_models():
12
  global model, feature_extractor
13
 
@@ -16,12 +18,16 @@ def load_models():
16
  "superb/wav2vec2-base-superb-er"
17
  )
18
 
19
- model = Wav2Vec2ForSequenceClassification.from_pretrained(
20
  "superb/wav2vec2-base-superb-er"
21
  ).to(device)
22
 
23
 
24
  def predict(audio):
 
 
 
 
25
  inputs = feature_extractor(
26
  audio,
27
  sampling_rate=16000,
@@ -30,15 +36,18 @@ def predict(audio):
30
  )
31
 
32
  with torch.no_grad():
33
- logits = model(**inputs).logits
 
 
34
 
35
- probs = torch.nn.functional.softmax(logits, dim=1).numpy()[0]
 
36
 
37
- idx = int(probs.argmax())
 
38
 
39
  return {
40
- "primary_emotion": idx,
41
- "emotion_label": EMOTIONS[idx],
42
- "confidence": float(probs[idx]),
43
- "scores": probs.tolist()
44
  }
 
1
  import torch
2
+ from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model
3
+ import numpy as np
4
+
5
+ from app.features import extract_features
6
+ from app.classifier import simple_rule_classifier
7
 
8
  device = "cpu"
9
 
10
  model = None
11
  feature_extractor = None
12
 
 
 
13
  def load_models():
14
  global model, feature_extractor
15
 
 
18
  "superb/wav2vec2-base-superb-er"
19
  )
20
 
21
+ model = Wav2Vec2Model.from_pretrained(
22
  "superb/wav2vec2-base-superb-er"
23
  ).to(device)
24
 
25
 
26
  def predict(audio):
27
+ # ---- Tone features ----
28
+ tone_features = extract_features(audio)
29
+
30
+ # ---- Deep embeddings ----
31
  inputs = feature_extractor(
32
  audio,
33
  sampling_rate=16000,
 
36
  )
37
 
38
  with torch.no_grad():
39
+ outputs = model(**inputs)
40
+
41
+ embeddings = outputs.last_hidden_state.mean(dim=1).numpy()[0]
42
 
43
+ # ---- Combine ----
44
+ combined = np.hstack([tone_features, embeddings])
45
 
46
+ # ---- Classify ----
47
+ emotion, confidence = simple_rule_classifier(tone_features)
48
 
49
  return {
50
+ "emotion_label": emotion,
51
+ "confidence": confidence,
52
+ "note": "Tone-based prediction (less text bias)"
 
53
  }