dooraven commited on
Commit
d3112c8
1 Parent(s): deb3ac7
Files changed (3) hide show
  1. app.py +3 -5
  2. pc.py +174 -0
  3. sampler.py +263 -0
app.py CHANGED
@@ -8,11 +8,9 @@ import numpy as np
8
  import argparse
9
 
10
  from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
11
- from point_e.diffusion.sampler import PointCloudSampler
12
  from point_e.models.download import load_checkpoint
13
  from point_e.models.configs import MODEL_CONFIGS, model_from_config
14
- from point_e.util.plotting import plot_point_cloud
15
- from point_e.util.ply_util import write_ply
16
 
17
  from diffusers import StableDiffusionPipeline
18
 
@@ -119,7 +117,7 @@ def generate_3D(input, model_name='base1B', guidance_scale=3.0, grid_size=128):
119
  set_state('Converting to mesh...')
120
 
121
  uniqid = uuid.uuid4()
122
- file_path = f'/tmp/mesh-{uniqid}.ply'
123
  save_ply(pc, file_path)
124
 
125
  set_state('')
@@ -153,7 +151,7 @@ def ply_to_glb(ply_file, glb_file):
153
  def save_ply(pc, file_name):
154
  # Produce a mesh (with vertex colors)
155
  with open(file_name, 'wb') as f:
156
- pc.write_ply(f)
157
 
158
 
159
  def create_gif(pc):
 
8
  import argparse
9
 
10
  from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
11
+ from .sampler import PointCloudSampler
12
  from point_e.models.download import load_checkpoint
13
  from point_e.models.configs import MODEL_CONFIGS, model_from_config
 
 
14
 
15
  from diffusers import StableDiffusionPipeline
16
 
 
117
  set_state('Converting to mesh...')
118
 
119
  uniqid = uuid.uuid4()
120
+ file_path = f'/tmp/mesh-{uniqid}.npy'
121
  save_ply(pc, file_path)
122
 
123
  set_state('')
 
151
  def save_ply(pc, file_name):
152
  # Produce a mesh (with vertex colors)
153
  with open(file_name, 'wb') as f:
154
+ pc.save(f)
155
 
156
 
157
  def create_gif(pc):
pc.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from dataclasses import dataclass
3
+ from typing import BinaryIO, Dict, List, Optional, Union
4
+
5
+ import numpy as np
6
+
7
+ from .ply_util import write_ply
8
+
9
+ COLORS = frozenset(["R", "G", "B", "A"])
10
+
11
+
12
+ def preprocess(data, channel):
13
+ if channel in COLORS:
14
+ return np.round(data * 255.0)
15
+ return data
16
+
17
+
18
+ @dataclass
19
+ class PointCloud:
20
+ """
21
+ An array of points sampled on a surface. Each point may have zero or more
22
+ channel attributes.
23
+
24
+ :param coords: an [N x 3] array of point coordinates.
25
+ :param channels: a dict mapping names to [N] arrays of channel values.
26
+ """
27
+
28
+ coords: np.ndarray
29
+ channels: Dict[str, np.ndarray]
30
+
31
+ @classmethod
32
+ def load(cls, f: Union[str, BinaryIO]) -> "PointCloud":
33
+ """
34
+ Load the point cloud from a .npz file.
35
+ """
36
+ if isinstance(f, str):
37
+ with open(f, "rb") as reader:
38
+ return cls.load(reader)
39
+ else:
40
+ obj = np.load(f)
41
+ keys = list(obj.keys())
42
+ return PointCloud(
43
+ coords=obj["coords"],
44
+ channels={k: obj[k] for k in keys if k != "coords"},
45
+ )
46
+
47
+ def save(self, f: Union[str, BinaryIO]):
48
+ """
49
+ Save the point cloud to a .npz file.
50
+ """
51
+ if isinstance(f, str):
52
+ with open(f, "wb") as writer:
53
+ self.save(writer)
54
+ else:
55
+ np.save(f, coords=self.coords, **self.channels)
56
+
57
+ def write_ply(self, raw_f: BinaryIO):
58
+ write_ply(
59
+ raw_f,
60
+ coords=self.coords,
61
+ rgb=(
62
+ np.stack([self.channels[x] for x in "RGB"], axis=1)
63
+ if all(x in self.channels for x in "RGB")
64
+ else None
65
+ ),
66
+ )
67
+
68
+ def random_sample(self, num_points: int, **subsample_kwargs) -> "PointCloud":
69
+ """
70
+ Sample a random subset of this PointCloud.
71
+
72
+ :param num_points: maximum number of points to sample.
73
+ :param subsample_kwargs: arguments to self.subsample().
74
+ :return: a reduced PointCloud, or self if num_points is not less than
75
+ the current number of points.
76
+ """
77
+ if len(self.coords) <= num_points:
78
+ return self
79
+ indices = np.random.choice(len(self.coords), size=(num_points,), replace=False)
80
+ return self.subsample(indices, **subsample_kwargs)
81
+
82
+ def farthest_point_sample(
83
+ self, num_points: int, init_idx: Optional[int] = None, **subsample_kwargs
84
+ ) -> "PointCloud":
85
+ """
86
+ Sample a subset of the point cloud that is evenly distributed in space.
87
+
88
+ First, a random point is selected. Then each successive point is chosen
89
+ such that it is furthest from the currently selected points.
90
+
91
+ The time complexity of this operation is O(NM), where N is the original
92
+ number of points and M is the reduced number. Therefore, performance
93
+ can be improved by randomly subsampling points with random_sample()
94
+ before running farthest_point_sample().
95
+
96
+ :param num_points: maximum number of points to sample.
97
+ :param init_idx: if specified, the first point to sample.
98
+ :param subsample_kwargs: arguments to self.subsample().
99
+ :return: a reduced PointCloud, or self if num_points is not less than
100
+ the current number of points.
101
+ """
102
+ if len(self.coords) <= num_points:
103
+ return self
104
+ init_idx = random.randrange(len(self.coords)) if init_idx is None else init_idx
105
+ indices = np.zeros([num_points], dtype=np.int64)
106
+ indices[0] = init_idx
107
+ sq_norms = np.sum(self.coords**2, axis=-1)
108
+
109
+ def compute_dists(idx: int):
110
+ # Utilize equality: ||A-B||^2 = ||A||^2 + ||B||^2 - 2*(A @ B).
111
+ return sq_norms + sq_norms[idx] - 2 * (self.coords @ self.coords[idx])
112
+
113
+ cur_dists = compute_dists(init_idx)
114
+ for i in range(1, num_points):
115
+ idx = np.argmax(cur_dists)
116
+ indices[i] = idx
117
+ cur_dists = np.minimum(cur_dists, compute_dists(idx))
118
+ return self.subsample(indices, **subsample_kwargs)
119
+
120
+ def subsample(self, indices: np.ndarray, average_neighbors: bool = False) -> "PointCloud":
121
+ if not average_neighbors:
122
+ return PointCloud(
123
+ coords=self.coords[indices],
124
+ channels={k: v[indices] for k, v in self.channels.items()},
125
+ )
126
+
127
+ new_coords = self.coords[indices]
128
+ neighbor_indices = PointCloud(coords=new_coords, channels={}).nearest_points(self.coords)
129
+
130
+ # Make sure every point points to itself, which might not
131
+ # be the case if points are duplicated or there is rounding
132
+ # error.
133
+ neighbor_indices[indices] = np.arange(len(indices))
134
+
135
+ new_channels = {}
136
+ for k, v in self.channels.items():
137
+ v_sum = np.zeros_like(v[: len(indices)])
138
+ v_count = np.zeros_like(v[: len(indices)])
139
+ np.add.at(v_sum, neighbor_indices, v)
140
+ np.add.at(v_count, neighbor_indices, 1)
141
+ new_channels[k] = v_sum / v_count
142
+ return PointCloud(coords=new_coords, channels=new_channels)
143
+
144
+ def select_channels(self, channel_names: List[str]) -> np.ndarray:
145
+ data = np.stack([preprocess(self.channels[name], name) for name in channel_names], axis=-1)
146
+ return data
147
+
148
+ def nearest_points(self, points: np.ndarray, batch_size: int = 16384) -> np.ndarray:
149
+ """
150
+ For each point in another set of points, compute the point in this
151
+ pointcloud which is closest.
152
+
153
+ :param points: an [N x 3] array of points.
154
+ :param batch_size: the number of neighbor distances to compute at once.
155
+ Smaller values save memory, while larger values may
156
+ make the computation faster.
157
+ :return: an [N] array of indices into self.coords.
158
+ """
159
+ norms = np.sum(self.coords**2, axis=-1)
160
+ all_indices = []
161
+ for i in range(0, len(points), batch_size):
162
+ batch = points[i : i + batch_size]
163
+ dists = norms + np.sum(batch**2, axis=-1)[:, None] - 2 * (batch @ self.coords.T)
164
+ all_indices.append(np.argmin(dists, axis=-1))
165
+ return np.concatenate(all_indices, axis=0)
166
+
167
+ def combine(self, other: "PointCloud") -> "PointCloud":
168
+ assert self.channels.keys() == other.channels.keys()
169
+ return PointCloud(
170
+ coords=np.concatenate([self.coords, other.coords], axis=0),
171
+ channels={
172
+ k: np.concatenate([v, other.channels[k]], axis=0) for k, v in self.channels.items()
173
+ },
174
+ )
sampler.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers for sampling from a single- or multi-stage point cloud diffusion model.
3
+ """
4
+
5
+ from typing import Any, Callable, Dict, Iterator, List, Sequence, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from .pc import PointCloud
11
+
12
+ from point_e.diffusion.gaussian_diffusion import GaussianDiffusion
13
+ from point_e.diffusion.k_diffusion import karras_sample_progressive
14
+
15
+
16
+ class PointCloudSampler:
17
+ """
18
+ A wrapper around a model or stack of models that produces conditional or
19
+ unconditional sample tensors.
20
+
21
+ By default, this will load models and configs from files.
22
+ If you want to modify the sampler arguments of an existing sampler, call
23
+ with_options() or with_args().
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ device: torch.device,
29
+ models: Sequence[nn.Module],
30
+ diffusions: Sequence[GaussianDiffusion],
31
+ num_points: Sequence[int],
32
+ aux_channels: Sequence[str],
33
+ model_kwargs_key_filter: Sequence[str] = ("*",),
34
+ guidance_scale: Sequence[float] = (3.0, 3.0),
35
+ clip_denoised: bool = True,
36
+ use_karras: Sequence[bool] = (True, True),
37
+ karras_steps: Sequence[int] = (64, 64),
38
+ sigma_min: Sequence[float] = (1e-3, 1e-3),
39
+ sigma_max: Sequence[float] = (120, 160),
40
+ s_churn: Sequence[float] = (3, 0),
41
+ ):
42
+ n = len(models)
43
+ assert n > 0
44
+
45
+ if n > 1:
46
+ if len(guidance_scale) == 1:
47
+ # Don't guide the upsamplers by default.
48
+ guidance_scale = list(guidance_scale) + [1.0] * (n - 1)
49
+ if len(use_karras) == 1:
50
+ use_karras = use_karras * n
51
+ if len(karras_steps) == 1:
52
+ karras_steps = karras_steps * n
53
+ if len(sigma_min) == 1:
54
+ sigma_min = sigma_min * n
55
+ if len(sigma_max) == 1:
56
+ sigma_max = sigma_max * n
57
+ if len(s_churn) == 1:
58
+ s_churn = s_churn * n
59
+ if len(model_kwargs_key_filter) == 1:
60
+ model_kwargs_key_filter = model_kwargs_key_filter * n
61
+ if len(model_kwargs_key_filter) == 0:
62
+ model_kwargs_key_filter = ["*"] * n
63
+ assert len(guidance_scale) == n
64
+ assert len(use_karras) == n
65
+ assert len(karras_steps) == n
66
+ assert len(sigma_min) == n
67
+ assert len(sigma_max) == n
68
+ assert len(s_churn) == n
69
+ assert len(model_kwargs_key_filter) == n
70
+
71
+ self.device = device
72
+ self.num_points = num_points
73
+ self.aux_channels = aux_channels
74
+ self.model_kwargs_key_filter = model_kwargs_key_filter
75
+ self.guidance_scale = guidance_scale
76
+ self.clip_denoised = clip_denoised
77
+ self.use_karras = use_karras
78
+ self.karras_steps = karras_steps
79
+ self.sigma_min = sigma_min
80
+ self.sigma_max = sigma_max
81
+ self.s_churn = s_churn
82
+
83
+ self.models = models
84
+ self.diffusions = diffusions
85
+
86
+ @property
87
+ def num_stages(self) -> int:
88
+ return len(self.models)
89
+
90
+ def sample_batch(self, batch_size: int, model_kwargs: Dict[str, Any]) -> torch.Tensor:
91
+ samples = None
92
+ for x in self.sample_batch_progressive(batch_size, model_kwargs):
93
+ samples = x
94
+ return samples
95
+
96
+ def sample_batch_progressive(
97
+ self, batch_size: int, model_kwargs: Dict[str, Any]
98
+ ) -> Iterator[torch.Tensor]:
99
+ samples = None
100
+ for (
101
+ model,
102
+ diffusion,
103
+ stage_num_points,
104
+ stage_guidance_scale,
105
+ stage_use_karras,
106
+ stage_karras_steps,
107
+ stage_sigma_min,
108
+ stage_sigma_max,
109
+ stage_s_churn,
110
+ stage_key_filter,
111
+ ) in zip(
112
+ self.models,
113
+ self.diffusions,
114
+ self.num_points,
115
+ self.guidance_scale,
116
+ self.use_karras,
117
+ self.karras_steps,
118
+ self.sigma_min,
119
+ self.sigma_max,
120
+ self.s_churn,
121
+ self.model_kwargs_key_filter,
122
+ ):
123
+ stage_model_kwargs = model_kwargs.copy()
124
+ if stage_key_filter != "*":
125
+ use_keys = set(stage_key_filter.split(","))
126
+ stage_model_kwargs = {k: v for k, v in stage_model_kwargs.items() if k in use_keys}
127
+ if samples is not None:
128
+ stage_model_kwargs["low_res"] = samples
129
+ if hasattr(model, "cached_model_kwargs"):
130
+ stage_model_kwargs = model.cached_model_kwargs(batch_size, stage_model_kwargs)
131
+ sample_shape = (batch_size, 3 + len(self.aux_channels), stage_num_points)
132
+
133
+ if stage_guidance_scale != 1 and stage_guidance_scale != 0:
134
+ for k, v in stage_model_kwargs.copy().items():
135
+ stage_model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0)
136
+
137
+ if stage_use_karras:
138
+ samples_it = karras_sample_progressive(
139
+ diffusion=diffusion,
140
+ model=model,
141
+ shape=sample_shape,
142
+ steps=stage_karras_steps,
143
+ clip_denoised=self.clip_denoised,
144
+ model_kwargs=stage_model_kwargs,
145
+ device=self.device,
146
+ sigma_min=stage_sigma_min,
147
+ sigma_max=stage_sigma_max,
148
+ s_churn=stage_s_churn,
149
+ guidance_scale=stage_guidance_scale,
150
+ )
151
+ else:
152
+ internal_batch_size = batch_size
153
+ if stage_guidance_scale:
154
+ model = self._uncond_guide_model(model, stage_guidance_scale)
155
+ internal_batch_size *= 2
156
+ samples_it = diffusion.p_sample_loop_progressive(
157
+ model,
158
+ shape=(internal_batch_size, *sample_shape[1:]),
159
+ model_kwargs=stage_model_kwargs,
160
+ device=self.device,
161
+ clip_denoised=self.clip_denoised,
162
+ )
163
+ for x in samples_it:
164
+ samples = x["pred_xstart"][:batch_size]
165
+ if "low_res" in stage_model_kwargs:
166
+ samples = torch.cat(
167
+ [stage_model_kwargs["low_res"][: len(samples)], samples], dim=-1
168
+ )
169
+ yield samples
170
+
171
+ @classmethod
172
+ def combine(cls, *samplers: "PointCloudSampler") -> "PointCloudSampler":
173
+ assert all(x.device == samplers[0].device for x in samplers[1:])
174
+ assert all(x.aux_channels == samplers[0].aux_channels for x in samplers[1:])
175
+ assert all(x.clip_denoised == samplers[0].clip_denoised for x in samplers[1:])
176
+ return cls(
177
+ device=samplers[0].device,
178
+ models=[x for y in samplers for x in y.models],
179
+ diffusions=[x for y in samplers for x in y.diffusions],
180
+ num_points=[x for y in samplers for x in y.num_points],
181
+ aux_channels=samplers[0].aux_channels,
182
+ model_kwargs_key_filter=[x for y in samplers for x in y.model_kwargs_key_filter],
183
+ guidance_scale=[x for y in samplers for x in y.guidance_scale],
184
+ clip_denoised=samplers[0].clip_denoised,
185
+ use_karras=[x for y in samplers for x in y.use_karras],
186
+ karras_steps=[x for y in samplers for x in y.karras_steps],
187
+ sigma_min=[x for y in samplers for x in y.sigma_min],
188
+ sigma_max=[x for y in samplers for x in y.sigma_max],
189
+ s_churn=[x for y in samplers for x in y.s_churn],
190
+ )
191
+
192
+ def _uncond_guide_model(
193
+ self, model: Callable[..., torch.Tensor], scale: float
194
+ ) -> Callable[..., torch.Tensor]:
195
+ def model_fn(x_t, ts, **kwargs):
196
+ half = x_t[: len(x_t) // 2]
197
+ combined = torch.cat([half, half], dim=0)
198
+ model_out = model(combined, ts, **kwargs)
199
+ eps, rest = model_out[:, :3], model_out[:, 3:]
200
+ cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0)
201
+ half_eps = uncond_eps + scale * (cond_eps - uncond_eps)
202
+ eps = torch.cat([half_eps, half_eps], dim=0)
203
+ return torch.cat([eps, rest], dim=1)
204
+
205
+ return model_fn
206
+
207
+ def split_model_output(
208
+ self,
209
+ output: torch.Tensor,
210
+ rescale_colors: bool = False,
211
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
212
+ assert (
213
+ len(self.aux_channels) + 3 == output.shape[1]
214
+ ), "there must be three spatial channels before aux"
215
+ pos, joined_aux = output[:, :3], output[:, 3:]
216
+
217
+ aux = {}
218
+ for i, name in enumerate(self.aux_channels):
219
+ v = joined_aux[:, i]
220
+ if name in {"R", "G", "B", "A"}:
221
+ v = v.clamp(0, 255).round()
222
+ if rescale_colors:
223
+ v = v / 255.0
224
+ aux[name] = v
225
+ return pos, aux
226
+
227
+ def output_to_point_clouds(self, output: torch.Tensor) -> List[PointCloud]:
228
+ res = []
229
+ for sample in output:
230
+ xyz, aux = self.split_model_output(sample[None], rescale_colors=True)
231
+ res.append(
232
+ PointCloud(
233
+ coords=xyz[0].t().cpu().numpy(),
234
+ channels={k: v[0].cpu().numpy() for k, v in aux.items()},
235
+ )
236
+ )
237
+ return res
238
+
239
+ def with_options(
240
+ self,
241
+ guidance_scale: float,
242
+ clip_denoised: bool,
243
+ use_karras: Sequence[bool] = (True, True),
244
+ karras_steps: Sequence[int] = (64, 64),
245
+ sigma_min: Sequence[float] = (1e-3, 1e-3),
246
+ sigma_max: Sequence[float] = (120, 160),
247
+ s_churn: Sequence[float] = (3, 0),
248
+ ) -> "PointCloudSampler":
249
+ return PointCloudSampler(
250
+ device=self.device,
251
+ models=self.models,
252
+ diffusions=self.diffusions,
253
+ num_points=self.num_points,
254
+ aux_channels=self.aux_channels,
255
+ model_kwargs_key_filter=self.model_kwargs_key_filter,
256
+ guidance_scale=guidance_scale,
257
+ clip_denoised=clip_denoised,
258
+ use_karras=use_karras,
259
+ karras_steps=karras_steps,
260
+ sigma_min=sigma_min,
261
+ sigma_max=sigma_max,
262
+ s_churn=s_churn,
263
+ )