Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import torchaudio | |
import librosa | |
import numpy as np | |
import tempfile | |
from transformers import WhisperForConditionalGeneration, WhisperProcessor | |
from jiwer import wer | |
import time | |
import re | |
# Title | |
st.title("๐๏ธ Bahasa Rojak Transcriber (Whisper Fine-tuned)") | |
# Session state initialization | |
if "audio_bytes" not in st.session_state: | |
st.session_state.audio_bytes = None | |
if "audio_path" not in st.session_state: | |
st.session_state.audio_path = None | |
if "ground_truth" not in st.session_state: | |
st.session_state.ground_truth = "" | |
if "predicted_text" not in st.session_state: | |
st.session_state.predicted_text = "" | |
if "wer_value" not in st.session_state: | |
st.session_state.wer_value = None | |
if "selected_tab" not in st.session_state: | |
st.session_state.selected_tab = "๐ Upload Audio" | |
if "previous_tab" not in st.session_state: | |
st.session_state.previous_tab = "๐ Upload Audio" | |
# Tab Selection using st.tabs() | |
tab1, tab2 = st.tabs(["๐ Upload Audio", "๐ค Record Audio"]) | |
# Reset state if tab is changed | |
if st.session_state.selected_tab != st.session_state.previous_tab: | |
st.session_state.audio_bytes = None | |
st.session_state.audio_path = None | |
st.session_state.ground_truth = "" | |
st.session_state.predicted_text = "" | |
st.session_state.wer_value = None | |
st.session_state.previous_tab = st.session_state.selected_tab | |
# Tab 1: Upload Audio | |
with tab1: | |
uploaded_file = st.file_uploader("Upload a .wav or .mp3 file", type=["wav", "mp3", "flac", "m4a", "ogg"]) | |
if uploaded_file: | |
try: | |
st.session_state.audio_bytes = uploaded_file.read() | |
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp: | |
tmp.write(st.session_state.audio_bytes) | |
st.session_state.audio_path = tmp.name | |
librosa.load(st.session_state.audio_path, sr=16000) # validate | |
st.audio(st.session_state.audio_bytes, format="audio/wav") | |
except Exception as e: | |
st.error(f"โ Failed to read audio file: {str(e)}") | |
st.session_state.audio_bytes = None | |
# Tab 2: Record Audio | |
with tab2: | |
duration = st.slider("Recording Duration (seconds)", 1, 60, 12) | |
if st.button("Start Recording"): | |
try: | |
st.info(f"Recording for {duration} seconds... Please speak now.") | |
recording = sd.rec(int(duration * 16000), samplerate=16000, channels=1, dtype='float32') | |
sd.wait() | |
recording = np.squeeze(recording) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
torchaudio.save(tmp.name, torch.tensor(recording).unsqueeze(0), 16000) | |
st.session_state.audio_path = tmp.name | |
librosa.load(st.session_state.audio_path, sr=16000) | |
st.audio(recording, format="audio/wav") | |
except Exception as e: | |
st.error(f"โ Failed to process recorded audio: {str(e)}") | |
st.session_state.audio_bytes = None | |
# Input ground truth for WER | |
st.session_state.ground_truth = st.text_input( | |
"Enter ground truth for WER calculation (Optional)", | |
value=st.session_state.ground_truth, | |
key="ground_truth_input" | |
) | |
# Load model and processor | |
def load_model(): | |
model_repo_name = "wy0909/Whisper-MixedLanguageModel" | |
model = WhisperForConditionalGeneration.from_pretrained(model_repo_name) | |
processor = WhisperProcessor.from_pretrained(model_repo_name) | |
model.config.forced_decoder_ids = None | |
model.generation_config.forced_decoder_ids = None | |
model.config.suppress_tokens = [] | |
return model, processor | |
model, processor = load_model() | |
def capitalize_sentences(text): | |
sentences = re.split(r'(?<=[.!?]) +', text) | |
capitalized = [s.strip().capitalize() for s in sentences] | |
return ' '.join(capitalized) | |
# Transcription Mode | |
mode = st.selectbox( | |
"Transcription Mode:", | |
options=["With Whisper", "With Whisper and WER"], | |
help="Select transcription mode" | |
) | |
# Transcription Button | |
if st.button("๐ Transcribe"): | |
if not st.session_state.audio_bytes: | |
st.error("Please upload or record an audio file first.") | |
else: | |
start_time = time.time() | |
try: | |
waveform, sample_rate = torchaudio.load(st.session_state.audio_path) | |
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) | |
inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt") | |
with torch.no_grad(): | |
predicted_ids = model.generate(inputs["input_features"]) | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
st.session_state.predicted_text = capitalize_sentences(transcription) | |
st.markdown("### ๐ค Transcription Result") | |
st.success(st.session_state.predicted_text) | |
if mode == "With Whisper and WER" and st.session_state.ground_truth: | |
st.session_state.wer_value = wer(st.session_state.ground_truth.lower(), st.session_state.predicted_text.lower()) | |
st.markdown("### ๐งฎ Word Error Rate (WER)") | |
st.write(f"WER: `{st.session_state.wer_value:.2f}`") | |
except Exception as e: | |
st.error(f"โ Transcription failed: {str(e)}") | |
end_time = time.time() | |
duration = end_time - start_time | |
st.markdown(f"<div style='font-size: 10px; color: gray;'>Processed in {duration:.2f}s</div>", unsafe_allow_html=True) | |