sdxl-dpo / app.py
fffiloni's picture
Update app.py
199441c
raw
history blame
3.13 kB
import gradio as gr
import os
hf_token = os.environ.get("HF_TOKEN")
from gradio_client import Client
client = Client("https://fffiloni-safety-checker-bot.hf.space/", hf_token=hf_token)
import re
import spaces
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel
import torch
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
# load pipeline
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = StableDiffusionXLPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to("cuda")
# load finetuned model
unet_id = "mhdang/dpo-sdxl-text2image-v1"
unet = UNet2DConditionModel.from_pretrained(unet_id, subfolder="unet", torch_dtype=torch.float16)
pipe.unet = unet
pipe = pipe.to("cuda")
pipe.enable_model_cpu_offload()
pipe.enable_vae_slicing()
def safety_check(user_prompt):
response = client.predict(
user_prompt, # str in 'User sent this' Textbox component
api_name="/infer"
)
return response
@spaces.GPU(enable_queue=True)
def infer(prompt):
print(f"""
—/n
{prompt}
""")
is_safe = safety_check(prompt)
print(is_safe)
match = re.search(r'\bYes\b', is_safe)
if match:
status = 'Yes'
else:
status = None
if status == "Yes" :
raise gr.Error("Don't ask for such things.")
else:
results = pipe(prompt, guidance_scale=7.5)
#for i in range(len(results.images)):
# if results.nsfw_content_detected[i]:
# results.images[i] = Image.open("nsfw.png")
return results.images[0]
css = """
#col-container{
margin: 0 auto;
max-width: 580px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.HTML("""
<h2 style="text-align: center;">
SDXL Using Direct Preference Optimization
</h2>
<p style="text-align: center;">
Direct Preference Optimization (DPO) for text-to-image diffusion models is a method to align diffusion models to text human preferences by directly optimizing on human comparison data.
</p>
""")
with gr.Group():
with gr.Column():
prompt_in = gr.Textbox(label="Prompt", value="An old man with a bird on his head")
submit_btn = gr.Button("Submit")
result = gr.Image(label="DPO SDXL Result")
gr.Examples(
examples = [
"Dragon, digital art, by Greg Rutkowski",
"Armored knight holding sword",
"A flat roof villa near a river with black walls and huge windows",
"A calm and peaceful office",
"Pirate guinea pig"
],
fn = infer,
inputs = [
prompt_in
],
outputs = [
result
]
)
submit_btn.click(
fn = infer,
inputs = [
prompt_in
],
outputs = [
result
]
)
demo.queue().launch(show_api=False)