liuq641968816 commited on
Commit
92394ad
β€’
1 Parent(s): 7e7d108

Upload 78 files

Browse files
Files changed (3) hide show
  1. run/gradio_ootd.py +274 -260
  2. run/run_ootd.py +87 -87
  3. run/utils_ootd.py +170 -170
run/gradio_ootd.py CHANGED
@@ -1,260 +1,274 @@
1
- import gradio as gr
2
- import os
3
- from pathlib import Path
4
- import sys
5
- import torch
6
- from PIL import Image, ImageOps
7
-
8
- from utils_ootd import get_mask_location
9
-
10
- PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
11
- sys.path.insert(0, str(PROJECT_ROOT))
12
-
13
- import time
14
- from preprocess.openpose.run_openpose import OpenPose
15
- from preprocess.humanparsing.run_parsing import Parsing
16
- from ootd.inference_ootd_hd import OOTDiffusionHD
17
- from ootd.inference_ootd_dc import OOTDiffusionDC
18
-
19
-
20
- openpose_model_hd = OpenPose(0)
21
- parsing_model_hd = Parsing(0)
22
- ootd_model_hd = OOTDiffusionHD(0)
23
-
24
- openpose_model_dc = OpenPose(1)
25
- parsing_model_dc = Parsing(1)
26
- ootd_model_dc = OOTDiffusionDC(1)
27
-
28
-
29
- category_dict = ['upperbody', 'lowerbody', 'dress']
30
- category_dict_utils = ['upper_body', 'lower_body', 'dresses']
31
-
32
-
33
- example_path = os.path.join(os.path.dirname(__file__), 'examples')
34
- model_hd = os.path.join(example_path, 'model/model_1.png')
35
- garment_hd = os.path.join(example_path, 'garment/03244_00.jpg')
36
- model_dc = os.path.join(example_path, 'model/model_8.png')
37
- garment_dc = os.path.join(example_path, 'garment/048554_1.jpg')
38
-
39
- def process_hd(vton_img, garm_img, n_samples, n_steps, image_scale, seed):
40
- model_type = 'hd'
41
- category = 0 # 0:upperbody; 1:lowerbody; 2:dress
42
-
43
- with torch.no_grad():
44
- garm_img = Image.open(garm_img).resize((768, 1024))
45
- vton_img = Image.open(vton_img).resize((768, 1024))
46
- keypoints = openpose_model_hd(vton_img.resize((384, 512)))
47
- model_parse, _ = parsing_model_hd(vton_img.resize((384, 512)))
48
-
49
- mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
50
- mask = mask.resize((768, 1024), Image.NEAREST)
51
- mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
52
-
53
- masked_vton_img = Image.composite(mask_gray, vton_img, mask)
54
-
55
- images = ootd_model_hd(
56
- model_type=model_type,
57
- category=category_dict[category],
58
- image_garm=garm_img,
59
- image_vton=masked_vton_img,
60
- mask=mask,
61
- image_ori=vton_img,
62
- num_samples=n_samples,
63
- num_steps=n_steps,
64
- image_scale=image_scale,
65
- seed=seed,
66
- )
67
-
68
- return images
69
-
70
- def process_dc(vton_img, garm_img, category, n_samples, n_steps, image_scale, seed):
71
- model_type = 'dc'
72
- if category == 'Upper-body':
73
- category = 0
74
- elif category == 'Lower-body':
75
- category = 1
76
- else:
77
- category =2
78
-
79
- with torch.no_grad():
80
- garm_img = Image.open(garm_img).resize((768, 1024))
81
- vton_img = Image.open(vton_img).resize((768, 1024))
82
- keypoints = openpose_model_dc(vton_img.resize((384, 512)))
83
- model_parse, _ = parsing_model_dc(vton_img.resize((384, 512)))
84
-
85
- mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
86
- mask = mask.resize((768, 1024), Image.NEAREST)
87
- mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
88
-
89
- masked_vton_img = Image.composite(mask_gray, vton_img, mask)
90
-
91
- images = ootd_model_dc(
92
- model_type=model_type,
93
- category=category_dict[category],
94
- image_garm=garm_img,
95
- image_vton=masked_vton_img,
96
- mask=mask,
97
- image_ori=vton_img,
98
- num_samples=n_samples,
99
- num_steps=n_steps,
100
- image_scale=image_scale,
101
- seed=seed,
102
- )
103
-
104
- return images
105
-
106
-
107
- block = gr.Blocks().queue()
108
- with block:
109
- with gr.Row():
110
- gr.Markdown("# OOTDiffusion Demo")
111
- with gr.Row():
112
- gr.Markdown("## Half-body")
113
- with gr.Row():
114
- gr.Markdown("***Support upper-body garments***")
115
- with gr.Row():
116
- with gr.Column():
117
- vton_img = gr.Image(label="Model", sources='upload', type="filepath", height=384, value=model_hd)
118
- example = gr.Examples(
119
- inputs=vton_img,
120
- examples_per_page=14,
121
- examples=[
122
- os.path.join(example_path, 'model/model_1.png'),
123
- os.path.join(example_path, 'model/model_2.png'),
124
- os.path.join(example_path, 'model/model_3.png'),
125
- os.path.join(example_path, 'model/model_4.png'),
126
- os.path.join(example_path, 'model/model_5.png'),
127
- os.path.join(example_path, 'model/model_6.png'),
128
- os.path.join(example_path, 'model/model_7.png'),
129
- os.path.join(example_path, 'model/01008_00.jpg'),
130
- os.path.join(example_path, 'model/07966_00.jpg'),
131
- os.path.join(example_path, 'model/05997_00.jpg'),
132
- os.path.join(example_path, 'model/02849_00.jpg'),
133
- os.path.join(example_path, 'model/14627_00.jpg'),
134
- os.path.join(example_path, 'model/09597_00.jpg'),
135
- os.path.join(example_path, 'model/01861_00.jpg'),
136
- ])
137
- with gr.Column():
138
- garm_img = gr.Image(label="Garment", sources='upload', type="filepath", height=384, value=garment_hd)
139
- example = gr.Examples(
140
- inputs=garm_img,
141
- examples_per_page=14,
142
- examples=[
143
- os.path.join(example_path, 'garment/03244_00.jpg'),
144
- os.path.join(example_path, 'garment/00126_00.jpg'),
145
- os.path.join(example_path, 'garment/03032_00.jpg'),
146
- os.path.join(example_path, 'garment/06123_00.jpg'),
147
- os.path.join(example_path, 'garment/02305_00.jpg'),
148
- os.path.join(example_path, 'garment/00055_00.jpg'),
149
- os.path.join(example_path, 'garment/00470_00.jpg'),
150
- os.path.join(example_path, 'garment/02015_00.jpg'),
151
- os.path.join(example_path, 'garment/10297_00.jpg'),
152
- os.path.join(example_path, 'garment/07382_00.jpg'),
153
- os.path.join(example_path, 'garment/07764_00.jpg'),
154
- os.path.join(example_path, 'garment/00151_00.jpg'),
155
- os.path.join(example_path, 'garment/12562_00.jpg'),
156
- os.path.join(example_path, 'garment/04825_00.jpg'),
157
- ])
158
- with gr.Column():
159
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True, scale=1)
160
- with gr.Column():
161
- run_button = gr.Button(value="Run")
162
- n_samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1)
163
- n_steps = gr.Slider(label="Steps", minimum=20, maximum=40, value=20, step=1)
164
- # scale = gr.Slider(label="Scale", minimum=1.0, maximum=12.0, value=5.0, step=0.1)
165
- image_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=5.0, value=2.0, step=0.1)
166
- seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1)
167
-
168
- ips = [vton_img, garm_img, n_samples, n_steps, image_scale, seed]
169
- run_button.click(fn=process_hd, inputs=ips, outputs=[result_gallery])
170
-
171
-
172
- with gr.Row():
173
- gr.Markdown("## Full-body")
174
- with gr.Row():
175
- gr.Markdown("***Support upper-body/lower-body/dresses; garment category must be paired!!!***")
176
- with gr.Row():
177
- with gr.Column():
178
- vton_img_dc = gr.Image(label="Model", sources='upload', type="filepath", height=384, value=model_dc)
179
- example = gr.Examples(
180
- label="Examples (upper-body/lower-body)",
181
- inputs=vton_img_dc,
182
- examples_per_page=7,
183
- examples=[
184
- os.path.join(example_path, 'model/model_8.png'),
185
- os.path.join(example_path, 'model/049447_0.jpg'),
186
- os.path.join(example_path, 'model/049713_0.jpg'),
187
- os.path.join(example_path, 'model/051482_0.jpg'),
188
- os.path.join(example_path, 'model/051918_0.jpg'),
189
- os.path.join(example_path, 'model/051962_0.jpg'),
190
- os.path.join(example_path, 'model/049205_0.jpg'),
191
- ])
192
- example = gr.Examples(
193
- label="Examples (dress)",
194
- inputs=vton_img_dc,
195
- examples_per_page=7,
196
- examples=[
197
- os.path.join(example_path, 'model/model_9.png'),
198
- os.path.join(example_path, 'model/052767_0.jpg'),
199
- os.path.join(example_path, 'model/052472_0.jpg'),
200
- os.path.join(example_path, 'model/053514_0.jpg'),
201
- os.path.join(example_path, 'model/053228_0.jpg'),
202
- os.path.join(example_path, 'model/052964_0.jpg'),
203
- os.path.join(example_path, 'model/053700_0.jpg'),
204
- ])
205
- with gr.Column():
206
- garm_img_dc = gr.Image(label="Garment", sources='upload', type="filepath", height=384, value=garment_dc)
207
- category_dc = gr.Dropdown(label="Garment category (important option!!!)", choices=["Upper-body", "Lower-body", "Dress"], value="Upper-body")
208
- example = gr.Examples(
209
- label="Examples (upper-body)",
210
- inputs=garm_img_dc,
211
- examples_per_page=7,
212
- examples=[
213
- os.path.join(example_path, 'garment/048554_1.jpg'),
214
- os.path.join(example_path, 'garment/049920_1.jpg'),
215
- os.path.join(example_path, 'garment/049965_1.jpg'),
216
- os.path.join(example_path, 'garment/049949_1.jpg'),
217
- os.path.join(example_path, 'garment/050181_1.jpg'),
218
- os.path.join(example_path, 'garment/049805_1.jpg'),
219
- os.path.join(example_path, 'garment/050105_1.jpg'),
220
- ])
221
- example = gr.Examples(
222
- label="Examples (lower-body)",
223
- inputs=garm_img_dc,
224
- examples_per_page=7,
225
- examples=[
226
- os.path.join(example_path, 'garment/051827_1.jpg'),
227
- os.path.join(example_path, 'garment/051946_1.jpg'),
228
- os.path.join(example_path, 'garment/051473_1.jpg'),
229
- os.path.join(example_path, 'garment/051515_1.jpg'),
230
- os.path.join(example_path, 'garment/051517_1.jpg'),
231
- os.path.join(example_path, 'garment/051988_1.jpg'),
232
- os.path.join(example_path, 'garment/051412_1.jpg'),
233
- ])
234
- example = gr.Examples(
235
- label="Examples (dress)",
236
- inputs=garm_img_dc,
237
- examples_per_page=7,
238
- examples=[
239
- os.path.join(example_path, 'garment/053290_1.jpg'),
240
- os.path.join(example_path, 'garment/053744_1.jpg'),
241
- os.path.join(example_path, 'garment/053742_1.jpg'),
242
- os.path.join(example_path, 'garment/053786_1.jpg'),
243
- os.path.join(example_path, 'garment/053790_1.jpg'),
244
- os.path.join(example_path, 'garment/053319_1.jpg'),
245
- os.path.join(example_path, 'garment/052234_1.jpg'),
246
- ])
247
- with gr.Column():
248
- result_gallery_dc = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True, scale=1)
249
- with gr.Column():
250
- run_button_dc = gr.Button(value="Run")
251
- n_samples_dc = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1)
252
- n_steps_dc = gr.Slider(label="Steps", minimum=20, maximum=40, value=20, step=1)
253
- # scale_dc = gr.Slider(label="Scale", minimum=1.0, maximum=12.0, value=5.0, step=0.1)
254
- image_scale_dc = gr.Slider(label="Guidance scale", minimum=1.0, maximum=5.0, value=2.0, step=0.1)
255
- seed_dc = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1)
256
-
257
- ips_dc = [vton_img_dc, garm_img_dc, category_dc, n_samples_dc, n_steps_dc, image_scale_dc, seed_dc]
258
- run_button_dc.click(fn=process_dc, inputs=ips_dc, outputs=[result_gallery_dc])
259
-
260
- block.launch(server_name='0.0.0.0', server_port=7865)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from pathlib import Path
4
+ import sys
5
+ import torch
6
+ from PIL import Image, ImageOps
7
+
8
+ from utils_ootd import get_mask_location
9
+
10
+ PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
11
+ sys.path.insert(0, str(PROJECT_ROOT))
12
+
13
+ from preprocess.openpose.run_openpose import OpenPose
14
+ from preprocess.humanparsing.run_parsing import Parsing
15
+ from ootd.inference_ootd_hd import OOTDiffusionHD
16
+ from ootd.inference_ootd_dc import OOTDiffusionDC
17
+
18
+
19
+ openpose_model_hd = OpenPose(0)
20
+ parsing_model_hd = Parsing(0)
21
+ ootd_model_hd = OOTDiffusionHD(0)
22
+
23
+ openpose_model_dc = OpenPose(1)
24
+ parsing_model_dc = Parsing(1)
25
+ ootd_model_dc = OOTDiffusionDC(1)
26
+
27
+
28
+ category_dict = ['upperbody', 'lowerbody', 'dress']
29
+ category_dict_utils = ['upper_body', 'lower_body', 'dresses']
30
+
31
+
32
+ example_path = os.path.join(os.path.dirname(__file__), 'examples')
33
+ model_hd = os.path.join(example_path, 'model/model_1.png')
34
+ garment_hd = os.path.join(example_path, 'garment/03244_00.jpg')
35
+ model_dc = os.path.join(example_path, 'model/model_8.png')
36
+ garment_dc = os.path.join(example_path, 'garment/048554_1.jpg')
37
+
38
+
39
+ import spaces
40
+
41
+ @spaces.GPU
42
+ def process_hd(vton_img, garm_img, n_samples, n_steps, image_scale, seed):
43
+ model_type = 'hd'
44
+ category = 0 # 0:upperbody; 1:lowerbody; 2:dress
45
+
46
+ with torch.no_grad():
47
+ openpose_model_hd.preprocessor.body_estimation.model.to('cuda')
48
+ ootd_model_hd.pipe.to('cuda')
49
+ ootd_model_hd.image_encoder.to('cuda')
50
+ ootd_model_hd.text_encoder.to('cuda')
51
+
52
+ garm_img = Image.open(garm_img).resize((768, 1024))
53
+ vton_img = Image.open(vton_img).resize((768, 1024))
54
+ keypoints = openpose_model_hd(vton_img.resize((384, 512)))
55
+ model_parse, _ = parsing_model_hd(vton_img.resize((384, 512)))
56
+
57
+ mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
58
+ mask = mask.resize((768, 1024), Image.NEAREST)
59
+ mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
60
+
61
+ masked_vton_img = Image.composite(mask_gray, vton_img, mask)
62
+
63
+ images = ootd_model_hd(
64
+ model_type=model_type,
65
+ category=category_dict[category],
66
+ image_garm=garm_img,
67
+ image_vton=masked_vton_img,
68
+ mask=mask,
69
+ image_ori=vton_img,
70
+ num_samples=n_samples,
71
+ num_steps=n_steps,
72
+ image_scale=image_scale,
73
+ seed=seed,
74
+ )
75
+
76
+ return images
77
+
78
+ @spaces.GPU
79
+ def process_dc(vton_img, garm_img, category, n_samples, n_steps, image_scale, seed):
80
+ model_type = 'dc'
81
+ if category == 'Upper-body':
82
+ category = 0
83
+ elif category == 'Lower-body':
84
+ category = 1
85
+ else:
86
+ category =2
87
+
88
+ with torch.no_grad():
89
+ openpose_model_dc.preprocessor.body_estimation.model.to('cuda')
90
+ ootd_model_dc.pipe.to('cuda')
91
+ ootd_model_dc.image_encoder.to('cuda')
92
+ ootd_model_dc.text_encoder.to('cuda')
93
+
94
+ garm_img = Image.open(garm_img).resize((768, 1024))
95
+ vton_img = Image.open(vton_img).resize((768, 1024))
96
+ keypoints = openpose_model_dc(vton_img.resize((384, 512)))
97
+ model_parse, _ = parsing_model_dc(vton_img.resize((384, 512)))
98
+
99
+ mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
100
+ mask = mask.resize((768, 1024), Image.NEAREST)
101
+ mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
102
+
103
+ masked_vton_img = Image.composite(mask_gray, vton_img, mask)
104
+
105
+ images = ootd_model_dc(
106
+ model_type=model_type,
107
+ category=category_dict[category],
108
+ image_garm=garm_img,
109
+ image_vton=masked_vton_img,
110
+ mask=mask,
111
+ image_ori=vton_img,
112
+ num_samples=n_samples,
113
+ num_steps=n_steps,
114
+ image_scale=image_scale,
115
+ seed=seed,
116
+ )
117
+
118
+ return images
119
+
120
+
121
+ block = gr.Blocks().queue()
122
+ with block:
123
+ with gr.Row():
124
+ gr.Markdown("# OOTDiffusion Demo")
125
+ with gr.Row():
126
+ gr.Markdown("## Half-body")
127
+ with gr.Row():
128
+ gr.Markdown("***Support upper-body garments***")
129
+ with gr.Row():
130
+ with gr.Column():
131
+ vton_img = gr.Image(label="Model", sources='upload', type="filepath", height=384, value=model_hd)
132
+ example = gr.Examples(
133
+ inputs=vton_img,
134
+ examples_per_page=14,
135
+ examples=[
136
+ os.path.join(example_path, 'model/model_1.png'),
137
+ os.path.join(example_path, 'model/model_2.png'),
138
+ os.path.join(example_path, 'model/model_3.png'),
139
+ os.path.join(example_path, 'model/model_4.png'),
140
+ os.path.join(example_path, 'model/model_5.png'),
141
+ os.path.join(example_path, 'model/model_6.png'),
142
+ os.path.join(example_path, 'model/model_7.png'),
143
+ os.path.join(example_path, 'model/01008_00.jpg'),
144
+ os.path.join(example_path, 'model/07966_00.jpg'),
145
+ os.path.join(example_path, 'model/05997_00.jpg'),
146
+ os.path.join(example_path, 'model/02849_00.jpg'),
147
+ os.path.join(example_path, 'model/14627_00.jpg'),
148
+ os.path.join(example_path, 'model/09597_00.jpg'),
149
+ os.path.join(example_path, 'model/01861_00.jpg'),
150
+ ])
151
+ with gr.Column():
152
+ garm_img = gr.Image(label="Garment", sources='upload', type="filepath", height=384, value=garment_hd)
153
+ example = gr.Examples(
154
+ inputs=garm_img,
155
+ examples_per_page=14,
156
+ examples=[
157
+ os.path.join(example_path, 'garment/03244_00.jpg'),
158
+ os.path.join(example_path, 'garment/00126_00.jpg'),
159
+ os.path.join(example_path, 'garment/03032_00.jpg'),
160
+ os.path.join(example_path, 'garment/06123_00.jpg'),
161
+ os.path.join(example_path, 'garment/02305_00.jpg'),
162
+ os.path.join(example_path, 'garment/00055_00.jpg'),
163
+ os.path.join(example_path, 'garment/00470_00.jpg'),
164
+ os.path.join(example_path, 'garment/02015_00.jpg'),
165
+ os.path.join(example_path, 'garment/10297_00.jpg'),
166
+ os.path.join(example_path, 'garment/07382_00.jpg'),
167
+ os.path.join(example_path, 'garment/07764_00.jpg'),
168
+ os.path.join(example_path, 'garment/00151_00.jpg'),
169
+ os.path.join(example_path, 'garment/12562_00.jpg'),
170
+ os.path.join(example_path, 'garment/04825_00.jpg'),
171
+ ])
172
+ with gr.Column():
173
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True, scale=1)
174
+ with gr.Column():
175
+ run_button = gr.Button(value="Run")
176
+ n_samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1)
177
+ n_steps = gr.Slider(label="Steps", minimum=20, maximum=40, value=20, step=1)
178
+ # scale = gr.Slider(label="Scale", minimum=1.0, maximum=12.0, value=5.0, step=0.1)
179
+ image_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=5.0, value=2.0, step=0.1)
180
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1)
181
+
182
+ ips = [vton_img, garm_img, n_samples, n_steps, image_scale, seed]
183
+ run_button.click(fn=process_hd, inputs=ips, outputs=[result_gallery])
184
+
185
+
186
+ with gr.Row():
187
+ gr.Markdown("## Full-body")
188
+ with gr.Row():
189
+ gr.Markdown("***Support upper-body/lower-body/dresses; garment category must be paired!!!***")
190
+ with gr.Row():
191
+ with gr.Column():
192
+ vton_img_dc = gr.Image(label="Model", sources='upload', type="filepath", height=384, value=model_dc)
193
+ example = gr.Examples(
194
+ label="Examples (upper-body/lower-body)",
195
+ inputs=vton_img_dc,
196
+ examples_per_page=7,
197
+ examples=[
198
+ os.path.join(example_path, 'model/model_8.png'),
199
+ os.path.join(example_path, 'model/049447_0.jpg'),
200
+ os.path.join(example_path, 'model/049713_0.jpg'),
201
+ os.path.join(example_path, 'model/051482_0.jpg'),
202
+ os.path.join(example_path, 'model/051918_0.jpg'),
203
+ os.path.join(example_path, 'model/051962_0.jpg'),
204
+ os.path.join(example_path, 'model/049205_0.jpg'),
205
+ ])
206
+ example = gr.Examples(
207
+ label="Examples (dress)",
208
+ inputs=vton_img_dc,
209
+ examples_per_page=7,
210
+ examples=[
211
+ os.path.join(example_path, 'model/model_9.png'),
212
+ os.path.join(example_path, 'model/052767_0.jpg'),
213
+ os.path.join(example_path, 'model/052472_0.jpg'),
214
+ os.path.join(example_path, 'model/053514_0.jpg'),
215
+ os.path.join(example_path, 'model/053228_0.jpg'),
216
+ os.path.join(example_path, 'model/052964_0.jpg'),
217
+ os.path.join(example_path, 'model/053700_0.jpg'),
218
+ ])
219
+ with gr.Column():
220
+ garm_img_dc = gr.Image(label="Garment", sources='upload', type="filepath", height=384, value=garment_dc)
221
+ category_dc = gr.Dropdown(label="Garment category (important option!!!)", choices=["Upper-body", "Lower-body", "Dress"], value="Upper-body")
222
+ example = gr.Examples(
223
+ label="Examples (upper-body)",
224
+ inputs=garm_img_dc,
225
+ examples_per_page=7,
226
+ examples=[
227
+ os.path.join(example_path, 'garment/048554_1.jpg'),
228
+ os.path.join(example_path, 'garment/049920_1.jpg'),
229
+ os.path.join(example_path, 'garment/049965_1.jpg'),
230
+ os.path.join(example_path, 'garment/049949_1.jpg'),
231
+ os.path.join(example_path, 'garment/050181_1.jpg'),
232
+ os.path.join(example_path, 'garment/049805_1.jpg'),
233
+ os.path.join(example_path, 'garment/050105_1.jpg'),
234
+ ])
235
+ example = gr.Examples(
236
+ label="Examples (lower-body)",
237
+ inputs=garm_img_dc,
238
+ examples_per_page=7,
239
+ examples=[
240
+ os.path.join(example_path, 'garment/051827_1.jpg'),
241
+ os.path.join(example_path, 'garment/051946_1.jpg'),
242
+ os.path.join(example_path, 'garment/051473_1.jpg'),
243
+ os.path.join(example_path, 'garment/051515_1.jpg'),
244
+ os.path.join(example_path, 'garment/051517_1.jpg'),
245
+ os.path.join(example_path, 'garment/051988_1.jpg'),
246
+ os.path.join(example_path, 'garment/051412_1.jpg'),
247
+ ])
248
+ example = gr.Examples(
249
+ label="Examples (dress)",
250
+ inputs=garm_img_dc,
251
+ examples_per_page=7,
252
+ examples=[
253
+ os.path.join(example_path, 'garment/053290_1.jpg'),
254
+ os.path.join(example_path, 'garment/053744_1.jpg'),
255
+ os.path.join(example_path, 'garment/053742_1.jpg'),
256
+ os.path.join(example_path, 'garment/053786_1.jpg'),
257
+ os.path.join(example_path, 'garment/053790_1.jpg'),
258
+ os.path.join(example_path, 'garment/053319_1.jpg'),
259
+ os.path.join(example_path, 'garment/052234_1.jpg'),
260
+ ])
261
+ with gr.Column():
262
+ result_gallery_dc = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True, scale=1)
263
+ with gr.Column():
264
+ run_button_dc = gr.Button(value="Run")
265
+ n_samples_dc = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1)
266
+ n_steps_dc = gr.Slider(label="Steps", minimum=20, maximum=40, value=20, step=1)
267
+ # scale_dc = gr.Slider(label="Scale", minimum=1.0, maximum=12.0, value=5.0, step=0.1)
268
+ image_scale_dc = gr.Slider(label="Guidance scale", minimum=1.0, maximum=5.0, value=2.0, step=0.1)
269
+ seed_dc = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, value=-1)
270
+
271
+ ips_dc = [vton_img_dc, garm_img_dc, category_dc, n_samples_dc, n_steps_dc, image_scale_dc, seed_dc]
272
+ run_button_dc.click(fn=process_dc, inputs=ips_dc, outputs=[result_gallery_dc])
273
+
274
+ block.launch()
run/run_ootd.py CHANGED
@@ -1,87 +1,87 @@
1
- from pathlib import Path
2
- import sys
3
- from PIL import Image
4
- from utils_ootd import get_mask_location
5
-
6
- PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
7
- sys.path.insert(0, str(PROJECT_ROOT))
8
-
9
- from preprocess.openpose.run_openpose import OpenPose
10
- from preprocess.humanparsing.run_parsing import Parsing
11
- from ootd.inference_ootd_hd import OOTDiffusionHD
12
- from ootd.inference_ootd_dc import OOTDiffusionDC
13
-
14
-
15
- import argparse
16
- parser = argparse.ArgumentParser(description='run ootd')
17
- parser.add_argument('--gpu_id', '-g', type=int, default=0, required=False)
18
- parser.add_argument('--model_path', type=str, default="", required=True)
19
- parser.add_argument('--cloth_path', type=str, default="", required=True)
20
- parser.add_argument('--model_type', type=str, default="hd", required=False)
21
- parser.add_argument('--category', '-c', type=int, default=0, required=False)
22
- parser.add_argument('--scale', type=float, default=2.0, required=False)
23
- parser.add_argument('--step', type=int, default=20, required=False)
24
- parser.add_argument('--sample', type=int, default=4, required=False)
25
- parser.add_argument('--seed', type=int, default=-1, required=False)
26
- args = parser.parse_args()
27
-
28
-
29
- openpose_model = OpenPose(args.gpu_id)
30
- parsing_model = Parsing(args.gpu_id)
31
-
32
-
33
- category_dict = ['upperbody', 'lowerbody', 'dress']
34
- category_dict_utils = ['upper_body', 'lower_body', 'dresses']
35
-
36
- model_type = args.model_type # "hd" or "dc"
37
- category = args.category # 0:upperbody; 1:lowerbody; 2:dress
38
- cloth_path = args.cloth_path
39
- model_path = args.model_path
40
-
41
- image_scale = args.scale
42
- n_steps = args.step
43
- n_samples = args.sample
44
- seed = args.seed
45
-
46
- if model_type == "hd":
47
- model = OOTDiffusionHD(args.gpu_id)
48
- elif model_type == "dc":
49
- model = OOTDiffusionDC(args.gpu_id)
50
- else:
51
- raise ValueError("model_type must be \'hd\' or \'dc\'!")
52
-
53
-
54
- if __name__ == '__main__':
55
-
56
- if model_type == 'hd' and category != 0:
57
- raise ValueError("model_type \'hd\' requires category == 0 (upperbody)!")
58
-
59
- cloth_img = Image.open(cloth_path).resize((768, 1024))
60
- model_img = Image.open(model_path).resize((768, 1024))
61
- keypoints = openpose_model(model_img.resize((384, 512)))
62
- model_parse, _ = parsing_model(model_img.resize((384, 512)))
63
-
64
- mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
65
- mask = mask.resize((768, 1024), Image.NEAREST)
66
- mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
67
-
68
- masked_vton_img = Image.composite(mask_gray, model_img, mask)
69
- masked_vton_img.save('./images_output/mask.jpg')
70
-
71
- images = model(
72
- model_type=model_type,
73
- category=category_dict[category],
74
- image_garm=cloth_img,
75
- image_vton=masked_vton_img,
76
- mask=mask,
77
- image_ori=model_img,
78
- num_samples=n_samples,
79
- num_steps=n_steps,
80
- image_scale=image_scale,
81
- seed=seed,
82
- )
83
-
84
- image_idx = 0
85
- for image in images:
86
- image.save('./images_output/out_' + model_type + '_' + str(image_idx) + '.png')
87
- image_idx += 1
 
1
+ from pathlib import Path
2
+ import sys
3
+ from PIL import Image
4
+ from utils_ootd import get_mask_location
5
+
6
+ PROJECT_ROOT = Path(__file__).absolute().parents[1].absolute()
7
+ sys.path.insert(0, str(PROJECT_ROOT))
8
+
9
+ from preprocess.openpose.run_openpose import OpenPose
10
+ from preprocess.humanparsing.run_parsing import Parsing
11
+ from ootd.inference_ootd_hd import OOTDiffusionHD
12
+ from ootd.inference_ootd_dc import OOTDiffusionDC
13
+
14
+
15
+ import argparse
16
+ parser = argparse.ArgumentParser(description='run ootd')
17
+ parser.add_argument('--gpu_id', '-g', type=int, default=0, required=False)
18
+ parser.add_argument('--model_path', type=str, default="", required=True)
19
+ parser.add_argument('--cloth_path', type=str, default="", required=True)
20
+ parser.add_argument('--model_type', type=str, default="hd", required=False)
21
+ parser.add_argument('--category', '-c', type=int, default=0, required=False)
22
+ parser.add_argument('--scale', type=float, default=2.0, required=False)
23
+ parser.add_argument('--step', type=int, default=20, required=False)
24
+ parser.add_argument('--sample', type=int, default=4, required=False)
25
+ parser.add_argument('--seed', type=int, default=-1, required=False)
26
+ args = parser.parse_args()
27
+
28
+
29
+ openpose_model = OpenPose(args.gpu_id)
30
+ parsing_model = Parsing(args.gpu_id)
31
+
32
+
33
+ category_dict = ['upperbody', 'lowerbody', 'dress']
34
+ category_dict_utils = ['upper_body', 'lower_body', 'dresses']
35
+
36
+ model_type = args.model_type # "hd" or "dc"
37
+ category = args.category # 0:upperbody; 1:lowerbody; 2:dress
38
+ cloth_path = args.cloth_path
39
+ model_path = args.model_path
40
+
41
+ image_scale = args.scale
42
+ n_steps = args.step
43
+ n_samples = args.sample
44
+ seed = args.seed
45
+
46
+ if model_type == "hd":
47
+ model = OOTDiffusionHD(args.gpu_id)
48
+ elif model_type == "dc":
49
+ model = OOTDiffusionDC(args.gpu_id)
50
+ else:
51
+ raise ValueError("model_type must be \'hd\' or \'dc\'!")
52
+
53
+
54
+ if __name__ == '__main__':
55
+
56
+ if model_type == 'hd' and category != 0:
57
+ raise ValueError("model_type \'hd\' requires category == 0 (upperbody)!")
58
+
59
+ cloth_img = Image.open(cloth_path).resize((768, 1024))
60
+ model_img = Image.open(model_path).resize((768, 1024))
61
+ keypoints = openpose_model(model_img.resize((384, 512)))
62
+ model_parse, _ = parsing_model(model_img.resize((384, 512)))
63
+
64
+ mask, mask_gray = get_mask_location(model_type, category_dict_utils[category], model_parse, keypoints)
65
+ mask = mask.resize((768, 1024), Image.NEAREST)
66
+ mask_gray = mask_gray.resize((768, 1024), Image.NEAREST)
67
+
68
+ masked_vton_img = Image.composite(mask_gray, model_img, mask)
69
+ masked_vton_img.save('./images_output/mask.jpg')
70
+
71
+ images = model(
72
+ model_type=model_type,
73
+ category=category_dict[category],
74
+ image_garm=cloth_img,
75
+ image_vton=masked_vton_img,
76
+ mask=mask,
77
+ image_ori=model_img,
78
+ num_samples=n_samples,
79
+ num_steps=n_steps,
80
+ image_scale=image_scale,
81
+ seed=seed,
82
+ )
83
+
84
+ image_idx = 0
85
+ for image in images:
86
+ image.save('./images_output/out_' + model_type + '_' + str(image_idx) + '.png')
87
+ image_idx += 1
run/utils_ootd.py CHANGED
@@ -1,170 +1,170 @@
1
- import pdb
2
-
3
- import numpy as np
4
- import cv2
5
- from PIL import Image, ImageDraw
6
-
7
- label_map = {
8
- "background": 0,
9
- "hat": 1,
10
- "hair": 2,
11
- "sunglasses": 3,
12
- "upper_clothes": 4,
13
- "skirt": 5,
14
- "pants": 6,
15
- "dress": 7,
16
- "belt": 8,
17
- "left_shoe": 9,
18
- "right_shoe": 10,
19
- "head": 11,
20
- "left_leg": 12,
21
- "right_leg": 13,
22
- "left_arm": 14,
23
- "right_arm": 15,
24
- "bag": 16,
25
- "scarf": 17,
26
- }
27
-
28
- def extend_arm_mask(wrist, elbow, scale):
29
- wrist = elbow + scale * (wrist - elbow)
30
- return wrist
31
-
32
- def hole_fill(img):
33
- img = np.pad(img[1:-1, 1:-1], pad_width = 1, mode = 'constant', constant_values=0)
34
- img_copy = img.copy()
35
- mask = np.zeros((img.shape[0] + 2, img.shape[1] + 2), dtype=np.uint8)
36
-
37
- cv2.floodFill(img, mask, (0, 0), 255)
38
- img_inverse = cv2.bitwise_not(img)
39
- dst = cv2.bitwise_or(img_copy, img_inverse)
40
- return dst
41
-
42
- def refine_mask(mask):
43
- contours, hierarchy = cv2.findContours(mask.astype(np.uint8),
44
- cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1)
45
- area = []
46
- for j in range(len(contours)):
47
- a_d = cv2.contourArea(contours[j], True)
48
- area.append(abs(a_d))
49
- refine_mask = np.zeros_like(mask).astype(np.uint8)
50
- if len(area) != 0:
51
- i = area.index(max(area))
52
- cv2.drawContours(refine_mask, contours, i, color=255, thickness=-1)
53
-
54
- return refine_mask
55
-
56
- def get_mask_location(model_type, category, model_parse: Image.Image, keypoint: dict, width=384,height=512):
57
- im_parse = model_parse.resize((width, height), Image.NEAREST)
58
- parse_array = np.array(im_parse)
59
-
60
- if model_type == 'hd':
61
- arm_width = 60
62
- elif model_type == 'dc':
63
- arm_width = 45
64
- else:
65
- raise ValueError("model_type must be \'hd\' or \'dc\'!")
66
-
67
- parse_head = (parse_array == 1).astype(np.float32) + \
68
- (parse_array == 3).astype(np.float32) + \
69
- (parse_array == 11).astype(np.float32)
70
-
71
- parser_mask_fixed = (parse_array == label_map["left_shoe"]).astype(np.float32) + \
72
- (parse_array == label_map["right_shoe"]).astype(np.float32) + \
73
- (parse_array == label_map["hat"]).astype(np.float32) + \
74
- (parse_array == label_map["sunglasses"]).astype(np.float32) + \
75
- (parse_array == label_map["bag"]).astype(np.float32)
76
-
77
- parser_mask_changeable = (parse_array == label_map["background"]).astype(np.float32)
78
-
79
- arms_left = (parse_array == 14).astype(np.float32)
80
- arms_right = (parse_array == 15).astype(np.float32)
81
- arms = arms_left + arms_right
82
-
83
- if category == 'dresses':
84
- parse_mask = (parse_array == 7).astype(np.float32) + \
85
- (parse_array == 4).astype(np.float32) + \
86
- (parse_array == 5).astype(np.float32) + \
87
- (parse_array == 6).astype(np.float32)
88
-
89
- parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
90
-
91
- elif category == 'upper_body':
92
- parse_mask = (parse_array == 4).astype(np.float32) + (parse_array == 7).astype(np.float32)
93
- parser_mask_fixed_lower_cloth = (parse_array == label_map["skirt"]).astype(np.float32) + \
94
- (parse_array == label_map["pants"]).astype(np.float32)
95
- parser_mask_fixed += parser_mask_fixed_lower_cloth
96
- parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
97
- elif category == 'lower_body':
98
- parse_mask = (parse_array == 6).astype(np.float32) + \
99
- (parse_array == 12).astype(np.float32) + \
100
- (parse_array == 13).astype(np.float32) + \
101
- (parse_array == 5).astype(np.float32)
102
- parser_mask_fixed += (parse_array == label_map["upper_clothes"]).astype(np.float32) + \
103
- (parse_array == 14).astype(np.float32) + \
104
- (parse_array == 15).astype(np.float32)
105
- parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
106
- else:
107
- raise NotImplementedError
108
-
109
- # Load pose points
110
- pose_data = keypoint["pose_keypoints_2d"]
111
- pose_data = np.array(pose_data)
112
- pose_data = pose_data.reshape((-1, 2))
113
-
114
- im_arms_left = Image.new('L', (width, height))
115
- im_arms_right = Image.new('L', (width, height))
116
- arms_draw_left = ImageDraw.Draw(im_arms_left)
117
- arms_draw_right = ImageDraw.Draw(im_arms_right)
118
- if category == 'dresses' or category == 'upper_body':
119
- shoulder_right = np.multiply(tuple(pose_data[2][:2]), height / 512.0)
120
- shoulder_left = np.multiply(tuple(pose_data[5][:2]), height / 512.0)
121
- elbow_right = np.multiply(tuple(pose_data[3][:2]), height / 512.0)
122
- elbow_left = np.multiply(tuple(pose_data[6][:2]), height / 512.0)
123
- wrist_right = np.multiply(tuple(pose_data[4][:2]), height / 512.0)
124
- wrist_left = np.multiply(tuple(pose_data[7][:2]), height / 512.0)
125
- ARM_LINE_WIDTH = int(arm_width / 512 * height)
126
- size_left = [shoulder_left[0] - ARM_LINE_WIDTH // 2, shoulder_left[1] - ARM_LINE_WIDTH // 2, shoulder_left[0] + ARM_LINE_WIDTH // 2, shoulder_left[1] + ARM_LINE_WIDTH // 2]
127
- size_right = [shoulder_right[0] - ARM_LINE_WIDTH // 2, shoulder_right[1] - ARM_LINE_WIDTH // 2, shoulder_right[0] + ARM_LINE_WIDTH // 2,
128
- shoulder_right[1] + ARM_LINE_WIDTH // 2]
129
-
130
-
131
- if wrist_right[0] <= 1. and wrist_right[1] <= 1.:
132
- im_arms_right = arms_right
133
- else:
134
- wrist_right = extend_arm_mask(wrist_right, elbow_right, 1.2)
135
- arms_draw_right.line(np.concatenate((shoulder_right, elbow_right, wrist_right)).astype(np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve')
136
- arms_draw_right.arc(size_right, 0, 360, 'white', ARM_LINE_WIDTH // 2)
137
-
138
- if wrist_left[0] <= 1. and wrist_left[1] <= 1.:
139
- im_arms_left = arms_left
140
- else:
141
- wrist_left = extend_arm_mask(wrist_left, elbow_left, 1.2)
142
- arms_draw_left.line(np.concatenate((wrist_left, elbow_left, shoulder_left)).astype(np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve')
143
- arms_draw_left.arc(size_left, 0, 360, 'white', ARM_LINE_WIDTH // 2)
144
-
145
- hands_left = np.logical_and(np.logical_not(im_arms_left), arms_left)
146
- hands_right = np.logical_and(np.logical_not(im_arms_right), arms_right)
147
- parser_mask_fixed += hands_left + hands_right
148
-
149
- parser_mask_fixed = np.logical_or(parser_mask_fixed, parse_head)
150
- parse_mask = cv2.dilate(parse_mask, np.ones((5, 5), np.uint16), iterations=5)
151
- if category == 'dresses' or category == 'upper_body':
152
- neck_mask = (parse_array == 18).astype(np.float32)
153
- neck_mask = cv2.dilate(neck_mask, np.ones((5, 5), np.uint16), iterations=1)
154
- neck_mask = np.logical_and(neck_mask, np.logical_not(parse_head))
155
- parse_mask = np.logical_or(parse_mask, neck_mask)
156
- arm_mask = cv2.dilate(np.logical_or(im_arms_left, im_arms_right).astype('float32'), np.ones((5, 5), np.uint16), iterations=4)
157
- parse_mask += np.logical_or(parse_mask, arm_mask)
158
-
159
- parse_mask = np.logical_and(parser_mask_changeable, np.logical_not(parse_mask))
160
-
161
- parse_mask_total = np.logical_or(parse_mask, parser_mask_fixed)
162
- inpaint_mask = 1 - parse_mask_total
163
- img = np.where(inpaint_mask, 255, 0)
164
- dst = hole_fill(img.astype(np.uint8))
165
- dst = refine_mask(dst)
166
- inpaint_mask = dst / 255 * 1
167
- mask = Image.fromarray(inpaint_mask.astype(np.uint8) * 255)
168
- mask_gray = Image.fromarray(inpaint_mask.astype(np.uint8) * 127)
169
-
170
- return mask, mask_gray
 
1
+ import pdb
2
+
3
+ import numpy as np
4
+ import cv2
5
+ from PIL import Image, ImageDraw
6
+
7
+ label_map = {
8
+ "background": 0,
9
+ "hat": 1,
10
+ "hair": 2,
11
+ "sunglasses": 3,
12
+ "upper_clothes": 4,
13
+ "skirt": 5,
14
+ "pants": 6,
15
+ "dress": 7,
16
+ "belt": 8,
17
+ "left_shoe": 9,
18
+ "right_shoe": 10,
19
+ "head": 11,
20
+ "left_leg": 12,
21
+ "right_leg": 13,
22
+ "left_arm": 14,
23
+ "right_arm": 15,
24
+ "bag": 16,
25
+ "scarf": 17,
26
+ }
27
+
28
+ def extend_arm_mask(wrist, elbow, scale):
29
+ wrist = elbow + scale * (wrist - elbow)
30
+ return wrist
31
+
32
+ def hole_fill(img):
33
+ img = np.pad(img[1:-1, 1:-1], pad_width = 1, mode = 'constant', constant_values=0)
34
+ img_copy = img.copy()
35
+ mask = np.zeros((img.shape[0] + 2, img.shape[1] + 2), dtype=np.uint8)
36
+
37
+ cv2.floodFill(img, mask, (0, 0), 255)
38
+ img_inverse = cv2.bitwise_not(img)
39
+ dst = cv2.bitwise_or(img_copy, img_inverse)
40
+ return dst
41
+
42
+ def refine_mask(mask):
43
+ contours, hierarchy = cv2.findContours(mask.astype(np.uint8),
44
+ cv2.RETR_CCOMP, cv2.CHAIN_APPROX_TC89_L1)
45
+ area = []
46
+ for j in range(len(contours)):
47
+ a_d = cv2.contourArea(contours[j], True)
48
+ area.append(abs(a_d))
49
+ refine_mask = np.zeros_like(mask).astype(np.uint8)
50
+ if len(area) != 0:
51
+ i = area.index(max(area))
52
+ cv2.drawContours(refine_mask, contours, i, color=255, thickness=-1)
53
+
54
+ return refine_mask
55
+
56
+ def get_mask_location(model_type, category, model_parse: Image.Image, keypoint: dict, width=384,height=512):
57
+ im_parse = model_parse.resize((width, height), Image.NEAREST)
58
+ parse_array = np.array(im_parse)
59
+
60
+ if model_type == 'hd':
61
+ arm_width = 60
62
+ elif model_type == 'dc':
63
+ arm_width = 45
64
+ else:
65
+ raise ValueError("model_type must be \'hd\' or \'dc\'!")
66
+
67
+ parse_head = (parse_array == 1).astype(np.float32) + \
68
+ (parse_array == 3).astype(np.float32) + \
69
+ (parse_array == 11).astype(np.float32)
70
+
71
+ parser_mask_fixed = (parse_array == label_map["left_shoe"]).astype(np.float32) + \
72
+ (parse_array == label_map["right_shoe"]).astype(np.float32) + \
73
+ (parse_array == label_map["hat"]).astype(np.float32) + \
74
+ (parse_array == label_map["sunglasses"]).astype(np.float32) + \
75
+ (parse_array == label_map["bag"]).astype(np.float32)
76
+
77
+ parser_mask_changeable = (parse_array == label_map["background"]).astype(np.float32)
78
+
79
+ arms_left = (parse_array == 14).astype(np.float32)
80
+ arms_right = (parse_array == 15).astype(np.float32)
81
+ arms = arms_left + arms_right
82
+
83
+ if category == 'dresses':
84
+ parse_mask = (parse_array == 7).astype(np.float32) + \
85
+ (parse_array == 4).astype(np.float32) + \
86
+ (parse_array == 5).astype(np.float32) + \
87
+ (parse_array == 6).astype(np.float32)
88
+
89
+ parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
90
+
91
+ elif category == 'upper_body':
92
+ parse_mask = (parse_array == 4).astype(np.float32) + (parse_array == 7).astype(np.float32)
93
+ parser_mask_fixed_lower_cloth = (parse_array == label_map["skirt"]).astype(np.float32) + \
94
+ (parse_array == label_map["pants"]).astype(np.float32)
95
+ parser_mask_fixed += parser_mask_fixed_lower_cloth
96
+ parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
97
+ elif category == 'lower_body':
98
+ parse_mask = (parse_array == 6).astype(np.float32) + \
99
+ (parse_array == 12).astype(np.float32) + \
100
+ (parse_array == 13).astype(np.float32) + \
101
+ (parse_array == 5).astype(np.float32)
102
+ parser_mask_fixed += (parse_array == label_map["upper_clothes"]).astype(np.float32) + \
103
+ (parse_array == 14).astype(np.float32) + \
104
+ (parse_array == 15).astype(np.float32)
105
+ parser_mask_changeable += np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
106
+ else:
107
+ raise NotImplementedError
108
+
109
+ # Load pose points
110
+ pose_data = keypoint["pose_keypoints_2d"]
111
+ pose_data = np.array(pose_data)
112
+ pose_data = pose_data.reshape((-1, 2))
113
+
114
+ im_arms_left = Image.new('L', (width, height))
115
+ im_arms_right = Image.new('L', (width, height))
116
+ arms_draw_left = ImageDraw.Draw(im_arms_left)
117
+ arms_draw_right = ImageDraw.Draw(im_arms_right)
118
+ if category == 'dresses' or category == 'upper_body':
119
+ shoulder_right = np.multiply(tuple(pose_data[2][:2]), height / 512.0)
120
+ shoulder_left = np.multiply(tuple(pose_data[5][:2]), height / 512.0)
121
+ elbow_right = np.multiply(tuple(pose_data[3][:2]), height / 512.0)
122
+ elbow_left = np.multiply(tuple(pose_data[6][:2]), height / 512.0)
123
+ wrist_right = np.multiply(tuple(pose_data[4][:2]), height / 512.0)
124
+ wrist_left = np.multiply(tuple(pose_data[7][:2]), height / 512.0)
125
+ ARM_LINE_WIDTH = int(arm_width / 512 * height)
126
+ size_left = [shoulder_left[0] - ARM_LINE_WIDTH // 2, shoulder_left[1] - ARM_LINE_WIDTH // 2, shoulder_left[0] + ARM_LINE_WIDTH // 2, shoulder_left[1] + ARM_LINE_WIDTH // 2]
127
+ size_right = [shoulder_right[0] - ARM_LINE_WIDTH // 2, shoulder_right[1] - ARM_LINE_WIDTH // 2, shoulder_right[0] + ARM_LINE_WIDTH // 2,
128
+ shoulder_right[1] + ARM_LINE_WIDTH // 2]
129
+
130
+
131
+ if wrist_right[0] <= 1. and wrist_right[1] <= 1.:
132
+ im_arms_right = arms_right
133
+ else:
134
+ wrist_right = extend_arm_mask(wrist_right, elbow_right, 1.2)
135
+ arms_draw_right.line(np.concatenate((shoulder_right, elbow_right, wrist_right)).astype(np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve')
136
+ arms_draw_right.arc(size_right, 0, 360, 'white', ARM_LINE_WIDTH // 2)
137
+
138
+ if wrist_left[0] <= 1. and wrist_left[1] <= 1.:
139
+ im_arms_left = arms_left
140
+ else:
141
+ wrist_left = extend_arm_mask(wrist_left, elbow_left, 1.2)
142
+ arms_draw_left.line(np.concatenate((wrist_left, elbow_left, shoulder_left)).astype(np.uint16).tolist(), 'white', ARM_LINE_WIDTH, 'curve')
143
+ arms_draw_left.arc(size_left, 0, 360, 'white', ARM_LINE_WIDTH // 2)
144
+
145
+ hands_left = np.logical_and(np.logical_not(im_arms_left), arms_left)
146
+ hands_right = np.logical_and(np.logical_not(im_arms_right), arms_right)
147
+ parser_mask_fixed += hands_left + hands_right
148
+
149
+ parser_mask_fixed = np.logical_or(parser_mask_fixed, parse_head)
150
+ parse_mask = cv2.dilate(parse_mask, np.ones((5, 5), np.uint16), iterations=5)
151
+ if category == 'dresses' or category == 'upper_body':
152
+ neck_mask = (parse_array == 18).astype(np.float32)
153
+ neck_mask = cv2.dilate(neck_mask, np.ones((5, 5), np.uint16), iterations=1)
154
+ neck_mask = np.logical_and(neck_mask, np.logical_not(parse_head))
155
+ parse_mask = np.logical_or(parse_mask, neck_mask)
156
+ arm_mask = cv2.dilate(np.logical_or(im_arms_left, im_arms_right).astype('float32'), np.ones((5, 5), np.uint16), iterations=4)
157
+ parse_mask += np.logical_or(parse_mask, arm_mask)
158
+
159
+ parse_mask = np.logical_and(parser_mask_changeable, np.logical_not(parse_mask))
160
+
161
+ parse_mask_total = np.logical_or(parse_mask, parser_mask_fixed)
162
+ inpaint_mask = 1 - parse_mask_total
163
+ img = np.where(inpaint_mask, 255, 0)
164
+ dst = hole_fill(img.astype(np.uint8))
165
+ dst = refine_mask(dst)
166
+ inpaint_mask = dst / 255 * 1
167
+ mask = Image.fromarray(inpaint_mask.astype(np.uint8) * 255)
168
+ mask_gray = Image.fromarray(inpaint_mask.astype(np.uint8) * 127)
169
+
170
+ return mask, mask_gray