seba3y commited on
Commit
76171df
1 Parent(s): a5e16c1

Delete wav2vec_aligen.py

Browse files
Files changed (1) hide show
  1. wav2vec_aligen.py +0 -285
wav2vec_aligen.py DELETED
@@ -1,285 +0,0 @@
1
- from dataclasses import dataclass
2
- import torch
3
- import librosa
4
- import numpy as np
5
- import os
6
- import scipy.stats as stats
7
-
8
- # os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
9
- os.environ['MODEL_IS_LOADED'] = '0'
10
- # os.environ['PHONEMIZER_ESPEAK_LIBRARY'] = "C:\Program Files\eSpeak NG\libespeak-ng.dll"
11
- os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
12
- # os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
13
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
14
- from optimum.bettertransformer import BetterTransformer
15
- torch.random.manual_seed(0);
16
-
17
- model_name = "facebook/wav2vec2-lv-60-espeak-cv-ft"
18
- processor = Wav2Vec2Processor.from_pretrained(model_name, phone_delimiter_token=' ', word_delimiter_token=' ')
19
- model = Wav2Vec2ForCTC.from_pretrained(model_name).to('cpu').eval()
20
- model = BetterTransformer.transform(model)
21
-
22
- @dataclass
23
- class Point:
24
- token_index: int
25
- time_index: int
26
- score: float
27
-
28
- # Merge the labels
29
- @dataclass
30
- class Segment:
31
- label: str
32
- start: int
33
- end: int
34
- score: float
35
-
36
- def __repr__(self):
37
- return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d}]"
38
-
39
- def __len__(self):
40
- return self.end - self.start
41
-
42
- def get_trellis(emission, tokens, blank_id=0):
43
- num_frame = emission.size(0)
44
- num_tokens = len(tokens)
45
-
46
- trellis = torch.zeros((num_frame, num_tokens))
47
- trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
48
- trellis[0, 1:] = -float("inf")
49
- trellis[-num_tokens + 1 :, 0] = float("inf")
50
-
51
- for t in range(num_frame - 1):
52
- trellis[t + 1, 1:] = torch.maximum(trellis[t, 1:] + emission[t, blank_id], # Score for staying at the same token
53
- trellis[t, :-1] + emission[t, tokens[1:]], # Score for changing to the next token
54
- )
55
- return trellis
56
-
57
-
58
- def backtrack(trellis, emission, tokens, blank_id=0):
59
- t, j = trellis.size(0) - 1, trellis.size(1) - 1
60
-
61
- aligenment_path = [Point(j, t, emission[t, blank_id].exp().item())]
62
- while j > 0:
63
- # Should not happen but just in case
64
- assert t > 0
65
-
66
- # 1. Figure out if the current position was stay or change
67
- # Frame-wise score of stay vs change
68
- p_stay = emission[t - 1, blank_id]
69
- p_change = emission[t - 1, tokens[j]]
70
-
71
- # Context-aware score for stay vs change
72
- stayed = trellis[t - 1, j] + p_stay
73
- changed = trellis[t - 1, j - 1] + p_change
74
-
75
- # Update position
76
- t -= 1
77
- if changed > stayed:
78
- j -= 1
79
-
80
- # Store the aligenment_path with frame-wise probability.
81
- prob = (p_change if changed > stayed else p_stay).exp().item()
82
- aligenment_path.append(Point(j, t, prob))
83
-
84
- # Now j == 0, which means, it reached the SoS.
85
- # Fill up the rest for the sake of visualization
86
- while t > 0:
87
- prob = emission[t - 1, blank_id].exp().item()
88
- aligenment_path.append(Point(j, t - 1, prob))
89
- t -= 1
90
-
91
- return aligenment_path[::-1]
92
-
93
- def merge_repeats(aligenment_path, ph):
94
- i1, i2 = 0, 0
95
- segments = []
96
- while i1 < len(aligenment_path):
97
- while i2 < len(aligenment_path) and aligenment_path[i1].token_index == aligenment_path[i2].token_index:
98
- i2 += 1
99
- score = sum(aligenment_path[k].score for k in range(i1, i2)) / (i2 - i1)
100
- segments.append(
101
- Segment(
102
- ph[aligenment_path[i1].token_index],
103
- aligenment_path[i1].time_index,
104
- aligenment_path[i2 - 1].time_index + 1,
105
- score,
106
- )
107
- )
108
- i1 = i2
109
- return segments
110
-
111
-
112
-
113
- def load_model(device='cpu'):
114
- model_name = "facebook/wav2vec2-lv-60-espeak-cv-ft"
115
- processor = Wav2Vec2Processor.from_pretrained(model_name, phone_delimiter_token=' ', word_delimiter_token=' ')
116
- model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device).eval()
117
- model = BetterTransformer.transform(model)
118
- return processor, model
119
-
120
- def load_audio(audio_path, processor):
121
- audio, sr = librosa.load(audio_path, sr=16000)
122
-
123
- input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values
124
- return input_values
125
-
126
-
127
- @torch.inference_mode()
128
- def get_emissions(input_values, model):
129
- emissions = model(input_values,).logits
130
- emissions = torch.log_softmax(emissions, dim=-1)
131
- emission = emissions[0].cpu().detach()
132
- return emission
133
-
134
- def get_chnocial_phonemes(transcript, processor):
135
- transcript = transcript.replace('from the', 'from | the')
136
- phoneme_ids = processor.tokenizer(transcript).input_ids
137
- ph = processor.tokenizer.phonemize(transcript)
138
- phoneme_list = ph.replace(' ', ' ').split()
139
- transcript = transcript.replace('from | the', 'from the')
140
- words = transcript.split()
141
- words_phonemes = ph.split(' ')
142
- words_phoneme_mapping = [(w, p) for w, p in zip(words, words_phonemes)]
143
-
144
-
145
- return phoneme_list, phoneme_ids, words_phoneme_mapping
146
-
147
-
148
- def word_level_scoring(words_phoneme_mapping, segments):
149
- word_scores = []
150
- start = 0
151
- for word, ph_seq in words_phoneme_mapping:
152
- n_ph = len(ph_seq.split())
153
- cum_score = 0
154
- wrong = 0
155
- for i in range(start, start + n_ph):
156
- s = segments[i]
157
- cum_score += s.score
158
- if s.score < 0.50:
159
- wrong += 1
160
-
161
- start += n_ph
162
- word_scores.append((word, np.round(cum_score / n_ph, 5), np.round(wrong / n_ph, 5)))
163
- return word_scores
164
-
165
- def map_word2_class(word_scores):
166
- word_levels = []
167
- for w, sc, wrong_ratio in word_scores:
168
- if wrong_ratio > 0.5 or sc < 0.60:
169
- word_levels.append((w, '/'))
170
- elif sc < 0.70:
171
- word_levels.append((w, 'Wrong'))
172
- elif sc < 0.85:
173
- word_levels.append((w, 'Understandable'))
174
- else:
175
- word_levels.append((w, 'Correct'))
176
- return word_levels
177
-
178
- def calculate_content_scores(word_levels):
179
- content_scores = len(word_levels)
180
- for w, c in word_levels:
181
- if c == '/':
182
- content_scores -= 1
183
- elif c == 'Wrong':
184
- content_scores -= 0.5
185
- else:None
186
- content_scores = (content_scores / len(word_levels)) * 100
187
- return content_scores
188
-
189
- def calculate_sentence_pronunciation_accuracy(word_scores):
190
- w_scores = 0
191
- error_scores = 0
192
- for w, sc, wrong_ratio in word_scores:
193
- sc = sc * 100
194
- if sc > 60:
195
- if sc < 70:
196
- sc = ((sc - 60) / (70 - 60)) * (20 - 0) + 0
197
- elif sc < 88:
198
- sc = ((sc - 70) / (88 - 70)) * (70 - 20) + 20
199
- else:
200
- sc = ((sc - 88) / (100 - 88)) * (100 - 70) + 70
201
- w_scores += sc
202
- error_scores += wrong_ratio
203
- w_scores = (w_scores / len(word_scores))
204
- # w_scores =( (w_scores - 50) / (100 - 50)) * 100
205
- error_scores = (error_scores / len(word_scores)) * 40
206
- pronunciation_accuracy = min(w_scores, w_scores - error_scores)
207
- return pronunciation_accuracy
208
-
209
- def get_hard_aligenment_with_scores(input_values, transcript):
210
- # processor, model = load_model(device='cpu')
211
-
212
- emission = get_emissions(input_values, model)
213
- phoneme_list, phoneme_ids, words_phoneme_mapping = get_chnocial_phonemes(transcript, processor)
214
- trellis = get_trellis(emission, phoneme_ids)
215
- aligenment_path = backtrack(trellis, emission, phoneme_ids)
216
- segments = merge_repeats(aligenment_path, phoneme_list)
217
- return segments, words_phoneme_mapping
218
-
219
- def normalize_aspect(value, mean, std):
220
- """ Normalize an aspect of speech using normal distribution. """
221
- return stats.norm(mean, std).cdf(value)
222
-
223
- def calculate_fluency_scores(audio, total_words, content_score, pron_score):
224
- # Constants
225
- content_score, pron_score = content_score / 100, pron_score / 100
226
- sample_rate = 16000 # Assuming a sample rate of 16 kHz
227
- # Define means and standard deviations for fluent speech
228
- speech_rate_mean, speech_rate_std = 170, 50
229
- phonation_time_mean, phonation_time_std = 50, 4
230
-
231
- # Calculate speaking and total duration
232
- non_silent_intervals = librosa.effects.split(audio, top_db=20)
233
- speaking_time = sum([intv[1] - intv[0] for intv in non_silent_intervals]) / sample_rate
234
- total_duration = len(audio) / sample_rate
235
-
236
- # Phonation time ratio
237
- phonation_time_ratio = speaking_time / total_duration * 60
238
-
239
- phonation_time_ratio = normalize_aspect(phonation_time_ratio, phonation_time_mean, phonation_time_std)
240
- if phonation_time_ratio > 0.5:
241
- phonation_time_ratio = 0.5 - (phonation_time_ratio - 0.5)
242
- phonation_time_ratio = (phonation_time_ratio / 0.5) * 1
243
-
244
-
245
- speech_rate = (total_words / (total_duration / 60))
246
- speech_rate = speech_rate * content_score
247
- speech_rate_score = normalize_aspect(speech_rate, speech_rate_mean, speech_rate_std)
248
- if speech_rate_score > 0.5:
249
- speech_rate_score = 0.5 - (speech_rate_score - 0.5)
250
-
251
- speech_rate_score = (speech_rate_score / 0.5) * 1
252
-
253
-
254
- w_rate_score = 0.4
255
- w_pho_ratio = 0.35
256
- w_pro = 0.25
257
- scaled_fluency_score = speech_rate_score * w_rate_score + phonation_time_ratio * w_pho_ratio + pron_score * w_pro
258
- scaled_fluency_score = scaled_fluency_score * 100
259
- return scaled_fluency_score, speech_rate
260
-
261
-
262
-
263
- def speaker_pronunciation_assesment(audio_path, transcript):
264
- input_values = load_audio(audio_path, processor)
265
- segments, words_phoneme_mapping = get_hard_aligenment_with_scores(input_values, transcript)
266
- word_scores = word_level_scoring(words_phoneme_mapping, segments)
267
- word_levels = map_word2_class(word_scores)
268
- content_scores = calculate_content_scores(word_levels)
269
- pronunciation_accuracy = calculate_sentence_pronunciation_accuracy(word_scores)
270
- fluency_accuracy, wpm = calculate_fluency_scores(input_values[0], len(word_scores), content_scores, pronunciation_accuracy)
271
-
272
-
273
- result = {'pronunciation_accuracy': pronunciation_accuracy,
274
- 'word_levels': word_levels,
275
- 'content_scores': content_scores,
276
- 'wpm': wpm,
277
- 'stress': None,
278
- 'fluency_score': fluency_accuracy}
279
- return result
280
-
281
- if __name__ == '__main__':
282
- MODEL_IS_LOADED = False
283
- else:
284
- MODEL_IS_LOADED = False
285
-