Matthijs Hollemans
let's go!
b1828a3
raw
history blame
4.66 kB
import gradio as gr
import librosa
import numpy as np
import moviepy.editor as mpy
from PIL import Image, ImageDraw, ImageFont
from transformers import pipeline
fps = 25
max_duration = 60 # seconds
video_width = 640
video_height = 480
margin_left = 20
margin_right = 20
margin_top = 20
line_height = 44
background_image = Image.open("background.png")
font = ImageFont.truetype("Lato-Regular.ttf", 40)
text_color = (255, 200, 200)
highlight_color = (255, 255, 255)
# checkpoint = "openai/whisper-tiny"
# checkpoint = "openai/whisper-base"
checkpoint = "openai/whisper-small"
pipe = pipeline(model=checkpoint)
# TODO: no longer need to set these manually once the models have been updated on the Hub
# whisper-base
# pipe.model.config.alignment_heads = [[3, 1], [4, 2], [4, 3], [4, 7], [5, 1], [5, 2], [5, 4], [5, 6]]
# whisper-small
pipe.model.config.alignment_heads = [[5, 3], [5, 9], [8, 0], [8, 4], [8, 7], [8, 8], [9, 0], [9, 7], [9, 9], [10, 5]]
chunks = []
def make_frame(t):
global chunks
# TODO speed optimization: could cache the last image returned and if the
# active chunk and active word didn't change, use that last image instead
# of drawing the exact same thing again
# TODO in the Henry V example, the word "desires" has an ending timestamp
# that's too far into the future, and so the word stays highlighted.
# Could fix this by finding the latest word that is active in the chunk
# and only highlight that one.
image = background_image.copy()
draw = ImageDraw.Draw(image)
# for debugging: draw frame time
#draw.text((20, 20), str(t), fill=text_color, font=font)
space_length = draw.textlength(" ", font)
x = margin_left
y = margin_top
for chunk in chunks:
chunk_start = chunk["timestamp"][0]
chunk_end = chunk["timestamp"][1]
if chunk_end is None: chunk_end = max_duration
if chunk_start <= t <= chunk_end:
words = [x["text"] for x in chunk["words"]]
word_times = [x["timestamp"] for x in chunk["words"]]
for (word, times) in zip(words, word_times):
word_length = draw.textlength(word + " ", font) - space_length
if x + word_length >= video_width - margin_right:
x = margin_left
y += line_height
if times[0] <= t <= times[1]:
color = highlight_color
draw.rectangle([x, y + line_height, x + word_length, y + line_height + 4], fill=color)
else:
color = text_color
draw.text((x, y), word, fill=color, font=font)
x += word_length + space_length
break
return np.array(image)
def predict(audio_path):
global chunks
audio_data, sr = librosa.load(audio_path, mono=True)
duration = librosa.get_duration(y=audio_data, sr=sr)
duration = min(max_duration, duration)
audio_data = audio_data[:int(duration * sr)]
# Run Whisper to get word-level timestamps.
audio_inputs = librosa.resample(audio_data, orig_sr=sr, target_sr=pipe.feature_extractor.sampling_rate)
output = pipe(audio_inputs, chunk_length_s=30, stride_length_s=[4, 2], return_timestamps="word")
chunks = output["chunks"]
print(chunks)
# Create the video.
clip = mpy.VideoClip(make_frame, duration=duration)
audio_clip = mpy.AudioFileClip(audio_path).set_duration(duration)
clip = clip.set_audio(audio_clip)
clip.write_videofile("my_video.mp4", fps=fps, codec="libx264", audio_codec="aac")
return "my_video.mp4"
title = "Word-level timestamps with Whisper"
description = """
This demo shows Whisper <b>word-level timestamps</b> in action using Hugging Face Transformers. It creates a video showing subtitled audio with the current word highlighted.
This demo uses the <b>openai/whisper-small</b> checkpoint. Since it's only a demo, the output is limited to the first 60 seconds of audio.
"""
article = """
<div style='margin:20px auto;'>
<p>Credits:<p>
<ul>
<li>Shakespeare's "Henry V" speech from <a href="https://freesound.org/people/acclivity/sounds/24096/">acclivity</a> (CC BY-NC 4.0 license)
<li>Lato font by Łukasz Dziedzic (licensed under Open Font License)</li>
<li>Whisper model by OpenAI</li>
</ul>
</div>
"""
examples = [
"examples/henry5.wav",
]
gr.Interface(
fn=predict,
inputs=[
gr.Audio(label="Upload Audio", source="upload", type="filepath"),
],
outputs=[
gr.Video(label="Output Video"),
],
title=title,
description=description,
article=article,
examples=examples,
).launch()