pronounciationevaluation / wav2vecasr /MispronounciationDetector.py
bel32123's picture
Introduce uncertainty to word error with PER threshold
2676061
from pandas.core.construction import T
import torch
import jiwer
import re
class MispronounciationDetector:
def __init__(self, l2_phoneme_recogniser, g2p, device):
self.phoneme_asr_model = l2_phoneme_recogniser # PhonemeASRModel class
self.g2p = g2p
self.device = device
def detect(self, audio, text, phoneme_error_threshold=0.25):
l2_phones = self.phoneme_asr_model.get_l2_phoneme_sequence(audio)
l2_phones = [re.sub(r'\d', "", phone_str) for phone_str in l2_phones] #g2p has no lexical stress
native_speaker_phones = self.get_native_speaker_phoneme_sequence(text)
standardised_native_speaker_phones = self.phoneme_asr_model.standardise_g2p_phoneme_sequence(native_speaker_phones)
raw_info = self.get_mispronounciation_output(text, l2_phones, standardised_native_speaker_phones, phoneme_error_threshold)
return raw_info
def get_native_speaker_phoneme_sequence(self, text):
phonemes = self.g2p(text)
return phonemes
def get_mispronounciation_output(self, text, pred_phones, org_label_phones, phoneme_error_threshold):
"""
Aligns the predicted phones from the L2 speaker and the expected native speaker phone to get the errors
:param text: original words read by the user
:type text: string
:param pred_phones: predicted phonemes by L2 speaker from ASR Model
:type pred_phones: array
:param org_label_phones: correct, native speaker phonemes from G2P where phonemes of each word is segregated by " "
:type org_label_phones: array
:return: dictionary containing various mispronounciation information like PER, WER and error boolean arrays at phoneme/word level
:rtype: dictionary
"""
# get per
label_phones = [phone for phone in org_label_phones if phone != " "]
reference = " ".join(label_phones) # dummy phones
hypothesis = " ".join(pred_phones) # dummy l2 speaker phones
res = jiwer.process_words(reference, hypothesis)
per = res.wer
# print(jiwer.visualize_alignment(res))
# get phoneme alignments
alignments = res.alignments
error_bool = []
ref, hyp = [],[]
for alignment_chunk in alignments[0]:
alignment_type = alignment_chunk.type
ref_start_idx = alignment_chunk.ref_start_idx
ref_end_idx = alignment_chunk.ref_end_idx
hyp_start_idx = alignment_chunk.hyp_start_idx
hyp_end_idx = alignment_chunk.hyp_end_idx
if alignment_type != "equal":
if alignment_type == "insert":
for i in range(hyp_start_idx, hyp_end_idx):
ref.append("*" * len(pred_phones[i]))
space_padding = " " * (len(pred_phones[i])-1)
error_bool.append(space_padding + "a")
hyp.extend(pred_phones[hyp_start_idx:hyp_end_idx])
elif alignment_type == "delete":
ref.extend(label_phones[ref_start_idx:ref_end_idx])
for i in range(ref_start_idx, ref_end_idx):
hyp.append("*" * len(label_phones[i]))
space_padding = " " * (len(label_phones[i])-1)
error_bool.append(space_padding + alignment_type[0])
else:
for i in range(ref_end_idx - ref_start_idx):
correct_phone = label_phones[ref_start_idx+i]
pred_phone = pred_phones[hyp_start_idx+i]
if len(correct_phone) > len(pred_phone):
space_padding = " " * (len(correct_phone) - len(pred_phone))
ref.append(correct_phone)
hyp.append(space_padding + pred_phone)
error_bool.append(" " * (len(correct_phone)-1) + alignment_type[0])
else:
space_padding = " " * (len(pred_phone) - len(correct_phone))
ref.append(space_padding + correct_phone)
hyp.append(pred_phone)
error_bool.append(" " * (len(pred_phone)-1) + alignment_type[0])
else:
ref.extend(label_phones[ref_start_idx:ref_end_idx])
hyp.extend(pred_phones[hyp_start_idx:hyp_end_idx])
# ref or hyp does not matter
for i in range(ref_start_idx, ref_end_idx):
space_padding = "-" * (len(label_phones[i]))
error_bool.append(space_padding)
# insert word delimiters to show user phoneme sections by word
delimiter_idx = 0
for phone in org_label_phones:
if phone == " ":
hyp.insert(delimiter_idx+1, "|")
ref.insert(delimiter_idx+1, "|")
error_bool.insert(delimiter_idx+1, "|")
continue
while delimiter_idx < len(ref) and ref[delimiter_idx].strip() != phone:
delimiter_idx += 1
# word ends
ref.append("|")
hyp.append("|")
# get mispronounced words based on if there are phoneme errors present in the phonemes of that word
aligned_word_error_output = ""
words = text.split(" ")
word_error_bool = self.get_mispronounced_words(error_bool, phoneme_error_threshold)
wer = sum(word_error_bool) / len(words)
raw_info = {"ref":ref, "hyp": hyp, "per":per, "phoneme_errors": error_bool, "wer": wer, "words": words, "word_errors":word_error_bool}
return raw_info
def get_mispronounced_words(self, phoneme_error_bool, phoneme_error_threshold):
# map mispronounced phones back to words that were mispronounce to get WER
word_error_bool = []
phoneme_error_bool.append("|")
word_phones = self.split_lst_by_delim(phoneme_error_bool, "|")
# wrong only if percentage of phones that are wrong > phoneme error threshold
for phones in word_phones:
# get count of "s", "d", "a" in phones
error_count = 0
for phone in phones:
if phone == "s" or phone == "d" or phone == "a":
error_count += 1
# check if pass threshold
if error_count / len(phones) > phoneme_error_threshold:
word_error_bool.append(True)
else:
word_error_bool.append(False)
return word_error_bool
def split_lst_by_delim(self, lst, delimiter):
temp = []
res = []
for item in lst:
if item != delimiter:
temp.append(item.strip())
else:
res.append(temp);
temp = []
return res