File size: 15,907 Bytes
02cc20b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b402b3c
 
 
 
 
 
 
 
02cc20b
 
 
 
 
 
 
 
 
b402b3c
02cc20b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
import cv2

# add_noise_to_tensor() adds a fixed amount of noise to the tensor.
def add_noise_to_tensor(ts, noise_std, noise_std_is_relative=True, keep_norm=False,
                        std_dim=-1, norm_dim=-1):
    if noise_std_is_relative:
        ts_std_mean = ts.std(dim=std_dim).mean().detach()
        noise_std *= ts_std_mean

    noise = torch.randn_like(ts) * noise_std
    if keep_norm:
        orig_norm = ts.norm(dim=norm_dim, keepdim=True)
        ts = ts + noise
        new_norm  = ts.norm(dim=norm_dim, keepdim=True).detach()
        ts = ts * orig_norm / (new_norm + 1e-8)
    else:
        ts = ts + noise
        
    return ts


# Revised from RevGrad, by removing the grad negation.
class ScaleGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_, alpha_, debug=False):
        ctx.save_for_backward(alpha_, debug)
        output = input_
        if debug:
            print(f"input: {input_.abs().mean().item()}")
        return output

    @staticmethod
    def backward(ctx, grad_output):  # pragma: no cover
        # saved_tensors returns a tuple of tensors.
        alpha_, debug = ctx.saved_tensors
        if ctx.needs_input_grad[0]:
            grad_output2 = grad_output * alpha_
            if debug:
                print(f"grad_output2: {grad_output2.abs().mean().item()}")
        else:
            grad_output2 = None
        return grad_output2, None, None

class GradientScaler(nn.Module):
    def __init__(self, alpha=1., debug=False, *args, **kwargs):
        """
        A gradient scaling layer.
        This layer has no parameters, and simply scales the gradient in the backward pass.
        """
        super().__init__(*args, **kwargs)

        self._alpha = torch.tensor(alpha, requires_grad=False)
        self._debug = torch.tensor(debug, requires_grad=False)

    def forward(self, input_):
        _debug = self._debug if hasattr(self, '_debug') else False
        return ScaleGrad.apply(input_, self._alpha.to(input_.device), _debug)

def gen_gradient_scaler(alpha, debug=False):
    if alpha == 1:
        return nn.Identity()
    if alpha > 0:
        return GradientScaler(alpha, debug=debug)
    else:
        assert alpha == 0
        # Don't use lambda function here, otherwise the object can't be pickled.
        return torch.detach

#@torch.autocast(device_type="cuda")
# In AdaFaceWrapper, input_max_length is 22.
def arc2face_forward_face_embs(tokenizer, arc2face_text_encoder, face_embs, 
                               input_max_length=77, return_full_and_core_embs=True):

    '''
    arc2face_text_encoder: arc2face_models.py CLIPTextModelWrapper instance.
    face_embs: (N, 512) normalized ArcFace embeddings.
    return_full_and_core_embs: Return both the full prompt embeddings and the core embeddings. 
                               If False, return only the core embeddings.

    '''

    # arcface_token_id: 1014
    arcface_token_id = tokenizer.encode("id", add_special_tokens=False)[0]

    # This step should be quite fast, and there's no need to cache the input_ids.
    input_ids = tokenizer(
            "photo of a id person",
            truncation=True,
            padding="max_length",
            max_length=input_max_length, #tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids.to(face_embs.device)
    # input_ids: [1, 77] or [3, 77] (during training).
    input_ids = input_ids.repeat(len(face_embs), 1)
    face_embs_dtype = face_embs.dtype
    face_embs = face_embs.to(arc2face_text_encoder.dtype)
    # face_embs_padded: [1, 512] -> [1, 768].
    face_embs_padded = F.pad(face_embs, (0, arc2face_text_encoder.config.hidden_size - face_embs.shape[-1]), "constant", 0)
    # arc2face_text_encoder(input_ids=input_ids, ...) is called twice. The first is only to get the token embeddings (the shallowest mapping).
    # The second call does the ordinary CLIP text encoding pass.
    token_embs = arc2face_text_encoder(input_ids=input_ids, return_token_embs=True)
    token_embs[input_ids==arcface_token_id] = face_embs_padded

    prompt_embeds = arc2face_text_encoder(
        input_ids=input_ids,
        input_token_embs=token_embs,
        return_token_embs=False
    )[0]

    # Restore the original dtype of prompt_embeds: float16 -> float32.
    prompt_embeds = prompt_embeds.to(face_embs_dtype)

    if return_full_and_core_embs:
        # token 4: 'id' in "photo of a id person". 
        # 4:20 are the most important 16 embeddings that contain the subject's identity.
        # [N, 77, 768] -> [N, 16, 768]
        return prompt_embeds, prompt_embeds[:, 4:20]
    else:
        # [N, 16, 768]
        return prompt_embeds[:, 4:20]

def get_b_core_e_embeddings(prompt_embeds, length=22):
    b_core_e_embs = torch.cat([ prompt_embeds[:, :length], prompt_embeds[:, [-1]] ], dim=1)
    return b_core_e_embs

# return_emb_types: a list of strings, each string is among ['full', 'core', 'full_zeroed_extra', 'b_core_e'].
def arc2face_inverse_face_prompt_embs(clip_tokenizer, inverse_text_encoder, face_prompt_embs, list_extra_words,
                                      return_emb_types, pad_embeddings, hidden_state_layer_weights=None, 
                                      input_max_length=77, zs_extra_words_scale=0.5):

    '''
    inverse_text_encoder: arc2face_models.py CLIPTextModelWrapper instance with **custom weights**.
    inverse_text_encoder is NOT the original arc2face text encoder, but retrained to do inverse mapping.
    face_prompt_embs: (BS, 16, 768). Only the core embeddings, no paddings.
    list_extra_words: [s_1, ..., s_BS], each s_i is a list of extra words to be added to the prompt.
    return_full_and_core_embs: Return both the full prompt embeddings and the core embeddings. 
                               If False, return only the core embeddings.
    '''

    if list_extra_words is not None:
        if len(list_extra_words) != len(face_prompt_embs):
            if len(face_prompt_embs) > 1:
                print("Warn: list_extra_words has different length as face_prompt_embs.")
                if len(list_extra_words) == 1:
                    list_extra_words = list_extra_words * len(face_prompt_embs)
                else:
                    breakpoint()
            else:
                # len(face_prompt_embs) == 1, this occurs when same_subject_in_batch == True, e.g. in do_mix_prompt_distillation.
                # But list_extra_words always corresponds to the actual batch size. So we only take the first element.
                list_extra_words = list_extra_words[:1]
                
        for extra_words in list_extra_words:
            assert len(extra_words.split()) <= 2, "Each extra_words string should consist of at most 2 words."
        # 16 ", " are placeholders for face_prompt_embs.
        prompt_templates = [ "photo of a " + ", " * 16 + list_extra_words[i] for i in range(len(list_extra_words)) ]
    else:
        # 16 ", " are placeholders for face_prompt_embs.
        # No extra words are added to the prompt.
        prompt_templates = [ "photo of a " + ", " * 16 for _ in range(len(face_prompt_embs)) ]

    # This step should be quite fast, and there's no need to cache the input_ids.
    # input_ids: [BS, 77].
    input_ids = clip_tokenizer(
            prompt_templates,
            truncation=True,
            padding="max_length",
            max_length=input_max_length,
            return_tensors="pt",
        ).input_ids.to(face_prompt_embs.device)

    face_prompt_embs_dtype  = face_prompt_embs.dtype
    face_prompt_embs        = face_prompt_embs.to(inverse_text_encoder.dtype)

    # token_embs: [1, 77, 768]. This call is only to get the template token embeddings (the shallowest mapping).
    token_embs = inverse_text_encoder(input_ids=input_ids, return_token_embs=True)
    # token 4: first ", " in the template prompt.
    # Replace embeddings of 16 placeholder ", " with face_prompt_embs.
    token_embs[:, 4:20] = face_prompt_embs

    # This call does the ordinary CLIP text encoding pass.
    prompt_embeds = inverse_text_encoder(
        input_ids=input_ids,
        input_token_embs=token_embs,
        hidden_state_layer_weights=hidden_state_layer_weights,
        return_token_embs=False
    )[0]

    # Restore the original dtype of prompt_embeds: float16 -> float32.
    prompt_embeds = prompt_embeds.to(face_prompt_embs_dtype)
    # token 4: first ", " in the template prompt.
    # 4:20 are the most important 16 embeddings that contain the subject's identity.
    # 20:22 are embeddings of the (at most) two extra words.
    # [N, 77, 768] -> [N, 16, 768]
    core_prompt_embs = prompt_embeds[:, 4:20]
    if list_extra_words is not None:
        # [N, 16, 768] -> [N, 18, 768]
        extra_words_embs = prompt_embeds[:, 20:22] * zs_extra_words_scale
        core_prompt_embs = torch.cat([core_prompt_embs, extra_words_embs], dim=1)

    return_prompts = []
    for emb_type in return_emb_types:
        if emb_type == 'full':
            return_prompts.append(prompt_embeds)
        elif emb_type == 'full_half_pad':
            prompt_embeds2 = prompt_embeds.clone()
            PADS  = prompt_embeds2.shape[1] - 23
            if PADS >= 2:
                # Fill half of the remaining embeddings with pad embeddings.
                prompt_embeds2[:, 22:22+PADS//2] = pad_embeddings[22:22+PADS//2]
            return_prompts.append(prompt_embeds2)
        elif emb_type == 'full_pad':
            prompt_embeds2 = prompt_embeds.clone()
            # Fill the 22nd to the second last embeddings with pad embeddings.
            prompt_embeds2[:, 22:-1] = pad_embeddings[22:-1]
            return_prompts.append(prompt_embeds2)
        elif emb_type == 'core':
            return_prompts.append(core_prompt_embs)
        elif emb_type == 'full_zeroed_extra':
            prompt_embeds2 = prompt_embeds.clone()
            # Only add two pad embeddings. The remaining embeddings are set to 0.
            # Make the positional embeddings align with the actual positions.
            prompt_embeds2[:, 22:24] = pad_embeddings[22:24]
            prompt_embeds2[:, 24:-1] = 0
            return_prompts.append(prompt_embeds2)
        elif emb_type == 'b_core_e':
            # The first 22 embeddings, plus the last EOS embedding.
            b_core_e_embs = get_b_core_e_embeddings(prompt_embeds, length=22)
            return_prompts.append(b_core_e_embs)
        else:
            breakpoint()

    return return_prompts

# if pre_face_embs is None, generate random face embeddings [BS, 512].
# image_folder is passed only for logging purpose. image_paths contains the paths of the images.
def get_arc2face_id_prompt_embs(face_app, clip_tokenizer, arc2face_text_encoder, 
                                extract_faceid_embeds, pre_face_embs, 
                                image_folder, image_paths, images_np,
                                id_batch_size, device, 
                                input_max_length=77, noise_level=0.0, 
                                return_core_id_embs=False,
                                gen_neg_prompt=False, verbose=False):
    if extract_faceid_embeds:
        image_count = 0
        faceid_embeds = []
        if image_paths is not None:
            images_np = []
            for image_path in image_paths:
                image_np = np.array(Image.open(image_path))
                images_np.append(image_np)

        for i, image_np in enumerate(images_np):
            image_obj = Image.fromarray(image_np).resize((512, 512), Image.NEAREST)
            # Remove alpha channel if it exists.
            if image_obj.mode == 'RGBA':
                image_obj = image_obj.convert('RGB')
            # This seems NOT a bug. The input image should be in BGR format, as per 
            # https://github.com/deepinsight/insightface/issues/524
            image_np = cv2.cvtColor(np.array(image_obj), cv2.COLOR_RGB2BGR)
            image_np = np.array(image_obj)

            face_infos = face_app.get(image_np)
            if verbose and image_paths is not None:
                print(image_paths[i], len(face_infos))
            # Assume all images belong to the same subject. Therefore, we can skip the images with no face detected.
            if len(face_infos) == 0:
                continue
            # only use the maximum face
            face_info = sorted(face_infos, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1]
            # Each faceid_embed: [1, 512]
            faceid_embeds.append(torch.from_numpy(face_info.normed_embedding).unsqueeze(0))
            image_count += 1

        if verbose:
            if image_folder is not None:
                print(f"Extracted ID embeddings from {image_count} images in {image_folder}")
            else:
                print(f"Extracted ID embeddings from {image_count} images")

        if len(faceid_embeds) == 0:
            print("No face detected. Use a random face instead.")
            faceid_embeds = torch.randn(id_batch_size, 512).to(device=device, dtype=torch.float16)
        else:
            # faceid_embeds: [10, 512]
            faceid_embeds = torch.cat(faceid_embeds, dim=0)
            # faceid_embeds: [10, 512] -> [1, 512].
            # and the resulted prompt embeddings are the same.
            faceid_embeds = faceid_embeds.mean(dim=0, keepdim=True).to(device=device, dtype=torch.float16)
    else:
        # Random face embeddings. faceid_embeds: [BS, 512].
        if pre_face_embs is None:
            faceid_embeds = torch.randn(id_batch_size, 512)
        else:
            faceid_embeds = pre_face_embs
            if pre_face_embs.shape[0] == 1:
                faceid_embeds = faceid_embeds.repeat(id_batch_size, 1)

        faceid_embeds = faceid_embeds.to(device=device, dtype=torch.float16)

    if noise_level > 0:
        # If id_batch_size > 1, after adding noises, the id_batch_size embeddings will be different.
        faceid_embeds = add_noise_to_tensor(faceid_embeds, noise_level, noise_std_is_relative=True, keep_norm=True)

    faceid_embeds = F.normalize(faceid_embeds, p=2, dim=-1)

    # arc2face_pos_prompt_emb, arc2face_neg_prompt_emb: [BS, 77, 768]
    with torch.no_grad():
        arc2face_pos_prompt_emb, arc2face_pos_core_prompt_emb  = \
             arc2face_forward_face_embs(clip_tokenizer, arc2face_text_encoder, 
                                        faceid_embeds, input_max_length=input_max_length,
                                        return_full_and_core_embs=True)
        if return_core_id_embs:
            arc2face_pos_prompt_emb = arc2face_pos_core_prompt_emb
    # If extract_faceid_embeds, we assume all images are from the same subject, and the batch dim of faceid_embeds is 1. 
    # So we need to repeat faceid_embeds.
    if extract_faceid_embeds:
        faceid_embeds = faceid_embeds.repeat(id_batch_size, 1)
        arc2face_pos_prompt_emb = arc2face_pos_prompt_emb.repeat(id_batch_size, 1, 1)

    if gen_neg_prompt:
        with torch.no_grad():
            arc2face_neg_prompt_emb, arc2face_neg_core_prompt_emb = \
                arc2face_forward_face_embs(clip_tokenizer, arc2face_text_encoder, 
                                           torch.zeros_like(faceid_embeds),
                                           input_max_length=input_max_length,
                                           return_full_and_core_embs=True)
            if return_core_id_embs:
                arc2face_neg_prompt_emb = arc2face_neg_core_prompt_emb

        #if extract_faceid_embeds:
        #    arc2face_neg_prompt_emb = arc2face_neg_prompt_emb.repeat(id_batch_size, 1, 1)
        return faceid_embeds, arc2face_pos_prompt_emb, arc2face_neg_prompt_emb
    else:
        return faceid_embeds, arc2face_pos_prompt_emb