JackAILab commited on
Commit
9aa2c12
·
verified ·
1 Parent(s): 6e291eb

Update pipline_StableDiffusion_ConsistentID.py

Browse files
pipline_StableDiffusion_ConsistentID.py CHANGED
@@ -5,7 +5,8 @@ import numpy as np
5
  from PIL import Image
6
  import torch
7
  from torchvision import transforms
8
- from insightface.app import FaceAnalysis
 
9
  from safetensors import safe_open
10
  from huggingface_hub.utils import validate_hf_hub_args
11
  from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
@@ -15,15 +16,11 @@ from diffusers.utils import _get_model_file
15
  from functions import process_text_with_markers, masks_for_unique_values, fetch_mask_raw_image, tokenize_and_mask_noun_phrases_ends, prepare_image_token_idx
16
  from functions import ProjPlusModel, masks_for_unique_values
17
  from attention import Consistent_IPAttProcessor, Consistent_AttProcessor, FacialEncoder
18
- # from modelscope.outputs import OutputKeys
19
- # from modelscope.pipelines import pipeline
20
-
21
- #TODO
22
- import sys
23
- sys.path.append("./models/BiSeNet")
24
- from model import BiSeNet
25
-
26
 
 
 
 
 
27
 
28
  PipelineImageInput = Union[
29
  PIL.Image.Image,
@@ -32,7 +29,7 @@ PipelineImageInput = Union[
32
  List[torch.FloatTensor],
33
  ]
34
 
35
-
36
  class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
37
 
38
  @validate_hf_hub_args
@@ -43,13 +40,13 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
43
  subfolder: str = '',
44
  trigger_word_ID: str = '<|image|>',
45
  trigger_word_facial: str = '<|facial|>',
46
- image_encoder_path: str = 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K', # TODO
47
  torch_dtype = torch.float16,
48
  num_tokens = 4,
49
  lora_rank= 128,
50
  **kwargs,
51
  ):
52
- self.lora_rank = lora_rank
53
  self.torch_dtype = torch_dtype
54
  self.num_tokens = num_tokens
55
  self.set_ip_adapter()
@@ -68,7 +65,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
68
  ### BiSeNet
69
  self.bise_net = BiSeNet(n_classes = 19)
70
  self.bise_net.cuda()
71
- self.bise_net_cp='./models/BiSeNet_pretrained_for_ConsistentID.pth' #TODO
72
  self.bise_net.load_state_dict(torch.load(self.bise_net_cp))
73
  self.bise_net.eval()
74
  # Colors for all 20 parts
@@ -82,8 +79,9 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
82
  [255, 0, 255], [255, 85, 255], [255, 170, 255],
83
  [0, 255, 255], [85, 255, 255], [170, 255, 255]]
84
 
85
- ### LLVA Optional
86
- self.llva_model_path = "llava-hf/llava-1.5-7b-hf" #TODO
 
87
  self.llva_prompt = "Describe this person's facial features for me, including face, ears, eyes, nose, and mouth."
88
  self.llva_tokenizer, self.llva_model, self.llva_image_processor, self.llva_context_len = None,None,None,None #load_pretrained_model(self.llva_model_path)
89
 
@@ -91,12 +89,10 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
91
  cross_attention_dim=self.unet.config.cross_attention_dim,
92
  id_embeddings_dim=512,
93
  clip_embeddings_dim=self.image_encoder.config.hidden_size,
94
- num_tokens=self.num_tokens, # 4
95
  ).to(self.device, dtype=self.torch_dtype)
96
  self.FacialEncoder = FacialEncoder(self.image_encoder).to(self.device, dtype=self.torch_dtype)
97
 
98
- # self.skin_retouching = pipeline('skin-retouching-torch', model='damo/cv_unet_skin_retouching_torch', model_revision='v1.0.2')
99
-
100
  # Load the main state dict first.
101
  cache_dir = kwargs.pop("cache_dir", None)
102
  force_download = kwargs.pop("force_download", False)
@@ -189,8 +185,10 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
189
  multi_facial_embeds = torch.stack(hidden_states)
190
  uncond_multi_facial_embeds = torch.stack(uncond_hidden_states)
191
 
 
192
  facial_prompt_embeds = self.FacialEncoder(prompt_embeds, multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask)
193
 
 
194
  uncond_facial_prompt_embeds = self.FacialEncoder(negative_prompt_embeds, uncond_multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask)
195
 
196
  return facial_prompt_embeds, uncond_facial_prompt_embeds
@@ -202,9 +200,11 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
202
  clip_image = clip_image.to(self.device, dtype=self.torch_dtype)
203
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
204
  uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2]
 
205
  faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
206
  image_prompt_tokens = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale)
207
  uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale)
 
208
  return image_prompt_tokens, uncond_image_prompt_embeds
209
 
210
  def set_scale(self, scale):
@@ -220,6 +220,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
220
  faceid_embeds = torch.zeros_like(torch.empty((1, 512)))
221
  else:
222
  faceid_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
 
223
  return faceid_embeds
224
 
225
  @torch.inference_mode()
@@ -237,13 +238,13 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
237
  img = to_tensor(image)
238
  img = torch.unsqueeze(img, 0)
239
  img = img.float().cuda()
240
- out = self.bise_net(img)[0] #1,19,512,512
241
- parsing_anno = out.squeeze(0).cpu().numpy().argmax(0)
242
 
243
  im = np.array(image_resize_PIL)
244
  vis_im = im.copy().astype(np.uint8)
245
  stride=1
246
- vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
247
  vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
248
  vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
249
 
@@ -253,7 +254,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
253
  index = np.where(vis_parsing_anno == pi)
254
  vis_parsing_anno_color[index[0], index[1], :] = self.part_colors[pi]
255
 
256
- vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
257
  vis_parsing_anno_color = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
258
 
259
  return vis_parsing_anno_color, vis_parsing_anno
@@ -282,23 +283,20 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
282
 
283
  return face_caption
284
 
285
-
286
-
287
  @torch.inference_mode()
288
  def get_prepare_facemask(self, input_image_file):
 
289
  vis_parsing_anno_color, vis_parsing_anno = self.parsing_face_mask(input_image_file)
290
  parsing_mask_list = masks_for_unique_values(vis_parsing_anno)
291
 
292
  key_parsing_mask_list = {}
293
  key_list = ["Face", "Left_Ear", "Right_Ear", "Left_Eye", "Right_Eye", "Nose", "Upper_Lip", "Lower_Lip"]
294
-
295
-
296
  processed_keys = set()
297
  for key, mask_image in parsing_mask_list.items():
298
  if key in key_list:
299
  if "_" in key:
300
  prefix = key.split("_")[1]
301
- if prefix in processed_keys:
302
  continue
303
  else:
304
  key_parsing_mask_list[key] = mask_image
@@ -320,6 +318,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
320
  device: Optional[torch.device] = None,
321
  ):
322
  device = device or self._execution_device
 
323
  face_caption_align, key_parsing_mask_list_align = process_text_with_markers(face_caption, key_parsing_mask_list)
324
 
325
  prompt_face = prompt + "Detail:" + face_caption_align
@@ -335,9 +334,11 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
335
  prompt_text_only = prompt_face.replace("<|facial|>", "").replace("<|image|>", "")
336
  tokenizer = self.tokenizer
337
  facial_token_id = tokenizer.convert_tokens_to_ids(facial_token)
338
- image_token_id = None
 
339
  clean_input_id, image_token_mask, facial_token_mask = tokenize_and_mask_noun_phrases_ends(
340
  prompt_face, image_token_id, facial_token_id, tokenizer)
 
341
  image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask = prepare_image_token_idx(
342
  image_token_mask, facial_token_mask, num_id_images, max_num_facials )
343
 
@@ -352,6 +353,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
352
  clip_image_processor = CLIPImageProcessor()
353
 
354
  num_facial_part = len(key_parsing_mask_list)
 
355
  for key in key_parsing_mask_list:
356
  key_mask=key_parsing_mask_list[key]
357
  facial_mask.append(transform_mask(key_mask))
@@ -361,6 +363,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
361
 
362
  padding_ficial_clip_image = torch.zeros_like(torch.zeros([1, 3, 224, 224]))
363
  padding_ficial_mask = torch.zeros_like(torch.zeros([1, image_size, image_size]))
 
364
  if num_facial_part < max_num_facials:
365
  facial_clip_image += [torch.zeros_like(padding_ficial_clip_image) for _ in range(max_num_facials - num_facial_part) ]
366
  facial_mask += [ torch.zeros_like(padding_ficial_mask) for _ in range(max_num_facials - num_facial_part)]
@@ -368,7 +371,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
368
  facial_clip_image = torch.stack(facial_clip_image, dim=1).squeeze(0)
369
  facial_mask = torch.stack(facial_mask, dim=0).squeeze(dim=1)
370
 
371
- return facial_clip_image, facial_mask
372
 
373
  @torch.no_grad()
374
  def __call__(
@@ -393,12 +396,9 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
393
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
394
  callback_steps: int = 1,
395
  input_id_images: PipelineImageInput = None,
396
- reference_id_images: PipelineImageInput =None,
397
  start_merge_step: int = 0,
398
  class_tokens_mask: Optional[torch.LongTensor] = None,
399
  prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
400
- retouching: bool=False,
401
- need_safetycheck: bool=True,
402
  ):
403
  # 0. Default height and width to unet
404
  height = height or self.unet.config.sample_size * self.vae_scale_factor
@@ -424,7 +424,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
424
  if prompt is not None and isinstance(prompt, str):
425
  batch_size = 1
426
  elif prompt is not None and isinstance(prompt, list):
427
- batch_size = len(prompt) #TODO
428
  else:
429
  batch_size = prompt_embeds.shape[0]
430
 
@@ -432,8 +432,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
432
  do_classifier_free_guidance = guidance_scale >= 1.0
433
  input_image_file = input_id_images[0]
434
 
435
-
436
- faceid_embeds = self.get_prepare_faceid(face_image=input_image_file)
437
  face_caption = self.get_prepare_llva_caption(input_image_file)
438
  key_parsing_mask_list, vis_parsing_anno_color = self.get_prepare_facemask(input_image_file)
439
 
@@ -445,13 +444,14 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
445
  (
446
  prompt_text_only,
447
  clean_input_id,
448
- key_parsing_mask_list_align,
449
- facial_token_mask,
450
- facial_token_idx,
451
  facial_token_idx_mask,
452
  ) = self.encode_prompt_with_trigger_word(
453
  prompt = prompt,
454
  face_caption = face_caption,
 
455
  key_parsing_mask_list=key_parsing_mask_list,
456
  device=device,
457
  max_num_facials = 5,
@@ -463,18 +463,20 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
463
 
464
  # 4. Encode input prompt without the trigger word for delayed conditioning
465
  encoder_hidden_states = self.text_encoder(clean_input_id.to(device))[0]
 
466
  prompt_embeds = self._encode_prompt(
467
  prompt_text_only,
468
  device=device,
469
  num_images_per_prompt=num_images_per_prompt,
470
  do_classifier_free_guidance=True,
471
  negative_prompt=negative_prompt,
472
- )
473
  negative_encoder_hidden_states_text_only = prompt_embeds[0:num_images_per_prompt]
474
  encoder_hidden_states_text_only = prompt_embeds[num_images_per_prompt:]
475
 
476
  # 5. Prepare the input ID images
477
- prompt_tokens_faceid, uncond_prompt_tokens_faceid = self.get_image_embeds(faceid_embeds, face_image=input_image_file, s_scale=0.0, shortcut=True)
 
478
  facial_clip_image, facial_mask = self.get_prepare_clip_image(input_image_file, key_parsing_mask_list_align, image_size=512, max_num_facials=5)
479
  facial_clip_images = facial_clip_image.unsqueeze(0).to(device, dtype=self.torch_dtype)
480
  facial_token_mask = facial_token_mask.to(device)
@@ -483,11 +485,13 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
483
 
484
  cross_attention_kwargs = {}
485
 
486
- # 6. Get the update text embeddingx
487
  prompt_embeds_facial, uncond_prompt_embeds_facial = self.get_facial_embeds(encoder_hidden_states, negative_encoder_hidden_states, \
488
  facial_clip_images, facial_token_mask, facial_token_idx_mask)
 
489
  prompt_embeds = torch.cat([prompt_embeds_facial, prompt_tokens_faceid], dim=1)
490
  negative_prompt_embeds = torch.cat([uncond_prompt_embeds_facial, uncond_prompt_tokens_faceid], dim=1)
 
491
  prompt_embeds = self._encode_prompt(
492
  prompt,
493
  device,
@@ -519,8 +523,8 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
519
 
520
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
521
  (
522
- null_prompt_embeds,
523
- augmented_prompt_embeds,
524
  text_prompt_embeds,
525
  ) = prompt_embeds.chunk(3)
526
 
@@ -542,6 +546,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
542
  [null_prompt_embeds, augmented_prompt_embeds], dim=0
543
  )
544
 
 
545
  noise_pred = self.unet(
546
  latent_model_input,
547
  t,
@@ -574,25 +579,17 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
574
  if output_type == "latent":
575
  image = latents
576
  has_nsfw_concept = None
577
- elif output_type == "pil":
578
  # 9.1 Post-processing
579
  image = self.decode_latents(latents)
580
 
581
  # 9.2 Run safety checker
582
- if need_safetycheck:
583
- image, has_nsfw_concept = self.run_safety_checker(
584
- image, device, prompt_embeds.dtype
585
- )
586
- else:
587
- has_nsfw_concept = None
588
 
589
- # 9.3 Convert to PIL list
590
- image = self.numpy_to_pil(image)
591
-
592
- # if retouching:
593
- # after_retouching = self.skin_retouching(image[0])
594
- # if OutputKeys.OUTPUT_IMG in after_retouching:
595
- # image = [Image.fromarray(cv2.cvtColor(after_retouching[OutputKeys.OUTPUT_IMG], cv2.COLOR_BGR2RGB))]
596
  else:
597
  # 9.1 Post-processing
598
  image = self.decode_latents(latents)
@@ -602,7 +599,6 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
602
  image, device, prompt_embeds.dtype
603
  )
604
 
605
-
606
  # Offload last model to CPU
607
  if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
608
  self.final_offload_hook.offload()
@@ -614,3 +610,10 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
614
  images=image, nsfw_content_detected=has_nsfw_concept
615
  )
616
 
 
 
 
 
 
 
 
 
5
  from PIL import Image
6
  import torch
7
  from torchvision import transforms
8
+ from insightface.app import FaceAnalysis
9
+ ### insight-face installation can be found at https://github.com/deepinsight/insightface
10
  from safetensors import safe_open
11
  from huggingface_hub.utils import validate_hf_hub_args
12
  from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
 
16
  from functions import process_text_with_markers, masks_for_unique_values, fetch_mask_raw_image, tokenize_and_mask_noun_phrases_ends, prepare_image_token_idx
17
  from functions import ProjPlusModel, masks_for_unique_values
18
  from attention import Consistent_IPAttProcessor, Consistent_AttProcessor, FacialEncoder
 
 
 
 
 
 
 
 
19
 
20
+ ### Model can be imported from https://github.com/zllrunning/face-parsing.PyTorch?tab=readme-ov-file
21
+ ### We use the ckpt of 79999_iter.pth: https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812
22
+ ### Thanks for the open source of face-parsing model.
23
+ from models.BiSeNet.model import BiSeNet
24
 
25
  PipelineImageInput = Union[
26
  PIL.Image.Image,
 
29
  List[torch.FloatTensor],
30
  ]
31
 
32
+ ### Download the pretrained model from huggingface and put it locally, then place the model in a local directory and specify the directory location.
33
  class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
34
 
35
  @validate_hf_hub_args
 
40
  subfolder: str = '',
41
  trigger_word_ID: str = '<|image|>',
42
  trigger_word_facial: str = '<|facial|>',
43
+ image_encoder_path: str = 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
44
  torch_dtype = torch.float16,
45
  num_tokens = 4,
46
  lora_rank= 128,
47
  **kwargs,
48
  ):
49
+ self.lora_rank = lora_rank
50
  self.torch_dtype = torch_dtype
51
  self.num_tokens = num_tokens
52
  self.set_ip_adapter()
 
65
  ### BiSeNet
66
  self.bise_net = BiSeNet(n_classes = 19)
67
  self.bise_net.cuda()
68
+ self.bise_net_cp='JackAILab/ConsistentID/face_parsing.pth'
69
  self.bise_net.load_state_dict(torch.load(self.bise_net_cp))
70
  self.bise_net.eval()
71
  # Colors for all 20 parts
 
79
  [255, 0, 255], [255, 85, 255], [255, 170, 255],
80
  [0, 255, 255], [85, 255, 255], [170, 255, 255]]
81
 
82
+ ### LLVA (Optional)
83
+ self.llva_model_path = "liuhaotian/llava-v1.5-13b" # TODO
84
+ # IMPORTANT! Download the openai/clip-vit-large-patch14-336 model and specify the model path in config.json ("mm_vision_tower": "openai/clip-vit-large-patch14-336").
85
  self.llva_prompt = "Describe this person's facial features for me, including face, ears, eyes, nose, and mouth."
86
  self.llva_tokenizer, self.llva_model, self.llva_image_processor, self.llva_context_len = None,None,None,None #load_pretrained_model(self.llva_model_path)
87
 
 
89
  cross_attention_dim=self.unet.config.cross_attention_dim,
90
  id_embeddings_dim=512,
91
  clip_embeddings_dim=self.image_encoder.config.hidden_size,
92
+ num_tokens=self.num_tokens, # 4 - inspirsed by IPAdapter and Midjourney
93
  ).to(self.device, dtype=self.torch_dtype)
94
  self.FacialEncoder = FacialEncoder(self.image_encoder).to(self.device, dtype=self.torch_dtype)
95
 
 
 
96
  # Load the main state dict first.
97
  cache_dir = kwargs.pop("cache_dir", None)
98
  force_download = kwargs.pop("force_download", False)
 
185
  multi_facial_embeds = torch.stack(hidden_states)
186
  uncond_multi_facial_embeds = torch.stack(uncond_hidden_states)
187
 
188
+ # condition
189
  facial_prompt_embeds = self.FacialEncoder(prompt_embeds, multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask)
190
 
191
+ # uncondition
192
  uncond_facial_prompt_embeds = self.FacialEncoder(negative_prompt_embeds, uncond_multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask)
193
 
194
  return facial_prompt_embeds, uncond_facial_prompt_embeds
 
200
  clip_image = clip_image.to(self.device, dtype=self.torch_dtype)
201
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
202
  uncond_clip_image_embeds = self.image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[-2]
203
+
204
  faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
205
  image_prompt_tokens = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale)
206
  uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale)
207
+
208
  return image_prompt_tokens, uncond_image_prompt_embeds
209
 
210
  def set_scale(self, scale):
 
220
  faceid_embeds = torch.zeros_like(torch.empty((1, 512)))
221
  else:
222
  faceid_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
223
+
224
  return faceid_embeds
225
 
226
  @torch.inference_mode()
 
238
  img = to_tensor(image)
239
  img = torch.unsqueeze(img, 0)
240
  img = img.float().cuda()
241
+ out = self.bise_net(img)[0]
242
+ parsing_anno = out.squeeze(0).cpu().numpy().argmax(0)
243
 
244
  im = np.array(image_resize_PIL)
245
  vis_im = im.copy().astype(np.uint8)
246
  stride=1
247
+ vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
248
  vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
249
  vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
250
 
 
254
  index = np.where(vis_parsing_anno == pi)
255
  vis_parsing_anno_color[index[0], index[1], :] = self.part_colors[pi]
256
 
257
+ vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
258
  vis_parsing_anno_color = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
259
 
260
  return vis_parsing_anno_color, vis_parsing_anno
 
283
 
284
  return face_caption
285
 
 
 
286
  @torch.inference_mode()
287
  def get_prepare_facemask(self, input_image_file):
288
+
289
  vis_parsing_anno_color, vis_parsing_anno = self.parsing_face_mask(input_image_file)
290
  parsing_mask_list = masks_for_unique_values(vis_parsing_anno)
291
 
292
  key_parsing_mask_list = {}
293
  key_list = ["Face", "Left_Ear", "Right_Ear", "Left_Eye", "Right_Eye", "Nose", "Upper_Lip", "Lower_Lip"]
 
 
294
  processed_keys = set()
295
  for key, mask_image in parsing_mask_list.items():
296
  if key in key_list:
297
  if "_" in key:
298
  prefix = key.split("_")[1]
299
+ if prefix in processed_keys:
300
  continue
301
  else:
302
  key_parsing_mask_list[key] = mask_image
 
318
  device: Optional[torch.device] = None,
319
  ):
320
  device = device or self._execution_device
321
+
322
  face_caption_align, key_parsing_mask_list_align = process_text_with_markers(face_caption, key_parsing_mask_list)
323
 
324
  prompt_face = prompt + "Detail:" + face_caption_align
 
334
  prompt_text_only = prompt_face.replace("<|facial|>", "").replace("<|image|>", "")
335
  tokenizer = self.tokenizer
336
  facial_token_id = tokenizer.convert_tokens_to_ids(facial_token)
337
+ image_token_id = None
338
+
339
  clean_input_id, image_token_mask, facial_token_mask = tokenize_and_mask_noun_phrases_ends(
340
  prompt_face, image_token_id, facial_token_id, tokenizer)
341
+
342
  image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask = prepare_image_token_idx(
343
  image_token_mask, facial_token_mask, num_id_images, max_num_facials )
344
 
 
353
  clip_image_processor = CLIPImageProcessor()
354
 
355
  num_facial_part = len(key_parsing_mask_list)
356
+
357
  for key in key_parsing_mask_list:
358
  key_mask=key_parsing_mask_list[key]
359
  facial_mask.append(transform_mask(key_mask))
 
363
 
364
  padding_ficial_clip_image = torch.zeros_like(torch.zeros([1, 3, 224, 224]))
365
  padding_ficial_mask = torch.zeros_like(torch.zeros([1, image_size, image_size]))
366
+
367
  if num_facial_part < max_num_facials:
368
  facial_clip_image += [torch.zeros_like(padding_ficial_clip_image) for _ in range(max_num_facials - num_facial_part) ]
369
  facial_mask += [ torch.zeros_like(padding_ficial_mask) for _ in range(max_num_facials - num_facial_part)]
 
371
  facial_clip_image = torch.stack(facial_clip_image, dim=1).squeeze(0)
372
  facial_mask = torch.stack(facial_mask, dim=0).squeeze(dim=1)
373
 
374
+ return facial_clip_image, facial_mask
375
 
376
  @torch.no_grad()
377
  def __call__(
 
396
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
397
  callback_steps: int = 1,
398
  input_id_images: PipelineImageInput = None,
 
399
  start_merge_step: int = 0,
400
  class_tokens_mask: Optional[torch.LongTensor] = None,
401
  prompt_embeds_text_only: Optional[torch.FloatTensor] = None,
 
 
402
  ):
403
  # 0. Default height and width to unet
404
  height = height or self.unet.config.sample_size * self.vae_scale_factor
 
424
  if prompt is not None and isinstance(prompt, str):
425
  batch_size = 1
426
  elif prompt is not None and isinstance(prompt, list):
427
+ batch_size = len(prompt)
428
  else:
429
  batch_size = prompt_embeds.shape[0]
430
 
 
432
  do_classifier_free_guidance = guidance_scale >= 1.0
433
  input_image_file = input_id_images[0]
434
 
435
+ faceid_embeds = self.get_prepare_faceid(face_image=input_image_file)
 
436
  face_caption = self.get_prepare_llva_caption(input_image_file)
437
  key_parsing_mask_list, vis_parsing_anno_color = self.get_prepare_facemask(input_image_file)
438
 
 
444
  (
445
  prompt_text_only,
446
  clean_input_id,
447
+ key_parsing_mask_list_align,
448
+ facial_token_mask,
449
+ facial_token_idx,
450
  facial_token_idx_mask,
451
  ) = self.encode_prompt_with_trigger_word(
452
  prompt = prompt,
453
  face_caption = face_caption,
454
+ # prompt_2=None,
455
  key_parsing_mask_list=key_parsing_mask_list,
456
  device=device,
457
  max_num_facials = 5,
 
463
 
464
  # 4. Encode input prompt without the trigger word for delayed conditioning
465
  encoder_hidden_states = self.text_encoder(clean_input_id.to(device))[0]
466
+
467
  prompt_embeds = self._encode_prompt(
468
  prompt_text_only,
469
  device=device,
470
  num_images_per_prompt=num_images_per_prompt,
471
  do_classifier_free_guidance=True,
472
  negative_prompt=negative_prompt,
473
+ )
474
  negative_encoder_hidden_states_text_only = prompt_embeds[0:num_images_per_prompt]
475
  encoder_hidden_states_text_only = prompt_embeds[num_images_per_prompt:]
476
 
477
  # 5. Prepare the input ID images
478
+ prompt_tokens_faceid, uncond_prompt_tokens_faceid = self.get_image_embeds(faceid_embeds, face_image=input_image_file, s_scale=1.0, shortcut=False)
479
+
480
  facial_clip_image, facial_mask = self.get_prepare_clip_image(input_image_file, key_parsing_mask_list_align, image_size=512, max_num_facials=5)
481
  facial_clip_images = facial_clip_image.unsqueeze(0).to(device, dtype=self.torch_dtype)
482
  facial_token_mask = facial_token_mask.to(device)
 
485
 
486
  cross_attention_kwargs = {}
487
 
488
+ # 6. Get the update text embedding
489
  prompt_embeds_facial, uncond_prompt_embeds_facial = self.get_facial_embeds(encoder_hidden_states, negative_encoder_hidden_states, \
490
  facial_clip_images, facial_token_mask, facial_token_idx_mask)
491
+
492
  prompt_embeds = torch.cat([prompt_embeds_facial, prompt_tokens_faceid], dim=1)
493
  negative_prompt_embeds = torch.cat([uncond_prompt_embeds_facial, uncond_prompt_tokens_faceid], dim=1)
494
+
495
  prompt_embeds = self._encode_prompt(
496
  prompt,
497
  device,
 
523
 
524
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
525
  (
526
+ null_prompt_embeds,
527
+ augmented_prompt_embeds,
528
  text_prompt_embeds,
529
  ) = prompt_embeds.chunk(3)
530
 
 
546
  [null_prompt_embeds, augmented_prompt_embeds], dim=0
547
  )
548
 
549
+ # predict the noise residual
550
  noise_pred = self.unet(
551
  latent_model_input,
552
  t,
 
579
  if output_type == "latent":
580
  image = latents
581
  has_nsfw_concept = None
582
+ elif output_type == "pil":
583
  # 9.1 Post-processing
584
  image = self.decode_latents(latents)
585
 
586
  # 9.2 Run safety checker
587
+ image, has_nsfw_concept = self.run_safety_checker(
588
+ image, device, prompt_embeds.dtype
589
+ )
 
 
 
590
 
591
+ # 9.3 Convert to PIL
592
+ image = self.numpy_to_pil(image)
 
 
 
 
 
593
  else:
594
  # 9.1 Post-processing
595
  image = self.decode_latents(latents)
 
599
  image, device, prompt_embeds.dtype
600
  )
601
 
 
602
  # Offload last model to CPU
603
  if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
604
  self.final_offload_hook.offload()
 
610
  images=image, nsfw_content_detected=has_nsfw_concept
611
  )
612
 
613
+
614
+
615
+
616
+
617
+
618
+
619
+