File size: 7,414 Bytes
6c226f9
645d142
23a2ead
4d39b43
 
645d142
4d39b43
 
 
 
7d50a29
4d39b43
645d142
4d39b43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
645d142
7d50a29
 
645d142
cf76f12
 
 
4d39b43
7d50a29
 
645d142
7d50a29
 
4d39b43
7d50a29
645d142
7d50a29
 
4d39b43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2d2762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d39b43
 
d2d2762
 
 
4d39b43
d2d2762
4d39b43
 
 
 
 
 
 
114f7ba
4d39b43
 
 
 
 
114f7ba
 
d2d2762
4d39b43
 
 
 
 
 
 
d2d2762
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import gradio as gr
import torch
import numpy as np
from sys import platform
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
from transformers.utils import is_flash_attn_2_available
from subtitle_manager import Subtitle
import logging
print(gr.__version__)
logging.basicConfig(level=logging.INFO)

# Global state
pipe = None
last_model = None
def get_language_names():
    return [
        "af", "am", "ar", "as", "az", "ba", "be", "bg", "bn", "bo", "br", "bs",
        "ca", "cs", "cy", "da", "de", "el", "en", "es", "et", "eu", "fa", "fi",
        "fo", "fr", "gl", "gu", "ha", "haw", "he", "hi", "hr", "ht", "hu", "hy",
        "id", "is", "it", "ja", "jw", "ka", "kk", "km", "kn", "ko", "la", "lb",
        "ln", "lo", "lt", "lv", "mg", "mi", "mk", "ml", "mn", "mr", "ms", "mt",
        "my", "ne", "nl", "nn", "no", "oc", "pa", "pl", "ps", "pt", "ro", "ru",
        "sa", "sd", "si", "sk", "sl", "sn", "so", "sq", "sr", "su", "sv", "sw",
        "ta", "te", "tg", "th", "tk", "tl", "tr", "tt", "uk", "ur", "uz", "vi",
        "yi", "yo", "zh"
    ]

def create_pipe(model_id, flash):
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        model_id,
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=True,
        use_safetensors=True,
        attn_implementation="flash_attention_2" if flash and is_flash_attn_2_available() else "sdpa",
    ).to(device)

    processor = AutoProcessor.from_pretrained(model_id)

    return pipeline(
        "automatic-speech-recognition",
        model=model,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        device=device,
        torch_dtype=torch_dtype,
    )

def transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task, flash,
                                    chunk_length_s, batch_size, progress=gr.Progress()):
    global last_model, pipe

    progress(0, desc="Loading Audio...")

    if last_model != modelName or pipe is None:
        torch.cuda.empty_cache()
        progress(0.1, desc="Loading Model...")
        pipe = create_pipe(modelName, flash)
        last_model = modelName

    files = []
    if multipleFiles:
        files += multipleFiles
    if urlData:
        files.append(urlData)
    if microphoneData:
        files.append(microphoneData)

    srt_sub = Subtitle("srt")
    vtt_sub = Subtitle("vtt")
    txt_sub = Subtitle("txt")

    files_out = []
    for file in progress.tqdm(files, desc="Working..."):
        outputs = pipe(
            file,
            chunk_length_s=chunk_length_s,
            batch_size=batch_size,
            generate_kwargs={
                "language": languageName if languageName != "Automatic Detection" else None,
                "task": task
            },
            return_timestamps=True,
        )
        file_out = file.split('/')[-1]
        srt = srt_sub.get_subtitle(outputs["chunks"])
        vtt = vtt_sub.get_subtitle(outputs["chunks"])
        txt = txt_sub.get_subtitle(outputs["chunks"])

        with open(file_out+".srt", 'w', encoding='utf-8') as f:
            f.write(srt)
        with open(file_out+".vtt", 'w', encoding='utf-8') as f:
            f.write(vtt)
        with open(file_out+".txt", 'w', encoding='utf-8') as f:
            f.write(txt)

        files_out += [file_out+".srt", file_out+".vtt", file_out+".txt"]

    progress(1, desc="Completed!")

    return files_out, vtt, txt

# Realtime STT
def transcribe_stream(buffer, new_chunk):
    sr, chunk = new_chunk
    if chunk.ndim > 1:
        chunk = chunk.mean(axis=1)
    chunk = chunk.astype(np.float32)
    peak = np.max(np.abs(chunk))
    if peak > 0:
        chunk /= peak

    buffer = chunk if buffer is None else np.concatenate([buffer, chunk])
    text = pipe({"sampling_rate": sr, "raw": buffer})["text"]
    return buffer, text

# Gradio UI
with gr.Blocks(title="Insanely Fast Whisper") as demo:
    gr.Markdown("## 🎙️ Insanely Fast Whisper + Real-time STT")
    
    whisper_models = [
        "openai/whisper-tiny", "openai/whisper-tiny.en",
        "openai/whisper-base", "openai/whisper-base.en",
        "openai/whisper-small", "openai/whisper-small.en",
        "distil-whisper/distil-small.en",
        "openai/whisper-medium", "openai/whisper-medium.en",
        "distil-whisper/distil-medium.en",
        "openai/whisper-large", "openai/whisper-large-v1",
        "openai/whisper-large-v2", "distil-whisper/distil-large-v2",
        "openai/whisper-large-v3", "distil-whisper/distil-large-v3",
    ]

    with gr.Tab("File Transcription"):
        with gr.Row():
            with gr.Column():
                model_dropdown = gr.Dropdown(
                    whisper_models, 
                    value="distil-whisper/distil-large-v2", 
                    label="Model"
                )
                language_dropdown = gr.Dropdown(
                    ["Automatic Detection"] + sorted(get_language_names()), 
                    value="Automatic Detection", 
                    label="Language"
                )
                url_input = gr.Text(label="URL (YouTube, etc.)")
                file_input = gr.File(label="Upload Files", file_count="multiple")
                audio_input = gr.Audio(
                    sources=["upload", "microphone"], 
                    type="filepath", 
                    label="Audio Input"
                )
                task_dropdown = gr.Dropdown(
                    ["transcribe", "translate"], 
                    label="Task", 
                    value="transcribe"
                )
                flash_checkbox = gr.Checkbox(label='Flash', info='Use Flash Attention 2')
                chunk_length = gr.Number(label='chunk_length_s', value=30)
                batch_size = gr.Number(label='batch_size', value=24)
                
                transcribe_button = gr.Button("Transcribe")
                
            with gr.Column():
                output_files = gr.File(label="Download")
                output_text = gr.Text(label="Transcription")
                output_segments = gr.Text(label="Segments")
                
        transcribe_button.click(
            fn=transcribe_webui_simple_progress,
            inputs=[
                model_dropdown, language_dropdown, url_input, 
                file_input, audio_input, task_dropdown, 
                flash_checkbox, chunk_length, batch_size
            ],
            outputs=[output_files, output_text, output_segments]
        )

    with gr.Tab("Real-time Transcription"):
        st_buffer = gr.State()
        mic_rt = gr.Audio(
            sources=["microphone"], type="numpy", streaming=True,
            label="🎤 Speak Now (Live Transcription)"
        )
        txt_rt = gr.Textbox(label="Real-time Transcription")
        mic_rt.stream(
            fn=transcribe_stream,
            inputs=[st_buffer, mic_rt],
            outputs=[st_buffer, txt_rt]
        )

    # Preload model for Hugging Face spaces
    def load_model():
        global pipe, last_model
        last_model = "distil-whisper/distil-large-v2"
        pipe = create_pipe(last_model, flash=False)

    demo.load(load_model)

# Launch the app
if __name__ == "__main__":
    demo.launch()