|
import gradio as gr |
|
|
|
import torch |
|
|
|
from diffusers import UniDiffuserPipeline |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
model_id = "thu-ml/unidiffuser-v1" |
|
|
|
pipeline = UniDiffuserPipeline.from_pretrained( |
|
model_id, |
|
) |
|
pipeline.to(device) |
|
|
|
|
|
def convert_to_none(s): |
|
if s: |
|
return s |
|
else: |
|
return None |
|
|
|
|
|
def set_mode(mode): |
|
if mode == "joint": |
|
pipeline.set_joint_mode() |
|
elif mode == "text2img": |
|
pipeline.set_text_to_image_mode() |
|
elif mode == "img2text": |
|
pipeline.set_image_text_mode() |
|
elif mode == "text": |
|
pipeline.set_text_mode() |
|
elif mode == "img": |
|
pipeline.set_image_mode() |
|
|
|
|
|
def sample(mode, prompt, image, num_inference_steps, guidance_scale, seed): |
|
set_mode(mode) |
|
prompt = convert_to_none(prompt) |
|
image = convert_to_none(image) |
|
generator = torch.Generator(device=device).manual_seed(seed) |
|
output_sample = pipeline( |
|
prompt=prompt, |
|
image=image, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
generator=generator, |
|
) |
|
sample_image = None |
|
sample_text = "" |
|
if output_sample.images is not None: |
|
sample_image = output_sample.images[0] |
|
if output_sample.text is not None: |
|
sample_text = output_sample.text[0] |
|
return sample_image, sample_text |
|
|
|
|
|
iface = gr.Interface( |
|
fn=sample, |
|
inputs=[ |
|
gr.Textbox(value="", label="Generation Task"), |
|
gr.Textbox(value="", label="Conditioning prompt"), |
|
gr.Image(value=None, label="Conditioning image", type="pil"), |
|
gr.Number(value=20, label="Num Inference Steps", precision=0), |
|
gr.Number(value=8.0, label="Guidance Scale"), |
|
gr.Number(value=0, label="Seed", precision=0), |
|
], |
|
outputs=[ |
|
gr.Image(label="Sample image"), |
|
gr.Textbox(label="Sample text"), |
|
], |
|
) |
|
iface.launch() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|