import torch
from tqdm import tqdm
import utils
from PIL import Image
import gc
import numpy as np
from .attention import GatedSelfAttentionDense
from .models import torch_device
def encode(model_dict, image, generator):
image should be a PIL object or numpy array with range 0 to 255
vae, dtype = model_dict.vae, model_dict.dtype
if isinstance(image, Image.Image):
w, h = image.size
assert w % 8 == 0 and h % 8 == 0, f"h ({h}) and w ({w}) should be a multiple of 8"
# w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
# image = np.array(image.resize((w, h), resample=Image.Resampling.LANCZOS))[None, :]
image = np.array(image)
if isinstance(image, np.ndarray):
assert image.dtype == np.uint8, f"Should have dtype uint8 (dtype: {image.dtype})"
image = image.astype(np.float32) / 255.0
image = image[None, ...]
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
assert isinstance(image, torch.Tensor), f"type of image: {type(image)}"
image =, dtype=dtype)
latents = vae.encode(image).latent_dist.sample(generator)
latents = vae.config.scaling_factor * latents
return latents
def decode(vae, latents):
# scale and decode the image latents with vae
scaled_latents = 1 / 0.18215 * latents
with torch.no_grad():
image = vae.decode(scaled_latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
return images
def generate(model_dict, latents, input_embeddings, num_inference_steps, guidance_scale = 7.5, no_set_timesteps=False, scheduler_key='dpm_scheduler'):
vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict[scheduler_key], model_dict.dtype
text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings
if not no_set_timesteps:
for t in tqdm(scheduler.timesteps):
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input =[latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
# predict the noise residual
with torch.no_grad():
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents).prev_sample
images = decode(vae, latents)
ret = [latents, images]
return tuple(ret)
def gligen_enable_fuser(unet, enabled=True):
for module in unet.modules():
if isinstance(module, GatedSelfAttentionDense):
module.enabled = enabled
def generate_gligen(model_dict, latents, input_embeddings, num_inference_steps, bboxes, phrases, num_images_per_prompt=1, gligen_scheduled_sampling_beta: float = 0.3, guidance_scale=7.5,
frozen_steps=20, frozen_mask=None,
return_saved_cross_attn=False, saved_cross_attn_keys=None, return_cond_ca_only=False, return_token_ca_only=None,
offload_cross_attn_to_cpu=False, offload_latents_to_cpu=True,
return_box_vis=False, show_progress=True, save_all_latents=False, scheduler_key='dpm_scheduler'):
The `bboxes` should be a list, rather than a list of lists (one box per phrase, we can have multiple duplicated phrases).
vae, tokenizer, text_encoder, unet, scheduler, dtype = model_dict.vae, model_dict.tokenizer, model_dict.text_encoder, model_dict.unet, model_dict[scheduler_key], model_dict.dtype
text_embeddings, uncond_embeddings, cond_embeddings = input_embeddings
if latents.dim() == 5:
# latents_all from the input side, different from the latents_all to be saved
latents_all_input = latents
latents = latents[0]
latents_all_input = None
# Just in case that we have in-place ops
latents = latents.clone()
if save_all_latents:
# offload to cpu to save space
if offload_latents_to_cpu:
latents_all = [latents.cpu()]
latents_all = [latents]
if frozen_mask is not None:
frozen_mask =, 1.)
batch_size = 1
# 5.1 Prepare GLIGEN variables
assert len(phrases) == len(bboxes)
# assert batch_size == 1
max_objs = 30
_boxes = bboxes
n_objs = min(len(_boxes), max_objs)
boxes = torch.zeros(max_objs, 4, device=torch_device, dtype=dtype)
phrase_embeddings = torch.zeros(max_objs, 768, device=torch_device, dtype=dtype)
masks = torch.zeros(max_objs, device=torch_device, dtype=dtype)
if n_objs > 0:
boxes[:n_objs] = torch.tensor(_boxes[:n_objs])
tokenizer_inputs = tokenizer(phrases, padding=True, return_tensors="pt").to(torch_device)
_phrase_embeddings = text_encoder(**tokenizer_inputs).pooler_output
phrase_embeddings[:n_objs] = _phrase_embeddings[:n_objs]
masks[:n_objs] = 1
# Classifier-free guidance
repeat_batch = batch_size * num_images_per_prompt * 2
boxes = boxes.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
phrase_embeddings = phrase_embeddings.unsqueeze(0).expand(repeat_batch, -1, -1).clone()
masks = masks.unsqueeze(0).expand(repeat_batch, -1).clone()
masks[:repeat_batch // 2] = 0
if return_saved_cross_attn:
saved_attns = []
main_cross_attention_kwargs = {
'offload_cross_attn_to_cpu': offload_cross_attn_to_cpu,
'return_cond_ca_only': return_cond_ca_only,
'return_token_ca_only': return_token_ca_only,
'save_keys': saved_cross_attn_keys,
'gligen': {
'boxes': boxes,
'positive_embeddings': phrase_embeddings,
'masks': masks
timesteps = scheduler.timesteps
num_grounding_steps = int(gligen_scheduled_sampling_beta * len(timesteps))
gligen_enable_fuser(unet, True)
for index, t in enumerate(tqdm(timesteps, disable=not show_progress)):
# Scheduled sampling
if index == num_grounding_steps:
gligen_enable_fuser(unet, False)
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input =[latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
main_cross_attention_kwargs['save_attn_to_dict'] = {}
# predict the noise residual
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings,
if return_saved_cross_attn:
del main_cross_attention_kwargs['save_attn_to_dict']
# perform guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = scheduler.step(noise_pred, t, latents).prev_sample
if frozen_mask is not None and index < frozen_steps:
latents = latents_all_input[index+1] * frozen_mask + latents * (1. - frozen_mask)
if save_all_latents:
if offload_latents_to_cpu:
# Turn off fuser for typical SD
gligen_enable_fuser(unet, False)
images = decode(vae, latents)
ret = [latents, images]
if return_saved_cross_attn:
if return_box_vis:
pil_images = [utils.draw_box(Image.fromarray(image), bboxes, phrases) for image in images]
if save_all_latents:
latents_all = torch.stack(latents_all, dim=0)
return tuple(ret)