AP123 commited on
Commit
b4ec18d
1 Parent(s): bbd51ad

Upload 15 files

Browse files
Files changed (15) hide show
  1. .gitignore +19 -0
  2. LICENSE +21 -0
  3. cam_utils.py +146 -0
  4. grid_put.py +300 -0
  5. gs_renderer.py +820 -0
  6. main.py +882 -0
  7. main2.py +671 -0
  8. mesh.py +394 -0
  9. mesh_renderer.py +154 -0
  10. mesh_utils.py +147 -0
  11. process.py +92 -0
  12. readme.md +120 -0
  13. requirements.txt +33 -0
  14. sh_utils.py +118 -0
  15. zero123.py +666 -0
.gitignore ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ build/
3
+ *.egg-info/
4
+ *.so
5
+ venv_*/
6
+ .vs/
7
+ .vscode/
8
+
9
+ tmp_*
10
+ data?
11
+ data??
12
+ scripts2
13
+
14
+ model_cache
15
+
16
+ logs
17
+ videos
18
+ images
19
+ *.mp4
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 dreamgaussian
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
cam_utils.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.spatial.transform import Rotation as R
3
+
4
+ import torch
5
+
6
+ def dot(x, y):
7
+ if isinstance(x, np.ndarray):
8
+ return np.sum(x * y, -1, keepdims=True)
9
+ else:
10
+ return torch.sum(x * y, -1, keepdim=True)
11
+
12
+
13
+ def length(x, eps=1e-20):
14
+ if isinstance(x, np.ndarray):
15
+ return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps))
16
+ else:
17
+ return torch.sqrt(torch.clamp(dot(x, x), min=eps))
18
+
19
+
20
+ def safe_normalize(x, eps=1e-20):
21
+ return x / length(x, eps)
22
+
23
+
24
+ def look_at(campos, target, opengl=True):
25
+ # campos: [N, 3], camera/eye position
26
+ # target: [N, 3], object to look at
27
+ # return: [N, 3, 3], rotation matrix
28
+ if not opengl:
29
+ # camera forward aligns with -z
30
+ forward_vector = safe_normalize(target - campos)
31
+ up_vector = np.array([0, 1, 0], dtype=np.float32)
32
+ right_vector = safe_normalize(np.cross(forward_vector, up_vector))
33
+ up_vector = safe_normalize(np.cross(right_vector, forward_vector))
34
+ else:
35
+ # camera forward aligns with +z
36
+ forward_vector = safe_normalize(campos - target)
37
+ up_vector = np.array([0, 1, 0], dtype=np.float32)
38
+ right_vector = safe_normalize(np.cross(up_vector, forward_vector))
39
+ up_vector = safe_normalize(np.cross(forward_vector, right_vector))
40
+ R = np.stack([right_vector, up_vector, forward_vector], axis=1)
41
+ return R
42
+
43
+
44
+ # elevation & azimuth to pose (cam2world) matrix
45
+ def orbit_camera(elevation, azimuth, radius=1, is_degree=True, target=None, opengl=True):
46
+ # radius: scalar
47
+ # elevation: scalar, in (-90, 90), from +y to -y is (-90, 90)
48
+ # azimuth: scalar, in (-180, 180), from +z to +x is (0, 90)
49
+ # return: [4, 4], camera pose matrix
50
+ if is_degree:
51
+ elevation = np.deg2rad(elevation)
52
+ azimuth = np.deg2rad(azimuth)
53
+ x = radius * np.cos(elevation) * np.sin(azimuth)
54
+ y = - radius * np.sin(elevation)
55
+ z = radius * np.cos(elevation) * np.cos(azimuth)
56
+ if target is None:
57
+ target = np.zeros([3], dtype=np.float32)
58
+ campos = np.array([x, y, z]) + target # [3]
59
+ T = np.eye(4, dtype=np.float32)
60
+ T[:3, :3] = look_at(campos, target, opengl)
61
+ T[:3, 3] = campos
62
+ return T
63
+
64
+
65
+ class OrbitCamera:
66
+ def __init__(self, W, H, r=2, fovy=60, near=0.01, far=100):
67
+ self.W = W
68
+ self.H = H
69
+ self.radius = r # camera distance from center
70
+ self.fovy = np.deg2rad(fovy) # deg 2 rad
71
+ self.near = near
72
+ self.far = far
73
+ self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point
74
+ self.rot = R.from_matrix(np.eye(3))
75
+ self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized!
76
+
77
+ @property
78
+ def fovx(self):
79
+ return 2 * np.arctan(np.tan(self.fovy / 2) * self.W / self.H)
80
+
81
+ @property
82
+ def campos(self):
83
+ return self.pose[:3, 3]
84
+
85
+ # pose (c2w)
86
+ @property
87
+ def pose(self):
88
+ # first move camera to radius
89
+ res = np.eye(4, dtype=np.float32)
90
+ res[2, 3] = self.radius # opengl convention...
91
+ # rotate
92
+ rot = np.eye(4, dtype=np.float32)
93
+ rot[:3, :3] = self.rot.as_matrix()
94
+ res = rot @ res
95
+ # translate
96
+ res[:3, 3] -= self.center
97
+ return res
98
+
99
+ # view (w2c)
100
+ @property
101
+ def view(self):
102
+ return np.linalg.inv(self.pose)
103
+
104
+ # projection (perspective)
105
+ @property
106
+ def perspective(self):
107
+ y = np.tan(self.fovy / 2)
108
+ aspect = self.W / self.H
109
+ return np.array(
110
+ [
111
+ [1 / (y * aspect), 0, 0, 0],
112
+ [0, -1 / y, 0, 0],
113
+ [
114
+ 0,
115
+ 0,
116
+ -(self.far + self.near) / (self.far - self.near),
117
+ -(2 * self.far * self.near) / (self.far - self.near),
118
+ ],
119
+ [0, 0, -1, 0],
120
+ ],
121
+ dtype=np.float32,
122
+ )
123
+
124
+ # intrinsics
125
+ @property
126
+ def intrinsics(self):
127
+ focal = self.H / (2 * np.tan(self.fovy / 2))
128
+ return np.array([focal, focal, self.W // 2, self.H // 2], dtype=np.float32)
129
+
130
+ @property
131
+ def mvp(self):
132
+ return self.perspective @ np.linalg.inv(self.pose) # [4, 4]
133
+
134
+ def orbit(self, dx, dy):
135
+ # rotate along camera up/side axis!
136
+ side = self.rot.as_matrix()[:3, 0]
137
+ rotvec_x = self.up * np.radians(-0.05 * dx)
138
+ rotvec_y = side * np.radians(-0.05 * dy)
139
+ self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot
140
+
141
+ def scale(self, delta):
142
+ self.radius *= 1.1 ** (-delta)
143
+
144
+ def pan(self, dx, dy, dz=0):
145
+ # pan in camera coordinate system (careful on the sensitivity!)
146
+ self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([-dx, -dy, dz])
grid_put.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ def stride_from_shape(shape):
5
+ stride = [1]
6
+ for x in reversed(shape[1:]):
7
+ stride.append(stride[-1] * x)
8
+ return list(reversed(stride))
9
+
10
+
11
+ def scatter_add_nd(input, indices, values):
12
+ # input: [..., C], D dimension + C channel
13
+ # indices: [N, D], long
14
+ # values: [N, C]
15
+
16
+ D = indices.shape[-1]
17
+ C = input.shape[-1]
18
+ size = input.shape[:-1]
19
+ stride = stride_from_shape(size)
20
+
21
+ assert len(size) == D
22
+
23
+ input = input.view(-1, C) # [HW, C]
24
+ flatten_indices = (indices * torch.tensor(stride, dtype=torch.long, device=indices.device)).sum(-1) # [N]
25
+
26
+ input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values)
27
+
28
+ return input.view(*size, C)
29
+
30
+
31
+ def scatter_add_nd_with_count(input, count, indices, values, weights=None):
32
+ # input: [..., C], D dimension + C channel
33
+ # count: [..., 1], D dimension
34
+ # indices: [N, D], long
35
+ # values: [N, C]
36
+
37
+ D = indices.shape[-1]
38
+ C = input.shape[-1]
39
+ size = input.shape[:-1]
40
+ stride = stride_from_shape(size)
41
+
42
+ assert len(size) == D
43
+
44
+ input = input.view(-1, C) # [HW, C]
45
+ count = count.view(-1, 1)
46
+
47
+ flatten_indices = (indices * torch.tensor(stride, dtype=torch.long, device=indices.device)).sum(-1) # [N]
48
+
49
+ if weights is None:
50
+ weights = torch.ones_like(values[..., :1])
51
+
52
+ input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values)
53
+ count.scatter_add_(0, flatten_indices.unsqueeze(1), weights)
54
+
55
+ return input.view(*size, C), count.view(*size, 1)
56
+
57
+ def nearest_grid_put_2d(H, W, coords, values, return_count=False):
58
+ # coords: [N, 2], float in [-1, 1]
59
+ # values: [N, C]
60
+
61
+ C = values.shape[-1]
62
+
63
+ indices = (coords * 0.5 + 0.5) * torch.tensor(
64
+ [H - 1, W - 1], dtype=torch.float32, device=coords.device
65
+ )
66
+ indices = indices.round().long() # [N, 2]
67
+
68
+ result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C]
69
+ count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
70
+ weights = torch.ones_like(values[..., :1]) # [N, 1]
71
+
72
+ result, count = scatter_add_nd_with_count(result, count, indices, values, weights)
73
+
74
+ if return_count:
75
+ return result, count
76
+
77
+ mask = (count.squeeze(-1) > 0)
78
+ result[mask] = result[mask] / count[mask].repeat(1, C)
79
+
80
+ return result
81
+
82
+
83
+ def linear_grid_put_2d(H, W, coords, values, return_count=False):
84
+ # coords: [N, 2], float in [-1, 1]
85
+ # values: [N, C]
86
+
87
+ C = values.shape[-1]
88
+
89
+ indices = (coords * 0.5 + 0.5) * torch.tensor(
90
+ [H - 1, W - 1], dtype=torch.float32, device=coords.device
91
+ )
92
+ indices_00 = indices.floor().long() # [N, 2]
93
+ indices_00[:, 0].clamp_(0, H - 2)
94
+ indices_00[:, 1].clamp_(0, W - 2)
95
+ indices_01 = indices_00 + torch.tensor(
96
+ [0, 1], dtype=torch.long, device=indices.device
97
+ )
98
+ indices_10 = indices_00 + torch.tensor(
99
+ [1, 0], dtype=torch.long, device=indices.device
100
+ )
101
+ indices_11 = indices_00 + torch.tensor(
102
+ [1, 1], dtype=torch.long, device=indices.device
103
+ )
104
+
105
+ h = indices[..., 0] - indices_00[..., 0].float()
106
+ w = indices[..., 1] - indices_00[..., 1].float()
107
+ w_00 = (1 - h) * (1 - w)
108
+ w_01 = (1 - h) * w
109
+ w_10 = h * (1 - w)
110
+ w_11 = h * w
111
+
112
+ result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C]
113
+ count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
114
+ weights = torch.ones_like(values[..., :1]) # [N, 1]
115
+
116
+ result, count = scatter_add_nd_with_count(result, count, indices_00, values * w_00.unsqueeze(1), weights* w_00.unsqueeze(1))
117
+ result, count = scatter_add_nd_with_count(result, count, indices_01, values * w_01.unsqueeze(1), weights* w_01.unsqueeze(1))
118
+ result, count = scatter_add_nd_with_count(result, count, indices_10, values * w_10.unsqueeze(1), weights* w_10.unsqueeze(1))
119
+ result, count = scatter_add_nd_with_count(result, count, indices_11, values * w_11.unsqueeze(1), weights* w_11.unsqueeze(1))
120
+
121
+ if return_count:
122
+ return result, count
123
+
124
+ mask = (count.squeeze(-1) > 0)
125
+ result[mask] = result[mask] / count[mask].repeat(1, C)
126
+
127
+ return result
128
+
129
+ def mipmap_linear_grid_put_2d(H, W, coords, values, min_resolution=32, return_count=False):
130
+ # coords: [N, 2], float in [-1, 1]
131
+ # values: [N, C]
132
+
133
+ C = values.shape[-1]
134
+
135
+ result = torch.zeros(H, W, C, device=values.device, dtype=values.dtype) # [H, W, C]
136
+ count = torch.zeros(H, W, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
137
+
138
+ cur_H, cur_W = H, W
139
+
140
+ while min(cur_H, cur_W) > min_resolution:
141
+
142
+ # try to fill the holes
143
+ mask = (count.squeeze(-1) == 0)
144
+ if not mask.any():
145
+ break
146
+
147
+ cur_result, cur_count = linear_grid_put_2d(cur_H, cur_W, coords, values, return_count=True)
148
+ result[mask] = result[mask] + F.interpolate(cur_result.permute(2,0,1).unsqueeze(0).contiguous(), (H, W), mode='bilinear', align_corners=False).squeeze(0).permute(1,2,0).contiguous()[mask]
149
+ count[mask] = count[mask] + F.interpolate(cur_count.view(1, 1, cur_H, cur_W), (H, W), mode='bilinear', align_corners=False).view(H, W, 1)[mask]
150
+ cur_H //= 2
151
+ cur_W //= 2
152
+
153
+ if return_count:
154
+ return result, count
155
+
156
+ mask = (count.squeeze(-1) > 0)
157
+ result[mask] = result[mask] / count[mask].repeat(1, C)
158
+
159
+ return result
160
+
161
+ def nearest_grid_put_3d(H, W, D, coords, values, return_count=False):
162
+ # coords: [N, 3], float in [-1, 1]
163
+ # values: [N, C]
164
+
165
+ C = values.shape[-1]
166
+
167
+ indices = (coords * 0.5 + 0.5) * torch.tensor(
168
+ [H - 1, W - 1, D - 1], dtype=torch.float32, device=coords.device
169
+ )
170
+ indices = indices.round().long() # [N, 2]
171
+
172
+ result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, C]
173
+ count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, 1]
174
+ weights = torch.ones_like(values[..., :1]) # [N, 1]
175
+
176
+ result, count = scatter_add_nd_with_count(result, count, indices, values, weights)
177
+
178
+ if return_count:
179
+ return result, count
180
+
181
+ mask = (count.squeeze(-1) > 0)
182
+ result[mask] = result[mask] / count[mask].repeat(1, C)
183
+
184
+ return result
185
+
186
+
187
+ def linear_grid_put_3d(H, W, D, coords, values, return_count=False):
188
+ # coords: [N, 3], float in [-1, 1]
189
+ # values: [N, C]
190
+
191
+ C = values.shape[-1]
192
+
193
+ indices = (coords * 0.5 + 0.5) * torch.tensor(
194
+ [H - 1, W - 1, D - 1], dtype=torch.float32, device=coords.device
195
+ )
196
+ indices_000 = indices.floor().long() # [N, 3]
197
+ indices_000[:, 0].clamp_(0, H - 2)
198
+ indices_000[:, 1].clamp_(0, W - 2)
199
+ indices_000[:, 2].clamp_(0, D - 2)
200
+
201
+ indices_001 = indices_000 + torch.tensor([0, 0, 1], dtype=torch.long, device=indices.device)
202
+ indices_010 = indices_000 + torch.tensor([0, 1, 0], dtype=torch.long, device=indices.device)
203
+ indices_011 = indices_000 + torch.tensor([0, 1, 1], dtype=torch.long, device=indices.device)
204
+ indices_100 = indices_000 + torch.tensor([1, 0, 0], dtype=torch.long, device=indices.device)
205
+ indices_101 = indices_000 + torch.tensor([1, 0, 1], dtype=torch.long, device=indices.device)
206
+ indices_110 = indices_000 + torch.tensor([1, 1, 0], dtype=torch.long, device=indices.device)
207
+ indices_111 = indices_000 + torch.tensor([1, 1, 1], dtype=torch.long, device=indices.device)
208
+
209
+ h = indices[..., 0] - indices_000[..., 0].float()
210
+ w = indices[..., 1] - indices_000[..., 1].float()
211
+ d = indices[..., 2] - indices_000[..., 2].float()
212
+
213
+ w_000 = (1 - h) * (1 - w) * (1 - d)
214
+ w_001 = (1 - h) * w * (1 - d)
215
+ w_010 = h * (1 - w) * (1 - d)
216
+ w_011 = h * w * (1 - d)
217
+ w_100 = (1 - h) * (1 - w) * d
218
+ w_101 = (1 - h) * w * d
219
+ w_110 = h * (1 - w) * d
220
+ w_111 = h * w * d
221
+
222
+ result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, D, C]
223
+ count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, D, 1]
224
+ weights = torch.ones_like(values[..., :1]) # [N, 1]
225
+
226
+ result, count = scatter_add_nd_with_count(result, count, indices_000, values * w_000.unsqueeze(1), weights * w_000.unsqueeze(1))
227
+ result, count = scatter_add_nd_with_count(result, count, indices_001, values * w_001.unsqueeze(1), weights * w_001.unsqueeze(1))
228
+ result, count = scatter_add_nd_with_count(result, count, indices_010, values * w_010.unsqueeze(1), weights * w_010.unsqueeze(1))
229
+ result, count = scatter_add_nd_with_count(result, count, indices_011, values * w_011.unsqueeze(1), weights * w_011.unsqueeze(1))
230
+ result, count = scatter_add_nd_with_count(result, count, indices_100, values * w_100.unsqueeze(1), weights * w_100.unsqueeze(1))
231
+ result, count = scatter_add_nd_with_count(result, count, indices_101, values * w_101.unsqueeze(1), weights * w_101.unsqueeze(1))
232
+ result, count = scatter_add_nd_with_count(result, count, indices_110, values * w_110.unsqueeze(1), weights * w_110.unsqueeze(1))
233
+ result, count = scatter_add_nd_with_count(result, count, indices_111, values * w_111.unsqueeze(1), weights * w_111.unsqueeze(1))
234
+
235
+ if return_count:
236
+ return result, count
237
+
238
+ mask = (count.squeeze(-1) > 0)
239
+ result[mask] = result[mask] / count[mask].repeat(1, C)
240
+
241
+ return result
242
+
243
+ def mipmap_linear_grid_put_3d(H, W, D, coords, values, min_resolution=32, return_count=False):
244
+ # coords: [N, 3], float in [-1, 1]
245
+ # values: [N, C]
246
+
247
+ C = values.shape[-1]
248
+
249
+ result = torch.zeros(H, W, D, C, device=values.device, dtype=values.dtype) # [H, W, D, C]
250
+ count = torch.zeros(H, W, D, 1, device=values.device, dtype=values.dtype) # [H, W, D, 1]
251
+ cur_H, cur_W, cur_D = H, W, D
252
+
253
+ while min(min(cur_H, cur_W), cur_D) > min_resolution:
254
+
255
+ # try to fill the holes
256
+ mask = (count.squeeze(-1) == 0)
257
+ if not mask.any():
258
+ break
259
+
260
+ cur_result, cur_count = linear_grid_put_3d(cur_H, cur_W, cur_D, coords, values, return_count=True)
261
+ result[mask] = result[mask] + F.interpolate(cur_result.permute(3,0,1,2).unsqueeze(0).contiguous(), (H, W, D), mode='trilinear', align_corners=False).squeeze(0).permute(1,2,3,0).contiguous()[mask]
262
+ count[mask] = count[mask] + F.interpolate(cur_count.view(1, 1, cur_H, cur_W, cur_D), (H, W, D), mode='trilinear', align_corners=False).view(H, W, D, 1)[mask]
263
+ cur_H //= 2
264
+ cur_W //= 2
265
+ cur_D //= 2
266
+
267
+ if return_count:
268
+ return result, count
269
+
270
+ mask = (count.squeeze(-1) > 0)
271
+ result[mask] = result[mask] / count[mask].repeat(1, C)
272
+
273
+ return result
274
+
275
+
276
+ def grid_put(shape, coords, values, mode='linear-mipmap', min_resolution=32, return_raw=False):
277
+ # shape: [D], list/tuple
278
+ # coords: [N, D], float in [-1, 1]
279
+ # values: [N, C]
280
+
281
+ D = len(shape)
282
+ assert D in [2, 3], f'only support D == 2 or 3, but got D == {D}'
283
+
284
+ if mode == 'nearest':
285
+ if D == 2:
286
+ return nearest_grid_put_2d(*shape, coords, values, return_raw)
287
+ else:
288
+ return nearest_grid_put_3d(*shape, coords, values, return_raw)
289
+ elif mode == 'linear':
290
+ if D == 2:
291
+ return linear_grid_put_2d(*shape, coords, values, return_raw)
292
+ else:
293
+ return linear_grid_put_3d(*shape, coords, values, return_raw)
294
+ elif mode == 'linear-mipmap':
295
+ if D == 2:
296
+ return mipmap_linear_grid_put_2d(*shape, coords, values, min_resolution, return_raw)
297
+ else:
298
+ return mipmap_linear_grid_put_3d(*shape, coords, values, min_resolution, return_raw)
299
+ else:
300
+ raise NotImplementedError(f"got mode {mode}")
gs_renderer.py ADDED
@@ -0,0 +1,820 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import numpy as np
4
+ from typing import NamedTuple
5
+ from plyfile import PlyData, PlyElement
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ from diff_gaussian_rasterization import (
11
+ GaussianRasterizationSettings,
12
+ GaussianRasterizer,
13
+ )
14
+ from simple_knn._C import distCUDA2
15
+
16
+ from sh_utils import eval_sh, SH2RGB, RGB2SH
17
+ from mesh import Mesh
18
+ from mesh_utils import decimate_mesh, clean_mesh
19
+
20
+ import kiui
21
+
22
+ def inverse_sigmoid(x):
23
+ return torch.log(x/(1-x))
24
+
25
+ def get_expon_lr_func(
26
+ lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
27
+ ):
28
+
29
+ def helper(step):
30
+ if lr_init == lr_final:
31
+ # constant lr, ignore other params
32
+ return lr_init
33
+ if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
34
+ # Disable this parameter
35
+ return 0.0
36
+ if lr_delay_steps > 0:
37
+ # A kind of reverse cosine decay.
38
+ delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
39
+ 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
40
+ )
41
+ else:
42
+ delay_rate = 1.0
43
+ t = np.clip(step / max_steps, 0, 1)
44
+ log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
45
+ return delay_rate * log_lerp
46
+
47
+ return helper
48
+
49
+
50
+ def strip_lowerdiag(L):
51
+ uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
52
+
53
+ uncertainty[:, 0] = L[:, 0, 0]
54
+ uncertainty[:, 1] = L[:, 0, 1]
55
+ uncertainty[:, 2] = L[:, 0, 2]
56
+ uncertainty[:, 3] = L[:, 1, 1]
57
+ uncertainty[:, 4] = L[:, 1, 2]
58
+ uncertainty[:, 5] = L[:, 2, 2]
59
+ return uncertainty
60
+
61
+ def strip_symmetric(sym):
62
+ return strip_lowerdiag(sym)
63
+
64
+ def gaussian_3d_coeff(xyzs, covs):
65
+ # xyzs: [N, 3]
66
+ # covs: [N, 6]
67
+ x, y, z = xyzs[:, 0], xyzs[:, 1], xyzs[:, 2]
68
+ a, b, c, d, e, f = covs[:, 0], covs[:, 1], covs[:, 2], covs[:, 3], covs[:, 4], covs[:, 5]
69
+
70
+ # eps must be small enough !!!
71
+ inv_det = 1 / (a * d * f + 2 * e * c * b - e**2 * a - c**2 * d - b**2 * f + 1e-24)
72
+ inv_a = (d * f - e**2) * inv_det
73
+ inv_b = (e * c - b * f) * inv_det
74
+ inv_c = (e * b - c * d) * inv_det
75
+ inv_d = (a * f - c**2) * inv_det
76
+ inv_e = (b * c - e * a) * inv_det
77
+ inv_f = (a * d - b**2) * inv_det
78
+
79
+ power = -0.5 * (x**2 * inv_a + y**2 * inv_d + z**2 * inv_f) - x * y * inv_b - x * z * inv_c - y * z * inv_e
80
+
81
+ power[power > 0] = -1e10 # abnormal values... make weights 0
82
+
83
+ return torch.exp(power)
84
+
85
+ def build_rotation(r):
86
+ norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
87
+
88
+ q = r / norm[:, None]
89
+
90
+ R = torch.zeros((q.size(0), 3, 3), device='cuda')
91
+
92
+ r = q[:, 0]
93
+ x = q[:, 1]
94
+ y = q[:, 2]
95
+ z = q[:, 3]
96
+
97
+ R[:, 0, 0] = 1 - 2 * (y*y + z*z)
98
+ R[:, 0, 1] = 2 * (x*y - r*z)
99
+ R[:, 0, 2] = 2 * (x*z + r*y)
100
+ R[:, 1, 0] = 2 * (x*y + r*z)
101
+ R[:, 1, 1] = 1 - 2 * (x*x + z*z)
102
+ R[:, 1, 2] = 2 * (y*z - r*x)
103
+ R[:, 2, 0] = 2 * (x*z - r*y)
104
+ R[:, 2, 1] = 2 * (y*z + r*x)
105
+ R[:, 2, 2] = 1 - 2 * (x*x + y*y)
106
+ return R
107
+
108
+ def build_scaling_rotation(s, r):
109
+ L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
110
+ R = build_rotation(r)
111
+
112
+ L[:,0,0] = s[:,0]
113
+ L[:,1,1] = s[:,1]
114
+ L[:,2,2] = s[:,2]
115
+
116
+ L = R @ L
117
+ return L
118
+
119
+ class BasicPointCloud(NamedTuple):
120
+ points: np.array
121
+ colors: np.array
122
+ normals: np.array
123
+
124
+
125
+ class GaussianModel:
126
+
127
+ def setup_functions(self):
128
+ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
129
+ L = build_scaling_rotation(scaling_modifier * scaling, rotation)
130
+ actual_covariance = L @ L.transpose(1, 2)
131
+ symm = strip_symmetric(actual_covariance)
132
+ return symm
133
+
134
+ self.scaling_activation = torch.exp
135
+ self.scaling_inverse_activation = torch.log
136
+
137
+ self.covariance_activation = build_covariance_from_scaling_rotation
138
+
139
+ self.opacity_activation = torch.sigmoid
140
+ self.inverse_opacity_activation = inverse_sigmoid
141
+
142
+ self.rotation_activation = torch.nn.functional.normalize
143
+
144
+
145
+ def __init__(self, sh_degree : int):
146
+ self.active_sh_degree = 0
147
+ self.max_sh_degree = sh_degree
148
+ self._xyz = torch.empty(0)
149
+ self._features_dc = torch.empty(0)
150
+ self._features_rest = torch.empty(0)
151
+ self._scaling = torch.empty(0)
152
+ self._rotation = torch.empty(0)
153
+ self._opacity = torch.empty(0)
154
+ self.max_radii2D = torch.empty(0)
155
+ self.xyz_gradient_accum = torch.empty(0)
156
+ self.denom = torch.empty(0)
157
+ self.optimizer = None
158
+ self.percent_dense = 0
159
+ self.spatial_lr_scale = 0
160
+ self.setup_functions()
161
+
162
+ def capture(self):
163
+ return (
164
+ self.active_sh_degree,
165
+ self._xyz,
166
+ self._features_dc,
167
+ self._features_rest,
168
+ self._scaling,
169
+ self._rotation,
170
+ self._opacity,
171
+ self.max_radii2D,
172
+ self.xyz_gradient_accum,
173
+ self.denom,
174
+ self.optimizer.state_dict(),
175
+ self.spatial_lr_scale,
176
+ )
177
+
178
+ def restore(self, model_args, training_args):
179
+ (self.active_sh_degree,
180
+ self._xyz,
181
+ self._features_dc,
182
+ self._features_rest,
183
+ self._scaling,
184
+ self._rotation,
185
+ self._opacity,
186
+ self.max_radii2D,
187
+ xyz_gradient_accum,
188
+ denom,
189
+ opt_dict,
190
+ self.spatial_lr_scale) = model_args
191
+ self.training_setup(training_args)
192
+ self.xyz_gradient_accum = xyz_gradient_accum
193
+ self.denom = denom
194
+ self.optimizer.load_state_dict(opt_dict)
195
+
196
+ @property
197
+ def get_scaling(self):
198
+ return self.scaling_activation(self._scaling)
199
+
200
+ @property
201
+ def get_rotation(self):
202
+ return self.rotation_activation(self._rotation)
203
+
204
+ @property
205
+ def get_xyz(self):
206
+ return self._xyz
207
+
208
+ @property
209
+ def get_features(self):
210
+ features_dc = self._features_dc
211
+ features_rest = self._features_rest
212
+ return torch.cat((features_dc, features_rest), dim=1)
213
+
214
+ @property
215
+ def get_opacity(self):
216
+ return self.opacity_activation(self._opacity)
217
+
218
+ @torch.no_grad()
219
+ def extract_fields(self, resolution=128, num_blocks=16, relax_ratio=1.5):
220
+ # resolution: resolution of field
221
+
222
+ block_size = 2 / num_blocks
223
+
224
+ assert resolution % block_size == 0
225
+ split_size = resolution // num_blocks
226
+
227
+ opacities = self.get_opacity
228
+
229
+ # pre-filter low opacity gaussians to save computation
230
+ mask = (opacities > 0.005).squeeze(1)
231
+
232
+ opacities = opacities[mask]
233
+ xyzs = self.get_xyz[mask]
234
+ stds = self.get_scaling[mask]
235
+
236
+ # normalize to ~ [-1, 1]
237
+ mn, mx = xyzs.amin(0), xyzs.amax(0)
238
+ self.center = (mn + mx) / 2
239
+ self.scale = 1.8 / (mx - mn).amax().item()
240
+
241
+ xyzs = (xyzs - self.center) * self.scale
242
+ stds = stds * self.scale
243
+
244
+ covs = self.covariance_activation(stds, 1, self._rotation[mask])
245
+
246
+ # tile
247
+ device = opacities.device
248
+ occ = torch.zeros([resolution] * 3, dtype=torch.float32, device=device)
249
+
250
+ X = torch.linspace(-1, 1, resolution).split(split_size)
251
+ Y = torch.linspace(-1, 1, resolution).split(split_size)
252
+ Z = torch.linspace(-1, 1, resolution).split(split_size)
253
+
254
+
255
+ # loop blocks (assume max size of gaussian is small than relax_ratio * block_size !!!)
256
+ for xi, xs in enumerate(X):
257
+ for yi, ys in enumerate(Y):
258
+ for zi, zs in enumerate(Z):
259
+ xx, yy, zz = torch.meshgrid(xs, ys, zs)
260
+ # sample points [M, 3]
261
+ pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1).to(device)
262
+ # in-tile gaussians mask
263
+ vmin, vmax = pts.amin(0), pts.amax(0)
264
+ vmin -= block_size * relax_ratio
265
+ vmax += block_size * relax_ratio
266
+ mask = (xyzs < vmax).all(-1) & (xyzs > vmin).all(-1)
267
+ # if hit no gaussian, continue to next block
268
+ if not mask.any():
269
+ continue
270
+ mask_xyzs = xyzs[mask] # [L, 3]
271
+ mask_covs = covs[mask] # [L, 6]
272
+ mask_opas = opacities[mask].view(1, -1) # [L, 1] --> [1, L]
273
+
274
+ # query per point-gaussian pair.
275
+ g_pts = pts.unsqueeze(1).repeat(1, mask_covs.shape[0], 1) - mask_xyzs.unsqueeze(0) # [M, L, 3]
276
+ g_covs = mask_covs.unsqueeze(0).repeat(pts.shape[0], 1, 1) # [M, L, 6]
277
+
278
+ # batch on gaussian to avoid OOM
279
+ batch_g = 1024
280
+ val = 0
281
+ for start in range(0, g_covs.shape[1], batch_g):
282
+ end = min(start + batch_g, g_covs.shape[1])
283
+ w = gaussian_3d_coeff(g_pts[:, start:end].reshape(-1, 3), g_covs[:, start:end].reshape(-1, 6)).reshape(pts.shape[0], -1) # [M, l]
284
+ val += (mask_opas[:, start:end] * w).sum(-1)
285
+
286
+ # kiui.lo(val, mask_opas, w)
287
+
288
+ occ[xi * split_size: xi * split_size + len(xs),
289
+ yi * split_size: yi * split_size + len(ys),
290
+ zi * split_size: zi * split_size + len(zs)] = val.reshape(len(xs), len(ys), len(zs))
291
+
292
+ kiui.lo(occ, verbose=1)
293
+
294
+ return occ
295
+
296
+ def extract_mesh(self, path, density_thresh=1, resolution=128, decimate_target=1e5):
297
+
298
+ os.makedirs(os.path.dirname(path), exist_ok=True)
299
+
300
+ occ = self.extract_fields(resolution).detach().cpu().numpy()
301
+
302
+ import mcubes
303
+ vertices, triangles = mcubes.marching_cubes(occ, density_thresh)
304
+ vertices = vertices / (resolution - 1.0) * 2 - 1
305
+
306
+ # transform back to the original space
307
+ vertices = vertices / self.scale + self.center.detach().cpu().numpy()
308
+
309
+ vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=0.015)
310
+ if decimate_target > 0 and triangles.shape[0] > decimate_target:
311
+ vertices, triangles = decimate_mesh(vertices, triangles, decimate_target)
312
+
313
+ v = torch.from_numpy(vertices.astype(np.float32)).contiguous().cuda()
314
+ f = torch.from_numpy(triangles.astype(np.int32)).contiguous().cuda()
315
+
316
+ print(
317
+ f"[INFO] marching cubes result: {v.shape} ({v.min().item()}-{v.max().item()}), {f.shape}"
318
+ )
319
+
320
+ mesh = Mesh(v=v, f=f, device='cuda')
321
+
322
+ return mesh
323
+
324
+ def get_covariance(self, scaling_modifier = 1):
325
+ return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
326
+
327
+ def oneupSHdegree(self):
328
+ if self.active_sh_degree < self.max_sh_degree:
329
+ self.active_sh_degree += 1
330
+
331
+ def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float = 1):
332
+ self.spatial_lr_scale = spatial_lr_scale
333
+ fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
334
+ fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
335
+ features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
336
+ features[:, :3, 0 ] = fused_color
337
+ features[:, 3:, 1:] = 0.0
338
+
339
+ print("Number of points at initialisation : ", fused_point_cloud.shape[0])
340
+
341
+ dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
342
+ scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
343
+ rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
344
+ rots[:, 0] = 1
345
+
346
+ opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
347
+
348
+ self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
349
+ self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
350
+ self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
351
+ self._scaling = nn.Parameter(scales.requires_grad_(True))
352
+ self._rotation = nn.Parameter(rots.requires_grad_(True))
353
+ self._opacity = nn.Parameter(opacities.requires_grad_(True))
354
+ self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
355
+
356
+ def training_setup(self, training_args):
357
+ self.percent_dense = training_args.percent_dense
358
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
359
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
360
+
361
+ l = [
362
+ {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
363
+ {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
364
+ {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
365
+ {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
366
+ {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
367
+ {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"}
368
+ ]
369
+
370
+ self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
371
+ self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
372
+ lr_final=training_args.position_lr_final*self.spatial_lr_scale,
373
+ lr_delay_mult=training_args.position_lr_delay_mult,
374
+ max_steps=training_args.position_lr_max_steps)
375
+
376
+ def update_learning_rate(self, iteration):
377
+ ''' Learning rate scheduling per step '''
378
+ for param_group in self.optimizer.param_groups:
379
+ if param_group["name"] == "xyz":
380
+ lr = self.xyz_scheduler_args(iteration)
381
+ param_group['lr'] = lr
382
+ return lr
383
+
384
+ def construct_list_of_attributes(self):
385
+ l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
386
+ # All channels except the 3 DC
387
+ for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
388
+ l.append('f_dc_{}'.format(i))
389
+ for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
390
+ l.append('f_rest_{}'.format(i))
391
+ l.append('opacity')
392
+ for i in range(self._scaling.shape[1]):
393
+ l.append('scale_{}'.format(i))
394
+ for i in range(self._rotation.shape[1]):
395
+ l.append('rot_{}'.format(i))
396
+ return l
397
+
398
+ def save_ply(self, path):
399
+ os.makedirs(os.path.dirname(path), exist_ok=True)
400
+
401
+ xyz = self._xyz.detach().cpu().numpy()
402
+ normals = np.zeros_like(xyz)
403
+ f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
404
+ f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
405
+ opacities = self._opacity.detach().cpu().numpy()
406
+ scale = self._scaling.detach().cpu().numpy()
407
+ rotation = self._rotation.detach().cpu().numpy()
408
+
409
+ dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
410
+
411
+ elements = np.empty(xyz.shape[0], dtype=dtype_full)
412
+ attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
413
+ elements[:] = list(map(tuple, attributes))
414
+ el = PlyElement.describe(elements, 'vertex')
415
+ PlyData([el]).write(path)
416
+
417
+ def reset_opacity(self):
418
+ opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
419
+ optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
420
+ self._opacity = optimizable_tensors["opacity"]
421
+
422
+ def load_ply(self, path):
423
+ plydata = PlyData.read(path)
424
+
425
+ xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
426
+ np.asarray(plydata.elements[0]["y"]),
427
+ np.asarray(plydata.elements[0]["z"])), axis=1)
428
+ opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
429
+
430
+ print("Number of points at loading : ", xyz.shape[0])
431
+
432
+ features_dc = np.zeros((xyz.shape[0], 3, 1))
433
+ features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
434
+ features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
435
+ features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
436
+
437
+ extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
438
+ assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
439
+ features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
440
+ for idx, attr_name in enumerate(extra_f_names):
441
+ features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
442
+ # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
443
+ features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
444
+
445
+ scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
446
+ scales = np.zeros((xyz.shape[0], len(scale_names)))
447
+ for idx, attr_name in enumerate(scale_names):
448
+ scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
449
+
450
+ rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
451
+ rots = np.zeros((xyz.shape[0], len(rot_names)))
452
+ for idx, attr_name in enumerate(rot_names):
453
+ rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
454
+
455
+ self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
456
+ self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
457
+ self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
458
+ self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
459
+ self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
460
+ self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
461
+
462
+ self.active_sh_degree = self.max_sh_degree
463
+
464
+ def replace_tensor_to_optimizer(self, tensor, name):
465
+ optimizable_tensors = {}
466
+ for group in self.optimizer.param_groups:
467
+ if group["name"] == name:
468
+ stored_state = self.optimizer.state.get(group['params'][0], None)
469
+ stored_state["exp_avg"] = torch.zeros_like(tensor)
470
+ stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
471
+
472
+ del self.optimizer.state[group['params'][0]]
473
+ group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
474
+ self.optimizer.state[group['params'][0]] = stored_state
475
+
476
+ optimizable_tensors[group["name"]] = group["params"][0]
477
+ return optimizable_tensors
478
+
479
+ def _prune_optimizer(self, mask):
480
+ optimizable_tensors = {}
481
+ for group in self.optimizer.param_groups:
482
+ stored_state = self.optimizer.state.get(group['params'][0], None)
483
+ if stored_state is not None:
484
+ stored_state["exp_avg"] = stored_state["exp_avg"][mask]
485
+ stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
486
+
487
+ del self.optimizer.state[group['params'][0]]
488
+ group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
489
+ self.optimizer.state[group['params'][0]] = stored_state
490
+
491
+ optimizable_tensors[group["name"]] = group["params"][0]
492
+ else:
493
+ group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
494
+ optimizable_tensors[group["name"]] = group["params"][0]
495
+ return optimizable_tensors
496
+
497
+ def prune_points(self, mask):
498
+ valid_points_mask = ~mask
499
+ optimizable_tensors = self._prune_optimizer(valid_points_mask)
500
+
501
+ self._xyz = optimizable_tensors["xyz"]
502
+ self._features_dc = optimizable_tensors["f_dc"]
503
+ self._features_rest = optimizable_tensors["f_rest"]
504
+ self._opacity = optimizable_tensors["opacity"]
505
+ self._scaling = optimizable_tensors["scaling"]
506
+ self._rotation = optimizable_tensors["rotation"]
507
+
508
+ self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
509
+
510
+ self.denom = self.denom[valid_points_mask]
511
+ self.max_radii2D = self.max_radii2D[valid_points_mask]
512
+
513
+ def cat_tensors_to_optimizer(self, tensors_dict):
514
+ optimizable_tensors = {}
515
+ for group in self.optimizer.param_groups:
516
+ assert len(group["params"]) == 1
517
+ extension_tensor = tensors_dict[group["name"]]
518
+ stored_state = self.optimizer.state.get(group['params'][0], None)
519
+ if stored_state is not None:
520
+
521
+ stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
522
+ stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
523
+
524
+ del self.optimizer.state[group['params'][0]]
525
+ group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
526
+ self.optimizer.state[group['params'][0]] = stored_state
527
+
528
+ optimizable_tensors[group["name"]] = group["params"][0]
529
+ else:
530
+ group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
531
+ optimizable_tensors[group["name"]] = group["params"][0]
532
+
533
+ return optimizable_tensors
534
+
535
+ def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation):
536
+ d = {"xyz": new_xyz,
537
+ "f_dc": new_features_dc,
538
+ "f_rest": new_features_rest,
539
+ "opacity": new_opacities,
540
+ "scaling" : new_scaling,
541
+ "rotation" : new_rotation}
542
+
543
+ optimizable_tensors = self.cat_tensors_to_optimizer(d)
544
+ self._xyz = optimizable_tensors["xyz"]
545
+ self._features_dc = optimizable_tensors["f_dc"]
546
+ self._features_rest = optimizable_tensors["f_rest"]
547
+ self._opacity = optimizable_tensors["opacity"]
548
+ self._scaling = optimizable_tensors["scaling"]
549
+ self._rotation = optimizable_tensors["rotation"]
550
+
551
+ self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
552
+ self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
553
+ self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
554
+
555
+ def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
556
+ n_init_points = self.get_xyz.shape[0]
557
+ # Extract points that satisfy the gradient condition
558
+ padded_grad = torch.zeros((n_init_points), device="cuda")
559
+ padded_grad[:grads.shape[0]] = grads.squeeze()
560
+ selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
561
+ selected_pts_mask = torch.logical_and(selected_pts_mask,
562
+ torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
563
+
564
+ stds = self.get_scaling[selected_pts_mask].repeat(N,1)
565
+ means =torch.zeros((stds.size(0), 3),device="cuda")
566
+ samples = torch.normal(mean=means, std=stds)
567
+ rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
568
+ new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
569
+ new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
570
+ new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
571
+ new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
572
+ new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
573
+ new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
574
+
575
+ self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation)
576
+
577
+ prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
578
+ self.prune_points(prune_filter)
579
+
580
+ def densify_and_clone(self, grads, grad_threshold, scene_extent):
581
+ # Extract points that satisfy the gradient condition
582
+ selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
583
+ selected_pts_mask = torch.logical_and(selected_pts_mask,
584
+ torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
585
+
586
+ new_xyz = self._xyz[selected_pts_mask]
587
+ new_features_dc = self._features_dc[selected_pts_mask]
588
+ new_features_rest = self._features_rest[selected_pts_mask]
589
+ new_opacities = self._opacity[selected_pts_mask]
590
+ new_scaling = self._scaling[selected_pts_mask]
591
+ new_rotation = self._rotation[selected_pts_mask]
592
+
593
+ self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation)
594
+
595
+ def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
596
+ grads = self.xyz_gradient_accum / self.denom
597
+ grads[grads.isnan()] = 0.0
598
+
599
+ self.densify_and_clone(grads, max_grad, extent)
600
+ self.densify_and_split(grads, max_grad, extent)
601
+
602
+ prune_mask = (self.get_opacity < min_opacity).squeeze()
603
+ if max_screen_size:
604
+ big_points_vs = self.max_radii2D > max_screen_size
605
+ big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
606
+ prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
607
+ self.prune_points(prune_mask)
608
+
609
+ torch.cuda.empty_cache()
610
+
611
+ def prune(self, min_opacity, extent, max_screen_size):
612
+
613
+ prune_mask = (self.get_opacity < min_opacity).squeeze()
614
+ if max_screen_size:
615
+ big_points_vs = self.max_radii2D > max_screen_size
616
+ big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
617
+ prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
618
+ self.prune_points(prune_mask)
619
+
620
+ torch.cuda.empty_cache()
621
+
622
+
623
+ def add_densification_stats(self, viewspace_point_tensor, update_filter):
624
+ self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
625
+ self.denom[update_filter] += 1
626
+
627
+ def getProjectionMatrix(znear, zfar, fovX, fovY):
628
+ tanHalfFovY = math.tan((fovY / 2))
629
+ tanHalfFovX = math.tan((fovX / 2))
630
+
631
+ P = torch.zeros(4, 4)
632
+
633
+ z_sign = 1.0
634
+
635
+ P[0, 0] = 1 / tanHalfFovX
636
+ P[1, 1] = 1 / tanHalfFovY
637
+ P[3, 2] = z_sign
638
+ P[2, 2] = z_sign * zfar / (zfar - znear)
639
+ P[2, 3] = -(zfar * znear) / (zfar - znear)
640
+ return P
641
+
642
+
643
+ class MiniCam:
644
+ def __init__(self, c2w, width, height, fovy, fovx, znear, zfar):
645
+ # c2w (pose) should be in NeRF convention.
646
+
647
+ self.image_width = width
648
+ self.image_height = height
649
+ self.FoVy = fovy
650
+ self.FoVx = fovx
651
+ self.znear = znear
652
+ self.zfar = zfar
653
+
654
+ w2c = np.linalg.inv(c2w)
655
+
656
+ # rectify...
657
+ w2c[1:3, :3] *= -1
658
+ w2c[:3, 3] *= -1
659
+
660
+ self.world_view_transform = torch.tensor(w2c).transpose(0, 1).cuda()
661
+ self.projection_matrix = (
662
+ getProjectionMatrix(
663
+ znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy
664
+ )
665
+ .transpose(0, 1)
666
+ .cuda()
667
+ )
668
+ self.full_proj_transform = self.world_view_transform @ self.projection_matrix
669
+ self.camera_center = -torch.tensor(c2w[:3, 3]).cuda()
670
+
671
+
672
+ class Renderer:
673
+ def __init__(self, sh_degree=3, white_background=True, radius=1):
674
+
675
+ self.sh_degree = sh_degree
676
+ self.white_background = white_background
677
+ self.radius = radius
678
+
679
+ self.gaussians = GaussianModel(sh_degree)
680
+
681
+ self.bg_color = torch.tensor(
682
+ [1, 1, 1] if white_background else [0, 0, 0],
683
+ dtype=torch.float32,
684
+ device="cuda",
685
+ )
686
+
687
+ def initialize(self, input=None, num_pts=5000, radius=0.5):
688
+ # load checkpoint
689
+ if input is None:
690
+ # init from random point cloud
691
+
692
+ phis = np.random.random((num_pts,)) * 2 * np.pi
693
+ costheta = np.random.random((num_pts,)) * 2 - 1
694
+ thetas = np.arccos(costheta)
695
+ mu = np.random.random((num_pts,))
696
+ radius = radius * np.cbrt(mu)
697
+ x = radius * np.sin(thetas) * np.cos(phis)
698
+ y = radius * np.sin(thetas) * np.sin(phis)
699
+ z = radius * np.cos(thetas)
700
+ xyz = np.stack((x, y, z), axis=1)
701
+ # xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
702
+
703
+ shs = np.random.random((num_pts, 3)) / 255.0
704
+ pcd = BasicPointCloud(
705
+ points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))
706
+ )
707
+ self.gaussians.create_from_pcd(pcd, 10)
708
+ elif isinstance(input, BasicPointCloud):
709
+ # load from a provided pcd
710
+ self.gaussians.create_from_pcd(input, 1)
711
+ else:
712
+ # load from saved ply
713
+ self.gaussians.load_ply(input)
714
+
715
+ def render(
716
+ self,
717
+ viewpoint_camera,
718
+ scaling_modifier=1.0,
719
+ invert_bg_color=False,
720
+ override_color=None,
721
+ compute_cov3D_python=False,
722
+ convert_SHs_python=False,
723
+ ):
724
+ # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
725
+ screenspace_points = (
726
+ torch.zeros_like(
727
+ self.gaussians.get_xyz,
728
+ dtype=self.gaussians.get_xyz.dtype,
729
+ requires_grad=True,
730
+ device="cuda",
731
+ )
732
+ + 0
733
+ )
734
+ try:
735
+ screenspace_points.retain_grad()
736
+ except:
737
+ pass
738
+
739
+ # Set up rasterization configuration
740
+ tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
741
+ tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
742
+
743
+ raster_settings = GaussianRasterizationSettings(
744
+ image_height=int(viewpoint_camera.image_height),
745
+ image_width=int(viewpoint_camera.image_width),
746
+ tanfovx=tanfovx,
747
+ tanfovy=tanfovy,
748
+ bg=self.bg_color if not invert_bg_color else 1 - self.bg_color,
749
+ scale_modifier=scaling_modifier,
750
+ viewmatrix=viewpoint_camera.world_view_transform,
751
+ projmatrix=viewpoint_camera.full_proj_transform,
752
+ sh_degree=self.gaussians.active_sh_degree,
753
+ campos=viewpoint_camera.camera_center,
754
+ prefiltered=False,
755
+ debug=False,
756
+ )
757
+
758
+ rasterizer = GaussianRasterizer(raster_settings=raster_settings)
759
+
760
+ means3D = self.gaussians.get_xyz
761
+ means2D = screenspace_points
762
+ opacity = self.gaussians.get_opacity
763
+
764
+ # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
765
+ # scaling / rotation by the rasterizer.
766
+ scales = None
767
+ rotations = None
768
+ cov3D_precomp = None
769
+ if compute_cov3D_python:
770
+ cov3D_precomp = self.gaussians.get_covariance(scaling_modifier)
771
+ else:
772
+ scales = self.gaussians.get_scaling
773
+ rotations = self.gaussians.get_rotation
774
+
775
+ # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
776
+ # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
777
+ shs = None
778
+ colors_precomp = None
779
+ if colors_precomp is None:
780
+ if convert_SHs_python:
781
+ shs_view = self.gaussians.get_features.transpose(1, 2).view(
782
+ -1, 3, (self.gaussians.max_sh_degree + 1) ** 2
783
+ )
784
+ dir_pp = self.gaussians.get_xyz - viewpoint_camera.camera_center.repeat(
785
+ self.gaussians.get_features.shape[0], 1
786
+ )
787
+ dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)
788
+ sh2rgb = eval_sh(
789
+ self.gaussians.active_sh_degree, shs_view, dir_pp_normalized
790
+ )
791
+ colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
792
+ else:
793
+ shs = self.gaussians.get_features
794
+ else:
795
+ colors_precomp = override_color
796
+
797
+ # Rasterize visible Gaussians to image, obtain their radii (on screen).
798
+ rendered_image, radii, rendered_depth, rendered_alpha = rasterizer(
799
+ means3D=means3D,
800
+ means2D=means2D,
801
+ shs=shs,
802
+ colors_precomp=colors_precomp,
803
+ opacities=opacity,
804
+ scales=scales,
805
+ rotations=rotations,
806
+ cov3D_precomp=cov3D_precomp,
807
+ )
808
+
809
+ rendered_image = rendered_image.clamp(0, 1)
810
+
811
+ # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
812
+ # They will be excluded from value updates used in the splitting criteria.
813
+ return {
814
+ "image": rendered_image,
815
+ "depth": rendered_depth,
816
+ "alpha": rendered_alpha,
817
+ "viewspace_points": screenspace_points,
818
+ "visibility_filter": radii > 0,
819
+ "radii": radii,
820
+ }
main.py ADDED
@@ -0,0 +1,882 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import tqdm
5
+ import numpy as np
6
+ import dearpygui.dearpygui as dpg
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ import rembg
12
+
13
+ from cam_utils import orbit_camera, OrbitCamera
14
+ from gs_renderer import Renderer, MiniCam
15
+
16
+ from grid_put import mipmap_linear_grid_put_2d
17
+ from mesh import Mesh, safe_normalize
18
+
19
+ class GUI:
20
+ def __init__(self, opt):
21
+ self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
22
+ self.gui = opt.gui # enable gui
23
+ self.W = opt.W
24
+ self.H = opt.H
25
+ self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
26
+
27
+ self.mode = "image"
28
+ self.seed = "random"
29
+
30
+ self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32)
31
+ self.need_update = True # update buffer_image
32
+
33
+ # models
34
+ self.device = torch.device("cuda")
35
+ self.bg_remover = None
36
+
37
+ self.guidance_sd = None
38
+ self.guidance_zero123 = None
39
+
40
+ self.enable_sd = False
41
+ self.enable_zero123 = False
42
+
43
+ # renderer
44
+ self.renderer = Renderer(sh_degree=self.opt.sh_degree)
45
+ self.gaussain_scale_factor = 1
46
+
47
+ # input image
48
+ self.input_img = None
49
+ self.input_mask = None
50
+ self.input_img_torch = None
51
+ self.input_mask_torch = None
52
+ self.overlay_input_img = False
53
+ self.overlay_input_img_ratio = 0.5
54
+
55
+ # input text
56
+ self.prompt = ""
57
+ self.negative_prompt = ""
58
+
59
+ # training stuff
60
+ self.training = False
61
+ self.optimizer = None
62
+ self.step = 0
63
+ self.train_steps = 1 # steps per rendering loop
64
+
65
+ # load input data from cmdline
66
+ if self.opt.input is not None:
67
+ self.load_input(self.opt.input)
68
+
69
+ # override prompt from cmdline
70
+ if self.opt.prompt is not None:
71
+ self.prompt = self.opt.prompt
72
+
73
+ # override if provide a checkpoint
74
+ if self.opt.load is not None:
75
+ self.renderer.initialize(self.opt.load)
76
+ else:
77
+ # initialize gaussians to a blob
78
+ self.renderer.initialize(num_pts=self.opt.num_pts)
79
+
80
+ if self.gui:
81
+ dpg.create_context()
82
+ self.register_dpg()
83
+ self.test_step()
84
+
85
+ def __del__(self):
86
+ if self.gui:
87
+ dpg.destroy_context()
88
+
89
+ def seed_everything(self):
90
+ try:
91
+ seed = int(self.seed)
92
+ except:
93
+ seed = np.random.randint(0, 1000000)
94
+
95
+ os.environ["PYTHONHASHSEED"] = str(seed)
96
+ np.random.seed(seed)
97
+ torch.manual_seed(seed)
98
+ torch.cuda.manual_seed(seed)
99
+ torch.backends.cudnn.deterministic = True
100
+ torch.backends.cudnn.benchmark = True
101
+
102
+ self.last_seed = seed
103
+
104
+ def prepare_train(self):
105
+
106
+ self.step = 0
107
+
108
+ # setup training
109
+ self.renderer.gaussians.training_setup(self.opt)
110
+ # do not do progressive sh-level
111
+ self.renderer.gaussians.active_sh_degree = self.renderer.gaussians.max_sh_degree
112
+ self.optimizer = self.renderer.gaussians.optimizer
113
+
114
+ # default camera
115
+ pose = orbit_camera(self.opt.elevation, 0, self.opt.radius)
116
+ self.fixed_cam = MiniCam(
117
+ pose,
118
+ self.opt.ref_size,
119
+ self.opt.ref_size,
120
+ self.cam.fovy,
121
+ self.cam.fovx,
122
+ self.cam.near,
123
+ self.cam.far,
124
+ )
125
+
126
+ self.enable_sd = self.opt.lambda_sd > 0 and self.prompt != ""
127
+ self.enable_zero123 = self.opt.lambda_zero123 > 0 and self.input_img is not None
128
+
129
+ # lazy load guidance model
130
+ if self.guidance_sd is None and self.enable_sd:
131
+ print(f"[INFO] loading SD...")
132
+ from guidance.sd_utils import StableDiffusion
133
+ self.guidance_sd = StableDiffusion(self.device)
134
+ print(f"[INFO] loaded SD!")
135
+
136
+ if self.guidance_zero123 is None and self.enable_zero123:
137
+ print(f"[INFO] loading zero123...")
138
+ from guidance.zero123_utils import Zero123
139
+ self.guidance_zero123 = Zero123(self.device)
140
+ print(f"[INFO] loaded zero123!")
141
+
142
+ # input image
143
+ if self.input_img is not None:
144
+ self.input_img_torch = torch.from_numpy(self.input_img).permute(2, 0, 1).unsqueeze(0).to(self.device)
145
+ self.input_img_torch = F.interpolate(self.input_img_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False)
146
+
147
+ self.input_mask_torch = torch.from_numpy(self.input_mask).permute(2, 0, 1).unsqueeze(0).to(self.device)
148
+ self.input_mask_torch = F.interpolate(self.input_mask_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False)
149
+
150
+ # prepare embeddings
151
+ with torch.no_grad():
152
+
153
+ if self.enable_sd:
154
+ self.guidance_sd.get_text_embeds([self.prompt], [self.negative_prompt])
155
+
156
+ if self.enable_zero123:
157
+ self.guidance_zero123.get_img_embeds(self.input_img_torch)
158
+
159
+ def train_step(self):
160
+ starter = torch.cuda.Event(enable_timing=True)
161
+ ender = torch.cuda.Event(enable_timing=True)
162
+ starter.record()
163
+
164
+ for _ in range(self.train_steps):
165
+
166
+ self.step += 1
167
+ step_ratio = min(1, self.step / self.opt.iters)
168
+
169
+ # update lr
170
+ self.renderer.gaussians.update_learning_rate(self.step)
171
+
172
+ loss = 0
173
+
174
+ ### known view
175
+ if self.input_img_torch is not None:
176
+ cur_cam = self.fixed_cam
177
+ out = self.renderer.render(cur_cam)
178
+
179
+ # rgb loss
180
+ image = out["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1]
181
+ loss = loss + 10000 * step_ratio * F.mse_loss(image, self.input_img_torch)
182
+
183
+ # mask loss
184
+ mask = out["alpha"].unsqueeze(0) # [1, 1, H, W] in [0, 1]
185
+ loss = loss + 1000 * step_ratio * F.mse_loss(mask, self.input_mask_torch)
186
+
187
+ ### novel view (manual batch)
188
+ render_resolution = 128 if step_ratio < 0.3 else (256 if step_ratio < 0.6 else 512)
189
+ images = []
190
+ vers, hors, radii = [], [], []
191
+ # avoid too large elevation (> 80 or < -80), and make sure it always cover [-30, 30]
192
+ min_ver = max(min(-30, -30 - self.opt.elevation), -80 - self.opt.elevation)
193
+ max_ver = min(max(30, 30 - self.opt.elevation), 80 - self.opt.elevation)
194
+ for _ in range(self.opt.batch_size):
195
+
196
+ # render random view
197
+ ver = np.random.randint(min_ver, max_ver)
198
+ hor = np.random.randint(-180, 180)
199
+ radius = 0
200
+
201
+ vers.append(ver)
202
+ hors.append(hor)
203
+ radii.append(radius)
204
+
205
+ pose = orbit_camera(self.opt.elevation + ver, hor, self.opt.radius + radius)
206
+
207
+ cur_cam = MiniCam(
208
+ pose,
209
+ render_resolution,
210
+ render_resolution,
211
+ self.cam.fovy,
212
+ self.cam.fovx,
213
+ self.cam.near,
214
+ self.cam.far,
215
+ )
216
+
217
+ invert_bg_color = np.random.rand() > self.opt.invert_bg_prob
218
+ out = self.renderer.render(cur_cam, invert_bg_color=invert_bg_color)
219
+
220
+ image = out["image"].unsqueeze(0)# [1, 3, H, W] in [0, 1]
221
+ images.append(image)
222
+
223
+ images = torch.cat(images, dim=0)
224
+
225
+ # import kiui
226
+ # kiui.lo(hor, ver)
227
+ # kiui.vis.plot_image(image)
228
+
229
+ # guidance loss
230
+ if self.enable_sd:
231
+ loss = loss + self.opt.lambda_sd * self.guidance_sd.train_step(images, step_ratio)
232
+
233
+ if self.enable_zero123:
234
+ loss = loss + self.opt.lambda_zero123 * self.guidance_zero123.train_step(images, vers, hors, radii, step_ratio)
235
+
236
+ # optimize step
237
+ loss.backward()
238
+ self.optimizer.step()
239
+ self.optimizer.zero_grad()
240
+
241
+ # densify and prune
242
+ if self.step >= self.opt.density_start_iter and self.step <= self.opt.density_end_iter:
243
+ viewspace_point_tensor, visibility_filter, radii = out["viewspace_points"], out["visibility_filter"], out["radii"]
244
+ self.renderer.gaussians.max_radii2D[visibility_filter] = torch.max(self.renderer.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
245
+ self.renderer.gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
246
+
247
+ if self.step % self.opt.densification_interval == 0:
248
+ # size_threshold = 20 if self.step > self.opt.opacity_reset_interval else None
249
+ self.renderer.gaussians.densify_and_prune(self.opt.densify_grad_threshold, min_opacity=0.01, extent=0.5, max_screen_size=1)
250
+
251
+ if self.step % self.opt.opacity_reset_interval == 0:
252
+ self.renderer.gaussians.reset_opacity()
253
+
254
+ ender.record()
255
+ torch.cuda.synchronize()
256
+ t = starter.elapsed_time(ender)
257
+
258
+ self.need_update = True
259
+
260
+ if self.gui:
261
+ dpg.set_value("_log_train_time", f"{t:.4f}ms")
262
+ dpg.set_value(
263
+ "_log_train_log",
264
+ f"step = {self.step: 5d} (+{self.train_steps: 2d}) loss = {loss.item():.4f}",
265
+ )
266
+
267
+ # dynamic train steps (no need for now)
268
+ # max allowed train time per-frame is 500 ms
269
+ # full_t = t / self.train_steps * 16
270
+ # train_steps = min(16, max(4, int(16 * 500 / full_t)))
271
+ # if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8:
272
+ # self.train_steps = train_steps
273
+
274
+ @torch.no_grad()
275
+ def test_step(self):
276
+ # ignore if no need to update
277
+ if not self.need_update:
278
+ return
279
+
280
+ starter = torch.cuda.Event(enable_timing=True)
281
+ ender = torch.cuda.Event(enable_timing=True)
282
+ starter.record()
283
+
284
+ # should update image
285
+ if self.need_update:
286
+ # render image
287
+
288
+ cur_cam = MiniCam(
289
+ self.cam.pose,
290
+ self.W,
291
+ self.H,
292
+ self.cam.fovy,
293
+ self.cam.fovx,
294
+ self.cam.near,
295
+ self.cam.far,
296
+ )
297
+
298
+ out = self.renderer.render(cur_cam, self.gaussain_scale_factor)
299
+
300
+ buffer_image = out[self.mode] # [3, H, W]
301
+
302
+ if self.mode in ['depth', 'alpha']:
303
+ buffer_image = buffer_image.repeat(3, 1, 1)
304
+ if self.mode == 'depth':
305
+ buffer_image = (buffer_image - buffer_image.min()) / (buffer_image.max() - buffer_image.min() + 1e-20)
306
+
307
+ buffer_image = F.interpolate(
308
+ buffer_image.unsqueeze(0),
309
+ size=(self.H, self.W),
310
+ mode="bilinear",
311
+ align_corners=False,
312
+ ).squeeze(0)
313
+
314
+ self.buffer_image = (
315
+ buffer_image.permute(1, 2, 0)
316
+ .contiguous()
317
+ .clamp(0, 1)
318
+ .contiguous()
319
+ .detach()
320
+ .cpu()
321
+ .numpy()
322
+ )
323
+
324
+ # display input_image
325
+ if self.overlay_input_img and self.input_img is not None:
326
+ self.buffer_image = (
327
+ self.buffer_image * (1 - self.overlay_input_img_ratio)
328
+ + self.input_img * self.overlay_input_img_ratio
329
+ )
330
+
331
+ self.need_update = False
332
+
333
+ ender.record()
334
+ torch.cuda.synchronize()
335
+ t = starter.elapsed_time(ender)
336
+
337
+ if self.gui:
338
+ dpg.set_value("_log_infer_time", f"{t:.4f}ms ({int(1000/t)} FPS)")
339
+ dpg.set_value(
340
+ "_texture", self.buffer_image
341
+ ) # buffer must be contiguous, else seg fault!
342
+
343
+
344
+ def load_input(self, file):
345
+ # load image
346
+ print(f'[INFO] load image from {file}...')
347
+ img = cv2.imread(file, cv2.IMREAD_UNCHANGED)
348
+ if img.shape[-1] == 3:
349
+ if self.bg_remover is None:
350
+ self.bg_remover = rembg.new_session()
351
+ img = rembg.remove(img, session=self.bg_remover)
352
+
353
+ img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA)
354
+ img = img.astype(np.float32) / 255.0
355
+
356
+ self.input_mask = img[..., 3:]
357
+ # white bg
358
+ self.input_img = img[..., :3] * self.input_mask + (1 - self.input_mask)
359
+ # bgr to rgb
360
+ self.input_img = self.input_img[..., ::-1].copy()
361
+
362
+ # load prompt
363
+ file_prompt = file.replace("_rgba.png", "_caption.txt")
364
+ if os.path.exists(file_prompt):
365
+ print(f'[INFO] load prompt from {file_prompt}...')
366
+ with open(file_prompt, "r") as f:
367
+ self.prompt = f.read().strip()
368
+
369
+ @torch.no_grad()
370
+ def save_model(self, mode='geo', texture_size=1024):
371
+ os.makedirs(self.opt.outdir, exist_ok=True)
372
+ if mode == 'geo':
373
+ path = os.path.join(self.opt.outdir, self.opt.save_path + '_mesh.ply')
374
+ mesh = self.renderer.gaussians.extract_mesh(path, self.opt.density_thresh)
375
+ mesh.write_ply(path)
376
+
377
+ elif mode == 'geo+tex':
378
+ path = os.path.join(self.opt.outdir, self.opt.save_path + '_mesh.obj')
379
+ mesh = self.renderer.gaussians.extract_mesh(path, self.opt.density_thresh)
380
+
381
+ # perform texture extraction
382
+ print(f"[INFO] unwrap uv...")
383
+ h = w = texture_size
384
+ mesh.auto_uv()
385
+ mesh.auto_normal()
386
+
387
+ albedo = torch.zeros((h, w, 3), device=self.device, dtype=torch.float32)
388
+ cnt = torch.zeros((h, w, 1), device=self.device, dtype=torch.float32)
389
+
390
+ # self.prepare_train() # tmp fix for not loading 0123
391
+ # vers = [0]
392
+ # hors = [0]
393
+ vers = [0] * 8 + [-45] * 8 + [45] * 8 + [-89.9, 89.9]
394
+ hors = [0, 45, -45, 90, -90, 135, -135, 180] * 3 + [0, 0]
395
+
396
+ render_resolution = 512
397
+
398
+ import nvdiffrast.torch as dr
399
+
400
+ if not self.opt.gui or os.name == 'nt':
401
+ glctx = dr.RasterizeGLContext()
402
+ else:
403
+ glctx = dr.RasterizeCudaContext()
404
+
405
+ for ver, hor in zip(vers, hors):
406
+ # render image
407
+ pose = orbit_camera(ver, hor, self.cam.radius)
408
+
409
+ cur_cam = MiniCam(
410
+ pose,
411
+ render_resolution,
412
+ render_resolution,
413
+ self.cam.fovy,
414
+ self.cam.fovx,
415
+ self.cam.near,
416
+ self.cam.far,
417
+ )
418
+
419
+ cur_out = self.renderer.render(cur_cam)
420
+
421
+ rgbs = cur_out["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1]
422
+
423
+ # enhance texture quality with zero123 [not working well]
424
+ # if self.opt.guidance_model == 'zero123':
425
+ # rgbs = self.guidance.refine(rgbs, [ver], [hor], [0])
426
+ # import kiui
427
+ # kiui.vis.plot_image(rgbs)
428
+
429
+ # get coordinate in texture image
430
+ pose = torch.from_numpy(pose.astype(np.float32)).to(self.device)
431
+ proj = torch.from_numpy(self.cam.perspective.astype(np.float32)).to(self.device)
432
+
433
+ v_cam = torch.matmul(F.pad(mesh.v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
434
+ v_clip = v_cam @ proj.T
435
+ rast, rast_db = dr.rasterize(glctx, v_clip, mesh.f, (render_resolution, render_resolution))
436
+
437
+ depth, _ = dr.interpolate(-v_cam[..., [2]], rast, mesh.f) # [1, H, W, 1]
438
+ depth = depth.squeeze(0) # [H, W, 1]
439
+
440
+ alpha = (rast[0, ..., 3:] > 0).float()
441
+
442
+ uvs, _ = dr.interpolate(mesh.vt.unsqueeze(0), rast, mesh.ft) # [1, 512, 512, 2] in [0, 1]
443
+
444
+ # use normal to produce a back-project mask
445
+ normal, _ = dr.interpolate(mesh.vn.unsqueeze(0).contiguous(), rast, mesh.fn)
446
+ normal = safe_normalize(normal[0])
447
+
448
+ # rotated normal (where [0, 0, 1] always faces camera)
449
+ rot_normal = normal @ pose[:3, :3]
450
+ viewcos = rot_normal[..., [2]]
451
+
452
+ mask = (alpha > 0) & (viewcos > 0.5) # [H, W, 1]
453
+ mask = mask.view(-1)
454
+
455
+ uvs = uvs.view(-1, 2).clamp(0, 1)[mask]
456
+ rgbs = rgbs.view(3, -1).permute(1, 0)[mask].contiguous()
457
+
458
+ # update texture image
459
+ cur_albedo, cur_cnt = mipmap_linear_grid_put_2d(
460
+ h, w,
461
+ uvs[..., [1, 0]] * 2 - 1,
462
+ rgbs,
463
+ min_resolution=256,
464
+ return_count=True,
465
+ )
466
+
467
+ # albedo += cur_albedo
468
+ # cnt += cur_cnt
469
+ mask = cnt.squeeze(-1) < 0.1
470
+ albedo[mask] += cur_albedo[mask]
471
+ cnt[mask] += cur_cnt[mask]
472
+
473
+ mask = cnt.squeeze(-1) > 0
474
+ albedo[mask] = albedo[mask] / cnt[mask].repeat(1, 3)
475
+
476
+ mask = mask.view(h, w)
477
+
478
+ albedo = albedo.detach().cpu().numpy()
479
+ mask = mask.detach().cpu().numpy()
480
+
481
+ # dilate texture
482
+ from sklearn.neighbors import NearestNeighbors
483
+ from scipy.ndimage import binary_dilation, binary_erosion
484
+
485
+ inpaint_region = binary_dilation(mask, iterations=32)
486
+ inpaint_region[mask] = 0
487
+
488
+ search_region = mask.copy()
489
+ not_search_region = binary_erosion(search_region, iterations=3)
490
+ search_region[not_search_region] = 0
491
+
492
+ search_coords = np.stack(np.nonzero(search_region), axis=-1)
493
+ inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1)
494
+
495
+ knn = NearestNeighbors(n_neighbors=1, algorithm="kd_tree").fit(
496
+ search_coords
497
+ )
498
+ _, indices = knn.kneighbors(inpaint_coords)
499
+
500
+ albedo[tuple(inpaint_coords.T)] = albedo[tuple(search_coords[indices[:, 0]].T)]
501
+
502
+ mesh.albedo = torch.from_numpy(albedo).to(self.device)
503
+ mesh.write(path)
504
+
505
+ else:
506
+ path = os.path.join(self.opt.outdir, self.opt.save_path + '_model.ply')
507
+ self.renderer.gaussians.save_ply(path)
508
+
509
+ print(f"[INFO] save model to {path}.")
510
+
511
+ def register_dpg(self):
512
+ ### register texture
513
+
514
+ with dpg.texture_registry(show=False):
515
+ dpg.add_raw_texture(
516
+ self.W,
517
+ self.H,
518
+ self.buffer_image,
519
+ format=dpg.mvFormat_Float_rgb,
520
+ tag="_texture",
521
+ )
522
+
523
+ ### register window
524
+
525
+ # the rendered image, as the primary window
526
+ with dpg.window(
527
+ tag="_primary_window",
528
+ width=self.W,
529
+ height=self.H,
530
+ pos=[0, 0],
531
+ no_move=True,
532
+ no_title_bar=True,
533
+ no_scrollbar=True,
534
+ ):
535
+ # add the texture
536
+ dpg.add_image("_texture")
537
+
538
+ # dpg.set_primary_window("_primary_window", True)
539
+
540
+ # control window
541
+ with dpg.window(
542
+ label="Control",
543
+ tag="_control_window",
544
+ width=600,
545
+ height=self.H,
546
+ pos=[self.W, 0],
547
+ no_move=True,
548
+ no_title_bar=True,
549
+ ):
550
+ # button theme
551
+ with dpg.theme() as theme_button:
552
+ with dpg.theme_component(dpg.mvButton):
553
+ dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))
554
+ dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))
555
+ dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))
556
+ dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)
557
+ dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)
558
+
559
+ # timer stuff
560
+ with dpg.group(horizontal=True):
561
+ dpg.add_text("Infer time: ")
562
+ dpg.add_text("no data", tag="_log_infer_time")
563
+
564
+ def callback_setattr(sender, app_data, user_data):
565
+ setattr(self, user_data, app_data)
566
+
567
+ # init stuff
568
+ with dpg.collapsing_header(label="Initialize", default_open=True):
569
+
570
+ # seed stuff
571
+ def callback_set_seed(sender, app_data):
572
+ self.seed = app_data
573
+ self.seed_everything()
574
+
575
+ dpg.add_input_text(
576
+ label="seed",
577
+ default_value=self.seed,
578
+ on_enter=True,
579
+ callback=callback_set_seed,
580
+ )
581
+
582
+ # input stuff
583
+ def callback_select_input(sender, app_data):
584
+ # only one item
585
+ for k, v in app_data["selections"].items():
586
+ dpg.set_value("_log_input", k)
587
+ self.load_input(v)
588
+
589
+ self.need_update = True
590
+
591
+ with dpg.file_dialog(
592
+ directory_selector=False,
593
+ show=False,
594
+ callback=callback_select_input,
595
+ file_count=1,
596
+ tag="file_dialog_tag",
597
+ width=700,
598
+ height=400,
599
+ ):
600
+ dpg.add_file_extension("Images{.jpg,.jpeg,.png}")
601
+
602
+ with dpg.group(horizontal=True):
603
+ dpg.add_button(
604
+ label="input",
605
+ callback=lambda: dpg.show_item("file_dialog_tag"),
606
+ )
607
+ dpg.add_text("", tag="_log_input")
608
+
609
+ # overlay stuff
610
+ with dpg.group(horizontal=True):
611
+
612
+ def callback_toggle_overlay_input_img(sender, app_data):
613
+ self.overlay_input_img = not self.overlay_input_img
614
+ self.need_update = True
615
+
616
+ dpg.add_checkbox(
617
+ label="overlay image",
618
+ default_value=self.overlay_input_img,
619
+ callback=callback_toggle_overlay_input_img,
620
+ )
621
+
622
+ def callback_set_overlay_input_img_ratio(sender, app_data):
623
+ self.overlay_input_img_ratio = app_data
624
+ self.need_update = True
625
+
626
+ dpg.add_slider_float(
627
+ label="ratio",
628
+ min_value=0,
629
+ max_value=1,
630
+ format="%.1f",
631
+ default_value=self.overlay_input_img_ratio,
632
+ callback=callback_set_overlay_input_img_ratio,
633
+ )
634
+
635
+ # prompt stuff
636
+
637
+ dpg.add_input_text(
638
+ label="prompt",
639
+ default_value=self.prompt,
640
+ callback=callback_setattr,
641
+ user_data="prompt",
642
+ )
643
+
644
+ dpg.add_input_text(
645
+ label="negative",
646
+ default_value=self.negative_prompt,
647
+ callback=callback_setattr,
648
+ user_data="negative_prompt",
649
+ )
650
+
651
+ # save current model
652
+ with dpg.group(horizontal=True):
653
+ dpg.add_text("Save: ")
654
+
655
+ def callback_save(sender, app_data, user_data):
656
+ self.save_model(mode=user_data)
657
+
658
+ dpg.add_button(
659
+ label="model",
660
+ tag="_button_save_model",
661
+ callback=callback_save,
662
+ user_data='model',
663
+ )
664
+ dpg.bind_item_theme("_button_save_model", theme_button)
665
+
666
+ dpg.add_button(
667
+ label="geo",
668
+ tag="_button_save_mesh",
669
+ callback=callback_save,
670
+ user_data='geo',
671
+ )
672
+ dpg.bind_item_theme("_button_save_mesh", theme_button)
673
+
674
+ dpg.add_button(
675
+ label="geo+tex",
676
+ tag="_button_save_mesh_with_tex",
677
+ callback=callback_save,
678
+ user_data='geo+tex',
679
+ )
680
+ dpg.bind_item_theme("_button_save_mesh_with_tex", theme_button)
681
+
682
+ dpg.add_input_text(
683
+ label="",
684
+ default_value=self.opt.save_path,
685
+ callback=callback_setattr,
686
+ user_data="save_path",
687
+ )
688
+
689
+ # training stuff
690
+ with dpg.collapsing_header(label="Train", default_open=True):
691
+ # lr and train button
692
+ with dpg.group(horizontal=True):
693
+ dpg.add_text("Train: ")
694
+
695
+ def callback_train(sender, app_data):
696
+ if self.training:
697
+ self.training = False
698
+ dpg.configure_item("_button_train", label="start")
699
+ else:
700
+ self.prepare_train()
701
+ self.training = True
702
+ dpg.configure_item("_button_train", label="stop")
703
+
704
+ # dpg.add_button(
705
+ # label="init", tag="_button_init", callback=self.prepare_train
706
+ # )
707
+ # dpg.bind_item_theme("_button_init", theme_button)
708
+
709
+ dpg.add_button(
710
+ label="start", tag="_button_train", callback=callback_train
711
+ )
712
+ dpg.bind_item_theme("_button_train", theme_button)
713
+
714
+ with dpg.group(horizontal=True):
715
+ dpg.add_text("", tag="_log_train_time")
716
+ dpg.add_text("", tag="_log_train_log")
717
+
718
+ # rendering options
719
+ with dpg.collapsing_header(label="Rendering", default_open=True):
720
+ # mode combo
721
+ def callback_change_mode(sender, app_data):
722
+ self.mode = app_data
723
+ self.need_update = True
724
+
725
+ dpg.add_combo(
726
+ ("image", "depth", "alpha"),
727
+ label="mode",
728
+ default_value=self.mode,
729
+ callback=callback_change_mode,
730
+ )
731
+
732
+ # fov slider
733
+ def callback_set_fovy(sender, app_data):
734
+ self.cam.fovy = np.deg2rad(app_data)
735
+ self.need_update = True
736
+
737
+ dpg.add_slider_int(
738
+ label="FoV (vertical)",
739
+ min_value=1,
740
+ max_value=120,
741
+ format="%d deg",
742
+ default_value=np.rad2deg(self.cam.fovy),
743
+ callback=callback_set_fovy,
744
+ )
745
+
746
+ def callback_set_gaussain_scale(sender, app_data):
747
+ self.gaussain_scale_factor = app_data
748
+ self.need_update = True
749
+
750
+ dpg.add_slider_float(
751
+ label="gaussain scale",
752
+ min_value=0,
753
+ max_value=1,
754
+ format="%.2f",
755
+ default_value=self.gaussain_scale_factor,
756
+ callback=callback_set_gaussain_scale,
757
+ )
758
+
759
+ ### register camera handler
760
+
761
+ def callback_camera_drag_rotate_or_draw_mask(sender, app_data):
762
+ if not dpg.is_item_focused("_primary_window"):
763
+ return
764
+
765
+ dx = app_data[1]
766
+ dy = app_data[2]
767
+
768
+ self.cam.orbit(dx, dy)
769
+ self.need_update = True
770
+
771
+ def callback_camera_wheel_scale(sender, app_data):
772
+ if not dpg.is_item_focused("_primary_window"):
773
+ return
774
+
775
+ delta = app_data
776
+
777
+ self.cam.scale(delta)
778
+ self.need_update = True
779
+
780
+ def callback_camera_drag_pan(sender, app_data):
781
+ if not dpg.is_item_focused("_primary_window"):
782
+ return
783
+
784
+ dx = app_data[1]
785
+ dy = app_data[2]
786
+
787
+ self.cam.pan(dx, dy)
788
+ self.need_update = True
789
+
790
+ def callback_set_mouse_loc(sender, app_data):
791
+ if not dpg.is_item_focused("_primary_window"):
792
+ return
793
+
794
+ # just the pixel coordinate in image
795
+ self.mouse_loc = np.array(app_data)
796
+
797
+ with dpg.handler_registry():
798
+ # for camera moving
799
+ dpg.add_mouse_drag_handler(
800
+ button=dpg.mvMouseButton_Left,
801
+ callback=callback_camera_drag_rotate_or_draw_mask,
802
+ )
803
+ dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
804
+ dpg.add_mouse_drag_handler(
805
+ button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan
806
+ )
807
+
808
+ dpg.create_viewport(
809
+ title="Gaussian3D",
810
+ width=self.W + 600,
811
+ height=self.H + (45 if os.name == "nt" else 0),
812
+ resizable=False,
813
+ )
814
+
815
+ ### global theme
816
+ with dpg.theme() as theme_no_padding:
817
+ with dpg.theme_component(dpg.mvAll):
818
+ # set all padding to 0 to avoid scroll bar
819
+ dpg.add_theme_style(
820
+ dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core
821
+ )
822
+ dpg.add_theme_style(
823
+ dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core
824
+ )
825
+ dpg.add_theme_style(
826
+ dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core
827
+ )
828
+
829
+ dpg.bind_item_theme("_primary_window", theme_no_padding)
830
+
831
+ dpg.setup_dearpygui()
832
+
833
+ ### register a larger font
834
+ # get it from: https://github.com/lxgw/LxgwWenKai/releases/download/v1.300/LXGWWenKai-Regular.ttf
835
+ if os.path.exists("LXGWWenKai-Regular.ttf"):
836
+ with dpg.font_registry():
837
+ with dpg.font("LXGWWenKai-Regular.ttf", 18) as default_font:
838
+ dpg.bind_font(default_font)
839
+
840
+ # dpg.show_metrics()
841
+
842
+ dpg.show_viewport()
843
+
844
+ def render(self):
845
+ assert self.gui
846
+ while dpg.is_dearpygui_running():
847
+ # update texture every frame
848
+ if self.training:
849
+ self.train_step()
850
+ self.test_step()
851
+ dpg.render_dearpygui_frame()
852
+
853
+ # no gui mode
854
+ def train(self, iters=500):
855
+ if iters > 0:
856
+ self.prepare_train()
857
+ for i in tqdm.trange(iters):
858
+ self.train_step()
859
+ # do a last prune
860
+ self.renderer.gaussians.prune(min_opacity=0.01, extent=1, max_screen_size=1)
861
+ # save
862
+ self.save_model(mode='model')
863
+ self.save_model(mode='geo+tex')
864
+
865
+
866
+ if __name__ == "__main__":
867
+ import argparse
868
+ from omegaconf import OmegaConf
869
+
870
+ parser = argparse.ArgumentParser()
871
+ parser.add_argument("--config", required=True, help="path to the yaml config file")
872
+ args, extras = parser.parse_known_args()
873
+
874
+ # override default config from cli
875
+ opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))
876
+
877
+ gui = GUI(opt)
878
+
879
+ if opt.gui:
880
+ gui.render()
881
+ else:
882
+ gui.train(opt.iters)
main2.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import time
4
+ import tqdm
5
+ import numpy as np
6
+ import dearpygui.dearpygui as dpg
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ import trimesh
12
+ import rembg
13
+
14
+ from cam_utils import orbit_camera, OrbitCamera
15
+ from mesh_renderer import Renderer
16
+
17
+ # from kiui.lpips import LPIPS
18
+
19
+ class GUI:
20
+ def __init__(self, opt):
21
+ self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
22
+ self.gui = opt.gui # enable gui
23
+ self.W = opt.W
24
+ self.H = opt.H
25
+ self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
26
+
27
+ self.mode = "image"
28
+ self.seed = "random"
29
+
30
+ self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32)
31
+ self.need_update = True # update buffer_image
32
+
33
+ # models
34
+ self.device = torch.device("cuda")
35
+ self.bg_remover = None
36
+
37
+ self.guidance_sd = None
38
+ self.guidance_zero123 = None
39
+
40
+ self.enable_sd = False
41
+ self.enable_zero123 = False
42
+
43
+ # renderer
44
+ self.renderer = Renderer(opt).to(self.device)
45
+
46
+ # input image
47
+ self.input_img = None
48
+ self.input_mask = None
49
+ self.input_img_torch = None
50
+ self.input_mask_torch = None
51
+ self.overlay_input_img = False
52
+ self.overlay_input_img_ratio = 0.5
53
+
54
+ # input text
55
+ self.prompt = ""
56
+ self.negative_prompt = ""
57
+
58
+ # training stuff
59
+ self.training = False
60
+ self.optimizer = None
61
+ self.step = 0
62
+ self.train_steps = 1 # steps per rendering loop
63
+ # self.lpips_loss = LPIPS(net='vgg').to(self.device)
64
+
65
+ # load input data from cmdline
66
+ if self.opt.input is not None:
67
+ self.load_input(self.opt.input)
68
+
69
+ # override prompt from cmdline
70
+ if self.opt.prompt is not None:
71
+ self.prompt = self.opt.prompt
72
+
73
+ if self.gui:
74
+ dpg.create_context()
75
+ self.register_dpg()
76
+ self.test_step()
77
+
78
+ def __del__(self):
79
+ if self.gui:
80
+ dpg.destroy_context()
81
+
82
+ def seed_everything(self):
83
+ try:
84
+ seed = int(self.seed)
85
+ except:
86
+ seed = np.random.randint(0, 1000000)
87
+
88
+ os.environ["PYTHONHASHSEED"] = str(seed)
89
+ np.random.seed(seed)
90
+ torch.manual_seed(seed)
91
+ torch.cuda.manual_seed(seed)
92
+ torch.backends.cudnn.deterministic = True
93
+ torch.backends.cudnn.benchmark = True
94
+
95
+ self.last_seed = seed
96
+
97
+ def prepare_train(self):
98
+
99
+ self.step = 0
100
+
101
+ # setup training
102
+ self.optimizer = torch.optim.Adam(self.renderer.get_params())
103
+
104
+ # default camera
105
+ pose = orbit_camera(self.opt.elevation, 0, self.opt.radius)
106
+ self.fixed_cam = (pose, self.cam.perspective)
107
+
108
+
109
+ self.enable_sd = self.opt.lambda_sd > 0 and self.prompt != ""
110
+ self.enable_zero123 = self.opt.lambda_zero123 > 0 and self.input_img is not None
111
+
112
+ # lazy load guidance model
113
+ if self.guidance_sd is None and self.enable_sd:
114
+ print(f"[INFO] loading SD...")
115
+ from guidance.sd_utils import StableDiffusion
116
+ self.guidance_sd = StableDiffusion(self.device)
117
+ print(f"[INFO] loaded SD!")
118
+
119
+ if self.guidance_zero123 is None and self.enable_zero123:
120
+ print(f"[INFO] loading zero123...")
121
+ from guidance.zero123_utils import Zero123
122
+ self.guidance_zero123 = Zero123(self.device)
123
+ print(f"[INFO] loaded zero123!")
124
+
125
+ # input image
126
+ if self.input_img is not None:
127
+ self.input_img_torch = torch.from_numpy(self.input_img).permute(2, 0, 1).unsqueeze(0).to(self.device)
128
+ self.input_img_torch = F.interpolate(
129
+ self.input_img_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False
130
+ )
131
+
132
+ self.input_mask_torch = torch.from_numpy(self.input_mask).permute(2, 0, 1).unsqueeze(0).to(self.device)
133
+ self.input_mask_torch = F.interpolate(
134
+ self.input_mask_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False
135
+ )
136
+ self.input_img_torch_channel_last = self.input_img_torch[0].permute(1,2,0).contiguous()
137
+
138
+ # prepare embeddings
139
+ with torch.no_grad():
140
+
141
+ if self.enable_sd:
142
+ self.guidance_sd.get_text_embeds([self.prompt], [self.negative_prompt])
143
+
144
+ if self.enable_zero123:
145
+ self.guidance_zero123.get_img_embeds(self.input_img_torch)
146
+
147
+ def train_step(self):
148
+ starter = torch.cuda.Event(enable_timing=True)
149
+ ender = torch.cuda.Event(enable_timing=True)
150
+ starter.record()
151
+
152
+
153
+ for _ in range(self.train_steps):
154
+
155
+ self.step += 1
156
+ step_ratio = min(1, self.step / self.opt.iters_refine)
157
+
158
+ loss = 0
159
+
160
+ ### known view
161
+ if self.input_img_torch is not None:
162
+
163
+ ssaa = min(2.0, max(0.125, 2 * np.random.random()))
164
+ out = self.renderer.render(*self.fixed_cam, self.opt.ref_size, self.opt.ref_size, ssaa=ssaa)
165
+
166
+ # rgb loss
167
+ image = out["image"] # [H, W, 3] in [0, 1]
168
+ valid_mask = ((out["alpha"] > 0) & (out["viewcos"] > 0.5)).detach()
169
+ loss = loss + F.mse_loss(image * valid_mask, self.input_img_torch_channel_last * valid_mask)
170
+
171
+ ### novel view (manual batch)
172
+ render_resolution = 512
173
+ images = []
174
+ vers, hors, radii = [], [], []
175
+ # avoid too large elevation (> 80 or < -80), and make sure it always cover [-30, 30]
176
+ min_ver = max(min(-30, -30 - self.opt.elevation), -80 - self.opt.elevation)
177
+ max_ver = min(max(30, 30 - self.opt.elevation), 80 - self.opt.elevation)
178
+ for _ in range(self.opt.batch_size):
179
+
180
+ # render random view
181
+ ver = np.random.randint(min_ver, max_ver)
182
+ hor = np.random.randint(-180, 180)
183
+ radius = 0
184
+
185
+ vers.append(ver)
186
+ hors.append(hor)
187
+ radii.append(radius)
188
+
189
+ pose = orbit_camera(self.opt.elevation + ver, hor, self.opt.radius + radius)
190
+
191
+ # random render resolution
192
+ ssaa = min(2.0, max(0.125, 2 * np.random.random()))
193
+ out = self.renderer.render(pose, self.cam.perspective, render_resolution, render_resolution, ssaa=ssaa)
194
+
195
+ image = out["image"] # [H, W, 3] in [0, 1]
196
+ image = image.permute(2,0,1).contiguous().unsqueeze(0) # [1, 3, H, W] in [0, 1]
197
+
198
+ images.append(image)
199
+
200
+ images = torch.cat(images, dim=0)
201
+
202
+ # import kiui
203
+ # kiui.lo(hor, ver)
204
+ # kiui.vis.plot_image(image)
205
+
206
+ # guidance loss
207
+ if self.enable_sd:
208
+
209
+ # loss = loss + self.opt.lambda_sd * self.guidance_sd.train_step(images, step_ratio)
210
+ refined_images = self.guidance_sd.refine(images, strength=0.6).float()
211
+ refined_images = F.interpolate(refined_images, (render_resolution, render_resolution), mode="bilinear", align_corners=False)
212
+ loss = loss + self.opt.lambda_sd * F.mse_loss(images, refined_images)
213
+
214
+ if self.enable_zero123:
215
+ # loss = loss + self.opt.lambda_zero123 * self.guidance_zero123.train_step(images, vers, hors, radii, step_ratio)
216
+ refined_images = self.guidance_zero123.refine(images, vers, hors, radii, strength=0.6).float()
217
+ refined_images = F.interpolate(refined_images, (render_resolution, render_resolution), mode="bilinear", align_corners=False)
218
+ loss = loss + self.opt.lambda_zero123 * F.mse_loss(images, refined_images)
219
+ # loss = loss + self.opt.lambda_zero123 * self.lpips_loss(images, refined_images)
220
+
221
+ # optimize step
222
+ loss.backward()
223
+ self.optimizer.step()
224
+ self.optimizer.zero_grad()
225
+
226
+ ender.record()
227
+ torch.cuda.synchronize()
228
+ t = starter.elapsed_time(ender)
229
+
230
+ self.need_update = True
231
+
232
+ if self.gui:
233
+ dpg.set_value("_log_train_time", f"{t:.4f}ms")
234
+ dpg.set_value(
235
+ "_log_train_log",
236
+ f"step = {self.step: 5d} (+{self.train_steps: 2d}) loss = {loss.item():.4f}",
237
+ )
238
+
239
+ # dynamic train steps (no need for now)
240
+ # max allowed train time per-frame is 500 ms
241
+ # full_t = t / self.train_steps * 16
242
+ # train_steps = min(16, max(4, int(16 * 500 / full_t)))
243
+ # if train_steps > self.train_steps * 1.2 or train_steps < self.train_steps * 0.8:
244
+ # self.train_steps = train_steps
245
+
246
+ @torch.no_grad()
247
+ def test_step(self):
248
+ # ignore if no need to update
249
+ if not self.need_update:
250
+ return
251
+
252
+ starter = torch.cuda.Event(enable_timing=True)
253
+ ender = torch.cuda.Event(enable_timing=True)
254
+ starter.record()
255
+
256
+ # should update image
257
+ if self.need_update:
258
+ # render image
259
+
260
+ out = self.renderer.render(self.cam.pose, self.cam.perspective, self.H, self.W)
261
+
262
+ buffer_image = out[self.mode] # [H, W, 3]
263
+
264
+ if self.mode in ['depth', 'alpha']:
265
+ buffer_image = buffer_image.repeat(1, 1, 3)
266
+ if self.mode == 'depth':
267
+ buffer_image = (buffer_image - buffer_image.min()) / (buffer_image.max() - buffer_image.min() + 1e-20)
268
+
269
+ self.buffer_image = buffer_image.contiguous().clamp(0, 1).detach().cpu().numpy()
270
+
271
+ # display input_image
272
+ if self.overlay_input_img and self.input_img is not None:
273
+ self.buffer_image = (
274
+ self.buffer_image * (1 - self.overlay_input_img_ratio)
275
+ + self.input_img * self.overlay_input_img_ratio
276
+ )
277
+
278
+ self.need_update = False
279
+
280
+ ender.record()
281
+ torch.cuda.synchronize()
282
+ t = starter.elapsed_time(ender)
283
+
284
+ if self.gui:
285
+ dpg.set_value("_log_infer_time", f"{t:.4f}ms ({int(1000/t)} FPS)")
286
+ dpg.set_value(
287
+ "_texture", self.buffer_image
288
+ ) # buffer must be contiguous, else seg fault!
289
+
290
+
291
+ def load_input(self, file):
292
+ # load image
293
+ print(f'[INFO] load image from {file}...')
294
+ img = cv2.imread(file, cv2.IMREAD_UNCHANGED)
295
+ if img.shape[-1] == 3:
296
+ if self.bg_remover is None:
297
+ self.bg_remover = rembg.new_session()
298
+ img = rembg.remove(img, session=self.bg_remover)
299
+
300
+ img = cv2.resize(
301
+ img, (self.W, self.H), interpolation=cv2.INTER_AREA
302
+ )
303
+ img = img.astype(np.float32) / 255.0
304
+
305
+ self.input_mask = img[..., 3:]
306
+ # white bg
307
+ self.input_img = img[..., :3] * self.input_mask + (
308
+ 1 - self.input_mask
309
+ )
310
+ # bgr to rgb
311
+ self.input_img = self.input_img[..., ::-1].copy()
312
+
313
+ # load prompt
314
+ file_prompt = file.replace("_rgba.png", "_caption.txt")
315
+ if os.path.exists(file_prompt):
316
+ print(f'[INFO] load prompt from {file_prompt}...')
317
+ with open(file_prompt, "r") as f:
318
+ self.prompt = f.read().strip()
319
+
320
+ def save_model(self):
321
+ os.makedirs(self.opt.outdir, exist_ok=True)
322
+
323
+ path = os.path.join(self.opt.outdir, self.opt.save_path + '.obj')
324
+ self.renderer.export_mesh(path)
325
+
326
+ print(f"[INFO] save model to {path}.")
327
+
328
+ def register_dpg(self):
329
+ ### register texture
330
+
331
+ with dpg.texture_registry(show=False):
332
+ dpg.add_raw_texture(
333
+ self.W,
334
+ self.H,
335
+ self.buffer_image,
336
+ format=dpg.mvFormat_Float_rgb,
337
+ tag="_texture",
338
+ )
339
+
340
+ ### register window
341
+
342
+ # the rendered image, as the primary window
343
+ with dpg.window(
344
+ tag="_primary_window",
345
+ width=self.W,
346
+ height=self.H,
347
+ pos=[0, 0],
348
+ no_move=True,
349
+ no_title_bar=True,
350
+ no_scrollbar=True,
351
+ ):
352
+ # add the texture
353
+ dpg.add_image("_texture")
354
+
355
+ # dpg.set_primary_window("_primary_window", True)
356
+
357
+ # control window
358
+ with dpg.window(
359
+ label="Control",
360
+ tag="_control_window",
361
+ width=600,
362
+ height=self.H,
363
+ pos=[self.W, 0],
364
+ no_move=True,
365
+ no_title_bar=True,
366
+ ):
367
+ # button theme
368
+ with dpg.theme() as theme_button:
369
+ with dpg.theme_component(dpg.mvButton):
370
+ dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))
371
+ dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))
372
+ dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))
373
+ dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)
374
+ dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)
375
+
376
+ # timer stuff
377
+ with dpg.group(horizontal=True):
378
+ dpg.add_text("Infer time: ")
379
+ dpg.add_text("no data", tag="_log_infer_time")
380
+
381
+ def callback_setattr(sender, app_data, user_data):
382
+ setattr(self, user_data, app_data)
383
+
384
+ # init stuff
385
+ with dpg.collapsing_header(label="Initialize", default_open=True):
386
+
387
+ # seed stuff
388
+ def callback_set_seed(sender, app_data):
389
+ self.seed = app_data
390
+ self.seed_everything()
391
+
392
+ dpg.add_input_text(
393
+ label="seed",
394
+ default_value=self.seed,
395
+ on_enter=True,
396
+ callback=callback_set_seed,
397
+ )
398
+
399
+ # input stuff
400
+ def callback_select_input(sender, app_data):
401
+ # only one item
402
+ for k, v in app_data["selections"].items():
403
+ dpg.set_value("_log_input", k)
404
+ self.load_input(v)
405
+
406
+ self.need_update = True
407
+
408
+ with dpg.file_dialog(
409
+ directory_selector=False,
410
+ show=False,
411
+ callback=callback_select_input,
412
+ file_count=1,
413
+ tag="file_dialog_tag",
414
+ width=700,
415
+ height=400,
416
+ ):
417
+ dpg.add_file_extension("Images{.jpg,.jpeg,.png}")
418
+
419
+ with dpg.group(horizontal=True):
420
+ dpg.add_button(
421
+ label="input",
422
+ callback=lambda: dpg.show_item("file_dialog_tag"),
423
+ )
424
+ dpg.add_text("", tag="_log_input")
425
+
426
+ # overlay stuff
427
+ with dpg.group(horizontal=True):
428
+
429
+ def callback_toggle_overlay_input_img(sender, app_data):
430
+ self.overlay_input_img = not self.overlay_input_img
431
+ self.need_update = True
432
+
433
+ dpg.add_checkbox(
434
+ label="overlay image",
435
+ default_value=self.overlay_input_img,
436
+ callback=callback_toggle_overlay_input_img,
437
+ )
438
+
439
+ def callback_set_overlay_input_img_ratio(sender, app_data):
440
+ self.overlay_input_img_ratio = app_data
441
+ self.need_update = True
442
+
443
+ dpg.add_slider_float(
444
+ label="ratio",
445
+ min_value=0,
446
+ max_value=1,
447
+ format="%.1f",
448
+ default_value=self.overlay_input_img_ratio,
449
+ callback=callback_set_overlay_input_img_ratio,
450
+ )
451
+
452
+ # prompt stuff
453
+
454
+ dpg.add_input_text(
455
+ label="prompt",
456
+ default_value=self.prompt,
457
+ callback=callback_setattr,
458
+ user_data="prompt",
459
+ )
460
+
461
+ dpg.add_input_text(
462
+ label="negative",
463
+ default_value=self.negative_prompt,
464
+ callback=callback_setattr,
465
+ user_data="negative_prompt",
466
+ )
467
+
468
+ # save current model
469
+ with dpg.group(horizontal=True):
470
+ dpg.add_text("Save: ")
471
+
472
+ dpg.add_button(
473
+ label="model",
474
+ tag="_button_save_model",
475
+ callback=self.save_model,
476
+ )
477
+ dpg.bind_item_theme("_button_save_model", theme_button)
478
+
479
+ dpg.add_input_text(
480
+ label="",
481
+ default_value=self.opt.save_path,
482
+ callback=callback_setattr,
483
+ user_data="save_path",
484
+ )
485
+
486
+ # training stuff
487
+ with dpg.collapsing_header(label="Train", default_open=True):
488
+ # lr and train button
489
+ with dpg.group(horizontal=True):
490
+ dpg.add_text("Train: ")
491
+
492
+ def callback_train(sender, app_data):
493
+ if self.training:
494
+ self.training = False
495
+ dpg.configure_item("_button_train", label="start")
496
+ else:
497
+ self.prepare_train()
498
+ self.training = True
499
+ dpg.configure_item("_button_train", label="stop")
500
+
501
+ # dpg.add_button(
502
+ # label="init", tag="_button_init", callback=self.prepare_train
503
+ # )
504
+ # dpg.bind_item_theme("_button_init", theme_button)
505
+
506
+ dpg.add_button(
507
+ label="start", tag="_button_train", callback=callback_train
508
+ )
509
+ dpg.bind_item_theme("_button_train", theme_button)
510
+
511
+ with dpg.group(horizontal=True):
512
+ dpg.add_text("", tag="_log_train_time")
513
+ dpg.add_text("", tag="_log_train_log")
514
+
515
+ # rendering options
516
+ with dpg.collapsing_header(label="Rendering", default_open=True):
517
+ # mode combo
518
+ def callback_change_mode(sender, app_data):
519
+ self.mode = app_data
520
+ self.need_update = True
521
+
522
+ dpg.add_combo(
523
+ ("image", "depth", "alpha", "normal"),
524
+ label="mode",
525
+ default_value=self.mode,
526
+ callback=callback_change_mode,
527
+ )
528
+
529
+ # fov slider
530
+ def callback_set_fovy(sender, app_data):
531
+ self.cam.fovy = np.deg2rad(app_data)
532
+ self.need_update = True
533
+
534
+ dpg.add_slider_int(
535
+ label="FoV (vertical)",
536
+ min_value=1,
537
+ max_value=120,
538
+ format="%d deg",
539
+ default_value=np.rad2deg(self.cam.fovy),
540
+ callback=callback_set_fovy,
541
+ )
542
+
543
+ ### register camera handler
544
+
545
+ def callback_camera_drag_rotate_or_draw_mask(sender, app_data):
546
+ if not dpg.is_item_focused("_primary_window"):
547
+ return
548
+
549
+ dx = app_data[1]
550
+ dy = app_data[2]
551
+
552
+ self.cam.orbit(dx, dy)
553
+ self.need_update = True
554
+
555
+ def callback_camera_wheel_scale(sender, app_data):
556
+ if not dpg.is_item_focused("_primary_window"):
557
+ return
558
+
559
+ delta = app_data
560
+
561
+ self.cam.scale(delta)
562
+ self.need_update = True
563
+
564
+ def callback_camera_drag_pan(sender, app_data):
565
+ if not dpg.is_item_focused("_primary_window"):
566
+ return
567
+
568
+ dx = app_data[1]
569
+ dy = app_data[2]
570
+
571
+ self.cam.pan(dx, dy)
572
+ self.need_update = True
573
+
574
+ def callback_set_mouse_loc(sender, app_data):
575
+ if not dpg.is_item_focused("_primary_window"):
576
+ return
577
+
578
+ # just the pixel coordinate in image
579
+ self.mouse_loc = np.array(app_data)
580
+
581
+ with dpg.handler_registry():
582
+ # for camera moving
583
+ dpg.add_mouse_drag_handler(
584
+ button=dpg.mvMouseButton_Left,
585
+ callback=callback_camera_drag_rotate_or_draw_mask,
586
+ )
587
+ dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
588
+ dpg.add_mouse_drag_handler(
589
+ button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan
590
+ )
591
+
592
+ dpg.create_viewport(
593
+ title="Gaussian3D",
594
+ width=self.W + 600,
595
+ height=self.H + (45 if os.name == "nt" else 0),
596
+ resizable=False,
597
+ )
598
+
599
+ ### global theme
600
+ with dpg.theme() as theme_no_padding:
601
+ with dpg.theme_component(dpg.mvAll):
602
+ # set all padding to 0 to avoid scroll bar
603
+ dpg.add_theme_style(
604
+ dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core
605
+ )
606
+ dpg.add_theme_style(
607
+ dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core
608
+ )
609
+ dpg.add_theme_style(
610
+ dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core
611
+ )
612
+
613
+ dpg.bind_item_theme("_primary_window", theme_no_padding)
614
+
615
+ dpg.setup_dearpygui()
616
+
617
+ ### register a larger font
618
+ # get it from: https://github.com/lxgw/LxgwWenKai/releases/download/v1.300/LXGWWenKai-Regular.ttf
619
+ if os.path.exists("LXGWWenKai-Regular.ttf"):
620
+ with dpg.font_registry():
621
+ with dpg.font("LXGWWenKai-Regular.ttf", 18) as default_font:
622
+ dpg.bind_font(default_font)
623
+
624
+ # dpg.show_metrics()
625
+
626
+ dpg.show_viewport()
627
+
628
+ def render(self):
629
+ assert self.gui
630
+ while dpg.is_dearpygui_running():
631
+ # update texture every frame
632
+ if self.training:
633
+ self.train_step()
634
+ self.test_step()
635
+ dpg.render_dearpygui_frame()
636
+
637
+ # no gui mode
638
+ def train(self, iters=500):
639
+ if iters > 0:
640
+ self.prepare_train()
641
+ for i in tqdm.trange(iters):
642
+ self.train_step()
643
+ # save
644
+ self.save_model()
645
+
646
+
647
+ if __name__ == "__main__":
648
+ import argparse
649
+ from omegaconf import OmegaConf
650
+
651
+ parser = argparse.ArgumentParser()
652
+ parser.add_argument("--config", required=True, help="path to the yaml config file")
653
+ args, extras = parser.parse_known_args()
654
+
655
+ # override default config from cli
656
+ opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))
657
+
658
+ # auto find mesh from stage 1
659
+ if opt.mesh is None:
660
+ default_path = os.path.join(opt.outdir, opt.save_path + '_mesh.obj')
661
+ if os.path.exists(default_path):
662
+ opt.mesh = default_path
663
+ else:
664
+ raise ValueError(f"Cannot find mesh from {default_path}, must specify --mesh explicitly!")
665
+
666
+ gui = GUI(opt)
667
+
668
+ if opt.gui:
669
+ gui.render()
670
+ else:
671
+ gui.train(opt.iters_refine)
mesh.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import trimesh
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ def dot(x, y):
10
+ return torch.sum(x * y, -1, keepdim=True)
11
+
12
+
13
+ def length(x, eps=1e-20):
14
+ return torch.sqrt(torch.clamp(dot(x, x), min=eps))
15
+
16
+
17
+ def safe_normalize(x, eps=1e-20):
18
+ return x / length(x, eps)
19
+
20
+
21
+ class Mesh:
22
+ def __init__(
23
+ self,
24
+ v=None,
25
+ f=None,
26
+ vn=None,
27
+ fn=None,
28
+ vt=None,
29
+ ft=None,
30
+ albedo=None,
31
+ device=None,
32
+ ):
33
+ self.device = device
34
+ self.v = v
35
+ self.vn = vn
36
+ self.vt = vt
37
+ self.f = f
38
+ self.fn = fn
39
+ self.ft = ft
40
+ # only support a single albedo
41
+ self.albedo = albedo
42
+
43
+ self.ori_center = 0
44
+ self.ori_scale = 1
45
+
46
+ @classmethod
47
+ def load(cls, path=None, resize=True, **kwargs):
48
+ # assume init with kwargs
49
+ if path is None:
50
+ mesh = cls(**kwargs)
51
+ # obj supports face uv
52
+ elif path.endswith(".obj"):
53
+ mesh = cls.load_obj(path, **kwargs)
54
+ # trimesh only supports vertex uv, but can load more formats
55
+ else:
56
+ mesh = cls.load_trimesh(path, **kwargs)
57
+
58
+ print(f"[Mesh loading] v: {mesh.v.shape}, f: {mesh.f.shape}")
59
+ # auto-normalize
60
+ if resize:
61
+ mesh.auto_size()
62
+ # auto-fix normal
63
+ if mesh.vn is None:
64
+ mesh.auto_normal()
65
+ print(f"[Mesh loading] vn: {mesh.vn.shape}, fn: {mesh.fn.shape}")
66
+ # auto-fix texture
67
+ if mesh.vt is None:
68
+ mesh.auto_uv(cache_path=path)
69
+ print(f"[Mesh loading] vt: {mesh.vt.shape}, ft: {mesh.ft.shape}")
70
+
71
+ return mesh
72
+
73
+ # load from obj file
74
+ @classmethod
75
+ def load_obj(cls, path, albedo_path=None, device=None, init_empty_tex=False):
76
+ assert os.path.splitext(path)[-1] == ".obj"
77
+
78
+ mesh = cls()
79
+
80
+ # device
81
+ if device is None:
82
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83
+
84
+ mesh.device = device
85
+
86
+ # try to find texture from mtl file
87
+ if albedo_path is None:
88
+ mtl_path = path.replace(".obj", ".mtl")
89
+ if os.path.exists(mtl_path):
90
+ with open(mtl_path, "r") as f:
91
+ lines = f.readlines()
92
+ for line in lines:
93
+ split_line = line.split()
94
+ # empty line
95
+ if len(split_line) == 0:
96
+ continue
97
+ prefix = split_line[0]
98
+ # NOTE: simply use the first map_Kd as albedo!
99
+ if "map_Kd" in prefix:
100
+ albedo_path = os.path.join(os.path.dirname(path), split_line[1])
101
+ print(f"[load_obj] use texture from: {albedo_path}")
102
+ break
103
+
104
+ if init_empty_tex or albedo_path is None or not os.path.exists(albedo_path):
105
+ # init an empty texture
106
+ print(f"[load_obj] init empty albedo!")
107
+ # albedo = np.random.rand(1024, 1024, 3).astype(np.float32)
108
+ albedo = np.ones((1024, 1024, 3), dtype=np.float32) * np.array(
109
+ [0.5, 0.5, 0.5]
110
+ ) # default color
111
+ else:
112
+ albedo = cv2.imread(albedo_path, cv2.IMREAD_UNCHANGED)
113
+ albedo = cv2.cvtColor(albedo, cv2.COLOR_BGR2RGB)
114
+ albedo = albedo.astype(np.float32) / 255
115
+ print(f"[load_obj] load texture: {albedo.shape}")
116
+
117
+ # import matplotlib.pyplot as plt
118
+ # plt.imshow(albedo)
119
+ # plt.show()
120
+
121
+ mesh.albedo = torch.tensor(albedo, dtype=torch.float32, device=device)
122
+
123
+ # load obj
124
+ with open(path, "r") as f:
125
+ lines = f.readlines()
126
+
127
+ def parse_f_v(fv):
128
+ # pass in a vertex term of a face, return {v, vt, vn} (-1 if not provided)
129
+ # supported forms:
130
+ # f v1 v2 v3
131
+ # f v1/vt1 v2/vt2 v3/vt3
132
+ # f v1/vt1/vn1 v2/vt2/vn2 v3/vt3/vn3
133
+ # f v1//vn1 v2//vn2 v3//vn3
134
+ xs = [int(x) - 1 if x != "" else -1 for x in fv.split("/")]
135
+ xs.extend([-1] * (3 - len(xs)))
136
+ return xs[0], xs[1], xs[2]
137
+
138
+ # NOTE: we ignore usemtl, and assume the mesh ONLY uses one material (first in mtl)
139
+ vertices, texcoords, normals = [], [], []
140
+ faces, tfaces, nfaces = [], [], []
141
+ for line in lines:
142
+ split_line = line.split()
143
+ # empty line
144
+ if len(split_line) == 0:
145
+ continue
146
+ # v/vn/vt
147
+ prefix = split_line[0].lower()
148
+ if prefix == "v":
149
+ vertices.append([float(v) for v in split_line[1:]])
150
+ elif prefix == "vn":
151
+ normals.append([float(v) for v in split_line[1:]])
152
+ elif prefix == "vt":
153
+ val = [float(v) for v in split_line[1:]]
154
+ texcoords.append([val[0], 1.0 - val[1]])
155
+ elif prefix == "f":
156
+ vs = split_line[1:]
157
+ nv = len(vs)
158
+ v0, t0, n0 = parse_f_v(vs[0])
159
+ for i in range(nv - 2): # triangulate (assume vertices are ordered)
160
+ v1, t1, n1 = parse_f_v(vs[i + 1])
161
+ v2, t2, n2 = parse_f_v(vs[i + 2])
162
+ faces.append([v0, v1, v2])
163
+ tfaces.append([t0, t1, t2])
164
+ nfaces.append([n0, n1, n2])
165
+
166
+ mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
167
+ mesh.vt = (
168
+ torch.tensor(texcoords, dtype=torch.float32, device=device)
169
+ if len(texcoords) > 0
170
+ else None
171
+ )
172
+ mesh.vn = (
173
+ torch.tensor(normals, dtype=torch.float32, device=device)
174
+ if len(normals) > 0
175
+ else None
176
+ )
177
+
178
+ mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
179
+ mesh.ft = (
180
+ torch.tensor(tfaces, dtype=torch.int32, device=device)
181
+ if texcoords is not None
182
+ else None
183
+ )
184
+ mesh.fn = (
185
+ torch.tensor(nfaces, dtype=torch.int32, device=device)
186
+ if normals is not None
187
+ else None
188
+ )
189
+
190
+ return mesh
191
+
192
+ @classmethod
193
+ def load_trimesh(cls, path, device=None):
194
+ mesh = cls()
195
+
196
+ # device
197
+ if device is None:
198
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
199
+
200
+ mesh.device = device
201
+
202
+ # use trimesh to load glb, assume only has one single RootMesh...
203
+ _data = trimesh.load(path)
204
+ if isinstance(_data, trimesh.Scene):
205
+ mesh_keys = list(_data.geometry.keys())
206
+ assert (
207
+ len(mesh_keys) == 1
208
+ ), f"{path} contains more than one meshes, not supported!"
209
+ _mesh = _data.geometry[mesh_keys[0]]
210
+
211
+ elif isinstance(_data, trimesh.Trimesh):
212
+ _mesh = _data
213
+
214
+ else:
215
+ raise NotImplementedError(f"type {type(_data)} not supported!")
216
+
217
+ # TODO: exception handling if no material
218
+ _material = _mesh.visual.material
219
+ if isinstance(_material, trimesh.visual.material.PBRMaterial):
220
+ texture = np.array(_material.baseColorTexture).astype(np.float32) / 255
221
+ elif isinstance(_material, trimesh.visual.material.SimpleMaterial):
222
+ texture = (
223
+ np.array(_material.to_pbr().baseColorTexture).astype(np.float32) / 255
224
+ )
225
+ else:
226
+ raise NotImplementedError(f"material type {type(_material)} not supported!")
227
+
228
+ print(f"[load_obj] load texture: {texture.shape}")
229
+ mesh.albedo = torch.tensor(texture, dtype=torch.float32, device=device)
230
+
231
+ vertices = _mesh.vertices
232
+ texcoords = _mesh.visual.uv
233
+ texcoords[:, 1] = 1 - texcoords[:, 1]
234
+ normals = _mesh.vertex_normals
235
+
236
+ # trimesh only support vertex uv...
237
+ faces = tfaces = nfaces = _mesh.faces
238
+
239
+ mesh.v = torch.tensor(vertices, dtype=torch.float32, device=device)
240
+ mesh.vt = (
241
+ torch.tensor(texcoords, dtype=torch.float32, device=device)
242
+ if len(texcoords) > 0
243
+ else None
244
+ )
245
+ mesh.vn = (
246
+ torch.tensor(normals, dtype=torch.float32, device=device)
247
+ if len(normals) > 0
248
+ else None
249
+ )
250
+
251
+ mesh.f = torch.tensor(faces, dtype=torch.int32, device=device)
252
+ mesh.ft = (
253
+ torch.tensor(tfaces, dtype=torch.int32, device=device)
254
+ if texcoords is not None
255
+ else None
256
+ )
257
+ mesh.fn = (
258
+ torch.tensor(nfaces, dtype=torch.int32, device=device)
259
+ if normals is not None
260
+ else None
261
+ )
262
+
263
+ return mesh
264
+
265
+ # aabb
266
+ def aabb(self):
267
+ return torch.min(self.v, dim=0).values, torch.max(self.v, dim=0).values
268
+
269
+ # unit size
270
+ @torch.no_grad()
271
+ def auto_size(self):
272
+ vmin, vmax = self.aabb()
273
+ self.ori_center = (vmax + vmin) / 2
274
+ self.ori_scale = 1.2 / torch.max(vmax - vmin).item() # to ~ [-0.6, 0.6]
275
+ self.v = (self.v - self.ori_center) * self.ori_scale
276
+
277
+ def auto_normal(self):
278
+ i0, i1, i2 = self.f[:, 0].long(), self.f[:, 1].long(), self.f[:, 2].long()
279
+ v0, v1, v2 = self.v[i0, :], self.v[i1, :], self.v[i2, :]
280
+
281
+ face_normals = torch.cross(v1 - v0, v2 - v0)
282
+
283
+ # Splat face normals to vertices
284
+ vn = torch.zeros_like(self.v)
285
+ vn.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
286
+ vn.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
287
+ vn.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
288
+
289
+ # Normalize, replace zero (degenerated) normals with some default value
290
+ vn = torch.where(
291
+ dot(vn, vn) > 1e-20,
292
+ vn,
293
+ torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device),
294
+ )
295
+ vn = safe_normalize(vn)
296
+
297
+ self.vn = vn
298
+ self.fn = self.f
299
+
300
+ def auto_uv(self, cache_path=None):
301
+ # try to load cache
302
+ if cache_path is not None:
303
+ cache_path = cache_path.replace(".obj", "_uv.npz")
304
+
305
+ if cache_path is not None and os.path.exists(cache_path):
306
+ data = np.load(cache_path)
307
+ vt_np, ft_np = data["vt"], data["ft"]
308
+ else:
309
+ import xatlas
310
+
311
+ v_np = self.v.detach().cpu().numpy()
312
+ f_np = self.f.detach().int().cpu().numpy()
313
+ atlas = xatlas.Atlas()
314
+ atlas.add_mesh(v_np, f_np)
315
+ chart_options = xatlas.ChartOptions()
316
+ # chart_options.max_iterations = 4
317
+ atlas.generate(chart_options=chart_options)
318
+ vmapping, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2]
319
+
320
+ # save to cache
321
+ if cache_path is not None:
322
+ np.savez(cache_path, vt=vt_np, ft=ft_np)
323
+
324
+ vt = torch.from_numpy(vt_np.astype(np.float32)).to(self.device)
325
+ ft = torch.from_numpy(ft_np.astype(np.int32)).to(self.device)
326
+
327
+ self.vt = vt
328
+ self.ft = ft
329
+
330
+ def to(self, device):
331
+ self.device = device
332
+ for name in ["v", "f", "vn", "fn", "vt", "ft", "albedo"]:
333
+ tensor = getattr(self, name)
334
+ if tensor is not None:
335
+ setattr(self, name, tensor.to(device))
336
+ return self
337
+
338
+ # write to ply file (only geom)
339
+ def write_ply(self, path):
340
+ assert path.endswith(".ply")
341
+
342
+ v_np = self.v.detach().cpu().numpy()
343
+ f_np = self.f.detach().cpu().numpy()
344
+
345
+ _mesh = trimesh.Trimesh(vertices=v_np, faces=f_np)
346
+ _mesh.export(path)
347
+
348
+ # write to obj file
349
+ def write(self, path):
350
+ mtl_path = path.replace(".obj", ".mtl")
351
+ albedo_path = path.replace(".obj", "_albedo.png")
352
+
353
+ v_np = self.v.detach().cpu().numpy()
354
+ vt_np = self.vt.detach().cpu().numpy() if self.vt is not None else None
355
+ vn_np = self.vn.detach().cpu().numpy() if self.vn is not None else None
356
+ f_np = self.f.detach().cpu().numpy()
357
+ ft_np = self.ft.detach().cpu().numpy() if self.ft is not None else None
358
+ fn_np = self.fn.detach().cpu().numpy() if self.fn is not None else None
359
+
360
+ with open(path, "w") as fp:
361
+ fp.write(f"mtllib {os.path.basename(mtl_path)} \n")
362
+
363
+ for v in v_np:
364
+ fp.write(f"v {v[0]} {v[1]} {v[2]} \n")
365
+
366
+ if vt_np is not None:
367
+ for v in vt_np:
368
+ fp.write(f"vt {v[0]} {1 - v[1]} \n")
369
+
370
+ if vn_np is not None:
371
+ for v in vn_np:
372
+ fp.write(f"vn {v[0]} {v[1]} {v[2]} \n")
373
+
374
+ fp.write(f"usemtl defaultMat \n")
375
+ for i in range(len(f_np)):
376
+ fp.write(
377
+ f'f {f_np[i, 0] + 1}/{ft_np[i, 0] + 1 if ft_np is not None else ""}/{fn_np[i, 0] + 1 if fn_np is not None else ""} \
378
+ {f_np[i, 1] + 1}/{ft_np[i, 1] + 1 if ft_np is not None else ""}/{fn_np[i, 1] + 1 if fn_np is not None else ""} \
379
+ {f_np[i, 2] + 1}/{ft_np[i, 2] + 1 if ft_np is not None else ""}/{fn_np[i, 2] + 1 if fn_np is not None else ""} \n'
380
+ )
381
+
382
+ with open(mtl_path, "w") as fp:
383
+ fp.write(f"newmtl defaultMat \n")
384
+ fp.write(f"Ka 1 1 1 \n")
385
+ fp.write(f"Kd 1 1 1 \n")
386
+ fp.write(f"Ks 0 0 0 \n")
387
+ fp.write(f"Tr 1 \n")
388
+ fp.write(f"illum 1 \n")
389
+ fp.write(f"Ns 0 \n")
390
+ fp.write(f"map_Kd {os.path.basename(albedo_path)} \n")
391
+
392
+ albedo = self.albedo.detach().cpu().numpy()
393
+ albedo = (albedo * 255).astype(np.uint8)
394
+ cv2.imwrite(albedo_path, cv2.cvtColor(albedo, cv2.COLOR_RGB2BGR))
mesh_renderer.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import cv2
4
+ import trimesh
5
+ import numpy as np
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ import nvdiffrast.torch as dr
12
+ from mesh import Mesh, safe_normalize
13
+
14
+ def scale_img_nhwc(x, size, mag='bilinear', min='bilinear'):
15
+ assert (x.shape[1] >= size[0] and x.shape[2] >= size[1]) or (x.shape[1] < size[0] and x.shape[2] < size[1]), "Trying to magnify image in one dimension and minify in the other"
16
+ y = x.permute(0, 3, 1, 2) # NHWC -> NCHW
17
+ if x.shape[1] > size[0] and x.shape[2] > size[1]: # Minification, previous size was bigger
18
+ y = torch.nn.functional.interpolate(y, size, mode=min)
19
+ else: # Magnification
20
+ if mag == 'bilinear' or mag == 'bicubic':
21
+ y = torch.nn.functional.interpolate(y, size, mode=mag, align_corners=True)
22
+ else:
23
+ y = torch.nn.functional.interpolate(y, size, mode=mag)
24
+ return y.permute(0, 2, 3, 1).contiguous() # NCHW -> NHWC
25
+
26
+ def scale_img_hwc(x, size, mag='bilinear', min='bilinear'):
27
+ return scale_img_nhwc(x[None, ...], size, mag, min)[0]
28
+
29
+ def scale_img_nhw(x, size, mag='bilinear', min='bilinear'):
30
+ return scale_img_nhwc(x[..., None], size, mag, min)[..., 0]
31
+
32
+ def scale_img_hw(x, size, mag='bilinear', min='bilinear'):
33
+ return scale_img_nhwc(x[None, ..., None], size, mag, min)[0, ..., 0]
34
+
35
+ def trunc_rev_sigmoid(x, eps=1e-6):
36
+ x = x.clamp(eps, 1 - eps)
37
+ return torch.log(x / (1 - x))
38
+
39
+ def make_divisible(x, m=8):
40
+ return int(math.ceil(x / m) * m)
41
+
42
+ class Renderer(nn.Module):
43
+ def __init__(self, opt):
44
+
45
+ super().__init__()
46
+
47
+ self.opt = opt
48
+
49
+ self.mesh = Mesh.load(self.opt.mesh, resize=False)
50
+
51
+ if not self.opt.gui or os.name == 'nt':
52
+ self.glctx = dr.RasterizeGLContext()
53
+ else:
54
+ self.glctx = dr.RasterizeCudaContext()
55
+
56
+ # extract trainable parameters
57
+ self.v_offsets = nn.Parameter(torch.zeros_like(self.mesh.v))
58
+ self.raw_albedo = nn.Parameter(trunc_rev_sigmoid(self.mesh.albedo))
59
+
60
+
61
+ def get_params(self):
62
+
63
+ params = [
64
+ {'params': self.raw_albedo, 'lr': self.opt.texture_lr},
65
+ ]
66
+
67
+ if self.opt.train_geo:
68
+ params.append({'params': self.v_offsets, 'lr': self.opt.geom_lr})
69
+
70
+ return params
71
+
72
+ @torch.no_grad()
73
+ def export_mesh(self, save_path):
74
+ self.mesh.v = (self.mesh.v + self.v_offsets).detach()
75
+ self.mesh.albedo = torch.sigmoid(self.raw_albedo.detach())
76
+ self.mesh.write(save_path)
77
+
78
+
79
+ def render(self, pose, proj, h0, w0, ssaa=1, bg_color=1, texture_filter='linear-mipmap-linear'):
80
+
81
+ # do super-sampling
82
+ if ssaa != 1:
83
+ h = make_divisible(h0 * ssaa, 8)
84
+ w = make_divisible(w0 * ssaa, 8)
85
+ else:
86
+ h, w = h0, w0
87
+
88
+ results = {}
89
+
90
+ # get v
91
+ if self.opt.train_geo:
92
+ v = self.mesh.v + self.v_offsets # [N, 3]
93
+ else:
94
+ v = self.mesh.v
95
+
96
+ pose = torch.from_numpy(pose.astype(np.float32)).to(v.device)
97
+ proj = torch.from_numpy(proj.astype(np.float32)).to(v.device)
98
+
99
+ # get v_clip and render rgb
100
+ v_cam = torch.matmul(F.pad(v, pad=(0, 1), mode='constant', value=1.0), torch.inverse(pose).T).float().unsqueeze(0)
101
+ v_clip = v_cam @ proj.T
102
+
103
+ rast, rast_db = dr.rasterize(self.glctx, v_clip, self.mesh.f, (h, w))
104
+
105
+ alpha = (rast[0, ..., 3:] > 0).float()
106
+ depth, _ = dr.interpolate(-v_cam[..., [2]], rast, self.mesh.f) # [1, H, W, 1]
107
+ depth = depth.squeeze(0) # [H, W, 1]
108
+
109
+ texc, texc_db = dr.interpolate(self.mesh.vt.unsqueeze(0).contiguous(), rast, self.mesh.ft, rast_db=rast_db, diff_attrs='all')
110
+ albedo = dr.texture(self.raw_albedo.unsqueeze(0), texc, uv_da=texc_db, filter_mode=texture_filter) # [1, H, W, 3]
111
+ albedo = torch.sigmoid(albedo)
112
+ # get vn and render normal
113
+ if self.opt.train_geo:
114
+ i0, i1, i2 = self.mesh.f[:, 0].long(), self.mesh.f[:, 1].long(), self.mesh.f[:, 2].long()
115
+ v0, v1, v2 = v[i0, :], v[i1, :], v[i2, :]
116
+
117
+ face_normals = torch.cross(v1 - v0, v2 - v0)
118
+ face_normals = safe_normalize(face_normals)
119
+
120
+ vn = torch.zeros_like(v)
121
+ vn.scatter_add_(0, i0[:, None].repeat(1,3), face_normals)
122
+ vn.scatter_add_(0, i1[:, None].repeat(1,3), face_normals)
123
+ vn.scatter_add_(0, i2[:, None].repeat(1,3), face_normals)
124
+
125
+ vn = torch.where(torch.sum(vn * vn, -1, keepdim=True) > 1e-20, vn, torch.tensor([0.0, 0.0, 1.0], dtype=torch.float32, device=vn.device))
126
+ else:
127
+ vn = self.mesh.vn
128
+
129
+ normal, _ = dr.interpolate(vn.unsqueeze(0).contiguous(), rast, self.mesh.fn)
130
+ normal = safe_normalize(normal[0])
131
+
132
+ # rotated normal (where [0, 0, 1] always faces camera)
133
+ rot_normal = normal @ pose[:3, :3]
134
+ viewcos = rot_normal[..., [2]]
135
+
136
+ # antialias
137
+ albedo = dr.antialias(albedo, rast, v_clip, self.mesh.f).squeeze(0) # [H, W, 3]
138
+ albedo = alpha * albedo + (1 - alpha) * bg_color
139
+
140
+ # ssaa
141
+ if ssaa != 1:
142
+ albedo = scale_img_hwc(albedo, (h0, w0))
143
+ alpha = scale_img_hwc(alpha, (h0, w0))
144
+ depth = scale_img_hwc(depth, (h0, w0))
145
+ normal = scale_img_hwc(normal, (h0, w0))
146
+ viewcos = scale_img_hwc(viewcos, (h0, w0))
147
+
148
+ results['image'] = albedo.clamp(0, 1)
149
+ results['alpha'] = alpha
150
+ results['depth'] = depth
151
+ results['normal'] = (normal + 1) / 2
152
+ results['viewcos'] = viewcos
153
+
154
+ return results
mesh_utils.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pymeshlab as pml
3
+
4
+
5
+ def poisson_mesh_reconstruction(points, normals=None):
6
+ # points/normals: [N, 3] np.ndarray
7
+
8
+ import open3d as o3d
9
+
10
+ pcd = o3d.geometry.PointCloud()
11
+ pcd.points = o3d.utility.Vector3dVector(points)
12
+
13
+ # outlier removal
14
+ pcd, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=10)
15
+
16
+ # normals
17
+ if normals is None:
18
+ pcd.estimate_normals()
19
+ else:
20
+ pcd.normals = o3d.utility.Vector3dVector(normals[ind])
21
+
22
+ # visualize
23
+ o3d.visualization.draw_geometries([pcd], point_show_normal=False)
24
+
25
+ mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
26
+ pcd, depth=9
27
+ )
28
+ vertices_to_remove = densities < np.quantile(densities, 0.1)
29
+ mesh.remove_vertices_by_mask(vertices_to_remove)
30
+
31
+ # visualize
32
+ o3d.visualization.draw_geometries([mesh])
33
+
34
+ vertices = np.asarray(mesh.vertices)
35
+ triangles = np.asarray(mesh.triangles)
36
+
37
+ print(
38
+ f"[INFO] poisson mesh reconstruction: {points.shape} --> {vertices.shape} / {triangles.shape}"
39
+ )
40
+
41
+ return vertices, triangles
42
+
43
+
44
+ def decimate_mesh(
45
+ verts, faces, target, backend="pymeshlab", remesh=False, optimalplacement=True
46
+ ):
47
+ # optimalplacement: default is True, but for flat mesh must turn False to prevent spike artifect.
48
+
49
+ _ori_vert_shape = verts.shape
50
+ _ori_face_shape = faces.shape
51
+
52
+ if backend == "pyfqmr":
53
+ import pyfqmr
54
+
55
+ solver = pyfqmr.Simplify()
56
+ solver.setMesh(verts, faces)
57
+ solver.simplify_mesh(target_count=target, preserve_border=False, verbose=False)
58
+ verts, faces, normals = solver.getMesh()
59
+ else:
60
+ m = pml.Mesh(verts, faces)
61
+ ms = pml.MeshSet()
62
+ ms.add_mesh(m, "mesh") # will copy!
63
+
64
+ # filters
65
+ # ms.meshing_decimation_clustering(threshold=pml.Percentage(1))
66
+ ms.meshing_decimation_quadric_edge_collapse(
67
+ targetfacenum=int(target), optimalplacement=optimalplacement
68
+ )
69
+
70
+ if remesh:
71
+ # ms.apply_coord_taubin_smoothing()
72
+ ms.meshing_isotropic_explicit_remeshing(
73
+ iterations=3, targetlen=pml.Percentage(1)
74
+ )
75
+
76
+ # extract mesh
77
+ m = ms.current_mesh()
78
+ verts = m.vertex_matrix()
79
+ faces = m.face_matrix()
80
+
81
+ print(
82
+ f"[INFO] mesh decimation: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}"
83
+ )
84
+
85
+ return verts, faces
86
+
87
+
88
+ def clean_mesh(
89
+ verts,
90
+ faces,
91
+ v_pct=1,
92
+ min_f=64,
93
+ min_d=20,
94
+ repair=True,
95
+ remesh=True,
96
+ remesh_size=0.01,
97
+ ):
98
+ # verts: [N, 3]
99
+ # faces: [N, 3]
100
+
101
+ _ori_vert_shape = verts.shape
102
+ _ori_face_shape = faces.shape
103
+
104
+ m = pml.Mesh(verts, faces)
105
+ ms = pml.MeshSet()
106
+ ms.add_mesh(m, "mesh") # will copy!
107
+
108
+ # filters
109
+ ms.meshing_remove_unreferenced_vertices() # verts not refed by any faces
110
+
111
+ if v_pct > 0:
112
+ ms.meshing_merge_close_vertices(
113
+ threshold=pml.Percentage(v_pct)
114
+ ) # 1/10000 of bounding box diagonal
115
+
116
+ ms.meshing_remove_duplicate_faces() # faces defined by the same verts
117
+ ms.meshing_remove_null_faces() # faces with area == 0
118
+
119
+ if min_d > 0:
120
+ ms.meshing_remove_connected_component_by_diameter(
121
+ mincomponentdiag=pml.Percentage(min_d)
122
+ )
123
+
124
+ if min_f > 0:
125
+ ms.meshing_remove_connected_component_by_face_number(mincomponentsize=min_f)
126
+
127
+ if repair:
128
+ # ms.meshing_remove_t_vertices(method=0, threshold=40, repeat=True)
129
+ ms.meshing_repair_non_manifold_edges(method=0)
130
+ ms.meshing_repair_non_manifold_vertices(vertdispratio=0)
131
+
132
+ if remesh:
133
+ # ms.apply_coord_taubin_smoothing()
134
+ ms.meshing_isotropic_explicit_remeshing(
135
+ iterations=3, targetlen=pml.AbsoluteValue(remesh_size)
136
+ )
137
+
138
+ # extract mesh
139
+ m = ms.current_mesh()
140
+ verts = m.vertex_matrix()
141
+ faces = m.face_matrix()
142
+
143
+ print(
144
+ f"[INFO] mesh cleaning: {_ori_vert_shape} --> {verts.shape}, {_ori_face_shape} --> {faces.shape}"
145
+ )
146
+
147
+ return verts, faces
process.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import sys
4
+ import cv2
5
+ import argparse
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torchvision import transforms
13
+ from PIL import Image
14
+ import rembg
15
+
16
+ class BLIP2():
17
+ def __init__(self, device='cuda'):
18
+ self.device = device
19
+ from transformers import AutoProcessor, Blip2ForConditionalGeneration
20
+ self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
21
+ self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16).to(device)
22
+
23
+ @torch.no_grad()
24
+ def __call__(self, image):
25
+ image = Image.fromarray(image)
26
+ inputs = self.processor(image, return_tensors="pt").to(self.device, torch.float16)
27
+
28
+ generated_ids = self.model.generate(**inputs, max_new_tokens=20)
29
+ generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
30
+
31
+ return generated_text
32
+
33
+
34
+ if __name__ == '__main__':
35
+
36
+ parser = argparse.ArgumentParser()
37
+ parser.add_argument('path', type=str, help="path to image (png, jpeg, etc.)")
38
+ parser.add_argument('--model', default='u2net', type=str, help="rembg model, see https://github.com/danielgatis/rembg#models")
39
+ parser.add_argument('--size', default=256, type=int, help="output resolution")
40
+ parser.add_argument('--border_ratio', default=0.2, type=float, help="output border ratio")
41
+ parser.add_argument('--recenter', type=bool, default=True, help="recenter, potentially not helpful for multiview zero123")
42
+ opt = parser.parse_args()
43
+
44
+ session = rembg.new_session(model_name=opt.model)
45
+
46
+ if os.path.isdir(opt.path):
47
+ print(f'[INFO] processing directory {opt.path}...')
48
+ files = glob.glob(f'{opt.path}/*')
49
+ out_dir = opt.path
50
+ else: # isfile
51
+ files = [opt.path]
52
+ out_dir = os.path.dirname(opt.path)
53
+
54
+ for file in files:
55
+
56
+ out_base = os.path.basename(file).split('.')[0]
57
+ out_rgba = os.path.join(out_dir, out_base + '_rgba.png')
58
+
59
+ # load image
60
+ print(f'[INFO] loading image {file}...')
61
+ image = cv2.imread(file, cv2.IMREAD_UNCHANGED)
62
+
63
+ # carve background
64
+ print(f'[INFO] background removal...')
65
+ carved_image = rembg.remove(image, session=session) # [H, W, 4]
66
+ mask = carved_image[..., -1] > 0
67
+
68
+ # recenter
69
+ if opt.recenter:
70
+ print(f'[INFO] recenter...')
71
+ final_rgba = np.zeros((opt.size, opt.size, 4), dtype=np.uint8)
72
+
73
+ coords = np.nonzero(mask)
74
+ x_min, x_max = coords[0].min(), coords[0].max()
75
+ y_min, y_max = coords[1].min(), coords[1].max()
76
+ h = x_max - x_min
77
+ w = y_max - y_min
78
+ desired_size = int(opt.size * (1 - opt.border_ratio))
79
+ scale = desired_size / max(h, w)
80
+ h2 = int(h * scale)
81
+ w2 = int(w * scale)
82
+ x2_min = (opt.size - h2) // 2
83
+ x2_max = x2_min + h2
84
+ y2_min = (opt.size - w2) // 2
85
+ y2_max = y2_min + w2
86
+ final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(carved_image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA)
87
+
88
+ else:
89
+ final_rgba = carved_image
90
+
91
+ # write image
92
+ cv2.imwrite(out_rgba, final_rgba)
readme.md ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DreamGaussian
2
+
3
+ This repository contains the official implementation for [DreamGaussian: Generative Gaussian Splatting for Efficient 3D Content Creation]().
4
+
5
+ ### [Project Page](https://dreamgaussian.github.io) | [Arxiv]()
6
+
7
+
8
+ https://github.com/dreamgaussian/dreamgaussian/assets/25863658/db860801-7b9c-4b30-9eb9-87330175f5c8
9
+
10
+
11
+ ## Install
12
+ ```bash
13
+ pip install -r requirements.txt
14
+
15
+ # a modified gaussain splatting (+ depth, alpha rendering)
16
+ git clone --recursive https://github.com/ashawkey/diff-gaussian-rasterization
17
+ pip install ./diff-gaussian-rasterization
18
+
19
+ # simple-knn
20
+ pip install ./simple-knn
21
+
22
+ # nvdiffrast
23
+ pip install git+https://github.com/NVlabs/nvdiffrast/
24
+
25
+ # kiuikit
26
+ pip install git+https://github.com/ashawkey/kiuikit
27
+ ```
28
+
29
+ Tested on:
30
+ * Ubuntu 22 with torch 1.12 & CUDA 11.6 on a V100.
31
+ * Windows 10 with torch 2.1 & CUDA 12.1 on a 3070.
32
+
33
+ ## Usage
34
+
35
+ Image-to-3D:
36
+ ```bash
37
+ ### preprocess
38
+ # background removal and recenter, save rgba at 256x256
39
+ python process.py data/name.jpg
40
+
41
+ # save at a larger resolution
42
+ python process.py data/name.jpg --size 512
43
+
44
+ # process all jpg images under a dir
45
+ python process.py data
46
+
47
+ ### training gaussian stage
48
+ # train 500 iters (~1min) and export ckpt & coarse_mesh to logs
49
+ python main.py --config configs/image.yaml input=data/name_rgba.png save_path=name
50
+
51
+ # gui mode (supports visualizing training)
52
+ python main.py --config configs/image.yaml input=data/name_rgba.png save_path=name gui=True
53
+
54
+ # load and visualize a saved ckpt
55
+ python main.py --config configs/image.yaml load=logs/name_model.ply gui=True
56
+
57
+ # use an estimated elevation angle if image is not front-view (e.g., common looking-down image can use -30)
58
+ python main.py --config configs/image.yaml input=data/name_rgba.png save_path=name elevation=-30
59
+
60
+ ### training mesh stage
61
+ # auto load coarse_mesh.obj and refine 50 iters (~1min), export fine_mesh to logs
62
+ python main2.py --config configs/image.yaml input=data/name_rgba.png save_path=name
63
+
64
+ # specify coarse mesh path explicity
65
+ python main2.py --config configs/image.yaml input=data/name_rgba.png save_path=name mesh=logs/name_mesh.obj
66
+
67
+ # gui mode
68
+ python main2.py --config configs/image.yaml input=data/name_rgba.png save_path=name gui=True
69
+
70
+ ### visualization
71
+ # gui for visualizing mesh
72
+ python -m kiui.render logs/name.obj
73
+
74
+ # save 360 degree video of mesh (can run without gui)
75
+ python -m kiui.render logs/name.obj --save_video name.mp4 --wogui
76
+
77
+ # save 8 view images of mesh (can run without gui)
78
+ python -m kiui.render logs/name.obj --save images/name/ --wogui
79
+
80
+ ### evaluation of CLIP-similarity
81
+ python -m kiui.cli.clip_sim data/name_rgba.png logs/name.obj
82
+ ```
83
+ Please check `./configs/image.yaml` for more options.
84
+
85
+ Text-to-3D:
86
+ ```bash
87
+ ### training gaussian stage
88
+ python main.py --config configs/text.yaml prompt="a photo of an icecream" save_path=icecream
89
+
90
+ ### training mesh stage
91
+ python main2.py --config configs/text.yaml prompt="a photo of an icecream" save_path=icecream
92
+ ```
93
+ Please check `./configs/text.yaml` for more options.
94
+
95
+ Helper scripts:
96
+ ```bash
97
+ # run all image samples (*_rgba.png) in ./data
98
+ python scripts/runall.py --dir ./data --gpu 0
99
+
100
+ # run all text samples (hardcoded in runall_sd.py)
101
+ python scripts/runall_sd.py --gpu 0
102
+
103
+ # export all ./logs/*.obj to mp4 in ./videos
104
+ python scripts/convert_obj_to_video.py --dir ./logs
105
+ ```
106
+
107
+ ## Acknowledgement
108
+
109
+ This work is built on many amazing research works and open-source projects, thanks a lot to all the authors for sharing!
110
+
111
+ * [gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting) and [diff-gaussian-rasterization](https://github.com/graphdeco-inria/diff-gaussian-rasterization)
112
+ * [threestudio](https://github.com/threestudio-project/threestudio)
113
+ * [nvdiffrast](https://github.com/NVlabs/nvdiffrast)
114
+ * [dearpygui](https://github.com/hoffstadt/DearPyGui)
115
+
116
+ ## Citation
117
+
118
+ ```
119
+
120
+ ```
requirements.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tqdm
2
+ rich
3
+ ninja
4
+ numpy
5
+ pandas
6
+ scipy
7
+ scikit-learn
8
+ matplotlib
9
+ opencv-python
10
+ imageio
11
+ imageio-ffmpeg
12
+ omegaconf
13
+
14
+ torch
15
+ einops
16
+ plyfile
17
+
18
+ # for gui
19
+ dearpygui
20
+
21
+ # for stable-diffusion
22
+ huggingface_hub
23
+ diffusers >= 0.9.0
24
+ accelerate
25
+ transformers
26
+
27
+ # for dmtet and mesh export
28
+ xatlas
29
+ trimesh
30
+ PyMCubes
31
+ pymeshlab
32
+
33
+ rembg[gpu,cli]
sh_utils.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The PlenOctree Authors.
2
+ # Redistribution and use in source and binary forms, with or without
3
+ # modification, are permitted provided that the following conditions are met:
4
+ #
5
+ # 1. Redistributions of source code must retain the above copyright notice,
6
+ # this list of conditions and the following disclaimer.
7
+ #
8
+ # 2. Redistributions in binary form must reproduce the above copyright notice,
9
+ # this list of conditions and the following disclaimer in the documentation
10
+ # and/or other materials provided with the distribution.
11
+ #
12
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
13
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
14
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
15
+ # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
16
+ # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
17
+ # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
18
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
19
+ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
20
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
21
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
22
+ # POSSIBILITY OF SUCH DAMAGE.
23
+
24
+ import torch
25
+
26
+ C0 = 0.28209479177387814
27
+ C1 = 0.4886025119029199
28
+ C2 = [
29
+ 1.0925484305920792,
30
+ -1.0925484305920792,
31
+ 0.31539156525252005,
32
+ -1.0925484305920792,
33
+ 0.5462742152960396
34
+ ]
35
+ C3 = [
36
+ -0.5900435899266435,
37
+ 2.890611442640554,
38
+ -0.4570457994644658,
39
+ 0.3731763325901154,
40
+ -0.4570457994644658,
41
+ 1.445305721320277,
42
+ -0.5900435899266435
43
+ ]
44
+ C4 = [
45
+ 2.5033429417967046,
46
+ -1.7701307697799304,
47
+ 0.9461746957575601,
48
+ -0.6690465435572892,
49
+ 0.10578554691520431,
50
+ -0.6690465435572892,
51
+ 0.47308734787878004,
52
+ -1.7701307697799304,
53
+ 0.6258357354491761,
54
+ ]
55
+
56
+
57
+ def eval_sh(deg, sh, dirs):
58
+ """
59
+ Evaluate spherical harmonics at unit directions
60
+ using hardcoded SH polynomials.
61
+ Works with torch/np/jnp.
62
+ ... Can be 0 or more batch dimensions.
63
+ Args:
64
+ deg: int SH deg. Currently, 0-3 supported
65
+ sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
66
+ dirs: jnp.ndarray unit directions [..., 3]
67
+ Returns:
68
+ [..., C]
69
+ """
70
+ assert deg <= 4 and deg >= 0
71
+ coeff = (deg + 1) ** 2
72
+ assert sh.shape[-1] >= coeff
73
+
74
+ result = C0 * sh[..., 0]
75
+ if deg > 0:
76
+ x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
77
+ result = (result -
78
+ C1 * y * sh[..., 1] +
79
+ C1 * z * sh[..., 2] -
80
+ C1 * x * sh[..., 3])
81
+
82
+ if deg > 1:
83
+ xx, yy, zz = x * x, y * y, z * z
84
+ xy, yz, xz = x * y, y * z, x * z
85
+ result = (result +
86
+ C2[0] * xy * sh[..., 4] +
87
+ C2[1] * yz * sh[..., 5] +
88
+ C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
89
+ C2[3] * xz * sh[..., 7] +
90
+ C2[4] * (xx - yy) * sh[..., 8])
91
+
92
+ if deg > 2:
93
+ result = (result +
94
+ C3[0] * y * (3 * xx - yy) * sh[..., 9] +
95
+ C3[1] * xy * z * sh[..., 10] +
96
+ C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
97
+ C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
98
+ C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
99
+ C3[5] * z * (xx - yy) * sh[..., 14] +
100
+ C3[6] * x * (xx - 3 * yy) * sh[..., 15])
101
+
102
+ if deg > 3:
103
+ result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
104
+ C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
105
+ C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
106
+ C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
107
+ C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
108
+ C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
109
+ C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
110
+ C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
111
+ C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
112
+ return result
113
+
114
+ def RGB2SH(rgb):
115
+ return (rgb - 0.5) / C0
116
+
117
+ def SH2RGB(sh):
118
+ return sh * C0 + 0.5
zero123.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import math
17
+ import warnings
18
+ from typing import Any, Callable, Dict, List, Optional, Union
19
+
20
+ import PIL
21
+ import torch
22
+ import torchvision.transforms.functional as TF
23
+ from diffusers.configuration_utils import ConfigMixin, FrozenDict, register_to_config
24
+ from diffusers.image_processor import VaeImageProcessor
25
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
26
+ from diffusers.models.modeling_utils import ModelMixin
27
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
28
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
29
+ from diffusers.pipelines.stable_diffusion.safety_checker import (
30
+ StableDiffusionSafetyChecker,
31
+ )
32
+ from diffusers.schedulers import KarrasDiffusionSchedulers
33
+ from diffusers.utils import deprecate, is_accelerate_available, logging
34
+ from diffusers.utils.torch_utils import randn_tensor
35
+ from packaging import version
36
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ class CLIPCameraProjection(ModelMixin, ConfigMixin):
42
+ """
43
+ A Projection layer for CLIP embedding and camera embedding.
44
+
45
+ Parameters:
46
+ embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `clip_embed`
47
+ additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
48
+ projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
49
+ additional_embeddings`.
50
+ """
51
+
52
+ @register_to_config
53
+ def __init__(self, embedding_dim: int = 768, additional_embeddings: int = 4):
54
+ super().__init__()
55
+ self.embedding_dim = embedding_dim
56
+ self.additional_embeddings = additional_embeddings
57
+
58
+ self.input_dim = self.embedding_dim + self.additional_embeddings
59
+ self.output_dim = self.embedding_dim
60
+
61
+ self.proj = torch.nn.Linear(self.input_dim, self.output_dim)
62
+
63
+ def forward(
64
+ self,
65
+ embedding: torch.FloatTensor,
66
+ ):
67
+ """
68
+ The [`PriorTransformer`] forward method.
69
+
70
+ Args:
71
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, input_dim)`):
72
+ The currently input embeddings.
73
+
74
+ Returns:
75
+ The output embedding projection (`torch.FloatTensor` of shape `(batch_size, output_dim)`).
76
+ """
77
+ proj_embedding = self.proj(embedding)
78
+ return proj_embedding
79
+
80
+
81
+ class Zero123Pipeline(DiffusionPipeline):
82
+ r"""
83
+ Pipeline to generate variations from an input image using Stable Diffusion.
84
+
85
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
86
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
87
+
88
+ Args:
89
+ vae ([`AutoencoderKL`]):
90
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
91
+ image_encoder ([`CLIPVisionModelWithProjection`]):
92
+ Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of
93
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
94
+ specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
95
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
96
+ scheduler ([`SchedulerMixin`]):
97
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
98
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
99
+ safety_checker ([`StableDiffusionSafetyChecker`]):
100
+ Classification module that estimates whether generated images could be considered offensive or harmful.
101
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
102
+ feature_extractor ([`CLIPImageProcessor`]):
103
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
104
+ """
105
+ # TODO: feature_extractor is required to encode images (if they are in PIL format),
106
+ # we should give a descriptive message if the pipeline doesn't have one.
107
+ _optional_components = ["safety_checker"]
108
+
109
+ def __init__(
110
+ self,
111
+ vae: AutoencoderKL,
112
+ image_encoder: CLIPVisionModelWithProjection,
113
+ unet: UNet2DConditionModel,
114
+ scheduler: KarrasDiffusionSchedulers,
115
+ safety_checker: StableDiffusionSafetyChecker,
116
+ feature_extractor: CLIPImageProcessor,
117
+ clip_camera_projection: CLIPCameraProjection,
118
+ requires_safety_checker: bool = True,
119
+ ):
120
+ super().__init__()
121
+
122
+ if safety_checker is None and requires_safety_checker:
123
+ logger.warn(
124
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
125
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
126
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
127
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
128
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
129
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
130
+ )
131
+
132
+ if safety_checker is not None and feature_extractor is None:
133
+ raise ValueError(
134
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
135
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
136
+ )
137
+
138
+ is_unet_version_less_0_9_0 = hasattr(
139
+ unet.config, "_diffusers_version"
140
+ ) and version.parse(
141
+ version.parse(unet.config._diffusers_version).base_version
142
+ ) < version.parse(
143
+ "0.9.0.dev0"
144
+ )
145
+ is_unet_sample_size_less_64 = (
146
+ hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
147
+ )
148
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
149
+ deprecation_message = (
150
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
151
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
152
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
153
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
154
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
155
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
156
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
157
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
158
+ " the `unet/config.json` file"
159
+ )
160
+ deprecate(
161
+ "sample_size<64", "1.0.0", deprecation_message, standard_warn=False
162
+ )
163
+ new_config = dict(unet.config)
164
+ new_config["sample_size"] = 64
165
+ unet._internal_dict = FrozenDict(new_config)
166
+
167
+ self.register_modules(
168
+ vae=vae,
169
+ image_encoder=image_encoder,
170
+ unet=unet,
171
+ scheduler=scheduler,
172
+ safety_checker=safety_checker,
173
+ feature_extractor=feature_extractor,
174
+ clip_camera_projection=clip_camera_projection,
175
+ )
176
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
177
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
178
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
179
+
180
+ def enable_sequential_cpu_offload(self, gpu_id=0):
181
+ r"""
182
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
183
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
184
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
185
+ """
186
+ if is_accelerate_available():
187
+ from accelerate import cpu_offload
188
+ else:
189
+ raise ImportError("Please install accelerate via `pip install accelerate`")
190
+
191
+ device = torch.device(f"cuda:{gpu_id}")
192
+
193
+ for cpu_offloaded_model in [
194
+ self.unet,
195
+ self.image_encoder,
196
+ self.vae,
197
+ self.safety_checker,
198
+ ]:
199
+ if cpu_offloaded_model is not None:
200
+ cpu_offload(cpu_offloaded_model, device)
201
+
202
+ @property
203
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
204
+ def _execution_device(self):
205
+ r"""
206
+ Returns the device on which the pipeline's models will be executed. After calling
207
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
208
+ hooks.
209
+ """
210
+ if not hasattr(self.unet, "_hf_hook"):
211
+ return self.device
212
+ for module in self.unet.modules():
213
+ if (
214
+ hasattr(module, "_hf_hook")
215
+ and hasattr(module._hf_hook, "execution_device")
216
+ and module._hf_hook.execution_device is not None
217
+ ):
218
+ return torch.device(module._hf_hook.execution_device)
219
+ return self.device
220
+
221
+ def _encode_image(
222
+ self,
223
+ image,
224
+ elevation,
225
+ azimuth,
226
+ distance,
227
+ device,
228
+ num_images_per_prompt,
229
+ do_classifier_free_guidance,
230
+ clip_image_embeddings=None,
231
+ image_camera_embeddings=None,
232
+ ):
233
+ dtype = next(self.image_encoder.parameters()).dtype
234
+
235
+ if image_camera_embeddings is None:
236
+ if image is None:
237
+ assert clip_image_embeddings is not None
238
+ image_embeddings = clip_image_embeddings.to(device=device, dtype=dtype)
239
+ else:
240
+ if not isinstance(image, torch.Tensor):
241
+ image = self.feature_extractor(
242
+ images=image, return_tensors="pt"
243
+ ).pixel_values
244
+
245
+ image = image.to(device=device, dtype=dtype)
246
+ image_embeddings = self.image_encoder(image).image_embeds
247
+ image_embeddings = image_embeddings.unsqueeze(1)
248
+
249
+ bs_embed, seq_len, _ = image_embeddings.shape
250
+
251
+ if isinstance(elevation, float):
252
+ elevation = torch.as_tensor(
253
+ [elevation] * bs_embed, dtype=dtype, device=device
254
+ )
255
+ if isinstance(azimuth, float):
256
+ azimuth = torch.as_tensor(
257
+ [azimuth] * bs_embed, dtype=dtype, device=device
258
+ )
259
+ if isinstance(distance, float):
260
+ distance = torch.as_tensor(
261
+ [distance] * bs_embed, dtype=dtype, device=device
262
+ )
263
+
264
+ camera_embeddings = torch.stack(
265
+ [
266
+ torch.deg2rad(elevation),
267
+ torch.sin(torch.deg2rad(azimuth)),
268
+ torch.cos(torch.deg2rad(azimuth)),
269
+ distance,
270
+ ],
271
+ dim=-1,
272
+ )[:, None, :]
273
+
274
+ image_embeddings = torch.cat([image_embeddings, camera_embeddings], dim=-1)
275
+
276
+ # project (image, camera) embeddings to the same dimension as clip embeddings
277
+ image_embeddings = self.clip_camera_projection(image_embeddings)
278
+ else:
279
+ image_embeddings = image_camera_embeddings.to(device=device, dtype=dtype)
280
+ bs_embed, seq_len, _ = image_embeddings.shape
281
+
282
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
283
+ image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)
284
+ image_embeddings = image_embeddings.view(
285
+ bs_embed * num_images_per_prompt, seq_len, -1
286
+ )
287
+
288
+ if do_classifier_free_guidance:
289
+ negative_prompt_embeds = torch.zeros_like(image_embeddings)
290
+
291
+ # For classifier free guidance, we need to do two forward passes.
292
+ # Here we concatenate the unconditional and text embeddings into a single batch
293
+ # to avoid doing two forward passes
294
+ image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])
295
+
296
+ return image_embeddings
297
+
298
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
299
+ def run_safety_checker(self, image, device, dtype):
300
+ if self.safety_checker is None:
301
+ has_nsfw_concept = None
302
+ else:
303
+ if torch.is_tensor(image):
304
+ feature_extractor_input = self.image_processor.postprocess(
305
+ image, output_type="pil"
306
+ )
307
+ else:
308
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
309
+ safety_checker_input = self.feature_extractor(
310
+ feature_extractor_input, return_tensors="pt"
311
+ ).to(device)
312
+ image, has_nsfw_concept = self.safety_checker(
313
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
314
+ )
315
+ return image, has_nsfw_concept
316
+
317
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
318
+ def decode_latents(self, latents):
319
+ warnings.warn(
320
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
321
+ " use VaeImageProcessor instead",
322
+ FutureWarning,
323
+ )
324
+ latents = 1 / self.vae.config.scaling_factor * latents
325
+ image = self.vae.decode(latents, return_dict=False)[0]
326
+ image = (image / 2 + 0.5).clamp(0, 1)
327
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
328
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
329
+ return image
330
+
331
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
332
+ def prepare_extra_step_kwargs(self, generator, eta):
333
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
334
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
335
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
336
+ # and should be between [0, 1]
337
+
338
+ accepts_eta = "eta" in set(
339
+ inspect.signature(self.scheduler.step).parameters.keys()
340
+ )
341
+ extra_step_kwargs = {}
342
+ if accepts_eta:
343
+ extra_step_kwargs["eta"] = eta
344
+
345
+ # check if the scheduler accepts generator
346
+ accepts_generator = "generator" in set(
347
+ inspect.signature(self.scheduler.step).parameters.keys()
348
+ )
349
+ if accepts_generator:
350
+ extra_step_kwargs["generator"] = generator
351
+ return extra_step_kwargs
352
+
353
+ def check_inputs(self, image, height, width, callback_steps):
354
+ # TODO: check image size or adjust image size to (height, width)
355
+
356
+ if height % 8 != 0 or width % 8 != 0:
357
+ raise ValueError(
358
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
359
+ )
360
+
361
+ if (callback_steps is None) or (
362
+ callback_steps is not None
363
+ and (not isinstance(callback_steps, int) or callback_steps <= 0)
364
+ ):
365
+ raise ValueError(
366
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
367
+ f" {type(callback_steps)}."
368
+ )
369
+
370
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
371
+ def prepare_latents(
372
+ self,
373
+ batch_size,
374
+ num_channels_latents,
375
+ height,
376
+ width,
377
+ dtype,
378
+ device,
379
+ generator,
380
+ latents=None,
381
+ ):
382
+ shape = (
383
+ batch_size,
384
+ num_channels_latents,
385
+ height // self.vae_scale_factor,
386
+ width // self.vae_scale_factor,
387
+ )
388
+ if isinstance(generator, list) and len(generator) != batch_size:
389
+ raise ValueError(
390
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
391
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
392
+ )
393
+
394
+ if latents is None:
395
+ latents = randn_tensor(
396
+ shape, generator=generator, device=device, dtype=dtype
397
+ )
398
+ else:
399
+ latents = latents.to(device)
400
+
401
+ # scale the initial noise by the standard deviation required by the scheduler
402
+ latents = latents * self.scheduler.init_noise_sigma
403
+ return latents
404
+
405
+ def _get_latent_model_input(
406
+ self,
407
+ latents: torch.FloatTensor,
408
+ image: Optional[
409
+ Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]
410
+ ],
411
+ num_images_per_prompt: int,
412
+ do_classifier_free_guidance: bool,
413
+ image_latents: Optional[torch.FloatTensor] = None,
414
+ ):
415
+ if isinstance(image, PIL.Image.Image):
416
+ image_pt = TF.to_tensor(image).unsqueeze(0).to(latents)
417
+ elif isinstance(image, list):
418
+ image_pt = torch.stack([TF.to_tensor(img) for img in image], dim=0).to(
419
+ latents
420
+ )
421
+ elif isinstance(image, torch.Tensor):
422
+ image_pt = image
423
+ else:
424
+ image_pt = None
425
+
426
+ if image_pt is None:
427
+ assert image_latents is not None
428
+ image_pt = image_latents.repeat_interleave(num_images_per_prompt, dim=0)
429
+ else:
430
+ image_pt = image_pt * 2.0 - 1.0 # scale to [-1, 1]
431
+ # FIXME: encoded latents should be multiplied with self.vae.config.scaling_factor
432
+ # but zero123 was not trained this way
433
+ image_pt = self.vae.encode(image_pt).latent_dist.mode()
434
+ image_pt = image_pt.repeat_interleave(num_images_per_prompt, dim=0)
435
+ if do_classifier_free_guidance:
436
+ latent_model_input = torch.cat(
437
+ [
438
+ torch.cat([latents, latents], dim=0),
439
+ torch.cat([torch.zeros_like(image_pt), image_pt], dim=0),
440
+ ],
441
+ dim=1,
442
+ )
443
+ else:
444
+ latent_model_input = torch.cat([latents, image_pt], dim=1)
445
+
446
+ return latent_model_input
447
+
448
+ @torch.no_grad()
449
+ def __call__(
450
+ self,
451
+ image: Optional[
452
+ Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]
453
+ ] = None,
454
+ elevation: Optional[Union[float, torch.FloatTensor]] = None,
455
+ azimuth: Optional[Union[float, torch.FloatTensor]] = None,
456
+ distance: Optional[Union[float, torch.FloatTensor]] = None,
457
+ height: Optional[int] = None,
458
+ width: Optional[int] = None,
459
+ num_inference_steps: int = 50,
460
+ guidance_scale: float = 3.0,
461
+ num_images_per_prompt: int = 1,
462
+ eta: float = 0.0,
463
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
464
+ latents: Optional[torch.FloatTensor] = None,
465
+ clip_image_embeddings: Optional[torch.FloatTensor] = None,
466
+ image_camera_embeddings: Optional[torch.FloatTensor] = None,
467
+ image_latents: Optional[torch.FloatTensor] = None,
468
+ output_type: Optional[str] = "pil",
469
+ return_dict: bool = True,
470
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
471
+ callback_steps: int = 1,
472
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
473
+ ):
474
+ r"""
475
+ Function invoked when calling the pipeline for generation.
476
+
477
+ Args:
478
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
479
+ The image or images to guide the image generation. If you provide a tensor, it needs to comply with the
480
+ configuration of
481
+ [this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)
482
+ `CLIPImageProcessor`
483
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
484
+ The height in pixels of the generated image.
485
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
486
+ The width in pixels of the generated image.
487
+ num_inference_steps (`int`, *optional*, defaults to 50):
488
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
489
+ expense of slower inference.
490
+ guidance_scale (`float`, *optional*, defaults to 7.5):
491
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
492
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
493
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
494
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
495
+ usually at the expense of lower image quality.
496
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
497
+ The number of images to generate per prompt.
498
+ eta (`float`, *optional*, defaults to 0.0):
499
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
500
+ [`schedulers.DDIMScheduler`], will be ignored for others.
501
+ generator (`torch.Generator`, *optional*):
502
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
503
+ to make generation deterministic.
504
+ latents (`torch.FloatTensor`, *optional*):
505
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
506
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
507
+ tensor will ge generated by sampling using the supplied random `generator`.
508
+ output_type (`str`, *optional*, defaults to `"pil"`):
509
+ The output format of the generate image. Choose between
510
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
511
+ return_dict (`bool`, *optional*, defaults to `True`):
512
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
513
+ plain tuple.
514
+ callback (`Callable`, *optional*):
515
+ A function that will be called every `callback_steps` steps during inference. The function will be
516
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
517
+ callback_steps (`int`, *optional*, defaults to 1):
518
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
519
+ called at every step.
520
+
521
+ Returns:
522
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
523
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
524
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
525
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
526
+ (nsfw) content, according to the `safety_checker`.
527
+ """
528
+ # 0. Default height and width to unet
529
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
530
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
531
+
532
+ # 1. Check inputs. Raise error if not correct
533
+ # TODO: check input elevation, azimuth, and distance
534
+ # TODO: check image, clip_image_embeddings, image_latents
535
+ self.check_inputs(image, height, width, callback_steps)
536
+
537
+ # 2. Define call parameters
538
+ if isinstance(image, PIL.Image.Image):
539
+ batch_size = 1
540
+ elif isinstance(image, list):
541
+ batch_size = len(image)
542
+ elif isinstance(image, torch.Tensor):
543
+ batch_size = image.shape[0]
544
+ else:
545
+ assert image_latents is not None
546
+ assert (
547
+ clip_image_embeddings is not None or image_camera_embeddings is not None
548
+ )
549
+ batch_size = image_latents.shape[0]
550
+
551
+ device = self._execution_device
552
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
553
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
554
+ # corresponds to doing no classifier free guidance.
555
+ do_classifier_free_guidance = guidance_scale > 1.0
556
+
557
+ # 3. Encode input image
558
+ if isinstance(image, PIL.Image.Image) or isinstance(image, list):
559
+ pil_image = image
560
+ elif isinstance(image, torch.Tensor):
561
+ pil_image = [TF.to_pil_image(image[i]) for i in range(image.shape[0])]
562
+ else:
563
+ pil_image = None
564
+ image_embeddings = self._encode_image(
565
+ pil_image,
566
+ elevation,
567
+ azimuth,
568
+ distance,
569
+ device,
570
+ num_images_per_prompt,
571
+ do_classifier_free_guidance,
572
+ clip_image_embeddings,
573
+ image_camera_embeddings,
574
+ )
575
+
576
+ # 4. Prepare timesteps
577
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
578
+ timesteps = self.scheduler.timesteps
579
+
580
+ # 5. Prepare latent variables
581
+ # num_channels_latents = self.unet.config.in_channels
582
+ num_channels_latents = 4 # FIXME: hard-coded
583
+ latents = self.prepare_latents(
584
+ batch_size * num_images_per_prompt,
585
+ num_channels_latents,
586
+ height,
587
+ width,
588
+ image_embeddings.dtype,
589
+ device,
590
+ generator,
591
+ latents,
592
+ )
593
+
594
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
595
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
596
+
597
+ # 7. Denoising loop
598
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
599
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
600
+ for i, t in enumerate(timesteps):
601
+ # expand the latents if we are doing classifier free guidance
602
+ latent_model_input = self._get_latent_model_input(
603
+ latents,
604
+ image,
605
+ num_images_per_prompt,
606
+ do_classifier_free_guidance,
607
+ image_latents,
608
+ )
609
+ latent_model_input = self.scheduler.scale_model_input(
610
+ latent_model_input, t
611
+ )
612
+
613
+ # predict the noise residual
614
+ noise_pred = self.unet(
615
+ latent_model_input,
616
+ t,
617
+ encoder_hidden_states=image_embeddings,
618
+ cross_attention_kwargs=cross_attention_kwargs,
619
+ ).sample
620
+
621
+ # perform guidance
622
+ if do_classifier_free_guidance:
623
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
624
+ noise_pred = noise_pred_uncond + guidance_scale * (
625
+ noise_pred_text - noise_pred_uncond
626
+ )
627
+
628
+ # compute the previous noisy sample x_t -> x_t-1
629
+ latents = self.scheduler.step(
630
+ noise_pred, t, latents, **extra_step_kwargs
631
+ ).prev_sample
632
+
633
+ # call the callback, if provided
634
+ if i == len(timesteps) - 1 or (
635
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
636
+ ):
637
+ progress_bar.update()
638
+ if callback is not None and i % callback_steps == 0:
639
+ callback(i, t, latents)
640
+
641
+ if not output_type == "latent":
642
+ image = self.vae.decode(
643
+ latents / self.vae.config.scaling_factor, return_dict=False
644
+ )[0]
645
+ image, has_nsfw_concept = self.run_safety_checker(
646
+ image, device, image_embeddings.dtype
647
+ )
648
+ else:
649
+ image = latents
650
+ has_nsfw_concept = None
651
+
652
+ if has_nsfw_concept is None:
653
+ do_denormalize = [True] * image.shape[0]
654
+ else:
655
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
656
+
657
+ image = self.image_processor.postprocess(
658
+ image, output_type=output_type, do_denormalize=do_denormalize
659
+ )
660
+
661
+ if not return_dict:
662
+ return (image, has_nsfw_concept)
663
+
664
+ return StableDiffusionPipelineOutput(
665
+ images=image, nsfw_content_detected=has_nsfw_concept
666
+ )