File size: 6,261 Bytes
d0a7bb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1920e68
d0a7bb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1920e68
 
 
d0a7bb3
 
 
 
1920e68
 
 
 
 
 
 
 
 
 
 
d0a7bb3
 
 
 
1920e68
d0a7bb3
 
 
 
 
 
 
1920e68
d0a7bb3
 
 
1920e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d0a7bb3
1920e68
7902b52
 
1920e68
d0a7bb3
1920e68
 
 
 
 
 
 
 
 
 
 
 
 
d0a7bb3
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
import gradio as gr
from functools import partial

from diffusers_patch import OMSPipeline


def create_sdxl_lcm_lora_pipe(sd_pipe_name_or_path, oms_name_or_path, lora_name_or_path):
    from diffusers import StableDiffusionXLPipeline, LCMScheduler
    sd_pipe = StableDiffusionXLPipeline.from_pretrained(sd_pipe_name_or_path, torch_dtype=torch.float16, variant="fp16", add_watermarker=False).to('cuda')
    print('successfully load pipe')
    sd_scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
    sd_pipe.load_lora_weights(lora_name_or_path, variant="fp16")

    pipe = OMSPipeline.from_pretrained(oms_name_or_path, sd_pipeline = sd_pipe, torch_dtype=torch.float16, variant="fp16", trust_remote_code=True, sd_scheduler=sd_scheduler)
    pipe.to('cuda')
    
    return pipe
        

class GradioDemo:
    def __init__(
        self, 
        sd_pipe_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0", 
        oms_name_or_path = 'h1t/oms_b_openclip_xl', 
        lora_name_or_path = 'latent-consistency/lcm-lora-sdxl'
    ):
        self.pipe = create_sdxl_lcm_lora_pipe(sd_pipe_name_or_path, oms_name_or_path, lora_name_or_path)
        
    def _inference(
            self,
            prompt = None,
            oms_prompt = None,
            oms_guidance_scale = 1.0,
            num_inference_steps = 4,
            sd_pipe_guidance_scale = 1.0,
            seed = 1024,
            oms_prompt_flag=True,
        ):
        pipe_kwargs = dict(
            prompt = prompt,
            num_inference_steps = num_inference_steps,
            guidance_scale = sd_pipe_guidance_scale,
        )

        generator = torch.Generator(device=self.pipe.device).manual_seed(seed)
        pipe_kwargs.update(oms_flag=False)
        print(f'raw kwargs: {pipe_kwargs}')
        image_raw = self.pipe(
            **pipe_kwargs,
            generator=generator
        )['images'][0]
        
        generator = torch.Generator(device=self.pipe.device).manual_seed(seed)
        pipe_kwargs.update(oms_flag=True, oms_prompt=prompt, oms_guidance_scale=1.0)
        print(f'w/ oms wo/ cfg (consistent) kwargs: {pipe_kwargs}')
        image_oms_cp = self.pipe(
            **pipe_kwargs,
            generator=generator
        )['images'][0]
        
        if oms_prompt_flag:
            generator = torch.Generator(device=self.pipe.device).manual_seed(seed)
            pipe_kwargs.update(oms_prompt=oms_prompt)
            print(f'w/ oms wo/ cfg (inconsistent) kwargs: {pipe_kwargs}')
            image_oms_icp = self.pipe(
                **pipe_kwargs,
                generator=generator
            )['images'][0]
        else:
            image_oms_icp = None
            
        oms_guidance_flag = oms_guidance_scale != 1.0
        if oms_guidance_flag:
            generator = torch.Generator(device=self.pipe.device).manual_seed(seed)
            pipe_kwargs.update(oms_guidance_scale=oms_guidance_scale)
            print(f'w/ oms +cfg (inconsistent) kwargs: {pipe_kwargs}')
            image_oms_cfg = self.pipe(
                **pipe_kwargs,
                generator=generator
            )['images'][0]
        else:
            image_oms_cfg = None
        
        return image_raw, image_oms_cp, image_oms_icp, image_oms_cfg, gr.update(visible=oms_prompt_flag), gr.update(visible=oms_guidance_flag)
    
    def mainloop(self):
        with gr.Blocks() as demo:
            gr.Markdown("# One More Step for SDXL w/ LCM-LoRA")
            
            with gr.Group() as inputs:
                prompt = gr.Textbox(label="Prompt", value="a cat against orange ground, studio")
                with gr.Accordion('OMS Prompt'):
                    oms_prompt_checkbox = gr.Checkbox(info="Inconsistent OMS prompt allows the additional control of low freq info, default is the same as Prompt.", label="Adding OMS Prompt", value=True)
                    oms_prompt = gr.Textbox(label="OMS Prompt", value="a black cat", info='try "a black cat" and "a black room" for diverse control.')
                with gr.Accordion('OMS Guidance'):
                    oms_cfg_scale_checkbox = gr.Checkbox(info="OMS Guidance will enhance the OMS prompt, specially focus on color and brightness. ", label="Adding OMS Guidance", value=True)
                    oms_guidance_scale = gr.Slider(label="OMS Guidance Scale", minimum=1.0, maximum=5.0, value=2., step=0.1)
                run_button = gr.Button(value="Generate images")
                with gr.Accordion("Advanced options", open=False):
                    num_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=4, step=1)
                    sd_guidance_scale = gr.Slider(label="SD Pipe Guidance Scale", minimum=1, maximum=3, value=1.0, step=0.1)
                    seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=False, value=1024)
            with gr.Row():
                output_raw = gr.Image(label="SDXL w/ LCM-LoRA ")
                output_oms_cp = gr.Image(label="w/ OMS (consistent) w/o OMS CFG")
                output_oms_icp = gr.Image(label="w/ OMS (inconsistent) w/o OMS CFG")
                output_oms_cfg = gr.Image(label="w/ OMS w/ OMS CFG")

            oms_prompt_checkbox.input(
                fn=lambda oms_prompt_flag, prompt, oms_prompt: (oms_prompt if oms_prompt_flag else prompt, gr.update(interactive=oms_prompt_flag)), 
                inputs=[oms_prompt_checkbox, prompt, oms_prompt], 
                outputs=[oms_prompt, oms_prompt]
            )
            oms_cfg_scale_checkbox.input(
                fn=lambda oms_cfg_scale_flag: (1.5 if oms_cfg_scale_flag else 1.0, gr.update(interactive=oms_cfg_scale_flag)), 
                inputs=[oms_cfg_scale_checkbox], 
                outputs=[oms_guidance_scale, oms_guidance_scale]
            )
            
            ips = [prompt, oms_prompt, oms_guidance_scale, num_steps, sd_guidance_scale, seed, oms_prompt_checkbox]
            run_button.click(fn=self._inference, inputs=ips, outputs=[output_raw, output_oms_cp, output_oms_icp, output_oms_cfg, output_oms_icp, output_oms_cfg])

            demo.queue(max_size=20)
            demo.launch()


if __name__ == "__main__":
    gradio_demo = GradioDemo()
    gradio_demo.mainloop()