import torch import torch.nn.functional as F @torch.no_grad() def project_face_embs(pipeline, face_embs): ''' face_embs: (N, 512) normalized ArcFace embeddings ''' arcface_token_id = pipeline.tokenizer.encode("id", add_special_tokens=False)[0] input_ids = pipeline.tokenizer( "photo of a id person", truncation=True, padding="max_length", max_length=pipeline.tokenizer.model_max_length, return_tensors="pt", ).input_ids.to(pipeline.device) face_embs_padded = F.pad(face_embs, (0, pipeline.text_encoder.config.hidden_size-512), "constant", 0) token_embs = pipeline.text_encoder(input_ids=input_ids.repeat(len(face_embs), 1), return_token_embs=True) token_embs[input_ids==arcface_token_id] = face_embs_padded prompt_embeds = pipeline.text_encoder( input_ids=input_ids, input_token_embs=token_embs )[0] return prompt_embeds