sdxl-dpo / app.py
fffiloni's picture
Update app.py
ff896ad
raw
history blame
1.31 kB
import gradio as gr
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
import torch
# load pipeline
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("cuda")
# load finetuned model
unet_id = "mhdang/dpo-sdxl-text2image-v1"
unet = UNet2DConditionModel.from_pretrained(unet_id, subfolder="unet", torch_dtype=torch.float16)
pipe.unet = unet
pipe = pipe.to("cuda")
pipe.enable_model_cpu_offload()
pipe.enable_vae_slicing()
def infer(prompt):
image = pipe(prompt, guidance_scale=7.5).images[0].resize((512,512))
return image
css = """
#col-container{
max-width: 720px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.HTML("""
<h2 style="text-align: center;">
SDXL Using Direct Preference Optimization
</h2>
""")
prompt_in = gr.Textbox(label="Prompt", value="An old man with a bird on his head)
submit_btn = gr.Button("Submit")
result = gr.Image(label="DPO SDXL Result")
submit_btn.click(
fn = infer,
inputs = [
prompt_in
],
outputs = [
result
]
)
demo.queue().launch()