import streamlit as st from speechbrain.pretrained import GraphemeToPhoneme import os import torchaudio from wav2vecasr.MispronounciationDetector import MispronounciationDetector from wav2vecasr.PhonemeASRModel import MultitaskPhonemeASRModel import json import os import random import openai from gtts import gTTS from io import BytesIO openai.api_key = os.getenv("OPENAI_KEY") # https://gtts.readthedocs.io/en/latest/ # def tts_gtts(text): mp3_fp = BytesIO() tts = gTTS(text, lang="en") tts.write_to_fp(mp3_fp) return mp3_fp def pronounce(text): if len(text) > 0: data = tts_gtts(text) return data return [] @st.cache_resource def load_model(): path = os.path.join(os.getcwd(), "wav2vecasr", "model", "multitask_best_ctc.pt") vocab_path = os.path.join(os.getcwd(), "wav2vecasr", "model", "vocab") device = "cpu" asr_model = MultitaskPhonemeASRModel(path, vocab_path, device) g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p") mispronounciation_detector = MispronounciationDetector(asr_model, g2p, device) return mispronounciation_detector def save_file(sound_file): # save your sound file in the right folder by following the path audio_folder_path = os.path.join(os.getcwd(), 'audio_files') if not os.path.exists(audio_folder_path): os.makedirs(audio_folder_path) with open(os.path.join(audio_folder_path, sound_file.name), 'wb') as f: f.write(sound_file.getbuffer()) return sound_file.name @st.cache_data def get_audio(saved_sound_filename): audio_path = f'audio_files/{saved_sound_filename}' audio, org_sr = torchaudio.load(audio_path) audio = torchaudio.functional.resample(audio, orig_freq=org_sr, new_freq=16000) audio = audio.view(audio.shape[1]) return audio @st.cache_data def get_prompts(): prompts_path = os.path.join(os.getcwd(), "wav2vecasr", "data", "prompts.json") f = open(prompts_path) data = json.load(f) prompts = data["prompts"] return prompts @st.cache_data def get_articulation_videos(): # note -- not all arpabets could be mapped to a video with visualisation on articulation path = os.path.join(os.getcwd(), "wav2vecasr", "data", "videos.json") f = open(path) data = json.load(f) return data def get_prompts_from_l2_arctic(prompts, current_prompt, num_to_get): selected_prompts = [] while len(selected_prompts) < num_to_get: prompt = random.choice(prompts) if prompt not in selected_prompts and prompt != current_prompt: selected_prompts.append(prompt) return selected_prompts def get_prompt_from_openai(words_with_error_list): try: words_with_errors = ", ".join(words_with_error_list) response = openai.ChatCompletion.create( model="gpt-3.5-turbo", messages=[ {"role": "system", "content": "You are writing practice reading prompts for learners of English to practice pronunciation. These prompts should be short, easy to understand and useful."}, {"role": "user", "content": f"Write a short sentence of less than 10 words and include the following words in the sentence: {words_with_errors} No numbers."} ] ) return response['choices'][0]['message']['content'] except: return "" def mispronounciation_detection_section(): st.write('# Prediction') st.write('1. Upload a recording of you saying the text in .wav format') uploaded_file = st.file_uploader(' ', type='wav') st.write('2. Input the text you are saying in your recording') text = st.text_input( "Enter the text you want to read 👇", label_visibility='collapsed' ) if st.button('Predict'): if uploaded_file is not None and len(text) > 0: # get audio from loaded file save_file(uploaded_file) audio = get_audio(uploaded_file.name) # load model mispronunciation_detector = load_model() st.write('# Detection Results') with st.spinner('Predicting...'): # detect raw_info = mispronunciation_detector.detect(audio, text, phoneme_error_threshold=0.25) # display prediction results for phonemes st.write('#### Phoneme Level Analysis') st.write(f"Phoneme Error Rate: {round(raw_info['per'],2)}") st.markdown( f""" ``` {raw_info['ref']} {raw_info['hyp']} {raw_info['phoneme_errors']} ``` """, unsafe_allow_html=True, ) st.divider() # display word errors md = [] words_with_errors = [] for word, has_error in zip(raw_info["words"], raw_info["word_errors"]): if has_error: words_with_errors.append(word) md.append(f"**{word}**") else: md.append(word) st.write('#### Word Level Analysis') st.write(f"Word Error Rate: {round(raw_info['wer'], 2)} and the following words in bold have errors:") st.markdown(" ".join(md)) st.divider() st.write('#### What is next?') # display pronounciation e.g. st.write("Compare your pronunciation to pronounced sample") st.audio(f'audio_files/{uploaded_file.name}', format="audio/wav", start_time=0) pronounced_sample = pronounce(text) st.audio(pronounced_sample, format="audio/wav", start_time=0) # display more prompts to practice -- 1 from ChatGPT -- based on user's mistakes, 2 from L2 Arctic st.write('Here are some more prompts for you to practice:') selected_prompts = [] unique_words_with_errors = list(set(words_with_errors)) prompt_for_mistakes_made = get_prompt_from_openai(unique_words_with_errors) if prompt_for_mistakes_made: selected_prompts.append(prompt_for_mistakes_made) prompts = get_prompts() l2_arctic_prompts = get_prompts_from_l2_arctic(prompts, text, 3-len(selected_prompts)) selected_prompts.extend(l2_arctic_prompts) for prompt in selected_prompts: st.code(f'''{prompt}''', language="python") else: st.error('The audio or text has not been properly input', icon="🚨") return def video_section(): st.write('# Get helpful videos on phoneme articulation') problem_phoneme = st.text_input( "Enter the phoneme you had problems with 👇" ) arpabet_to_video_map = get_articulation_videos() if st.button('Look up'): if not problem_phoneme: st.error('The audio or text has not been properly input', icon="🚨") elif problem_phoneme in arpabet_to_video_map: video_link = arpabet_to_video_map[problem_phoneme]["link"] if video_link: st.video(video_link) else: st.write("Sorry, we couldn't find a good enough video yet :( we are working on it!") if __name__ == '__main__': st.write('___') # create a sidebar st.sidebar.title('Pronounciation Evaluation') select = st.sidebar.selectbox('', ['Main Page', 'Mispronounciation Detection', 'Helpful Videos for Problem Phonemes'], key='1', label_visibility='collapsed') st.sidebar.write(select) if select=='Mispronounciation Detection': mispronounciation_detection_section() elif select=="Helpful Videos for Problem Phonemes": video_section() else: st.write('# Pronounciation Evaluation') st.write('This app is designed to detect mispronounciation of English words for English learners from Asian countries like Korean, Mandarin and Vietnameses.') st.write('Wav2Vec2.0 was used to detect the phonemes from the learner and this output is compared with the correct phoneme sequence generated from input text')