trip-fontaine commited on
Commit
7c9a0e8
1 Parent(s): ba21603
Files changed (2) hide show
  1. app.py +153 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
3
+ from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
4
+ from transformers.pipelines.audio_utils import ffmpeg_read
5
+ import torch
6
+ import gradio as gr
7
+ import time
8
+
9
+ BATCH_SIZE = 16
10
+ MAX_AUDIO_MINS = 30 # maximum audio input in minutes
11
+
12
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
14
+ attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else "sdpa" if is_torch_sdpa_available() else "eager"
15
+
16
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
17
+ "openai/whisper-large-v3", torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation=attn_implementation
18
+ )
19
+ distilled_model = AutoModelForSpeechSeq2Seq.from_pretrained(
20
+ "eustlb/distil-large-v3-fr", torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation=attn_implementation
21
+ )
22
+
23
+ processor = AutoProcessor.from_pretrained("openai/whisper-large-v3")
24
+
25
+ model.to(device)
26
+ distilled_model.to(device)
27
+
28
+ pipe = pipeline(
29
+ "automatic-speech-recognition",
30
+ model=model,
31
+ tokenizer=processor.tokenizer,
32
+ feature_extractor=processor.feature_extractor,
33
+ max_new_tokens=128,
34
+ chunk_length_s=30,
35
+ torch_dtype=torch_dtype,
36
+ device=device,
37
+ generate_kwargs={"language": "fr", "task": "transcribe"},
38
+ return_timestamps=True
39
+ )
40
+ pipe_forward = pipe._forward
41
+
42
+ distil_pipe = pipeline(
43
+ "automatic-speech-recognition",
44
+ model=distilled_model,
45
+ tokenizer=processor.tokenizer,
46
+ feature_extractor=processor.feature_extractor,
47
+ max_new_tokens=128,
48
+ chunk_length_s=25,
49
+ torch_dtype=torch_dtype,
50
+ device=device,
51
+ generate_kwargs={"language": "fr", "task": "transcribe"},
52
+ )
53
+ distil_pipe_forward = distil_pipe._forward
54
+
55
+
56
+ @spaces.GPU
57
+ def transcribe(inputs):
58
+ if inputs is None:
59
+ raise gr.Error("No audio file submitted! Please record or upload an audio file before submitting your request.")
60
+
61
+ with open(inputs, "rb") as f:
62
+ inputs = f.read()
63
+
64
+ inputs = ffmpeg_read(inputs, pipe.feature_extractor.sampling_rate)
65
+ audio_length_mins = len(inputs) / pipe.feature_extractor.sampling_rate / 60
66
+
67
+ if audio_length_mins > MAX_AUDIO_MINS:
68
+ raise gr.Error(
69
+ f"To ensure fair usage of the Space, the maximum audio length permitted is {MAX_AUDIO_MINS} minutes."
70
+ f"Got an audio of length {round(audio_length_mins, 3)} minutes."
71
+ )
72
+
73
+ inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
74
+
75
+ def _forward_distil_time(*args, **kwargs):
76
+ global distil_runtime
77
+ start_time = time.time()
78
+ result = distil_pipe_forward(*args, **kwargs)
79
+ distil_runtime = time.time() - start_time
80
+ distil_runtime = round(distil_runtime, 2)
81
+ return result
82
+
83
+ distil_pipe._forward = _forward_distil_time
84
+ distil_text = distil_pipe(inputs.copy(), batch_size=BATCH_SIZE)["text"]
85
+ yield distil_text, distil_runtime, None, None, None
86
+
87
+ def _forward_time(*args, **kwargs):
88
+ global runtime
89
+ start_time = time.time()
90
+ result = pipe_forward(*args, **kwargs)
91
+ runtime = time.time() - start_time
92
+ runtime = round(runtime, 2)
93
+ return result
94
+
95
+ pipe._forward = _forward_time
96
+ text = pipe(inputs, batch_size=BATCH_SIZE)["text"]
97
+
98
+ yield distil_text, distil_runtime, text, runtime
99
+
100
+
101
+ if __name__ == "__main__":
102
+ with gr.Blocks() as demo:
103
+ gr.HTML(
104
+ """
105
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
106
+ <div
107
+ style="
108
+ display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;
109
+ "
110
+ >
111
+ <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
112
+ Whisper vs Distil-Whisper: Speed Comparison
113
+ </h1>
114
+ </div>
115
+ </div>
116
+ """
117
+ )
118
+ gr.HTML(
119
+ f"""
120
+ <p><a href="https://huggingface.co/distil-whisper/distil-large-v3"> Distil-Whisper</a> is a distilled variant
121
+ of the <a href="https://huggingface.co/openai/whisper-large-v3"> Whisper</a> model by OpenAI. Compared to Whisper,
122
+ Distil-Whisper runs 6x faster with 50% fewer parameters, while performing to within 1% word error rate (WER) on
123
+ out-of-distribution evaluation data.</p>
124
+
125
+ <p>In this demo, we perform a speed comparison between Whisper and Distil-Whisper in order to test this claim.
126
+ Both models use the <a href="https://huggingface.co/distil-whisper/distil-large-v3#chunked-long-form"> chunked long-form transcription algorithm</a>
127
+ in 🤗 Transformers. To use Distil-Whisper yourself, check the code examples on the
128
+ <a href="https://github.com/huggingface/distil-whisper#1-usage"> Distil-Whisper repository</a>. To ensure fair
129
+ usage of the Space, we ask that audio file inputs are kept to < 30 mins.</p>
130
+ """
131
+ )
132
+ audio = gr.components.Audio(type="filepath", label="Audio input")
133
+ button = gr.Button("Transcribe")
134
+ with gr.Row():
135
+ distil_runtime = gr.components.Textbox(label="Distil-Whisper Transcription Time (s)")
136
+ runtime = gr.components.Textbox(label="Whisper Transcription Time (s)")
137
+ with gr.Row():
138
+ distil_transcription = gr.components.Textbox(label="Distil-Whisper Transcription", show_copy_button=True)
139
+ transcription = gr.components.Textbox(label="Whisper Transcription", show_copy_button=True)
140
+ button.click(
141
+ fn=transcribe,
142
+ inputs=audio,
143
+ outputs=[distil_transcription, distil_runtime, transcription, runtime],
144
+ )
145
+ gr.Markdown("## Examples")
146
+ gr.Examples(
147
+ [["./assets/example_1.wav"], ["./assets/example_2.wav"]],
148
+ audio,
149
+ outputs=[distil_transcription, distil_runtime, transcription, runtime],
150
+ fn=transcribe,
151
+ cache_examples=False,
152
+ )
153
+ demo.queue(max_size=10).launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers
2
+ accelerate