ysmao commited on
Commit
ef0eb1c
1 Parent(s): 6062162

update dependencies

Browse files
.gitignore CHANGED
@@ -159,4 +159,5 @@ cython_debug/
159
  # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
  # and can be added to the global gitignore or merged into this file. For a more nuclear
161
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
- #.idea/
 
 
159
  # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
  # and can be added to the global gitignore or merged into this file. For a more nuclear
161
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
163
+ gradio/
annotator/dsine/__init__.py ADDED
File without changes
annotator/dsine/dsine.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .submodules import (
7
+ Encoder,
8
+ ConvGRU,
9
+ UpSampleBN,
10
+ UpSampleGN,
11
+ RayReLU,
12
+ convex_upsampling,
13
+ get_unfold,
14
+ get_prediction_head,
15
+ INPUT_CHANNELS_DICT,
16
+ )
17
+ from .rotation import axis_angle_to_matrix
18
+
19
+
20
+ class Decoder(nn.Module):
21
+ def __init__(self, output_dims, B=5, NF=2048, BN=False, downsample_ratio=8):
22
+ super(Decoder, self).__init__()
23
+ input_channels = INPUT_CHANNELS_DICT[B]
24
+ output_dim, feature_dim, hidden_dim = output_dims
25
+ features = bottleneck_features = NF
26
+ self.downsample_ratio = downsample_ratio
27
+
28
+ UpSample = UpSampleBN if BN else UpSampleGN
29
+ self.conv2 = nn.Conv2d(
30
+ bottleneck_features + 2, features, kernel_size=1, stride=1, padding=0
31
+ )
32
+ self.up1 = UpSample(
33
+ skip_input=features // 1 + input_channels[1] + 2,
34
+ output_features=features // 2,
35
+ align_corners=False,
36
+ )
37
+ self.up2 = UpSample(
38
+ skip_input=features // 2 + input_channels[2] + 2,
39
+ output_features=features // 4,
40
+ align_corners=False,
41
+ )
42
+
43
+ # prediction heads
44
+ i_dim = features // 4
45
+ h_dim = 128
46
+ self.normal_head = get_prediction_head(i_dim + 2, h_dim, output_dim)
47
+ self.feature_head = get_prediction_head(i_dim + 2, h_dim, feature_dim)
48
+ self.hidden_head = get_prediction_head(i_dim + 2, h_dim, hidden_dim)
49
+
50
+ def forward(self, features, uvs):
51
+ _, _, x_block2, x_block3, x_block4 = (
52
+ features[4],
53
+ features[5],
54
+ features[6],
55
+ features[8],
56
+ features[11],
57
+ )
58
+ uv_32, uv_16, uv_8 = uvs
59
+
60
+ x_d0 = self.conv2(torch.cat([x_block4, uv_32], dim=1))
61
+ x_d1 = self.up1(x_d0, torch.cat([x_block3, uv_16], dim=1))
62
+ x_feat = self.up2(x_d1, torch.cat([x_block2, uv_8], dim=1))
63
+ x_feat = torch.cat([x_feat, uv_8], dim=1)
64
+
65
+ normal = self.normal_head(x_feat)
66
+ normal = F.normalize(normal, dim=1)
67
+ f = self.feature_head(x_feat)
68
+ h = self.hidden_head(x_feat)
69
+ return normal, f, h
70
+
71
+
72
+ class DSINE(nn.Module):
73
+ def __init__(self):
74
+ super(DSINE, self).__init__()
75
+ self.downsample_ratio = 8
76
+ self.ps = 5 # patch size
77
+ self.num_iter = 5 # num iterations
78
+
79
+ # define encoder
80
+ self.encoder = Encoder(
81
+ B=5,
82
+ pretrained=True,
83
+ )
84
+
85
+ # define decoder
86
+ self.output_dim = output_dim = 3
87
+ self.feature_dim = feature_dim = 64
88
+ self.hidden_dim = hidden_dim = 64
89
+ self.decoder = Decoder(
90
+ [output_dim, feature_dim, hidden_dim], B=5, NF=2048, BN=False
91
+ )
92
+
93
+ # ray direction-based ReLU
94
+ self.ray_relu = RayReLU(eps=1e-2)
95
+
96
+ # pixel_coords (1, 3, H, W)
97
+ # NOTE: this is set to some arbitrarily high number,
98
+ # if your input is 2000+ pixels wide/tall, increase these values
99
+ h = 2000
100
+ w = 2000
101
+ pixel_coords = np.ones((3, h, w)).astype(np.float32)
102
+ x_range = np.concatenate([np.arange(w).reshape(1, w)] * h, axis=0)
103
+ y_range = np.concatenate([np.arange(h).reshape(h, 1)] * w, axis=1)
104
+ pixel_coords[0, :, :] = x_range + 0.5
105
+ pixel_coords[1, :, :] = y_range + 0.5
106
+ self.pixel_coords = torch.from_numpy(pixel_coords).unsqueeze(0)
107
+
108
+ # define ConvGRU cell
109
+ self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=feature_dim + 2, ks=self.ps)
110
+
111
+ # padding used during NRN
112
+ self.pad = (self.ps - 1) // 2
113
+
114
+ # prediction heads
115
+ self.prob_head = get_prediction_head(
116
+ self.hidden_dim + 2, 64, self.ps * self.ps
117
+ ) # weights assigned for each nghbr pixel
118
+ self.xy_head = get_prediction_head(
119
+ self.hidden_dim + 2, 64, self.ps * self.ps * 2
120
+ ) # rotation axis for each nghbr pixel
121
+ self.angle_head = get_prediction_head(
122
+ self.hidden_dim + 2, 64, self.ps * self.ps
123
+ ) # rotation angle for each nghbr pixel
124
+
125
+ # prediction heads - weights used for upsampling the coarse resolution output
126
+ self.up_prob_head = get_prediction_head(
127
+ self.hidden_dim + 2, 64, 9 * self.downsample_ratio * self.downsample_ratio
128
+ )
129
+
130
+ def get_ray(self, intrins, H, W, orig_H, orig_W, return_uv=False):
131
+ B, _, _ = intrins.shape
132
+ fu = intrins[:, 0, 0][:, None, None] * (W / orig_W)
133
+ cu = intrins[:, 0, 2][:, None, None] * (W / orig_W)
134
+ fv = intrins[:, 1, 1][:, None, None] * (H / orig_H)
135
+ cv = intrins[:, 1, 2][:, None, None] * (H / orig_H)
136
+
137
+ # (B, 2, H, W)
138
+ ray = self.pixel_coords[:, :, :H, :W].repeat(B, 1, 1, 1)
139
+ ray[:, 0, :, :] = (ray[:, 0, :, :] - cu) / fu
140
+ ray[:, 1, :, :] = (ray[:, 1, :, :] - cv) / fv
141
+
142
+ if return_uv:
143
+ return ray[:, :2, :, :]
144
+ else:
145
+ return F.normalize(ray, dim=1)
146
+
147
+ def upsample(self, h, pred_norm, uv_8):
148
+ up_mask = self.up_prob_head(torch.cat([h, uv_8], dim=1))
149
+ up_pred_norm = convex_upsampling(pred_norm, up_mask, self.downsample_ratio)
150
+ up_pred_norm = F.normalize(up_pred_norm, dim=1)
151
+ return up_pred_norm
152
+
153
+ def refine(self, h, feat_map, pred_norm, intrins, orig_H, orig_W, uv_8, ray_8):
154
+ B, C, H, W = pred_norm.shape
155
+ fu = intrins[:, 0, 0][:, None, None, None] * (W / orig_W) # (B, 1, 1, 1)
156
+ cu = intrins[:, 0, 2][:, None, None, None] * (W / orig_W)
157
+ fv = intrins[:, 1, 1][:, None, None, None] * (H / orig_H)
158
+ cv = intrins[:, 1, 2][:, None, None, None] * (H / orig_H)
159
+
160
+ h_new = self.gru(h, feat_map)
161
+
162
+ # get nghbr prob (B, 1, ps*ps, h, w)
163
+ nghbr_prob = self.prob_head(torch.cat([h_new, uv_8], dim=1)).unsqueeze(1)
164
+ nghbr_prob = torch.sigmoid(nghbr_prob)
165
+
166
+ # get nghbr normals (B, 3, ps*ps, h, w)
167
+ nghbr_normals = get_unfold(pred_norm, ps=self.ps, pad=self.pad)
168
+
169
+ # get nghbr xy (B, 2, ps*ps, h, w)
170
+ nghbr_xys = self.xy_head(torch.cat([h_new, uv_8], dim=1))
171
+ nghbr_xs, nghbr_ys = torch.split(
172
+ nghbr_xys, [self.ps * self.ps, self.ps * self.ps], dim=1
173
+ )
174
+ nghbr_xys = torch.cat([nghbr_xs.unsqueeze(1), nghbr_ys.unsqueeze(1)], dim=1)
175
+ nghbr_xys = F.normalize(nghbr_xys, dim=1)
176
+
177
+ # get nghbr theta (B, 1, ps*ps, h, w)
178
+ nghbr_angle = self.angle_head(torch.cat([h_new, uv_8], dim=1)).unsqueeze(1)
179
+ nghbr_angle = torch.sigmoid(nghbr_angle) * np.pi
180
+
181
+ # get nghbr pixel coord (1, 3, ps*ps, h, w)
182
+ nghbr_pixel_coord = get_unfold(
183
+ self.pixel_coords[:, :, :H, :W], ps=self.ps, pad=self.pad
184
+ )
185
+
186
+ # nghbr axes (B, 3, ps*ps, h, w)
187
+ nghbr_axes = torch.zeros_like(nghbr_normals)
188
+
189
+ du_over_fu = nghbr_xys[:, 0, ...] / fu # (B, ps*ps, h, w)
190
+ dv_over_fv = nghbr_xys[:, 1, ...] / fv # (B, ps*ps, h, w)
191
+
192
+ term_u = (
193
+ nghbr_pixel_coord[:, 0, ...] + nghbr_xys[:, 0, ...] - cu
194
+ ) / fu # (B, ps*ps, h, w)
195
+ term_v = (
196
+ nghbr_pixel_coord[:, 1, ...] + nghbr_xys[:, 1, ...] - cv
197
+ ) / fv # (B, ps*ps, h, w)
198
+
199
+ nx = nghbr_normals[:, 0, ...] # (B, ps*ps, h, w)
200
+ ny = nghbr_normals[:, 1, ...] # (B, ps*ps, h, w)
201
+ nz = nghbr_normals[:, 2, ...] # (B, ps*ps, h, w)
202
+
203
+ nghbr_delta_z_num = -(du_over_fu * nx + dv_over_fv * ny)
204
+ nghbr_delta_z_denom = term_u * nx + term_v * ny + nz
205
+ nghbr_delta_z_denom[torch.abs(nghbr_delta_z_denom) < 1e-8] = 1e-8 * torch.sign(
206
+ nghbr_delta_z_denom[torch.abs(nghbr_delta_z_denom) < 1e-8]
207
+ )
208
+ nghbr_delta_z = nghbr_delta_z_num / nghbr_delta_z_denom
209
+
210
+ nghbr_axes[:, 0, ...] = du_over_fu + nghbr_delta_z * term_u
211
+ nghbr_axes[:, 1, ...] = dv_over_fv + nghbr_delta_z * term_v
212
+ nghbr_axes[:, 2, ...] = nghbr_delta_z
213
+ nghbr_axes = F.normalize(nghbr_axes, dim=1) # (B, 3, ps*ps, h, w)
214
+
215
+ # make sure axes are all valid
216
+ invalid = (
217
+ torch.sum(
218
+ torch.logical_or(
219
+ torch.isnan(nghbr_axes), torch.isinf(nghbr_axes)
220
+ ).float(),
221
+ dim=1,
222
+ )
223
+ > 0.5
224
+ ) # (B, ps*ps, h, w)
225
+ nghbr_axes[:, 0, ...][invalid] = 0.0
226
+ nghbr_axes[:, 1, ...][invalid] = 0.0
227
+ nghbr_axes[:, 2, ...][invalid] = 0.0
228
+
229
+ # nghbr_axes_angle (B, 3, ps*ps, h, w)
230
+ nghbr_axes_angle = nghbr_axes * nghbr_angle
231
+ nghbr_axes_angle = nghbr_axes_angle.permute(
232
+ 0, 2, 3, 4, 1
233
+ ) # (B, ps*ps, h, w, 3)
234
+ nghbr_R = axis_angle_to_matrix(nghbr_axes_angle) # (B, ps*ps, h, w, 3, 3)
235
+
236
+ # (B, 3, ps*ps, h, w)
237
+ nghbr_normals_rot = (
238
+ torch.bmm(
239
+ nghbr_R.reshape(B * self.ps * self.ps * H * W, 3, 3),
240
+ nghbr_normals.permute(0, 2, 3, 4, 1)
241
+ .reshape(B * self.ps * self.ps * H * W, 3)
242
+ .unsqueeze(-1),
243
+ )
244
+ .reshape(B, self.ps * self.ps, H, W, 3, 1)
245
+ .squeeze(-1)
246
+ .permute(0, 4, 1, 2, 3)
247
+ ) # (B, 3, ps*ps, h, w)
248
+ nghbr_normals_rot = F.normalize(nghbr_normals_rot, dim=1)
249
+
250
+ # ray ReLU
251
+ nghbr_normals_rot = torch.cat(
252
+ [
253
+ self.ray_relu(nghbr_normals_rot[:, :, i, :, :], ray_8).unsqueeze(2)
254
+ for i in range(nghbr_normals_rot.size(2))
255
+ ],
256
+ dim=2,
257
+ )
258
+
259
+ # (B, 1, ps*ps, h, w) * (B, 3, ps*ps, h, w)
260
+ pred_norm = torch.sum(nghbr_prob * nghbr_normals_rot, dim=2) # (B, C, H, W)
261
+ pred_norm = F.normalize(pred_norm, dim=1)
262
+
263
+ up_mask = self.up_prob_head(torch.cat([h_new, uv_8], dim=1))
264
+ up_pred_norm = convex_upsampling(pred_norm, up_mask, self.downsample_ratio)
265
+ up_pred_norm = F.normalize(up_pred_norm, dim=1)
266
+
267
+ return h_new, pred_norm, up_pred_norm
268
+
269
+ def forward(self, img, intrins=None):
270
+ # Step 1. encoder
271
+ features = self.encoder(img)
272
+
273
+ # Step 2. get uv encoding
274
+ B, _, orig_H, orig_W = img.shape
275
+ intrins[:, 0, 2] += 0.5
276
+ intrins[:, 1, 2] += 0.5
277
+ uv_32 = self.get_ray(
278
+ intrins, orig_H // 32, orig_W // 32, orig_H, orig_W, return_uv=True
279
+ )
280
+ uv_16 = self.get_ray(
281
+ intrins, orig_H // 16, orig_W // 16, orig_H, orig_W, return_uv=True
282
+ )
283
+ uv_8 = self.get_ray(
284
+ intrins, orig_H // 8, orig_W // 8, orig_H, orig_W, return_uv=True
285
+ )
286
+ ray_8 = self.get_ray(intrins, orig_H // 8, orig_W // 8, orig_H, orig_W)
287
+
288
+ # Step 3. decoder - initial prediction
289
+ pred_norm, feat_map, h = self.decoder(features, uvs=(uv_32, uv_16, uv_8))
290
+ pred_norm = self.ray_relu(pred_norm, ray_8)
291
+
292
+ # Step 4. add ray direction encoding
293
+ feat_map = torch.cat([feat_map, uv_8], dim=1)
294
+
295
+ # iterative refinement
296
+ up_pred_norm = self.upsample(h, pred_norm, uv_8)
297
+ pred_list = [up_pred_norm]
298
+ for i in range(self.num_iter):
299
+ h, pred_norm, up_pred_norm = self.refine(
300
+ h, feat_map, pred_norm.detach(), intrins, orig_H, orig_W, uv_8, ray_8
301
+ )
302
+ pred_list.append(up_pred_norm)
303
+ return pred_list
annotator/dsine/rotation.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ # NOTE: from PyTorch3D
6
+ def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
7
+ """
8
+ Convert rotations given as axis/angle to quaternions.
9
+
10
+ Args:
11
+ axis_angle: Rotations given as a vector in axis angle form,
12
+ as a tensor of shape (..., 3), where the magnitude is
13
+ the angle turned anticlockwise in radians around the
14
+ vector's direction.
15
+
16
+ Returns:
17
+ quaternions with real part first, as tensor of shape (..., 4).
18
+ """
19
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
20
+ half_angles = angles * 0.5
21
+ eps = 1e-6
22
+ small_angles = angles.abs() < eps
23
+ sin_half_angles_over_angles = torch.empty_like(angles)
24
+ sin_half_angles_over_angles[~small_angles] = (
25
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
26
+ )
27
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
28
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
29
+ sin_half_angles_over_angles[small_angles] = (
30
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
31
+ )
32
+ quaternions = torch.cat(
33
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
34
+ )
35
+ return quaternions
36
+
37
+
38
+ # NOTE: from PyTorch3D
39
+ def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
40
+ """
41
+ Convert rotations given as quaternions to rotation matrices.
42
+
43
+ Args:
44
+ quaternions: quaternions with real part first,
45
+ as tensor of shape (..., 4).
46
+
47
+ Returns:
48
+ Rotation matrices as tensor of shape (..., 3, 3).
49
+ """
50
+ r, i, j, k = torch.unbind(quaternions, -1)
51
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
52
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
53
+
54
+ o = torch.stack(
55
+ (
56
+ 1 - two_s * (j * j + k * k),
57
+ two_s * (i * j - k * r),
58
+ two_s * (i * k + j * r),
59
+ two_s * (i * j + k * r),
60
+ 1 - two_s * (i * i + k * k),
61
+ two_s * (j * k - i * r),
62
+ two_s * (i * k - j * r),
63
+ two_s * (j * k + i * r),
64
+ 1 - two_s * (i * i + j * j),
65
+ ),
66
+ -1,
67
+ )
68
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
69
+
70
+
71
+ # NOTE: from PyTorch3D
72
+ def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
73
+ """
74
+ Convert rotations given as axis/angle to rotation matrices.
75
+
76
+ Args:
77
+ axis_angle: Rotations given as a vector in axis angle form,
78
+ as a tensor of shape (..., 3), where the magnitude is
79
+ the angle turned anticlockwise in radians around the
80
+ vector's direction.
81
+
82
+ Returns:
83
+ Rotation matrices as tensor of shape (..., 3, 3).
84
+ """
85
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
annotator/dsine/submodules.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import geffnet
5
+
6
+
7
+ INPUT_CHANNELS_DICT = {
8
+ 0: [1280, 112, 40, 24, 16],
9
+ 1: [1280, 112, 40, 24, 16],
10
+ 2: [1408, 120, 48, 24, 16],
11
+ 3: [1536, 136, 48, 32, 24],
12
+ 4: [1792, 160, 56, 32, 24],
13
+ 5: [2048, 176, 64, 40, 24],
14
+ 6: [2304, 200, 72, 40, 32],
15
+ 7: [2560, 224, 80, 48, 32],
16
+ }
17
+
18
+
19
+ class Encoder(nn.Module):
20
+ def __init__(self, B=5, pretrained=True):
21
+ """e.g. B=5 will return EfficientNet-B5"""
22
+ super(Encoder, self).__init__()
23
+ basemodel_name = 'tf_efficientnet_b%s_ap' % B
24
+ basemodel = geffnet.create_model(basemodel_name, pretrained=pretrained)
25
+ # Remove last layer
26
+ basemodel.global_pool = nn.Identity()
27
+ basemodel.classifier = nn.Identity()
28
+ self.original_model = basemodel
29
+
30
+ def forward(self, x):
31
+ features = [x]
32
+ for k, v in self.original_model._modules.items():
33
+ if k == "blocks":
34
+ for ki, vi in v._modules.items():
35
+ features.append(vi(features[-1]))
36
+ else:
37
+ features.append(v(features[-1]))
38
+ return features
39
+
40
+
41
+ class ConvGRU(nn.Module):
42
+ def __init__(self, hidden_dim, input_dim, ks=3):
43
+ super(ConvGRU, self).__init__()
44
+ p = (ks - 1) // 2
45
+ self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, ks, padding=p)
46
+ self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, ks, padding=p)
47
+ self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, ks, padding=p)
48
+
49
+ def forward(self, h, x):
50
+ hx = torch.cat([h, x], dim=1)
51
+ z = torch.sigmoid(self.convz(hx))
52
+ r = torch.sigmoid(self.convr(hx))
53
+ q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
54
+ h = (1 - z) * h + z * q
55
+ return h
56
+
57
+
58
+ class RayReLU(nn.Module):
59
+ def __init__(self, eps=1e-2):
60
+ super(RayReLU, self).__init__()
61
+ self.eps = eps
62
+
63
+ def forward(self, pred_norm, ray):
64
+ # angle between the predicted normal and ray direction
65
+ cos = torch.cosine_similarity(pred_norm, ray, dim=1).unsqueeze(
66
+ 1
67
+ ) # (B, 1, H, W)
68
+
69
+ # component of pred_norm along view
70
+ norm_along_view = ray * cos
71
+
72
+ # cos should be bigger than eps
73
+ norm_along_view_relu = ray * (torch.relu(cos - self.eps) + self.eps)
74
+
75
+ # difference
76
+ diff = norm_along_view_relu - norm_along_view
77
+
78
+ # updated pred_norm
79
+ new_pred_norm = pred_norm + diff
80
+ new_pred_norm = F.normalize(new_pred_norm, dim=1)
81
+
82
+ return new_pred_norm
83
+
84
+
85
+ class UpSampleBN(nn.Module):
86
+ def __init__(self, skip_input, output_features, align_corners=True):
87
+ super(UpSampleBN, self).__init__()
88
+ self._net = nn.Sequential(
89
+ nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
90
+ nn.BatchNorm2d(output_features),
91
+ nn.LeakyReLU(),
92
+ nn.Conv2d(
93
+ output_features, output_features, kernel_size=3, stride=1, padding=1
94
+ ),
95
+ nn.BatchNorm2d(output_features),
96
+ nn.LeakyReLU(),
97
+ )
98
+ self.align_corners = align_corners
99
+
100
+ def forward(self, x, concat_with):
101
+ up_x = F.interpolate(
102
+ x,
103
+ size=[concat_with.size(2), concat_with.size(3)],
104
+ mode="bilinear",
105
+ align_corners=self.align_corners,
106
+ )
107
+ f = torch.cat([up_x, concat_with], dim=1)
108
+ return self._net(f)
109
+
110
+
111
+ class Conv2d_WS(nn.Conv2d):
112
+ """weight standardization"""
113
+
114
+ def __init__(
115
+ self,
116
+ in_channels,
117
+ out_channels,
118
+ kernel_size,
119
+ stride=1,
120
+ padding=0,
121
+ dilation=1,
122
+ groups=1,
123
+ bias=True,
124
+ ):
125
+ super(Conv2d_WS, self).__init__(
126
+ in_channels,
127
+ out_channels,
128
+ kernel_size,
129
+ stride,
130
+ padding,
131
+ dilation,
132
+ groups,
133
+ bias,
134
+ )
135
+
136
+ def forward(self, x):
137
+ weight = self.weight
138
+ weight_mean = (
139
+ weight.mean(dim=1, keepdim=True)
140
+ .mean(dim=2, keepdim=True)
141
+ .mean(dim=3, keepdim=True)
142
+ )
143
+ weight = weight - weight_mean
144
+ std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
145
+ weight = weight / std.expand_as(weight)
146
+ return F.conv2d(
147
+ x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
148
+ )
149
+
150
+
151
+ class UpSampleGN(nn.Module):
152
+ """UpSample with GroupNorm"""
153
+
154
+ def __init__(self, skip_input, output_features, align_corners=True):
155
+ super(UpSampleGN, self).__init__()
156
+ self._net = nn.Sequential(
157
+ Conv2d_WS(skip_input, output_features, kernel_size=3, stride=1, padding=1),
158
+ nn.GroupNorm(8, output_features),
159
+ nn.LeakyReLU(),
160
+ Conv2d_WS(
161
+ output_features, output_features, kernel_size=3, stride=1, padding=1
162
+ ),
163
+ nn.GroupNorm(8, output_features),
164
+ nn.LeakyReLU(),
165
+ )
166
+ self.align_corners = align_corners
167
+
168
+ def forward(self, x, concat_with):
169
+ up_x = F.interpolate(
170
+ x,
171
+ size=[concat_with.size(2), concat_with.size(3)],
172
+ mode="bilinear",
173
+ align_corners=self.align_corners,
174
+ )
175
+ f = torch.cat([up_x, concat_with], dim=1)
176
+ return self._net(f)
177
+
178
+
179
+ def upsample_via_bilinear(out, up_mask, downsample_ratio):
180
+ """bilinear upsampling (up_mask is a dummy variable)"""
181
+ return F.interpolate(
182
+ out, scale_factor=downsample_ratio, mode="bilinear", align_corners=True
183
+ )
184
+
185
+
186
+ def upsample_via_mask(out, up_mask, downsample_ratio):
187
+ """convex upsampling"""
188
+ # out: low-resolution output (B, o_dim, H, W)
189
+ # up_mask: (B, 9*k*k, H, W)
190
+ k = downsample_ratio
191
+
192
+ N, o_dim, H, W = out.shape
193
+ up_mask = up_mask.view(N, 1, 9, k, k, H, W)
194
+ up_mask = torch.softmax(up_mask, dim=2) # (B, 1, 9, k, k, H, W)
195
+
196
+ up_out = F.unfold(out, [3, 3], padding=1) # (B, 2, H, W) -> (B, 2 X 3*3, H*W)
197
+ up_out = up_out.view(N, o_dim, 9, 1, 1, H, W) # (B, 2, 3*3, 1, 1, H, W)
198
+ up_out = torch.sum(up_mask * up_out, dim=2) # (B, 2, k, k, H, W)
199
+
200
+ up_out = up_out.permute(0, 1, 4, 2, 5, 3) # (B, 2, H, k, W, k)
201
+ return up_out.reshape(N, o_dim, k * H, k * W) # (B, 2, kH, kW)
202
+
203
+
204
+ def convex_upsampling(out, up_mask, k):
205
+ # out: low-resolution output (B, C, H, W)
206
+ # up_mask: (B, 9*k*k, H, W)
207
+ B, C, H, W = out.shape
208
+ up_mask = up_mask.view(B, 1, 9, k, k, H, W)
209
+ up_mask = torch.softmax(up_mask, dim=2) # (B, 1, 9, k, k, H, W)
210
+
211
+ out = F.pad(out, pad=(1, 1, 1, 1), mode="replicate")
212
+ up_out = F.unfold(out, [3, 3], padding=0) # (B, C, H, W) -> (B, C X 3*3, H*W)
213
+ up_out = up_out.view(B, C, 9, 1, 1, H, W) # (B, C, 9, 1, 1, H, W)
214
+
215
+ up_out = torch.sum(up_mask * up_out, dim=2) # (B, C, k, k, H, W)
216
+ up_out = up_out.permute(0, 1, 4, 2, 5, 3) # (B, C, H, k, W, k)
217
+ return up_out.reshape(B, C, k * H, k * W) # (B, C, kH, kW)
218
+
219
+
220
+ def get_unfold(pred_norm, ps, pad):
221
+ B, C, H, W = pred_norm.shape
222
+ pred_norm = F.pad(
223
+ pred_norm, pad=(pad, pad, pad, pad), mode="replicate"
224
+ ) # (B, C, h, w)
225
+ pred_norm_unfold = F.unfold(pred_norm, [ps, ps], padding=0) # (B, C X ps*ps, h*w)
226
+ pred_norm_unfold = pred_norm_unfold.view(B, C, ps * ps, H, W) # (B, C, ps*ps, h, w)
227
+ return pred_norm_unfold
228
+
229
+
230
+ def get_prediction_head(input_dim, hidden_dim, output_dim):
231
+ return nn.Sequential(
232
+ nn.Conv2d(input_dim, hidden_dim, 3, padding=1),
233
+ nn.ReLU(inplace=True),
234
+ nn.Conv2d(hidden_dim, hidden_dim, 1),
235
+ nn.ReLU(inplace=True),
236
+ nn.Conv2d(hidden_dim, output_dim, 1),
237
+ )
annotator/dsine/utils.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ utils
2
+ """
3
+
4
+ import os
5
+ import torch
6
+ import numpy as np
7
+
8
+
9
+ def load_checkpoint(fpath, model):
10
+ print("loading checkpoint... {}".format(fpath))
11
+
12
+ ckpt = torch.load(fpath, map_location="cpu")["model"]
13
+
14
+ load_dict = {}
15
+ for k, v in ckpt.items():
16
+ if k.startswith("module."):
17
+ k_ = k.replace("module.", "")
18
+ load_dict[k_] = v
19
+ else:
20
+ load_dict[k] = v
21
+
22
+ model.load_state_dict(load_dict)
23
+ print("loading checkpoint... / done")
24
+ return model
25
+
26
+
27
+ def compute_normal_error(pred_norm, gt_norm):
28
+ pred_error = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
29
+ pred_error = torch.clamp(pred_error, min=-1.0, max=1.0)
30
+ pred_error = torch.acos(pred_error) * 180.0 / np.pi
31
+ pred_error = pred_error.unsqueeze(1) # (B, 1, H, W)
32
+ return pred_error
33
+
34
+
35
+ def compute_normal_metrics(total_normal_errors):
36
+ total_normal_errors = total_normal_errors.detach().cpu().numpy()
37
+ num_pixels = total_normal_errors.shape[0]
38
+
39
+ metrics = {
40
+ "mean": np.average(total_normal_errors),
41
+ "median": np.median(total_normal_errors),
42
+ "rmse": np.sqrt(np.sum(total_normal_errors * total_normal_errors) / num_pixels),
43
+ "a1": 100.0 * (np.sum(total_normal_errors < 5) / num_pixels),
44
+ "a2": 100.0 * (np.sum(total_normal_errors < 7.5) / num_pixels),
45
+ "a3": 100.0 * (np.sum(total_normal_errors < 11.25) / num_pixels),
46
+ "a4": 100.0 * (np.sum(total_normal_errors < 22.5) / num_pixels),
47
+ "a5": 100.0 * (np.sum(total_normal_errors < 30) / num_pixels),
48
+ }
49
+
50
+ return metrics
51
+
52
+
53
+ def pad_input(orig_H, orig_W):
54
+ if orig_W % 32 == 0:
55
+ l = 0
56
+ r = 0
57
+ else:
58
+ new_W = 32 * ((orig_W // 32) + 1)
59
+ l = (new_W - orig_W) // 2
60
+ r = (new_W - orig_W) - l
61
+
62
+ if orig_H % 32 == 0:
63
+ t = 0
64
+ b = 0
65
+ else:
66
+ new_H = 32 * ((orig_H // 32) + 1)
67
+ t = (new_H - orig_H) // 2
68
+ b = (new_H - orig_H) - t
69
+ return l, r, t, b
70
+
71
+
72
+ def get_intrins_from_fov(new_fov, H, W, device):
73
+ # NOTE: top-left pixel should be (0,0)
74
+ if W >= H:
75
+ new_fu = (W / 2.0) / np.tan(np.deg2rad(new_fov / 2.0))
76
+ new_fv = (W / 2.0) / np.tan(np.deg2rad(new_fov / 2.0))
77
+ else:
78
+ new_fu = (H / 2.0) / np.tan(np.deg2rad(new_fov / 2.0))
79
+ new_fv = (H / 2.0) / np.tan(np.deg2rad(new_fov / 2.0))
80
+
81
+ new_cu = (W / 2.0) - 0.5
82
+ new_cv = (H / 2.0) - 0.5
83
+
84
+ new_intrins = torch.tensor(
85
+ [[new_fu, 0, new_cu], [0, new_fv, new_cv], [0, 0, 1]],
86
+ dtype=torch.float32,
87
+ device=device,
88
+ )
89
+
90
+ return new_intrins
91
+
92
+
93
+ def get_intrins_from_txt(intrins_path, device):
94
+ # NOTE: top-left pixel should be (0,0)
95
+ with open(intrins_path, "r") as f:
96
+ intrins_ = f.readlines()[0].split()[0].split(",")
97
+ intrins_ = [float(i) for i in intrins_]
98
+ fx, fy, cx, cy = intrins_
99
+
100
+ intrins = torch.tensor(
101
+ [[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=torch.float32, device=device
102
+ )
103
+
104
+ return intrins
annotator/dsine_local.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ import torch.nn.functional as F
6
+ from .dsine.dsine import DSINE
7
+ from .dsine import utils as dsine_utils
8
+
9
+
10
+ class NormalDetector:
11
+ def __init__(self, model_path):
12
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ self.model = DSINE()
14
+ self.model = dsine_utils.load_checkpoint(model_path, self.model)
15
+ self.normalize = transforms.Normalize(
16
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
17
+ )
18
+ self.fov = 60
19
+
20
+ @torch.no_grad()
21
+ def __call__(self, image):
22
+ self.model.to(self.device)
23
+ self.model.pixel_coords = self.model.pixel_coords.to(self.device)
24
+
25
+ img = np.array(image).astype(np.float32) / 255.0
26
+ img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(self.device)
27
+ _, _, orig_H, orig_W = img.shape
28
+ l, r, t, b = dsine_utils.pad_input(orig_H, orig_W)
29
+ img = F.pad(img, (l, r, t, b), mode="constant", value=0.0)
30
+ img = self.normalize(img)
31
+ intrinsics = dsine_utils.get_intrins_from_fov(
32
+ new_fov=self.fov, H=orig_H, W=orig_W, device=self.device
33
+ ).unsqueeze(0)
34
+
35
+ intrinsics[:, 0, 2] += l
36
+ intrinsics[:, 1, 2] += t
37
+
38
+ pred_norm = self.model(img, intrins=intrinsics)[-1]
39
+ pred_norm = pred_norm[:, :, t : t + orig_H, l : l + orig_W]
40
+ pred_norm_np = (
41
+ pred_norm.cpu().detach().numpy()[0, :, :, :].transpose(1, 2, 0)
42
+ ) # (H, W, 3)
43
+ pred_norm_np = ((pred_norm_np + 1.0) / 2.0 * 255.0).astype(np.uint8)
44
+ normal_img = Image.fromarray(pred_norm_np).resize((orig_W, orig_H))
45
+
46
+ self.model.to("cpu")
47
+ self.model.pixel_coords = self.model.pixel_coords.to("cpu")
48
+ return normal_img
49
+
50
+
51
+ if __name__ == "__main__":
52
+ from diffusers.utils import load_image
53
+
54
+ image = load_image(
55
+ "https://qhstaticssl.kujiale.com/image/jpeg/1716177580588/9AAA49344B9CE33512C4EBD0A287495F.jpg"
56
+ )
57
+ image = np.asarray(image)
58
+ normal_detector = NormalDetector(
59
+ model_path="/juicefs/training/models/open_source/dsine/dsine.pt",
60
+ efficientnet_path="/juicefs/training/models/open_source/dsine/tf_efficientnet_b5_ap-9e82fae8.pth",
61
+ )
62
+ normal_image = normal_detector(image)
63
+ normal_image.save("normal_image.jpg")
app.py CHANGED
@@ -7,10 +7,11 @@ from diffusers import (
7
  UniPCMultistepScheduler,
8
  )
9
  import gradio as gr
 
10
 
11
  from annotator.util import resize_image, HWC3
12
  from annotator.midas import DepthDetector
13
- from annotator.dsine_hub import NormalDetector
14
  from annotator.upernet import SegmDetector
15
 
16
  controlnet_checkpoint = "kujiale-ai/controlnet"
@@ -26,7 +27,9 @@ pipe = StableDiffusionControlNetPipeline.from_pretrained(
26
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
27
 
28
  apply_depth = DepthDetector()
29
- apply_normal = NormalDetector()
 
 
30
  apply_segm = SegmDetector()
31
 
32
 
 
7
  UniPCMultistepScheduler,
8
  )
9
  import gradio as gr
10
+ from huggingface_hub import hf_hub_download
11
 
12
  from annotator.util import resize_image, HWC3
13
  from annotator.midas import DepthDetector
14
+ from annotator.dsine_local import NormalDetector
15
  from annotator.upernet import SegmDetector
16
 
17
  controlnet_checkpoint = "kujiale-ai/controlnet"
 
27
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
28
 
29
  apply_depth = DepthDetector()
30
+ apply_normal = NormalDetector(
31
+ hf_hub_download("camenduru/DSINE", filename="dsine.pt")
32
+ )
33
  apply_segm = SegmDetector()
34
 
35