import argparse import inspect import os import time import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL.Image import torch from PIL import Image from transformers import CLIPTokenizer from diffusers import OnnxRuntimeModel, StableDiffusionImg2ImgPipeline, UniPCMultistepScheduler from diffusers.image_processor import VaeImageProcessor from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import ( deprecate, logging, randn_tensor, replace_example_docstring, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> # !pip install opencv-python transformers accelerate >>> from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler >>> from diffusers.utils import load_image >>> import numpy as np >>> import torch >>> import cv2 >>> from PIL import Image >>> # download an image >>> image = load_image( ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" ... ) >>> np_image = np.array(image) >>> # get canny image >>> np_image = cv2.Canny(np_image, 100, 200) >>> np_image = np_image[:, :, None] >>> np_image = np.concatenate([np_image, np_image, np_image], axis=2) >>> canny_image = Image.fromarray(np_image) >>> # load control net and stable diffusion v1-5 >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) >>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 ... ) >>> # speed up diffusion process with faster scheduler and memory optimization >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) >>> pipe.enable_model_cpu_offload() >>> # generate image >>> generator = torch.manual_seed(0) >>> image = pipe( ... "futuristic-looking woman", ... num_inference_steps=20, ... generator=generator, ... image=image, ... control_image=canny_image, ... ).images[0] ``` """ def prepare_image(image): if isinstance(image, torch.Tensor): # Batch single image if image.ndim == 3: image = image.unsqueeze(0) image = image.to(dtype=torch.float32) else: # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): image = np.concatenate([i[None, :] for i in image], axis=0) image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 return image class OnnxStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline): vae_encoder: OnnxRuntimeModel vae_decoder: OnnxRuntimeModel text_encoder: OnnxRuntimeModel tokenizer: CLIPTokenizer unet: OnnxRuntimeModel scheduler: KarrasDiffusionSchedulers def __init__( self, vae_encoder: OnnxRuntimeModel, vae_decoder: OnnxRuntimeModel, text_encoder: OnnxRuntimeModel, tokenizer: CLIPTokenizer, unet: OnnxRuntimeModel, scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( vae_encoder=vae_encoder, vae_decoder=vae_decoder, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (4 - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) def _encode_prompt( self, prompt: Union[str, List[str]], num_images_per_prompt: Optional[int], do_classifier_free_guidance: bool, negative_prompt: Optional[str], prompt_embeds: Optional[np.ndarray] = None, negative_prompt_embeds: Optional[np.ndarray] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`): prompt to be encoded num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`np.ndarray`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`np.ndarray`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # get prompt text embeddings text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids if not np.array_equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0] prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] * batch_size elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="np", ) negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] if do_classifier_free_guidance: negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = np.concatenate([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, num_controlnet, prompt, image, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, controlnet_conditioning_scale=1.0, control_guidance_start=0.0, control_guidance_end=1.0, ): if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # Check `image` if num_controlnet == 1: self.check_image(image, prompt, prompt_embeds) elif num_controlnet > 1: if not isinstance(image, list): raise TypeError("For multiple controlnets: `image` must be type `list`") # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif len(image) != num_controlnet: raise ValueError( f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {num_controlnet} ControlNets." ) for image_ in image: self.check_image(image_, prompt, prompt_embeds) else: assert False # Check `controlnet_conditioning_scale` if num_controlnet == 1: if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") elif num_controlnet > 1: if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif ( isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != num_controlnet ): raise ValueError( "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" " the same length as the number of controlnets" ) else: assert False if len(control_guidance_start) != len(control_guidance_end): raise ValueError( f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." ) if num_controlnet > 1: if len(control_guidance_start) != num_controlnet: raise ValueError( f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {num_controlnet} controlnets available. Make sure to provide {num_controlnet}." ) for start, end in zip(control_guidance_start, control_guidance_end): if start >= end: raise ValueError( f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." ) if start < 0.0: raise ValueError(f"control guidance start: {start} can't be smaller than 0.") if end > 1.0: raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) if ( not image_is_pil and not image_is_tensor and not image_is_np and not image_is_pil_list and not image_is_tensor_list and not image_is_np_list ): raise TypeError( f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" ) if image_is_pil: image_batch_size = 1 else: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if image_batch_size != 1 and image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image def prepare_control_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: _image = image.cpu().detach().numpy() init_latents = self.vae_encoder(sample=_image)[0] init_latents = torch.from_numpy(init_latents).to(device=device, dtype=dtype) init_latents = 0.18215 * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, num_controlnet: int, fp16: bool = True, prompt: Union[str, List[str]] = None, image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, control_image: Union[ torch.FloatTensor, PIL.Image.Image, np.ndarray, List[torch.FloatTensor], List[PIL.Image.Image], List[np.ndarray], ] = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 0.8, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 0.8, guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The initial image will be used as the starting point for the image generation process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded again. control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in init, images must be passed as a list such that each element of the list can be correctly batched for input to a single controlnet. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original unet. If multiple ControlNets are specified in init, you can set the corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting than for [`~StableDiffusionControlNetPipeline.__call__`]. guess_mode (`bool`, *optional*, defaults to `False`): In this mode, the ControlNet encoder will try best to recognize the content of the input image even if you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): The percentage of total steps at which the controlnet starts applying. control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): The percentage of total steps at which the controlnet stops applying. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ if fp16: torch_dtype = torch.float16 np_dtype = np.float16 else: torch_dtype = torch.float32 np_dtype = np.float32 # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): mult = num_controlnet control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ control_guidance_end ] # 1. Check inputs. Raise error if not correct self.check_inputs( num_controlnet, prompt, control_image, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, controlnet_conditioning_scale, control_guidance_start, control_guidance_end, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 if num_controlnet > 1 and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * num_controlnet # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) # 4. Prepare image image = self.image_processor.preprocess(image).to(dtype=torch.float32) # 5. Prepare controlnet_conditioning_image if num_controlnet == 1: control_image = self.prepare_control_image( image=control_image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=torch_dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) elif num_controlnet > 1: control_images = [] for control_image_ in control_image: control_image_ = self.prepare_control_image( image=control_image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=torch_dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) control_images.append(control_image_) control_image = control_images else: assert False # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, torch_dtype, device, generator, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Create tensor stating which controlnets to keep controlnet_keep = [] for i in range(len(timesteps)): keeps = [ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) for s, e in zip(control_guidance_start, control_guidance_end) ] controlnet_keep.append(keeps[0] if num_controlnet == 1 else keeps) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) if isinstance(controlnet_keep[i], list): cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] else: controlnet_cond_scale = controlnet_conditioning_scale if isinstance(controlnet_cond_scale, list): controlnet_cond_scale = controlnet_cond_scale[0] cond_scale = controlnet_cond_scale * controlnet_keep[i] # predict the noise residual _latent_model_input = latent_model_input.cpu().detach().numpy() _prompt_embeds = np.array(prompt_embeds, dtype=np_dtype) _t = np.array([t.cpu().detach().numpy()], dtype=np_dtype) if num_controlnet == 1: control_images = np.array([control_image], dtype=np_dtype) else: control_images = [] for _control_img in control_image: _control_img = _control_img.cpu().detach().numpy() control_images.append(_control_img) control_images = np.array(control_images, dtype=np_dtype) control_scales = np.array(cond_scale, dtype=np_dtype) control_scales = np.resize(control_scales, (num_controlnet, 1)) noise_pred = self.unet( sample=_latent_model_input, timestep=_t, encoder_hidden_states=_prompt_embeds, controlnet_conds=control_images, conditioning_scales=control_scales, )[0] noise_pred = torch.from_numpy(noise_pred).to(device) # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": _latents = latents.cpu().detach().numpy() / 0.18215 _latents = np.array(_latents, dtype=np_dtype) image = self.vae_decoder(latent_sample=_latents)[0] image = torch.from_numpy(image).to(device, dtype=torch.float32) has_nsfw_concept = None else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--sd_model", type=str, required=True, help="Path to the `diffusers` checkpoint to convert (either a local directory or on the Hub).", ) parser.add_argument( "--onnx_model_dir", type=str, required=True, help="Path to the ONNX directory", ) parser.add_argument("--qr_img_path", type=str, required=True, help="Path to the qr code image") args = parser.parse_args() qr_image = Image.open(args.qr_img_path) qr_image = qr_image.resize((512, 512)) # init stable diffusion pipeline pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(args.sd_model) pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) provider = ["CUDAExecutionProvider", "CPUExecutionProvider"] onnx_pipeline = OnnxStableDiffusionControlNetImg2ImgPipeline( vae_encoder=OnnxRuntimeModel.from_pretrained( os.path.join(args.onnx_model_dir, "vae_encoder"), provider=provider ), vae_decoder=OnnxRuntimeModel.from_pretrained( os.path.join(args.onnx_model_dir, "vae_decoder"), provider=provider ), text_encoder=OnnxRuntimeModel.from_pretrained( os.path.join(args.onnx_model_dir, "text_encoder"), provider=provider ), tokenizer=pipeline.tokenizer, unet=OnnxRuntimeModel.from_pretrained(os.path.join(args.onnx_model_dir, "unet"), provider=provider), scheduler=pipeline.scheduler, ) onnx_pipeline = onnx_pipeline.to("cuda") prompt = "a cute cat fly to the moon" negative_prompt = "paintings, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, age spot, glans, nsfw, nipples, necklace, worst quality, low quality, watermark, username, signature, multiple breasts, lowres, bad anatomy, bad hands, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, bad feet, single color, ugly, duplicate, morbid, mutilated, tranny, trans, trannsexual, hermaphrodite, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, disfigured, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, bad body perspect" for i in range(10): start_time = time.time() image = onnx_pipeline( num_controlnet=2, prompt=prompt, negative_prompt=negative_prompt, image=qr_image, control_image=[qr_image, qr_image], width=512, height=512, strength=0.75, num_inference_steps=20, num_images_per_prompt=1, controlnet_conditioning_scale=[0.8, 0.8], control_guidance_start=[0.3, 0.3], control_guidance_end=[0.9, 0.9], ).images[0] print(time.time() - start_time) image.save("output_qr_code.png")