import streamlit as st import torch import torchaudio import librosa import numpy as np import tempfile from transformers import WhisperForConditionalGeneration, WhisperProcessor from jiwer import wer import time import re # Title st.title("🎙️ Bahasa Rojak Transcriber (Whisper Fine-tuned)") # Session state initialization if "audio_bytes" not in st.session_state: st.session_state.audio_bytes = None if "audio_path" not in st.session_state: st.session_state.audio_path = None if "ground_truth" not in st.session_state: st.session_state.ground_truth = "" if "predicted_text" not in st.session_state: st.session_state.predicted_text = "" if "wer_value" not in st.session_state: st.session_state.wer_value = None if "selected_tab" not in st.session_state: st.session_state.selected_tab = "📁 Upload Audio" if "previous_tab" not in st.session_state: st.session_state.previous_tab = "📁 Upload Audio" # Tab Selection using st.tabs() tab1, tab2 = st.tabs(["📁 Upload Audio", "🎤 Record Audio"]) # Reset state if tab is changed if st.session_state.selected_tab != st.session_state.previous_tab: st.session_state.audio_bytes = None st.session_state.audio_path = None st.session_state.ground_truth = "" st.session_state.predicted_text = "" st.session_state.wer_value = None st.session_state.previous_tab = st.session_state.selected_tab # Tab 1: Upload Audio with tab1: uploaded_file = st.file_uploader("Upload a .wav or .mp3 file", type=["wav", "mp3", "flac", "m4a", "ogg"]) if uploaded_file: try: st.session_state.audio_bytes = uploaded_file.read() with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp: tmp.write(st.session_state.audio_bytes) st.session_state.audio_path = tmp.name librosa.load(st.session_state.audio_path, sr=16000) # validate st.audio(st.session_state.audio_bytes, format="audio/wav") except Exception as e: st.error(f"❌ Failed to read audio file: {str(e)}") st.session_state.audio_bytes = None # Tab 2: Record Audio with tab2: duration = st.slider("Recording Duration (seconds)", 1, 60, 12) if st.button("Start Recording"): try: st.info(f"Recording for {duration} seconds... Please speak now.") recording = sd.rec(int(duration * 16000), samplerate=16000, channels=1, dtype='float32') sd.wait() recording = np.squeeze(recording) with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: torchaudio.save(tmp.name, torch.tensor(recording).unsqueeze(0), 16000) st.session_state.audio_path = tmp.name librosa.load(st.session_state.audio_path, sr=16000) st.audio(recording, format="audio/wav") except Exception as e: st.error(f"❌ Failed to process recorded audio: {str(e)}") st.session_state.audio_bytes = None # Input ground truth for WER st.session_state.ground_truth = st.text_input( "Enter ground truth for WER calculation (Optional)", value=st.session_state.ground_truth, key="ground_truth_input" ) # Load model and processor @st.cache_resource def load_model(): model_repo_name = "wy0909/Whisper-MixedLanguageModel" model = WhisperForConditionalGeneration.from_pretrained(model_repo_name) processor = WhisperProcessor.from_pretrained(model_repo_name) model.config.forced_decoder_ids = None model.generation_config.forced_decoder_ids = None model.config.suppress_tokens = [] return model, processor model, processor = load_model() def capitalize_sentences(text): sentences = re.split(r'(?<=[.!?]) +', text) capitalized = [s.strip().capitalize() for s in sentences] return ' '.join(capitalized) # Transcription Mode mode = st.selectbox( "Transcription Mode:", options=["With Whisper", "With Whisper and WER"], help="Select transcription mode" ) # Transcription Button if st.button("📝 Transcribe"): if not st.session_state.audio_bytes: st.error("Please upload or record an audio file first.") else: start_time = time.time() try: waveform, sample_rate = torchaudio.load(st.session_state.audio_path) waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt") with torch.no_grad(): predicted_ids = model.generate(inputs["input_features"]) transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] st.session_state.predicted_text = capitalize_sentences(transcription) st.markdown("### 🎤 Transcription Result") st.success(st.session_state.predicted_text) if mode == "With Whisper and WER" and st.session_state.ground_truth: st.session_state.wer_value = wer(st.session_state.ground_truth.lower(), st.session_state.predicted_text.lower()) st.markdown("### 🧮 Word Error Rate (WER)") st.write(f"WER: `{st.session_state.wer_value:.2f}`") except Exception as e: st.error(f"❌ Transcription failed: {str(e)}") end_time = time.time() duration = end_time - start_time st.markdown(f"
Processed in {duration:.2f}s
", unsafe_allow_html=True)