Adaface / adaface_wrapper.py
zyt334's picture
Upload folder using huggingface_hub
57f11a4 verified
import torch
import torch.nn as nn
from transformers import CLIPTextModel
from diffusers import (
StableDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
UNet2DConditionModel,
DDIMScheduler,
AutoencoderKL,
)
from insightface.app import FaceAnalysis
from adaface.arc2face_models import CLIPTextModelWrapper
from adaface.util import get_arc2face_id_prompt_embs
import re, os
import sys
sys.modules['ldm'] = sys.modules['adaface']
class AdaFaceWrapper(nn.Module):
def __init__(self, pipeline_name, base_model_path, adaface_ckpt_path, device,
subject_string='z', num_vectors=16,
num_inference_steps=50, negative_prompt=None,
use_840k_vae=False, use_ds_text_encoder=False, is_training=False):
'''
pipeline_name: "text2img" or "img2img" or None. If None, the unet and vae are
removed from the pipeline to release RAM.
'''
super().__init__()
self.pipeline_name = pipeline_name
self.base_model_path = base_model_path
self.adaface_ckpt_path = adaface_ckpt_path
self.use_840k_vae = use_840k_vae
self.use_ds_text_encoder = use_ds_text_encoder
self.subject_string = subject_string
self.num_vectors = num_vectors
self.num_inference_steps = num_inference_steps
self.device = device
self.is_training = is_training
self.initialize_pipeline()
self.extend_tokenizer_and_text_encoder()
if negative_prompt is None:
self.negative_prompt = \
"flaws in the eyes, flaws in the face, lowres, non-HDRi, low quality, worst quality, artifacts, noise, text, watermark, glitch, " \
"mutated, ugly, disfigured, hands, partially rendered objects, partially rendered eyes, deformed eyeballs, cross-eyed, blurry, " \
"mutation, duplicate, out of frame, cropped, mutilated, bad anatomy, deformed, bad proportions, " \
"nude, naked, nsfw, topless, bare breasts"
else:
self.negative_prompt = negative_prompt
def load_subj_basis_generator(self, adaface_ckpt_path):
ckpt = torch.load(adaface_ckpt_path, map_location='cpu')
string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
if self.subject_string not in string_to_subj_basis_generator_dict:
print(f"Subject '{self.subject_string}' not found in the embedding manager.")
breakpoint()
self.subj_basis_generator = string_to_subj_basis_generator_dict[self.subject_string]
# In the original ckpt, num_out_layers is 16 for layerwise embeddings.
# But we don't do layerwise embeddings here, so we set it to 1.
self.subj_basis_generator.num_out_layers = 1
print(f"Loaded subject basis generator for '{self.subject_string}'.")
print(repr(self.subj_basis_generator))
self.subj_basis_generator.to(self.device)
if self.is_training:
self.subj_basis_generator.train()
else:
self.subj_basis_generator.eval()
def initialize_pipeline(self):
self.load_subj_basis_generator(self.adaface_ckpt_path)
# arc2face_text_encoder maps the face analysis embedding to 16 face embeddings
# in the UNet image space.
arc2face_text_encoder = CLIPTextModelWrapper.from_pretrained(
'models/arc2face', subfolder="encoder", torch_dtype=torch.float16
)
self.arc2face_text_encoder = arc2face_text_encoder.to(self.device)
if self.use_840k_vae:
# The 840000-step vae model is slightly better in face details than the original vae model.
# https://huggingface.co/stabilityai/sd-vae-ft-mse-original
vae = AutoencoderKL.from_single_file("models/diffusers/sd-vae-ft-mse-original/vae-ft-mse-840000-ema-pruned.ckpt", torch_dtype=torch.float16)
else:
vae = None
if self.use_ds_text_encoder:
# The dreamshaper v7 finetuned text encoder follows the prompt slightly better than the original text encoder.
# https://huggingface.co/Lykon/DreamShaper/tree/main/text_encoder
text_encoder = CLIPTextModel.from_pretrained("models/ds_text_encoder", torch_dtype=torch.float16)
else:
text_encoder = None
remove_unet = False
if self.pipeline_name == "img2img":
PipelineClass = StableDiffusionImg2ImgPipeline
elif self.pipeline_name == "text2img":
PipelineClass = StableDiffusionPipeline
# pipeline_name is None means only use this instance to generate adaface embeddings, not to generate images.
elif self.pipeline_name is None:
PipelineClass = StableDiffusionPipeline
remove_unet = True
else:
raise ValueError(f"Unknown pipeline name: {self.pipeline_name}")
if os.path.isfile(self.base_model_path):
pipeline = PipelineClass.from_single_file(
self.base_model_path,
torch_dtype=torch.float16
)
else:
pipeline = PipelineClass.from_pretrained(
self.base_model_path,
torch_dtype=torch.float16,
safety_checker=None
)
print(f"Loaded pipeline from {self.base_model_path}.")
if self.use_840k_vae:
pipeline.vae = vae
print("Replaced the VAE with the 840k-step VAE.")
if self.use_ds_text_encoder:
pipeline.text_encoder = text_encoder
print("Replaced the text encoder with the DreamShaper text encoder.")
if remove_unet:
# Remove unet and vae to release RAM. Only keep tokenizer and text_encoder.
pipeline.unet = None
pipeline.vae = None
print("Removed UNet and VAE from the pipeline.")
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
pipeline.scheduler = noise_scheduler
self.pipeline = pipeline.to(self.device)
# FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2.
# Note there's a second "model" in the path.
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.face_app.prepare(ctx_id=0, det_size=(512, 512))
# Patch the missing tokenizer in the subj_basis_generator.
if not hasattr(self.subj_basis_generator, 'clip_tokenizer'):
self.subj_basis_generator.clip_tokenizer = self.pipeline.tokenizer
print("Patched the missing tokenizer in the subj_basis_generator.")
def extend_tokenizer_and_text_encoder(self):
if self.num_vectors < 1:
raise ValueError(f"num_vectors has to be larger or equal to 1, but is {self.num_vectors}")
tokenizer = self.pipeline.tokenizer
# Add z0, z1, z2, ..., z15.
self.placeholder_tokens = []
for i in range(0, self.num_vectors):
self.placeholder_tokens.append(f"{self.subject_string}_{i}")
self.placeholder_tokens_str = " ".join(self.placeholder_tokens)
# Add the new tokens to the tokenizer.
num_added_tokens = tokenizer.add_tokens(self.placeholder_tokens)
if num_added_tokens != self.num_vectors:
raise ValueError(
f"The tokenizer already contains the token {self.subject_string}. Please pass a different"
" `subject_string` that is not already in the tokenizer.")
print(f"Added {num_added_tokens} tokens ({self.placeholder_tokens_str}) to the tokenizer.")
# placeholder_token_ids: [49408, ..., 49423].
self.placeholder_token_ids = tokenizer.convert_tokens_to_ids(self.placeholder_tokens)
# print(self.placeholder_token_ids)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
old_weight = self.pipeline.text_encoder.get_input_embeddings().weight
self.pipeline.text_encoder.resize_token_embeddings(len(tokenizer))
new_weight = self.pipeline.text_encoder.get_input_embeddings().weight
print(f"Resized text encoder token embeddings from {old_weight.shape} to {new_weight.shape} on {new_weight.device}.")
# Extend pipeline.text_encoder with the adaface subject emeddings.
# subj_embs: [16, 768].
def update_text_encoder_subj_embs(self, subj_embs):
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = self.pipeline.text_encoder.get_input_embeddings().weight.data
with torch.no_grad():
for i, token_id in enumerate(self.placeholder_token_ids):
token_embeds[token_id] = subj_embs[i]
print(f"Updated {len(self.placeholder_token_ids)} tokens ({self.placeholder_tokens_str}) in the text encoder.")
def update_prompt(self, prompt):
# If the placeholder tokens are already in the prompt, then return the prompt as is.
if self.placeholder_tokens_str in prompt:
return prompt
# If the subject string 'z' is not in the prompt, then simply prepend the placeholder tokens to the prompt.
if re.search(r'\b' + self.subject_string + r'\b', prompt) is None:
print(f"Subject string '{self.subject_string}' not found in the prompt. Adding it.")
comp_prompt = self.placeholder_tokens_str + " " + prompt
else:
# Replace the subject string 'z' with the placeholder tokens.
comp_prompt = re.sub(r'\b' + self.subject_string + r'\b', self.placeholder_tokens_str, prompt)
return comp_prompt
# image_paths: a list of image paths. image_folder: the parent folder name.
def generate_adaface_embeddings(self, image_paths, image_folder=None,
pre_face_embs=None, gen_rand_face=False,
out_id_embs_scale=1., noise_level=0, update_text_encoder=True):
# faceid_embeds is a batch of extracted face analysis embeddings (BS * 512 = id_batch_size * 512).
# If extract_faceid_embeds is True, faceid_embeds is *the same* embedding repeated by id_batch_size times.
# Otherwise, faceid_embeds is a batch of random embeddings, each instance is different.
# The same applies to id_prompt_emb.
# faceid_embeds is in the face analysis embeddings. id_prompt_emb is in the image prompt space.
# Here id_batch_size = 1, so
# faceid_embeds: [1, 512]. NOT used later.
# id_prompt_emb: [1, 16, 768].
# NOTE: Since return_core_id_embs is True, id_prompt_emb is only the 16 core ID embeddings.
# arc2face prompt template: "photo of a id person"
# ID embeddings start from "id person ...". So there are 3 template tokens before the 16 ID embeddings.
face_image_count, faceid_embeds, id_prompt_emb \
= get_arc2face_id_prompt_embs(self.face_app, self.pipeline.tokenizer, self.arc2face_text_encoder,
extract_faceid_embeds=not gen_rand_face,
pre_face_embs=pre_face_embs,
# image_folder is passed only for logging purpose.
# image_paths contains the paths of the images.
image_folder=image_folder, image_paths=image_paths,
images_np=None,
id_batch_size=1,
device=self.device,
# input_max_length == 22: only keep the first 22 tokens,
# including 3 template tokens and 16 ID tokens, and BOS and EOS tokens.
# The results are indistinguishable from input_max_length=77.
input_max_length=22,
noise_level=noise_level,
return_core_id_embs=True,
gen_neg_prompt=False,
verbose=True)
if face_image_count == 0:
return None
# adaface_subj_embs: [1, 1, 16, 768].
# adaface_prompt_embs: [1, 77, 768] (not used).
adaface_subj_embs, adaface_prompt_embs = \
self.subj_basis_generator(id_prompt_emb, None, None,
out_id_embs_scale=out_id_embs_scale,
is_face=True, is_training=False,
adaface_prompt_embs_inf_type='full_half_pad')
# adaface_subj_embs: [16, 768]
adaface_subj_embs = adaface_subj_embs.squeeze()
if update_text_encoder:
self.update_text_encoder_subj_embs(adaface_subj_embs)
return adaface_subj_embs
def encode_prompt(self, prompt, negative_prompt=None, device="cuda", verbose=False):
if negative_prompt is None:
negative_prompt = self.negative_prompt
prompt = self.update_prompt(prompt)
if verbose:
print(f"Prompt: {prompt}")
# For some unknown reason, the text_encoder is still on CPU after self.pipeline.to(self.device).
# So we manually move it to GPU here.
self.pipeline.text_encoder.to(device)
# prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
prompt_embeds_, negative_prompt_embeds_ = \
self.pipeline.encode_prompt(prompt, device=device, num_images_per_prompt=1,
do_classifier_free_guidance=True, negative_prompt=negative_prompt)
return prompt_embeds_, negative_prompt_embeds_
# ref_img_strength is used only in the img2img pipeline.
def forward(self, noise, prompt, negative_prompt=None, guidance_scale=4.0,
out_image_count=4, ref_img_strength=0.8, generator=None, verbose=False):
if negative_prompt is None:
negative_prompt = self.negative_prompt
# prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
prompt_embeds_, negative_prompt_embeds_ = self.encode_prompt(prompt, negative_prompt, device=self.device, verbose=verbose)
# Repeat the prompt embeddings for all images in the batch.
prompt_embeds_ = prompt_embeds_.repeat(out_image_count, 1, 1)
negative_prompt_embeds_ = negative_prompt_embeds_.repeat(out_image_count, 1, 1)
noise = noise.to(self.device).to(torch.float16)
# noise: [BS, 4, 64, 64]
# When the pipeline is text2img, strength is ignored.
images = self.pipeline(image=noise,
prompt_embeds=prompt_embeds_,
negative_prompt_embeds=negative_prompt_embeds_,
num_inference_steps=self.num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
strength=ref_img_strength,
generator=generator).images
# images: [BS, 3, 512, 512]
return images