Spaces:
Runtime error
Runtime error
initial commit
Browse files- .gitignore +11 -0
- conr.py +292 -0
- data_loader.py +273 -0
- infer.sh +14 -0
- model/__init__.py +1 -0
- model/backbone.py +285 -0
- model/decoder_small.py +43 -0
- model/shader.py +290 -0
- model/warplayer.py +56 -0
- streamlit.py +52 -0
- train.py +229 -0
.gitignore
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
results/
|
2 |
+
test_data/
|
3 |
+
test_data_pre/
|
4 |
+
weights/
|
5 |
+
x264/
|
6 |
+
*.mp3
|
7 |
+
*.mp4
|
8 |
+
*.txt
|
9 |
+
*.png
|
10 |
+
complex_infer.sh
|
11 |
+
__pycache__/
|
conr.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from model.backbone import ResEncUnet
|
6 |
+
|
7 |
+
from model.shader import CINN
|
8 |
+
from model.decoder_small import RGBADecoderNet
|
9 |
+
|
10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
|
12 |
+
|
13 |
+
def UDPClip(x):
|
14 |
+
return torch.clamp(x, min=0, max=1) # NCHW
|
15 |
+
|
16 |
+
|
17 |
+
class CoNR():
|
18 |
+
def __init__(self, args):
|
19 |
+
self.args = args
|
20 |
+
|
21 |
+
self.udpparsernet = ResEncUnet(
|
22 |
+
backbone_name='resnet50_danbo',
|
23 |
+
classes=4,
|
24 |
+
pretrained=(args.local_rank == 0),
|
25 |
+
parametric_upsampling=True,
|
26 |
+
decoder_filters=(512, 384, 256, 128, 32),
|
27 |
+
map_location=device
|
28 |
+
)
|
29 |
+
self.target_pose_encoder = ResEncUnet(
|
30 |
+
backbone_name='resnet18_danbo-4',
|
31 |
+
classes=1,
|
32 |
+
pretrained=(args.local_rank == 0),
|
33 |
+
parametric_upsampling=True,
|
34 |
+
decoder_filters=(512, 384, 256, 128, 32),
|
35 |
+
map_location=device
|
36 |
+
)
|
37 |
+
self.DIM_SHADER_REFERENCE = 4
|
38 |
+
self.shader = CINN(self.DIM_SHADER_REFERENCE)
|
39 |
+
self.rgbadecodernet = RGBADecoderNet(
|
40 |
+
)
|
41 |
+
self.device()
|
42 |
+
self.parser_ckpt = None
|
43 |
+
|
44 |
+
def dist(self):
|
45 |
+
args = self.args
|
46 |
+
if args.distributed:
|
47 |
+
self.udpparsernet = torch.nn.parallel.DistributedDataParallel(
|
48 |
+
self.udpparsernet,
|
49 |
+
device_ids=[
|
50 |
+
args.local_rank],
|
51 |
+
output_device=args.local_rank,
|
52 |
+
broadcast_buffers=False,
|
53 |
+
find_unused_parameters=True
|
54 |
+
)
|
55 |
+
self.target_pose_encoder = torch.nn.parallel.DistributedDataParallel(
|
56 |
+
self.target_pose_encoder,
|
57 |
+
device_ids=[
|
58 |
+
args.local_rank],
|
59 |
+
output_device=args.local_rank,
|
60 |
+
broadcast_buffers=False,
|
61 |
+
find_unused_parameters=True
|
62 |
+
)
|
63 |
+
self.shader = torch.nn.parallel.DistributedDataParallel(
|
64 |
+
self.shader,
|
65 |
+
device_ids=[
|
66 |
+
args.local_rank],
|
67 |
+
output_device=args.local_rank,
|
68 |
+
broadcast_buffers=True
|
69 |
+
)
|
70 |
+
|
71 |
+
self.rgbadecodernet = torch.nn.parallel.DistributedDataParallel(
|
72 |
+
self.rgbadecodernet,
|
73 |
+
device_ids=[
|
74 |
+
args.local_rank],
|
75 |
+
output_device=args.local_rank,
|
76 |
+
broadcast_buffers=True
|
77 |
+
)
|
78 |
+
|
79 |
+
def load_model(self, path):
|
80 |
+
self.udpparsernet.load_state_dict(
|
81 |
+
torch.load('{}/udpparsernet.pth'.format(path), map_location=device))
|
82 |
+
self.target_pose_encoder.load_state_dict(
|
83 |
+
torch.load('{}/target_pose_encoder.pth'.format(path), map_location=device))
|
84 |
+
self.shader.load_state_dict(
|
85 |
+
torch.load('{}/shader.pth'.format(path), map_location=device))
|
86 |
+
self.rgbadecodernet.load_state_dict(
|
87 |
+
torch.load('{}/rgbadecodernet.pth'.format(path), map_location=device))
|
88 |
+
|
89 |
+
def save_model(self, ite_num):
|
90 |
+
self._save_pth(self.udpparsernet,
|
91 |
+
model_name="udpparsernet", ite_num=ite_num)
|
92 |
+
self._save_pth(self.target_pose_encoder,
|
93 |
+
model_name="target_pose_encoder", ite_num=ite_num)
|
94 |
+
self._save_pth(self.shader,
|
95 |
+
model_name="shader", ite_num=ite_num)
|
96 |
+
self._save_pth(self.rgbadecodernet,
|
97 |
+
model_name="rgbadecodernet", ite_num=ite_num)
|
98 |
+
|
99 |
+
def _save_pth(self, net, model_name, ite_num):
|
100 |
+
args = self.args
|
101 |
+
to_save = None
|
102 |
+
if args.distributed:
|
103 |
+
if args.local_rank == 0:
|
104 |
+
to_save = net.module.state_dict()
|
105 |
+
else:
|
106 |
+
to_save = net.state_dict()
|
107 |
+
if to_save:
|
108 |
+
model_dir = os.path.join(
|
109 |
+
os.getcwd(), 'saved_models', args.model_name + os.sep + "checkpoints" + os.sep + "itr_%d" % (ite_num)+os.sep)
|
110 |
+
|
111 |
+
os.makedirs(model_dir, exist_ok=True)
|
112 |
+
torch.save(to_save, model_dir + model_name + ".pth")
|
113 |
+
|
114 |
+
def train(self):
|
115 |
+
self.udpparsernet.train()
|
116 |
+
self.target_pose_encoder.train()
|
117 |
+
self.shader.train()
|
118 |
+
self.rgbadecodernet.train()
|
119 |
+
|
120 |
+
def eval(self):
|
121 |
+
self.udpparsernet.eval()
|
122 |
+
self.target_pose_encoder.eval()
|
123 |
+
self.shader.eval()
|
124 |
+
self.rgbadecodernet.eval()
|
125 |
+
|
126 |
+
def device(self):
|
127 |
+
self.udpparsernet.to(device)
|
128 |
+
self.target_pose_encoder.to(device)
|
129 |
+
self.shader.to(device)
|
130 |
+
self.rgbadecodernet.to(device)
|
131 |
+
|
132 |
+
def data_norm_image(self, data):
|
133 |
+
|
134 |
+
with torch.cuda.amp.autocast(enabled=False):
|
135 |
+
for name in ["character_labels", "pose_label"]:
|
136 |
+
if name in data:
|
137 |
+
data[name] = data[name].to(
|
138 |
+
device, non_blocking=True).float()
|
139 |
+
for name in ["pose_images", "pose_mask", "character_images", "character_masks"]:
|
140 |
+
if name in data:
|
141 |
+
data[name] = data[name].to(
|
142 |
+
device, non_blocking=True).float() / 255.0
|
143 |
+
if "pose_images" in data:
|
144 |
+
data["num_pose_images"] = data["pose_images"].shape[1]
|
145 |
+
data["num_samples"] = data["pose_images"].shape[0]
|
146 |
+
if "character_images" in data:
|
147 |
+
data["num_character_images"] = data["character_images"].shape[1]
|
148 |
+
data["num_samples"] = data["character_images"].shape[0]
|
149 |
+
if "pose_images" in data and "character_images" in data:
|
150 |
+
assert (data["pose_images"].shape[0] ==
|
151 |
+
data["character_images"].shape[0])
|
152 |
+
return data
|
153 |
+
|
154 |
+
def reset_charactersheet(self):
|
155 |
+
self.parser_ckpt = None
|
156 |
+
|
157 |
+
def model_step(self, data, training=False):
|
158 |
+
self.eval()
|
159 |
+
with torch.cuda.amp.autocast(enabled=False):
|
160 |
+
pred = {}
|
161 |
+
if self.parser_ckpt:
|
162 |
+
pred["parser"] = self.parser_ckpt
|
163 |
+
else:
|
164 |
+
pred = self.character_parser_forward(data, pred)
|
165 |
+
self.parser_ckpt = pred["parser"]
|
166 |
+
pred = self.pose_parser_sc_forward(data, pred)
|
167 |
+
pred = self.shader_pose_encoder_forward(data, pred)
|
168 |
+
pred = self.shader_forward(data, pred)
|
169 |
+
return pred
|
170 |
+
|
171 |
+
def shader_forward(self, data, pred={}):
|
172 |
+
assert ("num_character_images" in data), "ERROR: No Character Sheet input."
|
173 |
+
|
174 |
+
character_images_rgb_nmchw, num_character_images = data[
|
175 |
+
"character_images"], data["num_character_images"]
|
176 |
+
# build x_reference_rgb_a_sudp in the draw call
|
177 |
+
shader_character_a_nmchw = data["character_masks"]
|
178 |
+
assert torch.any(torch.mean(shader_character_a_nmchw, (0, 2, 3, 4)) >= 0.95) == False, "ERROR: \
|
179 |
+
No transparent area found in the image, PLEASE separate the foreground of input character sheets.\
|
180 |
+
The website waifucutout.com is recommended to automatically cut out the foreground."
|
181 |
+
|
182 |
+
if shader_character_a_nmchw is None:
|
183 |
+
shader_character_a_nmchw = pred["parser"]["pred"][:, :, 3:4, :, :]
|
184 |
+
x_reference_rgb_a = torch.cat([shader_character_a_nmchw[:, :, :, :, :] * character_images_rgb_nmchw[:, :, :, :, :],
|
185 |
+
shader_character_a_nmchw[:,
|
186 |
+
:, :, :, :],
|
187 |
+
|
188 |
+
], 2)
|
189 |
+
assert (x_reference_rgb_a.shape[2] == self.DIM_SHADER_REFERENCE)
|
190 |
+
# build x_reference_features in the draw call
|
191 |
+
x_reference_features = pred["parser"]["features"]
|
192 |
+
# run cinn shader
|
193 |
+
retdic = self.shader(
|
194 |
+
pred["shader"]["target_pose_features"], x_reference_rgb_a, x_reference_features)
|
195 |
+
pred["shader"].update(retdic)
|
196 |
+
|
197 |
+
# decode rgba
|
198 |
+
if True:
|
199 |
+
dec_out = self.rgbadecodernet(
|
200 |
+
retdic["y_last_remote_features"])
|
201 |
+
y_weighted_x_reference_RGB = dec_out[:, 0:3, :, :]
|
202 |
+
y_weighted_mask_A = dec_out[:, 3:4, :, :]
|
203 |
+
y_weighted_warp_decoded_rgba = torch.cat(
|
204 |
+
(y_weighted_x_reference_RGB*y_weighted_mask_A, y_weighted_mask_A), dim=1
|
205 |
+
)
|
206 |
+
assert(y_weighted_warp_decoded_rgba.shape[1] == 4)
|
207 |
+
assert(
|
208 |
+
y_weighted_warp_decoded_rgba.shape[-1] == character_images_rgb_nmchw.shape[-1])
|
209 |
+
# apply decoded mask to decoded rgb, finishing the draw call
|
210 |
+
pred["shader"]["y_weighted_warp_decoded_rgba"] = y_weighted_warp_decoded_rgba
|
211 |
+
return pred
|
212 |
+
|
213 |
+
def character_parser_forward(self, data, pred={}):
|
214 |
+
if not("num_character_images" in data and "character_images" in data):
|
215 |
+
return pred
|
216 |
+
pred["parser"] = {"pred": None} # create output
|
217 |
+
|
218 |
+
inputs_rgb_nmchw, num_samples, num_character_images = data[
|
219 |
+
"character_images"], data["num_samples"], data["num_character_images"]
|
220 |
+
inputs_rgb_fchw = inputs_rgb_nmchw.view(
|
221 |
+
(num_samples * num_character_images, inputs_rgb_nmchw.shape[2], inputs_rgb_nmchw.shape[3], inputs_rgb_nmchw.shape[4]))
|
222 |
+
|
223 |
+
encoder_out, features = self.udpparsernet(
|
224 |
+
(inputs_rgb_fchw-0.6)/0.2970)
|
225 |
+
|
226 |
+
pred["parser"]["features"] = [features_out.view(
|
227 |
+
(num_samples, num_character_images, features_out.shape[1], features_out.shape[2], features_out.shape[3])) for features_out in features]
|
228 |
+
|
229 |
+
if (encoder_out is not None):
|
230 |
+
|
231 |
+
pred["parser"]["pred"] = UDPClip(encoder_out.view(
|
232 |
+
(num_samples, num_character_images, encoder_out.shape[1], encoder_out.shape[2], encoder_out.shape[3])))
|
233 |
+
|
234 |
+
return pred
|
235 |
+
|
236 |
+
def pose_parser_sc_forward(self, data, pred={}):
|
237 |
+
if not("num_pose_images" in data and "pose_images" in data):
|
238 |
+
return pred
|
239 |
+
inputs_aug_rgb_nmchw, num_samples, num_pose_images = data[
|
240 |
+
"pose_images"], data["num_samples"], data["num_pose_images"]
|
241 |
+
inputs_aug_rgb_fchw = inputs_aug_rgb_nmchw.view(
|
242 |
+
(num_samples * num_pose_images, inputs_aug_rgb_nmchw.shape[2], inputs_aug_rgb_nmchw.shape[3], inputs_aug_rgb_nmchw.shape[4]))
|
243 |
+
|
244 |
+
encoder_out, _ = self.udpparsernet(
|
245 |
+
(inputs_aug_rgb_fchw-0.6)/0.2970)
|
246 |
+
|
247 |
+
encoder_out = encoder_out.view(
|
248 |
+
(num_samples, num_pose_images, encoder_out.shape[1], encoder_out.shape[2], encoder_out.shape[3]))
|
249 |
+
|
250 |
+
# apply sigmoid after eval loss
|
251 |
+
pred["pose_parser"] = {"pred":UDPClip(encoder_out)}
|
252 |
+
|
253 |
+
|
254 |
+
return pred
|
255 |
+
|
256 |
+
def shader_pose_encoder_forward(self, data, pred={}):
|
257 |
+
pred["shader"] = {} # create output
|
258 |
+
if "pose_images" in data:
|
259 |
+
pose_images_rgb_nmchw = data["pose_images"]
|
260 |
+
target_gt_rgb = pose_images_rgb_nmchw[:, 0, :, :, :]
|
261 |
+
pred["shader"]["target_gt_rgb"] = target_gt_rgb
|
262 |
+
|
263 |
+
shader_target_a = None
|
264 |
+
if "pose_mask" in data:
|
265 |
+
pred["shader"]["target_gt_a"] = data["pose_mask"]
|
266 |
+
shader_target_a = data["pose_mask"]
|
267 |
+
|
268 |
+
shader_target_sudp = None
|
269 |
+
if "pose_label" in data:
|
270 |
+
shader_target_sudp = data["pose_label"][:, :3, :, :]
|
271 |
+
|
272 |
+
if self.args.test_pose_use_parser_udp:
|
273 |
+
shader_target_sudp = None
|
274 |
+
if shader_target_sudp is None:
|
275 |
+
shader_target_sudp = pred["pose_parser"]["pred"][:, 0:3, :, :]
|
276 |
+
|
277 |
+
if shader_target_a is None:
|
278 |
+
shader_target_a = pred["pose_parser"]["pred"][:, 3:4, :, :]
|
279 |
+
|
280 |
+
# build x_target_sudp_a in the draw call
|
281 |
+
x_target_sudp_a = torch.cat((
|
282 |
+
shader_target_sudp*shader_target_a,
|
283 |
+
shader_target_a
|
284 |
+
), 1)
|
285 |
+
pred["shader"].update({
|
286 |
+
"x_target_sudp_a": x_target_sudp_a
|
287 |
+
})
|
288 |
+
_, features = self.traget_pose_encoder(
|
289 |
+
(x_target_sudp_a-0.6)/0.2970, ret_parser_out=False)
|
290 |
+
|
291 |
+
pred["shader"]["target_pose_features"] = features
|
292 |
+
return pred
|
data_loader.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import os
|
6 |
+
cv2.setNumThreads(1)
|
7 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
8 |
+
|
9 |
+
|
10 |
+
class RandomResizedCropWithAutoCenteringAndZeroPadding (object):
|
11 |
+
def __init__(self, output_size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), center_jitter=(0.1, 0.1), size_from_alpha_mask=True):
|
12 |
+
assert isinstance(output_size, (int, tuple))
|
13 |
+
if isinstance(output_size, int):
|
14 |
+
self.output_size = (output_size, output_size)
|
15 |
+
else:
|
16 |
+
assert len(output_size) == 2
|
17 |
+
self.output_size = output_size
|
18 |
+
assert isinstance(scale, tuple)
|
19 |
+
assert isinstance(ratio, tuple)
|
20 |
+
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
|
21 |
+
raise ValueError("Scale and ratio should be of kind (min, max)")
|
22 |
+
self.size_from_alpha_mask = size_from_alpha_mask
|
23 |
+
self.scale = scale
|
24 |
+
self.ratio = ratio
|
25 |
+
assert isinstance(center_jitter, tuple)
|
26 |
+
self.center_jitter = center_jitter
|
27 |
+
|
28 |
+
def __call__(self, sample):
|
29 |
+
imidx, image = sample['imidx'], sample["image_np"]
|
30 |
+
if "labels" in sample:
|
31 |
+
label = sample["labels"]
|
32 |
+
else:
|
33 |
+
label = None
|
34 |
+
|
35 |
+
im_h, im_w = image.shape[:2]
|
36 |
+
if self.size_from_alpha_mask and image.shape[2] == 4:
|
37 |
+
# compute bbox from alpha mask
|
38 |
+
bbox_left, bbox_top, bbox_w, bbox_h = cv2.boundingRect(
|
39 |
+
(image[:, :, 3] > 0).astype(np.uint8))
|
40 |
+
else:
|
41 |
+
bbox_left, bbox_top = 0, 0
|
42 |
+
bbox_h, bbox_w = image.shape[:2]
|
43 |
+
if bbox_h <= 1 and bbox_w <= 1:
|
44 |
+
sample["bad"] = 0
|
45 |
+
else:
|
46 |
+
# detect too small image here
|
47 |
+
alpha_varea = np.sum((image[:, :, 3] > 0).astype(np.uint8))
|
48 |
+
image_area = image.shape[0]*image.shape[1]
|
49 |
+
if alpha_varea/image_area < 0.001:
|
50 |
+
sample["bad"] = alpha_varea
|
51 |
+
# detect bad image
|
52 |
+
if "bad" in sample:
|
53 |
+
# baddata_dir = os.path.join(os.getcwd(), 'test_data', "baddata" + os.sep)
|
54 |
+
# save_output(str(imidx)+".png",image,label,baddata_dir)
|
55 |
+
bbox_h, bbox_w = image.shape[:2]
|
56 |
+
sample["image_np"] = np.zeros(
|
57 |
+
[self.output_size[0], self.output_size[1], image.shape[2]], dtype=image.dtype)
|
58 |
+
if label is not None:
|
59 |
+
sample["labels"] = np.zeros(
|
60 |
+
[self.output_size[0], self.output_size[1], 4], dtype=label.dtype)
|
61 |
+
|
62 |
+
return sample
|
63 |
+
|
64 |
+
# compute default area by making sure output_size contains bbox_w * bbox_h
|
65 |
+
|
66 |
+
jitter_h = np.random.uniform(-bbox_h *
|
67 |
+
self.center_jitter[0], bbox_h*self.center_jitter[0])
|
68 |
+
jitter_w = np.random.uniform(-bbox_w *
|
69 |
+
self.center_jitter[1], bbox_w*self.center_jitter[1])
|
70 |
+
|
71 |
+
# h/w
|
72 |
+
target_aspect_ratio = np.exp(
|
73 |
+
np.log(self.output_size[0]/self.output_size[1]) +
|
74 |
+
np.random.uniform(np.log(self.ratio[0]), np.log(self.ratio[1]))
|
75 |
+
)
|
76 |
+
|
77 |
+
source_aspect_ratio = bbox_h/bbox_w
|
78 |
+
|
79 |
+
if target_aspect_ratio < source_aspect_ratio:
|
80 |
+
# same w, target has larger h, use h to align
|
81 |
+
target_height = bbox_h * \
|
82 |
+
np.random.uniform(self.scale[0], self.scale[1])
|
83 |
+
virtual_h = int(
|
84 |
+
round(target_height))
|
85 |
+
virtual_w = int(
|
86 |
+
round(target_height / target_aspect_ratio)) # h/w
|
87 |
+
else:
|
88 |
+
# same w, source has larger h, use w to align
|
89 |
+
target_width = bbox_w * \
|
90 |
+
np.random.uniform(self.scale[0], self.scale[1])
|
91 |
+
virtual_h = int(
|
92 |
+
round(target_width * target_aspect_ratio)) # h/w
|
93 |
+
virtual_w = int(
|
94 |
+
round(target_width))
|
95 |
+
|
96 |
+
# print("required aspect ratio:", target_aspect_ratio)
|
97 |
+
|
98 |
+
virtual_top = int(round(bbox_top + jitter_h - (virtual_h-bbox_h)/2))
|
99 |
+
virutal_left = int(round(bbox_left + jitter_w - (virtual_w-bbox_w)/2))
|
100 |
+
|
101 |
+
if virtual_top < 0:
|
102 |
+
top_padding = abs(virtual_top)
|
103 |
+
crop_top = 0
|
104 |
+
else:
|
105 |
+
top_padding = 0
|
106 |
+
crop_top = virtual_top
|
107 |
+
if virutal_left < 0:
|
108 |
+
left_padding = abs(virutal_left)
|
109 |
+
crop_left = 0
|
110 |
+
else:
|
111 |
+
left_padding = 0
|
112 |
+
crop_left = virutal_left
|
113 |
+
if virtual_top+virtual_h > im_h:
|
114 |
+
bottom_padding = abs(im_h-(virtual_top+virtual_h))
|
115 |
+
crop_bottom = im_h
|
116 |
+
else:
|
117 |
+
bottom_padding = 0
|
118 |
+
crop_bottom = virtual_top+virtual_h
|
119 |
+
if virutal_left+virtual_w > im_w:
|
120 |
+
right_padding = abs(im_w-(virutal_left+virtual_w))
|
121 |
+
crop_right = im_w
|
122 |
+
else:
|
123 |
+
right_padding = 0
|
124 |
+
crop_right = virutal_left+virtual_w
|
125 |
+
# crop
|
126 |
+
|
127 |
+
image = image[crop_top:crop_bottom, crop_left: crop_right]
|
128 |
+
if label is not None:
|
129 |
+
label = label[crop_top:crop_bottom, crop_left: crop_right]
|
130 |
+
|
131 |
+
# pad
|
132 |
+
if top_padding + bottom_padding + left_padding + right_padding > 0:
|
133 |
+
padding = ((top_padding, bottom_padding),
|
134 |
+
(left_padding, right_padding), (0, 0))
|
135 |
+
# print("padding", padding)
|
136 |
+
image = np.pad(image, padding, mode='constant')
|
137 |
+
if label is not None:
|
138 |
+
label = np.pad(label, padding, mode='constant')
|
139 |
+
|
140 |
+
if image.shape[0]/image.shape[1] - virtual_h/virtual_w > 0.001:
|
141 |
+
print("virtual aspect ratio:", virtual_h/virtual_w)
|
142 |
+
print("image aspect ratio:", image.shape[0]/image.shape[1])
|
143 |
+
assert (image.shape[0]/image.shape[1] - virtual_h/virtual_w < 0.001)
|
144 |
+
sample["crop"] = np.array(
|
145 |
+
[im_h, im_w, crop_top, crop_bottom, crop_left, crop_right, top_padding, bottom_padding, left_padding, right_padding, image.shape[0], image.shape[1]])
|
146 |
+
|
147 |
+
# resize
|
148 |
+
if self.output_size[1] != image.shape[1] or self.output_size[0] != image.shape[0]:
|
149 |
+
if self.output_size[1] > image.shape[1] and self.output_size[0] > image.shape[0]:
|
150 |
+
# enlarging
|
151 |
+
image = cv2.resize(
|
152 |
+
image, (self.output_size[1], self.output_size[0]), interpolation=cv2.INTER_LINEAR)
|
153 |
+
else:
|
154 |
+
# shrinking
|
155 |
+
image = cv2.resize(
|
156 |
+
image, (self.output_size[1], self.output_size[0]), interpolation=cv2.INTER_AREA)
|
157 |
+
|
158 |
+
if label is not None:
|
159 |
+
label = cv2.resize(label, (self.output_size[1], self.output_size[0]),
|
160 |
+
interpolation=cv2.INTER_NEAREST_EXACT)
|
161 |
+
|
162 |
+
assert image.shape[0] == self.output_size[0] and image.shape[1] == self.output_size[1]
|
163 |
+
sample['imidx'], sample["image_np"] = imidx, image
|
164 |
+
if label is not None:
|
165 |
+
assert label.shape[0] == self.output_size[0] and label.shape[1] == self.output_size[1]
|
166 |
+
sample["labels"] = label
|
167 |
+
|
168 |
+
return sample
|
169 |
+
|
170 |
+
|
171 |
+
class FileDataset(Dataset):
|
172 |
+
def __init__(self, image_names_list, fg_img_lbl_transform=None, shader_pose_use_gt_udp_test=True, shader_target_use_gt_rgb_debug=False):
|
173 |
+
self.image_name_list = image_names_list
|
174 |
+
self.fg_img_lbl_transform = fg_img_lbl_transform
|
175 |
+
self.shader_pose_use_gt_udp_test = shader_pose_use_gt_udp_test
|
176 |
+
self.shader_target_use_gt_rgb_debug = shader_target_use_gt_rgb_debug
|
177 |
+
|
178 |
+
def __len__(self):
|
179 |
+
return len(self.image_name_list)
|
180 |
+
|
181 |
+
def get_gt_from_disk(self, idx, imname, read_label):
|
182 |
+
if read_label:
|
183 |
+
# read label
|
184 |
+
with open(imname, mode="rb") as bio:
|
185 |
+
if imname.find(".npz") > 0:
|
186 |
+
label_np = np.load(bio, allow_pickle=True)[
|
187 |
+
'i'].astype(np.float32, copy=False)
|
188 |
+
else:
|
189 |
+
label_np = cv2.cvtColor(cv2.imdecode(np.frombuffer(bio.read(
|
190 |
+
), np.uint8), cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH | cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA)
|
191 |
+
assert (4 == label_np.shape[2])
|
192 |
+
# fake image out of valid label
|
193 |
+
image_np = (label_np*255).clip(0, 255).astype(np.uint8, copy=False)
|
194 |
+
# assemble sample
|
195 |
+
sample = {'imidx': np.array(
|
196 |
+
[idx]), "image_np": image_np, "labels": label_np}
|
197 |
+
|
198 |
+
else:
|
199 |
+
# read image as unit8
|
200 |
+
with open(imname, mode="rb") as bio:
|
201 |
+
image_np = cv2.cvtColor(cv2.imdecode(np.frombuffer(
|
202 |
+
bio.read(), np.uint8), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA)
|
203 |
+
# image_np = Image.open(bio)
|
204 |
+
# image_np = np.array(image_np)
|
205 |
+
assert (3 == len(image_np.shape))
|
206 |
+
if (image_np.shape[2] == 4):
|
207 |
+
mask_np = image_np[:, :, 3:4]
|
208 |
+
image_np = (image_np[:, :, :3] *
|
209 |
+
(image_np[:, :, 3][:, :, np.newaxis]/255.0)).clip(0, 255).astype(np.uint8, copy=False)
|
210 |
+
elif (image_np.shape[2] == 3):
|
211 |
+
# generate a fake mask
|
212 |
+
# Fool-proofing
|
213 |
+
mask_np = np.ones(
|
214 |
+
(image_np.shape[0], image_np.shape[1], 1), dtype=np.uint8)*255
|
215 |
+
print("WARN: transparent background is preferred for image ", imname)
|
216 |
+
else:
|
217 |
+
raise ValueError("weird shape of image ", imname, image_np)
|
218 |
+
image_np = np.concatenate((image_np, mask_np), axis=2)
|
219 |
+
sample = {'imidx': np.array(
|
220 |
+
[idx]), "image_np": image_np}
|
221 |
+
|
222 |
+
# apply fg_img_lbl_transform
|
223 |
+
if self.fg_img_lbl_transform:
|
224 |
+
sample = self.fg_img_lbl_transform(sample)
|
225 |
+
|
226 |
+
if "labels" in sample:
|
227 |
+
# return UDP as 4chn XYZV float tensor
|
228 |
+
sample["labels"] = torch.from_numpy(
|
229 |
+
sample["labels"].transpose((2, 0, 1)))
|
230 |
+
assert (sample["labels"].dtype == torch.float32)
|
231 |
+
|
232 |
+
if "image_np" in sample:
|
233 |
+
# return image as 3chn RGB uint8 tensor and 1chn A uint8 tensor
|
234 |
+
sample["mask"] = torch.from_numpy(
|
235 |
+
sample["image_np"][:, :, 3:4].transpose((2, 0, 1)))
|
236 |
+
assert (sample["mask"].dtype == torch.uint8)
|
237 |
+
sample["image"] = torch.from_numpy(
|
238 |
+
sample["image_np"][:, :, :3].transpose((2, 0, 1)))
|
239 |
+
|
240 |
+
assert (sample["image"].dtype == torch.uint8)
|
241 |
+
del sample["image_np"]
|
242 |
+
return sample
|
243 |
+
|
244 |
+
def __getitem__(self, idx):
|
245 |
+
sample = {
|
246 |
+
'imidx': np.array([idx])}
|
247 |
+
target = self.get_gt_from_disk(
|
248 |
+
idx, imname=self.image_name_list[idx][0], read_label=self.shader_pose_use_gt_udp_test)
|
249 |
+
if self.shader_target_use_gt_rgb_debug:
|
250 |
+
sample["pose_images"] = torch.stack([target["image"]])
|
251 |
+
sample["pose_mask"] = target["mask"]
|
252 |
+
elif self.shader_pose_use_gt_udp_test:
|
253 |
+
sample["pose_label"] = target["labels"]
|
254 |
+
sample["pose_mask"] = target["mask"]
|
255 |
+
else:
|
256 |
+
sample["pose_images"] = torch.stack([target["image"]])
|
257 |
+
if "crop" in target:
|
258 |
+
sample["pose_crop"] = target["crop"]
|
259 |
+
character_images = []
|
260 |
+
character_masks = []
|
261 |
+
for i in range(1, len(self.image_name_list[idx])):
|
262 |
+
source = self.get_gt_from_disk(
|
263 |
+
idx, self.image_name_list[idx][i], read_label=False)
|
264 |
+
character_images.append(source["image"])
|
265 |
+
character_masks.append(source["mask"])
|
266 |
+
character_images = torch.stack(character_images)
|
267 |
+
character_masks = torch.stack(character_masks)
|
268 |
+
sample.update({
|
269 |
+
"character_images": character_images,
|
270 |
+
"character_masks": character_masks
|
271 |
+
})
|
272 |
+
# do not make fake labels in inference
|
273 |
+
return sample
|
infer.sh
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
rm -r "./results"
|
2 |
+
mkdir "./results"
|
3 |
+
|
4 |
+
rlaunch --gpu=1 --cpu=4 --memory=25600 -- python3 -m torch.distributed.launch \
|
5 |
+
--nproc_per_node=1 train.py --mode=test \
|
6 |
+
--world_size=1 --dataloaders=2 \
|
7 |
+
--test_input_poses_images=./test_data/ \
|
8 |
+
--test_input_person_images=./character_sheet/ \
|
9 |
+
--test_output_dir=./results/ \
|
10 |
+
--test_checkpoint_dir=./weights/
|
11 |
+
|
12 |
+
echo Generating video...
|
13 |
+
ffmpeg -r 30 -y -i ./results/%d.png -r 30 -c:v libx264 output.mp4 -r 30
|
14 |
+
echo DONE.
|
model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
model/backbone.py
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchvision import models
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch
|
8 |
+
from torchvision import models
|
9 |
+
|
10 |
+
|
11 |
+
class AdaptiveConcatPool2d(nn.Module):
|
12 |
+
"""
|
13 |
+
Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`.
|
14 |
+
Source: Fastai. This code was taken from the fastai library at url
|
15 |
+
https://github.com/fastai/fastai/blob/master/fastai/layers.py#L176
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, sz=None):
|
19 |
+
"Output will be 2*sz or 2 if sz is None"
|
20 |
+
super().__init__()
|
21 |
+
self.output_size = sz or 1
|
22 |
+
self.ap = nn.AdaptiveAvgPool2d(self.output_size)
|
23 |
+
self.mp = nn.AdaptiveMaxPool2d(self.output_size)
|
24 |
+
|
25 |
+
def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
|
26 |
+
|
27 |
+
|
28 |
+
class MyNorm(nn.Module):
|
29 |
+
def __init__(self, num_channels):
|
30 |
+
super(MyNorm, self).__init__()
|
31 |
+
self.norm = nn.InstanceNorm2d(
|
32 |
+
num_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
x = self.norm(x)
|
36 |
+
return x
|
37 |
+
|
38 |
+
|
39 |
+
def resnet_fastai(model, pretrained, url, replace_first_layer=None, replace_maxpool_layer=None, progress=True, map_location=None, **kwargs):
|
40 |
+
cut = -2
|
41 |
+
s = model(pretrained=False, **kwargs)
|
42 |
+
if replace_maxpool_layer is not None:
|
43 |
+
s.maxpool = replace_maxpool_layer
|
44 |
+
if replace_first_layer is not None:
|
45 |
+
body = nn.Sequential(replace_first_layer, *list(s.children())[1:cut])
|
46 |
+
else:
|
47 |
+
body = nn.Sequential(*list(s.children())[:cut])
|
48 |
+
|
49 |
+
if pretrained:
|
50 |
+
state = torch.hub.load_state_dict_from_url(url,
|
51 |
+
progress=progress, map_location=map_location)
|
52 |
+
if replace_first_layer is not None:
|
53 |
+
for each in list(state.keys()).copy():
|
54 |
+
if each.find("0.0.") == 0:
|
55 |
+
del state[each]
|
56 |
+
body_tail = nn.Sequential(body)
|
57 |
+
ret = body_tail.load_state_dict(state, strict=False)
|
58 |
+
return body
|
59 |
+
|
60 |
+
|
61 |
+
def get_backbone(name, pretrained=True, map_location=None):
|
62 |
+
""" Loading backbone, defining names for skip-connections and encoder output. """
|
63 |
+
|
64 |
+
first_layer_for_4chn = nn.Conv2d(
|
65 |
+
4, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
66 |
+
max_pool_layer_replace = nn.Conv2d(
|
67 |
+
64, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
68 |
+
# loading backbone model
|
69 |
+
if name == 'resnet18':
|
70 |
+
backbone = models.resnet18(pretrained=pretrained)
|
71 |
+
if name == 'resnet18-4':
|
72 |
+
backbone = models.resnet18(pretrained=pretrained)
|
73 |
+
backbone.conv1 = first_layer_for_4chn
|
74 |
+
elif name == 'resnet34':
|
75 |
+
backbone = models.resnet34(pretrained=pretrained)
|
76 |
+
elif name == 'resnet50':
|
77 |
+
backbone = models.resnet50(pretrained=False, norm_layer=MyNorm)
|
78 |
+
backbone.maxpool = max_pool_layer_replace
|
79 |
+
elif name == 'resnet101':
|
80 |
+
backbone = models.resnet101(pretrained=pretrained)
|
81 |
+
elif name == 'resnet152':
|
82 |
+
backbone = models.resnet152(pretrained=pretrained)
|
83 |
+
elif name == 'vgg16':
|
84 |
+
backbone = models.vgg16_bn(pretrained=pretrained).features
|
85 |
+
elif name == 'vgg19':
|
86 |
+
backbone = models.vgg19_bn(pretrained=pretrained).features
|
87 |
+
elif name == 'resnet18_danbo-4':
|
88 |
+
backbone = resnet_fastai(models.resnet18, url="https://github.com/RF5/danbooru-pretrained/releases/download/v0.1/resnet18-3f77756f.pth",
|
89 |
+
pretrained=pretrained, map_location=map_location, norm_layer=MyNorm, replace_first_layer=first_layer_for_4chn)
|
90 |
+
elif name == 'resnet50_danbo':
|
91 |
+
backbone = resnet_fastai(models.resnet50, url="https://github.com/RF5/danbooru-pretrained/releases/download/v0.1/resnet50-13306192.pth",
|
92 |
+
pretrained=pretrained, map_location=map_location, norm_layer=MyNorm, replace_maxpool_layer=max_pool_layer_replace)
|
93 |
+
elif name == 'densenet121':
|
94 |
+
backbone = models.densenet121(pretrained=True).features
|
95 |
+
elif name == 'densenet161':
|
96 |
+
backbone = models.densenet161(pretrained=True).features
|
97 |
+
elif name == 'densenet169':
|
98 |
+
backbone = models.densenet169(pretrained=True).features
|
99 |
+
elif name == 'densenet201':
|
100 |
+
backbone = models.densenet201(pretrained=True).features
|
101 |
+
else:
|
102 |
+
raise NotImplemented(
|
103 |
+
'{} backbone model is not implemented so far.'.format(name))
|
104 |
+
#print(backbone)
|
105 |
+
# specifying skip feature and output names
|
106 |
+
if name.startswith('resnet'):
|
107 |
+
feature_names = [None, 'relu', 'layer1', 'layer2', 'layer3']
|
108 |
+
backbone_output = 'layer4'
|
109 |
+
elif name == 'vgg16':
|
110 |
+
# TODO: consider using a 'bridge' for VGG models, there is just a MaxPool between last skip and backbone output
|
111 |
+
feature_names = ['5', '12', '22', '32', '42']
|
112 |
+
backbone_output = '43'
|
113 |
+
elif name == 'vgg19':
|
114 |
+
feature_names = ['5', '12', '25', '38', '51']
|
115 |
+
backbone_output = '52'
|
116 |
+
elif name.startswith('densenet'):
|
117 |
+
feature_names = [None, 'relu0', 'denseblock1',
|
118 |
+
'denseblock2', 'denseblock3']
|
119 |
+
backbone_output = 'denseblock4'
|
120 |
+
elif name == 'unet_encoder':
|
121 |
+
feature_names = ['module1', 'module2', 'module3', 'module4']
|
122 |
+
backbone_output = 'module5'
|
123 |
+
else:
|
124 |
+
raise NotImplemented(
|
125 |
+
'{} backbone model is not implemented so far.'.format(name))
|
126 |
+
if name.find('_danbo') > 0:
|
127 |
+
feature_names = [None, '2', '4', '5', '6']
|
128 |
+
backbone_output = '7'
|
129 |
+
return backbone, feature_names, backbone_output
|
130 |
+
|
131 |
+
|
132 |
+
class UpsampleBlock(nn.Module):
|
133 |
+
|
134 |
+
# TODO: separate parametric and non-parametric classes?
|
135 |
+
# TODO: skip connection concatenated OR added
|
136 |
+
|
137 |
+
def __init__(self, ch_in, ch_out=None, skip_in=0, use_bn=True, parametric=False):
|
138 |
+
super(UpsampleBlock, self).__init__()
|
139 |
+
|
140 |
+
self.parametric = parametric
|
141 |
+
ch_out = ch_in/2 if ch_out is None else ch_out
|
142 |
+
|
143 |
+
# first convolution: either transposed conv, or conv following the skip connection
|
144 |
+
if parametric:
|
145 |
+
# versions: kernel=4 padding=1, kernel=2 padding=0
|
146 |
+
self.up = nn.ConvTranspose2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(4, 4),
|
147 |
+
stride=2, padding=1, output_padding=0, bias=(not use_bn))
|
148 |
+
self.bn1 = MyNorm(ch_out) if use_bn else None
|
149 |
+
else:
|
150 |
+
self.up = None
|
151 |
+
ch_in = ch_in + skip_in
|
152 |
+
self.conv1 = nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(3, 3),
|
153 |
+
stride=1, padding=1, bias=(not use_bn))
|
154 |
+
self.bn1 = MyNorm(ch_out) if use_bn else None
|
155 |
+
|
156 |
+
self.relu = nn.ReLU(inplace=True)
|
157 |
+
|
158 |
+
# second convolution
|
159 |
+
conv2_in = ch_out if not parametric else ch_out + skip_in
|
160 |
+
self.conv2 = nn.Conv2d(in_channels=conv2_in, out_channels=ch_out, kernel_size=(3, 3),
|
161 |
+
stride=1, padding=1, bias=(not use_bn))
|
162 |
+
self.bn2 = MyNorm(ch_out) if use_bn else None
|
163 |
+
|
164 |
+
def forward(self, x, skip_connection=None):
|
165 |
+
|
166 |
+
x = self.up(x) if self.parametric else F.interpolate(x, size=None, scale_factor=2, mode='bilinear',
|
167 |
+
align_corners=None)
|
168 |
+
if self.parametric:
|
169 |
+
x = self.bn1(x) if self.bn1 is not None else x
|
170 |
+
x = self.relu(x)
|
171 |
+
|
172 |
+
if skip_connection is not None:
|
173 |
+
x = torch.cat([x, skip_connection], dim=1)
|
174 |
+
|
175 |
+
if not self.parametric:
|
176 |
+
x = self.conv1(x)
|
177 |
+
x = self.bn1(x) if self.bn1 is not None else x
|
178 |
+
x = self.relu(x)
|
179 |
+
x = self.conv2(x)
|
180 |
+
x = self.bn2(x) if self.bn2 is not None else x
|
181 |
+
x = self.relu(x)
|
182 |
+
|
183 |
+
return x
|
184 |
+
|
185 |
+
|
186 |
+
class ResEncUnet(nn.Module):
|
187 |
+
|
188 |
+
""" U-Net (https://arxiv.org/pdf/1505.04597.pdf) implementation with pre-trained torchvision backbones."""
|
189 |
+
|
190 |
+
def __init__(self,
|
191 |
+
backbone_name,
|
192 |
+
pretrained=True,
|
193 |
+
encoder_freeze=False,
|
194 |
+
classes=21,
|
195 |
+
decoder_filters=(512, 256, 128, 64, 32),
|
196 |
+
parametric_upsampling=True,
|
197 |
+
shortcut_features='default',
|
198 |
+
decoder_use_instancenorm=True,
|
199 |
+
map_location=None
|
200 |
+
):
|
201 |
+
super(ResEncUnet, self).__init__()
|
202 |
+
|
203 |
+
self.backbone_name = backbone_name
|
204 |
+
|
205 |
+
self.backbone, self.shortcut_features, self.bb_out_name = get_backbone(
|
206 |
+
backbone_name, pretrained=pretrained, map_location=map_location)
|
207 |
+
shortcut_chs, bb_out_chs = self.infer_skip_channels()
|
208 |
+
if shortcut_features != 'default':
|
209 |
+
self.shortcut_features = shortcut_features
|
210 |
+
|
211 |
+
# build decoder part
|
212 |
+
self.upsample_blocks = nn.ModuleList()
|
213 |
+
# avoiding having more blocks than skip connections
|
214 |
+
decoder_filters = decoder_filters[:len(self.shortcut_features)]
|
215 |
+
decoder_filters_in = [bb_out_chs] + list(decoder_filters[:-1])
|
216 |
+
num_blocks = len(self.shortcut_features)
|
217 |
+
for i, [filters_in, filters_out] in enumerate(zip(decoder_filters_in, decoder_filters)):
|
218 |
+
self.upsample_blocks.append(UpsampleBlock(filters_in, filters_out,
|
219 |
+
skip_in=shortcut_chs[num_blocks-i-1],
|
220 |
+
parametric=parametric_upsampling,
|
221 |
+
use_bn=decoder_use_instancenorm))
|
222 |
+
self.final_conv = nn.Conv2d(
|
223 |
+
decoder_filters[-1], classes, kernel_size=(1, 1))
|
224 |
+
|
225 |
+
if encoder_freeze:
|
226 |
+
self.freeze_encoder()
|
227 |
+
|
228 |
+
def freeze_encoder(self):
|
229 |
+
""" Freezing encoder parameters, the newly initialized decoder parameters are remaining trainable. """
|
230 |
+
|
231 |
+
for param in self.backbone.parameters():
|
232 |
+
param.requires_grad = False
|
233 |
+
|
234 |
+
def forward(self, *input, ret_parser_out=True):
|
235 |
+
""" Forward propagation in U-Net. """
|
236 |
+
|
237 |
+
x, features = self.forward_backbone(*input)
|
238 |
+
output_feature = [x]
|
239 |
+
for skip_name, upsample_block in zip(self.shortcut_features[::-1], self.upsample_blocks):
|
240 |
+
skip_features = features[skip_name]
|
241 |
+
if skip_features is not None:
|
242 |
+
output_feature.append(skip_features)
|
243 |
+
if ret_parser_out:
|
244 |
+
x = upsample_block(x, skip_features)
|
245 |
+
if ret_parser_out:
|
246 |
+
x = self.final_conv(x)
|
247 |
+
# apply sigmoid later
|
248 |
+
else:
|
249 |
+
x = None
|
250 |
+
|
251 |
+
return x, output_feature
|
252 |
+
|
253 |
+
def forward_backbone(self, x):
|
254 |
+
""" Forward propagation in backbone encoder network. """
|
255 |
+
|
256 |
+
features = {None: None} if None in self.shortcut_features else dict()
|
257 |
+
for name, child in self.backbone.named_children():
|
258 |
+
x = child(x)
|
259 |
+
if name in self.shortcut_features:
|
260 |
+
features[name] = x
|
261 |
+
if name == self.bb_out_name:
|
262 |
+
break
|
263 |
+
|
264 |
+
return x, features
|
265 |
+
|
266 |
+
def infer_skip_channels(self):
|
267 |
+
""" Getting the number of channels at skip connections and at the output of the encoder. """
|
268 |
+
if self.backbone_name.find("-4") > 0:
|
269 |
+
x = torch.zeros(1, 4, 224, 224)
|
270 |
+
else:
|
271 |
+
x = torch.zeros(1, 3, 224, 224)
|
272 |
+
has_fullres_features = self.backbone_name.startswith(
|
273 |
+
'vgg') or self.backbone_name == 'unet_encoder'
|
274 |
+
# only VGG has features at full resolution
|
275 |
+
channels = [] if has_fullres_features else [0]
|
276 |
+
|
277 |
+
# forward run in backbone to count channels (dirty solution but works for *any* Module)
|
278 |
+
for name, child in self.backbone.named_children():
|
279 |
+
x = child(x)
|
280 |
+
if name in self.shortcut_features:
|
281 |
+
channels.append(x.shape[1])
|
282 |
+
if name == self.bb_out_name:
|
283 |
+
out_channels = x.shape[1]
|
284 |
+
break
|
285 |
+
return channels, out_channels
|
model/decoder_small.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
from torch import nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class ResBlock2d(nn.Module):
|
9 |
+
def __init__(self, in_features, kernel_size, padding):
|
10 |
+
super(ResBlock2d, self).__init__()
|
11 |
+
self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
12 |
+
padding=padding)
|
13 |
+
self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
|
14 |
+
padding=padding)
|
15 |
+
|
16 |
+
self.norm1 = nn.Conv2d(
|
17 |
+
in_channels=in_features, out_channels=in_features, kernel_size=1)
|
18 |
+
self.norm2 = nn.Conv2d(
|
19 |
+
in_channels=in_features, out_channels=in_features, kernel_size=1)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
out = self.norm1(x)
|
23 |
+
out = F.relu(out, inplace=True)
|
24 |
+
out = self.conv1(out)
|
25 |
+
out = self.norm2(out)
|
26 |
+
out = F.relu(out, inplace=True)
|
27 |
+
out = self.conv2(out)
|
28 |
+
out += x
|
29 |
+
return out
|
30 |
+
|
31 |
+
|
32 |
+
class RGBADecoderNet(nn.Module):
|
33 |
+
def __init__(self, c=64, out_planes=4, num_bottleneck_blocks=1):
|
34 |
+
super(RGBADecoderNet, self).__init__()
|
35 |
+
self.conv_rgba = nn.Sequential(nn.Conv2d(c, out_planes, kernel_size=3, stride=1,
|
36 |
+
padding=1, dilation=1, bias=True))
|
37 |
+
self.bottleneck = torch.nn.Sequential()
|
38 |
+
for i in range(num_bottleneck_blocks):
|
39 |
+
self.bottleneck.add_module(
|
40 |
+
'r' + str(i), ResBlock2d(c, kernel_size=(3, 3), padding=(1, 1)))
|
41 |
+
|
42 |
+
def forward(self, features_weighted_mask_atfeaturesscale_list=[]):
|
43 |
+
return torch.sigmoid(self.conv_rgba(self.bottleneck(features_weighted_mask_atfeaturesscale_list.pop(0))))
|
model/shader.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from .warplayer import warp_features
|
5 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
6 |
+
|
7 |
+
|
8 |
+
class DecoderBlock(nn.Module):
|
9 |
+
def __init__(self, in_planes, c=224, out_msgs=0, out_locals=0, block_nums=1, out_masks=1, out_local_flows=32, out_msgs_flows=32, out_feat_flows=0):
|
10 |
+
|
11 |
+
super(DecoderBlock, self).__init__()
|
12 |
+
self.conv0 = nn.Sequential(
|
13 |
+
nn.Conv2d(in_planes, c, 3, 2, 1),
|
14 |
+
nn.PReLU(c),
|
15 |
+
nn.Conv2d(c, c, 3, 2, 1),
|
16 |
+
nn.PReLU(c),
|
17 |
+
)
|
18 |
+
|
19 |
+
self.convblocks = nn.ModuleList()
|
20 |
+
for i in range(block_nums):
|
21 |
+
self.convblocks.append(nn.Sequential(
|
22 |
+
nn.Conv2d(c, c, 3, 1, 1),
|
23 |
+
nn.PReLU(c),
|
24 |
+
nn.Conv2d(c, c, 3, 1, 1),
|
25 |
+
nn.PReLU(c),
|
26 |
+
nn.Conv2d(c, c, 3, 1, 1),
|
27 |
+
nn.PReLU(c),
|
28 |
+
nn.Conv2d(c, c, 3, 1, 1),
|
29 |
+
nn.PReLU(c),
|
30 |
+
nn.Conv2d(c, c, 3, 1, 1),
|
31 |
+
nn.PReLU(c),
|
32 |
+
nn.Conv2d(c, c, 3, 1, 1),
|
33 |
+
nn.PReLU(c),
|
34 |
+
))
|
35 |
+
self.out_flows = 2
|
36 |
+
self.out_msgs = out_msgs
|
37 |
+
self.out_msgs_flows = out_msgs_flows if out_msgs > 0 else 0
|
38 |
+
self.out_locals = out_locals
|
39 |
+
self.out_local_flows = out_local_flows if out_locals > 0 else 0
|
40 |
+
self.out_masks = out_masks
|
41 |
+
self.out_feat_flows = out_feat_flows
|
42 |
+
|
43 |
+
self.conv_last = nn.Sequential(
|
44 |
+
nn.ConvTranspose2d(c, c, 4, 2, 1),
|
45 |
+
nn.PReLU(c),
|
46 |
+
nn.ConvTranspose2d(c, self.out_flows+self.out_msgs+self.out_msgs_flows +
|
47 |
+
self.out_locals+self.out_local_flows+self.out_masks+self.out_feat_flows, 4, 2, 1),
|
48 |
+
)
|
49 |
+
|
50 |
+
def forward(self, accumulated_flow, *other):
|
51 |
+
x = [accumulated_flow]
|
52 |
+
for each in other:
|
53 |
+
if each is not None:
|
54 |
+
assert(accumulated_flow.shape[-1] == each.shape[-1]), "decoder want {}, but get {}".format(
|
55 |
+
accumulated_flow.shape, each.shape)
|
56 |
+
x.append(each)
|
57 |
+
feat = self.conv0(torch.cat(x, dim=1))
|
58 |
+
for convblock1 in self.convblocks:
|
59 |
+
feat = convblock1(feat) + feat
|
60 |
+
feat = self.conv_last(feat)
|
61 |
+
prev = 0
|
62 |
+
flow = feat[:, prev:prev+self.out_flows, :, :]
|
63 |
+
prev += self.out_flows
|
64 |
+
message = feat[:, prev:prev+self.out_msgs,
|
65 |
+
:, :] if self.out_msgs > 0 else None
|
66 |
+
prev += self.out_msgs
|
67 |
+
message_flow = feat[:, prev:prev + self.out_msgs_flows,
|
68 |
+
:, :] if self.out_msgs_flows > 0 else None
|
69 |
+
prev += self.out_msgs_flows
|
70 |
+
local_message = feat[:, prev:prev + self.out_locals,
|
71 |
+
:, :] if self.out_locals > 0 else None
|
72 |
+
prev += self.out_locals
|
73 |
+
local_message_flow = feat[:, prev:prev+self.out_local_flows,
|
74 |
+
:, :] if self.out_local_flows > 0 else None
|
75 |
+
prev += self.out_local_flows
|
76 |
+
mask = torch.sigmoid(
|
77 |
+
feat[:, prev:prev+self.out_masks, :, :]) if self.out_masks > 0 else None
|
78 |
+
prev += self.out_masks
|
79 |
+
feat_flow = feat[:, prev:prev+self.out_feat_flows,
|
80 |
+
:, :] if self.out_feat_flows > 0 else None
|
81 |
+
prev += self.out_feat_flows
|
82 |
+
return flow, mask, message, message_flow, local_message, local_message_flow, feat_flow
|
83 |
+
|
84 |
+
|
85 |
+
class CINN(nn.Module):
|
86 |
+
def __init__(self, DIM_SHADER_REFERENCE, target_feature_chns=[512, 256, 128, 64, 64], feature_chns=[2048, 1024, 512, 256, 64], out_msgs_chn=[2048, 1024, 512, 256, 64, 64], out_locals_chn=[2048, 1024, 512, 256, 64, 0], block_num=[1, 1, 1, 1, 1, 2], block_chn_num=[224, 224, 224, 224, 224, 224]):
|
87 |
+
super(CINN, self).__init__()
|
88 |
+
|
89 |
+
self.in_msgs_chn = [0, *out_msgs_chn[:-1]]
|
90 |
+
self.in_locals_chn = [0, *out_locals_chn[:-1]]
|
91 |
+
|
92 |
+
self.decoder_blocks = nn.ModuleList()
|
93 |
+
self.feed_weighted = True
|
94 |
+
if self.feed_weighted:
|
95 |
+
in_planes = 2+2+DIM_SHADER_REFERENCE*2
|
96 |
+
else:
|
97 |
+
in_planes = 2+DIM_SHADER_REFERENCE
|
98 |
+
for each_target_feature_chns, each_feature_chns, each_out_msgs_chn, each_out_locals_chn, each_in_msgs_chn, each_in_locals_chn, each_block_num, each_block_chn_num in zip(target_feature_chns, feature_chns, out_msgs_chn, out_locals_chn, self.in_msgs_chn, self.in_locals_chn, block_num, block_chn_num):
|
99 |
+
self.decoder_blocks.append(
|
100 |
+
DecoderBlock(in_planes+each_target_feature_chns+each_feature_chns+each_in_locals_chn+each_in_msgs_chn, c=each_block_chn_num, block_nums=each_block_num, out_msgs=each_out_msgs_chn, out_locals=each_out_locals_chn, out_masks=2+each_out_locals_chn))
|
101 |
+
for i in range(len(feature_chns), len(out_locals_chn)):
|
102 |
+
#print("append extra block", i, "msg",
|
103 |
+
# out_msgs_chn[i], "local", out_locals_chn[i], "block", block_num[i])
|
104 |
+
self.decoder_blocks.append(
|
105 |
+
DecoderBlock(in_planes+self.in_msgs_chn[i]+self.in_locals_chn[i], c=block_chn_num[i], block_nums=block_num[i], out_msgs=out_msgs_chn[i], out_locals=out_locals_chn[i], out_masks=2+out_msgs_chn[i], out_feat_flows=0))
|
106 |
+
|
107 |
+
def apply_flow(self, mask, message, message_flow, local_message, local_message_flow, x_reference, accumulated_flow, each_x_reference_features=None, each_x_reference_features_flow=None):
|
108 |
+
if each_x_reference_features is not None:
|
109 |
+
size_from = each_x_reference_features
|
110 |
+
else:
|
111 |
+
size_from = x_reference
|
112 |
+
f_size = (size_from.shape[2], size_from.shape[3])
|
113 |
+
accumulated_flow = self.flow_rescale(
|
114 |
+
accumulated_flow, size_from)
|
115 |
+
# mask = warp_features(F.interpolate(
|
116 |
+
# mask, size=f_size, mode="bilinear"), accumulated_flow) if mask is not None else None
|
117 |
+
mask = F.interpolate(
|
118 |
+
mask, size=f_size, mode="bilinear") if mask is not None else None
|
119 |
+
message = F.interpolate(
|
120 |
+
message, size=f_size, mode="bilinear") if message is not None else None
|
121 |
+
message_flow = self.flow_rescale(
|
122 |
+
message_flow, size_from) if message_flow is not None else None
|
123 |
+
message = warp_features(
|
124 |
+
message, message_flow) if message_flow is not None else message
|
125 |
+
|
126 |
+
local_message = F.interpolate(
|
127 |
+
local_message, size=f_size, mode="bilinear") if local_message is not None else None
|
128 |
+
local_message_flow = self.flow_rescale(
|
129 |
+
local_message_flow, size_from) if local_message_flow is not None else None
|
130 |
+
local_message = warp_features(
|
131 |
+
local_message, local_message_flow) if local_message_flow is not None else local_message
|
132 |
+
|
133 |
+
warp_x_reference = warp_features(F.interpolate(
|
134 |
+
x_reference, size=f_size, mode="bilinear"), accumulated_flow)
|
135 |
+
|
136 |
+
each_x_reference_features_flow = self.flow_rescale(
|
137 |
+
each_x_reference_features_flow, size_from) if (each_x_reference_features is not None and each_x_reference_features_flow is not None) else None
|
138 |
+
warp_each_x_reference_features = warp_features(
|
139 |
+
each_x_reference_features, each_x_reference_features_flow) if each_x_reference_features_flow is not None else each_x_reference_features
|
140 |
+
|
141 |
+
return mask, message, local_message, warp_x_reference, accumulated_flow, warp_each_x_reference_features, each_x_reference_features_flow
|
142 |
+
|
143 |
+
def forward(self, x_target_features=[], x_reference=None, x_reference_features=[]):
|
144 |
+
y_flow = []
|
145 |
+
y_feat_flow = []
|
146 |
+
|
147 |
+
y_local_message = []
|
148 |
+
y_warp_x_reference = []
|
149 |
+
y_warp_x_reference_features = []
|
150 |
+
|
151 |
+
y_weighted_flow = []
|
152 |
+
y_weighted_mask = []
|
153 |
+
y_weighted_message = []
|
154 |
+
y_weighted_x_reference = []
|
155 |
+
y_weighted_x_reference_features = []
|
156 |
+
|
157 |
+
for pyrlevel, ifblock in enumerate(self.decoder_blocks):
|
158 |
+
stacked_wref = []
|
159 |
+
stacked_feat = []
|
160 |
+
stacked_anci = []
|
161 |
+
stacked_flow = []
|
162 |
+
stacked_mask = []
|
163 |
+
stacked_mesg = []
|
164 |
+
stacked_locm = []
|
165 |
+
stacked_feat_flow = []
|
166 |
+
for view_id in range(x_reference.shape[1]): # NMCHW
|
167 |
+
|
168 |
+
if pyrlevel == 0:
|
169 |
+
# create from zero flow
|
170 |
+
feat_ev = x_reference_features[pyrlevel][:,
|
171 |
+
view_id, :, :, :] if pyrlevel < len(x_reference_features) else None
|
172 |
+
|
173 |
+
accumulated_flow = torch.zeros_like(
|
174 |
+
feat_ev[:, :2, :, :]).to(device)
|
175 |
+
accumulated_feat_flow = torch.zeros_like(
|
176 |
+
feat_ev[:, :32, :, :]).to(device)
|
177 |
+
# domestic inputs
|
178 |
+
warp_x_reference = F.interpolate(x_reference[:, view_id, :, :, :], size=(
|
179 |
+
feat_ev.shape[-2], feat_ev.shape[-1]), mode="bilinear")
|
180 |
+
warp_x_reference_features = feat_ev
|
181 |
+
|
182 |
+
local_message = None
|
183 |
+
# federated inputs
|
184 |
+
weighted_flow = accumulated_flow if self.feed_weighted else None
|
185 |
+
weighted_wref = warp_x_reference if self.feed_weighted else None
|
186 |
+
weighted_message = None
|
187 |
+
else:
|
188 |
+
# resume from last layer
|
189 |
+
accumulated_flow = y_flow[-1][:, view_id, :, :, :]
|
190 |
+
accumulated_feat_flow = y_feat_flow[-1][:,
|
191 |
+
view_id, :, :, :] if y_feat_flow[-1] is not None else None
|
192 |
+
# domestic inputs
|
193 |
+
warp_x_reference = y_warp_x_reference[-1][:,
|
194 |
+
view_id, :, :, :]
|
195 |
+
warp_x_reference_features = y_warp_x_reference_features[-1][:,
|
196 |
+
view_id, :, :, :] if y_warp_x_reference_features[-1] is not None else None
|
197 |
+
local_message = y_local_message[-1][:, view_id, :,
|
198 |
+
:, :] if len(y_local_message) > 0 else None
|
199 |
+
|
200 |
+
# federated inputs
|
201 |
+
weighted_flow = y_weighted_flow[-1] if self.feed_weighted else None
|
202 |
+
weighted_wref = y_weighted_x_reference[-1] if self.feed_weighted else None
|
203 |
+
weighted_message = y_weighted_message[-1] if len(
|
204 |
+
y_weighted_message) > 0 else None
|
205 |
+
scaled_x_target = x_target_features[pyrlevel][:, :, :, :].detach() if pyrlevel < len(
|
206 |
+
x_target_features) else None
|
207 |
+
# compute flow
|
208 |
+
residual_flow, mask, message, message_flow, local_message, local_message_flow, residual_feat_flow = ifblock(
|
209 |
+
accumulated_flow, scaled_x_target, warp_x_reference, warp_x_reference_features, weighted_flow, weighted_wref, weighted_message, local_message)
|
210 |
+
accumulated_flow = residual_flow + accumulated_flow
|
211 |
+
accumulated_feat_flow = accumulated_flow
|
212 |
+
|
213 |
+
feat_ev = x_reference_features[pyrlevel+1][:,
|
214 |
+
view_id, :, :, :] if pyrlevel+1 < len(x_reference_features) else None
|
215 |
+
mask, message, local_message, warp_x_reference, accumulated_flow, warp_x_reference_features, accumulated_feat_flow = self.apply_flow(
|
216 |
+
mask, message, message_flow, local_message, local_message_flow, x_reference[:, view_id, :, :, :], accumulated_flow, feat_ev, accumulated_feat_flow)
|
217 |
+
stacked_flow.append(accumulated_flow)
|
218 |
+
if accumulated_feat_flow is not None:
|
219 |
+
stacked_feat_flow.append(accumulated_feat_flow)
|
220 |
+
stacked_mask.append(mask)
|
221 |
+
if message is not None:
|
222 |
+
stacked_mesg.append(message)
|
223 |
+
if local_message is not None:
|
224 |
+
stacked_locm.append(local_message)
|
225 |
+
stacked_wref.append(warp_x_reference)
|
226 |
+
if warp_x_reference_features is not None:
|
227 |
+
stacked_feat.append(warp_x_reference_features)
|
228 |
+
|
229 |
+
stacked_flow = torch.stack(stacked_flow, dim=1) # M*NCHW -> NMCHW
|
230 |
+
stacked_feat_flow = torch.stack(stacked_feat_flow, dim=1) if len(
|
231 |
+
stacked_feat_flow) > 0 else None
|
232 |
+
stacked_mask = torch.stack(
|
233 |
+
stacked_mask, dim=1)
|
234 |
+
|
235 |
+
stacked_mesg = torch.stack(stacked_mesg, dim=1) if len(
|
236 |
+
stacked_mesg) > 0 else None
|
237 |
+
stacked_locm = torch.stack(stacked_locm, dim=1) if len(
|
238 |
+
stacked_locm) > 0 else None
|
239 |
+
|
240 |
+
stacked_wref = torch.stack(stacked_wref, dim=1)
|
241 |
+
stacked_feat = torch.stack(stacked_feat, dim=1) if len(
|
242 |
+
stacked_feat) > 0 else None
|
243 |
+
stacked_anci = torch.stack(stacked_anci, dim=1) if len(
|
244 |
+
stacked_anci) > 0 else None
|
245 |
+
y_flow.append(stacked_flow)
|
246 |
+
y_feat_flow.append(stacked_feat_flow)
|
247 |
+
|
248 |
+
y_warp_x_reference.append(stacked_wref)
|
249 |
+
y_warp_x_reference_features.append(stacked_feat)
|
250 |
+
# compute normalized confidence
|
251 |
+
stacked_contrib = torch.nn.functional.softmax(stacked_mask, dim=1)
|
252 |
+
|
253 |
+
# torch.sum to remove temp dimension M from NMCHW --> NCHW
|
254 |
+
weighted_flow = torch.sum(
|
255 |
+
stacked_mask[:, :, 0:1, :, :] * stacked_contrib[:, :, 0:1, :, :] * stacked_flow, dim=1)
|
256 |
+
weighted_mask = torch.sum(
|
257 |
+
stacked_contrib[:, :, 0:1, :, :] * stacked_mask[:, :, 0:1, :, :], dim=1)
|
258 |
+
weighted_wref = torch.sum(
|
259 |
+
stacked_mask[:, :, 0:1, :, :] * stacked_contrib[:, :, 0:1, :, :] * stacked_wref, dim=1) if stacked_wref is not None else None
|
260 |
+
weighted_feat = torch.sum(
|
261 |
+
stacked_mask[:, :, 1:2, :, :] * stacked_contrib[:, :, 1:2, :, :] * stacked_feat, dim=1) if stacked_feat is not None else None
|
262 |
+
weighted_mesg = torch.sum(
|
263 |
+
stacked_mask[:, :, 2:, :, :] * stacked_contrib[:, :, 2:, :, :] * stacked_mesg, dim=1) if stacked_mesg is not None else None
|
264 |
+
y_weighted_flow.append(weighted_flow)
|
265 |
+
y_weighted_mask.append(weighted_mask)
|
266 |
+
if weighted_mesg is not None:
|
267 |
+
y_weighted_message.append(weighted_mesg)
|
268 |
+
if stacked_locm is not None:
|
269 |
+
y_local_message.append(stacked_locm)
|
270 |
+
y_weighted_message.append(weighted_mesg)
|
271 |
+
y_weighted_x_reference.append(weighted_wref)
|
272 |
+
y_weighted_x_reference_features.append(weighted_feat)
|
273 |
+
|
274 |
+
if weighted_feat is not None:
|
275 |
+
y_weighted_x_reference_features.append(weighted_feat)
|
276 |
+
return {
|
277 |
+
"y_last_remote_features": [weighted_mesg],
|
278 |
+
}
|
279 |
+
|
280 |
+
def flow_rescale(self, prev_flow, each_x_reference_features):
|
281 |
+
if prev_flow is None:
|
282 |
+
prev_flow = torch.zeros_like(
|
283 |
+
each_x_reference_features[:, :2]).to(device)
|
284 |
+
else:
|
285 |
+
up_scale_factor = each_x_reference_features.shape[-1] / \
|
286 |
+
prev_flow.shape[-1]
|
287 |
+
if up_scale_factor != 1:
|
288 |
+
prev_flow = F.interpolate(prev_flow, scale_factor=up_scale_factor, mode="bilinear",
|
289 |
+
align_corners=False, recompute_scale_factor=False) * up_scale_factor
|
290 |
+
return prev_flow
|
model/warplayer.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
5 |
+
backwarp_tenGrid = {}
|
6 |
+
|
7 |
+
|
8 |
+
def warp(tenInput, tenFlow):
|
9 |
+
with torch.cuda.amp.autocast(enabled=False):
|
10 |
+
k = (str(tenFlow.device), str(tenFlow.size()))
|
11 |
+
if k not in backwarp_tenGrid:
|
12 |
+
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
|
13 |
+
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
14 |
+
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
|
15 |
+
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
16 |
+
backwarp_tenGrid[k] = torch.cat(
|
17 |
+
[tenHorizontal, tenVertical], 1).to(device)
|
18 |
+
|
19 |
+
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
20 |
+
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
|
21 |
+
|
22 |
+
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
23 |
+
if tenInput.dtype != g.dtype:
|
24 |
+
g = g.to(tenInput.dtype)
|
25 |
+
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
|
26 |
+
# "zeros" "border"
|
27 |
+
|
28 |
+
|
29 |
+
def warp_features(inp, flow, ):
|
30 |
+
groups = flow.shape[1]//2 # NCHW
|
31 |
+
samples = inp.shape[0]
|
32 |
+
h = inp.shape[2]
|
33 |
+
w = inp.shape[3]
|
34 |
+
assert(flow.shape[0] == samples and flow.shape[2]
|
35 |
+
== h and flow.shape[3] == w)
|
36 |
+
chns = inp.shape[1]
|
37 |
+
chns_per_group = chns // groups
|
38 |
+
assert(flow.shape[1] % 2 == 0)
|
39 |
+
assert(chns % groups == 0)
|
40 |
+
inp = inp.contiguous().view(samples*groups, chns_per_group, h, w)
|
41 |
+
flow = flow.contiguous().view(samples*groups, 2, h, w)
|
42 |
+
feat = warp(inp, flow)
|
43 |
+
feat = feat.view(samples, chns, h, w)
|
44 |
+
return feat
|
45 |
+
|
46 |
+
|
47 |
+
def flow2rgb(flow_map_np):
|
48 |
+
h, w, _ = flow_map_np.shape
|
49 |
+
rgb_map = np.ones((h, w, 3)).astype(np.float32)/2.0
|
50 |
+
normalized_flow_map = np.concatenate(
|
51 |
+
(flow_map_np[:, :, 0:1]/h/2.0, flow_map_np[:, :, 1:2]/w/2.0), axis=2)
|
52 |
+
rgb_map[:, :, 0] += normalized_flow_map[:, :, 0]
|
53 |
+
rgb_map[:, :, 1] -= 0.5 * \
|
54 |
+
(normalized_flow_map[:, :, 0] + normalized_flow_map[:, :, 1])
|
55 |
+
rgb_map[:, :, 2] += normalized_flow_map[:, :, 1]
|
56 |
+
return (rgb_map.clip(0, 1)*255.0).astype(np.uint8)
|
streamlit.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import streamlit as st
|
4 |
+
import os
|
5 |
+
import base64
|
6 |
+
|
7 |
+
st.set_page_config(layout="wide", page_title='CoNR demo', page_icon="🪐")
|
8 |
+
|
9 |
+
st.title('CoNR demo')
|
10 |
+
st.markdown(""" <style>
|
11 |
+
#MainMenu {visibility: hidden;}
|
12 |
+
footer {visibility: hidden;}
|
13 |
+
</style> """, unsafe_allow_html=True)
|
14 |
+
|
15 |
+
def get_base64(bin_file):
|
16 |
+
with open(bin_file, 'rb') as f:
|
17 |
+
data = f.read()
|
18 |
+
return base64.b64encode(data).decode()
|
19 |
+
|
20 |
+
# def set_background(png_file):
|
21 |
+
# bin_str = get_base64(png_file)
|
22 |
+
# page_bg_img = '''
|
23 |
+
# <style>
|
24 |
+
# .stApp {
|
25 |
+
# background-image: url("data:image/png;base64,%s");
|
26 |
+
# background-size: 1920px 1080px;
|
27 |
+
# background-attachment:fixed;
|
28 |
+
# background-position:center;
|
29 |
+
# background-repeat:no-repeat;
|
30 |
+
# }
|
31 |
+
# </style>
|
32 |
+
# ''' % bin_str
|
33 |
+
# st.markdown(page_bg_img, unsafe_allow_html=True)
|
34 |
+
|
35 |
+
# set_background('ipad_bg.png')
|
36 |
+
|
37 |
+
upload_img = (st.file_uploader("输入character sheet", "png", accept_multiple_files=True))
|
38 |
+
|
39 |
+
if st.button('RUN!'):
|
40 |
+
if upload_img is not None:
|
41 |
+
for i in range(len(upload_img)):
|
42 |
+
with open('character_sheet/{}.png'.format(i), 'wb') as f:
|
43 |
+
f.write(upload_img[i].read())
|
44 |
+
|
45 |
+
st.info('努力推理中...')
|
46 |
+
os.system('sh infer.sh')
|
47 |
+
st.info('Done!')
|
48 |
+
video_file=open('output.mp4', 'rb')
|
49 |
+
video_bytes = video_file.read()
|
50 |
+
st.video(video_bytes, start_time=0)
|
51 |
+
else:
|
52 |
+
st.info('还没上传图片呢> <')
|
train.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
from datetime import datetime
|
5 |
+
from distutils.util import strtobool
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
from torchvision import transforms
|
11 |
+
from data_loader import (FileDataset,
|
12 |
+
RandomResizedCropWithAutoCenteringAndZeroPadding)
|
13 |
+
from torch.utils.data.distributed import DistributedSampler
|
14 |
+
from conr import CoNR
|
15 |
+
|
16 |
+
def data_sampler(dataset, shuffle, distributed):
|
17 |
+
|
18 |
+
if distributed:
|
19 |
+
return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
|
20 |
+
|
21 |
+
if shuffle:
|
22 |
+
return torch.utils.data.RandomSampler(dataset)
|
23 |
+
|
24 |
+
else:
|
25 |
+
return torch.utils.data.SequentialSampler(dataset)
|
26 |
+
|
27 |
+
def save_output(image_name, inputs_v, d_dir=".", crop=None):
|
28 |
+
import cv2
|
29 |
+
|
30 |
+
inputs_v = inputs_v.detach().squeeze()
|
31 |
+
input_np = torch.clamp(inputs_v*255, 0, 255).byte().cpu().numpy().transpose(
|
32 |
+
(1, 2, 0))
|
33 |
+
# cv2.setNumThreads(1)
|
34 |
+
out_render_scale = cv2.cvtColor(input_np, cv2.COLOR_RGBA2BGRA)
|
35 |
+
if crop is not None:
|
36 |
+
crop = crop.cpu().numpy()[0]
|
37 |
+
output_img = np.zeros((crop[0], crop[1], 4), dtype=np.uint8)
|
38 |
+
before_resize_scale = cv2.resize(
|
39 |
+
out_render_scale, (crop[5]-crop[4]+crop[8]+crop[9], crop[3]-crop[2]+crop[6]+crop[7]), interpolation=cv2.INTER_AREA) # w,h
|
40 |
+
output_img[crop[2]:crop[3], crop[4]:crop[5]] = before_resize_scale[crop[6]:before_resize_scale.shape[0] -
|
41 |
+
crop[7], crop[8]:before_resize_scale.shape[1]-crop[9]]
|
42 |
+
else:
|
43 |
+
output_img = out_render_scale
|
44 |
+
cv2.imwrite(d_dir+"/"+image_name.split(os.sep)[-1]+'.png',
|
45 |
+
output_img
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
def test():
|
50 |
+
source_names_list = []
|
51 |
+
for name in sorted(os.listdir(args.test_input_person_images)):
|
52 |
+
thissource = os.path.join(args.test_input_person_images, name)
|
53 |
+
if os.path.isfile(thissource):
|
54 |
+
source_names_list.append(thissource)
|
55 |
+
if os.path.isdir(thissource):
|
56 |
+
print("skipping empty folder :"+thissource)
|
57 |
+
|
58 |
+
image_names_list = []
|
59 |
+
for name in sorted(os.listdir(args.test_input_poses_images)):
|
60 |
+
thistarget = os.path.join(args.test_input_poses_images, name)
|
61 |
+
if os.path.isfile(thistarget):
|
62 |
+
image_names_list.append([thistarget, *source_names_list])
|
63 |
+
if os.path.isdir(thistarget):
|
64 |
+
print("skipping folder :"+thistarget)
|
65 |
+
print(image_names_list)
|
66 |
+
|
67 |
+
print("---building models")
|
68 |
+
conrmodel = CoNR(args)
|
69 |
+
conrmodel.load_model(path=args.test_checkpoint_dir)
|
70 |
+
conrmodel.dist()
|
71 |
+
infer(args, conrmodel, image_names_list)
|
72 |
+
|
73 |
+
# def test():
|
74 |
+
# source_names_list = []
|
75 |
+
# for name in os.listdir(args.test_input_person_images):
|
76 |
+
# thissource = os.path.join(args.test_input_person_images, name)
|
77 |
+
# if os.path.isfile(thissource):
|
78 |
+
# source_names_list.append([thissource])
|
79 |
+
# if os.path.isdir(thissource):
|
80 |
+
# toadd = [os.path.join(thissource, this_file)
|
81 |
+
# for this_file in os.listdir(thissource)]
|
82 |
+
# if (toadd != []):
|
83 |
+
# source_names_list.append(toadd)
|
84 |
+
# else:
|
85 |
+
# print("skipping empty folder :"+thissource)
|
86 |
+
# image_names_list = []
|
87 |
+
# for eachlist in source_names_list:
|
88 |
+
# for name in sorted(os.listdir(args.test_input_poses_images)):
|
89 |
+
# thistarget = os.path.join(args.test_input_poses_images, name)
|
90 |
+
# if os.path.isfile(thistarget):
|
91 |
+
# image_names_list.append([thistarget, *eachlist])
|
92 |
+
# if os.path.isdir(thistarget):
|
93 |
+
# print("skipping folder :"+thistarget)
|
94 |
+
|
95 |
+
# print(image_names_list)
|
96 |
+
# print("---building models...")
|
97 |
+
# conrmodel = CoNR(args)
|
98 |
+
# conrmodel.load_model(path=args.test_checkpoint_dir)
|
99 |
+
# conrmodel.dist()
|
100 |
+
# infer(args, conrmodel, image_names_list)
|
101 |
+
|
102 |
+
|
103 |
+
def infer(args, humanflowmodel, image_names_list):
|
104 |
+
print("---test images: ", len(image_names_list))
|
105 |
+
test_salobj_dataset = FileDataset(image_names_list=image_names_list,
|
106 |
+
fg_img_lbl_transform=transforms.Compose([
|
107 |
+
RandomResizedCropWithAutoCenteringAndZeroPadding(
|
108 |
+
(args.dataloader_imgsize, args.dataloader_imgsize), scale=(1, 1), ratio=(1.0, 1.0), center_jitter=(0.0, 0.0)
|
109 |
+
)]),
|
110 |
+
shader_pose_use_gt_udp_test=not args.test_pose_use_parser_udp,
|
111 |
+
shader_target_use_gt_rgb_debug=False
|
112 |
+
)
|
113 |
+
sampler = data_sampler(test_salobj_dataset, shuffle=False,
|
114 |
+
distributed=args.distributed)
|
115 |
+
train_data = DataLoader(test_salobj_dataset,
|
116 |
+
batch_size=1,
|
117 |
+
shuffle=False,sampler=sampler,
|
118 |
+
num_workers=args.dataloaders)
|
119 |
+
|
120 |
+
# start testing
|
121 |
+
|
122 |
+
train_num = train_data.__len__()
|
123 |
+
time_stamp = time.time()
|
124 |
+
prev_frame_rgb = []
|
125 |
+
prev_frame_a = []
|
126 |
+
for i, data in enumerate(train_data):
|
127 |
+
data_time_interval = time.time() - time_stamp
|
128 |
+
time_stamp = time.time()
|
129 |
+
with torch.no_grad():
|
130 |
+
data["character_images"] = torch.cat(
|
131 |
+
[data["character_images"], *prev_frame_rgb], dim=1)
|
132 |
+
data["character_masks"] = torch.cat(
|
133 |
+
[data["character_masks"], *prev_frame_a], dim=1)
|
134 |
+
data = humanflowmodel.data_norm_image(data)
|
135 |
+
pred = humanflowmodel.model_step(data, training=False)
|
136 |
+
# remember to call humanflowmodel.reset_charactersheet() if you change character .
|
137 |
+
|
138 |
+
train_time_interval = time.time() - time_stamp
|
139 |
+
time_stamp = time.time()
|
140 |
+
if i % 5 == 0 and args.local_rank == 0:
|
141 |
+
print("[infer batch: %4d/%4d] time:%2f+%2f" % (
|
142 |
+
i, train_num,
|
143 |
+
data_time_interval, train_time_interval
|
144 |
+
))
|
145 |
+
with torch.no_grad():
|
146 |
+
|
147 |
+
if args.test_output_video:
|
148 |
+
pred_img = pred["shader"]["y_weighted_warp_decoded_rgba"]
|
149 |
+
save_output(
|
150 |
+
str(int(data["imidx"].cpu().item())), pred_img, args.test_output_dir, crop=data["pose_crop"])
|
151 |
+
|
152 |
+
if args.test_output_udp:
|
153 |
+
pred_img = pred["shader"]["x_target_sudp_a"]
|
154 |
+
save_output(
|
155 |
+
"udp_"+str(int(data["imidx"].cpu().item())), pred_img, args.test_output_dir)
|
156 |
+
|
157 |
+
|
158 |
+
def build_args():
|
159 |
+
parser = argparse.ArgumentParser()
|
160 |
+
# distributed learning settings
|
161 |
+
parser.add_argument("--world_size", type=int, default=1,
|
162 |
+
help='world size')
|
163 |
+
parser.add_argument("--local_rank", type=int, default=0,
|
164 |
+
help='local_rank, DON\'T change it')
|
165 |
+
|
166 |
+
# model settings
|
167 |
+
parser.add_argument('--dataloader_imgsize', type=int, default=256,
|
168 |
+
help='Input image size of the model')
|
169 |
+
parser.add_argument('--batch_size', type=int, default=4,
|
170 |
+
help='minibatch size')
|
171 |
+
parser.add_argument('--model_name', default='model_result',
|
172 |
+
help='Name of the experiment')
|
173 |
+
parser.add_argument('--dataloaders', type=int, default=2,
|
174 |
+
help='Num of dataloaders')
|
175 |
+
parser.add_argument('--mode', default="test", choices=['train', 'test'],
|
176 |
+
help='Training mode or Testing mode')
|
177 |
+
|
178 |
+
# i/o settings
|
179 |
+
parser.add_argument('--test_input_person_images',
|
180 |
+
type=str, default="./character_sheet/",
|
181 |
+
help='Directory to input character sheets')
|
182 |
+
parser.add_argument('--test_input_poses_images', type=str,
|
183 |
+
default="./test_data/",
|
184 |
+
help='Directory to input UDP sequences or pose images')
|
185 |
+
parser.add_argument('--test_checkpoint_dir', type=str,
|
186 |
+
default='./weights/',
|
187 |
+
help='Directory to model weights')
|
188 |
+
parser.add_argument('--test_output_dir', type=str,
|
189 |
+
default="./results/",
|
190 |
+
help='Directory to output images')
|
191 |
+
|
192 |
+
# output content settings
|
193 |
+
parser.add_argument('--test_output_video', type=strtobool, default=True,
|
194 |
+
help='Whether to output the final result of CoNR, \
|
195 |
+
images will be output to test_output_dir while True.')
|
196 |
+
parser.add_argument('--test_output_udp', type=strtobool, default=False,
|
197 |
+
help='Whether to output UDP generated from UDP detector, \
|
198 |
+
this is meaningful ONLY when test_input_poses_images \
|
199 |
+
is not UDP sequences but pose images. Meanwhile, \
|
200 |
+
test_pose_use_parser_udp need to be True')
|
201 |
+
|
202 |
+
# UDP detector settings
|
203 |
+
parser.add_argument('--test_pose_use_parser_udp',
|
204 |
+
type=strtobool, default=False,
|
205 |
+
help='Whether to use UDP detector to generate UDP from pngs, \
|
206 |
+
pose input MUST be pose images instead of UDP sequences \
|
207 |
+
while True')
|
208 |
+
|
209 |
+
args = parser.parse_args()
|
210 |
+
|
211 |
+
args.distributed = (args.world_size > 1)
|
212 |
+
if args.local_rank == 0:
|
213 |
+
print("batch_size:", args.batch_size, flush=True)
|
214 |
+
if args.distributed:
|
215 |
+
if args.local_rank == 0:
|
216 |
+
print("world_size: ", args.world_size)
|
217 |
+
torch.distributed.init_process_group(
|
218 |
+
backend="nccl", init_method="env://", world_size=args.world_size)
|
219 |
+
torch.cuda.set_device(args.local_rank)
|
220 |
+
torch.backends.cudnn.benchmark = True
|
221 |
+
else:
|
222 |
+
args.local_rank = 0
|
223 |
+
|
224 |
+
return args
|
225 |
+
|
226 |
+
|
227 |
+
if __name__ == "__main__":
|
228 |
+
args = build_args()
|
229 |
+
test()
|