NeTI / sd_pipeline_call.py
neural-ti's picture
Upload 17 files
3eb1ce9
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionPipeline
@torch.no_grad()
def sd_pipeline_call(
pipeline: StableDiffusionPipeline,
prompt_embeds: torch.FloatTensor,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None):
""" Modification of the standard SD pipeline call to support NeTI embeddings passed with prompt_embeds argument."""
# 0. Default height and width to unet
height = height or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
width = width or pipeline.unet.config.sample_size * pipeline.vae_scale_factor
# 2. Define call parameters
batch_size = 1
device = pipeline._execution_device
neg_prompt = get_neg_prompt_input_ids(pipeline, negative_prompt)
negative_prompt_embeds, _ = pipeline.text_encoder(
input_ids=neg_prompt.input_ids.to(device),
attention_mask=None,
)
negative_prompt_embeds = negative_prompt_embeds[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 4. Prepare timesteps
pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = pipeline.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = pipeline.unet.in_channels
latents = pipeline.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
pipeline.text_encoder.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs.
extra_step_kwargs = pipeline.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * pipeline.scheduler.order
with pipeline.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if do_classifier_free_guidance:
latent_model_input = latents
latent_model_input = pipeline.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred_uncond = pipeline.unet(
latent_model_input,
t,
encoder_hidden_states=negative_prompt_embeds.repeat(num_images_per_prompt, 1, 1),
cross_attention_kwargs=cross_attention_kwargs,
).sample
###############################################################
# NeTI logic: use the prompt embedding for the current timestep
###############################################################
embed = prompt_embeds[i] if type(prompt_embeds) == list else prompt_embeds
noise_pred_text = pipeline.unet(
latent_model_input,
t,
encoder_hidden_states=embed,
cross_attention_kwargs=cross_attention_kwargs,
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipeline.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if output_type == "latent":
image = latents
has_nsfw_concept = None
elif output_type == "pil":
# 8. Post-processing
image = pipeline.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = pipeline.run_safety_checker(image, device, pipeline.text_encoder.dtype)
# 10. Convert to PIL
image = pipeline.numpy_to_pil(image)
else:
# 8. Post-processing
image = pipeline.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = pipeline.run_safety_checker(image, device, pipeline.text_encoder.dtype)
# Offload last model to CPU
if hasattr(pipeline, "final_offload_hook") and pipeline.final_offload_hook is not None:
pipeline.final_offload_hook.offload()
if not return_dict:
return image, has_nsfw_concept
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
def get_neg_prompt_input_ids(pipeline: StableDiffusionPipeline,
negative_prompt: Optional[Union[str, List[str]]] = None):
if negative_prompt is None:
negative_prompt = ""
uncond_tokens = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
uncond_input = pipeline.tokenizer(
uncond_tokens,
padding="max_length",
max_length=pipeline.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
return uncond_input