File size: 3,685 Bytes
aa547ad
86a1f13
 
550cf61
 
b86a6f7
550cf61
aa547ad
9a9ac31
550cf61
 
1e923d6
 
 
4036c8e
1e923d6
550cf61
 
 
ee53092
550cf61
1e923d6
 
 
 
 
 
550cf61
 
 
1e923d6
 
 
 
 
 
 
550cf61
 
1e923d6
 
 
 
 
 
 
550cf61
 
1e923d6
 
550cf61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e923d6
 
 
 
 
b86a6f7
550cf61
86a1f13
b86a6f7
 
550cf61
 
 
 
 
 
 
1e923d6
550cf61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e923d6
550cf61
 
 
 
1e923d6
550cf61
 
1e923d6
550cf61
 
 
 
 
 
 
b86a6f7
1e923d6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import spaces
import torch
import gradio as gr
import whisperx
from transformers.pipelines.audio_utils import ffmpeg_read
import tempfile
import gc
import os

# Constants
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4  # reduce if low on GPU mem
COMPUTE_TYPE = "float32"  # change to "int8" if low on GPU mem
FILE_LIMIT_MB = 1000

@spaces.GPU(duration=200)
def transcribe_audio(inputs, task):
    if inputs is None:
        raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
    
    try:
        # Load audio
        if isinstance(inputs, str):
            # For file path input
            audio = whisperx.load_audio(inputs)
        else:
            # For microphone input (needs conversion)
            audio = whisperx.load_audio(inputs)
            
        # 1. Transcribe with base Whisper model
        model = whisperx.load_model("large-v3", device=DEVICE, compute_type=COMPUTE_TYPE)
        result = model.transcribe(audio, batch_size=BATCH_SIZE)
        
        # Clear GPU memory
        del model
        gc.collect()
        torch.cuda.empty_cache()
        
        # 2. Align whisper output
        model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=DEVICE)
        result = whisperx.align(result["segments"], model_a, metadata, audio, DEVICE, return_char_alignments=False)
        
        # Clear GPU memory again
        del model_a
        gc.collect()
        torch.cuda.empty_cache()
        
        # 3. Diarize audio
        diarize_model = whisperx.DiarizationPipeline(use_auth_token=os.environ["HF_TOKEN"], device=DEVICE)
        diarize_segments = diarize_model(audio)
        
        # 4. Assign speaker labels
        result = whisperx.assign_word_speakers(diarize_segments, result)
        
        # Format output
        output_text = ""
        for segment in result['segments']:
            speaker = segment.get('speaker', 'Unknown Speaker')
            text = segment['text']
            output_text += f"{speaker}: {text}\n"
        
        return output_text
        
    except Exception as e:
        raise gr.Error(f"Error processing audio: {str(e)}")
    
    finally:
        # Final cleanup
        gc.collect()
        torch.cuda.empty_cache()

# Create Gradio interface
demo = gr.Blocks(theme=gr.themes.Ocean())

with demo:
    gr.Markdown("# WhisperX: Advanced Speech Recognition with Speaker Diarization")
    
    with gr.Row():
        with gr.Column():
            audio_input = gr.Audio(
                sources=["microphone", "upload"],
                type="filepath",
                label="Audio Input (Microphone or File Upload)"
            )
            task = gr.Radio(
                ["transcribe", "translate"],
                label="Task",
                value="transcribe"
            )
            submit_button = gr.Button("Process Audio")
        
        with gr.Column():
            output_text = gr.Textbox(
                label="Transcription with Speaker Diarization",
                lines=10,
                placeholder="Transcribed text will appear here..."
            )
    
    gr.Markdown("""
    ### Features:
    - High-accuracy transcription using WhisperX
    - Automatic speaker diarization
    - Support for both microphone recording and file upload
    - GPU-accelerated processing
    
    ### Note:
    Processing may take a few moments depending on the audio length and system resources.
    """)
    
    submit_button.click(
        fn=transcribe_audio,
        inputs=[audio_input, task],
        outputs=output_text
    )

demo.queue().launch(ssr_mode=False)