zejunyang commited on
Commit
9667e74
1 Parent(s): 9464d6e
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ */ref_img.png filter=lfs diff=lfs merge=lfs -text
NTED/NTED_module.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ import random
5
+
6
+ import mediapipe as mp
7
+ from lite_openpose.body_bbox_detector import BodyPoseEstimator
8
+ from NTED.extraction_distribution_model import Generator
9
+ from NTED.demo_dataset import DemoDataset
10
+ from NTED.base_function import accumulate
11
+ from NTED.config import Config
12
+
13
+
14
+ def set_random_seed(seed):
15
+ r"""Set random seeds for everything.
16
+
17
+ Args:
18
+ seed (int): Random seed.
19
+ by_rank (bool):
20
+ """
21
+ random.seed(seed)
22
+ np.random.seed(seed)
23
+ torch.manual_seed(seed)
24
+ torch.cuda.manual_seed(seed)
25
+ torch.cuda.manual_seed_all(seed)
26
+
27
+ class NTED():
28
+ def __init__(self):
29
+ super(NTED, self).__init__()
30
+
31
+ self.openpose_module = BodyPoseEstimator('cpu')
32
+ set_random_seed(0)
33
+ self.opt = Config('NTED/fashion_512.yaml', is_train=False)
34
+
35
+ net_G = Generator(**self.opt.gen.param).to('cpu')
36
+ net_G_ema = Generator(**self.opt.gen.param).to('cpu')
37
+ net_G_ema.eval()
38
+ accumulate(net_G_ema, net_G, 0)
39
+
40
+ checkpoint = torch.load('NTED/nted_checkpoint.pt', map_location=lambda storage, loc: storage)
41
+ net_G_ema.load_state_dict(checkpoint['net_G_ema'])
42
+ self.net_G = net_G_ema.eval()
43
+
44
+ self.data_loader = DemoDataset()
45
+
46
+ mp_hands = mp.solutions.hands
47
+ self.hands = mp_hands.Hands(static_image_mode=True, max_num_hands=2, min_detection_confidence=0.1)
48
+
49
+ self.ref_img = cv2.imread('example/ref_img.png')
50
+ self.ref_img = cv2.resize(self.ref_img, (352, 512))
51
+
52
+ def hand_pose_est(self, img):
53
+ results = self.hands.process(cv2.cvtColor(cv2.flip(img, 1), cv2.COLOR_BGR2RGB))
54
+ image_height, image_width, _ = img.shape
55
+ pose_data = []
56
+
57
+ if results.multi_hand_landmarks is not None:
58
+ for hand_landmarks in results.multi_hand_landmarks:
59
+ for joint_idx in range(21):
60
+ pose_data.append([image_width - hand_landmarks.landmark[joint_idx].x * image_width, hand_landmarks.landmark[joint_idx].y * image_height])
61
+ if len(results.multi_hand_landmarks) == 2:
62
+ if results.multi_handedness[0].classification[0].label == 'Right':
63
+ # 交换一下,先左手再右手
64
+ tmp = pose_data[:21].copy()
65
+ pose_data[:21] = pose_data[21:]
66
+ pose_data[21:] = tmp
67
+ elif len(results.multi_hand_landmarks) == 1:
68
+ miss_hand = [[-1, -1] for _ in range(21)]
69
+ if results.multi_handedness[0].classification[0].label == 'Left':
70
+ pose_data += miss_hand
71
+ else:
72
+ pose_data = miss_hand + pose_data
73
+ else:
74
+ for _ in range(42):
75
+ pose_data.append([-1, -1])
76
+ pose_data = np.array(pose_data, dtype=np.int32)
77
+
78
+ return pose_data
79
+
80
+
81
+ def inference(self, img):
82
+
83
+ img = cv2.resize(img, (352, 512))
84
+
85
+ body_pose, bbox = self.openpose_module.detect_body_pose(img.copy())
86
+
87
+ hand_pose = self.hand_pose_est(img.copy())
88
+
89
+ data = self.data_loader.load_item(self.ref_img, body_pose[0], hand_pose)
90
+
91
+ output = self.net_G(
92
+ data['reference_image'],
93
+ data['target_skeleton'],
94
+ )
95
+ fake_image = output['fake_image'][0]
96
+
97
+ fake_image = self.data_loader.tensor2im(fake_image)
98
+
99
+ fake_image = cv2.resize(fake_image, (288, 480))
100
+
101
+ return data['skeleton_img'], fake_image
NTED/base_function.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import math
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from NTED.op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
9
+
10
+ class ExtractionOperation(nn.Module):
11
+ def __init__(self, in_channel, num_label, match_kernel):
12
+ super(ExtractionOperation, self).__init__()
13
+ self.value_conv = EqualConv2d(in_channel, in_channel, match_kernel, 1, match_kernel//2, bias=True)
14
+ self.semantic_extraction_filter = EqualConv2d(in_channel, num_label, match_kernel, 1, match_kernel//2, bias=False)
15
+
16
+ self.softmax = nn.Softmax(dim=-1)
17
+ self.num_label = num_label
18
+
19
+ def forward(self, value, recoder):
20
+ key = value
21
+ b,c,h,w = value.shape
22
+ key = self.semantic_extraction_filter(self.feature_norm(key))
23
+ extraction_softmax = self.softmax(key.view(b, -1, h*w)) #bkm
24
+ values_flatten = self.value_conv(value).view(b, -1, h*w)
25
+ neural_textures = torch.einsum('bkm,bvm->bvk', extraction_softmax, values_flatten)
26
+ recoder['extraction_softmax'].insert(0, extraction_softmax)
27
+ recoder['neural_textures'].insert(0, neural_textures)
28
+ return neural_textures, extraction_softmax
29
+
30
+
31
+ def feature_norm(self, input_tensor):
32
+ input_tensor = input_tensor - input_tensor.mean(dim=1, keepdim=True)
33
+ norm = torch.norm(input_tensor, 2, 1, keepdim=True) + sys.float_info.epsilon
34
+ out = torch.div(input_tensor, norm)
35
+ return out
36
+
37
+ class DistributionOperation(nn.Module):
38
+ def __init__(self, num_label, input_dim, match_kernel=3):
39
+ super(DistributionOperation, self).__init__()
40
+ self.semantic_distribution_filter = EqualConv2d(input_dim, num_label,
41
+ kernel_size=match_kernel,
42
+ stride=1,
43
+ padding=match_kernel//2)
44
+ self.num_label = num_label
45
+
46
+ def forward(self, query, extracted_feature, recoder):
47
+ b,c,h,w = query.shape
48
+
49
+ query = self.semantic_distribution_filter(query)
50
+ query_flatten = query.view(b, self.num_label, -1)
51
+ query_softmax = F.softmax(query_flatten, 1)
52
+ values_q = torch.einsum('bkm,bkv->bvm', query_softmax, extracted_feature.permute(0,2,1))
53
+ attn_out = values_q.view(b,-1,h,w)
54
+ recoder['semantic_distribution'].append(query)
55
+ return attn_out
56
+
57
+ class EncoderLayer(nn.Sequential):
58
+ def __init__(
59
+ self,
60
+ in_channel,
61
+ out_channel,
62
+ kernel_size,
63
+ downsample=False,
64
+ blur_kernel=[1, 3, 3, 1],
65
+ bias=True,
66
+ activate=True,
67
+ use_extraction=False,
68
+ num_label=None,
69
+ match_kernel=None,
70
+ num_extractions=2
71
+ ):
72
+ super().__init__()
73
+
74
+ if downsample:
75
+ factor = 2
76
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
77
+ pad0 = (p + 1) // 2
78
+ pad1 = p // 2
79
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
80
+
81
+ stride = 2
82
+ padding = 0
83
+
84
+ else:
85
+ self.blur = None
86
+ stride = 1
87
+ padding = kernel_size // 2
88
+
89
+
90
+ self.conv = EqualConv2d(
91
+ in_channel,
92
+ out_channel,
93
+ kernel_size,
94
+ padding=padding,
95
+ stride=stride,
96
+ bias=bias and not activate,
97
+ )
98
+
99
+ self.activate = FusedLeakyReLU(out_channel, bias=bias) if activate else None
100
+ self.use_extraction = use_extraction
101
+ if self.use_extraction:
102
+ self.extraction_operations = nn.ModuleList()
103
+ for _ in range(num_extractions):
104
+ self.extraction_operations.append(
105
+ ExtractionOperation(
106
+ out_channel,
107
+ num_label,
108
+ match_kernel
109
+ )
110
+ )
111
+
112
+ def forward(self, input, recoder=None):
113
+ out = self.blur(input) if self.blur is not None else input
114
+ out = self.conv(out)
115
+ out = self.activate(out) if self.activate is not None else out
116
+ if self.use_extraction:
117
+ for extraction_operation in self.extraction_operations:
118
+ extraction_operation(out, recoder)
119
+ return out
120
+
121
+
122
+ class DecoderLayer(nn.Module):
123
+ def __init__(
124
+ self,
125
+ in_channel,
126
+ out_channel,
127
+ kernel_size,
128
+ upsample=False,
129
+ blur_kernel=[1, 3, 3, 1],
130
+ bias=True,
131
+ activate=True,
132
+ use_distribution=True,
133
+ num_label=16,
134
+ match_kernel=3,
135
+ ):
136
+ super().__init__()
137
+ if upsample:
138
+ factor = 2
139
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
140
+ pad0 = (p + 1) // 2 + factor - 1
141
+ pad1 = p // 2 + 1
142
+
143
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
144
+ self.conv = EqualTransposeConv2d(
145
+ in_channel,
146
+ out_channel,
147
+ kernel_size,
148
+ stride=2,
149
+ padding=0,
150
+ bias=bias and not activate,
151
+ )
152
+ else:
153
+ self.conv = EqualConv2d(
154
+ in_channel,
155
+ out_channel,
156
+ kernel_size,
157
+ stride=1,
158
+ padding=kernel_size//2,
159
+ bias=bias and not activate,
160
+ )
161
+ self.blur = None
162
+
163
+ self.distribution_operation = DistributionOperation(
164
+ num_label,
165
+ out_channel,
166
+ match_kernel=match_kernel
167
+ ) if use_distribution else None
168
+ self.activate = FusedLeakyReLU(out_channel, bias=bias) if activate else None
169
+ self.use_distribution = use_distribution
170
+
171
+ def forward(self, input, neural_texture=None, recoder=None):
172
+ out = self.conv(input)
173
+ out = self.blur(out) if self.blur is not None else out
174
+ if self.use_distribution and neural_texture is not None:
175
+ out_attn = self.distribution_operation(out, neural_texture, recoder)
176
+ out = (out + out_attn) / math.sqrt(2)
177
+
178
+ out = self.activate(out.contiguous()) if self.activate is not None else out
179
+
180
+ return out
181
+
182
+ class EqualConv2d(nn.Module):
183
+ def __init__(
184
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
185
+ ):
186
+ super().__init__()
187
+
188
+ self.weight = nn.Parameter(
189
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
190
+ )
191
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
192
+
193
+ self.stride = stride
194
+ self.padding = padding
195
+
196
+ if bias:
197
+ self.bias = nn.Parameter(torch.zeros(out_channel))
198
+
199
+ else:
200
+ self.bias = None
201
+
202
+ def forward(self, input):
203
+ out = conv2d_gradfix.conv2d(
204
+ input,
205
+ self.weight * self.scale,
206
+ bias=self.bias,
207
+ stride=self.stride,
208
+ padding=self.padding,
209
+ )
210
+
211
+ return out
212
+
213
+ def __repr__(self):
214
+ return (
215
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
216
+ f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
217
+ )
218
+
219
+
220
+ class EqualTransposeConv2d(nn.Module):
221
+ def __init__(
222
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
223
+ ):
224
+ super().__init__()
225
+
226
+ self.weight = nn.Parameter(
227
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
228
+ )
229
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
230
+
231
+ self.stride = stride
232
+ self.padding = padding
233
+
234
+ if bias:
235
+ self.bias = nn.Parameter(torch.zeros(out_channel))
236
+
237
+ else:
238
+ self.bias = None
239
+
240
+ def forward(self, input):
241
+ weight = self.weight.transpose(0,1)
242
+ out = conv2d_gradfix.conv_transpose2d(
243
+ input,
244
+ weight * self.scale,
245
+ bias=self.bias,
246
+ stride=self.stride,
247
+ padding=self.padding,
248
+ )
249
+
250
+ return out
251
+
252
+ def __repr__(self):
253
+ return (
254
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
255
+ f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
256
+ )
257
+
258
+ class ToRGB(nn.Module):
259
+ def __init__(
260
+ self,
261
+ in_channel,
262
+ upsample=True,
263
+ blur_kernel=[1, 3, 3, 1]
264
+ ):
265
+ super().__init__()
266
+
267
+ if upsample:
268
+ self.upsample = Upsample(blur_kernel)
269
+ self.conv = EqualConv2d(in_channel, 3, 3, stride=1, padding=1)
270
+
271
+ def forward(self, input, skip=None):
272
+ out = self.conv(input)
273
+ if skip is not None:
274
+ skip = self.upsample(skip)
275
+ out = out + skip
276
+ return out
277
+
278
+
279
+ class EqualLinear(nn.Module):
280
+ def __init__(
281
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
282
+ ):
283
+ super().__init__()
284
+
285
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
286
+
287
+ if bias:
288
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
289
+
290
+ else:
291
+ self.bias = None
292
+
293
+ self.activation = activation
294
+
295
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
296
+ self.lr_mul = lr_mul
297
+
298
+ def forward(self, input):
299
+ if self.activation:
300
+ out = F.linear(input, self.weight * self.scale)
301
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
302
+
303
+ else:
304
+ out = F.linear(
305
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
306
+ )
307
+
308
+ return out
309
+
310
+ def __repr__(self):
311
+ return (
312
+ f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
313
+ )
314
+
315
+ class Upsample(nn.Module):
316
+ def __init__(self, kernel, factor=2):
317
+ super().__init__()
318
+
319
+ self.factor = factor
320
+ kernel = make_kernel(kernel) * (factor ** 2)
321
+ self.register_buffer("kernel", kernel)
322
+
323
+ p = kernel.shape[0] - factor
324
+
325
+ pad0 = (p + 1) // 2 + factor - 1
326
+ pad1 = p // 2
327
+
328
+ self.pad = (pad0, pad1)
329
+
330
+ def forward(self, input):
331
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
332
+
333
+ return out
334
+
335
+ class ResBlock(nn.Module):
336
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
337
+ super().__init__()
338
+
339
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
340
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
341
+
342
+ self.skip = ConvLayer(
343
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
344
+ )
345
+
346
+ def forward(self, input):
347
+ out = self.conv1(input)
348
+ out = self.conv2(out)
349
+
350
+ skip = self.skip(input)
351
+ out = (out + skip) / math.sqrt(2)
352
+
353
+ return out
354
+
355
+ class ConvLayer(nn.Sequential):
356
+ def __init__(
357
+ self,
358
+ in_channel,
359
+ out_channel,
360
+ kernel_size,
361
+ downsample=False,
362
+ blur_kernel=[1, 3, 3, 1],
363
+ bias=True,
364
+ activate=True,
365
+ ):
366
+ layers = []
367
+
368
+ if downsample:
369
+ factor = 2
370
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
371
+ pad0 = (p + 1) // 2
372
+ pad1 = p // 2
373
+
374
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
375
+
376
+ stride = 2
377
+ self.padding = 0
378
+
379
+ else:
380
+ stride = 1
381
+ self.padding = kernel_size // 2
382
+
383
+ layers.append(
384
+ EqualConv2d(
385
+ in_channel,
386
+ out_channel,
387
+ kernel_size,
388
+ padding=self.padding,
389
+ stride=stride,
390
+ bias=bias and not activate,
391
+ )
392
+ )
393
+
394
+ if activate:
395
+ layers.append(FusedLeakyReLU(out_channel, bias=bias))
396
+
397
+ super().__init__(*layers)
398
+
399
+
400
+ class Blur(nn.Module):
401
+ def __init__(self, kernel, pad, upsample_factor=1):
402
+ super().__init__()
403
+
404
+ kernel = make_kernel(kernel)
405
+
406
+ if upsample_factor > 1:
407
+ kernel = kernel * (upsample_factor ** 2)
408
+
409
+ self.register_buffer("kernel", kernel)
410
+
411
+ self.pad = pad
412
+
413
+ def forward(self, input):
414
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
415
+
416
+ return out
417
+
418
+
419
+ def make_kernel(k):
420
+ k = torch.tensor(k, dtype=torch.float32)
421
+
422
+ if k.ndim == 1:
423
+ k = k[None, :] * k[:, None]
424
+
425
+ k /= k.sum()
426
+
427
+ return k
428
+
429
+ def accumulate(model1, model2, decay=0.999):
430
+ par1 = dict(model1.named_parameters())
431
+ par2 = dict(model2.named_parameters())
432
+
433
+ for k in par1.keys():
434
+ par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay)
NTED/base_module.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import functools
3
+ import sys
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from NTED.base_function import EncoderLayer, DecoderLayer, ToRGB
9
+ from NTED.edge_attention_layer import Edge_Attn
10
+
11
+ class Encoder(nn.Module):
12
+ def __init__(
13
+ self,
14
+ size,
15
+ input_dim,
16
+ channels,
17
+ num_labels=None,
18
+ match_kernels=None,
19
+ blur_kernel=[1, 3, 3, 1],
20
+ ):
21
+ super().__init__()
22
+ self.first = EncoderLayer(input_dim, channels[size], 1)
23
+ self.convs = nn.ModuleList()
24
+
25
+ log_size = int(math.log(size, 2))
26
+ self.log_size = log_size
27
+
28
+ in_channel = channels[size]
29
+ for i in range(log_size-1, 3, -1):
30
+ out_channel = channels[2 ** i]
31
+ num_label = num_labels[2 ** i] if num_labels is not None else None
32
+ match_kernel = match_kernels[2 ** i] if match_kernels is not None else None
33
+ use_extraction = num_label and match_kernel
34
+ conv = EncoderLayer(
35
+ in_channel,
36
+ out_channel,
37
+ kernel_size=3,
38
+ downsample=True,
39
+ blur_kernel=blur_kernel,
40
+ use_extraction=use_extraction,
41
+ num_label=num_label,
42
+ match_kernel=match_kernel
43
+ )
44
+
45
+ self.convs.append(conv)
46
+ in_channel = out_channel
47
+
48
+ def forward(self, input, recoder=None):
49
+ out = self.first(input)
50
+ for idx, layer in enumerate(self.convs):
51
+ out = layer(out, recoder)
52
+ return out
53
+
54
+ class Decoder(nn.Module):
55
+ def __init__(
56
+ self,
57
+ size,
58
+ channels,
59
+ num_labels,
60
+ match_kernels,
61
+ blur_kernel=[1, 3, 3, 1],
62
+ ):
63
+ super().__init__()
64
+
65
+
66
+ self.convs = nn.ModuleList()
67
+ # input at resolution 16*16
68
+ in_channel = channels[16]
69
+ self.log_size = int(math.log(size, 2))
70
+
71
+ for i in range(4, self.log_size + 1):
72
+ out_channel = channels[2 ** i]
73
+ num_label, match_kernel = num_labels[2 ** i], match_kernels[2 ** i]
74
+ use_distribution = num_label and match_kernel
75
+ upsample = (i != 4)
76
+
77
+ base_layer = functools.partial(
78
+ DecoderLayer,
79
+ out_channel=out_channel,
80
+ kernel_size=3,
81
+ blur_kernel=blur_kernel,
82
+ use_distribution=use_distribution,
83
+ num_label=num_label,
84
+ match_kernel=match_kernel
85
+ )
86
+
87
+ up = nn.Module()
88
+ up.conv0 = base_layer(in_channel=in_channel, upsample=upsample)
89
+ up.conv1 = base_layer(in_channel=out_channel, upsample=False)
90
+ up.to_rgb = ToRGB(out_channel, upsample=upsample)
91
+ self.convs.append(up)
92
+ in_channel = out_channel
93
+
94
+ self.num_labels, self.match_kernels = num_labels, match_kernels
95
+
96
+ self.edge_attn_block = Edge_Attn(in_channels=3)
97
+
98
+ def forward(self, input, neural_textures, recoder):
99
+ counter = 0
100
+ out, skip = input, None
101
+ for i, up in enumerate(self.convs):
102
+ if self.num_labels[2**(i+4)] and self.match_kernels[2**(i+4)]:
103
+ neural_texture_conv0 = neural_textures[counter]
104
+ neural_texture_conv1 = neural_textures[counter+1]
105
+ counter += 2
106
+ else:
107
+ neural_texture_conv0, neural_texture_conv1 = None, None
108
+ out = up.conv0(out, neural_texture=neural_texture_conv0, recoder=recoder)
109
+ out = up.conv1(out, neural_texture=neural_texture_conv1, recoder=recoder)
110
+
111
+ skip = up.to_rgb(out, skip)
112
+ image = self.edge_attn_block(skip)
113
+ # image = skip
114
+ return image
115
+
NTED/config.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import functools
3
+ import os
4
+ import re
5
+
6
+ import yaml
7
+
8
+ class AttrDict(dict):
9
+ """Dict as attribute trick."""
10
+
11
+ def __init__(self, *args, **kwargs):
12
+ super(AttrDict, self).__init__(*args, **kwargs)
13
+ self.__dict__ = self
14
+ for key, value in self.__dict__.items():
15
+ if isinstance(value, dict):
16
+ self.__dict__[key] = AttrDict(value)
17
+ elif isinstance(value, (list, tuple)):
18
+ if isinstance(value[0], dict):
19
+ self.__dict__[key] = [AttrDict(item) for item in value]
20
+ else:
21
+ self.__dict__[key] = value
22
+
23
+ def yaml(self):
24
+ """Convert object to yaml dict and return."""
25
+ yaml_dict = {}
26
+ for key, value in self.__dict__.items():
27
+ if isinstance(value, AttrDict):
28
+ yaml_dict[key] = value.yaml()
29
+ elif isinstance(value, list):
30
+ if isinstance(value[0], AttrDict):
31
+ new_l = []
32
+ for item in value:
33
+ new_l.append(item.yaml())
34
+ yaml_dict[key] = new_l
35
+ else:
36
+ yaml_dict[key] = value
37
+ else:
38
+ yaml_dict[key] = value
39
+ return yaml_dict
40
+
41
+ def __repr__(self):
42
+ """Print all variables."""
43
+ ret_str = []
44
+ for key, value in self.__dict__.items():
45
+ if isinstance(value, AttrDict):
46
+ ret_str.append('{}:'.format(key))
47
+ child_ret_str = value.__repr__().split('\n')
48
+ for item in child_ret_str:
49
+ ret_str.append(' ' + item)
50
+ elif isinstance(value, list):
51
+ if isinstance(value[0], AttrDict):
52
+ ret_str.append('{}:'.format(key))
53
+ for item in value:
54
+ # Treat as AttrDict above.
55
+ child_ret_str = item.__repr__().split('\n')
56
+ for item in child_ret_str:
57
+ ret_str.append(' ' + item)
58
+ else:
59
+ ret_str.append('{}: {}'.format(key, value))
60
+ else:
61
+ ret_str.append('{}: {}'.format(key, value))
62
+ return '\n'.join(ret_str)
63
+
64
+
65
+ class Config(AttrDict):
66
+ r"""Configuration class. This should include every human specifiable
67
+ hyperparameter values for your training."""
68
+
69
+ def __init__(self, filename=None, verbose=False, is_train=True):
70
+ super(Config, self).__init__()
71
+ # Set default parameters.
72
+ # Logging.
73
+
74
+ large_number = 1000000000
75
+ self.snapshot_save_iter = large_number
76
+ self.snapshot_save_epoch = large_number
77
+ self.snapshot_save_start_iter = 0
78
+ self.snapshot_save_start_epoch = 0
79
+ self.image_save_iter = large_number
80
+ self.eval_epoch = large_number
81
+ self.start_eval_epoch = large_number
82
+ self.eval_epoch = large_number
83
+ self.max_epoch = large_number
84
+ self.max_iter = large_number
85
+ self.logging_iter = 100
86
+ self.image_to_tensorboard=False
87
+ self.which_iter = None
88
+ self.resume = True
89
+
90
+
91
+ self.checkpoints_dir = 'NTED'
92
+ self.name = 'nted_checkpoint.pt'
93
+ self.phase = 'train' if is_train else 'test'
94
+
95
+ # Networks.
96
+ self.gen = AttrDict(type='generators.dummy')
97
+ self.dis = AttrDict(type='discriminators.dummy')
98
+
99
+ # Optimizers.
100
+ self.gen_optimizer = AttrDict(type='adam',
101
+ lr=0.0001,
102
+ adam_beta1=0.0,
103
+ adam_beta2=0.999,
104
+ eps=1e-8,
105
+ lr_policy=AttrDict(iteration_mode=False,
106
+ type='step',
107
+ step_size=large_number,
108
+ gamma=1))
109
+ self.dis_optimizer = AttrDict(type='adam',
110
+ lr=0.0001,
111
+ adam_beta1=0.0,
112
+ adam_beta2=0.999,
113
+ eps=1e-8,
114
+ lr_policy=AttrDict(iteration_mode=False,
115
+ type='step',
116
+ step_size=large_number,
117
+ gamma=1))
118
+ # Data.
119
+ self.data = AttrDict(name='dummy',
120
+ type='datasets.images',
121
+ num_workers=0)
122
+ self.test_data = AttrDict(name='dummy',
123
+ type='datasets.images',
124
+ num_workers=0,
125
+ test=AttrDict(is_lmdb=False,
126
+ roots='',
127
+ batch_size=1))
128
+ self.trainer = AttrDict(
129
+ image_to_tensorboard=False,
130
+ hparam_to_tensorboard=False)
131
+
132
+ # Cudnn.
133
+ self.cudnn = AttrDict(deterministic=False,
134
+ benchmark=True)
135
+
136
+ # Others.
137
+ self.pretrained_weight = ''
138
+ self.inference_args = AttrDict()
139
+
140
+
141
+ # Update with given configurations.
142
+ assert os.path.exists(filename), 'File {} not exist.'.format(filename)
143
+ loader = yaml.SafeLoader
144
+ loader.add_implicit_resolver(
145
+ u'tag:yaml.org,2002:float',
146
+ re.compile(u'''^(?:
147
+ [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
148
+ |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
149
+ |\\.[0-9_]+(?:[eE][-+][0-9]+)?
150
+ |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
151
+ |[-+]?\\.(?:inf|Inf|INF)
152
+ |\\.(?:nan|NaN|NAN))$''', re.X),
153
+ list(u'-+0123456789.'))
154
+ try:
155
+ with open(filename, 'r') as f:
156
+ cfg_dict = yaml.load(f, Loader=loader)
157
+ except EnvironmentError:
158
+ print('Please check the file with name of "%s"', filename)
159
+ recursive_update(self, cfg_dict)
160
+
161
+ # Put common opts in both gen and dis.
162
+ if 'common' in cfg_dict:
163
+ self.common = AttrDict(**cfg_dict['common'])
164
+ self.gen.common = self.common
165
+ self.dis.common = self.common
166
+
167
+
168
+ if verbose:
169
+ print(' config '.center(80, '-'))
170
+ print(self.__repr__())
171
+ print(''.center(80, '-'))
172
+
173
+
174
+ def rsetattr(obj, attr, val):
175
+ """Recursively find object and set value"""
176
+ pre, _, post = attr.rpartition('.')
177
+ return setattr(rgetattr(obj, pre) if pre else obj, post, val)
178
+
179
+
180
+ def rgetattr(obj, attr, *args):
181
+ """Recursively find object and return value"""
182
+
183
+ def _getattr(obj, attr):
184
+ r"""Get attribute."""
185
+ return getattr(obj, attr, *args)
186
+
187
+ return functools.reduce(_getattr, [obj] + attr.split('.'))
188
+
189
+
190
+ def recursive_update(d, u):
191
+ """Recursively update AttrDict d with AttrDict u"""
192
+ for key, value in u.items():
193
+ if isinstance(value, collections.abc.Mapping):
194
+ d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value)
195
+ elif isinstance(value, (list, tuple)):
196
+ if isinstance(value[0], dict):
197
+ d.__dict__[key] = [AttrDict(item) for item in value]
198
+ else:
199
+ d.__dict__[key] = value
200
+ else:
201
+ d.__dict__[key] = value
202
+ return d
NTED/demo_dataset.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import cv2
4
+ import math
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ import torch
9
+ import torchvision.transforms.functional as F
10
+
11
+ class DemoDataset(object):
12
+ def __init__(self):
13
+ super().__init__()
14
+ self.LIMBSEQ = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
15
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
16
+ [1, 16], [16, 18], [3, 17], [6, 18]]
17
+
18
+ self.COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
19
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
20
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
21
+
22
+ self.LIMBSEQ_hands = [[0, 1], [1, 2], [2, 3], [3, 4], \
23
+ [0, 5], [5, 6], [6, 7], [7, 8], \
24
+ [0, 9], [9, 10], [10, 11], [11, 12], \
25
+ [0, 13], [13, 14], [14, 15], [15, 16], \
26
+ [0, 17], [17, 18], [18, 19], [19, 20], \
27
+ [21, 22], [22, 23], [23, 24], [24, 25], \
28
+ [21, 26], [26, 27], [27, 28], [28, 29], \
29
+ [21, 30], [30, 31], [31, 32], [32, 33], \
30
+ [21, 34], [34, 35], [35, 36], [36, 37], \
31
+ [21, 38], [38, 39], [39, 40], [40, 41]]
32
+
33
+ self.COLORS_hands = [[85, 0, 0], [170, 0, 0], [85, 85, 0], [85, 170, 0], [170, 85, 0], [170, 170, 0], [85, 85, 85], \
34
+ [85, 85, 170], [85, 170, 85], [85, 170, 170], [0, 85, 0], [0, 170, 0], [0, 85, 85], [0, 85, 170], \
35
+ [0, 170, 85], [0, 170, 170], [50, 0, 0], [135, 0, 0], [50, 50, 0], [50, 135, 0], [135, 50, 0], \
36
+ [135, 135, 0], [50, 50, 50], [50, 50, 135], [50, 135, 50], [50, 135, 135], [0, 50, 0], [0, 135, 0], \
37
+ [0, 50, 50], [0, 50, 135], [0, 135, 50], [0, 135, 135], [100, 0, 0], [200, 0, 0], [100, 100, 0], \
38
+ [100, 200, 0], [200, 100, 0], [200, 200, 0], [100, 100, 100], [100, 100, 200], [100, 200, 100], [100, 200, 200]
39
+ ]
40
+
41
+ self.img_size = tuple([512, 352])
42
+
43
+ def load_item(self, img, pose, handpose=None):
44
+
45
+ reference_img = self.get_image_tensor(img)[None,:]
46
+ label, ske = self.get_label_tensor(pose, handpose)
47
+ label = label[None,:]
48
+
49
+ return {'reference_image':reference_img, 'target_skeleton':label, 'skeleton_img': ske}
50
+
51
+ def get_image_tensor(self, bgr_img):
52
+ img = Image.fromarray(cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB))
53
+ img = F.resize(img, self.img_size)
54
+ img = F.to_tensor(img)
55
+ img = F.normalize(img, (0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
56
+ return img
57
+
58
+ def get_label_tensor(self, pose, hand_pose=None):
59
+ canvas = np.zeros((self.img_size[0], self.img_size[1], 3)).astype(np.uint8)
60
+ keypoint = np.array(pose)
61
+ if hand_pose is not None:
62
+ keypoint_hands = np.array(hand_pose)
63
+ else:
64
+ keypoint_hands = None
65
+
66
+ # keypoint = self.trans_keypoins(keypoint)
67
+
68
+ stickwidth = 4
69
+ for i in range(18):
70
+ x, y = keypoint[i, 0:2]
71
+ if x == -1 or y == -1:
72
+ continue
73
+ cv2.circle(canvas, (int(x), int(y)), 4, self.COLORS[i], thickness=-1)
74
+ if keypoint_hands is not None:
75
+ for i in range(42):
76
+ x, y = keypoint_hands[i, 0:2]
77
+ if x == -1 or y == -1:
78
+ continue
79
+ cv2.circle(canvas, (int(x), int(y)), 4, self.COLORS_hands[i], thickness=-1)
80
+
81
+ joints = []
82
+ for i in range(17):
83
+ Y = keypoint[np.array(self.LIMBSEQ[i])-1, 0]
84
+ X = keypoint[np.array(self.LIMBSEQ[i])-1, 1]
85
+ cur_canvas = canvas.copy()
86
+ if -1 in Y or -1 in X:
87
+ joints.append(np.zeros_like(cur_canvas[:, :, 0]))
88
+ continue
89
+ mX = np.mean(X)
90
+ mY = np.mean(Y)
91
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
92
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
93
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
94
+ cv2.fillConvexPoly(cur_canvas, polygon, self.COLORS[i])
95
+ canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
96
+
97
+ joint = np.zeros_like(cur_canvas[:, :, 0])
98
+ cv2.fillConvexPoly(joint, polygon, 255)
99
+ joint = cv2.addWeighted(joint, 0.4, joint, 0.6, 0)
100
+ joints.append(joint)
101
+ if keypoint_hands is not None:
102
+ for i in range(40):
103
+ Y = keypoint_hands[np.array(self.LIMBSEQ_hands[i]), 0]
104
+ X = keypoint_hands[np.array(self.LIMBSEQ_hands[i]), 1]
105
+ cur_canvas = canvas.copy()
106
+ if -1 in Y or -1 in X:
107
+ if (i+1) % 4 == 0:
108
+ joints.append(np.zeros_like(cur_canvas[:, :, 0]))
109
+ continue
110
+ mX = np.mean(X)
111
+ mY = np.mean(Y)
112
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
113
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
114
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), int(stickwidth/2)), int(angle), 0, 360, 1)
115
+ cv2.fillConvexPoly(cur_canvas, polygon, self.COLORS_hands[i])
116
+ canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
117
+
118
+ # 一根手指一个通道
119
+ if i % 4 == 0:
120
+ joint = np.zeros_like(cur_canvas[:, :, 0])
121
+ cv2.fillConvexPoly(joint, polygon, 255)
122
+ joint = cv2.addWeighted(joint, 0.4, joint, 0.6, 0)
123
+ if (i+1) % 4 == 0:
124
+ joints.append(joint)
125
+
126
+ pose = F.to_tensor(Image.fromarray(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)))
127
+
128
+ tensors_dist = 0
129
+ e = 1
130
+ for i in range(len(joints)):
131
+ im_dist = cv2.distanceTransform(255-joints[i], cv2.DIST_L1, 3)
132
+ im_dist = np.clip((im_dist / 3), 0, 255).astype(np.uint8)
133
+ tensor_dist = F.to_tensor(Image.fromarray(im_dist))
134
+ tensors_dist = tensor_dist if e == 1 else torch.cat([tensors_dist, tensor_dist])
135
+ e += 1
136
+
137
+ label_tensor = torch.cat((pose, tensors_dist), dim=0)
138
+
139
+ return label_tensor, canvas
140
+
141
+ def tensor2im(self, image_tensor, imtype=np.uint8, normalize=True,
142
+ three_channel_output=True):
143
+ r"""Convert tensor to image.
144
+
145
+ Args:
146
+ image_tensor (torch.tensor or list of torch.tensor): If tensor then
147
+ (NxCxHxW) or (NxTxCxHxW) or (CxHxW).
148
+ imtype (np.dtype): Type of output image.
149
+ normalize (bool): Is the input image normalized or not?
150
+ three_channel_output (bool): Should single channel images be made 3
151
+ channel in output?
152
+
153
+ Returns:
154
+ (numpy.ndarray, list if case 1, 2 above).
155
+ """
156
+ if image_tensor is None:
157
+ return None
158
+ if isinstance(image_tensor, list):
159
+ return [self.tensor2im(x, imtype, normalize) for x in image_tensor]
160
+ if image_tensor.dim() == 5 or image_tensor.dim() == 4:
161
+ return [self.tensor2im(image_tensor[idx], imtype, normalize)
162
+ for idx in range(image_tensor.size(0))]
163
+
164
+ if image_tensor.dim() == 3:
165
+ image_numpy = image_tensor.detach().float().numpy()
166
+ if normalize:
167
+ image_numpy = (np.transpose(
168
+ image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
169
+ else:
170
+ image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
171
+ image_numpy = np.clip(image_numpy, 0, 255)
172
+ if image_numpy.shape[2] == 1 and three_channel_output:
173
+ image_numpy = np.repeat(image_numpy, 3, axis=2)
174
+ elif image_numpy.shape[2] > 3:
175
+ image_numpy = image_numpy[:, :, :3]
176
+ return image_numpy.astype(imtype)
177
+
178
+ def trans_keypoins(self, keypoints):
179
+ missing_keypoint_index = keypoints == -1
180
+
181
+ keypoints[missing_keypoint_index] = -1
182
+ return keypoints
NTED/edge_attention_layer.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Date: 2023-03-14
2
+ # Creater: zejunyang
3
+ # Function: 边缘注意力层。
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from NTED.base_function import Blur
10
+
11
+
12
+ class ResBlock(nn.Module):
13
+ def __init__(self, in_nc, out_nc, scale='down'): # , norm_layer=nn.BatchNorm2d
14
+ super(ResBlock, self).__init__()
15
+ use_bias = True
16
+ assert scale in ['up', 'down', 'same'], "ResBlock scale must be in 'up' 'down' 'same'"
17
+
18
+ if scale == 'same':
19
+ # self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=1, bias=True)
20
+ self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=True)
21
+ if scale == 'up':
22
+ self.scale = nn.Sequential(
23
+ nn.Upsample(scale_factor=2, mode='bilinear'),
24
+ nn.Conv2d(in_nc, out_nc, kernel_size=1,bias=True)
25
+ )
26
+ if scale == 'down':
27
+ self.scale = nn.Conv2d(in_nc, out_nc, kernel_size=3, stride=2, padding=1, bias=use_bias)
28
+
29
+ self.block = nn.Sequential(
30
+ nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
31
+ # norm_layer(out_nc),
32
+ nn.ReLU(inplace=True),
33
+ nn.Conv2d(out_nc, out_nc, kernel_size=3, stride=1, padding=1, bias=use_bias),
34
+ # norm_layer(out_nc)
35
+ )
36
+ self.relu = nn.ReLU(inplace=True)
37
+ # self.padding = nn.ReplicationPad2d(padding=(0, 1, 0, 0))
38
+
39
+ def forward(self, x):
40
+ residual = self.scale(x)
41
+ return self.relu(residual + self.block(residual))
42
+
43
+
44
+ class Edge_Attn(nn.Module):
45
+ def __init__(self, in_channels=3):
46
+ super(Edge_Attn, self).__init__()
47
+ self.in_channels = in_channels
48
+
49
+ blur_kernel=[1, 3, 3, 3, 1]
50
+ self.blur = Blur(blur_kernel, pad=(2, 2), upsample_factor=1)
51
+
52
+ # self.conv = nn.Conv2d(self.in_channels, self.in_channels, 3, padding=1, bias=False)
53
+ self.res_block = ResBlock(self.in_channels, self.in_channels, scale='same')
54
+ self.sigmoid = nn.Sigmoid()
55
+
56
+ def gradient(self, x):
57
+ h_x = x.size()[2]
58
+ w_x = x.size()[3]
59
+ stride = 3
60
+ r = F.pad(x, (0, stride, 0, 0), mode='replicate')[:, :, :, stride:]
61
+ l = F.pad(x, (stride, 0, 0, 0), mode='replicate')[:, :, :, :w_x]
62
+ t = F.pad(x, (0, 0, stride, 0), mode='replicate')[:, :, :h_x, :]
63
+ b = F.pad(x, (0, 0, 0, stride), mode='replicate')[:, :, stride:, :]
64
+ xgrad = torch.pow(torch.pow((r - l) * 0.5, 2) + torch.pow((t - b) * 0.5, 2), 0.5)
65
+ xgrad = self.blur(xgrad)
66
+ return xgrad
67
+
68
+ def forward(self, x):
69
+ # feature_edge = self.gradient(x).detach()
70
+ # attn = self.conv(feature_edge)
71
+
72
+ for b in range(x.shape[0]):
73
+ for c in range(x.shape[1]):
74
+ if c == 0:
75
+ channel_edge = self.gradient(x[b:b+1, c:c+1])
76
+ else:
77
+ channel_edge = torch.concat([channel_edge, self.gradient(x[b:b+1, c:c+1])], dim=1)
78
+ if b == 0:
79
+ feature_edge = channel_edge
80
+ else:
81
+ feature_edge = torch.concat([feature_edge, channel_edge], dim=0)
82
+ feature_edge = feature_edge.detach()
83
+ feature_edge = x * feature_edge
84
+ attn = self.res_block(feature_edge)
85
+ attn = self.sigmoid(attn)
86
+
87
+ # out = x * attn
88
+
89
+ out = x * attn + x
90
+
91
+ return out
92
+
93
+
94
+
95
+ if __name__ == '__main__':
96
+ from PIL import Image
97
+ import numpy as np
98
+ import cv2
99
+
100
+ edg_atten = Edge_Attn()
101
+
102
+ im = Image.open('/apdcephfs/share_1474453/zejunzhang/dataset/pose_trans_dataset/fake_images/001400.png')
103
+ npim = np.array(im,dtype=np.float32)
104
+ npim = cv2.cvtColor(npim, cv2.COLOR_RGB2GRAY)
105
+
106
+ # npim = npim[:, :, 2]
107
+ tim = torch.from_numpy(npim).unsqueeze_(0).unsqueeze_(0)
108
+ edge = edg_atten.gradient(tim)
109
+ npgrad = edge.squeeze(0).squeeze(0).data.clamp(0,255).numpy()
110
+ Image.fromarray(npgrad.astype('uint8')).save('tmp.png')
111
+
112
+ # tim = torch.from_numpy(npim).unsqueeze_(0)
113
+ # edge = edg_atten.gradient_1order(tim)
114
+ # npgrad = edge.squeeze(0).data.clamp(0,255).numpy()[:, :, 0]
115
+ # Image.fromarray(npgrad.astype('uint8')).save('tmp.png')
116
+
NTED/extraction_distribution_model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ from torch import nn
3
+ from NTED.base_module import Encoder, Decoder
4
+
5
+ from torch.cuda.amp import autocast as autocast
6
+
7
+ class Generator(nn.Module):
8
+ def __init__(
9
+ self,
10
+ size,
11
+ semantic_dim,
12
+ channels,
13
+ num_labels,
14
+ match_kernels,
15
+ blur_kernel=[1, 3, 3, 1],
16
+ ):
17
+ super().__init__()
18
+ self.size = size
19
+ self.reference_encoder = Encoder(
20
+ size, 3, channels, num_labels, match_kernels, blur_kernel
21
+ )
22
+
23
+ self.skeleton_encoder = Encoder(
24
+ size, semantic_dim, channels,
25
+ )
26
+
27
+ self.target_image_renderer = Decoder(
28
+ size, channels, num_labels, match_kernels, blur_kernel
29
+ )
30
+
31
+ def _cal_temp(self, module):
32
+ return sum(p.numel() for p in module.parameters() if p.requires_grad)
33
+
34
+ def forward(
35
+ self,
36
+ source_image,
37
+ skeleton,
38
+ amp_flag=False,
39
+ ):
40
+ if amp_flag:
41
+ with autocast():
42
+ output_dict={}
43
+ recoder = collections.defaultdict(list)
44
+ skeleton_feature = self.skeleton_encoder(skeleton)
45
+ _ = self.reference_encoder(source_image, recoder)
46
+ neural_textures = recoder["neural_textures"]
47
+ output_dict['fake_image'] = self.target_image_renderer(
48
+ skeleton_feature, neural_textures, recoder
49
+ )
50
+ output_dict['info'] = recoder
51
+ return output_dict
52
+ else:
53
+ output_dict={}
54
+ recoder = collections.defaultdict(list)
55
+ skeleton_feature = self.skeleton_encoder(skeleton)
56
+ _ = self.reference_encoder(source_image, recoder)
57
+ neural_textures = recoder["neural_textures"]
58
+ output_dict['fake_image'] = self.target_image_renderer(
59
+ skeleton_feature, neural_textures, recoder
60
+ )
61
+ output_dict['info'] = recoder
62
+ return output_dict
NTED/fashion_512.yaml ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ distributed: True
2
+ image_to_tensorboard: True
3
+ snapshot_save_iter: 50000
4
+ snapshot_save_epoch: 20
5
+ snapshot_save_start_iter: 20000
6
+ snapshot_save_start_epoch: 100
7
+ image_save_iter: 1000
8
+ max_epoch: 400
9
+ logging_iter: 100
10
+ amp: False
11
+
12
+ gen_optimizer:
13
+ type: adam
14
+ lr: 0.002
15
+ adam_beta1: 0.
16
+ adam_beta2: 0.99
17
+ lr_policy:
18
+ iteration_mode: False
19
+ type: step
20
+ step_size: 1000000
21
+ gamma: 1
22
+
23
+ dis_optimizer:
24
+ type: adam
25
+ lr: 0.001882
26
+ adam_beta1: 0.
27
+ adam_beta2: 0.9905
28
+ lr_policy:
29
+ iteration_mode: False
30
+ type: step
31
+ step_size: 1000000
32
+ gamma: 1
33
+
34
+
35
+ trainer:
36
+ type: NTED.extraction_distribution_trainer::Trainer
37
+ gan_mode: style_gan2
38
+ gan_start_iteration: 1000 # 0
39
+ face_crop_method: util.face_crop::crop_face_from_output
40
+ hand_crop_method: util.face_crop::crop_hands_from_output
41
+ d_reg_every: 16
42
+ r1: 10
43
+ loss_weight:
44
+ weight_perceptual: 1
45
+ weight_gan: 1.5
46
+ weight_attn_rec: 15
47
+ weight_face: 1
48
+ weight_hand: 1
49
+ weight_l1: 1
50
+ weight_l1_hand: 0.8
51
+ weight_edge: 100
52
+ attn_weights:
53
+ 8: 1
54
+ 16: 1
55
+ 32: 1
56
+ 64: 1
57
+ 128: 1
58
+ 256: 1
59
+ vgg_param:
60
+ network: vgg19
61
+ layers: ['relu_1_1', 'relu_2_1', 'relu_3_1', 'relu_4_1', 'relu_5_1']
62
+ num_scales: 3
63
+ use_style_loss: True
64
+ style_to_perceptual: 1000
65
+ vgg_hand_param:
66
+ network: vgg19
67
+ layers: ['relu_1_1', 'relu_2_1', 'relu_3_1','relu_3_3', 'relu_4_1', 'relu_4_3', 'relu_5_1']
68
+
69
+ gen:
70
+ type: NTED.extraction_distribution_model::Generator
71
+ param:
72
+ size: 512
73
+ semantic_dim: 30
74
+ channels:
75
+ 16: 512
76
+ 32: 512
77
+ 64: 512
78
+ 128: 256
79
+ 256: 128
80
+ 512: 64
81
+ 1024: 32
82
+ num_labels:
83
+ 16: 16
84
+ 32: 32
85
+ 64: 32
86
+ 128: 64
87
+ 256: 64
88
+ 512: False
89
+ match_kernels:
90
+ 16: 1
91
+ 32: 3
92
+ 64: 3
93
+ 128: 3
94
+ 256: 3
95
+ 512: False
96
+
97
+ dis:
98
+ type: generators.discriminator::Discriminator
99
+ param:
100
+ size: 512
101
+ channels:
102
+ 4: 512
103
+ 8: 512
104
+ 16: 512
105
+ 32: 512
106
+ 64: 512
107
+ 128: 256
108
+ 256: 128
109
+ 512: 64
110
+ is_square_image: False
111
+
112
+
113
+ data:
114
+ type: data.fashion_data::Dataset
115
+ preprocess_mode: resize_and_crop # resize_and_crop
116
+ path: /apdcephfs/share_1474453/zejunzhang/dataset/pose_trans_dataset_2d
117
+ num_workers: 16
118
+ sub_path: 512-352
119
+ resolution: 512
120
+ scale_param: 0.1
121
+ train:
122
+ batch_size: 4 # real_batch_size: 2 * 2 (source-->target & target --> source) * 4 (GPUs) = 16
123
+ distributed: True
124
+ val:
125
+ batch_size: 4
126
+ distributed: True
127
+ hand_keypoint: True
128
+
129
+
NTED/nted_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:359d3d3bac365afe04aa8b906f1dc8891f0dd87ff1dfe5e60059b4fb9bb96af8
3
+ size 284375285
NTED/op/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
+ from .upfirdn2d import upfirdn2d
NTED/op/conv2d_gradfix.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import warnings
3
+
4
+ import torch
5
+ from torch import autograd
6
+ from torch.nn import functional as F
7
+
8
+ enabled = True
9
+ weight_gradients_disabled = False
10
+
11
+
12
+ @contextlib.contextmanager
13
+ def no_weight_gradients():
14
+ global weight_gradients_disabled
15
+
16
+ old = weight_gradients_disabled
17
+ weight_gradients_disabled = True
18
+ yield
19
+ weight_gradients_disabled = old
20
+
21
+
22
+ def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23
+ if could_use_op(input):
24
+ return conv2d_gradfix(
25
+ transpose=False,
26
+ weight_shape=weight.shape,
27
+ stride=stride,
28
+ padding=padding,
29
+ output_padding=0,
30
+ dilation=dilation,
31
+ groups=groups,
32
+ ).apply(input, weight, bias)
33
+
34
+ return F.conv2d(
35
+ input=input,
36
+ weight=weight,
37
+ bias=bias,
38
+ stride=stride,
39
+ padding=padding,
40
+ dilation=dilation,
41
+ groups=groups,
42
+ )
43
+
44
+
45
+ def conv_transpose2d(
46
+ input,
47
+ weight,
48
+ bias=None,
49
+ stride=1,
50
+ padding=0,
51
+ output_padding=0,
52
+ groups=1,
53
+ dilation=1,
54
+ ):
55
+ if could_use_op(input):
56
+ return conv2d_gradfix(
57
+ transpose=True,
58
+ weight_shape=weight.shape,
59
+ stride=stride,
60
+ padding=padding,
61
+ output_padding=output_padding,
62
+ groups=groups,
63
+ dilation=dilation,
64
+ ).apply(input, weight, bias)
65
+
66
+ return F.conv_transpose2d(
67
+ input=input,
68
+ weight=weight,
69
+ bias=bias,
70
+ stride=stride,
71
+ padding=padding,
72
+ output_padding=output_padding,
73
+ dilation=dilation,
74
+ groups=groups,
75
+ )
76
+
77
+
78
+ def could_use_op(input):
79
+ if (not enabled) or (not torch.backends.cudnn.enabled):
80
+ return False
81
+
82
+ if input.device.type != "cuda":
83
+ return False
84
+
85
+ if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
86
+ return True
87
+
88
+ warnings.warn(
89
+ f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
90
+ )
91
+
92
+ return False
93
+
94
+
95
+ def ensure_tuple(xs, ndim):
96
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
97
+
98
+ return xs
99
+
100
+
101
+ conv2d_gradfix_cache = dict()
102
+
103
+
104
+ def conv2d_gradfix(
105
+ transpose, weight_shape, stride, padding, output_padding, dilation, groups
106
+ ):
107
+ ndim = 2
108
+ weight_shape = tuple(weight_shape)
109
+ stride = ensure_tuple(stride, ndim)
110
+ padding = ensure_tuple(padding, ndim)
111
+ output_padding = ensure_tuple(output_padding, ndim)
112
+ dilation = ensure_tuple(dilation, ndim)
113
+
114
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
115
+ if key in conv2d_gradfix_cache:
116
+ return conv2d_gradfix_cache[key]
117
+
118
+ common_kwargs = dict(
119
+ stride=stride, padding=padding, dilation=dilation, groups=groups
120
+ )
121
+
122
+ def calc_output_padding(input_shape, output_shape):
123
+ if transpose:
124
+ return [0, 0]
125
+
126
+ return [
127
+ input_shape[i + 2]
128
+ - (output_shape[i + 2] - 1) * stride[i]
129
+ - (1 - 2 * padding[i])
130
+ - dilation[i] * (weight_shape[i + 2] - 1)
131
+ for i in range(ndim)
132
+ ]
133
+
134
+ class Conv2d(autograd.Function):
135
+ @staticmethod
136
+ def forward(ctx, input, weight, bias):
137
+ if not transpose:
138
+ out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
139
+
140
+ else:
141
+ out = F.conv_transpose2d(
142
+ input=input,
143
+ weight=weight,
144
+ bias=bias,
145
+ output_padding=output_padding,
146
+ **common_kwargs,
147
+ )
148
+
149
+ ctx.save_for_backward(input, weight)
150
+
151
+ return out
152
+
153
+ @staticmethod
154
+ def backward(ctx, grad_output):
155
+ input, weight = ctx.saved_tensors
156
+ grad_input, grad_weight, grad_bias = None, None, None
157
+
158
+ if ctx.needs_input_grad[0]:
159
+ p = calc_output_padding(
160
+ input_shape=input.shape, output_shape=grad_output.shape
161
+ )
162
+ grad_input = conv2d_gradfix(
163
+ transpose=(not transpose),
164
+ weight_shape=weight_shape,
165
+ output_padding=p,
166
+ **common_kwargs,
167
+ ).apply(grad_output, weight, None)
168
+
169
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
170
+ grad_weight = Conv2dGradWeight.apply(grad_output, input)
171
+
172
+ if ctx.needs_input_grad[2]:
173
+ grad_bias = grad_output.sum((0, 2, 3))
174
+
175
+ return grad_input, grad_weight, grad_bias
176
+
177
+ class Conv2dGradWeight(autograd.Function):
178
+ @staticmethod
179
+ def forward(ctx, grad_output, input):
180
+ op = torch._C._jit_get_operation(
181
+ "aten::cudnn_convolution_backward_weight"
182
+ if not transpose
183
+ else "aten::cudnn_convolution_transpose_backward_weight"
184
+ )
185
+ flags = [
186
+ torch.backends.cudnn.benchmark,
187
+ torch.backends.cudnn.deterministic,
188
+ torch.backends.cudnn.allow_tf32,
189
+ ]
190
+ grad_weight = op(
191
+ weight_shape,
192
+ grad_output,
193
+ input,
194
+ padding,
195
+ stride,
196
+ dilation,
197
+ groups,
198
+ *flags,
199
+ )
200
+ ctx.save_for_backward(grad_output, input)
201
+
202
+ return grad_weight
203
+
204
+ @staticmethod
205
+ def backward(ctx, grad_grad_weight):
206
+ grad_output, input = ctx.saved_tensors
207
+ grad_grad_output, grad_grad_input = None, None
208
+
209
+ if ctx.needs_input_grad[0]:
210
+ grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
211
+
212
+ if ctx.needs_input_grad[1]:
213
+ p = calc_output_padding(
214
+ input_shape=input.shape, output_shape=grad_output.shape
215
+ )
216
+ grad_grad_input = conv2d_gradfix(
217
+ transpose=(not transpose),
218
+ weight_shape=weight_shape,
219
+ output_padding=p,
220
+ **common_kwargs,
221
+ ).apply(grad_output, grad_grad_weight, None)
222
+
223
+ return grad_grad_output, grad_grad_input
224
+
225
+ conv2d_gradfix_cache[key] = Conv2d
226
+
227
+ return Conv2d
NTED/op/fused_act.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torch.autograd import Function
7
+ from torch.utils.cpp_extension import load
8
+
9
+
10
+ module_path = os.path.dirname(__file__)
11
+ fused = load(
12
+ "fused",
13
+ sources=[
14
+ os.path.join(module_path, "fused_bias_act.cpp"),
15
+ os.path.join(module_path, "fused_bias_act_kernel.cu"),
16
+ ],
17
+ )
18
+
19
+
20
+ class FusedLeakyReLUFunctionBackward(Function):
21
+ @staticmethod
22
+ def forward(ctx, grad_output, out, bias, negative_slope, scale):
23
+ ctx.save_for_backward(out)
24
+ ctx.negative_slope = negative_slope
25
+ ctx.scale = scale
26
+
27
+ empty = grad_output.new_empty(0)
28
+
29
+ grad_input = fused.fused_bias_act(
30
+ grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
31
+ )
32
+
33
+ dim = [0]
34
+
35
+ if grad_input.ndim > 2:
36
+ dim += list(range(2, grad_input.ndim))
37
+
38
+ if bias:
39
+ grad_bias = grad_input.sum(dim).detach()
40
+
41
+ else:
42
+ grad_bias = empty
43
+
44
+ return grad_input, grad_bias
45
+
46
+ @staticmethod
47
+ def backward(ctx, gradgrad_input, gradgrad_bias):
48
+ out, = ctx.saved_tensors
49
+ gradgrad_out = fused.fused_bias_act(
50
+ gradgrad_input.contiguous(),
51
+ gradgrad_bias.to(gradgrad_input.dtype),
52
+ out,
53
+ 3,
54
+ 1,
55
+ ctx.negative_slope,
56
+ ctx.scale,
57
+ )
58
+
59
+ return gradgrad_out, None, None, None, None
60
+
61
+
62
+ class FusedLeakyReLUFunction(Function):
63
+ @staticmethod
64
+ def forward(ctx, input, bias, negative_slope, scale):
65
+ empty = input.new_empty(0)
66
+
67
+ ctx.bias = bias is not None
68
+
69
+ if bias is None:
70
+ bias = empty
71
+
72
+ out = fused.fused_bias_act(input, bias.to(input.dtype), empty, 3, 0, negative_slope, scale)
73
+ ctx.save_for_backward(out)
74
+ ctx.negative_slope = negative_slope
75
+ ctx.scale = scale
76
+
77
+ return out
78
+
79
+ @staticmethod
80
+ def backward(ctx, grad_output):
81
+ out, = ctx.saved_tensors
82
+
83
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
84
+ grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
85
+ )
86
+
87
+ if not ctx.bias:
88
+ grad_bias = None
89
+
90
+ return grad_input, grad_bias, None, None
91
+
92
+
93
+ class FusedLeakyReLU(nn.Module):
94
+ def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
95
+ super().__init__()
96
+
97
+ if bias:
98
+ self.bias = nn.Parameter(torch.zeros(channel))
99
+
100
+ else:
101
+ self.bias = None
102
+
103
+ self.negative_slope = negative_slope
104
+ self.scale = scale
105
+
106
+ def forward(self, input):
107
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
108
+
109
+
110
+ def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
111
+ if input.device.type == "cpu":
112
+ if bias is not None:
113
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
114
+ return (
115
+ F.leaky_relu(
116
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
117
+ )
118
+ * scale
119
+ )
120
+
121
+ else:
122
+ return F.leaky_relu(input, negative_slope=0.2) * scale
123
+
124
+ else:
125
+ return FusedLeakyReLUFunction.apply(
126
+ input.contiguous(), bias, negative_slope, scale
127
+ )
NTED/op/fused_bias_act.cpp ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #include <ATen/ATen.h>
3
+ #include <torch/extension.h>
4
+
5
+ torch::Tensor fused_bias_act_op(const torch::Tensor &input,
6
+ const torch::Tensor &bias,
7
+ const torch::Tensor &refer, int act, int grad,
8
+ float alpha, float scale);
9
+
10
+ #define CHECK_CUDA(x) \
11
+ TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
12
+ #define CHECK_CONTIGUOUS(x) \
13
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
14
+ #define CHECK_INPUT(x) \
15
+ CHECK_CUDA(x); \
16
+ CHECK_CONTIGUOUS(x)
17
+
18
+ torch::Tensor fused_bias_act(const torch::Tensor &input,
19
+ const torch::Tensor &bias,
20
+ const torch::Tensor &refer, int act, int grad,
21
+ float alpha, float scale) {
22
+ CHECK_INPUT(input);
23
+ CHECK_INPUT(bias);
24
+
25
+ at::DeviceGuard guard(input.device());
26
+
27
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
28
+ }
29
+
30
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
31
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
32
+ }
NTED/op/fused_bias_act_kernel.cu ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+
14
+
15
+ #include <cuda.h>
16
+ #include <cuda_runtime.h>
17
+
18
+ template <typename scalar_t>
19
+ static __global__ void
20
+ fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
21
+ const scalar_t *p_ref, int act, int grad, scalar_t alpha,
22
+ scalar_t scale, int loop_x, int size_x, int step_b,
23
+ int size_b, int use_bias, int use_ref) {
24
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
25
+
26
+ scalar_t zero = 0.0;
27
+
28
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
29
+ loop_idx++, xi += blockDim.x) {
30
+ scalar_t x = p_x[xi];
31
+
32
+ if (use_bias) {
33
+ x += p_b[(xi / step_b) % size_b];
34
+ }
35
+
36
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
37
+
38
+ scalar_t y;
39
+
40
+ switch (act * 10 + grad) {
41
+ default:
42
+ case 10:
43
+ y = x;
44
+ break;
45
+ case 11:
46
+ y = x;
47
+ break;
48
+ case 12:
49
+ y = 0.0;
50
+ break;
51
+
52
+ case 30:
53
+ y = (x > 0.0) ? x : x * alpha;
54
+ break;
55
+ case 31:
56
+ y = (ref > 0.0) ? x : x * alpha;
57
+ break;
58
+ case 32:
59
+ y = 0.0;
60
+ break;
61
+ }
62
+
63
+ out[xi] = y * scale;
64
+ }
65
+ }
66
+
67
+ torch::Tensor fused_bias_act_op(const torch::Tensor &input,
68
+ const torch::Tensor &bias,
69
+ const torch::Tensor &refer, int act, int grad,
70
+ float alpha, float scale) {
71
+ int curDevice = -1;
72
+ cudaGetDevice(&curDevice);
73
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
74
+
75
+ auto x = input.contiguous();
76
+ auto b = bias.contiguous();
77
+ auto ref = refer.contiguous();
78
+
79
+ int use_bias = b.numel() ? 1 : 0;
80
+ int use_ref = ref.numel() ? 1 : 0;
81
+
82
+ int size_x = x.numel();
83
+ int size_b = b.numel();
84
+ int step_b = 1;
85
+
86
+ for (int i = 1 + 1; i < x.dim(); i++) {
87
+ step_b *= x.size(i);
88
+ }
89
+
90
+ int loop_x = 4;
91
+ int block_size = 4 * 32;
92
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
93
+
94
+ auto y = torch::empty_like(x);
95
+
96
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
97
+ x.scalar_type(), "fused_bias_act_kernel", [&] {
98
+ fused_bias_act_kernel<scalar_t><<<grid_size, block_size, 0, stream>>>(
99
+ y.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
100
+ b.data_ptr<scalar_t>(), ref.data_ptr<scalar_t>(), act, grad, alpha,
101
+ scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
102
+ });
103
+
104
+ return y;
105
+ }
NTED/op/upfirdn2d.cpp ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/ATen.h>
2
+ #include <torch/extension.h>
3
+
4
+ torch::Tensor upfirdn2d_op(const torch::Tensor &input,
5
+ const torch::Tensor &kernel, int up_x, int up_y,
6
+ int down_x, int down_y, int pad_x0, int pad_x1,
7
+ int pad_y0, int pad_y1);
8
+
9
+ #define CHECK_CUDA(x) \
10
+ TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
11
+ #define CHECK_CONTIGUOUS(x) \
12
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
13
+ #define CHECK_INPUT(x) \
14
+ CHECK_CUDA(x); \
15
+ CHECK_CONTIGUOUS(x)
16
+
17
+ torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
18
+ int up_x, int up_y, int down_x, int down_y, int pad_x0,
19
+ int pad_x1, int pad_y0, int pad_y1) {
20
+ CHECK_INPUT(input);
21
+ CHECK_INPUT(kernel);
22
+
23
+ at::DeviceGuard guard(input.device());
24
+
25
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
26
+ pad_y0, pad_y1);
27
+ }
28
+
29
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
30
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
31
+ }
NTED/op/upfirdn2d.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import abc
2
+ import os
3
+
4
+ import torch
5
+ from torch.nn import functional as F
6
+ from torch.autograd import Function
7
+ from torch.utils.cpp_extension import load
8
+
9
+
10
+ module_path = os.path.dirname(__file__)
11
+ upfirdn2d_op = load(
12
+ "upfirdn2d",
13
+ sources=[
14
+ os.path.join(module_path, "upfirdn2d.cpp"),
15
+ os.path.join(module_path, "upfirdn2d_kernel.cu"),
16
+ ],
17
+ )
18
+
19
+
20
+ class UpFirDn2dBackward(Function):
21
+ @staticmethod
22
+ def forward(
23
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
24
+ ):
25
+
26
+ up_x, up_y = up
27
+ down_x, down_y = down
28
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
29
+
30
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
31
+
32
+ grad_input = upfirdn2d_op.upfirdn2d(
33
+ grad_output,
34
+ grad_kernel.to(grad_output.dtype),
35
+ down_x,
36
+ down_y,
37
+ up_x,
38
+ up_y,
39
+ g_pad_x0,
40
+ g_pad_x1,
41
+ g_pad_y0,
42
+ g_pad_y1,
43
+ )
44
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
45
+
46
+ ctx.save_for_backward(kernel)
47
+
48
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
49
+
50
+ ctx.up_x = up_x
51
+ ctx.up_y = up_y
52
+ ctx.down_x = down_x
53
+ ctx.down_y = down_y
54
+ ctx.pad_x0 = pad_x0
55
+ ctx.pad_x1 = pad_x1
56
+ ctx.pad_y0 = pad_y0
57
+ ctx.pad_y1 = pad_y1
58
+ ctx.in_size = in_size
59
+ ctx.out_size = out_size
60
+
61
+ return grad_input
62
+
63
+ @staticmethod
64
+ def backward(ctx, gradgrad_input):
65
+ kernel, = ctx.saved_tensors
66
+
67
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
68
+
69
+ gradgrad_out = upfirdn2d_op.upfirdn2d(
70
+ gradgrad_input,
71
+ kernel.to(gradgrad_input.dtype),
72
+ ctx.up_x,
73
+ ctx.up_y,
74
+ ctx.down_x,
75
+ ctx.down_y,
76
+ ctx.pad_x0,
77
+ ctx.pad_x1,
78
+ ctx.pad_y0,
79
+ ctx.pad_y1,
80
+ )
81
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
82
+ gradgrad_out = gradgrad_out.view(
83
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
84
+ )
85
+
86
+ return gradgrad_out, None, None, None, None, None, None, None, None
87
+
88
+
89
+ class UpFirDn2d(Function):
90
+ @staticmethod
91
+ def forward(ctx, input, kernel, up, down, pad):
92
+ up_x, up_y = up
93
+ down_x, down_y = down
94
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
95
+
96
+ kernel_h, kernel_w = kernel.shape
97
+ batch, channel, in_h, in_w = input.shape
98
+ ctx.in_size = input.shape
99
+
100
+ input = input.reshape(-1, in_h, in_w, 1)
101
+
102
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
103
+
104
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
105
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
106
+ ctx.out_size = (out_h, out_w)
107
+
108
+ ctx.up = (up_x, up_y)
109
+ ctx.down = (down_x, down_y)
110
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
111
+
112
+ g_pad_x0 = kernel_w - pad_x0 - 1
113
+ g_pad_y0 = kernel_h - pad_y0 - 1
114
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
115
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
116
+
117
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
118
+
119
+ out = upfirdn2d_op.upfirdn2d(
120
+ input, kernel.to(input.dtype), up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
121
+ )
122
+ # out = out.view(major, out_h, out_w, minor)
123
+ out = out.view(-1, channel, out_h, out_w)
124
+
125
+ return out
126
+
127
+ @staticmethod
128
+ def backward(ctx, grad_output):
129
+ kernel, grad_kernel = ctx.saved_tensors
130
+
131
+ grad_input = None
132
+
133
+ if ctx.needs_input_grad[0]:
134
+ grad_input = UpFirDn2dBackward.apply(
135
+ grad_output,
136
+ kernel,
137
+ grad_kernel,
138
+ ctx.up,
139
+ ctx.down,
140
+ ctx.pad,
141
+ ctx.g_pad,
142
+ ctx.in_size,
143
+ ctx.out_size,
144
+ )
145
+
146
+ return grad_input, None, None, None, None
147
+
148
+
149
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
150
+ if not isinstance(up, abc.Iterable):
151
+ up = (up, up)
152
+
153
+ if not isinstance(down, abc.Iterable):
154
+ down = (down, down)
155
+
156
+ if len(pad) == 2:
157
+ pad = (pad[0], pad[1], pad[0], pad[1])
158
+
159
+ if input.device.type == "cpu":
160
+ out = upfirdn2d_native(input, kernel, *up, *down, *pad)
161
+
162
+ else:
163
+ out = UpFirDn2d.apply(input, kernel, up, down, pad)
164
+
165
+ return out
166
+
167
+
168
+ def upfirdn2d_native(
169
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
170
+ ):
171
+ _, channel, in_h, in_w = input.shape
172
+ input = input.reshape(-1, in_h, in_w, 1)
173
+
174
+ _, in_h, in_w, minor = input.shape
175
+ kernel_h, kernel_w = kernel.shape
176
+
177
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
178
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
179
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
180
+
181
+ out = F.pad(
182
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
183
+ )
184
+ out = out[
185
+ :,
186
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
187
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
188
+ :,
189
+ ]
190
+
191
+ out = out.permute(0, 3, 1, 2)
192
+ out = out.reshape(
193
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
194
+ )
195
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
196
+ out = F.conv2d(out, w)
197
+ out = out.reshape(
198
+ -1,
199
+ minor,
200
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
201
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
202
+ )
203
+ out = out.permute(0, 2, 3, 1)
204
+ out = out[:, ::down_y, ::down_x, :]
205
+
206
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
207
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
208
+
209
+ return out.view(-1, channel, out_h, out_w)
NTED/op/upfirdn2d_kernel.cu ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2
+ //
3
+ // This work is made available under the Nvidia Source Code License-NC.
4
+ // To view a copy of this license, visit
5
+ // https://nvlabs.github.io/stylegan2/license.html
6
+
7
+ #include <torch/types.h>
8
+
9
+ #include <ATen/ATen.h>
10
+ #include <ATen/AccumulateType.h>
11
+ #include <ATen/cuda/CUDAApplyUtils.cuh>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+
14
+ #include <cuda.h>
15
+ #include <cuda_runtime.h>
16
+
17
+ static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
18
+ int c = a / b;
19
+
20
+ if (c * b > a) {
21
+ c--;
22
+ }
23
+
24
+ return c;
25
+ }
26
+
27
+ struct UpFirDn2DKernelParams {
28
+ int up_x;
29
+ int up_y;
30
+ int down_x;
31
+ int down_y;
32
+ int pad_x0;
33
+ int pad_x1;
34
+ int pad_y0;
35
+ int pad_y1;
36
+
37
+ int major_dim;
38
+ int in_h;
39
+ int in_w;
40
+ int minor_dim;
41
+ int kernel_h;
42
+ int kernel_w;
43
+ int out_h;
44
+ int out_w;
45
+ int loop_major;
46
+ int loop_x;
47
+ };
48
+
49
+ template <typename scalar_t>
50
+ __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
51
+ const scalar_t *kernel,
52
+ const UpFirDn2DKernelParams p) {
53
+ int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
54
+ int out_y = minor_idx / p.minor_dim;
55
+ minor_idx -= out_y * p.minor_dim;
56
+ int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
57
+ int major_idx_base = blockIdx.z * p.loop_major;
58
+
59
+ if (out_x_base >= p.out_w || out_y >= p.out_h ||
60
+ major_idx_base >= p.major_dim) {
61
+ return;
62
+ }
63
+
64
+ int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
65
+ int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
66
+ int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
67
+ int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
68
+
69
+ for (int loop_major = 0, major_idx = major_idx_base;
70
+ loop_major < p.loop_major && major_idx < p.major_dim;
71
+ loop_major++, major_idx++) {
72
+ for (int loop_x = 0, out_x = out_x_base;
73
+ loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
74
+ int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
75
+ int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
76
+ int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
77
+ int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
78
+
79
+ const scalar_t *x_p =
80
+ &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
81
+ minor_idx];
82
+ const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
83
+ int x_px = p.minor_dim;
84
+ int k_px = -p.up_x;
85
+ int x_py = p.in_w * p.minor_dim;
86
+ int k_py = -p.up_y * p.kernel_w;
87
+
88
+ scalar_t v = 0.0f;
89
+
90
+ for (int y = 0; y < h; y++) {
91
+ for (int x = 0; x < w; x++) {
92
+ v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
93
+ x_p += x_px;
94
+ k_p += k_px;
95
+ }
96
+
97
+ x_p += x_py - w * x_px;
98
+ k_p += k_py - w * k_px;
99
+ }
100
+
101
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
102
+ minor_idx] = v;
103
+ }
104
+ }
105
+ }
106
+
107
+ template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
108
+ int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
109
+ __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
110
+ const scalar_t *kernel,
111
+ const UpFirDn2DKernelParams p) {
112
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
113
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
114
+
115
+ __shared__ volatile float sk[kernel_h][kernel_w];
116
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
117
+
118
+ int minor_idx = blockIdx.x;
119
+ int tile_out_y = minor_idx / p.minor_dim;
120
+ minor_idx -= tile_out_y * p.minor_dim;
121
+ tile_out_y *= tile_out_h;
122
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
123
+ int major_idx_base = blockIdx.z * p.loop_major;
124
+
125
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
126
+ major_idx_base >= p.major_dim) {
127
+ return;
128
+ }
129
+
130
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
131
+ tap_idx += blockDim.x) {
132
+ int ky = tap_idx / kernel_w;
133
+ int kx = tap_idx - ky * kernel_w;
134
+ scalar_t v = 0.0;
135
+
136
+ if (kx < p.kernel_w & ky < p.kernel_h) {
137
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
138
+ }
139
+
140
+ sk[ky][kx] = v;
141
+ }
142
+
143
+ for (int loop_major = 0, major_idx = major_idx_base;
144
+ loop_major < p.loop_major & major_idx < p.major_dim;
145
+ loop_major++, major_idx++) {
146
+ for (int loop_x = 0, tile_out_x = tile_out_x_base;
147
+ loop_x < p.loop_x & tile_out_x < p.out_w;
148
+ loop_x++, tile_out_x += tile_out_w) {
149
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
150
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
151
+ int tile_in_x = floor_div(tile_mid_x, up_x);
152
+ int tile_in_y = floor_div(tile_mid_y, up_y);
153
+
154
+ __syncthreads();
155
+
156
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
157
+ in_idx += blockDim.x) {
158
+ int rel_in_y = in_idx / tile_in_w;
159
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
160
+ int in_x = rel_in_x + tile_in_x;
161
+ int in_y = rel_in_y + tile_in_y;
162
+
163
+ scalar_t v = 0.0;
164
+
165
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
166
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
167
+ p.minor_dim +
168
+ minor_idx];
169
+ }
170
+
171
+ sx[rel_in_y][rel_in_x] = v;
172
+ }
173
+
174
+ __syncthreads();
175
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
176
+ out_idx += blockDim.x) {
177
+ int rel_out_y = out_idx / tile_out_w;
178
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
179
+ int out_x = rel_out_x + tile_out_x;
180
+ int out_y = rel_out_y + tile_out_y;
181
+
182
+ int mid_x = tile_mid_x + rel_out_x * down_x;
183
+ int mid_y = tile_mid_y + rel_out_y * down_y;
184
+ int in_x = floor_div(mid_x, up_x);
185
+ int in_y = floor_div(mid_y, up_y);
186
+ int rel_in_x = in_x - tile_in_x;
187
+ int rel_in_y = in_y - tile_in_y;
188
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
189
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
190
+
191
+ scalar_t v = 0.0;
192
+
193
+ #pragma unroll
194
+ for (int y = 0; y < kernel_h / up_y; y++)
195
+ #pragma unroll
196
+ for (int x = 0; x < kernel_w / up_x; x++)
197
+ v += sx[rel_in_y + y][rel_in_x + x] *
198
+ sk[kernel_y + y * up_y][kernel_x + x * up_x];
199
+
200
+ if (out_x < p.out_w & out_y < p.out_h) {
201
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
202
+ minor_idx] = v;
203
+ }
204
+ }
205
+ }
206
+ }
207
+ }
208
+
209
+ torch::Tensor upfirdn2d_op(const torch::Tensor &input,
210
+ const torch::Tensor &kernel, int up_x, int up_y,
211
+ int down_x, int down_y, int pad_x0, int pad_x1,
212
+ int pad_y0, int pad_y1) {
213
+ int curDevice = -1;
214
+ cudaGetDevice(&curDevice);
215
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
216
+
217
+ UpFirDn2DKernelParams p;
218
+
219
+ auto x = input.contiguous();
220
+ auto k = kernel.contiguous();
221
+
222
+ p.major_dim = x.size(0);
223
+ p.in_h = x.size(1);
224
+ p.in_w = x.size(2);
225
+ p.minor_dim = x.size(3);
226
+ p.kernel_h = k.size(0);
227
+ p.kernel_w = k.size(1);
228
+ p.up_x = up_x;
229
+ p.up_y = up_y;
230
+ p.down_x = down_x;
231
+ p.down_y = down_y;
232
+ p.pad_x0 = pad_x0;
233
+ p.pad_x1 = pad_x1;
234
+ p.pad_y0 = pad_y0;
235
+ p.pad_y1 = pad_y1;
236
+
237
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
238
+ p.down_y;
239
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
240
+ p.down_x;
241
+
242
+ auto out =
243
+ at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
244
+
245
+ int mode = -1;
246
+
247
+ int tile_out_h = -1;
248
+ int tile_out_w = -1;
249
+
250
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
251
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
252
+ mode = 1;
253
+ tile_out_h = 16;
254
+ tile_out_w = 64;
255
+ }
256
+
257
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
258
+ p.kernel_h <= 3 && p.kernel_w <= 3) {
259
+ mode = 2;
260
+ tile_out_h = 16;
261
+ tile_out_w = 64;
262
+ }
263
+
264
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
265
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
266
+ mode = 3;
267
+ tile_out_h = 16;
268
+ tile_out_w = 64;
269
+ }
270
+
271
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
272
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
273
+ mode = 4;
274
+ tile_out_h = 16;
275
+ tile_out_w = 64;
276
+ }
277
+
278
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
279
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
280
+ mode = 5;
281
+ tile_out_h = 8;
282
+ tile_out_w = 32;
283
+ }
284
+
285
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
286
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
287
+ mode = 6;
288
+ tile_out_h = 8;
289
+ tile_out_w = 32;
290
+ }
291
+
292
+ dim3 block_size;
293
+ dim3 grid_size;
294
+
295
+ if (tile_out_h > 0 && tile_out_w > 0) {
296
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
297
+ p.loop_x = 1;
298
+ block_size = dim3(32 * 8, 1, 1);
299
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
300
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
301
+ (p.major_dim - 1) / p.loop_major + 1);
302
+ } else {
303
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
304
+ p.loop_x = 4;
305
+ block_size = dim3(4, 32, 1);
306
+ grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
307
+ (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
308
+ (p.major_dim - 1) / p.loop_major + 1);
309
+ }
310
+
311
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
312
+ switch (mode) {
313
+ case 1:
314
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>
315
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
316
+ x.data_ptr<scalar_t>(),
317
+ k.data_ptr<scalar_t>(), p);
318
+
319
+ break;
320
+
321
+ case 2:
322
+ upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>
323
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
324
+ x.data_ptr<scalar_t>(),
325
+ k.data_ptr<scalar_t>(), p);
326
+
327
+ break;
328
+
329
+ case 3:
330
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>
331
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
332
+ x.data_ptr<scalar_t>(),
333
+ k.data_ptr<scalar_t>(), p);
334
+
335
+ break;
336
+
337
+ case 4:
338
+ upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>
339
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
340
+ x.data_ptr<scalar_t>(),
341
+ k.data_ptr<scalar_t>(), p);
342
+
343
+ break;
344
+
345
+ case 5:
346
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
347
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
348
+ x.data_ptr<scalar_t>(),
349
+ k.data_ptr<scalar_t>(), p);
350
+
351
+ break;
352
+
353
+ case 6:
354
+ upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>
355
+ <<<grid_size, block_size, 0, stream>>>(out.data_ptr<scalar_t>(),
356
+ x.data_ptr<scalar_t>(),
357
+ k.data_ptr<scalar_t>(), p);
358
+
359
+ break;
360
+
361
+ default:
362
+ upfirdn2d_kernel_large<scalar_t><<<grid_size, block_size, 0, stream>>>(
363
+ out.data_ptr<scalar_t>(), x.data_ptr<scalar_t>(),
364
+ k.data_ptr<scalar_t>(), p);
365
+ }
366
+ });
367
+
368
+ return out;
369
+ }
app.py CHANGED
@@ -1,17 +1,29 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- def greet(年龄预测器_输入您的年龄):
4
- return "恭喜,您今年" + 年龄预测器_输入您的年龄 + "岁了!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
7
 
8
- demo.launch()
 
 
 
 
 
 
 
 
 
9
 
10
  '''
11
  TODO
12
- 先把openpose light整合进来测试一下
13
-
14
  测试视频展示功能
15
-
16
-
17
  '''
 
1
  import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ from NTED.NTED_module import NTED
6
 
7
+ NTED_Module = NTED()
 
8
 
9
+ def pose_transfer(上传人体姿态图):
10
+ img = 上传人体姿态图
11
+ fake_img = NTED_Module.inference(img)
12
+
13
+ return fake_img
14
 
15
+ with gr.Column():
16
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
17
+
18
+ gr.Interface(fn=pose_transfer,
19
+ inputs=["image"],
20
+ outputs=[result_gallery],
21
+ title="谷小雨姿态驱动图像",
22
+ examples=[["example/exp1.png"], ["example/exp2.png"], ["example/exp3.png"],\
23
+ ["example/exp4.png"], ["example/exp5.png"], ["example/exp6.png"]],
24
+ ).launch(server_name='0.0.0.0')
25
 
26
  '''
27
  TODO
 
 
28
  测试视频展示功能
 
 
29
  '''
example/exp1.png ADDED
example/exp2.png ADDED
example/exp3.png ADDED
example/exp4.png ADDED
example/exp5.png ADDED
example/exp6.png ADDED
example/ref_img.png ADDED

Git LFS Details

  • SHA256: b3396e7f8e0a18f0c8dc50d1f98cabf26c13d8629e5b454a680531ea6daf31ed
  • Pointer size: 132 Bytes
  • Size of remote file: 1.01 MB
lite_openpose/body_bbox_detector.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+
3
+ import os
4
+ import os.path as osp
5
+ import sys
6
+ import numpy as np
7
+ import cv2
8
+ import math
9
+
10
+ import torch
11
+ import torchvision.transforms as transforms
12
+ # from PIL import Image
13
+
14
+ # Code from https://github.com/Daniil-Osokin/lightweight-human-pose-estimation.pytorch/blob/master/demo.py
15
+
16
+ # 2D body pose estimator
17
+ sys.path.append('/apdcephfs/share_1474453/zejunzhang/workspace/HR-VITON/dataset_process_utils/lite_openpose')
18
+ from pose2d_models.with_mobilenet import PoseEstimationWithMobileNet
19
+ from modules.load_state import load_state
20
+ from modules.pose import Pose, track_poses
21
+ from modules.keypoints import extract_keypoints, group_keypoints
22
+
23
+
24
+ def normalize(img, img_mean, img_scale):
25
+ img = np.array(img, dtype=np.float32)
26
+ img = (img - img_mean) * img_scale
27
+ return img
28
+
29
+
30
+ def pad_width(img, stride, pad_value, min_dims):
31
+ h, w, _ = img.shape
32
+ h = min(min_dims[0], h)
33
+ min_dims[0] = math.ceil(min_dims[0] / float(stride)) * stride
34
+ min_dims[1] = max(min_dims[1], w)
35
+ min_dims[1] = math.ceil(min_dims[1] / float(stride)) * stride
36
+ pad = []
37
+ pad.append(int(math.floor((min_dims[0] - h) / 2.0)))
38
+ pad.append(int(math.floor((min_dims[1] - w) / 2.0)))
39
+ pad.append(int(min_dims[0] - h - pad[0]))
40
+ pad.append(int(min_dims[1] - w - pad[1]))
41
+ padded_img = cv2.copyMakeBorder(img, pad[0], pad[2], pad[1], pad[3],
42
+ cv2.BORDER_CONSTANT, value=pad_value)
43
+ return padded_img, pad
44
+
45
+
46
+ class BodyPoseEstimator(object):
47
+ """
48
+ Hand Detector for third-view input.
49
+ It combines a body pose estimator (https://github.com/jhugestar/lightweight-human-pose-estimation.pytorch.git)
50
+ """
51
+ def __init__(self, device='cpu'):
52
+ # print("Loading Body Pose Estimator")
53
+ self.device=device
54
+ self.__load_body_estimator()
55
+
56
+
57
+
58
+ def __load_body_estimator(self):
59
+ net = PoseEstimationWithMobileNet()
60
+ pose2d_checkpoint = "lite_openpose/checkpoint_iter_370000.pth"
61
+ checkpoint = torch.load(pose2d_checkpoint, map_location='cpu')
62
+ load_state(net, checkpoint)
63
+ net = net.eval()
64
+ net = net.to(self.device)
65
+ self.model = net
66
+
67
+
68
+ #Code from https://github.com/Daniil-Osokin/lightweight-human-pose-estimation.pytorch/demo.py
69
+ def __infer_fast(self, img, input_height_size, stride, upsample_ratio,
70
+ cpu=False, pad_value=(0, 0, 0), img_mean=(128, 128, 128), img_scale=1/256):
71
+ height, width, _ = img.shape
72
+ scale = input_height_size / height
73
+
74
+ scaled_img = cv2.resize(img, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
75
+ scaled_img = normalize(scaled_img, img_mean, img_scale)
76
+ min_dims = [input_height_size, max(scaled_img.shape[1], input_height_size)]
77
+ padded_img, pad = pad_width(scaled_img, stride, pad_value, min_dims)
78
+
79
+ tensor_img = torch.from_numpy(padded_img).permute(2, 0, 1).unsqueeze(0).float()
80
+ if not cpu:
81
+ tensor_img = tensor_img.to(self.device)
82
+
83
+ with torch.no_grad():
84
+ stages_output = self.model(tensor_img)
85
+
86
+ stage2_heatmaps = stages_output[-2]
87
+ heatmaps = np.transpose(stage2_heatmaps.squeeze().cpu().data.numpy(), (1, 2, 0))
88
+ heatmaps = cv2.resize(heatmaps, (0, 0), fx=upsample_ratio, fy=upsample_ratio, interpolation=cv2.INTER_CUBIC)
89
+
90
+ stage2_pafs = stages_output[-1]
91
+ pafs = np.transpose(stage2_pafs.squeeze().cpu().data.numpy(), (1, 2, 0))
92
+ pafs = cv2.resize(pafs, (0, 0), fx=upsample_ratio, fy=upsample_ratio, interpolation=cv2.INTER_CUBIC)
93
+
94
+ return heatmaps, pafs, scale, pad
95
+
96
+ def detect_body_pose(self, img):
97
+ """
98
+ Output:
99
+ current_bbox: BBOX_XYWH
100
+ """
101
+ stride = 8
102
+ upsample_ratio = 4
103
+ orig_img = img.copy()
104
+
105
+ # forward
106
+ heatmaps, pafs, scale, pad = self.__infer_fast(img,
107
+ input_height_size=256, stride=stride, upsample_ratio=upsample_ratio)
108
+
109
+ total_keypoints_num = 0
110
+ all_keypoints_by_type = []
111
+ num_keypoints = Pose.num_kpts
112
+ for kpt_idx in range(num_keypoints): # 19th for bg
113
+ total_keypoints_num += extract_keypoints(heatmaps[:, :, kpt_idx], all_keypoints_by_type, total_keypoints_num)
114
+
115
+ pose_entries, all_keypoints = group_keypoints(all_keypoints_by_type, pafs, demo=True)
116
+ for kpt_id in range(all_keypoints.shape[0]):
117
+ all_keypoints[kpt_id, 0] = (all_keypoints[kpt_id, 0] * stride / upsample_ratio - pad[1]) / scale
118
+ all_keypoints[kpt_id, 1] = (all_keypoints[kpt_id, 1] * stride / upsample_ratio - pad[0]) / scale
119
+
120
+ '''
121
+ # print(len(pose_entries))
122
+ if len(pose_entries)>1:
123
+ pose_entries = pose_entries[:1]
124
+ print("We only support one person currently")
125
+ # assert len(pose_entries) == 1, "We only support one person currently"
126
+ '''
127
+
128
+ current_poses, current_bbox = list(), list()
129
+ for n in range(len(pose_entries)):
130
+ if len(pose_entries[n]) == 0:
131
+ continue
132
+ pose_keypoints = np.ones((num_keypoints, 2), dtype=np.int32) * -1
133
+ for kpt_id in range(num_keypoints):
134
+ if pose_entries[n][kpt_id] != -1.0: # keypoint was found
135
+ pose_keypoints[kpt_id, 0] = int(all_keypoints[int(pose_entries[n][kpt_id]), 0])
136
+ pose_keypoints[kpt_id, 1] = int(all_keypoints[int(pose_entries[n][kpt_id]), 1])
137
+ pose = Pose(pose_keypoints, pose_entries[n][18])
138
+ current_poses.append(pose.keypoints)
139
+ current_bbox.append(np.array(pose.bbox))
140
+
141
+ # enlarge the bbox
142
+ for i, bbox in enumerate(current_bbox):
143
+ x, y, w, h = bbox
144
+ margin = 0.2
145
+ x_margin = int(w * margin)
146
+ y_margin = int(h * margin)
147
+ x0 = max(x-x_margin, 0)
148
+ y0 = max(y-y_margin, 0)
149
+ x1 = min(x+w+x_margin, orig_img.shape[1])
150
+ y1 = min(y+h+y_margin, orig_img.shape[0])
151
+ current_bbox[i] = np.array((x0, y0, x1, y1)).astype(np.int32) # ltrb
152
+
153
+ # 只拿一个人
154
+ body_point_list = []
155
+ if len(current_poses) > 0:
156
+ for item in current_poses[0]:
157
+ if item[0] == item[1] == -1:
158
+ body_point_list += [0.0, 0.0, 0.0]
159
+ else:
160
+ body_point_list += [float(item[0]), float(item[1]), 1.0]
161
+ else:
162
+ for i in range(18):
163
+ body_point_list += [0.0, 0.0, 0.0]
164
+
165
+ pose_dict = dict()
166
+ pose_dict["people"] = []
167
+ pose_dict["people"].append({
168
+ "person_id": [-1],
169
+ "pose_keypoints_2d": body_point_list,
170
+ "hand_left_keypoints_2d": [],
171
+ "hand_right_keypoints_2d": [],
172
+ "face_keypoints_2d": [],
173
+ "pose_keypoints_3d": [],
174
+ "face_keypoints_3d": [],
175
+ "hand_left_keypoints_3d": [],
176
+ "hand_right_keypoints_3d": [],
177
+ })
178
+
179
+ return current_poses, current_bbox
lite_openpose/checkpoint_iter_370000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:517c86f769c6636583083f1467e3d212a0006c27109edb3aeffc19a79622d411
3
+ size 87959810
lite_openpose/modules/__init__.py ADDED
File without changes
lite_openpose/modules/conv.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ def conv(in_channels, out_channels, kernel_size=3, padding=1, bn=True, dilation=1, stride=1, relu=True, bias=True):
5
+ modules = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)]
6
+ if bn:
7
+ modules.append(nn.BatchNorm2d(out_channels))
8
+ if relu:
9
+ modules.append(nn.ReLU(inplace=True))
10
+ return nn.Sequential(*modules)
11
+
12
+
13
+ def conv_dw(in_channels, out_channels, kernel_size=3, padding=1, stride=1, dilation=1):
14
+ return nn.Sequential(
15
+ nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation=dilation, groups=in_channels, bias=False),
16
+ nn.BatchNorm2d(in_channels),
17
+ nn.ReLU(inplace=True),
18
+
19
+ nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False),
20
+ nn.BatchNorm2d(out_channels),
21
+ nn.ReLU(inplace=True),
22
+ )
23
+
24
+
25
+ def conv_dw_no_bn(in_channels, out_channels, kernel_size=3, padding=1, stride=1, dilation=1):
26
+ return nn.Sequential(
27
+ nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation=dilation, groups=in_channels, bias=False),
28
+ nn.ELU(inplace=True),
29
+
30
+ nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False),
31
+ nn.ELU(inplace=True),
32
+ )
lite_openpose/modules/get_parameters.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ def get_parameters(model, predicate):
5
+ for module in model.modules():
6
+ for param_name, param in module.named_parameters():
7
+ if predicate(module, param_name):
8
+ yield param
9
+
10
+
11
+ def get_parameters_conv(model, name):
12
+ return get_parameters(model, lambda m, p: isinstance(m, nn.Conv2d) and m.groups == 1 and p == name)
13
+
14
+
15
+ def get_parameters_conv_depthwise(model, name):
16
+ return get_parameters(model, lambda m, p: isinstance(m, nn.Conv2d)
17
+ and m.groups == m.in_channels
18
+ and m.in_channels == m.out_channels
19
+ and p == name)
20
+
21
+
22
+ def get_parameters_bn(model, name):
23
+ return get_parameters(model, lambda m, p: isinstance(m, nn.BatchNorm2d) and p == name)
lite_openpose/modules/keypoints.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ from operator import itemgetter
4
+
5
+ BODY_PARTS_KPT_IDS = [[1, 2], [1, 5], [2, 3], [3, 4], [5, 6], [6, 7], [1, 8], [8, 9], [9, 10], [1, 11],
6
+ [11, 12], [12, 13], [1, 0], [0, 14], [14, 16], [0, 15], [15, 17], [2, 16], [5, 17]]
7
+ BODY_PARTS_PAF_IDS = ([12, 13], [20, 21], [14, 15], [16, 17], [22, 23], [24, 25], [0, 1], [2, 3], [4, 5],
8
+ [6, 7], [8, 9], [10, 11], [28, 29], [30, 31], [34, 35], [32, 33], [36, 37], [18, 19], [26, 27])
9
+
10
+
11
+ def linspace2d(start, stop, n=10):
12
+ points = 1 / (n - 1) * (stop - start)
13
+ return points[:, None] * np.arange(n) + start[:, None]
14
+
15
+
16
+ def extract_keypoints(heatmap, all_keypoints, total_keypoint_num):
17
+ heatmap[heatmap < 0.1] = 0
18
+ heatmap_with_borders = np.pad(heatmap, [(2, 2), (2, 2)], mode='constant')
19
+ heatmap_center = heatmap_with_borders[1:heatmap_with_borders.shape[0]-1, 1:heatmap_with_borders.shape[1]-1]
20
+ heatmap_left = heatmap_with_borders[1:heatmap_with_borders.shape[0]-1, 2:heatmap_with_borders.shape[1]]
21
+ heatmap_right = heatmap_with_borders[1:heatmap_with_borders.shape[0]-1, 0:heatmap_with_borders.shape[1]-2]
22
+ heatmap_up = heatmap_with_borders[2:heatmap_with_borders.shape[0], 1:heatmap_with_borders.shape[1]-1]
23
+ heatmap_down = heatmap_with_borders[0:heatmap_with_borders.shape[0]-2, 1:heatmap_with_borders.shape[1]-1]
24
+
25
+ heatmap_peaks = (heatmap_center > heatmap_left) &\
26
+ (heatmap_center > heatmap_right) &\
27
+ (heatmap_center > heatmap_up) &\
28
+ (heatmap_center > heatmap_down)
29
+ heatmap_peaks = heatmap_peaks[1:heatmap_center.shape[0]-1, 1:heatmap_center.shape[1]-1]
30
+ keypoints = list(zip(np.nonzero(heatmap_peaks)[1], np.nonzero(heatmap_peaks)[0])) # (w, h)
31
+ keypoints = sorted(keypoints, key=itemgetter(0))
32
+
33
+ suppressed = np.zeros(len(keypoints), np.uint8)
34
+ keypoints_with_score_and_id = []
35
+ keypoint_num = 0
36
+ for i in range(len(keypoints)):
37
+ if suppressed[i]:
38
+ continue
39
+ for j in range(i+1, len(keypoints)):
40
+ if math.sqrt((keypoints[i][0] - keypoints[j][0]) ** 2 +
41
+ (keypoints[i][1] - keypoints[j][1]) ** 2) < 6:
42
+ suppressed[j] = 1
43
+ keypoint_with_score_and_id = (keypoints[i][0], keypoints[i][1], heatmap[keypoints[i][1], keypoints[i][0]],
44
+ total_keypoint_num + keypoint_num)
45
+ keypoints_with_score_and_id.append(keypoint_with_score_and_id)
46
+ keypoint_num += 1
47
+ all_keypoints.append(keypoints_with_score_and_id)
48
+ return keypoint_num
49
+
50
+
51
+ def group_keypoints(all_keypoints_by_type, pafs, pose_entry_size=20, min_paf_score=0.05, demo=False):
52
+ pose_entries = []
53
+ all_keypoints = np.array([item for sublist in all_keypoints_by_type for item in sublist])
54
+ for part_id in range(len(BODY_PARTS_PAF_IDS)):
55
+ part_pafs = pafs[:, :, BODY_PARTS_PAF_IDS[part_id]]
56
+ kpts_a = all_keypoints_by_type[BODY_PARTS_KPT_IDS[part_id][0]]
57
+ kpts_b = all_keypoints_by_type[BODY_PARTS_KPT_IDS[part_id][1]]
58
+ num_kpts_a = len(kpts_a)
59
+ num_kpts_b = len(kpts_b)
60
+ kpt_a_id = BODY_PARTS_KPT_IDS[part_id][0]
61
+ kpt_b_id = BODY_PARTS_KPT_IDS[part_id][1]
62
+
63
+ if num_kpts_a == 0 and num_kpts_b == 0: # no keypoints for such body part
64
+ continue
65
+ elif num_kpts_a == 0: # body part has just 'b' keypoints
66
+ for i in range(num_kpts_b):
67
+ num = 0
68
+ for j in range(len(pose_entries)): # check if already in some pose, was added by another body part
69
+ if pose_entries[j][kpt_b_id] == kpts_b[i][3]:
70
+ num += 1
71
+ continue
72
+ if num == 0:
73
+ pose_entry = np.ones(pose_entry_size) * -1
74
+ pose_entry[kpt_b_id] = kpts_b[i][3] # keypoint idx
75
+ pose_entry[-1] = 1 # num keypoints in pose
76
+ pose_entry[-2] = kpts_b[i][2] # pose score
77
+ pose_entries.append(pose_entry)
78
+ continue
79
+ elif num_kpts_b == 0: # body part has just 'a' keypoints
80
+ for i in range(num_kpts_a):
81
+ num = 0
82
+ for j in range(len(pose_entries)):
83
+ if pose_entries[j][kpt_a_id] == kpts_a[i][3]:
84
+ num += 1
85
+ continue
86
+ if num == 0:
87
+ pose_entry = np.ones(pose_entry_size) * -1
88
+ pose_entry[kpt_a_id] = kpts_a[i][3]
89
+ pose_entry[-1] = 1
90
+ pose_entry[-2] = kpts_a[i][2]
91
+ pose_entries.append(pose_entry)
92
+ continue
93
+
94
+ connections = []
95
+ for i in range(num_kpts_a):
96
+ kpt_a = np.array(kpts_a[i][0:2])
97
+ for j in range(num_kpts_b):
98
+ kpt_b = np.array(kpts_b[j][0:2])
99
+ mid_point = [(), ()]
100
+ mid_point[0] = (int(round((kpt_a[0] + kpt_b[0]) * 0.5)),
101
+ int(round((kpt_a[1] + kpt_b[1]) * 0.5)))
102
+ mid_point[1] = mid_point[0]
103
+
104
+ vec = [kpt_b[0] - kpt_a[0], kpt_b[1] - kpt_a[1]]
105
+ vec_norm = math.sqrt(vec[0] ** 2 + vec[1] ** 2)
106
+ if vec_norm == 0:
107
+ continue
108
+ vec[0] /= vec_norm
109
+ vec[1] /= vec_norm
110
+ cur_point_score = (vec[0] * part_pafs[mid_point[0][1], mid_point[0][0], 0] +
111
+ vec[1] * part_pafs[mid_point[1][1], mid_point[1][0], 1])
112
+
113
+ height_n = pafs.shape[0] // 2
114
+ success_ratio = 0
115
+ point_num = 10 # number of points to integration over paf
116
+ if cur_point_score > -100:
117
+ passed_point_score = 0
118
+ passed_point_num = 0
119
+ x, y = linspace2d(kpt_a, kpt_b)
120
+ for point_idx in range(point_num):
121
+ if not demo:
122
+ px = int(round(x[point_idx]))
123
+ py = int(round(y[point_idx]))
124
+ else:
125
+ px = int(x[point_idx])
126
+ py = int(y[point_idx])
127
+ paf = part_pafs[py, px, 0:2]
128
+ cur_point_score = vec[0] * paf[0] + vec[1] * paf[1]
129
+ if cur_point_score > min_paf_score:
130
+ passed_point_score += cur_point_score
131
+ passed_point_num += 1
132
+ success_ratio = passed_point_num / point_num
133
+ ratio = 0
134
+ if passed_point_num > 0:
135
+ ratio = passed_point_score / passed_point_num
136
+ ratio += min(height_n / vec_norm - 1, 0)
137
+ if ratio > 0 and success_ratio > 0.8:
138
+ score_all = ratio + kpts_a[i][2] + kpts_b[j][2]
139
+ connections.append([i, j, ratio, score_all])
140
+ if len(connections) > 0:
141
+ connections = sorted(connections, key=itemgetter(2), reverse=True)
142
+
143
+ num_connections = min(num_kpts_a, num_kpts_b)
144
+ has_kpt_a = np.zeros(num_kpts_a, dtype=np.int32)
145
+ has_kpt_b = np.zeros(num_kpts_b, dtype=np.int32)
146
+ filtered_connections = []
147
+ for row in range(len(connections)):
148
+ if len(filtered_connections) == num_connections:
149
+ break
150
+ i, j, cur_point_score = connections[row][0:3]
151
+ if not has_kpt_a[i] and not has_kpt_b[j]:
152
+ filtered_connections.append([kpts_a[i][3], kpts_b[j][3], cur_point_score])
153
+ has_kpt_a[i] = 1
154
+ has_kpt_b[j] = 1
155
+ connections = filtered_connections
156
+ if len(connections) == 0:
157
+ continue
158
+
159
+ if part_id == 0:
160
+ pose_entries = [np.ones(pose_entry_size) * -1 for _ in range(len(connections))]
161
+ for i in range(len(connections)):
162
+ pose_entries[i][BODY_PARTS_KPT_IDS[0][0]] = connections[i][0]
163
+ pose_entries[i][BODY_PARTS_KPT_IDS[0][1]] = connections[i][1]
164
+ pose_entries[i][-1] = 2
165
+ pose_entries[i][-2] = np.sum(all_keypoints[connections[i][0:2], 2]) + connections[i][2]
166
+ elif part_id == 17 or part_id == 18:
167
+ kpt_a_id = BODY_PARTS_KPT_IDS[part_id][0]
168
+ kpt_b_id = BODY_PARTS_KPT_IDS[part_id][1]
169
+ for i in range(len(connections)):
170
+ for j in range(len(pose_entries)):
171
+ if pose_entries[j][kpt_a_id] == connections[i][0] and pose_entries[j][kpt_b_id] == -1:
172
+ pose_entries[j][kpt_b_id] = connections[i][1]
173
+ elif pose_entries[j][kpt_b_id] == connections[i][1] and pose_entries[j][kpt_a_id] == -1:
174
+ pose_entries[j][kpt_a_id] = connections[i][0]
175
+ continue
176
+ else:
177
+ kpt_a_id = BODY_PARTS_KPT_IDS[part_id][0]
178
+ kpt_b_id = BODY_PARTS_KPT_IDS[part_id][1]
179
+ for i in range(len(connections)):
180
+ num = 0
181
+ for j in range(len(pose_entries)):
182
+ if pose_entries[j][kpt_a_id] == connections[i][0]:
183
+ pose_entries[j][kpt_b_id] = connections[i][1]
184
+ num += 1
185
+ pose_entries[j][-1] += 1
186
+ pose_entries[j][-2] += all_keypoints[connections[i][1], 2] + connections[i][2]
187
+ if num == 0:
188
+ pose_entry = np.ones(pose_entry_size) * -1
189
+ pose_entry[kpt_a_id] = connections[i][0]
190
+ pose_entry[kpt_b_id] = connections[i][1]
191
+ pose_entry[-1] = 2
192
+ pose_entry[-2] = np.sum(all_keypoints[connections[i][0:2], 2]) + connections[i][2]
193
+ pose_entries.append(pose_entry)
194
+
195
+ filtered_entries = []
196
+ for i in range(len(pose_entries)):
197
+ if pose_entries[i][-1] < 3 or (pose_entries[i][-2] / pose_entries[i][-1] < 0.2):
198
+ continue
199
+ filtered_entries.append(pose_entries[i])
200
+ pose_entries = np.asarray(filtered_entries)
201
+ return pose_entries, all_keypoints
lite_openpose/modules/load_state.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+
3
+
4
+ def load_state(net, checkpoint):
5
+ source_state = checkpoint['state_dict']
6
+ target_state = net.state_dict()
7
+ new_target_state = collections.OrderedDict()
8
+ for target_key, target_value in target_state.items():
9
+ if target_key in source_state and source_state[target_key].size() == target_state[target_key].size():
10
+ new_target_state[target_key] = source_state[target_key]
11
+ else:
12
+ new_target_state[target_key] = target_state[target_key]
13
+ print('[WARNING] Not found pre-trained parameters for {}'.format(target_key))
14
+
15
+ net.load_state_dict(new_target_state)
16
+
17
+
18
+ def load_from_mobilenet(net, checkpoint):
19
+ source_state = checkpoint['state_dict']
20
+ target_state = net.state_dict()
21
+ new_target_state = collections.OrderedDict()
22
+ for target_key, target_value in target_state.items():
23
+ k = target_key
24
+ if k.find('model') != -1:
25
+ k = k.replace('model', 'module.model')
26
+ if k in source_state and source_state[k].size() == target_state[target_key].size():
27
+ new_target_state[target_key] = source_state[k]
28
+ else:
29
+ new_target_state[target_key] = target_state[target_key]
30
+ print('[WARNING] Not found pre-trained parameters for {}'.format(target_key))
31
+
32
+ net.load_state_dict(new_target_state)
lite_openpose/modules/loss.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ def l2_loss(input, target, mask, batch_size):
2
+ loss = (input - target) * mask
3
+ loss = (loss * loss) / 2 / batch_size
4
+
5
+ return loss.sum()
lite_openpose/modules/one_euro_filter.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+
4
+ def get_alpha(rate=30, cutoff=1):
5
+ tau = 1 / (2 * math.pi * cutoff)
6
+ te = 1 / rate
7
+ return 1 / (1 + tau / te)
8
+
9
+
10
+ class LowPassFilter:
11
+ def __init__(self):
12
+ self.x_previous = None
13
+
14
+ def __call__(self, x, alpha=0.5):
15
+ if self.x_previous is None:
16
+ self.x_previous = x
17
+ return x
18
+ x_filtered = alpha * x + (1 - alpha) * self.x_previous
19
+ self.x_previous = x_filtered
20
+ return x_filtered
21
+
22
+
23
+ class OneEuroFilter:
24
+ def __init__(self, freq=15, mincutoff=1, beta=0.05, dcutoff=1):
25
+ self.freq = freq
26
+ self.mincutoff = mincutoff
27
+ self.beta = beta
28
+ self.dcutoff = dcutoff
29
+ self.filter_x = LowPassFilter()
30
+ self.filter_dx = LowPassFilter()
31
+ self.x_previous = None
32
+ self.dx = None
33
+
34
+ def __call__(self, x):
35
+ if self.dx is None:
36
+ self.dx = 0
37
+ else:
38
+ self.dx = (x - self.x_previous) * self.freq
39
+ dx_smoothed = self.filter_dx(self.dx, get_alpha(self.freq, self.dcutoff))
40
+ cutoff = self.mincutoff + self.beta * abs(dx_smoothed)
41
+ x_filtered = self.filter_x(x, get_alpha(self.freq, cutoff))
42
+ self.x_previous = x
43
+ return x_filtered
44
+
45
+
46
+ if __name__ == '__main__':
47
+ filter = OneEuroFilter(freq=15, beta=0.1)
48
+ for val in range(10):
49
+ x = val + (-1)**(val % 2)
50
+ x_filtered = filter(x)
51
+ print(x_filtered, x)
lite_openpose/modules/pose.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ from modules.keypoints import BODY_PARTS_KPT_IDS, BODY_PARTS_PAF_IDS
5
+ from modules.one_euro_filter import OneEuroFilter
6
+
7
+
8
+ class Pose:
9
+ num_kpts = 18
10
+ kpt_names = ['nose', 'neck',
11
+ 'r_sho', 'r_elb', 'r_wri', 'l_sho', 'l_elb', 'l_wri',
12
+ 'r_hip', 'r_knee', 'r_ank', 'l_hip', 'l_knee', 'l_ank',
13
+ 'r_eye', 'l_eye',
14
+ 'r_ear', 'l_ear']
15
+ sigmas = np.array([.26, .79, .79, .72, .62, .79, .72, .62, 1.07, .87, .89, 1.07, .87, .89, .25, .25, .35, .35],
16
+ dtype=np.float32) / 10.0
17
+ vars = (sigmas * 2) ** 2
18
+ last_id = -1
19
+ color = [0, 224, 255]
20
+
21
+ def __init__(self, keypoints, confidence):
22
+ super().__init__()
23
+ self.keypoints = keypoints
24
+ self.confidence = confidence
25
+ self.bbox = Pose.get_bbox(self.keypoints)
26
+ self.id = None
27
+ self.filters = [[OneEuroFilter(), OneEuroFilter()] for _ in range(Pose.num_kpts)]
28
+
29
+ @staticmethod
30
+ def get_bbox(keypoints):
31
+ found_keypoints = np.zeros((np.count_nonzero(keypoints[:, 0] != -1), 2), dtype=np.int32)
32
+ found_kpt_id = 0
33
+ for kpt_id in range(Pose.num_kpts):
34
+ if keypoints[kpt_id, 0] == -1:
35
+ continue
36
+ found_keypoints[found_kpt_id] = keypoints[kpt_id]
37
+ found_kpt_id += 1
38
+ bbox = cv2.boundingRect(found_keypoints)
39
+ return bbox
40
+
41
+ def update_id(self, id=None):
42
+ self.id = id
43
+ if self.id is None:
44
+ self.id = Pose.last_id + 1
45
+ Pose.last_id += 1
46
+
47
+ def draw(self, img):
48
+ assert self.keypoints.shape == (Pose.num_kpts, 2)
49
+
50
+ for part_id in range(len(BODY_PARTS_PAF_IDS) - 2):
51
+ kpt_a_id = BODY_PARTS_KPT_IDS[part_id][0]
52
+ global_kpt_a_id = self.keypoints[kpt_a_id, 0]
53
+ if global_kpt_a_id != -1:
54
+ x_a, y_a = self.keypoints[kpt_a_id]
55
+ cv2.circle(img, (int(x_a), int(y_a)), 3, Pose.color, -1)
56
+ kpt_b_id = BODY_PARTS_KPT_IDS[part_id][1]
57
+ global_kpt_b_id = self.keypoints[kpt_b_id, 0]
58
+ if global_kpt_b_id != -1:
59
+ x_b, y_b = self.keypoints[kpt_b_id]
60
+ cv2.circle(img, (int(x_b), int(y_b)), 3, Pose.color, -1)
61
+ if global_kpt_a_id != -1 and global_kpt_b_id != -1:
62
+ cv2.line(img, (int(x_a), int(y_a)), (int(x_b), int(y_b)), Pose.color, 2)
63
+
64
+
65
+ def get_similarity(a, b, threshold=0.5):
66
+ num_similar_kpt = 0
67
+ for kpt_id in range(Pose.num_kpts):
68
+ if a.keypoints[kpt_id, 0] != -1 and b.keypoints[kpt_id, 0] != -1:
69
+ distance = np.sum((a.keypoints[kpt_id] - b.keypoints[kpt_id]) ** 2)
70
+ area = max(a.bbox[2] * a.bbox[3], b.bbox[2] * b.bbox[3])
71
+ similarity = np.exp(-distance / (2 * (area + np.spacing(1)) * Pose.vars[kpt_id]))
72
+ if similarity > threshold:
73
+ num_similar_kpt += 1
74
+ return num_similar_kpt
75
+
76
+
77
+ def track_poses(previous_poses, current_poses, threshold=3, smooth=False):
78
+ """Propagate poses ids from previous frame results. Id is propagated,
79
+ if there are at least `threshold` similar keypoints between pose from previous frame and current.
80
+ If correspondence between pose on previous and current frame was established, pose keypoints are smoothed.
81
+
82
+ :param previous_poses: poses from previous frame with ids
83
+ :param current_poses: poses from current frame to assign ids
84
+ :param threshold: minimal number of similar keypoints between poses
85
+ :param smooth: smooth pose keypoints between frames
86
+ :return: None
87
+ """
88
+ current_poses = sorted(current_poses, key=lambda pose: pose.confidence, reverse=True) # match confident poses first
89
+ mask = np.ones(len(previous_poses), dtype=np.int32)
90
+ for current_pose in current_poses:
91
+ best_matched_id = None
92
+ best_matched_pose_id = None
93
+ best_matched_iou = 0
94
+ for id, previous_pose in enumerate(previous_poses):
95
+ if not mask[id]:
96
+ continue
97
+ iou = get_similarity(current_pose, previous_pose)
98
+ if iou > best_matched_iou:
99
+ best_matched_iou = iou
100
+ best_matched_pose_id = previous_pose.id
101
+ best_matched_id = id
102
+ if best_matched_iou >= threshold:
103
+ mask[best_matched_id] = 0
104
+ else: # pose not similar to any previous
105
+ best_matched_pose_id = None
106
+ current_pose.update_id(best_matched_pose_id)
107
+
108
+ if smooth:
109
+ for kpt_id in range(Pose.num_kpts):
110
+ if current_pose.keypoints[kpt_id, 0] == -1:
111
+ continue
112
+ # reuse filter if previous pose has valid filter
113
+ if (best_matched_pose_id is not None
114
+ and previous_poses[best_matched_id].keypoints[kpt_id, 0] != -1):
115
+ current_pose.filters[kpt_id] = previous_poses[best_matched_id].filters[kpt_id]
116
+ current_pose.keypoints[kpt_id, 0] = current_pose.filters[kpt_id][0](current_pose.keypoints[kpt_id, 0])
117
+ current_pose.keypoints[kpt_id, 1] = current_pose.filters[kpt_id][1](current_pose.keypoints[kpt_id, 1])
118
+ current_pose.bbox = Pose.get_bbox(current_pose.keypoints)
lite_openpose/pose2d_models/__init__.py ADDED
File without changes
lite_openpose/pose2d_models/with_mobilenet.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from modules.conv import conv, conv_dw, conv_dw_no_bn
5
+
6
+
7
+ class Cpm(nn.Module):
8
+ def __init__(self, in_channels, out_channels):
9
+ super().__init__()
10
+ self.align = conv(in_channels, out_channels, kernel_size=1, padding=0, bn=False)
11
+ self.trunk = nn.Sequential(
12
+ conv_dw_no_bn(out_channels, out_channels),
13
+ conv_dw_no_bn(out_channels, out_channels),
14
+ conv_dw_no_bn(out_channels, out_channels)
15
+ )
16
+ self.conv = conv(out_channels, out_channels, bn=False)
17
+
18
+ def forward(self, x):
19
+ x = self.align(x)
20
+ x = self.conv(x + self.trunk(x))
21
+ return x
22
+
23
+
24
+ class InitialStage(nn.Module):
25
+ def __init__(self, num_channels, num_heatmaps, num_pafs):
26
+ super().__init__()
27
+ self.trunk = nn.Sequential(
28
+ conv(num_channels, num_channels, bn=False),
29
+ conv(num_channels, num_channels, bn=False),
30
+ conv(num_channels, num_channels, bn=False)
31
+ )
32
+ self.heatmaps = nn.Sequential(
33
+ conv(num_channels, 512, kernel_size=1, padding=0, bn=False),
34
+ conv(512, num_heatmaps, kernel_size=1, padding=0, bn=False, relu=False)
35
+ )
36
+ self.pafs = nn.Sequential(
37
+ conv(num_channels, 512, kernel_size=1, padding=0, bn=False),
38
+ conv(512, num_pafs, kernel_size=1, padding=0, bn=False, relu=False)
39
+ )
40
+
41
+ def forward(self, x):
42
+ trunk_features = self.trunk(x)
43
+ heatmaps = self.heatmaps(trunk_features)
44
+ pafs = self.pafs(trunk_features)
45
+ return [heatmaps, pafs]
46
+
47
+
48
+ class RefinementStageBlock(nn.Module):
49
+ def __init__(self, in_channels, out_channels):
50
+ super().__init__()
51
+ self.initial = conv(in_channels, out_channels, kernel_size=1, padding=0, bn=False)
52
+ self.trunk = nn.Sequential(
53
+ conv(out_channels, out_channels),
54
+ conv(out_channels, out_channels, dilation=2, padding=2)
55
+ )
56
+
57
+ def forward(self, x):
58
+ initial_features = self.initial(x)
59
+ trunk_features = self.trunk(initial_features)
60
+ return initial_features + trunk_features
61
+
62
+
63
+ class RefinementStage(nn.Module):
64
+ def __init__(self, in_channels, out_channels, num_heatmaps, num_pafs):
65
+ super().__init__()
66
+ self.trunk = nn.Sequential(
67
+ RefinementStageBlock(in_channels, out_channels),
68
+ RefinementStageBlock(out_channels, out_channels),
69
+ RefinementStageBlock(out_channels, out_channels),
70
+ RefinementStageBlock(out_channels, out_channels),
71
+ RefinementStageBlock(out_channels, out_channels)
72
+ )
73
+ self.heatmaps = nn.Sequential(
74
+ conv(out_channels, out_channels, kernel_size=1, padding=0, bn=False),
75
+ conv(out_channels, num_heatmaps, kernel_size=1, padding=0, bn=False, relu=False)
76
+ )
77
+ self.pafs = nn.Sequential(
78
+ conv(out_channels, out_channels, kernel_size=1, padding=0, bn=False),
79
+ conv(out_channels, num_pafs, kernel_size=1, padding=0, bn=False, relu=False)
80
+ )
81
+
82
+ def forward(self, x):
83
+ trunk_features = self.trunk(x)
84
+ heatmaps = self.heatmaps(trunk_features)
85
+ pafs = self.pafs(trunk_features)
86
+ return [heatmaps, pafs]
87
+
88
+
89
+ class PoseEstimationWithMobileNet(nn.Module):
90
+ def __init__(self, num_refinement_stages=1, num_channels=128, num_heatmaps=19, num_pafs=38):
91
+ super().__init__()
92
+ self.model = nn.Sequential(
93
+ conv( 3, 32, stride=2, bias=False),
94
+ conv_dw( 32, 64),
95
+ conv_dw( 64, 128, stride=2),
96
+ conv_dw(128, 128),
97
+ conv_dw(128, 256, stride=2),
98
+ conv_dw(256, 256),
99
+ conv_dw(256, 512), # conv4_2
100
+ conv_dw(512, 512, dilation=2, padding=2),
101
+ conv_dw(512, 512),
102
+ conv_dw(512, 512),
103
+ conv_dw(512, 512),
104
+ conv_dw(512, 512) # conv5_5
105
+ )
106
+ self.cpm = Cpm(512, num_channels)
107
+
108
+ self.initial_stage = InitialStage(num_channels, num_heatmaps, num_pafs)
109
+ self.refinement_stages = nn.ModuleList()
110
+ for idx in range(num_refinement_stages):
111
+ self.refinement_stages.append(RefinementStage(num_channels + num_heatmaps + num_pafs, num_channels,
112
+ num_heatmaps, num_pafs))
113
+
114
+ def forward(self, x):
115
+ backbone_features = self.model(x)
116
+ backbone_features = self.cpm(backbone_features)
117
+
118
+ stages_output = self.initial_stage(backbone_features)
119
+ for refinement_stage in self.refinement_stages:
120
+ stages_output.extend(
121
+ refinement_stage(torch.cat([backbone_features, stages_output[-2], stages_output[-1]], dim=1)))
122
+
123
+ return stages_output