import os import torch import gradio as gr import logging import subprocess from pydub import AudioSegment from pydub.exceptions import CouldntDecodeError from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor from pathlib import Path from tempfile import NamedTemporaryFile from datetime import timedelta # Setup logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) # Configuration MODEL_ID = "KBLab/kb-whisper-large" CHUNK_DURATION_MS = 10000 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 SUPPORTED_FORMATS = {".wav", ".mp3", ".m4a"} # Check for ffmpeg availability def check_ffmpeg(): try: subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True) logger.info("ffmpeg is installed and accessible.") return True except (subprocess.CalledProcessError, FileNotFoundError): logger.error("ffmpeg is not installed or not found in PATH.") return False # Initialize model and pipeline def initialize_pipeline(): try: model = AutoModelForSpeechSeq2Seq.from_pretrained( MODEL_ID, torch_dtype=TORCH_DTYPE, low_cpu_mem_usage=True ).to(DEVICE) processor = AutoProcessor.from_pretrained(MODEL_ID) return pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, device=DEVICE, torch_dtype=TORCH_DTYPE, model_kwargs={"use_flash_attention_2": torch.cuda.is_available()} ) except Exception as e: logger.error(f"Failed to initialize pipeline: {str(e)}") raise RuntimeError("Unable to load transcription model. Please check your network connection or model ID.") # Convert audio if needed def convert_to_wav(audio_path: str) -> str: try: if not check_ffmpeg(): raise RuntimeError("ffmpeg is required to process .m4a files. Please install ffmpeg and ensure it's in your PATH.") ext = str(Path(audio_path).suffix).lower() if ext not in SUPPORTED_FORMATS: raise ValueError(f"Unsupported audio format: {ext}. Supported formats: {', '.join(SUPPORTED_FORMATS)}") if ext != ".wav": logger.info(f"Converting {ext} file to WAV: {audio_path}") audio = AudioSegment.from_file(audio_path) wav_path = str(Path(audio_path).with_suffix(".converted.wav")) audio.export(wav_path, format="wav") logger.info(f"Conversion successful: {wav_path}") return wav_path return audio_path except CouldntDecodeError: logger.error(f"Failed to decode .m4a file: {audio_path}") raise ValueError("The .m4a file is corrupted or not supported. Ensure it's a valid iPhone recording and ffmpeg is installed.") except OSError as e: logger.error(f"OS error during audio conversion: {str(e)}") raise ValueError("Failed to process the .m4a file due to a system error. Check file permissions or disk space.") except Exception as e: logger.error(f"Unexpected error during .m4a conversion: {str(e)}") raise ValueError(f"An unexpected error occurred while converting the .m4a file: {str(e)}") # Split audio into chunks def split_audio(audio_path: str) -> list: try: audio = AudioSegment.from_file(audio_path) if len(audio) == 0: raise ValueError("The .m4a file is empty or invalid.") logger.info(f"Splitting audio into {CHUNK_DURATION_MS/1000}-second chunks: {audio_path}") return [audio[i:i + CHUNK_DURATION_MS] for i in range(0, len(audio), CHUNK_DURATION_MS)] except CouldntDecodeError: logger.error(f"Failed to decode audio for splitting: {audio_path}") raise ValueError("The .m4a file is corrupted or not supported. Ensure it's a valid iPhone recording.") except Exception as e: logger.error(f"Failed to split audio: {str(e)}") raise ValueError(f"Failed to process the .m4a file: {str(e)}") # Helper to compute chunk start time def get_chunk_time(index: int, chunk_duration_ms: int) -> str: start_ms = index * chunk_duration_ms return str(timedelta(milliseconds=start_ms)) # Transcribe audio with progress and timestamps def transcribe(audio_path: str, include_timestamps: bool = False, progress=gr.Progress()): try: if not audio_path or not os.path.exists(audio_path): logger.warning("Invalid or missing audio file path.") return "Please upload a valid .m4a file.", None # Convert to WAV if needed wav_path = convert_to_wav(audio_path) # Split and process chunks = split_audio(wav_path) total_chunks = len(chunks) transcript = [] timestamped_transcript = [] failed_chunks = 0 for i, chunk in enumerate(chunks): temp_file_path = None try: with NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: temp_file_path = temp_file.name chunk.export(temp_file.name, format="wav") result = PIPELINE(temp_file.name, generate_kwargs={"task": "transcribe", "language": "sv"}) text = result["text"].strip() if text: transcript.append(text) if include_timestamps: timestamp = get_chunk_time(i, CHUNK_DURATION_MS) timestamped_transcript.append(f"[{timestamp}] {text}") except RuntimeError as e: logger.warning(f"Failed to transcribe chunk {i+1}/{total_chunks}: {str(e)}") failed_chunks += 1 transcript.append("[Transcription failed for this segment]") if include_timestamps: timestamp = get_chunk_time(i, CHUNK_DURATION_MS) timestamped_transcript.append(f"[{timestamp}] [Transcription failed]") except Exception as e: logger.error(f"Unexpected error in chunk {i+1}/{total_chunks}: {str(e)}") failed_chunks += 1 transcript.append("[Transcription failed for this segment]") if include_timestamps: timestamp = get_chunk_time(i, CHUNK_DURATION_MS) timestamped_transcript.append(f"[{timestamp}] [Transcription failed]") finally: if temp_file_path and os.path.exists(temp_file_path): try: os.remove(temp_file_path) except OSError as e: logger.warning(f"Failed to delete temporary file {temp_file_path}: {str(e)}") progress((i + 1) / total_chunks) yield " ".join(transcript), None # Clean up converted file if created if wav_path != audio_path and os.path.exists(wav_path): try: os.remove(wav_path) except OSError as e: logger.warning(f"Failed to delete converted WAV file {wav_path}: {str(e)}") # Prepare final transcript and downloadable file final_transcript = " ".join(transcript) if failed_chunks > 0: final_transcript = f"Warning: {failed_chunks}/{total_chunks} chunks failed to transcribe.\n{final_transcript}" download_content = "\n".join(timestamped_transcript) if include_timestamps else final_transcript download_path = None try: with NamedTemporaryFile(suffix=".txt", delete=False, mode='w', encoding='utf-8') as temp_file: temp_file.write(download_content) download_path = temp_file.name except OSError as e: logger.error(f"Failed to create downloadable transcript: {str(e)}") final_transcript = f"{final_transcript}\nNote: Could not generate downloadable transcript due to a file error." return final_transcript, download_path except ValueError as e: logger.error(f"Value error during transcription: {str(e)}") return str(e), None except Exception as e: logger.error(f"Unexpected error during transcription: {str(e)}") return f"An unexpected error occurred while processing the .m4a file: {str(e)}. Please ensure the file is a valid iPhone recording and try again.", None # Initialize pipeline globally try: PIPELINE = initialize_pipeline() except RuntimeError as e: logger.critical(f"Pipeline initialization failed: {str(e)}") raise # Gradio Interface with Blocks def create_interface(): with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# Swedish Whisper Transcriber") gr.Markdown("Upload an .m4a file from your iPhone for real-time Swedish speech transcription.") with gr.Row(): with gr.Column(): audio_input = gr.Audio(type="filepath", label="Upload .m4a Audio") timestamp_toggle = gr.Checkbox(label="Include Timestamps in Download", value=False) transcribe_btn = gr.Button("Transcribe") with gr.Column(): transcript_output = gr.Textbox(label="Live Transcription", lines=10) download_output = gr.File(label="Download Transcript") transcribe_btn.click( fn=transcribe, inputs=[audio_input, timestamp_toggle], outputs=[transcript_output, download_output] ) return demo if __name__ == "__main__": try: if not check_ffmpeg(): print("Error: ffmpeg is required to process .m4a files. Please install ffmpeg and ensure it's in your PATH.") exit(1) create_interface().launch() except Exception as e: logger.critical(f"Failed to launch Gradio interface: {str(e)}") print(f"Error: Could not start the application. Please check the logs for details.")