Spaces:
Running
Running
import os | |
import torch | |
import gradio as gr | |
import logging | |
import subprocess | |
from pydub import AudioSegment | |
from pydub.exceptions import CouldntDecodeError | |
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor | |
from pathlib import Path | |
from tempfile import NamedTemporaryFile | |
from datetime import timedelta | |
# Setup logging | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
logger = logging.getLogger(__name__) | |
# Configuration | |
MODEL_ID = "KBLab/kb-whisper-large" | |
CHUNK_DURATION_MS = 10000 | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 | |
SUPPORTED_FORMATS = {".wav", ".mp3", ".m4a"} | |
# Check for ffmpeg availability | |
def check_ffmpeg(): | |
try: | |
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True) | |
logger.info("ffmpeg is installed and accessible.") | |
return True | |
except (subprocess.CalledProcessError, FileNotFoundError): | |
logger.error("ffmpeg is not installed or not found in PATH.") | |
return False | |
# Initialize model and pipeline | |
def initialize_pipeline(): | |
try: | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
MODEL_ID, | |
torch_dtype=TORCH_DTYPE, | |
low_cpu_mem_usage=True | |
).to(DEVICE) | |
processor = AutoProcessor.from_pretrained(MODEL_ID) | |
return pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
device=DEVICE, | |
torch_dtype=TORCH_DTYPE, | |
model_kwargs={"use_flash_attention_2": torch.cuda.is_available()} | |
) | |
except Exception as e: | |
logger.error(f"Failed to initialize pipeline: {str(e)}") | |
raise RuntimeError("Unable to load transcription model. Please check your network connection or model ID.") | |
# Convert audio if needed | |
def convert_to_wav(audio_path: str) -> str: | |
try: | |
if not check_ffmpeg(): | |
raise RuntimeError("ffmpeg is required to process .m4a files. Please install ffmpeg and ensure it's in your PATH.") | |
ext = str(Path(audio_path).suffix).lower() | |
if ext not in SUPPORTED_FORMATS: | |
raise ValueError(f"Unsupported audio format: {ext}. Supported formats: {', '.join(SUPPORTED_FORMATS)}") | |
if ext != ".wav": | |
logger.info(f"Converting {ext} file to WAV: {audio_path}") | |
audio = AudioSegment.from_file(audio_path) | |
wav_path = str(Path(audio_path).with_suffix(".converted.wav")) | |
audio.export(wav_path, format="wav") | |
logger.info(f"Conversion successful: {wav_path}") | |
return wav_path | |
return audio_path | |
except CouldntDecodeError: | |
logger.error(f"Failed to decode .m4a file: {audio_path}") | |
raise ValueError("The .m4a file is corrupted or not supported. Ensure it's a valid iPhone recording and ffmpeg is installed.") | |
except OSError as e: | |
logger.error(f"OS error during audio conversion: {str(e)}") | |
raise ValueError("Failed to process the .m4a file due to a system error. Check file permissions or disk space.") | |
except Exception as e: | |
logger.error(f"Unexpected error during .m4a conversion: {str(e)}") | |
raise ValueError(f"An unexpected error occurred while converting the .m4a file: {str(e)}") | |
# Split audio into chunks | |
def split_audio(audio_path: str) -> list: | |
try: | |
audio = AudioSegment.from_file(audio_path) | |
if len(audio) == 0: | |
raise ValueError("The .m4a file is empty or invalid.") | |
logger.info(f"Splitting audio into {CHUNK_DURATION_MS/1000}-second chunks: {audio_path}") | |
return [audio[i:i + CHUNK_DURATION_MS] for i in range(0, len(audio), CHUNK_DURATION_MS)] | |
except CouldntDecodeError: | |
logger.error(f"Failed to decode audio for splitting: {audio_path}") | |
raise ValueError("The .m4a file is corrupted or not supported. Ensure it's a valid iPhone recording.") | |
except Exception as e: | |
logger.error(f"Failed to split audio: {str(e)}") | |
raise ValueError(f"Failed to process the .m4a file: {str(e)}") | |
# Helper to compute chunk start time | |
def get_chunk_time(index: int, chunk_duration_ms: int) -> str: | |
start_ms = index * chunk_duration_ms | |
return str(timedelta(milliseconds=start_ms)) | |
# Transcribe audio with progress and timestamps | |
def transcribe(audio_path: str, include_timestamps: bool = False, progress=gr.Progress()): | |
try: | |
if not audio_path or not os.path.exists(audio_path): | |
logger.warning("Invalid or missing audio file path.") | |
return "Please upload a valid .m4a file.", None | |
# Convert to WAV if needed | |
wav_path = convert_to_wav(audio_path) | |
# Split and process | |
chunks = split_audio(wav_path) | |
total_chunks = len(chunks) | |
transcript = [] | |
timestamped_transcript = [] | |
failed_chunks = 0 | |
for i, chunk in enumerate(chunks): | |
temp_file_path = None | |
try: | |
with NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: | |
temp_file_path = temp_file.name | |
chunk.export(temp_file.name, format="wav") | |
result = PIPELINE(temp_file.name, | |
generate_kwargs={"task": "transcribe", "language": "sv"}) | |
text = result["text"].strip() | |
if text: | |
transcript.append(text) | |
if include_timestamps: | |
timestamp = get_chunk_time(i, CHUNK_DURATION_MS) | |
timestamped_transcript.append(f"[{timestamp}] {text}") | |
except RuntimeError as e: | |
logger.warning(f"Failed to transcribe chunk {i+1}/{total_chunks}: {str(e)}") | |
failed_chunks += 1 | |
transcript.append("[Transcription failed for this segment]") | |
if include_timestamps: | |
timestamp = get_chunk_time(i, CHUNK_DURATION_MS) | |
timestamped_transcript.append(f"[{timestamp}] [Transcription failed]") | |
except Exception as e: | |
logger.error(f"Unexpected error in chunk {i+1}/{total_chunks}: {str(e)}") | |
failed_chunks += 1 | |
transcript.append("[Transcription failed for this segment]") | |
if include_timestamps: | |
timestamp = get_chunk_time(i, CHUNK_DURATION_MS) | |
timestamped_transcript.append(f"[{timestamp}] [Transcription failed]") | |
finally: | |
if temp_file_path and os.path.exists(temp_file_path): | |
try: | |
os.remove(temp_file_path) | |
except OSError as e: | |
logger.warning(f"Failed to delete temporary file {temp_file_path}: {str(e)}") | |
progress((i + 1) / total_chunks) | |
yield " ".join(transcript), None | |
# Clean up converted file if created | |
if wav_path != audio_path and os.path.exists(wav_path): | |
try: | |
os.remove(wav_path) | |
except OSError as e: | |
logger.warning(f"Failed to delete converted WAV file {wav_path}: {str(e)}") | |
# Prepare final transcript and downloadable file | |
final_transcript = " ".join(transcript) | |
if failed_chunks > 0: | |
final_transcript = f"Warning: {failed_chunks}/{total_chunks} chunks failed to transcribe.\n{final_transcript}" | |
download_content = "\n".join(timestamped_transcript) if include_timestamps else final_transcript | |
download_path = None | |
try: | |
with NamedTemporaryFile(suffix=".txt", delete=False, mode='w', encoding='utf-8') as temp_file: | |
temp_file.write(download_content) | |
download_path = temp_file.name | |
except OSError as e: | |
logger.error(f"Failed to create downloadable transcript: {str(e)}") | |
final_transcript = f"{final_transcript}\nNote: Could not generate downloadable transcript due to a file error." | |
return final_transcript, download_path | |
except ValueError as e: | |
logger.error(f"Value error during transcription: {str(e)}") | |
return str(e), None | |
except Exception as e: | |
logger.error(f"Unexpected error during transcription: {str(e)}") | |
return f"An unexpected error occurred while processing the .m4a file: {str(e)}. Please ensure the file is a valid iPhone recording and try again.", None | |
# Initialize pipeline globally | |
try: | |
PIPELINE = initialize_pipeline() | |
except RuntimeError as e: | |
logger.critical(f"Pipeline initialization failed: {str(e)}") | |
raise | |
# Gradio Interface with Blocks | |
def create_interface(): | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# Swedish Whisper Transcriber") | |
gr.Markdown("Upload an .m4a file from your iPhone for real-time Swedish speech transcription.") | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio(type="filepath", label="Upload .m4a Audio") | |
timestamp_toggle = gr.Checkbox(label="Include Timestamps in Download", value=False) | |
transcribe_btn = gr.Button("Transcribe") | |
with gr.Column(): | |
transcript_output = gr.Textbox(label="Live Transcription", lines=10) | |
download_output = gr.File(label="Download Transcript") | |
transcribe_btn.click( | |
fn=transcribe, | |
inputs=[audio_input, timestamp_toggle], | |
outputs=[transcript_output, download_output] | |
) | |
return demo | |
if __name__ == "__main__": | |
try: | |
if not check_ffmpeg(): | |
print("Error: ffmpeg is required to process .m4a files. Please install ffmpeg and ensure it's in your PATH.") | |
exit(1) | |
create_interface().launch() | |
except Exception as e: | |
logger.critical(f"Failed to launch Gradio interface: {str(e)}") | |
print(f"Error: Could not start the application. Please check the logs for details.") |