Spaces:
Running
on
Zero
Running
on
Zero
update dependencies
Browse files- .gitignore +2 -1
- annotator/dsine/__init__.py +0 -0
- annotator/dsine/dsine.py +303 -0
- annotator/dsine/rotation.py +85 -0
- annotator/dsine/submodules.py +237 -0
- annotator/dsine/utils.py +104 -0
- annotator/dsine_local.py +63 -0
- app.py +5 -2
.gitignore
CHANGED
@@ -159,4 +159,5 @@ cython_debug/
|
|
159 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
-
#.idea/
|
|
|
|
159 |
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
#.idea/
|
163 |
+
gradio/
|
annotator/dsine/__init__.py
ADDED
File without changes
|
annotator/dsine/dsine.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from .submodules import (
|
7 |
+
Encoder,
|
8 |
+
ConvGRU,
|
9 |
+
UpSampleBN,
|
10 |
+
UpSampleGN,
|
11 |
+
RayReLU,
|
12 |
+
convex_upsampling,
|
13 |
+
get_unfold,
|
14 |
+
get_prediction_head,
|
15 |
+
INPUT_CHANNELS_DICT,
|
16 |
+
)
|
17 |
+
from .rotation import axis_angle_to_matrix
|
18 |
+
|
19 |
+
|
20 |
+
class Decoder(nn.Module):
|
21 |
+
def __init__(self, output_dims, B=5, NF=2048, BN=False, downsample_ratio=8):
|
22 |
+
super(Decoder, self).__init__()
|
23 |
+
input_channels = INPUT_CHANNELS_DICT[B]
|
24 |
+
output_dim, feature_dim, hidden_dim = output_dims
|
25 |
+
features = bottleneck_features = NF
|
26 |
+
self.downsample_ratio = downsample_ratio
|
27 |
+
|
28 |
+
UpSample = UpSampleBN if BN else UpSampleGN
|
29 |
+
self.conv2 = nn.Conv2d(
|
30 |
+
bottleneck_features + 2, features, kernel_size=1, stride=1, padding=0
|
31 |
+
)
|
32 |
+
self.up1 = UpSample(
|
33 |
+
skip_input=features // 1 + input_channels[1] + 2,
|
34 |
+
output_features=features // 2,
|
35 |
+
align_corners=False,
|
36 |
+
)
|
37 |
+
self.up2 = UpSample(
|
38 |
+
skip_input=features // 2 + input_channels[2] + 2,
|
39 |
+
output_features=features // 4,
|
40 |
+
align_corners=False,
|
41 |
+
)
|
42 |
+
|
43 |
+
# prediction heads
|
44 |
+
i_dim = features // 4
|
45 |
+
h_dim = 128
|
46 |
+
self.normal_head = get_prediction_head(i_dim + 2, h_dim, output_dim)
|
47 |
+
self.feature_head = get_prediction_head(i_dim + 2, h_dim, feature_dim)
|
48 |
+
self.hidden_head = get_prediction_head(i_dim + 2, h_dim, hidden_dim)
|
49 |
+
|
50 |
+
def forward(self, features, uvs):
|
51 |
+
_, _, x_block2, x_block3, x_block4 = (
|
52 |
+
features[4],
|
53 |
+
features[5],
|
54 |
+
features[6],
|
55 |
+
features[8],
|
56 |
+
features[11],
|
57 |
+
)
|
58 |
+
uv_32, uv_16, uv_8 = uvs
|
59 |
+
|
60 |
+
x_d0 = self.conv2(torch.cat([x_block4, uv_32], dim=1))
|
61 |
+
x_d1 = self.up1(x_d0, torch.cat([x_block3, uv_16], dim=1))
|
62 |
+
x_feat = self.up2(x_d1, torch.cat([x_block2, uv_8], dim=1))
|
63 |
+
x_feat = torch.cat([x_feat, uv_8], dim=1)
|
64 |
+
|
65 |
+
normal = self.normal_head(x_feat)
|
66 |
+
normal = F.normalize(normal, dim=1)
|
67 |
+
f = self.feature_head(x_feat)
|
68 |
+
h = self.hidden_head(x_feat)
|
69 |
+
return normal, f, h
|
70 |
+
|
71 |
+
|
72 |
+
class DSINE(nn.Module):
|
73 |
+
def __init__(self):
|
74 |
+
super(DSINE, self).__init__()
|
75 |
+
self.downsample_ratio = 8
|
76 |
+
self.ps = 5 # patch size
|
77 |
+
self.num_iter = 5 # num iterations
|
78 |
+
|
79 |
+
# define encoder
|
80 |
+
self.encoder = Encoder(
|
81 |
+
B=5,
|
82 |
+
pretrained=True,
|
83 |
+
)
|
84 |
+
|
85 |
+
# define decoder
|
86 |
+
self.output_dim = output_dim = 3
|
87 |
+
self.feature_dim = feature_dim = 64
|
88 |
+
self.hidden_dim = hidden_dim = 64
|
89 |
+
self.decoder = Decoder(
|
90 |
+
[output_dim, feature_dim, hidden_dim], B=5, NF=2048, BN=False
|
91 |
+
)
|
92 |
+
|
93 |
+
# ray direction-based ReLU
|
94 |
+
self.ray_relu = RayReLU(eps=1e-2)
|
95 |
+
|
96 |
+
# pixel_coords (1, 3, H, W)
|
97 |
+
# NOTE: this is set to some arbitrarily high number,
|
98 |
+
# if your input is 2000+ pixels wide/tall, increase these values
|
99 |
+
h = 2000
|
100 |
+
w = 2000
|
101 |
+
pixel_coords = np.ones((3, h, w)).astype(np.float32)
|
102 |
+
x_range = np.concatenate([np.arange(w).reshape(1, w)] * h, axis=0)
|
103 |
+
y_range = np.concatenate([np.arange(h).reshape(h, 1)] * w, axis=1)
|
104 |
+
pixel_coords[0, :, :] = x_range + 0.5
|
105 |
+
pixel_coords[1, :, :] = y_range + 0.5
|
106 |
+
self.pixel_coords = torch.from_numpy(pixel_coords).unsqueeze(0)
|
107 |
+
|
108 |
+
# define ConvGRU cell
|
109 |
+
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=feature_dim + 2, ks=self.ps)
|
110 |
+
|
111 |
+
# padding used during NRN
|
112 |
+
self.pad = (self.ps - 1) // 2
|
113 |
+
|
114 |
+
# prediction heads
|
115 |
+
self.prob_head = get_prediction_head(
|
116 |
+
self.hidden_dim + 2, 64, self.ps * self.ps
|
117 |
+
) # weights assigned for each nghbr pixel
|
118 |
+
self.xy_head = get_prediction_head(
|
119 |
+
self.hidden_dim + 2, 64, self.ps * self.ps * 2
|
120 |
+
) # rotation axis for each nghbr pixel
|
121 |
+
self.angle_head = get_prediction_head(
|
122 |
+
self.hidden_dim + 2, 64, self.ps * self.ps
|
123 |
+
) # rotation angle for each nghbr pixel
|
124 |
+
|
125 |
+
# prediction heads - weights used for upsampling the coarse resolution output
|
126 |
+
self.up_prob_head = get_prediction_head(
|
127 |
+
self.hidden_dim + 2, 64, 9 * self.downsample_ratio * self.downsample_ratio
|
128 |
+
)
|
129 |
+
|
130 |
+
def get_ray(self, intrins, H, W, orig_H, orig_W, return_uv=False):
|
131 |
+
B, _, _ = intrins.shape
|
132 |
+
fu = intrins[:, 0, 0][:, None, None] * (W / orig_W)
|
133 |
+
cu = intrins[:, 0, 2][:, None, None] * (W / orig_W)
|
134 |
+
fv = intrins[:, 1, 1][:, None, None] * (H / orig_H)
|
135 |
+
cv = intrins[:, 1, 2][:, None, None] * (H / orig_H)
|
136 |
+
|
137 |
+
# (B, 2, H, W)
|
138 |
+
ray = self.pixel_coords[:, :, :H, :W].repeat(B, 1, 1, 1)
|
139 |
+
ray[:, 0, :, :] = (ray[:, 0, :, :] - cu) / fu
|
140 |
+
ray[:, 1, :, :] = (ray[:, 1, :, :] - cv) / fv
|
141 |
+
|
142 |
+
if return_uv:
|
143 |
+
return ray[:, :2, :, :]
|
144 |
+
else:
|
145 |
+
return F.normalize(ray, dim=1)
|
146 |
+
|
147 |
+
def upsample(self, h, pred_norm, uv_8):
|
148 |
+
up_mask = self.up_prob_head(torch.cat([h, uv_8], dim=1))
|
149 |
+
up_pred_norm = convex_upsampling(pred_norm, up_mask, self.downsample_ratio)
|
150 |
+
up_pred_norm = F.normalize(up_pred_norm, dim=1)
|
151 |
+
return up_pred_norm
|
152 |
+
|
153 |
+
def refine(self, h, feat_map, pred_norm, intrins, orig_H, orig_W, uv_8, ray_8):
|
154 |
+
B, C, H, W = pred_norm.shape
|
155 |
+
fu = intrins[:, 0, 0][:, None, None, None] * (W / orig_W) # (B, 1, 1, 1)
|
156 |
+
cu = intrins[:, 0, 2][:, None, None, None] * (W / orig_W)
|
157 |
+
fv = intrins[:, 1, 1][:, None, None, None] * (H / orig_H)
|
158 |
+
cv = intrins[:, 1, 2][:, None, None, None] * (H / orig_H)
|
159 |
+
|
160 |
+
h_new = self.gru(h, feat_map)
|
161 |
+
|
162 |
+
# get nghbr prob (B, 1, ps*ps, h, w)
|
163 |
+
nghbr_prob = self.prob_head(torch.cat([h_new, uv_8], dim=1)).unsqueeze(1)
|
164 |
+
nghbr_prob = torch.sigmoid(nghbr_prob)
|
165 |
+
|
166 |
+
# get nghbr normals (B, 3, ps*ps, h, w)
|
167 |
+
nghbr_normals = get_unfold(pred_norm, ps=self.ps, pad=self.pad)
|
168 |
+
|
169 |
+
# get nghbr xy (B, 2, ps*ps, h, w)
|
170 |
+
nghbr_xys = self.xy_head(torch.cat([h_new, uv_8], dim=1))
|
171 |
+
nghbr_xs, nghbr_ys = torch.split(
|
172 |
+
nghbr_xys, [self.ps * self.ps, self.ps * self.ps], dim=1
|
173 |
+
)
|
174 |
+
nghbr_xys = torch.cat([nghbr_xs.unsqueeze(1), nghbr_ys.unsqueeze(1)], dim=1)
|
175 |
+
nghbr_xys = F.normalize(nghbr_xys, dim=1)
|
176 |
+
|
177 |
+
# get nghbr theta (B, 1, ps*ps, h, w)
|
178 |
+
nghbr_angle = self.angle_head(torch.cat([h_new, uv_8], dim=1)).unsqueeze(1)
|
179 |
+
nghbr_angle = torch.sigmoid(nghbr_angle) * np.pi
|
180 |
+
|
181 |
+
# get nghbr pixel coord (1, 3, ps*ps, h, w)
|
182 |
+
nghbr_pixel_coord = get_unfold(
|
183 |
+
self.pixel_coords[:, :, :H, :W], ps=self.ps, pad=self.pad
|
184 |
+
)
|
185 |
+
|
186 |
+
# nghbr axes (B, 3, ps*ps, h, w)
|
187 |
+
nghbr_axes = torch.zeros_like(nghbr_normals)
|
188 |
+
|
189 |
+
du_over_fu = nghbr_xys[:, 0, ...] / fu # (B, ps*ps, h, w)
|
190 |
+
dv_over_fv = nghbr_xys[:, 1, ...] / fv # (B, ps*ps, h, w)
|
191 |
+
|
192 |
+
term_u = (
|
193 |
+
nghbr_pixel_coord[:, 0, ...] + nghbr_xys[:, 0, ...] - cu
|
194 |
+
) / fu # (B, ps*ps, h, w)
|
195 |
+
term_v = (
|
196 |
+
nghbr_pixel_coord[:, 1, ...] + nghbr_xys[:, 1, ...] - cv
|
197 |
+
) / fv # (B, ps*ps, h, w)
|
198 |
+
|
199 |
+
nx = nghbr_normals[:, 0, ...] # (B, ps*ps, h, w)
|
200 |
+
ny = nghbr_normals[:, 1, ...] # (B, ps*ps, h, w)
|
201 |
+
nz = nghbr_normals[:, 2, ...] # (B, ps*ps, h, w)
|
202 |
+
|
203 |
+
nghbr_delta_z_num = -(du_over_fu * nx + dv_over_fv * ny)
|
204 |
+
nghbr_delta_z_denom = term_u * nx + term_v * ny + nz
|
205 |
+
nghbr_delta_z_denom[torch.abs(nghbr_delta_z_denom) < 1e-8] = 1e-8 * torch.sign(
|
206 |
+
nghbr_delta_z_denom[torch.abs(nghbr_delta_z_denom) < 1e-8]
|
207 |
+
)
|
208 |
+
nghbr_delta_z = nghbr_delta_z_num / nghbr_delta_z_denom
|
209 |
+
|
210 |
+
nghbr_axes[:, 0, ...] = du_over_fu + nghbr_delta_z * term_u
|
211 |
+
nghbr_axes[:, 1, ...] = dv_over_fv + nghbr_delta_z * term_v
|
212 |
+
nghbr_axes[:, 2, ...] = nghbr_delta_z
|
213 |
+
nghbr_axes = F.normalize(nghbr_axes, dim=1) # (B, 3, ps*ps, h, w)
|
214 |
+
|
215 |
+
# make sure axes are all valid
|
216 |
+
invalid = (
|
217 |
+
torch.sum(
|
218 |
+
torch.logical_or(
|
219 |
+
torch.isnan(nghbr_axes), torch.isinf(nghbr_axes)
|
220 |
+
).float(),
|
221 |
+
dim=1,
|
222 |
+
)
|
223 |
+
> 0.5
|
224 |
+
) # (B, ps*ps, h, w)
|
225 |
+
nghbr_axes[:, 0, ...][invalid] = 0.0
|
226 |
+
nghbr_axes[:, 1, ...][invalid] = 0.0
|
227 |
+
nghbr_axes[:, 2, ...][invalid] = 0.0
|
228 |
+
|
229 |
+
# nghbr_axes_angle (B, 3, ps*ps, h, w)
|
230 |
+
nghbr_axes_angle = nghbr_axes * nghbr_angle
|
231 |
+
nghbr_axes_angle = nghbr_axes_angle.permute(
|
232 |
+
0, 2, 3, 4, 1
|
233 |
+
) # (B, ps*ps, h, w, 3)
|
234 |
+
nghbr_R = axis_angle_to_matrix(nghbr_axes_angle) # (B, ps*ps, h, w, 3, 3)
|
235 |
+
|
236 |
+
# (B, 3, ps*ps, h, w)
|
237 |
+
nghbr_normals_rot = (
|
238 |
+
torch.bmm(
|
239 |
+
nghbr_R.reshape(B * self.ps * self.ps * H * W, 3, 3),
|
240 |
+
nghbr_normals.permute(0, 2, 3, 4, 1)
|
241 |
+
.reshape(B * self.ps * self.ps * H * W, 3)
|
242 |
+
.unsqueeze(-1),
|
243 |
+
)
|
244 |
+
.reshape(B, self.ps * self.ps, H, W, 3, 1)
|
245 |
+
.squeeze(-1)
|
246 |
+
.permute(0, 4, 1, 2, 3)
|
247 |
+
) # (B, 3, ps*ps, h, w)
|
248 |
+
nghbr_normals_rot = F.normalize(nghbr_normals_rot, dim=1)
|
249 |
+
|
250 |
+
# ray ReLU
|
251 |
+
nghbr_normals_rot = torch.cat(
|
252 |
+
[
|
253 |
+
self.ray_relu(nghbr_normals_rot[:, :, i, :, :], ray_8).unsqueeze(2)
|
254 |
+
for i in range(nghbr_normals_rot.size(2))
|
255 |
+
],
|
256 |
+
dim=2,
|
257 |
+
)
|
258 |
+
|
259 |
+
# (B, 1, ps*ps, h, w) * (B, 3, ps*ps, h, w)
|
260 |
+
pred_norm = torch.sum(nghbr_prob * nghbr_normals_rot, dim=2) # (B, C, H, W)
|
261 |
+
pred_norm = F.normalize(pred_norm, dim=1)
|
262 |
+
|
263 |
+
up_mask = self.up_prob_head(torch.cat([h_new, uv_8], dim=1))
|
264 |
+
up_pred_norm = convex_upsampling(pred_norm, up_mask, self.downsample_ratio)
|
265 |
+
up_pred_norm = F.normalize(up_pred_norm, dim=1)
|
266 |
+
|
267 |
+
return h_new, pred_norm, up_pred_norm
|
268 |
+
|
269 |
+
def forward(self, img, intrins=None):
|
270 |
+
# Step 1. encoder
|
271 |
+
features = self.encoder(img)
|
272 |
+
|
273 |
+
# Step 2. get uv encoding
|
274 |
+
B, _, orig_H, orig_W = img.shape
|
275 |
+
intrins[:, 0, 2] += 0.5
|
276 |
+
intrins[:, 1, 2] += 0.5
|
277 |
+
uv_32 = self.get_ray(
|
278 |
+
intrins, orig_H // 32, orig_W // 32, orig_H, orig_W, return_uv=True
|
279 |
+
)
|
280 |
+
uv_16 = self.get_ray(
|
281 |
+
intrins, orig_H // 16, orig_W // 16, orig_H, orig_W, return_uv=True
|
282 |
+
)
|
283 |
+
uv_8 = self.get_ray(
|
284 |
+
intrins, orig_H // 8, orig_W // 8, orig_H, orig_W, return_uv=True
|
285 |
+
)
|
286 |
+
ray_8 = self.get_ray(intrins, orig_H // 8, orig_W // 8, orig_H, orig_W)
|
287 |
+
|
288 |
+
# Step 3. decoder - initial prediction
|
289 |
+
pred_norm, feat_map, h = self.decoder(features, uvs=(uv_32, uv_16, uv_8))
|
290 |
+
pred_norm = self.ray_relu(pred_norm, ray_8)
|
291 |
+
|
292 |
+
# Step 4. add ray direction encoding
|
293 |
+
feat_map = torch.cat([feat_map, uv_8], dim=1)
|
294 |
+
|
295 |
+
# iterative refinement
|
296 |
+
up_pred_norm = self.upsample(h, pred_norm, uv_8)
|
297 |
+
pred_list = [up_pred_norm]
|
298 |
+
for i in range(self.num_iter):
|
299 |
+
h, pred_norm, up_pred_norm = self.refine(
|
300 |
+
h, feat_map, pred_norm.detach(), intrins, orig_H, orig_W, uv_8, ray_8
|
301 |
+
)
|
302 |
+
pred_list.append(up_pred_norm)
|
303 |
+
return pred_list
|
annotator/dsine/rotation.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
# NOTE: from PyTorch3D
|
6 |
+
def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
|
7 |
+
"""
|
8 |
+
Convert rotations given as axis/angle to quaternions.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
12 |
+
as a tensor of shape (..., 3), where the magnitude is
|
13 |
+
the angle turned anticlockwise in radians around the
|
14 |
+
vector's direction.
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
18 |
+
"""
|
19 |
+
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
|
20 |
+
half_angles = angles * 0.5
|
21 |
+
eps = 1e-6
|
22 |
+
small_angles = angles.abs() < eps
|
23 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
24 |
+
sin_half_angles_over_angles[~small_angles] = (
|
25 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
26 |
+
)
|
27 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
28 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
29 |
+
sin_half_angles_over_angles[small_angles] = (
|
30 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
31 |
+
)
|
32 |
+
quaternions = torch.cat(
|
33 |
+
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
|
34 |
+
)
|
35 |
+
return quaternions
|
36 |
+
|
37 |
+
|
38 |
+
# NOTE: from PyTorch3D
|
39 |
+
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
|
40 |
+
"""
|
41 |
+
Convert rotations given as quaternions to rotation matrices.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
quaternions: quaternions with real part first,
|
45 |
+
as tensor of shape (..., 4).
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
49 |
+
"""
|
50 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
51 |
+
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
52 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
53 |
+
|
54 |
+
o = torch.stack(
|
55 |
+
(
|
56 |
+
1 - two_s * (j * j + k * k),
|
57 |
+
two_s * (i * j - k * r),
|
58 |
+
two_s * (i * k + j * r),
|
59 |
+
two_s * (i * j + k * r),
|
60 |
+
1 - two_s * (i * i + k * k),
|
61 |
+
two_s * (j * k - i * r),
|
62 |
+
two_s * (i * k - j * r),
|
63 |
+
two_s * (j * k + i * r),
|
64 |
+
1 - two_s * (i * i + j * j),
|
65 |
+
),
|
66 |
+
-1,
|
67 |
+
)
|
68 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
69 |
+
|
70 |
+
|
71 |
+
# NOTE: from PyTorch3D
|
72 |
+
def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
|
73 |
+
"""
|
74 |
+
Convert rotations given as axis/angle to rotation matrices.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
78 |
+
as a tensor of shape (..., 3), where the magnitude is
|
79 |
+
the angle turned anticlockwise in radians around the
|
80 |
+
vector's direction.
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
84 |
+
"""
|
85 |
+
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
|
annotator/dsine/submodules.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import geffnet
|
5 |
+
|
6 |
+
|
7 |
+
INPUT_CHANNELS_DICT = {
|
8 |
+
0: [1280, 112, 40, 24, 16],
|
9 |
+
1: [1280, 112, 40, 24, 16],
|
10 |
+
2: [1408, 120, 48, 24, 16],
|
11 |
+
3: [1536, 136, 48, 32, 24],
|
12 |
+
4: [1792, 160, 56, 32, 24],
|
13 |
+
5: [2048, 176, 64, 40, 24],
|
14 |
+
6: [2304, 200, 72, 40, 32],
|
15 |
+
7: [2560, 224, 80, 48, 32],
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
class Encoder(nn.Module):
|
20 |
+
def __init__(self, B=5, pretrained=True):
|
21 |
+
"""e.g. B=5 will return EfficientNet-B5"""
|
22 |
+
super(Encoder, self).__init__()
|
23 |
+
basemodel_name = 'tf_efficientnet_b%s_ap' % B
|
24 |
+
basemodel = geffnet.create_model(basemodel_name, pretrained=pretrained)
|
25 |
+
# Remove last layer
|
26 |
+
basemodel.global_pool = nn.Identity()
|
27 |
+
basemodel.classifier = nn.Identity()
|
28 |
+
self.original_model = basemodel
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
features = [x]
|
32 |
+
for k, v in self.original_model._modules.items():
|
33 |
+
if k == "blocks":
|
34 |
+
for ki, vi in v._modules.items():
|
35 |
+
features.append(vi(features[-1]))
|
36 |
+
else:
|
37 |
+
features.append(v(features[-1]))
|
38 |
+
return features
|
39 |
+
|
40 |
+
|
41 |
+
class ConvGRU(nn.Module):
|
42 |
+
def __init__(self, hidden_dim, input_dim, ks=3):
|
43 |
+
super(ConvGRU, self).__init__()
|
44 |
+
p = (ks - 1) // 2
|
45 |
+
self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, ks, padding=p)
|
46 |
+
self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, ks, padding=p)
|
47 |
+
self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, ks, padding=p)
|
48 |
+
|
49 |
+
def forward(self, h, x):
|
50 |
+
hx = torch.cat([h, x], dim=1)
|
51 |
+
z = torch.sigmoid(self.convz(hx))
|
52 |
+
r = torch.sigmoid(self.convr(hx))
|
53 |
+
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
|
54 |
+
h = (1 - z) * h + z * q
|
55 |
+
return h
|
56 |
+
|
57 |
+
|
58 |
+
class RayReLU(nn.Module):
|
59 |
+
def __init__(self, eps=1e-2):
|
60 |
+
super(RayReLU, self).__init__()
|
61 |
+
self.eps = eps
|
62 |
+
|
63 |
+
def forward(self, pred_norm, ray):
|
64 |
+
# angle between the predicted normal and ray direction
|
65 |
+
cos = torch.cosine_similarity(pred_norm, ray, dim=1).unsqueeze(
|
66 |
+
1
|
67 |
+
) # (B, 1, H, W)
|
68 |
+
|
69 |
+
# component of pred_norm along view
|
70 |
+
norm_along_view = ray * cos
|
71 |
+
|
72 |
+
# cos should be bigger than eps
|
73 |
+
norm_along_view_relu = ray * (torch.relu(cos - self.eps) + self.eps)
|
74 |
+
|
75 |
+
# difference
|
76 |
+
diff = norm_along_view_relu - norm_along_view
|
77 |
+
|
78 |
+
# updated pred_norm
|
79 |
+
new_pred_norm = pred_norm + diff
|
80 |
+
new_pred_norm = F.normalize(new_pred_norm, dim=1)
|
81 |
+
|
82 |
+
return new_pred_norm
|
83 |
+
|
84 |
+
|
85 |
+
class UpSampleBN(nn.Module):
|
86 |
+
def __init__(self, skip_input, output_features, align_corners=True):
|
87 |
+
super(UpSampleBN, self).__init__()
|
88 |
+
self._net = nn.Sequential(
|
89 |
+
nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1),
|
90 |
+
nn.BatchNorm2d(output_features),
|
91 |
+
nn.LeakyReLU(),
|
92 |
+
nn.Conv2d(
|
93 |
+
output_features, output_features, kernel_size=3, stride=1, padding=1
|
94 |
+
),
|
95 |
+
nn.BatchNorm2d(output_features),
|
96 |
+
nn.LeakyReLU(),
|
97 |
+
)
|
98 |
+
self.align_corners = align_corners
|
99 |
+
|
100 |
+
def forward(self, x, concat_with):
|
101 |
+
up_x = F.interpolate(
|
102 |
+
x,
|
103 |
+
size=[concat_with.size(2), concat_with.size(3)],
|
104 |
+
mode="bilinear",
|
105 |
+
align_corners=self.align_corners,
|
106 |
+
)
|
107 |
+
f = torch.cat([up_x, concat_with], dim=1)
|
108 |
+
return self._net(f)
|
109 |
+
|
110 |
+
|
111 |
+
class Conv2d_WS(nn.Conv2d):
|
112 |
+
"""weight standardization"""
|
113 |
+
|
114 |
+
def __init__(
|
115 |
+
self,
|
116 |
+
in_channels,
|
117 |
+
out_channels,
|
118 |
+
kernel_size,
|
119 |
+
stride=1,
|
120 |
+
padding=0,
|
121 |
+
dilation=1,
|
122 |
+
groups=1,
|
123 |
+
bias=True,
|
124 |
+
):
|
125 |
+
super(Conv2d_WS, self).__init__(
|
126 |
+
in_channels,
|
127 |
+
out_channels,
|
128 |
+
kernel_size,
|
129 |
+
stride,
|
130 |
+
padding,
|
131 |
+
dilation,
|
132 |
+
groups,
|
133 |
+
bias,
|
134 |
+
)
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
weight = self.weight
|
138 |
+
weight_mean = (
|
139 |
+
weight.mean(dim=1, keepdim=True)
|
140 |
+
.mean(dim=2, keepdim=True)
|
141 |
+
.mean(dim=3, keepdim=True)
|
142 |
+
)
|
143 |
+
weight = weight - weight_mean
|
144 |
+
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
|
145 |
+
weight = weight / std.expand_as(weight)
|
146 |
+
return F.conv2d(
|
147 |
+
x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
148 |
+
)
|
149 |
+
|
150 |
+
|
151 |
+
class UpSampleGN(nn.Module):
|
152 |
+
"""UpSample with GroupNorm"""
|
153 |
+
|
154 |
+
def __init__(self, skip_input, output_features, align_corners=True):
|
155 |
+
super(UpSampleGN, self).__init__()
|
156 |
+
self._net = nn.Sequential(
|
157 |
+
Conv2d_WS(skip_input, output_features, kernel_size=3, stride=1, padding=1),
|
158 |
+
nn.GroupNorm(8, output_features),
|
159 |
+
nn.LeakyReLU(),
|
160 |
+
Conv2d_WS(
|
161 |
+
output_features, output_features, kernel_size=3, stride=1, padding=1
|
162 |
+
),
|
163 |
+
nn.GroupNorm(8, output_features),
|
164 |
+
nn.LeakyReLU(),
|
165 |
+
)
|
166 |
+
self.align_corners = align_corners
|
167 |
+
|
168 |
+
def forward(self, x, concat_with):
|
169 |
+
up_x = F.interpolate(
|
170 |
+
x,
|
171 |
+
size=[concat_with.size(2), concat_with.size(3)],
|
172 |
+
mode="bilinear",
|
173 |
+
align_corners=self.align_corners,
|
174 |
+
)
|
175 |
+
f = torch.cat([up_x, concat_with], dim=1)
|
176 |
+
return self._net(f)
|
177 |
+
|
178 |
+
|
179 |
+
def upsample_via_bilinear(out, up_mask, downsample_ratio):
|
180 |
+
"""bilinear upsampling (up_mask is a dummy variable)"""
|
181 |
+
return F.interpolate(
|
182 |
+
out, scale_factor=downsample_ratio, mode="bilinear", align_corners=True
|
183 |
+
)
|
184 |
+
|
185 |
+
|
186 |
+
def upsample_via_mask(out, up_mask, downsample_ratio):
|
187 |
+
"""convex upsampling"""
|
188 |
+
# out: low-resolution output (B, o_dim, H, W)
|
189 |
+
# up_mask: (B, 9*k*k, H, W)
|
190 |
+
k = downsample_ratio
|
191 |
+
|
192 |
+
N, o_dim, H, W = out.shape
|
193 |
+
up_mask = up_mask.view(N, 1, 9, k, k, H, W)
|
194 |
+
up_mask = torch.softmax(up_mask, dim=2) # (B, 1, 9, k, k, H, W)
|
195 |
+
|
196 |
+
up_out = F.unfold(out, [3, 3], padding=1) # (B, 2, H, W) -> (B, 2 X 3*3, H*W)
|
197 |
+
up_out = up_out.view(N, o_dim, 9, 1, 1, H, W) # (B, 2, 3*3, 1, 1, H, W)
|
198 |
+
up_out = torch.sum(up_mask * up_out, dim=2) # (B, 2, k, k, H, W)
|
199 |
+
|
200 |
+
up_out = up_out.permute(0, 1, 4, 2, 5, 3) # (B, 2, H, k, W, k)
|
201 |
+
return up_out.reshape(N, o_dim, k * H, k * W) # (B, 2, kH, kW)
|
202 |
+
|
203 |
+
|
204 |
+
def convex_upsampling(out, up_mask, k):
|
205 |
+
# out: low-resolution output (B, C, H, W)
|
206 |
+
# up_mask: (B, 9*k*k, H, W)
|
207 |
+
B, C, H, W = out.shape
|
208 |
+
up_mask = up_mask.view(B, 1, 9, k, k, H, W)
|
209 |
+
up_mask = torch.softmax(up_mask, dim=2) # (B, 1, 9, k, k, H, W)
|
210 |
+
|
211 |
+
out = F.pad(out, pad=(1, 1, 1, 1), mode="replicate")
|
212 |
+
up_out = F.unfold(out, [3, 3], padding=0) # (B, C, H, W) -> (B, C X 3*3, H*W)
|
213 |
+
up_out = up_out.view(B, C, 9, 1, 1, H, W) # (B, C, 9, 1, 1, H, W)
|
214 |
+
|
215 |
+
up_out = torch.sum(up_mask * up_out, dim=2) # (B, C, k, k, H, W)
|
216 |
+
up_out = up_out.permute(0, 1, 4, 2, 5, 3) # (B, C, H, k, W, k)
|
217 |
+
return up_out.reshape(B, C, k * H, k * W) # (B, C, kH, kW)
|
218 |
+
|
219 |
+
|
220 |
+
def get_unfold(pred_norm, ps, pad):
|
221 |
+
B, C, H, W = pred_norm.shape
|
222 |
+
pred_norm = F.pad(
|
223 |
+
pred_norm, pad=(pad, pad, pad, pad), mode="replicate"
|
224 |
+
) # (B, C, h, w)
|
225 |
+
pred_norm_unfold = F.unfold(pred_norm, [ps, ps], padding=0) # (B, C X ps*ps, h*w)
|
226 |
+
pred_norm_unfold = pred_norm_unfold.view(B, C, ps * ps, H, W) # (B, C, ps*ps, h, w)
|
227 |
+
return pred_norm_unfold
|
228 |
+
|
229 |
+
|
230 |
+
def get_prediction_head(input_dim, hidden_dim, output_dim):
|
231 |
+
return nn.Sequential(
|
232 |
+
nn.Conv2d(input_dim, hidden_dim, 3, padding=1),
|
233 |
+
nn.ReLU(inplace=True),
|
234 |
+
nn.Conv2d(hidden_dim, hidden_dim, 1),
|
235 |
+
nn.ReLU(inplace=True),
|
236 |
+
nn.Conv2d(hidden_dim, output_dim, 1),
|
237 |
+
)
|
annotator/dsine/utils.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" utils
|
2 |
+
"""
|
3 |
+
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
def load_checkpoint(fpath, model):
|
10 |
+
print("loading checkpoint... {}".format(fpath))
|
11 |
+
|
12 |
+
ckpt = torch.load(fpath, map_location="cpu")["model"]
|
13 |
+
|
14 |
+
load_dict = {}
|
15 |
+
for k, v in ckpt.items():
|
16 |
+
if k.startswith("module."):
|
17 |
+
k_ = k.replace("module.", "")
|
18 |
+
load_dict[k_] = v
|
19 |
+
else:
|
20 |
+
load_dict[k] = v
|
21 |
+
|
22 |
+
model.load_state_dict(load_dict)
|
23 |
+
print("loading checkpoint... / done")
|
24 |
+
return model
|
25 |
+
|
26 |
+
|
27 |
+
def compute_normal_error(pred_norm, gt_norm):
|
28 |
+
pred_error = torch.cosine_similarity(pred_norm, gt_norm, dim=1)
|
29 |
+
pred_error = torch.clamp(pred_error, min=-1.0, max=1.0)
|
30 |
+
pred_error = torch.acos(pred_error) * 180.0 / np.pi
|
31 |
+
pred_error = pred_error.unsqueeze(1) # (B, 1, H, W)
|
32 |
+
return pred_error
|
33 |
+
|
34 |
+
|
35 |
+
def compute_normal_metrics(total_normal_errors):
|
36 |
+
total_normal_errors = total_normal_errors.detach().cpu().numpy()
|
37 |
+
num_pixels = total_normal_errors.shape[0]
|
38 |
+
|
39 |
+
metrics = {
|
40 |
+
"mean": np.average(total_normal_errors),
|
41 |
+
"median": np.median(total_normal_errors),
|
42 |
+
"rmse": np.sqrt(np.sum(total_normal_errors * total_normal_errors) / num_pixels),
|
43 |
+
"a1": 100.0 * (np.sum(total_normal_errors < 5) / num_pixels),
|
44 |
+
"a2": 100.0 * (np.sum(total_normal_errors < 7.5) / num_pixels),
|
45 |
+
"a3": 100.0 * (np.sum(total_normal_errors < 11.25) / num_pixels),
|
46 |
+
"a4": 100.0 * (np.sum(total_normal_errors < 22.5) / num_pixels),
|
47 |
+
"a5": 100.0 * (np.sum(total_normal_errors < 30) / num_pixels),
|
48 |
+
}
|
49 |
+
|
50 |
+
return metrics
|
51 |
+
|
52 |
+
|
53 |
+
def pad_input(orig_H, orig_W):
|
54 |
+
if orig_W % 32 == 0:
|
55 |
+
l = 0
|
56 |
+
r = 0
|
57 |
+
else:
|
58 |
+
new_W = 32 * ((orig_W // 32) + 1)
|
59 |
+
l = (new_W - orig_W) // 2
|
60 |
+
r = (new_W - orig_W) - l
|
61 |
+
|
62 |
+
if orig_H % 32 == 0:
|
63 |
+
t = 0
|
64 |
+
b = 0
|
65 |
+
else:
|
66 |
+
new_H = 32 * ((orig_H // 32) + 1)
|
67 |
+
t = (new_H - orig_H) // 2
|
68 |
+
b = (new_H - orig_H) - t
|
69 |
+
return l, r, t, b
|
70 |
+
|
71 |
+
|
72 |
+
def get_intrins_from_fov(new_fov, H, W, device):
|
73 |
+
# NOTE: top-left pixel should be (0,0)
|
74 |
+
if W >= H:
|
75 |
+
new_fu = (W / 2.0) / np.tan(np.deg2rad(new_fov / 2.0))
|
76 |
+
new_fv = (W / 2.0) / np.tan(np.deg2rad(new_fov / 2.0))
|
77 |
+
else:
|
78 |
+
new_fu = (H / 2.0) / np.tan(np.deg2rad(new_fov / 2.0))
|
79 |
+
new_fv = (H / 2.0) / np.tan(np.deg2rad(new_fov / 2.0))
|
80 |
+
|
81 |
+
new_cu = (W / 2.0) - 0.5
|
82 |
+
new_cv = (H / 2.0) - 0.5
|
83 |
+
|
84 |
+
new_intrins = torch.tensor(
|
85 |
+
[[new_fu, 0, new_cu], [0, new_fv, new_cv], [0, 0, 1]],
|
86 |
+
dtype=torch.float32,
|
87 |
+
device=device,
|
88 |
+
)
|
89 |
+
|
90 |
+
return new_intrins
|
91 |
+
|
92 |
+
|
93 |
+
def get_intrins_from_txt(intrins_path, device):
|
94 |
+
# NOTE: top-left pixel should be (0,0)
|
95 |
+
with open(intrins_path, "r") as f:
|
96 |
+
intrins_ = f.readlines()[0].split()[0].split(",")
|
97 |
+
intrins_ = [float(i) for i in intrins_]
|
98 |
+
fx, fy, cx, cy = intrins_
|
99 |
+
|
100 |
+
intrins = torch.tensor(
|
101 |
+
[[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=torch.float32, device=device
|
102 |
+
)
|
103 |
+
|
104 |
+
return intrins
|
annotator/dsine_local.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from torchvision import transforms
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from .dsine.dsine import DSINE
|
7 |
+
from .dsine import utils as dsine_utils
|
8 |
+
|
9 |
+
|
10 |
+
class NormalDetector:
|
11 |
+
def __init__(self, model_path):
|
12 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
self.model = DSINE()
|
14 |
+
self.model = dsine_utils.load_checkpoint(model_path, self.model)
|
15 |
+
self.normalize = transforms.Normalize(
|
16 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
17 |
+
)
|
18 |
+
self.fov = 60
|
19 |
+
|
20 |
+
@torch.no_grad()
|
21 |
+
def __call__(self, image):
|
22 |
+
self.model.to(self.device)
|
23 |
+
self.model.pixel_coords = self.model.pixel_coords.to(self.device)
|
24 |
+
|
25 |
+
img = np.array(image).astype(np.float32) / 255.0
|
26 |
+
img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(self.device)
|
27 |
+
_, _, orig_H, orig_W = img.shape
|
28 |
+
l, r, t, b = dsine_utils.pad_input(orig_H, orig_W)
|
29 |
+
img = F.pad(img, (l, r, t, b), mode="constant", value=0.0)
|
30 |
+
img = self.normalize(img)
|
31 |
+
intrinsics = dsine_utils.get_intrins_from_fov(
|
32 |
+
new_fov=self.fov, H=orig_H, W=orig_W, device=self.device
|
33 |
+
).unsqueeze(0)
|
34 |
+
|
35 |
+
intrinsics[:, 0, 2] += l
|
36 |
+
intrinsics[:, 1, 2] += t
|
37 |
+
|
38 |
+
pred_norm = self.model(img, intrins=intrinsics)[-1]
|
39 |
+
pred_norm = pred_norm[:, :, t : t + orig_H, l : l + orig_W]
|
40 |
+
pred_norm_np = (
|
41 |
+
pred_norm.cpu().detach().numpy()[0, :, :, :].transpose(1, 2, 0)
|
42 |
+
) # (H, W, 3)
|
43 |
+
pred_norm_np = ((pred_norm_np + 1.0) / 2.0 * 255.0).astype(np.uint8)
|
44 |
+
normal_img = Image.fromarray(pred_norm_np).resize((orig_W, orig_H))
|
45 |
+
|
46 |
+
self.model.to("cpu")
|
47 |
+
self.model.pixel_coords = self.model.pixel_coords.to("cpu")
|
48 |
+
return normal_img
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
from diffusers.utils import load_image
|
53 |
+
|
54 |
+
image = load_image(
|
55 |
+
"https://qhstaticssl.kujiale.com/image/jpeg/1716177580588/9AAA49344B9CE33512C4EBD0A287495F.jpg"
|
56 |
+
)
|
57 |
+
image = np.asarray(image)
|
58 |
+
normal_detector = NormalDetector(
|
59 |
+
model_path="/juicefs/training/models/open_source/dsine/dsine.pt",
|
60 |
+
efficientnet_path="/juicefs/training/models/open_source/dsine/tf_efficientnet_b5_ap-9e82fae8.pth",
|
61 |
+
)
|
62 |
+
normal_image = normal_detector(image)
|
63 |
+
normal_image.save("normal_image.jpg")
|
app.py
CHANGED
@@ -7,10 +7,11 @@ from diffusers import (
|
|
7 |
UniPCMultistepScheduler,
|
8 |
)
|
9 |
import gradio as gr
|
|
|
10 |
|
11 |
from annotator.util import resize_image, HWC3
|
12 |
from annotator.midas import DepthDetector
|
13 |
-
from annotator.
|
14 |
from annotator.upernet import SegmDetector
|
15 |
|
16 |
controlnet_checkpoint = "kujiale-ai/controlnet"
|
@@ -26,7 +27,9 @@ pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
|
26 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
27 |
|
28 |
apply_depth = DepthDetector()
|
29 |
-
apply_normal = NormalDetector(
|
|
|
|
|
30 |
apply_segm = SegmDetector()
|
31 |
|
32 |
|
|
|
7 |
UniPCMultistepScheduler,
|
8 |
)
|
9 |
import gradio as gr
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
|
12 |
from annotator.util import resize_image, HWC3
|
13 |
from annotator.midas import DepthDetector
|
14 |
+
from annotator.dsine_local import NormalDetector
|
15 |
from annotator.upernet import SegmDetector
|
16 |
|
17 |
controlnet_checkpoint = "kujiale-ai/controlnet"
|
|
|
27 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
28 |
|
29 |
apply_depth = DepthDetector()
|
30 |
+
apply_normal = NormalDetector(
|
31 |
+
hf_hub_download("camenduru/DSINE", filename="dsine.pt")
|
32 |
+
)
|
33 |
apply_segm = SegmDetector()
|
34 |
|
35 |
|