|
import logging |
|
import math |
|
from dataclasses import dataclass |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import torch |
|
from diffusers import AutoencoderKL, DiffusionPipeline |
|
from diffusers.utils import BaseOutput |
|
from diffusers.utils.torch_utils import randn_tensor |
|
from PIL import Image |
|
from torch import FloatTensor |
|
from tqdm.auto import tqdm |
|
from transformers import T5EncoderModel, T5TokenizerFast |
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class APGConfig: |
|
"""APG (Augmented Parallel Guidance) configuration""" |
|
|
|
enabled: bool = True |
|
orthogonal_threshold: float = 0.03 |
|
|
|
|
|
@dataclass |
|
class FLitePipelineOutput(BaseOutput): |
|
""" |
|
Output class for FLitePipeline pipeline. |
|
Args: |
|
images (`List[PIL.Image.Image]` or `np.ndarray`) |
|
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, |
|
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. |
|
""" |
|
|
|
images: Union[List[Image.Image], np.ndarray] |
|
|
|
|
|
class FLitePipeline(DiffusionPipeline): |
|
r""" |
|
Pipeline for text-to-image generation using FLite model. |
|
This model inherits from [`DiffusionPipeline`]. |
|
""" |
|
|
|
model_cpu_offload_seq = "text_encoder->dit_model->vae" |
|
|
|
dit_model: torch.nn.Module |
|
vae: AutoencoderKL |
|
text_encoder: T5EncoderModel |
|
tokenizer: T5TokenizerFast |
|
_progress_bar_config: Dict[str, Any] |
|
|
|
def __init__( |
|
self, dit_model: torch.nn.Module, vae: AutoencoderKL, text_encoder: T5EncoderModel, tokenizer: T5TokenizerFast |
|
): |
|
super().__init__() |
|
|
|
|
|
DiffusionPipeline.register_modules( |
|
self, dit_model=dit_model, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer |
|
) |
|
|
|
|
|
|
|
if hasattr(self.vae, "to"): |
|
self.vae.to(memory_format=torch.channels_last) |
|
if hasattr(self.vae, "requires_grad_"): |
|
self.vae.requires_grad_(False) |
|
if hasattr(self.text_encoder, "requires_grad_"): |
|
self.text_encoder.requires_grad_(False) |
|
|
|
|
|
self.vae_scale_factor = 8 |
|
self.return_index = -8 |
|
|
|
def enable_vae_slicing(self): |
|
"""Enable VAE slicing for memory efficiency.""" |
|
if hasattr(self.vae, "enable_slicing"): |
|
self.vae.enable_slicing() |
|
|
|
def enable_vae_tiling(self): |
|
"""Enable VAE tiling for memory efficiency.""" |
|
if hasattr(self.vae, "enable_tiling"): |
|
self.vae.enable_tiling() |
|
|
|
def set_progress_bar_config(self, **kwargs): |
|
"""Set progress bar configuration.""" |
|
self._progress_bar_config = kwargs |
|
|
|
def progress_bar(self, iterable=None, **kwargs): |
|
"""Create progress bar for iterations.""" |
|
self._progress_bar_config = getattr(self, "_progress_bar_config", None) or {} |
|
config = {**self._progress_bar_config, **kwargs} |
|
return tqdm(iterable, **config) |
|
|
|
def encode_prompt( |
|
self, |
|
prompt: Union[str, List[str]], |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
device: Optional[torch.device] = None, |
|
dtype: Optional[torch.dtype] = None, |
|
max_sequence_length: int = 512, |
|
return_index: int = -8, |
|
) -> Tuple[FloatTensor, FloatTensor]: |
|
"""Encodes the prompt and negative prompt.""" |
|
if isinstance(prompt, str): |
|
prompt = [prompt] |
|
device = device or self.text_encoder.device |
|
|
|
text_inputs = self.tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=max_sequence_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
text_input_ids = text_inputs.input_ids.to(device) |
|
prompt_embeds = self.text_encoder(text_input_ids, return_dict=True, output_hidden_states=True) |
|
prompt_embeds_tensor = prompt_embeds.hidden_states[return_index] |
|
if return_index != -1: |
|
prompt_embeds_tensor = self.text_encoder.encoder.final_layer_norm(prompt_embeds_tensor) |
|
prompt_embeds_tensor = self.text_encoder.encoder.dropout(prompt_embeds_tensor) |
|
|
|
dtype = dtype or next(self.text_encoder.parameters()).dtype |
|
prompt_embeds_tensor = prompt_embeds_tensor.to(dtype=dtype, device=device) |
|
|
|
|
|
if negative_prompt is None: |
|
negative_embeds = torch.zeros_like(prompt_embeds_tensor) |
|
else: |
|
if isinstance(negative_prompt, str): |
|
negative_prompt = [negative_prompt] |
|
negative_result = self.encode_prompt( |
|
prompt=negative_prompt, device=device, dtype=dtype, return_index=return_index |
|
) |
|
negative_embeds = negative_result[0] |
|
|
|
|
|
from typing import cast |
|
|
|
prompt_tensor = cast(FloatTensor, prompt_embeds_tensor.to(dtype=dtype)) |
|
negative_tensor = cast(FloatTensor, negative_embeds.to(dtype=dtype)) |
|
return (prompt_tensor, negative_tensor) |
|
|
|
def to(self, torch_device=None, torch_dtype=None, silence_dtype_warnings=False): |
|
"""Move pipeline components to specified device and dtype.""" |
|
if hasattr(self, "vae"): |
|
self.vae.to(device=torch_device, dtype=torch_dtype) |
|
if hasattr(self, "text_encoder"): |
|
self.text_encoder.to(device=torch_device, dtype=torch_dtype) |
|
if hasattr(self, "dit_model"): |
|
self.dit_model.to(device=torch_device, dtype=torch_dtype) |
|
return self |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
prompt: Union[str, List[str]], |
|
height: Optional[int] = 1024, |
|
width: Optional[int] = 1024, |
|
num_inference_steps: int = 30, |
|
guidance_scale: float = 3.0, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
dtype: Optional[torch.dtype] = None, |
|
alpha: Optional[float] = None, |
|
apg_config: Optional[APGConfig] = None, |
|
**kwargs, |
|
): |
|
"""Generate images from text prompt.""" |
|
batch_size = 1 |
|
|
|
if height is None: |
|
height = 1024 |
|
if width is None: |
|
width = 1024 |
|
|
|
dtype = dtype or next(self.dit_model.parameters()).dtype |
|
apg_config = apg_config or APGConfig() |
|
|
|
device = self._execution_device |
|
|
|
|
|
prompt_embeds, negative_embeds = self.encode_prompt( |
|
prompt=prompt, negative_prompt=negative_prompt, device=self.text_encoder.device, dtype=dtype |
|
) |
|
|
|
|
|
latent_height = height // self.vae_scale_factor |
|
latent_width = width // self.vae_scale_factor |
|
|
|
if isinstance(generator, list): |
|
if len(generator) != batch_size: |
|
raise ValueError(f"Got {len(generator)} generators for {batch_size} samples") |
|
|
|
latents = randn_tensor((batch_size, 16, latent_height, latent_width), generator=generator, device=device, dtype=dtype) |
|
acc_latents = latents.clone() |
|
|
|
|
|
if alpha is None: |
|
image_token_size = latent_height * latent_width |
|
alpha = 2 * math.sqrt(image_token_size / (64 * 64)) |
|
|
|
|
|
self.dit_model.eval() |
|
|
|
|
|
do_classifier_free_guidance = guidance_scale >= 1.0 |
|
|
|
for i in self.progress_bar(range(num_inference_steps, 0, -1)): |
|
|
|
t = i / num_inference_steps |
|
t_next = (i - 1) / num_inference_steps |
|
|
|
t = t * alpha / (1 + (alpha - 1) * t) |
|
t_next = t_next * alpha / (1 + (alpha - 1) * t_next) |
|
dt = t - t_next |
|
|
|
|
|
t_tensor = torch.tensor([t] * batch_size, device=device, dtype=dtype) |
|
|
|
if do_classifier_free_guidance: |
|
|
|
latents_input = torch.cat([latents] * 2) |
|
|
|
context_input = torch.cat([negative_embeds, prompt_embeds]) |
|
|
|
t_input = torch.cat([t_tensor] * 2) |
|
|
|
|
|
model_outputs = self.dit_model(latents_input, context_input, t_input) |
|
|
|
|
|
uncond_output, cond_output = model_outputs.chunk(2) |
|
|
|
if apg_config.enabled: |
|
|
|
dy = cond_output |
|
dd = cond_output - uncond_output |
|
|
|
parallel_direction = (dy * dd).sum() / (dy * dy).sum() * dy |
|
orthogonal_direction = dd - parallel_direction |
|
|
|
orthogonal_std = orthogonal_direction.std() |
|
orthogonal_scale = min(1, apg_config.orthogonal_threshold / orthogonal_std) |
|
orthogonal_direction = orthogonal_direction * orthogonal_scale |
|
model_output = dy + (guidance_scale - 1) * orthogonal_direction |
|
else: |
|
|
|
model_output = uncond_output + guidance_scale * (cond_output - uncond_output) |
|
else: |
|
|
|
model_output = self.dit_model(latents, prompt_embeds, t_tensor) |
|
|
|
|
|
acc_latents = acc_latents + dt * model_output.to(device) |
|
latents = acc_latents.clone() |
|
|
|
|
|
|
|
scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215) if hasattr(self.vae, "config") else 0.18215 |
|
shift_factor = getattr(self.vae.config, "shift_factor", 0) if hasattr(self.vae, "config") else 0 |
|
|
|
latents = latents / scaling_factor + shift_factor |
|
|
|
vae_dtype = self.vae.dtype if hasattr(self.vae, "dtype") else dtype |
|
decoded_images = self.vae.decode(latents.to(vae_dtype)).sample if hasattr(self.vae, "decode") else latents |
|
|
|
|
|
try: |
|
self.maybe_free_model_hooks() |
|
except AttributeError as e: |
|
if "OptimizedModule" in str(e): |
|
import warnings |
|
warnings.warn( |
|
"Encountered 'OptimizedModule' error when offloading models. " |
|
"This issue might be fixed in the future by: " |
|
"https://github.com/huggingface/diffusers/pull/10730" |
|
) |
|
else: |
|
raise |
|
|
|
|
|
images = (decoded_images / 2 + 0.5).clamp(0, 1) |
|
|
|
images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu() |
|
pil_images = [Image.fromarray(img.permute(1, 2, 0).numpy()) for img in images] |
|
|
|
return FLitePipelineOutput( |
|
images=pil_images, |
|
) |
|
|
|
|
|
|