Spaces:
Sleeping
Sleeping
Delete wav2vec_aligen.py
Browse files- wav2vec_aligen.py +0 -285
wav2vec_aligen.py
DELETED
@@ -1,285 +0,0 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
import torch
|
3 |
-
import librosa
|
4 |
-
import numpy as np
|
5 |
-
import os
|
6 |
-
import scipy.stats as stats
|
7 |
-
|
8 |
-
# os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
|
9 |
-
os.environ['MODEL_IS_LOADED'] = '0'
|
10 |
-
# os.environ['PHONEMIZER_ESPEAK_LIBRARY'] = "C:\Program Files\eSpeak NG\libespeak-ng.dll"
|
11 |
-
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
|
12 |
-
# os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
|
13 |
-
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
14 |
-
from optimum.bettertransformer import BetterTransformer
|
15 |
-
torch.random.manual_seed(0);
|
16 |
-
|
17 |
-
model_name = "facebook/wav2vec2-lv-60-espeak-cv-ft"
|
18 |
-
processor = Wav2Vec2Processor.from_pretrained(model_name, phone_delimiter_token=' ', word_delimiter_token=' ')
|
19 |
-
model = Wav2Vec2ForCTC.from_pretrained(model_name).to('cpu').eval()
|
20 |
-
model = BetterTransformer.transform(model)
|
21 |
-
|
22 |
-
@dataclass
|
23 |
-
class Point:
|
24 |
-
token_index: int
|
25 |
-
time_index: int
|
26 |
-
score: float
|
27 |
-
|
28 |
-
# Merge the labels
|
29 |
-
@dataclass
|
30 |
-
class Segment:
|
31 |
-
label: str
|
32 |
-
start: int
|
33 |
-
end: int
|
34 |
-
score: float
|
35 |
-
|
36 |
-
def __repr__(self):
|
37 |
-
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d}]"
|
38 |
-
|
39 |
-
def __len__(self):
|
40 |
-
return self.end - self.start
|
41 |
-
|
42 |
-
def get_trellis(emission, tokens, blank_id=0):
|
43 |
-
num_frame = emission.size(0)
|
44 |
-
num_tokens = len(tokens)
|
45 |
-
|
46 |
-
trellis = torch.zeros((num_frame, num_tokens))
|
47 |
-
trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
|
48 |
-
trellis[0, 1:] = -float("inf")
|
49 |
-
trellis[-num_tokens + 1 :, 0] = float("inf")
|
50 |
-
|
51 |
-
for t in range(num_frame - 1):
|
52 |
-
trellis[t + 1, 1:] = torch.maximum(trellis[t, 1:] + emission[t, blank_id], # Score for staying at the same token
|
53 |
-
trellis[t, :-1] + emission[t, tokens[1:]], # Score for changing to the next token
|
54 |
-
)
|
55 |
-
return trellis
|
56 |
-
|
57 |
-
|
58 |
-
def backtrack(trellis, emission, tokens, blank_id=0):
|
59 |
-
t, j = trellis.size(0) - 1, trellis.size(1) - 1
|
60 |
-
|
61 |
-
aligenment_path = [Point(j, t, emission[t, blank_id].exp().item())]
|
62 |
-
while j > 0:
|
63 |
-
# Should not happen but just in case
|
64 |
-
assert t > 0
|
65 |
-
|
66 |
-
# 1. Figure out if the current position was stay or change
|
67 |
-
# Frame-wise score of stay vs change
|
68 |
-
p_stay = emission[t - 1, blank_id]
|
69 |
-
p_change = emission[t - 1, tokens[j]]
|
70 |
-
|
71 |
-
# Context-aware score for stay vs change
|
72 |
-
stayed = trellis[t - 1, j] + p_stay
|
73 |
-
changed = trellis[t - 1, j - 1] + p_change
|
74 |
-
|
75 |
-
# Update position
|
76 |
-
t -= 1
|
77 |
-
if changed > stayed:
|
78 |
-
j -= 1
|
79 |
-
|
80 |
-
# Store the aligenment_path with frame-wise probability.
|
81 |
-
prob = (p_change if changed > stayed else p_stay).exp().item()
|
82 |
-
aligenment_path.append(Point(j, t, prob))
|
83 |
-
|
84 |
-
# Now j == 0, which means, it reached the SoS.
|
85 |
-
# Fill up the rest for the sake of visualization
|
86 |
-
while t > 0:
|
87 |
-
prob = emission[t - 1, blank_id].exp().item()
|
88 |
-
aligenment_path.append(Point(j, t - 1, prob))
|
89 |
-
t -= 1
|
90 |
-
|
91 |
-
return aligenment_path[::-1]
|
92 |
-
|
93 |
-
def merge_repeats(aligenment_path, ph):
|
94 |
-
i1, i2 = 0, 0
|
95 |
-
segments = []
|
96 |
-
while i1 < len(aligenment_path):
|
97 |
-
while i2 < len(aligenment_path) and aligenment_path[i1].token_index == aligenment_path[i2].token_index:
|
98 |
-
i2 += 1
|
99 |
-
score = sum(aligenment_path[k].score for k in range(i1, i2)) / (i2 - i1)
|
100 |
-
segments.append(
|
101 |
-
Segment(
|
102 |
-
ph[aligenment_path[i1].token_index],
|
103 |
-
aligenment_path[i1].time_index,
|
104 |
-
aligenment_path[i2 - 1].time_index + 1,
|
105 |
-
score,
|
106 |
-
)
|
107 |
-
)
|
108 |
-
i1 = i2
|
109 |
-
return segments
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
def load_model(device='cpu'):
|
114 |
-
model_name = "facebook/wav2vec2-lv-60-espeak-cv-ft"
|
115 |
-
processor = Wav2Vec2Processor.from_pretrained(model_name, phone_delimiter_token=' ', word_delimiter_token=' ')
|
116 |
-
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device).eval()
|
117 |
-
model = BetterTransformer.transform(model)
|
118 |
-
return processor, model
|
119 |
-
|
120 |
-
def load_audio(audio_path, processor):
|
121 |
-
audio, sr = librosa.load(audio_path, sr=16000)
|
122 |
-
|
123 |
-
input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values
|
124 |
-
return input_values
|
125 |
-
|
126 |
-
|
127 |
-
@torch.inference_mode()
|
128 |
-
def get_emissions(input_values, model):
|
129 |
-
emissions = model(input_values,).logits
|
130 |
-
emissions = torch.log_softmax(emissions, dim=-1)
|
131 |
-
emission = emissions[0].cpu().detach()
|
132 |
-
return emission
|
133 |
-
|
134 |
-
def get_chnocial_phonemes(transcript, processor):
|
135 |
-
transcript = transcript.replace('from the', 'from | the')
|
136 |
-
phoneme_ids = processor.tokenizer(transcript).input_ids
|
137 |
-
ph = processor.tokenizer.phonemize(transcript)
|
138 |
-
phoneme_list = ph.replace(' ', ' ').split()
|
139 |
-
transcript = transcript.replace('from | the', 'from the')
|
140 |
-
words = transcript.split()
|
141 |
-
words_phonemes = ph.split(' ')
|
142 |
-
words_phoneme_mapping = [(w, p) for w, p in zip(words, words_phonemes)]
|
143 |
-
|
144 |
-
|
145 |
-
return phoneme_list, phoneme_ids, words_phoneme_mapping
|
146 |
-
|
147 |
-
|
148 |
-
def word_level_scoring(words_phoneme_mapping, segments):
|
149 |
-
word_scores = []
|
150 |
-
start = 0
|
151 |
-
for word, ph_seq in words_phoneme_mapping:
|
152 |
-
n_ph = len(ph_seq.split())
|
153 |
-
cum_score = 0
|
154 |
-
wrong = 0
|
155 |
-
for i in range(start, start + n_ph):
|
156 |
-
s = segments[i]
|
157 |
-
cum_score += s.score
|
158 |
-
if s.score < 0.50:
|
159 |
-
wrong += 1
|
160 |
-
|
161 |
-
start += n_ph
|
162 |
-
word_scores.append((word, np.round(cum_score / n_ph, 5), np.round(wrong / n_ph, 5)))
|
163 |
-
return word_scores
|
164 |
-
|
165 |
-
def map_word2_class(word_scores):
|
166 |
-
word_levels = []
|
167 |
-
for w, sc, wrong_ratio in word_scores:
|
168 |
-
if wrong_ratio > 0.5 or sc < 0.60:
|
169 |
-
word_levels.append((w, '/'))
|
170 |
-
elif sc < 0.70:
|
171 |
-
word_levels.append((w, 'Wrong'))
|
172 |
-
elif sc < 0.85:
|
173 |
-
word_levels.append((w, 'Understandable'))
|
174 |
-
else:
|
175 |
-
word_levels.append((w, 'Correct'))
|
176 |
-
return word_levels
|
177 |
-
|
178 |
-
def calculate_content_scores(word_levels):
|
179 |
-
content_scores = len(word_levels)
|
180 |
-
for w, c in word_levels:
|
181 |
-
if c == '/':
|
182 |
-
content_scores -= 1
|
183 |
-
elif c == 'Wrong':
|
184 |
-
content_scores -= 0.5
|
185 |
-
else:None
|
186 |
-
content_scores = (content_scores / len(word_levels)) * 100
|
187 |
-
return content_scores
|
188 |
-
|
189 |
-
def calculate_sentence_pronunciation_accuracy(word_scores):
|
190 |
-
w_scores = 0
|
191 |
-
error_scores = 0
|
192 |
-
for w, sc, wrong_ratio in word_scores:
|
193 |
-
sc = sc * 100
|
194 |
-
if sc > 60:
|
195 |
-
if sc < 70:
|
196 |
-
sc = ((sc - 60) / (70 - 60)) * (20 - 0) + 0
|
197 |
-
elif sc < 88:
|
198 |
-
sc = ((sc - 70) / (88 - 70)) * (70 - 20) + 20
|
199 |
-
else:
|
200 |
-
sc = ((sc - 88) / (100 - 88)) * (100 - 70) + 70
|
201 |
-
w_scores += sc
|
202 |
-
error_scores += wrong_ratio
|
203 |
-
w_scores = (w_scores / len(word_scores))
|
204 |
-
# w_scores =( (w_scores - 50) / (100 - 50)) * 100
|
205 |
-
error_scores = (error_scores / len(word_scores)) * 40
|
206 |
-
pronunciation_accuracy = min(w_scores, w_scores - error_scores)
|
207 |
-
return pronunciation_accuracy
|
208 |
-
|
209 |
-
def get_hard_aligenment_with_scores(input_values, transcript):
|
210 |
-
# processor, model = load_model(device='cpu')
|
211 |
-
|
212 |
-
emission = get_emissions(input_values, model)
|
213 |
-
phoneme_list, phoneme_ids, words_phoneme_mapping = get_chnocial_phonemes(transcript, processor)
|
214 |
-
trellis = get_trellis(emission, phoneme_ids)
|
215 |
-
aligenment_path = backtrack(trellis, emission, phoneme_ids)
|
216 |
-
segments = merge_repeats(aligenment_path, phoneme_list)
|
217 |
-
return segments, words_phoneme_mapping
|
218 |
-
|
219 |
-
def normalize_aspect(value, mean, std):
|
220 |
-
""" Normalize an aspect of speech using normal distribution. """
|
221 |
-
return stats.norm(mean, std).cdf(value)
|
222 |
-
|
223 |
-
def calculate_fluency_scores(audio, total_words, content_score, pron_score):
|
224 |
-
# Constants
|
225 |
-
content_score, pron_score = content_score / 100, pron_score / 100
|
226 |
-
sample_rate = 16000 # Assuming a sample rate of 16 kHz
|
227 |
-
# Define means and standard deviations for fluent speech
|
228 |
-
speech_rate_mean, speech_rate_std = 170, 50
|
229 |
-
phonation_time_mean, phonation_time_std = 50, 4
|
230 |
-
|
231 |
-
# Calculate speaking and total duration
|
232 |
-
non_silent_intervals = librosa.effects.split(audio, top_db=20)
|
233 |
-
speaking_time = sum([intv[1] - intv[0] for intv in non_silent_intervals]) / sample_rate
|
234 |
-
total_duration = len(audio) / sample_rate
|
235 |
-
|
236 |
-
# Phonation time ratio
|
237 |
-
phonation_time_ratio = speaking_time / total_duration * 60
|
238 |
-
|
239 |
-
phonation_time_ratio = normalize_aspect(phonation_time_ratio, phonation_time_mean, phonation_time_std)
|
240 |
-
if phonation_time_ratio > 0.5:
|
241 |
-
phonation_time_ratio = 0.5 - (phonation_time_ratio - 0.5)
|
242 |
-
phonation_time_ratio = (phonation_time_ratio / 0.5) * 1
|
243 |
-
|
244 |
-
|
245 |
-
speech_rate = (total_words / (total_duration / 60))
|
246 |
-
speech_rate = speech_rate * content_score
|
247 |
-
speech_rate_score = normalize_aspect(speech_rate, speech_rate_mean, speech_rate_std)
|
248 |
-
if speech_rate_score > 0.5:
|
249 |
-
speech_rate_score = 0.5 - (speech_rate_score - 0.5)
|
250 |
-
|
251 |
-
speech_rate_score = (speech_rate_score / 0.5) * 1
|
252 |
-
|
253 |
-
|
254 |
-
w_rate_score = 0.4
|
255 |
-
w_pho_ratio = 0.35
|
256 |
-
w_pro = 0.25
|
257 |
-
scaled_fluency_score = speech_rate_score * w_rate_score + phonation_time_ratio * w_pho_ratio + pron_score * w_pro
|
258 |
-
scaled_fluency_score = scaled_fluency_score * 100
|
259 |
-
return scaled_fluency_score, speech_rate
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
def speaker_pronunciation_assesment(audio_path, transcript):
|
264 |
-
input_values = load_audio(audio_path, processor)
|
265 |
-
segments, words_phoneme_mapping = get_hard_aligenment_with_scores(input_values, transcript)
|
266 |
-
word_scores = word_level_scoring(words_phoneme_mapping, segments)
|
267 |
-
word_levels = map_word2_class(word_scores)
|
268 |
-
content_scores = calculate_content_scores(word_levels)
|
269 |
-
pronunciation_accuracy = calculate_sentence_pronunciation_accuracy(word_scores)
|
270 |
-
fluency_accuracy, wpm = calculate_fluency_scores(input_values[0], len(word_scores), content_scores, pronunciation_accuracy)
|
271 |
-
|
272 |
-
|
273 |
-
result = {'pronunciation_accuracy': pronunciation_accuracy,
|
274 |
-
'word_levels': word_levels,
|
275 |
-
'content_scores': content_scores,
|
276 |
-
'wpm': wpm,
|
277 |
-
'stress': None,
|
278 |
-
'fluency_score': fluency_accuracy}
|
279 |
-
return result
|
280 |
-
|
281 |
-
if __name__ == '__main__':
|
282 |
-
MODEL_IS_LOADED = False
|
283 |
-
else:
|
284 |
-
MODEL_IS_LOADED = False
|
285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|