Spaces:
Running
on
Zero
Running
on
Zero
import inspect | |
from typing import List, Optional, Union | |
import numpy as np | |
import PIL.Image | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torchvision import transforms | |
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer | |
from diffusers import ( | |
AutoencoderKL, | |
DDIMScheduler, | |
DPMSolverMultistepScheduler, | |
LMSDiscreteScheduler, | |
PNDMScheduler, | |
UNet2DConditionModel, | |
) | |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin | |
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | |
from diffusers.utils import PIL_INTERPOLATION, deprecate | |
from diffusers.utils.torch_utils import randn_tensor | |
EXAMPLE_DOC_STRING = """ | |
Examples: | |
```py | |
from io import BytesIO | |
import requests | |
import torch | |
from diffusers import DiffusionPipeline | |
from PIL import Image | |
from transformers import CLIPFeatureExtractor, CLIPModel | |
feature_extractor = CLIPFeatureExtractor.from_pretrained( | |
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K" | |
) | |
clip_model = CLIPModel.from_pretrained( | |
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16 | |
) | |
guided_pipeline = DiffusionPipeline.from_pretrained( | |
"CompVis/stable-diffusion-v1-4", | |
# custom_pipeline="clip_guided_stable_diffusion", | |
custom_pipeline="/home/njindal/diffusers/examples/community/clip_guided_stable_diffusion.py", | |
clip_model=clip_model, | |
feature_extractor=feature_extractor, | |
torch_dtype=torch.float16, | |
) | |
guided_pipeline.enable_attention_slicing() | |
guided_pipeline = guided_pipeline.to("cuda") | |
prompt = "fantasy book cover, full moon, fantasy forest landscape, golden vector elements, fantasy magic, dark light night, intricate, elegant, sharp focus, illustration, highly detailed, digital painting, concept art, matte, art by WLOP and Artgerm and Albert Bierstadt, masterpiece" | |
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" | |
response = requests.get(url) | |
init_image = Image.open(BytesIO(response.content)).convert("RGB") | |
image = guided_pipeline( | |
prompt=prompt, | |
num_inference_steps=30, | |
image=init_image, | |
strength=0.75, | |
guidance_scale=7.5, | |
clip_guidance_scale=100, | |
num_cutouts=4, | |
use_cutouts=False, | |
).images[0] | |
display(image) | |
``` | |
""" | |
def preprocess(image, w, h): | |
if isinstance(image, torch.Tensor): | |
return image | |
elif isinstance(image, PIL.Image.Image): | |
image = [image] | |
if isinstance(image[0], PIL.Image.Image): | |
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] | |
image = np.concatenate(image, axis=0) | |
image = np.array(image).astype(np.float32) / 255.0 | |
image = image.transpose(0, 3, 1, 2) | |
image = 2.0 * image - 1.0 | |
image = torch.from_numpy(image) | |
elif isinstance(image[0], torch.Tensor): | |
image = torch.cat(image, dim=0) | |
return image | |
class MakeCutouts(nn.Module): | |
def __init__(self, cut_size, cut_power=1.0): | |
super().__init__() | |
self.cut_size = cut_size | |
self.cut_power = cut_power | |
def forward(self, pixel_values, num_cutouts): | |
sideY, sideX = pixel_values.shape[2:4] | |
max_size = min(sideX, sideY) | |
min_size = min(sideX, sideY, self.cut_size) | |
cutouts = [] | |
for _ in range(num_cutouts): | |
size = int(torch.rand([]) ** self.cut_power * (max_size - min_size) + min_size) | |
offsetx = torch.randint(0, sideX - size + 1, ()) | |
offsety = torch.randint(0, sideY - size + 1, ()) | |
cutout = pixel_values[:, :, offsety : offsety + size, offsetx : offsetx + size] | |
cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size)) | |
return torch.cat(cutouts) | |
def spherical_dist_loss(x, y): | |
x = F.normalize(x, dim=-1) | |
y = F.normalize(y, dim=-1) | |
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) | |
def set_requires_grad(model, value): | |
for param in model.parameters(): | |
param.requires_grad = value | |
class CLIPGuidedStableDiffusion(DiffusionPipeline, StableDiffusionMixin): | |
"""CLIP guided stable diffusion based on the amazing repo by @crowsonkb and @Jack000 | |
- https://github.com/Jack000/glid-3-xl | |
- https://github.dev/crowsonkb/k-diffusion | |
""" | |
def __init__( | |
self, | |
vae: AutoencoderKL, | |
text_encoder: CLIPTextModel, | |
clip_model: CLIPModel, | |
tokenizer: CLIPTokenizer, | |
unet: UNet2DConditionModel, | |
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler], | |
feature_extractor: CLIPFeatureExtractor, | |
): | |
super().__init__() | |
self.register_modules( | |
vae=vae, | |
text_encoder=text_encoder, | |
clip_model=clip_model, | |
tokenizer=tokenizer, | |
unet=unet, | |
scheduler=scheduler, | |
feature_extractor=feature_extractor, | |
) | |
self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) | |
self.cut_out_size = ( | |
feature_extractor.size | |
if isinstance(feature_extractor.size, int) | |
else feature_extractor.size["shortest_edge"] | |
) | |
self.make_cutouts = MakeCutouts(self.cut_out_size) | |
set_requires_grad(self.text_encoder, False) | |
set_requires_grad(self.clip_model, False) | |
def freeze_vae(self): | |
set_requires_grad(self.vae, False) | |
def unfreeze_vae(self): | |
set_requires_grad(self.vae, True) | |
def freeze_unet(self): | |
set_requires_grad(self.unet, False) | |
def unfreeze_unet(self): | |
set_requires_grad(self.unet, True) | |
def get_timesteps(self, num_inference_steps, strength, device): | |
# get the original timestep using init_timestep | |
init_timestep = min(int(num_inference_steps * strength), num_inference_steps) | |
t_start = max(num_inference_steps - init_timestep, 0) | |
timesteps = self.scheduler.timesteps[t_start:] | |
return timesteps, num_inference_steps - t_start | |
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): | |
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): | |
raise ValueError( | |
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" | |
) | |
image = image.to(device=device, dtype=dtype) | |
batch_size = batch_size * num_images_per_prompt | |
if isinstance(generator, list) and len(generator) != batch_size: | |
raise ValueError( | |
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
) | |
if isinstance(generator, list): | |
init_latents = [ | |
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) | |
] | |
init_latents = torch.cat(init_latents, dim=0) | |
else: | |
init_latents = self.vae.encode(image).latent_dist.sample(generator) | |
init_latents = self.vae.config.scaling_factor * init_latents | |
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: | |
# expand init_latents for batch_size | |
deprecation_message = ( | |
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" | |
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note" | |
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" | |
" your script to pass as many initial images as text prompts to suppress this warning." | |
) | |
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) | |
additional_image_per_prompt = batch_size // init_latents.shape[0] | |
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) | |
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: | |
raise ValueError( | |
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." | |
) | |
else: | |
init_latents = torch.cat([init_latents], dim=0) | |
shape = init_latents.shape | |
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | |
# get latents | |
init_latents = self.scheduler.add_noise(init_latents, noise, timestep) | |
latents = init_latents | |
return latents | |
def cond_fn( | |
self, | |
latents, | |
timestep, | |
index, | |
text_embeddings, | |
noise_pred_original, | |
text_embeddings_clip, | |
clip_guidance_scale, | |
num_cutouts, | |
use_cutouts=True, | |
): | |
latents = latents.detach().requires_grad_() | |
latent_model_input = self.scheduler.scale_model_input(latents, timestep) | |
# predict the noise residual | |
noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample | |
if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler)): | |
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] | |
beta_prod_t = 1 - alpha_prod_t | |
# compute predicted original sample from predicted noise also called | |
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) | |
fac = torch.sqrt(beta_prod_t) | |
sample = pred_original_sample * (fac) + latents * (1 - fac) | |
elif isinstance(self.scheduler, LMSDiscreteScheduler): | |
sigma = self.scheduler.sigmas[index] | |
sample = latents - sigma * noise_pred | |
else: | |
raise ValueError(f"scheduler type {type(self.scheduler)} not supported") | |
sample = 1 / self.vae.config.scaling_factor * sample | |
image = self.vae.decode(sample).sample | |
image = (image / 2 + 0.5).clamp(0, 1) | |
if use_cutouts: | |
image = self.make_cutouts(image, num_cutouts) | |
else: | |
image = transforms.Resize(self.cut_out_size)(image) | |
image = self.normalize(image).to(latents.dtype) | |
image_embeddings_clip = self.clip_model.get_image_features(image) | |
image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) | |
if use_cutouts: | |
dists = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip) | |
dists = dists.view([num_cutouts, sample.shape[0], -1]) | |
loss = dists.sum(2).mean(0).sum() * clip_guidance_scale | |
else: | |
loss = spherical_dist_loss(image_embeddings_clip, text_embeddings_clip).mean() * clip_guidance_scale | |
grads = -torch.autograd.grad(loss, latents)[0] | |
if isinstance(self.scheduler, LMSDiscreteScheduler): | |
latents = latents.detach() + grads * (sigma**2) | |
noise_pred = noise_pred_original | |
else: | |
noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads | |
return noise_pred, latents | |
def __call__( | |
self, | |
prompt: Union[str, List[str]], | |
height: Optional[int] = 512, | |
width: Optional[int] = 512, | |
image: Union[torch.Tensor, PIL.Image.Image] = None, | |
strength: float = 0.8, | |
num_inference_steps: Optional[int] = 50, | |
guidance_scale: Optional[float] = 7.5, | |
num_images_per_prompt: Optional[int] = 1, | |
eta: float = 0.0, | |
clip_guidance_scale: Optional[float] = 100, | |
clip_prompt: Optional[Union[str, List[str]]] = None, | |
num_cutouts: Optional[int] = 4, | |
use_cutouts: Optional[bool] = True, | |
generator: Optional[torch.Generator] = None, | |
latents: Optional[torch.Tensor] = None, | |
output_type: Optional[str] = "pil", | |
return_dict: bool = True, | |
): | |
if isinstance(prompt, str): | |
batch_size = 1 | |
elif isinstance(prompt, list): | |
batch_size = len(prompt) | |
else: | |
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | |
if height % 8 != 0 or width % 8 != 0: | |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | |
# get prompt text embeddings | |
text_input = self.tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=self.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] | |
# duplicate text embeddings for each generation per prompt | |
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) | |
# set timesteps | |
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) | |
extra_set_kwargs = {} | |
if accepts_offset: | |
extra_set_kwargs["offset"] = 1 | |
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) | |
# Some schedulers like PNDM have timesteps as arrays | |
# It's more optimized to move all timesteps to correct device beforehand | |
self.scheduler.timesteps.to(self.device) | |
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, self.device) | |
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) | |
# Preprocess image | |
image = preprocess(image, width, height) | |
if latents is None: | |
latents = self.prepare_latents( | |
image, | |
latent_timestep, | |
batch_size, | |
num_images_per_prompt, | |
text_embeddings.dtype, | |
self.device, | |
generator, | |
) | |
if clip_guidance_scale > 0: | |
if clip_prompt is not None: | |
clip_text_input = self.tokenizer( | |
clip_prompt, | |
padding="max_length", | |
max_length=self.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
).input_ids.to(self.device) | |
else: | |
clip_text_input = text_input.input_ids.to(self.device) | |
text_embeddings_clip = self.clip_model.get_text_features(clip_text_input) | |
text_embeddings_clip = text_embeddings_clip / text_embeddings_clip.norm(p=2, dim=-1, keepdim=True) | |
# duplicate text embeddings clip for each generation per prompt | |
text_embeddings_clip = text_embeddings_clip.repeat_interleave(num_images_per_prompt, dim=0) | |
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | |
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | |
# corresponds to doing no classifier free guidance. | |
do_classifier_free_guidance = guidance_scale > 1.0 | |
# get unconditional embeddings for classifier free guidance | |
if do_classifier_free_guidance: | |
max_length = text_input.input_ids.shape[-1] | |
uncond_input = self.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt") | |
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | |
# duplicate unconditional embeddings for each generation per prompt | |
uncond_embeddings = uncond_embeddings.repeat_interleave(num_images_per_prompt, dim=0) | |
# For classifier free guidance, we need to do two forward passes. | |
# Here we concatenate the unconditional and text embeddings into a single batch | |
# to avoid doing two forward passes | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
# get the initial random noise unless the user supplied it | |
# Unlike in other pipelines, latents need to be generated in the target device | |
# for 1-to-1 results reproducibility with the CompVis implementation. | |
# However this currently doesn't work in `mps`. | |
latents_shape = (batch_size * num_images_per_prompt, self.unet.config.in_channels, height // 8, width // 8) | |
latents_dtype = text_embeddings.dtype | |
if latents is None: | |
if self.device.type == "mps": | |
# randn does not work reproducibly on mps | |
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( | |
self.device | |
) | |
else: | |
latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) | |
else: | |
if latents.shape != latents_shape: | |
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") | |
latents = latents.to(self.device) | |
# scale the initial noise by the standard deviation required by the scheduler | |
latents = latents * self.scheduler.init_noise_sigma | |
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | |
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | |
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | |
# and should be between [0, 1] | |
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | |
extra_step_kwargs = {} | |
if accepts_eta: | |
extra_step_kwargs["eta"] = eta | |
# check if the scheduler accepts generator | |
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) | |
if accepts_generator: | |
extra_step_kwargs["generator"] = generator | |
with self.progress_bar(total=num_inference_steps): | |
for i, t in enumerate(timesteps): | |
# expand the latents if we are doing classifier free guidance | |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
# predict the noise residual | |
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | |
# perform classifier free guidance | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
# perform clip guidance | |
if clip_guidance_scale > 0: | |
text_embeddings_for_guidance = ( | |
text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings | |
) | |
noise_pred, latents = self.cond_fn( | |
latents, | |
t, | |
i, | |
text_embeddings_for_guidance, | |
noise_pred, | |
text_embeddings_clip, | |
clip_guidance_scale, | |
num_cutouts, | |
use_cutouts, | |
) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample | |
# scale and decode the image latents with vae | |
latents = 1 / self.vae.config.scaling_factor * latents | |
image = self.vae.decode(latents).sample | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.cpu().permute(0, 2, 3, 1).numpy() | |
if output_type == "pil": | |
image = self.numpy_to_pil(image) | |
if not return_dict: | |
return (image, None) | |
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None) | |