ASR / app.py
Kr08's picture
Generate french texts for now
b815c4a verified
raw
history blame
2.77 kB
import torch
import streamlit as st
import torchaudio as ta
from io import BytesIO
from transformers import AutoProcessor, SeamlessM4TModel, WhisperProcessor, WhisperForConditionalGeneration
if torch.cuda.is_available():
device = "cuda:0"
torch_dtype = torch.float16
else:
device = "cpu"
torch_dtype = torch.float32
SAMPLING_RATE=16000
task = "transcribe"
print(f"{device} Active!")
# load Whisper model and processor
processor = WhisperProcessor.from_pretrained("openai/whisper-small")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
# Title of the app
st.title("Audio Player with Live Transcription")
# Sidebar for file uploader and submit button
st.sidebar.header("Upload Audio Files")
uploaded_files = st.sidebar.file_uploader("Choose audio files", type=["mp3", "wav"], accept_multiple_files=True)
submit_button = st.sidebar.button("Submit")
# def transcribe_audio(audio_data):
# recognizer = sr.Recognizer()
# with sr.AudioFile(audio_data) as source:
# audio = recognizer.record(source)
# try:
# # Transcribe the audio using Google Web Speech API
# transcription = recognizer.recognize_google(audio)
# return transcription
# except sr.UnknownValueError:
# return "Unable to transcribe the audio."
# except sr.RequestError as e:
# return f"Could not request results; {e}"
if submit_button and uploaded_files is not None:
st.write("Files uploaded successfully!")
for uploaded_file in uploaded_files:
# Display file name and audio player
st.write(f"**File name**: {uploaded_file.name}")
st.audio(uploaded_file, format=uploaded_file.type)
# Transcription section
st.write("**Transcription**:")
# Read the uploaded file data
waveform, sampling_rate = ta.load(uploaded_file.getvalue())
resampled_inp = ta.functional.resample(waveform, orig_freq=sampling_rate, new_freq=SAMPLING_RATE)
input_features = processor(resampled_inp[0], sampling_rate=16000, return_tensors='pt').input_features
## Here Generate specific language!!!
forced_decoder_ids = processor.get_decoder_prompt_ids(language="french", task="translate")
if task == "translate":
predicted_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids)
else:
predicted_ids = model.generate(input_features)
# decode token ids to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
st.write(transcription)
# print(waveform, sampling_rate)
# Run transcription function and display
# import pdb;pdb.set_trace()
# st.write(audio_data.getvalue())