ai-pronunciation-trainer / pronunciationTrainer.py
thiagohgl's picture
First repository code commit
28d0c5f
raw
history blame
8.36 kB
import torch
import numpy as np
import models as mo
import WordMetrics
import WordMatching as wm
import epitran
import ModelInterfaces as mi
import AIModels
import RuleBasedModels
from string import punctuation
import time
def getTrainer(language: str):
device = torch.device('cpu')
model, decoder = mo.getASRModel(language)
model = model.to(device)
model.eval()
asr_model = AIModels.NeuralASR(model, decoder)
if language == 'de':
phonem_converter = RuleBasedModels.EpitranPhonemConverter(
epitran.Epitran('deu-Latn'))
elif language == 'en':
phonem_converter = RuleBasedModels.EngPhonemConverter()
else:
raise ValueError('Language not implemented')
trainer = PronunciationTrainer(
asr_model, phonem_converter)
return trainer
class PronunciationTrainer:
current_transcript: str
current_ipa: str
current_recorded_audio: torch.Tensor
current_recorded_transcript: str
current_recorded_word_locations: list
current_recorded_intonations: torch.tensor
current_words_pronunciation_accuracy = []
categories_thresholds = np.array([80, 60, 59])
sampling_rate = 16000
def __init__(self, asr_model: mi.IASRModel, word_to_ipa_coverter: mi.ITextToPhonemModel) -> None:
self.asr_model = asr_model
self.ipa_converter = word_to_ipa_coverter
def getTranscriptAndWordsLocations(self, audio_length_in_samples: int):
audio_transcript = self.asr_model.getTranscript()
word_locations_in_samples = self.asr_model.getWordLocations()
fade_duration_in_samples = 0.05*self.sampling_rate
word_locations_in_samples = [(int(np.maximum(0, word['start_ts']-fade_duration_in_samples)), int(np.minimum(
audio_length_in_samples-1, word['end_ts']+fade_duration_in_samples))) for word in word_locations_in_samples]
return audio_transcript, word_locations_in_samples
def getWordsRelativeIntonation(self, Audio: torch.tensor, word_locations: list):
intonations = torch.zeros((len(word_locations), 1))
intonation_fade_samples = 0.3*self.sampling_rate
print(intonations.shape)
for word in range(len(word_locations)):
intonation_start = int(np.maximum(
0, word_locations[word][0]-intonation_fade_samples))
intonation_end = int(np.minimum(
Audio.shape[1]-1, word_locations[word][1]+intonation_fade_samples))
intonations[word] = torch.sqrt(torch.mean(
Audio[0][intonation_start:intonation_end]**2))
intonations = intonations/torch.mean(intonations)
return intonations
##################### ASR Functions ###########################
def processAudioForGivenText(self, recordedAudio: torch.Tensor = None, real_text=None):
start = time.time()
recording_transcript, recording_ipa, word_locations = self.getAudioTranscript(
recordedAudio)
print('Time for NN to transcript audio: ', str(time.time()-start))
start = time.time()
real_and_transcribed_words, real_and_transcribed_words_ipa, mapped_words_indices = self.matchSampleAndRecordedWords(
real_text, recording_transcript)
print('Time for matching transcripts: ', str(time.time()-start))
start_time, end_time = self.getWordLocationsFromRecordInSeconds(
word_locations, mapped_words_indices)
pronunciation_accuracy, current_words_pronunciation_accuracy = self.getPronunciationAccuracy(
real_and_transcribed_words) # _ipa
pronunciation_categories = self.getWordsPronunciationCategory(
current_words_pronunciation_accuracy)
result = {'recording_transcript': recording_transcript,
'real_and_transcribed_words': real_and_transcribed_words,
'recording_ipa': recording_ipa, 'start_time': start_time, 'end_time': end_time,
'real_and_transcribed_words_ipa': real_and_transcribed_words_ipa, 'pronunciation_accuracy': pronunciation_accuracy,
'pronunciation_categories': pronunciation_categories}
return result
def getAudioTranscript(self, recordedAudio: torch.Tensor = None):
current_recorded_audio = recordedAudio
current_recorded_audio = self.preprocessAudio(
current_recorded_audio)
self.asr_model.processAudio(current_recorded_audio)
current_recorded_transcript, current_recorded_word_locations = self.getTranscriptAndWordsLocations(
current_recorded_audio.shape[1])
current_recorded_ipa = self.ipa_converter.convertToPhonem(
current_recorded_transcript)
return current_recorded_transcript, current_recorded_ipa, current_recorded_word_locations
def getWordLocationsFromRecordInSeconds(self, word_locations, mapped_words_indices) -> list:
start_time = []
end_time = []
for word_idx in range(len(mapped_words_indices)):
start_time.append(float(word_locations[mapped_words_indices[word_idx]]
[0])/self.sampling_rate)
end_time.append(float(word_locations[mapped_words_indices[word_idx]]
[1])/self.sampling_rate)
return ' '.join([str(time) for time in start_time]), ' '.join([str(time) for time in end_time])
##################### END ASR Functions ###########################
##################### Evaluation Functions ###########################
def matchSampleAndRecordedWords(self, real_text, recorded_transcript):
words_estimated = recorded_transcript.split()
if real_text is None:
words_real = self.current_transcript[0].split()
else:
words_real = real_text.split()
mapped_words, mapped_words_indices = wm.get_best_mapped_words(
words_estimated, words_real)
real_and_transcribed_words = []
real_and_transcribed_words_ipa = []
for word_idx in range(len(words_real)):
if word_idx >= len(mapped_words)-1:
mapped_words.append('-')
real_and_transcribed_words.append(
(words_real[word_idx], mapped_words[word_idx]))
real_and_transcribed_words_ipa.append((self.ipa_converter.convertToPhonem(words_real[word_idx]),
self.ipa_converter.convertToPhonem(mapped_words[word_idx])))
return real_and_transcribed_words, real_and_transcribed_words_ipa, mapped_words_indices
def getPronunciationAccuracy(self, real_and_transcribed_words_ipa) -> float:
total_mismatches = 0.
number_of_phonemes = 0.
current_words_pronunciation_accuracy = []
for pair in real_and_transcribed_words_ipa:
real_without_punctuation = self.removePunctuation(pair[0]).lower()
number_of_word_mismatches = WordMetrics.edit_distance_python(
real_without_punctuation, self.removePunctuation(pair[1]).lower())
total_mismatches += number_of_word_mismatches
number_of_phonemes_in_word = len(real_without_punctuation)
number_of_phonemes += number_of_phonemes_in_word
current_words_pronunciation_accuracy.append(float(
number_of_phonemes_in_word-number_of_word_mismatches)/number_of_phonemes_in_word*100)
percentage_of_correct_pronunciations = (
number_of_phonemes-total_mismatches)/number_of_phonemes*100
return np.round(percentage_of_correct_pronunciations), current_words_pronunciation_accuracy
def removePunctuation(self, word: str) -> str:
return ''.join([char for char in word if char not in punctuation])
def getWordsPronunciationCategory(self, accuracies) -> list:
categories = []
for accuracy in accuracies:
categories.append(
self.getPronunciationCategoryFromAccuracy(accuracy))
return categories
def getPronunciationCategoryFromAccuracy(self, accuracy) -> int:
return np.argmin(abs(self.categories_thresholds-accuracy))
def preprocessAudio(self, audio: torch.tensor) -> torch.tensor:
audio = audio-torch.mean(audio)
audio = audio/torch.max(torch.abs(audio))
return audio