bhuvanmdev commited on
Commit
e67454a
·
verified ·
1 Parent(s): ddf3dab

Update unet/mv_unet.py

Browse files
Files changed (1) hide show
  1. unet/mv_unet.py +1004 -1004
unet/mv_unet.py CHANGED
@@ -1,1005 +1,1005 @@
1
- import math
2
- import numpy as np
3
- from inspect import isfunction
4
- from typing import Optional, Any, List
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- from einops import rearrange, repeat
10
-
11
- from diffusers.configuration_utils import ConfigMixin
12
- from diffusers.models.modeling_utils import ModelMixin
13
-
14
- # require xformers!
15
- import xformers
16
- import xformers.ops
17
-
18
- from kiui.cam import orbit_camera
19
-
20
- def get_camera(
21
- num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False,
22
- ):
23
- angle_gap = azimuth_span / num_frames
24
- cameras = []
25
- for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
26
-
27
- pose = orbit_camera(-elevation, azimuth, radius=1) # kiui's elevation is negated, [4, 4]
28
-
29
- # opengl to blender
30
- if blender_coord:
31
- pose[2] *= -1
32
- pose[[1, 2]] = pose[[2, 1]]
33
-
34
- cameras.append(pose.flatten())
35
-
36
- if extra_view:
37
- cameras.append(np.zeros_like(cameras[0]))
38
-
39
- return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
40
-
41
-
42
- def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
43
- """
44
- Create sinusoidal timestep embeddings.
45
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
46
- These may be fractional.
47
- :param dim: the dimension of the output.
48
- :param max_period: controls the minimum frequency of the embeddings.
49
- :return: an [N x dim] Tensor of positional embeddings.
50
- """
51
- if not repeat_only:
52
- half = dim // 2
53
- freqs = torch.exp(
54
- -math.log(max_period)
55
- * torch.arange(start=0, end=half, dtype=torch.float32)
56
- / half
57
- ).to(device=timesteps.device)
58
- args = timesteps[:, None] * freqs[None]
59
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
60
- if dim % 2:
61
- embedding = torch.cat(
62
- [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
63
- )
64
- else:
65
- embedding = repeat(timesteps, "b -> b d", d=dim)
66
- # import pdb; pdb.set_trace()
67
- return embedding
68
-
69
-
70
- def zero_module(module):
71
- """
72
- Zero out the parameters of a module and return it.
73
- """
74
- for p in module.parameters():
75
- p.detach().zero_()
76
- return module
77
-
78
-
79
- def conv_nd(dims, *args, **kwargs):
80
- """
81
- Create a 1D, 2D, or 3D convolution module.
82
- """
83
- if dims == 1:
84
- return nn.Conv1d(*args, **kwargs)
85
- elif dims == 2:
86
- return nn.Conv2d(*args, **kwargs)
87
- elif dims == 3:
88
- return nn.Conv3d(*args, **kwargs)
89
- raise ValueError(f"unsupported dimensions: {dims}")
90
-
91
-
92
- def avg_pool_nd(dims, *args, **kwargs):
93
- """
94
- Create a 1D, 2D, or 3D average pooling module.
95
- """
96
- if dims == 1:
97
- return nn.AvgPool1d(*args, **kwargs)
98
- elif dims == 2:
99
- return nn.AvgPool2d(*args, **kwargs)
100
- elif dims == 3:
101
- return nn.AvgPool3d(*args, **kwargs)
102
- raise ValueError(f"unsupported dimensions: {dims}")
103
-
104
-
105
- def default(val, d):
106
- if val is not None:
107
- return val
108
- return d() if isfunction(d) else d
109
-
110
-
111
- class GEGLU(nn.Module):
112
- def __init__(self, dim_in, dim_out):
113
- super().__init__()
114
- self.proj = nn.Linear(dim_in, dim_out * 2)
115
-
116
- def forward(self, x):
117
- x, gate = self.proj(x).chunk(2, dim=-1)
118
- return x * F.gelu(gate)
119
-
120
-
121
- class FeedForward(nn.Module):
122
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
123
- super().__init__()
124
- inner_dim = int(dim * mult)
125
- dim_out = default(dim_out, dim)
126
- project_in = (
127
- nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
128
- if not glu
129
- else GEGLU(dim, inner_dim)
130
- )
131
-
132
- self.net = nn.Sequential(
133
- project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
134
- )
135
-
136
- def forward(self, x):
137
- return self.net(x)
138
-
139
-
140
- class MemoryEfficientCrossAttention(nn.Module):
141
- # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
142
- def __init__(
143
- self,
144
- query_dim,
145
- context_dim=None,
146
- heads=8,
147
- dim_head=64,
148
- dropout=0.0,
149
- ip_dim=0,
150
- ip_weight=1,
151
- ):
152
- super().__init__()
153
-
154
- inner_dim = dim_head * heads
155
- context_dim = default(context_dim, query_dim)
156
-
157
- self.heads = heads
158
- self.dim_head = dim_head
159
-
160
- self.ip_dim = ip_dim
161
- self.ip_weight = ip_weight
162
-
163
- if self.ip_dim > 0:
164
- self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
165
- self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
166
-
167
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
168
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
169
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
170
-
171
- self.to_out = nn.Sequential(
172
- nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
173
- )
174
- self.attention_op: Optional[Any] = None
175
-
176
- def forward(self, x, context=None):
177
- q = self.to_q(x)
178
- context = default(context, x)
179
-
180
- if self.ip_dim > 0:
181
- # context: [B, 77 + 16(ip), 1024]
182
- token_len = context.shape[1]
183
- context_ip = context[:, -self.ip_dim :, :]
184
- k_ip = self.to_k_ip(context_ip)
185
- v_ip = self.to_v_ip(context_ip)
186
- context = context[:, : (token_len - self.ip_dim), :]
187
-
188
- k = self.to_k(context)
189
- v = self.to_v(context)
190
-
191
- b, _, _ = q.shape
192
- q, k, v = map(
193
- lambda t: t.unsqueeze(3)
194
- .reshape(b, t.shape[1], self.heads, self.dim_head)
195
- .permute(0, 2, 1, 3)
196
- .reshape(b * self.heads, t.shape[1], self.dim_head)
197
- .contiguous(),
198
- (q, k, v),
199
- )
200
-
201
- # actually compute the attention, what we cannot get enough of
202
- out = xformers.ops.memory_efficient_attention(
203
- q, k, v, attn_bias=None, op=self.attention_op
204
- )
205
-
206
- if self.ip_dim > 0:
207
- k_ip, v_ip = map(
208
- lambda t: t.unsqueeze(3)
209
- .reshape(b, t.shape[1], self.heads, self.dim_head)
210
- .permute(0, 2, 1, 3)
211
- .reshape(b * self.heads, t.shape[1], self.dim_head)
212
- .contiguous(),
213
- (k_ip, v_ip),
214
- )
215
- # actually compute the attention, what we cannot get enough of
216
- out_ip = xformers.ops.memory_efficient_attention(
217
- q, k_ip, v_ip, attn_bias=None, op=self.attention_op
218
- )
219
- out = out + self.ip_weight * out_ip
220
-
221
- out = (
222
- out.unsqueeze(0)
223
- .reshape(b, self.heads, out.shape[1], self.dim_head)
224
- .permute(0, 2, 1, 3)
225
- .reshape(b, out.shape[1], self.heads * self.dim_head)
226
- )
227
- return self.to_out(out)
228
-
229
-
230
- class BasicTransformerBlock3D(nn.Module):
231
-
232
- def __init__(
233
- self,
234
- dim,
235
- n_heads,
236
- d_head,
237
- context_dim,
238
- dropout=0.0,
239
- gated_ff=True,
240
- ip_dim=0,
241
- ip_weight=1,
242
- ):
243
- super().__init__()
244
-
245
- self.attn1 = MemoryEfficientCrossAttention(
246
- query_dim=dim,
247
- context_dim=None, # self-attention
248
- heads=n_heads,
249
- dim_head=d_head,
250
- dropout=dropout,
251
- )
252
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
253
- self.attn2 = MemoryEfficientCrossAttention(
254
- query_dim=dim,
255
- context_dim=context_dim,
256
- heads=n_heads,
257
- dim_head=d_head,
258
- dropout=dropout,
259
- # ip only applies to cross-attention
260
- ip_dim=ip_dim,
261
- ip_weight=ip_weight,
262
- )
263
- self.norm1 = nn.LayerNorm(dim)
264
- self.norm2 = nn.LayerNorm(dim)
265
- self.norm3 = nn.LayerNorm(dim)
266
-
267
- def forward(self, x, context=None, num_frames=1):
268
- x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
269
- x = self.attn1(self.norm1(x), context=None) + x
270
- x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
271
- x = self.attn2(self.norm2(x), context=context) + x
272
- x = self.ff(self.norm3(x)) + x
273
- return x
274
-
275
-
276
- class SpatialTransformer3D(nn.Module):
277
-
278
- def __init__(
279
- self,
280
- in_channels,
281
- n_heads,
282
- d_head,
283
- context_dim, # cross attention input dim
284
- depth=1,
285
- dropout=0.0,
286
- ip_dim=0,
287
- ip_weight=1,
288
- ):
289
- super().__init__()
290
-
291
- if not isinstance(context_dim, list):
292
- context_dim = [context_dim]
293
-
294
- self.in_channels = in_channels
295
-
296
- inner_dim = n_heads * d_head
297
- self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
298
- self.proj_in = nn.Linear(in_channels, inner_dim)
299
-
300
- self.transformer_blocks = nn.ModuleList(
301
- [
302
- BasicTransformerBlock3D(
303
- inner_dim,
304
- n_heads,
305
- d_head,
306
- context_dim=context_dim[d],
307
- dropout=dropout,
308
- ip_dim=ip_dim,
309
- ip_weight=ip_weight,
310
- )
311
- for d in range(depth)
312
- ]
313
- )
314
-
315
- self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
316
-
317
-
318
- def forward(self, x, context=None, num_frames=1):
319
- # note: if no context is given, cross-attention defaults to self-attention
320
- if not isinstance(context, list):
321
- context = [context]
322
- b, c, h, w = x.shape
323
- x_in = x
324
- x = self.norm(x)
325
- x = rearrange(x, "b c h w -> b (h w) c").contiguous()
326
- x = self.proj_in(x)
327
- for i, block in enumerate(self.transformer_blocks):
328
- x = block(x, context=context[i], num_frames=num_frames)
329
- x = self.proj_out(x)
330
- x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
331
-
332
- return x + x_in
333
-
334
-
335
- class PerceiverAttention(nn.Module):
336
- def __init__(self, *, dim, dim_head=64, heads=8):
337
- super().__init__()
338
- self.scale = dim_head ** -0.5
339
- self.dim_head = dim_head
340
- self.heads = heads
341
- inner_dim = dim_head * heads
342
-
343
- self.norm1 = nn.LayerNorm(dim)
344
- self.norm2 = nn.LayerNorm(dim)
345
-
346
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
347
- self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
348
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
349
-
350
- def forward(self, x, latents):
351
- """
352
- Args:
353
- x (torch.Tensor): image features
354
- shape (b, n1, D)
355
- latent (torch.Tensor): latent features
356
- shape (b, n2, D)
357
- """
358
- x = self.norm1(x)
359
- latents = self.norm2(latents)
360
-
361
- b, l, _ = latents.shape
362
-
363
- q = self.to_q(latents)
364
- kv_input = torch.cat((x, latents), dim=-2)
365
- k, v = self.to_kv(kv_input).chunk(2, dim=-1)
366
-
367
- q, k, v = map(
368
- lambda t: t.reshape(b, t.shape[1], self.heads, -1)
369
- .transpose(1, 2)
370
- .reshape(b, self.heads, t.shape[1], -1)
371
- .contiguous(),
372
- (q, k, v),
373
- )
374
-
375
- # attention
376
- scale = 1 / math.sqrt(math.sqrt(self.dim_head))
377
- weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
378
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
379
- out = weight @ v
380
-
381
- out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
382
-
383
- return self.to_out(out)
384
-
385
-
386
- class Resampler(nn.Module):
387
- def __init__(
388
- self,
389
- dim=1024,
390
- depth=8,
391
- dim_head=64,
392
- heads=16,
393
- num_queries=8,
394
- embedding_dim=768,
395
- output_dim=1024,
396
- ff_mult=4,
397
- ):
398
- super().__init__()
399
- self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
400
- self.proj_in = nn.Linear(embedding_dim, dim)
401
- self.proj_out = nn.Linear(dim, output_dim)
402
- self.norm_out = nn.LayerNorm(output_dim)
403
-
404
- self.layers = nn.ModuleList([])
405
- for _ in range(depth):
406
- self.layers.append(
407
- nn.ModuleList(
408
- [
409
- PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
410
- nn.Sequential(
411
- nn.LayerNorm(dim),
412
- nn.Linear(dim, dim * ff_mult, bias=False),
413
- nn.GELU(),
414
- nn.Linear(dim * ff_mult, dim, bias=False),
415
- )
416
- ]
417
- )
418
- )
419
-
420
- def forward(self, x):
421
- latents = self.latents.repeat(x.size(0), 1, 1)
422
- x = self.proj_in(x)
423
- for attn, ff in self.layers:
424
- latents = attn(x, latents) + latents
425
- latents = ff(latents) + latents
426
-
427
- latents = self.proj_out(latents)
428
- return self.norm_out(latents)
429
-
430
-
431
- class CondSequential(nn.Sequential):
432
- """
433
- A sequential module that passes timestep embeddings to the children that
434
- support it as an extra input.
435
- """
436
-
437
- def forward(self, x, emb, context=None, num_frames=1):
438
- for layer in self:
439
- if isinstance(layer, ResBlock):
440
- x = layer(x, emb)
441
- elif isinstance(layer, SpatialTransformer3D):
442
- x = layer(x, context, num_frames=num_frames)
443
- else:
444
- x = layer(x)
445
- return x
446
-
447
-
448
- class Upsample(nn.Module):
449
- """
450
- An upsampling layer with an optional convolution.
451
- :param channels: channels in the inputs and outputs.
452
- :param use_conv: a bool determining if a convolution is applied.
453
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
454
- upsampling occurs in the inner-two dimensions.
455
- """
456
-
457
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
458
- super().__init__()
459
- self.channels = channels
460
- self.out_channels = out_channels or channels
461
- self.use_conv = use_conv
462
- self.dims = dims
463
- if use_conv:
464
- self.conv = conv_nd(
465
- dims, self.channels, self.out_channels, 3, padding=padding
466
- )
467
-
468
- def forward(self, x):
469
- assert x.shape[1] == self.channels
470
- if self.dims == 3:
471
- x = F.interpolate(
472
- x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
473
- )
474
- else:
475
- x = F.interpolate(x, scale_factor=2, mode="nearest")
476
- if self.use_conv:
477
- x = self.conv(x)
478
- return x
479
-
480
-
481
- class Downsample(nn.Module):
482
- """
483
- A downsampling layer with an optional convolution.
484
- :param channels: channels in the inputs and outputs.
485
- :param use_conv: a bool determining if a convolution is applied.
486
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
487
- downsampling occurs in the inner-two dimensions.
488
- """
489
-
490
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
491
- super().__init__()
492
- self.channels = channels
493
- self.out_channels = out_channels or channels
494
- self.use_conv = use_conv
495
- self.dims = dims
496
- stride = 2 if dims != 3 else (1, 2, 2)
497
- if use_conv:
498
- self.op = conv_nd(
499
- dims,
500
- self.channels,
501
- self.out_channels,
502
- 3,
503
- stride=stride,
504
- padding=padding,
505
- )
506
- else:
507
- assert self.channels == self.out_channels
508
- self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
509
-
510
- def forward(self, x):
511
- assert x.shape[1] == self.channels
512
- return self.op(x)
513
-
514
-
515
- class ResBlock(nn.Module):
516
- """
517
- A residual block that can optionally change the number of channels.
518
- :param channels: the number of input channels.
519
- :param emb_channels: the number of timestep embedding channels.
520
- :param dropout: the rate of dropout.
521
- :param out_channels: if specified, the number of out channels.
522
- :param use_conv: if True and out_channels is specified, use a spatial
523
- convolution instead of a smaller 1x1 convolution to change the
524
- channels in the skip connection.
525
- :param dims: determines if the signal is 1D, 2D, or 3D.
526
- :param up: if True, use this block for upsampling.
527
- :param down: if True, use this block for downsampling.
528
- """
529
-
530
- def __init__(
531
- self,
532
- channels,
533
- emb_channels,
534
- dropout,
535
- out_channels=None,
536
- use_conv=False,
537
- use_scale_shift_norm=False,
538
- dims=2,
539
- up=False,
540
- down=False,
541
- ):
542
- super().__init__()
543
- self.channels = channels
544
- self.emb_channels = emb_channels
545
- self.dropout = dropout
546
- self.out_channels = out_channels or channels
547
- self.use_conv = use_conv
548
- self.use_scale_shift_norm = use_scale_shift_norm
549
-
550
- self.in_layers = nn.Sequential(
551
- nn.GroupNorm(32, channels),
552
- nn.SiLU(),
553
- conv_nd(dims, channels, self.out_channels, 3, padding=1),
554
- )
555
-
556
- self.updown = up or down
557
-
558
- if up:
559
- self.h_upd = Upsample(channels, False, dims)
560
- self.x_upd = Upsample(channels, False, dims)
561
- elif down:
562
- self.h_upd = Downsample(channels, False, dims)
563
- self.x_upd = Downsample(channels, False, dims)
564
- else:
565
- self.h_upd = self.x_upd = nn.Identity()
566
-
567
- self.emb_layers = nn.Sequential(
568
- nn.SiLU(),
569
- nn.Linear(
570
- emb_channels,
571
- 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
572
- ),
573
- )
574
- self.out_layers = nn.Sequential(
575
- nn.GroupNorm(32, self.out_channels),
576
- nn.SiLU(),
577
- nn.Dropout(p=dropout),
578
- zero_module(
579
- conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
580
- ),
581
- )
582
-
583
- if self.out_channels == channels:
584
- self.skip_connection = nn.Identity()
585
- elif use_conv:
586
- self.skip_connection = conv_nd(
587
- dims, channels, self.out_channels, 3, padding=1
588
- )
589
- else:
590
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
591
-
592
- def forward(self, x, emb):
593
- if self.updown:
594
- in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
595
- h = in_rest(x)
596
- h = self.h_upd(h)
597
- x = self.x_upd(x)
598
- h = in_conv(h)
599
- else:
600
- h = self.in_layers(x)
601
- emb_out = self.emb_layers(emb).type(h.dtype)
602
- while len(emb_out.shape) < len(h.shape):
603
- emb_out = emb_out[..., None]
604
- if self.use_scale_shift_norm:
605
- out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
606
- scale, shift = torch.chunk(emb_out, 2, dim=1)
607
- h = out_norm(h) * (1 + scale) + shift
608
- h = out_rest(h)
609
- else:
610
- h = h + emb_out
611
- h = self.out_layers(h)
612
- return self.skip_connection(x) + h
613
-
614
-
615
- class MultiViewUNetModel(ModelMixin, ConfigMixin):
616
- """
617
- The full multi-view UNet model with attention, timestep embedding and camera embedding.
618
- :param in_channels: channels in the input Tensor.
619
- :param model_channels: base channel count for the model.
620
- :param out_channels: channels in the output Tensor.
621
- :param num_res_blocks: number of residual blocks per downsample.
622
- :param attention_resolutions: a collection of downsample rates at which
623
- attention will take place. May be a set, list, or tuple.
624
- For example, if this contains 4, then at 4x downsampling, attention
625
- will be used.
626
- :param dropout: the dropout probability.
627
- :param channel_mult: channel multiplier for each level of the UNet.
628
- :param conv_resample: if True, use learned convolutions for upsampling and
629
- downsampling.
630
- :param dims: determines if the signal is 1D, 2D, or 3D.
631
- :param num_classes: if specified (as an int), then this model will be
632
- class-conditional with `num_classes` classes.
633
- :param num_heads: the number of attention heads in each attention layer.
634
- :param num_heads_channels: if specified, ignore num_heads and instead use
635
- a fixed channel width per attention head.
636
- :param num_heads_upsample: works with num_heads to set a different number
637
- of heads for upsampling. Deprecated.
638
- :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
639
- :param resblock_updown: use residual blocks for up/downsampling.
640
- :param use_new_attention_order: use a different attention pattern for potentially
641
- increased efficiency.
642
- :param camera_dim: dimensionality of camera input.
643
- """
644
-
645
- def __init__(
646
- self,
647
- image_size,
648
- in_channels,
649
- model_channels,
650
- out_channels,
651
- num_res_blocks,
652
- attention_resolutions,
653
- dropout=0,
654
- channel_mult=(1, 2, 4, 8),
655
- conv_resample=True,
656
- dims=2,
657
- num_classes=None,
658
- num_heads=-1,
659
- num_head_channels=-1,
660
- num_heads_upsample=-1,
661
- use_scale_shift_norm=False,
662
- resblock_updown=False,
663
- transformer_depth=1,
664
- context_dim=None,
665
- n_embed=None,
666
- num_attention_blocks=None,
667
- adm_in_channels=None,
668
- camera_dim=None,
669
- ip_dim=0, # imagedream uses ip_dim > 0
670
- ip_weight=1.0,
671
- **kwargs,
672
- ):
673
- super().__init__()
674
- assert context_dim is not None
675
-
676
- if num_heads_upsample == -1:
677
- num_heads_upsample = num_heads
678
-
679
- if num_heads == -1:
680
- assert (
681
- num_head_channels != -1
682
- ), "Either num_heads or num_head_channels has to be set"
683
-
684
- if num_head_channels == -1:
685
- assert (
686
- num_heads != -1
687
- ), "Either num_heads or num_head_channels has to be set"
688
-
689
- self.image_size = image_size
690
- self.in_channels = in_channels
691
- self.model_channels = model_channels
692
- self.out_channels = out_channels
693
- if isinstance(num_res_blocks, int):
694
- self.num_res_blocks = len(channel_mult) * [num_res_blocks]
695
- else:
696
- if len(num_res_blocks) != len(channel_mult):
697
- raise ValueError(
698
- "provide num_res_blocks either as an int (globally constant) or "
699
- "as a list/tuple (per-level) with the same length as channel_mult"
700
- )
701
- self.num_res_blocks = num_res_blocks
702
-
703
- if num_attention_blocks is not None:
704
- assert len(num_attention_blocks) == len(self.num_res_blocks)
705
- assert all(
706
- map(
707
- lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
708
- range(len(num_attention_blocks)),
709
- )
710
- )
711
- print(
712
- f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
713
- f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
714
- f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
715
- f"attention will still not be set."
716
- )
717
-
718
- self.attention_resolutions = attention_resolutions
719
- self.dropout = dropout
720
- self.channel_mult = channel_mult
721
- self.conv_resample = conv_resample
722
- self.num_classes = num_classes
723
- self.num_heads = num_heads
724
- self.num_head_channels = num_head_channels
725
- self.num_heads_upsample = num_heads_upsample
726
- self.predict_codebook_ids = n_embed is not None
727
-
728
- self.ip_dim = ip_dim
729
- self.ip_weight = ip_weight
730
-
731
- if self.ip_dim > 0:
732
- self.image_embed = Resampler(
733
- dim=context_dim,
734
- depth=4,
735
- dim_head=64,
736
- heads=12,
737
- num_queries=ip_dim, # num token
738
- embedding_dim=1280,
739
- output_dim=context_dim,
740
- ff_mult=4,
741
- )
742
-
743
- time_embed_dim = model_channels * 4
744
- self.time_embed = nn.Sequential(
745
- nn.Linear(model_channels, time_embed_dim),
746
- nn.SiLU(),
747
- nn.Linear(time_embed_dim, time_embed_dim),
748
- )
749
-
750
- if camera_dim is not None:
751
- time_embed_dim = model_channels * 4
752
- self.camera_embed = nn.Sequential(
753
- nn.Linear(camera_dim, time_embed_dim),
754
- nn.SiLU(),
755
- nn.Linear(time_embed_dim, time_embed_dim),
756
- )
757
-
758
- if self.num_classes is not None:
759
- if isinstance(self.num_classes, int):
760
- self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
761
- elif self.num_classes == "continuous":
762
- # print("setting up linear c_adm embedding layer")
763
- self.label_emb = nn.Linear(1, time_embed_dim)
764
- elif self.num_classes == "sequential":
765
- assert adm_in_channels is not None
766
- self.label_emb = nn.Sequential(
767
- nn.Sequential(
768
- nn.Linear(adm_in_channels, time_embed_dim),
769
- nn.SiLU(),
770
- nn.Linear(time_embed_dim, time_embed_dim),
771
- )
772
- )
773
- else:
774
- raise ValueError()
775
-
776
- self.input_blocks = nn.ModuleList(
777
- [
778
- CondSequential(
779
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
780
- )
781
- ]
782
- )
783
- self._feature_size = model_channels
784
- input_block_chans = [model_channels]
785
- ch = model_channels
786
- ds = 1
787
- for level, mult in enumerate(channel_mult):
788
- for nr in range(self.num_res_blocks[level]):
789
- layers: List[Any] = [
790
- ResBlock(
791
- ch,
792
- time_embed_dim,
793
- dropout,
794
- out_channels=mult * model_channels,
795
- dims=dims,
796
- use_scale_shift_norm=use_scale_shift_norm,
797
- )
798
- ]
799
- ch = mult * model_channels
800
- if ds in attention_resolutions:
801
- if num_head_channels == -1:
802
- dim_head = ch // num_heads
803
- else:
804
- num_heads = ch // num_head_channels
805
- dim_head = num_head_channels
806
-
807
- if num_attention_blocks is None or nr < num_attention_blocks[level]:
808
- layers.append(
809
- SpatialTransformer3D(
810
- ch,
811
- num_heads,
812
- dim_head,
813
- context_dim=context_dim,
814
- depth=transformer_depth,
815
- ip_dim=self.ip_dim,
816
- ip_weight=self.ip_weight,
817
- )
818
- )
819
- self.input_blocks.append(CondSequential(*layers))
820
- self._feature_size += ch
821
- input_block_chans.append(ch)
822
- if level != len(channel_mult) - 1:
823
- out_ch = ch
824
- self.input_blocks.append(
825
- CondSequential(
826
- ResBlock(
827
- ch,
828
- time_embed_dim,
829
- dropout,
830
- out_channels=out_ch,
831
- dims=dims,
832
- use_scale_shift_norm=use_scale_shift_norm,
833
- down=True,
834
- )
835
- if resblock_updown
836
- else Downsample(
837
- ch, conv_resample, dims=dims, out_channels=out_ch
838
- )
839
- )
840
- )
841
- ch = out_ch
842
- input_block_chans.append(ch)
843
- ds *= 2
844
- self._feature_size += ch
845
-
846
- if num_head_channels == -1:
847
- dim_head = ch // num_heads
848
- else:
849
- num_heads = ch // num_head_channels
850
- dim_head = num_head_channels
851
-
852
- self.middle_block = CondSequential(
853
- ResBlock(
854
- ch,
855
- time_embed_dim,
856
- dropout,
857
- dims=dims,
858
- use_scale_shift_norm=use_scale_shift_norm,
859
- ),
860
- SpatialTransformer3D(
861
- ch,
862
- num_heads,
863
- dim_head,
864
- context_dim=context_dim,
865
- depth=transformer_depth,
866
- ip_dim=self.ip_dim,
867
- ip_weight=self.ip_weight,
868
- ),
869
- ResBlock(
870
- ch,
871
- time_embed_dim,
872
- dropout,
873
- dims=dims,
874
- use_scale_shift_norm=use_scale_shift_norm,
875
- ),
876
- )
877
- self._feature_size += ch
878
-
879
- self.output_blocks = nn.ModuleList([])
880
- for level, mult in list(enumerate(channel_mult))[::-1]:
881
- for i in range(self.num_res_blocks[level] + 1):
882
- ich = input_block_chans.pop()
883
- layers = [
884
- ResBlock(
885
- ch + ich,
886
- time_embed_dim,
887
- dropout,
888
- out_channels=model_channels * mult,
889
- dims=dims,
890
- use_scale_shift_norm=use_scale_shift_norm,
891
- )
892
- ]
893
- ch = model_channels * mult
894
- if ds in attention_resolutions:
895
- if num_head_channels == -1:
896
- dim_head = ch // num_heads
897
- else:
898
- num_heads = ch // num_head_channels
899
- dim_head = num_head_channels
900
-
901
- if num_attention_blocks is None or i < num_attention_blocks[level]:
902
- layers.append(
903
- SpatialTransformer3D(
904
- ch,
905
- num_heads,
906
- dim_head,
907
- context_dim=context_dim,
908
- depth=transformer_depth,
909
- ip_dim=self.ip_dim,
910
- ip_weight=self.ip_weight,
911
- )
912
- )
913
- if level and i == self.num_res_blocks[level]:
914
- out_ch = ch
915
- layers.append(
916
- ResBlock(
917
- ch,
918
- time_embed_dim,
919
- dropout,
920
- out_channels=out_ch,
921
- dims=dims,
922
- use_scale_shift_norm=use_scale_shift_norm,
923
- up=True,
924
- )
925
- if resblock_updown
926
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
927
- )
928
- ds //= 2
929
- self.output_blocks.append(CondSequential(*layers))
930
- self._feature_size += ch
931
-
932
- self.out = nn.Sequential(
933
- nn.GroupNorm(32, ch),
934
- nn.SiLU(),
935
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
936
- )
937
- if self.predict_codebook_ids:
938
- self.id_predictor = nn.Sequential(
939
- nn.GroupNorm(32, ch),
940
- conv_nd(dims, model_channels, n_embed, 1),
941
- # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
942
- )
943
-
944
- def forward(
945
- self,
946
- x,
947
- timesteps=None,
948
- context=None,
949
- y=None,
950
- camera=None,
951
- num_frames=1,
952
- ip=None,
953
- ip_img=None,
954
- **kwargs,
955
- ):
956
- """
957
- Apply the model to an input batch.
958
- :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
959
- :param timesteps: a 1-D batch of timesteps.
960
- :param context: conditioning plugged in via crossattn
961
- :param y: an [N] Tensor of labels, if class-conditional.
962
- :param num_frames: a integer indicating number of frames for tensor reshaping.
963
- :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
964
- """
965
- assert (
966
- x.shape[0] % num_frames == 0
967
- ), "input batch size must be dividable by num_frames!"
968
- assert (y is not None) == (
969
- self.num_classes is not None
970
- ), "must specify y if and only if the model is class-conditional"
971
-
972
- hs = []
973
-
974
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
975
-
976
- emb = self.time_embed(t_emb)
977
-
978
- if self.num_classes is not None:
979
- assert y is not None
980
- assert y.shape[0] == x.shape[0]
981
- emb = emb + self.label_emb(y)
982
-
983
- # Add camera embeddings
984
- if camera is not None:
985
- emb = emb + self.camera_embed(camera)
986
-
987
- # imagedream variant
988
- if self.ip_dim > 0:
989
- x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9]
990
- ip_emb = self.image_embed(ip)
991
- context = torch.cat((context, ip_emb), 1)
992
-
993
- h = x
994
- for module in self.input_blocks:
995
- h = module(h, emb, context, num_frames=num_frames)
996
- hs.append(h)
997
- h = self.middle_block(h, emb, context, num_frames=num_frames)
998
- for module in self.output_blocks:
999
- h = torch.cat([h, hs.pop()], dim=1)
1000
- h = module(h, emb, context, num_frames=num_frames)
1001
- h = h.type(x.dtype)
1002
- if self.predict_codebook_ids:
1003
- return self.id_predictor(h)
1004
- else:
1005
  return self.out(h)
 
1
+ import math
2
+ import numpy as np
3
+ from inspect import isfunction
4
+ from typing import Optional, Any, List
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+
11
+ from diffusers.configuration_utils import ConfigMixin
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+
14
+ # require xformers!
15
+ import xformers
16
+ import xformers.ops
17
+
18
+ from kiui.cam import orbit_camera
19
+
20
+ def get_camera(
21
+ num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False,
22
+ ):
23
+ angle_gap = azimuth_span / num_frames
24
+ cameras = []
25
+ for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
26
+
27
+ pose = orbit_camera(-elevation, azimuth, radius=1) # kiui's elevation is negated, [4, 4]
28
+
29
+ # opengl to blender
30
+ if blender_coord:
31
+ pose[2] *= -1
32
+ pose[[1, 2]] = pose[[2, 1]]
33
+
34
+ cameras.append(pose.flatten())
35
+
36
+ if extra_view:
37
+ cameras.append(np.zeros_like(cameras[0]))
38
+
39
+ return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
40
+
41
+
42
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
43
+ """
44
+ Create sinusoidal timestep embeddings.
45
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
46
+ These may be fractional.
47
+ :param dim: the dimension of the output.
48
+ :param max_period: controls the minimum frequency of the embeddings.
49
+ :return: an [N x dim] Tensor of positional embeddings.
50
+ """
51
+ if not repeat_only:
52
+ half = dim // 2
53
+ freqs = torch.exp(
54
+ -math.log(max_period)
55
+ * torch.arange(start=0, end=half, dtype=torch.float32)
56
+ / half
57
+ ).to(device=timesteps.device)
58
+ args = timesteps[:, None] * freqs[None]
59
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
60
+ if dim % 2:
61
+ embedding = torch.cat(
62
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
63
+ )
64
+ else:
65
+ embedding = repeat(timesteps, "b -> b d", d=dim)
66
+ # import pdb; pdb.set_trace()
67
+ return embedding
68
+
69
+
70
+ def zero_module(module):
71
+ """
72
+ Zero out the parameters of a module and return it.
73
+ """
74
+ for p in module.parameters():
75
+ p.detach().zero_()
76
+ return module
77
+
78
+
79
+ def conv_nd(dims, *args, **kwargs):
80
+ """
81
+ Create a 1D, 2D, or 3D convolution module.
82
+ """
83
+ if dims == 1:
84
+ return nn.Conv1d(*args, **kwargs)
85
+ elif dims == 2:
86
+ return nn.Conv2d(*args, **kwargs)
87
+ elif dims == 3:
88
+ return nn.Conv3d(*args, **kwargs)
89
+ raise ValueError(f"unsupported dimensions: {dims}")
90
+
91
+
92
+ def avg_pool_nd(dims, *args, **kwargs):
93
+ """
94
+ Create a 1D, 2D, or 3D average pooling module.
95
+ """
96
+ if dims == 1:
97
+ return nn.AvgPool1d(*args, **kwargs)
98
+ elif dims == 2:
99
+ return nn.AvgPool2d(*args, **kwargs)
100
+ elif dims == 3:
101
+ return nn.AvgPool3d(*args, **kwargs)
102
+ raise ValueError(f"unsupported dimensions: {dims}")
103
+
104
+
105
+ def default(val, d):
106
+ if val is not None:
107
+ return val
108
+ return d() if isfunction(d) else d
109
+
110
+
111
+ class GEGLU(nn.Module):
112
+ def __init__(self, dim_in, dim_out):
113
+ super().__init__()
114
+ self.proj = nn.Linear(dim_in, dim_out * 2)
115
+
116
+ def forward(self, x):
117
+ x, gate = self.proj(x).chunk(2, dim=-1)
118
+ return x * F.gelu(gate)
119
+
120
+
121
+ class FeedForward(nn.Module):
122
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
123
+ super().__init__()
124
+ inner_dim = int(dim * mult)
125
+ dim_out = default(dim_out, dim)
126
+ project_in = (
127
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
128
+ if not glu
129
+ else GEGLU(dim, inner_dim)
130
+ )
131
+
132
+ self.net = nn.Sequential(
133
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
134
+ )
135
+
136
+ def forward(self, x):
137
+ return self.net(x)
138
+
139
+
140
+ class MemoryEfficientCrossAttention(nn.Module):
141
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
142
+ def __init__(
143
+ self,
144
+ query_dim,
145
+ context_dim=None,
146
+ heads=8,
147
+ dim_head=64,
148
+ dropout=0.0,
149
+ ip_dim=0,
150
+ ip_weight=1,
151
+ ):
152
+ super().__init__()
153
+
154
+ inner_dim = dim_head * heads
155
+ context_dim = default(context_dim, query_dim)
156
+
157
+ self.heads = heads
158
+ self.dim_head = dim_head
159
+
160
+ self.ip_dim = ip_dim
161
+ self.ip_weight = ip_weight
162
+
163
+ if self.ip_dim > 0:
164
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
165
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
166
+
167
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
168
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
169
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
170
+
171
+ self.to_out = nn.Sequential(
172
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
173
+ )
174
+ self.attention_op: Optional[Any] = None
175
+
176
+ def forward(self, x, context=None):
177
+ q = self.to_q(x)
178
+ context = default(context, x)
179
+
180
+ if self.ip_dim > 0:
181
+ # context: [B, 77 + 16(ip), 1024]
182
+ token_len = context.shape[1]
183
+ context_ip = context[:, -self.ip_dim :, :]
184
+ k_ip = self.to_k_ip(context_ip)
185
+ v_ip = self.to_v_ip(context_ip)
186
+ context = context[:, : (token_len - self.ip_dim), :]
187
+
188
+ k = self.to_k(context)
189
+ v = self.to_v(context)
190
+
191
+ b, _, _ = q.shape
192
+ q, k, v = map(
193
+ lambda t: t.unsqueeze(3)
194
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
195
+ .permute(0, 2, 1, 3)
196
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
197
+ .contiguous(),
198
+ (q, k, v),
199
+ )
200
+
201
+ # actually compute the attention, what we cannot get enough of
202
+ out = xformers.ops.memory_efficient_attention(
203
+ q, k, v, attn_bias=None, op=self.attention_op
204
+ )
205
+
206
+ if self.ip_dim > 0:
207
+ k_ip, v_ip = map(
208
+ lambda t: t.unsqueeze(3)
209
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
210
+ .permute(0, 2, 1, 3)
211
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
212
+ .contiguous(),
213
+ (k_ip, v_ip),
214
+ )
215
+ # actually compute the attention, what we cannot get enough of
216
+ out_ip = xformers.ops.memory_efficient_attention(
217
+ q, k_ip, v_ip, attn_bias=None, op=self.attention_op
218
+ )
219
+ out = out + self.ip_weight * out_ip
220
+
221
+ out = (
222
+ out.unsqueeze(0)
223
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
224
+ .permute(0, 2, 1, 3)
225
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
226
+ )
227
+ return self.to_out(out)
228
+
229
+
230
+ class BasicTransformerBlock3D(nn.Module):
231
+
232
+ def __init__(
233
+ self,
234
+ dim,
235
+ n_heads,
236
+ d_head,
237
+ context_dim,
238
+ dropout=0.0,
239
+ gated_ff=True,
240
+ ip_dim=0,
241
+ ip_weight=1,
242
+ ):
243
+ super().__init__()
244
+
245
+ self.attn1 = MemoryEfficientCrossAttention(
246
+ query_dim=dim,
247
+ context_dim=None, # self-attention
248
+ heads=n_heads,
249
+ dim_head=d_head,
250
+ dropout=dropout,
251
+ )
252
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
253
+ self.attn2 = MemoryEfficientCrossAttention(
254
+ query_dim=dim,
255
+ context_dim=context_dim,
256
+ heads=n_heads,
257
+ dim_head=d_head,
258
+ dropout=dropout,
259
+ # ip only applies to cross-attention
260
+ ip_dim=ip_dim,
261
+ ip_weight=ip_weight,
262
+ )
263
+ self.norm1 = nn.LayerNorm(dim)
264
+ self.norm2 = nn.LayerNorm(dim)
265
+ self.norm3 = nn.LayerNorm(dim)
266
+
267
+ def forward(self, x, context=None, num_frames=1):
268
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
269
+ x = self.attn1(self.norm1(x), context=None) + x
270
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
271
+ x = self.attn2(self.norm2(x), context=context) + x
272
+ x = self.ff(self.norm3(x)) + x
273
+ return x
274
+
275
+
276
+ class SpatialTransformer3D(nn.Module):
277
+
278
+ def __init__(
279
+ self,
280
+ in_channels,
281
+ n_heads,
282
+ d_head,
283
+ context_dim, # cross attention input dim
284
+ depth=1,
285
+ dropout=0.0,
286
+ ip_dim=0,
287
+ ip_weight=1,
288
+ ):
289
+ super().__init__()
290
+
291
+ if not isinstance(context_dim, list):
292
+ context_dim = [context_dim]
293
+
294
+ self.in_channels = in_channels
295
+
296
+ inner_dim = n_heads * d_head
297
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
298
+ self.proj_in = nn.Linear(in_channels, inner_dim)
299
+
300
+ self.transformer_blocks = nn.ModuleList(
301
+ [
302
+ BasicTransformerBlock3D(
303
+ inner_dim,
304
+ n_heads,
305
+ d_head,
306
+ context_dim=context_dim[d],
307
+ dropout=dropout,
308
+ ip_dim=ip_dim,
309
+ ip_weight=ip_weight,
310
+ )
311
+ for d in range(depth)
312
+ ]
313
+ )
314
+
315
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
316
+
317
+
318
+ def forward(self, x, context=None, num_frames=1):
319
+ # note: if no context is given, cross-attention defaults to self-attention
320
+ if not isinstance(context, list):
321
+ context = [context]
322
+ b, c, h, w = x.shape
323
+ x_in = x
324
+ x = self.norm(x)
325
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
326
+ x = self.proj_in(x)
327
+ for i, block in enumerate(self.transformer_blocks):
328
+ x = block(x, context=context[i], num_frames=num_frames)
329
+ x = self.proj_out(x)
330
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
331
+
332
+ return x + x_in
333
+
334
+
335
+ class PerceiverAttention(nn.Module):
336
+ def __init__(self, *, dim, dim_head=64, heads=8):
337
+ super().__init__()
338
+ self.scale = dim_head ** -0.5
339
+ self.dim_head = dim_head
340
+ self.heads = heads
341
+ inner_dim = dim_head * heads
342
+
343
+ self.norm1 = nn.LayerNorm(dim)
344
+ self.norm2 = nn.LayerNorm(dim)
345
+
346
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
347
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
348
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
349
+
350
+ def forward(self, x, latents):
351
+ """
352
+ Args:
353
+ x (torch.Tensor): image features
354
+ shape (b, n1, D)
355
+ latent (torch.Tensor): latent features
356
+ shape (b, n2, D)
357
+ """
358
+ x = self.norm1(x)
359
+ latents = self.norm2(latents)
360
+
361
+ b, l, _ = latents.shape
362
+
363
+ q = self.to_q(latents)
364
+ kv_input = torch.cat((x, latents), dim=-2)
365
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
366
+
367
+ q, k, v = map(
368
+ lambda t: t.reshape(b, t.shape[1], self.heads, -1)
369
+ .transpose(1, 2)
370
+ .reshape(b, self.heads, t.shape[1], -1)
371
+ .contiguous(),
372
+ (q, k, v),
373
+ )
374
+
375
+ # attention
376
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
377
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
378
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
379
+ out = weight @ v
380
+
381
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
382
+
383
+ return self.to_out(out)
384
+
385
+
386
+ class Resampler(nn.Module):
387
+ def __init__(
388
+ self,
389
+ dim=1024,
390
+ depth=8,
391
+ dim_head=64,
392
+ heads=16,
393
+ num_queries=8,
394
+ embedding_dim=768,
395
+ output_dim=1024,
396
+ ff_mult=4,
397
+ ):
398
+ super().__init__()
399
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
400
+ self.proj_in = nn.Linear(embedding_dim, dim)
401
+ self.proj_out = nn.Linear(dim, output_dim)
402
+ self.norm_out = nn.LayerNorm(output_dim)
403
+
404
+ self.layers = nn.ModuleList([])
405
+ for _ in range(depth):
406
+ self.layers.append(
407
+ nn.ModuleList(
408
+ [
409
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
410
+ nn.Sequential(
411
+ nn.LayerNorm(dim),
412
+ nn.Linear(dim, dim * ff_mult, bias=False),
413
+ nn.GELU(),
414
+ nn.Linear(dim * ff_mult, dim, bias=False),
415
+ )
416
+ ]
417
+ )
418
+ )
419
+
420
+ def forward(self, x):
421
+ latents = self.latents.repeat(x.size(0), 1, 1)
422
+ x = self.proj_in(x)
423
+ for attn, ff in self.layers:
424
+ latents = attn(x, latents) + latents
425
+ latents = ff(latents) + latents
426
+
427
+ latents = self.proj_out(latents)
428
+ return self.norm_out(latents)
429
+
430
+
431
+ class CondSequential(nn.Sequential):
432
+ """
433
+ A sequential module that passes timestep embeddings to the children that
434
+ support it as an extra input.
435
+ """
436
+
437
+ def forward(self, x, emb, context=None, num_frames=1):
438
+ for layer in self:
439
+ if isinstance(layer, ResBlock):
440
+ x = layer(x, emb)
441
+ elif isinstance(layer, SpatialTransformer3D):
442
+ x = layer(x, context, num_frames=num_frames)
443
+ else:
444
+ x = layer(x)
445
+ return x
446
+
447
+
448
+ class Upsample(nn.Module):
449
+ """
450
+ An upsampling layer with an optional convolution.
451
+ :param channels: channels in the inputs and outputs.
452
+ :param use_conv: a bool determining if a convolution is applied.
453
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
454
+ upsampling occurs in the inner-two dimensions.
455
+ """
456
+
457
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
458
+ super().__init__()
459
+ self.channels = channels
460
+ self.out_channels = out_channels or channels
461
+ self.use_conv = use_conv
462
+ self.dims = dims
463
+ if use_conv:
464
+ self.conv = conv_nd(
465
+ dims, self.channels, self.out_channels, 3, padding=padding
466
+ )
467
+
468
+ def forward(self, x):
469
+ assert x.shape[1] == self.channels
470
+ if self.dims == 3:
471
+ x = F.interpolate(
472
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
473
+ )
474
+ else:
475
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
476
+ if self.use_conv:
477
+ x = self.conv(x)
478
+ return x
479
+
480
+
481
+ class Downsample(nn.Module):
482
+ """
483
+ A downsampling layer with an optional convolution.
484
+ :param channels: channels in the inputs and outputs.
485
+ :param use_conv: a bool determining if a convolution is applied.
486
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
487
+ downsampling occurs in the inner-two dimensions.
488
+ """
489
+
490
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
491
+ super().__init__()
492
+ self.channels = channels
493
+ self.out_channels = out_channels or channels
494
+ self.use_conv = use_conv
495
+ self.dims = dims
496
+ stride = 2 if dims != 3 else (1, 2, 2)
497
+ if use_conv:
498
+ self.op = conv_nd(
499
+ dims,
500
+ self.channels,
501
+ self.out_channels,
502
+ 3,
503
+ stride=stride,
504
+ padding=padding,
505
+ )
506
+ else:
507
+ assert self.channels == self.out_channels
508
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
509
+
510
+ def forward(self, x):
511
+ assert x.shape[1] == self.channels
512
+ return self.op(x)
513
+
514
+
515
+ class ResBlock(nn.Module):
516
+ """
517
+ A residual block that can optionally change the number of channels.
518
+ :param channels: the number of input channels.
519
+ :param emb_channels: the number of timestep embedding channels.
520
+ :param dropout: the rate of dropout.
521
+ :param out_channels: if specified, the number of out channels.
522
+ :param use_conv: if True and out_channels is specified, use a spatial
523
+ convolution instead of a smaller 1x1 convolution to change the
524
+ channels in the skip connection.
525
+ :param dims: determines if the signal is 1D, 2D, or 3D.
526
+ :param up: if True, use this block for upsampling.
527
+ :param down: if True, use this block for downsampling.
528
+ """
529
+
530
+ def __init__(
531
+ self,
532
+ channels,
533
+ emb_channels,
534
+ dropout,
535
+ out_channels=None,
536
+ use_conv=False,
537
+ use_scale_shift_norm=False,
538
+ dims=2,
539
+ up=False,
540
+ down=False,
541
+ ):
542
+ super().__init__()
543
+ self.channels = channels
544
+ self.emb_channels = emb_channels
545
+ self.dropout = dropout
546
+ self.out_channels = out_channels or channels
547
+ self.use_conv = use_conv
548
+ self.use_scale_shift_norm = use_scale_shift_norm
549
+
550
+ self.in_layers = nn.Sequential(
551
+ nn.GroupNorm(32, channels),
552
+ nn.SiLU(),
553
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
554
+ )
555
+
556
+ self.updown = up or down
557
+
558
+ if up:
559
+ self.h_upd = Upsample(channels, False, dims)
560
+ self.x_upd = Upsample(channels, False, dims)
561
+ elif down:
562
+ self.h_upd = Downsample(channels, False, dims)
563
+ self.x_upd = Downsample(channels, False, dims)
564
+ else:
565
+ self.h_upd = self.x_upd = nn.Identity()
566
+
567
+ self.emb_layers = nn.Sequential(
568
+ nn.SiLU(),
569
+ nn.Linear(
570
+ emb_channels,
571
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
572
+ ),
573
+ )
574
+ self.out_layers = nn.Sequential(
575
+ nn.GroupNorm(32, self.out_channels),
576
+ nn.SiLU(),
577
+ nn.Dropout(p=dropout),
578
+ zero_module(
579
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
580
+ ),
581
+ )
582
+
583
+ if self.out_channels == channels:
584
+ self.skip_connection = nn.Identity()
585
+ elif use_conv:
586
+ self.skip_connection = conv_nd(
587
+ dims, channels, self.out_channels, 3, padding=1
588
+ )
589
+ else:
590
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
591
+
592
+ def forward(self, x, emb):
593
+ if self.updown:
594
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
595
+ h = in_rest(x)
596
+ h = self.h_upd(h)
597
+ x = self.x_upd(x)
598
+ h = in_conv(h)
599
+ else:
600
+ h = self.in_layers(x)
601
+ emb_out = self.emb_layers(emb).type(h.dtype)
602
+ while len(emb_out.shape) < len(h.shape):
603
+ emb_out = emb_out[..., None]
604
+ if self.use_scale_shift_norm:
605
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
606
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
607
+ h = out_norm(h) * (1 + scale) + shift
608
+ h = out_rest(h)
609
+ else:
610
+ h = h + emb_out
611
+ h = self.out_layers(h)
612
+ return self.skip_connection(x) + h
613
+
614
+
615
+ class MultiViewUNetModel(ModelMixin, ConfigMixin):
616
+ """
617
+ The full multi-view UNet model with attention, timestep embedding and camera embedding.
618
+ :param in_channels: channels in the input Tensor.
619
+ :param model_channels: base channel count for the model.
620
+ :param out_channels: channels in the output Tensor.
621
+ :param num_res_blocks: number of residual blocks per downsample.
622
+ :param attention_resolutions: a collection of downsample rates at which
623
+ attention will take place. May be a set, list, or tuple.
624
+ For example, if this contains 4, then at 4x downsampling, attention
625
+ will be used.
626
+ :param dropout: the dropout probability.
627
+ :param channel_mult: channel multiplier for each level of the UNet.
628
+ :param conv_resample: if True, use learned convolutions for upsampling and
629
+ downsampling.
630
+ :param dims: determines if the signal is 1D, 2D, or 3D.
631
+ :param num_classes: if specified (as an int), then this model will be
632
+ class-conditional with `num_classes` classes.
633
+ :param num_heads: the number of attention heads in each attention layer.
634
+ :param num_heads_channels: if specified, ignore num_heads and instead use
635
+ a fixed channel width per attention head.
636
+ :param num_heads_upsample: works with num_heads to set a different number
637
+ of heads for upsampling. Deprecated.
638
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
639
+ :param resblock_updown: use residual blocks for up/downsampling.
640
+ :param use_new_attention_order: use a different attention pattern for potentially
641
+ increased efficiency.
642
+ :param camera_dim: dimensionality of camera input.
643
+ """
644
+
645
+ def __init__(
646
+ self,
647
+ image_size,
648
+ in_channels,
649
+ model_channels,
650
+ out_channels,
651
+ num_res_blocks,
652
+ attention_resolutions,
653
+ dropout=0,
654
+ channel_mult=(1, 2, 4, 8),
655
+ conv_resample=True,
656
+ dims=2,
657
+ num_classes=None,
658
+ num_heads=-1,
659
+ num_head_channels=-1,
660
+ num_heads_upsample=-1,
661
+ use_scale_shift_norm=False,
662
+ resblock_updown=False,
663
+ transformer_depth=1,
664
+ context_dim=None,
665
+ n_embed=None,
666
+ num_attention_blocks=None,
667
+ adm_in_channels=None,
668
+ camera_dim=None,
669
+ ip_dim=0, # imagedream uses ip_dim > 0
670
+ ip_weight=1.0,
671
+ **kwargs,
672
+ ):
673
+ super().__init__()
674
+ assert context_dim is not None
675
+
676
+ if num_heads_upsample == -1:
677
+ num_heads_upsample = num_heads
678
+
679
+ if num_heads == -1:
680
+ assert (
681
+ num_head_channels != -1
682
+ ), "Either num_heads or num_head_channels has to be set"
683
+
684
+ if num_head_channels == -1:
685
+ assert (
686
+ num_heads != -1
687
+ ), "Either num_heads or num_head_channels has to be set"
688
+
689
+ self.image_size = image_size
690
+ self.in_channels = in_channels
691
+ self.model_channels = model_channels
692
+ self.out_channels = out_channels
693
+ if isinstance(num_res_blocks, int):
694
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
695
+ else:
696
+ if len(num_res_blocks) != len(channel_mult):
697
+ raise ValueError(
698
+ "provide num_res_blocks either as an int (globally constant) or "
699
+ "as a list/tuple (per-level) with the same length as channel_mult"
700
+ )
701
+ self.num_res_blocks = num_res_blocks
702
+
703
+ if num_attention_blocks is not None:
704
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
705
+ assert all(
706
+ map(
707
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
708
+ range(len(num_attention_blocks)),
709
+ )
710
+ )
711
+ print(
712
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
713
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
714
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
715
+ f"attention will still not be set."
716
+ )
717
+
718
+ self.attention_resolutions = attention_resolutions
719
+ self.dropout = dropout
720
+ self.channel_mult = channel_mult
721
+ self.conv_resample = conv_resample
722
+ self.num_classes = num_classes
723
+ self.num_heads = num_heads
724
+ self.num_head_channels = num_head_channels
725
+ self.num_heads_upsample = num_heads_upsample
726
+ self.predict_codebook_ids = n_embed is not None
727
+
728
+ self.ip_dim = ip_dim
729
+ self.ip_weight = ip_weight
730
+
731
+ if self.ip_dim > 0:
732
+ self.image_embed = Resampler(
733
+ dim=context_dim,
734
+ depth=4,
735
+ dim_head=64,
736
+ heads=12,
737
+ num_queries=ip_dim, # num token
738
+ embedding_dim=1280,
739
+ output_dim=context_dim,
740
+ ff_mult=4,
741
+ )
742
+
743
+ time_embed_dim = model_channels * 4
744
+ self.time_embed = nn.Sequential(
745
+ nn.Linear(model_channels, time_embed_dim),
746
+ nn.SiLU(),
747
+ nn.Linear(time_embed_dim, time_embed_dim),
748
+ )
749
+
750
+ if camera_dim is not None:
751
+ time_embed_dim = model_channels * 4
752
+ self.camera_embed = nn.Sequential(
753
+ nn.Linear(camera_dim, time_embed_dim),
754
+ nn.SiLU(),
755
+ nn.Linear(time_embed_dim, time_embed_dim),
756
+ )
757
+
758
+ if self.num_classes is not None:
759
+ if isinstance(self.num_classes, int):
760
+ self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
761
+ elif self.num_classes == "continuous":
762
+ # print("setting up linear c_adm embedding layer")
763
+ self.label_emb = nn.Linear(1, time_embed_dim)
764
+ elif self.num_classes == "sequential":
765
+ assert adm_in_channels is not None
766
+ self.label_emb = nn.Sequential(
767
+ nn.Sequential(
768
+ nn.Linear(adm_in_channels, time_embed_dim),
769
+ nn.SiLU(),
770
+ nn.Linear(time_embed_dim, time_embed_dim),
771
+ )
772
+ )
773
+ else:
774
+ raise ValueError()
775
+
776
+ self.input_blocks = nn.ModuleList(
777
+ [
778
+ CondSequential(
779
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
780
+ )
781
+ ]
782
+ )
783
+ self._feature_size = model_channels
784
+ input_block_chans = [model_channels]
785
+ ch = model_channels
786
+ ds = 1
787
+ for level, mult in enumerate(channel_mult):
788
+ for nr in range(self.num_res_blocks[level]):
789
+ layers: List[Any] = [
790
+ ResBlock(
791
+ ch,
792
+ time_embed_dim,
793
+ dropout,
794
+ out_channels=mult * model_channels,
795
+ dims=dims,
796
+ use_scale_shift_norm=use_scale_shift_norm,
797
+ )
798
+ ]
799
+ ch = mult * model_channels
800
+ if ds in attention_resolutions:
801
+ if num_head_channels == -1:
802
+ dim_head = ch // num_heads
803
+ else:
804
+ num_heads = ch // num_head_channels
805
+ dim_head = num_head_channels
806
+
807
+ if num_attention_blocks is None or nr < num_attention_blocks[level]:
808
+ layers.append(
809
+ SpatialTransformer3D(
810
+ ch,
811
+ num_heads,
812
+ dim_head,
813
+ context_dim=context_dim,
814
+ depth=transformer_depth,
815
+ ip_dim=self.ip_dim,
816
+ ip_weight=self.ip_weight,
817
+ )
818
+ )
819
+ self.input_blocks.append(CondSequential(*layers))
820
+ self._feature_size += ch
821
+ input_block_chans.append(ch)
822
+ if level != len(channel_mult) - 1:
823
+ out_ch = ch
824
+ self.input_blocks.append(
825
+ CondSequential(
826
+ ResBlock(
827
+ ch,
828
+ time_embed_dim,
829
+ dropout,
830
+ out_channels=out_ch,
831
+ dims=dims,
832
+ use_scale_shift_norm=use_scale_shift_norm,
833
+ down=True,
834
+ )
835
+ if resblock_updown
836
+ else Downsample(
837
+ ch, conv_resample, dims=dims, out_channels=out_ch
838
+ )
839
+ )
840
+ )
841
+ ch = out_ch
842
+ input_block_chans.append(ch)
843
+ ds *= 2
844
+ self._feature_size += ch
845
+
846
+ if num_head_channels == -1:
847
+ dim_head = ch // num_heads
848
+ else:
849
+ num_heads = ch // num_head_channels
850
+ dim_head = num_head_channels
851
+
852
+ self.middle_block = CondSequential(
853
+ ResBlock(
854
+ ch,
855
+ time_embed_dim,
856
+ dropout,
857
+ dims=dims,
858
+ use_scale_shift_norm=use_scale_shift_norm,
859
+ ),
860
+ SpatialTransformer3D(
861
+ ch,
862
+ num_heads,
863
+ dim_head,
864
+ context_dim=context_dim,
865
+ depth=transformer_depth,
866
+ ip_dim=self.ip_dim,
867
+ ip_weight=self.ip_weight,
868
+ ),
869
+ ResBlock(
870
+ ch,
871
+ time_embed_dim,
872
+ dropout,
873
+ dims=dims,
874
+ use_scale_shift_norm=use_scale_shift_norm,
875
+ ),
876
+ )
877
+ self._feature_size += ch
878
+
879
+ self.output_blocks = nn.ModuleList([])
880
+ for level, mult in list(enumerate(channel_mult))[::-1]:
881
+ for i in range(self.num_res_blocks[level] + 1):
882
+ ich = input_block_chans.pop()
883
+ layers = [
884
+ ResBlock(
885
+ ch + ich,
886
+ time_embed_dim,
887
+ dropout,
888
+ out_channels=model_channels * mult,
889
+ dims=dims,
890
+ use_scale_shift_norm=use_scale_shift_norm,
891
+ )
892
+ ]
893
+ ch = model_channels * mult
894
+ if ds in attention_resolutions:
895
+ if num_head_channels == -1:
896
+ dim_head = ch // num_heads
897
+ else:
898
+ num_heads = ch // num_head_channels
899
+ dim_head = num_head_channels
900
+
901
+ if num_attention_blocks is None or i < num_attention_blocks[level]:
902
+ layers.append(
903
+ SpatialTransformer3D(
904
+ ch,
905
+ num_heads,
906
+ dim_head,
907
+ context_dim=context_dim,
908
+ depth=transformer_depth,
909
+ ip_dim=self.ip_dim,
910
+ ip_weight=self.ip_weight,
911
+ )
912
+ )
913
+ if level and i == self.num_res_blocks[level]:
914
+ out_ch = ch
915
+ layers.append(
916
+ ResBlock(
917
+ ch,
918
+ time_embed_dim,
919
+ dropout,
920
+ out_channels=out_ch,
921
+ dims=dims,
922
+ use_scale_shift_norm=use_scale_shift_norm,
923
+ up=True,
924
+ )
925
+ if resblock_updown
926
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
927
+ )
928
+ ds //= 2
929
+ self.output_blocks.append(CondSequential(*layers))
930
+ self._feature_size += ch
931
+
932
+ self.out = nn.Sequential(
933
+ nn.GroupNorm(32, ch),
934
+ nn.SiLU(),
935
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
936
+ )
937
+ if self.predict_codebook_ids:
938
+ self.id_predictor = nn.Sequential(
939
+ nn.GroupNorm(32, ch),
940
+ conv_nd(dims, model_channels, n_embed, 1),
941
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
942
+ )
943
+
944
+ def forward(
945
+ self,
946
+ x,
947
+ timesteps=None,
948
+ context=None,
949
+ y=None,
950
+ camera=None,
951
+ num_frames=1,
952
+ ip=None,
953
+ ip_img=None,
954
+ **kwargs,
955
+ ):
956
+ """
957
+ Apply the model to an input batch.
958
+ :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
959
+ :param timesteps: a 1-D batch of timesteps.
960
+ :param context: conditioning plugged in via crossattn
961
+ :param y: an [N] Tensor of labels, if class-conditional.
962
+ :param num_frames: a integer indicating number of frames for tensor reshaping.
963
+ :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
964
+ """
965
+ assert (
966
+ x.shape[0] % num_frames == 0
967
+ ), "input batch size must be dividable by num_frames!"
968
+ assert (y is not None) == (
969
+ self.num_classes is not None
970
+ ), "must specify y if and only if the model is class-conditional"
971
+
972
+ hs = []
973
+
974
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
975
+
976
+ emb = self.time_embed(t_emb)
977
+
978
+ if self.num_classes is not None:
979
+ assert y is not None
980
+ assert y.shape[0] == x.shape[0]
981
+ emb = emb + self.label_emb(y)
982
+
983
+ # Add camera embeddings
984
+ if camera is not None:
985
+ emb = emb + self.camera_embed(camera)
986
+
987
+ # imagedream variant
988
+ if self.ip_dim > 0 and ip:
989
+ x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9]
990
+ ip_emb = self.image_embed(ip)
991
+ context = torch.cat((context, ip_emb), 1)
992
+
993
+ h = x
994
+ for module in self.input_blocks:
995
+ h = module(h, emb, context, num_frames=num_frames)
996
+ hs.append(h)
997
+ h = self.middle_block(h, emb, context, num_frames=num_frames)
998
+ for module in self.output_blocks:
999
+ h = torch.cat([h, hs.pop()], dim=1)
1000
+ h = module(h, emb, context, num_frames=num_frames)
1001
+ h = h.type(x.dtype)
1002
+ if self.predict_codebook_ids:
1003
+ return self.id_predictor(h)
1004
+ else:
1005
  return self.out(h)