p2oileen commited on
Commit
c34ed4d
·
1 Parent(s): 18b58f0

initial commit

Browse files
Files changed (11) hide show
  1. .gitignore +11 -0
  2. conr.py +292 -0
  3. data_loader.py +273 -0
  4. infer.sh +14 -0
  5. model/__init__.py +1 -0
  6. model/backbone.py +285 -0
  7. model/decoder_small.py +43 -0
  8. model/shader.py +290 -0
  9. model/warplayer.py +56 -0
  10. streamlit.py +52 -0
  11. train.py +229 -0
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ results/
2
+ test_data/
3
+ test_data_pre/
4
+ weights/
5
+ x264/
6
+ *.mp3
7
+ *.mp4
8
+ *.txt
9
+ *.png
10
+ complex_infer.sh
11
+ __pycache__/
conr.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import torch
4
+
5
+ from model.backbone import ResEncUnet
6
+
7
+ from model.shader import CINN
8
+ from model.decoder_small import RGBADecoderNet
9
+
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+
13
+ def UDPClip(x):
14
+ return torch.clamp(x, min=0, max=1) # NCHW
15
+
16
+
17
+ class CoNR():
18
+ def __init__(self, args):
19
+ self.args = args
20
+
21
+ self.udpparsernet = ResEncUnet(
22
+ backbone_name='resnet50_danbo',
23
+ classes=4,
24
+ pretrained=(args.local_rank == 0),
25
+ parametric_upsampling=True,
26
+ decoder_filters=(512, 384, 256, 128, 32),
27
+ map_location=device
28
+ )
29
+ self.target_pose_encoder = ResEncUnet(
30
+ backbone_name='resnet18_danbo-4',
31
+ classes=1,
32
+ pretrained=(args.local_rank == 0),
33
+ parametric_upsampling=True,
34
+ decoder_filters=(512, 384, 256, 128, 32),
35
+ map_location=device
36
+ )
37
+ self.DIM_SHADER_REFERENCE = 4
38
+ self.shader = CINN(self.DIM_SHADER_REFERENCE)
39
+ self.rgbadecodernet = RGBADecoderNet(
40
+ )
41
+ self.device()
42
+ self.parser_ckpt = None
43
+
44
+ def dist(self):
45
+ args = self.args
46
+ if args.distributed:
47
+ self.udpparsernet = torch.nn.parallel.DistributedDataParallel(
48
+ self.udpparsernet,
49
+ device_ids=[
50
+ args.local_rank],
51
+ output_device=args.local_rank,
52
+ broadcast_buffers=False,
53
+ find_unused_parameters=True
54
+ )
55
+ self.target_pose_encoder = torch.nn.parallel.DistributedDataParallel(
56
+ self.target_pose_encoder,
57
+ device_ids=[
58
+ args.local_rank],
59
+ output_device=args.local_rank,
60
+ broadcast_buffers=False,
61
+ find_unused_parameters=True
62
+ )
63
+ self.shader = torch.nn.parallel.DistributedDataParallel(
64
+ self.shader,
65
+ device_ids=[
66
+ args.local_rank],
67
+ output_device=args.local_rank,
68
+ broadcast_buffers=True
69
+ )
70
+
71
+ self.rgbadecodernet = torch.nn.parallel.DistributedDataParallel(
72
+ self.rgbadecodernet,
73
+ device_ids=[
74
+ args.local_rank],
75
+ output_device=args.local_rank,
76
+ broadcast_buffers=True
77
+ )
78
+
79
+ def load_model(self, path):
80
+ self.udpparsernet.load_state_dict(
81
+ torch.load('{}/udpparsernet.pth'.format(path), map_location=device))
82
+ self.target_pose_encoder.load_state_dict(
83
+ torch.load('{}/target_pose_encoder.pth'.format(path), map_location=device))
84
+ self.shader.load_state_dict(
85
+ torch.load('{}/shader.pth'.format(path), map_location=device))
86
+ self.rgbadecodernet.load_state_dict(
87
+ torch.load('{}/rgbadecodernet.pth'.format(path), map_location=device))
88
+
89
+ def save_model(self, ite_num):
90
+ self._save_pth(self.udpparsernet,
91
+ model_name="udpparsernet", ite_num=ite_num)
92
+ self._save_pth(self.target_pose_encoder,
93
+ model_name="target_pose_encoder", ite_num=ite_num)
94
+ self._save_pth(self.shader,
95
+ model_name="shader", ite_num=ite_num)
96
+ self._save_pth(self.rgbadecodernet,
97
+ model_name="rgbadecodernet", ite_num=ite_num)
98
+
99
+ def _save_pth(self, net, model_name, ite_num):
100
+ args = self.args
101
+ to_save = None
102
+ if args.distributed:
103
+ if args.local_rank == 0:
104
+ to_save = net.module.state_dict()
105
+ else:
106
+ to_save = net.state_dict()
107
+ if to_save:
108
+ model_dir = os.path.join(
109
+ os.getcwd(), 'saved_models', args.model_name + os.sep + "checkpoints" + os.sep + "itr_%d" % (ite_num)+os.sep)
110
+
111
+ os.makedirs(model_dir, exist_ok=True)
112
+ torch.save(to_save, model_dir + model_name + ".pth")
113
+
114
+ def train(self):
115
+ self.udpparsernet.train()
116
+ self.target_pose_encoder.train()
117
+ self.shader.train()
118
+ self.rgbadecodernet.train()
119
+
120
+ def eval(self):
121
+ self.udpparsernet.eval()
122
+ self.target_pose_encoder.eval()
123
+ self.shader.eval()
124
+ self.rgbadecodernet.eval()
125
+
126
+ def device(self):
127
+ self.udpparsernet.to(device)
128
+ self.target_pose_encoder.to(device)
129
+ self.shader.to(device)
130
+ self.rgbadecodernet.to(device)
131
+
132
+ def data_norm_image(self, data):
133
+
134
+ with torch.cuda.amp.autocast(enabled=False):
135
+ for name in ["character_labels", "pose_label"]:
136
+ if name in data:
137
+ data[name] = data[name].to(
138
+ device, non_blocking=True).float()
139
+ for name in ["pose_images", "pose_mask", "character_images", "character_masks"]:
140
+ if name in data:
141
+ data[name] = data[name].to(
142
+ device, non_blocking=True).float() / 255.0
143
+ if "pose_images" in data:
144
+ data["num_pose_images"] = data["pose_images"].shape[1]
145
+ data["num_samples"] = data["pose_images"].shape[0]
146
+ if "character_images" in data:
147
+ data["num_character_images"] = data["character_images"].shape[1]
148
+ data["num_samples"] = data["character_images"].shape[0]
149
+ if "pose_images" in data and "character_images" in data:
150
+ assert (data["pose_images"].shape[0] ==
151
+ data["character_images"].shape[0])
152
+ return data
153
+
154
+ def reset_charactersheet(self):
155
+ self.parser_ckpt = None
156
+
157
+ def model_step(self, data, training=False):
158
+ self.eval()
159
+ with torch.cuda.amp.autocast(enabled=False):
160
+ pred = {}
161
+ if self.parser_ckpt:
162
+ pred["parser"] = self.parser_ckpt
163
+ else:
164
+ pred = self.character_parser_forward(data, pred)
165
+ self.parser_ckpt = pred["parser"]
166
+ pred = self.pose_parser_sc_forward(data, pred)
167
+ pred = self.shader_pose_encoder_forward(data, pred)
168
+ pred = self.shader_forward(data, pred)
169
+ return pred
170
+
171
+ def shader_forward(self, data, pred={}):
172
+ assert ("num_character_images" in data), "ERROR: No Character Sheet input."
173
+
174
+ character_images_rgb_nmchw, num_character_images = data[
175
+ "character_images"], data["num_character_images"]
176
+ # build x_reference_rgb_a_sudp in the draw call
177
+ shader_character_a_nmchw = data["character_masks"]
178
+ assert torch.any(torch.mean(shader_character_a_nmchw, (0, 2, 3, 4)) >= 0.95) == False, "ERROR: \
179
+ No transparent area found in the image, PLEASE separate the foreground of input character sheets.\
180
+ The website waifucutout.com is recommended to automatically cut out the foreground."
181
+
182
+ if shader_character_a_nmchw is None:
183
+ shader_character_a_nmchw = pred["parser"]["pred"][:, :, 3:4, :, :]
184
+ x_reference_rgb_a = torch.cat([shader_character_a_nmchw[:, :, :, :, :] * character_images_rgb_nmchw[:, :, :, :, :],
185
+ shader_character_a_nmchw[:,
186
+ :, :, :, :],
187
+
188
+ ], 2)
189
+ assert (x_reference_rgb_a.shape[2] == self.DIM_SHADER_REFERENCE)
190
+ # build x_reference_features in the draw call
191
+ x_reference_features = pred["parser"]["features"]
192
+ # run cinn shader
193
+ retdic = self.shader(
194
+ pred["shader"]["target_pose_features"], x_reference_rgb_a, x_reference_features)
195
+ pred["shader"].update(retdic)
196
+
197
+ # decode rgba
198
+ if True:
199
+ dec_out = self.rgbadecodernet(
200
+ retdic["y_last_remote_features"])
201
+ y_weighted_x_reference_RGB = dec_out[:, 0:3, :, :]
202
+ y_weighted_mask_A = dec_out[:, 3:4, :, :]
203
+ y_weighted_warp_decoded_rgba = torch.cat(
204
+ (y_weighted_x_reference_RGB*y_weighted_mask_A, y_weighted_mask_A), dim=1
205
+ )
206
+ assert(y_weighted_warp_decoded_rgba.shape[1] == 4)
207
+ assert(
208
+ y_weighted_warp_decoded_rgba.shape[-1] == character_images_rgb_nmchw.shape[-1])
209
+ # apply decoded mask to decoded rgb, finishing the draw call
210
+ pred["shader"]["y_weighted_warp_decoded_rgba"] = y_weighted_warp_decoded_rgba
211
+ return pred
212
+
213
+ def character_parser_forward(self, data, pred={}):
214
+ if not("num_character_images" in data and "character_images" in data):
215
+ return pred
216
+ pred["parser"] = {"pred": None} # create output
217
+
218
+ inputs_rgb_nmchw, num_samples, num_character_images = data[
219
+ "character_images"], data["num_samples"], data["num_character_images"]
220
+ inputs_rgb_fchw = inputs_rgb_nmchw.view(
221
+ (num_samples * num_character_images, inputs_rgb_nmchw.shape[2], inputs_rgb_nmchw.shape[3], inputs_rgb_nmchw.shape[4]))
222
+
223
+ encoder_out, features = self.udpparsernet(
224
+ (inputs_rgb_fchw-0.6)/0.2970)
225
+
226
+ pred["parser"]["features"] = [features_out.view(
227
+ (num_samples, num_character_images, features_out.shape[1], features_out.shape[2], features_out.shape[3])) for features_out in features]
228
+
229
+ if (encoder_out is not None):
230
+
231
+ pred["parser"]["pred"] = UDPClip(encoder_out.view(
232
+ (num_samples, num_character_images, encoder_out.shape[1], encoder_out.shape[2], encoder_out.shape[3])))
233
+
234
+ return pred
235
+
236
+ def pose_parser_sc_forward(self, data, pred={}):
237
+ if not("num_pose_images" in data and "pose_images" in data):
238
+ return pred
239
+ inputs_aug_rgb_nmchw, num_samples, num_pose_images = data[
240
+ "pose_images"], data["num_samples"], data["num_pose_images"]
241
+ inputs_aug_rgb_fchw = inputs_aug_rgb_nmchw.view(
242
+ (num_samples * num_pose_images, inputs_aug_rgb_nmchw.shape[2], inputs_aug_rgb_nmchw.shape[3], inputs_aug_rgb_nmchw.shape[4]))
243
+
244
+ encoder_out, _ = self.udpparsernet(
245
+ (inputs_aug_rgb_fchw-0.6)/0.2970)
246
+
247
+ encoder_out = encoder_out.view(
248
+ (num_samples, num_pose_images, encoder_out.shape[1], encoder_out.shape[2], encoder_out.shape[3]))
249
+
250
+ # apply sigmoid after eval loss
251
+ pred["pose_parser"] = {"pred":UDPClip(encoder_out)}
252
+
253
+
254
+ return pred
255
+
256
+ def shader_pose_encoder_forward(self, data, pred={}):
257
+ pred["shader"] = {} # create output
258
+ if "pose_images" in data:
259
+ pose_images_rgb_nmchw = data["pose_images"]
260
+ target_gt_rgb = pose_images_rgb_nmchw[:, 0, :, :, :]
261
+ pred["shader"]["target_gt_rgb"] = target_gt_rgb
262
+
263
+ shader_target_a = None
264
+ if "pose_mask" in data:
265
+ pred["shader"]["target_gt_a"] = data["pose_mask"]
266
+ shader_target_a = data["pose_mask"]
267
+
268
+ shader_target_sudp = None
269
+ if "pose_label" in data:
270
+ shader_target_sudp = data["pose_label"][:, :3, :, :]
271
+
272
+ if self.args.test_pose_use_parser_udp:
273
+ shader_target_sudp = None
274
+ if shader_target_sudp is None:
275
+ shader_target_sudp = pred["pose_parser"]["pred"][:, 0:3, :, :]
276
+
277
+ if shader_target_a is None:
278
+ shader_target_a = pred["pose_parser"]["pred"][:, 3:4, :, :]
279
+
280
+ # build x_target_sudp_a in the draw call
281
+ x_target_sudp_a = torch.cat((
282
+ shader_target_sudp*shader_target_a,
283
+ shader_target_a
284
+ ), 1)
285
+ pred["shader"].update({
286
+ "x_target_sudp_a": x_target_sudp_a
287
+ })
288
+ _, features = self.traget_pose_encoder(
289
+ (x_target_sudp_a-0.6)/0.2970, ret_parser_out=False)
290
+
291
+ pred["shader"]["target_pose_features"] = features
292
+ return pred
data_loader.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ import os
6
+ cv2.setNumThreads(1)
7
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
8
+
9
+
10
+ class RandomResizedCropWithAutoCenteringAndZeroPadding (object):
11
+ def __init__(self, output_size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), center_jitter=(0.1, 0.1), size_from_alpha_mask=True):
12
+ assert isinstance(output_size, (int, tuple))
13
+ if isinstance(output_size, int):
14
+ self.output_size = (output_size, output_size)
15
+ else:
16
+ assert len(output_size) == 2
17
+ self.output_size = output_size
18
+ assert isinstance(scale, tuple)
19
+ assert isinstance(ratio, tuple)
20
+ if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
21
+ raise ValueError("Scale and ratio should be of kind (min, max)")
22
+ self.size_from_alpha_mask = size_from_alpha_mask
23
+ self.scale = scale
24
+ self.ratio = ratio
25
+ assert isinstance(center_jitter, tuple)
26
+ self.center_jitter = center_jitter
27
+
28
+ def __call__(self, sample):
29
+ imidx, image = sample['imidx'], sample["image_np"]
30
+ if "labels" in sample:
31
+ label = sample["labels"]
32
+ else:
33
+ label = None
34
+
35
+ im_h, im_w = image.shape[:2]
36
+ if self.size_from_alpha_mask and image.shape[2] == 4:
37
+ # compute bbox from alpha mask
38
+ bbox_left, bbox_top, bbox_w, bbox_h = cv2.boundingRect(
39
+ (image[:, :, 3] > 0).astype(np.uint8))
40
+ else:
41
+ bbox_left, bbox_top = 0, 0
42
+ bbox_h, bbox_w = image.shape[:2]
43
+ if bbox_h <= 1 and bbox_w <= 1:
44
+ sample["bad"] = 0
45
+ else:
46
+ # detect too small image here
47
+ alpha_varea = np.sum((image[:, :, 3] > 0).astype(np.uint8))
48
+ image_area = image.shape[0]*image.shape[1]
49
+ if alpha_varea/image_area < 0.001:
50
+ sample["bad"] = alpha_varea
51
+ # detect bad image
52
+ if "bad" in sample:
53
+ # baddata_dir = os.path.join(os.getcwd(), 'test_data', "baddata" + os.sep)
54
+ # save_output(str(imidx)+".png",image,label,baddata_dir)
55
+ bbox_h, bbox_w = image.shape[:2]
56
+ sample["image_np"] = np.zeros(
57
+ [self.output_size[0], self.output_size[1], image.shape[2]], dtype=image.dtype)
58
+ if label is not None:
59
+ sample["labels"] = np.zeros(
60
+ [self.output_size[0], self.output_size[1], 4], dtype=label.dtype)
61
+
62
+ return sample
63
+
64
+ # compute default area by making sure output_size contains bbox_w * bbox_h
65
+
66
+ jitter_h = np.random.uniform(-bbox_h *
67
+ self.center_jitter[0], bbox_h*self.center_jitter[0])
68
+ jitter_w = np.random.uniform(-bbox_w *
69
+ self.center_jitter[1], bbox_w*self.center_jitter[1])
70
+
71
+ # h/w
72
+ target_aspect_ratio = np.exp(
73
+ np.log(self.output_size[0]/self.output_size[1]) +
74
+ np.random.uniform(np.log(self.ratio[0]), np.log(self.ratio[1]))
75
+ )
76
+
77
+ source_aspect_ratio = bbox_h/bbox_w
78
+
79
+ if target_aspect_ratio < source_aspect_ratio:
80
+ # same w, target has larger h, use h to align
81
+ target_height = bbox_h * \
82
+ np.random.uniform(self.scale[0], self.scale[1])
83
+ virtual_h = int(
84
+ round(target_height))
85
+ virtual_w = int(
86
+ round(target_height / target_aspect_ratio)) # h/w
87
+ else:
88
+ # same w, source has larger h, use w to align
89
+ target_width = bbox_w * \
90
+ np.random.uniform(self.scale[0], self.scale[1])
91
+ virtual_h = int(
92
+ round(target_width * target_aspect_ratio)) # h/w
93
+ virtual_w = int(
94
+ round(target_width))
95
+
96
+ # print("required aspect ratio:", target_aspect_ratio)
97
+
98
+ virtual_top = int(round(bbox_top + jitter_h - (virtual_h-bbox_h)/2))
99
+ virutal_left = int(round(bbox_left + jitter_w - (virtual_w-bbox_w)/2))
100
+
101
+ if virtual_top < 0:
102
+ top_padding = abs(virtual_top)
103
+ crop_top = 0
104
+ else:
105
+ top_padding = 0
106
+ crop_top = virtual_top
107
+ if virutal_left < 0:
108
+ left_padding = abs(virutal_left)
109
+ crop_left = 0
110
+ else:
111
+ left_padding = 0
112
+ crop_left = virutal_left
113
+ if virtual_top+virtual_h > im_h:
114
+ bottom_padding = abs(im_h-(virtual_top+virtual_h))
115
+ crop_bottom = im_h
116
+ else:
117
+ bottom_padding = 0
118
+ crop_bottom = virtual_top+virtual_h
119
+ if virutal_left+virtual_w > im_w:
120
+ right_padding = abs(im_w-(virutal_left+virtual_w))
121
+ crop_right = im_w
122
+ else:
123
+ right_padding = 0
124
+ crop_right = virutal_left+virtual_w
125
+ # crop
126
+
127
+ image = image[crop_top:crop_bottom, crop_left: crop_right]
128
+ if label is not None:
129
+ label = label[crop_top:crop_bottom, crop_left: crop_right]
130
+
131
+ # pad
132
+ if top_padding + bottom_padding + left_padding + right_padding > 0:
133
+ padding = ((top_padding, bottom_padding),
134
+ (left_padding, right_padding), (0, 0))
135
+ # print("padding", padding)
136
+ image = np.pad(image, padding, mode='constant')
137
+ if label is not None:
138
+ label = np.pad(label, padding, mode='constant')
139
+
140
+ if image.shape[0]/image.shape[1] - virtual_h/virtual_w > 0.001:
141
+ print("virtual aspect ratio:", virtual_h/virtual_w)
142
+ print("image aspect ratio:", image.shape[0]/image.shape[1])
143
+ assert (image.shape[0]/image.shape[1] - virtual_h/virtual_w < 0.001)
144
+ sample["crop"] = np.array(
145
+ [im_h, im_w, crop_top, crop_bottom, crop_left, crop_right, top_padding, bottom_padding, left_padding, right_padding, image.shape[0], image.shape[1]])
146
+
147
+ # resize
148
+ if self.output_size[1] != image.shape[1] or self.output_size[0] != image.shape[0]:
149
+ if self.output_size[1] > image.shape[1] and self.output_size[0] > image.shape[0]:
150
+ # enlarging
151
+ image = cv2.resize(
152
+ image, (self.output_size[1], self.output_size[0]), interpolation=cv2.INTER_LINEAR)
153
+ else:
154
+ # shrinking
155
+ image = cv2.resize(
156
+ image, (self.output_size[1], self.output_size[0]), interpolation=cv2.INTER_AREA)
157
+
158
+ if label is not None:
159
+ label = cv2.resize(label, (self.output_size[1], self.output_size[0]),
160
+ interpolation=cv2.INTER_NEAREST_EXACT)
161
+
162
+ assert image.shape[0] == self.output_size[0] and image.shape[1] == self.output_size[1]
163
+ sample['imidx'], sample["image_np"] = imidx, image
164
+ if label is not None:
165
+ assert label.shape[0] == self.output_size[0] and label.shape[1] == self.output_size[1]
166
+ sample["labels"] = label
167
+
168
+ return sample
169
+
170
+
171
+ class FileDataset(Dataset):
172
+ def __init__(self, image_names_list, fg_img_lbl_transform=None, shader_pose_use_gt_udp_test=True, shader_target_use_gt_rgb_debug=False):
173
+ self.image_name_list = image_names_list
174
+ self.fg_img_lbl_transform = fg_img_lbl_transform
175
+ self.shader_pose_use_gt_udp_test = shader_pose_use_gt_udp_test
176
+ self.shader_target_use_gt_rgb_debug = shader_target_use_gt_rgb_debug
177
+
178
+ def __len__(self):
179
+ return len(self.image_name_list)
180
+
181
+ def get_gt_from_disk(self, idx, imname, read_label):
182
+ if read_label:
183
+ # read label
184
+ with open(imname, mode="rb") as bio:
185
+ if imname.find(".npz") > 0:
186
+ label_np = np.load(bio, allow_pickle=True)[
187
+ 'i'].astype(np.float32, copy=False)
188
+ else:
189
+ label_np = cv2.cvtColor(cv2.imdecode(np.frombuffer(bio.read(
190
+ ), np.uint8), cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH | cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA)
191
+ assert (4 == label_np.shape[2])
192
+ # fake image out of valid label
193
+ image_np = (label_np*255).clip(0, 255).astype(np.uint8, copy=False)
194
+ # assemble sample
195
+ sample = {'imidx': np.array(
196
+ [idx]), "image_np": image_np, "labels": label_np}
197
+
198
+ else:
199
+ # read image as unit8
200
+ with open(imname, mode="rb") as bio:
201
+ image_np = cv2.cvtColor(cv2.imdecode(np.frombuffer(
202
+ bio.read(), np.uint8), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA)
203
+ # image_np = Image.open(bio)
204
+ # image_np = np.array(image_np)
205
+ assert (3 == len(image_np.shape))
206
+ if (image_np.shape[2] == 4):
207
+ mask_np = image_np[:, :, 3:4]
208
+ image_np = (image_np[:, :, :3] *
209
+ (image_np[:, :, 3][:, :, np.newaxis]/255.0)).clip(0, 255).astype(np.uint8, copy=False)
210
+ elif (image_np.shape[2] == 3):
211
+ # generate a fake mask
212
+ # Fool-proofing
213
+ mask_np = np.ones(
214
+ (image_np.shape[0], image_np.shape[1], 1), dtype=np.uint8)*255
215
+ print("WARN: transparent background is preferred for image ", imname)
216
+ else:
217
+ raise ValueError("weird shape of image ", imname, image_np)
218
+ image_np = np.concatenate((image_np, mask_np), axis=2)
219
+ sample = {'imidx': np.array(
220
+ [idx]), "image_np": image_np}
221
+
222
+ # apply fg_img_lbl_transform
223
+ if self.fg_img_lbl_transform:
224
+ sample = self.fg_img_lbl_transform(sample)
225
+
226
+ if "labels" in sample:
227
+ # return UDP as 4chn XYZV float tensor
228
+ sample["labels"] = torch.from_numpy(
229
+ sample["labels"].transpose((2, 0, 1)))
230
+ assert (sample["labels"].dtype == torch.float32)
231
+
232
+ if "image_np" in sample:
233
+ # return image as 3chn RGB uint8 tensor and 1chn A uint8 tensor
234
+ sample["mask"] = torch.from_numpy(
235
+ sample["image_np"][:, :, 3:4].transpose((2, 0, 1)))
236
+ assert (sample["mask"].dtype == torch.uint8)
237
+ sample["image"] = torch.from_numpy(
238
+ sample["image_np"][:, :, :3].transpose((2, 0, 1)))
239
+
240
+ assert (sample["image"].dtype == torch.uint8)
241
+ del sample["image_np"]
242
+ return sample
243
+
244
+ def __getitem__(self, idx):
245
+ sample = {
246
+ 'imidx': np.array([idx])}
247
+ target = self.get_gt_from_disk(
248
+ idx, imname=self.image_name_list[idx][0], read_label=self.shader_pose_use_gt_udp_test)
249
+ if self.shader_target_use_gt_rgb_debug:
250
+ sample["pose_images"] = torch.stack([target["image"]])
251
+ sample["pose_mask"] = target["mask"]
252
+ elif self.shader_pose_use_gt_udp_test:
253
+ sample["pose_label"] = target["labels"]
254
+ sample["pose_mask"] = target["mask"]
255
+ else:
256
+ sample["pose_images"] = torch.stack([target["image"]])
257
+ if "crop" in target:
258
+ sample["pose_crop"] = target["crop"]
259
+ character_images = []
260
+ character_masks = []
261
+ for i in range(1, len(self.image_name_list[idx])):
262
+ source = self.get_gt_from_disk(
263
+ idx, self.image_name_list[idx][i], read_label=False)
264
+ character_images.append(source["image"])
265
+ character_masks.append(source["mask"])
266
+ character_images = torch.stack(character_images)
267
+ character_masks = torch.stack(character_masks)
268
+ sample.update({
269
+ "character_images": character_images,
270
+ "character_masks": character_masks
271
+ })
272
+ # do not make fake labels in inference
273
+ return sample
infer.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ rm -r "./results"
2
+ mkdir "./results"
3
+
4
+ rlaunch --gpu=1 --cpu=4 --memory=25600 -- python3 -m torch.distributed.launch \
5
+ --nproc_per_node=1 train.py --mode=test \
6
+ --world_size=1 --dataloaders=2 \
7
+ --test_input_poses_images=./test_data/ \
8
+ --test_input_person_images=./character_sheet/ \
9
+ --test_output_dir=./results/ \
10
+ --test_checkpoint_dir=./weights/
11
+
12
+ echo Generating video...
13
+ ffmpeg -r 30 -y -i ./results/%d.png -r 30 -c:v libx264 output.mp4 -r 30
14
+ echo DONE.
model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
model/backbone.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+ from torch.nn import functional as F
5
+
6
+ import torch.nn as nn
7
+ import torch
8
+ from torchvision import models
9
+
10
+
11
+ class AdaptiveConcatPool2d(nn.Module):
12
+ """
13
+ Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`.
14
+ Source: Fastai. This code was taken from the fastai library at url
15
+ https://github.com/fastai/fastai/blob/master/fastai/layers.py#L176
16
+ """
17
+
18
+ def __init__(self, sz=None):
19
+ "Output will be 2*sz or 2 if sz is None"
20
+ super().__init__()
21
+ self.output_size = sz or 1
22
+ self.ap = nn.AdaptiveAvgPool2d(self.output_size)
23
+ self.mp = nn.AdaptiveMaxPool2d(self.output_size)
24
+
25
+ def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)
26
+
27
+
28
+ class MyNorm(nn.Module):
29
+ def __init__(self, num_channels):
30
+ super(MyNorm, self).__init__()
31
+ self.norm = nn.InstanceNorm2d(
32
+ num_channels, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
33
+
34
+ def forward(self, x):
35
+ x = self.norm(x)
36
+ return x
37
+
38
+
39
+ def resnet_fastai(model, pretrained, url, replace_first_layer=None, replace_maxpool_layer=None, progress=True, map_location=None, **kwargs):
40
+ cut = -2
41
+ s = model(pretrained=False, **kwargs)
42
+ if replace_maxpool_layer is not None:
43
+ s.maxpool = replace_maxpool_layer
44
+ if replace_first_layer is not None:
45
+ body = nn.Sequential(replace_first_layer, *list(s.children())[1:cut])
46
+ else:
47
+ body = nn.Sequential(*list(s.children())[:cut])
48
+
49
+ if pretrained:
50
+ state = torch.hub.load_state_dict_from_url(url,
51
+ progress=progress, map_location=map_location)
52
+ if replace_first_layer is not None:
53
+ for each in list(state.keys()).copy():
54
+ if each.find("0.0.") == 0:
55
+ del state[each]
56
+ body_tail = nn.Sequential(body)
57
+ ret = body_tail.load_state_dict(state, strict=False)
58
+ return body
59
+
60
+
61
+ def get_backbone(name, pretrained=True, map_location=None):
62
+ """ Loading backbone, defining names for skip-connections and encoder output. """
63
+
64
+ first_layer_for_4chn = nn.Conv2d(
65
+ 4, 64, kernel_size=7, stride=2, padding=3, bias=False)
66
+ max_pool_layer_replace = nn.Conv2d(
67
+ 64, 64, kernel_size=3, stride=2, padding=1, bias=False)
68
+ # loading backbone model
69
+ if name == 'resnet18':
70
+ backbone = models.resnet18(pretrained=pretrained)
71
+ if name == 'resnet18-4':
72
+ backbone = models.resnet18(pretrained=pretrained)
73
+ backbone.conv1 = first_layer_for_4chn
74
+ elif name == 'resnet34':
75
+ backbone = models.resnet34(pretrained=pretrained)
76
+ elif name == 'resnet50':
77
+ backbone = models.resnet50(pretrained=False, norm_layer=MyNorm)
78
+ backbone.maxpool = max_pool_layer_replace
79
+ elif name == 'resnet101':
80
+ backbone = models.resnet101(pretrained=pretrained)
81
+ elif name == 'resnet152':
82
+ backbone = models.resnet152(pretrained=pretrained)
83
+ elif name == 'vgg16':
84
+ backbone = models.vgg16_bn(pretrained=pretrained).features
85
+ elif name == 'vgg19':
86
+ backbone = models.vgg19_bn(pretrained=pretrained).features
87
+ elif name == 'resnet18_danbo-4':
88
+ backbone = resnet_fastai(models.resnet18, url="https://github.com/RF5/danbooru-pretrained/releases/download/v0.1/resnet18-3f77756f.pth",
89
+ pretrained=pretrained, map_location=map_location, norm_layer=MyNorm, replace_first_layer=first_layer_for_4chn)
90
+ elif name == 'resnet50_danbo':
91
+ backbone = resnet_fastai(models.resnet50, url="https://github.com/RF5/danbooru-pretrained/releases/download/v0.1/resnet50-13306192.pth",
92
+ pretrained=pretrained, map_location=map_location, norm_layer=MyNorm, replace_maxpool_layer=max_pool_layer_replace)
93
+ elif name == 'densenet121':
94
+ backbone = models.densenet121(pretrained=True).features
95
+ elif name == 'densenet161':
96
+ backbone = models.densenet161(pretrained=True).features
97
+ elif name == 'densenet169':
98
+ backbone = models.densenet169(pretrained=True).features
99
+ elif name == 'densenet201':
100
+ backbone = models.densenet201(pretrained=True).features
101
+ else:
102
+ raise NotImplemented(
103
+ '{} backbone model is not implemented so far.'.format(name))
104
+ #print(backbone)
105
+ # specifying skip feature and output names
106
+ if name.startswith('resnet'):
107
+ feature_names = [None, 'relu', 'layer1', 'layer2', 'layer3']
108
+ backbone_output = 'layer4'
109
+ elif name == 'vgg16':
110
+ # TODO: consider using a 'bridge' for VGG models, there is just a MaxPool between last skip and backbone output
111
+ feature_names = ['5', '12', '22', '32', '42']
112
+ backbone_output = '43'
113
+ elif name == 'vgg19':
114
+ feature_names = ['5', '12', '25', '38', '51']
115
+ backbone_output = '52'
116
+ elif name.startswith('densenet'):
117
+ feature_names = [None, 'relu0', 'denseblock1',
118
+ 'denseblock2', 'denseblock3']
119
+ backbone_output = 'denseblock4'
120
+ elif name == 'unet_encoder':
121
+ feature_names = ['module1', 'module2', 'module3', 'module4']
122
+ backbone_output = 'module5'
123
+ else:
124
+ raise NotImplemented(
125
+ '{} backbone model is not implemented so far.'.format(name))
126
+ if name.find('_danbo') > 0:
127
+ feature_names = [None, '2', '4', '5', '6']
128
+ backbone_output = '7'
129
+ return backbone, feature_names, backbone_output
130
+
131
+
132
+ class UpsampleBlock(nn.Module):
133
+
134
+ # TODO: separate parametric and non-parametric classes?
135
+ # TODO: skip connection concatenated OR added
136
+
137
+ def __init__(self, ch_in, ch_out=None, skip_in=0, use_bn=True, parametric=False):
138
+ super(UpsampleBlock, self).__init__()
139
+
140
+ self.parametric = parametric
141
+ ch_out = ch_in/2 if ch_out is None else ch_out
142
+
143
+ # first convolution: either transposed conv, or conv following the skip connection
144
+ if parametric:
145
+ # versions: kernel=4 padding=1, kernel=2 padding=0
146
+ self.up = nn.ConvTranspose2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(4, 4),
147
+ stride=2, padding=1, output_padding=0, bias=(not use_bn))
148
+ self.bn1 = MyNorm(ch_out) if use_bn else None
149
+ else:
150
+ self.up = None
151
+ ch_in = ch_in + skip_in
152
+ self.conv1 = nn.Conv2d(in_channels=ch_in, out_channels=ch_out, kernel_size=(3, 3),
153
+ stride=1, padding=1, bias=(not use_bn))
154
+ self.bn1 = MyNorm(ch_out) if use_bn else None
155
+
156
+ self.relu = nn.ReLU(inplace=True)
157
+
158
+ # second convolution
159
+ conv2_in = ch_out if not parametric else ch_out + skip_in
160
+ self.conv2 = nn.Conv2d(in_channels=conv2_in, out_channels=ch_out, kernel_size=(3, 3),
161
+ stride=1, padding=1, bias=(not use_bn))
162
+ self.bn2 = MyNorm(ch_out) if use_bn else None
163
+
164
+ def forward(self, x, skip_connection=None):
165
+
166
+ x = self.up(x) if self.parametric else F.interpolate(x, size=None, scale_factor=2, mode='bilinear',
167
+ align_corners=None)
168
+ if self.parametric:
169
+ x = self.bn1(x) if self.bn1 is not None else x
170
+ x = self.relu(x)
171
+
172
+ if skip_connection is not None:
173
+ x = torch.cat([x, skip_connection], dim=1)
174
+
175
+ if not self.parametric:
176
+ x = self.conv1(x)
177
+ x = self.bn1(x) if self.bn1 is not None else x
178
+ x = self.relu(x)
179
+ x = self.conv2(x)
180
+ x = self.bn2(x) if self.bn2 is not None else x
181
+ x = self.relu(x)
182
+
183
+ return x
184
+
185
+
186
+ class ResEncUnet(nn.Module):
187
+
188
+ """ U-Net (https://arxiv.org/pdf/1505.04597.pdf) implementation with pre-trained torchvision backbones."""
189
+
190
+ def __init__(self,
191
+ backbone_name,
192
+ pretrained=True,
193
+ encoder_freeze=False,
194
+ classes=21,
195
+ decoder_filters=(512, 256, 128, 64, 32),
196
+ parametric_upsampling=True,
197
+ shortcut_features='default',
198
+ decoder_use_instancenorm=True,
199
+ map_location=None
200
+ ):
201
+ super(ResEncUnet, self).__init__()
202
+
203
+ self.backbone_name = backbone_name
204
+
205
+ self.backbone, self.shortcut_features, self.bb_out_name = get_backbone(
206
+ backbone_name, pretrained=pretrained, map_location=map_location)
207
+ shortcut_chs, bb_out_chs = self.infer_skip_channels()
208
+ if shortcut_features != 'default':
209
+ self.shortcut_features = shortcut_features
210
+
211
+ # build decoder part
212
+ self.upsample_blocks = nn.ModuleList()
213
+ # avoiding having more blocks than skip connections
214
+ decoder_filters = decoder_filters[:len(self.shortcut_features)]
215
+ decoder_filters_in = [bb_out_chs] + list(decoder_filters[:-1])
216
+ num_blocks = len(self.shortcut_features)
217
+ for i, [filters_in, filters_out] in enumerate(zip(decoder_filters_in, decoder_filters)):
218
+ self.upsample_blocks.append(UpsampleBlock(filters_in, filters_out,
219
+ skip_in=shortcut_chs[num_blocks-i-1],
220
+ parametric=parametric_upsampling,
221
+ use_bn=decoder_use_instancenorm))
222
+ self.final_conv = nn.Conv2d(
223
+ decoder_filters[-1], classes, kernel_size=(1, 1))
224
+
225
+ if encoder_freeze:
226
+ self.freeze_encoder()
227
+
228
+ def freeze_encoder(self):
229
+ """ Freezing encoder parameters, the newly initialized decoder parameters are remaining trainable. """
230
+
231
+ for param in self.backbone.parameters():
232
+ param.requires_grad = False
233
+
234
+ def forward(self, *input, ret_parser_out=True):
235
+ """ Forward propagation in U-Net. """
236
+
237
+ x, features = self.forward_backbone(*input)
238
+ output_feature = [x]
239
+ for skip_name, upsample_block in zip(self.shortcut_features[::-1], self.upsample_blocks):
240
+ skip_features = features[skip_name]
241
+ if skip_features is not None:
242
+ output_feature.append(skip_features)
243
+ if ret_parser_out:
244
+ x = upsample_block(x, skip_features)
245
+ if ret_parser_out:
246
+ x = self.final_conv(x)
247
+ # apply sigmoid later
248
+ else:
249
+ x = None
250
+
251
+ return x, output_feature
252
+
253
+ def forward_backbone(self, x):
254
+ """ Forward propagation in backbone encoder network. """
255
+
256
+ features = {None: None} if None in self.shortcut_features else dict()
257
+ for name, child in self.backbone.named_children():
258
+ x = child(x)
259
+ if name in self.shortcut_features:
260
+ features[name] = x
261
+ if name == self.bb_out_name:
262
+ break
263
+
264
+ return x, features
265
+
266
+ def infer_skip_channels(self):
267
+ """ Getting the number of channels at skip connections and at the output of the encoder. """
268
+ if self.backbone_name.find("-4") > 0:
269
+ x = torch.zeros(1, 4, 224, 224)
270
+ else:
271
+ x = torch.zeros(1, 3, 224, 224)
272
+ has_fullres_features = self.backbone_name.startswith(
273
+ 'vgg') or self.backbone_name == 'unet_encoder'
274
+ # only VGG has features at full resolution
275
+ channels = [] if has_fullres_features else [0]
276
+
277
+ # forward run in backbone to count channels (dirty solution but works for *any* Module)
278
+ for name, child in self.backbone.named_children():
279
+ x = child(x)
280
+ if name in self.shortcut_features:
281
+ channels.append(x.shape[1])
282
+ if name == self.bb_out_name:
283
+ out_channels = x.shape[1]
284
+ break
285
+ return channels, out_channels
model/decoder_small.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ import torch
6
+
7
+
8
+ class ResBlock2d(nn.Module):
9
+ def __init__(self, in_features, kernel_size, padding):
10
+ super(ResBlock2d, self).__init__()
11
+ self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
12
+ padding=padding)
13
+ self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
14
+ padding=padding)
15
+
16
+ self.norm1 = nn.Conv2d(
17
+ in_channels=in_features, out_channels=in_features, kernel_size=1)
18
+ self.norm2 = nn.Conv2d(
19
+ in_channels=in_features, out_channels=in_features, kernel_size=1)
20
+
21
+ def forward(self, x):
22
+ out = self.norm1(x)
23
+ out = F.relu(out, inplace=True)
24
+ out = self.conv1(out)
25
+ out = self.norm2(out)
26
+ out = F.relu(out, inplace=True)
27
+ out = self.conv2(out)
28
+ out += x
29
+ return out
30
+
31
+
32
+ class RGBADecoderNet(nn.Module):
33
+ def __init__(self, c=64, out_planes=4, num_bottleneck_blocks=1):
34
+ super(RGBADecoderNet, self).__init__()
35
+ self.conv_rgba = nn.Sequential(nn.Conv2d(c, out_planes, kernel_size=3, stride=1,
36
+ padding=1, dilation=1, bias=True))
37
+ self.bottleneck = torch.nn.Sequential()
38
+ for i in range(num_bottleneck_blocks):
39
+ self.bottleneck.add_module(
40
+ 'r' + str(i), ResBlock2d(c, kernel_size=(3, 3), padding=(1, 1)))
41
+
42
+ def forward(self, features_weighted_mask_atfeaturesscale_list=[]):
43
+ return torch.sigmoid(self.conv_rgba(self.bottleneck(features_weighted_mask_atfeaturesscale_list.pop(0))))
model/shader.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .warplayer import warp_features
5
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
+
7
+
8
+ class DecoderBlock(nn.Module):
9
+ def __init__(self, in_planes, c=224, out_msgs=0, out_locals=0, block_nums=1, out_masks=1, out_local_flows=32, out_msgs_flows=32, out_feat_flows=0):
10
+
11
+ super(DecoderBlock, self).__init__()
12
+ self.conv0 = nn.Sequential(
13
+ nn.Conv2d(in_planes, c, 3, 2, 1),
14
+ nn.PReLU(c),
15
+ nn.Conv2d(c, c, 3, 2, 1),
16
+ nn.PReLU(c),
17
+ )
18
+
19
+ self.convblocks = nn.ModuleList()
20
+ for i in range(block_nums):
21
+ self.convblocks.append(nn.Sequential(
22
+ nn.Conv2d(c, c, 3, 1, 1),
23
+ nn.PReLU(c),
24
+ nn.Conv2d(c, c, 3, 1, 1),
25
+ nn.PReLU(c),
26
+ nn.Conv2d(c, c, 3, 1, 1),
27
+ nn.PReLU(c),
28
+ nn.Conv2d(c, c, 3, 1, 1),
29
+ nn.PReLU(c),
30
+ nn.Conv2d(c, c, 3, 1, 1),
31
+ nn.PReLU(c),
32
+ nn.Conv2d(c, c, 3, 1, 1),
33
+ nn.PReLU(c),
34
+ ))
35
+ self.out_flows = 2
36
+ self.out_msgs = out_msgs
37
+ self.out_msgs_flows = out_msgs_flows if out_msgs > 0 else 0
38
+ self.out_locals = out_locals
39
+ self.out_local_flows = out_local_flows if out_locals > 0 else 0
40
+ self.out_masks = out_masks
41
+ self.out_feat_flows = out_feat_flows
42
+
43
+ self.conv_last = nn.Sequential(
44
+ nn.ConvTranspose2d(c, c, 4, 2, 1),
45
+ nn.PReLU(c),
46
+ nn.ConvTranspose2d(c, self.out_flows+self.out_msgs+self.out_msgs_flows +
47
+ self.out_locals+self.out_local_flows+self.out_masks+self.out_feat_flows, 4, 2, 1),
48
+ )
49
+
50
+ def forward(self, accumulated_flow, *other):
51
+ x = [accumulated_flow]
52
+ for each in other:
53
+ if each is not None:
54
+ assert(accumulated_flow.shape[-1] == each.shape[-1]), "decoder want {}, but get {}".format(
55
+ accumulated_flow.shape, each.shape)
56
+ x.append(each)
57
+ feat = self.conv0(torch.cat(x, dim=1))
58
+ for convblock1 in self.convblocks:
59
+ feat = convblock1(feat) + feat
60
+ feat = self.conv_last(feat)
61
+ prev = 0
62
+ flow = feat[:, prev:prev+self.out_flows, :, :]
63
+ prev += self.out_flows
64
+ message = feat[:, prev:prev+self.out_msgs,
65
+ :, :] if self.out_msgs > 0 else None
66
+ prev += self.out_msgs
67
+ message_flow = feat[:, prev:prev + self.out_msgs_flows,
68
+ :, :] if self.out_msgs_flows > 0 else None
69
+ prev += self.out_msgs_flows
70
+ local_message = feat[:, prev:prev + self.out_locals,
71
+ :, :] if self.out_locals > 0 else None
72
+ prev += self.out_locals
73
+ local_message_flow = feat[:, prev:prev+self.out_local_flows,
74
+ :, :] if self.out_local_flows > 0 else None
75
+ prev += self.out_local_flows
76
+ mask = torch.sigmoid(
77
+ feat[:, prev:prev+self.out_masks, :, :]) if self.out_masks > 0 else None
78
+ prev += self.out_masks
79
+ feat_flow = feat[:, prev:prev+self.out_feat_flows,
80
+ :, :] if self.out_feat_flows > 0 else None
81
+ prev += self.out_feat_flows
82
+ return flow, mask, message, message_flow, local_message, local_message_flow, feat_flow
83
+
84
+
85
+ class CINN(nn.Module):
86
+ def __init__(self, DIM_SHADER_REFERENCE, target_feature_chns=[512, 256, 128, 64, 64], feature_chns=[2048, 1024, 512, 256, 64], out_msgs_chn=[2048, 1024, 512, 256, 64, 64], out_locals_chn=[2048, 1024, 512, 256, 64, 0], block_num=[1, 1, 1, 1, 1, 2], block_chn_num=[224, 224, 224, 224, 224, 224]):
87
+ super(CINN, self).__init__()
88
+
89
+ self.in_msgs_chn = [0, *out_msgs_chn[:-1]]
90
+ self.in_locals_chn = [0, *out_locals_chn[:-1]]
91
+
92
+ self.decoder_blocks = nn.ModuleList()
93
+ self.feed_weighted = True
94
+ if self.feed_weighted:
95
+ in_planes = 2+2+DIM_SHADER_REFERENCE*2
96
+ else:
97
+ in_planes = 2+DIM_SHADER_REFERENCE
98
+ for each_target_feature_chns, each_feature_chns, each_out_msgs_chn, each_out_locals_chn, each_in_msgs_chn, each_in_locals_chn, each_block_num, each_block_chn_num in zip(target_feature_chns, feature_chns, out_msgs_chn, out_locals_chn, self.in_msgs_chn, self.in_locals_chn, block_num, block_chn_num):
99
+ self.decoder_blocks.append(
100
+ DecoderBlock(in_planes+each_target_feature_chns+each_feature_chns+each_in_locals_chn+each_in_msgs_chn, c=each_block_chn_num, block_nums=each_block_num, out_msgs=each_out_msgs_chn, out_locals=each_out_locals_chn, out_masks=2+each_out_locals_chn))
101
+ for i in range(len(feature_chns), len(out_locals_chn)):
102
+ #print("append extra block", i, "msg",
103
+ # out_msgs_chn[i], "local", out_locals_chn[i], "block", block_num[i])
104
+ self.decoder_blocks.append(
105
+ DecoderBlock(in_planes+self.in_msgs_chn[i]+self.in_locals_chn[i], c=block_chn_num[i], block_nums=block_num[i], out_msgs=out_msgs_chn[i], out_locals=out_locals_chn[i], out_masks=2+out_msgs_chn[i], out_feat_flows=0))
106
+
107
+ def apply_flow(self, mask, message, message_flow, local_message, local_message_flow, x_reference, accumulated_flow, each_x_reference_features=None, each_x_reference_features_flow=None):
108
+ if each_x_reference_features is not None:
109
+ size_from = each_x_reference_features
110
+ else:
111
+ size_from = x_reference
112
+ f_size = (size_from.shape[2], size_from.shape[3])
113
+ accumulated_flow = self.flow_rescale(
114
+ accumulated_flow, size_from)
115
+ # mask = warp_features(F.interpolate(
116
+ # mask, size=f_size, mode="bilinear"), accumulated_flow) if mask is not None else None
117
+ mask = F.interpolate(
118
+ mask, size=f_size, mode="bilinear") if mask is not None else None
119
+ message = F.interpolate(
120
+ message, size=f_size, mode="bilinear") if message is not None else None
121
+ message_flow = self.flow_rescale(
122
+ message_flow, size_from) if message_flow is not None else None
123
+ message = warp_features(
124
+ message, message_flow) if message_flow is not None else message
125
+
126
+ local_message = F.interpolate(
127
+ local_message, size=f_size, mode="bilinear") if local_message is not None else None
128
+ local_message_flow = self.flow_rescale(
129
+ local_message_flow, size_from) if local_message_flow is not None else None
130
+ local_message = warp_features(
131
+ local_message, local_message_flow) if local_message_flow is not None else local_message
132
+
133
+ warp_x_reference = warp_features(F.interpolate(
134
+ x_reference, size=f_size, mode="bilinear"), accumulated_flow)
135
+
136
+ each_x_reference_features_flow = self.flow_rescale(
137
+ each_x_reference_features_flow, size_from) if (each_x_reference_features is not None and each_x_reference_features_flow is not None) else None
138
+ warp_each_x_reference_features = warp_features(
139
+ each_x_reference_features, each_x_reference_features_flow) if each_x_reference_features_flow is not None else each_x_reference_features
140
+
141
+ return mask, message, local_message, warp_x_reference, accumulated_flow, warp_each_x_reference_features, each_x_reference_features_flow
142
+
143
+ def forward(self, x_target_features=[], x_reference=None, x_reference_features=[]):
144
+ y_flow = []
145
+ y_feat_flow = []
146
+
147
+ y_local_message = []
148
+ y_warp_x_reference = []
149
+ y_warp_x_reference_features = []
150
+
151
+ y_weighted_flow = []
152
+ y_weighted_mask = []
153
+ y_weighted_message = []
154
+ y_weighted_x_reference = []
155
+ y_weighted_x_reference_features = []
156
+
157
+ for pyrlevel, ifblock in enumerate(self.decoder_blocks):
158
+ stacked_wref = []
159
+ stacked_feat = []
160
+ stacked_anci = []
161
+ stacked_flow = []
162
+ stacked_mask = []
163
+ stacked_mesg = []
164
+ stacked_locm = []
165
+ stacked_feat_flow = []
166
+ for view_id in range(x_reference.shape[1]): # NMCHW
167
+
168
+ if pyrlevel == 0:
169
+ # create from zero flow
170
+ feat_ev = x_reference_features[pyrlevel][:,
171
+ view_id, :, :, :] if pyrlevel < len(x_reference_features) else None
172
+
173
+ accumulated_flow = torch.zeros_like(
174
+ feat_ev[:, :2, :, :]).to(device)
175
+ accumulated_feat_flow = torch.zeros_like(
176
+ feat_ev[:, :32, :, :]).to(device)
177
+ # domestic inputs
178
+ warp_x_reference = F.interpolate(x_reference[:, view_id, :, :, :], size=(
179
+ feat_ev.shape[-2], feat_ev.shape[-1]), mode="bilinear")
180
+ warp_x_reference_features = feat_ev
181
+
182
+ local_message = None
183
+ # federated inputs
184
+ weighted_flow = accumulated_flow if self.feed_weighted else None
185
+ weighted_wref = warp_x_reference if self.feed_weighted else None
186
+ weighted_message = None
187
+ else:
188
+ # resume from last layer
189
+ accumulated_flow = y_flow[-1][:, view_id, :, :, :]
190
+ accumulated_feat_flow = y_feat_flow[-1][:,
191
+ view_id, :, :, :] if y_feat_flow[-1] is not None else None
192
+ # domestic inputs
193
+ warp_x_reference = y_warp_x_reference[-1][:,
194
+ view_id, :, :, :]
195
+ warp_x_reference_features = y_warp_x_reference_features[-1][:,
196
+ view_id, :, :, :] if y_warp_x_reference_features[-1] is not None else None
197
+ local_message = y_local_message[-1][:, view_id, :,
198
+ :, :] if len(y_local_message) > 0 else None
199
+
200
+ # federated inputs
201
+ weighted_flow = y_weighted_flow[-1] if self.feed_weighted else None
202
+ weighted_wref = y_weighted_x_reference[-1] if self.feed_weighted else None
203
+ weighted_message = y_weighted_message[-1] if len(
204
+ y_weighted_message) > 0 else None
205
+ scaled_x_target = x_target_features[pyrlevel][:, :, :, :].detach() if pyrlevel < len(
206
+ x_target_features) else None
207
+ # compute flow
208
+ residual_flow, mask, message, message_flow, local_message, local_message_flow, residual_feat_flow = ifblock(
209
+ accumulated_flow, scaled_x_target, warp_x_reference, warp_x_reference_features, weighted_flow, weighted_wref, weighted_message, local_message)
210
+ accumulated_flow = residual_flow + accumulated_flow
211
+ accumulated_feat_flow = accumulated_flow
212
+
213
+ feat_ev = x_reference_features[pyrlevel+1][:,
214
+ view_id, :, :, :] if pyrlevel+1 < len(x_reference_features) else None
215
+ mask, message, local_message, warp_x_reference, accumulated_flow, warp_x_reference_features, accumulated_feat_flow = self.apply_flow(
216
+ mask, message, message_flow, local_message, local_message_flow, x_reference[:, view_id, :, :, :], accumulated_flow, feat_ev, accumulated_feat_flow)
217
+ stacked_flow.append(accumulated_flow)
218
+ if accumulated_feat_flow is not None:
219
+ stacked_feat_flow.append(accumulated_feat_flow)
220
+ stacked_mask.append(mask)
221
+ if message is not None:
222
+ stacked_mesg.append(message)
223
+ if local_message is not None:
224
+ stacked_locm.append(local_message)
225
+ stacked_wref.append(warp_x_reference)
226
+ if warp_x_reference_features is not None:
227
+ stacked_feat.append(warp_x_reference_features)
228
+
229
+ stacked_flow = torch.stack(stacked_flow, dim=1) # M*NCHW -> NMCHW
230
+ stacked_feat_flow = torch.stack(stacked_feat_flow, dim=1) if len(
231
+ stacked_feat_flow) > 0 else None
232
+ stacked_mask = torch.stack(
233
+ stacked_mask, dim=1)
234
+
235
+ stacked_mesg = torch.stack(stacked_mesg, dim=1) if len(
236
+ stacked_mesg) > 0 else None
237
+ stacked_locm = torch.stack(stacked_locm, dim=1) if len(
238
+ stacked_locm) > 0 else None
239
+
240
+ stacked_wref = torch.stack(stacked_wref, dim=1)
241
+ stacked_feat = torch.stack(stacked_feat, dim=1) if len(
242
+ stacked_feat) > 0 else None
243
+ stacked_anci = torch.stack(stacked_anci, dim=1) if len(
244
+ stacked_anci) > 0 else None
245
+ y_flow.append(stacked_flow)
246
+ y_feat_flow.append(stacked_feat_flow)
247
+
248
+ y_warp_x_reference.append(stacked_wref)
249
+ y_warp_x_reference_features.append(stacked_feat)
250
+ # compute normalized confidence
251
+ stacked_contrib = torch.nn.functional.softmax(stacked_mask, dim=1)
252
+
253
+ # torch.sum to remove temp dimension M from NMCHW --> NCHW
254
+ weighted_flow = torch.sum(
255
+ stacked_mask[:, :, 0:1, :, :] * stacked_contrib[:, :, 0:1, :, :] * stacked_flow, dim=1)
256
+ weighted_mask = torch.sum(
257
+ stacked_contrib[:, :, 0:1, :, :] * stacked_mask[:, :, 0:1, :, :], dim=1)
258
+ weighted_wref = torch.sum(
259
+ stacked_mask[:, :, 0:1, :, :] * stacked_contrib[:, :, 0:1, :, :] * stacked_wref, dim=1) if stacked_wref is not None else None
260
+ weighted_feat = torch.sum(
261
+ stacked_mask[:, :, 1:2, :, :] * stacked_contrib[:, :, 1:2, :, :] * stacked_feat, dim=1) if stacked_feat is not None else None
262
+ weighted_mesg = torch.sum(
263
+ stacked_mask[:, :, 2:, :, :] * stacked_contrib[:, :, 2:, :, :] * stacked_mesg, dim=1) if stacked_mesg is not None else None
264
+ y_weighted_flow.append(weighted_flow)
265
+ y_weighted_mask.append(weighted_mask)
266
+ if weighted_mesg is not None:
267
+ y_weighted_message.append(weighted_mesg)
268
+ if stacked_locm is not None:
269
+ y_local_message.append(stacked_locm)
270
+ y_weighted_message.append(weighted_mesg)
271
+ y_weighted_x_reference.append(weighted_wref)
272
+ y_weighted_x_reference_features.append(weighted_feat)
273
+
274
+ if weighted_feat is not None:
275
+ y_weighted_x_reference_features.append(weighted_feat)
276
+ return {
277
+ "y_last_remote_features": [weighted_mesg],
278
+ }
279
+
280
+ def flow_rescale(self, prev_flow, each_x_reference_features):
281
+ if prev_flow is None:
282
+ prev_flow = torch.zeros_like(
283
+ each_x_reference_features[:, :2]).to(device)
284
+ else:
285
+ up_scale_factor = each_x_reference_features.shape[-1] / \
286
+ prev_flow.shape[-1]
287
+ if up_scale_factor != 1:
288
+ prev_flow = F.interpolate(prev_flow, scale_factor=up_scale_factor, mode="bilinear",
289
+ align_corners=False, recompute_scale_factor=False) * up_scale_factor
290
+ return prev_flow
model/warplayer.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
+ backwarp_tenGrid = {}
6
+
7
+
8
+ def warp(tenInput, tenFlow):
9
+ with torch.cuda.amp.autocast(enabled=False):
10
+ k = (str(tenFlow.device), str(tenFlow.size()))
11
+ if k not in backwarp_tenGrid:
12
+ tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
13
+ 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
14
+ tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
15
+ 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
16
+ backwarp_tenGrid[k] = torch.cat(
17
+ [tenHorizontal, tenVertical], 1).to(device)
18
+
19
+ tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
20
+ tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
21
+
22
+ g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
23
+ if tenInput.dtype != g.dtype:
24
+ g = g.to(tenInput.dtype)
25
+ return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
26
+ # "zeros" "border"
27
+
28
+
29
+ def warp_features(inp, flow, ):
30
+ groups = flow.shape[1]//2 # NCHW
31
+ samples = inp.shape[0]
32
+ h = inp.shape[2]
33
+ w = inp.shape[3]
34
+ assert(flow.shape[0] == samples and flow.shape[2]
35
+ == h and flow.shape[3] == w)
36
+ chns = inp.shape[1]
37
+ chns_per_group = chns // groups
38
+ assert(flow.shape[1] % 2 == 0)
39
+ assert(chns % groups == 0)
40
+ inp = inp.contiguous().view(samples*groups, chns_per_group, h, w)
41
+ flow = flow.contiguous().view(samples*groups, 2, h, w)
42
+ feat = warp(inp, flow)
43
+ feat = feat.view(samples, chns, h, w)
44
+ return feat
45
+
46
+
47
+ def flow2rgb(flow_map_np):
48
+ h, w, _ = flow_map_np.shape
49
+ rgb_map = np.ones((h, w, 3)).astype(np.float32)/2.0
50
+ normalized_flow_map = np.concatenate(
51
+ (flow_map_np[:, :, 0:1]/h/2.0, flow_map_np[:, :, 1:2]/w/2.0), axis=2)
52
+ rgb_map[:, :, 0] += normalized_flow_map[:, :, 0]
53
+ rgb_map[:, :, 1] -= 0.5 * \
54
+ (normalized_flow_map[:, :, 0] + normalized_flow_map[:, :, 1])
55
+ rgb_map[:, :, 2] += normalized_flow_map[:, :, 1]
56
+ return (rgb_map.clip(0, 1)*255.0).astype(np.uint8)
streamlit.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import streamlit as st
4
+ import os
5
+ import base64
6
+
7
+ st.set_page_config(layout="wide", page_title='CoNR demo', page_icon="🪐")
8
+
9
+ st.title('CoNR demo')
10
+ st.markdown(""" <style>
11
+ #MainMenu {visibility: hidden;}
12
+ footer {visibility: hidden;}
13
+ </style> """, unsafe_allow_html=True)
14
+
15
+ def get_base64(bin_file):
16
+ with open(bin_file, 'rb') as f:
17
+ data = f.read()
18
+ return base64.b64encode(data).decode()
19
+
20
+ # def set_background(png_file):
21
+ # bin_str = get_base64(png_file)
22
+ # page_bg_img = '''
23
+ # <style>
24
+ # .stApp {
25
+ # background-image: url("data:image/png;base64,%s");
26
+ # background-size: 1920px 1080px;
27
+ # background-attachment:fixed;
28
+ # background-position:center;
29
+ # background-repeat:no-repeat;
30
+ # }
31
+ # </style>
32
+ # ''' % bin_str
33
+ # st.markdown(page_bg_img, unsafe_allow_html=True)
34
+
35
+ # set_background('ipad_bg.png')
36
+
37
+ upload_img = (st.file_uploader("输入character sheet", "png", accept_multiple_files=True))
38
+
39
+ if st.button('RUN!'):
40
+ if upload_img is not None:
41
+ for i in range(len(upload_img)):
42
+ with open('character_sheet/{}.png'.format(i), 'wb') as f:
43
+ f.write(upload_img[i].read())
44
+
45
+ st.info('努力推理中...')
46
+ os.system('sh infer.sh')
47
+ st.info('Done!')
48
+ video_file=open('output.mp4', 'rb')
49
+ video_bytes = video_file.read()
50
+ st.video(video_bytes, start_time=0)
51
+ else:
52
+ st.info('还没上传图片呢> <')
train.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import time
4
+ from datetime import datetime
5
+ from distutils.util import strtobool
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import DataLoader
10
+ from torchvision import transforms
11
+ from data_loader import (FileDataset,
12
+ RandomResizedCropWithAutoCenteringAndZeroPadding)
13
+ from torch.utils.data.distributed import DistributedSampler
14
+ from conr import CoNR
15
+
16
+ def data_sampler(dataset, shuffle, distributed):
17
+
18
+ if distributed:
19
+ return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle)
20
+
21
+ if shuffle:
22
+ return torch.utils.data.RandomSampler(dataset)
23
+
24
+ else:
25
+ return torch.utils.data.SequentialSampler(dataset)
26
+
27
+ def save_output(image_name, inputs_v, d_dir=".", crop=None):
28
+ import cv2
29
+
30
+ inputs_v = inputs_v.detach().squeeze()
31
+ input_np = torch.clamp(inputs_v*255, 0, 255).byte().cpu().numpy().transpose(
32
+ (1, 2, 0))
33
+ # cv2.setNumThreads(1)
34
+ out_render_scale = cv2.cvtColor(input_np, cv2.COLOR_RGBA2BGRA)
35
+ if crop is not None:
36
+ crop = crop.cpu().numpy()[0]
37
+ output_img = np.zeros((crop[0], crop[1], 4), dtype=np.uint8)
38
+ before_resize_scale = cv2.resize(
39
+ out_render_scale, (crop[5]-crop[4]+crop[8]+crop[9], crop[3]-crop[2]+crop[6]+crop[7]), interpolation=cv2.INTER_AREA) # w,h
40
+ output_img[crop[2]:crop[3], crop[4]:crop[5]] = before_resize_scale[crop[6]:before_resize_scale.shape[0] -
41
+ crop[7], crop[8]:before_resize_scale.shape[1]-crop[9]]
42
+ else:
43
+ output_img = out_render_scale
44
+ cv2.imwrite(d_dir+"/"+image_name.split(os.sep)[-1]+'.png',
45
+ output_img
46
+ )
47
+
48
+
49
+ def test():
50
+ source_names_list = []
51
+ for name in sorted(os.listdir(args.test_input_person_images)):
52
+ thissource = os.path.join(args.test_input_person_images, name)
53
+ if os.path.isfile(thissource):
54
+ source_names_list.append(thissource)
55
+ if os.path.isdir(thissource):
56
+ print("skipping empty folder :"+thissource)
57
+
58
+ image_names_list = []
59
+ for name in sorted(os.listdir(args.test_input_poses_images)):
60
+ thistarget = os.path.join(args.test_input_poses_images, name)
61
+ if os.path.isfile(thistarget):
62
+ image_names_list.append([thistarget, *source_names_list])
63
+ if os.path.isdir(thistarget):
64
+ print("skipping folder :"+thistarget)
65
+ print(image_names_list)
66
+
67
+ print("---building models")
68
+ conrmodel = CoNR(args)
69
+ conrmodel.load_model(path=args.test_checkpoint_dir)
70
+ conrmodel.dist()
71
+ infer(args, conrmodel, image_names_list)
72
+
73
+ # def test():
74
+ # source_names_list = []
75
+ # for name in os.listdir(args.test_input_person_images):
76
+ # thissource = os.path.join(args.test_input_person_images, name)
77
+ # if os.path.isfile(thissource):
78
+ # source_names_list.append([thissource])
79
+ # if os.path.isdir(thissource):
80
+ # toadd = [os.path.join(thissource, this_file)
81
+ # for this_file in os.listdir(thissource)]
82
+ # if (toadd != []):
83
+ # source_names_list.append(toadd)
84
+ # else:
85
+ # print("skipping empty folder :"+thissource)
86
+ # image_names_list = []
87
+ # for eachlist in source_names_list:
88
+ # for name in sorted(os.listdir(args.test_input_poses_images)):
89
+ # thistarget = os.path.join(args.test_input_poses_images, name)
90
+ # if os.path.isfile(thistarget):
91
+ # image_names_list.append([thistarget, *eachlist])
92
+ # if os.path.isdir(thistarget):
93
+ # print("skipping folder :"+thistarget)
94
+
95
+ # print(image_names_list)
96
+ # print("---building models...")
97
+ # conrmodel = CoNR(args)
98
+ # conrmodel.load_model(path=args.test_checkpoint_dir)
99
+ # conrmodel.dist()
100
+ # infer(args, conrmodel, image_names_list)
101
+
102
+
103
+ def infer(args, humanflowmodel, image_names_list):
104
+ print("---test images: ", len(image_names_list))
105
+ test_salobj_dataset = FileDataset(image_names_list=image_names_list,
106
+ fg_img_lbl_transform=transforms.Compose([
107
+ RandomResizedCropWithAutoCenteringAndZeroPadding(
108
+ (args.dataloader_imgsize, args.dataloader_imgsize), scale=(1, 1), ratio=(1.0, 1.0), center_jitter=(0.0, 0.0)
109
+ )]),
110
+ shader_pose_use_gt_udp_test=not args.test_pose_use_parser_udp,
111
+ shader_target_use_gt_rgb_debug=False
112
+ )
113
+ sampler = data_sampler(test_salobj_dataset, shuffle=False,
114
+ distributed=args.distributed)
115
+ train_data = DataLoader(test_salobj_dataset,
116
+ batch_size=1,
117
+ shuffle=False,sampler=sampler,
118
+ num_workers=args.dataloaders)
119
+
120
+ # start testing
121
+
122
+ train_num = train_data.__len__()
123
+ time_stamp = time.time()
124
+ prev_frame_rgb = []
125
+ prev_frame_a = []
126
+ for i, data in enumerate(train_data):
127
+ data_time_interval = time.time() - time_stamp
128
+ time_stamp = time.time()
129
+ with torch.no_grad():
130
+ data["character_images"] = torch.cat(
131
+ [data["character_images"], *prev_frame_rgb], dim=1)
132
+ data["character_masks"] = torch.cat(
133
+ [data["character_masks"], *prev_frame_a], dim=1)
134
+ data = humanflowmodel.data_norm_image(data)
135
+ pred = humanflowmodel.model_step(data, training=False)
136
+ # remember to call humanflowmodel.reset_charactersheet() if you change character .
137
+
138
+ train_time_interval = time.time() - time_stamp
139
+ time_stamp = time.time()
140
+ if i % 5 == 0 and args.local_rank == 0:
141
+ print("[infer batch: %4d/%4d] time:%2f+%2f" % (
142
+ i, train_num,
143
+ data_time_interval, train_time_interval
144
+ ))
145
+ with torch.no_grad():
146
+
147
+ if args.test_output_video:
148
+ pred_img = pred["shader"]["y_weighted_warp_decoded_rgba"]
149
+ save_output(
150
+ str(int(data["imidx"].cpu().item())), pred_img, args.test_output_dir, crop=data["pose_crop"])
151
+
152
+ if args.test_output_udp:
153
+ pred_img = pred["shader"]["x_target_sudp_a"]
154
+ save_output(
155
+ "udp_"+str(int(data["imidx"].cpu().item())), pred_img, args.test_output_dir)
156
+
157
+
158
+ def build_args():
159
+ parser = argparse.ArgumentParser()
160
+ # distributed learning settings
161
+ parser.add_argument("--world_size", type=int, default=1,
162
+ help='world size')
163
+ parser.add_argument("--local_rank", type=int, default=0,
164
+ help='local_rank, DON\'T change it')
165
+
166
+ # model settings
167
+ parser.add_argument('--dataloader_imgsize', type=int, default=256,
168
+ help='Input image size of the model')
169
+ parser.add_argument('--batch_size', type=int, default=4,
170
+ help='minibatch size')
171
+ parser.add_argument('--model_name', default='model_result',
172
+ help='Name of the experiment')
173
+ parser.add_argument('--dataloaders', type=int, default=2,
174
+ help='Num of dataloaders')
175
+ parser.add_argument('--mode', default="test", choices=['train', 'test'],
176
+ help='Training mode or Testing mode')
177
+
178
+ # i/o settings
179
+ parser.add_argument('--test_input_person_images',
180
+ type=str, default="./character_sheet/",
181
+ help='Directory to input character sheets')
182
+ parser.add_argument('--test_input_poses_images', type=str,
183
+ default="./test_data/",
184
+ help='Directory to input UDP sequences or pose images')
185
+ parser.add_argument('--test_checkpoint_dir', type=str,
186
+ default='./weights/',
187
+ help='Directory to model weights')
188
+ parser.add_argument('--test_output_dir', type=str,
189
+ default="./results/",
190
+ help='Directory to output images')
191
+
192
+ # output content settings
193
+ parser.add_argument('--test_output_video', type=strtobool, default=True,
194
+ help='Whether to output the final result of CoNR, \
195
+ images will be output to test_output_dir while True.')
196
+ parser.add_argument('--test_output_udp', type=strtobool, default=False,
197
+ help='Whether to output UDP generated from UDP detector, \
198
+ this is meaningful ONLY when test_input_poses_images \
199
+ is not UDP sequences but pose images. Meanwhile, \
200
+ test_pose_use_parser_udp need to be True')
201
+
202
+ # UDP detector settings
203
+ parser.add_argument('--test_pose_use_parser_udp',
204
+ type=strtobool, default=False,
205
+ help='Whether to use UDP detector to generate UDP from pngs, \
206
+ pose input MUST be pose images instead of UDP sequences \
207
+ while True')
208
+
209
+ args = parser.parse_args()
210
+
211
+ args.distributed = (args.world_size > 1)
212
+ if args.local_rank == 0:
213
+ print("batch_size:", args.batch_size, flush=True)
214
+ if args.distributed:
215
+ if args.local_rank == 0:
216
+ print("world_size: ", args.world_size)
217
+ torch.distributed.init_process_group(
218
+ backend="nccl", init_method="env://", world_size=args.world_size)
219
+ torch.cuda.set_device(args.local_rank)
220
+ torch.backends.cudnn.benchmark = True
221
+ else:
222
+ args.local_rank = 0
223
+
224
+ return args
225
+
226
+
227
+ if __name__ == "__main__":
228
+ args = build_args()
229
+ test()