shideqin commited on
Commit
3bfbf1b
1 Parent(s): 139df23

modify app

Browse files
Files changed (1) hide show
  1. app.py +53 -68
app.py CHANGED
@@ -1,17 +1,27 @@
1
- import base64
2
  import math
3
  import os
4
  import time
5
  from multiprocessing import Pool
6
 
7
  import gradio as gr
 
8
  import numpy as np
9
  import pytube
10
- import requests
11
- from processing_whisper import WhisperPrePostProcessor
12
  from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
13
  from transformers.pipelines.audio_utils import ffmpeg_read
14
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  title = "Whisper JAX: The Fastest Whisper API ⚡️"
17
 
@@ -24,56 +34,15 @@ To skip the queue, you may wish to create your own inference endpoint, details f
24
 
25
  article = "Whisper large-v2 model by OpenAI. Backend running JAX on a TPU v4-8 through the generous support of the [TRC](https://sites.research.google/trc/about/) programme. Whisper JAX [code](https://github.com/sanchit-gandhi/whisper-jax) and Gradio demo by 🤗 Hugging Face."
26
 
27
- API_URL = os.getenv("API_URL")
28
- API_URL_FROM_FEATURES = os.getenv("API_URL_FROM_FEATURES")
29
  language_names = sorted(TO_LANGUAGE_CODE.keys())
30
- CHUNK_LENGTH_S = 30
31
- BATCH_SIZE = 16
32
- NUM_PROC = 16
33
- FILE_LIMIT_MB = 1000
34
-
35
-
36
- def query(payload):
37
- response = requests.post(API_URL, json=payload)
38
- return response.json(), response.status_code
39
-
40
 
41
- def inference(inputs, task=None, return_timestamps=False):
42
- payload = {"inputs": inputs, "task": task, "return_timestamps": return_timestamps}
43
-
44
- data, status_code = query(payload)
45
-
46
- if status_code != 200:
47
- # error with our request - return the details to the user
48
- raise gr.Error(data["detail"])
49
-
50
- text = data["detail"]
51
- timestamps = data.get("chunks")
52
- if timestamps is not None:
53
- timestamps = [
54
- f"[{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
55
- for chunk in timestamps
56
- ]
57
- text = "\n".join(str(feature) for feature in timestamps)
58
- return text
59
-
60
-
61
- def chunked_query(payload):
62
- response = requests.post(API_URL_FROM_FEATURES, json=payload)
63
- return response.json(), response.status_code
64
-
65
-
66
- def forward(batch, task=None, return_timestamps=False):
67
- feature_shape = batch["input_features"].shape
68
- batch["input_features"] = base64.b64encode(batch["input_features"].tobytes()).decode()
69
- outputs, status_code = chunked_query(
70
- {"batch": batch, "task": task, "return_timestamps": return_timestamps, "feature_shape": feature_shape}
71
- )
72
- if status_code != 200:
73
- # error with our request - return the details to the user
74
- raise gr.Error(outputs["detail"])
75
- outputs["tokens"] = np.asarray(outputs["tokens"])
76
- return outputs
77
 
78
 
79
  def identity(batch):
@@ -102,10 +71,10 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
102
 
103
 
104
  if __name__ == "__main__":
105
- processor = WhisperPrePostProcessor.from_pretrained("openai/whisper-large-v2")
106
  stride_length_s = CHUNK_LENGTH_S / 6
107
- chunk_len = round(CHUNK_LENGTH_S * processor.feature_extractor.sampling_rate)
108
- stride_left = stride_right = round(stride_length_s * processor.feature_extractor.sampling_rate)
109
  step = chunk_len - stride_left - stride_right
110
  pool = Pool(NUM_PROC)
111
 
@@ -118,18 +87,21 @@ if __name__ == "__main__":
118
  range(num_batches)
119
  ) # Gradio progress bar not compatible with generator, see https://github.com/gradio-app/gradio/issues/3841
120
 
121
- dataloader = processor.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
122
  progress(0, desc="Pre-processing audio file...")
 
123
  dataloader = pool.map(identity, dataloader)
124
 
125
  model_outputs = []
126
  start_time = time.time()
127
  # iterate over our chunked audio samples
128
  for batch, _ in zip(dataloader, progress.tqdm(dummy_batches, desc="Transcribing...")):
129
- model_outputs.append(forward(batch, task=task, return_timestamps=return_timestamps))
 
 
130
  runtime = time.time() - start_time
131
 
132
- post_processed = processor.postprocess(model_outputs, return_timestamps=return_timestamps)
133
  text = post_processed["text"]
134
  timestamps = post_processed.get("chunks")
135
  if timestamps is not None:
@@ -138,14 +110,18 @@ if __name__ == "__main__":
138
  for chunk in timestamps
139
  ]
140
  text = "\n".join(str(feature) for feature in timestamps)
 
141
  return text, runtime
142
 
143
  def transcribe_chunked_audio(inputs, task, return_timestamps, progress=gr.Progress()):
144
  progress(0, desc="Loading audio file...")
 
145
  if inputs is None:
 
146
  raise gr.Error("No audio file submitted! Please upload an audio file before submitting your request.")
147
  file_size_mb = os.stat(inputs).st_size / (1024 * 1024)
148
  if file_size_mb > FILE_LIMIT_MB:
 
149
  raise gr.Error(
150
  f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB."
151
  )
@@ -153,9 +129,10 @@ if __name__ == "__main__":
153
  with open(inputs, "rb") as f:
154
  inputs = f.read()
155
 
156
- inputs = ffmpeg_read(inputs, processor.feature_extractor.sampling_rate)
157
- inputs = {"array": inputs, "sampling_rate": processor.feature_extractor.sampling_rate}
158
  text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress)
 
159
  return text, runtime
160
 
161
  def _return_yt_html_embed(yt_url):
@@ -168,14 +145,21 @@ if __name__ == "__main__":
168
 
169
  def transcribe_youtube(yt_url, task, return_timestamps, progress=gr.Progress(), max_filesize=75.0):
170
  progress(0, desc="Loading audio file...")
 
171
  html_embed_str = _return_yt_html_embed(yt_url)
172
- try:
173
- yt = pytube.YouTube(yt_url)
174
- stream = yt.streams.filter(only_audio=True)[0]
175
- except KeyError:
176
- raise gr.Error("An error occurred while loading the YouTube video. Please try again.")
 
 
 
 
 
177
 
178
  if stream.filesize_mb > max_filesize:
 
179
  raise gr.Error(f"Maximum YouTube file size is {max_filesize}MB, got {stream.filesize_mb:.2f}MB.")
180
 
181
  stream.download(filename="audio.mp3")
@@ -183,9 +167,10 @@ if __name__ == "__main__":
183
  with open("audio.mp3", "rb") as f:
184
  inputs = f.read()
185
 
186
- inputs = ffmpeg_read(inputs, processor.feature_extractor.sampling_rate)
187
- inputs = {"array": inputs, "sampling_rate": processor.feature_extractor.sampling_rate}
188
  text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress)
 
189
  return html_embed_str, text, runtime
190
 
191
  microphone_chunked = gr.Interface(
@@ -247,5 +232,5 @@ if __name__ == "__main__":
247
  with demo:
248
  gr.TabbedInterface([microphone_chunked, audio_chunked, youtube], ["Microphone", "Audio File", "YouTube"])
249
 
250
- demo.queue(concurrency_count=3, max_size=5)
251
- demo.launch(show_api=False, max_threads=10)
 
1
+ import logging
2
  import math
3
  import os
4
  import time
5
  from multiprocessing import Pool
6
 
7
  import gradio as gr
8
+ import jax.numpy as jnp
9
  import numpy as np
10
  import pytube
11
+ from jax.experimental.compilation_cache import compilation_cache as cc
 
12
  from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
13
  from transformers.pipelines.audio_utils import ffmpeg_read
14
 
15
+ from whisper_jax import FlaxWhisperPipline
16
+
17
+
18
+ cc.initialize_cache("./jax_cache")
19
+ checkpoint = "openai/whisper-large-v2"
20
+ BATCH_SIZE = 16
21
+ CHUNK_LENGTH_S = 30
22
+ NUM_PROC = 8
23
+ FILE_LIMIT_MB = 1000
24
+ YT_ATTEMPT_LIMIT = 3
25
 
26
  title = "Whisper JAX: The Fastest Whisper API ⚡️"
27
 
 
34
 
35
  article = "Whisper large-v2 model by OpenAI. Backend running JAX on a TPU v4-8 through the generous support of the [TRC](https://sites.research.google/trc/about/) programme. Whisper JAX [code](https://github.com/sanchit-gandhi/whisper-jax) and Gradio demo by 🤗 Hugging Face."
36
 
 
 
37
  language_names = sorted(TO_LANGUAGE_CODE.keys())
 
 
 
 
 
 
 
 
 
 
38
 
39
+ logger = logging.getLogger("whisper-jax-app")
40
+ logger.setLevel(logging.INFO)
41
+ ch = logging.StreamHandler()
42
+ ch.setLevel(logging.INFO)
43
+ formatter = logging.Formatter("%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S")
44
+ ch.setFormatter(formatter)
45
+ logger.addHandler(ch)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
 
48
  def identity(batch):
 
71
 
72
 
73
  if __name__ == "__main__":
74
+ pipeline = FlaxWhisperPipline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE)
75
  stride_length_s = CHUNK_LENGTH_S / 6
76
+ chunk_len = round(CHUNK_LENGTH_S * pipeline.feature_extractor.sampling_rate)
77
+ stride_left = stride_right = round(stride_length_s * pipeline.feature_extractor.sampling_rate)
78
  step = chunk_len - stride_left - stride_right
79
  pool = Pool(NUM_PROC)
80
 
 
87
  range(num_batches)
88
  ) # Gradio progress bar not compatible with generator, see https://github.com/gradio-app/gradio/issues/3841
89
 
90
+ dataloader = pipeline.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
91
  progress(0, desc="Pre-processing audio file...")
92
+ logger.info("Pre-processing audio file...")
93
  dataloader = pool.map(identity, dataloader)
94
 
95
  model_outputs = []
96
  start_time = time.time()
97
  # iterate over our chunked audio samples
98
  for batch, _ in zip(dataloader, progress.tqdm(dummy_batches, desc="Transcribing...")):
99
+ model_outputs.append(
100
+ pipeline.forward(batch, batch_size=BATCH_SIZE, task=task, return_timestamps=return_timestamps)
101
+ )
102
  runtime = time.time() - start_time
103
 
104
+ post_processed = pipeline.postprocess(model_outputs, return_timestamps=return_timestamps)
105
  text = post_processed["text"]
106
  timestamps = post_processed.get("chunks")
107
  if timestamps is not None:
 
110
  for chunk in timestamps
111
  ]
112
  text = "\n".join(str(feature) for feature in timestamps)
113
+ logger.info("done pre-processing")
114
  return text, runtime
115
 
116
  def transcribe_chunked_audio(inputs, task, return_timestamps, progress=gr.Progress()):
117
  progress(0, desc="Loading audio file...")
118
+ logger.info("Loading audio file...")
119
  if inputs is None:
120
+ logger.warning("No audio file")
121
  raise gr.Error("No audio file submitted! Please upload an audio file before submitting your request.")
122
  file_size_mb = os.stat(inputs).st_size / (1024 * 1024)
123
  if file_size_mb > FILE_LIMIT_MB:
124
+ logger.warning("Max file size exceeded")
125
  raise gr.Error(
126
  f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB."
127
  )
 
129
  with open(inputs, "rb") as f:
130
  inputs = f.read()
131
 
132
+ inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
133
+ inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
134
  text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress)
135
+ logger.info("done loading")
136
  return text, runtime
137
 
138
  def _return_yt_html_embed(yt_url):
 
145
 
146
  def transcribe_youtube(yt_url, task, return_timestamps, progress=gr.Progress(), max_filesize=75.0):
147
  progress(0, desc="Loading audio file...")
148
+ logger.info("Loading youtube file...")
149
  html_embed_str = _return_yt_html_embed(yt_url)
150
+
151
+ for attempt in range(YT_ATTEMPT_LIMIT):
152
+ try:
153
+ yt = pytube.YouTube(yt_url)
154
+ stream = yt.streams.filter(only_audio=True)[0]
155
+ break
156
+ except KeyError:
157
+ if attempt + 1 == YT_ATTEMPT_LIMIT:
158
+ logger.warning("YouTube error")
159
+ raise gr.Error("An error occurred while loading the YouTube video. Please try again.")
160
 
161
  if stream.filesize_mb > max_filesize:
162
+ logger.warning("Max YouTube size exceeded")
163
  raise gr.Error(f"Maximum YouTube file size is {max_filesize}MB, got {stream.filesize_mb:.2f}MB.")
164
 
165
  stream.download(filename="audio.mp3")
 
167
  with open("audio.mp3", "rb") as f:
168
  inputs = f.read()
169
 
170
+ inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
171
+ inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
172
  text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress)
173
+ logger.info("done youtube")
174
  return html_embed_str, text, runtime
175
 
176
  microphone_chunked = gr.Interface(
 
232
  with demo:
233
  gr.TabbedInterface([microphone_chunked, audio_chunked, youtube], ["Microphone", "Audio File", "YouTube"])
234
 
235
+ demo.queue(concurrency_count=1, max_size=5)
236
+ demo.launch(server_name="0.0.0.0", show_api=False)