fffiloni's picture
Update app.py
8d5c6a5
raw
history blame
1.64 kB
import gradio as gr
import spaces
from diffusers import StableDiffusionXLPipeline, DDIMScheduler
import torch
import sa_handler
# init models
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False,
set_alpha_to_one=False)
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True,
scheduler=scheduler
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_vae_slicing()
handler = sa_handler.Handler(pipeline)
sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
share_layer_norm=False,
share_attention=True,
adain_queries=True,
adain_keys=True,
adain_values=False,
)
handler.register(sa_args, )
# run StyleAligned
@spaces.GPU
def infer(prompts):
sets_of_prompts = [
"a toy train. macro photo. 3d game asset",
"a toy airplane. macro photo. 3d game asset",
"a toy bicycle. macro photo. 3d game asset",
"a toy car. macro photo. 3d game asset",
"a toy boat. macro photo. 3d game asset",
]
images = pipeline(sets_of_prompts,).images
return images
gr.Interface(
fn=infer,
inputs=[
gr.Textbox(value="Hit submit button to test")
],
outputs=[
gr.Gallery()
],
title="Style Aligned Image Generation"
).launch()