Nemov2 / app.py
AkshatJain1402's picture
Update app.py
d2870ee verified
raw history blame
No virus
6.74 kB
import streamlit as st
from st_audiorec import st_audiorec #pip install streamlit-audiorec
import nemo.collections.asr as nemo_asr
from pydub import AudioSegment
import subprocess
import io
import os
import uuid
# from func_timeout import func_timeout, FunctionTimedOut
import streamlit.components.v1 as components
import wave
import shutil
import contextlib
def download_model(model_path, model_link, lang):
if not os.path.exists(model_path):
print("Downloading AM: ", lang)
download = subprocess.run(["wget","-P",model_path,model_link],capture_output=True, text=True)
if download.returncode != 0:
raise Exception(lang + " Model Download Failed: {download.stderr}")
else:
print('Downloaded AM: ' + lang)
def main():
@st.cache_resource
def get_model():
try:
os.makedirs("audio_cache")
except:
shutil.rmtree("audio_cache")
os.makedirs("audio_cache")
download_model("./hi_am_model", "https://storage.googleapis.com/vakyansh-open-models/conformer_models/hindi/filtered_v1_ssl_2022-07-08_19-43-25/Conformer-CTC-BPE-Large.nemo", "hindi")
download_model("./en_am_model", "https://storage.googleapis.com/vakyansh-open-models/conformer_models/english/2022-09-13_15-50-48/Conformer-CTC-BPE-Large.nemo", "english")
try:
en_asr_model = nemo_asr.models.EncDecCTCModelBPE.restore_from("./en_am_model/Conformer-CTC-BPE-Large.nemo") #("/home/tanmay/zb/en_Conformer-CTC-BPE-Large.nemo")
hi_asr_model = nemo_asr.models.EncDecCTCModelBPE.restore_from("./hi_am_model/Conformer-CTC-BPE-Large.nemo") #("/home/tanmay/zb/hi_Conformer-CTC-BPE-Large.nemo")
except Exception as e:
print("ERROR Loading Model... ",e)
exit (1)
return en_asr_model, hi_asr_model
en_asr_model, hi_asr_model = get_model()
def get_audio_length_wave(filename):
with contextlib.closing(wave.open(filename,'r')) as f:
frames = f.getnframes()
rate = f.getframerate()
duration = frames / float(rate)
return duration
def save_audio_file_using_mic(audio_location, wav_audio_data):
audio_file = io.BytesIO(wav_audio_data)
audio = AudioSegment.from_file(audio_file)
audio = audio.set_sample_width(2)
audio = audio.set_channels(1)
audio = audio.set_frame_rate(16000)
audio.export(audio_location, format="wav")
def delete_audio_file(audio_location):
if os.path.exists(audio_location):
try:
os.remove(audio_location)
print(f"File deleted: {audio_location}")
except OSError as e:
print(f"Error deleting file: {e}")
def save_audio_file_using_upload(audio_location, uploaded_file):
if not uploaded_file.name.endswith(".wav"):
st.write("ERROR! File extension should be wav")
return 0
# get audio file length
with open(audio_location, "wb") as f:
f.write(uploaded_file.getvalue())
audioDuration = get_audio_length_wave(audio_location)
st.write("Audio Duration: ",audioDuration)
if audioDuration < 60 and audioDuration > 0:
return 1
else:
st.write('ERROR! File is more than 1 minute')
st.write('Uploaded File duration is restricted upto 1 minute')
return 0
########################################################################################################################################
# APP FUNCTIONALITY STARTS #
########################################################################################################################################
st.title("💬 Vocalize: Empower Your Voice ")
st.write("You can either try to record your own voice using a microphone or upload a small file, up to 1 minute in length, to transcribe.")
st.write('')
col1, ___, ____ = st.columns(3)
with col1:
language = st.selectbox('Select Your Preferred Language.',('English', 'Hindi'))
st.header("Transcribe Your Voice Using Mic")
wav_audio_data = st_audiorec()
if wav_audio_data:
audio_location = "audio_cache/" + str(uuid.uuid4()) + ".wav"
save_audio_file_using_mic(audio_location, wav_audio_data)
duration=get_audio_length_wave(audio_location)
if duration>2:
if language == "Hindi":
text = hi_asr_model.transcribe([audio_location], logprobs=False)[0]
else:
text = en_asr_model.transcribe([audio_location], logprobs=False)[0]
else:
st.write("ERROR mic recording should be more than 2 seconds")
print(text)
st.write("Transcription:")
st.write(text)
delete_audio_file(audio_location)
st.header("Transcribe Files")
st.write("Ensure that the file extension is .wav with a sample rate of 16,000 Hz and a single channel (mono)")
try:
uploaded_file = st.file_uploader("Upload Your Recording", disabled=False,type='')
if uploaded_file is not None:
# Store the uploaded file:
audio_location = "audio_cache/" + str(uuid.uuid4()) + ".wav"
flag = save_audio_file_using_upload(audio_location, uploaded_file)
if flag == 1:
if language == "Hindi":
with st.spinner():
text = hi_asr_model.transcribe([audio_location], logprobs=False)[0]
print(text)
else:
with st.spinner():
text = en_asr_model.transcribe([audio_location], logprobs=False)[0]
print(text)
st.write(text)
delete_audio_file(audio_location)
except Exception as e:
st.write("ERROR! Something is wrong with the uploaded file.")
st.write("The file extension should be .wav with a sample rate of 16,000 Hz and a mono channel")
print(str(e))
delete_audio_file(audio_location)
# Footer mentioning the website
footer = """
<style>
.footer {
width: 100%;
text-align: center;
padding: 10px 0;
margin-top: 240px; /* Add margin-top for some gap */
}
</style>
<div class="footer">
<p>Powered by: <a href="https://zinglebytes.com">zinglebytes.com</a> | <a href="mailto:info@zinglebytes.com">Email us</a></p>
</div>
"""
st.markdown(footer, unsafe_allow_html=True)
main()