makeavid-sd-jax /
lopho's picture
gradio app
history blame
9.23 kB
import os
from io import BytesIO
import base64
from functools import partial
from PIL import Image, ImageOps
import gradio as gr
from makeavid_sd.inference import InferenceUNetPseudo3D, FlaxDPMSolverMultistepScheduler, jnp
_preheat: bool = False
_seen_compilations = set()
_model = InferenceUNetPseudo3D(
model_path = 'TempoFunk/makeavid-sd-jax',
scheduler_cls = FlaxDPMSolverMultistepScheduler,
dtype = jnp.float16,
hf_auth_token = os.environ.get('HUGGING_FACE_HUB_TOKEN', None)
# gradio is illiterate. type hints make it go poopoo in pantsu.
def generate(
prompt = 'An elderly man having a great time in the park.',
neg_prompt = '',
image = { 'image': None, 'mask': None },
inference_steps = 20,
cfg = 12.0,
seed = 0,
fps = 24,
num_frames = 24,
height = 512,
width = 512
) -> str:
height = int(height)
width = int(width)
num_frames = int(num_frames)
seed = int(seed)
if seed < 0:
seed = -seed
inference_steps = int(inference_steps)
if image is not None:
hint_image = image['image']
mask_image = image['mask']
hint_image = None
mask_image = None
if hint_image is not None:
if hint_image.mode != 'RGB':
hint_image = hint_image.convert('RGB')
if hint_image.size != (width, height):
hint_image =, (width, height), method = Image.Resampling.LANCZOS)
if mask_image is not None:
if mask_image.mode != 'L':
mask_image = mask_image.convert('L')
if mask_image.size != (width, height):
mask_image =, (width, height), method = Image.Resampling.LANCZOS)
images = _model.generate(
prompt = [prompt] * _model.device_count,
neg_prompt = neg_prompt,
hint_image = hint_image,
mask_image = mask_image,
inference_steps = inference_steps,
cfg = cfg,
height = height,
width = width,
num_frames = num_frames,
seed = seed
_seen_compilations.add((hint_image is None, inference_steps, height, width, num_frames))
buffer = BytesIO()
format = 'webp',
save_all = True,
append_images = images[1:],
loop = 0,
duration = round(1000 / fps),
allow_mixed = True
data = base64.b64encode(buffer.getvalue()).decode()
data = 'data:image/webp;base64,' + data
return data
def check_if_compiled(image, inference_steps, height, width, num_frames, message):
height = int(height)
width = int(width)
hint_image = None if image is None else image['image']
if (hint_image is None, inference_steps, height, width, num_frames) in _seen_compilations:
return ''
return f"""{message}"""
if _preheat:
print('\npreheating the oven')
prompt = 'preheating the oven',
neg_prompt = '',
image = { 'image': None, 'mask': None },
inference_steps = 20,
cfg = 12.0,
seed = 0
print('Entertaining the guests with sailor songs played on an old piano.')
dada = generate(
prompt = 'Entertaining the guests with sailor songs played on an old harmonium.',
neg_prompt = '',
image = { 'image':'RGB', size = (512, 512), color = (0, 0, 0)), 'mask': None },
inference_steps = 20,
cfg = 12.0,
seed = 0
print('dinner is ready\n')
with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled = False) as demo:
variant = 'panel'
with gr.Row():
with gr.Column():
intro1 = gr.Markdown("""
# Make-A-Video Stable Diffusion JAX
**Please be patient. The model might have to compile with current parameters.**
This can take up to 5 minutes on the first run, and 2-3 minutes on later runs.
The compilation will be cached and consecutive runs with the same parameters
will be much faster.
with gr.Column():
intro2 = gr.Markdown("""
The following parameters require the model to compile
- Number of frames
- Width & Height
- Steps
- Input image vs. no input image
with gr.Row(variant = variant):
with gr.Column(variant = variant):
with gr.Row():
cancel_button = gr.Button(value = 'Cancel')
submit_button = gr.Button(value = 'Make A Video', variant = 'primary')
prompt_input = gr.Textbox(
label = 'Prompt',
value = 'They are dancing in the club while sweat drips from the ceiling.',
interactive = True
neg_prompt_input = gr.Textbox(
label = 'Negative prompt (optional)',
value = '',
interactive = True
inference_steps_input = gr.Slider(
label = 'Steps',
minimum = 1,
maximum = 100,
value = 20,
step = 1
cfg_input = gr.Slider(
label = 'Guidance scale',
minimum = 1.0,
maximum = 20.0,
step = 0.1,
value = 15.0,
interactive = True
seed_input = gr.Number(
label = 'Random seed',
value = 0,
interactive = True,
precision = 0
image_input = gr.Image(
label = 'Input image (optional)',
interactive = True,
image_mode = 'RGB',
type = 'pil',
optional = True,
source = 'upload',
tool = 'sketch'
num_frames_input = gr.Slider(
label = 'Number of frames to generate',
minimum = 1,
maximum = 24,
step = 1,
value = 24
width_input = gr.Slider(
label = 'Width',
minimum = 64,
maximum = 512,
step = 1,
value = 448
height_input = gr.Slider(
label = 'Height',
minimum = 64,
maximum = 512,
step = 1,
value = 448
fps_input = gr.Slider(
label = 'Output FPS',
minimum = 1,
maximum = 1000,
step = 1,
value = 12
with gr.Column(variant = variant):
will_trigger = gr.Markdown('')
patience = gr.Markdown('')
image_output = gr.Image(
label = 'Output',
value = 'example.webp',
interactive = False
trigger_inputs = [ image_input, inference_steps_input, height_input, width_input, num_frames_input ]
trigger_check_fun = partial(check_if_compiled, message = 'Current parameters will trigger compilation.')
height_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
width_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
num_frames_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
inference_steps_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
will_trigger.value = trigger_check_fun(image_input.value, inference_steps_input.value, height_input.value, width_input.value, num_frames_input.value)
ev =
fn = partial(
message = 'Please be patient. The model has to be compiled with current parameters.'
inputs = trigger_inputs,
outputs = patience
fn = generate,
inputs = [
outputs = image_output,
postprocess = False
fn = trigger_check_fun,
inputs = trigger_inputs,
outputs = will_trigger
cancel_button(cancels = ev)
demo.queue(concurrency_count = 1, max_size = 16)