|
import inspect |
|
from typing import List, Optional, Union |
|
|
|
import numpy as np |
|
|
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer |
|
|
|
from ...onnx_utils import OnnxRuntimeModel |
|
from ...pipeline_utils import DiffusionPipeline |
|
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler |
|
from . import StableDiffusionPipelineOutput |
|
|
|
|
|
class StableDiffusionOnnxPipeline(DiffusionPipeline): |
|
vae_decoder: OnnxRuntimeModel |
|
text_encoder: OnnxRuntimeModel |
|
tokenizer: CLIPTokenizer |
|
unet: OnnxRuntimeModel |
|
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] |
|
safety_checker: OnnxRuntimeModel |
|
feature_extractor: CLIPFeatureExtractor |
|
|
|
def __init__( |
|
self, |
|
vae_decoder: OnnxRuntimeModel, |
|
text_encoder: OnnxRuntimeModel, |
|
tokenizer: CLIPTokenizer, |
|
unet: OnnxRuntimeModel, |
|
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], |
|
safety_checker: OnnxRuntimeModel, |
|
feature_extractor: CLIPFeatureExtractor, |
|
): |
|
super().__init__() |
|
scheduler = scheduler.set_format("np") |
|
self.register_modules( |
|
vae_decoder=vae_decoder, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
unet=unet, |
|
scheduler=scheduler, |
|
safety_checker=safety_checker, |
|
feature_extractor=feature_extractor, |
|
) |
|
|
|
def __call__( |
|
self, |
|
prompt: Union[str, List[str]], |
|
height: Optional[int] = 512, |
|
width: Optional[int] = 512, |
|
num_inference_steps: Optional[int] = 50, |
|
guidance_scale: Optional[float] = 7.5, |
|
eta: Optional[float] = 0.0, |
|
latents: Optional[np.ndarray] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
**kwargs, |
|
): |
|
if isinstance(prompt, str): |
|
batch_size = 1 |
|
elif isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") |
|
|
|
if height % 8 != 0 or width % 8 != 0: |
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
|
|
|
|
|
text_input = self.tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=self.tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="np", |
|
) |
|
text_embeddings = self.text_encoder(input_ids=text_input.input_ids.astype(np.int32))[0] |
|
|
|
|
|
|
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
if do_classifier_free_guidance: |
|
max_length = text_input.input_ids.shape[-1] |
|
uncond_input = self.tokenizer( |
|
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" |
|
) |
|
uncond_embeddings = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0] |
|
|
|
|
|
|
|
|
|
text_embeddings = np.concatenate([uncond_embeddings, text_embeddings]) |
|
|
|
|
|
latents_shape = (batch_size, 4, height // 8, width // 8) |
|
if latents is None: |
|
latents = np.random.randn(*latents_shape).astype(np.float32) |
|
elif latents.shape != latents_shape: |
|
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") |
|
|
|
|
|
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) |
|
extra_set_kwargs = {} |
|
if accepts_offset: |
|
extra_set_kwargs["offset"] = 1 |
|
|
|
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) |
|
|
|
|
|
if isinstance(self.scheduler, LMSDiscreteScheduler): |
|
latents = latents * self.scheduler.sigmas[0] |
|
|
|
|
|
|
|
|
|
|
|
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
|
extra_step_kwargs = {} |
|
if accepts_eta: |
|
extra_step_kwargs["eta"] = eta |
|
|
|
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): |
|
|
|
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents |
|
if isinstance(self.scheduler, LMSDiscreteScheduler): |
|
sigma = self.scheduler.sigmas[i] |
|
|
|
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) |
|
|
|
|
|
noise_pred = self.unet( |
|
sample=latent_model_input, timestep=np.array([t]), encoder_hidden_states=text_embeddings |
|
) |
|
noise_pred = noise_pred[0] |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
if isinstance(self.scheduler, LMSDiscreteScheduler): |
|
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample |
|
else: |
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
|
|
|
|
|
latents = 1 / 0.18215 * latents |
|
image = self.vae_decoder(latent_sample=latents)[0] |
|
|
|
image = np.clip(image / 2 + 0.5, 0, 1) |
|
image = image.transpose((0, 2, 3, 1)) |
|
|
|
|
|
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") |
|
image, has_nsfw_concept = self.safety_checker(clip_input=safety_checker_input.pixel_values, images=image) |
|
|
|
if output_type == "pil": |
|
image = self.numpy_to_pil(image) |
|
|
|
if not return_dict: |
|
return (image, has_nsfw_concept) |
|
|
|
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |
|
|