maxin-cn commited on
Commit
6d2e0dc
โ€ข
1 Parent(s): 235f463

Delete models/dit.py

Browse files
Files changed (1) hide show
  1. models/dit.py +0 -617
models/dit.py DELETED
@@ -1,617 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- # --------------------------------------------------------
7
- # References:
8
- # GLIDE: https://github.com/openai/glide-text2im
9
- # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
- # --------------------------------------------------------
11
- import math
12
- import torch
13
- import torch.nn as nn
14
- import numpy as np
15
-
16
- from einops import rearrange, repeat
17
- from timm.models.vision_transformer import Mlp, PatchEmbed
18
-
19
-
20
-
21
- import os
22
- import sys
23
- # sys.path.append(os.getcwd())
24
- sys.path.append(os.path.split(sys.path[0])[0])
25
- # ไปฃ็ ่งฃ้‡Š
26
- # sys.path[0] : ๅพ—ๅˆฐC:\Users\maxu\Desktop\blog_test\pakage2
27
- # os.path.split(sys.path[0]) : ๅพ—ๅˆฐ['C:\Users\maxu\Desktop\blog_test',pakage2']
28
- # mmcls ้‡Œ้ข่ทจๅŒ…ๅผ•็”จๆ˜ฏๅ› ไธบๅฎ‰่ฃ…ไบ†mmcls
29
-
30
-
31
- # for i in sys.path:
32
- # print(i)
33
-
34
- # the xformers lib allows less memory, faster training and inference
35
- try:
36
- import xformers
37
- import xformers.ops
38
- except:
39
- XFORMERS_IS_AVAILBLE = False
40
-
41
- # from timm.models.layers.helpers import to_2tuple
42
- # from timm.models.layers.trace_utils import _assert
43
-
44
- def modulate(x, shift, scale):
45
- return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
46
-
47
- #################################################################################
48
- # Attention Layers from TIMM #
49
- #################################################################################
50
-
51
- class Attention(nn.Module):
52
- def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_lora=False, attention_mode='math'):
53
- super().__init__()
54
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
55
- self.num_heads = num_heads
56
- head_dim = dim // num_heads
57
- self.scale = head_dim ** -0.5
58
- self.attention_mode = attention_mode
59
- self.use_lora = use_lora
60
-
61
- if self.use_lora:
62
- self.qkv = lora.MergedLinear(dim, dim * 3, r=500, enable_lora=[True, False, True])
63
- else:
64
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
65
-
66
- self.attn_drop = nn.Dropout(attn_drop)
67
- self.proj = nn.Linear(dim, dim)
68
- self.proj_drop = nn.Dropout(proj_drop)
69
-
70
- def forward(self, x):
71
- B, N, C = x.shape
72
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
73
- q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
74
-
75
- if self.attention_mode == 'xformers': # cause loss nan while using with amp
76
- x = xformers.ops.memory_efficient_attention(q, k, v).reshape(B, N, C)
77
-
78
- elif self.attention_mode == 'flash':
79
- # cause loss nan while using with amp
80
- # Optionally use the context manager to ensure one of the fused kerenels is run
81
- with torch.backends.cuda.sdp_kernel(enable_math=False):
82
- x = torch.nn.functional.scaled_dot_product_attention(q, k, v).reshape(B, N, C) # require pytorch 2.0
83
-
84
- elif self.attention_mode == 'math':
85
- attn = (q @ k.transpose(-2, -1)) * self.scale
86
- attn = attn.softmax(dim=-1)
87
- attn = self.attn_drop(attn)
88
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
89
-
90
- else:
91
- raise NotImplemented
92
-
93
- x = self.proj(x)
94
- x = self.proj_drop(x)
95
- return x
96
-
97
-
98
- #################################################################################
99
- # Embedding Layers for Timesteps and Class Labels #
100
- #################################################################################
101
-
102
- class TimestepEmbedder(nn.Module):
103
- """
104
- Embeds scalar timesteps into vector representations.
105
- """
106
- def __init__(self, hidden_size, frequency_embedding_size=256):
107
- super().__init__()
108
- self.mlp = nn.Sequential(
109
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
110
- nn.SiLU(),
111
- nn.Linear(hidden_size, hidden_size, bias=True),
112
- )
113
- self.frequency_embedding_size = frequency_embedding_size
114
-
115
- @staticmethod
116
- def timestep_embedding(t, dim, max_period=10000):
117
- """
118
- Create sinusoidal timestep embeddings.
119
- :param t: a 1-D Tensor of N indices, one per batch element.
120
- These may be fractional.
121
- :param dim: the dimension of the output.
122
- :param max_period: controls the minimum frequency of the embeddings.
123
- :return: an (N, D) Tensor of positional embeddings.
124
- """
125
- # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
126
- half = dim // 2
127
- freqs = torch.exp(
128
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
129
- ).to(device=t.device)
130
- args = t[:, None].float() * freqs[None]
131
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
132
- if dim % 2:
133
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
134
- return embedding
135
-
136
- def forward(self, t):
137
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
138
- t_emb = self.mlp(t_freq)
139
- return t_emb
140
-
141
-
142
- class LabelEmbedder(nn.Module):
143
- """
144
- Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
145
- """
146
- def __init__(self, num_classes, hidden_size, dropout_prob):
147
- super().__init__()
148
- use_cfg_embedding = dropout_prob > 0
149
- self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
150
- self.num_classes = num_classes
151
- self.dropout_prob = dropout_prob
152
-
153
- def token_drop(self, labels, force_drop_ids=None):
154
- """
155
- Drops labels to enable classifier-free guidance.
156
- """
157
- if force_drop_ids is None:
158
- drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
159
- print(drop_ids)
160
- else:
161
- drop_ids = force_drop_ids == 1
162
- labels = torch.where(drop_ids, self.num_classes, labels)
163
- print('******labels******', labels)
164
- return labels
165
-
166
- def forward(self, labels, train, force_drop_ids=None):
167
- use_dropout = self.dropout_prob > 0
168
- if (train and use_dropout) or (force_drop_ids is not None):
169
- labels = self.token_drop(labels, force_drop_ids)
170
- embeddings = self.embedding_table(labels)
171
- return embeddings
172
-
173
-
174
- #################################################################################
175
- # Core DiT Model #
176
- #################################################################################
177
-
178
- class DiTBlock(nn.Module):
179
- """
180
- A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
181
- """
182
- def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
183
- super().__init__()
184
- self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
185
- self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
186
- self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
187
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
188
- approx_gelu = lambda: nn.GELU(approximate="tanh")
189
- self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
190
- self.adaLN_modulation = nn.Sequential(
191
- nn.SiLU(),
192
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
193
- )
194
-
195
- def forward(self, x, c):
196
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
197
- x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
198
- x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
199
- return x
200
-
201
-
202
- class FinalLayer(nn.Module):
203
- """
204
- The final layer of DiT.
205
- """
206
- def __init__(self, hidden_size, patch_size, out_channels):
207
- super().__init__()
208
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
209
- self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
210
- self.adaLN_modulation = nn.Sequential(
211
- nn.SiLU(),
212
- nn.Linear(hidden_size, 2 * hidden_size, bias=True)
213
- )
214
-
215
- def forward(self, x, c):
216
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
217
- x = modulate(self.norm_final(x), shift, scale)
218
- x = self.linear(x)
219
- return x
220
-
221
-
222
- class DiT(nn.Module):
223
- """
224
- Diffusion model with a Transformer backbone.
225
- """
226
- def __init__(
227
- self,
228
- input_size=32,
229
- patch_size=2,
230
- in_channels=4,
231
- hidden_size=1152,
232
- depth=28,
233
- num_heads=16,
234
- mlp_ratio=4.0,
235
- num_frames=16,
236
- class_dropout_prob=0.1,
237
- num_classes=1000,
238
- learn_sigma=True,
239
- class_guided=False,
240
- use_lora=False,
241
- attention_mode='math',
242
- ):
243
- super().__init__()
244
- self.learn_sigma = learn_sigma
245
- self.in_channels = in_channels
246
- self.out_channels = in_channels * 2 if learn_sigma else in_channels
247
- self.patch_size = patch_size
248
- self.num_heads = num_heads
249
- self.class_guided = class_guided
250
- self.num_frames = num_frames
251
-
252
- self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
253
- self.t_embedder = TimestepEmbedder(hidden_size)
254
-
255
- if self.class_guided:
256
- self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
257
-
258
- num_patches = self.x_embedder.num_patches
259
- # Will use fixed sin-cos embedding:
260
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
261
- self.time_embed = nn.Parameter(torch.zeros(1, num_frames, hidden_size), requires_grad=False)
262
-
263
- if use_lora:
264
- self.blocks = nn.ModuleList([
265
- DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attention_mode=attention_mode, use_lora=False if num % 2 ==0 else True) for num in range(depth)
266
- ])
267
- else:
268
- self.blocks = nn.ModuleList([
269
- DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, attention_mode=attention_mode) for _ in range(depth)
270
- ])
271
-
272
- self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
273
- self.initialize_weights()
274
-
275
- def initialize_weights(self):
276
- # Initialize transformer layers:
277
- def _basic_init(module):
278
- if isinstance(module, nn.Linear):
279
- torch.nn.init.xavier_uniform_(module.weight)
280
- if module.bias is not None:
281
- nn.init.constant_(module.bias, 0)
282
- self.apply(_basic_init)
283
-
284
- # Initialize (and freeze) pos_embed by sin-cos embedding:
285
- pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
286
- self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
287
-
288
- time_embed = get_1d_sincos_time_embed(self.time_embed.shape[-1], self.time_embed.shape[-2])
289
- self.time_embed.data.copy_(torch.from_numpy(time_embed).float().unsqueeze(0))
290
-
291
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
292
- w = self.x_embedder.proj.weight.data
293
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
294
- nn.init.constant_(self.x_embedder.proj.bias, 0)
295
-
296
- if self.class_guided:
297
- # Initialize label embedding table:
298
- nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
299
-
300
- # Initialize timestep embedding MLP:
301
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
302
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
303
-
304
- # Zero-out adaLN modulation layers in DiT blocks:
305
- for block in self.blocks:
306
- nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
307
- nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
308
-
309
- # Zero-out output layers:
310
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
311
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
312
- nn.init.constant_(self.final_layer.linear.weight, 0)
313
- nn.init.constant_(self.final_layer.linear.bias, 0)
314
-
315
- def unpatchify(self, x):
316
- """
317
- x: (N, T, patch_size**2 * C)
318
- imgs: (N, H, W, C)
319
- """
320
- c = self.out_channels
321
- p = self.x_embedder.patch_size[0]
322
- h = w = int(x.shape[1] ** 0.5)
323
- assert h * w == x.shape[1]
324
-
325
- x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
326
- x = torch.einsum('nhwpqc->nchpwq', x)
327
- imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
328
- return imgs
329
-
330
- # @torch.cuda.amp.autocast()
331
- # @torch.compile
332
- def forward(self, x, t, y=None):
333
- """
334
- Forward pass of DiT.
335
- x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
336
- t: (N,) tensor of diffusion timesteps
337
- y: (N,) tensor of class labels
338
- """
339
- # print('label: {}'.format(y))
340
- batches, frames, channels, high, weight = x.shape # for example, 3, 16, 3, 32, 32
341
- # ่ฟ™้‡ŒrearrangeๅŽๆฏ้š”fๆ˜ฏๅŒไธ€ไธช่ง†้ข‘
342
- x = rearrange(x, 'b f c h w -> (b f) c h w')
343
- x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
344
- t = self.t_embedder(t) # (N, D)
345
- # timestep_spatial็š„repeat้œ€่ฆไฟ่ฏๆฏfๅธงไธบๅŒไธ€ไธชtimesteps
346
- timestep_spatial = repeat(t, 'n d -> (n c) d', c=self.time_embed.shape[1]) # 48, 1152; c=num_frames
347
- timestep_time = repeat(t, 'n d -> (n c) d', c=self.pos_embed.shape[1]) # 768, 1152; c=num tokens
348
-
349
- if self.class_guided:
350
- y = self.y_embedder(y, self.training)
351
- y_spatial = repeat(y, 'n d -> (n c) d', c=self.time_embed.shape[1]) # 48, 1152; c=num_frames
352
- y_time = repeat(y, 'n d -> (n c) d', c=self.pos_embed.shape[1]) # 768, 1152; c=num tokens
353
-
354
- # if self.class_guided:
355
- # y = self.y_embedder(y, self.training) # (N, D)
356
- # c = timestep_spatial + y
357
- # else:
358
- # c = timestep_spatial
359
-
360
- # for block in self.blocks:
361
- # x = block(x, c) # (N, T, D)
362
-
363
- for i in range(0, len(self.blocks), 2):
364
- # print('The {}-th run'.format(i))
365
- spatial_block, time_block = self.blocks[i:i+2]
366
- # print(spatial_block)
367
- # print(time_block)
368
- # print(x.shape)
369
-
370
- if self.class_guided:
371
- c = timestep_spatial + y_spatial
372
- else:
373
- c = timestep_spatial
374
- x = spatial_block(x, c)
375
- # print(c.shape)
376
-
377
- x = rearrange(x, '(b f) t d -> (b t) f d', b=batches) # t ไปฃ่กจๅ•ๅธงtokenๆ•ฐ; 768, 16, 1152
378
- # Add Time Embedding
379
- if i == 0:
380
- x = x + self.time_embed # 768, 16, 1152
381
-
382
- if self.class_guided:
383
- c = timestep_time + y_time
384
- else:
385
- # timestep_time = repeat(t, 'n d -> (n c) d', c=x.shape[0] // batches) # 768, 1152
386
- # print(timestep_time.shape)
387
- c = timestep_time
388
-
389
- x = time_block(x, c)
390
- # print(x.shape)
391
- x = rearrange(x, '(b t) f d -> (b f) t d', b=batches)
392
-
393
- # x = rearrange(x, '(b t) f d -> (b f) t d', b=batches)
394
- if self.class_guided:
395
- c = timestep_spatial + y_spatial
396
- else:
397
- c = timestep_spatial
398
- x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
399
- x = self.unpatchify(x) # (N, out_channels, H, W)
400
- x = rearrange(x, '(b f) c h w -> b f c h w', b=batches)
401
- # print(x.shape)
402
- return x
403
-
404
- def forward_motion(self, motions, t, base_frame, y=None):
405
- """
406
- Forward pass of DiT.
407
- x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
408
- t: (N,) tensor of diffusion timesteps
409
- y: (N,) tensor of class labels
410
- """
411
- # print('label: {}'.format(y))
412
- batches, frames, channels, high, weight = motions.shape # for example, 3, 16, 3, 32, 32
413
- # ่ฟ™้‡ŒrearrangeๅŽๆฏ้š”fๆ˜ฏๅŒไธ€ไธช่ง†้ข‘
414
- motions = rearrange(motions, 'b f c h w -> (b f) c h w')
415
- motions = self.x_embedder(motions) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
416
- t = self.t_embedder(t) # (N, D)
417
- # timestep_spatial็š„repeat้œ€่ฆไฟ่ฏๆฏfๅธงไธบๅŒไธ€ไธชtimesteps
418
- timestep_spatial = repeat(t, 'n d -> (n c) d', c=self.time_embed.shape[1]) # 48, 1152; c=num_frames
419
- timestep_time = repeat(t, 'n d -> (n c) d', c=self.pos_embed.shape[1]) # 768, 1152; c=num tokens
420
-
421
- if self.class_guided:
422
- y = self.y_embedder(y, self.training)
423
- y_spatial = repeat(y, 'n d -> (n c) d', c=self.time_embed.shape[1]) # 48, 1152; c=num_frames
424
- y_time = repeat(y, 'n d -> (n c) d', c=self.pos_embed.shape[1]) # 768, 1152; c=num tokens
425
-
426
- # if self.class_guided:
427
- # y = self.y_embedder(y, self.training) # (N, D)
428
- # c = timestep_spatial + y
429
- # else:
430
- # c = timestep_spatial
431
-
432
- # for block in self.blocks:
433
- # x = block(x, c) # (N, T, D)
434
-
435
- for i in range(0, len(self.blocks), 2):
436
- # print('The {}-th run'.format(i))
437
- spatial_block, time_block = self.blocks[i:i+2]
438
- # print(spatial_block)
439
- # print(time_block)
440
- # print(x.shape)
441
-
442
- if self.class_guided:
443
- c = timestep_spatial + y_spatial
444
- else:
445
- c = timestep_spatial
446
- x = spatial_block(x, c)
447
- # print(c.shape)
448
-
449
- x = rearrange(x, '(b f) t d -> (b t) f d', b=batches) # t ไปฃ่กจๅ•ๅธงtokenๆ•ฐ; 768, 16, 1152
450
- # Add Time Embedding
451
- if i == 0:
452
- x = x + self.time_embed # 768, 16, 1152
453
-
454
- if self.class_guided:
455
- c = timestep_time + y_time
456
- else:
457
- # timestep_time = repeat(t, 'n d -> (n c) d', c=x.shape[0] // batches) # 768, 1152
458
- # print(timestep_time.shape)
459
- c = timestep_time
460
-
461
- x = time_block(x, c)
462
- # print(x.shape)
463
- x = rearrange(x, '(b t) f d -> (b f) t d', b=batches)
464
-
465
- # x = rearrange(x, '(b t) f d -> (b f) t d', b=batches)
466
- if self.class_guided:
467
- c = timestep_spatial + y_spatial
468
- else:
469
- c = timestep_spatial
470
- x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
471
- x = self.unpatchify(x) # (N, out_channels, H, W)
472
- x = rearrange(x, '(b f) c h w -> b f c h w', b=batches)
473
- # print(x.shape)
474
- return x
475
-
476
-
477
- def forward_with_cfg(self, x, t, y, cfg_scale):
478
- """
479
- Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
480
- """
481
- # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
482
- half = x[: len(x) // 2]
483
- combined = torch.cat([half, half], dim=0)
484
- model_out = self.forward(combined, t, y)
485
- # For exact reproducibility reasons, we apply classifier-free guidance on only
486
- # three channels by default. The standard approach to cfg applies it to all channels.
487
- # This can be done by uncommenting the following line and commenting-out the line following that.
488
- # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
489
- eps, rest = model_out[:, :3], model_out[:, 3:]
490
- cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
491
- half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
492
- eps = torch.cat([half_eps, half_eps], dim=0)
493
- return torch.cat([eps, rest], dim=1)
494
-
495
-
496
- #################################################################################
497
- # Sine/Cosine Positional Embedding Functions #
498
- #################################################################################
499
- # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
500
-
501
- def get_1d_sincos_time_embed(embed_dim, length):
502
- pos = torch.arange(0, length).unsqueeze(1)
503
- return get_1d_sincos_pos_embed_from_grid(embed_dim, pos)
504
-
505
- def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
506
- """
507
- grid_size: int of the grid height and width
508
- return:
509
- pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
510
- """
511
- grid_h = np.arange(grid_size, dtype=np.float32)
512
- grid_w = np.arange(grid_size, dtype=np.float32)
513
- grid = np.meshgrid(grid_w, grid_h) # here w goes first
514
- grid = np.stack(grid, axis=0)
515
-
516
- grid = grid.reshape([2, 1, grid_size, grid_size])
517
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
518
- if cls_token and extra_tokens > 0:
519
- pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
520
- return pos_embed
521
-
522
-
523
- def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
524
- assert embed_dim % 2 == 0
525
-
526
- # use half of dimensions to encode grid_h
527
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
528
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
529
-
530
- emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
531
- return emb
532
-
533
-
534
- def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
535
- """
536
- embed_dim: output dimension for each position
537
- pos: a list of positions to be encoded: size (M,)
538
- out: (M, D)
539
- """
540
- assert embed_dim % 2 == 0
541
- omega = np.arange(embed_dim // 2, dtype=np.float64)
542
- omega /= embed_dim / 2.
543
- omega = 1. / 10000**omega # (D/2,)
544
-
545
- pos = pos.reshape(-1) # (M,)
546
- out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
547
-
548
- emb_sin = np.sin(out) # (M, D/2)
549
- emb_cos = np.cos(out) # (M, D/2)
550
-
551
- emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
552
- return emb
553
-
554
-
555
- #################################################################################
556
- # DiT Configs #
557
- #################################################################################
558
-
559
- def DiT_XL_2(**kwargs):
560
- return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
561
-
562
- def DiT_XL_4(**kwargs):
563
- return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
564
-
565
- def DiT_XL_8(**kwargs):
566
- return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
567
-
568
- def DiT_L_2(**kwargs):
569
- return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
570
-
571
- def DiT_L_4(**kwargs):
572
- return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
573
-
574
- def DiT_L_8(**kwargs):
575
- return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
576
-
577
- def DiT_B_2(**kwargs):
578
- return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
579
-
580
- def DiT_B_4(**kwargs):
581
- return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
582
-
583
- def DiT_B_8(**kwargs):
584
- return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
585
-
586
- def DiT_S_2(**kwargs):
587
- return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
588
-
589
- def DiT_S_4(**kwargs):
590
- return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
591
-
592
- def DiT_S_8(**kwargs):
593
- return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
594
-
595
-
596
- DiT_models = {
597
- 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
598
- 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
599
- 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
600
- 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
601
- }
602
-
603
- if __name__ == '__main__':
604
-
605
- import torch
606
-
607
- device = "cuda" if torch.cuda.is_available() else "cpu"
608
-
609
- img = torch.randn(3, 16, 4, 32, 32).to(device)
610
- t = torch.tensor([1, 2, 3]).to(device)
611
- y = torch.tensor([1, 2, 3]).to(device)
612
- network = DiT_XL_2().to(device)
613
- y_embeder = LabelEmbedder(num_classes=100, hidden_size=768, dropout_prob=0.5).to(device)
614
- # lora.mark_only_lora_as_trainable(network)
615
- out = y_embeder(y, True)
616
- # out = network(img, t, y)
617
- print(out.shape)