m3face commited on
Commit
332190f
·
1 Parent(s): 92a5ecd

Adding files

Browse files
README.md ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - text-to-image
4
+ - controlnet
5
+ ---
6
+
7
+ # M<sup>3</sup>Face Model Card
8
+ We introduce M<sup>3</sup>Face, a unified multi-modal multilingual framework for controllable face generation and editing. This framework enables users to utilize only text input to generate controlling modalities automatically, for instance, semantic segmentation or facial landmarks, and subsequently generate face images.
9
+
10
+ ## Getting Started
11
+
12
+ ### Installation
13
+ 1. Clone our repository:
14
+
15
+ ```bash
16
+ git clone https://huggingface.co/m3face/m3face
17
+ cd m3face
18
+ ```
19
+
20
+ 2. Install dependencies:
21
+
22
+ ```bash
23
+ pip install -r requirements.txt
24
+ ```
25
+
26
+ ### Resources
27
+ - For face generation, VRAM of 10 GB+ for 512x512 images is required.
28
+ - For face editing, VRAM of 14 GB+ for 512x512 images is required.
29
+
30
+ ### Pre-trained Models
31
+ You can find the checkpoints for the ControlNet model at [`m3face/ControlnetModels`](https://huggingface.co/m3face/ControlnetModels) and the mask/landmark generator model at [`m3face/FaceConditioning`](https://huggingface.co/m3face/FaceConditioning).
32
+
33
+ ### M<sup>3</sup>CelebA Dataset
34
+ The M<sup>3</sup>CelebA Dataset is available at [`m3face/M3CelebA`](https://huggingface.co/m3face/M3CelebA). You can view or download it from there.
35
+
36
+ ## Face Generation
37
+ You can do face generation with text, segmentation mask, facial landmarks, or a combination of them by running the following command:
38
+ ```bash
39
+ python generate.py --seed 1111 \
40
+ --condition "landmark" \
41
+ --prompt "This attractive woman has narrow eyes, rosy cheeks, and wears heavy makeup." \
42
+ --save_condition
43
+ ```
44
+ You can define the type of conditioning modality with `--condition`. By default, a conditioning modality will be generated by our framework and will be saved if the `--save_condition` argument is given. Otherwise, you can use your condition image with the `condition_path` argument.
45
+
46
+ ## Face Editing
47
+ For face editing, you can run the following command:
48
+ ```bash
49
+ python edit.py --enable_xformers_memory_efficient_attention \
50
+ --seed 1111 \
51
+ --condition "landmark" \
52
+ --prompt "She is a smiling." \
53
+ --image_path "/path/to/image" \
54
+ --condition_path "/path/to/condition" \
55
+ --edit_condition \
56
+ --embedding_optimize_it 500 \
57
+ --model_finetune_it 1000 \
58
+ --alpha 0.7 1 1.1 \
59
+ --num_inference_steps 30 \
60
+ --unet_layer "2and3"
61
+ ```
62
+ You need to specify the input image and original conditioning modality. You can edit the face with an edit conditioning modality (specifying `--edit_condition_path`) or by editing the original conditioning modality with our framework (specifying `--edit_condition`).
63
+ The `--unet_layer` argument specifies which UNet layers in the SD to finetune.
64
+
65
+ > Note: If you don't have the original conditioning modality you can simply generate it using the `plot_mask.py` and `plot_landmark.py` scripts:
66
+ ```bash
67
+ pip install git+https://github.com/mapillary/inplace_abn
68
+ python utils/plot_mask.py --image_path "/path/to/image"
69
+ python utils/plot_landmark.py --image_path "/path/to/image"
70
+ ```
71
+
72
+ ## Training
73
+ The code and instruction for training our models will be posted soon!
data/landmarks.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e129223b20f017a389b04ffe65ac0fd047f03a2bd9ef5bcb9eb0358b2b50a85
3
+ size 688
data/landmarks/1.jpg ADDED
data/landmarks/2.jpg ADDED
data/landmarks/3.jpg ADDED
data/landmarks/4.jpg ADDED
data/masks/1.png ADDED
data/masks/2.png ADDED
data/masks/3.png ADDED
data/masks/4.png ADDED
docs/pull-figure.png ADDED
edit.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from tqdm.auto import tqdm
4
+ from packaging import version
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.utils.checkpoint
9
+ from torchvision import transforms
10
+ from diffusers import (
11
+ AutoencoderKL,
12
+ ControlNetModel,
13
+ DDPMScheduler,
14
+ StableDiffusionControlNetPipeline,
15
+ UNet2DConditionModel,
16
+ UniPCMultistepScheduler,
17
+ PNDMScheduler,
18
+ AmusedInpaintPipeline, AmusedScheduler, VQModel, UVit2DModel
19
+
20
+ )
21
+ from diffusers.utils.import_utils import is_xformers_available
22
+ from diffusers.utils import load_image
23
+ from transformers import AutoTokenizer, CLIPFeatureExtractor, PretrainedConfig
24
+ from PIL import Image
25
+ from utils.mclip import *
26
+
27
+
28
+ def parse_args():
29
+ parser = argparse.ArgumentParser(description="Edit images with M3Face.")
30
+ parser.add_argument(
31
+ "--prompt",
32
+ type=str,
33
+ default="This attractive woman has narrow eyes, rosy cheeks, and wears heavy makeup.",
34
+ help="The input text prompt for image generation."
35
+ )
36
+ parser.add_argument(
37
+ "--condition",
38
+ type=str,
39
+ default="mask",
40
+ choices=["mask", "landmark"],
41
+ help="Use segmentation mask or facial landmarks for image generation."
42
+ )
43
+ parser.add_argument(
44
+ "--image_path",
45
+ type=str,
46
+ default=None,
47
+ help="Path to the input image."
48
+ )
49
+ parser.add_argument(
50
+ "--condition_path",
51
+ type=str,
52
+ default=None,
53
+ help="Path to the original mask/landmark image."
54
+ )
55
+ parser.add_argument(
56
+ "--edit_condition_path",
57
+ type=str,
58
+ default=None,
59
+ help="Path to the target mask/landmark image."
60
+ )
61
+ parser.add_argument(
62
+ "--output_dir",
63
+ type=str,
64
+ default='output/',
65
+ help="The output directory where the results will be written.",
66
+ )
67
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible generation.")
68
+ parser.add_argument(
69
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
70
+ )
71
+ parser.add_argument("--edit_condition", action="store_true")
72
+ parser.add_argument("--load_unet_from_local", action="store_true")
73
+ parser.add_argument("--save_unet", action="store_true")
74
+ parser.add_argument("--unet_local_path", type=str, default=None)
75
+ parser.add_argument("--load_finetune_from_local", action="store_true")
76
+ parser.add_argument("--finetune_path", type=str, default=None)
77
+ parser.add_argument("--use_english", action="store_true", help="Use the English models.")
78
+ parser.add_argument("--embedding_optimize_it", type=int, default=500)
79
+ parser.add_argument("--model_finetune_it", type=int, default=1000)
80
+ parser.add_argument("--alpha", nargs="+", type=float, default=[0.8, 0.9, 1, 1.1])
81
+ parser.add_argument("--num_inference_steps", nargs="+", type=int, default=[20, 40, 50])
82
+ parser.add_argument("--unet_layer", type=str, default="2and3",
83
+ help="Which UNet layers in the SD to finetune.")
84
+
85
+ args = parser.parse_args()
86
+
87
+ return args
88
+
89
+ def get_muse(args):
90
+ muse_model_name = 'm3face/FaceConditioning'
91
+ if args.condition == 'mask':
92
+ muse_revision = 'segmentation'
93
+ elif args.condition == 'landmark':
94
+ muse_revision = 'landmark'
95
+ scheduler = AmusedScheduler.from_pretrained(muse_model_name, revision=muse_revision, subfolder='scheduler')
96
+ vqvae = VQModel.from_pretrained(muse_model_name, revision=muse_revision, subfolder='vqvae')
97
+ uvit2 = UVit2DModel.from_pretrained(muse_model_name, revision=muse_revision, subfolder='transformer')
98
+ text_encoder = MultilingualCLIP.from_pretrained(muse_model_name, revision=muse_revision, subfolder='text_encoder')
99
+ tokenizer = AutoTokenizer.from_pretrained(muse_model_name, revision=muse_revision, subfolder='tokenizer')
100
+
101
+ pipeline = AmusedInpaintPipeline(
102
+ vqvae=vqvae,
103
+ tokenizer=tokenizer,
104
+ text_encoder=text_encoder,
105
+ transformer=uvit2,
106
+ scheduler=scheduler
107
+ ).to("cuda")
108
+
109
+ return pipeline
110
+
111
+ def import_model_class_from_model_name(sd_model_name):
112
+ text_encoder_config = PretrainedConfig.from_pretrained(
113
+ sd_model_name,
114
+ subfolder="text_encoder",
115
+ )
116
+ model_class = text_encoder_config.architectures[0]
117
+
118
+ if model_class == "CLIPTextModel":
119
+ from transformers import CLIPTextModel
120
+
121
+ return CLIPTextModel
122
+ elif model_class == "RobertaSeriesModelWithTransformation":
123
+ from diffusers.pipelines.deprecated.alt_diffusion import RobertaSeriesModelWithTransformation
124
+
125
+ return RobertaSeriesModelWithTransformation
126
+ else:
127
+ raise ValueError(f"{model_class} is not supported.")
128
+
129
+ def preprocess(image, condition, prompt, tokenizer):
130
+ image_transforms = transforms.Compose(
131
+ [
132
+ transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
133
+ transforms.CenterCrop(512),
134
+ transforms.ToTensor(),
135
+ transforms.Normalize([0.5], [0.5]),
136
+ ]
137
+ )
138
+ condition_transforms = transforms.Compose(
139
+ [
140
+ transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
141
+ transforms.CenterCrop(512),
142
+ transforms.ToTensor(),
143
+ ]
144
+ )
145
+ image = image_transforms(image)
146
+ condition = condition_transforms(condition)
147
+ inputs = tokenizer(
148
+ [prompt], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
149
+ )
150
+
151
+ return image, condition, inputs.input_ids, inputs.attention_mask
152
+
153
+ def main(args):
154
+ if args.use_english:
155
+ sd_model_name = 'runwayml/stable-diffusion-v1-5'
156
+ controlnet_model_name = 'm3face/ControlnetModels'
157
+ if args.condition == 'mask':
158
+ controlnet_revision = 'segmentation-english'
159
+ elif args.condition == 'landmark':
160
+ controlnet_revision = 'landmark-english'
161
+ else:
162
+ sd_model_name = 'BAAI/AltDiffusion-m18'
163
+ controlnet_model_name = 'm3face/ControlnetModels'
164
+ if args.condition == 'mask':
165
+ controlnet_revision = 'segmentation-mlin'
166
+ elif args.condition == 'landmark':
167
+ controlnet_revision = 'landmark-mlin'
168
+
169
+ # ========== set up models ==========
170
+ vae = AutoencoderKL.from_pretrained(sd_model_name, subfolder="vae")
171
+ tokenizer = AutoTokenizer.from_pretrained(sd_model_name, subfolder="tokenizer", use_fast=False)
172
+ text_encoder_cls = import_model_class_from_model_name(sd_model_name)
173
+ text_encoder = text_encoder_cls.from_pretrained(sd_model_name, subfolder="text_encoder")
174
+ noise_scheduler = DDPMScheduler.from_pretrained(sd_model_name, subfolder="scheduler")
175
+
176
+ if args.load_unet_from_local:
177
+ unet = UNet2DConditionModel.from_pretrained(args.unet_local_path)
178
+ else:
179
+ unet = UNet2DConditionModel.from_pretrained(sd_model_name, subfolder="unet")
180
+
181
+ controlnet = ControlNetModel.from_pretrained(controlnet_model_name, revision=controlnet_revision)
182
+
183
+ if args.edit_condition:
184
+ muse = get_muse(args)
185
+
186
+ vae.requires_grad_(False)
187
+ text_encoder.requires_grad_(False)
188
+ controlnet.requires_grad_(False)
189
+ unet.requires_grad_(False)
190
+ vae.eval()
191
+ text_encoder.eval()
192
+ controlnet.eval()
193
+ unet.eval()
194
+
195
+ if args.enable_xformers_memory_efficient_attention:
196
+ if is_xformers_available():
197
+ import xformers
198
+
199
+ xformers_version = version.parse(xformers.__version__)
200
+ if xformers_version == version.parse("0.0.16"):
201
+ print(
202
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
203
+ )
204
+ unet.enable_xformers_memory_efficient_attention()
205
+ controlnet.enable_xformers_memory_efficient_attention()
206
+ else:
207
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
208
+
209
+ # ========== select params to optimize ==========
210
+ params = []
211
+ for name, param in unet.named_parameters():
212
+ if(name.startswith('up_blocks')):
213
+ params.append(param)
214
+
215
+ if args.unet_layer == 'only1': # 116 layers
216
+ params_to_optimize = [
217
+ {'params': params[38:154]},
218
+ ]
219
+ elif args.unet_layer == 'only2': # 116 layers
220
+ params_to_optimize = [
221
+ {'params': params[154:270]},
222
+ ]
223
+ elif args.unet_layer == 'only3': # 114 layers
224
+ params_to_optimize = [
225
+ {'params': params[270:]},
226
+ ]
227
+ elif args.unet_layer == '1and2': # 232 layers
228
+ params_to_optimize = [
229
+ {'params': params[38:270]},
230
+ ]
231
+ elif args.unet_layer == '2and3': # 230 layers
232
+ params_to_optimize = [
233
+ {'params': params[154:]},
234
+ ]
235
+ elif args.unet_layer == 'all': # all layers
236
+ params_to_optimize = [
237
+ {'params': params},
238
+ ]
239
+
240
+ image = Image.open(args.image_path).convert('RGB')
241
+ condition = Image.open(args.condition_path).convert('RGB')
242
+ image, condition, input_ids, attention_mask = preprocess(image, condition, args.prompt, tokenizer)
243
+
244
+ # Move to device
245
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
246
+ vae.to(device, dtype=torch.float32)
247
+ unet.to(device, dtype=torch.float32)
248
+ text_encoder.to(device, dtype=torch.float32)
249
+ controlnet.to(device)
250
+ image = image.to(device).unsqueeze(0)
251
+ condition = condition.to(device).unsqueeze(0)
252
+ input_ids = input_ids.to(device)
253
+ attention_mask = attention_mask.to(device)
254
+
255
+ # ========== imagic ==========
256
+ if args.load_finetune_from_local:
257
+ print('Loading embeddings from local ...')
258
+ orig_emb = torch.load(os.path.join(args.finetune_path, 'orig_emb.pt'))
259
+ emb = torch.load(os.path.join(args.finetune_path, 'emb.pt'))
260
+ else:
261
+ init_latent = vae.encode(image.to(dtype=torch.float32)).latent_dist.sample()
262
+ init_latent = init_latent * vae.config.scaling_factor
263
+
264
+ if not args.use_english:
265
+ orig_emb = text_encoder(input_ids, attention_mask=attention_mask)[0]
266
+ else:
267
+ orig_emb = text_encoder(input_ids)[0]
268
+ emb = orig_emb.clone()
269
+ torch.save(orig_emb, os.path.join(args.output_dir, 'orig_emb.pt'))
270
+ torch.save(emb, os.path.join(args.output_dir, 'emb.pt'))
271
+
272
+ # 1. Optimize the embedding
273
+ print('1. Optimize the embedding')
274
+ unet.eval()
275
+ emb.requires_grad = True
276
+ lr = 0.001
277
+ it = args.embedding_optimize_it # 500
278
+ opt = torch.optim.Adam([emb], lr=lr)
279
+ history = []
280
+
281
+ pbar = tqdm(
282
+ range(it),
283
+ initial=0,
284
+ desc="Optimize Steps",
285
+ )
286
+ global_step = 0
287
+
288
+ for i in pbar:
289
+ opt.zero_grad()
290
+
291
+ noise = torch.randn_like(init_latent)
292
+ bsz = init_latent.shape[0]
293
+ t_enc = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=init_latent.device)
294
+ t_enc = t_enc.long()
295
+ z = noise_scheduler.add_noise(init_latent, noise, t_enc)
296
+
297
+ controlnet_image = condition.to(dtype=torch.float32)
298
+
299
+ down_block_res_samples, mid_block_res_sample = controlnet(
300
+ z,
301
+ t_enc,
302
+ encoder_hidden_states=emb,
303
+ controlnet_cond=controlnet_image,
304
+ return_dict=False,
305
+ )
306
+
307
+ # Predict the noise residual
308
+ pred_noise = unet(
309
+ z,
310
+ t_enc,
311
+ encoder_hidden_states=emb,
312
+ down_block_additional_residuals=[
313
+ sample.to(dtype=torch.float32) for sample in down_block_res_samples
314
+ ],
315
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=torch.float32),
316
+ ).sample
317
+
318
+ # Get the target for loss depending on the prediction type
319
+ if noise_scheduler.config.prediction_type == "epsilon":
320
+ target = noise
321
+ elif noise_scheduler.config.prediction_type == "v_prediction":
322
+ target = noise_scheduler.get_velocity(init_latent, noise, t_enc)
323
+ else:
324
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
325
+ loss = F.mse_loss(pred_noise.float(), target.float(), reduction="mean")
326
+
327
+ loss.backward()
328
+ global_step += 1
329
+ pbar.set_postfix({"loss": loss.item()})
330
+ history.append(loss.item())
331
+ opt.step()
332
+ opt.zero_grad()
333
+
334
+ # 2. Finetune the model
335
+ print('2. Finetune the model')
336
+ emb.requires_grad = False
337
+ unet.requires_grad_(True)
338
+ unet.train()
339
+
340
+ lr = 5e-5
341
+ it = args.model_finetune_it # 1000
342
+ opt = torch.optim.Adam(params_to_optimize, lr=lr)
343
+ history = []
344
+
345
+ pbar = tqdm(
346
+ range(it),
347
+ initial=0,
348
+ desc="Finetune Steps",
349
+ )
350
+ global_step = 0
351
+ for i in pbar:
352
+ opt.zero_grad()
353
+
354
+ noise = torch.randn_like(init_latent)
355
+ bsz = init_latent.shape[0]
356
+ t_enc = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=init_latent.device)
357
+ t_enc = t_enc.long()
358
+ z = noise_scheduler.add_noise(init_latent, noise, t_enc)
359
+
360
+ controlnet_image = condition.to(dtype=torch.float32)
361
+
362
+ down_block_res_samples, mid_block_res_sample = controlnet(
363
+ z,
364
+ t_enc,
365
+ encoder_hidden_states=emb,
366
+ controlnet_cond=controlnet_image,
367
+ return_dict=False,
368
+ )
369
+
370
+ # Predict the noise residual
371
+ pred_noise = unet(
372
+ z,
373
+ t_enc,
374
+ encoder_hidden_states=emb,
375
+ down_block_additional_residuals=[
376
+ sample.to(dtype=torch.float32) for sample in down_block_res_samples
377
+ ],
378
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=torch.float32),
379
+ ).sample
380
+
381
+ # Get the target for loss depending on the prediction type
382
+ if noise_scheduler.config.prediction_type == "epsilon":
383
+ target = noise
384
+ elif noise_scheduler.config.prediction_type == "v_prediction":
385
+ target = noise_scheduler.get_velocity(init_latent, noise, t_enc)
386
+ else:
387
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
388
+ loss = F.mse_loss(pred_noise.float(), target.float(), reduction="mean")
389
+
390
+ loss.backward()
391
+ global_step += 1
392
+ pbar.set_postfix({"loss": loss.item()})
393
+ history.append(loss.item())
394
+ opt.step()
395
+ opt.zero_grad()
396
+
397
+ # 3. Generate Images
398
+ print("3. Generating images... ")
399
+
400
+ unet.eval()
401
+ controlnet.eval()
402
+
403
+ if args.edit_condition_path is None:
404
+ edit_condition = load_image(args.condition_path)
405
+ else:
406
+ edit_condition = load_image(args.edit_condition_path)
407
+ if args.edit_condition:
408
+ edit_mask = Image.new("L", (256, 256), 0)
409
+ for i in range(256):
410
+ for j in range(256):
411
+ if 40 < i < 220 and 20 < j < 256:
412
+ edit_mask.putpixel((i, j), 256)
413
+
414
+ if args.condition == 'mask':
415
+ condition = 'segmentation'
416
+ elif args.condition == 'landmark':
417
+ condition = 'landmark'
418
+ edit_prompt = f"Generate face {condition} | " + args.prompt
419
+ input_image = edit_condition.resize((256, 256)).convert("RGB")
420
+ edit_condition = muse(edit_prompt, input_image, edit_mask, num_inference_steps=30).images[0].resize((512, 512))
421
+ edit_condition.save(f'{args.output_dir}/edited_condition.png')
422
+
423
+ # remove muse and empty cache
424
+ del muse
425
+ torch.cuda.empty_cache()
426
+
427
+ if sd_model_name.startswith('BAAI'):
428
+ scheduler = PNDMScheduler.from_pretrained(
429
+ sd_model_name,
430
+ subfolder='scheduler',
431
+ )
432
+ scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
433
+ feature_extractor = CLIPFeatureExtractor.from_pretrained(
434
+ sd_model_name,
435
+ subfolder='feature_extractor',
436
+ )
437
+ pipeline = StableDiffusionControlNetPipeline(
438
+ vae=vae,
439
+ text_encoder=text_encoder,
440
+ tokenizer=tokenizer,
441
+ unet=unet,
442
+ controlnet=controlnet,
443
+ scheduler=scheduler,
444
+ safety_checker=None,
445
+ feature_extractor=feature_extractor
446
+ )
447
+ else:
448
+ pipeline = StableDiffusionControlNetPipeline.from_pretrained(
449
+ sd_model_name,
450
+ vae=vae,
451
+ text_encoder=text_encoder,
452
+ tokenizer=tokenizer,
453
+ unet=unet,
454
+ controlnet=controlnet,
455
+ safety_checker=None,
456
+ )
457
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
458
+ pipeline = pipeline.to(device)
459
+ pipeline.set_progress_bar_config(disable=True)
460
+
461
+ if args.enable_xformers_memory_efficient_attention:
462
+ pipeline.enable_xformers_memory_efficient_attention()
463
+
464
+ if args.seed is None:
465
+ generator = None
466
+ else:
467
+ generator = torch.Generator(device=device).manual_seed(args.seed)
468
+
469
+ with torch.autocast("cuda"):
470
+ image = pipeline(
471
+ image=edit_condition, prompt_embeds=emb, num_inference_steps=20, generator=generator
472
+ ).images[0]
473
+ image.save(f'{args.output_dir}/reconstruct.png')
474
+
475
+ # Interpolate the embedding
476
+ for num_inference_steps in args.num_inference_steps:
477
+ for alpha in args.alpha:
478
+ new_emb = alpha * orig_emb + (1 - alpha) * emb
479
+
480
+ with torch.autocast("cuda"):
481
+ image = pipeline(
482
+ image=edit_condition, prompt_embeds=new_emb, num_inference_steps=num_inference_steps, generator=generator
483
+ ).images[0]
484
+ image.save(f'{args.output_dir}/image_{num_inference_steps}_{alpha}.png')
485
+
486
+ if args.save_unet:
487
+ print('Saving the unet model...')
488
+ unet.save_pretrained(f'{args.output_dir}/unet')
489
+
490
+
491
+ if __name__ == '__main__':
492
+ args = parse_args()
493
+ main(args)
generate.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, time
2
+ import torch
3
+ from diffusers import (
4
+ AutoencoderKL,
5
+ ControlNetModel,
6
+ StableDiffusionControlNetPipeline,
7
+ UNet2DConditionModel,
8
+ UniPCMultistepScheduler,
9
+ PNDMScheduler,
10
+ AmusedPipeline, AmusedScheduler, VQModel, UVit2DModel
11
+ )
12
+ from transformers import AutoTokenizer, CLIPFeatureExtractor
13
+ from diffusers.pipelines.deprecated.alt_diffusion import RobertaSeriesModelWithTransformation
14
+ from diffusers.utils import load_image
15
+ from utils.mclip import *
16
+
17
+
18
+ def parse_args():
19
+ parser = argparse.ArgumentParser(description="Generate images with M3Face.")
20
+ parser.add_argument(
21
+ "--prompt",
22
+ type=str,
23
+ default="This attractive woman has narrow eyes, rosy cheeks, and wears heavy makeup.",
24
+ help="The input text prompt for image generation."
25
+ )
26
+ parser.add_argument(
27
+ "--condition",
28
+ type=str,
29
+ default="mask",
30
+ choices=["mask", "landmark"],
31
+ help="Use segmentation mask or facial landmarks for image generation."
32
+ )
33
+ parser.add_argument(
34
+ "--condition_path",
35
+ type=str,
36
+ default=None,
37
+ help="Path to the condition mask/landmark image. We will generate the condition if it is not given."
38
+ )
39
+ parser.add_argument("--save_condition", action="store_true", help="Save the generated condition image.")
40
+ parser.add_argument("--use_english", action="store_true", help="Use the English models.")
41
+ parser.add_argument("--enhance_prompt", action="store_true", help="Enhance the given text prompt.")
42
+ parser.add_argument("--num_inference_steps", type=int, default=30)
43
+ parser.add_argument("--num_samples", type=int, default=1)
44
+ parser.add_argument(
45
+ "--additional_prompt",
46
+ type=str,
47
+ default="rim lighting, dslr, ultra quality, sharp focus, dof, Fujifilm XT3, crystal clear, highly detailed glossy eyes, high detailed skin, skin pores, 8K UHD"
48
+ )
49
+ parser.add_argument(
50
+ "--negative_prompt",
51
+ type=str,
52
+ default="low quality, bad quality, worst quality, blurry, disfigured, ugly, immature, cartoon, painting"
53
+ )
54
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible generation.")
55
+ parser.add_argument(
56
+ "--output_dir",
57
+ type=str,
58
+ default="output/",
59
+ help="The output directory where the results will be written.",
60
+ )
61
+ args = parser.parse_args()
62
+
63
+ return args
64
+
65
+ def get_controlnet(args):
66
+ if args.use_english:
67
+ sd_model_name = 'runwayml/stable-diffusion-v1-5'
68
+ controlnet_model_name = 'm3face/ControlnetModels'
69
+ if args.condition == 'mask':
70
+ controlnet_revision = 'segmentation-english'
71
+ elif args.condition == 'landmark':
72
+ controlnet_revision = 'landmark-english'
73
+ controlnet = ControlNetModel.from_pretrained(controlnet_model_name, use_safetensors=True, revision=controlnet_revision)
74
+ pipeline = StableDiffusionControlNetPipeline.from_pretrained(
75
+ sd_model_name, controlnet=controlnet, use_safetensors=True, safety_checker=None
76
+ ).to("cuda")
77
+
78
+ pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
79
+ pipeline.enable_model_cpu_offload()
80
+ else:
81
+ sd_model_name = 'BAAI/AltDiffusion-m18'
82
+ controlnet_model_name = 'm3face/ControlnetModels'
83
+ if args.condition == 'mask':
84
+ controlnet_revision = 'segmentation-mlin'
85
+ elif args.condition == 'landmark':
86
+ controlnet_revision = 'landmark-mlin'
87
+ vae = AutoencoderKL.from_pretrained(sd_model_name, subfolder="vae")
88
+ unet = UNet2DConditionModel.from_pretrained(sd_model_name, subfolder="unet")
89
+ tokenizer = AutoTokenizer.from_pretrained(sd_model_name, subfolder="tokenizer", use_fast=False)
90
+ text_encoder = RobertaSeriesModelWithTransformation.from_pretrained(sd_model_name, subfolder="text_encoder")
91
+ controlnet = ControlNetModel.from_pretrained(controlnet_model_name, revision=controlnet_revision)
92
+
93
+ scheduler = PNDMScheduler.from_pretrained(
94
+ sd_model_name,
95
+ subfolder='scheduler',
96
+ )
97
+ scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
98
+ feature_extractor = CLIPFeatureExtractor.from_pretrained(
99
+ sd_model_name,
100
+ subfolder='feature_extractor',
101
+ )
102
+ pipeline = StableDiffusionControlNetPipeline(
103
+ vae=vae,
104
+ text_encoder=text_encoder,
105
+ tokenizer=tokenizer,
106
+ unet=unet,
107
+ controlnet=controlnet,
108
+ scheduler=scheduler,
109
+ safety_checker=None,
110
+ feature_extractor=feature_extractor,
111
+ ).to('cuda')
112
+
113
+ return pipeline
114
+
115
+
116
+ def get_muse(args):
117
+ muse_model_name = 'm3face/FaceConditioning'
118
+ if args.condition == 'mask':
119
+ muse_revision = 'segmentation'
120
+ elif args.condition == 'landmark':
121
+ muse_revision = 'landmark'
122
+ scheduler = AmusedScheduler.from_pretrained(muse_model_name, revision=muse_revision, subfolder='scheduler')
123
+ vqvae = VQModel.from_pretrained(muse_model_name, revision=muse_revision, subfolder='vqvae')
124
+ uvit2 = UVit2DModel.from_pretrained(muse_model_name, revision=muse_revision, subfolder='transformer')
125
+ text_encoder = MultilingualCLIP.from_pretrained(muse_model_name, revision=muse_revision, subfolder='text_encoder')
126
+ tokenizer = AutoTokenizer.from_pretrained(muse_model_name, revision=muse_revision, subfolder='tokenizer')
127
+
128
+ pipeline = AmusedPipeline(
129
+ vqvae=vqvae,
130
+ tokenizer=tokenizer,
131
+ text_encoder=text_encoder,
132
+ transformer=uvit2,
133
+ scheduler=scheduler
134
+ ).to("cuda")
135
+
136
+ return pipeline
137
+
138
+
139
+ if __name__ == '__main__':
140
+ args = parse_args()
141
+
142
+ # ========== set up face generation pipeline ==========
143
+ controlnet = get_controlnet(args)
144
+
145
+ # ========== set output directory ==========
146
+ os.makedirs(args.output_dir, exist_ok=True)
147
+
148
+ # ========== set random seed ==========
149
+ if args.seed is None:
150
+ generator = None
151
+ else:
152
+ generator = torch.Generator().manual_seed(args.seed)
153
+
154
+ # ========== generation ==========
155
+ id = int(time.time())
156
+ if args.condition_path:
157
+ condition = load_image(args.condition_path).resize((512, 512))
158
+ else:
159
+ # generate condition
160
+ muse = get_muse(args)
161
+ if args.condition == 'mask':
162
+ muse_added_prompt = 'Generate face segmentation | '
163
+ elif args.condition == 'landmark':
164
+ muse_added_prompt = 'Generate face landmark | '
165
+ muse_prompt = muse_added_prompt + args.prompt
166
+ condition = muse(muse_prompt, num_inference_steps=256).images[0].resize((512, 512))
167
+ if args.save_condition:
168
+ condition.save(f'{args.output_dir}/{id}_condition.png')
169
+
170
+ latents = torch.randn((args.num_samples, 4, 64, 64), generator=generator)
171
+ prompt = f'{args.prompt}, {args.additional_prompt}' if args.prompt else args.additional_prompt
172
+ images = controlnet(prompt, image=condition, num_inference_steps=args.num_inference_steps, negative_prompt=args.negative_prompt,
173
+ generator=generator, latents=latents, num_images_per_prompt=args.num_samples).images
174
+
175
+ for i, image in enumerate(images):
176
+ image.save(f'{args.output_dir}/{id}_{i}.png')
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ diffusers
2
+ datasets
3
+ transformers
4
+ accelerate
5
+ xformers==0.0.21
6
+ face-alignment
7
+ gdown
utils/dml_csr/dml_csr.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ """
4
+ @Author : Qingping Zheng
5
+ @Contact : qingpingzheng2014@gmail.com
6
+ @File : dml_csr.py
7
+ @Time : 10/01/21 00:00 PM
8
+ @Desc :
9
+ @License : Licensed under the Apache License, Version 2.0 (the "License");
10
+ @Copyright : Copyright 2015 The Authors. All Rights Reserved.
11
+ """
12
+ from __future__ import absolute_import
13
+ from __future__ import division
14
+ from __future__ import print_function
15
+
16
+
17
+ import torch.nn as nn
18
+
19
+ from torch.nn import functional as F
20
+ from inplace_abn import InPlaceABNSync
21
+ from .modules.ddgcn import DDualGCNHead
22
+ from .modules.parsing import Parsing
23
+ from .modules.edges import Edges
24
+ from .modules.util import Bottleneck
25
+
26
+
27
+ def conv3x3(in_planes, out_planes, stride=1):
28
+ "3x3 convolution with padding"
29
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
30
+ padding=1, bias=False)
31
+
32
+
33
+ class DML_CSR(nn.Module):
34
+ def __init__(self,
35
+ num_classes,
36
+ abn=InPlaceABNSync,
37
+ trained=True):
38
+ super().__init__()
39
+ self.inplanes = 128
40
+ self.is_trained = trained
41
+
42
+ self.conv1 = conv3x3(3, 64, stride=2)
43
+ self.bn1 = abn(64)
44
+ self.relu1 = nn.ReLU(inplace=False)
45
+ self.conv2 = conv3x3(64, 64)
46
+ self.bn2 = abn(64)
47
+ self.relu2 = nn.ReLU(inplace=False)
48
+ self.conv3 = conv3x3(64, 128)
49
+ self.bn3 = abn(128)
50
+ self.relu3 = nn.ReLU(inplace=False)
51
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
52
+ self.layers = [3, 4, 23, 3]
53
+ self.abn = abn
54
+ strides = [1, 2, 1, 1]
55
+ dilations = [1, 1, 1, 2]
56
+
57
+ self.layer1 = self._make_layer(Bottleneck, 64, self.layers[0], stride=strides[0], dilation=dilations[0])
58
+ self.layer2 = self._make_layer(Bottleneck, 128, self.layers[1], stride=strides[1], dilation=dilations[1])
59
+ self.layer3 = self._make_layer(Bottleneck, 256, self.layers[2], stride=strides[2], dilation=dilations[2])
60
+ self.layer4 = self._make_layer(Bottleneck, 512, self.layers[3], stride=strides[3], dilation=dilations[3], multi_grid=(1,1,1))
61
+ # Context Aware
62
+ self.context = DDualGCNHead(2048, 512, abn)
63
+ self.layer6 = Parsing(512, 256, num_classes, abn)
64
+ # edge
65
+ if self.is_trained:
66
+ self.edge_layer = Edges(abn, out_fea=num_classes)
67
+
68
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1):
69
+ downsample = None
70
+ if stride != 1 or self.inplanes != planes * block.expansion:
71
+ downsample = nn.Sequential(
72
+ nn.Conv2d(self.inplanes, planes * block.expansion,
73
+ kernel_size=1, stride=stride, bias=False),
74
+ self.abn(planes * block.expansion, affine=True))
75
+
76
+ layers = []
77
+ generate_multi_grid = lambda index, grids: grids[index%len(grids)] if isinstance(grids, tuple) else 1
78
+ layers.append(block(self.inplanes, planes, stride, abn=self.abn, dilation=dilation, downsample=downsample, multi_grid=generate_multi_grid(0, multi_grid)))
79
+ self.inplanes = planes * block.expansion
80
+ for i in range(1, blocks):
81
+ layers.append(block(self.inplanes, planes, abn=self.abn, dilation=dilation, multi_grid=generate_multi_grid(i, multi_grid)))
82
+
83
+ return nn.Sequential(*layers)
84
+
85
+ def forward(self, x):
86
+ input = x
87
+ x = self.relu1(self.bn1(self.conv1(x)))
88
+ x = self.relu2(self.bn2(self.conv2(x)))
89
+ x1 = self.relu3(self.bn3(self.conv3(x)))
90
+ x = self.maxpool(x1)
91
+ x2 = self.layer1(x) # 119 x 119
92
+ x3 = self.layer2(x2) # 60 x 60
93
+ x4 = self.layer3(x3) # 60 x 60
94
+ x5 = self.layer4(x4) # 60 x 60
95
+ x = self.context(x5)
96
+ seg, x = self.layer6(x, x2)
97
+
98
+ if self.is_trained:
99
+ binary_edge, semantic_edge, edge_fea = self.edge_layer(x2,x3,x4)
100
+ return seg, binary_edge, semantic_edge
101
+
102
+ return seg
103
+
utils/dml_csr/modules/ddgcn.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ """
4
+ @Author : Qingping Zheng
5
+ @Contact : qingpingzheng2014@gmail.com
6
+ @File : ddgcn.py
7
+ @Time : 10/01/21 00:00 PM
8
+ @Desc :
9
+ @License : Licensed under the Apache License, Version 2.0 (the "License");
10
+ @Copyright : Copyright 2022 The Authors. All Rights Reserved.
11
+ """
12
+ from __future__ import absolute_import
13
+ from __future__ import division
14
+ from __future__ import print_function
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ import torch.nn as nn
19
+
20
+ from inplace_abn import InPlaceABNSync
21
+
22
+
23
+ class SpatialGCN(nn.Module):
24
+ def __init__(self, plane, abn=InPlaceABNSync):
25
+ super(SpatialGCN, self).__init__()
26
+ inter_plane = plane // 2
27
+ self.node_k = nn.Conv2d(plane, inter_plane, kernel_size=1)
28
+ self.node_v = nn.Conv2d(plane, inter_plane, kernel_size=1)
29
+ self.node_q = nn.Conv2d(plane, inter_plane, kernel_size=1)
30
+
31
+ self.conv_wg = nn.Conv1d(inter_plane, inter_plane, kernel_size=1, bias=False)
32
+ self.bn_wg = nn.BatchNorm1d(inter_plane)
33
+ self.softmax = nn.Softmax(dim=2)
34
+
35
+ self.out = nn.Sequential(nn.Conv2d(inter_plane, plane, kernel_size=1),
36
+ abn(plane))
37
+
38
+ self.gamma = nn.Parameter(torch.zeros(1))
39
+
40
+ def forward(self, x):
41
+ # b, c, h, w = x.size()
42
+ node_k = self.node_k(x)
43
+ node_v = self.node_v(x)
44
+ node_q = self.node_q(x)
45
+ b,c,h,w = node_k.size()
46
+ node_k = node_k.view(b, c, -1).permute(0, 2, 1)
47
+ node_q = node_q.view(b, c, -1)
48
+ node_v = node_v.view(b, c, -1).permute(0, 2, 1)
49
+ # A = k * q
50
+ # AV = k * q * v
51
+ # AVW = k *(q *v) * w
52
+ AV = torch.bmm(node_q,node_v)
53
+ AV = self.softmax(AV)
54
+ AV = torch.bmm(node_k, AV)
55
+ AV = AV.transpose(1, 2).contiguous()
56
+ AVW = self.conv_wg(AV)
57
+ AVW = self.bn_wg(AVW)
58
+ AVW = AVW.view(b, c, h, -1)
59
+ # out = F.relu_(self.out(AVW) + x)
60
+ out = self.gamma * self.out(AVW) + x
61
+ return out
62
+
63
+
64
+ class DDualGCN(nn.Module):
65
+ """
66
+ Feature GCN with coordinate GCN
67
+ """
68
+ def __init__(self, planes, abn=InPlaceABNSync, ratio=4):
69
+ super(DDualGCN, self).__init__()
70
+
71
+ self.phi = nn.Conv2d(planes, planes // ratio * 2, kernel_size=1, bias=False)
72
+ self.bn_phi = abn(planes // ratio * 2)
73
+ self.theta = nn.Conv2d(planes, planes // ratio, kernel_size=1, bias=False)
74
+ self.bn_theta = abn(planes // ratio)
75
+
76
+ # Interaction Space
77
+ # Adjacency Matrix: (-)A_g
78
+ self.conv_adj = nn.Conv1d(planes // ratio, planes // ratio, kernel_size=1, bias=False)
79
+ self.bn_adj = nn.BatchNorm1d(planes // ratio)
80
+
81
+ # State Update Function: W_g
82
+ self.conv_wg = nn.Conv1d(planes // ratio * 2, planes // ratio * 2, kernel_size=1, bias=False)
83
+ self.bn_wg = nn.BatchNorm1d(planes // ratio * 2)
84
+
85
+ # last fc
86
+ self.conv3 = nn.Conv2d(planes // ratio * 2, planes, kernel_size=1, bias=False)
87
+ self.bn3 = abn(planes)
88
+
89
+ self.local = nn.Sequential(
90
+ nn.Conv2d(planes, planes, 3, groups=planes, stride=2, padding=1, bias=False),
91
+ abn(planes),
92
+ nn.Conv2d(planes, planes, 3, groups=planes, stride=2, padding=1, bias=False),
93
+ abn(planes),
94
+ nn.Conv2d(planes, planes, 3, groups=planes, stride=2, padding=1, bias=False),
95
+ abn(planes))
96
+ self.gcn_local_attention = SpatialGCN(planes, abn)
97
+
98
+ self.final = nn.Sequential(nn.Conv2d(planes * 2, planes, kernel_size=1, bias=False),
99
+ abn(planes))
100
+
101
+ self.gamma1 = nn.Parameter(torch.zeros(1))
102
+
103
+ def to_matrix(self, x):
104
+ n, c, h, w = x.size()
105
+ x = x.view(n, c, -1)
106
+ return x
107
+
108
+ def forward(self, feat):
109
+ # # # # Local # # # #
110
+ x = feat
111
+ local = self.local(feat)
112
+ local = self.gcn_local_attention(local)
113
+ local = F.interpolate(local, size=x.size()[2:], mode='bilinear', align_corners=True)
114
+ spatial_local_feat = x * local + x
115
+
116
+ # # # # Projection Space # # # #
117
+ x_sqz, b = x, x
118
+
119
+ x_sqz = self.phi(x_sqz)
120
+ x_sqz = self.bn_phi(x_sqz)
121
+ x_sqz = self.to_matrix(x_sqz)
122
+
123
+ b = self.theta(b)
124
+ b = self.bn_theta(b)
125
+ b = self.to_matrix(b)
126
+
127
+ # Project
128
+ z_idt = torch.matmul(x_sqz, b.transpose(1, 2)) # channel
129
+
130
+ # # # # Interaction Space # # # #
131
+ z = z_idt.transpose(1, 2).contiguous()
132
+
133
+ z = self.conv_adj(z)
134
+ z = self.bn_adj(z)
135
+
136
+ z = z.transpose(1, 2).contiguous()
137
+ # Laplacian smoothing: (I - A_g)Z => Z - A_gZ
138
+ z += z_idt
139
+
140
+ z = self.conv_wg(z)
141
+ z = self.bn_wg(z)
142
+
143
+ # # # # Re-projection Space # # # #
144
+ # Re-project
145
+ y = torch.matmul(z, b)
146
+
147
+ n, _, h, w = x.size()
148
+ y = y.view(n, -1, h, w)
149
+
150
+ y = self.conv3(y)
151
+ y = self.bn3(y)
152
+
153
+ # g_out = x + y
154
+ # g_out = F.relu_(x+y)
155
+ g_out = self.gamma1*y + x
156
+
157
+ # cat or sum, nearly the same results
158
+ out = self.final(torch.cat((spatial_local_feat, g_out), 1))
159
+
160
+ return out
161
+
162
+
163
+ class DDualGCNHead(nn.Module):
164
+ def __init__(self, inplanes, interplanes, abn=InPlaceABNSync):
165
+ super(DDualGCNHead, self).__init__()
166
+ self.conva = nn.Sequential(nn.Conv2d(inplanes, interplanes, 3, padding=1, bias=False),
167
+ abn(interplanes))
168
+ self.dualgcn = DDualGCN(interplanes, abn)
169
+ self.convb = nn.Sequential(nn.Conv2d(interplanes, interplanes, 3, padding=1, bias=False),
170
+ abn(interplanes))
171
+
172
+ self.bottleneck = nn.Sequential(
173
+ nn.Conv2d(inplanes + interplanes, interplanes, kernel_size=3, padding=1, dilation=1, bias=False),
174
+ abn(interplanes)
175
+ )
176
+
177
+ def forward(self, x):
178
+ output = self.conva(x)
179
+ output = self.dualgcn(output)
180
+ output = self.convb(output)
181
+ output = self.bottleneck(torch.cat([x, output], 1))
182
+ return output
utils/dml_csr/modules/edges.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ """
4
+ @Author : Qingping Zheng
5
+ @Contact : qingpingzheng2014@gmail.com
6
+ @File : edges.py
7
+ @Time : 10/01/21 00:00 PM
8
+ @Desc :
9
+ @License : Licensed under the Apache License, Version 2.0 (the "License");
10
+ @Copyright : Copyright 2022 The Authors. All Rights Reserved.
11
+ """
12
+ from __future__ import absolute_import
13
+ from __future__ import division
14
+ from __future__ import print_function
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ import torch.nn as nn
19
+
20
+ from inplace_abn import InPlaceABNSync
21
+
22
+
23
+ class Edges(nn.Module):
24
+
25
+ def __init__(self, abn=InPlaceABNSync, in_fea=[256,512,1024], mid_fea=256, out_fea=2):
26
+ super(Edges, self).__init__()
27
+
28
+ self.conv1 = nn.Sequential(
29
+ nn.Conv2d(in_fea[0], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False),
30
+ abn(mid_fea)
31
+ )
32
+ self.conv2 = nn.Sequential(
33
+ nn.Conv2d(in_fea[1], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False),
34
+ abn(mid_fea)
35
+ )
36
+ self.conv3 = nn.Sequential(
37
+ nn.Conv2d(in_fea[2], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False),
38
+ abn(mid_fea)
39
+ )
40
+ self.conv4 = nn.Conv2d(mid_fea,out_fea, kernel_size=3, padding=1, dilation=1, bias=True)
41
+ self.conv5_b = nn.Conv2d(out_fea*3,2, kernel_size=1, padding=0, dilation=1, bias=True)
42
+ self.conv5 = nn.Conv2d(out_fea*3,out_fea, kernel_size=1, padding=0, dilation=1, bias=True)
43
+
44
+
45
+ def forward(self, x1, x2, x3):
46
+ _, _, h, w = x1.size()
47
+
48
+ edge1_fea = self.conv1(x1)
49
+ edge1 = self.conv4(edge1_fea)
50
+ edge2_fea = self.conv2(x2)
51
+ edge2 = self.conv4(edge2_fea)
52
+ edge3_fea = self.conv3(x3)
53
+ edge3 = self.conv4(edge3_fea)
54
+
55
+ edge2_fea = F.interpolate(edge2_fea, size=(h, w), mode='bilinear',align_corners=True)
56
+ edge3_fea = F.interpolate(edge3_fea, size=(h, w), mode='bilinear',align_corners=True)
57
+ edge2 = F.interpolate(edge2, size=(h, w), mode='bilinear',align_corners=True)
58
+ edge3 = F.interpolate(edge3, size=(h, w), mode='bilinear',align_corners=True)
59
+
60
+ edge = torch.cat([edge1, edge2, edge3], dim=1)
61
+ edge_fea = torch.cat([edge1_fea, edge2_fea, edge3_fea], dim=1)
62
+ semantic_edge = self.conv5(edge)
63
+ binary_edge = self.conv5_b(edge)
64
+
65
+ return binary_edge, semantic_edge, edge_fea
66
+
utils/dml_csr/modules/parsing.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ """
4
+ @Author : Qingping Zheng
5
+ @Contact : qingpingzheng2014@gmail.com
6
+ @File : parsing.py
7
+ @Time : 10/01/21 00:00 PM
8
+ @Desc :
9
+ @License : Licensed under the Apache License, Version 2.0 (the "License");
10
+ @Copyright : Copyright 2022 The Authors. All Rights Reserved.
11
+ """
12
+ from __future__ import absolute_import
13
+ from __future__ import division
14
+ from __future__ import print_function
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+ import torch.nn as nn
19
+
20
+ from inplace_abn import InPlaceABNSync
21
+
22
+
23
+ class Parsing(nn.Module):
24
+ def __init__(self, in_plane1, in_plane2, num_classes, abn=InPlaceABNSync):
25
+ super(Parsing, self).__init__()
26
+ self.conv1 = nn.Sequential(
27
+ nn.Conv2d(in_plane1, 256, kernel_size=1, padding=0, dilation=1, bias=False),
28
+ abn(256)
29
+ )
30
+ self.conv2 = nn.Sequential(
31
+ nn.Conv2d(in_plane2, 48, kernel_size=1, stride=1, padding=0, dilation=1, bias=False),
32
+ abn(48)
33
+ )
34
+ self.conv3 = nn.Sequential(
35
+ nn.Conv2d(304, 256, kernel_size=1, padding=0, dilation=1, bias=False),
36
+ abn(256),
37
+ nn.Conv2d(256, 256, kernel_size=1, padding=0, dilation=1, bias=False),
38
+ abn(256)
39
+ )
40
+ self.conv4 = nn.Conv2d(256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True)
41
+
42
+ def forward(self, xt, xl):
43
+ _, _, h, w = xl.size()
44
+
45
+ xt = F.interpolate(self.conv1(xt), size=(h, w), mode='bilinear', align_corners=True)
46
+ xl = self.conv2(xl)
47
+ x = torch.cat([xt, xl], dim=1)
48
+ x = self.conv3(x)
49
+ seg = self.conv4(x)
50
+ return seg, x
51
+
utils/dml_csr/modules/util.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ """
4
+ @Author : Qingping Zheng
5
+ @Contact : qingpingzheng2014@gmail.com
6
+ @File : util.py
7
+ @Time : 10/01/21 00:00 PM
8
+ @Desc :
9
+ @License : Licensed under the Apache License, Version 2.0 (the "License");
10
+ @Copyright : Copyright 2022 The Authors. All Rights Reserved.
11
+ """
12
+ from __future__ import absolute_import
13
+ from __future__ import division
14
+ from __future__ import print_function
15
+
16
+ import torch.nn as nn
17
+
18
+ from inplace_abn import InPlaceABNSync
19
+
20
+
21
+ class Bottleneck(nn.Module):
22
+ expansion = 4
23
+ def __init__(self, inplanes, planes, stride=1, abn=InPlaceABNSync, dilation=1, downsample=None, fist_dilation=1, multi_grid=1):
24
+ super(Bottleneck, self).__init__()
25
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
26
+ self.bn1 = abn(planes)
27
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
28
+ padding=dilation*multi_grid, dilation=dilation*multi_grid, bias=False)
29
+ self.bn2 = abn(planes)
30
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
31
+ self.bn3 = abn(planes * 4)
32
+ self.relu = nn.ReLU(inplace=False)
33
+ self.relu_inplace = nn.ReLU(inplace=True)
34
+ self.downsample = downsample
35
+ self.dilation = dilation
36
+ self.stride = stride
37
+
38
+ def forward(self, x):
39
+ residual = x
40
+
41
+ out = self.conv1(x)
42
+ out = self.bn1(out)
43
+ out = self.relu(out)
44
+
45
+ out = self.conv2(out)
46
+ out = self.bn2(out)
47
+ out = self.relu(out)
48
+
49
+ out = self.conv3(out)
50
+ out = self.bn3(out)
51
+
52
+ if self.downsample is not None:
53
+ residual = self.downsample(x)
54
+
55
+ out = out + residual
56
+ out = self.relu_inplace(out)
57
+
58
+ return out
utils/dml_csr/transforms.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ """
4
+ @Author : Qingping Zheng
5
+ @Contact : qingpingzheng2014@gmail.com
6
+ @File : transforms.py
7
+ @Time : 10/01/21 00:00 PM
8
+ @Desc :
9
+ @License : Licensed under the Apache License, Version 2.0 (the "License");
10
+ @Copyright : Copyright 2022 The Authors. All Rights Reserved.
11
+ """
12
+ from __future__ import absolute_import
13
+ from __future__ import division
14
+ from __future__ import print_function
15
+
16
+
17
+ import numpy as np
18
+ import cv2
19
+
20
+
21
+ def flip_back(output_flipped, matched_parts):
22
+ '''
23
+ ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width)
24
+ '''
25
+ assert output_flipped.ndim == 4,\
26
+ 'output_flipped should be [batch_size, num_joints, height, width]'
27
+
28
+ output_flipped = output_flipped[:, :, :, ::-1]
29
+
30
+ for pair in matched_parts:
31
+ tmp = output_flipped[:, pair[0], :, :].copy()
32
+ output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
33
+ output_flipped[:, pair[1], :, :] = tmp
34
+
35
+ return output_flipped
36
+
37
+
38
+ def transform_parsing(pred, center, scale, width, height, input_size):
39
+
40
+ if center is not None:
41
+ trans = get_affine_transform(center, scale, 0, input_size, inv=1)
42
+ target_pred = cv2.warpAffine(
43
+ pred,
44
+ trans,
45
+ (int(width), int(height)), #(int(width), int(height)),
46
+ flags=cv2.INTER_NEAREST,
47
+ borderMode=cv2.BORDER_CONSTANT,
48
+ borderValue=(0))
49
+ else:
50
+ target_pred = cv2.resize(pred, (int(width), int(height)), interpolation=cv2.INTER_NEAREST)
51
+
52
+ return target_pred
53
+
54
+
55
+ def get_affine_transform(center,
56
+ scale,
57
+ rot,
58
+ output_size,
59
+ shift=np.array([0, 0], dtype=np.float32),
60
+ inv=0):
61
+ if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
62
+ print(scale)
63
+ scale = np.array([scale, scale])
64
+
65
+ scale_tmp = scale
66
+
67
+ src_w = scale_tmp[0]
68
+ dst_w = output_size[1]
69
+ dst_h = output_size[0]
70
+
71
+ rot_rad = np.pi * rot / 180
72
+ src_dir = get_dir([0, src_w * -0.5], rot_rad)
73
+ dst_dir = np.array([0, dst_w * -0.5], np.float32)
74
+
75
+ src = np.zeros((3, 2), dtype=np.float32)
76
+ dst = np.zeros((3, 2), dtype=np.float32)
77
+ src[0, :] = center + scale_tmp * shift
78
+ src[1, :] = center + src_dir + scale_tmp * shift
79
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
80
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
81
+
82
+ src[2:, :] = get_3rd_point(src[0, :], src[1, :])
83
+ dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
84
+
85
+ if inv:
86
+ trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
87
+ else:
88
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
89
+
90
+ return trans
91
+
92
+
93
+ def affine_transform(pt, t):
94
+ new_pt = np.array([pt[0], pt[1], 1.]).T
95
+ new_pt = np.dot(t, new_pt)
96
+ return new_pt[:2]
97
+
98
+
99
+ def get_3rd_point(a, b):
100
+ direct = a - b
101
+ return b + np.array([-direct[1], direct[0]], dtype=np.float32)
102
+
103
+
104
+ def get_dir(src_point, rot_rad):
105
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
106
+
107
+ src_result = [0, 0]
108
+ src_result[0] = src_point[0] * cs - src_point[1] * sn
109
+ src_result[1] = src_point[0] * sn + src_point[1] * cs
110
+
111
+ return src_result
112
+
113
+
114
+ def crop(img, center, scale, output_size, rot=0):
115
+ trans = get_affine_transform(center, scale, rot, output_size)
116
+
117
+ dst_img = cv2.warpAffine(img,
118
+ trans,
119
+ (int(output_size[1]), int(output_size[0])),
120
+ flags=cv2.INTER_LINEAR)
121
+
122
+ return dst_img
utils/mclip.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ from typing import Union, Optional, Tuple
4
+ from transformers import AutoConfig, AutoModel
5
+ from transformers.models.clip.modeling_clip import CLIPTextModelOutput
6
+
7
+
8
+ class MCLIPConfig(transformers.PretrainedConfig):
9
+ model_type = "M-CLIP"
10
+
11
+ def __init__(self, modelBase='xlm-roberta-large', transformerDimSize=1024, imageDimSize=768, **kwargs):
12
+ self.transformerDimensions = transformerDimSize
13
+ self.numDims = imageDimSize
14
+ self.modelBase = modelBase
15
+ super().__init__(**kwargs)
16
+
17
+
18
+
19
+ class MultilingualCLIP(transformers.PreTrainedModel):
20
+ config_class = MCLIPConfig
21
+
22
+ def __init__(self, config, *args, **kwargs):
23
+ super().__init__(config, *args, **kwargs)
24
+ self.transformer = transformers.AutoModel.from_pretrained(config.modelBase)
25
+ self.LinearTransformation = torch.nn.Linear(in_features=config.transformerDimensions,
26
+ out_features=config.numDims)
27
+
28
+ def forward(
29
+ self,
30
+ input_ids: Optional[torch.Tensor] = None,
31
+ attention_mask: Optional[torch.Tensor] = None,
32
+ position_ids: Optional[torch.Tensor] = None,
33
+ output_attentions: Optional[bool] = None,
34
+ output_hidden_states: Optional[bool] = None,
35
+ return_dict: Optional[bool] = None,
36
+ ) -> Union[Tuple, CLIPTextModelOutput]:
37
+
38
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
39
+
40
+ text_outputs = self.transformer(
41
+ input_ids=input_ids,
42
+ attention_mask=attention_mask,
43
+ position_ids=position_ids,
44
+ output_attentions=output_attentions,
45
+ output_hidden_states=output_hidden_states,
46
+ return_dict=return_dict,
47
+ )
48
+
49
+ pooled_output = text_outputs[1]
50
+
51
+ text_embeds = self.LinearTransformation(pooled_output)
52
+
53
+ if not return_dict:
54
+ outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
55
+ return tuple(output for output in outputs if output is not None)
56
+
57
+ return CLIPTextModelOutput(
58
+ text_embeds=text_embeds,
59
+ last_hidden_state=text_outputs.last_hidden_state,
60
+ hidden_states=text_outputs.hidden_states,
61
+ attentions=text_outputs.attentions,
62
+ )
63
+
64
+ @classmethod
65
+ def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path, _fast_init=True):
66
+ model.load_state_dict(state_dict)
67
+ return model, [], [], []
68
+
69
+ AutoConfig.register("M-CLIP", MCLIPConfig)
70
+ AutoModel.register(MCLIPConfig, MultilingualCLIP)
utils/plot_landmark.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import PIL
3
+ import cv2
4
+ import pickle
5
+ import argparse
6
+ import numpy as np
7
+ import face_alignment
8
+ import matplotlib.pyplot as plt
9
+ import matplotlib.patches as patches
10
+ from matplotlib.path import Path
11
+
12
+
13
+ def parse_args():
14
+ parser = argparse.ArgumentParser(description="Plot facial landmarks from an image.")
15
+ parser.add_argument(
16
+ "--image_path",
17
+ type=str,
18
+ default=None,
19
+ help="Path to the image file."
20
+ )
21
+ parser.add_argument("--size", type=int, default=512)
22
+ parser.add_argument("--crop", action="store_true", help="Crop around the face image.")
23
+ parser.add_argument(
24
+ "--output_dir",
25
+ type=str,
26
+ default="output/landmarks/",
27
+ help="Folder to save landmark images."
28
+ )
29
+ args = parser.parse_args()
30
+
31
+ return args
32
+
33
+ def get_patch(landmarks, color='lime', closed=False):
34
+ contour = landmarks
35
+ ops = [Path.MOVETO] + [Path.LINETO]*(len(contour)-1)
36
+ facecolor = (0, 0, 0, 0) # Transparent fill color, if open
37
+ if closed:
38
+ contour.append(contour[0])
39
+ ops.append(Path.CLOSEPOLY)
40
+ facecolor = color
41
+ path = Path(contour, ops)
42
+ return patches.PathPatch(path, facecolor=facecolor, edgecolor=color, lw=4)
43
+
44
+ def bbox_from_landmarks(landmarks):
45
+ landmarks_x, landmarks_y = zip(*landmarks)
46
+
47
+ x_min, x_max = min(landmarks_x), max(landmarks_x)
48
+ y_min, y_max = min(landmarks_y), max(landmarks_y)
49
+ width = x_max - x_min
50
+ height = y_max - y_min
51
+
52
+ # Give it a little room; I think it works anyway
53
+ x_min -= 25
54
+ y_min -= 25
55
+ width += 50
56
+ height += 50
57
+ bbox = (x_min, y_min, width, height)
58
+ return bbox
59
+
60
+ def plot_landmarks(landmarks, crop=False, size=512):
61
+ if crop:
62
+ (x_min, y_min, width, height) = bbox_from_landmarks(landmarks)
63
+ # print(x_min, y_min, width, height)
64
+ landmarks_np = np.array(landmarks)
65
+ landmarks_np[:, 0] = (landmarks_np[:, 0] - x_min) * size / width
66
+ landmarks_np[:, 1] = (landmarks_np[:, 1] - y_min) * size / height
67
+ landmarks = landmarks_np.tolist()
68
+ # Precisely control output image size
69
+ dpi = 72
70
+ fig, ax = plt.subplots(1, figsize=[size/dpi, size/dpi], tight_layout={'pad':0})
71
+ fig.set_dpi(dpi)
72
+
73
+ black = np.zeros((size, size, 3))
74
+ ax.imshow(black)
75
+
76
+ face_patch = get_patch(landmarks[0:17])
77
+ l_eyebrow = get_patch(landmarks[17:22], color='yellow')
78
+ r_eyebrow = get_patch(landmarks[22:27], color='yellow')
79
+ nose_v = get_patch(landmarks[27:31], color='orange')
80
+ nose_h = get_patch(landmarks[31:36], color='orange')
81
+ l_eye = get_patch(landmarks[36:42], color='magenta', closed=True)
82
+ r_eye = get_patch(landmarks[42:48], color='magenta', closed=True)
83
+ outer_lips = get_patch(landmarks[48:60], color='cyan', closed=True)
84
+ inner_lips = get_patch(landmarks[60:68], color='blue', closed=True)
85
+
86
+ ax.add_patch(face_patch)
87
+ ax.add_patch(l_eyebrow)
88
+ ax.add_patch(r_eyebrow)
89
+ ax.add_patch(nose_v)
90
+ ax.add_patch(nose_h)
91
+ ax.add_patch(l_eye)
92
+ ax.add_patch(r_eye)
93
+ ax.add_patch(outer_lips)
94
+ ax.add_patch(inner_lips)
95
+
96
+ plt.axis('off')
97
+
98
+ fig.canvas.draw()
99
+ buffer, (width, height) = fig.canvas.print_to_buffer()
100
+ assert width == height
101
+ assert width == size
102
+
103
+ buffer = np.frombuffer(buffer, np.uint8).reshape((height, width, 4))
104
+ buffer = buffer[:, :, 0:3]
105
+ plt.close(fig)
106
+ return PIL.Image.fromarray(buffer)
107
+
108
+ def get_landmarks(image):
109
+ fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False, face_detector='sfd')
110
+ faces = fa.get_landmarks_from_image(image)
111
+ if faces is None or len(faces) == 0:
112
+ return None
113
+ landmarks = faces[0]
114
+ return landmarks
115
+
116
+ def save_landmarks(args):
117
+ os.makedirs(args.output_dir, exist_ok=True)
118
+
119
+ image_name = os.path.basename(args.image_path)
120
+ image = cv2.imread(args.image_path)
121
+ image = cv2.resize(image, (args.size, args.size))
122
+ landmarks = get_landmarks(image)
123
+ if landmarks is None:
124
+ print(f'No faces found in {image_name}')
125
+ return
126
+
127
+ filename = f'{args.output_dir}/{image_name}'
128
+ if args.crop:
129
+ landmarks_cropped_image = plot_landmarks(landmarks.tolist(), crop=True, size=args.size)
130
+ landmarks_cropped_image.save(filename)
131
+ else:
132
+ landmarks_image = plot_landmarks(landmarks.tolist(), size=args.size)
133
+ landmarks_image.save(filename)
134
+ print(f'Landmark saved in {filename}')
135
+
136
+ if __name__ == '__main__':
137
+ args = parse_args()
138
+ save_landmarks(args)
utils/plot_mask.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import gdown
4
+ import shutil
5
+ import argparse
6
+ import numpy as np
7
+ import torch
8
+ import torch.backends.cudnn as cudnn
9
+ import torchvision.transforms as transforms
10
+ from torchvision.utils import save_image
11
+
12
+ from inplace_abn import InPlaceABN
13
+ from dml_csr import dml_csr
14
+ from dml_csr import transforms as dml_transforms
15
+
16
+
17
+ def parse_args():
18
+ parser = argparse.ArgumentParser(description="Plot segmentation mask of an image.")
19
+ parser.add_argument(
20
+ "--image_path",
21
+ type=str,
22
+ default=None,
23
+ help="Path to the image file."
24
+ )
25
+ parser.add_argument("--size", type=int, default=512)
26
+ parser.add_argument(
27
+ "--checkpoint_path",
28
+ type=str,
29
+ default='ckpt/DML_CSR/dml_csr_celebA.pth',
30
+ help="Path to the DML-CSR pretrained model."
31
+ )
32
+ parser.add_argument(
33
+ "--output_dir",
34
+ type=str,
35
+ default="output/masks/",
36
+ help="Folder to save segmentation mask."
37
+ )
38
+ args = parser.parse_args()
39
+
40
+ return args
41
+
42
+ def download_checkpoint():
43
+ os.makedirs('ckpt', exist_ok=True)
44
+ id = "1xttWuAj633-ujp_vcm5DtL98PP0b-sUm"
45
+ gdown.download(id=id, output='ckpt/DML_CSR.zip')
46
+ shutil.unpack_archive('ckpt/DML_CSR.zip', 'ckpt')
47
+ os.remove('ckpt/DML_CSR.zip')
48
+
49
+ def box2cs(box: list) -> tuple:
50
+ x, y, w, h = box[:4]
51
+ return xywh2cs(x, y, w, h)
52
+
53
+ def xywh2cs(x: float, y: float, w: float, h: float) -> tuple:
54
+ center = np.zeros((2), dtype=np.float32)
55
+ center[0] = x + w * 0.5
56
+ center[1] = y + h * 0.5
57
+ if w > h:
58
+ h = w
59
+ elif w < h:
60
+ w = h
61
+ scale = np.array([w * 1.0, h * 1.0], dtype=np.float32)
62
+
63
+ return center, scale
64
+
65
+ def labelcolormap(N):
66
+ if N == 19: # CelebAMask-HQ
67
+ cmap = np.array([(0, 0, 0), (204, 0, 0), (76, 153, 0),
68
+ (204, 204, 0), (204, 0, 204), (204, 0, 204), (255, 204, 204),
69
+ (255, 204, 204), (102, 51, 0), (102, 51, 0), (102, 204, 0),
70
+ (255, 255, 0), (0, 0, 153), (0, 0, 204), (255, 51, 153),
71
+ (0, 204, 204), (0, 51, 0), (255, 153, 51), (0, 204, 0)],
72
+ dtype=np.uint8)
73
+ else:
74
+ def uint82bin(n, count=8):
75
+ """returns the binary of integer n, count refers to amount of bits"""
76
+ return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)])
77
+
78
+ cmap = np.zeros((N, 3), dtype=np.uint8)
79
+ for i in range(N):
80
+ r, g, b = 0, 0, 0
81
+ id = i
82
+ for j in range(7):
83
+ str_id = uint82bin(id)
84
+ r = r ^ (np.uint8(str_id[-1]) << (7-j))
85
+ g = g ^ (np.uint8(str_id[-2]) << (7-j))
86
+ b = b ^ (np.uint8(str_id[-3]) << (7-j))
87
+ id = id >> 3
88
+ cmap[i, 0] = r
89
+ cmap[i, 1] = g
90
+ cmap[i, 2] = b
91
+ return cmap
92
+
93
+ class Colorize(object):
94
+ def __init__(self, n=19):
95
+ self.cmap = labelcolormap(n)
96
+ self.cmap = torch.from_numpy(self.cmap[:n])
97
+
98
+ def __call__(self, gray_image):
99
+ size = gray_image.size()
100
+ color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
101
+
102
+ for label in range(0, len(self.cmap)):
103
+ mask = (label == gray_image[0]).cpu()
104
+ color_image[0][mask] = self.cmap[label][0]
105
+ color_image[1][mask] = self.cmap[label][1]
106
+ color_image[2][mask] = self.cmap[label][2]
107
+
108
+ return color_image
109
+
110
+ def tensor2label(label_tensor, n_label):
111
+ label_tensor = label_tensor.cpu().float()
112
+ if label_tensor.size()[0] > 1:
113
+ label_tensor = label_tensor.max(0, keepdim=True)[1]
114
+ label_tensor = Colorize(n_label)(label_tensor)
115
+ #label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
116
+ label_numpy = label_tensor.numpy()
117
+ label_numpy = label_numpy / 255.0
118
+
119
+ return label_numpy
120
+
121
+ def generate_label(inputs, imsize):
122
+ pred_batch = []
123
+ for input in inputs:
124
+ input = input.view(1, 19, imsize, imsize)
125
+ pred = np.squeeze(input.data.max(1)[1].cpu().numpy(), axis=0)
126
+ pred_batch.append(pred)
127
+
128
+ pred_batch = np.array(pred_batch)
129
+ pred_batch = torch.from_numpy(pred_batch)
130
+
131
+ label_batch = []
132
+ for p in pred_batch:
133
+ p = p.view(1, imsize, imsize)
134
+ label_batch.append(tensor2label(p, 19))
135
+
136
+ label_batch = np.array(label_batch)
137
+ label_batch = torch.from_numpy(label_batch)
138
+
139
+ return label_batch
140
+
141
+ def get_mask(model, image, input_size):
142
+ interp = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True)
143
+
144
+ image = image.unsqueeze(0)
145
+ with torch.no_grad():
146
+ outputs = model(image.cuda())
147
+ labels = generate_label(interp(outputs), input_size[0])
148
+ return labels[0]
149
+
150
+ def save_mask(args):
151
+ os.makedirs(args.output_dir, exist_ok=True)
152
+
153
+ cudnn.benchmark = True
154
+ cudnn.enabled = True
155
+
156
+ model = dml_csr.DML_CSR(19, InPlaceABN, False)
157
+
158
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
159
+ std=[0.229, 0.224, 0.225])
160
+ transform = transforms.Compose([transforms.ToTensor(), normalize])
161
+
162
+ input_size = (args.size, args.size)
163
+ image = cv2.imread(args.image_path, cv2.IMREAD_COLOR)
164
+ h, w, _ = image.shape
165
+ center, s = box2cs([0, 0, w - 1, h - 1])
166
+ r = 0
167
+ crop_size = np.asarray(input_size)
168
+ trans = dml_transforms.get_affine_transform(center, s, r, crop_size)
169
+ image = cv2.warpAffine(image, trans, (int(crop_size[1]), int(crop_size[0])),
170
+ flags=cv2.INTER_LINEAR,
171
+ borderMode=cv2.BORDER_CONSTANT,
172
+ borderValue=(0, 0, 0))
173
+ image = transform(image)
174
+
175
+ if not os.path.exists(args.checkpoint_path):
176
+ download_checkpoint()
177
+ state_dict = torch.load(args.checkpoint_path, map_location='cuda:0')
178
+ model.load_state_dict(state_dict)
179
+
180
+ model.cuda()
181
+ model.eval()
182
+
183
+ mask = get_mask(model, image, input_size)
184
+ filename = os.path.join(args.output_dir, os.path.basename(args.image_path).split('.')[0] + '.png')
185
+ save_image(mask, filename)
186
+ print(f'Mask saved in {filename}')
187
+
188
+
189
+ if __name__ == '__main__':
190
+ args = parse_args()
191
+ save_mask(args)