File size: 8,879 Bytes
6cd6a16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import torch
from diffusers.models import AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizerFast, T5EncoderModel, T5Tokenizer
class EncoderMixin:
"""Mixin class for handling various encoders in the MotifDiT model.
This mixin provides functionality for:
1. Loading and initializing encoders (VAE, T5, CLIP-L, CLIP-G)
2. Text tokenization and encoding
3. Managing encoder parameters and state
"""
TOKEN_MAX_LENGTH: int = 256
def prepare_embeddings(
self,
images: torch.Tensor,
raw_text: list[str],
vae: AutoencoderKL,
t5: T5EncoderModel,
clip_l: CLIPTextModel,
clip_g: CLIPTextModel,
t5_tokenizer: T5Tokenizer,
clip_l_tokenizer: CLIPTokenizerFast,
clip_g_tokenizer: CLIPTokenizerFast,
is_training,
) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]:
"""Prepare image latents and text embeddings for model input.
Args:
images (torch.Tensor): Input images tensor with shape [B, C=3, H, W].
raw_text (List[str]): List of raw text strings with length B.
"""
with torch.no_grad():
latents: torch.Tensor = (
vae.encode(images).latent_dist.sample() - vae.config.shift_factor
) * vae.config.scaling_factor # Latents shape: [B, 16, H//8, W//8]
# Tokenize the input text and move tokens and masks to the same device as latents
tokenizers = [t5_tokenizer, clip_l_tokenizer, clip_g_tokenizer]
tokens, masks = self.tokenization(raw_text, tokenizers)
tokens = [token.to(latents.device) for token in tokens]
masks = [mask.to(latents.device) for mask in masks]
# Encode the text and drop unnecessary embeddings
text_embeddings, pooled_text_embeddings = self.text_encoding(
tokens,
masks,
t5,
clip_l,
clip_g,
t5_tokenizer.pad_token_id,
clip_l_tokenizer.eos_token_id,
clip_g_tokenizer.eos_token_id,
is_training,
)
text_embeddings = self.drop_text_emb(text_embeddings)
# Convert text embeddings to float
text_embeddings = [text_embedding.float() for text_embedding in text_embeddings]
# Convert pooled text embeddings to float
pooled_text_embeddings = pooled_text_embeddings.float()
return latents, text_embeddings, pooled_text_embeddings
def get_freezed_encoders_and_tokenizers(
self, vae_type: str
) -> tuple[
AutoencoderKL, T5EncoderModel, CLIPTextModel, CLIPTextModel, T5Tokenizer, CLIPTokenizerFast, CLIPTokenizerFast
]:
"""Initialize the VAE and text encoders."""
if vae_type != "SD3":
raise ValueError(
f"VAE type must be `SD3` but self.config.vae_type is {vae_type}."
f" note that the VAE type SDXL is deprecated."
)
vae: AutoencoderKL = AutoencoderKL.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae"
)
# Text encoders
# 1. T5-XXL from Google
t5 = T5EncoderModel.from_pretrained("google/flan-t5-xxl").to(dtype=torch.bfloat16)
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")
# 2. CLIP-L from OpenAI
clip_l = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(dtype=torch.bfloat16)
clip_l_tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-large-patch14")
# 3. CLIP-G from LAION
clip_g = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k").to(dtype=torch.bfloat16)
clip_g_tokenizer = CLIPTokenizerFast.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
# Freeze all encoders
for encoder_module in [vae, clip_l, clip_g, t5]:
for param in encoder_module.parameters():
param.requires_grad = False
return vae, t5, clip_l, clip_g, t5_tokenizer, clip_l_tokenizer, clip_g_tokenizer
def tokenization(
self, raw_text: list[str], tokenizers: list[T5Tokenizer | CLIPTokenizerFast]
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""Tokenize the input text using multiple tokenizers.
Args:
raw_text (str): Input text string.
Returns:
Tuple[List[torch.Tensor], List[torch.Tensor]]: Lists of tokenized text tensors and attention masks.
"""
tokens, masks = [], []
for tokenizer in tokenizers:
tok = tokenizer(
raw_text,
padding="max_length",
max_length=min(EncoderMixin.TOKEN_MAX_LENGTH, tokenizer.model_max_length),
return_tensors="pt",
truncation=True,
)
tokens.append(tok.input_ids)
masks.append(tok.attention_mask)
return tokens, masks
@torch.no_grad()
def text_encoding(
self,
tokens: list[torch.Tensor],
masks: list[torch.Tensor],
t5: T5EncoderModel,
clip_l: CLIPTextModel,
clip_g: CLIPTextModel,
t5_pad_token_id: int = 0,
clip_l_tokenizer_eos_token_id: int = 49407,
clip_g_tokenizer_eos_token_id: int = 49407,
is_training: bool = False,
) -> tuple[list[torch.Tensor], torch.Tensor]:
"""Encode the tokenized text using multiple text encoders.
Args:
tokens (List[torch.Tensor]): List of tokenized text tensors.
masks (List[torch.Tensor]): List of attention masks.
Returns:
Tuple[List[torch.Tensor], torch.Tensor]: Text embeddings and pooled text embeddings.
"""
t5_tokens, clip_l_tokens, clip_g_tokens = tokens
t5_masks, _, _ = masks
# T5 encoding
t5_emb = t5(t5_tokens, attention_mask=t5_masks)[0]
t5_emb = t5_emb * (t5_tokens != t5_pad_token_id).unsqueeze(-1)
# CLIP encodings
clip_l_emb = clip_l(input_ids=clip_l_tokens, output_hidden_states=True)
clip_g_emb = clip_g(input_ids=clip_g_tokens, output_hidden_states=True)
# Get pooled outputs
clip_l_emb_pooled = clip_l_emb.pooler_output # B x 768
clip_g_emb_pooled = clip_g_emb.pooler_output # B x 1280
if is_training:
clip_l_emb_pooled = self.drop_text_emb(clip_l_emb_pooled)
clip_g_emb_pooled = self.drop_text_emb(clip_g_emb_pooled)
clip_l_emb = clip_l_emb.last_hidden_state # B x L x 768
clip_g_emb = clip_g_emb.last_hidden_state # B x L x 1280
def masking_wo_first_eos(token, eos):
"""Create attention mask without first EOS token."""
idx = (token != eos).sum(dim=1)
mask = token != eos
arange = torch.arange(mask.size(0)).cuda()
if idx != len(mask[0]):
mask[arange, idx] = True
return mask.unsqueeze(-1) # B x L x 1
# Apply masking
clip_l_emb = clip_l_emb * masking_wo_first_eos(clip_l_tokens, clip_l_tokenizer_eos_token_id)
clip_g_emb = clip_g_emb * masking_wo_first_eos(clip_g_tokens, clip_g_tokenizer_eos_token_id)
encodings = [t5_emb, clip_l_emb, clip_g_emb]
pooled_encodings = torch.cat([clip_l_emb_pooled, clip_g_emb_pooled], dim=-1) # B x 2048
return encodings, pooled_encodings
@torch.no_grad()
def drop_text_emb(
self, text_embeddings: list[torch.Tensor] | torch.Tensor, drop_prob: float = 0.464
) -> list[torch.Tensor] | torch.Tensor:
"""Randomly drop text embeddings with a specified probability.
Args:
text_embeddings (Union[List[torch.Tensor], torch.Tensor]): Text embeddings to be dropped.
drop_prob (float, optional): Probability of dropping text embeddings. Defaults to 0.464.
Returns:
Union[List[torch.Tensor], torch.Tensor]: Text embeddings with dropped elements.
"""
if isinstance(text_embeddings, list):
# For BxLxC features
for text_embedding in text_embeddings:
probs = torch.ones((text_embedding.shape[0])).cuda() * (1 - drop_prob)
masks = torch.bernoulli(probs).cuda()
while len(masks.shape) < len(text_embedding.shape):
masks = masks.unsqueeze(-1)
text_embedding = text_embedding * masks
else:
# For a pooled BxC feature
probs = torch.ones((text_embeddings.shape[0])).cuda() * (1 - drop_prob)
masks = torch.bernoulli(probs).cuda()
while len(masks.shape) < len(text_embeddings.shape):
masks = masks.unsqueeze(-1)
text_embeddings = text_embeddings * masks
return text_embeddings
|