oms_sdxl_lcm / app.py
h1t's picture
init demo
d0a7bb3
raw
history blame
4.45 kB
import torch
import gradio as gr
from functools import partial
from diffusers_patch import OMSPipeline
def create_sdxl_lcm_lora_pipe(sd_pipe_name_or_path, oms_name_or_path, lora_name_or_path):
from diffusers import StableDiffusionXLPipeline, LCMScheduler
sd_pipe = StableDiffusionXLPipeline.from_pretrained(sd_pipe_name_or_path, torch_dtype=torch.float16, variant="fp16", add_watermarker=False).to('cuda')
print('successfully load pipe')
sd_scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.load_lora_weights(lora_name_or_path, variant="fp16")
pipe = OMSPipeline.from_pretrained(oms_name_or_path, sd_pipeline = sd_pipe, torch_dtype=torch.float16, variant="fp16", trust_remote_code=True, sd_scheduler=sd_scheduler)
pipe.to('cuda')
return pipe
class GradioDemo:
def __init__(
self,
sd_pipe_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0",
oms_name_or_path = 'h1t/oms_b_openclip_xl',
lora_name_or_path = 'latent-consistency/lcm-lora-sdxl'
):
self.pipe = create_sdxl_lcm_lora_pipe(sd_pipe_name_or_path, oms_name_or_path, lora_name_or_path)
def _inference(
self,
prompt = None,
oms_prompt = None,
oms_guidance_scale = 1.0,
num_inference_steps = 4,
sd_pipe_guidance_scale = 1.0,
seed = 1024,
):
pipe_kwargs = dict(
prompt = prompt,
num_inference_steps = num_inference_steps,
guidance_scale = sd_pipe_guidance_scale,
)
generator = torch.Generator(device=self.pipe.device).manual_seed(seed)
pipe_kwargs.update(oms_flag=False)
print(f'raw kwargs: {pipe_kwargs}')
image_raw = self.pipe(
**pipe_kwargs,
generator=generator
)['images'][0]
generator = torch.Generator(device=self.pipe.device).manual_seed(seed)
pipe_kwargs.update(oms_flag=True, oms_prompt=oms_prompt, oms_guidance_scale=1.0)
print(f'w/ oms wo/ cfg kwargs: {pipe_kwargs}')
image_oms = self.pipe(
**pipe_kwargs,
generator=generator
)['images'][0]
oms_guidance_flag = oms_guidance_scale != 1.0
if oms_guidance_flag:
generator = torch.Generator(device=self.pipe.device).manual_seed(seed)
pipe_kwargs.update(oms_guidance_scale=oms_guidance_scale)
print(f'w/ oms +cfg kwargs: {pipe_kwargs}')
image_oms_cfg = self.pipe(
**pipe_kwargs,
generator=generator
)['images'][0]
else:
image_oms_cfg = None
return image_raw, image_oms, image_oms_cfg, gr.update(visible=oms_guidance_flag)
def mainloop(self):
with gr.Blocks() as demo:
gr.Markdown("# One More Step Demo")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="a cat")
oms_prompt = gr.Textbox(label="OMS Prompt", value="orange car")
oms_guidance_scale = gr.Slider(label="OMS Guidance Scale", minimum=1.0, maximum=5.0, value=1.5, step=0.1)
run_button = gr.Button(value="Generate images")
with gr.Accordion("Advanced options", open=False):
num_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=4, step=1)
sd_guidance_scale = gr.Slider(label="SD Pipe Guidance Scale", minimum=0.1, maximum=30.0, value=1.0, step=0.1)
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=False, value=1024)
with gr.Column():
output_raw = gr.Image(label="SDXL w/ LCM-LoRA w/o OMS ")
output_oms = gr.Image(label="w/ OMS w/o OMS CFG")
with gr.Column(visible=False) as oms_cfg_wd:
output_oms_cfg = gr.Image(label=f"w/ OMS w/ OMS CFG")
ips = [prompt, oms_prompt, oms_guidance_scale, num_steps, sd_guidance_scale, seed]
run_button.click(fn=self._inference, inputs=ips, outputs=[output_raw, output_oms, output_oms_cfg, oms_cfg_wd])
demo.queue(max_size=20)
demo.launch()
if __name__ == "__main__":
gradio_demo = GradioDemo()
gradio_demo.mainloop()