import gradio as gr from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL from diffusers.utils import load_image from transformers import DPTImageProcessor, DPTForDepthEstimation import torch import mediapy import sa_handler import pipeline_calls # init models depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") feature_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") controlnet = ControlNetModel.from_pretrained( "diffusers/controlnet-depth-sdxl-1.0", variant="fp16", use_safetensors=True, torch_dtype=torch.float16, ).to("cuda") vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda") pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, variant="fp16", use_safetensors=True, torch_dtype=torch.float16, ).to("cuda") pipeline.enable_model_cpu_offload() sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False, share_layer_norm=False, share_attention=True, adain_queries=True, adain_keys=True, adain_values=False, ) handler = sa_handler.Handler(pipeline) handler.register(sa_args, ) # get depth maps def get_depth_maps(image): image = load_image(image) #("./example_image/train.png") depth_image1 = pipeline_calls.get_depth_map(image, feature_processor, depth_estimator) #depth_image2 = load_image("./example_image/sun.png").resize((1024, 1024)) #mediapy.show_images([depth_image1, depth_image2]) return depth_image1 #[depth_image1, depth_image2] # run ControlNet depth with StyleAligned def style_aligned_controlnet(reference_prompt, target_prompt, image) #reference_prompt = "a poster in flat design style" #target_prompts = [target_prompts] #["a train in flat design style", "the sun in flat design style"] controlnet_conditioning_scale = 0.8 num_images_per_prompt = 1 # adjust according to VRAM size depth_map = get_depth_maps(image) latents = torch.randn(1 + num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype) #for deph_map, target_prompt in zip((depth_image1, depth_image2), target_prompts): latents[1:] = torch.randn(num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype) images = pipeline_calls.controlnet_call(pipeline, [reference_prompt, target_prompt], image=deph_map, num_inference_steps=50, controlnet_conditioning_scale=controlnet_conditioning_scale, num_images_per_prompt=num_images_per_prompt, latents=latents) print(f"images -{images}") return images[0] #mediapy.show_images([images[0], deph_map] + images[1:], titles=["reference", "depth"] + [f'result {i}' for i in range(1, len(images))]) # run StyleAligned sets_of_prompts = [ "a toy train. macro photo. 3d game asset", "a toy airplane. macro photo. 3d game asset", "a toy bicycle. macro photo. 3d game asset", "a toy car. macro photo. 3d game asset", "a toy boat. macro photo. 3d game asset", ] with gr.Blocks() as demo: with gr.Row(variant='panel'): with gr.Group(): gr.Markdown("###
Reference Prompt and Image
") ref_prompt = gr.Textbox(label="Enter a Prompt describing the reference image", placeholder='a photo of in