Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import gradio as gr | |
| import logging | |
| import subprocess | |
| from pydub import AudioSegment | |
| from pydub.exceptions import CouldntDecodeError | |
| from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor | |
| from pathlib import Path | |
| from tempfile import NamedTemporaryFile | |
| from datetime import timedelta | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # Configuration | |
| MODEL_ID = "KBLab/kb-whisper-large" | |
| CHUNK_DURATION_MS = 10000 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| SUPPORTED_FORMATS = {".wav", ".mp3", ".m4a"} | |
| # Check for ffmpeg availability | |
| def check_ffmpeg(): | |
| try: | |
| subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True) | |
| logger.info("ffmpeg is installed and accessible.") | |
| return True | |
| except (subprocess.CalledProcessError, FileNotFoundError): | |
| logger.error("ffmpeg is not installed or not found in PATH.") | |
| return False | |
| # Initialize model and pipeline | |
| def initialize_pipeline(): | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=TORCH_DTYPE, | |
| low_cpu_mem_usage=True | |
| ).to(DEVICE) | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| return pipeline( | |
| "automatic-speech-recognition", | |
| model=model, | |
| tokenizer=processor.tokenizer, | |
| feature_extractor=processor.feature_extractor, | |
| device=DEVICE, | |
| torch_dtype=TORCH_DTYPE | |
| ) | |
| # Convert audio if needed | |
| def convert_to_wav(audio_path: str) -> str: | |
| if not check_ffmpeg(): | |
| raise RuntimeError("ffmpeg is required") | |
| ext = str(Path(audio_path).suffix).lower() | |
| if ext not in SUPPORTED_FORMATS: | |
| raise ValueError(f"Unsupported format: {ext}") | |
| if ext != ".wav": | |
| audio = AudioSegment.from_file(audio_path) | |
| wav_path = str(Path(audio_path).with_suffix(".converted.wav")) | |
| audio.export(wav_path, format="wav") | |
| return wav_path | |
| return audio_path | |
| # Split audio into chunks | |
| def split_audio(audio_path: str) -> list: | |
| audio = AudioSegment.from_file(audio_path) | |
| return [audio[i:i + CHUNK_DURATION_MS] for i in range(0, len(audio), CHUNK_DURATION_MS)] | |
| # Helper to compute chunk start time | |
| def get_chunk_time(index: int, chunk_duration_ms: int) -> str: | |
| start_ms = index * chunk_duration_ms | |
| return str(timedelta(milliseconds=start_ms)) | |
| # Transcribe audio with streaming + working download | |
| def transcribe(audio_path: str, include_timestamps: bool, progress=gr.Progress()): | |
| if not audio_path or not os.path.exists(audio_path): | |
| yield "Please upload a valid audio file.", None | |
| return | |
| wav_path = convert_to_wav(audio_path) | |
| chunks = split_audio(wav_path) | |
| transcript = [] | |
| timestamped_transcript = [] | |
| for i, chunk in enumerate(chunks): | |
| temp_file_path = None | |
| try: | |
| with NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: | |
| temp_file_path = temp_file.name | |
| chunk.export(temp_file.name, format="wav") | |
| result = PIPELINE( | |
| temp_file.name, | |
| generate_kwargs={"task": "transcribe", "language": "sv"} | |
| ) | |
| text = result["text"].strip() | |
| if text: | |
| transcript.append(text) | |
| if include_timestamps: | |
| timestamp = get_chunk_time(i, CHUNK_DURATION_MS) | |
| timestamped_transcript.append(f"[{timestamp}] {text}") | |
| finally: | |
| if temp_file_path and os.path.exists(temp_file_path): | |
| os.remove(temp_file_path) | |
| progress((i + 1) / len(chunks)) | |
| yield " ".join(transcript), None # STREAM TEXT ONLY | |
| # Create downloadable file ONLY ONCE (fix) | |
| content = ( | |
| "\n".join(timestamped_transcript) | |
| if include_timestamps | |
| else " ".join(transcript) | |
| ) | |
| with NamedTemporaryFile( | |
| suffix=".txt", | |
| delete=False, | |
| mode="w", | |
| encoding="utf-8" | |
| ) as f: | |
| f.write(content) | |
| download_path = f.name | |
| yield " ".join(transcript), download_path # FINAL OUTPUT | |
| # Initialize pipeline globally | |
| PIPELINE = initialize_pipeline() | |
| # Gradio Interface | |
| def create_interface(): | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Swedish Whisper Transcriber") | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio(type="filepath", label="Upload .m4a Audio") | |
| timestamp_toggle = gr.Checkbox(label="Include Timestamps in Download") | |
| transcribe_btn = gr.Button("Transcribe") | |
| with gr.Column(): | |
| transcript_output = gr.Textbox(label="Live Transcription", lines=10) | |
| download_output = gr.File(label="Download Transcript") | |
| transcribe_btn.click( | |
| fn=transcribe, | |
| inputs=[audio_input, timestamp_toggle], | |
| outputs=[transcript_output, download_output] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| create_interface().launch() | |