PartEdit / stable_diffusion_xl_partedit.py
Gorluxor
fixed cache path
5822221
# Based on stable_diffusion_reference.py
# Based on https://github.com/RoyiRa/prompt-to-prompt-with-sdxl
from __future__ import annotations
import abc
import typing
from collections.abc import Iterable
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import einops
import numpy as np
import torch
import torch.nn.functional as F
from diffusers import AutoencoderKL, StableDiffusionXLPipeline, UNet2DConditionModel
from diffusers import __version__ as diffusers_version
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.models.attention_processor import AttnProcessor2_0
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
rescale_noise_cfg,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import (
StableDiffusionXLPipelineOutput,
)
from diffusers.utils import (
USE_PEFT_BACKEND,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.import_utils import is_invisible_watermark_available
from packaging import version
from PIL import Image
from safetensors.torch import load_file
from torchvision.transforms import ToPILImage, ToTensor
from torchvision.utils import make_grid
from transformers import CLIPImageProcessor
if is_invisible_watermark_available():
from diffusers.pipelines.stable_diffusion_xl.watermark import (
StableDiffusionXLWatermarker,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
try:
from diffusers import LEditsPPPipelineStableDiffusionXL, EulerDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler
except ImportError as e:
logger.error("DPMSolverMultistepScheduler or LEditsPPPipelineStableDiffusionXL not found. Verified on >= 0.29.1")
from diffusers import DDIMScheduler, EulerDiscreteScheduler
if typing.TYPE_CHECKING:
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from transformers import (
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from diffusers.models.attention import Attention
from diffusers.schedulers import KarrasDiffusionSchedulers
# Original implementation from
# Updated to reflect
class PartEditPipeline(StableDiffusionXLPipeline):
r"""
PartEditPipeline for text-to-image generation Pusing Stable Diffusion XL with SD1.5 NSFW checker.
This model inherits from [`StableDiffusionXLPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion XL uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
text_encoder_2 ([` CLIPTextModelWithProjection`]):
Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
specifically the
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer_2 (`CLIPTokenizer`):
Second Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
`stabilityai/stable-diffusion-xl-base-1-0`.
add_watermarker (`bool`, *optional*):
Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
watermark output images. If not defined, it will default to True if the package is installed, otherwise no
watermarker will be used.
"""
_optional_components = ["feature_extractor", "add_watermarker, safety_checker"]
# Added back from stable_diffusion_reference.py with safety_check to instantiate the NSFW checker from SD1.5
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
text_encoder_2: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
image_encoder: CLIPVisionModelWithProjection = None,
feature_extractor: CLIPImageProcessor = None,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
safety_checker: Optional[StableDiffusionSafetyChecker] = None,
):
if safety_checker is not None:
assert isinstance(safety_checker, StableDiffusionSafetyChecker), f"Expected safety_checker to be of type StableDiffusionSafetyChecker, got {type(safety_checker)}"
assert feature_extractor is not None, "Feature Extractor must be present to use the NSFW checker"
super().__init__(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
unet=unet,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
force_zeros_for_empty_prompt=force_zeros_for_empty_prompt,
add_watermarker=add_watermarker,
)
self.register_modules(
safety_checker=safety_checker,
)
# self.warn_once_callback = True
@staticmethod
def default_pipeline(device, precision=torch.float16, scheduler_type: str = "euler", load_safety: bool = False) -> Tuple[StableDiffusionXLPipeline, PartEditPipeline]:
if scheduler_type.strip().lower() in ["ddim", "editfriendly"]:
scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler", torch_dtype=precision) # Edit Friendly DDPM
elif scheduler_type.strip().lower() in "leditspp":
scheduler = DPMSolverMultistepScheduler.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler", algorithm_type="sde-dpmsolver++", solver_order=2
) # LEdits
else:
scheduler = EulerDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler", torch_dtype=precision)
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix",
torch_dtype=precision,
use_safetensors=True,
resume_download=None,
)
default_pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
device=device,
vae=vae,
resume_download=None,
scheduler=DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler", torch_dtype=precision),
torch_dtype=precision,
)
safety_checker = (
StableDiffusionSafetyChecker.from_pretrained(
"benjamin-paine/stable-diffusion-v1-5", # runwayml/stable-diffusion-v1-5",
device_map=device,
torch_dtype=precision,
subfolder="safety_checker",
)
if load_safety
else None
)
feature_extractor = (
CLIPImageProcessor.from_pretrained(
"benjamin-paine/stable-diffusion-v1-5", # "runwayml/stable-diffusion-v1-5",
subfolder="feature_extractor",
device_map=device,
)
if load_safety
else None
)
pipeline: PartEditPipeline = PartEditPipeline(
vae=vae,
tokenizer=default_pipe.tokenizer,
tokenizer_2=default_pipe.tokenizer_2,
text_encoder=default_pipe.text_encoder,
text_encoder_2=default_pipe.text_encoder_2,
unet=default_pipe.unet,
scheduler=scheduler,
image_encoder=default_pipe.image_encoder,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
return default_pipe.to(device), pipeline.to(device)
def check_inputs(
self,
prompt,
prompt_2,
height,
width,
callback_steps,
negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
ip_adapter_image=None,
ip_adapter_image_embeds=None,
callback_on_step_end_tensor_inputs=None,
# PartEdit stuff
embedding_opt: Optional[torch.FloatTensor] = None,
):
# Check version of diffusers
extra_params = (
{
"ip_adapter_image": ip_adapter_image,
"ip_adapter_image_embeds": ip_adapter_image_embeds,
}
if version.parse(diffusers_version) >= version.parse("0.27.0")
else {}
)
# Use super to check the inputs from the parent class
super(PartEditPipeline, self).check_inputs(
prompt,
prompt_2,
height,
width,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
**extra_params,
)
# PartEdit checks
if embedding_opt is not None:
assert embedding_opt.ndim == 2, f"Embedding should be of shape (2, features), got {embedding_opt.shape}"
assert embedding_opt.shape[-1] == 2048, f"SDXL Embedding should have 2048 features, got {embedding_opt.shape[1]}"
assert embedding_opt.dtype in [
torch.float32,
torch.float16,
], f"Embedding should be of type fp32/fp16, got {embedding_opt.dtype}"
assert hasattr(self, "controller"), "Controller should be present"
assert hasattr(self.controller, "extra_kwargs"), "Controller should have extra_kwargs"
extra_kwargs: DotDictExtra = self.controller.extra_kwargs
strategy: Binarization = extra_kwargs.th_strategy
assert isinstance(strategy, Binarization), f"Expected strategy to be of type Binarization, got {type(strategy)}"
assert hasattr(extra_kwargs, "pad_strategy"), "Controller should have pad_strategy"
assert isinstance(extra_kwargs.pad_strategy, PaddingStrategy), f"Expected pad_strategy to be of type PaddingStrategy, got {type(self.controller.extra_kwargs.pad_strategy)}"
if strategy in [Binarization.PROVIDED_MASK]:
assert hasattr(extra_kwargs, "mask_edit"), "Mask should be present in extra_kwargs"
def _aggregate_and_get_attention_maps_per_token(self, with_softmax, select: int = 0, res: int = 32):
attention_maps = self.controller.aggregate_attention(
res=res,
from_where=("up", "down", "mid"),
batch_size=self.controller.batch_size,
is_cross=True,
select=select,
)
attention_maps_list = self._get_attention_maps_list(attention_maps=attention_maps, with_softmax=with_softmax)
return attention_maps_list
@staticmethod
def _get_attention_maps_list(attention_maps: torch.Tensor, with_softmax) -> List[torch.Tensor]:
attention_maps *= 100
if with_softmax:
attention_maps = torch.nn.functional.softmax(attention_maps, dim=-1)
attention_maps_list = [attention_maps[:, :, i] for i in range(attention_maps.shape[2])]
return attention_maps_list
@torch.inference_mode() # if this gives problems change back to @torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
denoising_end: Optional[float] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
negative_original_size: Optional[Tuple[int, int]] = None,
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
negative_target_size: Optional[Tuple[int, int]] = None,
attn_res=None,
callback_on_step_end: Optional[Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
# PartEdit
embedding_opt: Optional[Union[torch.FloatTensor, str]] = None,
extra_kwargs: Optional[Union[dict, DotDictExtra]] = None, # All params, check DotDictExtra
uncond_embeds: Optional[torch.FloatTensor] = None, # Unconditional embeddings from Null text inversion
latents_list=None,
zs=None,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
The keyword arguments to configure the edit are:
- edit_type (`str`). The edit type to apply. Can be either of `replace`, `refine`, `reweight`.
- n_cross_replace (`int`): Number of diffusion steps in which cross attention should be replaced
- n_self_replace (`int`): Number of diffusion steps in which self attention should be replaced
- local_blend_words(`List[str]`, *optional*, default to `None`): Determines which area should be
changed. If None, then the whole image can be changed.
- equalizer_words(`List[str]`, *optional*, default to `None`): Required for edit type `reweight`.
Determines which words should be enhanced.
- equalizer_strengths (`List[float]`, *optional*, default to `None`) Required for edit type `reweight`.
Determines which how much the words in `equalizer_words` should be enhanced.
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
PartEdit Parameters:
embedding_opt (`Union[torch.FloatTensor, str]`, *optional*): The embedding to be inserted in the prompt. The embedding
will be inserted as third batch dimension.
extra_kwargs (`dict`, *optional*): A dictionary with extra parameters to be passed to the pipeline.
- Check `pipe.part_edit_available_params()` for the available parameters.
Returns:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
original_size = original_size or (height, width)
target_size = target_size or (height, width)
# PartEdit setup
extra_kwargs = DotDictExtra() if extra_kwargs is None else DotDictExtra(extra_kwargs)
prompt = prompt + [prompt[0]] if prompt[0] != prompt[-1] else prompt # Add required extra batch if not present
extra_kwargs.batch_indx = len(prompt) - 1 if extra_kwargs.batch_indx == -1 else extra_kwargs.batch_indx
add_extra_step = extra_kwargs.add_extra_step
if attn_res is None:
attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32))
self.attn_res = attn_res
# _prompts = prompt if embedding_opt is None else prompt + [prompt[-1]]
if hasattr(self, "controller"):
self.controller.reset()
self.controller = create_controller(
prompt,
cross_attention_kwargs,
num_inference_steps,
tokenizer=self.tokenizer,
device=self.device,
attn_res=self.attn_res,
extra_kwargs=extra_kwargs,
)
assert self.controller is not None
assert issubclass(type(self.controller), AttentionControl)
self.register_attention_control(
self.controller,
) # add attention controller
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
height,
width,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# batch_size = batch_size + 1 if embedding_opt is not None else batch_size
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
latents[1] = latents[0]
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
add_time_ids = self._get_add_time_ids(
original_size,
crops_coords_top_left,
target_size,
dtype=prompt_embeds.dtype,
text_encoder_projection_dim=self.text_encoder_2.config.projection_dim, # if none should be changed to enc1
)
if negative_original_size is not None and negative_target_size is not None:
negative_add_time_ids = self._get_add_time_ids(
negative_original_size,
negative_crops_coords_top_left,
negative_target_size,
dtype=prompt_embeds.dtype,
)
else:
negative_add_time_ids = add_time_ids
# PartEdit:
prompt_embeds = self.process_embeddings(embedding_opt, prompt_embeds, self.controller.pad_strategy)
self.prompt_embeds = prompt_embeds
if do_classifier_free_guidance:
_og_prompt_embeds = prompt_embeds.clone()
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 7.1 Apply denoising_end
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
discrete_timestep_cutoff = int(round(self.scheduler.config.num_train_timesteps - (denoising_end * self.scheduler.config.num_train_timesteps)))
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
# PartEdit
if hasattr(self, "debug_list"): # if its disabled and there was a list
del self.debug_list
if extra_kwargs.debug_vis:
self.debug_list = []
if add_extra_step:
num_inference_steps += 1
timesteps = torch.cat([timesteps[[0]], timesteps], dim=-1)
_latents = latents.clone()
self._num_timesteps = len(timesteps) # Same as in SDXL
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# if i in range(50):
# latents[0] = latents_list[i]
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# NOTE(Alex): Null text inversion usage
if uncond_embeds is not None:
# if callback_on_step_end is not None and self.warn_once_callback:
# self.warn_once_callback = False
# logger.warning("Callback on step end is not supported with Null text inversion - Know what you are doing!")
_indx_to_use = i if i < len(uncond_embeds) else len(uncond_embeds) - 1 # use last if we have extra steps
# _og_prompt_embeds
curr = uncond_embeds[_indx_to_use].to(dtype=prompt_embeds.dtype).to(device).repeat(_og_prompt_embeds.shape[0], 1, 1)
prompt_embeds = torch.cat([curr, _og_prompt_embeds], dim=0) # For now not changing the pooled prompt embeds
# if prompt_embeds.shape != (2, 77, 2048):
# print(f"Prompt Embeds should be of shape (2, 77, 2048), got {prompt_embeds.shape}")
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
added_cond_kwargs=added_cond_kwargs,
).sample
if add_extra_step: # PartEdit
latents = _latents.clone()
add_extra_step = False
progress_bar.update()
self.scheduler._init_step_index(t)
continue # we just wanted the unet, not to do the step
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
# gs = torch.tensor([guidance_scale] * len(noise_pred_uncond),
# device=noise_pred.device, dtype= noise_pred.dtype).view(-1, 1, 1, 1)
# gs[0] = 7.5
# our_gs = torch.FloatTensor([1.0, guidance_scale, 1.0]).view(-1, 1, 1, 1).to(latents.device, dtype=latents.dtype)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1 # synth
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
)
# inv
# latents = self.scheduler.step(noise_pred, t, latents, variance_noise=zs[i], **extra_step_kwargs)
if extra_kwargs.debug_vis: # PartEdit
# Could be removed, with .prev_sample above
self.debug_list.append(latents.pred_original_sample.cpu())
latents = latents.prev_sample # Needed here because of logging above
# step callback
latents = self.controller.step_callback(latents)
# Note(Alex): Copied from SDXL
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop("negative_pooled_prompt_embeds", negative_pooled_prompt_embeds)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
if embedding_opt is not None: # PartEdit
us_dx = 0
if i == 0 and us_dx != 0:
print(f'Using lantents[{us_dx}] instead of latents[0]')
latents[-1:] = latents[us_dx] # always tie the diff process
# if embedding_opt is not None and callback_on_step_end is not None and \
# callback_on_step_end.reversed_latents is not None:
# latents[-1:] = callback_on_step_end.reversed_latents[i]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# 8. Post-processing
if output_type == "latent":
image = latents
else:
self.final_map = self.controller.visualize_final_map(False)
# Added to support lower VRAM gpus
self.controller.offload_stores(torch.device("cpu"))
image = self.latent2image(latents, device, output_type, force_upcast=False)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return image
self.grid = self.visualize_maps()
# Disable editing in case of
self.unregister_attention_control()
# Did not add NSFW output as it is not part of XLPipelineOuput
return StableDiffusionXLPipelineOutput(images=image)
@torch.no_grad()
def latent2image(
self: PartEditPipeline,
latents: torch.Tensor,
device: torch.device,
output_type: str = "pil", # ['latent', 'pt', 'np', 'pil']
force_upcast: bool = False,
) -> Union[torch.Tensor, np.ndarray, Image.Image]:
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast or force_upcast
latents = latents.to(device)
if needs_upcasting:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
# cast back to fp16 if needed
if needs_upcasting and not force_upcast:
self.vae.to(dtype=torch.float16)
image, has_nsfw_concept = self.run_safety_checker(image, device, latents.dtype)
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
if not all(do_denormalize):
logger.warn(
"NSFW detected in the following images: %s",
", ".join([f"image {i + 1}" for i, has_nsfw in enumerate(has_nsfw_concept) if has_nsfw]),
)
if self.watermark is not None:
image = self.watermark.apply_watermark(image)
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
if output_type in ["pt", "latent"]:
image = image.cpu()
latents = latents.cpu()
return image
def run_safety_checker(self, image: Union[np.ndarray, torch.Tensor], device: torch.device, dtype: type):
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
return image, has_nsfw_concept
def register_attention_control(self, controller):
attn_procs = {}
cross_att_count = 0
self.attn_names = {} # Name => Idx
for name in self.unet.attn_processors:
(None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim)
if name.startswith("mid_block"):
self.unet.config.block_out_channels[-1]
place_in_unet = "mid"
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
list(reversed(self.unet.config.block_out_channels))[block_id]
place_in_unet = "up"
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
self.unet.config.block_out_channels[block_id]
place_in_unet = "down"
else:
continue
attn_procs[name] = PartEditCrossAttnProcessor(controller=controller, place_in_unet=place_in_unet)
# print(f'{cross_att_count}=>{name}')
cross_att_count += 1
self.unet.set_attn_processor(attn_procs)
controller.num_att_layers = cross_att_count
def unregister_attention_control(self):
# if pytorch >= 2.0
self.unet.set_attn_processor(AttnProcessor2_0())
if hasattr(self, "controller") and self.controller is not None:
if hasattr(self.controller, "last_otsu"):
self.last_otsu_value = self.controller.last_otsu[-1]
del self.controller
# self.controller.allow_edit_control = False
def available_params(self) -> str:
pipeline_params = """
Pipeline Parameters:
embedding_opt (`Union[torch.FloatTensor, str]`, *optional*): The embedding to be inserted in the prompt. The embedding
will be inserted as third batch dimension.
extra_kwargs (`dict`, *optional*): A dictionary with extra parameters to be passed to the pipeline.
- Check `pipe.part_edit_available_params()` for the available parameters.
"""
return pipeline_params + "\n" + self.part_edit_available_params()
def process_embeddings(
self,
embedding_opt: Optional[Union[torch.FloatTensor, str]],
prompt_embeds: torch.FloatTensor,
padd_strategy: PaddingStrategy,
) -> torch.Tensor:
return process_embeddings(embedding_opt, prompt_embeds, padd_strategy)
def part_edit_available_params(self) -> str:
return DotDictExtra().explain()
# def run_sa
def visualize_maps(self, make_grid_kwargs: dict = None):
"""Wrapper function to select correct storage location"""
if not hasattr(self, "controller") or self.controller is None:
return self.grid if hasattr(self, "grid") else None
return self.controller.visualize_maps_agg(
self.controller.use_agg_store,
make_grid_kwargs=make_grid_kwargs,
)
def visualize_map_across_time(self):
"""Wrapper function to visualize the same as above, but as one mask"""
if hasattr(self, "final_map") and self.final_map is not None:
return self.final_map
return self.controller.visualize_final_map(self.controller.use_agg_store)
def process_embeddings(
embedding_opt: Optional[Union[torch.Tensor, str]],
prompt_embeds: torch.Tensor,
padd_strategy: PaddingStrategy,
) -> torch.Tensor:
if embedding_opt is None:
return prompt_embeds
assert isinstance(padd_strategy, PaddingStrategy), f"padd_strategy must be of type PaddingStrategy, got {type(padd_strategy)}"
if isinstance(embedding_opt, str):
embedding_opt = load_file(embedding_opt)["embedding"] if "safetensors" in embedding_opt else torch.load(embedding_opt)
elif isinstance(embedding_opt, list):
e = [load_file(i)["embedding"] if "safetensors" in i else torch.load(i) for i in embedding_opt]
embedding_opt = torch.cat(e, dim=0)
print(f'Embedding Opt shape: {embedding_opt.shape=}')
embedding_opt = embedding_opt.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
if embedding_opt.ndim == 2:
embedding_opt = embedding_opt[None]
num_embeds = embedding_opt.shape[1] # BG + Num of classes
prompt_embeds[-1:, :num_embeds, :] = embedding_opt[:, :num_embeds, :]
if PaddingStrategy.context == padd_strategy:
return prompt_embeds
if not (hasattr(padd_strategy, "norm") and hasattr(padd_strategy, "scale")):
raise ValueError(f"PaddingStrategy with {padd_strategy} not recognized")
_norm, _scale = padd_strategy.norm, padd_strategy.scale
if padd_strategy == PaddingStrategy.BG:
prompt_embeds[-1:, num_embeds:, :] = embedding_opt[:, :1, :]
elif padd_strategy == PaddingStrategy.EOS:
prompt_embeds[-1:, num_embeds:, :] = prompt_embeds[-1:, -1:, :]
elif padd_strategy == PaddingStrategy.ZERO:
prompt_embeds[-1:, num_embeds:, :] = 0.0
elif padd_strategy == PaddingStrategy.SOT_E:
prompt_embeds[-1:, num_embeds:, :] = prompt_embeds[-1:, :1, :]
else:
raise ValueError(f"{padd_strategy} not recognized")
# Not recommended
if _norm:
prompt_embeds[-1:, :, :] = F.normalize(prompt_embeds[-1:, :, :], p=2, dim=-1)
if _scale:
_eps = 1e-8
_min, _max = prompt_embeds[:1].min(), prompt_embeds[:1].max()
if _norm:
prompt_embeds = (prompt_embeds - _min) / (_max - _min + _eps)
else:
_new_min, _new_max = (
prompt_embeds[-1:, num_embeds:, :].min(),
prompt_embeds[-1:, num_embeds:, :].max(),
)
prompt_embeds[-1:, num_embeds:, :] = (prompt_embeds[-1:, num_embeds:, :] - _new_min) / (_new_max - _new_min + _eps)
prompt_embeds[-1:, num_embeds:, :] = prompt_embeds[-1:, num_embeds:, :] * (_max - _min + _eps) + _min
return prompt_embeds
# Depends on layers used to train with
LAYERS_TO_USE = [
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
49,
50,
51,
52,
53,
54,
55,
56,
57,
58,
59,
0,
1,
2,
3,
] # noqa: E501
class Binarization(Enum):
"""Controls the binarization of attn maps
in case of use_otsu lower_binarize and upper_binarizer are multilpiers of otsu threshold
args:
strategy: str: name of the strategy
enabled: bool: if binarization is enabled
lower_binarize: float: lower threshold for binarization
upper_binarize: float: upper threshold for binarization
use_otsu: bool: if otsu is used for binarization
"""
P2P = "p2p", False, 0.5, 0.5, False # Baseline
PROVIDED_MASK = "mask", True, 0.5, 0.5, False
BINARY_0_5 = "binary_0.5", True, 0.5, 0.5, False
BINARY_OTSU = "binary_otsu", True, 1.0, 1.0, True
PARTEDIT = "partedit", True, 0.5, 1.5, True
DISABLED = "disabled", False, 0.5, 0.5, False
def __new__(
cls,
strategy: str,
enabled: bool,
lower_binarize: float,
upper_binarize: float,
use_otsu: bool,
) -> "Binarization":
obj = object.__new__(cls)
obj._value_ = strategy
obj.enabled = enabled
obj.lower_binarize = lower_binarize
obj.upper_binarize = upper_binarize
obj.use_otsu = use_otsu
assert isinstance(obj.enabled, bool), "enabled should be of type bool"
assert isinstance(obj.lower_binarize, float), "lower_binarize should be of type float"
assert isinstance(obj.upper_binarize, float), "upper_binarize should be of type float"
assert isinstance(obj.use_otsu, bool), "use_otsu should be of type bool"
return obj
def __eq__(self, other: Optional[Union[Binarization, str]] = None) -> bool:
if not other:
return False
if isinstance(other, Binarization):
return self.value.lower() == other.value.lower()
if isinstance(other, str):
return self.value.lower() == other.lower()
@staticmethod
def available_strategies() -> List[str]:
return [strategy.name for strategy in Binarization]
def __str__(self) -> str:
return f"Binarization: {self.name} (Enabled: {self.enabled} Lower: {self.lower_binarize} Upper: {self.upper_binarize} Otsu: {self.use_otsu})"
@staticmethod
def from_string(
strategy: str,
enabled: Optional[bool] = None,
lower_binarize: Optional[bool] = None,
upper_binarize: Optional[float] = None,
use_otsu: Optional[bool] = None,
) -> Binarization:
strategy = strategy.strip().lower()
for _strategy in Binarization:
if _strategy.name.lower() == strategy:
if enabled is not None:
_strategy.enabled = enabled
if lower_binarize is not None:
_strategy.lower_binarize = lower_binarize
if upper_binarize is not None:
_strategy.upper_binarize = upper_binarize
if use_otsu is not None:
_strategy.use_otsu = use_otsu
return _strategy
raise ValueError(f"binarization_strategy={strategy} not recognized")
class PaddingStrategy(Enum):
# Default
BG = "BG", False, False
# Others added just for experimentation reasons
context = "context", False, False
EOS = "EoS", False, False
ZERO = "zero", False, False
SOT_E = "SoT_E", False, False
def __new__(cls, strategy: str, norm: bool, scale: bool) -> "PaddingStrategy":
obj = object.__new__(cls)
obj._value_ = strategy
obj.norm = norm
obj.scale = scale
return obj
# compare based on value
def __eq__(self, other: Optional[Union[PaddingStrategy, str]] = None) -> bool:
if not other:
return False
if isinstance(other, PaddingStrategy):
return self.value.lower() == other.value.lower()
if isinstance(other, str):
return self.value.lower() == other.lower()
@staticmethod
def available_strategies() -> List[str]:
return [strategy.name for strategy in PaddingStrategy]
def __str__(self) -> str:
return f"PaddStrategy: {self.name} Norm: {self.norm} Scale: {self.scale}"
@staticmethod
def from_string(strategy_str, norm: Optional[bool] = False, scale: Optional[bool] = False) -> "PaddingStrategy":
for strategy in PaddingStrategy:
if strategy.name.lower() == strategy_str.lower():
if norm is not None:
strategy.norm = norm
if scale is not None:
strategy.scale = scale
return strategy
raise ValueError(f"padd_strategy={strategy} not recognized")
class DotDictExtra(dict):
"""
dot.notation access to dictionary attributes
Holds default values for the extra_kwargs
"""
__getattr__ = dict.get
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
_layers_to_use = LAYERS_TO_USE # Training parameter, not exposed directly
_enable_non_agg_storing = False # Useful for visualization but very VRAM heavy! ~35GB without offload 14GB with offload
_cpu_offload = False # Lowers VRAM but Slows down drastically, hidden
_default = {
"th_strategy": Binarization.PARTEDIT,
"pad_strategy": PaddingStrategy.BG,
"omega": 1.5, # values should be between 0.25 and 2.0
"use_agg_store": False,
"edit_mask": None,
"edit_steps": 50, # End at this step
"start_editing_at": 0, # Recommended, but exposed in case of wanting to change
"use_layer_subset_idx": None, # In case we want to use specific layers, NOTE: order not aligned with UNet lаyers
"add_extra_step": False,
"batch_indx": -1, # assume last batch
"blend_layers": None,
"force_cross_attn": False, # Force cross attention to maps
# Optimization stuff
"VRAM_low": True, # Leave on by default, except if causing erros
"grounding": None,
}
_default_explanations = {
"th_strategy": "Binarization strategy for attention maps",
"pad_strategy": "Padding strategy for the added tokens",
"omega": "Omega value for the PartEdit",
"use_agg_store": "If the attention maps should be aggregated",
"add_extra_step": "If extra 0 step should be added to the diffusion process",
"edit_mask": "Mask for the edit when using ProvidedMask strategy",
"edit_steps": "Number of edit steps",
"start_editing_at": "Step at which the edit should start",
"use_layer_subset_idx": "Sublayers to use, recommended 0-8 if really needed to use some",
"VRAM_low": "Recommended to not change",
"force_cross_attn": "Force cross attention to use OPT token maps",
}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
for key, value in self._default.items():
if key not in self:
self[key] = value
# Extra changes to Binarization, PaddingStrategy
if isinstance(self["th_strategy"], str):
self["th_strategy"] = Binarization.from_string(self["th_strategy"])
if isinstance(self["pad_strategy"], str):
self["pad_strategy"] = PaddingStrategy.from_string(self["pad_strategy"])
self["edit_steps"] = self["edit_steps"] + self["add_extra_step"]
if self.edit_mask is not None :
if isinstance(self.edit_mask, str):
# load with PIL or torch/safetensors
if self.edit_mask.endswith(".safetensors"):
self.edit_mask = load_file(self.edit_mask)["edit_mask"]
elif self.edit_mask.endswith(".pt"):
self.edit_mask = torch.load(self.edit_mask)["edit_mask"]
else:
self.edit_mask = Image.open(self.edit_mask)
if isinstance(self.edit_mask, Image.Image):
self.edit_mask = ToTensor()(self.edit_mask.convert("L"))
elif isinstance(self.edit_mask, np.ndarray):
self.edit_mask = torch.from_numpy(self.edit_mask).unsqueeze(0)
if self.edit_mask.ndim == 2:
self.edit_mask = self.edit_mask[None, None, ...]
elif self.edit_mask.ndim == 3:
self.edit_mask = self.edit_mask[None, ...]
if self.edit_mask.max() > 1.0:
self.edit_mask = self.edit_mask / self.edit_mask.max()
if self.grounding is not None: # same as above, but slightly different function
if isinstance(self.grounding, Image.Image):
self.grounding = ToTensor()(self.grounding.convert("L"))
elif isinstance(self.grounding, np.ndarray):
self.grounding = torch.from_numpy(self.grounding).unsqueeze(0)
if self.grounding.ndim == 2:
self.grounding = self.grounding[None, None, ...]
elif self.grounding.ndim == 3:
self.grounding = self.grounding[None, ...]
if self.grounding.max() > 1.0:
self.grounding = self.grounding / self.grounding.max()
assert isinstance(self.th_strategy, Binarization), "th_strategy should be of type Binarization"
assert isinstance(self.pad_strategy, PaddingStrategy), "pad_strategy should be of type PaddingStrategy"
def th_from_str(self, strategy: str):
return Binarization.from_string(strategy)
@staticmethod
def explain() -> str:
"""Returns a string with all the explanations of the parameters"""
return "\n".join(
[
f"{key}: {DotDictExtra._default_explanations[key]}"
for key in DotDictExtra._default
if DotDictExtra._default_explanations.get(key, "Recommended to not change") != "Recommended to not change"
]
)
def pack_interpolate_unpack(att, size, interpolation_mode, unwrap_last_dim=True, rewrap=False):
has_last_dim = att.shape[-1] in [77, 1]
_last_dim = att.shape[-1]
if unwrap_last_dim:
if has_last_dim:
sq = int(att.shape[-2] ** 0.5)
att = att.reshape(att.shape[0], sq, sq, -1).permute(0, 3, 1, 2) # B x H x W x D => B x D x H x W
else:
sq = int(att.shape[-1] ** 0.5)
att = att.reshape(*att.shape[:-1], sq, sq) # B x H x W
att = att.unsqueeze(-3) # add a channel dimension
if att.shape[-2:] != size:
att, ps = einops.pack(att, "* c h w")
att = F.interpolate(
att,
size=size,
mode=interpolation_mode,
)
att = torch.stack(einops.unpack(att, ps, "* c h w"))
if rewrap:
if has_last_dim:
att = att.reshape(att.shape[0], -1, att.shape[-1] * att.shape[-1], _last_dim)
else:
att = att.reshape(att.shape[0], -1, att.shape[-1] * att.shape[-1])
# returns
# rewrap True:
# B x heads x D
# B x heads X D x N
# rewrap FALSE:
# B x heads x H x W
# B x N x heads X H x W x if has_last_dim
return att
@torch.no_grad()
def threshold_otsu(image: torch.Tensor = None, nbins=256, hist=None):
"""Return threshold value based on Otsu's method using PyTorch.
This is a reimplementation from scikit-image
https://github.com/scikit-image/scikit-image/blob/b76ff13478a5123e4d8b422586aaa54c791f2604/skimage/filters/thresholding.py#L336
Args:
image: torch.Tensor
Grayscale input image.
nbins: int
Number of bins used to calculate histogram.
hist: torch.Tensor or tuple
Histogram of the input image. If None, it will be calculated using the input image.
Returns
-------
threshold : float
Upper threshold value. All pixels with an intensity higher than
this value are assumed to be foreground.
"""
if image is not None and image.dim() > 2 and image.shape[-1] in (3, 4):
raise ValueError(f"threshold_otsu is expected to work correctly only for " f"grayscale images; image shape {image.shape} looks like " f"that of an RGB image.")
# Convert nbins to a tensor, on device
nbins = torch.tensor(nbins, device=image.device)
# Check if the image has more than one intensity value; if not, return that value
if image is not None:
first_pixel = image.view(-1)[0]
if torch.all(image == first_pixel):
return first_pixel.item()
counts, bin_centers = _validate_image_histogram(image, hist, nbins)
# class probabilities for all possible thresholds
weight1 = torch.cumsum(counts, dim=0)
weight2 = torch.cumsum(counts.flip(dims=[0]), dim=0).flip(dims=[0])
# class means for all possible thresholds
mean1 = torch.cumsum(counts * bin_centers, dim=0) / weight1
mean2 = (torch.cumsum((counts * bin_centers).flip(dims=[0]), dim=0).flip(dims=[0])) / weight2
# Clip ends to align class 1 and class 2 variables:
# The last value of ``weight1``/``mean1`` should pair with zero values in
# ``weight2``/``mean2``, which do not exist.
variance12 = weight1[:-1] * weight2[1:] * (mean1[:-1] - mean2[1:]) ** 2
idx = torch.argmax(variance12)
threshold = bin_centers[idx]
return threshold.item()
def _validate_image_histogram(image: torch.Tensor, hist, nbins):
"""Helper function to validate and compute histogram if necessary."""
if hist is not None:
if isinstance(hist, tuple) and len(hist) == 2:
counts, bin_centers = hist
if not (isinstance(counts, torch.Tensor) and isinstance(bin_centers, torch.Tensor)):
counts = torch.tensor(counts)
bin_centers = torch.tensor(bin_centers)
else:
counts = torch.tensor(hist)
bin_centers = torch.linspace(0, 1, len(counts))
else:
if image is None:
raise ValueError("Either image or hist must be provided.")
image = image.to(torch.float32)
counts, bin_edges = histogram(image, nbins)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
return counts, bin_centers
def histogram(xs: torch.Tensor, bins):
# Like torch.histogram, but works with cuda
# https://github.com/pytorch/pytorch/issues/69519#issuecomment-1183866843
min, max = xs.min(), xs.max()
counts = torch.histc(xs, bins, min=min, max=max).to(xs.device)
boundaries = torch.linspace(min, max, bins + 1, device=xs.device)
return counts, boundaries
# Modification of the original from
# https://github.com/google/prompt-to-prompt/blob/9c472e44aa1b607da59fea94820f7be9480ec545/prompt-to-prompt_stable.ipynb
def aggregate_attention(
attention_store: AttentionStore,
res: int,
batch_size: int,
from_where: List[str],
is_cross: bool,
upsample_everything: int = None,
return_all_layers: bool = False,
use_same_layers_as_train: bool = False,
train_layers: Optional[list[int]] = None,
use_layer_subset_idx: list[int] = None,
use_step_store: bool = False,
):
out = []
attention_maps = attention_store.get_average_attention(use_step_store)
num_pixels = res**2
for location in from_where:
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
if upsample_everything or (use_same_layers_as_train and is_cross):
item = pack_interpolate_unpack(item, (res, res), "bilinear", rewrap=True)
if item.shape[-2] == num_pixels:
cross_maps = item.reshape(batch_size, -1, res, res, item.shape[-1])[None]
out.append(cross_maps)
_dim = 0
if is_cross and use_same_layers_as_train and train_layers is not None:
out = [out[i] for i in train_layers]
if use_layer_subset_idx is not None: # after correct ordering
out = [out[i] for i in use_layer_subset_idx]
out = torch.cat(out, dim=_dim)
if return_all_layers:
return out
else:
out = out.sum(_dim) / out.shape[_dim]
return out
def min_max_norm(a, _min=None, _max=None, eps=1e-6):
_max = a.max() if _max is None else _max
_min = a.min() if _min is None else _min
return (a - _min) / (_max - _min + eps)
# Copied from https://github.com/RoyiRa/prompt-to-prompt-with-sdxl/blob/e579861f06962b697b37f3c6dd4813c2acdd55bd/processors.py#L209
class LocalBlend:
def __call__(self, x_t, attention_store):
# note that this code works on the latent level!
k = 1
# maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3]
# These are the numbers because we want to take layers that are 256 x 256, I think this can be changed to something smarter...
# like, get all attentions where thesecond dim is self.attn_res[0] * self.attn_res[1] in up and down cross.
# NOTE(Alex): This would require activating saving of the attention maps (change in DotDictExtra _enable_non_agg_storing)
# NOTE(Alex): Alternative is to use aggregate masks like in other examples
maps = [m for m in attention_store["down_cross"] + attention_store["mid_cross"] + attention_store["up_cross"] if m.shape[1] == self.attn_res[0] * self.attn_res[1]]
maps = [
item.reshape(
self.alpha_layers.shape[0],
-1,
1,
self.attn_res[0],
self.attn_res[1],
self.max_num_words,
)
for item in maps
]
maps = torch.cat(maps, dim=1)
maps = (maps * self.alpha_layers).sum(-1).mean(1)
# since alpha_layers is all 0s except where we edit, the product zeroes out all but what we change.
# Then, the sum adds the values of the original and what we edit.
# Then, we average across dim=1, which is the number of layers.
mask = F.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))
mask = F.interpolate(mask, size=(x_t.shape[2:]))
mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
mask = mask.gt(self.threshold)
mask = mask[:1] + mask[1:]
mask = mask.to(torch.float16)
if mask.shape[0] < x_t.shape[0]: # PartEdit
# concat last mask again
mask = torch.cat([mask, mask[-1:]], dim=0)
# ## NOTE(Alex): this is local blending with the mask
# assert isinstance(attention_store, AttentionStore), "AttentionStore expected"
# cur_res = x_t.shape[-1]
# if attention_store.th_strategy == Binarization.PROVIDED_MASK:
# mask = attention_store.edit_mask.to(x_t.device)
# # resize to res
# mask = F.interpolate(
# mask, (cur_res, cur_res), mode="bilinear"
# ) # ).reshape(1, -1, 1)
# else:
# mask = attention_store.get_maps_agg(
# res=cur_res,
# device=x_t.device,
# use_agg_store=attention_store.use_agg_store, # Agg is across time, Step is last step without time agg
# keepshape=True
# ) # provide in cross_attention_kwargs in pipeline
# x_t[1:] = mask * x_t[1:] + (1 - mask) * x_t[0]
# ## END NOTE(Alex): this is local blending with the mask
x_t = x_t[:1] + mask * (x_t - x_t[:1])
# The code applies a mask to the image difference between the original and each generated image, effectively retaining only the desired cells.
return x_t
# NOTE(Alex): Copied over for LocalBlend
def __init__(
self,
prompts: List[str],
words: List[List[str]],
tokenizer,
device,
threshold=0.3,
attn_res=None,
):
self.max_num_words = 77
self.attn_res = attn_res
alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, self.max_num_words)
for i, (prompt, words_) in enumerate(zip(prompts, words)):
if isinstance(words_, str):
words_ = [words_]
for word in words_:
ind = get_word_inds(prompt, word, tokenizer)
alpha_layers[i, :, :, :, :, ind] = 1
self.alpha_layers = alpha_layers.to(device) # a one-hot vector where the 1s are the words we modify (source and target)
self.threshold = threshold
# Copied from https://github.com/RoyiRa/prompt-to-prompt-with-sdxl/blob/e579861f06962b697b37f3c6dd4813c2acdd55bd/processors.py#L129
class AttentionControl(abc.ABC):
def step_callback(self, x_t):
return x_t
def between_steps(self):
return
@property
def num_uncond_att_layers(self):
return 0
@abc.abstractmethod
def forward(self, attn, is_cross: bool, place_in_unet: str, store: bool = True):
raise NotImplementedError
def __call__(self, attn, is_cross: bool, place_in_unet: str, store: bool = True):
if self.cur_att_layer >= self.num_uncond_att_layers:
h = attn.shape[0]
attn[h // 2 :] = self.forward(attn[h // 2 :], is_cross, place_in_unet, store)
self.cur_att_layer += 1
if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
self.cur_att_layer = 0
self.cur_step += 1
self.between_steps()
return attn
def reset(self):
self.cur_step = 0
self.cur_att_layer = 0
self.allow_edit_control = True
def __init__(self, attn_res=None, extra_kwargs: DotDictExtra = None):
# PartEdit
self.extra_kwargs = extra_kwargs
self.index_inside_batch = extra_kwargs.get("index_inside_batch", 1) # Default is one in our prior setting!
if not isinstance(self.index_inside_batch, list):
self.index_inside_batch = [self.index_inside_batch]
self.layers_to_use = extra_kwargs.get("_layers_to_use", LAYERS_TO_USE) # Training parameter, not exposed directly
# Params
self.th_strategy: Binarization = extra_kwargs.get("th_strategy", Binarization.P2P)
self.pad_strategy: PaddingStrategy = extra_kwargs.get("pad_strategy", PaddingStrategy.BG)
self.omega: float = extra_kwargs.get("omega", 1.0)
self.use_agg_store: bool = extra_kwargs.get("use_agg_store", False)
self.edit_mask: Optional[torch.Tensor] = extra_kwargs.get("edit_mask", None) # edit_mask_t
self.edit_steps: int = extra_kwargs.get("edit_steps", 50) # NOTE(Alex): This is the end step, IMPORTANT
self.blend_layers: Optional[List] = None
self.start_editing_at: int = extra_kwargs.get("start_editing_at", 0)
self.use_layer_subset_idx: Optional[list[int]] = extra_kwargs.get("use_layer_subset_idx", None)
self.batch_indx: int = extra_kwargs.get("batch_indx", 0)
self.VRAM_low: bool = extra_kwargs.get("VRAM_low", False)
self.allow_edit_control = True
# Old
self.cur_step: int = 0
self.num_att_layers: int = -1
self.cur_att_layer: int = 0
self.attn_res: int = attn_res
def get_maps_agg(self, resized_res, device):
return None
def _editing_allowed(self):
return self.allow_edit_control # TODO(Alex): Maybe make this only param, instead of unregister attn control?
# Copied from https://github.com/RoyiRa/prompt-to-prompt-with-sdxl/blob/e579861f06962b697b37f3c6dd4813c2acdd55bd/processors.py#L166
class EmptyControl(AttentionControl):
def forward(self, attn, is_cross: bool, place_in_unet: str, store:bool = True):
return attn
# Modified from https://github.com/RoyiRa/prompt-to-prompt-with-sdxl/blob/e579861f06962b697b37f3c6dd4813c2acdd55bd/processors.py#L171
class AttentionStore(AttentionControl):
@staticmethod
def get_empty_store():
return {
"down_cross": [],
"mid_cross": [],
"up_cross": [],
"down_self": [],
"mid_self": [],
"up_self": [],
"opt_cross": [],
"opt_bg_cross": [],
}
def maybe_offload(self, attn_device, attn_dtype):
if self.extra_kwargs.get("_cpu_offload", False):
attn_device, attn_dtype = torch.device("cpu"), torch.float32
return attn_device, attn_dtype
def forward(self, attn, is_cross: bool, place_in_unet: str, store: bool = True):
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
_device, _dtype = self.maybe_offload(attn.device, attn.dtype)
if store and self.batch_indx is not None and is_cross:
# We always store for our method
_dim = attn.shape[0] // self.num_prompt
_val = attn[_dim * self.batch_indx : _dim * (self.batch_indx + 1), ..., self.index_inside_batch].sum(0, keepdim=True).to(_device, _dtype)
if _val.shape[-1] != 1:
# min_max each -1 seperately
_max = _val.max()
for i in range(_val.shape[-1]):
_val[..., i] = min_max_norm(_val[..., i], _max=_max)
_val = _val.sum(-1, keepdim=True)
self.step_store["opt_cross"].append(_val)
if self.extra_kwargs.get("_enable_non_agg_storing", False) and store:
_attn = attn.clone().detach().to(_device, _dtype, non_blocking=True)
if attn.shape[1] <= 32**2: # avoid memory overhead
self.step_store[key].append(_attn)
return attn
def offload_stores(self, device):
"""Created for low VRAM usage, where we want to do this before Decoder"""
for key in self.step_store:
self.step_store[key] = [a.to(device) for a in self.step_store[key]]
for key in self.attention_store:
self.attention_store[key] = [a.to(device) for a in self.attention_store[key]]
torch.cuda.empty_cache()
@torch.no_grad()
def calculate_mask_t_res(self, use_step_store: bool = False):
mask_t_res = aggregate_attention(
self,
res=1024,
from_where=["opt"],
batch_size=1,
is_cross=True,
upsample_everything=False,
return_all_layers=False, # Removed sum in this function
use_same_layers_as_train=True,
train_layers=self.layers_to_use,
use_step_store=use_step_store,
use_layer_subset_idx=self.use_layer_subset_idx,
)[..., 0]
strategy: Binarization = self.th_strategy
mask_t_res = min_max_norm(mask_t_res)
upper_threshold = strategy.upper_binarize
lower_threshold = strategy.lower_binarize
use_otsu = strategy.use_otsu
tt = threshold_otsu(mask_t_res) # NOTE(Alex): Moved outside, for Inversion Low confidence region copy
if not hasattr(self, "last_otsu") or self.last_otsu == []:
self.last_otsu = [tt]
else:
self.last_otsu.append(tt)
if use_otsu:
upper_threshold, lower_threshold = (
tt * upper_threshold,
tt * lower_threshold,
)
if strategy == Binarization.PARTEDIT:
upper_threshold = self.omega * tt # Assuming we are not chaning upper in PartEdit
if strategy in [Binarization.P2P, Binarization.PROVIDED_MASK]:
return mask_t_res
mask_t_res[mask_t_res < lower_threshold] = 0
mask_t_res[mask_t_res >= upper_threshold] = 1.0
return mask_t_res
def has_maps(self) -> bool:
return len(self.mask_storage_step) > 0 or len(self.mask_storage_agg) > 0
def _store_agg_map(self) -> None:
if self.use_agg_store:
self.mask_storage_agg[self.cur_step] = self.calculate_mask_t_res().cpu()
else:
self.mask_storage_step[self.cur_step] = self.calculate_mask_t_res(True).cpu()
def between_steps(self):
no_items = len(self.attention_store) == 0
if no_items:
self.attention_store = self.step_store
else:
for key in self.attention_store:
for i in range(len(self.attention_store[key])):
self.attention_store[key][i] += self.step_store[key][i]
self._store_agg_map()
if not no_items:
# only in this case, otherwise we are just assigning it
for key in self.step_store:
# Clear the list while maintaining the dictionary structure
del self.step_store[key][:]
self.step_store = self.get_empty_store()
def get_maps_agg(self, res, device, use_agg_store: bool = None, keepshape: bool = False):
if use_agg_store is None:
use_agg_store = self.use_agg_store
_store = self.mask_storage_agg if use_agg_store else self.mask_storage_step
last_idx = sorted(_store.keys())[-1]
mask_t_res = _store[last_idx].to(device) # Should be 1 1 H W
mask_t_res = F.interpolate(mask_t_res, (res, res), mode="bilinear")
if not keepshape:
mask_t_res = mask_t_res.reshape(1, -1, 1)
return mask_t_res
def visualize_maps_agg(self, use_agg_store: bool, make_grid_kwargs: dict = None):
_store = self.mask_storage_agg if use_agg_store else self.mask_storage_step
if make_grid_kwargs is None:
make_grid_kwargs = {"nrow": 10}
return ToPILImage()(make_grid(torch.cat(list(_store.values())), **make_grid_kwargs))
def visualize_one_map(self, use_agg_store: bool, idx: int):
_store = self.mask_storage_agg if use_agg_store else self.mask_storage_step
return ToPILImage()(_store[idx])
def visualize_final_map(self, use_agg_store: bool):
"""This method returns the agg non-binarized attn map of the whole process
Args:
use_agg_store (bool): If True, it will return the agg store, otherwise the step store
Returns:
[PIL.Image]: The non-binarized attention map
"""
_store = self.mask_storage_agg if use_agg_store else self.mask_storage_step
return ToPILImage()(torch.cat(list(_store.values())).mean(0))
def get_average_attention(self, step: bool = False):
_store = self.attention_store if not step else self.step_store
average_attention = {key: [item / self.cur_step for item in _store[key]] for key in _store}
return average_attention
def reset(self):
super(AttentionStore, self).reset()
for key in self.step_store:
del self.step_store[key][:]
for key in self.attention_store:
del self.attention_store[key][:]
self.step_store = self.get_empty_store()
self.attention_store = {}
self.last_otsu = []
def __init__(
self,
num_prompt: int,
attn_res=None,
extra_kwargs: DotDictExtra = None,
):
super(AttentionStore, self).__init__(attn_res, extra_kwargs)
self.num_prompt = num_prompt
self.mask_storage_step = {}
self.mask_storage_agg = {}
if self.batch_indx is not None:
assert num_prompt > 0, "num_prompt must be greater than 0 if batch_indx is not None"
self.step_store = self.get_empty_store()
self.attention_store = {}
self.last_otsu = []
# Copied from https://github.com/RoyiRa/prompt-to-prompt-with-sdxl/blob/e579861f06962b697b37f3c6dd4813c2acdd55bd/processors.py#L246
class AttentionControlEdit(AttentionStore, abc.ABC):
def step_callback(self, x_t):
if self.local_blend is not None:
# x_t = self.local_blend(x_t, self.attention_store) # TODO: Check if there is more memory efficient way
x_t = self.local_blend(x_t, self)
return x_t
def replace_self_attention(self, attn_base, att_replace):
if att_replace.shape[2] <= self.attn_res[0] ** 2:
return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
else:
return att_replace
@abc.abstractmethod
def replace_cross_attention(self, attn_base, att_replace):
raise NotImplementedError
def forward(self, attn, is_cross: bool, place_in_unet: str, store: bool = True):
super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet, store)
if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
h = attn.shape[0] // (self.batch_size)
try:
attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
except RuntimeError as e:
logger.error(f"Batch size: {self.batch_size}, h: {h}, attn.shape: {attn.shape}")
raise e
attn_base, attn_replace = attn[0], attn[1:]
if is_cross:
alpha_words = self.cross_replace_alpha[self.cur_step].to(attn_base.device)
attn_replace_new = self.replace_cross_attention(attn_base, attn_replace) * alpha_words + (1 - alpha_words) * attn_replace
attn[1:] = attn_replace_new
if self.has_maps() and self.extra_kwargs.get("force_cross_attn", False): # and self.cur_step <= 51:
mask_t_res = self.get_maps_agg(
res=int(attn_base.shape[1] ** 0.5),
device=attn_base.device,
use_agg_store=self.use_agg_store, # Agg is across time, Step is last step without time agg
keepshape=False,
).repeat(h, 1, 1)
zero_index = torch.argmax(torch.eq(self.cross_replace_alpha[0], 0).to(mask_t_res.dtype)).item()
# zero_index = torch.eq(self.cross_replace_alpha[0].flatten(), 0)
mean_curr = attn[1:2, ..., zero_index].mean()
ratio_to_mean = mean_curr / mask_t_res[..., 0].mean()
# print(f'{ratio_to_mean=}')
extra_mask = torch.where(mask_t_res[..., 0] > self.last_otsu[-1], ratio_to_mean * 2, 0.5)
attn[1:2, ..., zero_index : zero_index + 1] += mask_t_res[None] * extra_mask[None, ..., None] # * ratio_to_mean # * 2
# attn[1:2, ..., zero_index] = (mask_t_res[..., 0][None] > self.last_otsu[-1] * 1.5).to(mask_t_res.dtype) * mean_curr
else:
attn[1:] = self.replace_self_attention(attn_base, attn_replace)
attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
return attn
def __init__(
self,
prompts: list[str],
num_steps: int,
cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
self_replace_steps: Union[float, Tuple[float, float]],
local_blend: Optional[LocalBlend],
tokenizer,
device: torch.device,
attn_res=None,
extra_kwargs: DotDictExtra = None,
):
super(AttentionControlEdit, self).__init__(
attn_res=attn_res,
num_prompt=len(prompts),
extra_kwargs=extra_kwargs,
)
# add tokenizer and device here
self.tokenizer = tokenizer
self.device = device
self.batch_size = len(prompts)
self.cross_replace_alpha = get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, self.tokenizer).to(self.device)
if isinstance(self_replace_steps, float):
self_replace_steps = 0, self_replace_steps
self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
self.local_blend = local_blend
# Copied from https://github.com/RoyiRa/prompt-to-prompt-with-sdxl/blob/e579861f06962b697b37f3c6dd4813c2acdd55bd/processors.py#L307
class AttentionReplace(AttentionControlEdit):
def replace_cross_attention(self, attn_base, att_replace):
return torch.einsum("hpw,bwn->bhpn", attn_base, self.mapper.to(attn_base.device))
def __init__(
self,
prompts,
num_steps: int,
cross_replace_steps: float,
self_replace_steps: float,
local_blend: Optional[LocalBlend] = None,
tokenizer=None,
device=None,
attn_res=None,
extra_kwargs: DotDictExtra = None,
):
super(AttentionReplace, self).__init__(
prompts,
num_steps,
cross_replace_steps,
self_replace_steps,
local_blend,
tokenizer,
device,
attn_res,
extra_kwargs,
)
self.mapper = get_replacement_mapper(prompts, self.tokenizer).to(self.device)
# Copied from https://github.com/RoyiRa/prompt-to-prompt-with-sdxl/blob/e579861f06962b697b37f3c6dd4813c2acdd55bd/processors.py#L328
class AttentionRefine(AttentionControlEdit):
def replace_cross_attention(self, attn_base, att_replace):
attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
return attn_replace
def __init__(
self,
prompts,
num_steps: int,
cross_replace_steps: float,
self_replace_steps: float,
local_blend: Optional[LocalBlend] = None,
tokenizer=None,
device=None,
attn_res=None,
extra_kwargs: DotDictExtra = None,
):
super(AttentionRefine, self).__init__(
prompts,
num_steps,
cross_replace_steps,
self_replace_steps,
local_blend,
tokenizer,
device,
attn_res,
extra_kwargs,
)
self.mapper, alphas = get_refinement_mapper(prompts, self.tokenizer)
self.mapper, alphas = self.mapper.to(self.device), alphas.to(self.device)
self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
# Copied from https://github.com/RoyiRa/prompt-to-prompt-with-sdxl/blob/e579861f06962b697b37f3c6dd4813c2acdd55bd/processors.py#L353
class AttentionReweight(AttentionControlEdit):
def replace_cross_attention(self, attn_base: torch.Tensor, att_replace: torch.Tensor):
if self.prev_controller is not None:
attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace)
attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
return attn_replace
def __init__(
self,
prompts: list[str],
num_steps: int,
cross_replace_steps: float,
self_replace_steps: float,
equalizer,
local_blend: Optional[LocalBlend] = None,
controller: Optional[AttentionControlEdit] = None,
tokenizer=None,
device=None,
attn_res=None,
extra_kwargs: DotDictExtra = None,
):
super(AttentionReweight, self).__init__(
prompts,
num_steps,
cross_replace_steps,
self_replace_steps,
local_blend,
tokenizer,
device,
attn_res,
extra_kwargs,
)
self.equalizer = equalizer.to(self.device)
self.prev_controller = controller
class PartEditCrossAttnProcessor:
# Modified from https://github.com/RoyiRa/prompt-to-prompt-with-sdxl/blob/e579861f06962b697b37f3c6dd4813c2acdd55bd/processors.py#L11
def __init__(
self,
controller: AttentionStore,
place_in_unet,
store_this_layer: bool = True,
):
super().__init__()
self.controller = controller
assert issubclass(type(controller), AttentionControl), f"{controller} isn't subclass of AttentionControl"
self.place_in_unet = place_in_unet
self.store_this_layer = store_this_layer
def has_maps(self) -> bool:
return len(self.controller.mask_storage_step) > 0 or len(self.controller.mask_storage_agg) > 0 or self.controller.edit_mask is not None
def condition_for_editing(self) -> bool:
# If we have a given mask
# If we are using PartEdit
return self.controller.th_strategy.enabled
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)
is_cross = encoder_hidden_states is not None
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
# initial_condition = hasattr(self, "controller") and hasattr(self.controller, "batch_indx") and batch_size > self.controller.batch_size
if hasattr(self, "controller") and self.controller._editing_allowed() and self.controller.batch_indx > 0:
# Set the negative/positive of the batch index to the zero image
batch_indx = self.controller.batch_indx
_bs = self.controller.batch_size
query[[batch_indx, batch_indx + _bs]] = query[[0, _bs]]
# value[[batch_indx, batch_indx+_bs]] = value[[0, _bs]]
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
self.controller(attention_probs, is_cross, self.place_in_unet, self.store_this_layer)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
res = int(np.sqrt(hidden_states.shape[1]))
should_edit = (
hasattr(self, "controller")
and self.controller._editing_allowed() # allow_edit_control
and self.has_maps()
and self.condition_for_editing()
and self.controller.cur_step > self.controller.start_editing_at
and self.controller.cur_step < self.controller.edit_steps
)
if should_edit:
if self.controller.th_strategy == Binarization.PROVIDED_MASK:
mask_t_res = self.controller.edit_mask.to(hidden_states.device)
# resize to res
mask_t_res = F.interpolate(mask_t_res, (res, res), mode="bilinear").reshape(1, -1, 1)
else:
mask_t_res = self.controller.get_maps_agg(
res=res,
device=hidden_states.device,
use_agg_store=self.controller.use_agg_store, # Agg is across time, Step is last step without time agg
) # provide in cross_attention_kwargs in pipeline
# Note: Additional blending with grounding
_extra_grounding = self.controller.extra_kwargs.get("grounding", None)
if _extra_grounding is not None:
mask_t_res = mask_t_res * F.interpolate(_extra_grounding, (res, res), mode="bilinear").reshape(1, -1, 1).to(hidden_states.device)
# hidden_states_orig = rearrange(hidden_states, "b (h w) c -> b h w c", w=res, h=res)
b1_u = 0
b1_c = self.controller.batch_size
b2_u = 1
b2_c = self.controller.batch_size + 1
hidden_states[b2_u] = (1 - mask_t_res) * hidden_states[b1_u] + mask_t_res * hidden_states[b2_u]
hidden_states[b2_c] = (1 - mask_t_res) * hidden_states[b1_c] + mask_t_res * hidden_states[b2_c]
# hidden_states_after = rearrange(hidden_states, "b (h w) c -> b h w c", w=res, h=res)
return hidden_states
# Adapted from https://github.com/RoyiRa/prompt-to-prompt-with-sdxl/blob/e579861f06962b697b37f3c6dd4813c2acdd55bd/processors.py#L48
def create_controller(
prompts: List[str],
cross_attention_kwargs: Dict,
num_inference_steps: int,
tokenizer,
device: torch.device,
attn_res: Tuple[int, int],
extra_kwargs: dict,
) -> AttentionControl:
edit_type = cross_attention_kwargs.get("edit_type", "replace")
local_blend_words = cross_attention_kwargs.get("local_blend_words")
equalizer_words = cross_attention_kwargs.get("equalizer_words")
equalizer_strengths = cross_attention_kwargs.get("equalizer_strengths")
n_cross_replace = cross_attention_kwargs.get("n_cross_replace", 0.4)
n_self_replace = cross_attention_kwargs.get("n_self_replace", 0.4)
# only replace
if edit_type == "replace" and local_blend_words is None:
return AttentionReplace(
prompts,
num_inference_steps,
n_cross_replace,
n_self_replace,
tokenizer=tokenizer,
device=device,
attn_res=attn_res,
extra_kwargs=extra_kwargs,
)
# replace + localblend
if edit_type == "replace" and local_blend_words is not None:
lb = LocalBlend(
prompts,
local_blend_words,
tokenizer=tokenizer,
device=device,
attn_res=attn_res,
)
return AttentionReplace(
prompts,
num_inference_steps,
n_cross_replace,
n_self_replace,
lb,
tokenizer=tokenizer,
device=device,
attn_res=attn_res,
extra_kwargs=extra_kwargs,
)
# only refine
if edit_type == "refine" and local_blend_words is None:
return AttentionRefine(
prompts,
num_inference_steps,
n_cross_replace,
n_self_replace,
tokenizer=tokenizer,
device=device,
attn_res=attn_res,
extra_kwargs=extra_kwargs,
)
# refine + localblend
if edit_type == "refine" and local_blend_words is not None:
lb = LocalBlend(
prompts,
local_blend_words,
tokenizer=tokenizer,
device=device,
attn_res=attn_res,
)
return AttentionRefine(
prompts,
num_inference_steps,
n_cross_replace,
n_self_replace,
lb,
tokenizer=tokenizer,
device=device,
attn_res=attn_res,
extra_kwargs=extra_kwargs,
)
# only reweight
if edit_type == "reweight" and local_blend_words is None:
assert equalizer_words is not None and equalizer_strengths is not None, "To use reweight edit, please specify equalizer_words and equalizer_strengths."
assert len(equalizer_words) == len(equalizer_strengths), "equalizer_words and equalizer_strengths must be of same length."
equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer)
return AttentionReweight(
prompts,
num_inference_steps,
n_cross_replace,
n_self_replace,
tokenizer=tokenizer,
device=device,
equalizer=equalizer,
attn_res=attn_res,
extra_kwargs=extra_kwargs,
)
# reweight and localblend
if edit_type == "reweight" and local_blend_words:
assert equalizer_words is not None and equalizer_strengths is not None, "To use reweight edit, please specify equalizer_words and equalizer_strengths."
assert len(equalizer_words) == len(equalizer_strengths), "equalizer_words and equalizer_strengths must be of same length."
equalizer = get_equalizer(prompts[1], equalizer_words, equalizer_strengths, tokenizer=tokenizer)
lb = LocalBlend(
prompts,
local_blend_words,
tokenizer=tokenizer,
device=device,
attn_res=attn_res,
)
return AttentionReweight(
prompts,
num_inference_steps,
n_cross_replace,
n_self_replace,
tokenizer=tokenizer,
device=device,
equalizer=equalizer,
attn_res=attn_res,
local_blend=lb,
extra_kwargs=extra_kwargs,
)
raise ValueError(f"Edit type {edit_type} not recognized. Use one of: replace, refine, reweight.")
# Copied from https://github.com/RoyiRa/prompt-to-prompt-with-sdxl/blob/e579861f06962b697b37f3c6dd4813c2acdd55bd/processors.py#L380-L596
### util functions for all Edits
def update_alpha_time_word(
alpha,
bounds: Union[float, Tuple[float, float]],
prompt_ind: int,
word_inds: Optional[torch.Tensor] = None,
):
if isinstance(bounds, float):
bounds = 0, bounds
start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
if word_inds is None:
word_inds = torch.arange(alpha.shape[2])
alpha[:start, prompt_ind, word_inds] = 0
alpha[start:end, prompt_ind, word_inds] = 1
alpha[end:, prompt_ind, word_inds] = 0
return alpha
def get_time_words_attention_alpha(
prompts,
num_steps,
cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
tokenizer,
max_num_words=77,
):
if not isinstance(cross_replace_steps, dict):
cross_replace_steps = {"default_": cross_replace_steps}
if "default_" not in cross_replace_steps:
cross_replace_steps["default_"] = (0.0, 1.0)
alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
for i in range(len(prompts) - 1):
alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], i)
for key, item in cross_replace_steps.items():
if key != "default_":
inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
for i, ind in enumerate(inds):
if len(ind) > 0:
alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
return alpha_time_words
### util functions for LocalBlend and ReplacementEdit
def get_word_inds(text: str, word_place: int, tokenizer):
split_text = text.split(" ")
if isinstance(word_place, str):
word_place = [i for i, word in enumerate(split_text) if word_place == word]
elif isinstance(word_place, int):
word_place = [word_place]
out = []
if len(word_place) > 0:
words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
cur_len, ptr = 0, 0
for i in range(len(words_encode)):
cur_len += len(words_encode[i])
if ptr in word_place:
out.append(i + 1)
if cur_len >= len(split_text[ptr]):
ptr += 1
cur_len = 0
return np.array(out)
### util functions for ReplacementEdit
def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
words_x = x.split(" ")
words_y = y.split(" ")
if len(words_x) != len(words_y):
raise ValueError(
f"attention replacement edit can only be applied on prompts with the same length" f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words."
)
inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
mapper = np.zeros((max_len, max_len))
i = j = 0
cur_inds = 0
while i < max_len and j < max_len:
if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
if len(inds_source_) == len(inds_target_):
mapper[inds_source_, inds_target_] = 1
else:
ratio = 1 / len(inds_target_)
for i_t in inds_target_:
mapper[inds_source_, i_t] = ratio
cur_inds += 1
i += len(inds_source_)
j += len(inds_target_)
elif cur_inds < len(inds_source):
mapper[i, j] = 1
i += 1
j += 1
else:
mapper[j, j] = 1
i += 1
j += 1
# return torch.from_numpy(mapper).float()
return torch.from_numpy(mapper).to(torch.float16)
def get_replacement_mapper(prompts, tokenizer, max_len=77):
x_seq = prompts[0]
mappers = []
for i in range(1, len(prompts)):
mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)
mappers.append(mapper)
return torch.stack(mappers)
### util functions for ReweightEdit
def get_equalizer(
text: str,
word_select: Union[int, Tuple[int, ...]],
values: Union[List[float], Tuple[float, ...]],
tokenizer,
):
if isinstance(word_select, (int, str)):
word_select = (word_select,)
equalizer = torch.ones(len(values), 77)
values = torch.tensor(values, dtype=torch.float32)
for i, word in enumerate(word_select):
inds = get_word_inds(text, word, tokenizer)
equalizer[:, inds] = torch.FloatTensor(values[i])
return equalizer
### util functions for RefinementEdit
class ScoreParams:
def __init__(self, gap, match, mismatch):
self.gap = gap
self.match = match
self.mismatch = mismatch
def mis_match_char(self, x, y):
if x != y:
return self.mismatch
else:
return self.match
def get_matrix(size_x, size_y, gap):
matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
matrix[0, 1:] = (np.arange(size_y) + 1) * gap
matrix[1:, 0] = (np.arange(size_x) + 1) * gap
return matrix
def get_traceback_matrix(size_x, size_y):
matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
matrix[0, 1:] = 1
matrix[1:, 0] = 2
matrix[0, 0] = 4
return matrix
def global_align(x, y, score):
matrix = get_matrix(len(x), len(y), score.gap)
trace_back = get_traceback_matrix(len(x), len(y))
for i in range(1, len(x) + 1):
for j in range(1, len(y) + 1):
left = matrix[i, j - 1] + score.gap
up = matrix[i - 1, j] + score.gap
diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])
matrix[i, j] = max(left, up, diag)
if matrix[i, j] == left:
trace_back[i, j] = 1
elif matrix[i, j] == up:
trace_back[i, j] = 2
else:
trace_back[i, j] = 3
return matrix, trace_back
def get_aligned_sequences(x, y, trace_back):
x_seq = []
y_seq = []
i = len(x)
j = len(y)
mapper_y_to_x = []
while i > 0 or j > 0:
if trace_back[i, j] == 3:
x_seq.append(x[i - 1])
y_seq.append(y[j - 1])
i = i - 1
j = j - 1
mapper_y_to_x.append((j, i))
elif trace_back[i][j] == 1:
x_seq.append("-")
y_seq.append(y[j - 1])
j = j - 1
mapper_y_to_x.append((j, -1))
elif trace_back[i][j] == 2:
x_seq.append(x[i - 1])
y_seq.append("-")
i = i - 1
elif trace_back[i][j] == 4:
break
mapper_y_to_x.reverse()
return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64)
def get_mapper(x: str, y: str, tokenizer, max_len=77):
x_seq = tokenizer.encode(x)
y_seq = tokenizer.encode(y)
score = ScoreParams(0, 1, -1)
matrix, trace_back = global_align(x_seq, y_seq, score)
mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]
alphas = torch.ones(max_len)
alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float()
mapper = torch.zeros(max_len, dtype=torch.int64)
mapper[: mapper_base.shape[0]] = mapper_base[:, 1]
mapper[mapper_base.shape[0] :] = len(y_seq) + torch.arange(max_len - len(y_seq))
return mapper, alphas
def get_refinement_mapper(prompts, tokenizer, max_len=77):
x_seq = prompts[0]
mappers, alphas = [], []
for i in range(1, len(prompts)):
mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len)
mappers.append(mapper)
alphas.append(alpha)
return torch.stack(mappers), torch.stack(alphas)