import spaces import gradio as gr import numpy as np import random from diffusers import DiffusionPipeline import torch import random from diffusers import ( ControlNetModel, DiffusionPipeline, StableDiffusionControlNetPipeline, StableDiffusionXLControlNetPipeline, UniPCMultistepScheduler, EulerDiscreteScheduler, AutoencoderKL ) from transformers import DPTFeatureExtractor, DPTForDepthEstimation, DPTImageProcessor from transformers import CLIPImageProcessor from diffusers.utils import load_image device = "cuda" base_model_id = "SG161222/RealVisXL_V4.0" controlnet_model_id = "diffusers/controlnet-depth-sdxl-1.0" vae_model_id = "madebyollin/sdxl-vae-fp16-fix" # load pipe controlnet = ControlNetModel.from_pretrained(controlnet_model_id, variant="fp16", use_safetensors=True, torch_dtype=torch.float16) vae = AutoencoderKL.from_pretrained(vae_model_id, torch_dtype=torch.float16) pipe = StableDiffusionXLControlNetPipeline.from_pretrained( base_model_id, controlnet=controlnet, vae=vae, variant="fp16", use_safetensors=True, torch_dtype=torch.float16, ) pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) pipe.enable_model_cpu_offload() pipe.enable_xformers_memory_efficient_attention() pipe.to(device) depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda") feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 USE_TORCH_COMPILE = 0 ENABLE_CPU_OFFLOAD = 0 def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: if randomize_seed: seed = random.randint(0, MAX_SEED) return seed def get_depth_map(image): image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda") with torch.no_grad(), torch.autocast("cuda"): depth_map = depth_estimator(image).predicted_depth depth_map = torch.nn.functional.interpolate( depth_map.unsqueeze(1), size=(1024, 1024), mode="bicubic", align_corners=False, ) depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) depth_map = (depth_map - depth_min) / (depth_max - depth_min) image = torch.cat([depth_map] * 3, dim=1) image = image.permute(0, 2, 3, 1).cpu().numpy()[0] image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8)) return image @spaces.GPU(enable_queue=True) def process(orginal_image, image_url, prompt, a_prompt, n_prompt, num_steps, guidance_scale, control_strength, seed): if image_url: orginal_image = load_image(image_url) width = 1024 height = 1024 depth_image = get_depth_map(orginal_image.resize((1024, 1024))) generator = torch.Generator().manual_seed(seed) generated_image = self.pipe( prompt=prompt, negative_prompt=n_prompt, width=width, height=height, guidance_scale=guidance_scale, num_inference_steps=num_steps, strength=control_strength, generator=generator, image=depth_image, ).images[0] return [[depth_image, generated_image], "ok"] with gr.Blocks() as demo: with gr.Row(): with gr.Column(): image = gr.Image() image_url = gr.Textbox(label="Image Url", placeholder="Enter image URL here (optional)") prompt = gr.Textbox(label="Prompt") run_button = gr.Button("Run") with gr.Accordion("Advanced options", open=True): num_steps = gr.Slider(label="Number of steps", minimum=1, maximum=100, value=30, step=1) guidance_scale = gr.Slider(label="Guidance scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1) control_strength = gr.Slider(label="Control Strength", minimum=0.1, maximum=4.0, value=0.8, step=0.1) seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) a_prompt = gr.Textbox(label="Additional prompt", value="high-quality, extremely detailed, 4K") n_prompt = gr.Textbox( label="Negative prompt", value="longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality", ) with gr.Column(): result = ImageSlider(label="Generate image", type="pil", slider_color="pink") logs = gr.Textbox(label="logs") inputs = [ image, image_url, prompt, a_prompt, n_prompt, num_steps, guidance_scale, control_strength, seed ] run_button.click( fn=randomize_seed_fn, inputs=[seed, randomize_seed], outputs=seed, queue=False, api_name=False, ).then( fn=process, inputs=inputs, outputs=[result, logs], api_name=False ) return demo demo.queue().launch()