| 
							 | 
						""" | 
					
					
						
						| 
							 | 
						modeled after the textual_inversion.py / train_dreambooth.py and the work | 
					
					
						
						| 
							 | 
						of justinpinkney here: https://github.com/justinpinkney/stable-diffusion/blob/main/notebooks/imagic.ipynb | 
					
					
						
						| 
							 | 
						""" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import inspect | 
					
					
						
						| 
							 | 
						import warnings | 
					
					
						
						| 
							 | 
						from typing import List, Optional, Union | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import numpy as np | 
					
					
						
						| 
							 | 
						import PIL.Image | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import torch.nn.functional as F | 
					
					
						
						| 
							 | 
						from accelerate import Accelerator | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						from packaging import version | 
					
					
						
						| 
							 | 
						from tqdm.auto import tqdm | 
					
					
						
						| 
							 | 
						from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from diffusers import DiffusionPipeline | 
					
					
						
						| 
							 | 
						from diffusers.models import AutoencoderKL, UNet2DConditionModel | 
					
					
						
						| 
							 | 
						from diffusers.pipelines.pipeline_utils import StableDiffusionMixin | 
					
					
						
						| 
							 | 
						from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput | 
					
					
						
						| 
							 | 
						from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker | 
					
					
						
						| 
							 | 
						from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler | 
					
					
						
						| 
							 | 
						from diffusers.utils import logging | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): | 
					
					
						
						| 
							 | 
						    PIL_INTERPOLATION = { | 
					
					
						
						| 
							 | 
						        "linear": PIL.Image.Resampling.BILINEAR, | 
					
					
						
						| 
							 | 
						        "bilinear": PIL.Image.Resampling.BILINEAR, | 
					
					
						
						| 
							 | 
						        "bicubic": PIL.Image.Resampling.BICUBIC, | 
					
					
						
						| 
							 | 
						        "lanczos": PIL.Image.Resampling.LANCZOS, | 
					
					
						
						| 
							 | 
						        "nearest": PIL.Image.Resampling.NEAREST, | 
					
					
						
						| 
							 | 
						    } | 
					
					
						
						| 
							 | 
						else: | 
					
					
						
						| 
							 | 
						    PIL_INTERPOLATION = { | 
					
					
						
						| 
							 | 
						        "linear": PIL.Image.LINEAR, | 
					
					
						
						| 
							 | 
						        "bilinear": PIL.Image.BILINEAR, | 
					
					
						
						| 
							 | 
						        "bicubic": PIL.Image.BICUBIC, | 
					
					
						
						| 
							 | 
						        "lanczos": PIL.Image.LANCZOS, | 
					
					
						
						| 
							 | 
						        "nearest": PIL.Image.NEAREST, | 
					
					
						
						| 
							 | 
						    } | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						logger = logging.get_logger(__name__)   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def preprocess(image): | 
					
					
						
						| 
							 | 
						    w, h = image.size | 
					
					
						
						| 
							 | 
						    w, h = (x - x % 32 for x in (w, h))   | 
					
					
						
						| 
							 | 
						    image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) | 
					
					
						
						| 
							 | 
						    image = np.array(image).astype(np.float32) / 255.0 | 
					
					
						
						| 
							 | 
						    image = image[None].transpose(0, 3, 1, 2) | 
					
					
						
						| 
							 | 
						    image = torch.from_numpy(image) | 
					
					
						
						| 
							 | 
						    return 2.0 * image - 1.0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						class ImagicStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin): | 
					
					
						
						| 
							 | 
						    r""" | 
					
					
						
						| 
							 | 
						    Pipeline for imagic image editing. | 
					
					
						
						| 
							 | 
						    See paper here: https://arxiv.org/pdf/2210.09276.pdf | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the | 
					
					
						
						| 
							 | 
						    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        vae ([`AutoencoderKL`]): | 
					
					
						
						| 
							 | 
						            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. | 
					
					
						
						| 
							 | 
						        text_encoder ([`CLIPTextModel`]): | 
					
					
						
						| 
							 | 
						            Frozen text-encoder. Stable Diffusion uses the text portion of | 
					
					
						
						| 
							 | 
						            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically | 
					
					
						
						| 
							 | 
						            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. | 
					
					
						
						| 
							 | 
						        tokenizer (`CLIPTokenizer`): | 
					
					
						
						| 
							 | 
						            Tokenizer of class | 
					
					
						
						| 
							 | 
						            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). | 
					
					
						
						| 
							 | 
						        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. | 
					
					
						
						| 
							 | 
						        scheduler ([`SchedulerMixin`]): | 
					
					
						
						| 
							 | 
						            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of | 
					
					
						
						| 
							 | 
						            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. | 
					
					
						
						| 
							 | 
						        safety_checker ([`StableDiffusionSafetyChecker`]): | 
					
					
						
						| 
							 | 
						            Classification module that estimates whether generated images could be considered offsensive or harmful. | 
					
					
						
						| 
							 | 
						            Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. | 
					
					
						
						| 
							 | 
						        feature_extractor ([`CLIPImageProcessor`]): | 
					
					
						
						| 
							 | 
						            Model that extracts features from generated images to be used as inputs for the `safety_checker`. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        vae: AutoencoderKL, | 
					
					
						
						| 
							 | 
						        text_encoder: CLIPTextModel, | 
					
					
						
						| 
							 | 
						        tokenizer: CLIPTokenizer, | 
					
					
						
						| 
							 | 
						        unet: UNet2DConditionModel, | 
					
					
						
						| 
							 | 
						        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], | 
					
					
						
						| 
							 | 
						        safety_checker: StableDiffusionSafetyChecker, | 
					
					
						
						| 
							 | 
						        feature_extractor: CLIPImageProcessor, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						        self.register_modules( | 
					
					
						
						| 
							 | 
						            vae=vae, | 
					
					
						
						| 
							 | 
						            text_encoder=text_encoder, | 
					
					
						
						| 
							 | 
						            tokenizer=tokenizer, | 
					
					
						
						| 
							 | 
						            unet=unet, | 
					
					
						
						| 
							 | 
						            scheduler=scheduler, | 
					
					
						
						| 
							 | 
						            safety_checker=safety_checker, | 
					
					
						
						| 
							 | 
						            feature_extractor=feature_extractor, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def train( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        prompt: Union[str, List[str]], | 
					
					
						
						| 
							 | 
						        image: Union[torch.Tensor, PIL.Image.Image], | 
					
					
						
						| 
							 | 
						        height: Optional[int] = 512, | 
					
					
						
						| 
							 | 
						        width: Optional[int] = 512, | 
					
					
						
						| 
							 | 
						        generator: Optional[torch.Generator] = None, | 
					
					
						
						| 
							 | 
						        embedding_learning_rate: float = 0.001, | 
					
					
						
						| 
							 | 
						        diffusion_model_learning_rate: float = 2e-6, | 
					
					
						
						| 
							 | 
						        text_embedding_optimization_steps: int = 500, | 
					
					
						
						| 
							 | 
						        model_fine_tuning_optimization_steps: int = 1000, | 
					
					
						
						| 
							 | 
						        **kwargs, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        r""" | 
					
					
						
						| 
							 | 
						        Function invoked when calling the pipeline for generation. | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            prompt (`str` or `List[str]`): | 
					
					
						
						| 
							 | 
						                The prompt or prompts to guide the image generation. | 
					
					
						
						| 
							 | 
						            height (`int`, *optional*, defaults to 512): | 
					
					
						
						| 
							 | 
						                The height in pixels of the generated image. | 
					
					
						
						| 
							 | 
						            width (`int`, *optional*, defaults to 512): | 
					
					
						
						| 
							 | 
						                The width in pixels of the generated image. | 
					
					
						
						| 
							 | 
						            num_inference_steps (`int`, *optional*, defaults to 50): | 
					
					
						
						| 
							 | 
						                The number of denoising steps. More denoising steps usually lead to a higher quality image at the | 
					
					
						
						| 
							 | 
						                expense of slower inference. | 
					
					
						
						| 
							 | 
						            guidance_scale (`float`, *optional*, defaults to 7.5): | 
					
					
						
						| 
							 | 
						                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). | 
					
					
						
						| 
							 | 
						                `guidance_scale` is defined as `w` of equation 2. of [Imagen | 
					
					
						
						| 
							 | 
						                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > | 
					
					
						
						| 
							 | 
						                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, | 
					
					
						
						| 
							 | 
						                usually at the expense of lower image quality. | 
					
					
						
						| 
							 | 
						            eta (`float`, *optional*, defaults to 0.0): | 
					
					
						
						| 
							 | 
						                Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to | 
					
					
						
						| 
							 | 
						                [`schedulers.DDIMScheduler`], will be ignored for others. | 
					
					
						
						| 
							 | 
						            generator (`torch.Generator`, *optional*): | 
					
					
						
						| 
							 | 
						                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation | 
					
					
						
						| 
							 | 
						                deterministic. | 
					
					
						
						| 
							 | 
						            latents (`torch.Tensor`, *optional*): | 
					
					
						
						| 
							 | 
						                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image | 
					
					
						
						| 
							 | 
						                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | 
					
					
						
						| 
							 | 
						                tensor will ge generated by sampling using the supplied random `generator`. | 
					
					
						
						| 
							 | 
						            output_type (`str`, *optional*, defaults to `"pil"`): | 
					
					
						
						| 
							 | 
						                The output format of the generate image. Choose between | 
					
					
						
						| 
							 | 
						                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. | 
					
					
						
						| 
							 | 
						            return_dict (`bool`, *optional*, defaults to `True`): | 
					
					
						
						| 
							 | 
						                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | 
					
					
						
						| 
							 | 
						                plain tuple. | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: | 
					
					
						
						| 
							 | 
						            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. | 
					
					
						
						| 
							 | 
						            When returning a tuple, the first element is a list with the generated images, and the second element is a | 
					
					
						
						| 
							 | 
						            list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" | 
					
					
						
						| 
							 | 
						            (nsfw) content, according to the `safety_checker`. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        accelerator = Accelerator( | 
					
					
						
						| 
							 | 
						            gradient_accumulation_steps=1, | 
					
					
						
						| 
							 | 
						            mixed_precision="fp16", | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if "torch_device" in kwargs: | 
					
					
						
						| 
							 | 
						            device = kwargs.pop("torch_device") | 
					
					
						
						| 
							 | 
						            warnings.warn( | 
					
					
						
						| 
							 | 
						                "`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." | 
					
					
						
						| 
							 | 
						                " Consider using `pipe.to(torch_device)` instead." | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            if device is None: | 
					
					
						
						| 
							 | 
						                device = "cuda" if torch.cuda.is_available() else "cpu" | 
					
					
						
						| 
							 | 
						            self.to(device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if height % 8 != 0 or width % 8 != 0: | 
					
					
						
						| 
							 | 
						            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.vae.requires_grad_(False) | 
					
					
						
						| 
							 | 
						        self.unet.requires_grad_(False) | 
					
					
						
						| 
							 | 
						        self.text_encoder.requires_grad_(False) | 
					
					
						
						| 
							 | 
						        self.unet.eval() | 
					
					
						
						| 
							 | 
						        self.vae.eval() | 
					
					
						
						| 
							 | 
						        self.text_encoder.eval() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if accelerator.is_main_process: | 
					
					
						
						| 
							 | 
						            accelerator.init_trackers( | 
					
					
						
						| 
							 | 
						                "imagic", | 
					
					
						
						| 
							 | 
						                config={ | 
					
					
						
						| 
							 | 
						                    "embedding_learning_rate": embedding_learning_rate, | 
					
					
						
						| 
							 | 
						                    "text_embedding_optimization_steps": text_embedding_optimization_steps, | 
					
					
						
						| 
							 | 
						                }, | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        text_input = self.tokenizer( | 
					
					
						
						| 
							 | 
						            prompt, | 
					
					
						
						| 
							 | 
						            padding="max_length", | 
					
					
						
						| 
							 | 
						            max_length=self.tokenizer.model_max_length, | 
					
					
						
						| 
							 | 
						            truncation=True, | 
					
					
						
						| 
							 | 
						            return_tensors="pt", | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        text_embeddings = torch.nn.Parameter( | 
					
					
						
						| 
							 | 
						            self.text_encoder(text_input.input_ids.to(self.device))[0], requires_grad=True | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        text_embeddings = text_embeddings.detach() | 
					
					
						
						| 
							 | 
						        text_embeddings.requires_grad_() | 
					
					
						
						| 
							 | 
						        text_embeddings_orig = text_embeddings.clone() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        optimizer = torch.optim.Adam( | 
					
					
						
						| 
							 | 
						            [text_embeddings],   | 
					
					
						
						| 
							 | 
						            lr=embedding_learning_rate, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if isinstance(image, PIL.Image.Image): | 
					
					
						
						| 
							 | 
						            image = preprocess(image) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        latents_dtype = text_embeddings.dtype | 
					
					
						
						| 
							 | 
						        image = image.to(device=self.device, dtype=latents_dtype) | 
					
					
						
						| 
							 | 
						        init_latent_image_dist = self.vae.encode(image).latent_dist | 
					
					
						
						| 
							 | 
						        image_latents = init_latent_image_dist.sample(generator=generator) | 
					
					
						
						| 
							 | 
						        image_latents = 0.18215 * image_latents | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        progress_bar = tqdm(range(text_embedding_optimization_steps), disable=not accelerator.is_local_main_process) | 
					
					
						
						| 
							 | 
						        progress_bar.set_description("Steps") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        global_step = 0 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        logger.info("First optimizing the text embedding to better reconstruct the init image") | 
					
					
						
						| 
							 | 
						        for _ in range(text_embedding_optimization_steps): | 
					
					
						
						| 
							 | 
						            with accelerator.accumulate(text_embeddings): | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                noise = torch.randn(image_latents.shape).to(image_latents.device) | 
					
					
						
						| 
							 | 
						                timesteps = torch.randint(1000, (1,), device=image_latents.device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                noisy_latents = self.scheduler.add_noise(image_latents, noise, timesteps) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 
					
					
						
						| 
							 | 
						                accelerator.backward(loss) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                optimizer.step() | 
					
					
						
						| 
							 | 
						                optimizer.zero_grad() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            if accelerator.sync_gradients: | 
					
					
						
						| 
							 | 
						                progress_bar.update(1) | 
					
					
						
						| 
							 | 
						                global_step += 1 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            logs = {"loss": loss.detach().item()}   | 
					
					
						
						| 
							 | 
						            progress_bar.set_postfix(**logs) | 
					
					
						
						| 
							 | 
						            accelerator.log(logs, step=global_step) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        accelerator.wait_for_everyone() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        text_embeddings.requires_grad_(False) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.unet.requires_grad_(True) | 
					
					
						
						| 
							 | 
						        self.unet.train() | 
					
					
						
						| 
							 | 
						        optimizer = torch.optim.Adam( | 
					
					
						
						| 
							 | 
						            self.unet.parameters(),   | 
					
					
						
						| 
							 | 
						            lr=diffusion_model_learning_rate, | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						        progress_bar = tqdm(range(model_fine_tuning_optimization_steps), disable=not accelerator.is_local_main_process) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        logger.info("Next fine tuning the entire model to better reconstruct the init image") | 
					
					
						
						| 
							 | 
						        for _ in range(model_fine_tuning_optimization_steps): | 
					
					
						
						| 
							 | 
						            with accelerator.accumulate(self.unet.parameters()): | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                noise = torch.randn(image_latents.shape).to(image_latents.device) | 
					
					
						
						| 
							 | 
						                timesteps = torch.randint(1000, (1,), device=image_latents.device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                noisy_latents = self.scheduler.add_noise(image_latents, noise, timesteps) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                 | 
					
					
						
						| 
							 | 
						                noise_pred = self.unet(noisy_latents, timesteps, text_embeddings).sample | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 
					
					
						
						| 
							 | 
						                accelerator.backward(loss) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						                optimizer.step() | 
					
					
						
						| 
							 | 
						                optimizer.zero_grad() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            if accelerator.sync_gradients: | 
					
					
						
						| 
							 | 
						                progress_bar.update(1) | 
					
					
						
						| 
							 | 
						                global_step += 1 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						            logs = {"loss": loss.detach().item()}   | 
					
					
						
						| 
							 | 
						            progress_bar.set_postfix(**logs) | 
					
					
						
						| 
							 | 
						            accelerator.log(logs, step=global_step) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        accelerator.wait_for_everyone() | 
					
					
						
						| 
							 | 
						        self.text_embeddings_orig = text_embeddings_orig | 
					
					
						
						| 
							 | 
						        self.text_embeddings = text_embeddings | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    @torch.no_grad() | 
					
					
						
						| 
							 | 
						    def __call__( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        alpha: float = 1.2, | 
					
					
						
						| 
							 | 
						        height: Optional[int] = 512, | 
					
					
						
						| 
							 | 
						        width: Optional[int] = 512, | 
					
					
						
						| 
							 | 
						        num_inference_steps: Optional[int] = 50, | 
					
					
						
						| 
							 | 
						        generator: Optional[torch.Generator] = None, | 
					
					
						
						| 
							 | 
						        output_type: Optional[str] = "pil", | 
					
					
						
						| 
							 | 
						        return_dict: bool = True, | 
					
					
						
						| 
							 | 
						        guidance_scale: float = 7.5, | 
					
					
						
						| 
							 | 
						        eta: float = 0.0, | 
					
					
						
						| 
							 | 
						    ): | 
					
					
						
						| 
							 | 
						        r""" | 
					
					
						
						| 
							 | 
						        Function invoked when calling the pipeline for generation. | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            alpha (`float`, *optional*, defaults to 1.2): | 
					
					
						
						| 
							 | 
						                The interpolation factor between the original and optimized text embeddings. A value closer to 0 | 
					
					
						
						| 
							 | 
						                will resemble the original input image. | 
					
					
						
						| 
							 | 
						            height (`int`, *optional*, defaults to 512): | 
					
					
						
						| 
							 | 
						                The height in pixels of the generated image. | 
					
					
						
						| 
							 | 
						            width (`int`, *optional*, defaults to 512): | 
					
					
						
						| 
							 | 
						                The width in pixels of the generated image. | 
					
					
						
						| 
							 | 
						            num_inference_steps (`int`, *optional*, defaults to 50): | 
					
					
						
						| 
							 | 
						                The number of denoising steps. More denoising steps usually lead to a higher quality image at the | 
					
					
						
						| 
							 | 
						                expense of slower inference. | 
					
					
						
						| 
							 | 
						            guidance_scale (`float`, *optional*, defaults to 7.5): | 
					
					
						
						| 
							 | 
						                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). | 
					
					
						
						| 
							 | 
						                `guidance_scale` is defined as `w` of equation 2. of [Imagen | 
					
					
						
						| 
							 | 
						                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > | 
					
					
						
						| 
							 | 
						                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, | 
					
					
						
						| 
							 | 
						                usually at the expense of lower image quality. | 
					
					
						
						| 
							 | 
						            generator (`torch.Generator`, *optional*): | 
					
					
						
						| 
							 | 
						                A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation | 
					
					
						
						| 
							 | 
						                deterministic. | 
					
					
						
						| 
							 | 
						            output_type (`str`, *optional*, defaults to `"pil"`): | 
					
					
						
						| 
							 | 
						                The output format of the generate image. Choose between | 
					
					
						
						| 
							 | 
						                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. | 
					
					
						
						| 
							 | 
						            return_dict (`bool`, *optional*, defaults to `True`): | 
					
					
						
						| 
							 | 
						                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | 
					
					
						
						| 
							 | 
						                plain tuple. | 
					
					
						
						| 
							 | 
						            eta (`float`, *optional*, defaults to 0.0): | 
					
					
						
						| 
							 | 
						                Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to | 
					
					
						
						| 
							 | 
						                [`schedulers.DDIMScheduler`], will be ignored for others. | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: | 
					
					
						
						| 
							 | 
						            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. | 
					
					
						
						| 
							 | 
						            When returning a tuple, the first element is a list with the generated images, and the second element is a | 
					
					
						
						| 
							 | 
						            list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" | 
					
					
						
						| 
							 | 
						            (nsfw) content, according to the `safety_checker`. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        if height % 8 != 0 or width % 8 != 0: | 
					
					
						
						| 
							 | 
						            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | 
					
					
						
						| 
							 | 
						        if self.text_embeddings is None: | 
					
					
						
						| 
							 | 
						            raise ValueError("Please run the pipe.train() before trying to generate an image.") | 
					
					
						
						| 
							 | 
						        if self.text_embeddings_orig is None: | 
					
					
						
						| 
							 | 
						            raise ValueError("Please run the pipe.train() before trying to generate an image.") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        text_embeddings = alpha * self.text_embeddings_orig + (1 - alpha) * self.text_embeddings | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        do_classifier_free_guidance = guidance_scale > 1.0 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if do_classifier_free_guidance: | 
					
					
						
						| 
							 | 
						            uncond_tokens = [""] | 
					
					
						
						| 
							 | 
						            max_length = self.tokenizer.model_max_length | 
					
					
						
						| 
							 | 
						            uncond_input = self.tokenizer( | 
					
					
						
						| 
							 | 
						                uncond_tokens, | 
					
					
						
						| 
							 | 
						                padding="max_length", | 
					
					
						
						| 
							 | 
						                max_length=max_length, | 
					
					
						
						| 
							 | 
						                truncation=True, | 
					
					
						
						| 
							 | 
						                return_tensors="pt", | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            seq_len = uncond_embeddings.shape[1] | 
					
					
						
						| 
							 | 
						            uncond_embeddings = uncond_embeddings.view(1, seq_len, -1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        latents_shape = (1, self.unet.config.in_channels, height // 8, width // 8) | 
					
					
						
						| 
							 | 
						        latents_dtype = text_embeddings.dtype | 
					
					
						
						| 
							 | 
						        if self.device.type == "mps": | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( | 
					
					
						
						| 
							 | 
						                self.device | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.scheduler.set_timesteps(num_inference_steps) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        timesteps_tensor = self.scheduler.timesteps.to(self.device) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        latents = latents * self.scheduler.init_noise_sigma | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | 
					
					
						
						| 
							 | 
						        extra_step_kwargs = {} | 
					
					
						
						| 
							 | 
						        if accepts_eta: | 
					
					
						
						| 
							 | 
						            extra_step_kwargs["eta"] = eta | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        for i, t in enumerate(self.progress_bar(timesteps_tensor)): | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | 
					
					
						
						| 
							 | 
						            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            if do_classifier_free_guidance: | 
					
					
						
						| 
							 | 
						                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | 
					
					
						
						| 
							 | 
						                noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						             | 
					
					
						
						| 
							 | 
						            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        latents = 1 / 0.18215 * latents | 
					
					
						
						| 
							 | 
						        image = self.vae.decode(latents).sample | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        image = (image / 2 + 0.5).clamp(0, 1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        image = image.cpu().permute(0, 2, 3, 1).float().numpy() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if self.safety_checker is not None: | 
					
					
						
						| 
							 | 
						            safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( | 
					
					
						
						| 
							 | 
						                self.device | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						            image, has_nsfw_concept = self.safety_checker( | 
					
					
						
						| 
							 | 
						                images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) | 
					
					
						
						| 
							 | 
						            ) | 
					
					
						
						| 
							 | 
						        else: | 
					
					
						
						| 
							 | 
						            has_nsfw_concept = None | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if output_type == "pil": | 
					
					
						
						| 
							 | 
						            image = self.numpy_to_pil(image) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        if not return_dict: | 
					
					
						
						| 
							 | 
						            return (image, has_nsfw_concept) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) | 
					
					
						
						| 
							 | 
						
 |