Spaces:
Paused
Paused
import gradio as gr | |
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel | |
import torch | |
# load pipeline | |
model_id = "stabilityai/stable-diffusion-xl-base-1.0" | |
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("cuda") | |
# load finetuned model | |
unet_id = "mhdang/dpo-sdxl-text2image-v1" | |
unet = UNet2DConditionModel.from_pretrained(unet_id, subfolder="unet", torch_dtype=torch.float16) | |
pipe.unet = unet | |
pipe = pipe.to("cuda") | |
pipe.enable_model_cpu_offload() | |
pipe.enable_vae_slicing() | |
def infer(prompt): | |
image = pipe(prompt, guidance_scale=7.5).images[0].resize((512,512)) | |
return image | |
css = """ | |
#col-container{ | |
max-width: 720px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.HTML(""" | |
<h2 style="text-align: center;"> | |
SDXL Using Direct Preference Optimization | |
</h2> | |
""") | |
prompt_in = gr.Textbox(label="Prompt", value="An old man with a bird on his head) | |
submit_btn = gr.Button("Submit") | |
result = gr.Image(label="DPO SDXL Result") | |
submit_btn.click( | |
fn = infer, | |
inputs = [ | |
prompt_in | |
], | |
outputs = [ | |
result | |
] | |
) | |
demo.queue().launch() |