Upload folder using huggingface_hub
Browse files- .gitattributes +35 -35
- README.md +12 -12
- adaface-infer.py +131 -0
- adaface-translate.py +208 -0
- adaface/__pycache__/adaface_wrapper.cpython-312.pyc +0 -0
- adaface/__pycache__/adaface_wrapper.cpython-38.pyc +0 -0
- adaface/__pycache__/arc2face_models.cpython-312.pyc +0 -0
- adaface/__pycache__/arc2face_models.cpython-38.pyc +0 -0
- adaface/__pycache__/subj_basis_generator.cpython-312.pyc +0 -0
- adaface/__pycache__/subj_basis_generator.cpython-38.pyc +0 -0
- adaface/__pycache__/util.cpython-312.pyc +0 -0
- adaface/__pycache__/util.cpython-38.pyc +0 -0
- adaface/adaface-infer.py +131 -0
- adaface/adaface-translate.py +208 -0
- adaface/adaface_wrapper.py +297 -0
- adaface/arc2face_models.py +303 -0
- adaface/subj_basis_generator.py +758 -0
- adaface/util.py +342 -0
- adaface_wrapper.py +297 -0
- app.py +203 -0
- arc2face_models.py +303 -0
- models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt +3 -0
- models/arc2face/arc2face/config.json +67 -0
- models/arc2face/arc2face/diffusion_pytorch_model.safetensors +3 -0
- models/arc2face/encoder/config.json +24 -0
- models/arc2face/encoder/pytorch_model.bin +3 -0
- models/insightface/models/antelopev2/1k3d68.onnx +3 -0
- models/insightface/models/antelopev2/2d106det.onnx +3 -0
- models/insightface/models/antelopev2/arcface.onnx +3 -0
- models/insightface/models/antelopev2/genderage.onnx +3 -0
- models/insightface/models/antelopev2/scrfd_10g_bnkps.onnx +3 -0
- models/insightface/models/buffalo_l/1k3d68.onnx +3 -0
- models/insightface/models/buffalo_l/2d106det.onnx +3 -0
- models/insightface/models/buffalo_l/det_10g.onnx +3 -0
- models/insightface/models/buffalo_l/genderage.onnx +3 -0
- models/insightface/models/buffalo_l/w600k_r50.onnx +3 -0
- models/sar/sar.safetensors +3 -0
- requirements.txt +12 -0
- subj_basis_generator.py +758 -0
- util.py +342 -0
.gitattributes
CHANGED
@@ -1,35 +1,35 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
-
---
|
2 |
-
title: Adaface
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
---
|
2 |
+
title: Adaface
|
3 |
+
emoji: 🌖
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.37.2
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
adaface-infer.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from adaface.adaface_wrapper import AdaFaceWrapper
|
2 |
+
import torch
|
3 |
+
#import torch.nn.functional as F
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
import os, argparse, glob, re
|
7 |
+
|
8 |
+
def save_images(images, num_images_per_row, subject_name, prompt, noise_level, save_dir = "samples-ada"):
|
9 |
+
if num_images_per_row > len(images):
|
10 |
+
num_images_per_row = len(images)
|
11 |
+
|
12 |
+
os.makedirs(save_dir, exist_ok=True)
|
13 |
+
|
14 |
+
num_columns = int(np.ceil(len(images) / num_images_per_row))
|
15 |
+
# Save 4 images as a grid image in save_dir
|
16 |
+
grid_image = Image.new('RGB', (512 * num_images_per_row, 512 * num_columns))
|
17 |
+
for i, image in enumerate(images):
|
18 |
+
image = image.resize((512, 512))
|
19 |
+
grid_image.paste(image, (512 * (i % num_images_per_row), 512 * (i // num_images_per_row)))
|
20 |
+
|
21 |
+
prompt_sig = prompt.replace(" ", "_").replace(",", "_")
|
22 |
+
grid_filepath = os.path.join(save_dir, f"{subject_name}-{prompt_sig}-noise{noise_level:.02f}.png")
|
23 |
+
if os.path.exists(grid_filepath):
|
24 |
+
grid_count = 2
|
25 |
+
grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-noise{noise_level:.02f}-{grid_count}.jpg')
|
26 |
+
while os.path.exists(grid_filepath):
|
27 |
+
grid_count += 1
|
28 |
+
grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-noise{noise_level:.02f}-{grid_count}.jpg')
|
29 |
+
|
30 |
+
grid_image.save(grid_filepath)
|
31 |
+
print(f"Saved to {grid_filepath}")
|
32 |
+
|
33 |
+
def seed_everything(seed):
|
34 |
+
np.random.seed(seed)
|
35 |
+
torch.manual_seed(seed)
|
36 |
+
torch.cuda.manual_seed_all(seed)
|
37 |
+
torch.backends.cudnn.deterministic = True
|
38 |
+
torch.backends.cudnn.benchmark = False
|
39 |
+
os.environ["PL_GLOBAL_SEED"] = str(seed)
|
40 |
+
|
41 |
+
def parse_args():
|
42 |
+
parser = argparse.ArgumentParser()
|
43 |
+
parser.add_argument("--base_model_path", type=str, default='runwayml/stable-diffusion-v1-5',
|
44 |
+
help="Type of checkpoints to use (default: SD 1.5)")
|
45 |
+
parser.add_argument("--embman_ckpt", type=str, required=True,
|
46 |
+
help="Path to the checkpoint of the embedding manager")
|
47 |
+
parser.add_argument("--subject", type=str, required=True)
|
48 |
+
parser.add_argument("--example_image_count", type=int, default=-1, help="Number of example images to use")
|
49 |
+
parser.add_argument("--out_image_count", type=int, default=4, help="Number of images to generate")
|
50 |
+
parser.add_argument("--prompt", type=str, default="a woman z in superman costume")
|
51 |
+
parser.add_argument("--noise", dest='noise_level', type=float, default=0)
|
52 |
+
parser.add_argument("--randface", action="store_true")
|
53 |
+
parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
|
54 |
+
help="Guidance scale for the diffusion model")
|
55 |
+
parser.add_argument("--id_cfg_scale", type=float, default=1,
|
56 |
+
help="CFG scale when generating the identity embeddings")
|
57 |
+
|
58 |
+
parser.add_argument("--subject_string",
|
59 |
+
type=str, default="z",
|
60 |
+
help="Subject placeholder string used in prompts to denote the concept.")
|
61 |
+
parser.add_argument("--num_vectors", type=int, default=16,
|
62 |
+
help="Number of vectors used to represent the subject.")
|
63 |
+
parser.add_argument("--num_images_per_row", type=int, default=4,
|
64 |
+
help="Number of images to display in a row in the output grid image.")
|
65 |
+
parser.add_argument("--num_inference_steps", type=int, default=50,
|
66 |
+
help="Number of DDIM inference steps")
|
67 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
|
68 |
+
parser.add_argument("--seed", type=int, default=42,
|
69 |
+
help="the seed (for reproducible sampling). Set to -1 to disable.")
|
70 |
+
args = parser.parse_args()
|
71 |
+
|
72 |
+
return args
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
args = parse_args()
|
76 |
+
if args.seed != -1:
|
77 |
+
seed_everything(args.seed)
|
78 |
+
|
79 |
+
if re.match(r"^\d+$", args.device):
|
80 |
+
args.device = f"cuda:{args.device}"
|
81 |
+
print(f"Using device {args.device}")
|
82 |
+
|
83 |
+
adaface = AdaFaceWrapper("text2img", args.base_model_path, args.embman_ckpt, args.device,
|
84 |
+
args.subject_string, args.num_vectors, args.num_inference_steps)
|
85 |
+
|
86 |
+
if not args.randface:
|
87 |
+
image_folder = args.subject
|
88 |
+
if image_folder.endswith("/"):
|
89 |
+
image_folder = image_folder[:-1]
|
90 |
+
|
91 |
+
if os.path.isfile(image_folder):
|
92 |
+
# Get the second to the last part of the path
|
93 |
+
subject_name = os.path.basename(os.path.dirname(image_folder))
|
94 |
+
image_paths = [image_folder]
|
95 |
+
|
96 |
+
else:
|
97 |
+
subject_name = os.path.basename(image_folder)
|
98 |
+
image_types = ["*.jpg", "*.png", "*.jpeg"]
|
99 |
+
alltype_image_paths = []
|
100 |
+
for image_type in image_types:
|
101 |
+
# glob returns the full path.
|
102 |
+
image_paths = glob.glob(os.path.join(image_folder, image_type))
|
103 |
+
if len(image_paths) > 0:
|
104 |
+
alltype_image_paths.extend(image_paths)
|
105 |
+
|
106 |
+
# Filter out images of "*_mask.png"
|
107 |
+
alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path]
|
108 |
+
|
109 |
+
# image_paths contain at most args.example_image_count full image paths.
|
110 |
+
if args.example_image_count > 0:
|
111 |
+
image_paths = alltype_image_paths[:args.example_image_count]
|
112 |
+
else:
|
113 |
+
image_paths = alltype_image_paths
|
114 |
+
else:
|
115 |
+
subject_name = None
|
116 |
+
image_paths = None
|
117 |
+
image_folder = None
|
118 |
+
|
119 |
+
subject_name = "randface-" + str(torch.seed()) if args.randface else subject_name
|
120 |
+
rand_face_embs = torch.randn(1, 512)
|
121 |
+
|
122 |
+
pre_face_embs = rand_face_embs if args.randface else None
|
123 |
+
noise = torch.randn(args.out_image_count, 4, 64, 64).cuda()
|
124 |
+
# args.noise_level: the *relative* std of the noise added to the face embeddings.
|
125 |
+
# A noise level of 0.08 could change gender, but 0.06 is usually safe.
|
126 |
+
# adaface_subj_embs is not used. It is generated for the purpose of updating the text encoder (within this function call).
|
127 |
+
adaface_subj_embs = adaface.generate_adaface_embeddings(image_paths, image_folder, pre_face_embs, args.randface,
|
128 |
+
out_id_embs_scale=args.id_cfg_scale, noise_level=args.noise_level,
|
129 |
+
update_text_encoder=True)
|
130 |
+
images = adaface(noise, args.prompt, args.guidance_scale, args.out_image_count, verbose=True)
|
131 |
+
save_images(images, args.num_images_per_row, subject_name, f"guide{args.guidance_scale}", args.noise_level)
|
adaface-translate.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from adaface.adaface_wrapper import AdaFaceWrapper
|
2 |
+
import torch
|
3 |
+
#import torch.nn.functional as F
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
import os, argparse, glob, re, shutil
|
7 |
+
|
8 |
+
def str2bool(v):
|
9 |
+
if isinstance(v, bool):
|
10 |
+
return v
|
11 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
12 |
+
return True
|
13 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
14 |
+
return False
|
15 |
+
else:
|
16 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
17 |
+
|
18 |
+
def seed_everything(seed):
|
19 |
+
np.random.seed(seed)
|
20 |
+
torch.manual_seed(seed)
|
21 |
+
torch.cuda.manual_seed_all(seed)
|
22 |
+
torch.backends.cudnn.deterministic = True
|
23 |
+
torch.backends.cudnn.benchmark = False
|
24 |
+
os.environ["PL_GLOBAL_SEED"] = str(seed)
|
25 |
+
|
26 |
+
def parse_args():
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
parser.add_argument("--base_model_path", type=str, default='models/realisticvision/realisticVisionV40_v40VAE.safetensors',
|
29 |
+
help="Path to the UNet checkpoint (default: RealisticVision 4.0)")
|
30 |
+
parser.add_argument("--embman_ckpt", type=str, required=True,
|
31 |
+
help="Path to the checkpoint of the embedding manager")
|
32 |
+
parser.add_argument("--in_folder", type=str, required=True, help="Path to the folder containing input images")
|
33 |
+
# If True, the input folder contains images of mixed subjects.
|
34 |
+
# If False, the input folder contains multiple subfolders, each of which contains images of the same subject.
|
35 |
+
parser.add_argument("--is_mix_subj_folder", type=str2bool, const=True, default=False, nargs="?",
|
36 |
+
help="Whether the input folder contains images of mixed subjects")
|
37 |
+
parser.add_argument("--max_images_per_subject", type=int, default=5, help="Number of example images used per subject")
|
38 |
+
parser.add_argument("--trans_subject_count", type=int, default=-1, help="Number of example images to be translated")
|
39 |
+
parser.add_argument("--out_folder", type=str, required=True, help="Path to the folder saving output images")
|
40 |
+
parser.add_argument("--out_count_per_input_image", type=int, default=1, help="Number of output images to generate per input image")
|
41 |
+
parser.add_argument("--copy_masks", action="store_true", help="Copy the mask images to the output folder")
|
42 |
+
parser.add_argument("--noise", dest='noise_level', type=float, default=0)
|
43 |
+
parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
|
44 |
+
help="Guidance scale for the diffusion model")
|
45 |
+
parser.add_argument("--ref_img_strength", type=float, default=0.8,
|
46 |
+
help="Strength of the reference image in the output image.")
|
47 |
+
parser.add_argument("--subject_string",
|
48 |
+
type=str, default="z",
|
49 |
+
help="Subject placeholder string used in prompts to denote the concept.")
|
50 |
+
parser.add_argument("--num_vectors", type=int, default=16,
|
51 |
+
help="Number of vectors used to represent the subject.")
|
52 |
+
parser.add_argument("--prompt", type=str, default="a person z")
|
53 |
+
parser.add_argument("--num_images_per_row", type=int, default=4,
|
54 |
+
help="Number of images to display in a row in the output grid image.")
|
55 |
+
parser.add_argument("--num_inference_steps", type=int, default=50,
|
56 |
+
help="Number of DDIM inference steps")
|
57 |
+
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use. If num_gpus > 1, use accelerate for distributed execution.")
|
58 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
|
59 |
+
parser.add_argument("--seed", type=int, default=42,
|
60 |
+
help="the seed (for reproducible sampling). Set to -1 to disable.")
|
61 |
+
args = parser.parse_args()
|
62 |
+
|
63 |
+
return args
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
args = parse_args()
|
67 |
+
if args.seed != -1:
|
68 |
+
seed_everything(args.seed)
|
69 |
+
|
70 |
+
# screen -dm -L -Logfile trans_rv4-2.txt accelerate launch --multi_gpu --num_processes=2 scripts/adaface-translate.py
|
71 |
+
# --embman_ckpt logs/subjects-celebrity2024-05-16T17-22-46_zero3-ada/checkpoints/embeddings_gs-30000.pt
|
72 |
+
# --base_model_path models/realisticvision/realisticVisionV40_v40VAE.safetensors --in_folder /data/username/VGGface2_HQ_masks/
|
73 |
+
# --is_mix_subj_folder 0 --out_folder /data/username/VGGface2_HQ_masks_rv4a --copy_masks --num_gpus 2
|
74 |
+
if args.num_gpus > 1:
|
75 |
+
from accelerate import PartialState
|
76 |
+
distributed_state = PartialState()
|
77 |
+
args.device = distributed_state.device
|
78 |
+
process_index = distributed_state.process_index
|
79 |
+
elif re.match(r"^\d+$", args.device):
|
80 |
+
args.device = f"cuda:{args.device}"
|
81 |
+
distributed_state = None
|
82 |
+
process_index = 0
|
83 |
+
|
84 |
+
adaface = AdaFaceWrapper("img2img", args.base_model_path, args.embman_ckpt, args.device,
|
85 |
+
args.subject_string, args.num_vectors, args.num_inference_steps)
|
86 |
+
|
87 |
+
in_folder = args.in_folder
|
88 |
+
if os.path.isfile(in_folder):
|
89 |
+
subject_folders = [ os.path.dirname(in_folder) ]
|
90 |
+
images_by_subject = [[in_folder]]
|
91 |
+
else:
|
92 |
+
if not args.is_mix_subj_folder:
|
93 |
+
in_folders = [in_folder]
|
94 |
+
else:
|
95 |
+
in_folders = [ os.path.join(in_folder, subfolder) for subfolder in sorted(os.listdir(in_folder)) ]
|
96 |
+
|
97 |
+
images_by_subject = []
|
98 |
+
subject_folders = []
|
99 |
+
for in_folder in in_folders:
|
100 |
+
image_types = ["*.jpg", "*.png", "*.jpeg"]
|
101 |
+
alltype_image_paths = []
|
102 |
+
for image_type in image_types:
|
103 |
+
# glob returns the full path.
|
104 |
+
image_paths = glob.glob(os.path.join(in_folder, image_type))
|
105 |
+
if len(image_paths) > 0:
|
106 |
+
alltype_image_paths.extend(image_paths)
|
107 |
+
|
108 |
+
# Filter out images of "*_mask.png"
|
109 |
+
alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path]
|
110 |
+
alltype_image_paths = sorted(alltype_image_paths)
|
111 |
+
|
112 |
+
if not args.is_mix_subj_folder:
|
113 |
+
# image_paths contain at most args.max_images_per_subject full image paths.
|
114 |
+
if args.max_images_per_subject > 0:
|
115 |
+
image_paths = alltype_image_paths[:args.max_images_per_subject]
|
116 |
+
else:
|
117 |
+
image_paths = alltype_image_paths
|
118 |
+
|
119 |
+
images_by_subject.append(image_paths)
|
120 |
+
subject_folders.append(in_folder)
|
121 |
+
else:
|
122 |
+
# Each image in the folder is treated as an individual subject.
|
123 |
+
images_by_subject.extend([[image_path] for image_path in alltype_image_paths])
|
124 |
+
subject_folders.extend([in_folder] * len(alltype_image_paths))
|
125 |
+
|
126 |
+
if args.trans_subject_count > 0 and len(subject_folders) >= args.trans_subject_count:
|
127 |
+
break
|
128 |
+
|
129 |
+
if args.trans_subject_count > 0:
|
130 |
+
images_by_subject = images_by_subject[:args.trans_subject_count]
|
131 |
+
subject_folders = subject_folders[:args.trans_subject_count]
|
132 |
+
|
133 |
+
out_image_count = 0
|
134 |
+
out_mask_count = 0
|
135 |
+
if not args.out_folder.endswith("/"):
|
136 |
+
args.out_folder += "/"
|
137 |
+
|
138 |
+
if args.num_gpus > 1:
|
139 |
+
# Split the subjects across the GPUs.
|
140 |
+
subject_folders = subject_folders[process_index::args.num_gpus]
|
141 |
+
images_by_subject = images_by_subject[process_index::args.num_gpus]
|
142 |
+
#subject_folders, images_by_subject = distributed_state.split_between_processes(zip(subject_folders, images_by_subject))
|
143 |
+
|
144 |
+
for (subject_folder, image_paths) in zip(subject_folders, images_by_subject):
|
145 |
+
# If is_mix_subj_folder, then image_paths only contains 1 image, and we use the file name as the signature of the image.
|
146 |
+
# Otherwise, we use the folder name as the signature of the images.
|
147 |
+
images_sig = subject_folder if not args.is_mix_subj_folder else os.path.basename(image_paths[0])
|
148 |
+
|
149 |
+
print(f"Translating {images_sig}...")
|
150 |
+
with torch.no_grad():
|
151 |
+
adaface_subj_embs = adaface.generate_adaface_embeddings(image_paths, subject_folder, None, False,
|
152 |
+
out_id_embs_scale=1, noise_level=args.noise_level,
|
153 |
+
update_text_encoder=True)
|
154 |
+
|
155 |
+
# Replace the first occurrence of "in_folder" with "out_folder" in the path of the subject_folder.
|
156 |
+
subject_out_folder = subject_folder.replace(args.in_folder, args.out_folder, 1)
|
157 |
+
if not os.path.exists(subject_out_folder):
|
158 |
+
os.makedirs(subject_out_folder)
|
159 |
+
print(f"Output images will be saved to {subject_out_folder}")
|
160 |
+
|
161 |
+
in_images = []
|
162 |
+
for image_path in image_paths:
|
163 |
+
image = Image.open(image_path).convert("RGB").resize((512, 512))
|
164 |
+
# [512, 512, 3] -> [3, 512, 512].
|
165 |
+
image = np.array(image).transpose(2, 0, 1)
|
166 |
+
# Convert the image to a tensor of shape (1, 3, 512, 512) and move it to the GPU.
|
167 |
+
image = torch.tensor(image).unsqueeze(0).float().cuda()
|
168 |
+
in_images.append(image)
|
169 |
+
|
170 |
+
# Put all input images of the subject into a batch. This assumes max_images_per_subject is small.
|
171 |
+
# NOTE: For simplicity, we do not check overly large batch sizes.
|
172 |
+
in_images = torch.cat(in_images, dim=0)
|
173 |
+
# in_images: [5, 3, 512, 512].
|
174 |
+
# Normalize the pixel values to [0, 1].
|
175 |
+
in_images = in_images / 255.0
|
176 |
+
num_out_images = len(in_images) * args.out_count_per_input_image
|
177 |
+
|
178 |
+
with torch.no_grad():
|
179 |
+
# args.noise_level: the *relative* std of the noise added to the face embeddings.
|
180 |
+
# A noise level of 0.08 could change gender, but 0.06 is usually safe.
|
181 |
+
# The returned adaface_subj_embs are already incorporated in the text encoder, and not used explicitly.
|
182 |
+
# NOTE: We assume out_count_per_input_image == 1, so that the output images are of the same number as the input images.
|
183 |
+
out_images = adaface(in_images, args.prompt, args.guidance_scale, num_out_images, ref_img_strength=args.ref_img_strength)
|
184 |
+
|
185 |
+
for img_i, img in enumerate(out_images):
|
186 |
+
# out_images: subj_1, subj_2, ..., subj_n, subj_1, subj_2, ..., subj_n, ...
|
187 |
+
subj_i = img_i % len(in_images)
|
188 |
+
copy_i = img_i // len(in_images)
|
189 |
+
image_filename_stem, image_fileext = os.path.splitext(os.path.basename(image_paths[subj_i]))
|
190 |
+
if copy_i == 0:
|
191 |
+
img.save(os.path.join(subject_out_folder, f"{image_filename_stem}{image_fileext}"))
|
192 |
+
else:
|
193 |
+
img.save(os.path.join(subject_out_folder, f"{image_filename_stem}_{copy_i}{image_fileext}"))
|
194 |
+
|
195 |
+
if args.copy_masks:
|
196 |
+
mask_path = image_paths[subj_i].replace(image_fileext, "_mask.png")
|
197 |
+
if os.path.exists(mask_path):
|
198 |
+
if copy_i == 0:
|
199 |
+
shutil.copy(mask_path, subject_out_folder)
|
200 |
+
else:
|
201 |
+
mask_filename_stem = image_filename_stem
|
202 |
+
shutil.copy(mask_path, os.path.join(subject_out_folder, f"{mask_filename_stem}_{copy_i}_mask.png"))
|
203 |
+
|
204 |
+
out_mask_count += 1
|
205 |
+
|
206 |
+
out_image_count += len(out_images)
|
207 |
+
|
208 |
+
print(f"{out_image_count} output images and {out_mask_count} masks saved to {args.out_folder}")
|
adaface/__pycache__/adaface_wrapper.cpython-312.pyc
ADDED
Binary file (13.5 kB). View file
|
|
adaface/__pycache__/adaface_wrapper.cpython-38.pyc
ADDED
Binary file (8.03 kB). View file
|
|
adaface/__pycache__/arc2face_models.cpython-312.pyc
ADDED
Binary file (16.1 kB). View file
|
|
adaface/__pycache__/arc2face_models.cpython-38.pyc
ADDED
Binary file (7 kB). View file
|
|
adaface/__pycache__/subj_basis_generator.cpython-312.pyc
ADDED
Binary file (30.1 kB). View file
|
|
adaface/__pycache__/subj_basis_generator.cpython-38.pyc
ADDED
Binary file (17.6 kB). View file
|
|
adaface/__pycache__/util.cpython-312.pyc
ADDED
Binary file (14 kB). View file
|
|
adaface/__pycache__/util.cpython-38.pyc
ADDED
Binary file (8.57 kB). View file
|
|
adaface/adaface-infer.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from adaface.adaface_wrapper import AdaFaceWrapper
|
2 |
+
import torch
|
3 |
+
#import torch.nn.functional as F
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
import os, argparse, glob, re
|
7 |
+
|
8 |
+
def save_images(images, num_images_per_row, subject_name, prompt, noise_level, save_dir = "samples-ada"):
|
9 |
+
if num_images_per_row > len(images):
|
10 |
+
num_images_per_row = len(images)
|
11 |
+
|
12 |
+
os.makedirs(save_dir, exist_ok=True)
|
13 |
+
|
14 |
+
num_columns = int(np.ceil(len(images) / num_images_per_row))
|
15 |
+
# Save 4 images as a grid image in save_dir
|
16 |
+
grid_image = Image.new('RGB', (512 * num_images_per_row, 512 * num_columns))
|
17 |
+
for i, image in enumerate(images):
|
18 |
+
image = image.resize((512, 512))
|
19 |
+
grid_image.paste(image, (512 * (i % num_images_per_row), 512 * (i // num_images_per_row)))
|
20 |
+
|
21 |
+
prompt_sig = prompt.replace(" ", "_").replace(",", "_")
|
22 |
+
grid_filepath = os.path.join(save_dir, f"{subject_name}-{prompt_sig}-noise{noise_level:.02f}.png")
|
23 |
+
if os.path.exists(grid_filepath):
|
24 |
+
grid_count = 2
|
25 |
+
grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-noise{noise_level:.02f}-{grid_count}.jpg')
|
26 |
+
while os.path.exists(grid_filepath):
|
27 |
+
grid_count += 1
|
28 |
+
grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-noise{noise_level:.02f}-{grid_count}.jpg')
|
29 |
+
|
30 |
+
grid_image.save(grid_filepath)
|
31 |
+
print(f"Saved to {grid_filepath}")
|
32 |
+
|
33 |
+
def seed_everything(seed):
|
34 |
+
np.random.seed(seed)
|
35 |
+
torch.manual_seed(seed)
|
36 |
+
torch.cuda.manual_seed_all(seed)
|
37 |
+
torch.backends.cudnn.deterministic = True
|
38 |
+
torch.backends.cudnn.benchmark = False
|
39 |
+
os.environ["PL_GLOBAL_SEED"] = str(seed)
|
40 |
+
|
41 |
+
def parse_args():
|
42 |
+
parser = argparse.ArgumentParser()
|
43 |
+
parser.add_argument("--base_model_path", type=str, default='runwayml/stable-diffusion-v1-5',
|
44 |
+
help="Type of checkpoints to use (default: SD 1.5)")
|
45 |
+
parser.add_argument("--embman_ckpt", type=str, required=True,
|
46 |
+
help="Path to the checkpoint of the embedding manager")
|
47 |
+
parser.add_argument("--subject", type=str, required=True)
|
48 |
+
parser.add_argument("--example_image_count", type=int, default=-1, help="Number of example images to use")
|
49 |
+
parser.add_argument("--out_image_count", type=int, default=4, help="Number of images to generate")
|
50 |
+
parser.add_argument("--prompt", type=str, default="a woman z in superman costume")
|
51 |
+
parser.add_argument("--noise", dest='noise_level', type=float, default=0)
|
52 |
+
parser.add_argument("--randface", action="store_true")
|
53 |
+
parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
|
54 |
+
help="Guidance scale for the diffusion model")
|
55 |
+
parser.add_argument("--id_cfg_scale", type=float, default=1,
|
56 |
+
help="CFG scale when generating the identity embeddings")
|
57 |
+
|
58 |
+
parser.add_argument("--subject_string",
|
59 |
+
type=str, default="z",
|
60 |
+
help="Subject placeholder string used in prompts to denote the concept.")
|
61 |
+
parser.add_argument("--num_vectors", type=int, default=16,
|
62 |
+
help="Number of vectors used to represent the subject.")
|
63 |
+
parser.add_argument("--num_images_per_row", type=int, default=4,
|
64 |
+
help="Number of images to display in a row in the output grid image.")
|
65 |
+
parser.add_argument("--num_inference_steps", type=int, default=50,
|
66 |
+
help="Number of DDIM inference steps")
|
67 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
|
68 |
+
parser.add_argument("--seed", type=int, default=42,
|
69 |
+
help="the seed (for reproducible sampling). Set to -1 to disable.")
|
70 |
+
args = parser.parse_args()
|
71 |
+
|
72 |
+
return args
|
73 |
+
|
74 |
+
if __name__ == "__main__":
|
75 |
+
args = parse_args()
|
76 |
+
if args.seed != -1:
|
77 |
+
seed_everything(args.seed)
|
78 |
+
|
79 |
+
if re.match(r"^\d+$", args.device):
|
80 |
+
args.device = f"cuda:{args.device}"
|
81 |
+
print(f"Using device {args.device}")
|
82 |
+
|
83 |
+
adaface = AdaFaceWrapper("text2img", args.base_model_path, args.embman_ckpt, args.device,
|
84 |
+
args.subject_string, args.num_vectors, args.num_inference_steps)
|
85 |
+
|
86 |
+
if not args.randface:
|
87 |
+
image_folder = args.subject
|
88 |
+
if image_folder.endswith("/"):
|
89 |
+
image_folder = image_folder[:-1]
|
90 |
+
|
91 |
+
if os.path.isfile(image_folder):
|
92 |
+
# Get the second to the last part of the path
|
93 |
+
subject_name = os.path.basename(os.path.dirname(image_folder))
|
94 |
+
image_paths = [image_folder]
|
95 |
+
|
96 |
+
else:
|
97 |
+
subject_name = os.path.basename(image_folder)
|
98 |
+
image_types = ["*.jpg", "*.png", "*.jpeg"]
|
99 |
+
alltype_image_paths = []
|
100 |
+
for image_type in image_types:
|
101 |
+
# glob returns the full path.
|
102 |
+
image_paths = glob.glob(os.path.join(image_folder, image_type))
|
103 |
+
if len(image_paths) > 0:
|
104 |
+
alltype_image_paths.extend(image_paths)
|
105 |
+
|
106 |
+
# Filter out images of "*_mask.png"
|
107 |
+
alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path]
|
108 |
+
|
109 |
+
# image_paths contain at most args.example_image_count full image paths.
|
110 |
+
if args.example_image_count > 0:
|
111 |
+
image_paths = alltype_image_paths[:args.example_image_count]
|
112 |
+
else:
|
113 |
+
image_paths = alltype_image_paths
|
114 |
+
else:
|
115 |
+
subject_name = None
|
116 |
+
image_paths = None
|
117 |
+
image_folder = None
|
118 |
+
|
119 |
+
subject_name = "randface-" + str(torch.seed()) if args.randface else subject_name
|
120 |
+
rand_face_embs = torch.randn(1, 512)
|
121 |
+
|
122 |
+
pre_face_embs = rand_face_embs if args.randface else None
|
123 |
+
noise = torch.randn(args.out_image_count, 4, 64, 64).cuda()
|
124 |
+
# args.noise_level: the *relative* std of the noise added to the face embeddings.
|
125 |
+
# A noise level of 0.08 could change gender, but 0.06 is usually safe.
|
126 |
+
# adaface_subj_embs is not used. It is generated for the purpose of updating the text encoder (within this function call).
|
127 |
+
adaface_subj_embs = adaface.generate_adaface_embeddings(image_paths, image_folder, pre_face_embs, args.randface,
|
128 |
+
out_id_embs_scale=args.id_cfg_scale, noise_level=args.noise_level,
|
129 |
+
update_text_encoder=True)
|
130 |
+
images = adaface(noise, args.prompt, args.guidance_scale, args.out_image_count, verbose=True)
|
131 |
+
save_images(images, args.num_images_per_row, subject_name, f"guide{args.guidance_scale}", args.noise_level)
|
adaface/adaface-translate.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from adaface.adaface_wrapper import AdaFaceWrapper
|
2 |
+
import torch
|
3 |
+
#import torch.nn.functional as F
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
import os, argparse, glob, re, shutil
|
7 |
+
|
8 |
+
def str2bool(v):
|
9 |
+
if isinstance(v, bool):
|
10 |
+
return v
|
11 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
12 |
+
return True
|
13 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
14 |
+
return False
|
15 |
+
else:
|
16 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
17 |
+
|
18 |
+
def seed_everything(seed):
|
19 |
+
np.random.seed(seed)
|
20 |
+
torch.manual_seed(seed)
|
21 |
+
torch.cuda.manual_seed_all(seed)
|
22 |
+
torch.backends.cudnn.deterministic = True
|
23 |
+
torch.backends.cudnn.benchmark = False
|
24 |
+
os.environ["PL_GLOBAL_SEED"] = str(seed)
|
25 |
+
|
26 |
+
def parse_args():
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
parser.add_argument("--base_model_path", type=str, default='models/realisticvision/realisticVisionV40_v40VAE.safetensors',
|
29 |
+
help="Path to the UNet checkpoint (default: RealisticVision 4.0)")
|
30 |
+
parser.add_argument("--embman_ckpt", type=str, required=True,
|
31 |
+
help="Path to the checkpoint of the embedding manager")
|
32 |
+
parser.add_argument("--in_folder", type=str, required=True, help="Path to the folder containing input images")
|
33 |
+
# If True, the input folder contains images of mixed subjects.
|
34 |
+
# If False, the input folder contains multiple subfolders, each of which contains images of the same subject.
|
35 |
+
parser.add_argument("--is_mix_subj_folder", type=str2bool, const=True, default=False, nargs="?",
|
36 |
+
help="Whether the input folder contains images of mixed subjects")
|
37 |
+
parser.add_argument("--max_images_per_subject", type=int, default=5, help="Number of example images used per subject")
|
38 |
+
parser.add_argument("--trans_subject_count", type=int, default=-1, help="Number of example images to be translated")
|
39 |
+
parser.add_argument("--out_folder", type=str, required=True, help="Path to the folder saving output images")
|
40 |
+
parser.add_argument("--out_count_per_input_image", type=int, default=1, help="Number of output images to generate per input image")
|
41 |
+
parser.add_argument("--copy_masks", action="store_true", help="Copy the mask images to the output folder")
|
42 |
+
parser.add_argument("--noise", dest='noise_level', type=float, default=0)
|
43 |
+
parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
|
44 |
+
help="Guidance scale for the diffusion model")
|
45 |
+
parser.add_argument("--ref_img_strength", type=float, default=0.8,
|
46 |
+
help="Strength of the reference image in the output image.")
|
47 |
+
parser.add_argument("--subject_string",
|
48 |
+
type=str, default="z",
|
49 |
+
help="Subject placeholder string used in prompts to denote the concept.")
|
50 |
+
parser.add_argument("--num_vectors", type=int, default=16,
|
51 |
+
help="Number of vectors used to represent the subject.")
|
52 |
+
parser.add_argument("--prompt", type=str, default="a person z")
|
53 |
+
parser.add_argument("--num_images_per_row", type=int, default=4,
|
54 |
+
help="Number of images to display in a row in the output grid image.")
|
55 |
+
parser.add_argument("--num_inference_steps", type=int, default=50,
|
56 |
+
help="Number of DDIM inference steps")
|
57 |
+
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use. If num_gpus > 1, use accelerate for distributed execution.")
|
58 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
|
59 |
+
parser.add_argument("--seed", type=int, default=42,
|
60 |
+
help="the seed (for reproducible sampling). Set to -1 to disable.")
|
61 |
+
args = parser.parse_args()
|
62 |
+
|
63 |
+
return args
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
args = parse_args()
|
67 |
+
if args.seed != -1:
|
68 |
+
seed_everything(args.seed)
|
69 |
+
|
70 |
+
# screen -dm -L -Logfile trans_rv4-2.txt accelerate launch --multi_gpu --num_processes=2 scripts/adaface-translate.py
|
71 |
+
# --embman_ckpt logs/subjects-celebrity2024-05-16T17-22-46_zero3-ada/checkpoints/embeddings_gs-30000.pt
|
72 |
+
# --base_model_path models/realisticvision/realisticVisionV40_v40VAE.safetensors --in_folder /data/username/VGGface2_HQ_masks/
|
73 |
+
# --is_mix_subj_folder 0 --out_folder /data/username/VGGface2_HQ_masks_rv4a --copy_masks --num_gpus 2
|
74 |
+
if args.num_gpus > 1:
|
75 |
+
from accelerate import PartialState
|
76 |
+
distributed_state = PartialState()
|
77 |
+
args.device = distributed_state.device
|
78 |
+
process_index = distributed_state.process_index
|
79 |
+
elif re.match(r"^\d+$", args.device):
|
80 |
+
args.device = f"cuda:{args.device}"
|
81 |
+
distributed_state = None
|
82 |
+
process_index = 0
|
83 |
+
|
84 |
+
adaface = AdaFaceWrapper("img2img", args.base_model_path, args.embman_ckpt, args.device,
|
85 |
+
args.subject_string, args.num_vectors, args.num_inference_steps)
|
86 |
+
|
87 |
+
in_folder = args.in_folder
|
88 |
+
if os.path.isfile(in_folder):
|
89 |
+
subject_folders = [ os.path.dirname(in_folder) ]
|
90 |
+
images_by_subject = [[in_folder]]
|
91 |
+
else:
|
92 |
+
if not args.is_mix_subj_folder:
|
93 |
+
in_folders = [in_folder]
|
94 |
+
else:
|
95 |
+
in_folders = [ os.path.join(in_folder, subfolder) for subfolder in sorted(os.listdir(in_folder)) ]
|
96 |
+
|
97 |
+
images_by_subject = []
|
98 |
+
subject_folders = []
|
99 |
+
for in_folder in in_folders:
|
100 |
+
image_types = ["*.jpg", "*.png", "*.jpeg"]
|
101 |
+
alltype_image_paths = []
|
102 |
+
for image_type in image_types:
|
103 |
+
# glob returns the full path.
|
104 |
+
image_paths = glob.glob(os.path.join(in_folder, image_type))
|
105 |
+
if len(image_paths) > 0:
|
106 |
+
alltype_image_paths.extend(image_paths)
|
107 |
+
|
108 |
+
# Filter out images of "*_mask.png"
|
109 |
+
alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path]
|
110 |
+
alltype_image_paths = sorted(alltype_image_paths)
|
111 |
+
|
112 |
+
if not args.is_mix_subj_folder:
|
113 |
+
# image_paths contain at most args.max_images_per_subject full image paths.
|
114 |
+
if args.max_images_per_subject > 0:
|
115 |
+
image_paths = alltype_image_paths[:args.max_images_per_subject]
|
116 |
+
else:
|
117 |
+
image_paths = alltype_image_paths
|
118 |
+
|
119 |
+
images_by_subject.append(image_paths)
|
120 |
+
subject_folders.append(in_folder)
|
121 |
+
else:
|
122 |
+
# Each image in the folder is treated as an individual subject.
|
123 |
+
images_by_subject.extend([[image_path] for image_path in alltype_image_paths])
|
124 |
+
subject_folders.extend([in_folder] * len(alltype_image_paths))
|
125 |
+
|
126 |
+
if args.trans_subject_count > 0 and len(subject_folders) >= args.trans_subject_count:
|
127 |
+
break
|
128 |
+
|
129 |
+
if args.trans_subject_count > 0:
|
130 |
+
images_by_subject = images_by_subject[:args.trans_subject_count]
|
131 |
+
subject_folders = subject_folders[:args.trans_subject_count]
|
132 |
+
|
133 |
+
out_image_count = 0
|
134 |
+
out_mask_count = 0
|
135 |
+
if not args.out_folder.endswith("/"):
|
136 |
+
args.out_folder += "/"
|
137 |
+
|
138 |
+
if args.num_gpus > 1:
|
139 |
+
# Split the subjects across the GPUs.
|
140 |
+
subject_folders = subject_folders[process_index::args.num_gpus]
|
141 |
+
images_by_subject = images_by_subject[process_index::args.num_gpus]
|
142 |
+
#subject_folders, images_by_subject = distributed_state.split_between_processes(zip(subject_folders, images_by_subject))
|
143 |
+
|
144 |
+
for (subject_folder, image_paths) in zip(subject_folders, images_by_subject):
|
145 |
+
# If is_mix_subj_folder, then image_paths only contains 1 image, and we use the file name as the signature of the image.
|
146 |
+
# Otherwise, we use the folder name as the signature of the images.
|
147 |
+
images_sig = subject_folder if not args.is_mix_subj_folder else os.path.basename(image_paths[0])
|
148 |
+
|
149 |
+
print(f"Translating {images_sig}...")
|
150 |
+
with torch.no_grad():
|
151 |
+
adaface_subj_embs = adaface.generate_adaface_embeddings(image_paths, subject_folder, None, False,
|
152 |
+
out_id_embs_scale=1, noise_level=args.noise_level,
|
153 |
+
update_text_encoder=True)
|
154 |
+
|
155 |
+
# Replace the first occurrence of "in_folder" with "out_folder" in the path of the subject_folder.
|
156 |
+
subject_out_folder = subject_folder.replace(args.in_folder, args.out_folder, 1)
|
157 |
+
if not os.path.exists(subject_out_folder):
|
158 |
+
os.makedirs(subject_out_folder)
|
159 |
+
print(f"Output images will be saved to {subject_out_folder}")
|
160 |
+
|
161 |
+
in_images = []
|
162 |
+
for image_path in image_paths:
|
163 |
+
image = Image.open(image_path).convert("RGB").resize((512, 512))
|
164 |
+
# [512, 512, 3] -> [3, 512, 512].
|
165 |
+
image = np.array(image).transpose(2, 0, 1)
|
166 |
+
# Convert the image to a tensor of shape (1, 3, 512, 512) and move it to the GPU.
|
167 |
+
image = torch.tensor(image).unsqueeze(0).float().cuda()
|
168 |
+
in_images.append(image)
|
169 |
+
|
170 |
+
# Put all input images of the subject into a batch. This assumes max_images_per_subject is small.
|
171 |
+
# NOTE: For simplicity, we do not check overly large batch sizes.
|
172 |
+
in_images = torch.cat(in_images, dim=0)
|
173 |
+
# in_images: [5, 3, 512, 512].
|
174 |
+
# Normalize the pixel values to [0, 1].
|
175 |
+
in_images = in_images / 255.0
|
176 |
+
num_out_images = len(in_images) * args.out_count_per_input_image
|
177 |
+
|
178 |
+
with torch.no_grad():
|
179 |
+
# args.noise_level: the *relative* std of the noise added to the face embeddings.
|
180 |
+
# A noise level of 0.08 could change gender, but 0.06 is usually safe.
|
181 |
+
# The returned adaface_subj_embs are already incorporated in the text encoder, and not used explicitly.
|
182 |
+
# NOTE: We assume out_count_per_input_image == 1, so that the output images are of the same number as the input images.
|
183 |
+
out_images = adaface(in_images, args.prompt, args.guidance_scale, num_out_images, ref_img_strength=args.ref_img_strength)
|
184 |
+
|
185 |
+
for img_i, img in enumerate(out_images):
|
186 |
+
# out_images: subj_1, subj_2, ..., subj_n, subj_1, subj_2, ..., subj_n, ...
|
187 |
+
subj_i = img_i % len(in_images)
|
188 |
+
copy_i = img_i // len(in_images)
|
189 |
+
image_filename_stem, image_fileext = os.path.splitext(os.path.basename(image_paths[subj_i]))
|
190 |
+
if copy_i == 0:
|
191 |
+
img.save(os.path.join(subject_out_folder, f"{image_filename_stem}{image_fileext}"))
|
192 |
+
else:
|
193 |
+
img.save(os.path.join(subject_out_folder, f"{image_filename_stem}_{copy_i}{image_fileext}"))
|
194 |
+
|
195 |
+
if args.copy_masks:
|
196 |
+
mask_path = image_paths[subj_i].replace(image_fileext, "_mask.png")
|
197 |
+
if os.path.exists(mask_path):
|
198 |
+
if copy_i == 0:
|
199 |
+
shutil.copy(mask_path, subject_out_folder)
|
200 |
+
else:
|
201 |
+
mask_filename_stem = image_filename_stem
|
202 |
+
shutil.copy(mask_path, os.path.join(subject_out_folder, f"{mask_filename_stem}_{copy_i}_mask.png"))
|
203 |
+
|
204 |
+
out_mask_count += 1
|
205 |
+
|
206 |
+
out_image_count += len(out_images)
|
207 |
+
|
208 |
+
print(f"{out_image_count} output images and {out_mask_count} masks saved to {args.out_folder}")
|
adaface/adaface_wrapper.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import CLIPTextModel
|
4 |
+
from diffusers import (
|
5 |
+
StableDiffusionPipeline,
|
6 |
+
StableDiffusionImg2ImgPipeline,
|
7 |
+
UNet2DConditionModel,
|
8 |
+
DDIMScheduler,
|
9 |
+
AutoencoderKL,
|
10 |
+
)
|
11 |
+
from insightface.app import FaceAnalysis
|
12 |
+
from adaface.arc2face_models import CLIPTextModelWrapper
|
13 |
+
from adaface.util import get_arc2face_id_prompt_embs
|
14 |
+
import re, os
|
15 |
+
import sys
|
16 |
+
sys.modules['ldm'] = sys.modules['adaface']
|
17 |
+
|
18 |
+
class AdaFaceWrapper(nn.Module):
|
19 |
+
def __init__(self, pipeline_name, base_model_path, adaface_ckpt_path, device,
|
20 |
+
subject_string='z', num_vectors=16,
|
21 |
+
num_inference_steps=50, negative_prompt=None,
|
22 |
+
use_840k_vae=False, use_ds_text_encoder=False, is_training=False):
|
23 |
+
'''
|
24 |
+
pipeline_name: "text2img" or "img2img" or None. If None, the unet and vae are
|
25 |
+
removed from the pipeline to release RAM.
|
26 |
+
'''
|
27 |
+
super().__init__()
|
28 |
+
self.pipeline_name = pipeline_name
|
29 |
+
self.base_model_path = base_model_path
|
30 |
+
self.adaface_ckpt_path = adaface_ckpt_path
|
31 |
+
self.use_840k_vae = use_840k_vae
|
32 |
+
self.use_ds_text_encoder = use_ds_text_encoder
|
33 |
+
self.subject_string = subject_string
|
34 |
+
self.num_vectors = num_vectors
|
35 |
+
self.num_inference_steps = num_inference_steps
|
36 |
+
self.device = device
|
37 |
+
self.is_training = is_training
|
38 |
+
self.initialize_pipeline()
|
39 |
+
self.extend_tokenizer_and_text_encoder()
|
40 |
+
if negative_prompt is None:
|
41 |
+
self.negative_prompt = \
|
42 |
+
"flaws in the eyes, flaws in the face, lowres, non-HDRi, low quality, worst quality, artifacts, noise, text, watermark, glitch, " \
|
43 |
+
"mutated, ugly, disfigured, hands, partially rendered objects, partially rendered eyes, deformed eyeballs, cross-eyed, blurry, " \
|
44 |
+
"mutation, duplicate, out of frame, cropped, mutilated, bad anatomy, deformed, bad proportions, " \
|
45 |
+
"nude, naked, nsfw, topless, bare breasts"
|
46 |
+
else:
|
47 |
+
self.negative_prompt = negative_prompt
|
48 |
+
|
49 |
+
def load_subj_basis_generator(self, adaface_ckpt_path):
|
50 |
+
ckpt = torch.load(adaface_ckpt_path, map_location='cpu')
|
51 |
+
string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
|
52 |
+
if self.subject_string not in string_to_subj_basis_generator_dict:
|
53 |
+
print(f"Subject '{self.subject_string}' not found in the embedding manager.")
|
54 |
+
breakpoint()
|
55 |
+
|
56 |
+
self.subj_basis_generator = string_to_subj_basis_generator_dict[self.subject_string]
|
57 |
+
# In the original ckpt, num_out_layers is 16 for layerwise embeddings.
|
58 |
+
# But we don't do layerwise embeddings here, so we set it to 1.
|
59 |
+
self.subj_basis_generator.num_out_layers = 1
|
60 |
+
print(f"Loaded subject basis generator for '{self.subject_string}'.")
|
61 |
+
print(repr(self.subj_basis_generator))
|
62 |
+
self.subj_basis_generator.to(self.device)
|
63 |
+
if self.is_training:
|
64 |
+
self.subj_basis_generator.train()
|
65 |
+
else:
|
66 |
+
self.subj_basis_generator.eval()
|
67 |
+
|
68 |
+
def initialize_pipeline(self):
|
69 |
+
self.load_subj_basis_generator(self.adaface_ckpt_path)
|
70 |
+
# arc2face_text_encoder maps the face analysis embedding to 16 face embeddings
|
71 |
+
# in the UNet image space.
|
72 |
+
arc2face_text_encoder = CLIPTextModelWrapper.from_pretrained(
|
73 |
+
'models/arc2face', subfolder="encoder", torch_dtype=torch.float16
|
74 |
+
)
|
75 |
+
self.arc2face_text_encoder = arc2face_text_encoder.to(self.device)
|
76 |
+
|
77 |
+
if self.use_840k_vae:
|
78 |
+
# The 840000-step vae model is slightly better in face details than the original vae model.
|
79 |
+
# https://huggingface.co/stabilityai/sd-vae-ft-mse-original
|
80 |
+
vae = AutoencoderKL.from_single_file("models/diffusers/sd-vae-ft-mse-original/vae-ft-mse-840000-ema-pruned.ckpt", torch_dtype=torch.float16)
|
81 |
+
else:
|
82 |
+
vae = None
|
83 |
+
|
84 |
+
if self.use_ds_text_encoder:
|
85 |
+
# The dreamshaper v7 finetuned text encoder follows the prompt slightly better than the original text encoder.
|
86 |
+
# https://huggingface.co/Lykon/DreamShaper/tree/main/text_encoder
|
87 |
+
text_encoder = CLIPTextModel.from_pretrained("models/ds_text_encoder", torch_dtype=torch.float16)
|
88 |
+
else:
|
89 |
+
text_encoder = None
|
90 |
+
|
91 |
+
remove_unet = False
|
92 |
+
|
93 |
+
if self.pipeline_name == "img2img":
|
94 |
+
PipelineClass = StableDiffusionImg2ImgPipeline
|
95 |
+
elif self.pipeline_name == "text2img":
|
96 |
+
PipelineClass = StableDiffusionPipeline
|
97 |
+
# pipeline_name is None means only use this instance to generate adaface embeddings, not to generate images.
|
98 |
+
elif self.pipeline_name is None:
|
99 |
+
PipelineClass = StableDiffusionPipeline
|
100 |
+
remove_unet = True
|
101 |
+
else:
|
102 |
+
raise ValueError(f"Unknown pipeline name: {self.pipeline_name}")
|
103 |
+
|
104 |
+
if os.path.isfile(self.base_model_path):
|
105 |
+
pipeline = PipelineClass.from_single_file(
|
106 |
+
self.base_model_path,
|
107 |
+
torch_dtype=torch.float16
|
108 |
+
)
|
109 |
+
else:
|
110 |
+
pipeline = PipelineClass.from_pretrained(
|
111 |
+
self.base_model_path,
|
112 |
+
torch_dtype=torch.float16,
|
113 |
+
safety_checker=None
|
114 |
+
)
|
115 |
+
print(f"Loaded pipeline from {self.base_model_path}.")
|
116 |
+
|
117 |
+
if self.use_840k_vae:
|
118 |
+
pipeline.vae = vae
|
119 |
+
print("Replaced the VAE with the 840k-step VAE.")
|
120 |
+
|
121 |
+
if self.use_ds_text_encoder:
|
122 |
+
pipeline.text_encoder = text_encoder
|
123 |
+
print("Replaced the text encoder with the DreamShaper text encoder.")
|
124 |
+
|
125 |
+
if remove_unet:
|
126 |
+
# Remove unet and vae to release RAM. Only keep tokenizer and text_encoder.
|
127 |
+
pipeline.unet = None
|
128 |
+
pipeline.vae = None
|
129 |
+
print("Removed UNet and VAE from the pipeline.")
|
130 |
+
|
131 |
+
noise_scheduler = DDIMScheduler(
|
132 |
+
num_train_timesteps=1000,
|
133 |
+
beta_start=0.00085,
|
134 |
+
beta_end=0.012,
|
135 |
+
beta_schedule="scaled_linear",
|
136 |
+
clip_sample=False,
|
137 |
+
set_alpha_to_one=False,
|
138 |
+
steps_offset=1,
|
139 |
+
)
|
140 |
+
|
141 |
+
pipeline.scheduler = noise_scheduler
|
142 |
+
self.pipeline = pipeline.to(self.device)
|
143 |
+
# FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2.
|
144 |
+
# Note there's a second "model" in the path.
|
145 |
+
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
146 |
+
self.face_app.prepare(ctx_id=0, det_size=(512, 512))
|
147 |
+
# Patch the missing tokenizer in the subj_basis_generator.
|
148 |
+
if not hasattr(self.subj_basis_generator, 'clip_tokenizer'):
|
149 |
+
self.subj_basis_generator.clip_tokenizer = self.pipeline.tokenizer
|
150 |
+
print("Patched the missing tokenizer in the subj_basis_generator.")
|
151 |
+
|
152 |
+
def extend_tokenizer_and_text_encoder(self):
|
153 |
+
if self.num_vectors < 1:
|
154 |
+
raise ValueError(f"num_vectors has to be larger or equal to 1, but is {self.num_vectors}")
|
155 |
+
|
156 |
+
tokenizer = self.pipeline.tokenizer
|
157 |
+
# Add z0, z1, z2, ..., z15.
|
158 |
+
self.placeholder_tokens = []
|
159 |
+
for i in range(0, self.num_vectors):
|
160 |
+
self.placeholder_tokens.append(f"{self.subject_string}_{i}")
|
161 |
+
|
162 |
+
self.placeholder_tokens_str = " ".join(self.placeholder_tokens)
|
163 |
+
|
164 |
+
# Add the new tokens to the tokenizer.
|
165 |
+
num_added_tokens = tokenizer.add_tokens(self.placeholder_tokens)
|
166 |
+
if num_added_tokens != self.num_vectors:
|
167 |
+
raise ValueError(
|
168 |
+
f"The tokenizer already contains the token {self.subject_string}. Please pass a different"
|
169 |
+
" `subject_string` that is not already in the tokenizer.")
|
170 |
+
|
171 |
+
print(f"Added {num_added_tokens} tokens ({self.placeholder_tokens_str}) to the tokenizer.")
|
172 |
+
|
173 |
+
# placeholder_token_ids: [49408, ..., 49423].
|
174 |
+
self.placeholder_token_ids = tokenizer.convert_tokens_to_ids(self.placeholder_tokens)
|
175 |
+
# print(self.placeholder_token_ids)
|
176 |
+
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
177 |
+
old_weight = self.pipeline.text_encoder.get_input_embeddings().weight
|
178 |
+
self.pipeline.text_encoder.resize_token_embeddings(len(tokenizer))
|
179 |
+
new_weight = self.pipeline.text_encoder.get_input_embeddings().weight
|
180 |
+
print(f"Resized text encoder token embeddings from {old_weight.shape} to {new_weight.shape} on {new_weight.device}.")
|
181 |
+
|
182 |
+
# Extend pipeline.text_encoder with the adaface subject emeddings.
|
183 |
+
# subj_embs: [16, 768].
|
184 |
+
def update_text_encoder_subj_embs(self, subj_embs):
|
185 |
+
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
186 |
+
token_embeds = self.pipeline.text_encoder.get_input_embeddings().weight.data
|
187 |
+
with torch.no_grad():
|
188 |
+
for i, token_id in enumerate(self.placeholder_token_ids):
|
189 |
+
token_embeds[token_id] = subj_embs[i]
|
190 |
+
print(f"Updated {len(self.placeholder_token_ids)} tokens ({self.placeholder_tokens_str}) in the text encoder.")
|
191 |
+
|
192 |
+
def update_prompt(self, prompt):
|
193 |
+
# If the placeholder tokens are already in the prompt, then return the prompt as is.
|
194 |
+
if self.placeholder_tokens_str in prompt:
|
195 |
+
return prompt
|
196 |
+
|
197 |
+
# If the subject string 'z' is not in the prompt, then simply prepend the placeholder tokens to the prompt.
|
198 |
+
if re.search(r'\b' + self.subject_string + r'\b', prompt) is None:
|
199 |
+
print(f"Subject string '{self.subject_string}' not found in the prompt. Adding it.")
|
200 |
+
comp_prompt = self.placeholder_tokens_str + " " + prompt
|
201 |
+
else:
|
202 |
+
# Replace the subject string 'z' with the placeholder tokens.
|
203 |
+
comp_prompt = re.sub(r'\b' + self.subject_string + r'\b', self.placeholder_tokens_str, prompt)
|
204 |
+
return comp_prompt
|
205 |
+
|
206 |
+
# image_paths: a list of image paths. image_folder: the parent folder name.
|
207 |
+
def generate_adaface_embeddings(self, image_paths, image_folder=None,
|
208 |
+
pre_face_embs=None, gen_rand_face=False,
|
209 |
+
out_id_embs_scale=1., noise_level=0, update_text_encoder=True):
|
210 |
+
# faceid_embeds is a batch of extracted face analysis embeddings (BS * 512 = id_batch_size * 512).
|
211 |
+
# If extract_faceid_embeds is True, faceid_embeds is *the same* embedding repeated by id_batch_size times.
|
212 |
+
# Otherwise, faceid_embeds is a batch of random embeddings, each instance is different.
|
213 |
+
# The same applies to id_prompt_emb.
|
214 |
+
# faceid_embeds is in the face analysis embeddings. id_prompt_emb is in the image prompt space.
|
215 |
+
# Here id_batch_size = 1, so
|
216 |
+
# faceid_embeds: [1, 512]. NOT used later.
|
217 |
+
# id_prompt_emb: [1, 16, 768].
|
218 |
+
# NOTE: Since return_core_id_embs is True, id_prompt_emb is only the 16 core ID embeddings.
|
219 |
+
# arc2face prompt template: "photo of a id person"
|
220 |
+
# ID embeddings start from "id person ...". So there are 3 template tokens before the 16 ID embeddings.
|
221 |
+
face_image_count, faceid_embeds, id_prompt_emb \
|
222 |
+
= get_arc2face_id_prompt_embs(self.face_app, self.pipeline.tokenizer, self.arc2face_text_encoder,
|
223 |
+
extract_faceid_embeds=not gen_rand_face,
|
224 |
+
pre_face_embs=pre_face_embs,
|
225 |
+
# image_folder is passed only for logging purpose.
|
226 |
+
# image_paths contains the paths of the images.
|
227 |
+
image_folder=image_folder, image_paths=image_paths,
|
228 |
+
images_np=None,
|
229 |
+
id_batch_size=1,
|
230 |
+
device=self.device,
|
231 |
+
# input_max_length == 22: only keep the first 22 tokens,
|
232 |
+
# including 3 template tokens and 16 ID tokens, and BOS and EOS tokens.
|
233 |
+
# The results are indistinguishable from input_max_length=77.
|
234 |
+
input_max_length=22,
|
235 |
+
noise_level=noise_level,
|
236 |
+
return_core_id_embs=True,
|
237 |
+
gen_neg_prompt=False,
|
238 |
+
verbose=True)
|
239 |
+
|
240 |
+
if face_image_count == 0:
|
241 |
+
return None
|
242 |
+
|
243 |
+
# adaface_subj_embs: [1, 1, 16, 768].
|
244 |
+
# adaface_prompt_embs: [1, 77, 768] (not used).
|
245 |
+
adaface_subj_embs, adaface_prompt_embs = \
|
246 |
+
self.subj_basis_generator(id_prompt_emb, None, None,
|
247 |
+
out_id_embs_scale=out_id_embs_scale,
|
248 |
+
is_face=True, is_training=False,
|
249 |
+
adaface_prompt_embs_inf_type='full_half_pad')
|
250 |
+
# adaface_subj_embs: [16, 768]
|
251 |
+
adaface_subj_embs = adaface_subj_embs.squeeze()
|
252 |
+
if update_text_encoder:
|
253 |
+
self.update_text_encoder_subj_embs(adaface_subj_embs)
|
254 |
+
return adaface_subj_embs
|
255 |
+
|
256 |
+
def encode_prompt(self, prompt, negative_prompt=None, device="cuda", verbose=False):
|
257 |
+
if negative_prompt is None:
|
258 |
+
negative_prompt = self.negative_prompt
|
259 |
+
|
260 |
+
prompt = self.update_prompt(prompt)
|
261 |
+
if verbose:
|
262 |
+
print(f"Prompt: {prompt}")
|
263 |
+
|
264 |
+
# For some unknown reason, the text_encoder is still on CPU after self.pipeline.to(self.device).
|
265 |
+
# So we manually move it to GPU here.
|
266 |
+
self.pipeline.text_encoder.to(device)
|
267 |
+
# prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
|
268 |
+
prompt_embeds_, negative_prompt_embeds_ = \
|
269 |
+
self.pipeline.encode_prompt(prompt, device=device, num_images_per_prompt=1,
|
270 |
+
do_classifier_free_guidance=True, negative_prompt=negative_prompt)
|
271 |
+
return prompt_embeds_, negative_prompt_embeds_
|
272 |
+
|
273 |
+
# ref_img_strength is used only in the img2img pipeline.
|
274 |
+
def forward(self, noise, prompt, negative_prompt=None, guidance_scale=4.0,
|
275 |
+
out_image_count=4, ref_img_strength=0.8, generator=None, verbose=False):
|
276 |
+
if negative_prompt is None:
|
277 |
+
negative_prompt = self.negative_prompt
|
278 |
+
# prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
|
279 |
+
prompt_embeds_, negative_prompt_embeds_ = self.encode_prompt(prompt, negative_prompt, device=self.device, verbose=verbose)
|
280 |
+
# Repeat the prompt embeddings for all images in the batch.
|
281 |
+
prompt_embeds_ = prompt_embeds_.repeat(out_image_count, 1, 1)
|
282 |
+
negative_prompt_embeds_ = negative_prompt_embeds_.repeat(out_image_count, 1, 1)
|
283 |
+
noise = noise.to(self.device).to(torch.float16)
|
284 |
+
|
285 |
+
# noise: [BS, 4, 64, 64]
|
286 |
+
# When the pipeline is text2img, strength is ignored.
|
287 |
+
images = self.pipeline(image=noise,
|
288 |
+
prompt_embeds=prompt_embeds_,
|
289 |
+
negative_prompt_embeds=negative_prompt_embeds_,
|
290 |
+
num_inference_steps=self.num_inference_steps,
|
291 |
+
guidance_scale=guidance_scale,
|
292 |
+
num_images_per_prompt=1,
|
293 |
+
strength=ref_img_strength,
|
294 |
+
generator=generator).images
|
295 |
+
# images: [BS, 3, 512, 512]
|
296 |
+
return images
|
297 |
+
|
adaface/arc2face_models.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import CLIPTextModel
|
4 |
+
from transformers.models.clip.modeling_clip import CLIPAttention
|
5 |
+
from typing import Any, Callable, Dict, Optional, Tuple, Union, List
|
6 |
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
7 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
8 |
+
# from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
|
9 |
+
_make_causal_mask = AttentionMaskConverter._make_causal_mask
|
10 |
+
_expand_mask = AttentionMaskConverter._expand_mask
|
11 |
+
|
12 |
+
from adaface.util import add_noise_to_tensor
|
13 |
+
|
14 |
+
# Extend CLIPAttention by using multiple k_proj and v_proj in each head.
|
15 |
+
# To avoid too much increase of computation, we don't extend q_proj.
|
16 |
+
class CLIPAttentionMKV(nn.Module):
|
17 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
18 |
+
|
19 |
+
def __init__(self, config, multiplier=2):
|
20 |
+
super().__init__()
|
21 |
+
self.config = config
|
22 |
+
self.embed_dim = config.hidden_size
|
23 |
+
self.num_heads = config.num_attention_heads
|
24 |
+
self.head_dim = self.embed_dim // self.num_heads
|
25 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
26 |
+
raise ValueError(
|
27 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
28 |
+
f" {self.num_heads})."
|
29 |
+
)
|
30 |
+
self.scale = self.head_dim**-0.5
|
31 |
+
self.dropout = config.attention_dropout
|
32 |
+
self.multiplier = multiplier
|
33 |
+
|
34 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim * self.multiplier)
|
35 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim * self.multiplier)
|
36 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
37 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
38 |
+
|
39 |
+
# The (approximately) repeated token features are repeated along the last dim in tensor
|
40 |
+
# (multiplier * num_heads * head_dim), and then reshaped to (bsz, -1, num_heads, head_dim).
|
41 |
+
# Therefore, the "multiplier" dim is tucked into the seq_len dim, which looks like
|
42 |
+
# [token1_emb, token1_emb, token2_emb, token2_emb, ..., tokenN_emb, tokenN_emb].
|
43 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
44 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
45 |
+
|
46 |
+
def extend_weights(self, clip_attn_layer, layer_idx, multiplier, noise_std=0.1,
|
47 |
+
noise_std_is_relative=True, keep_norm=False, verbose=False):
|
48 |
+
self.multiplier *= multiplier
|
49 |
+
# q_proj and out_proj are the same as the original CLIPAttention.
|
50 |
+
self.q_proj.weight.data = clip_attn_layer.q_proj.weight.data.clone()
|
51 |
+
self.q_proj.bias.data = clip_attn_layer.q_proj.bias.data.clone()
|
52 |
+
self.out_proj.weight.data = clip_attn_layer.out_proj.weight.data.clone()
|
53 |
+
self.out_proj.bias.data = clip_attn_layer.out_proj.bias.data.clone()
|
54 |
+
|
55 |
+
# bias doesn't need noise perturbation, as after the weights are noised,
|
56 |
+
# different copies of the weight/bias will receive different gradients,
|
57 |
+
# making the bias terms diverge and identifiable after training.
|
58 |
+
self.v_proj.bias.data = clip_attn_layer.v_proj.bias.data.repeat(multiplier)
|
59 |
+
self.k_proj.bias.data = clip_attn_layer.k_proj.bias.data.repeat(multiplier)
|
60 |
+
|
61 |
+
self.v_proj.weight.data = clip_attn_layer.v_proj.weight.data.repeat(multiplier, 1)
|
62 |
+
self.k_proj.weight.data = clip_attn_layer.k_proj.weight.data.repeat(multiplier, 1)
|
63 |
+
|
64 |
+
if noise_std > 0:
|
65 |
+
ORIG_V_SHAPE = list(clip_attn_layer.v_proj.weight.shape)
|
66 |
+
ORIG_V_SHAPE_D0 = ORIG_V_SHAPE[0]
|
67 |
+
# Adding noise to the extra copies of the weights (keep the first copy unchanged).
|
68 |
+
self.v_proj.weight.data[ORIG_V_SHAPE_D0:] = \
|
69 |
+
add_noise_to_tensor(self.v_proj.weight.data[ORIG_V_SHAPE_D0:],
|
70 |
+
noise_std, noise_std_is_relative, keep_norm)
|
71 |
+
if verbose:
|
72 |
+
NEW_V_SHAPE = list(self.v_proj.weight.shape)
|
73 |
+
NOISED_V_SHAPE = list(self.v_proj.weight.data[ORIG_V_SHAPE_D0:].shape)
|
74 |
+
print(f"Layer {layer_idx}: {NOISED_V_SHAPE} in {NEW_V_SHAPE} of v_proj is added with {noise_std} noise")
|
75 |
+
|
76 |
+
ORIG_K_SHAPE = list(clip_attn_layer.k_proj.weight.shape)
|
77 |
+
ORIG_K_SHAPE_D0 = ORIG_K_SHAPE[0]
|
78 |
+
# Adding noise to the extra copies of the weights.
|
79 |
+
self.k_proj.weight.data[ORIG_K_SHAPE_D0:] = \
|
80 |
+
add_noise_to_tensor(self.k_proj.weight.data[ORIG_K_SHAPE_D0:],
|
81 |
+
noise_std, noise_std_is_relative, keep_norm)
|
82 |
+
if verbose:
|
83 |
+
NEW_K_SHAPE = list(self.k_proj.weight.shape)
|
84 |
+
NOISED_K_SHAPE = list(self.k_proj.weight.data[ORIG_K_SHAPE_D0:].shape)
|
85 |
+
print(f"Layer {layer_idx}: {NOISED_K_SHAPE} in {NEW_K_SHAPE} of k_proj is added with {noise_std} noise")
|
86 |
+
|
87 |
+
def forward(
|
88 |
+
self,
|
89 |
+
hidden_states: torch.Tensor,
|
90 |
+
attention_mask: Optional[torch.Tensor] = None,
|
91 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
92 |
+
output_attentions: Optional[bool] = False,
|
93 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
94 |
+
"""Input shape: Batch x Time x Channel"""
|
95 |
+
|
96 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
97 |
+
|
98 |
+
query_states = self.q_proj(hidden_states) * self.scale
|
99 |
+
# For key_states and value_states, the multiplier is absorbed into the seq_len (dim 1, shape specified as -1).
|
100 |
+
# [token0_head_emb, token0_head_emb, token1_head_emb, token1_head_emb, ..., tokenN-1_head_emb, tokenN-1_head_emb].
|
101 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
102 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
103 |
+
|
104 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
105 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
106 |
+
key_states = key_states.view(*proj_shape)
|
107 |
+
value_states = value_states.view(*proj_shape)
|
108 |
+
|
109 |
+
src_len = key_states.size(1)
|
110 |
+
# src_len0 is the original src_len without the multiplier.
|
111 |
+
src_len0 = src_len // self.multiplier
|
112 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
113 |
+
|
114 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
115 |
+
raise ValueError(
|
116 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
117 |
+
f" {attn_weights.size()}"
|
118 |
+
)
|
119 |
+
|
120 |
+
# apply the causal_attention_mask first
|
121 |
+
if causal_attention_mask is not None:
|
122 |
+
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len0):
|
123 |
+
raise ValueError(
|
124 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len0)}, but is"
|
125 |
+
f" {causal_attention_mask.size()}"
|
126 |
+
)
|
127 |
+
# The last dim of attn_weights corresponds to [token0, token0, token1, token1, ..., tokenN-1, tokenN-1].
|
128 |
+
# If reshaping it as (self.multiplier, src_len0), it will become
|
129 |
+
# [[token0, token0, token1, token1, ..., tokenN//2], [tokenN//2+1, tokenN//2+1, ..., tokenN-1, tokenN-1]],
|
130 |
+
# and the mask will be applied to wrong elements.
|
131 |
+
# If reshaping it as (src_len0, self.multiplier), it will become
|
132 |
+
# [[token0, token1, ..., tokenN-1], [token0, token1, ..., tokenN-1]], and then
|
133 |
+
# the mask at element i will mask all the multiplier elements at i, which is desired.
|
134 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len0, self.multiplier) + causal_attention_mask.unsqueeze(4)
|
135 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
136 |
+
|
137 |
+
if attention_mask is not None:
|
138 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len0):
|
139 |
+
raise ValueError(
|
140 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len0)}, but is {attention_mask.size()}"
|
141 |
+
)
|
142 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len0, self.multiplier) + attention_mask.unsqueeze(4)
|
143 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
144 |
+
|
145 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
146 |
+
|
147 |
+
if output_attentions:
|
148 |
+
# this operation is a bit awkward, but it's required to
|
149 |
+
# make sure that attn_weights keeps its gradient.
|
150 |
+
# In order to do so, attn_weights have to reshaped
|
151 |
+
# twice and have to be reused in the following
|
152 |
+
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
153 |
+
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
154 |
+
else:
|
155 |
+
attn_weights_reshaped = None
|
156 |
+
|
157 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
158 |
+
|
159 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
160 |
+
|
161 |
+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
162 |
+
raise ValueError(
|
163 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
164 |
+
f" {attn_output.size()}"
|
165 |
+
)
|
166 |
+
|
167 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
168 |
+
attn_output = attn_output.transpose(1, 2)
|
169 |
+
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
170 |
+
|
171 |
+
attn_output = self.out_proj(attn_output)
|
172 |
+
|
173 |
+
return attn_output, attn_weights_reshaped
|
174 |
+
|
175 |
+
class CLIPTextModelWrapper(CLIPTextModel):
|
176 |
+
# Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812
|
177 |
+
# Modified to accept precomputed token embeddings "input_token_embs" as input or calculate them from input_ids and return them.
|
178 |
+
def forward(
|
179 |
+
self,
|
180 |
+
input_ids: Optional[torch.Tensor] = None,
|
181 |
+
attention_mask: Optional[torch.Tensor] = None,
|
182 |
+
position_ids: Optional[torch.Tensor] = None,
|
183 |
+
output_attentions: Optional[bool] = None,
|
184 |
+
output_hidden_states: Optional[bool] = None,
|
185 |
+
return_dict: Optional[bool] = None,
|
186 |
+
input_token_embs: Optional[torch.Tensor] = None,
|
187 |
+
hidden_state_layer_weights: Optional[torch.Tensor] = None,
|
188 |
+
return_token_embs: Optional[bool] = False,
|
189 |
+
) -> Union[Tuple, torch.Tensor, BaseModelOutputWithPooling]:
|
190 |
+
|
191 |
+
if return_token_embs:
|
192 |
+
return self.text_model.embeddings.token_embedding(input_ids)
|
193 |
+
|
194 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
195 |
+
|
196 |
+
output_attentions = output_attentions if output_attentions is not None else self.text_model.config.output_attentions
|
197 |
+
output_hidden_states = (
|
198 |
+
output_hidden_states if output_hidden_states is not None else self.text_model.config.output_hidden_states
|
199 |
+
)
|
200 |
+
if hidden_state_layer_weights is not None:
|
201 |
+
output_hidden_states = True
|
202 |
+
return_dict = return_dict if return_dict is not None else self.text_model.config.use_return_dict
|
203 |
+
|
204 |
+
if input_ids is None:
|
205 |
+
raise ValueError("You have to specify input_ids")
|
206 |
+
|
207 |
+
input_shape = input_ids.size()
|
208 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
209 |
+
|
210 |
+
hidden_states = self.text_model.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=input_token_embs)
|
211 |
+
|
212 |
+
# CLIP's text model uses causal mask, prepare it here.
|
213 |
+
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
214 |
+
causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
|
215 |
+
# expand attention_mask
|
216 |
+
if attention_mask is not None:
|
217 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
218 |
+
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
|
219 |
+
|
220 |
+
encoder_outputs = self.text_model.encoder(
|
221 |
+
inputs_embeds=hidden_states,
|
222 |
+
attention_mask=attention_mask,
|
223 |
+
causal_attention_mask=causal_attention_mask,
|
224 |
+
output_attentions=output_attentions,
|
225 |
+
# output_hidden_states is False by default, and only True if hidden_state_layer_weights is provided.
|
226 |
+
output_hidden_states=output_hidden_states,
|
227 |
+
return_dict=return_dict,
|
228 |
+
)
|
229 |
+
|
230 |
+
# If output_hidden_states is True, then encoder_outputs[0] is last_hidden_state [1, 22, 768].
|
231 |
+
# encoder_outputs[1] is hidden_states, which is a tuple of 13 hidden states, each being [1, 22, 768].
|
232 |
+
# encoder_outputs[0] == encoder_outputs[1][12].
|
233 |
+
if hidden_state_layer_weights is None:
|
234 |
+
last_hidden_state = encoder_outputs[0]
|
235 |
+
else:
|
236 |
+
num_hidden_state_layers = len(hidden_state_layer_weights)
|
237 |
+
last_hidden_states = encoder_outputs[1][-num_hidden_state_layers:]
|
238 |
+
hidden_state_layer_weights = hidden_state_layer_weights.to(last_hidden_states[0].dtype)
|
239 |
+
# Normalize the weights of to sum to 1 across layers.
|
240 |
+
# hidden_state_layer_weights: [3, 1] or [3, 768].
|
241 |
+
hidden_state_layer_weights = hidden_state_layer_weights / hidden_state_layer_weights.sum(dim=0, keepdim=True)
|
242 |
+
# [3, 1/768] -> [3, 1, 1, 1/768]
|
243 |
+
hidden_state_layer_weights = hidden_state_layer_weights.unsqueeze(1).unsqueeze(1)
|
244 |
+
# A weighted sum of last_hidden_states.
|
245 |
+
# [3, 1, 22, 768] * [3, 1, 1, 1/768] -> [3, 1, 22, 768] -> [1, 22, 768]
|
246 |
+
last_hidden_state = (torch.stack(last_hidden_states, dim=0) * hidden_state_layer_weights).sum(dim=0)
|
247 |
+
|
248 |
+
last_hidden_state = self.text_model.final_layer_norm(last_hidden_state)
|
249 |
+
|
250 |
+
# self.text_model.eos_token_id == 2 is True.
|
251 |
+
if self.text_model.eos_token_id == 2:
|
252 |
+
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
|
253 |
+
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
|
254 |
+
# ------------------------------------------------------------
|
255 |
+
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
256 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
257 |
+
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
258 |
+
pooled_output = last_hidden_state[
|
259 |
+
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
260 |
+
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
261 |
+
]
|
262 |
+
else:
|
263 |
+
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
|
264 |
+
pooled_output = last_hidden_state[
|
265 |
+
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
266 |
+
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
|
267 |
+
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.text_model.eos_token_id)
|
268 |
+
.int()
|
269 |
+
.argmax(dim=-1),
|
270 |
+
]
|
271 |
+
|
272 |
+
if not return_dict:
|
273 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
274 |
+
|
275 |
+
return BaseModelOutputWithPooling(
|
276 |
+
last_hidden_state=last_hidden_state,
|
277 |
+
pooler_output=pooled_output,
|
278 |
+
hidden_states=encoder_outputs.hidden_states,
|
279 |
+
attentions=encoder_outputs.attentions,
|
280 |
+
)
|
281 |
+
|
282 |
+
# Applied to layers [begin_layer_idx, end_layer_idx) in the encoder.
|
283 |
+
# The layer indexed by end_layer_idx is not included.
|
284 |
+
# If both layer indices are -1, then apply to all layers (0-11).
|
285 |
+
def extend_clip_attention_MKV_multiplier(self, begin_layer_idx=-1, end_layer_idx=-1, multiplier=2, noise_std=0.1):
|
286 |
+
num_extended_layers = 0
|
287 |
+
|
288 |
+
for layer_idx, layer in enumerate(self.text_model.encoder.layers):
|
289 |
+
if begin_layer_idx >= 0 and layer_idx < begin_layer_idx:
|
290 |
+
continue
|
291 |
+
if end_layer_idx >= 0 and layer_idx >= end_layer_idx:
|
292 |
+
break
|
293 |
+
# This shouldn't happen, unless self_attn has already been extended as CLIPAttentionMKV.
|
294 |
+
if not isinstance(layer.self_attn, (CLIPAttention, CLIPAttentionMKV)):
|
295 |
+
breakpoint()
|
296 |
+
old_attn_layer = layer.self_attn
|
297 |
+
if not isinstance(old_attn_layer, CLIPAttentionMKV):
|
298 |
+
layer.self_attn = CLIPAttentionMKV(old_attn_layer.config, 1)
|
299 |
+
layer.self_attn.extend_weights(old_attn_layer, layer_idx, multiplier, noise_std, verbose=True)
|
300 |
+
num_extended_layers += 1
|
301 |
+
|
302 |
+
return num_extended_layers
|
303 |
+
|
adaface/subj_basis_generator.py
ADDED
@@ -0,0 +1,758 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Borrowed from ip-adapter resampler.py.
|
2 |
+
# https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
|
3 |
+
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
4 |
+
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from einops import rearrange
|
12 |
+
from einops.layers.torch import Rearrange
|
13 |
+
from transformers import CLIPVisionModel, CLIPTokenizer
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
from torch import einsum
|
17 |
+
from dataclasses import dataclass
|
18 |
+
from typing import Optional, Tuple
|
19 |
+
from transformers.utils import ModelOutput
|
20 |
+
from adaface.util import arc2face_inverse_face_prompt_embs, gen_gradient_scaler
|
21 |
+
from adaface.arc2face_models import CLIPTextModelWrapper
|
22 |
+
import sys
|
23 |
+
sys.modules['ldm'] = sys.modules['adaface']
|
24 |
+
|
25 |
+
def reshape_tensor(x, num_heads):
|
26 |
+
bs, length, width = x.shape
|
27 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
28 |
+
x = x.view(bs, length, num_heads, -1)
|
29 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
30 |
+
x = x.transpose(1, 2)
|
31 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
32 |
+
x = x.reshape(bs, num_heads, length, -1)
|
33 |
+
return x
|
34 |
+
|
35 |
+
# FFN. Added a Dropout layer at the end, so that it can still load the old ckpt.
|
36 |
+
def FeedForward(dim, mult=4, p_dropout=0.1):
|
37 |
+
inner_dim = int(dim * mult)
|
38 |
+
return nn.Sequential(
|
39 |
+
nn.LayerNorm(dim),
|
40 |
+
nn.Linear(dim, inner_dim, bias=False),
|
41 |
+
nn.GELU(),
|
42 |
+
nn.Linear(inner_dim, dim, bias=False),
|
43 |
+
nn.Dropout(p_dropout),
|
44 |
+
)
|
45 |
+
|
46 |
+
# IP-Adapter FaceID class. Only used in knn-faces.py.
|
47 |
+
# From: https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter_faceid_separate.py
|
48 |
+
class IP_MLPProjModel(nn.Module):
|
49 |
+
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
50 |
+
super().__init__()
|
51 |
+
|
52 |
+
self.cross_attention_dim = cross_attention_dim
|
53 |
+
self.num_tokens = num_tokens
|
54 |
+
|
55 |
+
self.proj = nn.Sequential(
|
56 |
+
nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
|
57 |
+
nn.GELU(),
|
58 |
+
nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
|
59 |
+
)
|
60 |
+
self.norm = nn.LayerNorm(cross_attention_dim)
|
61 |
+
|
62 |
+
def forward(self, id_embeds):
|
63 |
+
x = self.proj(id_embeds)
|
64 |
+
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
65 |
+
x = self.norm(x)
|
66 |
+
return x
|
67 |
+
|
68 |
+
# group_dim: the tensor dimension that corresponds to the multiple groups.
|
69 |
+
class LearnedSoftAggregate(nn.Module):
|
70 |
+
def __init__(self, num_feat, group_dim, keepdim=False):
|
71 |
+
super(LearnedSoftAggregate, self).__init__()
|
72 |
+
self.group_dim = group_dim
|
73 |
+
# num_feat = 1: element-wise score function & softmax.
|
74 |
+
# num_feat > 1: the linear score function is applied to the last dim (features) of the input tensor.
|
75 |
+
self.num_feat = num_feat
|
76 |
+
self.feat2score = nn.Linear(num_feat, 1, bias=False)
|
77 |
+
self.keepdim = keepdim
|
78 |
+
|
79 |
+
def forward(self, x, score_basis=None):
|
80 |
+
# If there's only one mode, do nothing.
|
81 |
+
if x.shape[self.group_dim] == 1:
|
82 |
+
if self.keepdim:
|
83 |
+
return x
|
84 |
+
else:
|
85 |
+
return x.squeeze(self.group_dim)
|
86 |
+
|
87 |
+
# Assume the last dim of x is the feature dim.
|
88 |
+
if score_basis is None:
|
89 |
+
score_basis = x
|
90 |
+
|
91 |
+
if self.num_feat == 1:
|
92 |
+
mode_scores = self.feat2score(score_basis.unsqueeze(-1)).squeeze(-1)
|
93 |
+
else:
|
94 |
+
mode_scores = self.feat2score(score_basis)
|
95 |
+
attn_probs = mode_scores.softmax(dim=self.group_dim)
|
96 |
+
x_aggr = (x * attn_probs).sum(dim=self.group_dim, keepdim=self.keepdim)
|
97 |
+
return x_aggr
|
98 |
+
|
99 |
+
def LoRA_ExpandEmbs(input_dim, lora_rank, output_dim, num_modes,
|
100 |
+
num_output_vecs, elementwise_affine=True, p_dropout=0.1):
|
101 |
+
return nn.Sequential(
|
102 |
+
# Project to [BS, lora_rank * output_dim * num_modes].
|
103 |
+
# It takes a huge param size. 512 * 32 * 768 * 4 = 6,291,456.
|
104 |
+
nn.Linear(input_dim, lora_rank * output_dim * num_modes, bias=False),
|
105 |
+
# Reshape to [BS, lora_rank, output_dim].
|
106 |
+
Rearrange('b (m q d) -> b m q d', q=lora_rank, m=num_modes, d=output_dim),
|
107 |
+
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
108 |
+
# Aggregate [BS, num_modes, loar_rank, output_dim] -> [BS, lora_rank, output_dim].
|
109 |
+
LearnedSoftAggregate(num_feat=output_dim, group_dim=1, keepdim=False) if num_modes > 1 \
|
110 |
+
else Rearrange('b () q d -> b q d'),
|
111 |
+
nn.Dropout(p_dropout),
|
112 |
+
# Permute to [BS, output_dim, lora_rank].
|
113 |
+
Rearrange('b q d -> b d q'),
|
114 |
+
# Project to [BS, output_dim, num_output_vecs].
|
115 |
+
nn.Linear(lora_rank, num_output_vecs, bias=False),
|
116 |
+
# Permute to [BS, num_output_vecs, output_dim].
|
117 |
+
Rearrange('b d q -> b q d'),
|
118 |
+
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
119 |
+
nn.Dropout(p_dropout),
|
120 |
+
)
|
121 |
+
|
122 |
+
def ExpandEmbs(input_dim, output_dim, expansion_ratio, elementwise_affine=True, p_dropout=0.1):
|
123 |
+
return nn.Sequential(
|
124 |
+
# Project to [BS, num_output_vecs * output_dim].
|
125 |
+
nn.Linear(input_dim, expansion_ratio * output_dim, bias=False),
|
126 |
+
# Reshape to [BS, num_output_vecs, output_dim].
|
127 |
+
Rearrange('b (e d) -> b e d', e=expansion_ratio, d=output_dim),
|
128 |
+
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
129 |
+
nn.Dropout(p_dropout),
|
130 |
+
)
|
131 |
+
|
132 |
+
# Input: [BS, N, D].
|
133 |
+
def MultimodeProjection(input_dim, output_dim=-1, num_modes=4, elementwise_affine=True, p_dropout=0.1):
|
134 |
+
if output_dim == -1:
|
135 |
+
output_dim = input_dim
|
136 |
+
|
137 |
+
return nn.Sequential(
|
138 |
+
nn.Linear(input_dim, output_dim * num_modes, bias=False),
|
139 |
+
# Reshape to [BS, num_output_vecs, output_dim].
|
140 |
+
Rearrange('b n (m d) -> b n m d', m=num_modes, d=output_dim),
|
141 |
+
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
142 |
+
# If num_modes == 1, then simply remove the mode dim. Otherwise, aggregate the modes.
|
143 |
+
LearnedSoftAggregate(num_feat=output_dim, group_dim=2, keepdim=False) if num_modes > 1 \
|
144 |
+
else Rearrange('b n () d -> b n d'),
|
145 |
+
nn.Dropout(p_dropout),
|
146 |
+
)
|
147 |
+
|
148 |
+
# Low-rank to high-rank transformation.
|
149 |
+
def Lora2Hira(lora_rank, hira_rank, output_dim, num_modes, elementwise_affine=True, p_dropout=0.1):
|
150 |
+
return nn.Sequential(
|
151 |
+
# Permute to [BS, output_dim, lora_rank].
|
152 |
+
Rearrange('b q d -> b d q'),
|
153 |
+
# Project to [BS, output_dim, hira_rank].
|
154 |
+
nn.Linear(lora_rank, hira_rank * num_modes, bias=False),
|
155 |
+
# Reshape and permute to [BS, num_modes, num_output_vecs, output_dim].
|
156 |
+
Rearrange('b d (m q) -> b m q d', m=num_modes, q=hira_rank),
|
157 |
+
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
158 |
+
# Aggregate [BS, num_modes, hira_rank, output_dim] -> [BS, hira_rank, output_dim].
|
159 |
+
LearnedSoftAggregate(num_feat=output_dim, group_dim=1, keepdim=False) if num_modes > 1 \
|
160 |
+
else Rearrange('b () q d -> b q d'),
|
161 |
+
nn.Dropout(p_dropout),
|
162 |
+
)
|
163 |
+
|
164 |
+
class PerceiverAttention(nn.Module):
|
165 |
+
def __init__(self, *, dim, dim_head=64, num_heads=8, elementwise_affine=True):
|
166 |
+
super().__init__()
|
167 |
+
self.scale = dim_head**-0.5
|
168 |
+
self.dim_head = dim_head
|
169 |
+
self.num_heads = num_heads
|
170 |
+
inner_dim = dim_head * num_heads
|
171 |
+
|
172 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine)
|
173 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine)
|
174 |
+
|
175 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
176 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
177 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
178 |
+
|
179 |
+
def forward(self, x, latent_queries):
|
180 |
+
"""
|
181 |
+
Args:
|
182 |
+
x (torch.Tensor): image features
|
183 |
+
shape (b, n1, D)
|
184 |
+
latent (torch.Tensor): latent features
|
185 |
+
shape (b, n2, D)
|
186 |
+
"""
|
187 |
+
x = self.norm1(x)
|
188 |
+
latent_queries = self.norm2(latent_queries)
|
189 |
+
|
190 |
+
b, l, _ = latent_queries.shape
|
191 |
+
|
192 |
+
q = self.to_q(latent_queries)
|
193 |
+
kv_input = torch.cat((x, latent_queries), dim=-2)
|
194 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
195 |
+
|
196 |
+
q = reshape_tensor(q, self.num_heads)
|
197 |
+
k = reshape_tensor(k, self.num_heads)
|
198 |
+
v = reshape_tensor(v, self.num_heads)
|
199 |
+
|
200 |
+
# attention
|
201 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
202 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
203 |
+
attn = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
204 |
+
out = attn @ v
|
205 |
+
|
206 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
207 |
+
|
208 |
+
return self.to_out(out)
|
209 |
+
|
210 |
+
|
211 |
+
class CrossAttention(nn.Module):
|
212 |
+
# output_dim is always the same as input_dim.
|
213 |
+
# num_q only matters when q_aware_to_v is True.
|
214 |
+
# If q_aware_to_v is False, query x in forward() is still usable.
|
215 |
+
def __init__(self, input_dim, num_heads=6, p_dropout=0.05,
|
216 |
+
identity_to_q=False, identity_to_k=False, identity_to_v=False, v_has_skip=True,
|
217 |
+
q_aware_to_v=True, num_q=416, v_repeat=4, q_aware_to_v_lora_rank=64,
|
218 |
+
identity_to_out=False, out_has_skip=False):
|
219 |
+
super().__init__()
|
220 |
+
dim_head = input_dim // num_heads
|
221 |
+
inner_dim = dim_head * num_heads
|
222 |
+
|
223 |
+
self.num_heads = num_heads
|
224 |
+
self.q_aware_to_v = q_aware_to_v
|
225 |
+
self.v_has_skip = v_has_skip
|
226 |
+
self.to_q = nn.Sequential(
|
227 |
+
nn.Linear(input_dim, inner_dim, bias=False),
|
228 |
+
nn.LayerNorm(inner_dim, elementwise_affine=True)
|
229 |
+
) if not identity_to_q else nn.Identity()
|
230 |
+
self.to_k = nn.Sequential(
|
231 |
+
nn.Linear(input_dim, inner_dim, bias=False),
|
232 |
+
nn.LayerNorm(inner_dim, elementwise_affine=True)
|
233 |
+
) if not identity_to_k else nn.Identity()
|
234 |
+
|
235 |
+
self.v_repeat = v_repeat
|
236 |
+
self.num_q_group = num_q_group = num_q // v_repeat # 416 / 4 = 104.
|
237 |
+
|
238 |
+
# If q_aware_to_v is True, then self.to_v consists of num_q projections of input_dim to inner_dim.
|
239 |
+
# Otherwise, self.to_v consists of a single projection of input_dim to inner_dim.
|
240 |
+
if q_aware_to_v:
|
241 |
+
# all_q_mid: 104 * 64 = 6656.
|
242 |
+
all_q_mid = num_q_group * q_aware_to_v_lora_rank
|
243 |
+
self.to_v = nn.Sequential(
|
244 |
+
# number of params: 768 * 6656 = 5,111,808.
|
245 |
+
# Input: [BS, 16, 768]. Output: [BS, 16, 104*64] = [BS, 16, 6656].
|
246 |
+
# Each 768-dim vec is dispersed into 104 64-dim vecs.
|
247 |
+
nn.Linear(input_dim, all_q_mid, bias=False),
|
248 |
+
nn.LayerNorm(all_q_mid, elementwise_affine=True),
|
249 |
+
# Change the dim of the tensor to [BS, 6656, 16], as Conv1d transforms dim 1.
|
250 |
+
Rearrange('b n q -> b q n', q=all_q_mid),
|
251 |
+
# Each q_aware_to_v projection has its own linear layer.
|
252 |
+
# The total number of parameters will be 6656*768 = 5,111,808.
|
253 |
+
# Output: [BS, 104*768, 16]. Each 64 dim feature is expanded to 768 dim.
|
254 |
+
nn.Conv1d(
|
255 |
+
in_channels=all_q_mid,
|
256 |
+
out_channels=num_q_group * input_dim,
|
257 |
+
kernel_size=1,
|
258 |
+
groups=num_q_group,
|
259 |
+
bias=False,
|
260 |
+
),
|
261 |
+
# Output: [BS, 104, 16, 768].
|
262 |
+
Rearrange('b (q d) n -> b q n d', q=num_q_group, d=input_dim),
|
263 |
+
nn.LayerNorm(input_dim, elementwise_affine=True),
|
264 |
+
)
|
265 |
+
else:
|
266 |
+
self.to_v = nn.Sequential(
|
267 |
+
nn.Linear(input_dim, inner_dim, bias=False),
|
268 |
+
nn.LayerNorm(inner_dim, elementwise_affine=True)
|
269 |
+
) if not identity_to_v else nn.Identity()
|
270 |
+
|
271 |
+
if identity_to_out:
|
272 |
+
assert not out_has_skip, "identity_to_out=True, then out_has_skip has to be False."
|
273 |
+
|
274 |
+
if identity_to_out:
|
275 |
+
self.to_out = nn.Identity()
|
276 |
+
else:
|
277 |
+
self.to_out = nn.Sequential(
|
278 |
+
nn.Linear(input_dim, input_dim, bias=False),
|
279 |
+
nn.Dropout(p_dropout),
|
280 |
+
nn.LayerNorm(inner_dim, elementwise_affine=True)
|
281 |
+
)
|
282 |
+
|
283 |
+
self.out_has_skip = out_has_skip
|
284 |
+
self.attn_drop = nn.Dropout(p_dropout)
|
285 |
+
|
286 |
+
def forward(self, x, context=None, attn_mat=None, return_attn=False):
|
287 |
+
h = self.num_heads
|
288 |
+
|
289 |
+
if context is None:
|
290 |
+
context = x
|
291 |
+
|
292 |
+
if attn_mat is None:
|
293 |
+
# q: [BS, Q, D] -> [BS, Q, D].
|
294 |
+
q = self.to_q(x)
|
295 |
+
# k: [BS, L, D] -> [BS, L, D].
|
296 |
+
k = self.to_k(context)
|
297 |
+
# q: [6, 512, 128], k: [6, 17, 128].
|
298 |
+
q, k = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k))
|
299 |
+
|
300 |
+
if self.q_aware_to_v:
|
301 |
+
# context: [BS, L, D]. v: [BS, Q, L, D].
|
302 |
+
# There are effectively Q to_v projections.
|
303 |
+
v = self.to_v(context)
|
304 |
+
if self.v_has_skip:
|
305 |
+
v = v + context.unsqueeze(1)
|
306 |
+
else:
|
307 |
+
# v: [BS, L, D].
|
308 |
+
v = self.to_v(context)
|
309 |
+
if self.v_has_skip:
|
310 |
+
v = v + context
|
311 |
+
|
312 |
+
#print(v.shape)
|
313 |
+
|
314 |
+
if self.q_aware_to_v:
|
315 |
+
# v: [6, 64, 17, 128].
|
316 |
+
# v is query-specific, so there's an extra dim for the query.
|
317 |
+
v = rearrange(v, 'b q n (h d) -> (b h) q n d', h=h)
|
318 |
+
# Each v is for a query group with 512/64 = 8 queries.
|
319 |
+
# So each v is repeated 8 times to match the number of queries.
|
320 |
+
# v: [6, 64, 17, 128] -> [6, 512, 17, 128].
|
321 |
+
v = v.repeat(1, self.v_repeat, 1, 1)
|
322 |
+
else:
|
323 |
+
v = rearrange(v, 'b n (h d) -> (b h) n d', h=h)
|
324 |
+
|
325 |
+
if attn_mat is None:
|
326 |
+
scale = q.size(-1) ** -0.25
|
327 |
+
sim = einsum('b i d, b j d -> b i j', q * scale, k * scale)
|
328 |
+
# sim: [6, 64, 17]. 6: bs 1 * h 6.
|
329 |
+
# attention, what we cannot get enough of
|
330 |
+
# NOTE: the normalization is done across tokens, not across pixels.
|
331 |
+
# So for each pixel, the sum of attention scores across tokens is 1.
|
332 |
+
attn = sim.softmax(dim=-1)
|
333 |
+
attn = self.attn_drop(attn)
|
334 |
+
#print(attn.std())
|
335 |
+
else:
|
336 |
+
attn = attn_mat
|
337 |
+
|
338 |
+
if self.q_aware_to_v:
|
339 |
+
# attn: [6, 32, 17]. v: [6, 32, 17, 128]. 128: dim of each head. out: [6, 32, 128].
|
340 |
+
# out is combined with different attn weights and v for different queries.
|
341 |
+
out = einsum('b i j, b i j d -> b i d', attn, v)
|
342 |
+
else:
|
343 |
+
# v: [6, 17, 128]. out: [6, 32, 128].
|
344 |
+
out = einsum('b i j, b j d -> b i d', attn, v)
|
345 |
+
|
346 |
+
# [6, 32, 128] -> [1, 32, 768].
|
347 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
348 |
+
|
349 |
+
if self.out_has_skip:
|
350 |
+
out = self.to_out(out) + out
|
351 |
+
else:
|
352 |
+
out = self.to_out(out)
|
353 |
+
|
354 |
+
if return_attn:
|
355 |
+
return out, attn
|
356 |
+
else:
|
357 |
+
return out
|
358 |
+
|
359 |
+
class SubjBasisGenerator(nn.Module):
|
360 |
+
def __init__(
|
361 |
+
self,
|
362 |
+
# number of cross-attention heads. Half of the number of heads 12 of OpenAI clip-vit-large-patch14:
|
363 |
+
# https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
|
364 |
+
num_heads=6,
|
365 |
+
num_id_vecs={ 'subj': 77, 'bg': 257 }, # number of identity vectors. 18: 16 face tokens + 2 extra tokens. 257: 257 CLIP tokens.
|
366 |
+
num_out_embs_per_layer=4, # num_out_embs. subj: 16. bg: 4.
|
367 |
+
num_out_layers=16, # number of layers of output embeddings.
|
368 |
+
image_embedding_dim=768, # CLIP image feature dimension, as per config.json above.
|
369 |
+
# DINO vits16 has 6 attention heads:
|
370 |
+
# https://huggingface.co/facebook/dino-vits16/blob/main/config.json
|
371 |
+
dino_embedding_dim=384, # DINO object feature dimension for objects.
|
372 |
+
output_dim=768, # CLIP text embedding input dimension.
|
373 |
+
placeholder_is_bg: bool = False, # Whether the placeholder is for the image background.
|
374 |
+
prompt2token_proj_grad_scale: float = 0.4, # Gradient scale for prompt2token_proj.
|
375 |
+
zs_extra_words_scale: float = 0.5, # Scale for extra words in the prompt2token_proj.
|
376 |
+
learnable_hidden_state_weights_scheme: str = 'per-layer', # none, per-layer.
|
377 |
+
bg_prompt_translator_has_to_out_proj: bool = False, # Whether the prompt_trans_layers have a to_out projection.
|
378 |
+
):
|
379 |
+
super().__init__()
|
380 |
+
|
381 |
+
self.placeholder_is_bg = placeholder_is_bg
|
382 |
+
self.num_out_layers = num_out_layers
|
383 |
+
self.num_out_embs_per_layer = num_out_embs_per_layer
|
384 |
+
# subj: 64, bg: 32.
|
385 |
+
self.num_out_embs = num_out_layers * num_out_embs_per_layer
|
386 |
+
self.output_dim = output_dim
|
387 |
+
# num_id_vecs should be the number of core ID embs, 16.
|
388 |
+
# However, in such case, pos_embs is not used. So it doesn't matter if it's wrongly set.
|
389 |
+
self.num_id_vecs = num_id_vecs['bg'] if placeholder_is_bg else num_id_vecs['subj']
|
390 |
+
self.pos_embs = nn.Parameter(torch.randn(1, self.num_id_vecs, output_dim))
|
391 |
+
self.pos_embs_ln = nn.LayerNorm(output_dim)
|
392 |
+
self.zs_extra_words_scale = zs_extra_words_scale
|
393 |
+
self.output_scale = output_dim ** -0.5
|
394 |
+
self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
395 |
+
|
396 |
+
if not self.placeholder_is_bg:
|
397 |
+
# [1, 384] -> [1, 16, 768].
|
398 |
+
# TODO: use CLIPTextModelWrapper as obj_proj_in.
|
399 |
+
self.obj_proj_in = ExpandEmbs(dino_embedding_dim, output_dim, expansion_ratio=self.num_id_vecs)
|
400 |
+
|
401 |
+
# self.prompt2token_proj: [1, 16, 768] -> [1, 77, 768] (with paddings).
|
402 |
+
# If self.placeholder_is_bg: prompt2token_proj is set to None.
|
403 |
+
self.prompt2token_proj = CLIPTextModelWrapper.from_pretrained('openai/clip-vit-large-patch14')
|
404 |
+
self.prompt2token_proj_grad_scale = prompt2token_proj_grad_scale
|
405 |
+
self.prompt2token_proj_grad_scaler = gen_gradient_scaler(prompt2token_proj_grad_scale)
|
406 |
+
print(f"Subj prompt2token_proj initialized with grad scale of {prompt2token_proj_grad_scale}.")
|
407 |
+
# Freeze prompt2token_proj if prompt2token_proj_grad_scale is 0.
|
408 |
+
# Set requires_grad to False for all parameters in prompt2token_proj, to save memory taken by the optimizer.
|
409 |
+
if prompt2token_proj_grad_scale == 0:
|
410 |
+
self.freeze_prompt2token_proj()
|
411 |
+
|
412 |
+
self.prompt2token_proj_attention_multiplier = -1
|
413 |
+
self.initialize_hidden_state_layer_weights(learnable_hidden_state_weights_scheme, 'cpu')
|
414 |
+
self.pad_embeddings = None
|
415 |
+
self.bg_proj_in = None
|
416 |
+
else:
|
417 |
+
# For background placeholders, face and object embeddings are not used as they are foreground.
|
418 |
+
self.obj_proj_in = None
|
419 |
+
self.prompt2token_proj = None
|
420 |
+
print("Bg prompt2token_proj is set to None.")
|
421 |
+
|
422 |
+
self.bg_proj_in = nn.Sequential(
|
423 |
+
nn.Linear(image_embedding_dim, output_dim, bias=False),
|
424 |
+
nn.LayerNorm(output_dim),
|
425 |
+
)
|
426 |
+
|
427 |
+
self.latent_queries = nn.Parameter(torch.randn(1, self.num_out_embs, output_dim))
|
428 |
+
self.latent_queries_ln = nn.LayerNorm(output_dim)
|
429 |
+
|
430 |
+
self.bg_prompt_translator_has_to_out_proj = bg_prompt_translator_has_to_out_proj
|
431 |
+
identity_to_v = False
|
432 |
+
v_has_skip = not identity_to_v # True
|
433 |
+
identity_to_out = not bg_prompt_translator_has_to_out_proj # True
|
434 |
+
out_has_skip = not identity_to_out # False
|
435 |
+
# prompt_translator has a to_v projection with skip connection, and doesn't have a to_out projection.
|
436 |
+
# dim=768, num_heads=6.
|
437 |
+
self.prompt_translator = \
|
438 |
+
CrossAttention(input_dim=output_dim, num_heads=num_heads, p_dropout=0.05,
|
439 |
+
identity_to_q=False, identity_to_k=False, identity_to_v=identity_to_v,
|
440 |
+
q_aware_to_v=False, v_has_skip=v_has_skip,
|
441 |
+
num_q=0, # When not q_aware_to_v, num_q is not referenced.
|
442 |
+
identity_to_out=identity_to_out,
|
443 |
+
out_has_skip=out_has_skip)
|
444 |
+
'''
|
445 |
+
prompt_translator: CLIPEncoder
|
446 |
+
# https://github.com/huggingface/transformers/blob/1872bde7fc6a5d6796bd742bc2dc38eaf8069c5d/src/transformers/models/clip/modeling_clip.py#L566
|
447 |
+
# CLIPEncoder.layers: 12 layers of CLIPEncoderLayer, each being
|
448 |
+
(0): CLIPEncoderLayer(
|
449 |
+
(self_attn): CLIPAttention(
|
450 |
+
(k_proj): Linear(in_features=768, out_features=768, bias=True)
|
451 |
+
(v_proj): Linear(in_features=768, out_features=768, bias=True)
|
452 |
+
(q_proj): Linear(in_features=768, out_features=768, bias=True)
|
453 |
+
(out_proj): Linear(in_features=768, out_features=768, bias=True)
|
454 |
+
)
|
455 |
+
(layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
456 |
+
(mlp): CLIPMLP(
|
457 |
+
(activation_fn): QuickGELUActivation()
|
458 |
+
(fc1): Linear(in_features=768, out_features=3072, bias=True)
|
459 |
+
(fc2): Linear(in_features=3072, out_features=768, bias=True)
|
460 |
+
)
|
461 |
+
(layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
462 |
+
)
|
463 |
+
'''
|
464 |
+
|
465 |
+
print(repr(self))
|
466 |
+
|
467 |
+
# raw_id_embs: ArcFace embeddings for faces (not used since we have arc2face_id_embs),
|
468 |
+
# or DINO embeddings for objects.
|
469 |
+
# arc2face_id_embs: [BS, 16, 768], the core identity embeddings generated by Arc2Face.
|
470 |
+
def forward(self, arc2face_id_embs, clip_features=None, raw_id_embs=None, out_id_embs_scale=1.0,
|
471 |
+
is_face=True, is_training=False, adaface_prompt_embs_inf_type='full_half_pad'):
|
472 |
+
|
473 |
+
if not self.placeholder_is_bg:
|
474 |
+
BS = arc2face_id_embs.shape[0]
|
475 |
+
else:
|
476 |
+
# If bg, then arc2face_id_embs is set to None, but clip_features is not None.
|
477 |
+
BS = clip_features.shape[0]
|
478 |
+
|
479 |
+
adaface_prompt_embs = None
|
480 |
+
if not hasattr(self, 'clip_tokenizer'):
|
481 |
+
self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
482 |
+
|
483 |
+
# No need to use raw_id_embs if placeholder_is_bg.
|
484 |
+
if not self.placeholder_is_bg:
|
485 |
+
if is_face:
|
486 |
+
assert arc2face_id_embs is not None
|
487 |
+
# arc2face_embs has been projected to the (modified) prompt embedding space
|
488 |
+
# by arc2face_forward_face_embs. This prompt embedding space is modified because Arc2Face finetuned
|
489 |
+
# the text encoder and the U-Net.
|
490 |
+
# in embedding_manager: [BS, 16, 768] -> [BS, 77, 768].
|
491 |
+
# arc2face_id_embs is part of arc2face_embs: [BS, 77, 768] -> [BS, 16, 768].
|
492 |
+
# adaface_prompt_embs is projected to the prompt embedding spaces. This is the
|
493 |
+
# original U-Net prompt embedding space.
|
494 |
+
|
495 |
+
# hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]]
|
496 |
+
hidden_state_layer_weights = self.hidden_state_layer_weights_grad_scaler(self.hidden_state_layer_weights)
|
497 |
+
# return_emb_types: a list of strings, each string is among
|
498 |
+
# ['full', 'core', 'full_pad', 'full_half_pad', 'full_zeroed_extra', 'b_core_e'].
|
499 |
+
# Using b_core_e is more computationally efficient than using full_zeroed_extra.
|
500 |
+
# But there is an unknow BUG that causes crash when using b_core_e.
|
501 |
+
if is_training:
|
502 |
+
return_emb_types = ['full_pad', 'core']
|
503 |
+
else:
|
504 |
+
# adaface_prompt_embs_inf_type: default is full_half_pad, same as training.
|
505 |
+
return_emb_types = [adaface_prompt_embs_inf_type, 'core']
|
506 |
+
|
507 |
+
if self.pad_embeddings is None:
|
508 |
+
self.generate_pad_embeddings()
|
509 |
+
else:
|
510 |
+
self.pad_embeddings = self.pad_embeddings.to(arc2face_id_embs.device)
|
511 |
+
|
512 |
+
with torch.set_grad_enabled(self.training and self.prompt2token_proj_grad_scale != 0):
|
513 |
+
# If list_extra_words is not None, then core_id_embs: [BS, 18, 768], three leading words, the 16 identity tokens
|
514 |
+
# and (at most) two extra words in full_prompt_embs, without BOS and EOS.
|
515 |
+
# If list_extra_words is None, then core_id_embs: [BS, 16, 768], the 16 identity tokens in full_prompt_embs.
|
516 |
+
# hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]]
|
517 |
+
# zs_extra_words_scale is only effective when list_extra_words is not None.
|
518 |
+
# adaface_prompt_embs: [BS, 77, 768], core_id_embs: [BS, 16, 768].
|
519 |
+
adaface_prompt_embs, core_id_embs = \
|
520 |
+
arc2face_inverse_face_prompt_embs(self.clip_tokenizer,
|
521 |
+
self.prompt2token_proj,
|
522 |
+
arc2face_id_embs,
|
523 |
+
list_extra_words=None,
|
524 |
+
return_emb_types=return_emb_types,
|
525 |
+
pad_embeddings=self.pad_embeddings,
|
526 |
+
hidden_state_layer_weights=hidden_state_layer_weights,
|
527 |
+
input_max_length=77, zs_extra_words_scale=self.zs_extra_words_scale)
|
528 |
+
# Reduce the update rate to prompt2token_proj.
|
529 |
+
adaface_prompt_embs = self.prompt2token_proj_grad_scaler(adaface_prompt_embs)
|
530 |
+
core_id_embs = self.prompt2token_proj_grad_scaler(core_id_embs)
|
531 |
+
elif raw_id_embs is not None:
|
532 |
+
# id_embs: [BS, 384] -> [BS, 18, 768].
|
533 |
+
# obj_proj_in is expected to project the DINO object features to
|
534 |
+
# the token embedding space. So no need to use prompt2token_proj.
|
535 |
+
id_embs = self.obj_proj_in(raw_id_embs)
|
536 |
+
else:
|
537 |
+
breakpoint()
|
538 |
+
else:
|
539 |
+
# Otherwise, context is the ad-hoc CLIP image features.
|
540 |
+
# id_embs: [BS, 257, 768].
|
541 |
+
id_embs = self.bg_proj_in(clip_features)
|
542 |
+
|
543 |
+
if self.placeholder_is_bg:
|
544 |
+
id_embs = id_embs + self.pos_embs_ln(self.pos_embs)
|
545 |
+
latent_queries = self.latent_queries_ln(self.latent_queries).repeat(BS, 1, 1)
|
546 |
+
# If bg, we don't have to use a specific attn layer for each 4-vec set. Instead, one attn layer can generate 257 embs,
|
547 |
+
# and we take the first 16*4=64.
|
548 |
+
# Output of prompt_translator is exactly num_out_embs == 64 tokens. id_embs_out: [BS, 64, 768].
|
549 |
+
# prompt_translator: better named as bg_prompt_translator. It maps the bg features
|
550 |
+
# to bg prompt embeddings.
|
551 |
+
with torch.set_grad_enabled(self.training):
|
552 |
+
id_embs_out = self.prompt_translator(latent_queries, id_embs)
|
553 |
+
# [BS, 64, 768] -> [BS, 16, 4, 768]
|
554 |
+
id_embs_out = id_embs_out.reshape(BS, self.num_out_layers, -1, self.output_dim)
|
555 |
+
adaface_subj_embs = id_embs_out * self.output_scale # * 0.036
|
556 |
+
else:
|
557 |
+
# adaface_subj_embs: [BS, 16, 768] -> [BS, 1, 16, 768] -> [BS, 16, 16, 768]
|
558 |
+
adaface_subj_embs = core_id_embs.unsqueeze(1).repeat(1, self.num_out_layers, 1, 1)
|
559 |
+
|
560 |
+
# If out_id_embs_scale < 1, adaface_subj_embs is a mix of adaface_subj_embs and pad_embeddings.
|
561 |
+
if out_id_embs_scale != 1:
|
562 |
+
# pad_embeddings: [77, 768] -> [16, 768] -> [1, 1, 16, 768].
|
563 |
+
pad_embeddings = self.pad_embeddings[4:4+self.num_out_embs_per_layer].unsqueeze(0).unsqueeze(0)
|
564 |
+
adaface_subj_embs = adaface_subj_embs * out_id_embs_scale \
|
565 |
+
+ pad_embeddings * (1 - out_id_embs_scale)
|
566 |
+
|
567 |
+
return adaface_subj_embs, adaface_prompt_embs
|
568 |
+
|
569 |
+
def initialize_hidden_state_layer_weights(self, learnable_hidden_state_weights_scheme, device):
|
570 |
+
if learnable_hidden_state_weights_scheme == 'none':
|
571 |
+
self.hidden_state_layer_weights = None
|
572 |
+
# A grad scaler with alpha =1 is nn.Identity(), which outputs None given None as input.
|
573 |
+
self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(1)
|
574 |
+
print("hidden_state_layer_weights is set to None.")
|
575 |
+
|
576 |
+
elif learnable_hidden_state_weights_scheme == 'per-layer':
|
577 |
+
# Learnable weights of the last 3 layers, initialized to putting more focus on the last layer.
|
578 |
+
# 'per-layer': Different weights for different layers, but the same for different channels.
|
579 |
+
# hidden_state_layer_weights: [3, 1].
|
580 |
+
self.hidden_state_layer_weights = nn.Parameter(torch.tensor([[1.0], [2.0], [4.0]], device=device),
|
581 |
+
requires_grad=True)
|
582 |
+
self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(5)
|
583 |
+
print("hidden_state_layer_weights initialized as per-layer [1, 2, 4], with grad scaler 5.")
|
584 |
+
else:
|
585 |
+
breakpoint()
|
586 |
+
|
587 |
+
def generate_pad_embeddings(self):
|
588 |
+
# clip_embeddings: CLIPTextEmbeddings instance. pad_embeddings is generated after
|
589 |
+
# prompt2token_proj is loaded from the finetuned weight. It seems such pad embeddings perform
|
590 |
+
# slightly better than the original pad embeddings.
|
591 |
+
clip_embeddings = self.prompt2token_proj.text_model.embeddings
|
592 |
+
# clip_embeddings() and clip_embeddings.token_embedding() differ in that
|
593 |
+
# clip_embeddings() adds positional embeddings, while clip_embeddings.token_embedding() doesn't.
|
594 |
+
# Adding positional embeddings seems to help somewhat.
|
595 |
+
# pad_tokens: pad_token_id 49407 repeated 77 times.
|
596 |
+
# pad_token_id is the EOS token. But BOS is 49406.
|
597 |
+
pad_tokens = torch.tensor([self.clip_tokenizer.pad_token_id]).to(clip_embeddings.token_embedding.weight.device).repeat(77)
|
598 |
+
# pad_embeddings: [77, 768].
|
599 |
+
pad_embeddings = clip_embeddings(pad_tokens)[0]
|
600 |
+
# We don't allow face recon to influence the pad embeddings.
|
601 |
+
# Otherwise, face identity will leak into the pad embeddings.
|
602 |
+
self.pad_embeddings = pad_embeddings.detach()
|
603 |
+
|
604 |
+
def extend_prompt2token_proj_attention(self, begin_layer_idx=-1, end_layer_idx=-1, multiplier=2, noise_std=0.1):
|
605 |
+
if multiplier > 1:
|
606 |
+
num_extended_layers = self.prompt2token_proj.extend_clip_attention_MKV_multiplier(begin_layer_idx, end_layer_idx, multiplier, noise_std)
|
607 |
+
self.prompt2token_proj_attention_multiplier = multiplier
|
608 |
+
print(f"{num_extended_layers} layers in prompt2token_proj_attention are x{multiplier}")
|
609 |
+
|
610 |
+
def freeze_prompt2token_proj(self):
|
611 |
+
# If bg, then prompt2token_proj is set to None. Therefore no need to freeze it.
|
612 |
+
# Then we don't have to check whether it's for subj or bg.
|
613 |
+
if self.prompt2token_proj is not None:
|
614 |
+
frozen_param_names = []
|
615 |
+
for param_name, param in self.prompt2token_proj.named_parameters():
|
616 |
+
if param.requires_grad:
|
617 |
+
param.requires_grad = False
|
618 |
+
frozen_param_names.append(param_name)
|
619 |
+
# If param is already frozen, then no need to freeze it again.
|
620 |
+
print(f"{len(frozen_param_names)} params in Subj prompt2token_proj is frozen.")
|
621 |
+
#print(f"Frozen parameters:\n{frozen_param_names}")
|
622 |
+
|
623 |
+
def __repr__(self):
|
624 |
+
type_sig = 'subj' if not self.placeholder_is_bg else 'bg'
|
625 |
+
# Fix compatability with the previous version.
|
626 |
+
if not hasattr(self, 'bg_prompt_translator_has_to_out_proj'):
|
627 |
+
self.bg_prompt_translator_has_to_out_proj = False
|
628 |
+
if not hasattr(self, 'num_out_embs'):
|
629 |
+
self.num_out_embs = -1
|
630 |
+
return f"{type_sig} SubjBasisGenerator: num_out_embs={self.num_out_embs}, " \
|
631 |
+
f"bg_prompt_translator_has_to_out_proj={self.bg_prompt_translator_has_to_out_proj}"
|
632 |
+
|
633 |
+
@dataclass
|
634 |
+
class BaseModelOutputWithPooling2(ModelOutput):
|
635 |
+
"""
|
636 |
+
Base class for model's outputs that also contains a pooling of the last hidden states.
|
637 |
+
|
638 |
+
Args:
|
639 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
640 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
641 |
+
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
|
642 |
+
Last layer hidden-state of the first token of the sequence (classification token) after further processing
|
643 |
+
through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
|
644 |
+
the classification token after processing through a linear layer and a tanh activation function. The linear
|
645 |
+
layer weights are trained from the next sentence prediction (classification) objective during pretraining.
|
646 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
647 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
648 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
649 |
+
|
650 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
651 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
652 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
653 |
+
sequence_length)`.
|
654 |
+
|
655 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
656 |
+
heads.
|
657 |
+
"""
|
658 |
+
|
659 |
+
last_hidden_state: torch.FloatTensor = None
|
660 |
+
pooler_output: torch.FloatTensor = None
|
661 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
662 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
663 |
+
attn_mask: Optional[torch.FloatTensor] = None
|
664 |
+
|
665 |
+
# Revised from CLIPVisionTransformer to support attention mask.
|
666 |
+
# self: a CLIPVisionTransformer instance.
|
667 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L821
|
668 |
+
# pixel_values: preprocessed B*C*H*W images. [BS, 3, 224, 224]
|
669 |
+
# attn_mask: B*H*W attention mask.
|
670 |
+
def CLIPVisionTransformer_forward(self, pixel_values = None, attn_mask=None,
|
671 |
+
output_attentions = None,
|
672 |
+
output_hidden_states = None, return_dict = None):
|
673 |
+
|
674 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
675 |
+
output_hidden_states = (
|
676 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
677 |
+
)
|
678 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
679 |
+
|
680 |
+
if pixel_values is None:
|
681 |
+
raise ValueError("You have to specify pixel_values")
|
682 |
+
|
683 |
+
# Visual tokens are flattended in embeddings().
|
684 |
+
# self.embeddings: CLIPVisionEmbeddings.
|
685 |
+
# hidden_states: [BS, 257, 1280]. 257: 16*16 (patch_embeds) + 1 (class_embeds).
|
686 |
+
# 16*16 is output from Conv2d(3, 1280, kernel_size=(14, 14), stride=(14, 14), bias=False).
|
687 |
+
hidden_states = self.embeddings(pixel_values)
|
688 |
+
hidden_states = self.pre_layrnorm(hidden_states)
|
689 |
+
|
690 |
+
if attn_mask is not None:
|
691 |
+
# feat_edge_size: 16.
|
692 |
+
feat_edge_size = np.sqrt(hidden_states.shape[1] - 1).astype(int)
|
693 |
+
# attn_mask: [BS, 512, 512] -> [BS, 1, 16, 16].
|
694 |
+
attn_mask = F.interpolate(attn_mask.unsqueeze(1), size=(feat_edge_size, feat_edge_size), mode='nearest')
|
695 |
+
# Flatten the mask: [BS, 1, 16, 16] => [BS, 1, 256].
|
696 |
+
attn_mask = attn_mask.flatten(2)
|
697 |
+
# Prepend 1 to the mask: [BS, 1, 256] => [BS, 1, 257].
|
698 |
+
# This 1 corresponds to class_embeds, which is always attended to.
|
699 |
+
attn_mask = torch.cat([torch.ones_like(attn_mask[:, :, :1]), attn_mask], dim=-1)
|
700 |
+
attn_mask_pairs = torch.matmul(attn_mask.transpose(-1, -2), attn_mask).unsqueeze(1)
|
701 |
+
else:
|
702 |
+
attn_mask_pairs = None
|
703 |
+
|
704 |
+
# encoder: CLIPEncoder.
|
705 |
+
encoder_outputs = self.encoder(
|
706 |
+
inputs_embeds=hidden_states,
|
707 |
+
# New feature: (***The official documentation is wrong***)
|
708 |
+
# attention_mask (`torch.Tensor` of shape `(batch_size, 1, sequence_length, sequence_length)`, *optional*):
|
709 |
+
# Mask to avoid performing attention on pairs of token. Mask values selected in `[0, 1]`:
|
710 |
+
# - 1 for pairs that are **not masked**,
|
711 |
+
# - 0 for pairs that are **masked**.
|
712 |
+
# attention_mask is eventually used by CLIPEncoderLayer:
|
713 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L370
|
714 |
+
attention_mask=attn_mask_pairs,
|
715 |
+
output_attentions=output_attentions, # False
|
716 |
+
output_hidden_states=output_hidden_states, # True
|
717 |
+
return_dict=return_dict, # True
|
718 |
+
)
|
719 |
+
|
720 |
+
# last_hidden_state: [BS, 257, 1280]
|
721 |
+
last_hidden_state = encoder_outputs[0]
|
722 |
+
pooled_output = last_hidden_state[:, 0, :]
|
723 |
+
pooled_output = self.post_layernorm(pooled_output)
|
724 |
+
|
725 |
+
# return_dict is True.
|
726 |
+
if not return_dict:
|
727 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
728 |
+
|
729 |
+
return BaseModelOutputWithPooling2(
|
730 |
+
last_hidden_state=last_hidden_state,
|
731 |
+
pooler_output=pooled_output,
|
732 |
+
hidden_states=encoder_outputs.hidden_states,
|
733 |
+
attentions=encoder_outputs.attentions,
|
734 |
+
# Newly added: return resized flattened attention mask.
|
735 |
+
# [BS, 1, 257] -> [BS, 257, 1]
|
736 |
+
attn_mask=attn_mask.permute(0, 2, 1) if attn_mask is not None else None
|
737 |
+
)
|
738 |
+
|
739 |
+
|
740 |
+
class CLIPVisionModelWithMask(CLIPVisionModel):
|
741 |
+
def __init__(self, config):
|
742 |
+
super().__init__(config)
|
743 |
+
# Replace vision_model.forward() with the new one that supports mask.
|
744 |
+
self.vision_model.forward = CLIPVisionTransformer_forward.__get__(self.vision_model)
|
745 |
+
|
746 |
+
def forward(self, pixel_values = None, attn_mask = None, output_attentions = None,
|
747 |
+
output_hidden_states = None, return_dict = None):
|
748 |
+
|
749 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
750 |
+
|
751 |
+
return self.vision_model(
|
752 |
+
pixel_values=pixel_values,
|
753 |
+
attn_mask=attn_mask,
|
754 |
+
output_attentions=output_attentions,
|
755 |
+
output_hidden_states=output_hidden_states,
|
756 |
+
return_dict=return_dict,
|
757 |
+
)
|
758 |
+
|
adaface/util.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
# add_noise_to_tensor() adds a fixed amount of noise to the tensor.
|
9 |
+
def add_noise_to_tensor(ts, noise_std, noise_std_is_relative=True, keep_norm=False,
|
10 |
+
std_dim=-1, norm_dim=-1):
|
11 |
+
if noise_std_is_relative:
|
12 |
+
ts_std_mean = ts.std(dim=std_dim).mean().detach()
|
13 |
+
noise_std *= ts_std_mean
|
14 |
+
|
15 |
+
noise = torch.randn_like(ts) * noise_std
|
16 |
+
if keep_norm:
|
17 |
+
orig_norm = ts.norm(dim=norm_dim, keepdim=True)
|
18 |
+
ts = ts + noise
|
19 |
+
new_norm = ts.norm(dim=norm_dim, keepdim=True).detach()
|
20 |
+
ts = ts * orig_norm / (new_norm + 1e-8)
|
21 |
+
else:
|
22 |
+
ts = ts + noise
|
23 |
+
|
24 |
+
return ts
|
25 |
+
|
26 |
+
|
27 |
+
# Revised from RevGrad, by removing the grad negation.
|
28 |
+
class ScaleGrad(torch.autograd.Function):
|
29 |
+
@staticmethod
|
30 |
+
def forward(ctx, input_, alpha_, debug=False):
|
31 |
+
ctx.save_for_backward(alpha_, debug)
|
32 |
+
output = input_
|
33 |
+
if debug:
|
34 |
+
print(f"input: {input_.abs().mean().item()}")
|
35 |
+
return output
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
def backward(ctx, grad_output): # pragma: no cover
|
39 |
+
# saved_tensors returns a tuple of tensors.
|
40 |
+
alpha_, debug = ctx.saved_tensors
|
41 |
+
if ctx.needs_input_grad[0]:
|
42 |
+
grad_output2 = grad_output * alpha_
|
43 |
+
if debug:
|
44 |
+
print(f"grad_output2: {grad_output2.abs().mean().item()}")
|
45 |
+
else:
|
46 |
+
grad_output2 = None
|
47 |
+
return grad_output2, None, None
|
48 |
+
|
49 |
+
class GradientScaler(nn.Module):
|
50 |
+
def __init__(self, alpha=1., debug=False, *args, **kwargs):
|
51 |
+
"""
|
52 |
+
A gradient scaling layer.
|
53 |
+
This layer has no parameters, and simply scales the gradient in the backward pass.
|
54 |
+
"""
|
55 |
+
super().__init__(*args, **kwargs)
|
56 |
+
|
57 |
+
self._alpha = torch.tensor(alpha, requires_grad=False)
|
58 |
+
self._debug = torch.tensor(debug, requires_grad=False)
|
59 |
+
|
60 |
+
def forward(self, input_):
|
61 |
+
_debug = self._debug if hasattr(self, '_debug') else False
|
62 |
+
return ScaleGrad.apply(input_, self._alpha.to(input_.device), _debug)
|
63 |
+
|
64 |
+
def gen_gradient_scaler(alpha, debug=False):
|
65 |
+
if alpha == 1:
|
66 |
+
return nn.Identity()
|
67 |
+
if alpha > 0:
|
68 |
+
return GradientScaler(alpha, debug=debug)
|
69 |
+
else:
|
70 |
+
assert alpha == 0
|
71 |
+
# Don't use lambda function here, otherwise the object can't be pickled.
|
72 |
+
return torch.detach
|
73 |
+
|
74 |
+
#@torch.autocast(device_type="cuda")
|
75 |
+
# In AdaFaceWrapper, input_max_length is 22.
|
76 |
+
def arc2face_forward_face_embs(tokenizer, arc2face_text_encoder, face_embs,
|
77 |
+
input_max_length=77, return_full_and_core_embs=True):
|
78 |
+
|
79 |
+
'''
|
80 |
+
arc2face_text_encoder: arc2face_models.py CLIPTextModelWrapper instance.
|
81 |
+
face_embs: (N, 512) normalized ArcFace embeddings.
|
82 |
+
return_full_and_core_embs: Return both the full prompt embeddings and the core embeddings.
|
83 |
+
If False, return only the core embeddings.
|
84 |
+
|
85 |
+
'''
|
86 |
+
|
87 |
+
# arcface_token_id: 1014
|
88 |
+
arcface_token_id = tokenizer.encode("id", add_special_tokens=False)[0]
|
89 |
+
|
90 |
+
# This step should be quite fast, and there's no need to cache the input_ids.
|
91 |
+
input_ids = tokenizer(
|
92 |
+
"photo of a id person",
|
93 |
+
truncation=True,
|
94 |
+
padding="max_length",
|
95 |
+
max_length=input_max_length, #tokenizer.model_max_length,
|
96 |
+
return_tensors="pt",
|
97 |
+
).input_ids.to(face_embs.device)
|
98 |
+
# input_ids: [1, 77] or [3, 77] (during training).
|
99 |
+
input_ids = input_ids.repeat(len(face_embs), 1)
|
100 |
+
face_embs_dtype = face_embs.dtype
|
101 |
+
face_embs = face_embs.to(arc2face_text_encoder.dtype)
|
102 |
+
# face_embs_padded: [1, 512] -> [1, 768].
|
103 |
+
face_embs_padded = F.pad(face_embs, (0, arc2face_text_encoder.config.hidden_size - face_embs.shape[-1]), "constant", 0)
|
104 |
+
# arc2face_text_encoder(input_ids=input_ids, ...) is called twice. The first is only to get the token embeddings (the shallowest mapping).
|
105 |
+
# The second call does the ordinary CLIP text encoding pass.
|
106 |
+
token_embs = arc2face_text_encoder(input_ids=input_ids, return_token_embs=True)
|
107 |
+
token_embs[input_ids==arcface_token_id] = face_embs_padded
|
108 |
+
|
109 |
+
prompt_embeds = arc2face_text_encoder(
|
110 |
+
input_ids=input_ids,
|
111 |
+
input_token_embs=token_embs,
|
112 |
+
return_token_embs=False
|
113 |
+
)[0]
|
114 |
+
|
115 |
+
# Restore the original dtype of prompt_embeds: float16 -> float32.
|
116 |
+
prompt_embeds = prompt_embeds.to(face_embs_dtype)
|
117 |
+
|
118 |
+
if return_full_and_core_embs:
|
119 |
+
# token 4: 'id' in "photo of a id person".
|
120 |
+
# 4:20 are the most important 16 embeddings that contain the subject's identity.
|
121 |
+
# [N, 77, 768] -> [N, 16, 768]
|
122 |
+
return prompt_embeds, prompt_embeds[:, 4:20]
|
123 |
+
else:
|
124 |
+
# [N, 16, 768]
|
125 |
+
return prompt_embeds[:, 4:20]
|
126 |
+
|
127 |
+
def get_b_core_e_embeddings(prompt_embeds, length=22):
|
128 |
+
b_core_e_embs = torch.cat([ prompt_embeds[:, :length], prompt_embeds[:, [-1]] ], dim=1)
|
129 |
+
return b_core_e_embs
|
130 |
+
|
131 |
+
# return_emb_types: a list of strings, each string is among ['full', 'core', 'full_zeroed_extra', 'b_core_e'].
|
132 |
+
def arc2face_inverse_face_prompt_embs(clip_tokenizer, inverse_text_encoder, face_prompt_embs, list_extra_words,
|
133 |
+
return_emb_types, pad_embeddings, hidden_state_layer_weights=None,
|
134 |
+
input_max_length=77, zs_extra_words_scale=0.5):
|
135 |
+
|
136 |
+
'''
|
137 |
+
inverse_text_encoder: arc2face_models.py CLIPTextModelWrapper instance with **custom weights**.
|
138 |
+
inverse_text_encoder is NOT the original arc2face text encoder, but retrained to do inverse mapping.
|
139 |
+
face_prompt_embs: (BS, 16, 768). Only the core embeddings, no paddings.
|
140 |
+
list_extra_words: [s_1, ..., s_BS], each s_i is a list of extra words to be added to the prompt.
|
141 |
+
return_full_and_core_embs: Return both the full prompt embeddings and the core embeddings.
|
142 |
+
If False, return only the core embeddings.
|
143 |
+
'''
|
144 |
+
|
145 |
+
if list_extra_words is not None:
|
146 |
+
if len(list_extra_words) != len(face_prompt_embs):
|
147 |
+
if len(face_prompt_embs) > 1:
|
148 |
+
print("Warn: list_extra_words has different length as face_prompt_embs.")
|
149 |
+
if len(list_extra_words) == 1:
|
150 |
+
list_extra_words = list_extra_words * len(face_prompt_embs)
|
151 |
+
else:
|
152 |
+
breakpoint()
|
153 |
+
else:
|
154 |
+
# len(face_prompt_embs) == 1, this occurs when same_subject_in_batch == True, e.g. in do_mix_prompt_distillation.
|
155 |
+
# But list_extra_words always corresponds to the actual batch size. So we only take the first element.
|
156 |
+
list_extra_words = list_extra_words[:1]
|
157 |
+
|
158 |
+
for extra_words in list_extra_words:
|
159 |
+
assert len(extra_words.split()) <= 2, "Each extra_words string should consist of at most 2 words."
|
160 |
+
# 16 ", " are placeholders for face_prompt_embs.
|
161 |
+
prompt_templates = [ "photo of a " + ", " * 16 + list_extra_words[i] for i in range(len(list_extra_words)) ]
|
162 |
+
else:
|
163 |
+
# 16 ", " are placeholders for face_prompt_embs.
|
164 |
+
# No extra words are added to the prompt.
|
165 |
+
prompt_templates = [ "photo of a " + ", " * 16 for _ in range(len(face_prompt_embs)) ]
|
166 |
+
|
167 |
+
# This step should be quite fast, and there's no need to cache the input_ids.
|
168 |
+
# input_ids: [BS, 77].
|
169 |
+
input_ids = clip_tokenizer(
|
170 |
+
prompt_templates,
|
171 |
+
truncation=True,
|
172 |
+
padding="max_length",
|
173 |
+
max_length=input_max_length,
|
174 |
+
return_tensors="pt",
|
175 |
+
).input_ids.to(face_prompt_embs.device)
|
176 |
+
|
177 |
+
face_prompt_embs_dtype = face_prompt_embs.dtype
|
178 |
+
face_prompt_embs = face_prompt_embs.to(inverse_text_encoder.dtype)
|
179 |
+
|
180 |
+
# token_embs: [1, 77, 768]. This call is only to get the template token embeddings (the shallowest mapping).
|
181 |
+
token_embs = inverse_text_encoder(input_ids=input_ids, return_token_embs=True)
|
182 |
+
# token 4: first ", " in the template prompt.
|
183 |
+
# Replace embeddings of 16 placeholder ", " with face_prompt_embs.
|
184 |
+
token_embs[:, 4:20] = face_prompt_embs
|
185 |
+
|
186 |
+
# This call does the ordinary CLIP text encoding pass.
|
187 |
+
prompt_embeds = inverse_text_encoder(
|
188 |
+
input_ids=input_ids,
|
189 |
+
input_token_embs=token_embs,
|
190 |
+
hidden_state_layer_weights=hidden_state_layer_weights,
|
191 |
+
return_token_embs=False
|
192 |
+
)[0]
|
193 |
+
|
194 |
+
# Restore the original dtype of prompt_embeds: float16 -> float32.
|
195 |
+
prompt_embeds = prompt_embeds.to(face_prompt_embs_dtype)
|
196 |
+
# token 4: first ", " in the template prompt.
|
197 |
+
# 4:20 are the most important 16 embeddings that contain the subject's identity.
|
198 |
+
# 20:22 are embeddings of the (at most) two extra words.
|
199 |
+
# [N, 77, 768] -> [N, 16, 768]
|
200 |
+
core_prompt_embs = prompt_embeds[:, 4:20]
|
201 |
+
if list_extra_words is not None:
|
202 |
+
# [N, 16, 768] -> [N, 18, 768]
|
203 |
+
extra_words_embs = prompt_embeds[:, 20:22] * zs_extra_words_scale
|
204 |
+
core_prompt_embs = torch.cat([core_prompt_embs, extra_words_embs], dim=1)
|
205 |
+
|
206 |
+
return_prompts = []
|
207 |
+
for emb_type in return_emb_types:
|
208 |
+
if emb_type == 'full':
|
209 |
+
return_prompts.append(prompt_embeds)
|
210 |
+
elif emb_type == 'full_half_pad':
|
211 |
+
prompt_embeds2 = prompt_embeds.clone()
|
212 |
+
PADS = prompt_embeds2.shape[1] - 23
|
213 |
+
if PADS >= 2:
|
214 |
+
# Fill half of the remaining embeddings with pad embeddings.
|
215 |
+
prompt_embeds2[:, 22:22+PADS//2] = pad_embeddings[22:22+PADS//2]
|
216 |
+
return_prompts.append(prompt_embeds2)
|
217 |
+
elif emb_type == 'full_pad':
|
218 |
+
prompt_embeds2 = prompt_embeds.clone()
|
219 |
+
# Fill the 22nd to the second last embeddings with pad embeddings.
|
220 |
+
prompt_embeds2[:, 22:-1] = pad_embeddings[22:-1]
|
221 |
+
return_prompts.append(prompt_embeds2)
|
222 |
+
elif emb_type == 'core':
|
223 |
+
return_prompts.append(core_prompt_embs)
|
224 |
+
elif emb_type == 'full_zeroed_extra':
|
225 |
+
prompt_embeds2 = prompt_embeds.clone()
|
226 |
+
# Only add two pad embeddings. The remaining embeddings are set to 0.
|
227 |
+
# Make the positional embeddings align with the actual positions.
|
228 |
+
prompt_embeds2[:, 22:24] = pad_embeddings[22:24]
|
229 |
+
prompt_embeds2[:, 24:-1] = 0
|
230 |
+
return_prompts.append(prompt_embeds2)
|
231 |
+
elif emb_type == 'b_core_e':
|
232 |
+
# The first 22 embeddings, plus the last EOS embedding.
|
233 |
+
b_core_e_embs = get_b_core_e_embeddings(prompt_embeds, length=22)
|
234 |
+
return_prompts.append(b_core_e_embs)
|
235 |
+
else:
|
236 |
+
breakpoint()
|
237 |
+
|
238 |
+
return return_prompts
|
239 |
+
|
240 |
+
# if pre_face_embs is None, generate random face embeddings [BS, 512].
|
241 |
+
# image_folder is passed only for logging purpose. image_paths contains the paths of the images.
|
242 |
+
def get_arc2face_id_prompt_embs(face_app, clip_tokenizer, arc2face_text_encoder,
|
243 |
+
extract_faceid_embeds, pre_face_embs,
|
244 |
+
image_folder, image_paths, images_np,
|
245 |
+
id_batch_size, device,
|
246 |
+
input_max_length=77, noise_level=0.0,
|
247 |
+
return_core_id_embs=False,
|
248 |
+
gen_neg_prompt=False, verbose=False):
|
249 |
+
face_image_count = 0
|
250 |
+
|
251 |
+
if extract_faceid_embeds:
|
252 |
+
faceid_embeds = []
|
253 |
+
if image_paths is not None:
|
254 |
+
images_np = []
|
255 |
+
for image_path in image_paths:
|
256 |
+
image_np = np.array(Image.open(image_path))
|
257 |
+
images_np.append(image_np)
|
258 |
+
|
259 |
+
for i, image_np in enumerate(images_np):
|
260 |
+
image_obj = Image.fromarray(image_np).resize((512, 512), Image.NEAREST)
|
261 |
+
# Remove alpha channel if it exists.
|
262 |
+
if image_obj.mode == 'RGBA':
|
263 |
+
image_obj = image_obj.convert('RGB')
|
264 |
+
# This seems NOT a bug. The input image should be in BGR format, as per
|
265 |
+
# https://github.com/deepinsight/insightface/issues/524
|
266 |
+
image_np = cv2.cvtColor(np.array(image_obj), cv2.COLOR_RGB2BGR)
|
267 |
+
image_np = np.array(image_obj)
|
268 |
+
|
269 |
+
face_infos = face_app.get(image_np)
|
270 |
+
if verbose and image_paths is not None:
|
271 |
+
print(image_paths[i], len(face_infos))
|
272 |
+
# Assume all images belong to the same subject. Therefore, we can skip the images with no face detected.
|
273 |
+
if len(face_infos) == 0:
|
274 |
+
continue
|
275 |
+
# only use the maximum face
|
276 |
+
face_info = sorted(face_infos, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1]
|
277 |
+
# Each faceid_embed: [1, 512]
|
278 |
+
faceid_embeds.append(torch.from_numpy(face_info.normed_embedding).unsqueeze(0))
|
279 |
+
face_image_count += 1
|
280 |
+
|
281 |
+
if verbose:
|
282 |
+
if image_folder is not None:
|
283 |
+
print(f"Extracted ID embeddings from {face_image_count} images in {image_folder}")
|
284 |
+
else:
|
285 |
+
print(f"Extracted ID embeddings from {face_image_count} images")
|
286 |
+
|
287 |
+
if len(faceid_embeds) == 0:
|
288 |
+
print("No face detected. Use a random face instead.")
|
289 |
+
faceid_embeds = torch.randn(id_batch_size, 512).to(device=device, dtype=torch.float16)
|
290 |
+
else:
|
291 |
+
# faceid_embeds: [10, 512]
|
292 |
+
faceid_embeds = torch.cat(faceid_embeds, dim=0)
|
293 |
+
# faceid_embeds: [10, 512] -> [1, 512].
|
294 |
+
# and the resulted prompt embeddings are the same.
|
295 |
+
faceid_embeds = faceid_embeds.mean(dim=0, keepdim=True).to(device=device, dtype=torch.float16)
|
296 |
+
else:
|
297 |
+
# Random face embeddings. faceid_embeds: [BS, 512].
|
298 |
+
if pre_face_embs is None:
|
299 |
+
faceid_embeds = torch.randn(id_batch_size, 512)
|
300 |
+
else:
|
301 |
+
faceid_embeds = pre_face_embs
|
302 |
+
if pre_face_embs.shape[0] == 1:
|
303 |
+
faceid_embeds = faceid_embeds.repeat(id_batch_size, 1)
|
304 |
+
|
305 |
+
faceid_embeds = faceid_embeds.to(device=device, dtype=torch.float16)
|
306 |
+
|
307 |
+
if noise_level > 0:
|
308 |
+
# If id_batch_size > 1, after adding noises, the id_batch_size embeddings will be different.
|
309 |
+
faceid_embeds = add_noise_to_tensor(faceid_embeds, noise_level, noise_std_is_relative=True, keep_norm=True)
|
310 |
+
|
311 |
+
faceid_embeds = F.normalize(faceid_embeds, p=2, dim=-1)
|
312 |
+
|
313 |
+
# arc2face_pos_prompt_emb, arc2face_neg_prompt_emb: [BS, 77, 768]
|
314 |
+
with torch.no_grad():
|
315 |
+
arc2face_pos_prompt_emb, arc2face_pos_core_prompt_emb = \
|
316 |
+
arc2face_forward_face_embs(clip_tokenizer, arc2face_text_encoder,
|
317 |
+
faceid_embeds, input_max_length=input_max_length,
|
318 |
+
return_full_and_core_embs=True)
|
319 |
+
if return_core_id_embs:
|
320 |
+
arc2face_pos_prompt_emb = arc2face_pos_core_prompt_emb
|
321 |
+
# If extract_faceid_embeds, we assume all images are from the same subject, and the batch dim of faceid_embeds is 1.
|
322 |
+
# So we need to repeat faceid_embeds.
|
323 |
+
if extract_faceid_embeds:
|
324 |
+
faceid_embeds = faceid_embeds.repeat(id_batch_size, 1)
|
325 |
+
arc2face_pos_prompt_emb = arc2face_pos_prompt_emb.repeat(id_batch_size, 1, 1)
|
326 |
+
|
327 |
+
if gen_neg_prompt:
|
328 |
+
with torch.no_grad():
|
329 |
+
arc2face_neg_prompt_emb, arc2face_neg_core_prompt_emb = \
|
330 |
+
arc2face_forward_face_embs(clip_tokenizer, arc2face_text_encoder,
|
331 |
+
torch.zeros_like(faceid_embeds),
|
332 |
+
input_max_length=input_max_length,
|
333 |
+
return_full_and_core_embs=True)
|
334 |
+
if return_core_id_embs:
|
335 |
+
arc2face_neg_prompt_emb = arc2face_neg_core_prompt_emb
|
336 |
+
|
337 |
+
#if extract_faceid_embeds:
|
338 |
+
# arc2face_neg_prompt_emb = arc2face_neg_prompt_emb.repeat(id_batch_size, 1, 1)
|
339 |
+
return face_image_count, faceid_embeds, arc2face_pos_prompt_emb, arc2face_neg_prompt_emb
|
340 |
+
else:
|
341 |
+
return face_image_count, faceid_embeds, arc2face_pos_prompt_emb
|
342 |
+
|
adaface_wrapper.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import CLIPTextModel
|
4 |
+
from diffusers import (
|
5 |
+
StableDiffusionPipeline,
|
6 |
+
StableDiffusionImg2ImgPipeline,
|
7 |
+
UNet2DConditionModel,
|
8 |
+
DDIMScheduler,
|
9 |
+
AutoencoderKL,
|
10 |
+
)
|
11 |
+
from insightface.app import FaceAnalysis
|
12 |
+
from adaface.arc2face_models import CLIPTextModelWrapper
|
13 |
+
from adaface.util import get_arc2face_id_prompt_embs
|
14 |
+
import re, os
|
15 |
+
import sys
|
16 |
+
sys.modules['ldm'] = sys.modules['adaface']
|
17 |
+
|
18 |
+
class AdaFaceWrapper(nn.Module):
|
19 |
+
def __init__(self, pipeline_name, base_model_path, adaface_ckpt_path, device,
|
20 |
+
subject_string='z', num_vectors=16,
|
21 |
+
num_inference_steps=50, negative_prompt=None,
|
22 |
+
use_840k_vae=False, use_ds_text_encoder=False, is_training=False):
|
23 |
+
'''
|
24 |
+
pipeline_name: "text2img" or "img2img" or None. If None, the unet and vae are
|
25 |
+
removed from the pipeline to release RAM.
|
26 |
+
'''
|
27 |
+
super().__init__()
|
28 |
+
self.pipeline_name = pipeline_name
|
29 |
+
self.base_model_path = base_model_path
|
30 |
+
self.adaface_ckpt_path = adaface_ckpt_path
|
31 |
+
self.use_840k_vae = use_840k_vae
|
32 |
+
self.use_ds_text_encoder = use_ds_text_encoder
|
33 |
+
self.subject_string = subject_string
|
34 |
+
self.num_vectors = num_vectors
|
35 |
+
self.num_inference_steps = num_inference_steps
|
36 |
+
self.device = device
|
37 |
+
self.is_training = is_training
|
38 |
+
self.initialize_pipeline()
|
39 |
+
self.extend_tokenizer_and_text_encoder()
|
40 |
+
if negative_prompt is None:
|
41 |
+
self.negative_prompt = \
|
42 |
+
"flaws in the eyes, flaws in the face, lowres, non-HDRi, low quality, worst quality, artifacts, noise, text, watermark, glitch, " \
|
43 |
+
"mutated, ugly, disfigured, hands, partially rendered objects, partially rendered eyes, deformed eyeballs, cross-eyed, blurry, " \
|
44 |
+
"mutation, duplicate, out of frame, cropped, mutilated, bad anatomy, deformed, bad proportions, " \
|
45 |
+
"nude, naked, nsfw, topless, bare breasts"
|
46 |
+
else:
|
47 |
+
self.negative_prompt = negative_prompt
|
48 |
+
|
49 |
+
def load_subj_basis_generator(self, adaface_ckpt_path):
|
50 |
+
ckpt = torch.load(adaface_ckpt_path, map_location='cpu')
|
51 |
+
string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
|
52 |
+
if self.subject_string not in string_to_subj_basis_generator_dict:
|
53 |
+
print(f"Subject '{self.subject_string}' not found in the embedding manager.")
|
54 |
+
breakpoint()
|
55 |
+
|
56 |
+
self.subj_basis_generator = string_to_subj_basis_generator_dict[self.subject_string]
|
57 |
+
# In the original ckpt, num_out_layers is 16 for layerwise embeddings.
|
58 |
+
# But we don't do layerwise embeddings here, so we set it to 1.
|
59 |
+
self.subj_basis_generator.num_out_layers = 1
|
60 |
+
print(f"Loaded subject basis generator for '{self.subject_string}'.")
|
61 |
+
print(repr(self.subj_basis_generator))
|
62 |
+
self.subj_basis_generator.to(self.device)
|
63 |
+
if self.is_training:
|
64 |
+
self.subj_basis_generator.train()
|
65 |
+
else:
|
66 |
+
self.subj_basis_generator.eval()
|
67 |
+
|
68 |
+
def initialize_pipeline(self):
|
69 |
+
self.load_subj_basis_generator(self.adaface_ckpt_path)
|
70 |
+
# arc2face_text_encoder maps the face analysis embedding to 16 face embeddings
|
71 |
+
# in the UNet image space.
|
72 |
+
arc2face_text_encoder = CLIPTextModelWrapper.from_pretrained(
|
73 |
+
'models/arc2face', subfolder="encoder", torch_dtype=torch.float16
|
74 |
+
)
|
75 |
+
self.arc2face_text_encoder = arc2face_text_encoder.to(self.device)
|
76 |
+
|
77 |
+
if self.use_840k_vae:
|
78 |
+
# The 840000-step vae model is slightly better in face details than the original vae model.
|
79 |
+
# https://huggingface.co/stabilityai/sd-vae-ft-mse-original
|
80 |
+
vae = AutoencoderKL.from_single_file("models/diffusers/sd-vae-ft-mse-original/vae-ft-mse-840000-ema-pruned.ckpt", torch_dtype=torch.float16)
|
81 |
+
else:
|
82 |
+
vae = None
|
83 |
+
|
84 |
+
if self.use_ds_text_encoder:
|
85 |
+
# The dreamshaper v7 finetuned text encoder follows the prompt slightly better than the original text encoder.
|
86 |
+
# https://huggingface.co/Lykon/DreamShaper/tree/main/text_encoder
|
87 |
+
text_encoder = CLIPTextModel.from_pretrained("models/ds_text_encoder", torch_dtype=torch.float16)
|
88 |
+
else:
|
89 |
+
text_encoder = None
|
90 |
+
|
91 |
+
remove_unet = False
|
92 |
+
|
93 |
+
if self.pipeline_name == "img2img":
|
94 |
+
PipelineClass = StableDiffusionImg2ImgPipeline
|
95 |
+
elif self.pipeline_name == "text2img":
|
96 |
+
PipelineClass = StableDiffusionPipeline
|
97 |
+
# pipeline_name is None means only use this instance to generate adaface embeddings, not to generate images.
|
98 |
+
elif self.pipeline_name is None:
|
99 |
+
PipelineClass = StableDiffusionPipeline
|
100 |
+
remove_unet = True
|
101 |
+
else:
|
102 |
+
raise ValueError(f"Unknown pipeline name: {self.pipeline_name}")
|
103 |
+
|
104 |
+
if os.path.isfile(self.base_model_path):
|
105 |
+
pipeline = PipelineClass.from_single_file(
|
106 |
+
self.base_model_path,
|
107 |
+
torch_dtype=torch.float16
|
108 |
+
)
|
109 |
+
else:
|
110 |
+
pipeline = PipelineClass.from_pretrained(
|
111 |
+
self.base_model_path,
|
112 |
+
torch_dtype=torch.float16,
|
113 |
+
safety_checker=None
|
114 |
+
)
|
115 |
+
print(f"Loaded pipeline from {self.base_model_path}.")
|
116 |
+
|
117 |
+
if self.use_840k_vae:
|
118 |
+
pipeline.vae = vae
|
119 |
+
print("Replaced the VAE with the 840k-step VAE.")
|
120 |
+
|
121 |
+
if self.use_ds_text_encoder:
|
122 |
+
pipeline.text_encoder = text_encoder
|
123 |
+
print("Replaced the text encoder with the DreamShaper text encoder.")
|
124 |
+
|
125 |
+
if remove_unet:
|
126 |
+
# Remove unet and vae to release RAM. Only keep tokenizer and text_encoder.
|
127 |
+
pipeline.unet = None
|
128 |
+
pipeline.vae = None
|
129 |
+
print("Removed UNet and VAE from the pipeline.")
|
130 |
+
|
131 |
+
noise_scheduler = DDIMScheduler(
|
132 |
+
num_train_timesteps=1000,
|
133 |
+
beta_start=0.00085,
|
134 |
+
beta_end=0.012,
|
135 |
+
beta_schedule="scaled_linear",
|
136 |
+
clip_sample=False,
|
137 |
+
set_alpha_to_one=False,
|
138 |
+
steps_offset=1,
|
139 |
+
)
|
140 |
+
|
141 |
+
pipeline.scheduler = noise_scheduler
|
142 |
+
self.pipeline = pipeline.to(self.device)
|
143 |
+
# FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2.
|
144 |
+
# Note there's a second "model" in the path.
|
145 |
+
self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
146 |
+
self.face_app.prepare(ctx_id=0, det_size=(512, 512))
|
147 |
+
# Patch the missing tokenizer in the subj_basis_generator.
|
148 |
+
if not hasattr(self.subj_basis_generator, 'clip_tokenizer'):
|
149 |
+
self.subj_basis_generator.clip_tokenizer = self.pipeline.tokenizer
|
150 |
+
print("Patched the missing tokenizer in the subj_basis_generator.")
|
151 |
+
|
152 |
+
def extend_tokenizer_and_text_encoder(self):
|
153 |
+
if self.num_vectors < 1:
|
154 |
+
raise ValueError(f"num_vectors has to be larger or equal to 1, but is {self.num_vectors}")
|
155 |
+
|
156 |
+
tokenizer = self.pipeline.tokenizer
|
157 |
+
# Add z0, z1, z2, ..., z15.
|
158 |
+
self.placeholder_tokens = []
|
159 |
+
for i in range(0, self.num_vectors):
|
160 |
+
self.placeholder_tokens.append(f"{self.subject_string}_{i}")
|
161 |
+
|
162 |
+
self.placeholder_tokens_str = " ".join(self.placeholder_tokens)
|
163 |
+
|
164 |
+
# Add the new tokens to the tokenizer.
|
165 |
+
num_added_tokens = tokenizer.add_tokens(self.placeholder_tokens)
|
166 |
+
if num_added_tokens != self.num_vectors:
|
167 |
+
raise ValueError(
|
168 |
+
f"The tokenizer already contains the token {self.subject_string}. Please pass a different"
|
169 |
+
" `subject_string` that is not already in the tokenizer.")
|
170 |
+
|
171 |
+
print(f"Added {num_added_tokens} tokens ({self.placeholder_tokens_str}) to the tokenizer.")
|
172 |
+
|
173 |
+
# placeholder_token_ids: [49408, ..., 49423].
|
174 |
+
self.placeholder_token_ids = tokenizer.convert_tokens_to_ids(self.placeholder_tokens)
|
175 |
+
# print(self.placeholder_token_ids)
|
176 |
+
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
177 |
+
old_weight = self.pipeline.text_encoder.get_input_embeddings().weight
|
178 |
+
self.pipeline.text_encoder.resize_token_embeddings(len(tokenizer))
|
179 |
+
new_weight = self.pipeline.text_encoder.get_input_embeddings().weight
|
180 |
+
print(f"Resized text encoder token embeddings from {old_weight.shape} to {new_weight.shape} on {new_weight.device}.")
|
181 |
+
|
182 |
+
# Extend pipeline.text_encoder with the adaface subject emeddings.
|
183 |
+
# subj_embs: [16, 768].
|
184 |
+
def update_text_encoder_subj_embs(self, subj_embs):
|
185 |
+
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
186 |
+
token_embeds = self.pipeline.text_encoder.get_input_embeddings().weight.data
|
187 |
+
with torch.no_grad():
|
188 |
+
for i, token_id in enumerate(self.placeholder_token_ids):
|
189 |
+
token_embeds[token_id] = subj_embs[i]
|
190 |
+
print(f"Updated {len(self.placeholder_token_ids)} tokens ({self.placeholder_tokens_str}) in the text encoder.")
|
191 |
+
|
192 |
+
def update_prompt(self, prompt):
|
193 |
+
# If the placeholder tokens are already in the prompt, then return the prompt as is.
|
194 |
+
if self.placeholder_tokens_str in prompt:
|
195 |
+
return prompt
|
196 |
+
|
197 |
+
# If the subject string 'z' is not in the prompt, then simply prepend the placeholder tokens to the prompt.
|
198 |
+
if re.search(r'\b' + self.subject_string + r'\b', prompt) is None:
|
199 |
+
print(f"Subject string '{self.subject_string}' not found in the prompt. Adding it.")
|
200 |
+
comp_prompt = self.placeholder_tokens_str + " " + prompt
|
201 |
+
else:
|
202 |
+
# Replace the subject string 'z' with the placeholder tokens.
|
203 |
+
comp_prompt = re.sub(r'\b' + self.subject_string + r'\b', self.placeholder_tokens_str, prompt)
|
204 |
+
return comp_prompt
|
205 |
+
|
206 |
+
# image_paths: a list of image paths. image_folder: the parent folder name.
|
207 |
+
def generate_adaface_embeddings(self, image_paths, image_folder=None,
|
208 |
+
pre_face_embs=None, gen_rand_face=False,
|
209 |
+
out_id_embs_scale=1., noise_level=0, update_text_encoder=True):
|
210 |
+
# faceid_embeds is a batch of extracted face analysis embeddings (BS * 512 = id_batch_size * 512).
|
211 |
+
# If extract_faceid_embeds is True, faceid_embeds is *the same* embedding repeated by id_batch_size times.
|
212 |
+
# Otherwise, faceid_embeds is a batch of random embeddings, each instance is different.
|
213 |
+
# The same applies to id_prompt_emb.
|
214 |
+
# faceid_embeds is in the face analysis embeddings. id_prompt_emb is in the image prompt space.
|
215 |
+
# Here id_batch_size = 1, so
|
216 |
+
# faceid_embeds: [1, 512]. NOT used later.
|
217 |
+
# id_prompt_emb: [1, 16, 768].
|
218 |
+
# NOTE: Since return_core_id_embs is True, id_prompt_emb is only the 16 core ID embeddings.
|
219 |
+
# arc2face prompt template: "photo of a id person"
|
220 |
+
# ID embeddings start from "id person ...". So there are 3 template tokens before the 16 ID embeddings.
|
221 |
+
face_image_count, faceid_embeds, id_prompt_emb \
|
222 |
+
= get_arc2face_id_prompt_embs(self.face_app, self.pipeline.tokenizer, self.arc2face_text_encoder,
|
223 |
+
extract_faceid_embeds=not gen_rand_face,
|
224 |
+
pre_face_embs=pre_face_embs,
|
225 |
+
# image_folder is passed only for logging purpose.
|
226 |
+
# image_paths contains the paths of the images.
|
227 |
+
image_folder=image_folder, image_paths=image_paths,
|
228 |
+
images_np=None,
|
229 |
+
id_batch_size=1,
|
230 |
+
device=self.device,
|
231 |
+
# input_max_length == 22: only keep the first 22 tokens,
|
232 |
+
# including 3 template tokens and 16 ID tokens, and BOS and EOS tokens.
|
233 |
+
# The results are indistinguishable from input_max_length=77.
|
234 |
+
input_max_length=22,
|
235 |
+
noise_level=noise_level,
|
236 |
+
return_core_id_embs=True,
|
237 |
+
gen_neg_prompt=False,
|
238 |
+
verbose=True)
|
239 |
+
|
240 |
+
if face_image_count == 0:
|
241 |
+
return None
|
242 |
+
|
243 |
+
# adaface_subj_embs: [1, 1, 16, 768].
|
244 |
+
# adaface_prompt_embs: [1, 77, 768] (not used).
|
245 |
+
adaface_subj_embs, adaface_prompt_embs = \
|
246 |
+
self.subj_basis_generator(id_prompt_emb, None, None,
|
247 |
+
out_id_embs_scale=out_id_embs_scale,
|
248 |
+
is_face=True, is_training=False,
|
249 |
+
adaface_prompt_embs_inf_type='full_half_pad')
|
250 |
+
# adaface_subj_embs: [16, 768]
|
251 |
+
adaface_subj_embs = adaface_subj_embs.squeeze()
|
252 |
+
if update_text_encoder:
|
253 |
+
self.update_text_encoder_subj_embs(adaface_subj_embs)
|
254 |
+
return adaface_subj_embs
|
255 |
+
|
256 |
+
def encode_prompt(self, prompt, negative_prompt=None, device="cuda", verbose=False):
|
257 |
+
if negative_prompt is None:
|
258 |
+
negative_prompt = self.negative_prompt
|
259 |
+
|
260 |
+
prompt = self.update_prompt(prompt)
|
261 |
+
if verbose:
|
262 |
+
print(f"Prompt: {prompt}")
|
263 |
+
|
264 |
+
# For some unknown reason, the text_encoder is still on CPU after self.pipeline.to(self.device).
|
265 |
+
# So we manually move it to GPU here.
|
266 |
+
self.pipeline.text_encoder.to(device)
|
267 |
+
# prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
|
268 |
+
prompt_embeds_, negative_prompt_embeds_ = \
|
269 |
+
self.pipeline.encode_prompt(prompt, device=device, num_images_per_prompt=1,
|
270 |
+
do_classifier_free_guidance=True, negative_prompt=negative_prompt)
|
271 |
+
return prompt_embeds_, negative_prompt_embeds_
|
272 |
+
|
273 |
+
# ref_img_strength is used only in the img2img pipeline.
|
274 |
+
def forward(self, noise, prompt, negative_prompt=None, guidance_scale=4.0,
|
275 |
+
out_image_count=4, ref_img_strength=0.8, generator=None, verbose=False):
|
276 |
+
if negative_prompt is None:
|
277 |
+
negative_prompt = self.negative_prompt
|
278 |
+
# prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
|
279 |
+
prompt_embeds_, negative_prompt_embeds_ = self.encode_prompt(prompt, negative_prompt, device=self.device, verbose=verbose)
|
280 |
+
# Repeat the prompt embeddings for all images in the batch.
|
281 |
+
prompt_embeds_ = prompt_embeds_.repeat(out_image_count, 1, 1)
|
282 |
+
negative_prompt_embeds_ = negative_prompt_embeds_.repeat(out_image_count, 1, 1)
|
283 |
+
noise = noise.to(self.device).to(torch.float16)
|
284 |
+
|
285 |
+
# noise: [BS, 4, 64, 64]
|
286 |
+
# When the pipeline is text2img, strength is ignored.
|
287 |
+
images = self.pipeline(image=noise,
|
288 |
+
prompt_embeds=prompt_embeds_,
|
289 |
+
negative_prompt_embeds=negative_prompt_embeds_,
|
290 |
+
num_inference_steps=self.num_inference_steps,
|
291 |
+
guidance_scale=guidance_scale,
|
292 |
+
num_images_per_prompt=1,
|
293 |
+
strength=ref_img_strength,
|
294 |
+
generator=generator).images
|
295 |
+
# images: [BS, 3, 512, 512]
|
296 |
+
return images
|
297 |
+
|
app.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append('./')
|
3 |
+
|
4 |
+
from adaface.adaface_wrapper import AdaFaceWrapper
|
5 |
+
import torch
|
6 |
+
from insightface.app import FaceAnalysis
|
7 |
+
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
import random
|
10 |
+
|
11 |
+
import gradio as gr
|
12 |
+
import spaces
|
13 |
+
import argparse
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument('--adaface_ckpt_path', type=str,
|
16 |
+
default='models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt')
|
17 |
+
parser.add_argument('--gpu', type=int, default=None)
|
18 |
+
parser.add_argument('--ip', type=str, default="0.0.0.0")
|
19 |
+
args = parser.parse_args()
|
20 |
+
|
21 |
+
# global variable
|
22 |
+
MAX_SEED = np.iinfo(np.int32).max
|
23 |
+
if torch.cuda.is_available():
|
24 |
+
device = "cuda" if args.gpu is None else f"cuda:{args.gpu}"
|
25 |
+
else:
|
26 |
+
device = "cpu"
|
27 |
+
dtype = torch.float16
|
28 |
+
|
29 |
+
# base_model_path is only used for initialization, not really used in the inference.
|
30 |
+
adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path="models/sar/sar.safetensors",
|
31 |
+
adaface_ckpt_path=args.adaface_ckpt_path, device=device)
|
32 |
+
|
33 |
+
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
34 |
+
if randomize_seed:
|
35 |
+
seed = random.randint(0, MAX_SEED)
|
36 |
+
return seed
|
37 |
+
|
38 |
+
def swap_to_gallery(images):
|
39 |
+
# Update uploaded_files_gallery, show files, hide clear_button_column
|
40 |
+
# Or:
|
41 |
+
# Update uploaded_init_img_gallery, show init_img_files, hide init_clear_button_column
|
42 |
+
return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(value=images, visible=False)
|
43 |
+
|
44 |
+
def remove_back_to_files():
|
45 |
+
# Hide uploaded_files_gallery, show clear_button_column, hide files, reset init_img_selected_idx
|
46 |
+
# Or:
|
47 |
+
# Hide uploaded_init_img_gallery, hide init_clear_button_column, show init_img_files, reset init_img_selected_idx
|
48 |
+
return gr.update(visible=False), gr.update(visible=False), gr.update(value=None, visible=True)
|
49 |
+
|
50 |
+
def update_out_gallery(images):
|
51 |
+
#rows = (len(images) + 1) // 2 # Calculate the number of rows needed
|
52 |
+
return gr.update(height=600)
|
53 |
+
|
54 |
+
@spaces.GPU
|
55 |
+
def generate_image(image_paths, guidance_scale, adaface_id_cfg_scale,
|
56 |
+
num_images, prompt, negative_prompt, seed, progress=gr.Progress(track_tqdm=True)):
|
57 |
+
|
58 |
+
if image_paths is None or len(image_paths) == 0:
|
59 |
+
raise gr.Error(f"Cannot find any input face image! Please upload a face image.")
|
60 |
+
|
61 |
+
if prompt is None:
|
62 |
+
prompt = ""
|
63 |
+
|
64 |
+
adaface_subj_embs = \
|
65 |
+
adaface.generate_adaface_embeddings(image_folder=None, image_paths=image_paths,
|
66 |
+
out_id_embs_scale=adaface_id_cfg_scale, update_text_encoder=True)
|
67 |
+
|
68 |
+
if adaface_subj_embs is None:
|
69 |
+
raise gr.Error(f"Failed to detect any faces! Please try with other images")
|
70 |
+
|
71 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
72 |
+
print(f"Manual seed: {seed}")
|
73 |
+
# Generate two images each time for the user to select from.
|
74 |
+
noise = torch.randn(num_images, 3, 512, 512, device=device, generator=generator)
|
75 |
+
#print(noise.abs().sum())
|
76 |
+
# samples: A list of PIL Image instances.
|
77 |
+
samples = adaface(noise, prompt, negative_prompt, guidance_scale=guidance_scale, out_image_count=num_images, generator=generator, verbose=True)
|
78 |
+
return samples
|
79 |
+
|
80 |
+
### Description
|
81 |
+
title = r"""
|
82 |
+
<h1>AdaFace: A Versatile Face Encoder for Zero-Shot Diffusion Model Personalization</h1>
|
83 |
+
"""
|
84 |
+
|
85 |
+
description = r"""
|
86 |
+
<b>Official demo</b> for our NeurIPS 2024 submission <b>AdaFace: A Versatile Face Encoder for Zero-Shot Diffusion Model Personalization</b>.<br>
|
87 |
+
|
88 |
+
❗️**Tips**❗️
|
89 |
+
1. Upload one or more images of a person. If multiple faces are detected, we use the largest one.
|
90 |
+
2. Increase <b>AdaFace CFG Scale</b> (preferred) or <b>Guidance scale</b> and/or to highlight fine facial features.
|
91 |
+
3. AdaFace Text-to-Video: <a href="https://huggingface.co/spaces/adaface-neurips/adaface-animate" style="display: inline-flex; align-items: center;">
|
92 |
+
AdaFace-Animate
|
93 |
+
<img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow" alt="Hugging Face Spaces" style="margin-left: 5px;">
|
94 |
+
</a>
|
95 |
+
|
96 |
+
**TODO**
|
97 |
+
- ControlNet integration.
|
98 |
+
"""
|
99 |
+
|
100 |
+
css = '''
|
101 |
+
.gradio-container {width: 85% !important}
|
102 |
+
'''
|
103 |
+
with gr.Blocks(css=css) as demo:
|
104 |
+
|
105 |
+
# description
|
106 |
+
gr.Markdown(title)
|
107 |
+
gr.Markdown(description)
|
108 |
+
|
109 |
+
with gr.Row():
|
110 |
+
with gr.Column():
|
111 |
+
|
112 |
+
# upload face image
|
113 |
+
# img_file = gr.Image(label="Upload a photo with a face", type="filepath")
|
114 |
+
img_files = gr.File(
|
115 |
+
label="Drag / Select 1 or more photos of a person's face",
|
116 |
+
file_types=["image"],
|
117 |
+
file_count="multiple"
|
118 |
+
)
|
119 |
+
uploaded_files_gallery = gr.Gallery(label="Subject images", visible=False, columns=3, rows=1, height=300)
|
120 |
+
with gr.Column(visible=False) as clear_button_column:
|
121 |
+
remove_and_reupload = gr.ClearButton(value="Remove and upload subject images", components=img_files, size="sm")
|
122 |
+
|
123 |
+
prompt = gr.Dropdown(label="Prompt",
|
124 |
+
info="Try something like 'man/woman walking on the beach'. If the face is not in focus, try adding 'face portrait of' at the beginning.",
|
125 |
+
value=None,
|
126 |
+
allow_custom_value=True,
|
127 |
+
filterable=False,
|
128 |
+
choices=[
|
129 |
+
"woman ((best quality)), ((masterpiece)), ((realistic)), long highlighted hair, futuristic silver armor suit, confident stance, high-resolution, living room, smiling, head tilted, perfect smooth skin",
|
130 |
+
"woman walking on the beach, sunset, orange sky",
|
131 |
+
"woman in a white apron and chef hat, garnishing a gourmet dish, full body view, long shot",
|
132 |
+
"woman dancing pose among folks in a park, waving hands",
|
133 |
+
"woman in iron man costume flying pose, the sky ablaze with hues of orange and purple, full body view, long shot",
|
134 |
+
"woman jedi wielding a lightsaber, star wars, full body view, eye level shot",
|
135 |
+
"woman playing guitar on a boat, ocean waves",
|
136 |
+
"woman with a passion for reading, curled up with a book in a cozy nook near a window",
|
137 |
+
"woman running pose in a park, eye level shot",
|
138 |
+
"woman in superman costume flying pose, the sky ablaze with hues of orange and purple, full body view, long shot"
|
139 |
+
])
|
140 |
+
|
141 |
+
submit = gr.Button("Submit", variant="primary")
|
142 |
+
|
143 |
+
negative_prompt = gr.Textbox(
|
144 |
+
label="Negative Prompt",
|
145 |
+
value="flaws in the eyes, flaws in the face, lowres, non-HDRi, low quality, worst quality, artifacts, noise, text, watermark, glitch, mutated, ugly, disfigured, hands, partially rendered objects, partially rendered eyes, deformed eyeballs, cross-eyed, blurry, mutation, duplicate, out of frame, cropped, mutilated, bad anatomy, deformed, bad proportions, nude, naked, nsfw, topless, bare breasts",
|
146 |
+
)
|
147 |
+
|
148 |
+
adaface_id_cfg_scale = gr.Slider(
|
149 |
+
label="AdaFace CFG Scale",
|
150 |
+
info="The CFG scale of the AdaFace ID embeddings (influencing fine facial features)",
|
151 |
+
minimum=0.5,
|
152 |
+
maximum=8.0,
|
153 |
+
step=0.5,
|
154 |
+
value=4.0,
|
155 |
+
)
|
156 |
+
|
157 |
+
guidance_scale = gr.Slider(
|
158 |
+
label="Guidance scale",
|
159 |
+
minimum=0.5,
|
160 |
+
maximum=8.0,
|
161 |
+
step=0.5,
|
162 |
+
value=4.0,
|
163 |
+
)
|
164 |
+
|
165 |
+
num_images = gr.Slider(
|
166 |
+
label="Number of output images",
|
167 |
+
minimum=1,
|
168 |
+
maximum=6,
|
169 |
+
step=1,
|
170 |
+
value=4,
|
171 |
+
)
|
172 |
+
seed = gr.Slider(
|
173 |
+
label="Seed",
|
174 |
+
minimum=0,
|
175 |
+
maximum=MAX_SEED,
|
176 |
+
step=1,
|
177 |
+
value=0,
|
178 |
+
)
|
179 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True, info="Uncheck for reproducible results")
|
180 |
+
|
181 |
+
with gr.Column():
|
182 |
+
out_gallery = gr.Gallery(label="Generated Images", columns=2, rows=2, height=600)
|
183 |
+
|
184 |
+
img_files.upload(fn=swap_to_gallery, inputs=img_files, outputs=[uploaded_files_gallery, clear_button_column, img_files])
|
185 |
+
remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files_gallery, clear_button_column, img_files])
|
186 |
+
|
187 |
+
submit.click(
|
188 |
+
fn=randomize_seed_fn,
|
189 |
+
inputs=[seed, randomize_seed],
|
190 |
+
outputs=seed,
|
191 |
+
queue=False,
|
192 |
+
api_name=False,
|
193 |
+
).then(
|
194 |
+
fn=generate_image,
|
195 |
+
inputs=[img_files, guidance_scale, adaface_id_cfg_scale, num_images, prompt, negative_prompt, seed],
|
196 |
+
outputs=[out_gallery]
|
197 |
+
).then(
|
198 |
+
fn=update_out_gallery,
|
199 |
+
inputs=[out_gallery],
|
200 |
+
outputs=[out_gallery]
|
201 |
+
)
|
202 |
+
|
203 |
+
demo.launch(share=True, server_name=args.ip, ssl_verify=False)
|
arc2face_models.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import CLIPTextModel
|
4 |
+
from transformers.models.clip.modeling_clip import CLIPAttention
|
5 |
+
from typing import Any, Callable, Dict, Optional, Tuple, Union, List
|
6 |
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
7 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
8 |
+
# from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
|
9 |
+
_make_causal_mask = AttentionMaskConverter._make_causal_mask
|
10 |
+
_expand_mask = AttentionMaskConverter._expand_mask
|
11 |
+
|
12 |
+
from adaface.util import add_noise_to_tensor
|
13 |
+
|
14 |
+
# Extend CLIPAttention by using multiple k_proj and v_proj in each head.
|
15 |
+
# To avoid too much increase of computation, we don't extend q_proj.
|
16 |
+
class CLIPAttentionMKV(nn.Module):
|
17 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
18 |
+
|
19 |
+
def __init__(self, config, multiplier=2):
|
20 |
+
super().__init__()
|
21 |
+
self.config = config
|
22 |
+
self.embed_dim = config.hidden_size
|
23 |
+
self.num_heads = config.num_attention_heads
|
24 |
+
self.head_dim = self.embed_dim // self.num_heads
|
25 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
26 |
+
raise ValueError(
|
27 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
28 |
+
f" {self.num_heads})."
|
29 |
+
)
|
30 |
+
self.scale = self.head_dim**-0.5
|
31 |
+
self.dropout = config.attention_dropout
|
32 |
+
self.multiplier = multiplier
|
33 |
+
|
34 |
+
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim * self.multiplier)
|
35 |
+
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim * self.multiplier)
|
36 |
+
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
37 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
|
38 |
+
|
39 |
+
# The (approximately) repeated token features are repeated along the last dim in tensor
|
40 |
+
# (multiplier * num_heads * head_dim), and then reshaped to (bsz, -1, num_heads, head_dim).
|
41 |
+
# Therefore, the "multiplier" dim is tucked into the seq_len dim, which looks like
|
42 |
+
# [token1_emb, token1_emb, token2_emb, token2_emb, ..., tokenN_emb, tokenN_emb].
|
43 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
44 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
45 |
+
|
46 |
+
def extend_weights(self, clip_attn_layer, layer_idx, multiplier, noise_std=0.1,
|
47 |
+
noise_std_is_relative=True, keep_norm=False, verbose=False):
|
48 |
+
self.multiplier *= multiplier
|
49 |
+
# q_proj and out_proj are the same as the original CLIPAttention.
|
50 |
+
self.q_proj.weight.data = clip_attn_layer.q_proj.weight.data.clone()
|
51 |
+
self.q_proj.bias.data = clip_attn_layer.q_proj.bias.data.clone()
|
52 |
+
self.out_proj.weight.data = clip_attn_layer.out_proj.weight.data.clone()
|
53 |
+
self.out_proj.bias.data = clip_attn_layer.out_proj.bias.data.clone()
|
54 |
+
|
55 |
+
# bias doesn't need noise perturbation, as after the weights are noised,
|
56 |
+
# different copies of the weight/bias will receive different gradients,
|
57 |
+
# making the bias terms diverge and identifiable after training.
|
58 |
+
self.v_proj.bias.data = clip_attn_layer.v_proj.bias.data.repeat(multiplier)
|
59 |
+
self.k_proj.bias.data = clip_attn_layer.k_proj.bias.data.repeat(multiplier)
|
60 |
+
|
61 |
+
self.v_proj.weight.data = clip_attn_layer.v_proj.weight.data.repeat(multiplier, 1)
|
62 |
+
self.k_proj.weight.data = clip_attn_layer.k_proj.weight.data.repeat(multiplier, 1)
|
63 |
+
|
64 |
+
if noise_std > 0:
|
65 |
+
ORIG_V_SHAPE = list(clip_attn_layer.v_proj.weight.shape)
|
66 |
+
ORIG_V_SHAPE_D0 = ORIG_V_SHAPE[0]
|
67 |
+
# Adding noise to the extra copies of the weights (keep the first copy unchanged).
|
68 |
+
self.v_proj.weight.data[ORIG_V_SHAPE_D0:] = \
|
69 |
+
add_noise_to_tensor(self.v_proj.weight.data[ORIG_V_SHAPE_D0:],
|
70 |
+
noise_std, noise_std_is_relative, keep_norm)
|
71 |
+
if verbose:
|
72 |
+
NEW_V_SHAPE = list(self.v_proj.weight.shape)
|
73 |
+
NOISED_V_SHAPE = list(self.v_proj.weight.data[ORIG_V_SHAPE_D0:].shape)
|
74 |
+
print(f"Layer {layer_idx}: {NOISED_V_SHAPE} in {NEW_V_SHAPE} of v_proj is added with {noise_std} noise")
|
75 |
+
|
76 |
+
ORIG_K_SHAPE = list(clip_attn_layer.k_proj.weight.shape)
|
77 |
+
ORIG_K_SHAPE_D0 = ORIG_K_SHAPE[0]
|
78 |
+
# Adding noise to the extra copies of the weights.
|
79 |
+
self.k_proj.weight.data[ORIG_K_SHAPE_D0:] = \
|
80 |
+
add_noise_to_tensor(self.k_proj.weight.data[ORIG_K_SHAPE_D0:],
|
81 |
+
noise_std, noise_std_is_relative, keep_norm)
|
82 |
+
if verbose:
|
83 |
+
NEW_K_SHAPE = list(self.k_proj.weight.shape)
|
84 |
+
NOISED_K_SHAPE = list(self.k_proj.weight.data[ORIG_K_SHAPE_D0:].shape)
|
85 |
+
print(f"Layer {layer_idx}: {NOISED_K_SHAPE} in {NEW_K_SHAPE} of k_proj is added with {noise_std} noise")
|
86 |
+
|
87 |
+
def forward(
|
88 |
+
self,
|
89 |
+
hidden_states: torch.Tensor,
|
90 |
+
attention_mask: Optional[torch.Tensor] = None,
|
91 |
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
92 |
+
output_attentions: Optional[bool] = False,
|
93 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
94 |
+
"""Input shape: Batch x Time x Channel"""
|
95 |
+
|
96 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
97 |
+
|
98 |
+
query_states = self.q_proj(hidden_states) * self.scale
|
99 |
+
# For key_states and value_states, the multiplier is absorbed into the seq_len (dim 1, shape specified as -1).
|
100 |
+
# [token0_head_emb, token0_head_emb, token1_head_emb, token1_head_emb, ..., tokenN-1_head_emb, tokenN-1_head_emb].
|
101 |
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
102 |
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
103 |
+
|
104 |
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
105 |
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
106 |
+
key_states = key_states.view(*proj_shape)
|
107 |
+
value_states = value_states.view(*proj_shape)
|
108 |
+
|
109 |
+
src_len = key_states.size(1)
|
110 |
+
# src_len0 is the original src_len without the multiplier.
|
111 |
+
src_len0 = src_len // self.multiplier
|
112 |
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
113 |
+
|
114 |
+
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
115 |
+
raise ValueError(
|
116 |
+
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
117 |
+
f" {attn_weights.size()}"
|
118 |
+
)
|
119 |
+
|
120 |
+
# apply the causal_attention_mask first
|
121 |
+
if causal_attention_mask is not None:
|
122 |
+
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len0):
|
123 |
+
raise ValueError(
|
124 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len0)}, but is"
|
125 |
+
f" {causal_attention_mask.size()}"
|
126 |
+
)
|
127 |
+
# The last dim of attn_weights corresponds to [token0, token0, token1, token1, ..., tokenN-1, tokenN-1].
|
128 |
+
# If reshaping it as (self.multiplier, src_len0), it will become
|
129 |
+
# [[token0, token0, token1, token1, ..., tokenN//2], [tokenN//2+1, tokenN//2+1, ..., tokenN-1, tokenN-1]],
|
130 |
+
# and the mask will be applied to wrong elements.
|
131 |
+
# If reshaping it as (src_len0, self.multiplier), it will become
|
132 |
+
# [[token0, token1, ..., tokenN-1], [token0, token1, ..., tokenN-1]], and then
|
133 |
+
# the mask at element i will mask all the multiplier elements at i, which is desired.
|
134 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len0, self.multiplier) + causal_attention_mask.unsqueeze(4)
|
135 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
136 |
+
|
137 |
+
if attention_mask is not None:
|
138 |
+
if attention_mask.size() != (bsz, 1, tgt_len, src_len0):
|
139 |
+
raise ValueError(
|
140 |
+
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len0)}, but is {attention_mask.size()}"
|
141 |
+
)
|
142 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len0, self.multiplier) + attention_mask.unsqueeze(4)
|
143 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
144 |
+
|
145 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
146 |
+
|
147 |
+
if output_attentions:
|
148 |
+
# this operation is a bit awkward, but it's required to
|
149 |
+
# make sure that attn_weights keeps its gradient.
|
150 |
+
# In order to do so, attn_weights have to reshaped
|
151 |
+
# twice and have to be reused in the following
|
152 |
+
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
153 |
+
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
154 |
+
else:
|
155 |
+
attn_weights_reshaped = None
|
156 |
+
|
157 |
+
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
158 |
+
|
159 |
+
attn_output = torch.bmm(attn_probs, value_states)
|
160 |
+
|
161 |
+
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
162 |
+
raise ValueError(
|
163 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
164 |
+
f" {attn_output.size()}"
|
165 |
+
)
|
166 |
+
|
167 |
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
168 |
+
attn_output = attn_output.transpose(1, 2)
|
169 |
+
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
|
170 |
+
|
171 |
+
attn_output = self.out_proj(attn_output)
|
172 |
+
|
173 |
+
return attn_output, attn_weights_reshaped
|
174 |
+
|
175 |
+
class CLIPTextModelWrapper(CLIPTextModel):
|
176 |
+
# Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812
|
177 |
+
# Modified to accept precomputed token embeddings "input_token_embs" as input or calculate them from input_ids and return them.
|
178 |
+
def forward(
|
179 |
+
self,
|
180 |
+
input_ids: Optional[torch.Tensor] = None,
|
181 |
+
attention_mask: Optional[torch.Tensor] = None,
|
182 |
+
position_ids: Optional[torch.Tensor] = None,
|
183 |
+
output_attentions: Optional[bool] = None,
|
184 |
+
output_hidden_states: Optional[bool] = None,
|
185 |
+
return_dict: Optional[bool] = None,
|
186 |
+
input_token_embs: Optional[torch.Tensor] = None,
|
187 |
+
hidden_state_layer_weights: Optional[torch.Tensor] = None,
|
188 |
+
return_token_embs: Optional[bool] = False,
|
189 |
+
) -> Union[Tuple, torch.Tensor, BaseModelOutputWithPooling]:
|
190 |
+
|
191 |
+
if return_token_embs:
|
192 |
+
return self.text_model.embeddings.token_embedding(input_ids)
|
193 |
+
|
194 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
195 |
+
|
196 |
+
output_attentions = output_attentions if output_attentions is not None else self.text_model.config.output_attentions
|
197 |
+
output_hidden_states = (
|
198 |
+
output_hidden_states if output_hidden_states is not None else self.text_model.config.output_hidden_states
|
199 |
+
)
|
200 |
+
if hidden_state_layer_weights is not None:
|
201 |
+
output_hidden_states = True
|
202 |
+
return_dict = return_dict if return_dict is not None else self.text_model.config.use_return_dict
|
203 |
+
|
204 |
+
if input_ids is None:
|
205 |
+
raise ValueError("You have to specify input_ids")
|
206 |
+
|
207 |
+
input_shape = input_ids.size()
|
208 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
209 |
+
|
210 |
+
hidden_states = self.text_model.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=input_token_embs)
|
211 |
+
|
212 |
+
# CLIP's text model uses causal mask, prepare it here.
|
213 |
+
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
214 |
+
causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
|
215 |
+
# expand attention_mask
|
216 |
+
if attention_mask is not None:
|
217 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
218 |
+
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
|
219 |
+
|
220 |
+
encoder_outputs = self.text_model.encoder(
|
221 |
+
inputs_embeds=hidden_states,
|
222 |
+
attention_mask=attention_mask,
|
223 |
+
causal_attention_mask=causal_attention_mask,
|
224 |
+
output_attentions=output_attentions,
|
225 |
+
# output_hidden_states is False by default, and only True if hidden_state_layer_weights is provided.
|
226 |
+
output_hidden_states=output_hidden_states,
|
227 |
+
return_dict=return_dict,
|
228 |
+
)
|
229 |
+
|
230 |
+
# If output_hidden_states is True, then encoder_outputs[0] is last_hidden_state [1, 22, 768].
|
231 |
+
# encoder_outputs[1] is hidden_states, which is a tuple of 13 hidden states, each being [1, 22, 768].
|
232 |
+
# encoder_outputs[0] == encoder_outputs[1][12].
|
233 |
+
if hidden_state_layer_weights is None:
|
234 |
+
last_hidden_state = encoder_outputs[0]
|
235 |
+
else:
|
236 |
+
num_hidden_state_layers = len(hidden_state_layer_weights)
|
237 |
+
last_hidden_states = encoder_outputs[1][-num_hidden_state_layers:]
|
238 |
+
hidden_state_layer_weights = hidden_state_layer_weights.to(last_hidden_states[0].dtype)
|
239 |
+
# Normalize the weights of to sum to 1 across layers.
|
240 |
+
# hidden_state_layer_weights: [3, 1] or [3, 768].
|
241 |
+
hidden_state_layer_weights = hidden_state_layer_weights / hidden_state_layer_weights.sum(dim=0, keepdim=True)
|
242 |
+
# [3, 1/768] -> [3, 1, 1, 1/768]
|
243 |
+
hidden_state_layer_weights = hidden_state_layer_weights.unsqueeze(1).unsqueeze(1)
|
244 |
+
# A weighted sum of last_hidden_states.
|
245 |
+
# [3, 1, 22, 768] * [3, 1, 1, 1/768] -> [3, 1, 22, 768] -> [1, 22, 768]
|
246 |
+
last_hidden_state = (torch.stack(last_hidden_states, dim=0) * hidden_state_layer_weights).sum(dim=0)
|
247 |
+
|
248 |
+
last_hidden_state = self.text_model.final_layer_norm(last_hidden_state)
|
249 |
+
|
250 |
+
# self.text_model.eos_token_id == 2 is True.
|
251 |
+
if self.text_model.eos_token_id == 2:
|
252 |
+
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
|
253 |
+
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
|
254 |
+
# ------------------------------------------------------------
|
255 |
+
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
256 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
257 |
+
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
258 |
+
pooled_output = last_hidden_state[
|
259 |
+
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
260 |
+
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
261 |
+
]
|
262 |
+
else:
|
263 |
+
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
|
264 |
+
pooled_output = last_hidden_state[
|
265 |
+
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
266 |
+
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
|
267 |
+
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.text_model.eos_token_id)
|
268 |
+
.int()
|
269 |
+
.argmax(dim=-1),
|
270 |
+
]
|
271 |
+
|
272 |
+
if not return_dict:
|
273 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
274 |
+
|
275 |
+
return BaseModelOutputWithPooling(
|
276 |
+
last_hidden_state=last_hidden_state,
|
277 |
+
pooler_output=pooled_output,
|
278 |
+
hidden_states=encoder_outputs.hidden_states,
|
279 |
+
attentions=encoder_outputs.attentions,
|
280 |
+
)
|
281 |
+
|
282 |
+
# Applied to layers [begin_layer_idx, end_layer_idx) in the encoder.
|
283 |
+
# The layer indexed by end_layer_idx is not included.
|
284 |
+
# If both layer indices are -1, then apply to all layers (0-11).
|
285 |
+
def extend_clip_attention_MKV_multiplier(self, begin_layer_idx=-1, end_layer_idx=-1, multiplier=2, noise_std=0.1):
|
286 |
+
num_extended_layers = 0
|
287 |
+
|
288 |
+
for layer_idx, layer in enumerate(self.text_model.encoder.layers):
|
289 |
+
if begin_layer_idx >= 0 and layer_idx < begin_layer_idx:
|
290 |
+
continue
|
291 |
+
if end_layer_idx >= 0 and layer_idx >= end_layer_idx:
|
292 |
+
break
|
293 |
+
# This shouldn't happen, unless self_attn has already been extended as CLIPAttentionMKV.
|
294 |
+
if not isinstance(layer.self_attn, (CLIPAttention, CLIPAttentionMKV)):
|
295 |
+
breakpoint()
|
296 |
+
old_attn_layer = layer.self_attn
|
297 |
+
if not isinstance(old_attn_layer, CLIPAttentionMKV):
|
298 |
+
layer.self_attn = CLIPAttentionMKV(old_attn_layer.config, 1)
|
299 |
+
layer.self_attn.extend_weights(old_attn_layer, layer_idx, multiplier, noise_std, verbose=True)
|
300 |
+
num_extended_layers += 1
|
301 |
+
|
302 |
+
return num_extended_layers
|
303 |
+
|
models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4aa1eb9ff3e364ea1b9db6dfff0c281ff3b57864d7ccc4c64d5f29ed752484f3
|
3 |
+
size 821700521
|
models/arc2face/arc2face/config.json
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "UNet2DConditionModel",
|
3 |
+
"_diffusers_version": "0.22.0",
|
4 |
+
"act_fn": "silu",
|
5 |
+
"addition_embed_type": null,
|
6 |
+
"addition_embed_type_num_heads": 64,
|
7 |
+
"addition_time_embed_dim": null,
|
8 |
+
"attention_head_dim": 8,
|
9 |
+
"attention_type": "default",
|
10 |
+
"block_out_channels": [
|
11 |
+
320,
|
12 |
+
640,
|
13 |
+
1280,
|
14 |
+
1280
|
15 |
+
],
|
16 |
+
"center_input_sample": false,
|
17 |
+
"class_embed_type": null,
|
18 |
+
"class_embeddings_concat": false,
|
19 |
+
"conv_in_kernel": 3,
|
20 |
+
"conv_out_kernel": 3,
|
21 |
+
"cross_attention_dim": 768,
|
22 |
+
"cross_attention_norm": null,
|
23 |
+
"down_block_types": [
|
24 |
+
"CrossAttnDownBlock2D",
|
25 |
+
"CrossAttnDownBlock2D",
|
26 |
+
"CrossAttnDownBlock2D",
|
27 |
+
"DownBlock2D"
|
28 |
+
],
|
29 |
+
"downsample_padding": 1,
|
30 |
+
"dropout": 0.0,
|
31 |
+
"dual_cross_attention": false,
|
32 |
+
"encoder_hid_dim": null,
|
33 |
+
"encoder_hid_dim_type": null,
|
34 |
+
"flip_sin_to_cos": true,
|
35 |
+
"freq_shift": 0,
|
36 |
+
"in_channels": 4,
|
37 |
+
"layers_per_block": 2,
|
38 |
+
"mid_block_only_cross_attention": null,
|
39 |
+
"mid_block_scale_factor": 1,
|
40 |
+
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
41 |
+
"norm_eps": 1e-05,
|
42 |
+
"norm_num_groups": 32,
|
43 |
+
"num_attention_heads": null,
|
44 |
+
"num_class_embeds": null,
|
45 |
+
"only_cross_attention": false,
|
46 |
+
"out_channels": 4,
|
47 |
+
"projection_class_embeddings_input_dim": null,
|
48 |
+
"resnet_out_scale_factor": 1.0,
|
49 |
+
"resnet_skip_time_act": false,
|
50 |
+
"resnet_time_scale_shift": "default",
|
51 |
+
"reverse_transformer_layers_per_block": null,
|
52 |
+
"sample_size": 64,
|
53 |
+
"time_cond_proj_dim": null,
|
54 |
+
"time_embedding_act_fn": null,
|
55 |
+
"time_embedding_dim": null,
|
56 |
+
"time_embedding_type": "positional",
|
57 |
+
"timestep_post_act": null,
|
58 |
+
"transformer_layers_per_block": 1,
|
59 |
+
"up_block_types": [
|
60 |
+
"UpBlock2D",
|
61 |
+
"CrossAttnUpBlock2D",
|
62 |
+
"CrossAttnUpBlock2D",
|
63 |
+
"CrossAttnUpBlock2D"
|
64 |
+
],
|
65 |
+
"upcast_attention": false,
|
66 |
+
"use_linear_projection": false
|
67 |
+
}
|
models/arc2face/arc2face/diffusion_pytorch_model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d2377c16b7135650ca375817a4812a999194fba1f081e39117bd54e50dacc784
|
3 |
+
size 3438167536
|
models/arc2face/encoder/config.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"CLIPTextModel"
|
4 |
+
],
|
5 |
+
"attention_dropout": 0.0,
|
6 |
+
"bos_token_id": 0,
|
7 |
+
"dropout": 0.0,
|
8 |
+
"eos_token_id": 2,
|
9 |
+
"hidden_act": "quick_gelu",
|
10 |
+
"hidden_size": 768,
|
11 |
+
"initializer_factor": 1.0,
|
12 |
+
"initializer_range": 0.02,
|
13 |
+
"intermediate_size": 3072,
|
14 |
+
"layer_norm_eps": 1e-05,
|
15 |
+
"max_position_embeddings": 77,
|
16 |
+
"model_type": "clip_text_model",
|
17 |
+
"num_attention_heads": 12,
|
18 |
+
"num_hidden_layers": 12,
|
19 |
+
"pad_token_id": 1,
|
20 |
+
"projection_dim": 768,
|
21 |
+
"torch_dtype": "float32",
|
22 |
+
"transformers_version": "4.34.1",
|
23 |
+
"vocab_size": 49408
|
24 |
+
}
|
models/arc2face/encoder/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e2d364df774b7d3975f85de42bda73c0c0cdb952273dd5f138511b6cf65424aa
|
3 |
+
size 492308829
|
models/insightface/models/antelopev2/1k3d68.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc
|
3 |
+
size 143607619
|
models/insightface/models/antelopev2/2d106det.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf
|
3 |
+
size 5030888
|
models/insightface/models/antelopev2/arcface.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ec639a0429b4819130d1405a2d3b38beaa4cc4a6c5bd9cf48b94fdf65461de83
|
3 |
+
size 260694151
|
models/insightface/models/antelopev2/genderage.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb
|
3 |
+
size 1322532
|
models/insightface/models/antelopev2/scrfd_10g_bnkps.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91
|
3 |
+
size 16923827
|
models/insightface/models/buffalo_l/1k3d68.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc
|
3 |
+
size 143607619
|
models/insightface/models/buffalo_l/2d106det.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf
|
3 |
+
size 5030888
|
models/insightface/models/buffalo_l/det_10g.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91
|
3 |
+
size 16923827
|
models/insightface/models/buffalo_l/genderage.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb
|
3 |
+
size 1322532
|
models/insightface/models/buffalo_l/w600k_r50.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4c06341c33c2ca1f86781dab0e829f88ad5b64be9fba56e56bc9ebdefc619e43
|
3 |
+
size 174383860
|
models/sar/sar.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:35a5d7615850879ffecce7b1e463ae0317c95fe784dd9b179793b58531a9e3ab
|
3 |
+
size 2299982596
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
einops
|
4 |
+
gradio
|
5 |
+
transformers
|
6 |
+
insightface
|
7 |
+
opencv-python
|
8 |
+
diffusers
|
9 |
+
onnx>=1.16.0
|
10 |
+
onnxruntime
|
11 |
+
safetensors
|
12 |
+
spaces
|
subj_basis_generator.py
ADDED
@@ -0,0 +1,758 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Borrowed from ip-adapter resampler.py.
|
2 |
+
# https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
|
3 |
+
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
4 |
+
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from einops import rearrange
|
12 |
+
from einops.layers.torch import Rearrange
|
13 |
+
from transformers import CLIPVisionModel, CLIPTokenizer
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
from torch import einsum
|
17 |
+
from dataclasses import dataclass
|
18 |
+
from typing import Optional, Tuple
|
19 |
+
from transformers.utils import ModelOutput
|
20 |
+
from adaface.util import arc2face_inverse_face_prompt_embs, gen_gradient_scaler
|
21 |
+
from adaface.arc2face_models import CLIPTextModelWrapper
|
22 |
+
import sys
|
23 |
+
sys.modules['ldm'] = sys.modules['adaface']
|
24 |
+
|
25 |
+
def reshape_tensor(x, num_heads):
|
26 |
+
bs, length, width = x.shape
|
27 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
28 |
+
x = x.view(bs, length, num_heads, -1)
|
29 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
30 |
+
x = x.transpose(1, 2)
|
31 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
32 |
+
x = x.reshape(bs, num_heads, length, -1)
|
33 |
+
return x
|
34 |
+
|
35 |
+
# FFN. Added a Dropout layer at the end, so that it can still load the old ckpt.
|
36 |
+
def FeedForward(dim, mult=4, p_dropout=0.1):
|
37 |
+
inner_dim = int(dim * mult)
|
38 |
+
return nn.Sequential(
|
39 |
+
nn.LayerNorm(dim),
|
40 |
+
nn.Linear(dim, inner_dim, bias=False),
|
41 |
+
nn.GELU(),
|
42 |
+
nn.Linear(inner_dim, dim, bias=False),
|
43 |
+
nn.Dropout(p_dropout),
|
44 |
+
)
|
45 |
+
|
46 |
+
# IP-Adapter FaceID class. Only used in knn-faces.py.
|
47 |
+
# From: https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter_faceid_separate.py
|
48 |
+
class IP_MLPProjModel(nn.Module):
|
49 |
+
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
50 |
+
super().__init__()
|
51 |
+
|
52 |
+
self.cross_attention_dim = cross_attention_dim
|
53 |
+
self.num_tokens = num_tokens
|
54 |
+
|
55 |
+
self.proj = nn.Sequential(
|
56 |
+
nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
|
57 |
+
nn.GELU(),
|
58 |
+
nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
|
59 |
+
)
|
60 |
+
self.norm = nn.LayerNorm(cross_attention_dim)
|
61 |
+
|
62 |
+
def forward(self, id_embeds):
|
63 |
+
x = self.proj(id_embeds)
|
64 |
+
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
65 |
+
x = self.norm(x)
|
66 |
+
return x
|
67 |
+
|
68 |
+
# group_dim: the tensor dimension that corresponds to the multiple groups.
|
69 |
+
class LearnedSoftAggregate(nn.Module):
|
70 |
+
def __init__(self, num_feat, group_dim, keepdim=False):
|
71 |
+
super(LearnedSoftAggregate, self).__init__()
|
72 |
+
self.group_dim = group_dim
|
73 |
+
# num_feat = 1: element-wise score function & softmax.
|
74 |
+
# num_feat > 1: the linear score function is applied to the last dim (features) of the input tensor.
|
75 |
+
self.num_feat = num_feat
|
76 |
+
self.feat2score = nn.Linear(num_feat, 1, bias=False)
|
77 |
+
self.keepdim = keepdim
|
78 |
+
|
79 |
+
def forward(self, x, score_basis=None):
|
80 |
+
# If there's only one mode, do nothing.
|
81 |
+
if x.shape[self.group_dim] == 1:
|
82 |
+
if self.keepdim:
|
83 |
+
return x
|
84 |
+
else:
|
85 |
+
return x.squeeze(self.group_dim)
|
86 |
+
|
87 |
+
# Assume the last dim of x is the feature dim.
|
88 |
+
if score_basis is None:
|
89 |
+
score_basis = x
|
90 |
+
|
91 |
+
if self.num_feat == 1:
|
92 |
+
mode_scores = self.feat2score(score_basis.unsqueeze(-1)).squeeze(-1)
|
93 |
+
else:
|
94 |
+
mode_scores = self.feat2score(score_basis)
|
95 |
+
attn_probs = mode_scores.softmax(dim=self.group_dim)
|
96 |
+
x_aggr = (x * attn_probs).sum(dim=self.group_dim, keepdim=self.keepdim)
|
97 |
+
return x_aggr
|
98 |
+
|
99 |
+
def LoRA_ExpandEmbs(input_dim, lora_rank, output_dim, num_modes,
|
100 |
+
num_output_vecs, elementwise_affine=True, p_dropout=0.1):
|
101 |
+
return nn.Sequential(
|
102 |
+
# Project to [BS, lora_rank * output_dim * num_modes].
|
103 |
+
# It takes a huge param size. 512 * 32 * 768 * 4 = 6,291,456.
|
104 |
+
nn.Linear(input_dim, lora_rank * output_dim * num_modes, bias=False),
|
105 |
+
# Reshape to [BS, lora_rank, output_dim].
|
106 |
+
Rearrange('b (m q d) -> b m q d', q=lora_rank, m=num_modes, d=output_dim),
|
107 |
+
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
108 |
+
# Aggregate [BS, num_modes, loar_rank, output_dim] -> [BS, lora_rank, output_dim].
|
109 |
+
LearnedSoftAggregate(num_feat=output_dim, group_dim=1, keepdim=False) if num_modes > 1 \
|
110 |
+
else Rearrange('b () q d -> b q d'),
|
111 |
+
nn.Dropout(p_dropout),
|
112 |
+
# Permute to [BS, output_dim, lora_rank].
|
113 |
+
Rearrange('b q d -> b d q'),
|
114 |
+
# Project to [BS, output_dim, num_output_vecs].
|
115 |
+
nn.Linear(lora_rank, num_output_vecs, bias=False),
|
116 |
+
# Permute to [BS, num_output_vecs, output_dim].
|
117 |
+
Rearrange('b d q -> b q d'),
|
118 |
+
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
119 |
+
nn.Dropout(p_dropout),
|
120 |
+
)
|
121 |
+
|
122 |
+
def ExpandEmbs(input_dim, output_dim, expansion_ratio, elementwise_affine=True, p_dropout=0.1):
|
123 |
+
return nn.Sequential(
|
124 |
+
# Project to [BS, num_output_vecs * output_dim].
|
125 |
+
nn.Linear(input_dim, expansion_ratio * output_dim, bias=False),
|
126 |
+
# Reshape to [BS, num_output_vecs, output_dim].
|
127 |
+
Rearrange('b (e d) -> b e d', e=expansion_ratio, d=output_dim),
|
128 |
+
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
129 |
+
nn.Dropout(p_dropout),
|
130 |
+
)
|
131 |
+
|
132 |
+
# Input: [BS, N, D].
|
133 |
+
def MultimodeProjection(input_dim, output_dim=-1, num_modes=4, elementwise_affine=True, p_dropout=0.1):
|
134 |
+
if output_dim == -1:
|
135 |
+
output_dim = input_dim
|
136 |
+
|
137 |
+
return nn.Sequential(
|
138 |
+
nn.Linear(input_dim, output_dim * num_modes, bias=False),
|
139 |
+
# Reshape to [BS, num_output_vecs, output_dim].
|
140 |
+
Rearrange('b n (m d) -> b n m d', m=num_modes, d=output_dim),
|
141 |
+
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
142 |
+
# If num_modes == 1, then simply remove the mode dim. Otherwise, aggregate the modes.
|
143 |
+
LearnedSoftAggregate(num_feat=output_dim, group_dim=2, keepdim=False) if num_modes > 1 \
|
144 |
+
else Rearrange('b n () d -> b n d'),
|
145 |
+
nn.Dropout(p_dropout),
|
146 |
+
)
|
147 |
+
|
148 |
+
# Low-rank to high-rank transformation.
|
149 |
+
def Lora2Hira(lora_rank, hira_rank, output_dim, num_modes, elementwise_affine=True, p_dropout=0.1):
|
150 |
+
return nn.Sequential(
|
151 |
+
# Permute to [BS, output_dim, lora_rank].
|
152 |
+
Rearrange('b q d -> b d q'),
|
153 |
+
# Project to [BS, output_dim, hira_rank].
|
154 |
+
nn.Linear(lora_rank, hira_rank * num_modes, bias=False),
|
155 |
+
# Reshape and permute to [BS, num_modes, num_output_vecs, output_dim].
|
156 |
+
Rearrange('b d (m q) -> b m q d', m=num_modes, q=hira_rank),
|
157 |
+
nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
|
158 |
+
# Aggregate [BS, num_modes, hira_rank, output_dim] -> [BS, hira_rank, output_dim].
|
159 |
+
LearnedSoftAggregate(num_feat=output_dim, group_dim=1, keepdim=False) if num_modes > 1 \
|
160 |
+
else Rearrange('b () q d -> b q d'),
|
161 |
+
nn.Dropout(p_dropout),
|
162 |
+
)
|
163 |
+
|
164 |
+
class PerceiverAttention(nn.Module):
|
165 |
+
def __init__(self, *, dim, dim_head=64, num_heads=8, elementwise_affine=True):
|
166 |
+
super().__init__()
|
167 |
+
self.scale = dim_head**-0.5
|
168 |
+
self.dim_head = dim_head
|
169 |
+
self.num_heads = num_heads
|
170 |
+
inner_dim = dim_head * num_heads
|
171 |
+
|
172 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine)
|
173 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine)
|
174 |
+
|
175 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
176 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
177 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
178 |
+
|
179 |
+
def forward(self, x, latent_queries):
|
180 |
+
"""
|
181 |
+
Args:
|
182 |
+
x (torch.Tensor): image features
|
183 |
+
shape (b, n1, D)
|
184 |
+
latent (torch.Tensor): latent features
|
185 |
+
shape (b, n2, D)
|
186 |
+
"""
|
187 |
+
x = self.norm1(x)
|
188 |
+
latent_queries = self.norm2(latent_queries)
|
189 |
+
|
190 |
+
b, l, _ = latent_queries.shape
|
191 |
+
|
192 |
+
q = self.to_q(latent_queries)
|
193 |
+
kv_input = torch.cat((x, latent_queries), dim=-2)
|
194 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
195 |
+
|
196 |
+
q = reshape_tensor(q, self.num_heads)
|
197 |
+
k = reshape_tensor(k, self.num_heads)
|
198 |
+
v = reshape_tensor(v, self.num_heads)
|
199 |
+
|
200 |
+
# attention
|
201 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
202 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
203 |
+
attn = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
204 |
+
out = attn @ v
|
205 |
+
|
206 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
207 |
+
|
208 |
+
return self.to_out(out)
|
209 |
+
|
210 |
+
|
211 |
+
class CrossAttention(nn.Module):
|
212 |
+
# output_dim is always the same as input_dim.
|
213 |
+
# num_q only matters when q_aware_to_v is True.
|
214 |
+
# If q_aware_to_v is False, query x in forward() is still usable.
|
215 |
+
def __init__(self, input_dim, num_heads=6, p_dropout=0.05,
|
216 |
+
identity_to_q=False, identity_to_k=False, identity_to_v=False, v_has_skip=True,
|
217 |
+
q_aware_to_v=True, num_q=416, v_repeat=4, q_aware_to_v_lora_rank=64,
|
218 |
+
identity_to_out=False, out_has_skip=False):
|
219 |
+
super().__init__()
|
220 |
+
dim_head = input_dim // num_heads
|
221 |
+
inner_dim = dim_head * num_heads
|
222 |
+
|
223 |
+
self.num_heads = num_heads
|
224 |
+
self.q_aware_to_v = q_aware_to_v
|
225 |
+
self.v_has_skip = v_has_skip
|
226 |
+
self.to_q = nn.Sequential(
|
227 |
+
nn.Linear(input_dim, inner_dim, bias=False),
|
228 |
+
nn.LayerNorm(inner_dim, elementwise_affine=True)
|
229 |
+
) if not identity_to_q else nn.Identity()
|
230 |
+
self.to_k = nn.Sequential(
|
231 |
+
nn.Linear(input_dim, inner_dim, bias=False),
|
232 |
+
nn.LayerNorm(inner_dim, elementwise_affine=True)
|
233 |
+
) if not identity_to_k else nn.Identity()
|
234 |
+
|
235 |
+
self.v_repeat = v_repeat
|
236 |
+
self.num_q_group = num_q_group = num_q // v_repeat # 416 / 4 = 104.
|
237 |
+
|
238 |
+
# If q_aware_to_v is True, then self.to_v consists of num_q projections of input_dim to inner_dim.
|
239 |
+
# Otherwise, self.to_v consists of a single projection of input_dim to inner_dim.
|
240 |
+
if q_aware_to_v:
|
241 |
+
# all_q_mid: 104 * 64 = 6656.
|
242 |
+
all_q_mid = num_q_group * q_aware_to_v_lora_rank
|
243 |
+
self.to_v = nn.Sequential(
|
244 |
+
# number of params: 768 * 6656 = 5,111,808.
|
245 |
+
# Input: [BS, 16, 768]. Output: [BS, 16, 104*64] = [BS, 16, 6656].
|
246 |
+
# Each 768-dim vec is dispersed into 104 64-dim vecs.
|
247 |
+
nn.Linear(input_dim, all_q_mid, bias=False),
|
248 |
+
nn.LayerNorm(all_q_mid, elementwise_affine=True),
|
249 |
+
# Change the dim of the tensor to [BS, 6656, 16], as Conv1d transforms dim 1.
|
250 |
+
Rearrange('b n q -> b q n', q=all_q_mid),
|
251 |
+
# Each q_aware_to_v projection has its own linear layer.
|
252 |
+
# The total number of parameters will be 6656*768 = 5,111,808.
|
253 |
+
# Output: [BS, 104*768, 16]. Each 64 dim feature is expanded to 768 dim.
|
254 |
+
nn.Conv1d(
|
255 |
+
in_channels=all_q_mid,
|
256 |
+
out_channels=num_q_group * input_dim,
|
257 |
+
kernel_size=1,
|
258 |
+
groups=num_q_group,
|
259 |
+
bias=False,
|
260 |
+
),
|
261 |
+
# Output: [BS, 104, 16, 768].
|
262 |
+
Rearrange('b (q d) n -> b q n d', q=num_q_group, d=input_dim),
|
263 |
+
nn.LayerNorm(input_dim, elementwise_affine=True),
|
264 |
+
)
|
265 |
+
else:
|
266 |
+
self.to_v = nn.Sequential(
|
267 |
+
nn.Linear(input_dim, inner_dim, bias=False),
|
268 |
+
nn.LayerNorm(inner_dim, elementwise_affine=True)
|
269 |
+
) if not identity_to_v else nn.Identity()
|
270 |
+
|
271 |
+
if identity_to_out:
|
272 |
+
assert not out_has_skip, "identity_to_out=True, then out_has_skip has to be False."
|
273 |
+
|
274 |
+
if identity_to_out:
|
275 |
+
self.to_out = nn.Identity()
|
276 |
+
else:
|
277 |
+
self.to_out = nn.Sequential(
|
278 |
+
nn.Linear(input_dim, input_dim, bias=False),
|
279 |
+
nn.Dropout(p_dropout),
|
280 |
+
nn.LayerNorm(inner_dim, elementwise_affine=True)
|
281 |
+
)
|
282 |
+
|
283 |
+
self.out_has_skip = out_has_skip
|
284 |
+
self.attn_drop = nn.Dropout(p_dropout)
|
285 |
+
|
286 |
+
def forward(self, x, context=None, attn_mat=None, return_attn=False):
|
287 |
+
h = self.num_heads
|
288 |
+
|
289 |
+
if context is None:
|
290 |
+
context = x
|
291 |
+
|
292 |
+
if attn_mat is None:
|
293 |
+
# q: [BS, Q, D] -> [BS, Q, D].
|
294 |
+
q = self.to_q(x)
|
295 |
+
# k: [BS, L, D] -> [BS, L, D].
|
296 |
+
k = self.to_k(context)
|
297 |
+
# q: [6, 512, 128], k: [6, 17, 128].
|
298 |
+
q, k = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k))
|
299 |
+
|
300 |
+
if self.q_aware_to_v:
|
301 |
+
# context: [BS, L, D]. v: [BS, Q, L, D].
|
302 |
+
# There are effectively Q to_v projections.
|
303 |
+
v = self.to_v(context)
|
304 |
+
if self.v_has_skip:
|
305 |
+
v = v + context.unsqueeze(1)
|
306 |
+
else:
|
307 |
+
# v: [BS, L, D].
|
308 |
+
v = self.to_v(context)
|
309 |
+
if self.v_has_skip:
|
310 |
+
v = v + context
|
311 |
+
|
312 |
+
#print(v.shape)
|
313 |
+
|
314 |
+
if self.q_aware_to_v:
|
315 |
+
# v: [6, 64, 17, 128].
|
316 |
+
# v is query-specific, so there's an extra dim for the query.
|
317 |
+
v = rearrange(v, 'b q n (h d) -> (b h) q n d', h=h)
|
318 |
+
# Each v is for a query group with 512/64 = 8 queries.
|
319 |
+
# So each v is repeated 8 times to match the number of queries.
|
320 |
+
# v: [6, 64, 17, 128] -> [6, 512, 17, 128].
|
321 |
+
v = v.repeat(1, self.v_repeat, 1, 1)
|
322 |
+
else:
|
323 |
+
v = rearrange(v, 'b n (h d) -> (b h) n d', h=h)
|
324 |
+
|
325 |
+
if attn_mat is None:
|
326 |
+
scale = q.size(-1) ** -0.25
|
327 |
+
sim = einsum('b i d, b j d -> b i j', q * scale, k * scale)
|
328 |
+
# sim: [6, 64, 17]. 6: bs 1 * h 6.
|
329 |
+
# attention, what we cannot get enough of
|
330 |
+
# NOTE: the normalization is done across tokens, not across pixels.
|
331 |
+
# So for each pixel, the sum of attention scores across tokens is 1.
|
332 |
+
attn = sim.softmax(dim=-1)
|
333 |
+
attn = self.attn_drop(attn)
|
334 |
+
#print(attn.std())
|
335 |
+
else:
|
336 |
+
attn = attn_mat
|
337 |
+
|
338 |
+
if self.q_aware_to_v:
|
339 |
+
# attn: [6, 32, 17]. v: [6, 32, 17, 128]. 128: dim of each head. out: [6, 32, 128].
|
340 |
+
# out is combined with different attn weights and v for different queries.
|
341 |
+
out = einsum('b i j, b i j d -> b i d', attn, v)
|
342 |
+
else:
|
343 |
+
# v: [6, 17, 128]. out: [6, 32, 128].
|
344 |
+
out = einsum('b i j, b j d -> b i d', attn, v)
|
345 |
+
|
346 |
+
# [6, 32, 128] -> [1, 32, 768].
|
347 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
348 |
+
|
349 |
+
if self.out_has_skip:
|
350 |
+
out = self.to_out(out) + out
|
351 |
+
else:
|
352 |
+
out = self.to_out(out)
|
353 |
+
|
354 |
+
if return_attn:
|
355 |
+
return out, attn
|
356 |
+
else:
|
357 |
+
return out
|
358 |
+
|
359 |
+
class SubjBasisGenerator(nn.Module):
|
360 |
+
def __init__(
|
361 |
+
self,
|
362 |
+
# number of cross-attention heads. Half of the number of heads 12 of OpenAI clip-vit-large-patch14:
|
363 |
+
# https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
|
364 |
+
num_heads=6,
|
365 |
+
num_id_vecs={ 'subj': 77, 'bg': 257 }, # number of identity vectors. 18: 16 face tokens + 2 extra tokens. 257: 257 CLIP tokens.
|
366 |
+
num_out_embs_per_layer=4, # num_out_embs. subj: 16. bg: 4.
|
367 |
+
num_out_layers=16, # number of layers of output embeddings.
|
368 |
+
image_embedding_dim=768, # CLIP image feature dimension, as per config.json above.
|
369 |
+
# DINO vits16 has 6 attention heads:
|
370 |
+
# https://huggingface.co/facebook/dino-vits16/blob/main/config.json
|
371 |
+
dino_embedding_dim=384, # DINO object feature dimension for objects.
|
372 |
+
output_dim=768, # CLIP text embedding input dimension.
|
373 |
+
placeholder_is_bg: bool = False, # Whether the placeholder is for the image background.
|
374 |
+
prompt2token_proj_grad_scale: float = 0.4, # Gradient scale for prompt2token_proj.
|
375 |
+
zs_extra_words_scale: float = 0.5, # Scale for extra words in the prompt2token_proj.
|
376 |
+
learnable_hidden_state_weights_scheme: str = 'per-layer', # none, per-layer.
|
377 |
+
bg_prompt_translator_has_to_out_proj: bool = False, # Whether the prompt_trans_layers have a to_out projection.
|
378 |
+
):
|
379 |
+
super().__init__()
|
380 |
+
|
381 |
+
self.placeholder_is_bg = placeholder_is_bg
|
382 |
+
self.num_out_layers = num_out_layers
|
383 |
+
self.num_out_embs_per_layer = num_out_embs_per_layer
|
384 |
+
# subj: 64, bg: 32.
|
385 |
+
self.num_out_embs = num_out_layers * num_out_embs_per_layer
|
386 |
+
self.output_dim = output_dim
|
387 |
+
# num_id_vecs should be the number of core ID embs, 16.
|
388 |
+
# However, in such case, pos_embs is not used. So it doesn't matter if it's wrongly set.
|
389 |
+
self.num_id_vecs = num_id_vecs['bg'] if placeholder_is_bg else num_id_vecs['subj']
|
390 |
+
self.pos_embs = nn.Parameter(torch.randn(1, self.num_id_vecs, output_dim))
|
391 |
+
self.pos_embs_ln = nn.LayerNorm(output_dim)
|
392 |
+
self.zs_extra_words_scale = zs_extra_words_scale
|
393 |
+
self.output_scale = output_dim ** -0.5
|
394 |
+
self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
395 |
+
|
396 |
+
if not self.placeholder_is_bg:
|
397 |
+
# [1, 384] -> [1, 16, 768].
|
398 |
+
# TODO: use CLIPTextModelWrapper as obj_proj_in.
|
399 |
+
self.obj_proj_in = ExpandEmbs(dino_embedding_dim, output_dim, expansion_ratio=self.num_id_vecs)
|
400 |
+
|
401 |
+
# self.prompt2token_proj: [1, 16, 768] -> [1, 77, 768] (with paddings).
|
402 |
+
# If self.placeholder_is_bg: prompt2token_proj is set to None.
|
403 |
+
self.prompt2token_proj = CLIPTextModelWrapper.from_pretrained('openai/clip-vit-large-patch14')
|
404 |
+
self.prompt2token_proj_grad_scale = prompt2token_proj_grad_scale
|
405 |
+
self.prompt2token_proj_grad_scaler = gen_gradient_scaler(prompt2token_proj_grad_scale)
|
406 |
+
print(f"Subj prompt2token_proj initialized with grad scale of {prompt2token_proj_grad_scale}.")
|
407 |
+
# Freeze prompt2token_proj if prompt2token_proj_grad_scale is 0.
|
408 |
+
# Set requires_grad to False for all parameters in prompt2token_proj, to save memory taken by the optimizer.
|
409 |
+
if prompt2token_proj_grad_scale == 0:
|
410 |
+
self.freeze_prompt2token_proj()
|
411 |
+
|
412 |
+
self.prompt2token_proj_attention_multiplier = -1
|
413 |
+
self.initialize_hidden_state_layer_weights(learnable_hidden_state_weights_scheme, 'cpu')
|
414 |
+
self.pad_embeddings = None
|
415 |
+
self.bg_proj_in = None
|
416 |
+
else:
|
417 |
+
# For background placeholders, face and object embeddings are not used as they are foreground.
|
418 |
+
self.obj_proj_in = None
|
419 |
+
self.prompt2token_proj = None
|
420 |
+
print("Bg prompt2token_proj is set to None.")
|
421 |
+
|
422 |
+
self.bg_proj_in = nn.Sequential(
|
423 |
+
nn.Linear(image_embedding_dim, output_dim, bias=False),
|
424 |
+
nn.LayerNorm(output_dim),
|
425 |
+
)
|
426 |
+
|
427 |
+
self.latent_queries = nn.Parameter(torch.randn(1, self.num_out_embs, output_dim))
|
428 |
+
self.latent_queries_ln = nn.LayerNorm(output_dim)
|
429 |
+
|
430 |
+
self.bg_prompt_translator_has_to_out_proj = bg_prompt_translator_has_to_out_proj
|
431 |
+
identity_to_v = False
|
432 |
+
v_has_skip = not identity_to_v # True
|
433 |
+
identity_to_out = not bg_prompt_translator_has_to_out_proj # True
|
434 |
+
out_has_skip = not identity_to_out # False
|
435 |
+
# prompt_translator has a to_v projection with skip connection, and doesn't have a to_out projection.
|
436 |
+
# dim=768, num_heads=6.
|
437 |
+
self.prompt_translator = \
|
438 |
+
CrossAttention(input_dim=output_dim, num_heads=num_heads, p_dropout=0.05,
|
439 |
+
identity_to_q=False, identity_to_k=False, identity_to_v=identity_to_v,
|
440 |
+
q_aware_to_v=False, v_has_skip=v_has_skip,
|
441 |
+
num_q=0, # When not q_aware_to_v, num_q is not referenced.
|
442 |
+
identity_to_out=identity_to_out,
|
443 |
+
out_has_skip=out_has_skip)
|
444 |
+
'''
|
445 |
+
prompt_translator: CLIPEncoder
|
446 |
+
# https://github.com/huggingface/transformers/blob/1872bde7fc6a5d6796bd742bc2dc38eaf8069c5d/src/transformers/models/clip/modeling_clip.py#L566
|
447 |
+
# CLIPEncoder.layers: 12 layers of CLIPEncoderLayer, each being
|
448 |
+
(0): CLIPEncoderLayer(
|
449 |
+
(self_attn): CLIPAttention(
|
450 |
+
(k_proj): Linear(in_features=768, out_features=768, bias=True)
|
451 |
+
(v_proj): Linear(in_features=768, out_features=768, bias=True)
|
452 |
+
(q_proj): Linear(in_features=768, out_features=768, bias=True)
|
453 |
+
(out_proj): Linear(in_features=768, out_features=768, bias=True)
|
454 |
+
)
|
455 |
+
(layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
456 |
+
(mlp): CLIPMLP(
|
457 |
+
(activation_fn): QuickGELUActivation()
|
458 |
+
(fc1): Linear(in_features=768, out_features=3072, bias=True)
|
459 |
+
(fc2): Linear(in_features=3072, out_features=768, bias=True)
|
460 |
+
)
|
461 |
+
(layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
|
462 |
+
)
|
463 |
+
'''
|
464 |
+
|
465 |
+
print(repr(self))
|
466 |
+
|
467 |
+
# raw_id_embs: ArcFace embeddings for faces (not used since we have arc2face_id_embs),
|
468 |
+
# or DINO embeddings for objects.
|
469 |
+
# arc2face_id_embs: [BS, 16, 768], the core identity embeddings generated by Arc2Face.
|
470 |
+
def forward(self, arc2face_id_embs, clip_features=None, raw_id_embs=None, out_id_embs_scale=1.0,
|
471 |
+
is_face=True, is_training=False, adaface_prompt_embs_inf_type='full_half_pad'):
|
472 |
+
|
473 |
+
if not self.placeholder_is_bg:
|
474 |
+
BS = arc2face_id_embs.shape[0]
|
475 |
+
else:
|
476 |
+
# If bg, then arc2face_id_embs is set to None, but clip_features is not None.
|
477 |
+
BS = clip_features.shape[0]
|
478 |
+
|
479 |
+
adaface_prompt_embs = None
|
480 |
+
if not hasattr(self, 'clip_tokenizer'):
|
481 |
+
self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
482 |
+
|
483 |
+
# No need to use raw_id_embs if placeholder_is_bg.
|
484 |
+
if not self.placeholder_is_bg:
|
485 |
+
if is_face:
|
486 |
+
assert arc2face_id_embs is not None
|
487 |
+
# arc2face_embs has been projected to the (modified) prompt embedding space
|
488 |
+
# by arc2face_forward_face_embs. This prompt embedding space is modified because Arc2Face finetuned
|
489 |
+
# the text encoder and the U-Net.
|
490 |
+
# in embedding_manager: [BS, 16, 768] -> [BS, 77, 768].
|
491 |
+
# arc2face_id_embs is part of arc2face_embs: [BS, 77, 768] -> [BS, 16, 768].
|
492 |
+
# adaface_prompt_embs is projected to the prompt embedding spaces. This is the
|
493 |
+
# original U-Net prompt embedding space.
|
494 |
+
|
495 |
+
# hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]]
|
496 |
+
hidden_state_layer_weights = self.hidden_state_layer_weights_grad_scaler(self.hidden_state_layer_weights)
|
497 |
+
# return_emb_types: a list of strings, each string is among
|
498 |
+
# ['full', 'core', 'full_pad', 'full_half_pad', 'full_zeroed_extra', 'b_core_e'].
|
499 |
+
# Using b_core_e is more computationally efficient than using full_zeroed_extra.
|
500 |
+
# But there is an unknow BUG that causes crash when using b_core_e.
|
501 |
+
if is_training:
|
502 |
+
return_emb_types = ['full_pad', 'core']
|
503 |
+
else:
|
504 |
+
# adaface_prompt_embs_inf_type: default is full_half_pad, same as training.
|
505 |
+
return_emb_types = [adaface_prompt_embs_inf_type, 'core']
|
506 |
+
|
507 |
+
if self.pad_embeddings is None:
|
508 |
+
self.generate_pad_embeddings()
|
509 |
+
else:
|
510 |
+
self.pad_embeddings = self.pad_embeddings.to(arc2face_id_embs.device)
|
511 |
+
|
512 |
+
with torch.set_grad_enabled(self.training and self.prompt2token_proj_grad_scale != 0):
|
513 |
+
# If list_extra_words is not None, then core_id_embs: [BS, 18, 768], three leading words, the 16 identity tokens
|
514 |
+
# and (at most) two extra words in full_prompt_embs, without BOS and EOS.
|
515 |
+
# If list_extra_words is None, then core_id_embs: [BS, 16, 768], the 16 identity tokens in full_prompt_embs.
|
516 |
+
# hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]]
|
517 |
+
# zs_extra_words_scale is only effective when list_extra_words is not None.
|
518 |
+
# adaface_prompt_embs: [BS, 77, 768], core_id_embs: [BS, 16, 768].
|
519 |
+
adaface_prompt_embs, core_id_embs = \
|
520 |
+
arc2face_inverse_face_prompt_embs(self.clip_tokenizer,
|
521 |
+
self.prompt2token_proj,
|
522 |
+
arc2face_id_embs,
|
523 |
+
list_extra_words=None,
|
524 |
+
return_emb_types=return_emb_types,
|
525 |
+
pad_embeddings=self.pad_embeddings,
|
526 |
+
hidden_state_layer_weights=hidden_state_layer_weights,
|
527 |
+
input_max_length=77, zs_extra_words_scale=self.zs_extra_words_scale)
|
528 |
+
# Reduce the update rate to prompt2token_proj.
|
529 |
+
adaface_prompt_embs = self.prompt2token_proj_grad_scaler(adaface_prompt_embs)
|
530 |
+
core_id_embs = self.prompt2token_proj_grad_scaler(core_id_embs)
|
531 |
+
elif raw_id_embs is not None:
|
532 |
+
# id_embs: [BS, 384] -> [BS, 18, 768].
|
533 |
+
# obj_proj_in is expected to project the DINO object features to
|
534 |
+
# the token embedding space. So no need to use prompt2token_proj.
|
535 |
+
id_embs = self.obj_proj_in(raw_id_embs)
|
536 |
+
else:
|
537 |
+
breakpoint()
|
538 |
+
else:
|
539 |
+
# Otherwise, context is the ad-hoc CLIP image features.
|
540 |
+
# id_embs: [BS, 257, 768].
|
541 |
+
id_embs = self.bg_proj_in(clip_features)
|
542 |
+
|
543 |
+
if self.placeholder_is_bg:
|
544 |
+
id_embs = id_embs + self.pos_embs_ln(self.pos_embs)
|
545 |
+
latent_queries = self.latent_queries_ln(self.latent_queries).repeat(BS, 1, 1)
|
546 |
+
# If bg, we don't have to use a specific attn layer for each 4-vec set. Instead, one attn layer can generate 257 embs,
|
547 |
+
# and we take the first 16*4=64.
|
548 |
+
# Output of prompt_translator is exactly num_out_embs == 64 tokens. id_embs_out: [BS, 64, 768].
|
549 |
+
# prompt_translator: better named as bg_prompt_translator. It maps the bg features
|
550 |
+
# to bg prompt embeddings.
|
551 |
+
with torch.set_grad_enabled(self.training):
|
552 |
+
id_embs_out = self.prompt_translator(latent_queries, id_embs)
|
553 |
+
# [BS, 64, 768] -> [BS, 16, 4, 768]
|
554 |
+
id_embs_out = id_embs_out.reshape(BS, self.num_out_layers, -1, self.output_dim)
|
555 |
+
adaface_subj_embs = id_embs_out * self.output_scale # * 0.036
|
556 |
+
else:
|
557 |
+
# adaface_subj_embs: [BS, 16, 768] -> [BS, 1, 16, 768] -> [BS, 16, 16, 768]
|
558 |
+
adaface_subj_embs = core_id_embs.unsqueeze(1).repeat(1, self.num_out_layers, 1, 1)
|
559 |
+
|
560 |
+
# If out_id_embs_scale < 1, adaface_subj_embs is a mix of adaface_subj_embs and pad_embeddings.
|
561 |
+
if out_id_embs_scale != 1:
|
562 |
+
# pad_embeddings: [77, 768] -> [16, 768] -> [1, 1, 16, 768].
|
563 |
+
pad_embeddings = self.pad_embeddings[4:4+self.num_out_embs_per_layer].unsqueeze(0).unsqueeze(0)
|
564 |
+
adaface_subj_embs = adaface_subj_embs * out_id_embs_scale \
|
565 |
+
+ pad_embeddings * (1 - out_id_embs_scale)
|
566 |
+
|
567 |
+
return adaface_subj_embs, adaface_prompt_embs
|
568 |
+
|
569 |
+
def initialize_hidden_state_layer_weights(self, learnable_hidden_state_weights_scheme, device):
|
570 |
+
if learnable_hidden_state_weights_scheme == 'none':
|
571 |
+
self.hidden_state_layer_weights = None
|
572 |
+
# A grad scaler with alpha =1 is nn.Identity(), which outputs None given None as input.
|
573 |
+
self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(1)
|
574 |
+
print("hidden_state_layer_weights is set to None.")
|
575 |
+
|
576 |
+
elif learnable_hidden_state_weights_scheme == 'per-layer':
|
577 |
+
# Learnable weights of the last 3 layers, initialized to putting more focus on the last layer.
|
578 |
+
# 'per-layer': Different weights for different layers, but the same for different channels.
|
579 |
+
# hidden_state_layer_weights: [3, 1].
|
580 |
+
self.hidden_state_layer_weights = nn.Parameter(torch.tensor([[1.0], [2.0], [4.0]], device=device),
|
581 |
+
requires_grad=True)
|
582 |
+
self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(5)
|
583 |
+
print("hidden_state_layer_weights initialized as per-layer [1, 2, 4], with grad scaler 5.")
|
584 |
+
else:
|
585 |
+
breakpoint()
|
586 |
+
|
587 |
+
def generate_pad_embeddings(self):
|
588 |
+
# clip_embeddings: CLIPTextEmbeddings instance. pad_embeddings is generated after
|
589 |
+
# prompt2token_proj is loaded from the finetuned weight. It seems such pad embeddings perform
|
590 |
+
# slightly better than the original pad embeddings.
|
591 |
+
clip_embeddings = self.prompt2token_proj.text_model.embeddings
|
592 |
+
# clip_embeddings() and clip_embeddings.token_embedding() differ in that
|
593 |
+
# clip_embeddings() adds positional embeddings, while clip_embeddings.token_embedding() doesn't.
|
594 |
+
# Adding positional embeddings seems to help somewhat.
|
595 |
+
# pad_tokens: pad_token_id 49407 repeated 77 times.
|
596 |
+
# pad_token_id is the EOS token. But BOS is 49406.
|
597 |
+
pad_tokens = torch.tensor([self.clip_tokenizer.pad_token_id]).to(clip_embeddings.token_embedding.weight.device).repeat(77)
|
598 |
+
# pad_embeddings: [77, 768].
|
599 |
+
pad_embeddings = clip_embeddings(pad_tokens)[0]
|
600 |
+
# We don't allow face recon to influence the pad embeddings.
|
601 |
+
# Otherwise, face identity will leak into the pad embeddings.
|
602 |
+
self.pad_embeddings = pad_embeddings.detach()
|
603 |
+
|
604 |
+
def extend_prompt2token_proj_attention(self, begin_layer_idx=-1, end_layer_idx=-1, multiplier=2, noise_std=0.1):
|
605 |
+
if multiplier > 1:
|
606 |
+
num_extended_layers = self.prompt2token_proj.extend_clip_attention_MKV_multiplier(begin_layer_idx, end_layer_idx, multiplier, noise_std)
|
607 |
+
self.prompt2token_proj_attention_multiplier = multiplier
|
608 |
+
print(f"{num_extended_layers} layers in prompt2token_proj_attention are x{multiplier}")
|
609 |
+
|
610 |
+
def freeze_prompt2token_proj(self):
|
611 |
+
# If bg, then prompt2token_proj is set to None. Therefore no need to freeze it.
|
612 |
+
# Then we don't have to check whether it's for subj or bg.
|
613 |
+
if self.prompt2token_proj is not None:
|
614 |
+
frozen_param_names = []
|
615 |
+
for param_name, param in self.prompt2token_proj.named_parameters():
|
616 |
+
if param.requires_grad:
|
617 |
+
param.requires_grad = False
|
618 |
+
frozen_param_names.append(param_name)
|
619 |
+
# If param is already frozen, then no need to freeze it again.
|
620 |
+
print(f"{len(frozen_param_names)} params in Subj prompt2token_proj is frozen.")
|
621 |
+
#print(f"Frozen parameters:\n{frozen_param_names}")
|
622 |
+
|
623 |
+
def __repr__(self):
|
624 |
+
type_sig = 'subj' if not self.placeholder_is_bg else 'bg'
|
625 |
+
# Fix compatability with the previous version.
|
626 |
+
if not hasattr(self, 'bg_prompt_translator_has_to_out_proj'):
|
627 |
+
self.bg_prompt_translator_has_to_out_proj = False
|
628 |
+
if not hasattr(self, 'num_out_embs'):
|
629 |
+
self.num_out_embs = -1
|
630 |
+
return f"{type_sig} SubjBasisGenerator: num_out_embs={self.num_out_embs}, " \
|
631 |
+
f"bg_prompt_translator_has_to_out_proj={self.bg_prompt_translator_has_to_out_proj}"
|
632 |
+
|
633 |
+
@dataclass
|
634 |
+
class BaseModelOutputWithPooling2(ModelOutput):
|
635 |
+
"""
|
636 |
+
Base class for model's outputs that also contains a pooling of the last hidden states.
|
637 |
+
|
638 |
+
Args:
|
639 |
+
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
640 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
641 |
+
pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
|
642 |
+
Last layer hidden-state of the first token of the sequence (classification token) after further processing
|
643 |
+
through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
|
644 |
+
the classification token after processing through a linear layer and a tanh activation function. The linear
|
645 |
+
layer weights are trained from the next sentence prediction (classification) objective during pretraining.
|
646 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
647 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
648 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
649 |
+
|
650 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
651 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
652 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
653 |
+
sequence_length)`.
|
654 |
+
|
655 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
656 |
+
heads.
|
657 |
+
"""
|
658 |
+
|
659 |
+
last_hidden_state: torch.FloatTensor = None
|
660 |
+
pooler_output: torch.FloatTensor = None
|
661 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
662 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
663 |
+
attn_mask: Optional[torch.FloatTensor] = None
|
664 |
+
|
665 |
+
# Revised from CLIPVisionTransformer to support attention mask.
|
666 |
+
# self: a CLIPVisionTransformer instance.
|
667 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L821
|
668 |
+
# pixel_values: preprocessed B*C*H*W images. [BS, 3, 224, 224]
|
669 |
+
# attn_mask: B*H*W attention mask.
|
670 |
+
def CLIPVisionTransformer_forward(self, pixel_values = None, attn_mask=None,
|
671 |
+
output_attentions = None,
|
672 |
+
output_hidden_states = None, return_dict = None):
|
673 |
+
|
674 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
675 |
+
output_hidden_states = (
|
676 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
677 |
+
)
|
678 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
679 |
+
|
680 |
+
if pixel_values is None:
|
681 |
+
raise ValueError("You have to specify pixel_values")
|
682 |
+
|
683 |
+
# Visual tokens are flattended in embeddings().
|
684 |
+
# self.embeddings: CLIPVisionEmbeddings.
|
685 |
+
# hidden_states: [BS, 257, 1280]. 257: 16*16 (patch_embeds) + 1 (class_embeds).
|
686 |
+
# 16*16 is output from Conv2d(3, 1280, kernel_size=(14, 14), stride=(14, 14), bias=False).
|
687 |
+
hidden_states = self.embeddings(pixel_values)
|
688 |
+
hidden_states = self.pre_layrnorm(hidden_states)
|
689 |
+
|
690 |
+
if attn_mask is not None:
|
691 |
+
# feat_edge_size: 16.
|
692 |
+
feat_edge_size = np.sqrt(hidden_states.shape[1] - 1).astype(int)
|
693 |
+
# attn_mask: [BS, 512, 512] -> [BS, 1, 16, 16].
|
694 |
+
attn_mask = F.interpolate(attn_mask.unsqueeze(1), size=(feat_edge_size, feat_edge_size), mode='nearest')
|
695 |
+
# Flatten the mask: [BS, 1, 16, 16] => [BS, 1, 256].
|
696 |
+
attn_mask = attn_mask.flatten(2)
|
697 |
+
# Prepend 1 to the mask: [BS, 1, 256] => [BS, 1, 257].
|
698 |
+
# This 1 corresponds to class_embeds, which is always attended to.
|
699 |
+
attn_mask = torch.cat([torch.ones_like(attn_mask[:, :, :1]), attn_mask], dim=-1)
|
700 |
+
attn_mask_pairs = torch.matmul(attn_mask.transpose(-1, -2), attn_mask).unsqueeze(1)
|
701 |
+
else:
|
702 |
+
attn_mask_pairs = None
|
703 |
+
|
704 |
+
# encoder: CLIPEncoder.
|
705 |
+
encoder_outputs = self.encoder(
|
706 |
+
inputs_embeds=hidden_states,
|
707 |
+
# New feature: (***The official documentation is wrong***)
|
708 |
+
# attention_mask (`torch.Tensor` of shape `(batch_size, 1, sequence_length, sequence_length)`, *optional*):
|
709 |
+
# Mask to avoid performing attention on pairs of token. Mask values selected in `[0, 1]`:
|
710 |
+
# - 1 for pairs that are **not masked**,
|
711 |
+
# - 0 for pairs that are **masked**.
|
712 |
+
# attention_mask is eventually used by CLIPEncoderLayer:
|
713 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L370
|
714 |
+
attention_mask=attn_mask_pairs,
|
715 |
+
output_attentions=output_attentions, # False
|
716 |
+
output_hidden_states=output_hidden_states, # True
|
717 |
+
return_dict=return_dict, # True
|
718 |
+
)
|
719 |
+
|
720 |
+
# last_hidden_state: [BS, 257, 1280]
|
721 |
+
last_hidden_state = encoder_outputs[0]
|
722 |
+
pooled_output = last_hidden_state[:, 0, :]
|
723 |
+
pooled_output = self.post_layernorm(pooled_output)
|
724 |
+
|
725 |
+
# return_dict is True.
|
726 |
+
if not return_dict:
|
727 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
728 |
+
|
729 |
+
return BaseModelOutputWithPooling2(
|
730 |
+
last_hidden_state=last_hidden_state,
|
731 |
+
pooler_output=pooled_output,
|
732 |
+
hidden_states=encoder_outputs.hidden_states,
|
733 |
+
attentions=encoder_outputs.attentions,
|
734 |
+
# Newly added: return resized flattened attention mask.
|
735 |
+
# [BS, 1, 257] -> [BS, 257, 1]
|
736 |
+
attn_mask=attn_mask.permute(0, 2, 1) if attn_mask is not None else None
|
737 |
+
)
|
738 |
+
|
739 |
+
|
740 |
+
class CLIPVisionModelWithMask(CLIPVisionModel):
|
741 |
+
def __init__(self, config):
|
742 |
+
super().__init__(config)
|
743 |
+
# Replace vision_model.forward() with the new one that supports mask.
|
744 |
+
self.vision_model.forward = CLIPVisionTransformer_forward.__get__(self.vision_model)
|
745 |
+
|
746 |
+
def forward(self, pixel_values = None, attn_mask = None, output_attentions = None,
|
747 |
+
output_hidden_states = None, return_dict = None):
|
748 |
+
|
749 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
750 |
+
|
751 |
+
return self.vision_model(
|
752 |
+
pixel_values=pixel_values,
|
753 |
+
attn_mask=attn_mask,
|
754 |
+
output_attentions=output_attentions,
|
755 |
+
output_hidden_states=output_hidden_states,
|
756 |
+
return_dict=return_dict,
|
757 |
+
)
|
758 |
+
|
util.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
# add_noise_to_tensor() adds a fixed amount of noise to the tensor.
|
9 |
+
def add_noise_to_tensor(ts, noise_std, noise_std_is_relative=True, keep_norm=False,
|
10 |
+
std_dim=-1, norm_dim=-1):
|
11 |
+
if noise_std_is_relative:
|
12 |
+
ts_std_mean = ts.std(dim=std_dim).mean().detach()
|
13 |
+
noise_std *= ts_std_mean
|
14 |
+
|
15 |
+
noise = torch.randn_like(ts) * noise_std
|
16 |
+
if keep_norm:
|
17 |
+
orig_norm = ts.norm(dim=norm_dim, keepdim=True)
|
18 |
+
ts = ts + noise
|
19 |
+
new_norm = ts.norm(dim=norm_dim, keepdim=True).detach()
|
20 |
+
ts = ts * orig_norm / (new_norm + 1e-8)
|
21 |
+
else:
|
22 |
+
ts = ts + noise
|
23 |
+
|
24 |
+
return ts
|
25 |
+
|
26 |
+
|
27 |
+
# Revised from RevGrad, by removing the grad negation.
|
28 |
+
class ScaleGrad(torch.autograd.Function):
|
29 |
+
@staticmethod
|
30 |
+
def forward(ctx, input_, alpha_, debug=False):
|
31 |
+
ctx.save_for_backward(alpha_, debug)
|
32 |
+
output = input_
|
33 |
+
if debug:
|
34 |
+
print(f"input: {input_.abs().mean().item()}")
|
35 |
+
return output
|
36 |
+
|
37 |
+
@staticmethod
|
38 |
+
def backward(ctx, grad_output): # pragma: no cover
|
39 |
+
# saved_tensors returns a tuple of tensors.
|
40 |
+
alpha_, debug = ctx.saved_tensors
|
41 |
+
if ctx.needs_input_grad[0]:
|
42 |
+
grad_output2 = grad_output * alpha_
|
43 |
+
if debug:
|
44 |
+
print(f"grad_output2: {grad_output2.abs().mean().item()}")
|
45 |
+
else:
|
46 |
+
grad_output2 = None
|
47 |
+
return grad_output2, None, None
|
48 |
+
|
49 |
+
class GradientScaler(nn.Module):
|
50 |
+
def __init__(self, alpha=1., debug=False, *args, **kwargs):
|
51 |
+
"""
|
52 |
+
A gradient scaling layer.
|
53 |
+
This layer has no parameters, and simply scales the gradient in the backward pass.
|
54 |
+
"""
|
55 |
+
super().__init__(*args, **kwargs)
|
56 |
+
|
57 |
+
self._alpha = torch.tensor(alpha, requires_grad=False)
|
58 |
+
self._debug = torch.tensor(debug, requires_grad=False)
|
59 |
+
|
60 |
+
def forward(self, input_):
|
61 |
+
_debug = self._debug if hasattr(self, '_debug') else False
|
62 |
+
return ScaleGrad.apply(input_, self._alpha.to(input_.device), _debug)
|
63 |
+
|
64 |
+
def gen_gradient_scaler(alpha, debug=False):
|
65 |
+
if alpha == 1:
|
66 |
+
return nn.Identity()
|
67 |
+
if alpha > 0:
|
68 |
+
return GradientScaler(alpha, debug=debug)
|
69 |
+
else:
|
70 |
+
assert alpha == 0
|
71 |
+
# Don't use lambda function here, otherwise the object can't be pickled.
|
72 |
+
return torch.detach
|
73 |
+
|
74 |
+
#@torch.autocast(device_type="cuda")
|
75 |
+
# In AdaFaceWrapper, input_max_length is 22.
|
76 |
+
def arc2face_forward_face_embs(tokenizer, arc2face_text_encoder, face_embs,
|
77 |
+
input_max_length=77, return_full_and_core_embs=True):
|
78 |
+
|
79 |
+
'''
|
80 |
+
arc2face_text_encoder: arc2face_models.py CLIPTextModelWrapper instance.
|
81 |
+
face_embs: (N, 512) normalized ArcFace embeddings.
|
82 |
+
return_full_and_core_embs: Return both the full prompt embeddings and the core embeddings.
|
83 |
+
If False, return only the core embeddings.
|
84 |
+
|
85 |
+
'''
|
86 |
+
|
87 |
+
# arcface_token_id: 1014
|
88 |
+
arcface_token_id = tokenizer.encode("id", add_special_tokens=False)[0]
|
89 |
+
|
90 |
+
# This step should be quite fast, and there's no need to cache the input_ids.
|
91 |
+
input_ids = tokenizer(
|
92 |
+
"photo of a id person",
|
93 |
+
truncation=True,
|
94 |
+
padding="max_length",
|
95 |
+
max_length=input_max_length, #tokenizer.model_max_length,
|
96 |
+
return_tensors="pt",
|
97 |
+
).input_ids.to(face_embs.device)
|
98 |
+
# input_ids: [1, 77] or [3, 77] (during training).
|
99 |
+
input_ids = input_ids.repeat(len(face_embs), 1)
|
100 |
+
face_embs_dtype = face_embs.dtype
|
101 |
+
face_embs = face_embs.to(arc2face_text_encoder.dtype)
|
102 |
+
# face_embs_padded: [1, 512] -> [1, 768].
|
103 |
+
face_embs_padded = F.pad(face_embs, (0, arc2face_text_encoder.config.hidden_size - face_embs.shape[-1]), "constant", 0)
|
104 |
+
# arc2face_text_encoder(input_ids=input_ids, ...) is called twice. The first is only to get the token embeddings (the shallowest mapping).
|
105 |
+
# The second call does the ordinary CLIP text encoding pass.
|
106 |
+
token_embs = arc2face_text_encoder(input_ids=input_ids, return_token_embs=True)
|
107 |
+
token_embs[input_ids==arcface_token_id] = face_embs_padded
|
108 |
+
|
109 |
+
prompt_embeds = arc2face_text_encoder(
|
110 |
+
input_ids=input_ids,
|
111 |
+
input_token_embs=token_embs,
|
112 |
+
return_token_embs=False
|
113 |
+
)[0]
|
114 |
+
|
115 |
+
# Restore the original dtype of prompt_embeds: float16 -> float32.
|
116 |
+
prompt_embeds = prompt_embeds.to(face_embs_dtype)
|
117 |
+
|
118 |
+
if return_full_and_core_embs:
|
119 |
+
# token 4: 'id' in "photo of a id person".
|
120 |
+
# 4:20 are the most important 16 embeddings that contain the subject's identity.
|
121 |
+
# [N, 77, 768] -> [N, 16, 768]
|
122 |
+
return prompt_embeds, prompt_embeds[:, 4:20]
|
123 |
+
else:
|
124 |
+
# [N, 16, 768]
|
125 |
+
return prompt_embeds[:, 4:20]
|
126 |
+
|
127 |
+
def get_b_core_e_embeddings(prompt_embeds, length=22):
|
128 |
+
b_core_e_embs = torch.cat([ prompt_embeds[:, :length], prompt_embeds[:, [-1]] ], dim=1)
|
129 |
+
return b_core_e_embs
|
130 |
+
|
131 |
+
# return_emb_types: a list of strings, each string is among ['full', 'core', 'full_zeroed_extra', 'b_core_e'].
|
132 |
+
def arc2face_inverse_face_prompt_embs(clip_tokenizer, inverse_text_encoder, face_prompt_embs, list_extra_words,
|
133 |
+
return_emb_types, pad_embeddings, hidden_state_layer_weights=None,
|
134 |
+
input_max_length=77, zs_extra_words_scale=0.5):
|
135 |
+
|
136 |
+
'''
|
137 |
+
inverse_text_encoder: arc2face_models.py CLIPTextModelWrapper instance with **custom weights**.
|
138 |
+
inverse_text_encoder is NOT the original arc2face text encoder, but retrained to do inverse mapping.
|
139 |
+
face_prompt_embs: (BS, 16, 768). Only the core embeddings, no paddings.
|
140 |
+
list_extra_words: [s_1, ..., s_BS], each s_i is a list of extra words to be added to the prompt.
|
141 |
+
return_full_and_core_embs: Return both the full prompt embeddings and the core embeddings.
|
142 |
+
If False, return only the core embeddings.
|
143 |
+
'''
|
144 |
+
|
145 |
+
if list_extra_words is not None:
|
146 |
+
if len(list_extra_words) != len(face_prompt_embs):
|
147 |
+
if len(face_prompt_embs) > 1:
|
148 |
+
print("Warn: list_extra_words has different length as face_prompt_embs.")
|
149 |
+
if len(list_extra_words) == 1:
|
150 |
+
list_extra_words = list_extra_words * len(face_prompt_embs)
|
151 |
+
else:
|
152 |
+
breakpoint()
|
153 |
+
else:
|
154 |
+
# len(face_prompt_embs) == 1, this occurs when same_subject_in_batch == True, e.g. in do_mix_prompt_distillation.
|
155 |
+
# But list_extra_words always corresponds to the actual batch size. So we only take the first element.
|
156 |
+
list_extra_words = list_extra_words[:1]
|
157 |
+
|
158 |
+
for extra_words in list_extra_words:
|
159 |
+
assert len(extra_words.split()) <= 2, "Each extra_words string should consist of at most 2 words."
|
160 |
+
# 16 ", " are placeholders for face_prompt_embs.
|
161 |
+
prompt_templates = [ "photo of a " + ", " * 16 + list_extra_words[i] for i in range(len(list_extra_words)) ]
|
162 |
+
else:
|
163 |
+
# 16 ", " are placeholders for face_prompt_embs.
|
164 |
+
# No extra words are added to the prompt.
|
165 |
+
prompt_templates = [ "photo of a " + ", " * 16 for _ in range(len(face_prompt_embs)) ]
|
166 |
+
|
167 |
+
# This step should be quite fast, and there's no need to cache the input_ids.
|
168 |
+
# input_ids: [BS, 77].
|
169 |
+
input_ids = clip_tokenizer(
|
170 |
+
prompt_templates,
|
171 |
+
truncation=True,
|
172 |
+
padding="max_length",
|
173 |
+
max_length=input_max_length,
|
174 |
+
return_tensors="pt",
|
175 |
+
).input_ids.to(face_prompt_embs.device)
|
176 |
+
|
177 |
+
face_prompt_embs_dtype = face_prompt_embs.dtype
|
178 |
+
face_prompt_embs = face_prompt_embs.to(inverse_text_encoder.dtype)
|
179 |
+
|
180 |
+
# token_embs: [1, 77, 768]. This call is only to get the template token embeddings (the shallowest mapping).
|
181 |
+
token_embs = inverse_text_encoder(input_ids=input_ids, return_token_embs=True)
|
182 |
+
# token 4: first ", " in the template prompt.
|
183 |
+
# Replace embeddings of 16 placeholder ", " with face_prompt_embs.
|
184 |
+
token_embs[:, 4:20] = face_prompt_embs
|
185 |
+
|
186 |
+
# This call does the ordinary CLIP text encoding pass.
|
187 |
+
prompt_embeds = inverse_text_encoder(
|
188 |
+
input_ids=input_ids,
|
189 |
+
input_token_embs=token_embs,
|
190 |
+
hidden_state_layer_weights=hidden_state_layer_weights,
|
191 |
+
return_token_embs=False
|
192 |
+
)[0]
|
193 |
+
|
194 |
+
# Restore the original dtype of prompt_embeds: float16 -> float32.
|
195 |
+
prompt_embeds = prompt_embeds.to(face_prompt_embs_dtype)
|
196 |
+
# token 4: first ", " in the template prompt.
|
197 |
+
# 4:20 are the most important 16 embeddings that contain the subject's identity.
|
198 |
+
# 20:22 are embeddings of the (at most) two extra words.
|
199 |
+
# [N, 77, 768] -> [N, 16, 768]
|
200 |
+
core_prompt_embs = prompt_embeds[:, 4:20]
|
201 |
+
if list_extra_words is not None:
|
202 |
+
# [N, 16, 768] -> [N, 18, 768]
|
203 |
+
extra_words_embs = prompt_embeds[:, 20:22] * zs_extra_words_scale
|
204 |
+
core_prompt_embs = torch.cat([core_prompt_embs, extra_words_embs], dim=1)
|
205 |
+
|
206 |
+
return_prompts = []
|
207 |
+
for emb_type in return_emb_types:
|
208 |
+
if emb_type == 'full':
|
209 |
+
return_prompts.append(prompt_embeds)
|
210 |
+
elif emb_type == 'full_half_pad':
|
211 |
+
prompt_embeds2 = prompt_embeds.clone()
|
212 |
+
PADS = prompt_embeds2.shape[1] - 23
|
213 |
+
if PADS >= 2:
|
214 |
+
# Fill half of the remaining embeddings with pad embeddings.
|
215 |
+
prompt_embeds2[:, 22:22+PADS//2] = pad_embeddings[22:22+PADS//2]
|
216 |
+
return_prompts.append(prompt_embeds2)
|
217 |
+
elif emb_type == 'full_pad':
|
218 |
+
prompt_embeds2 = prompt_embeds.clone()
|
219 |
+
# Fill the 22nd to the second last embeddings with pad embeddings.
|
220 |
+
prompt_embeds2[:, 22:-1] = pad_embeddings[22:-1]
|
221 |
+
return_prompts.append(prompt_embeds2)
|
222 |
+
elif emb_type == 'core':
|
223 |
+
return_prompts.append(core_prompt_embs)
|
224 |
+
elif emb_type == 'full_zeroed_extra':
|
225 |
+
prompt_embeds2 = prompt_embeds.clone()
|
226 |
+
# Only add two pad embeddings. The remaining embeddings are set to 0.
|
227 |
+
# Make the positional embeddings align with the actual positions.
|
228 |
+
prompt_embeds2[:, 22:24] = pad_embeddings[22:24]
|
229 |
+
prompt_embeds2[:, 24:-1] = 0
|
230 |
+
return_prompts.append(prompt_embeds2)
|
231 |
+
elif emb_type == 'b_core_e':
|
232 |
+
# The first 22 embeddings, plus the last EOS embedding.
|
233 |
+
b_core_e_embs = get_b_core_e_embeddings(prompt_embeds, length=22)
|
234 |
+
return_prompts.append(b_core_e_embs)
|
235 |
+
else:
|
236 |
+
breakpoint()
|
237 |
+
|
238 |
+
return return_prompts
|
239 |
+
|
240 |
+
# if pre_face_embs is None, generate random face embeddings [BS, 512].
|
241 |
+
# image_folder is passed only for logging purpose. image_paths contains the paths of the images.
|
242 |
+
def get_arc2face_id_prompt_embs(face_app, clip_tokenizer, arc2face_text_encoder,
|
243 |
+
extract_faceid_embeds, pre_face_embs,
|
244 |
+
image_folder, image_paths, images_np,
|
245 |
+
id_batch_size, device,
|
246 |
+
input_max_length=77, noise_level=0.0,
|
247 |
+
return_core_id_embs=False,
|
248 |
+
gen_neg_prompt=False, verbose=False):
|
249 |
+
face_image_count = 0
|
250 |
+
|
251 |
+
if extract_faceid_embeds:
|
252 |
+
faceid_embeds = []
|
253 |
+
if image_paths is not None:
|
254 |
+
images_np = []
|
255 |
+
for image_path in image_paths:
|
256 |
+
image_np = np.array(Image.open(image_path))
|
257 |
+
images_np.append(image_np)
|
258 |
+
|
259 |
+
for i, image_np in enumerate(images_np):
|
260 |
+
image_obj = Image.fromarray(image_np).resize((512, 512), Image.NEAREST)
|
261 |
+
# Remove alpha channel if it exists.
|
262 |
+
if image_obj.mode == 'RGBA':
|
263 |
+
image_obj = image_obj.convert('RGB')
|
264 |
+
# This seems NOT a bug. The input image should be in BGR format, as per
|
265 |
+
# https://github.com/deepinsight/insightface/issues/524
|
266 |
+
image_np = cv2.cvtColor(np.array(image_obj), cv2.COLOR_RGB2BGR)
|
267 |
+
image_np = np.array(image_obj)
|
268 |
+
|
269 |
+
face_infos = face_app.get(image_np)
|
270 |
+
if verbose and image_paths is not None:
|
271 |
+
print(image_paths[i], len(face_infos))
|
272 |
+
# Assume all images belong to the same subject. Therefore, we can skip the images with no face detected.
|
273 |
+
if len(face_infos) == 0:
|
274 |
+
continue
|
275 |
+
# only use the maximum face
|
276 |
+
face_info = sorted(face_infos, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1]
|
277 |
+
# Each faceid_embed: [1, 512]
|
278 |
+
faceid_embeds.append(torch.from_numpy(face_info.normed_embedding).unsqueeze(0))
|
279 |
+
face_image_count += 1
|
280 |
+
|
281 |
+
if verbose:
|
282 |
+
if image_folder is not None:
|
283 |
+
print(f"Extracted ID embeddings from {face_image_count} images in {image_folder}")
|
284 |
+
else:
|
285 |
+
print(f"Extracted ID embeddings from {face_image_count} images")
|
286 |
+
|
287 |
+
if len(faceid_embeds) == 0:
|
288 |
+
print("No face detected. Use a random face instead.")
|
289 |
+
faceid_embeds = torch.randn(id_batch_size, 512).to(device=device, dtype=torch.float16)
|
290 |
+
else:
|
291 |
+
# faceid_embeds: [10, 512]
|
292 |
+
faceid_embeds = torch.cat(faceid_embeds, dim=0)
|
293 |
+
# faceid_embeds: [10, 512] -> [1, 512].
|
294 |
+
# and the resulted prompt embeddings are the same.
|
295 |
+
faceid_embeds = faceid_embeds.mean(dim=0, keepdim=True).to(device=device, dtype=torch.float16)
|
296 |
+
else:
|
297 |
+
# Random face embeddings. faceid_embeds: [BS, 512].
|
298 |
+
if pre_face_embs is None:
|
299 |
+
faceid_embeds = torch.randn(id_batch_size, 512)
|
300 |
+
else:
|
301 |
+
faceid_embeds = pre_face_embs
|
302 |
+
if pre_face_embs.shape[0] == 1:
|
303 |
+
faceid_embeds = faceid_embeds.repeat(id_batch_size, 1)
|
304 |
+
|
305 |
+
faceid_embeds = faceid_embeds.to(device=device, dtype=torch.float16)
|
306 |
+
|
307 |
+
if noise_level > 0:
|
308 |
+
# If id_batch_size > 1, after adding noises, the id_batch_size embeddings will be different.
|
309 |
+
faceid_embeds = add_noise_to_tensor(faceid_embeds, noise_level, noise_std_is_relative=True, keep_norm=True)
|
310 |
+
|
311 |
+
faceid_embeds = F.normalize(faceid_embeds, p=2, dim=-1)
|
312 |
+
|
313 |
+
# arc2face_pos_prompt_emb, arc2face_neg_prompt_emb: [BS, 77, 768]
|
314 |
+
with torch.no_grad():
|
315 |
+
arc2face_pos_prompt_emb, arc2face_pos_core_prompt_emb = \
|
316 |
+
arc2face_forward_face_embs(clip_tokenizer, arc2face_text_encoder,
|
317 |
+
faceid_embeds, input_max_length=input_max_length,
|
318 |
+
return_full_and_core_embs=True)
|
319 |
+
if return_core_id_embs:
|
320 |
+
arc2face_pos_prompt_emb = arc2face_pos_core_prompt_emb
|
321 |
+
# If extract_faceid_embeds, we assume all images are from the same subject, and the batch dim of faceid_embeds is 1.
|
322 |
+
# So we need to repeat faceid_embeds.
|
323 |
+
if extract_faceid_embeds:
|
324 |
+
faceid_embeds = faceid_embeds.repeat(id_batch_size, 1)
|
325 |
+
arc2face_pos_prompt_emb = arc2face_pos_prompt_emb.repeat(id_batch_size, 1, 1)
|
326 |
+
|
327 |
+
if gen_neg_prompt:
|
328 |
+
with torch.no_grad():
|
329 |
+
arc2face_neg_prompt_emb, arc2face_neg_core_prompt_emb = \
|
330 |
+
arc2face_forward_face_embs(clip_tokenizer, arc2face_text_encoder,
|
331 |
+
torch.zeros_like(faceid_embeds),
|
332 |
+
input_max_length=input_max_length,
|
333 |
+
return_full_and_core_embs=True)
|
334 |
+
if return_core_id_embs:
|
335 |
+
arc2face_neg_prompt_emb = arc2face_neg_core_prompt_emb
|
336 |
+
|
337 |
+
#if extract_faceid_embeds:
|
338 |
+
# arc2face_neg_prompt_emb = arc2face_neg_prompt_emb.repeat(id_batch_size, 1, 1)
|
339 |
+
return face_image_count, faceid_embeds, arc2face_pos_prompt_emb, arc2face_neg_prompt_emb
|
340 |
+
else:
|
341 |
+
return face_image_count, faceid_embeds, arc2face_pos_prompt_emb
|
342 |
+
|