import streamlit as st import io import soundfile as sf import numpy as np import whisper import torch # pre-process # file object input case def trans_byte2arr(byte_data: bytes): arr_data, _ = sf.read(file=io.BytesIO(byte_data.read()), dtype="float32") sig_data = merge_sig(arr_data) return sig_data def merge_sig(arr_data): if arr_data.ndim == 2: # left right channel sound file case # element-wise add left and right sig_data = arr_data.sum(axis=1) elif arr_data.ndim > 2: print("this file is not audio file") else: return arr_data return sig_data # pre-process def audio_speed_reduce(sig_data: np.array, sample_rate: int): if sample_rate > 16000: reduce_size = sample_rate / 16000 elif sample_rate < 16000: reduce_size = 16000 / sample_rate else: reduce_size = None sig_data = merge_sig(sig_data) if reduce_size is None: return audio else: try: audio = sig_data.reshape(-1, int(reduce_size)).mean(axis=1) except: slice_size = len(sig_data) % reduce_size audio = ( sig_data[: -int(slice_size)].reshape(-1, int(reduce_size)).mean(axis=1) ) return audio def convert_byte_audio(byte_data): # convert audio from bytes arr_data, sr = sf.read(file=io.BytesIO(byte_data), dtype="float32") # reduce audio audio = audio_speed_reduce(arr_data, sr) return audio def get_langage_cls(audio_arr: np.array, model: torch.nn.Module): # data slice 30 sec audio = whisper.pad_or_trim(audio_arr) # make log-Mel spectrogram and move to the same device as the model mel = whisper.log_mel_spectrogram(audio).to(model.device) # detect the spoken language _, probs = model.detect_language(mel) return probs def transcribe(audio: np.array, model: torch.nn.Module, task: str = "transcribe"): base_option = dict(beam_size=5, best_of=5) if task == "transcribe": base_option = dict(task="transcribe", **base_option) else: base_option = dict(task="translate", **base_option) result = model.transcribe(audio, **base_option) return result["text"] def load_model(model_name: str): model = whisper.load_model(model_name) return model file_data = st.file_uploader("Upload your audio(.wav) file") if file_data is not None and file_data.name[-4:] == ".wav": # To read file as bytes: bytes_data = file_data.getvalue() audio_arr = convert_byte_audio(bytes_data) # audio plotting #fig, ax = plt.subplots() #ax.plot(audio_arr) #st.pyplot(fig) st.audio(bytes_data) model_option = [ "tiny", "base", "small", "medium", "large", ] selected_model_size = st.selectbox( "What do you want model size?", ["None"] + model_option ) if selected_model_size in model_option: model = load_model(selected_model_size) lang_button = st.button("What is language") if lang_button: with st.spinner('Detecting language...'): probs = get_langage_cls(audio_arr=audio_arr, model=model) st.write(f"Detected language: {max(probs, key=probs.get)}") task_option = ["transcribe", "translate"] translate_task = st.selectbox("What is your task", ["None"] + task_option) if translate_task != "None": with st.spinner('In progress...'): result = transcribe(audio=audio_arr, model=model, task=translate_task) st.write(result)