Koke_Cacao commited on
Commit
0974835
·
1 Parent(s): 5b08d3b

:sparkles: add models

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -1,5 +1,4 @@
1
  *.pt
2
  *.yaml
3
- converted
4
  __pycache__
5
  *.png
 
1
  *.pt
2
  *.yaml
 
3
  __pycache__
4
  *.png
model_index.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MVDreamStableDiffusionPipeline",
3
+ "_diffusers_version": "0.21.4",
4
+ "requires_safety_checker": false,
5
+ "scheduler": [
6
+ "diffusers",
7
+ "DDIMScheduler"
8
+ ],
9
+ "text_encoder": [
10
+ "transformers",
11
+ "CLIPTextModel"
12
+ ],
13
+ "tokenizer": [
14
+ "transformers",
15
+ "CLIPTokenizer"
16
+ ],
17
+ "unet": [
18
+ "models",
19
+ "MultiViewUNetWrapperModel"
20
+ ],
21
+ "vae": [
22
+ "diffusers",
23
+ "AutoencoderKL"
24
+ ]
25
+ }
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.21.4",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "clip_sample_range": 1.0,
9
+ "dynamic_thresholding_ratio": 0.995,
10
+ "num_train_timesteps": 1000,
11
+ "prediction_type": "epsilon",
12
+ "rescale_betas_zero_snr": false,
13
+ "sample_max_value": 1.0,
14
+ "set_alpha_to_one": false,
15
+ "steps_offset": 1,
16
+ "thresholding": false,
17
+ "timestep_spacing": "leading",
18
+ "trained_betas": null
19
+ }
scripts/README.md CHANGED
@@ -1,6 +1,6 @@
1
  # Convert original weights to diffusers
2
 
3
- Download original MVDream checkpoint under `ckpts` through one of the following sources:
4
 
5
  ```bash
6
  # for sd-v1.5 (recommended for production)
@@ -14,5 +14,5 @@ wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd
14
 
15
  Hugging Face diffusers weights are converted by script:
16
  ```bash
17
- python ./scripts/convert_mvdream_to_diffusers.py --checkpoint_path ./sd-v1.5-4view.pt --dump_path ./converted --original_config_file ./sd-v1.yaml
18
  ```
 
1
  # Convert original weights to diffusers
2
 
3
+ Download original MVDream checkpoint through one of the following sources:
4
 
5
  ```bash
6
  # for sd-v1.5 (recommended for production)
 
14
 
15
  Hugging Face diffusers weights are converted by script:
16
  ```bash
17
+ python ./scripts/convert_mvdream_to_diffusers.py --checkpoint_path ./sd-v1.5-4view.pt --dump_path . --original_config_file ./sd-v1.yaml
18
  ```
scripts/attention.py CHANGED
@@ -1,11 +1,14 @@
1
- from inspect import isfunction
 
2
  import math
3
  import torch
4
  import torch.nn.functional as F
 
 
5
  from torch import nn, einsum
 
6
  from einops import rearrange, repeat
7
  from typing import Optional, Any
8
-
9
  from util import checkpoint
10
 
11
 
@@ -20,16 +23,13 @@ except:
20
  import os
21
  _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
22
 
23
- def exists(val):
24
- return val is not None
25
-
26
 
27
  def uniq(arr):
28
  return{el: True for el in arr}.keys()
29
 
30
 
31
  def default(val, d):
32
- if exists(val):
33
  return val
34
  return d() if isfunction(d) else d
35
 
@@ -172,7 +172,7 @@ class CrossAttention(nn.Module):
172
 
173
  # force cast to fp32 to avoid overflowing
174
  if _ATTN_PRECISION =="fp32":
175
- with torch.autocast(enabled=False, device_type = 'cuda'):
176
  q, k = q.float(), k.float()
177
  sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
178
  else:
@@ -180,7 +180,7 @@ class CrossAttention(nn.Module):
180
 
181
  del q, k
182
 
183
- if exists(mask):
184
  mask = rearrange(mask, 'b ... -> b (...)')
185
  max_neg_value = -torch.finfo(sim.dtype).max
186
  mask = repeat(mask, 'b j -> (b h) () j', h=h)
@@ -232,7 +232,7 @@ class MemoryEfficientCrossAttention(nn.Module):
232
  # actually compute the attention, what we cannot get enough of
233
  out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
234
 
235
- if exists(mask):
236
  raise NotImplementedError
237
  out = (
238
  out.unsqueeze(0)
@@ -289,7 +289,8 @@ class SpatialTransformer(nn.Module):
289
  disable_self_attn=False, use_linear=False,
290
  use_checkpoint=True):
291
  super().__init__()
292
- if exists(context_dim) and not isinstance(context_dim, list):
 
293
  context_dim = [context_dim]
294
  self.in_channels = in_channels
295
  inner_dim = n_heads * d_head
@@ -361,7 +362,8 @@ class SpatialTransformer3D(nn.Module):
361
  disable_self_attn=False, use_linear=False,
362
  use_checkpoint=True):
363
  super().__init__()
364
- if exists(context_dim) and not isinstance(context_dim, list):
 
365
  context_dim = [context_dim]
366
  self.in_channels = in_channels
367
  inner_dim = n_heads * d_head
 
1
+ # obtained and modified from https://github.com/bytedance/MVDream
2
+
3
  import math
4
  import torch
5
  import torch.nn.functional as F
6
+
7
+ from inspect import isfunction
8
  from torch import nn, einsum
9
+ from torch.amp.autocast_mode import autocast
10
  from einops import rearrange, repeat
11
  from typing import Optional, Any
 
12
  from util import checkpoint
13
 
14
 
 
23
  import os
24
  _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
25
 
 
 
 
26
 
27
  def uniq(arr):
28
  return{el: True for el in arr}.keys()
29
 
30
 
31
  def default(val, d):
32
+ if val is not None:
33
  return val
34
  return d() if isfunction(d) else d
35
 
 
172
 
173
  # force cast to fp32 to avoid overflowing
174
  if _ATTN_PRECISION =="fp32":
175
+ with autocast(enabled=False, device_type = 'cuda'):
176
  q, k = q.float(), k.float()
177
  sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
178
  else:
 
180
 
181
  del q, k
182
 
183
+ if mask is not None:
184
  mask = rearrange(mask, 'b ... -> b (...)')
185
  max_neg_value = -torch.finfo(sim.dtype).max
186
  mask = repeat(mask, 'b j -> (b h) () j', h=h)
 
232
  # actually compute the attention, what we cannot get enough of
233
  out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
234
 
235
+ if mask is not None:
236
  raise NotImplementedError
237
  out = (
238
  out.unsqueeze(0)
 
289
  disable_self_attn=False, use_linear=False,
290
  use_checkpoint=True):
291
  super().__init__()
292
+ assert context_dim is not None
293
+ if not isinstance(context_dim, list):
294
  context_dim = [context_dim]
295
  self.in_channels = in_channels
296
  inner_dim = n_heads * d_head
 
362
  disable_self_attn=False, use_linear=False,
363
  use_checkpoint=True):
364
  super().__init__()
365
+ assert context_dim is not None
366
+ if not isinstance(context_dim, list):
367
  context_dim = [context_dim]
368
  self.in_channels = in_channels
369
  inner_dim = n_heads * d_head
scripts/models.py CHANGED
@@ -1,12 +1,12 @@
1
- from abc import abstractmethod
2
- import math
3
- from typing import Any, Mapping
4
 
 
5
  import numpy as np
6
  import torch as th
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
 
 
10
  from util import (
11
  checkpoint,
12
  conv_nd,
@@ -16,58 +16,32 @@ from util import (
16
  normalization,
17
  timestep_embedding,
18
  )
19
- from attention import SpatialTransformer, SpatialTransformer3D, exists
20
-
21
-
22
  from diffusers.configuration_utils import ConfigMixin
23
  from diffusers.models.modeling_utils import ModelMixin
 
 
 
 
24
  class MultiViewUNetWrapperModel(ModelMixin, ConfigMixin):
 
25
  def __init__(self, *args, **kwargs):
26
  super().__init__()
27
  self.unet: MultiViewUNetModel = MultiViewUNetModel(*args, **kwargs)
28
-
29
  def forward(self, *args, **kwargs):
30
  return self.unet(*args, **kwargs)
31
 
 
32
  # dummy replace
33
  def convert_module_to_f16(x):
34
  pass
35
 
 
36
  def convert_module_to_f32(x):
37
  pass
38
 
39
 
40
- ## go
41
- class AttentionPool2d(nn.Module):
42
- """
43
- Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
44
- """
45
-
46
- def __init__(
47
- self,
48
- spacial_dim: int,
49
- embed_dim: int,
50
- num_heads_channels: int,
51
- output_dim: int = None,
52
- ):
53
- super().__init__()
54
- self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
55
- self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
56
- self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
57
- self.num_heads = embed_dim // num_heads_channels
58
- self.attention = QKVAttention(self.num_heads)
59
-
60
- def forward(self, x):
61
- b, c, *_spatial = x.shape
62
- x = x.reshape(b, c, -1) # NC(HW)
63
- x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
64
- x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
65
- x = self.qkv_proj(x)
66
- x = self.attention(x)
67
- x = self.c_proj(x)
68
- return x[:, :, 0]
69
-
70
-
71
  class TimestepBlock(nn.Module):
72
  """
73
  Any module where forward() takes timestep embeddings as a second argument.
@@ -108,39 +82,35 @@ class Upsample(nn.Module):
108
  upsampling occurs in the inner-two dimensions.
109
  """
110
 
111
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
 
 
 
 
 
112
  super().__init__()
113
  self.channels = channels
114
  self.out_channels = out_channels or channels
115
  self.use_conv = use_conv
116
  self.dims = dims
117
  if use_conv:
118
- self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
 
 
 
 
119
 
120
  def forward(self, x):
121
  assert x.shape[1] == self.channels
122
  if self.dims == 3:
123
- x = F.interpolate(
124
- x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
125
- )
126
  else:
127
  x = F.interpolate(x, scale_factor=2, mode="nearest")
128
  if self.use_conv:
129
  x = self.conv(x)
130
  return x
131
 
132
- class TransposedUpsample(nn.Module):
133
- 'Learned 2x upsampling without padding'
134
- def __init__(self, channels, out_channels=None, ks=5):
135
- super().__init__()
136
- self.channels = channels
137
- self.out_channels = out_channels or channels
138
-
139
- self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
140
-
141
- def forward(self,x):
142
- return self.up(x)
143
-
144
 
145
  class Downsample(nn.Module):
146
  """
@@ -151,7 +121,12 @@ class Downsample(nn.Module):
151
  downsampling occurs in the inner-two dimensions.
152
  """
153
 
154
- def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
 
 
 
 
 
155
  super().__init__()
156
  self.channels = channels
157
  self.out_channels = out_channels or channels
@@ -159,9 +134,12 @@ class Downsample(nn.Module):
159
  self.dims = dims
160
  stride = 2 if dims != 3 else (1, 2, 2)
161
  if use_conv:
162
- self.op = conv_nd(
163
- dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
164
- )
 
 
 
165
  else:
166
  assert self.channels == self.out_channels
167
  self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
@@ -230,7 +208,8 @@ class ResBlock(TimestepBlock):
230
  nn.SiLU(),
231
  linear(
232
  emb_channels,
233
- 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
 
234
  ),
235
  )
236
  self.out_layers = nn.Sequential(
@@ -238,18 +217,24 @@ class ResBlock(TimestepBlock):
238
  nn.SiLU(),
239
  nn.Dropout(p=dropout),
240
  zero_module(
241
- conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
242
- ),
 
 
 
243
  )
244
 
245
  if self.out_channels == channels:
246
  self.skip_connection = nn.Identity()
247
  elif use_conv:
248
- self.skip_connection = conv_nd(
249
- dims, channels, self.out_channels, 3, padding=1
250
- )
 
 
251
  else:
252
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
 
253
 
254
  def forward(self, x, emb):
255
  """
@@ -258,10 +243,8 @@ class ResBlock(TimestepBlock):
258
  :param emb: an [N x emb_channels] Tensor of timestep embeddings.
259
  :return: an [N x C x ...] Tensor of outputs.
260
  """
261
- return checkpoint(
262
- self._forward, (x, emb), self.parameters(), self.use_checkpoint
263
- )
264
-
265
 
266
  def _forward(self, x, emb):
267
  if self.updown:
@@ -323,7 +306,9 @@ class AttentionBlock(nn.Module):
323
  self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
324
 
325
  def forward(self, x):
326
- return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
 
 
327
  #return pt_checkpoint(self._forward, x) # pytorch
328
 
329
  def _forward(self, x):
@@ -351,7 +336,7 @@ def count_flops_attn(model, _x, y):
351
  # We perform two matmuls with the same number of ops.
352
  # The first computes the weight matrix, the second computes
353
  # the combination of the value vectors.
354
- matmul_ops = 2 * b * (num_spatial ** 2) * c
355
  model.total_ops += th.DoubleTensor([matmul_ops])
356
 
357
 
@@ -373,11 +358,12 @@ class QKVAttentionLegacy(nn.Module):
373
  bs, width, length = qkv.shape
374
  assert width % (3 * self.n_heads) == 0
375
  ch = width // (3 * self.n_heads)
376
- q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
 
377
  scale = 1 / math.sqrt(math.sqrt(ch))
378
  weight = th.einsum(
379
- "bct,bcs->bts", q * scale, k * scale
380
- ) # More stable with f16 than dividing afterwards
381
  weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
382
  a = th.einsum("bts,bcs->bct", weight, v)
383
  return a.reshape(bs, -1, length)
@@ -413,7 +399,8 @@ class QKVAttention(nn.Module):
413
  (k * scale).view(bs * self.n_heads, ch, length),
414
  ) # More stable with f16 than dividing afterwards
415
  weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
416
- a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
 
417
  return a.reshape(bs, -1, length)
418
 
419
  @staticmethod
@@ -422,6 +409,7 @@ class QKVAttention(nn.Module):
422
 
423
 
424
  class Timestep(nn.Module):
 
425
  def __init__(self, dim):
426
  super().__init__()
427
  self.dim = dim
@@ -430,395 +418,6 @@ class Timestep(nn.Module):
430
  return timestep_embedding(t, self.dim)
431
 
432
 
433
- class UNetModel(nn.Module):
434
- """
435
- The full UNet model with attention and timestep embedding.
436
- :param in_channels: channels in the input Tensor.
437
- :param model_channels: base channel count for the model.
438
- :param out_channels: channels in the output Tensor.
439
- :param num_res_blocks: number of residual blocks per downsample.
440
- :param attention_resolutions: a collection of downsample rates at which
441
- attention will take place. May be a set, list, or tuple.
442
- For example, if this contains 4, then at 4x downsampling, attention
443
- will be used.
444
- :param dropout: the dropout probability.
445
- :param channel_mult: channel multiplier for each level of the UNet.
446
- :param conv_resample: if True, use learned convolutions for upsampling and
447
- downsampling.
448
- :param dims: determines if the signal is 1D, 2D, or 3D.
449
- :param num_classes: if specified (as an int), then this model will be
450
- class-conditional with `num_classes` classes.
451
- :param use_checkpoint: use gradient checkpointing to reduce memory usage.
452
- :param num_heads: the number of attention heads in each attention layer.
453
- :param num_heads_channels: if specified, ignore num_heads and instead use
454
- a fixed channel width per attention head.
455
- :param num_heads_upsample: works with num_heads to set a different number
456
- of heads for upsampling. Deprecated.
457
- :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
458
- :param resblock_updown: use residual blocks for up/downsampling.
459
- :param use_new_attention_order: use a different attention pattern for potentially
460
- increased efficiency.
461
- """
462
-
463
- def __init__(
464
- self,
465
- image_size,
466
- in_channels,
467
- model_channels,
468
- out_channels,
469
- num_res_blocks,
470
- attention_resolutions,
471
- dropout=0,
472
- channel_mult=(1, 2, 4, 8),
473
- conv_resample=True,
474
- dims=2,
475
- num_classes=None,
476
- use_checkpoint=False,
477
- use_fp16=False,
478
- use_bf16=False,
479
- num_heads=-1,
480
- num_head_channels=-1,
481
- num_heads_upsample=-1,
482
- use_scale_shift_norm=False,
483
- resblock_updown=False,
484
- use_new_attention_order=False,
485
- use_spatial_transformer=False, # custom transformer support
486
- transformer_depth=1, # custom transformer support
487
- context_dim=None, # custom transformer support
488
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
489
- legacy=True,
490
- disable_self_attentions=None,
491
- num_attention_blocks=None,
492
- disable_middle_self_attn=False,
493
- use_linear_in_transformer=False,
494
- adm_in_channels=None,
495
- ):
496
- super().__init__()
497
- if use_spatial_transformer:
498
- assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
499
-
500
- if context_dim is not None:
501
- assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
502
- from omegaconf.listconfig import ListConfig
503
- if type(context_dim) == ListConfig:
504
- context_dim = list(context_dim)
505
-
506
- if num_heads_upsample == -1:
507
- num_heads_upsample = num_heads
508
-
509
- if num_heads == -1:
510
- assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
511
-
512
- if num_head_channels == -1:
513
- assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
514
-
515
- self.image_size = image_size
516
- self.in_channels = in_channels
517
- self.model_channels = model_channels
518
- self.out_channels = out_channels
519
- if isinstance(num_res_blocks, int):
520
- self.num_res_blocks = len(channel_mult) * [num_res_blocks]
521
- else:
522
- if len(num_res_blocks) != len(channel_mult):
523
- raise ValueError("provide num_res_blocks either as an int (globally constant) or "
524
- "as a list/tuple (per-level) with the same length as channel_mult")
525
- self.num_res_blocks = num_res_blocks
526
- if disable_self_attentions is not None:
527
- # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
528
- assert len(disable_self_attentions) == len(channel_mult)
529
- if num_attention_blocks is not None:
530
- assert len(num_attention_blocks) == len(self.num_res_blocks)
531
- assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
532
- print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
533
- f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
534
- f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
535
- f"attention will still not be set.")
536
-
537
- self.attention_resolutions = attention_resolutions
538
- self.dropout = dropout
539
- self.channel_mult = channel_mult
540
- self.conv_resample = conv_resample
541
- self.num_classes = num_classes
542
- self.use_checkpoint = use_checkpoint
543
- self.dtype = th.float16 if use_fp16 else th.float32
544
- self.dtype = th.bfloat16 if use_bf16 else self.dtype
545
- self.num_heads = num_heads
546
- self.num_head_channels = num_head_channels
547
- self.num_heads_upsample = num_heads_upsample
548
- self.predict_codebook_ids = n_embed is not None
549
-
550
- time_embed_dim = model_channels * 4
551
- self.time_embed = nn.Sequential(
552
- linear(model_channels, time_embed_dim),
553
- nn.SiLU(),
554
- linear(time_embed_dim, time_embed_dim),
555
- )
556
-
557
- if self.num_classes is not None:
558
- if isinstance(self.num_classes, int):
559
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
560
- elif self.num_classes == "continuous":
561
- print("setting up linear c_adm embedding layer")
562
- self.label_emb = nn.Linear(1, time_embed_dim)
563
- elif self.num_classes == "sequential":
564
- assert adm_in_channels is not None
565
- self.label_emb = nn.Sequential(
566
- nn.Sequential(
567
- linear(adm_in_channels, time_embed_dim),
568
- nn.SiLU(),
569
- linear(time_embed_dim, time_embed_dim),
570
- )
571
- )
572
- else:
573
- raise ValueError()
574
-
575
- self.input_blocks = nn.ModuleList(
576
- [
577
- TimestepEmbedSequential(
578
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
579
- )
580
- ]
581
- )
582
- self._feature_size = model_channels
583
- input_block_chans = [model_channels]
584
- ch = model_channels
585
- ds = 1
586
- for level, mult in enumerate(channel_mult):
587
- for nr in range(self.num_res_blocks[level]):
588
- layers = [
589
- ResBlock(
590
- ch,
591
- time_embed_dim,
592
- dropout,
593
- out_channels=mult * model_channels,
594
- dims=dims,
595
- use_checkpoint=use_checkpoint,
596
- use_scale_shift_norm=use_scale_shift_norm,
597
- )
598
- ]
599
- ch = mult * model_channels
600
- if ds in attention_resolutions:
601
- if num_head_channels == -1:
602
- dim_head = ch // num_heads
603
- else:
604
- num_heads = ch // num_head_channels
605
- dim_head = num_head_channels
606
- if legacy:
607
- #num_heads = 1
608
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
609
- if exists(disable_self_attentions):
610
- disabled_sa = disable_self_attentions[level]
611
- else:
612
- disabled_sa = False
613
-
614
- if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
615
- layers.append(
616
- AttentionBlock(
617
- ch,
618
- use_checkpoint=use_checkpoint,
619
- num_heads=num_heads,
620
- num_head_channels=dim_head,
621
- use_new_attention_order=use_new_attention_order,
622
- ) if not use_spatial_transformer else SpatialTransformer(
623
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
624
- disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
625
- use_checkpoint=use_checkpoint
626
- )
627
- )
628
- self.input_blocks.append(TimestepEmbedSequential(*layers))
629
- self._feature_size += ch
630
- input_block_chans.append(ch)
631
- if level != len(channel_mult) - 1:
632
- out_ch = ch
633
- self.input_blocks.append(
634
- TimestepEmbedSequential(
635
- ResBlock(
636
- ch,
637
- time_embed_dim,
638
- dropout,
639
- out_channels=out_ch,
640
- dims=dims,
641
- use_checkpoint=use_checkpoint,
642
- use_scale_shift_norm=use_scale_shift_norm,
643
- down=True,
644
- )
645
- if resblock_updown
646
- else Downsample(
647
- ch, conv_resample, dims=dims, out_channels=out_ch
648
- )
649
- )
650
- )
651
- ch = out_ch
652
- input_block_chans.append(ch)
653
- ds *= 2
654
- self._feature_size += ch
655
-
656
- if num_head_channels == -1:
657
- dim_head = ch // num_heads
658
- else:
659
- num_heads = ch // num_head_channels
660
- dim_head = num_head_channels
661
- if legacy:
662
- #num_heads = 1
663
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
664
- self.middle_block = TimestepEmbedSequential(
665
- ResBlock(
666
- ch,
667
- time_embed_dim,
668
- dropout,
669
- dims=dims,
670
- use_checkpoint=use_checkpoint,
671
- use_scale_shift_norm=use_scale_shift_norm,
672
- ),
673
- AttentionBlock(
674
- ch,
675
- use_checkpoint=use_checkpoint,
676
- num_heads=num_heads,
677
- num_head_channels=dim_head,
678
- use_new_attention_order=use_new_attention_order,
679
- ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
680
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
681
- disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
682
- use_checkpoint=use_checkpoint
683
- ),
684
- ResBlock(
685
- ch,
686
- time_embed_dim,
687
- dropout,
688
- dims=dims,
689
- use_checkpoint=use_checkpoint,
690
- use_scale_shift_norm=use_scale_shift_norm,
691
- ),
692
- )
693
- self._feature_size += ch
694
-
695
- self.output_blocks = nn.ModuleList([])
696
- for level, mult in list(enumerate(channel_mult))[::-1]:
697
- for i in range(self.num_res_blocks[level] + 1):
698
- ich = input_block_chans.pop()
699
- layers = [
700
- ResBlock(
701
- ch + ich,
702
- time_embed_dim,
703
- dropout,
704
- out_channels=model_channels * mult,
705
- dims=dims,
706
- use_checkpoint=use_checkpoint,
707
- use_scale_shift_norm=use_scale_shift_norm,
708
- )
709
- ]
710
- ch = model_channels * mult
711
- if ds in attention_resolutions:
712
- if num_head_channels == -1:
713
- dim_head = ch // num_heads
714
- else:
715
- num_heads = ch // num_head_channels
716
- dim_head = num_head_channels
717
- if legacy:
718
- #num_heads = 1
719
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
720
- if exists(disable_self_attentions):
721
- disabled_sa = disable_self_attentions[level]
722
- else:
723
- disabled_sa = False
724
-
725
- if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
726
- layers.append(
727
- AttentionBlock(
728
- ch,
729
- use_checkpoint=use_checkpoint,
730
- num_heads=num_heads_upsample,
731
- num_head_channels=dim_head,
732
- use_new_attention_order=use_new_attention_order,
733
- ) if not use_spatial_transformer else SpatialTransformer(
734
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
735
- disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
736
- use_checkpoint=use_checkpoint
737
- )
738
- )
739
- if level and i == self.num_res_blocks[level]:
740
- out_ch = ch
741
- layers.append(
742
- ResBlock(
743
- ch,
744
- time_embed_dim,
745
- dropout,
746
- out_channels=out_ch,
747
- dims=dims,
748
- use_checkpoint=use_checkpoint,
749
- use_scale_shift_norm=use_scale_shift_norm,
750
- up=True,
751
- )
752
- if resblock_updown
753
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
754
- )
755
- ds //= 2
756
- self.output_blocks.append(TimestepEmbedSequential(*layers))
757
- self._feature_size += ch
758
-
759
- self.out = nn.Sequential(
760
- normalization(ch),
761
- nn.SiLU(),
762
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
763
- )
764
- if self.predict_codebook_ids:
765
- self.id_predictor = nn.Sequential(
766
- normalization(ch),
767
- conv_nd(dims, model_channels, n_embed, 1),
768
- #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
769
- )
770
-
771
- def convert_to_fp16(self):
772
- """
773
- Convert the torso of the model to float16.
774
- """
775
- self.input_blocks.apply(convert_module_to_f16)
776
- self.middle_block.apply(convert_module_to_f16)
777
- self.output_blocks.apply(convert_module_to_f16)
778
-
779
- def convert_to_fp32(self):
780
- """
781
- Convert the torso of the model to float32.
782
- """
783
- self.input_blocks.apply(convert_module_to_f32)
784
- self.middle_block.apply(convert_module_to_f32)
785
- self.output_blocks.apply(convert_module_to_f32)
786
-
787
- def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
788
- """
789
- Apply the model to an input batch.
790
- :param x: an [N x C x ...] Tensor of inputs.
791
- :param timesteps: a 1-D batch of timesteps.
792
- :param context: conditioning plugged in via crossattn
793
- :param y: an [N] Tensor of labels, if class-conditional.
794
- :return: an [N x C x ...] Tensor of outputs.
795
- """
796
- assert (y is not None) == (
797
- self.num_classes is not None
798
- ), "must specify y if and only if the model is class-conditional"
799
- hs = []
800
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
801
- emb = self.time_embed(t_emb)
802
-
803
- if self.num_classes is not None:
804
- assert y.shape[0] == x.shape[0]
805
- emb = emb + self.label_emb(y)
806
-
807
- h = x.type(self.dtype)
808
- for module in self.input_blocks:
809
- h = module(h, emb, context)
810
- hs.append(h)
811
- h = self.middle_block(h, emb, context)
812
- for module in self.output_blocks:
813
- h = th.cat([h, hs.pop()], dim=1)
814
- h = module(h, emb, context)
815
- h = h.type(x.dtype)
816
- if self.predict_codebook_ids:
817
- return self.id_predictor(h)
818
- else:
819
- return self.out(h)
820
-
821
-
822
  class MultiViewUNetModel(nn.Module):
823
  """
824
  The full multi-view UNet model with attention, timestep embedding and camera embedding.
@@ -872,10 +471,10 @@ class MultiViewUNetModel(nn.Module):
872
  use_scale_shift_norm=False,
873
  resblock_updown=False,
874
  use_new_attention_order=False,
875
- use_spatial_transformer=False, # custom transformer support
876
- transformer_depth=1, # custom transformer support
877
- context_dim=None, # custom transformer support
878
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
879
  legacy=True,
880
  disable_self_attentions=None,
881
  num_attention_blocks=None,
@@ -885,6 +484,7 @@ class MultiViewUNetModel(nn.Module):
885
  camera_dim=None,
886
  ):
887
  super().__init__()
 
888
  if use_spatial_transformer:
889
  assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
890
 
@@ -911,19 +511,26 @@ class MultiViewUNetModel(nn.Module):
911
  self.num_res_blocks = len(channel_mult) * [num_res_blocks]
912
  else:
913
  if len(num_res_blocks) != len(channel_mult):
914
- raise ValueError("provide num_res_blocks either as an int (globally constant) or "
915
- "as a list/tuple (per-level) with the same length as channel_mult")
 
 
916
  self.num_res_blocks = num_res_blocks
917
  if disable_self_attentions is not None:
918
  # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
919
  assert len(disable_self_attentions) == len(channel_mult)
920
  if num_attention_blocks is not None:
921
  assert len(num_attention_blocks) == len(self.num_res_blocks)
922
- assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
923
- print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
924
- f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
925
- f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
926
- f"attention will still not be set.")
 
 
 
 
 
927
 
928
  self.attention_resolutions = attention_resolutions
929
  self.dropout = dropout
@@ -966,25 +573,21 @@ class MultiViewUNetModel(nn.Module):
966
  linear(adm_in_channels, time_embed_dim),
967
  nn.SiLU(),
968
  linear(time_embed_dim, time_embed_dim),
969
- )
970
- )
971
  else:
972
  raise ValueError()
973
 
974
- self.input_blocks = nn.ModuleList(
975
- [
976
- TimestepEmbedSequential(
977
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
978
- )
979
- ]
980
- )
981
  self._feature_size = model_channels
982
  input_block_chans = [model_channels]
983
  ch = model_channels
984
  ds = 1
985
  for level, mult in enumerate(channel_mult):
986
  for nr in range(self.num_res_blocks[level]):
987
- layers = [
988
  ResBlock(
989
  ch,
990
  time_embed_dim,
@@ -1005,12 +608,13 @@ class MultiViewUNetModel(nn.Module):
1005
  if legacy:
1006
  #num_heads = 1
1007
  dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
1008
- if exists(disable_self_attentions):
1009
  disabled_sa = disable_self_attentions[level]
1010
  else:
1011
  disabled_sa = False
1012
 
1013
- if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
 
1014
  layers.append(
1015
  AttentionBlock(
1016
  ch,
@@ -1018,12 +622,16 @@ class MultiViewUNetModel(nn.Module):
1018
  num_heads=num_heads,
1019
  num_head_channels=dim_head,
1020
  use_new_attention_order=use_new_attention_order,
1021
- ) if not use_spatial_transformer else SpatialTransformer3D(
1022
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
1023
- disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
1024
- use_checkpoint=use_checkpoint
1025
- )
1026
- )
 
 
 
 
1027
  self.input_blocks.append(TimestepEmbedSequential(*layers))
1028
  self._feature_size += ch
1029
  input_block_chans.append(ch)
@@ -1040,12 +648,8 @@ class MultiViewUNetModel(nn.Module):
1040
  use_checkpoint=use_checkpoint,
1041
  use_scale_shift_norm=use_scale_shift_norm,
1042
  down=True,
1043
- )
1044
- if resblock_updown
1045
- else Downsample(
1046
- ch, conv_resample, dims=dims, out_channels=out_ch
1047
- )
1048
- )
1049
  )
1050
  ch = out_ch
1051
  input_block_chans.append(ch)
@@ -1075,11 +679,16 @@ class MultiViewUNetModel(nn.Module):
1075
  num_heads=num_heads,
1076
  num_head_channels=dim_head,
1077
  use_new_attention_order=use_new_attention_order,
1078
- ) if not use_spatial_transformer else SpatialTransformer3D( # always uses a self-attn
1079
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
1080
- disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
1081
- use_checkpoint=use_checkpoint
1082
- ),
 
 
 
 
 
1083
  ResBlock(
1084
  ch,
1085
  time_embed_dim,
@@ -1116,12 +725,13 @@ class MultiViewUNetModel(nn.Module):
1116
  if legacy:
1117
  #num_heads = 1
1118
  dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
1119
- if exists(disable_self_attentions):
1120
  disabled_sa = disable_self_attentions[level]
1121
  else:
1122
  disabled_sa = False
1123
 
1124
- if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
 
1125
  layers.append(
1126
  AttentionBlock(
1127
  ch,
@@ -1129,12 +739,16 @@ class MultiViewUNetModel(nn.Module):
1129
  num_heads=num_heads_upsample,
1130
  num_head_channels=dim_head,
1131
  use_new_attention_order=use_new_attention_order,
1132
- ) if not use_spatial_transformer else SpatialTransformer3D(
1133
- ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
1134
- disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
1135
- use_checkpoint=use_checkpoint
1136
- )
1137
- )
 
 
 
 
1138
  if level and i == self.num_res_blocks[level]:
1139
  out_ch = ch
1140
  layers.append(
@@ -1147,10 +761,8 @@ class MultiViewUNetModel(nn.Module):
1147
  use_checkpoint=use_checkpoint,
1148
  use_scale_shift_norm=use_scale_shift_norm,
1149
  up=True,
1150
- )
1151
- if resblock_updown
1152
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
1153
- )
1154
  ds //= 2
1155
  self.output_blocks.append(TimestepEmbedSequential(*layers))
1156
  self._feature_size += ch
@@ -1158,14 +770,15 @@ class MultiViewUNetModel(nn.Module):
1158
  self.out = nn.Sequential(
1159
  normalization(ch),
1160
  nn.SiLU(),
1161
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
 
1162
  )
1163
  if self.predict_codebook_ids:
1164
  self.id_predictor = nn.Sequential(
1165
- normalization(ch),
1166
- conv_nd(dims, model_channels, n_embed, 1),
1167
- #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
1168
- )
1169
 
1170
  def convert_to_fp16(self):
1171
  """
@@ -1183,7 +796,14 @@ class MultiViewUNetModel(nn.Module):
1183
  self.middle_block.apply(convert_module_to_f32)
1184
  self.output_blocks.apply(convert_module_to_f32)
1185
 
1186
- def forward(self, x, timesteps=None, context=None, y=None, camera=None, num_frames=1, **kwargs):
 
 
 
 
 
 
 
1187
  """
1188
  Apply the model to an input batch.
1189
  :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
@@ -1193,15 +813,19 @@ class MultiViewUNetModel(nn.Module):
1193
  :param num_frames: a integer indicating number of frames for tensor reshaping.
1194
  :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
1195
  """
1196
- assert x.shape[0] % num_frames == 0, "[UNet] input batch size must be dividable by num_frames!"
 
1197
  assert (y is not None) == (
1198
  self.num_classes is not None
1199
  ), "must specify y if and only if the model is class-conditional"
1200
  hs = []
1201
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
 
 
1202
  emb = self.time_embed(t_emb)
1203
 
1204
  if self.num_classes is not None:
 
1205
  assert y.shape[0] == x.shape[0]
1206
  emb = emb + self.label_emb(y)
1207
 
@@ -1222,4 +846,4 @@ class MultiViewUNetModel(nn.Module):
1222
  if self.predict_codebook_ids:
1223
  return self.id_predictor(h)
1224
  else:
1225
- return self.out(h)
 
1
+ # obtained and modified from https://github.com/bytedance/MVDream
 
 
2
 
3
+ import math
4
  import numpy as np
5
  import torch as th
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
 
9
+ from abc import abstractmethod
10
  from util import (
11
  checkpoint,
12
  conv_nd,
 
16
  normalization,
17
  timestep_embedding,
18
  )
19
+ from attention import SpatialTransformer, SpatialTransformer3D
 
 
20
  from diffusers.configuration_utils import ConfigMixin
21
  from diffusers.models.modeling_utils import ModelMixin
22
+ from typing import Any, List, Optional
23
+ from torch import Tensor
24
+
25
+
26
  class MultiViewUNetWrapperModel(ModelMixin, ConfigMixin):
27
+
28
  def __init__(self, *args, **kwargs):
29
  super().__init__()
30
  self.unet: MultiViewUNetModel = MultiViewUNetModel(*args, **kwargs)
31
+
32
  def forward(self, *args, **kwargs):
33
  return self.unet(*args, **kwargs)
34
 
35
+
36
  # dummy replace
37
  def convert_module_to_f16(x):
38
  pass
39
 
40
+
41
  def convert_module_to_f32(x):
42
  pass
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  class TimestepBlock(nn.Module):
46
  """
47
  Any module where forward() takes timestep embeddings as a second argument.
 
82
  upsampling occurs in the inner-two dimensions.
83
  """
84
 
85
+ def __init__(self,
86
+ channels,
87
+ use_conv,
88
+ dims=2,
89
+ out_channels=None,
90
+ padding=1):
91
  super().__init__()
92
  self.channels = channels
93
  self.out_channels = out_channels or channels
94
  self.use_conv = use_conv
95
  self.dims = dims
96
  if use_conv:
97
+ self.conv = conv_nd(dims,
98
+ self.channels,
99
+ self.out_channels,
100
+ 3,
101
+ padding=padding)
102
 
103
  def forward(self, x):
104
  assert x.shape[1] == self.channels
105
  if self.dims == 3:
106
+ x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
107
+ mode="nearest")
 
108
  else:
109
  x = F.interpolate(x, scale_factor=2, mode="nearest")
110
  if self.use_conv:
111
  x = self.conv(x)
112
  return x
113
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  class Downsample(nn.Module):
116
  """
 
121
  downsampling occurs in the inner-two dimensions.
122
  """
123
 
124
+ def __init__(self,
125
+ channels,
126
+ use_conv,
127
+ dims=2,
128
+ out_channels=None,
129
+ padding=1):
130
  super().__init__()
131
  self.channels = channels
132
  self.out_channels = out_channels or channels
 
134
  self.dims = dims
135
  stride = 2 if dims != 3 else (1, 2, 2)
136
  if use_conv:
137
+ self.op = conv_nd(dims,
138
+ self.channels,
139
+ self.out_channels,
140
+ 3,
141
+ stride=stride,
142
+ padding=padding)
143
  else:
144
  assert self.channels == self.out_channels
145
  self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
 
208
  nn.SiLU(),
209
  linear(
210
  emb_channels,
211
+ 2 * self.out_channels
212
+ if use_scale_shift_norm else self.out_channels,
213
  ),
214
  )
215
  self.out_layers = nn.Sequential(
 
217
  nn.SiLU(),
218
  nn.Dropout(p=dropout),
219
  zero_module(
220
+ conv_nd(dims,
221
+ self.out_channels,
222
+ self.out_channels,
223
+ 3,
224
+ padding=1)),
225
  )
226
 
227
  if self.out_channels == channels:
228
  self.skip_connection = nn.Identity()
229
  elif use_conv:
230
+ self.skip_connection = conv_nd(dims,
231
+ channels,
232
+ self.out_channels,
233
+ 3,
234
+ padding=1)
235
  else:
236
+ self.skip_connection = conv_nd(dims, channels, self.out_channels,
237
+ 1)
238
 
239
  def forward(self, x, emb):
240
  """
 
243
  :param emb: an [N x emb_channels] Tensor of timestep embeddings.
244
  :return: an [N x C x ...] Tensor of outputs.
245
  """
246
+ return checkpoint(self._forward, (x, emb), self.parameters(),
247
+ self.use_checkpoint)
 
 
248
 
249
  def _forward(self, x, emb):
250
  if self.updown:
 
306
  self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
307
 
308
  def forward(self, x):
309
+ return checkpoint(
310
+ self._forward, (x, ), self.parameters(), True
311
+ ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
312
  #return pt_checkpoint(self._forward, x) # pytorch
313
 
314
  def _forward(self, x):
 
336
  # We perform two matmuls with the same number of ops.
337
  # The first computes the weight matrix, the second computes
338
  # the combination of the value vectors.
339
+ matmul_ops = 2 * b * (num_spatial**2) * c
340
  model.total_ops += th.DoubleTensor([matmul_ops])
341
 
342
 
 
358
  bs, width, length = qkv.shape
359
  assert width % (3 * self.n_heads) == 0
360
  ch = width // (3 * self.n_heads)
361
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch,
362
+ dim=1)
363
  scale = 1 / math.sqrt(math.sqrt(ch))
364
  weight = th.einsum(
365
+ "bct,bcs->bts", q * scale,
366
+ k * scale) # More stable with f16 than dividing afterwards
367
  weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
368
  a = th.einsum("bts,bcs->bct", weight, v)
369
  return a.reshape(bs, -1, length)
 
399
  (k * scale).view(bs * self.n_heads, ch, length),
400
  ) # More stable with f16 than dividing afterwards
401
  weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
402
+ a = th.einsum("bts,bcs->bct", weight,
403
+ v.reshape(bs * self.n_heads, ch, length))
404
  return a.reshape(bs, -1, length)
405
 
406
  @staticmethod
 
409
 
410
 
411
  class Timestep(nn.Module):
412
+
413
  def __init__(self, dim):
414
  super().__init__()
415
  self.dim = dim
 
418
  return timestep_embedding(t, self.dim)
419
 
420
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  class MultiViewUNetModel(nn.Module):
422
  """
423
  The full multi-view UNet model with attention, timestep embedding and camera embedding.
 
471
  use_scale_shift_norm=False,
472
  resblock_updown=False,
473
  use_new_attention_order=False,
474
+ use_spatial_transformer=False, # custom transformer support
475
+ transformer_depth=1, # custom transformer support
476
+ context_dim=None, # custom transformer support
477
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
478
  legacy=True,
479
  disable_self_attentions=None,
480
  num_attention_blocks=None,
 
484
  camera_dim=None,
485
  ):
486
  super().__init__()
487
+ assert num_classes is not None
488
  if use_spatial_transformer:
489
  assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
490
 
 
511
  self.num_res_blocks = len(channel_mult) * [num_res_blocks]
512
  else:
513
  if len(num_res_blocks) != len(channel_mult):
514
+ raise ValueError(
515
+ "provide num_res_blocks either as an int (globally constant) or "
516
+ "as a list/tuple (per-level) with the same length as channel_mult"
517
+ )
518
  self.num_res_blocks = num_res_blocks
519
  if disable_self_attentions is not None:
520
  # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
521
  assert len(disable_self_attentions) == len(channel_mult)
522
  if num_attention_blocks is not None:
523
  assert len(num_attention_blocks) == len(self.num_res_blocks)
524
+ assert all(
525
+ map(
526
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i
527
+ ],
528
+ range(len(num_attention_blocks))))
529
+ print(
530
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
531
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
532
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
533
+ f"attention will still not be set.")
534
 
535
  self.attention_resolutions = attention_resolutions
536
  self.dropout = dropout
 
573
  linear(adm_in_channels, time_embed_dim),
574
  nn.SiLU(),
575
  linear(time_embed_dim, time_embed_dim),
576
+ ))
 
577
  else:
578
  raise ValueError()
579
 
580
+ self.input_blocks = nn.ModuleList([
581
+ TimestepEmbedSequential(
582
+ conv_nd(dims, in_channels, model_channels, 3, padding=1))
583
+ ])
 
 
 
584
  self._feature_size = model_channels
585
  input_block_chans = [model_channels]
586
  ch = model_channels
587
  ds = 1
588
  for level, mult in enumerate(channel_mult):
589
  for nr in range(self.num_res_blocks[level]):
590
+ layers: List[Any] = [
591
  ResBlock(
592
  ch,
593
  time_embed_dim,
 
608
  if legacy:
609
  #num_heads = 1
610
  dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
611
+ if disable_self_attentions is not None:
612
  disabled_sa = disable_self_attentions[level]
613
  else:
614
  disabled_sa = False
615
 
616
+ if num_attention_blocks is None or nr < num_attention_blocks[
617
+ level]:
618
  layers.append(
619
  AttentionBlock(
620
  ch,
 
622
  num_heads=num_heads,
623
  num_head_channels=dim_head,
624
  use_new_attention_order=use_new_attention_order,
625
+ ) if not use_spatial_transformer else
626
+ SpatialTransformer3D(
627
+ ch,
628
+ num_heads,
629
+ dim_head,
630
+ depth=transformer_depth,
631
+ context_dim=context_dim,
632
+ disable_self_attn=disabled_sa,
633
+ use_linear=use_linear_in_transformer,
634
+ use_checkpoint=use_checkpoint))
635
  self.input_blocks.append(TimestepEmbedSequential(*layers))
636
  self._feature_size += ch
637
  input_block_chans.append(ch)
 
648
  use_checkpoint=use_checkpoint,
649
  use_scale_shift_norm=use_scale_shift_norm,
650
  down=True,
651
+ ) if resblock_updown else Downsample(
652
+ ch, conv_resample, dims=dims, out_channels=out_ch))
 
 
 
 
653
  )
654
  ch = out_ch
655
  input_block_chans.append(ch)
 
679
  num_heads=num_heads,
680
  num_head_channels=dim_head,
681
  use_new_attention_order=use_new_attention_order,
682
+ ) if not use_spatial_transformer else
683
+ SpatialTransformer3D( # always uses a self-attn
684
+ ch,
685
+ num_heads,
686
+ dim_head,
687
+ depth=transformer_depth,
688
+ context_dim=context_dim,
689
+ disable_self_attn=disable_middle_self_attn,
690
+ use_linear=use_linear_in_transformer,
691
+ use_checkpoint=use_checkpoint),
692
  ResBlock(
693
  ch,
694
  time_embed_dim,
 
725
  if legacy:
726
  #num_heads = 1
727
  dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
728
+ if disable_self_attentions is not None:
729
  disabled_sa = disable_self_attentions[level]
730
  else:
731
  disabled_sa = False
732
 
733
+ if num_attention_blocks is None or i < num_attention_blocks[
734
+ level]:
735
  layers.append(
736
  AttentionBlock(
737
  ch,
 
739
  num_heads=num_heads_upsample,
740
  num_head_channels=dim_head,
741
  use_new_attention_order=use_new_attention_order,
742
+ ) if not use_spatial_transformer else
743
+ SpatialTransformer3D(
744
+ ch,
745
+ num_heads,
746
+ dim_head,
747
+ depth=transformer_depth,
748
+ context_dim=context_dim,
749
+ disable_self_attn=disabled_sa,
750
+ use_linear=use_linear_in_transformer,
751
+ use_checkpoint=use_checkpoint))
752
  if level and i == self.num_res_blocks[level]:
753
  out_ch = ch
754
  layers.append(
 
761
  use_checkpoint=use_checkpoint,
762
  use_scale_shift_norm=use_scale_shift_norm,
763
  up=True,
764
+ ) if resblock_updown else Upsample(
765
+ ch, conv_resample, dims=dims, out_channels=out_ch))
 
 
766
  ds //= 2
767
  self.output_blocks.append(TimestepEmbedSequential(*layers))
768
  self._feature_size += ch
 
770
  self.out = nn.Sequential(
771
  normalization(ch),
772
  nn.SiLU(),
773
+ zero_module(
774
+ conv_nd(dims, model_channels, out_channels, 3, padding=1)),
775
  )
776
  if self.predict_codebook_ids:
777
  self.id_predictor = nn.Sequential(
778
+ normalization(ch),
779
+ conv_nd(dims, model_channels, n_embed, 1),
780
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
781
+ )
782
 
783
  def convert_to_fp16(self):
784
  """
 
796
  self.middle_block.apply(convert_module_to_f32)
797
  self.output_blocks.apply(convert_module_to_f32)
798
 
799
+ def forward(self,
800
+ x,
801
+ timesteps=None,
802
+ context=None,
803
+ y: Optional[Tensor] = None,
804
+ camera=None,
805
+ num_frames=1,
806
+ **kwargs):
807
  """
808
  Apply the model to an input batch.
809
  :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
 
813
  :param num_frames: a integer indicating number of frames for tensor reshaping.
814
  :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
815
  """
816
+ assert x.shape[
817
+ 0] % num_frames == 0, "[UNet] input batch size must be dividable by num_frames!"
818
  assert (y is not None) == (
819
  self.num_classes is not None
820
  ), "must specify y if and only if the model is class-conditional"
821
  hs = []
822
+ t_emb = timestep_embedding(timesteps,
823
+ self.model_channels,
824
+ repeat_only=False)
825
  emb = self.time_embed(t_emb)
826
 
827
  if self.num_classes is not None:
828
+ assert y is not None
829
  assert y.shape[0] == x.shape[0]
830
  emb = emb + self.label_emb(y)
831
 
 
846
  if self.predict_codebook_ids:
847
  return self.id_predictor(h)
848
  else:
849
+ return self.out(h)
scripts/pipeline_mvdream.py CHANGED
@@ -557,14 +557,30 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
557
  self.scheduler.set_timesteps(num_inference_steps, device=device)
558
  timesteps = self.scheduler.timesteps
559
 
560
- _: torch.Tensor = self._encode_prompt(
561
  prompt=prompt,
562
  device=device,
563
  num_images_per_prompt=num_images_per_prompt,
564
  do_classifier_free_guidance=do_classifier_free_guidance,
565
  negative_prompt=negative_prompt,
566
  ) # type: ignore
567
- prompt_embeds_neg, prompt_embeds_pos = _.chunk(2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
 
569
  # 5. Prepare latent variables
570
  latents: torch.Tensor = self.prepare_latents(
@@ -604,7 +620,7 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
604
  timesteps=torch.tensor([t] * 4 * multiplier,
605
  device=device),
606
  context=torch.cat([prompt_embeds_neg] * 4 +
607
- [prompt_embeds_pos] * 4),
608
  num_frames=4,
609
  camera=torch.cat([camera] * multiplier),
610
  )
 
557
  self.scheduler.set_timesteps(num_inference_steps, device=device)
558
  timesteps = self.scheduler.timesteps
559
 
560
+ _prompt_embeds: torch.Tensor = self._encode_prompt(
561
  prompt=prompt,
562
  device=device,
563
  num_images_per_prompt=num_images_per_prompt,
564
  do_classifier_free_guidance=do_classifier_free_guidance,
565
  negative_prompt=negative_prompt,
566
  ) # type: ignore
567
+ prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2)
568
+
569
+ _, prompt_embeds_pos_2 = self._encode_prompt(
570
+ prompt="watermellon",
571
+ device=device,
572
+ num_images_per_prompt=num_images_per_prompt,
573
+ do_classifier_free_guidance=do_classifier_free_guidance,
574
+ negative_prompt=negative_prompt,
575
+ ).chunk(2) # type: ignore
576
+
577
+ _, prompt_embeds_pos_4 = self._encode_prompt(
578
+ prompt="long hair",
579
+ device=device,
580
+ num_images_per_prompt=num_images_per_prompt,
581
+ do_classifier_free_guidance=do_classifier_free_guidance,
582
+ negative_prompt=negative_prompt,
583
+ ).chunk(2) # type: ignore
584
 
585
  # 5. Prepare latent variables
586
  latents: torch.Tensor = self.prepare_latents(
 
620
  timesteps=torch.tensor([t] * 4 * multiplier,
621
  device=device),
622
  context=torch.cat([prompt_embeds_neg] * 4 +
623
+ [prompt_embeds_pos, prompt_embeds_pos_2, prompt_embeds_pos, prompt_embeds_pos_4]),
624
  num_frames=4,
625
  camera=torch.cat([camera] * multiplier),
626
  )
scripts/util.py CHANGED
@@ -7,14 +7,14 @@
7
  #
8
  # thanks!
9
 
10
-
11
- import os
12
  import math
13
  import torch
14
  import torch.nn as nn
15
  import numpy as np
16
- from einops import repeat
17
  import importlib
 
 
 
18
 
19
  def instantiate_from_config(config):
20
  if not "target" in config:
@@ -33,16 +33,22 @@ def get_obj_from_str(string, reload=False):
33
  importlib.reload(module_imp)
34
  return getattr(importlib.import_module(module, package=None), cls)
35
 
36
- def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
 
 
 
 
 
37
  if schedule == "linear":
38
- betas = (
39
- torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
40
- )
 
41
 
42
  elif schedule == "cosine":
43
  timesteps = (
44
- torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
45
- )
46
  alphas = timesteps / (1 + cosine_s) * np.pi / 2
47
  alphas = torch.cos(alphas).pow(2)
48
  alphas = alphas / alphas[0]
@@ -50,22 +56,34 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
50
  betas = np.clip(betas, a_min=0, a_max=0.999)
51
 
52
  elif schedule == "sqrt_linear":
53
- betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
 
 
 
54
  elif schedule == "sqrt":
55
- betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
 
 
 
56
  else:
57
  raise ValueError(f"schedule '{schedule}' unknown.")
58
- return betas.numpy()
59
 
60
 
61
- def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
 
 
 
62
  if ddim_discr_method == 'uniform':
63
  c = num_ddpm_timesteps // num_ddim_timesteps
64
  ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
65
  elif ddim_discr_method == 'quad':
66
- ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
 
67
  else:
68
- raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
 
 
69
 
70
  # assert ddim_timesteps.shape[0] == num_ddim_timesteps
71
  # add one to get the final alpha values right (the ones from first scale to data during sampling)
@@ -75,17 +93,26 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
75
  return steps_out
76
 
77
 
78
- def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
 
 
 
79
  # select alphas for computing the variance schedule
80
  alphas = alphacums[ddim_timesteps]
81
- alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
 
82
 
83
  # according the the formula provided in https://arxiv.org/abs/2010.02502
84
- sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
 
85
  if verbose:
86
- print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
87
- print(f'For the chosen value of eta, which is {eta}, '
88
- f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
 
 
 
 
89
  return sigmas, alphas, alphas_prev
90
 
91
 
@@ -111,7 +138,7 @@ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
111
  def extract_into_tensor(a, t, x_shape):
112
  b, *_ = t.shape
113
  out = a.gather(-1, t)
114
- return out.reshape(b, *((1,) * (len(x_shape) - 1)))
115
 
116
 
117
  def checkpoint(func, inputs, params, flag):
@@ -130,7 +157,9 @@ def checkpoint(func, inputs, params, flag):
130
  else:
131
  return func(*inputs)
132
 
 
133
  class CheckpointFunction(torch.autograd.Function):
 
134
  @staticmethod
135
  def forward(ctx, run_function, length, *args):
136
  ctx.run_function = run_function
@@ -143,7 +172,9 @@ class CheckpointFunction(torch.autograd.Function):
143
 
144
  @staticmethod
145
  def backward(ctx, *output_grads):
146
- ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
 
 
147
  with torch.enable_grad():
148
  # Fixes a bug where the first op in run_function modifies the
149
  # Tensor storage in place, which is not allowed for detach()'d
@@ -174,12 +205,14 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
174
  if not repeat_only:
175
  half = dim // 2
176
  freqs = torch.exp(
177
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
178
- ).to(device=timesteps.device)
 
179
  args = timesteps[:, None].float() * freqs[None]
180
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
181
  if dim % 2:
182
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
 
183
  else:
184
  embedding = repeat(timesteps, 'b -> b d', d=dim)
185
  # import pdb; pdb.set_trace()
@@ -222,14 +255,17 @@ def normalization(channels):
222
 
223
  # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
224
  class SiLU(nn.Module):
 
225
  def forward(self, x):
226
  return x * torch.sigmoid(x)
227
 
228
 
229
  class GroupNorm32(nn.GroupNorm):
 
230
  def forward(self, x):
231
  return super().forward(x.float()).type(x.dtype)
232
 
 
233
  def conv_nd(dims, *args, **kwargs):
234
  """
235
  Create a 1D, 2D, or 3D convolution module.
@@ -267,8 +303,9 @@ class HybridConditioner(nn.Module):
267
 
268
  def __init__(self, c_concat_config, c_crossattn_config):
269
  super().__init__()
270
- self.concat_conditioner = instantiate_from_config(c_concat_config)
271
- self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
 
272
 
273
  def forward(self, c_concat, c_crossattn):
274
  c_concat = self.concat_conditioner(c_concat)
@@ -277,6 +314,7 @@ class HybridConditioner(nn.Module):
277
 
278
 
279
  def noise_like(shape, device, repeat=False):
280
- repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
 
281
  noise = lambda: torch.randn(shape, device=device)
282
- return repeat_noise() if repeat else noise()
 
7
  #
8
  # thanks!
9
 
 
 
10
  import math
11
  import torch
12
  import torch.nn as nn
13
  import numpy as np
 
14
  import importlib
15
+ from einops import repeat
16
+ from typing import Any
17
+
18
 
19
  def instantiate_from_config(config):
20
  if not "target" in config:
 
33
  importlib.reload(module_imp)
34
  return getattr(importlib.import_module(module, package=None), cls)
35
 
36
+
37
+ def make_beta_schedule(schedule,
38
+ n_timestep,
39
+ linear_start=1e-4,
40
+ linear_end=2e-2,
41
+ cosine_s=8e-3):
42
  if schedule == "linear":
43
+ betas = (torch.linspace(linear_start**0.5,
44
+ linear_end**0.5,
45
+ n_timestep,
46
+ dtype=torch.float64)**2)
47
 
48
  elif schedule == "cosine":
49
  timesteps = (
50
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep +
51
+ cosine_s)
52
  alphas = timesteps / (1 + cosine_s) * np.pi / 2
53
  alphas = torch.cos(alphas).pow(2)
54
  alphas = alphas / alphas[0]
 
56
  betas = np.clip(betas, a_min=0, a_max=0.999)
57
 
58
  elif schedule == "sqrt_linear":
59
+ betas = torch.linspace(linear_start,
60
+ linear_end,
61
+ n_timestep,
62
+ dtype=torch.float64)
63
  elif schedule == "sqrt":
64
+ betas = torch.linspace(linear_start,
65
+ linear_end,
66
+ n_timestep,
67
+ dtype=torch.float64)**0.5
68
  else:
69
  raise ValueError(f"schedule '{schedule}' unknown.")
70
+ return betas.numpy() # type: ignore
71
 
72
 
73
+ def make_ddim_timesteps(ddim_discr_method,
74
+ num_ddim_timesteps,
75
+ num_ddpm_timesteps,
76
+ verbose=True):
77
  if ddim_discr_method == 'uniform':
78
  c = num_ddpm_timesteps // num_ddim_timesteps
79
  ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
80
  elif ddim_discr_method == 'quad':
81
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8),
82
+ num_ddim_timesteps))**2).astype(int)
83
  else:
84
+ raise NotImplementedError(
85
+ f'There is no ddim discretization method called "{ddim_discr_method}"'
86
+ )
87
 
88
  # assert ddim_timesteps.shape[0] == num_ddim_timesteps
89
  # add one to get the final alpha values right (the ones from first scale to data during sampling)
 
93
  return steps_out
94
 
95
 
96
+ def make_ddim_sampling_parameters(alphacums,
97
+ ddim_timesteps,
98
+ eta,
99
+ verbose=True):
100
  # select alphas for computing the variance schedule
101
  alphas = alphacums[ddim_timesteps]
102
+ alphas_prev = np.asarray([alphacums[0]] +
103
+ alphacums[ddim_timesteps[:-1]].tolist())
104
 
105
  # according the the formula provided in https://arxiv.org/abs/2010.02502
106
+ sigmas = eta * np.sqrt(
107
+ (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
108
  if verbose:
109
+ print(
110
+ f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
111
+ )
112
+ print(
113
+ f'For the chosen value of eta, which is {eta}, '
114
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
115
+ )
116
  return sigmas, alphas, alphas_prev
117
 
118
 
 
138
  def extract_into_tensor(a, t, x_shape):
139
  b, *_ = t.shape
140
  out = a.gather(-1, t)
141
+ return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
142
 
143
 
144
  def checkpoint(func, inputs, params, flag):
 
157
  else:
158
  return func(*inputs)
159
 
160
+
161
  class CheckpointFunction(torch.autograd.Function):
162
+
163
  @staticmethod
164
  def forward(ctx, run_function, length, *args):
165
  ctx.run_function = run_function
 
172
 
173
  @staticmethod
174
  def backward(ctx, *output_grads):
175
+ ctx.input_tensors = [
176
+ x.detach().requires_grad_(True) for x in ctx.input_tensors
177
+ ]
178
  with torch.enable_grad():
179
  # Fixes a bug where the first op in run_function modifies the
180
  # Tensor storage in place, which is not allowed for detach()'d
 
205
  if not repeat_only:
206
  half = dim // 2
207
  freqs = torch.exp(
208
+ -math.log(max_period) *
209
+ torch.arange(start=0, end=half, dtype=torch.float32) /
210
+ half).to(device=timesteps.device)
211
  args = timesteps[:, None].float() * freqs[None]
212
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
213
  if dim % 2:
214
+ embedding = torch.cat(
215
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
216
  else:
217
  embedding = repeat(timesteps, 'b -> b d', d=dim)
218
  # import pdb; pdb.set_trace()
 
255
 
256
  # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
257
  class SiLU(nn.Module):
258
+
259
  def forward(self, x):
260
  return x * torch.sigmoid(x)
261
 
262
 
263
  class GroupNorm32(nn.GroupNorm):
264
+
265
  def forward(self, x):
266
  return super().forward(x.float()).type(x.dtype)
267
 
268
+
269
  def conv_nd(dims, *args, **kwargs):
270
  """
271
  Create a 1D, 2D, or 3D convolution module.
 
303
 
304
  def __init__(self, c_concat_config, c_crossattn_config):
305
  super().__init__()
306
+ self.concat_conditioner: Any = instantiate_from_config(c_concat_config)
307
+ self.crossattn_conditioner: Any = instantiate_from_config(
308
+ c_crossattn_config)
309
 
310
  def forward(self, c_concat, c_crossattn):
311
  c_concat = self.concat_conditioner(c_concat)
 
314
 
315
 
316
  def noise_like(shape, device, repeat=False):
317
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
318
+ shape[0], *((1, ) * (len(shape) - 1)))
319
  noise = lambda: torch.randn(shape, device=device)
320
+ return repeat_noise() if repeat else noise()
text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "openai/clip-vit-large-patch14",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "quick_gelu",
11
+ "hidden_size": 768,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 3072,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 12,
19
+ "num_hidden_layers": 12,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 768,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.34.1",
24
+ "vocab_size": 49408
25
+ }
text_encoder/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06da5c5b4b82aff7c4264398cbdd9f85d7cb2debc93e1e27c16a31222211b6e0
3
+ size 492309274
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<|endoftext|>",
25
+ "lstrip": false,
26
+ "normalized": true,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "49406": {
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49407": {
13
+ "content": "<|endoftext|>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ }
20
+ },
21
+ "bos_token": "<|startoftext|>",
22
+ "clean_up_tokenization_spaces": true,
23
+ "do_lower_case": true,
24
+ "eos_token": "<|endoftext|>",
25
+ "errors": "replace",
26
+ "model_max_length": 77,
27
+ "pad_token": "<|endoftext|>",
28
+ "tokenizer_class": "CLIPTokenizer",
29
+ "unk_token": "<|endoftext|>"
30
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
unet/config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_class_name": "MultiViewUNetWrapperModel",
3
+ "_diffusers_version": "0.21.4"
4
+ }
unet/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d04d15df72f825a031626fad29c8478d6b084442b33f7cf61e3d2acb85f7ff9
3
+ size 3445031598
vae/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.21.4",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "force_upcast": true,
18
+ "in_channels": 3,
19
+ "latent_channels": 4,
20
+ "layers_per_block": 2,
21
+ "norm_num_groups": 32,
22
+ "out_channels": 3,
23
+ "sample_size": 256,
24
+ "scaling_factor": 0.18215,
25
+ "up_block_types": [
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D",
28
+ "UpDecoderBlock2D",
29
+ "UpDecoderBlock2D"
30
+ ]
31
+ }
vae/diffusion_pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f1b909aa85cc520a2986d6fc379478e0c46c41f853f9a7c73c0150b2c9c9b8b
3
+ size 334716034