sanchit-gandhi HF staff commited on
Commit
9cedb26
1 Parent(s): 04d64f8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +309 -0
app.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import time
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ import yt_dlp as youtube_dl
9
+ from gradio_client import Client
10
+ from pyannote.audio import Pipeline
11
+ from transformers.pipelines.audio_utils import ffmpeg_read
12
+
13
+
14
+ YT_LENGTH_LIMIT_S = 36000 # limit to 1 hour YouTube files
15
+ SAMPLING_RATE = 16000
16
+
17
+ API_URL = "https://sanchit-gandhi-whisper-jax.hf.space/"
18
+
19
+ # set up the Gradio client
20
+ client = Client(API_URL)
21
+
22
+ # set up the diarization pipeline
23
+ diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=True)
24
+
25
+
26
+ def format_string(timestamp):
27
+ """
28
+ Reformat a timestamp string from (HH:)MM:SS to float seconds. Note that the hour column
29
+ is optional, and is appended within the function if not input.
30
+
31
+ Args:
32
+ timestamp (str):
33
+ Timestamp in string format, either MM:SS or HH:MM:SS.
34
+ Returns:
35
+ seconds (float):
36
+ Total seconds corresponding to the input timestamp.
37
+ """
38
+ split_time = timestamp.split(":")
39
+ split_time = [float(sub_time) for sub_time in split_time]
40
+
41
+ if len(split_time) == 2:
42
+ split_time.insert(0, 0)
43
+
44
+ seconds = split_time[0] * 3600 + split_time[1] * 60 + split_time[2]
45
+ return seconds
46
+
47
+
48
+ # Adapted from https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/utils.py#L50
49
+ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."):
50
+ """
51
+ Reformat a timestamp from a float of seconds to a string in format (HH:)MM:SS. Note that the hour
52
+ column is optional, and is appended in the function if the number of hours > 0.
53
+
54
+ Args:
55
+ seconds (float):
56
+ Total seconds corresponding to the input timestamp.
57
+ Returns:
58
+ timestamp (str):
59
+ Timestamp in string format, either MM:SS or HH:MM:SS.
60
+ """
61
+ if seconds is not None:
62
+ milliseconds = round(seconds * 1000.0)
63
+
64
+ hours = milliseconds // 3_600_000
65
+ milliseconds -= hours * 3_600_000
66
+
67
+ minutes = milliseconds // 60_000
68
+ milliseconds -= minutes * 60_000
69
+
70
+ seconds = milliseconds // 1_000
71
+ milliseconds -= seconds * 1_000
72
+
73
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
74
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
75
+ else:
76
+ # we have a malformed timestamp so just return it as is
77
+ return seconds
78
+
79
+
80
+ def format_as_transcription(raw_segments):
81
+ return "\n".join(
82
+ [
83
+ f"{chunk['speaker']} [{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
84
+ for chunk in raw_segments
85
+ ]
86
+ )
87
+
88
+
89
+ def _return_yt_html_embed(yt_url):
90
+ video_id = yt_url.split("?v=")[-1]
91
+ HTML_str = (
92
+ f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
93
+ " </center>"
94
+ )
95
+ return HTML_str
96
+
97
+
98
+ def download_yt_audio(yt_url, filename):
99
+ info_loader = youtube_dl.YoutubeDL()
100
+ try:
101
+ info = info_loader.extract_info(yt_url, download=False)
102
+ except youtube_dl.utils.DownloadError as err:
103
+ raise gr.Error(str(err))
104
+
105
+ file_length = info["duration_string"]
106
+ file_length_s = format_string(file_length)
107
+
108
+ if file_length_s > YT_LENGTH_LIMIT_S:
109
+ yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S))
110
+ file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s))
111
+ raise gr.Error(f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.")
112
+
113
+ ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"}
114
+ with youtube_dl.YoutubeDL(ydl_opts) as ydl:
115
+ try:
116
+ ydl.download([yt_url])
117
+ except youtube_dl.utils.ExtractorError as err:
118
+ raise gr.Error(str(err))
119
+
120
+
121
+ def align(transcription, segments, group_by_speaker=True):
122
+ transcription_split = transcription.split("\n")
123
+
124
+ # re-format transcription from string to List[Dict]
125
+ transcript = []
126
+ for chunk in transcription_split:
127
+ start_end, transcription = chunk[1:].split("] ")
128
+ start, end = start_end.split("->")
129
+
130
+ transcript.append({"timestamp": (format_string(start), format_string(end)), "text": transcription})
131
+
132
+ # diarizer output may contain consecutive segments from the same speaker (e.g. {(0 -> 1, speaker_1), (1 -> 1.5, speaker_1), ...})
133
+ # we combine these segments to give overall timestamps for each speaker's turn (e.g. {(0 -> 1.5, speaker_1), ...})
134
+ new_segments = []
135
+ prev_segment = cur_segment = segments[0]
136
+
137
+ for i in range(1, len(segments)):
138
+ cur_segment = segments[i]
139
+
140
+ # check if we have changed speaker ("label")
141
+ if cur_segment["label"] != prev_segment["label"] and i < len(segments):
142
+ # add the start/end times for the super-segment to the new list
143
+ new_segments.append(
144
+ {
145
+ "segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["start"]},
146
+ "speaker": prev_segment["label"],
147
+ }
148
+ )
149
+ prev_segment = segments[i]
150
+
151
+ # add the last segment(s) if there was no speaker change
152
+ new_segments.append(
153
+ {
154
+ "segment": {"start": prev_segment["segment"]["start"], "end": cur_segment["segment"]["end"]},
155
+ "speaker": prev_segment["label"],
156
+ }
157
+ )
158
+
159
+ # get the end timestamps for each chunk from the ASR output
160
+ end_timestamps = np.array([chunk["timestamp"][-1] for chunk in transcript])
161
+ segmented_preds = []
162
+
163
+ # align the diarizer timestamps and the ASR timestamps
164
+ for segment in new_segments:
165
+ # get the diarizer end timestamp
166
+ end_time = segment["segment"]["end"]
167
+ # find the ASR end timestamp that is closest to the diarizer's end timestamp and cut the transcript to here
168
+ upto_idx = np.argmin(np.abs(end_timestamps - end_time))
169
+
170
+ if group_by_speaker:
171
+ segmented_preds.append(
172
+ {
173
+ "speaker": segment["speaker"],
174
+ "text": "".join([chunk["text"] for chunk in transcript[: upto_idx + 1]]),
175
+ "timestamp": (transcript[0]["timestamp"][0], transcript[upto_idx]["timestamp"][1]),
176
+ }
177
+ )
178
+ else:
179
+ for i in range(upto_idx + 1):
180
+ segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})
181
+
182
+ # crop the transcripts and timestamp lists according to the latest timestamp (for faster argmin)
183
+ transcript = transcript[upto_idx + 1 :]
184
+ end_timestamps = end_timestamps[upto_idx + 1 :]
185
+
186
+ # final post-processing
187
+ transcription = format_as_transcription(segmented_preds)
188
+ return transcription
189
+
190
+
191
+ def transcribe(audio_path, group_by_speaker=True):
192
+ # run Whisper JAX asynchronously using Gradio client (endpoint)
193
+ job = client.submit(
194
+ audio_path,
195
+ "transcribe",
196
+ True,
197
+ api_name="/predict_1",
198
+ )
199
+
200
+ # run diarization while we wait for Whisper JAX
201
+ diarization = diarization_pipeline(audio_path)
202
+ segments = diarization.for_json()["content"]
203
+
204
+ # only fetch the transcription result after performing diarization
205
+ transcription, _ = job.result()
206
+
207
+ # align the ASR transcriptions and diarization timestamps
208
+ transcription = align(transcription, segments, group_by_speaker=group_by_speaker)
209
+
210
+ return transcription
211
+
212
+
213
+ def transcribe_yt(yt_url, group_by_speaker=True):
214
+ # run Whisper JAX asynchronously using Gradio client (endpoint)
215
+ job = client.submit(
216
+ yt_url,
217
+ "transcribe",
218
+ True,
219
+ api_name="/predict_2",
220
+ )
221
+
222
+ _return_yt_html_embed(yt_url)
223
+ with tempfile.TemporaryDirectory() as tmpdirname:
224
+ filepath = os.path.join(tmpdirname, "video.mp4")
225
+ download_yt_audio(yt_url, filepath)
226
+
227
+ with open(filepath, "rb") as f:
228
+ inputs = f.read()
229
+
230
+ inputs = ffmpeg_read(inputs, SAMPLING_RATE)
231
+ inputs = torch.from_numpy(inputs).float()
232
+ inputs = inputs.unsqueeze(0)
233
+
234
+ diarization = diarization_pipeline(
235
+ {"waveform": inputs, "sample_rate": SAMPLING_RATE},
236
+ )
237
+ segments = diarization.for_json()["content"]
238
+
239
+ # only fetch the transcription result after performing diarization
240
+ transcription, _ = job.result()
241
+
242
+ # align the ASR transcriptions and diarization timestamps
243
+ transcription = align(transcription, segments, group_by_speaker=group_by_speaker)
244
+
245
+ return transcription
246
+
247
+
248
+ title = "Whisper JAX + Speaker Diarization ⚡️"
249
+
250
+ description = """Combine the speed of Whisper JAX with pyannote speaker diarization to transcribe meetings in super fast time.
251
+ """
252
+
253
+ article = "Whisper large-v2 model by OpenAI. Speaker diarization model by pyannote. Whisper JAX backend running JAX on a TPU v4-8 through the generous support of the [TRC](https://sites.research.google/trc/about/) programme. Whisper JAX [code](https://github.com/sanchit-gandhi/whisper-jax) and Gradio demo by 🤗 Hugging Face."
254
+
255
+ microphone = gr.Interface(
256
+ fn=transcribe,
257
+ inputs=[
258
+ gr.inputs.Audio(source="microphone", optional=True, type="filepath"),
259
+ gr.inputs.Checkbox(default=True, label="Group by speaker"),
260
+ ],
261
+ outputs=[
262
+ gr.outputs.Textbox(label="Transcription").style(show_copy_button=True),
263
+ ],
264
+ allow_flagging="never",
265
+ title=title,
266
+ description=description,
267
+ article=article,
268
+ )
269
+
270
+ audio_file = gr.Interface(
271
+ fn=transcribe,
272
+ inputs=[
273
+ gr.inputs.Audio(source="upload", optional=True, label="Audio file", type="filepath"),
274
+ gr.inputs.Checkbox(default=True, label="Group by speaker"),
275
+ ],
276
+ outputs=[
277
+ gr.outputs.Textbox(label="Transcription").style(show_copy_button=True),
278
+ ],
279
+ allow_flagging="never",
280
+ title=title,
281
+ description=description,
282
+ article=article,
283
+ )
284
+
285
+ youtube = gr.Interface(
286
+ fn=transcribe_yt,
287
+ inputs=[
288
+ gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
289
+ gr.inputs.Checkbox(default=True, label="Group by speaker"),
290
+ ],
291
+ outputs=[
292
+ gr.outputs.HTML(label="Video"),
293
+ gr.outputs.Textbox(label="Transcription").style(show_copy_button=True),
294
+ ],
295
+ allow_flagging="never",
296
+ title=title,
297
+ examples=[["https://www.youtube.com/watch?v=m8u-18Q0s7I", True]],
298
+ cache_examples=False,
299
+ description=description,
300
+ article=article,
301
+ )
302
+
303
+ demo = gr.Blocks()
304
+
305
+ with demo:
306
+ gr.TabbedInterface([microphone, audio_file, youtube], ["Microphone", "Audio File", "YouTube"])
307
+
308
+ demo.queue(concurrency_count=1, max_size=5)
309
+ demo.launch()