tango / diffusers /examples /community /clip_guided_stable_diffusion.py
deepanway's picture
add required files
6b448ad
raw
history blame contribute delete
No virus
14.6 kB
import inspect
from typing import List, Optional, Union
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from transformers import CLIPImageProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
UNet2DConditionModel,
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
class MakeCutouts(nn.Module):
def __init__(self, cut_size, cut_power=1.0):
super().__init__()
self.cut_size = cut_size
self.cut_power = cut_power
def forward(self, pixel_values, num_cutouts):
sideY, sideX = pixel_values.shape[2:4]
max_size = min(sideX, sideY)
min_size = min(sideX, sideY, self.cut_size)
cutouts = []
for _ in range(num_cutouts):
size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size)
offsetx = torch.randint(0, sideX - size + 1, ())
offsety = torch.randint(0, sideY - size + 1, ())
cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size]
cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
return torch.cat(cutouts)
def spherical_dist_loss(x, y):
x = F.normalize(x, dim=-1)
y = F.normalize(y, dim=-1)
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
def set_requires_grad(model, value):
for param in model.parameters():
param.requires_grad = value
class CLIPGuidedStableDiffusion(DiffusionPipeline):
"""CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000
- https://github.com/Jack000/glid-3-xl
- https://github.dev/crowsonkb/k-diffusion
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
clip_model: CLIPModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],
feature_extractor: CLIPImageProcessor,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
clip_model=clip_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
self.cut_out_size = (
feature_extractor.size
if isinstance(feature_extractor.size, int)
else feature_extractor.size["shortest_edge"]
)
self.make_cutouts = MakeCutouts(self.cut_out_size)
set_requires_grad(self.text_encoder, False)
set_requires_grad(self.clip_model, False)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = self.unet.config.attention_head_dim // 2
self.unet.set_attention_slice(slice_size)
def disable_attention_slicing(self):
self.enable_attention_slicing(None)
def freeze_vae(self):
set_requires_grad(self.vae, False)
def unfreeze_vae(self):
set_requires_grad(self.vae, True)
def freeze_unet(self):
set_requires_grad(self.unet, False)
def unfreeze_unet(self):
set_requires_grad(self.unet, True)
@torch.enable_grad()
def cond_fn(
self,
latents,
timestep,
index,
text_embeddings,
noise_pred_original,
text_embeddings_clip,
clip_guidance_scale,
num_cutouts,
use_cutouts=True,
):
latents = latents.detach().requires_grad_()
latent_model_input = self.scheduler.scale_model_input(latents, timestep)
# predict the noise residual
noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample
if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler)):
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
beta_prod_t = 1 - alpha_prod_t
# compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
fac = torch.sqrt(beta_prod_t)
sample = pred_original_sample * (fac) + latents * (1 - fac)
elif isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[index]
sample = latents - sigma * noise_pred
else:
raise ValueError(f"scheduler type {type(self.scheduler)} not supported")
sample = 1 / self.vae.config.scaling_factor * sample
image = self.vae.decode(sample).sample
image = (image / 2 + 0.5).clamp(0, 1)
if use_cutouts:
image = self.make_cutouts(image, num_cutouts)
else:
image = transforms.Resize(self.cut_out_size)(image)
image = self.normalize(image).to(latents.dtype)
image_embeddings_clip = self.clip_model.get_image_features(image)
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
if use_cutouts:
dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip)
dists = dists.view([num_cutouts, sample.shape[0], -1])
loss = dists.sum(2).mean(0).sum() * clip_guidance_scale
else:
loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale
grads = -torch.autograd.grad(loss, latents)[0]
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents.detach() + grads * (sigma**2)
noise_pred = noise_pred_original
else:
noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads
return noise_pred, latents
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
height: Optional[int] = 512,
width: Optional[int] = 512,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
clip_guidance_scale: Optional[float] = 100,
clip_prompt: Optional[Union[str, List[str]]] = None,
num_cutouts: Optional[int] = 4,
use_cutouts: Optional[bool] = True,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# get prompt text embeddings
text_input = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
if clip_guidance_scale > 0:
if clip_prompt is not None:
clip_text_input = self.tokenizer(
clip_prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
).input_ids.to(self.device)
else:
clip_text_input = text_input.input_ids.to(self.device)
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input)
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True)
# duplicate text embeddings clip for each generation per prompt
text_embeddings_clip = text_embeddings_clip.repeat_interleave(num_images_per_prompt, dim=0)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance:
max_length = text_input.input_ids.shape[-1]
uncond_input = self.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# duplicate unconditional embeddings for each generation per prompt
uncond_embeddings = uncond_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# get the initial random noise unless the user supplied it
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
self.device
)
else:
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)
# set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
extra_set_kwargs = {}
if accepts_offset:
extra_set_kwargs["offset"] = 1
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
# Some schedulers like PNDM have timesteps as arrays
# It's more optimized to move all timesteps to correct device beforehand
timesteps_tensor = self.scheduler.timesteps.to(self.device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform classifier free guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# perform clip guidance
if clip_guidance_scale > 0:
text_embeddings_for_guidance = (
text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings
)
noise_pred, latents = self.cond_fn(
latents,
t,
i,
text_embeddings_for_guidance,
noise_pred,
text_embeddings_clip,
clip_guidance_scale,
num_cutouts,
use_cutouts,
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# scale and decode the image latents with vae
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image, None)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)