Spaces:
Runtime error
Runtime error
adaface-neurips
commited on
Commit
·
57aa583
1
Parent(s):
7e25f02
Separate generate_id_prompt_embeds, Remove nsfw and other useless code
Browse files- app.py +10 -16
- pipline_StableDiffusion_ConsistentID.py → pipline_ConsistentID.py +104 -120
app.py
CHANGED
@@ -9,7 +9,7 @@ from datetime import datetime
|
|
9 |
from PIL import Image
|
10 |
from diffusers.utils import load_image
|
11 |
from diffusers import EulerDiscreteScheduler
|
12 |
-
from
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
### Model can be imported from https://github.com/zllrunning/face-parsing.PyTorch?tab=readme-ov-file
|
15 |
### We use the ckpt of 79999_iter.pth: https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812
|
@@ -31,7 +31,7 @@ consistentID_path = hf_hub_download(repo_id="JackAILab/ConsistentID",
|
|
31 |
local_dir="./models")
|
32 |
|
33 |
### Load base model
|
34 |
-
pipe =
|
35 |
base_model_path,
|
36 |
torch_dtype=torch.float16,
|
37 |
safety_checker=None, # use_safetensors=True,
|
@@ -54,7 +54,7 @@ pipe.load_ConsistentID_model(
|
|
54 |
weight_name=os.path.basename(consistentID_path),
|
55 |
trigger_word="img",
|
56 |
)
|
57 |
-
|
58 |
|
59 |
### Load to cuda
|
60 |
pipe.to(device)
|
@@ -65,20 +65,18 @@ pipe.FacialEncoder.to(device)
|
|
65 |
|
66 |
@spaces.GPU
|
67 |
def process(selected_template_images, custom_image, prompt,
|
68 |
-
negative_prompt, prompt_selected,
|
69 |
prompt_selected_tab, width, height, merge_steps, seed_set):
|
70 |
|
|
|
71 |
if model_selected_tab==0:
|
72 |
-
|
73 |
else:
|
74 |
-
|
75 |
|
76 |
if prompt_selected_tab==0:
|
77 |
prompt = prompt_selected
|
78 |
negative_prompt = ""
|
79 |
-
need_safetycheck = False
|
80 |
-
else:
|
81 |
-
need_safetycheck = True
|
82 |
|
83 |
# hyper-parameter
|
84 |
num_steps = 50
|
@@ -110,17 +108,14 @@ def process(selected_template_images, custom_image, prompt,
|
|
110 |
prompt=prompt,
|
111 |
width=width,
|
112 |
height=height,
|
113 |
-
|
114 |
negative_prompt=negative_prompt,
|
115 |
num_images_per_prompt=1,
|
116 |
num_inference_steps=num_steps,
|
117 |
start_merge_step=merge_steps,
|
118 |
generator=generator,
|
119 |
-
retouching=retouching,
|
120 |
-
need_safetycheck=need_safetycheck,
|
121 |
).images[0]
|
122 |
|
123 |
-
current_date = datetime.today()
|
124 |
return np.array(images)
|
125 |
|
126 |
# Gets the templates
|
@@ -174,7 +169,6 @@ with gr.Blocks(title="ConsistentID Demo") as demo:
|
|
174 |
for i, tab in enumerate(prompt_selected_tabs):
|
175 |
tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[prompt_selected_tab])
|
176 |
|
177 |
-
retouching = gr.Checkbox(label="face retouching",value=False,visible=False)
|
178 |
width = gr.Slider(label="image width",minimum=256,maximum=768,value=512,step=8)
|
179 |
height = gr.Slider(label="image height",minimum=256,maximum=768,value=768,step=8)
|
180 |
width.release(lambda x,y: min(1280-x,y), inputs=[width,height], outputs=[height])
|
@@ -191,7 +185,7 @@ with gr.Blocks(title="ConsistentID Demo") as demo:
|
|
191 |
- At the same time, use prompt with \"man\" or \"woman\" instead of \"person\" as much as possible, as that may cause the model to be confused whether the protagonist is male or female.
|
192 |
- Due to insufficient graphics memory on the demo server, there is an upper limit on the resolution for generating samples. We will support the generation of SDXL as soon as possible<br/><br/>
|
193 |
''')
|
194 |
-
btn.click(fn=process, inputs=[selected_template_images,custom_image,prompt,nagetive_prompt,prompt_selected,
|
195 |
-
|
196 |
|
197 |
demo.launch(server_name='0.0.0.0', ssl_verify=False)
|
|
|
9 |
from PIL import Image
|
10 |
from diffusers.utils import load_image
|
11 |
from diffusers import EulerDiscreteScheduler
|
12 |
+
from pipline_ConsistentID import ConsistentIDPipeline
|
13 |
from huggingface_hub import hf_hub_download
|
14 |
### Model can be imported from https://github.com/zllrunning/face-parsing.PyTorch?tab=readme-ov-file
|
15 |
### We use the ckpt of 79999_iter.pth: https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812
|
|
|
31 |
local_dir="./models")
|
32 |
|
33 |
### Load base model
|
34 |
+
pipe = ConsistentIDPipeline.from_pretrained(
|
35 |
base_model_path,
|
36 |
torch_dtype=torch.float16,
|
37 |
safety_checker=None, # use_safetensors=True,
|
|
|
54 |
weight_name=os.path.basename(consistentID_path),
|
55 |
trigger_word="img",
|
56 |
)
|
57 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
58 |
|
59 |
### Load to cuda
|
60 |
pipe.to(device)
|
|
|
65 |
|
66 |
@spaces.GPU
|
67 |
def process(selected_template_images, custom_image, prompt,
|
68 |
+
negative_prompt, prompt_selected, model_selected_tab,
|
69 |
prompt_selected_tab, width, height, merge_steps, seed_set):
|
70 |
|
71 |
+
# The gradio UI only supports one image at a time.
|
72 |
if model_selected_tab==0:
|
73 |
+
subj_images = load_image(Image.open(selected_template_images))
|
74 |
else:
|
75 |
+
subj_images = load_image(Image.fromarray(custom_image))
|
76 |
|
77 |
if prompt_selected_tab==0:
|
78 |
prompt = prompt_selected
|
79 |
negative_prompt = ""
|
|
|
|
|
|
|
80 |
|
81 |
# hyper-parameter
|
82 |
num_steps = 50
|
|
|
108 |
prompt=prompt,
|
109 |
width=width,
|
110 |
height=height,
|
111 |
+
input_subj_image_objs=subj_images,
|
112 |
negative_prompt=negative_prompt,
|
113 |
num_images_per_prompt=1,
|
114 |
num_inference_steps=num_steps,
|
115 |
start_merge_step=merge_steps,
|
116 |
generator=generator,
|
|
|
|
|
117 |
).images[0]
|
118 |
|
|
|
119 |
return np.array(images)
|
120 |
|
121 |
# Gets the templates
|
|
|
169 |
for i, tab in enumerate(prompt_selected_tabs):
|
170 |
tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[prompt_selected_tab])
|
171 |
|
|
|
172 |
width = gr.Slider(label="image width",minimum=256,maximum=768,value=512,step=8)
|
173 |
height = gr.Slider(label="image height",minimum=256,maximum=768,value=768,step=8)
|
174 |
width.release(lambda x,y: min(1280-x,y), inputs=[width,height], outputs=[height])
|
|
|
185 |
- At the same time, use prompt with \"man\" or \"woman\" instead of \"person\" as much as possible, as that may cause the model to be confused whether the protagonist is male or female.
|
186 |
- Due to insufficient graphics memory on the demo server, there is an upper limit on the resolution for generating samples. We will support the generation of SDXL as soon as possible<br/><br/>
|
187 |
''')
|
188 |
+
btn.click(fn=process, inputs=[selected_template_images, custom_image,prompt, nagetive_prompt, prompt_selected,
|
189 |
+
model_selected_tab, prompt_selected_tab, width, height, merge_steps, seed_set], outputs=out)
|
190 |
|
191 |
demo.launch(server_name='0.0.0.0', ssl_verify=False)
|
pipline_StableDiffusion_ConsistentID.py → pipline_ConsistentID.py
RENAMED
@@ -5,7 +5,6 @@ import numpy as np
|
|
5 |
from PIL import Image
|
6 |
import torch
|
7 |
from torchvision import transforms
|
8 |
-
from torchvision.utils import save_image
|
9 |
from insightface.app import FaceAnalysis
|
10 |
### insight-face installation can be found at https://github.com/deepinsight/insightface
|
11 |
from safetensors import safe_open
|
@@ -27,7 +26,7 @@ PipelineImageInput = Union[
|
|
27 |
]
|
28 |
|
29 |
### Download the pretrained model from huggingface and put it locally, then place the model in a local directory and specify the directory location.
|
30 |
-
class
|
31 |
|
32 |
def cuda(self, dtype=torch.float16, use_xformers=False):
|
33 |
self.to('cuda', dtype)
|
@@ -191,7 +190,8 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
|
|
191 |
# parsed_image_parts2 is a batched tensor of parsed_image_parts with bs=1. It only contains the facial areas of one input image.
|
192 |
# clip_encoder maps image parts to image-space diffusion prompts.
|
193 |
# Then the facial class token embeddings are replaced with the fused (multi_facial_embeds, prompt_embeds[class_tokens_mask]).
|
194 |
-
def get_local_facial_embeds(self, prompt_embeds,
|
|
|
195 |
|
196 |
hidden_states = []
|
197 |
uncond_hidden_states = []
|
@@ -209,8 +209,10 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
|
|
209 |
# multi_facial_embeds: [1, 5, 257, 1280].
|
210 |
facial_prompt_embeds = self.FacialEncoder(prompt_embeds, multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask)
|
211 |
|
|
|
|
|
212 |
# unconditional prompt.
|
213 |
-
uncond_facial_prompt_embeds = self.FacialEncoder(
|
214 |
|
215 |
return facial_prompt_embeds, uncond_facial_prompt_embeds
|
216 |
|
@@ -285,9 +287,8 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
|
|
285 |
return vis_parsing_anno_color, vis_parsing_anno
|
286 |
|
287 |
@torch.inference_mode()
|
288 |
-
def extract_facemask(self,
|
289 |
-
|
290 |
-
vis_parsing_anno_color, vis_parsing_anno = self.parse_face_mask(input_image_file)
|
291 |
parsing_mask_list = masks_for_unique_values(vis_parsing_anno)
|
292 |
|
293 |
key_parsing_mask_dict = {}
|
@@ -307,12 +308,11 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
|
|
307 |
|
308 |
return key_parsing_mask_dict, vis_parsing_anno_color
|
309 |
|
310 |
-
def
|
311 |
self,
|
312 |
prompt: str,
|
313 |
face_caption: str,
|
314 |
key_parsing_mask_dict = None,
|
315 |
-
image_token = "<|image|>",
|
316 |
facial_token = "<|facial|>",
|
317 |
max_num_facials = 5,
|
318 |
num_id_images: int = 1,
|
@@ -335,9 +335,9 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
|
|
335 |
face_caption_align = ""
|
336 |
|
337 |
# Remove "<|facial|>" from prompt_face.
|
338 |
-
#
|
339 |
# The person has one nose , two ears , two eyes , and a mouth , '
|
340 |
-
|
341 |
tokenizer = self.tokenizer
|
342 |
facial_token_id = tokenizer.convert_tokens_to_ids(facial_token)
|
343 |
image_token_id = None
|
@@ -350,10 +350,10 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
|
|
350 |
image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask = \
|
351 |
prepare_image_token_idx(image_token_mask, facial_token_mask, num_id_images, max_num_facials)
|
352 |
|
353 |
-
return
|
354 |
|
355 |
@torch.inference_mode()
|
356 |
-
def extract_parsed_image_parts(self,
|
357 |
facial_masks = []
|
358 |
parsed_image_parts = []
|
359 |
key_masked_raw_images_dict = {}
|
@@ -365,7 +365,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
|
|
365 |
for key in key_parsing_mask_dict:
|
366 |
key_mask=key_parsing_mask_dict[key]
|
367 |
facial_masks.append(transform_mask(key_mask))
|
368 |
-
key_masked_raw_image = apply_mask_to_raw_image(
|
369 |
key_masked_raw_images_dict[key] = key_masked_raw_image
|
370 |
# clip_preprocessor normalizes key_masked_raw_image, so that (masked) zero pixels become non-zero.
|
371 |
# It also resizes the image to 224x224.
|
@@ -384,14 +384,88 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
|
|
384 |
|
385 |
return parsed_image_parts, facial_masks, key_masked_raw_images_dict
|
386 |
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
# Only keep the config and in_channels attributes that are used in the pipeline.
|
391 |
-
|
392 |
-
|
393 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
394 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
395 |
@torch.no_grad()
|
396 |
def __call__(
|
397 |
self,
|
@@ -414,12 +488,8 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
|
|
414 |
target_size: Optional[Tuple[int, int]] = None,
|
415 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
416 |
callback_steps: int = 1,
|
417 |
-
|
418 |
start_merge_step: int = 0,
|
419 |
-
class_tokens_mask: Optional[torch.LongTensor] = None,
|
420 |
-
text_embeds: Optional[torch.FloatTensor] = None,
|
421 |
-
retouching: bool=False,
|
422 |
-
need_safetycheck: bool=True,
|
423 |
):
|
424 |
# 0. Default height and width to unet
|
425 |
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
@@ -438,8 +508,8 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
|
|
438 |
prompt_embeds,
|
439 |
negative_prompt_embeds,
|
440 |
)
|
441 |
-
if not isinstance(
|
442 |
-
|
443 |
|
444 |
# 2. Define call parameters
|
445 |
if prompt is not None and isinstance(prompt, str):
|
@@ -451,81 +521,11 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
|
|
451 |
|
452 |
device = self._execution_device
|
453 |
do_classifier_free_guidance = guidance_scale >= 1.0
|
454 |
-
|
455 |
-
input_image_file = input_id_images[0]
|
456 |
-
|
457 |
-
faceid_embeds = self.extract_faceid(face_image=input_image_file)
|
458 |
-
face_caption = "The person has one nose, two eyes, two ears, and a mouth."
|
459 |
-
key_parsing_mask_dict, vis_parsing_anno_color = self.extract_facemask(input_image_file)
|
460 |
-
|
461 |
assert do_classifier_free_guidance
|
462 |
|
463 |
# 3. Encode input prompt
|
464 |
-
|
465 |
-
|
466 |
-
prompt_text_only, clean_input_id, key_parsing_mask_dict_align, \
|
467 |
-
facial_token_mask, facial_token_idx, facial_token_idx_mask \
|
468 |
-
= self.encode_prompt_with_trigger_word(
|
469 |
-
prompt = prompt,
|
470 |
-
face_caption = face_caption,
|
471 |
-
key_parsing_mask_dict=key_parsing_mask_dict,
|
472 |
-
device=device,
|
473 |
-
max_num_facials = 5,
|
474 |
-
num_id_images= num_id_images
|
475 |
-
)
|
476 |
-
|
477 |
-
# 4. Encode input prompt without the trigger word for delayed conditioning
|
478 |
-
# encoder_hidden_states = self.text_encoder(clean_input_id.to(device))[0]
|
479 |
-
|
480 |
-
prompt_embeds = self._encode_prompt(
|
481 |
-
prompt_text_only,
|
482 |
-
device=device,
|
483 |
-
num_images_per_prompt=num_images_per_prompt,
|
484 |
-
do_classifier_free_guidance=True,
|
485 |
-
negative_prompt=negative_prompt,
|
486 |
-
)
|
487 |
-
uncond_text_embeds = prompt_embeds[0:num_images_per_prompt]
|
488 |
-
text_embeds = prompt_embeds[num_images_per_prompt:]
|
489 |
-
|
490 |
-
# 5. Prepare the input ID images
|
491 |
-
# global_id_embeds: [1, 4, 768]
|
492 |
-
# get_global_id_embeds() extrats OpenCLIP embeddings from the input image and map them to global face prompt embeddings.
|
493 |
-
global_id_embeds, uncond_global_id_embeds = \
|
494 |
-
self.get_global_id_embeds(faceid_embeds, face_image=input_image_file, s_scale=1.0, shortcut=False)
|
495 |
-
|
496 |
-
# parsed_image_parts: [5, 3, 224, 224]. 5 parts, each part is a 3-channel 224x224 image (resized by CLIP Preprocessor).
|
497 |
-
parsed_image_parts, facial_masks, key_masked_raw_images_dict = \
|
498 |
-
self.extract_parsed_image_parts(input_image_file, key_parsing_mask_dict_align, image_size=512, max_num_facials=5)
|
499 |
-
parsed_image_parts2 = parsed_image_parts.unsqueeze(0).to(device, dtype=self.torch_dtype)
|
500 |
-
facial_token_mask = facial_token_mask.to(device)
|
501 |
-
facial_token_idx_mask = facial_token_idx_mask.to(device)
|
502 |
-
|
503 |
-
# key_masked_raw_images_dict: ['Right_Eye', 'Right_Ear', 'Nose', 'Upper_Lip']
|
504 |
-
# for key in key_masked_raw_images_dict:
|
505 |
-
# key_masked_raw_images_dict[key].save(f"{key}.png")
|
506 |
-
|
507 |
-
# 6. Get the update text embedding
|
508 |
-
# parsed_image_parts2: the facial areas of the input image
|
509 |
-
# text_local_id_embeds: [1, 77, 768]
|
510 |
-
# text_local_id_embeds only differs with text_global_id_embeds on 4 tokens, and is identical
|
511 |
-
# to text_global_id_embeds on the rest 73 tokens.
|
512 |
-
# get_local_facial_embeds() maps parsed_image_parts2 to multi_facial_embeds, and then replaces the class tokens in prompt_embeds
|
513 |
-
# with the fused (id_embeds, prompt_embeds[class_tokens_mask]) whose indices are specified by class_tokens_mask.
|
514 |
-
# parsed_image_parts2: [1, 5, 3, 224, 224]
|
515 |
-
text_local_id_embeds, uncond_text_local_id_embeds = \
|
516 |
-
self.get_local_facial_embeds(text_embeds, uncond_text_embeds, \
|
517 |
-
parsed_image_parts2, facial_token_mask, facial_token_idx_mask)
|
518 |
-
|
519 |
-
uncond_text_global_id_embeds = torch.cat([uncond_text_local_id_embeds, uncond_global_id_embeds], dim=1)
|
520 |
-
# text_global_id_embeds, text_local_global_id_embeds: [1, 81, 768]
|
521 |
-
text_global_id_embeds = torch.cat([text_embeds, global_id_embeds], dim=1)
|
522 |
-
text_local_global_id_embeds = torch.cat([text_local_id_embeds, global_id_embeds], dim=1)
|
523 |
-
|
524 |
-
coarse_prompt_embeds = torch.cat([uncond_text_global_id_embeds, text_global_id_embeds], dim=0)
|
525 |
-
fine_prompt_embeds = torch.cat([uncond_text_global_id_embeds, text_local_global_id_embeds], dim=0)
|
526 |
-
# fine_prompt_embeds: the conditional part is
|
527 |
-
# (text_global_id_embeds + text_local_global_id_embeds) / 2.
|
528 |
-
fine_prompt_embeds = (coarse_prompt_embeds + fine_prompt_embeds) / 2
|
529 |
|
530 |
# 7. Prepare timesteps
|
531 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
@@ -538,7 +538,7 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
|
|
538 |
num_channels_latents,
|
539 |
height,
|
540 |
width,
|
541 |
-
|
542 |
device,
|
543 |
generator,
|
544 |
latents,
|
@@ -584,48 +584,32 @@ class ConsistentIDStableDiffusionPipeline(StableDiffusionPipeline):
|
|
584 |
).prev_sample
|
585 |
|
586 |
# call the callback, if provided
|
587 |
-
if i == len(timesteps) - 1 or
|
588 |
-
|
589 |
-
):
|
590 |
progress_bar.update()
|
591 |
if callback is not None and i % callback_steps == 0:
|
592 |
callback(i, t, latents)
|
593 |
|
594 |
if output_type == "latent":
|
595 |
image = latents
|
596 |
-
has_nsfw_concept = None
|
597 |
elif output_type == "pil":
|
598 |
# 9.1 Post-processing
|
599 |
image = self.decode_latents(latents)
|
600 |
-
|
601 |
-
# 9.2 Run safety checker
|
602 |
-
if need_safetycheck:
|
603 |
-
image, has_nsfw_concept = self.run_safety_checker(
|
604 |
-
image, device, prompt_embeds.dtype
|
605 |
-
)
|
606 |
-
else:
|
607 |
-
has_nsfw_concept = None
|
608 |
-
|
609 |
# 9.3 Convert to PIL
|
610 |
image = self.numpy_to_pil(image)
|
611 |
else:
|
612 |
# 9.1 Post-processing
|
613 |
image = self.decode_latents(latents)
|
614 |
|
615 |
-
# 9.2 Run safety checker
|
616 |
-
image, has_nsfw_concept = self.run_safety_checker(
|
617 |
-
image, device, prompt_embeds.dtype
|
618 |
-
)
|
619 |
-
|
620 |
# Offload last model to CPU
|
621 |
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
622 |
self.final_offload_hook.offload()
|
623 |
|
624 |
if not return_dict:
|
625 |
-
return (image,
|
626 |
|
627 |
return StableDiffusionPipelineOutput(
|
628 |
-
images=image, nsfw_content_detected=
|
629 |
)
|
630 |
|
631 |
|
|
|
5 |
from PIL import Image
|
6 |
import torch
|
7 |
from torchvision import transforms
|
|
|
8 |
from insightface.app import FaceAnalysis
|
9 |
### insight-face installation can be found at https://github.com/deepinsight/insightface
|
10 |
from safetensors import safe_open
|
|
|
26 |
]
|
27 |
|
28 |
### Download the pretrained model from huggingface and put it locally, then place the model in a local directory and specify the directory location.
|
29 |
+
class ConsistentIDPipeline(StableDiffusionPipeline):
|
30 |
|
31 |
def cuda(self, dtype=torch.float16, use_xformers=False):
|
32 |
self.to('cuda', dtype)
|
|
|
190 |
# parsed_image_parts2 is a batched tensor of parsed_image_parts with bs=1. It only contains the facial areas of one input image.
|
191 |
# clip_encoder maps image parts to image-space diffusion prompts.
|
192 |
# Then the facial class token embeddings are replaced with the fused (multi_facial_embeds, prompt_embeds[class_tokens_mask]).
|
193 |
+
def get_local_facial_embeds(self, prompt_embeds, uncond_prompt_embeds, parsed_image_parts2,
|
194 |
+
facial_token_masks, valid_facial_token_idx_mask, calc_uncond=True):
|
195 |
|
196 |
hidden_states = []
|
197 |
uncond_hidden_states = []
|
|
|
209 |
# multi_facial_embeds: [1, 5, 257, 1280].
|
210 |
facial_prompt_embeds = self.FacialEncoder(prompt_embeds, multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask)
|
211 |
|
212 |
+
if not calc_uncond:
|
213 |
+
return facial_prompt_embeds, None
|
214 |
# unconditional prompt.
|
215 |
+
uncond_facial_prompt_embeds = self.FacialEncoder(uncond_prompt_embeds, uncond_multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask)
|
216 |
|
217 |
return facial_prompt_embeds, uncond_facial_prompt_embeds
|
218 |
|
|
|
287 |
return vis_parsing_anno_color, vis_parsing_anno
|
288 |
|
289 |
@torch.inference_mode()
|
290 |
+
def extract_facemask(self, input_image_obj):
|
291 |
+
vis_parsing_anno_color, vis_parsing_anno = self.parse_face_mask(input_image_obj)
|
|
|
292 |
parsing_mask_list = masks_for_unique_values(vis_parsing_anno)
|
293 |
|
294 |
key_parsing_mask_dict = {}
|
|
|
308 |
|
309 |
return key_parsing_mask_dict, vis_parsing_anno_color
|
310 |
|
311 |
+
def augment_prompt_with_trigger_word(
|
312 |
self,
|
313 |
prompt: str,
|
314 |
face_caption: str,
|
315 |
key_parsing_mask_dict = None,
|
|
|
316 |
facial_token = "<|facial|>",
|
317 |
max_num_facials = 5,
|
318 |
num_id_images: int = 1,
|
|
|
335 |
face_caption_align = ""
|
336 |
|
337 |
# Remove "<|facial|>" from prompt_face.
|
338 |
+
# augmented_prompt: 'A person, police officer, half body shot Detail:
|
339 |
# The person has one nose , two ears , two eyes , and a mouth , '
|
340 |
+
augmented_prompt = prompt_face.replace("<|facial|>", "").replace("<|image|>", "")
|
341 |
tokenizer = self.tokenizer
|
342 |
facial_token_id = tokenizer.convert_tokens_to_ids(facial_token)
|
343 |
image_token_id = None
|
|
|
350 |
image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask = \
|
351 |
prepare_image_token_idx(image_token_mask, facial_token_mask, num_id_images, max_num_facials)
|
352 |
|
353 |
+
return augmented_prompt, clean_input_id, key_parsing_mask_dict_align, facial_token_mask, facial_token_idx, facial_token_idx_mask
|
354 |
|
355 |
@torch.inference_mode()
|
356 |
+
def extract_parsed_image_parts(self, input_image_obj, key_parsing_mask_dict, image_size=512, max_num_facials=5):
|
357 |
facial_masks = []
|
358 |
parsed_image_parts = []
|
359 |
key_masked_raw_images_dict = {}
|
|
|
365 |
for key in key_parsing_mask_dict:
|
366 |
key_mask=key_parsing_mask_dict[key]
|
367 |
facial_masks.append(transform_mask(key_mask))
|
368 |
+
key_masked_raw_image = apply_mask_to_raw_image(input_image_obj, key_mask)
|
369 |
key_masked_raw_images_dict[key] = key_masked_raw_image
|
370 |
# clip_preprocessor normalizes key_masked_raw_image, so that (masked) zero pixels become non-zero.
|
371 |
# It also resizes the image to 224x224.
|
|
|
384 |
|
385 |
return parsed_image_parts, facial_masks, key_masked_raw_images_dict
|
386 |
|
387 |
+
# Release the unet and vae models to save memory.
|
388 |
+
def release_unet_vae(self):
|
389 |
+
unet = edict()
|
390 |
# Only keep the config and in_channels attributes that are used in the pipeline.
|
391 |
+
unet.config = self.unet.config
|
392 |
+
unet.in_channels = self.unet.in_channels
|
393 |
+
self.unet = unet
|
394 |
+
self.vae = None
|
395 |
+
|
396 |
+
# input_subj_image_obj: an Image object.
|
397 |
+
def generate_id_prompt_embeds(self, prompt, negative_prompt, input_subj_image_obj, device, calc_uncond=True):
|
398 |
+
faceid_embeds = self.extract_faceid(face_image=input_subj_image_obj)
|
399 |
+
face_caption = "The person has one nose, two eyes, two ears, and a mouth."
|
400 |
+
key_parsing_mask_dict, vis_parsing_anno_color = self.extract_facemask(input_subj_image_obj)
|
401 |
+
|
402 |
+
augmented_prompt, clean_input_id, key_parsing_mask_dict_align, \
|
403 |
+
facial_token_mask, facial_token_idx, facial_token_idx_mask \
|
404 |
+
= self.augment_prompt_with_trigger_word(
|
405 |
+
prompt = prompt,
|
406 |
+
face_caption = face_caption,
|
407 |
+
key_parsing_mask_dict=key_parsing_mask_dict,
|
408 |
+
device=device,
|
409 |
+
max_num_facials = 5,
|
410 |
+
num_id_images = 1
|
411 |
+
)
|
412 |
+
|
413 |
+
text_embeds, uncond_text_embeds = self.encode_prompt(
|
414 |
+
augmented_prompt,
|
415 |
+
device=device,
|
416 |
+
num_images_per_prompt=1,
|
417 |
+
do_classifier_free_guidance=calc_uncond,
|
418 |
+
negative_prompt=negative_prompt,
|
419 |
+
)
|
420 |
+
|
421 |
+
# 5. Prepare the input ID images
|
422 |
+
# global_id_embeds: [1, 4, 768]
|
423 |
+
# get_global_id_embeds() extrats OpenCLIP embeddings from the input image and map them to global face prompt embeddings.
|
424 |
+
global_id_embeds, uncond_global_id_embeds = \
|
425 |
+
self.get_global_id_embeds(faceid_embeds, face_image=input_subj_image_obj, s_scale=1.0, shortcut=False)
|
426 |
+
|
427 |
+
# parsed_image_parts: [5, 3, 224, 224]. 5 parts, each part is a 3-channel 224x224 image (resized by CLIP Preprocessor).
|
428 |
+
parsed_image_parts, facial_masks, key_masked_raw_images_dict = \
|
429 |
+
self.extract_parsed_image_parts(input_subj_image_obj, key_parsing_mask_dict_align, image_size=512, max_num_facials=5)
|
430 |
+
parsed_image_parts2 = parsed_image_parts.unsqueeze(0).to(device, dtype=self.torch_dtype)
|
431 |
+
facial_token_mask = facial_token_mask.to(device)
|
432 |
+
facial_token_idx_mask = facial_token_idx_mask.to(device)
|
433 |
|
434 |
+
# key_masked_raw_images_dict: ['Right_Eye', 'Right_Ear', 'Nose', 'Upper_Lip']
|
435 |
+
# for key in key_masked_raw_images_dict:
|
436 |
+
# key_masked_raw_images_dict[key].save(f"{key}.png")
|
437 |
+
|
438 |
+
# 6. Get the update text embedding
|
439 |
+
# parsed_image_parts2: the facial areas of the input image
|
440 |
+
# text_local_id_embeds: [1, 77, 768]
|
441 |
+
# text_local_id_embeds only differs with text_global_id_embeds on 4 tokens, and is identical
|
442 |
+
# to text_global_id_embeds on the rest 73 tokens.
|
443 |
+
# get_local_facial_embeds() maps parsed_image_parts2 to multi_facial_embeds, and then replaces the class tokens in prompt_embeds
|
444 |
+
# with the fused (id_embeds, prompt_embeds[class_tokens_mask]) whose indices are specified by class_tokens_mask.
|
445 |
+
# parsed_image_parts2: [1, 5, 3, 224, 224]
|
446 |
+
text_local_id_embeds, uncond_text_local_id_embeds = \
|
447 |
+
self.get_local_facial_embeds(text_embeds, uncond_text_embeds, \
|
448 |
+
parsed_image_parts2, facial_token_mask, facial_token_idx_mask,
|
449 |
+
calc_uncond=calc_uncond)
|
450 |
+
|
451 |
+
# text_global_id_embeds, text_local_global_id_embeds: [1, 81, 768]
|
452 |
+
text_global_id_embeds = torch.cat([text_embeds, global_id_embeds], dim=1)
|
453 |
+
text_local_global_id_embeds = torch.cat([text_local_id_embeds, global_id_embeds], dim=1)
|
454 |
+
|
455 |
+
if calc_uncond:
|
456 |
+
uncond_text_global_id_embeds = torch.cat([uncond_text_local_id_embeds, uncond_global_id_embeds], dim=1)
|
457 |
+
coarse_prompt_embeds = torch.cat([uncond_text_global_id_embeds, text_global_id_embeds], dim=0)
|
458 |
+
fine_prompt_embeds = torch.cat([uncond_text_global_id_embeds, text_local_global_id_embeds], dim=0)
|
459 |
+
else:
|
460 |
+
coarse_prompt_embeds = text_global_id_embeds
|
461 |
+
fine_prompt_embeds = text_local_global_id_embeds
|
462 |
+
|
463 |
+
# fine_prompt_embeds: the conditional part is
|
464 |
+
# (text_global_id_embeds + text_local_global_id_embeds) / 2.
|
465 |
+
fine_prompt_embeds = (coarse_prompt_embeds + fine_prompt_embeds) / 2
|
466 |
+
|
467 |
+
return coarse_prompt_embeds, fine_prompt_embeds
|
468 |
+
|
469 |
@torch.no_grad()
|
470 |
def __call__(
|
471 |
self,
|
|
|
488 |
target_size: Optional[Tuple[int, int]] = None,
|
489 |
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
490 |
callback_steps: int = 1,
|
491 |
+
input_subj_image_objs: PipelineImageInput = None,
|
492 |
start_merge_step: int = 0,
|
|
|
|
|
|
|
|
|
493 |
):
|
494 |
# 0. Default height and width to unet
|
495 |
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
|
|
508 |
prompt_embeds,
|
509 |
negative_prompt_embeds,
|
510 |
)
|
511 |
+
if not isinstance(input_subj_image_objs, list):
|
512 |
+
input_subj_image_objs = [input_subj_image_objs]
|
513 |
|
514 |
# 2. Define call parameters
|
515 |
if prompt is not None and isinstance(prompt, str):
|
|
|
521 |
|
522 |
device = self._execution_device
|
523 |
do_classifier_free_guidance = guidance_scale >= 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
524 |
assert do_classifier_free_guidance
|
525 |
|
526 |
# 3. Encode input prompt
|
527 |
+
coarse_prompt_embeds, fine_prompt_embeds = \
|
528 |
+
self.generate_id_prompt_embeds(prompt, negative_prompt, input_subj_image_objs[0], device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
529 |
|
530 |
# 7. Prepare timesteps
|
531 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
|
|
538 |
num_channels_latents,
|
539 |
height,
|
540 |
width,
|
541 |
+
coarse_prompt_embeds.dtype,
|
542 |
device,
|
543 |
generator,
|
544 |
latents,
|
|
|
584 |
).prev_sample
|
585 |
|
586 |
# call the callback, if provided
|
587 |
+
if i == len(timesteps) - 1 or \
|
588 |
+
( (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 ):
|
|
|
589 |
progress_bar.update()
|
590 |
if callback is not None and i % callback_steps == 0:
|
591 |
callback(i, t, latents)
|
592 |
|
593 |
if output_type == "latent":
|
594 |
image = latents
|
|
|
595 |
elif output_type == "pil":
|
596 |
# 9.1 Post-processing
|
597 |
image = self.decode_latents(latents)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
598 |
# 9.3 Convert to PIL
|
599 |
image = self.numpy_to_pil(image)
|
600 |
else:
|
601 |
# 9.1 Post-processing
|
602 |
image = self.decode_latents(latents)
|
603 |
|
|
|
|
|
|
|
|
|
|
|
604 |
# Offload last model to CPU
|
605 |
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
606 |
self.final_offload_hook.offload()
|
607 |
|
608 |
if not return_dict:
|
609 |
+
return (image, None)
|
610 |
|
611 |
return StableDiffusionPipelineOutput(
|
612 |
+
images=image, nsfw_content_detected=None
|
613 |
)
|
614 |
|
615 |
|