adaface-neurips commited on
Commit
57aa583
·
1 Parent(s): 7e25f02

Separate generate_id_prompt_embeds, Remove nsfw and other useless code

Browse files
app.py CHANGED
@@ -9,7 +9,7 @@ from datetime import datetime
9
  from PIL import Image
10
  from diffusers.utils import load_image
11
  from diffusers import EulerDiscreteScheduler
12
- from pipline_StableDiffusion_ConsistentID import ConsistentIDStableDiffusionPipeline
13
  from huggingface_hub import hf_hub_download
14
  ### Model can be imported from https://github.com/zllrunning/face-parsing.PyTorch?tab=readme-ov-file
15
  ### We use the ckpt of 79999_iter.pth: https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812
@@ -31,7 +31,7 @@ consistentID_path = hf_hub_download(repo_id="JackAILab/ConsistentID",
31
  local_dir="./models")
32
 
33
  ### Load base model
34
- pipe = ConsistentIDStableDiffusionPipeline.from_pretrained(
35
  base_model_path,
36
  torch_dtype=torch.float16,
37
  safety_checker=None, # use_safetensors=True,
@@ -54,7 +54,7 @@ pipe.load_ConsistentID_model(
54
  weight_name=os.path.basename(consistentID_path),
55
  trigger_word="img",
56
  )
57
- #pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
58
 
59
  ### Load to cuda
60
  pipe.to(device)
@@ -65,20 +65,18 @@ pipe.FacialEncoder.to(device)
65
 
66
  @spaces.GPU
67
  def process(selected_template_images, custom_image, prompt,
68
- negative_prompt, prompt_selected, retouching, model_selected_tab,
69
  prompt_selected_tab, width, height, merge_steps, seed_set):
70
 
 
71
  if model_selected_tab==0:
72
- select_images = load_image(Image.open(selected_template_images))
73
  else:
74
- select_images = load_image(Image.fromarray(custom_image))
75
 
76
  if prompt_selected_tab==0:
77
  prompt = prompt_selected
78
  negative_prompt = ""
79
- need_safetycheck = False
80
- else:
81
- need_safetycheck = True
82
 
83
  # hyper-parameter
84
  num_steps = 50
@@ -110,17 +108,14 @@ def process(selected_template_images, custom_image, prompt,
110
  prompt=prompt,
111
  width=width,
112
  height=height,
113
- input_id_images=select_images,
114
  negative_prompt=negative_prompt,
115
  num_images_per_prompt=1,
116
  num_inference_steps=num_steps,
117
  start_merge_step=merge_steps,
118
  generator=generator,
119
- retouching=retouching,
120
- need_safetycheck=need_safetycheck,
121
  ).images[0]
122
 
123
- current_date = datetime.today()
124
  return np.array(images)
125
 
126
  # Gets the templates
@@ -174,7 +169,6 @@ with gr.Blocks(title="ConsistentID Demo") as demo:
174
  for i, tab in enumerate(prompt_selected_tabs):
175
  tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[prompt_selected_tab])
176
 
177
- retouching = gr.Checkbox(label="face retouching",value=False,visible=False)
178
  width = gr.Slider(label="image width",minimum=256,maximum=768,value=512,step=8)
179
  height = gr.Slider(label="image height",minimum=256,maximum=768,value=768,step=8)
180
  width.release(lambda x,y: min(1280-x,y), inputs=[width,height], outputs=[height])
@@ -191,7 +185,7 @@ with gr.Blocks(title="ConsistentID Demo") as demo:
191
  - At the same time, use prompt with \"man\" or \"woman\" instead of \"person\" as much as possible, as that may cause the model to be confused whether the protagonist is male or female.
192
  - Due to insufficient graphics memory on the demo server, there is an upper limit on the resolution for generating samples. We will support the generation of SDXL as soon as possible<br/><br/>
193
  ''')
194
- btn.click(fn=process, inputs=[selected_template_images,custom_image,prompt,nagetive_prompt,prompt_selected,retouching
195
- ,model_selected_tab,prompt_selected_tab,width,height,merge_steps,seed_set], outputs=out)
196
 
197
  demo.launch(server_name='0.0.0.0', ssl_verify=False)
 
9
  from PIL import Image
10
  from diffusers.utils import load_image
11
  from diffusers import EulerDiscreteScheduler
12
+ from pipline_ConsistentID import ConsistentIDPipeline
13
  from huggingface_hub import hf_hub_download
14
  ### Model can be imported from https://github.com/zllrunning/face-parsing.PyTorch?tab=readme-ov-file
15
  ### We use the ckpt of 79999_iter.pth: https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812
 
31
  local_dir="./models")
32
 
33
  ### Load base model
34
+ pipe = ConsistentIDPipeline.from_pretrained(
35
  base_model_path,
36
  torch_dtype=torch.float16,
37
  safety_checker=None, # use_safetensors=True,
 
54
  weight_name=os.path.basename(consistentID_path),
55
  trigger_word="img",
56
  )
57
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
58
 
59
  ### Load to cuda
60
  pipe.to(device)
 
65
 
66
  @spaces.GPU
67
  def process(selected_template_images, custom_image, prompt,
68
+ negative_prompt, prompt_selected, model_selected_tab,
69
  prompt_selected_tab, width, height, merge_steps, seed_set):
70
 
71
+ # The gradio UI only supports one image at a time.
72
  if model_selected_tab==0:
73
+ subj_images = load_image(Image.open(selected_template_images))
74
  else:
75
+ subj_images = load_image(Image.fromarray(custom_image))
76
 
77
  if prompt_selected_tab==0:
78
  prompt = prompt_selected
79
  negative_prompt = ""
 
 
 
80
 
81
  # hyper-parameter
82
  num_steps = 50
 
108
  prompt=prompt,
109
  width=width,
110
  height=height,
111
+ input_subj_image_objs=subj_images,
112
  negative_prompt=negative_prompt,
113
  num_images_per_prompt=1,
114
  num_inference_steps=num_steps,
115
  start_merge_step=merge_steps,
116
  generator=generator,
 
 
117
  ).images[0]
118
 
 
119
  return np.array(images)
120
 
121
  # Gets the templates
 
169
  for i, tab in enumerate(prompt_selected_tabs):
170
  tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[prompt_selected_tab])
171
 
 
172
  width = gr.Slider(label="image width",minimum=256,maximum=768,value=512,step=8)
173
  height = gr.Slider(label="image height",minimum=256,maximum=768,value=768,step=8)
174
  width.release(lambda x,y: min(1280-x,y), inputs=[width,height], outputs=[height])
 
185
  - At the same time, use prompt with \"man\" or \"woman\" instead of \"person\" as much as possible, as that may cause the model to be confused whether the protagonist is male or female.
186
  - Due to insufficient graphics memory on the demo server, there is an upper limit on the resolution for generating samples. We will support the generation of SDXL as soon as possible<br/><br/>
187
  ''')
188
+ btn.click(fn=process, inputs=[selected_template_images, custom_image,prompt, nagetive_prompt, prompt_selected,
189
+ model_selected_tab, prompt_selected_tab, width, height, merge_steps, seed_set], outputs=out)
190
 
191
  demo.launch(server_name='0.0.0.0', ssl_verify=False)
pipline_StableDiffusion_ConsistentID.py → pipline_ConsistentID.py RENAMED
@@ -5,7 +5,6 @@ import numpy as np
5
  from PIL import Image
6
  import torch
7
  from torchvision import transforms
8
- from torchvision.utils import save_image
9
  from insightface.app import FaceAnalysis
10
  ### insight-face installation can be found at https://github.com/deepinsight/insightface
11
  from safetensors import safe_open
@@ -27,7 +26,7 @@ PipelineImageInput = Union[
27
  ]
28
 
29
  ### Download the pretrained model from huggingface and put it locally, then place the model in a local directory and specify the directory location.
30
- class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
31
 
32
  def cuda(self, dtype=torch.float16, use_xformers=False):
33
  self.to('cuda', dtype)
@@ -191,7 +190,8 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
191
  # parsed_image_parts2 is a batched tensor of parsed_image_parts with bs=1. It only contains the facial areas of one input image.
192
  # clip_encoder maps image parts to image-space diffusion prompts.
193
  # Then the facial class token embeddings are replaced with the fused (multi_facial_embeds, prompt_embeds[class_tokens_mask]).
194
- def get_local_facial_embeds(self, prompt_embeds, negative_prompt_embeds, parsed_image_parts2, facial_token_masks, valid_facial_token_idx_mask):
 
195
 
196
  hidden_states = []
197
  uncond_hidden_states = []
@@ -209,8 +209,10 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
209
  # multi_facial_embeds: [1, 5, 257, 1280].
210
  facial_prompt_embeds = self.FacialEncoder(prompt_embeds, multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask)
211
 
 
 
212
  # unconditional prompt.
213
- uncond_facial_prompt_embeds = self.FacialEncoder(negative_prompt_embeds, uncond_multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask)
214
 
215
  return facial_prompt_embeds, uncond_facial_prompt_embeds
216
 
@@ -285,9 +287,8 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
285
  return vis_parsing_anno_color, vis_parsing_anno
286
 
287
  @torch.inference_mode()
288
- def extract_facemask(self, input_image_file):
289
-
290
- vis_parsing_anno_color, vis_parsing_anno = self.parse_face_mask(input_image_file)
291
  parsing_mask_list = masks_for_unique_values(vis_parsing_anno)
292
 
293
  key_parsing_mask_dict = {}
@@ -307,12 +308,11 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
307
 
308
  return key_parsing_mask_dict, vis_parsing_anno_color
309
 
310
- def encode_prompt_with_trigger_word(
311
  self,
312
  prompt: str,
313
  face_caption: str,
314
  key_parsing_mask_dict = None,
315
- image_token = "<|image|>",
316
  facial_token = "<|facial|>",
317
  max_num_facials = 5,
318
  num_id_images: int = 1,
@@ -335,9 +335,9 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
335
  face_caption_align = ""
336
 
337
  # Remove "<|facial|>" from prompt_face.
338
- # prompt_text_only: 'A person, police officer, half body shot Detail:
339
  # The person has one nose , two ears , two eyes , and a mouth , '
340
- prompt_text_only = prompt_face.replace("<|facial|>", "").replace("<|image|>", "")
341
  tokenizer = self.tokenizer
342
  facial_token_id = tokenizer.convert_tokens_to_ids(facial_token)
343
  image_token_id = None
@@ -350,10 +350,10 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
350
  image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask = \
351
  prepare_image_token_idx(image_token_mask, facial_token_mask, num_id_images, max_num_facials)
352
 
353
- return prompt_text_only, clean_input_id, key_parsing_mask_dict_align, facial_token_mask, facial_token_idx, facial_token_idx_mask
354
 
355
  @torch.inference_mode()
356
- def extract_parsed_image_parts(self, input_image_file, key_parsing_mask_dict, image_size=512, max_num_facials=5):
357
  facial_masks = []
358
  parsed_image_parts = []
359
  key_masked_raw_images_dict = {}
@@ -365,7 +365,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
365
  for key in key_parsing_mask_dict:
366
  key_mask=key_parsing_mask_dict[key]
367
  facial_masks.append(transform_mask(key_mask))
368
- key_masked_raw_image = apply_mask_to_raw_image(input_image_file, key_mask)
369
  key_masked_raw_images_dict[key] = key_masked_raw_image
370
  # clip_preprocessor normalizes key_masked_raw_image, so that (masked) zero pixels become non-zero.
371
  # It also resizes the image to 224x224.
@@ -384,14 +384,88 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
384
 
385
  return parsed_image_parts, facial_masks, key_masked_raw_images_dict
386
 
387
- def remove_unet(self):
388
- unet = self.unet
389
- self.unet = edict()
390
  # Only keep the config and in_channels attributes that are used in the pipeline.
391
- self.unet.config = unet.config
392
- self.unet.in_channels = unet.in_channels
393
- del unet
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  @torch.no_grad()
396
  def __call__(
397
  self,
@@ -414,12 +488,8 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
414
  target_size: Optional[Tuple[int, int]] = None,
415
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
416
  callback_steps: int = 1,
417
- input_id_images: PipelineImageInput = None,
418
  start_merge_step: int = 0,
419
- class_tokens_mask: Optional[torch.LongTensor] = None,
420
- text_embeds: Optional[torch.FloatTensor] = None,
421
- retouching: bool=False,
422
- need_safetycheck: bool=True,
423
  ):
424
  # 0. Default height and width to unet
425
  height = height or self.unet.config.sample_size * self.vae_scale_factor
@@ -438,8 +508,8 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
438
  prompt_embeds,
439
  negative_prompt_embeds,
440
  )
441
- if not isinstance(input_id_images, list):
442
- input_id_images = [input_id_images]
443
 
444
  # 2. Define call parameters
445
  if prompt is not None and isinstance(prompt, str):
@@ -451,81 +521,11 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
451
 
452
  device = self._execution_device
453
  do_classifier_free_guidance = guidance_scale >= 1.0
454
-
455
- input_image_file = input_id_images[0]
456
-
457
- faceid_embeds = self.extract_faceid(face_image=input_image_file)
458
- face_caption = "The person has one nose, two eyes, two ears, and a mouth."
459
- key_parsing_mask_dict, vis_parsing_anno_color = self.extract_facemask(input_image_file)
460
-
461
  assert do_classifier_free_guidance
462
 
463
  # 3. Encode input prompt
464
- num_id_images = len(input_id_images)
465
-
466
- prompt_text_only, clean_input_id, key_parsing_mask_dict_align, \
467
- facial_token_mask, facial_token_idx, facial_token_idx_mask \
468
- = self.encode_prompt_with_trigger_word(
469
- prompt = prompt,
470
- face_caption = face_caption,
471
- key_parsing_mask_dict=key_parsing_mask_dict,
472
- device=device,
473
- max_num_facials = 5,
474
- num_id_images= num_id_images
475
- )
476
-
477
- # 4. Encode input prompt without the trigger word for delayed conditioning
478
- # encoder_hidden_states = self.text_encoder(clean_input_id.to(device))[0]
479
-
480
- prompt_embeds = self._encode_prompt(
481
- prompt_text_only,
482
- device=device,
483
- num_images_per_prompt=num_images_per_prompt,
484
- do_classifier_free_guidance=True,
485
- negative_prompt=negative_prompt,
486
- )
487
- uncond_text_embeds = prompt_embeds[0:num_images_per_prompt]
488
- text_embeds = prompt_embeds[num_images_per_prompt:]
489
-
490
- # 5. Prepare the input ID images
491
- # global_id_embeds: [1, 4, 768]
492
- # get_global_id_embeds() extrats OpenCLIP embeddings from the input image and map them to global face prompt embeddings.
493
- global_id_embeds, uncond_global_id_embeds = \
494
- self.get_global_id_embeds(faceid_embeds, face_image=input_image_file, s_scale=1.0, shortcut=False)
495
-
496
- # parsed_image_parts: [5, 3, 224, 224]. 5 parts, each part is a 3-channel 224x224 image (resized by CLIP Preprocessor).
497
- parsed_image_parts, facial_masks, key_masked_raw_images_dict = \
498
- self.extract_parsed_image_parts(input_image_file, key_parsing_mask_dict_align, image_size=512, max_num_facials=5)
499
- parsed_image_parts2 = parsed_image_parts.unsqueeze(0).to(device, dtype=self.torch_dtype)
500
- facial_token_mask = facial_token_mask.to(device)
501
- facial_token_idx_mask = facial_token_idx_mask.to(device)
502
-
503
- # key_masked_raw_images_dict: ['Right_Eye', 'Right_Ear', 'Nose', 'Upper_Lip']
504
- # for key in key_masked_raw_images_dict:
505
- # key_masked_raw_images_dict[key].save(f"{key}.png")
506
-
507
- # 6. Get the update text embedding
508
- # parsed_image_parts2: the facial areas of the input image
509
- # text_local_id_embeds: [1, 77, 768]
510
- # text_local_id_embeds only differs with text_global_id_embeds on 4 tokens, and is identical
511
- # to text_global_id_embeds on the rest 73 tokens.
512
- # get_local_facial_embeds() maps parsed_image_parts2 to multi_facial_embeds, and then replaces the class tokens in prompt_embeds
513
- # with the fused (id_embeds, prompt_embeds[class_tokens_mask]) whose indices are specified by class_tokens_mask.
514
- # parsed_image_parts2: [1, 5, 3, 224, 224]
515
- text_local_id_embeds, uncond_text_local_id_embeds = \
516
- self.get_local_facial_embeds(text_embeds, uncond_text_embeds, \
517
- parsed_image_parts2, facial_token_mask, facial_token_idx_mask)
518
-
519
- uncond_text_global_id_embeds = torch.cat([uncond_text_local_id_embeds, uncond_global_id_embeds], dim=1)
520
- # text_global_id_embeds, text_local_global_id_embeds: [1, 81, 768]
521
- text_global_id_embeds = torch.cat([text_embeds, global_id_embeds], dim=1)
522
- text_local_global_id_embeds = torch.cat([text_local_id_embeds, global_id_embeds], dim=1)
523
-
524
- coarse_prompt_embeds = torch.cat([uncond_text_global_id_embeds, text_global_id_embeds], dim=0)
525
- fine_prompt_embeds = torch.cat([uncond_text_global_id_embeds, text_local_global_id_embeds], dim=0)
526
- # fine_prompt_embeds: the conditional part is
527
- # (text_global_id_embeds + text_local_global_id_embeds) / 2.
528
- fine_prompt_embeds = (coarse_prompt_embeds + fine_prompt_embeds) / 2
529
 
530
  # 7. Prepare timesteps
531
  self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -538,7 +538,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
538
  num_channels_latents,
539
  height,
540
  width,
541
- prompt_embeds.dtype,
542
  device,
543
  generator,
544
  latents,
@@ -584,48 +584,32 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
584
  ).prev_sample
585
 
586
  # call the callback, if provided
587
- if i == len(timesteps) - 1 or (
588
- (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
589
- ):
590
  progress_bar.update()
591
  if callback is not None and i % callback_steps == 0:
592
  callback(i, t, latents)
593
 
594
  if output_type == "latent":
595
  image = latents
596
- has_nsfw_concept = None
597
  elif output_type == "pil":
598
  # 9.1 Post-processing
599
  image = self.decode_latents(latents)
600
-
601
- # 9.2 Run safety checker
602
- if need_safetycheck:
603
- image, has_nsfw_concept = self.run_safety_checker(
604
- image, device, prompt_embeds.dtype
605
- )
606
- else:
607
- has_nsfw_concept = None
608
-
609
  # 9.3 Convert to PIL
610
  image = self.numpy_to_pil(image)
611
  else:
612
  # 9.1 Post-processing
613
  image = self.decode_latents(latents)
614
 
615
- # 9.2 Run safety checker
616
- image, has_nsfw_concept = self.run_safety_checker(
617
- image, device, prompt_embeds.dtype
618
- )
619
-
620
  # Offload last model to CPU
621
  if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
622
  self.final_offload_hook.offload()
623
 
624
  if not return_dict:
625
- return (image, has_nsfw_concept)
626
 
627
  return StableDiffusionPipelineOutput(
628
- images=image, nsfw_content_detected=has_nsfw_concept
629
  )
630
 
631
 
 
5
  from PIL import Image
6
  import torch
7
  from torchvision import transforms
 
8
  from insightface.app import FaceAnalysis
9
  ### insight-face installation can be found at https://github.com/deepinsight/insightface
10
  from safetensors import safe_open
 
26
  ]
27
 
28
  ### Download the pretrained model from huggingface and put it locally, then place the model in a local directory and specify the directory location.
29
+ class ConsistentIDPipeline(StableDiffusionPipeline):
30
 
31
  def cuda(self, dtype=torch.float16, use_xformers=False):
32
  self.to('cuda', dtype)
 
190
  # parsed_image_parts2 is a batched tensor of parsed_image_parts with bs=1. It only contains the facial areas of one input image.
191
  # clip_encoder maps image parts to image-space diffusion prompts.
192
  # Then the facial class token embeddings are replaced with the fused (multi_facial_embeds, prompt_embeds[class_tokens_mask]).
193
+ def get_local_facial_embeds(self, prompt_embeds, uncond_prompt_embeds, parsed_image_parts2,
194
+ facial_token_masks, valid_facial_token_idx_mask, calc_uncond=True):
195
 
196
  hidden_states = []
197
  uncond_hidden_states = []
 
209
  # multi_facial_embeds: [1, 5, 257, 1280].
210
  facial_prompt_embeds = self.FacialEncoder(prompt_embeds, multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask)
211
 
212
+ if not calc_uncond:
213
+ return facial_prompt_embeds, None
214
  # unconditional prompt.
215
+ uncond_facial_prompt_embeds = self.FacialEncoder(uncond_prompt_embeds, uncond_multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask)
216
 
217
  return facial_prompt_embeds, uncond_facial_prompt_embeds
218
 
 
287
  return vis_parsing_anno_color, vis_parsing_anno
288
 
289
  @torch.inference_mode()
290
+ def extract_facemask(self, input_image_obj):
291
+ vis_parsing_anno_color, vis_parsing_anno = self.parse_face_mask(input_image_obj)
 
292
  parsing_mask_list = masks_for_unique_values(vis_parsing_anno)
293
 
294
  key_parsing_mask_dict = {}
 
308
 
309
  return key_parsing_mask_dict, vis_parsing_anno_color
310
 
311
+ def augment_prompt_with_trigger_word(
312
  self,
313
  prompt: str,
314
  face_caption: str,
315
  key_parsing_mask_dict = None,
 
316
  facial_token = "<|facial|>",
317
  max_num_facials = 5,
318
  num_id_images: int = 1,
 
335
  face_caption_align = ""
336
 
337
  # Remove "<|facial|>" from prompt_face.
338
+ # augmented_prompt: 'A person, police officer, half body shot Detail:
339
  # The person has one nose , two ears , two eyes , and a mouth , '
340
+ augmented_prompt = prompt_face.replace("<|facial|>", "").replace("<|image|>", "")
341
  tokenizer = self.tokenizer
342
  facial_token_id = tokenizer.convert_tokens_to_ids(facial_token)
343
  image_token_id = None
 
350
  image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask = \
351
  prepare_image_token_idx(image_token_mask, facial_token_mask, num_id_images, max_num_facials)
352
 
353
+ return augmented_prompt, clean_input_id, key_parsing_mask_dict_align, facial_token_mask, facial_token_idx, facial_token_idx_mask
354
 
355
  @torch.inference_mode()
356
+ def extract_parsed_image_parts(self, input_image_obj, key_parsing_mask_dict, image_size=512, max_num_facials=5):
357
  facial_masks = []
358
  parsed_image_parts = []
359
  key_masked_raw_images_dict = {}
 
365
  for key in key_parsing_mask_dict:
366
  key_mask=key_parsing_mask_dict[key]
367
  facial_masks.append(transform_mask(key_mask))
368
+ key_masked_raw_image = apply_mask_to_raw_image(input_image_obj, key_mask)
369
  key_masked_raw_images_dict[key] = key_masked_raw_image
370
  # clip_preprocessor normalizes key_masked_raw_image, so that (masked) zero pixels become non-zero.
371
  # It also resizes the image to 224x224.
 
384
 
385
  return parsed_image_parts, facial_masks, key_masked_raw_images_dict
386
 
387
+ # Release the unet and vae models to save memory.
388
+ def release_unet_vae(self):
389
+ unet = edict()
390
  # Only keep the config and in_channels attributes that are used in the pipeline.
391
+ unet.config = self.unet.config
392
+ unet.in_channels = self.unet.in_channels
393
+ self.unet = unet
394
+ self.vae = None
395
+
396
+ # input_subj_image_obj: an Image object.
397
+ def generate_id_prompt_embeds(self, prompt, negative_prompt, input_subj_image_obj, device, calc_uncond=True):
398
+ faceid_embeds = self.extract_faceid(face_image=input_subj_image_obj)
399
+ face_caption = "The person has one nose, two eyes, two ears, and a mouth."
400
+ key_parsing_mask_dict, vis_parsing_anno_color = self.extract_facemask(input_subj_image_obj)
401
+
402
+ augmented_prompt, clean_input_id, key_parsing_mask_dict_align, \
403
+ facial_token_mask, facial_token_idx, facial_token_idx_mask \
404
+ = self.augment_prompt_with_trigger_word(
405
+ prompt = prompt,
406
+ face_caption = face_caption,
407
+ key_parsing_mask_dict=key_parsing_mask_dict,
408
+ device=device,
409
+ max_num_facials = 5,
410
+ num_id_images = 1
411
+ )
412
+
413
+ text_embeds, uncond_text_embeds = self.encode_prompt(
414
+ augmented_prompt,
415
+ device=device,
416
+ num_images_per_prompt=1,
417
+ do_classifier_free_guidance=calc_uncond,
418
+ negative_prompt=negative_prompt,
419
+ )
420
+
421
+ # 5. Prepare the input ID images
422
+ # global_id_embeds: [1, 4, 768]
423
+ # get_global_id_embeds() extrats OpenCLIP embeddings from the input image and map them to global face prompt embeddings.
424
+ global_id_embeds, uncond_global_id_embeds = \
425
+ self.get_global_id_embeds(faceid_embeds, face_image=input_subj_image_obj, s_scale=1.0, shortcut=False)
426
+
427
+ # parsed_image_parts: [5, 3, 224, 224]. 5 parts, each part is a 3-channel 224x224 image (resized by CLIP Preprocessor).
428
+ parsed_image_parts, facial_masks, key_masked_raw_images_dict = \
429
+ self.extract_parsed_image_parts(input_subj_image_obj, key_parsing_mask_dict_align, image_size=512, max_num_facials=5)
430
+ parsed_image_parts2 = parsed_image_parts.unsqueeze(0).to(device, dtype=self.torch_dtype)
431
+ facial_token_mask = facial_token_mask.to(device)
432
+ facial_token_idx_mask = facial_token_idx_mask.to(device)
433
 
434
+ # key_masked_raw_images_dict: ['Right_Eye', 'Right_Ear', 'Nose', 'Upper_Lip']
435
+ # for key in key_masked_raw_images_dict:
436
+ # key_masked_raw_images_dict[key].save(f"{key}.png")
437
+
438
+ # 6. Get the update text embedding
439
+ # parsed_image_parts2: the facial areas of the input image
440
+ # text_local_id_embeds: [1, 77, 768]
441
+ # text_local_id_embeds only differs with text_global_id_embeds on 4 tokens, and is identical
442
+ # to text_global_id_embeds on the rest 73 tokens.
443
+ # get_local_facial_embeds() maps parsed_image_parts2 to multi_facial_embeds, and then replaces the class tokens in prompt_embeds
444
+ # with the fused (id_embeds, prompt_embeds[class_tokens_mask]) whose indices are specified by class_tokens_mask.
445
+ # parsed_image_parts2: [1, 5, 3, 224, 224]
446
+ text_local_id_embeds, uncond_text_local_id_embeds = \
447
+ self.get_local_facial_embeds(text_embeds, uncond_text_embeds, \
448
+ parsed_image_parts2, facial_token_mask, facial_token_idx_mask,
449
+ calc_uncond=calc_uncond)
450
+
451
+ # text_global_id_embeds, text_local_global_id_embeds: [1, 81, 768]
452
+ text_global_id_embeds = torch.cat([text_embeds, global_id_embeds], dim=1)
453
+ text_local_global_id_embeds = torch.cat([text_local_id_embeds, global_id_embeds], dim=1)
454
+
455
+ if calc_uncond:
456
+ uncond_text_global_id_embeds = torch.cat([uncond_text_local_id_embeds, uncond_global_id_embeds], dim=1)
457
+ coarse_prompt_embeds = torch.cat([uncond_text_global_id_embeds, text_global_id_embeds], dim=0)
458
+ fine_prompt_embeds = torch.cat([uncond_text_global_id_embeds, text_local_global_id_embeds], dim=0)
459
+ else:
460
+ coarse_prompt_embeds = text_global_id_embeds
461
+ fine_prompt_embeds = text_local_global_id_embeds
462
+
463
+ # fine_prompt_embeds: the conditional part is
464
+ # (text_global_id_embeds + text_local_global_id_embeds) / 2.
465
+ fine_prompt_embeds = (coarse_prompt_embeds + fine_prompt_embeds) / 2
466
+
467
+ return coarse_prompt_embeds, fine_prompt_embeds
468
+
469
  @torch.no_grad()
470
  def __call__(
471
  self,
 
488
  target_size: Optional[Tuple[int, int]] = None,
489
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
490
  callback_steps: int = 1,
491
+ input_subj_image_objs: PipelineImageInput = None,
492
  start_merge_step: int = 0,
 
 
 
 
493
  ):
494
  # 0. Default height and width to unet
495
  height = height or self.unet.config.sample_size * self.vae_scale_factor
 
508
  prompt_embeds,
509
  negative_prompt_embeds,
510
  )
511
+ if not isinstance(input_subj_image_objs, list):
512
+ input_subj_image_objs = [input_subj_image_objs]
513
 
514
  # 2. Define call parameters
515
  if prompt is not None and isinstance(prompt, str):
 
521
 
522
  device = self._execution_device
523
  do_classifier_free_guidance = guidance_scale >= 1.0
 
 
 
 
 
 
 
524
  assert do_classifier_free_guidance
525
 
526
  # 3. Encode input prompt
527
+ coarse_prompt_embeds, fine_prompt_embeds = \
528
+ self.generate_id_prompt_embeds(prompt, negative_prompt, input_subj_image_objs[0], device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
 
530
  # 7. Prepare timesteps
531
  self.scheduler.set_timesteps(num_inference_steps, device=device)
 
538
  num_channels_latents,
539
  height,
540
  width,
541
+ coarse_prompt_embeds.dtype,
542
  device,
543
  generator,
544
  latents,
 
584
  ).prev_sample
585
 
586
  # call the callback, if provided
587
+ if i == len(timesteps) - 1 or \
588
+ ( (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 ):
 
589
  progress_bar.update()
590
  if callback is not None and i % callback_steps == 0:
591
  callback(i, t, latents)
592
 
593
  if output_type == "latent":
594
  image = latents
 
595
  elif output_type == "pil":
596
  # 9.1 Post-processing
597
  image = self.decode_latents(latents)
 
 
 
 
 
 
 
 
 
598
  # 9.3 Convert to PIL
599
  image = self.numpy_to_pil(image)
600
  else:
601
  # 9.1 Post-processing
602
  image = self.decode_latents(latents)
603
 
 
 
 
 
 
604
  # Offload last model to CPU
605
  if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
606
  self.final_offload_hook.offload()
607
 
608
  if not return_dict:
609
+ return (image, None)
610
 
611
  return StableDiffusionPipelineOutput(
612
+ images=image, nsfw_content_detected=None
613
  )
614
 
615