|
import streamlit as st |
|
import torch |
|
import torchaudio |
|
import librosa |
|
import numpy as np |
|
import tempfile |
|
from transformers import WhisperForConditionalGeneration, WhisperProcessor |
|
from jiwer import wer |
|
import time |
|
import re |
|
|
|
|
|
st.title("๐๏ธ Bahasa Rojak Transcriber (Whisper Fine-tuned)") |
|
|
|
|
|
if "audio_bytes" not in st.session_state: |
|
st.session_state.audio_bytes = None |
|
if "audio_path" not in st.session_state: |
|
st.session_state.audio_path = None |
|
if "ground_truth" not in st.session_state: |
|
st.session_state.ground_truth = "" |
|
if "predicted_text" not in st.session_state: |
|
st.session_state.predicted_text = "" |
|
if "wer_value" not in st.session_state: |
|
st.session_state.wer_value = None |
|
if "selected_tab" not in st.session_state: |
|
st.session_state.selected_tab = "๐ Upload Audio" |
|
if "previous_tab" not in st.session_state: |
|
st.session_state.previous_tab = "๐ Upload Audio" |
|
|
|
|
|
tab1, tab2 = st.tabs(["๐ Upload Audio", "๐ค Record Audio"]) |
|
|
|
|
|
if st.session_state.selected_tab != st.session_state.previous_tab: |
|
st.session_state.audio_bytes = None |
|
st.session_state.audio_path = None |
|
st.session_state.ground_truth = "" |
|
st.session_state.predicted_text = "" |
|
st.session_state.wer_value = None |
|
st.session_state.previous_tab = st.session_state.selected_tab |
|
|
|
|
|
with tab1: |
|
uploaded_file = st.file_uploader("Upload a .wav or .mp3 file", type=["wav", "mp3", "flac", "m4a", "ogg"]) |
|
if uploaded_file: |
|
try: |
|
st.session_state.audio_bytes = uploaded_file.read() |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{uploaded_file.name.split('.')[-1]}") as tmp: |
|
tmp.write(st.session_state.audio_bytes) |
|
st.session_state.audio_path = tmp.name |
|
|
|
librosa.load(st.session_state.audio_path, sr=16000) |
|
st.audio(st.session_state.audio_bytes, format="audio/wav") |
|
except Exception as e: |
|
st.error(f"โ Failed to read audio file: {str(e)}") |
|
st.session_state.audio_bytes = None |
|
|
|
|
|
with tab2: |
|
duration = st.slider("Recording Duration (seconds)", 1, 60, 12) |
|
if st.button("Start Recording"): |
|
try: |
|
st.info(f"Recording for {duration} seconds... Please speak now.") |
|
recording = sd.rec(int(duration * 16000), samplerate=16000, channels=1, dtype='float32') |
|
sd.wait() |
|
recording = np.squeeze(recording) |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: |
|
torchaudio.save(tmp.name, torch.tensor(recording).unsqueeze(0), 16000) |
|
st.session_state.audio_path = tmp.name |
|
librosa.load(st.session_state.audio_path, sr=16000) |
|
st.audio(recording, format="audio/wav") |
|
except Exception as e: |
|
st.error(f"โ Failed to process recorded audio: {str(e)}") |
|
st.session_state.audio_bytes = None |
|
|
|
|
|
st.session_state.ground_truth = st.text_input( |
|
"Enter ground truth for WER calculation (Optional)", |
|
value=st.session_state.ground_truth, |
|
key="ground_truth_input" |
|
) |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
model_repo_name = "wy0909/Whisper-MixedLanguageModel" |
|
model = WhisperForConditionalGeneration.from_pretrained(model_repo_name) |
|
processor = WhisperProcessor.from_pretrained(model_repo_name) |
|
model.config.forced_decoder_ids = None |
|
model.generation_config.forced_decoder_ids = None |
|
model.config.suppress_tokens = [] |
|
return model, processor |
|
|
|
model, processor = load_model() |
|
|
|
def capitalize_sentences(text): |
|
sentences = re.split(r'(?<=[.!?]) +', text) |
|
capitalized = [s.strip().capitalize() for s in sentences] |
|
return ' '.join(capitalized) |
|
|
|
|
|
mode = st.selectbox( |
|
"Transcription Mode:", |
|
options=["With Whisper", "With Whisper and WER"], |
|
help="Select transcription mode" |
|
) |
|
|
|
|
|
if st.button("๐ Transcribe"): |
|
if not st.session_state.audio_bytes: |
|
st.error("Please upload or record an audio file first.") |
|
else: |
|
start_time = time.time() |
|
try: |
|
waveform, sample_rate = torchaudio.load(st.session_state.audio_path) |
|
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform) |
|
inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt") |
|
with torch.no_grad(): |
|
predicted_ids = model.generate(inputs["input_features"]) |
|
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] |
|
st.session_state.predicted_text = capitalize_sentences(transcription) |
|
st.markdown("### ๐ค Transcription Result") |
|
st.success(st.session_state.predicted_text) |
|
|
|
if mode == "With Whisper and WER" and st.session_state.ground_truth: |
|
st.session_state.wer_value = wer(st.session_state.ground_truth.lower(), st.session_state.predicted_text.lower()) |
|
st.markdown("### ๐งฎ Word Error Rate (WER)") |
|
st.write(f"WER: `{st.session_state.wer_value:.2f}`") |
|
|
|
except Exception as e: |
|
st.error(f"โ Transcription failed: {str(e)}") |
|
|
|
end_time = time.time() |
|
duration = end_time - start_time |
|
st.markdown(f"<div style='font-size: 10px; color: gray;'>Processed in {duration:.2f}s</div>", unsafe_allow_html=True) |
|
|