rookie9 commited on
Commit
f582ec6
·
verified ·
1 Parent(s): 39cfe13

Upload 77 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +77 -0
  2. models/__pycache__/common.cpython-310.pyc +0 -0
  3. models/__pycache__/content_adapter.cpython-310.pyc +0 -0
  4. models/__pycache__/diffusion.cpython-310.pyc +0 -0
  5. models/__pycache__/diffusion_cfg.cpython-310.pyc +0 -0
  6. models/__pycache__/diffusion_cfg_new.cpython-310.pyc +0 -0
  7. models/__pycache__/diffusion_content_cfg.cpython-310.pyc +0 -0
  8. models/autoencoder/__pycache__/autoencoder_base.cpython-310.pyc +0 -0
  9. models/autoencoder/autoencoder_base.py +22 -0
  10. models/autoencoder/waveform/__pycache__/stable_vae.cpython-310.pyc +0 -0
  11. models/autoencoder/waveform/stable_vae.py +537 -0
  12. models/common.py +69 -0
  13. models/content_encoder/__pycache__/caption_encoder.cpython-310.pyc +0 -0
  14. models/content_encoder/__pycache__/content_encoder.cpython-310.pyc +0 -0
  15. models/content_encoder/__pycache__/content_encoder_add_1024.cpython-310.pyc +0 -0
  16. models/content_encoder/__pycache__/content_encoder_clap.cpython-310.pyc +0 -0
  17. models/content_encoder/__pycache__/content_encoder_clap_test.cpython-310.pyc +0 -0
  18. models/content_encoder/__pycache__/content_encoder_concat.cpython-310.pyc +0 -0
  19. models/content_encoder/__pycache__/content_encoder_concat_4096.cpython-310.pyc +0 -0
  20. models/content_encoder/__pycache__/content_encoder_concat_4096_random.cpython-310.pyc +0 -0
  21. models/content_encoder/__pycache__/content_encoder_full.cpython-310.pyc +0 -0
  22. models/content_encoder/__pycache__/content_encoder_full_non.cpython-310.pyc +0 -0
  23. models/content_encoder/__pycache__/content_encoder_full_non_test.cpython-310.pyc +0 -0
  24. models/content_encoder/__pycache__/content_encoder_full_test.cpython-310.pyc +0 -0
  25. models/content_encoder/__pycache__/content_encoder_full_woonset.cpython-310.pyc +0 -0
  26. models/content_encoder/__pycache__/content_encoder_merge.cpython-310.pyc +0 -0
  27. models/content_encoder/__pycache__/content_encoder_merge_test.cpython-310.pyc +0 -0
  28. models/content_encoder/__pycache__/content_encoder_replace.cpython-310.pyc +0 -0
  29. models/content_encoder/__pycache__/content_encoder_replace_merge.cpython-310.pyc +0 -0
  30. models/content_encoder/__pycache__/content_encoder_replace_new.cpython-310.pyc +0 -0
  31. models/content_encoder/__pycache__/content_encoder_test.cpython-310.pyc +0 -0
  32. models/content_encoder/__pycache__/content_test.cpython-310.pyc +0 -0
  33. models/content_encoder/__pycache__/new_content_encoder.cpython-310.pyc +0 -0
  34. models/content_encoder/__pycache__/text_encoder.cpython-310.pyc +0 -0
  35. models/content_encoder/caption_encoder.py +116 -0
  36. models/content_encoder/text_encoder.py +76 -0
  37. models/diffusion.py +398 -0
  38. models/dit/__pycache__/attention.cpython-310.pyc +0 -0
  39. models/dit/__pycache__/audio_dit.cpython-310.pyc +0 -0
  40. models/dit/__pycache__/mask_dit.cpython-310.pyc +0 -0
  41. models/dit/__pycache__/modules.cpython-310.pyc +0 -0
  42. models/dit/__pycache__/rotary.cpython-310.pyc +0 -0
  43. models/dit/__pycache__/span_mask.cpython-310.pyc +0 -0
  44. models/dit/attention.py +350 -0
  45. models/dit/audio_diffsingernet_dit.py +520 -0
  46. models/dit/audio_dit.py +549 -0
  47. models/dit/mask_dit.py +823 -0
  48. models/dit/modules.py +445 -0
  49. models/dit/rotary.py +88 -0
  50. models/dit/span_mask.py +149 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import json
4
+ import torch
5
+ import soundfile as sf
6
+ import numpy as np
7
+ from pathlib import Path
8
+ from transformers import AutoModel
9
+ #from utils.llm import get_time_info
10
+ from utils.llm_xiapi import get_time_info
11
+
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ model = AutoModel.from_pretrained("rookie9/PicoAudio2", trust_remote_code=True).to(device)
14
+ print("ok")
15
+ def is_tdc_format_valid(tdc_str):
16
+ try:
17
+ for event_onset in tdc_str.split('--'):
18
+ event, instance = event_onset.split('__')
19
+ for start_end in instance.split('_'):
20
+ start, end = start_end.split('-')
21
+ return True
22
+ except Exception:
23
+ return False
24
+
25
+ def infer(input_text, input_onset, input_length, time_control):
26
+ # para
27
+ if input_onset and not is_tdc_format_valid(input_onset):
28
+ input_onset = "random"
29
+ if time_control:
30
+ if not input_onset or not input_length:
31
+ input_json = json.loads(get_time_info(input_text))
32
+ input_onset, input_length = input_json["onset"], input_json["length"]
33
+ else:
34
+ input_onset = input_onset if input_onset else "random"
35
+ input_length = input_length if input_length else "10.0"
36
+
37
+ content = {
38
+ "caption": input_text,
39
+ "onset": input_onset,
40
+ "length": input_length
41
+ }
42
+
43
+
44
+ with torch.no_grad():
45
+ waveform = model(content)
46
+ output_wav = "output.wav"
47
+ sf.write(
48
+ output_wav,
49
+ waveform[0, 0].cpu().numpy(),
50
+ samplerate=exp_config["sample_rate"],
51
+ )
52
+ return output_wav, str(input_onset)
53
+
54
+ demo = gr.Interface(
55
+ fn=infer,
56
+ inputs=[
57
+ gr.Textbox(label="TCC (caption, required)", value="a dog barks"),
58
+ gr.Textbox(label="TDC (optional, see format)", value="random"),
59
+ gr.Textbox(label="Length (seconds, optional)", value="10.0"),
60
+ gr.Checkbox(label="Enable Time Control", value=False),
61
+ ],
62
+ outputs=[
63
+ gr.Audio(label="Generated Audio"),
64
+ gr.Textbox(label="Final TDC Used (input_onset)")
65
+ ],
66
+ title="PicoAudio2 Online Inference",
67
+ description=(
68
+ "TCC (caption) is neto generate audio. "
69
+ "If you need time control, please enter TDC and length (in seconds). "
70
+ "Alternatively, you can let the LLM generate TDC, but API quota limits may affect availability. "
71
+ "TDC format: \"event1__start1-end1_start2-end2--event2__start1-end1\", for example: "
72
+ "\"a_dog_barks__1.0-2.0_3.0-4.0--a_man_speaks__5.0-6.0\"."
73
+ "If the format of TDC is wrong or no input length, the model will generate audio without temporal control. Sorry!"
74
+ )
75
+ )
76
+ if __name__ == "__main__":
77
+ demo.launch()
models/__pycache__/common.cpython-310.pyc ADDED
Binary file (2.94 kB). View file
 
models/__pycache__/content_adapter.cpython-310.pyc ADDED
Binary file (3.87 kB). View file
 
models/__pycache__/diffusion.cpython-310.pyc ADDED
Binary file (10.5 kB). View file
 
models/__pycache__/diffusion_cfg.cpython-310.pyc ADDED
Binary file (18.9 kB). View file
 
models/__pycache__/diffusion_cfg_new.cpython-310.pyc ADDED
Binary file (18.8 kB). View file
 
models/__pycache__/diffusion_content_cfg.cpython-310.pyc ADDED
Binary file (18.5 kB). View file
 
models/autoencoder/__pycache__/autoencoder_base.cpython-310.pyc ADDED
Binary file (1.06 kB). View file
 
models/autoencoder/autoencoder_base.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod, ABC
2
+ from typing import Sequence
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class AutoEncoderBase(ABC):
8
+ def __init__(
9
+ self, downsampling_ratio: int, sample_rate: int,
10
+ latent_shape: Sequence[int | None]
11
+ ):
12
+ self.downsampling_ratio = downsampling_ratio
13
+ self.sample_rate = sample_rate
14
+ self.latent_token_rate = sample_rate // downsampling_ratio
15
+ self.latent_shape = latent_shape
16
+ self.time_dim = latent_shape.index(None) + 1 # the first dim is batch
17
+
18
+ @abstractmethod
19
+ def encode(
20
+ self, waveform: torch.Tensor, waveform_lengths: torch.Tensor
21
+ ) -> tuple[torch.Tensor, torch.Tensor]:
22
+ ...
models/autoencoder/waveform/__pycache__/stable_vae.cpython-310.pyc ADDED
Binary file (12 kB). View file
 
models/autoencoder/waveform/stable_vae.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Literal, Callable
2
+ import math
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn.utils.parametrizations import weight_norm
8
+ import torchaudio
9
+ from alias_free_torch import Activation1d
10
+
11
+ from models.common import LoadPretrainedBase
12
+ from models.autoencoder.autoencoder_base import AutoEncoderBase
13
+ from utils.torch_utilities import remove_key_prefix_factory, create_mask_from_length
14
+
15
+
16
+ # jit script make it 1.4x faster and save GPU memory
17
+ @torch.jit.script
18
+ def snake_beta(x, alpha, beta):
19
+ return x + (1.0 / (beta+0.000000001)) * pow(torch.sin(x * alpha), 2)
20
+
21
+
22
+ class SnakeBeta(nn.Module):
23
+ def __init__(
24
+ self,
25
+ in_features,
26
+ alpha=1.0,
27
+ alpha_trainable=True,
28
+ alpha_logscale=True
29
+ ):
30
+ super(SnakeBeta, self).__init__()
31
+ self.in_features = in_features
32
+
33
+ # initialize alpha
34
+ self.alpha_logscale = alpha_logscale
35
+ if self.alpha_logscale:
36
+ # log scale alphas initialized to zeros
37
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
38
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
39
+ else:
40
+ # linear scale alphas initialized to ones
41
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
42
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
43
+
44
+ self.alpha.requires_grad = alpha_trainable
45
+ self.beta.requires_grad = alpha_trainable
46
+
47
+ # self.no_div_by_zero = 0.000000001
48
+
49
+ def forward(self, x):
50
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
51
+ # line up with x to [B, C, T]
52
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
53
+ if self.alpha_logscale:
54
+ alpha = torch.exp(alpha)
55
+ beta = torch.exp(beta)
56
+ x = snake_beta(x, alpha, beta)
57
+
58
+ return x
59
+
60
+
61
+ def WNConv1d(*args, **kwargs):
62
+ return weight_norm(nn.Conv1d(*args, **kwargs))
63
+
64
+
65
+ def WNConvTranspose1d(*args, **kwargs):
66
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
67
+
68
+
69
+ def get_activation(
70
+ activation: Literal["elu", "snake", "none"],
71
+ antialias=False,
72
+ channels=None
73
+ ) -> nn.Module:
74
+ if activation == "elu":
75
+ act = nn.ELU()
76
+ elif activation == "snake":
77
+ act = SnakeBeta(channels)
78
+ elif activation == "none":
79
+ act = nn.Identity()
80
+ else:
81
+ raise ValueError(f"Unknown activation {activation}")
82
+
83
+ if antialias:
84
+ act = Activation1d(act)
85
+
86
+ return act
87
+
88
+
89
+ class ResidualUnit(nn.Module):
90
+ def __init__(
91
+ self,
92
+ in_channels,
93
+ out_channels,
94
+ dilation,
95
+ use_snake=False,
96
+ antialias_activation=False
97
+ ):
98
+ super().__init__()
99
+
100
+ self.dilation = dilation
101
+
102
+ padding = (dilation * (7-1)) // 2
103
+
104
+ self.layers = nn.Sequential(
105
+ get_activation(
106
+ "snake" if use_snake else "elu",
107
+ antialias=antialias_activation,
108
+ channels=out_channels
109
+ ),
110
+ WNConv1d(
111
+ in_channels=in_channels,
112
+ out_channels=out_channels,
113
+ kernel_size=7,
114
+ dilation=dilation,
115
+ padding=padding
116
+ ),
117
+ get_activation(
118
+ "snake" if use_snake else "elu",
119
+ antialias=antialias_activation,
120
+ channels=out_channels
121
+ ),
122
+ WNConv1d(
123
+ in_channels=out_channels,
124
+ out_channels=out_channels,
125
+ kernel_size=1
126
+ )
127
+ )
128
+
129
+ def forward(self, x):
130
+ res = x
131
+
132
+ #x = checkpoint(self.layers, x)
133
+ x = self.layers(x)
134
+
135
+ return x + res
136
+
137
+
138
+ class EncoderBlock(nn.Module):
139
+ def __init__(
140
+ self,
141
+ in_channels,
142
+ out_channels,
143
+ stride,
144
+ use_snake=False,
145
+ antialias_activation=False
146
+ ):
147
+ super().__init__()
148
+
149
+ self.layers = nn.Sequential(
150
+ ResidualUnit(
151
+ in_channels=in_channels,
152
+ out_channels=in_channels,
153
+ dilation=1,
154
+ use_snake=use_snake
155
+ ),
156
+ ResidualUnit(
157
+ in_channels=in_channels,
158
+ out_channels=in_channels,
159
+ dilation=3,
160
+ use_snake=use_snake
161
+ ),
162
+ ResidualUnit(
163
+ in_channels=in_channels,
164
+ out_channels=in_channels,
165
+ dilation=9,
166
+ use_snake=use_snake
167
+ ),
168
+ get_activation(
169
+ "snake" if use_snake else "elu",
170
+ antialias=antialias_activation,
171
+ channels=in_channels
172
+ ),
173
+ WNConv1d(
174
+ in_channels=in_channels,
175
+ out_channels=out_channels,
176
+ kernel_size=2 * stride,
177
+ stride=stride,
178
+ padding=math.ceil(stride / 2)
179
+ ),
180
+ )
181
+
182
+ def forward(self, x):
183
+ return self.layers(x)
184
+
185
+
186
+ class DecoderBlock(nn.Module):
187
+ def __init__(
188
+ self,
189
+ in_channels,
190
+ out_channels,
191
+ stride,
192
+ use_snake=False,
193
+ antialias_activation=False,
194
+ use_nearest_upsample=False
195
+ ):
196
+ super().__init__()
197
+
198
+ if use_nearest_upsample:
199
+ upsample_layer = nn.Sequential(
200
+ nn.Upsample(scale_factor=stride, mode="nearest"),
201
+ WNConv1d(
202
+ in_channels=in_channels,
203
+ out_channels=out_channels,
204
+ kernel_size=2 * stride,
205
+ stride=1,
206
+ bias=False,
207
+ padding='same'
208
+ )
209
+ )
210
+ else:
211
+ upsample_layer = WNConvTranspose1d(
212
+ in_channels=in_channels,
213
+ out_channels=out_channels,
214
+ kernel_size=2 * stride,
215
+ stride=stride,
216
+ padding=math.ceil(stride / 2)
217
+ )
218
+
219
+ self.layers = nn.Sequential(
220
+ get_activation(
221
+ "snake" if use_snake else "elu",
222
+ antialias=antialias_activation,
223
+ channels=in_channels
224
+ ),
225
+ upsample_layer,
226
+ ResidualUnit(
227
+ in_channels=out_channels,
228
+ out_channels=out_channels,
229
+ dilation=1,
230
+ use_snake=use_snake
231
+ ),
232
+ ResidualUnit(
233
+ in_channels=out_channels,
234
+ out_channels=out_channels,
235
+ dilation=3,
236
+ use_snake=use_snake
237
+ ),
238
+ ResidualUnit(
239
+ in_channels=out_channels,
240
+ out_channels=out_channels,
241
+ dilation=9,
242
+ use_snake=use_snake
243
+ ),
244
+ )
245
+
246
+ def forward(self, x):
247
+ return self.layers(x)
248
+
249
+
250
+ class OobleckEncoder(nn.Module):
251
+ def __init__(
252
+ self,
253
+ in_channels=2,
254
+ channels=128,
255
+ latent_dim=32,
256
+ c_mults=[1, 2, 4, 8],
257
+ strides=[2, 4, 8, 8],
258
+ use_snake=False,
259
+ antialias_activation=False
260
+ ):
261
+ super().__init__()
262
+
263
+ c_mults = [1] + c_mults
264
+
265
+ self.depth = len(c_mults)
266
+
267
+ layers = [
268
+ WNConv1d(
269
+ in_channels=in_channels,
270
+ out_channels=c_mults[0] * channels,
271
+ kernel_size=7,
272
+ padding=3
273
+ )
274
+ ]
275
+
276
+ for i in range(self.depth - 1):
277
+ layers += [
278
+ EncoderBlock(
279
+ in_channels=c_mults[i] * channels,
280
+ out_channels=c_mults[i + 1] * channels,
281
+ stride=strides[i],
282
+ use_snake=use_snake
283
+ )
284
+ ]
285
+
286
+ layers += [
287
+ get_activation(
288
+ "snake" if use_snake else "elu",
289
+ antialias=antialias_activation,
290
+ channels=c_mults[-1] * channels
291
+ ),
292
+ WNConv1d(
293
+ in_channels=c_mults[-1] * channels,
294
+ out_channels=latent_dim,
295
+ kernel_size=3,
296
+ padding=1
297
+ )
298
+ ]
299
+
300
+ self.layers = nn.Sequential(*layers)
301
+
302
+ def forward(self, x):
303
+ return self.layers(x)
304
+
305
+
306
+ class OobleckDecoder(nn.Module):
307
+ def __init__(
308
+ self,
309
+ out_channels=2,
310
+ channels=128,
311
+ latent_dim=32,
312
+ c_mults=[1, 2, 4, 8],
313
+ strides=[2, 4, 8, 8],
314
+ use_snake=False,
315
+ antialias_activation=False,
316
+ use_nearest_upsample=False,
317
+ final_tanh=True
318
+ ):
319
+ super().__init__()
320
+
321
+ c_mults = [1] + c_mults
322
+
323
+ self.depth = len(c_mults)
324
+
325
+ layers = [
326
+ WNConv1d(
327
+ in_channels=latent_dim,
328
+ out_channels=c_mults[-1] * channels,
329
+ kernel_size=7,
330
+ padding=3
331
+ ),
332
+ ]
333
+
334
+ for i in range(self.depth - 1, 0, -1):
335
+ layers += [
336
+ DecoderBlock(
337
+ in_channels=c_mults[i] * channels,
338
+ out_channels=c_mults[i - 1] * channels,
339
+ stride=strides[i - 1],
340
+ use_snake=use_snake,
341
+ antialias_activation=antialias_activation,
342
+ use_nearest_upsample=use_nearest_upsample
343
+ )
344
+ ]
345
+
346
+ layers += [
347
+ get_activation(
348
+ "snake" if use_snake else "elu",
349
+ antialias=antialias_activation,
350
+ channels=c_mults[0] * channels
351
+ ),
352
+ WNConv1d(
353
+ in_channels=c_mults[0] * channels,
354
+ out_channels=out_channels,
355
+ kernel_size=7,
356
+ padding=3,
357
+ bias=False
358
+ ),
359
+ nn.Tanh() if final_tanh else nn.Identity()
360
+ ]
361
+
362
+ self.layers = nn.Sequential(*layers)
363
+
364
+ def forward(self, x):
365
+ return self.layers(x)
366
+
367
+
368
+ class Bottleneck(nn.Module):
369
+ def __init__(self, is_discrete: bool = False):
370
+ super().__init__()
371
+
372
+ self.is_discrete = is_discrete
373
+
374
+ def encode(self, x, return_info=False, **kwargs):
375
+ raise NotImplementedError
376
+
377
+ def decode(self, x):
378
+ raise NotImplementedError
379
+
380
+
381
+ @torch.jit.script
382
+ def vae_sample(mean, scale) -> dict[str, torch.Tensor]:
383
+ stdev = nn.functional.softplus(scale) + 1e-4
384
+ var = stdev * stdev
385
+ logvar = torch.log(var)
386
+ latents = torch.randn_like(mean) * stdev + mean
387
+
388
+ kl = (mean*mean + var - logvar - 1).sum(1).mean()
389
+ return {"latents": latents, "kl": kl}
390
+
391
+
392
+ class VAEBottleneck(Bottleneck):
393
+ def __init__(self):
394
+ super().__init__(is_discrete=False)
395
+
396
+ def encode(self,
397
+ x,
398
+ return_info=False,
399
+ **kwargs) -> dict[str, torch.Tensor] | torch.Tensor:
400
+ mean, scale = x.chunk(2, dim=1)
401
+ sampled = vae_sample(mean, scale)
402
+
403
+ if return_info:
404
+ return sampled["latents"], {"kl": sampled["kl"]}
405
+ else:
406
+ return sampled["latents"]
407
+
408
+ def decode(self, x):
409
+ return x
410
+
411
+
412
+ def compute_mean_kernel(x, y):
413
+ kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
414
+ return torch.exp(-kernel_input).mean()
415
+
416
+
417
+ class Pretransform(nn.Module):
418
+ def __init__(self, enable_grad, io_channels, is_discrete):
419
+ super().__init__()
420
+
421
+ self.is_discrete = is_discrete
422
+ self.io_channels = io_channels
423
+ self.encoded_channels = None
424
+ self.downsampling_ratio = None
425
+
426
+ self.enable_grad = enable_grad
427
+
428
+ def encode(self, x):
429
+ raise NotImplementedError
430
+
431
+ def decode(self, z):
432
+ raise NotImplementedError
433
+
434
+ def tokenize(self, x):
435
+ raise NotImplementedError
436
+
437
+ def decode_tokens(self, tokens):
438
+ raise NotImplementedError
439
+
440
+
441
+ class StableVAE(LoadPretrainedBase, AutoEncoderBase):
442
+ def __init__(
443
+ self,
444
+ encoder,
445
+ decoder,
446
+ latent_dim,
447
+ downsampling_ratio,
448
+ sample_rate,
449
+ io_channels=2,
450
+ bottleneck: Bottleneck = None,
451
+ pretransform: Pretransform = None,
452
+ in_channels=None,
453
+ out_channels=None,
454
+ soft_clip=False,
455
+ pretrained_ckpt: str | Path = None
456
+ ):
457
+ LoadPretrainedBase.__init__(self)
458
+ AutoEncoderBase.__init__(
459
+ self,
460
+ downsampling_ratio=downsampling_ratio,
461
+ sample_rate=sample_rate,
462
+ latent_shape=(latent_dim, None)
463
+ )
464
+
465
+ self.latent_dim = latent_dim
466
+ self.io_channels = io_channels
467
+ self.in_channels = io_channels
468
+ self.out_channels = io_channels
469
+ self.min_length = self.downsampling_ratio
470
+
471
+ if in_channels is not None:
472
+ self.in_channels = in_channels
473
+
474
+ if out_channels is not None:
475
+ self.out_channels = out_channels
476
+
477
+ self.bottleneck = bottleneck
478
+ self.encoder = encoder
479
+ self.decoder = decoder
480
+ self.pretransform = pretransform
481
+ self.soft_clip = soft_clip
482
+ self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
483
+
484
+ self.remove_autoencoder_prefix_fn: Callable = remove_key_prefix_factory(
485
+ "autoencoder."
486
+ )
487
+ if pretrained_ckpt is not None:
488
+ self.load_pretrained(pretrained_ckpt)
489
+
490
+ def process_state_dict(self, model_dict, state_dict):
491
+ state_dict = state_dict["state_dict"]
492
+ state_dict = self.remove_autoencoder_prefix_fn(model_dict, state_dict)
493
+ return state_dict
494
+
495
+ def encode(
496
+ self, waveform: torch.Tensor, waveform_lengths: torch.Tensor
497
+ ) -> tuple[torch.Tensor, torch.Tensor]:
498
+ z = self.encoder(waveform)
499
+ z = self.bottleneck.encode(z)
500
+ z_length = waveform_lengths // self.downsampling_ratio
501
+ z_mask = create_mask_from_length(z_length)
502
+ return z, z_mask
503
+
504
+ def decode(self, latents: torch.Tensor) -> torch.Tensor:
505
+ waveform = self.decoder(latents)
506
+ return waveform
507
+
508
+
509
+ if __name__ == '__main__':
510
+ import hydra
511
+ from utils.config import generate_config_from_command_line_overrides
512
+ model_config = generate_config_from_command_line_overrides(
513
+ "configs/model/autoencoder/stable_vae.yaml"
514
+ )
515
+ autoencoder: StableVAE = hydra.utils.instantiate(model_config)
516
+ autoencoder.eval()
517
+
518
+ waveform, sr = torchaudio.load(
519
+ "/hpc_stor03/sjtu_home/xuenan.xu/workspace/singing_voice_synthesis/diffsinger/data/raw/opencpop/segments/wavs/2007000230.wav"
520
+ )
521
+ waveform = torchaudio.functional.resample(
522
+ waveform, sr, model_config["sample_rate"]
523
+ )
524
+ print("waveform: ", waveform.shape)
525
+ with torch.no_grad():
526
+ latent, latent_length = autoencoder.encode(
527
+ waveform, torch.as_tensor([waveform.shape[-1]])
528
+ )
529
+ print("latent: ", latent.shape)
530
+ reconstructed = autoencoder.decode(latent)
531
+ print("reconstructed: ", reconstructed.shape)
532
+ import soundfile as sf
533
+ sf.write(
534
+ "./reconstructed.wav",
535
+ reconstructed[0, 0].numpy(),
536
+ samplerate=model_config["sample_rate"]
537
+ )
models/common.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import torch
3
+ import torch.nn as nn
4
+ from utils.torch_utilities import load_pretrained_model, merge_matched_keys
5
+ import warnings
6
+
7
+ class LoadPretrainedBase(nn.Module):
8
+ def process_state_dict(
9
+ self, model_dict: dict[str, torch.Tensor],
10
+ state_dict: dict[str, torch.Tensor]
11
+ ):
12
+ """
13
+ Custom processing functions of each model that transforms `state_dict` loaded from
14
+ checkpoints to the state that can be used in `load_state_dict`.
15
+ Use `merge_mathced_keys` to update parameters with matched names and shapes by
16
+ default.
17
+
18
+ Args
19
+ model_dict:
20
+ The state dict of the current model, which is going to load pretrained parameters
21
+ state_dict:
22
+ A dictionary of parameters from a pre-trained model.
23
+
24
+ Returns:
25
+ dict[str, torch.Tensor]:
26
+ The updated state dict, where parameters with matched keys and shape are
27
+ updated with values in `state_dict`.
28
+ """
29
+ state_dict = merge_matched_keys(model_dict, state_dict)
30
+ return state_dict
31
+
32
+ def load_pretrained(self, ckpt_path: str | Path):
33
+ load_pretrained_model(
34
+ self, ckpt_path, state_dict_process_fn=self.process_state_dict
35
+ )
36
+
37
+
38
+ class CountParamsBase(nn.Module):
39
+ def count_params(self):
40
+ num_params = 0
41
+ trainable_params = 0
42
+ for param in self.parameters():
43
+ num_params += param.numel()
44
+ if param.requires_grad:
45
+ trainable_params += param.numel()
46
+ return num_params, trainable_params
47
+
48
+
49
+ class SaveTrainableParamsBase(nn.Module):
50
+ @property
51
+ def param_names_to_save(self):
52
+ names = []
53
+ for name, param in self.named_parameters():
54
+ if param.requires_grad:
55
+ names.append(name)
56
+ for name, _ in self.named_buffers():
57
+ names.append(name)
58
+ return names
59
+
60
+ def load_state_dict(self, state_dict, strict=True, assign=True):
61
+ print("State dict keys:", list(state_dict.keys()))
62
+ #for key in self.param_names_to_save:
63
+ # if key not in state_dict:
64
+ # raise Exception(
65
+ # f"{key} not found in either pre-trained models (e.g. BERT)"
66
+ # " or resumed checkpoints (e.g. epoch_40/model.pt)"
67
+ # )
68
+ # 兼容 PyTorch/transformers 的 assign 参数
69
+ return super().load_state_dict(state_dict, strict=strict, assign=assign)
models/content_encoder/__pycache__/caption_encoder.cpython-310.pyc ADDED
Binary file (3.51 kB). View file
 
models/content_encoder/__pycache__/content_encoder.cpython-310.pyc ADDED
Binary file (4.72 kB). View file
 
models/content_encoder/__pycache__/content_encoder_add_1024.cpython-310.pyc ADDED
Binary file (4.62 kB). View file
 
models/content_encoder/__pycache__/content_encoder_clap.cpython-310.pyc ADDED
Binary file (6.11 kB). View file
 
models/content_encoder/__pycache__/content_encoder_clap_test.cpython-310.pyc ADDED
Binary file (6.12 kB). View file
 
models/content_encoder/__pycache__/content_encoder_concat.cpython-310.pyc ADDED
Binary file (4.74 kB). View file
 
models/content_encoder/__pycache__/content_encoder_concat_4096.cpython-310.pyc ADDED
Binary file (4.69 kB). View file
 
models/content_encoder/__pycache__/content_encoder_concat_4096_random.cpython-310.pyc ADDED
Binary file (4.73 kB). View file
 
models/content_encoder/__pycache__/content_encoder_full.cpython-310.pyc ADDED
Binary file (5.01 kB). View file
 
models/content_encoder/__pycache__/content_encoder_full_non.cpython-310.pyc ADDED
Binary file (5 kB). View file
 
models/content_encoder/__pycache__/content_encoder_full_non_test.cpython-310.pyc ADDED
Binary file (4.87 kB). View file
 
models/content_encoder/__pycache__/content_encoder_full_test.cpython-310.pyc ADDED
Binary file (4.48 kB). View file
 
models/content_encoder/__pycache__/content_encoder_full_woonset.cpython-310.pyc ADDED
Binary file (4.59 kB). View file
 
models/content_encoder/__pycache__/content_encoder_merge.cpython-310.pyc ADDED
Binary file (4.78 kB). View file
 
models/content_encoder/__pycache__/content_encoder_merge_test.cpython-310.pyc ADDED
Binary file (4.82 kB). View file
 
models/content_encoder/__pycache__/content_encoder_replace.cpython-310.pyc ADDED
Binary file (4.71 kB). View file
 
models/content_encoder/__pycache__/content_encoder_replace_merge.cpython-310.pyc ADDED
Binary file (4.72 kB). View file
 
models/content_encoder/__pycache__/content_encoder_replace_new.cpython-310.pyc ADDED
Binary file (4.71 kB). View file
 
models/content_encoder/__pycache__/content_encoder_test.cpython-310.pyc ADDED
Binary file (4.58 kB). View file
 
models/content_encoder/__pycache__/content_test.cpython-310.pyc ADDED
Binary file (4.71 kB). View file
 
models/content_encoder/__pycache__/new_content_encoder.cpython-310.pyc ADDED
Binary file (4.73 kB). View file
 
models/content_encoder/__pycache__/text_encoder.cpython-310.pyc ADDED
Binary file (2.71 kB). View file
 
models/content_encoder/caption_encoder.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import torch
3
+ import torch.nn as nn
4
+ import random
5
+ from utils.audiotime_event_merge import replace_event_synonyms
6
+
7
+ def decode_data(line_onset_str, latent_length):
8
+ """
9
+ Extracts a timestamp matrix (event onset indices) from a formatted onset string.
10
+
11
+ Args:
12
+ line_onset_str (str): String containing event names and onset intervals,
13
+ formatted like "event1__start1-end1_start2-end2--event2__start1-end1".
14
+ latent_length (int): Length of the output matrix.
15
+
16
+ Returns:
17
+ line_onset_index (torch.Tensor): Matrix of shape [4, latent_length],
18
+ line_event (list): List of event names extracted from the onset string.
19
+
20
+ Notes:
21
+ - 24000 is the audio sample rate.
22
+ - 480 is the downsample ratio to align with VAE.
23
+ - Each onset interval "start-end" (in seconds) is converted to embedding indices via (time * 24000 / 480).
24
+ """
25
+ line_onset_index = torch.zeros((4, latent_length)) # max for 4 events
26
+ line_event = []
27
+ event_idx = 0
28
+ for event_onset in line_onset_str.split('--'):
29
+ #print(event_onset)
30
+ (event, instance) = event_onset.split('__')
31
+ #print(instance)
32
+ line_event.append(event)
33
+ for start_end in instance.split('_'):
34
+ (start, end) = start_end.split('-')
35
+ start, end = int(float(start)*24000/480), int(float(end)*24000/480)
36
+ if end > (latent_length - 1): break
37
+ line_onset_index[event_idx, start: end] = 1
38
+ event_idx = event_idx + 1
39
+ return line_onset_index, line_event
40
+
41
+
42
+ class ContentEncoder(nn.Module):
43
+ """
44
+ ContentEncoder encodes TCC and TDC information.
45
+ """
46
+ def __init__(
47
+ self,
48
+ text_encoder: nn.Module= None,
49
+ ):
50
+ super().__init__()
51
+ self.text_encoder = text_encoder
52
+ self.pool = nn.AdaptiveAvgPool1d(1)
53
+
54
+ def encode_content(
55
+ self, batch_content: list[Any], device: str | torch.device
56
+ ):
57
+ batch_output = []
58
+ batch_mask = []
59
+ batch_onset = []
60
+ length_list = []
61
+ print(batch_content)
62
+ for content in batch_content:
63
+
64
+ caption = content["caption"]
65
+ onset = content["onset"]
66
+ length = int(float(content["length"]) *24000/480)
67
+ # Replacement for AudioTime
68
+ print(onset)
69
+ replace_label = content.get("replace_label", "False")
70
+ if replace_label == "True":
71
+ caption, onset = replace_event_synonyms(caption, onset)
72
+
73
+ # Handle random onset case for read data without timestamp
74
+ if content["onset"] == "random":
75
+ length_list.append(length)
76
+ """
77
+ fixed embedding. Actually it's a sick sentence, a error during training, kept to match the checkpoint.
78
+ You can change it to sentence that difference to captions in datasets.
79
+ The use of fixed text to obtain encoding is for numerical stability.
80
+ We attempted to use learnable unified encoding during training, but the results were not satisfactory.
81
+ """
82
+ event = "There is no event here"
83
+ event_embed = self.text_encoder([event.replace("_", " ")])["output"]
84
+ event_embed = self.pool(event_embed.permute(0, 2, 1)) # (B, 1024, 1)
85
+ event_embed = event_embed.flatten().unsqueeze(0)
86
+ new_onset = event_embed.repeat(length, 1).T
87
+ else:
88
+ onset_matrix, events = decode_data(onset, length)
89
+ length_list.append(length)
90
+ new_onset = torch.zeros((1024, length), device=device) # 1024 for T5
91
+ # TDC
92
+ for (idx, event) in enumerate(events):
93
+ with torch.no_grad():
94
+ event_embed = self.text_encoder([event.replace("_", " ")])["output"]
95
+ event_embed = self.pool(event_embed.permute(0, 2, 1)) # (B, 1024, 1)
96
+ event_embed = event_embed.flatten().unsqueeze(0)
97
+ mask = (onset_matrix[idx, :] == 0)
98
+ cols = mask.nonzero(as_tuple=True)[0]
99
+ new_onset[:, cols] += event_embed.T.float()
100
+ # TCC
101
+ output_dict = self.text_encoder([caption])
102
+ batch_output.append(output_dict["output"][0])
103
+ batch_mask.append(output_dict["mask"][0])
104
+ batch_onset.append(new_onset)
105
+
106
+ # Pad all sequences in the batch to the same length for batching
107
+ batch_output = nn.utils.rnn.pad_sequence(
108
+ batch_output, batch_first=True, padding_value=0
109
+ )
110
+ batch_mask = nn.utils.rnn.pad_sequence(
111
+ batch_mask, batch_first=True, padding_value=False
112
+ )
113
+ batch_onset = nn.utils.rnn.pad_sequence(
114
+ batch_onset, batch_first=True, padding_value=0
115
+ )
116
+ return batch_output, batch_mask, batch_onset, length_list
models/content_encoder/text_encoder.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoTokenizer, AutoModel, T5Tokenizer, T5EncoderModel
4
+ from transformers.modeling_outputs import BaseModelOutput
5
+
6
+ try:
7
+ import torch_npu
8
+ from torch_npu.contrib import transfer_to_npu
9
+ DEVICE_TYPE = "npu"
10
+ except ModuleNotFoundError:
11
+ DEVICE_TYPE = "cuda"
12
+
13
+
14
+ class TransformersTextEncoderBase(nn.Module):
15
+ """
16
+ Base class for text encoding using HuggingFace Transformers models.
17
+
18
+ """
19
+ def __init__(self, model_name: str):
20
+ super().__init__()
21
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ self.model = AutoModel.from_pretrained(model_name)
23
+
24
+ def forward(
25
+ self,
26
+ text: list[str],
27
+ ):
28
+ device = self.model.device
29
+ batch = self.tokenizer(
30
+ text,
31
+ max_length=self.tokenizer.model_max_length,
32
+ padding=True,
33
+ truncation=True,
34
+ return_tensors="pt"
35
+ )
36
+ input_ids = batch.input_ids.to(device)
37
+ attention_mask = batch.attention_mask.to(device)
38
+ output: BaseModelOutput = self.model(
39
+ input_ids=input_ids, attention_mask=attention_mask
40
+ )
41
+ output = output.last_hidden_state
42
+ mask = (attention_mask == 1).to(device)
43
+
44
+ return {"output": output, "mask": mask}
45
+
46
+
47
+ class T5TextEncoder(TransformersTextEncoderBase):
48
+ """
49
+ Text encoder using T5 encoder model.
50
+ """
51
+ def __init__(self, model_name: str = "/mnt/petrelfs/zhengzihao/cache/google-flan-t5-large"):
52
+ nn.Module.__init__(self)
53
+ self.tokenizer = T5Tokenizer.from_pretrained(model_name)
54
+ self.model = T5EncoderModel.from_pretrained(model_name)
55
+ for param in self.model.parameters():
56
+ param.requires_grad = False
57
+ self.eval()
58
+
59
+ def forward(
60
+ self,
61
+ text: list[str],
62
+ ):
63
+ with torch.no_grad(), torch.amp.autocast(
64
+ device_type=DEVICE_TYPE, enabled=False
65
+ ):
66
+ return super().forward(text)
67
+
68
+
69
+ if __name__ == '__main__':
70
+ text_encoder = T5TextEncoder()
71
+ text = ["dog barking and cat moving"]
72
+ text_encoder.eval()
73
+ with torch.no_grad():
74
+ output = text_encoder(text)
75
+ print(output["output"].shape)
76
+ #print(output)
models/diffusion.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence
2
+ import random
3
+ from typing import Any
4
+
5
+ from tqdm import tqdm
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import diffusers.schedulers as noise_schedulers
10
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
11
+ from diffusers.utils.torch_utils import randn_tensor
12
+
13
+ import numpy as np
14
+ from models.autoencoder.autoencoder_base import AutoEncoderBase
15
+ from models.content_encoder.caption_encoder import ContentEncoder
16
+ from models.common import LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase
17
+ from utils.torch_utilities import (
18
+ create_alignment_path, create_mask_from_length, loss_with_mask,
19
+ trim_or_pad_length
20
+ )
21
+
22
+
23
+ class DiffusionMixin:
24
+ def __init__(
25
+ self,
26
+ noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
27
+ snr_gamma: float = None,
28
+ classifier_free_guidance: bool = True,
29
+ cfg_drop_ratio: float = 0.2,
30
+
31
+ ) -> None:
32
+ self.noise_scheduler_name = noise_scheduler_name
33
+ self.snr_gamma = snr_gamma
34
+ self.classifier_free_guidance = classifier_free_guidance
35
+ self.cfg_drop_ratio = cfg_drop_ratio
36
+ self.noise_scheduler = noise_schedulers.DDIMScheduler.from_pretrained(
37
+ self.noise_scheduler_name, subfolder="scheduler"
38
+ )
39
+
40
+ def compute_snr(self, timesteps) -> torch.Tensor:
41
+ """
42
+ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
43
+ """
44
+ alphas_cumprod = self.noise_scheduler.alphas_cumprod
45
+ sqrt_alphas_cumprod = alphas_cumprod**0.5
46
+ sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod)**0.5
47
+
48
+ # Expand the tensors.
49
+ # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
50
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device
51
+ )[timesteps].float()
52
+ while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
53
+ sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
54
+ alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
55
+
56
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
57
+ device=timesteps.device
58
+ )[timesteps].float()
59
+ while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
60
+ sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[...,
61
+ None]
62
+ sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
63
+
64
+ # Compute SNR.
65
+ snr = (alpha / sigma)**2
66
+ return snr
67
+
68
+ def get_timesteps(
69
+ self,
70
+ batch_size: int,
71
+ device: torch.device,
72
+ training: bool = True
73
+ ) -> torch.Tensor:
74
+ if training:
75
+ timesteps = torch.randint(
76
+ 0,
77
+ self.noise_scheduler.config.num_train_timesteps,
78
+ (batch_size, ),
79
+ device=device
80
+ )
81
+ else:
82
+ # validation on half of the total timesteps
83
+ timesteps = (self.noise_scheduler.config.num_train_timesteps //
84
+ 2) * torch.ones((batch_size, ),
85
+ dtype=torch.int64,
86
+ device=device)
87
+
88
+ timesteps = timesteps.long()
89
+ return timesteps
90
+
91
+ def get_target(
92
+ self, latent: torch.Tensor, noise: torch.Tensor,
93
+ timesteps: torch.Tensor
94
+ ) -> torch.Tensor:
95
+ """
96
+ Get the target for loss depending on the prediction type
97
+ """
98
+ if self.noise_scheduler.config.prediction_type == "epsilon":
99
+ target = noise
100
+ elif self.noise_scheduler.config.prediction_type == "v_prediction":
101
+ target = self.noise_scheduler.get_velocity(
102
+ latent, noise, timesteps
103
+ )
104
+ else:
105
+ raise ValueError(
106
+ f"Unknown prediction type {self.noise_scheduler.config.prediction_type}"
107
+ )
108
+ return target
109
+
110
+ def loss_with_snr(
111
+ self, pred: torch.Tensor, target: torch.Tensor,
112
+ timesteps: torch.Tensor, mask: torch.Tensor
113
+ ) -> torch.Tensor:
114
+ if self.snr_gamma is None:
115
+ loss = F.mse_loss(pred.float(), target.float(), reduction="none")
116
+ loss = loss_with_mask(loss, mask)
117
+ else:
118
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
119
+ # Adaptef from huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py
120
+ snr = self.compute_snr(timesteps)
121
+ mse_loss_weights = (
122
+ torch.stack([snr, self.snr_gamma * torch.ones_like(timesteps)],
123
+ dim=1).min(dim=1)[0] / snr
124
+ )
125
+ loss = F.mse_loss(pred.float(), target.float(), reduction="none")
126
+ loss = loss_with_mask(loss, mask, reduce=False) * mse_loss_weights
127
+ loss = loss.mean()
128
+ return loss
129
+
130
+
131
+ class AudioDiffusion(
132
+ LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase,
133
+ DiffusionMixin
134
+ ):
135
+ """
136
+ Args:
137
+ autoencoder (AutoEncoderBase): Pretrained autoencoder module VAE(frozen).
138
+ content_encoder (ContentEncoder): Encodes TCC and TDC information.
139
+ backbone (nn.Module): Main denoising network.
140
+ frame_resolution (float): Resolution for audio frames.
141
+ noise_scheduler_name (str): Noise scheduler identifier.
142
+ snr_gamma (float, optional): SNR gamma for noise scheduler.
143
+ classifier_free_guidance (bool): Enable classifier-free guidance.
144
+ cfg_drop_ratio (float): Ratio for randomly dropping context for classifier-free guidance.
145
+ """
146
+ def __init__(
147
+ self,
148
+ autoencoder: AutoEncoderBase,
149
+ content_encoder: ContentEncoder,
150
+ backbone: nn.Module,
151
+ frame_resolution:float,
152
+ noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
153
+ snr_gamma: float = None,
154
+ classifier_free_guidance: bool = True,
155
+ cfg_drop_ratio: float = 0.2,
156
+ ):
157
+ nn.Module.__init__(self)
158
+ DiffusionMixin.__init__(
159
+ self, noise_scheduler_name, snr_gamma, classifier_free_guidance, cfg_drop_ratio
160
+ )
161
+
162
+ self.autoencoder = autoencoder
163
+ # Freeze autoencoder parameters
164
+ for param in self.autoencoder.parameters():
165
+ param.requires_grad = False
166
+
167
+ self.content_encoder = content_encoder
168
+ self.backbone = backbone
169
+ self.frame_resolution = frame_resolution
170
+ self.dummy_param = nn.Parameter(torch.empty(0))
171
+
172
+ def forward(
173
+ self, content: list[Any], condition: list[Any], task: list[str],
174
+ waveform: torch.Tensor, waveform_lengths: torch.Tensor, **kwargs
175
+ ):
176
+ """
177
+ Training forward pass.
178
+
179
+ Args:
180
+ content (list[Any]): List of content dicts for each sample.
181
+ condition (list[Any]): Conditioning information (unused here).
182
+ task (list[str]): List of task types.
183
+ waveform (Tensor): Batch of waveform tensors.
184
+ waveform_lengths (Tensor): Lengths for each waveform sample.
185
+
186
+ Returns:
187
+ dict: Dictionary containing the diffusion loss.
188
+ """
189
+ device = self.dummy_param.device
190
+ num_train_timesteps = self.noise_scheduler.config.num_train_timesteps
191
+ self.noise_scheduler.set_timesteps(num_train_timesteps, device=device)
192
+
193
+ self.autoencoder.eval()
194
+ with torch.no_grad():
195
+ latent, latent_mask = self.autoencoder.encode(
196
+ waveform.unsqueeze(1), waveform_lengths
197
+ )
198
+ # content(non_time_aligned_content) for TCC and time_aligned_content for TDC
199
+ content, content_mask, onset, _= self.content_encoder.encode_content(
200
+ content, device=device
201
+ )
202
+
203
+ # prepare latent and diffusion-related noise
204
+ time_aligned_content = onset.permute(0,2,1)
205
+ if self.training and self.classifier_free_guidance:
206
+ mask_indices = [
207
+ k for k in range(len(waveform)) if random.random() < self.cfg_drop_ratio
208
+ ]
209
+ if len(mask_indices) > 0:
210
+ content[mask_indices] = 0
211
+ time_aligned_content[mask_indices] = 0
212
+
213
+ batch_size = latent.shape[0]
214
+ timesteps = self.get_timesteps(batch_size, device, self.training)
215
+ noise = torch.randn_like(latent)
216
+ noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps)
217
+ target = self.get_target(latent, noise, timesteps)
218
+
219
+ # Denoising prediction
220
+ pred: torch.Tensor = self.backbone(
221
+ x=noisy_latent,
222
+ timesteps=timesteps,
223
+ time_aligned_context=time_aligned_content,
224
+ context=content,
225
+ x_mask=latent_mask,
226
+ context_mask=content_mask
227
+ )
228
+ pred = pred.transpose(1, self.autoencoder.time_dim)
229
+ target = target.transpose(1, self.autoencoder.time_dim)
230
+ diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask)
231
+ return {
232
+ "diff_loss": diff_loss,
233
+ }
234
+
235
+ @torch.no_grad()
236
+ def inference(
237
+ self,
238
+ content: list[Any],
239
+ num_steps: int = 20,
240
+ guidance_scale: float = 3.0,
241
+ guidance_rescale: float = 0.0,
242
+ disable_progress: bool = True,
243
+ num_samples_per_content: int = 1,
244
+ **kwargs
245
+ ):
246
+ """
247
+ Inference/generation method for audio diffusion.
248
+
249
+ Args:
250
+ content (list[Any]): List of content dicts.
251
+ scheduler (SchedulerMixin): Scheduler for timesteps and noise.
252
+ num_steps (int): Number of denoising steps.
253
+ guidance_scale (float): Classifier-free guidance scale.
254
+ guidance_rescale (float): Rescale factor for guidance.
255
+ disable_progress (bool): Disable progress bar.
256
+ num_samples_per_content (int): How many samples to generate per content.
257
+
258
+ Returns:
259
+ waveform (Tensor): Generated waveform.
260
+ """
261
+ device = self.dummy_param.device
262
+ classifier_free_guidance = guidance_scale > 1.0
263
+ batch_size = len(content) * num_samples_per_content
264
+ print(content)
265
+ if classifier_free_guidance:
266
+ content, content_mask, onset, length_list = self.encode_content_classifier_free(
267
+ content, num_samples_per_content
268
+ )
269
+ else:
270
+ content, content_mask, onset, length_list = self.content_encoder.encode_content(
271
+ content, device=device
272
+ )
273
+ content = content.repeat_interleave(num_samples_per_content, 0)
274
+ content_mask = content_mask.repeat_interleave(
275
+ num_samples_per_content, 0
276
+ )
277
+
278
+ self.noise_scheduler.set_timesteps(num_steps, device=device)
279
+ timesteps = self.noise_scheduler.timesteps
280
+
281
+
282
+ # prepare input latent and context for the backbone
283
+ shape = (batch_size, 128, onset.shape[2]) # 128 for StableVAE channels
284
+ time_aligned_content = onset.permute(0,2,1)
285
+ latent = randn_tensor(
286
+ shape, generator=None, device=device, dtype=content.dtype
287
+ )
288
+
289
+ # scale the initial noise by the standard deviation required by the scheduler
290
+ latent = latent * self.noise_scheduler.init_noise_sigma
291
+ latent_mask = torch.full((batch_size, onset.shape[2]), False, device=device)
292
+
293
+ for i, length in enumerate(length_list):
294
+ # Set latent mask True for valid time steps for each sample
295
+ latent_mask[i, :length] = True
296
+ num_warmup_steps = len(timesteps) - num_steps * self.noise_scheduler.order
297
+ progress_bar = tqdm(range(num_steps), disable=disable_progress)
298
+
299
+ if classifier_free_guidance:
300
+ uncond_time_aligned_content = torch.zeros_like(
301
+ time_aligned_content
302
+ )
303
+ time_aligned_content = torch.cat(
304
+ [uncond_time_aligned_content, time_aligned_content]
305
+ )
306
+ latent_mask = torch.cat(
307
+ [latent_mask, latent_mask.detach().clone()]
308
+ )
309
+
310
+ # iteratively denoising
311
+
312
+ for i, timestep in enumerate(timesteps):
313
+
314
+ latent_input = torch.cat(
315
+ [latent, latent]
316
+ ) if classifier_free_guidance else latent
317
+ latent_input = self.noise_scheduler.scale_model_input(latent_input, timestep)
318
+
319
+ noise_pred = self.backbone(
320
+ x=latent_input,
321
+ x_mask=latent_mask,
322
+ timesteps=timestep,
323
+ time_aligned_context=time_aligned_content,
324
+ context=content,
325
+ context_mask=content_mask,
326
+ )
327
+
328
+ if classifier_free_guidance:
329
+ noise_pred_uncond, noise_pred_content = noise_pred.chunk(2)
330
+ noise_pred = noise_pred_uncond + guidance_scale * (
331
+ noise_pred_content - noise_pred_uncond
332
+ )
333
+ if guidance_rescale != 0.0:
334
+ noise_pred = self.rescale_cfg(
335
+ noise_pred_content, noise_pred, guidance_rescale
336
+ )
337
+ # compute the previous noisy sample x_t -> x_t-1
338
+ latent = self.noise_scheduler.step(noise_pred, timestep, latent).prev_sample
339
+
340
+ # call the callback, if provided
341
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
342
+ (i+1) % self.noise_scheduler.order == 0):
343
+ progress_bar.update(1)
344
+ #latent = latent.to(next(self.autoencoder.parameters()).device)
345
+ waveform = self.autoencoder.decode(latent)
346
+ return waveform
347
+
348
+ def encode_content_classifier_free(
349
+ self,
350
+ content: list[Any],
351
+ task: list[str],
352
+ num_samples_per_content: int = 1
353
+ ):
354
+ device = self.dummy_param.device
355
+
356
+ content, content_mask, onset, length_list = self.content_encoder.encode_content(
357
+ content, device=device
358
+ )
359
+ content = content.repeat_interleave(num_samples_per_content, 0)
360
+ content_mask = content_mask.repeat_interleave(
361
+ num_samples_per_content, 0
362
+ )
363
+
364
+ # get unconditional embeddings for classifier free guidance
365
+ uncond_content = torch.zeros_like(content)
366
+ uncond_content_mask = content_mask.detach().clone()
367
+
368
+ uncond_content = uncond_content.repeat_interleave(
369
+ num_samples_per_content, 0
370
+ )
371
+ uncond_content_mask = uncond_content_mask.repeat_interleave(
372
+ num_samples_per_content, 0
373
+ )
374
+
375
+ # For classifier free guidance, we need to do two forward passes.
376
+ # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
377
+ content = torch.cat([uncond_content, content])
378
+ content_mask = torch.cat([uncond_content_mask, content_mask])
379
+
380
+ return content, content_mask, onset, length_list
381
+
382
+ def rescale_cfg(
383
+ self, pred_cond: torch.Tensor, pred_cfg: torch.Tensor,
384
+ guidance_rescale: float
385
+ ):
386
+ """
387
+ Rescale `pred_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
388
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
389
+ """
390
+ std_cond = pred_cond.std(
391
+ dim=list(range(1, pred_cond.ndim)), keepdim=True
392
+ )
393
+ std_cfg = pred_cfg.std(dim=list(range(1, pred_cfg.ndim)), keepdim=True)
394
+
395
+ pred_rescaled = pred_cfg * (std_cond / std_cfg)
396
+ pred_cfg = guidance_rescale * pred_rescaled + (
397
+ 1 - guidance_rescale
398
+ ) * pred_cfg
models/dit/__pycache__/attention.cpython-310.pyc ADDED
Binary file (7.7 kB). View file
 
models/dit/__pycache__/audio_dit.cpython-310.pyc ADDED
Binary file (8.31 kB). View file
 
models/dit/__pycache__/mask_dit.cpython-310.pyc ADDED
Binary file (14.6 kB). View file
 
models/dit/__pycache__/modules.cpython-310.pyc ADDED
Binary file (14 kB). View file
 
models/dit/__pycache__/rotary.cpython-310.pyc ADDED
Binary file (2.79 kB). View file
 
models/dit/__pycache__/span_mask.cpython-310.pyc ADDED
Binary file (4.75 kB). View file
 
models/dit/attention.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.utils.checkpoint
5
+ import einops
6
+ from einops import rearrange, repeat
7
+ from inspect import isfunction
8
+ from .rotary import RotaryEmbedding
9
+ from .modules import RMSNorm
10
+
11
+ if hasattr(nn.functional, 'scaled_dot_product_attention'):
12
+ ATTENTION_MODE = 'flash'
13
+ else:
14
+ ATTENTION_MODE = 'math'
15
+ print(f'attention mode is {ATTENTION_MODE}')
16
+
17
+
18
+ def add_mask(sim, mask):
19
+ b, ndim = sim.shape[0], mask.ndim
20
+ if ndim == 3:
21
+ mask = rearrange(mask, "b n m -> b 1 n m")
22
+ if ndim == 2:
23
+ mask = repeat(mask, "n m -> b 1 n m", b=b)
24
+ max_neg_value = -torch.finfo(sim.dtype).max
25
+ sim = sim.masked_fill(~mask, max_neg_value)
26
+ return sim
27
+
28
+
29
+ def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None):
30
+ def default(val, d):
31
+ return val if val is not None else (d() if isfunction(d) else d)
32
+
33
+ b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device
34
+ #print(q_mask)
35
+ q_mask = default(
36
+ q_mask, torch.ones((b, i), device=device, dtype=torch.bool)
37
+ )
38
+ k_mask = default(
39
+ k_mask, torch.ones((b, j), device=device, dtype=torch.bool)
40
+ )
41
+ attn_mask = rearrange(q_mask, 'b i -> b 1 i 1'
42
+ ) * rearrange(k_mask, 'b j -> b 1 1 j')
43
+ return attn_mask
44
+
45
+
46
+ class Attention(nn.Module):
47
+ def __init__(
48
+ self,
49
+ dim,
50
+ context_dim=None,
51
+ num_heads=8,
52
+ qkv_bias=False,
53
+ qk_scale=None,
54
+ qk_norm=None,
55
+ attn_drop=0.,
56
+ proj_drop=0.,
57
+ rope_mode='none'
58
+ ):
59
+ super().__init__()
60
+ self.num_heads = num_heads
61
+ head_dim = dim // num_heads
62
+ self.scale = qk_scale or head_dim**-0.5
63
+
64
+ if context_dim is None:
65
+ self.cross_attn = False
66
+ else:
67
+ self.cross_attn = True
68
+
69
+ context_dim = dim if context_dim is None else context_dim
70
+
71
+ self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
72
+ self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
73
+ self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)
74
+
75
+ if qk_norm is None:
76
+ self.norm_q = nn.Identity()
77
+ self.norm_k = nn.Identity()
78
+ elif qk_norm == 'layernorm':
79
+ self.norm_q = nn.LayerNorm(head_dim)
80
+ self.norm_k = nn.LayerNorm(head_dim)
81
+ elif qk_norm == 'rmsnorm':
82
+ self.norm_q = RMSNorm(head_dim)
83
+ self.norm_k = RMSNorm(head_dim)
84
+ else:
85
+ raise NotImplementedError
86
+
87
+ self.attn_drop_p = attn_drop
88
+ self.attn_drop = nn.Dropout(attn_drop)
89
+ self.proj = nn.Linear(dim, dim)
90
+ self.proj_drop = nn.Dropout(proj_drop)
91
+
92
+ if self.cross_attn:
93
+ assert rope_mode == 'none'
94
+ self.rope_mode = rope_mode
95
+ if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
96
+ self.rotary = RotaryEmbedding(dim=head_dim)
97
+ elif self.rope_mode == 'dual':
98
+ self.rotary_x = RotaryEmbedding(dim=head_dim)
99
+ self.rotary_c = RotaryEmbedding(dim=head_dim)
100
+
101
+ def _rotary(self, q, k, extras):
102
+ if self.rope_mode == 'shared':
103
+ q, k = self.rotary(q=q, k=k)
104
+ elif self.rope_mode == 'x_only':
105
+ q_x, k_x = self.rotary(
106
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
107
+ )
108
+ q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
109
+ q = torch.cat((q_c, q_x), dim=2)
110
+ k = torch.cat((k_c, k_x), dim=2)
111
+ elif self.rope_mode == 'dual':
112
+ q_x, k_x = self.rotary_x(
113
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
114
+ )
115
+ q_c, k_c = self.rotary_c(
116
+ q=q[:, :, :extras, :], k=k[:, :, :extras, :]
117
+ )
118
+ q = torch.cat((q_c, q_x), dim=2)
119
+ k = torch.cat((k_c, k_x), dim=2)
120
+ elif self.rope_mode == 'none':
121
+ pass
122
+ else:
123
+ raise NotImplementedError
124
+ return q, k
125
+
126
+ def _attn(self, q, k, v, mask_binary):
127
+ if ATTENTION_MODE == 'flash':
128
+ x = F.scaled_dot_product_attention(
129
+ q, k, v, dropout_p=self.attn_drop_p, attn_mask=mask_binary
130
+ )
131
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
132
+ elif ATTENTION_MODE == 'math':
133
+ attn = (q @ k.transpose(-2, -1)) * self.scale
134
+ attn = add_mask(
135
+ attn, mask_binary
136
+ ) if mask_binary is not None else attn
137
+ attn = attn.softmax(dim=-1)
138
+ attn = self.attn_drop(attn)
139
+ x = (attn @ v).transpose(1, 2)
140
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
141
+ else:
142
+ raise NotImplementedError
143
+ return x
144
+
145
+ def forward(self, x, context=None, context_mask=None, extras=0):
146
+ B, L, C = x.shape
147
+ if context is None:
148
+ context = x
149
+
150
+ q = self.to_q(x)
151
+ k = self.to_k(context)
152
+ v = self.to_v(context)
153
+
154
+ if context_mask is not None:
155
+ mask_binary = create_mask(
156
+ x.shape, context.shape, x.device, None, context_mask
157
+ )
158
+ else:
159
+ mask_binary = None
160
+
161
+ q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads)
162
+ k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads)
163
+ v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads)
164
+
165
+ q = self.norm_q(q)
166
+ k = self.norm_k(k)
167
+
168
+ q, k = self._rotary(q, k, extras)
169
+
170
+ x = self._attn(q, k, v, mask_binary)
171
+
172
+ x = self.proj(x)
173
+ x = self.proj_drop(x)
174
+ return x
175
+
176
+
177
+ class JointAttention(nn.Module):
178
+ def __init__(
179
+ self,
180
+ dim,
181
+ num_heads=8,
182
+ qkv_bias=False,
183
+ qk_scale=None,
184
+ qk_norm=None,
185
+ attn_drop=0.,
186
+ proj_drop=0.,
187
+ rope_mode='none'
188
+ ):
189
+ super().__init__()
190
+ self.num_heads = num_heads
191
+ head_dim = dim // num_heads
192
+ self.scale = qk_scale or head_dim**-0.5
193
+
194
+ self.to_qx, self.to_kx, self.to_vx = self._make_qkv_layers(
195
+ dim, qkv_bias
196
+ )
197
+ self.to_qc, self.to_kc, self.to_vc = self._make_qkv_layers(
198
+ dim, qkv_bias
199
+ )
200
+
201
+ self.norm_qx, self.norm_kx = self._make_norm_layers(qk_norm, head_dim)
202
+ self.norm_qc, self.norm_kc = self._make_norm_layers(qk_norm, head_dim)
203
+
204
+ self.attn_drop_p = attn_drop
205
+ self.attn_drop = nn.Dropout(attn_drop)
206
+
207
+ self.proj_x = nn.Linear(dim, dim)
208
+ self.proj_drop_x = nn.Dropout(proj_drop)
209
+
210
+ self.proj_c = nn.Linear(dim, dim)
211
+ self.proj_drop_c = nn.Dropout(proj_drop)
212
+
213
+ self.rope_mode = rope_mode
214
+ if self.rope_mode == 'shared' or self.rope_mode == 'x_only':
215
+ self.rotary = RotaryEmbedding(dim=head_dim)
216
+ elif self.rope_mode == 'dual':
217
+ self.rotary_x = RotaryEmbedding(dim=head_dim)
218
+ self.rotary_c = RotaryEmbedding(dim=head_dim)
219
+
220
+ def _make_qkv_layers(self, dim, qkv_bias):
221
+ return (
222
+ nn.Linear(dim, dim,
223
+ bias=qkv_bias), nn.Linear(dim, dim, bias=qkv_bias),
224
+ nn.Linear(dim, dim, bias=qkv_bias)
225
+ )
226
+
227
+ def _make_norm_layers(self, qk_norm, head_dim):
228
+ if qk_norm is None:
229
+ norm_q = nn.Identity()
230
+ norm_k = nn.Identity()
231
+ elif qk_norm == 'layernorm':
232
+ norm_q = nn.LayerNorm(head_dim)
233
+ norm_k = nn.LayerNorm(head_dim)
234
+ elif qk_norm == 'rmsnorm':
235
+ norm_q = RMSNorm(head_dim)
236
+ norm_k = RMSNorm(head_dim)
237
+ else:
238
+ raise NotImplementedError
239
+ return norm_q, norm_k
240
+
241
+ def _rotary(self, q, k, extras):
242
+ if self.rope_mode == 'shared':
243
+ q, k = self.rotary(q=q, k=k)
244
+ elif self.rope_mode == 'x_only':
245
+ q_x, k_x = self.rotary(
246
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
247
+ )
248
+ q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :]
249
+ q = torch.cat((q_c, q_x), dim=2)
250
+ k = torch.cat((k_c, k_x), dim=2)
251
+ elif self.rope_mode == 'dual':
252
+ q_x, k_x = self.rotary_x(
253
+ q=q[:, :, extras:, :], k=k[:, :, extras:, :]
254
+ )
255
+ q_c, k_c = self.rotary_c(
256
+ q=q[:, :, :extras, :], k=k[:, :, :extras, :]
257
+ )
258
+ q = torch.cat((q_c, q_x), dim=2)
259
+ k = torch.cat((k_c, k_x), dim=2)
260
+ elif self.rope_mode == 'none':
261
+ pass
262
+ else:
263
+ raise NotImplementedError
264
+ return q, k
265
+
266
+ def _attn(self, q, k, v, mask_binary):
267
+ if ATTENTION_MODE == 'flash':
268
+ x = F.scaled_dot_product_attention(
269
+ q, k, v, dropout_p=self.attn_drop_p, attn_mask=mask_binary
270
+ )
271
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
272
+ elif ATTENTION_MODE == 'math':
273
+ attn = (q @ k.transpose(-2, -1)) * self.scale
274
+ attn = add_mask(
275
+ attn, mask_binary
276
+ ) if mask_binary is not None else attn
277
+ attn = attn.softmax(dim=-1)
278
+ attn = self.attn_drop(attn)
279
+ x = (attn @ v).transpose(1, 2)
280
+ x = einops.rearrange(x, 'B H L D -> B L (H D)')
281
+ else:
282
+ raise NotImplementedError
283
+ return x
284
+
285
+ def _cat_mask(self, x, context, x_mask=None, context_mask=None):
286
+ B = x.shape[0]
287
+ if x_mask is None:
288
+ x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
289
+ if context_mask is None:
290
+ context_mask = torch.ones(
291
+ B, context.shape[-2], device=context.device
292
+ ).bool()
293
+ mask = torch.cat([context_mask, x_mask], dim=1)
294
+ return mask
295
+
296
+ def forward(self, x, context, x_mask=None, context_mask=None, extras=0):
297
+ B, Lx, C = x.shape
298
+ _, Lc, _ = context.shape
299
+ if x_mask is not None or context_mask is not None:
300
+ mask = self._cat_mask(
301
+ x, context, x_mask=x_mask, context_mask=context_mask
302
+ )
303
+ shape = [B, Lx + Lc, C]
304
+ mask_binary = create_mask(
305
+ q_shape=shape,
306
+ k_shape=shape,
307
+ device=x.device,
308
+ q_mask=None,
309
+ k_mask=mask
310
+ )
311
+ else:
312
+ mask_binary = None
313
+
314
+ qx, kx, vx = self.to_qx(x), self.to_kx(x), self.to_vx(x)
315
+ qc, kc, vc = self.to_qc(context), self.to_kc(context
316
+ ), self.to_vc(context)
317
+
318
+ qx, kx, vx = map(
319
+ lambda t: einops.
320
+ rearrange(t, 'B L (H D) -> B H L D', H=self.num_heads),
321
+ [qx, kx, vx]
322
+ )
323
+ qc, kc, vc = map(
324
+ lambda t: einops.
325
+ rearrange(t, 'B L (H D) -> B H L D', H=self.num_heads),
326
+ [qc, kc, vc]
327
+ )
328
+
329
+ qx, kx = self.norm_qx(qx), self.norm_kx(kx)
330
+ qc, kc = self.norm_qc(qc), self.norm_kc(kc)
331
+
332
+ q, k, v = (
333
+ torch.cat([qc, qx],
334
+ dim=2), torch.cat([kc, kx],
335
+ dim=2), torch.cat([vc, vx], dim=2)
336
+ )
337
+
338
+ q, k = self._rotary(q, k, extras)
339
+
340
+ x = self._attn(q, k, v, mask_binary)
341
+
342
+ context, x = x[:, :Lc, :], x[:, Lc:, :]
343
+
344
+ x = self.proj_x(x)
345
+ x = self.proj_drop_x(x)
346
+
347
+ context = self.proj_c(context)
348
+ context = self.proj_drop_c(context)
349
+
350
+ return x, context
models/dit/audio_diffsingernet_dit.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ from .mask_dit import DiTBlock, FinalBlock, UDiT
6
+ from .modules import (
7
+ film_modulate,
8
+ PatchEmbed,
9
+ PE_wrapper,
10
+ TimestepEmbedder,
11
+ RMSNorm,
12
+ )
13
+
14
+
15
+ class AudioDiTBlock(DiTBlock):
16
+ """
17
+ A modified DiT block with time_aligned_context add to latent.
18
+ """
19
+ def __init__(
20
+ self,
21
+ dim,
22
+ time_aligned_context_dim,
23
+ dilation,
24
+ context_dim=None,
25
+ num_heads=8,
26
+ mlp_ratio=4.,
27
+ qkv_bias=False,
28
+ qk_scale=None,
29
+ qk_norm=None,
30
+ act_layer='gelu',
31
+ norm_layer=nn.LayerNorm,
32
+ time_fusion='none',
33
+ ada_sola_rank=None,
34
+ ada_sola_alpha=None,
35
+ skip=False,
36
+ skip_norm=False,
37
+ rope_mode='none',
38
+ context_norm=False,
39
+ use_checkpoint=False
40
+ ):
41
+ super().__init__(
42
+ dim=dim,
43
+ context_dim=context_dim,
44
+ num_heads=num_heads,
45
+ mlp_ratio=mlp_ratio,
46
+ qkv_bias=qkv_bias,
47
+ qk_scale=qk_scale,
48
+ qk_norm=qk_norm,
49
+ act_layer=act_layer,
50
+ norm_layer=norm_layer,
51
+ time_fusion=time_fusion,
52
+ ada_sola_rank=ada_sola_rank,
53
+ ada_sola_alpha=ada_sola_alpha,
54
+ skip=skip,
55
+ skip_norm=skip_norm,
56
+ rope_mode=rope_mode,
57
+ context_norm=context_norm,
58
+ use_checkpoint=use_checkpoint
59
+ )
60
+ # time-aligned context projection
61
+ self.ta_context_projection = nn.Linear(
62
+ time_aligned_context_dim, 2 * dim
63
+ )
64
+ self.dilated_conv = nn.Conv1d(
65
+ dim, 2 * dim, kernel_size=3, padding=dilation, dilation=dilation
66
+ )
67
+
68
+ def forward(
69
+ self,
70
+ x,
71
+ time_aligned_context,
72
+ time_token=None,
73
+ time_ada=None,
74
+ skip=None,
75
+ context=None,
76
+ x_mask=None,
77
+ context_mask=None,
78
+ extras=None
79
+ ):
80
+ if self.use_checkpoint:
81
+ return checkpoint(
82
+ self._forward,
83
+ x,
84
+ time_aligned_context,
85
+ time_token,
86
+ time_ada,
87
+ skip,
88
+ context,
89
+ x_mask,
90
+ context_mask,
91
+ extras,
92
+ use_reentrant=False
93
+ )
94
+ else:
95
+ return self._forward(
96
+ x,
97
+ time_aligned_context,
98
+ time_token,
99
+ time_ada,
100
+ skip,
101
+ context,
102
+ x_mask,
103
+ context_mask,
104
+ extras,
105
+ )
106
+
107
+ def _forward(
108
+ self,
109
+ x,
110
+ time_aligned_context,
111
+ time_token=None,
112
+ time_ada=None,
113
+ skip=None,
114
+ context=None,
115
+ x_mask=None,
116
+ context_mask=None,
117
+ extras=None
118
+ ):
119
+ B, T, C = x.shape
120
+ if self.skip_linear is not None:
121
+ assert skip is not None
122
+ cat = torch.cat([x, skip], dim=-1)
123
+ cat = self.skip_norm(cat)
124
+ x = self.skip_linear(cat)
125
+
126
+ if self.use_adanorm:
127
+ time_ada = self.adaln(time_token, time_ada)
128
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
129
+ gate_mlp) = time_ada.chunk(6, dim=1)
130
+
131
+ # self attention
132
+ if self.use_adanorm:
133
+ x_norm = film_modulate(
134
+ self.norm1(x), shift=shift_msa, scale=scale_msa
135
+ )
136
+ x = x + (1-gate_msa) * self.attn(
137
+ x_norm, context=None, context_mask=x_mask, extras=extras
138
+ )
139
+ else:
140
+ # TODO diffusion timestep input is not fused here
141
+ x = x + self.attn(
142
+ self.norm1(x),
143
+ context=None,
144
+ context_mask=x_mask,
145
+ extras=extras
146
+ )
147
+
148
+ # time-aligned context
149
+ time_aligned_context = self.ta_context_projection(time_aligned_context)
150
+ x = self.dilated_conv(x.transpose(1, 2)
151
+ ).transpose(1, 2) + time_aligned_context
152
+
153
+ gate, filter = torch.chunk(x, 2, dim=-1)
154
+ x = torch.sigmoid(gate) * torch.tanh(filter)
155
+
156
+ # cross attention
157
+ if self.use_context:
158
+ assert context is not None
159
+ x = x + self.cross_attn(
160
+ x=self.norm2(x),
161
+ context=self.norm_context(context),
162
+ context_mask=context_mask,
163
+ extras=extras
164
+ )
165
+
166
+ # mlp
167
+ if self.use_adanorm:
168
+ x_norm = film_modulate(
169
+ self.norm3(x), shift=shift_mlp, scale=scale_mlp
170
+ )
171
+ x = x + (1-gate_mlp) * self.mlp(x_norm)
172
+ else:
173
+ x = x + self.mlp(self.norm3(x))
174
+
175
+ return x
176
+
177
+
178
+ class AudioUDiT(UDiT):
179
+ def __init__(
180
+ self,
181
+ img_size=224,
182
+ patch_size=16,
183
+ in_chans=3,
184
+ input_type='2d',
185
+ out_chans=None,
186
+ embed_dim=768,
187
+ depth=12,
188
+ dilation_cycle_length=4,
189
+ num_heads=12,
190
+ mlp_ratio=4,
191
+ qkv_bias=False,
192
+ qk_scale=None,
193
+ qk_norm=None,
194
+ act_layer='gelu',
195
+ norm_layer='layernorm',
196
+ context_norm=False,
197
+ use_checkpoint=False,
198
+ time_fusion='token',
199
+ ada_sola_rank=None,
200
+ ada_sola_alpha=None,
201
+ cls_dim=None,
202
+ time_aligned_context_dim=768,
203
+ context_dim=768,
204
+ context_fusion='concat',
205
+ context_max_length=128,
206
+ context_pe_method='sinu',
207
+ pe_method='abs',
208
+ rope_mode='none',
209
+ use_conv=True,
210
+ skip=True,
211
+ skip_norm=True
212
+ ):
213
+ nn.Module.__init__(self)
214
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
215
+
216
+ # input
217
+ self.in_chans = in_chans
218
+ self.input_type = input_type
219
+ if self.input_type == '2d':
220
+ num_patches = (img_size[0] //
221
+ patch_size) * (img_size[1] // patch_size)
222
+ elif self.input_type == '1d':
223
+ num_patches = img_size // patch_size
224
+ self.patch_embed = PatchEmbed(
225
+ patch_size=patch_size,
226
+ in_chans=in_chans,
227
+ embed_dim=embed_dim,
228
+ input_type=input_type
229
+ )
230
+ out_chans = in_chans if out_chans is None else out_chans
231
+ self.out_chans = out_chans
232
+
233
+ # position embedding
234
+ self.rope = rope_mode
235
+ self.x_pe = PE_wrapper(
236
+ dim=embed_dim, method=pe_method, length=num_patches
237
+ )
238
+
239
+ # time embed
240
+ self.time_embed = TimestepEmbedder(embed_dim)
241
+ self.time_fusion = time_fusion
242
+ self.use_adanorm = False
243
+
244
+ # cls embed
245
+ if cls_dim is not None:
246
+ self.cls_embed = nn.Sequential(
247
+ nn.Linear(cls_dim, embed_dim, bias=True),
248
+ nn.SiLU(),
249
+ nn.Linear(embed_dim, embed_dim, bias=True),
250
+ )
251
+ else:
252
+ self.cls_embed = None
253
+
254
+ # time fusion
255
+ if time_fusion == 'token':
256
+ # put token at the beginning of sequence
257
+ self.extras = 2 if self.cls_embed else 1
258
+ self.time_pe = PE_wrapper(
259
+ dim=embed_dim, method='abs', length=self.extras
260
+ )
261
+ elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']:
262
+ self.use_adanorm = True
263
+ # aviod repetitive silu for each adaln block
264
+ self.time_act = nn.SiLU()
265
+ self.extras = 0
266
+ self.time_ada_final = nn.Linear(
267
+ embed_dim, 2 * embed_dim, bias=True
268
+ )
269
+ if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']:
270
+ # shared adaln
271
+ self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
272
+ else:
273
+ self.time_ada = None
274
+ else:
275
+ raise NotImplementedError
276
+
277
+ # context
278
+ # use a simple projection
279
+ self.use_context = False
280
+ self.context_cross = False
281
+ self.context_max_length = context_max_length
282
+ self.context_fusion = 'none'
283
+ if context_dim is not None:
284
+ self.use_context = True
285
+ self.context_embed = nn.Sequential(
286
+ nn.Linear(context_dim, embed_dim, bias=True),
287
+ nn.SiLU(),
288
+ nn.Linear(embed_dim, embed_dim, bias=True),
289
+ )
290
+ self.context_fusion = context_fusion
291
+ if context_fusion == 'concat' or context_fusion == 'joint':
292
+ self.extras += context_max_length
293
+ self.context_pe = PE_wrapper(
294
+ dim=embed_dim,
295
+ method=context_pe_method,
296
+ length=context_max_length
297
+ )
298
+ # no cross attention layers
299
+ context_dim = None
300
+ elif context_fusion == 'cross':
301
+ self.context_pe = PE_wrapper(
302
+ dim=embed_dim,
303
+ method=context_pe_method,
304
+ length=context_max_length
305
+ )
306
+ self.context_cross = True
307
+ context_dim = embed_dim
308
+ else:
309
+ raise NotImplementedError
310
+
311
+ self.use_skip = skip
312
+
313
+ # norm layers
314
+ if norm_layer == 'layernorm':
315
+ norm_layer = nn.LayerNorm
316
+ elif norm_layer == 'rmsnorm':
317
+ norm_layer = RMSNorm
318
+ else:
319
+ raise NotImplementedError
320
+
321
+ self.in_blocks = nn.ModuleList([
322
+ AudioDiTBlock(
323
+ dim=embed_dim,
324
+ time_aligned_context_dim=time_aligned_context_dim,
325
+ dilation=2**(i % dilation_cycle_length),
326
+ context_dim=context_dim,
327
+ num_heads=num_heads,
328
+ mlp_ratio=mlp_ratio,
329
+ qkv_bias=qkv_bias,
330
+ qk_scale=qk_scale,
331
+ qk_norm=qk_norm,
332
+ act_layer=act_layer,
333
+ norm_layer=norm_layer,
334
+ time_fusion=time_fusion,
335
+ ada_sola_rank=ada_sola_rank,
336
+ ada_sola_alpha=ada_sola_alpha,
337
+ skip=False,
338
+ skip_norm=False,
339
+ rope_mode=self.rope,
340
+ context_norm=context_norm,
341
+ use_checkpoint=use_checkpoint
342
+ ) for i in range(depth // 2)
343
+ ])
344
+
345
+ self.mid_block = AudioDiTBlock(
346
+ dim=embed_dim,
347
+ time_aligned_context_dim=time_aligned_context_dim,
348
+ dilation=1,
349
+ context_dim=context_dim,
350
+ num_heads=num_heads,
351
+ mlp_ratio=mlp_ratio,
352
+ qkv_bias=qkv_bias,
353
+ qk_scale=qk_scale,
354
+ qk_norm=qk_norm,
355
+ act_layer=act_layer,
356
+ norm_layer=norm_layer,
357
+ time_fusion=time_fusion,
358
+ ada_sola_rank=ada_sola_rank,
359
+ ada_sola_alpha=ada_sola_alpha,
360
+ skip=False,
361
+ skip_norm=False,
362
+ rope_mode=self.rope,
363
+ context_norm=context_norm,
364
+ use_checkpoint=use_checkpoint
365
+ )
366
+
367
+ self.out_blocks = nn.ModuleList([
368
+ AudioDiTBlock(
369
+ dim=embed_dim,
370
+ time_aligned_context_dim=time_aligned_context_dim,
371
+ dilation=2**(i % dilation_cycle_length),
372
+ context_dim=context_dim,
373
+ num_heads=num_heads,
374
+ mlp_ratio=mlp_ratio,
375
+ qkv_bias=qkv_bias,
376
+ qk_scale=qk_scale,
377
+ qk_norm=qk_norm,
378
+ act_layer=act_layer,
379
+ norm_layer=norm_layer,
380
+ time_fusion=time_fusion,
381
+ ada_sola_rank=ada_sola_rank,
382
+ ada_sola_alpha=ada_sola_alpha,
383
+ skip=skip,
384
+ skip_norm=skip_norm,
385
+ rope_mode=self.rope,
386
+ context_norm=context_norm,
387
+ use_checkpoint=use_checkpoint
388
+ ) for i in range(depth // 2)
389
+ ])
390
+
391
+ # FinalLayer block
392
+ self.use_conv = use_conv
393
+ self.final_block = FinalBlock(
394
+ embed_dim=embed_dim,
395
+ patch_size=patch_size,
396
+ img_size=img_size,
397
+ in_chans=out_chans,
398
+ input_type=input_type,
399
+ norm_layer=norm_layer,
400
+ use_conv=use_conv,
401
+ use_adanorm=self.use_adanorm
402
+ )
403
+ self.initialize_weights()
404
+
405
+ def forward(
406
+ self,
407
+ x,
408
+ timesteps,
409
+ time_aligned_context,
410
+ context,
411
+ x_mask=None,
412
+ context_mask=None,
413
+ cls_token=None,
414
+ controlnet_skips=None,
415
+ ):
416
+ # make it compatible with int time step during inference
417
+ if timesteps.dim() == 0:
418
+ timesteps = timesteps.expand(x.shape[0]
419
+ ).to(x.device, dtype=torch.long)
420
+
421
+ x = self.patch_embed(x)
422
+ x = self.x_pe(x)
423
+
424
+ B, L, D = x.shape
425
+
426
+ if self.use_context:
427
+ context_token = self.context_embed(context)
428
+ context_token = self.context_pe(context_token)
429
+ if self.context_fusion == 'concat' or self.context_fusion == 'joint':
430
+ x, x_mask = self._concat_x_context(
431
+ x=x,
432
+ context=context_token,
433
+ x_mask=x_mask,
434
+ context_mask=context_mask
435
+ )
436
+ context_token, context_mask = None, None
437
+ else:
438
+ context_token, context_mask = None, None
439
+
440
+ time_token = self.time_embed(timesteps)
441
+ if self.cls_embed:
442
+ cls_token = self.cls_embed(cls_token)
443
+ time_ada = None
444
+ time_ada_final = None
445
+ if self.use_adanorm:
446
+ if self.cls_embed:
447
+ time_token = time_token + cls_token
448
+ time_token = self.time_act(time_token)
449
+ time_ada_final = self.time_ada_final(time_token)
450
+ if self.time_ada is not None:
451
+ time_ada = self.time_ada(time_token)
452
+ else:
453
+ time_token = time_token.unsqueeze(dim=1)
454
+ if self.cls_embed:
455
+ cls_token = cls_token.unsqueeze(dim=1)
456
+ time_token = torch.cat([time_token, cls_token], dim=1)
457
+ time_token = self.time_pe(time_token)
458
+ x = torch.cat((time_token, x), dim=1)
459
+ if x_mask is not None:
460
+ x_mask = torch.cat([
461
+ torch.ones(B, time_token.shape[1],
462
+ device=x_mask.device).bool(), x_mask
463
+ ],
464
+ dim=1)
465
+ time_token = None
466
+
467
+ skips = []
468
+ for blk in self.in_blocks:
469
+ x = blk(
470
+ x=x,
471
+ time_aligned_context=time_aligned_context,
472
+ time_token=time_token,
473
+ time_ada=time_ada,
474
+ skip=None,
475
+ context=context_token,
476
+ x_mask=x_mask,
477
+ context_mask=context_mask,
478
+ extras=self.extras
479
+ )
480
+ if self.use_skip:
481
+ skips.append(x)
482
+
483
+ x = self.mid_block(
484
+ x=x,
485
+ time_aligned_context=time_aligned_context,
486
+ time_token=time_token,
487
+ time_ada=time_ada,
488
+ skip=None,
489
+ context=context_token,
490
+ x_mask=x_mask,
491
+ context_mask=context_mask,
492
+ extras=self.extras
493
+ )
494
+ for blk in self.out_blocks:
495
+ if self.use_skip:
496
+ skip = skips.pop()
497
+ if controlnet_skips:
498
+ # add to skip like u-net controlnet
499
+ skip = skip + controlnet_skips.pop()
500
+ else:
501
+ skip = None
502
+ if controlnet_skips:
503
+ # directly add to x
504
+ x = x + controlnet_skips.pop()
505
+
506
+ x = blk(
507
+ x=x,
508
+ time_aligned_context=time_aligned_context,
509
+ time_token=time_token,
510
+ time_ada=time_ada,
511
+ skip=skip,
512
+ context=context_token,
513
+ x_mask=x_mask,
514
+ context_mask=context_mask,
515
+ extras=self.extras
516
+ )
517
+
518
+ x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
519
+
520
+ return x
models/dit/audio_dit.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ from .mask_dit import DiTBlock, FinalBlock, UDiT
6
+ from .modules import (
7
+ film_modulate,
8
+ PatchEmbed,
9
+ PE_wrapper,
10
+ TimestepEmbedder,
11
+ RMSNorm,
12
+ )
13
+
14
+
15
+ class AudioDiTBlock(DiTBlock):
16
+ """
17
+ A modified DiT block with time aligned context add to latent.
18
+ """
19
+ def __init__(
20
+ self,
21
+ dim,
22
+ ta_context_dim,
23
+ ta_context_norm=False,
24
+ context_dim=None,
25
+ num_heads=8,
26
+ mlp_ratio=4.,
27
+ qkv_bias=False,
28
+ qk_scale=None,
29
+ qk_norm=None,
30
+ act_layer='gelu',
31
+ norm_layer=nn.LayerNorm,
32
+ ta_context_fusion='add',
33
+ time_fusion='none',
34
+ ada_sola_rank=None,
35
+ ada_sola_alpha=None,
36
+ skip=False,
37
+ skip_norm=False,
38
+ rope_mode='none',
39
+ context_norm=False,
40
+ use_checkpoint=False
41
+ ):
42
+ super().__init__(
43
+ dim=dim,
44
+ context_dim=context_dim,
45
+ num_heads=num_heads,
46
+ mlp_ratio=mlp_ratio,
47
+ qkv_bias=qkv_bias,
48
+ qk_scale=qk_scale,
49
+ qk_norm=qk_norm,
50
+ act_layer=act_layer,
51
+ norm_layer=norm_layer,
52
+ time_fusion=time_fusion,
53
+ ada_sola_rank=ada_sola_rank,
54
+ ada_sola_alpha=ada_sola_alpha,
55
+ skip=skip,
56
+ skip_norm=skip_norm,
57
+ rope_mode=rope_mode,
58
+ context_norm=context_norm,
59
+ use_checkpoint=use_checkpoint
60
+ )
61
+ self.ta_context_fusion = ta_context_fusion
62
+ self.ta_context_norm = ta_context_norm
63
+ if self.ta_context_fusion == "add":
64
+ self.ta_context_projection = nn.Linear(ta_context_dim, dim)
65
+ self.ta_context_norm = norm_layer(
66
+ ta_context_dim
67
+ ) if self.ta_context_norm else nn.Identity()
68
+ elif self.ta_context_fusion == "concat":
69
+ self.ta_context_projection = nn.Linear(ta_context_dim + dim, dim)
70
+ self.ta_context_norm = norm_layer(
71
+ ta_context_dim + dim
72
+ ) if self.ta_context_norm else nn.Identity()
73
+
74
+ def forward(
75
+ self,
76
+ x,
77
+ time_aligned_context,
78
+ time_token=None,
79
+ time_ada=None,
80
+ skip=None,
81
+ context=None,
82
+ x_mask=None,
83
+ context_mask=None,
84
+ extras=None
85
+ ):
86
+ if self.use_checkpoint:
87
+ return checkpoint(
88
+ self._forward,
89
+ x,
90
+ time_aligned_context,
91
+ time_token,
92
+ time_ada,
93
+ skip,
94
+ context,
95
+ x_mask,
96
+ context_mask,
97
+ extras,
98
+ use_reentrant=False
99
+ )
100
+ else:
101
+ return self._forward(
102
+ x,
103
+ time_aligned_context,
104
+ time_token,
105
+ time_ada,
106
+ skip,
107
+ context,
108
+ x_mask,
109
+ context_mask,
110
+ extras,
111
+ )
112
+
113
+ def _forward(
114
+ self,
115
+ x,
116
+ time_aligned_context,
117
+ time_token=None,
118
+ time_ada=None,
119
+ skip=None,
120
+ context=None,
121
+ x_mask=None,
122
+ context_mask=None,
123
+ extras=None
124
+ ):
125
+ B, T, C = x.shape
126
+
127
+ # # time aligned context
128
+ # if self.ta_context_fusion == "add":
129
+ # time_aligned_context = self.ta_context_projection(
130
+ # self.ta_context_norm(time_aligned_context)
131
+ # )
132
+ # x = x + time_aligned_context
133
+ # elif self.ta_context_fusion == "concat":
134
+ # cat = torch.cat([x, time_aligned_context], dim=-1)
135
+ # cat = self.ta_context_norm(cat)
136
+ # x = self.ta_context_projection(cat)
137
+
138
+ # skip connection
139
+ if self.skip_linear is not None:
140
+ assert skip is not None
141
+ cat = torch.cat([x, skip], dim=-1)
142
+ cat = self.skip_norm(cat)
143
+ x = self.skip_linear(cat)
144
+ #print('skip')
145
+ #print(x)
146
+ if self.use_adanorm:
147
+ time_ada = self.adaln(time_token, time_ada)
148
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
149
+ gate_mlp) = time_ada.chunk(6, dim=1)
150
+
151
+ # self attention
152
+ if self.use_adanorm:
153
+ x_norm = film_modulate(
154
+ self.norm1(x), shift=shift_msa, scale=scale_msa
155
+ )
156
+ x = x + (1-gate_msa) * self.attn(
157
+ x_norm, context=None, context_mask=x_mask, extras=extras
158
+ )
159
+ else:
160
+ # TODO diffusion timestep input is not fused here
161
+ x = x + self.attn(
162
+ self.norm1(x),
163
+ context=None,
164
+ context_mask=x_mask,
165
+ extras=extras
166
+ )
167
+
168
+ # time aligned context fusion
169
+ if self.ta_context_fusion == "add":
170
+ time_aligned_context = self.ta_context_projection(
171
+ self.ta_context_norm(time_aligned_context)
172
+ )
173
+ x = x + time_aligned_context
174
+ elif self.ta_context_fusion == "concat":
175
+ cat = torch.cat([x, time_aligned_context], dim=-1)
176
+ cat = self.ta_context_norm(cat)
177
+ x = self.ta_context_projection(cat)
178
+
179
+ # cross attention
180
+ if self.use_context:
181
+ assert context is not None
182
+ x = x + self.cross_attn(
183
+ x=self.norm2(x),
184
+ context=self.norm_context(context),
185
+ context_mask=context_mask,
186
+ extras=extras
187
+ )
188
+
189
+ # mlp
190
+ if self.use_adanorm:
191
+ x_norm = film_modulate(
192
+ self.norm3(x), shift=shift_mlp, scale=scale_mlp
193
+ )
194
+ x = x + (1-gate_mlp) * self.mlp(x_norm)
195
+ else:
196
+ x = x + self.mlp(self.norm3(x))
197
+
198
+ return x
199
+
200
+
201
+ class AudioUDiT(UDiT):
202
+ def __init__(
203
+ self,
204
+ img_size=224,
205
+ patch_size=16,
206
+ in_chans=3,
207
+ input_type='2d',
208
+ out_chans=None,
209
+ embed_dim=768,
210
+ depth=12,
211
+ num_heads=12,
212
+ mlp_ratio=4,
213
+ qkv_bias=False,
214
+ qk_scale=None,
215
+ qk_norm=None,
216
+ act_layer='gelu',
217
+ norm_layer='layernorm',
218
+ context_norm=False,
219
+ use_checkpoint=False,
220
+ time_fusion='token',
221
+ ada_sola_rank=None,
222
+ ada_sola_alpha=None,
223
+ cls_dim=None,
224
+ ta_context_dim=768,
225
+ ta_context_fusion='concat',
226
+ ta_context_norm=True,
227
+ context_dim=768,
228
+ context_fusion='concat',
229
+ context_max_length=128,
230
+ context_pe_method='sinu',
231
+ pe_method='abs',
232
+ rope_mode='none',
233
+ use_conv=True,
234
+ skip=True,
235
+ skip_norm=True
236
+ ):
237
+ nn.Module.__init__(self)
238
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
239
+
240
+ # input
241
+ self.in_chans = in_chans
242
+ self.input_type = input_type
243
+ if self.input_type == '2d':
244
+ num_patches = (img_size[0] //
245
+ patch_size) * (img_size[1] // patch_size)
246
+ elif self.input_type == '1d':
247
+ num_patches = img_size // patch_size
248
+ self.patch_embed = PatchEmbed(
249
+ patch_size=patch_size,
250
+ in_chans=in_chans,
251
+ embed_dim=embed_dim,
252
+ input_type=input_type
253
+ )
254
+ out_chans = in_chans if out_chans is None else out_chans
255
+ self.out_chans = out_chans
256
+
257
+ # position embedding
258
+ self.rope = rope_mode
259
+ self.x_pe = PE_wrapper(
260
+ dim=embed_dim, method=pe_method, length=num_patches
261
+ )
262
+
263
+ # time embed
264
+ self.time_embed = TimestepEmbedder(embed_dim)
265
+ self.time_fusion = time_fusion
266
+ self.use_adanorm = False
267
+
268
+ # cls embed
269
+ if cls_dim is not None:
270
+ self.cls_embed = nn.Sequential(
271
+ nn.Linear(cls_dim, embed_dim, bias=True),
272
+ nn.SiLU(),
273
+ nn.Linear(embed_dim, embed_dim, bias=True),
274
+ )
275
+ else:
276
+ self.cls_embed = None
277
+
278
+ # time fusion
279
+ if time_fusion == 'token':
280
+ # put token at the beginning of sequence
281
+ self.extras = 2 if self.cls_embed else 1
282
+ self.time_pe = PE_wrapper(
283
+ dim=embed_dim, method='abs', length=self.extras
284
+ )
285
+ elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']:
286
+ self.use_adanorm = True
287
+ # aviod repetitive silu for each adaln block
288
+ self.time_act = nn.SiLU()
289
+ self.extras = 0
290
+ self.time_ada_final = nn.Linear(
291
+ embed_dim, 2 * embed_dim, bias=True
292
+ )
293
+ if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']:
294
+ # shared adaln
295
+ self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
296
+ else:
297
+ self.time_ada = None
298
+ else:
299
+ raise NotImplementedError
300
+
301
+ # context
302
+ # use a simple projection
303
+ self.use_context = False
304
+ self.context_cross = False
305
+ self.context_max_length = context_max_length
306
+ self.context_fusion = 'none'
307
+ if context_dim is not None:
308
+ self.use_context = True
309
+ self.context_embed = nn.Sequential(
310
+ nn.Linear(context_dim, embed_dim, bias=True),
311
+ nn.SiLU(),
312
+ nn.Linear(embed_dim, embed_dim, bias=True),
313
+ )
314
+ self.context_fusion = context_fusion
315
+ if context_fusion == 'concat' or context_fusion == 'joint':
316
+ self.extras += context_max_length
317
+ self.context_pe = PE_wrapper(
318
+ dim=embed_dim,
319
+ method=context_pe_method,
320
+ length=context_max_length
321
+ )
322
+ # no cross attention layers
323
+ context_dim = None
324
+ elif context_fusion == 'cross':
325
+ self.context_pe = PE_wrapper(
326
+ dim=embed_dim,
327
+ method=context_pe_method,
328
+ length=context_max_length
329
+ )
330
+ self.context_cross = True
331
+ context_dim = embed_dim
332
+ else:
333
+ raise NotImplementedError
334
+
335
+ self.use_skip = skip
336
+
337
+ # norm layers
338
+ if norm_layer == 'layernorm':
339
+ norm_layer = nn.LayerNorm
340
+ elif norm_layer == 'rmsnorm':
341
+ norm_layer = RMSNorm
342
+ else:
343
+ raise NotImplementedError
344
+
345
+ self.in_blocks = nn.ModuleList([
346
+ AudioDiTBlock(
347
+ dim=embed_dim,
348
+ ta_context_dim=ta_context_dim,
349
+ ta_context_fusion=ta_context_fusion,
350
+ ta_context_norm=ta_context_norm,
351
+ context_dim=context_dim,
352
+ num_heads=num_heads,
353
+ mlp_ratio=mlp_ratio,
354
+ qkv_bias=qkv_bias,
355
+ qk_scale=qk_scale,
356
+ qk_norm=qk_norm,
357
+ act_layer=act_layer,
358
+ norm_layer=norm_layer,
359
+ time_fusion=time_fusion,
360
+ ada_sola_rank=ada_sola_rank,
361
+ ada_sola_alpha=ada_sola_alpha,
362
+ skip=False,
363
+ skip_norm=False,
364
+ rope_mode=self.rope,
365
+ context_norm=context_norm,
366
+ use_checkpoint=use_checkpoint
367
+ ) for i in range(depth // 2)
368
+ ])
369
+
370
+ self.mid_block = AudioDiTBlock(
371
+ dim=embed_dim,
372
+ ta_context_dim=ta_context_dim,
373
+ context_dim=context_dim,
374
+ num_heads=num_heads,
375
+ mlp_ratio=mlp_ratio,
376
+ qkv_bias=qkv_bias,
377
+ qk_scale=qk_scale,
378
+ qk_norm=qk_norm,
379
+ act_layer=act_layer,
380
+ norm_layer=norm_layer,
381
+ time_fusion=time_fusion,
382
+ ada_sola_rank=ada_sola_rank,
383
+ ada_sola_alpha=ada_sola_alpha,
384
+ ta_context_fusion=ta_context_fusion,
385
+ ta_context_norm=ta_context_norm,
386
+ skip=False,
387
+ skip_norm=False,
388
+ rope_mode=self.rope,
389
+ context_norm=context_norm,
390
+ use_checkpoint=use_checkpoint
391
+ )
392
+
393
+ self.out_blocks = nn.ModuleList([
394
+ AudioDiTBlock(
395
+ dim=embed_dim,
396
+ ta_context_dim=ta_context_dim,
397
+ context_dim=context_dim,
398
+ num_heads=num_heads,
399
+ mlp_ratio=mlp_ratio,
400
+ qkv_bias=qkv_bias,
401
+ qk_scale=qk_scale,
402
+ qk_norm=qk_norm,
403
+ act_layer=act_layer,
404
+ norm_layer=norm_layer,
405
+ time_fusion=time_fusion,
406
+ ada_sola_rank=ada_sola_rank,
407
+ ada_sola_alpha=ada_sola_alpha,
408
+ ta_context_fusion=ta_context_fusion,
409
+ ta_context_norm=ta_context_norm,
410
+ skip=skip,
411
+ skip_norm=skip_norm,
412
+ rope_mode=self.rope,
413
+ context_norm=context_norm,
414
+ use_checkpoint=use_checkpoint
415
+ ) for i in range(depth // 2)
416
+ ])
417
+
418
+ # FinalLayer block
419
+ self.use_conv = use_conv
420
+ self.final_block = FinalBlock(
421
+ embed_dim=embed_dim,
422
+ patch_size=patch_size,
423
+ img_size=img_size,
424
+ in_chans=out_chans,
425
+ input_type=input_type,
426
+ norm_layer=norm_layer,
427
+ use_conv=use_conv,
428
+ use_adanorm=self.use_adanorm
429
+ )
430
+ self.initialize_weights()
431
+
432
+ def forward(
433
+ self,
434
+ x,
435
+ timesteps,
436
+ time_aligned_context,
437
+ context,
438
+ x_mask=None,
439
+ context_mask=None,
440
+ cls_token=None,
441
+ controlnet_skips=None,
442
+ ):
443
+ # make it compatible with int time step during inference
444
+ if timesteps.dim() == 0:
445
+ timesteps = timesteps.expand(x.shape[0]
446
+ ).to(x.device, dtype=torch.long)
447
+
448
+ x = self.patch_embed(x)
449
+ x = self.x_pe(x)
450
+
451
+ B, L, D = x.shape
452
+
453
+ if self.use_context:
454
+ context_token = self.context_embed(context)
455
+ context_token = self.context_pe(context_token)
456
+ if self.context_fusion == 'concat' or self.context_fusion == 'joint':
457
+ x, x_mask = self._concat_x_context(
458
+ x=x,
459
+ context=context_token,
460
+ x_mask=x_mask,
461
+ context_mask=context_mask
462
+ )
463
+ context_token, context_mask = None, None
464
+ else:
465
+ context_token, context_mask = None, None
466
+
467
+ time_token = self.time_embed(timesteps)
468
+ if self.cls_embed:
469
+ cls_token = self.cls_embed(cls_token)
470
+ time_ada = None
471
+ time_ada_final = None
472
+ if self.use_adanorm:
473
+ if self.cls_embed:
474
+ time_token = time_token + cls_token
475
+ time_token = self.time_act(time_token)
476
+ time_ada_final = self.time_ada_final(time_token)
477
+ if self.time_ada is not None:
478
+ time_ada = self.time_ada(time_token)
479
+ else:
480
+ time_token = time_token.unsqueeze(dim=1)
481
+ if self.cls_embed:
482
+ cls_token = cls_token.unsqueeze(dim=1)
483
+ time_token = torch.cat([time_token, cls_token], dim=1)
484
+ time_token = self.time_pe(time_token)
485
+ x = torch.cat((time_token, x), dim=1)
486
+ if x_mask is not None:
487
+ x_mask = torch.cat([
488
+ torch.ones(B, time_token.shape[1],
489
+ device=x_mask.device).bool(), x_mask
490
+ ],
491
+ dim=1)
492
+ time_token = None
493
+
494
+ skips = []
495
+ for blk in self.in_blocks:
496
+ x = blk(
497
+ x=x,
498
+ time_aligned_context=time_aligned_context,
499
+ time_token=time_token,
500
+ time_ada=time_ada,
501
+ skip=None,
502
+ context=context_token,
503
+ x_mask=x_mask,
504
+ context_mask=context_mask,
505
+ extras=self.extras
506
+ )
507
+
508
+ if self.use_skip:
509
+ skips.append(x)
510
+
511
+ x = self.mid_block(
512
+ x=x,
513
+ time_aligned_context=time_aligned_context,
514
+ time_token=time_token,
515
+ time_ada=time_ada,
516
+ skip=None,
517
+ context=context_token,
518
+ x_mask=x_mask,
519
+ context_mask=context_mask,
520
+ extras=self.extras
521
+ )
522
+
523
+ for blk in self.out_blocks:
524
+ if self.use_skip:
525
+ skip = skips.pop()
526
+ if controlnet_skips:
527
+ # add to skip like u-net controlnet
528
+ skip = skip + controlnet_skips.pop()
529
+ else:
530
+ skip = None
531
+ if controlnet_skips:
532
+ # directly add to x
533
+ x = x + controlnet_skips.pop()
534
+
535
+ x = blk(
536
+ x=x,
537
+ time_aligned_context=time_aligned_context,
538
+ time_token=time_token,
539
+ time_ada=time_ada,
540
+ skip=skip,
541
+ context=context_token,
542
+ x_mask=x_mask,
543
+ context_mask=context_mask,
544
+ extras=self.extras
545
+ )
546
+
547
+ x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
548
+
549
+ return x
models/dit/mask_dit.py ADDED
@@ -0,0 +1,823 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.checkpoint import checkpoint
6
+
7
+ from .modules import (
8
+ film_modulate,
9
+ unpatchify,
10
+ PatchEmbed,
11
+ PE_wrapper,
12
+ TimestepEmbedder,
13
+ FeedForward,
14
+ RMSNorm,
15
+ )
16
+ from .span_mask import compute_mask_indices
17
+ from .attention import Attention
18
+
19
+ logger = logging.Logger(__file__)
20
+
21
+
22
+ class AdaLN(nn.Module):
23
+ def __init__(self, dim, ada_mode='ada', r=None, alpha=None):
24
+ super().__init__()
25
+ self.ada_mode = ada_mode
26
+ self.scale_shift_table = None
27
+ if ada_mode == 'ada':
28
+ # move nn.silu outside
29
+ self.time_ada = nn.Linear(dim, 6 * dim, bias=True)
30
+ elif ada_mode == 'ada_single':
31
+ # adaln used in pixel-art alpha
32
+ self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
33
+ elif ada_mode in ['ada_solo', 'ada_sola_bias']:
34
+ self.lora_a = nn.Linear(dim, r * 6, bias=False)
35
+ self.lora_b = nn.Linear(r * 6, dim * 6, bias=False)
36
+ self.scaling = alpha / r
37
+ if ada_mode == 'ada_sola_bias':
38
+ # take bias out for consistency
39
+ self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
40
+ else:
41
+ raise NotImplementedError
42
+
43
+ def forward(self, time_token=None, time_ada=None):
44
+ if self.ada_mode == 'ada':
45
+ assert time_ada is None
46
+ B = time_token.shape[0]
47
+ time_ada = self.time_ada(time_token).reshape(B, 6, -1)
48
+ elif self.ada_mode == 'ada_single':
49
+ B = time_ada.shape[0]
50
+ time_ada = time_ada.reshape(B, 6, -1)
51
+ time_ada = self.scale_shift_table[None] + time_ada
52
+ elif self.ada_mode in ['ada_sola', 'ada_sola_bias']:
53
+ B = time_ada.shape[0]
54
+ time_ada_lora = self.lora_b(self.lora_a(time_token)) * self.scaling
55
+ time_ada = time_ada + time_ada_lora
56
+ time_ada = time_ada.reshape(B, 6, -1)
57
+ if self.scale_shift_table is not None:
58
+ time_ada = self.scale_shift_table[None] + time_ada
59
+ else:
60
+ raise NotImplementedError
61
+ return time_ada
62
+
63
+
64
+ class DiTBlock(nn.Module):
65
+ """
66
+ A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
67
+ """
68
+ def __init__(
69
+ self,
70
+ dim,
71
+ context_dim=None,
72
+ num_heads=8,
73
+ mlp_ratio=4.,
74
+ qkv_bias=False,
75
+ qk_scale=None,
76
+ qk_norm=None,
77
+ act_layer='gelu',
78
+ norm_layer=nn.LayerNorm,
79
+ time_fusion='none',
80
+ ada_sola_rank=None,
81
+ ada_sola_alpha=None,
82
+ skip=False,
83
+ skip_norm=False,
84
+ rope_mode='none',
85
+ context_norm=False,
86
+ use_checkpoint=False
87
+ ):
88
+
89
+ super().__init__()
90
+ self.norm1 = norm_layer(dim)
91
+ self.attn = Attention(
92
+ dim=dim,
93
+ num_heads=num_heads,
94
+ qkv_bias=qkv_bias,
95
+ qk_scale=qk_scale,
96
+ qk_norm=qk_norm,
97
+ rope_mode=rope_mode
98
+ )
99
+
100
+ if context_dim is not None:
101
+ self.use_context = True
102
+ self.cross_attn = Attention(
103
+ dim=dim,
104
+ num_heads=num_heads,
105
+ context_dim=context_dim,
106
+ qkv_bias=qkv_bias,
107
+ qk_scale=qk_scale,
108
+ qk_norm=qk_norm,
109
+ rope_mode='none'
110
+ )
111
+ self.norm2 = norm_layer(dim)
112
+ if context_norm:
113
+ self.norm_context = norm_layer(context_dim)
114
+ else:
115
+ self.norm_context = nn.Identity()
116
+ else:
117
+ self.use_context = False
118
+
119
+ self.norm3 = norm_layer(dim)
120
+ self.mlp = FeedForward(
121
+ dim=dim, mult=mlp_ratio, activation_fn=act_layer, dropout=0
122
+ )
123
+
124
+ self.use_adanorm = True if time_fusion != 'token' else False
125
+ if self.use_adanorm:
126
+ self.adaln = AdaLN(
127
+ dim,
128
+ ada_mode=time_fusion,
129
+ r=ada_sola_rank,
130
+ alpha=ada_sola_alpha
131
+ )
132
+ if skip:
133
+ self.skip_norm = norm_layer(2 *
134
+ dim) if skip_norm else nn.Identity()
135
+ self.skip_linear = nn.Linear(2 * dim, dim)
136
+ else:
137
+ self.skip_linear = None
138
+
139
+ self.use_checkpoint = use_checkpoint
140
+
141
+ def forward(
142
+ self,
143
+ x,
144
+ time_token=None,
145
+ time_ada=None,
146
+ skip=None,
147
+ context=None,
148
+ x_mask=None,
149
+ context_mask=None,
150
+ extras=None
151
+ ):
152
+ if self.use_checkpoint:
153
+ return checkpoint(
154
+ self._forward,
155
+ x,
156
+ time_token,
157
+ time_ada,
158
+ skip,
159
+ context,
160
+ x_mask,
161
+ context_mask,
162
+ extras,
163
+ use_reentrant=False
164
+ )
165
+ else:
166
+ return self._forward(
167
+ x, time_token, time_ada, skip, context, x_mask, context_mask,
168
+ extras
169
+ )
170
+
171
+ def _forward(
172
+ self,
173
+ x,
174
+ time_token=None,
175
+ time_ada=None,
176
+ skip=None,
177
+ context=None,
178
+ x_mask=None,
179
+ context_mask=None,
180
+ extras=None
181
+ ):
182
+ B, T, C = x.shape
183
+ if self.skip_linear is not None:
184
+ assert skip is not None
185
+ cat = torch.cat([x, skip], dim=-1)
186
+ cat = self.skip_norm(cat)
187
+ x = self.skip_linear(cat)
188
+
189
+ if self.use_adanorm:
190
+ time_ada = self.adaln(time_token, time_ada)
191
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
192
+ gate_mlp) = time_ada.chunk(6, dim=1)
193
+
194
+ # self attention
195
+ if self.use_adanorm:
196
+ x_norm = film_modulate(
197
+ self.norm1(x), shift=shift_msa, scale=scale_msa
198
+ )
199
+ x = x + (1-gate_msa) * self.attn(
200
+ x_norm, context=None, context_mask=x_mask, extras=extras
201
+ )
202
+ else:
203
+ x = x + self.attn(
204
+ self.norm1(x),
205
+ context=None,
206
+ context_mask=x_mask,
207
+ extras=extras
208
+ )
209
+
210
+ # cross attention
211
+ if self.use_context:
212
+ assert context is not None
213
+ x = x + self.cross_attn(
214
+ x=self.norm2(x),
215
+ context=self.norm_context(context),
216
+ context_mask=context_mask,
217
+ extras=extras
218
+ )
219
+
220
+ # mlp
221
+ if self.use_adanorm:
222
+ x_norm = film_modulate(
223
+ self.norm3(x), shift=shift_mlp, scale=scale_mlp
224
+ )
225
+ x = x + (1-gate_mlp) * self.mlp(x_norm)
226
+ else:
227
+ x = x + self.mlp(self.norm3(x))
228
+
229
+ return x
230
+
231
+
232
+ class FinalBlock(nn.Module):
233
+ def __init__(
234
+ self,
235
+ embed_dim,
236
+ patch_size,
237
+ in_chans,
238
+ img_size,
239
+ input_type='2d',
240
+ norm_layer=nn.LayerNorm,
241
+ use_conv=True,
242
+ use_adanorm=True
243
+ ):
244
+ super().__init__()
245
+ self.in_chans = in_chans
246
+ self.img_size = img_size
247
+ self.input_type = input_type
248
+
249
+ self.norm = norm_layer(embed_dim)
250
+ if use_adanorm:
251
+ self.use_adanorm = True
252
+ else:
253
+ self.use_adanorm = False
254
+
255
+ if input_type == '2d':
256
+ self.patch_dim = patch_size**2 * in_chans
257
+ self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
258
+ if use_conv:
259
+ self.final_layer = nn.Conv2d(
260
+ self.in_chans, self.in_chans, 3, padding=1
261
+ )
262
+ else:
263
+ self.final_layer = nn.Identity()
264
+
265
+ elif input_type == '1d':
266
+ self.patch_dim = patch_size * in_chans
267
+ self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
268
+ if use_conv:
269
+ self.final_layer = nn.Conv1d(
270
+ self.in_chans, self.in_chans, 3, padding=1
271
+ )
272
+ else:
273
+ self.final_layer = nn.Identity()
274
+
275
+ def forward(self, x, time_ada=None, extras=0):
276
+ B, T, C = x.shape
277
+ x = x[:, extras:, :]
278
+ # only handle generation target
279
+ if self.use_adanorm:
280
+ shift, scale = time_ada.reshape(B, 2, -1).chunk(2, dim=1)
281
+ x = film_modulate(self.norm(x), shift, scale)
282
+ else:
283
+ x = self.norm(x)
284
+ x = self.linear(x)
285
+ x = unpatchify(x, self.in_chans, self.input_type, self.img_size)
286
+ x = self.final_layer(x)
287
+ return x
288
+
289
+
290
+ class UDiT(nn.Module):
291
+ def __init__(
292
+ self,
293
+ img_size=224,
294
+ patch_size=16,
295
+ in_chans=3,
296
+ input_type='2d',
297
+ out_chans=None,
298
+ embed_dim=768,
299
+ depth=12,
300
+ num_heads=12,
301
+ mlp_ratio=4.,
302
+ qkv_bias=False,
303
+ qk_scale=None,
304
+ qk_norm=None,
305
+ act_layer='gelu',
306
+ norm_layer='layernorm',
307
+ context_norm=False,
308
+ use_checkpoint=False,
309
+ # time fusion ada or token
310
+ time_fusion='token',
311
+ ada_sola_rank=None,
312
+ ada_sola_alpha=None,
313
+ cls_dim=None,
314
+ # max length is only used for concat
315
+ context_dim=768,
316
+ context_fusion='concat',
317
+ context_max_length=128,
318
+ context_pe_method='sinu',
319
+ pe_method='abs',
320
+ rope_mode='none',
321
+ use_conv=True,
322
+ skip=True,
323
+ skip_norm=True
324
+ ):
325
+ super().__init__()
326
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
327
+
328
+ # input
329
+ self.in_chans = in_chans
330
+ self.input_type = input_type
331
+ if self.input_type == '2d':
332
+ num_patches = (img_size[0] //
333
+ patch_size) * (img_size[1] // patch_size)
334
+ elif self.input_type == '1d':
335
+ num_patches = img_size // patch_size
336
+ self.patch_embed = PatchEmbed(
337
+ patch_size=patch_size,
338
+ in_chans=in_chans,
339
+ embed_dim=embed_dim,
340
+ input_type=input_type
341
+ )
342
+ out_chans = in_chans if out_chans is None else out_chans
343
+ self.out_chans = out_chans
344
+
345
+ # position embedding
346
+ self.rope = rope_mode
347
+ self.x_pe = PE_wrapper(
348
+ dim=embed_dim, method=pe_method, length=num_patches
349
+ )
350
+
351
+ logger.info(f'x position embedding: {pe_method}')
352
+ logger.info(f'rope mode: {self.rope}')
353
+
354
+ # time embed
355
+ self.time_embed = TimestepEmbedder(embed_dim)
356
+ self.time_fusion = time_fusion
357
+ self.use_adanorm = False
358
+
359
+ # cls embed
360
+ if cls_dim is not None:
361
+ self.cls_embed = nn.Sequential(
362
+ nn.Linear(cls_dim, embed_dim, bias=True),
363
+ nn.SiLU(),
364
+ nn.Linear(embed_dim, embed_dim, bias=True),
365
+ )
366
+ else:
367
+ self.cls_embed = None
368
+
369
+ # time fusion
370
+ if time_fusion == 'token':
371
+ # put token at the beginning of sequence
372
+ self.extras = 2 if self.cls_embed else 1
373
+ self.time_pe = PE_wrapper(
374
+ dim=embed_dim, method='abs', length=self.extras
375
+ )
376
+ elif time_fusion in ['ada', 'ada_single', 'ada_sola', 'ada_sola_bias']:
377
+ self.use_adanorm = True
378
+ # aviod repetitive silu for each adaln block
379
+ self.time_act = nn.SiLU()
380
+ self.extras = 0
381
+ self.time_ada_final = nn.Linear(
382
+ embed_dim, 2 * embed_dim, bias=True
383
+ )
384
+ if time_fusion in ['ada_single', 'ada_sola', 'ada_sola_bias']:
385
+ # shared adaln
386
+ self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
387
+ else:
388
+ self.time_ada = None
389
+ else:
390
+ raise NotImplementedError
391
+ logger.info(f'time fusion mode: {self.time_fusion}')
392
+
393
+ # context
394
+ # use a simple projection
395
+ self.use_context = False
396
+ self.context_cross = False
397
+ self.context_max_length = context_max_length
398
+ self.context_fusion = 'none'
399
+ if context_dim is not None:
400
+ self.use_context = True
401
+ self.context_embed = nn.Sequential(
402
+ nn.Linear(context_dim, embed_dim, bias=True),
403
+ nn.SiLU(),
404
+ nn.Linear(embed_dim, embed_dim, bias=True),
405
+ )
406
+ self.context_fusion = context_fusion
407
+ if context_fusion == 'concat' or context_fusion == 'joint':
408
+ self.extras += context_max_length
409
+ self.context_pe = PE_wrapper(
410
+ dim=embed_dim,
411
+ method=context_pe_method,
412
+ length=context_max_length
413
+ )
414
+ # no cross attention layers
415
+ context_dim = None
416
+ elif context_fusion == 'cross':
417
+ self.context_pe = PE_wrapper(
418
+ dim=embed_dim,
419
+ method=context_pe_method,
420
+ length=context_max_length
421
+ )
422
+ self.context_cross = True
423
+ context_dim = embed_dim
424
+ else:
425
+ raise NotImplementedError
426
+ logger.info(f'context fusion mode: {context_fusion}')
427
+ logger.info(f'context position embedding: {context_pe_method}')
428
+
429
+ self.use_skip = skip
430
+
431
+ # norm layers
432
+ if norm_layer == 'layernorm':
433
+ norm_layer = nn.LayerNorm
434
+ elif norm_layer == 'rmsnorm':
435
+ norm_layer = RMSNorm
436
+ else:
437
+ raise NotImplementedError
438
+
439
+ logger.info(f'use long skip connection: {skip}')
440
+ self.in_blocks = nn.ModuleList([
441
+ DiTBlock(
442
+ dim=embed_dim,
443
+ context_dim=context_dim,
444
+ num_heads=num_heads,
445
+ mlp_ratio=mlp_ratio,
446
+ qkv_bias=qkv_bias,
447
+ qk_scale=qk_scale,
448
+ qk_norm=qk_norm,
449
+ act_layer=act_layer,
450
+ norm_layer=norm_layer,
451
+ time_fusion=time_fusion,
452
+ ada_sola_rank=ada_sola_rank,
453
+ ada_sola_alpha=ada_sola_alpha,
454
+ skip=False,
455
+ skip_norm=False,
456
+ rope_mode=self.rope,
457
+ context_norm=context_norm,
458
+ use_checkpoint=use_checkpoint
459
+ ) for _ in range(depth // 2)
460
+ ])
461
+
462
+ self.mid_block = DiTBlock(
463
+ dim=embed_dim,
464
+ context_dim=context_dim,
465
+ num_heads=num_heads,
466
+ mlp_ratio=mlp_ratio,
467
+ qkv_bias=qkv_bias,
468
+ qk_scale=qk_scale,
469
+ qk_norm=qk_norm,
470
+ act_layer=act_layer,
471
+ norm_layer=norm_layer,
472
+ time_fusion=time_fusion,
473
+ ada_sola_rank=ada_sola_rank,
474
+ ada_sola_alpha=ada_sola_alpha,
475
+ skip=False,
476
+ skip_norm=False,
477
+ rope_mode=self.rope,
478
+ context_norm=context_norm,
479
+ use_checkpoint=use_checkpoint
480
+ )
481
+
482
+ self.out_blocks = nn.ModuleList([
483
+ DiTBlock(
484
+ dim=embed_dim,
485
+ context_dim=context_dim,
486
+ num_heads=num_heads,
487
+ mlp_ratio=mlp_ratio,
488
+ qkv_bias=qkv_bias,
489
+ qk_scale=qk_scale,
490
+ qk_norm=qk_norm,
491
+ act_layer=act_layer,
492
+ norm_layer=norm_layer,
493
+ time_fusion=time_fusion,
494
+ ada_sola_rank=ada_sola_rank,
495
+ ada_sola_alpha=ada_sola_alpha,
496
+ skip=skip,
497
+ skip_norm=skip_norm,
498
+ rope_mode=self.rope,
499
+ context_norm=context_norm,
500
+ use_checkpoint=use_checkpoint
501
+ ) for _ in range(depth // 2)
502
+ ])
503
+
504
+ # FinalLayer block
505
+ self.use_conv = use_conv
506
+ self.final_block = FinalBlock(
507
+ embed_dim=embed_dim,
508
+ patch_size=patch_size,
509
+ img_size=img_size,
510
+ in_chans=out_chans,
511
+ input_type=input_type,
512
+ norm_layer=norm_layer,
513
+ use_conv=use_conv,
514
+ use_adanorm=self.use_adanorm
515
+ )
516
+ self.initialize_weights()
517
+
518
+ def _init_ada(self):
519
+ if self.time_fusion == 'ada':
520
+ nn.init.constant_(self.time_ada_final.weight, 0)
521
+ nn.init.constant_(self.time_ada_final.bias, 0)
522
+ for block in self.in_blocks:
523
+ nn.init.constant_(block.adaln.time_ada.weight, 0)
524
+ nn.init.constant_(block.adaln.time_ada.bias, 0)
525
+ nn.init.constant_(self.mid_block.adaln.time_ada.weight, 0)
526
+ nn.init.constant_(self.mid_block.adaln.time_ada.bias, 0)
527
+ for block in self.out_blocks:
528
+ nn.init.constant_(block.adaln.time_ada.weight, 0)
529
+ nn.init.constant_(block.adaln.time_ada.bias, 0)
530
+ elif self.time_fusion == 'ada_single':
531
+ nn.init.constant_(self.time_ada.weight, 0)
532
+ nn.init.constant_(self.time_ada.bias, 0)
533
+ nn.init.constant_(self.time_ada_final.weight, 0)
534
+ nn.init.constant_(self.time_ada_final.bias, 0)
535
+ elif self.time_fusion in ['ada_sola', 'ada_sola_bias']:
536
+ nn.init.constant_(self.time_ada.weight, 0)
537
+ nn.init.constant_(self.time_ada.bias, 0)
538
+ nn.init.constant_(self.time_ada_final.weight, 0)
539
+ nn.init.constant_(self.time_ada_final.bias, 0)
540
+ for block in self.in_blocks:
541
+ nn.init.kaiming_uniform_(
542
+ block.adaln.lora_a.weight, a=math.sqrt(5)
543
+ )
544
+ nn.init.constant_(block.adaln.lora_b.weight, 0)
545
+ nn.init.kaiming_uniform_(
546
+ self.mid_block.adaln.lora_a.weight, a=math.sqrt(5)
547
+ )
548
+ nn.init.constant_(self.mid_block.adaln.lora_b.weight, 0)
549
+ for block in self.out_blocks:
550
+ nn.init.kaiming_uniform_(
551
+ block.adaln.lora_a.weight, a=math.sqrt(5)
552
+ )
553
+ nn.init.constant_(block.adaln.lora_b.weight, 0)
554
+
555
+ def initialize_weights(self):
556
+ # Basic init for all layers
557
+ def _basic_init(module):
558
+ if isinstance(module, nn.Linear):
559
+ torch.nn.init.xavier_uniform_(module.weight)
560
+ if module.bias is not None:
561
+ nn.init.constant_(module.bias, 0)
562
+
563
+ self.apply(_basic_init)
564
+
565
+ # init patch Conv like Linear
566
+ w = self.patch_embed.proj.weight.data
567
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
568
+ nn.init.constant_(self.patch_embed.proj.bias, 0)
569
+
570
+ # Zero-out AdaLN
571
+ if self.use_adanorm:
572
+ self._init_ada()
573
+
574
+ # Zero-out Cross Attention
575
+ if self.context_cross:
576
+ for block in self.in_blocks:
577
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
578
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
579
+ nn.init.constant_(self.mid_block.cross_attn.proj.weight, 0)
580
+ nn.init.constant_(self.mid_block.cross_attn.proj.bias, 0)
581
+ for block in self.out_blocks:
582
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
583
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
584
+
585
+ # Zero-out cls embedding
586
+ if self.cls_embed:
587
+ if self.use_adanorm:
588
+ nn.init.constant_(self.cls_embed[-1].weight, 0)
589
+ nn.init.constant_(self.cls_embed[-1].bias, 0)
590
+
591
+ # Zero-out Output
592
+ # might not zero-out this when using v-prediction
593
+ # it could be good when using noise-prediction
594
+ # nn.init.constant_(self.final_block.linear.weight, 0)
595
+ # nn.init.constant_(self.final_block.linear.bias, 0)
596
+ # if self.use_conv:
597
+ # nn.init.constant_(self.final_block.final_layer.weight.data, 0)
598
+ # nn.init.constant_(self.final_block.final_layer.bias, 0)
599
+
600
+ # init out Conv
601
+ if self.use_conv:
602
+ nn.init.xavier_uniform_(self.final_block.final_layer.weight)
603
+ nn.init.constant_(self.final_block.final_layer.bias, 0)
604
+
605
+ def _concat_x_context(self, x, context, x_mask=None, context_mask=None):
606
+ assert context.shape[-2] == self.context_max_length
607
+ # Check if either x_mask or context_mask is provided
608
+ B = x.shape[0]
609
+ # Create default masks if they are not provided
610
+ if x_mask is None:
611
+ x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
612
+ if context_mask is None:
613
+ context_mask = torch.ones(
614
+ B, context.shape[-2], device=context.device
615
+ ).bool()
616
+ # Concatenate the masks along the second dimension (dim=1)
617
+ x_mask = torch.cat([context_mask, x_mask], dim=1)
618
+ # Concatenate context and x along the second dimension (dim=1)
619
+ x = torch.cat((context, x), dim=1)
620
+ return x, x_mask
621
+
622
+ def forward(
623
+ self,
624
+ x,
625
+ timesteps,
626
+ context,
627
+ x_mask=None,
628
+ context_mask=None,
629
+ cls_token=None,
630
+ controlnet_skips=None,
631
+ ):
632
+ # make it compatible with int time step during inference
633
+ if timesteps.dim() == 0:
634
+ timesteps = timesteps.expand(x.shape[0]
635
+ ).to(x.device, dtype=torch.long)
636
+
637
+ x = self.patch_embed(x)
638
+ x = self.x_pe(x)
639
+
640
+ B, L, D = x.shape
641
+
642
+ if self.use_context:
643
+ context_token = self.context_embed(context)
644
+ context_token = self.context_pe(context_token)
645
+ if self.context_fusion == 'concat' or self.context_fusion == 'joint':
646
+ x, x_mask = self._concat_x_context(
647
+ x=x,
648
+ context=context_token,
649
+ x_mask=x_mask,
650
+ context_mask=context_mask
651
+ )
652
+ context_token, context_mask = None, None
653
+ else:
654
+ context_token, context_mask = None, None
655
+
656
+ time_token = self.time_embed(timesteps)
657
+ if self.cls_embed:
658
+ cls_token = self.cls_embed(cls_token)
659
+ time_ada = None
660
+ time_ada_final = None
661
+ if self.use_adanorm:
662
+ if self.cls_embed:
663
+ time_token = time_token + cls_token
664
+ time_token = self.time_act(time_token)
665
+ time_ada_final = self.time_ada_final(time_token)
666
+ if self.time_ada is not None:
667
+ time_ada = self.time_ada(time_token)
668
+ else:
669
+ time_token = time_token.unsqueeze(dim=1)
670
+ if self.cls_embed:
671
+ cls_token = cls_token.unsqueeze(dim=1)
672
+ time_token = torch.cat([time_token, cls_token], dim=1)
673
+ time_token = self.time_pe(time_token)
674
+ x = torch.cat((time_token, x), dim=1)
675
+ if x_mask is not None:
676
+ x_mask = torch.cat([
677
+ torch.ones(B, time_token.shape[1],
678
+ device=x_mask.device).bool(), x_mask
679
+ ],
680
+ dim=1)
681
+ time_token = None
682
+
683
+ skips = []
684
+ for blk in self.in_blocks:
685
+ x = blk(
686
+ x=x,
687
+ time_token=time_token,
688
+ time_ada=time_ada,
689
+ skip=None,
690
+ context=context_token,
691
+ x_mask=x_mask,
692
+ context_mask=context_mask,
693
+ extras=self.extras
694
+ )
695
+ if self.use_skip:
696
+ skips.append(x)
697
+
698
+ x = self.mid_block(
699
+ x=x,
700
+ time_token=time_token,
701
+ time_ada=time_ada,
702
+ skip=None,
703
+ context=context_token,
704
+ x_mask=x_mask,
705
+ context_mask=context_mask,
706
+ extras=self.extras
707
+ )
708
+ for blk in self.out_blocks:
709
+ if self.use_skip:
710
+ skip = skips.pop()
711
+ if controlnet_skips:
712
+ # add to skip like u-net controlnet
713
+ skip = skip + controlnet_skips.pop()
714
+ else:
715
+ skip = None
716
+ if controlnet_skips:
717
+ # directly add to x
718
+ x = x + controlnet_skips.pop()
719
+
720
+ x = blk(
721
+ x=x,
722
+ time_token=time_token,
723
+ time_ada=time_ada,
724
+ skip=skip,
725
+ context=context_token,
726
+ x_mask=x_mask,
727
+ context_mask=context_mask,
728
+ extras=self.extras
729
+ )
730
+
731
+ x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
732
+
733
+ return x
734
+
735
+
736
+ class MaskDiT(nn.Module):
737
+ def __init__(
738
+ self,
739
+ model: UDiT,
740
+ mae=False,
741
+ mae_prob=0.5,
742
+ mask_ratio=[0.25, 1.0],
743
+ mask_span=10,
744
+ ):
745
+ super().__init__()
746
+ self.model = model
747
+ self.mae = mae
748
+ if self.mae:
749
+ out_channel = model.out_chans
750
+ self.mask_embed = nn.Parameter(torch.zeros((out_channel)))
751
+ self.mae_prob = mae_prob
752
+ self.mask_ratio = mask_ratio
753
+ self.mask_span = mask_span
754
+
755
+ def random_masking(self, gt, mask_ratios, mae_mask_infer=None):
756
+ B, D, L = gt.shape
757
+ if mae_mask_infer is None:
758
+ # mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1)
759
+ mask_ratios = mask_ratios.cpu().numpy()
760
+ mask = compute_mask_indices(
761
+ shape=[B, L],
762
+ padding_mask=None,
763
+ mask_prob=mask_ratios,
764
+ mask_length=self.mask_span,
765
+ mask_type="static",
766
+ mask_other=0.0,
767
+ min_masks=1,
768
+ no_overlap=False,
769
+ min_space=0,
770
+ )
771
+ mask = mask.unsqueeze(1).expand_as(gt)
772
+ else:
773
+ mask = mae_mask_infer
774
+ mask = mask.expand_as(gt)
775
+ gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask]
776
+ return gt, mask.type_as(gt)
777
+
778
+ def forward(
779
+ self,
780
+ x,
781
+ timesteps,
782
+ context,
783
+ x_mask=None,
784
+ context_mask=None,
785
+ cls_token=None,
786
+ gt=None,
787
+ mae_mask_infer=None,
788
+ forward_model=True
789
+ ):
790
+ # todo: handle controlnet inside
791
+ mae_mask = torch.ones_like(x)
792
+ if self.mae:
793
+ if gt is not None:
794
+ B, D, L = gt.shape
795
+ mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio
796
+ ).to(gt.device)
797
+ gt, mae_mask = self.random_masking(
798
+ gt, mask_ratios, mae_mask_infer
799
+ )
800
+ # apply mae only to the selected batches
801
+ if mae_mask_infer is None:
802
+ # determine mae batch
803
+ mae_batch = torch.rand(B) < self.mae_prob
804
+ gt[~mae_batch] = self.mask_embed.view(
805
+ 1, D, 1
806
+ ).expand_as(gt)[~mae_batch]
807
+ mae_mask[~mae_batch] = 1.0
808
+ else:
809
+ B, D, L = x.shape
810
+ gt = self.mask_embed.view(1, D, 1).expand_as(x)
811
+ x = torch.cat([x, gt, mae_mask[:, 0:1, :]], dim=1)
812
+
813
+ if forward_model:
814
+ x = self.model(
815
+ x=x,
816
+ timesteps=timesteps,
817
+ context=context,
818
+ x_mask=x_mask,
819
+ context_mask=context_mask,
820
+ cls_token=cls_token
821
+ )
822
+ # logger.info(mae_mask[:, 0, :].sum(dim=-1))
823
+ return x, mae_mask
models/dit/modules.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint
6
+ from torch.cuda.amp import autocast
7
+ import math
8
+ import einops
9
+ from einops import rearrange, repeat
10
+ from inspect import isfunction
11
+
12
+
13
+ def trunc_normal_(tensor, mean, std, a, b):
14
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
15
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
16
+ def norm_cdf(x):
17
+ # Computes standard normal cumulative distribution function
18
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
19
+
20
+ if (mean < a - 2*std) or (mean > b + 2*std):
21
+ warnings.warn(
22
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
23
+ "The distribution of values may be incorrect.",
24
+ stacklevel=2
25
+ )
26
+
27
+ with torch.no_grad():
28
+ # Values are generated by using a truncated uniform distribution and
29
+ # then using the inverse CDF for the normal distribution.
30
+ # Get upper and lower cdf values
31
+ l = norm_cdf((a-mean) / std)
32
+ u = norm_cdf((b-mean) / std)
33
+
34
+ # Uniformly fill tensor with values from [l, u], then translate to
35
+ # [2l-1, 2u-1].
36
+ tensor.uniform_(2*l - 1, 2*u - 1)
37
+
38
+ # Use inverse cdf transform for normal distribution to get truncated
39
+ # standard normal
40
+ tensor.erfinv_()
41
+
42
+ # Transform to proper mean, std
43
+ tensor.mul_(std * math.sqrt(2.))
44
+ tensor.add_(mean)
45
+
46
+ # Clamp to ensure it's in the proper range
47
+ tensor.clamp_(min=a, max=b)
48
+ return tensor
49
+
50
+
51
+ # disable in checkpoint mode
52
+ # @torch.jit.script
53
+ def film_modulate(x, shift, scale):
54
+ return x * (1+scale) + shift
55
+
56
+
57
+ def timestep_embedding(timesteps, dim, max_period=10000):
58
+ """
59
+ Create sinusoidal timestep embeddings.
60
+
61
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
62
+ These may be fractional.
63
+ :param dim: the dimension of the output.
64
+ :param max_period: controls the minimum frequency of the embeddings.
65
+ :return: an [N x dim] Tensor of positional embeddings.
66
+ """
67
+ half = dim // 2
68
+ freqs = torch.exp(
69
+ -math.log(max_period) *
70
+ torch.arange(start=0, end=half, dtype=torch.float32) / half
71
+ ).to(device=timesteps.device)
72
+ args = timesteps[:, None].float() * freqs[None]
73
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
74
+ if dim % 2:
75
+ embedding = torch.cat([embedding,
76
+ torch.zeros_like(embedding[:, :1])],
77
+ dim=-1)
78
+ return embedding
79
+
80
+
81
+ class TimestepEmbedder(nn.Module):
82
+ """
83
+ Embeds scalar timesteps into vector representations.
84
+ """
85
+ def __init__(
86
+ self, hidden_size, frequency_embedding_size=256, out_size=None
87
+ ):
88
+ super().__init__()
89
+ if out_size is None:
90
+ out_size = hidden_size
91
+ self.mlp = nn.Sequential(
92
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
93
+ nn.SiLU(),
94
+ nn.Linear(hidden_size, out_size, bias=True),
95
+ )
96
+ self.frequency_embedding_size = frequency_embedding_size
97
+
98
+ def forward(self, t):
99
+ t_freq = timestep_embedding(t, self.frequency_embedding_size).type(
100
+ self.mlp[0].weight.dtype
101
+ )
102
+ t_emb = self.mlp(t_freq)
103
+ return t_emb
104
+
105
+
106
+ def patchify(imgs, patch_size, input_type='2d'):
107
+ if input_type == '2d':
108
+ x = einops.rearrange(
109
+ imgs,
110
+ 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)',
111
+ p1=patch_size,
112
+ p2=patch_size
113
+ )
114
+ elif input_type == '1d':
115
+ x = einops.rearrange(imgs, 'B C (h p1) -> B h (p1 C)', p1=patch_size)
116
+ return x
117
+
118
+
119
+ def unpatchify(x, channels=3, input_type='2d', img_size=None):
120
+ if input_type == '2d':
121
+ patch_size = int((x.shape[2] // channels)**0.5)
122
+ # h = w = int(x.shape[1] ** .5)
123
+ h, w = img_size[0] // patch_size, img_size[1] // patch_size
124
+ assert h * w == x.shape[1] and patch_size**2 * channels == x.shape[2]
125
+ x = einops.rearrange(
126
+ x,
127
+ 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)',
128
+ h=h,
129
+ p1=patch_size,
130
+ p2=patch_size
131
+ )
132
+ elif input_type == '1d':
133
+ patch_size = int((x.shape[2] // channels))
134
+ h = x.shape[1]
135
+ assert patch_size * channels == x.shape[2]
136
+ x = einops.rearrange(x, 'B h (p1 C) -> B C (h p1)', h=h, p1=patch_size)
137
+ return x
138
+
139
+
140
+ class PatchEmbed(nn.Module):
141
+ """
142
+ Image to Patch Embedding
143
+ """
144
+ def __init__(self, patch_size, in_chans=3, embed_dim=768, input_type='2d'):
145
+ super().__init__()
146
+ self.patch_size = patch_size
147
+ self.input_type = input_type
148
+ if input_type == '2d':
149
+ self.proj = nn.Conv2d(
150
+ in_chans,
151
+ embed_dim,
152
+ kernel_size=patch_size,
153
+ stride=patch_size,
154
+ bias=True
155
+ )
156
+ elif input_type == '1d':
157
+ self.proj = nn.Conv1d(
158
+ in_chans,
159
+ embed_dim,
160
+ kernel_size=patch_size,
161
+ stride=patch_size,
162
+ bias=True
163
+ )
164
+
165
+ def forward(self, x):
166
+ if self.input_type == '2d':
167
+ B, C, H, W = x.shape
168
+ assert H % self.patch_size == 0 and W % self.patch_size == 0
169
+ elif self.input_type == '1d':
170
+ B, C, H = x.shape
171
+ assert H % self.patch_size == 0
172
+
173
+ x = self.proj(x).flatten(2).transpose(1, 2)
174
+ return x
175
+
176
+
177
+ class PositionalConvEmbedding(nn.Module):
178
+ """
179
+ Relative positional embedding used in HuBERT
180
+ """
181
+ def __init__(self, dim=768, kernel_size=128, groups=16):
182
+ super().__init__()
183
+ self.conv = nn.Conv1d(
184
+ dim,
185
+ dim,
186
+ kernel_size=kernel_size,
187
+ padding=kernel_size // 2,
188
+ groups=groups,
189
+ bias=True
190
+ )
191
+ self.conv = nn.utils.parametrizations.weight_norm(
192
+ self.conv, name="weight", dim=2
193
+ )
194
+
195
+ def forward(self, x):
196
+ # B C T
197
+ x = self.conv(x)
198
+ x = F.gelu(x[:, :, :-1])
199
+ return x
200
+
201
+
202
+ class SinusoidalPositionalEncoding(nn.Module):
203
+ def __init__(self, dim, length):
204
+ super(SinusoidalPositionalEncoding, self).__init__()
205
+ self.length = length
206
+ self.dim = dim
207
+ self.register_buffer(
208
+ 'pe', self._generate_positional_encoding(length, dim)
209
+ )
210
+
211
+ def _generate_positional_encoding(self, length, dim):
212
+ pe = torch.zeros(length, dim)
213
+ position = torch.arange(0, length, dtype=torch.float).unsqueeze(1)
214
+ div_term = torch.exp(
215
+ torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)
216
+ )
217
+
218
+ pe[:, 0::2] = torch.sin(position * div_term)
219
+ pe[:, 1::2] = torch.cos(position * div_term)
220
+
221
+ pe = pe.unsqueeze(0)
222
+ return pe
223
+
224
+ def forward(self, x):
225
+ x = x + self.pe[:, :x.size(1)]
226
+ return x
227
+
228
+
229
+ class PE_wrapper(nn.Module):
230
+ def __init__(self, dim=768, method='abs', length=None, **kwargs):
231
+ super().__init__()
232
+ self.method = method
233
+ if method == 'abs':
234
+ # init absolute pe like UViT
235
+ self.length = length
236
+ self.abs_pe = nn.Parameter(torch.zeros(1, length, dim))
237
+ trunc_normal_(self.abs_pe, std=.02)
238
+ elif method == 'conv':
239
+ self.conv_pe = PositionalConvEmbedding(dim=dim, **kwargs)
240
+ elif method == 'sinu':
241
+ self.sinu_pe = SinusoidalPositionalEncoding(dim=dim, length=length)
242
+ elif method == 'none':
243
+ # skip pe
244
+ self.id = nn.Identity()
245
+ else:
246
+ raise NotImplementedError
247
+
248
+ def forward(self, x):
249
+ if self.method == 'abs':
250
+ _, L, _ = x.shape
251
+ assert L <= self.length
252
+ x = x + self.abs_pe[:, :L, :]
253
+ elif self.method == 'conv':
254
+ x = x + self.conv_pe(x)
255
+ elif self.method == 'sinu':
256
+ x = self.sinu_pe(x)
257
+ elif self.method == 'none':
258
+ x = self.id(x)
259
+ else:
260
+ raise NotImplementedError
261
+ return x
262
+
263
+
264
+ class RMSNorm(torch.nn.Module):
265
+ def __init__(self, dim: int, eps: float = 1e-6):
266
+ """
267
+ Initialize the RMSNorm normalization layer.
268
+
269
+ Args:
270
+ dim (int): The dimension of the input tensor.
271
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
272
+
273
+ Attributes:
274
+ eps (float): A small value added to the denominator for numerical stability.
275
+ weight (nn.Parameter): Learnable scaling parameter.
276
+
277
+ """
278
+ super().__init__()
279
+ self.eps = eps
280
+ self.weight = nn.Parameter(torch.ones(dim))
281
+
282
+ def _norm(self, x):
283
+ """
284
+ Apply the RMSNorm normalization to the input tensor.
285
+
286
+ Args:
287
+ x (torch.Tensor): The input tensor.
288
+
289
+ Returns:
290
+ torch.Tensor: The normalized tensor.
291
+
292
+ """
293
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
294
+
295
+ def forward(self, x):
296
+ """
297
+ Forward pass through the RMSNorm layer.
298
+
299
+ Args:
300
+ x (torch.Tensor): The input tensor.
301
+
302
+ Returns:
303
+ torch.Tensor: The output tensor after applying RMSNorm.
304
+
305
+ """
306
+ output = self._norm(x.float()).type_as(x)
307
+ return output * self.weight
308
+
309
+
310
+ class GELU(nn.Module):
311
+ def __init__(
312
+ self,
313
+ dim_in: int,
314
+ dim_out: int,
315
+ approximate: str = "none",
316
+ bias: bool = True
317
+ ):
318
+ super().__init__()
319
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
320
+ self.approximate = approximate
321
+
322
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
323
+ if gate.device.type != "mps":
324
+ return F.gelu(gate, approximate=self.approximate)
325
+ # mps: gelu is not implemented for float16
326
+ return F.gelu(
327
+ gate.to(dtype=torch.float32), approximate=self.approximate
328
+ ).to(dtype=gate.dtype)
329
+
330
+ def forward(self, hidden_states):
331
+ hidden_states = self.proj(hidden_states)
332
+ hidden_states = self.gelu(hidden_states)
333
+ return hidden_states
334
+
335
+
336
+ class GEGLU(nn.Module):
337
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
338
+ super().__init__()
339
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
340
+
341
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
342
+ if gate.device.type != "mps":
343
+ return F.gelu(gate)
344
+ # mps: gelu is not implemented for float16
345
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
346
+
347
+ def forward(self, hidden_states):
348
+ hidden_states = self.proj(hidden_states)
349
+ hidden_states, gate = hidden_states.chunk(2, dim=-1)
350
+ return hidden_states * self.gelu(gate)
351
+
352
+
353
+ class ApproximateGELU(nn.Module):
354
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
355
+ super().__init__()
356
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
357
+
358
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
359
+ x = self.proj(x)
360
+ return x * torch.sigmoid(1.702 * x)
361
+
362
+
363
+ # disable in checkpoint mode
364
+ # @torch.jit.script
365
+ def snake_beta(x, alpha, beta):
366
+ return x + beta * torch.sin(x * alpha).pow(2)
367
+
368
+
369
+ class Snake(nn.Module):
370
+ def __init__(self, dim_in, dim_out, bias, alpha_trainable=True):
371
+ super().__init__()
372
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
373
+ self.alpha = nn.Parameter(torch.ones(1, 1, dim_out))
374
+ self.beta = nn.Parameter(torch.ones(1, 1, dim_out))
375
+ self.alpha.requires_grad = alpha_trainable
376
+ self.beta.requires_grad = alpha_trainable
377
+
378
+ def forward(self, x):
379
+ x = self.proj(x)
380
+ x = snake_beta(x, self.alpha, self.beta)
381
+ return x
382
+
383
+
384
+ class GESnake(nn.Module):
385
+ def __init__(self, dim_in, dim_out, bias, alpha_trainable=True):
386
+ super().__init__()
387
+ self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
388
+ self.alpha = nn.Parameter(torch.ones(1, 1, dim_out))
389
+ self.beta = nn.Parameter(torch.ones(1, 1, dim_out))
390
+ self.alpha.requires_grad = alpha_trainable
391
+ self.beta.requires_grad = alpha_trainable
392
+
393
+ def forward(self, x):
394
+ x = self.proj(x)
395
+ x, gate = x.chunk(2, dim=-1)
396
+ return x * snake_beta(gate, self.alpha, self.beta)
397
+
398
+
399
+ class FeedForward(nn.Module):
400
+ def __init__(
401
+ self,
402
+ dim,
403
+ dim_out=None,
404
+ mult=4,
405
+ dropout=0.0,
406
+ activation_fn="geglu",
407
+ final_dropout=False,
408
+ inner_dim=None,
409
+ bias=True,
410
+ ):
411
+ super().__init__()
412
+ if inner_dim is None:
413
+ inner_dim = int(dim * mult)
414
+ dim_out = dim_out if dim_out is not None else dim
415
+
416
+ if activation_fn == "gelu":
417
+ act_fn = GELU(dim, inner_dim, bias=bias)
418
+ elif activation_fn == "gelu-approximate":
419
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
420
+ elif activation_fn == "geglu":
421
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
422
+ elif activation_fn == "geglu-approximate":
423
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
424
+ elif activation_fn == "snake":
425
+ act_fn = Snake(dim, inner_dim, bias=bias)
426
+ elif activation_fn == "gesnake":
427
+ act_fn = GESnake(dim, inner_dim, bias=bias)
428
+ else:
429
+ raise NotImplementedError
430
+
431
+ self.net = nn.ModuleList([])
432
+ # project in
433
+ self.net.append(act_fn)
434
+ # project dropout
435
+ self.net.append(nn.Dropout(dropout))
436
+ # project out
437
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
438
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
439
+ if final_dropout:
440
+ self.net.append(nn.Dropout(dropout))
441
+
442
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
443
+ for module in self.net:
444
+ hidden_states = module(hidden_states)
445
+ return hidden_states
models/dit/rotary.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ "this rope is faster than llama rope with jit script"
3
+
4
+
5
+ def rotate_half(x):
6
+ x1, x2 = x.chunk(2, dim=-1)
7
+ return torch.cat((-x2, x1), dim=-1)
8
+
9
+
10
+ # disable in checkpoint mode
11
+ # @torch.jit.script
12
+ def apply_rotary_pos_emb(x, cos, sin):
13
+ # NOTE: This could probably be moved to Triton
14
+ # Handle a possible sequence length mismatch in between q and k
15
+ cos = cos[:, :, :x.shape[-2], :]
16
+ sin = sin[:, :, :x.shape[-2], :]
17
+ return (x*cos) + (rotate_half(x) * sin)
18
+
19
+
20
+ class RotaryEmbedding(torch.nn.Module):
21
+ """
22
+ The rotary position embeddings from RoFormer_ (Su et. al).
23
+ A crucial insight from the method is that the query and keys are
24
+ transformed by rotation matrices which depend on the relative positions.
25
+
26
+ Other implementations are available in the Rotary Transformer repo_ and in
27
+ GPT-NeoX_, GPT-NeoX was an inspiration
28
+
29
+ .. _RoFormer: https://arxiv.org/abs/2104.09864
30
+ .. _repo: https://github.com/ZhuiyiTechnology/roformer
31
+ .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
32
+
33
+
34
+ .. warning: Please note that this embedding is not registered on purpose, as it is transformative
35
+ (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis
36
+ """
37
+ def __init__(self, dim: int):
38
+ super().__init__()
39
+ # Generate and save the inverse frequency buffer (non trainable)
40
+ inv_freq = 1.0 / (10000**(torch.arange(0, dim, 2).float() / dim))
41
+ self.register_buffer("inv_freq", inv_freq)
42
+ self._seq_len_cached = None
43
+ self._cos_cached = None
44
+ self._sin_cached = None
45
+
46
+ def _update_cos_sin_tables(self, x, seq_dimension=-2):
47
+ # expect input: B, H, L, D
48
+ seq_len = x.shape[seq_dimension]
49
+
50
+ # Reset the tables if the sequence length has changed,
51
+ # or if we're on a new device (possibly due to tracing for instance)
52
+ # also make sure dtype wont change
53
+ if (
54
+ seq_len != self._seq_len_cached or
55
+ self._cos_cached.device != x.device or
56
+ self._cos_cached.dtype != x.dtype
57
+ ):
58
+ self._seq_len_cached = seq_len
59
+ t = torch.arange(
60
+ x.shape[seq_dimension], device=x.device, dtype=torch.float32
61
+ )
62
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype))
63
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
64
+
65
+ self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
66
+ self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
67
+
68
+ return self._cos_cached, self._sin_cached
69
+
70
+ def forward(self, q, k):
71
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
72
+ q.float(), seq_dimension=-2
73
+ )
74
+ if k is not None:
75
+ return (
76
+ apply_rotary_pos_emb(
77
+ q.float(), self._cos_cached, self._sin_cached
78
+ ).type_as(q),
79
+ apply_rotary_pos_emb(
80
+ k.float(), self._cos_cached, self._sin_cached
81
+ ).type_as(k),
82
+ )
83
+ else:
84
+ return (
85
+ apply_rotary_pos_emb(
86
+ q.float(), self._cos_cached, self._sin_cached
87
+ ).type_as(q), None
88
+ )
models/dit/span_mask.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from typing import Optional, Tuple
4
+
5
+
6
+ def compute_mask_indices(
7
+ shape: Tuple[int, int],
8
+ padding_mask: Optional[torch.Tensor],
9
+ mask_prob: float,
10
+ mask_length: int,
11
+ mask_type: str = "static",
12
+ mask_other: float = 0.0,
13
+ min_masks: int = 0,
14
+ no_overlap: bool = False,
15
+ min_space: int = 0,
16
+ ) -> np.ndarray:
17
+ """
18
+ Computes random mask spans for a given shape
19
+
20
+ Args:
21
+ shape: the the shape for which to compute masks.
22
+ should be of size 2 where first element is batch size and 2nd is timesteps
23
+ padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
24
+ mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
25
+ number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
26
+ however due to overlaps, the actual number will be smaller (unless no_overlap is True)
27
+ mask_type: how to compute mask lengths
28
+ static = fixed size
29
+ uniform = sample from uniform distribution [mask_other, mask_length*2]
30
+ normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
31
+ poisson = sample from possion distribution with lambda = mask length
32
+ min_masks: minimum number of masked spans
33
+ no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
34
+ min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
35
+ """
36
+
37
+ bsz, all_sz = shape
38
+ mask = np.full((bsz, all_sz), False)
39
+
40
+ # Convert mask_prob to a NumPy array
41
+ mask_prob = np.array(mask_prob)
42
+
43
+ # Calculate all_num_mask for each element in the batch
44
+ all_num_mask = np.floor(
45
+ mask_prob * all_sz / float(mask_length) + np.random.rand(bsz)
46
+ ).astype(int)
47
+
48
+ # Apply the max operation with min_masks for each element
49
+ all_num_mask = np.maximum(min_masks, all_num_mask)
50
+
51
+ mask_idcs = []
52
+ for i in range(bsz):
53
+ if padding_mask is not None:
54
+ sz = all_sz - padding_mask[i].long().sum().item()
55
+ num_mask = int(
56
+ # add a random number for probabilistic rounding
57
+ mask_prob * sz / float(mask_length) + np.random.rand()
58
+ )
59
+ num_mask = max(min_masks, num_mask)
60
+ else:
61
+ sz = all_sz
62
+ num_mask = all_num_mask[i]
63
+
64
+ if mask_type == "static":
65
+ lengths = np.full(num_mask, mask_length)
66
+ elif mask_type == "uniform":
67
+ lengths = np.random.randint(
68
+ mask_other, mask_length*2 + 1, size=num_mask
69
+ )
70
+ elif mask_type == "normal":
71
+ lengths = np.random.normal(mask_length, mask_other, size=num_mask)
72
+ lengths = [max(1, int(round(x))) for x in lengths]
73
+ elif mask_type == "poisson":
74
+ lengths = np.random.poisson(mask_length, size=num_mask)
75
+ lengths = [int(round(x)) for x in lengths]
76
+ else:
77
+ raise Exception("unknown mask selection " + mask_type)
78
+
79
+ if sum(lengths) == 0:
80
+ lengths[0] = min(mask_length, sz - 1)
81
+
82
+ if no_overlap:
83
+ mask_idc = []
84
+
85
+ def arrange(s, e, length, keep_length):
86
+ span_start = np.random.randint(s, e - length)
87
+ mask_idc.extend(span_start + i for i in range(length))
88
+
89
+ new_parts = []
90
+ if span_start - s - min_space >= keep_length:
91
+ new_parts.append((s, span_start - min_space + 1))
92
+ if e - span_start - keep_length - min_space > keep_length:
93
+ new_parts.append((span_start + length + min_space, e))
94
+ return new_parts
95
+
96
+ parts = [(0, sz)]
97
+ min_length = min(lengths)
98
+ for length in sorted(lengths, reverse=True):
99
+ lens = np.fromiter(
100
+ (
101
+ e - s if e - s >= length + min_space else 0
102
+ for s, e in parts
103
+ ),
104
+ np.int,
105
+ )
106
+ l_sum = np.sum(lens)
107
+ if l_sum == 0:
108
+ break
109
+ probs = lens / np.sum(lens)
110
+ c = np.random.choice(len(parts), p=probs)
111
+ s, e = parts.pop(c)
112
+ parts.extend(arrange(s, e, length, min_length))
113
+ mask_idc = np.asarray(mask_idc)
114
+ else:
115
+ min_len = min(lengths)
116
+ if sz - min_len <= num_mask:
117
+ min_len = sz - num_mask - 1
118
+
119
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
120
+
121
+ mask_idc = np.asarray([
122
+ mask_idc[j] + offset for j in range(len(mask_idc))
123
+ for offset in range(lengths[j])
124
+ ])
125
+
126
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
127
+ # min_len = min([len(m) for m in mask_idcs])
128
+ for i, mask_idc in enumerate(mask_idcs):
129
+ # if len(mask_idc) > min_len:
130
+ # mask_idc = np.random.choice(mask_idc, min_len, replace=False)
131
+ mask[i, mask_idc] = True
132
+
133
+ return torch.tensor(mask)
134
+
135
+
136
+ if __name__ == '__main__':
137
+ mask = compute_mask_indices(
138
+ shape=[4, 500],
139
+ padding_mask=None,
140
+ mask_prob=[0.65, 0.5, 0.65, 0.65],
141
+ mask_length=10,
142
+ mask_type="static",
143
+ mask_other=0.0,
144
+ min_masks=1,
145
+ no_overlap=False,
146
+ min_space=0,
147
+ )
148
+ print(mask)
149
+ print(mask.sum(dim=1))