File size: 3,880 Bytes
fd1c028
 
 
 
 
 
 
 
 
 
 
 
 
ce70a4b
 
fd1c028
 
 
 
 
 
 
 
 
 
 
 
c3f2272
fd1c028
 
 
 
 
 
 
 
c3f2272
 
 
 
fd1c028
 
c3f2272
fd1c028
 
 
c3f2272
 
 
 
 
 
 
 
 
 
 
 
fd1c028
c3f2272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd1c028
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
import torch
import mediapy
import sa_handler

# init models
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False,
                              set_alpha_to_one=False)
pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True,
    scheduler=scheduler
).to("cuda")
pipeline.enable_sequential_cpu_offload()
pipeline.enable_vae_slicing()

handler = sa_handler.Handler(pipeline)
sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
                                      share_layer_norm=False,
                                      share_attention=True,
                                      adain_queries=True,
                                      adain_keys=True,
                                      adain_values=False,
                                     )

handler.register(sa_args, )

# example of set of prompts
sets_of_prompts = [
  "a toy train. macro photo. 3d game asset",
  "a toy airplane. macro photo. 3d game asset",
  "a toy bicycle. macro photo. 3d game asset",
  "a toy car. macro photo. 3d game asset",
  "a toy boat. macro photo. 3d game asset",
]

# run StyleAligned
def style_aligned_sdxl(initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt): 
    sets_of_prompts = [ prompt + ". " + style_prompt for prompt in [initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5,]]
    images = pipeline(sets_of_prompts,).images
    #mediapy.show_images(images)
    print(images)
    return images

with gr.Blocks() as demo:
    with gr.Group():
      with gr.Column():
        with gr.Accordion(label='Enter upto 5 different initial prompts', open=True):
          with gr.Row():
            initial_prompt1 = gr.Textbox(value='', show_label=False, container=False, placeholder='a toy train')
            initial_prompt2 = gr.Textbox(value='', show_label=False, container=False, placeholder='a toy airplane')
            initial_prompt3 = gr.Textbox(value='', show_label=False, container=False, placeholder='a toy bicycle')
            initial_prompt4 = gr.Textbox(value='', show_label=False, container=False, placeholder='a toy car')
            initial_prompt5 = gr.Textbox(value='', show_label=False, container=False, placeholder='a toy boat')
        with gr.Row():
          style_prompt = gr.Textbox(label="Enter a style prompt", placeholder='macro photo, 3d game asset')
        btn = gr.Button("Generate a set of Style-aligned SDXL images",)
    output = gr.Gallery(label="Style-Aligned SDXL", )
    
    btn.click(fn=style_aligned_sdxl, 
              inputs=[initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt], 
              outputs=output, 
              api_name="style_aligned_sdxl")

    gr.Examples(examples=[
                    ["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "macro photo. 3d game asset."],
                    ["a toy train", "a toy airplane", "a toy bicycle", "a toy car", "a toy boat", "BW logo. high contrast."],
                    ["a cat", "a dog", "a bear", "a man on a bicycle", "a girl working on laptop", "minimal origami."],
                    ["Firewoman", "Gradner", "Scientist", "Policewoman", "Saxophone player", "made of claymation, stop motion animation."],
                    ["Firewoman", "Gradner", "Scientist", "Policewoman", "Saxophone player", "sketch, character sheet."],
                    ],
            inputs=[initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt],
            outputs=[output],
            fn=style_aligned_sdxl)

demo.launch()