Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import torchaudio | |
| import librosa | |
| import numpy as np | |
| import tempfile | |
| from transformers import WhisperForConditionalGeneration, WhisperProcessor | |
| from jiwer import wer | |
| import time | |
| import re | |
| # Title | |
| st.title("๐๏ธ Bahasa Rojak Transcriber (Whisper Fine-tuned)") | |
| # Session state initialization | |
| if "audio_bytes" not in st.session_state: | |
| st.session_state.audio_bytes = None | |
| if "audio_path" not in st.session_state: | |
| st.session_state.audio_path = None | |
| if "ground_truth" not in st.session_state: | |
| st.session_state.ground_truth = "" | |
| if "predicted_text" not in st.session_state: | |
| st.session_state.predicted_text = "" | |
| if "wer_value" not in st.session_state: | |
| st.session_state.wer_value = None | |
| if "selected_tab" not in st.session_state: | |
| st.session_state.selected_tab = "๐ Upload Audio" | |
| if "previous_tab" not in st.session_state: | |
| st.session_state.previous_tab = "๐ Upload Audio" | |
| # Tab Selection using st.tabs() | |
| tab1, tab2 = st.tabs(["๐ Upload Audio", "๐ค Record Audio"]) | |
| # Reset state if tab is changed | |
| if st.session_state.selected_tab != st.session_state.previous_tab: | |
| st.session_state.audio_bytes = None | |
| st.session_state.audio_path = None | |
| st.session_state.ground_truth = "" | |
| st.session_state.predicted_text = "" | |
| st.session_state.wer_value = None | |
| st.session_state.previous_tab = st.session_state.selected_tab | |
| # Tab 1: Upload Audio | |
| with tab1: | |
| uploaded_file = st.file_uploader("Upload a .wav or .mp3 file", type=["wav", "mp3", "flac", "m4a", "ogg"]) | |
| if uploaded_file: | |
| try: | |
| st.session_state.audio_bytes = uploaded_file.read() | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp: | |
| tmp.write(st.session_state.audio_bytes) | |
| st.session_state.audio_path = tmp.name | |
| librosa.load(st.session_state.audio_path, sr=16000) # validate | |
| st.audio(st.session_state.audio_bytes, format="audio/wav") | |
| except Exception as e: | |
| st.error(f"โ Failed to read audio file: {str(e)}") | |
| st.session_state.audio_bytes = None | |
| # Tab 2: Record Audio | |
| with tab2: | |
| duration = st.slider("Recording Duration (seconds)", 1, 60, 12) | |
| if st.button("Start Recording"): | |
| try: | |
| st.info(f"Recording for {duration} seconds... Please speak now.") | |
| recording = sd.rec(int(duration * 16000), samplerate=16000, channels=1, dtype='float32') | |
| sd.wait() | |
| recording = np.squeeze(recording) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
| torchaudio.save(tmp.name, torch.tensor(recording).unsqueeze(0), 16000) | |
| st.session_state.audio_path = tmp.name | |
| librosa.load(st.session_state.audio_path, sr=16000) | |
| st.audio(recording, format="audio/wav") | |
| except Exception as e: | |
| st.error(f"โ Failed to process recorded audio: {str(e)}") | |
| st.session_state.audio_bytes = None | |
| # Input ground truth for WER | |
| st.session_state.ground_truth = st.text_input( | |
| "Enter ground truth for WER calculation (Optional)", | |
| value=st.session_state.ground_truth, | |
| key="ground_truth_input" | |
| ) | |
| # Load model and processor | |
| def load_model(): | |
| model_repo_name = "wy0909/Whisper-MixedLanguageModel" | |
| model = WhisperForConditionalGeneration.from_pretrained(model_repo_name) | |
| processor = WhisperProcessor.from_pretrained(model_repo_name) | |
| model.config.forced_decoder_ids = None | |
| model.generation_config.forced_decoder_ids = None | |
| model.config.suppress_tokens = [] | |
| return model, processor | |
| model, processor = load_model() | |
| def capitalize_sentences(text): | |
| sentences = re.split(r'(?<=[.!?]) +', text) | |
| capitalized = [s.strip().capitalize() for s in sentences] | |
| return ' '.join(capitalized) | |
| # Transcription Mode | |
| mode = st.selectbox( | |
| "Transcription Mode:", | |
| options=["With Whisper", "With Whisper and WER"], | |
| help="Select transcription mode" | |
| ) | |
| # Transcription Button | |
| if st.button("๐ Transcribe"): | |
| if not st.session_state.audio_bytes: | |
| st.error("Please upload or record an audio file first.") | |
| else: | |
| start_time = time.time() | |
| try: | |
| waveform, sample_rate = torchaudio.load(st.session_state.audio_path) | |
| waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) | |
| inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt") | |
| with torch.no_grad(): | |
| predicted_ids = model.generate(inputs["input_features"]) | |
| transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
| st.session_state.predicted_text = capitalize_sentences(transcription) | |
| st.markdown("### ๐ค Transcription Result") | |
| st.success(st.session_state.predicted_text) | |
| if mode == "With Whisper and WER" and st.session_state.ground_truth: | |
| st.session_state.wer_value = wer(st.session_state.ground_truth.lower(), st.session_state.predicted_text.lower()) | |
| st.markdown("### ๐งฎ Word Error Rate (WER)") | |
| st.write(f"WER: `{st.session_state.wer_value:.2f}`") | |
| except Exception as e: | |
| st.error(f"โ Transcription failed: {str(e)}") | |
| end_time = time.time() | |
| duration = end_time - start_time | |
| st.markdown(f"<div style='font-size: 10px; color: gray;'>Processed in {duration:.2f}s</div>", unsafe_allow_html=True) | |