shideqin commited on
Commit
63f0318
1 Parent(s): f988cc5
Files changed (3) hide show
  1. LICENSE +203 -0
  2. README.md +468 -1
  3. app.py +251 -0
LICENSE ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2023- The HuggingFace Inc. team and The OpenAI Authors. All rights reserved.
2
+
3
+ Apache License
4
+ Version 2.0, January 2004
5
+ http://www.apache.org/licenses/
6
+
7
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8
+
9
+ 1. Definitions.
10
+
11
+ "License" shall mean the terms and conditions for use, reproduction,
12
+ and distribution as defined by Sections 1 through 9 of this document.
13
+
14
+ "Licensor" shall mean the copyright owner or entity authorized by
15
+ the copyright owner that is granting the License.
16
+
17
+ "Legal Entity" shall mean the union of the acting entity and all
18
+ other entities that control, are controlled by, or are under common
19
+ control with that entity. For the purposes of this definition,
20
+ "control" means (i) the power, direct or indirect, to cause the
21
+ direction or management of such entity, whether by contract or
22
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
23
+ outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ "You" (or "Your") shall mean an individual or Legal Entity
26
+ exercising permissions granted by this License.
27
+
28
+ "Source" form shall mean the preferred form for making modifications,
29
+ including but not limited to software source code, documentation
30
+ source, and configuration files.
31
+
32
+ "Object" form shall mean any form resulting from mechanical
33
+ transformation or translation of a Source form, including but
34
+ not limited to compiled object code, generated documentation,
35
+ and conversions to other media types.
36
+
37
+ "Work" shall mean the work of authorship, whether in Source or
38
+ Object form, made available under the License, as indicated by a
39
+ copyright notice that is included in or attached to the work
40
+ (an example is provided in the Appendix below).
41
+
42
+ "Derivative Works" shall mean any work, whether in Source or Object
43
+ form, that is based on (or derived from) the Work and for which the
44
+ editorial revisions, annotations, elaborations, or other modifications
45
+ represent, as a whole, an original work of authorship. For the purposes
46
+ of this License, Derivative Works shall not include works that remain
47
+ separable from, or merely link (or bind by name) to the interfaces of,
48
+ the Work and Derivative Works thereof.
49
+
50
+ "Contribution" shall mean any work of authorship, including
51
+ the original version of the Work and any modifications or additions
52
+ to that Work or Derivative Works thereof, that is intentionally
53
+ submitted to Licensor for inclusion in the Work by the copyright owner
54
+ or by an individual or Legal Entity authorized to submit on behalf of
55
+ the copyright owner. For the purposes of this definition, "submitted"
56
+ means any form of electronic, verbal, or written communication sent
57
+ to the Licensor or its representatives, including but not limited to
58
+ communication on electronic mailing lists, source code control systems,
59
+ and issue tracking systems that are managed by, or on behalf of, the
60
+ Licensor for the purpose of discussing and improving the Work, but
61
+ excluding communication that is conspicuously marked or otherwise
62
+ designated in writing by the copyright owner as "Not a Contribution."
63
+
64
+ "Contributor" shall mean Licensor and any individual or Legal Entity
65
+ on behalf of whom a Contribution has been received by Licensor and
66
+ subsequently incorporated within the Work.
67
+
68
+ 2. Grant of Copyright License. Subject to the terms and conditions of
69
+ this License, each Contributor hereby grants to You a perpetual,
70
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71
+ copyright license to reproduce, prepare Derivative Works of,
72
+ publicly display, publicly perform, sublicense, and distribute the
73
+ Work and such Derivative Works in Source or Object form.
74
+
75
+ 3. Grant of Patent License. Subject to the terms and conditions of
76
+ this License, each Contributor hereby grants to You a perpetual,
77
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78
+ (except as stated in this section) patent license to make, have made,
79
+ use, offer to sell, sell, import, and otherwise transfer the Work,
80
+ where such license applies only to those patent claims licensable
81
+ by such Contributor that are necessarily infringed by their
82
+ Contribution(s) alone or by combination of their Contribution(s)
83
+ with the Work to which such Contribution(s) was submitted. If You
84
+ institute patent litigation against any entity (including a
85
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
86
+ or a Contribution incorporated within the Work constitutes direct
87
+ or contributory patent infringement, then any patent licenses
88
+ granted to You under this License for that Work shall terminate
89
+ as of the date such litigation is filed.
90
+
91
+ 4. Redistribution. You may reproduce and distribute copies of the
92
+ Work or Derivative Works thereof in any medium, with or without
93
+ modifications, and in Source or Object form, provided that You
94
+ meet the following conditions:
95
+
96
+ (a) You must give any other recipients of the Work or
97
+ Derivative Works a copy of this License; and
98
+
99
+ (b) You must cause any modified files to carry prominent notices
100
+ stating that You changed the files; and
101
+
102
+ (c) You must retain, in the Source form of any Derivative Works
103
+ that You distribute, all copyright, patent, trademark, and
104
+ attribution notices from the Source form of the Work,
105
+ excluding those notices that do not pertain to any part of
106
+ the Derivative Works; and
107
+
108
+ (d) If the Work includes a "NOTICE" text file as part of its
109
+ distribution, then any Derivative Works that You distribute must
110
+ include a readable copy of the attribution notices contained
111
+ within such NOTICE file, excluding those notices that do not
112
+ pertain to any part of the Derivative Works, in at least one
113
+ of the following places: within a NOTICE text file distributed
114
+ as part of the Derivative Works; within the Source form or
115
+ documentation, if provided along with the Derivative Works; or,
116
+ within a display generated by the Derivative Works, if and
117
+ wherever such third-party notices normally appear. The contents
118
+ of the NOTICE file are for informational purposes only and
119
+ do not modify the License. You may add Your own attribution
120
+ notices within Derivative Works that You distribute, alongside
121
+ or as an addendum to the NOTICE text from the Work, provided
122
+ that such additional attribution notices cannot be construed
123
+ as modifying the License.
124
+
125
+ You may add Your own copyright statement to Your modifications and
126
+ may provide additional or different license terms and conditions
127
+ for use, reproduction, or distribution of Your modifications, or
128
+ for any such Derivative Works as a whole, provided Your use,
129
+ reproduction, and distribution of the Work otherwise complies with
130
+ the conditions stated in this License.
131
+
132
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
133
+ any Contribution intentionally submitted for inclusion in the Work
134
+ by You to the Licensor shall be under the terms and conditions of
135
+ this License, without any additional terms or conditions.
136
+ Notwithstanding the above, nothing herein shall supersede or modify
137
+ the terms of any separate license agreement you may have executed
138
+ with Licensor regarding such Contributions.
139
+
140
+ 6. Trademarks. This License does not grant permission to use the trade
141
+ names, trademarks, service marks, or product names of the Licensor,
142
+ except as required for reasonable and customary use in describing the
143
+ origin of the Work and reproducing the content of the NOTICE file.
144
+
145
+ 7. Disclaimer of Warranty. Unless required by applicable law or
146
+ agreed to in writing, Licensor provides the Work (and each
147
+ Contributor provides its Contributions) on an "AS IS" BASIS,
148
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149
+ implied, including, without limitation, any warranties or conditions
150
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151
+ PARTICULAR PURPOSE. You are solely responsible for determining the
152
+ appropriateness of using or redistributing the Work and assume any
153
+ risks associated with Your exercise of permissions under this License.
154
+
155
+ 8. Limitation of Liability. In no event and under no legal theory,
156
+ whether in tort (including negligence), contract, or otherwise,
157
+ unless required by applicable law (such as deliberate and grossly
158
+ negligent acts) or agreed to in writing, shall any Contributor be
159
+ liable to You for damages, including any direct, indirect, special,
160
+ incidental, or consequential damages of any character arising as a
161
+ result of this License or out of the use or inability to use the
162
+ Work (including but not limited to damages for loss of goodwill,
163
+ work stoppage, computer failure or malfunction, or any and all
164
+ other commercial damages or losses), even if such Contributor
165
+ has been advised of the possibility of such damages.
166
+
167
+ 9. Accepting Warranty or Additional Liability. While redistributing
168
+ the Work or Derivative Works thereof, You may choose to offer,
169
+ and charge a fee for, acceptance of support, warranty, indemnity,
170
+ or other liability obligations and/or rights consistent with this
171
+ License. However, in accepting such obligations, You may act only
172
+ on Your own behalf and on Your sole responsibility, not on behalf
173
+ of any other Contributor, and only if You agree to indemnify,
174
+ defend, and hold each Contributor harmless for any liability
175
+ incurred by, or claims asserted against, such Contributor by reason
176
+ of your accepting any such warranty or additional liability.
177
+
178
+ END OF TERMS AND CONDITIONS
179
+
180
+ APPENDIX: How to apply the Apache License to your work.
181
+
182
+ To apply the Apache License to your work, attach the following
183
+ boilerplate notice, with the fields enclosed by brackets "[]"
184
+ replaced with your own identifying information. (Don't include
185
+ the brackets!) The text should be enclosed in the appropriate
186
+ comment syntax for the file format. We also recommend that a
187
+ file or class name and description of purpose be included on the
188
+ same "printed page" as the copyright notice for easier
189
+ identification within third-party archives.
190
+
191
+ Copyright [yyyy] [name of copyright owner]
192
+
193
+ Licensed under the Apache License, Version 2.0 (the "License");
194
+ you may not use this file except in compliance with the License.
195
+ You may obtain a copy of the License at
196
+
197
+ http://www.apache.org/licenses/LICENSE-2.0
198
+
199
+ Unless required by applicable law or agreed to in writing, software
200
+ distributed under the License is distributed on an "AS IS" BASIS,
201
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202
+ See the License for the specific language governing permissions and
203
+ limitations under the License.
README.md CHANGED
@@ -9,4 +9,471 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  pinned: false
10
  ---
11
 
12
+ # Whisper JAX
13
+
14
+ This repository contains optimised JAX code for OpenAI's [Whisper Model](https://arxiv.org/abs/2212.04356), largely built
15
+ on the 🤗 Hugging Face Transformers Whisper implementation. Compared to OpenAI's PyTorch code, Whisper JAX runs over **70x**
16
+ faster, making it the fastest Whisper implementation available.
17
+
18
+ The JAX code is compatible on CPU, GPU and TPU, and can be run standalone (see [Pipeline Usage](#pipeline-usage)) or
19
+ as an inference endpoint (see [Creating an Endpoint](#creating-an-endpoint)).
20
+
21
+ For a quick-start guide to running Whisper JAX on a Cloud TPU, refer to the following Kaggle notebook, where we transcribe 30 mins of audio in approx 30 sec:
22
+
23
+ [![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/code/sgandhi99/whisper-jax-tpu)
24
+
25
+ The Whisper JAX model is also running as a demo on the Hugging Face Hub:
26
+
27
+ [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/sanchit-gandhi/whisper-jax)
28
+
29
+ ## Installation
30
+
31
+ Whisper JAX was tested using Python 3.9 and JAX version 0.4.5. Installation assumes that you already have the latest
32
+ version of the JAX package installed on your device. You can do so using the official JAX installation guide: https://github.com/google/jax#installation
33
+
34
+ Once the appropriate version of JAX has been installed, Whisper JAX can be installed through pip:
35
+ ```
36
+ pip install git+https://github.com/sanchit-gandhi/whisper-jax.git
37
+ ```
38
+
39
+ To update the Whisper JAX package to the latest version, simply run:
40
+ ```
41
+ pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit-gandhi/whisper-jax.git
42
+ ```
43
+
44
+ ## Pipeline Usage
45
+
46
+ The recommended way of running Whisper JAX is through the [`FlaxWhisperPipline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) abstraction class. This class handles all
47
+ the necessary pre- and post-processing, as well as wrapping the generate method for data parallelism across accelerator devices.
48
+
49
+ Whisper JAX makes use of JAX's [`pmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html) function for data parallelism across GPU/TPU devices. This function is _Just In Time (JIT)_
50
+ compiled the first time it is called. Thereafter, the function will be _cached_, enabling it to be run in super-fast time:
51
+
52
+ ```python
53
+ from whisper_jax import FlaxWhisperPipline
54
+
55
+ # instantiate pipeline
56
+ pipeline = FlaxWhisperPipline("openai/whisper-large-v2")
57
+
58
+ # JIT compile the forward call - slow, but we only do once
59
+ text = pipeline("audio.mp3")
60
+
61
+ # used cached function thereafter - super fast!!
62
+ text = pipeline("audio.mp3")
63
+ ```
64
+
65
+ ### Half-Precision
66
+
67
+ The model computation can be run in half-precision by passing the dtype argument when instantiating the pipeline. This will
68
+ speed-up the computation quite considerably by storing intermediate tensors in half-precision. There is no change to the precision
69
+ of the model weights.
70
+
71
+ For most GPUs, the dtype should be set to `jnp.float16`. For A100 GPUs or TPUs, the dtype should be set to `jnp.bfloat16`:
72
+ ```python
73
+ from whisper_jax import FlaxWhisperPipline
74
+ import jax.numpy as jnp
75
+
76
+ # instantiate pipeline in bfloat16
77
+ pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16)
78
+ ```
79
+
80
+ ### Batching
81
+ Whisper JAX also provides the option of _batching_ a single audio input across accelerator devices. The audio is first
82
+ chunked into 30 second segments, and then chunks dispatched to the model to be transcribed in parallel. The resulting
83
+ transcriptions are stitched back together at the boundaries to give a single, uniform transcription. In practice, batching
84
+ provides a 10x speed-up compared to transcribing the audio samples sequentially, with a less than 1% penalty to the WER[^1], provided the batch size is selected large enough.
85
+
86
+ To enable batching, pass the `batch_size` parameter when you instantiate the pipeline:
87
+
88
+ ```python
89
+ from whisper_jax import FlaxWhisperPipline
90
+
91
+ # instantiate pipeline with batching
92
+ pipeline = FlaxWhisperPipline("openai/whisper-large-v2", batch_size=16)
93
+ ```
94
+
95
+ ### Task
96
+
97
+ By default, the pipeline transcribes the audio file in the language it was spoken in. For speech translation, set the
98
+ `task` argument to `"translate"`:
99
+
100
+ ```python
101
+ # translate
102
+ text = pipeline("audio.mp3", task="translate")
103
+ ```
104
+
105
+ ### Timestamps
106
+
107
+ The [`FlaxWhisperPipline`](https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py#L57) also supports timestamp prediction. Note that enabling timestamps will require a second JIT compilation of the
108
+ forward call, this time including the timestamp outputs:
109
+
110
+ ```python
111
+ # transcribe and return timestamps
112
+ outputs = pipeline("audio.mp3", task="transcribe", return_timestamps=True)
113
+ text = outputs["text"] # transcription
114
+ chunks = outputs["chunks"] # transcription + timestamps
115
+ ```
116
+
117
+ ### Putting it all together
118
+ In the following code snippet, we instantiate the model in bfloat16 precision with batching enabled, and transcribe the audio file
119
+ returning timestamps tokens:
120
+
121
+ ```python
122
+ from whisper_jax import FlaxWhisperPipline
123
+ import jax.numpy as jnp
124
+
125
+ # instantiate pipeline with bfloat16 and enable batching
126
+ pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16)
127
+
128
+ # transcribe and return timestamps
129
+ outputs = pipeline("audio.mp3", task="transcribe", return_timestamps=True)
130
+ ```
131
+
132
+ ## Model Usage
133
+
134
+ The Whisper JAX model can use on a more granular level in much the same way as the original Hugging Face
135
+ Transformers implementation. This requires the Whisper processor to be loaded separately to the model to handle the
136
+ pre- and post-processing, and the generate function to be wrapped using `pmap` by hand:
137
+
138
+ ```python
139
+ import jax.numpy as jnp
140
+ from datasets import load_dataset
141
+ from flax.jax_utils import replicate
142
+ from flax.training.common_utils import shard
143
+ from jax import device_get, pmap
144
+ from transformers import WhisperProcessor
145
+
146
+ from whisper_jax import FlaxWhisperForConditionalGeneration
147
+
148
+ # load the processor and model
149
+ processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
150
+ model, params = FlaxWhisperForConditionalGeneration.from_pretrained(
151
+ "openai/whisper-large-v2", dtype=jnp.bfloat16, _do_init=False,
152
+ )
153
+
154
+ def generate_fn(input_features):
155
+ pred_ids = model.generate(
156
+ input_features, task="transcribe", return_timestamps=False, max_length=model.config.max_length, params=params,
157
+ )
158
+ return pred_ids.sequences
159
+
160
+ # pmap the generate function for data parallelism
161
+ p_generate = pmap(generate_fn, "input_features")
162
+ # replicate the parameters across devices
163
+ params = replicate(params)
164
+
165
+ # load a dummy sample from the LibriSpeech dataset
166
+ ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
167
+ sample = ds[0]["audio"]
168
+
169
+ # pre-process: convert the audio array to log-mel input features
170
+ input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="np").input_features
171
+ # replicate the input features across devices for DP
172
+ input_features = shard(input_features)
173
+
174
+ # run the forward pass (JIT compiled the first time it is called)
175
+ pred_ids = p_generate(input_features)
176
+ output_ids = device_get(pred_ids.reshape(-1, model.config.max_length))
177
+
178
+ # post-process: convert tokens ids to text string
179
+ transcription = processor.batch_decode(pred_ids, skip_special_tokens=True)
180
+ ```
181
+
182
+ ## Available Models and Languages
183
+ All Whisper models on the Hugging Face Hub with Flax weights are compatible with Whisper JAX. This includes, but is not limited to,
184
+ the official OpenAI Whisper checkpoints:
185
+
186
+ | Size | Parameters | English-only | Multilingual |
187
+ |----------|------------|------------------------------------------------------|-----------------------------------------------------|
188
+ | tiny | 39 M | [✓](https://huggingface.co/openai/whisper-tiny.en) | [✓](https://huggingface.co/openai/whisper-tiny) |
189
+ | base | 74 M | [✓](https://huggingface.co/openai/whisper-base.en) | [✓](https://huggingface.co/openai/whisper-base) |
190
+ | small | 244 M | [✓](https://huggingface.co/openai/whisper-small.en) | [✓](https://huggingface.co/openai/whisper-small) |
191
+ | medium | 769 M | [✓](https://huggingface.co/openai/whisper-medium.en) | [✓](https://huggingface.co/openai/whisper-medium) |
192
+ | large | 1550 M | x | [✓](https://huggingface.co/openai/whisper-large) |
193
+ | large-v2 | 1550 M | x | [✓](https://huggingface.co/openai/whisper-large-v2) |
194
+
195
+ Should you wish to use a fine-tuned Whisper checkpoint in Whisper JAX, you should first convert the PyTorch weights to Flax.
196
+ This is straightforward through use of the `from_pt` argument, which will convert the PyTorch state dict to a frozen Flax
197
+ parameter dictionary on the fly. You can then push the converted Flax weights to the Hub to be used directly in Flax
198
+ the next time they are required. Note that converting weights from PyTorch to Flax requires both PyTorch and Flax to be installed.
199
+
200
+ For example, to convert the fine-tuned checkpoint [`sanchit-gandhi/whisper-small-hi`](https://huggingface.co/sanchit-gandhi/whisper-small-hi) from the blog post [Fine-Tuning Whisper](https://huggingface.co/blog/fine-tune-whisper):
201
+ ```python
202
+ from whisper_jax import FlaxWhisperForConditionalGeneration, FlaxWhisperPipline
203
+ import jax.numpy as jnp
204
+
205
+ checkpoint_id = "sanchit-gandhi/whisper-small-hi"
206
+ # convert PyTorch weights to Flax
207
+ model = FlaxWhisperForConditionalGeneration.from_pretrained(checkpoint_id, from_pt=True)
208
+ # push converted weights to the Hub
209
+ model.push_to_hub(checkpoint_id)
210
+
211
+ # now we can load the Flax weights directly as required
212
+ pipeline = FlaxWhisperPipline(checkpoint_id, dtype=jnp.bfloat16, batch_size=16)
213
+ ```
214
+
215
+ ## Advanced Usage
216
+ More advanced users may wish to explore different parallelisation techniques. The Whisper JAX code is
217
+ built on-top of the [T5x codebase](https://github.com/google-research/t5x), meaning it can be run using model, activation, and data parallelism using the T5x
218
+ partitioning convention. To use T5x partitioning, the logical axis rules and number of model partitions must be defined.
219
+ For more details, the user is referred to the official T5x partitioning guide: https://github.com/google-research/t5x/blob/main/docs/usage/partitioning.md
220
+
221
+ ### Pipeline
222
+ The following code snippet demonstrates how data parallelism can be achieved using the pipeline `shard_params` method in
223
+ an entirely equivalent way to `pmap`:
224
+
225
+ ```python
226
+ from whisper_jax import FlaxWhisperPipline
227
+ import jax.numpy as jnp
228
+
229
+ # 2D parameter and activation partitioning for DP
230
+ logical_axis_rules_dp = (
231
+ ("batch", "data"),
232
+ ("mlp", None),
233
+ ("heads", None),
234
+ ("vocab", None),
235
+ ("embed", None),
236
+ ("embed", None),
237
+ ("joined_kv", None),
238
+ ("kv", None),
239
+ ("length", None),
240
+ ("num_mel", None),
241
+ ("channels", None),
242
+ )
243
+
244
+ pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16, batch_size=16)
245
+ pipeline.shard_params(num_mp_partitions=1, logical_axis_rules=logical_axis_rules_dp)
246
+ ```
247
+
248
+ ### Model
249
+ It is also possible to use the Whisper JAX model with T5x partitioning by defining a T5x inference state and T5x partitioner:
250
+
251
+ ```python
252
+ import jax
253
+ import jax.numpy as jnp
254
+ from flax.core.frozen_dict import freeze
255
+ from jax.sharding import PartitionSpec as P
256
+
257
+ from whisper_jax import FlaxWhisperForConditionalGeneration, InferenceState, PjitPartitioner
258
+
259
+
260
+ # 2D parameter and activation partitioning for DP
261
+ logical_axis_rules_dp = [
262
+ ("batch", "data"),
263
+ ("mlp", None),
264
+ ("heads", None),
265
+ ("vocab", None),
266
+ ("embed", None),
267
+ ("embed", None),
268
+ ("joined_kv", None),
269
+ ("kv", None),
270
+ ("length", None),
271
+ ("num_mel", None),
272
+ ("channels", None),
273
+ ]
274
+
275
+ model, params = FlaxWhisperForConditionalGeneration.from_pretrained(
276
+ "openai/whisper-large-v2",
277
+ _do_init=False,
278
+ dtype=jnp.bfloat16,
279
+ )
280
+
281
+
282
+ def init_fn():
283
+ input_shape = (1, 80, 3000)
284
+
285
+ input_features = jnp.zeros(input_shape, dtype="f4")
286
+ input_features = input_features.at[(..., -1)].set(model.config.eos_token_id)
287
+
288
+ decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4")
289
+ decoder_attention_mask = jnp.ones_like(decoder_input_ids)
290
+
291
+ batch_size, sequence_length = decoder_input_ids.shape
292
+ decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
293
+
294
+ rng = jax.random.PRNGKey(0)
295
+ init_params = model.module.init(
296
+ rng,
297
+ input_features=input_features,
298
+ decoder_input_ids=decoder_input_ids,
299
+ decoder_attention_mask=decoder_attention_mask,
300
+ decoder_position_ids=decoder_position_ids,
301
+ return_dict=False,
302
+ )
303
+ return init_params
304
+
305
+
306
+ # Axis names metadata
307
+ param_axes = jax.eval_shape(init_fn)["params_axes"]
308
+
309
+ # Create InferenceState, since the partitioner expects it
310
+ state = InferenceState(
311
+ step=jnp.array(0),
312
+ params=freeze(model.params_shape_tree),
313
+ params_axes=freeze(param_axes),
314
+ flax_mutables=None,
315
+ flax_mutables_axes=param_axes,
316
+ )
317
+
318
+ # Define the pjit partitioner with 1 model partition
319
+ partitioner = PjitPartitioner(
320
+ num_partitions=1,
321
+ logical_axis_rules=logical_axis_rules_dp,
322
+ )
323
+
324
+ mesh_axes = partitioner.get_mesh_axes(state)
325
+ params_spec = mesh_axes.params
326
+
327
+ p_shard_params = partitioner.partition(model.to_bf16, (params_spec,), params_spec)
328
+
329
+
330
+ def generate(params, input_features):
331
+ output_ids = model.generate(input_features, params=params, max_length=model.config.max_length).sequences
332
+ return output_ids
333
+
334
+
335
+ p_generate = partitioner.partition(
336
+ generate,
337
+ in_axis_resources=(params_spec, P("data")),
338
+ out_axis_resources=P("data"),
339
+ )
340
+
341
+ # This will auto-magically run in mesh context
342
+ params = p_shard_params(freeze(params))
343
+
344
+ # you can now run the forward pass with:
345
+ # pred_ids = p_generate(input_features)
346
+ ```
347
+
348
+ ## Benchmarks
349
+
350
+ We compare Whisper JAX to the official [OpenAI implementation](https://github.com/openai/whisper) and the [🤗 Transformers
351
+ implementation](https://huggingface.co/docs/transformers/model_doc/whisper). We benchmark the models on audio samples of
352
+ increasing length and report the average inference time in seconds over 10 repeat runs. For all three systems, we pass a
353
+ pre-loaded audio file to the model and measure the time for the forward pass. Leaving the task of loading the audio file
354
+ to the systems adds an equal offset to all the benchmark times, so the actual time for loading **and** transcribing an
355
+ audio file will be higher than the reported numbers.
356
+
357
+ OpenAI and Transformers both run in PyTorch on GPU. Whisper JAX runs in JAX on GPU and TPU. OpenAI transcribes the audio
358
+ sequentially in the order it is spoken. Both Transformers and Whisper JAX use a batching algorithm, where chunks of audio
359
+ are batched together and transcribed in parallel (see section [Batching](#batching)).
360
+
361
+ **Table 1:** Average inference time in seconds for audio files of increasing length. GPU device is a single A100 40GB GPU.
362
+ TPU device is a single TPU v4-8.
363
+
364
+ <div align="center">
365
+
366
+ | | OpenAI | Transformers | Whisper JAX | Whisper JAX |
367
+ |-----------|---------|--------------|-------------|-------------|
368
+ | | | | | |
369
+ | Framework | PyTorch | PyTorch | JAX | JAX |
370
+ | Backend | GPU | GPU | GPU | TPU |
371
+ | | | | | |
372
+ | 1 min | 13.8 | 4.54 | 1.72 | 0.45 |
373
+ | 10 min | 108.3 | 20.2 | 9.38 | 2.01 |
374
+ | 1 hour | 1001.0 | 126.1 | 75.3 | 13.8 |
375
+ | | | | | |
376
+
377
+ </div>
378
+
379
+ ## Creating an Endpoint
380
+
381
+ The Whisper JAX model is running as a demo on the Hugging Face Hub:
382
+
383
+ [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/sanchit-gandhi/whisper-jax)
384
+
385
+ However, at peak times there may be a queue of users that limit how quickly your audio input is transcribed. In this case,
386
+ you may benefit from running the model yourself, such that you have unrestricted access to the Whisper JAX model.
387
+
388
+ If you are just interested in running the model in a standalone Python script, refer to the Kaggle notebook Whisper JAX TPU:
389
+
390
+ [![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/code/sgandhi99/whisper-jax-tpu)
391
+
392
+ Otherwise, we provide all the necessary code for creating an inference endpoint. To obtain this code, first clone the
393
+ repository on the GPU/TPU on which you want to host the endpoint:
394
+ ```
395
+ git clone https://github.com/sanchit-gandhi/whisper-jax
396
+ ```
397
+
398
+ And then install Whisper JAX from source, with the required additional endpoint dependencies:
399
+ ```
400
+ cd whisper-jax
401
+ pip install -e .["endpoint"]
402
+ ```
403
+
404
+ We recommend that you set-up an endpoint in the same zone/region as the one you are based in. This reduces the communication
405
+ time between your local machine and the remote one, which can significantly reduce the overall request time.
406
+
407
+ The Python script [`fastapi_app.py`](app/fastapi_app.py) contains the code to launch a FastAPI app with the Whisper large-v2 model.
408
+ By default, it uses a batch size of 16 and bfloat16 half-precision. You should update these parameters depending on your
409
+ GPU/TPU device (as explained in the sections on [Half-precision](#half-precision) and [Batching](#batching)).
410
+
411
+ You can launch the FastAPI app through Uvicorn using the bash script [`launch_app.sh`](app/launch_app.sh):
412
+ ```
413
+ bash launch_app.sh
414
+ ```
415
+
416
+ This will open the port 8000 for the FastAPI app. To direct network requests to the FastAPI app, we use ngrok to launch a
417
+ server on the corresponding port:
418
+ ```
419
+ ngrok http --subdomain=whisper-jax 8000
420
+ ```
421
+
422
+ We can now send json requests to our endpoint using ngrok. The function `transcribe_audio` loads an audio file, encodes it
423
+ in bytes, sends it to our endpoint, and returns the transcription:
424
+
425
+ ```python
426
+ import base64
427
+ from transformers.pipelines.audio_utils import ffmpeg_read
428
+ import requests
429
+
430
+ API_URL = "https://whisper-jax.ngrok.io/generate/" # make sure this URL matches your ngrok subdomain
431
+
432
+
433
+ def query(payload):
434
+ """Send json payload to ngrok API URL and return response."""
435
+ response = requests.post(API_URL, json=payload)
436
+ return response.json(), response.status_code
437
+
438
+
439
+ def transcribe_audio(audio_file, task="transcribe", return_timestamps=False):
440
+ with open(audio_file, "rb") as f:
441
+ inputs = f.read()
442
+ inputs = ffmpeg_read(inputs, sampling_rate=16000)
443
+ # encode to bytes to make json compatible
444
+ inputs = {"array": base64.b64encode(inputs.tobytes()).decode(), "sampling_rate": 16000}
445
+ # format as a json payload and send query
446
+ payload = {"inputs": inputs, "task": task, "return_timestamps": return_timestamps}
447
+ data, status_code = query(payload)
448
+
449
+ if status_code == 200:
450
+ output = {"text": data["text"], "chunks": data.get("chunks", None)}
451
+ else:
452
+ output = data["detail"]
453
+
454
+ return output
455
+
456
+ # transcribe an audio file using our endpoint
457
+ output = transcribe_audio("audio.mp3")
458
+ ```
459
+
460
+ Note that this code snippet sends a base64 byte encoding of the audio file to the remote machine over [`requests`](https://requests.readthedocs.io).
461
+ In some cases, transferring the audio request from the local machine to the remote can take longer than actually
462
+ transcribing it. Therefore, you may wish to explore more efficient methods of sending requests, such as parallel
463
+ requests/transcription (see function `transcribe_chunked_audio` in [app.py](app/app.py).)
464
+
465
+ Finally, we can create a Gradio demo for the frontend, the code for which resides in [`app.py`](app/app.py). You can launch this
466
+ application by providing the ngrok subdomain:
467
+ ```
468
+ API_URL=https://whisper-jax.ngrok.io/generate/ API_URL_FROM_FEATURES=https://whisper-jax.ngrok.io/generate_from_features/ python app.py
469
+ ```
470
+
471
+ This will launch a Gradio demo with the same interface as the official Whisper JAX demo.
472
+
473
+ ## Acknowledgements
474
+
475
+ * 🤗 Hugging Face Transformers for the base Whisper implementation, particularly to [andyehrenberg](https://github.com/andyehrenberg) for the [Flax Whisper PR](https://github.com/huggingface/transformers/pull/20479) and [ArthurZucker](https://github.com/ArthurZucker) for the batching algorithm
476
+ * Gradio for their easy-to-use package for building ML demos, and [pcuenca](https://github.com/pcuenca) for the help in hooking the demo up to the TPU
477
+ * Google's [TPU Research Cloud (TRC)](https://sites.research.google/trc/about/) programme for Cloud TPUs
478
+
479
+ [^1]: See WER results from Colab: https://colab.research.google.com/drive/1rS1L4YSJqKUH_3YxIQHBI982zso23wor?usp=sharing
app.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import math
3
+ import os
4
+ import time
5
+ from multiprocessing import Pool
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import pytube
10
+ import requests
11
+ from processing_whisper import WhisperPrePostProcessor
12
+ from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
13
+ from transformers.pipelines.audio_utils import ffmpeg_read
14
+
15
+
16
+ title = "Whisper JAX: The Fastest Whisper API ⚡️"
17
+
18
+ description = """Whisper JAX is an optimised implementation of the [Whisper model](https://huggingface.co/openai/whisper-large-v2) by OpenAI. It runs on JAX with a TPU v4-8 in the backend. Compared to PyTorch on an A100 GPU, it is over [**70x faster**](https://github.com/sanchit-gandhi/whisper-jax#benchmarks), making it the fastest Whisper API available.
19
+
20
+ Note that at peak times, you may find yourself in the queue for this demo. When you submit a request, your queue position will be shown in the top right-hand side of the demo pane. Once you reach the front of the queue, your audio file will be transcribed, with the progress displayed through a progress bar.
21
+
22
+ To skip the queue, you may wish to create your own inference endpoint, details for which can be found in the [Whisper JAX repository](https://github.com/sanchit-gandhi/whisper-jax#creating-an-endpoint).
23
+ """
24
+
25
+ article = "Whisper large-v2 model by OpenAI. Backend running JAX on a TPU v4-8 through the generous support of the [TRC](https://sites.research.google/trc/about/) programme. Whisper JAX [code](https://github.com/sanchit-gandhi/whisper-jax) and Gradio demo by 🤗 Hugging Face."
26
+
27
+ API_URL = os.getenv("API_URL")
28
+ API_URL_FROM_FEATURES = os.getenv("API_URL_FROM_FEATURES")
29
+ language_names = sorted(TO_LANGUAGE_CODE.keys())
30
+ CHUNK_LENGTH_S = 30
31
+ BATCH_SIZE = 16
32
+ NUM_PROC = 16
33
+ FILE_LIMIT_MB = 1000
34
+
35
+
36
+ def query(payload):
37
+ response = requests.post(API_URL, json=payload)
38
+ return response.json(), response.status_code
39
+
40
+
41
+ def inference(inputs, task=None, return_timestamps=False):
42
+ payload = {"inputs": inputs, "task": task, "return_timestamps": return_timestamps}
43
+
44
+ data, status_code = query(payload)
45
+
46
+ if status_code != 200:
47
+ # error with our request - return the details to the user
48
+ raise gr.Error(data["detail"])
49
+
50
+ text = data["detail"]
51
+ timestamps = data.get("chunks")
52
+ if timestamps is not None:
53
+ timestamps = [
54
+ f"[{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
55
+ for chunk in timestamps
56
+ ]
57
+ text = "\n".join(str(feature) for feature in timestamps)
58
+ return text
59
+
60
+
61
+ def chunked_query(payload):
62
+ response = requests.post(API_URL_FROM_FEATURES, json=payload)
63
+ return response.json(), response.status_code
64
+
65
+
66
+ def forward(batch, task=None, return_timestamps=False):
67
+ feature_shape = batch["input_features"].shape
68
+ batch["input_features"] = base64.b64encode(batch["input_features"].tobytes()).decode()
69
+ outputs, status_code = chunked_query(
70
+ {"batch": batch, "task": task, "return_timestamps": return_timestamps, "feature_shape": feature_shape}
71
+ )
72
+ if status_code != 200:
73
+ # error with our request - return the details to the user
74
+ raise gr.Error(outputs["detail"])
75
+ outputs["tokens"] = np.asarray(outputs["tokens"])
76
+ return outputs
77
+
78
+
79
+ def identity(batch):
80
+ return batch
81
+
82
+
83
+ # Copied from https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/utils.py#L50
84
+ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."):
85
+ if seconds is not None:
86
+ milliseconds = round(seconds * 1000.0)
87
+
88
+ hours = milliseconds // 3_600_000
89
+ milliseconds -= hours * 3_600_000
90
+
91
+ minutes = milliseconds // 60_000
92
+ milliseconds -= minutes * 60_000
93
+
94
+ seconds = milliseconds // 1_000
95
+ milliseconds -= seconds * 1_000
96
+
97
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
98
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
99
+ else:
100
+ # we have a malformed timestamp so just return it as is
101
+ return seconds
102
+
103
+
104
+ if __name__ == "__main__":
105
+ processor = WhisperPrePostProcessor.from_pretrained("openai/whisper-large-v2")
106
+ stride_length_s = CHUNK_LENGTH_S / 6
107
+ chunk_len = round(CHUNK_LENGTH_S * processor.feature_extractor.sampling_rate)
108
+ stride_left = stride_right = round(stride_length_s * processor.feature_extractor.sampling_rate)
109
+ step = chunk_len - stride_left - stride_right
110
+ pool = Pool(NUM_PROC)
111
+
112
+ def tqdm_generate(inputs: dict, task: str, return_timestamps: bool, progress: gr.Progress):
113
+ inputs_len = inputs["array"].shape[0]
114
+ all_chunk_start_idx = np.arange(0, inputs_len, step)
115
+ num_samples = len(all_chunk_start_idx)
116
+ num_batches = math.ceil(num_samples / BATCH_SIZE)
117
+ dummy_batches = list(
118
+ range(num_batches)
119
+ ) # Gradio progress bar not compatible with generator, see https://github.com/gradio-app/gradio/issues/3841
120
+
121
+ dataloader = processor.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
122
+ progress(0, desc="Pre-processing audio file...")
123
+ dataloader = pool.map(identity, dataloader)
124
+
125
+ model_outputs = []
126
+ start_time = time.time()
127
+ # iterate over our chunked audio samples
128
+ for batch, _ in zip(dataloader, progress.tqdm(dummy_batches, desc="Transcribing...")):
129
+ model_outputs.append(forward(batch, task=task, return_timestamps=return_timestamps))
130
+ runtime = time.time() - start_time
131
+
132
+ post_processed = processor.postprocess(model_outputs, return_timestamps=return_timestamps)
133
+ text = post_processed["text"]
134
+ timestamps = post_processed.get("chunks")
135
+ if timestamps is not None:
136
+ timestamps = [
137
+ f"[{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
138
+ for chunk in timestamps
139
+ ]
140
+ text = "\n".join(str(feature) for feature in timestamps)
141
+ return text, runtime
142
+
143
+ def transcribe_chunked_audio(inputs, task, return_timestamps, progress=gr.Progress()):
144
+ progress(0, desc="Loading audio file...")
145
+ if inputs is None:
146
+ raise gr.Error("No audio file submitted! Please upload an audio file before submitting your request.")
147
+ file_size_mb = os.stat(inputs).st_size / (1024 * 1024)
148
+ if file_size_mb > FILE_LIMIT_MB:
149
+ raise gr.Error(
150
+ f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB."
151
+ )
152
+
153
+ with open(inputs, "rb") as f:
154
+ inputs = f.read()
155
+
156
+ inputs = ffmpeg_read(inputs, processor.feature_extractor.sampling_rate)
157
+ inputs = {"array": inputs, "sampling_rate": processor.feature_extractor.sampling_rate}
158
+ text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress)
159
+ return text, runtime
160
+
161
+ def _return_yt_html_embed(yt_url):
162
+ video_id = yt_url.split("?v=")[-1]
163
+ HTML_str = (
164
+ f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
165
+ " </center>"
166
+ )
167
+ return HTML_str
168
+
169
+ def transcribe_youtube(yt_url, task, return_timestamps, progress=gr.Progress(), max_filesize=75.0):
170
+ progress(0, desc="Loading audio file...")
171
+ html_embed_str = _return_yt_html_embed(yt_url)
172
+ try:
173
+ yt = pytube.YouTube(yt_url)
174
+ stream = yt.streams.filter(only_audio=True)[0]
175
+ except KeyError:
176
+ raise gr.Error("An error occurred while loading the YouTube video. Please try again.")
177
+
178
+ if stream.filesize_mb > max_filesize:
179
+ raise gr.Error(f"Maximum YouTube file size is {max_filesize}MB, got {stream.filesize_mb:.2f}MB.")
180
+
181
+ stream.download(filename="audio.mp3")
182
+
183
+ with open("audio.mp3", "rb") as f:
184
+ inputs = f.read()
185
+
186
+ inputs = ffmpeg_read(inputs, processor.feature_extractor.sampling_rate)
187
+ inputs = {"array": inputs, "sampling_rate": processor.feature_extractor.sampling_rate}
188
+ text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress)
189
+ return html_embed_str, text, runtime
190
+
191
+ microphone_chunked = gr.Interface(
192
+ fn=transcribe_chunked_audio,
193
+ inputs=[
194
+ gr.inputs.Audio(source="microphone", optional=True, type="filepath"),
195
+ gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
196
+ gr.inputs.Checkbox(default=False, label="Return timestamps"),
197
+ ],
198
+ outputs=[
199
+ gr.outputs.Textbox(label="Transcription").style(show_copy_button=True),
200
+ gr.outputs.Textbox(label="Transcription Time (s)"),
201
+ ],
202
+ allow_flagging="never",
203
+ title=title,
204
+ description=description,
205
+ article=article,
206
+ )
207
+
208
+ audio_chunked = gr.Interface(
209
+ fn=transcribe_chunked_audio,
210
+ inputs=[
211
+ gr.inputs.Audio(source="upload", optional=True, label="Audio file", type="filepath"),
212
+ gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
213
+ gr.inputs.Checkbox(default=False, label="Return timestamps"),
214
+ ],
215
+ outputs=[
216
+ gr.outputs.Textbox(label="Transcription").style(show_copy_button=True),
217
+ gr.outputs.Textbox(label="Transcription Time (s)"),
218
+ ],
219
+ allow_flagging="never",
220
+ title=title,
221
+ description=description,
222
+ article=article,
223
+ )
224
+
225
+ youtube = gr.Interface(
226
+ fn=transcribe_youtube,
227
+ inputs=[
228
+ gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
229
+ gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
230
+ gr.inputs.Checkbox(default=False, label="Return timestamps"),
231
+ ],
232
+ outputs=[
233
+ gr.outputs.HTML(label="Video"),
234
+ gr.outputs.Textbox(label="Transcription").style(show_copy_button=True),
235
+ gr.outputs.Textbox(label="Transcription Time (s)"),
236
+ ],
237
+ allow_flagging="never",
238
+ title=title,
239
+ examples=[["https://www.youtube.com/watch?v=m8u-18Q0s7I", "transcribe", False]],
240
+ cache_examples=False,
241
+ description=description,
242
+ article=article,
243
+ )
244
+
245
+ demo = gr.Blocks()
246
+
247
+ with demo:
248
+ gr.TabbedInterface([microphone_chunked, audio_chunked, youtube], ["Microphone", "Audio File", "YouTube"])
249
+
250
+ demo.queue(concurrency_count=3, max_size=5)
251
+ demo.launch(show_api=False, max_threads=10)