|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect |
|
import os |
|
import shutil |
|
from glob import glob |
|
from typing import Any, Callable, Dict, List, Optional, Union |
|
|
|
import cv2 |
|
import numpy as np |
|
import PIL.Image |
|
import requests |
|
import torch |
|
from detectron2.config import get_cfg |
|
from detectron2.data import MetadataCatalog |
|
from detectron2.engine import DefaultPredictor |
|
from detectron2.projects import point_rend |
|
from detectron2.structures.instances import Instances |
|
from detectron2.utils.visualizer import ColorMode, Visualizer |
|
from packaging import version |
|
from tqdm import tqdm |
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer |
|
|
|
from diffusers.configuration_utils import FrozenDict |
|
from diffusers.image_processor import VaeImageProcessor |
|
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin |
|
from diffusers.models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel |
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline |
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput |
|
from diffusers.schedulers import KarrasDiffusionSchedulers |
|
from diffusers.utils import ( |
|
deprecate, |
|
is_accelerate_available, |
|
is_accelerate_version, |
|
logging, |
|
randn_tensor, |
|
) |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
AMI_INSTALL_MESSAGE = """ |
|
|
|
Example Demo of Adaptive Mask Inpainting |
|
|
|
Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models |
|
Kim et al. |
|
ECCV-2024 (Oral) |
|
|
|
|
|
Please prepare the environment via |
|
|
|
``` |
|
conda create --name ami python=3.9 -y |
|
conda activate ami |
|
|
|
conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge -y |
|
python -m pip install detectron2==0.6 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html |
|
pip install easydict |
|
pip install diffusers==0.20.2 accelerate safetensors transformers |
|
pip install setuptools==59.5.0 |
|
pip install opencv-python |
|
pip install numpy==1.24.1 |
|
``` |
|
|
|
|
|
Put the code inside the root of diffusers library (e.g., as '/home/username/diffusers/adaptive_mask_inpainting_example.py') and run the python code. |
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
EXAMPLE_DOC_STRING = """ |
|
Examples: |
|
```py |
|
>>> # !pip install transformers accelerate |
|
>>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler |
|
>>> from diffusers.utils import load_image |
|
>>> import numpy as np |
|
>>> import torch |
|
|
|
>>> init_image = load_image( |
|
... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" |
|
... ) |
|
>>> init_image = init_image.resize((512, 512)) |
|
|
|
>>> generator = torch.Generator(device="cpu").manual_seed(1) |
|
|
|
>>> mask_image = load_image( |
|
... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" |
|
... ) |
|
>>> mask_image = mask_image.resize((512, 512)) |
|
|
|
|
|
>>> def make_inpaint_condition(image, image_mask): |
|
... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 |
|
... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 |
|
|
|
... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size" |
|
... image[image_mask > 0.5] = -1.0 # set as masked pixel |
|
... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) |
|
... image = torch.from_numpy(image) |
|
... return image |
|
|
|
|
|
>>> control_image = make_inpaint_condition(init_image, mask_image) |
|
|
|
>>> controlnet = ControlNetModel.from_pretrained( |
|
... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16 |
|
... ) |
|
>>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( |
|
... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 |
|
... ) |
|
|
|
>>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) |
|
>>> pipe.enable_model_cpu_offload() |
|
|
|
>>> # generate image |
|
>>> image = pipe( |
|
... "a handsome man with ray-ban sunglasses", |
|
... num_inference_steps=20, |
|
... generator=generator, |
|
... eta=1.0, |
|
... image=init_image, |
|
... mask_image=mask_image, |
|
... control_image=control_image, |
|
... ).images[0] |
|
``` |
|
""" |
|
|
|
|
|
def download_file(url, output_file, exist_ok: bool): |
|
if exist_ok and os.path.exists(output_file): |
|
return |
|
|
|
response = requests.get(url, stream=True) |
|
|
|
with open(output_file, "wb") as file: |
|
for chunk in tqdm(response.iter_content(chunk_size=8192), desc=f"Downloading '{output_file}'..."): |
|
if chunk: |
|
file.write(chunk) |
|
|
|
|
|
def generate_video_from_imgs(images_save_directory, fps=15.0, delete_dir=True): |
|
|
|
if os.path.exists(f"{images_save_directory}.mp4"): |
|
os.remove(f"{images_save_directory}.mp4") |
|
if os.path.exists(f"{images_save_directory}_before_process.mp4"): |
|
os.remove(f"{images_save_directory}_before_process.mp4") |
|
|
|
|
|
assert os.path.isdir(images_save_directory) |
|
ImgPaths = sorted(glob(f"{images_save_directory}/*")) |
|
|
|
if len(ImgPaths) == 0: |
|
print("\tSkipping, since there must be at least one image to create mp4\n") |
|
else: |
|
|
|
video_path = images_save_directory + "_before_process.mp4" |
|
|
|
|
|
images = sorted([ImgPath.split("/")[-1] for ImgPath in ImgPaths if ImgPath.endswith(".png")]) |
|
frame = cv2.imread(os.path.join(images_save_directory, images[0])) |
|
height, width, channels = frame.shape |
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|
video = cv2.VideoWriter(video_path, fourcc, fps, (width, height)) |
|
for image in images: |
|
video.write(cv2.imread(os.path.join(images_save_directory, image))) |
|
cv2.destroyAllWindows() |
|
video.release() |
|
|
|
|
|
os.system( |
|
f'ffmpeg -i "{images_save_directory}_before_process.mp4" -vcodec libx264 -f mp4 "{images_save_directory}.mp4" ' |
|
) |
|
|
|
|
|
if delete_dir and os.path.exists(images_save_directory): |
|
shutil.rmtree(images_save_directory) |
|
|
|
if os.path.exists(f"{images_save_directory}_before_process.mp4"): |
|
os.remove(f"{images_save_directory}_before_process.mp4") |
|
|
|
|
|
|
|
def prepare_mask_and_masked_image(image, mask, height, width, return_image=False): |
|
""" |
|
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be |
|
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the |
|
``image`` and ``1`` for the ``mask``. |
|
|
|
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be |
|
binarized (``mask > 0.5``) and cast to ``torch.float32`` too. |
|
|
|
Args: |
|
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. |
|
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` |
|
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. |
|
mask (_type_): The mask to apply to the image, i.e. regions to inpaint. |
|
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` |
|
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. |
|
|
|
|
|
Raises: |
|
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask |
|
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. |
|
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not |
|
(ot the other way around). |
|
|
|
Returns: |
|
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 |
|
dimensions: ``batch x channels x height x width``. |
|
""" |
|
|
|
if image is None: |
|
raise ValueError("`image` input cannot be undefined.") |
|
|
|
if mask is None: |
|
raise ValueError("`mask_image` input cannot be undefined.") |
|
|
|
if isinstance(image, torch.Tensor): |
|
if not isinstance(mask, torch.Tensor): |
|
raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") |
|
|
|
|
|
if image.ndim == 3: |
|
assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" |
|
image = image.unsqueeze(0) |
|
|
|
|
|
if mask.ndim == 2: |
|
mask = mask.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
if mask.ndim == 3: |
|
|
|
if mask.shape[0] == 1: |
|
mask = mask.unsqueeze(0) |
|
|
|
|
|
else: |
|
mask = mask.unsqueeze(1) |
|
|
|
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" |
|
assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" |
|
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" |
|
|
|
|
|
if image.min() < -1 or image.max() > 1: |
|
raise ValueError("Image should be in [-1, 1] range") |
|
|
|
|
|
if mask.min() < 0 or mask.max() > 1: |
|
raise ValueError("Mask should be in [0, 1] range") |
|
|
|
|
|
mask[mask < 0.5] = 0 |
|
mask[mask >= 0.5] = 1 |
|
|
|
|
|
image = image.to(dtype=torch.float32) |
|
elif isinstance(mask, torch.Tensor): |
|
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") |
|
else: |
|
|
|
if isinstance(image, (PIL.Image.Image, np.ndarray)): |
|
image = [image] |
|
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): |
|
|
|
image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] |
|
image = [np.array(i.convert("RGB"))[None, :] for i in image] |
|
image = np.concatenate(image, axis=0) |
|
elif isinstance(image, list) and isinstance(image[0], np.ndarray): |
|
image = np.concatenate([i[None, :] for i in image], axis=0) |
|
|
|
image = image.transpose(0, 3, 1, 2) |
|
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 |
|
|
|
|
|
if isinstance(mask, (PIL.Image.Image, np.ndarray)): |
|
mask = [mask] |
|
|
|
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): |
|
mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] |
|
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) |
|
mask = mask.astype(np.float32) / 255.0 |
|
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): |
|
mask = np.concatenate([m[None, None, :] for m in mask], axis=0) |
|
|
|
mask[mask < 0.5] = 0 |
|
mask[mask >= 0.5] = 1 |
|
mask = torch.from_numpy(mask) |
|
|
|
masked_image = image * (mask < 0.5) |
|
|
|
|
|
if return_image: |
|
return mask, masked_image, image |
|
|
|
return mask, masked_image |
|
|
|
|
|
class AdaptiveMaskInpaintPipeline( |
|
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin |
|
): |
|
r""" |
|
Pipeline for text-guided image inpainting using Stable Diffusion. |
|
|
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods |
|
implemented for all pipelines (downloading, saving, running on a particular device, etc.). |
|
|
|
The pipeline also inherits the following loading methods: |
|
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings |
|
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights |
|
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights |
|
|
|
Args: |
|
vae ([`AutoencoderKL`, `AsymmetricAutoencoderKL`]): |
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. |
|
text_encoder ([`CLIPTextModel`]): |
|
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). |
|
tokenizer ([`~transformers.CLIPTokenizer`]): |
|
A `CLIPTokenizer` to tokenize text. |
|
unet ([`UNet2DConditionModel`]): |
|
A `UNet2DConditionModel` to denoise the encoded image latents. |
|
scheduler ([`SchedulerMixin`]): |
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of |
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. |
|
safety_checker ([`StableDiffusionSafetyChecker`]): |
|
Classification module that estimates whether generated images could be considered offensive or harmful. |
|
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details |
|
about a model's potential harms. |
|
feature_extractor ([`~transformers.CLIPImageProcessor`]): |
|
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. |
|
""" |
|
|
|
_optional_components = ["safety_checker", "feature_extractor"] |
|
|
|
def __init__( |
|
self, |
|
vae: Union[AutoencoderKL, AsymmetricAutoencoderKL], |
|
text_encoder: CLIPTextModel, |
|
tokenizer: CLIPTokenizer, |
|
unet: UNet2DConditionModel, |
|
scheduler: KarrasDiffusionSchedulers, |
|
|
|
safety_checker, |
|
feature_extractor: CLIPImageProcessor, |
|
requires_safety_checker: bool = True, |
|
): |
|
super().__init__() |
|
|
|
self.register_adaptive_mask_model() |
|
self.register_adaptive_mask_settings() |
|
|
|
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: |
|
deprecation_message = ( |
|
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" |
|
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " |
|
"to update the config accordingly as leaving `steps_offset` might led to incorrect results" |
|
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," |
|
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" |
|
" file" |
|
) |
|
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) |
|
new_config = dict(scheduler.config) |
|
new_config["steps_offset"] = 1 |
|
scheduler._internal_dict = FrozenDict(new_config) |
|
|
|
if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: |
|
deprecation_message = ( |
|
f"The configuration file of this scheduler: {scheduler} has not set the configuration" |
|
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" |
|
" sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" |
|
" incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" |
|
" Hub, it would be very nice if you could open a Pull request for the" |
|
" `scheduler/scheduler_config.json` file" |
|
) |
|
deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) |
|
new_config = dict(scheduler.config) |
|
new_config["skip_prk_steps"] = True |
|
scheduler._internal_dict = FrozenDict(new_config) |
|
|
|
if safety_checker is None and requires_safety_checker: |
|
logger.warning( |
|
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" |
|
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" |
|
" results in services or applications open to the public. Both the diffusers team and Hugging Face" |
|
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" |
|
" it only for use-cases that involve analyzing network behavior or auditing its results. For more" |
|
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." |
|
) |
|
|
|
if safety_checker is not None and feature_extractor is None: |
|
raise ValueError( |
|
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" |
|
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." |
|
) |
|
|
|
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( |
|
version.parse(unet.config._diffusers_version).base_version |
|
) < version.parse("0.9.0.dev0") |
|
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 |
|
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: |
|
deprecation_message = ( |
|
"The configuration file of the unet has set the default `sample_size` to smaller than" |
|
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" |
|
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" |
|
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" |
|
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" |
|
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" |
|
" in the config might lead to incorrect results in future versions. If you have downloaded this" |
|
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" |
|
" the `unet/config.json` file" |
|
) |
|
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) |
|
new_config = dict(unet.config) |
|
new_config["sample_size"] = 64 |
|
unet._internal_dict = FrozenDict(new_config) |
|
|
|
|
|
if unet.config.in_channels != 9: |
|
logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.") |
|
|
|
self.register_modules( |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
unet=unet, |
|
scheduler=scheduler, |
|
safety_checker=safety_checker, |
|
feature_extractor=feature_extractor, |
|
) |
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) |
|
self.register_to_config(requires_safety_checker=requires_safety_checker) |
|
|
|
""" Preparation for Adaptive Mask inpainting """ |
|
|
|
|
|
def enable_model_cpu_offload(self, gpu_id=0): |
|
r""" |
|
Offload all models to CPU to reduce memory usage with a low impact on performance. Moves one whole model at a |
|
time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. |
|
Memory savings are lower than using `enable_sequential_cpu_offload`, but performance is much better due to the |
|
iterative execution of the `unet`. |
|
""" |
|
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): |
|
from accelerate import cpu_offload_with_hook |
|
else: |
|
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") |
|
|
|
device = torch.device(f"cuda:{gpu_id}") |
|
|
|
if self.device.type != "cpu": |
|
self.to("cpu", silence_dtype_warnings=True) |
|
torch.cuda.empty_cache() |
|
|
|
hook = None |
|
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: |
|
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) |
|
|
|
if self.safety_checker is not None: |
|
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) |
|
|
|
|
|
self.final_offload_hook = hook |
|
|
|
|
|
def _encode_prompt( |
|
self, |
|
prompt, |
|
device, |
|
num_images_per_prompt, |
|
do_classifier_free_guidance, |
|
negative_prompt=None, |
|
prompt_embeds: Optional[torch.FloatTensor] = None, |
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
lora_scale: Optional[float] = None, |
|
): |
|
r""" |
|
Encodes the prompt into text encoder hidden states. |
|
|
|
Args: |
|
prompt (`str` or `List[str]`, *optional*): |
|
prompt to be encoded |
|
device: (`torch.device`): |
|
torch device |
|
num_images_per_prompt (`int`): |
|
number of images that should be generated per prompt |
|
do_classifier_free_guidance (`bool`): |
|
whether to use classifier free guidance or not |
|
negative_prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass |
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is |
|
less than `1`). |
|
prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
|
provided, text embeddings will be generated from `prompt` input argument. |
|
negative_prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt |
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input |
|
argument. |
|
lora_scale (`float`, *optional*): |
|
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. |
|
""" |
|
|
|
|
|
if lora_scale is not None and isinstance(self, LoraLoaderMixin): |
|
self._lora_scale = lora_scale |
|
|
|
if prompt is not None and isinstance(prompt, str): |
|
batch_size = 1 |
|
elif prompt is not None and isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
if prompt_embeds is None: |
|
|
|
if isinstance(self, TextualInversionLoaderMixin): |
|
prompt = self.maybe_convert_prompt(prompt, self.tokenizer) |
|
|
|
text_inputs = self.tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=self.tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
text_input_ids = text_inputs.input_ids |
|
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids |
|
|
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( |
|
text_input_ids, untruncated_ids |
|
): |
|
removed_text = self.tokenizer.batch_decode( |
|
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] |
|
) |
|
logger.warning( |
|
"The following part of your input was truncated because CLIP can only handle sequences up to" |
|
f" {self.tokenizer.model_max_length} tokens: {removed_text}" |
|
) |
|
|
|
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: |
|
attention_mask = text_inputs.attention_mask.to(device) |
|
else: |
|
attention_mask = None |
|
|
|
prompt_embeds = self.text_encoder( |
|
text_input_ids.to(device), |
|
attention_mask=attention_mask, |
|
) |
|
prompt_embeds = prompt_embeds[0] |
|
|
|
if self.text_encoder is not None: |
|
prompt_embeds_dtype = self.text_encoder.dtype |
|
elif self.unet is not None: |
|
prompt_embeds_dtype = self.unet.dtype |
|
else: |
|
prompt_embeds_dtype = prompt_embeds.dtype |
|
|
|
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) |
|
|
|
bs_embed, seq_len, _ = prompt_embeds.shape |
|
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) |
|
|
|
|
|
if do_classifier_free_guidance and negative_prompt_embeds is None: |
|
uncond_tokens: List[str] |
|
if negative_prompt is None: |
|
uncond_tokens = [""] * batch_size |
|
elif prompt is not None and type(prompt) is not type(negative_prompt): |
|
raise TypeError( |
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" |
|
f" {type(prompt)}." |
|
) |
|
elif isinstance(negative_prompt, str): |
|
uncond_tokens = [negative_prompt] |
|
elif batch_size != len(negative_prompt): |
|
raise ValueError( |
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" |
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" |
|
" the batch size of `prompt`." |
|
) |
|
else: |
|
uncond_tokens = negative_prompt |
|
|
|
|
|
if isinstance(self, TextualInversionLoaderMixin): |
|
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) |
|
|
|
max_length = prompt_embeds.shape[1] |
|
uncond_input = self.tokenizer( |
|
uncond_tokens, |
|
padding="max_length", |
|
max_length=max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
|
|
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: |
|
attention_mask = uncond_input.attention_mask.to(device) |
|
else: |
|
attention_mask = None |
|
|
|
negative_prompt_embeds = self.text_encoder( |
|
uncond_input.input_ids.to(device), |
|
attention_mask=attention_mask, |
|
) |
|
negative_prompt_embeds = negative_prompt_embeds[0] |
|
|
|
if do_classifier_free_guidance: |
|
|
|
seq_len = negative_prompt_embeds.shape[1] |
|
|
|
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) |
|
|
|
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) |
|
|
|
|
|
|
|
|
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) |
|
|
|
return prompt_embeds |
|
|
|
|
|
def run_safety_checker(self, image, device, dtype): |
|
if self.safety_checker is None: |
|
has_nsfw_concept = None |
|
else: |
|
if torch.is_tensor(image): |
|
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") |
|
else: |
|
feature_extractor_input = self.image_processor.numpy_to_pil(image) |
|
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) |
|
image, has_nsfw_concept = self.safety_checker( |
|
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) |
|
) |
|
return image, has_nsfw_concept |
|
|
|
|
|
def prepare_extra_step_kwargs(self, generator, eta): |
|
|
|
|
|
|
|
|
|
|
|
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
|
extra_step_kwargs = {} |
|
if accepts_eta: |
|
extra_step_kwargs["eta"] = eta |
|
|
|
|
|
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
|
if accepts_generator: |
|
extra_step_kwargs["generator"] = generator |
|
return extra_step_kwargs |
|
|
|
def check_inputs( |
|
self, |
|
prompt, |
|
height, |
|
width, |
|
strength, |
|
callback_steps, |
|
negative_prompt=None, |
|
prompt_embeds=None, |
|
negative_prompt_embeds=None, |
|
): |
|
if strength < 0 or strength > 1: |
|
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") |
|
|
|
if height % 8 != 0 or width % 8 != 0: |
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
|
|
|
if (callback_steps is None) or ( |
|
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) |
|
): |
|
raise ValueError( |
|
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" |
|
f" {type(callback_steps)}." |
|
) |
|
|
|
if prompt is not None and prompt_embeds is not None: |
|
raise ValueError( |
|
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" |
|
" only forward one of the two." |
|
) |
|
elif prompt is None and prompt_embeds is None: |
|
raise ValueError( |
|
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." |
|
) |
|
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): |
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") |
|
|
|
if negative_prompt is not None and negative_prompt_embeds is not None: |
|
raise ValueError( |
|
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" |
|
f" {negative_prompt_embeds}. Please make sure to only forward one of the two." |
|
) |
|
|
|
if prompt_embeds is not None and negative_prompt_embeds is not None: |
|
if prompt_embeds.shape != negative_prompt_embeds.shape: |
|
raise ValueError( |
|
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" |
|
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" |
|
f" {negative_prompt_embeds.shape}." |
|
) |
|
|
|
def prepare_latents( |
|
self, |
|
batch_size, |
|
num_channels_latents, |
|
height, |
|
width, |
|
dtype, |
|
device, |
|
generator, |
|
latents=None, |
|
image=None, |
|
timestep=None, |
|
is_strength_max=True, |
|
return_noise=False, |
|
return_image_latents=False, |
|
): |
|
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) |
|
if isinstance(generator, list) and len(generator) != batch_size: |
|
raise ValueError( |
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" |
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators." |
|
) |
|
|
|
if (image is None or timestep is None) and not is_strength_max: |
|
raise ValueError( |
|
"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." |
|
"However, either the image or the noise timestep has not been provided." |
|
) |
|
|
|
if return_image_latents or (latents is None and not is_strength_max): |
|
image = image.to(device=device, dtype=dtype) |
|
image_latents = self._encode_vae_image(image=image, generator=generator) |
|
|
|
if latents is None: |
|
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
|
|
|
latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) |
|
|
|
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents |
|
else: |
|
noise = latents.to(device) |
|
latents = noise * self.scheduler.init_noise_sigma |
|
|
|
outputs = (latents,) |
|
|
|
if return_noise: |
|
outputs += (noise,) |
|
|
|
if return_image_latents: |
|
outputs += (image_latents,) |
|
|
|
return outputs |
|
|
|
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): |
|
if isinstance(generator, list): |
|
image_latents = [ |
|
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) |
|
for i in range(image.shape[0]) |
|
] |
|
image_latents = torch.cat(image_latents, dim=0) |
|
else: |
|
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) |
|
|
|
image_latents = self.vae.config.scaling_factor * image_latents |
|
|
|
return image_latents |
|
|
|
def prepare_mask_latents( |
|
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance |
|
): |
|
|
|
|
|
|
|
mask = torch.nn.functional.interpolate( |
|
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) |
|
) |
|
mask = mask.to(device=device, dtype=dtype) |
|
|
|
masked_image = masked_image.to(device=device, dtype=dtype) |
|
masked_image_latents = self._encode_vae_image(masked_image, generator=generator) |
|
|
|
|
|
if mask.shape[0] < batch_size: |
|
if not batch_size % mask.shape[0] == 0: |
|
raise ValueError( |
|
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" |
|
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" |
|
" of masks that you pass is divisible by the total requested batch size." |
|
) |
|
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) |
|
if masked_image_latents.shape[0] < batch_size: |
|
if not batch_size % masked_image_latents.shape[0] == 0: |
|
raise ValueError( |
|
"The passed images and the required batch size don't match. Images are supposed to be duplicated" |
|
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." |
|
" Make sure the number of images that you pass is divisible by the total requested batch size." |
|
) |
|
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) |
|
|
|
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask |
|
masked_image_latents = ( |
|
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents |
|
) |
|
|
|
|
|
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) |
|
return mask, masked_image_latents |
|
|
|
|
|
def get_timesteps(self, num_inference_steps, strength, device): |
|
|
|
init_timestep = min(int(num_inference_steps * strength), num_inference_steps) |
|
|
|
t_start = max(num_inference_steps - init_timestep, 0) |
|
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] |
|
|
|
return timesteps, num_inference_steps - t_start |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
prompt: Union[str, List[str]] = None, |
|
image: Union[torch.FloatTensor, PIL.Image.Image] = None, |
|
default_mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
strength: float = 1.0, |
|
num_inference_steps: int = 50, |
|
guidance_scale: float = 7.5, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
eta: float = 0.0, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
latents: Optional[torch.FloatTensor] = None, |
|
prompt_embeds: Optional[torch.FloatTensor] = None, |
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
|
callback_steps: int = 1, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
use_adaptive_mask: bool = True, |
|
enforce_full_mask_ratio: float = 0.5, |
|
human_detection_thres: float = 0.008, |
|
visualization_save_dir: str = None, |
|
): |
|
r""" |
|
The call function to the pipeline for generation. |
|
|
|
Args: |
|
prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. |
|
image (`PIL.Image.Image`): |
|
`Image` or tensor representing an image batch to be inpainted (which parts of the image to be masked |
|
out with `default_mask_image` and repainted according to `prompt`). |
|
default_mask_image (`PIL.Image.Image`): |
|
`Image` or tensor representing an image batch to mask `image`. White pixels in the mask are repainted |
|
while black pixels are preserved. If `default_mask_image` is a PIL image, it is converted to a single channel |
|
(luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the |
|
expected shape would be `(B, H, W, 1)`. |
|
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): |
|
The height in pixels of the generated image. |
|
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): |
|
The width in pixels of the generated image. |
|
strength (`float`, *optional*, defaults to 1.0): |
|
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a |
|
starting point and more noise is added the higher the `strength`. The number of denoising steps depends |
|
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising |
|
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 |
|
essentially ignores `image`. |
|
num_inference_steps (`int`, *optional*, defaults to 50): |
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
|
expense of slower inference. This parameter is modulated by `strength`. |
|
guidance_scale (`float`, *optional*, defaults to 7.5): |
|
A higher guidance scale value encourages the model to generate images closely linked to the text |
|
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. |
|
negative_prompt (`str` or `List[str]`, *optional*): |
|
The prompt or prompts to guide what to not include in image generation. If not defined, you need to |
|
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). |
|
num_images_per_prompt (`int`, *optional*, defaults to 1): |
|
The number of images to generate per prompt. |
|
eta (`float`, *optional*, defaults to 0.0): |
|
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies |
|
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. |
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): |
|
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make |
|
generation deterministic. |
|
latents (`torch.FloatTensor`, *optional*): |
|
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image |
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents |
|
tensor is generated by sampling using the supplied random `generator`. |
|
prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not |
|
provided, text embeddings are generated from the `prompt` input argument. |
|
negative_prompt_embeds (`torch.FloatTensor`, *optional*): |
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If |
|
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. |
|
output_type (`str`, *optional*, defaults to `"pil"`): |
|
The output format of the generated image. Choose between `PIL.Image` or `np.array`. |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
|
plain tuple. |
|
callback (`Callable`, *optional*): |
|
A function that calls every `callback_steps` steps during inference. The function is called with the |
|
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. |
|
callback_steps (`int`, *optional*, defaults to 1): |
|
The frequency at which the `callback` function is called. If not specified, the callback is called at |
|
every step. |
|
cross_attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in |
|
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
|
|
|
Examples: |
|
|
|
```py |
|
>>> import PIL |
|
>>> import requests |
|
>>> import torch |
|
>>> from io import BytesIO |
|
|
|
>>> from diffusers import AdaptiveMaskInpaintPipeline |
|
|
|
|
|
>>> def download_image(url): |
|
... response = requests.get(url) |
|
... return PIL.Image.open(BytesIO(response.content)).convert("RGB") |
|
|
|
|
|
>>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" |
|
>>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" |
|
|
|
>>> init_image = download_image(img_url).resize((512, 512)) |
|
>>> default_mask_image = download_image(mask_url).resize((512, 512)) |
|
|
|
>>> pipe = AdaptiveMaskInpaintPipeline.from_pretrained( |
|
... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16 |
|
... ) |
|
>>> pipe = pipe.to("cuda") |
|
|
|
>>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" |
|
>>> image = pipe(prompt=prompt, image=init_image, default_mask_image=default_mask_image).images[0] |
|
``` |
|
|
|
Returns: |
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: |
|
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, |
|
otherwise a `tuple` is returned where the first element is a list with the generated images and the |
|
second element is a list of `bool`s indicating whether the corresponding generated image contains |
|
"not-safe-for-work" (nsfw) content. |
|
""" |
|
|
|
width, height = image.size |
|
|
|
|
|
|
|
|
|
self.check_inputs( |
|
prompt, |
|
height, |
|
width, |
|
strength, |
|
callback_steps, |
|
negative_prompt, |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
) |
|
|
|
|
|
if prompt is not None and isinstance(prompt, str): |
|
batch_size = 1 |
|
elif prompt is not None and isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
device = self._execution_device |
|
|
|
|
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
|
|
text_encoder_lora_scale = ( |
|
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None |
|
) |
|
prompt_embeds = self._encode_prompt( |
|
prompt, |
|
device, |
|
num_images_per_prompt, |
|
do_classifier_free_guidance, |
|
negative_prompt, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
lora_scale=text_encoder_lora_scale, |
|
) |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
timesteps, num_inference_steps = self.get_timesteps( |
|
num_inference_steps=num_inference_steps, strength=strength, device=device |
|
) |
|
|
|
if num_inference_steps < 1: |
|
raise ValueError( |
|
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" |
|
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." |
|
) |
|
|
|
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) |
|
|
|
is_strength_max = strength == 1.0 |
|
|
|
|
|
mask, masked_image, init_image = prepare_mask_and_masked_image( |
|
image, default_mask_image, height, width, return_image=True |
|
) |
|
default_mask_image_np = np.array(default_mask_image).astype(np.uint8) / 255 |
|
mask_condition = mask.clone() |
|
|
|
|
|
num_channels_latents = self.vae.config.latent_channels |
|
num_channels_unet = self.unet.config.in_channels |
|
return_image_latents = num_channels_unet == 4 |
|
|
|
latents_outputs = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
latents, |
|
image=init_image, |
|
timestep=latent_timestep, |
|
is_strength_max=is_strength_max, |
|
return_noise=True, |
|
return_image_latents=return_image_latents, |
|
) |
|
|
|
if return_image_latents: |
|
latents, noise, image_latents = latents_outputs |
|
else: |
|
latents, noise = latents_outputs |
|
|
|
|
|
mask, masked_image_latents = self.prepare_mask_latents( |
|
mask, |
|
masked_image, |
|
batch_size * num_images_per_prompt, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
do_classifier_free_guidance, |
|
) |
|
|
|
|
|
if num_channels_unet == 9: |
|
|
|
num_channels_mask = mask.shape[1] |
|
num_channels_masked_image = masked_image_latents.shape[1] |
|
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: |
|
raise ValueError( |
|
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" |
|
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" |
|
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" |
|
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" |
|
" `pipeline.unet` or your `default_mask_image` or `image` input." |
|
) |
|
elif num_channels_unet != 4: |
|
raise ValueError( |
|
f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." |
|
) |
|
|
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
|
|
|
|
mask_image_np = default_mask_image_np |
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
|
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
|
|
|
|
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
if num_channels_unet == 9: |
|
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
noise_pred = self.unet( |
|
latent_model_input, |
|
t, |
|
encoder_hidden_states=prompt_embeds, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
return_dict=False, |
|
)[0] |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
outputs = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True) |
|
latents = outputs["prev_sample"] |
|
pred_orig_latents = outputs["pred_original_sample"] |
|
|
|
|
|
if use_adaptive_mask: |
|
if enforce_full_mask_ratio > 0.0: |
|
use_default_mask = t < self.scheduler.config.num_train_timesteps * enforce_full_mask_ratio |
|
elif enforce_full_mask_ratio == 0.0: |
|
use_default_mask = False |
|
else: |
|
raise NotImplementedError |
|
|
|
pred_orig_image = self.decode_to_npuint8_image(pred_orig_latents) |
|
dilate_num = self.adaptive_mask_settings.dilate_scheduler(i) |
|
do_adapt_mask = self.adaptive_mask_settings.provoke_scheduler(i) |
|
if do_adapt_mask: |
|
mask, masked_image_latents, mask_image_np, vis_np = self.adapt_mask( |
|
init_image, |
|
pred_orig_image, |
|
default_mask_image_np, |
|
dilate_num=dilate_num, |
|
use_default_mask=use_default_mask, |
|
height=height, |
|
width=width, |
|
batch_size=batch_size, |
|
num_images_per_prompt=num_images_per_prompt, |
|
prompt_embeds=prompt_embeds, |
|
device=device, |
|
generator=generator, |
|
do_classifier_free_guidance=do_classifier_free_guidance, |
|
i=i, |
|
human_detection_thres=human_detection_thres, |
|
mask_image_np=mask_image_np, |
|
) |
|
|
|
if self.adaptive_mask_model.use_visualizer: |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
os.makedirs(visualization_save_dir, exist_ok=True) |
|
|
|
|
|
plt.axis("off") |
|
plt.subplot(1, 2, 1) |
|
plt.imshow(mask_image_np) |
|
plt.subplot(1, 2, 2) |
|
plt.imshow(pred_orig_image) |
|
plt.savefig(f"{visualization_save_dir}/{i:05}.png", bbox_inches="tight") |
|
plt.close("all") |
|
|
|
if num_channels_unet == 4: |
|
init_latents_proper = image_latents[:1] |
|
init_mask = mask[:1] |
|
|
|
if i < len(timesteps) - 1: |
|
noise_timestep = timesteps[i + 1] |
|
init_latents_proper = self.scheduler.add_noise( |
|
init_latents_proper, noise, torch.tensor([noise_timestep]) |
|
) |
|
|
|
latents = (1 - init_mask) * init_latents_proper + init_mask * latents |
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
progress_bar.update() |
|
if callback is not None and i % callback_steps == 0: |
|
callback(i, t, latents) |
|
|
|
if not output_type == "latent": |
|
condition_kwargs = {} |
|
if isinstance(self.vae, AsymmetricAutoencoderKL): |
|
init_image = init_image.to(device=device, dtype=masked_image_latents.dtype) |
|
init_image_condition = init_image.clone() |
|
init_image = self._encode_vae_image(init_image, generator=generator) |
|
mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype) |
|
condition_kwargs = {"image": init_image_condition, "mask": mask_condition} |
|
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, **condition_kwargs)[0] |
|
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) |
|
else: |
|
image = latents |
|
has_nsfw_concept = None |
|
|
|
if has_nsfw_concept is None: |
|
do_denormalize = [True] * image.shape[0] |
|
else: |
|
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] |
|
|
|
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) |
|
|
|
|
|
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: |
|
self.final_offload_hook.offload() |
|
|
|
if self.adaptive_mask_model.use_visualizer: |
|
generate_video_from_imgs(images_save_directory=visualization_save_dir, fps=10, delete_dir=True) |
|
|
|
if not return_dict: |
|
return (image, has_nsfw_concept) |
|
|
|
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |
|
|
|
def decode_to_npuint8_image(self, latents): |
|
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, **{})[ |
|
0 |
|
] |
|
image = self.image_processor.postprocess(image, output_type="pt", do_denormalize=[True] * image.shape[0]) |
|
image = (image.squeeze().permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8) |
|
return image |
|
|
|
def register_adaptive_mask_settings(self): |
|
from easydict import EasyDict |
|
|
|
num_steps = 50 |
|
|
|
step_num = int(num_steps * 0.1) |
|
final_step_num = num_steps - step_num * 7 |
|
|
|
self.adaptive_mask_settings = EasyDict( |
|
dilate_scheduler=MaskDilateScheduler( |
|
max_dilate_num=20, |
|
num_inference_steps=num_steps, |
|
schedule=[20] * step_num |
|
+ [10] * step_num |
|
+ [5] * step_num |
|
+ [4] * step_num |
|
+ [3] * step_num |
|
+ [2] * step_num |
|
+ [1] * step_num |
|
+ [0] * final_step_num, |
|
), |
|
dilate_kernel=np.ones((3, 3), dtype=np.uint8), |
|
provoke_scheduler=ProvokeScheduler( |
|
num_inference_steps=num_steps, |
|
schedule=list(range(2, 10 + 1, 2)) + list(range(12, 40 + 1, 2)) + [45], |
|
is_zero_indexing=False, |
|
), |
|
) |
|
|
|
def register_adaptive_mask_model(self): |
|
|
|
use_visualizer = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.adaptive_mask_model = PointRendPredictor( |
|
|
|
pointrend_thres=0.9, |
|
device="cuda" if torch.cuda.is_available() else "cpu", |
|
use_visualizer=use_visualizer, |
|
config_pth="pointrend_rcnn_R_50_FPN_3x_coco.yaml", |
|
weights_pth="model_final_edd263.pkl", |
|
) |
|
|
|
def adapt_mask(self, init_image, pred_orig_image, default_mask_image, dilate_num, use_default_mask, **kwargs): |
|
|
|
adapt_output = self.adaptive_mask_model(pred_orig_image) |
|
mask = adapt_output["mask"] |
|
vis = adapt_output["vis"] |
|
|
|
|
|
if use_default_mask or mask.sum() < 512 * 512 * kwargs["human_detection_thres"]: |
|
|
|
mask = default_mask_image |
|
|
|
else: |
|
|
|
mask = cv2.dilate( |
|
mask, self.adaptive_mask_settings.dilate_kernel, iterations=dilate_num |
|
) |
|
mask = np.logical_and(mask, default_mask_image) |
|
|
|
|
|
mask = torch.tensor(mask, dtype=torch.float32).to(kwargs["device"])[None, None] |
|
mask, masked_image = prepare_mask_and_masked_image( |
|
init_image.to(kwargs["device"]), mask, kwargs["height"], kwargs["width"], return_image=False |
|
) |
|
|
|
mask_image_np = mask.clone().squeeze().detach().cpu().numpy() |
|
|
|
mask, masked_image_latents = self.prepare_mask_latents( |
|
mask, |
|
masked_image, |
|
kwargs["batch_size"] * kwargs["num_images_per_prompt"], |
|
kwargs["height"], |
|
kwargs["width"], |
|
kwargs["prompt_embeds"].dtype, |
|
kwargs["device"], |
|
kwargs["generator"], |
|
kwargs["do_classifier_free_guidance"], |
|
) |
|
|
|
return mask, masked_image_latents, mask_image_np, vis |
|
|
|
|
|
def seg2bbox(seg_mask: np.ndarray): |
|
nonzero_i, nonzero_j = seg_mask.nonzero() |
|
min_i, max_i = nonzero_i.min(), nonzero_i.max() |
|
min_j, max_j = nonzero_j.min(), nonzero_j.max() |
|
|
|
return np.array([min_j, min_i, max_j + 1, max_i + 1]) |
|
|
|
|
|
def merge_bbox(bboxes: list): |
|
assert len(bboxes) > 0 |
|
|
|
all_bboxes = np.stack(bboxes, axis=0) |
|
merged_bbox = np.zeros_like(all_bboxes[0]) |
|
|
|
merged_bbox[0] = all_bboxes[:, 0].min() |
|
merged_bbox[1] = all_bboxes[:, 1].min() |
|
merged_bbox[2] = all_bboxes[:, 2].max() |
|
merged_bbox[3] = all_bboxes[:, 3].max() |
|
|
|
return merged_bbox |
|
|
|
|
|
class PointRendPredictor: |
|
def __init__( |
|
self, |
|
cat_id_to_focus=0, |
|
pointrend_thres=0.9, |
|
device="cuda", |
|
use_visualizer=False, |
|
merge_mode="merge", |
|
config_pth=None, |
|
weights_pth=None, |
|
): |
|
super().__init__() |
|
|
|
|
|
self.cat_id_to_focus = cat_id_to_focus |
|
|
|
|
|
self.coco_metadata = MetadataCatalog.get("coco_2017_val") |
|
self.cfg = get_cfg() |
|
|
|
|
|
point_rend.add_pointrend_config(self.cfg) |
|
self.cfg.merge_from_file(config_pth) |
|
self.cfg.MODEL.WEIGHTS = weights_pth |
|
self.cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = pointrend_thres |
|
self.cfg.MODEL.DEVICE = device |
|
|
|
|
|
self.pointrend_seg_model = DefaultPredictor(self.cfg) |
|
|
|
|
|
self.use_visualizer = use_visualizer |
|
|
|
|
|
assert merge_mode in ["merge", "max-confidence"], f"'merge_mode': {merge_mode} not implemented." |
|
self.merge_mode = merge_mode |
|
|
|
def merge_mask(self, masks, scores=None): |
|
if self.merge_mode == "merge": |
|
mask = np.any(masks, axis=0) |
|
elif self.merge_mode == "max-confidence": |
|
mask = masks[np.argmax(scores)] |
|
return mask |
|
|
|
def vis_seg_on_img(self, image, mask): |
|
if type(mask) == np.ndarray: |
|
mask = torch.tensor(mask) |
|
v = Visualizer(image, self.coco_metadata, scale=0.5, instance_mode=ColorMode.IMAGE_BW) |
|
instances = Instances(image_size=image.shape[:2], pred_masks=mask if len(mask.shape) == 3 else mask[None]) |
|
vis = v.draw_instance_predictions(instances.to("cpu")).get_image() |
|
return vis |
|
|
|
def __call__(self, image): |
|
|
|
outputs = self.pointrend_seg_model(image) |
|
instances = outputs["instances"] |
|
|
|
|
|
is_class = instances.pred_classes == self.cat_id_to_focus |
|
masks = instances.pred_masks[is_class] |
|
masks = masks.detach().cpu().numpy() |
|
mask = self.merge_mask(masks, scores=instances.scores[is_class]) |
|
|
|
return { |
|
"asset_mask": None, |
|
"mask": mask.astype(np.uint8), |
|
"vis": self.vis_seg_on_img(image, mask) if self.use_visualizer else None, |
|
} |
|
|
|
|
|
class MaskDilateScheduler: |
|
def __init__(self, max_dilate_num=15, num_inference_steps=50, schedule=None): |
|
super().__init__() |
|
self.max_dilate_num = max_dilate_num |
|
self.schedule = [num_inference_steps - i for i in range(num_inference_steps)] if schedule is None else schedule |
|
assert len(self.schedule) == num_inference_steps |
|
|
|
def __call__(self, i): |
|
return min(self.max_dilate_num, self.schedule[i]) |
|
|
|
|
|
class ProvokeScheduler: |
|
def __init__(self, num_inference_steps=50, schedule=None, is_zero_indexing=False): |
|
super().__init__() |
|
if len(schedule) > 0: |
|
if is_zero_indexing: |
|
assert max(schedule) <= num_inference_steps - 1 |
|
else: |
|
assert max(schedule) <= num_inference_steps |
|
|
|
|
|
self.is_zero_indexing = is_zero_indexing |
|
self.schedule = schedule |
|
|
|
def __call__(self, i): |
|
if self.is_zero_indexing: |
|
return i in self.schedule |
|
else: |
|
return i + 1 in self.schedule |
|
|