unidiffuser / app.py
hysts's picture
hysts HF staff
Update
4539421
raw
history blame
6.4 kB
#!/usr/bin/env python
from __future__ import annotations
import os
import random
import gradio as gr
import numpy as np
import PIL.Image
import spaces
import torch
from diffusers import UniDiffuserPipeline
DESCRIPTION = "# [UniDiffuser](https://github.com/thu-ml/unidiffuser)"
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶</p>"
MAX_SEED = np.iinfo(np.int32).max
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1", torch_dtype=torch.float16)
pipe.to(device)
@spaces.GPU
def run(
mode: str,
prompt: str,
image: PIL.Image.Image | None,
seed: int = 0,
num_steps: int = 20,
guidance_scale: float = 8.0,
) -> tuple[PIL.Image.Image | None, str]:
generator = torch.Generator(device=device).manual_seed(seed)
if mode == "t2i":
pipe.set_text_to_image_mode()
sample = pipe(prompt=prompt, num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator)
return sample.images[0], ""
elif mode == "i2t":
pipe.set_image_to_text_mode()
sample = pipe(image=image, num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator)
return None, sample.text[0]
elif mode == "joint":
pipe.set_joint_mode()
sample = pipe(num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator)
return sample.images[0], sample.text[0]
elif mode == "i":
pipe.set_image_mode()
sample = pipe(num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator)
return sample.images[0], ""
elif mode == "t":
pipe.set_text_mode()
sample = pipe(num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator)
return None, sample.text[0]
elif mode == "i2t2i":
pipe.set_image_to_text_mode()
sample = pipe(image=image, num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator)
pipe.set_text_to_image_mode()
sample = pipe(
prompt=sample.text[0],
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
generator=generator,
)
return sample.images[0], ""
elif mode == "t2i2t":
pipe.set_text_to_image_mode()
sample = pipe(prompt=prompt, num_inference_steps=num_steps, guidance_scale=guidance_scale, generator=generator)
pipe.set_image_to_text_mode()
sample = pipe(
image=sample.images[0],
num_inference_steps=num_steps,
guidance_scale=guidance_scale,
generator=generator,
)
return None, sample.text[0]
else:
raise ValueError
def create_demo(mode_name: str) -> gr.Blocks:
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
mode = gr.Dropdown(
label="Mode",
choices=[
"t2i",
"i2t",
"joint",
"i",
"t",
"i2t2i",
"t2i2t",
],
value=mode_name,
visible=False,
)
prompt = gr.Text(label="Prompt", max_lines=1, visible=mode_name in ["t2i", "t2i2t"])
image = gr.Image(label="Input image", type="pil", visible=mode_name in ["i2t", "i2t2i"])
run_button = gr.Button("Run")
with gr.Accordion("Advanced options", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
num_steps = gr.Slider(
label="Steps",
minimum=1,
maximum=100,
value=20,
step=1,
)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=0.1,
maximum=30.0,
value=8.0,
step=0.1,
)
with gr.Column():
result_image = gr.Image(label="Generated image", visible=mode_name in ["t2i", "i", "joint", "i2t2i"])
result_text = gr.Text(label="Generated text", visible=mode_name in ["i2t", "t", "joint", "t2i2t"])
gr.on(
triggers=[prompt.submit, run_button.click],
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
).then(
fn=run,
inputs=[
mode,
prompt,
image,
seed,
num_steps,
guidance_scale,
],
outputs=[
result_image,
result_text,
],
api_name=f"run_{mode_name}",
)
return demo
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_id="duplicate-button",
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
)
with gr.Tabs():
with gr.TabItem("text2image"):
create_demo("t2i")
with gr.TabItem("image2text"):
create_demo("i2t")
with gr.TabItem("image variation"):
create_demo("i2t2i")
with gr.TabItem("joint generation"):
create_demo("joint")
with gr.TabItem("image generation"):
create_demo("i")
with gr.TabItem("text generation"):
create_demo("t")
with gr.TabItem("text variation"):
create_demo("t2i2t")
if __name__ == "__main__":
demo.queue(max_size=20).launch()