quaja commited on
Commit
87f3bb4
1 Parent(s): bb4c459

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +71 -1
README.md CHANGED
@@ -9,7 +9,77 @@ tags:
9
  pipeline_tag: audio-classification
10
  ---
11
 
 
 
 
 
 
 
 
 
 
 
12
  model_name_or_path = "quaja/hubert-base-amharic-speech-emotion-recognition"
13
  config = AutoConfig.from_pretrained(model_name_or_path)
14
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name_or_path)
15
- model = HubertForSpeechClassification.from_pretrained(model_name_or_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  pipeline_tag: audio-classification
10
  ---
11
 
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import torchaudio
16
+ from transformers import AutoConfig, Wav2Vec2Processor
17
+ import librosa
18
+ import IPython.display as ipd
19
+ import numpy as np
20
+ import pandas as pd
21
+
22
  model_name_or_path = "quaja/hubert-base-amharic-speech-emotion-recognition"
23
  config = AutoConfig.from_pretrained(model_name_or_path)
24
  feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name_or_path)
25
+ sampling_rate = feature_extractor.sampling_rate
26
+ model = HubertForSpeechClassification.from_pretrained(model_name_or_path)
27
+
28
+ def speech_file_to_array_fn(path, sampling_rate):
29
+ speech_array, _sampling_rate = torchaudio.load(path)
30
+ resampler = torchaudio.transforms.Resample(_sampling_rate)
31
+ speech = resampler(speech_array).squeeze().numpy()
32
+ return speech
33
+
34
+
35
+ def predict(path, sampling_rate):
36
+ speech = speech_file_to_array_fn(path, sampling_rate)
37
+ features = feature_extractor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
38
+
39
+ input_values = features.input_values.to(device)
40
+
41
+ with torch.no_grad():
42
+ logits = model(input_values).logits
43
+
44
+ scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
45
+ outputs = [{"Label": config.id2label[i], "Score": f"{round(score * 100, 3):.1f}%"} for i, score in enumerate(scores)]
46
+ return outputs
47
+
48
+
49
+ STYLES = """
50
+ <style>
51
+ div.display_data {
52
+ margin: 0 auto;
53
+ max-width: 500px;
54
+ }
55
+ table.xxx {
56
+ margin: 50px !important;
57
+ float: right !important;
58
+ clear: both !important;
59
+ }
60
+ table.xxx td {
61
+ min-width: 300px !important;
62
+ text-align: center !important;
63
+ }
64
+ </style>
65
+ """.strip()
66
+
67
+ def prediction(df_row):
68
+ path, label = df_row["path"], df_row["emotion"]
69
+ df = pd.DataFrame([{"Emotion": label, "Sentence": " "}])
70
+ setup = {
71
+ 'border': 2,
72
+ 'show_dimensions': True,
73
+ 'justify': 'center',
74
+ 'classes': 'xxx',
75
+ 'escape': False,
76
+ }
77
+ ipd.display(ipd.HTML(STYLES + df.to_html(**setup) + "<br />"))
78
+ speech, sr = torchaudio.load(path)
79
+ resampler = torchaudio.transforms.Resample(sr)
80
+ speech = resampler(speech[0]).squeeze().numpy()
81
+ ipd.display(ipd.Audio(data=np.asarray(speech), autoplay=True, rate=sampling_rate))
82
+
83
+ outputs = predict(path, sampling_rate)
84
+ r = pd.DataFrame(outputs)
85
+ ipd.display(ipd.HTML(STYLES + r.to_html(**setup) + "<br />"))