Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2024 Jaerin Lee | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
import inspect | |
from typing import Any, Callable, Dict, List, Literal, Tuple, Optional, Union | |
from tqdm import tqdm | |
from PIL import Image | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.transforms as T | |
from einops import rearrange | |
from transformers import ( | |
CLIPTextModelWithProjection, | |
CLIPTokenizer, | |
T5EncoderModel, | |
T5TokenizerFast, | |
) | |
from transformers import Blip2Processor, Blip2ForConditionalGeneration | |
from diffusers.image_processor import VaeImageProcessor | |
from diffusers.loaders import FromSingleFileMixin, SD3LoraLoaderMixin | |
from diffusers.models.attention_processor import ( | |
AttnProcessor2_0, | |
FusedAttnProcessor2_0, | |
LoRAAttnProcessor2_0, | |
LoRAXFormersAttnProcessor, | |
XFormersAttnProcessor, | |
) | |
from diffusers.models.autoencoders import AutoencoderKL | |
from diffusers.models.transformers import SD3Transformer2DModel | |
from diffusers.pipelines.stable_diffusion_3 import StableDiffusion3PipelineOutput | |
from diffusers.schedulers import ( | |
FlowMatchEulerDiscreteScheduler, | |
FlashFlowMatchEulerDiscreteScheduler, | |
) | |
from diffusers.utils import ( | |
is_torch_xla_available, | |
logging, | |
replace_example_docstring, | |
) | |
from diffusers.utils.torch_utils import randn_tensor | |
from diffusers import ( | |
DiffusionPipeline, | |
StableDiffusion3Pipeline, | |
) | |
from peft import PeftModel | |
from util import load_model, gaussian_lowpass, blend, get_panorama_views, shift_to_mask_bbox_center | |
if is_torch_xla_available(): | |
import torch_xla.core.xla_model as xm | |
XLA_AVAILABLE = True | |
else: | |
XLA_AVAILABLE = False | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
EXAMPLE_DOC_STRING = """ | |
Examples: | |
```py | |
>>> import torch | |
>>> from diffusers import StableDiffusion3Pipeline | |
>>> pipe = StableDiffusion3Pipeline.from_pretrained( | |
... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16 | |
... ) | |
>>> pipe.to("cuda") | |
>>> prompt = "A cat holding a sign that says hello world" | |
>>> image = pipe(prompt).images[0] | |
>>> image.save("sd3.png") | |
``` | |
""" | |
class StableMultiDiffusion3Pipeline(nn.Module): | |
def __init__( | |
self, | |
device: torch.device, | |
dtype: torch.dtype = torch.float16, | |
hf_key: Optional[str] = None, | |
lora_key: Optional[str] = None, | |
load_from_local: bool = False, # Turn on if you have already downloaed LoRA & Hugging Face hub is down. | |
default_mask_std: float = 1.0, # 8.0 | |
default_mask_strength: float = 1.0, | |
default_prompt_strength: float = 1.0, # 8.0 | |
default_bootstrap_steps: int = 1, | |
default_boostrap_mix_steps: float = 1.0, | |
default_bootstrap_leak_sensitivity: float = 0.2, | |
default_preprocess_mask_cover_alpha: float = 0.3, | |
t_index_list: List[int] = [0, 4, 12, 25, 37], # [0, 5, 16, 18, 20, 37], # # [0, 12, 25, 37], # Magic number. | |
mask_type: Literal['discrete', 'semi-continuous', 'continuous'] = 'discrete', | |
has_i2t: bool = True, | |
lora_weight: float = 1.0, | |
) -> None: | |
r"""Stabilized MultiDiffusion for fast sampling. | |
Accelrated region-based text-to-image synthesis with Latent Consistency | |
Model while preserving mask fidelity and quality. | |
Args: | |
device (torch.device): Specify CUDA device. | |
hf_key (Optional[str]): Custom StableDiffusion checkpoint for | |
stylized generation. | |
lora_key (Optional[str]): Custom Lightning LoRA for acceleration. | |
load_from_local (bool): Turn on if you have already downloaed LoRA | |
& Hugging Face hub is down. | |
default_mask_std (float): Preprocess mask with Gaussian blur with | |
specified standard deviation. | |
default_mask_strength (float): Preprocess mask by multiplying it | |
globally with the specified variable. Caution: extremely | |
sensitive. Recommended range: 0.98-1. | |
default_prompt_strength (float): Preprocess foreground prompts | |
globally by linearly interpolating its embedding with the | |
background prompt embeddint with specified mix ratio. Useful | |
control handle for foreground blending. Recommended range: | |
0.5-1. | |
default_bootstrap_steps (int): Bootstrapping stage steps to | |
encourage region separation. Recommended range: 1-3. | |
default_boostrap_mix_steps (float): Bootstrapping background is a | |
linear interpolation between background latent and the white | |
image latent. This handle controls the mix ratio. Available | |
range: 0-(number of bootstrapping inference steps). For | |
example, 2.3 means that for the first two steps, white image | |
is used as a bootstrapping background and in the third step, | |
mixture of white (0.3) and registered background (0.7) is used | |
as a bootstrapping background. | |
default_bootstrap_leak_sensitivity (float): Postprocessing at each | |
inference step by masking away the remaining bootstrap | |
backgrounds t Recommended range: 0-1. | |
default_preprocess_mask_cover_alpha (float): Optional preprocessing | |
where each mask covered by other masks is reduced in its alpha | |
value by this specified factor. | |
t_index_list (List[int]): The default scheduling for the scheduler. | |
mask_type (Literal['discrete', 'semi-continuous', 'continuous']): | |
defines the mask quantization modes. Details in the codes of | |
`self.process_mask`. Basically, this (subtly) controls the | |
smoothness of foreground-background blending. More continuous | |
means more blending, but smaller generated patch depending on | |
the mask standard deviation. | |
has_i2t (bool): Automatic background image to text prompt con- | |
version with BLIP-2 model. May not be necessary for the non- | |
streaming application. | |
lora_weight (float): Adjusts weight of the LCM/Lightning LoRA. | |
Heavily affects the overall quality! | |
""" | |
super().__init__() | |
self.device = device | |
self.dtype = dtype | |
self.default_mask_std = default_mask_std | |
self.default_mask_strength = default_mask_strength | |
self.default_prompt_strength = default_prompt_strength | |
self.default_t_list = t_index_list | |
self.default_bootstrap_steps = default_bootstrap_steps | |
self.default_boostrap_mix_steps = default_boostrap_mix_steps | |
self.default_bootstrap_leak_sensitivity = default_bootstrap_leak_sensitivity | |
self.default_preprocess_mask_cover_alpha = default_preprocess_mask_cover_alpha | |
self.mask_type = mask_type | |
# Create model. | |
print(f'[INFO] Loading Stable Diffusion...') | |
if hf_key is not None: | |
print(f'[INFO] Using Hugging Face custom model key: {hf_key}') | |
else: | |
hf_key = "stabilityai/stable-diffusion-3-medium-diffusers" | |
transformer = SD3Transformer2DModel.from_pretrained( | |
hf_key, | |
subfolder="transformer", | |
torch_dtype=torch.float16, | |
).to(self.device) | |
transformer = PeftModel.from_pretrained(transformer, "jasperai/flash-sd3").to(self.device) | |
self.pipe = StableDiffusion3Pipeline.from_pretrained( | |
"stabilityai/stable-diffusion-3-medium-diffusers", | |
transformer=transformer, | |
torch_dtype=torch.float16, | |
text_encoder_3=None, | |
tokenizer_3=None | |
).to(self.device) | |
# Create model | |
if has_i2t: | |
self.i2t_processor = Blip2Processor.from_pretrained('Salesforce/blip2-opt-2.7b') | |
self.i2t_model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-opt-2.7b') | |
# Use SDXL-Lightning LoRA by default. | |
self.pipe.scheduler = FlashFlowMatchEulerDiscreteScheduler.from_pretrained( | |
"stabilityai/stable-diffusion-3-medium-diffusers", subfolder="scheduler") | |
self.pipe = self.pipe.to(self.device) | |
self.scheduler = self.pipe.scheduler | |
self.default_num_inference_steps = 4 | |
self.default_guidance_scale = 0.0 | |
if t_index_list is None: | |
self.prepare_flashflowmatch_schedule( | |
list(range(self.default_num_inference_steps)), | |
self.default_num_inference_steps, | |
) | |
else: | |
self.prepare_flashflowmatch_schedule(t_index_list, 50) | |
self.vae = self.pipe.vae | |
self.tokenizer = self.pipe.tokenizer | |
self.tokenizer_2 = self.pipe.tokenizer_2 | |
self.tokenizer_3 = self.pipe.tokenizer_3 | |
self.text_encoder = self.pipe.text_encoder | |
self.text_encoder_2 = self.pipe.text_encoder_2 | |
self.text_encoder_3 = self.pipe.text_encoder_3 | |
self.transformer = self.pipe.transformer | |
self.vae_scale_factor = self.pipe.vae_scale_factor | |
# Prepare white background for bootstrapping. | |
self.get_white_background(1024, 1024) | |
print(f'[INFO] Model is loaded!') | |
def prepare_flashflowmatch_schedule( | |
self, | |
t_index_list: Optional[List[int]] = None, | |
num_inference_steps: Optional[int] = None, | |
) -> None: | |
r"""Set up different inference schedule for the diffusion model. | |
You do not have to run this explicitly if you want to use the default | |
setting, but if you want other time schedules, run this function | |
between the module initialization and the main call. | |
Note: | |
- Recommended t_index_lists for LCMs: | |
- [0, 12, 25, 37]: Default schedule for 4 steps. Best for | |
panorama. Not recommended if you want to use bootstrapping. | |
Because bootstrapping stage affects the initial structuring | |
of the generated image & in this four step LCM, this is done | |
with only at the first step, the structure may be distorted. | |
- [0, 4, 12, 25, 37]: Recommended if you would use 1-step boot- | |
strapping. Default initialization in this implementation. | |
- [0, 5, 16, 18, 20, 37]: Recommended if you would use 2-step | |
bootstrapping. | |
- Due to the characteristic of SD1.5 LCM LoRA, setting | |
`num_inference_steps` larger than 20 may results in overly blurry | |
and unrealistic images. Beware! | |
Args: | |
t_index_list (Optional[List[int]]): The specified scheduling step | |
regarding the maximum timestep as `num_inference_steps`, which | |
is by default, 50. That means that | |
`t_index_list=[0, 12, 25, 37]` is a relative time indices basd | |
on the full scale of 50. If None, reinitialize the module with | |
the default value. | |
num_inference_steps (Optional[int]): The maximum timestep of the | |
sampler. Defines relative scale of the `t_index_list`. Rarely | |
used in practice. If None, reinitialize the module with the | |
default value. | |
""" | |
if t_index_list is None: | |
t_index_list = self.default_t_list | |
if num_inference_steps is None: | |
num_inference_steps = self.default_num_inference_steps | |
self.scheduler.set_timesteps(num_inference_steps) | |
self.timesteps = self.scheduler.timesteps[torch.tensor(t_index_list)].to(self.device) | |
# FlashFlowMatchEulerDiscreteScheduler | |
# https://github.com/initml/diffusers/blob/clement/feature/flash_sd3/src/diffusers/schedulers/scheduling_flash_flow_match_euler_discrete.py | |
self.sigmas = self.scheduler.sigmas[torch.tensor(t_index_list)].to(self.device) | |
self.sigmas_next = torch.cat([self.sigmas, self.sigmas.new_zeros(1)])[1:].to(self.device) | |
noise_lvs = self.sigmas * (self.sigmas**2 + 1)**(-0.5) | |
self.noise_lvs = noise_lvs[None, :, None, None, None] | |
self.next_noise_lvs = torch.cat([noise_lvs[1:], noise_lvs.new_zeros(1)])[None, :, None, None, None] | |
def get_text_prompts(self, image: Image.Image) -> str: | |
r"""A convenient method to extract text prompt from an image. | |
This is called if the user does not provide background prompt but only | |
the background image. We use BLIP-2 to automatically generate prompts. | |
Args: | |
image (Image.Image): A PIL image. | |
Returns: | |
A single string of text prompt. | |
""" | |
if hasattr(self, 'i2t_model'): | |
question = 'Question: What are in the image? Answer:' | |
inputs = self.i2t_processor(image, question, return_tensors='pt') | |
out = self.i2t_model.generate(**inputs, max_new_tokens=77) | |
prompt = self.i2t_processor.decode(out[0], skip_special_tokens=True).strip() | |
return prompt | |
else: | |
return '' | |
def encode_imgs( | |
self, | |
imgs: torch.Tensor, | |
generator: Optional[torch.Generator] = None, | |
vae: Optional[nn.Module] = None, | |
) -> torch.Tensor: | |
r"""A wrapper function for VAE encoder of the latent diffusion model. | |
Args: | |
imgs (torch.Tensor): An image to get StableDiffusion latents. | |
Expected shape: (B, 3, H, W). Expected pixel scale: [0, 1]. | |
generator (Optional[torch.Generator]): Seed for KL-Autoencoder. | |
vae (Optional[nn.Module]): Explicitly specify VAE (used for | |
the demo application with TinyVAE). | |
Returns: | |
An image latent embedding with 1/8 size (depending on the auto- | |
encoder. Shape: (B, 4, H//8, W//8). | |
""" | |
def _retrieve_latents( | |
encoder_output: torch.Tensor, | |
generator: Optional[torch.Generator] = None, | |
sample_mode: str = 'sample', | |
): | |
if hasattr(encoder_output, 'latent_dist') and sample_mode == 'sample': | |
return encoder_output.latent_dist.sample(generator) | |
elif hasattr(encoder_output, 'latent_dist') and sample_mode == 'argmax': | |
return encoder_output.latent_dist.mode() | |
elif hasattr(encoder_output, 'latents'): | |
return encoder_output.latents | |
else: | |
raise AttributeError('Could not access latents of provided encoder_output') | |
vae = self.vae if vae is None else vae | |
imgs = 2 * imgs - 1 | |
latents = vae.config.scaling_factor * _retrieve_latents(vae.encode(imgs), generator=generator) | |
return latents | |
def decode_latents(self, latents: torch.Tensor, vae: Optional[nn.Module] = None) -> torch.Tensor: | |
r"""A wrapper function for VAE decoder of the latent diffusion model. | |
Args: | |
latents (torch.Tensor): An image latent to get associated images. | |
Expected shape: (B, 4, H//8, W//8). | |
vae (Optional[nn.Module]): Explicitly specify VAE (used for | |
the demo application with TinyVAE). | |
Returns: | |
An image latent embedding with 1/8 size (depending on the auto- | |
encoder. Shape: (B, 3, H, W). | |
""" | |
vae = self.vae if vae is None else vae | |
latents = 1 / vae.config.scaling_factor * latents | |
imgs = vae.decode(latents).sample | |
imgs = (imgs / 2 + 0.5).clip_(0, 1) | |
return imgs | |
def get_white_background(self, height: int, width: int) -> torch.Tensor: | |
r"""White background image latent for bootstrapping or in case of | |
absent background. | |
Additionally stores the maximally-sized white latent for fast retrieval | |
in the future. By default, we initially call this with 1024x1024 sized | |
white image, so the function is rarely visited twice. | |
Args: | |
height (int): The height of the white *image*, not its latent. | |
width (int): The width of the white *image*, not its latent. | |
Returns: | |
A white image latent of size (1, 4, height//8, width//8). A cropped | |
version of the stored white latent is returned if the requested | |
size is smaller than what we already have created. | |
""" | |
if not hasattr(self, 'white') or self.white.shape[-2] < height or self.white.shape[-1] < width: | |
white = torch.ones(1, 3, height, width, dtype=self.dtype, device=self.device) | |
self.white = self.encode_imgs(white) | |
return self.white | |
return self.white[..., :(height // self.vae_scale_factor), :(width // self.vae_scale_factor)] | |
def process_mask( | |
self, | |
masks: Union[torch.Tensor, Image.Image, List[Image.Image]], | |
strength: Optional[Union[torch.Tensor, float]] = None, | |
std: Optional[Union[torch.Tensor, float]] = None, | |
height: int = 1024, | |
width: int = 1024, | |
use_boolean_mask: bool = True, | |
timesteps: Optional[torch.Tensor] = None, | |
preprocess_mask_cover_alpha: Optional[float] = None, | |
) -> Tuple[torch.Tensor]: | |
r"""Fast preprocess of masks for region-based generation with fine- | |
grained controls. | |
Mask preprocessing is done in four steps: | |
1. Resizing: Resize the masks into the specified width and height by | |
nearest neighbor interpolation. | |
2. (Optional) Ordering: Masks with higher indices are considered to | |
cover the masks with smaller indices. Covered masks are decayed | |
in its alpha value by the specified factor of | |
`preprocess_mask_cover_alpha`. | |
3. Blurring: Gaussian blur is applied to the mask with the specified | |
standard deviation (isotropic). This results in gradual increase of | |
masked region as the timesteps evolve, naturally blending fore- | |
ground and the predesignated background. Not strictly required if | |
you want to produce images from scratch withoout background. | |
4. Quantization: Split the real-numbered masks of value between [0, 1] | |
into predefined noise levels for each quantized scheduling step of | |
the diffusion sampler. For example, if the diffusion model sampler | |
has noise level of [0.9977, 0.9912, 0.9735, 0.8499, 0.5840], which | |
is the default noise level of this module with schedule [0, 4, 12, | |
25, 37], the masks are split into binary masks whose values are | |
greater than these levels. This results in tradual increase of mask | |
region as the timesteps increase. Details are described in our | |
paper at https://arxiv.org/pdf/2403.09055.pdf. | |
On the Three Modes of `mask_type`: | |
`self.mask_type` is predefined at the initialization stage of this | |
pipeline. Three possible modes are available: 'discrete', 'semi- | |
continuous', and 'continuous'. These define the mask quantization | |
modes we use. Basically, this (subtly) controls the smoothness of | |
foreground-background blending. Continuous modes produces nonbinary | |
masks to further blend foreground and background latents by linear- | |
ly interpolating between them. Semi-continuous masks only applies | |
continuous mask at the last step of the LCM sampler. Due to the | |
large step size of the LCM scheduler, we find that our continuous | |
blending helps generating seamless inpainting and editing results. | |
Args: | |
masks (Union[torch.Tensor, Image.Image, List[Image.Image]]): Masks. | |
strength (Optional[Union[torch.Tensor, float]]): Mask strength that | |
overrides the default value. A globally multiplied factor to | |
the mask at the initial stage of processing. Can be applied | |
seperately for each mask. | |
std (Optional[Union[torch.Tensor, float]]): Mask blurring Gaussian | |
kernel's standard deviation. Overrides the default value. Can | |
be applied seperately for each mask. | |
height (int): The height of the expected generation. Mask is | |
resized to (height//8, width//8) with nearest neighbor inter- | |
polation. | |
width (int): The width of the expected generation. Mask is resized | |
to (height//8, width//8) with nearest neighbor interpolation. | |
use_boolean_mask (bool): Specify this to treat the mask image as | |
a boolean tensor. The retion with dark part darker than 0.5 of | |
the maximal pixel value (that is, 127.5) is considered as the | |
designated mask. | |
timesteps (Optional[torch.Tensor]): Defines the scheduler noise | |
levels that acts as bins of mask quantization. | |
preprocess_mask_cover_alpha (Optional[float]): Optional pre- | |
processing where each mask covered by other masks is reduced in | |
its alpha value by this specified factor. Overrides the default | |
value. | |
Returns: A tuple of tensors. | |
- masks: Preprocessed (ordered, blurred, and quantized) binary/non- | |
binary masks (see the explanation on `mask_type` above) for | |
region-based image synthesis. | |
- masks_blurred: Gaussian blurred masks. Used for optionally | |
specified foreground-background blending after image | |
generation. | |
- std: Mask blur standard deviation. Used for optionally specified | |
foreground-background blending after image generation. | |
""" | |
if isinstance(masks, Image.Image): | |
masks = [masks] | |
if isinstance(masks, (tuple, list)): | |
# Assumes white background for Image.Image; | |
# inverted boolean masks with shape (1, 1, H, W) for torch.Tensor. | |
if use_boolean_mask: | |
proc = lambda m: T.ToTensor()(m)[None, -1:] < 0.5 | |
else: | |
proc = lambda m: 1.0 - T.ToTensor()(m)[None, -1:] | |
masks = torch.cat([proc(mask) for mask in masks], dim=0).float().clip_(0, 1) | |
masks = F.interpolate(masks.float(), size=(height, width), mode='bilinear', align_corners=False) | |
masks = masks.to(self.device) | |
# Background mask alpha is decayed by the specified factor where foreground masks covers it. | |
if preprocess_mask_cover_alpha is None: | |
preprocess_mask_cover_alpha = self.default_preprocess_mask_cover_alpha | |
if preprocess_mask_cover_alpha > 0: | |
masks = torch.stack([ | |
torch.where( | |
masks[i + 1:].sum(dim=0) > 0, | |
mask * preprocess_mask_cover_alpha, | |
mask, | |
) if i < len(masks) - 1 else mask | |
for i, mask in enumerate(masks) | |
], dim=0) | |
# Scheduler noise levels for mask quantization. | |
if timesteps is None: | |
noise_lvs = self.noise_lvs | |
next_noise_lvs = self.next_noise_lvs | |
else: | |
noise_lvs_ = self.sigmas * (self.sigmas**2 + 1)**(-0.5) | |
# noise_lvs_ = (1 - self.scheduler.alphas_cumprod[timesteps].to(self.device)) ** 0.5 | |
noise_lvs = noise_lvs_[None, :, None, None, None].to(masks.device) | |
next_noise_lvs = torch.cat([noise_lvs_[1:], noise_lvs_.new_zeros(1)])[None, :, None, None, None] | |
# Mask preprocessing parameters are fetched from the default settings. | |
if std is None: | |
std = self.default_mask_std | |
if isinstance(std, (int, float)): | |
std = [std] * len(masks) | |
if isinstance(std, (list, tuple)): | |
std = torch.as_tensor(std, dtype=torch.float, device=self.device) | |
if strength is None: | |
strength = self.default_mask_strength | |
if isinstance(strength, (int, float)): | |
strength = [strength] * len(masks) | |
if isinstance(strength, (list, tuple)): | |
strength = torch.as_tensor(strength, dtype=torch.float, device=self.device) | |
if (std > 0).any(): | |
std = torch.where(std > 0, std, 1e-5) | |
masks = gaussian_lowpass(masks, std) | |
masks_blurred = masks | |
# NOTE: This `strength` aligns with `denoising strength`. However, with LCM, using strength < 0.96 | |
# gives unpleasant results. | |
masks = masks * strength[:, None, None, None] | |
masks = masks.unsqueeze(1).repeat(1, noise_lvs.shape[1], 1, 1, 1) | |
# Mask is quantized according to the current noise levels specified by the scheduler. | |
if self.mask_type == 'discrete': | |
# Discrete mode. | |
masks = masks > noise_lvs | |
elif self.mask_type == 'semi-continuous': | |
# Semi-continuous mode (continuous at the last step only). | |
masks = torch.cat(( | |
masks[:, :-1] > noise_lvs[:, :-1], | |
( | |
(masks[:, -1:] - next_noise_lvs[:, -1:]) / (noise_lvs[:, -1:] - next_noise_lvs[:, -1:]) | |
).clip_(0, 1), | |
), dim=1) | |
elif self.mask_type == 'continuous': | |
# Continuous mode: Have the exact same `1` coverage with discrete mode, but the mask gradually | |
# decreases continuously after the discrete mode boundary to become `0` at the | |
# next lower threshold. | |
masks = ((masks - next_noise_lvs) / (noise_lvs - next_noise_lvs)).clip_(0, 1) | |
# NOTE: Post processing mask strength does not align with conventional 'denoising_strength'. However, | |
# fine-grained mask alpha channel tuning is available with this form. | |
# masks = masks * strength[None, :, None, None, None] | |
h = height // self.vae_scale_factor | |
w = width // self.vae_scale_factor | |
masks = rearrange(masks.float(), 'p t () h w -> (p t) () h w') | |
masks = F.interpolate(masks, size=(h, w), mode='nearest') | |
masks = rearrange(masks.to(self.dtype), '(p t) () h w -> p t () h w', p=len(std)) | |
return masks, masks_blurred, std | |
def scheduler_step( | |
self, | |
noise_pred: torch.Tensor, | |
idx: int, | |
latent: torch.Tensor, | |
) -> torch.Tensor: | |
r"""Denoise-only step for reverse diffusion scheduler. | |
Designed to match the interface of the original `pipe.scheduler.step`, | |
which is a combination of this method and the following | |
`scheduler_add_noise`. | |
Args: | |
noise_pred (torch.Tensor): Noise prediction results from the U-Net. | |
idx (int): Instead of timesteps (in [0, 1000]-scale) use indices | |
for the timesteps tensor (ranged in [0, len(timesteps)-1]). | |
latent (torch.Tensor): Noisy latent. | |
Returns: | |
A denoised tensor with the same size as latent. | |
""" | |
# Upcast to avoid precision issues when computing prev_sample. | |
latent = latent.to(torch.float32) | |
prev_sample = latent - noise_pred * self.sigmas[idx] | |
return prev_sample.to(self.dtype) | |
def scheduler_add_noise( | |
self, | |
latent: torch.Tensor, | |
noise: Optional[torch.Tensor], | |
idx: int, | |
) -> torch.Tensor: | |
r"""Separated noise-add step for the reverse diffusion scheduler. | |
Designed to match the interface of the original | |
`pipe.scheduler.add_noise`. | |
Args: | |
latent (torch.Tensor): Denoised latent. | |
noise (torch.Tensor): Added noise. Can be None. If None, a random | |
noise is newly sampled for addition. | |
idx (int): Instead of timesteps (in [0, 1000]-scale) use indices | |
for the timesteps tensor (ranged in [0, len(timesteps)-1]). | |
Returns: | |
A noisy tensor with the same size as latent. | |
""" | |
if idx < len(self.sigmas) and idx >= 0: | |
noise = torch.randn_like(latent) if noise is None else noise | |
return (1.0 - self.sigmas[idx]) * latent + self.sigmas[idx] * noise | |
else: | |
return latent | |
def __call__( | |
self, | |
prompts: Optional[Union[str, List[str]]] = None, | |
negative_prompts: Union[str, List[str]] = '', | |
suffix: Optional[str] = None, #', background is ', | |
background: Optional[Union[torch.Tensor, Image.Image]] = None, | |
background_prompt: Optional[str] = None, | |
background_negative_prompt: str = '', | |
height: int = 1024, | |
width: int = 1024, | |
num_inference_steps: Optional[int] = None, | |
guidance_scale: Optional[float] = None, | |
prompt_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None, | |
masks: Optional[Union[Image.Image, List[Image.Image]]] = None, | |
mask_strengths: Optional[Union[torch.Tensor, float, List[float]]] = None, | |
mask_stds: Optional[Union[torch.Tensor, float, List[float]]] = None, | |
use_boolean_mask: bool = True, | |
do_blend: bool = True, | |
tile_size: int = 1024, | |
bootstrap_steps: Optional[int] = None, | |
boostrap_mix_steps: Optional[float] = None, | |
bootstrap_leak_sensitivity: Optional[float] = None, | |
preprocess_mask_cover_alpha: Optional[float] = None, | |
# SDXL Pipeline setting. | |
guidance_rescale: float = 0.7, | |
output_type = 'pil', | |
joint_attention_kwargs: Optional[Dict[str, Any]] = None, | |
clip_skip: Optional[int] = None, | |
) -> Image.Image: | |
r"""Arbitrary-size image generation from multiple pairs of (regional) | |
text prompt-mask pairs. | |
This is a main routine for this pipeline. | |
Example: | |
>>> device = torch.device('cuda:0') | |
>>> smd = StableMultiDiffusionPipeline(device) | |
>>> prompts = {... specify prompts} | |
>>> masks = {... specify mask tensors} | |
>>> height, width = masks.shape[-2:] | |
>>> image = smd( | |
>>> prompts, masks=masks.float(), height=height, width=width) | |
>>> image.save('my_beautiful_creation.png') | |
Args: | |
prompts (Union[str, List[str]]): A text prompt. | |
negative_prompts (Union[str, List[str]]): A negative text prompt. | |
suffix (Optional[str]): One option for blending foreground prompts | |
with background prompts by simply appending background prompt | |
to the end of each foreground prompt with this `middle word` in | |
between. For example, if you set this as `, background is`, | |
then the foreground prompt will be changed into | |
`(fg), background is (bg)` before conditional generation. | |
background (Optional[Union[torch.Tensor, Image.Image]]): a | |
background image, if the user wants to draw in front of the | |
specified image. Background prompt will automatically generated | |
with a BLIP-2 model. | |
background_prompt (Optional[str]): The background prompt is used | |
for preprocessing foreground prompt embeddings to blend | |
foreground and background. | |
background_negative_prompt (Optional[str]): The negative background | |
prompt. | |
height (int): Height of a generated image. It is tiled if larger | |
than `tile_size`. | |
width (int): Width of a generated image. It is tiled if larger | |
than `tile_size`. | |
num_inference_steps (Optional[int]): Number of inference steps. | |
Default inference scheduling is used if none is specified. | |
guidance_scale (Optional[float]): Classifier guidance scale. | |
Default value is used if none is specified. | |
prompt_strength (float): Overrides default value. Preprocess | |
foreground prompts globally by linearly interpolating its | |
embedding with the background prompt embeddint with specified | |
mix ratio. Useful control handle for foreground blending. | |
Recommended range: 0.5-1. | |
masks (Optional[Union[Image.Image, List[Image.Image]]]): a list of | |
mask images. Each mask associates with each of the text prompts | |
and each of the negative prompts. If specified as an image, it | |
regards the image as a boolean mask. Also accepts torch.Tensor | |
masks, which can have nonbinary values for fine-grained | |
controls in mixing regional generations. | |
mask_strengths (Optional[Union[torch.Tensor, float, List[float]]]): | |
Overrides the default value. an be assigned for each mask | |
separately. Preprocess mask by multiplying it globally with the | |
specified variable. Caution: extremely sensitive. Recommended | |
range: 0.98-1. | |
mask_stds (Optional[Union[torch.Tensor, float, List[float]]]): | |
Overrides the default value. Can be assigned for each mask | |
separately. Preprocess mask with Gaussian blur with specified | |
standard deviation. Recommended range: 0-64. | |
use_boolean_mask (bool): Turn this off if you want to treat the | |
mask image as nonbinary one. The module will use the last | |
channel of the given image in `masks` as the mask value. | |
do_blend (bool): Blend the generated foreground and the optionally | |
predefined background by smooth boundary obtained from Gaussian | |
blurs of the foreground `masks` with the given `mask_stds`. | |
tile_size (Optional[int]): Tile size of the panorama generation. | |
Works best with the default training size of the Stable- | |
Diffusion model, i.e., 1024 or 1024 for SD1.5 and 1024 for SDXL. | |
bootstrap_steps (int): Overrides the default value. Bootstrapping | |
stage steps to encourage region separation. Recommended range: | |
1-3. | |
boostrap_mix_steps (float): Overrides the default value. | |
Bootstrapping background is a linear interpolation between | |
background latent and the white image latent. This handle | |
controls the mix ratio. Available range: 0-(number of | |
bootstrapping inference steps). For example, 2.3 means that for | |
the first two steps, white image is used as a bootstrapping | |
background and in the third step, mixture of white (0.3) and | |
registered background (0.7) is used as a bootstrapping | |
background. | |
bootstrap_leak_sensitivity (float): Overrides the default value. | |
Postprocessing at each inference step by masking away the | |
remaining bootstrap backgrounds t Recommended range: 0-1. | |
preprocess_mask_cover_alpha (float): Overrides the default value. | |
Optional preprocessing where each mask covered by other masks | |
is reduced in its alpha value by this specified factor. | |
Returns: A PIL.Image image of a panorama (large-size) image. | |
""" | |
### Simplest cases | |
# prompts is None: return background. | |
# masks is None but prompts is not None: return prompts | |
# masks is not None and prompts is not None: Do StableMultiDiffusion. | |
if prompts is None or (isinstance(prompts, (list, tuple, str)) and len(prompts) == 0): | |
if background is None and background_prompt is not None: | |
return sample(background_prompt, background_negative_prompt, height, width, num_inference_steps, guidance_scale) | |
return background | |
elif masks is None or (isinstance(masks, (list, tuple)) and len(masks) == 0): | |
return sample(prompts, negative_prompts, height, width, num_inference_steps, guidance_scale) | |
### Prepare generation | |
if num_inference_steps is not None: | |
self.prepare_flashflowmatch_schedule(list(range(num_inference_steps)), num_inference_steps) | |
if guidance_scale is None: | |
guidance_scale = self.default_guidance_scale | |
self.pipe._guidance_scale = guidance_scale | |
self.pipe._clip_skip = clip_skip | |
self.pipe._joint_attention_kwargs = joint_attention_kwargs | |
self.pipe._interrupt = False | |
do_classifier_free_guidance = guidance_scale > 1.0 | |
### Prompts & Masks | |
# asserts #m > 0 and #p > 0. | |
# #m == #p == #n > 0: We happily generate according to the prompts & masks. | |
# #m != #p: #p should be 1 and we will broadcast text embeds of p through m masks. | |
# #p != #n: #n should be 1 and we will broadcast negative embeds n through p prompts. | |
if isinstance(masks, Image.Image): | |
masks = [masks] | |
if isinstance(prompts, str): | |
prompts = [prompts] | |
if isinstance(negative_prompts, str): | |
negative_prompts = [negative_prompts] | |
num_masks = len(masks) | |
num_prompts = len(prompts) | |
num_nprompts = len(negative_prompts) | |
assert num_prompts in (num_masks, 1), \ | |
f'The number of prompts {num_prompts} should match the number of masks {num_masks}!' | |
assert num_nprompts in (num_prompts, 1), \ | |
f'The number of negative prompts {num_nprompts} should match the number of prompts {num_prompts}!' | |
fg_masks, masks_g, std = self.process_mask( | |
masks, | |
mask_strengths, | |
mask_stds, | |
height=height, | |
width=width, | |
use_boolean_mask=use_boolean_mask, | |
timesteps=self.timesteps, | |
preprocess_mask_cover_alpha=preprocess_mask_cover_alpha, | |
) # (p, t, 1, H, W) | |
bg_masks = (1 - fg_masks.sum(dim=0)).clip_(0, 1) # (T, 1, h, w) | |
has_background = bg_masks.sum() > 0 | |
h = (height + self.vae_scale_factor - 1) // self.vae_scale_factor | |
w = (width + self.vae_scale_factor - 1) // self.vae_scale_factor | |
### Background | |
# background == None && background_prompt == None: Initialize with white background. | |
# background == None && background_prompt != None: Generate background *along with other prompts*. | |
# background != None && background_prompt == None: Retrieve text prompt using BLIP. | |
# background != None && background_prompt != None: Use the given arguments. | |
# not has_background: no effect of prompt_strength (the mix ratio between fg prompt & bg prompt) | |
# has_background && prompt_strength != 1: mix only for this case. | |
bg_latent = None | |
if has_background: | |
if background is None and background_prompt is not None: | |
fg_masks = torch.cat((bg_masks[None], fg_masks), dim=0) | |
if suffix is not None: | |
prompts = [p + suffix + background_prompt for p in prompts] | |
prompts = [background_prompt] + prompts | |
negative_prompts = [background_negative_prompt] + negative_prompts | |
has_background = False # Regard that background does not exist. | |
else: | |
if background is None and background_prompt is None: | |
background = torch.ones(1, 3, height, width, dtype=self.dtype, device=self.device) | |
background_prompt = 'simple white background image' | |
elif background is not None and background_prompt is None: | |
background_prompt = self.get_text_prompts(background) | |
if suffix is not None: | |
prompts = [p + suffix + background_prompt for p in prompts] | |
prompts = [background_prompt] + prompts | |
negative_prompts = [background_negative_prompt] + negative_prompts | |
if isinstance(background, Image.Image): | |
background = T.ToTensor()(background).to(dtype=self.dtype, device=self.device)[None] | |
background = F.interpolate(background, size=(height, width), mode='bicubic', align_corners=False) | |
bg_latent = self.encode_imgs(background) | |
# Bootstrapping stage preparation. | |
if bootstrap_steps is None: | |
bootstrap_steps = self.default_bootstrap_steps | |
if boostrap_mix_steps is None: | |
boostrap_mix_steps = self.default_boostrap_mix_steps | |
if bootstrap_leak_sensitivity is None: | |
bootstrap_leak_sensitivity = self.default_bootstrap_leak_sensitivity | |
if bootstrap_steps > 0: | |
height_ = min(height, tile_size) | |
width_ = min(width, tile_size) | |
white = self.get_white_background(height, width) # (1, 4, h, w) | |
### Prepare text embeddings (optimized for the minimal encoder batch size) | |
# SD3 pipeline settings. | |
batch_size = 1 | |
num_images_per_prompt = 1 | |
original_size = (height, width) | |
target_size = (height, width) | |
crops_coords_top_left = (0, 0) | |
negative_original_size = None | |
negative_target_size = None | |
negative_crops_coords_top_left = (0, 0) | |
prompt_2 = None | |
prompt_3 = None | |
negative_prompt_2 = None | |
negative_prompt_3 = None | |
prompt_embeds = None | |
negative_prompt_embeds = None | |
pooled_prompt_embeds = None | |
negative_pooled_prompt_embeds = None | |
text_encoder_lora_scale = None | |
( | |
prompt_embeds, | |
negative_prompt_embeds, | |
pooled_prompt_embeds, | |
negative_pooled_prompt_embeds, | |
) = self.pipe.encode_prompt( | |
prompt=prompts, | |
prompt_2=prompt_2, | |
prompt_3=prompt_3, | |
negative_prompt=negative_prompts, | |
negative_prompt_2=negative_prompt_2, | |
negative_prompt_3=negative_prompt_3, | |
do_classifier_free_guidance=do_classifier_free_guidance, | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
pooled_prompt_embeds=pooled_prompt_embeds, | |
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
device=self.device, | |
clip_skip=self.pipe.clip_skip, | |
num_images_per_prompt=num_images_per_prompt, | |
) | |
if has_background: | |
# First channel is background prompt text embeds. Background prompt itself is not used for generation. | |
s = prompt_strengths | |
if prompt_strengths is None: | |
s = self.default_prompt_strength | |
if isinstance(s, (int, float)): | |
s = [s] * num_prompts | |
if isinstance(s, (list, tuple)): | |
assert len(s) == num_prompts, \ | |
f'The number of prompt strengths {len(s)} should match the number of prompts {num_prompts}!' | |
s = torch.as_tensor(s, dtype=self.dtype, device=self.device) | |
s = s[:, None, None] | |
be = prompt_embeds[:1] | |
fe = prompt_embeds[1:] | |
prompt_embeds = torch.lerp(be, fe, s) # (p, 77, 1024) | |
if negative_prompt_embeds is not None: | |
bu = negative_prompt_embeds[:1] | |
fu = negative_prompt_embeds[1:] | |
if num_prompts > num_nprompts: | |
# # negative prompts = 1; # prompts > 1. | |
assert fu.shape[0] == 1 and fe.shape == num_prompts | |
fu = fu.repeat(num_prompts, 1, 1) | |
negative_prompt_embeds = torch.lerp(bu, fu, s) # (n, 77, 1024) | |
be = pooled_prompt_embeds[:1] | |
fe = pooled_prompt_embeds[1:] | |
pooled_prompt_embeds = torch.lerp(be, fe, s[..., 0]) # (p, 1280) | |
if negative_pooled_prompt_embeds is not None: | |
bu = negative_pooled_prompt_embeds[:1] | |
fu = negative_pooled_prompt_embeds[1:] | |
if num_prompts > num_nprompts: | |
# # negative prompts = 1; # prompts > 1. | |
assert fu.shape[0] == 1 and fe.shape == num_prompts | |
fu = fu.repeat(num_prompts, 1) | |
negative_pooled_prompt_embeds = torch.lerp(bu, fu, s[..., 0]) # (n, 1280) | |
elif negative_prompt_embeds is not None and num_prompts > num_nprompts: | |
# # negative prompts = 1; # prompts > 1. | |
assert negative_prompt_embeds.shape[0] == 1 and prompt_embeds.shape[0] == num_prompts | |
negative_prompt_embeds = negative_prompt_embeds.repeat(num_prompts, 1, 1) | |
assert negative_pooled_prompt_embeds.shape[0] == 1 and pooled_prompt_embeds.shape[0] == num_prompts | |
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(num_prompts, 1) | |
# assert negative_prompt_embeds.shape[0] == prompt_embeds.shape[0] == num_prompts | |
if num_masks > num_prompts: | |
assert masks.shape[0] == num_masks and num_prompts == 1 | |
prompt_embeds = prompt_embeds.repeat(num_masks, 1, 1) | |
if negative_prompt_embeds is not None: | |
negative_prompt_embeds = negative_prompt_embeds.repeat(num_masks, 1, 1) | |
pooled_prompt_embeds = pooled_prompt_embeds.repeat(num_masks, 1) | |
if negative_pooled_prompt_embeds is not None: | |
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(num_masks, 1) | |
# SD3 pipeline settings. | |
if do_classifier_free_guidance: | |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) | |
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) | |
del negative_prompt_embeds, negative_pooled_prompt_embeds | |
prompt_embeds = prompt_embeds.to(self.device) | |
pooled_prompt_embeds = pooled_prompt_embeds.to(self.device) | |
### Run | |
# Latent initialization. | |
num_channels_latents = self.transformer.config.in_channels | |
noise = torch.randn((1, num_channels_latents, h, w), dtype=self.dtype, device=self.device) | |
if self.timesteps[0] < 999 and has_background: | |
latent = self.scheduler_add_noise(bg_latent, noise, 0) | |
else: | |
noise = torch.randn((1, num_channels_latents, h, w), dtype=self.dtype, device=self.device) | |
latent = noise | |
if has_background: | |
noise_bg_latents = [ | |
self.scheduler_add_noise(bg_latent, noise, i) for i in range(len(self.timesteps)) | |
] + [bg_latent] | |
# Tiling (if needed). | |
if height > tile_size or width > tile_size: | |
t = (tile_size + self.vae_scale_factor - 1) // self.vae_scale_factor | |
views, tile_masks = get_panorama_views(h, w, t) | |
tile_masks = tile_masks.to(self.device) | |
else: | |
views = [(0, h, 0, w)] | |
tile_masks = latent.new_ones((1, 1, h, w)) | |
value = torch.zeros_like(latent) | |
count_all = torch.zeros_like(latent) | |
with torch.autocast('cuda'): | |
for i, t in enumerate(tqdm(self.timesteps)): | |
if self.pipe.interrupt: | |
continue | |
fg_mask = fg_masks[:, i] | |
bg_mask = bg_masks[i:i + 1] | |
value.zero_() | |
count_all.zero_() | |
for j, (h_start, h_end, w_start, w_end) in enumerate(views): | |
fg_mask_ = fg_mask[..., h_start:h_end, w_start:w_end] | |
latent_ = latent[..., h_start:h_end, w_start:w_end].repeat(num_masks, 1, 1, 1) | |
# Bootstrap for tight background. | |
if i < bootstrap_steps: | |
mix_ratio = min(1, max(0, boostrap_mix_steps - i)) | |
# Treat the first foreground latent as the background latent if one does not exist. | |
bg_latent_ = noise_bg_latents[i][..., h_start:h_end, w_start:w_end] if has_background else latent_[:1] | |
white_ = white[..., h_start:h_end, w_start:w_end] | |
white_ = self.scheduler_add_noise(white_, noise[..., h_start:h_end, w_start:w_end], i) | |
bg_latent_ = mix_ratio * white_ + (1.0 - mix_ratio) * bg_latent_ | |
latent_ = (1.0 - fg_mask_) * bg_latent_ + fg_mask_ * latent_ | |
# Centering. | |
latent_ = shift_to_mask_bbox_center(latent_, fg_mask_, reverse=True) | |
# expand the latents if we are doing classifier free guidance | |
latent_model_input = torch.cat([latent_] * 2) if do_classifier_free_guidance else latent_ | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
timestep = t.expand(latent_model_input.shape[0]) | |
# Perform one step of the reverse diffusion. | |
noise_pred = self.transformer( | |
hidden_states=latent_model_input, | |
timestep=timestep, | |
encoder_hidden_states=prompt_embeds, | |
pooled_projections=pooled_prompt_embeds, | |
joint_attention_kwargs=joint_attention_kwargs, | |
return_dict=False, | |
)[0] | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
if do_classifier_free_guidance and guidance_rescale > 0.0: | |
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf | |
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_cond, guidance_rescale=guidance_rescale) | |
latent_ = self.scheduler_step(noise_pred, i, latent_) | |
if i < bootstrap_steps: | |
# Uncentering. | |
latent_ = shift_to_mask_bbox_center(latent_, fg_mask_) | |
# Remove leakage (optional). | |
leak = (latent_ - bg_latent_).pow(2).mean(dim=1, keepdim=True) | |
leak_sigmoid = torch.sigmoid(leak / bootstrap_leak_sensitivity) * 2 - 1 | |
fg_mask_ = fg_mask_ * leak_sigmoid | |
# Mix the latents. | |
fg_mask_ = fg_mask_ * tile_masks[:, j:j + 1, h_start:h_end, w_start:w_end] | |
value[..., h_start:h_end, w_start:w_end] += (fg_mask_ * latent_).sum(dim=0, keepdim=True) | |
count_all[..., h_start:h_end, w_start:w_end] += fg_mask_.sum(dim=0, keepdim=True) | |
latent = torch.where(count_all > 0, value / count_all, value) | |
bg_mask = (1 - count_all).clip_(0, 1) # (T, 1, h, w) | |
if has_background: | |
latent = (1 - bg_mask) * latent + bg_mask * noise_bg_latents[i + 1] # bg_latent | |
# Noise is added after mixing. | |
if i < len(self.timesteps) - 1: | |
latent = self.scheduler_add_noise(latent, None, i + 1) | |
if not output_type == "latent": | |
latent = (latent / self.vae.config.scaling_factor) + self.vae.config.shift_factor | |
image = self.vae.decode(latent, return_dict=False)[0] | |
else: | |
image = latent | |
# Return PIL Image. | |
image = image[0].clip_(-1, 1) * 0.5 + 0.5 | |
if has_background and do_blend: | |
fg_mask = torch.sum(masks_g, dim=0).clip_(0, 1) | |
image = blend(image, background[0], fg_mask) | |
else: | |
image = T.ToPILImage()(image) | |
return image | |