Spaces:
Runtime error
Runtime error
wips
Browse files- app.py +3 -5
- pc.py +174 -0
- sampler.py +263 -0
app.py
CHANGED
@@ -8,11 +8,9 @@ import numpy as np
|
|
8 |
import argparse
|
9 |
|
10 |
from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
|
11 |
-
from
|
12 |
from point_e.models.download import load_checkpoint
|
13 |
from point_e.models.configs import MODEL_CONFIGS, model_from_config
|
14 |
-
from point_e.util.plotting import plot_point_cloud
|
15 |
-
from point_e.util.ply_util import write_ply
|
16 |
|
17 |
from diffusers import StableDiffusionPipeline
|
18 |
|
@@ -119,7 +117,7 @@ def generate_3D(input, model_name='base1B', guidance_scale=3.0, grid_size=128):
|
|
119 |
set_state('Converting to mesh...')
|
120 |
|
121 |
uniqid = uuid.uuid4()
|
122 |
-
file_path = f'/tmp/mesh-{uniqid}.
|
123 |
save_ply(pc, file_path)
|
124 |
|
125 |
set_state('')
|
@@ -153,7 +151,7 @@ def ply_to_glb(ply_file, glb_file):
|
|
153 |
def save_ply(pc, file_name):
|
154 |
# Produce a mesh (with vertex colors)
|
155 |
with open(file_name, 'wb') as f:
|
156 |
-
pc.
|
157 |
|
158 |
|
159 |
def create_gif(pc):
|
|
|
8 |
import argparse
|
9 |
|
10 |
from point_e.diffusion.configs import DIFFUSION_CONFIGS, diffusion_from_config
|
11 |
+
from .sampler import PointCloudSampler
|
12 |
from point_e.models.download import load_checkpoint
|
13 |
from point_e.models.configs import MODEL_CONFIGS, model_from_config
|
|
|
|
|
14 |
|
15 |
from diffusers import StableDiffusionPipeline
|
16 |
|
|
|
117 |
set_state('Converting to mesh...')
|
118 |
|
119 |
uniqid = uuid.uuid4()
|
120 |
+
file_path = f'/tmp/mesh-{uniqid}.npy'
|
121 |
save_ply(pc, file_path)
|
122 |
|
123 |
set_state('')
|
|
|
151 |
def save_ply(pc, file_name):
|
152 |
# Produce a mesh (with vertex colors)
|
153 |
with open(file_name, 'wb') as f:
|
154 |
+
pc.save(f)
|
155 |
|
156 |
|
157 |
def create_gif(pc):
|
pc.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import BinaryIO, Dict, List, Optional, Union
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from .ply_util import write_ply
|
8 |
+
|
9 |
+
COLORS = frozenset(["R", "G", "B", "A"])
|
10 |
+
|
11 |
+
|
12 |
+
def preprocess(data, channel):
|
13 |
+
if channel in COLORS:
|
14 |
+
return np.round(data * 255.0)
|
15 |
+
return data
|
16 |
+
|
17 |
+
|
18 |
+
@dataclass
|
19 |
+
class PointCloud:
|
20 |
+
"""
|
21 |
+
An array of points sampled on a surface. Each point may have zero or more
|
22 |
+
channel attributes.
|
23 |
+
|
24 |
+
:param coords: an [N x 3] array of point coordinates.
|
25 |
+
:param channels: a dict mapping names to [N] arrays of channel values.
|
26 |
+
"""
|
27 |
+
|
28 |
+
coords: np.ndarray
|
29 |
+
channels: Dict[str, np.ndarray]
|
30 |
+
|
31 |
+
@classmethod
|
32 |
+
def load(cls, f: Union[str, BinaryIO]) -> "PointCloud":
|
33 |
+
"""
|
34 |
+
Load the point cloud from a .npz file.
|
35 |
+
"""
|
36 |
+
if isinstance(f, str):
|
37 |
+
with open(f, "rb") as reader:
|
38 |
+
return cls.load(reader)
|
39 |
+
else:
|
40 |
+
obj = np.load(f)
|
41 |
+
keys = list(obj.keys())
|
42 |
+
return PointCloud(
|
43 |
+
coords=obj["coords"],
|
44 |
+
channels={k: obj[k] for k in keys if k != "coords"},
|
45 |
+
)
|
46 |
+
|
47 |
+
def save(self, f: Union[str, BinaryIO]):
|
48 |
+
"""
|
49 |
+
Save the point cloud to a .npz file.
|
50 |
+
"""
|
51 |
+
if isinstance(f, str):
|
52 |
+
with open(f, "wb") as writer:
|
53 |
+
self.save(writer)
|
54 |
+
else:
|
55 |
+
np.save(f, coords=self.coords, **self.channels)
|
56 |
+
|
57 |
+
def write_ply(self, raw_f: BinaryIO):
|
58 |
+
write_ply(
|
59 |
+
raw_f,
|
60 |
+
coords=self.coords,
|
61 |
+
rgb=(
|
62 |
+
np.stack([self.channels[x] for x in "RGB"], axis=1)
|
63 |
+
if all(x in self.channels for x in "RGB")
|
64 |
+
else None
|
65 |
+
),
|
66 |
+
)
|
67 |
+
|
68 |
+
def random_sample(self, num_points: int, **subsample_kwargs) -> "PointCloud":
|
69 |
+
"""
|
70 |
+
Sample a random subset of this PointCloud.
|
71 |
+
|
72 |
+
:param num_points: maximum number of points to sample.
|
73 |
+
:param subsample_kwargs: arguments to self.subsample().
|
74 |
+
:return: a reduced PointCloud, or self if num_points is not less than
|
75 |
+
the current number of points.
|
76 |
+
"""
|
77 |
+
if len(self.coords) <= num_points:
|
78 |
+
return self
|
79 |
+
indices = np.random.choice(len(self.coords), size=(num_points,), replace=False)
|
80 |
+
return self.subsample(indices, **subsample_kwargs)
|
81 |
+
|
82 |
+
def farthest_point_sample(
|
83 |
+
self, num_points: int, init_idx: Optional[int] = None, **subsample_kwargs
|
84 |
+
) -> "PointCloud":
|
85 |
+
"""
|
86 |
+
Sample a subset of the point cloud that is evenly distributed in space.
|
87 |
+
|
88 |
+
First, a random point is selected. Then each successive point is chosen
|
89 |
+
such that it is furthest from the currently selected points.
|
90 |
+
|
91 |
+
The time complexity of this operation is O(NM), where N is the original
|
92 |
+
number of points and M is the reduced number. Therefore, performance
|
93 |
+
can be improved by randomly subsampling points with random_sample()
|
94 |
+
before running farthest_point_sample().
|
95 |
+
|
96 |
+
:param num_points: maximum number of points to sample.
|
97 |
+
:param init_idx: if specified, the first point to sample.
|
98 |
+
:param subsample_kwargs: arguments to self.subsample().
|
99 |
+
:return: a reduced PointCloud, or self if num_points is not less than
|
100 |
+
the current number of points.
|
101 |
+
"""
|
102 |
+
if len(self.coords) <= num_points:
|
103 |
+
return self
|
104 |
+
init_idx = random.randrange(len(self.coords)) if init_idx is None else init_idx
|
105 |
+
indices = np.zeros([num_points], dtype=np.int64)
|
106 |
+
indices[0] = init_idx
|
107 |
+
sq_norms = np.sum(self.coords**2, axis=-1)
|
108 |
+
|
109 |
+
def compute_dists(idx: int):
|
110 |
+
# Utilize equality: ||A-B||^2 = ||A||^2 + ||B||^2 - 2*(A @ B).
|
111 |
+
return sq_norms + sq_norms[idx] - 2 * (self.coords @ self.coords[idx])
|
112 |
+
|
113 |
+
cur_dists = compute_dists(init_idx)
|
114 |
+
for i in range(1, num_points):
|
115 |
+
idx = np.argmax(cur_dists)
|
116 |
+
indices[i] = idx
|
117 |
+
cur_dists = np.minimum(cur_dists, compute_dists(idx))
|
118 |
+
return self.subsample(indices, **subsample_kwargs)
|
119 |
+
|
120 |
+
def subsample(self, indices: np.ndarray, average_neighbors: bool = False) -> "PointCloud":
|
121 |
+
if not average_neighbors:
|
122 |
+
return PointCloud(
|
123 |
+
coords=self.coords[indices],
|
124 |
+
channels={k: v[indices] for k, v in self.channels.items()},
|
125 |
+
)
|
126 |
+
|
127 |
+
new_coords = self.coords[indices]
|
128 |
+
neighbor_indices = PointCloud(coords=new_coords, channels={}).nearest_points(self.coords)
|
129 |
+
|
130 |
+
# Make sure every point points to itself, which might not
|
131 |
+
# be the case if points are duplicated or there is rounding
|
132 |
+
# error.
|
133 |
+
neighbor_indices[indices] = np.arange(len(indices))
|
134 |
+
|
135 |
+
new_channels = {}
|
136 |
+
for k, v in self.channels.items():
|
137 |
+
v_sum = np.zeros_like(v[: len(indices)])
|
138 |
+
v_count = np.zeros_like(v[: len(indices)])
|
139 |
+
np.add.at(v_sum, neighbor_indices, v)
|
140 |
+
np.add.at(v_count, neighbor_indices, 1)
|
141 |
+
new_channels[k] = v_sum / v_count
|
142 |
+
return PointCloud(coords=new_coords, channels=new_channels)
|
143 |
+
|
144 |
+
def select_channels(self, channel_names: List[str]) -> np.ndarray:
|
145 |
+
data = np.stack([preprocess(self.channels[name], name) for name in channel_names], axis=-1)
|
146 |
+
return data
|
147 |
+
|
148 |
+
def nearest_points(self, points: np.ndarray, batch_size: int = 16384) -> np.ndarray:
|
149 |
+
"""
|
150 |
+
For each point in another set of points, compute the point in this
|
151 |
+
pointcloud which is closest.
|
152 |
+
|
153 |
+
:param points: an [N x 3] array of points.
|
154 |
+
:param batch_size: the number of neighbor distances to compute at once.
|
155 |
+
Smaller values save memory, while larger values may
|
156 |
+
make the computation faster.
|
157 |
+
:return: an [N] array of indices into self.coords.
|
158 |
+
"""
|
159 |
+
norms = np.sum(self.coords**2, axis=-1)
|
160 |
+
all_indices = []
|
161 |
+
for i in range(0, len(points), batch_size):
|
162 |
+
batch = points[i : i + batch_size]
|
163 |
+
dists = norms + np.sum(batch**2, axis=-1)[:, None] - 2 * (batch @ self.coords.T)
|
164 |
+
all_indices.append(np.argmin(dists, axis=-1))
|
165 |
+
return np.concatenate(all_indices, axis=0)
|
166 |
+
|
167 |
+
def combine(self, other: "PointCloud") -> "PointCloud":
|
168 |
+
assert self.channels.keys() == other.channels.keys()
|
169 |
+
return PointCloud(
|
170 |
+
coords=np.concatenate([self.coords, other.coords], axis=0),
|
171 |
+
channels={
|
172 |
+
k: np.concatenate([v, other.channels[k]], axis=0) for k, v in self.channels.items()
|
173 |
+
},
|
174 |
+
)
|
sampler.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Helpers for sampling from a single- or multi-stage point cloud diffusion model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import Any, Callable, Dict, Iterator, List, Sequence, Tuple
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from .pc import PointCloud
|
11 |
+
|
12 |
+
from point_e.diffusion.gaussian_diffusion import GaussianDiffusion
|
13 |
+
from point_e.diffusion.k_diffusion import karras_sample_progressive
|
14 |
+
|
15 |
+
|
16 |
+
class PointCloudSampler:
|
17 |
+
"""
|
18 |
+
A wrapper around a model or stack of models that produces conditional or
|
19 |
+
unconditional sample tensors.
|
20 |
+
|
21 |
+
By default, this will load models and configs from files.
|
22 |
+
If you want to modify the sampler arguments of an existing sampler, call
|
23 |
+
with_options() or with_args().
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
device: torch.device,
|
29 |
+
models: Sequence[nn.Module],
|
30 |
+
diffusions: Sequence[GaussianDiffusion],
|
31 |
+
num_points: Sequence[int],
|
32 |
+
aux_channels: Sequence[str],
|
33 |
+
model_kwargs_key_filter: Sequence[str] = ("*",),
|
34 |
+
guidance_scale: Sequence[float] = (3.0, 3.0),
|
35 |
+
clip_denoised: bool = True,
|
36 |
+
use_karras: Sequence[bool] = (True, True),
|
37 |
+
karras_steps: Sequence[int] = (64, 64),
|
38 |
+
sigma_min: Sequence[float] = (1e-3, 1e-3),
|
39 |
+
sigma_max: Sequence[float] = (120, 160),
|
40 |
+
s_churn: Sequence[float] = (3, 0),
|
41 |
+
):
|
42 |
+
n = len(models)
|
43 |
+
assert n > 0
|
44 |
+
|
45 |
+
if n > 1:
|
46 |
+
if len(guidance_scale) == 1:
|
47 |
+
# Don't guide the upsamplers by default.
|
48 |
+
guidance_scale = list(guidance_scale) + [1.0] * (n - 1)
|
49 |
+
if len(use_karras) == 1:
|
50 |
+
use_karras = use_karras * n
|
51 |
+
if len(karras_steps) == 1:
|
52 |
+
karras_steps = karras_steps * n
|
53 |
+
if len(sigma_min) == 1:
|
54 |
+
sigma_min = sigma_min * n
|
55 |
+
if len(sigma_max) == 1:
|
56 |
+
sigma_max = sigma_max * n
|
57 |
+
if len(s_churn) == 1:
|
58 |
+
s_churn = s_churn * n
|
59 |
+
if len(model_kwargs_key_filter) == 1:
|
60 |
+
model_kwargs_key_filter = model_kwargs_key_filter * n
|
61 |
+
if len(model_kwargs_key_filter) == 0:
|
62 |
+
model_kwargs_key_filter = ["*"] * n
|
63 |
+
assert len(guidance_scale) == n
|
64 |
+
assert len(use_karras) == n
|
65 |
+
assert len(karras_steps) == n
|
66 |
+
assert len(sigma_min) == n
|
67 |
+
assert len(sigma_max) == n
|
68 |
+
assert len(s_churn) == n
|
69 |
+
assert len(model_kwargs_key_filter) == n
|
70 |
+
|
71 |
+
self.device = device
|
72 |
+
self.num_points = num_points
|
73 |
+
self.aux_channels = aux_channels
|
74 |
+
self.model_kwargs_key_filter = model_kwargs_key_filter
|
75 |
+
self.guidance_scale = guidance_scale
|
76 |
+
self.clip_denoised = clip_denoised
|
77 |
+
self.use_karras = use_karras
|
78 |
+
self.karras_steps = karras_steps
|
79 |
+
self.sigma_min = sigma_min
|
80 |
+
self.sigma_max = sigma_max
|
81 |
+
self.s_churn = s_churn
|
82 |
+
|
83 |
+
self.models = models
|
84 |
+
self.diffusions = diffusions
|
85 |
+
|
86 |
+
@property
|
87 |
+
def num_stages(self) -> int:
|
88 |
+
return len(self.models)
|
89 |
+
|
90 |
+
def sample_batch(self, batch_size: int, model_kwargs: Dict[str, Any]) -> torch.Tensor:
|
91 |
+
samples = None
|
92 |
+
for x in self.sample_batch_progressive(batch_size, model_kwargs):
|
93 |
+
samples = x
|
94 |
+
return samples
|
95 |
+
|
96 |
+
def sample_batch_progressive(
|
97 |
+
self, batch_size: int, model_kwargs: Dict[str, Any]
|
98 |
+
) -> Iterator[torch.Tensor]:
|
99 |
+
samples = None
|
100 |
+
for (
|
101 |
+
model,
|
102 |
+
diffusion,
|
103 |
+
stage_num_points,
|
104 |
+
stage_guidance_scale,
|
105 |
+
stage_use_karras,
|
106 |
+
stage_karras_steps,
|
107 |
+
stage_sigma_min,
|
108 |
+
stage_sigma_max,
|
109 |
+
stage_s_churn,
|
110 |
+
stage_key_filter,
|
111 |
+
) in zip(
|
112 |
+
self.models,
|
113 |
+
self.diffusions,
|
114 |
+
self.num_points,
|
115 |
+
self.guidance_scale,
|
116 |
+
self.use_karras,
|
117 |
+
self.karras_steps,
|
118 |
+
self.sigma_min,
|
119 |
+
self.sigma_max,
|
120 |
+
self.s_churn,
|
121 |
+
self.model_kwargs_key_filter,
|
122 |
+
):
|
123 |
+
stage_model_kwargs = model_kwargs.copy()
|
124 |
+
if stage_key_filter != "*":
|
125 |
+
use_keys = set(stage_key_filter.split(","))
|
126 |
+
stage_model_kwargs = {k: v for k, v in stage_model_kwargs.items() if k in use_keys}
|
127 |
+
if samples is not None:
|
128 |
+
stage_model_kwargs["low_res"] = samples
|
129 |
+
if hasattr(model, "cached_model_kwargs"):
|
130 |
+
stage_model_kwargs = model.cached_model_kwargs(batch_size, stage_model_kwargs)
|
131 |
+
sample_shape = (batch_size, 3 + len(self.aux_channels), stage_num_points)
|
132 |
+
|
133 |
+
if stage_guidance_scale != 1 and stage_guidance_scale != 0:
|
134 |
+
for k, v in stage_model_kwargs.copy().items():
|
135 |
+
stage_model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0)
|
136 |
+
|
137 |
+
if stage_use_karras:
|
138 |
+
samples_it = karras_sample_progressive(
|
139 |
+
diffusion=diffusion,
|
140 |
+
model=model,
|
141 |
+
shape=sample_shape,
|
142 |
+
steps=stage_karras_steps,
|
143 |
+
clip_denoised=self.clip_denoised,
|
144 |
+
model_kwargs=stage_model_kwargs,
|
145 |
+
device=self.device,
|
146 |
+
sigma_min=stage_sigma_min,
|
147 |
+
sigma_max=stage_sigma_max,
|
148 |
+
s_churn=stage_s_churn,
|
149 |
+
guidance_scale=stage_guidance_scale,
|
150 |
+
)
|
151 |
+
else:
|
152 |
+
internal_batch_size = batch_size
|
153 |
+
if stage_guidance_scale:
|
154 |
+
model = self._uncond_guide_model(model, stage_guidance_scale)
|
155 |
+
internal_batch_size *= 2
|
156 |
+
samples_it = diffusion.p_sample_loop_progressive(
|
157 |
+
model,
|
158 |
+
shape=(internal_batch_size, *sample_shape[1:]),
|
159 |
+
model_kwargs=stage_model_kwargs,
|
160 |
+
device=self.device,
|
161 |
+
clip_denoised=self.clip_denoised,
|
162 |
+
)
|
163 |
+
for x in samples_it:
|
164 |
+
samples = x["pred_xstart"][:batch_size]
|
165 |
+
if "low_res" in stage_model_kwargs:
|
166 |
+
samples = torch.cat(
|
167 |
+
[stage_model_kwargs["low_res"][: len(samples)], samples], dim=-1
|
168 |
+
)
|
169 |
+
yield samples
|
170 |
+
|
171 |
+
@classmethod
|
172 |
+
def combine(cls, *samplers: "PointCloudSampler") -> "PointCloudSampler":
|
173 |
+
assert all(x.device == samplers[0].device for x in samplers[1:])
|
174 |
+
assert all(x.aux_channels == samplers[0].aux_channels for x in samplers[1:])
|
175 |
+
assert all(x.clip_denoised == samplers[0].clip_denoised for x in samplers[1:])
|
176 |
+
return cls(
|
177 |
+
device=samplers[0].device,
|
178 |
+
models=[x for y in samplers for x in y.models],
|
179 |
+
diffusions=[x for y in samplers for x in y.diffusions],
|
180 |
+
num_points=[x for y in samplers for x in y.num_points],
|
181 |
+
aux_channels=samplers[0].aux_channels,
|
182 |
+
model_kwargs_key_filter=[x for y in samplers for x in y.model_kwargs_key_filter],
|
183 |
+
guidance_scale=[x for y in samplers for x in y.guidance_scale],
|
184 |
+
clip_denoised=samplers[0].clip_denoised,
|
185 |
+
use_karras=[x for y in samplers for x in y.use_karras],
|
186 |
+
karras_steps=[x for y in samplers for x in y.karras_steps],
|
187 |
+
sigma_min=[x for y in samplers for x in y.sigma_min],
|
188 |
+
sigma_max=[x for y in samplers for x in y.sigma_max],
|
189 |
+
s_churn=[x for y in samplers for x in y.s_churn],
|
190 |
+
)
|
191 |
+
|
192 |
+
def _uncond_guide_model(
|
193 |
+
self, model: Callable[..., torch.Tensor], scale: float
|
194 |
+
) -> Callable[..., torch.Tensor]:
|
195 |
+
def model_fn(x_t, ts, **kwargs):
|
196 |
+
half = x_t[: len(x_t) // 2]
|
197 |
+
combined = torch.cat([half, half], dim=0)
|
198 |
+
model_out = model(combined, ts, **kwargs)
|
199 |
+
eps, rest = model_out[:, :3], model_out[:, 3:]
|
200 |
+
cond_eps, uncond_eps = torch.chunk(eps, 2, dim=0)
|
201 |
+
half_eps = uncond_eps + scale * (cond_eps - uncond_eps)
|
202 |
+
eps = torch.cat([half_eps, half_eps], dim=0)
|
203 |
+
return torch.cat([eps, rest], dim=1)
|
204 |
+
|
205 |
+
return model_fn
|
206 |
+
|
207 |
+
def split_model_output(
|
208 |
+
self,
|
209 |
+
output: torch.Tensor,
|
210 |
+
rescale_colors: bool = False,
|
211 |
+
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
212 |
+
assert (
|
213 |
+
len(self.aux_channels) + 3 == output.shape[1]
|
214 |
+
), "there must be three spatial channels before aux"
|
215 |
+
pos, joined_aux = output[:, :3], output[:, 3:]
|
216 |
+
|
217 |
+
aux = {}
|
218 |
+
for i, name in enumerate(self.aux_channels):
|
219 |
+
v = joined_aux[:, i]
|
220 |
+
if name in {"R", "G", "B", "A"}:
|
221 |
+
v = v.clamp(0, 255).round()
|
222 |
+
if rescale_colors:
|
223 |
+
v = v / 255.0
|
224 |
+
aux[name] = v
|
225 |
+
return pos, aux
|
226 |
+
|
227 |
+
def output_to_point_clouds(self, output: torch.Tensor) -> List[PointCloud]:
|
228 |
+
res = []
|
229 |
+
for sample in output:
|
230 |
+
xyz, aux = self.split_model_output(sample[None], rescale_colors=True)
|
231 |
+
res.append(
|
232 |
+
PointCloud(
|
233 |
+
coords=xyz[0].t().cpu().numpy(),
|
234 |
+
channels={k: v[0].cpu().numpy() for k, v in aux.items()},
|
235 |
+
)
|
236 |
+
)
|
237 |
+
return res
|
238 |
+
|
239 |
+
def with_options(
|
240 |
+
self,
|
241 |
+
guidance_scale: float,
|
242 |
+
clip_denoised: bool,
|
243 |
+
use_karras: Sequence[bool] = (True, True),
|
244 |
+
karras_steps: Sequence[int] = (64, 64),
|
245 |
+
sigma_min: Sequence[float] = (1e-3, 1e-3),
|
246 |
+
sigma_max: Sequence[float] = (120, 160),
|
247 |
+
s_churn: Sequence[float] = (3, 0),
|
248 |
+
) -> "PointCloudSampler":
|
249 |
+
return PointCloudSampler(
|
250 |
+
device=self.device,
|
251 |
+
models=self.models,
|
252 |
+
diffusions=self.diffusions,
|
253 |
+
num_points=self.num_points,
|
254 |
+
aux_channels=self.aux_channels,
|
255 |
+
model_kwargs_key_filter=self.model_kwargs_key_filter,
|
256 |
+
guidance_scale=guidance_scale,
|
257 |
+
clip_denoised=clip_denoised,
|
258 |
+
use_karras=use_karras,
|
259 |
+
karras_steps=karras_steps,
|
260 |
+
sigma_min=sigma_min,
|
261 |
+
sigma_max=sigma_max,
|
262 |
+
s_churn=s_churn,
|
263 |
+
)
|