ashawkey commited on
Commit
8241d5f
1 Parent(s): 5d590ff

imagedream framework

Browse files
imagedream/attention.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.amp.autocast_mode import autocast
5
+
6
+ from inspect import isfunction
7
+ from einops import rearrange, repeat
8
+ from typing import Optional, Any
9
+ from .util import checkpoint, zero_module
10
+
11
+ try:
12
+ import xformers # type: ignore
13
+ import xformers.ops # type: ignore
14
+ XFORMERS_IS_AVAILBLE = True
15
+ except:
16
+ print(f'[WARN] xformers is unavailable!')
17
+ XFORMERS_IS_AVAILBLE = False
18
+
19
+ # CrossAttn precision handling
20
+ import os
21
+
22
+ _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
23
+
24
+
25
+ def default(val, d):
26
+ if val is not None:
27
+ return val
28
+ return d() if isfunction(d) else d
29
+
30
+
31
+ class GEGLU(nn.Module):
32
+ def __init__(self, dim_in, dim_out):
33
+ super().__init__()
34
+ self.proj = nn.Linear(dim_in, dim_out * 2)
35
+
36
+ def forward(self, x):
37
+ x, gate = self.proj(x).chunk(2, dim=-1)
38
+ return x * F.gelu(gate)
39
+
40
+
41
+ class FeedForward(nn.Module):
42
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
43
+ super().__init__()
44
+ inner_dim = int(dim * mult)
45
+ dim_out = default(dim_out, dim)
46
+ project_in = (
47
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
48
+ if not glu
49
+ else GEGLU(dim, inner_dim)
50
+ )
51
+
52
+ self.net = nn.Sequential(
53
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
54
+ )
55
+
56
+ def forward(self, x):
57
+ return self.net(x)
58
+
59
+
60
+ class CrossAttention(nn.Module):
61
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
62
+ super().__init__()
63
+ inner_dim = dim_head * heads
64
+ context_dim = default(context_dim, query_dim)
65
+
66
+ self.scale = dim_head**-0.5
67
+ self.heads = heads
68
+
69
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
70
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
71
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
72
+
73
+ self.to_out = nn.Sequential(
74
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
75
+ )
76
+
77
+ def forward(self, x, context=None, mask=None):
78
+ h = self.heads
79
+
80
+ q = self.to_q(x)
81
+ context = default(context, x)
82
+ k = self.to_k(context)
83
+ v = self.to_v(context)
84
+
85
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
86
+
87
+ # force cast to fp32 to avoid overflowing
88
+ if _ATTN_PRECISION == "fp32":
89
+ with autocast(enabled=False, device_type="cuda"):
90
+ q, k = q.float(), k.float()
91
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
92
+ else:
93
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
94
+
95
+ del q, k
96
+
97
+ if mask is not None:
98
+ mask = rearrange(mask, "b ... -> b (...)")
99
+ max_neg_value = -torch.finfo(sim.dtype).max
100
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
101
+ sim.masked_fill_(~mask, max_neg_value)
102
+
103
+ # attention, what we cannot get enough of
104
+ sim = sim.softmax(dim=-1)
105
+
106
+ out = torch.einsum("b i j, b j d -> b i d", sim, v)
107
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
108
+ return self.to_out(out)
109
+
110
+
111
+ class MemoryEfficientCrossAttention(nn.Module):
112
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
113
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
114
+ super().__init__()
115
+ # print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using {heads} heads.")
116
+ inner_dim = dim_head * heads
117
+ context_dim = default(context_dim, query_dim)
118
+
119
+ self.heads = heads
120
+ self.dim_head = dim_head
121
+
122
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
123
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
124
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
125
+
126
+ self.to_out = nn.Sequential(
127
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
128
+ )
129
+ self.attention_op: Optional[Any] = None
130
+
131
+ def forward(self, x, context=None, mask=None):
132
+ q = self.to_q(x)
133
+ context = default(context, x)
134
+ k = self.to_k(context)
135
+ v = self.to_v(context)
136
+
137
+ b, _, _ = q.shape
138
+ q, k, v = map(
139
+ lambda t: t.unsqueeze(3)
140
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
141
+ .permute(0, 2, 1, 3)
142
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
143
+ .contiguous(),
144
+ (q, k, v),
145
+ )
146
+
147
+ # actually compute the attention, what we cannot get enough of
148
+ out = xformers.ops.memory_efficient_attention(
149
+ q, k, v, attn_bias=None, op=self.attention_op
150
+ )
151
+
152
+ if mask is not None:
153
+ raise NotImplementedError
154
+ out = (
155
+ out.unsqueeze(0)
156
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
157
+ .permute(0, 2, 1, 3)
158
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
159
+ )
160
+ return self.to_out(out)
161
+
162
+
163
+ class BasicTransformerBlock(nn.Module):
164
+ ATTENTION_MODES = {
165
+ "softmax": CrossAttention,
166
+ "softmax-xformers": MemoryEfficientCrossAttention,
167
+ } # vanilla attention
168
+
169
+ def __init__(
170
+ self,
171
+ dim,
172
+ n_heads,
173
+ d_head,
174
+ dropout=0.0,
175
+ context_dim=None,
176
+ gated_ff=True,
177
+ checkpoint=True,
178
+ disable_self_attn=False,
179
+ ):
180
+ super().__init__()
181
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
182
+ assert attn_mode in self.ATTENTION_MODES
183
+ attn_cls = self.ATTENTION_MODES[attn_mode]
184
+ self.disable_self_attn = disable_self_attn
185
+ self.attn1 = attn_cls(
186
+ query_dim=dim,
187
+ heads=n_heads,
188
+ dim_head=d_head,
189
+ dropout=dropout,
190
+ context_dim=context_dim if self.disable_self_attn else None,
191
+ ) # is a self-attention if not self.disable_self_attn
192
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
193
+ self.attn2 = attn_cls(
194
+ query_dim=dim,
195
+ context_dim=context_dim,
196
+ heads=n_heads,
197
+ dim_head=d_head,
198
+ dropout=dropout,
199
+ ) # is self-attn if context is none
200
+ self.norm1 = nn.LayerNorm(dim)
201
+ self.norm2 = nn.LayerNorm(dim)
202
+ self.norm3 = nn.LayerNorm(dim)
203
+ self.checkpoint = checkpoint
204
+
205
+ def forward(self, x, context=None):
206
+ return checkpoint(
207
+ self._forward, (x, context), self.parameters(), self.checkpoint
208
+ )
209
+
210
+ def _forward(self, x, context=None):
211
+ x = (
212
+ self.attn1(
213
+ self.norm1(x), context=context if self.disable_self_attn else None
214
+ )
215
+ + x
216
+ )
217
+ x = self.attn2(self.norm2(x), context=context) + x
218
+ x = self.ff(self.norm3(x)) + x
219
+ return x
220
+
221
+
222
+ class SpatialTransformer(nn.Module):
223
+ """
224
+ Transformer block for image-like data.
225
+ First, project the input (aka embedding)
226
+ and reshape to b, t, d.
227
+ Then apply standard transformer action.
228
+ Finally, reshape to image
229
+ NEW: use_linear for more efficiency instead of the 1x1 convs
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ in_channels,
235
+ n_heads,
236
+ d_head,
237
+ depth=1,
238
+ dropout=0.0,
239
+ context_dim=None,
240
+ disable_self_attn=False,
241
+ use_linear=False,
242
+ use_checkpoint=True,
243
+ ):
244
+ super().__init__()
245
+ assert context_dim is not None
246
+ if not isinstance(context_dim, list):
247
+ context_dim = [context_dim]
248
+ self.in_channels = in_channels
249
+ inner_dim = n_heads * d_head
250
+ self.norm = nn.GroupNorm(
251
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
252
+ )
253
+ if not use_linear:
254
+ self.proj_in = nn.Conv2d(
255
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
256
+ )
257
+ else:
258
+ self.proj_in = nn.Linear(in_channels, inner_dim)
259
+
260
+ self.transformer_blocks = nn.ModuleList(
261
+ [
262
+ BasicTransformerBlock(
263
+ inner_dim,
264
+ n_heads,
265
+ d_head,
266
+ dropout=dropout,
267
+ context_dim=context_dim[d],
268
+ disable_self_attn=disable_self_attn,
269
+ checkpoint=use_checkpoint,
270
+ )
271
+ for d in range(depth)
272
+ ]
273
+ )
274
+ if not use_linear:
275
+ self.proj_out = zero_module(
276
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
277
+ )
278
+ else:
279
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
280
+ self.use_linear = use_linear
281
+
282
+ def forward(self, x, context=None):
283
+ # note: if no context is given, cross-attention defaults to self-attention
284
+ if not isinstance(context, list):
285
+ context = [context]
286
+ b, c, h, w = x.shape
287
+ x_in = x
288
+ x = self.norm(x)
289
+ if not self.use_linear:
290
+ x = self.proj_in(x)
291
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
292
+ if self.use_linear:
293
+ x = self.proj_in(x)
294
+ for i, block in enumerate(self.transformer_blocks):
295
+ x = block(x, context=context[i])
296
+ if self.use_linear:
297
+ x = self.proj_out(x)
298
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
299
+ if not self.use_linear:
300
+ x = self.proj_out(x)
301
+ return x + x_in
302
+
303
+
304
+ class BasicTransformerBlock3D(BasicTransformerBlock):
305
+ def forward(self, x, context=None, num_frames=1):
306
+ return checkpoint(
307
+ self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
308
+ )
309
+
310
+ def _forward(self, x, context=None, num_frames=1):
311
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
312
+ x = (
313
+ self.attn1(
314
+ self.norm1(x), context=context if self.disable_self_attn else None
315
+ )
316
+ + x
317
+ )
318
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
319
+ x = self.attn2(self.norm2(x), context=context) + x
320
+ x = self.ff(self.norm3(x)) + x
321
+ return x
322
+
323
+
324
+ class SpatialTransformer3D(nn.Module):
325
+ """3D self-attention"""
326
+
327
+ def __init__(
328
+ self,
329
+ in_channels,
330
+ n_heads,
331
+ d_head,
332
+ depth=1,
333
+ dropout=0.0,
334
+ context_dim=None,
335
+ disable_self_attn=False,
336
+ use_linear=True,
337
+ use_checkpoint=True,
338
+ ):
339
+ super().__init__()
340
+ assert context_dim is not None
341
+ if not isinstance(context_dim, list):
342
+ context_dim = [context_dim]
343
+ self.in_channels = in_channels
344
+ inner_dim = n_heads * d_head
345
+ self.norm = nn.GroupNorm(
346
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
347
+ )
348
+ if not use_linear:
349
+ self.proj_in = nn.Conv2d(
350
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
351
+ )
352
+ else:
353
+ self.proj_in = nn.Linear(in_channels, inner_dim)
354
+
355
+ self.transformer_blocks = nn.ModuleList(
356
+ [
357
+ BasicTransformerBlock3D(
358
+ inner_dim,
359
+ n_heads,
360
+ d_head,
361
+ dropout=dropout,
362
+ context_dim=context_dim[d],
363
+ disable_self_attn=disable_self_attn,
364
+ checkpoint=use_checkpoint,
365
+ )
366
+ for d in range(depth)
367
+ ]
368
+ )
369
+ if not use_linear:
370
+ self.proj_out = zero_module(
371
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
372
+ )
373
+ else:
374
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
375
+ self.use_linear = use_linear
376
+
377
+ def forward(self, x, context=None, num_frames=1):
378
+ # note: if no context is given, cross-attention defaults to self-attention
379
+ if not isinstance(context, list):
380
+ context = [context]
381
+ b, c, h, w = x.shape
382
+ x_in = x
383
+ x = self.norm(x)
384
+ if not self.use_linear:
385
+ x = self.proj_in(x)
386
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
387
+ if self.use_linear:
388
+ x = self.proj_in(x)
389
+ for i, block in enumerate(self.transformer_blocks):
390
+ x = block(x, context=context[i], num_frames=num_frames)
391
+ if self.use_linear:
392
+ x = self.proj_out(x)
393
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
394
+ if not self.use_linear:
395
+ x = self.proj_out(x)
396
+ return x + x_in
imagedream/models.py ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from diffusers.configuration_utils import ConfigMixin
5
+ from diffusers.models.modeling_utils import ModelMixin
6
+ from typing import Any, List, Optional
7
+ from torch import Tensor
8
+
9
+ from .util import (
10
+ checkpoint,
11
+ conv_nd,
12
+ avg_pool_nd,
13
+ zero_module,
14
+ timestep_embedding,
15
+ )
16
+ from .attention import SpatialTransformer, SpatialTransformer3D
17
+
18
+
19
+ class CondSequential(nn.Sequential):
20
+ """
21
+ A sequential module that passes timestep embeddings to the children that
22
+ support it as an extra input.
23
+ """
24
+
25
+ def forward(self, x, emb, context=None, num_frames=1):
26
+ for layer in self:
27
+ if isinstance(layer, ResBlock):
28
+ x = layer(x, emb)
29
+ elif isinstance(layer, SpatialTransformer3D):
30
+ x = layer(x, context, num_frames=num_frames)
31
+ elif isinstance(layer, SpatialTransformer):
32
+ x = layer(x, context)
33
+ else:
34
+ x = layer(x)
35
+ return x
36
+
37
+
38
+ class Upsample(nn.Module):
39
+ """
40
+ An upsampling layer with an optional convolution.
41
+ :param channels: channels in the inputs and outputs.
42
+ :param use_conv: a bool determining if a convolution is applied.
43
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
44
+ upsampling occurs in the inner-two dimensions.
45
+ """
46
+
47
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
48
+ super().__init__()
49
+ self.channels = channels
50
+ self.out_channels = out_channels or channels
51
+ self.use_conv = use_conv
52
+ self.dims = dims
53
+ if use_conv:
54
+ self.conv = conv_nd(
55
+ dims, self.channels, self.out_channels, 3, padding=padding
56
+ )
57
+
58
+ def forward(self, x):
59
+ assert x.shape[1] == self.channels
60
+ if self.dims == 3:
61
+ x = F.interpolate(
62
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
63
+ )
64
+ else:
65
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
66
+ if self.use_conv:
67
+ x = self.conv(x)
68
+ return x
69
+
70
+
71
+ class Downsample(nn.Module):
72
+ """
73
+ A downsampling layer with an optional convolution.
74
+ :param channels: channels in the inputs and outputs.
75
+ :param use_conv: a bool determining if a convolution is applied.
76
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
77
+ downsampling occurs in the inner-two dimensions.
78
+ """
79
+
80
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
81
+ super().__init__()
82
+ self.channels = channels
83
+ self.out_channels = out_channels or channels
84
+ self.use_conv = use_conv
85
+ self.dims = dims
86
+ stride = 2 if dims != 3 else (1, 2, 2)
87
+ if use_conv:
88
+ self.op = conv_nd(
89
+ dims,
90
+ self.channels,
91
+ self.out_channels,
92
+ 3,
93
+ stride=stride,
94
+ padding=padding,
95
+ )
96
+ else:
97
+ assert self.channels == self.out_channels
98
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
99
+
100
+ def forward(self, x):
101
+ assert x.shape[1] == self.channels
102
+ return self.op(x)
103
+
104
+
105
+ class ResBlock(nn.Module):
106
+ """
107
+ A residual block that can optionally change the number of channels.
108
+ :param channels: the number of input channels.
109
+ :param emb_channels: the number of timestep embedding channels.
110
+ :param dropout: the rate of dropout.
111
+ :param out_channels: if specified, the number of out channels.
112
+ :param use_conv: if True and out_channels is specified, use a spatial
113
+ convolution instead of a smaller 1x1 convolution to change the
114
+ channels in the skip connection.
115
+ :param dims: determines if the signal is 1D, 2D, or 3D.
116
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
117
+ :param up: if True, use this block for upsampling.
118
+ :param down: if True, use this block for downsampling.
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ channels,
124
+ emb_channels,
125
+ dropout,
126
+ out_channels=None,
127
+ use_conv=False,
128
+ use_scale_shift_norm=False,
129
+ dims=2,
130
+ use_checkpoint=False,
131
+ up=False,
132
+ down=False,
133
+ ):
134
+ super().__init__()
135
+ self.channels = channels
136
+ self.emb_channels = emb_channels
137
+ self.dropout = dropout
138
+ self.out_channels = out_channels or channels
139
+ self.use_conv = use_conv
140
+ self.use_checkpoint = use_checkpoint
141
+ self.use_scale_shift_norm = use_scale_shift_norm
142
+
143
+ self.in_layers = nn.Sequential(
144
+ nn.GroupNorm(32, channels),
145
+ nn.SiLU(),
146
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
147
+ )
148
+
149
+ self.updown = up or down
150
+
151
+ if up:
152
+ self.h_upd = Upsample(channels, False, dims)
153
+ self.x_upd = Upsample(channels, False, dims)
154
+ elif down:
155
+ self.h_upd = Downsample(channels, False, dims)
156
+ self.x_upd = Downsample(channels, False, dims)
157
+ else:
158
+ self.h_upd = self.x_upd = nn.Identity()
159
+
160
+ self.emb_layers = nn.Sequential(
161
+ nn.SiLU(),
162
+ nn.Linear(
163
+ emb_channels,
164
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
165
+ ),
166
+ )
167
+ self.out_layers = nn.Sequential(
168
+ nn.GroupNorm(32, self.out_channels),
169
+ nn.SiLU(),
170
+ nn.Dropout(p=dropout),
171
+ zero_module(
172
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
173
+ ),
174
+ )
175
+
176
+ if self.out_channels == channels:
177
+ self.skip_connection = nn.Identity()
178
+ elif use_conv:
179
+ self.skip_connection = conv_nd(
180
+ dims, channels, self.out_channels, 3, padding=1
181
+ )
182
+ else:
183
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
184
+
185
+ def forward(self, x, emb):
186
+ """
187
+ Apply the block to a Tensor, conditioned on a timestep embedding.
188
+ :param x: an [N x C x ...] Tensor of features.
189
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
190
+ :return: an [N x C x ...] Tensor of outputs.
191
+ """
192
+ return checkpoint(
193
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
194
+ )
195
+
196
+ def _forward(self, x, emb):
197
+ if self.updown:
198
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
199
+ h = in_rest(x)
200
+ h = self.h_upd(h)
201
+ x = self.x_upd(x)
202
+ h = in_conv(h)
203
+ else:
204
+ h = self.in_layers(x)
205
+ emb_out = self.emb_layers(emb).type(h.dtype)
206
+ while len(emb_out.shape) < len(h.shape):
207
+ emb_out = emb_out[..., None]
208
+ if self.use_scale_shift_norm:
209
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
210
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
211
+ h = out_norm(h) * (1 + scale) + shift
212
+ h = out_rest(h)
213
+ else:
214
+ h = h + emb_out
215
+ h = self.out_layers(h)
216
+ return self.skip_connection(x) + h
217
+
218
+
219
+ class MultiViewUNetModel(ModelMixin, ConfigMixin):
220
+ """
221
+ The full multi-view UNet model with attention, timestep embedding and camera embedding.
222
+ :param in_channels: channels in the input Tensor.
223
+ :param model_channels: base channel count for the model.
224
+ :param out_channels: channels in the output Tensor.
225
+ :param num_res_blocks: number of residual blocks per downsample.
226
+ :param attention_resolutions: a collection of downsample rates at which
227
+ attention will take place. May be a set, list, or tuple.
228
+ For example, if this contains 4, then at 4x downsampling, attention
229
+ will be used.
230
+ :param dropout: the dropout probability.
231
+ :param channel_mult: channel multiplier for each level of the UNet.
232
+ :param conv_resample: if True, use learned convolutions for upsampling and
233
+ downsampling.
234
+ :param dims: determines if the signal is 1D, 2D, or 3D.
235
+ :param num_classes: if specified (as an int), then this model will be
236
+ class-conditional with `num_classes` classes.
237
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
238
+ :param num_heads: the number of attention heads in each attention layer.
239
+ :param num_heads_channels: if specified, ignore num_heads and instead use
240
+ a fixed channel width per attention head.
241
+ :param num_heads_upsample: works with num_heads to set a different number
242
+ of heads for upsampling. Deprecated.
243
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
244
+ :param resblock_updown: use residual blocks for up/downsampling.
245
+ :param use_new_attention_order: use a different attention pattern for potentially
246
+ increased efficiency.
247
+ :param camera_dim: dimensionality of camera input.
248
+ """
249
+
250
+ def __init__(
251
+ self,
252
+ image_size,
253
+ in_channels,
254
+ model_channels,
255
+ out_channels,
256
+ num_res_blocks,
257
+ attention_resolutions,
258
+ dropout=0,
259
+ channel_mult=(1, 2, 4, 8),
260
+ conv_resample=True,
261
+ dims=2,
262
+ num_classes=None,
263
+ use_checkpoint=False,
264
+ num_heads=-1,
265
+ num_head_channels=-1,
266
+ num_heads_upsample=-1,
267
+ use_scale_shift_norm=False,
268
+ resblock_updown=False,
269
+ transformer_depth=1, # custom transformer support
270
+ context_dim=None, # custom transformer support
271
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
272
+ disable_self_attentions=None,
273
+ num_attention_blocks=None,
274
+ disable_middle_self_attn=False,
275
+ adm_in_channels=None,
276
+ camera_dim=None,
277
+ **kwargs,
278
+ ):
279
+ super().__init__()
280
+ assert context_dim is not None
281
+
282
+ if num_heads_upsample == -1:
283
+ num_heads_upsample = num_heads
284
+
285
+ if num_heads == -1:
286
+ assert (
287
+ num_head_channels != -1
288
+ ), "Either num_heads or num_head_channels has to be set"
289
+
290
+ if num_head_channels == -1:
291
+ assert (
292
+ num_heads != -1
293
+ ), "Either num_heads or num_head_channels has to be set"
294
+
295
+ self.image_size = image_size
296
+ self.in_channels = in_channels
297
+ self.model_channels = model_channels
298
+ self.out_channels = out_channels
299
+ if isinstance(num_res_blocks, int):
300
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
301
+ else:
302
+ if len(num_res_blocks) != len(channel_mult):
303
+ raise ValueError(
304
+ "provide num_res_blocks either as an int (globally constant) or "
305
+ "as a list/tuple (per-level) with the same length as channel_mult"
306
+ )
307
+ self.num_res_blocks = num_res_blocks
308
+ if disable_self_attentions is not None:
309
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
310
+ assert len(disable_self_attentions) == len(channel_mult)
311
+ if num_attention_blocks is not None:
312
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
313
+ assert all(
314
+ map(
315
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
316
+ range(len(num_attention_blocks)),
317
+ )
318
+ )
319
+ print(
320
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
321
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
322
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
323
+ f"attention will still not be set."
324
+ )
325
+
326
+ self.attention_resolutions = attention_resolutions
327
+ self.dropout = dropout
328
+ self.channel_mult = channel_mult
329
+ self.conv_resample = conv_resample
330
+ self.num_classes = num_classes
331
+ self.use_checkpoint = use_checkpoint
332
+ self.num_heads = num_heads
333
+ self.num_head_channels = num_head_channels
334
+ self.num_heads_upsample = num_heads_upsample
335
+ self.predict_codebook_ids = n_embed is not None
336
+
337
+ time_embed_dim = model_channels * 4
338
+ self.time_embed = nn.Sequential(
339
+ nn.Linear(model_channels, time_embed_dim),
340
+ nn.SiLU(),
341
+ nn.Linear(time_embed_dim, time_embed_dim),
342
+ )
343
+
344
+ if camera_dim is not None:
345
+ time_embed_dim = model_channels * 4
346
+ self.camera_embed = nn.Sequential(
347
+ nn.Linear(camera_dim, time_embed_dim),
348
+ nn.SiLU(),
349
+ nn.Linear(time_embed_dim, time_embed_dim),
350
+ )
351
+
352
+ if self.num_classes is not None:
353
+ if isinstance(self.num_classes, int):
354
+ self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
355
+ elif self.num_classes == "continuous":
356
+ # print("setting up linear c_adm embedding layer")
357
+ self.label_emb = nn.Linear(1, time_embed_dim)
358
+ elif self.num_classes == "sequential":
359
+ assert adm_in_channels is not None
360
+ self.label_emb = nn.Sequential(
361
+ nn.Sequential(
362
+ nn.Linear(adm_in_channels, time_embed_dim),
363
+ nn.SiLU(),
364
+ nn.Linear(time_embed_dim, time_embed_dim),
365
+ )
366
+ )
367
+ else:
368
+ raise ValueError()
369
+
370
+ self.input_blocks = nn.ModuleList(
371
+ [
372
+ CondSequential(
373
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
374
+ )
375
+ ]
376
+ )
377
+ self._feature_size = model_channels
378
+ input_block_chans = [model_channels]
379
+ ch = model_channels
380
+ ds = 1
381
+ for level, mult in enumerate(channel_mult):
382
+ for nr in range(self.num_res_blocks[level]):
383
+ layers: List[Any] = [
384
+ ResBlock(
385
+ ch,
386
+ time_embed_dim,
387
+ dropout,
388
+ out_channels=mult * model_channels,
389
+ dims=dims,
390
+ use_checkpoint=use_checkpoint,
391
+ use_scale_shift_norm=use_scale_shift_norm,
392
+ )
393
+ ]
394
+ ch = mult * model_channels
395
+ if ds in attention_resolutions:
396
+ if num_head_channels == -1:
397
+ dim_head = ch // num_heads
398
+ else:
399
+ num_heads = ch // num_head_channels
400
+ dim_head = num_head_channels
401
+
402
+ if disable_self_attentions is not None:
403
+ disabled_sa = disable_self_attentions[level]
404
+ else:
405
+ disabled_sa = False
406
+
407
+ if num_attention_blocks is None or nr < num_attention_blocks[level]:
408
+ layers.append(
409
+ SpatialTransformer3D(
410
+ ch,
411
+ num_heads,
412
+ dim_head,
413
+ depth=transformer_depth,
414
+ context_dim=context_dim,
415
+ disable_self_attn=disabled_sa,
416
+ use_checkpoint=use_checkpoint,
417
+ )
418
+ )
419
+ self.input_blocks.append(CondSequential(*layers))
420
+ self._feature_size += ch
421
+ input_block_chans.append(ch)
422
+ if level != len(channel_mult) - 1:
423
+ out_ch = ch
424
+ self.input_blocks.append(
425
+ CondSequential(
426
+ ResBlock(
427
+ ch,
428
+ time_embed_dim,
429
+ dropout,
430
+ out_channels=out_ch,
431
+ dims=dims,
432
+ use_checkpoint=use_checkpoint,
433
+ use_scale_shift_norm=use_scale_shift_norm,
434
+ down=True,
435
+ )
436
+ if resblock_updown
437
+ else Downsample(
438
+ ch, conv_resample, dims=dims, out_channels=out_ch
439
+ )
440
+ )
441
+ )
442
+ ch = out_ch
443
+ input_block_chans.append(ch)
444
+ ds *= 2
445
+ self._feature_size += ch
446
+
447
+ if num_head_channels == -1:
448
+ dim_head = ch // num_heads
449
+ else:
450
+ num_heads = ch // num_head_channels
451
+ dim_head = num_head_channels
452
+
453
+ self.middle_block = CondSequential(
454
+ ResBlock(
455
+ ch,
456
+ time_embed_dim,
457
+ dropout,
458
+ dims=dims,
459
+ use_checkpoint=use_checkpoint,
460
+ use_scale_shift_norm=use_scale_shift_norm,
461
+ ),
462
+ SpatialTransformer3D(
463
+ ch,
464
+ num_heads,
465
+ dim_head,
466
+ depth=transformer_depth,
467
+ context_dim=context_dim,
468
+ disable_self_attn=disable_middle_self_attn,
469
+ use_checkpoint=use_checkpoint,
470
+ ),
471
+ ResBlock(
472
+ ch,
473
+ time_embed_dim,
474
+ dropout,
475
+ dims=dims,
476
+ use_checkpoint=use_checkpoint,
477
+ use_scale_shift_norm=use_scale_shift_norm,
478
+ ),
479
+ )
480
+ self._feature_size += ch
481
+
482
+ self.output_blocks = nn.ModuleList([])
483
+ for level, mult in list(enumerate(channel_mult))[::-1]:
484
+ for i in range(self.num_res_blocks[level] + 1):
485
+ ich = input_block_chans.pop()
486
+ layers = [
487
+ ResBlock(
488
+ ch + ich,
489
+ time_embed_dim,
490
+ dropout,
491
+ out_channels=model_channels * mult,
492
+ dims=dims,
493
+ use_checkpoint=use_checkpoint,
494
+ use_scale_shift_norm=use_scale_shift_norm,
495
+ )
496
+ ]
497
+ ch = model_channels * mult
498
+ if ds in attention_resolutions:
499
+ if num_head_channels == -1:
500
+ dim_head = ch // num_heads
501
+ else:
502
+ num_heads = ch // num_head_channels
503
+ dim_head = num_head_channels
504
+
505
+ if disable_self_attentions is not None:
506
+ disabled_sa = disable_self_attentions[level]
507
+ else:
508
+ disabled_sa = False
509
+
510
+ if num_attention_blocks is None or i < num_attention_blocks[level]:
511
+ layers.append(
512
+ SpatialTransformer3D(
513
+ ch,
514
+ num_heads,
515
+ dim_head,
516
+ depth=transformer_depth,
517
+ context_dim=context_dim,
518
+ disable_self_attn=disabled_sa,
519
+ use_checkpoint=use_checkpoint,
520
+ )
521
+ )
522
+ if level and i == self.num_res_blocks[level]:
523
+ out_ch = ch
524
+ layers.append(
525
+ ResBlock(
526
+ ch,
527
+ time_embed_dim,
528
+ dropout,
529
+ out_channels=out_ch,
530
+ dims=dims,
531
+ use_checkpoint=use_checkpoint,
532
+ use_scale_shift_norm=use_scale_shift_norm,
533
+ up=True,
534
+ )
535
+ if resblock_updown
536
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
537
+ )
538
+ ds //= 2
539
+ self.output_blocks.append(CondSequential(*layers))
540
+ self._feature_size += ch
541
+
542
+ self.out = nn.Sequential(
543
+ nn.GroupNorm(32, ch),
544
+ nn.SiLU(),
545
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
546
+ )
547
+ if self.predict_codebook_ids:
548
+ self.id_predictor = nn.Sequential(
549
+ nn.GroupNorm(32, ch),
550
+ conv_nd(dims, model_channels, n_embed, 1),
551
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
552
+ )
553
+
554
+ def forward(
555
+ self,
556
+ x,
557
+ timesteps=None,
558
+ context=None,
559
+ y: Optional[Tensor] = None,
560
+ camera=None,
561
+ num_frames=1,
562
+ **kwargs,
563
+ ):
564
+ """
565
+ Apply the model to an input batch.
566
+ :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
567
+ :param timesteps: a 1-D batch of timesteps.
568
+ :param context: conditioning plugged in via crossattn
569
+ :param y: an [N] Tensor of labels, if class-conditional.
570
+ :param num_frames: a integer indicating number of frames for tensor reshaping.
571
+ :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
572
+ """
573
+ assert (
574
+ x.shape[0] % num_frames == 0
575
+ ), "[UNet] input batch size must be dividable by num_frames!"
576
+ assert (y is not None) == (
577
+ self.num_classes is not None
578
+ ), "must specify y if and only if the model is class-conditional"
579
+ hs = []
580
+ t_emb = timestep_embedding(
581
+ timesteps, self.model_channels, repeat_only=False
582
+ ).to(x.dtype)
583
+
584
+ emb = self.time_embed(t_emb)
585
+
586
+ if self.num_classes is not None:
587
+ assert y is not None
588
+ assert y.shape[0] == x.shape[0]
589
+ emb = emb + self.label_emb(y)
590
+
591
+ # Add camera embeddings
592
+ if camera is not None:
593
+ assert camera.shape[0] == emb.shape[0]
594
+ emb = emb + self.camera_embed(camera)
595
+
596
+ h = x
597
+ for module in self.input_blocks:
598
+ h = module(h, emb, context, num_frames=num_frames)
599
+ hs.append(h)
600
+ h = self.middle_block(h, emb, context, num_frames=num_frames)
601
+ for module in self.output_blocks:
602
+ h = torch.cat([h, hs.pop()], dim=1)
603
+ h = module(h, emb, context, num_frames=num_frames)
604
+ h = h.type(x.dtype)
605
+ if self.predict_codebook_ids:
606
+ return self.id_predictor(h)
607
+ else:
608
+ return self.out(h)
imagedream/pipeline_imagedream.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import inspect
3
+ import numpy as np
4
+ from typing import Callable, List, Optional, Union
5
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor
6
+ from diffusers import AutoencoderKL, DiffusionPipeline
7
+ from diffusers.utils import (
8
+ deprecate,
9
+ is_accelerate_available,
10
+ is_accelerate_version,
11
+ logging,
12
+ )
13
+ from diffusers.configuration_utils import FrozenDict
14
+ from diffusers.schedulers import DDIMScheduler
15
+ from diffusers.utils.torch_utils import randn_tensor
16
+
17
+ from .models import MultiViewUNetModel
18
+
19
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
+
21
+
22
+ def create_camera_to_world_matrix(elevation, azimuth):
23
+ elevation = np.radians(elevation)
24
+ azimuth = np.radians(azimuth)
25
+ # Convert elevation and azimuth angles to Cartesian coordinates on a unit sphere
26
+ x = np.cos(elevation) * np.sin(azimuth)
27
+ y = np.sin(elevation)
28
+ z = np.cos(elevation) * np.cos(azimuth)
29
+
30
+ # Calculate camera position, target, and up vectors
31
+ camera_pos = np.array([x, y, z])
32
+ target = np.array([0, 0, 0])
33
+ up = np.array([0, 1, 0])
34
+
35
+ # Construct view matrix
36
+ forward = target - camera_pos
37
+ forward /= np.linalg.norm(forward)
38
+ right = np.cross(forward, up)
39
+ right /= np.linalg.norm(right)
40
+ new_up = np.cross(right, forward)
41
+ new_up /= np.linalg.norm(new_up)
42
+ cam2world = np.eye(4)
43
+ cam2world[:3, :3] = np.array([right, new_up, -forward]).T
44
+ cam2world[:3, 3] = camera_pos
45
+ return cam2world
46
+
47
+
48
+ def convert_opengl_to_blender(camera_matrix):
49
+ if isinstance(camera_matrix, np.ndarray):
50
+ # Construct transformation matrix to convert from OpenGL space to Blender space
51
+ flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
52
+ camera_matrix_blender = np.dot(flip_yz, camera_matrix)
53
+ else:
54
+ # Construct transformation matrix to convert from OpenGL space to Blender space
55
+ flip_yz = torch.tensor(
56
+ [[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]
57
+ )
58
+ if camera_matrix.ndim == 3:
59
+ flip_yz = flip_yz.unsqueeze(0)
60
+ camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix), camera_matrix)
61
+ return camera_matrix_blender
62
+
63
+
64
+ def get_camera(
65
+ num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True
66
+ ):
67
+ angle_gap = azimuth_span / num_frames
68
+ cameras = []
69
+ for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
70
+ camera_matrix = create_camera_to_world_matrix(elevation, azimuth)
71
+ if blender_coord:
72
+ camera_matrix = convert_opengl_to_blender(camera_matrix)
73
+ cameras.append(camera_matrix.flatten())
74
+ return torch.tensor(np.stack(cameras, 0)).float()
75
+
76
+
77
+ class ImageDreamPipeline(DiffusionPipeline):
78
+ def __init__(
79
+ self,
80
+ vae: AutoencoderKL,
81
+ unet: MultiViewUNetModel,
82
+ tokenizer: CLIPTokenizer,
83
+ text_encoder: CLIPTextModel,
84
+ scheduler: DDIMScheduler,
85
+ feature_extractor: CLIPImageProcessor,
86
+ image_encoder: CLIPVisionModel,
87
+ requires_safety_checker: bool = False,
88
+ ):
89
+ super().__init__()
90
+
91
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: # type: ignore
92
+ deprecation_message = (
93
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
94
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " # type: ignore
95
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
96
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
97
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
98
+ " file"
99
+ )
100
+ deprecate(
101
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
102
+ )
103
+ new_config = dict(scheduler.config)
104
+ new_config["steps_offset"] = 1
105
+ scheduler._internal_dict = FrozenDict(new_config)
106
+
107
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: # type: ignore
108
+ deprecation_message = (
109
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
110
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
111
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
112
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
113
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
114
+ )
115
+ deprecate(
116
+ "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
117
+ )
118
+ new_config = dict(scheduler.config)
119
+ new_config["clip_sample"] = False
120
+ scheduler._internal_dict = FrozenDict(new_config)
121
+
122
+ self.register_modules(
123
+ vae=vae,
124
+ unet=unet,
125
+ scheduler=scheduler,
126
+ tokenizer=tokenizer,
127
+ text_encoder=text_encoder,
128
+ feature_extractor=feature_extractor,
129
+ image_encoder=image_encoder,
130
+ )
131
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
132
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
133
+
134
+ def enable_vae_slicing(self):
135
+ r"""
136
+ Enable sliced VAE decoding.
137
+
138
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
139
+ steps. This is useful to save some memory and allow larger batch sizes.
140
+ """
141
+ self.vae.enable_slicing()
142
+
143
+ def disable_vae_slicing(self):
144
+ r"""
145
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
146
+ computing decoding in one step.
147
+ """
148
+ self.vae.disable_slicing()
149
+
150
+ def enable_vae_tiling(self):
151
+ r"""
152
+ Enable tiled VAE decoding.
153
+
154
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
155
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
156
+ """
157
+ self.vae.enable_tiling()
158
+
159
+ def disable_vae_tiling(self):
160
+ r"""
161
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
162
+ computing decoding in one step.
163
+ """
164
+ self.vae.disable_tiling()
165
+
166
+ def enable_sequential_cpu_offload(self, gpu_id=0):
167
+ r"""
168
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
169
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
170
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
171
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
172
+ `enable_model_cpu_offload`, but performance is lower.
173
+ """
174
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
175
+ from accelerate import cpu_offload
176
+ else:
177
+ raise ImportError(
178
+ "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher"
179
+ )
180
+
181
+ device = torch.device(f"cuda:{gpu_id}")
182
+
183
+ if self.device.type != "cpu":
184
+ self.to("cpu", silence_dtype_warnings=True)
185
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
186
+
187
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
188
+ cpu_offload(cpu_offloaded_model, device)
189
+
190
+ def enable_model_cpu_offload(self, gpu_id=0):
191
+ r"""
192
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
193
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
194
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
195
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
196
+ """
197
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
198
+ from accelerate import cpu_offload_with_hook
199
+ else:
200
+ raise ImportError(
201
+ "`enable_model_offload` requires `accelerate v0.17.0` or higher."
202
+ )
203
+
204
+ device = torch.device(f"cuda:{gpu_id}")
205
+
206
+ if self.device.type != "cpu":
207
+ self.to("cpu", silence_dtype_warnings=True)
208
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
209
+
210
+ hook = None
211
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
212
+ _, hook = cpu_offload_with_hook(
213
+ cpu_offloaded_model, device, prev_module_hook=hook
214
+ )
215
+
216
+ # We'll offload the last model manually.
217
+ self.final_offload_hook = hook
218
+
219
+ @property
220
+ def _execution_device(self):
221
+ r"""
222
+ Returns the device on which the pipeline's models will be executed. After calling
223
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
224
+ hooks.
225
+ """
226
+ if not hasattr(self.unet, "_hf_hook"):
227
+ return self.device
228
+ for module in self.unet.modules():
229
+ if (
230
+ hasattr(module, "_hf_hook")
231
+ and hasattr(module._hf_hook, "execution_device")
232
+ and module._hf_hook.execution_device is not None
233
+ ):
234
+ return torch.device(module._hf_hook.execution_device)
235
+ return self.device
236
+
237
+ def _encode_prompt(
238
+ self,
239
+ prompt,
240
+ device,
241
+ num_images_per_prompt,
242
+ do_classifier_free_guidance: bool,
243
+ negative_prompt=None,
244
+ ):
245
+ r"""
246
+ Encodes the prompt into text encoder hidden states.
247
+
248
+ Args:
249
+ prompt (`str` or `List[str]`, *optional*):
250
+ prompt to be encoded
251
+ device: (`torch.device`):
252
+ torch device
253
+ num_images_per_prompt (`int`):
254
+ number of images that should be generated per prompt
255
+ do_classifier_free_guidance (`bool`):
256
+ whether to use classifier free guidance or not
257
+ negative_prompt (`str` or `List[str]`, *optional*):
258
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
259
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
260
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
261
+ prompt_embeds (`torch.FloatTensor`, *optional*):
262
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
263
+ provided, text embeddings will be generated from `prompt` input argument.
264
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
265
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
266
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
267
+ argument.
268
+ """
269
+ if prompt is not None and isinstance(prompt, str):
270
+ batch_size = 1
271
+ elif prompt is not None and isinstance(prompt, list):
272
+ batch_size = len(prompt)
273
+ else:
274
+ raise ValueError(
275
+ f"`prompt` should be either a string or a list of strings, but got {type(prompt)}."
276
+ )
277
+
278
+ text_inputs = self.tokenizer(
279
+ prompt,
280
+ padding="max_length",
281
+ max_length=self.tokenizer.model_max_length,
282
+ truncation=True,
283
+ return_tensors="pt",
284
+ )
285
+ text_input_ids = text_inputs.input_ids
286
+ untruncated_ids = self.tokenizer(
287
+ prompt, padding="longest", return_tensors="pt"
288
+ ).input_ids
289
+
290
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
291
+ text_input_ids, untruncated_ids
292
+ ):
293
+ removed_text = self.tokenizer.batch_decode(
294
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
295
+ )
296
+ logger.warning(
297
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
298
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
299
+ )
300
+
301
+ if (
302
+ hasattr(self.text_encoder.config, "use_attention_mask")
303
+ and self.text_encoder.config.use_attention_mask
304
+ ):
305
+ attention_mask = text_inputs.attention_mask.to(device)
306
+ else:
307
+ attention_mask = None
308
+
309
+ prompt_embeds = self.text_encoder(
310
+ text_input_ids.to(device),
311
+ attention_mask=attention_mask,
312
+ )
313
+ prompt_embeds = prompt_embeds[0]
314
+
315
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
316
+
317
+ bs_embed, seq_len, _ = prompt_embeds.shape
318
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
319
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
320
+ prompt_embeds = prompt_embeds.view(
321
+ bs_embed * num_images_per_prompt, seq_len, -1
322
+ )
323
+
324
+ # get unconditional embeddings for classifier free guidance
325
+ if do_classifier_free_guidance:
326
+ uncond_tokens: List[str]
327
+ if negative_prompt is None:
328
+ uncond_tokens = [""] * batch_size
329
+ elif type(prompt) is not type(negative_prompt):
330
+ raise TypeError(
331
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
332
+ f" {type(prompt)}."
333
+ )
334
+ elif isinstance(negative_prompt, str):
335
+ uncond_tokens = [negative_prompt]
336
+ elif batch_size != len(negative_prompt):
337
+ raise ValueError(
338
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
339
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
340
+ " the batch size of `prompt`."
341
+ )
342
+ else:
343
+ uncond_tokens = negative_prompt
344
+
345
+ max_length = prompt_embeds.shape[1]
346
+ uncond_input = self.tokenizer(
347
+ uncond_tokens,
348
+ padding="max_length",
349
+ max_length=max_length,
350
+ truncation=True,
351
+ return_tensors="pt",
352
+ )
353
+
354
+ if (
355
+ hasattr(self.text_encoder.config, "use_attention_mask")
356
+ and self.text_encoder.config.use_attention_mask
357
+ ):
358
+ attention_mask = uncond_input.attention_mask.to(device)
359
+ else:
360
+ attention_mask = None
361
+
362
+ negative_prompt_embeds = self.text_encoder(
363
+ uncond_input.input_ids.to(device),
364
+ attention_mask=attention_mask,
365
+ )
366
+ negative_prompt_embeds = negative_prompt_embeds[0]
367
+
368
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
369
+ seq_len = negative_prompt_embeds.shape[1]
370
+
371
+ negative_prompt_embeds = negative_prompt_embeds.to(
372
+ dtype=self.text_encoder.dtype, device=device
373
+ )
374
+
375
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
376
+ 1, num_images_per_prompt, 1
377
+ )
378
+ negative_prompt_embeds = negative_prompt_embeds.view(
379
+ batch_size * num_images_per_prompt, seq_len, -1
380
+ )
381
+
382
+ # For classifier free guidance, we need to do two forward passes.
383
+ # Here we concatenate the unconditional and text embeddings into a single batch
384
+ # to avoid doing two forward passes
385
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
386
+
387
+ return prompt_embeds
388
+
389
+ def decode_latents(self, latents):
390
+ latents = 1 / self.vae.config.scaling_factor * latents
391
+ image = self.vae.decode(latents).sample
392
+ image = (image / 2 + 0.5).clamp(0, 1)
393
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
394
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
395
+ return image
396
+
397
+ def prepare_extra_step_kwargs(self, generator, eta):
398
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
399
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
400
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
401
+ # and should be between [0, 1]
402
+
403
+ accepts_eta = "eta" in set(
404
+ inspect.signature(self.scheduler.step).parameters.keys()
405
+ )
406
+ extra_step_kwargs = {}
407
+ if accepts_eta:
408
+ extra_step_kwargs["eta"] = eta
409
+
410
+ # check if the scheduler accepts generator
411
+ accepts_generator = "generator" in set(
412
+ inspect.signature(self.scheduler.step).parameters.keys()
413
+ )
414
+ if accepts_generator:
415
+ extra_step_kwargs["generator"] = generator
416
+ return extra_step_kwargs
417
+
418
+ def prepare_latents(
419
+ self,
420
+ batch_size,
421
+ num_channels_latents,
422
+ height,
423
+ width,
424
+ dtype,
425
+ device,
426
+ generator,
427
+ latents=None,
428
+ ):
429
+ shape = (
430
+ batch_size,
431
+ num_channels_latents,
432
+ height // self.vae_scale_factor,
433
+ width // self.vae_scale_factor,
434
+ )
435
+ if isinstance(generator, list) and len(generator) != batch_size:
436
+ raise ValueError(
437
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
438
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
439
+ )
440
+
441
+ if latents is None:
442
+ latents = randn_tensor(
443
+ shape, generator=generator, device=device, dtype=dtype
444
+ )
445
+ else:
446
+ latents = latents.to(device)
447
+
448
+ # scale the initial noise by the standard deviation required by the scheduler
449
+ latents = latents * self.scheduler.init_noise_sigma
450
+ return latents
451
+
452
+ @torch.no_grad()
453
+ def __call__(
454
+ self,
455
+ image, # input image (TODO: pil?)
456
+ prompt: str = "a car",
457
+ height: int = 256,
458
+ width: int = 256,
459
+ num_inference_steps: int = 50,
460
+ guidance_scale: float = 7.0,
461
+ negative_prompt: str = "bad quality",
462
+ num_images_per_prompt: int = 1,
463
+ eta: float = 0.0,
464
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
465
+ output_type: Optional[str] = "image",
466
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
467
+ callback_steps: int = 1,
468
+ batch_size: int = 4,
469
+ device=torch.device("cuda:0"),
470
+ ):
471
+ self.unet = self.unet.to(device=device)
472
+ self.vae = self.vae.to(device=device)
473
+
474
+ self.text_encoder = self.text_encoder.to(device=device)
475
+
476
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
477
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
478
+ # corresponds to doing no classifier free guidance.
479
+ do_classifier_free_guidance = guidance_scale > 1.0
480
+
481
+ # Prepare timesteps
482
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
483
+ timesteps = self.scheduler.timesteps
484
+
485
+ _prompt_embeds: torch.Tensor = self._encode_prompt(
486
+ prompt=prompt,
487
+ device=device,
488
+ num_images_per_prompt=num_images_per_prompt,
489
+ do_classifier_free_guidance=do_classifier_free_guidance,
490
+ negative_prompt=negative_prompt,
491
+ ) # type: ignore
492
+ prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2)
493
+
494
+ # Prepare latent variables
495
+ latents: torch.Tensor = self.prepare_latents(
496
+ batch_size * num_images_per_prompt,
497
+ 4,
498
+ height,
499
+ width,
500
+ prompt_embeds_pos.dtype,
501
+ device,
502
+ generator,
503
+ None,
504
+ )
505
+
506
+ camera = get_camera(batch_size).to(dtype=latents.dtype, device=device)
507
+
508
+ # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
509
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
510
+
511
+ # Denoising loop
512
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
513
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
514
+ for i, t in enumerate(timesteps):
515
+ # expand the latents if we are doing classifier free guidance
516
+ multiplier = 2 if do_classifier_free_guidance else 1
517
+ latent_model_input = torch.cat([latents] * multiplier)
518
+ latent_model_input = self.scheduler.scale_model_input(
519
+ latent_model_input, t
520
+ )
521
+
522
+ # predict the noise residual
523
+ noise_pred = self.unet.forward(
524
+ x=latent_model_input,
525
+ timesteps=torch.tensor(
526
+ [t] * 4 * multiplier,
527
+ dtype=latent_model_input.dtype,
528
+ device=device,
529
+ ),
530
+ context=torch.cat(
531
+ [prompt_embeds_neg] * 4 + [prompt_embeds_pos] * 4
532
+ ),
533
+ num_frames=4,
534
+ camera=torch.cat([camera] * multiplier),
535
+ )
536
+
537
+ # perform guidance
538
+ if do_classifier_free_guidance:
539
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
540
+ noise_pred = noise_pred_uncond + guidance_scale * (
541
+ noise_pred_text - noise_pred_uncond
542
+ )
543
+
544
+ # compute the previous noisy sample x_t -> x_t-1
545
+ # latents = self.scheduler.step(noise_pred.to(dtype=torch.float32), t, latents.to(dtype=torch.float32)).prev_sample.to(prompt_embeds.dtype)
546
+ latents: torch.Tensor = self.scheduler.step(
547
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
548
+ )[0]
549
+
550
+ # call the callback, if provided
551
+ if i == len(timesteps) - 1 or (
552
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
553
+ ):
554
+ progress_bar.update()
555
+ if callback is not None and i % callback_steps == 0:
556
+ callback(i, t, latents) # type: ignore
557
+
558
+ # Post-processing
559
+ if output_type == "latent":
560
+ image = latents
561
+ elif output_type == "pil":
562
+ image = self.decode_latents(latents)
563
+ image = self.numpy_to_pil(image)
564
+ else:
565
+ image = self.decode_latents(latents)
566
+
567
+ # Offload last model to CPU
568
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
569
+ self.final_offload_hook.offload()
570
+
571
+ return image
imagedream/util.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import repeat
5
+
6
+
7
+ def checkpoint(func, inputs, params, flag):
8
+ """
9
+ Evaluate a function without caching intermediate activations, allowing for
10
+ reduced memory at the expense of extra compute in the backward pass.
11
+ :param func: the function to evaluate.
12
+ :param inputs: the argument sequence to pass to `func`.
13
+ :param params: a sequence of parameters `func` depends on but does not
14
+ explicitly take as arguments.
15
+ :param flag: if False, disable gradient checkpointing.
16
+ """
17
+ if flag:
18
+ args = tuple(inputs) + tuple(params)
19
+ return CheckpointFunction.apply(func, len(inputs), *args)
20
+ else:
21
+ return func(*inputs)
22
+
23
+
24
+ class CheckpointFunction(torch.autograd.Function):
25
+ @staticmethod
26
+ def forward(ctx, run_function, length, *args):
27
+ ctx.run_function = run_function
28
+ ctx.input_tensors = list(args[:length])
29
+ ctx.input_params = list(args[length:])
30
+
31
+ with torch.no_grad():
32
+ output_tensors = ctx.run_function(*ctx.input_tensors)
33
+ return output_tensors
34
+
35
+ @staticmethod
36
+ def backward(ctx, *output_grads):
37
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
38
+ with torch.enable_grad():
39
+ # Fixes a bug where the first op in run_function modifies the
40
+ # Tensor storage in place, which is not allowed for detach()'d
41
+ # Tensors.
42
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
43
+ output_tensors = ctx.run_function(*shallow_copies)
44
+ input_grads = torch.autograd.grad(
45
+ output_tensors,
46
+ ctx.input_tensors + ctx.input_params,
47
+ output_grads,
48
+ allow_unused=True,
49
+ )
50
+ del ctx.input_tensors
51
+ del ctx.input_params
52
+ del output_tensors
53
+ return (None, None) + input_grads
54
+
55
+
56
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
57
+ """
58
+ Create sinusoidal timestep embeddings.
59
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
60
+ These may be fractional.
61
+ :param dim: the dimension of the output.
62
+ :param max_period: controls the minimum frequency of the embeddings.
63
+ :return: an [N x dim] Tensor of positional embeddings.
64
+ """
65
+ if not repeat_only:
66
+ half = dim // 2
67
+ freqs = torch.exp(
68
+ -math.log(max_period)
69
+ * torch.arange(start=0, end=half, dtype=torch.float32)
70
+ / half
71
+ ).to(device=timesteps.device)
72
+ args = timesteps[:, None] * freqs[None]
73
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
74
+ if dim % 2:
75
+ embedding = torch.cat(
76
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
77
+ )
78
+ else:
79
+ embedding = repeat(timesteps, "b -> b d", d=dim)
80
+ # import pdb; pdb.set_trace()
81
+ return embedding
82
+
83
+
84
+ def zero_module(module):
85
+ """
86
+ Zero out the parameters of a module and return it.
87
+ """
88
+ for p in module.parameters():
89
+ p.detach().zero_()
90
+ return module
91
+
92
+
93
+ def conv_nd(dims, *args, **kwargs):
94
+ """
95
+ Create a 1D, 2D, or 3D convolution module.
96
+ """
97
+ if dims == 1:
98
+ return nn.Conv1d(*args, **kwargs)
99
+ elif dims == 2:
100
+ return nn.Conv2d(*args, **kwargs)
101
+ elif dims == 3:
102
+ return nn.Conv3d(*args, **kwargs)
103
+ raise ValueError(f"unsupported dimensions: {dims}")
104
+
105
+
106
+ def avg_pool_nd(dims, *args, **kwargs):
107
+ """
108
+ Create a 1D, 2D, or 3D average pooling module.
109
+ """
110
+ if dims == 1:
111
+ return nn.AvgPool1d(*args, **kwargs)
112
+ elif dims == 2:
113
+ return nn.AvgPool2d(*args, **kwargs)
114
+ elif dims == 3:
115
+ return nn.AvgPool3d(*args, **kwargs)
116
+ raise ValueError(f"unsupported dimensions: {dims}")