File size: 10,322 Bytes
db55266
 
 
37f7d1f
f64cacf
db55266
37f7d1f
be25d7c
 
 
 
db55266
37f7d1f
 
 
 
be25d7c
 
 
 
 
37f7d1f
0b63b29
f64cacf
 
 
 
 
 
 
 
 
 
be25d7c
 
37f7d1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db55266
be25d7c
 
37f7d1f
f64cacf
 
37f7d1f
 
 
 
f64cacf
37f7d1f
 
 
f64cacf
37f7d1f
 
 
f64cacf
 
37f7d1f
 
f64cacf
37f7d1f
f64cacf
 
0b63b29
be25d7c
 
 
 
 
f64cacf
 
37f7d1f
 
 
f64cacf
be25d7c
37f7d1f
f64cacf
0b63b29
be25d7c
 
 
 
0b63b29
be25d7c
 
 
37f7d1f
 
f64cacf
37f7d1f
be25d7c
 
 
 
 
 
 
 
37f7d1f
be25d7c
 
37f7d1f
be25d7c
 
37f7d1f
be25d7c
 
 
 
f64cacf
37f7d1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be25d7c
37f7d1f
 
 
 
 
 
be25d7c
 
37f7d1f
be25d7c
 
37f7d1f
 
 
 
 
be25d7c
 
37f7d1f
 
 
be25d7c
37f7d1f
 
 
 
 
 
 
 
 
be25d7c
 
37f7d1f
 
 
be25d7c
37f7d1f
f64cacf
0b63b29
be25d7c
37f7d1f
 
 
 
 
0b63b29
be25d7c
 
 
 
f64cacf
be25d7c
 
 
f64cacf
be25d7c
 
 
 
 
 
 
 
 
 
 
 
 
 
db55266
be25d7c
37f7d1f
f64cacf
 
 
37f7d1f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import os
import torch
import gradio as gr
import logging
import subprocess
from pydub import AudioSegment
from pydub.exceptions import CouldntDecodeError
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
from pathlib import Path
from tempfile import NamedTemporaryFile
from datetime import timedelta

# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# Configuration
MODEL_ID = "KBLab/kb-whisper-large"
CHUNK_DURATION_MS = 10000
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
SUPPORTED_FORMATS = {".wav", ".mp3", ".m4a"}

# Check for ffmpeg availability
def check_ffmpeg():
    try:
        subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
        logger.info("ffmpeg is installed and accessible.")
        return True
    except (subprocess.CalledProcessError, FileNotFoundError):
        logger.error("ffmpeg is not installed or not found in PATH.")
        return False

# Initialize model and pipeline
def initialize_pipeline():
    try:
        model = AutoModelForSpeechSeq2Seq.from_pretrained(
            MODEL_ID,
            torch_dtype=TORCH_DTYPE,
            low_cpu_mem_usage=True
        ).to(DEVICE)
        processor = AutoProcessor.from_pretrained(MODEL_ID)
        return pipeline(
            "automatic-speech-recognition",
            model=model,
            tokenizer=processor.tokenizer,
            feature_extractor=processor.feature_extractor,
            device=DEVICE,
            torch_dtype=TORCH_DTYPE,
            model_kwargs={"use_flash_attention_2": torch.cuda.is_available()}
        )
    except Exception as e:
        logger.error(f"Failed to initialize pipeline: {str(e)}")
        raise RuntimeError("Unable to load transcription model. Please check your network connection or model ID.")

# Convert audio if needed
def convert_to_wav(audio_path: str) -> str:
    try:
        if not check_ffmpeg():
            raise RuntimeError("ffmpeg is required to process .m4a files. Please install ffmpeg and ensure it's in your PATH.")
        ext = str(Path(audio_path).suffix).lower()
        if ext not in SUPPORTED_FORMATS:
            raise ValueError(f"Unsupported audio format: {ext}. Supported formats: {', '.join(SUPPORTED_FORMATS)}")
        if ext != ".wav":
            logger.info(f"Converting {ext} file to WAV: {audio_path}")
            audio = AudioSegment.from_file(audio_path)
            wav_path = str(Path(audio_path).with_suffix(".converted.wav"))
            audio.export(wav_path, format="wav")
            logger.info(f"Conversion successful: {wav_path}")
            return wav_path
        return audio_path
    except CouldntDecodeError:
        logger.error(f"Failed to decode .m4a file: {audio_path}")
        raise ValueError("The .m4a file is corrupted or not supported. Ensure it's a valid iPhone recording and ffmpeg is installed.")
    except OSError as e:
        logger.error(f"OS error during audio conversion: {str(e)}")
        raise ValueError("Failed to process the .m4a file due to a system error. Check file permissions or disk space.")
    except Exception as e:
        logger.error(f"Unexpected error during .m4a conversion: {str(e)}")
        raise ValueError(f"An unexpected error occurred while converting the .m4a file: {str(e)}")

# Split audio into chunks
def split_audio(audio_path: str) -> list:
    try:
        audio = AudioSegment.from_file(audio_path)
        if len(audio) == 0:
            raise ValueError("The .m4a file is empty or invalid.")
        logger.info(f"Splitting audio into {CHUNK_DURATION_MS/1000}-second chunks: {audio_path}")
        return [audio[i:i + CHUNK_DURATION_MS] for i in range(0, len(audio), CHUNK_DURATION_MS)]
    except CouldntDecodeError:
        logger.error(f"Failed to decode audio for splitting: {audio_path}")
        raise ValueError("The .m4a file is corrupted or not supported. Ensure it's a valid iPhone recording.")
    except Exception as e:
        logger.error(f"Failed to split audio: {str(e)}")
        raise ValueError(f"Failed to process the .m4a file: {str(e)}")

# Helper to compute chunk start time
def get_chunk_time(index: int, chunk_duration_ms: int) -> str:
    start_ms = index * chunk_duration_ms
    return str(timedelta(milliseconds=start_ms))

# Transcribe audio with progress and timestamps
def transcribe(audio_path: str, include_timestamps: bool = False, progress=gr.Progress()):
    try:
        if not audio_path or not os.path.exists(audio_path):
            logger.warning("Invalid or missing audio file path.")
            return "Please upload a valid .m4a file.", None

        # Convert to WAV if needed
        wav_path = convert_to_wav(audio_path)
        
        # Split and process
        chunks = split_audio(wav_path)
        total_chunks = len(chunks)
        transcript = []
        timestamped_transcript = []
        failed_chunks = 0
        
        for i, chunk in enumerate(chunks):
            temp_file_path = None
            try:
                with NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
                    temp_file_path = temp_file.name
                    chunk.export(temp_file.name, format="wav")
                    result = PIPELINE(temp_file.name, 
                                   generate_kwargs={"task": "transcribe", "language": "sv"})
                    text = result["text"].strip()
                    if text:
                        transcript.append(text)
                        if include_timestamps:
                            timestamp = get_chunk_time(i, CHUNK_DURATION_MS)
                            timestamped_transcript.append(f"[{timestamp}] {text}")
            except RuntimeError as e:
                logger.warning(f"Failed to transcribe chunk {i+1}/{total_chunks}: {str(e)}")
                failed_chunks += 1
                transcript.append("[Transcription failed for this segment]")
                if include_timestamps:
                    timestamp = get_chunk_time(i, CHUNK_DURATION_MS)
                    timestamped_transcript.append(f"[{timestamp}] [Transcription failed]")
            except Exception as e:
                logger.error(f"Unexpected error in chunk {i+1}/{total_chunks}: {str(e)}")
                failed_chunks += 1
                transcript.append("[Transcription failed for this segment]")
                if include_timestamps:
                    timestamp = get_chunk_time(i, CHUNK_DURATION_MS)
                    timestamped_transcript.append(f"[{timestamp}] [Transcription failed]")
            finally:
                if temp_file_path and os.path.exists(temp_file_path):
                    try:
                        os.remove(temp_file_path)
                    except OSError as e:
                        logger.warning(f"Failed to delete temporary file {temp_file_path}: {str(e)}")
                        
            progress((i + 1) / total_chunks)
            yield " ".join(transcript), None
        
        # Clean up converted file if created
        if wav_path != audio_path and os.path.exists(wav_path):
            try:
                os.remove(wav_path)
            except OSError as e:
                logger.warning(f"Failed to delete converted WAV file {wav_path}: {str(e)}")
        
        # Prepare final transcript and downloadable file
        final_transcript = " ".join(transcript)
        if failed_chunks > 0:
            final_transcript = f"Warning: {failed_chunks}/{total_chunks} chunks failed to transcribe.\n{final_transcript}"
        
        download_content = "\n".join(timestamped_transcript) if include_timestamps else final_transcript
        download_path = None
        try:
            with NamedTemporaryFile(suffix=".txt", delete=False, mode='w', encoding='utf-8') as temp_file:
                temp_file.write(download_content)
                download_path = temp_file.name
        except OSError as e:
            logger.error(f"Failed to create downloadable transcript: {str(e)}")
            final_transcript = f"{final_transcript}\nNote: Could not generate downloadable transcript due to a file error."
        
        return final_transcript, download_path
        
    except ValueError as e:
        logger.error(f"Value error during transcription: {str(e)}")
        return str(e), None
    except Exception as e:
        logger.error(f"Unexpected error during transcription: {str(e)}")
        return f"An unexpected error occurred while processing the .m4a file: {str(e)}. Please ensure the file is a valid iPhone recording and try again.", None

# Initialize pipeline globally
try:
    PIPELINE = initialize_pipeline()
except RuntimeError as e:
    logger.critical(f"Pipeline initialization failed: {str(e)}")
    raise

# Gradio Interface with Blocks
def create_interface():
    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown("# Swedish Whisper Transcriber")
        gr.Markdown("Upload an .m4a file from your iPhone for real-time Swedish speech transcription.")
        
        with gr.Row():
            with gr.Column():
                audio_input = gr.Audio(type="filepath", label="Upload .m4a Audio")
                timestamp_toggle = gr.Checkbox(label="Include Timestamps in Download", value=False)
                transcribe_btn = gr.Button("Transcribe")
            
            with gr.Column():
                transcript_output = gr.Textbox(label="Live Transcription", lines=10)
                download_output = gr.File(label="Download Transcript")
        
        transcribe_btn.click(
            fn=transcribe,
            inputs=[audio_input, timestamp_toggle],
            outputs=[transcript_output, download_output]
        )
    
    return demo

if __name__ == "__main__":
    try:
        if not check_ffmpeg():
            print("Error: ffmpeg is required to process .m4a files. Please install ffmpeg and ensure it's in your PATH.")
            exit(1)
        create_interface().launch()
    except Exception as e:
        logger.critical(f"Failed to launch Gradio interface: {str(e)}")
        print(f"Error: Could not start the application. Please check the logs for details.")