sanchit-gandhi HF staff commited on
Commit
172ec24
1 Parent(s): 9e35e59
Files changed (1) hide show
  1. app.py +26 -6
app.py CHANGED
@@ -1,11 +1,12 @@
1
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
2
  from transformers.utils import is_flash_attn_2_available
 
3
  import torch
4
  import gradio as gr
5
  import time
6
- import os
7
 
8
  BATCH_SIZE = 16
 
9
 
10
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
11
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
@@ -15,10 +16,11 @@ model = AutoModelForSpeechSeq2Seq.from_pretrained(
15
  "openai/whisper-large-v2", torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=use_flash_attention_2
16
  )
17
  distilled_model = AutoModelForSpeechSeq2Seq.from_pretrained(
18
- "distil-whisper/distil-large-v2", torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=use_flash_attention_2, token=TOKEN
19
  )
20
 
21
  if not use_flash_attention_2:
 
22
  model = model.to_bettertransformer()
23
  distilled_model = distilled_model.to_bettertransformer()
24
 
@@ -49,6 +51,7 @@ distil_pipe = pipeline(
49
  chunk_length_s=15,
50
  torch_dtype=torch_dtype,
51
  device=device,
 
52
  )
53
  distil_pipe_forward = distil_pipe._forward
54
 
@@ -56,6 +59,20 @@ def transcribe(inputs):
56
  if inputs is None:
57
  raise gr.Error("No audio file submitted! Please record or upload an audio file before submitting your request.")
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def _forward_distil_time(*args, **kwargs):
60
  global distil_runtime
61
  start_time = time.time()
@@ -92,7 +109,7 @@ if __name__ == "__main__":
92
  "
93
  >
94
  <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
95
- Distil-Whisper VS Whisper
96
  </h1>
97
  </div>
98
  </div>
@@ -100,8 +117,11 @@ if __name__ == "__main__":
100
  )
101
  gr.HTML(
102
  f"""
103
- This demo evaluates the <a href="https://huggingface.co/distil-whisper/distil-large-v2"> Distil-Whisper </a> model
104
- against the <a href="https://huggingface.co/openai/whisper-large-v2"> Whisper </a> model.
 
 
 
105
  """
106
  )
107
  audio = gr.components.Audio(type="filepath", label="Audio input")
@@ -117,4 +137,4 @@ if __name__ == "__main__":
117
  inputs=audio,
118
  outputs=[distil_transcription, distil_runtime, transcription, runtime],
119
  )
120
- demo.queue().launch()
 
1
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
2
  from transformers.utils import is_flash_attn_2_available
3
+ from transformers.pipelines.audio_utils import ffmpeg_read
4
  import torch
5
  import gradio as gr
6
  import time
 
7
 
8
  BATCH_SIZE = 16
9
+ MAX_AUDIO_MINS = 30 # maximum audio input in minutes
10
 
11
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
12
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
 
16
  "openai/whisper-large-v2", torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=use_flash_attention_2
17
  )
18
  distilled_model = AutoModelForSpeechSeq2Seq.from_pretrained(
19
+ "distil-whisper/distil-large-v2", torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=use_flash_attention_2
20
  )
21
 
22
  if not use_flash_attention_2:
23
+ # use flash attention from pytorch sdpa
24
  model = model.to_bettertransformer()
25
  distilled_model = distilled_model.to_bettertransformer()
26
 
 
51
  chunk_length_s=15,
52
  torch_dtype=torch_dtype,
53
  device=device,
54
+ generate_kwargs={"language": "en", "task": "transcribe"},
55
  )
56
  distil_pipe_forward = distil_pipe._forward
57
 
 
59
  if inputs is None:
60
  raise gr.Error("No audio file submitted! Please record or upload an audio file before submitting your request.")
61
 
62
+ with open(inputs, "rb") as f:
63
+ inputs = f.read()
64
+
65
+ inputs = ffmpeg_read(inputs, pipe.feature_extractor.sampling_rate)
66
+ audio_length_mins = len(inputs) / pipe.feature_extractor.sampling_rate / 60
67
+
68
+ if audio_length_mins > MAX_AUDIO_MINS:
69
+ raise gr.Error(
70
+ f"To ensure fair usage of the Space, the maximum audio length permitted is {MAX_AUDIO_MINS} minutes."
71
+ f"Got an audio of length {round(audio_length_mins, 3)} minutes."
72
+ )
73
+
74
+ inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
75
+
76
  def _forward_distil_time(*args, **kwargs):
77
  global distil_runtime
78
  start_time = time.time()
 
109
  "
110
  >
111
  <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
112
+ Whisper vs Distil-Whisper
113
  </h1>
114
  </div>
115
  </div>
 
117
  )
118
  gr.HTML(
119
  f"""
120
+ This demo shows a speed comparison between <a href="https://huggingface.co/openai/whisper-large-v2"> Whisper </a>
121
+ and <a href="https://huggingface.co/distil-whisper/distil-large-v2"> Distil-Whisper </a> for the same audio
122
+ file input. Both models use the <a href="https://huggingface.co/distil-whisper/distil-large-v2#long-form-transcription"> chunked long-form transcription algorithm </a>
123
+ in 🤗 Transformers with Flash Attention support. To ensure fair usage of the Space, we ask that audio
124
+ file inputs are kept to < 30 mins.
125
  """
126
  )
127
  audio = gr.components.Audio(type="filepath", label="Audio input")
 
137
  inputs=audio,
138
  outputs=[distil_transcription, distil_runtime, transcription, runtime],
139
  )
140
+ demo.queue(max_size=10).launch()