ashawkey commited on
Commit
0cf459d
•
1 Parent(s): 1d16c41

merged to one file for loading from huggingface...

Browse files
convert_mvdream_to_diffusers.py CHANGED
@@ -15,10 +15,10 @@ from diffusers.utils import logging
15
  from typing import Any
16
  from accelerate import init_empty_weights
17
  from accelerate.utils import set_module_tensor_to_device
18
- from mvdream.models import MultiViewUNetModel
19
- from mvdream.pipeline_mvdream import MVDreamPipeline
20
  from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor
21
 
 
 
22
  import kiui
23
 
24
  logger = logging.get_logger(__name__)
 
15
  from typing import Any
16
  from accelerate import init_empty_weights
17
  from accelerate.utils import set_module_tensor_to_device
 
 
18
  from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor
19
 
20
+ from mv_unet import MultiViewUNetModel
21
+ from pipeline_mvdream import MVDreamPipeline
22
  import kiui
23
 
24
  logger = logging.get_logger(__name__)
mvdream/models.py → mv_unet.py RENAMED
@@ -1,19 +1,490 @@
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
 
4
  from diffusers.configuration_utils import ConfigMixin
5
  from diffusers.models.modeling_utils import ModelMixin
6
- from typing import Any, List
7
-
8
- from .util import (
9
- checkpoint,
10
- conv_nd,
11
- avg_pool_nd,
12
- zero_module,
13
- timestep_embedding,
14
- )
15
- from .attention import SpatialTransformer3D
16
- from .adaptor import Resampler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  class CondSequential(nn.Sequential):
19
  """
@@ -615,4 +1086,4 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
615
  if self.predict_codebook_ids:
616
  return self.id_predictor(h)
617
  else:
618
- return self.out(h)
 
1
+ import math
2
+ import numpy as np
3
+ from inspect import isfunction
4
+ from typing import Optional, Any, List
5
+
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+
11
  from diffusers.configuration_utils import ConfigMixin
12
  from diffusers.models.modeling_utils import ModelMixin
13
+
14
+ # require xformers!
15
+ import xformers
16
+ import xformers.ops
17
+
18
+ from kiui.cam import orbit_camera
19
+
20
+ def get_camera(
21
+ num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False,
22
+ ):
23
+ angle_gap = azimuth_span / num_frames
24
+ cameras = []
25
+ for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
26
+
27
+ pose = orbit_camera(-elevation, azimuth, radius=1) # kiui's elevation is negated, [4, 4]
28
+
29
+ # opengl to blender
30
+ if blender_coord:
31
+ pose[2] *= -1
32
+ pose[[1, 2]] = pose[[2, 1]]
33
+
34
+ cameras.append(pose.flatten())
35
+
36
+ if extra_view:
37
+ cameras.append(np.zeros_like(cameras[0]))
38
+
39
+ return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
40
+
41
+
42
+ def checkpoint(func, inputs, params, flag):
43
+ """
44
+ Evaluate a function without caching intermediate activations, allowing for
45
+ reduced memory at the expense of extra compute in the backward pass.
46
+ :param func: the function to evaluate.
47
+ :param inputs: the argument sequence to pass to `func`.
48
+ :param params: a sequence of parameters `func` depends on but does not
49
+ explicitly take as arguments.
50
+ :param flag: if False, disable gradient checkpointing.
51
+ """
52
+ if flag:
53
+ args = tuple(inputs) + tuple(params)
54
+ return CheckpointFunction.apply(func, len(inputs), *args)
55
+ else:
56
+ return func(*inputs)
57
+
58
+
59
+ class CheckpointFunction(torch.autograd.Function):
60
+ @staticmethod
61
+ def forward(ctx, run_function, length, *args):
62
+ ctx.run_function = run_function
63
+ ctx.input_tensors = list(args[:length])
64
+ ctx.input_params = list(args[length:])
65
+
66
+ with torch.no_grad():
67
+ output_tensors = ctx.run_function(*ctx.input_tensors)
68
+ return output_tensors
69
+
70
+ @staticmethod
71
+ def backward(ctx, *output_grads):
72
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
73
+ with torch.enable_grad():
74
+ # Fixes a bug where the first op in run_function modifies the
75
+ # Tensor storage in place, which is not allowed for detach()'d
76
+ # Tensors.
77
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
78
+ output_tensors = ctx.run_function(*shallow_copies)
79
+ input_grads = torch.autograd.grad(
80
+ output_tensors,
81
+ ctx.input_tensors + ctx.input_params,
82
+ output_grads,
83
+ allow_unused=True,
84
+ )
85
+ del ctx.input_tensors
86
+ del ctx.input_params
87
+ del output_tensors
88
+ return (None, None) + input_grads
89
+
90
+
91
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
92
+ """
93
+ Create sinusoidal timestep embeddings.
94
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
95
+ These may be fractional.
96
+ :param dim: the dimension of the output.
97
+ :param max_period: controls the minimum frequency of the embeddings.
98
+ :return: an [N x dim] Tensor of positional embeddings.
99
+ """
100
+ if not repeat_only:
101
+ half = dim // 2
102
+ freqs = torch.exp(
103
+ -math.log(max_period)
104
+ * torch.arange(start=0, end=half, dtype=torch.float32)
105
+ / half
106
+ ).to(device=timesteps.device)
107
+ args = timesteps[:, None] * freqs[None]
108
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
109
+ if dim % 2:
110
+ embedding = torch.cat(
111
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
112
+ )
113
+ else:
114
+ embedding = repeat(timesteps, "b -> b d", d=dim)
115
+ # import pdb; pdb.set_trace()
116
+ return embedding
117
+
118
+
119
+ def zero_module(module):
120
+ """
121
+ Zero out the parameters of a module and return it.
122
+ """
123
+ for p in module.parameters():
124
+ p.detach().zero_()
125
+ return module
126
+
127
+
128
+ def conv_nd(dims, *args, **kwargs):
129
+ """
130
+ Create a 1D, 2D, or 3D convolution module.
131
+ """
132
+ if dims == 1:
133
+ return nn.Conv1d(*args, **kwargs)
134
+ elif dims == 2:
135
+ return nn.Conv2d(*args, **kwargs)
136
+ elif dims == 3:
137
+ return nn.Conv3d(*args, **kwargs)
138
+ raise ValueError(f"unsupported dimensions: {dims}")
139
+
140
+
141
+ def avg_pool_nd(dims, *args, **kwargs):
142
+ """
143
+ Create a 1D, 2D, or 3D average pooling module.
144
+ """
145
+ if dims == 1:
146
+ return nn.AvgPool1d(*args, **kwargs)
147
+ elif dims == 2:
148
+ return nn.AvgPool2d(*args, **kwargs)
149
+ elif dims == 3:
150
+ return nn.AvgPool3d(*args, **kwargs)
151
+ raise ValueError(f"unsupported dimensions: {dims}")
152
+
153
+
154
+ def default(val, d):
155
+ if val is not None:
156
+ return val
157
+ return d() if isfunction(d) else d
158
+
159
+
160
+ class GEGLU(nn.Module):
161
+ def __init__(self, dim_in, dim_out):
162
+ super().__init__()
163
+ self.proj = nn.Linear(dim_in, dim_out * 2)
164
+
165
+ def forward(self, x):
166
+ x, gate = self.proj(x).chunk(2, dim=-1)
167
+ return x * F.gelu(gate)
168
+
169
+
170
+ class FeedForward(nn.Module):
171
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
172
+ super().__init__()
173
+ inner_dim = int(dim * mult)
174
+ dim_out = default(dim_out, dim)
175
+ project_in = (
176
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
177
+ if not glu
178
+ else GEGLU(dim, inner_dim)
179
+ )
180
+
181
+ self.net = nn.Sequential(
182
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
183
+ )
184
+
185
+ def forward(self, x):
186
+ return self.net(x)
187
+
188
+
189
+ class MemoryEfficientCrossAttention(nn.Module):
190
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
191
+ def __init__(
192
+ self,
193
+ query_dim,
194
+ context_dim=None,
195
+ heads=8,
196
+ dim_head=64,
197
+ dropout=0.0,
198
+ ip_dim=0,
199
+ ip_weight=1,
200
+ ):
201
+ super().__init__()
202
+
203
+ inner_dim = dim_head * heads
204
+ context_dim = default(context_dim, query_dim)
205
+
206
+ self.heads = heads
207
+ self.dim_head = dim_head
208
+
209
+ self.ip_dim = ip_dim
210
+ self.ip_weight = ip_weight
211
+
212
+ if self.ip_dim > 0:
213
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
214
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
215
+
216
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
217
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
218
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
219
+
220
+ self.to_out = nn.Sequential(
221
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
222
+ )
223
+ self.attention_op: Optional[Any] = None
224
+
225
+ def forward(self, x, context=None):
226
+ q = self.to_q(x)
227
+ context = default(context, x)
228
+
229
+ if self.ip_dim > 0:
230
+ # context: [B, 77 + 16(ip), 1024]
231
+ token_len = context.shape[1]
232
+ context_ip = context[:, -self.ip_dim :, :]
233
+ k_ip = self.to_k_ip(context_ip)
234
+ v_ip = self.to_v_ip(context_ip)
235
+ context = context[:, : (token_len - self.ip_dim), :]
236
+
237
+ k = self.to_k(context)
238
+ v = self.to_v(context)
239
+
240
+ b, _, _ = q.shape
241
+ q, k, v = map(
242
+ lambda t: t.unsqueeze(3)
243
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
244
+ .permute(0, 2, 1, 3)
245
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
246
+ .contiguous(),
247
+ (q, k, v),
248
+ )
249
+
250
+ # actually compute the attention, what we cannot get enough of
251
+ out = xformers.ops.memory_efficient_attention(
252
+ q, k, v, attn_bias=None, op=self.attention_op
253
+ )
254
+
255
+ if self.ip_dim > 0:
256
+ k_ip, v_ip = map(
257
+ lambda t: t.unsqueeze(3)
258
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
259
+ .permute(0, 2, 1, 3)
260
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
261
+ .contiguous(),
262
+ (k_ip, v_ip),
263
+ )
264
+ # actually compute the attention, what we cannot get enough of
265
+ out_ip = xformers.ops.memory_efficient_attention(
266
+ q, k_ip, v_ip, attn_bias=None, op=self.attention_op
267
+ )
268
+ out = out + self.ip_weight * out_ip
269
+
270
+ out = (
271
+ out.unsqueeze(0)
272
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
273
+ .permute(0, 2, 1, 3)
274
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
275
+ )
276
+ return self.to_out(out)
277
+
278
+
279
+ class BasicTransformerBlock3D(nn.Module):
280
+
281
+ def __init__(
282
+ self,
283
+ dim,
284
+ n_heads,
285
+ d_head,
286
+ context_dim,
287
+ dropout=0.0,
288
+ gated_ff=True,
289
+ checkpoint=True,
290
+ ip_dim=0,
291
+ ip_weight=1,
292
+ ):
293
+ super().__init__()
294
+
295
+ self.attn1 = MemoryEfficientCrossAttention(
296
+ query_dim=dim,
297
+ context_dim=None, # self-attention
298
+ heads=n_heads,
299
+ dim_head=d_head,
300
+ dropout=dropout,
301
+ )
302
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
303
+ self.attn2 = MemoryEfficientCrossAttention(
304
+ query_dim=dim,
305
+ context_dim=context_dim,
306
+ heads=n_heads,
307
+ dim_head=d_head,
308
+ dropout=dropout,
309
+ # ip only applies to cross-attention
310
+ ip_dim=ip_dim,
311
+ ip_weight=ip_weight,
312
+ )
313
+ self.norm1 = nn.LayerNorm(dim)
314
+ self.norm2 = nn.LayerNorm(dim)
315
+ self.norm3 = nn.LayerNorm(dim)
316
+ self.checkpoint = checkpoint
317
+
318
+ def forward(self, x, context=None, num_frames=1):
319
+ return checkpoint(
320
+ self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
321
+ )
322
+
323
+ def _forward(self, x, context=None, num_frames=1):
324
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
325
+ x = self.attn1(self.norm1(x), context=None) + x
326
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
327
+ x = self.attn2(self.norm2(x), context=context) + x
328
+ x = self.ff(self.norm3(x)) + x
329
+ return x
330
+
331
+
332
+ class SpatialTransformer3D(nn.Module):
333
+
334
+ def __init__(
335
+ self,
336
+ in_channels,
337
+ n_heads,
338
+ d_head,
339
+ context_dim, # cross attention input dim
340
+ depth=1,
341
+ dropout=0.0,
342
+ ip_dim=0,
343
+ ip_weight=1,
344
+ use_checkpoint=True,
345
+ ):
346
+ super().__init__()
347
+
348
+ if not isinstance(context_dim, list):
349
+ context_dim = [context_dim]
350
+
351
+ self.in_channels = in_channels
352
+
353
+ inner_dim = n_heads * d_head
354
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
355
+ self.proj_in = nn.Linear(in_channels, inner_dim)
356
+
357
+ self.transformer_blocks = nn.ModuleList(
358
+ [
359
+ BasicTransformerBlock3D(
360
+ inner_dim,
361
+ n_heads,
362
+ d_head,
363
+ context_dim=context_dim[d],
364
+ dropout=dropout,
365
+ checkpoint=use_checkpoint,
366
+ ip_dim=ip_dim,
367
+ ip_weight=ip_weight,
368
+ )
369
+ for d in range(depth)
370
+ ]
371
+ )
372
+
373
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
374
+
375
+
376
+ def forward(self, x, context=None, num_frames=1):
377
+ # note: if no context is given, cross-attention defaults to self-attention
378
+ if not isinstance(context, list):
379
+ context = [context]
380
+ b, c, h, w = x.shape
381
+ x_in = x
382
+ x = self.norm(x)
383
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
384
+ x = self.proj_in(x)
385
+ for i, block in enumerate(self.transformer_blocks):
386
+ x = block(x, context=context[i], num_frames=num_frames)
387
+ x = self.proj_out(x)
388
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
389
+
390
+ return x + x_in
391
+
392
+
393
+ class PerceiverAttention(nn.Module):
394
+ def __init__(self, *, dim, dim_head=64, heads=8):
395
+ super().__init__()
396
+ self.scale = dim_head ** -0.5
397
+ self.dim_head = dim_head
398
+ self.heads = heads
399
+ inner_dim = dim_head * heads
400
+
401
+ self.norm1 = nn.LayerNorm(dim)
402
+ self.norm2 = nn.LayerNorm(dim)
403
+
404
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
405
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
406
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
407
+
408
+ def forward(self, x, latents):
409
+ """
410
+ Args:
411
+ x (torch.Tensor): image features
412
+ shape (b, n1, D)
413
+ latent (torch.Tensor): latent features
414
+ shape (b, n2, D)
415
+ """
416
+ x = self.norm1(x)
417
+ latents = self.norm2(latents)
418
+
419
+ b, l, _ = latents.shape
420
+
421
+ q = self.to_q(latents)
422
+ kv_input = torch.cat((x, latents), dim=-2)
423
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
424
+
425
+ q, k, v = map(
426
+ lambda t: t.reshape(b, t.shape[1], self.heads, -1)
427
+ .transpose(1, 2)
428
+ .reshape(b, self.heads, t.shape[1], -1)
429
+ .contiguous(),
430
+ (q, k, v),
431
+ )
432
+
433
+ # attention
434
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
435
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
436
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
437
+ out = weight @ v
438
+
439
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
440
+
441
+ return self.to_out(out)
442
+
443
+
444
+ class Resampler(nn.Module):
445
+ def __init__(
446
+ self,
447
+ dim=1024,
448
+ depth=8,
449
+ dim_head=64,
450
+ heads=16,
451
+ num_queries=8,
452
+ embedding_dim=768,
453
+ output_dim=1024,
454
+ ff_mult=4,
455
+ ):
456
+ super().__init__()
457
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
458
+ self.proj_in = nn.Linear(embedding_dim, dim)
459
+ self.proj_out = nn.Linear(dim, output_dim)
460
+ self.norm_out = nn.LayerNorm(output_dim)
461
+
462
+ self.layers = nn.ModuleList([])
463
+ for _ in range(depth):
464
+ self.layers.append(
465
+ nn.ModuleList(
466
+ [
467
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
468
+ nn.Sequential(
469
+ nn.LayerNorm(dim),
470
+ nn.Linear(dim, dim * ff_mult, bias=False),
471
+ nn.GELU(),
472
+ nn.Linear(dim * ff_mult, dim, bias=False),
473
+ )
474
+ ]
475
+ )
476
+ )
477
+
478
+ def forward(self, x):
479
+ latents = self.latents.repeat(x.size(0), 1, 1)
480
+ x = self.proj_in(x)
481
+ for attn, ff in self.layers:
482
+ latents = attn(x, latents) + latents
483
+ latents = ff(latents) + latents
484
+
485
+ latents = self.proj_out(latents)
486
+ return self.norm_out(latents)
487
+
488
 
489
  class CondSequential(nn.Sequential):
490
  """
 
1086
  if self.predict_codebook_ids:
1087
  return self.id_predictor(h)
1088
  else:
1089
+ return self.out(h)
mvdream/adaptor.py DELETED
@@ -1,113 +0,0 @@
1
- import math
2
- import torch
3
- import torch.nn as nn
4
-
5
- # FFN
6
- def FeedForward(dim, mult=4):
7
- inner_dim = int(dim * mult)
8
- return nn.Sequential(
9
- nn.LayerNorm(dim),
10
- nn.Linear(dim, inner_dim, bias=False),
11
- nn.GELU(),
12
- nn.Linear(inner_dim, dim, bias=False),
13
- )
14
-
15
-
16
- def reshape_tensor(x, heads):
17
- bs, length, width = x.shape
18
- # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
19
- x = x.view(bs, length, heads, -1)
20
- # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
21
- x = x.transpose(1, 2)
22
- # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
23
- x = x.reshape(bs, heads, length, -1)
24
- return x
25
-
26
-
27
- class PerceiverAttention(nn.Module):
28
- def __init__(self, *, dim, dim_head=64, heads=8):
29
- super().__init__()
30
- self.scale = dim_head ** -0.5
31
- self.dim_head = dim_head
32
- self.heads = heads
33
- inner_dim = dim_head * heads
34
-
35
- self.norm1 = nn.LayerNorm(dim)
36
- self.norm2 = nn.LayerNorm(dim)
37
-
38
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
39
- self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
40
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
41
-
42
- def forward(self, x, latents):
43
- """
44
- Args:
45
- x (torch.Tensor): image features
46
- shape (b, n1, D)
47
- latent (torch.Tensor): latent features
48
- shape (b, n2, D)
49
- """
50
- x = self.norm1(x)
51
- latents = self.norm2(latents)
52
-
53
- b, l, _ = latents.shape
54
-
55
- q = self.to_q(latents)
56
- kv_input = torch.cat((x, latents), dim=-2)
57
- k, v = self.to_kv(kv_input).chunk(2, dim=-1)
58
-
59
- q = reshape_tensor(q, self.heads)
60
- k = reshape_tensor(k, self.heads)
61
- v = reshape_tensor(v, self.heads)
62
-
63
- # attention
64
- scale = 1 / math.sqrt(math.sqrt(self.dim_head))
65
- weight = (q * scale) @ (k * scale).transpose(
66
- -2, -1
67
- ) # More stable with f16 than dividing afterwards
68
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
69
- out = weight @ v
70
-
71
- out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
72
-
73
- return self.to_out(out)
74
-
75
-
76
- class Resampler(nn.Module):
77
- def __init__(
78
- self,
79
- dim=1024,
80
- depth=8,
81
- dim_head=64,
82
- heads=16,
83
- num_queries=8,
84
- embedding_dim=768,
85
- output_dim=1024,
86
- ff_mult=4,
87
- ):
88
- super().__init__()
89
- self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
90
- self.proj_in = nn.Linear(embedding_dim, dim)
91
- self.proj_out = nn.Linear(dim, output_dim)
92
- self.norm_out = nn.LayerNorm(output_dim)
93
-
94
- self.layers = nn.ModuleList([])
95
- for _ in range(depth):
96
- self.layers.append(
97
- nn.ModuleList(
98
- [
99
- PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
100
- FeedForward(dim=dim, mult=ff_mult),
101
- ]
102
- )
103
- )
104
-
105
- def forward(self, x):
106
- latents = self.latents.repeat(x.size(0), 1, 1)
107
- x = self.proj_in(x)
108
- for attn, ff in self.layers:
109
- latents = attn(x, latents) + latents
110
- latents = ff(latents) + latents
111
-
112
- latents = self.proj_out(latents)
113
- return self.norm_out(latents)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mvdream/attention.py DELETED
@@ -1,251 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from inspect import isfunction
6
- from einops import rearrange, repeat
7
- from typing import Optional, Any
8
-
9
- # require xformers
10
- import xformers # type: ignore
11
- import xformers.ops # type: ignore
12
-
13
- from .util import checkpoint, zero_module
14
-
15
- def default(val, d):
16
- if val is not None:
17
- return val
18
- return d() if isfunction(d) else d
19
-
20
-
21
- class GEGLU(nn.Module):
22
- def __init__(self, dim_in, dim_out):
23
- super().__init__()
24
- self.proj = nn.Linear(dim_in, dim_out * 2)
25
-
26
- def forward(self, x):
27
- x, gate = self.proj(x).chunk(2, dim=-1)
28
- return x * F.gelu(gate)
29
-
30
-
31
- class FeedForward(nn.Module):
32
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
33
- super().__init__()
34
- inner_dim = int(dim * mult)
35
- dim_out = default(dim_out, dim)
36
- project_in = (
37
- nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
38
- if not glu
39
- else GEGLU(dim, inner_dim)
40
- )
41
-
42
- self.net = nn.Sequential(
43
- project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
44
- )
45
-
46
- def forward(self, x):
47
- return self.net(x)
48
-
49
-
50
- class MemoryEfficientCrossAttention(nn.Module):
51
- # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
52
- def __init__(
53
- self,
54
- query_dim,
55
- context_dim=None,
56
- heads=8,
57
- dim_head=64,
58
- dropout=0.0,
59
- ip_dim=0,
60
- ip_weight=1,
61
- ):
62
- super().__init__()
63
-
64
- inner_dim = dim_head * heads
65
- context_dim = default(context_dim, query_dim)
66
-
67
- self.heads = heads
68
- self.dim_head = dim_head
69
-
70
- self.ip_dim = ip_dim
71
- self.ip_weight = ip_weight
72
-
73
- if self.ip_dim > 0:
74
- self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
75
- self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
76
-
77
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
78
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
79
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
80
-
81
- self.to_out = nn.Sequential(
82
- nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
83
- )
84
- self.attention_op: Optional[Any] = None
85
-
86
- def forward(self, x, context=None):
87
- q = self.to_q(x)
88
- context = default(context, x)
89
-
90
- if self.ip_dim > 0:
91
- # context: [B, 77 + 16(ip), 1024]
92
- token_len = context.shape[1]
93
- context_ip = context[:, -self.ip_dim :, :]
94
- k_ip = self.to_k_ip(context_ip)
95
- v_ip = self.to_v_ip(context_ip)
96
- context = context[:, : (token_len - self.ip_dim), :]
97
-
98
- k = self.to_k(context)
99
- v = self.to_v(context)
100
-
101
- b, _, _ = q.shape
102
- q, k, v = map(
103
- lambda t: t.unsqueeze(3)
104
- .reshape(b, t.shape[1], self.heads, self.dim_head)
105
- .permute(0, 2, 1, 3)
106
- .reshape(b * self.heads, t.shape[1], self.dim_head)
107
- .contiguous(),
108
- (q, k, v),
109
- )
110
-
111
- # actually compute the attention, what we cannot get enough of
112
- out = xformers.ops.memory_efficient_attention(
113
- q, k, v, attn_bias=None, op=self.attention_op
114
- )
115
-
116
- if self.ip_dim > 0:
117
- k_ip, v_ip = map(
118
- lambda t: t.unsqueeze(3)
119
- .reshape(b, t.shape[1], self.heads, self.dim_head)
120
- .permute(0, 2, 1, 3)
121
- .reshape(b * self.heads, t.shape[1], self.dim_head)
122
- .contiguous(),
123
- (k_ip, v_ip),
124
- )
125
- # actually compute the attention, what we cannot get enough of
126
- out_ip = xformers.ops.memory_efficient_attention(
127
- q, k_ip, v_ip, attn_bias=None, op=self.attention_op
128
- )
129
- out = out + self.ip_weight * out_ip
130
-
131
- out = (
132
- out.unsqueeze(0)
133
- .reshape(b, self.heads, out.shape[1], self.dim_head)
134
- .permute(0, 2, 1, 3)
135
- .reshape(b, out.shape[1], self.heads * self.dim_head)
136
- )
137
- return self.to_out(out)
138
-
139
-
140
- class BasicTransformerBlock3D(nn.Module):
141
-
142
- def __init__(
143
- self,
144
- dim,
145
- n_heads,
146
- d_head,
147
- context_dim,
148
- dropout=0.0,
149
- gated_ff=True,
150
- checkpoint=True,
151
- ip_dim=0,
152
- ip_weight=1,
153
- ):
154
- super().__init__()
155
-
156
- self.attn1 = MemoryEfficientCrossAttention(
157
- query_dim=dim,
158
- context_dim=None, # self-attention
159
- heads=n_heads,
160
- dim_head=d_head,
161
- dropout=dropout,
162
- )
163
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
164
- self.attn2 = MemoryEfficientCrossAttention(
165
- query_dim=dim,
166
- context_dim=context_dim,
167
- heads=n_heads,
168
- dim_head=d_head,
169
- dropout=dropout,
170
- # ip only applies to cross-attention
171
- ip_dim=ip_dim,
172
- ip_weight=ip_weight,
173
- )
174
- self.norm1 = nn.LayerNorm(dim)
175
- self.norm2 = nn.LayerNorm(dim)
176
- self.norm3 = nn.LayerNorm(dim)
177
- self.checkpoint = checkpoint
178
-
179
- def forward(self, x, context=None, num_frames=1):
180
- return checkpoint(
181
- self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
182
- )
183
-
184
- def _forward(self, x, context=None, num_frames=1):
185
- x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
186
- x = self.attn1(self.norm1(x), context=None) + x
187
- x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
188
- x = self.attn2(self.norm2(x), context=context) + x
189
- x = self.ff(self.norm3(x)) + x
190
- return x
191
-
192
-
193
- class SpatialTransformer3D(nn.Module):
194
-
195
- def __init__(
196
- self,
197
- in_channels,
198
- n_heads,
199
- d_head,
200
- context_dim, # cross attention input dim
201
- depth=1,
202
- dropout=0.0,
203
- ip_dim=0,
204
- ip_weight=1,
205
- use_checkpoint=True,
206
- ):
207
- super().__init__()
208
-
209
- if not isinstance(context_dim, list):
210
- context_dim = [context_dim]
211
-
212
- self.in_channels = in_channels
213
-
214
- inner_dim = n_heads * d_head
215
- self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
216
- self.proj_in = nn.Linear(in_channels, inner_dim)
217
-
218
- self.transformer_blocks = nn.ModuleList(
219
- [
220
- BasicTransformerBlock3D(
221
- inner_dim,
222
- n_heads,
223
- d_head,
224
- context_dim=context_dim[d],
225
- dropout=dropout,
226
- checkpoint=use_checkpoint,
227
- ip_dim=ip_dim,
228
- ip_weight=ip_weight,
229
- )
230
- for d in range(depth)
231
- ]
232
- )
233
-
234
- self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
235
-
236
-
237
- def forward(self, x, context=None, num_frames=1):
238
- # note: if no context is given, cross-attention defaults to self-attention
239
- if not isinstance(context, list):
240
- context = [context]
241
- b, c, h, w = x.shape
242
- x_in = x
243
- x = self.norm(x)
244
- x = rearrange(x, "b c h w -> b (h w) c").contiguous()
245
- x = self.proj_in(x)
246
- for i, block in enumerate(self.transformer_blocks):
247
- x = block(x, context=context[i], num_frames=num_frames)
248
- x = self.proj_out(x)
249
- x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
250
-
251
- return x + x_in
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mvdream/util.py DELETED
@@ -1,140 +0,0 @@
1
- import math
2
- import torch
3
- import torch.nn as nn
4
- import numpy as np
5
- from einops import repeat
6
-
7
- from kiui.cam import orbit_camera
8
-
9
- def get_camera(
10
- num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False,
11
- ):
12
- angle_gap = azimuth_span / num_frames
13
- cameras = []
14
- for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
15
-
16
- pose = orbit_camera(-elevation, azimuth, radius=1) # kiui's elevation is negated, [4, 4]
17
-
18
- # opengl to blender
19
- if blender_coord:
20
- pose[2] *= -1
21
- pose[[1, 2]] = pose[[2, 1]]
22
-
23
- cameras.append(pose.flatten())
24
-
25
- if extra_view:
26
- cameras.append(np.zeros_like(cameras[0]))
27
-
28
- return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
29
-
30
-
31
- def checkpoint(func, inputs, params, flag):
32
- """
33
- Evaluate a function without caching intermediate activations, allowing for
34
- reduced memory at the expense of extra compute in the backward pass.
35
- :param func: the function to evaluate.
36
- :param inputs: the argument sequence to pass to `func`.
37
- :param params: a sequence of parameters `func` depends on but does not
38
- explicitly take as arguments.
39
- :param flag: if False, disable gradient checkpointing.
40
- """
41
- if flag:
42
- args = tuple(inputs) + tuple(params)
43
- return CheckpointFunction.apply(func, len(inputs), *args)
44
- else:
45
- return func(*inputs)
46
-
47
-
48
- class CheckpointFunction(torch.autograd.Function):
49
- @staticmethod
50
- def forward(ctx, run_function, length, *args):
51
- ctx.run_function = run_function
52
- ctx.input_tensors = list(args[:length])
53
- ctx.input_params = list(args[length:])
54
-
55
- with torch.no_grad():
56
- output_tensors = ctx.run_function(*ctx.input_tensors)
57
- return output_tensors
58
-
59
- @staticmethod
60
- def backward(ctx, *output_grads):
61
- ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
62
- with torch.enable_grad():
63
- # Fixes a bug where the first op in run_function modifies the
64
- # Tensor storage in place, which is not allowed for detach()'d
65
- # Tensors.
66
- shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
67
- output_tensors = ctx.run_function(*shallow_copies)
68
- input_grads = torch.autograd.grad(
69
- output_tensors,
70
- ctx.input_tensors + ctx.input_params,
71
- output_grads,
72
- allow_unused=True,
73
- )
74
- del ctx.input_tensors
75
- del ctx.input_params
76
- del output_tensors
77
- return (None, None) + input_grads
78
-
79
-
80
- def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
81
- """
82
- Create sinusoidal timestep embeddings.
83
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
84
- These may be fractional.
85
- :param dim: the dimension of the output.
86
- :param max_period: controls the minimum frequency of the embeddings.
87
- :return: an [N x dim] Tensor of positional embeddings.
88
- """
89
- if not repeat_only:
90
- half = dim // 2
91
- freqs = torch.exp(
92
- -math.log(max_period)
93
- * torch.arange(start=0, end=half, dtype=torch.float32)
94
- / half
95
- ).to(device=timesteps.device)
96
- args = timesteps[:, None] * freqs[None]
97
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
98
- if dim % 2:
99
- embedding = torch.cat(
100
- [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
101
- )
102
- else:
103
- embedding = repeat(timesteps, "b -> b d", d=dim)
104
- # import pdb; pdb.set_trace()
105
- return embedding
106
-
107
-
108
- def zero_module(module):
109
- """
110
- Zero out the parameters of a module and return it.
111
- """
112
- for p in module.parameters():
113
- p.detach().zero_()
114
- return module
115
-
116
-
117
- def conv_nd(dims, *args, **kwargs):
118
- """
119
- Create a 1D, 2D, or 3D convolution module.
120
- """
121
- if dims == 1:
122
- return nn.Conv1d(*args, **kwargs)
123
- elif dims == 2:
124
- return nn.Conv2d(*args, **kwargs)
125
- elif dims == 3:
126
- return nn.Conv3d(*args, **kwargs)
127
- raise ValueError(f"unsupported dimensions: {dims}")
128
-
129
-
130
- def avg_pool_nd(dims, *args, **kwargs):
131
- """
132
- Create a 1D, 2D, or 3D average pooling module.
133
- """
134
- if dims == 1:
135
- return nn.AvgPool1d(*args, **kwargs)
136
- elif dims == 2:
137
- return nn.AvgPool2d(*args, **kwargs)
138
- elif dims == 3:
139
- return nn.AvgPool3d(*args, **kwargs)
140
- raise ValueError(f"unsupported dimensions: {dims}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mvdream/pipeline_mvdream.py → pipeline_mvdream.py RENAMED
@@ -15,8 +15,7 @@ from diffusers.configuration_utils import FrozenDict
15
  from diffusers.schedulers import DDIMScheduler
16
  from diffusers.utils.torch_utils import randn_tensor
17
 
18
- from .models import MultiViewUNetModel
19
- from .util import get_camera
20
 
21
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
22
 
 
15
  from diffusers.schedulers import DDIMScheduler
16
  from diffusers.utils.torch_utils import randn_tensor
17
 
18
+ from mv_unet import MultiViewUNetModel, get_camera
 
19
 
20
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
21
 
run_imagedream.py CHANGED
@@ -2,12 +2,13 @@ import torch
2
  import kiui
3
  import numpy as np
4
  import argparse
5
- from mvdream.pipeline_mvdream import MVDreamPipeline
6
 
7
  pipe = MVDreamPipeline.from_pretrained(
8
  "./weights_imagedream", # local weights
9
  # "ashawkey/mvdream-sd2.1-diffusers",
10
- torch_dtype=torch.float16
 
11
  )
12
  pipe = pipe.to("cuda")
13
 
 
2
  import kiui
3
  import numpy as np
4
  import argparse
5
+ from pipeline_mvdream import MVDreamPipeline
6
 
7
  pipe = MVDreamPipeline.from_pretrained(
8
  "./weights_imagedream", # local weights
9
  # "ashawkey/mvdream-sd2.1-diffusers",
10
+ torch_dtype=torch.float16,
11
+ trust_remote_code=True,
12
  )
13
  pipe = pipe.to("cuda")
14
 
run_mvdream.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import kiui
3
  import numpy as np
4
  import argparse
5
- from mvdream.pipeline_mvdream import MVDreamPipeline
6
 
7
  pipe = MVDreamPipeline.from_pretrained(
8
  "./weights_mvdream", # local weights
 
2
  import kiui
3
  import numpy as np
4
  import argparse
5
+ from pipeline_mvdream import MVDreamPipeline
6
 
7
  pipe = MVDreamPipeline.from_pretrained(
8
  "./weights_mvdream", # local weights