import gradio as gr from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL from diffusers.utils import load_image from transformers import DPTImageProcessor, DPTForDepthEstimation import torch import mediapy import sa_handler import pipeline_calls # init models depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") feature_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") controlnet = ControlNetModel.from_pretrained( "diffusers/controlnet-depth-sdxl-1.0", variant="fp16", use_safetensors=True, torch_dtype=torch.float16, ).to("cuda") vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda") pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, variant="fp16", use_safetensors=True, torch_dtype=torch.float16, ).to("cuda") pipeline.enable_model_cpu_offload() pipeline.enable_vae_slicing() sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False, share_layer_norm=False, share_attention=True, adain_queries=True, adain_keys=True, adain_values=False, ) handler = sa_handler.Handler(pipeline) handler.register(sa_args, ) # run ControlNet depth with StyleAligned def style_aligned_controlnet(ref_style_prompt, depth_map, ref_image, img_generation_prompt): if depth_map == True: image = load_image(ref_image) depth_image = pipeline_calls.get_depth_map(image, feature_processor, depth_estimator) else: depth_image = load_image(ref_image).resize((1024, 1024)) #reference_prompt = ref_style_prompt #"a poster in minimalist origami style" #target_prompts = img_generation_prompt #["mona lisa"] #, "gal gadot"] controlnet_conditioning_scale = 0.8 num_images_per_prompt = 3 # adjust according to VRAM size latents = torch.randn(1 + num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype) latents[1:] = torch.randn(num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype) images = pipeline_calls.controlnet_call(pipeline, [ref_style_prompt, img_generation_prompt], image=depth_image, num_inference_steps=50, controlnet_conditioning_scale=controlnet_conditioning_scale, num_images_per_prompt=num_images_per_prompt, latents=latents) #mediapy.show_images([images[0], depth_image2] + images[1:], titles=["reference", "depth"] + [f'result {i}' for i in range(1, len(images))]) return [images[0], depth_image] + images[1:], gr.Image(value=images[0], visible=True) with gr.Blocks() as demo: with gr.Row(): with gr.Column(variant='panel'): ref_style_prompt = gr.Textbox( label='Reference style prompt', info="Enter a Prompt to generate the reference image", placeholder='a poster in