wanyin09
Rename deploy.py to app.py for Streamlit
c336150
import streamlit as st
import torch
import torchaudio
import librosa
import numpy as np
import tempfile
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from jiwer import wer
import time
import re
# Title
st.title("๐ŸŽ™๏ธ Bahasa Rojak Transcriber (Whisper Fine-tuned)")
# Session state initialization
if "audio_bytes" not in st.session_state:
st.session_state.audio_bytes = None
if "audio_path" not in st.session_state:
st.session_state.audio_path = None
if "ground_truth" not in st.session_state:
st.session_state.ground_truth = ""
if "predicted_text" not in st.session_state:
st.session_state.predicted_text = ""
if "wer_value" not in st.session_state:
st.session_state.wer_value = None
if "selected_tab" not in st.session_state:
st.session_state.selected_tab = "๐Ÿ“ Upload Audio"
if "previous_tab" not in st.session_state:
st.session_state.previous_tab = "๐Ÿ“ Upload Audio"
# Tab Selection using st.tabs()
tab1, tab2 = st.tabs(["๐Ÿ“ Upload Audio", "๐ŸŽค Record Audio"])
# Reset state if tab is changed
if st.session_state.selected_tab != st.session_state.previous_tab:
st.session_state.audio_bytes = None
st.session_state.audio_path = None
st.session_state.ground_truth = ""
st.session_state.predicted_text = ""
st.session_state.wer_value = None
st.session_state.previous_tab = st.session_state.selected_tab
# Tab 1: Upload Audio
with tab1:
uploaded_file = st.file_uploader("Upload a .wav or .mp3 file", type=["wav", "mp3", "flac", "m4a", "ogg"])
if uploaded_file:
try:
st.session_state.audio_bytes = uploaded_file.read()
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp:
tmp.write(st.session_state.audio_bytes)
st.session_state.audio_path = tmp.name
librosa.load(st.session_state.audio_path, sr=16000) # validate
st.audio(st.session_state.audio_bytes, format="audio/wav")
except Exception as e:
st.error(f"โŒ Failed to read audio file: {str(e)}")
st.session_state.audio_bytes = None
# Tab 2: Record Audio
with tab2:
duration = st.slider("Recording Duration (seconds)", 1, 60, 12)
if st.button("Start Recording"):
try:
st.info(f"Recording for {duration} seconds... Please speak now.")
recording = sd.rec(int(duration * 16000), samplerate=16000, channels=1, dtype='float32')
sd.wait()
recording = np.squeeze(recording)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
torchaudio.save(tmp.name, torch.tensor(recording).unsqueeze(0), 16000)
st.session_state.audio_path = tmp.name
librosa.load(st.session_state.audio_path, sr=16000)
st.audio(recording, format="audio/wav")
except Exception as e:
st.error(f"โŒ Failed to process recorded audio: {str(e)}")
st.session_state.audio_bytes = None
# Input ground truth for WER
st.session_state.ground_truth = st.text_input(
"Enter ground truth for WER calculation (Optional)",
value=st.session_state.ground_truth,
key="ground_truth_input"
)
# Load model and processor
@st.cache_resource
def load_model():
model_repo_name = "wy0909/Whisper-MixedLanguageModel"
model = WhisperForConditionalGeneration.from_pretrained(model_repo_name)
processor = WhisperProcessor.from_pretrained(model_repo_name)
model.config.forced_decoder_ids = None
model.generation_config.forced_decoder_ids = None
model.config.suppress_tokens = []
return model, processor
model, processor = load_model()
def capitalize_sentences(text):
sentences = re.split(r'(?<=[.!?]) +', text)
capitalized = [s.strip().capitalize() for s in sentences]
return ' '.join(capitalized)
# Transcription Mode
mode = st.selectbox(
"Transcription Mode:",
options=["With Whisper", "With Whisper and WER"],
help="Select transcription mode"
)
# Transcription Button
if st.button("๐Ÿ“ Transcribe"):
if not st.session_state.audio_bytes:
st.error("Please upload or record an audio file first.")
else:
start_time = time.time()
try:
waveform, sample_rate = torchaudio.load(st.session_state.audio_path)
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
predicted_ids = model.generate(inputs["input_features"])
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
st.session_state.predicted_text = capitalize_sentences(transcription)
st.markdown("### ๐ŸŽค Transcription Result")
st.success(st.session_state.predicted_text)
if mode == "With Whisper and WER" and st.session_state.ground_truth:
st.session_state.wer_value = wer(st.session_state.ground_truth.lower(), st.session_state.predicted_text.lower())
st.markdown("### ๐Ÿงฎ Word Error Rate (WER)")
st.write(f"WER: `{st.session_state.wer_value:.2f}`")
except Exception as e:
st.error(f"โŒ Transcription failed: {str(e)}")
end_time = time.time()
duration = end_time - start_time
st.markdown(f"<div style='font-size: 10px; color: gray;'>Processed in {duration:.2f}s</div>", unsafe_allow_html=True)