ashawkey commited on
Commit
9f446ee
1 Parent(s): b7219f3

Upload folder using huggingface_hub

Browse files
model_index.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MVDreamPipeline",
3
+ "_diffusers_version": "0.25.0",
4
+ "feature_extractor": [
5
+ null,
6
+ null
7
+ ],
8
+ "image_encoder": [
9
+ null,
10
+ null
11
+ ],
12
+ "requires_safety_checker": false,
13
+ "scheduler": [
14
+ "diffusers",
15
+ "DDIMScheduler"
16
+ ],
17
+ "text_encoder": [
18
+ "transformers",
19
+ "CLIPTextModel"
20
+ ],
21
+ "tokenizer": [
22
+ "transformers",
23
+ "CLIPTokenizer"
24
+ ],
25
+ "unet": [
26
+ "mv_unet",
27
+ "MultiViewUNetModel"
28
+ ],
29
+ "vae": [
30
+ "diffusers",
31
+ "AutoencoderKL"
32
+ ]
33
+ }
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.25.0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "clip_sample_range": 1.0,
9
+ "dynamic_thresholding_ratio": 0.995,
10
+ "num_train_timesteps": 1000,
11
+ "prediction_type": "epsilon",
12
+ "rescale_betas_zero_snr": false,
13
+ "sample_max_value": 1.0,
14
+ "set_alpha_to_one": false,
15
+ "steps_offset": 1,
16
+ "thresholding": false,
17
+ "timestep_spacing": "leading",
18
+ "trained_betas": null
19
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "stabilityai/stable-diffusion-2-1",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_size": 1024,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 4096,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 23,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 512,
22
+ "torch_dtype": "float16",
23
+ "transformers_version": "4.35.2",
24
+ "vocab_size": 49408
25
+ }
text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc1827c465450322616f06dea41596eac7d493f4e95904dcb51f0fc745c4e13f
3
+ size 680820392
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "!",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "!",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49406": {
13
+ "content": "<|startoftext|>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "49407": {
21
+ "content": "<|endoftext|>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ }
28
+ },
29
+ "bos_token": "<|startoftext|>",
30
+ "clean_up_tokenization_spaces": true,
31
+ "do_lower_case": true,
32
+ "eos_token": "<|endoftext|>",
33
+ "errors": "replace",
34
+ "model_max_length": 77,
35
+ "pad_token": "!",
36
+ "tokenizer_class": "CLIPTokenizer",
37
+ "unk_token": "<|endoftext|>"
38
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
unet/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MultiViewUNetModel",
3
+ "_diffusers_version": "0.25.0",
4
+ "attention_resolutions": [
5
+ 4,
6
+ 2,
7
+ 1
8
+ ],
9
+ "camera_dim": 16,
10
+ "channel_mult": [
11
+ 1,
12
+ 2,
13
+ 4,
14
+ 4
15
+ ],
16
+ "context_dim": 1024,
17
+ "image_size": 32,
18
+ "in_channels": 4,
19
+ "model_channels": 320,
20
+ "num_head_channels": 64,
21
+ "num_res_blocks": 2,
22
+ "out_channels": 4,
23
+ "transformer_depth": 1,
24
+ "use_checkpoint": false
25
+ }
unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ff839dd8c11591c2faa8efca41ac8145be8878b8ebf7cc92255fdcab0e09e53
3
+ size 1735224544
unet/mv_unet.py ADDED
@@ -0,0 +1,1089 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ from inspect import isfunction
4
+ from typing import Optional, Any, List
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+
11
+ from diffusers.configuration_utils import ConfigMixin
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+
14
+ # require xformers!
15
+ import xformers
16
+ import xformers.ops
17
+
18
+ from kiui.cam import orbit_camera
19
+
20
+ def get_camera(
21
+ num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False,
22
+ ):
23
+ angle_gap = azimuth_span / num_frames
24
+ cameras = []
25
+ for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
26
+
27
+ pose = orbit_camera(-elevation, azimuth, radius=1) # kiui's elevation is negated, [4, 4]
28
+
29
+ # opengl to blender
30
+ if blender_coord:
31
+ pose[2] *= -1
32
+ pose[[1, 2]] = pose[[2, 1]]
33
+
34
+ cameras.append(pose.flatten())
35
+
36
+ if extra_view:
37
+ cameras.append(np.zeros_like(cameras[0]))
38
+
39
+ return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
40
+
41
+
42
+ def checkpoint(func, inputs, params, flag):
43
+ """
44
+ Evaluate a function without caching intermediate activations, allowing for
45
+ reduced memory at the expense of extra compute in the backward pass.
46
+ :param func: the function to evaluate.
47
+ :param inputs: the argument sequence to pass to `func`.
48
+ :param params: a sequence of parameters `func` depends on but does not
49
+ explicitly take as arguments.
50
+ :param flag: if False, disable gradient checkpointing.
51
+ """
52
+ if flag:
53
+ args = tuple(inputs) + tuple(params)
54
+ return CheckpointFunction.apply(func, len(inputs), *args)
55
+ else:
56
+ return func(*inputs)
57
+
58
+
59
+ class CheckpointFunction(torch.autograd.Function):
60
+ @staticmethod
61
+ def forward(ctx, run_function, length, *args):
62
+ ctx.run_function = run_function
63
+ ctx.input_tensors = list(args[:length])
64
+ ctx.input_params = list(args[length:])
65
+
66
+ with torch.no_grad():
67
+ output_tensors = ctx.run_function(*ctx.input_tensors)
68
+ return output_tensors
69
+
70
+ @staticmethod
71
+ def backward(ctx, *output_grads):
72
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
73
+ with torch.enable_grad():
74
+ # Fixes a bug where the first op in run_function modifies the
75
+ # Tensor storage in place, which is not allowed for detach()'d
76
+ # Tensors.
77
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
78
+ output_tensors = ctx.run_function(*shallow_copies)
79
+ input_grads = torch.autograd.grad(
80
+ output_tensors,
81
+ ctx.input_tensors + ctx.input_params,
82
+ output_grads,
83
+ allow_unused=True,
84
+ )
85
+ del ctx.input_tensors
86
+ del ctx.input_params
87
+ del output_tensors
88
+ return (None, None) + input_grads
89
+
90
+
91
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
92
+ """
93
+ Create sinusoidal timestep embeddings.
94
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
95
+ These may be fractional.
96
+ :param dim: the dimension of the output.
97
+ :param max_period: controls the minimum frequency of the embeddings.
98
+ :return: an [N x dim] Tensor of positional embeddings.
99
+ """
100
+ if not repeat_only:
101
+ half = dim // 2
102
+ freqs = torch.exp(
103
+ -math.log(max_period)
104
+ * torch.arange(start=0, end=half, dtype=torch.float32)
105
+ / half
106
+ ).to(device=timesteps.device)
107
+ args = timesteps[:, None] * freqs[None]
108
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
109
+ if dim % 2:
110
+ embedding = torch.cat(
111
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
112
+ )
113
+ else:
114
+ embedding = repeat(timesteps, "b -> b d", d=dim)
115
+ # import pdb; pdb.set_trace()
116
+ return embedding
117
+
118
+
119
+ def zero_module(module):
120
+ """
121
+ Zero out the parameters of a module and return it.
122
+ """
123
+ for p in module.parameters():
124
+ p.detach().zero_()
125
+ return module
126
+
127
+
128
+ def conv_nd(dims, *args, **kwargs):
129
+ """
130
+ Create a 1D, 2D, or 3D convolution module.
131
+ """
132
+ if dims == 1:
133
+ return nn.Conv1d(*args, **kwargs)
134
+ elif dims == 2:
135
+ return nn.Conv2d(*args, **kwargs)
136
+ elif dims == 3:
137
+ return nn.Conv3d(*args, **kwargs)
138
+ raise ValueError(f"unsupported dimensions: {dims}")
139
+
140
+
141
+ def avg_pool_nd(dims, *args, **kwargs):
142
+ """
143
+ Create a 1D, 2D, or 3D average pooling module.
144
+ """
145
+ if dims == 1:
146
+ return nn.AvgPool1d(*args, **kwargs)
147
+ elif dims == 2:
148
+ return nn.AvgPool2d(*args, **kwargs)
149
+ elif dims == 3:
150
+ return nn.AvgPool3d(*args, **kwargs)
151
+ raise ValueError(f"unsupported dimensions: {dims}")
152
+
153
+
154
+ def default(val, d):
155
+ if val is not None:
156
+ return val
157
+ return d() if isfunction(d) else d
158
+
159
+
160
+ class GEGLU(nn.Module):
161
+ def __init__(self, dim_in, dim_out):
162
+ super().__init__()
163
+ self.proj = nn.Linear(dim_in, dim_out * 2)
164
+
165
+ def forward(self, x):
166
+ x, gate = self.proj(x).chunk(2, dim=-1)
167
+ return x * F.gelu(gate)
168
+
169
+
170
+ class FeedForward(nn.Module):
171
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
172
+ super().__init__()
173
+ inner_dim = int(dim * mult)
174
+ dim_out = default(dim_out, dim)
175
+ project_in = (
176
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
177
+ if not glu
178
+ else GEGLU(dim, inner_dim)
179
+ )
180
+
181
+ self.net = nn.Sequential(
182
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
183
+ )
184
+
185
+ def forward(self, x):
186
+ return self.net(x)
187
+
188
+
189
+ class MemoryEfficientCrossAttention(nn.Module):
190
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
191
+ def __init__(
192
+ self,
193
+ query_dim,
194
+ context_dim=None,
195
+ heads=8,
196
+ dim_head=64,
197
+ dropout=0.0,
198
+ ip_dim=0,
199
+ ip_weight=1,
200
+ ):
201
+ super().__init__()
202
+
203
+ inner_dim = dim_head * heads
204
+ context_dim = default(context_dim, query_dim)
205
+
206
+ self.heads = heads
207
+ self.dim_head = dim_head
208
+
209
+ self.ip_dim = ip_dim
210
+ self.ip_weight = ip_weight
211
+
212
+ if self.ip_dim > 0:
213
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
214
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
215
+
216
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
217
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
218
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
219
+
220
+ self.to_out = nn.Sequential(
221
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
222
+ )
223
+ self.attention_op: Optional[Any] = None
224
+
225
+ def forward(self, x, context=None):
226
+ q = self.to_q(x)
227
+ context = default(context, x)
228
+
229
+ if self.ip_dim > 0:
230
+ # context: [B, 77 + 16(ip), 1024]
231
+ token_len = context.shape[1]
232
+ context_ip = context[:, -self.ip_dim :, :]
233
+ k_ip = self.to_k_ip(context_ip)
234
+ v_ip = self.to_v_ip(context_ip)
235
+ context = context[:, : (token_len - self.ip_dim), :]
236
+
237
+ k = self.to_k(context)
238
+ v = self.to_v(context)
239
+
240
+ b, _, _ = q.shape
241
+ q, k, v = map(
242
+ lambda t: t.unsqueeze(3)
243
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
244
+ .permute(0, 2, 1, 3)
245
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
246
+ .contiguous(),
247
+ (q, k, v),
248
+ )
249
+
250
+ # actually compute the attention, what we cannot get enough of
251
+ out = xformers.ops.memory_efficient_attention(
252
+ q, k, v, attn_bias=None, op=self.attention_op
253
+ )
254
+
255
+ if self.ip_dim > 0:
256
+ k_ip, v_ip = map(
257
+ lambda t: t.unsqueeze(3)
258
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
259
+ .permute(0, 2, 1, 3)
260
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
261
+ .contiguous(),
262
+ (k_ip, v_ip),
263
+ )
264
+ # actually compute the attention, what we cannot get enough of
265
+ out_ip = xformers.ops.memory_efficient_attention(
266
+ q, k_ip, v_ip, attn_bias=None, op=self.attention_op
267
+ )
268
+ out = out + self.ip_weight * out_ip
269
+
270
+ out = (
271
+ out.unsqueeze(0)
272
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
273
+ .permute(0, 2, 1, 3)
274
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
275
+ )
276
+ return self.to_out(out)
277
+
278
+
279
+ class BasicTransformerBlock3D(nn.Module):
280
+
281
+ def __init__(
282
+ self,
283
+ dim,
284
+ n_heads,
285
+ d_head,
286
+ context_dim,
287
+ dropout=0.0,
288
+ gated_ff=True,
289
+ checkpoint=True,
290
+ ip_dim=0,
291
+ ip_weight=1,
292
+ ):
293
+ super().__init__()
294
+
295
+ self.attn1 = MemoryEfficientCrossAttention(
296
+ query_dim=dim,
297
+ context_dim=None, # self-attention
298
+ heads=n_heads,
299
+ dim_head=d_head,
300
+ dropout=dropout,
301
+ )
302
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
303
+ self.attn2 = MemoryEfficientCrossAttention(
304
+ query_dim=dim,
305
+ context_dim=context_dim,
306
+ heads=n_heads,
307
+ dim_head=d_head,
308
+ dropout=dropout,
309
+ # ip only applies to cross-attention
310
+ ip_dim=ip_dim,
311
+ ip_weight=ip_weight,
312
+ )
313
+ self.norm1 = nn.LayerNorm(dim)
314
+ self.norm2 = nn.LayerNorm(dim)
315
+ self.norm3 = nn.LayerNorm(dim)
316
+ self.checkpoint = checkpoint
317
+
318
+ def forward(self, x, context=None, num_frames=1):
319
+ return checkpoint(
320
+ self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
321
+ )
322
+
323
+ def _forward(self, x, context=None, num_frames=1):
324
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
325
+ x = self.attn1(self.norm1(x), context=None) + x
326
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
327
+ x = self.attn2(self.norm2(x), context=context) + x
328
+ x = self.ff(self.norm3(x)) + x
329
+ return x
330
+
331
+
332
+ class SpatialTransformer3D(nn.Module):
333
+
334
+ def __init__(
335
+ self,
336
+ in_channels,
337
+ n_heads,
338
+ d_head,
339
+ context_dim, # cross attention input dim
340
+ depth=1,
341
+ dropout=0.0,
342
+ ip_dim=0,
343
+ ip_weight=1,
344
+ use_checkpoint=True,
345
+ ):
346
+ super().__init__()
347
+
348
+ if not isinstance(context_dim, list):
349
+ context_dim = [context_dim]
350
+
351
+ self.in_channels = in_channels
352
+
353
+ inner_dim = n_heads * d_head
354
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
355
+ self.proj_in = nn.Linear(in_channels, inner_dim)
356
+
357
+ self.transformer_blocks = nn.ModuleList(
358
+ [
359
+ BasicTransformerBlock3D(
360
+ inner_dim,
361
+ n_heads,
362
+ d_head,
363
+ context_dim=context_dim[d],
364
+ dropout=dropout,
365
+ checkpoint=use_checkpoint,
366
+ ip_dim=ip_dim,
367
+ ip_weight=ip_weight,
368
+ )
369
+ for d in range(depth)
370
+ ]
371
+ )
372
+
373
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
374
+
375
+
376
+ def forward(self, x, context=None, num_frames=1):
377
+ # note: if no context is given, cross-attention defaults to self-attention
378
+ if not isinstance(context, list):
379
+ context = [context]
380
+ b, c, h, w = x.shape
381
+ x_in = x
382
+ x = self.norm(x)
383
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
384
+ x = self.proj_in(x)
385
+ for i, block in enumerate(self.transformer_blocks):
386
+ x = block(x, context=context[i], num_frames=num_frames)
387
+ x = self.proj_out(x)
388
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
389
+
390
+ return x + x_in
391
+
392
+
393
+ class PerceiverAttention(nn.Module):
394
+ def __init__(self, *, dim, dim_head=64, heads=8):
395
+ super().__init__()
396
+ self.scale = dim_head ** -0.5
397
+ self.dim_head = dim_head
398
+ self.heads = heads
399
+ inner_dim = dim_head * heads
400
+
401
+ self.norm1 = nn.LayerNorm(dim)
402
+ self.norm2 = nn.LayerNorm(dim)
403
+
404
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
405
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
406
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
407
+
408
+ def forward(self, x, latents):
409
+ """
410
+ Args:
411
+ x (torch.Tensor): image features
412
+ shape (b, n1, D)
413
+ latent (torch.Tensor): latent features
414
+ shape (b, n2, D)
415
+ """
416
+ x = self.norm1(x)
417
+ latents = self.norm2(latents)
418
+
419
+ b, l, _ = latents.shape
420
+
421
+ q = self.to_q(latents)
422
+ kv_input = torch.cat((x, latents), dim=-2)
423
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
424
+
425
+ q, k, v = map(
426
+ lambda t: t.reshape(b, t.shape[1], self.heads, -1)
427
+ .transpose(1, 2)
428
+ .reshape(b, self.heads, t.shape[1], -1)
429
+ .contiguous(),
430
+ (q, k, v),
431
+ )
432
+
433
+ # attention
434
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
435
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
436
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
437
+ out = weight @ v
438
+
439
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
440
+
441
+ return self.to_out(out)
442
+
443
+
444
+ class Resampler(nn.Module):
445
+ def __init__(
446
+ self,
447
+ dim=1024,
448
+ depth=8,
449
+ dim_head=64,
450
+ heads=16,
451
+ num_queries=8,
452
+ embedding_dim=768,
453
+ output_dim=1024,
454
+ ff_mult=4,
455
+ ):
456
+ super().__init__()
457
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
458
+ self.proj_in = nn.Linear(embedding_dim, dim)
459
+ self.proj_out = nn.Linear(dim, output_dim)
460
+ self.norm_out = nn.LayerNorm(output_dim)
461
+
462
+ self.layers = nn.ModuleList([])
463
+ for _ in range(depth):
464
+ self.layers.append(
465
+ nn.ModuleList(
466
+ [
467
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
468
+ nn.Sequential(
469
+ nn.LayerNorm(dim),
470
+ nn.Linear(dim, dim * ff_mult, bias=False),
471
+ nn.GELU(),
472
+ nn.Linear(dim * ff_mult, dim, bias=False),
473
+ )
474
+ ]
475
+ )
476
+ )
477
+
478
+ def forward(self, x):
479
+ latents = self.latents.repeat(x.size(0), 1, 1)
480
+ x = self.proj_in(x)
481
+ for attn, ff in self.layers:
482
+ latents = attn(x, latents) + latents
483
+ latents = ff(latents) + latents
484
+
485
+ latents = self.proj_out(latents)
486
+ return self.norm_out(latents)
487
+
488
+
489
+ class CondSequential(nn.Sequential):
490
+ """
491
+ A sequential module that passes timestep embeddings to the children that
492
+ support it as an extra input.
493
+ """
494
+
495
+ def forward(self, x, emb, context=None, num_frames=1):
496
+ for layer in self:
497
+ if isinstance(layer, ResBlock):
498
+ x = layer(x, emb)
499
+ elif isinstance(layer, SpatialTransformer3D):
500
+ x = layer(x, context, num_frames=num_frames)
501
+ else:
502
+ x = layer(x)
503
+ return x
504
+
505
+
506
+ class Upsample(nn.Module):
507
+ """
508
+ An upsampling layer with an optional convolution.
509
+ :param channels: channels in the inputs and outputs.
510
+ :param use_conv: a bool determining if a convolution is applied.
511
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
512
+ upsampling occurs in the inner-two dimensions.
513
+ """
514
+
515
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
516
+ super().__init__()
517
+ self.channels = channels
518
+ self.out_channels = out_channels or channels
519
+ self.use_conv = use_conv
520
+ self.dims = dims
521
+ if use_conv:
522
+ self.conv = conv_nd(
523
+ dims, self.channels, self.out_channels, 3, padding=padding
524
+ )
525
+
526
+ def forward(self, x):
527
+ assert x.shape[1] == self.channels
528
+ if self.dims == 3:
529
+ x = F.interpolate(
530
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
531
+ )
532
+ else:
533
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
534
+ if self.use_conv:
535
+ x = self.conv(x)
536
+ return x
537
+
538
+
539
+ class Downsample(nn.Module):
540
+ """
541
+ A downsampling layer with an optional convolution.
542
+ :param channels: channels in the inputs and outputs.
543
+ :param use_conv: a bool determining if a convolution is applied.
544
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
545
+ downsampling occurs in the inner-two dimensions.
546
+ """
547
+
548
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
549
+ super().__init__()
550
+ self.channels = channels
551
+ self.out_channels = out_channels or channels
552
+ self.use_conv = use_conv
553
+ self.dims = dims
554
+ stride = 2 if dims != 3 else (1, 2, 2)
555
+ if use_conv:
556
+ self.op = conv_nd(
557
+ dims,
558
+ self.channels,
559
+ self.out_channels,
560
+ 3,
561
+ stride=stride,
562
+ padding=padding,
563
+ )
564
+ else:
565
+ assert self.channels == self.out_channels
566
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
567
+
568
+ def forward(self, x):
569
+ assert x.shape[1] == self.channels
570
+ return self.op(x)
571
+
572
+
573
+ class ResBlock(nn.Module):
574
+ """
575
+ A residual block that can optionally change the number of channels.
576
+ :param channels: the number of input channels.
577
+ :param emb_channels: the number of timestep embedding channels.
578
+ :param dropout: the rate of dropout.
579
+ :param out_channels: if specified, the number of out channels.
580
+ :param use_conv: if True and out_channels is specified, use a spatial
581
+ convolution instead of a smaller 1x1 convolution to change the
582
+ channels in the skip connection.
583
+ :param dims: determines if the signal is 1D, 2D, or 3D.
584
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
585
+ :param up: if True, use this block for upsampling.
586
+ :param down: if True, use this block for downsampling.
587
+ """
588
+
589
+ def __init__(
590
+ self,
591
+ channels,
592
+ emb_channels,
593
+ dropout,
594
+ out_channels=None,
595
+ use_conv=False,
596
+ use_scale_shift_norm=False,
597
+ dims=2,
598
+ use_checkpoint=False,
599
+ up=False,
600
+ down=False,
601
+ ):
602
+ super().__init__()
603
+ self.channels = channels
604
+ self.emb_channels = emb_channels
605
+ self.dropout = dropout
606
+ self.out_channels = out_channels or channels
607
+ self.use_conv = use_conv
608
+ self.use_checkpoint = use_checkpoint
609
+ self.use_scale_shift_norm = use_scale_shift_norm
610
+
611
+ self.in_layers = nn.Sequential(
612
+ nn.GroupNorm(32, channels),
613
+ nn.SiLU(),
614
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
615
+ )
616
+
617
+ self.updown = up or down
618
+
619
+ if up:
620
+ self.h_upd = Upsample(channels, False, dims)
621
+ self.x_upd = Upsample(channels, False, dims)
622
+ elif down:
623
+ self.h_upd = Downsample(channels, False, dims)
624
+ self.x_upd = Downsample(channels, False, dims)
625
+ else:
626
+ self.h_upd = self.x_upd = nn.Identity()
627
+
628
+ self.emb_layers = nn.Sequential(
629
+ nn.SiLU(),
630
+ nn.Linear(
631
+ emb_channels,
632
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
633
+ ),
634
+ )
635
+ self.out_layers = nn.Sequential(
636
+ nn.GroupNorm(32, self.out_channels),
637
+ nn.SiLU(),
638
+ nn.Dropout(p=dropout),
639
+ zero_module(
640
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
641
+ ),
642
+ )
643
+
644
+ if self.out_channels == channels:
645
+ self.skip_connection = nn.Identity()
646
+ elif use_conv:
647
+ self.skip_connection = conv_nd(
648
+ dims, channels, self.out_channels, 3, padding=1
649
+ )
650
+ else:
651
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
652
+
653
+ def forward(self, x, emb):
654
+ """
655
+ Apply the block to a Tensor, conditioned on a timestep embedding.
656
+ :param x: an [N x C x ...] Tensor of features.
657
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
658
+ :return: an [N x C x ...] Tensor of outputs.
659
+ """
660
+ return checkpoint(
661
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
662
+ )
663
+
664
+ def _forward(self, x, emb):
665
+ if self.updown:
666
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
667
+ h = in_rest(x)
668
+ h = self.h_upd(h)
669
+ x = self.x_upd(x)
670
+ h = in_conv(h)
671
+ else:
672
+ h = self.in_layers(x)
673
+ emb_out = self.emb_layers(emb).type(h.dtype)
674
+ while len(emb_out.shape) < len(h.shape):
675
+ emb_out = emb_out[..., None]
676
+ if self.use_scale_shift_norm:
677
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
678
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
679
+ h = out_norm(h) * (1 + scale) + shift
680
+ h = out_rest(h)
681
+ else:
682
+ h = h + emb_out
683
+ h = self.out_layers(h)
684
+ return self.skip_connection(x) + h
685
+
686
+
687
+ class MultiViewUNetModel(ModelMixin, ConfigMixin):
688
+ """
689
+ The full multi-view UNet model with attention, timestep embedding and camera embedding.
690
+ :param in_channels: channels in the input Tensor.
691
+ :param model_channels: base channel count for the model.
692
+ :param out_channels: channels in the output Tensor.
693
+ :param num_res_blocks: number of residual blocks per downsample.
694
+ :param attention_resolutions: a collection of downsample rates at which
695
+ attention will take place. May be a set, list, or tuple.
696
+ For example, if this contains 4, then at 4x downsampling, attention
697
+ will be used.
698
+ :param dropout: the dropout probability.
699
+ :param channel_mult: channel multiplier for each level of the UNet.
700
+ :param conv_resample: if True, use learned convolutions for upsampling and
701
+ downsampling.
702
+ :param dims: determines if the signal is 1D, 2D, or 3D.
703
+ :param num_classes: if specified (as an int), then this model will be
704
+ class-conditional with `num_classes` classes.
705
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
706
+ :param num_heads: the number of attention heads in each attention layer.
707
+ :param num_heads_channels: if specified, ignore num_heads and instead use
708
+ a fixed channel width per attention head.
709
+ :param num_heads_upsample: works with num_heads to set a different number
710
+ of heads for upsampling. Deprecated.
711
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
712
+ :param resblock_updown: use residual blocks for up/downsampling.
713
+ :param use_new_attention_order: use a different attention pattern for potentially
714
+ increased efficiency.
715
+ :param camera_dim: dimensionality of camera input.
716
+ """
717
+
718
+ def __init__(
719
+ self,
720
+ image_size,
721
+ in_channels,
722
+ model_channels,
723
+ out_channels,
724
+ num_res_blocks,
725
+ attention_resolutions,
726
+ dropout=0,
727
+ channel_mult=(1, 2, 4, 8),
728
+ conv_resample=True,
729
+ dims=2,
730
+ num_classes=None,
731
+ use_checkpoint=False,
732
+ num_heads=-1,
733
+ num_head_channels=-1,
734
+ num_heads_upsample=-1,
735
+ use_scale_shift_norm=False,
736
+ resblock_updown=False,
737
+ transformer_depth=1,
738
+ context_dim=None,
739
+ n_embed=None,
740
+ num_attention_blocks=None,
741
+ adm_in_channels=None,
742
+ camera_dim=None,
743
+ ip_dim=0, # imagedream uses ip_dim > 0
744
+ ip_weight=1.0,
745
+ **kwargs,
746
+ ):
747
+ super().__init__()
748
+ assert context_dim is not None
749
+
750
+ if num_heads_upsample == -1:
751
+ num_heads_upsample = num_heads
752
+
753
+ if num_heads == -1:
754
+ assert (
755
+ num_head_channels != -1
756
+ ), "Either num_heads or num_head_channels has to be set"
757
+
758
+ if num_head_channels == -1:
759
+ assert (
760
+ num_heads != -1
761
+ ), "Either num_heads or num_head_channels has to be set"
762
+
763
+ self.image_size = image_size
764
+ self.in_channels = in_channels
765
+ self.model_channels = model_channels
766
+ self.out_channels = out_channels
767
+ if isinstance(num_res_blocks, int):
768
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
769
+ else:
770
+ if len(num_res_blocks) != len(channel_mult):
771
+ raise ValueError(
772
+ "provide num_res_blocks either as an int (globally constant) or "
773
+ "as a list/tuple (per-level) with the same length as channel_mult"
774
+ )
775
+ self.num_res_blocks = num_res_blocks
776
+
777
+ if num_attention_blocks is not None:
778
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
779
+ assert all(
780
+ map(
781
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
782
+ range(len(num_attention_blocks)),
783
+ )
784
+ )
785
+ print(
786
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
787
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
788
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
789
+ f"attention will still not be set."
790
+ )
791
+
792
+ self.attention_resolutions = attention_resolutions
793
+ self.dropout = dropout
794
+ self.channel_mult = channel_mult
795
+ self.conv_resample = conv_resample
796
+ self.num_classes = num_classes
797
+ self.use_checkpoint = use_checkpoint
798
+ self.num_heads = num_heads
799
+ self.num_head_channels = num_head_channels
800
+ self.num_heads_upsample = num_heads_upsample
801
+ self.predict_codebook_ids = n_embed is not None
802
+
803
+ self.ip_dim = ip_dim
804
+ self.ip_weight = ip_weight
805
+
806
+ if self.ip_dim > 0:
807
+ self.image_embed = Resampler(
808
+ dim=context_dim,
809
+ depth=4,
810
+ dim_head=64,
811
+ heads=12,
812
+ num_queries=ip_dim, # num token
813
+ embedding_dim=1280,
814
+ output_dim=context_dim,
815
+ ff_mult=4,
816
+ )
817
+
818
+ time_embed_dim = model_channels * 4
819
+ self.time_embed = nn.Sequential(
820
+ nn.Linear(model_channels, time_embed_dim),
821
+ nn.SiLU(),
822
+ nn.Linear(time_embed_dim, time_embed_dim),
823
+ )
824
+
825
+ if camera_dim is not None:
826
+ time_embed_dim = model_channels * 4
827
+ self.camera_embed = nn.Sequential(
828
+ nn.Linear(camera_dim, time_embed_dim),
829
+ nn.SiLU(),
830
+ nn.Linear(time_embed_dim, time_embed_dim),
831
+ )
832
+
833
+ if self.num_classes is not None:
834
+ if isinstance(self.num_classes, int):
835
+ self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
836
+ elif self.num_classes == "continuous":
837
+ # print("setting up linear c_adm embedding layer")
838
+ self.label_emb = nn.Linear(1, time_embed_dim)
839
+ elif self.num_classes == "sequential":
840
+ assert adm_in_channels is not None
841
+ self.label_emb = nn.Sequential(
842
+ nn.Sequential(
843
+ nn.Linear(adm_in_channels, time_embed_dim),
844
+ nn.SiLU(),
845
+ nn.Linear(time_embed_dim, time_embed_dim),
846
+ )
847
+ )
848
+ else:
849
+ raise ValueError()
850
+
851
+ self.input_blocks = nn.ModuleList(
852
+ [
853
+ CondSequential(
854
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
855
+ )
856
+ ]
857
+ )
858
+ self._feature_size = model_channels
859
+ input_block_chans = [model_channels]
860
+ ch = model_channels
861
+ ds = 1
862
+ for level, mult in enumerate(channel_mult):
863
+ for nr in range(self.num_res_blocks[level]):
864
+ layers: List[Any] = [
865
+ ResBlock(
866
+ ch,
867
+ time_embed_dim,
868
+ dropout,
869
+ out_channels=mult * model_channels,
870
+ dims=dims,
871
+ use_checkpoint=use_checkpoint,
872
+ use_scale_shift_norm=use_scale_shift_norm,
873
+ )
874
+ ]
875
+ ch = mult * model_channels
876
+ if ds in attention_resolutions:
877
+ if num_head_channels == -1:
878
+ dim_head = ch // num_heads
879
+ else:
880
+ num_heads = ch // num_head_channels
881
+ dim_head = num_head_channels
882
+
883
+ if num_attention_blocks is None or nr < num_attention_blocks[level]:
884
+ layers.append(
885
+ SpatialTransformer3D(
886
+ ch,
887
+ num_heads,
888
+ dim_head,
889
+ context_dim=context_dim,
890
+ depth=transformer_depth,
891
+ use_checkpoint=use_checkpoint,
892
+ ip_dim=self.ip_dim,
893
+ ip_weight=self.ip_weight,
894
+ )
895
+ )
896
+ self.input_blocks.append(CondSequential(*layers))
897
+ self._feature_size += ch
898
+ input_block_chans.append(ch)
899
+ if level != len(channel_mult) - 1:
900
+ out_ch = ch
901
+ self.input_blocks.append(
902
+ CondSequential(
903
+ ResBlock(
904
+ ch,
905
+ time_embed_dim,
906
+ dropout,
907
+ out_channels=out_ch,
908
+ dims=dims,
909
+ use_checkpoint=use_checkpoint,
910
+ use_scale_shift_norm=use_scale_shift_norm,
911
+ down=True,
912
+ )
913
+ if resblock_updown
914
+ else Downsample(
915
+ ch, conv_resample, dims=dims, out_channels=out_ch
916
+ )
917
+ )
918
+ )
919
+ ch = out_ch
920
+ input_block_chans.append(ch)
921
+ ds *= 2
922
+ self._feature_size += ch
923
+
924
+ if num_head_channels == -1:
925
+ dim_head = ch // num_heads
926
+ else:
927
+ num_heads = ch // num_head_channels
928
+ dim_head = num_head_channels
929
+
930
+ self.middle_block = CondSequential(
931
+ ResBlock(
932
+ ch,
933
+ time_embed_dim,
934
+ dropout,
935
+ dims=dims,
936
+ use_checkpoint=use_checkpoint,
937
+ use_scale_shift_norm=use_scale_shift_norm,
938
+ ),
939
+ SpatialTransformer3D(
940
+ ch,
941
+ num_heads,
942
+ dim_head,
943
+ context_dim=context_dim,
944
+ depth=transformer_depth,
945
+ use_checkpoint=use_checkpoint,
946
+ ip_dim=self.ip_dim,
947
+ ip_weight=self.ip_weight,
948
+ ),
949
+ ResBlock(
950
+ ch,
951
+ time_embed_dim,
952
+ dropout,
953
+ dims=dims,
954
+ use_checkpoint=use_checkpoint,
955
+ use_scale_shift_norm=use_scale_shift_norm,
956
+ ),
957
+ )
958
+ self._feature_size += ch
959
+
960
+ self.output_blocks = nn.ModuleList([])
961
+ for level, mult in list(enumerate(channel_mult))[::-1]:
962
+ for i in range(self.num_res_blocks[level] + 1):
963
+ ich = input_block_chans.pop()
964
+ layers = [
965
+ ResBlock(
966
+ ch + ich,
967
+ time_embed_dim,
968
+ dropout,
969
+ out_channels=model_channels * mult,
970
+ dims=dims,
971
+ use_checkpoint=use_checkpoint,
972
+ use_scale_shift_norm=use_scale_shift_norm,
973
+ )
974
+ ]
975
+ ch = model_channels * mult
976
+ if ds in attention_resolutions:
977
+ if num_head_channels == -1:
978
+ dim_head = ch // num_heads
979
+ else:
980
+ num_heads = ch // num_head_channels
981
+ dim_head = num_head_channels
982
+
983
+ if num_attention_blocks is None or i < num_attention_blocks[level]:
984
+ layers.append(
985
+ SpatialTransformer3D(
986
+ ch,
987
+ num_heads,
988
+ dim_head,
989
+ context_dim=context_dim,
990
+ depth=transformer_depth,
991
+ use_checkpoint=use_checkpoint,
992
+ ip_dim=self.ip_dim,
993
+ ip_weight=self.ip_weight,
994
+ )
995
+ )
996
+ if level and i == self.num_res_blocks[level]:
997
+ out_ch = ch
998
+ layers.append(
999
+ ResBlock(
1000
+ ch,
1001
+ time_embed_dim,
1002
+ dropout,
1003
+ out_channels=out_ch,
1004
+ dims=dims,
1005
+ use_checkpoint=use_checkpoint,
1006
+ use_scale_shift_norm=use_scale_shift_norm,
1007
+ up=True,
1008
+ )
1009
+ if resblock_updown
1010
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
1011
+ )
1012
+ ds //= 2
1013
+ self.output_blocks.append(CondSequential(*layers))
1014
+ self._feature_size += ch
1015
+
1016
+ self.out = nn.Sequential(
1017
+ nn.GroupNorm(32, ch),
1018
+ nn.SiLU(),
1019
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
1020
+ )
1021
+ if self.predict_codebook_ids:
1022
+ self.id_predictor = nn.Sequential(
1023
+ nn.GroupNorm(32, ch),
1024
+ conv_nd(dims, model_channels, n_embed, 1),
1025
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
1026
+ )
1027
+
1028
+ def forward(
1029
+ self,
1030
+ x,
1031
+ timesteps=None,
1032
+ context=None,
1033
+ y=None,
1034
+ camera=None,
1035
+ num_frames=1,
1036
+ ip=None,
1037
+ ip_img=None,
1038
+ **kwargs,
1039
+ ):
1040
+ """
1041
+ Apply the model to an input batch.
1042
+ :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
1043
+ :param timesteps: a 1-D batch of timesteps.
1044
+ :param context: conditioning plugged in via crossattn
1045
+ :param y: an [N] Tensor of labels, if class-conditional.
1046
+ :param num_frames: a integer indicating number of frames for tensor reshaping.
1047
+ :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
1048
+ """
1049
+ assert (
1050
+ x.shape[0] % num_frames == 0
1051
+ ), "input batch size must be dividable by num_frames!"
1052
+ assert (y is not None) == (
1053
+ self.num_classes is not None
1054
+ ), "must specify y if and only if the model is class-conditional"
1055
+
1056
+ hs = []
1057
+
1058
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
1059
+
1060
+ emb = self.time_embed(t_emb)
1061
+
1062
+ if self.num_classes is not None:
1063
+ assert y is not None
1064
+ assert y.shape[0] == x.shape[0]
1065
+ emb = emb + self.label_emb(y)
1066
+
1067
+ # Add camera embeddings
1068
+ if camera is not None:
1069
+ emb = emb + self.camera_embed(camera)
1070
+
1071
+ # imagedream variant
1072
+ if self.ip_dim > 0:
1073
+ x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9]
1074
+ ip_emb = self.image_embed(ip)
1075
+ context = torch.cat((context, ip_emb), 1)
1076
+
1077
+ h = x
1078
+ for module in self.input_blocks:
1079
+ h = module(h, emb, context, num_frames=num_frames)
1080
+ hs.append(h)
1081
+ h = self.middle_block(h, emb, context, num_frames=num_frames)
1082
+ for module in self.output_blocks:
1083
+ h = torch.cat([h, hs.pop()], dim=1)
1084
+ h = module(h, emb, context, num_frames=num_frames)
1085
+ h = h.type(x.dtype)
1086
+ if self.predict_codebook_ids:
1087
+ return self.id_predictor(h)
1088
+ else:
1089
+ return self.out(h)
vae/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.25.0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "force_upcast": true,
18
+ "in_channels": 3,
19
+ "latent_channels": 4,
20
+ "layers_per_block": 2,
21
+ "norm_num_groups": 32,
22
+ "out_channels": 3,
23
+ "sample_size": 256,
24
+ "scaling_factor": 0.18215,
25
+ "up_block_types": [
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D",
28
+ "UpDecoderBlock2D",
29
+ "UpDecoderBlock2D"
30
+ ]
31
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e4c08995484ee61270175e9e7a072b66a6e4eeb5f0c266667fe1f45b90daf9a
3
+ size 167335342