|
import inspect |
|
import warnings |
|
from typing import Any, Callable, Dict, List, Optional, Union |
|
|
|
import numpy as np |
|
import PIL |
|
from PIL import Image, ImageFilter, ImageOps |
|
import torch |
|
|
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel |
|
from diffusers.pipelines.stable_diffusion import ( |
|
StableDiffusionInpaintPipeline, |
|
StableDiffusionPipelineOutput, |
|
StableDiffusionSafetyChecker, |
|
) |
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image |
|
|
|
|
|
def fill_images_masks(images: Union[list, PIL.Image.Image], masks: Union[list, PIL.Image.Image]): |
|
|
|
new_images = [] |
|
|
|
if isinstance(images, PIL.Image.Image) is True: |
|
if isinstance(masks, PIL.Image.Image) is False: |
|
raise TypeError(f"`image` is a PIL.Image.Image but `mask` (type: {type(masks)} is not") |
|
images = [images] |
|
masks = [masks] |
|
|
|
if isinstance(images, list) is True: |
|
if isinstance(masks, list) is False: |
|
raise TypeError(f"`image` is a list but `mask` (type: {type(masks)} is not") |
|
|
|
for image, mask in zip(images, masks): |
|
filled_image = fill(image, mask) |
|
new_images.append(filled_image) |
|
else: |
|
raise ValueError(f"image is not a list but {type(images)}") |
|
|
|
return new_images, masks |
|
|
|
|
|
def fill(image: PIL.Image.Image, mask: PIL.Image.Image): |
|
"""fills masked regions with colors from image using blur. Not extremely effective.""" |
|
|
|
image_mod = Image.new('RGBA', (image.width, image.height)) |
|
|
|
image_masked = Image.new('RGBa', (image.width, image.height)) |
|
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L'))) |
|
|
|
image_masked = image_masked.convert('RGBa') |
|
|
|
for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]: |
|
blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA') |
|
for _ in range(repeats): |
|
image_mod.alpha_composite(blurred) |
|
|
|
return image_mod.convert("RGB") |
|
|
|
class StableDiffusionFillInpaintPipeline(StableDiffusionInpaintPipeline): |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
prompt: Union[str, List[str]] = None, |
|
image: Union[torch.FloatTensor, PIL.Image.Image] = None, |
|
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
strength: float = 1.0, |
|
num_inference_steps: int = 50, |
|
guidance_scale: float = 7.5, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
eta: float = 0.0, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
latents: Optional[torch.FloatTensor] = None, |
|
prompt_embeds: Optional[torch.FloatTensor] = None, |
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
|
callback_steps: int = 1, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
): |
|
r""" |
|
Function invoked when calling the pipeline for generation. |
|
|
|
Args: |
|
prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. |
|
instead. |
|
image (`PIL.Image.Image`): |
|
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will |
|
be masked out with `mask_image` and repainted according to `prompt`. |
|
mask_image (`PIL.Image.Image`): |
|
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be |
|
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted |
|
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) |
|
instead of 3, so the expected shape would be `(B, H, W, 1)`. |
|
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): |
|
The height in pixels of the generated image. |
|
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): |
|
The width in pixels of the generated image. |
|
strength (`float`, *optional*, defaults to 1.): |
|
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be |
|
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the |
|
`strength`. The number of denoising steps depends on the amount of noise initially added. When |
|
`strength` is 1, added noise will be maximum and the denoising process will run for the full number of |
|
iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked |
|
portion of the reference `image`. |
|
num_inference_steps (`int`, *optional*, defaults to 50): |
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
|
expense of slower inference. |
|
guidance_scale (`float`, *optional*, defaults to 7.5): |
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). |
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen |
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > |
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, |
|
usually at the expense of lower image quality. |
|
negative_prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass |
|
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` |
|
is less than `1`). |
|
num_images_per_prompt (`int`, *optional*, defaults to 1): |
|
The number of images to generate per prompt. |
|
eta (`float`, *optional*, defaults to 0.0): |
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to |
|
[`schedulers.DDIMScheduler`], will be ignored for others. |
|
generator (`torch.Generator`, *optional*): |
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) |
|
to make generation deterministic. |
|
latents (`torch.FloatTensor`, *optional*): |
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image |
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
|
tensor will ge generated by sampling using the supplied random `generator`. |
|
prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
|
provided, text embeddings will be generated from `prompt` input argument. |
|
negative_prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
|
argument. |
|
output_type (`str`, *optional*, defaults to `"pil"`): |
|
The output format of the generate image. Choose between |
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
|
plain tuple. |
|
callback (`Callable`, *optional*): |
|
A function that will be called every `callback_steps` steps during inference. The function will be |
|
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. |
|
callback_steps (`int`, *optional*, defaults to 1): |
|
The frequency at which the `callback` function will be called. If not specified, the callback will be |
|
called at every step. |
|
cross_attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
|
`self.processor` in |
|
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). |
|
Examples: |
|
|
|
```py |
|
>>> import PIL |
|
>>> import requests |
|
>>> import torch |
|
>>> from io import BytesIO |
|
|
|
>>> from diffusers import StableDiffusionInpaintPipeline |
|
|
|
|
|
>>> def download_image(url): |
|
... response = requests.get(url) |
|
... return PIL.Image.open(BytesIO(response.content)).convert("RGB") |
|
|
|
|
|
>>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" |
|
>>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" |
|
|
|
>>> init_image = download_image(img_url).resize((512, 512)) |
|
>>> mask_image = download_image(mask_url).resize((512, 512)) |
|
|
|
>>> pipe = StableDiffusionInpaintPipeline.from_pretrained( |
|
... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16 |
|
... ) |
|
>>> pipe = pipe.to("cuda") |
|
|
|
>>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" |
|
>>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] |
|
``` |
|
|
|
Returns: |
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: |
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. |
|
When returning a tuple, the first element is a list with the generated images, and the second element is a |
|
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" |
|
(nsfw) content, according to the `safety_checker`. |
|
""" |
|
|
|
height = height or self.unet.config.sample_size * self.vae_scale_factor |
|
width = width or self.unet.config.sample_size * self.vae_scale_factor |
|
|
|
|
|
self.check_inputs( |
|
prompt, |
|
height, |
|
width, |
|
strength, |
|
callback_steps, |
|
negative_prompt, |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
) |
|
|
|
|
|
if prompt is not None and isinstance(prompt, str): |
|
batch_size = 1 |
|
elif prompt is not None and isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
device = self._execution_device |
|
|
|
|
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
|
|
text_encoder_lora_scale = ( |
|
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None |
|
) |
|
prompt_embeds = self._encode_prompt( |
|
prompt, |
|
device, |
|
num_images_per_prompt, |
|
do_classifier_free_guidance, |
|
negative_prompt, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
lora_scale=text_encoder_lora_scale, |
|
) |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
timesteps, num_inference_steps = self.get_timesteps( |
|
num_inference_steps=num_inference_steps, strength=strength, device=device |
|
) |
|
|
|
if num_inference_steps < 1: |
|
raise ValueError( |
|
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" |
|
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." |
|
) |
|
|
|
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) |
|
|
|
is_strength_max = strength == 1.0 |
|
|
|
|
|
if batch_size == 1: |
|
original_image, original_mask = [image], [mask_image] |
|
|
|
image, mask_image = fill_images_masks(image, mask_image) |
|
|
|
|
|
|
|
mask, masked_image, init_image = prepare_mask_and_masked_image( |
|
image, mask_image, height, width, return_image=True |
|
) |
|
|
|
|
|
num_channels_latents = self.vae.config.latent_channels |
|
num_channels_unet = self.unet.config.in_channels |
|
return_image_latents = num_channels_unet == 4 |
|
|
|
latents_outputs = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
latents, |
|
image=init_image, |
|
timestep=latent_timestep, |
|
is_strength_max=is_strength_max, |
|
return_noise=True, |
|
return_image_latents=return_image_latents, |
|
) |
|
|
|
if return_image_latents: |
|
latents, noise, image_latents = latents_outputs |
|
else: |
|
latents, noise = latents_outputs |
|
|
|
|
|
mask, masked_image_latents = self.prepare_mask_latents( |
|
mask, |
|
masked_image, |
|
batch_size * num_images_per_prompt, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
do_classifier_free_guidance, |
|
) |
|
init_image = init_image.to(device=device, dtype=masked_image_latents.dtype) |
|
init_image = self._encode_vae_image(init_image, generator=generator) |
|
|
|
|
|
if num_channels_unet == 9: |
|
|
|
num_channels_mask = mask.shape[1] |
|
num_channels_masked_image = masked_image_latents.shape[1] |
|
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: |
|
raise ValueError( |
|
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" |
|
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" |
|
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" |
|
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" |
|
" `pipeline.unet` or your `mask_image` or `image` input." |
|
) |
|
elif num_channels_unet != 4: |
|
raise ValueError( |
|
f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." |
|
) |
|
|
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
|
|
|
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
|
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
|
|
|
|
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
if num_channels_unet == 9: |
|
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) |
|
|
|
|
|
noise_pred = self.unet( |
|
latent_model_input, |
|
t, |
|
encoder_hidden_states=prompt_embeds, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
return_dict=False, |
|
)[0] |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] |
|
|
|
if num_channels_unet == 4: |
|
init_latents_proper = image_latents[:1] |
|
init_mask = mask[:1] |
|
|
|
if i < len(timesteps) - 1: |
|
noise_timestep = timesteps[i + 1] |
|
init_latents_proper = self.scheduler.add_noise( |
|
init_latents_proper, noise, torch.tensor([noise_timestep]) |
|
) |
|
|
|
latents = (1 - init_mask) * init_latents_proper + init_mask * latents |
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
progress_bar.update() |
|
if callback is not None and i % callback_steps == 0: |
|
callback(i, t, latents) |
|
|
|
if not output_type == "latent": |
|
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] |
|
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) |
|
else: |
|
image = latents |
|
has_nsfw_concept = None |
|
|
|
if has_nsfw_concept is None: |
|
do_denormalize = [True] * image.shape[0] |
|
else: |
|
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] |
|
|
|
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) |
|
|
|
image = [Image.composite(generate_img, original_img, mask_img) for generate_img, original_img, mask_img in zip(image, original_image, original_mask)] |
|
|
|
|
|
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: |
|
self.final_offload_hook.offload() |
|
|
|
if not return_dict: |
|
return (image, has_nsfw_concept) |
|
|
|
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |