File size: 16,013 Bytes
57f11a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 |
import torch
import torch.nn as nn
from transformers import CLIPTextModel
from diffusers import (
StableDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
UNet2DConditionModel,
DDIMScheduler,
AutoencoderKL,
)
from insightface.app import FaceAnalysis
from adaface.arc2face_models import CLIPTextModelWrapper
from adaface.util import get_arc2face_id_prompt_embs
import re, os
import sys
sys.modules['ldm'] = sys.modules['adaface']
class AdaFaceWrapper(nn.Module):
def __init__(self, pipeline_name, base_model_path, adaface_ckpt_path, device,
subject_string='z', num_vectors=16,
num_inference_steps=50, negative_prompt=None,
use_840k_vae=False, use_ds_text_encoder=False, is_training=False):
'''
pipeline_name: "text2img" or "img2img" or None. If None, the unet and vae are
removed from the pipeline to release RAM.
'''
super().__init__()
self.pipeline_name = pipeline_name
self.base_model_path = base_model_path
self.adaface_ckpt_path = adaface_ckpt_path
self.use_840k_vae = use_840k_vae
self.use_ds_text_encoder = use_ds_text_encoder
self.subject_string = subject_string
self.num_vectors = num_vectors
self.num_inference_steps = num_inference_steps
self.device = device
self.is_training = is_training
self.initialize_pipeline()
self.extend_tokenizer_and_text_encoder()
if negative_prompt is None:
self.negative_prompt = \
"flaws in the eyes, flaws in the face, lowres, non-HDRi, low quality, worst quality, artifacts, noise, text, watermark, glitch, " \
"mutated, ugly, disfigured, hands, partially rendered objects, partially rendered eyes, deformed eyeballs, cross-eyed, blurry, " \
"mutation, duplicate, out of frame, cropped, mutilated, bad anatomy, deformed, bad proportions, " \
"nude, naked, nsfw, topless, bare breasts"
else:
self.negative_prompt = negative_prompt
def load_subj_basis_generator(self, adaface_ckpt_path):
ckpt = torch.load(adaface_ckpt_path, map_location='cpu')
string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
if self.subject_string not in string_to_subj_basis_generator_dict:
print(f"Subject '{self.subject_string}' not found in the embedding manager.")
breakpoint()
self.subj_basis_generator = string_to_subj_basis_generator_dict[self.subject_string]
# In the original ckpt, num_out_layers is 16 for layerwise embeddings.
# But we don't do layerwise embeddings here, so we set it to 1.
self.subj_basis_generator.num_out_layers = 1
print(f"Loaded subject basis generator for '{self.subject_string}'.")
print(repr(self.subj_basis_generator))
self.subj_basis_generator.to(self.device)
if self.is_training:
self.subj_basis_generator.train()
else:
self.subj_basis_generator.eval()
def initialize_pipeline(self):
self.load_subj_basis_generator(self.adaface_ckpt_path)
# arc2face_text_encoder maps the face analysis embedding to 16 face embeddings
# in the UNet image space.
arc2face_text_encoder = CLIPTextModelWrapper.from_pretrained(
'models/arc2face', subfolder="encoder", torch_dtype=torch.float16
)
self.arc2face_text_encoder = arc2face_text_encoder.to(self.device)
if self.use_840k_vae:
# The 840000-step vae model is slightly better in face details than the original vae model.
# https://huggingface.co/stabilityai/sd-vae-ft-mse-original
vae = AutoencoderKL.from_single_file("models/diffusers/sd-vae-ft-mse-original/vae-ft-mse-840000-ema-pruned.ckpt", torch_dtype=torch.float16)
else:
vae = None
if self.use_ds_text_encoder:
# The dreamshaper v7 finetuned text encoder follows the prompt slightly better than the original text encoder.
# https://huggingface.co/Lykon/DreamShaper/tree/main/text_encoder
text_encoder = CLIPTextModel.from_pretrained("models/ds_text_encoder", torch_dtype=torch.float16)
else:
text_encoder = None
remove_unet = False
if self.pipeline_name == "img2img":
PipelineClass = StableDiffusionImg2ImgPipeline
elif self.pipeline_name == "text2img":
PipelineClass = StableDiffusionPipeline
# pipeline_name is None means only use this instance to generate adaface embeddings, not to generate images.
elif self.pipeline_name is None:
PipelineClass = StableDiffusionPipeline
remove_unet = True
else:
raise ValueError(f"Unknown pipeline name: {self.pipeline_name}")
if os.path.isfile(self.base_model_path):
pipeline = PipelineClass.from_single_file(
self.base_model_path,
torch_dtype=torch.float16
)
else:
pipeline = PipelineClass.from_pretrained(
self.base_model_path,
torch_dtype=torch.float16,
safety_checker=None
)
print(f"Loaded pipeline from {self.base_model_path}.")
if self.use_840k_vae:
pipeline.vae = vae
print("Replaced the VAE with the 840k-step VAE.")
if self.use_ds_text_encoder:
pipeline.text_encoder = text_encoder
print("Replaced the text encoder with the DreamShaper text encoder.")
if remove_unet:
# Remove unet and vae to release RAM. Only keep tokenizer and text_encoder.
pipeline.unet = None
pipeline.vae = None
print("Removed UNet and VAE from the pipeline.")
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
pipeline.scheduler = noise_scheduler
self.pipeline = pipeline.to(self.device)
# FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2.
# Note there's a second "model" in the path.
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.face_app.prepare(ctx_id=0, det_size=(512, 512))
# Patch the missing tokenizer in the subj_basis_generator.
if not hasattr(self.subj_basis_generator, 'clip_tokenizer'):
self.subj_basis_generator.clip_tokenizer = self.pipeline.tokenizer
print("Patched the missing tokenizer in the subj_basis_generator.")
def extend_tokenizer_and_text_encoder(self):
if self.num_vectors < 1:
raise ValueError(f"num_vectors has to be larger or equal to 1, but is {self.num_vectors}")
tokenizer = self.pipeline.tokenizer
# Add z0, z1, z2, ..., z15.
self.placeholder_tokens = []
for i in range(0, self.num_vectors):
self.placeholder_tokens.append(f"{self.subject_string}_{i}")
self.placeholder_tokens_str = " ".join(self.placeholder_tokens)
# Add the new tokens to the tokenizer.
num_added_tokens = tokenizer.add_tokens(self.placeholder_tokens)
if num_added_tokens != self.num_vectors:
raise ValueError(
f"The tokenizer already contains the token {self.subject_string}. Please pass a different"
" `subject_string` that is not already in the tokenizer.")
print(f"Added {num_added_tokens} tokens ({self.placeholder_tokens_str}) to the tokenizer.")
# placeholder_token_ids: [49408, ..., 49423].
self.placeholder_token_ids = tokenizer.convert_tokens_to_ids(self.placeholder_tokens)
# print(self.placeholder_token_ids)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
old_weight = self.pipeline.text_encoder.get_input_embeddings().weight
self.pipeline.text_encoder.resize_token_embeddings(len(tokenizer))
new_weight = self.pipeline.text_encoder.get_input_embeddings().weight
print(f"Resized text encoder token embeddings from {old_weight.shape} to {new_weight.shape} on {new_weight.device}.")
# Extend pipeline.text_encoder with the adaface subject emeddings.
# subj_embs: [16, 768].
def update_text_encoder_subj_embs(self, subj_embs):
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = self.pipeline.text_encoder.get_input_embeddings().weight.data
with torch.no_grad():
for i, token_id in enumerate(self.placeholder_token_ids):
token_embeds[token_id] = subj_embs[i]
print(f"Updated {len(self.placeholder_token_ids)} tokens ({self.placeholder_tokens_str}) in the text encoder.")
def update_prompt(self, prompt):
# If the placeholder tokens are already in the prompt, then return the prompt as is.
if self.placeholder_tokens_str in prompt:
return prompt
# If the subject string 'z' is not in the prompt, then simply prepend the placeholder tokens to the prompt.
if re.search(r'\b' + self.subject_string + r'\b', prompt) is None:
print(f"Subject string '{self.subject_string}' not found in the prompt. Adding it.")
comp_prompt = self.placeholder_tokens_str + " " + prompt
else:
# Replace the subject string 'z' with the placeholder tokens.
comp_prompt = re.sub(r'\b' + self.subject_string + r'\b', self.placeholder_tokens_str, prompt)
return comp_prompt
# image_paths: a list of image paths. image_folder: the parent folder name.
def generate_adaface_embeddings(self, image_paths, image_folder=None,
pre_face_embs=None, gen_rand_face=False,
out_id_embs_scale=1., noise_level=0, update_text_encoder=True):
# faceid_embeds is a batch of extracted face analysis embeddings (BS * 512 = id_batch_size * 512).
# If extract_faceid_embeds is True, faceid_embeds is *the same* embedding repeated by id_batch_size times.
# Otherwise, faceid_embeds is a batch of random embeddings, each instance is different.
# The same applies to id_prompt_emb.
# faceid_embeds is in the face analysis embeddings. id_prompt_emb is in the image prompt space.
# Here id_batch_size = 1, so
# faceid_embeds: [1, 512]. NOT used later.
# id_prompt_emb: [1, 16, 768].
# NOTE: Since return_core_id_embs is True, id_prompt_emb is only the 16 core ID embeddings.
# arc2face prompt template: "photo of a id person"
# ID embeddings start from "id person ...". So there are 3 template tokens before the 16 ID embeddings.
face_image_count, faceid_embeds, id_prompt_emb \
= get_arc2face_id_prompt_embs(self.face_app, self.pipeline.tokenizer, self.arc2face_text_encoder,
extract_faceid_embeds=not gen_rand_face,
pre_face_embs=pre_face_embs,
# image_folder is passed only for logging purpose.
# image_paths contains the paths of the images.
image_folder=image_folder, image_paths=image_paths,
images_np=None,
id_batch_size=1,
device=self.device,
# input_max_length == 22: only keep the first 22 tokens,
# including 3 template tokens and 16 ID tokens, and BOS and EOS tokens.
# The results are indistinguishable from input_max_length=77.
input_max_length=22,
noise_level=noise_level,
return_core_id_embs=True,
gen_neg_prompt=False,
verbose=True)
if face_image_count == 0:
return None
# adaface_subj_embs: [1, 1, 16, 768].
# adaface_prompt_embs: [1, 77, 768] (not used).
adaface_subj_embs, adaface_prompt_embs = \
self.subj_basis_generator(id_prompt_emb, None, None,
out_id_embs_scale=out_id_embs_scale,
is_face=True, is_training=False,
adaface_prompt_embs_inf_type='full_half_pad')
# adaface_subj_embs: [16, 768]
adaface_subj_embs = adaface_subj_embs.squeeze()
if update_text_encoder:
self.update_text_encoder_subj_embs(adaface_subj_embs)
return adaface_subj_embs
def encode_prompt(self, prompt, negative_prompt=None, device="cuda", verbose=False):
if negative_prompt is None:
negative_prompt = self.negative_prompt
prompt = self.update_prompt(prompt)
if verbose:
print(f"Prompt: {prompt}")
# For some unknown reason, the text_encoder is still on CPU after self.pipeline.to(self.device).
# So we manually move it to GPU here.
self.pipeline.text_encoder.to(device)
# prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
prompt_embeds_, negative_prompt_embeds_ = \
self.pipeline.encode_prompt(prompt, device=device, num_images_per_prompt=1,
do_classifier_free_guidance=True, negative_prompt=negative_prompt)
return prompt_embeds_, negative_prompt_embeds_
# ref_img_strength is used only in the img2img pipeline.
def forward(self, noise, prompt, negative_prompt=None, guidance_scale=4.0,
out_image_count=4, ref_img_strength=0.8, generator=None, verbose=False):
if negative_prompt is None:
negative_prompt = self.negative_prompt
# prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
prompt_embeds_, negative_prompt_embeds_ = self.encode_prompt(prompt, negative_prompt, device=self.device, verbose=verbose)
# Repeat the prompt embeddings for all images in the batch.
prompt_embeds_ = prompt_embeds_.repeat(out_image_count, 1, 1)
negative_prompt_embeds_ = negative_prompt_embeds_.repeat(out_image_count, 1, 1)
noise = noise.to(self.device).to(torch.float16)
# noise: [BS, 4, 64, 64]
# When the pipeline is text2img, strength is ignored.
images = self.pipeline(image=noise,
prompt_embeds=prompt_embeds_,
negative_prompt_embeds=negative_prompt_embeds_,
num_inference_steps=self.num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
strength=ref_img_strength,
generator=generator).images
# images: [BS, 3, 512, 512]
return images
|