zyt334 commited on
Commit
57f11a4
1 Parent(s): bc971c6

Upload folder using huggingface_hub

Browse files
Files changed (40) hide show
  1. .gitattributes +35 -35
  2. README.md +12 -12
  3. adaface-infer.py +131 -0
  4. adaface-translate.py +208 -0
  5. adaface/__pycache__/adaface_wrapper.cpython-312.pyc +0 -0
  6. adaface/__pycache__/adaface_wrapper.cpython-38.pyc +0 -0
  7. adaface/__pycache__/arc2face_models.cpython-312.pyc +0 -0
  8. adaface/__pycache__/arc2face_models.cpython-38.pyc +0 -0
  9. adaface/__pycache__/subj_basis_generator.cpython-312.pyc +0 -0
  10. adaface/__pycache__/subj_basis_generator.cpython-38.pyc +0 -0
  11. adaface/__pycache__/util.cpython-312.pyc +0 -0
  12. adaface/__pycache__/util.cpython-38.pyc +0 -0
  13. adaface/adaface-infer.py +131 -0
  14. adaface/adaface-translate.py +208 -0
  15. adaface/adaface_wrapper.py +297 -0
  16. adaface/arc2face_models.py +303 -0
  17. adaface/subj_basis_generator.py +758 -0
  18. adaface/util.py +342 -0
  19. adaface_wrapper.py +297 -0
  20. app.py +203 -0
  21. arc2face_models.py +303 -0
  22. models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt +3 -0
  23. models/arc2face/arc2face/config.json +67 -0
  24. models/arc2face/arc2face/diffusion_pytorch_model.safetensors +3 -0
  25. models/arc2face/encoder/config.json +24 -0
  26. models/arc2face/encoder/pytorch_model.bin +3 -0
  27. models/insightface/models/antelopev2/1k3d68.onnx +3 -0
  28. models/insightface/models/antelopev2/2d106det.onnx +3 -0
  29. models/insightface/models/antelopev2/arcface.onnx +3 -0
  30. models/insightface/models/antelopev2/genderage.onnx +3 -0
  31. models/insightface/models/antelopev2/scrfd_10g_bnkps.onnx +3 -0
  32. models/insightface/models/buffalo_l/1k3d68.onnx +3 -0
  33. models/insightface/models/buffalo_l/2d106det.onnx +3 -0
  34. models/insightface/models/buffalo_l/det_10g.onnx +3 -0
  35. models/insightface/models/buffalo_l/genderage.onnx +3 -0
  36. models/insightface/models/buffalo_l/w600k_r50.onnx +3 -0
  37. models/sar/sar.safetensors +3 -0
  38. requirements.txt +12 -0
  39. subj_basis_generator.py +758 -0
  40. util.py +342 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,12 @@
1
- ---
2
- title: Adaface
3
- emoji: 🚀
4
- colorFrom: gray
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 4.39.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Adaface
3
+ emoji: 🌖
4
+ colorFrom: purple
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 4.37.2
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
adaface-infer.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from adaface.adaface_wrapper import AdaFaceWrapper
2
+ import torch
3
+ #import torch.nn.functional as F
4
+ from PIL import Image
5
+ import numpy as np
6
+ import os, argparse, glob, re
7
+
8
+ def save_images(images, num_images_per_row, subject_name, prompt, noise_level, save_dir = "samples-ada"):
9
+ if num_images_per_row > len(images):
10
+ num_images_per_row = len(images)
11
+
12
+ os.makedirs(save_dir, exist_ok=True)
13
+
14
+ num_columns = int(np.ceil(len(images) / num_images_per_row))
15
+ # Save 4 images as a grid image in save_dir
16
+ grid_image = Image.new('RGB', (512 * num_images_per_row, 512 * num_columns))
17
+ for i, image in enumerate(images):
18
+ image = image.resize((512, 512))
19
+ grid_image.paste(image, (512 * (i % num_images_per_row), 512 * (i // num_images_per_row)))
20
+
21
+ prompt_sig = prompt.replace(" ", "_").replace(",", "_")
22
+ grid_filepath = os.path.join(save_dir, f"{subject_name}-{prompt_sig}-noise{noise_level:.02f}.png")
23
+ if os.path.exists(grid_filepath):
24
+ grid_count = 2
25
+ grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-noise{noise_level:.02f}-{grid_count}.jpg')
26
+ while os.path.exists(grid_filepath):
27
+ grid_count += 1
28
+ grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-noise{noise_level:.02f}-{grid_count}.jpg')
29
+
30
+ grid_image.save(grid_filepath)
31
+ print(f"Saved to {grid_filepath}")
32
+
33
+ def seed_everything(seed):
34
+ np.random.seed(seed)
35
+ torch.manual_seed(seed)
36
+ torch.cuda.manual_seed_all(seed)
37
+ torch.backends.cudnn.deterministic = True
38
+ torch.backends.cudnn.benchmark = False
39
+ os.environ["PL_GLOBAL_SEED"] = str(seed)
40
+
41
+ def parse_args():
42
+ parser = argparse.ArgumentParser()
43
+ parser.add_argument("--base_model_path", type=str, default='runwayml/stable-diffusion-v1-5',
44
+ help="Type of checkpoints to use (default: SD 1.5)")
45
+ parser.add_argument("--embman_ckpt", type=str, required=True,
46
+ help="Path to the checkpoint of the embedding manager")
47
+ parser.add_argument("--subject", type=str, required=True)
48
+ parser.add_argument("--example_image_count", type=int, default=-1, help="Number of example images to use")
49
+ parser.add_argument("--out_image_count", type=int, default=4, help="Number of images to generate")
50
+ parser.add_argument("--prompt", type=str, default="a woman z in superman costume")
51
+ parser.add_argument("--noise", dest='noise_level', type=float, default=0)
52
+ parser.add_argument("--randface", action="store_true")
53
+ parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
54
+ help="Guidance scale for the diffusion model")
55
+ parser.add_argument("--id_cfg_scale", type=float, default=1,
56
+ help="CFG scale when generating the identity embeddings")
57
+
58
+ parser.add_argument("--subject_string",
59
+ type=str, default="z",
60
+ help="Subject placeholder string used in prompts to denote the concept.")
61
+ parser.add_argument("--num_vectors", type=int, default=16,
62
+ help="Number of vectors used to represent the subject.")
63
+ parser.add_argument("--num_images_per_row", type=int, default=4,
64
+ help="Number of images to display in a row in the output grid image.")
65
+ parser.add_argument("--num_inference_steps", type=int, default=50,
66
+ help="Number of DDIM inference steps")
67
+ parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
68
+ parser.add_argument("--seed", type=int, default=42,
69
+ help="the seed (for reproducible sampling). Set to -1 to disable.")
70
+ args = parser.parse_args()
71
+
72
+ return args
73
+
74
+ if __name__ == "__main__":
75
+ args = parse_args()
76
+ if args.seed != -1:
77
+ seed_everything(args.seed)
78
+
79
+ if re.match(r"^\d+$", args.device):
80
+ args.device = f"cuda:{args.device}"
81
+ print(f"Using device {args.device}")
82
+
83
+ adaface = AdaFaceWrapper("text2img", args.base_model_path, args.embman_ckpt, args.device,
84
+ args.subject_string, args.num_vectors, args.num_inference_steps)
85
+
86
+ if not args.randface:
87
+ image_folder = args.subject
88
+ if image_folder.endswith("/"):
89
+ image_folder = image_folder[:-1]
90
+
91
+ if os.path.isfile(image_folder):
92
+ # Get the second to the last part of the path
93
+ subject_name = os.path.basename(os.path.dirname(image_folder))
94
+ image_paths = [image_folder]
95
+
96
+ else:
97
+ subject_name = os.path.basename(image_folder)
98
+ image_types = ["*.jpg", "*.png", "*.jpeg"]
99
+ alltype_image_paths = []
100
+ for image_type in image_types:
101
+ # glob returns the full path.
102
+ image_paths = glob.glob(os.path.join(image_folder, image_type))
103
+ if len(image_paths) > 0:
104
+ alltype_image_paths.extend(image_paths)
105
+
106
+ # Filter out images of "*_mask.png"
107
+ alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path]
108
+
109
+ # image_paths contain at most args.example_image_count full image paths.
110
+ if args.example_image_count > 0:
111
+ image_paths = alltype_image_paths[:args.example_image_count]
112
+ else:
113
+ image_paths = alltype_image_paths
114
+ else:
115
+ subject_name = None
116
+ image_paths = None
117
+ image_folder = None
118
+
119
+ subject_name = "randface-" + str(torch.seed()) if args.randface else subject_name
120
+ rand_face_embs = torch.randn(1, 512)
121
+
122
+ pre_face_embs = rand_face_embs if args.randface else None
123
+ noise = torch.randn(args.out_image_count, 4, 64, 64).cuda()
124
+ # args.noise_level: the *relative* std of the noise added to the face embeddings.
125
+ # A noise level of 0.08 could change gender, but 0.06 is usually safe.
126
+ # adaface_subj_embs is not used. It is generated for the purpose of updating the text encoder (within this function call).
127
+ adaface_subj_embs = adaface.generate_adaface_embeddings(image_paths, image_folder, pre_face_embs, args.randface,
128
+ out_id_embs_scale=args.id_cfg_scale, noise_level=args.noise_level,
129
+ update_text_encoder=True)
130
+ images = adaface(noise, args.prompt, args.guidance_scale, args.out_image_count, verbose=True)
131
+ save_images(images, args.num_images_per_row, subject_name, f"guide{args.guidance_scale}", args.noise_level)
adaface-translate.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from adaface.adaface_wrapper import AdaFaceWrapper
2
+ import torch
3
+ #import torch.nn.functional as F
4
+ from PIL import Image
5
+ import numpy as np
6
+ import os, argparse, glob, re, shutil
7
+
8
+ def str2bool(v):
9
+ if isinstance(v, bool):
10
+ return v
11
+ if v.lower() in ("yes", "true", "t", "y", "1"):
12
+ return True
13
+ elif v.lower() in ("no", "false", "f", "n", "0"):
14
+ return False
15
+ else:
16
+ raise argparse.ArgumentTypeError("Boolean value expected.")
17
+
18
+ def seed_everything(seed):
19
+ np.random.seed(seed)
20
+ torch.manual_seed(seed)
21
+ torch.cuda.manual_seed_all(seed)
22
+ torch.backends.cudnn.deterministic = True
23
+ torch.backends.cudnn.benchmark = False
24
+ os.environ["PL_GLOBAL_SEED"] = str(seed)
25
+
26
+ def parse_args():
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument("--base_model_path", type=str, default='models/realisticvision/realisticVisionV40_v40VAE.safetensors',
29
+ help="Path to the UNet checkpoint (default: RealisticVision 4.0)")
30
+ parser.add_argument("--embman_ckpt", type=str, required=True,
31
+ help="Path to the checkpoint of the embedding manager")
32
+ parser.add_argument("--in_folder", type=str, required=True, help="Path to the folder containing input images")
33
+ # If True, the input folder contains images of mixed subjects.
34
+ # If False, the input folder contains multiple subfolders, each of which contains images of the same subject.
35
+ parser.add_argument("--is_mix_subj_folder", type=str2bool, const=True, default=False, nargs="?",
36
+ help="Whether the input folder contains images of mixed subjects")
37
+ parser.add_argument("--max_images_per_subject", type=int, default=5, help="Number of example images used per subject")
38
+ parser.add_argument("--trans_subject_count", type=int, default=-1, help="Number of example images to be translated")
39
+ parser.add_argument("--out_folder", type=str, required=True, help="Path to the folder saving output images")
40
+ parser.add_argument("--out_count_per_input_image", type=int, default=1, help="Number of output images to generate per input image")
41
+ parser.add_argument("--copy_masks", action="store_true", help="Copy the mask images to the output folder")
42
+ parser.add_argument("--noise", dest='noise_level', type=float, default=0)
43
+ parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
44
+ help="Guidance scale for the diffusion model")
45
+ parser.add_argument("--ref_img_strength", type=float, default=0.8,
46
+ help="Strength of the reference image in the output image.")
47
+ parser.add_argument("--subject_string",
48
+ type=str, default="z",
49
+ help="Subject placeholder string used in prompts to denote the concept.")
50
+ parser.add_argument("--num_vectors", type=int, default=16,
51
+ help="Number of vectors used to represent the subject.")
52
+ parser.add_argument("--prompt", type=str, default="a person z")
53
+ parser.add_argument("--num_images_per_row", type=int, default=4,
54
+ help="Number of images to display in a row in the output grid image.")
55
+ parser.add_argument("--num_inference_steps", type=int, default=50,
56
+ help="Number of DDIM inference steps")
57
+ parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use. If num_gpus > 1, use accelerate for distributed execution.")
58
+ parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
59
+ parser.add_argument("--seed", type=int, default=42,
60
+ help="the seed (for reproducible sampling). Set to -1 to disable.")
61
+ args = parser.parse_args()
62
+
63
+ return args
64
+
65
+ if __name__ == "__main__":
66
+ args = parse_args()
67
+ if args.seed != -1:
68
+ seed_everything(args.seed)
69
+
70
+ # screen -dm -L -Logfile trans_rv4-2.txt accelerate launch --multi_gpu --num_processes=2 scripts/adaface-translate.py
71
+ # --embman_ckpt logs/subjects-celebrity2024-05-16T17-22-46_zero3-ada/checkpoints/embeddings_gs-30000.pt
72
+ # --base_model_path models/realisticvision/realisticVisionV40_v40VAE.safetensors --in_folder /data/username/VGGface2_HQ_masks/
73
+ # --is_mix_subj_folder 0 --out_folder /data/username/VGGface2_HQ_masks_rv4a --copy_masks --num_gpus 2
74
+ if args.num_gpus > 1:
75
+ from accelerate import PartialState
76
+ distributed_state = PartialState()
77
+ args.device = distributed_state.device
78
+ process_index = distributed_state.process_index
79
+ elif re.match(r"^\d+$", args.device):
80
+ args.device = f"cuda:{args.device}"
81
+ distributed_state = None
82
+ process_index = 0
83
+
84
+ adaface = AdaFaceWrapper("img2img", args.base_model_path, args.embman_ckpt, args.device,
85
+ args.subject_string, args.num_vectors, args.num_inference_steps)
86
+
87
+ in_folder = args.in_folder
88
+ if os.path.isfile(in_folder):
89
+ subject_folders = [ os.path.dirname(in_folder) ]
90
+ images_by_subject = [[in_folder]]
91
+ else:
92
+ if not args.is_mix_subj_folder:
93
+ in_folders = [in_folder]
94
+ else:
95
+ in_folders = [ os.path.join(in_folder, subfolder) for subfolder in sorted(os.listdir(in_folder)) ]
96
+
97
+ images_by_subject = []
98
+ subject_folders = []
99
+ for in_folder in in_folders:
100
+ image_types = ["*.jpg", "*.png", "*.jpeg"]
101
+ alltype_image_paths = []
102
+ for image_type in image_types:
103
+ # glob returns the full path.
104
+ image_paths = glob.glob(os.path.join(in_folder, image_type))
105
+ if len(image_paths) > 0:
106
+ alltype_image_paths.extend(image_paths)
107
+
108
+ # Filter out images of "*_mask.png"
109
+ alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path]
110
+ alltype_image_paths = sorted(alltype_image_paths)
111
+
112
+ if not args.is_mix_subj_folder:
113
+ # image_paths contain at most args.max_images_per_subject full image paths.
114
+ if args.max_images_per_subject > 0:
115
+ image_paths = alltype_image_paths[:args.max_images_per_subject]
116
+ else:
117
+ image_paths = alltype_image_paths
118
+
119
+ images_by_subject.append(image_paths)
120
+ subject_folders.append(in_folder)
121
+ else:
122
+ # Each image in the folder is treated as an individual subject.
123
+ images_by_subject.extend([[image_path] for image_path in alltype_image_paths])
124
+ subject_folders.extend([in_folder] * len(alltype_image_paths))
125
+
126
+ if args.trans_subject_count > 0 and len(subject_folders) >= args.trans_subject_count:
127
+ break
128
+
129
+ if args.trans_subject_count > 0:
130
+ images_by_subject = images_by_subject[:args.trans_subject_count]
131
+ subject_folders = subject_folders[:args.trans_subject_count]
132
+
133
+ out_image_count = 0
134
+ out_mask_count = 0
135
+ if not args.out_folder.endswith("/"):
136
+ args.out_folder += "/"
137
+
138
+ if args.num_gpus > 1:
139
+ # Split the subjects across the GPUs.
140
+ subject_folders = subject_folders[process_index::args.num_gpus]
141
+ images_by_subject = images_by_subject[process_index::args.num_gpus]
142
+ #subject_folders, images_by_subject = distributed_state.split_between_processes(zip(subject_folders, images_by_subject))
143
+
144
+ for (subject_folder, image_paths) in zip(subject_folders, images_by_subject):
145
+ # If is_mix_subj_folder, then image_paths only contains 1 image, and we use the file name as the signature of the image.
146
+ # Otherwise, we use the folder name as the signature of the images.
147
+ images_sig = subject_folder if not args.is_mix_subj_folder else os.path.basename(image_paths[0])
148
+
149
+ print(f"Translating {images_sig}...")
150
+ with torch.no_grad():
151
+ adaface_subj_embs = adaface.generate_adaface_embeddings(image_paths, subject_folder, None, False,
152
+ out_id_embs_scale=1, noise_level=args.noise_level,
153
+ update_text_encoder=True)
154
+
155
+ # Replace the first occurrence of "in_folder" with "out_folder" in the path of the subject_folder.
156
+ subject_out_folder = subject_folder.replace(args.in_folder, args.out_folder, 1)
157
+ if not os.path.exists(subject_out_folder):
158
+ os.makedirs(subject_out_folder)
159
+ print(f"Output images will be saved to {subject_out_folder}")
160
+
161
+ in_images = []
162
+ for image_path in image_paths:
163
+ image = Image.open(image_path).convert("RGB").resize((512, 512))
164
+ # [512, 512, 3] -> [3, 512, 512].
165
+ image = np.array(image).transpose(2, 0, 1)
166
+ # Convert the image to a tensor of shape (1, 3, 512, 512) and move it to the GPU.
167
+ image = torch.tensor(image).unsqueeze(0).float().cuda()
168
+ in_images.append(image)
169
+
170
+ # Put all input images of the subject into a batch. This assumes max_images_per_subject is small.
171
+ # NOTE: For simplicity, we do not check overly large batch sizes.
172
+ in_images = torch.cat(in_images, dim=0)
173
+ # in_images: [5, 3, 512, 512].
174
+ # Normalize the pixel values to [0, 1].
175
+ in_images = in_images / 255.0
176
+ num_out_images = len(in_images) * args.out_count_per_input_image
177
+
178
+ with torch.no_grad():
179
+ # args.noise_level: the *relative* std of the noise added to the face embeddings.
180
+ # A noise level of 0.08 could change gender, but 0.06 is usually safe.
181
+ # The returned adaface_subj_embs are already incorporated in the text encoder, and not used explicitly.
182
+ # NOTE: We assume out_count_per_input_image == 1, so that the output images are of the same number as the input images.
183
+ out_images = adaface(in_images, args.prompt, args.guidance_scale, num_out_images, ref_img_strength=args.ref_img_strength)
184
+
185
+ for img_i, img in enumerate(out_images):
186
+ # out_images: subj_1, subj_2, ..., subj_n, subj_1, subj_2, ..., subj_n, ...
187
+ subj_i = img_i % len(in_images)
188
+ copy_i = img_i // len(in_images)
189
+ image_filename_stem, image_fileext = os.path.splitext(os.path.basename(image_paths[subj_i]))
190
+ if copy_i == 0:
191
+ img.save(os.path.join(subject_out_folder, f"{image_filename_stem}{image_fileext}"))
192
+ else:
193
+ img.save(os.path.join(subject_out_folder, f"{image_filename_stem}_{copy_i}{image_fileext}"))
194
+
195
+ if args.copy_masks:
196
+ mask_path = image_paths[subj_i].replace(image_fileext, "_mask.png")
197
+ if os.path.exists(mask_path):
198
+ if copy_i == 0:
199
+ shutil.copy(mask_path, subject_out_folder)
200
+ else:
201
+ mask_filename_stem = image_filename_stem
202
+ shutil.copy(mask_path, os.path.join(subject_out_folder, f"{mask_filename_stem}_{copy_i}_mask.png"))
203
+
204
+ out_mask_count += 1
205
+
206
+ out_image_count += len(out_images)
207
+
208
+ print(f"{out_image_count} output images and {out_mask_count} masks saved to {args.out_folder}")
adaface/__pycache__/adaface_wrapper.cpython-312.pyc ADDED
Binary file (13.5 kB). View file
 
adaface/__pycache__/adaface_wrapper.cpython-38.pyc ADDED
Binary file (8.03 kB). View file
 
adaface/__pycache__/arc2face_models.cpython-312.pyc ADDED
Binary file (16.1 kB). View file
 
adaface/__pycache__/arc2face_models.cpython-38.pyc ADDED
Binary file (7 kB). View file
 
adaface/__pycache__/subj_basis_generator.cpython-312.pyc ADDED
Binary file (30.1 kB). View file
 
adaface/__pycache__/subj_basis_generator.cpython-38.pyc ADDED
Binary file (17.6 kB). View file
 
adaface/__pycache__/util.cpython-312.pyc ADDED
Binary file (14 kB). View file
 
adaface/__pycache__/util.cpython-38.pyc ADDED
Binary file (8.57 kB). View file
 
adaface/adaface-infer.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from adaface.adaface_wrapper import AdaFaceWrapper
2
+ import torch
3
+ #import torch.nn.functional as F
4
+ from PIL import Image
5
+ import numpy as np
6
+ import os, argparse, glob, re
7
+
8
+ def save_images(images, num_images_per_row, subject_name, prompt, noise_level, save_dir = "samples-ada"):
9
+ if num_images_per_row > len(images):
10
+ num_images_per_row = len(images)
11
+
12
+ os.makedirs(save_dir, exist_ok=True)
13
+
14
+ num_columns = int(np.ceil(len(images) / num_images_per_row))
15
+ # Save 4 images as a grid image in save_dir
16
+ grid_image = Image.new('RGB', (512 * num_images_per_row, 512 * num_columns))
17
+ for i, image in enumerate(images):
18
+ image = image.resize((512, 512))
19
+ grid_image.paste(image, (512 * (i % num_images_per_row), 512 * (i // num_images_per_row)))
20
+
21
+ prompt_sig = prompt.replace(" ", "_").replace(",", "_")
22
+ grid_filepath = os.path.join(save_dir, f"{subject_name}-{prompt_sig}-noise{noise_level:.02f}.png")
23
+ if os.path.exists(grid_filepath):
24
+ grid_count = 2
25
+ grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-noise{noise_level:.02f}-{grid_count}.jpg')
26
+ while os.path.exists(grid_filepath):
27
+ grid_count += 1
28
+ grid_filepath = os.path.join(save_dir, f'{subject_name}-{prompt_sig}-noise{noise_level:.02f}-{grid_count}.jpg')
29
+
30
+ grid_image.save(grid_filepath)
31
+ print(f"Saved to {grid_filepath}")
32
+
33
+ def seed_everything(seed):
34
+ np.random.seed(seed)
35
+ torch.manual_seed(seed)
36
+ torch.cuda.manual_seed_all(seed)
37
+ torch.backends.cudnn.deterministic = True
38
+ torch.backends.cudnn.benchmark = False
39
+ os.environ["PL_GLOBAL_SEED"] = str(seed)
40
+
41
+ def parse_args():
42
+ parser = argparse.ArgumentParser()
43
+ parser.add_argument("--base_model_path", type=str, default='runwayml/stable-diffusion-v1-5',
44
+ help="Type of checkpoints to use (default: SD 1.5)")
45
+ parser.add_argument("--embman_ckpt", type=str, required=True,
46
+ help="Path to the checkpoint of the embedding manager")
47
+ parser.add_argument("--subject", type=str, required=True)
48
+ parser.add_argument("--example_image_count", type=int, default=-1, help="Number of example images to use")
49
+ parser.add_argument("--out_image_count", type=int, default=4, help="Number of images to generate")
50
+ parser.add_argument("--prompt", type=str, default="a woman z in superman costume")
51
+ parser.add_argument("--noise", dest='noise_level', type=float, default=0)
52
+ parser.add_argument("--randface", action="store_true")
53
+ parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
54
+ help="Guidance scale for the diffusion model")
55
+ parser.add_argument("--id_cfg_scale", type=float, default=1,
56
+ help="CFG scale when generating the identity embeddings")
57
+
58
+ parser.add_argument("--subject_string",
59
+ type=str, default="z",
60
+ help="Subject placeholder string used in prompts to denote the concept.")
61
+ parser.add_argument("--num_vectors", type=int, default=16,
62
+ help="Number of vectors used to represent the subject.")
63
+ parser.add_argument("--num_images_per_row", type=int, default=4,
64
+ help="Number of images to display in a row in the output grid image.")
65
+ parser.add_argument("--num_inference_steps", type=int, default=50,
66
+ help="Number of DDIM inference steps")
67
+ parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
68
+ parser.add_argument("--seed", type=int, default=42,
69
+ help="the seed (for reproducible sampling). Set to -1 to disable.")
70
+ args = parser.parse_args()
71
+
72
+ return args
73
+
74
+ if __name__ == "__main__":
75
+ args = parse_args()
76
+ if args.seed != -1:
77
+ seed_everything(args.seed)
78
+
79
+ if re.match(r"^\d+$", args.device):
80
+ args.device = f"cuda:{args.device}"
81
+ print(f"Using device {args.device}")
82
+
83
+ adaface = AdaFaceWrapper("text2img", args.base_model_path, args.embman_ckpt, args.device,
84
+ args.subject_string, args.num_vectors, args.num_inference_steps)
85
+
86
+ if not args.randface:
87
+ image_folder = args.subject
88
+ if image_folder.endswith("/"):
89
+ image_folder = image_folder[:-1]
90
+
91
+ if os.path.isfile(image_folder):
92
+ # Get the second to the last part of the path
93
+ subject_name = os.path.basename(os.path.dirname(image_folder))
94
+ image_paths = [image_folder]
95
+
96
+ else:
97
+ subject_name = os.path.basename(image_folder)
98
+ image_types = ["*.jpg", "*.png", "*.jpeg"]
99
+ alltype_image_paths = []
100
+ for image_type in image_types:
101
+ # glob returns the full path.
102
+ image_paths = glob.glob(os.path.join(image_folder, image_type))
103
+ if len(image_paths) > 0:
104
+ alltype_image_paths.extend(image_paths)
105
+
106
+ # Filter out images of "*_mask.png"
107
+ alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path]
108
+
109
+ # image_paths contain at most args.example_image_count full image paths.
110
+ if args.example_image_count > 0:
111
+ image_paths = alltype_image_paths[:args.example_image_count]
112
+ else:
113
+ image_paths = alltype_image_paths
114
+ else:
115
+ subject_name = None
116
+ image_paths = None
117
+ image_folder = None
118
+
119
+ subject_name = "randface-" + str(torch.seed()) if args.randface else subject_name
120
+ rand_face_embs = torch.randn(1, 512)
121
+
122
+ pre_face_embs = rand_face_embs if args.randface else None
123
+ noise = torch.randn(args.out_image_count, 4, 64, 64).cuda()
124
+ # args.noise_level: the *relative* std of the noise added to the face embeddings.
125
+ # A noise level of 0.08 could change gender, but 0.06 is usually safe.
126
+ # adaface_subj_embs is not used. It is generated for the purpose of updating the text encoder (within this function call).
127
+ adaface_subj_embs = adaface.generate_adaface_embeddings(image_paths, image_folder, pre_face_embs, args.randface,
128
+ out_id_embs_scale=args.id_cfg_scale, noise_level=args.noise_level,
129
+ update_text_encoder=True)
130
+ images = adaface(noise, args.prompt, args.guidance_scale, args.out_image_count, verbose=True)
131
+ save_images(images, args.num_images_per_row, subject_name, f"guide{args.guidance_scale}", args.noise_level)
adaface/adaface-translate.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from adaface.adaface_wrapper import AdaFaceWrapper
2
+ import torch
3
+ #import torch.nn.functional as F
4
+ from PIL import Image
5
+ import numpy as np
6
+ import os, argparse, glob, re, shutil
7
+
8
+ def str2bool(v):
9
+ if isinstance(v, bool):
10
+ return v
11
+ if v.lower() in ("yes", "true", "t", "y", "1"):
12
+ return True
13
+ elif v.lower() in ("no", "false", "f", "n", "0"):
14
+ return False
15
+ else:
16
+ raise argparse.ArgumentTypeError("Boolean value expected.")
17
+
18
+ def seed_everything(seed):
19
+ np.random.seed(seed)
20
+ torch.manual_seed(seed)
21
+ torch.cuda.manual_seed_all(seed)
22
+ torch.backends.cudnn.deterministic = True
23
+ torch.backends.cudnn.benchmark = False
24
+ os.environ["PL_GLOBAL_SEED"] = str(seed)
25
+
26
+ def parse_args():
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument("--base_model_path", type=str, default='models/realisticvision/realisticVisionV40_v40VAE.safetensors',
29
+ help="Path to the UNet checkpoint (default: RealisticVision 4.0)")
30
+ parser.add_argument("--embman_ckpt", type=str, required=True,
31
+ help="Path to the checkpoint of the embedding manager")
32
+ parser.add_argument("--in_folder", type=str, required=True, help="Path to the folder containing input images")
33
+ # If True, the input folder contains images of mixed subjects.
34
+ # If False, the input folder contains multiple subfolders, each of which contains images of the same subject.
35
+ parser.add_argument("--is_mix_subj_folder", type=str2bool, const=True, default=False, nargs="?",
36
+ help="Whether the input folder contains images of mixed subjects")
37
+ parser.add_argument("--max_images_per_subject", type=int, default=5, help="Number of example images used per subject")
38
+ parser.add_argument("--trans_subject_count", type=int, default=-1, help="Number of example images to be translated")
39
+ parser.add_argument("--out_folder", type=str, required=True, help="Path to the folder saving output images")
40
+ parser.add_argument("--out_count_per_input_image", type=int, default=1, help="Number of output images to generate per input image")
41
+ parser.add_argument("--copy_masks", action="store_true", help="Copy the mask images to the output folder")
42
+ parser.add_argument("--noise", dest='noise_level', type=float, default=0)
43
+ parser.add_argument("--scale", dest='guidance_scale', type=float, default=4,
44
+ help="Guidance scale for the diffusion model")
45
+ parser.add_argument("--ref_img_strength", type=float, default=0.8,
46
+ help="Strength of the reference image in the output image.")
47
+ parser.add_argument("--subject_string",
48
+ type=str, default="z",
49
+ help="Subject placeholder string used in prompts to denote the concept.")
50
+ parser.add_argument("--num_vectors", type=int, default=16,
51
+ help="Number of vectors used to represent the subject.")
52
+ parser.add_argument("--prompt", type=str, default="a person z")
53
+ parser.add_argument("--num_images_per_row", type=int, default=4,
54
+ help="Number of images to display in a row in the output grid image.")
55
+ parser.add_argument("--num_inference_steps", type=int, default=50,
56
+ help="Number of DDIM inference steps")
57
+ parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use. If num_gpus > 1, use accelerate for distributed execution.")
58
+ parser.add_argument("--device", type=str, default="cuda", help="Device to run the model on")
59
+ parser.add_argument("--seed", type=int, default=42,
60
+ help="the seed (for reproducible sampling). Set to -1 to disable.")
61
+ args = parser.parse_args()
62
+
63
+ return args
64
+
65
+ if __name__ == "__main__":
66
+ args = parse_args()
67
+ if args.seed != -1:
68
+ seed_everything(args.seed)
69
+
70
+ # screen -dm -L -Logfile trans_rv4-2.txt accelerate launch --multi_gpu --num_processes=2 scripts/adaface-translate.py
71
+ # --embman_ckpt logs/subjects-celebrity2024-05-16T17-22-46_zero3-ada/checkpoints/embeddings_gs-30000.pt
72
+ # --base_model_path models/realisticvision/realisticVisionV40_v40VAE.safetensors --in_folder /data/username/VGGface2_HQ_masks/
73
+ # --is_mix_subj_folder 0 --out_folder /data/username/VGGface2_HQ_masks_rv4a --copy_masks --num_gpus 2
74
+ if args.num_gpus > 1:
75
+ from accelerate import PartialState
76
+ distributed_state = PartialState()
77
+ args.device = distributed_state.device
78
+ process_index = distributed_state.process_index
79
+ elif re.match(r"^\d+$", args.device):
80
+ args.device = f"cuda:{args.device}"
81
+ distributed_state = None
82
+ process_index = 0
83
+
84
+ adaface = AdaFaceWrapper("img2img", args.base_model_path, args.embman_ckpt, args.device,
85
+ args.subject_string, args.num_vectors, args.num_inference_steps)
86
+
87
+ in_folder = args.in_folder
88
+ if os.path.isfile(in_folder):
89
+ subject_folders = [ os.path.dirname(in_folder) ]
90
+ images_by_subject = [[in_folder]]
91
+ else:
92
+ if not args.is_mix_subj_folder:
93
+ in_folders = [in_folder]
94
+ else:
95
+ in_folders = [ os.path.join(in_folder, subfolder) for subfolder in sorted(os.listdir(in_folder)) ]
96
+
97
+ images_by_subject = []
98
+ subject_folders = []
99
+ for in_folder in in_folders:
100
+ image_types = ["*.jpg", "*.png", "*.jpeg"]
101
+ alltype_image_paths = []
102
+ for image_type in image_types:
103
+ # glob returns the full path.
104
+ image_paths = glob.glob(os.path.join(in_folder, image_type))
105
+ if len(image_paths) > 0:
106
+ alltype_image_paths.extend(image_paths)
107
+
108
+ # Filter out images of "*_mask.png"
109
+ alltype_image_paths = [image_path for image_path in alltype_image_paths if "_mask.png" not in image_path]
110
+ alltype_image_paths = sorted(alltype_image_paths)
111
+
112
+ if not args.is_mix_subj_folder:
113
+ # image_paths contain at most args.max_images_per_subject full image paths.
114
+ if args.max_images_per_subject > 0:
115
+ image_paths = alltype_image_paths[:args.max_images_per_subject]
116
+ else:
117
+ image_paths = alltype_image_paths
118
+
119
+ images_by_subject.append(image_paths)
120
+ subject_folders.append(in_folder)
121
+ else:
122
+ # Each image in the folder is treated as an individual subject.
123
+ images_by_subject.extend([[image_path] for image_path in alltype_image_paths])
124
+ subject_folders.extend([in_folder] * len(alltype_image_paths))
125
+
126
+ if args.trans_subject_count > 0 and len(subject_folders) >= args.trans_subject_count:
127
+ break
128
+
129
+ if args.trans_subject_count > 0:
130
+ images_by_subject = images_by_subject[:args.trans_subject_count]
131
+ subject_folders = subject_folders[:args.trans_subject_count]
132
+
133
+ out_image_count = 0
134
+ out_mask_count = 0
135
+ if not args.out_folder.endswith("/"):
136
+ args.out_folder += "/"
137
+
138
+ if args.num_gpus > 1:
139
+ # Split the subjects across the GPUs.
140
+ subject_folders = subject_folders[process_index::args.num_gpus]
141
+ images_by_subject = images_by_subject[process_index::args.num_gpus]
142
+ #subject_folders, images_by_subject = distributed_state.split_between_processes(zip(subject_folders, images_by_subject))
143
+
144
+ for (subject_folder, image_paths) in zip(subject_folders, images_by_subject):
145
+ # If is_mix_subj_folder, then image_paths only contains 1 image, and we use the file name as the signature of the image.
146
+ # Otherwise, we use the folder name as the signature of the images.
147
+ images_sig = subject_folder if not args.is_mix_subj_folder else os.path.basename(image_paths[0])
148
+
149
+ print(f"Translating {images_sig}...")
150
+ with torch.no_grad():
151
+ adaface_subj_embs = adaface.generate_adaface_embeddings(image_paths, subject_folder, None, False,
152
+ out_id_embs_scale=1, noise_level=args.noise_level,
153
+ update_text_encoder=True)
154
+
155
+ # Replace the first occurrence of "in_folder" with "out_folder" in the path of the subject_folder.
156
+ subject_out_folder = subject_folder.replace(args.in_folder, args.out_folder, 1)
157
+ if not os.path.exists(subject_out_folder):
158
+ os.makedirs(subject_out_folder)
159
+ print(f"Output images will be saved to {subject_out_folder}")
160
+
161
+ in_images = []
162
+ for image_path in image_paths:
163
+ image = Image.open(image_path).convert("RGB").resize((512, 512))
164
+ # [512, 512, 3] -> [3, 512, 512].
165
+ image = np.array(image).transpose(2, 0, 1)
166
+ # Convert the image to a tensor of shape (1, 3, 512, 512) and move it to the GPU.
167
+ image = torch.tensor(image).unsqueeze(0).float().cuda()
168
+ in_images.append(image)
169
+
170
+ # Put all input images of the subject into a batch. This assumes max_images_per_subject is small.
171
+ # NOTE: For simplicity, we do not check overly large batch sizes.
172
+ in_images = torch.cat(in_images, dim=0)
173
+ # in_images: [5, 3, 512, 512].
174
+ # Normalize the pixel values to [0, 1].
175
+ in_images = in_images / 255.0
176
+ num_out_images = len(in_images) * args.out_count_per_input_image
177
+
178
+ with torch.no_grad():
179
+ # args.noise_level: the *relative* std of the noise added to the face embeddings.
180
+ # A noise level of 0.08 could change gender, but 0.06 is usually safe.
181
+ # The returned adaface_subj_embs are already incorporated in the text encoder, and not used explicitly.
182
+ # NOTE: We assume out_count_per_input_image == 1, so that the output images are of the same number as the input images.
183
+ out_images = adaface(in_images, args.prompt, args.guidance_scale, num_out_images, ref_img_strength=args.ref_img_strength)
184
+
185
+ for img_i, img in enumerate(out_images):
186
+ # out_images: subj_1, subj_2, ..., subj_n, subj_1, subj_2, ..., subj_n, ...
187
+ subj_i = img_i % len(in_images)
188
+ copy_i = img_i // len(in_images)
189
+ image_filename_stem, image_fileext = os.path.splitext(os.path.basename(image_paths[subj_i]))
190
+ if copy_i == 0:
191
+ img.save(os.path.join(subject_out_folder, f"{image_filename_stem}{image_fileext}"))
192
+ else:
193
+ img.save(os.path.join(subject_out_folder, f"{image_filename_stem}_{copy_i}{image_fileext}"))
194
+
195
+ if args.copy_masks:
196
+ mask_path = image_paths[subj_i].replace(image_fileext, "_mask.png")
197
+ if os.path.exists(mask_path):
198
+ if copy_i == 0:
199
+ shutil.copy(mask_path, subject_out_folder)
200
+ else:
201
+ mask_filename_stem = image_filename_stem
202
+ shutil.copy(mask_path, os.path.join(subject_out_folder, f"{mask_filename_stem}_{copy_i}_mask.png"))
203
+
204
+ out_mask_count += 1
205
+
206
+ out_image_count += len(out_images)
207
+
208
+ print(f"{out_image_count} output images and {out_mask_count} masks saved to {args.out_folder}")
adaface/adaface_wrapper.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPTextModel
4
+ from diffusers import (
5
+ StableDiffusionPipeline,
6
+ StableDiffusionImg2ImgPipeline,
7
+ UNet2DConditionModel,
8
+ DDIMScheduler,
9
+ AutoencoderKL,
10
+ )
11
+ from insightface.app import FaceAnalysis
12
+ from adaface.arc2face_models import CLIPTextModelWrapper
13
+ from adaface.util import get_arc2face_id_prompt_embs
14
+ import re, os
15
+ import sys
16
+ sys.modules['ldm'] = sys.modules['adaface']
17
+
18
+ class AdaFaceWrapper(nn.Module):
19
+ def __init__(self, pipeline_name, base_model_path, adaface_ckpt_path, device,
20
+ subject_string='z', num_vectors=16,
21
+ num_inference_steps=50, negative_prompt=None,
22
+ use_840k_vae=False, use_ds_text_encoder=False, is_training=False):
23
+ '''
24
+ pipeline_name: "text2img" or "img2img" or None. If None, the unet and vae are
25
+ removed from the pipeline to release RAM.
26
+ '''
27
+ super().__init__()
28
+ self.pipeline_name = pipeline_name
29
+ self.base_model_path = base_model_path
30
+ self.adaface_ckpt_path = adaface_ckpt_path
31
+ self.use_840k_vae = use_840k_vae
32
+ self.use_ds_text_encoder = use_ds_text_encoder
33
+ self.subject_string = subject_string
34
+ self.num_vectors = num_vectors
35
+ self.num_inference_steps = num_inference_steps
36
+ self.device = device
37
+ self.is_training = is_training
38
+ self.initialize_pipeline()
39
+ self.extend_tokenizer_and_text_encoder()
40
+ if negative_prompt is None:
41
+ self.negative_prompt = \
42
+ "flaws in the eyes, flaws in the face, lowres, non-HDRi, low quality, worst quality, artifacts, noise, text, watermark, glitch, " \
43
+ "mutated, ugly, disfigured, hands, partially rendered objects, partially rendered eyes, deformed eyeballs, cross-eyed, blurry, " \
44
+ "mutation, duplicate, out of frame, cropped, mutilated, bad anatomy, deformed, bad proportions, " \
45
+ "nude, naked, nsfw, topless, bare breasts"
46
+ else:
47
+ self.negative_prompt = negative_prompt
48
+
49
+ def load_subj_basis_generator(self, adaface_ckpt_path):
50
+ ckpt = torch.load(adaface_ckpt_path, map_location='cpu')
51
+ string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
52
+ if self.subject_string not in string_to_subj_basis_generator_dict:
53
+ print(f"Subject '{self.subject_string}' not found in the embedding manager.")
54
+ breakpoint()
55
+
56
+ self.subj_basis_generator = string_to_subj_basis_generator_dict[self.subject_string]
57
+ # In the original ckpt, num_out_layers is 16 for layerwise embeddings.
58
+ # But we don't do layerwise embeddings here, so we set it to 1.
59
+ self.subj_basis_generator.num_out_layers = 1
60
+ print(f"Loaded subject basis generator for '{self.subject_string}'.")
61
+ print(repr(self.subj_basis_generator))
62
+ self.subj_basis_generator.to(self.device)
63
+ if self.is_training:
64
+ self.subj_basis_generator.train()
65
+ else:
66
+ self.subj_basis_generator.eval()
67
+
68
+ def initialize_pipeline(self):
69
+ self.load_subj_basis_generator(self.adaface_ckpt_path)
70
+ # arc2face_text_encoder maps the face analysis embedding to 16 face embeddings
71
+ # in the UNet image space.
72
+ arc2face_text_encoder = CLIPTextModelWrapper.from_pretrained(
73
+ 'models/arc2face', subfolder="encoder", torch_dtype=torch.float16
74
+ )
75
+ self.arc2face_text_encoder = arc2face_text_encoder.to(self.device)
76
+
77
+ if self.use_840k_vae:
78
+ # The 840000-step vae model is slightly better in face details than the original vae model.
79
+ # https://huggingface.co/stabilityai/sd-vae-ft-mse-original
80
+ vae = AutoencoderKL.from_single_file("models/diffusers/sd-vae-ft-mse-original/vae-ft-mse-840000-ema-pruned.ckpt", torch_dtype=torch.float16)
81
+ else:
82
+ vae = None
83
+
84
+ if self.use_ds_text_encoder:
85
+ # The dreamshaper v7 finetuned text encoder follows the prompt slightly better than the original text encoder.
86
+ # https://huggingface.co/Lykon/DreamShaper/tree/main/text_encoder
87
+ text_encoder = CLIPTextModel.from_pretrained("models/ds_text_encoder", torch_dtype=torch.float16)
88
+ else:
89
+ text_encoder = None
90
+
91
+ remove_unet = False
92
+
93
+ if self.pipeline_name == "img2img":
94
+ PipelineClass = StableDiffusionImg2ImgPipeline
95
+ elif self.pipeline_name == "text2img":
96
+ PipelineClass = StableDiffusionPipeline
97
+ # pipeline_name is None means only use this instance to generate adaface embeddings, not to generate images.
98
+ elif self.pipeline_name is None:
99
+ PipelineClass = StableDiffusionPipeline
100
+ remove_unet = True
101
+ else:
102
+ raise ValueError(f"Unknown pipeline name: {self.pipeline_name}")
103
+
104
+ if os.path.isfile(self.base_model_path):
105
+ pipeline = PipelineClass.from_single_file(
106
+ self.base_model_path,
107
+ torch_dtype=torch.float16
108
+ )
109
+ else:
110
+ pipeline = PipelineClass.from_pretrained(
111
+ self.base_model_path,
112
+ torch_dtype=torch.float16,
113
+ safety_checker=None
114
+ )
115
+ print(f"Loaded pipeline from {self.base_model_path}.")
116
+
117
+ if self.use_840k_vae:
118
+ pipeline.vae = vae
119
+ print("Replaced the VAE with the 840k-step VAE.")
120
+
121
+ if self.use_ds_text_encoder:
122
+ pipeline.text_encoder = text_encoder
123
+ print("Replaced the text encoder with the DreamShaper text encoder.")
124
+
125
+ if remove_unet:
126
+ # Remove unet and vae to release RAM. Only keep tokenizer and text_encoder.
127
+ pipeline.unet = None
128
+ pipeline.vae = None
129
+ print("Removed UNet and VAE from the pipeline.")
130
+
131
+ noise_scheduler = DDIMScheduler(
132
+ num_train_timesteps=1000,
133
+ beta_start=0.00085,
134
+ beta_end=0.012,
135
+ beta_schedule="scaled_linear",
136
+ clip_sample=False,
137
+ set_alpha_to_one=False,
138
+ steps_offset=1,
139
+ )
140
+
141
+ pipeline.scheduler = noise_scheduler
142
+ self.pipeline = pipeline.to(self.device)
143
+ # FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2.
144
+ # Note there's a second "model" in the path.
145
+ self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
146
+ self.face_app.prepare(ctx_id=0, det_size=(512, 512))
147
+ # Patch the missing tokenizer in the subj_basis_generator.
148
+ if not hasattr(self.subj_basis_generator, 'clip_tokenizer'):
149
+ self.subj_basis_generator.clip_tokenizer = self.pipeline.tokenizer
150
+ print("Patched the missing tokenizer in the subj_basis_generator.")
151
+
152
+ def extend_tokenizer_and_text_encoder(self):
153
+ if self.num_vectors < 1:
154
+ raise ValueError(f"num_vectors has to be larger or equal to 1, but is {self.num_vectors}")
155
+
156
+ tokenizer = self.pipeline.tokenizer
157
+ # Add z0, z1, z2, ..., z15.
158
+ self.placeholder_tokens = []
159
+ for i in range(0, self.num_vectors):
160
+ self.placeholder_tokens.append(f"{self.subject_string}_{i}")
161
+
162
+ self.placeholder_tokens_str = " ".join(self.placeholder_tokens)
163
+
164
+ # Add the new tokens to the tokenizer.
165
+ num_added_tokens = tokenizer.add_tokens(self.placeholder_tokens)
166
+ if num_added_tokens != self.num_vectors:
167
+ raise ValueError(
168
+ f"The tokenizer already contains the token {self.subject_string}. Please pass a different"
169
+ " `subject_string` that is not already in the tokenizer.")
170
+
171
+ print(f"Added {num_added_tokens} tokens ({self.placeholder_tokens_str}) to the tokenizer.")
172
+
173
+ # placeholder_token_ids: [49408, ..., 49423].
174
+ self.placeholder_token_ids = tokenizer.convert_tokens_to_ids(self.placeholder_tokens)
175
+ # print(self.placeholder_token_ids)
176
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
177
+ old_weight = self.pipeline.text_encoder.get_input_embeddings().weight
178
+ self.pipeline.text_encoder.resize_token_embeddings(len(tokenizer))
179
+ new_weight = self.pipeline.text_encoder.get_input_embeddings().weight
180
+ print(f"Resized text encoder token embeddings from {old_weight.shape} to {new_weight.shape} on {new_weight.device}.")
181
+
182
+ # Extend pipeline.text_encoder with the adaface subject emeddings.
183
+ # subj_embs: [16, 768].
184
+ def update_text_encoder_subj_embs(self, subj_embs):
185
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
186
+ token_embeds = self.pipeline.text_encoder.get_input_embeddings().weight.data
187
+ with torch.no_grad():
188
+ for i, token_id in enumerate(self.placeholder_token_ids):
189
+ token_embeds[token_id] = subj_embs[i]
190
+ print(f"Updated {len(self.placeholder_token_ids)} tokens ({self.placeholder_tokens_str}) in the text encoder.")
191
+
192
+ def update_prompt(self, prompt):
193
+ # If the placeholder tokens are already in the prompt, then return the prompt as is.
194
+ if self.placeholder_tokens_str in prompt:
195
+ return prompt
196
+
197
+ # If the subject string 'z' is not in the prompt, then simply prepend the placeholder tokens to the prompt.
198
+ if re.search(r'\b' + self.subject_string + r'\b', prompt) is None:
199
+ print(f"Subject string '{self.subject_string}' not found in the prompt. Adding it.")
200
+ comp_prompt = self.placeholder_tokens_str + " " + prompt
201
+ else:
202
+ # Replace the subject string 'z' with the placeholder tokens.
203
+ comp_prompt = re.sub(r'\b' + self.subject_string + r'\b', self.placeholder_tokens_str, prompt)
204
+ return comp_prompt
205
+
206
+ # image_paths: a list of image paths. image_folder: the parent folder name.
207
+ def generate_adaface_embeddings(self, image_paths, image_folder=None,
208
+ pre_face_embs=None, gen_rand_face=False,
209
+ out_id_embs_scale=1., noise_level=0, update_text_encoder=True):
210
+ # faceid_embeds is a batch of extracted face analysis embeddings (BS * 512 = id_batch_size * 512).
211
+ # If extract_faceid_embeds is True, faceid_embeds is *the same* embedding repeated by id_batch_size times.
212
+ # Otherwise, faceid_embeds is a batch of random embeddings, each instance is different.
213
+ # The same applies to id_prompt_emb.
214
+ # faceid_embeds is in the face analysis embeddings. id_prompt_emb is in the image prompt space.
215
+ # Here id_batch_size = 1, so
216
+ # faceid_embeds: [1, 512]. NOT used later.
217
+ # id_prompt_emb: [1, 16, 768].
218
+ # NOTE: Since return_core_id_embs is True, id_prompt_emb is only the 16 core ID embeddings.
219
+ # arc2face prompt template: "photo of a id person"
220
+ # ID embeddings start from "id person ...". So there are 3 template tokens before the 16 ID embeddings.
221
+ face_image_count, faceid_embeds, id_prompt_emb \
222
+ = get_arc2face_id_prompt_embs(self.face_app, self.pipeline.tokenizer, self.arc2face_text_encoder,
223
+ extract_faceid_embeds=not gen_rand_face,
224
+ pre_face_embs=pre_face_embs,
225
+ # image_folder is passed only for logging purpose.
226
+ # image_paths contains the paths of the images.
227
+ image_folder=image_folder, image_paths=image_paths,
228
+ images_np=None,
229
+ id_batch_size=1,
230
+ device=self.device,
231
+ # input_max_length == 22: only keep the first 22 tokens,
232
+ # including 3 template tokens and 16 ID tokens, and BOS and EOS tokens.
233
+ # The results are indistinguishable from input_max_length=77.
234
+ input_max_length=22,
235
+ noise_level=noise_level,
236
+ return_core_id_embs=True,
237
+ gen_neg_prompt=False,
238
+ verbose=True)
239
+
240
+ if face_image_count == 0:
241
+ return None
242
+
243
+ # adaface_subj_embs: [1, 1, 16, 768].
244
+ # adaface_prompt_embs: [1, 77, 768] (not used).
245
+ adaface_subj_embs, adaface_prompt_embs = \
246
+ self.subj_basis_generator(id_prompt_emb, None, None,
247
+ out_id_embs_scale=out_id_embs_scale,
248
+ is_face=True, is_training=False,
249
+ adaface_prompt_embs_inf_type='full_half_pad')
250
+ # adaface_subj_embs: [16, 768]
251
+ adaface_subj_embs = adaface_subj_embs.squeeze()
252
+ if update_text_encoder:
253
+ self.update_text_encoder_subj_embs(adaface_subj_embs)
254
+ return adaface_subj_embs
255
+
256
+ def encode_prompt(self, prompt, negative_prompt=None, device="cuda", verbose=False):
257
+ if negative_prompt is None:
258
+ negative_prompt = self.negative_prompt
259
+
260
+ prompt = self.update_prompt(prompt)
261
+ if verbose:
262
+ print(f"Prompt: {prompt}")
263
+
264
+ # For some unknown reason, the text_encoder is still on CPU after self.pipeline.to(self.device).
265
+ # So we manually move it to GPU here.
266
+ self.pipeline.text_encoder.to(device)
267
+ # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
268
+ prompt_embeds_, negative_prompt_embeds_ = \
269
+ self.pipeline.encode_prompt(prompt, device=device, num_images_per_prompt=1,
270
+ do_classifier_free_guidance=True, negative_prompt=negative_prompt)
271
+ return prompt_embeds_, negative_prompt_embeds_
272
+
273
+ # ref_img_strength is used only in the img2img pipeline.
274
+ def forward(self, noise, prompt, negative_prompt=None, guidance_scale=4.0,
275
+ out_image_count=4, ref_img_strength=0.8, generator=None, verbose=False):
276
+ if negative_prompt is None:
277
+ negative_prompt = self.negative_prompt
278
+ # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
279
+ prompt_embeds_, negative_prompt_embeds_ = self.encode_prompt(prompt, negative_prompt, device=self.device, verbose=verbose)
280
+ # Repeat the prompt embeddings for all images in the batch.
281
+ prompt_embeds_ = prompt_embeds_.repeat(out_image_count, 1, 1)
282
+ negative_prompt_embeds_ = negative_prompt_embeds_.repeat(out_image_count, 1, 1)
283
+ noise = noise.to(self.device).to(torch.float16)
284
+
285
+ # noise: [BS, 4, 64, 64]
286
+ # When the pipeline is text2img, strength is ignored.
287
+ images = self.pipeline(image=noise,
288
+ prompt_embeds=prompt_embeds_,
289
+ negative_prompt_embeds=negative_prompt_embeds_,
290
+ num_inference_steps=self.num_inference_steps,
291
+ guidance_scale=guidance_scale,
292
+ num_images_per_prompt=1,
293
+ strength=ref_img_strength,
294
+ generator=generator).images
295
+ # images: [BS, 3, 512, 512]
296
+ return images
297
+
adaface/arc2face_models.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPTextModel
4
+ from transformers.models.clip.modeling_clip import CLIPAttention
5
+ from typing import Any, Callable, Dict, Optional, Tuple, Union, List
6
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
7
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
8
+ # from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
9
+ _make_causal_mask = AttentionMaskConverter._make_causal_mask
10
+ _expand_mask = AttentionMaskConverter._expand_mask
11
+
12
+ from adaface.util import add_noise_to_tensor
13
+
14
+ # Extend CLIPAttention by using multiple k_proj and v_proj in each head.
15
+ # To avoid too much increase of computation, we don't extend q_proj.
16
+ class CLIPAttentionMKV(nn.Module):
17
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
18
+
19
+ def __init__(self, config, multiplier=2):
20
+ super().__init__()
21
+ self.config = config
22
+ self.embed_dim = config.hidden_size
23
+ self.num_heads = config.num_attention_heads
24
+ self.head_dim = self.embed_dim // self.num_heads
25
+ if self.head_dim * self.num_heads != self.embed_dim:
26
+ raise ValueError(
27
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
28
+ f" {self.num_heads})."
29
+ )
30
+ self.scale = self.head_dim**-0.5
31
+ self.dropout = config.attention_dropout
32
+ self.multiplier = multiplier
33
+
34
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim * self.multiplier)
35
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim * self.multiplier)
36
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
37
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
38
+
39
+ # The (approximately) repeated token features are repeated along the last dim in tensor
40
+ # (multiplier * num_heads * head_dim), and then reshaped to (bsz, -1, num_heads, head_dim).
41
+ # Therefore, the "multiplier" dim is tucked into the seq_len dim, which looks like
42
+ # [token1_emb, token1_emb, token2_emb, token2_emb, ..., tokenN_emb, tokenN_emb].
43
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
44
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
45
+
46
+ def extend_weights(self, clip_attn_layer, layer_idx, multiplier, noise_std=0.1,
47
+ noise_std_is_relative=True, keep_norm=False, verbose=False):
48
+ self.multiplier *= multiplier
49
+ # q_proj and out_proj are the same as the original CLIPAttention.
50
+ self.q_proj.weight.data = clip_attn_layer.q_proj.weight.data.clone()
51
+ self.q_proj.bias.data = clip_attn_layer.q_proj.bias.data.clone()
52
+ self.out_proj.weight.data = clip_attn_layer.out_proj.weight.data.clone()
53
+ self.out_proj.bias.data = clip_attn_layer.out_proj.bias.data.clone()
54
+
55
+ # bias doesn't need noise perturbation, as after the weights are noised,
56
+ # different copies of the weight/bias will receive different gradients,
57
+ # making the bias terms diverge and identifiable after training.
58
+ self.v_proj.bias.data = clip_attn_layer.v_proj.bias.data.repeat(multiplier)
59
+ self.k_proj.bias.data = clip_attn_layer.k_proj.bias.data.repeat(multiplier)
60
+
61
+ self.v_proj.weight.data = clip_attn_layer.v_proj.weight.data.repeat(multiplier, 1)
62
+ self.k_proj.weight.data = clip_attn_layer.k_proj.weight.data.repeat(multiplier, 1)
63
+
64
+ if noise_std > 0:
65
+ ORIG_V_SHAPE = list(clip_attn_layer.v_proj.weight.shape)
66
+ ORIG_V_SHAPE_D0 = ORIG_V_SHAPE[0]
67
+ # Adding noise to the extra copies of the weights (keep the first copy unchanged).
68
+ self.v_proj.weight.data[ORIG_V_SHAPE_D0:] = \
69
+ add_noise_to_tensor(self.v_proj.weight.data[ORIG_V_SHAPE_D0:],
70
+ noise_std, noise_std_is_relative, keep_norm)
71
+ if verbose:
72
+ NEW_V_SHAPE = list(self.v_proj.weight.shape)
73
+ NOISED_V_SHAPE = list(self.v_proj.weight.data[ORIG_V_SHAPE_D0:].shape)
74
+ print(f"Layer {layer_idx}: {NOISED_V_SHAPE} in {NEW_V_SHAPE} of v_proj is added with {noise_std} noise")
75
+
76
+ ORIG_K_SHAPE = list(clip_attn_layer.k_proj.weight.shape)
77
+ ORIG_K_SHAPE_D0 = ORIG_K_SHAPE[0]
78
+ # Adding noise to the extra copies of the weights.
79
+ self.k_proj.weight.data[ORIG_K_SHAPE_D0:] = \
80
+ add_noise_to_tensor(self.k_proj.weight.data[ORIG_K_SHAPE_D0:],
81
+ noise_std, noise_std_is_relative, keep_norm)
82
+ if verbose:
83
+ NEW_K_SHAPE = list(self.k_proj.weight.shape)
84
+ NOISED_K_SHAPE = list(self.k_proj.weight.data[ORIG_K_SHAPE_D0:].shape)
85
+ print(f"Layer {layer_idx}: {NOISED_K_SHAPE} in {NEW_K_SHAPE} of k_proj is added with {noise_std} noise")
86
+
87
+ def forward(
88
+ self,
89
+ hidden_states: torch.Tensor,
90
+ attention_mask: Optional[torch.Tensor] = None,
91
+ causal_attention_mask: Optional[torch.Tensor] = None,
92
+ output_attentions: Optional[bool] = False,
93
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
94
+ """Input shape: Batch x Time x Channel"""
95
+
96
+ bsz, tgt_len, embed_dim = hidden_states.size()
97
+
98
+ query_states = self.q_proj(hidden_states) * self.scale
99
+ # For key_states and value_states, the multiplier is absorbed into the seq_len (dim 1, shape specified as -1).
100
+ # [token0_head_emb, token0_head_emb, token1_head_emb, token1_head_emb, ..., tokenN-1_head_emb, tokenN-1_head_emb].
101
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
102
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
103
+
104
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
105
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
106
+ key_states = key_states.view(*proj_shape)
107
+ value_states = value_states.view(*proj_shape)
108
+
109
+ src_len = key_states.size(1)
110
+ # src_len0 is the original src_len without the multiplier.
111
+ src_len0 = src_len // self.multiplier
112
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
113
+
114
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
115
+ raise ValueError(
116
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
117
+ f" {attn_weights.size()}"
118
+ )
119
+
120
+ # apply the causal_attention_mask first
121
+ if causal_attention_mask is not None:
122
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len0):
123
+ raise ValueError(
124
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len0)}, but is"
125
+ f" {causal_attention_mask.size()}"
126
+ )
127
+ # The last dim of attn_weights corresponds to [token0, token0, token1, token1, ..., tokenN-1, tokenN-1].
128
+ # If reshaping it as (self.multiplier, src_len0), it will become
129
+ # [[token0, token0, token1, token1, ..., tokenN//2], [tokenN//2+1, tokenN//2+1, ..., tokenN-1, tokenN-1]],
130
+ # and the mask will be applied to wrong elements.
131
+ # If reshaping it as (src_len0, self.multiplier), it will become
132
+ # [[token0, token1, ..., tokenN-1], [token0, token1, ..., tokenN-1]], and then
133
+ # the mask at element i will mask all the multiplier elements at i, which is desired.
134
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len0, self.multiplier) + causal_attention_mask.unsqueeze(4)
135
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
136
+
137
+ if attention_mask is not None:
138
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len0):
139
+ raise ValueError(
140
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len0)}, but is {attention_mask.size()}"
141
+ )
142
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len0, self.multiplier) + attention_mask.unsqueeze(4)
143
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
144
+
145
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
146
+
147
+ if output_attentions:
148
+ # this operation is a bit awkward, but it's required to
149
+ # make sure that attn_weights keeps its gradient.
150
+ # In order to do so, attn_weights have to reshaped
151
+ # twice and have to be reused in the following
152
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
153
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
154
+ else:
155
+ attn_weights_reshaped = None
156
+
157
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
158
+
159
+ attn_output = torch.bmm(attn_probs, value_states)
160
+
161
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
162
+ raise ValueError(
163
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
164
+ f" {attn_output.size()}"
165
+ )
166
+
167
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
168
+ attn_output = attn_output.transpose(1, 2)
169
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
170
+
171
+ attn_output = self.out_proj(attn_output)
172
+
173
+ return attn_output, attn_weights_reshaped
174
+
175
+ class CLIPTextModelWrapper(CLIPTextModel):
176
+ # Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812
177
+ # Modified to accept precomputed token embeddings "input_token_embs" as input or calculate them from input_ids and return them.
178
+ def forward(
179
+ self,
180
+ input_ids: Optional[torch.Tensor] = None,
181
+ attention_mask: Optional[torch.Tensor] = None,
182
+ position_ids: Optional[torch.Tensor] = None,
183
+ output_attentions: Optional[bool] = None,
184
+ output_hidden_states: Optional[bool] = None,
185
+ return_dict: Optional[bool] = None,
186
+ input_token_embs: Optional[torch.Tensor] = None,
187
+ hidden_state_layer_weights: Optional[torch.Tensor] = None,
188
+ return_token_embs: Optional[bool] = False,
189
+ ) -> Union[Tuple, torch.Tensor, BaseModelOutputWithPooling]:
190
+
191
+ if return_token_embs:
192
+ return self.text_model.embeddings.token_embedding(input_ids)
193
+
194
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
195
+
196
+ output_attentions = output_attentions if output_attentions is not None else self.text_model.config.output_attentions
197
+ output_hidden_states = (
198
+ output_hidden_states if output_hidden_states is not None else self.text_model.config.output_hidden_states
199
+ )
200
+ if hidden_state_layer_weights is not None:
201
+ output_hidden_states = True
202
+ return_dict = return_dict if return_dict is not None else self.text_model.config.use_return_dict
203
+
204
+ if input_ids is None:
205
+ raise ValueError("You have to specify input_ids")
206
+
207
+ input_shape = input_ids.size()
208
+ input_ids = input_ids.view(-1, input_shape[-1])
209
+
210
+ hidden_states = self.text_model.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=input_token_embs)
211
+
212
+ # CLIP's text model uses causal mask, prepare it here.
213
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
214
+ causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
215
+ # expand attention_mask
216
+ if attention_mask is not None:
217
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
218
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
219
+
220
+ encoder_outputs = self.text_model.encoder(
221
+ inputs_embeds=hidden_states,
222
+ attention_mask=attention_mask,
223
+ causal_attention_mask=causal_attention_mask,
224
+ output_attentions=output_attentions,
225
+ # output_hidden_states is False by default, and only True if hidden_state_layer_weights is provided.
226
+ output_hidden_states=output_hidden_states,
227
+ return_dict=return_dict,
228
+ )
229
+
230
+ # If output_hidden_states is True, then encoder_outputs[0] is last_hidden_state [1, 22, 768].
231
+ # encoder_outputs[1] is hidden_states, which is a tuple of 13 hidden states, each being [1, 22, 768].
232
+ # encoder_outputs[0] == encoder_outputs[1][12].
233
+ if hidden_state_layer_weights is None:
234
+ last_hidden_state = encoder_outputs[0]
235
+ else:
236
+ num_hidden_state_layers = len(hidden_state_layer_weights)
237
+ last_hidden_states = encoder_outputs[1][-num_hidden_state_layers:]
238
+ hidden_state_layer_weights = hidden_state_layer_weights.to(last_hidden_states[0].dtype)
239
+ # Normalize the weights of to sum to 1 across layers.
240
+ # hidden_state_layer_weights: [3, 1] or [3, 768].
241
+ hidden_state_layer_weights = hidden_state_layer_weights / hidden_state_layer_weights.sum(dim=0, keepdim=True)
242
+ # [3, 1/768] -> [3, 1, 1, 1/768]
243
+ hidden_state_layer_weights = hidden_state_layer_weights.unsqueeze(1).unsqueeze(1)
244
+ # A weighted sum of last_hidden_states.
245
+ # [3, 1, 22, 768] * [3, 1, 1, 1/768] -> [3, 1, 22, 768] -> [1, 22, 768]
246
+ last_hidden_state = (torch.stack(last_hidden_states, dim=0) * hidden_state_layer_weights).sum(dim=0)
247
+
248
+ last_hidden_state = self.text_model.final_layer_norm(last_hidden_state)
249
+
250
+ # self.text_model.eos_token_id == 2 is True.
251
+ if self.text_model.eos_token_id == 2:
252
+ # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
253
+ # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
254
+ # ------------------------------------------------------------
255
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
256
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
257
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
258
+ pooled_output = last_hidden_state[
259
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
260
+ input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
261
+ ]
262
+ else:
263
+ # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
264
+ pooled_output = last_hidden_state[
265
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
266
+ # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
267
+ (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.text_model.eos_token_id)
268
+ .int()
269
+ .argmax(dim=-1),
270
+ ]
271
+
272
+ if not return_dict:
273
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
274
+
275
+ return BaseModelOutputWithPooling(
276
+ last_hidden_state=last_hidden_state,
277
+ pooler_output=pooled_output,
278
+ hidden_states=encoder_outputs.hidden_states,
279
+ attentions=encoder_outputs.attentions,
280
+ )
281
+
282
+ # Applied to layers [begin_layer_idx, end_layer_idx) in the encoder.
283
+ # The layer indexed by end_layer_idx is not included.
284
+ # If both layer indices are -1, then apply to all layers (0-11).
285
+ def extend_clip_attention_MKV_multiplier(self, begin_layer_idx=-1, end_layer_idx=-1, multiplier=2, noise_std=0.1):
286
+ num_extended_layers = 0
287
+
288
+ for layer_idx, layer in enumerate(self.text_model.encoder.layers):
289
+ if begin_layer_idx >= 0 and layer_idx < begin_layer_idx:
290
+ continue
291
+ if end_layer_idx >= 0 and layer_idx >= end_layer_idx:
292
+ break
293
+ # This shouldn't happen, unless self_attn has already been extended as CLIPAttentionMKV.
294
+ if not isinstance(layer.self_attn, (CLIPAttention, CLIPAttentionMKV)):
295
+ breakpoint()
296
+ old_attn_layer = layer.self_attn
297
+ if not isinstance(old_attn_layer, CLIPAttentionMKV):
298
+ layer.self_attn = CLIPAttentionMKV(old_attn_layer.config, 1)
299
+ layer.self_attn.extend_weights(old_attn_layer, layer_idx, multiplier, noise_std, verbose=True)
300
+ num_extended_layers += 1
301
+
302
+ return num_extended_layers
303
+
adaface/subj_basis_generator.py ADDED
@@ -0,0 +1,758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Borrowed from ip-adapter resampler.py.
2
+ # https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
3
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
4
+ # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
5
+
6
+ import math
7
+
8
+ import torch
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange
12
+ from einops.layers.torch import Rearrange
13
+ from transformers import CLIPVisionModel, CLIPTokenizer
14
+
15
+ import numpy as np
16
+ from torch import einsum
17
+ from dataclasses import dataclass
18
+ from typing import Optional, Tuple
19
+ from transformers.utils import ModelOutput
20
+ from adaface.util import arc2face_inverse_face_prompt_embs, gen_gradient_scaler
21
+ from adaface.arc2face_models import CLIPTextModelWrapper
22
+ import sys
23
+ sys.modules['ldm'] = sys.modules['adaface']
24
+
25
+ def reshape_tensor(x, num_heads):
26
+ bs, length, width = x.shape
27
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
28
+ x = x.view(bs, length, num_heads, -1)
29
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
30
+ x = x.transpose(1, 2)
31
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
32
+ x = x.reshape(bs, num_heads, length, -1)
33
+ return x
34
+
35
+ # FFN. Added a Dropout layer at the end, so that it can still load the old ckpt.
36
+ def FeedForward(dim, mult=4, p_dropout=0.1):
37
+ inner_dim = int(dim * mult)
38
+ return nn.Sequential(
39
+ nn.LayerNorm(dim),
40
+ nn.Linear(dim, inner_dim, bias=False),
41
+ nn.GELU(),
42
+ nn.Linear(inner_dim, dim, bias=False),
43
+ nn.Dropout(p_dropout),
44
+ )
45
+
46
+ # IP-Adapter FaceID class. Only used in knn-faces.py.
47
+ # From: https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter_faceid_separate.py
48
+ class IP_MLPProjModel(nn.Module):
49
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
50
+ super().__init__()
51
+
52
+ self.cross_attention_dim = cross_attention_dim
53
+ self.num_tokens = num_tokens
54
+
55
+ self.proj = nn.Sequential(
56
+ nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
57
+ nn.GELU(),
58
+ nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
59
+ )
60
+ self.norm = nn.LayerNorm(cross_attention_dim)
61
+
62
+ def forward(self, id_embeds):
63
+ x = self.proj(id_embeds)
64
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
65
+ x = self.norm(x)
66
+ return x
67
+
68
+ # group_dim: the tensor dimension that corresponds to the multiple groups.
69
+ class LearnedSoftAggregate(nn.Module):
70
+ def __init__(self, num_feat, group_dim, keepdim=False):
71
+ super(LearnedSoftAggregate, self).__init__()
72
+ self.group_dim = group_dim
73
+ # num_feat = 1: element-wise score function & softmax.
74
+ # num_feat > 1: the linear score function is applied to the last dim (features) of the input tensor.
75
+ self.num_feat = num_feat
76
+ self.feat2score = nn.Linear(num_feat, 1, bias=False)
77
+ self.keepdim = keepdim
78
+
79
+ def forward(self, x, score_basis=None):
80
+ # If there's only one mode, do nothing.
81
+ if x.shape[self.group_dim] == 1:
82
+ if self.keepdim:
83
+ return x
84
+ else:
85
+ return x.squeeze(self.group_dim)
86
+
87
+ # Assume the last dim of x is the feature dim.
88
+ if score_basis is None:
89
+ score_basis = x
90
+
91
+ if self.num_feat == 1:
92
+ mode_scores = self.feat2score(score_basis.unsqueeze(-1)).squeeze(-1)
93
+ else:
94
+ mode_scores = self.feat2score(score_basis)
95
+ attn_probs = mode_scores.softmax(dim=self.group_dim)
96
+ x_aggr = (x * attn_probs).sum(dim=self.group_dim, keepdim=self.keepdim)
97
+ return x_aggr
98
+
99
+ def LoRA_ExpandEmbs(input_dim, lora_rank, output_dim, num_modes,
100
+ num_output_vecs, elementwise_affine=True, p_dropout=0.1):
101
+ return nn.Sequential(
102
+ # Project to [BS, lora_rank * output_dim * num_modes].
103
+ # It takes a huge param size. 512 * 32 * 768 * 4 = 6,291,456.
104
+ nn.Linear(input_dim, lora_rank * output_dim * num_modes, bias=False),
105
+ # Reshape to [BS, lora_rank, output_dim].
106
+ Rearrange('b (m q d) -> b m q d', q=lora_rank, m=num_modes, d=output_dim),
107
+ nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
108
+ # Aggregate [BS, num_modes, loar_rank, output_dim] -> [BS, lora_rank, output_dim].
109
+ LearnedSoftAggregate(num_feat=output_dim, group_dim=1, keepdim=False) if num_modes > 1 \
110
+ else Rearrange('b () q d -> b q d'),
111
+ nn.Dropout(p_dropout),
112
+ # Permute to [BS, output_dim, lora_rank].
113
+ Rearrange('b q d -> b d q'),
114
+ # Project to [BS, output_dim, num_output_vecs].
115
+ nn.Linear(lora_rank, num_output_vecs, bias=False),
116
+ # Permute to [BS, num_output_vecs, output_dim].
117
+ Rearrange('b d q -> b q d'),
118
+ nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
119
+ nn.Dropout(p_dropout),
120
+ )
121
+
122
+ def ExpandEmbs(input_dim, output_dim, expansion_ratio, elementwise_affine=True, p_dropout=0.1):
123
+ return nn.Sequential(
124
+ # Project to [BS, num_output_vecs * output_dim].
125
+ nn.Linear(input_dim, expansion_ratio * output_dim, bias=False),
126
+ # Reshape to [BS, num_output_vecs, output_dim].
127
+ Rearrange('b (e d) -> b e d', e=expansion_ratio, d=output_dim),
128
+ nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
129
+ nn.Dropout(p_dropout),
130
+ )
131
+
132
+ # Input: [BS, N, D].
133
+ def MultimodeProjection(input_dim, output_dim=-1, num_modes=4, elementwise_affine=True, p_dropout=0.1):
134
+ if output_dim == -1:
135
+ output_dim = input_dim
136
+
137
+ return nn.Sequential(
138
+ nn.Linear(input_dim, output_dim * num_modes, bias=False),
139
+ # Reshape to [BS, num_output_vecs, output_dim].
140
+ Rearrange('b n (m d) -> b n m d', m=num_modes, d=output_dim),
141
+ nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
142
+ # If num_modes == 1, then simply remove the mode dim. Otherwise, aggregate the modes.
143
+ LearnedSoftAggregate(num_feat=output_dim, group_dim=2, keepdim=False) if num_modes > 1 \
144
+ else Rearrange('b n () d -> b n d'),
145
+ nn.Dropout(p_dropout),
146
+ )
147
+
148
+ # Low-rank to high-rank transformation.
149
+ def Lora2Hira(lora_rank, hira_rank, output_dim, num_modes, elementwise_affine=True, p_dropout=0.1):
150
+ return nn.Sequential(
151
+ # Permute to [BS, output_dim, lora_rank].
152
+ Rearrange('b q d -> b d q'),
153
+ # Project to [BS, output_dim, hira_rank].
154
+ nn.Linear(lora_rank, hira_rank * num_modes, bias=False),
155
+ # Reshape and permute to [BS, num_modes, num_output_vecs, output_dim].
156
+ Rearrange('b d (m q) -> b m q d', m=num_modes, q=hira_rank),
157
+ nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
158
+ # Aggregate [BS, num_modes, hira_rank, output_dim] -> [BS, hira_rank, output_dim].
159
+ LearnedSoftAggregate(num_feat=output_dim, group_dim=1, keepdim=False) if num_modes > 1 \
160
+ else Rearrange('b () q d -> b q d'),
161
+ nn.Dropout(p_dropout),
162
+ )
163
+
164
+ class PerceiverAttention(nn.Module):
165
+ def __init__(self, *, dim, dim_head=64, num_heads=8, elementwise_affine=True):
166
+ super().__init__()
167
+ self.scale = dim_head**-0.5
168
+ self.dim_head = dim_head
169
+ self.num_heads = num_heads
170
+ inner_dim = dim_head * num_heads
171
+
172
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine)
173
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine)
174
+
175
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
176
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
177
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
178
+
179
+ def forward(self, x, latent_queries):
180
+ """
181
+ Args:
182
+ x (torch.Tensor): image features
183
+ shape (b, n1, D)
184
+ latent (torch.Tensor): latent features
185
+ shape (b, n2, D)
186
+ """
187
+ x = self.norm1(x)
188
+ latent_queries = self.norm2(latent_queries)
189
+
190
+ b, l, _ = latent_queries.shape
191
+
192
+ q = self.to_q(latent_queries)
193
+ kv_input = torch.cat((x, latent_queries), dim=-2)
194
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
195
+
196
+ q = reshape_tensor(q, self.num_heads)
197
+ k = reshape_tensor(k, self.num_heads)
198
+ v = reshape_tensor(v, self.num_heads)
199
+
200
+ # attention
201
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
202
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
203
+ attn = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
204
+ out = attn @ v
205
+
206
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
207
+
208
+ return self.to_out(out)
209
+
210
+
211
+ class CrossAttention(nn.Module):
212
+ # output_dim is always the same as input_dim.
213
+ # num_q only matters when q_aware_to_v is True.
214
+ # If q_aware_to_v is False, query x in forward() is still usable.
215
+ def __init__(self, input_dim, num_heads=6, p_dropout=0.05,
216
+ identity_to_q=False, identity_to_k=False, identity_to_v=False, v_has_skip=True,
217
+ q_aware_to_v=True, num_q=416, v_repeat=4, q_aware_to_v_lora_rank=64,
218
+ identity_to_out=False, out_has_skip=False):
219
+ super().__init__()
220
+ dim_head = input_dim // num_heads
221
+ inner_dim = dim_head * num_heads
222
+
223
+ self.num_heads = num_heads
224
+ self.q_aware_to_v = q_aware_to_v
225
+ self.v_has_skip = v_has_skip
226
+ self.to_q = nn.Sequential(
227
+ nn.Linear(input_dim, inner_dim, bias=False),
228
+ nn.LayerNorm(inner_dim, elementwise_affine=True)
229
+ ) if not identity_to_q else nn.Identity()
230
+ self.to_k = nn.Sequential(
231
+ nn.Linear(input_dim, inner_dim, bias=False),
232
+ nn.LayerNorm(inner_dim, elementwise_affine=True)
233
+ ) if not identity_to_k else nn.Identity()
234
+
235
+ self.v_repeat = v_repeat
236
+ self.num_q_group = num_q_group = num_q // v_repeat # 416 / 4 = 104.
237
+
238
+ # If q_aware_to_v is True, then self.to_v consists of num_q projections of input_dim to inner_dim.
239
+ # Otherwise, self.to_v consists of a single projection of input_dim to inner_dim.
240
+ if q_aware_to_v:
241
+ # all_q_mid: 104 * 64 = 6656.
242
+ all_q_mid = num_q_group * q_aware_to_v_lora_rank
243
+ self.to_v = nn.Sequential(
244
+ # number of params: 768 * 6656 = 5,111,808.
245
+ # Input: [BS, 16, 768]. Output: [BS, 16, 104*64] = [BS, 16, 6656].
246
+ # Each 768-dim vec is dispersed into 104 64-dim vecs.
247
+ nn.Linear(input_dim, all_q_mid, bias=False),
248
+ nn.LayerNorm(all_q_mid, elementwise_affine=True),
249
+ # Change the dim of the tensor to [BS, 6656, 16], as Conv1d transforms dim 1.
250
+ Rearrange('b n q -> b q n', q=all_q_mid),
251
+ # Each q_aware_to_v projection has its own linear layer.
252
+ # The total number of parameters will be 6656*768 = 5,111,808.
253
+ # Output: [BS, 104*768, 16]. Each 64 dim feature is expanded to 768 dim.
254
+ nn.Conv1d(
255
+ in_channels=all_q_mid,
256
+ out_channels=num_q_group * input_dim,
257
+ kernel_size=1,
258
+ groups=num_q_group,
259
+ bias=False,
260
+ ),
261
+ # Output: [BS, 104, 16, 768].
262
+ Rearrange('b (q d) n -> b q n d', q=num_q_group, d=input_dim),
263
+ nn.LayerNorm(input_dim, elementwise_affine=True),
264
+ )
265
+ else:
266
+ self.to_v = nn.Sequential(
267
+ nn.Linear(input_dim, inner_dim, bias=False),
268
+ nn.LayerNorm(inner_dim, elementwise_affine=True)
269
+ ) if not identity_to_v else nn.Identity()
270
+
271
+ if identity_to_out:
272
+ assert not out_has_skip, "identity_to_out=True, then out_has_skip has to be False."
273
+
274
+ if identity_to_out:
275
+ self.to_out = nn.Identity()
276
+ else:
277
+ self.to_out = nn.Sequential(
278
+ nn.Linear(input_dim, input_dim, bias=False),
279
+ nn.Dropout(p_dropout),
280
+ nn.LayerNorm(inner_dim, elementwise_affine=True)
281
+ )
282
+
283
+ self.out_has_skip = out_has_skip
284
+ self.attn_drop = nn.Dropout(p_dropout)
285
+
286
+ def forward(self, x, context=None, attn_mat=None, return_attn=False):
287
+ h = self.num_heads
288
+
289
+ if context is None:
290
+ context = x
291
+
292
+ if attn_mat is None:
293
+ # q: [BS, Q, D] -> [BS, Q, D].
294
+ q = self.to_q(x)
295
+ # k: [BS, L, D] -> [BS, L, D].
296
+ k = self.to_k(context)
297
+ # q: [6, 512, 128], k: [6, 17, 128].
298
+ q, k = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k))
299
+
300
+ if self.q_aware_to_v:
301
+ # context: [BS, L, D]. v: [BS, Q, L, D].
302
+ # There are effectively Q to_v projections.
303
+ v = self.to_v(context)
304
+ if self.v_has_skip:
305
+ v = v + context.unsqueeze(1)
306
+ else:
307
+ # v: [BS, L, D].
308
+ v = self.to_v(context)
309
+ if self.v_has_skip:
310
+ v = v + context
311
+
312
+ #print(v.shape)
313
+
314
+ if self.q_aware_to_v:
315
+ # v: [6, 64, 17, 128].
316
+ # v is query-specific, so there's an extra dim for the query.
317
+ v = rearrange(v, 'b q n (h d) -> (b h) q n d', h=h)
318
+ # Each v is for a query group with 512/64 = 8 queries.
319
+ # So each v is repeated 8 times to match the number of queries.
320
+ # v: [6, 64, 17, 128] -> [6, 512, 17, 128].
321
+ v = v.repeat(1, self.v_repeat, 1, 1)
322
+ else:
323
+ v = rearrange(v, 'b n (h d) -> (b h) n d', h=h)
324
+
325
+ if attn_mat is None:
326
+ scale = q.size(-1) ** -0.25
327
+ sim = einsum('b i d, b j d -> b i j', q * scale, k * scale)
328
+ # sim: [6, 64, 17]. 6: bs 1 * h 6.
329
+ # attention, what we cannot get enough of
330
+ # NOTE: the normalization is done across tokens, not across pixels.
331
+ # So for each pixel, the sum of attention scores across tokens is 1.
332
+ attn = sim.softmax(dim=-1)
333
+ attn = self.attn_drop(attn)
334
+ #print(attn.std())
335
+ else:
336
+ attn = attn_mat
337
+
338
+ if self.q_aware_to_v:
339
+ # attn: [6, 32, 17]. v: [6, 32, 17, 128]. 128: dim of each head. out: [6, 32, 128].
340
+ # out is combined with different attn weights and v for different queries.
341
+ out = einsum('b i j, b i j d -> b i d', attn, v)
342
+ else:
343
+ # v: [6, 17, 128]. out: [6, 32, 128].
344
+ out = einsum('b i j, b j d -> b i d', attn, v)
345
+
346
+ # [6, 32, 128] -> [1, 32, 768].
347
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
348
+
349
+ if self.out_has_skip:
350
+ out = self.to_out(out) + out
351
+ else:
352
+ out = self.to_out(out)
353
+
354
+ if return_attn:
355
+ return out, attn
356
+ else:
357
+ return out
358
+
359
+ class SubjBasisGenerator(nn.Module):
360
+ def __init__(
361
+ self,
362
+ # number of cross-attention heads. Half of the number of heads 12 of OpenAI clip-vit-large-patch14:
363
+ # https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
364
+ num_heads=6,
365
+ num_id_vecs={ 'subj': 77, 'bg': 257 }, # number of identity vectors. 18: 16 face tokens + 2 extra tokens. 257: 257 CLIP tokens.
366
+ num_out_embs_per_layer=4, # num_out_embs. subj: 16. bg: 4.
367
+ num_out_layers=16, # number of layers of output embeddings.
368
+ image_embedding_dim=768, # CLIP image feature dimension, as per config.json above.
369
+ # DINO vits16 has 6 attention heads:
370
+ # https://huggingface.co/facebook/dino-vits16/blob/main/config.json
371
+ dino_embedding_dim=384, # DINO object feature dimension for objects.
372
+ output_dim=768, # CLIP text embedding input dimension.
373
+ placeholder_is_bg: bool = False, # Whether the placeholder is for the image background.
374
+ prompt2token_proj_grad_scale: float = 0.4, # Gradient scale for prompt2token_proj.
375
+ zs_extra_words_scale: float = 0.5, # Scale for extra words in the prompt2token_proj.
376
+ learnable_hidden_state_weights_scheme: str = 'per-layer', # none, per-layer.
377
+ bg_prompt_translator_has_to_out_proj: bool = False, # Whether the prompt_trans_layers have a to_out projection.
378
+ ):
379
+ super().__init__()
380
+
381
+ self.placeholder_is_bg = placeholder_is_bg
382
+ self.num_out_layers = num_out_layers
383
+ self.num_out_embs_per_layer = num_out_embs_per_layer
384
+ # subj: 64, bg: 32.
385
+ self.num_out_embs = num_out_layers * num_out_embs_per_layer
386
+ self.output_dim = output_dim
387
+ # num_id_vecs should be the number of core ID embs, 16.
388
+ # However, in such case, pos_embs is not used. So it doesn't matter if it's wrongly set.
389
+ self.num_id_vecs = num_id_vecs['bg'] if placeholder_is_bg else num_id_vecs['subj']
390
+ self.pos_embs = nn.Parameter(torch.randn(1, self.num_id_vecs, output_dim))
391
+ self.pos_embs_ln = nn.LayerNorm(output_dim)
392
+ self.zs_extra_words_scale = zs_extra_words_scale
393
+ self.output_scale = output_dim ** -0.5
394
+ self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
395
+
396
+ if not self.placeholder_is_bg:
397
+ # [1, 384] -> [1, 16, 768].
398
+ # TODO: use CLIPTextModelWrapper as obj_proj_in.
399
+ self.obj_proj_in = ExpandEmbs(dino_embedding_dim, output_dim, expansion_ratio=self.num_id_vecs)
400
+
401
+ # self.prompt2token_proj: [1, 16, 768] -> [1, 77, 768] (with paddings).
402
+ # If self.placeholder_is_bg: prompt2token_proj is set to None.
403
+ self.prompt2token_proj = CLIPTextModelWrapper.from_pretrained('openai/clip-vit-large-patch14')
404
+ self.prompt2token_proj_grad_scale = prompt2token_proj_grad_scale
405
+ self.prompt2token_proj_grad_scaler = gen_gradient_scaler(prompt2token_proj_grad_scale)
406
+ print(f"Subj prompt2token_proj initialized with grad scale of {prompt2token_proj_grad_scale}.")
407
+ # Freeze prompt2token_proj if prompt2token_proj_grad_scale is 0.
408
+ # Set requires_grad to False for all parameters in prompt2token_proj, to save memory taken by the optimizer.
409
+ if prompt2token_proj_grad_scale == 0:
410
+ self.freeze_prompt2token_proj()
411
+
412
+ self.prompt2token_proj_attention_multiplier = -1
413
+ self.initialize_hidden_state_layer_weights(learnable_hidden_state_weights_scheme, 'cpu')
414
+ self.pad_embeddings = None
415
+ self.bg_proj_in = None
416
+ else:
417
+ # For background placeholders, face and object embeddings are not used as they are foreground.
418
+ self.obj_proj_in = None
419
+ self.prompt2token_proj = None
420
+ print("Bg prompt2token_proj is set to None.")
421
+
422
+ self.bg_proj_in = nn.Sequential(
423
+ nn.Linear(image_embedding_dim, output_dim, bias=False),
424
+ nn.LayerNorm(output_dim),
425
+ )
426
+
427
+ self.latent_queries = nn.Parameter(torch.randn(1, self.num_out_embs, output_dim))
428
+ self.latent_queries_ln = nn.LayerNorm(output_dim)
429
+
430
+ self.bg_prompt_translator_has_to_out_proj = bg_prompt_translator_has_to_out_proj
431
+ identity_to_v = False
432
+ v_has_skip = not identity_to_v # True
433
+ identity_to_out = not bg_prompt_translator_has_to_out_proj # True
434
+ out_has_skip = not identity_to_out # False
435
+ # prompt_translator has a to_v projection with skip connection, and doesn't have a to_out projection.
436
+ # dim=768, num_heads=6.
437
+ self.prompt_translator = \
438
+ CrossAttention(input_dim=output_dim, num_heads=num_heads, p_dropout=0.05,
439
+ identity_to_q=False, identity_to_k=False, identity_to_v=identity_to_v,
440
+ q_aware_to_v=False, v_has_skip=v_has_skip,
441
+ num_q=0, # When not q_aware_to_v, num_q is not referenced.
442
+ identity_to_out=identity_to_out,
443
+ out_has_skip=out_has_skip)
444
+ '''
445
+ prompt_translator: CLIPEncoder
446
+ # https://github.com/huggingface/transformers/blob/1872bde7fc6a5d6796bd742bc2dc38eaf8069c5d/src/transformers/models/clip/modeling_clip.py#L566
447
+ # CLIPEncoder.layers: 12 layers of CLIPEncoderLayer, each being
448
+ (0): CLIPEncoderLayer(
449
+ (self_attn): CLIPAttention(
450
+ (k_proj): Linear(in_features=768, out_features=768, bias=True)
451
+ (v_proj): Linear(in_features=768, out_features=768, bias=True)
452
+ (q_proj): Linear(in_features=768, out_features=768, bias=True)
453
+ (out_proj): Linear(in_features=768, out_features=768, bias=True)
454
+ )
455
+ (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
456
+ (mlp): CLIPMLP(
457
+ (activation_fn): QuickGELUActivation()
458
+ (fc1): Linear(in_features=768, out_features=3072, bias=True)
459
+ (fc2): Linear(in_features=3072, out_features=768, bias=True)
460
+ )
461
+ (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
462
+ )
463
+ '''
464
+
465
+ print(repr(self))
466
+
467
+ # raw_id_embs: ArcFace embeddings for faces (not used since we have arc2face_id_embs),
468
+ # or DINO embeddings for objects.
469
+ # arc2face_id_embs: [BS, 16, 768], the core identity embeddings generated by Arc2Face.
470
+ def forward(self, arc2face_id_embs, clip_features=None, raw_id_embs=None, out_id_embs_scale=1.0,
471
+ is_face=True, is_training=False, adaface_prompt_embs_inf_type='full_half_pad'):
472
+
473
+ if not self.placeholder_is_bg:
474
+ BS = arc2face_id_embs.shape[0]
475
+ else:
476
+ # If bg, then arc2face_id_embs is set to None, but clip_features is not None.
477
+ BS = clip_features.shape[0]
478
+
479
+ adaface_prompt_embs = None
480
+ if not hasattr(self, 'clip_tokenizer'):
481
+ self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
482
+
483
+ # No need to use raw_id_embs if placeholder_is_bg.
484
+ if not self.placeholder_is_bg:
485
+ if is_face:
486
+ assert arc2face_id_embs is not None
487
+ # arc2face_embs has been projected to the (modified) prompt embedding space
488
+ # by arc2face_forward_face_embs. This prompt embedding space is modified because Arc2Face finetuned
489
+ # the text encoder and the U-Net.
490
+ # in embedding_manager: [BS, 16, 768] -> [BS, 77, 768].
491
+ # arc2face_id_embs is part of arc2face_embs: [BS, 77, 768] -> [BS, 16, 768].
492
+ # adaface_prompt_embs is projected to the prompt embedding spaces. This is the
493
+ # original U-Net prompt embedding space.
494
+
495
+ # hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]]
496
+ hidden_state_layer_weights = self.hidden_state_layer_weights_grad_scaler(self.hidden_state_layer_weights)
497
+ # return_emb_types: a list of strings, each string is among
498
+ # ['full', 'core', 'full_pad', 'full_half_pad', 'full_zeroed_extra', 'b_core_e'].
499
+ # Using b_core_e is more computationally efficient than using full_zeroed_extra.
500
+ # But there is an unknow BUG that causes crash when using b_core_e.
501
+ if is_training:
502
+ return_emb_types = ['full_pad', 'core']
503
+ else:
504
+ # adaface_prompt_embs_inf_type: default is full_half_pad, same as training.
505
+ return_emb_types = [adaface_prompt_embs_inf_type, 'core']
506
+
507
+ if self.pad_embeddings is None:
508
+ self.generate_pad_embeddings()
509
+ else:
510
+ self.pad_embeddings = self.pad_embeddings.to(arc2face_id_embs.device)
511
+
512
+ with torch.set_grad_enabled(self.training and self.prompt2token_proj_grad_scale != 0):
513
+ # If list_extra_words is not None, then core_id_embs: [BS, 18, 768], three leading words, the 16 identity tokens
514
+ # and (at most) two extra words in full_prompt_embs, without BOS and EOS.
515
+ # If list_extra_words is None, then core_id_embs: [BS, 16, 768], the 16 identity tokens in full_prompt_embs.
516
+ # hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]]
517
+ # zs_extra_words_scale is only effective when list_extra_words is not None.
518
+ # adaface_prompt_embs: [BS, 77, 768], core_id_embs: [BS, 16, 768].
519
+ adaface_prompt_embs, core_id_embs = \
520
+ arc2face_inverse_face_prompt_embs(self.clip_tokenizer,
521
+ self.prompt2token_proj,
522
+ arc2face_id_embs,
523
+ list_extra_words=None,
524
+ return_emb_types=return_emb_types,
525
+ pad_embeddings=self.pad_embeddings,
526
+ hidden_state_layer_weights=hidden_state_layer_weights,
527
+ input_max_length=77, zs_extra_words_scale=self.zs_extra_words_scale)
528
+ # Reduce the update rate to prompt2token_proj.
529
+ adaface_prompt_embs = self.prompt2token_proj_grad_scaler(adaface_prompt_embs)
530
+ core_id_embs = self.prompt2token_proj_grad_scaler(core_id_embs)
531
+ elif raw_id_embs is not None:
532
+ # id_embs: [BS, 384] -> [BS, 18, 768].
533
+ # obj_proj_in is expected to project the DINO object features to
534
+ # the token embedding space. So no need to use prompt2token_proj.
535
+ id_embs = self.obj_proj_in(raw_id_embs)
536
+ else:
537
+ breakpoint()
538
+ else:
539
+ # Otherwise, context is the ad-hoc CLIP image features.
540
+ # id_embs: [BS, 257, 768].
541
+ id_embs = self.bg_proj_in(clip_features)
542
+
543
+ if self.placeholder_is_bg:
544
+ id_embs = id_embs + self.pos_embs_ln(self.pos_embs)
545
+ latent_queries = self.latent_queries_ln(self.latent_queries).repeat(BS, 1, 1)
546
+ # If bg, we don't have to use a specific attn layer for each 4-vec set. Instead, one attn layer can generate 257 embs,
547
+ # and we take the first 16*4=64.
548
+ # Output of prompt_translator is exactly num_out_embs == 64 tokens. id_embs_out: [BS, 64, 768].
549
+ # prompt_translator: better named as bg_prompt_translator. It maps the bg features
550
+ # to bg prompt embeddings.
551
+ with torch.set_grad_enabled(self.training):
552
+ id_embs_out = self.prompt_translator(latent_queries, id_embs)
553
+ # [BS, 64, 768] -> [BS, 16, 4, 768]
554
+ id_embs_out = id_embs_out.reshape(BS, self.num_out_layers, -1, self.output_dim)
555
+ adaface_subj_embs = id_embs_out * self.output_scale # * 0.036
556
+ else:
557
+ # adaface_subj_embs: [BS, 16, 768] -> [BS, 1, 16, 768] -> [BS, 16, 16, 768]
558
+ adaface_subj_embs = core_id_embs.unsqueeze(1).repeat(1, self.num_out_layers, 1, 1)
559
+
560
+ # If out_id_embs_scale < 1, adaface_subj_embs is a mix of adaface_subj_embs and pad_embeddings.
561
+ if out_id_embs_scale != 1:
562
+ # pad_embeddings: [77, 768] -> [16, 768] -> [1, 1, 16, 768].
563
+ pad_embeddings = self.pad_embeddings[4:4+self.num_out_embs_per_layer].unsqueeze(0).unsqueeze(0)
564
+ adaface_subj_embs = adaface_subj_embs * out_id_embs_scale \
565
+ + pad_embeddings * (1 - out_id_embs_scale)
566
+
567
+ return adaface_subj_embs, adaface_prompt_embs
568
+
569
+ def initialize_hidden_state_layer_weights(self, learnable_hidden_state_weights_scheme, device):
570
+ if learnable_hidden_state_weights_scheme == 'none':
571
+ self.hidden_state_layer_weights = None
572
+ # A grad scaler with alpha =1 is nn.Identity(), which outputs None given None as input.
573
+ self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(1)
574
+ print("hidden_state_layer_weights is set to None.")
575
+
576
+ elif learnable_hidden_state_weights_scheme == 'per-layer':
577
+ # Learnable weights of the last 3 layers, initialized to putting more focus on the last layer.
578
+ # 'per-layer': Different weights for different layers, but the same for different channels.
579
+ # hidden_state_layer_weights: [3, 1].
580
+ self.hidden_state_layer_weights = nn.Parameter(torch.tensor([[1.0], [2.0], [4.0]], device=device),
581
+ requires_grad=True)
582
+ self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(5)
583
+ print("hidden_state_layer_weights initialized as per-layer [1, 2, 4], with grad scaler 5.")
584
+ else:
585
+ breakpoint()
586
+
587
+ def generate_pad_embeddings(self):
588
+ # clip_embeddings: CLIPTextEmbeddings instance. pad_embeddings is generated after
589
+ # prompt2token_proj is loaded from the finetuned weight. It seems such pad embeddings perform
590
+ # slightly better than the original pad embeddings.
591
+ clip_embeddings = self.prompt2token_proj.text_model.embeddings
592
+ # clip_embeddings() and clip_embeddings.token_embedding() differ in that
593
+ # clip_embeddings() adds positional embeddings, while clip_embeddings.token_embedding() doesn't.
594
+ # Adding positional embeddings seems to help somewhat.
595
+ # pad_tokens: pad_token_id 49407 repeated 77 times.
596
+ # pad_token_id is the EOS token. But BOS is 49406.
597
+ pad_tokens = torch.tensor([self.clip_tokenizer.pad_token_id]).to(clip_embeddings.token_embedding.weight.device).repeat(77)
598
+ # pad_embeddings: [77, 768].
599
+ pad_embeddings = clip_embeddings(pad_tokens)[0]
600
+ # We don't allow face recon to influence the pad embeddings.
601
+ # Otherwise, face identity will leak into the pad embeddings.
602
+ self.pad_embeddings = pad_embeddings.detach()
603
+
604
+ def extend_prompt2token_proj_attention(self, begin_layer_idx=-1, end_layer_idx=-1, multiplier=2, noise_std=0.1):
605
+ if multiplier > 1:
606
+ num_extended_layers = self.prompt2token_proj.extend_clip_attention_MKV_multiplier(begin_layer_idx, end_layer_idx, multiplier, noise_std)
607
+ self.prompt2token_proj_attention_multiplier = multiplier
608
+ print(f"{num_extended_layers} layers in prompt2token_proj_attention are x{multiplier}")
609
+
610
+ def freeze_prompt2token_proj(self):
611
+ # If bg, then prompt2token_proj is set to None. Therefore no need to freeze it.
612
+ # Then we don't have to check whether it's for subj or bg.
613
+ if self.prompt2token_proj is not None:
614
+ frozen_param_names = []
615
+ for param_name, param in self.prompt2token_proj.named_parameters():
616
+ if param.requires_grad:
617
+ param.requires_grad = False
618
+ frozen_param_names.append(param_name)
619
+ # If param is already frozen, then no need to freeze it again.
620
+ print(f"{len(frozen_param_names)} params in Subj prompt2token_proj is frozen.")
621
+ #print(f"Frozen parameters:\n{frozen_param_names}")
622
+
623
+ def __repr__(self):
624
+ type_sig = 'subj' if not self.placeholder_is_bg else 'bg'
625
+ # Fix compatability with the previous version.
626
+ if not hasattr(self, 'bg_prompt_translator_has_to_out_proj'):
627
+ self.bg_prompt_translator_has_to_out_proj = False
628
+ if not hasattr(self, 'num_out_embs'):
629
+ self.num_out_embs = -1
630
+ return f"{type_sig} SubjBasisGenerator: num_out_embs={self.num_out_embs}, " \
631
+ f"bg_prompt_translator_has_to_out_proj={self.bg_prompt_translator_has_to_out_proj}"
632
+
633
+ @dataclass
634
+ class BaseModelOutputWithPooling2(ModelOutput):
635
+ """
636
+ Base class for model's outputs that also contains a pooling of the last hidden states.
637
+
638
+ Args:
639
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
640
+ Sequence of hidden-states at the output of the last layer of the model.
641
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
642
+ Last layer hidden-state of the first token of the sequence (classification token) after further processing
643
+ through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
644
+ the classification token after processing through a linear layer and a tanh activation function. The linear
645
+ layer weights are trained from the next sentence prediction (classification) objective during pretraining.
646
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
647
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
648
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
649
+
650
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
651
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
652
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
653
+ sequence_length)`.
654
+
655
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
656
+ heads.
657
+ """
658
+
659
+ last_hidden_state: torch.FloatTensor = None
660
+ pooler_output: torch.FloatTensor = None
661
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
662
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
663
+ attn_mask: Optional[torch.FloatTensor] = None
664
+
665
+ # Revised from CLIPVisionTransformer to support attention mask.
666
+ # self: a CLIPVisionTransformer instance.
667
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L821
668
+ # pixel_values: preprocessed B*C*H*W images. [BS, 3, 224, 224]
669
+ # attn_mask: B*H*W attention mask.
670
+ def CLIPVisionTransformer_forward(self, pixel_values = None, attn_mask=None,
671
+ output_attentions = None,
672
+ output_hidden_states = None, return_dict = None):
673
+
674
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
675
+ output_hidden_states = (
676
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
677
+ )
678
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
679
+
680
+ if pixel_values is None:
681
+ raise ValueError("You have to specify pixel_values")
682
+
683
+ # Visual tokens are flattended in embeddings().
684
+ # self.embeddings: CLIPVisionEmbeddings.
685
+ # hidden_states: [BS, 257, 1280]. 257: 16*16 (patch_embeds) + 1 (class_embeds).
686
+ # 16*16 is output from Conv2d(3, 1280, kernel_size=(14, 14), stride=(14, 14), bias=False).
687
+ hidden_states = self.embeddings(pixel_values)
688
+ hidden_states = self.pre_layrnorm(hidden_states)
689
+
690
+ if attn_mask is not None:
691
+ # feat_edge_size: 16.
692
+ feat_edge_size = np.sqrt(hidden_states.shape[1] - 1).astype(int)
693
+ # attn_mask: [BS, 512, 512] -> [BS, 1, 16, 16].
694
+ attn_mask = F.interpolate(attn_mask.unsqueeze(1), size=(feat_edge_size, feat_edge_size), mode='nearest')
695
+ # Flatten the mask: [BS, 1, 16, 16] => [BS, 1, 256].
696
+ attn_mask = attn_mask.flatten(2)
697
+ # Prepend 1 to the mask: [BS, 1, 256] => [BS, 1, 257].
698
+ # This 1 corresponds to class_embeds, which is always attended to.
699
+ attn_mask = torch.cat([torch.ones_like(attn_mask[:, :, :1]), attn_mask], dim=-1)
700
+ attn_mask_pairs = torch.matmul(attn_mask.transpose(-1, -2), attn_mask).unsqueeze(1)
701
+ else:
702
+ attn_mask_pairs = None
703
+
704
+ # encoder: CLIPEncoder.
705
+ encoder_outputs = self.encoder(
706
+ inputs_embeds=hidden_states,
707
+ # New feature: (***The official documentation is wrong***)
708
+ # attention_mask (`torch.Tensor` of shape `(batch_size, 1, sequence_length, sequence_length)`, *optional*):
709
+ # Mask to avoid performing attention on pairs of token. Mask values selected in `[0, 1]`:
710
+ # - 1 for pairs that are **not masked**,
711
+ # - 0 for pairs that are **masked**.
712
+ # attention_mask is eventually used by CLIPEncoderLayer:
713
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L370
714
+ attention_mask=attn_mask_pairs,
715
+ output_attentions=output_attentions, # False
716
+ output_hidden_states=output_hidden_states, # True
717
+ return_dict=return_dict, # True
718
+ )
719
+
720
+ # last_hidden_state: [BS, 257, 1280]
721
+ last_hidden_state = encoder_outputs[0]
722
+ pooled_output = last_hidden_state[:, 0, :]
723
+ pooled_output = self.post_layernorm(pooled_output)
724
+
725
+ # return_dict is True.
726
+ if not return_dict:
727
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
728
+
729
+ return BaseModelOutputWithPooling2(
730
+ last_hidden_state=last_hidden_state,
731
+ pooler_output=pooled_output,
732
+ hidden_states=encoder_outputs.hidden_states,
733
+ attentions=encoder_outputs.attentions,
734
+ # Newly added: return resized flattened attention mask.
735
+ # [BS, 1, 257] -> [BS, 257, 1]
736
+ attn_mask=attn_mask.permute(0, 2, 1) if attn_mask is not None else None
737
+ )
738
+
739
+
740
+ class CLIPVisionModelWithMask(CLIPVisionModel):
741
+ def __init__(self, config):
742
+ super().__init__(config)
743
+ # Replace vision_model.forward() with the new one that supports mask.
744
+ self.vision_model.forward = CLIPVisionTransformer_forward.__get__(self.vision_model)
745
+
746
+ def forward(self, pixel_values = None, attn_mask = None, output_attentions = None,
747
+ output_hidden_states = None, return_dict = None):
748
+
749
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
750
+
751
+ return self.vision_model(
752
+ pixel_values=pixel_values,
753
+ attn_mask=attn_mask,
754
+ output_attentions=output_attentions,
755
+ output_hidden_states=output_hidden_states,
756
+ return_dict=return_dict,
757
+ )
758
+
adaface/util.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from PIL import Image
6
+ import cv2
7
+
8
+ # add_noise_to_tensor() adds a fixed amount of noise to the tensor.
9
+ def add_noise_to_tensor(ts, noise_std, noise_std_is_relative=True, keep_norm=False,
10
+ std_dim=-1, norm_dim=-1):
11
+ if noise_std_is_relative:
12
+ ts_std_mean = ts.std(dim=std_dim).mean().detach()
13
+ noise_std *= ts_std_mean
14
+
15
+ noise = torch.randn_like(ts) * noise_std
16
+ if keep_norm:
17
+ orig_norm = ts.norm(dim=norm_dim, keepdim=True)
18
+ ts = ts + noise
19
+ new_norm = ts.norm(dim=norm_dim, keepdim=True).detach()
20
+ ts = ts * orig_norm / (new_norm + 1e-8)
21
+ else:
22
+ ts = ts + noise
23
+
24
+ return ts
25
+
26
+
27
+ # Revised from RevGrad, by removing the grad negation.
28
+ class ScaleGrad(torch.autograd.Function):
29
+ @staticmethod
30
+ def forward(ctx, input_, alpha_, debug=False):
31
+ ctx.save_for_backward(alpha_, debug)
32
+ output = input_
33
+ if debug:
34
+ print(f"input: {input_.abs().mean().item()}")
35
+ return output
36
+
37
+ @staticmethod
38
+ def backward(ctx, grad_output): # pragma: no cover
39
+ # saved_tensors returns a tuple of tensors.
40
+ alpha_, debug = ctx.saved_tensors
41
+ if ctx.needs_input_grad[0]:
42
+ grad_output2 = grad_output * alpha_
43
+ if debug:
44
+ print(f"grad_output2: {grad_output2.abs().mean().item()}")
45
+ else:
46
+ grad_output2 = None
47
+ return grad_output2, None, None
48
+
49
+ class GradientScaler(nn.Module):
50
+ def __init__(self, alpha=1., debug=False, *args, **kwargs):
51
+ """
52
+ A gradient scaling layer.
53
+ This layer has no parameters, and simply scales the gradient in the backward pass.
54
+ """
55
+ super().__init__(*args, **kwargs)
56
+
57
+ self._alpha = torch.tensor(alpha, requires_grad=False)
58
+ self._debug = torch.tensor(debug, requires_grad=False)
59
+
60
+ def forward(self, input_):
61
+ _debug = self._debug if hasattr(self, '_debug') else False
62
+ return ScaleGrad.apply(input_, self._alpha.to(input_.device), _debug)
63
+
64
+ def gen_gradient_scaler(alpha, debug=False):
65
+ if alpha == 1:
66
+ return nn.Identity()
67
+ if alpha > 0:
68
+ return GradientScaler(alpha, debug=debug)
69
+ else:
70
+ assert alpha == 0
71
+ # Don't use lambda function here, otherwise the object can't be pickled.
72
+ return torch.detach
73
+
74
+ #@torch.autocast(device_type="cuda")
75
+ # In AdaFaceWrapper, input_max_length is 22.
76
+ def arc2face_forward_face_embs(tokenizer, arc2face_text_encoder, face_embs,
77
+ input_max_length=77, return_full_and_core_embs=True):
78
+
79
+ '''
80
+ arc2face_text_encoder: arc2face_models.py CLIPTextModelWrapper instance.
81
+ face_embs: (N, 512) normalized ArcFace embeddings.
82
+ return_full_and_core_embs: Return both the full prompt embeddings and the core embeddings.
83
+ If False, return only the core embeddings.
84
+
85
+ '''
86
+
87
+ # arcface_token_id: 1014
88
+ arcface_token_id = tokenizer.encode("id", add_special_tokens=False)[0]
89
+
90
+ # This step should be quite fast, and there's no need to cache the input_ids.
91
+ input_ids = tokenizer(
92
+ "photo of a id person",
93
+ truncation=True,
94
+ padding="max_length",
95
+ max_length=input_max_length, #tokenizer.model_max_length,
96
+ return_tensors="pt",
97
+ ).input_ids.to(face_embs.device)
98
+ # input_ids: [1, 77] or [3, 77] (during training).
99
+ input_ids = input_ids.repeat(len(face_embs), 1)
100
+ face_embs_dtype = face_embs.dtype
101
+ face_embs = face_embs.to(arc2face_text_encoder.dtype)
102
+ # face_embs_padded: [1, 512] -> [1, 768].
103
+ face_embs_padded = F.pad(face_embs, (0, arc2face_text_encoder.config.hidden_size - face_embs.shape[-1]), "constant", 0)
104
+ # arc2face_text_encoder(input_ids=input_ids, ...) is called twice. The first is only to get the token embeddings (the shallowest mapping).
105
+ # The second call does the ordinary CLIP text encoding pass.
106
+ token_embs = arc2face_text_encoder(input_ids=input_ids, return_token_embs=True)
107
+ token_embs[input_ids==arcface_token_id] = face_embs_padded
108
+
109
+ prompt_embeds = arc2face_text_encoder(
110
+ input_ids=input_ids,
111
+ input_token_embs=token_embs,
112
+ return_token_embs=False
113
+ )[0]
114
+
115
+ # Restore the original dtype of prompt_embeds: float16 -> float32.
116
+ prompt_embeds = prompt_embeds.to(face_embs_dtype)
117
+
118
+ if return_full_and_core_embs:
119
+ # token 4: 'id' in "photo of a id person".
120
+ # 4:20 are the most important 16 embeddings that contain the subject's identity.
121
+ # [N, 77, 768] -> [N, 16, 768]
122
+ return prompt_embeds, prompt_embeds[:, 4:20]
123
+ else:
124
+ # [N, 16, 768]
125
+ return prompt_embeds[:, 4:20]
126
+
127
+ def get_b_core_e_embeddings(prompt_embeds, length=22):
128
+ b_core_e_embs = torch.cat([ prompt_embeds[:, :length], prompt_embeds[:, [-1]] ], dim=1)
129
+ return b_core_e_embs
130
+
131
+ # return_emb_types: a list of strings, each string is among ['full', 'core', 'full_zeroed_extra', 'b_core_e'].
132
+ def arc2face_inverse_face_prompt_embs(clip_tokenizer, inverse_text_encoder, face_prompt_embs, list_extra_words,
133
+ return_emb_types, pad_embeddings, hidden_state_layer_weights=None,
134
+ input_max_length=77, zs_extra_words_scale=0.5):
135
+
136
+ '''
137
+ inverse_text_encoder: arc2face_models.py CLIPTextModelWrapper instance with **custom weights**.
138
+ inverse_text_encoder is NOT the original arc2face text encoder, but retrained to do inverse mapping.
139
+ face_prompt_embs: (BS, 16, 768). Only the core embeddings, no paddings.
140
+ list_extra_words: [s_1, ..., s_BS], each s_i is a list of extra words to be added to the prompt.
141
+ return_full_and_core_embs: Return both the full prompt embeddings and the core embeddings.
142
+ If False, return only the core embeddings.
143
+ '''
144
+
145
+ if list_extra_words is not None:
146
+ if len(list_extra_words) != len(face_prompt_embs):
147
+ if len(face_prompt_embs) > 1:
148
+ print("Warn: list_extra_words has different length as face_prompt_embs.")
149
+ if len(list_extra_words) == 1:
150
+ list_extra_words = list_extra_words * len(face_prompt_embs)
151
+ else:
152
+ breakpoint()
153
+ else:
154
+ # len(face_prompt_embs) == 1, this occurs when same_subject_in_batch == True, e.g. in do_mix_prompt_distillation.
155
+ # But list_extra_words always corresponds to the actual batch size. So we only take the first element.
156
+ list_extra_words = list_extra_words[:1]
157
+
158
+ for extra_words in list_extra_words:
159
+ assert len(extra_words.split()) <= 2, "Each extra_words string should consist of at most 2 words."
160
+ # 16 ", " are placeholders for face_prompt_embs.
161
+ prompt_templates = [ "photo of a " + ", " * 16 + list_extra_words[i] for i in range(len(list_extra_words)) ]
162
+ else:
163
+ # 16 ", " are placeholders for face_prompt_embs.
164
+ # No extra words are added to the prompt.
165
+ prompt_templates = [ "photo of a " + ", " * 16 for _ in range(len(face_prompt_embs)) ]
166
+
167
+ # This step should be quite fast, and there's no need to cache the input_ids.
168
+ # input_ids: [BS, 77].
169
+ input_ids = clip_tokenizer(
170
+ prompt_templates,
171
+ truncation=True,
172
+ padding="max_length",
173
+ max_length=input_max_length,
174
+ return_tensors="pt",
175
+ ).input_ids.to(face_prompt_embs.device)
176
+
177
+ face_prompt_embs_dtype = face_prompt_embs.dtype
178
+ face_prompt_embs = face_prompt_embs.to(inverse_text_encoder.dtype)
179
+
180
+ # token_embs: [1, 77, 768]. This call is only to get the template token embeddings (the shallowest mapping).
181
+ token_embs = inverse_text_encoder(input_ids=input_ids, return_token_embs=True)
182
+ # token 4: first ", " in the template prompt.
183
+ # Replace embeddings of 16 placeholder ", " with face_prompt_embs.
184
+ token_embs[:, 4:20] = face_prompt_embs
185
+
186
+ # This call does the ordinary CLIP text encoding pass.
187
+ prompt_embeds = inverse_text_encoder(
188
+ input_ids=input_ids,
189
+ input_token_embs=token_embs,
190
+ hidden_state_layer_weights=hidden_state_layer_weights,
191
+ return_token_embs=False
192
+ )[0]
193
+
194
+ # Restore the original dtype of prompt_embeds: float16 -> float32.
195
+ prompt_embeds = prompt_embeds.to(face_prompt_embs_dtype)
196
+ # token 4: first ", " in the template prompt.
197
+ # 4:20 are the most important 16 embeddings that contain the subject's identity.
198
+ # 20:22 are embeddings of the (at most) two extra words.
199
+ # [N, 77, 768] -> [N, 16, 768]
200
+ core_prompt_embs = prompt_embeds[:, 4:20]
201
+ if list_extra_words is not None:
202
+ # [N, 16, 768] -> [N, 18, 768]
203
+ extra_words_embs = prompt_embeds[:, 20:22] * zs_extra_words_scale
204
+ core_prompt_embs = torch.cat([core_prompt_embs, extra_words_embs], dim=1)
205
+
206
+ return_prompts = []
207
+ for emb_type in return_emb_types:
208
+ if emb_type == 'full':
209
+ return_prompts.append(prompt_embeds)
210
+ elif emb_type == 'full_half_pad':
211
+ prompt_embeds2 = prompt_embeds.clone()
212
+ PADS = prompt_embeds2.shape[1] - 23
213
+ if PADS >= 2:
214
+ # Fill half of the remaining embeddings with pad embeddings.
215
+ prompt_embeds2[:, 22:22+PADS//2] = pad_embeddings[22:22+PADS//2]
216
+ return_prompts.append(prompt_embeds2)
217
+ elif emb_type == 'full_pad':
218
+ prompt_embeds2 = prompt_embeds.clone()
219
+ # Fill the 22nd to the second last embeddings with pad embeddings.
220
+ prompt_embeds2[:, 22:-1] = pad_embeddings[22:-1]
221
+ return_prompts.append(prompt_embeds2)
222
+ elif emb_type == 'core':
223
+ return_prompts.append(core_prompt_embs)
224
+ elif emb_type == 'full_zeroed_extra':
225
+ prompt_embeds2 = prompt_embeds.clone()
226
+ # Only add two pad embeddings. The remaining embeddings are set to 0.
227
+ # Make the positional embeddings align with the actual positions.
228
+ prompt_embeds2[:, 22:24] = pad_embeddings[22:24]
229
+ prompt_embeds2[:, 24:-1] = 0
230
+ return_prompts.append(prompt_embeds2)
231
+ elif emb_type == 'b_core_e':
232
+ # The first 22 embeddings, plus the last EOS embedding.
233
+ b_core_e_embs = get_b_core_e_embeddings(prompt_embeds, length=22)
234
+ return_prompts.append(b_core_e_embs)
235
+ else:
236
+ breakpoint()
237
+
238
+ return return_prompts
239
+
240
+ # if pre_face_embs is None, generate random face embeddings [BS, 512].
241
+ # image_folder is passed only for logging purpose. image_paths contains the paths of the images.
242
+ def get_arc2face_id_prompt_embs(face_app, clip_tokenizer, arc2face_text_encoder,
243
+ extract_faceid_embeds, pre_face_embs,
244
+ image_folder, image_paths, images_np,
245
+ id_batch_size, device,
246
+ input_max_length=77, noise_level=0.0,
247
+ return_core_id_embs=False,
248
+ gen_neg_prompt=False, verbose=False):
249
+ face_image_count = 0
250
+
251
+ if extract_faceid_embeds:
252
+ faceid_embeds = []
253
+ if image_paths is not None:
254
+ images_np = []
255
+ for image_path in image_paths:
256
+ image_np = np.array(Image.open(image_path))
257
+ images_np.append(image_np)
258
+
259
+ for i, image_np in enumerate(images_np):
260
+ image_obj = Image.fromarray(image_np).resize((512, 512), Image.NEAREST)
261
+ # Remove alpha channel if it exists.
262
+ if image_obj.mode == 'RGBA':
263
+ image_obj = image_obj.convert('RGB')
264
+ # This seems NOT a bug. The input image should be in BGR format, as per
265
+ # https://github.com/deepinsight/insightface/issues/524
266
+ image_np = cv2.cvtColor(np.array(image_obj), cv2.COLOR_RGB2BGR)
267
+ image_np = np.array(image_obj)
268
+
269
+ face_infos = face_app.get(image_np)
270
+ if verbose and image_paths is not None:
271
+ print(image_paths[i], len(face_infos))
272
+ # Assume all images belong to the same subject. Therefore, we can skip the images with no face detected.
273
+ if len(face_infos) == 0:
274
+ continue
275
+ # only use the maximum face
276
+ face_info = sorted(face_infos, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1]
277
+ # Each faceid_embed: [1, 512]
278
+ faceid_embeds.append(torch.from_numpy(face_info.normed_embedding).unsqueeze(0))
279
+ face_image_count += 1
280
+
281
+ if verbose:
282
+ if image_folder is not None:
283
+ print(f"Extracted ID embeddings from {face_image_count} images in {image_folder}")
284
+ else:
285
+ print(f"Extracted ID embeddings from {face_image_count} images")
286
+
287
+ if len(faceid_embeds) == 0:
288
+ print("No face detected. Use a random face instead.")
289
+ faceid_embeds = torch.randn(id_batch_size, 512).to(device=device, dtype=torch.float16)
290
+ else:
291
+ # faceid_embeds: [10, 512]
292
+ faceid_embeds = torch.cat(faceid_embeds, dim=0)
293
+ # faceid_embeds: [10, 512] -> [1, 512].
294
+ # and the resulted prompt embeddings are the same.
295
+ faceid_embeds = faceid_embeds.mean(dim=0, keepdim=True).to(device=device, dtype=torch.float16)
296
+ else:
297
+ # Random face embeddings. faceid_embeds: [BS, 512].
298
+ if pre_face_embs is None:
299
+ faceid_embeds = torch.randn(id_batch_size, 512)
300
+ else:
301
+ faceid_embeds = pre_face_embs
302
+ if pre_face_embs.shape[0] == 1:
303
+ faceid_embeds = faceid_embeds.repeat(id_batch_size, 1)
304
+
305
+ faceid_embeds = faceid_embeds.to(device=device, dtype=torch.float16)
306
+
307
+ if noise_level > 0:
308
+ # If id_batch_size > 1, after adding noises, the id_batch_size embeddings will be different.
309
+ faceid_embeds = add_noise_to_tensor(faceid_embeds, noise_level, noise_std_is_relative=True, keep_norm=True)
310
+
311
+ faceid_embeds = F.normalize(faceid_embeds, p=2, dim=-1)
312
+
313
+ # arc2face_pos_prompt_emb, arc2face_neg_prompt_emb: [BS, 77, 768]
314
+ with torch.no_grad():
315
+ arc2face_pos_prompt_emb, arc2face_pos_core_prompt_emb = \
316
+ arc2face_forward_face_embs(clip_tokenizer, arc2face_text_encoder,
317
+ faceid_embeds, input_max_length=input_max_length,
318
+ return_full_and_core_embs=True)
319
+ if return_core_id_embs:
320
+ arc2face_pos_prompt_emb = arc2face_pos_core_prompt_emb
321
+ # If extract_faceid_embeds, we assume all images are from the same subject, and the batch dim of faceid_embeds is 1.
322
+ # So we need to repeat faceid_embeds.
323
+ if extract_faceid_embeds:
324
+ faceid_embeds = faceid_embeds.repeat(id_batch_size, 1)
325
+ arc2face_pos_prompt_emb = arc2face_pos_prompt_emb.repeat(id_batch_size, 1, 1)
326
+
327
+ if gen_neg_prompt:
328
+ with torch.no_grad():
329
+ arc2face_neg_prompt_emb, arc2face_neg_core_prompt_emb = \
330
+ arc2face_forward_face_embs(clip_tokenizer, arc2face_text_encoder,
331
+ torch.zeros_like(faceid_embeds),
332
+ input_max_length=input_max_length,
333
+ return_full_and_core_embs=True)
334
+ if return_core_id_embs:
335
+ arc2face_neg_prompt_emb = arc2face_neg_core_prompt_emb
336
+
337
+ #if extract_faceid_embeds:
338
+ # arc2face_neg_prompt_emb = arc2face_neg_prompt_emb.repeat(id_batch_size, 1, 1)
339
+ return face_image_count, faceid_embeds, arc2face_pos_prompt_emb, arc2face_neg_prompt_emb
340
+ else:
341
+ return face_image_count, faceid_embeds, arc2face_pos_prompt_emb
342
+
adaface_wrapper.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPTextModel
4
+ from diffusers import (
5
+ StableDiffusionPipeline,
6
+ StableDiffusionImg2ImgPipeline,
7
+ UNet2DConditionModel,
8
+ DDIMScheduler,
9
+ AutoencoderKL,
10
+ )
11
+ from insightface.app import FaceAnalysis
12
+ from adaface.arc2face_models import CLIPTextModelWrapper
13
+ from adaface.util import get_arc2face_id_prompt_embs
14
+ import re, os
15
+ import sys
16
+ sys.modules['ldm'] = sys.modules['adaface']
17
+
18
+ class AdaFaceWrapper(nn.Module):
19
+ def __init__(self, pipeline_name, base_model_path, adaface_ckpt_path, device,
20
+ subject_string='z', num_vectors=16,
21
+ num_inference_steps=50, negative_prompt=None,
22
+ use_840k_vae=False, use_ds_text_encoder=False, is_training=False):
23
+ '''
24
+ pipeline_name: "text2img" or "img2img" or None. If None, the unet and vae are
25
+ removed from the pipeline to release RAM.
26
+ '''
27
+ super().__init__()
28
+ self.pipeline_name = pipeline_name
29
+ self.base_model_path = base_model_path
30
+ self.adaface_ckpt_path = adaface_ckpt_path
31
+ self.use_840k_vae = use_840k_vae
32
+ self.use_ds_text_encoder = use_ds_text_encoder
33
+ self.subject_string = subject_string
34
+ self.num_vectors = num_vectors
35
+ self.num_inference_steps = num_inference_steps
36
+ self.device = device
37
+ self.is_training = is_training
38
+ self.initialize_pipeline()
39
+ self.extend_tokenizer_and_text_encoder()
40
+ if negative_prompt is None:
41
+ self.negative_prompt = \
42
+ "flaws in the eyes, flaws in the face, lowres, non-HDRi, low quality, worst quality, artifacts, noise, text, watermark, glitch, " \
43
+ "mutated, ugly, disfigured, hands, partially rendered objects, partially rendered eyes, deformed eyeballs, cross-eyed, blurry, " \
44
+ "mutation, duplicate, out of frame, cropped, mutilated, bad anatomy, deformed, bad proportions, " \
45
+ "nude, naked, nsfw, topless, bare breasts"
46
+ else:
47
+ self.negative_prompt = negative_prompt
48
+
49
+ def load_subj_basis_generator(self, adaface_ckpt_path):
50
+ ckpt = torch.load(adaface_ckpt_path, map_location='cpu')
51
+ string_to_subj_basis_generator_dict = ckpt["string_to_subj_basis_generator_dict"]
52
+ if self.subject_string not in string_to_subj_basis_generator_dict:
53
+ print(f"Subject '{self.subject_string}' not found in the embedding manager.")
54
+ breakpoint()
55
+
56
+ self.subj_basis_generator = string_to_subj_basis_generator_dict[self.subject_string]
57
+ # In the original ckpt, num_out_layers is 16 for layerwise embeddings.
58
+ # But we don't do layerwise embeddings here, so we set it to 1.
59
+ self.subj_basis_generator.num_out_layers = 1
60
+ print(f"Loaded subject basis generator for '{self.subject_string}'.")
61
+ print(repr(self.subj_basis_generator))
62
+ self.subj_basis_generator.to(self.device)
63
+ if self.is_training:
64
+ self.subj_basis_generator.train()
65
+ else:
66
+ self.subj_basis_generator.eval()
67
+
68
+ def initialize_pipeline(self):
69
+ self.load_subj_basis_generator(self.adaface_ckpt_path)
70
+ # arc2face_text_encoder maps the face analysis embedding to 16 face embeddings
71
+ # in the UNet image space.
72
+ arc2face_text_encoder = CLIPTextModelWrapper.from_pretrained(
73
+ 'models/arc2face', subfolder="encoder", torch_dtype=torch.float16
74
+ )
75
+ self.arc2face_text_encoder = arc2face_text_encoder.to(self.device)
76
+
77
+ if self.use_840k_vae:
78
+ # The 840000-step vae model is slightly better in face details than the original vae model.
79
+ # https://huggingface.co/stabilityai/sd-vae-ft-mse-original
80
+ vae = AutoencoderKL.from_single_file("models/diffusers/sd-vae-ft-mse-original/vae-ft-mse-840000-ema-pruned.ckpt", torch_dtype=torch.float16)
81
+ else:
82
+ vae = None
83
+
84
+ if self.use_ds_text_encoder:
85
+ # The dreamshaper v7 finetuned text encoder follows the prompt slightly better than the original text encoder.
86
+ # https://huggingface.co/Lykon/DreamShaper/tree/main/text_encoder
87
+ text_encoder = CLIPTextModel.from_pretrained("models/ds_text_encoder", torch_dtype=torch.float16)
88
+ else:
89
+ text_encoder = None
90
+
91
+ remove_unet = False
92
+
93
+ if self.pipeline_name == "img2img":
94
+ PipelineClass = StableDiffusionImg2ImgPipeline
95
+ elif self.pipeline_name == "text2img":
96
+ PipelineClass = StableDiffusionPipeline
97
+ # pipeline_name is None means only use this instance to generate adaface embeddings, not to generate images.
98
+ elif self.pipeline_name is None:
99
+ PipelineClass = StableDiffusionPipeline
100
+ remove_unet = True
101
+ else:
102
+ raise ValueError(f"Unknown pipeline name: {self.pipeline_name}")
103
+
104
+ if os.path.isfile(self.base_model_path):
105
+ pipeline = PipelineClass.from_single_file(
106
+ self.base_model_path,
107
+ torch_dtype=torch.float16
108
+ )
109
+ else:
110
+ pipeline = PipelineClass.from_pretrained(
111
+ self.base_model_path,
112
+ torch_dtype=torch.float16,
113
+ safety_checker=None
114
+ )
115
+ print(f"Loaded pipeline from {self.base_model_path}.")
116
+
117
+ if self.use_840k_vae:
118
+ pipeline.vae = vae
119
+ print("Replaced the VAE with the 840k-step VAE.")
120
+
121
+ if self.use_ds_text_encoder:
122
+ pipeline.text_encoder = text_encoder
123
+ print("Replaced the text encoder with the DreamShaper text encoder.")
124
+
125
+ if remove_unet:
126
+ # Remove unet and vae to release RAM. Only keep tokenizer and text_encoder.
127
+ pipeline.unet = None
128
+ pipeline.vae = None
129
+ print("Removed UNet and VAE from the pipeline.")
130
+
131
+ noise_scheduler = DDIMScheduler(
132
+ num_train_timesteps=1000,
133
+ beta_start=0.00085,
134
+ beta_end=0.012,
135
+ beta_schedule="scaled_linear",
136
+ clip_sample=False,
137
+ set_alpha_to_one=False,
138
+ steps_offset=1,
139
+ )
140
+
141
+ pipeline.scheduler = noise_scheduler
142
+ self.pipeline = pipeline.to(self.device)
143
+ # FaceAnalysis will try to find the ckpt in: models/insightface/models/antelopev2.
144
+ # Note there's a second "model" in the path.
145
+ self.face_app = FaceAnalysis(name='antelopev2', root='models/insightface', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
146
+ self.face_app.prepare(ctx_id=0, det_size=(512, 512))
147
+ # Patch the missing tokenizer in the subj_basis_generator.
148
+ if not hasattr(self.subj_basis_generator, 'clip_tokenizer'):
149
+ self.subj_basis_generator.clip_tokenizer = self.pipeline.tokenizer
150
+ print("Patched the missing tokenizer in the subj_basis_generator.")
151
+
152
+ def extend_tokenizer_and_text_encoder(self):
153
+ if self.num_vectors < 1:
154
+ raise ValueError(f"num_vectors has to be larger or equal to 1, but is {self.num_vectors}")
155
+
156
+ tokenizer = self.pipeline.tokenizer
157
+ # Add z0, z1, z2, ..., z15.
158
+ self.placeholder_tokens = []
159
+ for i in range(0, self.num_vectors):
160
+ self.placeholder_tokens.append(f"{self.subject_string}_{i}")
161
+
162
+ self.placeholder_tokens_str = " ".join(self.placeholder_tokens)
163
+
164
+ # Add the new tokens to the tokenizer.
165
+ num_added_tokens = tokenizer.add_tokens(self.placeholder_tokens)
166
+ if num_added_tokens != self.num_vectors:
167
+ raise ValueError(
168
+ f"The tokenizer already contains the token {self.subject_string}. Please pass a different"
169
+ " `subject_string` that is not already in the tokenizer.")
170
+
171
+ print(f"Added {num_added_tokens} tokens ({self.placeholder_tokens_str}) to the tokenizer.")
172
+
173
+ # placeholder_token_ids: [49408, ..., 49423].
174
+ self.placeholder_token_ids = tokenizer.convert_tokens_to_ids(self.placeholder_tokens)
175
+ # print(self.placeholder_token_ids)
176
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
177
+ old_weight = self.pipeline.text_encoder.get_input_embeddings().weight
178
+ self.pipeline.text_encoder.resize_token_embeddings(len(tokenizer))
179
+ new_weight = self.pipeline.text_encoder.get_input_embeddings().weight
180
+ print(f"Resized text encoder token embeddings from {old_weight.shape} to {new_weight.shape} on {new_weight.device}.")
181
+
182
+ # Extend pipeline.text_encoder with the adaface subject emeddings.
183
+ # subj_embs: [16, 768].
184
+ def update_text_encoder_subj_embs(self, subj_embs):
185
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
186
+ token_embeds = self.pipeline.text_encoder.get_input_embeddings().weight.data
187
+ with torch.no_grad():
188
+ for i, token_id in enumerate(self.placeholder_token_ids):
189
+ token_embeds[token_id] = subj_embs[i]
190
+ print(f"Updated {len(self.placeholder_token_ids)} tokens ({self.placeholder_tokens_str}) in the text encoder.")
191
+
192
+ def update_prompt(self, prompt):
193
+ # If the placeholder tokens are already in the prompt, then return the prompt as is.
194
+ if self.placeholder_tokens_str in prompt:
195
+ return prompt
196
+
197
+ # If the subject string 'z' is not in the prompt, then simply prepend the placeholder tokens to the prompt.
198
+ if re.search(r'\b' + self.subject_string + r'\b', prompt) is None:
199
+ print(f"Subject string '{self.subject_string}' not found in the prompt. Adding it.")
200
+ comp_prompt = self.placeholder_tokens_str + " " + prompt
201
+ else:
202
+ # Replace the subject string 'z' with the placeholder tokens.
203
+ comp_prompt = re.sub(r'\b' + self.subject_string + r'\b', self.placeholder_tokens_str, prompt)
204
+ return comp_prompt
205
+
206
+ # image_paths: a list of image paths. image_folder: the parent folder name.
207
+ def generate_adaface_embeddings(self, image_paths, image_folder=None,
208
+ pre_face_embs=None, gen_rand_face=False,
209
+ out_id_embs_scale=1., noise_level=0, update_text_encoder=True):
210
+ # faceid_embeds is a batch of extracted face analysis embeddings (BS * 512 = id_batch_size * 512).
211
+ # If extract_faceid_embeds is True, faceid_embeds is *the same* embedding repeated by id_batch_size times.
212
+ # Otherwise, faceid_embeds is a batch of random embeddings, each instance is different.
213
+ # The same applies to id_prompt_emb.
214
+ # faceid_embeds is in the face analysis embeddings. id_prompt_emb is in the image prompt space.
215
+ # Here id_batch_size = 1, so
216
+ # faceid_embeds: [1, 512]. NOT used later.
217
+ # id_prompt_emb: [1, 16, 768].
218
+ # NOTE: Since return_core_id_embs is True, id_prompt_emb is only the 16 core ID embeddings.
219
+ # arc2face prompt template: "photo of a id person"
220
+ # ID embeddings start from "id person ...". So there are 3 template tokens before the 16 ID embeddings.
221
+ face_image_count, faceid_embeds, id_prompt_emb \
222
+ = get_arc2face_id_prompt_embs(self.face_app, self.pipeline.tokenizer, self.arc2face_text_encoder,
223
+ extract_faceid_embeds=not gen_rand_face,
224
+ pre_face_embs=pre_face_embs,
225
+ # image_folder is passed only for logging purpose.
226
+ # image_paths contains the paths of the images.
227
+ image_folder=image_folder, image_paths=image_paths,
228
+ images_np=None,
229
+ id_batch_size=1,
230
+ device=self.device,
231
+ # input_max_length == 22: only keep the first 22 tokens,
232
+ # including 3 template tokens and 16 ID tokens, and BOS and EOS tokens.
233
+ # The results are indistinguishable from input_max_length=77.
234
+ input_max_length=22,
235
+ noise_level=noise_level,
236
+ return_core_id_embs=True,
237
+ gen_neg_prompt=False,
238
+ verbose=True)
239
+
240
+ if face_image_count == 0:
241
+ return None
242
+
243
+ # adaface_subj_embs: [1, 1, 16, 768].
244
+ # adaface_prompt_embs: [1, 77, 768] (not used).
245
+ adaface_subj_embs, adaface_prompt_embs = \
246
+ self.subj_basis_generator(id_prompt_emb, None, None,
247
+ out_id_embs_scale=out_id_embs_scale,
248
+ is_face=True, is_training=False,
249
+ adaface_prompt_embs_inf_type='full_half_pad')
250
+ # adaface_subj_embs: [16, 768]
251
+ adaface_subj_embs = adaface_subj_embs.squeeze()
252
+ if update_text_encoder:
253
+ self.update_text_encoder_subj_embs(adaface_subj_embs)
254
+ return adaface_subj_embs
255
+
256
+ def encode_prompt(self, prompt, negative_prompt=None, device="cuda", verbose=False):
257
+ if negative_prompt is None:
258
+ negative_prompt = self.negative_prompt
259
+
260
+ prompt = self.update_prompt(prompt)
261
+ if verbose:
262
+ print(f"Prompt: {prompt}")
263
+
264
+ # For some unknown reason, the text_encoder is still on CPU after self.pipeline.to(self.device).
265
+ # So we manually move it to GPU here.
266
+ self.pipeline.text_encoder.to(device)
267
+ # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
268
+ prompt_embeds_, negative_prompt_embeds_ = \
269
+ self.pipeline.encode_prompt(prompt, device=device, num_images_per_prompt=1,
270
+ do_classifier_free_guidance=True, negative_prompt=negative_prompt)
271
+ return prompt_embeds_, negative_prompt_embeds_
272
+
273
+ # ref_img_strength is used only in the img2img pipeline.
274
+ def forward(self, noise, prompt, negative_prompt=None, guidance_scale=4.0,
275
+ out_image_count=4, ref_img_strength=0.8, generator=None, verbose=False):
276
+ if negative_prompt is None:
277
+ negative_prompt = self.negative_prompt
278
+ # prompt_embeds_, negative_prompt_embeds_: [1, 77, 768]
279
+ prompt_embeds_, negative_prompt_embeds_ = self.encode_prompt(prompt, negative_prompt, device=self.device, verbose=verbose)
280
+ # Repeat the prompt embeddings for all images in the batch.
281
+ prompt_embeds_ = prompt_embeds_.repeat(out_image_count, 1, 1)
282
+ negative_prompt_embeds_ = negative_prompt_embeds_.repeat(out_image_count, 1, 1)
283
+ noise = noise.to(self.device).to(torch.float16)
284
+
285
+ # noise: [BS, 4, 64, 64]
286
+ # When the pipeline is text2img, strength is ignored.
287
+ images = self.pipeline(image=noise,
288
+ prompt_embeds=prompt_embeds_,
289
+ negative_prompt_embeds=negative_prompt_embeds_,
290
+ num_inference_steps=self.num_inference_steps,
291
+ guidance_scale=guidance_scale,
292
+ num_images_per_prompt=1,
293
+ strength=ref_img_strength,
294
+ generator=generator).images
295
+ # images: [BS, 3, 512, 512]
296
+ return images
297
+
app.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('./')
3
+
4
+ from adaface.adaface_wrapper import AdaFaceWrapper
5
+ import torch
6
+ from insightface.app import FaceAnalysis
7
+ from PIL import Image
8
+ import numpy as np
9
+ import random
10
+
11
+ import gradio as gr
12
+ import spaces
13
+ import argparse
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('--adaface_ckpt_path', type=str,
16
+ default='models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt')
17
+ parser.add_argument('--gpu', type=int, default=None)
18
+ parser.add_argument('--ip', type=str, default="0.0.0.0")
19
+ args = parser.parse_args()
20
+
21
+ # global variable
22
+ MAX_SEED = np.iinfo(np.int32).max
23
+ if torch.cuda.is_available():
24
+ device = "cuda" if args.gpu is None else f"cuda:{args.gpu}"
25
+ else:
26
+ device = "cpu"
27
+ dtype = torch.float16
28
+
29
+ # base_model_path is only used for initialization, not really used in the inference.
30
+ adaface = AdaFaceWrapper(pipeline_name="text2img", base_model_path="models/sar/sar.safetensors",
31
+ adaface_ckpt_path=args.adaface_ckpt_path, device=device)
32
+
33
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
34
+ if randomize_seed:
35
+ seed = random.randint(0, MAX_SEED)
36
+ return seed
37
+
38
+ def swap_to_gallery(images):
39
+ # Update uploaded_files_gallery, show files, hide clear_button_column
40
+ # Or:
41
+ # Update uploaded_init_img_gallery, show init_img_files, hide init_clear_button_column
42
+ return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(value=images, visible=False)
43
+
44
+ def remove_back_to_files():
45
+ # Hide uploaded_files_gallery, show clear_button_column, hide files, reset init_img_selected_idx
46
+ # Or:
47
+ # Hide uploaded_init_img_gallery, hide init_clear_button_column, show init_img_files, reset init_img_selected_idx
48
+ return gr.update(visible=False), gr.update(visible=False), gr.update(value=None, visible=True)
49
+
50
+ def update_out_gallery(images):
51
+ #rows = (len(images) + 1) // 2 # Calculate the number of rows needed
52
+ return gr.update(height=600)
53
+
54
+ @spaces.GPU
55
+ def generate_image(image_paths, guidance_scale, adaface_id_cfg_scale,
56
+ num_images, prompt, negative_prompt, seed, progress=gr.Progress(track_tqdm=True)):
57
+
58
+ if image_paths is None or len(image_paths) == 0:
59
+ raise gr.Error(f"Cannot find any input face image! Please upload a face image.")
60
+
61
+ if prompt is None:
62
+ prompt = ""
63
+
64
+ adaface_subj_embs = \
65
+ adaface.generate_adaface_embeddings(image_folder=None, image_paths=image_paths,
66
+ out_id_embs_scale=adaface_id_cfg_scale, update_text_encoder=True)
67
+
68
+ if adaface_subj_embs is None:
69
+ raise gr.Error(f"Failed to detect any faces! Please try with other images")
70
+
71
+ generator = torch.Generator(device=device).manual_seed(seed)
72
+ print(f"Manual seed: {seed}")
73
+ # Generate two images each time for the user to select from.
74
+ noise = torch.randn(num_images, 3, 512, 512, device=device, generator=generator)
75
+ #print(noise.abs().sum())
76
+ # samples: A list of PIL Image instances.
77
+ samples = adaface(noise, prompt, negative_prompt, guidance_scale=guidance_scale, out_image_count=num_images, generator=generator, verbose=True)
78
+ return samples
79
+
80
+ ### Description
81
+ title = r"""
82
+ <h1>AdaFace: A Versatile Face Encoder for Zero-Shot Diffusion Model Personalization</h1>
83
+ """
84
+
85
+ description = r"""
86
+ <b>Official demo</b> for our NeurIPS 2024 submission <b>AdaFace: A Versatile Face Encoder for Zero-Shot Diffusion Model Personalization</b>.<br>
87
+
88
+ ❗️**Tips**❗️
89
+ 1. Upload one or more images of a person. If multiple faces are detected, we use the largest one.
90
+ 2. Increase <b>AdaFace CFG Scale</b> (preferred) or <b>Guidance scale</b> and/or to highlight fine facial features.
91
+ 3. AdaFace Text-to-Video: <a href="https://huggingface.co/spaces/adaface-neurips/adaface-animate" style="display: inline-flex; align-items: center;">
92
+ AdaFace-Animate
93
+ <img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow" alt="Hugging Face Spaces" style="margin-left: 5px;">
94
+ </a>
95
+
96
+ **TODO**
97
+ - ControlNet integration.
98
+ """
99
+
100
+ css = '''
101
+ .gradio-container {width: 85% !important}
102
+ '''
103
+ with gr.Blocks(css=css) as demo:
104
+
105
+ # description
106
+ gr.Markdown(title)
107
+ gr.Markdown(description)
108
+
109
+ with gr.Row():
110
+ with gr.Column():
111
+
112
+ # upload face image
113
+ # img_file = gr.Image(label="Upload a photo with a face", type="filepath")
114
+ img_files = gr.File(
115
+ label="Drag / Select 1 or more photos of a person's face",
116
+ file_types=["image"],
117
+ file_count="multiple"
118
+ )
119
+ uploaded_files_gallery = gr.Gallery(label="Subject images", visible=False, columns=3, rows=1, height=300)
120
+ with gr.Column(visible=False) as clear_button_column:
121
+ remove_and_reupload = gr.ClearButton(value="Remove and upload subject images", components=img_files, size="sm")
122
+
123
+ prompt = gr.Dropdown(label="Prompt",
124
+ info="Try something like 'man/woman walking on the beach'. If the face is not in focus, try adding 'face portrait of' at the beginning.",
125
+ value=None,
126
+ allow_custom_value=True,
127
+ filterable=False,
128
+ choices=[
129
+ "woman ((best quality)), ((masterpiece)), ((realistic)), long highlighted hair, futuristic silver armor suit, confident stance, high-resolution, living room, smiling, head tilted, perfect smooth skin",
130
+ "woman walking on the beach, sunset, orange sky",
131
+ "woman in a white apron and chef hat, garnishing a gourmet dish, full body view, long shot",
132
+ "woman dancing pose among folks in a park, waving hands",
133
+ "woman in iron man costume flying pose, the sky ablaze with hues of orange and purple, full body view, long shot",
134
+ "woman jedi wielding a lightsaber, star wars, full body view, eye level shot",
135
+ "woman playing guitar on a boat, ocean waves",
136
+ "woman with a passion for reading, curled up with a book in a cozy nook near a window",
137
+ "woman running pose in a park, eye level shot",
138
+ "woman in superman costume flying pose, the sky ablaze with hues of orange and purple, full body view, long shot"
139
+ ])
140
+
141
+ submit = gr.Button("Submit", variant="primary")
142
+
143
+ negative_prompt = gr.Textbox(
144
+ label="Negative Prompt",
145
+ value="flaws in the eyes, flaws in the face, lowres, non-HDRi, low quality, worst quality, artifacts, noise, text, watermark, glitch, mutated, ugly, disfigured, hands, partially rendered objects, partially rendered eyes, deformed eyeballs, cross-eyed, blurry, mutation, duplicate, out of frame, cropped, mutilated, bad anatomy, deformed, bad proportions, nude, naked, nsfw, topless, bare breasts",
146
+ )
147
+
148
+ adaface_id_cfg_scale = gr.Slider(
149
+ label="AdaFace CFG Scale",
150
+ info="The CFG scale of the AdaFace ID embeddings (influencing fine facial features)",
151
+ minimum=0.5,
152
+ maximum=8.0,
153
+ step=0.5,
154
+ value=4.0,
155
+ )
156
+
157
+ guidance_scale = gr.Slider(
158
+ label="Guidance scale",
159
+ minimum=0.5,
160
+ maximum=8.0,
161
+ step=0.5,
162
+ value=4.0,
163
+ )
164
+
165
+ num_images = gr.Slider(
166
+ label="Number of output images",
167
+ minimum=1,
168
+ maximum=6,
169
+ step=1,
170
+ value=4,
171
+ )
172
+ seed = gr.Slider(
173
+ label="Seed",
174
+ minimum=0,
175
+ maximum=MAX_SEED,
176
+ step=1,
177
+ value=0,
178
+ )
179
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True, info="Uncheck for reproducible results")
180
+
181
+ with gr.Column():
182
+ out_gallery = gr.Gallery(label="Generated Images", columns=2, rows=2, height=600)
183
+
184
+ img_files.upload(fn=swap_to_gallery, inputs=img_files, outputs=[uploaded_files_gallery, clear_button_column, img_files])
185
+ remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files_gallery, clear_button_column, img_files])
186
+
187
+ submit.click(
188
+ fn=randomize_seed_fn,
189
+ inputs=[seed, randomize_seed],
190
+ outputs=seed,
191
+ queue=False,
192
+ api_name=False,
193
+ ).then(
194
+ fn=generate_image,
195
+ inputs=[img_files, guidance_scale, adaface_id_cfg_scale, num_images, prompt, negative_prompt, seed],
196
+ outputs=[out_gallery]
197
+ ).then(
198
+ fn=update_out_gallery,
199
+ inputs=[out_gallery],
200
+ outputs=[out_gallery]
201
+ )
202
+
203
+ demo.launch(share=True, server_name=args.ip, ssl_verify=False)
arc2face_models.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import CLIPTextModel
4
+ from transformers.models.clip.modeling_clip import CLIPAttention
5
+ from typing import Any, Callable, Dict, Optional, Tuple, Union, List
6
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
7
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
8
+ # from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
9
+ _make_causal_mask = AttentionMaskConverter._make_causal_mask
10
+ _expand_mask = AttentionMaskConverter._expand_mask
11
+
12
+ from adaface.util import add_noise_to_tensor
13
+
14
+ # Extend CLIPAttention by using multiple k_proj and v_proj in each head.
15
+ # To avoid too much increase of computation, we don't extend q_proj.
16
+ class CLIPAttentionMKV(nn.Module):
17
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
18
+
19
+ def __init__(self, config, multiplier=2):
20
+ super().__init__()
21
+ self.config = config
22
+ self.embed_dim = config.hidden_size
23
+ self.num_heads = config.num_attention_heads
24
+ self.head_dim = self.embed_dim // self.num_heads
25
+ if self.head_dim * self.num_heads != self.embed_dim:
26
+ raise ValueError(
27
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
28
+ f" {self.num_heads})."
29
+ )
30
+ self.scale = self.head_dim**-0.5
31
+ self.dropout = config.attention_dropout
32
+ self.multiplier = multiplier
33
+
34
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim * self.multiplier)
35
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim * self.multiplier)
36
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
37
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
38
+
39
+ # The (approximately) repeated token features are repeated along the last dim in tensor
40
+ # (multiplier * num_heads * head_dim), and then reshaped to (bsz, -1, num_heads, head_dim).
41
+ # Therefore, the "multiplier" dim is tucked into the seq_len dim, which looks like
42
+ # [token1_emb, token1_emb, token2_emb, token2_emb, ..., tokenN_emb, tokenN_emb].
43
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
44
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
45
+
46
+ def extend_weights(self, clip_attn_layer, layer_idx, multiplier, noise_std=0.1,
47
+ noise_std_is_relative=True, keep_norm=False, verbose=False):
48
+ self.multiplier *= multiplier
49
+ # q_proj and out_proj are the same as the original CLIPAttention.
50
+ self.q_proj.weight.data = clip_attn_layer.q_proj.weight.data.clone()
51
+ self.q_proj.bias.data = clip_attn_layer.q_proj.bias.data.clone()
52
+ self.out_proj.weight.data = clip_attn_layer.out_proj.weight.data.clone()
53
+ self.out_proj.bias.data = clip_attn_layer.out_proj.bias.data.clone()
54
+
55
+ # bias doesn't need noise perturbation, as after the weights are noised,
56
+ # different copies of the weight/bias will receive different gradients,
57
+ # making the bias terms diverge and identifiable after training.
58
+ self.v_proj.bias.data = clip_attn_layer.v_proj.bias.data.repeat(multiplier)
59
+ self.k_proj.bias.data = clip_attn_layer.k_proj.bias.data.repeat(multiplier)
60
+
61
+ self.v_proj.weight.data = clip_attn_layer.v_proj.weight.data.repeat(multiplier, 1)
62
+ self.k_proj.weight.data = clip_attn_layer.k_proj.weight.data.repeat(multiplier, 1)
63
+
64
+ if noise_std > 0:
65
+ ORIG_V_SHAPE = list(clip_attn_layer.v_proj.weight.shape)
66
+ ORIG_V_SHAPE_D0 = ORIG_V_SHAPE[0]
67
+ # Adding noise to the extra copies of the weights (keep the first copy unchanged).
68
+ self.v_proj.weight.data[ORIG_V_SHAPE_D0:] = \
69
+ add_noise_to_tensor(self.v_proj.weight.data[ORIG_V_SHAPE_D0:],
70
+ noise_std, noise_std_is_relative, keep_norm)
71
+ if verbose:
72
+ NEW_V_SHAPE = list(self.v_proj.weight.shape)
73
+ NOISED_V_SHAPE = list(self.v_proj.weight.data[ORIG_V_SHAPE_D0:].shape)
74
+ print(f"Layer {layer_idx}: {NOISED_V_SHAPE} in {NEW_V_SHAPE} of v_proj is added with {noise_std} noise")
75
+
76
+ ORIG_K_SHAPE = list(clip_attn_layer.k_proj.weight.shape)
77
+ ORIG_K_SHAPE_D0 = ORIG_K_SHAPE[0]
78
+ # Adding noise to the extra copies of the weights.
79
+ self.k_proj.weight.data[ORIG_K_SHAPE_D0:] = \
80
+ add_noise_to_tensor(self.k_proj.weight.data[ORIG_K_SHAPE_D0:],
81
+ noise_std, noise_std_is_relative, keep_norm)
82
+ if verbose:
83
+ NEW_K_SHAPE = list(self.k_proj.weight.shape)
84
+ NOISED_K_SHAPE = list(self.k_proj.weight.data[ORIG_K_SHAPE_D0:].shape)
85
+ print(f"Layer {layer_idx}: {NOISED_K_SHAPE} in {NEW_K_SHAPE} of k_proj is added with {noise_std} noise")
86
+
87
+ def forward(
88
+ self,
89
+ hidden_states: torch.Tensor,
90
+ attention_mask: Optional[torch.Tensor] = None,
91
+ causal_attention_mask: Optional[torch.Tensor] = None,
92
+ output_attentions: Optional[bool] = False,
93
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
94
+ """Input shape: Batch x Time x Channel"""
95
+
96
+ bsz, tgt_len, embed_dim = hidden_states.size()
97
+
98
+ query_states = self.q_proj(hidden_states) * self.scale
99
+ # For key_states and value_states, the multiplier is absorbed into the seq_len (dim 1, shape specified as -1).
100
+ # [token0_head_emb, token0_head_emb, token1_head_emb, token1_head_emb, ..., tokenN-1_head_emb, tokenN-1_head_emb].
101
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
102
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
103
+
104
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
105
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
106
+ key_states = key_states.view(*proj_shape)
107
+ value_states = value_states.view(*proj_shape)
108
+
109
+ src_len = key_states.size(1)
110
+ # src_len0 is the original src_len without the multiplier.
111
+ src_len0 = src_len // self.multiplier
112
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
113
+
114
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
115
+ raise ValueError(
116
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
117
+ f" {attn_weights.size()}"
118
+ )
119
+
120
+ # apply the causal_attention_mask first
121
+ if causal_attention_mask is not None:
122
+ if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len0):
123
+ raise ValueError(
124
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len0)}, but is"
125
+ f" {causal_attention_mask.size()}"
126
+ )
127
+ # The last dim of attn_weights corresponds to [token0, token0, token1, token1, ..., tokenN-1, tokenN-1].
128
+ # If reshaping it as (self.multiplier, src_len0), it will become
129
+ # [[token0, token0, token1, token1, ..., tokenN//2], [tokenN//2+1, tokenN//2+1, ..., tokenN-1, tokenN-1]],
130
+ # and the mask will be applied to wrong elements.
131
+ # If reshaping it as (src_len0, self.multiplier), it will become
132
+ # [[token0, token1, ..., tokenN-1], [token0, token1, ..., tokenN-1]], and then
133
+ # the mask at element i will mask all the multiplier elements at i, which is desired.
134
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len0, self.multiplier) + causal_attention_mask.unsqueeze(4)
135
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
136
+
137
+ if attention_mask is not None:
138
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len0):
139
+ raise ValueError(
140
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len0)}, but is {attention_mask.size()}"
141
+ )
142
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len0, self.multiplier) + attention_mask.unsqueeze(4)
143
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
144
+
145
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
146
+
147
+ if output_attentions:
148
+ # this operation is a bit awkward, but it's required to
149
+ # make sure that attn_weights keeps its gradient.
150
+ # In order to do so, attn_weights have to reshaped
151
+ # twice and have to be reused in the following
152
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
153
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
154
+ else:
155
+ attn_weights_reshaped = None
156
+
157
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
158
+
159
+ attn_output = torch.bmm(attn_probs, value_states)
160
+
161
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
162
+ raise ValueError(
163
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
164
+ f" {attn_output.size()}"
165
+ )
166
+
167
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
168
+ attn_output = attn_output.transpose(1, 2)
169
+ attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
170
+
171
+ attn_output = self.out_proj(attn_output)
172
+
173
+ return attn_output, attn_weights_reshaped
174
+
175
+ class CLIPTextModelWrapper(CLIPTextModel):
176
+ # Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812
177
+ # Modified to accept precomputed token embeddings "input_token_embs" as input or calculate them from input_ids and return them.
178
+ def forward(
179
+ self,
180
+ input_ids: Optional[torch.Tensor] = None,
181
+ attention_mask: Optional[torch.Tensor] = None,
182
+ position_ids: Optional[torch.Tensor] = None,
183
+ output_attentions: Optional[bool] = None,
184
+ output_hidden_states: Optional[bool] = None,
185
+ return_dict: Optional[bool] = None,
186
+ input_token_embs: Optional[torch.Tensor] = None,
187
+ hidden_state_layer_weights: Optional[torch.Tensor] = None,
188
+ return_token_embs: Optional[bool] = False,
189
+ ) -> Union[Tuple, torch.Tensor, BaseModelOutputWithPooling]:
190
+
191
+ if return_token_embs:
192
+ return self.text_model.embeddings.token_embedding(input_ids)
193
+
194
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
195
+
196
+ output_attentions = output_attentions if output_attentions is not None else self.text_model.config.output_attentions
197
+ output_hidden_states = (
198
+ output_hidden_states if output_hidden_states is not None else self.text_model.config.output_hidden_states
199
+ )
200
+ if hidden_state_layer_weights is not None:
201
+ output_hidden_states = True
202
+ return_dict = return_dict if return_dict is not None else self.text_model.config.use_return_dict
203
+
204
+ if input_ids is None:
205
+ raise ValueError("You have to specify input_ids")
206
+
207
+ input_shape = input_ids.size()
208
+ input_ids = input_ids.view(-1, input_shape[-1])
209
+
210
+ hidden_states = self.text_model.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=input_token_embs)
211
+
212
+ # CLIP's text model uses causal mask, prepare it here.
213
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
214
+ causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
215
+ # expand attention_mask
216
+ if attention_mask is not None:
217
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
218
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
219
+
220
+ encoder_outputs = self.text_model.encoder(
221
+ inputs_embeds=hidden_states,
222
+ attention_mask=attention_mask,
223
+ causal_attention_mask=causal_attention_mask,
224
+ output_attentions=output_attentions,
225
+ # output_hidden_states is False by default, and only True if hidden_state_layer_weights is provided.
226
+ output_hidden_states=output_hidden_states,
227
+ return_dict=return_dict,
228
+ )
229
+
230
+ # If output_hidden_states is True, then encoder_outputs[0] is last_hidden_state [1, 22, 768].
231
+ # encoder_outputs[1] is hidden_states, which is a tuple of 13 hidden states, each being [1, 22, 768].
232
+ # encoder_outputs[0] == encoder_outputs[1][12].
233
+ if hidden_state_layer_weights is None:
234
+ last_hidden_state = encoder_outputs[0]
235
+ else:
236
+ num_hidden_state_layers = len(hidden_state_layer_weights)
237
+ last_hidden_states = encoder_outputs[1][-num_hidden_state_layers:]
238
+ hidden_state_layer_weights = hidden_state_layer_weights.to(last_hidden_states[0].dtype)
239
+ # Normalize the weights of to sum to 1 across layers.
240
+ # hidden_state_layer_weights: [3, 1] or [3, 768].
241
+ hidden_state_layer_weights = hidden_state_layer_weights / hidden_state_layer_weights.sum(dim=0, keepdim=True)
242
+ # [3, 1/768] -> [3, 1, 1, 1/768]
243
+ hidden_state_layer_weights = hidden_state_layer_weights.unsqueeze(1).unsqueeze(1)
244
+ # A weighted sum of last_hidden_states.
245
+ # [3, 1, 22, 768] * [3, 1, 1, 1/768] -> [3, 1, 22, 768] -> [1, 22, 768]
246
+ last_hidden_state = (torch.stack(last_hidden_states, dim=0) * hidden_state_layer_weights).sum(dim=0)
247
+
248
+ last_hidden_state = self.text_model.final_layer_norm(last_hidden_state)
249
+
250
+ # self.text_model.eos_token_id == 2 is True.
251
+ if self.text_model.eos_token_id == 2:
252
+ # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
253
+ # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
254
+ # ------------------------------------------------------------
255
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
256
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
257
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
258
+ pooled_output = last_hidden_state[
259
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
260
+ input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
261
+ ]
262
+ else:
263
+ # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
264
+ pooled_output = last_hidden_state[
265
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
266
+ # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
267
+ (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.text_model.eos_token_id)
268
+ .int()
269
+ .argmax(dim=-1),
270
+ ]
271
+
272
+ if not return_dict:
273
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
274
+
275
+ return BaseModelOutputWithPooling(
276
+ last_hidden_state=last_hidden_state,
277
+ pooler_output=pooled_output,
278
+ hidden_states=encoder_outputs.hidden_states,
279
+ attentions=encoder_outputs.attentions,
280
+ )
281
+
282
+ # Applied to layers [begin_layer_idx, end_layer_idx) in the encoder.
283
+ # The layer indexed by end_layer_idx is not included.
284
+ # If both layer indices are -1, then apply to all layers (0-11).
285
+ def extend_clip_attention_MKV_multiplier(self, begin_layer_idx=-1, end_layer_idx=-1, multiplier=2, noise_std=0.1):
286
+ num_extended_layers = 0
287
+
288
+ for layer_idx, layer in enumerate(self.text_model.encoder.layers):
289
+ if begin_layer_idx >= 0 and layer_idx < begin_layer_idx:
290
+ continue
291
+ if end_layer_idx >= 0 and layer_idx >= end_layer_idx:
292
+ break
293
+ # This shouldn't happen, unless self_attn has already been extended as CLIPAttentionMKV.
294
+ if not isinstance(layer.self_attn, (CLIPAttention, CLIPAttentionMKV)):
295
+ breakpoint()
296
+ old_attn_layer = layer.self_attn
297
+ if not isinstance(old_attn_layer, CLIPAttentionMKV):
298
+ layer.self_attn = CLIPAttentionMKV(old_attn_layer.config, 1)
299
+ layer.self_attn.extend_weights(old_attn_layer, layer_idx, multiplier, noise_std, verbose=True)
300
+ num_extended_layers += 1
301
+
302
+ return num_extended_layers
303
+
models/adaface/subjects-celebrity2024-05-16T17-22-46_zero3-ada-30000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4aa1eb9ff3e364ea1b9db6dfff0c281ff3b57864d7ccc4c64d5f29ed752484f3
3
+ size 821700521
models/arc2face/arc2face/config.json ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.22.0",
4
+ "act_fn": "silu",
5
+ "addition_embed_type": null,
6
+ "addition_embed_type_num_heads": 64,
7
+ "addition_time_embed_dim": null,
8
+ "attention_head_dim": 8,
9
+ "attention_type": "default",
10
+ "block_out_channels": [
11
+ 320,
12
+ 640,
13
+ 1280,
14
+ 1280
15
+ ],
16
+ "center_input_sample": false,
17
+ "class_embed_type": null,
18
+ "class_embeddings_concat": false,
19
+ "conv_in_kernel": 3,
20
+ "conv_out_kernel": 3,
21
+ "cross_attention_dim": 768,
22
+ "cross_attention_norm": null,
23
+ "down_block_types": [
24
+ "CrossAttnDownBlock2D",
25
+ "CrossAttnDownBlock2D",
26
+ "CrossAttnDownBlock2D",
27
+ "DownBlock2D"
28
+ ],
29
+ "downsample_padding": 1,
30
+ "dropout": 0.0,
31
+ "dual_cross_attention": false,
32
+ "encoder_hid_dim": null,
33
+ "encoder_hid_dim_type": null,
34
+ "flip_sin_to_cos": true,
35
+ "freq_shift": 0,
36
+ "in_channels": 4,
37
+ "layers_per_block": 2,
38
+ "mid_block_only_cross_attention": null,
39
+ "mid_block_scale_factor": 1,
40
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
41
+ "norm_eps": 1e-05,
42
+ "norm_num_groups": 32,
43
+ "num_attention_heads": null,
44
+ "num_class_embeds": null,
45
+ "only_cross_attention": false,
46
+ "out_channels": 4,
47
+ "projection_class_embeddings_input_dim": null,
48
+ "resnet_out_scale_factor": 1.0,
49
+ "resnet_skip_time_act": false,
50
+ "resnet_time_scale_shift": "default",
51
+ "reverse_transformer_layers_per_block": null,
52
+ "sample_size": 64,
53
+ "time_cond_proj_dim": null,
54
+ "time_embedding_act_fn": null,
55
+ "time_embedding_dim": null,
56
+ "time_embedding_type": "positional",
57
+ "timestep_post_act": null,
58
+ "transformer_layers_per_block": 1,
59
+ "up_block_types": [
60
+ "UpBlock2D",
61
+ "CrossAttnUpBlock2D",
62
+ "CrossAttnUpBlock2D",
63
+ "CrossAttnUpBlock2D"
64
+ ],
65
+ "upcast_attention": false,
66
+ "use_linear_projection": false
67
+ }
models/arc2face/arc2face/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2377c16b7135650ca375817a4812a999194fba1f081e39117bd54e50dacc784
3
+ size 3438167536
models/arc2face/encoder/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CLIPTextModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 0,
7
+ "dropout": 0.0,
8
+ "eos_token_id": 2,
9
+ "hidden_act": "quick_gelu",
10
+ "hidden_size": 768,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 77,
16
+ "model_type": "clip_text_model",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 1,
20
+ "projection_dim": 768,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.34.1",
23
+ "vocab_size": 49408
24
+ }
models/arc2face/encoder/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2d364df774b7d3975f85de42bda73c0c0cdb952273dd5f138511b6cf65424aa
3
+ size 492308829
models/insightface/models/antelopev2/1k3d68.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc
3
+ size 143607619
models/insightface/models/antelopev2/2d106det.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf
3
+ size 5030888
models/insightface/models/antelopev2/arcface.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec639a0429b4819130d1405a2d3b38beaa4cc4a6c5bd9cf48b94fdf65461de83
3
+ size 260694151
models/insightface/models/antelopev2/genderage.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb
3
+ size 1322532
models/insightface/models/antelopev2/scrfd_10g_bnkps.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91
3
+ size 16923827
models/insightface/models/buffalo_l/1k3d68.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc
3
+ size 143607619
models/insightface/models/buffalo_l/2d106det.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf
3
+ size 5030888
models/insightface/models/buffalo_l/det_10g.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91
3
+ size 16923827
models/insightface/models/buffalo_l/genderage.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb
3
+ size 1322532
models/insightface/models/buffalo_l/w600k_r50.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c06341c33c2ca1f86781dab0e829f88ad5b64be9fba56e56bc9ebdefc619e43
3
+ size 174383860
models/sar/sar.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:35a5d7615850879ffecce7b1e463ae0317c95fe784dd9b179793b58531a9e3ab
3
+ size 2299982596
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ einops
4
+ gradio
5
+ transformers
6
+ insightface
7
+ opencv-python
8
+ diffusers
9
+ onnx>=1.16.0
10
+ onnxruntime
11
+ safetensors
12
+ spaces
subj_basis_generator.py ADDED
@@ -0,0 +1,758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Borrowed from ip-adapter resampler.py.
2
+ # https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
3
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
4
+ # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
5
+
6
+ import math
7
+
8
+ import torch
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange
12
+ from einops.layers.torch import Rearrange
13
+ from transformers import CLIPVisionModel, CLIPTokenizer
14
+
15
+ import numpy as np
16
+ from torch import einsum
17
+ from dataclasses import dataclass
18
+ from typing import Optional, Tuple
19
+ from transformers.utils import ModelOutput
20
+ from adaface.util import arc2face_inverse_face_prompt_embs, gen_gradient_scaler
21
+ from adaface.arc2face_models import CLIPTextModelWrapper
22
+ import sys
23
+ sys.modules['ldm'] = sys.modules['adaface']
24
+
25
+ def reshape_tensor(x, num_heads):
26
+ bs, length, width = x.shape
27
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
28
+ x = x.view(bs, length, num_heads, -1)
29
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
30
+ x = x.transpose(1, 2)
31
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
32
+ x = x.reshape(bs, num_heads, length, -1)
33
+ return x
34
+
35
+ # FFN. Added a Dropout layer at the end, so that it can still load the old ckpt.
36
+ def FeedForward(dim, mult=4, p_dropout=0.1):
37
+ inner_dim = int(dim * mult)
38
+ return nn.Sequential(
39
+ nn.LayerNorm(dim),
40
+ nn.Linear(dim, inner_dim, bias=False),
41
+ nn.GELU(),
42
+ nn.Linear(inner_dim, dim, bias=False),
43
+ nn.Dropout(p_dropout),
44
+ )
45
+
46
+ # IP-Adapter FaceID class. Only used in knn-faces.py.
47
+ # From: https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter_faceid_separate.py
48
+ class IP_MLPProjModel(nn.Module):
49
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
50
+ super().__init__()
51
+
52
+ self.cross_attention_dim = cross_attention_dim
53
+ self.num_tokens = num_tokens
54
+
55
+ self.proj = nn.Sequential(
56
+ nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
57
+ nn.GELU(),
58
+ nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
59
+ )
60
+ self.norm = nn.LayerNorm(cross_attention_dim)
61
+
62
+ def forward(self, id_embeds):
63
+ x = self.proj(id_embeds)
64
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
65
+ x = self.norm(x)
66
+ return x
67
+
68
+ # group_dim: the tensor dimension that corresponds to the multiple groups.
69
+ class LearnedSoftAggregate(nn.Module):
70
+ def __init__(self, num_feat, group_dim, keepdim=False):
71
+ super(LearnedSoftAggregate, self).__init__()
72
+ self.group_dim = group_dim
73
+ # num_feat = 1: element-wise score function & softmax.
74
+ # num_feat > 1: the linear score function is applied to the last dim (features) of the input tensor.
75
+ self.num_feat = num_feat
76
+ self.feat2score = nn.Linear(num_feat, 1, bias=False)
77
+ self.keepdim = keepdim
78
+
79
+ def forward(self, x, score_basis=None):
80
+ # If there's only one mode, do nothing.
81
+ if x.shape[self.group_dim] == 1:
82
+ if self.keepdim:
83
+ return x
84
+ else:
85
+ return x.squeeze(self.group_dim)
86
+
87
+ # Assume the last dim of x is the feature dim.
88
+ if score_basis is None:
89
+ score_basis = x
90
+
91
+ if self.num_feat == 1:
92
+ mode_scores = self.feat2score(score_basis.unsqueeze(-1)).squeeze(-1)
93
+ else:
94
+ mode_scores = self.feat2score(score_basis)
95
+ attn_probs = mode_scores.softmax(dim=self.group_dim)
96
+ x_aggr = (x * attn_probs).sum(dim=self.group_dim, keepdim=self.keepdim)
97
+ return x_aggr
98
+
99
+ def LoRA_ExpandEmbs(input_dim, lora_rank, output_dim, num_modes,
100
+ num_output_vecs, elementwise_affine=True, p_dropout=0.1):
101
+ return nn.Sequential(
102
+ # Project to [BS, lora_rank * output_dim * num_modes].
103
+ # It takes a huge param size. 512 * 32 * 768 * 4 = 6,291,456.
104
+ nn.Linear(input_dim, lora_rank * output_dim * num_modes, bias=False),
105
+ # Reshape to [BS, lora_rank, output_dim].
106
+ Rearrange('b (m q d) -> b m q d', q=lora_rank, m=num_modes, d=output_dim),
107
+ nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
108
+ # Aggregate [BS, num_modes, loar_rank, output_dim] -> [BS, lora_rank, output_dim].
109
+ LearnedSoftAggregate(num_feat=output_dim, group_dim=1, keepdim=False) if num_modes > 1 \
110
+ else Rearrange('b () q d -> b q d'),
111
+ nn.Dropout(p_dropout),
112
+ # Permute to [BS, output_dim, lora_rank].
113
+ Rearrange('b q d -> b d q'),
114
+ # Project to [BS, output_dim, num_output_vecs].
115
+ nn.Linear(lora_rank, num_output_vecs, bias=False),
116
+ # Permute to [BS, num_output_vecs, output_dim].
117
+ Rearrange('b d q -> b q d'),
118
+ nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
119
+ nn.Dropout(p_dropout),
120
+ )
121
+
122
+ def ExpandEmbs(input_dim, output_dim, expansion_ratio, elementwise_affine=True, p_dropout=0.1):
123
+ return nn.Sequential(
124
+ # Project to [BS, num_output_vecs * output_dim].
125
+ nn.Linear(input_dim, expansion_ratio * output_dim, bias=False),
126
+ # Reshape to [BS, num_output_vecs, output_dim].
127
+ Rearrange('b (e d) -> b e d', e=expansion_ratio, d=output_dim),
128
+ nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
129
+ nn.Dropout(p_dropout),
130
+ )
131
+
132
+ # Input: [BS, N, D].
133
+ def MultimodeProjection(input_dim, output_dim=-1, num_modes=4, elementwise_affine=True, p_dropout=0.1):
134
+ if output_dim == -1:
135
+ output_dim = input_dim
136
+
137
+ return nn.Sequential(
138
+ nn.Linear(input_dim, output_dim * num_modes, bias=False),
139
+ # Reshape to [BS, num_output_vecs, output_dim].
140
+ Rearrange('b n (m d) -> b n m d', m=num_modes, d=output_dim),
141
+ nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
142
+ # If num_modes == 1, then simply remove the mode dim. Otherwise, aggregate the modes.
143
+ LearnedSoftAggregate(num_feat=output_dim, group_dim=2, keepdim=False) if num_modes > 1 \
144
+ else Rearrange('b n () d -> b n d'),
145
+ nn.Dropout(p_dropout),
146
+ )
147
+
148
+ # Low-rank to high-rank transformation.
149
+ def Lora2Hira(lora_rank, hira_rank, output_dim, num_modes, elementwise_affine=True, p_dropout=0.1):
150
+ return nn.Sequential(
151
+ # Permute to [BS, output_dim, lora_rank].
152
+ Rearrange('b q d -> b d q'),
153
+ # Project to [BS, output_dim, hira_rank].
154
+ nn.Linear(lora_rank, hira_rank * num_modes, bias=False),
155
+ # Reshape and permute to [BS, num_modes, num_output_vecs, output_dim].
156
+ Rearrange('b d (m q) -> b m q d', m=num_modes, q=hira_rank),
157
+ nn.LayerNorm(output_dim, elementwise_affine=elementwise_affine),
158
+ # Aggregate [BS, num_modes, hira_rank, output_dim] -> [BS, hira_rank, output_dim].
159
+ LearnedSoftAggregate(num_feat=output_dim, group_dim=1, keepdim=False) if num_modes > 1 \
160
+ else Rearrange('b () q d -> b q d'),
161
+ nn.Dropout(p_dropout),
162
+ )
163
+
164
+ class PerceiverAttention(nn.Module):
165
+ def __init__(self, *, dim, dim_head=64, num_heads=8, elementwise_affine=True):
166
+ super().__init__()
167
+ self.scale = dim_head**-0.5
168
+ self.dim_head = dim_head
169
+ self.num_heads = num_heads
170
+ inner_dim = dim_head * num_heads
171
+
172
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine)
173
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=elementwise_affine)
174
+
175
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
176
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
177
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
178
+
179
+ def forward(self, x, latent_queries):
180
+ """
181
+ Args:
182
+ x (torch.Tensor): image features
183
+ shape (b, n1, D)
184
+ latent (torch.Tensor): latent features
185
+ shape (b, n2, D)
186
+ """
187
+ x = self.norm1(x)
188
+ latent_queries = self.norm2(latent_queries)
189
+
190
+ b, l, _ = latent_queries.shape
191
+
192
+ q = self.to_q(latent_queries)
193
+ kv_input = torch.cat((x, latent_queries), dim=-2)
194
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
195
+
196
+ q = reshape_tensor(q, self.num_heads)
197
+ k = reshape_tensor(k, self.num_heads)
198
+ v = reshape_tensor(v, self.num_heads)
199
+
200
+ # attention
201
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
202
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
203
+ attn = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
204
+ out = attn @ v
205
+
206
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
207
+
208
+ return self.to_out(out)
209
+
210
+
211
+ class CrossAttention(nn.Module):
212
+ # output_dim is always the same as input_dim.
213
+ # num_q only matters when q_aware_to_v is True.
214
+ # If q_aware_to_v is False, query x in forward() is still usable.
215
+ def __init__(self, input_dim, num_heads=6, p_dropout=0.05,
216
+ identity_to_q=False, identity_to_k=False, identity_to_v=False, v_has_skip=True,
217
+ q_aware_to_v=True, num_q=416, v_repeat=4, q_aware_to_v_lora_rank=64,
218
+ identity_to_out=False, out_has_skip=False):
219
+ super().__init__()
220
+ dim_head = input_dim // num_heads
221
+ inner_dim = dim_head * num_heads
222
+
223
+ self.num_heads = num_heads
224
+ self.q_aware_to_v = q_aware_to_v
225
+ self.v_has_skip = v_has_skip
226
+ self.to_q = nn.Sequential(
227
+ nn.Linear(input_dim, inner_dim, bias=False),
228
+ nn.LayerNorm(inner_dim, elementwise_affine=True)
229
+ ) if not identity_to_q else nn.Identity()
230
+ self.to_k = nn.Sequential(
231
+ nn.Linear(input_dim, inner_dim, bias=False),
232
+ nn.LayerNorm(inner_dim, elementwise_affine=True)
233
+ ) if not identity_to_k else nn.Identity()
234
+
235
+ self.v_repeat = v_repeat
236
+ self.num_q_group = num_q_group = num_q // v_repeat # 416 / 4 = 104.
237
+
238
+ # If q_aware_to_v is True, then self.to_v consists of num_q projections of input_dim to inner_dim.
239
+ # Otherwise, self.to_v consists of a single projection of input_dim to inner_dim.
240
+ if q_aware_to_v:
241
+ # all_q_mid: 104 * 64 = 6656.
242
+ all_q_mid = num_q_group * q_aware_to_v_lora_rank
243
+ self.to_v = nn.Sequential(
244
+ # number of params: 768 * 6656 = 5,111,808.
245
+ # Input: [BS, 16, 768]. Output: [BS, 16, 104*64] = [BS, 16, 6656].
246
+ # Each 768-dim vec is dispersed into 104 64-dim vecs.
247
+ nn.Linear(input_dim, all_q_mid, bias=False),
248
+ nn.LayerNorm(all_q_mid, elementwise_affine=True),
249
+ # Change the dim of the tensor to [BS, 6656, 16], as Conv1d transforms dim 1.
250
+ Rearrange('b n q -> b q n', q=all_q_mid),
251
+ # Each q_aware_to_v projection has its own linear layer.
252
+ # The total number of parameters will be 6656*768 = 5,111,808.
253
+ # Output: [BS, 104*768, 16]. Each 64 dim feature is expanded to 768 dim.
254
+ nn.Conv1d(
255
+ in_channels=all_q_mid,
256
+ out_channels=num_q_group * input_dim,
257
+ kernel_size=1,
258
+ groups=num_q_group,
259
+ bias=False,
260
+ ),
261
+ # Output: [BS, 104, 16, 768].
262
+ Rearrange('b (q d) n -> b q n d', q=num_q_group, d=input_dim),
263
+ nn.LayerNorm(input_dim, elementwise_affine=True),
264
+ )
265
+ else:
266
+ self.to_v = nn.Sequential(
267
+ nn.Linear(input_dim, inner_dim, bias=False),
268
+ nn.LayerNorm(inner_dim, elementwise_affine=True)
269
+ ) if not identity_to_v else nn.Identity()
270
+
271
+ if identity_to_out:
272
+ assert not out_has_skip, "identity_to_out=True, then out_has_skip has to be False."
273
+
274
+ if identity_to_out:
275
+ self.to_out = nn.Identity()
276
+ else:
277
+ self.to_out = nn.Sequential(
278
+ nn.Linear(input_dim, input_dim, bias=False),
279
+ nn.Dropout(p_dropout),
280
+ nn.LayerNorm(inner_dim, elementwise_affine=True)
281
+ )
282
+
283
+ self.out_has_skip = out_has_skip
284
+ self.attn_drop = nn.Dropout(p_dropout)
285
+
286
+ def forward(self, x, context=None, attn_mat=None, return_attn=False):
287
+ h = self.num_heads
288
+
289
+ if context is None:
290
+ context = x
291
+
292
+ if attn_mat is None:
293
+ # q: [BS, Q, D] -> [BS, Q, D].
294
+ q = self.to_q(x)
295
+ # k: [BS, L, D] -> [BS, L, D].
296
+ k = self.to_k(context)
297
+ # q: [6, 512, 128], k: [6, 17, 128].
298
+ q, k = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k))
299
+
300
+ if self.q_aware_to_v:
301
+ # context: [BS, L, D]. v: [BS, Q, L, D].
302
+ # There are effectively Q to_v projections.
303
+ v = self.to_v(context)
304
+ if self.v_has_skip:
305
+ v = v + context.unsqueeze(1)
306
+ else:
307
+ # v: [BS, L, D].
308
+ v = self.to_v(context)
309
+ if self.v_has_skip:
310
+ v = v + context
311
+
312
+ #print(v.shape)
313
+
314
+ if self.q_aware_to_v:
315
+ # v: [6, 64, 17, 128].
316
+ # v is query-specific, so there's an extra dim for the query.
317
+ v = rearrange(v, 'b q n (h d) -> (b h) q n d', h=h)
318
+ # Each v is for a query group with 512/64 = 8 queries.
319
+ # So each v is repeated 8 times to match the number of queries.
320
+ # v: [6, 64, 17, 128] -> [6, 512, 17, 128].
321
+ v = v.repeat(1, self.v_repeat, 1, 1)
322
+ else:
323
+ v = rearrange(v, 'b n (h d) -> (b h) n d', h=h)
324
+
325
+ if attn_mat is None:
326
+ scale = q.size(-1) ** -0.25
327
+ sim = einsum('b i d, b j d -> b i j', q * scale, k * scale)
328
+ # sim: [6, 64, 17]. 6: bs 1 * h 6.
329
+ # attention, what we cannot get enough of
330
+ # NOTE: the normalization is done across tokens, not across pixels.
331
+ # So for each pixel, the sum of attention scores across tokens is 1.
332
+ attn = sim.softmax(dim=-1)
333
+ attn = self.attn_drop(attn)
334
+ #print(attn.std())
335
+ else:
336
+ attn = attn_mat
337
+
338
+ if self.q_aware_to_v:
339
+ # attn: [6, 32, 17]. v: [6, 32, 17, 128]. 128: dim of each head. out: [6, 32, 128].
340
+ # out is combined with different attn weights and v for different queries.
341
+ out = einsum('b i j, b i j d -> b i d', attn, v)
342
+ else:
343
+ # v: [6, 17, 128]. out: [6, 32, 128].
344
+ out = einsum('b i j, b j d -> b i d', attn, v)
345
+
346
+ # [6, 32, 128] -> [1, 32, 768].
347
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
348
+
349
+ if self.out_has_skip:
350
+ out = self.to_out(out) + out
351
+ else:
352
+ out = self.to_out(out)
353
+
354
+ if return_attn:
355
+ return out, attn
356
+ else:
357
+ return out
358
+
359
+ class SubjBasisGenerator(nn.Module):
360
+ def __init__(
361
+ self,
362
+ # number of cross-attention heads. Half of the number of heads 12 of OpenAI clip-vit-large-patch14:
363
+ # https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json
364
+ num_heads=6,
365
+ num_id_vecs={ 'subj': 77, 'bg': 257 }, # number of identity vectors. 18: 16 face tokens + 2 extra tokens. 257: 257 CLIP tokens.
366
+ num_out_embs_per_layer=4, # num_out_embs. subj: 16. bg: 4.
367
+ num_out_layers=16, # number of layers of output embeddings.
368
+ image_embedding_dim=768, # CLIP image feature dimension, as per config.json above.
369
+ # DINO vits16 has 6 attention heads:
370
+ # https://huggingface.co/facebook/dino-vits16/blob/main/config.json
371
+ dino_embedding_dim=384, # DINO object feature dimension for objects.
372
+ output_dim=768, # CLIP text embedding input dimension.
373
+ placeholder_is_bg: bool = False, # Whether the placeholder is for the image background.
374
+ prompt2token_proj_grad_scale: float = 0.4, # Gradient scale for prompt2token_proj.
375
+ zs_extra_words_scale: float = 0.5, # Scale for extra words in the prompt2token_proj.
376
+ learnable_hidden_state_weights_scheme: str = 'per-layer', # none, per-layer.
377
+ bg_prompt_translator_has_to_out_proj: bool = False, # Whether the prompt_trans_layers have a to_out projection.
378
+ ):
379
+ super().__init__()
380
+
381
+ self.placeholder_is_bg = placeholder_is_bg
382
+ self.num_out_layers = num_out_layers
383
+ self.num_out_embs_per_layer = num_out_embs_per_layer
384
+ # subj: 64, bg: 32.
385
+ self.num_out_embs = num_out_layers * num_out_embs_per_layer
386
+ self.output_dim = output_dim
387
+ # num_id_vecs should be the number of core ID embs, 16.
388
+ # However, in such case, pos_embs is not used. So it doesn't matter if it's wrongly set.
389
+ self.num_id_vecs = num_id_vecs['bg'] if placeholder_is_bg else num_id_vecs['subj']
390
+ self.pos_embs = nn.Parameter(torch.randn(1, self.num_id_vecs, output_dim))
391
+ self.pos_embs_ln = nn.LayerNorm(output_dim)
392
+ self.zs_extra_words_scale = zs_extra_words_scale
393
+ self.output_scale = output_dim ** -0.5
394
+ self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
395
+
396
+ if not self.placeholder_is_bg:
397
+ # [1, 384] -> [1, 16, 768].
398
+ # TODO: use CLIPTextModelWrapper as obj_proj_in.
399
+ self.obj_proj_in = ExpandEmbs(dino_embedding_dim, output_dim, expansion_ratio=self.num_id_vecs)
400
+
401
+ # self.prompt2token_proj: [1, 16, 768] -> [1, 77, 768] (with paddings).
402
+ # If self.placeholder_is_bg: prompt2token_proj is set to None.
403
+ self.prompt2token_proj = CLIPTextModelWrapper.from_pretrained('openai/clip-vit-large-patch14')
404
+ self.prompt2token_proj_grad_scale = prompt2token_proj_grad_scale
405
+ self.prompt2token_proj_grad_scaler = gen_gradient_scaler(prompt2token_proj_grad_scale)
406
+ print(f"Subj prompt2token_proj initialized with grad scale of {prompt2token_proj_grad_scale}.")
407
+ # Freeze prompt2token_proj if prompt2token_proj_grad_scale is 0.
408
+ # Set requires_grad to False for all parameters in prompt2token_proj, to save memory taken by the optimizer.
409
+ if prompt2token_proj_grad_scale == 0:
410
+ self.freeze_prompt2token_proj()
411
+
412
+ self.prompt2token_proj_attention_multiplier = -1
413
+ self.initialize_hidden_state_layer_weights(learnable_hidden_state_weights_scheme, 'cpu')
414
+ self.pad_embeddings = None
415
+ self.bg_proj_in = None
416
+ else:
417
+ # For background placeholders, face and object embeddings are not used as they are foreground.
418
+ self.obj_proj_in = None
419
+ self.prompt2token_proj = None
420
+ print("Bg prompt2token_proj is set to None.")
421
+
422
+ self.bg_proj_in = nn.Sequential(
423
+ nn.Linear(image_embedding_dim, output_dim, bias=False),
424
+ nn.LayerNorm(output_dim),
425
+ )
426
+
427
+ self.latent_queries = nn.Parameter(torch.randn(1, self.num_out_embs, output_dim))
428
+ self.latent_queries_ln = nn.LayerNorm(output_dim)
429
+
430
+ self.bg_prompt_translator_has_to_out_proj = bg_prompt_translator_has_to_out_proj
431
+ identity_to_v = False
432
+ v_has_skip = not identity_to_v # True
433
+ identity_to_out = not bg_prompt_translator_has_to_out_proj # True
434
+ out_has_skip = not identity_to_out # False
435
+ # prompt_translator has a to_v projection with skip connection, and doesn't have a to_out projection.
436
+ # dim=768, num_heads=6.
437
+ self.prompt_translator = \
438
+ CrossAttention(input_dim=output_dim, num_heads=num_heads, p_dropout=0.05,
439
+ identity_to_q=False, identity_to_k=False, identity_to_v=identity_to_v,
440
+ q_aware_to_v=False, v_has_skip=v_has_skip,
441
+ num_q=0, # When not q_aware_to_v, num_q is not referenced.
442
+ identity_to_out=identity_to_out,
443
+ out_has_skip=out_has_skip)
444
+ '''
445
+ prompt_translator: CLIPEncoder
446
+ # https://github.com/huggingface/transformers/blob/1872bde7fc6a5d6796bd742bc2dc38eaf8069c5d/src/transformers/models/clip/modeling_clip.py#L566
447
+ # CLIPEncoder.layers: 12 layers of CLIPEncoderLayer, each being
448
+ (0): CLIPEncoderLayer(
449
+ (self_attn): CLIPAttention(
450
+ (k_proj): Linear(in_features=768, out_features=768, bias=True)
451
+ (v_proj): Linear(in_features=768, out_features=768, bias=True)
452
+ (q_proj): Linear(in_features=768, out_features=768, bias=True)
453
+ (out_proj): Linear(in_features=768, out_features=768, bias=True)
454
+ )
455
+ (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
456
+ (mlp): CLIPMLP(
457
+ (activation_fn): QuickGELUActivation()
458
+ (fc1): Linear(in_features=768, out_features=3072, bias=True)
459
+ (fc2): Linear(in_features=3072, out_features=768, bias=True)
460
+ )
461
+ (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
462
+ )
463
+ '''
464
+
465
+ print(repr(self))
466
+
467
+ # raw_id_embs: ArcFace embeddings for faces (not used since we have arc2face_id_embs),
468
+ # or DINO embeddings for objects.
469
+ # arc2face_id_embs: [BS, 16, 768], the core identity embeddings generated by Arc2Face.
470
+ def forward(self, arc2face_id_embs, clip_features=None, raw_id_embs=None, out_id_embs_scale=1.0,
471
+ is_face=True, is_training=False, adaface_prompt_embs_inf_type='full_half_pad'):
472
+
473
+ if not self.placeholder_is_bg:
474
+ BS = arc2face_id_embs.shape[0]
475
+ else:
476
+ # If bg, then arc2face_id_embs is set to None, but clip_features is not None.
477
+ BS = clip_features.shape[0]
478
+
479
+ adaface_prompt_embs = None
480
+ if not hasattr(self, 'clip_tokenizer'):
481
+ self.clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
482
+
483
+ # No need to use raw_id_embs if placeholder_is_bg.
484
+ if not self.placeholder_is_bg:
485
+ if is_face:
486
+ assert arc2face_id_embs is not None
487
+ # arc2face_embs has been projected to the (modified) prompt embedding space
488
+ # by arc2face_forward_face_embs. This prompt embedding space is modified because Arc2Face finetuned
489
+ # the text encoder and the U-Net.
490
+ # in embedding_manager: [BS, 16, 768] -> [BS, 77, 768].
491
+ # arc2face_id_embs is part of arc2face_embs: [BS, 77, 768] -> [BS, 16, 768].
492
+ # adaface_prompt_embs is projected to the prompt embedding spaces. This is the
493
+ # original U-Net prompt embedding space.
494
+
495
+ # hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]]
496
+ hidden_state_layer_weights = self.hidden_state_layer_weights_grad_scaler(self.hidden_state_layer_weights)
497
+ # return_emb_types: a list of strings, each string is among
498
+ # ['full', 'core', 'full_pad', 'full_half_pad', 'full_zeroed_extra', 'b_core_e'].
499
+ # Using b_core_e is more computationally efficient than using full_zeroed_extra.
500
+ # But there is an unknow BUG that causes crash when using b_core_e.
501
+ if is_training:
502
+ return_emb_types = ['full_pad', 'core']
503
+ else:
504
+ # adaface_prompt_embs_inf_type: default is full_half_pad, same as training.
505
+ return_emb_types = [adaface_prompt_embs_inf_type, 'core']
506
+
507
+ if self.pad_embeddings is None:
508
+ self.generate_pad_embeddings()
509
+ else:
510
+ self.pad_embeddings = self.pad_embeddings.to(arc2face_id_embs.device)
511
+
512
+ with torch.set_grad_enabled(self.training and self.prompt2token_proj_grad_scale != 0):
513
+ # If list_extra_words is not None, then core_id_embs: [BS, 18, 768], three leading words, the 16 identity tokens
514
+ # and (at most) two extra words in full_prompt_embs, without BOS and EOS.
515
+ # If list_extra_words is None, then core_id_embs: [BS, 16, 768], the 16 identity tokens in full_prompt_embs.
516
+ # hidden_state_layer_weights: [[0.9163], [0.9483], [2.0762]]
517
+ # zs_extra_words_scale is only effective when list_extra_words is not None.
518
+ # adaface_prompt_embs: [BS, 77, 768], core_id_embs: [BS, 16, 768].
519
+ adaface_prompt_embs, core_id_embs = \
520
+ arc2face_inverse_face_prompt_embs(self.clip_tokenizer,
521
+ self.prompt2token_proj,
522
+ arc2face_id_embs,
523
+ list_extra_words=None,
524
+ return_emb_types=return_emb_types,
525
+ pad_embeddings=self.pad_embeddings,
526
+ hidden_state_layer_weights=hidden_state_layer_weights,
527
+ input_max_length=77, zs_extra_words_scale=self.zs_extra_words_scale)
528
+ # Reduce the update rate to prompt2token_proj.
529
+ adaface_prompt_embs = self.prompt2token_proj_grad_scaler(adaface_prompt_embs)
530
+ core_id_embs = self.prompt2token_proj_grad_scaler(core_id_embs)
531
+ elif raw_id_embs is not None:
532
+ # id_embs: [BS, 384] -> [BS, 18, 768].
533
+ # obj_proj_in is expected to project the DINO object features to
534
+ # the token embedding space. So no need to use prompt2token_proj.
535
+ id_embs = self.obj_proj_in(raw_id_embs)
536
+ else:
537
+ breakpoint()
538
+ else:
539
+ # Otherwise, context is the ad-hoc CLIP image features.
540
+ # id_embs: [BS, 257, 768].
541
+ id_embs = self.bg_proj_in(clip_features)
542
+
543
+ if self.placeholder_is_bg:
544
+ id_embs = id_embs + self.pos_embs_ln(self.pos_embs)
545
+ latent_queries = self.latent_queries_ln(self.latent_queries).repeat(BS, 1, 1)
546
+ # If bg, we don't have to use a specific attn layer for each 4-vec set. Instead, one attn layer can generate 257 embs,
547
+ # and we take the first 16*4=64.
548
+ # Output of prompt_translator is exactly num_out_embs == 64 tokens. id_embs_out: [BS, 64, 768].
549
+ # prompt_translator: better named as bg_prompt_translator. It maps the bg features
550
+ # to bg prompt embeddings.
551
+ with torch.set_grad_enabled(self.training):
552
+ id_embs_out = self.prompt_translator(latent_queries, id_embs)
553
+ # [BS, 64, 768] -> [BS, 16, 4, 768]
554
+ id_embs_out = id_embs_out.reshape(BS, self.num_out_layers, -1, self.output_dim)
555
+ adaface_subj_embs = id_embs_out * self.output_scale # * 0.036
556
+ else:
557
+ # adaface_subj_embs: [BS, 16, 768] -> [BS, 1, 16, 768] -> [BS, 16, 16, 768]
558
+ adaface_subj_embs = core_id_embs.unsqueeze(1).repeat(1, self.num_out_layers, 1, 1)
559
+
560
+ # If out_id_embs_scale < 1, adaface_subj_embs is a mix of adaface_subj_embs and pad_embeddings.
561
+ if out_id_embs_scale != 1:
562
+ # pad_embeddings: [77, 768] -> [16, 768] -> [1, 1, 16, 768].
563
+ pad_embeddings = self.pad_embeddings[4:4+self.num_out_embs_per_layer].unsqueeze(0).unsqueeze(0)
564
+ adaface_subj_embs = adaface_subj_embs * out_id_embs_scale \
565
+ + pad_embeddings * (1 - out_id_embs_scale)
566
+
567
+ return adaface_subj_embs, adaface_prompt_embs
568
+
569
+ def initialize_hidden_state_layer_weights(self, learnable_hidden_state_weights_scheme, device):
570
+ if learnable_hidden_state_weights_scheme == 'none':
571
+ self.hidden_state_layer_weights = None
572
+ # A grad scaler with alpha =1 is nn.Identity(), which outputs None given None as input.
573
+ self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(1)
574
+ print("hidden_state_layer_weights is set to None.")
575
+
576
+ elif learnable_hidden_state_weights_scheme == 'per-layer':
577
+ # Learnable weights of the last 3 layers, initialized to putting more focus on the last layer.
578
+ # 'per-layer': Different weights for different layers, but the same for different channels.
579
+ # hidden_state_layer_weights: [3, 1].
580
+ self.hidden_state_layer_weights = nn.Parameter(torch.tensor([[1.0], [2.0], [4.0]], device=device),
581
+ requires_grad=True)
582
+ self.hidden_state_layer_weights_grad_scaler = gen_gradient_scaler(5)
583
+ print("hidden_state_layer_weights initialized as per-layer [1, 2, 4], with grad scaler 5.")
584
+ else:
585
+ breakpoint()
586
+
587
+ def generate_pad_embeddings(self):
588
+ # clip_embeddings: CLIPTextEmbeddings instance. pad_embeddings is generated after
589
+ # prompt2token_proj is loaded from the finetuned weight. It seems such pad embeddings perform
590
+ # slightly better than the original pad embeddings.
591
+ clip_embeddings = self.prompt2token_proj.text_model.embeddings
592
+ # clip_embeddings() and clip_embeddings.token_embedding() differ in that
593
+ # clip_embeddings() adds positional embeddings, while clip_embeddings.token_embedding() doesn't.
594
+ # Adding positional embeddings seems to help somewhat.
595
+ # pad_tokens: pad_token_id 49407 repeated 77 times.
596
+ # pad_token_id is the EOS token. But BOS is 49406.
597
+ pad_tokens = torch.tensor([self.clip_tokenizer.pad_token_id]).to(clip_embeddings.token_embedding.weight.device).repeat(77)
598
+ # pad_embeddings: [77, 768].
599
+ pad_embeddings = clip_embeddings(pad_tokens)[0]
600
+ # We don't allow face recon to influence the pad embeddings.
601
+ # Otherwise, face identity will leak into the pad embeddings.
602
+ self.pad_embeddings = pad_embeddings.detach()
603
+
604
+ def extend_prompt2token_proj_attention(self, begin_layer_idx=-1, end_layer_idx=-1, multiplier=2, noise_std=0.1):
605
+ if multiplier > 1:
606
+ num_extended_layers = self.prompt2token_proj.extend_clip_attention_MKV_multiplier(begin_layer_idx, end_layer_idx, multiplier, noise_std)
607
+ self.prompt2token_proj_attention_multiplier = multiplier
608
+ print(f"{num_extended_layers} layers in prompt2token_proj_attention are x{multiplier}")
609
+
610
+ def freeze_prompt2token_proj(self):
611
+ # If bg, then prompt2token_proj is set to None. Therefore no need to freeze it.
612
+ # Then we don't have to check whether it's for subj or bg.
613
+ if self.prompt2token_proj is not None:
614
+ frozen_param_names = []
615
+ for param_name, param in self.prompt2token_proj.named_parameters():
616
+ if param.requires_grad:
617
+ param.requires_grad = False
618
+ frozen_param_names.append(param_name)
619
+ # If param is already frozen, then no need to freeze it again.
620
+ print(f"{len(frozen_param_names)} params in Subj prompt2token_proj is frozen.")
621
+ #print(f"Frozen parameters:\n{frozen_param_names}")
622
+
623
+ def __repr__(self):
624
+ type_sig = 'subj' if not self.placeholder_is_bg else 'bg'
625
+ # Fix compatability with the previous version.
626
+ if not hasattr(self, 'bg_prompt_translator_has_to_out_proj'):
627
+ self.bg_prompt_translator_has_to_out_proj = False
628
+ if not hasattr(self, 'num_out_embs'):
629
+ self.num_out_embs = -1
630
+ return f"{type_sig} SubjBasisGenerator: num_out_embs={self.num_out_embs}, " \
631
+ f"bg_prompt_translator_has_to_out_proj={self.bg_prompt_translator_has_to_out_proj}"
632
+
633
+ @dataclass
634
+ class BaseModelOutputWithPooling2(ModelOutput):
635
+ """
636
+ Base class for model's outputs that also contains a pooling of the last hidden states.
637
+
638
+ Args:
639
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
640
+ Sequence of hidden-states at the output of the last layer of the model.
641
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
642
+ Last layer hidden-state of the first token of the sequence (classification token) after further processing
643
+ through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns
644
+ the classification token after processing through a linear layer and a tanh activation function. The linear
645
+ layer weights are trained from the next sentence prediction (classification) objective during pretraining.
646
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
647
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
648
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
649
+
650
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
651
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
652
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
653
+ sequence_length)`.
654
+
655
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
656
+ heads.
657
+ """
658
+
659
+ last_hidden_state: torch.FloatTensor = None
660
+ pooler_output: torch.FloatTensor = None
661
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
662
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
663
+ attn_mask: Optional[torch.FloatTensor] = None
664
+
665
+ # Revised from CLIPVisionTransformer to support attention mask.
666
+ # self: a CLIPVisionTransformer instance.
667
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L821
668
+ # pixel_values: preprocessed B*C*H*W images. [BS, 3, 224, 224]
669
+ # attn_mask: B*H*W attention mask.
670
+ def CLIPVisionTransformer_forward(self, pixel_values = None, attn_mask=None,
671
+ output_attentions = None,
672
+ output_hidden_states = None, return_dict = None):
673
+
674
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
675
+ output_hidden_states = (
676
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
677
+ )
678
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
679
+
680
+ if pixel_values is None:
681
+ raise ValueError("You have to specify pixel_values")
682
+
683
+ # Visual tokens are flattended in embeddings().
684
+ # self.embeddings: CLIPVisionEmbeddings.
685
+ # hidden_states: [BS, 257, 1280]. 257: 16*16 (patch_embeds) + 1 (class_embeds).
686
+ # 16*16 is output from Conv2d(3, 1280, kernel_size=(14, 14), stride=(14, 14), bias=False).
687
+ hidden_states = self.embeddings(pixel_values)
688
+ hidden_states = self.pre_layrnorm(hidden_states)
689
+
690
+ if attn_mask is not None:
691
+ # feat_edge_size: 16.
692
+ feat_edge_size = np.sqrt(hidden_states.shape[1] - 1).astype(int)
693
+ # attn_mask: [BS, 512, 512] -> [BS, 1, 16, 16].
694
+ attn_mask = F.interpolate(attn_mask.unsqueeze(1), size=(feat_edge_size, feat_edge_size), mode='nearest')
695
+ # Flatten the mask: [BS, 1, 16, 16] => [BS, 1, 256].
696
+ attn_mask = attn_mask.flatten(2)
697
+ # Prepend 1 to the mask: [BS, 1, 256] => [BS, 1, 257].
698
+ # This 1 corresponds to class_embeds, which is always attended to.
699
+ attn_mask = torch.cat([torch.ones_like(attn_mask[:, :, :1]), attn_mask], dim=-1)
700
+ attn_mask_pairs = torch.matmul(attn_mask.transpose(-1, -2), attn_mask).unsqueeze(1)
701
+ else:
702
+ attn_mask_pairs = None
703
+
704
+ # encoder: CLIPEncoder.
705
+ encoder_outputs = self.encoder(
706
+ inputs_embeds=hidden_states,
707
+ # New feature: (***The official documentation is wrong***)
708
+ # attention_mask (`torch.Tensor` of shape `(batch_size, 1, sequence_length, sequence_length)`, *optional*):
709
+ # Mask to avoid performing attention on pairs of token. Mask values selected in `[0, 1]`:
710
+ # - 1 for pairs that are **not masked**,
711
+ # - 0 for pairs that are **masked**.
712
+ # attention_mask is eventually used by CLIPEncoderLayer:
713
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py#L370
714
+ attention_mask=attn_mask_pairs,
715
+ output_attentions=output_attentions, # False
716
+ output_hidden_states=output_hidden_states, # True
717
+ return_dict=return_dict, # True
718
+ )
719
+
720
+ # last_hidden_state: [BS, 257, 1280]
721
+ last_hidden_state = encoder_outputs[0]
722
+ pooled_output = last_hidden_state[:, 0, :]
723
+ pooled_output = self.post_layernorm(pooled_output)
724
+
725
+ # return_dict is True.
726
+ if not return_dict:
727
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
728
+
729
+ return BaseModelOutputWithPooling2(
730
+ last_hidden_state=last_hidden_state,
731
+ pooler_output=pooled_output,
732
+ hidden_states=encoder_outputs.hidden_states,
733
+ attentions=encoder_outputs.attentions,
734
+ # Newly added: return resized flattened attention mask.
735
+ # [BS, 1, 257] -> [BS, 257, 1]
736
+ attn_mask=attn_mask.permute(0, 2, 1) if attn_mask is not None else None
737
+ )
738
+
739
+
740
+ class CLIPVisionModelWithMask(CLIPVisionModel):
741
+ def __init__(self, config):
742
+ super().__init__(config)
743
+ # Replace vision_model.forward() with the new one that supports mask.
744
+ self.vision_model.forward = CLIPVisionTransformer_forward.__get__(self.vision_model)
745
+
746
+ def forward(self, pixel_values = None, attn_mask = None, output_attentions = None,
747
+ output_hidden_states = None, return_dict = None):
748
+
749
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
750
+
751
+ return self.vision_model(
752
+ pixel_values=pixel_values,
753
+ attn_mask=attn_mask,
754
+ output_attentions=output_attentions,
755
+ output_hidden_states=output_hidden_states,
756
+ return_dict=return_dict,
757
+ )
758
+
util.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from PIL import Image
6
+ import cv2
7
+
8
+ # add_noise_to_tensor() adds a fixed amount of noise to the tensor.
9
+ def add_noise_to_tensor(ts, noise_std, noise_std_is_relative=True, keep_norm=False,
10
+ std_dim=-1, norm_dim=-1):
11
+ if noise_std_is_relative:
12
+ ts_std_mean = ts.std(dim=std_dim).mean().detach()
13
+ noise_std *= ts_std_mean
14
+
15
+ noise = torch.randn_like(ts) * noise_std
16
+ if keep_norm:
17
+ orig_norm = ts.norm(dim=norm_dim, keepdim=True)
18
+ ts = ts + noise
19
+ new_norm = ts.norm(dim=norm_dim, keepdim=True).detach()
20
+ ts = ts * orig_norm / (new_norm + 1e-8)
21
+ else:
22
+ ts = ts + noise
23
+
24
+ return ts
25
+
26
+
27
+ # Revised from RevGrad, by removing the grad negation.
28
+ class ScaleGrad(torch.autograd.Function):
29
+ @staticmethod
30
+ def forward(ctx, input_, alpha_, debug=False):
31
+ ctx.save_for_backward(alpha_, debug)
32
+ output = input_
33
+ if debug:
34
+ print(f"input: {input_.abs().mean().item()}")
35
+ return output
36
+
37
+ @staticmethod
38
+ def backward(ctx, grad_output): # pragma: no cover
39
+ # saved_tensors returns a tuple of tensors.
40
+ alpha_, debug = ctx.saved_tensors
41
+ if ctx.needs_input_grad[0]:
42
+ grad_output2 = grad_output * alpha_
43
+ if debug:
44
+ print(f"grad_output2: {grad_output2.abs().mean().item()}")
45
+ else:
46
+ grad_output2 = None
47
+ return grad_output2, None, None
48
+
49
+ class GradientScaler(nn.Module):
50
+ def __init__(self, alpha=1., debug=False, *args, **kwargs):
51
+ """
52
+ A gradient scaling layer.
53
+ This layer has no parameters, and simply scales the gradient in the backward pass.
54
+ """
55
+ super().__init__(*args, **kwargs)
56
+
57
+ self._alpha = torch.tensor(alpha, requires_grad=False)
58
+ self._debug = torch.tensor(debug, requires_grad=False)
59
+
60
+ def forward(self, input_):
61
+ _debug = self._debug if hasattr(self, '_debug') else False
62
+ return ScaleGrad.apply(input_, self._alpha.to(input_.device), _debug)
63
+
64
+ def gen_gradient_scaler(alpha, debug=False):
65
+ if alpha == 1:
66
+ return nn.Identity()
67
+ if alpha > 0:
68
+ return GradientScaler(alpha, debug=debug)
69
+ else:
70
+ assert alpha == 0
71
+ # Don't use lambda function here, otherwise the object can't be pickled.
72
+ return torch.detach
73
+
74
+ #@torch.autocast(device_type="cuda")
75
+ # In AdaFaceWrapper, input_max_length is 22.
76
+ def arc2face_forward_face_embs(tokenizer, arc2face_text_encoder, face_embs,
77
+ input_max_length=77, return_full_and_core_embs=True):
78
+
79
+ '''
80
+ arc2face_text_encoder: arc2face_models.py CLIPTextModelWrapper instance.
81
+ face_embs: (N, 512) normalized ArcFace embeddings.
82
+ return_full_and_core_embs: Return both the full prompt embeddings and the core embeddings.
83
+ If False, return only the core embeddings.
84
+
85
+ '''
86
+
87
+ # arcface_token_id: 1014
88
+ arcface_token_id = tokenizer.encode("id", add_special_tokens=False)[0]
89
+
90
+ # This step should be quite fast, and there's no need to cache the input_ids.
91
+ input_ids = tokenizer(
92
+ "photo of a id person",
93
+ truncation=True,
94
+ padding="max_length",
95
+ max_length=input_max_length, #tokenizer.model_max_length,
96
+ return_tensors="pt",
97
+ ).input_ids.to(face_embs.device)
98
+ # input_ids: [1, 77] or [3, 77] (during training).
99
+ input_ids = input_ids.repeat(len(face_embs), 1)
100
+ face_embs_dtype = face_embs.dtype
101
+ face_embs = face_embs.to(arc2face_text_encoder.dtype)
102
+ # face_embs_padded: [1, 512] -> [1, 768].
103
+ face_embs_padded = F.pad(face_embs, (0, arc2face_text_encoder.config.hidden_size - face_embs.shape[-1]), "constant", 0)
104
+ # arc2face_text_encoder(input_ids=input_ids, ...) is called twice. The first is only to get the token embeddings (the shallowest mapping).
105
+ # The second call does the ordinary CLIP text encoding pass.
106
+ token_embs = arc2face_text_encoder(input_ids=input_ids, return_token_embs=True)
107
+ token_embs[input_ids==arcface_token_id] = face_embs_padded
108
+
109
+ prompt_embeds = arc2face_text_encoder(
110
+ input_ids=input_ids,
111
+ input_token_embs=token_embs,
112
+ return_token_embs=False
113
+ )[0]
114
+
115
+ # Restore the original dtype of prompt_embeds: float16 -> float32.
116
+ prompt_embeds = prompt_embeds.to(face_embs_dtype)
117
+
118
+ if return_full_and_core_embs:
119
+ # token 4: 'id' in "photo of a id person".
120
+ # 4:20 are the most important 16 embeddings that contain the subject's identity.
121
+ # [N, 77, 768] -> [N, 16, 768]
122
+ return prompt_embeds, prompt_embeds[:, 4:20]
123
+ else:
124
+ # [N, 16, 768]
125
+ return prompt_embeds[:, 4:20]
126
+
127
+ def get_b_core_e_embeddings(prompt_embeds, length=22):
128
+ b_core_e_embs = torch.cat([ prompt_embeds[:, :length], prompt_embeds[:, [-1]] ], dim=1)
129
+ return b_core_e_embs
130
+
131
+ # return_emb_types: a list of strings, each string is among ['full', 'core', 'full_zeroed_extra', 'b_core_e'].
132
+ def arc2face_inverse_face_prompt_embs(clip_tokenizer, inverse_text_encoder, face_prompt_embs, list_extra_words,
133
+ return_emb_types, pad_embeddings, hidden_state_layer_weights=None,
134
+ input_max_length=77, zs_extra_words_scale=0.5):
135
+
136
+ '''
137
+ inverse_text_encoder: arc2face_models.py CLIPTextModelWrapper instance with **custom weights**.
138
+ inverse_text_encoder is NOT the original arc2face text encoder, but retrained to do inverse mapping.
139
+ face_prompt_embs: (BS, 16, 768). Only the core embeddings, no paddings.
140
+ list_extra_words: [s_1, ..., s_BS], each s_i is a list of extra words to be added to the prompt.
141
+ return_full_and_core_embs: Return both the full prompt embeddings and the core embeddings.
142
+ If False, return only the core embeddings.
143
+ '''
144
+
145
+ if list_extra_words is not None:
146
+ if len(list_extra_words) != len(face_prompt_embs):
147
+ if len(face_prompt_embs) > 1:
148
+ print("Warn: list_extra_words has different length as face_prompt_embs.")
149
+ if len(list_extra_words) == 1:
150
+ list_extra_words = list_extra_words * len(face_prompt_embs)
151
+ else:
152
+ breakpoint()
153
+ else:
154
+ # len(face_prompt_embs) == 1, this occurs when same_subject_in_batch == True, e.g. in do_mix_prompt_distillation.
155
+ # But list_extra_words always corresponds to the actual batch size. So we only take the first element.
156
+ list_extra_words = list_extra_words[:1]
157
+
158
+ for extra_words in list_extra_words:
159
+ assert len(extra_words.split()) <= 2, "Each extra_words string should consist of at most 2 words."
160
+ # 16 ", " are placeholders for face_prompt_embs.
161
+ prompt_templates = [ "photo of a " + ", " * 16 + list_extra_words[i] for i in range(len(list_extra_words)) ]
162
+ else:
163
+ # 16 ", " are placeholders for face_prompt_embs.
164
+ # No extra words are added to the prompt.
165
+ prompt_templates = [ "photo of a " + ", " * 16 for _ in range(len(face_prompt_embs)) ]
166
+
167
+ # This step should be quite fast, and there's no need to cache the input_ids.
168
+ # input_ids: [BS, 77].
169
+ input_ids = clip_tokenizer(
170
+ prompt_templates,
171
+ truncation=True,
172
+ padding="max_length",
173
+ max_length=input_max_length,
174
+ return_tensors="pt",
175
+ ).input_ids.to(face_prompt_embs.device)
176
+
177
+ face_prompt_embs_dtype = face_prompt_embs.dtype
178
+ face_prompt_embs = face_prompt_embs.to(inverse_text_encoder.dtype)
179
+
180
+ # token_embs: [1, 77, 768]. This call is only to get the template token embeddings (the shallowest mapping).
181
+ token_embs = inverse_text_encoder(input_ids=input_ids, return_token_embs=True)
182
+ # token 4: first ", " in the template prompt.
183
+ # Replace embeddings of 16 placeholder ", " with face_prompt_embs.
184
+ token_embs[:, 4:20] = face_prompt_embs
185
+
186
+ # This call does the ordinary CLIP text encoding pass.
187
+ prompt_embeds = inverse_text_encoder(
188
+ input_ids=input_ids,
189
+ input_token_embs=token_embs,
190
+ hidden_state_layer_weights=hidden_state_layer_weights,
191
+ return_token_embs=False
192
+ )[0]
193
+
194
+ # Restore the original dtype of prompt_embeds: float16 -> float32.
195
+ prompt_embeds = prompt_embeds.to(face_prompt_embs_dtype)
196
+ # token 4: first ", " in the template prompt.
197
+ # 4:20 are the most important 16 embeddings that contain the subject's identity.
198
+ # 20:22 are embeddings of the (at most) two extra words.
199
+ # [N, 77, 768] -> [N, 16, 768]
200
+ core_prompt_embs = prompt_embeds[:, 4:20]
201
+ if list_extra_words is not None:
202
+ # [N, 16, 768] -> [N, 18, 768]
203
+ extra_words_embs = prompt_embeds[:, 20:22] * zs_extra_words_scale
204
+ core_prompt_embs = torch.cat([core_prompt_embs, extra_words_embs], dim=1)
205
+
206
+ return_prompts = []
207
+ for emb_type in return_emb_types:
208
+ if emb_type == 'full':
209
+ return_prompts.append(prompt_embeds)
210
+ elif emb_type == 'full_half_pad':
211
+ prompt_embeds2 = prompt_embeds.clone()
212
+ PADS = prompt_embeds2.shape[1] - 23
213
+ if PADS >= 2:
214
+ # Fill half of the remaining embeddings with pad embeddings.
215
+ prompt_embeds2[:, 22:22+PADS//2] = pad_embeddings[22:22+PADS//2]
216
+ return_prompts.append(prompt_embeds2)
217
+ elif emb_type == 'full_pad':
218
+ prompt_embeds2 = prompt_embeds.clone()
219
+ # Fill the 22nd to the second last embeddings with pad embeddings.
220
+ prompt_embeds2[:, 22:-1] = pad_embeddings[22:-1]
221
+ return_prompts.append(prompt_embeds2)
222
+ elif emb_type == 'core':
223
+ return_prompts.append(core_prompt_embs)
224
+ elif emb_type == 'full_zeroed_extra':
225
+ prompt_embeds2 = prompt_embeds.clone()
226
+ # Only add two pad embeddings. The remaining embeddings are set to 0.
227
+ # Make the positional embeddings align with the actual positions.
228
+ prompt_embeds2[:, 22:24] = pad_embeddings[22:24]
229
+ prompt_embeds2[:, 24:-1] = 0
230
+ return_prompts.append(prompt_embeds2)
231
+ elif emb_type == 'b_core_e':
232
+ # The first 22 embeddings, plus the last EOS embedding.
233
+ b_core_e_embs = get_b_core_e_embeddings(prompt_embeds, length=22)
234
+ return_prompts.append(b_core_e_embs)
235
+ else:
236
+ breakpoint()
237
+
238
+ return return_prompts
239
+
240
+ # if pre_face_embs is None, generate random face embeddings [BS, 512].
241
+ # image_folder is passed only for logging purpose. image_paths contains the paths of the images.
242
+ def get_arc2face_id_prompt_embs(face_app, clip_tokenizer, arc2face_text_encoder,
243
+ extract_faceid_embeds, pre_face_embs,
244
+ image_folder, image_paths, images_np,
245
+ id_batch_size, device,
246
+ input_max_length=77, noise_level=0.0,
247
+ return_core_id_embs=False,
248
+ gen_neg_prompt=False, verbose=False):
249
+ face_image_count = 0
250
+
251
+ if extract_faceid_embeds:
252
+ faceid_embeds = []
253
+ if image_paths is not None:
254
+ images_np = []
255
+ for image_path in image_paths:
256
+ image_np = np.array(Image.open(image_path))
257
+ images_np.append(image_np)
258
+
259
+ for i, image_np in enumerate(images_np):
260
+ image_obj = Image.fromarray(image_np).resize((512, 512), Image.NEAREST)
261
+ # Remove alpha channel if it exists.
262
+ if image_obj.mode == 'RGBA':
263
+ image_obj = image_obj.convert('RGB')
264
+ # This seems NOT a bug. The input image should be in BGR format, as per
265
+ # https://github.com/deepinsight/insightface/issues/524
266
+ image_np = cv2.cvtColor(np.array(image_obj), cv2.COLOR_RGB2BGR)
267
+ image_np = np.array(image_obj)
268
+
269
+ face_infos = face_app.get(image_np)
270
+ if verbose and image_paths is not None:
271
+ print(image_paths[i], len(face_infos))
272
+ # Assume all images belong to the same subject. Therefore, we can skip the images with no face detected.
273
+ if len(face_infos) == 0:
274
+ continue
275
+ # only use the maximum face
276
+ face_info = sorted(face_infos, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1]
277
+ # Each faceid_embed: [1, 512]
278
+ faceid_embeds.append(torch.from_numpy(face_info.normed_embedding).unsqueeze(0))
279
+ face_image_count += 1
280
+
281
+ if verbose:
282
+ if image_folder is not None:
283
+ print(f"Extracted ID embeddings from {face_image_count} images in {image_folder}")
284
+ else:
285
+ print(f"Extracted ID embeddings from {face_image_count} images")
286
+
287
+ if len(faceid_embeds) == 0:
288
+ print("No face detected. Use a random face instead.")
289
+ faceid_embeds = torch.randn(id_batch_size, 512).to(device=device, dtype=torch.float16)
290
+ else:
291
+ # faceid_embeds: [10, 512]
292
+ faceid_embeds = torch.cat(faceid_embeds, dim=0)
293
+ # faceid_embeds: [10, 512] -> [1, 512].
294
+ # and the resulted prompt embeddings are the same.
295
+ faceid_embeds = faceid_embeds.mean(dim=0, keepdim=True).to(device=device, dtype=torch.float16)
296
+ else:
297
+ # Random face embeddings. faceid_embeds: [BS, 512].
298
+ if pre_face_embs is None:
299
+ faceid_embeds = torch.randn(id_batch_size, 512)
300
+ else:
301
+ faceid_embeds = pre_face_embs
302
+ if pre_face_embs.shape[0] == 1:
303
+ faceid_embeds = faceid_embeds.repeat(id_batch_size, 1)
304
+
305
+ faceid_embeds = faceid_embeds.to(device=device, dtype=torch.float16)
306
+
307
+ if noise_level > 0:
308
+ # If id_batch_size > 1, after adding noises, the id_batch_size embeddings will be different.
309
+ faceid_embeds = add_noise_to_tensor(faceid_embeds, noise_level, noise_std_is_relative=True, keep_norm=True)
310
+
311
+ faceid_embeds = F.normalize(faceid_embeds, p=2, dim=-1)
312
+
313
+ # arc2face_pos_prompt_emb, arc2face_neg_prompt_emb: [BS, 77, 768]
314
+ with torch.no_grad():
315
+ arc2face_pos_prompt_emb, arc2face_pos_core_prompt_emb = \
316
+ arc2face_forward_face_embs(clip_tokenizer, arc2face_text_encoder,
317
+ faceid_embeds, input_max_length=input_max_length,
318
+ return_full_and_core_embs=True)
319
+ if return_core_id_embs:
320
+ arc2face_pos_prompt_emb = arc2face_pos_core_prompt_emb
321
+ # If extract_faceid_embeds, we assume all images are from the same subject, and the batch dim of faceid_embeds is 1.
322
+ # So we need to repeat faceid_embeds.
323
+ if extract_faceid_embeds:
324
+ faceid_embeds = faceid_embeds.repeat(id_batch_size, 1)
325
+ arc2face_pos_prompt_emb = arc2face_pos_prompt_emb.repeat(id_batch_size, 1, 1)
326
+
327
+ if gen_neg_prompt:
328
+ with torch.no_grad():
329
+ arc2face_neg_prompt_emb, arc2face_neg_core_prompt_emb = \
330
+ arc2face_forward_face_embs(clip_tokenizer, arc2face_text_encoder,
331
+ torch.zeros_like(faceid_embeds),
332
+ input_max_length=input_max_length,
333
+ return_full_and_core_embs=True)
334
+ if return_core_id_embs:
335
+ arc2face_neg_prompt_emb = arc2face_neg_core_prompt_emb
336
+
337
+ #if extract_faceid_embeds:
338
+ # arc2face_neg_prompt_emb = arc2face_neg_prompt_emb.repeat(id_batch_size, 1, 1)
339
+ return face_image_count, faceid_embeds, arc2face_pos_prompt_emb, arc2face_neg_prompt_emb
340
+ else:
341
+ return face_image_count, faceid_embeds, arc2face_pos_prompt_emb
342
+