Adding files
Browse files- README.md +73 -0
- data/landmarks.pickle +3 -0
- data/landmarks/1.jpg +0 -0
- data/landmarks/2.jpg +0 -0
- data/landmarks/3.jpg +0 -0
- data/landmarks/4.jpg +0 -0
- data/masks/1.png +0 -0
- data/masks/2.png +0 -0
- data/masks/3.png +0 -0
- data/masks/4.png +0 -0
- docs/pull-figure.png +0 -0
- edit.py +493 -0
- generate.py +176 -0
- requirements.txt +7 -0
- utils/dml_csr/dml_csr.py +103 -0
- utils/dml_csr/modules/ddgcn.py +182 -0
- utils/dml_csr/modules/edges.py +66 -0
- utils/dml_csr/modules/parsing.py +51 -0
- utils/dml_csr/modules/util.py +58 -0
- utils/dml_csr/transforms.py +122 -0
- utils/mclip.py +70 -0
- utils/plot_landmark.py +138 -0
- utils/plot_mask.py +191 -0
README.md
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
tags:
|
3 |
+
- text-to-image
|
4 |
+
- controlnet
|
5 |
+
---
|
6 |
+
|
7 |
+
# M<sup>3</sup>Face Model Card
|
8 |
+
We introduce M<sup>3</sup>Face, a unified multi-modal multilingual framework for controllable face generation and editing. This framework enables users to utilize only text input to generate controlling modalities automatically, for instance, semantic segmentation or facial landmarks, and subsequently generate face images.
|
9 |
+
|
10 |
+
## Getting Started
|
11 |
+
|
12 |
+
### Installation
|
13 |
+
1. Clone our repository:
|
14 |
+
|
15 |
+
```bash
|
16 |
+
git clone https://huggingface.co/m3face/m3face
|
17 |
+
cd m3face
|
18 |
+
```
|
19 |
+
|
20 |
+
2. Install dependencies:
|
21 |
+
|
22 |
+
```bash
|
23 |
+
pip install -r requirements.txt
|
24 |
+
```
|
25 |
+
|
26 |
+
### Resources
|
27 |
+
- For face generation, VRAM of 10 GB+ for 512x512 images is required.
|
28 |
+
- For face editing, VRAM of 14 GB+ for 512x512 images is required.
|
29 |
+
|
30 |
+
### Pre-trained Models
|
31 |
+
You can find the checkpoints for the ControlNet model at [`m3face/ControlnetModels`](https://huggingface.co/m3face/ControlnetModels) and the mask/landmark generator model at [`m3face/FaceConditioning`](https://huggingface.co/m3face/FaceConditioning).
|
32 |
+
|
33 |
+
### M<sup>3</sup>CelebA Dataset
|
34 |
+
The M<sup>3</sup>CelebA Dataset is available at [`m3face/M3CelebA`](https://huggingface.co/m3face/M3CelebA). You can view or download it from there.
|
35 |
+
|
36 |
+
## Face Generation
|
37 |
+
You can do face generation with text, segmentation mask, facial landmarks, or a combination of them by running the following command:
|
38 |
+
```bash
|
39 |
+
python generate.py --seed 1111 \
|
40 |
+
--condition "landmark" \
|
41 |
+
--prompt "This attractive woman has narrow eyes, rosy cheeks, and wears heavy makeup." \
|
42 |
+
--save_condition
|
43 |
+
```
|
44 |
+
You can define the type of conditioning modality with `--condition`. By default, a conditioning modality will be generated by our framework and will be saved if the `--save_condition` argument is given. Otherwise, you can use your condition image with the `condition_path` argument.
|
45 |
+
|
46 |
+
## Face Editing
|
47 |
+
For face editing, you can run the following command:
|
48 |
+
```bash
|
49 |
+
python edit.py --enable_xformers_memory_efficient_attention \
|
50 |
+
--seed 1111 \
|
51 |
+
--condition "landmark" \
|
52 |
+
--prompt "She is a smiling." \
|
53 |
+
--image_path "/path/to/image" \
|
54 |
+
--condition_path "/path/to/condition" \
|
55 |
+
--edit_condition \
|
56 |
+
--embedding_optimize_it 500 \
|
57 |
+
--model_finetune_it 1000 \
|
58 |
+
--alpha 0.7 1 1.1 \
|
59 |
+
--num_inference_steps 30 \
|
60 |
+
--unet_layer "2and3"
|
61 |
+
```
|
62 |
+
You need to specify the input image and original conditioning modality. You can edit the face with an edit conditioning modality (specifying `--edit_condition_path`) or by editing the original conditioning modality with our framework (specifying `--edit_condition`).
|
63 |
+
The `--unet_layer` argument specifies which UNet layers in the SD to finetune.
|
64 |
+
|
65 |
+
> Note: If you don't have the original conditioning modality you can simply generate it using the `plot_mask.py` and `plot_landmark.py` scripts:
|
66 |
+
```bash
|
67 |
+
pip install git+https://github.com/mapillary/inplace_abn
|
68 |
+
python utils/plot_mask.py --image_path "/path/to/image"
|
69 |
+
python utils/plot_landmark.py --image_path "/path/to/image"
|
70 |
+
```
|
71 |
+
|
72 |
+
## Training
|
73 |
+
The code and instruction for training our models will be posted soon!
|
data/landmarks.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3e129223b20f017a389b04ffe65ac0fd047f03a2bd9ef5bcb9eb0358b2b50a85
|
3 |
+
size 688
|
data/landmarks/1.jpg
ADDED
![]() |
data/landmarks/2.jpg
ADDED
![]() |
data/landmarks/3.jpg
ADDED
![]() |
data/landmarks/4.jpg
ADDED
![]() |
data/masks/1.png
ADDED
![]() |
data/masks/2.png
ADDED
![]() |
data/masks/3.png
ADDED
![]() |
data/masks/4.png
ADDED
![]() |
docs/pull-figure.png
ADDED
![]() |
edit.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
from tqdm.auto import tqdm
|
4 |
+
from packaging import version
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torch.utils.checkpoint
|
9 |
+
from torchvision import transforms
|
10 |
+
from diffusers import (
|
11 |
+
AutoencoderKL,
|
12 |
+
ControlNetModel,
|
13 |
+
DDPMScheduler,
|
14 |
+
StableDiffusionControlNetPipeline,
|
15 |
+
UNet2DConditionModel,
|
16 |
+
UniPCMultistepScheduler,
|
17 |
+
PNDMScheduler,
|
18 |
+
AmusedInpaintPipeline, AmusedScheduler, VQModel, UVit2DModel
|
19 |
+
|
20 |
+
)
|
21 |
+
from diffusers.utils.import_utils import is_xformers_available
|
22 |
+
from diffusers.utils import load_image
|
23 |
+
from transformers import AutoTokenizer, CLIPFeatureExtractor, PretrainedConfig
|
24 |
+
from PIL import Image
|
25 |
+
from utils.mclip import *
|
26 |
+
|
27 |
+
|
28 |
+
def parse_args():
|
29 |
+
parser = argparse.ArgumentParser(description="Edit images with M3Face.")
|
30 |
+
parser.add_argument(
|
31 |
+
"--prompt",
|
32 |
+
type=str,
|
33 |
+
default="This attractive woman has narrow eyes, rosy cheeks, and wears heavy makeup.",
|
34 |
+
help="The input text prompt for image generation."
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--condition",
|
38 |
+
type=str,
|
39 |
+
default="mask",
|
40 |
+
choices=["mask", "landmark"],
|
41 |
+
help="Use segmentation mask or facial landmarks for image generation."
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
"--image_path",
|
45 |
+
type=str,
|
46 |
+
default=None,
|
47 |
+
help="Path to the input image."
|
48 |
+
)
|
49 |
+
parser.add_argument(
|
50 |
+
"--condition_path",
|
51 |
+
type=str,
|
52 |
+
default=None,
|
53 |
+
help="Path to the original mask/landmark image."
|
54 |
+
)
|
55 |
+
parser.add_argument(
|
56 |
+
"--edit_condition_path",
|
57 |
+
type=str,
|
58 |
+
default=None,
|
59 |
+
help="Path to the target mask/landmark image."
|
60 |
+
)
|
61 |
+
parser.add_argument(
|
62 |
+
"--output_dir",
|
63 |
+
type=str,
|
64 |
+
default='output/',
|
65 |
+
help="The output directory where the results will be written.",
|
66 |
+
)
|
67 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible generation.")
|
68 |
+
parser.add_argument(
|
69 |
+
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
70 |
+
)
|
71 |
+
parser.add_argument("--edit_condition", action="store_true")
|
72 |
+
parser.add_argument("--load_unet_from_local", action="store_true")
|
73 |
+
parser.add_argument("--save_unet", action="store_true")
|
74 |
+
parser.add_argument("--unet_local_path", type=str, default=None)
|
75 |
+
parser.add_argument("--load_finetune_from_local", action="store_true")
|
76 |
+
parser.add_argument("--finetune_path", type=str, default=None)
|
77 |
+
parser.add_argument("--use_english", action="store_true", help="Use the English models.")
|
78 |
+
parser.add_argument("--embedding_optimize_it", type=int, default=500)
|
79 |
+
parser.add_argument("--model_finetune_it", type=int, default=1000)
|
80 |
+
parser.add_argument("--alpha", nargs="+", type=float, default=[0.8, 0.9, 1, 1.1])
|
81 |
+
parser.add_argument("--num_inference_steps", nargs="+", type=int, default=[20, 40, 50])
|
82 |
+
parser.add_argument("--unet_layer", type=str, default="2and3",
|
83 |
+
help="Which UNet layers in the SD to finetune.")
|
84 |
+
|
85 |
+
args = parser.parse_args()
|
86 |
+
|
87 |
+
return args
|
88 |
+
|
89 |
+
def get_muse(args):
|
90 |
+
muse_model_name = 'm3face/FaceConditioning'
|
91 |
+
if args.condition == 'mask':
|
92 |
+
muse_revision = 'segmentation'
|
93 |
+
elif args.condition == 'landmark':
|
94 |
+
muse_revision = 'landmark'
|
95 |
+
scheduler = AmusedScheduler.from_pretrained(muse_model_name, revision=muse_revision, subfolder='scheduler')
|
96 |
+
vqvae = VQModel.from_pretrained(muse_model_name, revision=muse_revision, subfolder='vqvae')
|
97 |
+
uvit2 = UVit2DModel.from_pretrained(muse_model_name, revision=muse_revision, subfolder='transformer')
|
98 |
+
text_encoder = MultilingualCLIP.from_pretrained(muse_model_name, revision=muse_revision, subfolder='text_encoder')
|
99 |
+
tokenizer = AutoTokenizer.from_pretrained(muse_model_name, revision=muse_revision, subfolder='tokenizer')
|
100 |
+
|
101 |
+
pipeline = AmusedInpaintPipeline(
|
102 |
+
vqvae=vqvae,
|
103 |
+
tokenizer=tokenizer,
|
104 |
+
text_encoder=text_encoder,
|
105 |
+
transformer=uvit2,
|
106 |
+
scheduler=scheduler
|
107 |
+
).to("cuda")
|
108 |
+
|
109 |
+
return pipeline
|
110 |
+
|
111 |
+
def import_model_class_from_model_name(sd_model_name):
|
112 |
+
text_encoder_config = PretrainedConfig.from_pretrained(
|
113 |
+
sd_model_name,
|
114 |
+
subfolder="text_encoder",
|
115 |
+
)
|
116 |
+
model_class = text_encoder_config.architectures[0]
|
117 |
+
|
118 |
+
if model_class == "CLIPTextModel":
|
119 |
+
from transformers import CLIPTextModel
|
120 |
+
|
121 |
+
return CLIPTextModel
|
122 |
+
elif model_class == "RobertaSeriesModelWithTransformation":
|
123 |
+
from diffusers.pipelines.deprecated.alt_diffusion import RobertaSeriesModelWithTransformation
|
124 |
+
|
125 |
+
return RobertaSeriesModelWithTransformation
|
126 |
+
else:
|
127 |
+
raise ValueError(f"{model_class} is not supported.")
|
128 |
+
|
129 |
+
def preprocess(image, condition, prompt, tokenizer):
|
130 |
+
image_transforms = transforms.Compose(
|
131 |
+
[
|
132 |
+
transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
|
133 |
+
transforms.CenterCrop(512),
|
134 |
+
transforms.ToTensor(),
|
135 |
+
transforms.Normalize([0.5], [0.5]),
|
136 |
+
]
|
137 |
+
)
|
138 |
+
condition_transforms = transforms.Compose(
|
139 |
+
[
|
140 |
+
transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
|
141 |
+
transforms.CenterCrop(512),
|
142 |
+
transforms.ToTensor(),
|
143 |
+
]
|
144 |
+
)
|
145 |
+
image = image_transforms(image)
|
146 |
+
condition = condition_transforms(condition)
|
147 |
+
inputs = tokenizer(
|
148 |
+
[prompt], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
149 |
+
)
|
150 |
+
|
151 |
+
return image, condition, inputs.input_ids, inputs.attention_mask
|
152 |
+
|
153 |
+
def main(args):
|
154 |
+
if args.use_english:
|
155 |
+
sd_model_name = 'runwayml/stable-diffusion-v1-5'
|
156 |
+
controlnet_model_name = 'm3face/ControlnetModels'
|
157 |
+
if args.condition == 'mask':
|
158 |
+
controlnet_revision = 'segmentation-english'
|
159 |
+
elif args.condition == 'landmark':
|
160 |
+
controlnet_revision = 'landmark-english'
|
161 |
+
else:
|
162 |
+
sd_model_name = 'BAAI/AltDiffusion-m18'
|
163 |
+
controlnet_model_name = 'm3face/ControlnetModels'
|
164 |
+
if args.condition == 'mask':
|
165 |
+
controlnet_revision = 'segmentation-mlin'
|
166 |
+
elif args.condition == 'landmark':
|
167 |
+
controlnet_revision = 'landmark-mlin'
|
168 |
+
|
169 |
+
# ========== set up models ==========
|
170 |
+
vae = AutoencoderKL.from_pretrained(sd_model_name, subfolder="vae")
|
171 |
+
tokenizer = AutoTokenizer.from_pretrained(sd_model_name, subfolder="tokenizer", use_fast=False)
|
172 |
+
text_encoder_cls = import_model_class_from_model_name(sd_model_name)
|
173 |
+
text_encoder = text_encoder_cls.from_pretrained(sd_model_name, subfolder="text_encoder")
|
174 |
+
noise_scheduler = DDPMScheduler.from_pretrained(sd_model_name, subfolder="scheduler")
|
175 |
+
|
176 |
+
if args.load_unet_from_local:
|
177 |
+
unet = UNet2DConditionModel.from_pretrained(args.unet_local_path)
|
178 |
+
else:
|
179 |
+
unet = UNet2DConditionModel.from_pretrained(sd_model_name, subfolder="unet")
|
180 |
+
|
181 |
+
controlnet = ControlNetModel.from_pretrained(controlnet_model_name, revision=controlnet_revision)
|
182 |
+
|
183 |
+
if args.edit_condition:
|
184 |
+
muse = get_muse(args)
|
185 |
+
|
186 |
+
vae.requires_grad_(False)
|
187 |
+
text_encoder.requires_grad_(False)
|
188 |
+
controlnet.requires_grad_(False)
|
189 |
+
unet.requires_grad_(False)
|
190 |
+
vae.eval()
|
191 |
+
text_encoder.eval()
|
192 |
+
controlnet.eval()
|
193 |
+
unet.eval()
|
194 |
+
|
195 |
+
if args.enable_xformers_memory_efficient_attention:
|
196 |
+
if is_xformers_available():
|
197 |
+
import xformers
|
198 |
+
|
199 |
+
xformers_version = version.parse(xformers.__version__)
|
200 |
+
if xformers_version == version.parse("0.0.16"):
|
201 |
+
print(
|
202 |
+
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
203 |
+
)
|
204 |
+
unet.enable_xformers_memory_efficient_attention()
|
205 |
+
controlnet.enable_xformers_memory_efficient_attention()
|
206 |
+
else:
|
207 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
208 |
+
|
209 |
+
# ========== select params to optimize ==========
|
210 |
+
params = []
|
211 |
+
for name, param in unet.named_parameters():
|
212 |
+
if(name.startswith('up_blocks')):
|
213 |
+
params.append(param)
|
214 |
+
|
215 |
+
if args.unet_layer == 'only1': # 116 layers
|
216 |
+
params_to_optimize = [
|
217 |
+
{'params': params[38:154]},
|
218 |
+
]
|
219 |
+
elif args.unet_layer == 'only2': # 116 layers
|
220 |
+
params_to_optimize = [
|
221 |
+
{'params': params[154:270]},
|
222 |
+
]
|
223 |
+
elif args.unet_layer == 'only3': # 114 layers
|
224 |
+
params_to_optimize = [
|
225 |
+
{'params': params[270:]},
|
226 |
+
]
|
227 |
+
elif args.unet_layer == '1and2': # 232 layers
|
228 |
+
params_to_optimize = [
|
229 |
+
{'params': params[38:270]},
|
230 |
+
]
|
231 |
+
elif args.unet_layer == '2and3': # 230 layers
|
232 |
+
params_to_optimize = [
|
233 |
+
{'params': params[154:]},
|
234 |
+
]
|
235 |
+
elif args.unet_layer == 'all': # all layers
|
236 |
+
params_to_optimize = [
|
237 |
+
{'params': params},
|
238 |
+
]
|
239 |
+
|
240 |
+
image = Image.open(args.image_path).convert('RGB')
|
241 |
+
condition = Image.open(args.condition_path).convert('RGB')
|
242 |
+
image, condition, input_ids, attention_mask = preprocess(image, condition, args.prompt, tokenizer)
|
243 |
+
|
244 |
+
# Move to device
|
245 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
246 |
+
vae.to(device, dtype=torch.float32)
|
247 |
+
unet.to(device, dtype=torch.float32)
|
248 |
+
text_encoder.to(device, dtype=torch.float32)
|
249 |
+
controlnet.to(device)
|
250 |
+
image = image.to(device).unsqueeze(0)
|
251 |
+
condition = condition.to(device).unsqueeze(0)
|
252 |
+
input_ids = input_ids.to(device)
|
253 |
+
attention_mask = attention_mask.to(device)
|
254 |
+
|
255 |
+
# ========== imagic ==========
|
256 |
+
if args.load_finetune_from_local:
|
257 |
+
print('Loading embeddings from local ...')
|
258 |
+
orig_emb = torch.load(os.path.join(args.finetune_path, 'orig_emb.pt'))
|
259 |
+
emb = torch.load(os.path.join(args.finetune_path, 'emb.pt'))
|
260 |
+
else:
|
261 |
+
init_latent = vae.encode(image.to(dtype=torch.float32)).latent_dist.sample()
|
262 |
+
init_latent = init_latent * vae.config.scaling_factor
|
263 |
+
|
264 |
+
if not args.use_english:
|
265 |
+
orig_emb = text_encoder(input_ids, attention_mask=attention_mask)[0]
|
266 |
+
else:
|
267 |
+
orig_emb = text_encoder(input_ids)[0]
|
268 |
+
emb = orig_emb.clone()
|
269 |
+
torch.save(orig_emb, os.path.join(args.output_dir, 'orig_emb.pt'))
|
270 |
+
torch.save(emb, os.path.join(args.output_dir, 'emb.pt'))
|
271 |
+
|
272 |
+
# 1. Optimize the embedding
|
273 |
+
print('1. Optimize the embedding')
|
274 |
+
unet.eval()
|
275 |
+
emb.requires_grad = True
|
276 |
+
lr = 0.001
|
277 |
+
it = args.embedding_optimize_it # 500
|
278 |
+
opt = torch.optim.Adam([emb], lr=lr)
|
279 |
+
history = []
|
280 |
+
|
281 |
+
pbar = tqdm(
|
282 |
+
range(it),
|
283 |
+
initial=0,
|
284 |
+
desc="Optimize Steps",
|
285 |
+
)
|
286 |
+
global_step = 0
|
287 |
+
|
288 |
+
for i in pbar:
|
289 |
+
opt.zero_grad()
|
290 |
+
|
291 |
+
noise = torch.randn_like(init_latent)
|
292 |
+
bsz = init_latent.shape[0]
|
293 |
+
t_enc = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=init_latent.device)
|
294 |
+
t_enc = t_enc.long()
|
295 |
+
z = noise_scheduler.add_noise(init_latent, noise, t_enc)
|
296 |
+
|
297 |
+
controlnet_image = condition.to(dtype=torch.float32)
|
298 |
+
|
299 |
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
300 |
+
z,
|
301 |
+
t_enc,
|
302 |
+
encoder_hidden_states=emb,
|
303 |
+
controlnet_cond=controlnet_image,
|
304 |
+
return_dict=False,
|
305 |
+
)
|
306 |
+
|
307 |
+
# Predict the noise residual
|
308 |
+
pred_noise = unet(
|
309 |
+
z,
|
310 |
+
t_enc,
|
311 |
+
encoder_hidden_states=emb,
|
312 |
+
down_block_additional_residuals=[
|
313 |
+
sample.to(dtype=torch.float32) for sample in down_block_res_samples
|
314 |
+
],
|
315 |
+
mid_block_additional_residual=mid_block_res_sample.to(dtype=torch.float32),
|
316 |
+
).sample
|
317 |
+
|
318 |
+
# Get the target for loss depending on the prediction type
|
319 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
320 |
+
target = noise
|
321 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
322 |
+
target = noise_scheduler.get_velocity(init_latent, noise, t_enc)
|
323 |
+
else:
|
324 |
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
325 |
+
loss = F.mse_loss(pred_noise.float(), target.float(), reduction="mean")
|
326 |
+
|
327 |
+
loss.backward()
|
328 |
+
global_step += 1
|
329 |
+
pbar.set_postfix({"loss": loss.item()})
|
330 |
+
history.append(loss.item())
|
331 |
+
opt.step()
|
332 |
+
opt.zero_grad()
|
333 |
+
|
334 |
+
# 2. Finetune the model
|
335 |
+
print('2. Finetune the model')
|
336 |
+
emb.requires_grad = False
|
337 |
+
unet.requires_grad_(True)
|
338 |
+
unet.train()
|
339 |
+
|
340 |
+
lr = 5e-5
|
341 |
+
it = args.model_finetune_it # 1000
|
342 |
+
opt = torch.optim.Adam(params_to_optimize, lr=lr)
|
343 |
+
history = []
|
344 |
+
|
345 |
+
pbar = tqdm(
|
346 |
+
range(it),
|
347 |
+
initial=0,
|
348 |
+
desc="Finetune Steps",
|
349 |
+
)
|
350 |
+
global_step = 0
|
351 |
+
for i in pbar:
|
352 |
+
opt.zero_grad()
|
353 |
+
|
354 |
+
noise = torch.randn_like(init_latent)
|
355 |
+
bsz = init_latent.shape[0]
|
356 |
+
t_enc = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=init_latent.device)
|
357 |
+
t_enc = t_enc.long()
|
358 |
+
z = noise_scheduler.add_noise(init_latent, noise, t_enc)
|
359 |
+
|
360 |
+
controlnet_image = condition.to(dtype=torch.float32)
|
361 |
+
|
362 |
+
down_block_res_samples, mid_block_res_sample = controlnet(
|
363 |
+
z,
|
364 |
+
t_enc,
|
365 |
+
encoder_hidden_states=emb,
|
366 |
+
controlnet_cond=controlnet_image,
|
367 |
+
return_dict=False,
|
368 |
+
)
|
369 |
+
|
370 |
+
# Predict the noise residual
|
371 |
+
pred_noise = unet(
|
372 |
+
z,
|
373 |
+
t_enc,
|
374 |
+
encoder_hidden_states=emb,
|
375 |
+
down_block_additional_residuals=[
|
376 |
+
sample.to(dtype=torch.float32) for sample in down_block_res_samples
|
377 |
+
],
|
378 |
+
mid_block_additional_residual=mid_block_res_sample.to(dtype=torch.float32),
|
379 |
+
).sample
|
380 |
+
|
381 |
+
# Get the target for loss depending on the prediction type
|
382 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
383 |
+
target = noise
|
384 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
385 |
+
target = noise_scheduler.get_velocity(init_latent, noise, t_enc)
|
386 |
+
else:
|
387 |
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
388 |
+
loss = F.mse_loss(pred_noise.float(), target.float(), reduction="mean")
|
389 |
+
|
390 |
+
loss.backward()
|
391 |
+
global_step += 1
|
392 |
+
pbar.set_postfix({"loss": loss.item()})
|
393 |
+
history.append(loss.item())
|
394 |
+
opt.step()
|
395 |
+
opt.zero_grad()
|
396 |
+
|
397 |
+
# 3. Generate Images
|
398 |
+
print("3. Generating images... ")
|
399 |
+
|
400 |
+
unet.eval()
|
401 |
+
controlnet.eval()
|
402 |
+
|
403 |
+
if args.edit_condition_path is None:
|
404 |
+
edit_condition = load_image(args.condition_path)
|
405 |
+
else:
|
406 |
+
edit_condition = load_image(args.edit_condition_path)
|
407 |
+
if args.edit_condition:
|
408 |
+
edit_mask = Image.new("L", (256, 256), 0)
|
409 |
+
for i in range(256):
|
410 |
+
for j in range(256):
|
411 |
+
if 40 < i < 220 and 20 < j < 256:
|
412 |
+
edit_mask.putpixel((i, j), 256)
|
413 |
+
|
414 |
+
if args.condition == 'mask':
|
415 |
+
condition = 'segmentation'
|
416 |
+
elif args.condition == 'landmark':
|
417 |
+
condition = 'landmark'
|
418 |
+
edit_prompt = f"Generate face {condition} | " + args.prompt
|
419 |
+
input_image = edit_condition.resize((256, 256)).convert("RGB")
|
420 |
+
edit_condition = muse(edit_prompt, input_image, edit_mask, num_inference_steps=30).images[0].resize((512, 512))
|
421 |
+
edit_condition.save(f'{args.output_dir}/edited_condition.png')
|
422 |
+
|
423 |
+
# remove muse and empty cache
|
424 |
+
del muse
|
425 |
+
torch.cuda.empty_cache()
|
426 |
+
|
427 |
+
if sd_model_name.startswith('BAAI'):
|
428 |
+
scheduler = PNDMScheduler.from_pretrained(
|
429 |
+
sd_model_name,
|
430 |
+
subfolder='scheduler',
|
431 |
+
)
|
432 |
+
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
|
433 |
+
feature_extractor = CLIPFeatureExtractor.from_pretrained(
|
434 |
+
sd_model_name,
|
435 |
+
subfolder='feature_extractor',
|
436 |
+
)
|
437 |
+
pipeline = StableDiffusionControlNetPipeline(
|
438 |
+
vae=vae,
|
439 |
+
text_encoder=text_encoder,
|
440 |
+
tokenizer=tokenizer,
|
441 |
+
unet=unet,
|
442 |
+
controlnet=controlnet,
|
443 |
+
scheduler=scheduler,
|
444 |
+
safety_checker=None,
|
445 |
+
feature_extractor=feature_extractor
|
446 |
+
)
|
447 |
+
else:
|
448 |
+
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
|
449 |
+
sd_model_name,
|
450 |
+
vae=vae,
|
451 |
+
text_encoder=text_encoder,
|
452 |
+
tokenizer=tokenizer,
|
453 |
+
unet=unet,
|
454 |
+
controlnet=controlnet,
|
455 |
+
safety_checker=None,
|
456 |
+
)
|
457 |
+
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
|
458 |
+
pipeline = pipeline.to(device)
|
459 |
+
pipeline.set_progress_bar_config(disable=True)
|
460 |
+
|
461 |
+
if args.enable_xformers_memory_efficient_attention:
|
462 |
+
pipeline.enable_xformers_memory_efficient_attention()
|
463 |
+
|
464 |
+
if args.seed is None:
|
465 |
+
generator = None
|
466 |
+
else:
|
467 |
+
generator = torch.Generator(device=device).manual_seed(args.seed)
|
468 |
+
|
469 |
+
with torch.autocast("cuda"):
|
470 |
+
image = pipeline(
|
471 |
+
image=edit_condition, prompt_embeds=emb, num_inference_steps=20, generator=generator
|
472 |
+
).images[0]
|
473 |
+
image.save(f'{args.output_dir}/reconstruct.png')
|
474 |
+
|
475 |
+
# Interpolate the embedding
|
476 |
+
for num_inference_steps in args.num_inference_steps:
|
477 |
+
for alpha in args.alpha:
|
478 |
+
new_emb = alpha * orig_emb + (1 - alpha) * emb
|
479 |
+
|
480 |
+
with torch.autocast("cuda"):
|
481 |
+
image = pipeline(
|
482 |
+
image=edit_condition, prompt_embeds=new_emb, num_inference_steps=num_inference_steps, generator=generator
|
483 |
+
).images[0]
|
484 |
+
image.save(f'{args.output_dir}/image_{num_inference_steps}_{alpha}.png')
|
485 |
+
|
486 |
+
if args.save_unet:
|
487 |
+
print('Saving the unet model...')
|
488 |
+
unet.save_pretrained(f'{args.output_dir}/unet')
|
489 |
+
|
490 |
+
|
491 |
+
if __name__ == '__main__':
|
492 |
+
args = parse_args()
|
493 |
+
main(args)
|
generate.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse, os, time
|
2 |
+
import torch
|
3 |
+
from diffusers import (
|
4 |
+
AutoencoderKL,
|
5 |
+
ControlNetModel,
|
6 |
+
StableDiffusionControlNetPipeline,
|
7 |
+
UNet2DConditionModel,
|
8 |
+
UniPCMultistepScheduler,
|
9 |
+
PNDMScheduler,
|
10 |
+
AmusedPipeline, AmusedScheduler, VQModel, UVit2DModel
|
11 |
+
)
|
12 |
+
from transformers import AutoTokenizer, CLIPFeatureExtractor
|
13 |
+
from diffusers.pipelines.deprecated.alt_diffusion import RobertaSeriesModelWithTransformation
|
14 |
+
from diffusers.utils import load_image
|
15 |
+
from utils.mclip import *
|
16 |
+
|
17 |
+
|
18 |
+
def parse_args():
|
19 |
+
parser = argparse.ArgumentParser(description="Generate images with M3Face.")
|
20 |
+
parser.add_argument(
|
21 |
+
"--prompt",
|
22 |
+
type=str,
|
23 |
+
default="This attractive woman has narrow eyes, rosy cheeks, and wears heavy makeup.",
|
24 |
+
help="The input text prompt for image generation."
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
"--condition",
|
28 |
+
type=str,
|
29 |
+
default="mask",
|
30 |
+
choices=["mask", "landmark"],
|
31 |
+
help="Use segmentation mask or facial landmarks for image generation."
|
32 |
+
)
|
33 |
+
parser.add_argument(
|
34 |
+
"--condition_path",
|
35 |
+
type=str,
|
36 |
+
default=None,
|
37 |
+
help="Path to the condition mask/landmark image. We will generate the condition if it is not given."
|
38 |
+
)
|
39 |
+
parser.add_argument("--save_condition", action="store_true", help="Save the generated condition image.")
|
40 |
+
parser.add_argument("--use_english", action="store_true", help="Use the English models.")
|
41 |
+
parser.add_argument("--enhance_prompt", action="store_true", help="Enhance the given text prompt.")
|
42 |
+
parser.add_argument("--num_inference_steps", type=int, default=30)
|
43 |
+
parser.add_argument("--num_samples", type=int, default=1)
|
44 |
+
parser.add_argument(
|
45 |
+
"--additional_prompt",
|
46 |
+
type=str,
|
47 |
+
default="rim lighting, dslr, ultra quality, sharp focus, dof, Fujifilm XT3, crystal clear, highly detailed glossy eyes, high detailed skin, skin pores, 8K UHD"
|
48 |
+
)
|
49 |
+
parser.add_argument(
|
50 |
+
"--negative_prompt",
|
51 |
+
type=str,
|
52 |
+
default="low quality, bad quality, worst quality, blurry, disfigured, ugly, immature, cartoon, painting"
|
53 |
+
)
|
54 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible generation.")
|
55 |
+
parser.add_argument(
|
56 |
+
"--output_dir",
|
57 |
+
type=str,
|
58 |
+
default="output/",
|
59 |
+
help="The output directory where the results will be written.",
|
60 |
+
)
|
61 |
+
args = parser.parse_args()
|
62 |
+
|
63 |
+
return args
|
64 |
+
|
65 |
+
def get_controlnet(args):
|
66 |
+
if args.use_english:
|
67 |
+
sd_model_name = 'runwayml/stable-diffusion-v1-5'
|
68 |
+
controlnet_model_name = 'm3face/ControlnetModels'
|
69 |
+
if args.condition == 'mask':
|
70 |
+
controlnet_revision = 'segmentation-english'
|
71 |
+
elif args.condition == 'landmark':
|
72 |
+
controlnet_revision = 'landmark-english'
|
73 |
+
controlnet = ControlNetModel.from_pretrained(controlnet_model_name, use_safetensors=True, revision=controlnet_revision)
|
74 |
+
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
|
75 |
+
sd_model_name, controlnet=controlnet, use_safetensors=True, safety_checker=None
|
76 |
+
).to("cuda")
|
77 |
+
|
78 |
+
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
|
79 |
+
pipeline.enable_model_cpu_offload()
|
80 |
+
else:
|
81 |
+
sd_model_name = 'BAAI/AltDiffusion-m18'
|
82 |
+
controlnet_model_name = 'm3face/ControlnetModels'
|
83 |
+
if args.condition == 'mask':
|
84 |
+
controlnet_revision = 'segmentation-mlin'
|
85 |
+
elif args.condition == 'landmark':
|
86 |
+
controlnet_revision = 'landmark-mlin'
|
87 |
+
vae = AutoencoderKL.from_pretrained(sd_model_name, subfolder="vae")
|
88 |
+
unet = UNet2DConditionModel.from_pretrained(sd_model_name, subfolder="unet")
|
89 |
+
tokenizer = AutoTokenizer.from_pretrained(sd_model_name, subfolder="tokenizer", use_fast=False)
|
90 |
+
text_encoder = RobertaSeriesModelWithTransformation.from_pretrained(sd_model_name, subfolder="text_encoder")
|
91 |
+
controlnet = ControlNetModel.from_pretrained(controlnet_model_name, revision=controlnet_revision)
|
92 |
+
|
93 |
+
scheduler = PNDMScheduler.from_pretrained(
|
94 |
+
sd_model_name,
|
95 |
+
subfolder='scheduler',
|
96 |
+
)
|
97 |
+
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
|
98 |
+
feature_extractor = CLIPFeatureExtractor.from_pretrained(
|
99 |
+
sd_model_name,
|
100 |
+
subfolder='feature_extractor',
|
101 |
+
)
|
102 |
+
pipeline = StableDiffusionControlNetPipeline(
|
103 |
+
vae=vae,
|
104 |
+
text_encoder=text_encoder,
|
105 |
+
tokenizer=tokenizer,
|
106 |
+
unet=unet,
|
107 |
+
controlnet=controlnet,
|
108 |
+
scheduler=scheduler,
|
109 |
+
safety_checker=None,
|
110 |
+
feature_extractor=feature_extractor,
|
111 |
+
).to('cuda')
|
112 |
+
|
113 |
+
return pipeline
|
114 |
+
|
115 |
+
|
116 |
+
def get_muse(args):
|
117 |
+
muse_model_name = 'm3face/FaceConditioning'
|
118 |
+
if args.condition == 'mask':
|
119 |
+
muse_revision = 'segmentation'
|
120 |
+
elif args.condition == 'landmark':
|
121 |
+
muse_revision = 'landmark'
|
122 |
+
scheduler = AmusedScheduler.from_pretrained(muse_model_name, revision=muse_revision, subfolder='scheduler')
|
123 |
+
vqvae = VQModel.from_pretrained(muse_model_name, revision=muse_revision, subfolder='vqvae')
|
124 |
+
uvit2 = UVit2DModel.from_pretrained(muse_model_name, revision=muse_revision, subfolder='transformer')
|
125 |
+
text_encoder = MultilingualCLIP.from_pretrained(muse_model_name, revision=muse_revision, subfolder='text_encoder')
|
126 |
+
tokenizer = AutoTokenizer.from_pretrained(muse_model_name, revision=muse_revision, subfolder='tokenizer')
|
127 |
+
|
128 |
+
pipeline = AmusedPipeline(
|
129 |
+
vqvae=vqvae,
|
130 |
+
tokenizer=tokenizer,
|
131 |
+
text_encoder=text_encoder,
|
132 |
+
transformer=uvit2,
|
133 |
+
scheduler=scheduler
|
134 |
+
).to("cuda")
|
135 |
+
|
136 |
+
return pipeline
|
137 |
+
|
138 |
+
|
139 |
+
if __name__ == '__main__':
|
140 |
+
args = parse_args()
|
141 |
+
|
142 |
+
# ========== set up face generation pipeline ==========
|
143 |
+
controlnet = get_controlnet(args)
|
144 |
+
|
145 |
+
# ========== set output directory ==========
|
146 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
147 |
+
|
148 |
+
# ========== set random seed ==========
|
149 |
+
if args.seed is None:
|
150 |
+
generator = None
|
151 |
+
else:
|
152 |
+
generator = torch.Generator().manual_seed(args.seed)
|
153 |
+
|
154 |
+
# ========== generation ==========
|
155 |
+
id = int(time.time())
|
156 |
+
if args.condition_path:
|
157 |
+
condition = load_image(args.condition_path).resize((512, 512))
|
158 |
+
else:
|
159 |
+
# generate condition
|
160 |
+
muse = get_muse(args)
|
161 |
+
if args.condition == 'mask':
|
162 |
+
muse_added_prompt = 'Generate face segmentation | '
|
163 |
+
elif args.condition == 'landmark':
|
164 |
+
muse_added_prompt = 'Generate face landmark | '
|
165 |
+
muse_prompt = muse_added_prompt + args.prompt
|
166 |
+
condition = muse(muse_prompt, num_inference_steps=256).images[0].resize((512, 512))
|
167 |
+
if args.save_condition:
|
168 |
+
condition.save(f'{args.output_dir}/{id}_condition.png')
|
169 |
+
|
170 |
+
latents = torch.randn((args.num_samples, 4, 64, 64), generator=generator)
|
171 |
+
prompt = f'{args.prompt}, {args.additional_prompt}' if args.prompt else args.additional_prompt
|
172 |
+
images = controlnet(prompt, image=condition, num_inference_steps=args.num_inference_steps, negative_prompt=args.negative_prompt,
|
173 |
+
generator=generator, latents=latents, num_images_per_prompt=args.num_samples).images
|
174 |
+
|
175 |
+
for i, image in enumerate(images):
|
176 |
+
image.save(f'{args.output_dir}/{id}_{i}.png')
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers
|
2 |
+
datasets
|
3 |
+
transformers
|
4 |
+
accelerate
|
5 |
+
xformers==0.0.21
|
6 |
+
face-alignment
|
7 |
+
gdown
|
utils/dml_csr/dml_csr.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@Author : Qingping Zheng
|
5 |
+
@Contact : qingpingzheng2014@gmail.com
|
6 |
+
@File : dml_csr.py
|
7 |
+
@Time : 10/01/21 00:00 PM
|
8 |
+
@Desc :
|
9 |
+
@License : Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
@Copyright : Copyright 2015 The Authors. All Rights Reserved.
|
11 |
+
"""
|
12 |
+
from __future__ import absolute_import
|
13 |
+
from __future__ import division
|
14 |
+
from __future__ import print_function
|
15 |
+
|
16 |
+
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
from torch.nn import functional as F
|
20 |
+
from inplace_abn import InPlaceABNSync
|
21 |
+
from .modules.ddgcn import DDualGCNHead
|
22 |
+
from .modules.parsing import Parsing
|
23 |
+
from .modules.edges import Edges
|
24 |
+
from .modules.util import Bottleneck
|
25 |
+
|
26 |
+
|
27 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
28 |
+
"3x3 convolution with padding"
|
29 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
30 |
+
padding=1, bias=False)
|
31 |
+
|
32 |
+
|
33 |
+
class DML_CSR(nn.Module):
|
34 |
+
def __init__(self,
|
35 |
+
num_classes,
|
36 |
+
abn=InPlaceABNSync,
|
37 |
+
trained=True):
|
38 |
+
super().__init__()
|
39 |
+
self.inplanes = 128
|
40 |
+
self.is_trained = trained
|
41 |
+
|
42 |
+
self.conv1 = conv3x3(3, 64, stride=2)
|
43 |
+
self.bn1 = abn(64)
|
44 |
+
self.relu1 = nn.ReLU(inplace=False)
|
45 |
+
self.conv2 = conv3x3(64, 64)
|
46 |
+
self.bn2 = abn(64)
|
47 |
+
self.relu2 = nn.ReLU(inplace=False)
|
48 |
+
self.conv3 = conv3x3(64, 128)
|
49 |
+
self.bn3 = abn(128)
|
50 |
+
self.relu3 = nn.ReLU(inplace=False)
|
51 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
52 |
+
self.layers = [3, 4, 23, 3]
|
53 |
+
self.abn = abn
|
54 |
+
strides = [1, 2, 1, 1]
|
55 |
+
dilations = [1, 1, 1, 2]
|
56 |
+
|
57 |
+
self.layer1 = self._make_layer(Bottleneck, 64, self.layers[0], stride=strides[0], dilation=dilations[0])
|
58 |
+
self.layer2 = self._make_layer(Bottleneck, 128, self.layers[1], stride=strides[1], dilation=dilations[1])
|
59 |
+
self.layer3 = self._make_layer(Bottleneck, 256, self.layers[2], stride=strides[2], dilation=dilations[2])
|
60 |
+
self.layer4 = self._make_layer(Bottleneck, 512, self.layers[3], stride=strides[3], dilation=dilations[3], multi_grid=(1,1,1))
|
61 |
+
# Context Aware
|
62 |
+
self.context = DDualGCNHead(2048, 512, abn)
|
63 |
+
self.layer6 = Parsing(512, 256, num_classes, abn)
|
64 |
+
# edge
|
65 |
+
if self.is_trained:
|
66 |
+
self.edge_layer = Edges(abn, out_fea=num_classes)
|
67 |
+
|
68 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, multi_grid=1):
|
69 |
+
downsample = None
|
70 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
71 |
+
downsample = nn.Sequential(
|
72 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
73 |
+
kernel_size=1, stride=stride, bias=False),
|
74 |
+
self.abn(planes * block.expansion, affine=True))
|
75 |
+
|
76 |
+
layers = []
|
77 |
+
generate_multi_grid = lambda index, grids: grids[index%len(grids)] if isinstance(grids, tuple) else 1
|
78 |
+
layers.append(block(self.inplanes, planes, stride, abn=self.abn, dilation=dilation, downsample=downsample, multi_grid=generate_multi_grid(0, multi_grid)))
|
79 |
+
self.inplanes = planes * block.expansion
|
80 |
+
for i in range(1, blocks):
|
81 |
+
layers.append(block(self.inplanes, planes, abn=self.abn, dilation=dilation, multi_grid=generate_multi_grid(i, multi_grid)))
|
82 |
+
|
83 |
+
return nn.Sequential(*layers)
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
input = x
|
87 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
88 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
89 |
+
x1 = self.relu3(self.bn3(self.conv3(x)))
|
90 |
+
x = self.maxpool(x1)
|
91 |
+
x2 = self.layer1(x) # 119 x 119
|
92 |
+
x3 = self.layer2(x2) # 60 x 60
|
93 |
+
x4 = self.layer3(x3) # 60 x 60
|
94 |
+
x5 = self.layer4(x4) # 60 x 60
|
95 |
+
x = self.context(x5)
|
96 |
+
seg, x = self.layer6(x, x2)
|
97 |
+
|
98 |
+
if self.is_trained:
|
99 |
+
binary_edge, semantic_edge, edge_fea = self.edge_layer(x2,x3,x4)
|
100 |
+
return seg, binary_edge, semantic_edge
|
101 |
+
|
102 |
+
return seg
|
103 |
+
|
utils/dml_csr/modules/ddgcn.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@Author : Qingping Zheng
|
5 |
+
@Contact : qingpingzheng2014@gmail.com
|
6 |
+
@File : ddgcn.py
|
7 |
+
@Time : 10/01/21 00:00 PM
|
8 |
+
@Desc :
|
9 |
+
@License : Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
@Copyright : Copyright 2022 The Authors. All Rights Reserved.
|
11 |
+
"""
|
12 |
+
from __future__ import absolute_import
|
13 |
+
from __future__ import division
|
14 |
+
from __future__ import print_function
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from inplace_abn import InPlaceABNSync
|
21 |
+
|
22 |
+
|
23 |
+
class SpatialGCN(nn.Module):
|
24 |
+
def __init__(self, plane, abn=InPlaceABNSync):
|
25 |
+
super(SpatialGCN, self).__init__()
|
26 |
+
inter_plane = plane // 2
|
27 |
+
self.node_k = nn.Conv2d(plane, inter_plane, kernel_size=1)
|
28 |
+
self.node_v = nn.Conv2d(plane, inter_plane, kernel_size=1)
|
29 |
+
self.node_q = nn.Conv2d(plane, inter_plane, kernel_size=1)
|
30 |
+
|
31 |
+
self.conv_wg = nn.Conv1d(inter_plane, inter_plane, kernel_size=1, bias=False)
|
32 |
+
self.bn_wg = nn.BatchNorm1d(inter_plane)
|
33 |
+
self.softmax = nn.Softmax(dim=2)
|
34 |
+
|
35 |
+
self.out = nn.Sequential(nn.Conv2d(inter_plane, plane, kernel_size=1),
|
36 |
+
abn(plane))
|
37 |
+
|
38 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
# b, c, h, w = x.size()
|
42 |
+
node_k = self.node_k(x)
|
43 |
+
node_v = self.node_v(x)
|
44 |
+
node_q = self.node_q(x)
|
45 |
+
b,c,h,w = node_k.size()
|
46 |
+
node_k = node_k.view(b, c, -1).permute(0, 2, 1)
|
47 |
+
node_q = node_q.view(b, c, -1)
|
48 |
+
node_v = node_v.view(b, c, -1).permute(0, 2, 1)
|
49 |
+
# A = k * q
|
50 |
+
# AV = k * q * v
|
51 |
+
# AVW = k *(q *v) * w
|
52 |
+
AV = torch.bmm(node_q,node_v)
|
53 |
+
AV = self.softmax(AV)
|
54 |
+
AV = torch.bmm(node_k, AV)
|
55 |
+
AV = AV.transpose(1, 2).contiguous()
|
56 |
+
AVW = self.conv_wg(AV)
|
57 |
+
AVW = self.bn_wg(AVW)
|
58 |
+
AVW = AVW.view(b, c, h, -1)
|
59 |
+
# out = F.relu_(self.out(AVW) + x)
|
60 |
+
out = self.gamma * self.out(AVW) + x
|
61 |
+
return out
|
62 |
+
|
63 |
+
|
64 |
+
class DDualGCN(nn.Module):
|
65 |
+
"""
|
66 |
+
Feature GCN with coordinate GCN
|
67 |
+
"""
|
68 |
+
def __init__(self, planes, abn=InPlaceABNSync, ratio=4):
|
69 |
+
super(DDualGCN, self).__init__()
|
70 |
+
|
71 |
+
self.phi = nn.Conv2d(planes, planes // ratio * 2, kernel_size=1, bias=False)
|
72 |
+
self.bn_phi = abn(planes // ratio * 2)
|
73 |
+
self.theta = nn.Conv2d(planes, planes // ratio, kernel_size=1, bias=False)
|
74 |
+
self.bn_theta = abn(planes // ratio)
|
75 |
+
|
76 |
+
# Interaction Space
|
77 |
+
# Adjacency Matrix: (-)A_g
|
78 |
+
self.conv_adj = nn.Conv1d(planes // ratio, planes // ratio, kernel_size=1, bias=False)
|
79 |
+
self.bn_adj = nn.BatchNorm1d(planes // ratio)
|
80 |
+
|
81 |
+
# State Update Function: W_g
|
82 |
+
self.conv_wg = nn.Conv1d(planes // ratio * 2, planes // ratio * 2, kernel_size=1, bias=False)
|
83 |
+
self.bn_wg = nn.BatchNorm1d(planes // ratio * 2)
|
84 |
+
|
85 |
+
# last fc
|
86 |
+
self.conv3 = nn.Conv2d(planes // ratio * 2, planes, kernel_size=1, bias=False)
|
87 |
+
self.bn3 = abn(planes)
|
88 |
+
|
89 |
+
self.local = nn.Sequential(
|
90 |
+
nn.Conv2d(planes, planes, 3, groups=planes, stride=2, padding=1, bias=False),
|
91 |
+
abn(planes),
|
92 |
+
nn.Conv2d(planes, planes, 3, groups=planes, stride=2, padding=1, bias=False),
|
93 |
+
abn(planes),
|
94 |
+
nn.Conv2d(planes, planes, 3, groups=planes, stride=2, padding=1, bias=False),
|
95 |
+
abn(planes))
|
96 |
+
self.gcn_local_attention = SpatialGCN(planes, abn)
|
97 |
+
|
98 |
+
self.final = nn.Sequential(nn.Conv2d(planes * 2, planes, kernel_size=1, bias=False),
|
99 |
+
abn(planes))
|
100 |
+
|
101 |
+
self.gamma1 = nn.Parameter(torch.zeros(1))
|
102 |
+
|
103 |
+
def to_matrix(self, x):
|
104 |
+
n, c, h, w = x.size()
|
105 |
+
x = x.view(n, c, -1)
|
106 |
+
return x
|
107 |
+
|
108 |
+
def forward(self, feat):
|
109 |
+
# # # # Local # # # #
|
110 |
+
x = feat
|
111 |
+
local = self.local(feat)
|
112 |
+
local = self.gcn_local_attention(local)
|
113 |
+
local = F.interpolate(local, size=x.size()[2:], mode='bilinear', align_corners=True)
|
114 |
+
spatial_local_feat = x * local + x
|
115 |
+
|
116 |
+
# # # # Projection Space # # # #
|
117 |
+
x_sqz, b = x, x
|
118 |
+
|
119 |
+
x_sqz = self.phi(x_sqz)
|
120 |
+
x_sqz = self.bn_phi(x_sqz)
|
121 |
+
x_sqz = self.to_matrix(x_sqz)
|
122 |
+
|
123 |
+
b = self.theta(b)
|
124 |
+
b = self.bn_theta(b)
|
125 |
+
b = self.to_matrix(b)
|
126 |
+
|
127 |
+
# Project
|
128 |
+
z_idt = torch.matmul(x_sqz, b.transpose(1, 2)) # channel
|
129 |
+
|
130 |
+
# # # # Interaction Space # # # #
|
131 |
+
z = z_idt.transpose(1, 2).contiguous()
|
132 |
+
|
133 |
+
z = self.conv_adj(z)
|
134 |
+
z = self.bn_adj(z)
|
135 |
+
|
136 |
+
z = z.transpose(1, 2).contiguous()
|
137 |
+
# Laplacian smoothing: (I - A_g)Z => Z - A_gZ
|
138 |
+
z += z_idt
|
139 |
+
|
140 |
+
z = self.conv_wg(z)
|
141 |
+
z = self.bn_wg(z)
|
142 |
+
|
143 |
+
# # # # Re-projection Space # # # #
|
144 |
+
# Re-project
|
145 |
+
y = torch.matmul(z, b)
|
146 |
+
|
147 |
+
n, _, h, w = x.size()
|
148 |
+
y = y.view(n, -1, h, w)
|
149 |
+
|
150 |
+
y = self.conv3(y)
|
151 |
+
y = self.bn3(y)
|
152 |
+
|
153 |
+
# g_out = x + y
|
154 |
+
# g_out = F.relu_(x+y)
|
155 |
+
g_out = self.gamma1*y + x
|
156 |
+
|
157 |
+
# cat or sum, nearly the same results
|
158 |
+
out = self.final(torch.cat((spatial_local_feat, g_out), 1))
|
159 |
+
|
160 |
+
return out
|
161 |
+
|
162 |
+
|
163 |
+
class DDualGCNHead(nn.Module):
|
164 |
+
def __init__(self, inplanes, interplanes, abn=InPlaceABNSync):
|
165 |
+
super(DDualGCNHead, self).__init__()
|
166 |
+
self.conva = nn.Sequential(nn.Conv2d(inplanes, interplanes, 3, padding=1, bias=False),
|
167 |
+
abn(interplanes))
|
168 |
+
self.dualgcn = DDualGCN(interplanes, abn)
|
169 |
+
self.convb = nn.Sequential(nn.Conv2d(interplanes, interplanes, 3, padding=1, bias=False),
|
170 |
+
abn(interplanes))
|
171 |
+
|
172 |
+
self.bottleneck = nn.Sequential(
|
173 |
+
nn.Conv2d(inplanes + interplanes, interplanes, kernel_size=3, padding=1, dilation=1, bias=False),
|
174 |
+
abn(interplanes)
|
175 |
+
)
|
176 |
+
|
177 |
+
def forward(self, x):
|
178 |
+
output = self.conva(x)
|
179 |
+
output = self.dualgcn(output)
|
180 |
+
output = self.convb(output)
|
181 |
+
output = self.bottleneck(torch.cat([x, output], 1))
|
182 |
+
return output
|
utils/dml_csr/modules/edges.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@Author : Qingping Zheng
|
5 |
+
@Contact : qingpingzheng2014@gmail.com
|
6 |
+
@File : edges.py
|
7 |
+
@Time : 10/01/21 00:00 PM
|
8 |
+
@Desc :
|
9 |
+
@License : Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
@Copyright : Copyright 2022 The Authors. All Rights Reserved.
|
11 |
+
"""
|
12 |
+
from __future__ import absolute_import
|
13 |
+
from __future__ import division
|
14 |
+
from __future__ import print_function
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from inplace_abn import InPlaceABNSync
|
21 |
+
|
22 |
+
|
23 |
+
class Edges(nn.Module):
|
24 |
+
|
25 |
+
def __init__(self, abn=InPlaceABNSync, in_fea=[256,512,1024], mid_fea=256, out_fea=2):
|
26 |
+
super(Edges, self).__init__()
|
27 |
+
|
28 |
+
self.conv1 = nn.Sequential(
|
29 |
+
nn.Conv2d(in_fea[0], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False),
|
30 |
+
abn(mid_fea)
|
31 |
+
)
|
32 |
+
self.conv2 = nn.Sequential(
|
33 |
+
nn.Conv2d(in_fea[1], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False),
|
34 |
+
abn(mid_fea)
|
35 |
+
)
|
36 |
+
self.conv3 = nn.Sequential(
|
37 |
+
nn.Conv2d(in_fea[2], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False),
|
38 |
+
abn(mid_fea)
|
39 |
+
)
|
40 |
+
self.conv4 = nn.Conv2d(mid_fea,out_fea, kernel_size=3, padding=1, dilation=1, bias=True)
|
41 |
+
self.conv5_b = nn.Conv2d(out_fea*3,2, kernel_size=1, padding=0, dilation=1, bias=True)
|
42 |
+
self.conv5 = nn.Conv2d(out_fea*3,out_fea, kernel_size=1, padding=0, dilation=1, bias=True)
|
43 |
+
|
44 |
+
|
45 |
+
def forward(self, x1, x2, x3):
|
46 |
+
_, _, h, w = x1.size()
|
47 |
+
|
48 |
+
edge1_fea = self.conv1(x1)
|
49 |
+
edge1 = self.conv4(edge1_fea)
|
50 |
+
edge2_fea = self.conv2(x2)
|
51 |
+
edge2 = self.conv4(edge2_fea)
|
52 |
+
edge3_fea = self.conv3(x3)
|
53 |
+
edge3 = self.conv4(edge3_fea)
|
54 |
+
|
55 |
+
edge2_fea = F.interpolate(edge2_fea, size=(h, w), mode='bilinear',align_corners=True)
|
56 |
+
edge3_fea = F.interpolate(edge3_fea, size=(h, w), mode='bilinear',align_corners=True)
|
57 |
+
edge2 = F.interpolate(edge2, size=(h, w), mode='bilinear',align_corners=True)
|
58 |
+
edge3 = F.interpolate(edge3, size=(h, w), mode='bilinear',align_corners=True)
|
59 |
+
|
60 |
+
edge = torch.cat([edge1, edge2, edge3], dim=1)
|
61 |
+
edge_fea = torch.cat([edge1_fea, edge2_fea, edge3_fea], dim=1)
|
62 |
+
semantic_edge = self.conv5(edge)
|
63 |
+
binary_edge = self.conv5_b(edge)
|
64 |
+
|
65 |
+
return binary_edge, semantic_edge, edge_fea
|
66 |
+
|
utils/dml_csr/modules/parsing.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@Author : Qingping Zheng
|
5 |
+
@Contact : qingpingzheng2014@gmail.com
|
6 |
+
@File : parsing.py
|
7 |
+
@Time : 10/01/21 00:00 PM
|
8 |
+
@Desc :
|
9 |
+
@License : Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
@Copyright : Copyright 2022 The Authors. All Rights Reserved.
|
11 |
+
"""
|
12 |
+
from __future__ import absolute_import
|
13 |
+
from __future__ import division
|
14 |
+
from __future__ import print_function
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from inplace_abn import InPlaceABNSync
|
21 |
+
|
22 |
+
|
23 |
+
class Parsing(nn.Module):
|
24 |
+
def __init__(self, in_plane1, in_plane2, num_classes, abn=InPlaceABNSync):
|
25 |
+
super(Parsing, self).__init__()
|
26 |
+
self.conv1 = nn.Sequential(
|
27 |
+
nn.Conv2d(in_plane1, 256, kernel_size=1, padding=0, dilation=1, bias=False),
|
28 |
+
abn(256)
|
29 |
+
)
|
30 |
+
self.conv2 = nn.Sequential(
|
31 |
+
nn.Conv2d(in_plane2, 48, kernel_size=1, stride=1, padding=0, dilation=1, bias=False),
|
32 |
+
abn(48)
|
33 |
+
)
|
34 |
+
self.conv3 = nn.Sequential(
|
35 |
+
nn.Conv2d(304, 256, kernel_size=1, padding=0, dilation=1, bias=False),
|
36 |
+
abn(256),
|
37 |
+
nn.Conv2d(256, 256, kernel_size=1, padding=0, dilation=1, bias=False),
|
38 |
+
abn(256)
|
39 |
+
)
|
40 |
+
self.conv4 = nn.Conv2d(256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True)
|
41 |
+
|
42 |
+
def forward(self, xt, xl):
|
43 |
+
_, _, h, w = xl.size()
|
44 |
+
|
45 |
+
xt = F.interpolate(self.conv1(xt), size=(h, w), mode='bilinear', align_corners=True)
|
46 |
+
xl = self.conv2(xl)
|
47 |
+
x = torch.cat([xt, xl], dim=1)
|
48 |
+
x = self.conv3(x)
|
49 |
+
seg = self.conv4(x)
|
50 |
+
return seg, x
|
51 |
+
|
utils/dml_csr/modules/util.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@Author : Qingping Zheng
|
5 |
+
@Contact : qingpingzheng2014@gmail.com
|
6 |
+
@File : util.py
|
7 |
+
@Time : 10/01/21 00:00 PM
|
8 |
+
@Desc :
|
9 |
+
@License : Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
@Copyright : Copyright 2022 The Authors. All Rights Reserved.
|
11 |
+
"""
|
12 |
+
from __future__ import absolute_import
|
13 |
+
from __future__ import division
|
14 |
+
from __future__ import print_function
|
15 |
+
|
16 |
+
import torch.nn as nn
|
17 |
+
|
18 |
+
from inplace_abn import InPlaceABNSync
|
19 |
+
|
20 |
+
|
21 |
+
class Bottleneck(nn.Module):
|
22 |
+
expansion = 4
|
23 |
+
def __init__(self, inplanes, planes, stride=1, abn=InPlaceABNSync, dilation=1, downsample=None, fist_dilation=1, multi_grid=1):
|
24 |
+
super(Bottleneck, self).__init__()
|
25 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
26 |
+
self.bn1 = abn(planes)
|
27 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
28 |
+
padding=dilation*multi_grid, dilation=dilation*multi_grid, bias=False)
|
29 |
+
self.bn2 = abn(planes)
|
30 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
31 |
+
self.bn3 = abn(planes * 4)
|
32 |
+
self.relu = nn.ReLU(inplace=False)
|
33 |
+
self.relu_inplace = nn.ReLU(inplace=True)
|
34 |
+
self.downsample = downsample
|
35 |
+
self.dilation = dilation
|
36 |
+
self.stride = stride
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
residual = x
|
40 |
+
|
41 |
+
out = self.conv1(x)
|
42 |
+
out = self.bn1(out)
|
43 |
+
out = self.relu(out)
|
44 |
+
|
45 |
+
out = self.conv2(out)
|
46 |
+
out = self.bn2(out)
|
47 |
+
out = self.relu(out)
|
48 |
+
|
49 |
+
out = self.conv3(out)
|
50 |
+
out = self.bn3(out)
|
51 |
+
|
52 |
+
if self.downsample is not None:
|
53 |
+
residual = self.downsample(x)
|
54 |
+
|
55 |
+
out = out + residual
|
56 |
+
out = self.relu_inplace(out)
|
57 |
+
|
58 |
+
return out
|
utils/dml_csr/transforms.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
@Author : Qingping Zheng
|
5 |
+
@Contact : qingpingzheng2014@gmail.com
|
6 |
+
@File : transforms.py
|
7 |
+
@Time : 10/01/21 00:00 PM
|
8 |
+
@Desc :
|
9 |
+
@License : Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
@Copyright : Copyright 2022 The Authors. All Rights Reserved.
|
11 |
+
"""
|
12 |
+
from __future__ import absolute_import
|
13 |
+
from __future__ import division
|
14 |
+
from __future__ import print_function
|
15 |
+
|
16 |
+
|
17 |
+
import numpy as np
|
18 |
+
import cv2
|
19 |
+
|
20 |
+
|
21 |
+
def flip_back(output_flipped, matched_parts):
|
22 |
+
'''
|
23 |
+
ouput_flipped: numpy.ndarray(batch_size, num_joints, height, width)
|
24 |
+
'''
|
25 |
+
assert output_flipped.ndim == 4,\
|
26 |
+
'output_flipped should be [batch_size, num_joints, height, width]'
|
27 |
+
|
28 |
+
output_flipped = output_flipped[:, :, :, ::-1]
|
29 |
+
|
30 |
+
for pair in matched_parts:
|
31 |
+
tmp = output_flipped[:, pair[0], :, :].copy()
|
32 |
+
output_flipped[:, pair[0], :, :] = output_flipped[:, pair[1], :, :]
|
33 |
+
output_flipped[:, pair[1], :, :] = tmp
|
34 |
+
|
35 |
+
return output_flipped
|
36 |
+
|
37 |
+
|
38 |
+
def transform_parsing(pred, center, scale, width, height, input_size):
|
39 |
+
|
40 |
+
if center is not None:
|
41 |
+
trans = get_affine_transform(center, scale, 0, input_size, inv=1)
|
42 |
+
target_pred = cv2.warpAffine(
|
43 |
+
pred,
|
44 |
+
trans,
|
45 |
+
(int(width), int(height)), #(int(width), int(height)),
|
46 |
+
flags=cv2.INTER_NEAREST,
|
47 |
+
borderMode=cv2.BORDER_CONSTANT,
|
48 |
+
borderValue=(0))
|
49 |
+
else:
|
50 |
+
target_pred = cv2.resize(pred, (int(width), int(height)), interpolation=cv2.INTER_NEAREST)
|
51 |
+
|
52 |
+
return target_pred
|
53 |
+
|
54 |
+
|
55 |
+
def get_affine_transform(center,
|
56 |
+
scale,
|
57 |
+
rot,
|
58 |
+
output_size,
|
59 |
+
shift=np.array([0, 0], dtype=np.float32),
|
60 |
+
inv=0):
|
61 |
+
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
|
62 |
+
print(scale)
|
63 |
+
scale = np.array([scale, scale])
|
64 |
+
|
65 |
+
scale_tmp = scale
|
66 |
+
|
67 |
+
src_w = scale_tmp[0]
|
68 |
+
dst_w = output_size[1]
|
69 |
+
dst_h = output_size[0]
|
70 |
+
|
71 |
+
rot_rad = np.pi * rot / 180
|
72 |
+
src_dir = get_dir([0, src_w * -0.5], rot_rad)
|
73 |
+
dst_dir = np.array([0, dst_w * -0.5], np.float32)
|
74 |
+
|
75 |
+
src = np.zeros((3, 2), dtype=np.float32)
|
76 |
+
dst = np.zeros((3, 2), dtype=np.float32)
|
77 |
+
src[0, :] = center + scale_tmp * shift
|
78 |
+
src[1, :] = center + src_dir + scale_tmp * shift
|
79 |
+
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
80 |
+
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
81 |
+
|
82 |
+
src[2:, :] = get_3rd_point(src[0, :], src[1, :])
|
83 |
+
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
|
84 |
+
|
85 |
+
if inv:
|
86 |
+
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
87 |
+
else:
|
88 |
+
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
89 |
+
|
90 |
+
return trans
|
91 |
+
|
92 |
+
|
93 |
+
def affine_transform(pt, t):
|
94 |
+
new_pt = np.array([pt[0], pt[1], 1.]).T
|
95 |
+
new_pt = np.dot(t, new_pt)
|
96 |
+
return new_pt[:2]
|
97 |
+
|
98 |
+
|
99 |
+
def get_3rd_point(a, b):
|
100 |
+
direct = a - b
|
101 |
+
return b + np.array([-direct[1], direct[0]], dtype=np.float32)
|
102 |
+
|
103 |
+
|
104 |
+
def get_dir(src_point, rot_rad):
|
105 |
+
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
106 |
+
|
107 |
+
src_result = [0, 0]
|
108 |
+
src_result[0] = src_point[0] * cs - src_point[1] * sn
|
109 |
+
src_result[1] = src_point[0] * sn + src_point[1] * cs
|
110 |
+
|
111 |
+
return src_result
|
112 |
+
|
113 |
+
|
114 |
+
def crop(img, center, scale, output_size, rot=0):
|
115 |
+
trans = get_affine_transform(center, scale, rot, output_size)
|
116 |
+
|
117 |
+
dst_img = cv2.warpAffine(img,
|
118 |
+
trans,
|
119 |
+
(int(output_size[1]), int(output_size[0])),
|
120 |
+
flags=cv2.INTER_LINEAR)
|
121 |
+
|
122 |
+
return dst_img
|
utils/mclip.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import transformers
|
3 |
+
from typing import Union, Optional, Tuple
|
4 |
+
from transformers import AutoConfig, AutoModel
|
5 |
+
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
|
6 |
+
|
7 |
+
|
8 |
+
class MCLIPConfig(transformers.PretrainedConfig):
|
9 |
+
model_type = "M-CLIP"
|
10 |
+
|
11 |
+
def __init__(self, modelBase='xlm-roberta-large', transformerDimSize=1024, imageDimSize=768, **kwargs):
|
12 |
+
self.transformerDimensions = transformerDimSize
|
13 |
+
self.numDims = imageDimSize
|
14 |
+
self.modelBase = modelBase
|
15 |
+
super().__init__(**kwargs)
|
16 |
+
|
17 |
+
|
18 |
+
|
19 |
+
class MultilingualCLIP(transformers.PreTrainedModel):
|
20 |
+
config_class = MCLIPConfig
|
21 |
+
|
22 |
+
def __init__(self, config, *args, **kwargs):
|
23 |
+
super().__init__(config, *args, **kwargs)
|
24 |
+
self.transformer = transformers.AutoModel.from_pretrained(config.modelBase)
|
25 |
+
self.LinearTransformation = torch.nn.Linear(in_features=config.transformerDimensions,
|
26 |
+
out_features=config.numDims)
|
27 |
+
|
28 |
+
def forward(
|
29 |
+
self,
|
30 |
+
input_ids: Optional[torch.Tensor] = None,
|
31 |
+
attention_mask: Optional[torch.Tensor] = None,
|
32 |
+
position_ids: Optional[torch.Tensor] = None,
|
33 |
+
output_attentions: Optional[bool] = None,
|
34 |
+
output_hidden_states: Optional[bool] = None,
|
35 |
+
return_dict: Optional[bool] = None,
|
36 |
+
) -> Union[Tuple, CLIPTextModelOutput]:
|
37 |
+
|
38 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
39 |
+
|
40 |
+
text_outputs = self.transformer(
|
41 |
+
input_ids=input_ids,
|
42 |
+
attention_mask=attention_mask,
|
43 |
+
position_ids=position_ids,
|
44 |
+
output_attentions=output_attentions,
|
45 |
+
output_hidden_states=output_hidden_states,
|
46 |
+
return_dict=return_dict,
|
47 |
+
)
|
48 |
+
|
49 |
+
pooled_output = text_outputs[1]
|
50 |
+
|
51 |
+
text_embeds = self.LinearTransformation(pooled_output)
|
52 |
+
|
53 |
+
if not return_dict:
|
54 |
+
outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
|
55 |
+
return tuple(output for output in outputs if output is not None)
|
56 |
+
|
57 |
+
return CLIPTextModelOutput(
|
58 |
+
text_embeds=text_embeds,
|
59 |
+
last_hidden_state=text_outputs.last_hidden_state,
|
60 |
+
hidden_states=text_outputs.hidden_states,
|
61 |
+
attentions=text_outputs.attentions,
|
62 |
+
)
|
63 |
+
|
64 |
+
@classmethod
|
65 |
+
def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path, _fast_init=True):
|
66 |
+
model.load_state_dict(state_dict)
|
67 |
+
return model, [], [], []
|
68 |
+
|
69 |
+
AutoConfig.register("M-CLIP", MCLIPConfig)
|
70 |
+
AutoModel.register(MCLIPConfig, MultilingualCLIP)
|
utils/plot_landmark.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import PIL
|
3 |
+
import cv2
|
4 |
+
import pickle
|
5 |
+
import argparse
|
6 |
+
import numpy as np
|
7 |
+
import face_alignment
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import matplotlib.patches as patches
|
10 |
+
from matplotlib.path import Path
|
11 |
+
|
12 |
+
|
13 |
+
def parse_args():
|
14 |
+
parser = argparse.ArgumentParser(description="Plot facial landmarks from an image.")
|
15 |
+
parser.add_argument(
|
16 |
+
"--image_path",
|
17 |
+
type=str,
|
18 |
+
default=None,
|
19 |
+
help="Path to the image file."
|
20 |
+
)
|
21 |
+
parser.add_argument("--size", type=int, default=512)
|
22 |
+
parser.add_argument("--crop", action="store_true", help="Crop around the face image.")
|
23 |
+
parser.add_argument(
|
24 |
+
"--output_dir",
|
25 |
+
type=str,
|
26 |
+
default="output/landmarks/",
|
27 |
+
help="Folder to save landmark images."
|
28 |
+
)
|
29 |
+
args = parser.parse_args()
|
30 |
+
|
31 |
+
return args
|
32 |
+
|
33 |
+
def get_patch(landmarks, color='lime', closed=False):
|
34 |
+
contour = landmarks
|
35 |
+
ops = [Path.MOVETO] + [Path.LINETO]*(len(contour)-1)
|
36 |
+
facecolor = (0, 0, 0, 0) # Transparent fill color, if open
|
37 |
+
if closed:
|
38 |
+
contour.append(contour[0])
|
39 |
+
ops.append(Path.CLOSEPOLY)
|
40 |
+
facecolor = color
|
41 |
+
path = Path(contour, ops)
|
42 |
+
return patches.PathPatch(path, facecolor=facecolor, edgecolor=color, lw=4)
|
43 |
+
|
44 |
+
def bbox_from_landmarks(landmarks):
|
45 |
+
landmarks_x, landmarks_y = zip(*landmarks)
|
46 |
+
|
47 |
+
x_min, x_max = min(landmarks_x), max(landmarks_x)
|
48 |
+
y_min, y_max = min(landmarks_y), max(landmarks_y)
|
49 |
+
width = x_max - x_min
|
50 |
+
height = y_max - y_min
|
51 |
+
|
52 |
+
# Give it a little room; I think it works anyway
|
53 |
+
x_min -= 25
|
54 |
+
y_min -= 25
|
55 |
+
width += 50
|
56 |
+
height += 50
|
57 |
+
bbox = (x_min, y_min, width, height)
|
58 |
+
return bbox
|
59 |
+
|
60 |
+
def plot_landmarks(landmarks, crop=False, size=512):
|
61 |
+
if crop:
|
62 |
+
(x_min, y_min, width, height) = bbox_from_landmarks(landmarks)
|
63 |
+
# print(x_min, y_min, width, height)
|
64 |
+
landmarks_np = np.array(landmarks)
|
65 |
+
landmarks_np[:, 0] = (landmarks_np[:, 0] - x_min) * size / width
|
66 |
+
landmarks_np[:, 1] = (landmarks_np[:, 1] - y_min) * size / height
|
67 |
+
landmarks = landmarks_np.tolist()
|
68 |
+
# Precisely control output image size
|
69 |
+
dpi = 72
|
70 |
+
fig, ax = plt.subplots(1, figsize=[size/dpi, size/dpi], tight_layout={'pad':0})
|
71 |
+
fig.set_dpi(dpi)
|
72 |
+
|
73 |
+
black = np.zeros((size, size, 3))
|
74 |
+
ax.imshow(black)
|
75 |
+
|
76 |
+
face_patch = get_patch(landmarks[0:17])
|
77 |
+
l_eyebrow = get_patch(landmarks[17:22], color='yellow')
|
78 |
+
r_eyebrow = get_patch(landmarks[22:27], color='yellow')
|
79 |
+
nose_v = get_patch(landmarks[27:31], color='orange')
|
80 |
+
nose_h = get_patch(landmarks[31:36], color='orange')
|
81 |
+
l_eye = get_patch(landmarks[36:42], color='magenta', closed=True)
|
82 |
+
r_eye = get_patch(landmarks[42:48], color='magenta', closed=True)
|
83 |
+
outer_lips = get_patch(landmarks[48:60], color='cyan', closed=True)
|
84 |
+
inner_lips = get_patch(landmarks[60:68], color='blue', closed=True)
|
85 |
+
|
86 |
+
ax.add_patch(face_patch)
|
87 |
+
ax.add_patch(l_eyebrow)
|
88 |
+
ax.add_patch(r_eyebrow)
|
89 |
+
ax.add_patch(nose_v)
|
90 |
+
ax.add_patch(nose_h)
|
91 |
+
ax.add_patch(l_eye)
|
92 |
+
ax.add_patch(r_eye)
|
93 |
+
ax.add_patch(outer_lips)
|
94 |
+
ax.add_patch(inner_lips)
|
95 |
+
|
96 |
+
plt.axis('off')
|
97 |
+
|
98 |
+
fig.canvas.draw()
|
99 |
+
buffer, (width, height) = fig.canvas.print_to_buffer()
|
100 |
+
assert width == height
|
101 |
+
assert width == size
|
102 |
+
|
103 |
+
buffer = np.frombuffer(buffer, np.uint8).reshape((height, width, 4))
|
104 |
+
buffer = buffer[:, :, 0:3]
|
105 |
+
plt.close(fig)
|
106 |
+
return PIL.Image.fromarray(buffer)
|
107 |
+
|
108 |
+
def get_landmarks(image):
|
109 |
+
fa = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, flip_input=False, face_detector='sfd')
|
110 |
+
faces = fa.get_landmarks_from_image(image)
|
111 |
+
if faces is None or len(faces) == 0:
|
112 |
+
return None
|
113 |
+
landmarks = faces[0]
|
114 |
+
return landmarks
|
115 |
+
|
116 |
+
def save_landmarks(args):
|
117 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
118 |
+
|
119 |
+
image_name = os.path.basename(args.image_path)
|
120 |
+
image = cv2.imread(args.image_path)
|
121 |
+
image = cv2.resize(image, (args.size, args.size))
|
122 |
+
landmarks = get_landmarks(image)
|
123 |
+
if landmarks is None:
|
124 |
+
print(f'No faces found in {image_name}')
|
125 |
+
return
|
126 |
+
|
127 |
+
filename = f'{args.output_dir}/{image_name}'
|
128 |
+
if args.crop:
|
129 |
+
landmarks_cropped_image = plot_landmarks(landmarks.tolist(), crop=True, size=args.size)
|
130 |
+
landmarks_cropped_image.save(filename)
|
131 |
+
else:
|
132 |
+
landmarks_image = plot_landmarks(landmarks.tolist(), size=args.size)
|
133 |
+
landmarks_image.save(filename)
|
134 |
+
print(f'Landmark saved in {filename}')
|
135 |
+
|
136 |
+
if __name__ == '__main__':
|
137 |
+
args = parse_args()
|
138 |
+
save_landmarks(args)
|
utils/plot_mask.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import gdown
|
4 |
+
import shutil
|
5 |
+
import argparse
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.backends.cudnn as cudnn
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
from torchvision.utils import save_image
|
11 |
+
|
12 |
+
from inplace_abn import InPlaceABN
|
13 |
+
from dml_csr import dml_csr
|
14 |
+
from dml_csr import transforms as dml_transforms
|
15 |
+
|
16 |
+
|
17 |
+
def parse_args():
|
18 |
+
parser = argparse.ArgumentParser(description="Plot segmentation mask of an image.")
|
19 |
+
parser.add_argument(
|
20 |
+
"--image_path",
|
21 |
+
type=str,
|
22 |
+
default=None,
|
23 |
+
help="Path to the image file."
|
24 |
+
)
|
25 |
+
parser.add_argument("--size", type=int, default=512)
|
26 |
+
parser.add_argument(
|
27 |
+
"--checkpoint_path",
|
28 |
+
type=str,
|
29 |
+
default='ckpt/DML_CSR/dml_csr_celebA.pth',
|
30 |
+
help="Path to the DML-CSR pretrained model."
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
"--output_dir",
|
34 |
+
type=str,
|
35 |
+
default="output/masks/",
|
36 |
+
help="Folder to save segmentation mask."
|
37 |
+
)
|
38 |
+
args = parser.parse_args()
|
39 |
+
|
40 |
+
return args
|
41 |
+
|
42 |
+
def download_checkpoint():
|
43 |
+
os.makedirs('ckpt', exist_ok=True)
|
44 |
+
id = "1xttWuAj633-ujp_vcm5DtL98PP0b-sUm"
|
45 |
+
gdown.download(id=id, output='ckpt/DML_CSR.zip')
|
46 |
+
shutil.unpack_archive('ckpt/DML_CSR.zip', 'ckpt')
|
47 |
+
os.remove('ckpt/DML_CSR.zip')
|
48 |
+
|
49 |
+
def box2cs(box: list) -> tuple:
|
50 |
+
x, y, w, h = box[:4]
|
51 |
+
return xywh2cs(x, y, w, h)
|
52 |
+
|
53 |
+
def xywh2cs(x: float, y: float, w: float, h: float) -> tuple:
|
54 |
+
center = np.zeros((2), dtype=np.float32)
|
55 |
+
center[0] = x + w * 0.5
|
56 |
+
center[1] = y + h * 0.5
|
57 |
+
if w > h:
|
58 |
+
h = w
|
59 |
+
elif w < h:
|
60 |
+
w = h
|
61 |
+
scale = np.array([w * 1.0, h * 1.0], dtype=np.float32)
|
62 |
+
|
63 |
+
return center, scale
|
64 |
+
|
65 |
+
def labelcolormap(N):
|
66 |
+
if N == 19: # CelebAMask-HQ
|
67 |
+
cmap = np.array([(0, 0, 0), (204, 0, 0), (76, 153, 0),
|
68 |
+
(204, 204, 0), (204, 0, 204), (204, 0, 204), (255, 204, 204),
|
69 |
+
(255, 204, 204), (102, 51, 0), (102, 51, 0), (102, 204, 0),
|
70 |
+
(255, 255, 0), (0, 0, 153), (0, 0, 204), (255, 51, 153),
|
71 |
+
(0, 204, 204), (0, 51, 0), (255, 153, 51), (0, 204, 0)],
|
72 |
+
dtype=np.uint8)
|
73 |
+
else:
|
74 |
+
def uint82bin(n, count=8):
|
75 |
+
"""returns the binary of integer n, count refers to amount of bits"""
|
76 |
+
return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)])
|
77 |
+
|
78 |
+
cmap = np.zeros((N, 3), dtype=np.uint8)
|
79 |
+
for i in range(N):
|
80 |
+
r, g, b = 0, 0, 0
|
81 |
+
id = i
|
82 |
+
for j in range(7):
|
83 |
+
str_id = uint82bin(id)
|
84 |
+
r = r ^ (np.uint8(str_id[-1]) << (7-j))
|
85 |
+
g = g ^ (np.uint8(str_id[-2]) << (7-j))
|
86 |
+
b = b ^ (np.uint8(str_id[-3]) << (7-j))
|
87 |
+
id = id >> 3
|
88 |
+
cmap[i, 0] = r
|
89 |
+
cmap[i, 1] = g
|
90 |
+
cmap[i, 2] = b
|
91 |
+
return cmap
|
92 |
+
|
93 |
+
class Colorize(object):
|
94 |
+
def __init__(self, n=19):
|
95 |
+
self.cmap = labelcolormap(n)
|
96 |
+
self.cmap = torch.from_numpy(self.cmap[:n])
|
97 |
+
|
98 |
+
def __call__(self, gray_image):
|
99 |
+
size = gray_image.size()
|
100 |
+
color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
|
101 |
+
|
102 |
+
for label in range(0, len(self.cmap)):
|
103 |
+
mask = (label == gray_image[0]).cpu()
|
104 |
+
color_image[0][mask] = self.cmap[label][0]
|
105 |
+
color_image[1][mask] = self.cmap[label][1]
|
106 |
+
color_image[2][mask] = self.cmap[label][2]
|
107 |
+
|
108 |
+
return color_image
|
109 |
+
|
110 |
+
def tensor2label(label_tensor, n_label):
|
111 |
+
label_tensor = label_tensor.cpu().float()
|
112 |
+
if label_tensor.size()[0] > 1:
|
113 |
+
label_tensor = label_tensor.max(0, keepdim=True)[1]
|
114 |
+
label_tensor = Colorize(n_label)(label_tensor)
|
115 |
+
#label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
|
116 |
+
label_numpy = label_tensor.numpy()
|
117 |
+
label_numpy = label_numpy / 255.0
|
118 |
+
|
119 |
+
return label_numpy
|
120 |
+
|
121 |
+
def generate_label(inputs, imsize):
|
122 |
+
pred_batch = []
|
123 |
+
for input in inputs:
|
124 |
+
input = input.view(1, 19, imsize, imsize)
|
125 |
+
pred = np.squeeze(input.data.max(1)[1].cpu().numpy(), axis=0)
|
126 |
+
pred_batch.append(pred)
|
127 |
+
|
128 |
+
pred_batch = np.array(pred_batch)
|
129 |
+
pred_batch = torch.from_numpy(pred_batch)
|
130 |
+
|
131 |
+
label_batch = []
|
132 |
+
for p in pred_batch:
|
133 |
+
p = p.view(1, imsize, imsize)
|
134 |
+
label_batch.append(tensor2label(p, 19))
|
135 |
+
|
136 |
+
label_batch = np.array(label_batch)
|
137 |
+
label_batch = torch.from_numpy(label_batch)
|
138 |
+
|
139 |
+
return label_batch
|
140 |
+
|
141 |
+
def get_mask(model, image, input_size):
|
142 |
+
interp = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True)
|
143 |
+
|
144 |
+
image = image.unsqueeze(0)
|
145 |
+
with torch.no_grad():
|
146 |
+
outputs = model(image.cuda())
|
147 |
+
labels = generate_label(interp(outputs), input_size[0])
|
148 |
+
return labels[0]
|
149 |
+
|
150 |
+
def save_mask(args):
|
151 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
152 |
+
|
153 |
+
cudnn.benchmark = True
|
154 |
+
cudnn.enabled = True
|
155 |
+
|
156 |
+
model = dml_csr.DML_CSR(19, InPlaceABN, False)
|
157 |
+
|
158 |
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
159 |
+
std=[0.229, 0.224, 0.225])
|
160 |
+
transform = transforms.Compose([transforms.ToTensor(), normalize])
|
161 |
+
|
162 |
+
input_size = (args.size, args.size)
|
163 |
+
image = cv2.imread(args.image_path, cv2.IMREAD_COLOR)
|
164 |
+
h, w, _ = image.shape
|
165 |
+
center, s = box2cs([0, 0, w - 1, h - 1])
|
166 |
+
r = 0
|
167 |
+
crop_size = np.asarray(input_size)
|
168 |
+
trans = dml_transforms.get_affine_transform(center, s, r, crop_size)
|
169 |
+
image = cv2.warpAffine(image, trans, (int(crop_size[1]), int(crop_size[0])),
|
170 |
+
flags=cv2.INTER_LINEAR,
|
171 |
+
borderMode=cv2.BORDER_CONSTANT,
|
172 |
+
borderValue=(0, 0, 0))
|
173 |
+
image = transform(image)
|
174 |
+
|
175 |
+
if not os.path.exists(args.checkpoint_path):
|
176 |
+
download_checkpoint()
|
177 |
+
state_dict = torch.load(args.checkpoint_path, map_location='cuda:0')
|
178 |
+
model.load_state_dict(state_dict)
|
179 |
+
|
180 |
+
model.cuda()
|
181 |
+
model.eval()
|
182 |
+
|
183 |
+
mask = get_mask(model, image, input_size)
|
184 |
+
filename = os.path.join(args.output_dir, os.path.basename(args.image_path).split('.')[0] + '.png')
|
185 |
+
save_image(mask, filename)
|
186 |
+
print(f'Mask saved in {filename}')
|
187 |
+
|
188 |
+
|
189 |
+
if __name__ == '__main__':
|
190 |
+
args = parse_args()
|
191 |
+
save_mask(args)
|