ysharma's picture
ysharma HF staff
Update app.py
c3f2272
raw
history blame
No virus
3.81 kB
import gradio as gr
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
import torch
import mediapy
import sa_handler
# init models
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False,
set_alpha_to_one=False)
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True,
scheduler=scheduler
).to("cuda")
handler = sa_handler.Handler(pipeline)
sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
share_layer_norm=False,
share_attention=True,
adain_queries=True,
adain_keys=True,
adain_values=False,
)
handler.register(sa_args, )
# example of set of prompts
sets_of_prompts = [
"a toy train. macro photo. 3d game asset",
"a toy airplane. macro photo. 3d game asset",
"a toy bicycle. macro photo. 3d game asset",
"a toy car. macro photo. 3d game asset",
"a toy boat. macro photo. 3d game asset",
]
# run StyleAligned
def style_aligned_sdxl(initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt):
sets_of_prompts = [ prompt + ". " + style_prompt for prompt in [initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5,]]
images = pipeline(sets_of_prompts,).images
#mediapy.show_images(images)
print(images)
return images
with gr.Blocks() as demo:
with gr.Group():
with gr.Column():
with gr.Accordion(label='Enter upto 5 different initial prompts', open=True):
with gr.Row():
initial_prompt1 = gr.Textbox(value='', show_label=False, container=False, placeholder='a toy train')
initial_prompt2 = gr.Textbox(value='', show_label=False, container=False, placeholder='a toy airplane')
initial_prompt3 = gr.Textbox(value='', show_label=False, container=False, placeholder='a toy bicycle')
initial_prompt4 = gr.Textbox(value='', show_label=False, container=False, placeholder='a toy car')
initial_prompt5 = gr.Textbox(value='', show_label=False, container=False, placeholder='a toy boat')
with gr.Row():
style_prompt = gr.Textbox(label="Enter a style prompt", placeholder='macro photo, 3d game asset')
btn = gr.Button("Generate a set of Style-aligned SDXL images",)
output = gr.Gallery(label="Style-Aligned SDXL", )
btn.click(fn=style_aligned_sdxl,
inputs=[initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt],
outputs=output,
api_name="style_aligned_sdxl")
gr.Examples(examples=[
["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "macro photo. 3d game asset."],
["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "BW logo. high contrast."],
["a cat", "a dog", "a bear", "a man on a bicycle", "a girl working on laptop", "minimal origami."],
["Firewoman", "Gradner", "Scientist", "Policewoman", "Saxophone player", "made of claymation, stop motion animation."],
["Firewoman", "Gradner", "Scientist", "Policewoman", "Saxophone player", "sketch, character sheet."],
],
inputs=[initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt],
outputs=[output],
fn=style_aligned_sdxl)
demo.launch()