adaface-neurips commited on
Commit
3736ac5
·
1 Parent(s): b80f000
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +10 -0
  3. ConsistentID/.gitattributes +38 -0
  4. ConsistentID/.gitignore +5 -0
  5. ConsistentID/LICENSE +21 -0
  6. ConsistentID/README.md +13 -0
  7. ConsistentID/__init__.py +0 -0
  8. ConsistentID/app.py +168 -0
  9. ConsistentID/images/templates/3f8d901770014c1b8f7f261971f0e92.png +3 -0
  10. ConsistentID/images/templates/6577b962b6346df03fea83211daaf48.png +0 -0
  11. ConsistentID/images/templates/75583964a834abe33b72f52b1a98e84.png +3 -0
  12. ConsistentID/images/templates/c9fe4c2d5ddbc5670dde47fc465c48b.jpg +0 -0
  13. ConsistentID/lib/BiSeNet/6.jpg +0 -0
  14. ConsistentID/lib/BiSeNet/__init__.py +2 -0
  15. ConsistentID/lib/BiSeNet/evaluate.py +95 -0
  16. ConsistentID/lib/BiSeNet/face_dataset.py +106 -0
  17. ConsistentID/lib/BiSeNet/hair.png +0 -0
  18. ConsistentID/lib/BiSeNet/logger.py +23 -0
  19. ConsistentID/lib/BiSeNet/loss.py +72 -0
  20. ConsistentID/lib/BiSeNet/makeup.py +129 -0
  21. ConsistentID/lib/BiSeNet/makeup/116_1.png +0 -0
  22. ConsistentID/lib/BiSeNet/makeup/116_3.png +0 -0
  23. ConsistentID/lib/BiSeNet/makeup/116_lip_ori.png +0 -0
  24. ConsistentID/lib/BiSeNet/makeup/116_ori.png +0 -0
  25. ConsistentID/lib/BiSeNet/model.py +282 -0
  26. ConsistentID/lib/BiSeNet/modules/__init__.py +5 -0
  27. ConsistentID/lib/BiSeNet/modules/bn.py +130 -0
  28. ConsistentID/lib/BiSeNet/modules/deeplab.py +84 -0
  29. ConsistentID/lib/BiSeNet/modules/dense.py +42 -0
  30. ConsistentID/lib/BiSeNet/modules/functions.py +234 -0
  31. ConsistentID/lib/BiSeNet/modules/misc.py +21 -0
  32. ConsistentID/lib/BiSeNet/modules/residual.py +88 -0
  33. ConsistentID/lib/BiSeNet/modules/src/checks.h +15 -0
  34. ConsistentID/lib/BiSeNet/modules/src/inplace_abn.cpp +95 -0
  35. ConsistentID/lib/BiSeNet/modules/src/inplace_abn.h +88 -0
  36. ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cpu.cpp +119 -0
  37. ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cuda.cu +333 -0
  38. ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cuda_half.cu +275 -0
  39. ConsistentID/lib/BiSeNet/modules/src/utils/checks.h +15 -0
  40. ConsistentID/lib/BiSeNet/modules/src/utils/common.h +49 -0
  41. ConsistentID/lib/BiSeNet/modules/src/utils/cuda.cuh +71 -0
  42. ConsistentID/lib/BiSeNet/optimizer.py +69 -0
  43. ConsistentID/lib/BiSeNet/prepropess_data.py +38 -0
  44. ConsistentID/lib/BiSeNet/resnet.py +109 -0
  45. ConsistentID/lib/BiSeNet/test.py +90 -0
  46. ConsistentID/lib/BiSeNet/train.py +179 -0
  47. ConsistentID/lib/BiSeNet/transform.py +129 -0
  48. ConsistentID/lib/attention.py +287 -0
  49. ConsistentID/lib/functions.py +606 -0
  50. ConsistentID/lib/pipeline_ConsistentID.py +605 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ models/ensemble/ar18-unet/diffusion_pytorch_model.safetensors filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ models/awportrait/*
2
+ models/awportrait
3
+ __pycache__/*
4
+ __pycache__
5
+ samples-ada/*
6
+ samples-ada
7
+ models/ensemble/awp14-unet/*
8
+ models/ensemble/awp14-unet
9
+ .gradio/certificate.pem
10
+
ConsistentID/.gitattributes ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ images/templates/3f8d901770014c1b8f7f261971f0e92.png filter=lfs diff=lfs merge=lfs -text
37
+ images/templates/75583964a834abe33b72f52b1a98e84.png filter=lfs diff=lfs merge=lfs -text
38
+ models/LLaVA/images/demo_cli.gif filter=lfs diff=lfs merge=lfs -text
ConsistentID/.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__/*
2
+ __pycache__
3
+ /*.png
4
+ models/insightface
5
+ models/Realistic_Vision*
ConsistentID/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Jiehui Huang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
ConsistentID/README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ConsistentID
3
+ emoji: 🔥
4
+ colorFrom: yellow
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 4.37.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
ConsistentID/__init__.py ADDED
File without changes
ConsistentID/app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import glob
5
+ import spaces
6
+ import numpy as np
7
+
8
+ from PIL import Image
9
+ from diffusers.utils import load_image
10
+ from diffusers import EulerDiscreteScheduler
11
+ from ConsistentID.lib.pipeline_ConsistentID import ConsistentIDPipeline
12
+ import argparse
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument('--base_model_path', type=str,
15
+ default="models/Realistic_Vision_V4.0_noVAE")
16
+ parser.add_argument('--gpu', type=int, default=0)
17
+ args = parser.parse_args()
18
+
19
+ device = f"cuda:{args.gpu}"
20
+
21
+ ### Load base model
22
+ pipe = ConsistentIDPipeline.from_pretrained(
23
+ args.base_model_path,
24
+ torch_dtype=torch.float16,
25
+ )
26
+
27
+ ### Load consistentID_model checkpoint
28
+ pipe.load_ConsistentID_model(
29
+ consistentID_weight_path="./models/ConsistentID-v1.bin",
30
+ bise_net_weight_path="./models/BiSeNet_pretrained_for_ConsistentID.pth",
31
+ )
32
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
33
+ pipe = pipe.to(device, torch.float16)
34
+
35
+ @spaces.GPU
36
+ def process(selected_template_images, custom_image, prompt,
37
+ negative_prompt, prompt_selected, model_selected_tab,
38
+ prompt_selected_tab, guidance_scale, width, height, merge_steps, seed_set):
39
+
40
+ # The gradio UI only supports one image at a time.
41
+ if model_selected_tab==0:
42
+ subj_images = load_image(Image.open(selected_template_images))
43
+ else:
44
+ subj_images = load_image(Image.fromarray(custom_image))
45
+
46
+ if prompt_selected_tab==0:
47
+ prompt = prompt_selected
48
+ negative_prompt = ""
49
+
50
+ # hyper-parameter
51
+ num_steps = 50
52
+ seed_set = torch.randint(0, 1000, (1,)).item()
53
+ # merge_steps = 30
54
+
55
+ if prompt == "":
56
+ prompt = "A man, in a forest"
57
+ prompt = "A man, with backpack, in a raining tropical forest, adventuring, holding a flashlight, in mist, seeking animals"
58
+ prompt = "A person, in a sowm, wearing santa hat and a scarf, with a cottage behind"
59
+ else:
60
+ #prompt=Enhance_prompt(prompt, Image.new('RGB', (200, 200), color = 'white'))
61
+ print(prompt)
62
+
63
+ if negative_prompt == "":
64
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality, blurry"
65
+
66
+ #Extend Prompt
67
+ #prompt = "cinematic photo," + prompt + ", 50mm photograph, half-length portrait, film, bokeh, professional, 4k, highly detailed"
68
+ #print(prompt)
69
+
70
+ negtive_prompt_group="((cross-eye)),((cross-eyed)),(((NFSW))),(nipple),((((ugly)))), (((duplicate))), ((morbid)), ((mutilated)), [out of frame], extra fingers, mutated hands, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, (((disfigured))). out of frame, ugly, extra limbs, (bad anatomy), gross proportions, (malformed limbs), ((missing arms)), ((missing legs)), (((extra arms))), (((extra legs))), mutated hands, (fused fingers), (too many fingers), (((long neck)))"
71
+ negative_prompt = negative_prompt + negtive_prompt_group
72
+
73
+ # seed = torch.randint(0, 1000, (1,)).item()
74
+ generator = torch.Generator(device=device).manual_seed(seed_set)
75
+
76
+ images = pipe(
77
+ prompt=prompt,
78
+ width=width,
79
+ height=height,
80
+ input_subj_image_objs=subj_images,
81
+ negative_prompt=negative_prompt,
82
+ num_images_per_prompt=1,
83
+ num_inference_steps=num_steps,
84
+ guidance_scale=guidance_scale,
85
+ start_merge_step=merge_steps,
86
+ generator=generator,
87
+ ).images[0]
88
+
89
+ return np.array(images)
90
+
91
+ # Gets the templates
92
+ preset_template = glob.glob("./images/templates/*.png")
93
+ preset_template = preset_template + glob.glob("./images/templates/*.jpg")
94
+
95
+ with gr.Blocks(title="ConsistentID Demo") as demo:
96
+ gr.Markdown("# ConsistentID Demo")
97
+ gr.Markdown("\
98
+ Put the reference figure to be redrawn into the box below (There is a small probability of referensing failure. You can submit it repeatedly)")
99
+ gr.Markdown("\
100
+ If you find our work interesting, please leave a star in GitHub for us!<br>\
101
+ https://github.com/JackAILab/ConsistentID")
102
+ with gr.Row():
103
+ with gr.Column():
104
+ model_selected_tab = gr.State(0)
105
+ with gr.TabItem("template images") as template_images_tab:
106
+ template_gallery_list = [(i, i) for i in preset_template]
107
+ gallery = gr.Gallery(template_gallery_list,columns=[4], rows=[2], object_fit="contain", height="auto",show_label=False)
108
+
109
+ def select_function(evt: gr.SelectData):
110
+ return preset_template[evt.index]
111
+
112
+ selected_template_images = gr.Text(show_label=False, visible=False, placeholder="Selected")
113
+ gallery.select(select_function, None, selected_template_images)
114
+ with gr.TabItem("Upload Image") as upload_image_tab:
115
+ custom_image = gr.Image(label="Upload Image")
116
+
117
+ model_selected_tabs = [template_images_tab, upload_image_tab]
118
+ for i, tab in enumerate(model_selected_tabs):
119
+ tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[model_selected_tab])
120
+
121
+ with gr.Column():
122
+ prompt_selected_tab = gr.State(0)
123
+ with gr.TabItem("template prompts") as template_prompts_tab:
124
+ prompt_selected = gr.Dropdown(value="A person, police officer, half body shot", elem_id='dropdown', choices=[
125
+ "A woman in a wedding dress",
126
+ "A woman, queen, in a gorgeous palace",
127
+ "A man sitting at the beach with sunset",
128
+ "A person, police officer, half body shot",
129
+ "A man, sailor, in a boat above ocean",
130
+ "A women wearing headphone, listening music",
131
+ "A man, firefighter, half body shot"], label=f"prepared prompts")
132
+
133
+ with gr.TabItem("custom prompt") as custom_prompt_tab:
134
+ prompt = gr.Textbox(label="prompt",placeholder="A man/woman wearing a santa hat")
135
+ nagetive_prompt = gr.Textbox(label="negative prompt",placeholder="monochrome, lowres, bad anatomy, worst quality, low quality, blurry")
136
+
137
+ prompt_selected_tabs = [template_prompts_tab, custom_prompt_tab]
138
+ for i, tab in enumerate(prompt_selected_tabs):
139
+ tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[prompt_selected_tab])
140
+
141
+ guidance_scale = gr.Slider(
142
+ label="Guidance scale",
143
+ minimum=1.0,
144
+ maximum=10.0,
145
+ step=1.0,
146
+ value=5.0,
147
+ )
148
+
149
+ width = gr.Slider(label="image width",minimum=256,maximum=768,value=512,step=8)
150
+ height = gr.Slider(label="image height",minimum=256,maximum=768,value=512,step=8)
151
+ width.release(lambda x,y: min(1280-x,y), inputs=[width,height], outputs=[height])
152
+ height.release(lambda x,y: min(1280-y,x), inputs=[width,height], outputs=[width])
153
+ merge_steps = gr.Slider(label="step starting to merge facial details(30 is recommended)",minimum=10,maximum=50,value=30,step=1)
154
+ seed_set = gr.Slider(label="set the random seed for different results",minimum=1,maximum=2147483647,value=2024,step=1)
155
+
156
+ btn = gr.Button("Run")
157
+ with gr.Column():
158
+ out = gr.Image(label="Output")
159
+ gr.Markdown('''
160
+ N.B.:<br/>
161
+ - If the proportion of face in the image is too small, the probability of an error will be slightly higher, and the similarity will also significantly decrease.)
162
+ - At the same time, use prompt with \"man\" or \"woman\" instead of \"person\" as much as possible, as that may cause the model to be confused whether the protagonist is male or female.
163
+ - Due to insufficient graphics memory on the demo server, there is an upper limit on the resolution for generating samples. We will support the generation of SDXL as soon as possible<br/><br/>
164
+ ''')
165
+ btn.click(fn=process, inputs=[selected_template_images, custom_image,prompt, nagetive_prompt, prompt_selected,
166
+ model_selected_tab, prompt_selected_tab, guidance_scale, width, height, merge_steps, seed_set], outputs=out)
167
+
168
+ demo.launch(server_name='0.0.0.0', ssl_verify=False)
ConsistentID/images/templates/3f8d901770014c1b8f7f261971f0e92.png ADDED

Git LFS Details

  • SHA256: 4fa9319750b9927075934c40a180766e75ff539711293581dae6bac5963b9d05
  • Pointer size: 132 Bytes
  • Size of remote file: 2.06 MB
ConsistentID/images/templates/6577b962b6346df03fea83211daaf48.png ADDED
ConsistentID/images/templates/75583964a834abe33b72f52b1a98e84.png ADDED

Git LFS Details

  • SHA256: 318c942eb3cc8a1f9320b2ea84a88cd95067785c07f8ae1dd18fe6c4cf8e8282
  • Pointer size: 132 Bytes
  • Size of remote file: 7.54 MB
ConsistentID/images/templates/c9fe4c2d5ddbc5670dde47fc465c48b.jpg ADDED
ConsistentID/lib/BiSeNet/6.jpg ADDED
ConsistentID/lib/BiSeNet/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ #__init__.py
2
+ # from BiSeNet.model import *
ConsistentID/lib/BiSeNet/evaluate.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ from logger import setup_logger
5
+ import BiSeNet
6
+ from face_dataset import FaceMask
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.utils.data import DataLoader
11
+ import torch.nn.functional as F
12
+ import torch.distributed as dist
13
+
14
+ import os
15
+ import os.path as osp
16
+ import logging
17
+ import time
18
+ import numpy as np
19
+ from tqdm import tqdm
20
+ import math
21
+ from PIL import Image
22
+ import torchvision.transforms as transforms
23
+ import cv2
24
+
25
+ def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'):
26
+ # Colors for all 20 parts
27
+ part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
28
+ [255, 0, 85], [255, 0, 170],
29
+ [0, 255, 0], [85, 255, 0], [170, 255, 0],
30
+ [0, 255, 85], [0, 255, 170],
31
+ [0, 0, 255], [85, 0, 255], [170, 0, 255],
32
+ [0, 85, 255], [0, 170, 255],
33
+ [255, 255, 0], [255, 255, 85], [255, 255, 170],
34
+ [255, 0, 255], [255, 85, 255], [255, 170, 255],
35
+ [0, 255, 255], [85, 255, 255], [170, 255, 255]]
36
+
37
+ im = np.array(im)
38
+ vis_im = im.copy().astype(np.uint8)
39
+ vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
40
+ vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
41
+ vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
42
+
43
+ num_of_class = np.max(vis_parsing_anno)
44
+
45
+ for pi in range(1, num_of_class + 1):
46
+ index = np.where(vis_parsing_anno == pi)
47
+ vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
48
+
49
+ vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
50
+ # print(vis_parsing_anno_color.shape, vis_im.shape)
51
+ vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
52
+
53
+ # Save result or not
54
+ if save_im:
55
+ cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
56
+
57
+ # return vis_im
58
+
59
+ def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'):
60
+
61
+ if not os.path.exists(respth):
62
+ os.makedirs(respth)
63
+
64
+ n_classes = 19
65
+ net = BiSeNet(n_classes=n_classes)
66
+ net.cuda()
67
+ save_pth = osp.join('res/cp', cp)
68
+ net.load_state_dict(torch.load(save_pth))
69
+ net.eval()
70
+
71
+ to_tensor = transforms.Compose([
72
+ transforms.ToTensor(),
73
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
74
+ ])
75
+ with torch.no_grad():
76
+ for image_path in os.listdir(dspth):
77
+ img = Image.open(osp.join(dspth, image_path))
78
+ image = img.resize((512, 512), Image.BILINEAR)
79
+ img = to_tensor(image)
80
+ img = torch.unsqueeze(img, 0)
81
+ img = img.cuda()
82
+ out = net(img)[0]
83
+ parsing = out.squeeze(0).cpu().numpy().argmax(0)
84
+
85
+ vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path))
86
+
87
+
88
+
89
+
90
+
91
+
92
+
93
+ if __name__ == "__main__":
94
+ setup_logger('./res')
95
+ evaluate()
ConsistentID/lib/BiSeNet/face_dataset.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ import torchvision.transforms as transforms
7
+
8
+ import os.path as osp
9
+ import os
10
+ from PIL import Image
11
+ import numpy as np
12
+ import json
13
+ import cv2
14
+
15
+ from transform import *
16
+
17
+
18
+
19
+ class FaceMask(Dataset):
20
+ def __init__(self, rootpth, cropsize=(640, 480), mode='train', *args, **kwargs):
21
+ super(FaceMask, self).__init__(*args, **kwargs)
22
+ assert mode in ('train', 'val', 'test')
23
+ self.mode = mode
24
+ self.ignore_lb = 255
25
+ self.rootpth = rootpth
26
+
27
+ self.imgs = os.listdir(os.path.join(self.rootpth, 'CelebA-HQ-img'))
28
+
29
+ # pre-processing
30
+ self.to_tensor = transforms.Compose([
31
+ transforms.ToTensor(),
32
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
33
+ ])
34
+ self.trans_train = Compose([
35
+ ColorJitter(
36
+ brightness=0.5,
37
+ contrast=0.5,
38
+ saturation=0.5),
39
+ HorizontalFlip(),
40
+ RandomScale((0.75, 1.0, 1.25, 1.5, 1.75, 2.0)),
41
+ RandomCrop(cropsize)
42
+ ])
43
+
44
+ def __getitem__(self, idx):
45
+ impth = self.imgs[idx]
46
+ img = Image.open(osp.join(self.rootpth, 'CelebA-HQ-img', impth))
47
+ img = img.resize((512, 512), Image.BILINEAR)
48
+ label = Image.open(osp.join(self.rootpth, 'mask', impth[:-3]+'png')).convert('P')
49
+ # print(np.unique(np.array(label)))
50
+ if self.mode == 'train':
51
+ im_lb = dict(im=img, lb=label)
52
+ im_lb = self.trans_train(im_lb)
53
+ img, label = im_lb['im'], im_lb['lb']
54
+ img = self.to_tensor(img)
55
+ label = np.array(label).astype(np.int64)[np.newaxis, :]
56
+ return img, label
57
+
58
+ def __len__(self):
59
+ return len(self.imgs)
60
+
61
+
62
+ if __name__ == "__main__":
63
+ face_data = '/home/zll/data/CelebAMask-HQ/CelebA-HQ-img'
64
+ face_sep_mask = '/home/zll/data/CelebAMask-HQ/CelebAMask-HQ-mask-anno'
65
+ mask_path = '/home/zll/data/CelebAMask-HQ/mask'
66
+ counter = 0
67
+ total = 0
68
+ for i in range(15):
69
+ # files = os.listdir(osp.join(face_sep_mask, str(i)))
70
+
71
+ atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
72
+ 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
73
+
74
+ for j in range(i*2000, (i+1)*2000):
75
+
76
+ mask = np.zeros((512, 512))
77
+
78
+ for l, att in enumerate(atts, 1):
79
+ total += 1
80
+ file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png'])
81
+ path = osp.join(face_sep_mask, str(i), file_name)
82
+
83
+ if os.path.exists(path):
84
+ counter += 1
85
+ sep_mask = np.array(Image.open(path).convert('P'))
86
+ # print(np.unique(sep_mask))
87
+
88
+ mask[sep_mask == 225] = l
89
+ cv2.imwrite('{}/{}.png'.format(mask_path, j), mask)
90
+ print(j)
91
+
92
+ print(counter, total)
93
+
94
+
95
+
96
+
97
+
98
+
99
+
100
+
101
+
102
+
103
+
104
+
105
+
106
+
ConsistentID/lib/BiSeNet/hair.png ADDED
ConsistentID/lib/BiSeNet/logger.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ import os.path as osp
6
+ import time
7
+ import sys
8
+ import logging
9
+
10
+ import torch.distributed as dist
11
+
12
+
13
+ def setup_logger(logpth):
14
+ logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S'))
15
+ logfile = osp.join(logpth, logfile)
16
+ FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s'
17
+ log_level = logging.INFO
18
+ if dist.is_initialized() and not dist.get_rank()==0:
19
+ log_level = logging.ERROR
20
+ logging.basicConfig(level=log_level, format=FORMAT, filename=logfile)
21
+ logging.root.addHandler(logging.StreamHandler())
22
+
23
+
ConsistentID/lib/BiSeNet/loss.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ class OhemCELoss(nn.Module):
10
+ def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
11
+ super(OhemCELoss, self).__init__()
12
+ self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()
13
+ self.n_min = n_min
14
+ self.ignore_lb = ignore_lb
15
+ self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')
16
+
17
+ def forward(self, logits, labels):
18
+ N, C, H, W = logits.size()
19
+ loss = self.criteria(logits, labels).view(-1)
20
+ loss, _ = torch.sort(loss, descending=True)
21
+ if loss[self.n_min] > self.thresh:
22
+ loss = loss[loss>self.thresh]
23
+ else:
24
+ loss = loss[:self.n_min]
25
+ return torch.mean(loss)
26
+
27
+
28
+ class SoftmaxFocalLoss(nn.Module):
29
+ def __init__(self, gamma, ignore_lb=255, *args, **kwargs):
30
+ super(SoftmaxFocalLoss, self).__init__()
31
+ self.gamma = gamma
32
+ self.nll = nn.NLLLoss(ignore_index=ignore_lb)
33
+
34
+ def forward(self, logits, labels):
35
+ scores = F.softmax(logits, dim=1)
36
+ factor = torch.pow(1.-scores, self.gamma)
37
+ log_score = F.log_softmax(logits, dim=1)
38
+ log_score = factor * log_score
39
+ loss = self.nll(log_score, labels)
40
+ return loss
41
+
42
+
43
+ if __name__ == '__main__':
44
+ torch.manual_seed(15)
45
+ criteria1 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda()
46
+ criteria2 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda()
47
+ net1 = nn.Sequential(
48
+ nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1),
49
+ )
50
+ net1.cuda()
51
+ net1.train()
52
+ net2 = nn.Sequential(
53
+ nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1),
54
+ )
55
+ net2.cuda()
56
+ net2.train()
57
+
58
+ with torch.no_grad():
59
+ inten = torch.randn(16, 3, 20, 20).cuda()
60
+ lbs = torch.randint(0, 19, [16, 20, 20]).cuda()
61
+ lbs[1, :, :] = 255
62
+
63
+ logits1 = net1(inten)
64
+ logits1 = F.interpolate(logits1, inten.size()[2:], mode='bilinear')
65
+ logits2 = net2(inten)
66
+ logits2 = F.interpolate(logits2, inten.size()[2:], mode='bilinear')
67
+
68
+ loss1 = criteria1(logits1, lbs)
69
+ loss2 = criteria2(logits2, lbs)
70
+ loss = loss1 + loss2
71
+ print(loss.detach().cpu())
72
+ loss.backward()
ConsistentID/lib/BiSeNet/makeup.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from skimage.filters import gaussian
4
+
5
+
6
+ def sharpen(img):
7
+ img = img * 1.0
8
+ gauss_out = gaussian(img, sigma=5, multichannel=True)
9
+
10
+ alpha = 1.5
11
+ img_out = (img - gauss_out) * alpha + img
12
+
13
+ img_out = img_out / 255.0
14
+
15
+ mask_1 = img_out < 0
16
+ mask_2 = img_out > 1
17
+
18
+ img_out = img_out * (1 - mask_1)
19
+ img_out = img_out * (1 - mask_2) + mask_2
20
+ img_out = np.clip(img_out, 0, 1)
21
+ img_out = img_out * 255
22
+ return np.array(img_out, dtype=np.uint8)
23
+
24
+
25
+ def hair(image, parsing, part=17, color=[230, 50, 20]):
26
+ b, g, r = color #[10, 50, 250] # [10, 250, 10]
27
+ tar_color = np.zeros_like(image)
28
+ tar_color[:, :, 0] = b
29
+ tar_color[:, :, 1] = g
30
+ tar_color[:, :, 2] = r
31
+
32
+ image_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
33
+ tar_hsv = cv2.cvtColor(tar_color, cv2.COLOR_BGR2HSV)
34
+
35
+ if part == 12 or part == 13:
36
+ image_hsv[:, :, 0:2] = tar_hsv[:, :, 0:2]
37
+ else:
38
+ image_hsv[:, :, 0:1] = tar_hsv[:, :, 0:1]
39
+
40
+ changed = cv2.cvtColor(image_hsv, cv2.COLOR_HSV2BGR)
41
+
42
+ if part == 17:
43
+ changed = sharpen(changed)
44
+
45
+ changed[parsing != part] = image[parsing != part]
46
+ # changed = cv2.resize(changed, (512, 512))
47
+ return changed
48
+
49
+ #
50
+ # def lip(image, parsing, part=17, color=[230, 50, 20]):
51
+ # b, g, r = color #[10, 50, 250] # [10, 250, 10]
52
+ # tar_color = np.zeros_like(image)
53
+ # tar_color[:, :, 0] = b
54
+ # tar_color[:, :, 1] = g
55
+ # tar_color[:, :, 2] = r
56
+ #
57
+ # image_lab = cv2.cvtColor(image, cv2.COLOR_BGR2Lab)
58
+ # il, ia, ib = cv2.split(image_lab)
59
+ #
60
+ # tar_lab = cv2.cvtColor(tar_color, cv2.COLOR_BGR2Lab)
61
+ # tl, ta, tb = cv2.split(tar_lab)
62
+ #
63
+ # image_lab[:, :, 0] = np.clip(il - np.mean(il) + tl, 0, 100)
64
+ # image_lab[:, :, 1] = np.clip(ia - np.mean(ia) + ta, -127, 128)
65
+ # image_lab[:, :, 2] = np.clip(ib - np.mean(ib) + tb, -127, 128)
66
+ #
67
+ #
68
+ # changed = cv2.cvtColor(image_lab, cv2.COLOR_Lab2BGR)
69
+ #
70
+ # if part == 17:
71
+ # changed = sharpen(changed)
72
+ #
73
+ # changed[parsing != part] = image[parsing != part]
74
+ # # changed = cv2.resize(changed, (512, 512))
75
+ # return changed
76
+
77
+
78
+ if __name__ == '__main__':
79
+ # 1 face
80
+ # 10 nose
81
+ # 11 teeth
82
+ # 12 upper lip
83
+ # 13 lower lip
84
+ # 17 hair
85
+ num = 116
86
+ table = {
87
+ 'hair': 17,
88
+ 'upper_lip': 12,
89
+ 'lower_lip': 13
90
+ }
91
+ image_path = '/home/zll/data/CelebAMask-HQ/test-img/{}.jpg'.format(num)
92
+ parsing_path = 'res/test_res/{}.png'.format(num)
93
+
94
+ image = cv2.imread(image_path)
95
+ ori = image.copy()
96
+ parsing = np.array(cv2.imread(parsing_path, 0))
97
+ parsing = cv2.resize(parsing, image.shape[0:2], interpolation=cv2.INTER_NEAREST)
98
+
99
+ parts = [table['hair'], table['upper_lip'], table['lower_lip']]
100
+ # colors = [[20, 20, 200], [100, 100, 230], [100, 100, 230]]
101
+ colors = [[100, 200, 100]]
102
+ for part, color in zip(parts, colors):
103
+ image = hair(image, parsing, part, color)
104
+ cv2.imwrite('res/makeup/116_ori.png', cv2.resize(ori, (512, 512)))
105
+ cv2.imwrite('res/makeup/116_2.png', cv2.resize(image, (512, 512)))
106
+
107
+ cv2.imshow('image', cv2.resize(ori, (512, 512)))
108
+ cv2.imshow('color', cv2.resize(image, (512, 512)))
109
+
110
+ # cv2.imshow('image', ori)
111
+ # cv2.imshow('color', image)
112
+
113
+ cv2.waitKey(0)
114
+ cv2.destroyAllWindows()
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+
ConsistentID/lib/BiSeNet/makeup/116_1.png ADDED
ConsistentID/lib/BiSeNet/makeup/116_3.png ADDED
ConsistentID/lib/BiSeNet/makeup/116_lip_ori.png ADDED
ConsistentID/lib/BiSeNet/makeup/116_ori.png ADDED
ConsistentID/lib/BiSeNet/model.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from .resnet import Resnet18
10
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
11
+
12
+
13
+ class ConvBNReLU(nn.Module):
14
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
15
+ super(ConvBNReLU, self).__init__()
16
+ self.conv = nn.Conv2d(in_chan,
17
+ out_chan,
18
+ kernel_size = ks,
19
+ stride = stride,
20
+ padding = padding,
21
+ bias = False)
22
+ self.bn = nn.BatchNorm2d(out_chan)
23
+ self.init_weight()
24
+
25
+ def forward(self, x):
26
+ x = self.conv(x)
27
+ x = F.relu(self.bn(x))
28
+ return x
29
+
30
+ def init_weight(self):
31
+ for ly in self.children():
32
+ if isinstance(ly, nn.Conv2d):
33
+ nn.init.kaiming_normal_(ly.weight, a=1)
34
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
35
+
36
+ class BiSeNetOutput(nn.Module):
37
+ def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
38
+ super(BiSeNetOutput, self).__init__()
39
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
40
+ self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
41
+ self.init_weight()
42
+
43
+ def forward(self, x):
44
+ x = self.conv(x)
45
+ x = self.conv_out(x)
46
+ return x
47
+
48
+ def init_weight(self):
49
+ for ly in self.children():
50
+ if isinstance(ly, nn.Conv2d):
51
+ nn.init.kaiming_normal_(ly.weight, a=1)
52
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
53
+
54
+ def get_params(self):
55
+ wd_params, nowd_params = [], []
56
+ for name, module in self.named_modules():
57
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
58
+ wd_params.append(module.weight)
59
+ if not module.bias is None:
60
+ nowd_params.append(module.bias)
61
+ elif isinstance(module, nn.BatchNorm2d):
62
+ nowd_params += list(module.parameters())
63
+ return wd_params, nowd_params
64
+
65
+
66
+ class AttentionRefinementModule(nn.Module):
67
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
68
+ super(AttentionRefinementModule, self).__init__()
69
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
70
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
71
+ self.bn_atten = nn.BatchNorm2d(out_chan)
72
+ self.sigmoid_atten = nn.Sigmoid()
73
+ self.init_weight()
74
+
75
+ def forward(self, x):
76
+ feat = self.conv(x)
77
+ atten = F.avg_pool2d(feat, feat.size()[2:])
78
+ atten = self.conv_atten(atten)
79
+ atten = self.bn_atten(atten)
80
+ atten = self.sigmoid_atten(atten)
81
+ out = torch.mul(feat, atten)
82
+ return out
83
+
84
+ def init_weight(self):
85
+ for ly in self.children():
86
+ if isinstance(ly, nn.Conv2d):
87
+ nn.init.kaiming_normal_(ly.weight, a=1)
88
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
89
+
90
+
91
+ class ContextPath(nn.Module):
92
+ def __init__(self, *args, **kwargs):
93
+ super(ContextPath, self).__init__()
94
+ self.resnet = Resnet18()
95
+ self.arm16 = AttentionRefinementModule(256, 128)
96
+ self.arm32 = AttentionRefinementModule(512, 128)
97
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
98
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
99
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
100
+
101
+ self.init_weight()
102
+
103
+ def forward(self, x):
104
+ H0, W0 = x.size()[2:]
105
+ feat8, feat16, feat32 = self.resnet(x)
106
+ H8, W8 = feat8.size()[2:]
107
+ H16, W16 = feat16.size()[2:]
108
+ H32, W32 = feat32.size()[2:]
109
+
110
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
111
+ avg = self.conv_avg(avg)
112
+ avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
113
+
114
+ feat32_arm = self.arm32(feat32)
115
+ feat32_sum = feat32_arm + avg_up
116
+ feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
117
+ feat32_up = self.conv_head32(feat32_up)
118
+
119
+ feat16_arm = self.arm16(feat16)
120
+ feat16_sum = feat16_arm + feat32_up
121
+ feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
122
+ feat16_up = self.conv_head16(feat16_up)
123
+
124
+ return feat8, feat16_up, feat32_up # x8, x8, x16
125
+
126
+ def init_weight(self):
127
+ for ly in self.children():
128
+ if isinstance(ly, nn.Conv2d):
129
+ nn.init.kaiming_normal_(ly.weight, a=1)
130
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
131
+
132
+ def get_params(self):
133
+ wd_params, nowd_params = [], []
134
+ for name, module in self.named_modules():
135
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
136
+ wd_params.append(module.weight)
137
+ if not module.bias is None:
138
+ nowd_params.append(module.bias)
139
+ elif isinstance(module, nn.BatchNorm2d):
140
+ nowd_params += list(module.parameters())
141
+ return wd_params, nowd_params
142
+
143
+
144
+ ### This is not used, since I replace this with the resnet feature with the same size
145
+ class SpatialPath(nn.Module):
146
+ def __init__(self, *args, **kwargs):
147
+ super(SpatialPath, self).__init__()
148
+ self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
149
+ self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
150
+ self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
151
+ self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
152
+ self.init_weight()
153
+
154
+ def forward(self, x):
155
+ feat = self.conv1(x)
156
+ feat = self.conv2(feat)
157
+ feat = self.conv3(feat)
158
+ feat = self.conv_out(feat)
159
+ return feat
160
+
161
+ def init_weight(self):
162
+ for ly in self.children():
163
+ if isinstance(ly, nn.Conv2d):
164
+ nn.init.kaiming_normal_(ly.weight, a=1)
165
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
166
+
167
+ def get_params(self):
168
+ wd_params, nowd_params = [], []
169
+ for name, module in self.named_modules():
170
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
171
+ wd_params.append(module.weight)
172
+ if not module.bias is None:
173
+ nowd_params.append(module.bias)
174
+ elif isinstance(module, nn.BatchNorm2d):
175
+ nowd_params += list(module.parameters())
176
+ return wd_params, nowd_params
177
+
178
+
179
+ class FeatureFusionModule(nn.Module):
180
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
181
+ super(FeatureFusionModule, self).__init__()
182
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
183
+ self.conv1 = nn.Conv2d(out_chan,
184
+ out_chan//4,
185
+ kernel_size = 1,
186
+ stride = 1,
187
+ padding = 0,
188
+ bias = False)
189
+ self.conv2 = nn.Conv2d(out_chan//4,
190
+ out_chan,
191
+ kernel_size = 1,
192
+ stride = 1,
193
+ padding = 0,
194
+ bias = False)
195
+ self.relu = nn.ReLU(inplace=True)
196
+ self.sigmoid = nn.Sigmoid()
197
+ self.init_weight()
198
+
199
+ def forward(self, fsp, fcp):
200
+ fcat = torch.cat([fsp, fcp], dim=1)
201
+ feat = self.convblk(fcat)
202
+ atten = F.avg_pool2d(feat, feat.size()[2:])
203
+ atten = self.conv1(atten)
204
+ atten = self.relu(atten)
205
+ atten = self.conv2(atten)
206
+ atten = self.sigmoid(atten)
207
+ feat_atten = torch.mul(feat, atten)
208
+ feat_out = feat_atten + feat
209
+ return feat_out
210
+
211
+ def init_weight(self):
212
+ for ly in self.children():
213
+ if isinstance(ly, nn.Conv2d):
214
+ nn.init.kaiming_normal_(ly.weight, a=1)
215
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
216
+
217
+ def get_params(self):
218
+ wd_params, nowd_params = [], []
219
+ for name, module in self.named_modules():
220
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
221
+ wd_params.append(module.weight)
222
+ if not module.bias is None:
223
+ nowd_params.append(module.bias)
224
+ elif isinstance(module, nn.BatchNorm2d):
225
+ nowd_params += list(module.parameters())
226
+ return wd_params, nowd_params
227
+
228
+
229
+ class BiSeNet(nn.Module):
230
+ def __init__(self, n_classes, *args, **kwargs):
231
+ super(BiSeNet, self).__init__()
232
+ self.cp = ContextPath()
233
+ ## here self.sp is deleted
234
+ self.ffm = FeatureFusionModule(256, 256)
235
+ self.conv_out = BiSeNetOutput(256, 256, n_classes)
236
+ self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
237
+ self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
238
+ self.init_weight()
239
+
240
+ def forward(self, x):
241
+ H, W = x.size()[2:]
242
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
243
+ feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
244
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
245
+
246
+ feat_out = self.conv_out(feat_fuse)
247
+ feat_out16 = self.conv_out16(feat_cp8)
248
+ feat_out32 = self.conv_out32(feat_cp16)
249
+
250
+ feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
251
+ feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
252
+ feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
253
+ return feat_out, feat_out16, feat_out32
254
+
255
+ def init_weight(self):
256
+ for ly in self.children():
257
+ if isinstance(ly, nn.Conv2d):
258
+ nn.init.kaiming_normal_(ly.weight, a=1)
259
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
260
+
261
+ def get_params(self):
262
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
263
+ for name, child in self.named_children():
264
+ child_wd_params, child_nowd_params = child.get_params()
265
+ if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
266
+ lr_mul_wd_params += child_wd_params
267
+ lr_mul_nowd_params += child_nowd_params
268
+ else:
269
+ wd_params += child_wd_params
270
+ nowd_params += child_nowd_params
271
+ return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
272
+
273
+
274
+ if __name__ == "__main__":
275
+ net = BiSeNet(19)
276
+ net.cuda()
277
+ net.eval()
278
+ in_ten = torch.randn(16, 3, 640, 480).cuda()
279
+ out, out16, out32 = net(in_ten)
280
+ print(out.shape)
281
+
282
+ net.get_params()
ConsistentID/lib/BiSeNet/modules/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .bn import ABN, InPlaceABN, InPlaceABNSync
2
+ from .functions import ACT_RELU, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE
3
+ from .misc import GlobalAvgPool2d, SingleGPU
4
+ from .residual import IdentityResidualBlock
5
+ from .dense import DenseModule
ConsistentID/lib/BiSeNet/modules/bn.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as functional
4
+
5
+ try:
6
+ from queue import Queue
7
+ except ImportError:
8
+ from Queue import Queue
9
+
10
+ from .functions import *
11
+
12
+
13
+ class ABN(nn.Module):
14
+ """Activated Batch Normalization
15
+
16
+ This gathers a `BatchNorm2d` and an activation function in a single module
17
+ """
18
+
19
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
20
+ """Creates an Activated Batch Normalization module
21
+
22
+ Parameters
23
+ ----------
24
+ num_features : int
25
+ Number of feature channels in the input and output.
26
+ eps : float
27
+ Small constant to prevent numerical issues.
28
+ momentum : float
29
+ Momentum factor applied to compute running statistics as.
30
+ affine : bool
31
+ If `True` apply learned scale and shift transformation after normalization.
32
+ activation : str
33
+ Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
34
+ slope : float
35
+ Negative slope for the `leaky_relu` activation.
36
+ """
37
+ super(ABN, self).__init__()
38
+ self.num_features = num_features
39
+ self.affine = affine
40
+ self.eps = eps
41
+ self.momentum = momentum
42
+ self.activation = activation
43
+ self.slope = slope
44
+ if self.affine:
45
+ self.weight = nn.Parameter(torch.ones(num_features))
46
+ self.bias = nn.Parameter(torch.zeros(num_features))
47
+ else:
48
+ self.register_parameter('weight', None)
49
+ self.register_parameter('bias', None)
50
+ self.register_buffer('running_mean', torch.zeros(num_features))
51
+ self.register_buffer('running_var', torch.ones(num_features))
52
+ self.reset_parameters()
53
+
54
+ def reset_parameters(self):
55
+ nn.init.constant_(self.running_mean, 0)
56
+ nn.init.constant_(self.running_var, 1)
57
+ if self.affine:
58
+ nn.init.constant_(self.weight, 1)
59
+ nn.init.constant_(self.bias, 0)
60
+
61
+ def forward(self, x):
62
+ x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
63
+ self.training, self.momentum, self.eps)
64
+
65
+ if self.activation == ACT_RELU:
66
+ return functional.relu(x, inplace=True)
67
+ elif self.activation == ACT_LEAKY_RELU:
68
+ return functional.leaky_relu(x, negative_slope=self.slope, inplace=True)
69
+ elif self.activation == ACT_ELU:
70
+ return functional.elu(x, inplace=True)
71
+ else:
72
+ return x
73
+
74
+ def __repr__(self):
75
+ rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
76
+ ' affine={affine}, activation={activation}'
77
+ if self.activation == "leaky_relu":
78
+ rep += ', slope={slope})'
79
+ else:
80
+ rep += ')'
81
+ return rep.format(name=self.__class__.__name__, **self.__dict__)
82
+
83
+
84
+ class InPlaceABN(ABN):
85
+ """InPlace Activated Batch Normalization"""
86
+
87
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
88
+ """Creates an InPlace Activated Batch Normalization module
89
+
90
+ Parameters
91
+ ----------
92
+ num_features : int
93
+ Number of feature channels in the input and output.
94
+ eps : float
95
+ Small constant to prevent numerical issues.
96
+ momentum : float
97
+ Momentum factor applied to compute running statistics as.
98
+ affine : bool
99
+ If `True` apply learned scale and shift transformation after normalization.
100
+ activation : str
101
+ Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
102
+ slope : float
103
+ Negative slope for the `leaky_relu` activation.
104
+ """
105
+ super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope)
106
+
107
+ def forward(self, x):
108
+ return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var,
109
+ self.training, self.momentum, self.eps, self.activation, self.slope)
110
+
111
+
112
+ class InPlaceABNSync(ABN):
113
+ """InPlace Activated Batch Normalization with cross-GPU synchronization
114
+ This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DistributedDataParallel`.
115
+ """
116
+
117
+ def forward(self, x):
118
+ return inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var,
119
+ self.training, self.momentum, self.eps, self.activation, self.slope)
120
+
121
+ def __repr__(self):
122
+ rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
123
+ ' affine={affine}, activation={activation}'
124
+ if self.activation == "leaky_relu":
125
+ rep += ', slope={slope})'
126
+ else:
127
+ rep += ')'
128
+ return rep.format(name=self.__class__.__name__, **self.__dict__)
129
+
130
+
ConsistentID/lib/BiSeNet/modules/deeplab.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as functional
4
+
5
+ from models._util import try_index
6
+ from .bn import ABN
7
+
8
+
9
+ class DeeplabV3(nn.Module):
10
+ def __init__(self,
11
+ in_channels,
12
+ out_channels,
13
+ hidden_channels=256,
14
+ dilations=(12, 24, 36),
15
+ norm_act=ABN,
16
+ pooling_size=None):
17
+ super(DeeplabV3, self).__init__()
18
+ self.pooling_size = pooling_size
19
+
20
+ self.map_convs = nn.ModuleList([
21
+ nn.Conv2d(in_channels, hidden_channels, 1, bias=False),
22
+ nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[0], padding=dilations[0]),
23
+ nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[1], padding=dilations[1]),
24
+ nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[2], padding=dilations[2])
25
+ ])
26
+ self.map_bn = norm_act(hidden_channels * 4)
27
+
28
+ self.global_pooling_conv = nn.Conv2d(in_channels, hidden_channels, 1, bias=False)
29
+ self.global_pooling_bn = norm_act(hidden_channels)
30
+
31
+ self.red_conv = nn.Conv2d(hidden_channels * 4, out_channels, 1, bias=False)
32
+ self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False)
33
+ self.red_bn = norm_act(out_channels)
34
+
35
+ self.reset_parameters(self.map_bn.activation, self.map_bn.slope)
36
+
37
+ def reset_parameters(self, activation, slope):
38
+ gain = nn.init.calculate_gain(activation, slope)
39
+ for m in self.modules():
40
+ if isinstance(m, nn.Conv2d):
41
+ nn.init.xavier_normal_(m.weight.data, gain)
42
+ if hasattr(m, "bias") and m.bias is not None:
43
+ nn.init.constant_(m.bias, 0)
44
+ elif isinstance(m, ABN):
45
+ if hasattr(m, "weight") and m.weight is not None:
46
+ nn.init.constant_(m.weight, 1)
47
+ if hasattr(m, "bias") and m.bias is not None:
48
+ nn.init.constant_(m.bias, 0)
49
+
50
+ def forward(self, x):
51
+ # Map convolutions
52
+ out = torch.cat([m(x) for m in self.map_convs], dim=1)
53
+ out = self.map_bn(out)
54
+ out = self.red_conv(out)
55
+
56
+ # Global pooling
57
+ pool = self._global_pooling(x)
58
+ pool = self.global_pooling_conv(pool)
59
+ pool = self.global_pooling_bn(pool)
60
+ pool = self.pool_red_conv(pool)
61
+ if self.training or self.pooling_size is None:
62
+ pool = pool.repeat(1, 1, x.size(2), x.size(3))
63
+
64
+ out += pool
65
+ out = self.red_bn(out)
66
+ return out
67
+
68
+ def _global_pooling(self, x):
69
+ if self.training or self.pooling_size is None:
70
+ pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1)
71
+ pool = pool.view(x.size(0), x.size(1), 1, 1)
72
+ else:
73
+ pooling_size = (min(try_index(self.pooling_size, 0), x.shape[2]),
74
+ min(try_index(self.pooling_size, 1), x.shape[3]))
75
+ padding = (
76
+ (pooling_size[1] - 1) // 2,
77
+ (pooling_size[1] - 1) // 2 if pooling_size[1] % 2 == 1 else (pooling_size[1] - 1) // 2 + 1,
78
+ (pooling_size[0] - 1) // 2,
79
+ (pooling_size[0] - 1) // 2 if pooling_size[0] % 2 == 1 else (pooling_size[0] - 1) // 2 + 1
80
+ )
81
+
82
+ pool = functional.avg_pool2d(x, pooling_size, stride=1)
83
+ pool = functional.pad(pool, pad=padding, mode="replicate")
84
+ return pool
ConsistentID/lib/BiSeNet/modules/dense.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .bn import ABN
7
+
8
+
9
+ class DenseModule(nn.Module):
10
+ def __init__(self, in_channels, growth, layers, bottleneck_factor=4, norm_act=ABN, dilation=1):
11
+ super(DenseModule, self).__init__()
12
+ self.in_channels = in_channels
13
+ self.growth = growth
14
+ self.layers = layers
15
+
16
+ self.convs1 = nn.ModuleList()
17
+ self.convs3 = nn.ModuleList()
18
+ for i in range(self.layers):
19
+ self.convs1.append(nn.Sequential(OrderedDict([
20
+ ("bn", norm_act(in_channels)),
21
+ ("conv", nn.Conv2d(in_channels, self.growth * bottleneck_factor, 1, bias=False))
22
+ ])))
23
+ self.convs3.append(nn.Sequential(OrderedDict([
24
+ ("bn", norm_act(self.growth * bottleneck_factor)),
25
+ ("conv", nn.Conv2d(self.growth * bottleneck_factor, self.growth, 3, padding=dilation, bias=False,
26
+ dilation=dilation))
27
+ ])))
28
+ in_channels += self.growth
29
+
30
+ @property
31
+ def out_channels(self):
32
+ return self.in_channels + self.growth * self.layers
33
+
34
+ def forward(self, x):
35
+ inputs = [x]
36
+ for i in range(self.layers):
37
+ x = torch.cat(inputs, dim=1)
38
+ x = self.convs1[i](x)
39
+ x = self.convs3[i](x)
40
+ inputs += [x]
41
+
42
+ return torch.cat(inputs, dim=1)
ConsistentID/lib/BiSeNet/modules/functions.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import path
2
+ import torch
3
+ import torch.distributed as dist
4
+ import torch.autograd as autograd
5
+ import torch.cuda.comm as comm
6
+ from torch.autograd.function import once_differentiable
7
+ from torch.utils.cpp_extension import load
8
+
9
+ _src_path = path.join(path.dirname(path.abspath(__file__)), "src")
10
+ _backend = load(name="inplace_abn",
11
+ extra_cflags=["-O3"],
12
+ sources=[path.join(_src_path, f) for f in [
13
+ "inplace_abn.cpp",
14
+ "inplace_abn_cpu.cpp",
15
+ "inplace_abn_cuda.cu",
16
+ "inplace_abn_cuda_half.cu"
17
+ ]],
18
+ extra_cuda_cflags=["--expt-extended-lambda"])
19
+
20
+ # Activation names
21
+ ACT_RELU = "relu"
22
+ ACT_LEAKY_RELU = "leaky_relu"
23
+ ACT_ELU = "elu"
24
+ ACT_NONE = "none"
25
+
26
+
27
+ def _check(fn, *args, **kwargs):
28
+ success = fn(*args, **kwargs)
29
+ if not success:
30
+ raise RuntimeError("CUDA Error encountered in {}".format(fn))
31
+
32
+
33
+ def _broadcast_shape(x):
34
+ out_size = []
35
+ for i, s in enumerate(x.size()):
36
+ if i != 1:
37
+ out_size.append(1)
38
+ else:
39
+ out_size.append(s)
40
+ return out_size
41
+
42
+
43
+ def _reduce(x):
44
+ if len(x.size()) == 2:
45
+ return x.sum(dim=0)
46
+ else:
47
+ n, c = x.size()[0:2]
48
+ return x.contiguous().view((n, c, -1)).sum(2).sum(0)
49
+
50
+
51
+ def _count_samples(x):
52
+ count = 1
53
+ for i, s in enumerate(x.size()):
54
+ if i != 1:
55
+ count *= s
56
+ return count
57
+
58
+
59
+ def _act_forward(ctx, x):
60
+ if ctx.activation == ACT_LEAKY_RELU:
61
+ _backend.leaky_relu_forward(x, ctx.slope)
62
+ elif ctx.activation == ACT_ELU:
63
+ _backend.elu_forward(x)
64
+ elif ctx.activation == ACT_NONE:
65
+ pass
66
+
67
+
68
+ def _act_backward(ctx, x, dx):
69
+ if ctx.activation == ACT_LEAKY_RELU:
70
+ _backend.leaky_relu_backward(x, dx, ctx.slope)
71
+ elif ctx.activation == ACT_ELU:
72
+ _backend.elu_backward(x, dx)
73
+ elif ctx.activation == ACT_NONE:
74
+ pass
75
+
76
+
77
+ class InPlaceABN(autograd.Function):
78
+ @staticmethod
79
+ def forward(ctx, x, weight, bias, running_mean, running_var,
80
+ training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01):
81
+ # Save context
82
+ ctx.training = training
83
+ ctx.momentum = momentum
84
+ ctx.eps = eps
85
+ ctx.activation = activation
86
+ ctx.slope = slope
87
+ ctx.affine = weight is not None and bias is not None
88
+
89
+ # Prepare inputs
90
+ count = _count_samples(x)
91
+ x = x.contiguous()
92
+ weight = weight.contiguous() if ctx.affine else x.new_empty(0)
93
+ bias = bias.contiguous() if ctx.affine else x.new_empty(0)
94
+
95
+ if ctx.training:
96
+ mean, var = _backend.mean_var(x)
97
+
98
+ # Update running stats
99
+ running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
100
+ running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1))
101
+
102
+ # Mark in-place modified tensors
103
+ ctx.mark_dirty(x, running_mean, running_var)
104
+ else:
105
+ mean, var = running_mean.contiguous(), running_var.contiguous()
106
+ ctx.mark_dirty(x)
107
+
108
+ # BN forward + activation
109
+ _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
110
+ _act_forward(ctx, x)
111
+
112
+ # Output
113
+ ctx.var = var
114
+ ctx.save_for_backward(x, var, weight, bias)
115
+ return x
116
+
117
+ @staticmethod
118
+ @once_differentiable
119
+ def backward(ctx, dz):
120
+ z, var, weight, bias = ctx.saved_tensors
121
+ dz = dz.contiguous()
122
+
123
+ # Undo activation
124
+ _act_backward(ctx, z, dz)
125
+
126
+ if ctx.training:
127
+ edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
128
+ else:
129
+ # TODO: implement simplified CUDA backward for inference mode
130
+ edz = dz.new_zeros(dz.size(1))
131
+ eydz = dz.new_zeros(dz.size(1))
132
+
133
+ dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
134
+ dweight = eydz * weight.sign() if ctx.affine else None
135
+ dbias = edz if ctx.affine else None
136
+
137
+ return dx, dweight, dbias, None, None, None, None, None, None, None
138
+
139
+ class InPlaceABNSync(autograd.Function):
140
+ @classmethod
141
+ def forward(cls, ctx, x, weight, bias, running_mean, running_var,
142
+ training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01, equal_batches=True):
143
+ # Save context
144
+ ctx.training = training
145
+ ctx.momentum = momentum
146
+ ctx.eps = eps
147
+ ctx.activation = activation
148
+ ctx.slope = slope
149
+ ctx.affine = weight is not None and bias is not None
150
+
151
+ # Prepare inputs
152
+ ctx.world_size = dist.get_world_size() if dist.is_initialized() else 1
153
+
154
+ #count = _count_samples(x)
155
+ batch_size = x.new_tensor([x.shape[0]],dtype=torch.long)
156
+
157
+ x = x.contiguous()
158
+ weight = weight.contiguous() if ctx.affine else x.new_empty(0)
159
+ bias = bias.contiguous() if ctx.affine else x.new_empty(0)
160
+
161
+ if ctx.training:
162
+ mean, var = _backend.mean_var(x)
163
+ if ctx.world_size>1:
164
+ # get global batch size
165
+ if equal_batches:
166
+ batch_size *= ctx.world_size
167
+ else:
168
+ dist.all_reduce(batch_size, dist.ReduceOp.SUM)
169
+
170
+ ctx.factor = x.shape[0]/float(batch_size.item())
171
+
172
+ mean_all = mean.clone() * ctx.factor
173
+ dist.all_reduce(mean_all, dist.ReduceOp.SUM)
174
+
175
+ var_all = (var + (mean - mean_all) ** 2) * ctx.factor
176
+ dist.all_reduce(var_all, dist.ReduceOp.SUM)
177
+
178
+ mean = mean_all
179
+ var = var_all
180
+
181
+ # Update running stats
182
+ running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
183
+ count = batch_size.item() * x.view(x.shape[0],x.shape[1],-1).shape[-1]
184
+ running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * (float(count) / (count - 1)))
185
+
186
+ # Mark in-place modified tensors
187
+ ctx.mark_dirty(x, running_mean, running_var)
188
+ else:
189
+ mean, var = running_mean.contiguous(), running_var.contiguous()
190
+ ctx.mark_dirty(x)
191
+
192
+ # BN forward + activation
193
+ _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
194
+ _act_forward(ctx, x)
195
+
196
+ # Output
197
+ ctx.var = var
198
+ ctx.save_for_backward(x, var, weight, bias)
199
+ return x
200
+
201
+ @staticmethod
202
+ @once_differentiable
203
+ def backward(ctx, dz):
204
+ z, var, weight, bias = ctx.saved_tensors
205
+ dz = dz.contiguous()
206
+
207
+ # Undo activation
208
+ _act_backward(ctx, z, dz)
209
+
210
+ if ctx.training:
211
+ edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
212
+ edz_local = edz.clone()
213
+ eydz_local = eydz.clone()
214
+
215
+ if ctx.world_size>1:
216
+ edz *= ctx.factor
217
+ dist.all_reduce(edz, dist.ReduceOp.SUM)
218
+
219
+ eydz *= ctx.factor
220
+ dist.all_reduce(eydz, dist.ReduceOp.SUM)
221
+ else:
222
+ edz_local = edz = dz.new_zeros(dz.size(1))
223
+ eydz_local = eydz = dz.new_zeros(dz.size(1))
224
+
225
+ dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
226
+ dweight = eydz_local * weight.sign() if ctx.affine else None
227
+ dbias = edz_local if ctx.affine else None
228
+
229
+ return dx, dweight, dbias, None, None, None, None, None, None, None
230
+
231
+ inplace_abn = InPlaceABN.apply
232
+ inplace_abn_sync = InPlaceABNSync.apply
233
+
234
+ __all__ = ["inplace_abn", "inplace_abn_sync", "ACT_RELU", "ACT_LEAKY_RELU", "ACT_ELU", "ACT_NONE"]
ConsistentID/lib/BiSeNet/modules/misc.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.distributed as dist
4
+
5
+ class GlobalAvgPool2d(nn.Module):
6
+ def __init__(self):
7
+ """Global average pooling over the input's spatial dimensions"""
8
+ super(GlobalAvgPool2d, self).__init__()
9
+
10
+ def forward(self, inputs):
11
+ in_size = inputs.size()
12
+ return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2)
13
+
14
+ class SingleGPU(nn.Module):
15
+ def __init__(self, module):
16
+ super(SingleGPU, self).__init__()
17
+ self.module=module
18
+
19
+ def forward(self, input):
20
+ return self.module(input.cuda(non_blocking=True))
21
+
ConsistentID/lib/BiSeNet/modules/residual.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch.nn as nn
4
+
5
+ from .bn import ABN
6
+
7
+
8
+ class IdentityResidualBlock(nn.Module):
9
+ def __init__(self,
10
+ in_channels,
11
+ channels,
12
+ stride=1,
13
+ dilation=1,
14
+ groups=1,
15
+ norm_act=ABN,
16
+ dropout=None):
17
+ """Configurable identity-mapping residual block
18
+
19
+ Parameters
20
+ ----------
21
+ in_channels : int
22
+ Number of input channels.
23
+ channels : list of int
24
+ Number of channels in the internal feature maps. Can either have two or three elements: if three construct
25
+ a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then
26
+ `3 x 3` then `1 x 1` convolutions.
27
+ stride : int
28
+ Stride of the first `3 x 3` convolution
29
+ dilation : int
30
+ Dilation to apply to the `3 x 3` convolutions.
31
+ groups : int
32
+ Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with
33
+ bottleneck blocks.
34
+ norm_act : callable
35
+ Function to create normalization / activation Module.
36
+ dropout: callable
37
+ Function to create Dropout Module.
38
+ """
39
+ super(IdentityResidualBlock, self).__init__()
40
+
41
+ # Check parameters for inconsistencies
42
+ if len(channels) != 2 and len(channels) != 3:
43
+ raise ValueError("channels must contain either two or three values")
44
+ if len(channels) == 2 and groups != 1:
45
+ raise ValueError("groups > 1 are only valid if len(channels) == 3")
46
+
47
+ is_bottleneck = len(channels) == 3
48
+ need_proj_conv = stride != 1 or in_channels != channels[-1]
49
+
50
+ self.bn1 = norm_act(in_channels)
51
+ if not is_bottleneck:
52
+ layers = [
53
+ ("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False,
54
+ dilation=dilation)),
55
+ ("bn2", norm_act(channels[0])),
56
+ ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
57
+ dilation=dilation))
58
+ ]
59
+ if dropout is not None:
60
+ layers = layers[0:2] + [("dropout", dropout())] + layers[2:]
61
+ else:
62
+ layers = [
63
+ ("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=stride, padding=0, bias=False)),
64
+ ("bn2", norm_act(channels[0])),
65
+ ("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
66
+ groups=groups, dilation=dilation)),
67
+ ("bn3", norm_act(channels[1])),
68
+ ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False))
69
+ ]
70
+ if dropout is not None:
71
+ layers = layers[0:4] + [("dropout", dropout())] + layers[4:]
72
+ self.convs = nn.Sequential(OrderedDict(layers))
73
+
74
+ if need_proj_conv:
75
+ self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False)
76
+
77
+ def forward(self, x):
78
+ if hasattr(self, "proj_conv"):
79
+ bn1 = self.bn1(x)
80
+ shortcut = self.proj_conv(bn1)
81
+ else:
82
+ shortcut = x.clone()
83
+ bn1 = self.bn1(x)
84
+
85
+ out = self.convs(bn1)
86
+ out.add_(shortcut)
87
+
88
+ return out
ConsistentID/lib/BiSeNet/modules/src/checks.h ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/ATen.h>
4
+
5
+ // Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT
6
+ #ifndef AT_CHECK
7
+ #define AT_CHECK AT_ASSERT
8
+ #endif
9
+
10
+ #define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor")
11
+ #define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor")
12
+ #define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous")
13
+
14
+ #define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
15
+ #define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x)
ConsistentID/lib/BiSeNet/modules/src/inplace_abn.cpp ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+ #include <vector>
4
+
5
+ #include "inplace_abn.h"
6
+
7
+ std::vector<at::Tensor> mean_var(at::Tensor x) {
8
+ if (x.is_cuda()) {
9
+ if (x.type().scalarType() == at::ScalarType::Half) {
10
+ return mean_var_cuda_h(x);
11
+ } else {
12
+ return mean_var_cuda(x);
13
+ }
14
+ } else {
15
+ return mean_var_cpu(x);
16
+ }
17
+ }
18
+
19
+ at::Tensor forward(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
20
+ bool affine, float eps) {
21
+ if (x.is_cuda()) {
22
+ if (x.type().scalarType() == at::ScalarType::Half) {
23
+ return forward_cuda_h(x, mean, var, weight, bias, affine, eps);
24
+ } else {
25
+ return forward_cuda(x, mean, var, weight, bias, affine, eps);
26
+ }
27
+ } else {
28
+ return forward_cpu(x, mean, var, weight, bias, affine, eps);
29
+ }
30
+ }
31
+
32
+ std::vector<at::Tensor> edz_eydz(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
33
+ bool affine, float eps) {
34
+ if (z.is_cuda()) {
35
+ if (z.type().scalarType() == at::ScalarType::Half) {
36
+ return edz_eydz_cuda_h(z, dz, weight, bias, affine, eps);
37
+ } else {
38
+ return edz_eydz_cuda(z, dz, weight, bias, affine, eps);
39
+ }
40
+ } else {
41
+ return edz_eydz_cpu(z, dz, weight, bias, affine, eps);
42
+ }
43
+ }
44
+
45
+ at::Tensor backward(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
46
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
47
+ if (z.is_cuda()) {
48
+ if (z.type().scalarType() == at::ScalarType::Half) {
49
+ return backward_cuda_h(z, dz, var, weight, bias, edz, eydz, affine, eps);
50
+ } else {
51
+ return backward_cuda(z, dz, var, weight, bias, edz, eydz, affine, eps);
52
+ }
53
+ } else {
54
+ return backward_cpu(z, dz, var, weight, bias, edz, eydz, affine, eps);
55
+ }
56
+ }
57
+
58
+ void leaky_relu_forward(at::Tensor z, float slope) {
59
+ at::leaky_relu_(z, slope);
60
+ }
61
+
62
+ void leaky_relu_backward(at::Tensor z, at::Tensor dz, float slope) {
63
+ if (z.is_cuda()) {
64
+ if (z.type().scalarType() == at::ScalarType::Half) {
65
+ return leaky_relu_backward_cuda_h(z, dz, slope);
66
+ } else {
67
+ return leaky_relu_backward_cuda(z, dz, slope);
68
+ }
69
+ } else {
70
+ return leaky_relu_backward_cpu(z, dz, slope);
71
+ }
72
+ }
73
+
74
+ void elu_forward(at::Tensor z) {
75
+ at::elu_(z);
76
+ }
77
+
78
+ void elu_backward(at::Tensor z, at::Tensor dz) {
79
+ if (z.is_cuda()) {
80
+ return elu_backward_cuda(z, dz);
81
+ } else {
82
+ return elu_backward_cpu(z, dz);
83
+ }
84
+ }
85
+
86
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
87
+ m.def("mean_var", &mean_var, "Mean and variance computation");
88
+ m.def("forward", &forward, "In-place forward computation");
89
+ m.def("edz_eydz", &edz_eydz, "First part of backward computation");
90
+ m.def("backward", &backward, "Second part of backward computation");
91
+ m.def("leaky_relu_forward", &leaky_relu_forward, "Leaky relu forward computation");
92
+ m.def("leaky_relu_backward", &leaky_relu_backward, "Leaky relu backward computation and inversion");
93
+ m.def("elu_forward", &elu_forward, "Elu forward computation");
94
+ m.def("elu_backward", &elu_backward, "Elu backward computation and inversion");
95
+ }
ConsistentID/lib/BiSeNet/modules/src/inplace_abn.h ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/ATen.h>
4
+
5
+ #include <vector>
6
+
7
+ std::vector<at::Tensor> mean_var_cpu(at::Tensor x);
8
+ std::vector<at::Tensor> mean_var_cuda(at::Tensor x);
9
+ std::vector<at::Tensor> mean_var_cuda_h(at::Tensor x);
10
+
11
+ at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
12
+ bool affine, float eps);
13
+ at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
14
+ bool affine, float eps);
15
+ at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
16
+ bool affine, float eps);
17
+
18
+ std::vector<at::Tensor> edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
19
+ bool affine, float eps);
20
+ std::vector<at::Tensor> edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
21
+ bool affine, float eps);
22
+ std::vector<at::Tensor> edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
23
+ bool affine, float eps);
24
+
25
+ at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
26
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps);
27
+ at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
28
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps);
29
+ at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
30
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps);
31
+
32
+ void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope);
33
+ void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope);
34
+ void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope);
35
+
36
+ void elu_backward_cpu(at::Tensor z, at::Tensor dz);
37
+ void elu_backward_cuda(at::Tensor z, at::Tensor dz);
38
+
39
+ static void get_dims(at::Tensor x, int64_t& num, int64_t& chn, int64_t& sp) {
40
+ num = x.size(0);
41
+ chn = x.size(1);
42
+ sp = 1;
43
+ for (int64_t i = 2; i < x.ndimension(); ++i)
44
+ sp *= x.size(i);
45
+ }
46
+
47
+ /*
48
+ * Specialized CUDA reduction functions for BN
49
+ */
50
+ #ifdef __CUDACC__
51
+
52
+ #include "utils/cuda.cuh"
53
+
54
+ template <typename T, typename Op>
55
+ __device__ T reduce(Op op, int plane, int N, int S) {
56
+ T sum = (T)0;
57
+ for (int batch = 0; batch < N; ++batch) {
58
+ for (int x = threadIdx.x; x < S; x += blockDim.x) {
59
+ sum += op(batch, plane, x);
60
+ }
61
+ }
62
+
63
+ // sum over NumThreads within a warp
64
+ sum = warpSum(sum);
65
+
66
+ // 'transpose', and reduce within warp again
67
+ __shared__ T shared[32];
68
+ __syncthreads();
69
+ if (threadIdx.x % WARP_SIZE == 0) {
70
+ shared[threadIdx.x / WARP_SIZE] = sum;
71
+ }
72
+ if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
73
+ // zero out the other entries in shared
74
+ shared[threadIdx.x] = (T)0;
75
+ }
76
+ __syncthreads();
77
+ if (threadIdx.x / WARP_SIZE == 0) {
78
+ sum = warpSum(shared[threadIdx.x]);
79
+ if (threadIdx.x == 0) {
80
+ shared[0] = sum;
81
+ }
82
+ }
83
+ __syncthreads();
84
+
85
+ // Everyone picks it up, should be broadcast into the whole gradInput
86
+ return shared[0];
87
+ }
88
+ #endif
ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cpu.cpp ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+
3
+ #include <vector>
4
+
5
+ #include "utils/checks.h"
6
+ #include "inplace_abn.h"
7
+
8
+ at::Tensor reduce_sum(at::Tensor x) {
9
+ if (x.ndimension() == 2) {
10
+ return x.sum(0);
11
+ } else {
12
+ auto x_view = x.view({x.size(0), x.size(1), -1});
13
+ return x_view.sum(-1).sum(0);
14
+ }
15
+ }
16
+
17
+ at::Tensor broadcast_to(at::Tensor v, at::Tensor x) {
18
+ if (x.ndimension() == 2) {
19
+ return v;
20
+ } else {
21
+ std::vector<int64_t> broadcast_size = {1, -1};
22
+ for (int64_t i = 2; i < x.ndimension(); ++i)
23
+ broadcast_size.push_back(1);
24
+
25
+ return v.view(broadcast_size);
26
+ }
27
+ }
28
+
29
+ int64_t count(at::Tensor x) {
30
+ int64_t count = x.size(0);
31
+ for (int64_t i = 2; i < x.ndimension(); ++i)
32
+ count *= x.size(i);
33
+
34
+ return count;
35
+ }
36
+
37
+ at::Tensor invert_affine(at::Tensor z, at::Tensor weight, at::Tensor bias, bool affine, float eps) {
38
+ if (affine) {
39
+ return (z - broadcast_to(bias, z)) / broadcast_to(at::abs(weight) + eps, z);
40
+ } else {
41
+ return z;
42
+ }
43
+ }
44
+
45
+ std::vector<at::Tensor> mean_var_cpu(at::Tensor x) {
46
+ auto num = count(x);
47
+ auto mean = reduce_sum(x) / num;
48
+ auto diff = x - broadcast_to(mean, x);
49
+ auto var = reduce_sum(diff.pow(2)) / num;
50
+
51
+ return {mean, var};
52
+ }
53
+
54
+ at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
55
+ bool affine, float eps) {
56
+ auto gamma = affine ? at::abs(weight) + eps : at::ones_like(var);
57
+ auto mul = at::rsqrt(var + eps) * gamma;
58
+
59
+ x.sub_(broadcast_to(mean, x));
60
+ x.mul_(broadcast_to(mul, x));
61
+ if (affine) x.add_(broadcast_to(bias, x));
62
+
63
+ return x;
64
+ }
65
+
66
+ std::vector<at::Tensor> edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
67
+ bool affine, float eps) {
68
+ auto edz = reduce_sum(dz);
69
+ auto y = invert_affine(z, weight, bias, affine, eps);
70
+ auto eydz = reduce_sum(y * dz);
71
+
72
+ return {edz, eydz};
73
+ }
74
+
75
+ at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
76
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
77
+ auto y = invert_affine(z, weight, bias, affine, eps);
78
+ auto mul = affine ? at::rsqrt(var + eps) * (at::abs(weight) + eps) : at::rsqrt(var + eps);
79
+
80
+ auto num = count(z);
81
+ auto dx = (dz - broadcast_to(edz / num, dz) - y * broadcast_to(eydz / num, dz)) * broadcast_to(mul, dz);
82
+ return dx;
83
+ }
84
+
85
+ void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope) {
86
+ CHECK_CPU_INPUT(z);
87
+ CHECK_CPU_INPUT(dz);
88
+
89
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cpu", ([&] {
90
+ int64_t count = z.numel();
91
+ auto *_z = z.data<scalar_t>();
92
+ auto *_dz = dz.data<scalar_t>();
93
+
94
+ for (int64_t i = 0; i < count; ++i) {
95
+ if (_z[i] < 0) {
96
+ _z[i] *= 1 / slope;
97
+ _dz[i] *= slope;
98
+ }
99
+ }
100
+ }));
101
+ }
102
+
103
+ void elu_backward_cpu(at::Tensor z, at::Tensor dz) {
104
+ CHECK_CPU_INPUT(z);
105
+ CHECK_CPU_INPUT(dz);
106
+
107
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "elu_backward_cpu", ([&] {
108
+ int64_t count = z.numel();
109
+ auto *_z = z.data<scalar_t>();
110
+ auto *_dz = dz.data<scalar_t>();
111
+
112
+ for (int64_t i = 0; i < count; ++i) {
113
+ if (_z[i] < 0) {
114
+ _z[i] = log1p(_z[i]);
115
+ _dz[i] *= (_z[i] + 1.f);
116
+ }
117
+ }
118
+ }));
119
+ }
ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cuda.cu ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+
3
+ #include <thrust/device_ptr.h>
4
+ #include <thrust/transform.h>
5
+
6
+ #include <vector>
7
+
8
+ #include "utils/checks.h"
9
+ #include "utils/cuda.cuh"
10
+ #include "inplace_abn.h"
11
+
12
+ #include <ATen/cuda/CUDAContext.h>
13
+
14
+ // Operations for reduce
15
+ template<typename T>
16
+ struct SumOp {
17
+ __device__ SumOp(const T *t, int c, int s)
18
+ : tensor(t), chn(c), sp(s) {}
19
+ __device__ __forceinline__ T operator()(int batch, int plane, int n) {
20
+ return tensor[(batch * chn + plane) * sp + n];
21
+ }
22
+ const T *tensor;
23
+ const int chn;
24
+ const int sp;
25
+ };
26
+
27
+ template<typename T>
28
+ struct VarOp {
29
+ __device__ VarOp(T m, const T *t, int c, int s)
30
+ : mean(m), tensor(t), chn(c), sp(s) {}
31
+ __device__ __forceinline__ T operator()(int batch, int plane, int n) {
32
+ T val = tensor[(batch * chn + plane) * sp + n];
33
+ return (val - mean) * (val - mean);
34
+ }
35
+ const T mean;
36
+ const T *tensor;
37
+ const int chn;
38
+ const int sp;
39
+ };
40
+
41
+ template<typename T>
42
+ struct GradOp {
43
+ __device__ GradOp(T _weight, T _bias, const T *_z, const T *_dz, int c, int s)
44
+ : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {}
45
+ __device__ __forceinline__ Pair<T> operator()(int batch, int plane, int n) {
46
+ T _y = (z[(batch * chn + plane) * sp + n] - bias) / weight;
47
+ T _dz = dz[(batch * chn + plane) * sp + n];
48
+ return Pair<T>(_dz, _y * _dz);
49
+ }
50
+ const T weight;
51
+ const T bias;
52
+ const T *z;
53
+ const T *dz;
54
+ const int chn;
55
+ const int sp;
56
+ };
57
+
58
+ /***********
59
+ * mean_var
60
+ ***********/
61
+
62
+ template<typename T>
63
+ __global__ void mean_var_kernel(const T *x, T *mean, T *var, int num, int chn, int sp) {
64
+ int plane = blockIdx.x;
65
+ T norm = T(1) / T(num * sp);
66
+
67
+ T _mean = reduce<T, SumOp<T>>(SumOp<T>(x, chn, sp), plane, num, sp) * norm;
68
+ __syncthreads();
69
+ T _var = reduce<T, VarOp<T>>(VarOp<T>(_mean, x, chn, sp), plane, num, sp) * norm;
70
+
71
+ if (threadIdx.x == 0) {
72
+ mean[plane] = _mean;
73
+ var[plane] = _var;
74
+ }
75
+ }
76
+
77
+ std::vector<at::Tensor> mean_var_cuda(at::Tensor x) {
78
+ CHECK_CUDA_INPUT(x);
79
+
80
+ // Extract dimensions
81
+ int64_t num, chn, sp;
82
+ get_dims(x, num, chn, sp);
83
+
84
+ // Prepare output tensors
85
+ auto mean = at::empty({chn}, x.options());
86
+ auto var = at::empty({chn}, x.options());
87
+
88
+ // Run kernel
89
+ dim3 blocks(chn);
90
+ dim3 threads(getNumThreads(sp));
91
+ auto stream = at::cuda::getCurrentCUDAStream();
92
+ AT_DISPATCH_FLOATING_TYPES(x.type(), "mean_var_cuda", ([&] {
93
+ mean_var_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
94
+ x.data<scalar_t>(),
95
+ mean.data<scalar_t>(),
96
+ var.data<scalar_t>(),
97
+ num, chn, sp);
98
+ }));
99
+
100
+ return {mean, var};
101
+ }
102
+
103
+ /**********
104
+ * forward
105
+ **********/
106
+
107
+ template<typename T>
108
+ __global__ void forward_kernel(T *x, const T *mean, const T *var, const T *weight, const T *bias,
109
+ bool affine, float eps, int num, int chn, int sp) {
110
+ int plane = blockIdx.x;
111
+
112
+ T _mean = mean[plane];
113
+ T _var = var[plane];
114
+ T _weight = affine ? abs(weight[plane]) + eps : T(1);
115
+ T _bias = affine ? bias[plane] : T(0);
116
+
117
+ T mul = rsqrt(_var + eps) * _weight;
118
+
119
+ for (int batch = 0; batch < num; ++batch) {
120
+ for (int n = threadIdx.x; n < sp; n += blockDim.x) {
121
+ T _x = x[(batch * chn + plane) * sp + n];
122
+ T _y = (_x - _mean) * mul + _bias;
123
+
124
+ x[(batch * chn + plane) * sp + n] = _y;
125
+ }
126
+ }
127
+ }
128
+
129
+ at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
130
+ bool affine, float eps) {
131
+ CHECK_CUDA_INPUT(x);
132
+ CHECK_CUDA_INPUT(mean);
133
+ CHECK_CUDA_INPUT(var);
134
+ CHECK_CUDA_INPUT(weight);
135
+ CHECK_CUDA_INPUT(bias);
136
+
137
+ // Extract dimensions
138
+ int64_t num, chn, sp;
139
+ get_dims(x, num, chn, sp);
140
+
141
+ // Run kernel
142
+ dim3 blocks(chn);
143
+ dim3 threads(getNumThreads(sp));
144
+ auto stream = at::cuda::getCurrentCUDAStream();
145
+ AT_DISPATCH_FLOATING_TYPES(x.type(), "forward_cuda", ([&] {
146
+ forward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
147
+ x.data<scalar_t>(),
148
+ mean.data<scalar_t>(),
149
+ var.data<scalar_t>(),
150
+ weight.data<scalar_t>(),
151
+ bias.data<scalar_t>(),
152
+ affine, eps, num, chn, sp);
153
+ }));
154
+
155
+ return x;
156
+ }
157
+
158
+ /***********
159
+ * edz_eydz
160
+ ***********/
161
+
162
+ template<typename T>
163
+ __global__ void edz_eydz_kernel(const T *z, const T *dz, const T *weight, const T *bias,
164
+ T *edz, T *eydz, bool affine, float eps, int num, int chn, int sp) {
165
+ int plane = blockIdx.x;
166
+
167
+ T _weight = affine ? abs(weight[plane]) + eps : 1.f;
168
+ T _bias = affine ? bias[plane] : 0.f;
169
+
170
+ Pair<T> res = reduce<Pair<T>, GradOp<T>>(GradOp<T>(_weight, _bias, z, dz, chn, sp), plane, num, sp);
171
+ __syncthreads();
172
+
173
+ if (threadIdx.x == 0) {
174
+ edz[plane] = res.v1;
175
+ eydz[plane] = res.v2;
176
+ }
177
+ }
178
+
179
+ std::vector<at::Tensor> edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
180
+ bool affine, float eps) {
181
+ CHECK_CUDA_INPUT(z);
182
+ CHECK_CUDA_INPUT(dz);
183
+ CHECK_CUDA_INPUT(weight);
184
+ CHECK_CUDA_INPUT(bias);
185
+
186
+ // Extract dimensions
187
+ int64_t num, chn, sp;
188
+ get_dims(z, num, chn, sp);
189
+
190
+ auto edz = at::empty({chn}, z.options());
191
+ auto eydz = at::empty({chn}, z.options());
192
+
193
+ // Run kernel
194
+ dim3 blocks(chn);
195
+ dim3 threads(getNumThreads(sp));
196
+ auto stream = at::cuda::getCurrentCUDAStream();
197
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "edz_eydz_cuda", ([&] {
198
+ edz_eydz_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
199
+ z.data<scalar_t>(),
200
+ dz.data<scalar_t>(),
201
+ weight.data<scalar_t>(),
202
+ bias.data<scalar_t>(),
203
+ edz.data<scalar_t>(),
204
+ eydz.data<scalar_t>(),
205
+ affine, eps, num, chn, sp);
206
+ }));
207
+
208
+ return {edz, eydz};
209
+ }
210
+
211
+ /***********
212
+ * backward
213
+ ***********/
214
+
215
+ template<typename T>
216
+ __global__ void backward_kernel(const T *z, const T *dz, const T *var, const T *weight, const T *bias, const T *edz,
217
+ const T *eydz, T *dx, bool affine, float eps, int num, int chn, int sp) {
218
+ int plane = blockIdx.x;
219
+
220
+ T _weight = affine ? abs(weight[plane]) + eps : 1.f;
221
+ T _bias = affine ? bias[plane] : 0.f;
222
+ T _var = var[plane];
223
+ T _edz = edz[plane];
224
+ T _eydz = eydz[plane];
225
+
226
+ T _mul = _weight * rsqrt(_var + eps);
227
+ T count = T(num * sp);
228
+
229
+ for (int batch = 0; batch < num; ++batch) {
230
+ for (int n = threadIdx.x; n < sp; n += blockDim.x) {
231
+ T _dz = dz[(batch * chn + plane) * sp + n];
232
+ T _y = (z[(batch * chn + plane) * sp + n] - _bias) / _weight;
233
+
234
+ dx[(batch * chn + plane) * sp + n] = (_dz - _edz / count - _y * _eydz / count) * _mul;
235
+ }
236
+ }
237
+ }
238
+
239
+ at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
240
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
241
+ CHECK_CUDA_INPUT(z);
242
+ CHECK_CUDA_INPUT(dz);
243
+ CHECK_CUDA_INPUT(var);
244
+ CHECK_CUDA_INPUT(weight);
245
+ CHECK_CUDA_INPUT(bias);
246
+ CHECK_CUDA_INPUT(edz);
247
+ CHECK_CUDA_INPUT(eydz);
248
+
249
+ // Extract dimensions
250
+ int64_t num, chn, sp;
251
+ get_dims(z, num, chn, sp);
252
+
253
+ auto dx = at::zeros_like(z);
254
+
255
+ // Run kernel
256
+ dim3 blocks(chn);
257
+ dim3 threads(getNumThreads(sp));
258
+ auto stream = at::cuda::getCurrentCUDAStream();
259
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "backward_cuda", ([&] {
260
+ backward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
261
+ z.data<scalar_t>(),
262
+ dz.data<scalar_t>(),
263
+ var.data<scalar_t>(),
264
+ weight.data<scalar_t>(),
265
+ bias.data<scalar_t>(),
266
+ edz.data<scalar_t>(),
267
+ eydz.data<scalar_t>(),
268
+ dx.data<scalar_t>(),
269
+ affine, eps, num, chn, sp);
270
+ }));
271
+
272
+ return dx;
273
+ }
274
+
275
+ /**************
276
+ * activations
277
+ **************/
278
+
279
+ template<typename T>
280
+ inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) {
281
+ // Create thrust pointers
282
+ thrust::device_ptr<T> th_z = thrust::device_pointer_cast(z);
283
+ thrust::device_ptr<T> th_dz = thrust::device_pointer_cast(dz);
284
+
285
+ auto stream = at::cuda::getCurrentCUDAStream();
286
+ thrust::transform_if(thrust::cuda::par.on(stream),
287
+ th_dz, th_dz + count, th_z, th_dz,
288
+ [slope] __device__ (const T& dz) { return dz * slope; },
289
+ [] __device__ (const T& z) { return z < 0; });
290
+ thrust::transform_if(thrust::cuda::par.on(stream),
291
+ th_z, th_z + count, th_z,
292
+ [slope] __device__ (const T& z) { return z / slope; },
293
+ [] __device__ (const T& z) { return z < 0; });
294
+ }
295
+
296
+ void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope) {
297
+ CHECK_CUDA_INPUT(z);
298
+ CHECK_CUDA_INPUT(dz);
299
+
300
+ int64_t count = z.numel();
301
+
302
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
303
+ leaky_relu_backward_impl<scalar_t>(z.data<scalar_t>(), dz.data<scalar_t>(), slope, count);
304
+ }));
305
+ }
306
+
307
+ template<typename T>
308
+ inline void elu_backward_impl(T *z, T *dz, int64_t count) {
309
+ // Create thrust pointers
310
+ thrust::device_ptr<T> th_z = thrust::device_pointer_cast(z);
311
+ thrust::device_ptr<T> th_dz = thrust::device_pointer_cast(dz);
312
+
313
+ auto stream = at::cuda::getCurrentCUDAStream();
314
+ thrust::transform_if(thrust::cuda::par.on(stream),
315
+ th_dz, th_dz + count, th_z, th_z, th_dz,
316
+ [] __device__ (const T& dz, const T& z) { return dz * (z + 1.); },
317
+ [] __device__ (const T& z) { return z < 0; });
318
+ thrust::transform_if(thrust::cuda::par.on(stream),
319
+ th_z, th_z + count, th_z,
320
+ [] __device__ (const T& z) { return log1p(z); },
321
+ [] __device__ (const T& z) { return z < 0; });
322
+ }
323
+
324
+ void elu_backward_cuda(at::Tensor z, at::Tensor dz) {
325
+ CHECK_CUDA_INPUT(z);
326
+ CHECK_CUDA_INPUT(dz);
327
+
328
+ int64_t count = z.numel();
329
+
330
+ AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
331
+ elu_backward_impl<scalar_t>(z.data<scalar_t>(), dz.data<scalar_t>(), count);
332
+ }));
333
+ }
ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cuda_half.cu ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+
3
+ #include <cuda_fp16.h>
4
+
5
+ #include <vector>
6
+
7
+ #include "utils/checks.h"
8
+ #include "utils/cuda.cuh"
9
+ #include "inplace_abn.h"
10
+
11
+ #include <ATen/cuda/CUDAContext.h>
12
+
13
+ // Operations for reduce
14
+ struct SumOpH {
15
+ __device__ SumOpH(const half *t, int c, int s)
16
+ : tensor(t), chn(c), sp(s) {}
17
+ __device__ __forceinline__ float operator()(int batch, int plane, int n) {
18
+ return __half2float(tensor[(batch * chn + plane) * sp + n]);
19
+ }
20
+ const half *tensor;
21
+ const int chn;
22
+ const int sp;
23
+ };
24
+
25
+ struct VarOpH {
26
+ __device__ VarOpH(float m, const half *t, int c, int s)
27
+ : mean(m), tensor(t), chn(c), sp(s) {}
28
+ __device__ __forceinline__ float operator()(int batch, int plane, int n) {
29
+ const auto t = __half2float(tensor[(batch * chn + plane) * sp + n]);
30
+ return (t - mean) * (t - mean);
31
+ }
32
+ const float mean;
33
+ const half *tensor;
34
+ const int chn;
35
+ const int sp;
36
+ };
37
+
38
+ struct GradOpH {
39
+ __device__ GradOpH(float _weight, float _bias, const half *_z, const half *_dz, int c, int s)
40
+ : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {}
41
+ __device__ __forceinline__ Pair<float> operator()(int batch, int plane, int n) {
42
+ float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - bias) / weight;
43
+ float _dz = __half2float(dz[(batch * chn + plane) * sp + n]);
44
+ return Pair<float>(_dz, _y * _dz);
45
+ }
46
+ const float weight;
47
+ const float bias;
48
+ const half *z;
49
+ const half *dz;
50
+ const int chn;
51
+ const int sp;
52
+ };
53
+
54
+ /***********
55
+ * mean_var
56
+ ***********/
57
+
58
+ __global__ void mean_var_kernel_h(const half *x, float *mean, float *var, int num, int chn, int sp) {
59
+ int plane = blockIdx.x;
60
+ float norm = 1.f / static_cast<float>(num * sp);
61
+
62
+ float _mean = reduce<float, SumOpH>(SumOpH(x, chn, sp), plane, num, sp) * norm;
63
+ __syncthreads();
64
+ float _var = reduce<float, VarOpH>(VarOpH(_mean, x, chn, sp), plane, num, sp) * norm;
65
+
66
+ if (threadIdx.x == 0) {
67
+ mean[plane] = _mean;
68
+ var[plane] = _var;
69
+ }
70
+ }
71
+
72
+ std::vector<at::Tensor> mean_var_cuda_h(at::Tensor x) {
73
+ CHECK_CUDA_INPUT(x);
74
+
75
+ // Extract dimensions
76
+ int64_t num, chn, sp;
77
+ get_dims(x, num, chn, sp);
78
+
79
+ // Prepare output tensors
80
+ auto mean = at::empty({chn},x.options().dtype(at::kFloat));
81
+ auto var = at::empty({chn},x.options().dtype(at::kFloat));
82
+
83
+ // Run kernel
84
+ dim3 blocks(chn);
85
+ dim3 threads(getNumThreads(sp));
86
+ auto stream = at::cuda::getCurrentCUDAStream();
87
+ mean_var_kernel_h<<<blocks, threads, 0, stream>>>(
88
+ reinterpret_cast<half*>(x.data<at::Half>()),
89
+ mean.data<float>(),
90
+ var.data<float>(),
91
+ num, chn, sp);
92
+
93
+ return {mean, var};
94
+ }
95
+
96
+ /**********
97
+ * forward
98
+ **********/
99
+
100
+ __global__ void forward_kernel_h(half *x, const float *mean, const float *var, const float *weight, const float *bias,
101
+ bool affine, float eps, int num, int chn, int sp) {
102
+ int plane = blockIdx.x;
103
+
104
+ const float _mean = mean[plane];
105
+ const float _var = var[plane];
106
+ const float _weight = affine ? abs(weight[plane]) + eps : 1.f;
107
+ const float _bias = affine ? bias[plane] : 0.f;
108
+
109
+ const float mul = rsqrt(_var + eps) * _weight;
110
+
111
+ for (int batch = 0; batch < num; ++batch) {
112
+ for (int n = threadIdx.x; n < sp; n += blockDim.x) {
113
+ half *x_ptr = x + (batch * chn + plane) * sp + n;
114
+ float _x = __half2float(*x_ptr);
115
+ float _y = (_x - _mean) * mul + _bias;
116
+
117
+ *x_ptr = __float2half(_y);
118
+ }
119
+ }
120
+ }
121
+
122
+ at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
123
+ bool affine, float eps) {
124
+ CHECK_CUDA_INPUT(x);
125
+ CHECK_CUDA_INPUT(mean);
126
+ CHECK_CUDA_INPUT(var);
127
+ CHECK_CUDA_INPUT(weight);
128
+ CHECK_CUDA_INPUT(bias);
129
+
130
+ // Extract dimensions
131
+ int64_t num, chn, sp;
132
+ get_dims(x, num, chn, sp);
133
+
134
+ // Run kernel
135
+ dim3 blocks(chn);
136
+ dim3 threads(getNumThreads(sp));
137
+ auto stream = at::cuda::getCurrentCUDAStream();
138
+ forward_kernel_h<<<blocks, threads, 0, stream>>>(
139
+ reinterpret_cast<half*>(x.data<at::Half>()),
140
+ mean.data<float>(),
141
+ var.data<float>(),
142
+ weight.data<float>(),
143
+ bias.data<float>(),
144
+ affine, eps, num, chn, sp);
145
+
146
+ return x;
147
+ }
148
+
149
+ __global__ void edz_eydz_kernel_h(const half *z, const half *dz, const float *weight, const float *bias,
150
+ float *edz, float *eydz, bool affine, float eps, int num, int chn, int sp) {
151
+ int plane = blockIdx.x;
152
+
153
+ float _weight = affine ? abs(weight[plane]) + eps : 1.f;
154
+ float _bias = affine ? bias[plane] : 0.f;
155
+
156
+ Pair<float> res = reduce<Pair<float>, GradOpH>(GradOpH(_weight, _bias, z, dz, chn, sp), plane, num, sp);
157
+ __syncthreads();
158
+
159
+ if (threadIdx.x == 0) {
160
+ edz[plane] = res.v1;
161
+ eydz[plane] = res.v2;
162
+ }
163
+ }
164
+
165
+ std::vector<at::Tensor> edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
166
+ bool affine, float eps) {
167
+ CHECK_CUDA_INPUT(z);
168
+ CHECK_CUDA_INPUT(dz);
169
+ CHECK_CUDA_INPUT(weight);
170
+ CHECK_CUDA_INPUT(bias);
171
+
172
+ // Extract dimensions
173
+ int64_t num, chn, sp;
174
+ get_dims(z, num, chn, sp);
175
+
176
+ auto edz = at::empty({chn},z.options().dtype(at::kFloat));
177
+ auto eydz = at::empty({chn},z.options().dtype(at::kFloat));
178
+
179
+ // Run kernel
180
+ dim3 blocks(chn);
181
+ dim3 threads(getNumThreads(sp));
182
+ auto stream = at::cuda::getCurrentCUDAStream();
183
+ edz_eydz_kernel_h<<<blocks, threads, 0, stream>>>(
184
+ reinterpret_cast<half*>(z.data<at::Half>()),
185
+ reinterpret_cast<half*>(dz.data<at::Half>()),
186
+ weight.data<float>(),
187
+ bias.data<float>(),
188
+ edz.data<float>(),
189
+ eydz.data<float>(),
190
+ affine, eps, num, chn, sp);
191
+
192
+ return {edz, eydz};
193
+ }
194
+
195
+ __global__ void backward_kernel_h(const half *z, const half *dz, const float *var, const float *weight, const float *bias, const float *edz,
196
+ const float *eydz, half *dx, bool affine, float eps, int num, int chn, int sp) {
197
+ int plane = blockIdx.x;
198
+
199
+ float _weight = affine ? abs(weight[plane]) + eps : 1.f;
200
+ float _bias = affine ? bias[plane] : 0.f;
201
+ float _var = var[plane];
202
+ float _edz = edz[plane];
203
+ float _eydz = eydz[plane];
204
+
205
+ float _mul = _weight * rsqrt(_var + eps);
206
+ float count = float(num * sp);
207
+
208
+ for (int batch = 0; batch < num; ++batch) {
209
+ for (int n = threadIdx.x; n < sp; n += blockDim.x) {
210
+ float _dz = __half2float(dz[(batch * chn + plane) * sp + n]);
211
+ float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - _bias) / _weight;
212
+
213
+ dx[(batch * chn + plane) * sp + n] = __float2half((_dz - _edz / count - _y * _eydz / count) * _mul);
214
+ }
215
+ }
216
+ }
217
+
218
+ at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
219
+ at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
220
+ CHECK_CUDA_INPUT(z);
221
+ CHECK_CUDA_INPUT(dz);
222
+ CHECK_CUDA_INPUT(var);
223
+ CHECK_CUDA_INPUT(weight);
224
+ CHECK_CUDA_INPUT(bias);
225
+ CHECK_CUDA_INPUT(edz);
226
+ CHECK_CUDA_INPUT(eydz);
227
+
228
+ // Extract dimensions
229
+ int64_t num, chn, sp;
230
+ get_dims(z, num, chn, sp);
231
+
232
+ auto dx = at::zeros_like(z);
233
+
234
+ // Run kernel
235
+ dim3 blocks(chn);
236
+ dim3 threads(getNumThreads(sp));
237
+ auto stream = at::cuda::getCurrentCUDAStream();
238
+ backward_kernel_h<<<blocks, threads, 0, stream>>>(
239
+ reinterpret_cast<half*>(z.data<at::Half>()),
240
+ reinterpret_cast<half*>(dz.data<at::Half>()),
241
+ var.data<float>(),
242
+ weight.data<float>(),
243
+ bias.data<float>(),
244
+ edz.data<float>(),
245
+ eydz.data<float>(),
246
+ reinterpret_cast<half*>(dx.data<at::Half>()),
247
+ affine, eps, num, chn, sp);
248
+
249
+ return dx;
250
+ }
251
+
252
+ __global__ void leaky_relu_backward_impl_h(half *z, half *dz, float slope, int64_t count) {
253
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x){
254
+ float _z = __half2float(z[i]);
255
+ if (_z < 0) {
256
+ dz[i] = __float2half(__half2float(dz[i]) * slope);
257
+ z[i] = __float2half(_z / slope);
258
+ }
259
+ }
260
+ }
261
+
262
+ void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope) {
263
+ CHECK_CUDA_INPUT(z);
264
+ CHECK_CUDA_INPUT(dz);
265
+
266
+ int64_t count = z.numel();
267
+ dim3 threads(getNumThreads(count));
268
+ dim3 blocks = (count + threads.x - 1) / threads.x;
269
+ auto stream = at::cuda::getCurrentCUDAStream();
270
+ leaky_relu_backward_impl_h<<<blocks, threads, 0, stream>>>(
271
+ reinterpret_cast<half*>(z.data<at::Half>()),
272
+ reinterpret_cast<half*>(dz.data<at::Half>()),
273
+ slope, count);
274
+ }
275
+
ConsistentID/lib/BiSeNet/modules/src/utils/checks.h ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/ATen.h>
4
+
5
+ // Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT
6
+ #ifndef AT_CHECK
7
+ #define AT_CHECK AT_ASSERT
8
+ #endif
9
+
10
+ #define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor")
11
+ #define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor")
12
+ #define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous")
13
+
14
+ #define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
15
+ #define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x)
ConsistentID/lib/BiSeNet/modules/src/utils/common.h ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/ATen.h>
4
+
5
+ /*
6
+ * Functions to share code between CPU and GPU
7
+ */
8
+
9
+ #ifdef __CUDACC__
10
+ // CUDA versions
11
+
12
+ #define HOST_DEVICE __host__ __device__
13
+ #define INLINE_HOST_DEVICE __host__ __device__ inline
14
+ #define FLOOR(x) floor(x)
15
+
16
+ #if __CUDA_ARCH__ >= 600
17
+ // Recent compute capabilities have block-level atomicAdd for all data types, so we use that
18
+ #define ACCUM(x,y) atomicAdd_block(&(x),(y))
19
+ #else
20
+ // Older architectures don't have block-level atomicAdd, nor atomicAdd for doubles, so we defer to atomicAdd for float
21
+ // and use the known atomicCAS-based implementation for double
22
+ template<typename data_t>
23
+ __device__ inline data_t atomic_add(data_t *address, data_t val) {
24
+ return atomicAdd(address, val);
25
+ }
26
+
27
+ template<>
28
+ __device__ inline double atomic_add(double *address, double val) {
29
+ unsigned long long int* address_as_ull = (unsigned long long int*)address;
30
+ unsigned long long int old = *address_as_ull, assumed;
31
+ do {
32
+ assumed = old;
33
+ old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed)));
34
+ } while (assumed != old);
35
+ return __longlong_as_double(old);
36
+ }
37
+
38
+ #define ACCUM(x,y) atomic_add(&(x),(y))
39
+ #endif // #if __CUDA_ARCH__ >= 600
40
+
41
+ #else
42
+ // CPU versions
43
+
44
+ #define HOST_DEVICE
45
+ #define INLINE_HOST_DEVICE inline
46
+ #define FLOOR(x) std::floor(x)
47
+ #define ACCUM(x,y) (x) += (y)
48
+
49
+ #endif // #ifdef __CUDACC__
ConsistentID/lib/BiSeNet/modules/src/utils/cuda.cuh ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ /*
4
+ * General settings and functions
5
+ */
6
+ const int WARP_SIZE = 32;
7
+ const int MAX_BLOCK_SIZE = 1024;
8
+
9
+ static int getNumThreads(int nElem) {
10
+ int threadSizes[6] = {32, 64, 128, 256, 512, MAX_BLOCK_SIZE};
11
+ for (int i = 0; i < 6; ++i) {
12
+ if (nElem <= threadSizes[i]) {
13
+ return threadSizes[i];
14
+ }
15
+ }
16
+ return MAX_BLOCK_SIZE;
17
+ }
18
+
19
+ /*
20
+ * Reduction utilities
21
+ */
22
+ template <typename T>
23
+ __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize,
24
+ unsigned int mask = 0xffffffff) {
25
+ #if CUDART_VERSION >= 9000
26
+ return __shfl_xor_sync(mask, value, laneMask, width);
27
+ #else
28
+ return __shfl_xor(value, laneMask, width);
29
+ #endif
30
+ }
31
+
32
+ __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); }
33
+
34
+ template<typename T>
35
+ struct Pair {
36
+ T v1, v2;
37
+ __device__ Pair() {}
38
+ __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {}
39
+ __device__ Pair(T v) : v1(v), v2(v) {}
40
+ __device__ Pair(int v) : v1(v), v2(v) {}
41
+ __device__ Pair &operator+=(const Pair<T> &a) {
42
+ v1 += a.v1;
43
+ v2 += a.v2;
44
+ return *this;
45
+ }
46
+ };
47
+
48
+ template<typename T>
49
+ static __device__ __forceinline__ T warpSum(T val) {
50
+ #if __CUDA_ARCH__ >= 300
51
+ for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
52
+ val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE);
53
+ }
54
+ #else
55
+ __shared__ T values[MAX_BLOCK_SIZE];
56
+ values[threadIdx.x] = val;
57
+ __threadfence_block();
58
+ const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
59
+ for (int i = 1; i < WARP_SIZE; i++) {
60
+ val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
61
+ }
62
+ #endif
63
+ return val;
64
+ }
65
+
66
+ template<typename T>
67
+ static __device__ __forceinline__ Pair<T> warpSum(Pair<T> value) {
68
+ value.v1 = warpSum(value.v1);
69
+ value.v2 = warpSum(value.v2);
70
+ return value;
71
+ }
ConsistentID/lib/BiSeNet/optimizer.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ import torch
6
+ import logging
7
+
8
+ logger = logging.getLogger()
9
+
10
+ class Optimizer(object):
11
+ def __init__(self,
12
+ model,
13
+ lr0,
14
+ momentum,
15
+ wd,
16
+ warmup_steps,
17
+ warmup_start_lr,
18
+ max_iter,
19
+ power,
20
+ *args, **kwargs):
21
+ self.warmup_steps = warmup_steps
22
+ self.warmup_start_lr = warmup_start_lr
23
+ self.lr0 = lr0
24
+ self.lr = self.lr0
25
+ self.max_iter = float(max_iter)
26
+ self.power = power
27
+ self.it = 0
28
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = model.get_params()
29
+ param_list = [
30
+ {'params': wd_params},
31
+ {'params': nowd_params, 'weight_decay': 0},
32
+ {'params': lr_mul_wd_params, 'lr_mul': True},
33
+ {'params': lr_mul_nowd_params, 'weight_decay': 0, 'lr_mul': True}]
34
+ self.optim = torch.optim.SGD(
35
+ param_list,
36
+ lr = lr0,
37
+ momentum = momentum,
38
+ weight_decay = wd)
39
+ self.warmup_factor = (self.lr0/self.warmup_start_lr)**(1./self.warmup_steps)
40
+
41
+
42
+ def get_lr(self):
43
+ if self.it <= self.warmup_steps:
44
+ lr = self.warmup_start_lr*(self.warmup_factor**self.it)
45
+ else:
46
+ factor = (1-(self.it-self.warmup_steps)/(self.max_iter-self.warmup_steps))**self.power
47
+ lr = self.lr0 * factor
48
+ return lr
49
+
50
+
51
+ def step(self):
52
+ self.lr = self.get_lr()
53
+ for pg in self.optim.param_groups:
54
+ if pg.get('lr_mul', False):
55
+ pg['lr'] = self.lr * 10
56
+ else:
57
+ pg['lr'] = self.lr
58
+ if self.optim.defaults.get('lr_mul', False):
59
+ self.optim.defaults['lr'] = self.lr * 10
60
+ else:
61
+ self.optim.defaults['lr'] = self.lr
62
+ self.it += 1
63
+ self.optim.step()
64
+ if self.it == self.warmup_steps+2:
65
+ logger.info('==> warmup done, start to implement poly lr strategy')
66
+
67
+ def zero_grad(self):
68
+ self.optim.zero_grad()
69
+
ConsistentID/lib/BiSeNet/prepropess_data.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import os.path as osp
5
+ import os
6
+ import cv2
7
+ from transform import *
8
+ from PIL import Image
9
+
10
+ face_data = '/home/zll/data/CelebAMask-HQ/CelebA-HQ-img'
11
+ face_sep_mask = '/home/zll/data/CelebAMask-HQ/CelebAMask-HQ-mask-anno'
12
+ mask_path = '/home/zll/data/CelebAMask-HQ/mask'
13
+ counter = 0
14
+ total = 0
15
+ for i in range(15):
16
+
17
+ atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
18
+ 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
19
+
20
+ for j in range(i * 2000, (i + 1) * 2000):
21
+
22
+ mask = np.zeros((512, 512))
23
+
24
+ for l, att in enumerate(atts, 1):
25
+ total += 1
26
+ file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png'])
27
+ path = osp.join(face_sep_mask, str(i), file_name)
28
+
29
+ if os.path.exists(path):
30
+ counter += 1
31
+ sep_mask = np.array(Image.open(path).convert('P'))
32
+ # print(np.unique(sep_mask))
33
+
34
+ mask[sep_mask == 225] = l
35
+ cv2.imwrite('{}/{}.png'.format(mask_path, j), mask)
36
+ print(j)
37
+
38
+ print(counter, total)
ConsistentID/lib/BiSeNet/resnet.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.model_zoo as modelzoo
8
+
9
+ # from modules.bn import InPlaceABNSync as BatchNorm2d
10
+
11
+ resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12
+
13
+
14
+ def conv3x3(in_planes, out_planes, stride=1):
15
+ """3x3 convolution with padding"""
16
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17
+ padding=1, bias=False)
18
+
19
+
20
+ class BasicBlock(nn.Module):
21
+ def __init__(self, in_chan, out_chan, stride=1):
22
+ super(BasicBlock, self).__init__()
23
+ self.conv1 = conv3x3(in_chan, out_chan, stride)
24
+ self.bn1 = nn.BatchNorm2d(out_chan)
25
+ self.conv2 = conv3x3(out_chan, out_chan)
26
+ self.bn2 = nn.BatchNorm2d(out_chan)
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.downsample = None
29
+ if in_chan != out_chan or stride != 1:
30
+ self.downsample = nn.Sequential(
31
+ nn.Conv2d(in_chan, out_chan,
32
+ kernel_size=1, stride=stride, bias=False),
33
+ nn.BatchNorm2d(out_chan),
34
+ )
35
+
36
+ def forward(self, x):
37
+ residual = self.conv1(x)
38
+ residual = F.relu(self.bn1(residual))
39
+ residual = self.conv2(residual)
40
+ residual = self.bn2(residual)
41
+
42
+ shortcut = x
43
+ if self.downsample is not None:
44
+ shortcut = self.downsample(x)
45
+
46
+ out = shortcut + residual
47
+ out = self.relu(out)
48
+ return out
49
+
50
+
51
+ def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52
+ layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53
+ for i in range(bnum-1):
54
+ layers.append(BasicBlock(out_chan, out_chan, stride=1))
55
+ return nn.Sequential(*layers)
56
+
57
+
58
+ class Resnet18(nn.Module):
59
+ def __init__(self):
60
+ super(Resnet18, self).__init__()
61
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62
+ bias=False)
63
+ self.bn1 = nn.BatchNorm2d(64)
64
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65
+ self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
66
+ self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
67
+ self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
68
+ self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
69
+ self.init_weight()
70
+
71
+ def forward(self, x):
72
+ x = self.conv1(x)
73
+ x = F.relu(self.bn1(x))
74
+ x = self.maxpool(x)
75
+
76
+ x = self.layer1(x)
77
+ feat8 = self.layer2(x) # 1/8
78
+ feat16 = self.layer3(feat8) # 1/16
79
+ feat32 = self.layer4(feat16) # 1/32
80
+ return feat8, feat16, feat32
81
+
82
+ def init_weight(self):
83
+ state_dict = modelzoo.load_url(resnet18_url)
84
+ self_state_dict = self.state_dict()
85
+ for k, v in state_dict.items():
86
+ if 'fc' in k: continue
87
+ self_state_dict.update({k: v})
88
+ self.load_state_dict(self_state_dict)
89
+
90
+ def get_params(self):
91
+ wd_params, nowd_params = [], []
92
+ for name, module in self.named_modules():
93
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
94
+ wd_params.append(module.weight)
95
+ if not module.bias is None:
96
+ nowd_params.append(module.bias)
97
+ elif isinstance(module, nn.BatchNorm2d):
98
+ nowd_params += list(module.parameters())
99
+ return wd_params, nowd_params
100
+
101
+
102
+ if __name__ == "__main__":
103
+ net = Resnet18()
104
+ x = torch.randn(16, 3, 224, 224)
105
+ out = net(x)
106
+ print(out[0].size())
107
+ print(out[1].size())
108
+ print(out[2].size())
109
+ net.get_params()
ConsistentID/lib/BiSeNet/test.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ from logger import setup_logger
5
+ import BiSeNet
6
+
7
+ import torch
8
+
9
+ import os
10
+ import os.path as osp
11
+ import numpy as np
12
+ from PIL import Image
13
+ import torchvision.transforms as transforms
14
+ import cv2
15
+
16
+ def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'):
17
+ # Colors for all 20 parts
18
+ part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
19
+ [255, 0, 85], [255, 0, 170],
20
+ [0, 255, 0], [85, 255, 0], [170, 255, 0],
21
+ [0, 255, 85], [0, 255, 170],
22
+ [0, 0, 255], [85, 0, 255], [170, 0, 255],
23
+ [0, 85, 255], [0, 170, 255],
24
+ [255, 255, 0], [255, 255, 85], [255, 255, 170],
25
+ [255, 0, 255], [255, 85, 255], [255, 170, 255],
26
+ [0, 255, 255], [85, 255, 255], [170, 255, 255]]
27
+
28
+ im = np.array(im)
29
+ vis_im = im.copy().astype(np.uint8)
30
+ vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
31
+ vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
32
+ vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
33
+
34
+ num_of_class = np.max(vis_parsing_anno)
35
+
36
+ for pi in range(1, num_of_class + 1):
37
+ index = np.where(vis_parsing_anno == pi)
38
+ vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
39
+
40
+ vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
41
+ # print(vis_parsing_anno_color.shape, vis_im.shape)
42
+ vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
43
+
44
+ # Save result or not
45
+ if save_im:
46
+ cv2.imwrite(save_path[:-4] +'.png', vis_parsing_anno)
47
+ cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
48
+
49
+ # return vis_im
50
+
51
+ def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'):
52
+
53
+ if not os.path.exists(respth):
54
+ os.makedirs(respth)
55
+
56
+ n_classes = 19
57
+ net = BiSeNet(n_classes=n_classes)
58
+ net.cuda()
59
+ save_pth = osp.join('res/cp', cp)
60
+ net.load_state_dict(torch.load(save_pth))
61
+ net.eval()
62
+
63
+ to_tensor = transforms.Compose([
64
+ transforms.ToTensor(),
65
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
66
+ ])
67
+ with torch.no_grad():
68
+ for image_path in os.listdir(dspth):
69
+ img = Image.open(osp.join(dspth, image_path))
70
+ image = img.resize((512, 512), Image.BILINEAR)
71
+ img = to_tensor(image)
72
+ img = torch.unsqueeze(img, 0)
73
+ img = img.cuda()
74
+ out = net(img)[0]
75
+ parsing = out.squeeze(0).cpu().numpy().argmax(0)
76
+ # print(parsing)
77
+ print(np.unique(parsing))
78
+
79
+ vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path))
80
+
81
+
82
+
83
+
84
+
85
+
86
+
87
+ if __name__ == "__main__":
88
+ evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='79999_iter.pth')
89
+
90
+
ConsistentID/lib/BiSeNet/train.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ from logger import setup_logger
5
+ import BiSeNet
6
+ from face_dataset import FaceMask
7
+ from loss import OhemCELoss
8
+ from evaluate import evaluate
9
+ from optimizer import Optimizer
10
+ import cv2
11
+ import numpy as np
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from torch.utils.data import DataLoader
16
+ import torch.nn.functional as F
17
+ import torch.distributed as dist
18
+
19
+ import os
20
+ import os.path as osp
21
+ import logging
22
+ import time
23
+ import datetime
24
+ import argparse
25
+
26
+
27
+ respth = './res'
28
+ if not osp.exists(respth):
29
+ os.makedirs(respth)
30
+ logger = logging.getLogger()
31
+
32
+
33
+ def parse_args():
34
+ parse = argparse.ArgumentParser()
35
+ parse.add_argument(
36
+ '--local_rank',
37
+ dest = 'local_rank',
38
+ type = int,
39
+ default = -1,
40
+ )
41
+ return parse.parse_args()
42
+
43
+
44
+ def train():
45
+ args = parse_args()
46
+ torch.cuda.set_device(args.local_rank)
47
+ dist.init_process_group(
48
+ backend = 'nccl',
49
+ init_method = 'tcp://127.0.0.1:33241',
50
+ world_size = torch.cuda.device_count(),
51
+ rank=args.local_rank
52
+ )
53
+ setup_logger(respth)
54
+
55
+ # dataset
56
+ n_classes = 19
57
+ n_img_per_gpu = 16
58
+ n_workers = 8
59
+ cropsize = [448, 448]
60
+ data_root = '/home/zll/data/CelebAMask-HQ/'
61
+
62
+ ds = FaceMask(data_root, cropsize=cropsize, mode='train')
63
+ sampler = torch.utils.data.distributed.DistributedSampler(ds)
64
+ dl = DataLoader(ds,
65
+ batch_size = n_img_per_gpu,
66
+ shuffle = False,
67
+ sampler = sampler,
68
+ num_workers = n_workers,
69
+ pin_memory = True,
70
+ drop_last = True)
71
+
72
+ # model
73
+ ignore_idx = -100
74
+ net = BiSeNet(n_classes=n_classes)
75
+ net.cuda()
76
+ net.train()
77
+ net = nn.parallel.DistributedDataParallel(net,
78
+ device_ids = [args.local_rank, ],
79
+ output_device = args.local_rank
80
+ )
81
+ score_thres = 0.7
82
+ n_min = n_img_per_gpu * cropsize[0] * cropsize[1]//16
83
+ LossP = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
84
+ Loss2 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
85
+ Loss3 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
86
+
87
+ ## optimizer
88
+ momentum = 0.9
89
+ weight_decay = 5e-4
90
+ lr_start = 1e-2
91
+ max_iter = 80000
92
+ power = 0.9
93
+ warmup_steps = 1000
94
+ warmup_start_lr = 1e-5
95
+ optim = Optimizer(
96
+ model = net.module,
97
+ lr0 = lr_start,
98
+ momentum = momentum,
99
+ wd = weight_decay,
100
+ warmup_steps = warmup_steps,
101
+ warmup_start_lr = warmup_start_lr,
102
+ max_iter = max_iter,
103
+ power = power)
104
+
105
+ ## train loop
106
+ msg_iter = 50
107
+ loss_avg = []
108
+ st = glob_st = time.time()
109
+ diter = iter(dl)
110
+ epoch = 0
111
+ for it in range(max_iter):
112
+ try:
113
+ im, lb = next(diter)
114
+ if not im.size()[0] == n_img_per_gpu:
115
+ raise StopIteration
116
+ except StopIteration:
117
+ epoch += 1
118
+ sampler.set_epoch(epoch)
119
+ diter = iter(dl)
120
+ im, lb = next(diter)
121
+ im = im.cuda()
122
+ lb = lb.cuda()
123
+ H, W = im.size()[2:]
124
+ lb = torch.squeeze(lb, 1)
125
+
126
+ optim.zero_grad()
127
+ out, out16, out32 = net(im)
128
+ lossp = LossP(out, lb)
129
+ loss2 = Loss2(out16, lb)
130
+ loss3 = Loss3(out32, lb)
131
+ loss = lossp + loss2 + loss3
132
+ loss.backward()
133
+ optim.step()
134
+
135
+ loss_avg.append(loss.item())
136
+
137
+ # print training log message
138
+ if (it+1) % msg_iter == 0:
139
+ loss_avg = sum(loss_avg) / len(loss_avg)
140
+ lr = optim.lr
141
+ ed = time.time()
142
+ t_intv, glob_t_intv = ed - st, ed - glob_st
143
+ eta = int((max_iter - it) * (glob_t_intv / it))
144
+ eta = str(datetime.timedelta(seconds=eta))
145
+ msg = ', '.join([
146
+ 'it: {it}/{max_it}',
147
+ 'lr: {lr:4f}',
148
+ 'loss: {loss:.4f}',
149
+ 'eta: {eta}',
150
+ 'time: {time:.4f}',
151
+ ]).format(
152
+ it = it+1,
153
+ max_it = max_iter,
154
+ lr = lr,
155
+ loss = loss_avg,
156
+ time = t_intv,
157
+ eta = eta
158
+ )
159
+ logger.info(msg)
160
+ loss_avg = []
161
+ st = ed
162
+ if dist.get_rank() == 0:
163
+ if (it+1) % 5000 == 0:
164
+ state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
165
+ if dist.get_rank() == 0:
166
+ torch.save(state, './res/cp/{}_iter.pth'.format(it))
167
+ evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='{}_iter.pth'.format(it))
168
+
169
+ # dump the final model
170
+ save_pth = osp.join(respth, 'model_final_diss.pth')
171
+ # net.cpu()
172
+ state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
173
+ if dist.get_rank() == 0:
174
+ torch.save(state, save_pth)
175
+ logger.info('training done, model saved to: {}'.format(save_pth))
176
+
177
+
178
+ if __name__ == "__main__":
179
+ train()
ConsistentID/lib/BiSeNet/transform.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+
5
+ from PIL import Image
6
+ import PIL.ImageEnhance as ImageEnhance
7
+ import random
8
+ import numpy as np
9
+
10
+ class RandomCrop(object):
11
+ def __init__(self, size, *args, **kwargs):
12
+ self.size = size
13
+
14
+ def __call__(self, im_lb):
15
+ im = im_lb['im']
16
+ lb = im_lb['lb']
17
+ assert im.size == lb.size
18
+ W, H = self.size
19
+ w, h = im.size
20
+
21
+ if (W, H) == (w, h): return dict(im=im, lb=lb)
22
+ if w < W or h < H:
23
+ scale = float(W) / w if w < h else float(H) / h
24
+ w, h = int(scale * w + 1), int(scale * h + 1)
25
+ im = im.resize((w, h), Image.BILINEAR)
26
+ lb = lb.resize((w, h), Image.NEAREST)
27
+ sw, sh = random.random() * (w - W), random.random() * (h - H)
28
+ crop = int(sw), int(sh), int(sw) + W, int(sh) + H
29
+ return dict(
30
+ im = im.crop(crop),
31
+ lb = lb.crop(crop)
32
+ )
33
+
34
+
35
+ class HorizontalFlip(object):
36
+ def __init__(self, p=0.5, *args, **kwargs):
37
+ self.p = p
38
+
39
+ def __call__(self, im_lb):
40
+ if random.random() > self.p:
41
+ return im_lb
42
+ else:
43
+ im = im_lb['im']
44
+ lb = im_lb['lb']
45
+
46
+ # atts = [1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye', 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r',
47
+ # 10 'nose', 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat']
48
+
49
+ flip_lb = np.array(lb)
50
+ flip_lb[lb == 2] = 3
51
+ flip_lb[lb == 3] = 2
52
+ flip_lb[lb == 4] = 5
53
+ flip_lb[lb == 5] = 4
54
+ flip_lb[lb == 7] = 8
55
+ flip_lb[lb == 8] = 7
56
+ flip_lb = Image.fromarray(flip_lb)
57
+ return dict(im = im.transpose(Image.FLIP_LEFT_RIGHT),
58
+ lb = flip_lb.transpose(Image.FLIP_LEFT_RIGHT),
59
+ )
60
+
61
+
62
+ class RandomScale(object):
63
+ def __init__(self, scales=(1, ), *args, **kwargs):
64
+ self.scales = scales
65
+
66
+ def __call__(self, im_lb):
67
+ im = im_lb['im']
68
+ lb = im_lb['lb']
69
+ W, H = im.size
70
+ scale = random.choice(self.scales)
71
+ w, h = int(W * scale), int(H * scale)
72
+ return dict(im = im.resize((w, h), Image.BILINEAR),
73
+ lb = lb.resize((w, h), Image.NEAREST),
74
+ )
75
+
76
+
77
+ class ColorJitter(object):
78
+ def __init__(self, brightness=None, contrast=None, saturation=None, *args, **kwargs):
79
+ if not brightness is None and brightness>0:
80
+ self.brightness = [max(1-brightness, 0), 1+brightness]
81
+ if not contrast is None and contrast>0:
82
+ self.contrast = [max(1-contrast, 0), 1+contrast]
83
+ if not saturation is None and saturation>0:
84
+ self.saturation = [max(1-saturation, 0), 1+saturation]
85
+
86
+ def __call__(self, im_lb):
87
+ im = im_lb['im']
88
+ lb = im_lb['lb']
89
+ r_brightness = random.uniform(self.brightness[0], self.brightness[1])
90
+ r_contrast = random.uniform(self.contrast[0], self.contrast[1])
91
+ r_saturation = random.uniform(self.saturation[0], self.saturation[1])
92
+ im = ImageEnhance.Brightness(im).enhance(r_brightness)
93
+ im = ImageEnhance.Contrast(im).enhance(r_contrast)
94
+ im = ImageEnhance.Color(im).enhance(r_saturation)
95
+ return dict(im = im,
96
+ lb = lb,
97
+ )
98
+
99
+
100
+ class MultiScale(object):
101
+ def __init__(self, scales):
102
+ self.scales = scales
103
+
104
+ def __call__(self, img):
105
+ W, H = img.size
106
+ sizes = [(int(W*ratio), int(H*ratio)) for ratio in self.scales]
107
+ imgs = []
108
+ [imgs.append(img.resize(size, Image.BILINEAR)) for size in sizes]
109
+ return imgs
110
+
111
+
112
+ class Compose(object):
113
+ def __init__(self, do_list):
114
+ self.do_list = do_list
115
+
116
+ def __call__(self, im_lb):
117
+ for comp in self.do_list:
118
+ im_lb = comp(im_lb)
119
+ return im_lb
120
+
121
+
122
+
123
+
124
+ if __name__ == '__main__':
125
+ flip = HorizontalFlip(p = 1)
126
+ crop = RandomCrop((321, 321))
127
+ rscales = RandomScale((0.75, 1.0, 1.5, 1.75, 2.0))
128
+ img = Image.open('data/img.jpg')
129
+ lb = Image.open('data/label.png')
ConsistentID/lib/attention.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from diffusers.models.lora import LoRALinearLayer
5
+ from .functions import AttentionMLP
6
+
7
+ class FuseModule(nn.Module):
8
+ def __init__(self, embed_dim):
9
+ super().__init__()
10
+ self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False)
11
+ self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True)
12
+ self.layer_norm = nn.LayerNorm(embed_dim)
13
+
14
+ def fuse_fn(self, prompt_embeds, id_embeds):
15
+ stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
16
+ stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
17
+ stacked_id_embeds = self.mlp2(stacked_id_embeds)
18
+ stacked_id_embeds = self.layer_norm(stacked_id_embeds)
19
+ return stacked_id_embeds
20
+
21
+ def forward(
22
+ self,
23
+ prompt_embeds,
24
+ id_embeds,
25
+ class_tokens_mask,
26
+ valid_id_mask,
27
+ ) -> torch.Tensor:
28
+ id_embeds = id_embeds.to(prompt_embeds.dtype)
29
+ batch_size, max_num_inputs = id_embeds.shape[:2] # 1,5
30
+ seq_length = prompt_embeds.shape[1] # 77
31
+ flat_id_embeds = id_embeds.view(-1, id_embeds.shape[-2], id_embeds.shape[-1])
32
+ # flat_id_embeds torch.Size([5, 1, 768])
33
+ valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()]
34
+ # valid_id_embeds torch.Size([4, 1, 768])
35
+ prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1]) # torch.Size([77, 768])
36
+ class_tokens_mask = class_tokens_mask.view(-1) # torch.Size([77])
37
+ valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1]) # torch.Size([4, 768])
38
+ image_token_embeds = prompt_embeds[class_tokens_mask] # torch.Size([4, 768])
39
+ stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) # torch.Size([4, 768])
40
+ assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
41
+ prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
42
+ updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1)
43
+
44
+ return updated_prompt_embeds
45
+
46
+ class MLP(nn.Module):
47
+ def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True):
48
+ super().__init__()
49
+ if use_residual:
50
+ assert in_dim == out_dim
51
+ self.layernorm = nn.LayerNorm(in_dim)
52
+ self.fc1 = nn.Linear(in_dim, hidden_dim)
53
+ self.fc2 = nn.Linear(hidden_dim, out_dim)
54
+ self.use_residual = use_residual
55
+ self.act_fn = nn.GELU()
56
+
57
+ def forward(self, x):
58
+
59
+ residual = x
60
+ x = self.layernorm(x)
61
+ x = self.fc1(x)
62
+ x = self.act_fn(x)
63
+ x = self.fc2(x)
64
+ if self.use_residual:
65
+ x = x + residual
66
+ return x
67
+
68
+ class FacialEncoder(nn.Module):
69
+ def __init__(self):
70
+ super().__init__()
71
+ self.visual_projection = AttentionMLP()
72
+ self.fuse_module = FuseModule(768)
73
+
74
+ def forward(self, prompt_embeds, multi_image_embeds, class_tokens_mask, valid_id_mask):
75
+ bs, num_inputs, token_length, image_dim = multi_image_embeds.shape
76
+ multi_image_embeds_view = multi_image_embeds.view(bs * num_inputs, token_length, image_dim)
77
+ id_embeds = self.visual_projection(multi_image_embeds_view) # torch.Size([5, 1, 768])
78
+ id_embeds = id_embeds.view(bs, num_inputs, 1, -1)
79
+ # fuse_module replaces the class tokens in prompt_embeds with the fused (id_embeds, prompt_embeds[class_tokens_mask])
80
+ # whose indices are specified by class_tokens_mask.
81
+ updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask, valid_id_mask)
82
+ return updated_prompt_embeds
83
+
84
+ class Consistent_AttProcessor(nn.Module):
85
+
86
+ def __init__(
87
+ self,
88
+ hidden_size=None,
89
+ cross_attention_dim=None,
90
+ rank=4,
91
+ network_alpha=None,
92
+ lora_scale=1.0,
93
+ ):
94
+ super().__init__()
95
+
96
+ self.rank = rank
97
+ self.lora_scale = lora_scale
98
+
99
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
100
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
101
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
102
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
103
+
104
+ def __call__(
105
+ self,
106
+ attn,
107
+ hidden_states,
108
+ encoder_hidden_states=None,
109
+ attention_mask=None,
110
+ temb=None,
111
+ ):
112
+ residual = hidden_states
113
+
114
+ if attn.spatial_norm is not None:
115
+ hidden_states = attn.spatial_norm(hidden_states, temb)
116
+
117
+ input_ndim = hidden_states.ndim
118
+
119
+ if input_ndim == 4:
120
+ batch_size, channel, height, width = hidden_states.shape
121
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
122
+
123
+ batch_size, sequence_length, _ = (
124
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
125
+ )
126
+
127
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
128
+
129
+ if attn.group_norm is not None:
130
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
131
+
132
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
133
+
134
+ if encoder_hidden_states is None:
135
+ encoder_hidden_states = hidden_states
136
+ elif attn.norm_cross:
137
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
138
+
139
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
140
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
141
+
142
+ query = attn.head_to_batch_dim(query)
143
+ key = attn.head_to_batch_dim(key)
144
+ value = attn.head_to_batch_dim(value)
145
+
146
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
147
+ hidden_states = torch.bmm(attention_probs, value)
148
+ hidden_states = attn.batch_to_head_dim(hidden_states)
149
+
150
+ # linear proj
151
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
152
+ # dropout
153
+ hidden_states = attn.to_out[1](hidden_states)
154
+
155
+ if input_ndim == 4:
156
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
157
+
158
+ if attn.residual_connection:
159
+ hidden_states = hidden_states + residual
160
+
161
+ hidden_states = hidden_states / attn.rescale_output_factor
162
+
163
+ return hidden_states
164
+
165
+
166
+ class Consistent_IPAttProcessor(nn.Module):
167
+
168
+ def __init__(
169
+ self,
170
+ hidden_size,
171
+ cross_attention_dim=None,
172
+ rank=4,
173
+ network_alpha=None,
174
+ lora_scale=1.0,
175
+ scale=1.0,
176
+ num_tokens=4):
177
+ super().__init__()
178
+
179
+ self.rank = rank
180
+ self.lora_scale = lora_scale
181
+ self.num_tokens = num_tokens
182
+
183
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
184
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
185
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
186
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
187
+
188
+
189
+ self.hidden_size = hidden_size
190
+ self.cross_attention_dim = cross_attention_dim
191
+ self.scale = scale
192
+
193
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
194
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
195
+
196
+ for module in [self.to_q_lora, self.to_k_lora, self.to_v_lora, self.to_out_lora, self.to_k_ip, self.to_v_ip]:
197
+ for param in module.parameters():
198
+ param.requires_grad = False
199
+
200
+ def __call__(
201
+ self,
202
+ attn,
203
+ hidden_states,
204
+ encoder_hidden_states=None,
205
+ attention_mask=None,
206
+ scale=1.0,
207
+ temb=None,
208
+ ):
209
+ residual = hidden_states
210
+
211
+ if attn.spatial_norm is not None:
212
+ hidden_states = attn.spatial_norm(hidden_states, temb)
213
+
214
+ input_ndim = hidden_states.ndim
215
+
216
+ if input_ndim == 4:
217
+ batch_size, channel, height, width = hidden_states.shape
218
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
219
+
220
+ batch_size, sequence_length, _ = (
221
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
222
+ )
223
+
224
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
225
+
226
+ if attn.group_norm is not None:
227
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
228
+
229
+ query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
230
+
231
+ if encoder_hidden_states is None:
232
+ encoder_hidden_states = hidden_states
233
+ else:
234
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
235
+ encoder_hidden_states, ip_hidden_states = (
236
+ encoder_hidden_states[:, :end_pos, :],
237
+ encoder_hidden_states[:, end_pos:, :],
238
+ )
239
+ if attn.norm_cross:
240
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
241
+
242
+ key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
243
+ value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
244
+
245
+ inner_dim = key.shape[-1]
246
+ head_dim = inner_dim // attn.heads
247
+
248
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
249
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
250
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
251
+
252
+ hidden_states = F.scaled_dot_product_attention(
253
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
254
+ )
255
+
256
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
257
+ hidden_states = hidden_states.to(query.dtype)
258
+
259
+ ip_key = self.to_k_ip(ip_hidden_states)
260
+ ip_value = self.to_v_ip(ip_hidden_states)
261
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
262
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
263
+
264
+
265
+ ip_hidden_states = F.scaled_dot_product_attention(
266
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
267
+ )
268
+
269
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
270
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
271
+
272
+ hidden_states = hidden_states + self.scale * ip_hidden_states
273
+
274
+ # linear proj
275
+ hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
276
+ # dropout
277
+ hidden_states = attn.to_out[1](hidden_states)
278
+
279
+ if input_ndim == 4:
280
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
281
+
282
+ if attn.residual_connection:
283
+ hidden_states = hidden_states + residual
284
+
285
+ hidden_states = hidden_states / attn.rescale_output_factor
286
+
287
+ return hidden_states
ConsistentID/lib/functions.py ADDED
@@ -0,0 +1,606 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+ import types
4
+ import torch
5
+ import torch.nn as nn
6
+ import numpy as np
7
+ import cv2
8
+ import re
9
+ import torch.nn.functional as F
10
+ from einops import rearrange
11
+ from einops.layers.torch import Rearrange
12
+ from PIL import Image
13
+
14
+ def extract_first_sentence(text):
15
+ end_index = text.find('.')
16
+ if end_index != -1:
17
+ first_sentence = text[:end_index + 1]
18
+ return first_sentence.strip()
19
+ else:
20
+ return text.strip()
21
+
22
+ import re
23
+ def remove_duplicate_keywords(text, keywords):
24
+ keyword_counts = {}
25
+
26
+ words = re.findall(r'\b\w+\b|[.,;!?]', text)
27
+
28
+ for keyword in keywords:
29
+ keyword_counts[keyword] = 0
30
+ for i, word in enumerate(words):
31
+ if word.lower() == keyword.lower():
32
+ keyword_counts[keyword] += 1
33
+ if keyword_counts[keyword] > 1:
34
+ words[i] = ""
35
+ processed_text = " ".join(words)
36
+
37
+ return processed_text
38
+
39
+ # text: 'The person has one nose , two eyes , two ears , and a mouth .'
40
+ def insert_markers_to_prompt(text, parsing_mask_dict):
41
+ keywords = ["face", "ears", "eyes", "nose", "mouth"]
42
+ text = remove_duplicate_keywords(text, keywords)
43
+ key_parsing_mask_markers = ["Nose", "Face", "Left_Ear", "Right_Ear", "Left_Eye", "Right_Eye", "Upper_Lip", "Lower_Lip"]
44
+ mapping = {
45
+ "Face": "face",
46
+ "Left_Ear": "ears",
47
+ "Right_Ear": "ears",
48
+ "Left_Eye": "eyes",
49
+ "Right_Eye": "eyes",
50
+ "Nose": "nose",
51
+ "Upper_Lip": "mouth",
52
+ "Lower_Lip": "mouth",
53
+ }
54
+ facial_features_align = []
55
+ markers_align = []
56
+ for key in key_parsing_mask_markers:
57
+ if key in parsing_mask_dict:
58
+ mapped_key = mapping.get(key, key.lower())
59
+ if mapped_key not in facial_features_align:
60
+ facial_features_align.append(mapped_key)
61
+ markers_align.append("<|" + mapped_key + "|>")
62
+
63
+ text_marked = text
64
+ align_parsing_mask_dict = parsing_mask_dict
65
+ for feature, marker in zip(facial_features_align[::-1], markers_align[::-1]):
66
+ pattern = rf'\b{feature}\b'
67
+ text_marked_new = re.sub(pattern, f'{feature} {marker}', text_marked, count=1)
68
+ if text_marked == text_marked_new:
69
+ for key, value in mapping.items():
70
+ if value == feature:
71
+ if key in align_parsing_mask_dict:
72
+ del align_parsing_mask_dict[key]
73
+
74
+ text_marked = text_marked_new
75
+
76
+ text_marked = text_marked.replace('\n', '')
77
+
78
+ ordered_text = []
79
+ text_none_makers = []
80
+ facial_marked_count = 0
81
+ skip_count = 0
82
+ for marker in markers_align:
83
+ start_idx = text_marked.find(marker)
84
+ end_idx = start_idx + len(marker)
85
+
86
+ while start_idx > 0 and text_marked[start_idx - 1] not in [",", ".", ";"]:
87
+ start_idx -= 1
88
+
89
+ while end_idx < len(text_marked) and text_marked[end_idx] not in [",", ".", ";"]:
90
+ end_idx += 1
91
+
92
+ context = text_marked[start_idx:end_idx].strip()
93
+ if context == "":
94
+ text_none_makers.append(text_marked[:end_idx])
95
+ else:
96
+ if skip_count!=0:
97
+ skip_count -= 1
98
+ continue
99
+ else:
100
+ ordered_text.append(context + ", ")
101
+ text_delete_makers = text_marked[:start_idx] + text_marked[end_idx:]
102
+ text_marked = text_delete_makers
103
+ facial_marked_count += 1
104
+
105
+ # ordered_text: ['The person has one nose <|nose|>, ', 'two ears <|ears|>, ',
106
+ # 'two eyes <|eyes|>, ', 'and a mouth <|mouth|>, ']
107
+ # align_parsing_mask_dict.keys(): ['Right_Eye', 'Right_Ear', 'Nose', 'Upper_Lip']
108
+ align_marked_text = "".join(ordered_text)
109
+ replace_list = ["<|face|>", "<|ears|>", "<|nose|>", "<|eyes|>", "<|mouth|>"]
110
+ for item in replace_list:
111
+ align_marked_text = align_marked_text.replace(item, "<|facial|>")
112
+
113
+ # align_marked_text: 'The person has one nose <|facial|>, two ears <|facial|>, two eyes <|facial|>, and a mouth <|facial|>, '
114
+ return align_marked_text, align_parsing_mask_dict
115
+
116
+ def tokenize_and_mask_noun_phrases_ends(text, image_token_id, facial_token_id, tokenizer):
117
+ input_ids = tokenizer.encode(text)
118
+ image_noun_phrase_end_mask = [False for _ in input_ids]
119
+ facial_noun_phrase_end_mask = [False for _ in input_ids]
120
+ clean_input_ids = []
121
+ clean_index = 0
122
+ image_num = 0
123
+
124
+ for i, id in enumerate(input_ids):
125
+ if id == image_token_id:
126
+ image_noun_phrase_end_mask[clean_index + image_num - 1] = True
127
+ image_num += 1
128
+ elif id == facial_token_id:
129
+ facial_noun_phrase_end_mask[clean_index - 1] = True
130
+ else:
131
+ clean_input_ids.append(id)
132
+ clean_index += 1
133
+
134
+ max_len = tokenizer.model_max_length
135
+
136
+ if len(clean_input_ids) > max_len:
137
+ clean_input_ids = clean_input_ids[:max_len]
138
+ else:
139
+ clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * (
140
+ max_len - len(clean_input_ids)
141
+ )
142
+
143
+ if len(image_noun_phrase_end_mask) > max_len:
144
+ image_noun_phrase_end_mask = image_noun_phrase_end_mask[:max_len]
145
+ else:
146
+ image_noun_phrase_end_mask = image_noun_phrase_end_mask + [False] * (
147
+ max_len - len(image_noun_phrase_end_mask)
148
+ )
149
+
150
+ if len(facial_noun_phrase_end_mask) > max_len:
151
+ facial_noun_phrase_end_mask = facial_noun_phrase_end_mask[:max_len]
152
+ else:
153
+ facial_noun_phrase_end_mask = facial_noun_phrase_end_mask + [False] * (
154
+ max_len - len(facial_noun_phrase_end_mask)
155
+ )
156
+ clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long)
157
+ image_noun_phrase_end_mask = torch.tensor(image_noun_phrase_end_mask, dtype=torch.bool)
158
+ facial_noun_phrase_end_mask = torch.tensor(facial_noun_phrase_end_mask, dtype=torch.bool)
159
+
160
+ return clean_input_ids.unsqueeze(0), image_noun_phrase_end_mask.unsqueeze(0), facial_noun_phrase_end_mask.unsqueeze(0)
161
+
162
+ def prepare_image_token_idx(image_token_mask, facial_token_mask, max_num_objects=2, max_num_facials=5):
163
+ image_token_idx = torch.nonzero(image_token_mask, as_tuple=True)[1]
164
+ image_token_idx_mask = torch.ones_like(image_token_idx, dtype=torch.bool)
165
+ if len(image_token_idx) < max_num_objects:
166
+ image_token_idx = torch.cat(
167
+ [
168
+ image_token_idx,
169
+ torch.zeros(max_num_objects - len(image_token_idx), dtype=torch.long),
170
+ ]
171
+ )
172
+ image_token_idx_mask = torch.cat(
173
+ [
174
+ image_token_idx_mask,
175
+ torch.zeros(
176
+ max_num_objects - len(image_token_idx_mask),
177
+ dtype=torch.bool,
178
+ ),
179
+ ]
180
+ )
181
+ facial_token_idx = torch.nonzero(facial_token_mask, as_tuple=True)[1]
182
+ facial_token_idx_mask = torch.ones_like(facial_token_idx, dtype=torch.bool)
183
+ if len(facial_token_idx) < max_num_facials:
184
+ facial_token_idx = torch.cat(
185
+ [
186
+ facial_token_idx,
187
+ torch.zeros(max_num_facials - len(facial_token_idx), dtype=torch.long),
188
+ ]
189
+ )
190
+ facial_token_idx_mask = torch.cat(
191
+ [
192
+ facial_token_idx_mask,
193
+ torch.zeros(
194
+ max_num_facials - len(facial_token_idx_mask),
195
+ dtype=torch.bool,
196
+ ),
197
+ ]
198
+ )
199
+ image_token_idx = image_token_idx.unsqueeze(0)
200
+ image_token_idx_mask = image_token_idx_mask.unsqueeze(0)
201
+
202
+ facial_token_idx = facial_token_idx.unsqueeze(0)
203
+ facial_token_idx_mask = facial_token_idx_mask.unsqueeze(0)
204
+
205
+ return image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask
206
+
207
+ def get_object_localization_loss_for_one_layer(
208
+ cross_attention_scores,
209
+ object_segmaps,
210
+ object_token_idx,
211
+ object_token_idx_mask,
212
+ loss_fn,
213
+ ):
214
+ bxh, num_noise_latents, num_text_tokens = cross_attention_scores.shape
215
+ b, max_num_objects, _, _ = object_segmaps.shape
216
+ size = int(num_noise_latents**0.5)
217
+
218
+ object_segmaps = F.interpolate(object_segmaps, size=(size, size), mode="bilinear", antialias=True)
219
+
220
+ object_segmaps = object_segmaps.view(
221
+ b, max_num_objects, -1
222
+ )
223
+
224
+ num_heads = bxh // b
225
+ cross_attention_scores = cross_attention_scores.view(b, num_heads, num_noise_latents, num_text_tokens)
226
+
227
+
228
+ object_token_attn_prob = torch.gather(
229
+ cross_attention_scores,
230
+ dim=3,
231
+ index=object_token_idx.view(b, 1, 1, max_num_objects).expand(
232
+ b, num_heads, num_noise_latents, max_num_objects
233
+ ),
234
+ )
235
+ object_segmaps = (
236
+ object_segmaps.permute(0, 2, 1)
237
+ .unsqueeze(1)
238
+ .expand(b, num_heads, num_noise_latents, max_num_objects)
239
+ )
240
+ loss = loss_fn(object_token_attn_prob, object_segmaps)
241
+
242
+ loss = loss * object_token_idx_mask.view(b, 1, max_num_objects)
243
+ object_token_cnt = object_token_idx_mask.sum(dim=1).view(b, 1) + 1e-5
244
+ loss = (loss.sum(dim=2) / object_token_cnt).mean()
245
+
246
+ return loss
247
+
248
+
249
+ def get_object_localization_loss(
250
+ cross_attention_scores,
251
+ object_segmaps,
252
+ image_token_idx,
253
+ image_token_idx_mask,
254
+ loss_fn,
255
+ ):
256
+ num_layers = len(cross_attention_scores)
257
+ loss = 0
258
+ for k, v in cross_attention_scores.items():
259
+ layer_loss = get_object_localization_loss_for_one_layer(
260
+ v, object_segmaps, image_token_idx, image_token_idx_mask, loss_fn
261
+ )
262
+ loss += layer_loss
263
+ return loss / num_layers
264
+
265
+ def unet_store_cross_attention_scores(unet, attention_scores, layers=5):
266
+ from diffusers.models.attention_processor import Attention
267
+
268
+ UNET_LAYER_NAMES = [
269
+ "down_blocks.0",
270
+ "down_blocks.1",
271
+ "down_blocks.2",
272
+ "mid_block",
273
+ "up_blocks.1",
274
+ "up_blocks.2",
275
+ "up_blocks.3",
276
+ ]
277
+
278
+ start_layer = (len(UNET_LAYER_NAMES) - layers) // 2
279
+ end_layer = start_layer + layers
280
+ applicable_layers = UNET_LAYER_NAMES[start_layer:end_layer]
281
+
282
+ def make_new_get_attention_scores_fn(name):
283
+ def new_get_attention_scores(module, query, key, attention_mask=None):
284
+ attention_probs = module.old_get_attention_scores(
285
+ query, key, attention_mask
286
+ )
287
+ attention_scores[name] = attention_probs
288
+ return attention_probs
289
+
290
+ return new_get_attention_scores
291
+
292
+ for name, module in unet.named_modules():
293
+ if isinstance(module, Attention) and "attn1" in name:
294
+ if not any(layer in name for layer in applicable_layers):
295
+ continue
296
+
297
+ module.old_get_attention_scores = module.get_attention_scores
298
+ module.get_attention_scores = types.MethodType(
299
+ make_new_get_attention_scores_fn(name), module
300
+ )
301
+ return unet
302
+
303
+ class BalancedL1Loss(nn.Module):
304
+ def __init__(self, threshold=1.0, normalize=False):
305
+ super().__init__()
306
+ self.threshold = threshold
307
+ self.normalize = normalize
308
+
309
+ def forward(self, object_token_attn_prob, object_segmaps):
310
+ if self.normalize:
311
+ object_token_attn_prob = object_token_attn_prob / (
312
+ object_token_attn_prob.max(dim=2, keepdim=True)[0] + 1e-5
313
+ )
314
+ background_segmaps = 1 - object_segmaps
315
+ background_segmaps_sum = background_segmaps.sum(dim=2) + 1e-5
316
+ object_segmaps_sum = object_segmaps.sum(dim=2) + 1e-5
317
+
318
+ background_loss = (object_token_attn_prob * background_segmaps).sum(
319
+ dim=2
320
+ ) / background_segmaps_sum
321
+
322
+ object_loss = (object_token_attn_prob * object_segmaps).sum(
323
+ dim=2
324
+ ) / object_segmaps_sum
325
+
326
+ return background_loss - object_loss
327
+
328
+ def apply_mask_to_raw_image(raw_image, mask_image):
329
+ mask_image = mask_image.resize(raw_image.size)
330
+ mask_raw_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (0, 0, 0)), mask_image)
331
+ return mask_raw_image
332
+
333
+ mapping_table = [
334
+ {"Mask Value": 0, "Body Part": "Background", "RGB Color": [0, 0, 0]},
335
+ {"Mask Value": 1, "Body Part": "Face", "RGB Color": [255, 0, 0]},
336
+ {"Mask Value": 2, "Body Part": "Left_Eyebrow", "RGB Color": [255, 85, 0]},
337
+ {"Mask Value": 3, "Body Part": "Right_Eyebrow", "RGB Color": [255, 170, 0]},
338
+ {"Mask Value": 4, "Body Part": "Left_Eye", "RGB Color": [255, 0, 85]},
339
+ {"Mask Value": 5, "Body Part": "Right_Eye", "RGB Color": [255, 0, 170]},
340
+ {"Mask Value": 6, "Body Part": "Hair", "RGB Color": [0, 0, 255]},
341
+ {"Mask Value": 7, "Body Part": "Left_Ear", "RGB Color": [85, 0, 255]},
342
+ {"Mask Value": 8, "Body Part": "Right_Ear", "RGB Color": [170, 0, 255]},
343
+ {"Mask Value": 9, "Body Part": "Mouth_External Contour", "RGB Color": [0, 255, 85]},
344
+ {"Mask Value": 10, "Body Part": "Nose", "RGB Color": [0, 255, 0]},
345
+ {"Mask Value": 11, "Body Part": "Mouth_Inner_Contour", "RGB Color": [0, 255, 170]},
346
+ {"Mask Value": 12, "Body Part": "Upper_Lip", "RGB Color": [85, 255, 0]},
347
+ {"Mask Value": 13, "Body Part": "Lower_Lip", "RGB Color": [170, 255, 0]},
348
+ {"Mask Value": 14, "Body Part": "Neck", "RGB Color": [0, 85, 255]},
349
+ {"Mask Value": 15, "Body Part": "Neck_Inner Contour", "RGB Color": [0, 170, 255]},
350
+ {"Mask Value": 16, "Body Part": "Cloth", "RGB Color": [255, 255, 0]},
351
+ {"Mask Value": 17, "Body Part": "Hat", "RGB Color": [255, 0, 255]},
352
+ {"Mask Value": 18, "Body Part": "Earring", "RGB Color": [255, 85, 255]},
353
+ {"Mask Value": 19, "Body Part": "Necklace", "RGB Color": [255, 255, 85]},
354
+ {"Mask Value": 20, "Body Part": "Glasses", "RGB Color": [255, 170, 255]},
355
+ {"Mask Value": 21, "Body Part": "Hand", "RGB Color": [255, 0, 255]},
356
+ {"Mask Value": 22, "Body Part": "Wristband", "RGB Color": [0, 255, 255]},
357
+ {"Mask Value": 23, "Body Part": "Clothes_Upper", "RGB Color": [85, 255, 255]},
358
+ {"Mask Value": 24, "Body Part": "Clothes_Lower", "RGB Color": [170, 255, 255]}
359
+ ]
360
+
361
+
362
+ def masks_for_unique_values(image_raw_mask):
363
+
364
+ image_array = np.array(image_raw_mask)
365
+ unique_values, counts = np.unique(image_array, return_counts=True)
366
+ masks_dict = {}
367
+ for value in unique_values:
368
+ binary_image = np.uint8(image_array == value) * 255
369
+ contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
370
+
371
+ mask = np.zeros_like(image_array)
372
+ for contour in contours:
373
+ cv2.drawContours(mask, [contour], -1, (255), thickness=cv2.FILLED)
374
+
375
+ if value == 0:
376
+ body_part="WithoutBackground"
377
+ mask2 = np.where(mask == 255, 0, 255).astype(mask.dtype)
378
+ masks_dict[body_part] = Image.fromarray(mask2)
379
+
380
+ body_part = next((entry["Body Part"] for entry in mapping_table if entry["Mask Value"] == value), f"Unknown_{value}")
381
+ if body_part.startswith("Unknown_"):
382
+ continue
383
+
384
+ masks_dict[body_part] = Image.fromarray(mask)
385
+
386
+ return masks_dict
387
+ # FFN
388
+ def FeedForward(dim, mult=4):
389
+ inner_dim = int(dim * mult)
390
+ return nn.Sequential(
391
+ nn.LayerNorm(dim),
392
+ nn.Linear(dim, inner_dim, bias=False),
393
+ nn.GELU(),
394
+ nn.Linear(inner_dim, dim, bias=False),
395
+ )
396
+
397
+
398
+ def reshape_tensor(x, heads):
399
+ bs, length, width = x.shape
400
+ x = x.view(bs, length, heads, -1)
401
+ x = x.transpose(1, 2)
402
+ x = x.reshape(bs, heads, length, -1)
403
+ return x
404
+
405
+ class PerceiverAttention(nn.Module):
406
+ def __init__(self, *, dim, dim_head=64, heads=8):
407
+ super().__init__()
408
+ self.scale = dim_head**-0.5
409
+ self.dim_head = dim_head
410
+ self.heads = heads
411
+ inner_dim = dim_head * heads
412
+
413
+ self.norm1 = nn.LayerNorm(dim)
414
+ self.norm2 = nn.LayerNorm(dim)
415
+
416
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
417
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
418
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
419
+
420
+ # x -> kv, latents -> q
421
+ def forward(self, x, latents):
422
+ """
423
+ Args:
424
+ x (torch.Tensor): image features
425
+ shape (b, n1, D)
426
+ latent (torch.Tensor): latent features
427
+ shape (b, n2, D)
428
+ """
429
+
430
+ x = self.norm1(x)
431
+ latents = self.norm2(latents)
432
+
433
+ b, l, _ = latents.shape
434
+
435
+ q = self.to_q(latents)
436
+ kv_input = torch.cat((x, latents), dim=-2)
437
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
438
+
439
+ q = reshape_tensor(q, self.heads)
440
+ k = reshape_tensor(k, self.heads)
441
+ v = reshape_tensor(v, self.heads)
442
+
443
+ # attention
444
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
445
+ weight = (q * scale) @ (k * scale).transpose(-2, -1)
446
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
447
+ out = weight @ v
448
+
449
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
450
+
451
+ return self.to_out(out)
452
+
453
+ class FacePerceiverResampler(torch.nn.Module):
454
+ def __init__(
455
+ self,
456
+ *,
457
+ dim=768,
458
+ depth=4,
459
+ dim_head=64,
460
+ heads=16,
461
+ embedding_dim=1280,
462
+ output_dim=768,
463
+ ff_mult=4,
464
+ ):
465
+ super().__init__()
466
+
467
+ self.proj_in = torch.nn.Linear(embedding_dim, dim)
468
+ self.proj_out = torch.nn.Linear(dim, output_dim)
469
+ self.norm_out = torch.nn.LayerNorm(output_dim)
470
+ self.layers = torch.nn.ModuleList([])
471
+ for _ in range(depth):
472
+ self.layers.append(
473
+ torch.nn.ModuleList(
474
+ [
475
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
476
+ FeedForward(dim=dim, mult=ff_mult),
477
+ ]
478
+ )
479
+ )
480
+ # x -> kv, latents -> q
481
+ def forward(self, latents, x): # latents.torch.Size([2, 4, 768]) x.torch.Size([2, 257, 1280])
482
+ x = self.proj_in(x) # x.torch.Size([2, 257, 768])
483
+ for attn, ff in self.layers:
484
+ # x -> kv, latents -> q
485
+ latents = attn(x, latents) + latents # latents.torch.Size([2, 4, 768])
486
+ latents = ff(latents) + latents # latents.torch.Size([2, 4, 768])
487
+ latents = self.proj_out(latents)
488
+ return self.norm_out(latents)
489
+
490
+ class ProjPlusModel(torch.nn.Module):
491
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4):
492
+ super().__init__()
493
+
494
+ self.cross_attention_dim = cross_attention_dim
495
+ self.num_tokens = num_tokens
496
+
497
+ self.proj = torch.nn.Sequential(
498
+ torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
499
+ torch.nn.GELU(),
500
+ torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
501
+ )
502
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
503
+
504
+ self.perceiver_resampler = FacePerceiverResampler(
505
+ dim=cross_attention_dim,
506
+ depth=4,
507
+ dim_head=64,
508
+ heads=cross_attention_dim // 64,
509
+ embedding_dim=clip_embeddings_dim,
510
+ output_dim=cross_attention_dim,
511
+ ff_mult=4,
512
+ )
513
+
514
+ def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0):
515
+
516
+ x = self.proj(id_embeds)
517
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
518
+ x = self.norm(x)
519
+ # id_embeds -> x -> kv, clip_embeds -> q
520
+ out = self.perceiver_resampler(x, clip_embeds)
521
+ if shortcut:
522
+ out = scale * x + out
523
+ return out
524
+
525
+ class AttentionMLP(nn.Module):
526
+ def __init__(
527
+ self,
528
+ dtype=torch.float16,
529
+ dim=1024,
530
+ depth=8,
531
+ dim_head=64,
532
+ heads=16,
533
+ single_num_tokens=1,
534
+ embedding_dim=1280,
535
+ output_dim=768,
536
+ ff_mult=4,
537
+ max_seq_len: int = 257*2,
538
+ apply_pos_emb: bool = False,
539
+ num_latents_mean_pooled: int = 0,
540
+ ):
541
+ super().__init__()
542
+ self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
543
+
544
+ self.single_num_tokens = single_num_tokens
545
+ self.latents = nn.Parameter(torch.randn(1, self.single_num_tokens, dim) / dim**0.5)
546
+
547
+ self.proj_in = nn.Linear(embedding_dim, dim)
548
+
549
+ self.proj_out = nn.Linear(dim, output_dim)
550
+ self.norm_out = nn.LayerNorm(output_dim)
551
+
552
+ self.to_latents_from_mean_pooled_seq = (
553
+ nn.Sequential(
554
+ nn.LayerNorm(dim),
555
+ nn.Linear(dim, dim * num_latents_mean_pooled),
556
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
557
+ )
558
+ if num_latents_mean_pooled > 0
559
+ else None
560
+ )
561
+
562
+ self.layers = nn.ModuleList([])
563
+ for _ in range(depth):
564
+ self.layers.append(
565
+ nn.ModuleList(
566
+ [
567
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
568
+ FeedForward(dim=dim, mult=ff_mult),
569
+ ]
570
+ )
571
+ )
572
+
573
+ def forward(self, x):
574
+ if self.pos_emb is not None:
575
+ n, device = x.shape[1], x.device
576
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
577
+ x = x + pos_emb
578
+ # x torch.Size([5, 257, 1280])
579
+ latents = self.latents.repeat(x.size(0), 1, 1)
580
+
581
+ x = self.proj_in(x) # torch.Size([5, 257, 1024])
582
+
583
+ if self.to_latents_from_mean_pooled_seq:
584
+ meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
585
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
586
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
587
+
588
+ for attn, ff in self.layers:
589
+ latents = attn(x, latents) + latents
590
+ latents = ff(latents) + latents
591
+
592
+ latents = self.proj_out(latents)
593
+ return self.norm_out(latents)
594
+
595
+
596
+ def masked_mean(t, *, dim, mask=None):
597
+ if mask is None:
598
+ return t.mean(dim=dim)
599
+
600
+ denom = mask.sum(dim=dim, keepdim=True)
601
+ mask = rearrange(mask, "b n -> b n 1")
602
+ masked_t = t.masked_fill(~mask, 0.0)
603
+
604
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
605
+
606
+
ConsistentID/lib/pipeline_ConsistentID.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
2
+ import cv2
3
+ import PIL
4
+ import numpy as np
5
+ from PIL import Image
6
+ import torch
7
+ from torchvision import transforms
8
+ from insightface.app import FaceAnalysis
9
+ ### insight-face installation can be found at https://github.com/deepinsight/insightface
10
+ from safetensors import safe_open
11
+ from huggingface_hub.utils import validate_hf_hub_args
12
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
13
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
14
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
15
+ from .functions import insert_markers_to_prompt, masks_for_unique_values, apply_mask_to_raw_image, tokenize_and_mask_noun_phrases_ends, prepare_image_token_idx
16
+ from .functions import ProjPlusModel, masks_for_unique_values
17
+ from .attention import Consistent_IPAttProcessor, Consistent_AttProcessor, FacialEncoder
18
+ from easydict import EasyDict as edict
19
+ from huggingface_hub import hf_hub_download
20
+ ### Model can be imported from https://github.com/zllrunning/face-parsing.PyTorch?tab=readme-ov-file
21
+ ### We use the ckpt of 79999_iter.pth: https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812
22
+ ### Thanks for the open source of face-parsing model.
23
+ from .BiSeNet.model import BiSeNet
24
+ import os
25
+
26
+ PipelineImageInput = Union[
27
+ PIL.Image.Image,
28
+ torch.FloatTensor,
29
+ List[PIL.Image.Image],
30
+ List[torch.FloatTensor],
31
+ ]
32
+
33
+ ### Download the pretrained model from huggingface and put it locally, then place the model in a local directory and specify the directory location.
34
+ class ConsistentIDPipeline(StableDiffusionPipeline):
35
+ # to() should be only called after all modules are loaded.
36
+ def to(
37
+ self,
38
+ torch_device: Optional[Union[str, torch.device]] = None,
39
+ dtype: Optional[torch.dtype] = None,
40
+ ):
41
+ super().to(torch_device, dtype=dtype)
42
+ self.bise_net.to(torch_device, dtype=dtype)
43
+ self.clip_encoder.to(torch_device, dtype=dtype)
44
+ self.image_proj_model.to(torch_device, dtype=dtype)
45
+ self.FacialEncoder.to(torch_device, dtype=dtype)
46
+ # If the unet is not released, the ip_layers should be moved to the specified device and dtype.
47
+ if not isinstance(self.unet, edict):
48
+ self.ip_layers.to(torch_device, dtype=dtype)
49
+ return self
50
+
51
+ @validate_hf_hub_args
52
+ def load_ConsistentID_model(
53
+ self,
54
+ consistentID_weight_path: str,
55
+ bise_net_weight_path: str,
56
+ trigger_word_facial: str = '<|facial|>',
57
+ # A CLIP ViT-H/14 model trained with the LAION-2B English subset of LAION-5B using OpenCLIP.
58
+ # output dim: 1280.
59
+ image_encoder_path: str = 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
60
+ torch_dtype = torch.float16,
61
+ num_tokens = 4,
62
+ lora_rank= 128,
63
+ **kwargs,
64
+ ):
65
+ self.lora_rank = lora_rank
66
+ self.torch_dtype = torch_dtype
67
+ self.num_tokens = num_tokens
68
+ self.set_ip_adapter()
69
+ self.image_encoder_path = image_encoder_path
70
+ self.clip_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path)
71
+ self.clip_preprocessor = CLIPImageProcessor()
72
+ self.id_image_processor = CLIPImageProcessor()
73
+ self.crop_size = 512
74
+
75
+ # face_app: FaceAnalysis object
76
+ self.face_app = FaceAnalysis(name="buffalo_l", root='models/insightface', providers=['CPUExecutionProvider'])
77
+ # The original det_size=(640, 640) is too large and face_app often fails to detect faces.
78
+ self.face_app.prepare(ctx_id=0, det_size=(512, 512))
79
+
80
+ if not os.path.exists(consistentID_weight_path):
81
+ ### Download pretrained models
82
+ hf_hub_download(repo_id="JackAILab/ConsistentID", repo_type="model",
83
+ filename=os.path.basename(consistentID_weight_path),
84
+ local_dir=os.path.dirname(consistentID_weight_path))
85
+ if not os.path.exists(bise_net_weight_path):
86
+ hf_hub_download(repo_id="JackAILab/ConsistentID",
87
+ filename=os.path.basename(bise_net_weight_path),
88
+ local_dir=os.path.dirname(bise_net_weight_path))
89
+
90
+ bise_net = BiSeNet(n_classes = 19)
91
+ bise_net.load_state_dict(torch.load(bise_net_weight_path, map_location="cpu"))
92
+ bise_net.eval()
93
+ self.bise_net = bise_net
94
+
95
+ # Colors for all 20 parts
96
+ self.part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
97
+ [255, 0, 85], [255, 0, 170],
98
+ [0, 255, 0], [85, 255, 0], [170, 255, 0],
99
+ [0, 255, 85], [0, 255, 170],
100
+ [0, 0, 255], [85, 0, 255], [170, 0, 255],
101
+ [0, 85, 255], [0, 170, 255],
102
+ [255, 255, 0], [255, 255, 85], [255, 255, 170],
103
+ [255, 0, 255], [255, 85, 255], [255, 170, 255],
104
+ [0, 255, 255], [85, 255, 255], [170, 255, 255]]
105
+
106
+ # image_proj_model maps 1280-dim OpenCLIP embeddings to 768-dim face prompt embeddings.
107
+ self.image_proj_model = ProjPlusModel(
108
+ cross_attention_dim=self.unet.config.cross_attention_dim,
109
+ id_embeddings_dim=512,
110
+ clip_embeddings_dim=self.clip_encoder.config.hidden_size,
111
+ num_tokens=self.num_tokens, # 4 - inspirsed by IPAdapter and Midjourney
112
+ )
113
+ self.FacialEncoder = FacialEncoder()
114
+
115
+ if consistentID_weight_path.endswith(".safetensors"):
116
+ state_dict = {"id_encoder": {}, "lora_weights": {}}
117
+ with safe_open(consistentID_weight_path, framework="pt", device="cpu") as f:
118
+ ### TODO safetensors add
119
+ for key in f.keys():
120
+ if key.startswith("FacialEncoder."):
121
+ state_dict["FacialEncoder"][key.replace("FacialEncoder.", "")] = f.get_tensor(key)
122
+ elif key.startswith("image_proj."):
123
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
124
+ else:
125
+ state_dict = torch.load(consistentID_weight_path, map_location="cpu")
126
+
127
+ self.trigger_word_facial = trigger_word_facial
128
+
129
+ self.FacialEncoder.load_state_dict(state_dict["FacialEncoder"], strict=True)
130
+ self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
131
+ self.ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
132
+ self.ip_layers.load_state_dict(state_dict["adapter_modules"], strict=True)
133
+ print(f"Successfully loaded weights from checkpoint")
134
+
135
+ # Add trigger word token
136
+ if self.tokenizer is not None:
137
+ self.tokenizer.add_tokens([self.trigger_word_facial], special_tokens=True)
138
+
139
+ def set_ip_adapter(self):
140
+ unet = self.unet
141
+ attn_procs = {}
142
+ for name in unet.attn_processors.keys():
143
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
144
+ if name.startswith("mid_block"):
145
+ hidden_size = unet.config.block_out_channels[-1]
146
+ elif name.startswith("up_blocks"):
147
+ block_id = int(name[len("up_blocks.")])
148
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
149
+ elif name.startswith("down_blocks"):
150
+ block_id = int(name[len("down_blocks.")])
151
+ hidden_size = unet.config.block_out_channels[block_id]
152
+ if cross_attention_dim is None:
153
+ attn_procs[name] = Consistent_AttProcessor(
154
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank,
155
+ )
156
+ else:
157
+ attn_procs[name] = Consistent_IPAttProcessor(
158
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens,
159
+ )
160
+
161
+ unet.set_attn_processor(attn_procs)
162
+
163
+ @torch.inference_mode()
164
+ # parsed_image_parts2 is a batched tensor of parsed_image_parts with bs=1. It only contains the facial areas of one input image.
165
+ # clip_encoder maps image parts to image-space diffusion prompts.
166
+ # Then the facial class token embeddings are replaced with the fused (multi_facial_embeds, prompt_embeds[class_tokens_mask]).
167
+ def extract_local_facial_embeds(self, prompt_embeds, uncond_prompt_embeds, parsed_image_parts2,
168
+ facial_token_masks, valid_facial_token_idx_mask, calc_uncond=True):
169
+
170
+ hidden_states = []
171
+ uncond_hidden_states = []
172
+ for parsed_image_parts in parsed_image_parts2:
173
+ hidden_state = self.clip_encoder(parsed_image_parts.to(self.device, dtype=self.torch_dtype), output_hidden_states=True).hidden_states[-2]
174
+ uncond_hidden_state = self.clip_encoder(torch.zeros_like(parsed_image_parts, dtype=self.torch_dtype).to(self.device), output_hidden_states=True).hidden_states[-2]
175
+ hidden_states.append(hidden_state)
176
+ uncond_hidden_states.append(uncond_hidden_state)
177
+ multi_facial_embeds = torch.stack(hidden_states)
178
+ uncond_multi_facial_embeds = torch.stack(uncond_hidden_states)
179
+
180
+ # conditional prompt.
181
+ # FacialEncoder maps multi_facial_embeds to facial ID embeddings, and replaces the class tokens in prompt_embeds
182
+ # with the fused (facial ID embeddings, prompt_embeds[class_tokens_mask]).
183
+ # multi_facial_embeds: [1, 5, 257, 1280].
184
+ facial_prompt_embeds = self.FacialEncoder(prompt_embeds, multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask)
185
+
186
+ if not calc_uncond:
187
+ return facial_prompt_embeds, None
188
+ # unconditional prompt.
189
+ uncond_facial_prompt_embeds = self.FacialEncoder(uncond_prompt_embeds, uncond_multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask)
190
+
191
+ return facial_prompt_embeds, uncond_facial_prompt_embeds
192
+
193
+ @torch.inference_mode()
194
+ # Extrat OpenCLIP embeddings from the input image and map them to face prompt embeddings.
195
+ def extract_global_id_embeds(self, face_image_obj, s_scale=1.0, shortcut=False):
196
+ clip_image_ts = self.clip_preprocessor(images=face_image_obj, return_tensors="pt").pixel_values
197
+ clip_image_ts = clip_image_ts.to(self.device, dtype=self.torch_dtype)
198
+ clip_image_embeds = self.clip_encoder(clip_image_ts, output_hidden_states=True).hidden_states[-2]
199
+ uncond_clip_image_embeds = self.clip_encoder(torch.zeros_like(clip_image_ts), output_hidden_states=True).hidden_states[-2]
200
+
201
+ faceid_embeds = self.extract_faceid(face_image_obj)
202
+ faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
203
+ # image_proj_model maps 1280-dim OpenCLIP embeddings to 768-dim face prompt embeddings.
204
+ # clip_image_embeds are used as queries to transform faceid_embeds.
205
+ # faceid_embeds -> kv, clip_image_embeds -> q
206
+ global_id_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale)
207
+ uncond_global_id_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale)
208
+
209
+ return global_id_embeds, uncond_global_id_embeds
210
+
211
+ def set_scale(self, scale):
212
+ for attn_processor in self.pipe.unet.attn_processors.values():
213
+ if isinstance(attn_processor, Consistent_IPAttProcessor):
214
+ attn_processor.scale = scale
215
+
216
+ @torch.inference_mode()
217
+ def extract_faceid(self, face_image_obj):
218
+ faceid_image = np.array(face_image_obj)
219
+ faces = self.face_app.get(faceid_image)
220
+ if faces==[]:
221
+ faceid_embeds = torch.zeros_like(torch.empty((1, 512)))
222
+ else:
223
+ faceid_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
224
+
225
+ return faceid_embeds
226
+
227
+ @torch.inference_mode()
228
+ def parse_face_mask(self, raw_image_refer):
229
+
230
+ to_tensor = transforms.Compose([
231
+ transforms.ToTensor(),
232
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
233
+ ])
234
+ to_pil = transforms.ToPILImage()
235
+
236
+ with torch.no_grad():
237
+ image = raw_image_refer.resize((512, 512), Image.BILINEAR)
238
+ image_resize_PIL = image
239
+ img = to_tensor(image)
240
+ img = torch.unsqueeze(img, 0)
241
+ img = img.to(self.device, dtype=self.torch_dtype)
242
+ out = self.bise_net(img)[0]
243
+ parsing_anno = out.squeeze(0).cpu().numpy().argmax(0)
244
+
245
+ im = np.array(image_resize_PIL)
246
+ vis_im = im.copy().astype(np.uint8)
247
+ stride=1
248
+ vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
249
+ vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
250
+ vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
251
+
252
+ num_of_class = np.max(vis_parsing_anno)
253
+
254
+ for pi in range(1, num_of_class + 1): # num_of_class=17 pi=1~16
255
+ index = np.where(vis_parsing_anno == pi)
256
+ vis_parsing_anno_color[index[0], index[1], :] = self.part_colors[pi]
257
+
258
+ vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
259
+ vis_parsing_anno_color = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
260
+
261
+ return vis_parsing_anno_color, vis_parsing_anno
262
+
263
+ @torch.inference_mode()
264
+ def extract_facemask(self, input_image_obj):
265
+ vis_parsing_anno_color, vis_parsing_anno = self.parse_face_mask(input_image_obj)
266
+ parsing_mask_list = masks_for_unique_values(vis_parsing_anno)
267
+
268
+ key_parsing_mask_dict = {}
269
+ key_list = ["Face", "Left_Ear", "Right_Ear", "Left_Eye", "Right_Eye", "Nose", "Upper_Lip", "Lower_Lip"]
270
+ processed_keys = set()
271
+ for key, mask_image in parsing_mask_list.items():
272
+ if key in key_list:
273
+ if "_" in key:
274
+ prefix = key.split("_")[1]
275
+ if prefix in processed_keys:
276
+ continue
277
+ else:
278
+ key_parsing_mask_dict[key] = mask_image
279
+ processed_keys.add(prefix)
280
+
281
+ key_parsing_mask_dict[key] = mask_image
282
+
283
+ return key_parsing_mask_dict, vis_parsing_anno_color
284
+
285
+ def augment_prompt_with_trigger_word(
286
+ self,
287
+ prompt: str,
288
+ face_caption: str,
289
+ key_parsing_mask_dict = None,
290
+ facial_token = "<|facial|>",
291
+ max_num_facials = 5,
292
+ num_id_images: int = 1,
293
+ device: Optional[torch.device] = None,
294
+ ):
295
+ device = device or self._execution_device
296
+
297
+ # face_caption_align: 'The person has one nose <|facial|>, two ears <|facial|>, two eyes <|facial|>, and a mouth <|facial|>, '
298
+ face_caption_align, key_parsing_mask_dict_align = insert_markers_to_prompt(face_caption, key_parsing_mask_dict)
299
+
300
+ prompt_face = prompt + " Detail: " + face_caption_align
301
+
302
+ max_text_length=330
303
+ if len(self.tokenizer(prompt_face, max_length=self.tokenizer.model_max_length,
304
+ padding="max_length", truncation=False, return_tensors="pt").input_ids[0]) != 77:
305
+ # Put face_caption_align at the beginning of the prompt, so that the original prompt is truncated,
306
+ # but the face_caption_align is well kept.
307
+ prompt_face = "Detail: " + face_caption_align + " Caption:" + prompt
308
+
309
+ # Remove "<|facial|>" from prompt_face.
310
+ # augmented_prompt: 'A person, police officer, half body shot Detail:
311
+ # The person has one nose , two ears , two eyes , and a mouth , '
312
+ augmented_prompt = prompt_face.replace("<|facial|>", "")
313
+ tokenizer = self.tokenizer
314
+ facial_token_id = tokenizer.convert_tokens_to_ids(facial_token)
315
+ image_token_id = None
316
+
317
+ # image_token_id: the token id of "<|image|>". Disabled, as it's set to None.
318
+ # facial_token_id: the token id of "<|facial|>".
319
+ clean_input_id, image_token_mask, facial_token_mask = \
320
+ tokenize_and_mask_noun_phrases_ends(prompt_face, image_token_id, facial_token_id, tokenizer)
321
+
322
+ image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask = \
323
+ prepare_image_token_idx(image_token_mask, facial_token_mask, num_id_images, max_num_facials)
324
+
325
+ return augmented_prompt, clean_input_id, key_parsing_mask_dict_align, facial_token_mask, facial_token_idx, facial_token_idx_mask
326
+
327
+ @torch.inference_mode()
328
+ def extract_parsed_image_parts(self, input_image_obj, key_parsing_mask_dict, image_size=512, max_num_facials=5):
329
+ facial_masks = []
330
+ parsed_image_parts = []
331
+ key_masked_raw_images_dict = {}
332
+ transform_mask = transforms.Compose([transforms.CenterCrop(size=image_size), transforms.ToTensor(),])
333
+ clip_preprocessor = CLIPImageProcessor()
334
+
335
+ num_facial_part = len(key_parsing_mask_dict)
336
+
337
+ for key in key_parsing_mask_dict:
338
+ key_mask=key_parsing_mask_dict[key]
339
+ facial_masks.append(transform_mask(key_mask))
340
+ key_masked_raw_image = apply_mask_to_raw_image(input_image_obj, key_mask)
341
+ key_masked_raw_images_dict[key] = key_masked_raw_image
342
+ # clip_preprocessor normalizes key_masked_raw_image, so that (masked) zero pixels become non-zero.
343
+ # It also resizes the image to 224x224.
344
+ parsed_image_part = clip_preprocessor(images=key_masked_raw_image, return_tensors="pt").pixel_values
345
+ parsed_image_parts.append(parsed_image_part)
346
+
347
+ padding_ficial_clip_image = torch.zeros_like(torch.zeros([1, 3, 224, 224]))
348
+ padding_ficial_mask = torch.zeros_like(torch.zeros([1, image_size, image_size]))
349
+
350
+ if num_facial_part < max_num_facials:
351
+ parsed_image_parts += [ torch.zeros_like(padding_ficial_clip_image) for _ in range(max_num_facials - num_facial_part) ]
352
+ facial_masks += [ torch.zeros_like(padding_ficial_mask) for _ in range(max_num_facials - num_facial_part) ]
353
+
354
+ parsed_image_parts = torch.stack(parsed_image_parts, dim=1).squeeze(0)
355
+ facial_masks = torch.stack(facial_masks, dim=0).squeeze(dim=1)
356
+
357
+ return parsed_image_parts, facial_masks, key_masked_raw_images_dict
358
+
359
+ # Release the unet/vae/text_encoder to save memory.
360
+ def release_components(self, released_components=["unet", "vae", "text_encoder"]):
361
+ if "unet" in released_components:
362
+ unet = edict()
363
+ # Only keep the config and in_channels attributes that are used in the pipeline.
364
+ unet.config = self.unet.config
365
+ self.unet = unet
366
+
367
+ if "vae" in released_components:
368
+ self.vae = None
369
+ if "text_encoder" in released_components:
370
+ self.text_encoder = None
371
+
372
+ # input_subj_image_obj: an Image object.
373
+ def extract_double_id_prompt_embeds(self, prompt, negative_prompt, input_subj_image_obj, device, calc_uncond=True):
374
+ face_caption = "The person has one nose, two eyes, two ears, and a mouth."
375
+ key_parsing_mask_dict, vis_parsing_anno_color = self.extract_facemask(input_subj_image_obj)
376
+
377
+ augmented_prompt, clean_input_id, key_parsing_mask_dict_align, \
378
+ facial_token_mask, facial_token_idx, facial_token_idx_mask \
379
+ = self.augment_prompt_with_trigger_word(
380
+ prompt = prompt,
381
+ face_caption = face_caption,
382
+ key_parsing_mask_dict=key_parsing_mask_dict,
383
+ device=device,
384
+ max_num_facials = 5,
385
+ num_id_images = 1
386
+ )
387
+
388
+ text_embeds, uncond_text_embeds = self.encode_prompt(
389
+ augmented_prompt,
390
+ device=device,
391
+ num_images_per_prompt=1,
392
+ do_classifier_free_guidance=calc_uncond,
393
+ negative_prompt=negative_prompt,
394
+ )
395
+
396
+ # 5. Prepare the input ID images
397
+ # global_id_embeds: [1, 4, 768]
398
+ # extract_global_id_embeds() extrats OpenCLIP embeddings from the input image and map them to global face prompt embeddings.
399
+ global_id_embeds, uncond_global_id_embeds = \
400
+ self.extract_global_id_embeds(face_image_obj=input_subj_image_obj, s_scale=1.0, shortcut=False)
401
+
402
+ # parsed_image_parts: [5, 3, 224, 224]. 5 parts, each part is a 3-channel 224x224 image (resized by CLIP Preprocessor).
403
+ parsed_image_parts, facial_masks, key_masked_raw_images_dict = \
404
+ self.extract_parsed_image_parts(input_subj_image_obj, key_parsing_mask_dict_align, image_size=512, max_num_facials=5)
405
+ parsed_image_parts2 = parsed_image_parts.unsqueeze(0).to(device, dtype=self.torch_dtype)
406
+ facial_token_mask = facial_token_mask.to(device)
407
+ facial_token_idx_mask = facial_token_idx_mask.to(device)
408
+
409
+ # key_masked_raw_images_dict: ['Right_Eye', 'Right_Ear', 'Nose', 'Upper_Lip']
410
+ # for key in key_masked_raw_images_dict:
411
+ # key_masked_raw_images_dict[key].save(f"{key}.png")
412
+
413
+ # 6. Get the update text embedding
414
+ # parsed_image_parts2: the facial areas of the input image
415
+ # extract_local_facial_embeds() maps parsed_image_parts2 to multi_facial_embeds, and then replaces the class tokens in prompt_embeds
416
+ # with the fused (id_embeds, prompt_embeds[class_tokens_mask]) whose indices are specified by class_tokens_mask.
417
+ # parsed_image_parts2: [1, 5, 3, 224, 224]
418
+ text_local_id_embeds, uncond_text_local_id_embeds = \
419
+ self.extract_local_facial_embeds(text_embeds, uncond_text_embeds, \
420
+ parsed_image_parts2, facial_token_mask, facial_token_idx_mask,
421
+ calc_uncond=calc_uncond)
422
+
423
+ # text_global_id_embeds, text_local_global_id_embeds: [1, 81, 768]
424
+ # text_local_id_embeds: [1, 77, 768], only differs with text_embeds on 4 ID embeddings, and is identical
425
+ # to text_embeds on the rest 73 tokens.
426
+ text_global_id_embeds = torch.cat([text_embeds, global_id_embeds], dim=1)
427
+ text_local_global_id_embeds = torch.cat([text_local_id_embeds, global_id_embeds], dim=1)
428
+
429
+ if calc_uncond:
430
+ uncond_text_global_id_embeds = torch.cat([uncond_text_local_id_embeds, uncond_global_id_embeds], dim=1)
431
+ coarse_prompt_embeds = torch.cat([uncond_text_global_id_embeds, text_global_id_embeds], dim=0)
432
+ fine_prompt_embeds = torch.cat([uncond_text_global_id_embeds, text_local_global_id_embeds], dim=0)
433
+ else:
434
+ coarse_prompt_embeds = text_global_id_embeds
435
+ fine_prompt_embeds = text_local_global_id_embeds
436
+
437
+ # fine_prompt_embeds: the conditional part is
438
+ # (text_global_id_embeds + text_local_global_id_embeds) / 2.
439
+ fine_prompt_embeds = (coarse_prompt_embeds + fine_prompt_embeds) / 2
440
+
441
+ return coarse_prompt_embeds, fine_prompt_embeds
442
+
443
+ @torch.no_grad()
444
+ def __call__(
445
+ self,
446
+ prompt: Union[str, List[str]] = None,
447
+ height: Optional[int] = None,
448
+ width: Optional[int] = None,
449
+ num_inference_steps: int = 50,
450
+ guidance_scale: float = 5.0,
451
+ negative_prompt: Optional[Union[str, List[str]]] = None,
452
+ num_images_per_prompt: Optional[int] = 1,
453
+ eta: float = 0.0,
454
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
455
+ latents: Optional[torch.FloatTensor] = None,
456
+ prompt_embeds: Optional[torch.FloatTensor] = None,
457
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
458
+ output_type: Optional[str] = "pil",
459
+ return_dict: bool = True,
460
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
461
+ original_size: Optional[Tuple[int, int]] = None,
462
+ target_size: Optional[Tuple[int, int]] = None,
463
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
464
+ callback_steps: int = 1,
465
+ input_subj_image_objs: PipelineImageInput = None,
466
+ start_merge_step: int = 0,
467
+ ):
468
+ # 0. Default height and width to unet
469
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
470
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
471
+
472
+ original_size = original_size or (height, width)
473
+ target_size = target_size or (height, width)
474
+
475
+ # 1. Check inputs. Raise error if not correct
476
+ self.check_inputs(
477
+ prompt,
478
+ height,
479
+ width,
480
+ callback_steps,
481
+ negative_prompt,
482
+ prompt_embeds,
483
+ negative_prompt_embeds,
484
+ )
485
+
486
+ # 2. Define call parameters
487
+ if prompt is not None and isinstance(prompt, str):
488
+ batch_size = 1
489
+ elif prompt is not None and isinstance(prompt, list):
490
+ batch_size = len(prompt)
491
+ else:
492
+ batch_size = prompt_embeds.shape[0]
493
+
494
+ device = self._execution_device
495
+ do_classifier_free_guidance = guidance_scale >= 1.0
496
+ assert do_classifier_free_guidance
497
+
498
+ if input_subj_image_objs is not None:
499
+ if not isinstance(input_subj_image_objs, list):
500
+ input_subj_image_objs = [input_subj_image_objs]
501
+
502
+ # 3. Encode input prompt
503
+ coarse_prompt_embeds, fine_prompt_embeds = \
504
+ self.extract_double_id_prompt_embeds(prompt, negative_prompt, input_subj_image_objs[0], device)
505
+ else:
506
+ # Replace the coarse_prompt_embeds and fine_prompt_embeds with the input prompt_embeds.
507
+ # This is used when prompt_embeds are computed in advance.
508
+ cfg_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
509
+ coarse_prompt_embeds = cfg_prompt_embeds
510
+ fine_prompt_embeds = cfg_prompt_embeds
511
+
512
+ # 7. Prepare timesteps
513
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
514
+ timesteps = self.scheduler.timesteps
515
+
516
+ # 8. Prepare latent variables
517
+ num_channels_latents = self.unet.config.in_channels
518
+ latents = self.prepare_latents(
519
+ batch_size * num_images_per_prompt,
520
+ num_channels_latents,
521
+ height,
522
+ width,
523
+ self.dtype,
524
+ device,
525
+ generator,
526
+ latents,
527
+ )
528
+
529
+ # {'eta': 0.0, 'generator': None}. eta is 0 for DDIM.
530
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
531
+ cross_attention_kwargs = {}
532
+
533
+ # 9. Denoising loop
534
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
535
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
536
+ for i, t in enumerate(timesteps):
537
+ latent_model_input = (
538
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
539
+ )
540
+ # DDIM doesn't scale latent_model_input.
541
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
542
+
543
+ if i <= start_merge_step:
544
+ current_prompt_embeds = coarse_prompt_embeds
545
+ else:
546
+ current_prompt_embeds = fine_prompt_embeds
547
+
548
+ # predict the noise residual
549
+ noise_pred = self.unet(
550
+ latent_model_input,
551
+ t,
552
+ encoder_hidden_states=current_prompt_embeds,
553
+ cross_attention_kwargs=cross_attention_kwargs,
554
+ ).sample
555
+
556
+ # perform guidance
557
+ if do_classifier_free_guidance:
558
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
559
+ noise_pred = noise_pred_uncond + guidance_scale * (
560
+ noise_pred_text - noise_pred_uncond
561
+ )
562
+ else:
563
+ assert 0, "Not Implemented"
564
+
565
+ # compute the previous noisy sample x_t -> x_t-1
566
+ latents = self.scheduler.step(
567
+ noise_pred, t, latents, **extra_step_kwargs
568
+ ).prev_sample
569
+
570
+ # call the callback, if provided
571
+ if i == len(timesteps) - 1 or \
572
+ ( (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 ):
573
+ progress_bar.update()
574
+ if callback is not None and i % callback_steps == 0:
575
+ callback(i, t, latents)
576
+
577
+ if output_type == "latent":
578
+ image = latents
579
+ elif output_type == "pil":
580
+ # 9.1 Post-processing
581
+ image = self.decode_latents(latents)
582
+ # 9.3 Convert to PIL
583
+ image = self.numpy_to_pil(image)
584
+ else:
585
+ # 9.1 Post-processing
586
+ image = self.decode_latents(latents)
587
+
588
+ # Offload last model to CPU
589
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
590
+ self.final_offload_hook.offload()
591
+
592
+ if not return_dict:
593
+ return (image, None)
594
+
595
+ return StableDiffusionPipelineOutput(
596
+ images=image, nsfw_content_detected=None
597
+ )
598
+
599
+
600
+
601
+
602
+
603
+
604
+
605
+