clean up
Browse files- README.md +8 -1
- main.py +13 -3
- mvdream/models.py +13 -25
- mvdream/util.py +0 -196
README.md
CHANGED
@@ -12,7 +12,14 @@ wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd
|
|
12 |
python convert_mvdream_to_diffusers.py --checkpoint_path ./sd-v2.1-base-4view.pt --dump_path ./weights --original_config_file ./sd-v2-base.yaml --half --to_safetensors --test
|
13 |
```
|
14 |
|
15 |
-
###
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
```python
|
17 |
import torch
|
18 |
import kiui
|
|
|
12 |
python convert_mvdream_to_diffusers.py --checkpoint_path ./sd-v2.1-base-4view.pt --dump_path ./weights --original_config_file ./sd-v2-base.yaml --half --to_safetensors --test
|
13 |
```
|
14 |
|
15 |
+
### usage
|
16 |
+
|
17 |
+
example:
|
18 |
+
```bash
|
19 |
+
python main.py "a cute owl"
|
20 |
+
```
|
21 |
+
|
22 |
+
detailed usage:
|
23 |
```python
|
24 |
import torch
|
25 |
import kiui
|
main.py
CHANGED
@@ -1,11 +1,21 @@
|
|
1 |
import torch
|
2 |
import kiui
|
|
|
|
|
3 |
from mvdream.pipeline_mvdream import MVDreamStableDiffusionPipeline
|
4 |
|
5 |
pipe = MVDreamStableDiffusionPipeline.from_pretrained('./weights', torch_dtype=torch.float16)
|
6 |
pipe = pipe.to("cuda")
|
7 |
|
8 |
-
prompt = "a photo of an astronaut riding a horse on mars"
|
9 |
-
image = pipe(prompt)
|
10 |
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import kiui
|
3 |
+
import numpy as np
|
4 |
+
import argparse
|
5 |
from mvdream.pipeline_mvdream import MVDreamStableDiffusionPipeline
|
6 |
|
7 |
pipe = MVDreamStableDiffusionPipeline.from_pretrained('./weights', torch_dtype=torch.float16)
|
8 |
pipe = pipe.to("cuda")
|
9 |
|
|
|
|
|
10 |
|
11 |
+
parser = argparse.ArgumentParser(description='MVDream')
|
12 |
+
parser.add_argument('prompt', type=str, default="a cute owl 3d model")
|
13 |
+
args = parser.parse_args()
|
14 |
+
|
15 |
+
while True:
|
16 |
+
image = pipe(args.prompt)
|
17 |
+
grid = np.concatenate([
|
18 |
+
np.concatenate([image[0], image[2]], axis=0),
|
19 |
+
np.concatenate([image[1], image[3]], axis=0),
|
20 |
+
], axis=1)
|
21 |
+
kiui.vis.plot_image(grid)
|
mvdream/models.py
CHANGED
@@ -10,10 +10,8 @@ from abc import abstractmethod
|
|
10 |
from .util import (
|
11 |
checkpoint,
|
12 |
conv_nd,
|
13 |
-
linear,
|
14 |
avg_pool_nd,
|
15 |
zero_module,
|
16 |
-
normalization,
|
17 |
timestep_embedding,
|
18 |
)
|
19 |
from .attention import SpatialTransformer, SpatialTransformer3D
|
@@ -56,7 +54,7 @@ class MultiViewUNetWrapperModel(ModelMixin, ConfigMixin):
|
|
56 |
adm_in_channels=None,
|
57 |
camera_dim=None,):
|
58 |
super().__init__()
|
59 |
-
self.unet
|
60 |
image_size=image_size,
|
61 |
in_channels=in_channels,
|
62 |
model_channels=model_channels,
|
@@ -218,7 +216,7 @@ class ResBlock(TimestepBlock):
|
|
218 |
self.use_scale_shift_norm = use_scale_shift_norm
|
219 |
|
220 |
self.in_layers = nn.Sequential(
|
221 |
-
|
222 |
nn.SiLU(),
|
223 |
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
224 |
)
|
@@ -236,13 +234,13 @@ class ResBlock(TimestepBlock):
|
|
236 |
|
237 |
self.emb_layers = nn.Sequential(
|
238 |
nn.SiLU(),
|
239 |
-
|
240 |
emb_channels,
|
241 |
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
242 |
),
|
243 |
)
|
244 |
self.out_layers = nn.Sequential(
|
245 |
-
|
246 |
nn.SiLU(),
|
247 |
nn.Dropout(p=dropout),
|
248 |
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
|
@@ -310,7 +308,7 @@ class AttentionBlock(nn.Module):
|
|
310 |
assert (channels % num_head_channels == 0), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
311 |
self.num_heads = channels // num_head_channels
|
312 |
self.use_checkpoint = use_checkpoint
|
313 |
-
self.norm =
|
314 |
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
315 |
if use_new_attention_order:
|
316 |
# split qkv before split heads
|
@@ -418,16 +416,6 @@ class QKVAttention(nn.Module):
|
|
418 |
return count_flops_attn(model, _x, y)
|
419 |
|
420 |
|
421 |
-
class Timestep(nn.Module):
|
422 |
-
|
423 |
-
def __init__(self, dim):
|
424 |
-
super().__init__()
|
425 |
-
self.dim = dim
|
426 |
-
|
427 |
-
def forward(self, t):
|
428 |
-
return timestep_embedding(t, self.dim)
|
429 |
-
|
430 |
-
|
431 |
class MultiViewUNetModel(nn.Module):
|
432 |
"""
|
433 |
The full multi-view UNet model with attention, timestep embedding and camera embedding.
|
@@ -545,17 +533,17 @@ class MultiViewUNetModel(nn.Module):
|
|
545 |
|
546 |
time_embed_dim = model_channels * 4
|
547 |
self.time_embed = nn.Sequential(
|
548 |
-
|
549 |
nn.SiLU(),
|
550 |
-
|
551 |
)
|
552 |
|
553 |
if camera_dim is not None:
|
554 |
time_embed_dim = model_channels * 4
|
555 |
self.camera_embed = nn.Sequential(
|
556 |
-
|
557 |
nn.SiLU(),
|
558 |
-
|
559 |
)
|
560 |
|
561 |
if self.num_classes is not None:
|
@@ -567,9 +555,9 @@ class MultiViewUNetModel(nn.Module):
|
|
567 |
elif self.num_classes == "sequential":
|
568 |
assert adm_in_channels is not None
|
569 |
self.label_emb = nn.Sequential(nn.Sequential(
|
570 |
-
|
571 |
nn.SiLU(),
|
572 |
-
|
573 |
))
|
574 |
else:
|
575 |
raise ValueError()
|
@@ -722,13 +710,13 @@ class MultiViewUNetModel(nn.Module):
|
|
722 |
self._feature_size += ch
|
723 |
|
724 |
self.out = nn.Sequential(
|
725 |
-
|
726 |
nn.SiLU(),
|
727 |
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
728 |
)
|
729 |
if self.predict_codebook_ids:
|
730 |
self.id_predictor = nn.Sequential(
|
731 |
-
|
732 |
conv_nd(dims, model_channels, n_embed, 1),
|
733 |
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
734 |
)
|
|
|
10 |
from .util import (
|
11 |
checkpoint,
|
12 |
conv_nd,
|
|
|
13 |
avg_pool_nd,
|
14 |
zero_module,
|
|
|
15 |
timestep_embedding,
|
16 |
)
|
17 |
from .attention import SpatialTransformer, SpatialTransformer3D
|
|
|
54 |
adm_in_channels=None,
|
55 |
camera_dim=None,):
|
56 |
super().__init__()
|
57 |
+
self.unet = MultiViewUNetModel(
|
58 |
image_size=image_size,
|
59 |
in_channels=in_channels,
|
60 |
model_channels=model_channels,
|
|
|
216 |
self.use_scale_shift_norm = use_scale_shift_norm
|
217 |
|
218 |
self.in_layers = nn.Sequential(
|
219 |
+
nn.GroupNorm(32, channels),
|
220 |
nn.SiLU(),
|
221 |
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
222 |
)
|
|
|
234 |
|
235 |
self.emb_layers = nn.Sequential(
|
236 |
nn.SiLU(),
|
237 |
+
nn.Linear(
|
238 |
emb_channels,
|
239 |
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
240 |
),
|
241 |
)
|
242 |
self.out_layers = nn.Sequential(
|
243 |
+
nn.GroupNorm(32, self.out_channels),
|
244 |
nn.SiLU(),
|
245 |
nn.Dropout(p=dropout),
|
246 |
zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)),
|
|
|
308 |
assert (channels % num_head_channels == 0), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
309 |
self.num_heads = channels // num_head_channels
|
310 |
self.use_checkpoint = use_checkpoint
|
311 |
+
self.norm = nn.GroupNorm(32, channels)
|
312 |
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
313 |
if use_new_attention_order:
|
314 |
# split qkv before split heads
|
|
|
416 |
return count_flops_attn(model, _x, y)
|
417 |
|
418 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
class MultiViewUNetModel(nn.Module):
|
420 |
"""
|
421 |
The full multi-view UNet model with attention, timestep embedding and camera embedding.
|
|
|
533 |
|
534 |
time_embed_dim = model_channels * 4
|
535 |
self.time_embed = nn.Sequential(
|
536 |
+
nn.Linear(model_channels, time_embed_dim),
|
537 |
nn.SiLU(),
|
538 |
+
nn.Linear(time_embed_dim, time_embed_dim),
|
539 |
)
|
540 |
|
541 |
if camera_dim is not None:
|
542 |
time_embed_dim = model_channels * 4
|
543 |
self.camera_embed = nn.Sequential(
|
544 |
+
nn.Linear(camera_dim, time_embed_dim),
|
545 |
nn.SiLU(),
|
546 |
+
nn.Linear(time_embed_dim, time_embed_dim),
|
547 |
)
|
548 |
|
549 |
if self.num_classes is not None:
|
|
|
555 |
elif self.num_classes == "sequential":
|
556 |
assert adm_in_channels is not None
|
557 |
self.label_emb = nn.Sequential(nn.Sequential(
|
558 |
+
nn.Linear(adm_in_channels, time_embed_dim),
|
559 |
nn.SiLU(),
|
560 |
+
nn.Linear(time_embed_dim, time_embed_dim),
|
561 |
))
|
562 |
else:
|
563 |
raise ValueError()
|
|
|
710 |
self._feature_size += ch
|
711 |
|
712 |
self.out = nn.Sequential(
|
713 |
+
nn.GroupNorm(32, ch),
|
714 |
nn.SiLU(),
|
715 |
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
716 |
)
|
717 |
if self.predict_codebook_ids:
|
718 |
self.id_predictor = nn.Sequential(
|
719 |
+
nn.GroupNorm(32, ch),
|
720 |
conv_nd(dims, model_channels, n_embed, 1),
|
721 |
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
722 |
)
|
mvdream/util.py
CHANGED
@@ -10,136 +10,7 @@
|
|
10 |
import math
|
11 |
import torch
|
12 |
import torch.nn as nn
|
13 |
-
import numpy as np
|
14 |
-
import importlib
|
15 |
from einops import repeat
|
16 |
-
from typing import Any
|
17 |
-
|
18 |
-
|
19 |
-
def instantiate_from_config(config):
|
20 |
-
if not "target" in config:
|
21 |
-
if config == '__is_first_stage__':
|
22 |
-
return None
|
23 |
-
elif config == "__is_unconditional__":
|
24 |
-
return None
|
25 |
-
raise KeyError("Expected key `target` to instantiate.")
|
26 |
-
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
27 |
-
|
28 |
-
|
29 |
-
def get_obj_from_str(string, reload=False):
|
30 |
-
module, cls = string.rsplit(".", 1)
|
31 |
-
if reload:
|
32 |
-
module_imp = importlib.import_module(module)
|
33 |
-
importlib.reload(module_imp)
|
34 |
-
return getattr(importlib.import_module(module, package=None), cls)
|
35 |
-
|
36 |
-
|
37 |
-
def make_beta_schedule(schedule,
|
38 |
-
n_timestep,
|
39 |
-
linear_start=1e-4,
|
40 |
-
linear_end=2e-2,
|
41 |
-
cosine_s=8e-3):
|
42 |
-
if schedule == "linear":
|
43 |
-
betas = (torch.linspace(linear_start**0.5,
|
44 |
-
linear_end**0.5,
|
45 |
-
n_timestep,
|
46 |
-
dtype=torch.float64)**2)
|
47 |
-
|
48 |
-
elif schedule == "cosine":
|
49 |
-
timesteps = (
|
50 |
-
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep +
|
51 |
-
cosine_s)
|
52 |
-
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
53 |
-
alphas = torch.cos(alphas).pow(2)
|
54 |
-
alphas = alphas / alphas[0]
|
55 |
-
betas = 1 - alphas[1:] / alphas[:-1]
|
56 |
-
betas = np.clip(betas, a_min=0, a_max=0.999)
|
57 |
-
|
58 |
-
elif schedule == "sqrt_linear":
|
59 |
-
betas = torch.linspace(linear_start,
|
60 |
-
linear_end,
|
61 |
-
n_timestep,
|
62 |
-
dtype=torch.float64)
|
63 |
-
elif schedule == "sqrt":
|
64 |
-
betas = torch.linspace(linear_start,
|
65 |
-
linear_end,
|
66 |
-
n_timestep,
|
67 |
-
dtype=torch.float64)**0.5
|
68 |
-
else:
|
69 |
-
raise ValueError(f"schedule '{schedule}' unknown.")
|
70 |
-
return betas.numpy() # type: ignore
|
71 |
-
|
72 |
-
|
73 |
-
def make_ddim_timesteps(ddim_discr_method,
|
74 |
-
num_ddim_timesteps,
|
75 |
-
num_ddpm_timesteps,
|
76 |
-
verbose=True):
|
77 |
-
if ddim_discr_method == 'uniform':
|
78 |
-
c = num_ddpm_timesteps // num_ddim_timesteps
|
79 |
-
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
80 |
-
elif ddim_discr_method == 'quad':
|
81 |
-
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8),
|
82 |
-
num_ddim_timesteps))**2).astype(int)
|
83 |
-
else:
|
84 |
-
raise NotImplementedError(
|
85 |
-
f'There is no ddim discretization method called "{ddim_discr_method}"'
|
86 |
-
)
|
87 |
-
|
88 |
-
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
89 |
-
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
90 |
-
steps_out = ddim_timesteps + 1
|
91 |
-
if verbose:
|
92 |
-
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
93 |
-
return steps_out
|
94 |
-
|
95 |
-
|
96 |
-
def make_ddim_sampling_parameters(alphacums,
|
97 |
-
ddim_timesteps,
|
98 |
-
eta,
|
99 |
-
verbose=True):
|
100 |
-
# select alphas for computing the variance schedule
|
101 |
-
alphas = alphacums[ddim_timesteps]
|
102 |
-
alphas_prev = np.asarray([alphacums[0]] +
|
103 |
-
alphacums[ddim_timesteps[:-1]].tolist())
|
104 |
-
|
105 |
-
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
106 |
-
sigmas = eta * np.sqrt(
|
107 |
-
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
108 |
-
if verbose:
|
109 |
-
print(
|
110 |
-
f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
|
111 |
-
)
|
112 |
-
print(
|
113 |
-
f'For the chosen value of eta, which is {eta}, '
|
114 |
-
f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
|
115 |
-
)
|
116 |
-
return sigmas, alphas, alphas_prev
|
117 |
-
|
118 |
-
|
119 |
-
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
120 |
-
"""
|
121 |
-
Create a beta schedule that discretizes the given alpha_t_bar function,
|
122 |
-
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
123 |
-
:param num_diffusion_timesteps: the number of betas to produce.
|
124 |
-
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
125 |
-
produces the cumulative product of (1-beta) up to that
|
126 |
-
part of the diffusion process.
|
127 |
-
:param max_beta: the maximum beta to use; use values lower than 1 to
|
128 |
-
prevent singularities.
|
129 |
-
"""
|
130 |
-
betas = []
|
131 |
-
for i in range(num_diffusion_timesteps):
|
132 |
-
t1 = i / num_diffusion_timesteps
|
133 |
-
t2 = (i + 1) / num_diffusion_timesteps
|
134 |
-
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
135 |
-
return np.array(betas)
|
136 |
-
|
137 |
-
|
138 |
-
def extract_into_tensor(a, t, x_shape):
|
139 |
-
b, *_ = t.shape
|
140 |
-
out = a.gather(-1, t)
|
141 |
-
return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
|
142 |
-
|
143 |
|
144 |
def checkpoint(func, inputs, params, flag):
|
145 |
"""
|
@@ -227,45 +98,6 @@ def zero_module(module):
|
|
227 |
p.detach().zero_()
|
228 |
return module
|
229 |
|
230 |
-
|
231 |
-
def scale_module(module, scale):
|
232 |
-
"""
|
233 |
-
Scale the parameters of a module and return it.
|
234 |
-
"""
|
235 |
-
for p in module.parameters():
|
236 |
-
p.detach().mul_(scale)
|
237 |
-
return module
|
238 |
-
|
239 |
-
|
240 |
-
def mean_flat(tensor):
|
241 |
-
"""
|
242 |
-
Take the mean over all non-batch dimensions.
|
243 |
-
"""
|
244 |
-
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
245 |
-
|
246 |
-
|
247 |
-
def normalization(channels):
|
248 |
-
"""
|
249 |
-
Make a standard normalization layer.
|
250 |
-
:param channels: number of input channels.
|
251 |
-
:return: an nn.Module for normalization.
|
252 |
-
"""
|
253 |
-
return GroupNorm32(32, channels)
|
254 |
-
|
255 |
-
|
256 |
-
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
257 |
-
class SiLU(nn.Module):
|
258 |
-
|
259 |
-
def forward(self, x):
|
260 |
-
return x * torch.sigmoid(x)
|
261 |
-
|
262 |
-
|
263 |
-
class GroupNorm32(nn.GroupNorm):
|
264 |
-
|
265 |
-
def forward(self, x):
|
266 |
-
return super().forward(x)
|
267 |
-
|
268 |
-
|
269 |
def conv_nd(dims, *args, **kwargs):
|
270 |
"""
|
271 |
Create a 1D, 2D, or 3D convolution module.
|
@@ -279,13 +111,6 @@ def conv_nd(dims, *args, **kwargs):
|
|
279 |
raise ValueError(f"unsupported dimensions: {dims}")
|
280 |
|
281 |
|
282 |
-
def linear(*args, **kwargs):
|
283 |
-
"""
|
284 |
-
Create a linear module.
|
285 |
-
"""
|
286 |
-
return nn.Linear(*args, **kwargs)
|
287 |
-
|
288 |
-
|
289 |
def avg_pool_nd(dims, *args, **kwargs):
|
290 |
"""
|
291 |
Create a 1D, 2D, or 3D average pooling module.
|
@@ -297,24 +122,3 @@ def avg_pool_nd(dims, *args, **kwargs):
|
|
297 |
elif dims == 3:
|
298 |
return nn.AvgPool3d(*args, **kwargs)
|
299 |
raise ValueError(f"unsupported dimensions: {dims}")
|
300 |
-
|
301 |
-
|
302 |
-
class HybridConditioner(nn.Module):
|
303 |
-
|
304 |
-
def __init__(self, c_concat_config, c_crossattn_config):
|
305 |
-
super().__init__()
|
306 |
-
self.concat_conditioner: Any = instantiate_from_config(c_concat_config)
|
307 |
-
self.crossattn_conditioner: Any = instantiate_from_config(
|
308 |
-
c_crossattn_config)
|
309 |
-
|
310 |
-
def forward(self, c_concat, c_crossattn):
|
311 |
-
c_concat = self.concat_conditioner(c_concat)
|
312 |
-
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
313 |
-
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
|
314 |
-
|
315 |
-
|
316 |
-
def noise_like(shape, device, repeat=False):
|
317 |
-
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
|
318 |
-
shape[0], *((1, ) * (len(shape) - 1)))
|
319 |
-
noise = lambda: torch.randn(shape, device=device)
|
320 |
-
return repeat_noise() if repeat else noise()
|
|
|
10 |
import math
|
11 |
import torch
|
12 |
import torch.nn as nn
|
|
|
|
|
13 |
from einops import repeat
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
def checkpoint(func, inputs, params, flag):
|
16 |
"""
|
|
|
98 |
p.detach().zero_()
|
99 |
return module
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
def conv_nd(dims, *args, **kwargs):
|
102 |
"""
|
103 |
Create a 1D, 2D, or 3D convolution module.
|
|
|
111 |
raise ValueError(f"unsupported dimensions: {dims}")
|
112 |
|
113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
def avg_pool_nd(dims, *args, **kwargs):
|
115 |
"""
|
116 |
Create a 1D, 2D, or 3D average pooling module.
|
|
|
122 |
elif dims == 3:
|
123 |
return nn.AvgPool3d(*args, **kwargs)
|
124 |
raise ValueError(f"unsupported dimensions: {dims}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|