Spaces:
Sleeping
Sleeping
adaface-neurips
commited on
Commit
·
3736ac5
1
Parent(s):
b80f000
Reboot
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .gitignore +10 -0
- ConsistentID/.gitattributes +38 -0
- ConsistentID/.gitignore +5 -0
- ConsistentID/LICENSE +21 -0
- ConsistentID/README.md +13 -0
- ConsistentID/__init__.py +0 -0
- ConsistentID/app.py +168 -0
- ConsistentID/images/templates/3f8d901770014c1b8f7f261971f0e92.png +3 -0
- ConsistentID/images/templates/6577b962b6346df03fea83211daaf48.png +0 -0
- ConsistentID/images/templates/75583964a834abe33b72f52b1a98e84.png +3 -0
- ConsistentID/images/templates/c9fe4c2d5ddbc5670dde47fc465c48b.jpg +0 -0
- ConsistentID/lib/BiSeNet/6.jpg +0 -0
- ConsistentID/lib/BiSeNet/__init__.py +2 -0
- ConsistentID/lib/BiSeNet/evaluate.py +95 -0
- ConsistentID/lib/BiSeNet/face_dataset.py +106 -0
- ConsistentID/lib/BiSeNet/hair.png +0 -0
- ConsistentID/lib/BiSeNet/logger.py +23 -0
- ConsistentID/lib/BiSeNet/loss.py +72 -0
- ConsistentID/lib/BiSeNet/makeup.py +129 -0
- ConsistentID/lib/BiSeNet/makeup/116_1.png +0 -0
- ConsistentID/lib/BiSeNet/makeup/116_3.png +0 -0
- ConsistentID/lib/BiSeNet/makeup/116_lip_ori.png +0 -0
- ConsistentID/lib/BiSeNet/makeup/116_ori.png +0 -0
- ConsistentID/lib/BiSeNet/model.py +282 -0
- ConsistentID/lib/BiSeNet/modules/__init__.py +5 -0
- ConsistentID/lib/BiSeNet/modules/bn.py +130 -0
- ConsistentID/lib/BiSeNet/modules/deeplab.py +84 -0
- ConsistentID/lib/BiSeNet/modules/dense.py +42 -0
- ConsistentID/lib/BiSeNet/modules/functions.py +234 -0
- ConsistentID/lib/BiSeNet/modules/misc.py +21 -0
- ConsistentID/lib/BiSeNet/modules/residual.py +88 -0
- ConsistentID/lib/BiSeNet/modules/src/checks.h +15 -0
- ConsistentID/lib/BiSeNet/modules/src/inplace_abn.cpp +95 -0
- ConsistentID/lib/BiSeNet/modules/src/inplace_abn.h +88 -0
- ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cpu.cpp +119 -0
- ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cuda.cu +333 -0
- ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cuda_half.cu +275 -0
- ConsistentID/lib/BiSeNet/modules/src/utils/checks.h +15 -0
- ConsistentID/lib/BiSeNet/modules/src/utils/common.h +49 -0
- ConsistentID/lib/BiSeNet/modules/src/utils/cuda.cuh +71 -0
- ConsistentID/lib/BiSeNet/optimizer.py +69 -0
- ConsistentID/lib/BiSeNet/prepropess_data.py +38 -0
- ConsistentID/lib/BiSeNet/resnet.py +109 -0
- ConsistentID/lib/BiSeNet/test.py +90 -0
- ConsistentID/lib/BiSeNet/train.py +179 -0
- ConsistentID/lib/BiSeNet/transform.py +129 -0
- ConsistentID/lib/attention.py +287 -0
- ConsistentID/lib/functions.py +606 -0
- ConsistentID/lib/pipeline_ConsistentID.py +605 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
models/ensemble/ar18-unet/diffusion_pytorch_model.safetensors filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
models/awportrait/*
|
2 |
+
models/awportrait
|
3 |
+
__pycache__/*
|
4 |
+
__pycache__
|
5 |
+
samples-ada/*
|
6 |
+
samples-ada
|
7 |
+
models/ensemble/awp14-unet/*
|
8 |
+
models/ensemble/awp14-unet
|
9 |
+
.gradio/certificate.pem
|
10 |
+
|
ConsistentID/.gitattributes
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
images/templates/3f8d901770014c1b8f7f261971f0e92.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
images/templates/75583964a834abe33b72f52b1a98e84.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
models/LLaVA/images/demo_cli.gif filter=lfs diff=lfs merge=lfs -text
|
ConsistentID/.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/*
|
2 |
+
__pycache__
|
3 |
+
/*.png
|
4 |
+
models/insightface
|
5 |
+
models/Realistic_Vision*
|
ConsistentID/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Jiehui Huang
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
ConsistentID/README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: ConsistentID
|
3 |
+
emoji: 🔥
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 4.37.2
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
ConsistentID/__init__.py
ADDED
File without changes
|
ConsistentID/app.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import glob
|
5 |
+
import spaces
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
from diffusers.utils import load_image
|
10 |
+
from diffusers import EulerDiscreteScheduler
|
11 |
+
from ConsistentID.lib.pipeline_ConsistentID import ConsistentIDPipeline
|
12 |
+
import argparse
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument('--base_model_path', type=str,
|
15 |
+
default="models/Realistic_Vision_V4.0_noVAE")
|
16 |
+
parser.add_argument('--gpu', type=int, default=0)
|
17 |
+
args = parser.parse_args()
|
18 |
+
|
19 |
+
device = f"cuda:{args.gpu}"
|
20 |
+
|
21 |
+
### Load base model
|
22 |
+
pipe = ConsistentIDPipeline.from_pretrained(
|
23 |
+
args.base_model_path,
|
24 |
+
torch_dtype=torch.float16,
|
25 |
+
)
|
26 |
+
|
27 |
+
### Load consistentID_model checkpoint
|
28 |
+
pipe.load_ConsistentID_model(
|
29 |
+
consistentID_weight_path="./models/ConsistentID-v1.bin",
|
30 |
+
bise_net_weight_path="./models/BiSeNet_pretrained_for_ConsistentID.pth",
|
31 |
+
)
|
32 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
33 |
+
pipe = pipe.to(device, torch.float16)
|
34 |
+
|
35 |
+
@spaces.GPU
|
36 |
+
def process(selected_template_images, custom_image, prompt,
|
37 |
+
negative_prompt, prompt_selected, model_selected_tab,
|
38 |
+
prompt_selected_tab, guidance_scale, width, height, merge_steps, seed_set):
|
39 |
+
|
40 |
+
# The gradio UI only supports one image at a time.
|
41 |
+
if model_selected_tab==0:
|
42 |
+
subj_images = load_image(Image.open(selected_template_images))
|
43 |
+
else:
|
44 |
+
subj_images = load_image(Image.fromarray(custom_image))
|
45 |
+
|
46 |
+
if prompt_selected_tab==0:
|
47 |
+
prompt = prompt_selected
|
48 |
+
negative_prompt = ""
|
49 |
+
|
50 |
+
# hyper-parameter
|
51 |
+
num_steps = 50
|
52 |
+
seed_set = torch.randint(0, 1000, (1,)).item()
|
53 |
+
# merge_steps = 30
|
54 |
+
|
55 |
+
if prompt == "":
|
56 |
+
prompt = "A man, in a forest"
|
57 |
+
prompt = "A man, with backpack, in a raining tropical forest, adventuring, holding a flashlight, in mist, seeking animals"
|
58 |
+
prompt = "A person, in a sowm, wearing santa hat and a scarf, with a cottage behind"
|
59 |
+
else:
|
60 |
+
#prompt=Enhance_prompt(prompt, Image.new('RGB', (200, 200), color = 'white'))
|
61 |
+
print(prompt)
|
62 |
+
|
63 |
+
if negative_prompt == "":
|
64 |
+
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality, blurry"
|
65 |
+
|
66 |
+
#Extend Prompt
|
67 |
+
#prompt = "cinematic photo," + prompt + ", 50mm photograph, half-length portrait, film, bokeh, professional, 4k, highly detailed"
|
68 |
+
#print(prompt)
|
69 |
+
|
70 |
+
negtive_prompt_group="((cross-eye)),((cross-eyed)),(((NFSW))),(nipple),((((ugly)))), (((duplicate))), ((morbid)), ((mutilated)), [out of frame], extra fingers, mutated hands, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, (((disfigured))). out of frame, ugly, extra limbs, (bad anatomy), gross proportions, (malformed limbs), ((missing arms)), ((missing legs)), (((extra arms))), (((extra legs))), mutated hands, (fused fingers), (too many fingers), (((long neck)))"
|
71 |
+
negative_prompt = negative_prompt + negtive_prompt_group
|
72 |
+
|
73 |
+
# seed = torch.randint(0, 1000, (1,)).item()
|
74 |
+
generator = torch.Generator(device=device).manual_seed(seed_set)
|
75 |
+
|
76 |
+
images = pipe(
|
77 |
+
prompt=prompt,
|
78 |
+
width=width,
|
79 |
+
height=height,
|
80 |
+
input_subj_image_objs=subj_images,
|
81 |
+
negative_prompt=negative_prompt,
|
82 |
+
num_images_per_prompt=1,
|
83 |
+
num_inference_steps=num_steps,
|
84 |
+
guidance_scale=guidance_scale,
|
85 |
+
start_merge_step=merge_steps,
|
86 |
+
generator=generator,
|
87 |
+
).images[0]
|
88 |
+
|
89 |
+
return np.array(images)
|
90 |
+
|
91 |
+
# Gets the templates
|
92 |
+
preset_template = glob.glob("./images/templates/*.png")
|
93 |
+
preset_template = preset_template + glob.glob("./images/templates/*.jpg")
|
94 |
+
|
95 |
+
with gr.Blocks(title="ConsistentID Demo") as demo:
|
96 |
+
gr.Markdown("# ConsistentID Demo")
|
97 |
+
gr.Markdown("\
|
98 |
+
Put the reference figure to be redrawn into the box below (There is a small probability of referensing failure. You can submit it repeatedly)")
|
99 |
+
gr.Markdown("\
|
100 |
+
If you find our work interesting, please leave a star in GitHub for us!<br>\
|
101 |
+
https://github.com/JackAILab/ConsistentID")
|
102 |
+
with gr.Row():
|
103 |
+
with gr.Column():
|
104 |
+
model_selected_tab = gr.State(0)
|
105 |
+
with gr.TabItem("template images") as template_images_tab:
|
106 |
+
template_gallery_list = [(i, i) for i in preset_template]
|
107 |
+
gallery = gr.Gallery(template_gallery_list,columns=[4], rows=[2], object_fit="contain", height="auto",show_label=False)
|
108 |
+
|
109 |
+
def select_function(evt: gr.SelectData):
|
110 |
+
return preset_template[evt.index]
|
111 |
+
|
112 |
+
selected_template_images = gr.Text(show_label=False, visible=False, placeholder="Selected")
|
113 |
+
gallery.select(select_function, None, selected_template_images)
|
114 |
+
with gr.TabItem("Upload Image") as upload_image_tab:
|
115 |
+
custom_image = gr.Image(label="Upload Image")
|
116 |
+
|
117 |
+
model_selected_tabs = [template_images_tab, upload_image_tab]
|
118 |
+
for i, tab in enumerate(model_selected_tabs):
|
119 |
+
tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[model_selected_tab])
|
120 |
+
|
121 |
+
with gr.Column():
|
122 |
+
prompt_selected_tab = gr.State(0)
|
123 |
+
with gr.TabItem("template prompts") as template_prompts_tab:
|
124 |
+
prompt_selected = gr.Dropdown(value="A person, police officer, half body shot", elem_id='dropdown', choices=[
|
125 |
+
"A woman in a wedding dress",
|
126 |
+
"A woman, queen, in a gorgeous palace",
|
127 |
+
"A man sitting at the beach with sunset",
|
128 |
+
"A person, police officer, half body shot",
|
129 |
+
"A man, sailor, in a boat above ocean",
|
130 |
+
"A women wearing headphone, listening music",
|
131 |
+
"A man, firefighter, half body shot"], label=f"prepared prompts")
|
132 |
+
|
133 |
+
with gr.TabItem("custom prompt") as custom_prompt_tab:
|
134 |
+
prompt = gr.Textbox(label="prompt",placeholder="A man/woman wearing a santa hat")
|
135 |
+
nagetive_prompt = gr.Textbox(label="negative prompt",placeholder="monochrome, lowres, bad anatomy, worst quality, low quality, blurry")
|
136 |
+
|
137 |
+
prompt_selected_tabs = [template_prompts_tab, custom_prompt_tab]
|
138 |
+
for i, tab in enumerate(prompt_selected_tabs):
|
139 |
+
tab.select(fn=lambda tabnum=i: tabnum, inputs=[], outputs=[prompt_selected_tab])
|
140 |
+
|
141 |
+
guidance_scale = gr.Slider(
|
142 |
+
label="Guidance scale",
|
143 |
+
minimum=1.0,
|
144 |
+
maximum=10.0,
|
145 |
+
step=1.0,
|
146 |
+
value=5.0,
|
147 |
+
)
|
148 |
+
|
149 |
+
width = gr.Slider(label="image width",minimum=256,maximum=768,value=512,step=8)
|
150 |
+
height = gr.Slider(label="image height",minimum=256,maximum=768,value=512,step=8)
|
151 |
+
width.release(lambda x,y: min(1280-x,y), inputs=[width,height], outputs=[height])
|
152 |
+
height.release(lambda x,y: min(1280-y,x), inputs=[width,height], outputs=[width])
|
153 |
+
merge_steps = gr.Slider(label="step starting to merge facial details(30 is recommended)",minimum=10,maximum=50,value=30,step=1)
|
154 |
+
seed_set = gr.Slider(label="set the random seed for different results",minimum=1,maximum=2147483647,value=2024,step=1)
|
155 |
+
|
156 |
+
btn = gr.Button("Run")
|
157 |
+
with gr.Column():
|
158 |
+
out = gr.Image(label="Output")
|
159 |
+
gr.Markdown('''
|
160 |
+
N.B.:<br/>
|
161 |
+
- If the proportion of face in the image is too small, the probability of an error will be slightly higher, and the similarity will also significantly decrease.)
|
162 |
+
- At the same time, use prompt with \"man\" or \"woman\" instead of \"person\" as much as possible, as that may cause the model to be confused whether the protagonist is male or female.
|
163 |
+
- Due to insufficient graphics memory on the demo server, there is an upper limit on the resolution for generating samples. We will support the generation of SDXL as soon as possible<br/><br/>
|
164 |
+
''')
|
165 |
+
btn.click(fn=process, inputs=[selected_template_images, custom_image,prompt, nagetive_prompt, prompt_selected,
|
166 |
+
model_selected_tab, prompt_selected_tab, guidance_scale, width, height, merge_steps, seed_set], outputs=out)
|
167 |
+
|
168 |
+
demo.launch(server_name='0.0.0.0', ssl_verify=False)
|
ConsistentID/images/templates/3f8d901770014c1b8f7f261971f0e92.png
ADDED
Git LFS Details
|
ConsistentID/images/templates/6577b962b6346df03fea83211daaf48.png
ADDED
ConsistentID/images/templates/75583964a834abe33b72f52b1a98e84.png
ADDED
Git LFS Details
|
ConsistentID/images/templates/c9fe4c2d5ddbc5670dde47fc465c48b.jpg
ADDED
ConsistentID/lib/BiSeNet/6.jpg
ADDED
ConsistentID/lib/BiSeNet/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
#__init__.py
|
2 |
+
# from BiSeNet.model import *
|
ConsistentID/lib/BiSeNet/evaluate.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
from logger import setup_logger
|
5 |
+
import BiSeNet
|
6 |
+
from face_dataset import FaceMask
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.utils.data import DataLoader
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torch.distributed as dist
|
13 |
+
|
14 |
+
import os
|
15 |
+
import os.path as osp
|
16 |
+
import logging
|
17 |
+
import time
|
18 |
+
import numpy as np
|
19 |
+
from tqdm import tqdm
|
20 |
+
import math
|
21 |
+
from PIL import Image
|
22 |
+
import torchvision.transforms as transforms
|
23 |
+
import cv2
|
24 |
+
|
25 |
+
def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'):
|
26 |
+
# Colors for all 20 parts
|
27 |
+
part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
|
28 |
+
[255, 0, 85], [255, 0, 170],
|
29 |
+
[0, 255, 0], [85, 255, 0], [170, 255, 0],
|
30 |
+
[0, 255, 85], [0, 255, 170],
|
31 |
+
[0, 0, 255], [85, 0, 255], [170, 0, 255],
|
32 |
+
[0, 85, 255], [0, 170, 255],
|
33 |
+
[255, 255, 0], [255, 255, 85], [255, 255, 170],
|
34 |
+
[255, 0, 255], [255, 85, 255], [255, 170, 255],
|
35 |
+
[0, 255, 255], [85, 255, 255], [170, 255, 255]]
|
36 |
+
|
37 |
+
im = np.array(im)
|
38 |
+
vis_im = im.copy().astype(np.uint8)
|
39 |
+
vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
|
40 |
+
vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
|
41 |
+
vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
|
42 |
+
|
43 |
+
num_of_class = np.max(vis_parsing_anno)
|
44 |
+
|
45 |
+
for pi in range(1, num_of_class + 1):
|
46 |
+
index = np.where(vis_parsing_anno == pi)
|
47 |
+
vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
|
48 |
+
|
49 |
+
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
|
50 |
+
# print(vis_parsing_anno_color.shape, vis_im.shape)
|
51 |
+
vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
|
52 |
+
|
53 |
+
# Save result or not
|
54 |
+
if save_im:
|
55 |
+
cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
|
56 |
+
|
57 |
+
# return vis_im
|
58 |
+
|
59 |
+
def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'):
|
60 |
+
|
61 |
+
if not os.path.exists(respth):
|
62 |
+
os.makedirs(respth)
|
63 |
+
|
64 |
+
n_classes = 19
|
65 |
+
net = BiSeNet(n_classes=n_classes)
|
66 |
+
net.cuda()
|
67 |
+
save_pth = osp.join('res/cp', cp)
|
68 |
+
net.load_state_dict(torch.load(save_pth))
|
69 |
+
net.eval()
|
70 |
+
|
71 |
+
to_tensor = transforms.Compose([
|
72 |
+
transforms.ToTensor(),
|
73 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
74 |
+
])
|
75 |
+
with torch.no_grad():
|
76 |
+
for image_path in os.listdir(dspth):
|
77 |
+
img = Image.open(osp.join(dspth, image_path))
|
78 |
+
image = img.resize((512, 512), Image.BILINEAR)
|
79 |
+
img = to_tensor(image)
|
80 |
+
img = torch.unsqueeze(img, 0)
|
81 |
+
img = img.cuda()
|
82 |
+
out = net(img)[0]
|
83 |
+
parsing = out.squeeze(0).cpu().numpy().argmax(0)
|
84 |
+
|
85 |
+
vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path))
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
if __name__ == "__main__":
|
94 |
+
setup_logger('./res')
|
95 |
+
evaluate()
|
ConsistentID/lib/BiSeNet/face_dataset.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
|
8 |
+
import os.path as osp
|
9 |
+
import os
|
10 |
+
from PIL import Image
|
11 |
+
import numpy as np
|
12 |
+
import json
|
13 |
+
import cv2
|
14 |
+
|
15 |
+
from transform import *
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
class FaceMask(Dataset):
|
20 |
+
def __init__(self, rootpth, cropsize=(640, 480), mode='train', *args, **kwargs):
|
21 |
+
super(FaceMask, self).__init__(*args, **kwargs)
|
22 |
+
assert mode in ('train', 'val', 'test')
|
23 |
+
self.mode = mode
|
24 |
+
self.ignore_lb = 255
|
25 |
+
self.rootpth = rootpth
|
26 |
+
|
27 |
+
self.imgs = os.listdir(os.path.join(self.rootpth, 'CelebA-HQ-img'))
|
28 |
+
|
29 |
+
# pre-processing
|
30 |
+
self.to_tensor = transforms.Compose([
|
31 |
+
transforms.ToTensor(),
|
32 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
33 |
+
])
|
34 |
+
self.trans_train = Compose([
|
35 |
+
ColorJitter(
|
36 |
+
brightness=0.5,
|
37 |
+
contrast=0.5,
|
38 |
+
saturation=0.5),
|
39 |
+
HorizontalFlip(),
|
40 |
+
RandomScale((0.75, 1.0, 1.25, 1.5, 1.75, 2.0)),
|
41 |
+
RandomCrop(cropsize)
|
42 |
+
])
|
43 |
+
|
44 |
+
def __getitem__(self, idx):
|
45 |
+
impth = self.imgs[idx]
|
46 |
+
img = Image.open(osp.join(self.rootpth, 'CelebA-HQ-img', impth))
|
47 |
+
img = img.resize((512, 512), Image.BILINEAR)
|
48 |
+
label = Image.open(osp.join(self.rootpth, 'mask', impth[:-3]+'png')).convert('P')
|
49 |
+
# print(np.unique(np.array(label)))
|
50 |
+
if self.mode == 'train':
|
51 |
+
im_lb = dict(im=img, lb=label)
|
52 |
+
im_lb = self.trans_train(im_lb)
|
53 |
+
img, label = im_lb['im'], im_lb['lb']
|
54 |
+
img = self.to_tensor(img)
|
55 |
+
label = np.array(label).astype(np.int64)[np.newaxis, :]
|
56 |
+
return img, label
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return len(self.imgs)
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
face_data = '/home/zll/data/CelebAMask-HQ/CelebA-HQ-img'
|
64 |
+
face_sep_mask = '/home/zll/data/CelebAMask-HQ/CelebAMask-HQ-mask-anno'
|
65 |
+
mask_path = '/home/zll/data/CelebAMask-HQ/mask'
|
66 |
+
counter = 0
|
67 |
+
total = 0
|
68 |
+
for i in range(15):
|
69 |
+
# files = os.listdir(osp.join(face_sep_mask, str(i)))
|
70 |
+
|
71 |
+
atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
|
72 |
+
'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
|
73 |
+
|
74 |
+
for j in range(i*2000, (i+1)*2000):
|
75 |
+
|
76 |
+
mask = np.zeros((512, 512))
|
77 |
+
|
78 |
+
for l, att in enumerate(atts, 1):
|
79 |
+
total += 1
|
80 |
+
file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png'])
|
81 |
+
path = osp.join(face_sep_mask, str(i), file_name)
|
82 |
+
|
83 |
+
if os.path.exists(path):
|
84 |
+
counter += 1
|
85 |
+
sep_mask = np.array(Image.open(path).convert('P'))
|
86 |
+
# print(np.unique(sep_mask))
|
87 |
+
|
88 |
+
mask[sep_mask == 225] = l
|
89 |
+
cv2.imwrite('{}/{}.png'.format(mask_path, j), mask)
|
90 |
+
print(j)
|
91 |
+
|
92 |
+
print(counter, total)
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
|
ConsistentID/lib/BiSeNet/hair.png
ADDED
ConsistentID/lib/BiSeNet/logger.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
import os.path as osp
|
6 |
+
import time
|
7 |
+
import sys
|
8 |
+
import logging
|
9 |
+
|
10 |
+
import torch.distributed as dist
|
11 |
+
|
12 |
+
|
13 |
+
def setup_logger(logpth):
|
14 |
+
logfile = 'BiSeNet-{}.log'.format(time.strftime('%Y-%m-%d-%H-%M-%S'))
|
15 |
+
logfile = osp.join(logpth, logfile)
|
16 |
+
FORMAT = '%(levelname)s %(filename)s(%(lineno)d): %(message)s'
|
17 |
+
log_level = logging.INFO
|
18 |
+
if dist.is_initialized() and not dist.get_rank()==0:
|
19 |
+
log_level = logging.ERROR
|
20 |
+
logging.basicConfig(level=log_level, format=FORMAT, filename=logfile)
|
21 |
+
logging.root.addHandler(logging.StreamHandler())
|
22 |
+
|
23 |
+
|
ConsistentID/lib/BiSeNet/loss.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
class OhemCELoss(nn.Module):
|
10 |
+
def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
|
11 |
+
super(OhemCELoss, self).__init__()
|
12 |
+
self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()
|
13 |
+
self.n_min = n_min
|
14 |
+
self.ignore_lb = ignore_lb
|
15 |
+
self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')
|
16 |
+
|
17 |
+
def forward(self, logits, labels):
|
18 |
+
N, C, H, W = logits.size()
|
19 |
+
loss = self.criteria(logits, labels).view(-1)
|
20 |
+
loss, _ = torch.sort(loss, descending=True)
|
21 |
+
if loss[self.n_min] > self.thresh:
|
22 |
+
loss = loss[loss>self.thresh]
|
23 |
+
else:
|
24 |
+
loss = loss[:self.n_min]
|
25 |
+
return torch.mean(loss)
|
26 |
+
|
27 |
+
|
28 |
+
class SoftmaxFocalLoss(nn.Module):
|
29 |
+
def __init__(self, gamma, ignore_lb=255, *args, **kwargs):
|
30 |
+
super(SoftmaxFocalLoss, self).__init__()
|
31 |
+
self.gamma = gamma
|
32 |
+
self.nll = nn.NLLLoss(ignore_index=ignore_lb)
|
33 |
+
|
34 |
+
def forward(self, logits, labels):
|
35 |
+
scores = F.softmax(logits, dim=1)
|
36 |
+
factor = torch.pow(1.-scores, self.gamma)
|
37 |
+
log_score = F.log_softmax(logits, dim=1)
|
38 |
+
log_score = factor * log_score
|
39 |
+
loss = self.nll(log_score, labels)
|
40 |
+
return loss
|
41 |
+
|
42 |
+
|
43 |
+
if __name__ == '__main__':
|
44 |
+
torch.manual_seed(15)
|
45 |
+
criteria1 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda()
|
46 |
+
criteria2 = OhemCELoss(thresh=0.7, n_min=16*20*20//16).cuda()
|
47 |
+
net1 = nn.Sequential(
|
48 |
+
nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1),
|
49 |
+
)
|
50 |
+
net1.cuda()
|
51 |
+
net1.train()
|
52 |
+
net2 = nn.Sequential(
|
53 |
+
nn.Conv2d(3, 19, kernel_size=3, stride=2, padding=1),
|
54 |
+
)
|
55 |
+
net2.cuda()
|
56 |
+
net2.train()
|
57 |
+
|
58 |
+
with torch.no_grad():
|
59 |
+
inten = torch.randn(16, 3, 20, 20).cuda()
|
60 |
+
lbs = torch.randint(0, 19, [16, 20, 20]).cuda()
|
61 |
+
lbs[1, :, :] = 255
|
62 |
+
|
63 |
+
logits1 = net1(inten)
|
64 |
+
logits1 = F.interpolate(logits1, inten.size()[2:], mode='bilinear')
|
65 |
+
logits2 = net2(inten)
|
66 |
+
logits2 = F.interpolate(logits2, inten.size()[2:], mode='bilinear')
|
67 |
+
|
68 |
+
loss1 = criteria1(logits1, lbs)
|
69 |
+
loss2 = criteria2(logits2, lbs)
|
70 |
+
loss = loss1 + loss2
|
71 |
+
print(loss.detach().cpu())
|
72 |
+
loss.backward()
|
ConsistentID/lib/BiSeNet/makeup.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
from skimage.filters import gaussian
|
4 |
+
|
5 |
+
|
6 |
+
def sharpen(img):
|
7 |
+
img = img * 1.0
|
8 |
+
gauss_out = gaussian(img, sigma=5, multichannel=True)
|
9 |
+
|
10 |
+
alpha = 1.5
|
11 |
+
img_out = (img - gauss_out) * alpha + img
|
12 |
+
|
13 |
+
img_out = img_out / 255.0
|
14 |
+
|
15 |
+
mask_1 = img_out < 0
|
16 |
+
mask_2 = img_out > 1
|
17 |
+
|
18 |
+
img_out = img_out * (1 - mask_1)
|
19 |
+
img_out = img_out * (1 - mask_2) + mask_2
|
20 |
+
img_out = np.clip(img_out, 0, 1)
|
21 |
+
img_out = img_out * 255
|
22 |
+
return np.array(img_out, dtype=np.uint8)
|
23 |
+
|
24 |
+
|
25 |
+
def hair(image, parsing, part=17, color=[230, 50, 20]):
|
26 |
+
b, g, r = color #[10, 50, 250] # [10, 250, 10]
|
27 |
+
tar_color = np.zeros_like(image)
|
28 |
+
tar_color[:, :, 0] = b
|
29 |
+
tar_color[:, :, 1] = g
|
30 |
+
tar_color[:, :, 2] = r
|
31 |
+
|
32 |
+
image_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
33 |
+
tar_hsv = cv2.cvtColor(tar_color, cv2.COLOR_BGR2HSV)
|
34 |
+
|
35 |
+
if part == 12 or part == 13:
|
36 |
+
image_hsv[:, :, 0:2] = tar_hsv[:, :, 0:2]
|
37 |
+
else:
|
38 |
+
image_hsv[:, :, 0:1] = tar_hsv[:, :, 0:1]
|
39 |
+
|
40 |
+
changed = cv2.cvtColor(image_hsv, cv2.COLOR_HSV2BGR)
|
41 |
+
|
42 |
+
if part == 17:
|
43 |
+
changed = sharpen(changed)
|
44 |
+
|
45 |
+
changed[parsing != part] = image[parsing != part]
|
46 |
+
# changed = cv2.resize(changed, (512, 512))
|
47 |
+
return changed
|
48 |
+
|
49 |
+
#
|
50 |
+
# def lip(image, parsing, part=17, color=[230, 50, 20]):
|
51 |
+
# b, g, r = color #[10, 50, 250] # [10, 250, 10]
|
52 |
+
# tar_color = np.zeros_like(image)
|
53 |
+
# tar_color[:, :, 0] = b
|
54 |
+
# tar_color[:, :, 1] = g
|
55 |
+
# tar_color[:, :, 2] = r
|
56 |
+
#
|
57 |
+
# image_lab = cv2.cvtColor(image, cv2.COLOR_BGR2Lab)
|
58 |
+
# il, ia, ib = cv2.split(image_lab)
|
59 |
+
#
|
60 |
+
# tar_lab = cv2.cvtColor(tar_color, cv2.COLOR_BGR2Lab)
|
61 |
+
# tl, ta, tb = cv2.split(tar_lab)
|
62 |
+
#
|
63 |
+
# image_lab[:, :, 0] = np.clip(il - np.mean(il) + tl, 0, 100)
|
64 |
+
# image_lab[:, :, 1] = np.clip(ia - np.mean(ia) + ta, -127, 128)
|
65 |
+
# image_lab[:, :, 2] = np.clip(ib - np.mean(ib) + tb, -127, 128)
|
66 |
+
#
|
67 |
+
#
|
68 |
+
# changed = cv2.cvtColor(image_lab, cv2.COLOR_Lab2BGR)
|
69 |
+
#
|
70 |
+
# if part == 17:
|
71 |
+
# changed = sharpen(changed)
|
72 |
+
#
|
73 |
+
# changed[parsing != part] = image[parsing != part]
|
74 |
+
# # changed = cv2.resize(changed, (512, 512))
|
75 |
+
# return changed
|
76 |
+
|
77 |
+
|
78 |
+
if __name__ == '__main__':
|
79 |
+
# 1 face
|
80 |
+
# 10 nose
|
81 |
+
# 11 teeth
|
82 |
+
# 12 upper lip
|
83 |
+
# 13 lower lip
|
84 |
+
# 17 hair
|
85 |
+
num = 116
|
86 |
+
table = {
|
87 |
+
'hair': 17,
|
88 |
+
'upper_lip': 12,
|
89 |
+
'lower_lip': 13
|
90 |
+
}
|
91 |
+
image_path = '/home/zll/data/CelebAMask-HQ/test-img/{}.jpg'.format(num)
|
92 |
+
parsing_path = 'res/test_res/{}.png'.format(num)
|
93 |
+
|
94 |
+
image = cv2.imread(image_path)
|
95 |
+
ori = image.copy()
|
96 |
+
parsing = np.array(cv2.imread(parsing_path, 0))
|
97 |
+
parsing = cv2.resize(parsing, image.shape[0:2], interpolation=cv2.INTER_NEAREST)
|
98 |
+
|
99 |
+
parts = [table['hair'], table['upper_lip'], table['lower_lip']]
|
100 |
+
# colors = [[20, 20, 200], [100, 100, 230], [100, 100, 230]]
|
101 |
+
colors = [[100, 200, 100]]
|
102 |
+
for part, color in zip(parts, colors):
|
103 |
+
image = hair(image, parsing, part, color)
|
104 |
+
cv2.imwrite('res/makeup/116_ori.png', cv2.resize(ori, (512, 512)))
|
105 |
+
cv2.imwrite('res/makeup/116_2.png', cv2.resize(image, (512, 512)))
|
106 |
+
|
107 |
+
cv2.imshow('image', cv2.resize(ori, (512, 512)))
|
108 |
+
cv2.imshow('color', cv2.resize(image, (512, 512)))
|
109 |
+
|
110 |
+
# cv2.imshow('image', ori)
|
111 |
+
# cv2.imshow('color', image)
|
112 |
+
|
113 |
+
cv2.waitKey(0)
|
114 |
+
cv2.destroyAllWindows()
|
115 |
+
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
|
ConsistentID/lib/BiSeNet/makeup/116_1.png
ADDED
ConsistentID/lib/BiSeNet/makeup/116_3.png
ADDED
ConsistentID/lib/BiSeNet/makeup/116_lip_ori.png
ADDED
ConsistentID/lib/BiSeNet/makeup/116_ori.png
ADDED
ConsistentID/lib/BiSeNet/model.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from .resnet import Resnet18
|
10 |
+
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
11 |
+
|
12 |
+
|
13 |
+
class ConvBNReLU(nn.Module):
|
14 |
+
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
|
15 |
+
super(ConvBNReLU, self).__init__()
|
16 |
+
self.conv = nn.Conv2d(in_chan,
|
17 |
+
out_chan,
|
18 |
+
kernel_size = ks,
|
19 |
+
stride = stride,
|
20 |
+
padding = padding,
|
21 |
+
bias = False)
|
22 |
+
self.bn = nn.BatchNorm2d(out_chan)
|
23 |
+
self.init_weight()
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
x = self.conv(x)
|
27 |
+
x = F.relu(self.bn(x))
|
28 |
+
return x
|
29 |
+
|
30 |
+
def init_weight(self):
|
31 |
+
for ly in self.children():
|
32 |
+
if isinstance(ly, nn.Conv2d):
|
33 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
34 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
35 |
+
|
36 |
+
class BiSeNetOutput(nn.Module):
|
37 |
+
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
|
38 |
+
super(BiSeNetOutput, self).__init__()
|
39 |
+
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
|
40 |
+
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
|
41 |
+
self.init_weight()
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
x = self.conv(x)
|
45 |
+
x = self.conv_out(x)
|
46 |
+
return x
|
47 |
+
|
48 |
+
def init_weight(self):
|
49 |
+
for ly in self.children():
|
50 |
+
if isinstance(ly, nn.Conv2d):
|
51 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
52 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
53 |
+
|
54 |
+
def get_params(self):
|
55 |
+
wd_params, nowd_params = [], []
|
56 |
+
for name, module in self.named_modules():
|
57 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
58 |
+
wd_params.append(module.weight)
|
59 |
+
if not module.bias is None:
|
60 |
+
nowd_params.append(module.bias)
|
61 |
+
elif isinstance(module, nn.BatchNorm2d):
|
62 |
+
nowd_params += list(module.parameters())
|
63 |
+
return wd_params, nowd_params
|
64 |
+
|
65 |
+
|
66 |
+
class AttentionRefinementModule(nn.Module):
|
67 |
+
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
68 |
+
super(AttentionRefinementModule, self).__init__()
|
69 |
+
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
|
70 |
+
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
|
71 |
+
self.bn_atten = nn.BatchNorm2d(out_chan)
|
72 |
+
self.sigmoid_atten = nn.Sigmoid()
|
73 |
+
self.init_weight()
|
74 |
+
|
75 |
+
def forward(self, x):
|
76 |
+
feat = self.conv(x)
|
77 |
+
atten = F.avg_pool2d(feat, feat.size()[2:])
|
78 |
+
atten = self.conv_atten(atten)
|
79 |
+
atten = self.bn_atten(atten)
|
80 |
+
atten = self.sigmoid_atten(atten)
|
81 |
+
out = torch.mul(feat, atten)
|
82 |
+
return out
|
83 |
+
|
84 |
+
def init_weight(self):
|
85 |
+
for ly in self.children():
|
86 |
+
if isinstance(ly, nn.Conv2d):
|
87 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
88 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
89 |
+
|
90 |
+
|
91 |
+
class ContextPath(nn.Module):
|
92 |
+
def __init__(self, *args, **kwargs):
|
93 |
+
super(ContextPath, self).__init__()
|
94 |
+
self.resnet = Resnet18()
|
95 |
+
self.arm16 = AttentionRefinementModule(256, 128)
|
96 |
+
self.arm32 = AttentionRefinementModule(512, 128)
|
97 |
+
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
98 |
+
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
99 |
+
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
|
100 |
+
|
101 |
+
self.init_weight()
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
H0, W0 = x.size()[2:]
|
105 |
+
feat8, feat16, feat32 = self.resnet(x)
|
106 |
+
H8, W8 = feat8.size()[2:]
|
107 |
+
H16, W16 = feat16.size()[2:]
|
108 |
+
H32, W32 = feat32.size()[2:]
|
109 |
+
|
110 |
+
avg = F.avg_pool2d(feat32, feat32.size()[2:])
|
111 |
+
avg = self.conv_avg(avg)
|
112 |
+
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
|
113 |
+
|
114 |
+
feat32_arm = self.arm32(feat32)
|
115 |
+
feat32_sum = feat32_arm + avg_up
|
116 |
+
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
|
117 |
+
feat32_up = self.conv_head32(feat32_up)
|
118 |
+
|
119 |
+
feat16_arm = self.arm16(feat16)
|
120 |
+
feat16_sum = feat16_arm + feat32_up
|
121 |
+
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
|
122 |
+
feat16_up = self.conv_head16(feat16_up)
|
123 |
+
|
124 |
+
return feat8, feat16_up, feat32_up # x8, x8, x16
|
125 |
+
|
126 |
+
def init_weight(self):
|
127 |
+
for ly in self.children():
|
128 |
+
if isinstance(ly, nn.Conv2d):
|
129 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
130 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
131 |
+
|
132 |
+
def get_params(self):
|
133 |
+
wd_params, nowd_params = [], []
|
134 |
+
for name, module in self.named_modules():
|
135 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
136 |
+
wd_params.append(module.weight)
|
137 |
+
if not module.bias is None:
|
138 |
+
nowd_params.append(module.bias)
|
139 |
+
elif isinstance(module, nn.BatchNorm2d):
|
140 |
+
nowd_params += list(module.parameters())
|
141 |
+
return wd_params, nowd_params
|
142 |
+
|
143 |
+
|
144 |
+
### This is not used, since I replace this with the resnet feature with the same size
|
145 |
+
class SpatialPath(nn.Module):
|
146 |
+
def __init__(self, *args, **kwargs):
|
147 |
+
super(SpatialPath, self).__init__()
|
148 |
+
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
|
149 |
+
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
150 |
+
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
151 |
+
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
|
152 |
+
self.init_weight()
|
153 |
+
|
154 |
+
def forward(self, x):
|
155 |
+
feat = self.conv1(x)
|
156 |
+
feat = self.conv2(feat)
|
157 |
+
feat = self.conv3(feat)
|
158 |
+
feat = self.conv_out(feat)
|
159 |
+
return feat
|
160 |
+
|
161 |
+
def init_weight(self):
|
162 |
+
for ly in self.children():
|
163 |
+
if isinstance(ly, nn.Conv2d):
|
164 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
165 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
166 |
+
|
167 |
+
def get_params(self):
|
168 |
+
wd_params, nowd_params = [], []
|
169 |
+
for name, module in self.named_modules():
|
170 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
171 |
+
wd_params.append(module.weight)
|
172 |
+
if not module.bias is None:
|
173 |
+
nowd_params.append(module.bias)
|
174 |
+
elif isinstance(module, nn.BatchNorm2d):
|
175 |
+
nowd_params += list(module.parameters())
|
176 |
+
return wd_params, nowd_params
|
177 |
+
|
178 |
+
|
179 |
+
class FeatureFusionModule(nn.Module):
|
180 |
+
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
181 |
+
super(FeatureFusionModule, self).__init__()
|
182 |
+
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
|
183 |
+
self.conv1 = nn.Conv2d(out_chan,
|
184 |
+
out_chan//4,
|
185 |
+
kernel_size = 1,
|
186 |
+
stride = 1,
|
187 |
+
padding = 0,
|
188 |
+
bias = False)
|
189 |
+
self.conv2 = nn.Conv2d(out_chan//4,
|
190 |
+
out_chan,
|
191 |
+
kernel_size = 1,
|
192 |
+
stride = 1,
|
193 |
+
padding = 0,
|
194 |
+
bias = False)
|
195 |
+
self.relu = nn.ReLU(inplace=True)
|
196 |
+
self.sigmoid = nn.Sigmoid()
|
197 |
+
self.init_weight()
|
198 |
+
|
199 |
+
def forward(self, fsp, fcp):
|
200 |
+
fcat = torch.cat([fsp, fcp], dim=1)
|
201 |
+
feat = self.convblk(fcat)
|
202 |
+
atten = F.avg_pool2d(feat, feat.size()[2:])
|
203 |
+
atten = self.conv1(atten)
|
204 |
+
atten = self.relu(atten)
|
205 |
+
atten = self.conv2(atten)
|
206 |
+
atten = self.sigmoid(atten)
|
207 |
+
feat_atten = torch.mul(feat, atten)
|
208 |
+
feat_out = feat_atten + feat
|
209 |
+
return feat_out
|
210 |
+
|
211 |
+
def init_weight(self):
|
212 |
+
for ly in self.children():
|
213 |
+
if isinstance(ly, nn.Conv2d):
|
214 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
215 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
216 |
+
|
217 |
+
def get_params(self):
|
218 |
+
wd_params, nowd_params = [], []
|
219 |
+
for name, module in self.named_modules():
|
220 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
221 |
+
wd_params.append(module.weight)
|
222 |
+
if not module.bias is None:
|
223 |
+
nowd_params.append(module.bias)
|
224 |
+
elif isinstance(module, nn.BatchNorm2d):
|
225 |
+
nowd_params += list(module.parameters())
|
226 |
+
return wd_params, nowd_params
|
227 |
+
|
228 |
+
|
229 |
+
class BiSeNet(nn.Module):
|
230 |
+
def __init__(self, n_classes, *args, **kwargs):
|
231 |
+
super(BiSeNet, self).__init__()
|
232 |
+
self.cp = ContextPath()
|
233 |
+
## here self.sp is deleted
|
234 |
+
self.ffm = FeatureFusionModule(256, 256)
|
235 |
+
self.conv_out = BiSeNetOutput(256, 256, n_classes)
|
236 |
+
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
|
237 |
+
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
|
238 |
+
self.init_weight()
|
239 |
+
|
240 |
+
def forward(self, x):
|
241 |
+
H, W = x.size()[2:]
|
242 |
+
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
|
243 |
+
feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
|
244 |
+
feat_fuse = self.ffm(feat_sp, feat_cp8)
|
245 |
+
|
246 |
+
feat_out = self.conv_out(feat_fuse)
|
247 |
+
feat_out16 = self.conv_out16(feat_cp8)
|
248 |
+
feat_out32 = self.conv_out32(feat_cp16)
|
249 |
+
|
250 |
+
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
|
251 |
+
feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
|
252 |
+
feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
|
253 |
+
return feat_out, feat_out16, feat_out32
|
254 |
+
|
255 |
+
def init_weight(self):
|
256 |
+
for ly in self.children():
|
257 |
+
if isinstance(ly, nn.Conv2d):
|
258 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
259 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
260 |
+
|
261 |
+
def get_params(self):
|
262 |
+
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
|
263 |
+
for name, child in self.named_children():
|
264 |
+
child_wd_params, child_nowd_params = child.get_params()
|
265 |
+
if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
|
266 |
+
lr_mul_wd_params += child_wd_params
|
267 |
+
lr_mul_nowd_params += child_nowd_params
|
268 |
+
else:
|
269 |
+
wd_params += child_wd_params
|
270 |
+
nowd_params += child_nowd_params
|
271 |
+
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
|
272 |
+
|
273 |
+
|
274 |
+
if __name__ == "__main__":
|
275 |
+
net = BiSeNet(19)
|
276 |
+
net.cuda()
|
277 |
+
net.eval()
|
278 |
+
in_ten = torch.randn(16, 3, 640, 480).cuda()
|
279 |
+
out, out16, out32 = net(in_ten)
|
280 |
+
print(out.shape)
|
281 |
+
|
282 |
+
net.get_params()
|
ConsistentID/lib/BiSeNet/modules/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .bn import ABN, InPlaceABN, InPlaceABNSync
|
2 |
+
from .functions import ACT_RELU, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE
|
3 |
+
from .misc import GlobalAvgPool2d, SingleGPU
|
4 |
+
from .residual import IdentityResidualBlock
|
5 |
+
from .dense import DenseModule
|
ConsistentID/lib/BiSeNet/modules/bn.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as functional
|
4 |
+
|
5 |
+
try:
|
6 |
+
from queue import Queue
|
7 |
+
except ImportError:
|
8 |
+
from Queue import Queue
|
9 |
+
|
10 |
+
from .functions import *
|
11 |
+
|
12 |
+
|
13 |
+
class ABN(nn.Module):
|
14 |
+
"""Activated Batch Normalization
|
15 |
+
|
16 |
+
This gathers a `BatchNorm2d` and an activation function in a single module
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
|
20 |
+
"""Creates an Activated Batch Normalization module
|
21 |
+
|
22 |
+
Parameters
|
23 |
+
----------
|
24 |
+
num_features : int
|
25 |
+
Number of feature channels in the input and output.
|
26 |
+
eps : float
|
27 |
+
Small constant to prevent numerical issues.
|
28 |
+
momentum : float
|
29 |
+
Momentum factor applied to compute running statistics as.
|
30 |
+
affine : bool
|
31 |
+
If `True` apply learned scale and shift transformation after normalization.
|
32 |
+
activation : str
|
33 |
+
Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
|
34 |
+
slope : float
|
35 |
+
Negative slope for the `leaky_relu` activation.
|
36 |
+
"""
|
37 |
+
super(ABN, self).__init__()
|
38 |
+
self.num_features = num_features
|
39 |
+
self.affine = affine
|
40 |
+
self.eps = eps
|
41 |
+
self.momentum = momentum
|
42 |
+
self.activation = activation
|
43 |
+
self.slope = slope
|
44 |
+
if self.affine:
|
45 |
+
self.weight = nn.Parameter(torch.ones(num_features))
|
46 |
+
self.bias = nn.Parameter(torch.zeros(num_features))
|
47 |
+
else:
|
48 |
+
self.register_parameter('weight', None)
|
49 |
+
self.register_parameter('bias', None)
|
50 |
+
self.register_buffer('running_mean', torch.zeros(num_features))
|
51 |
+
self.register_buffer('running_var', torch.ones(num_features))
|
52 |
+
self.reset_parameters()
|
53 |
+
|
54 |
+
def reset_parameters(self):
|
55 |
+
nn.init.constant_(self.running_mean, 0)
|
56 |
+
nn.init.constant_(self.running_var, 1)
|
57 |
+
if self.affine:
|
58 |
+
nn.init.constant_(self.weight, 1)
|
59 |
+
nn.init.constant_(self.bias, 0)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
|
63 |
+
self.training, self.momentum, self.eps)
|
64 |
+
|
65 |
+
if self.activation == ACT_RELU:
|
66 |
+
return functional.relu(x, inplace=True)
|
67 |
+
elif self.activation == ACT_LEAKY_RELU:
|
68 |
+
return functional.leaky_relu(x, negative_slope=self.slope, inplace=True)
|
69 |
+
elif self.activation == ACT_ELU:
|
70 |
+
return functional.elu(x, inplace=True)
|
71 |
+
else:
|
72 |
+
return x
|
73 |
+
|
74 |
+
def __repr__(self):
|
75 |
+
rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
|
76 |
+
' affine={affine}, activation={activation}'
|
77 |
+
if self.activation == "leaky_relu":
|
78 |
+
rep += ', slope={slope})'
|
79 |
+
else:
|
80 |
+
rep += ')'
|
81 |
+
return rep.format(name=self.__class__.__name__, **self.__dict__)
|
82 |
+
|
83 |
+
|
84 |
+
class InPlaceABN(ABN):
|
85 |
+
"""InPlace Activated Batch Normalization"""
|
86 |
+
|
87 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
|
88 |
+
"""Creates an InPlace Activated Batch Normalization module
|
89 |
+
|
90 |
+
Parameters
|
91 |
+
----------
|
92 |
+
num_features : int
|
93 |
+
Number of feature channels in the input and output.
|
94 |
+
eps : float
|
95 |
+
Small constant to prevent numerical issues.
|
96 |
+
momentum : float
|
97 |
+
Momentum factor applied to compute running statistics as.
|
98 |
+
affine : bool
|
99 |
+
If `True` apply learned scale and shift transformation after normalization.
|
100 |
+
activation : str
|
101 |
+
Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
|
102 |
+
slope : float
|
103 |
+
Negative slope for the `leaky_relu` activation.
|
104 |
+
"""
|
105 |
+
super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope)
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var,
|
109 |
+
self.training, self.momentum, self.eps, self.activation, self.slope)
|
110 |
+
|
111 |
+
|
112 |
+
class InPlaceABNSync(ABN):
|
113 |
+
"""InPlace Activated Batch Normalization with cross-GPU synchronization
|
114 |
+
This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DistributedDataParallel`.
|
115 |
+
"""
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
return inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var,
|
119 |
+
self.training, self.momentum, self.eps, self.activation, self.slope)
|
120 |
+
|
121 |
+
def __repr__(self):
|
122 |
+
rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
|
123 |
+
' affine={affine}, activation={activation}'
|
124 |
+
if self.activation == "leaky_relu":
|
125 |
+
rep += ', slope={slope})'
|
126 |
+
else:
|
127 |
+
rep += ')'
|
128 |
+
return rep.format(name=self.__class__.__name__, **self.__dict__)
|
129 |
+
|
130 |
+
|
ConsistentID/lib/BiSeNet/modules/deeplab.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as functional
|
4 |
+
|
5 |
+
from models._util import try_index
|
6 |
+
from .bn import ABN
|
7 |
+
|
8 |
+
|
9 |
+
class DeeplabV3(nn.Module):
|
10 |
+
def __init__(self,
|
11 |
+
in_channels,
|
12 |
+
out_channels,
|
13 |
+
hidden_channels=256,
|
14 |
+
dilations=(12, 24, 36),
|
15 |
+
norm_act=ABN,
|
16 |
+
pooling_size=None):
|
17 |
+
super(DeeplabV3, self).__init__()
|
18 |
+
self.pooling_size = pooling_size
|
19 |
+
|
20 |
+
self.map_convs = nn.ModuleList([
|
21 |
+
nn.Conv2d(in_channels, hidden_channels, 1, bias=False),
|
22 |
+
nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[0], padding=dilations[0]),
|
23 |
+
nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[1], padding=dilations[1]),
|
24 |
+
nn.Conv2d(in_channels, hidden_channels, 3, bias=False, dilation=dilations[2], padding=dilations[2])
|
25 |
+
])
|
26 |
+
self.map_bn = norm_act(hidden_channels * 4)
|
27 |
+
|
28 |
+
self.global_pooling_conv = nn.Conv2d(in_channels, hidden_channels, 1, bias=False)
|
29 |
+
self.global_pooling_bn = norm_act(hidden_channels)
|
30 |
+
|
31 |
+
self.red_conv = nn.Conv2d(hidden_channels * 4, out_channels, 1, bias=False)
|
32 |
+
self.pool_red_conv = nn.Conv2d(hidden_channels, out_channels, 1, bias=False)
|
33 |
+
self.red_bn = norm_act(out_channels)
|
34 |
+
|
35 |
+
self.reset_parameters(self.map_bn.activation, self.map_bn.slope)
|
36 |
+
|
37 |
+
def reset_parameters(self, activation, slope):
|
38 |
+
gain = nn.init.calculate_gain(activation, slope)
|
39 |
+
for m in self.modules():
|
40 |
+
if isinstance(m, nn.Conv2d):
|
41 |
+
nn.init.xavier_normal_(m.weight.data, gain)
|
42 |
+
if hasattr(m, "bias") and m.bias is not None:
|
43 |
+
nn.init.constant_(m.bias, 0)
|
44 |
+
elif isinstance(m, ABN):
|
45 |
+
if hasattr(m, "weight") and m.weight is not None:
|
46 |
+
nn.init.constant_(m.weight, 1)
|
47 |
+
if hasattr(m, "bias") and m.bias is not None:
|
48 |
+
nn.init.constant_(m.bias, 0)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
# Map convolutions
|
52 |
+
out = torch.cat([m(x) for m in self.map_convs], dim=1)
|
53 |
+
out = self.map_bn(out)
|
54 |
+
out = self.red_conv(out)
|
55 |
+
|
56 |
+
# Global pooling
|
57 |
+
pool = self._global_pooling(x)
|
58 |
+
pool = self.global_pooling_conv(pool)
|
59 |
+
pool = self.global_pooling_bn(pool)
|
60 |
+
pool = self.pool_red_conv(pool)
|
61 |
+
if self.training or self.pooling_size is None:
|
62 |
+
pool = pool.repeat(1, 1, x.size(2), x.size(3))
|
63 |
+
|
64 |
+
out += pool
|
65 |
+
out = self.red_bn(out)
|
66 |
+
return out
|
67 |
+
|
68 |
+
def _global_pooling(self, x):
|
69 |
+
if self.training or self.pooling_size is None:
|
70 |
+
pool = x.view(x.size(0), x.size(1), -1).mean(dim=-1)
|
71 |
+
pool = pool.view(x.size(0), x.size(1), 1, 1)
|
72 |
+
else:
|
73 |
+
pooling_size = (min(try_index(self.pooling_size, 0), x.shape[2]),
|
74 |
+
min(try_index(self.pooling_size, 1), x.shape[3]))
|
75 |
+
padding = (
|
76 |
+
(pooling_size[1] - 1) // 2,
|
77 |
+
(pooling_size[1] - 1) // 2 if pooling_size[1] % 2 == 1 else (pooling_size[1] - 1) // 2 + 1,
|
78 |
+
(pooling_size[0] - 1) // 2,
|
79 |
+
(pooling_size[0] - 1) // 2 if pooling_size[0] % 2 == 1 else (pooling_size[0] - 1) // 2 + 1
|
80 |
+
)
|
81 |
+
|
82 |
+
pool = functional.avg_pool2d(x, pooling_size, stride=1)
|
83 |
+
pool = functional.pad(pool, pad=padding, mode="replicate")
|
84 |
+
return pool
|
ConsistentID/lib/BiSeNet/modules/dense.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from .bn import ABN
|
7 |
+
|
8 |
+
|
9 |
+
class DenseModule(nn.Module):
|
10 |
+
def __init__(self, in_channels, growth, layers, bottleneck_factor=4, norm_act=ABN, dilation=1):
|
11 |
+
super(DenseModule, self).__init__()
|
12 |
+
self.in_channels = in_channels
|
13 |
+
self.growth = growth
|
14 |
+
self.layers = layers
|
15 |
+
|
16 |
+
self.convs1 = nn.ModuleList()
|
17 |
+
self.convs3 = nn.ModuleList()
|
18 |
+
for i in range(self.layers):
|
19 |
+
self.convs1.append(nn.Sequential(OrderedDict([
|
20 |
+
("bn", norm_act(in_channels)),
|
21 |
+
("conv", nn.Conv2d(in_channels, self.growth * bottleneck_factor, 1, bias=False))
|
22 |
+
])))
|
23 |
+
self.convs3.append(nn.Sequential(OrderedDict([
|
24 |
+
("bn", norm_act(self.growth * bottleneck_factor)),
|
25 |
+
("conv", nn.Conv2d(self.growth * bottleneck_factor, self.growth, 3, padding=dilation, bias=False,
|
26 |
+
dilation=dilation))
|
27 |
+
])))
|
28 |
+
in_channels += self.growth
|
29 |
+
|
30 |
+
@property
|
31 |
+
def out_channels(self):
|
32 |
+
return self.in_channels + self.growth * self.layers
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
inputs = [x]
|
36 |
+
for i in range(self.layers):
|
37 |
+
x = torch.cat(inputs, dim=1)
|
38 |
+
x = self.convs1[i](x)
|
39 |
+
x = self.convs3[i](x)
|
40 |
+
inputs += [x]
|
41 |
+
|
42 |
+
return torch.cat(inputs, dim=1)
|
ConsistentID/lib/BiSeNet/modules/functions.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os import path
|
2 |
+
import torch
|
3 |
+
import torch.distributed as dist
|
4 |
+
import torch.autograd as autograd
|
5 |
+
import torch.cuda.comm as comm
|
6 |
+
from torch.autograd.function import once_differentiable
|
7 |
+
from torch.utils.cpp_extension import load
|
8 |
+
|
9 |
+
_src_path = path.join(path.dirname(path.abspath(__file__)), "src")
|
10 |
+
_backend = load(name="inplace_abn",
|
11 |
+
extra_cflags=["-O3"],
|
12 |
+
sources=[path.join(_src_path, f) for f in [
|
13 |
+
"inplace_abn.cpp",
|
14 |
+
"inplace_abn_cpu.cpp",
|
15 |
+
"inplace_abn_cuda.cu",
|
16 |
+
"inplace_abn_cuda_half.cu"
|
17 |
+
]],
|
18 |
+
extra_cuda_cflags=["--expt-extended-lambda"])
|
19 |
+
|
20 |
+
# Activation names
|
21 |
+
ACT_RELU = "relu"
|
22 |
+
ACT_LEAKY_RELU = "leaky_relu"
|
23 |
+
ACT_ELU = "elu"
|
24 |
+
ACT_NONE = "none"
|
25 |
+
|
26 |
+
|
27 |
+
def _check(fn, *args, **kwargs):
|
28 |
+
success = fn(*args, **kwargs)
|
29 |
+
if not success:
|
30 |
+
raise RuntimeError("CUDA Error encountered in {}".format(fn))
|
31 |
+
|
32 |
+
|
33 |
+
def _broadcast_shape(x):
|
34 |
+
out_size = []
|
35 |
+
for i, s in enumerate(x.size()):
|
36 |
+
if i != 1:
|
37 |
+
out_size.append(1)
|
38 |
+
else:
|
39 |
+
out_size.append(s)
|
40 |
+
return out_size
|
41 |
+
|
42 |
+
|
43 |
+
def _reduce(x):
|
44 |
+
if len(x.size()) == 2:
|
45 |
+
return x.sum(dim=0)
|
46 |
+
else:
|
47 |
+
n, c = x.size()[0:2]
|
48 |
+
return x.contiguous().view((n, c, -1)).sum(2).sum(0)
|
49 |
+
|
50 |
+
|
51 |
+
def _count_samples(x):
|
52 |
+
count = 1
|
53 |
+
for i, s in enumerate(x.size()):
|
54 |
+
if i != 1:
|
55 |
+
count *= s
|
56 |
+
return count
|
57 |
+
|
58 |
+
|
59 |
+
def _act_forward(ctx, x):
|
60 |
+
if ctx.activation == ACT_LEAKY_RELU:
|
61 |
+
_backend.leaky_relu_forward(x, ctx.slope)
|
62 |
+
elif ctx.activation == ACT_ELU:
|
63 |
+
_backend.elu_forward(x)
|
64 |
+
elif ctx.activation == ACT_NONE:
|
65 |
+
pass
|
66 |
+
|
67 |
+
|
68 |
+
def _act_backward(ctx, x, dx):
|
69 |
+
if ctx.activation == ACT_LEAKY_RELU:
|
70 |
+
_backend.leaky_relu_backward(x, dx, ctx.slope)
|
71 |
+
elif ctx.activation == ACT_ELU:
|
72 |
+
_backend.elu_backward(x, dx)
|
73 |
+
elif ctx.activation == ACT_NONE:
|
74 |
+
pass
|
75 |
+
|
76 |
+
|
77 |
+
class InPlaceABN(autograd.Function):
|
78 |
+
@staticmethod
|
79 |
+
def forward(ctx, x, weight, bias, running_mean, running_var,
|
80 |
+
training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01):
|
81 |
+
# Save context
|
82 |
+
ctx.training = training
|
83 |
+
ctx.momentum = momentum
|
84 |
+
ctx.eps = eps
|
85 |
+
ctx.activation = activation
|
86 |
+
ctx.slope = slope
|
87 |
+
ctx.affine = weight is not None and bias is not None
|
88 |
+
|
89 |
+
# Prepare inputs
|
90 |
+
count = _count_samples(x)
|
91 |
+
x = x.contiguous()
|
92 |
+
weight = weight.contiguous() if ctx.affine else x.new_empty(0)
|
93 |
+
bias = bias.contiguous() if ctx.affine else x.new_empty(0)
|
94 |
+
|
95 |
+
if ctx.training:
|
96 |
+
mean, var = _backend.mean_var(x)
|
97 |
+
|
98 |
+
# Update running stats
|
99 |
+
running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
|
100 |
+
running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1))
|
101 |
+
|
102 |
+
# Mark in-place modified tensors
|
103 |
+
ctx.mark_dirty(x, running_mean, running_var)
|
104 |
+
else:
|
105 |
+
mean, var = running_mean.contiguous(), running_var.contiguous()
|
106 |
+
ctx.mark_dirty(x)
|
107 |
+
|
108 |
+
# BN forward + activation
|
109 |
+
_backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
|
110 |
+
_act_forward(ctx, x)
|
111 |
+
|
112 |
+
# Output
|
113 |
+
ctx.var = var
|
114 |
+
ctx.save_for_backward(x, var, weight, bias)
|
115 |
+
return x
|
116 |
+
|
117 |
+
@staticmethod
|
118 |
+
@once_differentiable
|
119 |
+
def backward(ctx, dz):
|
120 |
+
z, var, weight, bias = ctx.saved_tensors
|
121 |
+
dz = dz.contiguous()
|
122 |
+
|
123 |
+
# Undo activation
|
124 |
+
_act_backward(ctx, z, dz)
|
125 |
+
|
126 |
+
if ctx.training:
|
127 |
+
edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
|
128 |
+
else:
|
129 |
+
# TODO: implement simplified CUDA backward for inference mode
|
130 |
+
edz = dz.new_zeros(dz.size(1))
|
131 |
+
eydz = dz.new_zeros(dz.size(1))
|
132 |
+
|
133 |
+
dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
|
134 |
+
dweight = eydz * weight.sign() if ctx.affine else None
|
135 |
+
dbias = edz if ctx.affine else None
|
136 |
+
|
137 |
+
return dx, dweight, dbias, None, None, None, None, None, None, None
|
138 |
+
|
139 |
+
class InPlaceABNSync(autograd.Function):
|
140 |
+
@classmethod
|
141 |
+
def forward(cls, ctx, x, weight, bias, running_mean, running_var,
|
142 |
+
training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01, equal_batches=True):
|
143 |
+
# Save context
|
144 |
+
ctx.training = training
|
145 |
+
ctx.momentum = momentum
|
146 |
+
ctx.eps = eps
|
147 |
+
ctx.activation = activation
|
148 |
+
ctx.slope = slope
|
149 |
+
ctx.affine = weight is not None and bias is not None
|
150 |
+
|
151 |
+
# Prepare inputs
|
152 |
+
ctx.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
153 |
+
|
154 |
+
#count = _count_samples(x)
|
155 |
+
batch_size = x.new_tensor([x.shape[0]],dtype=torch.long)
|
156 |
+
|
157 |
+
x = x.contiguous()
|
158 |
+
weight = weight.contiguous() if ctx.affine else x.new_empty(0)
|
159 |
+
bias = bias.contiguous() if ctx.affine else x.new_empty(0)
|
160 |
+
|
161 |
+
if ctx.training:
|
162 |
+
mean, var = _backend.mean_var(x)
|
163 |
+
if ctx.world_size>1:
|
164 |
+
# get global batch size
|
165 |
+
if equal_batches:
|
166 |
+
batch_size *= ctx.world_size
|
167 |
+
else:
|
168 |
+
dist.all_reduce(batch_size, dist.ReduceOp.SUM)
|
169 |
+
|
170 |
+
ctx.factor = x.shape[0]/float(batch_size.item())
|
171 |
+
|
172 |
+
mean_all = mean.clone() * ctx.factor
|
173 |
+
dist.all_reduce(mean_all, dist.ReduceOp.SUM)
|
174 |
+
|
175 |
+
var_all = (var + (mean - mean_all) ** 2) * ctx.factor
|
176 |
+
dist.all_reduce(var_all, dist.ReduceOp.SUM)
|
177 |
+
|
178 |
+
mean = mean_all
|
179 |
+
var = var_all
|
180 |
+
|
181 |
+
# Update running stats
|
182 |
+
running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
|
183 |
+
count = batch_size.item() * x.view(x.shape[0],x.shape[1],-1).shape[-1]
|
184 |
+
running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * (float(count) / (count - 1)))
|
185 |
+
|
186 |
+
# Mark in-place modified tensors
|
187 |
+
ctx.mark_dirty(x, running_mean, running_var)
|
188 |
+
else:
|
189 |
+
mean, var = running_mean.contiguous(), running_var.contiguous()
|
190 |
+
ctx.mark_dirty(x)
|
191 |
+
|
192 |
+
# BN forward + activation
|
193 |
+
_backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
|
194 |
+
_act_forward(ctx, x)
|
195 |
+
|
196 |
+
# Output
|
197 |
+
ctx.var = var
|
198 |
+
ctx.save_for_backward(x, var, weight, bias)
|
199 |
+
return x
|
200 |
+
|
201 |
+
@staticmethod
|
202 |
+
@once_differentiable
|
203 |
+
def backward(ctx, dz):
|
204 |
+
z, var, weight, bias = ctx.saved_tensors
|
205 |
+
dz = dz.contiguous()
|
206 |
+
|
207 |
+
# Undo activation
|
208 |
+
_act_backward(ctx, z, dz)
|
209 |
+
|
210 |
+
if ctx.training:
|
211 |
+
edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
|
212 |
+
edz_local = edz.clone()
|
213 |
+
eydz_local = eydz.clone()
|
214 |
+
|
215 |
+
if ctx.world_size>1:
|
216 |
+
edz *= ctx.factor
|
217 |
+
dist.all_reduce(edz, dist.ReduceOp.SUM)
|
218 |
+
|
219 |
+
eydz *= ctx.factor
|
220 |
+
dist.all_reduce(eydz, dist.ReduceOp.SUM)
|
221 |
+
else:
|
222 |
+
edz_local = edz = dz.new_zeros(dz.size(1))
|
223 |
+
eydz_local = eydz = dz.new_zeros(dz.size(1))
|
224 |
+
|
225 |
+
dx = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
|
226 |
+
dweight = eydz_local * weight.sign() if ctx.affine else None
|
227 |
+
dbias = edz_local if ctx.affine else None
|
228 |
+
|
229 |
+
return dx, dweight, dbias, None, None, None, None, None, None, None
|
230 |
+
|
231 |
+
inplace_abn = InPlaceABN.apply
|
232 |
+
inplace_abn_sync = InPlaceABNSync.apply
|
233 |
+
|
234 |
+
__all__ = ["inplace_abn", "inplace_abn_sync", "ACT_RELU", "ACT_LEAKY_RELU", "ACT_ELU", "ACT_NONE"]
|
ConsistentID/lib/BiSeNet/modules/misc.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
import torch.distributed as dist
|
4 |
+
|
5 |
+
class GlobalAvgPool2d(nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
"""Global average pooling over the input's spatial dimensions"""
|
8 |
+
super(GlobalAvgPool2d, self).__init__()
|
9 |
+
|
10 |
+
def forward(self, inputs):
|
11 |
+
in_size = inputs.size()
|
12 |
+
return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2)
|
13 |
+
|
14 |
+
class SingleGPU(nn.Module):
|
15 |
+
def __init__(self, module):
|
16 |
+
super(SingleGPU, self).__init__()
|
17 |
+
self.module=module
|
18 |
+
|
19 |
+
def forward(self, input):
|
20 |
+
return self.module(input.cuda(non_blocking=True))
|
21 |
+
|
ConsistentID/lib/BiSeNet/modules/residual.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from .bn import ABN
|
6 |
+
|
7 |
+
|
8 |
+
class IdentityResidualBlock(nn.Module):
|
9 |
+
def __init__(self,
|
10 |
+
in_channels,
|
11 |
+
channels,
|
12 |
+
stride=1,
|
13 |
+
dilation=1,
|
14 |
+
groups=1,
|
15 |
+
norm_act=ABN,
|
16 |
+
dropout=None):
|
17 |
+
"""Configurable identity-mapping residual block
|
18 |
+
|
19 |
+
Parameters
|
20 |
+
----------
|
21 |
+
in_channels : int
|
22 |
+
Number of input channels.
|
23 |
+
channels : list of int
|
24 |
+
Number of channels in the internal feature maps. Can either have two or three elements: if three construct
|
25 |
+
a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then
|
26 |
+
`3 x 3` then `1 x 1` convolutions.
|
27 |
+
stride : int
|
28 |
+
Stride of the first `3 x 3` convolution
|
29 |
+
dilation : int
|
30 |
+
Dilation to apply to the `3 x 3` convolutions.
|
31 |
+
groups : int
|
32 |
+
Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with
|
33 |
+
bottleneck blocks.
|
34 |
+
norm_act : callable
|
35 |
+
Function to create normalization / activation Module.
|
36 |
+
dropout: callable
|
37 |
+
Function to create Dropout Module.
|
38 |
+
"""
|
39 |
+
super(IdentityResidualBlock, self).__init__()
|
40 |
+
|
41 |
+
# Check parameters for inconsistencies
|
42 |
+
if len(channels) != 2 and len(channels) != 3:
|
43 |
+
raise ValueError("channels must contain either two or three values")
|
44 |
+
if len(channels) == 2 and groups != 1:
|
45 |
+
raise ValueError("groups > 1 are only valid if len(channels) == 3")
|
46 |
+
|
47 |
+
is_bottleneck = len(channels) == 3
|
48 |
+
need_proj_conv = stride != 1 or in_channels != channels[-1]
|
49 |
+
|
50 |
+
self.bn1 = norm_act(in_channels)
|
51 |
+
if not is_bottleneck:
|
52 |
+
layers = [
|
53 |
+
("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False,
|
54 |
+
dilation=dilation)),
|
55 |
+
("bn2", norm_act(channels[0])),
|
56 |
+
("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
|
57 |
+
dilation=dilation))
|
58 |
+
]
|
59 |
+
if dropout is not None:
|
60 |
+
layers = layers[0:2] + [("dropout", dropout())] + layers[2:]
|
61 |
+
else:
|
62 |
+
layers = [
|
63 |
+
("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=stride, padding=0, bias=False)),
|
64 |
+
("bn2", norm_act(channels[0])),
|
65 |
+
("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
|
66 |
+
groups=groups, dilation=dilation)),
|
67 |
+
("bn3", norm_act(channels[1])),
|
68 |
+
("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False))
|
69 |
+
]
|
70 |
+
if dropout is not None:
|
71 |
+
layers = layers[0:4] + [("dropout", dropout())] + layers[4:]
|
72 |
+
self.convs = nn.Sequential(OrderedDict(layers))
|
73 |
+
|
74 |
+
if need_proj_conv:
|
75 |
+
self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False)
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
if hasattr(self, "proj_conv"):
|
79 |
+
bn1 = self.bn1(x)
|
80 |
+
shortcut = self.proj_conv(bn1)
|
81 |
+
else:
|
82 |
+
shortcut = x.clone()
|
83 |
+
bn1 = self.bn1(x)
|
84 |
+
|
85 |
+
out = self.convs(bn1)
|
86 |
+
out.add_(shortcut)
|
87 |
+
|
88 |
+
return out
|
ConsistentID/lib/BiSeNet/modules/src/checks.h
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <ATen/ATen.h>
|
4 |
+
|
5 |
+
// Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT
|
6 |
+
#ifndef AT_CHECK
|
7 |
+
#define AT_CHECK AT_ASSERT
|
8 |
+
#endif
|
9 |
+
|
10 |
+
#define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor")
|
11 |
+
#define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor")
|
12 |
+
#define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous")
|
13 |
+
|
14 |
+
#define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
15 |
+
#define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x)
|
ConsistentID/lib/BiSeNet/modules/src/inplace_abn.cpp
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
#include <vector>
|
4 |
+
|
5 |
+
#include "inplace_abn.h"
|
6 |
+
|
7 |
+
std::vector<at::Tensor> mean_var(at::Tensor x) {
|
8 |
+
if (x.is_cuda()) {
|
9 |
+
if (x.type().scalarType() == at::ScalarType::Half) {
|
10 |
+
return mean_var_cuda_h(x);
|
11 |
+
} else {
|
12 |
+
return mean_var_cuda(x);
|
13 |
+
}
|
14 |
+
} else {
|
15 |
+
return mean_var_cpu(x);
|
16 |
+
}
|
17 |
+
}
|
18 |
+
|
19 |
+
at::Tensor forward(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
20 |
+
bool affine, float eps) {
|
21 |
+
if (x.is_cuda()) {
|
22 |
+
if (x.type().scalarType() == at::ScalarType::Half) {
|
23 |
+
return forward_cuda_h(x, mean, var, weight, bias, affine, eps);
|
24 |
+
} else {
|
25 |
+
return forward_cuda(x, mean, var, weight, bias, affine, eps);
|
26 |
+
}
|
27 |
+
} else {
|
28 |
+
return forward_cpu(x, mean, var, weight, bias, affine, eps);
|
29 |
+
}
|
30 |
+
}
|
31 |
+
|
32 |
+
std::vector<at::Tensor> edz_eydz(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
|
33 |
+
bool affine, float eps) {
|
34 |
+
if (z.is_cuda()) {
|
35 |
+
if (z.type().scalarType() == at::ScalarType::Half) {
|
36 |
+
return edz_eydz_cuda_h(z, dz, weight, bias, affine, eps);
|
37 |
+
} else {
|
38 |
+
return edz_eydz_cuda(z, dz, weight, bias, affine, eps);
|
39 |
+
}
|
40 |
+
} else {
|
41 |
+
return edz_eydz_cpu(z, dz, weight, bias, affine, eps);
|
42 |
+
}
|
43 |
+
}
|
44 |
+
|
45 |
+
at::Tensor backward(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
46 |
+
at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
|
47 |
+
if (z.is_cuda()) {
|
48 |
+
if (z.type().scalarType() == at::ScalarType::Half) {
|
49 |
+
return backward_cuda_h(z, dz, var, weight, bias, edz, eydz, affine, eps);
|
50 |
+
} else {
|
51 |
+
return backward_cuda(z, dz, var, weight, bias, edz, eydz, affine, eps);
|
52 |
+
}
|
53 |
+
} else {
|
54 |
+
return backward_cpu(z, dz, var, weight, bias, edz, eydz, affine, eps);
|
55 |
+
}
|
56 |
+
}
|
57 |
+
|
58 |
+
void leaky_relu_forward(at::Tensor z, float slope) {
|
59 |
+
at::leaky_relu_(z, slope);
|
60 |
+
}
|
61 |
+
|
62 |
+
void leaky_relu_backward(at::Tensor z, at::Tensor dz, float slope) {
|
63 |
+
if (z.is_cuda()) {
|
64 |
+
if (z.type().scalarType() == at::ScalarType::Half) {
|
65 |
+
return leaky_relu_backward_cuda_h(z, dz, slope);
|
66 |
+
} else {
|
67 |
+
return leaky_relu_backward_cuda(z, dz, slope);
|
68 |
+
}
|
69 |
+
} else {
|
70 |
+
return leaky_relu_backward_cpu(z, dz, slope);
|
71 |
+
}
|
72 |
+
}
|
73 |
+
|
74 |
+
void elu_forward(at::Tensor z) {
|
75 |
+
at::elu_(z);
|
76 |
+
}
|
77 |
+
|
78 |
+
void elu_backward(at::Tensor z, at::Tensor dz) {
|
79 |
+
if (z.is_cuda()) {
|
80 |
+
return elu_backward_cuda(z, dz);
|
81 |
+
} else {
|
82 |
+
return elu_backward_cpu(z, dz);
|
83 |
+
}
|
84 |
+
}
|
85 |
+
|
86 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
87 |
+
m.def("mean_var", &mean_var, "Mean and variance computation");
|
88 |
+
m.def("forward", &forward, "In-place forward computation");
|
89 |
+
m.def("edz_eydz", &edz_eydz, "First part of backward computation");
|
90 |
+
m.def("backward", &backward, "Second part of backward computation");
|
91 |
+
m.def("leaky_relu_forward", &leaky_relu_forward, "Leaky relu forward computation");
|
92 |
+
m.def("leaky_relu_backward", &leaky_relu_backward, "Leaky relu backward computation and inversion");
|
93 |
+
m.def("elu_forward", &elu_forward, "Elu forward computation");
|
94 |
+
m.def("elu_backward", &elu_backward, "Elu backward computation and inversion");
|
95 |
+
}
|
ConsistentID/lib/BiSeNet/modules/src/inplace_abn.h
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <ATen/ATen.h>
|
4 |
+
|
5 |
+
#include <vector>
|
6 |
+
|
7 |
+
std::vector<at::Tensor> mean_var_cpu(at::Tensor x);
|
8 |
+
std::vector<at::Tensor> mean_var_cuda(at::Tensor x);
|
9 |
+
std::vector<at::Tensor> mean_var_cuda_h(at::Tensor x);
|
10 |
+
|
11 |
+
at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
12 |
+
bool affine, float eps);
|
13 |
+
at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
14 |
+
bool affine, float eps);
|
15 |
+
at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
16 |
+
bool affine, float eps);
|
17 |
+
|
18 |
+
std::vector<at::Tensor> edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
|
19 |
+
bool affine, float eps);
|
20 |
+
std::vector<at::Tensor> edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
|
21 |
+
bool affine, float eps);
|
22 |
+
std::vector<at::Tensor> edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
|
23 |
+
bool affine, float eps);
|
24 |
+
|
25 |
+
at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
26 |
+
at::Tensor edz, at::Tensor eydz, bool affine, float eps);
|
27 |
+
at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
28 |
+
at::Tensor edz, at::Tensor eydz, bool affine, float eps);
|
29 |
+
at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
30 |
+
at::Tensor edz, at::Tensor eydz, bool affine, float eps);
|
31 |
+
|
32 |
+
void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope);
|
33 |
+
void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope);
|
34 |
+
void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope);
|
35 |
+
|
36 |
+
void elu_backward_cpu(at::Tensor z, at::Tensor dz);
|
37 |
+
void elu_backward_cuda(at::Tensor z, at::Tensor dz);
|
38 |
+
|
39 |
+
static void get_dims(at::Tensor x, int64_t& num, int64_t& chn, int64_t& sp) {
|
40 |
+
num = x.size(0);
|
41 |
+
chn = x.size(1);
|
42 |
+
sp = 1;
|
43 |
+
for (int64_t i = 2; i < x.ndimension(); ++i)
|
44 |
+
sp *= x.size(i);
|
45 |
+
}
|
46 |
+
|
47 |
+
/*
|
48 |
+
* Specialized CUDA reduction functions for BN
|
49 |
+
*/
|
50 |
+
#ifdef __CUDACC__
|
51 |
+
|
52 |
+
#include "utils/cuda.cuh"
|
53 |
+
|
54 |
+
template <typename T, typename Op>
|
55 |
+
__device__ T reduce(Op op, int plane, int N, int S) {
|
56 |
+
T sum = (T)0;
|
57 |
+
for (int batch = 0; batch < N; ++batch) {
|
58 |
+
for (int x = threadIdx.x; x < S; x += blockDim.x) {
|
59 |
+
sum += op(batch, plane, x);
|
60 |
+
}
|
61 |
+
}
|
62 |
+
|
63 |
+
// sum over NumThreads within a warp
|
64 |
+
sum = warpSum(sum);
|
65 |
+
|
66 |
+
// 'transpose', and reduce within warp again
|
67 |
+
__shared__ T shared[32];
|
68 |
+
__syncthreads();
|
69 |
+
if (threadIdx.x % WARP_SIZE == 0) {
|
70 |
+
shared[threadIdx.x / WARP_SIZE] = sum;
|
71 |
+
}
|
72 |
+
if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
|
73 |
+
// zero out the other entries in shared
|
74 |
+
shared[threadIdx.x] = (T)0;
|
75 |
+
}
|
76 |
+
__syncthreads();
|
77 |
+
if (threadIdx.x / WARP_SIZE == 0) {
|
78 |
+
sum = warpSum(shared[threadIdx.x]);
|
79 |
+
if (threadIdx.x == 0) {
|
80 |
+
shared[0] = sum;
|
81 |
+
}
|
82 |
+
}
|
83 |
+
__syncthreads();
|
84 |
+
|
85 |
+
// Everyone picks it up, should be broadcast into the whole gradInput
|
86 |
+
return shared[0];
|
87 |
+
}
|
88 |
+
#endif
|
ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cpu.cpp
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/ATen.h>
|
2 |
+
|
3 |
+
#include <vector>
|
4 |
+
|
5 |
+
#include "utils/checks.h"
|
6 |
+
#include "inplace_abn.h"
|
7 |
+
|
8 |
+
at::Tensor reduce_sum(at::Tensor x) {
|
9 |
+
if (x.ndimension() == 2) {
|
10 |
+
return x.sum(0);
|
11 |
+
} else {
|
12 |
+
auto x_view = x.view({x.size(0), x.size(1), -1});
|
13 |
+
return x_view.sum(-1).sum(0);
|
14 |
+
}
|
15 |
+
}
|
16 |
+
|
17 |
+
at::Tensor broadcast_to(at::Tensor v, at::Tensor x) {
|
18 |
+
if (x.ndimension() == 2) {
|
19 |
+
return v;
|
20 |
+
} else {
|
21 |
+
std::vector<int64_t> broadcast_size = {1, -1};
|
22 |
+
for (int64_t i = 2; i < x.ndimension(); ++i)
|
23 |
+
broadcast_size.push_back(1);
|
24 |
+
|
25 |
+
return v.view(broadcast_size);
|
26 |
+
}
|
27 |
+
}
|
28 |
+
|
29 |
+
int64_t count(at::Tensor x) {
|
30 |
+
int64_t count = x.size(0);
|
31 |
+
for (int64_t i = 2; i < x.ndimension(); ++i)
|
32 |
+
count *= x.size(i);
|
33 |
+
|
34 |
+
return count;
|
35 |
+
}
|
36 |
+
|
37 |
+
at::Tensor invert_affine(at::Tensor z, at::Tensor weight, at::Tensor bias, bool affine, float eps) {
|
38 |
+
if (affine) {
|
39 |
+
return (z - broadcast_to(bias, z)) / broadcast_to(at::abs(weight) + eps, z);
|
40 |
+
} else {
|
41 |
+
return z;
|
42 |
+
}
|
43 |
+
}
|
44 |
+
|
45 |
+
std::vector<at::Tensor> mean_var_cpu(at::Tensor x) {
|
46 |
+
auto num = count(x);
|
47 |
+
auto mean = reduce_sum(x) / num;
|
48 |
+
auto diff = x - broadcast_to(mean, x);
|
49 |
+
auto var = reduce_sum(diff.pow(2)) / num;
|
50 |
+
|
51 |
+
return {mean, var};
|
52 |
+
}
|
53 |
+
|
54 |
+
at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
55 |
+
bool affine, float eps) {
|
56 |
+
auto gamma = affine ? at::abs(weight) + eps : at::ones_like(var);
|
57 |
+
auto mul = at::rsqrt(var + eps) * gamma;
|
58 |
+
|
59 |
+
x.sub_(broadcast_to(mean, x));
|
60 |
+
x.mul_(broadcast_to(mul, x));
|
61 |
+
if (affine) x.add_(broadcast_to(bias, x));
|
62 |
+
|
63 |
+
return x;
|
64 |
+
}
|
65 |
+
|
66 |
+
std::vector<at::Tensor> edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
|
67 |
+
bool affine, float eps) {
|
68 |
+
auto edz = reduce_sum(dz);
|
69 |
+
auto y = invert_affine(z, weight, bias, affine, eps);
|
70 |
+
auto eydz = reduce_sum(y * dz);
|
71 |
+
|
72 |
+
return {edz, eydz};
|
73 |
+
}
|
74 |
+
|
75 |
+
at::Tensor backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
76 |
+
at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
|
77 |
+
auto y = invert_affine(z, weight, bias, affine, eps);
|
78 |
+
auto mul = affine ? at::rsqrt(var + eps) * (at::abs(weight) + eps) : at::rsqrt(var + eps);
|
79 |
+
|
80 |
+
auto num = count(z);
|
81 |
+
auto dx = (dz - broadcast_to(edz / num, dz) - y * broadcast_to(eydz / num, dz)) * broadcast_to(mul, dz);
|
82 |
+
return dx;
|
83 |
+
}
|
84 |
+
|
85 |
+
void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope) {
|
86 |
+
CHECK_CPU_INPUT(z);
|
87 |
+
CHECK_CPU_INPUT(dz);
|
88 |
+
|
89 |
+
AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cpu", ([&] {
|
90 |
+
int64_t count = z.numel();
|
91 |
+
auto *_z = z.data<scalar_t>();
|
92 |
+
auto *_dz = dz.data<scalar_t>();
|
93 |
+
|
94 |
+
for (int64_t i = 0; i < count; ++i) {
|
95 |
+
if (_z[i] < 0) {
|
96 |
+
_z[i] *= 1 / slope;
|
97 |
+
_dz[i] *= slope;
|
98 |
+
}
|
99 |
+
}
|
100 |
+
}));
|
101 |
+
}
|
102 |
+
|
103 |
+
void elu_backward_cpu(at::Tensor z, at::Tensor dz) {
|
104 |
+
CHECK_CPU_INPUT(z);
|
105 |
+
CHECK_CPU_INPUT(dz);
|
106 |
+
|
107 |
+
AT_DISPATCH_FLOATING_TYPES(z.type(), "elu_backward_cpu", ([&] {
|
108 |
+
int64_t count = z.numel();
|
109 |
+
auto *_z = z.data<scalar_t>();
|
110 |
+
auto *_dz = dz.data<scalar_t>();
|
111 |
+
|
112 |
+
for (int64_t i = 0; i < count; ++i) {
|
113 |
+
if (_z[i] < 0) {
|
114 |
+
_z[i] = log1p(_z[i]);
|
115 |
+
_dz[i] *= (_z[i] + 1.f);
|
116 |
+
}
|
117 |
+
}
|
118 |
+
}));
|
119 |
+
}
|
ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cuda.cu
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/ATen.h>
|
2 |
+
|
3 |
+
#include <thrust/device_ptr.h>
|
4 |
+
#include <thrust/transform.h>
|
5 |
+
|
6 |
+
#include <vector>
|
7 |
+
|
8 |
+
#include "utils/checks.h"
|
9 |
+
#include "utils/cuda.cuh"
|
10 |
+
#include "inplace_abn.h"
|
11 |
+
|
12 |
+
#include <ATen/cuda/CUDAContext.h>
|
13 |
+
|
14 |
+
// Operations for reduce
|
15 |
+
template<typename T>
|
16 |
+
struct SumOp {
|
17 |
+
__device__ SumOp(const T *t, int c, int s)
|
18 |
+
: tensor(t), chn(c), sp(s) {}
|
19 |
+
__device__ __forceinline__ T operator()(int batch, int plane, int n) {
|
20 |
+
return tensor[(batch * chn + plane) * sp + n];
|
21 |
+
}
|
22 |
+
const T *tensor;
|
23 |
+
const int chn;
|
24 |
+
const int sp;
|
25 |
+
};
|
26 |
+
|
27 |
+
template<typename T>
|
28 |
+
struct VarOp {
|
29 |
+
__device__ VarOp(T m, const T *t, int c, int s)
|
30 |
+
: mean(m), tensor(t), chn(c), sp(s) {}
|
31 |
+
__device__ __forceinline__ T operator()(int batch, int plane, int n) {
|
32 |
+
T val = tensor[(batch * chn + plane) * sp + n];
|
33 |
+
return (val - mean) * (val - mean);
|
34 |
+
}
|
35 |
+
const T mean;
|
36 |
+
const T *tensor;
|
37 |
+
const int chn;
|
38 |
+
const int sp;
|
39 |
+
};
|
40 |
+
|
41 |
+
template<typename T>
|
42 |
+
struct GradOp {
|
43 |
+
__device__ GradOp(T _weight, T _bias, const T *_z, const T *_dz, int c, int s)
|
44 |
+
: weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {}
|
45 |
+
__device__ __forceinline__ Pair<T> operator()(int batch, int plane, int n) {
|
46 |
+
T _y = (z[(batch * chn + plane) * sp + n] - bias) / weight;
|
47 |
+
T _dz = dz[(batch * chn + plane) * sp + n];
|
48 |
+
return Pair<T>(_dz, _y * _dz);
|
49 |
+
}
|
50 |
+
const T weight;
|
51 |
+
const T bias;
|
52 |
+
const T *z;
|
53 |
+
const T *dz;
|
54 |
+
const int chn;
|
55 |
+
const int sp;
|
56 |
+
};
|
57 |
+
|
58 |
+
/***********
|
59 |
+
* mean_var
|
60 |
+
***********/
|
61 |
+
|
62 |
+
template<typename T>
|
63 |
+
__global__ void mean_var_kernel(const T *x, T *mean, T *var, int num, int chn, int sp) {
|
64 |
+
int plane = blockIdx.x;
|
65 |
+
T norm = T(1) / T(num * sp);
|
66 |
+
|
67 |
+
T _mean = reduce<T, SumOp<T>>(SumOp<T>(x, chn, sp), plane, num, sp) * norm;
|
68 |
+
__syncthreads();
|
69 |
+
T _var = reduce<T, VarOp<T>>(VarOp<T>(_mean, x, chn, sp), plane, num, sp) * norm;
|
70 |
+
|
71 |
+
if (threadIdx.x == 0) {
|
72 |
+
mean[plane] = _mean;
|
73 |
+
var[plane] = _var;
|
74 |
+
}
|
75 |
+
}
|
76 |
+
|
77 |
+
std::vector<at::Tensor> mean_var_cuda(at::Tensor x) {
|
78 |
+
CHECK_CUDA_INPUT(x);
|
79 |
+
|
80 |
+
// Extract dimensions
|
81 |
+
int64_t num, chn, sp;
|
82 |
+
get_dims(x, num, chn, sp);
|
83 |
+
|
84 |
+
// Prepare output tensors
|
85 |
+
auto mean = at::empty({chn}, x.options());
|
86 |
+
auto var = at::empty({chn}, x.options());
|
87 |
+
|
88 |
+
// Run kernel
|
89 |
+
dim3 blocks(chn);
|
90 |
+
dim3 threads(getNumThreads(sp));
|
91 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
92 |
+
AT_DISPATCH_FLOATING_TYPES(x.type(), "mean_var_cuda", ([&] {
|
93 |
+
mean_var_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
94 |
+
x.data<scalar_t>(),
|
95 |
+
mean.data<scalar_t>(),
|
96 |
+
var.data<scalar_t>(),
|
97 |
+
num, chn, sp);
|
98 |
+
}));
|
99 |
+
|
100 |
+
return {mean, var};
|
101 |
+
}
|
102 |
+
|
103 |
+
/**********
|
104 |
+
* forward
|
105 |
+
**********/
|
106 |
+
|
107 |
+
template<typename T>
|
108 |
+
__global__ void forward_kernel(T *x, const T *mean, const T *var, const T *weight, const T *bias,
|
109 |
+
bool affine, float eps, int num, int chn, int sp) {
|
110 |
+
int plane = blockIdx.x;
|
111 |
+
|
112 |
+
T _mean = mean[plane];
|
113 |
+
T _var = var[plane];
|
114 |
+
T _weight = affine ? abs(weight[plane]) + eps : T(1);
|
115 |
+
T _bias = affine ? bias[plane] : T(0);
|
116 |
+
|
117 |
+
T mul = rsqrt(_var + eps) * _weight;
|
118 |
+
|
119 |
+
for (int batch = 0; batch < num; ++batch) {
|
120 |
+
for (int n = threadIdx.x; n < sp; n += blockDim.x) {
|
121 |
+
T _x = x[(batch * chn + plane) * sp + n];
|
122 |
+
T _y = (_x - _mean) * mul + _bias;
|
123 |
+
|
124 |
+
x[(batch * chn + plane) * sp + n] = _y;
|
125 |
+
}
|
126 |
+
}
|
127 |
+
}
|
128 |
+
|
129 |
+
at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
130 |
+
bool affine, float eps) {
|
131 |
+
CHECK_CUDA_INPUT(x);
|
132 |
+
CHECK_CUDA_INPUT(mean);
|
133 |
+
CHECK_CUDA_INPUT(var);
|
134 |
+
CHECK_CUDA_INPUT(weight);
|
135 |
+
CHECK_CUDA_INPUT(bias);
|
136 |
+
|
137 |
+
// Extract dimensions
|
138 |
+
int64_t num, chn, sp;
|
139 |
+
get_dims(x, num, chn, sp);
|
140 |
+
|
141 |
+
// Run kernel
|
142 |
+
dim3 blocks(chn);
|
143 |
+
dim3 threads(getNumThreads(sp));
|
144 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
145 |
+
AT_DISPATCH_FLOATING_TYPES(x.type(), "forward_cuda", ([&] {
|
146 |
+
forward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
147 |
+
x.data<scalar_t>(),
|
148 |
+
mean.data<scalar_t>(),
|
149 |
+
var.data<scalar_t>(),
|
150 |
+
weight.data<scalar_t>(),
|
151 |
+
bias.data<scalar_t>(),
|
152 |
+
affine, eps, num, chn, sp);
|
153 |
+
}));
|
154 |
+
|
155 |
+
return x;
|
156 |
+
}
|
157 |
+
|
158 |
+
/***********
|
159 |
+
* edz_eydz
|
160 |
+
***********/
|
161 |
+
|
162 |
+
template<typename T>
|
163 |
+
__global__ void edz_eydz_kernel(const T *z, const T *dz, const T *weight, const T *bias,
|
164 |
+
T *edz, T *eydz, bool affine, float eps, int num, int chn, int sp) {
|
165 |
+
int plane = blockIdx.x;
|
166 |
+
|
167 |
+
T _weight = affine ? abs(weight[plane]) + eps : 1.f;
|
168 |
+
T _bias = affine ? bias[plane] : 0.f;
|
169 |
+
|
170 |
+
Pair<T> res = reduce<Pair<T>, GradOp<T>>(GradOp<T>(_weight, _bias, z, dz, chn, sp), plane, num, sp);
|
171 |
+
__syncthreads();
|
172 |
+
|
173 |
+
if (threadIdx.x == 0) {
|
174 |
+
edz[plane] = res.v1;
|
175 |
+
eydz[plane] = res.v2;
|
176 |
+
}
|
177 |
+
}
|
178 |
+
|
179 |
+
std::vector<at::Tensor> edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
|
180 |
+
bool affine, float eps) {
|
181 |
+
CHECK_CUDA_INPUT(z);
|
182 |
+
CHECK_CUDA_INPUT(dz);
|
183 |
+
CHECK_CUDA_INPUT(weight);
|
184 |
+
CHECK_CUDA_INPUT(bias);
|
185 |
+
|
186 |
+
// Extract dimensions
|
187 |
+
int64_t num, chn, sp;
|
188 |
+
get_dims(z, num, chn, sp);
|
189 |
+
|
190 |
+
auto edz = at::empty({chn}, z.options());
|
191 |
+
auto eydz = at::empty({chn}, z.options());
|
192 |
+
|
193 |
+
// Run kernel
|
194 |
+
dim3 blocks(chn);
|
195 |
+
dim3 threads(getNumThreads(sp));
|
196 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
197 |
+
AT_DISPATCH_FLOATING_TYPES(z.type(), "edz_eydz_cuda", ([&] {
|
198 |
+
edz_eydz_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
199 |
+
z.data<scalar_t>(),
|
200 |
+
dz.data<scalar_t>(),
|
201 |
+
weight.data<scalar_t>(),
|
202 |
+
bias.data<scalar_t>(),
|
203 |
+
edz.data<scalar_t>(),
|
204 |
+
eydz.data<scalar_t>(),
|
205 |
+
affine, eps, num, chn, sp);
|
206 |
+
}));
|
207 |
+
|
208 |
+
return {edz, eydz};
|
209 |
+
}
|
210 |
+
|
211 |
+
/***********
|
212 |
+
* backward
|
213 |
+
***********/
|
214 |
+
|
215 |
+
template<typename T>
|
216 |
+
__global__ void backward_kernel(const T *z, const T *dz, const T *var, const T *weight, const T *bias, const T *edz,
|
217 |
+
const T *eydz, T *dx, bool affine, float eps, int num, int chn, int sp) {
|
218 |
+
int plane = blockIdx.x;
|
219 |
+
|
220 |
+
T _weight = affine ? abs(weight[plane]) + eps : 1.f;
|
221 |
+
T _bias = affine ? bias[plane] : 0.f;
|
222 |
+
T _var = var[plane];
|
223 |
+
T _edz = edz[plane];
|
224 |
+
T _eydz = eydz[plane];
|
225 |
+
|
226 |
+
T _mul = _weight * rsqrt(_var + eps);
|
227 |
+
T count = T(num * sp);
|
228 |
+
|
229 |
+
for (int batch = 0; batch < num; ++batch) {
|
230 |
+
for (int n = threadIdx.x; n < sp; n += blockDim.x) {
|
231 |
+
T _dz = dz[(batch * chn + plane) * sp + n];
|
232 |
+
T _y = (z[(batch * chn + plane) * sp + n] - _bias) / _weight;
|
233 |
+
|
234 |
+
dx[(batch * chn + plane) * sp + n] = (_dz - _edz / count - _y * _eydz / count) * _mul;
|
235 |
+
}
|
236 |
+
}
|
237 |
+
}
|
238 |
+
|
239 |
+
at::Tensor backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
240 |
+
at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
|
241 |
+
CHECK_CUDA_INPUT(z);
|
242 |
+
CHECK_CUDA_INPUT(dz);
|
243 |
+
CHECK_CUDA_INPUT(var);
|
244 |
+
CHECK_CUDA_INPUT(weight);
|
245 |
+
CHECK_CUDA_INPUT(bias);
|
246 |
+
CHECK_CUDA_INPUT(edz);
|
247 |
+
CHECK_CUDA_INPUT(eydz);
|
248 |
+
|
249 |
+
// Extract dimensions
|
250 |
+
int64_t num, chn, sp;
|
251 |
+
get_dims(z, num, chn, sp);
|
252 |
+
|
253 |
+
auto dx = at::zeros_like(z);
|
254 |
+
|
255 |
+
// Run kernel
|
256 |
+
dim3 blocks(chn);
|
257 |
+
dim3 threads(getNumThreads(sp));
|
258 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
259 |
+
AT_DISPATCH_FLOATING_TYPES(z.type(), "backward_cuda", ([&] {
|
260 |
+
backward_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
|
261 |
+
z.data<scalar_t>(),
|
262 |
+
dz.data<scalar_t>(),
|
263 |
+
var.data<scalar_t>(),
|
264 |
+
weight.data<scalar_t>(),
|
265 |
+
bias.data<scalar_t>(),
|
266 |
+
edz.data<scalar_t>(),
|
267 |
+
eydz.data<scalar_t>(),
|
268 |
+
dx.data<scalar_t>(),
|
269 |
+
affine, eps, num, chn, sp);
|
270 |
+
}));
|
271 |
+
|
272 |
+
return dx;
|
273 |
+
}
|
274 |
+
|
275 |
+
/**************
|
276 |
+
* activations
|
277 |
+
**************/
|
278 |
+
|
279 |
+
template<typename T>
|
280 |
+
inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) {
|
281 |
+
// Create thrust pointers
|
282 |
+
thrust::device_ptr<T> th_z = thrust::device_pointer_cast(z);
|
283 |
+
thrust::device_ptr<T> th_dz = thrust::device_pointer_cast(dz);
|
284 |
+
|
285 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
286 |
+
thrust::transform_if(thrust::cuda::par.on(stream),
|
287 |
+
th_dz, th_dz + count, th_z, th_dz,
|
288 |
+
[slope] __device__ (const T& dz) { return dz * slope; },
|
289 |
+
[] __device__ (const T& z) { return z < 0; });
|
290 |
+
thrust::transform_if(thrust::cuda::par.on(stream),
|
291 |
+
th_z, th_z + count, th_z,
|
292 |
+
[slope] __device__ (const T& z) { return z / slope; },
|
293 |
+
[] __device__ (const T& z) { return z < 0; });
|
294 |
+
}
|
295 |
+
|
296 |
+
void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope) {
|
297 |
+
CHECK_CUDA_INPUT(z);
|
298 |
+
CHECK_CUDA_INPUT(dz);
|
299 |
+
|
300 |
+
int64_t count = z.numel();
|
301 |
+
|
302 |
+
AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
|
303 |
+
leaky_relu_backward_impl<scalar_t>(z.data<scalar_t>(), dz.data<scalar_t>(), slope, count);
|
304 |
+
}));
|
305 |
+
}
|
306 |
+
|
307 |
+
template<typename T>
|
308 |
+
inline void elu_backward_impl(T *z, T *dz, int64_t count) {
|
309 |
+
// Create thrust pointers
|
310 |
+
thrust::device_ptr<T> th_z = thrust::device_pointer_cast(z);
|
311 |
+
thrust::device_ptr<T> th_dz = thrust::device_pointer_cast(dz);
|
312 |
+
|
313 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
314 |
+
thrust::transform_if(thrust::cuda::par.on(stream),
|
315 |
+
th_dz, th_dz + count, th_z, th_z, th_dz,
|
316 |
+
[] __device__ (const T& dz, const T& z) { return dz * (z + 1.); },
|
317 |
+
[] __device__ (const T& z) { return z < 0; });
|
318 |
+
thrust::transform_if(thrust::cuda::par.on(stream),
|
319 |
+
th_z, th_z + count, th_z,
|
320 |
+
[] __device__ (const T& z) { return log1p(z); },
|
321 |
+
[] __device__ (const T& z) { return z < 0; });
|
322 |
+
}
|
323 |
+
|
324 |
+
void elu_backward_cuda(at::Tensor z, at::Tensor dz) {
|
325 |
+
CHECK_CUDA_INPUT(z);
|
326 |
+
CHECK_CUDA_INPUT(dz);
|
327 |
+
|
328 |
+
int64_t count = z.numel();
|
329 |
+
|
330 |
+
AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
|
331 |
+
elu_backward_impl<scalar_t>(z.data<scalar_t>(), dz.data<scalar_t>(), count);
|
332 |
+
}));
|
333 |
+
}
|
ConsistentID/lib/BiSeNet/modules/src/inplace_abn_cuda_half.cu
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/ATen.h>
|
2 |
+
|
3 |
+
#include <cuda_fp16.h>
|
4 |
+
|
5 |
+
#include <vector>
|
6 |
+
|
7 |
+
#include "utils/checks.h"
|
8 |
+
#include "utils/cuda.cuh"
|
9 |
+
#include "inplace_abn.h"
|
10 |
+
|
11 |
+
#include <ATen/cuda/CUDAContext.h>
|
12 |
+
|
13 |
+
// Operations for reduce
|
14 |
+
struct SumOpH {
|
15 |
+
__device__ SumOpH(const half *t, int c, int s)
|
16 |
+
: tensor(t), chn(c), sp(s) {}
|
17 |
+
__device__ __forceinline__ float operator()(int batch, int plane, int n) {
|
18 |
+
return __half2float(tensor[(batch * chn + plane) * sp + n]);
|
19 |
+
}
|
20 |
+
const half *tensor;
|
21 |
+
const int chn;
|
22 |
+
const int sp;
|
23 |
+
};
|
24 |
+
|
25 |
+
struct VarOpH {
|
26 |
+
__device__ VarOpH(float m, const half *t, int c, int s)
|
27 |
+
: mean(m), tensor(t), chn(c), sp(s) {}
|
28 |
+
__device__ __forceinline__ float operator()(int batch, int plane, int n) {
|
29 |
+
const auto t = __half2float(tensor[(batch * chn + plane) * sp + n]);
|
30 |
+
return (t - mean) * (t - mean);
|
31 |
+
}
|
32 |
+
const float mean;
|
33 |
+
const half *tensor;
|
34 |
+
const int chn;
|
35 |
+
const int sp;
|
36 |
+
};
|
37 |
+
|
38 |
+
struct GradOpH {
|
39 |
+
__device__ GradOpH(float _weight, float _bias, const half *_z, const half *_dz, int c, int s)
|
40 |
+
: weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {}
|
41 |
+
__device__ __forceinline__ Pair<float> operator()(int batch, int plane, int n) {
|
42 |
+
float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - bias) / weight;
|
43 |
+
float _dz = __half2float(dz[(batch * chn + plane) * sp + n]);
|
44 |
+
return Pair<float>(_dz, _y * _dz);
|
45 |
+
}
|
46 |
+
const float weight;
|
47 |
+
const float bias;
|
48 |
+
const half *z;
|
49 |
+
const half *dz;
|
50 |
+
const int chn;
|
51 |
+
const int sp;
|
52 |
+
};
|
53 |
+
|
54 |
+
/***********
|
55 |
+
* mean_var
|
56 |
+
***********/
|
57 |
+
|
58 |
+
__global__ void mean_var_kernel_h(const half *x, float *mean, float *var, int num, int chn, int sp) {
|
59 |
+
int plane = blockIdx.x;
|
60 |
+
float norm = 1.f / static_cast<float>(num * sp);
|
61 |
+
|
62 |
+
float _mean = reduce<float, SumOpH>(SumOpH(x, chn, sp), plane, num, sp) * norm;
|
63 |
+
__syncthreads();
|
64 |
+
float _var = reduce<float, VarOpH>(VarOpH(_mean, x, chn, sp), plane, num, sp) * norm;
|
65 |
+
|
66 |
+
if (threadIdx.x == 0) {
|
67 |
+
mean[plane] = _mean;
|
68 |
+
var[plane] = _var;
|
69 |
+
}
|
70 |
+
}
|
71 |
+
|
72 |
+
std::vector<at::Tensor> mean_var_cuda_h(at::Tensor x) {
|
73 |
+
CHECK_CUDA_INPUT(x);
|
74 |
+
|
75 |
+
// Extract dimensions
|
76 |
+
int64_t num, chn, sp;
|
77 |
+
get_dims(x, num, chn, sp);
|
78 |
+
|
79 |
+
// Prepare output tensors
|
80 |
+
auto mean = at::empty({chn},x.options().dtype(at::kFloat));
|
81 |
+
auto var = at::empty({chn},x.options().dtype(at::kFloat));
|
82 |
+
|
83 |
+
// Run kernel
|
84 |
+
dim3 blocks(chn);
|
85 |
+
dim3 threads(getNumThreads(sp));
|
86 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
87 |
+
mean_var_kernel_h<<<blocks, threads, 0, stream>>>(
|
88 |
+
reinterpret_cast<half*>(x.data<at::Half>()),
|
89 |
+
mean.data<float>(),
|
90 |
+
var.data<float>(),
|
91 |
+
num, chn, sp);
|
92 |
+
|
93 |
+
return {mean, var};
|
94 |
+
}
|
95 |
+
|
96 |
+
/**********
|
97 |
+
* forward
|
98 |
+
**********/
|
99 |
+
|
100 |
+
__global__ void forward_kernel_h(half *x, const float *mean, const float *var, const float *weight, const float *bias,
|
101 |
+
bool affine, float eps, int num, int chn, int sp) {
|
102 |
+
int plane = blockIdx.x;
|
103 |
+
|
104 |
+
const float _mean = mean[plane];
|
105 |
+
const float _var = var[plane];
|
106 |
+
const float _weight = affine ? abs(weight[plane]) + eps : 1.f;
|
107 |
+
const float _bias = affine ? bias[plane] : 0.f;
|
108 |
+
|
109 |
+
const float mul = rsqrt(_var + eps) * _weight;
|
110 |
+
|
111 |
+
for (int batch = 0; batch < num; ++batch) {
|
112 |
+
for (int n = threadIdx.x; n < sp; n += blockDim.x) {
|
113 |
+
half *x_ptr = x + (batch * chn + plane) * sp + n;
|
114 |
+
float _x = __half2float(*x_ptr);
|
115 |
+
float _y = (_x - _mean) * mul + _bias;
|
116 |
+
|
117 |
+
*x_ptr = __float2half(_y);
|
118 |
+
}
|
119 |
+
}
|
120 |
+
}
|
121 |
+
|
122 |
+
at::Tensor forward_cuda_h(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
123 |
+
bool affine, float eps) {
|
124 |
+
CHECK_CUDA_INPUT(x);
|
125 |
+
CHECK_CUDA_INPUT(mean);
|
126 |
+
CHECK_CUDA_INPUT(var);
|
127 |
+
CHECK_CUDA_INPUT(weight);
|
128 |
+
CHECK_CUDA_INPUT(bias);
|
129 |
+
|
130 |
+
// Extract dimensions
|
131 |
+
int64_t num, chn, sp;
|
132 |
+
get_dims(x, num, chn, sp);
|
133 |
+
|
134 |
+
// Run kernel
|
135 |
+
dim3 blocks(chn);
|
136 |
+
dim3 threads(getNumThreads(sp));
|
137 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
138 |
+
forward_kernel_h<<<blocks, threads, 0, stream>>>(
|
139 |
+
reinterpret_cast<half*>(x.data<at::Half>()),
|
140 |
+
mean.data<float>(),
|
141 |
+
var.data<float>(),
|
142 |
+
weight.data<float>(),
|
143 |
+
bias.data<float>(),
|
144 |
+
affine, eps, num, chn, sp);
|
145 |
+
|
146 |
+
return x;
|
147 |
+
}
|
148 |
+
|
149 |
+
__global__ void edz_eydz_kernel_h(const half *z, const half *dz, const float *weight, const float *bias,
|
150 |
+
float *edz, float *eydz, bool affine, float eps, int num, int chn, int sp) {
|
151 |
+
int plane = blockIdx.x;
|
152 |
+
|
153 |
+
float _weight = affine ? abs(weight[plane]) + eps : 1.f;
|
154 |
+
float _bias = affine ? bias[plane] : 0.f;
|
155 |
+
|
156 |
+
Pair<float> res = reduce<Pair<float>, GradOpH>(GradOpH(_weight, _bias, z, dz, chn, sp), plane, num, sp);
|
157 |
+
__syncthreads();
|
158 |
+
|
159 |
+
if (threadIdx.x == 0) {
|
160 |
+
edz[plane] = res.v1;
|
161 |
+
eydz[plane] = res.v2;
|
162 |
+
}
|
163 |
+
}
|
164 |
+
|
165 |
+
std::vector<at::Tensor> edz_eydz_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
|
166 |
+
bool affine, float eps) {
|
167 |
+
CHECK_CUDA_INPUT(z);
|
168 |
+
CHECK_CUDA_INPUT(dz);
|
169 |
+
CHECK_CUDA_INPUT(weight);
|
170 |
+
CHECK_CUDA_INPUT(bias);
|
171 |
+
|
172 |
+
// Extract dimensions
|
173 |
+
int64_t num, chn, sp;
|
174 |
+
get_dims(z, num, chn, sp);
|
175 |
+
|
176 |
+
auto edz = at::empty({chn},z.options().dtype(at::kFloat));
|
177 |
+
auto eydz = at::empty({chn},z.options().dtype(at::kFloat));
|
178 |
+
|
179 |
+
// Run kernel
|
180 |
+
dim3 blocks(chn);
|
181 |
+
dim3 threads(getNumThreads(sp));
|
182 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
183 |
+
edz_eydz_kernel_h<<<blocks, threads, 0, stream>>>(
|
184 |
+
reinterpret_cast<half*>(z.data<at::Half>()),
|
185 |
+
reinterpret_cast<half*>(dz.data<at::Half>()),
|
186 |
+
weight.data<float>(),
|
187 |
+
bias.data<float>(),
|
188 |
+
edz.data<float>(),
|
189 |
+
eydz.data<float>(),
|
190 |
+
affine, eps, num, chn, sp);
|
191 |
+
|
192 |
+
return {edz, eydz};
|
193 |
+
}
|
194 |
+
|
195 |
+
__global__ void backward_kernel_h(const half *z, const half *dz, const float *var, const float *weight, const float *bias, const float *edz,
|
196 |
+
const float *eydz, half *dx, bool affine, float eps, int num, int chn, int sp) {
|
197 |
+
int plane = blockIdx.x;
|
198 |
+
|
199 |
+
float _weight = affine ? abs(weight[plane]) + eps : 1.f;
|
200 |
+
float _bias = affine ? bias[plane] : 0.f;
|
201 |
+
float _var = var[plane];
|
202 |
+
float _edz = edz[plane];
|
203 |
+
float _eydz = eydz[plane];
|
204 |
+
|
205 |
+
float _mul = _weight * rsqrt(_var + eps);
|
206 |
+
float count = float(num * sp);
|
207 |
+
|
208 |
+
for (int batch = 0; batch < num; ++batch) {
|
209 |
+
for (int n = threadIdx.x; n < sp; n += blockDim.x) {
|
210 |
+
float _dz = __half2float(dz[(batch * chn + plane) * sp + n]);
|
211 |
+
float _y = (__half2float(z[(batch * chn + plane) * sp + n]) - _bias) / _weight;
|
212 |
+
|
213 |
+
dx[(batch * chn + plane) * sp + n] = __float2half((_dz - _edz / count - _y * _eydz / count) * _mul);
|
214 |
+
}
|
215 |
+
}
|
216 |
+
}
|
217 |
+
|
218 |
+
at::Tensor backward_cuda_h(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
|
219 |
+
at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
|
220 |
+
CHECK_CUDA_INPUT(z);
|
221 |
+
CHECK_CUDA_INPUT(dz);
|
222 |
+
CHECK_CUDA_INPUT(var);
|
223 |
+
CHECK_CUDA_INPUT(weight);
|
224 |
+
CHECK_CUDA_INPUT(bias);
|
225 |
+
CHECK_CUDA_INPUT(edz);
|
226 |
+
CHECK_CUDA_INPUT(eydz);
|
227 |
+
|
228 |
+
// Extract dimensions
|
229 |
+
int64_t num, chn, sp;
|
230 |
+
get_dims(z, num, chn, sp);
|
231 |
+
|
232 |
+
auto dx = at::zeros_like(z);
|
233 |
+
|
234 |
+
// Run kernel
|
235 |
+
dim3 blocks(chn);
|
236 |
+
dim3 threads(getNumThreads(sp));
|
237 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
238 |
+
backward_kernel_h<<<blocks, threads, 0, stream>>>(
|
239 |
+
reinterpret_cast<half*>(z.data<at::Half>()),
|
240 |
+
reinterpret_cast<half*>(dz.data<at::Half>()),
|
241 |
+
var.data<float>(),
|
242 |
+
weight.data<float>(),
|
243 |
+
bias.data<float>(),
|
244 |
+
edz.data<float>(),
|
245 |
+
eydz.data<float>(),
|
246 |
+
reinterpret_cast<half*>(dx.data<at::Half>()),
|
247 |
+
affine, eps, num, chn, sp);
|
248 |
+
|
249 |
+
return dx;
|
250 |
+
}
|
251 |
+
|
252 |
+
__global__ void leaky_relu_backward_impl_h(half *z, half *dz, float slope, int64_t count) {
|
253 |
+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x){
|
254 |
+
float _z = __half2float(z[i]);
|
255 |
+
if (_z < 0) {
|
256 |
+
dz[i] = __float2half(__half2float(dz[i]) * slope);
|
257 |
+
z[i] = __float2half(_z / slope);
|
258 |
+
}
|
259 |
+
}
|
260 |
+
}
|
261 |
+
|
262 |
+
void leaky_relu_backward_cuda_h(at::Tensor z, at::Tensor dz, float slope) {
|
263 |
+
CHECK_CUDA_INPUT(z);
|
264 |
+
CHECK_CUDA_INPUT(dz);
|
265 |
+
|
266 |
+
int64_t count = z.numel();
|
267 |
+
dim3 threads(getNumThreads(count));
|
268 |
+
dim3 blocks = (count + threads.x - 1) / threads.x;
|
269 |
+
auto stream = at::cuda::getCurrentCUDAStream();
|
270 |
+
leaky_relu_backward_impl_h<<<blocks, threads, 0, stream>>>(
|
271 |
+
reinterpret_cast<half*>(z.data<at::Half>()),
|
272 |
+
reinterpret_cast<half*>(dz.data<at::Half>()),
|
273 |
+
slope, count);
|
274 |
+
}
|
275 |
+
|
ConsistentID/lib/BiSeNet/modules/src/utils/checks.h
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <ATen/ATen.h>
|
4 |
+
|
5 |
+
// Define AT_CHECK for old version of ATen where the same function was called AT_ASSERT
|
6 |
+
#ifndef AT_CHECK
|
7 |
+
#define AT_CHECK AT_ASSERT
|
8 |
+
#endif
|
9 |
+
|
10 |
+
#define CHECK_CUDA(x) AT_CHECK((x).type().is_cuda(), #x " must be a CUDA tensor")
|
11 |
+
#define CHECK_CPU(x) AT_CHECK(!(x).type().is_cuda(), #x " must be a CPU tensor")
|
12 |
+
#define CHECK_CONTIGUOUS(x) AT_CHECK((x).is_contiguous(), #x " must be contiguous")
|
13 |
+
|
14 |
+
#define CHECK_CUDA_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
15 |
+
#define CHECK_CPU_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x)
|
ConsistentID/lib/BiSeNet/modules/src/utils/common.h
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <ATen/ATen.h>
|
4 |
+
|
5 |
+
/*
|
6 |
+
* Functions to share code between CPU and GPU
|
7 |
+
*/
|
8 |
+
|
9 |
+
#ifdef __CUDACC__
|
10 |
+
// CUDA versions
|
11 |
+
|
12 |
+
#define HOST_DEVICE __host__ __device__
|
13 |
+
#define INLINE_HOST_DEVICE __host__ __device__ inline
|
14 |
+
#define FLOOR(x) floor(x)
|
15 |
+
|
16 |
+
#if __CUDA_ARCH__ >= 600
|
17 |
+
// Recent compute capabilities have block-level atomicAdd for all data types, so we use that
|
18 |
+
#define ACCUM(x,y) atomicAdd_block(&(x),(y))
|
19 |
+
#else
|
20 |
+
// Older architectures don't have block-level atomicAdd, nor atomicAdd for doubles, so we defer to atomicAdd for float
|
21 |
+
// and use the known atomicCAS-based implementation for double
|
22 |
+
template<typename data_t>
|
23 |
+
__device__ inline data_t atomic_add(data_t *address, data_t val) {
|
24 |
+
return atomicAdd(address, val);
|
25 |
+
}
|
26 |
+
|
27 |
+
template<>
|
28 |
+
__device__ inline double atomic_add(double *address, double val) {
|
29 |
+
unsigned long long int* address_as_ull = (unsigned long long int*)address;
|
30 |
+
unsigned long long int old = *address_as_ull, assumed;
|
31 |
+
do {
|
32 |
+
assumed = old;
|
33 |
+
old = atomicCAS(address_as_ull, assumed, __double_as_longlong(val + __longlong_as_double(assumed)));
|
34 |
+
} while (assumed != old);
|
35 |
+
return __longlong_as_double(old);
|
36 |
+
}
|
37 |
+
|
38 |
+
#define ACCUM(x,y) atomic_add(&(x),(y))
|
39 |
+
#endif // #if __CUDA_ARCH__ >= 600
|
40 |
+
|
41 |
+
#else
|
42 |
+
// CPU versions
|
43 |
+
|
44 |
+
#define HOST_DEVICE
|
45 |
+
#define INLINE_HOST_DEVICE inline
|
46 |
+
#define FLOOR(x) std::floor(x)
|
47 |
+
#define ACCUM(x,y) (x) += (y)
|
48 |
+
|
49 |
+
#endif // #ifdef __CUDACC__
|
ConsistentID/lib/BiSeNet/modules/src/utils/cuda.cuh
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
/*
|
4 |
+
* General settings and functions
|
5 |
+
*/
|
6 |
+
const int WARP_SIZE = 32;
|
7 |
+
const int MAX_BLOCK_SIZE = 1024;
|
8 |
+
|
9 |
+
static int getNumThreads(int nElem) {
|
10 |
+
int threadSizes[6] = {32, 64, 128, 256, 512, MAX_BLOCK_SIZE};
|
11 |
+
for (int i = 0; i < 6; ++i) {
|
12 |
+
if (nElem <= threadSizes[i]) {
|
13 |
+
return threadSizes[i];
|
14 |
+
}
|
15 |
+
}
|
16 |
+
return MAX_BLOCK_SIZE;
|
17 |
+
}
|
18 |
+
|
19 |
+
/*
|
20 |
+
* Reduction utilities
|
21 |
+
*/
|
22 |
+
template <typename T>
|
23 |
+
__device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize,
|
24 |
+
unsigned int mask = 0xffffffff) {
|
25 |
+
#if CUDART_VERSION >= 9000
|
26 |
+
return __shfl_xor_sync(mask, value, laneMask, width);
|
27 |
+
#else
|
28 |
+
return __shfl_xor(value, laneMask, width);
|
29 |
+
#endif
|
30 |
+
}
|
31 |
+
|
32 |
+
__device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); }
|
33 |
+
|
34 |
+
template<typename T>
|
35 |
+
struct Pair {
|
36 |
+
T v1, v2;
|
37 |
+
__device__ Pair() {}
|
38 |
+
__device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {}
|
39 |
+
__device__ Pair(T v) : v1(v), v2(v) {}
|
40 |
+
__device__ Pair(int v) : v1(v), v2(v) {}
|
41 |
+
__device__ Pair &operator+=(const Pair<T> &a) {
|
42 |
+
v1 += a.v1;
|
43 |
+
v2 += a.v2;
|
44 |
+
return *this;
|
45 |
+
}
|
46 |
+
};
|
47 |
+
|
48 |
+
template<typename T>
|
49 |
+
static __device__ __forceinline__ T warpSum(T val) {
|
50 |
+
#if __CUDA_ARCH__ >= 300
|
51 |
+
for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
|
52 |
+
val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE);
|
53 |
+
}
|
54 |
+
#else
|
55 |
+
__shared__ T values[MAX_BLOCK_SIZE];
|
56 |
+
values[threadIdx.x] = val;
|
57 |
+
__threadfence_block();
|
58 |
+
const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
|
59 |
+
for (int i = 1; i < WARP_SIZE; i++) {
|
60 |
+
val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
|
61 |
+
}
|
62 |
+
#endif
|
63 |
+
return val;
|
64 |
+
}
|
65 |
+
|
66 |
+
template<typename T>
|
67 |
+
static __device__ __forceinline__ Pair<T> warpSum(Pair<T> value) {
|
68 |
+
value.v1 = warpSum(value.v1);
|
69 |
+
value.v2 = warpSum(value.v2);
|
70 |
+
return value;
|
71 |
+
}
|
ConsistentID/lib/BiSeNet/optimizer.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import logging
|
7 |
+
|
8 |
+
logger = logging.getLogger()
|
9 |
+
|
10 |
+
class Optimizer(object):
|
11 |
+
def __init__(self,
|
12 |
+
model,
|
13 |
+
lr0,
|
14 |
+
momentum,
|
15 |
+
wd,
|
16 |
+
warmup_steps,
|
17 |
+
warmup_start_lr,
|
18 |
+
max_iter,
|
19 |
+
power,
|
20 |
+
*args, **kwargs):
|
21 |
+
self.warmup_steps = warmup_steps
|
22 |
+
self.warmup_start_lr = warmup_start_lr
|
23 |
+
self.lr0 = lr0
|
24 |
+
self.lr = self.lr0
|
25 |
+
self.max_iter = float(max_iter)
|
26 |
+
self.power = power
|
27 |
+
self.it = 0
|
28 |
+
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = model.get_params()
|
29 |
+
param_list = [
|
30 |
+
{'params': wd_params},
|
31 |
+
{'params': nowd_params, 'weight_decay': 0},
|
32 |
+
{'params': lr_mul_wd_params, 'lr_mul': True},
|
33 |
+
{'params': lr_mul_nowd_params, 'weight_decay': 0, 'lr_mul': True}]
|
34 |
+
self.optim = torch.optim.SGD(
|
35 |
+
param_list,
|
36 |
+
lr = lr0,
|
37 |
+
momentum = momentum,
|
38 |
+
weight_decay = wd)
|
39 |
+
self.warmup_factor = (self.lr0/self.warmup_start_lr)**(1./self.warmup_steps)
|
40 |
+
|
41 |
+
|
42 |
+
def get_lr(self):
|
43 |
+
if self.it <= self.warmup_steps:
|
44 |
+
lr = self.warmup_start_lr*(self.warmup_factor**self.it)
|
45 |
+
else:
|
46 |
+
factor = (1-(self.it-self.warmup_steps)/(self.max_iter-self.warmup_steps))**self.power
|
47 |
+
lr = self.lr0 * factor
|
48 |
+
return lr
|
49 |
+
|
50 |
+
|
51 |
+
def step(self):
|
52 |
+
self.lr = self.get_lr()
|
53 |
+
for pg in self.optim.param_groups:
|
54 |
+
if pg.get('lr_mul', False):
|
55 |
+
pg['lr'] = self.lr * 10
|
56 |
+
else:
|
57 |
+
pg['lr'] = self.lr
|
58 |
+
if self.optim.defaults.get('lr_mul', False):
|
59 |
+
self.optim.defaults['lr'] = self.lr * 10
|
60 |
+
else:
|
61 |
+
self.optim.defaults['lr'] = self.lr
|
62 |
+
self.it += 1
|
63 |
+
self.optim.step()
|
64 |
+
if self.it == self.warmup_steps+2:
|
65 |
+
logger.info('==> warmup done, start to implement poly lr strategy')
|
66 |
+
|
67 |
+
def zero_grad(self):
|
68 |
+
self.optim.zero_grad()
|
69 |
+
|
ConsistentID/lib/BiSeNet/prepropess_data.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
import os.path as osp
|
5 |
+
import os
|
6 |
+
import cv2
|
7 |
+
from transform import *
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
face_data = '/home/zll/data/CelebAMask-HQ/CelebA-HQ-img'
|
11 |
+
face_sep_mask = '/home/zll/data/CelebAMask-HQ/CelebAMask-HQ-mask-anno'
|
12 |
+
mask_path = '/home/zll/data/CelebAMask-HQ/mask'
|
13 |
+
counter = 0
|
14 |
+
total = 0
|
15 |
+
for i in range(15):
|
16 |
+
|
17 |
+
atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
|
18 |
+
'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
|
19 |
+
|
20 |
+
for j in range(i * 2000, (i + 1) * 2000):
|
21 |
+
|
22 |
+
mask = np.zeros((512, 512))
|
23 |
+
|
24 |
+
for l, att in enumerate(atts, 1):
|
25 |
+
total += 1
|
26 |
+
file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png'])
|
27 |
+
path = osp.join(face_sep_mask, str(i), file_name)
|
28 |
+
|
29 |
+
if os.path.exists(path):
|
30 |
+
counter += 1
|
31 |
+
sep_mask = np.array(Image.open(path).convert('P'))
|
32 |
+
# print(np.unique(sep_mask))
|
33 |
+
|
34 |
+
mask[sep_mask == 225] = l
|
35 |
+
cv2.imwrite('{}/{}.png'.format(mask_path, j), mask)
|
36 |
+
print(j)
|
37 |
+
|
38 |
+
print(counter, total)
|
ConsistentID/lib/BiSeNet/resnet.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.utils.model_zoo as modelzoo
|
8 |
+
|
9 |
+
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
10 |
+
|
11 |
+
resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
|
12 |
+
|
13 |
+
|
14 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
15 |
+
"""3x3 convolution with padding"""
|
16 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
17 |
+
padding=1, bias=False)
|
18 |
+
|
19 |
+
|
20 |
+
class BasicBlock(nn.Module):
|
21 |
+
def __init__(self, in_chan, out_chan, stride=1):
|
22 |
+
super(BasicBlock, self).__init__()
|
23 |
+
self.conv1 = conv3x3(in_chan, out_chan, stride)
|
24 |
+
self.bn1 = nn.BatchNorm2d(out_chan)
|
25 |
+
self.conv2 = conv3x3(out_chan, out_chan)
|
26 |
+
self.bn2 = nn.BatchNorm2d(out_chan)
|
27 |
+
self.relu = nn.ReLU(inplace=True)
|
28 |
+
self.downsample = None
|
29 |
+
if in_chan != out_chan or stride != 1:
|
30 |
+
self.downsample = nn.Sequential(
|
31 |
+
nn.Conv2d(in_chan, out_chan,
|
32 |
+
kernel_size=1, stride=stride, bias=False),
|
33 |
+
nn.BatchNorm2d(out_chan),
|
34 |
+
)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
residual = self.conv1(x)
|
38 |
+
residual = F.relu(self.bn1(residual))
|
39 |
+
residual = self.conv2(residual)
|
40 |
+
residual = self.bn2(residual)
|
41 |
+
|
42 |
+
shortcut = x
|
43 |
+
if self.downsample is not None:
|
44 |
+
shortcut = self.downsample(x)
|
45 |
+
|
46 |
+
out = shortcut + residual
|
47 |
+
out = self.relu(out)
|
48 |
+
return out
|
49 |
+
|
50 |
+
|
51 |
+
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
|
52 |
+
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
|
53 |
+
for i in range(bnum-1):
|
54 |
+
layers.append(BasicBlock(out_chan, out_chan, stride=1))
|
55 |
+
return nn.Sequential(*layers)
|
56 |
+
|
57 |
+
|
58 |
+
class Resnet18(nn.Module):
|
59 |
+
def __init__(self):
|
60 |
+
super(Resnet18, self).__init__()
|
61 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
62 |
+
bias=False)
|
63 |
+
self.bn1 = nn.BatchNorm2d(64)
|
64 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
65 |
+
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
|
66 |
+
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
|
67 |
+
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
|
68 |
+
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
|
69 |
+
self.init_weight()
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
x = self.conv1(x)
|
73 |
+
x = F.relu(self.bn1(x))
|
74 |
+
x = self.maxpool(x)
|
75 |
+
|
76 |
+
x = self.layer1(x)
|
77 |
+
feat8 = self.layer2(x) # 1/8
|
78 |
+
feat16 = self.layer3(feat8) # 1/16
|
79 |
+
feat32 = self.layer4(feat16) # 1/32
|
80 |
+
return feat8, feat16, feat32
|
81 |
+
|
82 |
+
def init_weight(self):
|
83 |
+
state_dict = modelzoo.load_url(resnet18_url)
|
84 |
+
self_state_dict = self.state_dict()
|
85 |
+
for k, v in state_dict.items():
|
86 |
+
if 'fc' in k: continue
|
87 |
+
self_state_dict.update({k: v})
|
88 |
+
self.load_state_dict(self_state_dict)
|
89 |
+
|
90 |
+
def get_params(self):
|
91 |
+
wd_params, nowd_params = [], []
|
92 |
+
for name, module in self.named_modules():
|
93 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
94 |
+
wd_params.append(module.weight)
|
95 |
+
if not module.bias is None:
|
96 |
+
nowd_params.append(module.bias)
|
97 |
+
elif isinstance(module, nn.BatchNorm2d):
|
98 |
+
nowd_params += list(module.parameters())
|
99 |
+
return wd_params, nowd_params
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
net = Resnet18()
|
104 |
+
x = torch.randn(16, 3, 224, 224)
|
105 |
+
out = net(x)
|
106 |
+
print(out[0].size())
|
107 |
+
print(out[1].size())
|
108 |
+
print(out[2].size())
|
109 |
+
net.get_params()
|
ConsistentID/lib/BiSeNet/test.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
from logger import setup_logger
|
5 |
+
import BiSeNet
|
6 |
+
|
7 |
+
import torch
|
8 |
+
|
9 |
+
import os
|
10 |
+
import os.path as osp
|
11 |
+
import numpy as np
|
12 |
+
from PIL import Image
|
13 |
+
import torchvision.transforms as transforms
|
14 |
+
import cv2
|
15 |
+
|
16 |
+
def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg'):
|
17 |
+
# Colors for all 20 parts
|
18 |
+
part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
|
19 |
+
[255, 0, 85], [255, 0, 170],
|
20 |
+
[0, 255, 0], [85, 255, 0], [170, 255, 0],
|
21 |
+
[0, 255, 85], [0, 255, 170],
|
22 |
+
[0, 0, 255], [85, 0, 255], [170, 0, 255],
|
23 |
+
[0, 85, 255], [0, 170, 255],
|
24 |
+
[255, 255, 0], [255, 255, 85], [255, 255, 170],
|
25 |
+
[255, 0, 255], [255, 85, 255], [255, 170, 255],
|
26 |
+
[0, 255, 255], [85, 255, 255], [170, 255, 255]]
|
27 |
+
|
28 |
+
im = np.array(im)
|
29 |
+
vis_im = im.copy().astype(np.uint8)
|
30 |
+
vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
|
31 |
+
vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
|
32 |
+
vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
|
33 |
+
|
34 |
+
num_of_class = np.max(vis_parsing_anno)
|
35 |
+
|
36 |
+
for pi in range(1, num_of_class + 1):
|
37 |
+
index = np.where(vis_parsing_anno == pi)
|
38 |
+
vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
|
39 |
+
|
40 |
+
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
|
41 |
+
# print(vis_parsing_anno_color.shape, vis_im.shape)
|
42 |
+
vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
|
43 |
+
|
44 |
+
# Save result or not
|
45 |
+
if save_im:
|
46 |
+
cv2.imwrite(save_path[:-4] +'.png', vis_parsing_anno)
|
47 |
+
cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
|
48 |
+
|
49 |
+
# return vis_im
|
50 |
+
|
51 |
+
def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'):
|
52 |
+
|
53 |
+
if not os.path.exists(respth):
|
54 |
+
os.makedirs(respth)
|
55 |
+
|
56 |
+
n_classes = 19
|
57 |
+
net = BiSeNet(n_classes=n_classes)
|
58 |
+
net.cuda()
|
59 |
+
save_pth = osp.join('res/cp', cp)
|
60 |
+
net.load_state_dict(torch.load(save_pth))
|
61 |
+
net.eval()
|
62 |
+
|
63 |
+
to_tensor = transforms.Compose([
|
64 |
+
transforms.ToTensor(),
|
65 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
66 |
+
])
|
67 |
+
with torch.no_grad():
|
68 |
+
for image_path in os.listdir(dspth):
|
69 |
+
img = Image.open(osp.join(dspth, image_path))
|
70 |
+
image = img.resize((512, 512), Image.BILINEAR)
|
71 |
+
img = to_tensor(image)
|
72 |
+
img = torch.unsqueeze(img, 0)
|
73 |
+
img = img.cuda()
|
74 |
+
out = net(img)[0]
|
75 |
+
parsing = out.squeeze(0).cpu().numpy().argmax(0)
|
76 |
+
# print(parsing)
|
77 |
+
print(np.unique(parsing))
|
78 |
+
|
79 |
+
vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path))
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
|
87 |
+
if __name__ == "__main__":
|
88 |
+
evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='79999_iter.pth')
|
89 |
+
|
90 |
+
|
ConsistentID/lib/BiSeNet/train.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
from logger import setup_logger
|
5 |
+
import BiSeNet
|
6 |
+
from face_dataset import FaceMask
|
7 |
+
from loss import OhemCELoss
|
8 |
+
from evaluate import evaluate
|
9 |
+
from optimizer import Optimizer
|
10 |
+
import cv2
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
from torch.utils.data import DataLoader
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import torch.distributed as dist
|
18 |
+
|
19 |
+
import os
|
20 |
+
import os.path as osp
|
21 |
+
import logging
|
22 |
+
import time
|
23 |
+
import datetime
|
24 |
+
import argparse
|
25 |
+
|
26 |
+
|
27 |
+
respth = './res'
|
28 |
+
if not osp.exists(respth):
|
29 |
+
os.makedirs(respth)
|
30 |
+
logger = logging.getLogger()
|
31 |
+
|
32 |
+
|
33 |
+
def parse_args():
|
34 |
+
parse = argparse.ArgumentParser()
|
35 |
+
parse.add_argument(
|
36 |
+
'--local_rank',
|
37 |
+
dest = 'local_rank',
|
38 |
+
type = int,
|
39 |
+
default = -1,
|
40 |
+
)
|
41 |
+
return parse.parse_args()
|
42 |
+
|
43 |
+
|
44 |
+
def train():
|
45 |
+
args = parse_args()
|
46 |
+
torch.cuda.set_device(args.local_rank)
|
47 |
+
dist.init_process_group(
|
48 |
+
backend = 'nccl',
|
49 |
+
init_method = 'tcp://127.0.0.1:33241',
|
50 |
+
world_size = torch.cuda.device_count(),
|
51 |
+
rank=args.local_rank
|
52 |
+
)
|
53 |
+
setup_logger(respth)
|
54 |
+
|
55 |
+
# dataset
|
56 |
+
n_classes = 19
|
57 |
+
n_img_per_gpu = 16
|
58 |
+
n_workers = 8
|
59 |
+
cropsize = [448, 448]
|
60 |
+
data_root = '/home/zll/data/CelebAMask-HQ/'
|
61 |
+
|
62 |
+
ds = FaceMask(data_root, cropsize=cropsize, mode='train')
|
63 |
+
sampler = torch.utils.data.distributed.DistributedSampler(ds)
|
64 |
+
dl = DataLoader(ds,
|
65 |
+
batch_size = n_img_per_gpu,
|
66 |
+
shuffle = False,
|
67 |
+
sampler = sampler,
|
68 |
+
num_workers = n_workers,
|
69 |
+
pin_memory = True,
|
70 |
+
drop_last = True)
|
71 |
+
|
72 |
+
# model
|
73 |
+
ignore_idx = -100
|
74 |
+
net = BiSeNet(n_classes=n_classes)
|
75 |
+
net.cuda()
|
76 |
+
net.train()
|
77 |
+
net = nn.parallel.DistributedDataParallel(net,
|
78 |
+
device_ids = [args.local_rank, ],
|
79 |
+
output_device = args.local_rank
|
80 |
+
)
|
81 |
+
score_thres = 0.7
|
82 |
+
n_min = n_img_per_gpu * cropsize[0] * cropsize[1]//16
|
83 |
+
LossP = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
|
84 |
+
Loss2 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
|
85 |
+
Loss3 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx)
|
86 |
+
|
87 |
+
## optimizer
|
88 |
+
momentum = 0.9
|
89 |
+
weight_decay = 5e-4
|
90 |
+
lr_start = 1e-2
|
91 |
+
max_iter = 80000
|
92 |
+
power = 0.9
|
93 |
+
warmup_steps = 1000
|
94 |
+
warmup_start_lr = 1e-5
|
95 |
+
optim = Optimizer(
|
96 |
+
model = net.module,
|
97 |
+
lr0 = lr_start,
|
98 |
+
momentum = momentum,
|
99 |
+
wd = weight_decay,
|
100 |
+
warmup_steps = warmup_steps,
|
101 |
+
warmup_start_lr = warmup_start_lr,
|
102 |
+
max_iter = max_iter,
|
103 |
+
power = power)
|
104 |
+
|
105 |
+
## train loop
|
106 |
+
msg_iter = 50
|
107 |
+
loss_avg = []
|
108 |
+
st = glob_st = time.time()
|
109 |
+
diter = iter(dl)
|
110 |
+
epoch = 0
|
111 |
+
for it in range(max_iter):
|
112 |
+
try:
|
113 |
+
im, lb = next(diter)
|
114 |
+
if not im.size()[0] == n_img_per_gpu:
|
115 |
+
raise StopIteration
|
116 |
+
except StopIteration:
|
117 |
+
epoch += 1
|
118 |
+
sampler.set_epoch(epoch)
|
119 |
+
diter = iter(dl)
|
120 |
+
im, lb = next(diter)
|
121 |
+
im = im.cuda()
|
122 |
+
lb = lb.cuda()
|
123 |
+
H, W = im.size()[2:]
|
124 |
+
lb = torch.squeeze(lb, 1)
|
125 |
+
|
126 |
+
optim.zero_grad()
|
127 |
+
out, out16, out32 = net(im)
|
128 |
+
lossp = LossP(out, lb)
|
129 |
+
loss2 = Loss2(out16, lb)
|
130 |
+
loss3 = Loss3(out32, lb)
|
131 |
+
loss = lossp + loss2 + loss3
|
132 |
+
loss.backward()
|
133 |
+
optim.step()
|
134 |
+
|
135 |
+
loss_avg.append(loss.item())
|
136 |
+
|
137 |
+
# print training log message
|
138 |
+
if (it+1) % msg_iter == 0:
|
139 |
+
loss_avg = sum(loss_avg) / len(loss_avg)
|
140 |
+
lr = optim.lr
|
141 |
+
ed = time.time()
|
142 |
+
t_intv, glob_t_intv = ed - st, ed - glob_st
|
143 |
+
eta = int((max_iter - it) * (glob_t_intv / it))
|
144 |
+
eta = str(datetime.timedelta(seconds=eta))
|
145 |
+
msg = ', '.join([
|
146 |
+
'it: {it}/{max_it}',
|
147 |
+
'lr: {lr:4f}',
|
148 |
+
'loss: {loss:.4f}',
|
149 |
+
'eta: {eta}',
|
150 |
+
'time: {time:.4f}',
|
151 |
+
]).format(
|
152 |
+
it = it+1,
|
153 |
+
max_it = max_iter,
|
154 |
+
lr = lr,
|
155 |
+
loss = loss_avg,
|
156 |
+
time = t_intv,
|
157 |
+
eta = eta
|
158 |
+
)
|
159 |
+
logger.info(msg)
|
160 |
+
loss_avg = []
|
161 |
+
st = ed
|
162 |
+
if dist.get_rank() == 0:
|
163 |
+
if (it+1) % 5000 == 0:
|
164 |
+
state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
|
165 |
+
if dist.get_rank() == 0:
|
166 |
+
torch.save(state, './res/cp/{}_iter.pth'.format(it))
|
167 |
+
evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='{}_iter.pth'.format(it))
|
168 |
+
|
169 |
+
# dump the final model
|
170 |
+
save_pth = osp.join(respth, 'model_final_diss.pth')
|
171 |
+
# net.cpu()
|
172 |
+
state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict()
|
173 |
+
if dist.get_rank() == 0:
|
174 |
+
torch.save(state, save_pth)
|
175 |
+
logger.info('training done, model saved to: {}'.format(save_pth))
|
176 |
+
|
177 |
+
|
178 |
+
if __name__ == "__main__":
|
179 |
+
train()
|
ConsistentID/lib/BiSeNet/transform.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
|
5 |
+
from PIL import Image
|
6 |
+
import PIL.ImageEnhance as ImageEnhance
|
7 |
+
import random
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
class RandomCrop(object):
|
11 |
+
def __init__(self, size, *args, **kwargs):
|
12 |
+
self.size = size
|
13 |
+
|
14 |
+
def __call__(self, im_lb):
|
15 |
+
im = im_lb['im']
|
16 |
+
lb = im_lb['lb']
|
17 |
+
assert im.size == lb.size
|
18 |
+
W, H = self.size
|
19 |
+
w, h = im.size
|
20 |
+
|
21 |
+
if (W, H) == (w, h): return dict(im=im, lb=lb)
|
22 |
+
if w < W or h < H:
|
23 |
+
scale = float(W) / w if w < h else float(H) / h
|
24 |
+
w, h = int(scale * w + 1), int(scale * h + 1)
|
25 |
+
im = im.resize((w, h), Image.BILINEAR)
|
26 |
+
lb = lb.resize((w, h), Image.NEAREST)
|
27 |
+
sw, sh = random.random() * (w - W), random.random() * (h - H)
|
28 |
+
crop = int(sw), int(sh), int(sw) + W, int(sh) + H
|
29 |
+
return dict(
|
30 |
+
im = im.crop(crop),
|
31 |
+
lb = lb.crop(crop)
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
class HorizontalFlip(object):
|
36 |
+
def __init__(self, p=0.5, *args, **kwargs):
|
37 |
+
self.p = p
|
38 |
+
|
39 |
+
def __call__(self, im_lb):
|
40 |
+
if random.random() > self.p:
|
41 |
+
return im_lb
|
42 |
+
else:
|
43 |
+
im = im_lb['im']
|
44 |
+
lb = im_lb['lb']
|
45 |
+
|
46 |
+
# atts = [1 'skin', 2 'l_brow', 3 'r_brow', 4 'l_eye', 5 'r_eye', 6 'eye_g', 7 'l_ear', 8 'r_ear', 9 'ear_r',
|
47 |
+
# 10 'nose', 11 'mouth', 12 'u_lip', 13 'l_lip', 14 'neck', 15 'neck_l', 16 'cloth', 17 'hair', 18 'hat']
|
48 |
+
|
49 |
+
flip_lb = np.array(lb)
|
50 |
+
flip_lb[lb == 2] = 3
|
51 |
+
flip_lb[lb == 3] = 2
|
52 |
+
flip_lb[lb == 4] = 5
|
53 |
+
flip_lb[lb == 5] = 4
|
54 |
+
flip_lb[lb == 7] = 8
|
55 |
+
flip_lb[lb == 8] = 7
|
56 |
+
flip_lb = Image.fromarray(flip_lb)
|
57 |
+
return dict(im = im.transpose(Image.FLIP_LEFT_RIGHT),
|
58 |
+
lb = flip_lb.transpose(Image.FLIP_LEFT_RIGHT),
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
class RandomScale(object):
|
63 |
+
def __init__(self, scales=(1, ), *args, **kwargs):
|
64 |
+
self.scales = scales
|
65 |
+
|
66 |
+
def __call__(self, im_lb):
|
67 |
+
im = im_lb['im']
|
68 |
+
lb = im_lb['lb']
|
69 |
+
W, H = im.size
|
70 |
+
scale = random.choice(self.scales)
|
71 |
+
w, h = int(W * scale), int(H * scale)
|
72 |
+
return dict(im = im.resize((w, h), Image.BILINEAR),
|
73 |
+
lb = lb.resize((w, h), Image.NEAREST),
|
74 |
+
)
|
75 |
+
|
76 |
+
|
77 |
+
class ColorJitter(object):
|
78 |
+
def __init__(self, brightness=None, contrast=None, saturation=None, *args, **kwargs):
|
79 |
+
if not brightness is None and brightness>0:
|
80 |
+
self.brightness = [max(1-brightness, 0), 1+brightness]
|
81 |
+
if not contrast is None and contrast>0:
|
82 |
+
self.contrast = [max(1-contrast, 0), 1+contrast]
|
83 |
+
if not saturation is None and saturation>0:
|
84 |
+
self.saturation = [max(1-saturation, 0), 1+saturation]
|
85 |
+
|
86 |
+
def __call__(self, im_lb):
|
87 |
+
im = im_lb['im']
|
88 |
+
lb = im_lb['lb']
|
89 |
+
r_brightness = random.uniform(self.brightness[0], self.brightness[1])
|
90 |
+
r_contrast = random.uniform(self.contrast[0], self.contrast[1])
|
91 |
+
r_saturation = random.uniform(self.saturation[0], self.saturation[1])
|
92 |
+
im = ImageEnhance.Brightness(im).enhance(r_brightness)
|
93 |
+
im = ImageEnhance.Contrast(im).enhance(r_contrast)
|
94 |
+
im = ImageEnhance.Color(im).enhance(r_saturation)
|
95 |
+
return dict(im = im,
|
96 |
+
lb = lb,
|
97 |
+
)
|
98 |
+
|
99 |
+
|
100 |
+
class MultiScale(object):
|
101 |
+
def __init__(self, scales):
|
102 |
+
self.scales = scales
|
103 |
+
|
104 |
+
def __call__(self, img):
|
105 |
+
W, H = img.size
|
106 |
+
sizes = [(int(W*ratio), int(H*ratio)) for ratio in self.scales]
|
107 |
+
imgs = []
|
108 |
+
[imgs.append(img.resize(size, Image.BILINEAR)) for size in sizes]
|
109 |
+
return imgs
|
110 |
+
|
111 |
+
|
112 |
+
class Compose(object):
|
113 |
+
def __init__(self, do_list):
|
114 |
+
self.do_list = do_list
|
115 |
+
|
116 |
+
def __call__(self, im_lb):
|
117 |
+
for comp in self.do_list:
|
118 |
+
im_lb = comp(im_lb)
|
119 |
+
return im_lb
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
if __name__ == '__main__':
|
125 |
+
flip = HorizontalFlip(p = 1)
|
126 |
+
crop = RandomCrop((321, 321))
|
127 |
+
rscales = RandomScale((0.75, 1.0, 1.5, 1.75, 2.0))
|
128 |
+
img = Image.open('data/img.jpg')
|
129 |
+
lb = Image.open('data/label.png')
|
ConsistentID/lib/attention.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from diffusers.models.lora import LoRALinearLayer
|
5 |
+
from .functions import AttentionMLP
|
6 |
+
|
7 |
+
class FuseModule(nn.Module):
|
8 |
+
def __init__(self, embed_dim):
|
9 |
+
super().__init__()
|
10 |
+
self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False)
|
11 |
+
self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True)
|
12 |
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
13 |
+
|
14 |
+
def fuse_fn(self, prompt_embeds, id_embeds):
|
15 |
+
stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1)
|
16 |
+
stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds
|
17 |
+
stacked_id_embeds = self.mlp2(stacked_id_embeds)
|
18 |
+
stacked_id_embeds = self.layer_norm(stacked_id_embeds)
|
19 |
+
return stacked_id_embeds
|
20 |
+
|
21 |
+
def forward(
|
22 |
+
self,
|
23 |
+
prompt_embeds,
|
24 |
+
id_embeds,
|
25 |
+
class_tokens_mask,
|
26 |
+
valid_id_mask,
|
27 |
+
) -> torch.Tensor:
|
28 |
+
id_embeds = id_embeds.to(prompt_embeds.dtype)
|
29 |
+
batch_size, max_num_inputs = id_embeds.shape[:2] # 1,5
|
30 |
+
seq_length = prompt_embeds.shape[1] # 77
|
31 |
+
flat_id_embeds = id_embeds.view(-1, id_embeds.shape[-2], id_embeds.shape[-1])
|
32 |
+
# flat_id_embeds torch.Size([5, 1, 768])
|
33 |
+
valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()]
|
34 |
+
# valid_id_embeds torch.Size([4, 1, 768])
|
35 |
+
prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1]) # torch.Size([77, 768])
|
36 |
+
class_tokens_mask = class_tokens_mask.view(-1) # torch.Size([77])
|
37 |
+
valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1]) # torch.Size([4, 768])
|
38 |
+
image_token_embeds = prompt_embeds[class_tokens_mask] # torch.Size([4, 768])
|
39 |
+
stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) # torch.Size([4, 768])
|
40 |
+
assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}"
|
41 |
+
prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype))
|
42 |
+
updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1)
|
43 |
+
|
44 |
+
return updated_prompt_embeds
|
45 |
+
|
46 |
+
class MLP(nn.Module):
|
47 |
+
def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True):
|
48 |
+
super().__init__()
|
49 |
+
if use_residual:
|
50 |
+
assert in_dim == out_dim
|
51 |
+
self.layernorm = nn.LayerNorm(in_dim)
|
52 |
+
self.fc1 = nn.Linear(in_dim, hidden_dim)
|
53 |
+
self.fc2 = nn.Linear(hidden_dim, out_dim)
|
54 |
+
self.use_residual = use_residual
|
55 |
+
self.act_fn = nn.GELU()
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
|
59 |
+
residual = x
|
60 |
+
x = self.layernorm(x)
|
61 |
+
x = self.fc1(x)
|
62 |
+
x = self.act_fn(x)
|
63 |
+
x = self.fc2(x)
|
64 |
+
if self.use_residual:
|
65 |
+
x = x + residual
|
66 |
+
return x
|
67 |
+
|
68 |
+
class FacialEncoder(nn.Module):
|
69 |
+
def __init__(self):
|
70 |
+
super().__init__()
|
71 |
+
self.visual_projection = AttentionMLP()
|
72 |
+
self.fuse_module = FuseModule(768)
|
73 |
+
|
74 |
+
def forward(self, prompt_embeds, multi_image_embeds, class_tokens_mask, valid_id_mask):
|
75 |
+
bs, num_inputs, token_length, image_dim = multi_image_embeds.shape
|
76 |
+
multi_image_embeds_view = multi_image_embeds.view(bs * num_inputs, token_length, image_dim)
|
77 |
+
id_embeds = self.visual_projection(multi_image_embeds_view) # torch.Size([5, 1, 768])
|
78 |
+
id_embeds = id_embeds.view(bs, num_inputs, 1, -1)
|
79 |
+
# fuse_module replaces the class tokens in prompt_embeds with the fused (id_embeds, prompt_embeds[class_tokens_mask])
|
80 |
+
# whose indices are specified by class_tokens_mask.
|
81 |
+
updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask, valid_id_mask)
|
82 |
+
return updated_prompt_embeds
|
83 |
+
|
84 |
+
class Consistent_AttProcessor(nn.Module):
|
85 |
+
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
hidden_size=None,
|
89 |
+
cross_attention_dim=None,
|
90 |
+
rank=4,
|
91 |
+
network_alpha=None,
|
92 |
+
lora_scale=1.0,
|
93 |
+
):
|
94 |
+
super().__init__()
|
95 |
+
|
96 |
+
self.rank = rank
|
97 |
+
self.lora_scale = lora_scale
|
98 |
+
|
99 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
100 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
101 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
102 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
103 |
+
|
104 |
+
def __call__(
|
105 |
+
self,
|
106 |
+
attn,
|
107 |
+
hidden_states,
|
108 |
+
encoder_hidden_states=None,
|
109 |
+
attention_mask=None,
|
110 |
+
temb=None,
|
111 |
+
):
|
112 |
+
residual = hidden_states
|
113 |
+
|
114 |
+
if attn.spatial_norm is not None:
|
115 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
116 |
+
|
117 |
+
input_ndim = hidden_states.ndim
|
118 |
+
|
119 |
+
if input_ndim == 4:
|
120 |
+
batch_size, channel, height, width = hidden_states.shape
|
121 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
122 |
+
|
123 |
+
batch_size, sequence_length, _ = (
|
124 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
125 |
+
)
|
126 |
+
|
127 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
128 |
+
|
129 |
+
if attn.group_norm is not None:
|
130 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
131 |
+
|
132 |
+
query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
|
133 |
+
|
134 |
+
if encoder_hidden_states is None:
|
135 |
+
encoder_hidden_states = hidden_states
|
136 |
+
elif attn.norm_cross:
|
137 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
138 |
+
|
139 |
+
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
|
140 |
+
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
|
141 |
+
|
142 |
+
query = attn.head_to_batch_dim(query)
|
143 |
+
key = attn.head_to_batch_dim(key)
|
144 |
+
value = attn.head_to_batch_dim(value)
|
145 |
+
|
146 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
147 |
+
hidden_states = torch.bmm(attention_probs, value)
|
148 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
149 |
+
|
150 |
+
# linear proj
|
151 |
+
hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
|
152 |
+
# dropout
|
153 |
+
hidden_states = attn.to_out[1](hidden_states)
|
154 |
+
|
155 |
+
if input_ndim == 4:
|
156 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
157 |
+
|
158 |
+
if attn.residual_connection:
|
159 |
+
hidden_states = hidden_states + residual
|
160 |
+
|
161 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
162 |
+
|
163 |
+
return hidden_states
|
164 |
+
|
165 |
+
|
166 |
+
class Consistent_IPAttProcessor(nn.Module):
|
167 |
+
|
168 |
+
def __init__(
|
169 |
+
self,
|
170 |
+
hidden_size,
|
171 |
+
cross_attention_dim=None,
|
172 |
+
rank=4,
|
173 |
+
network_alpha=None,
|
174 |
+
lora_scale=1.0,
|
175 |
+
scale=1.0,
|
176 |
+
num_tokens=4):
|
177 |
+
super().__init__()
|
178 |
+
|
179 |
+
self.rank = rank
|
180 |
+
self.lora_scale = lora_scale
|
181 |
+
self.num_tokens = num_tokens
|
182 |
+
|
183 |
+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
184 |
+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
185 |
+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
|
186 |
+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
|
187 |
+
|
188 |
+
|
189 |
+
self.hidden_size = hidden_size
|
190 |
+
self.cross_attention_dim = cross_attention_dim
|
191 |
+
self.scale = scale
|
192 |
+
|
193 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
194 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
195 |
+
|
196 |
+
for module in [self.to_q_lora, self.to_k_lora, self.to_v_lora, self.to_out_lora, self.to_k_ip, self.to_v_ip]:
|
197 |
+
for param in module.parameters():
|
198 |
+
param.requires_grad = False
|
199 |
+
|
200 |
+
def __call__(
|
201 |
+
self,
|
202 |
+
attn,
|
203 |
+
hidden_states,
|
204 |
+
encoder_hidden_states=None,
|
205 |
+
attention_mask=None,
|
206 |
+
scale=1.0,
|
207 |
+
temb=None,
|
208 |
+
):
|
209 |
+
residual = hidden_states
|
210 |
+
|
211 |
+
if attn.spatial_norm is not None:
|
212 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
213 |
+
|
214 |
+
input_ndim = hidden_states.ndim
|
215 |
+
|
216 |
+
if input_ndim == 4:
|
217 |
+
batch_size, channel, height, width = hidden_states.shape
|
218 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
219 |
+
|
220 |
+
batch_size, sequence_length, _ = (
|
221 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
222 |
+
)
|
223 |
+
|
224 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
225 |
+
|
226 |
+
if attn.group_norm is not None:
|
227 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
228 |
+
|
229 |
+
query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states)
|
230 |
+
|
231 |
+
if encoder_hidden_states is None:
|
232 |
+
encoder_hidden_states = hidden_states
|
233 |
+
else:
|
234 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
235 |
+
encoder_hidden_states, ip_hidden_states = (
|
236 |
+
encoder_hidden_states[:, :end_pos, :],
|
237 |
+
encoder_hidden_states[:, end_pos:, :],
|
238 |
+
)
|
239 |
+
if attn.norm_cross:
|
240 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
241 |
+
|
242 |
+
key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states)
|
243 |
+
value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states)
|
244 |
+
|
245 |
+
inner_dim = key.shape[-1]
|
246 |
+
head_dim = inner_dim // attn.heads
|
247 |
+
|
248 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
249 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
250 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
251 |
+
|
252 |
+
hidden_states = F.scaled_dot_product_attention(
|
253 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
254 |
+
)
|
255 |
+
|
256 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
257 |
+
hidden_states = hidden_states.to(query.dtype)
|
258 |
+
|
259 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
260 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
261 |
+
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
262 |
+
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
263 |
+
|
264 |
+
|
265 |
+
ip_hidden_states = F.scaled_dot_product_attention(
|
266 |
+
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
267 |
+
)
|
268 |
+
|
269 |
+
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
270 |
+
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
271 |
+
|
272 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
273 |
+
|
274 |
+
# linear proj
|
275 |
+
hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states)
|
276 |
+
# dropout
|
277 |
+
hidden_states = attn.to_out[1](hidden_states)
|
278 |
+
|
279 |
+
if input_ndim == 4:
|
280 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
281 |
+
|
282 |
+
if attn.residual_connection:
|
283 |
+
hidden_states = hidden_states + residual
|
284 |
+
|
285 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
286 |
+
|
287 |
+
return hidden_states
|
ConsistentID/lib/functions.py
ADDED
@@ -0,0 +1,606 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import math
|
3 |
+
import types
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import numpy as np
|
7 |
+
import cv2
|
8 |
+
import re
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from einops import rearrange
|
11 |
+
from einops.layers.torch import Rearrange
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
def extract_first_sentence(text):
|
15 |
+
end_index = text.find('.')
|
16 |
+
if end_index != -1:
|
17 |
+
first_sentence = text[:end_index + 1]
|
18 |
+
return first_sentence.strip()
|
19 |
+
else:
|
20 |
+
return text.strip()
|
21 |
+
|
22 |
+
import re
|
23 |
+
def remove_duplicate_keywords(text, keywords):
|
24 |
+
keyword_counts = {}
|
25 |
+
|
26 |
+
words = re.findall(r'\b\w+\b|[.,;!?]', text)
|
27 |
+
|
28 |
+
for keyword in keywords:
|
29 |
+
keyword_counts[keyword] = 0
|
30 |
+
for i, word in enumerate(words):
|
31 |
+
if word.lower() == keyword.lower():
|
32 |
+
keyword_counts[keyword] += 1
|
33 |
+
if keyword_counts[keyword] > 1:
|
34 |
+
words[i] = ""
|
35 |
+
processed_text = " ".join(words)
|
36 |
+
|
37 |
+
return processed_text
|
38 |
+
|
39 |
+
# text: 'The person has one nose , two eyes , two ears , and a mouth .'
|
40 |
+
def insert_markers_to_prompt(text, parsing_mask_dict):
|
41 |
+
keywords = ["face", "ears", "eyes", "nose", "mouth"]
|
42 |
+
text = remove_duplicate_keywords(text, keywords)
|
43 |
+
key_parsing_mask_markers = ["Nose", "Face", "Left_Ear", "Right_Ear", "Left_Eye", "Right_Eye", "Upper_Lip", "Lower_Lip"]
|
44 |
+
mapping = {
|
45 |
+
"Face": "face",
|
46 |
+
"Left_Ear": "ears",
|
47 |
+
"Right_Ear": "ears",
|
48 |
+
"Left_Eye": "eyes",
|
49 |
+
"Right_Eye": "eyes",
|
50 |
+
"Nose": "nose",
|
51 |
+
"Upper_Lip": "mouth",
|
52 |
+
"Lower_Lip": "mouth",
|
53 |
+
}
|
54 |
+
facial_features_align = []
|
55 |
+
markers_align = []
|
56 |
+
for key in key_parsing_mask_markers:
|
57 |
+
if key in parsing_mask_dict:
|
58 |
+
mapped_key = mapping.get(key, key.lower())
|
59 |
+
if mapped_key not in facial_features_align:
|
60 |
+
facial_features_align.append(mapped_key)
|
61 |
+
markers_align.append("<|" + mapped_key + "|>")
|
62 |
+
|
63 |
+
text_marked = text
|
64 |
+
align_parsing_mask_dict = parsing_mask_dict
|
65 |
+
for feature, marker in zip(facial_features_align[::-1], markers_align[::-1]):
|
66 |
+
pattern = rf'\b{feature}\b'
|
67 |
+
text_marked_new = re.sub(pattern, f'{feature} {marker}', text_marked, count=1)
|
68 |
+
if text_marked == text_marked_new:
|
69 |
+
for key, value in mapping.items():
|
70 |
+
if value == feature:
|
71 |
+
if key in align_parsing_mask_dict:
|
72 |
+
del align_parsing_mask_dict[key]
|
73 |
+
|
74 |
+
text_marked = text_marked_new
|
75 |
+
|
76 |
+
text_marked = text_marked.replace('\n', '')
|
77 |
+
|
78 |
+
ordered_text = []
|
79 |
+
text_none_makers = []
|
80 |
+
facial_marked_count = 0
|
81 |
+
skip_count = 0
|
82 |
+
for marker in markers_align:
|
83 |
+
start_idx = text_marked.find(marker)
|
84 |
+
end_idx = start_idx + len(marker)
|
85 |
+
|
86 |
+
while start_idx > 0 and text_marked[start_idx - 1] not in [",", ".", ";"]:
|
87 |
+
start_idx -= 1
|
88 |
+
|
89 |
+
while end_idx < len(text_marked) and text_marked[end_idx] not in [",", ".", ";"]:
|
90 |
+
end_idx += 1
|
91 |
+
|
92 |
+
context = text_marked[start_idx:end_idx].strip()
|
93 |
+
if context == "":
|
94 |
+
text_none_makers.append(text_marked[:end_idx])
|
95 |
+
else:
|
96 |
+
if skip_count!=0:
|
97 |
+
skip_count -= 1
|
98 |
+
continue
|
99 |
+
else:
|
100 |
+
ordered_text.append(context + ", ")
|
101 |
+
text_delete_makers = text_marked[:start_idx] + text_marked[end_idx:]
|
102 |
+
text_marked = text_delete_makers
|
103 |
+
facial_marked_count += 1
|
104 |
+
|
105 |
+
# ordered_text: ['The person has one nose <|nose|>, ', 'two ears <|ears|>, ',
|
106 |
+
# 'two eyes <|eyes|>, ', 'and a mouth <|mouth|>, ']
|
107 |
+
# align_parsing_mask_dict.keys(): ['Right_Eye', 'Right_Ear', 'Nose', 'Upper_Lip']
|
108 |
+
align_marked_text = "".join(ordered_text)
|
109 |
+
replace_list = ["<|face|>", "<|ears|>", "<|nose|>", "<|eyes|>", "<|mouth|>"]
|
110 |
+
for item in replace_list:
|
111 |
+
align_marked_text = align_marked_text.replace(item, "<|facial|>")
|
112 |
+
|
113 |
+
# align_marked_text: 'The person has one nose <|facial|>, two ears <|facial|>, two eyes <|facial|>, and a mouth <|facial|>, '
|
114 |
+
return align_marked_text, align_parsing_mask_dict
|
115 |
+
|
116 |
+
def tokenize_and_mask_noun_phrases_ends(text, image_token_id, facial_token_id, tokenizer):
|
117 |
+
input_ids = tokenizer.encode(text)
|
118 |
+
image_noun_phrase_end_mask = [False for _ in input_ids]
|
119 |
+
facial_noun_phrase_end_mask = [False for _ in input_ids]
|
120 |
+
clean_input_ids = []
|
121 |
+
clean_index = 0
|
122 |
+
image_num = 0
|
123 |
+
|
124 |
+
for i, id in enumerate(input_ids):
|
125 |
+
if id == image_token_id:
|
126 |
+
image_noun_phrase_end_mask[clean_index + image_num - 1] = True
|
127 |
+
image_num += 1
|
128 |
+
elif id == facial_token_id:
|
129 |
+
facial_noun_phrase_end_mask[clean_index - 1] = True
|
130 |
+
else:
|
131 |
+
clean_input_ids.append(id)
|
132 |
+
clean_index += 1
|
133 |
+
|
134 |
+
max_len = tokenizer.model_max_length
|
135 |
+
|
136 |
+
if len(clean_input_ids) > max_len:
|
137 |
+
clean_input_ids = clean_input_ids[:max_len]
|
138 |
+
else:
|
139 |
+
clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * (
|
140 |
+
max_len - len(clean_input_ids)
|
141 |
+
)
|
142 |
+
|
143 |
+
if len(image_noun_phrase_end_mask) > max_len:
|
144 |
+
image_noun_phrase_end_mask = image_noun_phrase_end_mask[:max_len]
|
145 |
+
else:
|
146 |
+
image_noun_phrase_end_mask = image_noun_phrase_end_mask + [False] * (
|
147 |
+
max_len - len(image_noun_phrase_end_mask)
|
148 |
+
)
|
149 |
+
|
150 |
+
if len(facial_noun_phrase_end_mask) > max_len:
|
151 |
+
facial_noun_phrase_end_mask = facial_noun_phrase_end_mask[:max_len]
|
152 |
+
else:
|
153 |
+
facial_noun_phrase_end_mask = facial_noun_phrase_end_mask + [False] * (
|
154 |
+
max_len - len(facial_noun_phrase_end_mask)
|
155 |
+
)
|
156 |
+
clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long)
|
157 |
+
image_noun_phrase_end_mask = torch.tensor(image_noun_phrase_end_mask, dtype=torch.bool)
|
158 |
+
facial_noun_phrase_end_mask = torch.tensor(facial_noun_phrase_end_mask, dtype=torch.bool)
|
159 |
+
|
160 |
+
return clean_input_ids.unsqueeze(0), image_noun_phrase_end_mask.unsqueeze(0), facial_noun_phrase_end_mask.unsqueeze(0)
|
161 |
+
|
162 |
+
def prepare_image_token_idx(image_token_mask, facial_token_mask, max_num_objects=2, max_num_facials=5):
|
163 |
+
image_token_idx = torch.nonzero(image_token_mask, as_tuple=True)[1]
|
164 |
+
image_token_idx_mask = torch.ones_like(image_token_idx, dtype=torch.bool)
|
165 |
+
if len(image_token_idx) < max_num_objects:
|
166 |
+
image_token_idx = torch.cat(
|
167 |
+
[
|
168 |
+
image_token_idx,
|
169 |
+
torch.zeros(max_num_objects - len(image_token_idx), dtype=torch.long),
|
170 |
+
]
|
171 |
+
)
|
172 |
+
image_token_idx_mask = torch.cat(
|
173 |
+
[
|
174 |
+
image_token_idx_mask,
|
175 |
+
torch.zeros(
|
176 |
+
max_num_objects - len(image_token_idx_mask),
|
177 |
+
dtype=torch.bool,
|
178 |
+
),
|
179 |
+
]
|
180 |
+
)
|
181 |
+
facial_token_idx = torch.nonzero(facial_token_mask, as_tuple=True)[1]
|
182 |
+
facial_token_idx_mask = torch.ones_like(facial_token_idx, dtype=torch.bool)
|
183 |
+
if len(facial_token_idx) < max_num_facials:
|
184 |
+
facial_token_idx = torch.cat(
|
185 |
+
[
|
186 |
+
facial_token_idx,
|
187 |
+
torch.zeros(max_num_facials - len(facial_token_idx), dtype=torch.long),
|
188 |
+
]
|
189 |
+
)
|
190 |
+
facial_token_idx_mask = torch.cat(
|
191 |
+
[
|
192 |
+
facial_token_idx_mask,
|
193 |
+
torch.zeros(
|
194 |
+
max_num_facials - len(facial_token_idx_mask),
|
195 |
+
dtype=torch.bool,
|
196 |
+
),
|
197 |
+
]
|
198 |
+
)
|
199 |
+
image_token_idx = image_token_idx.unsqueeze(0)
|
200 |
+
image_token_idx_mask = image_token_idx_mask.unsqueeze(0)
|
201 |
+
|
202 |
+
facial_token_idx = facial_token_idx.unsqueeze(0)
|
203 |
+
facial_token_idx_mask = facial_token_idx_mask.unsqueeze(0)
|
204 |
+
|
205 |
+
return image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask
|
206 |
+
|
207 |
+
def get_object_localization_loss_for_one_layer(
|
208 |
+
cross_attention_scores,
|
209 |
+
object_segmaps,
|
210 |
+
object_token_idx,
|
211 |
+
object_token_idx_mask,
|
212 |
+
loss_fn,
|
213 |
+
):
|
214 |
+
bxh, num_noise_latents, num_text_tokens = cross_attention_scores.shape
|
215 |
+
b, max_num_objects, _, _ = object_segmaps.shape
|
216 |
+
size = int(num_noise_latents**0.5)
|
217 |
+
|
218 |
+
object_segmaps = F.interpolate(object_segmaps, size=(size, size), mode="bilinear", antialias=True)
|
219 |
+
|
220 |
+
object_segmaps = object_segmaps.view(
|
221 |
+
b, max_num_objects, -1
|
222 |
+
)
|
223 |
+
|
224 |
+
num_heads = bxh // b
|
225 |
+
cross_attention_scores = cross_attention_scores.view(b, num_heads, num_noise_latents, num_text_tokens)
|
226 |
+
|
227 |
+
|
228 |
+
object_token_attn_prob = torch.gather(
|
229 |
+
cross_attention_scores,
|
230 |
+
dim=3,
|
231 |
+
index=object_token_idx.view(b, 1, 1, max_num_objects).expand(
|
232 |
+
b, num_heads, num_noise_latents, max_num_objects
|
233 |
+
),
|
234 |
+
)
|
235 |
+
object_segmaps = (
|
236 |
+
object_segmaps.permute(0, 2, 1)
|
237 |
+
.unsqueeze(1)
|
238 |
+
.expand(b, num_heads, num_noise_latents, max_num_objects)
|
239 |
+
)
|
240 |
+
loss = loss_fn(object_token_attn_prob, object_segmaps)
|
241 |
+
|
242 |
+
loss = loss * object_token_idx_mask.view(b, 1, max_num_objects)
|
243 |
+
object_token_cnt = object_token_idx_mask.sum(dim=1).view(b, 1) + 1e-5
|
244 |
+
loss = (loss.sum(dim=2) / object_token_cnt).mean()
|
245 |
+
|
246 |
+
return loss
|
247 |
+
|
248 |
+
|
249 |
+
def get_object_localization_loss(
|
250 |
+
cross_attention_scores,
|
251 |
+
object_segmaps,
|
252 |
+
image_token_idx,
|
253 |
+
image_token_idx_mask,
|
254 |
+
loss_fn,
|
255 |
+
):
|
256 |
+
num_layers = len(cross_attention_scores)
|
257 |
+
loss = 0
|
258 |
+
for k, v in cross_attention_scores.items():
|
259 |
+
layer_loss = get_object_localization_loss_for_one_layer(
|
260 |
+
v, object_segmaps, image_token_idx, image_token_idx_mask, loss_fn
|
261 |
+
)
|
262 |
+
loss += layer_loss
|
263 |
+
return loss / num_layers
|
264 |
+
|
265 |
+
def unet_store_cross_attention_scores(unet, attention_scores, layers=5):
|
266 |
+
from diffusers.models.attention_processor import Attention
|
267 |
+
|
268 |
+
UNET_LAYER_NAMES = [
|
269 |
+
"down_blocks.0",
|
270 |
+
"down_blocks.1",
|
271 |
+
"down_blocks.2",
|
272 |
+
"mid_block",
|
273 |
+
"up_blocks.1",
|
274 |
+
"up_blocks.2",
|
275 |
+
"up_blocks.3",
|
276 |
+
]
|
277 |
+
|
278 |
+
start_layer = (len(UNET_LAYER_NAMES) - layers) // 2
|
279 |
+
end_layer = start_layer + layers
|
280 |
+
applicable_layers = UNET_LAYER_NAMES[start_layer:end_layer]
|
281 |
+
|
282 |
+
def make_new_get_attention_scores_fn(name):
|
283 |
+
def new_get_attention_scores(module, query, key, attention_mask=None):
|
284 |
+
attention_probs = module.old_get_attention_scores(
|
285 |
+
query, key, attention_mask
|
286 |
+
)
|
287 |
+
attention_scores[name] = attention_probs
|
288 |
+
return attention_probs
|
289 |
+
|
290 |
+
return new_get_attention_scores
|
291 |
+
|
292 |
+
for name, module in unet.named_modules():
|
293 |
+
if isinstance(module, Attention) and "attn1" in name:
|
294 |
+
if not any(layer in name for layer in applicable_layers):
|
295 |
+
continue
|
296 |
+
|
297 |
+
module.old_get_attention_scores = module.get_attention_scores
|
298 |
+
module.get_attention_scores = types.MethodType(
|
299 |
+
make_new_get_attention_scores_fn(name), module
|
300 |
+
)
|
301 |
+
return unet
|
302 |
+
|
303 |
+
class BalancedL1Loss(nn.Module):
|
304 |
+
def __init__(self, threshold=1.0, normalize=False):
|
305 |
+
super().__init__()
|
306 |
+
self.threshold = threshold
|
307 |
+
self.normalize = normalize
|
308 |
+
|
309 |
+
def forward(self, object_token_attn_prob, object_segmaps):
|
310 |
+
if self.normalize:
|
311 |
+
object_token_attn_prob = object_token_attn_prob / (
|
312 |
+
object_token_attn_prob.max(dim=2, keepdim=True)[0] + 1e-5
|
313 |
+
)
|
314 |
+
background_segmaps = 1 - object_segmaps
|
315 |
+
background_segmaps_sum = background_segmaps.sum(dim=2) + 1e-5
|
316 |
+
object_segmaps_sum = object_segmaps.sum(dim=2) + 1e-5
|
317 |
+
|
318 |
+
background_loss = (object_token_attn_prob * background_segmaps).sum(
|
319 |
+
dim=2
|
320 |
+
) / background_segmaps_sum
|
321 |
+
|
322 |
+
object_loss = (object_token_attn_prob * object_segmaps).sum(
|
323 |
+
dim=2
|
324 |
+
) / object_segmaps_sum
|
325 |
+
|
326 |
+
return background_loss - object_loss
|
327 |
+
|
328 |
+
def apply_mask_to_raw_image(raw_image, mask_image):
|
329 |
+
mask_image = mask_image.resize(raw_image.size)
|
330 |
+
mask_raw_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (0, 0, 0)), mask_image)
|
331 |
+
return mask_raw_image
|
332 |
+
|
333 |
+
mapping_table = [
|
334 |
+
{"Mask Value": 0, "Body Part": "Background", "RGB Color": [0, 0, 0]},
|
335 |
+
{"Mask Value": 1, "Body Part": "Face", "RGB Color": [255, 0, 0]},
|
336 |
+
{"Mask Value": 2, "Body Part": "Left_Eyebrow", "RGB Color": [255, 85, 0]},
|
337 |
+
{"Mask Value": 3, "Body Part": "Right_Eyebrow", "RGB Color": [255, 170, 0]},
|
338 |
+
{"Mask Value": 4, "Body Part": "Left_Eye", "RGB Color": [255, 0, 85]},
|
339 |
+
{"Mask Value": 5, "Body Part": "Right_Eye", "RGB Color": [255, 0, 170]},
|
340 |
+
{"Mask Value": 6, "Body Part": "Hair", "RGB Color": [0, 0, 255]},
|
341 |
+
{"Mask Value": 7, "Body Part": "Left_Ear", "RGB Color": [85, 0, 255]},
|
342 |
+
{"Mask Value": 8, "Body Part": "Right_Ear", "RGB Color": [170, 0, 255]},
|
343 |
+
{"Mask Value": 9, "Body Part": "Mouth_External Contour", "RGB Color": [0, 255, 85]},
|
344 |
+
{"Mask Value": 10, "Body Part": "Nose", "RGB Color": [0, 255, 0]},
|
345 |
+
{"Mask Value": 11, "Body Part": "Mouth_Inner_Contour", "RGB Color": [0, 255, 170]},
|
346 |
+
{"Mask Value": 12, "Body Part": "Upper_Lip", "RGB Color": [85, 255, 0]},
|
347 |
+
{"Mask Value": 13, "Body Part": "Lower_Lip", "RGB Color": [170, 255, 0]},
|
348 |
+
{"Mask Value": 14, "Body Part": "Neck", "RGB Color": [0, 85, 255]},
|
349 |
+
{"Mask Value": 15, "Body Part": "Neck_Inner Contour", "RGB Color": [0, 170, 255]},
|
350 |
+
{"Mask Value": 16, "Body Part": "Cloth", "RGB Color": [255, 255, 0]},
|
351 |
+
{"Mask Value": 17, "Body Part": "Hat", "RGB Color": [255, 0, 255]},
|
352 |
+
{"Mask Value": 18, "Body Part": "Earring", "RGB Color": [255, 85, 255]},
|
353 |
+
{"Mask Value": 19, "Body Part": "Necklace", "RGB Color": [255, 255, 85]},
|
354 |
+
{"Mask Value": 20, "Body Part": "Glasses", "RGB Color": [255, 170, 255]},
|
355 |
+
{"Mask Value": 21, "Body Part": "Hand", "RGB Color": [255, 0, 255]},
|
356 |
+
{"Mask Value": 22, "Body Part": "Wristband", "RGB Color": [0, 255, 255]},
|
357 |
+
{"Mask Value": 23, "Body Part": "Clothes_Upper", "RGB Color": [85, 255, 255]},
|
358 |
+
{"Mask Value": 24, "Body Part": "Clothes_Lower", "RGB Color": [170, 255, 255]}
|
359 |
+
]
|
360 |
+
|
361 |
+
|
362 |
+
def masks_for_unique_values(image_raw_mask):
|
363 |
+
|
364 |
+
image_array = np.array(image_raw_mask)
|
365 |
+
unique_values, counts = np.unique(image_array, return_counts=True)
|
366 |
+
masks_dict = {}
|
367 |
+
for value in unique_values:
|
368 |
+
binary_image = np.uint8(image_array == value) * 255
|
369 |
+
contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
370 |
+
|
371 |
+
mask = np.zeros_like(image_array)
|
372 |
+
for contour in contours:
|
373 |
+
cv2.drawContours(mask, [contour], -1, (255), thickness=cv2.FILLED)
|
374 |
+
|
375 |
+
if value == 0:
|
376 |
+
body_part="WithoutBackground"
|
377 |
+
mask2 = np.where(mask == 255, 0, 255).astype(mask.dtype)
|
378 |
+
masks_dict[body_part] = Image.fromarray(mask2)
|
379 |
+
|
380 |
+
body_part = next((entry["Body Part"] for entry in mapping_table if entry["Mask Value"] == value), f"Unknown_{value}")
|
381 |
+
if body_part.startswith("Unknown_"):
|
382 |
+
continue
|
383 |
+
|
384 |
+
masks_dict[body_part] = Image.fromarray(mask)
|
385 |
+
|
386 |
+
return masks_dict
|
387 |
+
# FFN
|
388 |
+
def FeedForward(dim, mult=4):
|
389 |
+
inner_dim = int(dim * mult)
|
390 |
+
return nn.Sequential(
|
391 |
+
nn.LayerNorm(dim),
|
392 |
+
nn.Linear(dim, inner_dim, bias=False),
|
393 |
+
nn.GELU(),
|
394 |
+
nn.Linear(inner_dim, dim, bias=False),
|
395 |
+
)
|
396 |
+
|
397 |
+
|
398 |
+
def reshape_tensor(x, heads):
|
399 |
+
bs, length, width = x.shape
|
400 |
+
x = x.view(bs, length, heads, -1)
|
401 |
+
x = x.transpose(1, 2)
|
402 |
+
x = x.reshape(bs, heads, length, -1)
|
403 |
+
return x
|
404 |
+
|
405 |
+
class PerceiverAttention(nn.Module):
|
406 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
407 |
+
super().__init__()
|
408 |
+
self.scale = dim_head**-0.5
|
409 |
+
self.dim_head = dim_head
|
410 |
+
self.heads = heads
|
411 |
+
inner_dim = dim_head * heads
|
412 |
+
|
413 |
+
self.norm1 = nn.LayerNorm(dim)
|
414 |
+
self.norm2 = nn.LayerNorm(dim)
|
415 |
+
|
416 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
417 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
418 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
419 |
+
|
420 |
+
# x -> kv, latents -> q
|
421 |
+
def forward(self, x, latents):
|
422 |
+
"""
|
423 |
+
Args:
|
424 |
+
x (torch.Tensor): image features
|
425 |
+
shape (b, n1, D)
|
426 |
+
latent (torch.Tensor): latent features
|
427 |
+
shape (b, n2, D)
|
428 |
+
"""
|
429 |
+
|
430 |
+
x = self.norm1(x)
|
431 |
+
latents = self.norm2(latents)
|
432 |
+
|
433 |
+
b, l, _ = latents.shape
|
434 |
+
|
435 |
+
q = self.to_q(latents)
|
436 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
437 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
438 |
+
|
439 |
+
q = reshape_tensor(q, self.heads)
|
440 |
+
k = reshape_tensor(k, self.heads)
|
441 |
+
v = reshape_tensor(v, self.heads)
|
442 |
+
|
443 |
+
# attention
|
444 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
445 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1)
|
446 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
447 |
+
out = weight @ v
|
448 |
+
|
449 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
450 |
+
|
451 |
+
return self.to_out(out)
|
452 |
+
|
453 |
+
class FacePerceiverResampler(torch.nn.Module):
|
454 |
+
def __init__(
|
455 |
+
self,
|
456 |
+
*,
|
457 |
+
dim=768,
|
458 |
+
depth=4,
|
459 |
+
dim_head=64,
|
460 |
+
heads=16,
|
461 |
+
embedding_dim=1280,
|
462 |
+
output_dim=768,
|
463 |
+
ff_mult=4,
|
464 |
+
):
|
465 |
+
super().__init__()
|
466 |
+
|
467 |
+
self.proj_in = torch.nn.Linear(embedding_dim, dim)
|
468 |
+
self.proj_out = torch.nn.Linear(dim, output_dim)
|
469 |
+
self.norm_out = torch.nn.LayerNorm(output_dim)
|
470 |
+
self.layers = torch.nn.ModuleList([])
|
471 |
+
for _ in range(depth):
|
472 |
+
self.layers.append(
|
473 |
+
torch.nn.ModuleList(
|
474 |
+
[
|
475 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
476 |
+
FeedForward(dim=dim, mult=ff_mult),
|
477 |
+
]
|
478 |
+
)
|
479 |
+
)
|
480 |
+
# x -> kv, latents -> q
|
481 |
+
def forward(self, latents, x): # latents.torch.Size([2, 4, 768]) x.torch.Size([2, 257, 1280])
|
482 |
+
x = self.proj_in(x) # x.torch.Size([2, 257, 768])
|
483 |
+
for attn, ff in self.layers:
|
484 |
+
# x -> kv, latents -> q
|
485 |
+
latents = attn(x, latents) + latents # latents.torch.Size([2, 4, 768])
|
486 |
+
latents = ff(latents) + latents # latents.torch.Size([2, 4, 768])
|
487 |
+
latents = self.proj_out(latents)
|
488 |
+
return self.norm_out(latents)
|
489 |
+
|
490 |
+
class ProjPlusModel(torch.nn.Module):
|
491 |
+
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4):
|
492 |
+
super().__init__()
|
493 |
+
|
494 |
+
self.cross_attention_dim = cross_attention_dim
|
495 |
+
self.num_tokens = num_tokens
|
496 |
+
|
497 |
+
self.proj = torch.nn.Sequential(
|
498 |
+
torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
|
499 |
+
torch.nn.GELU(),
|
500 |
+
torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
|
501 |
+
)
|
502 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
503 |
+
|
504 |
+
self.perceiver_resampler = FacePerceiverResampler(
|
505 |
+
dim=cross_attention_dim,
|
506 |
+
depth=4,
|
507 |
+
dim_head=64,
|
508 |
+
heads=cross_attention_dim // 64,
|
509 |
+
embedding_dim=clip_embeddings_dim,
|
510 |
+
output_dim=cross_attention_dim,
|
511 |
+
ff_mult=4,
|
512 |
+
)
|
513 |
+
|
514 |
+
def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0):
|
515 |
+
|
516 |
+
x = self.proj(id_embeds)
|
517 |
+
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
518 |
+
x = self.norm(x)
|
519 |
+
# id_embeds -> x -> kv, clip_embeds -> q
|
520 |
+
out = self.perceiver_resampler(x, clip_embeds)
|
521 |
+
if shortcut:
|
522 |
+
out = scale * x + out
|
523 |
+
return out
|
524 |
+
|
525 |
+
class AttentionMLP(nn.Module):
|
526 |
+
def __init__(
|
527 |
+
self,
|
528 |
+
dtype=torch.float16,
|
529 |
+
dim=1024,
|
530 |
+
depth=8,
|
531 |
+
dim_head=64,
|
532 |
+
heads=16,
|
533 |
+
single_num_tokens=1,
|
534 |
+
embedding_dim=1280,
|
535 |
+
output_dim=768,
|
536 |
+
ff_mult=4,
|
537 |
+
max_seq_len: int = 257*2,
|
538 |
+
apply_pos_emb: bool = False,
|
539 |
+
num_latents_mean_pooled: int = 0,
|
540 |
+
):
|
541 |
+
super().__init__()
|
542 |
+
self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
|
543 |
+
|
544 |
+
self.single_num_tokens = single_num_tokens
|
545 |
+
self.latents = nn.Parameter(torch.randn(1, self.single_num_tokens, dim) / dim**0.5)
|
546 |
+
|
547 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
548 |
+
|
549 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
550 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
551 |
+
|
552 |
+
self.to_latents_from_mean_pooled_seq = (
|
553 |
+
nn.Sequential(
|
554 |
+
nn.LayerNorm(dim),
|
555 |
+
nn.Linear(dim, dim * num_latents_mean_pooled),
|
556 |
+
Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
|
557 |
+
)
|
558 |
+
if num_latents_mean_pooled > 0
|
559 |
+
else None
|
560 |
+
)
|
561 |
+
|
562 |
+
self.layers = nn.ModuleList([])
|
563 |
+
for _ in range(depth):
|
564 |
+
self.layers.append(
|
565 |
+
nn.ModuleList(
|
566 |
+
[
|
567 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
568 |
+
FeedForward(dim=dim, mult=ff_mult),
|
569 |
+
]
|
570 |
+
)
|
571 |
+
)
|
572 |
+
|
573 |
+
def forward(self, x):
|
574 |
+
if self.pos_emb is not None:
|
575 |
+
n, device = x.shape[1], x.device
|
576 |
+
pos_emb = self.pos_emb(torch.arange(n, device=device))
|
577 |
+
x = x + pos_emb
|
578 |
+
# x torch.Size([5, 257, 1280])
|
579 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
580 |
+
|
581 |
+
x = self.proj_in(x) # torch.Size([5, 257, 1024])
|
582 |
+
|
583 |
+
if self.to_latents_from_mean_pooled_seq:
|
584 |
+
meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
|
585 |
+
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
|
586 |
+
latents = torch.cat((meanpooled_latents, latents), dim=-2)
|
587 |
+
|
588 |
+
for attn, ff in self.layers:
|
589 |
+
latents = attn(x, latents) + latents
|
590 |
+
latents = ff(latents) + latents
|
591 |
+
|
592 |
+
latents = self.proj_out(latents)
|
593 |
+
return self.norm_out(latents)
|
594 |
+
|
595 |
+
|
596 |
+
def masked_mean(t, *, dim, mask=None):
|
597 |
+
if mask is None:
|
598 |
+
return t.mean(dim=dim)
|
599 |
+
|
600 |
+
denom = mask.sum(dim=dim, keepdim=True)
|
601 |
+
mask = rearrange(mask, "b n -> b n 1")
|
602 |
+
masked_t = t.masked_fill(~mask, 0.0)
|
603 |
+
|
604 |
+
return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
|
605 |
+
|
606 |
+
|
ConsistentID/lib/pipeline_ConsistentID.py
ADDED
@@ -0,0 +1,605 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
2 |
+
import cv2
|
3 |
+
import PIL
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import torch
|
7 |
+
from torchvision import transforms
|
8 |
+
from insightface.app import FaceAnalysis
|
9 |
+
### insight-face installation can be found at https://github.com/deepinsight/insightface
|
10 |
+
from safetensors import safe_open
|
11 |
+
from huggingface_hub.utils import validate_hf_hub_args
|
12 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
13 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
14 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
|
15 |
+
from .functions import insert_markers_to_prompt, masks_for_unique_values, apply_mask_to_raw_image, tokenize_and_mask_noun_phrases_ends, prepare_image_token_idx
|
16 |
+
from .functions import ProjPlusModel, masks_for_unique_values
|
17 |
+
from .attention import Consistent_IPAttProcessor, Consistent_AttProcessor, FacialEncoder
|
18 |
+
from easydict import EasyDict as edict
|
19 |
+
from huggingface_hub import hf_hub_download
|
20 |
+
### Model can be imported from https://github.com/zllrunning/face-parsing.PyTorch?tab=readme-ov-file
|
21 |
+
### We use the ckpt of 79999_iter.pth: https://drive.google.com/open?id=154JgKpzCPW82qINcVieuPH3fZ2e0P812
|
22 |
+
### Thanks for the open source of face-parsing model.
|
23 |
+
from .BiSeNet.model import BiSeNet
|
24 |
+
import os
|
25 |
+
|
26 |
+
PipelineImageInput = Union[
|
27 |
+
PIL.Image.Image,
|
28 |
+
torch.FloatTensor,
|
29 |
+
List[PIL.Image.Image],
|
30 |
+
List[torch.FloatTensor],
|
31 |
+
]
|
32 |
+
|
33 |
+
### Download the pretrained model from huggingface and put it locally, then place the model in a local directory and specify the directory location.
|
34 |
+
class ConsistentIDPipeline(StableDiffusionPipeline):
|
35 |
+
# to() should be only called after all modules are loaded.
|
36 |
+
def to(
|
37 |
+
self,
|
38 |
+
torch_device: Optional[Union[str, torch.device]] = None,
|
39 |
+
dtype: Optional[torch.dtype] = None,
|
40 |
+
):
|
41 |
+
super().to(torch_device, dtype=dtype)
|
42 |
+
self.bise_net.to(torch_device, dtype=dtype)
|
43 |
+
self.clip_encoder.to(torch_device, dtype=dtype)
|
44 |
+
self.image_proj_model.to(torch_device, dtype=dtype)
|
45 |
+
self.FacialEncoder.to(torch_device, dtype=dtype)
|
46 |
+
# If the unet is not released, the ip_layers should be moved to the specified device and dtype.
|
47 |
+
if not isinstance(self.unet, edict):
|
48 |
+
self.ip_layers.to(torch_device, dtype=dtype)
|
49 |
+
return self
|
50 |
+
|
51 |
+
@validate_hf_hub_args
|
52 |
+
def load_ConsistentID_model(
|
53 |
+
self,
|
54 |
+
consistentID_weight_path: str,
|
55 |
+
bise_net_weight_path: str,
|
56 |
+
trigger_word_facial: str = '<|facial|>',
|
57 |
+
# A CLIP ViT-H/14 model trained with the LAION-2B English subset of LAION-5B using OpenCLIP.
|
58 |
+
# output dim: 1280.
|
59 |
+
image_encoder_path: str = 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
|
60 |
+
torch_dtype = torch.float16,
|
61 |
+
num_tokens = 4,
|
62 |
+
lora_rank= 128,
|
63 |
+
**kwargs,
|
64 |
+
):
|
65 |
+
self.lora_rank = lora_rank
|
66 |
+
self.torch_dtype = torch_dtype
|
67 |
+
self.num_tokens = num_tokens
|
68 |
+
self.set_ip_adapter()
|
69 |
+
self.image_encoder_path = image_encoder_path
|
70 |
+
self.clip_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path)
|
71 |
+
self.clip_preprocessor = CLIPImageProcessor()
|
72 |
+
self.id_image_processor = CLIPImageProcessor()
|
73 |
+
self.crop_size = 512
|
74 |
+
|
75 |
+
# face_app: FaceAnalysis object
|
76 |
+
self.face_app = FaceAnalysis(name="buffalo_l", root='models/insightface', providers=['CPUExecutionProvider'])
|
77 |
+
# The original det_size=(640, 640) is too large and face_app often fails to detect faces.
|
78 |
+
self.face_app.prepare(ctx_id=0, det_size=(512, 512))
|
79 |
+
|
80 |
+
if not os.path.exists(consistentID_weight_path):
|
81 |
+
### Download pretrained models
|
82 |
+
hf_hub_download(repo_id="JackAILab/ConsistentID", repo_type="model",
|
83 |
+
filename=os.path.basename(consistentID_weight_path),
|
84 |
+
local_dir=os.path.dirname(consistentID_weight_path))
|
85 |
+
if not os.path.exists(bise_net_weight_path):
|
86 |
+
hf_hub_download(repo_id="JackAILab/ConsistentID",
|
87 |
+
filename=os.path.basename(bise_net_weight_path),
|
88 |
+
local_dir=os.path.dirname(bise_net_weight_path))
|
89 |
+
|
90 |
+
bise_net = BiSeNet(n_classes = 19)
|
91 |
+
bise_net.load_state_dict(torch.load(bise_net_weight_path, map_location="cpu"))
|
92 |
+
bise_net.eval()
|
93 |
+
self.bise_net = bise_net
|
94 |
+
|
95 |
+
# Colors for all 20 parts
|
96 |
+
self.part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
|
97 |
+
[255, 0, 85], [255, 0, 170],
|
98 |
+
[0, 255, 0], [85, 255, 0], [170, 255, 0],
|
99 |
+
[0, 255, 85], [0, 255, 170],
|
100 |
+
[0, 0, 255], [85, 0, 255], [170, 0, 255],
|
101 |
+
[0, 85, 255], [0, 170, 255],
|
102 |
+
[255, 255, 0], [255, 255, 85], [255, 255, 170],
|
103 |
+
[255, 0, 255], [255, 85, 255], [255, 170, 255],
|
104 |
+
[0, 255, 255], [85, 255, 255], [170, 255, 255]]
|
105 |
+
|
106 |
+
# image_proj_model maps 1280-dim OpenCLIP embeddings to 768-dim face prompt embeddings.
|
107 |
+
self.image_proj_model = ProjPlusModel(
|
108 |
+
cross_attention_dim=self.unet.config.cross_attention_dim,
|
109 |
+
id_embeddings_dim=512,
|
110 |
+
clip_embeddings_dim=self.clip_encoder.config.hidden_size,
|
111 |
+
num_tokens=self.num_tokens, # 4 - inspirsed by IPAdapter and Midjourney
|
112 |
+
)
|
113 |
+
self.FacialEncoder = FacialEncoder()
|
114 |
+
|
115 |
+
if consistentID_weight_path.endswith(".safetensors"):
|
116 |
+
state_dict = {"id_encoder": {}, "lora_weights": {}}
|
117 |
+
with safe_open(consistentID_weight_path, framework="pt", device="cpu") as f:
|
118 |
+
### TODO safetensors add
|
119 |
+
for key in f.keys():
|
120 |
+
if key.startswith("FacialEncoder."):
|
121 |
+
state_dict["FacialEncoder"][key.replace("FacialEncoder.", "")] = f.get_tensor(key)
|
122 |
+
elif key.startswith("image_proj."):
|
123 |
+
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
124 |
+
else:
|
125 |
+
state_dict = torch.load(consistentID_weight_path, map_location="cpu")
|
126 |
+
|
127 |
+
self.trigger_word_facial = trigger_word_facial
|
128 |
+
|
129 |
+
self.FacialEncoder.load_state_dict(state_dict["FacialEncoder"], strict=True)
|
130 |
+
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
|
131 |
+
self.ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
|
132 |
+
self.ip_layers.load_state_dict(state_dict["adapter_modules"], strict=True)
|
133 |
+
print(f"Successfully loaded weights from checkpoint")
|
134 |
+
|
135 |
+
# Add trigger word token
|
136 |
+
if self.tokenizer is not None:
|
137 |
+
self.tokenizer.add_tokens([self.trigger_word_facial], special_tokens=True)
|
138 |
+
|
139 |
+
def set_ip_adapter(self):
|
140 |
+
unet = self.unet
|
141 |
+
attn_procs = {}
|
142 |
+
for name in unet.attn_processors.keys():
|
143 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
144 |
+
if name.startswith("mid_block"):
|
145 |
+
hidden_size = unet.config.block_out_channels[-1]
|
146 |
+
elif name.startswith("up_blocks"):
|
147 |
+
block_id = int(name[len("up_blocks.")])
|
148 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
149 |
+
elif name.startswith("down_blocks"):
|
150 |
+
block_id = int(name[len("down_blocks.")])
|
151 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
152 |
+
if cross_attention_dim is None:
|
153 |
+
attn_procs[name] = Consistent_AttProcessor(
|
154 |
+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank,
|
155 |
+
)
|
156 |
+
else:
|
157 |
+
attn_procs[name] = Consistent_IPAttProcessor(
|
158 |
+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens,
|
159 |
+
)
|
160 |
+
|
161 |
+
unet.set_attn_processor(attn_procs)
|
162 |
+
|
163 |
+
@torch.inference_mode()
|
164 |
+
# parsed_image_parts2 is a batched tensor of parsed_image_parts with bs=1. It only contains the facial areas of one input image.
|
165 |
+
# clip_encoder maps image parts to image-space diffusion prompts.
|
166 |
+
# Then the facial class token embeddings are replaced with the fused (multi_facial_embeds, prompt_embeds[class_tokens_mask]).
|
167 |
+
def extract_local_facial_embeds(self, prompt_embeds, uncond_prompt_embeds, parsed_image_parts2,
|
168 |
+
facial_token_masks, valid_facial_token_idx_mask, calc_uncond=True):
|
169 |
+
|
170 |
+
hidden_states = []
|
171 |
+
uncond_hidden_states = []
|
172 |
+
for parsed_image_parts in parsed_image_parts2:
|
173 |
+
hidden_state = self.clip_encoder(parsed_image_parts.to(self.device, dtype=self.torch_dtype), output_hidden_states=True).hidden_states[-2]
|
174 |
+
uncond_hidden_state = self.clip_encoder(torch.zeros_like(parsed_image_parts, dtype=self.torch_dtype).to(self.device), output_hidden_states=True).hidden_states[-2]
|
175 |
+
hidden_states.append(hidden_state)
|
176 |
+
uncond_hidden_states.append(uncond_hidden_state)
|
177 |
+
multi_facial_embeds = torch.stack(hidden_states)
|
178 |
+
uncond_multi_facial_embeds = torch.stack(uncond_hidden_states)
|
179 |
+
|
180 |
+
# conditional prompt.
|
181 |
+
# FacialEncoder maps multi_facial_embeds to facial ID embeddings, and replaces the class tokens in prompt_embeds
|
182 |
+
# with the fused (facial ID embeddings, prompt_embeds[class_tokens_mask]).
|
183 |
+
# multi_facial_embeds: [1, 5, 257, 1280].
|
184 |
+
facial_prompt_embeds = self.FacialEncoder(prompt_embeds, multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask)
|
185 |
+
|
186 |
+
if not calc_uncond:
|
187 |
+
return facial_prompt_embeds, None
|
188 |
+
# unconditional prompt.
|
189 |
+
uncond_facial_prompt_embeds = self.FacialEncoder(uncond_prompt_embeds, uncond_multi_facial_embeds, facial_token_masks, valid_facial_token_idx_mask)
|
190 |
+
|
191 |
+
return facial_prompt_embeds, uncond_facial_prompt_embeds
|
192 |
+
|
193 |
+
@torch.inference_mode()
|
194 |
+
# Extrat OpenCLIP embeddings from the input image and map them to face prompt embeddings.
|
195 |
+
def extract_global_id_embeds(self, face_image_obj, s_scale=1.0, shortcut=False):
|
196 |
+
clip_image_ts = self.clip_preprocessor(images=face_image_obj, return_tensors="pt").pixel_values
|
197 |
+
clip_image_ts = clip_image_ts.to(self.device, dtype=self.torch_dtype)
|
198 |
+
clip_image_embeds = self.clip_encoder(clip_image_ts, output_hidden_states=True).hidden_states[-2]
|
199 |
+
uncond_clip_image_embeds = self.clip_encoder(torch.zeros_like(clip_image_ts), output_hidden_states=True).hidden_states[-2]
|
200 |
+
|
201 |
+
faceid_embeds = self.extract_faceid(face_image_obj)
|
202 |
+
faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
|
203 |
+
# image_proj_model maps 1280-dim OpenCLIP embeddings to 768-dim face prompt embeddings.
|
204 |
+
# clip_image_embeds are used as queries to transform faceid_embeds.
|
205 |
+
# faceid_embeds -> kv, clip_image_embeds -> q
|
206 |
+
global_id_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale)
|
207 |
+
uncond_global_id_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale)
|
208 |
+
|
209 |
+
return global_id_embeds, uncond_global_id_embeds
|
210 |
+
|
211 |
+
def set_scale(self, scale):
|
212 |
+
for attn_processor in self.pipe.unet.attn_processors.values():
|
213 |
+
if isinstance(attn_processor, Consistent_IPAttProcessor):
|
214 |
+
attn_processor.scale = scale
|
215 |
+
|
216 |
+
@torch.inference_mode()
|
217 |
+
def extract_faceid(self, face_image_obj):
|
218 |
+
faceid_image = np.array(face_image_obj)
|
219 |
+
faces = self.face_app.get(faceid_image)
|
220 |
+
if faces==[]:
|
221 |
+
faceid_embeds = torch.zeros_like(torch.empty((1, 512)))
|
222 |
+
else:
|
223 |
+
faceid_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
|
224 |
+
|
225 |
+
return faceid_embeds
|
226 |
+
|
227 |
+
@torch.inference_mode()
|
228 |
+
def parse_face_mask(self, raw_image_refer):
|
229 |
+
|
230 |
+
to_tensor = transforms.Compose([
|
231 |
+
transforms.ToTensor(),
|
232 |
+
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
233 |
+
])
|
234 |
+
to_pil = transforms.ToPILImage()
|
235 |
+
|
236 |
+
with torch.no_grad():
|
237 |
+
image = raw_image_refer.resize((512, 512), Image.BILINEAR)
|
238 |
+
image_resize_PIL = image
|
239 |
+
img = to_tensor(image)
|
240 |
+
img = torch.unsqueeze(img, 0)
|
241 |
+
img = img.to(self.device, dtype=self.torch_dtype)
|
242 |
+
out = self.bise_net(img)[0]
|
243 |
+
parsing_anno = out.squeeze(0).cpu().numpy().argmax(0)
|
244 |
+
|
245 |
+
im = np.array(image_resize_PIL)
|
246 |
+
vis_im = im.copy().astype(np.uint8)
|
247 |
+
stride=1
|
248 |
+
vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
|
249 |
+
vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
|
250 |
+
vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
|
251 |
+
|
252 |
+
num_of_class = np.max(vis_parsing_anno)
|
253 |
+
|
254 |
+
for pi in range(1, num_of_class + 1): # num_of_class=17 pi=1~16
|
255 |
+
index = np.where(vis_parsing_anno == pi)
|
256 |
+
vis_parsing_anno_color[index[0], index[1], :] = self.part_colors[pi]
|
257 |
+
|
258 |
+
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
|
259 |
+
vis_parsing_anno_color = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
|
260 |
+
|
261 |
+
return vis_parsing_anno_color, vis_parsing_anno
|
262 |
+
|
263 |
+
@torch.inference_mode()
|
264 |
+
def extract_facemask(self, input_image_obj):
|
265 |
+
vis_parsing_anno_color, vis_parsing_anno = self.parse_face_mask(input_image_obj)
|
266 |
+
parsing_mask_list = masks_for_unique_values(vis_parsing_anno)
|
267 |
+
|
268 |
+
key_parsing_mask_dict = {}
|
269 |
+
key_list = ["Face", "Left_Ear", "Right_Ear", "Left_Eye", "Right_Eye", "Nose", "Upper_Lip", "Lower_Lip"]
|
270 |
+
processed_keys = set()
|
271 |
+
for key, mask_image in parsing_mask_list.items():
|
272 |
+
if key in key_list:
|
273 |
+
if "_" in key:
|
274 |
+
prefix = key.split("_")[1]
|
275 |
+
if prefix in processed_keys:
|
276 |
+
continue
|
277 |
+
else:
|
278 |
+
key_parsing_mask_dict[key] = mask_image
|
279 |
+
processed_keys.add(prefix)
|
280 |
+
|
281 |
+
key_parsing_mask_dict[key] = mask_image
|
282 |
+
|
283 |
+
return key_parsing_mask_dict, vis_parsing_anno_color
|
284 |
+
|
285 |
+
def augment_prompt_with_trigger_word(
|
286 |
+
self,
|
287 |
+
prompt: str,
|
288 |
+
face_caption: str,
|
289 |
+
key_parsing_mask_dict = None,
|
290 |
+
facial_token = "<|facial|>",
|
291 |
+
max_num_facials = 5,
|
292 |
+
num_id_images: int = 1,
|
293 |
+
device: Optional[torch.device] = None,
|
294 |
+
):
|
295 |
+
device = device or self._execution_device
|
296 |
+
|
297 |
+
# face_caption_align: 'The person has one nose <|facial|>, two ears <|facial|>, two eyes <|facial|>, and a mouth <|facial|>, '
|
298 |
+
face_caption_align, key_parsing_mask_dict_align = insert_markers_to_prompt(face_caption, key_parsing_mask_dict)
|
299 |
+
|
300 |
+
prompt_face = prompt + " Detail: " + face_caption_align
|
301 |
+
|
302 |
+
max_text_length=330
|
303 |
+
if len(self.tokenizer(prompt_face, max_length=self.tokenizer.model_max_length,
|
304 |
+
padding="max_length", truncation=False, return_tensors="pt").input_ids[0]) != 77:
|
305 |
+
# Put face_caption_align at the beginning of the prompt, so that the original prompt is truncated,
|
306 |
+
# but the face_caption_align is well kept.
|
307 |
+
prompt_face = "Detail: " + face_caption_align + " Caption:" + prompt
|
308 |
+
|
309 |
+
# Remove "<|facial|>" from prompt_face.
|
310 |
+
# augmented_prompt: 'A person, police officer, half body shot Detail:
|
311 |
+
# The person has one nose , two ears , two eyes , and a mouth , '
|
312 |
+
augmented_prompt = prompt_face.replace("<|facial|>", "")
|
313 |
+
tokenizer = self.tokenizer
|
314 |
+
facial_token_id = tokenizer.convert_tokens_to_ids(facial_token)
|
315 |
+
image_token_id = None
|
316 |
+
|
317 |
+
# image_token_id: the token id of "<|image|>". Disabled, as it's set to None.
|
318 |
+
# facial_token_id: the token id of "<|facial|>".
|
319 |
+
clean_input_id, image_token_mask, facial_token_mask = \
|
320 |
+
tokenize_and_mask_noun_phrases_ends(prompt_face, image_token_id, facial_token_id, tokenizer)
|
321 |
+
|
322 |
+
image_token_idx, image_token_idx_mask, facial_token_idx, facial_token_idx_mask = \
|
323 |
+
prepare_image_token_idx(image_token_mask, facial_token_mask, num_id_images, max_num_facials)
|
324 |
+
|
325 |
+
return augmented_prompt, clean_input_id, key_parsing_mask_dict_align, facial_token_mask, facial_token_idx, facial_token_idx_mask
|
326 |
+
|
327 |
+
@torch.inference_mode()
|
328 |
+
def extract_parsed_image_parts(self, input_image_obj, key_parsing_mask_dict, image_size=512, max_num_facials=5):
|
329 |
+
facial_masks = []
|
330 |
+
parsed_image_parts = []
|
331 |
+
key_masked_raw_images_dict = {}
|
332 |
+
transform_mask = transforms.Compose([transforms.CenterCrop(size=image_size), transforms.ToTensor(),])
|
333 |
+
clip_preprocessor = CLIPImageProcessor()
|
334 |
+
|
335 |
+
num_facial_part = len(key_parsing_mask_dict)
|
336 |
+
|
337 |
+
for key in key_parsing_mask_dict:
|
338 |
+
key_mask=key_parsing_mask_dict[key]
|
339 |
+
facial_masks.append(transform_mask(key_mask))
|
340 |
+
key_masked_raw_image = apply_mask_to_raw_image(input_image_obj, key_mask)
|
341 |
+
key_masked_raw_images_dict[key] = key_masked_raw_image
|
342 |
+
# clip_preprocessor normalizes key_masked_raw_image, so that (masked) zero pixels become non-zero.
|
343 |
+
# It also resizes the image to 224x224.
|
344 |
+
parsed_image_part = clip_preprocessor(images=key_masked_raw_image, return_tensors="pt").pixel_values
|
345 |
+
parsed_image_parts.append(parsed_image_part)
|
346 |
+
|
347 |
+
padding_ficial_clip_image = torch.zeros_like(torch.zeros([1, 3, 224, 224]))
|
348 |
+
padding_ficial_mask = torch.zeros_like(torch.zeros([1, image_size, image_size]))
|
349 |
+
|
350 |
+
if num_facial_part < max_num_facials:
|
351 |
+
parsed_image_parts += [ torch.zeros_like(padding_ficial_clip_image) for _ in range(max_num_facials - num_facial_part) ]
|
352 |
+
facial_masks += [ torch.zeros_like(padding_ficial_mask) for _ in range(max_num_facials - num_facial_part) ]
|
353 |
+
|
354 |
+
parsed_image_parts = torch.stack(parsed_image_parts, dim=1).squeeze(0)
|
355 |
+
facial_masks = torch.stack(facial_masks, dim=0).squeeze(dim=1)
|
356 |
+
|
357 |
+
return parsed_image_parts, facial_masks, key_masked_raw_images_dict
|
358 |
+
|
359 |
+
# Release the unet/vae/text_encoder to save memory.
|
360 |
+
def release_components(self, released_components=["unet", "vae", "text_encoder"]):
|
361 |
+
if "unet" in released_components:
|
362 |
+
unet = edict()
|
363 |
+
# Only keep the config and in_channels attributes that are used in the pipeline.
|
364 |
+
unet.config = self.unet.config
|
365 |
+
self.unet = unet
|
366 |
+
|
367 |
+
if "vae" in released_components:
|
368 |
+
self.vae = None
|
369 |
+
if "text_encoder" in released_components:
|
370 |
+
self.text_encoder = None
|
371 |
+
|
372 |
+
# input_subj_image_obj: an Image object.
|
373 |
+
def extract_double_id_prompt_embeds(self, prompt, negative_prompt, input_subj_image_obj, device, calc_uncond=True):
|
374 |
+
face_caption = "The person has one nose, two eyes, two ears, and a mouth."
|
375 |
+
key_parsing_mask_dict, vis_parsing_anno_color = self.extract_facemask(input_subj_image_obj)
|
376 |
+
|
377 |
+
augmented_prompt, clean_input_id, key_parsing_mask_dict_align, \
|
378 |
+
facial_token_mask, facial_token_idx, facial_token_idx_mask \
|
379 |
+
= self.augment_prompt_with_trigger_word(
|
380 |
+
prompt = prompt,
|
381 |
+
face_caption = face_caption,
|
382 |
+
key_parsing_mask_dict=key_parsing_mask_dict,
|
383 |
+
device=device,
|
384 |
+
max_num_facials = 5,
|
385 |
+
num_id_images = 1
|
386 |
+
)
|
387 |
+
|
388 |
+
text_embeds, uncond_text_embeds = self.encode_prompt(
|
389 |
+
augmented_prompt,
|
390 |
+
device=device,
|
391 |
+
num_images_per_prompt=1,
|
392 |
+
do_classifier_free_guidance=calc_uncond,
|
393 |
+
negative_prompt=negative_prompt,
|
394 |
+
)
|
395 |
+
|
396 |
+
# 5. Prepare the input ID images
|
397 |
+
# global_id_embeds: [1, 4, 768]
|
398 |
+
# extract_global_id_embeds() extrats OpenCLIP embeddings from the input image and map them to global face prompt embeddings.
|
399 |
+
global_id_embeds, uncond_global_id_embeds = \
|
400 |
+
self.extract_global_id_embeds(face_image_obj=input_subj_image_obj, s_scale=1.0, shortcut=False)
|
401 |
+
|
402 |
+
# parsed_image_parts: [5, 3, 224, 224]. 5 parts, each part is a 3-channel 224x224 image (resized by CLIP Preprocessor).
|
403 |
+
parsed_image_parts, facial_masks, key_masked_raw_images_dict = \
|
404 |
+
self.extract_parsed_image_parts(input_subj_image_obj, key_parsing_mask_dict_align, image_size=512, max_num_facials=5)
|
405 |
+
parsed_image_parts2 = parsed_image_parts.unsqueeze(0).to(device, dtype=self.torch_dtype)
|
406 |
+
facial_token_mask = facial_token_mask.to(device)
|
407 |
+
facial_token_idx_mask = facial_token_idx_mask.to(device)
|
408 |
+
|
409 |
+
# key_masked_raw_images_dict: ['Right_Eye', 'Right_Ear', 'Nose', 'Upper_Lip']
|
410 |
+
# for key in key_masked_raw_images_dict:
|
411 |
+
# key_masked_raw_images_dict[key].save(f"{key}.png")
|
412 |
+
|
413 |
+
# 6. Get the update text embedding
|
414 |
+
# parsed_image_parts2: the facial areas of the input image
|
415 |
+
# extract_local_facial_embeds() maps parsed_image_parts2 to multi_facial_embeds, and then replaces the class tokens in prompt_embeds
|
416 |
+
# with the fused (id_embeds, prompt_embeds[class_tokens_mask]) whose indices are specified by class_tokens_mask.
|
417 |
+
# parsed_image_parts2: [1, 5, 3, 224, 224]
|
418 |
+
text_local_id_embeds, uncond_text_local_id_embeds = \
|
419 |
+
self.extract_local_facial_embeds(text_embeds, uncond_text_embeds, \
|
420 |
+
parsed_image_parts2, facial_token_mask, facial_token_idx_mask,
|
421 |
+
calc_uncond=calc_uncond)
|
422 |
+
|
423 |
+
# text_global_id_embeds, text_local_global_id_embeds: [1, 81, 768]
|
424 |
+
# text_local_id_embeds: [1, 77, 768], only differs with text_embeds on 4 ID embeddings, and is identical
|
425 |
+
# to text_embeds on the rest 73 tokens.
|
426 |
+
text_global_id_embeds = torch.cat([text_embeds, global_id_embeds], dim=1)
|
427 |
+
text_local_global_id_embeds = torch.cat([text_local_id_embeds, global_id_embeds], dim=1)
|
428 |
+
|
429 |
+
if calc_uncond:
|
430 |
+
uncond_text_global_id_embeds = torch.cat([uncond_text_local_id_embeds, uncond_global_id_embeds], dim=1)
|
431 |
+
coarse_prompt_embeds = torch.cat([uncond_text_global_id_embeds, text_global_id_embeds], dim=0)
|
432 |
+
fine_prompt_embeds = torch.cat([uncond_text_global_id_embeds, text_local_global_id_embeds], dim=0)
|
433 |
+
else:
|
434 |
+
coarse_prompt_embeds = text_global_id_embeds
|
435 |
+
fine_prompt_embeds = text_local_global_id_embeds
|
436 |
+
|
437 |
+
# fine_prompt_embeds: the conditional part is
|
438 |
+
# (text_global_id_embeds + text_local_global_id_embeds) / 2.
|
439 |
+
fine_prompt_embeds = (coarse_prompt_embeds + fine_prompt_embeds) / 2
|
440 |
+
|
441 |
+
return coarse_prompt_embeds, fine_prompt_embeds
|
442 |
+
|
443 |
+
@torch.no_grad()
|
444 |
+
def __call__(
|
445 |
+
self,
|
446 |
+
prompt: Union[str, List[str]] = None,
|
447 |
+
height: Optional[int] = None,
|
448 |
+
width: Optional[int] = None,
|
449 |
+
num_inference_steps: int = 50,
|
450 |
+
guidance_scale: float = 5.0,
|
451 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
452 |
+
num_images_per_prompt: Optional[int] = 1,
|
453 |
+
eta: float = 0.0,
|
454 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
455 |
+
latents: Optional[torch.FloatTensor] = None,
|
456 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
457 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
458 |
+
output_type: Optional[str] = "pil",
|
459 |
+
return_dict: bool = True,
|
460 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
461 |
+
original_size: Optional[Tuple[int, int]] = None,
|
462 |
+
target_size: Optional[Tuple[int, int]] = None,
|
463 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
464 |
+
callback_steps: int = 1,
|
465 |
+
input_subj_image_objs: PipelineImageInput = None,
|
466 |
+
start_merge_step: int = 0,
|
467 |
+
):
|
468 |
+
# 0. Default height and width to unet
|
469 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
470 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
471 |
+
|
472 |
+
original_size = original_size or (height, width)
|
473 |
+
target_size = target_size or (height, width)
|
474 |
+
|
475 |
+
# 1. Check inputs. Raise error if not correct
|
476 |
+
self.check_inputs(
|
477 |
+
prompt,
|
478 |
+
height,
|
479 |
+
width,
|
480 |
+
callback_steps,
|
481 |
+
negative_prompt,
|
482 |
+
prompt_embeds,
|
483 |
+
negative_prompt_embeds,
|
484 |
+
)
|
485 |
+
|
486 |
+
# 2. Define call parameters
|
487 |
+
if prompt is not None and isinstance(prompt, str):
|
488 |
+
batch_size = 1
|
489 |
+
elif prompt is not None and isinstance(prompt, list):
|
490 |
+
batch_size = len(prompt)
|
491 |
+
else:
|
492 |
+
batch_size = prompt_embeds.shape[0]
|
493 |
+
|
494 |
+
device = self._execution_device
|
495 |
+
do_classifier_free_guidance = guidance_scale >= 1.0
|
496 |
+
assert do_classifier_free_guidance
|
497 |
+
|
498 |
+
if input_subj_image_objs is not None:
|
499 |
+
if not isinstance(input_subj_image_objs, list):
|
500 |
+
input_subj_image_objs = [input_subj_image_objs]
|
501 |
+
|
502 |
+
# 3. Encode input prompt
|
503 |
+
coarse_prompt_embeds, fine_prompt_embeds = \
|
504 |
+
self.extract_double_id_prompt_embeds(prompt, negative_prompt, input_subj_image_objs[0], device)
|
505 |
+
else:
|
506 |
+
# Replace the coarse_prompt_embeds and fine_prompt_embeds with the input prompt_embeds.
|
507 |
+
# This is used when prompt_embeds are computed in advance.
|
508 |
+
cfg_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
509 |
+
coarse_prompt_embeds = cfg_prompt_embeds
|
510 |
+
fine_prompt_embeds = cfg_prompt_embeds
|
511 |
+
|
512 |
+
# 7. Prepare timesteps
|
513 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
514 |
+
timesteps = self.scheduler.timesteps
|
515 |
+
|
516 |
+
# 8. Prepare latent variables
|
517 |
+
num_channels_latents = self.unet.config.in_channels
|
518 |
+
latents = self.prepare_latents(
|
519 |
+
batch_size * num_images_per_prompt,
|
520 |
+
num_channels_latents,
|
521 |
+
height,
|
522 |
+
width,
|
523 |
+
self.dtype,
|
524 |
+
device,
|
525 |
+
generator,
|
526 |
+
latents,
|
527 |
+
)
|
528 |
+
|
529 |
+
# {'eta': 0.0, 'generator': None}. eta is 0 for DDIM.
|
530 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
531 |
+
cross_attention_kwargs = {}
|
532 |
+
|
533 |
+
# 9. Denoising loop
|
534 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
535 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
536 |
+
for i, t in enumerate(timesteps):
|
537 |
+
latent_model_input = (
|
538 |
+
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
539 |
+
)
|
540 |
+
# DDIM doesn't scale latent_model_input.
|
541 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
542 |
+
|
543 |
+
if i <= start_merge_step:
|
544 |
+
current_prompt_embeds = coarse_prompt_embeds
|
545 |
+
else:
|
546 |
+
current_prompt_embeds = fine_prompt_embeds
|
547 |
+
|
548 |
+
# predict the noise residual
|
549 |
+
noise_pred = self.unet(
|
550 |
+
latent_model_input,
|
551 |
+
t,
|
552 |
+
encoder_hidden_states=current_prompt_embeds,
|
553 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
554 |
+
).sample
|
555 |
+
|
556 |
+
# perform guidance
|
557 |
+
if do_classifier_free_guidance:
|
558 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
559 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
560 |
+
noise_pred_text - noise_pred_uncond
|
561 |
+
)
|
562 |
+
else:
|
563 |
+
assert 0, "Not Implemented"
|
564 |
+
|
565 |
+
# compute the previous noisy sample x_t -> x_t-1
|
566 |
+
latents = self.scheduler.step(
|
567 |
+
noise_pred, t, latents, **extra_step_kwargs
|
568 |
+
).prev_sample
|
569 |
+
|
570 |
+
# call the callback, if provided
|
571 |
+
if i == len(timesteps) - 1 or \
|
572 |
+
( (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 ):
|
573 |
+
progress_bar.update()
|
574 |
+
if callback is not None and i % callback_steps == 0:
|
575 |
+
callback(i, t, latents)
|
576 |
+
|
577 |
+
if output_type == "latent":
|
578 |
+
image = latents
|
579 |
+
elif output_type == "pil":
|
580 |
+
# 9.1 Post-processing
|
581 |
+
image = self.decode_latents(latents)
|
582 |
+
# 9.3 Convert to PIL
|
583 |
+
image = self.numpy_to_pil(image)
|
584 |
+
else:
|
585 |
+
# 9.1 Post-processing
|
586 |
+
image = self.decode_latents(latents)
|
587 |
+
|
588 |
+
# Offload last model to CPU
|
589 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
590 |
+
self.final_offload_hook.offload()
|
591 |
+
|
592 |
+
if not return_dict:
|
593 |
+
return (image, None)
|
594 |
+
|
595 |
+
return StableDiffusionPipelineOutput(
|
596 |
+
images=image, nsfw_content_detected=None
|
597 |
+
)
|
598 |
+
|
599 |
+
|
600 |
+
|
601 |
+
|
602 |
+
|
603 |
+
|
604 |
+
|
605 |
+
|