# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Callable, List, Optional, Union import torch import torch.utils.checkpoint from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention import Transformer2DModel from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import is_accelerate_available, logging from .modeling_text_unet import UNetFlatConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): r""" This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Parameters: vqvae ([`VQModel`]): Vector-quantized (VQ) Model to encode and decode images to and from latent representations. bert ([`LDMBertModel`]): Text-encoder model based on [BERT](https://huggingface.co/docs/transformers/model_doc/bert) architecture. tokenizer (`transformers.BertTokenizer`): Tokenizer of class [BertTokenizer](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ tokenizer: CLIPTokenizer image_feature_extractor: CLIPFeatureExtractor text_encoder: CLIPTextModelWithProjection image_unet: UNet2DConditionModel text_unet: UNetFlatConditionModel vae: AutoencoderKL scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] _optional_components = ["text_unet"] def __init__( self, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModelWithProjection, image_unet: UNet2DConditionModel, text_unet: UNetFlatConditionModel, vae: AutoencoderKL, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], ): super().__init__() self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, image_unet=image_unet, text_unet=text_unet, vae=vae, scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if self.text_unet is not None: self._swap_unet_attention_blocks() def _swap_unet_attention_blocks(self): """ Swap the `Transformer2DModel` blocks between the image and text UNets """ for name, module in self.image_unet.named_modules(): if isinstance(module, Transformer2DModel): parent_name, index = name.rsplit(".", 1) index = int(index) self.image_unet.get_submodule(parent_name)[index], self.text_unet.get_submodule(parent_name)[index] = ( self.text_unet.get_submodule(parent_name)[index], self.image_unet.get_submodule(parent_name)[index], ) def remove_unused_weights(self): self.register_modules(text_unet=None) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" Enable sliced attention computation. When this option is enabled, the attention module will split the input tensor in slices, to compute attention in several steps. This is useful to save some memory in exchange for a small speed decrease. Args: slice_size (`str` or `int`, *optional*, defaults to `"auto"`): When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ if slice_size == "auto": if isinstance(self.image_unet.config.attention_head_dim, int): # half the attention head size is usually a good trade-off between # speed and memory slice_size = self.image_unet.config.attention_head_dim // 2 else: # if `attention_head_dim` is a list, take the smallest head size slice_size = min(self.image_unet.config.attention_head_dim) self.image_unet.set_attention_slice(slice_size) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing def disable_attention_slicing(self): r""" Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go back to computing attention in one step. """ # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.image_unet, self.text_unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device with unet->image_unet def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if self.device != torch.device("meta") or not hasattr(self.image_unet, "_hf_hook"): return self.device for module in self.image_unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `list(int)`): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). """ def normalize_embeddings(encoder_output): embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) embeds_pooled = encoder_output.text_embeds embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) return embeds batch_size = len(prompt) if isinstance(prompt, list) else 1 text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids if not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None text_embeddings = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) text_embeddings = normalize_embeddings(text_embeddings) # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = text_embeddings.shape text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt max_length = text_input_ids.shape[-1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None uncond_embeddings = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) uncond_embeddings = normalize_embeddings(uncond_embeddings) # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = uncond_embeddings.shape[1] uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) return text_embeddings # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): latents = 1 / 0.18215 * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs(self, prompt, height, width, callback_steps): if not isinstance(prompt, str) and not isinstance(prompt, list): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if latents is None: if device.type == "mps": # randn does not work reproducibly on mps latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) else: latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) else: if latents.shape != shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: Optional[int] = 1, **kwargs, ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. height (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.image_unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Examples: ```py >>> from diffusers import VersatileDiffusionTextToImagePipeline >>> import torch >>> pipe = VersatileDiffusionTextToImagePipeline.from_pretrained( ... "shi-labs/versatile-diffusion", torch_dtype=torch.float16 ... ) >>> pipe.remove_unused_weights() >>> pipe = pipe.to("cuda") >>> generator = torch.Generator(device="cuda").manual_seed(0) >>> image = pipe("an astronaut riding on a horse on mars", generator=generator).images[0] >>> image.save("./astronaut.png") ``` Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.image_unet.config.sample_size * self.vae_scale_factor width = width or self.image_unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, callback_steps) # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_embeddings = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.image_unet.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, text_embeddings.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop for i, t in enumerate(self.progress_bar(timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.image_unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 9. Post-processing image = self.decode_latents(latents) # 10. Convert to PIL if output_type == "pil": image = self.numpy_to_pil(image) if not return_dict: return (image,) return ImagePipelineOutput(images=image)