dylanebert HF staff commited on
Commit
a3fbb44
1 Parent(s): 1144e46

embed mv_unet

Browse files
Files changed (2) hide show
  1. mv_unet.py +0 -1005
  2. pipeline.py +1064 -25
mv_unet.py DELETED
@@ -1,1005 +0,0 @@
1
- import math
2
- import numpy as np
3
- from inspect import isfunction
4
- from typing import Optional, Any, List
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- from einops import rearrange, repeat
10
-
11
- from diffusers.configuration_utils import ConfigMixin
12
- from diffusers.models.modeling_utils import ModelMixin
13
-
14
- # require xformers!
15
- import xformers
16
- import xformers.ops
17
-
18
- from kiui.cam import orbit_camera
19
-
20
- def get_camera(
21
- num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False,
22
- ):
23
- angle_gap = azimuth_span / num_frames
24
- cameras = []
25
- for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
26
-
27
- pose = orbit_camera(-elevation, azimuth, radius=1) # kiui's elevation is negated, [4, 4]
28
-
29
- # opengl to blender
30
- if blender_coord:
31
- pose[2] *= -1
32
- pose[[1, 2]] = pose[[2, 1]]
33
-
34
- cameras.append(pose.flatten())
35
-
36
- if extra_view:
37
- cameras.append(np.zeros_like(cameras[0]))
38
-
39
- return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
40
-
41
-
42
- def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
43
- """
44
- Create sinusoidal timestep embeddings.
45
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
46
- These may be fractional.
47
- :param dim: the dimension of the output.
48
- :param max_period: controls the minimum frequency of the embeddings.
49
- :return: an [N x dim] Tensor of positional embeddings.
50
- """
51
- if not repeat_only:
52
- half = dim // 2
53
- freqs = torch.exp(
54
- -math.log(max_period)
55
- * torch.arange(start=0, end=half, dtype=torch.float32)
56
- / half
57
- ).to(device=timesteps.device)
58
- args = timesteps[:, None] * freqs[None]
59
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
60
- if dim % 2:
61
- embedding = torch.cat(
62
- [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
63
- )
64
- else:
65
- embedding = repeat(timesteps, "b -> b d", d=dim)
66
- # import pdb; pdb.set_trace()
67
- return embedding
68
-
69
-
70
- def zero_module(module):
71
- """
72
- Zero out the parameters of a module and return it.
73
- """
74
- for p in module.parameters():
75
- p.detach().zero_()
76
- return module
77
-
78
-
79
- def conv_nd(dims, *args, **kwargs):
80
- """
81
- Create a 1D, 2D, or 3D convolution module.
82
- """
83
- if dims == 1:
84
- return nn.Conv1d(*args, **kwargs)
85
- elif dims == 2:
86
- return nn.Conv2d(*args, **kwargs)
87
- elif dims == 3:
88
- return nn.Conv3d(*args, **kwargs)
89
- raise ValueError(f"unsupported dimensions: {dims}")
90
-
91
-
92
- def avg_pool_nd(dims, *args, **kwargs):
93
- """
94
- Create a 1D, 2D, or 3D average pooling module.
95
- """
96
- if dims == 1:
97
- return nn.AvgPool1d(*args, **kwargs)
98
- elif dims == 2:
99
- return nn.AvgPool2d(*args, **kwargs)
100
- elif dims == 3:
101
- return nn.AvgPool3d(*args, **kwargs)
102
- raise ValueError(f"unsupported dimensions: {dims}")
103
-
104
-
105
- def default(val, d):
106
- if val is not None:
107
- return val
108
- return d() if isfunction(d) else d
109
-
110
-
111
- class GEGLU(nn.Module):
112
- def __init__(self, dim_in, dim_out):
113
- super().__init__()
114
- self.proj = nn.Linear(dim_in, dim_out * 2)
115
-
116
- def forward(self, x):
117
- x, gate = self.proj(x).chunk(2, dim=-1)
118
- return x * F.gelu(gate)
119
-
120
-
121
- class FeedForward(nn.Module):
122
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
123
- super().__init__()
124
- inner_dim = int(dim * mult)
125
- dim_out = default(dim_out, dim)
126
- project_in = (
127
- nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
128
- if not glu
129
- else GEGLU(dim, inner_dim)
130
- )
131
-
132
- self.net = nn.Sequential(
133
- project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
134
- )
135
-
136
- def forward(self, x):
137
- return self.net(x)
138
-
139
-
140
- class MemoryEfficientCrossAttention(nn.Module):
141
- # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
142
- def __init__(
143
- self,
144
- query_dim,
145
- context_dim=None,
146
- heads=8,
147
- dim_head=64,
148
- dropout=0.0,
149
- ip_dim=0,
150
- ip_weight=1,
151
- ):
152
- super().__init__()
153
-
154
- inner_dim = dim_head * heads
155
- context_dim = default(context_dim, query_dim)
156
-
157
- self.heads = heads
158
- self.dim_head = dim_head
159
-
160
- self.ip_dim = ip_dim
161
- self.ip_weight = ip_weight
162
-
163
- if self.ip_dim > 0:
164
- self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
165
- self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
166
-
167
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
168
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
169
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
170
-
171
- self.to_out = nn.Sequential(
172
- nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
173
- )
174
- self.attention_op: Optional[Any] = None
175
-
176
- def forward(self, x, context=None):
177
- q = self.to_q(x)
178
- context = default(context, x)
179
-
180
- if self.ip_dim > 0:
181
- # context: [B, 77 + 16(ip), 1024]
182
- token_len = context.shape[1]
183
- context_ip = context[:, -self.ip_dim :, :]
184
- k_ip = self.to_k_ip(context_ip)
185
- v_ip = self.to_v_ip(context_ip)
186
- context = context[:, : (token_len - self.ip_dim), :]
187
-
188
- k = self.to_k(context)
189
- v = self.to_v(context)
190
-
191
- b, _, _ = q.shape
192
- q, k, v = map(
193
- lambda t: t.unsqueeze(3)
194
- .reshape(b, t.shape[1], self.heads, self.dim_head)
195
- .permute(0, 2, 1, 3)
196
- .reshape(b * self.heads, t.shape[1], self.dim_head)
197
- .contiguous(),
198
- (q, k, v),
199
- )
200
-
201
- # actually compute the attention, what we cannot get enough of
202
- out = xformers.ops.memory_efficient_attention(
203
- q, k, v, attn_bias=None, op=self.attention_op
204
- )
205
-
206
- if self.ip_dim > 0:
207
- k_ip, v_ip = map(
208
- lambda t: t.unsqueeze(3)
209
- .reshape(b, t.shape[1], self.heads, self.dim_head)
210
- .permute(0, 2, 1, 3)
211
- .reshape(b * self.heads, t.shape[1], self.dim_head)
212
- .contiguous(),
213
- (k_ip, v_ip),
214
- )
215
- # actually compute the attention, what we cannot get enough of
216
- out_ip = xformers.ops.memory_efficient_attention(
217
- q, k_ip, v_ip, attn_bias=None, op=self.attention_op
218
- )
219
- out = out + self.ip_weight * out_ip
220
-
221
- out = (
222
- out.unsqueeze(0)
223
- .reshape(b, self.heads, out.shape[1], self.dim_head)
224
- .permute(0, 2, 1, 3)
225
- .reshape(b, out.shape[1], self.heads * self.dim_head)
226
- )
227
- return self.to_out(out)
228
-
229
-
230
- class BasicTransformerBlock3D(nn.Module):
231
-
232
- def __init__(
233
- self,
234
- dim,
235
- n_heads,
236
- d_head,
237
- context_dim,
238
- dropout=0.0,
239
- gated_ff=True,
240
- ip_dim=0,
241
- ip_weight=1,
242
- ):
243
- super().__init__()
244
-
245
- self.attn1 = MemoryEfficientCrossAttention(
246
- query_dim=dim,
247
- context_dim=None, # self-attention
248
- heads=n_heads,
249
- dim_head=d_head,
250
- dropout=dropout,
251
- )
252
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
253
- self.attn2 = MemoryEfficientCrossAttention(
254
- query_dim=dim,
255
- context_dim=context_dim,
256
- heads=n_heads,
257
- dim_head=d_head,
258
- dropout=dropout,
259
- # ip only applies to cross-attention
260
- ip_dim=ip_dim,
261
- ip_weight=ip_weight,
262
- )
263
- self.norm1 = nn.LayerNorm(dim)
264
- self.norm2 = nn.LayerNorm(dim)
265
- self.norm3 = nn.LayerNorm(dim)
266
-
267
- def forward(self, x, context=None, num_frames=1):
268
- x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
269
- x = self.attn1(self.norm1(x), context=None) + x
270
- x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
271
- x = self.attn2(self.norm2(x), context=context) + x
272
- x = self.ff(self.norm3(x)) + x
273
- return x
274
-
275
-
276
- class SpatialTransformer3D(nn.Module):
277
-
278
- def __init__(
279
- self,
280
- in_channels,
281
- n_heads,
282
- d_head,
283
- context_dim, # cross attention input dim
284
- depth=1,
285
- dropout=0.0,
286
- ip_dim=0,
287
- ip_weight=1,
288
- ):
289
- super().__init__()
290
-
291
- if not isinstance(context_dim, list):
292
- context_dim = [context_dim]
293
-
294
- self.in_channels = in_channels
295
-
296
- inner_dim = n_heads * d_head
297
- self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
298
- self.proj_in = nn.Linear(in_channels, inner_dim)
299
-
300
- self.transformer_blocks = nn.ModuleList(
301
- [
302
- BasicTransformerBlock3D(
303
- inner_dim,
304
- n_heads,
305
- d_head,
306
- context_dim=context_dim[d],
307
- dropout=dropout,
308
- ip_dim=ip_dim,
309
- ip_weight=ip_weight,
310
- )
311
- for d in range(depth)
312
- ]
313
- )
314
-
315
- self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
316
-
317
-
318
- def forward(self, x, context=None, num_frames=1):
319
- # note: if no context is given, cross-attention defaults to self-attention
320
- if not isinstance(context, list):
321
- context = [context]
322
- b, c, h, w = x.shape
323
- x_in = x
324
- x = self.norm(x)
325
- x = rearrange(x, "b c h w -> b (h w) c").contiguous()
326
- x = self.proj_in(x)
327
- for i, block in enumerate(self.transformer_blocks):
328
- x = block(x, context=context[i], num_frames=num_frames)
329
- x = self.proj_out(x)
330
- x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
331
-
332
- return x + x_in
333
-
334
-
335
- class PerceiverAttention(nn.Module):
336
- def __init__(self, *, dim, dim_head=64, heads=8):
337
- super().__init__()
338
- self.scale = dim_head ** -0.5
339
- self.dim_head = dim_head
340
- self.heads = heads
341
- inner_dim = dim_head * heads
342
-
343
- self.norm1 = nn.LayerNorm(dim)
344
- self.norm2 = nn.LayerNorm(dim)
345
-
346
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
347
- self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
348
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
349
-
350
- def forward(self, x, latents):
351
- """
352
- Args:
353
- x (torch.Tensor): image features
354
- shape (b, n1, D)
355
- latent (torch.Tensor): latent features
356
- shape (b, n2, D)
357
- """
358
- x = self.norm1(x)
359
- latents = self.norm2(latents)
360
-
361
- b, l, _ = latents.shape
362
-
363
- q = self.to_q(latents)
364
- kv_input = torch.cat((x, latents), dim=-2)
365
- k, v = self.to_kv(kv_input).chunk(2, dim=-1)
366
-
367
- q, k, v = map(
368
- lambda t: t.reshape(b, t.shape[1], self.heads, -1)
369
- .transpose(1, 2)
370
- .reshape(b, self.heads, t.shape[1], -1)
371
- .contiguous(),
372
- (q, k, v),
373
- )
374
-
375
- # attention
376
- scale = 1 / math.sqrt(math.sqrt(self.dim_head))
377
- weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
378
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
379
- out = weight @ v
380
-
381
- out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
382
-
383
- return self.to_out(out)
384
-
385
-
386
- class Resampler(nn.Module):
387
- def __init__(
388
- self,
389
- dim=1024,
390
- depth=8,
391
- dim_head=64,
392
- heads=16,
393
- num_queries=8,
394
- embedding_dim=768,
395
- output_dim=1024,
396
- ff_mult=4,
397
- ):
398
- super().__init__()
399
- self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
400
- self.proj_in = nn.Linear(embedding_dim, dim)
401
- self.proj_out = nn.Linear(dim, output_dim)
402
- self.norm_out = nn.LayerNorm(output_dim)
403
-
404
- self.layers = nn.ModuleList([])
405
- for _ in range(depth):
406
- self.layers.append(
407
- nn.ModuleList(
408
- [
409
- PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
410
- nn.Sequential(
411
- nn.LayerNorm(dim),
412
- nn.Linear(dim, dim * ff_mult, bias=False),
413
- nn.GELU(),
414
- nn.Linear(dim * ff_mult, dim, bias=False),
415
- )
416
- ]
417
- )
418
- )
419
-
420
- def forward(self, x):
421
- latents = self.latents.repeat(x.size(0), 1, 1)
422
- x = self.proj_in(x)
423
- for attn, ff in self.layers:
424
- latents = attn(x, latents) + latents
425
- latents = ff(latents) + latents
426
-
427
- latents = self.proj_out(latents)
428
- return self.norm_out(latents)
429
-
430
-
431
- class CondSequential(nn.Sequential):
432
- """
433
- A sequential module that passes timestep embeddings to the children that
434
- support it as an extra input.
435
- """
436
-
437
- def forward(self, x, emb, context=None, num_frames=1):
438
- for layer in self:
439
- if isinstance(layer, ResBlock):
440
- x = layer(x, emb)
441
- elif isinstance(layer, SpatialTransformer3D):
442
- x = layer(x, context, num_frames=num_frames)
443
- else:
444
- x = layer(x)
445
- return x
446
-
447
-
448
- class Upsample(nn.Module):
449
- """
450
- An upsampling layer with an optional convolution.
451
- :param channels: channels in the inputs and outputs.
452
- :param use_conv: a bool determining if a convolution is applied.
453
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
454
- upsampling occurs in the inner-two dimensions.
455
- """
456
-
457
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
458
- super().__init__()
459
- self.channels = channels
460
- self.out_channels = out_channels or channels
461
- self.use_conv = use_conv
462
- self.dims = dims
463
- if use_conv:
464
- self.conv = conv_nd(
465
- dims, self.channels, self.out_channels, 3, padding=padding
466
- )
467
-
468
- def forward(self, x):
469
- assert x.shape[1] == self.channels
470
- if self.dims == 3:
471
- x = F.interpolate(
472
- x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
473
- )
474
- else:
475
- x = F.interpolate(x, scale_factor=2, mode="nearest")
476
- if self.use_conv:
477
- x = self.conv(x)
478
- return x
479
-
480
-
481
- class Downsample(nn.Module):
482
- """
483
- A downsampling layer with an optional convolution.
484
- :param channels: channels in the inputs and outputs.
485
- :param use_conv: a bool determining if a convolution is applied.
486
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
487
- downsampling occurs in the inner-two dimensions.
488
- """
489
-
490
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
491
- super().__init__()
492
- self.channels = channels
493
- self.out_channels = out_channels or channels
494
- self.use_conv = use_conv
495
- self.dims = dims
496
- stride = 2 if dims != 3 else (1, 2, 2)
497
- if use_conv:
498
- self.op = conv_nd(
499
- dims,
500
- self.channels,
501
- self.out_channels,
502
- 3,
503
- stride=stride,
504
- padding=padding,
505
- )
506
- else:
507
- assert self.channels == self.out_channels
508
- self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
509
-
510
- def forward(self, x):
511
- assert x.shape[1] == self.channels
512
- return self.op(x)
513
-
514
-
515
- class ResBlock(nn.Module):
516
- """
517
- A residual block that can optionally change the number of channels.
518
- :param channels: the number of input channels.
519
- :param emb_channels: the number of timestep embedding channels.
520
- :param dropout: the rate of dropout.
521
- :param out_channels: if specified, the number of out channels.
522
- :param use_conv: if True and out_channels is specified, use a spatial
523
- convolution instead of a smaller 1x1 convolution to change the
524
- channels in the skip connection.
525
- :param dims: determines if the signal is 1D, 2D, or 3D.
526
- :param up: if True, use this block for upsampling.
527
- :param down: if True, use this block for downsampling.
528
- """
529
-
530
- def __init__(
531
- self,
532
- channels,
533
- emb_channels,
534
- dropout,
535
- out_channels=None,
536
- use_conv=False,
537
- use_scale_shift_norm=False,
538
- dims=2,
539
- up=False,
540
- down=False,
541
- ):
542
- super().__init__()
543
- self.channels = channels
544
- self.emb_channels = emb_channels
545
- self.dropout = dropout
546
- self.out_channels = out_channels or channels
547
- self.use_conv = use_conv
548
- self.use_scale_shift_norm = use_scale_shift_norm
549
-
550
- self.in_layers = nn.Sequential(
551
- nn.GroupNorm(32, channels),
552
- nn.SiLU(),
553
- conv_nd(dims, channels, self.out_channels, 3, padding=1),
554
- )
555
-
556
- self.updown = up or down
557
-
558
- if up:
559
- self.h_upd = Upsample(channels, False, dims)
560
- self.x_upd = Upsample(channels, False, dims)
561
- elif down:
562
- self.h_upd = Downsample(channels, False, dims)
563
- self.x_upd = Downsample(channels, False, dims)
564
- else:
565
- self.h_upd = self.x_upd = nn.Identity()
566
-
567
- self.emb_layers = nn.Sequential(
568
- nn.SiLU(),
569
- nn.Linear(
570
- emb_channels,
571
- 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
572
- ),
573
- )
574
- self.out_layers = nn.Sequential(
575
- nn.GroupNorm(32, self.out_channels),
576
- nn.SiLU(),
577
- nn.Dropout(p=dropout),
578
- zero_module(
579
- conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
580
- ),
581
- )
582
-
583
- if self.out_channels == channels:
584
- self.skip_connection = nn.Identity()
585
- elif use_conv:
586
- self.skip_connection = conv_nd(
587
- dims, channels, self.out_channels, 3, padding=1
588
- )
589
- else:
590
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
591
-
592
- def forward(self, x, emb):
593
- if self.updown:
594
- in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
595
- h = in_rest(x)
596
- h = self.h_upd(h)
597
- x = self.x_upd(x)
598
- h = in_conv(h)
599
- else:
600
- h = self.in_layers(x)
601
- emb_out = self.emb_layers(emb).type(h.dtype)
602
- while len(emb_out.shape) < len(h.shape):
603
- emb_out = emb_out[..., None]
604
- if self.use_scale_shift_norm:
605
- out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
606
- scale, shift = torch.chunk(emb_out, 2, dim=1)
607
- h = out_norm(h) * (1 + scale) + shift
608
- h = out_rest(h)
609
- else:
610
- h = h + emb_out
611
- h = self.out_layers(h)
612
- return self.skip_connection(x) + h
613
-
614
-
615
- class MultiViewUNetModel(ModelMixin, ConfigMixin):
616
- """
617
- The full multi-view UNet model with attention, timestep embedding and camera embedding.
618
- :param in_channels: channels in the input Tensor.
619
- :param model_channels: base channel count for the model.
620
- :param out_channels: channels in the output Tensor.
621
- :param num_res_blocks: number of residual blocks per downsample.
622
- :param attention_resolutions: a collection of downsample rates at which
623
- attention will take place. May be a set, list, or tuple.
624
- For example, if this contains 4, then at 4x downsampling, attention
625
- will be used.
626
- :param dropout: the dropout probability.
627
- :param channel_mult: channel multiplier for each level of the UNet.
628
- :param conv_resample: if True, use learned convolutions for upsampling and
629
- downsampling.
630
- :param dims: determines if the signal is 1D, 2D, or 3D.
631
- :param num_classes: if specified (as an int), then this model will be
632
- class-conditional with `num_classes` classes.
633
- :param num_heads: the number of attention heads in each attention layer.
634
- :param num_heads_channels: if specified, ignore num_heads and instead use
635
- a fixed channel width per attention head.
636
- :param num_heads_upsample: works with num_heads to set a different number
637
- of heads for upsampling. Deprecated.
638
- :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
639
- :param resblock_updown: use residual blocks for up/downsampling.
640
- :param use_new_attention_order: use a different attention pattern for potentially
641
- increased efficiency.
642
- :param camera_dim: dimensionality of camera input.
643
- """
644
-
645
- def __init__(
646
- self,
647
- image_size,
648
- in_channels,
649
- model_channels,
650
- out_channels,
651
- num_res_blocks,
652
- attention_resolutions,
653
- dropout=0,
654
- channel_mult=(1, 2, 4, 8),
655
- conv_resample=True,
656
- dims=2,
657
- num_classes=None,
658
- num_heads=-1,
659
- num_head_channels=-1,
660
- num_heads_upsample=-1,
661
- use_scale_shift_norm=False,
662
- resblock_updown=False,
663
- transformer_depth=1,
664
- context_dim=None,
665
- n_embed=None,
666
- num_attention_blocks=None,
667
- adm_in_channels=None,
668
- camera_dim=None,
669
- ip_dim=0, # imagedream uses ip_dim > 0
670
- ip_weight=1.0,
671
- **kwargs,
672
- ):
673
- super().__init__()
674
- assert context_dim is not None
675
-
676
- if num_heads_upsample == -1:
677
- num_heads_upsample = num_heads
678
-
679
- if num_heads == -1:
680
- assert (
681
- num_head_channels != -1
682
- ), "Either num_heads or num_head_channels has to be set"
683
-
684
- if num_head_channels == -1:
685
- assert (
686
- num_heads != -1
687
- ), "Either num_heads or num_head_channels has to be set"
688
-
689
- self.image_size = image_size
690
- self.in_channels = in_channels
691
- self.model_channels = model_channels
692
- self.out_channels = out_channels
693
- if isinstance(num_res_blocks, int):
694
- self.num_res_blocks = len(channel_mult) * [num_res_blocks]
695
- else:
696
- if len(num_res_blocks) != len(channel_mult):
697
- raise ValueError(
698
- "provide num_res_blocks either as an int (globally constant) or "
699
- "as a list/tuple (per-level) with the same length as channel_mult"
700
- )
701
- self.num_res_blocks = num_res_blocks
702
-
703
- if num_attention_blocks is not None:
704
- assert len(num_attention_blocks) == len(self.num_res_blocks)
705
- assert all(
706
- map(
707
- lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
708
- range(len(num_attention_blocks)),
709
- )
710
- )
711
- print(
712
- f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
713
- f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
714
- f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
715
- f"attention will still not be set."
716
- )
717
-
718
- self.attention_resolutions = attention_resolutions
719
- self.dropout = dropout
720
- self.channel_mult = channel_mult
721
- self.conv_resample = conv_resample
722
- self.num_classes = num_classes
723
- self.num_heads = num_heads
724
- self.num_head_channels = num_head_channels
725
- self.num_heads_upsample = num_heads_upsample
726
- self.predict_codebook_ids = n_embed is not None
727
-
728
- self.ip_dim = ip_dim
729
- self.ip_weight = ip_weight
730
-
731
- if self.ip_dim > 0:
732
- self.image_embed = Resampler(
733
- dim=context_dim,
734
- depth=4,
735
- dim_head=64,
736
- heads=12,
737
- num_queries=ip_dim, # num token
738
- embedding_dim=1280,
739
- output_dim=context_dim,
740
- ff_mult=4,
741
- )
742
-
743
- time_embed_dim = model_channels * 4
744
- self.time_embed = nn.Sequential(
745
- nn.Linear(model_channels, time_embed_dim),
746
- nn.SiLU(),
747
- nn.Linear(time_embed_dim, time_embed_dim),
748
- )
749
-
750
- if camera_dim is not None:
751
- time_embed_dim = model_channels * 4
752
- self.camera_embed = nn.Sequential(
753
- nn.Linear(camera_dim, time_embed_dim),
754
- nn.SiLU(),
755
- nn.Linear(time_embed_dim, time_embed_dim),
756
- )
757
-
758
- if self.num_classes is not None:
759
- if isinstance(self.num_classes, int):
760
- self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
761
- elif self.num_classes == "continuous":
762
- # print("setting up linear c_adm embedding layer")
763
- self.label_emb = nn.Linear(1, time_embed_dim)
764
- elif self.num_classes == "sequential":
765
- assert adm_in_channels is not None
766
- self.label_emb = nn.Sequential(
767
- nn.Sequential(
768
- nn.Linear(adm_in_channels, time_embed_dim),
769
- nn.SiLU(),
770
- nn.Linear(time_embed_dim, time_embed_dim),
771
- )
772
- )
773
- else:
774
- raise ValueError()
775
-
776
- self.input_blocks = nn.ModuleList(
777
- [
778
- CondSequential(
779
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
780
- )
781
- ]
782
- )
783
- self._feature_size = model_channels
784
- input_block_chans = [model_channels]
785
- ch = model_channels
786
- ds = 1
787
- for level, mult in enumerate(channel_mult):
788
- for nr in range(self.num_res_blocks[level]):
789
- layers: List[Any] = [
790
- ResBlock(
791
- ch,
792
- time_embed_dim,
793
- dropout,
794
- out_channels=mult * model_channels,
795
- dims=dims,
796
- use_scale_shift_norm=use_scale_shift_norm,
797
- )
798
- ]
799
- ch = mult * model_channels
800
- if ds in attention_resolutions:
801
- if num_head_channels == -1:
802
- dim_head = ch // num_heads
803
- else:
804
- num_heads = ch // num_head_channels
805
- dim_head = num_head_channels
806
-
807
- if num_attention_blocks is None or nr < num_attention_blocks[level]:
808
- layers.append(
809
- SpatialTransformer3D(
810
- ch,
811
- num_heads,
812
- dim_head,
813
- context_dim=context_dim,
814
- depth=transformer_depth,
815
- ip_dim=self.ip_dim,
816
- ip_weight=self.ip_weight,
817
- )
818
- )
819
- self.input_blocks.append(CondSequential(*layers))
820
- self._feature_size += ch
821
- input_block_chans.append(ch)
822
- if level != len(channel_mult) - 1:
823
- out_ch = ch
824
- self.input_blocks.append(
825
- CondSequential(
826
- ResBlock(
827
- ch,
828
- time_embed_dim,
829
- dropout,
830
- out_channels=out_ch,
831
- dims=dims,
832
- use_scale_shift_norm=use_scale_shift_norm,
833
- down=True,
834
- )
835
- if resblock_updown
836
- else Downsample(
837
- ch, conv_resample, dims=dims, out_channels=out_ch
838
- )
839
- )
840
- )
841
- ch = out_ch
842
- input_block_chans.append(ch)
843
- ds *= 2
844
- self._feature_size += ch
845
-
846
- if num_head_channels == -1:
847
- dim_head = ch // num_heads
848
- else:
849
- num_heads = ch // num_head_channels
850
- dim_head = num_head_channels
851
-
852
- self.middle_block = CondSequential(
853
- ResBlock(
854
- ch,
855
- time_embed_dim,
856
- dropout,
857
- dims=dims,
858
- use_scale_shift_norm=use_scale_shift_norm,
859
- ),
860
- SpatialTransformer3D(
861
- ch,
862
- num_heads,
863
- dim_head,
864
- context_dim=context_dim,
865
- depth=transformer_depth,
866
- ip_dim=self.ip_dim,
867
- ip_weight=self.ip_weight,
868
- ),
869
- ResBlock(
870
- ch,
871
- time_embed_dim,
872
- dropout,
873
- dims=dims,
874
- use_scale_shift_norm=use_scale_shift_norm,
875
- ),
876
- )
877
- self._feature_size += ch
878
-
879
- self.output_blocks = nn.ModuleList([])
880
- for level, mult in list(enumerate(channel_mult))[::-1]:
881
- for i in range(self.num_res_blocks[level] + 1):
882
- ich = input_block_chans.pop()
883
- layers = [
884
- ResBlock(
885
- ch + ich,
886
- time_embed_dim,
887
- dropout,
888
- out_channels=model_channels * mult,
889
- dims=dims,
890
- use_scale_shift_norm=use_scale_shift_norm,
891
- )
892
- ]
893
- ch = model_channels * mult
894
- if ds in attention_resolutions:
895
- if num_head_channels == -1:
896
- dim_head = ch // num_heads
897
- else:
898
- num_heads = ch // num_head_channels
899
- dim_head = num_head_channels
900
-
901
- if num_attention_blocks is None or i < num_attention_blocks[level]:
902
- layers.append(
903
- SpatialTransformer3D(
904
- ch,
905
- num_heads,
906
- dim_head,
907
- context_dim=context_dim,
908
- depth=transformer_depth,
909
- ip_dim=self.ip_dim,
910
- ip_weight=self.ip_weight,
911
- )
912
- )
913
- if level and i == self.num_res_blocks[level]:
914
- out_ch = ch
915
- layers.append(
916
- ResBlock(
917
- ch,
918
- time_embed_dim,
919
- dropout,
920
- out_channels=out_ch,
921
- dims=dims,
922
- use_scale_shift_norm=use_scale_shift_norm,
923
- up=True,
924
- )
925
- if resblock_updown
926
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
927
- )
928
- ds //= 2
929
- self.output_blocks.append(CondSequential(*layers))
930
- self._feature_size += ch
931
-
932
- self.out = nn.Sequential(
933
- nn.GroupNorm(32, ch),
934
- nn.SiLU(),
935
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
936
- )
937
- if self.predict_codebook_ids:
938
- self.id_predictor = nn.Sequential(
939
- nn.GroupNorm(32, ch),
940
- conv_nd(dims, model_channels, n_embed, 1),
941
- # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
942
- )
943
-
944
- def forward(
945
- self,
946
- x,
947
- timesteps=None,
948
- context=None,
949
- y=None,
950
- camera=None,
951
- num_frames=1,
952
- ip=None,
953
- ip_img=None,
954
- **kwargs,
955
- ):
956
- """
957
- Apply the model to an input batch.
958
- :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
959
- :param timesteps: a 1-D batch of timesteps.
960
- :param context: conditioning plugged in via crossattn
961
- :param y: an [N] Tensor of labels, if class-conditional.
962
- :param num_frames: a integer indicating number of frames for tensor reshaping.
963
- :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
964
- """
965
- assert (
966
- x.shape[0] % num_frames == 0
967
- ), "input batch size must be dividable by num_frames!"
968
- assert (y is not None) == (
969
- self.num_classes is not None
970
- ), "must specify y if and only if the model is class-conditional"
971
-
972
- hs = []
973
-
974
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
975
-
976
- emb = self.time_embed(t_emb)
977
-
978
- if self.num_classes is not None:
979
- assert y is not None
980
- assert y.shape[0] == x.shape[0]
981
- emb = emb + self.label_emb(y)
982
-
983
- # Add camera embeddings
984
- if camera is not None:
985
- emb = emb + self.camera_embed(camera)
986
-
987
- # imagedream variant
988
- if self.ip_dim > 0:
989
- x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9]
990
- ip_emb = self.image_embed(ip)
991
- context = torch.cat((context, ip_emb), 1)
992
-
993
- h = x
994
- for module in self.input_blocks:
995
- h = module(h, emb, context, num_frames=num_frames)
996
- hs.append(h)
997
- h = self.middle_block(h, emb, context, num_frames=num_frames)
998
- for module in self.output_blocks:
999
- h = torch.cat([h, hs.pop()], dim=1)
1000
- h = module(h, emb, context, num_frames=num_frames)
1001
- h = h.type(x.dtype)
1002
- if self.predict_codebook_ids:
1003
- return self.id_predictor(h)
1004
- else:
1005
- return self.out(h)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pipeline.py CHANGED
@@ -2,8 +2,13 @@ import torch
2
  import torch.nn.functional as F
3
  import inspect
4
  import numpy as np
5
- from typing import Callable, List, Optional, Union
6
- from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor
 
 
 
 
 
7
  from diffusers import AutoencoderKL, DiffusionPipeline
8
  from diffusers.utils import (
9
  deprecate,
@@ -15,7 +20,1017 @@ from diffusers.configuration_utils import FrozenDict
15
  from diffusers.schedulers import DDIMScheduler
16
  from diffusers.utils.torch_utils import randn_tensor
17
 
18
- from mv_unet import MultiViewUNetModel, get_camera
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
21
 
@@ -404,26 +1419,30 @@ class MVDreamPipeline(DiffusionPipeline):
404
 
405
  if image.dtype == np.float32:
406
  image = (image * 255).astype(np.uint8)
407
-
408
  image = self.feature_extractor(image, return_tensors="pt").pixel_values
409
  image = image.to(device=device, dtype=dtype)
410
-
411
- image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
 
 
412
  image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
413
 
414
  return torch.zeros_like(image_embeds), image_embeds
415
 
416
  def encode_image_latents(self, image, device, num_images_per_prompt):
417
-
418
  dtype = next(self.image_encoder.parameters()).dtype
419
 
420
- image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device) # [1, 3, H, W]
 
 
421
  image = 2 * image - 1
422
- image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False)
423
  image = image.to(dtype=dtype)
424
 
425
  posterior = self.vae.encode(image).latent_dist
426
- latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W]
427
  latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
428
 
429
  return torch.zeros_like(latents), latents
@@ -442,7 +1461,7 @@ class MVDreamPipeline(DiffusionPipeline):
442
  num_images_per_prompt: int = 1,
443
  eta: float = 0.0,
444
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
445
- output_type: Optional[str] = "numpy", # pil, numpy, latents
446
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
447
  callback_steps: int = 1,
448
  num_frames: int = 4,
@@ -465,9 +1484,13 @@ class MVDreamPipeline(DiffusionPipeline):
465
  if image is not None:
466
  assert isinstance(image, np.ndarray) and image.dtype == np.float32
467
  self.image_encoder = self.image_encoder.to(device=device)
468
- image_embeds_neg, image_embeds_pos = self.encode_image(image, device, num_images_per_prompt)
469
- image_latents_neg, image_latents_pos = self.encode_image_latents(image, device, num_images_per_prompt)
470
-
 
 
 
 
471
  _prompt_embeds = self._encode_prompt(
472
  prompt=prompt,
473
  device=device,
@@ -491,7 +1514,9 @@ class MVDreamPipeline(DiffusionPipeline):
491
  )
492
 
493
  # Get camera
494
- camera = get_camera(num_frames, elevation=elevation, extra_view=(image is not None)).to(dtype=latents.dtype, device=device)
 
 
495
  camera = camera.repeat_interleave(num_images_per_prompt, dim=0)
496
 
497
  # Prepare extra step kwargs.
@@ -504,20 +1529,34 @@ class MVDreamPipeline(DiffusionPipeline):
504
  # expand the latents if we are doing classifier free guidance
505
  multiplier = 2 if do_classifier_free_guidance else 1
506
  latent_model_input = torch.cat([latents] * multiplier)
507
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
 
508
 
509
  unet_inputs = {
510
- 'x': latent_model_input,
511
- 'timesteps': torch.tensor([t] * actual_num_frames * multiplier, dtype=latent_model_input.dtype, device=device),
512
- 'context': torch.cat([prompt_embeds_neg] * actual_num_frames + [prompt_embeds_pos] * actual_num_frames),
513
- 'num_frames': actual_num_frames,
514
- 'camera': torch.cat([camera] * multiplier),
 
 
 
 
 
 
 
515
  }
516
 
517
  if image is not None:
518
- unet_inputs['ip'] = torch.cat([image_embeds_neg] * actual_num_frames + [image_embeds_pos] * actual_num_frames)
519
- unet_inputs['ip_img'] = torch.cat([image_latents_neg] + [image_latents_pos]) # no repeat
520
-
 
 
 
 
 
521
  # predict the noise residual
522
  noise_pred = self.unet.forward(**unet_inputs)
523
 
@@ -547,7 +1586,7 @@ class MVDreamPipeline(DiffusionPipeline):
547
  elif output_type == "pil":
548
  image = self.decode_latents(latents)
549
  image = self.numpy_to_pil(image)
550
- else: # numpy
551
  image = self.decode_latents(latents)
552
 
553
  # Offload last model to CPU
 
2
  import torch.nn.functional as F
3
  import inspect
4
  import numpy as np
5
+ from typing import Callable, List, Optional, Union, Any
6
+ from transformers import (
7
+ CLIPTextModel,
8
+ CLIPTokenizer,
9
+ CLIPVisionModel,
10
+ CLIPImageProcessor,
11
+ )
12
  from diffusers import AutoencoderKL, DiffusionPipeline
13
  from diffusers.utils import (
14
  deprecate,
 
20
  from diffusers.schedulers import DDIMScheduler
21
  from diffusers.utils.torch_utils import randn_tensor
22
 
23
+ import math
24
+ from inspect import isfunction
25
+
26
+ import torch.nn as nn
27
+ from einops import rearrange, repeat
28
+
29
+ from diffusers.configuration_utils import ConfigMixin
30
+ from diffusers.models.modeling_utils import ModelMixin
31
+
32
+ # require xformers!
33
+ import xformers
34
+ import xformers.ops
35
+
36
+ from kiui.cam import orbit_camera
37
+
38
+
39
+ def get_camera(
40
+ num_frames,
41
+ elevation=15,
42
+ azimuth_start=0,
43
+ azimuth_span=360,
44
+ blender_coord=True,
45
+ extra_view=False,
46
+ ):
47
+ angle_gap = azimuth_span / num_frames
48
+ cameras = []
49
+ for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
50
+
51
+ pose = orbit_camera(
52
+ -elevation, azimuth, radius=1
53
+ ) # kiui's elevation is negated, [4, 4]
54
+
55
+ # opengl to blender
56
+ if blender_coord:
57
+ pose[2] *= -1
58
+ pose[[1, 2]] = pose[[2, 1]]
59
+
60
+ cameras.append(pose.flatten())
61
+
62
+ if extra_view:
63
+ cameras.append(np.zeros_like(cameras[0]))
64
+
65
+ return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
66
+
67
+
68
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
69
+ """
70
+ Create sinusoidal timestep embeddings.
71
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
72
+ These may be fractional.
73
+ :param dim: the dimension of the output.
74
+ :param max_period: controls the minimum frequency of the embeddings.
75
+ :return: an [N x dim] Tensor of positional embeddings.
76
+ """
77
+ if not repeat_only:
78
+ half = dim // 2
79
+ freqs = torch.exp(
80
+ -math.log(max_period)
81
+ * torch.arange(start=0, end=half, dtype=torch.float32)
82
+ / half
83
+ ).to(device=timesteps.device)
84
+ args = timesteps[:, None] * freqs[None]
85
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
86
+ if dim % 2:
87
+ embedding = torch.cat(
88
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
89
+ )
90
+ else:
91
+ embedding = repeat(timesteps, "b -> b d", d=dim)
92
+ # import pdb; pdb.set_trace()
93
+ return embedding
94
+
95
+
96
+ def zero_module(module):
97
+ """
98
+ Zero out the parameters of a module and return it.
99
+ """
100
+ for p in module.parameters():
101
+ p.detach().zero_()
102
+ return module
103
+
104
+
105
+ def conv_nd(dims, *args, **kwargs):
106
+ """
107
+ Create a 1D, 2D, or 3D convolution module.
108
+ """
109
+ if dims == 1:
110
+ return nn.Conv1d(*args, **kwargs)
111
+ elif dims == 2:
112
+ return nn.Conv2d(*args, **kwargs)
113
+ elif dims == 3:
114
+ return nn.Conv3d(*args, **kwargs)
115
+ raise ValueError(f"unsupported dimensions: {dims}")
116
+
117
+
118
+ def avg_pool_nd(dims, *args, **kwargs):
119
+ """
120
+ Create a 1D, 2D, or 3D average pooling module.
121
+ """
122
+ if dims == 1:
123
+ return nn.AvgPool1d(*args, **kwargs)
124
+ elif dims == 2:
125
+ return nn.AvgPool2d(*args, **kwargs)
126
+ elif dims == 3:
127
+ return nn.AvgPool3d(*args, **kwargs)
128
+ raise ValueError(f"unsupported dimensions: {dims}")
129
+
130
+
131
+ def default(val, d):
132
+ if val is not None:
133
+ return val
134
+ return d() if isfunction(d) else d
135
+
136
+
137
+ class GEGLU(nn.Module):
138
+ def __init__(self, dim_in, dim_out):
139
+ super().__init__()
140
+ self.proj = nn.Linear(dim_in, dim_out * 2)
141
+
142
+ def forward(self, x):
143
+ x, gate = self.proj(x).chunk(2, dim=-1)
144
+ return x * F.gelu(gate)
145
+
146
+
147
+ class FeedForward(nn.Module):
148
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
149
+ super().__init__()
150
+ inner_dim = int(dim * mult)
151
+ dim_out = default(dim_out, dim)
152
+ project_in = (
153
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
154
+ if not glu
155
+ else GEGLU(dim, inner_dim)
156
+ )
157
+
158
+ self.net = nn.Sequential(
159
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
160
+ )
161
+
162
+ def forward(self, x):
163
+ return self.net(x)
164
+
165
+
166
+ class MemoryEfficientCrossAttention(nn.Module):
167
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
168
+ def __init__(
169
+ self,
170
+ query_dim,
171
+ context_dim=None,
172
+ heads=8,
173
+ dim_head=64,
174
+ dropout=0.0,
175
+ ip_dim=0,
176
+ ip_weight=1,
177
+ ):
178
+ super().__init__()
179
+
180
+ inner_dim = dim_head * heads
181
+ context_dim = default(context_dim, query_dim)
182
+
183
+ self.heads = heads
184
+ self.dim_head = dim_head
185
+
186
+ self.ip_dim = ip_dim
187
+ self.ip_weight = ip_weight
188
+
189
+ if self.ip_dim > 0:
190
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
191
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
192
+
193
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
194
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
195
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
196
+
197
+ self.to_out = nn.Sequential(
198
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
199
+ )
200
+ self.attention_op: Optional[Any] = None
201
+
202
+ def forward(self, x, context=None):
203
+ q = self.to_q(x)
204
+ context = default(context, x)
205
+
206
+ if self.ip_dim > 0:
207
+ # context: [B, 77 + 16(ip), 1024]
208
+ token_len = context.shape[1]
209
+ context_ip = context[:, -self.ip_dim :, :]
210
+ k_ip = self.to_k_ip(context_ip)
211
+ v_ip = self.to_v_ip(context_ip)
212
+ context = context[:, : (token_len - self.ip_dim), :]
213
+
214
+ k = self.to_k(context)
215
+ v = self.to_v(context)
216
+
217
+ b, _, _ = q.shape
218
+ q, k, v = map(
219
+ lambda t: t.unsqueeze(3)
220
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
221
+ .permute(0, 2, 1, 3)
222
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
223
+ .contiguous(),
224
+ (q, k, v),
225
+ )
226
+
227
+ # actually compute the attention, what we cannot get enough of
228
+ out = xformers.ops.memory_efficient_attention(
229
+ q, k, v, attn_bias=None, op=self.attention_op
230
+ )
231
+
232
+ if self.ip_dim > 0:
233
+ k_ip, v_ip = map(
234
+ lambda t: t.unsqueeze(3)
235
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
236
+ .permute(0, 2, 1, 3)
237
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
238
+ .contiguous(),
239
+ (k_ip, v_ip),
240
+ )
241
+ # actually compute the attention, what we cannot get enough of
242
+ out_ip = xformers.ops.memory_efficient_attention(
243
+ q, k_ip, v_ip, attn_bias=None, op=self.attention_op
244
+ )
245
+ out = out + self.ip_weight * out_ip
246
+
247
+ out = (
248
+ out.unsqueeze(0)
249
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
250
+ .permute(0, 2, 1, 3)
251
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
252
+ )
253
+ return self.to_out(out)
254
+
255
+
256
+ class BasicTransformerBlock3D(nn.Module):
257
+
258
+ def __init__(
259
+ self,
260
+ dim,
261
+ n_heads,
262
+ d_head,
263
+ context_dim,
264
+ dropout=0.0,
265
+ gated_ff=True,
266
+ ip_dim=0,
267
+ ip_weight=1,
268
+ ):
269
+ super().__init__()
270
+
271
+ self.attn1 = MemoryEfficientCrossAttention(
272
+ query_dim=dim,
273
+ context_dim=None, # self-attention
274
+ heads=n_heads,
275
+ dim_head=d_head,
276
+ dropout=dropout,
277
+ )
278
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
279
+ self.attn2 = MemoryEfficientCrossAttention(
280
+ query_dim=dim,
281
+ context_dim=context_dim,
282
+ heads=n_heads,
283
+ dim_head=d_head,
284
+ dropout=dropout,
285
+ # ip only applies to cross-attention
286
+ ip_dim=ip_dim,
287
+ ip_weight=ip_weight,
288
+ )
289
+ self.norm1 = nn.LayerNorm(dim)
290
+ self.norm2 = nn.LayerNorm(dim)
291
+ self.norm3 = nn.LayerNorm(dim)
292
+
293
+ def forward(self, x, context=None, num_frames=1):
294
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
295
+ x = self.attn1(self.norm1(x), context=None) + x
296
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
297
+ x = self.attn2(self.norm2(x), context=context) + x
298
+ x = self.ff(self.norm3(x)) + x
299
+ return x
300
+
301
+
302
+ class SpatialTransformer3D(nn.Module):
303
+
304
+ def __init__(
305
+ self,
306
+ in_channels,
307
+ n_heads,
308
+ d_head,
309
+ context_dim, # cross attention input dim
310
+ depth=1,
311
+ dropout=0.0,
312
+ ip_dim=0,
313
+ ip_weight=1,
314
+ ):
315
+ super().__init__()
316
+
317
+ if not isinstance(context_dim, list):
318
+ context_dim = [context_dim]
319
+
320
+ self.in_channels = in_channels
321
+
322
+ inner_dim = n_heads * d_head
323
+ self.norm = nn.GroupNorm(
324
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
325
+ )
326
+ self.proj_in = nn.Linear(in_channels, inner_dim)
327
+
328
+ self.transformer_blocks = nn.ModuleList(
329
+ [
330
+ BasicTransformerBlock3D(
331
+ inner_dim,
332
+ n_heads,
333
+ d_head,
334
+ context_dim=context_dim[d],
335
+ dropout=dropout,
336
+ ip_dim=ip_dim,
337
+ ip_weight=ip_weight,
338
+ )
339
+ for d in range(depth)
340
+ ]
341
+ )
342
+
343
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
344
+
345
+ def forward(self, x, context=None, num_frames=1):
346
+ # note: if no context is given, cross-attention defaults to self-attention
347
+ if not isinstance(context, list):
348
+ context = [context]
349
+ b, c, h, w = x.shape
350
+ x_in = x
351
+ x = self.norm(x)
352
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
353
+ x = self.proj_in(x)
354
+ for i, block in enumerate(self.transformer_blocks):
355
+ x = block(x, context=context[i], num_frames=num_frames)
356
+ x = self.proj_out(x)
357
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
358
+
359
+ return x + x_in
360
+
361
+
362
+ class PerceiverAttention(nn.Module):
363
+ def __init__(self, *, dim, dim_head=64, heads=8):
364
+ super().__init__()
365
+ self.scale = dim_head**-0.5
366
+ self.dim_head = dim_head
367
+ self.heads = heads
368
+ inner_dim = dim_head * heads
369
+
370
+ self.norm1 = nn.LayerNorm(dim)
371
+ self.norm2 = nn.LayerNorm(dim)
372
+
373
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
374
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
375
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
376
+
377
+ def forward(self, x, latents):
378
+ """
379
+ Args:
380
+ x (torch.Tensor): image features
381
+ shape (b, n1, D)
382
+ latent (torch.Tensor): latent features
383
+ shape (b, n2, D)
384
+ """
385
+ x = self.norm1(x)
386
+ latents = self.norm2(latents)
387
+
388
+ b, h, _ = latents.shape
389
+
390
+ q = self.to_q(latents)
391
+ kv_input = torch.cat((x, latents), dim=-2)
392
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
393
+
394
+ q, k, v = map(
395
+ lambda t: t.reshape(b, t.shape[1], self.heads, -1)
396
+ .transpose(1, 2)
397
+ .reshape(b, self.heads, t.shape[1], -1)
398
+ .contiguous(),
399
+ (q, k, v),
400
+ )
401
+
402
+ # attention
403
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
404
+ weight = (q * scale) @ (k * scale).transpose(
405
+ -2, -1
406
+ ) # More stable with f16 than dividing afterwards
407
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
408
+ out = weight @ v
409
+
410
+ out = out.permute(0, 2, 1, 3).reshape(b, h, -1)
411
+
412
+ return self.to_out(out)
413
+
414
+
415
+ class Resampler(nn.Module):
416
+ def __init__(
417
+ self,
418
+ dim=1024,
419
+ depth=8,
420
+ dim_head=64,
421
+ heads=16,
422
+ num_queries=8,
423
+ embedding_dim=768,
424
+ output_dim=1024,
425
+ ff_mult=4,
426
+ ):
427
+ super().__init__()
428
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
429
+ self.proj_in = nn.Linear(embedding_dim, dim)
430
+ self.proj_out = nn.Linear(dim, output_dim)
431
+ self.norm_out = nn.LayerNorm(output_dim)
432
+
433
+ self.layers = nn.ModuleList([])
434
+ for _ in range(depth):
435
+ self.layers.append(
436
+ nn.ModuleList(
437
+ [
438
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
439
+ nn.Sequential(
440
+ nn.LayerNorm(dim),
441
+ nn.Linear(dim, dim * ff_mult, bias=False),
442
+ nn.GELU(),
443
+ nn.Linear(dim * ff_mult, dim, bias=False),
444
+ ),
445
+ ]
446
+ )
447
+ )
448
+
449
+ def forward(self, x):
450
+ latents = self.latents.repeat(x.size(0), 1, 1)
451
+ x = self.proj_in(x)
452
+ for attn, ff in self.layers:
453
+ latents = attn(x, latents) + latents
454
+ latents = ff(latents) + latents
455
+
456
+ latents = self.proj_out(latents)
457
+ return self.norm_out(latents)
458
+
459
+
460
+ class CondSequential(nn.Sequential):
461
+ """
462
+ A sequential module that passes timestep embeddings to the children that
463
+ support it as an extra input.
464
+ """
465
+
466
+ def forward(self, x, emb, context=None, num_frames=1):
467
+ for layer in self:
468
+ if isinstance(layer, ResBlock):
469
+ x = layer(x, emb)
470
+ elif isinstance(layer, SpatialTransformer3D):
471
+ x = layer(x, context, num_frames=num_frames)
472
+ else:
473
+ x = layer(x)
474
+ return x
475
+
476
+
477
+ class Upsample(nn.Module):
478
+ """
479
+ An upsampling layer with an optional convolution.
480
+ :param channels: channels in the inputs and outputs.
481
+ :param use_conv: a bool determining if a convolution is applied.
482
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
483
+ upsampling occurs in the inner-two dimensions.
484
+ """
485
+
486
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
487
+ super().__init__()
488
+ self.channels = channels
489
+ self.out_channels = out_channels or channels
490
+ self.use_conv = use_conv
491
+ self.dims = dims
492
+ if use_conv:
493
+ self.conv = conv_nd(
494
+ dims, self.channels, self.out_channels, 3, padding=padding
495
+ )
496
+
497
+ def forward(self, x):
498
+ assert x.shape[1] == self.channels
499
+ if self.dims == 3:
500
+ x = F.interpolate(
501
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
502
+ )
503
+ else:
504
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
505
+ if self.use_conv:
506
+ x = self.conv(x)
507
+ return x
508
+
509
+
510
+ class Downsample(nn.Module):
511
+ """
512
+ A downsampling layer with an optional convolution.
513
+ :param channels: channels in the inputs and outputs.
514
+ :param use_conv: a bool determining if a convolution is applied.
515
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
516
+ downsampling occurs in the inner-two dimensions.
517
+ """
518
+
519
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
520
+ super().__init__()
521
+ self.channels = channels
522
+ self.out_channels = out_channels or channels
523
+ self.use_conv = use_conv
524
+ self.dims = dims
525
+ stride = 2 if dims != 3 else (1, 2, 2)
526
+ if use_conv:
527
+ self.op = conv_nd(
528
+ dims,
529
+ self.channels,
530
+ self.out_channels,
531
+ 3,
532
+ stride=stride,
533
+ padding=padding,
534
+ )
535
+ else:
536
+ assert self.channels == self.out_channels
537
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
538
+
539
+ def forward(self, x):
540
+ assert x.shape[1] == self.channels
541
+ return self.op(x)
542
+
543
+
544
+ class ResBlock(nn.Module):
545
+ """
546
+ A residual block that can optionally change the number of channels.
547
+ :param channels: the number of input channels.
548
+ :param emb_channels: the number of timestep embedding channels.
549
+ :param dropout: the rate of dropout.
550
+ :param out_channels: if specified, the number of out channels.
551
+ :param use_conv: if True and out_channels is specified, use a spatial
552
+ convolution instead of a smaller 1x1 convolution to change the
553
+ channels in the skip connection.
554
+ :param dims: determines if the signal is 1D, 2D, or 3D.
555
+ :param up: if True, use this block for upsampling.
556
+ :param down: if True, use this block for downsampling.
557
+ """
558
+
559
+ def __init__(
560
+ self,
561
+ channels,
562
+ emb_channels,
563
+ dropout,
564
+ out_channels=None,
565
+ use_conv=False,
566
+ use_scale_shift_norm=False,
567
+ dims=2,
568
+ up=False,
569
+ down=False,
570
+ ):
571
+ super().__init__()
572
+ self.channels = channels
573
+ self.emb_channels = emb_channels
574
+ self.dropout = dropout
575
+ self.out_channels = out_channels or channels
576
+ self.use_conv = use_conv
577
+ self.use_scale_shift_norm = use_scale_shift_norm
578
+
579
+ self.in_layers = nn.Sequential(
580
+ nn.GroupNorm(32, channels),
581
+ nn.SiLU(),
582
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
583
+ )
584
+
585
+ self.updown = up or down
586
+
587
+ if up:
588
+ self.h_upd = Upsample(channels, False, dims)
589
+ self.x_upd = Upsample(channels, False, dims)
590
+ elif down:
591
+ self.h_upd = Downsample(channels, False, dims)
592
+ self.x_upd = Downsample(channels, False, dims)
593
+ else:
594
+ self.h_upd = self.x_upd = nn.Identity()
595
+
596
+ self.emb_layers = nn.Sequential(
597
+ nn.SiLU(),
598
+ nn.Linear(
599
+ emb_channels,
600
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
601
+ ),
602
+ )
603
+ self.out_layers = nn.Sequential(
604
+ nn.GroupNorm(32, self.out_channels),
605
+ nn.SiLU(),
606
+ nn.Dropout(p=dropout),
607
+ zero_module(
608
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
609
+ ),
610
+ )
611
+
612
+ if self.out_channels == channels:
613
+ self.skip_connection = nn.Identity()
614
+ elif use_conv:
615
+ self.skip_connection = conv_nd(
616
+ dims, channels, self.out_channels, 3, padding=1
617
+ )
618
+ else:
619
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
620
+
621
+ def forward(self, x, emb):
622
+ if self.updown:
623
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
624
+ h = in_rest(x)
625
+ h = self.h_upd(h)
626
+ x = self.x_upd(x)
627
+ h = in_conv(h)
628
+ else:
629
+ h = self.in_layers(x)
630
+ emb_out = self.emb_layers(emb).type(h.dtype)
631
+ while len(emb_out.shape) < len(h.shape):
632
+ emb_out = emb_out[..., None]
633
+ if self.use_scale_shift_norm:
634
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
635
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
636
+ h = out_norm(h) * (1 + scale) + shift
637
+ h = out_rest(h)
638
+ else:
639
+ h = h + emb_out
640
+ h = self.out_layers(h)
641
+ return self.skip_connection(x) + h
642
+
643
+
644
+ class MultiViewUNetModel(ModelMixin, ConfigMixin):
645
+ """
646
+ The full multi-view UNet model with attention, timestep embedding and camera embedding.
647
+ :param in_channels: channels in the input Tensor.
648
+ :param model_channels: base channel count for the model.
649
+ :param out_channels: channels in the output Tensor.
650
+ :param num_res_blocks: number of residual blocks per downsample.
651
+ :param attention_resolutions: a collection of downsample rates at which
652
+ attention will take place. May be a set, list, or tuple.
653
+ For example, if this contains 4, then at 4x downsampling, attention
654
+ will be used.
655
+ :param dropout: the dropout probability.
656
+ :param channel_mult: channel multiplier for each level of the UNet.
657
+ :param conv_resample: if True, use learned convolutions for upsampling and
658
+ downsampling.
659
+ :param dims: determines if the signal is 1D, 2D, or 3D.
660
+ :param num_classes: if specified (as an int), then this model will be
661
+ class-conditional with `num_classes` classes.
662
+ :param num_heads: the number of attention heads in each attention layer.
663
+ :param num_heads_channels: if specified, ignore num_heads and instead use
664
+ a fixed channel width per attention head.
665
+ :param num_heads_upsample: works with num_heads to set a different number
666
+ of heads for upsampling. Deprecated.
667
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
668
+ :param resblock_updown: use residual blocks for up/downsampling.
669
+ :param use_new_attention_order: use a different attention pattern for potentially
670
+ increased efficiency.
671
+ :param camera_dim: dimensionality of camera input.
672
+ """
673
+
674
+ def __init__(
675
+ self,
676
+ image_size,
677
+ in_channels,
678
+ model_channels,
679
+ out_channels,
680
+ num_res_blocks,
681
+ attention_resolutions,
682
+ dropout=0,
683
+ channel_mult=(1, 2, 4, 8),
684
+ conv_resample=True,
685
+ dims=2,
686
+ num_classes=None,
687
+ num_heads=-1,
688
+ num_head_channels=-1,
689
+ num_heads_upsample=-1,
690
+ use_scale_shift_norm=False,
691
+ resblock_updown=False,
692
+ transformer_depth=1,
693
+ context_dim=None,
694
+ n_embed=None,
695
+ num_attention_blocks=None,
696
+ adm_in_channels=None,
697
+ camera_dim=None,
698
+ ip_dim=0, # imagedream uses ip_dim > 0
699
+ ip_weight=1.0,
700
+ **kwargs,
701
+ ):
702
+ super().__init__()
703
+ assert context_dim is not None
704
+
705
+ if num_heads_upsample == -1:
706
+ num_heads_upsample = num_heads
707
+
708
+ if num_heads == -1:
709
+ assert (
710
+ num_head_channels != -1
711
+ ), "Either num_heads or num_head_channels has to be set"
712
+
713
+ if num_head_channels == -1:
714
+ assert (
715
+ num_heads != -1
716
+ ), "Either num_heads or num_head_channels has to be set"
717
+
718
+ self.image_size = image_size
719
+ self.in_channels = in_channels
720
+ self.model_channels = model_channels
721
+ self.out_channels = out_channels
722
+ if isinstance(num_res_blocks, int):
723
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
724
+ else:
725
+ if len(num_res_blocks) != len(channel_mult):
726
+ raise ValueError(
727
+ "provide num_res_blocks either as an int (globally constant) or "
728
+ "as a list/tuple (per-level) with the same length as channel_mult"
729
+ )
730
+ self.num_res_blocks = num_res_blocks
731
+
732
+ if num_attention_blocks is not None:
733
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
734
+ assert all(
735
+ map(
736
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
737
+ range(len(num_attention_blocks)),
738
+ )
739
+ )
740
+ print(
741
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
742
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
743
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
744
+ f"attention will still not be set."
745
+ )
746
+
747
+ self.attention_resolutions = attention_resolutions
748
+ self.dropout = dropout
749
+ self.channel_mult = channel_mult
750
+ self.conv_resample = conv_resample
751
+ self.num_classes = num_classes
752
+ self.num_heads = num_heads
753
+ self.num_head_channels = num_head_channels
754
+ self.num_heads_upsample = num_heads_upsample
755
+ self.predict_codebook_ids = n_embed is not None
756
+
757
+ self.ip_dim = ip_dim
758
+ self.ip_weight = ip_weight
759
+
760
+ if self.ip_dim > 0:
761
+ self.image_embed = Resampler(
762
+ dim=context_dim,
763
+ depth=4,
764
+ dim_head=64,
765
+ heads=12,
766
+ num_queries=ip_dim, # num token
767
+ embedding_dim=1280,
768
+ output_dim=context_dim,
769
+ ff_mult=4,
770
+ )
771
+
772
+ time_embed_dim = model_channels * 4
773
+ self.time_embed = nn.Sequential(
774
+ nn.Linear(model_channels, time_embed_dim),
775
+ nn.SiLU(),
776
+ nn.Linear(time_embed_dim, time_embed_dim),
777
+ )
778
+
779
+ if camera_dim is not None:
780
+ time_embed_dim = model_channels * 4
781
+ self.camera_embed = nn.Sequential(
782
+ nn.Linear(camera_dim, time_embed_dim),
783
+ nn.SiLU(),
784
+ nn.Linear(time_embed_dim, time_embed_dim),
785
+ )
786
+
787
+ if self.num_classes is not None:
788
+ if isinstance(self.num_classes, int):
789
+ self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
790
+ elif self.num_classes == "continuous":
791
+ # print("setting up linear c_adm embedding layer")
792
+ self.label_emb = nn.Linear(1, time_embed_dim)
793
+ elif self.num_classes == "sequential":
794
+ assert adm_in_channels is not None
795
+ self.label_emb = nn.Sequential(
796
+ nn.Sequential(
797
+ nn.Linear(adm_in_channels, time_embed_dim),
798
+ nn.SiLU(),
799
+ nn.Linear(time_embed_dim, time_embed_dim),
800
+ )
801
+ )
802
+ else:
803
+ raise ValueError()
804
+
805
+ self.input_blocks = nn.ModuleList(
806
+ [CondSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))]
807
+ )
808
+ self._feature_size = model_channels
809
+ input_block_chans = [model_channels]
810
+ ch = model_channels
811
+ ds = 1
812
+ for level, mult in enumerate(channel_mult):
813
+ for nr in range(self.num_res_blocks[level]):
814
+ layers: List[Any] = [
815
+ ResBlock(
816
+ ch,
817
+ time_embed_dim,
818
+ dropout,
819
+ out_channels=mult * model_channels,
820
+ dims=dims,
821
+ use_scale_shift_norm=use_scale_shift_norm,
822
+ )
823
+ ]
824
+ ch = mult * model_channels
825
+ if ds in attention_resolutions:
826
+ if num_head_channels == -1:
827
+ dim_head = ch // num_heads
828
+ else:
829
+ num_heads = ch // num_head_channels
830
+ dim_head = num_head_channels
831
+
832
+ if num_attention_blocks is None or nr < num_attention_blocks[level]:
833
+ layers.append(
834
+ SpatialTransformer3D(
835
+ ch,
836
+ num_heads,
837
+ dim_head,
838
+ context_dim=context_dim,
839
+ depth=transformer_depth,
840
+ ip_dim=self.ip_dim,
841
+ ip_weight=self.ip_weight,
842
+ )
843
+ )
844
+ self.input_blocks.append(CondSequential(*layers))
845
+ self._feature_size += ch
846
+ input_block_chans.append(ch)
847
+ if level != len(channel_mult) - 1:
848
+ out_ch = ch
849
+ self.input_blocks.append(
850
+ CondSequential(
851
+ ResBlock(
852
+ ch,
853
+ time_embed_dim,
854
+ dropout,
855
+ out_channels=out_ch,
856
+ dims=dims,
857
+ use_scale_shift_norm=use_scale_shift_norm,
858
+ down=True,
859
+ )
860
+ if resblock_updown
861
+ else Downsample(
862
+ ch, conv_resample, dims=dims, out_channels=out_ch
863
+ )
864
+ )
865
+ )
866
+ ch = out_ch
867
+ input_block_chans.append(ch)
868
+ ds *= 2
869
+ self._feature_size += ch
870
+
871
+ if num_head_channels == -1:
872
+ dim_head = ch // num_heads
873
+ else:
874
+ num_heads = ch // num_head_channels
875
+ dim_head = num_head_channels
876
+
877
+ self.middle_block = CondSequential(
878
+ ResBlock(
879
+ ch,
880
+ time_embed_dim,
881
+ dropout,
882
+ dims=dims,
883
+ use_scale_shift_norm=use_scale_shift_norm,
884
+ ),
885
+ SpatialTransformer3D(
886
+ ch,
887
+ num_heads,
888
+ dim_head,
889
+ context_dim=context_dim,
890
+ depth=transformer_depth,
891
+ ip_dim=self.ip_dim,
892
+ ip_weight=self.ip_weight,
893
+ ),
894
+ ResBlock(
895
+ ch,
896
+ time_embed_dim,
897
+ dropout,
898
+ dims=dims,
899
+ use_scale_shift_norm=use_scale_shift_norm,
900
+ ),
901
+ )
902
+ self._feature_size += ch
903
+
904
+ self.output_blocks = nn.ModuleList([])
905
+ for level, mult in list(enumerate(channel_mult))[::-1]:
906
+ for i in range(self.num_res_blocks[level] + 1):
907
+ ich = input_block_chans.pop()
908
+ layers = [
909
+ ResBlock(
910
+ ch + ich,
911
+ time_embed_dim,
912
+ dropout,
913
+ out_channels=model_channels * mult,
914
+ dims=dims,
915
+ use_scale_shift_norm=use_scale_shift_norm,
916
+ )
917
+ ]
918
+ ch = model_channels * mult
919
+ if ds in attention_resolutions:
920
+ if num_head_channels == -1:
921
+ dim_head = ch // num_heads
922
+ else:
923
+ num_heads = ch // num_head_channels
924
+ dim_head = num_head_channels
925
+
926
+ if num_attention_blocks is None or i < num_attention_blocks[level]:
927
+ layers.append(
928
+ SpatialTransformer3D(
929
+ ch,
930
+ num_heads,
931
+ dim_head,
932
+ context_dim=context_dim,
933
+ depth=transformer_depth,
934
+ ip_dim=self.ip_dim,
935
+ ip_weight=self.ip_weight,
936
+ )
937
+ )
938
+ if level and i == self.num_res_blocks[level]:
939
+ out_ch = ch
940
+ layers.append(
941
+ ResBlock(
942
+ ch,
943
+ time_embed_dim,
944
+ dropout,
945
+ out_channels=out_ch,
946
+ dims=dims,
947
+ use_scale_shift_norm=use_scale_shift_norm,
948
+ up=True,
949
+ )
950
+ if resblock_updown
951
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
952
+ )
953
+ ds //= 2
954
+ self.output_blocks.append(CondSequential(*layers))
955
+ self._feature_size += ch
956
+
957
+ self.out = nn.Sequential(
958
+ nn.GroupNorm(32, ch),
959
+ nn.SiLU(),
960
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
961
+ )
962
+ if self.predict_codebook_ids:
963
+ self.id_predictor = nn.Sequential(
964
+ nn.GroupNorm(32, ch),
965
+ conv_nd(dims, model_channels, n_embed, 1),
966
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
967
+ )
968
+
969
+ def forward(
970
+ self,
971
+ x,
972
+ timesteps=None,
973
+ context=None,
974
+ y=None,
975
+ camera=None,
976
+ num_frames=1,
977
+ ip=None,
978
+ ip_img=None,
979
+ **kwargs,
980
+ ):
981
+ """
982
+ Apply the model to an input batch.
983
+ :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
984
+ :param timesteps: a 1-D batch of timesteps.
985
+ :param context: conditioning plugged in via crossattn
986
+ :param y: an [N] Tensor of labels, if class-conditional.
987
+ :param num_frames: a integer indicating number of frames for tensor reshaping.
988
+ :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
989
+ """
990
+ assert (
991
+ x.shape[0] % num_frames == 0
992
+ ), "input batch size must be dividable by num_frames!"
993
+ assert (y is not None) == (
994
+ self.num_classes is not None
995
+ ), "must specify y if and only if the model is class-conditional"
996
+
997
+ hs = []
998
+
999
+ t_emb = timestep_embedding(
1000
+ timesteps, self.model_channels, repeat_only=False
1001
+ ).to(x.dtype)
1002
+
1003
+ emb = self.time_embed(t_emb)
1004
+
1005
+ if self.num_classes is not None:
1006
+ assert y is not None
1007
+ assert y.shape[0] == x.shape[0]
1008
+ emb = emb + self.label_emb(y)
1009
+
1010
+ # Add camera embeddings
1011
+ if camera is not None:
1012
+ emb = emb + self.camera_embed(camera)
1013
+
1014
+ # imagedream variant
1015
+ if self.ip_dim > 0:
1016
+ x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9]
1017
+ ip_emb = self.image_embed(ip)
1018
+ context = torch.cat((context, ip_emb), 1)
1019
+
1020
+ h = x
1021
+ for module in self.input_blocks:
1022
+ h = module(h, emb, context, num_frames=num_frames)
1023
+ hs.append(h)
1024
+ h = self.middle_block(h, emb, context, num_frames=num_frames)
1025
+ for module in self.output_blocks:
1026
+ h = torch.cat([h, hs.pop()], dim=1)
1027
+ h = module(h, emb, context, num_frames=num_frames)
1028
+ h = h.type(x.dtype)
1029
+ if self.predict_codebook_ids:
1030
+ return self.id_predictor(h)
1031
+ else:
1032
+ return self.out(h)
1033
+
1034
 
1035
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
1036
 
 
1419
 
1420
  if image.dtype == np.float32:
1421
  image = (image * 255).astype(np.uint8)
1422
+
1423
  image = self.feature_extractor(image, return_tensors="pt").pixel_values
1424
  image = image.to(device=device, dtype=dtype)
1425
+
1426
+ image_embeds = self.image_encoder(
1427
+ image, output_hidden_states=True
1428
+ ).hidden_states[-2]
1429
  image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
1430
 
1431
  return torch.zeros_like(image_embeds), image_embeds
1432
 
1433
  def encode_image_latents(self, image, device, num_images_per_prompt):
1434
+
1435
  dtype = next(self.image_encoder.parameters()).dtype
1436
 
1437
+ image = (
1438
+ torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device)
1439
+ ) # [1, 3, H, W]
1440
  image = 2 * image - 1
1441
+ image = F.interpolate(image, (256, 256), mode="bilinear", align_corners=False)
1442
  image = image.to(dtype=dtype)
1443
 
1444
  posterior = self.vae.encode(image).latent_dist
1445
+ latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W]
1446
  latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
1447
 
1448
  return torch.zeros_like(latents), latents
 
1461
  num_images_per_prompt: int = 1,
1462
  eta: float = 0.0,
1463
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1464
+ output_type: Optional[str] = "numpy", # pil, numpy, latents
1465
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1466
  callback_steps: int = 1,
1467
  num_frames: int = 4,
 
1484
  if image is not None:
1485
  assert isinstance(image, np.ndarray) and image.dtype == np.float32
1486
  self.image_encoder = self.image_encoder.to(device=device)
1487
+ image_embeds_neg, image_embeds_pos = self.encode_image(
1488
+ image, device, num_images_per_prompt
1489
+ )
1490
+ image_latents_neg, image_latents_pos = self.encode_image_latents(
1491
+ image, device, num_images_per_prompt
1492
+ )
1493
+
1494
  _prompt_embeds = self._encode_prompt(
1495
  prompt=prompt,
1496
  device=device,
 
1514
  )
1515
 
1516
  # Get camera
1517
+ camera = get_camera(
1518
+ num_frames, elevation=elevation, extra_view=(image is not None)
1519
+ ).to(dtype=latents.dtype, device=device)
1520
  camera = camera.repeat_interleave(num_images_per_prompt, dim=0)
1521
 
1522
  # Prepare extra step kwargs.
 
1529
  # expand the latents if we are doing classifier free guidance
1530
  multiplier = 2 if do_classifier_free_guidance else 1
1531
  latent_model_input = torch.cat([latents] * multiplier)
1532
+ latent_model_input = self.scheduler.scale_model_input(
1533
+ latent_model_input, t
1534
+ )
1535
 
1536
  unet_inputs = {
1537
+ "x": latent_model_input,
1538
+ "timesteps": torch.tensor(
1539
+ [t] * actual_num_frames * multiplier,
1540
+ dtype=latent_model_input.dtype,
1541
+ device=device,
1542
+ ),
1543
+ "context": torch.cat(
1544
+ [prompt_embeds_neg] * actual_num_frames
1545
+ + [prompt_embeds_pos] * actual_num_frames
1546
+ ),
1547
+ "num_frames": actual_num_frames,
1548
+ "camera": torch.cat([camera] * multiplier),
1549
  }
1550
 
1551
  if image is not None:
1552
+ unet_inputs["ip"] = torch.cat(
1553
+ [image_embeds_neg] * actual_num_frames
1554
+ + [image_embeds_pos] * actual_num_frames
1555
+ )
1556
+ unet_inputs["ip_img"] = torch.cat(
1557
+ [image_latents_neg] + [image_latents_pos]
1558
+ ) # no repeat
1559
+
1560
  # predict the noise residual
1561
  noise_pred = self.unet.forward(**unet_inputs)
1562
 
 
1586
  elif output_type == "pil":
1587
  image = self.decode_latents(latents)
1588
  image = self.numpy_to_pil(image)
1589
+ else: # numpy
1590
  image = self.decode_latents(latents)
1591
 
1592
  # Offload last model to CPU