import torch import gradio as gr from diffusers import StableDiffusionXLPipeline from utils import ( cross_attn_init, register_cross_attention_hook, attn_maps, get_net_attn_map, resize_net_attn_map, return_net_attn_map, ) # from transformers.utils.hub import move_cache # move_cache() cross_attn_init() pipe = StableDiffusionXLPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", # "stabilityai/sdxl-turbo", torch_dtype=torch.float16, ) pipe.unet = register_cross_attention_hook(pipe.unet) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') pipe = pipe.to(device) def inference(prompt): image = pipe( prompt, num_inference_steps=15, ).images[0] net_attn_maps = get_net_attn_map(image.size) net_attn_maps = resize_net_attn_map(net_attn_maps, image.size) net_attn_maps = return_net_attn_map(net_attn_maps, pipe.tokenizer, prompt) # remove sos and eos net_attn_maps = [attn_map for attn_map in net_attn_maps if attn_map[1].split('_')[-1] != "<<|startoftext|>>"] net_attn_maps = [attn_map for attn_map in net_attn_maps if attn_map[1].split('_')[-1] != "<<|endoftext|>>"] return image, net_attn_maps with gr.Blocks() as demo: gr.Markdown( """ # ๐Ÿš€ Text-to-Image Cross Attention Map for ๐Ÿงจ Diffusers โšก """ ) # prompt = gr.Textbox(value="A photo of a black puppy, christmas atmosphere", label="Prompt", lines=2) prompt = gr.Textbox(value="A portrait photo of a kangaroo wearing an orange hoodie and blue sunglasses standing on the grass in front of the Sydney Opera House holding a sign on the chest that says 'SDXL'!.", label="Prompt", lines=2) btn = gr.Button("Generate images", scale=0) with gr.Row(): image = gr.Image(height=512,width=512,type="pil") gallery = gr.Gallery( value=None, label="Generated images", show_label=False, elem_id="gallery", object_fit="contain", height="auto" ) btn.click(inference, prompt, [image, gallery]) if __name__ == "__main__": demo.launch(share=True)