dooraven commited on
Commit
1cecf80
1 Parent(s): d3112c8
Files changed (3) hide show
  1. app.py +1 -1
  2. pc.py +0 -174
  3. sampler.py +0 -263
app.py CHANGED
@@ -117,7 +117,7 @@ def generate_3D(input, model_name='base1B', guidance_scale=3.0, grid_size=128):
117
  set_state('Converting to mesh...')
118
 
119
  uniqid = uuid.uuid4()
120
- file_path = f'/tmp/mesh-{uniqid}.npy'
121
  save_ply(pc, file_path)
122
 
123
  set_state('')
 
117
  set_state('Converting to mesh...')
118
 
119
  uniqid = uuid.uuid4()
120
+ file_path = f'/tmp/mesh-{uniqid}.npz'
121
  save_ply(pc, file_path)
122
 
123
  set_state('')
pc.py DELETED
@@ -1,174 +0,0 @@
1
- import random
2
- from dataclasses import dataclass
3
- from typing import BinaryIO, Dict, List, Optional, Union
4
-
5
- import numpy as np
6
-
7
- from .ply_util import write_ply
8
-
9
- COLORS = frozenset(["R", "G", "B", "A"])
10
-
11
-
12
- def preprocess(data, channel):
13
- if channel in COLORS:
14
- return np.round(data * 255.0)
15
- return data
16
-
17
-
18
- @dataclass
19
- class PointCloud:
20
- """
21
- An array of points sampled on a surface. Each point may have zero or more
22
- channel attributes.
23
-
24
- :param coords: an [N x 3] array of point coordinates.
25
- :param channels: a dict mapping names to [N] arrays of channel values.
26
- """
27
-
28
- coords: np.ndarray
29
- channels: Dict[str, np.ndarray]
30
-
31
- @classmethod
32
- def load(cls, f: Union[str, BinaryIO]) -> "PointCloud":
33
- """
34
- Load the point cloud from a .npz file.
35
- """
36
- if isinstance(f, str):
37
- with open(f, "rb") as reader:
38
- return cls.load(reader)
39
- else:
40
- obj = np.load(f)
41
- keys = list(obj.keys())
42
- return PointCloud(
43
- coords=obj["coords"],
44
- channels={k: obj[k] for k in keys if k != "coords"},
45
- )
46
-
47
- def save(self, f: Union[str, BinaryIO]):
48
- """
49
- Save the point cloud to a .npz file.
50
- """
51
- if isinstance(f, str):
52
- with open(f, "wb") as writer:
53
- self.save(writer)
54
- else:
55
- np.save(f, coords=self.coords, **self.channels)
56
-
57
- def write_ply(self, raw_f: BinaryIO):
58
- write_ply(
59
- raw_f,
60
- coords=self.coords,
61
- rgb=(
62
- np.stack([self.channels[x] for x in "RGB"], axis=1)
63
- if all(x in self.channels for x in "RGB")
64
- else None
65
- ),
66
- )
67
-
68
- def random_sample(self, num_points: int, **subsample_kwargs) -> "PointCloud":
69
- """
70
- Sample a random subset of this PointCloud.
71
-
72
- :param num_points: maximum number of points to sample.
73
- :param subsample_kwargs: arguments to self.subsample().
74
- :return: a reduced PointCloud, or self if num_points is not less than
75
- the current number of points.
76
- """
77
- if len(self.coords) <= num_points:
78
- return self
79
- indices = np.random.choice(len(self.coords), size=(num_points,), replace=False)
80
- return self.subsample(indices, **subsample_kwargs)
81
-
82
- def farthest_point_sample(
83
- self, num_points: int, init_idx: Optional[int] = None, **subsample_kwargs
84
- ) -> "PointCloud":
85
- """
86
- Sample a subset of the point cloud that is evenly distributed in space.
87
-
88
- First, a random point is selected. Then each successive point is chosen
89
- such that it is furthest from the currently selected points.
90
-
91
- The time complexity of this operation is O(NM), where N is the original
92
- number of points and M is the reduced number. Therefore, performance
93
- can be improved by randomly subsampling points with random_sample()
94
- before running farthest_point_sample().
95
-
96
- :param num_points: maximum number of points to sample.
97
- :param init_idx: if specified, the first point to sample.
98
- :param subsample_kwargs: arguments to self.subsample().
99
- :return: a reduced PointCloud, or self if num_points is not less than
100
- the current number of points.
101
- """
102
- if len(self.coords) <= num_points:
103
- return self
104
- init_idx = random.randrange(len(self.coords)) if init_idx is None else init_idx
105
- indices = np.zeros([num_points], dtype=np.int64)
106
- indices[0] = init_idx
107
- sq_norms = np.sum(self.coords**2, axis=-1)
108
-
109
- def compute_dists(idx: int):
110
- # Utilize equality: ||A-B||^2 = ||A||^2 + ||B||^2 - 2*(A @ B).
111
- return sq_norms + sq_norms[idx] - 2 * (self.coords @ self.coords[idx])
112
-
113
- cur_dists = compute_dists(init_idx)
114
- for i in range(1, num_points):
115
- idx = np.argmax(cur_dists)
116
- indices[i] = idx
117
- cur_dists = np.minimum(cur_dists, compute_dists(idx))
118
- return self.subsample(indices, **subsample_kwargs)
119
-
120
- def subsample(self, indices: np.ndarray, average_neighbors: bool = False) -> "PointCloud":
121
- if not average_neighbors:
122
- return PointCloud(
123
- coords=self.coords[indices],
124
- channels={k: v[indices] for k, v in self.channels.items()},
125
- )
126
-
127
- new_coords = self.coords[indices]
128
- neighbor_indices = PointCloud(coords=new_coords, channels={}).nearest_points(self.coords)
129
-
130
- # Make sure every point points to itself, which might not
131
- # be the case if points are duplicated or there is rounding
132
- # error.
133
- neighbor_indices[indices] = np.arange(len(indices))
134
-
135
- new_channels = {}
136
- for k, v in self.channels.items():
137
- v_sum = np.zeros_like(v[: len(indices)])
138
- v_count = np.zeros_like(v[: len(indices)])
139
- np.add.at(v_sum, neighbor_indices, v)
140
- np.add.at(v_count, neighbor_indices, 1)
141
- new_channels[k] = v_sum / v_count
142
- return PointCloud(coords=new_coords, channels=new_channels)
143
-
144
- def select_channels(self, channel_names: List[str]) -> np.ndarray:
145
- data = np.stack([preprocess(self.channels[name], name) for name in channel_names], axis=-1)
146
- return data
147
-
148
- def nearest_points(self, points: np.ndarray, batch_size: int = 16384) -> np.ndarray:
149
- """
150
- For each point in another set of points, compute the point in this
151
- pointcloud which is closest.
152
-
153
- :param points: an [N x 3] array of points.
154
- :param batch_size: the number of neighbor distances to compute at once.
155
- Smaller values save memory, while larger values may
156
- make the computation faster.
157
- :return: an [N] array of indices into self.coords.
158
- """
159
- norms = np.sum(self.coords**2, axis=-1)
160
- all_indices = []
161
- for i in range(0, len(points), batch_size):
162
- batch = points[i : i + batch_size]
163
- dists = norms + np.sum(batch**2, axis=-1)[:, None] - 2 * (batch @ self.coords.T)
164
- all_indices.append(np.argmin(dists, axis=-1))
165
- return np.concatenate(all_indices, axis=0)
166
-
167
- def combine(self, other: "PointCloud") -> "PointCloud":
168
- assert self.channels.keys() == other.channels.keys()
169
- return PointCloud(
170
- coords=np.concatenate([self.coords, other.coords], axis=0),
171
- channels={
172
- k: np.concatenate([v, other.channels[k]], axis=0) for k, v in self.channels.items()
173
- },
174
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sampler.py DELETED
@@ -1,263 +0,0 @@
1
- """
2
- Helpers for sampling from a single- or multi-stage point cloud diffusion model.
3
- """
4
-
5
- from typing import Any, Callable, Dict, Iterator, List, Sequence, Tuple
6
-
7
- import torch
8
- import torch.nn as nn
9
-
10
- from .pc import PointCloud
11
-
12
- from point_e.diffusion.gaussian_diffusion import GaussianDiffusion
13
- from point_e.diffusion.k_diffusion import karras_sample_progressive
14
-
15
-
16
- class PointCloudSampler:
17
- """
18
- A wrapper around a model or stack of models that produces conditional or
19
- unconditional sample tensors.
20
-
21
- By default, this will load models and configs from files.
22
- If you want to modify the sampler arguments of an existing sampler, call
23
- with_options() or with_args().
24
- """
25
-
26
- def __init__(
27
- self,
28
- device: torch.device,
29
- models: Sequence[nn.Module],
30
- diffusions: Sequence[GaussianDiffusion],
31
- num_points: Sequence[int],
32
- aux_channels: Sequence[str],
33
- model_kwargs_key_filter: Sequence[str] = ("*",),
34
- guidance_scale: Sequence[float] = (3.0, 3.0),
35
- clip_denoised: bool = True,
36
- use_karras: Sequence[bool] = (True, True),
37
- karras_steps: Sequence[int] = (64, 64),
38
- sigma_min: Sequence[float] = (1e-3, 1e-3),
39
- sigma_max: Sequence[float] = (120, 160),
40
- s_churn: Sequence[float] = (3, 0),
41
- ):
42
- n = len(models)
43
- assert n > 0
44
-
45
- if n > 1:
46
- if len(guidance_scale) == 1:
47
- # Don't guide the upsamplers by default.
48
- guidance_scale = list(guidance_scale) + [1.0] * (n - 1)
49
- if len(use_karras) == 1:
50
- use_karras = use_karras * n
51
- if len(karras_steps) == 1:
52
- karras_steps = karras_steps * n
53
- if len(sigma_min) == 1:
54
- sigma_min = sigma_min * n
55
- if len(sigma_max) == 1:
56
- sigma_max = sigma_max * n
57
- if len(s_churn) == 1:
58
- s_churn = s_churn * n
59
- if len(model_kwargs_key_filter) == 1:
60
- model_kwargs_key_filter = model_kwargs_key_filter * n
61
- if len(model_kwargs_key_filter) == 0:
62
- model_kwargs_key_filter = ["*"] * n
63
- assert len(guidance_scale) == n
64
- assert len(use_karras) == n
65
- assert len(karras_steps) == n
66
- assert len(sigma_min) == n
67
- assert len(sigma_max) == n
68
- assert len(s_churn) == n
69
- assert len(model_kwargs_key_filter) == n
70
-
71
- self.device = device
72
- self.num_points = num_points
73
- self.aux_channels = aux_channels
74
- self.model_kwargs_key_filter = model_kwargs_key_filter
75
- self.guidance_scale = guidance_scale
76
- self.clip_denoised = clip_denoised
77
- self.use_karras = use_karras
78
- self.karras_steps = karras_steps
79
- self.sigma_min = sigma_min
80
- self.sigma_max = sigma_max
81
- self.s_churn = s_churn
82
-
83
- self.models = models
84
- self.diffusions = diffusions
85
-
86
- @property
87
- def num_stages(self) -> int:
88
- return len(self.models)
89
-
90
- def sample_batch(self, batch_size: int, model_kwargs: Dict[str, Any]) -> torch.Tensor:
91
- samples = None
92
- for x in self.sample_batch_progressive(batch_size, model_kwargs):
93
- samples = x
94
- return samples
95
-
96
- def sample_batch_progressive(
97
- self, batch_size: int, model_kwargs: Dict[str, Any]
98
- ) -> Iterator[torch.Tensor]:
99
- samples = None
100
- for (
101
- model,
102
- diffusion,
103
- stage_num_points,
104
- stage_guidance_scale,
105
- stage_use_karras,
106
- stage_karras_steps,
107
- stage_sigma_min,
108
- stage_sigma_max,
109
- stage_s_churn,
110
- stage_key_filter,
111
- ) in zip(
112
- self.models,
113
- self.diffusions,
114
- self.num_points,
115
- self.guidance_scale,
116
- self.use_karras,
117
- self.karras_steps,
118
- self.sigma_min,
119
- self.sigma_max,
120
- self.s_churn,
121
- self.model_kwargs_key_filter,
122
- ):
123
- stage_model_kwargs = model_kwargs.copy()
124
- if stage_key_filter != "*":
125
- use_keys = set(stage_key_filter.split(","))
126
- stage_model_kwargs = {k: v for k, v in stage_model_kwargs.items() if k in use_keys}
127
- if samples is not None:
128
- stage_model_kwargs["low_res"] = samples
129
- if hasattr(model, "cached_model_kwargs"):
130
- stage_model_kwargs = model.cached_model_kwargs(batch_size, stage_model_kwargs)
131
- sample_shape = (batch_size, 3 + len(self.aux_channels), stage_num_points)
132
-
133
- if stage_guidance_scale != 1 and stage_guidance_scale != 0:
134
- for k, v in stage_model_kwargs.copy().items():
135
- stage_model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0)
136
-
137
- if stage_use_karras:
138
- samples_it = karras_sample_progressive(
139
- diffusion=diffusion,
140
- model=model,
141
- shape=sample_shape,
142
- steps=stage_karras_steps,
143
- clip_denoised=self.clip_denoised,
144
- model_kwargs=stage_model_kwargs,
145
- device=self.device,
146
- sigma_min=stage_sigma_min,
147
- sigma_max=stage_sigma_max,
148
- s_churn=stage_s_churn,
149
- guidance_scale=stage_guidance_scale,
150
- )
151
- else:
152
- internal_batch_size = batch_size
153
- if stage_guidance_scale:
154
- model = self._uncond_guide_model(model, stage_guidance_scale)
155
- internal_batch_size *= 2
156
- samples_it = diffusion.p_sample_loop_progressive(
157
- model,
158
- shape=(internal_batch_size, *sample_shape[1:]),
159
- model_kwargs=stage_model_kwargs,
160
- device=self.device,
161
- clip_denoised=self.clip_denoised,
162
- )
163
- for x in samples_it:
164
- samples = x["pred_xstart"][:batch_size]
165
- if "low_res" in stage_model_kwargs:
166
- samples = torch.cat(
167
- [stage_model_kwargs["low_res"][: len(samples)], samples], dim=-1
168
- )
169
- yield samples
170
-
171
- @classmethod
172
- def combine(cls, *samplers: "PointCloudSampler") -> "PointCloudSampler":
173
- assert all(x.device == samplers[0].device for x in samplers[1:])
174
- assert all(x.aux_channels == samplers[0].aux_channels for x in samplers[1:])
175
- assert all(x.clip_denoised == samplers[0].clip_denoised for x in samplers[1:])
176
- return cls(
177
- device=samplers[0].device,
178
- models=[x for y in samplers for x in y.models],
179
- diffusions=[x for y in samplers for x in y.diffusions],
180
- num_points=[x for y in samplers for x in y.num_points],
181
- aux_channels=samplers[0].aux_channels,
182
- model_kwargs_key_filter=[x for y in samplers for x in y.model_kwargs_key_filter],
183
- guidance_scale=[x for y in samplers for x in y.guidance_scale],
184
- clip_denoised=samplers[0].clip_denoised,
185
- use_karras=[x for y in samplers for x in y.use_karras],
186
- karras_steps=[x for y in samplers for x in y.karras_steps],
187
- sigma_min=[x for y in samplers for x in y.sigma_min],
188
- sigma_max=[x for y in samplers for x in y.sigma_max],
189
- s_churn=[x for y in samplers for x in y.s_churn],
190
- )
191
-
192
- def _uncond_guide_model(
193
- self, model: Callable[..., torch.Tensor], scale: float
194
- ) -> Callable[..., torch.Tensor]:
195
- def model_fn(x_t, ts, **kwargs):
196
- half = x_t[: len(x_t) // 2]
197
- combined = torch.cat([half, half], dim=0)
198
- model_out = model(combined, ts, **kwargs)
199
- eps, rest = model_out[:, :3], model_out[:, 3:]
200
- cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0)
201
- half_eps = uncond_eps + scale * (cond_eps - uncond_eps)
202
- eps = torch.cat([half_eps, half_eps], dim=0)
203
- return torch.cat([eps, rest], dim=1)
204
-
205
- return model_fn
206
-
207
- def split_model_output(
208
- self,
209
- output: torch.Tensor,
210
- rescale_colors: bool = False,
211
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
212
- assert (
213
- len(self.aux_channels) + 3 == output.shape[1]
214
- ), "there must be three spatial channels before aux"
215
- pos, joined_aux = output[:, :3], output[:, 3:]
216
-
217
- aux = {}
218
- for i, name in enumerate(self.aux_channels):
219
- v = joined_aux[:, i]
220
- if name in {"R", "G", "B", "A"}:
221
- v = v.clamp(0, 255).round()
222
- if rescale_colors:
223
- v = v / 255.0
224
- aux[name] = v
225
- return pos, aux
226
-
227
- def output_to_point_clouds(self, output: torch.Tensor) -> List[PointCloud]:
228
- res = []
229
- for sample in output:
230
- xyz, aux = self.split_model_output(sample[None], rescale_colors=True)
231
- res.append(
232
- PointCloud(
233
- coords=xyz[0].t().cpu().numpy(),
234
- channels={k: v[0].cpu().numpy() for k, v in aux.items()},
235
- )
236
- )
237
- return res
238
-
239
- def with_options(
240
- self,
241
- guidance_scale: float,
242
- clip_denoised: bool,
243
- use_karras: Sequence[bool] = (True, True),
244
- karras_steps: Sequence[int] = (64, 64),
245
- sigma_min: Sequence[float] = (1e-3, 1e-3),
246
- sigma_max: Sequence[float] = (120, 160),
247
- s_churn: Sequence[float] = (3, 0),
248
- ) -> "PointCloudSampler":
249
- return PointCloudSampler(
250
- device=self.device,
251
- models=self.models,
252
- diffusions=self.diffusions,
253
- num_points=self.num_points,
254
- aux_channels=self.aux_channels,
255
- model_kwargs_key_filter=self.model_kwargs_key_filter,
256
- guidance_scale=guidance_scale,
257
- clip_denoised=clip_denoised,
258
- use_karras=use_karras,
259
- karras_steps=karras_steps,
260
- sigma_min=sigma_min,
261
- sigma_max=sigma_max,
262
- s_churn=s_churn,
263
- )