ChristophSchuhmann commited on
Commit
dfd1909
·
verified ·
1 Parent(s): 193cecb

Add model code, inference script, and examples

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. .gitignore +8 -0
  3. Assets/ExampleInput/music.wav +3 -0
  4. Assets/ExampleInput/soundeffects.wav +3 -0
  5. Assets/ExampleInput/speech.wav +3 -0
  6. Assets/Figure.png +3 -0
  7. FlashSR/AudioSR/AudioSRUnet.py +1127 -0
  8. FlashSR/AudioSR/EncoderDecoder.py +1010 -0
  9. FlashSR/AudioSR/Vocoder.py +167 -0
  10. FlashSR/AudioSR/args/mel_argument.yaml +6 -0
  11. FlashSR/AudioSR/args/model_argument.yaml +25 -0
  12. FlashSR/AudioSR/autoencoder.py +370 -0
  13. FlashSR/AudioSR/hifigan/LICENSE +21 -0
  14. FlashSR/AudioSR/hifigan/__init__.py +8 -0
  15. FlashSR/AudioSR/hifigan/models.py +174 -0
  16. FlashSR/AudioSR/hifigan/models_v2.py +395 -0
  17. FlashSR/AudioSR/latent_diffusion/__init__.py +0 -0
  18. FlashSR/AudioSR/latent_diffusion/modules/attention.py +467 -0
  19. FlashSR/AudioSR/latent_diffusion/modules/audiomae/AudioMAE.py +149 -0
  20. FlashSR/AudioSR/latent_diffusion/modules/audiomae/__init__.py +0 -0
  21. FlashSR/AudioSR/latent_diffusion/modules/audiomae/models_mae.py +613 -0
  22. FlashSR/AudioSR/latent_diffusion/modules/audiomae/models_vit.py +243 -0
  23. FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/crop.py +43 -0
  24. FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/datasets.py +67 -0
  25. FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lars.py +60 -0
  26. FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lr_decay.py +76 -0
  27. FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lr_sched.py +28 -0
  28. FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/misc.py +453 -0
  29. FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/patch_embed.py +127 -0
  30. FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/pos_embed.py +206 -0
  31. FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/stat.py +76 -0
  32. FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/__init__.py +0 -0
  33. FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/model.py +1069 -0
  34. FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/openaimodel.py +1103 -0
  35. FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/util.py +294 -0
  36. FlashSR/AudioSR/latent_diffusion/modules/distributions/__init__.py +0 -0
  37. FlashSR/AudioSR/latent_diffusion/modules/distributions/distributions.py +102 -0
  38. FlashSR/AudioSR/latent_diffusion/modules/ema.py +82 -0
  39. FlashSR/AudioSR/latent_diffusion/modules/encoders/__init__.py +0 -0
  40. FlashSR/AudioSR/latent_diffusion/modules/encoders/modules.py +682 -0
  41. FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/__init__.py +0 -0
  42. FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/attentions.py +430 -0
  43. FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/commons.py +161 -0
  44. FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/encoder.py +50 -0
  45. FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/LICENSE +19 -0
  46. FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/__init__.py +52 -0
  47. FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/cleaners.py +110 -0
  48. FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/symbols.py +16 -0
  49. FlashSR/AudioSR/latent_diffusion/util.py +267 -0
  50. FlashSR/BigVGAN/LICENSE +21 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Assets/ExampleInput/music.wav filter=lfs diff=lfs merge=lfs -text
37
+ Assets/ExampleInput/soundeffects.wav filter=lfs diff=lfs merge=lfs -text
38
+ Assets/ExampleInput/speech.wav filter=lfs diff=lfs merge=lfs -text
39
+ Assets/Figure.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ .eggs/
5
+ *.egg-info/
6
+ dist/
7
+ build/
8
+ .env
Assets/ExampleInput/music.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57031bac5cdc14b2be1c50811eca257070bf1fa9fb98d2aebeb7eda86c67ceaa
3
+ size 491564
Assets/ExampleInput/soundeffects.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96b31750ddb93397789224ddc047ee21cb34760f536818e9ea4cd2328d8dfe69
3
+ size 491564
Assets/ExampleInput/speech.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:046907c4a7a5bcb9f1b1bb49623fdc4d01be215d6c2783ff5157150938ed21a2
3
+ size 491564
Assets/Figure.png ADDED

Git LFS Details

  • SHA256: 70639d9a3d6fd61ac73b97104cc082e816a81b4b211ac9092cfee557d8aff6c7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB
FlashSR/AudioSR/AudioSRUnet.py ADDED
@@ -0,0 +1,1127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ from TorchJaekwon.Util.Util import Util
3
+ Util.set_sys_path_to_parent_dir(__file__,3)
4
+ import sys, os
5
+ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
6
+ '''
7
+ ################################################################################
8
+ #just copy from audiosr/latent_diffusion/modules/diffusionmodules/openaimodel.py
9
+ from abc import abstractmethod
10
+ import math
11
+
12
+ import numpy as np
13
+ import torch as th
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ from FlashSR.AudioSR.latent_diffusion.modules.diffusionmodules.util import (
18
+ checkpoint,
19
+ conv_nd,
20
+ linear,
21
+ avg_pool_nd,
22
+ zero_module,
23
+ normalization,
24
+ timestep_embedding,
25
+ )
26
+ from FlashSR.AudioSR.latent_diffusion.modules.attention import SpatialTransformer
27
+
28
+
29
+ # dummy replace
30
+ def convert_module_to_f16(x):
31
+ pass
32
+
33
+
34
+ def convert_module_to_f32(x):
35
+ pass
36
+
37
+
38
+ ## go
39
+ class AttentionPool2d(nn.Module):
40
+ """
41
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ spacial_dim: int,
47
+ embed_dim: int,
48
+ num_heads_channels: int,
49
+ output_dim: int = None,
50
+ ):
51
+ super().__init__()
52
+ self.positional_embedding = nn.Parameter(
53
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
54
+ )
55
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
56
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
57
+ self.num_heads = embed_dim // num_heads_channels
58
+ self.attention = QKVAttention(self.num_heads)
59
+
60
+ def forward(self, x):
61
+ b, c, *_spatial = x.shape
62
+ x = x.reshape(b, c, -1).contiguous() # NC(HW)
63
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
64
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
65
+ x = self.qkv_proj(x)
66
+ x = self.attention(x)
67
+ x = self.c_proj(x)
68
+ return x[:, :, 0]
69
+
70
+
71
+ class TimestepBlock(nn.Module):
72
+ """
73
+ Any module where forward() takes timestep embeddings as a second argument.
74
+ """
75
+
76
+ @abstractmethod
77
+ def forward(self, x, emb):
78
+ """
79
+ Apply the module to `x` given `emb` timestep embeddings.
80
+ """
81
+
82
+
83
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
84
+ """
85
+ A sequential module that passes timestep embeddings to the children that
86
+ support it as an extra input.
87
+ """
88
+
89
+ def forward(self, x, emb, context_list=None, mask_list=None):
90
+ # The first spatial transformer block does not have context
91
+ spatial_transformer_id = 0
92
+ context_list = [None] + context_list
93
+ mask_list = [None] + mask_list
94
+
95
+ for layer in self:
96
+ if isinstance(layer, TimestepBlock):
97
+ x = layer(x, emb)
98
+ elif isinstance(layer, SpatialTransformer):
99
+ if spatial_transformer_id >= len(context_list):
100
+ context, mask = None, None
101
+ else:
102
+ context, mask = (
103
+ context_list[spatial_transformer_id],
104
+ mask_list[spatial_transformer_id],
105
+ )
106
+
107
+ x = layer(x, context, mask=mask)
108
+ spatial_transformer_id += 1
109
+ else:
110
+ x = layer(x)
111
+ return x
112
+
113
+
114
+ class Upsample(nn.Module):
115
+ """
116
+ An upsampling layer with an optional convolution.
117
+ :param channels: channels in the inputs and outputs.
118
+ :param use_conv: a bool determining if a convolution is applied.
119
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
120
+ upsampling occurs in the inner-two dimensions.
121
+ """
122
+
123
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
124
+ super().__init__()
125
+ self.channels = channels
126
+ self.out_channels = out_channels or channels
127
+ self.use_conv = use_conv
128
+ self.dims = dims
129
+ if use_conv:
130
+ self.conv = conv_nd(
131
+ dims, self.channels, self.out_channels, 3, padding=padding
132
+ )
133
+
134
+ def forward(self, x):
135
+ assert x.shape[1] == self.channels
136
+ if self.dims == 3:
137
+ x = F.interpolate(
138
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
139
+ )
140
+ else:
141
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
142
+ if self.use_conv:
143
+ x = self.conv(x)
144
+ return x
145
+
146
+
147
+ class TransposedUpsample(nn.Module):
148
+ "Learned 2x upsampling without padding"
149
+
150
+ def __init__(self, channels, out_channels=None, ks=5):
151
+ super().__init__()
152
+ self.channels = channels
153
+ self.out_channels = out_channels or channels
154
+
155
+ self.up = nn.ConvTranspose2d(
156
+ self.channels, self.out_channels, kernel_size=ks, stride=2
157
+ )
158
+
159
+ def forward(self, x):
160
+ return self.up(x)
161
+
162
+
163
+ class Downsample(nn.Module):
164
+ """
165
+ A downsampling layer with an optional convolution.
166
+ :param channels: channels in the inputs and outputs.
167
+ :param use_conv: a bool determining if a convolution is applied.
168
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
169
+ downsampling occurs in the inner-two dimensions.
170
+ """
171
+
172
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
173
+ super().__init__()
174
+ self.channels = channels
175
+ self.out_channels = out_channels or channels
176
+ self.use_conv = use_conv
177
+ self.dims = dims
178
+ stride = 2 if dims != 3 else (1, 2, 2)
179
+ if use_conv:
180
+ self.op = conv_nd(
181
+ dims,
182
+ self.channels,
183
+ self.out_channels,
184
+ 3,
185
+ stride=stride,
186
+ padding=padding,
187
+ )
188
+ else:
189
+ assert self.channels == self.out_channels
190
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
191
+
192
+ def forward(self, x):
193
+ assert x.shape[1] == self.channels
194
+ return self.op(x)
195
+
196
+
197
+ class ResBlock(TimestepBlock):
198
+ """
199
+ A residual block that can optionally change the number of channels.
200
+ :param channels: the number of input channels.
201
+ :param emb_channels: the number of timestep embedding channels.
202
+ :param dropout: the rate of dropout.
203
+ :param out_channels: if specified, the number of out channels.
204
+ :param use_conv: if True and out_channels is specified, use a spatial
205
+ convolution instead of a smaller 1x1 convolution to change the
206
+ channels in the skip connection.
207
+ :param dims: determines if the signal is 1D, 2D, or 3D.
208
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
209
+ :param up: if True, use this block for upsampling.
210
+ :param down: if True, use this block for downsampling.
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ channels,
216
+ emb_channels,
217
+ dropout,
218
+ out_channels=None,
219
+ use_conv=False,
220
+ use_scale_shift_norm=False,
221
+ dims=2,
222
+ use_checkpoint=False,
223
+ up=False,
224
+ down=False,
225
+ ):
226
+ super().__init__()
227
+ self.channels = channels
228
+ self.emb_channels = emb_channels
229
+ self.dropout = dropout
230
+ self.out_channels = out_channels or channels
231
+ self.use_conv = use_conv
232
+ self.use_checkpoint = use_checkpoint
233
+ self.use_scale_shift_norm = use_scale_shift_norm
234
+
235
+ self.in_layers = nn.Sequential(
236
+ normalization(channels),
237
+ nn.SiLU(),
238
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
239
+ )
240
+
241
+ self.updown = up or down
242
+
243
+ if up:
244
+ self.h_upd = Upsample(channels, False, dims)
245
+ self.x_upd = Upsample(channels, False, dims)
246
+ elif down:
247
+ self.h_upd = Downsample(channels, False, dims)
248
+ self.x_upd = Downsample(channels, False, dims)
249
+ else:
250
+ self.h_upd = self.x_upd = nn.Identity()
251
+
252
+ self.emb_layers = nn.Sequential(
253
+ nn.SiLU(),
254
+ linear(
255
+ emb_channels,
256
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
257
+ ),
258
+ )
259
+ self.out_layers = nn.Sequential(
260
+ normalization(self.out_channels),
261
+ nn.SiLU(),
262
+ nn.Dropout(p=dropout),
263
+ zero_module(
264
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
265
+ ),
266
+ )
267
+
268
+ if self.out_channels == channels:
269
+ self.skip_connection = nn.Identity()
270
+ elif use_conv:
271
+ self.skip_connection = conv_nd(
272
+ dims, channels, self.out_channels, 3, padding=1
273
+ )
274
+ else:
275
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
276
+
277
+ def forward(self, x, emb):
278
+ """
279
+ Apply the block to a Tensor, conditioned on a timestep embedding.
280
+ :param x: an [N x C x ...] Tensor of features.
281
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
282
+ :return: an [N x C x ...] Tensor of outputs.
283
+ """
284
+ return checkpoint(
285
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
286
+ )
287
+
288
+ def _forward(self, x, emb):
289
+ if self.updown:
290
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
291
+ h = in_rest(x)
292
+ h = self.h_upd(h)
293
+ x = self.x_upd(x)
294
+ h = in_conv(h)
295
+ else:
296
+ h = self.in_layers(x)
297
+ emb_out = self.emb_layers(emb).type(h.dtype)
298
+ while len(emb_out.shape) < len(h.shape):
299
+ emb_out = emb_out[..., None]
300
+ if self.use_scale_shift_norm:
301
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
302
+ scale, shift = th.chunk(emb_out, 2, dim=1)
303
+ h = out_norm(h) * (1 + scale) + shift
304
+ h = out_rest(h)
305
+ else:
306
+ h = h + emb_out
307
+ h = self.out_layers(h)
308
+ return self.skip_connection(x) + h
309
+
310
+
311
+ class AttentionBlock(nn.Module):
312
+ """
313
+ An attention block that allows spatial positions to attend to each other.
314
+ Originally ported from here, but adapted to the N-d case.
315
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
316
+ """
317
+
318
+ def __init__(
319
+ self,
320
+ channels,
321
+ num_heads=1,
322
+ num_head_channels=-1,
323
+ use_checkpoint=False,
324
+ use_new_attention_order=False,
325
+ ):
326
+ super().__init__()
327
+ self.channels = channels
328
+ if num_head_channels == -1:
329
+ self.num_heads = num_heads
330
+ else:
331
+ assert (
332
+ channels % num_head_channels == 0
333
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
334
+ self.num_heads = channels // num_head_channels
335
+ self.use_checkpoint = use_checkpoint
336
+ self.norm = normalization(channels)
337
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
338
+ if use_new_attention_order:
339
+ # split qkv before split heads
340
+ self.attention = QKVAttention(self.num_heads)
341
+ else:
342
+ # split heads before split qkv
343
+ self.attention = QKVAttentionLegacy(self.num_heads)
344
+
345
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
346
+
347
+ def forward(self, x):
348
+ return checkpoint(
349
+ self._forward, (x,), self.parameters(), True
350
+ ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
351
+ # return pt_checkpoint(self._forward, x) # pytorch
352
+
353
+ def _forward(self, x):
354
+ b, c, *spatial = x.shape
355
+ x = x.reshape(b, c, -1).contiguous()
356
+ qkv = self.qkv(self.norm(x)).contiguous()
357
+ h = self.attention(qkv).contiguous()
358
+ h = self.proj_out(h).contiguous()
359
+ return (x + h).reshape(b, c, *spatial).contiguous()
360
+
361
+
362
+ def count_flops_attn(model, _x, y):
363
+ """
364
+ A counter for the `thop` package to count the operations in an
365
+ attention operation.
366
+ Meant to be used like:
367
+ macs, params = thop.profile(
368
+ model,
369
+ inputs=(inputs, timestamps),
370
+ custom_ops={QKVAttention: QKVAttention.count_flops},
371
+ )
372
+ """
373
+ b, c, *spatial = y[0].shape
374
+ num_spatial = int(np.prod(spatial))
375
+ # We perform two matmuls with the same number of ops.
376
+ # The first computes the weight matrix, the second computes
377
+ # the combination of the value vectors.
378
+ matmul_ops = 2 * b * (num_spatial**2) * c
379
+ model.total_ops += th.DoubleTensor([matmul_ops])
380
+
381
+
382
+ class QKVAttentionLegacy(nn.Module):
383
+ """
384
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
385
+ """
386
+
387
+ def __init__(self, n_heads):
388
+ super().__init__()
389
+ self.n_heads = n_heads
390
+
391
+ def forward(self, qkv):
392
+ """
393
+ Apply QKV attention.
394
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
395
+ :return: an [N x (H * C) x T] tensor after attention.
396
+ """
397
+ bs, width, length = qkv.shape
398
+ assert width % (3 * self.n_heads) == 0
399
+ ch = width // (3 * self.n_heads)
400
+ q, k, v = (
401
+ qkv.reshape(bs * self.n_heads, ch * 3, length).contiguous().split(ch, dim=1)
402
+ )
403
+ scale = 1 / math.sqrt(math.sqrt(ch))
404
+ weight = th.einsum(
405
+ "bct,bcs->bts", q * scale, k * scale
406
+ ) # More stable with f16 than dividing afterwards
407
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
408
+ a = th.einsum("bts,bcs->bct", weight, v)
409
+ return a.reshape(bs, -1, length).contiguous()
410
+
411
+ @staticmethod
412
+ def count_flops(model, _x, y):
413
+ return count_flops_attn(model, _x, y)
414
+
415
+
416
+ class QKVAttention(nn.Module):
417
+ """
418
+ A module which performs QKV attention and splits in a different order.
419
+ """
420
+
421
+ def __init__(self, n_heads):
422
+ super().__init__()
423
+ self.n_heads = n_heads
424
+
425
+ def forward(self, qkv):
426
+ """
427
+ Apply QKV attention.
428
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
429
+ :return: an [N x (H * C) x T] tensor after attention.
430
+ """
431
+ bs, width, length = qkv.shape
432
+ assert width % (3 * self.n_heads) == 0
433
+ ch = width // (3 * self.n_heads)
434
+ q, k, v = qkv.chunk(3, dim=1)
435
+ scale = 1 / math.sqrt(math.sqrt(ch))
436
+ weight = th.einsum(
437
+ "bct,bcs->bts",
438
+ (q * scale).view(bs * self.n_heads, ch, length),
439
+ (k * scale).view(bs * self.n_heads, ch, length),
440
+ ) # More stable with f16 than dividing afterwards
441
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
442
+ a = th.einsum(
443
+ "bts,bcs->bct",
444
+ weight,
445
+ v.reshape(bs * self.n_heads, ch, length).contiguous(),
446
+ )
447
+ return a.reshape(bs, -1, length).contiguous()
448
+
449
+ @staticmethod
450
+ def count_flops(model, _x, y):
451
+ return count_flops_attn(model, _x, y)
452
+
453
+
454
+ class AudioSRUnet(nn.Module):
455
+ """
456
+ The full UNet model with attention and timestep embedding.
457
+ :param in_channels: channels in the input Tensor.
458
+ :param model_channels: base channel count for the model.
459
+ :param out_channels: channels in the output Tensor.
460
+ :param num_res_blocks: number of residual blocks per downsample.
461
+ :param attention_resolutions: a collection of downsample rates at which
462
+ attention will take place. May be a set, list, or tuple.
463
+ For example, if this contains 4, then at 4x downsampling, attention
464
+ will be used.
465
+ :param dropout: the dropout probability.
466
+ :param channel_mult: channel multiplier for each level of the UNet.
467
+ :param conv_resample: if True, use learned convolutions for upsampling and
468
+ downsampling.
469
+ :param dims: determines if the signal is 1D, 2D, or 3D.
470
+ :param num_classes: if specified (as an int), then this model will be
471
+ class-conditional with `num_classes` classes.
472
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
473
+ :param num_heads: the number of attention heads in each attention layer.
474
+ :param num_heads_channels: if specified, ignore num_heads and instead use
475
+ a fixed channel width per attention head.
476
+ :param num_heads_upsample: works with num_heads to set a different number
477
+ of heads for upsampling. Deprecated.
478
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
479
+ :param resblock_updown: use residual blocks for up/downsampling.
480
+ :param use_new_attention_order: use a different attention pattern for potentially
481
+ increased efficiency.
482
+ """
483
+
484
+ def __init__(
485
+ self,
486
+ image_size:int = 64,
487
+ in_channels:int = 32,
488
+ model_channels:int = 128,
489
+ out_channels:int = 16,
490
+ num_res_blocks:int = 2,
491
+ attention_resolutions:list = [8, 4, 2],
492
+ dropout=0,
493
+ channel_mult=[1, 2, 3, 5],
494
+ conv_resample=True,
495
+ dims=2,
496
+ extra_sa_layer=True,
497
+ num_classes=None,
498
+ extra_film_condition_dim=None,
499
+ use_checkpoint=False,
500
+ use_fp16=False,
501
+ num_heads=-1,
502
+ num_head_channels=32,
503
+ num_heads_upsample=-1,
504
+ use_scale_shift_norm=False,
505
+ resblock_updown=False,
506
+ use_new_attention_order=False,
507
+ use_spatial_transformer=True, # custom transformer support
508
+ transformer_depth=1, # custom transformer support
509
+ context_dim=None, # custom transformer support
510
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
511
+ legacy=True,
512
+ ):
513
+ super().__init__()
514
+ if num_heads_upsample == -1:
515
+ num_heads_upsample = num_heads
516
+
517
+ if num_heads == -1:
518
+ assert (
519
+ num_head_channels != -1
520
+ ), "Either num_heads or num_head_channels has to be set"
521
+
522
+ if num_head_channels == -1:
523
+ assert (
524
+ num_heads != -1
525
+ ), "Either num_heads or num_head_channels has to be set"
526
+
527
+ self.image_size = image_size
528
+ self.in_channels = in_channels
529
+ self.model_channels = model_channels
530
+ self.out_channels = out_channels
531
+ self.num_res_blocks = num_res_blocks
532
+ self.attention_resolutions = attention_resolutions
533
+ self.dropout = dropout
534
+ self.channel_mult = channel_mult
535
+ self.conv_resample = conv_resample
536
+ self.num_classes = num_classes
537
+ self.extra_film_condition_dim = extra_film_condition_dim
538
+ self.use_checkpoint = use_checkpoint
539
+ self.dtype = th.float16 if use_fp16 else th.float32
540
+ self.num_heads = num_heads
541
+ self.num_head_channels = num_head_channels
542
+ self.num_heads_upsample = num_heads_upsample
543
+ self.predict_codebook_ids = n_embed is not None
544
+ time_embed_dim = model_channels * 4
545
+ self.time_embed = nn.Sequential(
546
+ linear(model_channels, time_embed_dim),
547
+ nn.SiLU(),
548
+ linear(time_embed_dim, time_embed_dim),
549
+ )
550
+
551
+ # assert not (
552
+ # self.num_classes is not None and self.extra_film_condition_dim is not None
553
+ # ), "As for the condition of theh UNet model, you can only set using class label or an extra embedding vector (such as from CLAP). You cannot set both num_classes and extra_film_condition_dim."
554
+
555
+ if self.num_classes is not None:
556
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
557
+
558
+ self.use_extra_film_by_concat = self.extra_film_condition_dim is not None
559
+
560
+ if self.extra_film_condition_dim is not None:
561
+ self.film_emb = nn.Linear(self.extra_film_condition_dim, time_embed_dim)
562
+ print(
563
+ "+ Use extra condition on UNet channel using Film. Extra condition dimension is %s. "
564
+ % self.extra_film_condition_dim
565
+ )
566
+
567
+ if context_dim is not None and not use_spatial_transformer:
568
+ assert (
569
+ use_spatial_transformer
570
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
571
+
572
+ if context_dim is not None and not isinstance(context_dim, list):
573
+ context_dim = [context_dim]
574
+ elif context_dim is None:
575
+ context_dim = [None] # At least use one spatial transformer
576
+
577
+ self.input_blocks = nn.ModuleList(
578
+ [
579
+ TimestepEmbedSequential(
580
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
581
+ )
582
+ ]
583
+ )
584
+ self._feature_size = model_channels
585
+ input_block_chans = [model_channels]
586
+ ch = model_channels
587
+ ds = 1
588
+ for level, mult in enumerate(channel_mult):
589
+ for _ in range(num_res_blocks):
590
+ layers = [
591
+ ResBlock(
592
+ ch,
593
+ time_embed_dim
594
+ if (not self.use_extra_film_by_concat)
595
+ else time_embed_dim * 2,
596
+ dropout,
597
+ out_channels=mult * model_channels,
598
+ dims=dims,
599
+ use_checkpoint=use_checkpoint,
600
+ use_scale_shift_norm=use_scale_shift_norm,
601
+ )
602
+ ]
603
+ ch = mult * model_channels
604
+ if ds in attention_resolutions:
605
+ if num_head_channels == -1:
606
+ dim_head = ch // num_heads
607
+ else:
608
+ num_heads = ch // num_head_channels
609
+ dim_head = num_head_channels
610
+ if legacy:
611
+ dim_head = (
612
+ ch // num_heads
613
+ if use_spatial_transformer
614
+ else num_head_channels
615
+ )
616
+ if extra_sa_layer:
617
+ layers.append(
618
+ SpatialTransformer(
619
+ ch,
620
+ num_heads,
621
+ dim_head,
622
+ depth=transformer_depth,
623
+ context_dim=None,
624
+ )
625
+ )
626
+ for context_dim_id in range(len(context_dim)):
627
+ layers.append(
628
+ AttentionBlock(
629
+ ch,
630
+ use_checkpoint=use_checkpoint,
631
+ num_heads=num_heads,
632
+ num_head_channels=dim_head,
633
+ use_new_attention_order=use_new_attention_order,
634
+ )
635
+ if not use_spatial_transformer
636
+ else SpatialTransformer(
637
+ ch,
638
+ num_heads,
639
+ dim_head,
640
+ depth=transformer_depth,
641
+ context_dim=context_dim[context_dim_id],
642
+ )
643
+ )
644
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
645
+ self._feature_size += ch
646
+ input_block_chans.append(ch)
647
+ if level != len(channel_mult) - 1:
648
+ out_ch = ch
649
+ self.input_blocks.append(
650
+ TimestepEmbedSequential(
651
+ ResBlock(
652
+ ch,
653
+ time_embed_dim
654
+ if (not self.use_extra_film_by_concat)
655
+ else time_embed_dim * 2,
656
+ dropout,
657
+ out_channels=out_ch,
658
+ dims=dims,
659
+ use_checkpoint=use_checkpoint,
660
+ use_scale_shift_norm=use_scale_shift_norm,
661
+ down=True,
662
+ )
663
+ if resblock_updown
664
+ else Downsample(
665
+ ch, conv_resample, dims=dims, out_channels=out_ch
666
+ )
667
+ )
668
+ )
669
+ ch = out_ch
670
+ input_block_chans.append(ch)
671
+ ds *= 2
672
+ self._feature_size += ch
673
+
674
+ if num_head_channels == -1:
675
+ dim_head = ch // num_heads
676
+ else:
677
+ num_heads = ch // num_head_channels
678
+ dim_head = num_head_channels
679
+ if legacy:
680
+ # num_heads = 1
681
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
682
+ middle_layers = [
683
+ ResBlock(
684
+ ch,
685
+ time_embed_dim
686
+ if (not self.use_extra_film_by_concat)
687
+ else time_embed_dim * 2,
688
+ dropout,
689
+ dims=dims,
690
+ use_checkpoint=use_checkpoint,
691
+ use_scale_shift_norm=use_scale_shift_norm,
692
+ )
693
+ ]
694
+ if extra_sa_layer:
695
+ middle_layers.append(
696
+ SpatialTransformer(
697
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=None
698
+ )
699
+ )
700
+ for context_dim_id in range(len(context_dim)):
701
+ middle_layers.append(
702
+ AttentionBlock(
703
+ ch,
704
+ use_checkpoint=use_checkpoint,
705
+ num_heads=num_heads,
706
+ num_head_channels=dim_head,
707
+ use_new_attention_order=use_new_attention_order,
708
+ )
709
+ if not use_spatial_transformer
710
+ else SpatialTransformer(
711
+ ch,
712
+ num_heads,
713
+ dim_head,
714
+ depth=transformer_depth,
715
+ context_dim=context_dim[context_dim_id],
716
+ )
717
+ )
718
+ middle_layers.append(
719
+ ResBlock(
720
+ ch,
721
+ time_embed_dim
722
+ if (not self.use_extra_film_by_concat)
723
+ else time_embed_dim * 2,
724
+ dropout,
725
+ dims=dims,
726
+ use_checkpoint=use_checkpoint,
727
+ use_scale_shift_norm=use_scale_shift_norm,
728
+ )
729
+ )
730
+ self.middle_block = TimestepEmbedSequential(*middle_layers)
731
+
732
+ self._feature_size += ch
733
+
734
+ self.output_blocks = nn.ModuleList([])
735
+ for level, mult in list(enumerate(channel_mult))[::-1]:
736
+ for i in range(num_res_blocks + 1):
737
+ ich = input_block_chans.pop()
738
+ layers = [
739
+ ResBlock(
740
+ ch + ich,
741
+ time_embed_dim
742
+ if (not self.use_extra_film_by_concat)
743
+ else time_embed_dim * 2,
744
+ dropout,
745
+ out_channels=model_channels * mult,
746
+ dims=dims,
747
+ use_checkpoint=use_checkpoint,
748
+ use_scale_shift_norm=use_scale_shift_norm,
749
+ )
750
+ ]
751
+ ch = model_channels * mult
752
+ if ds in attention_resolutions:
753
+ if num_head_channels == -1:
754
+ dim_head = ch // num_heads
755
+ else:
756
+ num_heads = ch // num_head_channels
757
+ dim_head = num_head_channels
758
+ if legacy:
759
+ # num_heads = 1
760
+ dim_head = (
761
+ ch // num_heads
762
+ if use_spatial_transformer
763
+ else num_head_channels
764
+ )
765
+ if extra_sa_layer:
766
+ layers.append(
767
+ SpatialTransformer(
768
+ ch,
769
+ num_heads,
770
+ dim_head,
771
+ depth=transformer_depth,
772
+ context_dim=None,
773
+ )
774
+ )
775
+ for context_dim_id in range(len(context_dim)):
776
+ layers.append(
777
+ AttentionBlock(
778
+ ch,
779
+ use_checkpoint=use_checkpoint,
780
+ num_heads=num_heads_upsample,
781
+ num_head_channels=dim_head,
782
+ use_new_attention_order=use_new_attention_order,
783
+ )
784
+ if not use_spatial_transformer
785
+ else SpatialTransformer(
786
+ ch,
787
+ num_heads,
788
+ dim_head,
789
+ depth=transformer_depth,
790
+ context_dim=context_dim[context_dim_id],
791
+ )
792
+ )
793
+ if level and i == num_res_blocks:
794
+ out_ch = ch
795
+ layers.append(
796
+ ResBlock(
797
+ ch,
798
+ time_embed_dim
799
+ if (not self.use_extra_film_by_concat)
800
+ else time_embed_dim * 2,
801
+ dropout,
802
+ out_channels=out_ch,
803
+ dims=dims,
804
+ use_checkpoint=use_checkpoint,
805
+ use_scale_shift_norm=use_scale_shift_norm,
806
+ up=True,
807
+ )
808
+ if resblock_updown
809
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
810
+ )
811
+ ds //= 2
812
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
813
+ self._feature_size += ch
814
+
815
+ self.out = nn.Sequential(
816
+ normalization(ch),
817
+ nn.SiLU(),
818
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
819
+ )
820
+ if self.predict_codebook_ids:
821
+ self.id_predictor = nn.Sequential(
822
+ normalization(ch),
823
+ conv_nd(dims, model_channels, n_embed, 1),
824
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
825
+ )
826
+
827
+ self.shape_reported = False
828
+
829
+ def convert_to_fp16(self):
830
+ """
831
+ Convert the torso of the model to float16.
832
+ """
833
+ self.input_blocks.apply(convert_module_to_f16)
834
+ self.middle_block.apply(convert_module_to_f16)
835
+ self.output_blocks.apply(convert_module_to_f16)
836
+
837
+ def convert_to_fp32(self):
838
+ """
839
+ Convert the torso of the model to float32.
840
+ """
841
+ self.input_blocks.apply(convert_module_to_f32)
842
+ self.middle_block.apply(convert_module_to_f32)
843
+ self.output_blocks.apply(convert_module_to_f32)
844
+
845
+ def forward(
846
+ self,
847
+ x,
848
+ timesteps=None,
849
+ y=None,
850
+ context_list=list(),
851
+ context_attn_mask_list=list(),
852
+ **kwargs,
853
+ ):
854
+ """
855
+ Apply the model to an input batch.
856
+ :param x: an [N x C x ...] Tensor of inputs.
857
+ :param timesteps: a 1-D batch of timesteps.
858
+ :param context: conditioning plugged in via crossattn
859
+ :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional
860
+ :return: an [N x C x ...] Tensor of outputs.
861
+ """
862
+ x = th.concat([x,y], dim=1) #jakeoneijk added
863
+ y = None
864
+ if not self.shape_reported:
865
+ # print("The shape of UNet input is", x.size())
866
+ self.shape_reported = True
867
+
868
+ assert (y is not None) == (
869
+ self.num_classes is not None or self.extra_film_condition_dim is not None
870
+ ), "must specify y if and only if the model is class-conditional or film embedding conditional"
871
+ hs = []
872
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
873
+ emb = self.time_embed(t_emb)
874
+
875
+ # if self.num_classes is not None:
876
+ # assert y.shape == (x.shape[0],)
877
+ # emb = emb + self.label_emb(y)
878
+
879
+ if self.use_extra_film_by_concat:
880
+ emb = th.cat([emb, self.film_emb(y)], dim=-1)
881
+
882
+ h = x.type(self.dtype)
883
+ for module in self.input_blocks:
884
+ h = module(h, emb, context_list, context_attn_mask_list)
885
+ hs.append(h)
886
+ h = self.middle_block(h, emb, context_list, context_attn_mask_list)
887
+ for module in self.output_blocks:
888
+ concate_tensor = hs.pop()
889
+ h = th.cat([h, concate_tensor], dim=1)
890
+ h = module(h, emb, context_list, context_attn_mask_list)
891
+ h = h.type(x.dtype)
892
+ if self.predict_codebook_ids:
893
+ return self.id_predictor(h)
894
+ else:
895
+ return self.out(h)
896
+
897
+
898
+ class EncoderUNetModel(nn.Module):
899
+ """
900
+ The half UNet model with attention and timestep embedding.
901
+ For usage, see UNet.
902
+ """
903
+
904
+ def __init__(
905
+ self,
906
+ image_size,
907
+ in_channels,
908
+ model_channels,
909
+ out_channels,
910
+ num_res_blocks,
911
+ attention_resolutions,
912
+ dropout=0,
913
+ channel_mult=(1, 2, 4, 8),
914
+ conv_resample=True,
915
+ dims=2,
916
+ use_checkpoint=False,
917
+ use_fp16=False,
918
+ num_heads=1,
919
+ num_head_channels=-1,
920
+ num_heads_upsample=-1,
921
+ use_scale_shift_norm=False,
922
+ resblock_updown=False,
923
+ use_new_attention_order=False,
924
+ pool="adaptive",
925
+ *args,
926
+ **kwargs,
927
+ ):
928
+ super().__init__()
929
+
930
+ if num_heads_upsample == -1:
931
+ num_heads_upsample = num_heads
932
+
933
+ self.in_channels = in_channels
934
+ self.model_channels = model_channels
935
+ self.out_channels = out_channels
936
+ self.num_res_blocks = num_res_blocks
937
+ self.attention_resolutions = attention_resolutions
938
+ self.dropout = dropout
939
+ self.channel_mult = channel_mult
940
+ self.conv_resample = conv_resample
941
+ self.use_checkpoint = use_checkpoint
942
+ self.dtype = th.float16 if use_fp16 else th.float32
943
+ self.num_heads = num_heads
944
+ self.num_head_channels = num_head_channels
945
+ self.num_heads_upsample = num_heads_upsample
946
+
947
+ time_embed_dim = model_channels * 4
948
+ self.time_embed = nn.Sequential(
949
+ linear(model_channels, time_embed_dim),
950
+ nn.SiLU(),
951
+ linear(time_embed_dim, time_embed_dim),
952
+ )
953
+
954
+ self.input_blocks = nn.ModuleList(
955
+ [
956
+ TimestepEmbedSequential(
957
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
958
+ )
959
+ ]
960
+ )
961
+ self._feature_size = model_channels
962
+ input_block_chans = [model_channels]
963
+ ch = model_channels
964
+ ds = 1
965
+ for level, mult in enumerate(channel_mult):
966
+ for _ in range(num_res_blocks):
967
+ layers = [
968
+ ResBlock(
969
+ ch,
970
+ time_embed_dim,
971
+ dropout,
972
+ out_channels=mult * model_channels,
973
+ dims=dims,
974
+ use_checkpoint=use_checkpoint,
975
+ use_scale_shift_norm=use_scale_shift_norm,
976
+ )
977
+ ]
978
+ ch = mult * model_channels
979
+ if ds in attention_resolutions:
980
+ layers.append(
981
+ AttentionBlock(
982
+ ch,
983
+ use_checkpoint=use_checkpoint,
984
+ num_heads=num_heads,
985
+ num_head_channels=num_head_channels,
986
+ use_new_attention_order=use_new_attention_order,
987
+ )
988
+ )
989
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
990
+ self._feature_size += ch
991
+ input_block_chans.append(ch)
992
+ if level != len(channel_mult) - 1:
993
+ out_ch = ch
994
+ self.input_blocks.append(
995
+ TimestepEmbedSequential(
996
+ ResBlock(
997
+ ch,
998
+ time_embed_dim,
999
+ dropout,
1000
+ out_channels=out_ch,
1001
+ dims=dims,
1002
+ use_checkpoint=use_checkpoint,
1003
+ use_scale_shift_norm=use_scale_shift_norm,
1004
+ down=True,
1005
+ )
1006
+ if resblock_updown
1007
+ else Downsample(
1008
+ ch, conv_resample, dims=dims, out_channels=out_ch
1009
+ )
1010
+ )
1011
+ )
1012
+ ch = out_ch
1013
+ input_block_chans.append(ch)
1014
+ ds *= 2
1015
+ self._feature_size += ch
1016
+
1017
+ self.middle_block = TimestepEmbedSequential(
1018
+ ResBlock(
1019
+ ch,
1020
+ time_embed_dim,
1021
+ dropout,
1022
+ dims=dims,
1023
+ use_checkpoint=use_checkpoint,
1024
+ use_scale_shift_norm=use_scale_shift_norm,
1025
+ ),
1026
+ AttentionBlock(
1027
+ ch,
1028
+ use_checkpoint=use_checkpoint,
1029
+ num_heads=num_heads,
1030
+ num_head_channels=num_head_channels,
1031
+ use_new_attention_order=use_new_attention_order,
1032
+ ),
1033
+ ResBlock(
1034
+ ch,
1035
+ time_embed_dim,
1036
+ dropout,
1037
+ dims=dims,
1038
+ use_checkpoint=use_checkpoint,
1039
+ use_scale_shift_norm=use_scale_shift_norm,
1040
+ ),
1041
+ )
1042
+ self._feature_size += ch
1043
+ self.pool = pool
1044
+ if pool == "adaptive":
1045
+ self.out = nn.Sequential(
1046
+ normalization(ch),
1047
+ nn.SiLU(),
1048
+ nn.AdaptiveAvgPool2d((1, 1)),
1049
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
1050
+ nn.Flatten(),
1051
+ )
1052
+ elif pool == "attention":
1053
+ assert num_head_channels != -1
1054
+ self.out = nn.Sequential(
1055
+ normalization(ch),
1056
+ nn.SiLU(),
1057
+ AttentionPool2d(
1058
+ (image_size // ds), ch, num_head_channels, out_channels
1059
+ ),
1060
+ )
1061
+ elif pool == "spatial":
1062
+ self.out = nn.Sequential(
1063
+ nn.Linear(self._feature_size, 2048),
1064
+ nn.ReLU(),
1065
+ nn.Linear(2048, self.out_channels),
1066
+ )
1067
+ elif pool == "spatial_v2":
1068
+ self.out = nn.Sequential(
1069
+ nn.Linear(self._feature_size, 2048),
1070
+ normalization(2048),
1071
+ nn.SiLU(),
1072
+ nn.Linear(2048, self.out_channels),
1073
+ )
1074
+ else:
1075
+ raise NotImplementedError(f"Unexpected {pool} pooling")
1076
+
1077
+ def convert_to_fp16(self):
1078
+ """
1079
+ Convert the torso of the model to float16.
1080
+ """
1081
+ self.input_blocks.apply(convert_module_to_f16)
1082
+ self.middle_block.apply(convert_module_to_f16)
1083
+
1084
+ def convert_to_fp32(self):
1085
+ """
1086
+ Convert the torso of the model to float32.
1087
+ """
1088
+ self.input_blocks.apply(convert_module_to_f32)
1089
+ self.middle_block.apply(convert_module_to_f32)
1090
+
1091
+ def forward(self, x, timesteps):
1092
+ """
1093
+ Apply the model to an input batch.
1094
+ :param x: an [N x C x ...] Tensor of inputs.
1095
+ :param timesteps: a 1-D batch of timesteps.
1096
+ :return: an [N x K] Tensor of outputs.
1097
+ """
1098
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
1099
+
1100
+ results = []
1101
+ h = x.type(self.dtype)
1102
+ for module in self.input_blocks:
1103
+ h = module(h, emb)
1104
+ if self.pool.startswith("spatial"):
1105
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1106
+ h = self.middle_block(h, emb)
1107
+ if self.pool.startswith("spatial"):
1108
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1109
+ h = th.cat(results, axis=-1)
1110
+ return self.out(h)
1111
+ else:
1112
+ h = h.type(x.dtype)
1113
+ return self.out(h)
1114
+
1115
+ if __name__ == '__main__':
1116
+ '''
1117
+
1118
+ args = {
1119
+ 'in_channels': 2,
1120
+ 'model_channels': 64,
1121
+ 'out_channels': 1
1122
+ }
1123
+ audio_sr = AudioSRUnet(**args)
1124
+ audio_sr(x = th.randn(1, 1, 128, 256), timesteps = th.tensor([30]), y = th.randn(1, 1, 128, 256)) #jakeoneijk added
1125
+ '''
1126
+ audio_sr = AudioSRUnet()
1127
+ audio_sr(x = th.randn(1, 16, 64, 32), timesteps = th.tensor([30]), y = th.randn(1, 16, 64, 32)) #jakeoneijk added
FlashSR/AudioSR/EncoderDecoder.py ADDED
@@ -0,0 +1,1010 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+
8
+ class LinearAttention(nn.Module):
9
+ def __init__(self, dim, heads=4, dim_head=32):
10
+ super().__init__()
11
+ self.heads = heads
12
+ hidden_dim = dim_head * heads
13
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
14
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
15
+
16
+ def forward(self, x):
17
+ b, c, h, w = x.shape
18
+ qkv = self.to_qkv(x)
19
+ q, k, v = rearrange(
20
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
21
+ )
22
+ k = k.softmax(dim=-1)
23
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
24
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
25
+ out = rearrange(
26
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
27
+ )
28
+ return self.to_out(out)
29
+
30
+ def get_timestep_embedding(timesteps, embedding_dim):
31
+ """
32
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
33
+ From Fairseq.
34
+ Build sinusoidal embeddings.
35
+ This matches the implementation in tensor2tensor, but differs slightly
36
+ from the description in Section 3.5 of "Attention Is All You Need".
37
+ """
38
+ assert len(timesteps.shape) == 1
39
+
40
+ half_dim = embedding_dim // 2
41
+ emb = math.log(10000) / (half_dim - 1)
42
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
43
+ emb = emb.to(device=timesteps.device)
44
+ emb = timesteps.float()[:, None] * emb[None, :]
45
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
46
+ if embedding_dim % 2 == 1: # zero pad
47
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
48
+ return emb
49
+
50
+
51
+ def nonlinearity(x):
52
+ # swish
53
+ return x * torch.sigmoid(x)
54
+
55
+
56
+ def Normalize(in_channels, num_groups=32):
57
+ return torch.nn.GroupNorm(
58
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
59
+ )
60
+
61
+
62
+ class Upsample(nn.Module):
63
+ def __init__(self, in_channels, with_conv):
64
+ super().__init__()
65
+ self.with_conv = with_conv
66
+ if self.with_conv:
67
+ self.conv = torch.nn.Conv2d(
68
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
69
+ )
70
+
71
+ def forward(self, x):
72
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
73
+ if self.with_conv:
74
+ x = self.conv(x)
75
+ return x
76
+
77
+
78
+ class UpsampleTimeStride4(nn.Module):
79
+ def __init__(self, in_channels, with_conv):
80
+ super().__init__()
81
+ self.with_conv = with_conv
82
+ if self.with_conv:
83
+ self.conv = torch.nn.Conv2d(
84
+ in_channels, in_channels, kernel_size=5, stride=1, padding=2
85
+ )
86
+
87
+ def forward(self, x):
88
+ x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest")
89
+ if self.with_conv:
90
+ x = self.conv(x)
91
+ return x
92
+
93
+
94
+ class Downsample(nn.Module):
95
+ def __init__(self, in_channels, with_conv):
96
+ super().__init__()
97
+ self.with_conv = with_conv
98
+ if self.with_conv:
99
+ # Do time downsampling here
100
+ # no asymmetric padding in torch conv, must do it ourselves
101
+ self.conv = torch.nn.Conv2d(
102
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
103
+ )
104
+
105
+ def forward(self, x):
106
+ if self.with_conv:
107
+ pad = (0, 1, 0, 1)
108
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
109
+ x = self.conv(x)
110
+ else:
111
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
112
+ return x
113
+
114
+
115
+ class DownsampleTimeStride4(nn.Module):
116
+ def __init__(self, in_channels, with_conv):
117
+ super().__init__()
118
+ self.with_conv = with_conv
119
+ if self.with_conv:
120
+ # Do time downsampling here
121
+ # no asymmetric padding in torch conv, must do it ourselves
122
+ self.conv = torch.nn.Conv2d(
123
+ in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1
124
+ )
125
+
126
+ def forward(self, x):
127
+ if self.with_conv:
128
+ pad = (0, 1, 0, 1)
129
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
130
+ x = self.conv(x)
131
+ else:
132
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2))
133
+ return x
134
+
135
+
136
+ class ResnetBlock(nn.Module):
137
+ def __init__(
138
+ self,
139
+ *,
140
+ in_channels,
141
+ out_channels=None,
142
+ conv_shortcut=False,
143
+ dropout,
144
+ temb_channels=512,
145
+ ):
146
+ super().__init__()
147
+ self.in_channels = in_channels
148
+ out_channels = in_channels if out_channels is None else out_channels
149
+ self.out_channels = out_channels
150
+ self.use_conv_shortcut = conv_shortcut
151
+
152
+ self.norm1 = Normalize(in_channels)
153
+ self.conv1 = torch.nn.Conv2d(
154
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
155
+ )
156
+ if temb_channels > 0:
157
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
158
+ self.norm2 = Normalize(out_channels)
159
+ self.dropout = torch.nn.Dropout(dropout)
160
+ self.conv2 = torch.nn.Conv2d(
161
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
162
+ )
163
+ if self.in_channels != self.out_channels:
164
+ if self.use_conv_shortcut:
165
+ self.conv_shortcut = torch.nn.Conv2d(
166
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
167
+ )
168
+ else:
169
+ self.nin_shortcut = torch.nn.Conv2d(
170
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
171
+ )
172
+
173
+ def forward(self, x, temb):
174
+ h = x
175
+ h = self.norm1(h)
176
+ h = nonlinearity(h)
177
+ h = self.conv1(h)
178
+
179
+ if temb is not None:
180
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
181
+
182
+ h = self.norm2(h)
183
+ h = nonlinearity(h)
184
+ h = self.dropout(h)
185
+ h = self.conv2(h)
186
+
187
+ if self.in_channels != self.out_channels:
188
+ if self.use_conv_shortcut:
189
+ x = self.conv_shortcut(x)
190
+ else:
191
+ x = self.nin_shortcut(x)
192
+
193
+ return x + h
194
+
195
+
196
+ class LinAttnBlock(LinearAttention):
197
+ """to match AttnBlock usage"""
198
+
199
+ def __init__(self, in_channels):
200
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
201
+
202
+
203
+ class AttnBlock(nn.Module):
204
+ def __init__(self, in_channels):
205
+ super().__init__()
206
+ self.in_channels = in_channels
207
+
208
+ self.norm = Normalize(in_channels)
209
+ self.q = torch.nn.Conv2d(
210
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
211
+ )
212
+ self.k = torch.nn.Conv2d(
213
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
214
+ )
215
+ self.v = torch.nn.Conv2d(
216
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
217
+ )
218
+ self.proj_out = torch.nn.Conv2d(
219
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
220
+ )
221
+
222
+ def forward(self, x):
223
+ h_ = x
224
+ h_ = self.norm(h_)
225
+ q = self.q(h_)
226
+ k = self.k(h_)
227
+ v = self.v(h_)
228
+
229
+ # compute attention
230
+ b, c, h, w = q.shape
231
+ q = q.reshape(b, c, h * w).contiguous()
232
+ q = q.permute(0, 2, 1).contiguous() # b,hw,c
233
+ k = k.reshape(b, c, h * w).contiguous() # b,c,hw
234
+ w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
235
+ w_ = w_ * (int(c) ** (-0.5))
236
+ w_ = torch.nn.functional.softmax(w_, dim=2)
237
+
238
+ # attend to values
239
+ v = v.reshape(b, c, h * w).contiguous()
240
+ w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
241
+ h_ = torch.bmm(
242
+ v, w_
243
+ ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
244
+ h_ = h_.reshape(b, c, h, w).contiguous()
245
+
246
+ h_ = self.proj_out(h_)
247
+
248
+ return x + h_
249
+
250
+
251
+ def make_attn(in_channels, attn_type="vanilla"):
252
+ assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
253
+ # print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
254
+ if attn_type == "vanilla":
255
+ return AttnBlock(in_channels)
256
+ elif attn_type == "none":
257
+ return nn.Identity(in_channels)
258
+ else:
259
+ return LinAttnBlock(in_channels)
260
+
261
+
262
+ class Model(nn.Module):
263
+ def __init__(
264
+ self,
265
+ *,
266
+ ch,
267
+ out_ch,
268
+ ch_mult=(1, 2, 4, 8),
269
+ num_res_blocks,
270
+ attn_resolutions,
271
+ dropout=0.0,
272
+ resamp_with_conv=True,
273
+ in_channels,
274
+ resolution,
275
+ use_timestep=True,
276
+ use_linear_attn=False,
277
+ attn_type="vanilla",
278
+ ):
279
+ super().__init__()
280
+ if use_linear_attn:
281
+ attn_type = "linear"
282
+ self.ch = ch
283
+ self.temb_ch = self.ch * 4
284
+ self.num_resolutions = len(ch_mult)
285
+ self.num_res_blocks = num_res_blocks
286
+ self.resolution = resolution
287
+ self.in_channels = in_channels
288
+
289
+ self.use_timestep = use_timestep
290
+ if self.use_timestep:
291
+ # timestep embedding
292
+ self.temb = nn.Module()
293
+ self.temb.dense = nn.ModuleList(
294
+ [
295
+ torch.nn.Linear(self.ch, self.temb_ch),
296
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
297
+ ]
298
+ )
299
+
300
+ # downsampling
301
+ self.conv_in = torch.nn.Conv2d(
302
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
303
+ )
304
+
305
+ curr_res = resolution
306
+ in_ch_mult = (1,) + tuple(ch_mult)
307
+ self.down = nn.ModuleList()
308
+ for i_level in range(self.num_resolutions):
309
+ block = nn.ModuleList()
310
+ attn = nn.ModuleList()
311
+ block_in = ch * in_ch_mult[i_level]
312
+ block_out = ch * ch_mult[i_level]
313
+ for i_block in range(self.num_res_blocks):
314
+ block.append(
315
+ ResnetBlock(
316
+ in_channels=block_in,
317
+ out_channels=block_out,
318
+ temb_channels=self.temb_ch,
319
+ dropout=dropout,
320
+ )
321
+ )
322
+ block_in = block_out
323
+ if curr_res in attn_resolutions:
324
+ attn.append(make_attn(block_in, attn_type=attn_type))
325
+ down = nn.Module()
326
+ down.block = block
327
+ down.attn = attn
328
+ if i_level != self.num_resolutions - 1:
329
+ down.downsample = Downsample(block_in, resamp_with_conv)
330
+ curr_res = curr_res // 2
331
+ self.down.append(down)
332
+
333
+ # middle
334
+ self.mid = nn.Module()
335
+ self.mid.block_1 = ResnetBlock(
336
+ in_channels=block_in,
337
+ out_channels=block_in,
338
+ temb_channels=self.temb_ch,
339
+ dropout=dropout,
340
+ )
341
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
342
+ self.mid.block_2 = ResnetBlock(
343
+ in_channels=block_in,
344
+ out_channels=block_in,
345
+ temb_channels=self.temb_ch,
346
+ dropout=dropout,
347
+ )
348
+
349
+ # upsampling
350
+ self.up = nn.ModuleList()
351
+ for i_level in reversed(range(self.num_resolutions)):
352
+ block = nn.ModuleList()
353
+ attn = nn.ModuleList()
354
+ block_out = ch * ch_mult[i_level]
355
+ skip_in = ch * ch_mult[i_level]
356
+ for i_block in range(self.num_res_blocks + 1):
357
+ if i_block == self.num_res_blocks:
358
+ skip_in = ch * in_ch_mult[i_level]
359
+ block.append(
360
+ ResnetBlock(
361
+ in_channels=block_in + skip_in,
362
+ out_channels=block_out,
363
+ temb_channels=self.temb_ch,
364
+ dropout=dropout,
365
+ )
366
+ )
367
+ block_in = block_out
368
+ if curr_res in attn_resolutions:
369
+ attn.append(make_attn(block_in, attn_type=attn_type))
370
+ up = nn.Module()
371
+ up.block = block
372
+ up.attn = attn
373
+ if i_level != 0:
374
+ up.upsample = Upsample(block_in, resamp_with_conv)
375
+ curr_res = curr_res * 2
376
+ self.up.insert(0, up) # prepend to get consistent order
377
+
378
+ # end
379
+ self.norm_out = Normalize(block_in)
380
+ self.conv_out = torch.nn.Conv2d(
381
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
382
+ )
383
+
384
+ def forward(self, x, t=None, context=None):
385
+ # assert x.shape[2] == x.shape[3] == self.resolution
386
+ if context is not None:
387
+ # assume aligned context, cat along channel axis
388
+ x = torch.cat((x, context), dim=1)
389
+ if self.use_timestep:
390
+ # timestep embedding
391
+ assert t is not None
392
+ temb = get_timestep_embedding(t, self.ch)
393
+ temb = self.temb.dense[0](temb)
394
+ temb = nonlinearity(temb)
395
+ temb = self.temb.dense[1](temb)
396
+ else:
397
+ temb = None
398
+
399
+ # downsampling
400
+ hs = [self.conv_in(x)]
401
+ for i_level in range(self.num_resolutions):
402
+ for i_block in range(self.num_res_blocks):
403
+ h = self.down[i_level].block[i_block](hs[-1], temb)
404
+ if len(self.down[i_level].attn) > 0:
405
+ h = self.down[i_level].attn[i_block](h)
406
+ hs.append(h)
407
+ if i_level != self.num_resolutions - 1:
408
+ hs.append(self.down[i_level].downsample(hs[-1]))
409
+
410
+ # middle
411
+ h = hs[-1]
412
+ h = self.mid.block_1(h, temb)
413
+ h = self.mid.attn_1(h)
414
+ h = self.mid.block_2(h, temb)
415
+
416
+ # upsampling
417
+ for i_level in reversed(range(self.num_resolutions)):
418
+ for i_block in range(self.num_res_blocks + 1):
419
+ h = self.up[i_level].block[i_block](
420
+ torch.cat([h, hs.pop()], dim=1), temb
421
+ )
422
+ if len(self.up[i_level].attn) > 0:
423
+ h = self.up[i_level].attn[i_block](h)
424
+ if i_level != 0:
425
+ h = self.up[i_level].upsample(h)
426
+
427
+ # end
428
+ h = self.norm_out(h)
429
+ h = nonlinearity(h)
430
+ h = self.conv_out(h)
431
+ return h
432
+
433
+ def get_last_layer(self):
434
+ return self.conv_out.weight
435
+
436
+
437
+ class Encoder(nn.Module):
438
+ def __init__(
439
+ self,
440
+ *,
441
+ ch,
442
+ out_ch,
443
+ ch_mult=(1, 2, 4, 8),
444
+ num_res_blocks,
445
+ attn_resolutions,
446
+ dropout=0.0,
447
+ resamp_with_conv=True,
448
+ in_channels,
449
+ resolution,
450
+ z_channels,
451
+ double_z=True,
452
+ use_linear_attn=False,
453
+ attn_type="vanilla",
454
+ downsample_time_stride4_levels=[],
455
+ **ignore_kwargs,
456
+ ):
457
+ super().__init__()
458
+ if use_linear_attn:
459
+ attn_type = "linear"
460
+ self.ch = ch
461
+ self.temb_ch = 0
462
+ self.num_resolutions = len(ch_mult)
463
+ self.num_res_blocks = num_res_blocks
464
+ self.resolution = resolution
465
+ self.in_channels = in_channels
466
+ self.downsample_time_stride4_levels = downsample_time_stride4_levels
467
+
468
+ if len(self.downsample_time_stride4_levels) > 0:
469
+ assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
470
+ "The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
471
+ % str(self.num_resolutions)
472
+ )
473
+
474
+ # downsampling
475
+ self.conv_in = torch.nn.Conv2d(
476
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
477
+ )
478
+
479
+ curr_res = resolution
480
+ in_ch_mult = (1,) + tuple(ch_mult)
481
+ self.in_ch_mult = in_ch_mult
482
+ self.down = nn.ModuleList()
483
+ for i_level in range(self.num_resolutions):
484
+ block = nn.ModuleList()
485
+ attn = nn.ModuleList()
486
+ block_in = ch * in_ch_mult[i_level]
487
+ block_out = ch * ch_mult[i_level]
488
+ for i_block in range(self.num_res_blocks):
489
+ block.append(
490
+ ResnetBlock(
491
+ in_channels=block_in,
492
+ out_channels=block_out,
493
+ temb_channels=self.temb_ch,
494
+ dropout=dropout,
495
+ )
496
+ )
497
+ block_in = block_out
498
+ if curr_res in attn_resolutions:
499
+ attn.append(make_attn(block_in, attn_type=attn_type))
500
+ down = nn.Module()
501
+ down.block = block
502
+ down.attn = attn
503
+ if i_level != self.num_resolutions - 1:
504
+ if i_level in self.downsample_time_stride4_levels:
505
+ down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv)
506
+ else:
507
+ down.downsample = Downsample(block_in, resamp_with_conv)
508
+ curr_res = curr_res // 2
509
+ self.down.append(down)
510
+
511
+ # middle
512
+ self.mid = nn.Module()
513
+ self.mid.block_1 = ResnetBlock(
514
+ in_channels=block_in,
515
+ out_channels=block_in,
516
+ temb_channels=self.temb_ch,
517
+ dropout=dropout,
518
+ )
519
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
520
+ self.mid.block_2 = ResnetBlock(
521
+ in_channels=block_in,
522
+ out_channels=block_in,
523
+ temb_channels=self.temb_ch,
524
+ dropout=dropout,
525
+ )
526
+
527
+ # end
528
+ self.norm_out = Normalize(block_in)
529
+ self.conv_out = torch.nn.Conv2d(
530
+ block_in,
531
+ 2 * z_channels if double_z else z_channels,
532
+ kernel_size=3,
533
+ stride=1,
534
+ padding=1,
535
+ )
536
+
537
+ def forward(self, x):
538
+ # timestep embedding
539
+ temb = None
540
+ # downsampling
541
+ hs = [self.conv_in(x)]
542
+ for i_level in range(self.num_resolutions):
543
+ for i_block in range(self.num_res_blocks):
544
+ h = self.down[i_level].block[i_block](hs[-1], temb)
545
+ if len(self.down[i_level].attn) > 0:
546
+ h = self.down[i_level].attn[i_block](h)
547
+ hs.append(h)
548
+ if i_level != self.num_resolutions - 1:
549
+ hs.append(self.down[i_level].downsample(hs[-1]))
550
+
551
+ # middle
552
+ h = hs[-1]
553
+ h = self.mid.block_1(h, temb)
554
+ h = self.mid.attn_1(h)
555
+ h = self.mid.block_2(h, temb)
556
+
557
+ # end
558
+ h = self.norm_out(h)
559
+ h = nonlinearity(h)
560
+ h = self.conv_out(h)
561
+ return h
562
+
563
+
564
+ class Decoder(nn.Module):
565
+ def __init__(
566
+ self,
567
+ *,
568
+ ch,
569
+ out_ch,
570
+ ch_mult=(1, 2, 4, 8),
571
+ num_res_blocks,
572
+ attn_resolutions,
573
+ dropout=0.0,
574
+ resamp_with_conv=True,
575
+ in_channels,
576
+ resolution,
577
+ z_channels,
578
+ give_pre_end=False,
579
+ tanh_out=False,
580
+ use_linear_attn=False,
581
+ downsample_time_stride4_levels=[],
582
+ attn_type="vanilla",
583
+ **ignorekwargs,
584
+ ):
585
+ super().__init__()
586
+ if use_linear_attn:
587
+ attn_type = "linear"
588
+ self.ch = ch
589
+ self.temb_ch = 0
590
+ self.num_resolutions = len(ch_mult)
591
+ self.num_res_blocks = num_res_blocks
592
+ self.resolution = resolution
593
+ self.in_channels = in_channels
594
+ self.give_pre_end = give_pre_end
595
+ self.tanh_out = tanh_out
596
+ self.downsample_time_stride4_levels = downsample_time_stride4_levels
597
+
598
+ if len(self.downsample_time_stride4_levels) > 0:
599
+ assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
600
+ "The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
601
+ % str(self.num_resolutions)
602
+ )
603
+
604
+ # compute in_ch_mult, block_in and curr_res at lowest res
605
+ (1,) + tuple(ch_mult)
606
+ block_in = ch * ch_mult[self.num_resolutions - 1]
607
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
608
+ self.z_shape = (1, z_channels, curr_res, curr_res)
609
+ # print(
610
+ # "Working with z of shape {} = {} dimensions.".format(
611
+ # self.z_shape, np.prod(self.z_shape)
612
+ # )
613
+ # )
614
+
615
+ # z to block_in
616
+ self.conv_in = torch.nn.Conv2d(
617
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
618
+ )
619
+
620
+ # middle
621
+ self.mid = nn.Module()
622
+ self.mid.block_1 = ResnetBlock(
623
+ in_channels=block_in,
624
+ out_channels=block_in,
625
+ temb_channels=self.temb_ch,
626
+ dropout=dropout,
627
+ )
628
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
629
+ self.mid.block_2 = ResnetBlock(
630
+ in_channels=block_in,
631
+ out_channels=block_in,
632
+ temb_channels=self.temb_ch,
633
+ dropout=dropout,
634
+ )
635
+
636
+ # upsampling
637
+ self.up = nn.ModuleList()
638
+ for i_level in reversed(range(self.num_resolutions)):
639
+ block = nn.ModuleList()
640
+ attn = nn.ModuleList()
641
+ block_out = ch * ch_mult[i_level]
642
+ for i_block in range(self.num_res_blocks + 1):
643
+ block.append(
644
+ ResnetBlock(
645
+ in_channels=block_in,
646
+ out_channels=block_out,
647
+ temb_channels=self.temb_ch,
648
+ dropout=dropout,
649
+ )
650
+ )
651
+ block_in = block_out
652
+ if curr_res in attn_resolutions:
653
+ attn.append(make_attn(block_in, attn_type=attn_type))
654
+ up = nn.Module()
655
+ up.block = block
656
+ up.attn = attn
657
+ if i_level != 0:
658
+ if i_level - 1 in self.downsample_time_stride4_levels:
659
+ up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv)
660
+ else:
661
+ up.upsample = Upsample(block_in, resamp_with_conv)
662
+ curr_res = curr_res * 2
663
+ self.up.insert(0, up) # prepend to get consistent order
664
+
665
+ # end
666
+ self.norm_out = Normalize(block_in)
667
+ self.conv_out = torch.nn.Conv2d(
668
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
669
+ )
670
+
671
+ def forward(self, z):
672
+ # assert z.shape[1:] == self.z_shape[1:]
673
+ self.last_z_shape = z.shape
674
+
675
+ # timestep embedding
676
+ temb = None
677
+
678
+ # z to block_in
679
+ h = self.conv_in(z)
680
+
681
+ # middle
682
+ h = self.mid.block_1(h, temb)
683
+ h = self.mid.attn_1(h)
684
+ h = self.mid.block_2(h, temb)
685
+
686
+ # upsampling
687
+ for i_level in reversed(range(self.num_resolutions)):
688
+ for i_block in range(self.num_res_blocks + 1):
689
+ h = self.up[i_level].block[i_block](h, temb)
690
+ if len(self.up[i_level].attn) > 0:
691
+ h = self.up[i_level].attn[i_block](h)
692
+ if i_level != 0:
693
+ h = self.up[i_level].upsample(h)
694
+
695
+ # end
696
+ if self.give_pre_end:
697
+ return h
698
+
699
+ h = self.norm_out(h)
700
+ h = nonlinearity(h)
701
+ h = self.conv_out(h)
702
+ if self.tanh_out:
703
+ h = torch.tanh(h)
704
+ return h
705
+
706
+
707
+ class SimpleDecoder(nn.Module):
708
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
709
+ super().__init__()
710
+ self.model = nn.ModuleList(
711
+ [
712
+ nn.Conv2d(in_channels, in_channels, 1),
713
+ ResnetBlock(
714
+ in_channels=in_channels,
715
+ out_channels=2 * in_channels,
716
+ temb_channels=0,
717
+ dropout=0.0,
718
+ ),
719
+ ResnetBlock(
720
+ in_channels=2 * in_channels,
721
+ out_channels=4 * in_channels,
722
+ temb_channels=0,
723
+ dropout=0.0,
724
+ ),
725
+ ResnetBlock(
726
+ in_channels=4 * in_channels,
727
+ out_channels=2 * in_channels,
728
+ temb_channels=0,
729
+ dropout=0.0,
730
+ ),
731
+ nn.Conv2d(2 * in_channels, in_channels, 1),
732
+ Upsample(in_channels, with_conv=True),
733
+ ]
734
+ )
735
+ # end
736
+ self.norm_out = Normalize(in_channels)
737
+ self.conv_out = torch.nn.Conv2d(
738
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
739
+ )
740
+
741
+ def forward(self, x):
742
+ for i, layer in enumerate(self.model):
743
+ if i in [1, 2, 3]:
744
+ x = layer(x, None)
745
+ else:
746
+ x = layer(x)
747
+
748
+ h = self.norm_out(x)
749
+ h = nonlinearity(h)
750
+ x = self.conv_out(h)
751
+ return x
752
+
753
+
754
+ class UpsampleDecoder(nn.Module):
755
+ def __init__(
756
+ self,
757
+ in_channels,
758
+ out_channels,
759
+ ch,
760
+ num_res_blocks,
761
+ resolution,
762
+ ch_mult=(2, 2),
763
+ dropout=0.0,
764
+ ):
765
+ super().__init__()
766
+ # upsampling
767
+ self.temb_ch = 0
768
+ self.num_resolutions = len(ch_mult)
769
+ self.num_res_blocks = num_res_blocks
770
+ block_in = in_channels
771
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
772
+ self.res_blocks = nn.ModuleList()
773
+ self.upsample_blocks = nn.ModuleList()
774
+ for i_level in range(self.num_resolutions):
775
+ res_block = []
776
+ block_out = ch * ch_mult[i_level]
777
+ for i_block in range(self.num_res_blocks + 1):
778
+ res_block.append(
779
+ ResnetBlock(
780
+ in_channels=block_in,
781
+ out_channels=block_out,
782
+ temb_channels=self.temb_ch,
783
+ dropout=dropout,
784
+ )
785
+ )
786
+ block_in = block_out
787
+ self.res_blocks.append(nn.ModuleList(res_block))
788
+ if i_level != self.num_resolutions - 1:
789
+ self.upsample_blocks.append(Upsample(block_in, True))
790
+ curr_res = curr_res * 2
791
+
792
+ # end
793
+ self.norm_out = Normalize(block_in)
794
+ self.conv_out = torch.nn.Conv2d(
795
+ block_in, out_channels, kernel_size=3, stride=1, padding=1
796
+ )
797
+
798
+ def forward(self, x):
799
+ # upsampling
800
+ h = x
801
+ for k, i_level in enumerate(range(self.num_resolutions)):
802
+ for i_block in range(self.num_res_blocks + 1):
803
+ h = self.res_blocks[i_level][i_block](h, None)
804
+ if i_level != self.num_resolutions - 1:
805
+ h = self.upsample_blocks[k](h)
806
+ h = self.norm_out(h)
807
+ h = nonlinearity(h)
808
+ h = self.conv_out(h)
809
+ return h
810
+
811
+
812
+ class LatentRescaler(nn.Module):
813
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
814
+ super().__init__()
815
+ # residual block, interpolate, residual block
816
+ self.factor = factor
817
+ self.conv_in = nn.Conv2d(
818
+ in_channels, mid_channels, kernel_size=3, stride=1, padding=1
819
+ )
820
+ self.res_block1 = nn.ModuleList(
821
+ [
822
+ ResnetBlock(
823
+ in_channels=mid_channels,
824
+ out_channels=mid_channels,
825
+ temb_channels=0,
826
+ dropout=0.0,
827
+ )
828
+ for _ in range(depth)
829
+ ]
830
+ )
831
+ self.attn = AttnBlock(mid_channels)
832
+ self.res_block2 = nn.ModuleList(
833
+ [
834
+ ResnetBlock(
835
+ in_channels=mid_channels,
836
+ out_channels=mid_channels,
837
+ temb_channels=0,
838
+ dropout=0.0,
839
+ )
840
+ for _ in range(depth)
841
+ ]
842
+ )
843
+
844
+ self.conv_out = nn.Conv2d(
845
+ mid_channels,
846
+ out_channels,
847
+ kernel_size=1,
848
+ )
849
+
850
+ def forward(self, x):
851
+ x = self.conv_in(x)
852
+ for block in self.res_block1:
853
+ x = block(x, None)
854
+ x = torch.nn.functional.interpolate(
855
+ x,
856
+ size=(
857
+ int(round(x.shape[2] * self.factor)),
858
+ int(round(x.shape[3] * self.factor)),
859
+ ),
860
+ )
861
+ x = self.attn(x).contiguous()
862
+ for block in self.res_block2:
863
+ x = block(x, None)
864
+ x = self.conv_out(x)
865
+ return x
866
+
867
+
868
+ class MergedRescaleEncoder(nn.Module):
869
+ def __init__(
870
+ self,
871
+ in_channels,
872
+ ch,
873
+ resolution,
874
+ out_ch,
875
+ num_res_blocks,
876
+ attn_resolutions,
877
+ dropout=0.0,
878
+ resamp_with_conv=True,
879
+ ch_mult=(1, 2, 4, 8),
880
+ rescale_factor=1.0,
881
+ rescale_module_depth=1,
882
+ ):
883
+ super().__init__()
884
+ intermediate_chn = ch * ch_mult[-1]
885
+ self.encoder = Encoder(
886
+ in_channels=in_channels,
887
+ num_res_blocks=num_res_blocks,
888
+ ch=ch,
889
+ ch_mult=ch_mult,
890
+ z_channels=intermediate_chn,
891
+ double_z=False,
892
+ resolution=resolution,
893
+ attn_resolutions=attn_resolutions,
894
+ dropout=dropout,
895
+ resamp_with_conv=resamp_with_conv,
896
+ out_ch=None,
897
+ )
898
+ self.rescaler = LatentRescaler(
899
+ factor=rescale_factor,
900
+ in_channels=intermediate_chn,
901
+ mid_channels=intermediate_chn,
902
+ out_channels=out_ch,
903
+ depth=rescale_module_depth,
904
+ )
905
+
906
+ def forward(self, x):
907
+ x = self.encoder(x)
908
+ x = self.rescaler(x)
909
+ return x
910
+
911
+
912
+ class MergedRescaleDecoder(nn.Module):
913
+ def __init__(
914
+ self,
915
+ z_channels,
916
+ out_ch,
917
+ resolution,
918
+ num_res_blocks,
919
+ attn_resolutions,
920
+ ch,
921
+ ch_mult=(1, 2, 4, 8),
922
+ dropout=0.0,
923
+ resamp_with_conv=True,
924
+ rescale_factor=1.0,
925
+ rescale_module_depth=1,
926
+ ):
927
+ super().__init__()
928
+ tmp_chn = z_channels * ch_mult[-1]
929
+ self.decoder = Decoder(
930
+ out_ch=out_ch,
931
+ z_channels=tmp_chn,
932
+ attn_resolutions=attn_resolutions,
933
+ dropout=dropout,
934
+ resamp_with_conv=resamp_with_conv,
935
+ in_channels=None,
936
+ num_res_blocks=num_res_blocks,
937
+ ch_mult=ch_mult,
938
+ resolution=resolution,
939
+ ch=ch,
940
+ )
941
+ self.rescaler = LatentRescaler(
942
+ factor=rescale_factor,
943
+ in_channels=z_channels,
944
+ mid_channels=tmp_chn,
945
+ out_channels=tmp_chn,
946
+ depth=rescale_module_depth,
947
+ )
948
+
949
+ def forward(self, x):
950
+ x = self.rescaler(x)
951
+ x = self.decoder(x)
952
+ return x
953
+
954
+
955
+ class Upsampler(nn.Module):
956
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
957
+ super().__init__()
958
+ assert out_size >= in_size
959
+ num_blocks = int(np.log2(out_size // in_size)) + 1
960
+ factor_up = 1.0 + (out_size % in_size)
961
+ print(
962
+ f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
963
+ )
964
+ self.rescaler = LatentRescaler(
965
+ factor=factor_up,
966
+ in_channels=in_channels,
967
+ mid_channels=2 * in_channels,
968
+ out_channels=in_channels,
969
+ )
970
+ self.decoder = Decoder(
971
+ out_ch=out_channels,
972
+ resolution=out_size,
973
+ z_channels=in_channels,
974
+ num_res_blocks=2,
975
+ attn_resolutions=[],
976
+ in_channels=None,
977
+ ch=in_channels,
978
+ ch_mult=[ch_mult for _ in range(num_blocks)],
979
+ )
980
+
981
+ def forward(self, x):
982
+ x = self.rescaler(x)
983
+ x = self.decoder(x)
984
+ return x
985
+
986
+
987
+ class Resize(nn.Module):
988
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
989
+ super().__init__()
990
+ self.with_conv = learned
991
+ self.mode = mode
992
+ if self.with_conv:
993
+ print(
994
+ f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
995
+ )
996
+ raise NotImplementedError()
997
+ assert in_channels is not None
998
+ # no asymmetric padding in torch conv, must do it ourselves
999
+ self.conv = torch.nn.Conv2d(
1000
+ in_channels, in_channels, kernel_size=4, stride=2, padding=1
1001
+ )
1002
+
1003
+ def forward(self, x, scale_factor=1.0):
1004
+ if scale_factor == 1.0:
1005
+ return x
1006
+ else:
1007
+ x = torch.nn.functional.interpolate(
1008
+ x, mode=self.mode, align_corners=False, scale_factor=scale_factor
1009
+ )
1010
+ return x
FlashSR/AudioSR/Vocoder.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import FlashSR.AudioSR.hifigan as hifigan
4
+
5
+
6
+ def get_vocoder_config():
7
+ return {
8
+ "resblock": "1",
9
+ "num_gpus": 6,
10
+ "batch_size": 16,
11
+ "learning_rate": 0.0002,
12
+ "adam_b1": 0.8,
13
+ "adam_b2": 0.99,
14
+ "lr_decay": 0.999,
15
+ "seed": 1234,
16
+ "upsample_rates": [5, 4, 2, 2, 2],
17
+ "upsample_kernel_sizes": [16, 16, 8, 4, 4],
18
+ "upsample_initial_channel": 1024,
19
+ "resblock_kernel_sizes": [3, 7, 11],
20
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
21
+ "segment_size": 8192,
22
+ "num_mels": 64,
23
+ "num_freq": 1025,
24
+ "n_fft": 1024,
25
+ "hop_size": 160,
26
+ "win_size": 1024,
27
+ "sampling_rate": 16000,
28
+ "fmin": 0,
29
+ "fmax": 8000,
30
+ "fmax_for_loss": None,
31
+ "num_workers": 4,
32
+ "dist_config": {
33
+ "dist_backend": "nccl",
34
+ "dist_url": "tcp://localhost:54321",
35
+ "world_size": 1,
36
+ },
37
+ }
38
+
39
+
40
+ def get_vocoder_config_48k():
41
+ return {
42
+ "resblock": "1",
43
+ "num_gpus": 8,
44
+ "batch_size": 128,
45
+ "learning_rate": 0.0001,
46
+ "adam_b1": 0.8,
47
+ "adam_b2": 0.99,
48
+ "lr_decay": 0.999,
49
+ "seed": 1234,
50
+ "upsample_rates": [6, 5, 4, 2, 2],
51
+ "upsample_kernel_sizes": [12, 10, 8, 4, 4],
52
+ "upsample_initial_channel": 1536,
53
+ "resblock_kernel_sizes": [3, 7, 11, 15],
54
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5], [1, 3, 5]],
55
+ "segment_size": 15360,
56
+ "num_mels": 256,
57
+ "n_fft": 2048,
58
+ "hop_size": 480,
59
+ "win_size": 2048,
60
+ "sampling_rate": 48000,
61
+ "fmin": 20,
62
+ "fmax": 24000,
63
+ "fmax_for_loss": None,
64
+ "num_workers": 8,
65
+ "dist_config": {
66
+ "dist_backend": "nccl",
67
+ "dist_url": "tcp://localhost:18273",
68
+ "world_size": 1,
69
+ },
70
+ }
71
+
72
+
73
+ def get_available_checkpoint_keys(model, ckpt):
74
+ state_dict = torch.load(ckpt)["state_dict"]
75
+ current_state_dict = model.state_dict()
76
+ new_state_dict = {}
77
+ for k in state_dict.keys():
78
+ if (
79
+ k in current_state_dict.keys()
80
+ and current_state_dict[k].size() == state_dict[k].size()
81
+ ):
82
+ new_state_dict[k] = state_dict[k]
83
+ else:
84
+ print("==> WARNING: Skipping %s" % k)
85
+ print(
86
+ "%s out of %s keys are matched"
87
+ % (len(new_state_dict.keys()), len(state_dict.keys()))
88
+ )
89
+ return new_state_dict
90
+
91
+
92
+ def get_param_num(model):
93
+ num_param = sum(param.numel() for param in model.parameters())
94
+ return num_param
95
+
96
+
97
+ def torch_version_orig_mod_remove(state_dict):
98
+ new_state_dict = {}
99
+ new_state_dict["generator"] = {}
100
+ for key in state_dict["generator"].keys():
101
+ if "_orig_mod." in key:
102
+ new_state_dict["generator"][key.replace("_orig_mod.", "")] = state_dict[
103
+ "generator"
104
+ ][key]
105
+ else:
106
+ new_state_dict["generator"][key] = state_dict["generator"][key]
107
+ return new_state_dict
108
+
109
+
110
+ def get_vocoder(config, device, mel_bins):
111
+ name = "HiFi-GAN"
112
+ speaker = ""
113
+ if name == "MelGAN":
114
+ if speaker == "LJSpeech":
115
+ vocoder = torch.hub.load(
116
+ "descriptinc/melgan-neurips", "load_melgan", "linda_johnson"
117
+ )
118
+ elif speaker == "universal":
119
+ vocoder = torch.hub.load(
120
+ "descriptinc/melgan-neurips", "load_melgan", "multi_speaker"
121
+ )
122
+ vocoder.mel2wav.eval()
123
+ vocoder.mel2wav.to(device)
124
+ elif name == "HiFi-GAN":
125
+ if mel_bins == 64:
126
+ config = get_vocoder_config()
127
+ config = hifigan.AttrDict(config)
128
+ vocoder = hifigan.Generator_old(config)
129
+ # print("Load hifigan/g_01080000")
130
+ # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000"))
131
+ # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000"))
132
+ # ckpt = torch_version_orig_mod_remove(ckpt)
133
+ # vocoder.load_state_dict(ckpt["generator"])
134
+ vocoder.eval()
135
+ vocoder.remove_weight_norm()
136
+ vocoder.to(device)
137
+ else:
138
+ config = get_vocoder_config_48k()
139
+ config = hifigan.AttrDict(config)
140
+ vocoder = hifigan.Generator_old(config)
141
+ # print("Load hifigan/g_01080000")
142
+ # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000"))
143
+ # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000"))
144
+ # ckpt = torch_version_orig_mod_remove(ckpt)
145
+ # vocoder.load_state_dict(ckpt["generator"])
146
+ vocoder.eval()
147
+ vocoder.remove_weight_norm()
148
+ vocoder.to(device)
149
+ return vocoder
150
+
151
+
152
+ def vocoder_infer(mels, vocoder, lengths=None):
153
+ with torch.no_grad():
154
+ wavs = vocoder(mels).squeeze(1)
155
+
156
+ wavs = (wavs.cpu().numpy() * 32768).astype("int16")
157
+
158
+ if lengths is not None:
159
+ wavs = wavs[:, :lengths]
160
+
161
+ # wavs = [wav for wav in wavs]
162
+
163
+ # for i in range(len(mels)):
164
+ # if lengths is not None:
165
+ # wavs[i] = wavs[i][: lengths[i]]
166
+
167
+ return wavs
FlashSR/AudioSR/args/mel_argument.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ nfft: 2048
2
+ hop_size: 480
3
+ sample_rate: 48000
4
+ mel_size: 256
5
+ frequency_min: 20
6
+ frequency_max: 24000
FlashSR/AudioSR/args/model_argument.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batchsize: 4
2
+ ddconfig:
3
+ attn_resolutions: []
4
+ ch: 128
5
+ ch_mult:
6
+ - 1
7
+ - 2
8
+ - 4
9
+ - 8
10
+ double_z: true
11
+ downsample_time: false
12
+ dropout: 0.1
13
+ in_channels: 1
14
+ mel_bins: 256
15
+ num_res_blocks: 2
16
+ out_ch: 1
17
+ resolution: 256
18
+ z_channels: 16
19
+ embed_dim: 16
20
+ image_key: fbank
21
+ monitor: val/rec_loss
22
+ reload_from_ckpt: /mnt/bn/lqhaoheliu/project/audio_generation_diffusion/log/vae/vae_48k_256/ds_8_kl_1/checkpoints/ckpt-checkpoint-484999.ckpt
23
+ sampling_rate: 48000
24
+ subband: 1
25
+ time_shuffle: 1
FlashSR/AudioSR/autoencoder.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import soundfile as sf
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from FlashSR.AudioSR.EncoderDecoder import Encoder, Decoder
9
+ from FlashSR.AudioSR.Vocoder import get_vocoder
10
+
11
+
12
+ class AutoencoderKL(nn.Module):
13
+ def __init__(
14
+ self,
15
+ ddconfig=None,
16
+ lossconfig=None,
17
+ batchsize=None,
18
+ embed_dim=None,
19
+ time_shuffle=1,
20
+ subband=1,
21
+ sampling_rate=16000,
22
+ ckpt_path=None,
23
+ reload_from_ckpt=None,
24
+ ignore_keys=[],
25
+ image_key="fbank",
26
+ colorize_nlabels=None,
27
+ monitor=None,
28
+ base_learning_rate=1e-5,
29
+ ):
30
+ super().__init__()
31
+ self.automatic_optimization = False
32
+ assert (
33
+ "mel_bins" in ddconfig.keys()
34
+ ), "mel_bins is not specified in the Autoencoder config"
35
+ num_mel = ddconfig["mel_bins"]
36
+ self.image_key = image_key
37
+ self.sampling_rate = sampling_rate
38
+ self.encoder = Encoder(**ddconfig)
39
+ self.decoder = Decoder(**ddconfig)
40
+
41
+ self.loss = None
42
+ self.subband = int(subband)
43
+
44
+ if self.subband > 1:
45
+ print("Use subband decomposition %s" % self.subband)
46
+
47
+ assert ddconfig["double_z"]
48
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1)
49
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
50
+
51
+ if self.image_key == "fbank":
52
+ self.vocoder = get_vocoder(None, "cpu", num_mel)
53
+ self.embed_dim = embed_dim
54
+ if colorize_nlabels is not None:
55
+ assert type(colorize_nlabels) == int
56
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
57
+ if monitor is not None:
58
+ self.monitor = monitor
59
+ if ckpt_path is not None:
60
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
61
+ self.learning_rate = float(base_learning_rate)
62
+ # print("Initial learning rate %s" % self.learning_rate)
63
+
64
+ self.time_shuffle = time_shuffle
65
+ self.reload_from_ckpt = reload_from_ckpt
66
+ self.reloaded = False
67
+ self.mean, self.std = None, None
68
+
69
+ self.feature_cache = None
70
+ self.flag_first_run = True
71
+ self.train_step = 0
72
+
73
+ self.logger_save_dir = None
74
+ self.logger_exp_name = None
75
+
76
+ def get_log_dir(self):
77
+ if self.logger_save_dir is None and self.logger_exp_name is None:
78
+ return os.path.join(self.logger.save_dir, self.logger._project)
79
+ else:
80
+ return os.path.join(self.logger_save_dir, self.logger_exp_name)
81
+
82
+ def set_log_dir(self, save_dir, exp_name):
83
+ self.logger_save_dir = save_dir
84
+ self.logger_exp_name = exp_name
85
+
86
+ def init_from_ckpt(self, path, ignore_keys=list()):
87
+ sd = torch.load(path, map_location="cpu")["state_dict"]
88
+ keys = list(sd.keys())
89
+ for k in keys:
90
+ for ik in ignore_keys:
91
+ if k.startswith(ik):
92
+ print("Deleting key {} from state_dict.".format(k))
93
+ del sd[k]
94
+ self.load_state_dict(sd, strict=False)
95
+ print(f"Restored from {path}")
96
+
97
+ def encode(self, x):
98
+ # x = self.time_shuffle_operation(x)
99
+ # x = self.freq_split_subband(x)
100
+ h = self.encoder(x)
101
+ moments = self.quant_conv(h)
102
+ posterior = DiagonalGaussianDistribution(moments)
103
+ return posterior
104
+
105
+ def decode(self, z):
106
+ z = self.post_quant_conv(z)
107
+ dec = self.decoder(z)
108
+ # bs, ch, shuffled_timesteps, fbins = dec.size()
109
+ # dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins)
110
+ # dec = self.freq_merge_subband(dec)
111
+ return dec
112
+
113
+ def decode_to_waveform(self, dec):
114
+ from audiosr.utilities.model import vocoder_infer
115
+
116
+ if self.image_key == "fbank":
117
+ dec = dec.squeeze(1).permute(0, 2, 1)
118
+ wav_reconstruction = vocoder_infer(dec, self.vocoder)
119
+ elif self.image_key == "stft":
120
+ dec = dec.squeeze(1).permute(0, 2, 1)
121
+ wav_reconstruction = self.wave_decoder(dec)
122
+ return wav_reconstruction
123
+
124
+ def visualize_latent(self, input):
125
+ import matplotlib.pyplot as plt
126
+
127
+ # for i in range(10):
128
+ # zero_input = torch.zeros_like(input) - 11.59
129
+ # zero_input[:,:,i * 16: i * 16 + 16,:16] += 13.59
130
+
131
+ # posterior = self.encode(zero_input)
132
+ # latent = posterior.sample()
133
+ # avg_latent = torch.mean(latent, dim=1)[0]
134
+ # plt.imshow(avg_latent.cpu().detach().numpy().T)
135
+ # plt.savefig("%s.png" % i)
136
+ # plt.close()
137
+
138
+ np.save("input.npy", input.cpu().detach().numpy())
139
+ # zero_input = torch.zeros_like(input) - 11.59
140
+ time_input = input.clone()
141
+ time_input[:, :, :, :32] *= 0
142
+ time_input[:, :, :, :32] -= 11.59
143
+
144
+ np.save("time_input.npy", time_input.cpu().detach().numpy())
145
+
146
+ posterior = self.encode(time_input)
147
+ latent = posterior.sample()
148
+ np.save("time_latent.npy", latent.cpu().detach().numpy())
149
+ avg_latent = torch.mean(latent, dim=1)
150
+ for i in range(avg_latent.size(0)):
151
+ plt.imshow(avg_latent[i].cpu().detach().numpy().T)
152
+ plt.savefig("freq_%s.png" % i)
153
+ plt.close()
154
+
155
+ freq_input = input.clone()
156
+ freq_input[:, :, :512, :] *= 0
157
+ freq_input[:, :, :512, :] -= 11.59
158
+
159
+ np.save("freq_input.npy", freq_input.cpu().detach().numpy())
160
+
161
+ posterior = self.encode(freq_input)
162
+ latent = posterior.sample()
163
+ np.save("freq_latent.npy", latent.cpu().detach().numpy())
164
+ avg_latent = torch.mean(latent, dim=1)
165
+ for i in range(avg_latent.size(0)):
166
+ plt.imshow(avg_latent[i].cpu().detach().numpy().T)
167
+ plt.savefig("time_%s.png" % i)
168
+ plt.close()
169
+
170
+ def get_input(self, batch):
171
+ fname, text, label_indices, waveform, stft, fbank = (
172
+ batch["fname"],
173
+ batch["text"],
174
+ batch["label_vector"],
175
+ batch["waveform"],
176
+ batch["stft"],
177
+ batch["log_mel_spec"],
178
+ )
179
+ # if(self.time_shuffle != 1):
180
+ # if(fbank.size(1) % self.time_shuffle != 0):
181
+ # pad_len = self.time_shuffle - (fbank.size(1) % self.time_shuffle)
182
+ # fbank = torch.nn.functional.pad(fbank, (0,0,0,pad_len))
183
+
184
+ ret = {}
185
+
186
+ ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = (
187
+ fbank.unsqueeze(1),
188
+ stft.unsqueeze(1),
189
+ fname,
190
+ waveform.unsqueeze(1),
191
+ )
192
+
193
+ return ret
194
+
195
+ def save_wave(self, batch_wav, fname, save_dir):
196
+ os.makedirs(save_dir, exist_ok=True)
197
+
198
+ for wav, name in zip(batch_wav, fname):
199
+ name = os.path.basename(name)
200
+
201
+ sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate)
202
+
203
+ def get_last_layer(self):
204
+ return self.decoder.conv_out.weight
205
+
206
+ @torch.no_grad()
207
+ def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs):
208
+ log = dict()
209
+ x = batch.to(self.device)
210
+ if not only_inputs:
211
+ xrec, posterior = self(x)
212
+ log["samples"] = self.decode(posterior.sample())
213
+ log["reconstructions"] = xrec
214
+
215
+ log["inputs"] = x
216
+ wavs = self._log_img(log, train=train, index=0, waveform=waveform)
217
+ return wavs
218
+
219
+ def _log_img(self, log, train=True, index=0, waveform=None):
220
+ images_input = self.tensor2numpy(log["inputs"][index, 0]).T
221
+ images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T
222
+ images_samples = self.tensor2numpy(log["samples"][index, 0]).T
223
+
224
+ if train:
225
+ name = "train"
226
+ else:
227
+ name = "val"
228
+
229
+ if self.logger is not None:
230
+ self.logger.log_image(
231
+ "img_%s" % name,
232
+ [images_input, images_reconstruct, images_samples],
233
+ caption=["input", "reconstruct", "samples"],
234
+ )
235
+
236
+ inputs, reconstructions, samples = (
237
+ log["inputs"],
238
+ log["reconstructions"],
239
+ log["samples"],
240
+ )
241
+
242
+ if self.image_key == "fbank":
243
+ wav_original, wav_prediction = synth_one_sample(
244
+ inputs[index],
245
+ reconstructions[index],
246
+ labels="validation",
247
+ vocoder=self.vocoder,
248
+ )
249
+ wav_original, wav_samples = synth_one_sample(
250
+ inputs[index], samples[index], labels="validation", vocoder=self.vocoder
251
+ )
252
+ wav_original, wav_samples, wav_prediction = (
253
+ wav_original[0],
254
+ wav_samples[0],
255
+ wav_prediction[0],
256
+ )
257
+ elif self.image_key == "stft":
258
+ wav_prediction = (
259
+ self.decode_to_waveform(reconstructions)[index, 0]
260
+ .cpu()
261
+ .detach()
262
+ .numpy()
263
+ )
264
+ wav_samples = (
265
+ self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy()
266
+ )
267
+ wav_original = waveform[index, 0].cpu().detach().numpy()
268
+
269
+ if self.logger is not None:
270
+ self.logger.experiment.log(
271
+ {
272
+ "original_%s"
273
+ % name: wandb.Audio(
274
+ wav_original, caption="original", sample_rate=self.sampling_rate
275
+ ),
276
+ "reconstruct_%s"
277
+ % name: wandb.Audio(
278
+ wav_prediction,
279
+ caption="reconstruct",
280
+ sample_rate=self.sampling_rate,
281
+ ),
282
+ "samples_%s"
283
+ % name: wandb.Audio(
284
+ wav_samples, caption="samples", sample_rate=self.sampling_rate
285
+ ),
286
+ }
287
+ )
288
+
289
+ return wav_original, wav_prediction, wav_samples
290
+
291
+ def tensor2numpy(self, tensor):
292
+ return tensor.cpu().detach().numpy()
293
+
294
+ def to_rgb(self, x):
295
+ assert self.image_key == "segmentation"
296
+ if not hasattr(self, "colorize"):
297
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
298
+ x = F.conv2d(x, weight=self.colorize)
299
+ x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0
300
+ return x
301
+
302
+
303
+ class IdentityFirstStage(torch.nn.Module):
304
+ def __init__(self, *args, vq_interface=False, **kwargs):
305
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
306
+ super().__init__()
307
+
308
+ def encode(self, x, *args, **kwargs):
309
+ return x
310
+
311
+ def decode(self, x, *args, **kwargs):
312
+ return x
313
+
314
+ def quantize(self, x, *args, **kwargs):
315
+ if self.vq_interface:
316
+ return x, None, [None, None, None]
317
+ return x
318
+
319
+ def forward(self, x, *args, **kwargs):
320
+ return x
321
+
322
+ class DiagonalGaussianDistribution(object):
323
+ def __init__(self, parameters, deterministic=False):
324
+ self.parameters = parameters
325
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
326
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
327
+ self.deterministic = deterministic
328
+ self.std = torch.exp(0.5 * self.logvar)
329
+ self.var = torch.exp(self.logvar)
330
+ if self.deterministic:
331
+ self.var = self.std = torch.zeros_like(self.mean).to(
332
+ device=self.parameters.device
333
+ )
334
+
335
+ def sample(self):
336
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(
337
+ device=self.parameters.device
338
+ )
339
+ return x
340
+
341
+ def kl(self, other=None):
342
+ if self.deterministic:
343
+ return torch.Tensor([0.0])
344
+ else:
345
+ if other is None:
346
+ return 0.5 * torch.mean(
347
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
348
+ dim=[1, 2, 3],
349
+ )
350
+ else:
351
+ return 0.5 * torch.mean(
352
+ torch.pow(self.mean - other.mean, 2) / other.var
353
+ + self.var / other.var
354
+ - 1.0
355
+ - self.logvar
356
+ + other.logvar,
357
+ dim=[1, 2, 3],
358
+ )
359
+
360
+ def nll(self, sample, dims=[1, 2, 3]):
361
+ if self.deterministic:
362
+ return torch.Tensor([0.0])
363
+ logtwopi = np.log(2.0 * np.pi)
364
+ return 0.5 * torch.sum(
365
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
366
+ dim=dims,
367
+ )
368
+
369
+ def mode(self):
370
+ return self.mean
FlashSR/AudioSR/hifigan/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Jungil Kong
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
FlashSR/AudioSR/hifigan/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .models_v2 import Generator
2
+ from .models import Generator as Generator_old
3
+
4
+
5
+ class AttrDict(dict):
6
+ def __init__(self, *args, **kwargs):
7
+ super(AttrDict, self).__init__(*args, **kwargs)
8
+ self.__dict__ = self
FlashSR/AudioSR/hifigan/models.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import Conv1d, ConvTranspose1d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm
6
+
7
+ LRELU_SLOPE = 0.1
8
+
9
+
10
+ def init_weights(m, mean=0.0, std=0.01):
11
+ classname = m.__class__.__name__
12
+ if classname.find("Conv") != -1:
13
+ m.weight.data.normal_(mean, std)
14
+
15
+
16
+ def get_padding(kernel_size, dilation=1):
17
+ return int((kernel_size * dilation - dilation) / 2)
18
+
19
+
20
+ class ResBlock(torch.nn.Module):
21
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
22
+ super(ResBlock, self).__init__()
23
+ self.h = h
24
+ self.convs1 = nn.ModuleList(
25
+ [
26
+ weight_norm(
27
+ Conv1d(
28
+ channels,
29
+ channels,
30
+ kernel_size,
31
+ 1,
32
+ dilation=dilation[0],
33
+ padding=get_padding(kernel_size, dilation[0]),
34
+ )
35
+ ),
36
+ weight_norm(
37
+ Conv1d(
38
+ channels,
39
+ channels,
40
+ kernel_size,
41
+ 1,
42
+ dilation=dilation[1],
43
+ padding=get_padding(kernel_size, dilation[1]),
44
+ )
45
+ ),
46
+ weight_norm(
47
+ Conv1d(
48
+ channels,
49
+ channels,
50
+ kernel_size,
51
+ 1,
52
+ dilation=dilation[2],
53
+ padding=get_padding(kernel_size, dilation[2]),
54
+ )
55
+ ),
56
+ ]
57
+ )
58
+ self.convs1.apply(init_weights)
59
+
60
+ self.convs2 = nn.ModuleList(
61
+ [
62
+ weight_norm(
63
+ Conv1d(
64
+ channels,
65
+ channels,
66
+ kernel_size,
67
+ 1,
68
+ dilation=1,
69
+ padding=get_padding(kernel_size, 1),
70
+ )
71
+ ),
72
+ weight_norm(
73
+ Conv1d(
74
+ channels,
75
+ channels,
76
+ kernel_size,
77
+ 1,
78
+ dilation=1,
79
+ padding=get_padding(kernel_size, 1),
80
+ )
81
+ ),
82
+ weight_norm(
83
+ Conv1d(
84
+ channels,
85
+ channels,
86
+ kernel_size,
87
+ 1,
88
+ dilation=1,
89
+ padding=get_padding(kernel_size, 1),
90
+ )
91
+ ),
92
+ ]
93
+ )
94
+ self.convs2.apply(init_weights)
95
+
96
+ def forward(self, x):
97
+ for c1, c2 in zip(self.convs1, self.convs2):
98
+ xt = F.leaky_relu(x, LRELU_SLOPE)
99
+ xt = c1(xt)
100
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
101
+ xt = c2(xt)
102
+ x = xt + x
103
+ return x
104
+
105
+ def remove_weight_norm(self):
106
+ for l in self.convs1:
107
+ remove_weight_norm(l)
108
+ for l in self.convs2:
109
+ remove_weight_norm(l)
110
+
111
+
112
+ class Generator(torch.nn.Module):
113
+ def __init__(self, h):
114
+ super(Generator, self).__init__()
115
+ self.h = h
116
+ self.num_kernels = len(h.resblock_kernel_sizes)
117
+ self.num_upsamples = len(h.upsample_rates)
118
+ self.conv_pre = weight_norm(
119
+ Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
120
+ )
121
+ resblock = ResBlock
122
+
123
+ self.ups = nn.ModuleList()
124
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
125
+ self.ups.append(
126
+ weight_norm(
127
+ ConvTranspose1d(
128
+ h.upsample_initial_channel // (2**i),
129
+ h.upsample_initial_channel // (2 ** (i + 1)),
130
+ k,
131
+ u,
132
+ padding=(k - u) // 2,
133
+ )
134
+ )
135
+ )
136
+
137
+ self.resblocks = nn.ModuleList()
138
+ for i in range(len(self.ups)):
139
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
140
+ for j, (k, d) in enumerate(
141
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
142
+ ):
143
+ self.resblocks.append(resblock(h, ch, k, d))
144
+
145
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
146
+ self.ups.apply(init_weights)
147
+ self.conv_post.apply(init_weights)
148
+
149
+ def forward(self, x):
150
+ x = self.conv_pre(x)
151
+ for i in range(self.num_upsamples):
152
+ x = F.leaky_relu(x, LRELU_SLOPE)
153
+ x = self.ups[i](x)
154
+ xs = None
155
+ for j in range(self.num_kernels):
156
+ if xs is None:
157
+ xs = self.resblocks[i * self.num_kernels + j](x)
158
+ else:
159
+ xs += self.resblocks[i * self.num_kernels + j](x)
160
+ x = xs / self.num_kernels
161
+ x = F.leaky_relu(x)
162
+ x = self.conv_post(x)
163
+ x = torch.tanh(x)
164
+
165
+ return x
166
+
167
+ def remove_weight_norm(self):
168
+ # print("Removing weight norm...")
169
+ for l in self.ups:
170
+ remove_weight_norm(l)
171
+ for l in self.resblocks:
172
+ l.remove_weight_norm()
173
+ remove_weight_norm(self.conv_pre)
174
+ remove_weight_norm(self.conv_post)
FlashSR/AudioSR/hifigan/models_v2.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm
6
+
7
+ LRELU_SLOPE = 0.1
8
+
9
+
10
+ def init_weights(m, mean=0.0, std=0.01):
11
+ classname = m.__class__.__name__
12
+ if classname.find("Conv") != -1:
13
+ m.weight.data.normal_(mean, std)
14
+
15
+
16
+ def get_padding(kernel_size, dilation=1):
17
+ return int((kernel_size * dilation - dilation) / 2)
18
+
19
+
20
+ class ResBlock1(torch.nn.Module):
21
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
22
+ super(ResBlock1, self).__init__()
23
+ self.h = h
24
+ self.convs1 = nn.ModuleList(
25
+ [
26
+ weight_norm(
27
+ Conv1d(
28
+ channels,
29
+ channels,
30
+ kernel_size,
31
+ 1,
32
+ dilation=dilation[0],
33
+ padding=get_padding(kernel_size, dilation[0]),
34
+ )
35
+ ),
36
+ weight_norm(
37
+ Conv1d(
38
+ channels,
39
+ channels,
40
+ kernel_size,
41
+ 1,
42
+ dilation=dilation[1],
43
+ padding=get_padding(kernel_size, dilation[1]),
44
+ )
45
+ ),
46
+ weight_norm(
47
+ Conv1d(
48
+ channels,
49
+ channels,
50
+ kernel_size,
51
+ 1,
52
+ dilation=dilation[2],
53
+ padding=get_padding(kernel_size, dilation[2]),
54
+ )
55
+ ),
56
+ ]
57
+ )
58
+ self.convs1.apply(init_weights)
59
+
60
+ self.convs2 = nn.ModuleList(
61
+ [
62
+ weight_norm(
63
+ Conv1d(
64
+ channels,
65
+ channels,
66
+ kernel_size,
67
+ 1,
68
+ dilation=1,
69
+ padding=get_padding(kernel_size, 1),
70
+ )
71
+ ),
72
+ weight_norm(
73
+ Conv1d(
74
+ channels,
75
+ channels,
76
+ kernel_size,
77
+ 1,
78
+ dilation=1,
79
+ padding=get_padding(kernel_size, 1),
80
+ )
81
+ ),
82
+ weight_norm(
83
+ Conv1d(
84
+ channels,
85
+ channels,
86
+ kernel_size,
87
+ 1,
88
+ dilation=1,
89
+ padding=get_padding(kernel_size, 1),
90
+ )
91
+ ),
92
+ ]
93
+ )
94
+ self.convs2.apply(init_weights)
95
+
96
+ def forward(self, x):
97
+ for c1, c2 in zip(self.convs1, self.convs2):
98
+ xt = F.leaky_relu(x, LRELU_SLOPE)
99
+ xt = c1(xt)
100
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
101
+ xt = c2(xt)
102
+ x = xt + x
103
+ return x
104
+
105
+ def remove_weight_norm(self):
106
+ for l in self.convs1:
107
+ remove_weight_norm(l)
108
+ for l in self.convs2:
109
+ remove_weight_norm(l)
110
+
111
+
112
+ class ResBlock2(torch.nn.Module):
113
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
114
+ super(ResBlock2, self).__init__()
115
+ self.h = h
116
+ self.convs = nn.ModuleList(
117
+ [
118
+ weight_norm(
119
+ Conv1d(
120
+ channels,
121
+ channels,
122
+ kernel_size,
123
+ 1,
124
+ dilation=dilation[0],
125
+ padding=get_padding(kernel_size, dilation[0]),
126
+ )
127
+ ),
128
+ weight_norm(
129
+ Conv1d(
130
+ channels,
131
+ channels,
132
+ kernel_size,
133
+ 1,
134
+ dilation=dilation[1],
135
+ padding=get_padding(kernel_size, dilation[1]),
136
+ )
137
+ ),
138
+ ]
139
+ )
140
+ self.convs.apply(init_weights)
141
+
142
+ def forward(self, x):
143
+ for c in self.convs:
144
+ xt = F.leaky_relu(x, LRELU_SLOPE)
145
+ xt = c(xt)
146
+ x = xt + x
147
+ return x
148
+
149
+ def remove_weight_norm(self):
150
+ for l in self.convs:
151
+ remove_weight_norm(l)
152
+
153
+
154
+ class Generator(torch.nn.Module):
155
+ def __init__(self, h):
156
+ super(Generator, self).__init__()
157
+ self.h = h
158
+ self.num_kernels = len(h.resblock_kernel_sizes)
159
+ self.num_upsamples = len(h.upsample_rates)
160
+ self.conv_pre = weight_norm(
161
+ Conv1d(256, h.upsample_initial_channel, 7, 1, padding=3)
162
+ )
163
+ resblock = ResBlock1 if h.resblock == "1" else ResBlock2
164
+
165
+ self.ups = nn.ModuleList()
166
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
167
+ self.ups.append(
168
+ weight_norm(
169
+ ConvTranspose1d(
170
+ h.upsample_initial_channel // (2**i),
171
+ h.upsample_initial_channel // (2 ** (i + 1)),
172
+ u * 2,
173
+ u,
174
+ padding=u // 2 + u % 2,
175
+ output_padding=u % 2,
176
+ )
177
+ )
178
+ )
179
+
180
+ self.resblocks = nn.ModuleList()
181
+ for i in range(len(self.ups)):
182
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
183
+ for j, (k, d) in enumerate(
184
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
185
+ ):
186
+ self.resblocks.append(resblock(h, ch, k, d))
187
+
188
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
189
+ self.ups.apply(init_weights)
190
+ self.conv_post.apply(init_weights)
191
+
192
+ def forward(self, x):
193
+ # import ipdb; ipdb.set_trace()
194
+ x = self.conv_pre(x)
195
+ for i in range(self.num_upsamples):
196
+ x = F.leaky_relu(x, LRELU_SLOPE)
197
+ x = self.ups[i](x)
198
+ xs = None
199
+ for j in range(self.num_kernels):
200
+ if xs is None:
201
+ xs = self.resblocks[i * self.num_kernels + j](x)
202
+ else:
203
+ xs += self.resblocks[i * self.num_kernels + j](x)
204
+ x = xs / self.num_kernels
205
+ x = F.leaky_relu(x)
206
+ x = self.conv_post(x)
207
+ x = torch.tanh(x)
208
+
209
+ return x
210
+
211
+ def remove_weight_norm(self):
212
+ # print('Removing weight norm...')
213
+ for l in self.ups:
214
+ remove_weight_norm(l)
215
+ for l in self.resblocks:
216
+ l.remove_weight_norm()
217
+ remove_weight_norm(self.conv_pre)
218
+ remove_weight_norm(self.conv_post)
219
+
220
+
221
+ ##################################################################################################
222
+
223
+ # import torch
224
+ # import torch.nn as nn
225
+ # import torch.nn.functional as F
226
+ # from torch.nn import Conv1d, ConvTranspose1d
227
+ # from torch.nn.utils import weight_norm, remove_weight_norm
228
+
229
+ # LRELU_SLOPE = 0.1
230
+
231
+
232
+ # def init_weights(m, mean=0.0, std=0.01):
233
+ # classname = m.__class__.__name__
234
+ # if classname.find("Conv") != -1:
235
+ # m.weight.data.normal_(mean, std)
236
+
237
+
238
+ # def get_padding(kernel_size, dilation=1):
239
+ # return int((kernel_size * dilation - dilation) / 2)
240
+
241
+
242
+ # class ResBlock(torch.nn.Module):
243
+ # def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
244
+ # super(ResBlock, self).__init__()
245
+ # self.h = h
246
+ # self.convs1 = nn.ModuleList(
247
+ # [
248
+ # weight_norm(
249
+ # Conv1d(
250
+ # channels,
251
+ # channels,
252
+ # kernel_size,
253
+ # 1,
254
+ # dilation=dilation[0],
255
+ # padding=get_padding(kernel_size, dilation[0]),
256
+ # )
257
+ # ),
258
+ # weight_norm(
259
+ # Conv1d(
260
+ # channels,
261
+ # channels,
262
+ # kernel_size,
263
+ # 1,
264
+ # dilation=dilation[1],
265
+ # padding=get_padding(kernel_size, dilation[1]),
266
+ # )
267
+ # ),
268
+ # weight_norm(
269
+ # Conv1d(
270
+ # channels,
271
+ # channels,
272
+ # kernel_size,
273
+ # 1,
274
+ # dilation=dilation[2],
275
+ # padding=get_padding(kernel_size, dilation[2]),
276
+ # )
277
+ # ),
278
+ # ]
279
+ # )
280
+ # self.convs1.apply(init_weights)
281
+
282
+ # self.convs2 = nn.ModuleList(
283
+ # [
284
+ # weight_norm(
285
+ # Conv1d(
286
+ # channels,
287
+ # channels,
288
+ # kernel_size,
289
+ # 1,
290
+ # dilation=1,
291
+ # padding=get_padding(kernel_size, 1),
292
+ # )
293
+ # ),
294
+ # weight_norm(
295
+ # Conv1d(
296
+ # channels,
297
+ # channels,
298
+ # kernel_size,
299
+ # 1,
300
+ # dilation=1,
301
+ # padding=get_padding(kernel_size, 1),
302
+ # )
303
+ # ),
304
+ # weight_norm(
305
+ # Conv1d(
306
+ # channels,
307
+ # channels,
308
+ # kernel_size,
309
+ # 1,
310
+ # dilation=1,
311
+ # padding=get_padding(kernel_size, 1),
312
+ # )
313
+ # ),
314
+ # ]
315
+ # )
316
+ # self.convs2.apply(init_weights)
317
+
318
+ # def forward(self, x):
319
+ # for c1, c2 in zip(self.convs1, self.convs2):
320
+ # xt = F.leaky_relu(x, LRELU_SLOPE)
321
+ # xt = c1(xt)
322
+ # xt = F.leaky_relu(xt, LRELU_SLOPE)
323
+ # xt = c2(xt)
324
+ # x = xt + x
325
+ # return x
326
+
327
+ # def remove_weight_norm(self):
328
+ # for l in self.convs1:
329
+ # remove_weight_norm(l)
330
+ # for l in self.convs2:
331
+ # remove_weight_norm(l)
332
+
333
+ # class Generator(torch.nn.Module):
334
+ # def __init__(self, h):
335
+ # super(Generator, self).__init__()
336
+ # self.h = h
337
+ # self.num_kernels = len(h.resblock_kernel_sizes)
338
+ # self.num_upsamples = len(h.upsample_rates)
339
+ # self.conv_pre = weight_norm(
340
+ # Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
341
+ # )
342
+ # resblock = ResBlock
343
+
344
+ # self.ups = nn.ModuleList()
345
+ # for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
346
+ # self.ups.append(
347
+ # weight_norm(
348
+ # ConvTranspose1d(
349
+ # h.upsample_initial_channel // (2**i),
350
+ # h.upsample_initial_channel // (2 ** (i + 1)),
351
+ # k,
352
+ # u,
353
+ # padding=(k - u) // 2,
354
+ # )
355
+ # )
356
+ # )
357
+
358
+ # self.resblocks = nn.ModuleList()
359
+ # for i in range(len(self.ups)):
360
+ # ch = h.upsample_initial_channel // (2 ** (i + 1))
361
+ # for j, (k, d) in enumerate(
362
+ # zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
363
+ # ):
364
+ # self.resblocks.append(resblock(h, ch, k, d))
365
+
366
+ # self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
367
+ # self.ups.apply(init_weights)
368
+ # self.conv_post.apply(init_weights)
369
+
370
+ # def forward(self, x):
371
+ # x = self.conv_pre(x)
372
+ # for i in range(self.num_upsamples):
373
+ # x = F.leaky_relu(x, LRELU_SLOPE)
374
+ # x = self.ups[i](x)
375
+ # xs = None
376
+ # for j in range(self.num_kernels):
377
+ # if xs is None:
378
+ # xs = self.resblocks[i * self.num_kernels + j](x)
379
+ # else:
380
+ # xs += self.resblocks[i * self.num_kernels + j](x)
381
+ # x = xs / self.num_kernels
382
+ # x = F.leaky_relu(x)
383
+ # x = self.conv_post(x)
384
+ # x = torch.tanh(x)
385
+
386
+ # return x
387
+
388
+ # def remove_weight_norm(self):
389
+ # print("Removing weight norm...")
390
+ # for l in self.ups:
391
+ # remove_weight_norm(l)
392
+ # for l in self.resblocks:
393
+ # l.remove_weight_norm()
394
+ # remove_weight_norm(self.conv_pre)
395
+ # remove_weight_norm(self.conv_post)
FlashSR/AudioSR/latent_diffusion/__init__.py ADDED
File without changes
FlashSR/AudioSR/latent_diffusion/modules/attention.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inspect import isfunction
2
+ import math
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn, einsum
6
+ from einops import rearrange, repeat
7
+
8
+ from FlashSR.AudioSR.latent_diffusion.modules.diffusionmodules.util import checkpoint
9
+
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+
15
+ def uniq(arr):
16
+ return {el: True for el in arr}.keys()
17
+
18
+
19
+ def default(val, d):
20
+ if exists(val):
21
+ return val
22
+ return d() if isfunction(d) else d
23
+
24
+
25
+ def max_neg_value(t):
26
+ return -torch.finfo(t.dtype).max
27
+
28
+
29
+ def init_(tensor):
30
+ dim = tensor.shape[-1]
31
+ std = 1 / math.sqrt(dim)
32
+ tensor.uniform_(-std, std)
33
+ return tensor
34
+
35
+
36
+ # feedforward
37
+ class GEGLU(nn.Module):
38
+ def __init__(self, dim_in, dim_out):
39
+ super().__init__()
40
+ self.proj = nn.Linear(dim_in, dim_out * 2)
41
+
42
+ def forward(self, x):
43
+ x, gate = self.proj(x).chunk(2, dim=-1)
44
+ return x * F.gelu(gate)
45
+
46
+
47
+ class FeedForward(nn.Module):
48
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
49
+ super().__init__()
50
+ inner_dim = int(dim * mult)
51
+ dim_out = default(dim_out, dim)
52
+ project_in = (
53
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
54
+ if not glu
55
+ else GEGLU(dim, inner_dim)
56
+ )
57
+
58
+ self.net = nn.Sequential(
59
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
60
+ )
61
+
62
+ def forward(self, x):
63
+ return self.net(x)
64
+
65
+
66
+ def zero_module(module):
67
+ """
68
+ Zero out the parameters of a module and return it.
69
+ """
70
+ for p in module.parameters():
71
+ p.detach().zero_()
72
+ return module
73
+
74
+
75
+ def Normalize(in_channels):
76
+ return torch.nn.GroupNorm(
77
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
78
+ )
79
+
80
+
81
+ class LinearAttention(nn.Module):
82
+ def __init__(self, dim, heads=4, dim_head=32):
83
+ super().__init__()
84
+ self.heads = heads
85
+ hidden_dim = dim_head * heads
86
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
87
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
88
+
89
+ def forward(self, x):
90
+ b, c, h, w = x.shape
91
+ qkv = self.to_qkv(x)
92
+ q, k, v = rearrange(
93
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
94
+ )
95
+ k = k.softmax(dim=-1)
96
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
97
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
98
+ out = rearrange(
99
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
100
+ )
101
+ return self.to_out(out)
102
+
103
+
104
+ class SpatialSelfAttention(nn.Module):
105
+ def __init__(self, in_channels):
106
+ super().__init__()
107
+ self.in_channels = in_channels
108
+
109
+ self.norm = Normalize(in_channels)
110
+ self.q = torch.nn.Conv2d(
111
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
112
+ )
113
+ self.k = torch.nn.Conv2d(
114
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
115
+ )
116
+ self.v = torch.nn.Conv2d(
117
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
118
+ )
119
+ self.proj_out = torch.nn.Conv2d(
120
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
121
+ )
122
+
123
+ def forward(self, x):
124
+ h_ = x
125
+ h_ = self.norm(h_)
126
+ q = self.q(h_)
127
+ k = self.k(h_)
128
+ v = self.v(h_)
129
+
130
+ # compute attention
131
+ b, c, h, w = q.shape
132
+ q = rearrange(q, "b c h w -> b (h w) c")
133
+ k = rearrange(k, "b c h w -> b c (h w)")
134
+ w_ = torch.einsum("bij,bjk->bik", q, k)
135
+
136
+ w_ = w_ * (int(c) ** (-0.5))
137
+ w_ = torch.nn.functional.softmax(w_, dim=2)
138
+
139
+ # attend to values
140
+ v = rearrange(v, "b c h w -> b c (h w)")
141
+ w_ = rearrange(w_, "b i j -> b j i")
142
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
143
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
144
+ h_ = self.proj_out(h_)
145
+
146
+ return x + h_
147
+
148
+
149
+ # class CrossAttention(nn.Module):
150
+ # """
151
+ # ### Cross Attention Layer
152
+ # This falls-back to self-attention when conditional embeddings are not specified.
153
+ # """
154
+
155
+ # use_flash_attention: bool = True
156
+
157
+ # # use_flash_attention: bool = False
158
+ # def __init__(
159
+ # self,
160
+ # query_dim,
161
+ # context_dim=None,
162
+ # heads=8,
163
+ # dim_head=64,
164
+ # dropout=0.0,
165
+ # is_inplace: bool = True,
166
+ # ):
167
+ # # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):
168
+ # """
169
+ # :param d_model: is the input embedding size
170
+ # :param n_heads: is the number of attention heads
171
+ # :param d_head: is the size of a attention head
172
+ # :param d_cond: is the size of the conditional embeddings
173
+ # :param is_inplace: specifies whether to perform the attention softmax computation inplace to
174
+ # save memory
175
+ # """
176
+ # super().__init__()
177
+
178
+ # self.is_inplace = is_inplace
179
+ # self.n_heads = heads
180
+ # self.d_head = dim_head
181
+
182
+ # # Attention scaling factor
183
+ # self.scale = dim_head**-0.5
184
+
185
+ # # The normal self-attention layer
186
+ # if context_dim is None:
187
+ # context_dim = query_dim
188
+
189
+ # # Query, key and value mappings
190
+ # d_attn = dim_head * heads
191
+ # self.to_q = nn.Linear(query_dim, d_attn, bias=False)
192
+ # self.to_k = nn.Linear(context_dim, d_attn, bias=False)
193
+ # self.to_v = nn.Linear(context_dim, d_attn, bias=False)
194
+
195
+ # # Final linear layer
196
+ # self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout))
197
+
198
+ # # Setup [flash attention](https://github.com/HazyResearch/flash-attention).
199
+ # # Flash attention is only used if it's installed
200
+ # # and `CrossAttention.use_flash_attention` is set to `True`.
201
+ # try:
202
+ # # You can install flash attention by cloning their Github repo,
203
+ # # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
204
+ # # and then running `python setup.py install`
205
+ # from flash_attn.flash_attention import FlashAttention
206
+
207
+ # self.flash = FlashAttention()
208
+ # # Set the scale for scaled dot-product attention.
209
+ # self.flash.softmax_scale = self.scale
210
+ # # Set to `None` if it's not installed
211
+ # except ImportError:
212
+ # self.flash = None
213
+
214
+ # def forward(self, x, context=None, mask=None):
215
+ # """
216
+ # :param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
217
+ # :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
218
+ # """
219
+
220
+ # # If `cond` is `None` we perform self attention
221
+ # has_cond = context is not None
222
+ # if not has_cond:
223
+ # context = x
224
+
225
+ # # Get query, key and value vectors
226
+ # q = self.to_q(x)
227
+ # k = self.to_k(context)
228
+ # v = self.to_v(context)
229
+
230
+ # # Use flash attention if it's available and the head size is less than or equal to `128`
231
+ # if (
232
+ # CrossAttention.use_flash_attention
233
+ # and self.flash is not None
234
+ # and not has_cond
235
+ # and self.d_head <= 128
236
+ # ):
237
+ # return self.flash_attention(q, k, v)
238
+ # # Otherwise, fallback to normal attention
239
+ # else:
240
+ # return self.normal_attention(q, k, v)
241
+
242
+ # def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
243
+ # """
244
+ # #### Flash Attention
245
+ # :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
246
+ # :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
247
+ # :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
248
+ # """
249
+
250
+ # # Get batch size and number of elements along sequence axis (`width * height`)
251
+ # batch_size, seq_len, _ = q.shape
252
+
253
+ # # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
254
+ # # shape `[batch_size, seq_len, 3, n_heads * d_head]`
255
+ # qkv = torch.stack((q, k, v), dim=2)
256
+ # # Split the heads
257
+ # qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
258
+
259
+ # # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
260
+ # # fit this size.
261
+ # if self.d_head <= 32:
262
+ # pad = 32 - self.d_head
263
+ # elif self.d_head <= 64:
264
+ # pad = 64 - self.d_head
265
+ # elif self.d_head <= 128:
266
+ # pad = 128 - self.d_head
267
+ # else:
268
+ # raise ValueError(f"Head size ${self.d_head} too large for Flash Attention")
269
+
270
+ # # Pad the heads
271
+ # if pad:
272
+ # qkv = torch.cat(
273
+ # (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1
274
+ # )
275
+
276
+ # # Compute attention
277
+ # # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
278
+ # # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
279
+ # # TODO here I add the dtype changing
280
+ # out, _ = self.flash(qkv.type(torch.float16))
281
+ # # Truncate the extra head size
282
+ # out = out[:, :, :, : self.d_head].float()
283
+ # # Reshape to `[batch_size, seq_len, n_heads * d_head]`
284
+ # out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
285
+
286
+ # # Map to `[batch_size, height * width, d_model]` with a linear layer
287
+ # return self.to_out(out)
288
+
289
+ # def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
290
+ # """
291
+ # #### Normal Attention
292
+
293
+ # :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
294
+ # :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
295
+ # :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
296
+ # """
297
+
298
+ # # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
299
+ # q = q.view(*q.shape[:2], self.n_heads, -1) # [bs, 64, 20, 32]
300
+ # k = k.view(*k.shape[:2], self.n_heads, -1) # [bs, 1, 20, 32]
301
+ # v = v.view(*v.shape[:2], self.n_heads, -1)
302
+
303
+ # # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
304
+ # attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale
305
+
306
+ # # Compute softmax
307
+ # # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
308
+ # if self.is_inplace:
309
+ # half = attn.shape[0] // 2
310
+ # attn[half:] = attn[half:].softmax(dim=-1)
311
+ # attn[:half] = attn[:half].softmax(dim=-1)
312
+ # else:
313
+ # attn = attn.softmax(dim=-1)
314
+
315
+ # # Compute attention output
316
+ # # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
317
+ # # attn: [bs, 20, 64, 1]
318
+ # # v: [bs, 1, 20, 32]
319
+ # out = torch.einsum("bhij,bjhd->bihd", attn, v)
320
+ # # Reshape to `[batch_size, height * width, n_heads * d_head]`
321
+ # out = out.reshape(*out.shape[:2], -1)
322
+ # # Map to `[batch_size, height * width, d_model]` with a linear layer
323
+ # return self.to_out(out)
324
+
325
+
326
+ class CrossAttention(nn.Module):
327
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
328
+ super().__init__()
329
+ inner_dim = dim_head * heads
330
+ context_dim = default(context_dim, query_dim)
331
+
332
+ self.scale = dim_head**-0.5
333
+ self.heads = heads
334
+
335
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
336
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
337
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
338
+
339
+ self.to_out = nn.Sequential(
340
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
341
+ )
342
+
343
+ def forward(self, x, context=None, mask=None):
344
+ h = self.heads
345
+
346
+ q = self.to_q(x)
347
+ context = default(context, x)
348
+
349
+ k = self.to_k(context)
350
+ v = self.to_v(context)
351
+
352
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
353
+
354
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
355
+
356
+ if exists(mask):
357
+ mask = rearrange(mask, "b ... -> b (...)")
358
+ max_neg_value = -torch.finfo(sim.dtype).max
359
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
360
+ sim.masked_fill_(~(mask == 1), max_neg_value)
361
+
362
+ # attention, what we cannot get enough of
363
+ attn = sim.softmax(dim=-1)
364
+
365
+ out = einsum("b i j, b j d -> b i d", attn, v)
366
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
367
+ return self.to_out(out)
368
+
369
+
370
+ class BasicTransformerBlock(nn.Module):
371
+ def __init__(
372
+ self,
373
+ dim,
374
+ n_heads,
375
+ d_head,
376
+ dropout=0.0,
377
+ context_dim=None,
378
+ gated_ff=True,
379
+ checkpoint=True,
380
+ ):
381
+ super().__init__()
382
+ self.attn1 = CrossAttention(
383
+ query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
384
+ ) # is a self-attention
385
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
386
+ self.attn2 = CrossAttention(
387
+ query_dim=dim,
388
+ context_dim=context_dim,
389
+ heads=n_heads,
390
+ dim_head=d_head,
391
+ dropout=dropout,
392
+ ) # is self-attn if context is none
393
+ self.norm1 = nn.LayerNorm(dim)
394
+ self.norm2 = nn.LayerNorm(dim)
395
+ self.norm3 = nn.LayerNorm(dim)
396
+ self.checkpoint = checkpoint
397
+
398
+ def forward(self, x, context=None, mask=None):
399
+ if context is None:
400
+ return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint)
401
+ else:
402
+ return checkpoint(
403
+ self._forward, (x, context, mask), self.parameters(), self.checkpoint
404
+ )
405
+
406
+ def _forward(self, x, context=None, mask=None):
407
+ x = self.attn1(self.norm1(x)) + x
408
+ x = self.attn2(self.norm2(x), context=context, mask=mask) + x
409
+ x = self.ff(self.norm3(x)) + x
410
+ return x
411
+
412
+
413
+ class SpatialTransformer(nn.Module):
414
+ """
415
+ Transformer block for image-like data.
416
+ First, project the input (aka embedding)
417
+ and reshape to b, t, d.
418
+ Then apply standard transformer action.
419
+ Finally, reshape to image
420
+ """
421
+
422
+ def __init__(
423
+ self,
424
+ in_channels,
425
+ n_heads,
426
+ d_head,
427
+ depth=1,
428
+ dropout=0.0,
429
+ context_dim=None,
430
+ ):
431
+ super().__init__()
432
+
433
+ context_dim = context_dim
434
+
435
+ self.in_channels = in_channels
436
+ inner_dim = n_heads * d_head
437
+ self.norm = Normalize(in_channels)
438
+
439
+ self.proj_in = nn.Conv2d(
440
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
441
+ )
442
+
443
+ self.transformer_blocks = nn.ModuleList(
444
+ [
445
+ BasicTransformerBlock(
446
+ inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
447
+ )
448
+ for d in range(depth)
449
+ ]
450
+ )
451
+
452
+ self.proj_out = zero_module(
453
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
454
+ )
455
+
456
+ def forward(self, x, context=None, mask=None):
457
+ # note: if no context is given, cross-attention defaults to self-attention
458
+ b, c, h, w = x.shape
459
+ x_in = x
460
+ x = self.norm(x)
461
+ x = self.proj_in(x)
462
+ x = rearrange(x, "b c h w -> b (h w) c")
463
+ for block in self.transformer_blocks:
464
+ x = block(x, context=context, mask=mask)
465
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
466
+ x = self.proj_out(x)
467
+ return x + x_in
FlashSR/AudioSR/latent_diffusion/modules/audiomae/AudioMAE.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reference Repo: https://github.com/facebookresearch/AudioMAE
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ #from timm.models.layers import to_2tuple
8
+ import audiosr.latent_diffusion.modules.audiomae.models_vit as models_vit
9
+ import audiosr.latent_diffusion.modules.audiomae.models_mae as models_mae
10
+
11
+ # model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128))
12
+
13
+
14
+ class PatchEmbed_new(nn.Module):
15
+ """Flexible Image to Patch Embedding"""
16
+
17
+ def __init__(
18
+ self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10
19
+ ):
20
+ super().__init__()
21
+ img_size = to_2tuple(img_size)
22
+ patch_size = to_2tuple(patch_size)
23
+ stride = to_2tuple(stride)
24
+
25
+ self.img_size = img_size
26
+ self.patch_size = patch_size
27
+
28
+ self.proj = nn.Conv2d(
29
+ in_chans, embed_dim, kernel_size=patch_size, stride=stride
30
+ ) # with overlapped patches
31
+ # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
32
+
33
+ # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
34
+ # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
35
+ _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
36
+ self.patch_hw = (h, w)
37
+ self.num_patches = h * w
38
+
39
+ def get_output_shape(self, img_size):
40
+ # todo: don't be lazy..
41
+ return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape
42
+
43
+ def forward(self, x):
44
+ B, C, H, W = x.shape
45
+ # FIXME look at relaxing size constraints
46
+ # assert H == self.img_size[0] and W == self.img_size[1], \
47
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
48
+ x = self.proj(x)
49
+ x = x.flatten(2).transpose(1, 2)
50
+ return x
51
+
52
+
53
+ class AudioMAE(nn.Module):
54
+ """Audio Masked Autoencoder (MAE) pre-trained and finetuned on AudioSet (for SoundCLIP)"""
55
+
56
+ def __init__(
57
+ self,
58
+ ):
59
+ super().__init__()
60
+ model = models_vit.__dict__["vit_base_patch16"](
61
+ num_classes=527,
62
+ drop_path_rate=0.1,
63
+ global_pool=True,
64
+ mask_2d=True,
65
+ use_custom_patch=False,
66
+ )
67
+
68
+ img_size = (1024, 128)
69
+ emb_dim = 768
70
+
71
+ model.patch_embed = PatchEmbed_new(
72
+ img_size=img_size,
73
+ patch_size=(16, 16),
74
+ in_chans=1,
75
+ embed_dim=emb_dim,
76
+ stride=16,
77
+ )
78
+ num_patches = model.patch_embed.num_patches
79
+ # num_patches = 512 # assume audioset, 1024//16=64, 128//16=8, 512=64x8
80
+ model.pos_embed = nn.Parameter(
81
+ torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False
82
+ ) # fixed sin-cos embedding
83
+
84
+ # checkpoint_path = '/mnt/bn/data-xubo/project/Masked_AudioEncoder/checkpoint/finetuned.pth'
85
+ # checkpoint = torch.load(checkpoint_path, map_location='cpu')
86
+ # msg = model.load_state_dict(checkpoint['model'], strict=False)
87
+ # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')
88
+
89
+ self.model = model
90
+
91
+ def forward(self, x, mask_t_prob=0.0, mask_f_prob=0.0):
92
+ """
93
+ x: mel fbank [Batch, 1, T, F]
94
+ mask_t_prob: 'T masking ratio (percentage of removed patches).'
95
+ mask_f_prob: 'F masking ratio (percentage of removed patches).'
96
+ """
97
+ return self.model(x=x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob)
98
+
99
+
100
+ class Vanilla_AudioMAE(nn.Module):
101
+ """Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM2)"""
102
+
103
+ def __init__(
104
+ self,
105
+ ):
106
+ super().__init__()
107
+ model = models_mae.__dict__["mae_vit_base_patch16"](
108
+ in_chans=1, audio_exp=True, img_size=(1024, 128)
109
+ )
110
+
111
+ # checkpoint_path = '/mnt/bn/lqhaoheliu/exps/checkpoints/audiomae/pretrained.pth'
112
+ # checkpoint = torch.load(checkpoint_path, map_location='cpu')
113
+ # msg = model.load_state_dict(checkpoint['model'], strict=False)
114
+
115
+ # Skip the missing keys of decoder modules (not required)
116
+ # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')
117
+
118
+ self.model = model.eval()
119
+
120
+ def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False):
121
+ """
122
+ x: mel fbank [Batch, 1, 1024 (T), 128 (F)]
123
+ mask_ratio: 'masking ratio (percentage of removed patches).'
124
+ """
125
+ with torch.no_grad():
126
+ # embed: [B, 513, 768] for mask_ratio=0.0
127
+ if no_mask:
128
+ if no_average:
129
+ raise RuntimeError("This function is deprecated")
130
+ embed = self.model.forward_encoder_no_random_mask_no_average(
131
+ x
132
+ ) # mask_ratio
133
+ else:
134
+ embed = self.model.forward_encoder_no_mask(x) # mask_ratio
135
+ else:
136
+ raise RuntimeError("This function is deprecated")
137
+ embed, _, _, _ = self.model.forward_encoder(x, mask_ratio=mask_ratio)
138
+ return embed
139
+
140
+
141
+ if __name__ == "__main__":
142
+ model = Vanilla_AudioMAE().cuda()
143
+ input = torch.randn(4, 1, 1024, 128).cuda()
144
+ print("The first run")
145
+ embed = model(input, mask_ratio=0.0, no_mask=True)
146
+ print(embed)
147
+ print("The second run")
148
+ embed = model(input, mask_ratio=0.0)
149
+ print(embed)
FlashSR/AudioSR/latent_diffusion/modules/audiomae/__init__.py ADDED
File without changes
FlashSR/AudioSR/latent_diffusion/modules/audiomae/models_mae.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # --------------------------------------------------------
11
+
12
+ from functools import partial
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from timm.models.vision_transformer import Block
18
+ from audiosr.latent_diffusion.modules.audiomae.util.pos_embed import (
19
+ get_2d_sincos_pos_embed,
20
+ get_2d_sincos_pos_embed_flexible,
21
+ )
22
+ from audiosr.latent_diffusion.modules.audiomae.util.patch_embed import (
23
+ PatchEmbed_new,
24
+ PatchEmbed_org,
25
+ )
26
+
27
+
28
+ class MaskedAutoencoderViT(nn.Module):
29
+ """Masked Autoencoder with VisionTransformer backbone"""
30
+
31
+ def __init__(
32
+ self,
33
+ img_size=224,
34
+ patch_size=16,
35
+ stride=10,
36
+ in_chans=3,
37
+ embed_dim=1024,
38
+ depth=24,
39
+ num_heads=16,
40
+ decoder_embed_dim=512,
41
+ decoder_depth=8,
42
+ decoder_num_heads=16,
43
+ mlp_ratio=4.0,
44
+ norm_layer=nn.LayerNorm,
45
+ norm_pix_loss=False,
46
+ audio_exp=False,
47
+ alpha=0.0,
48
+ temperature=0.2,
49
+ mode=0,
50
+ contextual_depth=8,
51
+ use_custom_patch=False,
52
+ split_pos=False,
53
+ pos_trainable=False,
54
+ use_nce=False,
55
+ beta=4.0,
56
+ decoder_mode=0,
57
+ mask_t_prob=0.6,
58
+ mask_f_prob=0.5,
59
+ mask_2d=False,
60
+ epoch=0,
61
+ no_shift=False,
62
+ ):
63
+ super().__init__()
64
+
65
+ self.audio_exp = audio_exp
66
+ self.embed_dim = embed_dim
67
+ self.decoder_embed_dim = decoder_embed_dim
68
+ # --------------------------------------------------------------------------
69
+ # MAE encoder specifics
70
+ if use_custom_patch:
71
+ print(
72
+ f"Use custom patch_emb with patch size: {patch_size}, stride: {stride}"
73
+ )
74
+ self.patch_embed = PatchEmbed_new(
75
+ img_size=img_size,
76
+ patch_size=patch_size,
77
+ in_chans=in_chans,
78
+ embed_dim=embed_dim,
79
+ stride=stride,
80
+ )
81
+ else:
82
+ self.patch_embed = PatchEmbed_org(img_size, patch_size, in_chans, embed_dim)
83
+ self.use_custom_patch = use_custom_patch
84
+ num_patches = self.patch_embed.num_patches
85
+
86
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
87
+
88
+ # self.split_pos = split_pos # not useful
89
+ self.pos_embed = nn.Parameter(
90
+ torch.zeros(1, num_patches + 1, embed_dim), requires_grad=pos_trainable
91
+ ) # fixed sin-cos embedding
92
+
93
+ self.encoder_depth = depth
94
+ self.contextual_depth = contextual_depth
95
+ self.blocks = nn.ModuleList(
96
+ [
97
+ Block(
98
+ embed_dim,
99
+ num_heads,
100
+ mlp_ratio,
101
+ qkv_bias=True,
102
+ norm_layer=norm_layer,
103
+ ) # qk_scale=None
104
+ for i in range(depth)
105
+ ]
106
+ )
107
+ self.norm = norm_layer(embed_dim)
108
+
109
+ # --------------------------------------------------------------------------
110
+ # MAE decoder specifics
111
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
112
+
113
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
114
+ self.decoder_pos_embed = nn.Parameter(
115
+ torch.zeros(1, num_patches + 1, decoder_embed_dim),
116
+ requires_grad=pos_trainable,
117
+ ) # fixed sin-cos embedding
118
+
119
+ self.no_shift = no_shift
120
+
121
+ self.decoder_mode = decoder_mode
122
+ if (
123
+ self.use_custom_patch
124
+ ): # overlapped patches as in AST. Similar performance yet compute heavy
125
+ window_size = (6, 6)
126
+ feat_size = (102, 12)
127
+ else:
128
+ window_size = (4, 4)
129
+ feat_size = (64, 8)
130
+ if self.decoder_mode == 1:
131
+ decoder_modules = []
132
+ for index in range(16):
133
+ if self.no_shift:
134
+ shift_size = (0, 0)
135
+ else:
136
+ if (index % 2) == 0:
137
+ shift_size = (0, 0)
138
+ else:
139
+ shift_size = (2, 0)
140
+ # shift_size = tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size])
141
+ decoder_modules.append(
142
+ SwinTransformerBlock(
143
+ dim=decoder_embed_dim,
144
+ num_heads=16,
145
+ feat_size=feat_size,
146
+ window_size=window_size,
147
+ shift_size=shift_size,
148
+ mlp_ratio=mlp_ratio,
149
+ drop=0.0,
150
+ drop_attn=0.0,
151
+ drop_path=0.0,
152
+ extra_norm=False,
153
+ sequential_attn=False,
154
+ norm_layer=norm_layer, # nn.LayerNorm,
155
+ )
156
+ )
157
+ self.decoder_blocks = nn.ModuleList(decoder_modules)
158
+ else:
159
+ # Transfomer
160
+ self.decoder_blocks = nn.ModuleList(
161
+ [
162
+ Block(
163
+ decoder_embed_dim,
164
+ decoder_num_heads,
165
+ mlp_ratio,
166
+ qkv_bias=True,
167
+ norm_layer=norm_layer,
168
+ ) # qk_scale=None,
169
+ for i in range(decoder_depth)
170
+ ]
171
+ )
172
+
173
+ self.decoder_norm = norm_layer(decoder_embed_dim)
174
+ self.decoder_pred = nn.Linear(
175
+ decoder_embed_dim, patch_size**2 * in_chans, bias=True
176
+ ) # decoder to patch
177
+
178
+ # --------------------------------------------------------------------------
179
+
180
+ self.norm_pix_loss = norm_pix_loss
181
+
182
+ self.patch_size = patch_size
183
+ self.stride = stride
184
+
185
+ # audio exps
186
+ self.alpha = alpha
187
+ self.T = temperature
188
+ self.mode = mode
189
+ self.use_nce = use_nce
190
+ self.beta = beta
191
+
192
+ self.log_softmax = nn.LogSoftmax(dim=-1)
193
+
194
+ self.mask_t_prob = mask_t_prob
195
+ self.mask_f_prob = mask_f_prob
196
+ self.mask_2d = mask_2d
197
+
198
+ self.epoch = epoch
199
+
200
+ self.initialize_weights()
201
+
202
+ def initialize_weights(self):
203
+ # initialization
204
+ # initialize (and freeze) pos_embed by sin-cos embedding
205
+ if self.audio_exp:
206
+ pos_embed = get_2d_sincos_pos_embed_flexible(
207
+ self.pos_embed.shape[-1], self.patch_embed.patch_hw, cls_token=True
208
+ )
209
+ else:
210
+ pos_embed = get_2d_sincos_pos_embed(
211
+ self.pos_embed.shape[-1],
212
+ int(self.patch_embed.num_patches**0.5),
213
+ cls_token=True,
214
+ )
215
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
216
+
217
+ if self.audio_exp:
218
+ decoder_pos_embed = get_2d_sincos_pos_embed_flexible(
219
+ self.decoder_pos_embed.shape[-1],
220
+ self.patch_embed.patch_hw,
221
+ cls_token=True,
222
+ )
223
+ else:
224
+ decoder_pos_embed = get_2d_sincos_pos_embed(
225
+ self.decoder_pos_embed.shape[-1],
226
+ int(self.patch_embed.num_patches**0.5),
227
+ cls_token=True,
228
+ )
229
+ self.decoder_pos_embed.data.copy_(
230
+ torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
231
+ )
232
+
233
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
234
+ w = self.patch_embed.proj.weight.data
235
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
236
+
237
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
238
+ torch.nn.init.normal_(self.cls_token, std=0.02)
239
+ torch.nn.init.normal_(self.mask_token, std=0.02)
240
+
241
+ # initialize nn.Linear and nn.LayerNorm
242
+ self.apply(self._init_weights)
243
+
244
+ def _init_weights(self, m):
245
+ if isinstance(m, nn.Linear):
246
+ # we use xavier_uniform following official JAX ViT:
247
+ torch.nn.init.xavier_uniform_(m.weight)
248
+ if isinstance(m, nn.Linear) and m.bias is not None:
249
+ nn.init.constant_(m.bias, 0)
250
+ elif isinstance(m, nn.LayerNorm):
251
+ nn.init.constant_(m.bias, 0)
252
+ nn.init.constant_(m.weight, 1.0)
253
+
254
+ def patchify(self, imgs):
255
+ """
256
+ imgs: (N, 3, H, W)
257
+ x: (N, L, patch_size**2 *3)
258
+ L = (H/p)*(W/p)
259
+ """
260
+ p = self.patch_embed.patch_size[0]
261
+ # assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
262
+
263
+ if self.audio_exp:
264
+ if self.use_custom_patch: # overlapped patch
265
+ h, w = self.patch_embed.patch_hw
266
+ # todo: fixed h/w patch size and stride size. Make hw custom in the future
267
+ x = imgs.unfold(2, self.patch_size, self.stride).unfold(
268
+ 3, self.patch_size, self.stride
269
+ ) # n,1,H,W -> n,1,h,w,p,p
270
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
271
+ # x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))
272
+ # x = torch.einsum('nchpwq->nhwpqc', x)
273
+ # x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
274
+ else:
275
+ h = imgs.shape[2] // p
276
+ w = imgs.shape[3] // p
277
+ # h,w = self.patch_embed.patch_hw
278
+ x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))
279
+ x = torch.einsum("nchpwq->nhwpqc", x)
280
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
281
+ else:
282
+ h = w = imgs.shape[2] // p
283
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
284
+ x = torch.einsum("nchpwq->nhwpqc", x)
285
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
286
+
287
+ return x
288
+
289
+ def unpatchify(self, x):
290
+ """
291
+ x: (N, L, patch_size**2 *3)
292
+ specs: (N, 1, H, W)
293
+ """
294
+ p = self.patch_embed.patch_size[0]
295
+ h = 1024 // p
296
+ w = 128 // p
297
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, 1))
298
+ x = torch.einsum("nhwpqc->nchpwq", x)
299
+ specs = x.reshape(shape=(x.shape[0], 1, h * p, w * p))
300
+ return specs
301
+
302
+ def random_masking(self, x, mask_ratio):
303
+ """
304
+ Perform per-sample random masking by per-sample shuffling.
305
+ Per-sample shuffling is done by argsort random noise.
306
+ x: [N, L, D], sequence
307
+ """
308
+ N, L, D = x.shape # batch, length, dim
309
+ len_keep = int(L * (1 - mask_ratio))
310
+
311
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
312
+
313
+ # sort noise for each sample
314
+ ids_shuffle = torch.argsort(
315
+ noise, dim=1
316
+ ) # ascend: small is keep, large is remove
317
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
318
+
319
+ # keep the first subset
320
+ ids_keep = ids_shuffle[:, :len_keep]
321
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
322
+
323
+ # generate the binary mask: 0 is keep, 1 is remove
324
+ mask = torch.ones([N, L], device=x.device)
325
+ mask[:, :len_keep] = 0
326
+ # unshuffle to get the binary mask
327
+ mask = torch.gather(mask, dim=1, index=ids_restore)
328
+
329
+ return x_masked, mask, ids_restore
330
+
331
+ def random_masking_2d(self, x, mask_t_prob, mask_f_prob):
332
+ """
333
+ 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)
334
+ Perform per-sample random masking by per-sample shuffling.
335
+ Per-sample shuffling is done by argsort random noise.
336
+ x: [N, L, D], sequence
337
+ """
338
+ N, L, D = x.shape # batch, length, dim
339
+ if self.use_custom_patch: # overlapped patch
340
+ T = 101
341
+ F = 12
342
+ else:
343
+ T = 64
344
+ F = 8
345
+ # x = x.reshape(N, T, F, D)
346
+ len_keep_t = int(T * (1 - mask_t_prob))
347
+ len_keep_f = int(F * (1 - mask_f_prob))
348
+
349
+ # noise for mask in time
350
+ noise_t = torch.rand(N, T, device=x.device) # noise in [0, 1]
351
+ # sort noise for each sample aling time
352
+ ids_shuffle_t = torch.argsort(
353
+ noise_t, dim=1
354
+ ) # ascend: small is keep, large is remove
355
+ ids_restore_t = torch.argsort(ids_shuffle_t, dim=1)
356
+ ids_keep_t = ids_shuffle_t[:, :len_keep_t]
357
+ # noise mask in freq
358
+ noise_f = torch.rand(N, F, device=x.device) # noise in [0, 1]
359
+ ids_shuffle_f = torch.argsort(
360
+ noise_f, dim=1
361
+ ) # ascend: small is keep, large is remove
362
+ ids_restore_f = torch.argsort(ids_shuffle_f, dim=1)
363
+ ids_keep_f = ids_shuffle_f[:, :len_keep_f] #
364
+
365
+ # generate the binary mask: 0 is keep, 1 is remove
366
+ # mask in freq
367
+ mask_f = torch.ones(N, F, device=x.device)
368
+ mask_f[:, :len_keep_f] = 0
369
+ mask_f = (
370
+ torch.gather(mask_f, dim=1, index=ids_restore_f)
371
+ .unsqueeze(1)
372
+ .repeat(1, T, 1)
373
+ ) # N,T,F
374
+ # mask in time
375
+ mask_t = torch.ones(N, T, device=x.device)
376
+ mask_t[:, :len_keep_t] = 0
377
+ mask_t = (
378
+ torch.gather(mask_t, dim=1, index=ids_restore_t)
379
+ .unsqueeze(1)
380
+ .repeat(1, F, 1)
381
+ .permute(0, 2, 1)
382
+ ) # N,T,F
383
+ mask = 1 - (1 - mask_t) * (1 - mask_f) # N, T, F
384
+
385
+ # get masked x
386
+ id2res = torch.Tensor(list(range(N * T * F))).reshape(N, T, F).to(x.device)
387
+ id2res = id2res + 999 * mask # add a large value for masked elements
388
+ id2res2 = torch.argsort(id2res.flatten(start_dim=1))
389
+ ids_keep = id2res2.flatten(start_dim=1)[:, : len_keep_f * len_keep_t]
390
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
391
+
392
+ ids_restore = torch.argsort(id2res2.flatten(start_dim=1))
393
+ mask = mask.flatten(start_dim=1)
394
+
395
+ return x_masked, mask, ids_restore
396
+
397
+ def forward_encoder(self, x, mask_ratio, mask_2d=False):
398
+ # embed patches
399
+ x = self.patch_embed(x)
400
+ # add pos embed w/o cls token
401
+ x = x + self.pos_embed[:, 1:, :]
402
+
403
+ # masking: length -> length * mask_ratio
404
+ if mask_2d:
405
+ x, mask, ids_restore = self.random_masking_2d(
406
+ x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob
407
+ )
408
+ else:
409
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
410
+
411
+ # append cls token
412
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
413
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
414
+ x = torch.cat((cls_tokens, x), dim=1)
415
+
416
+ # apply Transformer blocks
417
+ for blk in self.blocks:
418
+ x = blk(x)
419
+ x = self.norm(x)
420
+
421
+ return x, mask, ids_restore, None
422
+
423
+ def forward_encoder_no_random_mask_no_average(self, x):
424
+ # embed patches
425
+ x = self.patch_embed(x)
426
+ # add pos embed w/o cls token
427
+ x = x + self.pos_embed[:, 1:, :]
428
+
429
+ # masking: length -> length * mask_ratio
430
+ # if mask_2d:
431
+ # x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob)
432
+ # else:
433
+ # x, mask, ids_restore = self.random_masking(x, mask_ratio)
434
+
435
+ # append cls token
436
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
437
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
438
+ x = torch.cat((cls_tokens, x), dim=1)
439
+
440
+ # apply Transformer blocks
441
+ for blk in self.blocks:
442
+ x = blk(x)
443
+ x = self.norm(x)
444
+
445
+ return x
446
+
447
+ def forward_encoder_no_mask(self, x):
448
+ # embed patches
449
+ x = self.patch_embed(x)
450
+
451
+ # add pos embed w/o cls token
452
+ x = x + self.pos_embed[:, 1:, :]
453
+
454
+ # masking: length -> length * mask_ratio
455
+ # x, mask, ids_restore = self.random_masking(x, mask_ratio)
456
+ # append cls token
457
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
458
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
459
+ x = torch.cat((cls_tokens, x), dim=1)
460
+
461
+ # apply Transformer blocks
462
+ contextual_embs = []
463
+ for n, blk in enumerate(self.blocks):
464
+ x = blk(x)
465
+ if n > self.contextual_depth:
466
+ contextual_embs.append(self.norm(x))
467
+ # x = self.norm(x)
468
+ contextual_emb = torch.stack(contextual_embs, dim=0).mean(dim=0)
469
+
470
+ return contextual_emb
471
+
472
+ def forward_decoder(self, x, ids_restore):
473
+ # embed tokens
474
+ x = self.decoder_embed(x)
475
+
476
+ # append mask tokens to sequence
477
+ mask_tokens = self.mask_token.repeat(
478
+ x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
479
+ )
480
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
481
+ x_ = torch.gather(
482
+ x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
483
+ ) # unshuffle
484
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
485
+
486
+ # add pos embed
487
+ x = x + self.decoder_pos_embed
488
+
489
+ if self.decoder_mode != 0:
490
+ B, L, D = x.shape
491
+ x = x[:, 1:, :]
492
+ if self.use_custom_patch:
493
+ x = x.reshape(B, 101, 12, D)
494
+ x = torch.cat([x, x[:, -1, :].unsqueeze(1)], dim=1) # hack
495
+ x = x.reshape(B, 1224, D)
496
+ if self.decoder_mode > 3: # mvit
497
+ x = self.decoder_blocks(x)
498
+ else:
499
+ # apply Transformer blocks
500
+ for blk in self.decoder_blocks:
501
+ x = blk(x)
502
+ x = self.decoder_norm(x)
503
+
504
+ # predictor projection
505
+ pred = self.decoder_pred(x)
506
+
507
+ # remove cls token
508
+ if self.decoder_mode != 0:
509
+ if self.use_custom_patch:
510
+ pred = pred.reshape(B, 102, 12, 256)
511
+ pred = pred[:, :101, :, :]
512
+ pred = pred.reshape(B, 1212, 256)
513
+ else:
514
+ pred = pred
515
+ else:
516
+ pred = pred[:, 1:, :]
517
+ return pred, None, None # emb, emb_pixel
518
+
519
+ def forward_loss(self, imgs, pred, mask, norm_pix_loss=False):
520
+ """
521
+ imgs: [N, 3, H, W]
522
+ pred: [N, L, p*p*3]
523
+ mask: [N, L], 0 is keep, 1 is remove,
524
+ """
525
+ target = self.patchify(imgs)
526
+ if norm_pix_loss:
527
+ mean = target.mean(dim=-1, keepdim=True)
528
+ var = target.var(dim=-1, keepdim=True)
529
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
530
+
531
+ loss = (pred - target) ** 2
532
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
533
+
534
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
535
+ return loss
536
+
537
+ def forward(self, imgs, mask_ratio=0.8):
538
+ emb_enc, mask, ids_restore, _ = self.forward_encoder(
539
+ imgs, mask_ratio, mask_2d=self.mask_2d
540
+ )
541
+ pred, _, _ = self.forward_decoder(emb_enc, ids_restore) # [N, L, p*p*3]
542
+ loss_recon = self.forward_loss(
543
+ imgs, pred, mask, norm_pix_loss=self.norm_pix_loss
544
+ )
545
+ loss_contrastive = torch.FloatTensor([0.0]).cuda()
546
+ return loss_recon, pred, mask, loss_contrastive
547
+
548
+
549
+ def mae_vit_small_patch16_dec512d8b(**kwargs):
550
+ model = MaskedAutoencoderViT(
551
+ patch_size=16,
552
+ embed_dim=384,
553
+ depth=12,
554
+ num_heads=6,
555
+ decoder_embed_dim=512,
556
+ decoder_num_heads=16,
557
+ mlp_ratio=4,
558
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
559
+ **kwargs,
560
+ )
561
+ return model
562
+
563
+
564
+ def mae_vit_base_patch16_dec512d8b(**kwargs):
565
+ model = MaskedAutoencoderViT(
566
+ patch_size=16,
567
+ embed_dim=768,
568
+ depth=12,
569
+ num_heads=12,
570
+ decoder_embed_dim=512,
571
+ decoder_num_heads=16,
572
+ mlp_ratio=4,
573
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
574
+ **kwargs,
575
+ )
576
+ return model
577
+
578
+
579
+ def mae_vit_large_patch16_dec512d8b(**kwargs):
580
+ model = MaskedAutoencoderViT(
581
+ patch_size=16,
582
+ embed_dim=1024,
583
+ depth=24,
584
+ num_heads=16,
585
+ decoder_embed_dim=512,
586
+ decoder_num_heads=16,
587
+ mlp_ratio=4,
588
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
589
+ **kwargs,
590
+ )
591
+ return model
592
+
593
+
594
+ def mae_vit_huge_patch14_dec512d8b(**kwargs):
595
+ model = MaskedAutoencoderViT(
596
+ patch_size=14,
597
+ embed_dim=1280,
598
+ depth=32,
599
+ num_heads=16,
600
+ decoder_embed_dim=512,
601
+ decoder_num_heads=16,
602
+ mlp_ratio=4,
603
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
604
+ **kwargs,
605
+ )
606
+ return model
607
+
608
+
609
+ # set recommended archs
610
+ mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks
611
+ mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
612
+ mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks
613
+ mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b # decoder: 512 dim, 8 blocks
FlashSR/AudioSR/latent_diffusion/modules/audiomae/models_vit.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # --------------------------------------------------------
11
+
12
+ from functools import partial
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import timm.models.vision_transformer
17
+
18
+
19
+ class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
20
+ """Vision Transformer with support for global average pooling"""
21
+
22
+ def __init__(
23
+ self, global_pool=False, mask_2d=True, use_custom_patch=False, **kwargs
24
+ ):
25
+ super(VisionTransformer, self).__init__(**kwargs)
26
+
27
+ self.global_pool = global_pool
28
+ if self.global_pool:
29
+ norm_layer = kwargs["norm_layer"]
30
+ embed_dim = kwargs["embed_dim"]
31
+ self.fc_norm = norm_layer(embed_dim)
32
+ del self.norm # remove the original norm
33
+ self.mask_2d = mask_2d
34
+ self.use_custom_patch = use_custom_patch
35
+
36
+ def forward_features(self, x):
37
+ B = x.shape[0]
38
+ x = self.patch_embed(x)
39
+ x = x + self.pos_embed[:, 1:, :]
40
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
41
+ cls_tokens = cls_token.expand(
42
+ B, -1, -1
43
+ ) # stole cls_tokens impl from Phil Wang, thanks
44
+ x = torch.cat((cls_tokens, x), dim=1)
45
+ x = self.pos_drop(x)
46
+
47
+ for blk in self.blocks:
48
+ x = blk(x)
49
+
50
+ if self.global_pool:
51
+ x = x[:, 1:, :].mean(dim=1) # global pool without cls token
52
+ outcome = self.fc_norm(x)
53
+ else:
54
+ x = self.norm(x)
55
+ outcome = x[:, 0]
56
+
57
+ return outcome
58
+
59
+ def random_masking(self, x, mask_ratio):
60
+ """
61
+ Perform per-sample random masking by per-sample shuffling.
62
+ Per-sample shuffling is done by argsort random noise.
63
+ x: [N, L, D], sequence
64
+ """
65
+ N, L, D = x.shape # batch, length, dim
66
+ len_keep = int(L * (1 - mask_ratio))
67
+
68
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
69
+
70
+ # sort noise for each sample
71
+ ids_shuffle = torch.argsort(
72
+ noise, dim=1
73
+ ) # ascend: small is keep, large is remove
74
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
75
+
76
+ # keep the first subset
77
+ ids_keep = ids_shuffle[:, :len_keep]
78
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
79
+
80
+ # generate the binary mask: 0 is keep, 1 is remove
81
+ mask = torch.ones([N, L], device=x.device)
82
+ mask[:, :len_keep] = 0
83
+ # unshuffle to get the binary mask
84
+ mask = torch.gather(mask, dim=1, index=ids_restore)
85
+
86
+ return x_masked, mask, ids_restore
87
+
88
+ def random_masking_2d(self, x, mask_t_prob, mask_f_prob):
89
+ """
90
+ 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)
91
+ Perform per-sample random masking by per-sample shuffling.
92
+ Per-sample shuffling is done by argsort random noise.
93
+ x: [N, L, D], sequence
94
+ """
95
+
96
+ N, L, D = x.shape # batch, length, dim
97
+ if self.use_custom_patch:
98
+ # # for AS
99
+ T = 101 # 64,101
100
+ F = 12 # 8,12
101
+ # # for ESC
102
+ # T=50
103
+ # F=12
104
+ # for SPC
105
+ # T=12
106
+ # F=12
107
+ else:
108
+ # ## for AS
109
+ T = 64
110
+ F = 8
111
+ # ## for ESC
112
+ # T=32
113
+ # F=8
114
+ ## for SPC
115
+ # T=8
116
+ # F=8
117
+
118
+ # mask T
119
+ x = x.reshape(N, T, F, D)
120
+ len_keep_T = int(T * (1 - mask_t_prob))
121
+ noise = torch.rand(N, T, device=x.device) # noise in [0, 1]
122
+ # sort noise for each sample
123
+ ids_shuffle = torch.argsort(
124
+ noise, dim=1
125
+ ) # ascend: small is keep, large is remove
126
+ ids_keep = ids_shuffle[:, :len_keep_T]
127
+ index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, F, D)
128
+ # x_masked = torch.gather(x, dim=1, index=index)
129
+ # x_masked = x_masked.reshape(N,len_keep_T*F,D)
130
+ x = torch.gather(x, dim=1, index=index) # N, len_keep_T(T'), F, D
131
+
132
+ # mask F
133
+ # x = x.reshape(N, T, F, D)
134
+ x = x.permute(0, 2, 1, 3) # N T' F D => N F T' D
135
+ len_keep_F = int(F * (1 - mask_f_prob))
136
+ noise = torch.rand(N, F, device=x.device) # noise in [0, 1]
137
+ # sort noise for each sample
138
+ ids_shuffle = torch.argsort(
139
+ noise, dim=1
140
+ ) # ascend: small is keep, large is remove
141
+ ids_keep = ids_shuffle[:, :len_keep_F]
142
+ # index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, T, D)
143
+ index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_T, D)
144
+ x_masked = torch.gather(x, dim=1, index=index)
145
+ x_masked = x_masked.permute(0, 2, 1, 3) # N F' T' D => N T' F' D
146
+ # x_masked = x_masked.reshape(N,len_keep*T,D)
147
+ x_masked = x_masked.reshape(N, len_keep_F * len_keep_T, D)
148
+
149
+ return x_masked, None, None
150
+
151
+ def forward_features_mask(self, x, mask_t_prob, mask_f_prob):
152
+ B = x.shape[0] # 4,1,1024,128
153
+ x = self.patch_embed(x) # 4, 512, 768
154
+
155
+ x = x + self.pos_embed[:, 1:, :]
156
+ if self.random_masking_2d:
157
+ x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob, mask_f_prob)
158
+ else:
159
+ x, mask, ids_restore = self.random_masking(x, mask_t_prob)
160
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
161
+ cls_tokens = cls_token.expand(B, -1, -1)
162
+ x = torch.cat((cls_tokens, x), dim=1)
163
+ x = self.pos_drop(x)
164
+
165
+ # apply Transformer blocks
166
+ for blk in self.blocks:
167
+ x = blk(x)
168
+
169
+ if self.global_pool:
170
+ x = x[:, 1:, :].mean(dim=1) # global pool without cls token
171
+ outcome = self.fc_norm(x)
172
+ else:
173
+ x = self.norm(x)
174
+ outcome = x[:, 0]
175
+
176
+ return outcome
177
+
178
+ # overwrite original timm
179
+ def forward(self, x, v=None, mask_t_prob=0.0, mask_f_prob=0.0):
180
+ if mask_t_prob > 0.0 or mask_f_prob > 0.0:
181
+ x = self.forward_features_mask(
182
+ x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob
183
+ )
184
+ else:
185
+ x = self.forward_features(x)
186
+ x = self.head(x)
187
+ return x
188
+
189
+
190
+ def vit_small_patch16(**kwargs):
191
+ model = VisionTransformer(
192
+ patch_size=16,
193
+ embed_dim=384,
194
+ depth=12,
195
+ num_heads=6,
196
+ mlp_ratio=4,
197
+ qkv_bias=True,
198
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
199
+ **kwargs
200
+ )
201
+ return model
202
+
203
+
204
+ def vit_base_patch16(**kwargs):
205
+ model = VisionTransformer(
206
+ patch_size=16,
207
+ embed_dim=768,
208
+ depth=12,
209
+ num_heads=12,
210
+ mlp_ratio=4,
211
+ qkv_bias=True,
212
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
213
+ **kwargs
214
+ )
215
+ return model
216
+
217
+
218
+ def vit_large_patch16(**kwargs):
219
+ model = VisionTransformer(
220
+ patch_size=16,
221
+ embed_dim=1024,
222
+ depth=24,
223
+ num_heads=16,
224
+ mlp_ratio=4,
225
+ qkv_bias=True,
226
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
227
+ **kwargs
228
+ )
229
+ return model
230
+
231
+
232
+ def vit_huge_patch14(**kwargs):
233
+ model = VisionTransformer(
234
+ patch_size=14,
235
+ embed_dim=1280,
236
+ depth=32,
237
+ num_heads=16,
238
+ mlp_ratio=4,
239
+ qkv_bias=True,
240
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
241
+ **kwargs
242
+ )
243
+ return model
FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/crop.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+ import torch
10
+
11
+ from torchvision import transforms
12
+ from torchvision.transforms import functional as F
13
+
14
+
15
+ class RandomResizedCrop(transforms.RandomResizedCrop):
16
+ """
17
+ RandomResizedCrop for matching TF/TPU implementation: no for-loop is used.
18
+ This may lead to results different with torchvision's version.
19
+ Following BYOL's TF code:
20
+ https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206
21
+ """
22
+
23
+ @staticmethod
24
+ def get_params(img, scale, ratio):
25
+ width, height = F._get_image_size(img)
26
+ area = height * width
27
+
28
+ target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item()
29
+ log_ratio = torch.log(torch.tensor(ratio))
30
+ aspect_ratio = torch.exp(
31
+ torch.empty(1).uniform_(log_ratio[0], log_ratio[1])
32
+ ).item()
33
+
34
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
35
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
36
+
37
+ w = min(w, width)
38
+ h = min(h, height)
39
+
40
+ i = torch.randint(0, height - h + 1, size=(1,)).item()
41
+ j = torch.randint(0, width - w + 1, size=(1,)).item()
42
+
43
+ return i, j, h, w
FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/datasets.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # --------------------------------------------------------
10
+
11
+ import os
12
+ import PIL
13
+
14
+ from torchvision import datasets, transforms
15
+
16
+ from timm.data import create_transform
17
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
18
+
19
+
20
+ def build_dataset(is_train, args):
21
+ transform = build_transform(is_train, args)
22
+
23
+ root = os.path.join(args.data_path, "train" if is_train else "val")
24
+ dataset = datasets.ImageFolder(root, transform=transform)
25
+
26
+ print(dataset)
27
+
28
+ return dataset
29
+
30
+
31
+ def build_transform(is_train, args):
32
+ mean = IMAGENET_DEFAULT_MEAN
33
+ std = IMAGENET_DEFAULT_STD
34
+ # train transform
35
+ if is_train:
36
+ # this should always dispatch to transforms_imagenet_train
37
+ transform = create_transform(
38
+ input_size=args.input_size,
39
+ is_training=True,
40
+ color_jitter=args.color_jitter,
41
+ auto_augment=args.aa,
42
+ interpolation="bicubic",
43
+ re_prob=args.reprob,
44
+ re_mode=args.remode,
45
+ re_count=args.recount,
46
+ mean=mean,
47
+ std=std,
48
+ )
49
+ return transform
50
+
51
+ # eval transform
52
+ t = []
53
+ if args.input_size <= 224:
54
+ crop_pct = 224 / 256
55
+ else:
56
+ crop_pct = 1.0
57
+ size = int(args.input_size / crop_pct)
58
+ t.append(
59
+ transforms.Resize(
60
+ size, interpolation=PIL.Image.BICUBIC
61
+ ), # to maintain same ratio w.r.t. 224 images
62
+ )
63
+ t.append(transforms.CenterCrop(args.input_size))
64
+
65
+ t.append(transforms.ToTensor())
66
+ t.append(transforms.Normalize(mean, std))
67
+ return transforms.Compose(t)
FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lars.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # LARS optimizer, implementation from MoCo v3:
8
+ # https://github.com/facebookresearch/moco-v3
9
+ # --------------------------------------------------------
10
+
11
+ import torch
12
+
13
+
14
+ class LARS(torch.optim.Optimizer):
15
+ """
16
+ LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
17
+ """
18
+
19
+ def __init__(
20
+ self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001
21
+ ):
22
+ defaults = dict(
23
+ lr=lr,
24
+ weight_decay=weight_decay,
25
+ momentum=momentum,
26
+ trust_coefficient=trust_coefficient,
27
+ )
28
+ super().__init__(params, defaults)
29
+
30
+ @torch.no_grad()
31
+ def step(self):
32
+ for g in self.param_groups:
33
+ for p in g["params"]:
34
+ dp = p.grad
35
+
36
+ if dp is None:
37
+ continue
38
+
39
+ if p.ndim > 1: # if not normalization gamma/beta or bias
40
+ dp = dp.add(p, alpha=g["weight_decay"])
41
+ param_norm = torch.norm(p)
42
+ update_norm = torch.norm(dp)
43
+ one = torch.ones_like(param_norm)
44
+ q = torch.where(
45
+ param_norm > 0.0,
46
+ torch.where(
47
+ update_norm > 0,
48
+ (g["trust_coefficient"] * param_norm / update_norm),
49
+ one,
50
+ ),
51
+ one,
52
+ )
53
+ dp = dp.mul(q)
54
+
55
+ param_state = self.state[p]
56
+ if "mu" not in param_state:
57
+ param_state["mu"] = torch.zeros_like(p)
58
+ mu = param_state["mu"]
59
+ mu.mul_(g["momentum"]).add_(dp)
60
+ p.add_(mu, alpha=-g["lr"])
FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lr_decay.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # ELECTRA https://github.com/google-research/electra
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # --------------------------------------------------------
11
+
12
+
13
+ def param_groups_lrd(
14
+ model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=0.75
15
+ ):
16
+ """
17
+ Parameter groups for layer-wise lr decay
18
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
19
+ """
20
+ param_group_names = {}
21
+ param_groups = {}
22
+
23
+ num_layers = len(model.blocks) + 1
24
+
25
+ layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
26
+
27
+ for n, p in model.named_parameters():
28
+ if not p.requires_grad:
29
+ continue
30
+
31
+ # no decay: all 1D parameters and model specific ones
32
+ if p.ndim == 1 or n in no_weight_decay_list:
33
+ g_decay = "no_decay"
34
+ this_decay = 0.0
35
+ else:
36
+ g_decay = "decay"
37
+ this_decay = weight_decay
38
+
39
+ layer_id = get_layer_id_for_vit(n, num_layers)
40
+ group_name = "layer_%d_%s" % (layer_id, g_decay)
41
+
42
+ if group_name not in param_group_names:
43
+ this_scale = layer_scales[layer_id]
44
+
45
+ param_group_names[group_name] = {
46
+ "lr_scale": this_scale,
47
+ "weight_decay": this_decay,
48
+ "params": [],
49
+ }
50
+ param_groups[group_name] = {
51
+ "lr_scale": this_scale,
52
+ "weight_decay": this_decay,
53
+ "params": [],
54
+ }
55
+
56
+ param_group_names[group_name]["params"].append(n)
57
+ param_groups[group_name]["params"].append(p)
58
+
59
+ # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
60
+
61
+ return list(param_groups.values())
62
+
63
+
64
+ def get_layer_id_for_vit(name, num_layers):
65
+ """
66
+ Assign a parameter with its layer id
67
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
68
+ """
69
+ if name in ["cls_token", "pos_embed"]:
70
+ return 0
71
+ elif name.startswith("patch_embed"):
72
+ return 0
73
+ elif name.startswith("blocks"):
74
+ return int(name.split(".")[1]) + 1
75
+ else:
76
+ return num_layers
FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/lr_sched.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+
9
+
10
+ def adjust_learning_rate(optimizer, epoch, args):
11
+ """Decay the learning rate with half-cycle cosine after warmup"""
12
+ if epoch < args.warmup_epochs:
13
+ lr = args.lr * epoch / args.warmup_epochs
14
+ else:
15
+ lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * (
16
+ 1.0
17
+ + math.cos(
18
+ math.pi
19
+ * (epoch - args.warmup_epochs)
20
+ / (args.epochs - args.warmup_epochs)
21
+ )
22
+ )
23
+ for param_group in optimizer.param_groups:
24
+ if "lr_scale" in param_group:
25
+ param_group["lr"] = lr * param_group["lr_scale"]
26
+ else:
27
+ param_group["lr"] = lr
28
+ return lr
FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/misc.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # DeiT: https://github.com/facebookresearch/deit
9
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10
+ # --------------------------------------------------------
11
+
12
+ import builtins
13
+ import datetime
14
+ import os
15
+ import time
16
+ from collections import defaultdict, deque
17
+ from pathlib import Path
18
+
19
+ import torch
20
+ import torch.distributed as dist
21
+ from torch._six import inf
22
+
23
+
24
+ class SmoothedValue(object):
25
+ """Track a series of values and provide access to smoothed values over a
26
+ window or the global series average.
27
+ """
28
+
29
+ def __init__(self, window_size=20, fmt=None):
30
+ if fmt is None:
31
+ fmt = "{median:.4f} ({global_avg:.4f})"
32
+ self.deque = deque(maxlen=window_size)
33
+ self.total = 0.0
34
+ self.count = 0
35
+ self.fmt = fmt
36
+
37
+ def update(self, value, n=1):
38
+ self.deque.append(value)
39
+ self.count += n
40
+ self.total += value * n
41
+
42
+ def synchronize_between_processes(self):
43
+ """
44
+ Warning: does not synchronize the deque!
45
+ """
46
+ if not is_dist_avail_and_initialized():
47
+ return
48
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
49
+ dist.barrier()
50
+ dist.all_reduce(t)
51
+ t = t.tolist()
52
+ self.count = int(t[0])
53
+ self.total = t[1]
54
+
55
+ @property
56
+ def median(self):
57
+ d = torch.tensor(list(self.deque))
58
+ return d.median().item()
59
+
60
+ @property
61
+ def avg(self):
62
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
63
+ return d.mean().item()
64
+
65
+ @property
66
+ def global_avg(self):
67
+ return self.total / self.count
68
+
69
+ @property
70
+ def max(self):
71
+ return max(self.deque)
72
+
73
+ @property
74
+ def value(self):
75
+ return self.deque[-1]
76
+
77
+ def __str__(self):
78
+ return self.fmt.format(
79
+ median=self.median,
80
+ avg=self.avg,
81
+ global_avg=self.global_avg,
82
+ max=self.max,
83
+ value=self.value,
84
+ )
85
+
86
+
87
+ class MetricLogger(object):
88
+ def __init__(self, delimiter="\t"):
89
+ self.meters = defaultdict(SmoothedValue)
90
+ self.delimiter = delimiter
91
+
92
+ def update(self, **kwargs):
93
+ for k, v in kwargs.items():
94
+ if v is None:
95
+ continue
96
+ if isinstance(v, torch.Tensor):
97
+ v = v.item()
98
+ assert isinstance(v, (float, int))
99
+ self.meters[k].update(v)
100
+
101
+ def __getattr__(self, attr):
102
+ if attr in self.meters:
103
+ return self.meters[attr]
104
+ if attr in self.__dict__:
105
+ return self.__dict__[attr]
106
+ raise AttributeError(
107
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
108
+ )
109
+
110
+ def __str__(self):
111
+ loss_str = []
112
+ for name, meter in self.meters.items():
113
+ loss_str.append("{}: {}".format(name, str(meter)))
114
+ return self.delimiter.join(loss_str)
115
+
116
+ def synchronize_between_processes(self):
117
+ for meter in self.meters.values():
118
+ meter.synchronize_between_processes()
119
+
120
+ def add_meter(self, name, meter):
121
+ self.meters[name] = meter
122
+
123
+ def log_every(self, iterable, print_freq, header=None):
124
+ i = 0
125
+ if not header:
126
+ header = ""
127
+ start_time = time.time()
128
+ end = time.time()
129
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
130
+ data_time = SmoothedValue(fmt="{avg:.4f}")
131
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
132
+ log_msg = [
133
+ header,
134
+ "[{0" + space_fmt + "}/{1}]",
135
+ "eta: {eta}",
136
+ "{meters}",
137
+ "time: {time}",
138
+ "data: {data}",
139
+ ]
140
+ if torch.cuda.is_available():
141
+ log_msg.append("max mem: {memory:.0f}")
142
+ log_msg = self.delimiter.join(log_msg)
143
+ MB = 1024.0 * 1024.0
144
+ for obj in iterable:
145
+ data_time.update(time.time() - end)
146
+ yield obj
147
+ iter_time.update(time.time() - end)
148
+ if i % print_freq == 0 or i == len(iterable) - 1:
149
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
150
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
151
+ if torch.cuda.is_available():
152
+ print(
153
+ log_msg.format(
154
+ i,
155
+ len(iterable),
156
+ eta=eta_string,
157
+ meters=str(self),
158
+ time=str(iter_time),
159
+ data=str(data_time),
160
+ memory=torch.cuda.max_memory_allocated() / MB,
161
+ )
162
+ )
163
+ else:
164
+ print(
165
+ log_msg.format(
166
+ i,
167
+ len(iterable),
168
+ eta=eta_string,
169
+ meters=str(self),
170
+ time=str(iter_time),
171
+ data=str(data_time),
172
+ )
173
+ )
174
+ i += 1
175
+ end = time.time()
176
+ total_time = time.time() - start_time
177
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
178
+ print(
179
+ "{} Total time: {} ({:.4f} s / it)".format(
180
+ header, total_time_str, total_time / len(iterable)
181
+ )
182
+ )
183
+
184
+
185
+ def setup_for_distributed(is_master):
186
+ """
187
+ This function disables printing when not in master process
188
+ """
189
+ builtin_print = builtins.print
190
+
191
+ def print(*args, **kwargs):
192
+ force = kwargs.pop("force", False)
193
+ force = force or (get_world_size() > 8)
194
+ if is_master or force:
195
+ now = datetime.datetime.now().time()
196
+ builtin_print("[{}] ".format(now), end="") # print with time stamp
197
+ builtin_print(*args, **kwargs)
198
+
199
+ builtins.print = print
200
+
201
+
202
+ def is_dist_avail_and_initialized():
203
+ if not dist.is_available():
204
+ return False
205
+ if not dist.is_initialized():
206
+ return False
207
+ return True
208
+
209
+
210
+ def get_world_size():
211
+ if not is_dist_avail_and_initialized():
212
+ return 1
213
+ return dist.get_world_size()
214
+
215
+
216
+ def get_rank():
217
+ if not is_dist_avail_and_initialized():
218
+ return 0
219
+ return dist.get_rank()
220
+
221
+
222
+ def is_main_process():
223
+ return get_rank() == 0
224
+
225
+
226
+ def save_on_master(*args, **kwargs):
227
+ if is_main_process():
228
+ torch.save(*args, **kwargs)
229
+
230
+
231
+ def init_distributed_mode(args):
232
+ if args.dist_on_itp:
233
+ args.rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
234
+ args.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
235
+ args.gpu = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
236
+ args.dist_url = "tcp://%s:%s" % (
237
+ os.environ["MASTER_ADDR"],
238
+ os.environ["MASTER_PORT"],
239
+ )
240
+ os.environ["LOCAL_RANK"] = str(args.gpu)
241
+ os.environ["RANK"] = str(args.rank)
242
+ os.environ["WORLD_SIZE"] = str(args.world_size)
243
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
244
+ elif "RANK" in os.environ and "WORLD_SIZE" in os.environ:
245
+ args.rank = int(os.environ["RANK"])
246
+ args.world_size = int(os.environ["WORLD_SIZE"])
247
+ args.gpu = int(os.environ["LOCAL_RANK"])
248
+ elif "SLURM_PROCID" in os.environ:
249
+ args.rank = int(os.environ["SLURM_PROCID"])
250
+ args.gpu = args.rank % torch.cuda.device_count()
251
+ else:
252
+ print("Not using distributed mode")
253
+ setup_for_distributed(is_master=True) # hack
254
+ args.distributed = False
255
+ return
256
+
257
+ args.distributed = True
258
+
259
+ torch.cuda.set_device(args.gpu)
260
+ args.dist_backend = "nccl"
261
+ print(
262
+ "| distributed init (rank {}): {}, gpu {}".format(
263
+ args.rank, args.dist_url, args.gpu
264
+ ),
265
+ flush=True,
266
+ )
267
+ torch.distributed.init_process_group(
268
+ backend=args.dist_backend,
269
+ init_method=args.dist_url,
270
+ world_size=args.world_size,
271
+ rank=args.rank,
272
+ )
273
+ torch.distributed.barrier()
274
+ setup_for_distributed(args.rank == 0)
275
+
276
+
277
+ class NativeScalerWithGradNormCount:
278
+ state_dict_key = "amp_scaler"
279
+
280
+ def __init__(self):
281
+ self._scaler = torch.cuda.amp.GradScaler()
282
+
283
+ def __call__(
284
+ self,
285
+ loss,
286
+ optimizer,
287
+ clip_grad=None,
288
+ parameters=None,
289
+ create_graph=False,
290
+ update_grad=True,
291
+ ):
292
+ self._scaler.scale(loss).backward(create_graph=create_graph)
293
+ if update_grad:
294
+ if clip_grad is not None:
295
+ assert parameters is not None
296
+ self._scaler.unscale_(
297
+ optimizer
298
+ ) # unscale the gradients of optimizer's assigned params in-place
299
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
300
+ else:
301
+ self._scaler.unscale_(optimizer)
302
+ norm = get_grad_norm_(parameters)
303
+ self._scaler.step(optimizer)
304
+ self._scaler.update()
305
+ else:
306
+ norm = None
307
+ return norm
308
+
309
+ def state_dict(self):
310
+ return self._scaler.state_dict()
311
+
312
+ def load_state_dict(self, state_dict):
313
+ self._scaler.load_state_dict(state_dict)
314
+
315
+
316
+ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
317
+ if isinstance(parameters, torch.Tensor):
318
+ parameters = [parameters]
319
+ parameters = [p for p in parameters if p.grad is not None]
320
+ norm_type = float(norm_type)
321
+ if len(parameters) == 0:
322
+ return torch.tensor(0.0)
323
+ device = parameters[0].grad.device
324
+ if norm_type == inf:
325
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
326
+ else:
327
+ total_norm = torch.norm(
328
+ torch.stack(
329
+ [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]
330
+ ),
331
+ norm_type,
332
+ )
333
+ return total_norm
334
+
335
+
336
+ def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
337
+ output_dir = Path(args.output_dir)
338
+ epoch_name = str(epoch)
339
+ if loss_scaler is not None:
340
+ checkpoint_paths = [output_dir / ("checkpoint-%s.pth" % epoch_name)]
341
+ for checkpoint_path in checkpoint_paths:
342
+ to_save = {
343
+ "model": model_without_ddp.state_dict(),
344
+ "optimizer": optimizer.state_dict(),
345
+ "epoch": epoch,
346
+ "scaler": loss_scaler.state_dict(),
347
+ "args": args,
348
+ }
349
+
350
+ save_on_master(to_save, checkpoint_path)
351
+ else:
352
+ client_state = {"epoch": epoch}
353
+ model.save_checkpoint(
354
+ save_dir=args.output_dir,
355
+ tag="checkpoint-%s" % epoch_name,
356
+ client_state=client_state,
357
+ )
358
+
359
+
360
+ def load_model(args, model_without_ddp, optimizer, loss_scaler):
361
+ if args.resume:
362
+ if args.resume.startswith("https"):
363
+ checkpoint = torch.hub.load_state_dict_from_url(
364
+ args.resume, map_location="cpu", check_hash=True
365
+ )
366
+ else:
367
+ checkpoint = torch.load(args.resume, map_location="cpu")
368
+ model_without_ddp.load_state_dict(checkpoint["model"])
369
+ print("Resume checkpoint %s" % args.resume)
370
+ if (
371
+ "optimizer" in checkpoint
372
+ and "epoch" in checkpoint
373
+ and not (hasattr(args, "eval") and args.eval)
374
+ ):
375
+ optimizer.load_state_dict(checkpoint["optimizer"])
376
+ args.start_epoch = checkpoint["epoch"] + 1
377
+ if "scaler" in checkpoint:
378
+ loss_scaler.load_state_dict(checkpoint["scaler"])
379
+ print("With optim & sched!")
380
+
381
+
382
+ def all_reduce_mean(x):
383
+ world_size = get_world_size()
384
+ if world_size > 1:
385
+ x_reduce = torch.tensor(x).cuda()
386
+ dist.all_reduce(x_reduce)
387
+ x_reduce /= world_size
388
+ return x_reduce.item()
389
+ else:
390
+ return x
391
+
392
+
393
+ # utils
394
+ @torch.no_grad()
395
+ def concat_all_gather(tensor):
396
+ """
397
+ Performs all_gather operation on the provided tensors.
398
+ *** Warning ***: torch.distributed.all_gather has no gradient.
399
+ """
400
+ tensors_gather = [
401
+ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
402
+ ]
403
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
404
+
405
+ output = torch.cat(tensors_gather, dim=0)
406
+ return output
407
+
408
+
409
+ def merge_vmae_to_avmae(avmae_state_dict, vmae_ckpt):
410
+ # keys_to_copy=['pos_embed','patch_embed']
411
+ # replaced=0
412
+
413
+ vmae_ckpt["cls_token"] = vmae_ckpt["cls_token_v"]
414
+ vmae_ckpt["mask_token"] = vmae_ckpt["mask_token_v"]
415
+
416
+ # pos_emb % not trainable, use default
417
+ pos_embed_v = vmae_ckpt["pos_embed_v"] # 1,589,768
418
+ pos_embed = pos_embed_v[:, 1:, :] # 1,588,768
419
+ cls_embed = pos_embed_v[:, 0, :].unsqueeze(1)
420
+ pos_embed = pos_embed.reshape(1, 2, 14, 14, 768).sum(dim=1) # 1, 14, 14, 768
421
+ print("Position interpolate from 14,14 to 64,8")
422
+ pos_embed = pos_embed.permute(0, 3, 1, 2) # 1, 14,14,768 -> 1,768,14,14
423
+ pos_embed = torch.nn.functional.interpolate(
424
+ pos_embed, size=(64, 8), mode="bicubic", align_corners=False
425
+ )
426
+ pos_embed = pos_embed.permute(0, 2, 3, 1).flatten(
427
+ 1, 2
428
+ ) # 1, 14, 14, 768 => 1, 196,768
429
+ pos_embed = torch.cat((cls_embed, pos_embed), dim=1)
430
+ assert vmae_ckpt["pos_embed"].shape == pos_embed.shape
431
+ vmae_ckpt["pos_embed"] = pos_embed
432
+ # patch_emb
433
+ # aggregate 3 channels in video-rgb ckpt to 1 channel for audio
434
+ v_weight = vmae_ckpt["patch_embed_v.proj.weight"] # 768,3,2,16,16
435
+ new_proj_weight = torch.nn.Parameter(v_weight.sum(dim=2).sum(dim=1).unsqueeze(1))
436
+ assert new_proj_weight.shape == vmae_ckpt["patch_embed.proj.weight"].shape
437
+ vmae_ckpt["patch_embed.proj.weight"] = new_proj_weight
438
+ vmae_ckpt["patch_embed.proj.bias"] = vmae_ckpt["patch_embed_v.proj.bias"]
439
+
440
+ # hack
441
+ vmae_ckpt["norm.weight"] = vmae_ckpt["norm_v.weight"]
442
+ vmae_ckpt["norm.bias"] = vmae_ckpt["norm_v.bias"]
443
+
444
+ # replace transformer encoder
445
+ for k, v in vmae_ckpt.items():
446
+ if k.startswith("blocks."):
447
+ kk = k.replace("blocks.", "blocks_v.")
448
+ vmae_ckpt[k] = vmae_ckpt[kk]
449
+ elif k.startswith("blocks_v."):
450
+ pass
451
+ else:
452
+ print(k)
453
+ print(k)
FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/patch_embed.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from timm.models.layers import to_2tuple
4
+
5
+
6
+ class PatchEmbed_org(nn.Module):
7
+ """Image to Patch Embedding"""
8
+
9
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
10
+ super().__init__()
11
+ img_size = to_2tuple(img_size)
12
+ patch_size = to_2tuple(patch_size)
13
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
14
+ self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
15
+ self.img_size = img_size
16
+ self.patch_size = patch_size
17
+ self.num_patches = num_patches
18
+
19
+ self.proj = nn.Conv2d(
20
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
21
+ )
22
+
23
+ def forward(self, x):
24
+ B, C, H, W = x.shape
25
+ # FIXME look at relaxing size constraints
26
+ # assert H == self.img_size[0] and W == self.img_size[1], \
27
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
28
+ x = self.proj(x)
29
+ y = x.flatten(2).transpose(1, 2)
30
+ return y
31
+
32
+
33
+ class PatchEmbed_new(nn.Module):
34
+ """Flexible Image to Patch Embedding"""
35
+
36
+ def __init__(
37
+ self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10
38
+ ):
39
+ super().__init__()
40
+ img_size = to_2tuple(img_size)
41
+ patch_size = to_2tuple(patch_size)
42
+ stride = to_2tuple(stride)
43
+
44
+ self.img_size = img_size
45
+ self.patch_size = patch_size
46
+
47
+ self.proj = nn.Conv2d(
48
+ in_chans, embed_dim, kernel_size=patch_size, stride=stride
49
+ ) # with overlapped patches
50
+ # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
51
+
52
+ # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
53
+ # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
54
+ _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
55
+ self.patch_hw = (h, w)
56
+ self.num_patches = h * w
57
+
58
+ def get_output_shape(self, img_size):
59
+ # todo: don't be lazy..
60
+ return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape
61
+
62
+ def forward(self, x):
63
+ B, C, H, W = x.shape
64
+ # FIXME look at relaxing size constraints
65
+ # assert H == self.img_size[0] and W == self.img_size[1], \
66
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
67
+ # x = self.proj(x).flatten(2).transpose(1, 2)
68
+ x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12
69
+ x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212
70
+ x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768
71
+ return x
72
+
73
+
74
+ class PatchEmbed3D_new(nn.Module):
75
+ """Flexible Image to Patch Embedding"""
76
+
77
+ def __init__(
78
+ self,
79
+ video_size=(16, 224, 224),
80
+ patch_size=(2, 16, 16),
81
+ in_chans=3,
82
+ embed_dim=768,
83
+ stride=(2, 16, 16),
84
+ ):
85
+ super().__init__()
86
+
87
+ self.video_size = video_size
88
+ self.patch_size = patch_size
89
+ self.in_chans = in_chans
90
+
91
+ self.proj = nn.Conv3d(
92
+ in_chans, embed_dim, kernel_size=patch_size, stride=stride
93
+ )
94
+ _, _, t, h, w = self.get_output_shape(video_size) # n, emb_dim, h, w
95
+ self.patch_thw = (t, h, w)
96
+ self.num_patches = t * h * w
97
+
98
+ def get_output_shape(self, video_size):
99
+ # todo: don't be lazy..
100
+ return self.proj(
101
+ torch.randn(1, self.in_chans, video_size[0], video_size[1], video_size[2])
102
+ ).shape
103
+
104
+ def forward(self, x):
105
+ B, C, T, H, W = x.shape
106
+ x = self.proj(x) # 32, 3, 16, 224, 224 -> 32, 768, 8, 14, 14
107
+ x = x.flatten(2) # 32, 768, 1568
108
+ x = x.transpose(1, 2) # 32, 768, 1568 -> 32, 1568, 768
109
+ return x
110
+
111
+
112
+ if __name__ == "__main__":
113
+ # patch_emb = PatchEmbed_new(img_size=224, patch_size=16, in_chans=1, embed_dim=64, stride=(16,16))
114
+ # input = torch.rand(8,1,1024,128)
115
+ # output = patch_emb(input)
116
+ # print(output.shape) # (8,512,64)
117
+
118
+ patch_emb = PatchEmbed3D_new(
119
+ video_size=(6, 224, 224),
120
+ patch_size=(2, 16, 16),
121
+ in_chans=3,
122
+ embed_dim=768,
123
+ stride=(2, 16, 16),
124
+ )
125
+ input = torch.rand(8, 3, 6, 224, 224)
126
+ output = patch_emb(input)
127
+ print(output.shape) # (8,64)
FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/pos_embed.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # Position embedding utils
8
+ # --------------------------------------------------------
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+
14
+
15
+ # --------------------------------------------------------
16
+ # 2D sine-cosine position embedding
17
+ # References:
18
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
19
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
20
+ # --------------------------------------------------------
21
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
22
+ """
23
+ grid_size: int of the grid height and width
24
+ return:
25
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
26
+ """
27
+ grid_h = np.arange(grid_size, dtype=np.float32)
28
+ grid_w = np.arange(grid_size, dtype=np.float32)
29
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
30
+ grid = np.stack(grid, axis=0)
31
+
32
+ grid = grid.reshape([2, 1, grid_size, grid_size])
33
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
34
+ if cls_token:
35
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
36
+ return pos_embed
37
+
38
+
39
+ def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):
40
+ """
41
+ grid_size: int of the grid height and width
42
+ return:
43
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
44
+ """
45
+ grid_h = np.arange(grid_size[0], dtype=np.float32)
46
+ grid_w = np.arange(grid_size[1], dtype=np.float32)
47
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
48
+ grid = np.stack(grid, axis=0)
49
+
50
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
51
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
52
+ if cls_token:
53
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
54
+ return pos_embed
55
+
56
+
57
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
58
+ assert embed_dim % 2 == 0
59
+
60
+ # use half of dimensions to encode grid_h
61
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
62
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
63
+
64
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
65
+ return emb
66
+
67
+
68
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
69
+ """
70
+ embed_dim: output dimension for each position
71
+ pos: a list of positions to be encoded: size (M,)
72
+ out: (M, D)
73
+ """
74
+ assert embed_dim % 2 == 0
75
+ # omega = np.arange(embed_dim // 2, dtype=np.float)
76
+ omega = np.arange(embed_dim // 2, dtype=float)
77
+ omega /= embed_dim / 2.0
78
+ omega = 1.0 / 10000**omega # (D/2,)
79
+
80
+ pos = pos.reshape(-1) # (M,)
81
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
82
+
83
+ emb_sin = np.sin(out) # (M, D/2)
84
+ emb_cos = np.cos(out) # (M, D/2)
85
+
86
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
87
+ return emb
88
+
89
+
90
+ # --------------------------------------------------------
91
+ # Interpolate position embeddings for high-resolution
92
+ # References:
93
+ # DeiT: https://github.com/facebookresearch/deit
94
+ # --------------------------------------------------------
95
+ def interpolate_pos_embed(model, checkpoint_model):
96
+ if "pos_embed" in checkpoint_model:
97
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
98
+ embedding_size = pos_embed_checkpoint.shape[-1]
99
+ num_patches = model.patch_embed.num_patches
100
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
101
+ # height (== width) for the checkpoint position embedding
102
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
103
+ # height (== width) for the new position embedding
104
+ new_size = int(num_patches**0.5)
105
+ # class_token and dist_token are kept unchanged
106
+ if orig_size != new_size:
107
+ print(
108
+ "Position interpolate from %dx%d to %dx%d"
109
+ % (orig_size, orig_size, new_size, new_size)
110
+ )
111
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
112
+ # only the position tokens are interpolated
113
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
114
+ pos_tokens = pos_tokens.reshape(
115
+ -1, orig_size, orig_size, embedding_size
116
+ ).permute(0, 3, 1, 2)
117
+ pos_tokens = torch.nn.functional.interpolate(
118
+ pos_tokens,
119
+ size=(new_size, new_size),
120
+ mode="bicubic",
121
+ align_corners=False,
122
+ )
123
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
124
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
125
+ checkpoint_model["pos_embed"] = new_pos_embed
126
+
127
+
128
+ def interpolate_pos_embed_img2audio(model, checkpoint_model, orig_size, new_size):
129
+ if "pos_embed" in checkpoint_model:
130
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
131
+ embedding_size = pos_embed_checkpoint.shape[-1]
132
+ num_patches = model.patch_embed.num_patches
133
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
134
+ # height (== width) for the checkpoint position embedding
135
+ # orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
136
+ # height (== width) for the new position embedding
137
+ # new_size = int(num_patches ** 0.5)
138
+ # class_token and dist_token are kept unchanged
139
+ if orig_size != new_size:
140
+ print(
141
+ "Position interpolate from %dx%d to %dx%d"
142
+ % (orig_size[0], orig_size[1], new_size[0], new_size[1])
143
+ )
144
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
145
+ # only the position tokens are interpolated
146
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
147
+ pos_tokens = pos_tokens.reshape(
148
+ -1, orig_size[0], orig_size[1], embedding_size
149
+ ).permute(0, 3, 1, 2)
150
+ pos_tokens = torch.nn.functional.interpolate(
151
+ pos_tokens,
152
+ size=(new_size[0], new_size[1]),
153
+ mode="bicubic",
154
+ align_corners=False,
155
+ )
156
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
157
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
158
+ checkpoint_model["pos_embed"] = new_pos_embed
159
+
160
+
161
+ def interpolate_pos_embed_audio(model, checkpoint_model, orig_size, new_size):
162
+ if "pos_embed" in checkpoint_model:
163
+ pos_embed_checkpoint = checkpoint_model["pos_embed"]
164
+ embedding_size = pos_embed_checkpoint.shape[-1]
165
+ num_patches = model.patch_embed.num_patches
166
+ model.pos_embed.shape[-2] - num_patches
167
+ if orig_size != new_size:
168
+ print(
169
+ "Position interpolate from %dx%d to %dx%d"
170
+ % (orig_size[0], orig_size[1], new_size[0], new_size[1])
171
+ )
172
+ # extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
173
+ # only the position tokens are interpolated
174
+ cls_token = pos_embed_checkpoint[:, 0, :].unsqueeze(1)
175
+ pos_tokens = pos_embed_checkpoint[:, 1:, :] # remove
176
+ pos_tokens = pos_tokens.reshape(
177
+ -1, orig_size[0], orig_size[1], embedding_size
178
+ ) # .permute(0, 3, 1, 2)
179
+ # pos_tokens = torch.nn.functional.interpolate(
180
+ # pos_tokens, size=(new_size[0], new_size[1]), mode='bicubic', align_corners=False)
181
+
182
+ # pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
183
+ pos_tokens = pos_tokens[:, :, : new_size[1], :] # assume only time diff
184
+ pos_tokens = pos_tokens.flatten(1, 2)
185
+ new_pos_embed = torch.cat((cls_token, pos_tokens), dim=1)
186
+ checkpoint_model["pos_embed"] = new_pos_embed
187
+
188
+
189
+ def interpolate_patch_embed_audio(
190
+ model,
191
+ checkpoint_model,
192
+ orig_channel,
193
+ new_channel=1,
194
+ kernel_size=(16, 16),
195
+ stride=(16, 16),
196
+ padding=(0, 0),
197
+ ):
198
+ if orig_channel != new_channel:
199
+ if "patch_embed.proj.weight" in checkpoint_model:
200
+ # aggregate 3 channels in rgb ckpt to 1 channel for audio
201
+ new_proj_weight = torch.nn.Parameter(
202
+ torch.sum(checkpoint_model["patch_embed.proj.weight"], dim=1).unsqueeze(
203
+ 1
204
+ )
205
+ )
206
+ checkpoint_model["patch_embed.proj.weight"] = new_proj_weight
FlashSR/AudioSR/latent_diffusion/modules/audiomae/util/stat.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy import stats
3
+ from sklearn import metrics
4
+ import torch
5
+
6
+
7
+ def d_prime(auc):
8
+ standard_normal = stats.norm()
9
+ d_prime = standard_normal.ppf(auc) * np.sqrt(2.0)
10
+ return d_prime
11
+
12
+
13
+ @torch.no_grad()
14
+ def concat_all_gather(tensor):
15
+ """
16
+ Performs all_gather operation on the provided tensors.
17
+ *** Warning ***: torch.distributed.all_gather has no gradient.
18
+ """
19
+ tensors_gather = [
20
+ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
21
+ ]
22
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
23
+
24
+ output = torch.cat(tensors_gather, dim=0)
25
+ return output
26
+
27
+
28
+ def calculate_stats(output, target):
29
+ """Calculate statistics including mAP, AUC, etc.
30
+
31
+ Args:
32
+ output: 2d array, (samples_num, classes_num)
33
+ target: 2d array, (samples_num, classes_num)
34
+
35
+ Returns:
36
+ stats: list of statistic of each class.
37
+ """
38
+
39
+ classes_num = target.shape[-1]
40
+ stats = []
41
+
42
+ # Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet
43
+ acc = metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1))
44
+
45
+ # Class-wise statistics
46
+ for k in range(classes_num):
47
+ # Average precision
48
+ avg_precision = metrics.average_precision_score(
49
+ target[:, k], output[:, k], average=None
50
+ )
51
+
52
+ # AUC
53
+ # auc = metrics.roc_auc_score(target[:, k], output[:, k], average=None)
54
+
55
+ # Precisions, recalls
56
+ (precisions, recalls, thresholds) = metrics.precision_recall_curve(
57
+ target[:, k], output[:, k]
58
+ )
59
+
60
+ # FPR, TPR
61
+ (fpr, tpr, thresholds) = metrics.roc_curve(target[:, k], output[:, k])
62
+
63
+ save_every_steps = 1000 # Sample statistics to reduce size
64
+ dict = {
65
+ "precisions": precisions[0::save_every_steps],
66
+ "recalls": recalls[0::save_every_steps],
67
+ "AP": avg_precision,
68
+ "fpr": fpr[0::save_every_steps],
69
+ "fnr": 1.0 - tpr[0::save_every_steps],
70
+ # 'auc': auc,
71
+ # note acc is not class-wise, this is just to keep consistent with other metrics
72
+ "acc": acc,
73
+ }
74
+ stats.append(dict)
75
+
76
+ return stats
FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/__init__.py ADDED
File without changes
FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/model.py ADDED
@@ -0,0 +1,1069 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pytorch_diffusion + derived encoder decoder
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from einops import rearrange
7
+
8
+ from audiosr.latent_diffusion.util import instantiate_from_config
9
+ from audiosr.latent_diffusion.modules.attention import LinearAttention
10
+
11
+
12
+ def get_timestep_embedding(timesteps, embedding_dim):
13
+ """
14
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
15
+ From Fairseq.
16
+ Build sinusoidal embeddings.
17
+ This matches the implementation in tensor2tensor, but differs slightly
18
+ from the description in Section 3.5 of "Attention Is All You Need".
19
+ """
20
+ assert len(timesteps.shape) == 1
21
+
22
+ half_dim = embedding_dim // 2
23
+ emb = math.log(10000) / (half_dim - 1)
24
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
25
+ emb = emb.to(device=timesteps.device)
26
+ emb = timesteps.float()[:, None] * emb[None, :]
27
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
28
+ if embedding_dim % 2 == 1: # zero pad
29
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
30
+ return emb
31
+
32
+
33
+ def nonlinearity(x):
34
+ # swish
35
+ return x * torch.sigmoid(x)
36
+
37
+
38
+ def Normalize(in_channels, num_groups=32):
39
+ return torch.nn.GroupNorm(
40
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
41
+ )
42
+
43
+
44
+ class Upsample(nn.Module):
45
+ def __init__(self, in_channels, with_conv):
46
+ super().__init__()
47
+ self.with_conv = with_conv
48
+ if self.with_conv:
49
+ self.conv = torch.nn.Conv2d(
50
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
51
+ )
52
+
53
+ def forward(self, x):
54
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
55
+ if self.with_conv:
56
+ x = self.conv(x)
57
+ return x
58
+
59
+
60
+ class UpsampleTimeStride4(nn.Module):
61
+ def __init__(self, in_channels, with_conv):
62
+ super().__init__()
63
+ self.with_conv = with_conv
64
+ if self.with_conv:
65
+ self.conv = torch.nn.Conv2d(
66
+ in_channels, in_channels, kernel_size=5, stride=1, padding=2
67
+ )
68
+
69
+ def forward(self, x):
70
+ x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest")
71
+ if self.with_conv:
72
+ x = self.conv(x)
73
+ return x
74
+
75
+
76
+ class Downsample(nn.Module):
77
+ def __init__(self, in_channels, with_conv):
78
+ super().__init__()
79
+ self.with_conv = with_conv
80
+ if self.with_conv:
81
+ # Do time downsampling here
82
+ # no asymmetric padding in torch conv, must do it ourselves
83
+ self.conv = torch.nn.Conv2d(
84
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
85
+ )
86
+
87
+ def forward(self, x):
88
+ if self.with_conv:
89
+ pad = (0, 1, 0, 1)
90
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
91
+ x = self.conv(x)
92
+ else:
93
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
94
+ return x
95
+
96
+
97
+ class DownsampleTimeStride4(nn.Module):
98
+ def __init__(self, in_channels, with_conv):
99
+ super().__init__()
100
+ self.with_conv = with_conv
101
+ if self.with_conv:
102
+ # Do time downsampling here
103
+ # no asymmetric padding in torch conv, must do it ourselves
104
+ self.conv = torch.nn.Conv2d(
105
+ in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1
106
+ )
107
+
108
+ def forward(self, x):
109
+ if self.with_conv:
110
+ pad = (0, 1, 0, 1)
111
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
112
+ x = self.conv(x)
113
+ else:
114
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2))
115
+ return x
116
+
117
+
118
+ class ResnetBlock(nn.Module):
119
+ def __init__(
120
+ self,
121
+ *,
122
+ in_channels,
123
+ out_channels=None,
124
+ conv_shortcut=False,
125
+ dropout,
126
+ temb_channels=512,
127
+ ):
128
+ super().__init__()
129
+ self.in_channels = in_channels
130
+ out_channels = in_channels if out_channels is None else out_channels
131
+ self.out_channels = out_channels
132
+ self.use_conv_shortcut = conv_shortcut
133
+
134
+ self.norm1 = Normalize(in_channels)
135
+ self.conv1 = torch.nn.Conv2d(
136
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
137
+ )
138
+ if temb_channels > 0:
139
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
140
+ self.norm2 = Normalize(out_channels)
141
+ self.dropout = torch.nn.Dropout(dropout)
142
+ self.conv2 = torch.nn.Conv2d(
143
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
144
+ )
145
+ if self.in_channels != self.out_channels:
146
+ if self.use_conv_shortcut:
147
+ self.conv_shortcut = torch.nn.Conv2d(
148
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
149
+ )
150
+ else:
151
+ self.nin_shortcut = torch.nn.Conv2d(
152
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
153
+ )
154
+
155
+ def forward(self, x, temb):
156
+ h = x
157
+ h = self.norm1(h)
158
+ h = nonlinearity(h)
159
+ h = self.conv1(h)
160
+
161
+ if temb is not None:
162
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
163
+
164
+ h = self.norm2(h)
165
+ h = nonlinearity(h)
166
+ h = self.dropout(h)
167
+ h = self.conv2(h)
168
+
169
+ if self.in_channels != self.out_channels:
170
+ if self.use_conv_shortcut:
171
+ x = self.conv_shortcut(x)
172
+ else:
173
+ x = self.nin_shortcut(x)
174
+
175
+ return x + h
176
+
177
+
178
+ class LinAttnBlock(LinearAttention):
179
+ """to match AttnBlock usage"""
180
+
181
+ def __init__(self, in_channels):
182
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
183
+
184
+
185
+ class AttnBlock(nn.Module):
186
+ def __init__(self, in_channels):
187
+ super().__init__()
188
+ self.in_channels = in_channels
189
+
190
+ self.norm = Normalize(in_channels)
191
+ self.q = torch.nn.Conv2d(
192
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
193
+ )
194
+ self.k = torch.nn.Conv2d(
195
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
196
+ )
197
+ self.v = torch.nn.Conv2d(
198
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
199
+ )
200
+ self.proj_out = torch.nn.Conv2d(
201
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
202
+ )
203
+
204
+ def forward(self, x):
205
+ h_ = x
206
+ h_ = self.norm(h_)
207
+ q = self.q(h_)
208
+ k = self.k(h_)
209
+ v = self.v(h_)
210
+
211
+ # compute attention
212
+ b, c, h, w = q.shape
213
+ q = q.reshape(b, c, h * w).contiguous()
214
+ q = q.permute(0, 2, 1).contiguous() # b,hw,c
215
+ k = k.reshape(b, c, h * w).contiguous() # b,c,hw
216
+ w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
217
+ w_ = w_ * (int(c) ** (-0.5))
218
+ w_ = torch.nn.functional.softmax(w_, dim=2)
219
+
220
+ # attend to values
221
+ v = v.reshape(b, c, h * w).contiguous()
222
+ w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
223
+ h_ = torch.bmm(
224
+ v, w_
225
+ ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
226
+ h_ = h_.reshape(b, c, h, w).contiguous()
227
+
228
+ h_ = self.proj_out(h_)
229
+
230
+ return x + h_
231
+
232
+
233
+ def make_attn(in_channels, attn_type="vanilla"):
234
+ assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown"
235
+ # print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
236
+ if attn_type == "vanilla":
237
+ return AttnBlock(in_channels)
238
+ elif attn_type == "none":
239
+ return nn.Identity(in_channels)
240
+ else:
241
+ return LinAttnBlock(in_channels)
242
+
243
+
244
+ class Model(nn.Module):
245
+ def __init__(
246
+ self,
247
+ *,
248
+ ch,
249
+ out_ch,
250
+ ch_mult=(1, 2, 4, 8),
251
+ num_res_blocks,
252
+ attn_resolutions,
253
+ dropout=0.0,
254
+ resamp_with_conv=True,
255
+ in_channels,
256
+ resolution,
257
+ use_timestep=True,
258
+ use_linear_attn=False,
259
+ attn_type="vanilla",
260
+ ):
261
+ super().__init__()
262
+ if use_linear_attn:
263
+ attn_type = "linear"
264
+ self.ch = ch
265
+ self.temb_ch = self.ch * 4
266
+ self.num_resolutions = len(ch_mult)
267
+ self.num_res_blocks = num_res_blocks
268
+ self.resolution = resolution
269
+ self.in_channels = in_channels
270
+
271
+ self.use_timestep = use_timestep
272
+ if self.use_timestep:
273
+ # timestep embedding
274
+ self.temb = nn.Module()
275
+ self.temb.dense = nn.ModuleList(
276
+ [
277
+ torch.nn.Linear(self.ch, self.temb_ch),
278
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
279
+ ]
280
+ )
281
+
282
+ # downsampling
283
+ self.conv_in = torch.nn.Conv2d(
284
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
285
+ )
286
+
287
+ curr_res = resolution
288
+ in_ch_mult = (1,) + tuple(ch_mult)
289
+ self.down = nn.ModuleList()
290
+ for i_level in range(self.num_resolutions):
291
+ block = nn.ModuleList()
292
+ attn = nn.ModuleList()
293
+ block_in = ch * in_ch_mult[i_level]
294
+ block_out = ch * ch_mult[i_level]
295
+ for i_block in range(self.num_res_blocks):
296
+ block.append(
297
+ ResnetBlock(
298
+ in_channels=block_in,
299
+ out_channels=block_out,
300
+ temb_channels=self.temb_ch,
301
+ dropout=dropout,
302
+ )
303
+ )
304
+ block_in = block_out
305
+ if curr_res in attn_resolutions:
306
+ attn.append(make_attn(block_in, attn_type=attn_type))
307
+ down = nn.Module()
308
+ down.block = block
309
+ down.attn = attn
310
+ if i_level != self.num_resolutions - 1:
311
+ down.downsample = Downsample(block_in, resamp_with_conv)
312
+ curr_res = curr_res // 2
313
+ self.down.append(down)
314
+
315
+ # middle
316
+ self.mid = nn.Module()
317
+ self.mid.block_1 = ResnetBlock(
318
+ in_channels=block_in,
319
+ out_channels=block_in,
320
+ temb_channels=self.temb_ch,
321
+ dropout=dropout,
322
+ )
323
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
324
+ self.mid.block_2 = ResnetBlock(
325
+ in_channels=block_in,
326
+ out_channels=block_in,
327
+ temb_channels=self.temb_ch,
328
+ dropout=dropout,
329
+ )
330
+
331
+ # upsampling
332
+ self.up = nn.ModuleList()
333
+ for i_level in reversed(range(self.num_resolutions)):
334
+ block = nn.ModuleList()
335
+ attn = nn.ModuleList()
336
+ block_out = ch * ch_mult[i_level]
337
+ skip_in = ch * ch_mult[i_level]
338
+ for i_block in range(self.num_res_blocks + 1):
339
+ if i_block == self.num_res_blocks:
340
+ skip_in = ch * in_ch_mult[i_level]
341
+ block.append(
342
+ ResnetBlock(
343
+ in_channels=block_in + skip_in,
344
+ out_channels=block_out,
345
+ temb_channels=self.temb_ch,
346
+ dropout=dropout,
347
+ )
348
+ )
349
+ block_in = block_out
350
+ if curr_res in attn_resolutions:
351
+ attn.append(make_attn(block_in, attn_type=attn_type))
352
+ up = nn.Module()
353
+ up.block = block
354
+ up.attn = attn
355
+ if i_level != 0:
356
+ up.upsample = Upsample(block_in, resamp_with_conv)
357
+ curr_res = curr_res * 2
358
+ self.up.insert(0, up) # prepend to get consistent order
359
+
360
+ # end
361
+ self.norm_out = Normalize(block_in)
362
+ self.conv_out = torch.nn.Conv2d(
363
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
364
+ )
365
+
366
+ def forward(self, x, t=None, context=None):
367
+ # assert x.shape[2] == x.shape[3] == self.resolution
368
+ if context is not None:
369
+ # assume aligned context, cat along channel axis
370
+ x = torch.cat((x, context), dim=1)
371
+ if self.use_timestep:
372
+ # timestep embedding
373
+ assert t is not None
374
+ temb = get_timestep_embedding(t, self.ch)
375
+ temb = self.temb.dense[0](temb)
376
+ temb = nonlinearity(temb)
377
+ temb = self.temb.dense[1](temb)
378
+ else:
379
+ temb = None
380
+
381
+ # downsampling
382
+ hs = [self.conv_in(x)]
383
+ for i_level in range(self.num_resolutions):
384
+ for i_block in range(self.num_res_blocks):
385
+ h = self.down[i_level].block[i_block](hs[-1], temb)
386
+ if len(self.down[i_level].attn) > 0:
387
+ h = self.down[i_level].attn[i_block](h)
388
+ hs.append(h)
389
+ if i_level != self.num_resolutions - 1:
390
+ hs.append(self.down[i_level].downsample(hs[-1]))
391
+
392
+ # middle
393
+ h = hs[-1]
394
+ h = self.mid.block_1(h, temb)
395
+ h = self.mid.attn_1(h)
396
+ h = self.mid.block_2(h, temb)
397
+
398
+ # upsampling
399
+ for i_level in reversed(range(self.num_resolutions)):
400
+ for i_block in range(self.num_res_blocks + 1):
401
+ h = self.up[i_level].block[i_block](
402
+ torch.cat([h, hs.pop()], dim=1), temb
403
+ )
404
+ if len(self.up[i_level].attn) > 0:
405
+ h = self.up[i_level].attn[i_block](h)
406
+ if i_level != 0:
407
+ h = self.up[i_level].upsample(h)
408
+
409
+ # end
410
+ h = self.norm_out(h)
411
+ h = nonlinearity(h)
412
+ h = self.conv_out(h)
413
+ return h
414
+
415
+ def get_last_layer(self):
416
+ return self.conv_out.weight
417
+
418
+
419
+ class Encoder(nn.Module):
420
+ def __init__(
421
+ self,
422
+ *,
423
+ ch,
424
+ out_ch,
425
+ ch_mult=(1, 2, 4, 8),
426
+ num_res_blocks,
427
+ attn_resolutions,
428
+ dropout=0.0,
429
+ resamp_with_conv=True,
430
+ in_channels,
431
+ resolution,
432
+ z_channels,
433
+ double_z=True,
434
+ use_linear_attn=False,
435
+ attn_type="vanilla",
436
+ downsample_time_stride4_levels=[],
437
+ **ignore_kwargs,
438
+ ):
439
+ super().__init__()
440
+ if use_linear_attn:
441
+ attn_type = "linear"
442
+ self.ch = ch
443
+ self.temb_ch = 0
444
+ self.num_resolutions = len(ch_mult)
445
+ self.num_res_blocks = num_res_blocks
446
+ self.resolution = resolution
447
+ self.in_channels = in_channels
448
+ self.downsample_time_stride4_levels = downsample_time_stride4_levels
449
+
450
+ if len(self.downsample_time_stride4_levels) > 0:
451
+ assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
452
+ "The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
453
+ % str(self.num_resolutions)
454
+ )
455
+
456
+ # downsampling
457
+ self.conv_in = torch.nn.Conv2d(
458
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
459
+ )
460
+
461
+ curr_res = resolution
462
+ in_ch_mult = (1,) + tuple(ch_mult)
463
+ self.in_ch_mult = in_ch_mult
464
+ self.down = nn.ModuleList()
465
+ for i_level in range(self.num_resolutions):
466
+ block = nn.ModuleList()
467
+ attn = nn.ModuleList()
468
+ block_in = ch * in_ch_mult[i_level]
469
+ block_out = ch * ch_mult[i_level]
470
+ for i_block in range(self.num_res_blocks):
471
+ block.append(
472
+ ResnetBlock(
473
+ in_channels=block_in,
474
+ out_channels=block_out,
475
+ temb_channels=self.temb_ch,
476
+ dropout=dropout,
477
+ )
478
+ )
479
+ block_in = block_out
480
+ if curr_res in attn_resolutions:
481
+ attn.append(make_attn(block_in, attn_type=attn_type))
482
+ down = nn.Module()
483
+ down.block = block
484
+ down.attn = attn
485
+ if i_level != self.num_resolutions - 1:
486
+ if i_level in self.downsample_time_stride4_levels:
487
+ down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv)
488
+ else:
489
+ down.downsample = Downsample(block_in, resamp_with_conv)
490
+ curr_res = curr_res // 2
491
+ self.down.append(down)
492
+
493
+ # middle
494
+ self.mid = nn.Module()
495
+ self.mid.block_1 = ResnetBlock(
496
+ in_channels=block_in,
497
+ out_channels=block_in,
498
+ temb_channels=self.temb_ch,
499
+ dropout=dropout,
500
+ )
501
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
502
+ self.mid.block_2 = ResnetBlock(
503
+ in_channels=block_in,
504
+ out_channels=block_in,
505
+ temb_channels=self.temb_ch,
506
+ dropout=dropout,
507
+ )
508
+
509
+ # end
510
+ self.norm_out = Normalize(block_in)
511
+ self.conv_out = torch.nn.Conv2d(
512
+ block_in,
513
+ 2 * z_channels if double_z else z_channels,
514
+ kernel_size=3,
515
+ stride=1,
516
+ padding=1,
517
+ )
518
+
519
+ def forward(self, x):
520
+ # timestep embedding
521
+ temb = None
522
+ # downsampling
523
+ hs = [self.conv_in(x)]
524
+ for i_level in range(self.num_resolutions):
525
+ for i_block in range(self.num_res_blocks):
526
+ h = self.down[i_level].block[i_block](hs[-1], temb)
527
+ if len(self.down[i_level].attn) > 0:
528
+ h = self.down[i_level].attn[i_block](h)
529
+ hs.append(h)
530
+ if i_level != self.num_resolutions - 1:
531
+ hs.append(self.down[i_level].downsample(hs[-1]))
532
+
533
+ # middle
534
+ h = hs[-1]
535
+ h = self.mid.block_1(h, temb)
536
+ h = self.mid.attn_1(h)
537
+ h = self.mid.block_2(h, temb)
538
+
539
+ # end
540
+ h = self.norm_out(h)
541
+ h = nonlinearity(h)
542
+ h = self.conv_out(h)
543
+ return h
544
+
545
+
546
+ class Decoder(nn.Module):
547
+ def __init__(
548
+ self,
549
+ *,
550
+ ch,
551
+ out_ch,
552
+ ch_mult=(1, 2, 4, 8),
553
+ num_res_blocks,
554
+ attn_resolutions,
555
+ dropout=0.0,
556
+ resamp_with_conv=True,
557
+ in_channels,
558
+ resolution,
559
+ z_channels,
560
+ give_pre_end=False,
561
+ tanh_out=False,
562
+ use_linear_attn=False,
563
+ downsample_time_stride4_levels=[],
564
+ attn_type="vanilla",
565
+ **ignorekwargs,
566
+ ):
567
+ super().__init__()
568
+ if use_linear_attn:
569
+ attn_type = "linear"
570
+ self.ch = ch
571
+ self.temb_ch = 0
572
+ self.num_resolutions = len(ch_mult)
573
+ self.num_res_blocks = num_res_blocks
574
+ self.resolution = resolution
575
+ self.in_channels = in_channels
576
+ self.give_pre_end = give_pre_end
577
+ self.tanh_out = tanh_out
578
+ self.downsample_time_stride4_levels = downsample_time_stride4_levels
579
+
580
+ if len(self.downsample_time_stride4_levels) > 0:
581
+ assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (
582
+ "The level to perform downsample 4 operation need to be smaller than the total resolution number %s"
583
+ % str(self.num_resolutions)
584
+ )
585
+
586
+ # compute in_ch_mult, block_in and curr_res at lowest res
587
+ (1,) + tuple(ch_mult)
588
+ block_in = ch * ch_mult[self.num_resolutions - 1]
589
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
590
+ self.z_shape = (1, z_channels, curr_res, curr_res)
591
+ # print(
592
+ # "Working with z of shape {} = {} dimensions.".format(
593
+ # self.z_shape, np.prod(self.z_shape)
594
+ # )
595
+ # )
596
+
597
+ # z to block_in
598
+ self.conv_in = torch.nn.Conv2d(
599
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
600
+ )
601
+
602
+ # middle
603
+ self.mid = nn.Module()
604
+ self.mid.block_1 = ResnetBlock(
605
+ in_channels=block_in,
606
+ out_channels=block_in,
607
+ temb_channels=self.temb_ch,
608
+ dropout=dropout,
609
+ )
610
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
611
+ self.mid.block_2 = ResnetBlock(
612
+ in_channels=block_in,
613
+ out_channels=block_in,
614
+ temb_channels=self.temb_ch,
615
+ dropout=dropout,
616
+ )
617
+
618
+ # upsampling
619
+ self.up = nn.ModuleList()
620
+ for i_level in reversed(range(self.num_resolutions)):
621
+ block = nn.ModuleList()
622
+ attn = nn.ModuleList()
623
+ block_out = ch * ch_mult[i_level]
624
+ for i_block in range(self.num_res_blocks + 1):
625
+ block.append(
626
+ ResnetBlock(
627
+ in_channels=block_in,
628
+ out_channels=block_out,
629
+ temb_channels=self.temb_ch,
630
+ dropout=dropout,
631
+ )
632
+ )
633
+ block_in = block_out
634
+ if curr_res in attn_resolutions:
635
+ attn.append(make_attn(block_in, attn_type=attn_type))
636
+ up = nn.Module()
637
+ up.block = block
638
+ up.attn = attn
639
+ if i_level != 0:
640
+ if i_level - 1 in self.downsample_time_stride4_levels:
641
+ up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv)
642
+ else:
643
+ up.upsample = Upsample(block_in, resamp_with_conv)
644
+ curr_res = curr_res * 2
645
+ self.up.insert(0, up) # prepend to get consistent order
646
+
647
+ # end
648
+ self.norm_out = Normalize(block_in)
649
+ self.conv_out = torch.nn.Conv2d(
650
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
651
+ )
652
+
653
+ def forward(self, z):
654
+ # assert z.shape[1:] == self.z_shape[1:]
655
+ self.last_z_shape = z.shape
656
+
657
+ # timestep embedding
658
+ temb = None
659
+
660
+ # z to block_in
661
+ h = self.conv_in(z)
662
+
663
+ # middle
664
+ h = self.mid.block_1(h, temb)
665
+ h = self.mid.attn_1(h)
666
+ h = self.mid.block_2(h, temb)
667
+
668
+ # upsampling
669
+ for i_level in reversed(range(self.num_resolutions)):
670
+ for i_block in range(self.num_res_blocks + 1):
671
+ h = self.up[i_level].block[i_block](h, temb)
672
+ if len(self.up[i_level].attn) > 0:
673
+ h = self.up[i_level].attn[i_block](h)
674
+ if i_level != 0:
675
+ h = self.up[i_level].upsample(h)
676
+
677
+ # end
678
+ if self.give_pre_end:
679
+ return h
680
+
681
+ h = self.norm_out(h)
682
+ h = nonlinearity(h)
683
+ h = self.conv_out(h)
684
+ if self.tanh_out:
685
+ h = torch.tanh(h)
686
+ return h
687
+
688
+
689
+ class SimpleDecoder(nn.Module):
690
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
691
+ super().__init__()
692
+ self.model = nn.ModuleList(
693
+ [
694
+ nn.Conv2d(in_channels, in_channels, 1),
695
+ ResnetBlock(
696
+ in_channels=in_channels,
697
+ out_channels=2 * in_channels,
698
+ temb_channels=0,
699
+ dropout=0.0,
700
+ ),
701
+ ResnetBlock(
702
+ in_channels=2 * in_channels,
703
+ out_channels=4 * in_channels,
704
+ temb_channels=0,
705
+ dropout=0.0,
706
+ ),
707
+ ResnetBlock(
708
+ in_channels=4 * in_channels,
709
+ out_channels=2 * in_channels,
710
+ temb_channels=0,
711
+ dropout=0.0,
712
+ ),
713
+ nn.Conv2d(2 * in_channels, in_channels, 1),
714
+ Upsample(in_channels, with_conv=True),
715
+ ]
716
+ )
717
+ # end
718
+ self.norm_out = Normalize(in_channels)
719
+ self.conv_out = torch.nn.Conv2d(
720
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
721
+ )
722
+
723
+ def forward(self, x):
724
+ for i, layer in enumerate(self.model):
725
+ if i in [1, 2, 3]:
726
+ x = layer(x, None)
727
+ else:
728
+ x = layer(x)
729
+
730
+ h = self.norm_out(x)
731
+ h = nonlinearity(h)
732
+ x = self.conv_out(h)
733
+ return x
734
+
735
+
736
+ class UpsampleDecoder(nn.Module):
737
+ def __init__(
738
+ self,
739
+ in_channels,
740
+ out_channels,
741
+ ch,
742
+ num_res_blocks,
743
+ resolution,
744
+ ch_mult=(2, 2),
745
+ dropout=0.0,
746
+ ):
747
+ super().__init__()
748
+ # upsampling
749
+ self.temb_ch = 0
750
+ self.num_resolutions = len(ch_mult)
751
+ self.num_res_blocks = num_res_blocks
752
+ block_in = in_channels
753
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
754
+ self.res_blocks = nn.ModuleList()
755
+ self.upsample_blocks = nn.ModuleList()
756
+ for i_level in range(self.num_resolutions):
757
+ res_block = []
758
+ block_out = ch * ch_mult[i_level]
759
+ for i_block in range(self.num_res_blocks + 1):
760
+ res_block.append(
761
+ ResnetBlock(
762
+ in_channels=block_in,
763
+ out_channels=block_out,
764
+ temb_channels=self.temb_ch,
765
+ dropout=dropout,
766
+ )
767
+ )
768
+ block_in = block_out
769
+ self.res_blocks.append(nn.ModuleList(res_block))
770
+ if i_level != self.num_resolutions - 1:
771
+ self.upsample_blocks.append(Upsample(block_in, True))
772
+ curr_res = curr_res * 2
773
+
774
+ # end
775
+ self.norm_out = Normalize(block_in)
776
+ self.conv_out = torch.nn.Conv2d(
777
+ block_in, out_channels, kernel_size=3, stride=1, padding=1
778
+ )
779
+
780
+ def forward(self, x):
781
+ # upsampling
782
+ h = x
783
+ for k, i_level in enumerate(range(self.num_resolutions)):
784
+ for i_block in range(self.num_res_blocks + 1):
785
+ h = self.res_blocks[i_level][i_block](h, None)
786
+ if i_level != self.num_resolutions - 1:
787
+ h = self.upsample_blocks[k](h)
788
+ h = self.norm_out(h)
789
+ h = nonlinearity(h)
790
+ h = self.conv_out(h)
791
+ return h
792
+
793
+
794
+ class LatentRescaler(nn.Module):
795
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
796
+ super().__init__()
797
+ # residual block, interpolate, residual block
798
+ self.factor = factor
799
+ self.conv_in = nn.Conv2d(
800
+ in_channels, mid_channels, kernel_size=3, stride=1, padding=1
801
+ )
802
+ self.res_block1 = nn.ModuleList(
803
+ [
804
+ ResnetBlock(
805
+ in_channels=mid_channels,
806
+ out_channels=mid_channels,
807
+ temb_channels=0,
808
+ dropout=0.0,
809
+ )
810
+ for _ in range(depth)
811
+ ]
812
+ )
813
+ self.attn = AttnBlock(mid_channels)
814
+ self.res_block2 = nn.ModuleList(
815
+ [
816
+ ResnetBlock(
817
+ in_channels=mid_channels,
818
+ out_channels=mid_channels,
819
+ temb_channels=0,
820
+ dropout=0.0,
821
+ )
822
+ for _ in range(depth)
823
+ ]
824
+ )
825
+
826
+ self.conv_out = nn.Conv2d(
827
+ mid_channels,
828
+ out_channels,
829
+ kernel_size=1,
830
+ )
831
+
832
+ def forward(self, x):
833
+ x = self.conv_in(x)
834
+ for block in self.res_block1:
835
+ x = block(x, None)
836
+ x = torch.nn.functional.interpolate(
837
+ x,
838
+ size=(
839
+ int(round(x.shape[2] * self.factor)),
840
+ int(round(x.shape[3] * self.factor)),
841
+ ),
842
+ )
843
+ x = self.attn(x).contiguous()
844
+ for block in self.res_block2:
845
+ x = block(x, None)
846
+ x = self.conv_out(x)
847
+ return x
848
+
849
+
850
+ class MergedRescaleEncoder(nn.Module):
851
+ def __init__(
852
+ self,
853
+ in_channels,
854
+ ch,
855
+ resolution,
856
+ out_ch,
857
+ num_res_blocks,
858
+ attn_resolutions,
859
+ dropout=0.0,
860
+ resamp_with_conv=True,
861
+ ch_mult=(1, 2, 4, 8),
862
+ rescale_factor=1.0,
863
+ rescale_module_depth=1,
864
+ ):
865
+ super().__init__()
866
+ intermediate_chn = ch * ch_mult[-1]
867
+ self.encoder = Encoder(
868
+ in_channels=in_channels,
869
+ num_res_blocks=num_res_blocks,
870
+ ch=ch,
871
+ ch_mult=ch_mult,
872
+ z_channels=intermediate_chn,
873
+ double_z=False,
874
+ resolution=resolution,
875
+ attn_resolutions=attn_resolutions,
876
+ dropout=dropout,
877
+ resamp_with_conv=resamp_with_conv,
878
+ out_ch=None,
879
+ )
880
+ self.rescaler = LatentRescaler(
881
+ factor=rescale_factor,
882
+ in_channels=intermediate_chn,
883
+ mid_channels=intermediate_chn,
884
+ out_channels=out_ch,
885
+ depth=rescale_module_depth,
886
+ )
887
+
888
+ def forward(self, x):
889
+ x = self.encoder(x)
890
+ x = self.rescaler(x)
891
+ return x
892
+
893
+
894
+ class MergedRescaleDecoder(nn.Module):
895
+ def __init__(
896
+ self,
897
+ z_channels,
898
+ out_ch,
899
+ resolution,
900
+ num_res_blocks,
901
+ attn_resolutions,
902
+ ch,
903
+ ch_mult=(1, 2, 4, 8),
904
+ dropout=0.0,
905
+ resamp_with_conv=True,
906
+ rescale_factor=1.0,
907
+ rescale_module_depth=1,
908
+ ):
909
+ super().__init__()
910
+ tmp_chn = z_channels * ch_mult[-1]
911
+ self.decoder = Decoder(
912
+ out_ch=out_ch,
913
+ z_channels=tmp_chn,
914
+ attn_resolutions=attn_resolutions,
915
+ dropout=dropout,
916
+ resamp_with_conv=resamp_with_conv,
917
+ in_channels=None,
918
+ num_res_blocks=num_res_blocks,
919
+ ch_mult=ch_mult,
920
+ resolution=resolution,
921
+ ch=ch,
922
+ )
923
+ self.rescaler = LatentRescaler(
924
+ factor=rescale_factor,
925
+ in_channels=z_channels,
926
+ mid_channels=tmp_chn,
927
+ out_channels=tmp_chn,
928
+ depth=rescale_module_depth,
929
+ )
930
+
931
+ def forward(self, x):
932
+ x = self.rescaler(x)
933
+ x = self.decoder(x)
934
+ return x
935
+
936
+
937
+ class Upsampler(nn.Module):
938
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
939
+ super().__init__()
940
+ assert out_size >= in_size
941
+ num_blocks = int(np.log2(out_size // in_size)) + 1
942
+ factor_up = 1.0 + (out_size % in_size)
943
+ print(
944
+ f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
945
+ )
946
+ self.rescaler = LatentRescaler(
947
+ factor=factor_up,
948
+ in_channels=in_channels,
949
+ mid_channels=2 * in_channels,
950
+ out_channels=in_channels,
951
+ )
952
+ self.decoder = Decoder(
953
+ out_ch=out_channels,
954
+ resolution=out_size,
955
+ z_channels=in_channels,
956
+ num_res_blocks=2,
957
+ attn_resolutions=[],
958
+ in_channels=None,
959
+ ch=in_channels,
960
+ ch_mult=[ch_mult for _ in range(num_blocks)],
961
+ )
962
+
963
+ def forward(self, x):
964
+ x = self.rescaler(x)
965
+ x = self.decoder(x)
966
+ return x
967
+
968
+
969
+ class Resize(nn.Module):
970
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
971
+ super().__init__()
972
+ self.with_conv = learned
973
+ self.mode = mode
974
+ if self.with_conv:
975
+ print(
976
+ f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
977
+ )
978
+ raise NotImplementedError()
979
+ assert in_channels is not None
980
+ # no asymmetric padding in torch conv, must do it ourselves
981
+ self.conv = torch.nn.Conv2d(
982
+ in_channels, in_channels, kernel_size=4, stride=2, padding=1
983
+ )
984
+
985
+ def forward(self, x, scale_factor=1.0):
986
+ if scale_factor == 1.0:
987
+ return x
988
+ else:
989
+ x = torch.nn.functional.interpolate(
990
+ x, mode=self.mode, align_corners=False, scale_factor=scale_factor
991
+ )
992
+ return x
993
+
994
+
995
+ class FirstStagePostProcessor(nn.Module):
996
+ def __init__(
997
+ self,
998
+ ch_mult: list,
999
+ in_channels,
1000
+ pretrained_model: nn.Module = None,
1001
+ reshape=False,
1002
+ n_channels=None,
1003
+ dropout=0.0,
1004
+ pretrained_config=None,
1005
+ ):
1006
+ super().__init__()
1007
+ if pretrained_config is None:
1008
+ assert (
1009
+ pretrained_model is not None
1010
+ ), 'Either "pretrained_model" or "pretrained_config" must not be None'
1011
+ self.pretrained_model = pretrained_model
1012
+ else:
1013
+ assert (
1014
+ pretrained_config is not None
1015
+ ), 'Either "pretrained_model" or "pretrained_config" must not be None'
1016
+ self.instantiate_pretrained(pretrained_config)
1017
+
1018
+ self.do_reshape = reshape
1019
+
1020
+ if n_channels is None:
1021
+ n_channels = self.pretrained_model.encoder.ch
1022
+
1023
+ self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)
1024
+ self.proj = nn.Conv2d(
1025
+ in_channels, n_channels, kernel_size=3, stride=1, padding=1
1026
+ )
1027
+
1028
+ blocks = []
1029
+ downs = []
1030
+ ch_in = n_channels
1031
+ for m in ch_mult:
1032
+ blocks.append(
1033
+ ResnetBlock(
1034
+ in_channels=ch_in, out_channels=m * n_channels, dropout=dropout
1035
+ )
1036
+ )
1037
+ ch_in = m * n_channels
1038
+ downs.append(Downsample(ch_in, with_conv=False))
1039
+
1040
+ self.model = nn.ModuleList(blocks)
1041
+ self.downsampler = nn.ModuleList(downs)
1042
+
1043
+ def instantiate_pretrained(self, config):
1044
+ model = instantiate_from_config(config)
1045
+ self.pretrained_model = model.eval()
1046
+ # self.pretrained_model.train = False
1047
+ for param in self.pretrained_model.parameters():
1048
+ param.requires_grad = False
1049
+
1050
+ @torch.no_grad()
1051
+ def encode_with_pretrained(self, x):
1052
+ c = self.pretrained_model.encode(x)
1053
+ if isinstance(c, DiagonalGaussianDistribution):
1054
+ c = c.mode()
1055
+ return c
1056
+
1057
+ def forward(self, x):
1058
+ z_fs = self.encode_with_pretrained(x)
1059
+ z = self.proj_norm(z_fs)
1060
+ z = self.proj(z)
1061
+ z = nonlinearity(z)
1062
+
1063
+ for submodel, downmodel in zip(self.model, self.downsampler):
1064
+ z = submodel(z, temb=None)
1065
+ z = downmodel(z)
1066
+
1067
+ if self.do_reshape:
1068
+ z = rearrange(z, "b c h w -> b (h w) c")
1069
+ return z
FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/openaimodel.py ADDED
@@ -0,0 +1,1103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch as th
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from audiosr.latent_diffusion.modules.diffusionmodules.util import (
10
+ checkpoint,
11
+ conv_nd,
12
+ linear,
13
+ avg_pool_nd,
14
+ zero_module,
15
+ normalization,
16
+ timestep_embedding,
17
+ )
18
+ from audiosr.latent_diffusion.modules.attention import SpatialTransformer
19
+
20
+
21
+ # dummy replace
22
+ def convert_module_to_f16(x):
23
+ pass
24
+
25
+
26
+ def convert_module_to_f32(x):
27
+ pass
28
+
29
+
30
+ ## go
31
+ class AttentionPool2d(nn.Module):
32
+ """
33
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ spacial_dim: int,
39
+ embed_dim: int,
40
+ num_heads_channels: int,
41
+ output_dim: int = None,
42
+ ):
43
+ super().__init__()
44
+ self.positional_embedding = nn.Parameter(
45
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
46
+ )
47
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
48
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
49
+ self.num_heads = embed_dim // num_heads_channels
50
+ self.attention = QKVAttention(self.num_heads)
51
+
52
+ def forward(self, x):
53
+ b, c, *_spatial = x.shape
54
+ x = x.reshape(b, c, -1).contiguous() # NC(HW)
55
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
56
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
57
+ x = self.qkv_proj(x)
58
+ x = self.attention(x)
59
+ x = self.c_proj(x)
60
+ return x[:, :, 0]
61
+
62
+
63
+ class TimestepBlock(nn.Module):
64
+ """
65
+ Any module where forward() takes timestep embeddings as a second argument.
66
+ """
67
+
68
+ @abstractmethod
69
+ def forward(self, x, emb):
70
+ """
71
+ Apply the module to `x` given `emb` timestep embeddings.
72
+ """
73
+
74
+
75
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
76
+ """
77
+ A sequential module that passes timestep embeddings to the children that
78
+ support it as an extra input.
79
+ """
80
+
81
+ def forward(self, x, emb, context_list=None, mask_list=None):
82
+ # The first spatial transformer block does not have context
83
+ spatial_transformer_id = 0
84
+ context_list = [None] + context_list
85
+ mask_list = [None] + mask_list
86
+
87
+ for layer in self:
88
+ if isinstance(layer, TimestepBlock):
89
+ x = layer(x, emb)
90
+ elif isinstance(layer, SpatialTransformer):
91
+ if spatial_transformer_id >= len(context_list):
92
+ context, mask = None, None
93
+ else:
94
+ context, mask = (
95
+ context_list[spatial_transformer_id],
96
+ mask_list[spatial_transformer_id],
97
+ )
98
+
99
+ x = layer(x, context, mask=mask)
100
+ spatial_transformer_id += 1
101
+ else:
102
+ x = layer(x)
103
+ return x
104
+
105
+
106
+ class Upsample(nn.Module):
107
+ """
108
+ An upsampling layer with an optional convolution.
109
+ :param channels: channels in the inputs and outputs.
110
+ :param use_conv: a bool determining if a convolution is applied.
111
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
112
+ upsampling occurs in the inner-two dimensions.
113
+ """
114
+
115
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
116
+ super().__init__()
117
+ self.channels = channels
118
+ self.out_channels = out_channels or channels
119
+ self.use_conv = use_conv
120
+ self.dims = dims
121
+ if use_conv:
122
+ self.conv = conv_nd(
123
+ dims, self.channels, self.out_channels, 3, padding=padding
124
+ )
125
+
126
+ def forward(self, x):
127
+ assert x.shape[1] == self.channels
128
+ if self.dims == 3:
129
+ x = F.interpolate(
130
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
131
+ )
132
+ else:
133
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
134
+ if self.use_conv:
135
+ x = self.conv(x)
136
+ return x
137
+
138
+
139
+ class TransposedUpsample(nn.Module):
140
+ "Learned 2x upsampling without padding"
141
+
142
+ def __init__(self, channels, out_channels=None, ks=5):
143
+ super().__init__()
144
+ self.channels = channels
145
+ self.out_channels = out_channels or channels
146
+
147
+ self.up = nn.ConvTranspose2d(
148
+ self.channels, self.out_channels, kernel_size=ks, stride=2
149
+ )
150
+
151
+ def forward(self, x):
152
+ return self.up(x)
153
+
154
+
155
+ class Downsample(nn.Module):
156
+ """
157
+ A downsampling layer with an optional convolution.
158
+ :param channels: channels in the inputs and outputs.
159
+ :param use_conv: a bool determining if a convolution is applied.
160
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
161
+ downsampling occurs in the inner-two dimensions.
162
+ """
163
+
164
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
165
+ super().__init__()
166
+ self.channels = channels
167
+ self.out_channels = out_channels or channels
168
+ self.use_conv = use_conv
169
+ self.dims = dims
170
+ stride = 2 if dims != 3 else (1, 2, 2)
171
+ if use_conv:
172
+ self.op = conv_nd(
173
+ dims,
174
+ self.channels,
175
+ self.out_channels,
176
+ 3,
177
+ stride=stride,
178
+ padding=padding,
179
+ )
180
+ else:
181
+ assert self.channels == self.out_channels
182
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
183
+
184
+ def forward(self, x):
185
+ assert x.shape[1] == self.channels
186
+ return self.op(x)
187
+
188
+
189
+ class ResBlock(TimestepBlock):
190
+ """
191
+ A residual block that can optionally change the number of channels.
192
+ :param channels: the number of input channels.
193
+ :param emb_channels: the number of timestep embedding channels.
194
+ :param dropout: the rate of dropout.
195
+ :param out_channels: if specified, the number of out channels.
196
+ :param use_conv: if True and out_channels is specified, use a spatial
197
+ convolution instead of a smaller 1x1 convolution to change the
198
+ channels in the skip connection.
199
+ :param dims: determines if the signal is 1D, 2D, or 3D.
200
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
201
+ :param up: if True, use this block for upsampling.
202
+ :param down: if True, use this block for downsampling.
203
+ """
204
+
205
+ def __init__(
206
+ self,
207
+ channels,
208
+ emb_channels,
209
+ dropout,
210
+ out_channels=None,
211
+ use_conv=False,
212
+ use_scale_shift_norm=False,
213
+ dims=2,
214
+ use_checkpoint=False,
215
+ up=False,
216
+ down=False,
217
+ ):
218
+ super().__init__()
219
+ self.channels = channels
220
+ self.emb_channels = emb_channels
221
+ self.dropout = dropout
222
+ self.out_channels = out_channels or channels
223
+ self.use_conv = use_conv
224
+ self.use_checkpoint = use_checkpoint
225
+ self.use_scale_shift_norm = use_scale_shift_norm
226
+
227
+ self.in_layers = nn.Sequential(
228
+ normalization(channels),
229
+ nn.SiLU(),
230
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
231
+ )
232
+
233
+ self.updown = up or down
234
+
235
+ if up:
236
+ self.h_upd = Upsample(channels, False, dims)
237
+ self.x_upd = Upsample(channels, False, dims)
238
+ elif down:
239
+ self.h_upd = Downsample(channels, False, dims)
240
+ self.x_upd = Downsample(channels, False, dims)
241
+ else:
242
+ self.h_upd = self.x_upd = nn.Identity()
243
+
244
+ self.emb_layers = nn.Sequential(
245
+ nn.SiLU(),
246
+ linear(
247
+ emb_channels,
248
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
249
+ ),
250
+ )
251
+ self.out_layers = nn.Sequential(
252
+ normalization(self.out_channels),
253
+ nn.SiLU(),
254
+ nn.Dropout(p=dropout),
255
+ zero_module(
256
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
257
+ ),
258
+ )
259
+
260
+ if self.out_channels == channels:
261
+ self.skip_connection = nn.Identity()
262
+ elif use_conv:
263
+ self.skip_connection = conv_nd(
264
+ dims, channels, self.out_channels, 3, padding=1
265
+ )
266
+ else:
267
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
268
+
269
+ def forward(self, x, emb):
270
+ """
271
+ Apply the block to a Tensor, conditioned on a timestep embedding.
272
+ :param x: an [N x C x ...] Tensor of features.
273
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
274
+ :return: an [N x C x ...] Tensor of outputs.
275
+ """
276
+ return checkpoint(
277
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
278
+ )
279
+
280
+ def _forward(self, x, emb):
281
+ if self.updown:
282
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
283
+ h = in_rest(x)
284
+ h = self.h_upd(h)
285
+ x = self.x_upd(x)
286
+ h = in_conv(h)
287
+ else:
288
+ h = self.in_layers(x)
289
+ emb_out = self.emb_layers(emb).type(h.dtype)
290
+ while len(emb_out.shape) < len(h.shape):
291
+ emb_out = emb_out[..., None]
292
+ if self.use_scale_shift_norm:
293
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
294
+ scale, shift = th.chunk(emb_out, 2, dim=1)
295
+ h = out_norm(h) * (1 + scale) + shift
296
+ h = out_rest(h)
297
+ else:
298
+ h = h + emb_out
299
+ h = self.out_layers(h)
300
+ return self.skip_connection(x) + h
301
+
302
+
303
+ class AttentionBlock(nn.Module):
304
+ """
305
+ An attention block that allows spatial positions to attend to each other.
306
+ Originally ported from here, but adapted to the N-d case.
307
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
308
+ """
309
+
310
+ def __init__(
311
+ self,
312
+ channels,
313
+ num_heads=1,
314
+ num_head_channels=-1,
315
+ use_checkpoint=False,
316
+ use_new_attention_order=False,
317
+ ):
318
+ super().__init__()
319
+ self.channels = channels
320
+ if num_head_channels == -1:
321
+ self.num_heads = num_heads
322
+ else:
323
+ assert (
324
+ channels % num_head_channels == 0
325
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
326
+ self.num_heads = channels // num_head_channels
327
+ self.use_checkpoint = use_checkpoint
328
+ self.norm = normalization(channels)
329
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
330
+ if use_new_attention_order:
331
+ # split qkv before split heads
332
+ self.attention = QKVAttention(self.num_heads)
333
+ else:
334
+ # split heads before split qkv
335
+ self.attention = QKVAttentionLegacy(self.num_heads)
336
+
337
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
338
+
339
+ def forward(self, x):
340
+ return checkpoint(
341
+ self._forward, (x,), self.parameters(), True
342
+ ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
343
+ # return pt_checkpoint(self._forward, x) # pytorch
344
+
345
+ def _forward(self, x):
346
+ b, c, *spatial = x.shape
347
+ x = x.reshape(b, c, -1).contiguous()
348
+ qkv = self.qkv(self.norm(x)).contiguous()
349
+ h = self.attention(qkv).contiguous()
350
+ h = self.proj_out(h).contiguous()
351
+ return (x + h).reshape(b, c, *spatial).contiguous()
352
+
353
+
354
+ def count_flops_attn(model, _x, y):
355
+ """
356
+ A counter for the `thop` package to count the operations in an
357
+ attention operation.
358
+ Meant to be used like:
359
+ macs, params = thop.profile(
360
+ model,
361
+ inputs=(inputs, timestamps),
362
+ custom_ops={QKVAttention: QKVAttention.count_flops},
363
+ )
364
+ """
365
+ b, c, *spatial = y[0].shape
366
+ num_spatial = int(np.prod(spatial))
367
+ # We perform two matmuls with the same number of ops.
368
+ # The first computes the weight matrix, the second computes
369
+ # the combination of the value vectors.
370
+ matmul_ops = 2 * b * (num_spatial**2) * c
371
+ model.total_ops += th.DoubleTensor([matmul_ops])
372
+
373
+
374
+ class QKVAttentionLegacy(nn.Module):
375
+ """
376
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
377
+ """
378
+
379
+ def __init__(self, n_heads):
380
+ super().__init__()
381
+ self.n_heads = n_heads
382
+
383
+ def forward(self, qkv):
384
+ """
385
+ Apply QKV attention.
386
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
387
+ :return: an [N x (H * C) x T] tensor after attention.
388
+ """
389
+ bs, width, length = qkv.shape
390
+ assert width % (3 * self.n_heads) == 0
391
+ ch = width // (3 * self.n_heads)
392
+ q, k, v = (
393
+ qkv.reshape(bs * self.n_heads, ch * 3, length).contiguous().split(ch, dim=1)
394
+ )
395
+ scale = 1 / math.sqrt(math.sqrt(ch))
396
+ weight = th.einsum(
397
+ "bct,bcs->bts", q * scale, k * scale
398
+ ) # More stable with f16 than dividing afterwards
399
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
400
+ a = th.einsum("bts,bcs->bct", weight, v)
401
+ return a.reshape(bs, -1, length).contiguous()
402
+
403
+ @staticmethod
404
+ def count_flops(model, _x, y):
405
+ return count_flops_attn(model, _x, y)
406
+
407
+
408
+ class QKVAttention(nn.Module):
409
+ """
410
+ A module which performs QKV attention and splits in a different order.
411
+ """
412
+
413
+ def __init__(self, n_heads):
414
+ super().__init__()
415
+ self.n_heads = n_heads
416
+
417
+ def forward(self, qkv):
418
+ """
419
+ Apply QKV attention.
420
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
421
+ :return: an [N x (H * C) x T] tensor after attention.
422
+ """
423
+ bs, width, length = qkv.shape
424
+ assert width % (3 * self.n_heads) == 0
425
+ ch = width // (3 * self.n_heads)
426
+ q, k, v = qkv.chunk(3, dim=1)
427
+ scale = 1 / math.sqrt(math.sqrt(ch))
428
+ weight = th.einsum(
429
+ "bct,bcs->bts",
430
+ (q * scale).view(bs * self.n_heads, ch, length),
431
+ (k * scale).view(bs * self.n_heads, ch, length),
432
+ ) # More stable with f16 than dividing afterwards
433
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
434
+ a = th.einsum(
435
+ "bts,bcs->bct",
436
+ weight,
437
+ v.reshape(bs * self.n_heads, ch, length).contiguous(),
438
+ )
439
+ return a.reshape(bs, -1, length).contiguous()
440
+
441
+ @staticmethod
442
+ def count_flops(model, _x, y):
443
+ return count_flops_attn(model, _x, y)
444
+
445
+
446
+ class UNetModel(nn.Module):
447
+ """
448
+ The full UNet model with attention and timestep embedding.
449
+ :param in_channels: channels in the input Tensor.
450
+ :param model_channels: base channel count for the model.
451
+ :param out_channels: channels in the output Tensor.
452
+ :param num_res_blocks: number of residual blocks per downsample.
453
+ :param attention_resolutions: a collection of downsample rates at which
454
+ attention will take place. May be a set, list, or tuple.
455
+ For example, if this contains 4, then at 4x downsampling, attention
456
+ will be used.
457
+ :param dropout: the dropout probability.
458
+ :param channel_mult: channel multiplier for each level of the UNet.
459
+ :param conv_resample: if True, use learned convolutions for upsampling and
460
+ downsampling.
461
+ :param dims: determines if the signal is 1D, 2D, or 3D.
462
+ :param num_classes: if specified (as an int), then this model will be
463
+ class-conditional with `num_classes` classes.
464
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
465
+ :param num_heads: the number of attention heads in each attention layer.
466
+ :param num_heads_channels: if specified, ignore num_heads and instead use
467
+ a fixed channel width per attention head.
468
+ :param num_heads_upsample: works with num_heads to set a different number
469
+ of heads for upsampling. Deprecated.
470
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
471
+ :param resblock_updown: use residual blocks for up/downsampling.
472
+ :param use_new_attention_order: use a different attention pattern for potentially
473
+ increased efficiency.
474
+ """
475
+
476
+ def __init__(
477
+ self,
478
+ image_size,
479
+ in_channels,
480
+ model_channels,
481
+ out_channels,
482
+ num_res_blocks,
483
+ attention_resolutions,
484
+ dropout=0,
485
+ channel_mult=(1, 2, 4, 8),
486
+ conv_resample=True,
487
+ dims=2,
488
+ extra_sa_layer=True,
489
+ num_classes=None,
490
+ extra_film_condition_dim=None,
491
+ use_checkpoint=False,
492
+ use_fp16=False,
493
+ num_heads=-1,
494
+ num_head_channels=-1,
495
+ num_heads_upsample=-1,
496
+ use_scale_shift_norm=False,
497
+ resblock_updown=False,
498
+ use_new_attention_order=False,
499
+ use_spatial_transformer=True, # custom transformer support
500
+ transformer_depth=1, # custom transformer support
501
+ context_dim=None, # custom transformer support
502
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
503
+ legacy=True,
504
+ ):
505
+ super().__init__()
506
+ if num_heads_upsample == -1:
507
+ num_heads_upsample = num_heads
508
+
509
+ if num_heads == -1:
510
+ assert (
511
+ num_head_channels != -1
512
+ ), "Either num_heads or num_head_channels has to be set"
513
+
514
+ if num_head_channels == -1:
515
+ assert (
516
+ num_heads != -1
517
+ ), "Either num_heads or num_head_channels has to be set"
518
+
519
+ self.image_size = image_size
520
+ self.in_channels = in_channels
521
+ self.model_channels = model_channels
522
+ self.out_channels = out_channels
523
+ self.num_res_blocks = num_res_blocks
524
+ self.attention_resolutions = attention_resolutions
525
+ self.dropout = dropout
526
+ self.channel_mult = channel_mult
527
+ self.conv_resample = conv_resample
528
+ self.num_classes = num_classes
529
+ self.extra_film_condition_dim = extra_film_condition_dim
530
+ self.use_checkpoint = use_checkpoint
531
+ self.dtype = th.float16 if use_fp16 else th.float32
532
+ self.num_heads = num_heads
533
+ self.num_head_channels = num_head_channels
534
+ self.num_heads_upsample = num_heads_upsample
535
+ self.predict_codebook_ids = n_embed is not None
536
+ time_embed_dim = model_channels * 4
537
+ self.time_embed = nn.Sequential(
538
+ linear(model_channels, time_embed_dim),
539
+ nn.SiLU(),
540
+ linear(time_embed_dim, time_embed_dim),
541
+ )
542
+
543
+ # assert not (
544
+ # self.num_classes is not None and self.extra_film_condition_dim is not None
545
+ # ), "As for the condition of theh UNet model, you can only set using class label or an extra embedding vector (such as from CLAP). You cannot set both num_classes and extra_film_condition_dim."
546
+
547
+ if self.num_classes is not None:
548
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
549
+
550
+ self.use_extra_film_by_concat = self.extra_film_condition_dim is not None
551
+
552
+ if self.extra_film_condition_dim is not None:
553
+ self.film_emb = nn.Linear(self.extra_film_condition_dim, time_embed_dim)
554
+ print(
555
+ "+ Use extra condition on UNet channel using Film. Extra condition dimension is %s. "
556
+ % self.extra_film_condition_dim
557
+ )
558
+
559
+ if context_dim is not None and not use_spatial_transformer:
560
+ assert (
561
+ use_spatial_transformer
562
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
563
+
564
+ if context_dim is not None and not isinstance(context_dim, list):
565
+ context_dim = [context_dim]
566
+ elif context_dim is None:
567
+ context_dim = [None] # At least use one spatial transformer
568
+
569
+ self.input_blocks = nn.ModuleList(
570
+ [
571
+ TimestepEmbedSequential(
572
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
573
+ )
574
+ ]
575
+ )
576
+ self._feature_size = model_channels
577
+ input_block_chans = [model_channels]
578
+ ch = model_channels
579
+ ds = 1
580
+ for level, mult in enumerate(channel_mult):
581
+ for _ in range(num_res_blocks):
582
+ layers = [
583
+ ResBlock(
584
+ ch,
585
+ time_embed_dim
586
+ if (not self.use_extra_film_by_concat)
587
+ else time_embed_dim * 2,
588
+ dropout,
589
+ out_channels=mult * model_channels,
590
+ dims=dims,
591
+ use_checkpoint=use_checkpoint,
592
+ use_scale_shift_norm=use_scale_shift_norm,
593
+ )
594
+ ]
595
+ ch = mult * model_channels
596
+ if ds in attention_resolutions:
597
+ if num_head_channels == -1:
598
+ dim_head = ch // num_heads
599
+ else:
600
+ num_heads = ch // num_head_channels
601
+ dim_head = num_head_channels
602
+ if legacy:
603
+ dim_head = (
604
+ ch // num_heads
605
+ if use_spatial_transformer
606
+ else num_head_channels
607
+ )
608
+ if extra_sa_layer:
609
+ layers.append(
610
+ SpatialTransformer(
611
+ ch,
612
+ num_heads,
613
+ dim_head,
614
+ depth=transformer_depth,
615
+ context_dim=None,
616
+ )
617
+ )
618
+ for context_dim_id in range(len(context_dim)):
619
+ layers.append(
620
+ AttentionBlock(
621
+ ch,
622
+ use_checkpoint=use_checkpoint,
623
+ num_heads=num_heads,
624
+ num_head_channels=dim_head,
625
+ use_new_attention_order=use_new_attention_order,
626
+ )
627
+ if not use_spatial_transformer
628
+ else SpatialTransformer(
629
+ ch,
630
+ num_heads,
631
+ dim_head,
632
+ depth=transformer_depth,
633
+ context_dim=context_dim[context_dim_id],
634
+ )
635
+ )
636
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
637
+ self._feature_size += ch
638
+ input_block_chans.append(ch)
639
+ if level != len(channel_mult) - 1:
640
+ out_ch = ch
641
+ self.input_blocks.append(
642
+ TimestepEmbedSequential(
643
+ ResBlock(
644
+ ch,
645
+ time_embed_dim
646
+ if (not self.use_extra_film_by_concat)
647
+ else time_embed_dim * 2,
648
+ dropout,
649
+ out_channels=out_ch,
650
+ dims=dims,
651
+ use_checkpoint=use_checkpoint,
652
+ use_scale_shift_norm=use_scale_shift_norm,
653
+ down=True,
654
+ )
655
+ if resblock_updown
656
+ else Downsample(
657
+ ch, conv_resample, dims=dims, out_channels=out_ch
658
+ )
659
+ )
660
+ )
661
+ ch = out_ch
662
+ input_block_chans.append(ch)
663
+ ds *= 2
664
+ self._feature_size += ch
665
+
666
+ if num_head_channels == -1:
667
+ dim_head = ch // num_heads
668
+ else:
669
+ num_heads = ch // num_head_channels
670
+ dim_head = num_head_channels
671
+ if legacy:
672
+ # num_heads = 1
673
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
674
+ middle_layers = [
675
+ ResBlock(
676
+ ch,
677
+ time_embed_dim
678
+ if (not self.use_extra_film_by_concat)
679
+ else time_embed_dim * 2,
680
+ dropout,
681
+ dims=dims,
682
+ use_checkpoint=use_checkpoint,
683
+ use_scale_shift_norm=use_scale_shift_norm,
684
+ )
685
+ ]
686
+ if extra_sa_layer:
687
+ middle_layers.append(
688
+ SpatialTransformer(
689
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=None
690
+ )
691
+ )
692
+ for context_dim_id in range(len(context_dim)):
693
+ middle_layers.append(
694
+ AttentionBlock(
695
+ ch,
696
+ use_checkpoint=use_checkpoint,
697
+ num_heads=num_heads,
698
+ num_head_channels=dim_head,
699
+ use_new_attention_order=use_new_attention_order,
700
+ )
701
+ if not use_spatial_transformer
702
+ else SpatialTransformer(
703
+ ch,
704
+ num_heads,
705
+ dim_head,
706
+ depth=transformer_depth,
707
+ context_dim=context_dim[context_dim_id],
708
+ )
709
+ )
710
+ middle_layers.append(
711
+ ResBlock(
712
+ ch,
713
+ time_embed_dim
714
+ if (not self.use_extra_film_by_concat)
715
+ else time_embed_dim * 2,
716
+ dropout,
717
+ dims=dims,
718
+ use_checkpoint=use_checkpoint,
719
+ use_scale_shift_norm=use_scale_shift_norm,
720
+ )
721
+ )
722
+ self.middle_block = TimestepEmbedSequential(*middle_layers)
723
+
724
+ self._feature_size += ch
725
+
726
+ self.output_blocks = nn.ModuleList([])
727
+ for level, mult in list(enumerate(channel_mult))[::-1]:
728
+ for i in range(num_res_blocks + 1):
729
+ ich = input_block_chans.pop()
730
+ layers = [
731
+ ResBlock(
732
+ ch + ich,
733
+ time_embed_dim
734
+ if (not self.use_extra_film_by_concat)
735
+ else time_embed_dim * 2,
736
+ dropout,
737
+ out_channels=model_channels * mult,
738
+ dims=dims,
739
+ use_checkpoint=use_checkpoint,
740
+ use_scale_shift_norm=use_scale_shift_norm,
741
+ )
742
+ ]
743
+ ch = model_channels * mult
744
+ if ds in attention_resolutions:
745
+ if num_head_channels == -1:
746
+ dim_head = ch // num_heads
747
+ else:
748
+ num_heads = ch // num_head_channels
749
+ dim_head = num_head_channels
750
+ if legacy:
751
+ # num_heads = 1
752
+ dim_head = (
753
+ ch // num_heads
754
+ if use_spatial_transformer
755
+ else num_head_channels
756
+ )
757
+ if extra_sa_layer:
758
+ layers.append(
759
+ SpatialTransformer(
760
+ ch,
761
+ num_heads,
762
+ dim_head,
763
+ depth=transformer_depth,
764
+ context_dim=None,
765
+ )
766
+ )
767
+ for context_dim_id in range(len(context_dim)):
768
+ layers.append(
769
+ AttentionBlock(
770
+ ch,
771
+ use_checkpoint=use_checkpoint,
772
+ num_heads=num_heads_upsample,
773
+ num_head_channels=dim_head,
774
+ use_new_attention_order=use_new_attention_order,
775
+ )
776
+ if not use_spatial_transformer
777
+ else SpatialTransformer(
778
+ ch,
779
+ num_heads,
780
+ dim_head,
781
+ depth=transformer_depth,
782
+ context_dim=context_dim[context_dim_id],
783
+ )
784
+ )
785
+ if level and i == num_res_blocks:
786
+ out_ch = ch
787
+ layers.append(
788
+ ResBlock(
789
+ ch,
790
+ time_embed_dim
791
+ if (not self.use_extra_film_by_concat)
792
+ else time_embed_dim * 2,
793
+ dropout,
794
+ out_channels=out_ch,
795
+ dims=dims,
796
+ use_checkpoint=use_checkpoint,
797
+ use_scale_shift_norm=use_scale_shift_norm,
798
+ up=True,
799
+ )
800
+ if resblock_updown
801
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
802
+ )
803
+ ds //= 2
804
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
805
+ self._feature_size += ch
806
+
807
+ self.out = nn.Sequential(
808
+ normalization(ch),
809
+ nn.SiLU(),
810
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
811
+ )
812
+ if self.predict_codebook_ids:
813
+ self.id_predictor = nn.Sequential(
814
+ normalization(ch),
815
+ conv_nd(dims, model_channels, n_embed, 1),
816
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
817
+ )
818
+
819
+ self.shape_reported = False
820
+
821
+ def convert_to_fp16(self):
822
+ """
823
+ Convert the torso of the model to float16.
824
+ """
825
+ self.input_blocks.apply(convert_module_to_f16)
826
+ self.middle_block.apply(convert_module_to_f16)
827
+ self.output_blocks.apply(convert_module_to_f16)
828
+
829
+ def convert_to_fp32(self):
830
+ """
831
+ Convert the torso of the model to float32.
832
+ """
833
+ self.input_blocks.apply(convert_module_to_f32)
834
+ self.middle_block.apply(convert_module_to_f32)
835
+ self.output_blocks.apply(convert_module_to_f32)
836
+
837
+ def forward(
838
+ self,
839
+ x,
840
+ timesteps=None,
841
+ y=None,
842
+ context_list=None,
843
+ context_attn_mask_list=None,
844
+ **kwargs,
845
+ ):
846
+ """
847
+ Apply the model to an input batch.
848
+ :param x: an [N x C x ...] Tensor of inputs.
849
+ :param timesteps: a 1-D batch of timesteps.
850
+ :param context: conditioning plugged in via crossattn
851
+ :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional
852
+ :return: an [N x C x ...] Tensor of outputs.
853
+ """
854
+ if not self.shape_reported:
855
+ # print("The shape of UNet input is", x.size())
856
+ self.shape_reported = True
857
+
858
+ assert (y is not None) == (
859
+ self.num_classes is not None or self.extra_film_condition_dim is not None
860
+ ), "must specify y if and only if the model is class-conditional or film embedding conditional"
861
+ hs = []
862
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
863
+ emb = self.time_embed(t_emb)
864
+
865
+ # if self.num_classes is not None:
866
+ # assert y.shape == (x.shape[0],)
867
+ # emb = emb + self.label_emb(y)
868
+
869
+ if self.use_extra_film_by_concat:
870
+ emb = th.cat([emb, self.film_emb(y)], dim=-1)
871
+
872
+ h = x.type(self.dtype)
873
+ for module in self.input_blocks:
874
+ h = module(h, emb, context_list, context_attn_mask_list)
875
+ hs.append(h)
876
+ h = self.middle_block(h, emb, context_list, context_attn_mask_list)
877
+ for module in self.output_blocks:
878
+ concate_tensor = hs.pop()
879
+ h = th.cat([h, concate_tensor], dim=1)
880
+ h = module(h, emb, context_list, context_attn_mask_list)
881
+ h = h.type(x.dtype)
882
+ if self.predict_codebook_ids:
883
+ return self.id_predictor(h)
884
+ else:
885
+ return self.out(h)
886
+
887
+
888
+ class EncoderUNetModel(nn.Module):
889
+ """
890
+ The half UNet model with attention and timestep embedding.
891
+ For usage, see UNet.
892
+ """
893
+
894
+ def __init__(
895
+ self,
896
+ image_size,
897
+ in_channels,
898
+ model_channels,
899
+ out_channels,
900
+ num_res_blocks,
901
+ attention_resolutions,
902
+ dropout=0,
903
+ channel_mult=(1, 2, 4, 8),
904
+ conv_resample=True,
905
+ dims=2,
906
+ use_checkpoint=False,
907
+ use_fp16=False,
908
+ num_heads=1,
909
+ num_head_channels=-1,
910
+ num_heads_upsample=-1,
911
+ use_scale_shift_norm=False,
912
+ resblock_updown=False,
913
+ use_new_attention_order=False,
914
+ pool="adaptive",
915
+ *args,
916
+ **kwargs,
917
+ ):
918
+ super().__init__()
919
+
920
+ if num_heads_upsample == -1:
921
+ num_heads_upsample = num_heads
922
+
923
+ self.in_channels = in_channels
924
+ self.model_channels = model_channels
925
+ self.out_channels = out_channels
926
+ self.num_res_blocks = num_res_blocks
927
+ self.attention_resolutions = attention_resolutions
928
+ self.dropout = dropout
929
+ self.channel_mult = channel_mult
930
+ self.conv_resample = conv_resample
931
+ self.use_checkpoint = use_checkpoint
932
+ self.dtype = th.float16 if use_fp16 else th.float32
933
+ self.num_heads = num_heads
934
+ self.num_head_channels = num_head_channels
935
+ self.num_heads_upsample = num_heads_upsample
936
+
937
+ time_embed_dim = model_channels * 4
938
+ self.time_embed = nn.Sequential(
939
+ linear(model_channels, time_embed_dim),
940
+ nn.SiLU(),
941
+ linear(time_embed_dim, time_embed_dim),
942
+ )
943
+
944
+ self.input_blocks = nn.ModuleList(
945
+ [
946
+ TimestepEmbedSequential(
947
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
948
+ )
949
+ ]
950
+ )
951
+ self._feature_size = model_channels
952
+ input_block_chans = [model_channels]
953
+ ch = model_channels
954
+ ds = 1
955
+ for level, mult in enumerate(channel_mult):
956
+ for _ in range(num_res_blocks):
957
+ layers = [
958
+ ResBlock(
959
+ ch,
960
+ time_embed_dim,
961
+ dropout,
962
+ out_channels=mult * model_channels,
963
+ dims=dims,
964
+ use_checkpoint=use_checkpoint,
965
+ use_scale_shift_norm=use_scale_shift_norm,
966
+ )
967
+ ]
968
+ ch = mult * model_channels
969
+ if ds in attention_resolutions:
970
+ layers.append(
971
+ AttentionBlock(
972
+ ch,
973
+ use_checkpoint=use_checkpoint,
974
+ num_heads=num_heads,
975
+ num_head_channels=num_head_channels,
976
+ use_new_attention_order=use_new_attention_order,
977
+ )
978
+ )
979
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
980
+ self._feature_size += ch
981
+ input_block_chans.append(ch)
982
+ if level != len(channel_mult) - 1:
983
+ out_ch = ch
984
+ self.input_blocks.append(
985
+ TimestepEmbedSequential(
986
+ ResBlock(
987
+ ch,
988
+ time_embed_dim,
989
+ dropout,
990
+ out_channels=out_ch,
991
+ dims=dims,
992
+ use_checkpoint=use_checkpoint,
993
+ use_scale_shift_norm=use_scale_shift_norm,
994
+ down=True,
995
+ )
996
+ if resblock_updown
997
+ else Downsample(
998
+ ch, conv_resample, dims=dims, out_channels=out_ch
999
+ )
1000
+ )
1001
+ )
1002
+ ch = out_ch
1003
+ input_block_chans.append(ch)
1004
+ ds *= 2
1005
+ self._feature_size += ch
1006
+
1007
+ self.middle_block = TimestepEmbedSequential(
1008
+ ResBlock(
1009
+ ch,
1010
+ time_embed_dim,
1011
+ dropout,
1012
+ dims=dims,
1013
+ use_checkpoint=use_checkpoint,
1014
+ use_scale_shift_norm=use_scale_shift_norm,
1015
+ ),
1016
+ AttentionBlock(
1017
+ ch,
1018
+ use_checkpoint=use_checkpoint,
1019
+ num_heads=num_heads,
1020
+ num_head_channels=num_head_channels,
1021
+ use_new_attention_order=use_new_attention_order,
1022
+ ),
1023
+ ResBlock(
1024
+ ch,
1025
+ time_embed_dim,
1026
+ dropout,
1027
+ dims=dims,
1028
+ use_checkpoint=use_checkpoint,
1029
+ use_scale_shift_norm=use_scale_shift_norm,
1030
+ ),
1031
+ )
1032
+ self._feature_size += ch
1033
+ self.pool = pool
1034
+ if pool == "adaptive":
1035
+ self.out = nn.Sequential(
1036
+ normalization(ch),
1037
+ nn.SiLU(),
1038
+ nn.AdaptiveAvgPool2d((1, 1)),
1039
+ zero_module(conv_nd(dims, ch, out_channels, 1)),
1040
+ nn.Flatten(),
1041
+ )
1042
+ elif pool == "attention":
1043
+ assert num_head_channels != -1
1044
+ self.out = nn.Sequential(
1045
+ normalization(ch),
1046
+ nn.SiLU(),
1047
+ AttentionPool2d(
1048
+ (image_size // ds), ch, num_head_channels, out_channels
1049
+ ),
1050
+ )
1051
+ elif pool == "spatial":
1052
+ self.out = nn.Sequential(
1053
+ nn.Linear(self._feature_size, 2048),
1054
+ nn.ReLU(),
1055
+ nn.Linear(2048, self.out_channels),
1056
+ )
1057
+ elif pool == "spatial_v2":
1058
+ self.out = nn.Sequential(
1059
+ nn.Linear(self._feature_size, 2048),
1060
+ normalization(2048),
1061
+ nn.SiLU(),
1062
+ nn.Linear(2048, self.out_channels),
1063
+ )
1064
+ else:
1065
+ raise NotImplementedError(f"Unexpected {pool} pooling")
1066
+
1067
+ def convert_to_fp16(self):
1068
+ """
1069
+ Convert the torso of the model to float16.
1070
+ """
1071
+ self.input_blocks.apply(convert_module_to_f16)
1072
+ self.middle_block.apply(convert_module_to_f16)
1073
+
1074
+ def convert_to_fp32(self):
1075
+ """
1076
+ Convert the torso of the model to float32.
1077
+ """
1078
+ self.input_blocks.apply(convert_module_to_f32)
1079
+ self.middle_block.apply(convert_module_to_f32)
1080
+
1081
+ def forward(self, x, timesteps):
1082
+ """
1083
+ Apply the model to an input batch.
1084
+ :param x: an [N x C x ...] Tensor of inputs.
1085
+ :param timesteps: a 1-D batch of timesteps.
1086
+ :return: an [N x K] Tensor of outputs.
1087
+ """
1088
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
1089
+
1090
+ results = []
1091
+ h = x.type(self.dtype)
1092
+ for module in self.input_blocks:
1093
+ h = module(h, emb)
1094
+ if self.pool.startswith("spatial"):
1095
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1096
+ h = self.middle_block(h, emb)
1097
+ if self.pool.startswith("spatial"):
1098
+ results.append(h.type(x.dtype).mean(dim=(2, 3)))
1099
+ h = th.cat(results, axis=-1)
1100
+ return self.out(h)
1101
+ else:
1102
+ h = h.type(x.dtype)
1103
+ return self.out(h)
FlashSR/AudioSR/latent_diffusion/modules/diffusionmodules/util.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+
11
+ import math
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+ from einops import repeat
16
+
17
+ from FlashSR.AudioSR.latent_diffusion.util import instantiate_from_config
18
+
19
+
20
+ def make_beta_schedule(
21
+ schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
22
+ ):
23
+ if schedule == "linear":
24
+ betas = (
25
+ torch.linspace(
26
+ linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
27
+ )
28
+ ** 2
29
+ )
30
+
31
+ elif schedule == "cosine":
32
+ timesteps = (
33
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
34
+ )
35
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
36
+ alphas = torch.cos(alphas).pow(2)
37
+ alphas = alphas / alphas[0]
38
+ betas = 1 - alphas[1:] / alphas[:-1]
39
+ # betas = np.clip(betas, a_min=0, a_max=0.999)
40
+
41
+ elif schedule == "sqrt_linear":
42
+ betas = torch.linspace(
43
+ linear_start, linear_end, n_timestep, dtype=torch.float64
44
+ )
45
+ elif schedule == "sqrt":
46
+ betas = (
47
+ torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
48
+ ** 0.5
49
+ )
50
+ else:
51
+ raise ValueError(f"schedule '{schedule}' unknown.")
52
+ return betas.numpy()
53
+
54
+
55
+ def make_ddim_timesteps(
56
+ ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
57
+ ):
58
+ if ddim_discr_method == "uniform":
59
+ c = num_ddpm_timesteps // num_ddim_timesteps
60
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
61
+ elif ddim_discr_method == "quad":
62
+ ddim_timesteps = (
63
+ (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
64
+ ).astype(int)
65
+ else:
66
+ raise NotImplementedError(
67
+ f'There is no ddim discretization method called "{ddim_discr_method}"'
68
+ )
69
+
70
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
71
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
72
+ steps_out = ddim_timesteps + 1
73
+ if verbose:
74
+ print(f"Selected timesteps for ddim sampler: {steps_out}")
75
+ return steps_out
76
+
77
+
78
+ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
79
+ # select alphas for computing the variance schedule
80
+ alphas = alphacums[ddim_timesteps]
81
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
82
+
83
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
84
+ sigmas = eta * np.sqrt(
85
+ (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
86
+ )
87
+ if verbose:
88
+ print(
89
+ f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
90
+ )
91
+ print(
92
+ f"For the chosen value of eta, which is {eta}, "
93
+ f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
94
+ )
95
+ return sigmas, alphas, alphas_prev
96
+
97
+
98
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
99
+ """
100
+ Create a beta schedule that discretizes the given alpha_t_bar function,
101
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
102
+ :param num_diffusion_timesteps: the number of betas to produce.
103
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
104
+ produces the cumulative product of (1-beta) up to that
105
+ part of the diffusion process.
106
+ :param max_beta: the maximum beta to use; use values lower than 1 to
107
+ prevent singularities.
108
+ """
109
+ betas = []
110
+ for i in range(num_diffusion_timesteps):
111
+ t1 = i / num_diffusion_timesteps
112
+ t2 = (i + 1) / num_diffusion_timesteps
113
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
114
+ return np.array(betas)
115
+
116
+
117
+ def extract_into_tensor(a, t, x_shape):
118
+ b, *_ = t.shape
119
+ out = a.gather(-1, t).contiguous()
120
+ return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous()
121
+
122
+
123
+ def checkpoint(func, inputs, params, flag):
124
+ """
125
+ Evaluate a function without caching intermediate activations, allowing for
126
+ reduced memory at the expense of extra compute in the backward pass.
127
+ :param func: the function to evaluate.
128
+ :param inputs: the argument sequence to pass to `func`.
129
+ :param params: a sequence of parameters `func` depends on but does not
130
+ explicitly take as arguments.
131
+ :param flag: if False, disable gradient checkpointing.
132
+ """
133
+ if flag:
134
+ args = tuple(inputs) + tuple(params)
135
+ return CheckpointFunction.apply(func, len(inputs), *args)
136
+ else:
137
+ return func(*inputs)
138
+
139
+
140
+ class CheckpointFunction(torch.autograd.Function):
141
+ @staticmethod
142
+ def forward(ctx, run_function, length, *args):
143
+ ctx.run_function = run_function
144
+ ctx.input_tensors = list(args[:length])
145
+ ctx.input_params = list(args[length:])
146
+
147
+ with torch.no_grad():
148
+ output_tensors = ctx.run_function(*ctx.input_tensors)
149
+ return output_tensors
150
+
151
+ @staticmethod
152
+ def backward(ctx, *output_grads):
153
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
154
+ with torch.enable_grad():
155
+ # Fixes a bug where the first op in run_function modifies the
156
+ # Tensor storage in place, which is not allowed for detach()'d
157
+ # Tensors.
158
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
159
+ output_tensors = ctx.run_function(*shallow_copies)
160
+ input_grads = torch.autograd.grad(
161
+ output_tensors,
162
+ ctx.input_tensors + ctx.input_params,
163
+ output_grads,
164
+ allow_unused=True,
165
+ )
166
+ del ctx.input_tensors
167
+ del ctx.input_params
168
+ del output_tensors
169
+ return (None, None) + input_grads
170
+
171
+
172
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
173
+ """
174
+ Create sinusoidal timestep embeddings.
175
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
176
+ These may be fractional.
177
+ :param dim: the dimension of the output.
178
+ :param max_period: controls the minimum frequency of the embeddings.
179
+ :return: an [N x dim] Tensor of positional embeddings.
180
+ """
181
+ if not repeat_only:
182
+ half = dim // 2
183
+ freqs = torch.exp(
184
+ -math.log(max_period)
185
+ * torch.arange(start=0, end=half, dtype=torch.float32)
186
+ / half
187
+ ).to(device=timesteps.device)
188
+ args = timesteps[:, None].float() * freqs[None]
189
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
190
+ if dim % 2:
191
+ embedding = torch.cat(
192
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
193
+ )
194
+ else:
195
+ embedding = repeat(timesteps, "b -> b d", d=dim)
196
+ return embedding
197
+
198
+
199
+ def zero_module(module):
200
+ """
201
+ Zero out the parameters of a module and return it.
202
+ """
203
+ for p in module.parameters():
204
+ p.detach().zero_()
205
+ return module
206
+
207
+
208
+ def scale_module(module, scale):
209
+ """
210
+ Scale the parameters of a module and return it.
211
+ """
212
+ for p in module.parameters():
213
+ p.detach().mul_(scale)
214
+ return module
215
+
216
+
217
+ def mean_flat(tensor):
218
+ """
219
+ Take the mean over all non-batch dimensions.
220
+ """
221
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
222
+
223
+
224
+ def normalization(channels):
225
+ """
226
+ Make a standard normalization layer.
227
+ :param channels: number of input channels.
228
+ :return: an nn.Module for normalization.
229
+ """
230
+ return GroupNorm32(32, channels)
231
+
232
+
233
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
234
+ class SiLU(nn.Module):
235
+ def forward(self, x):
236
+ return x * torch.sigmoid(x)
237
+
238
+
239
+ class GroupNorm32(nn.GroupNorm):
240
+ def forward(self, x):
241
+ return super().forward(x.float()).type(x.dtype)
242
+
243
+
244
+ def conv_nd(dims, *args, **kwargs):
245
+ """
246
+ Create a 1D, 2D, or 3D convolution module.
247
+ """
248
+ if dims == 1:
249
+ return nn.Conv1d(*args, **kwargs)
250
+ elif dims == 2:
251
+ return nn.Conv2d(*args, **kwargs)
252
+ elif dims == 3:
253
+ return nn.Conv3d(*args, **kwargs)
254
+ raise ValueError(f"unsupported dimensions: {dims}")
255
+
256
+
257
+ def linear(*args, **kwargs):
258
+ """
259
+ Create a linear module.
260
+ """
261
+ return nn.Linear(*args, **kwargs)
262
+
263
+
264
+ def avg_pool_nd(dims, *args, **kwargs):
265
+ """
266
+ Create a 1D, 2D, or 3D average pooling module.
267
+ """
268
+ if dims == 1:
269
+ return nn.AvgPool1d(*args, **kwargs)
270
+ elif dims == 2:
271
+ return nn.AvgPool2d(*args, **kwargs)
272
+ elif dims == 3:
273
+ return nn.AvgPool3d(*args, **kwargs)
274
+ raise ValueError(f"unsupported dimensions: {dims}")
275
+
276
+
277
+ class HybridConditioner(nn.Module):
278
+ def __init__(self, c_concat_config, c_crossattn_config):
279
+ super().__init__()
280
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
281
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
282
+
283
+ def forward(self, c_concat, c_crossattn):
284
+ c_concat = self.concat_conditioner(c_concat)
285
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
286
+ return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]}
287
+
288
+
289
+ def noise_like(shape, device, repeat=False):
290
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
291
+ shape[0], *((1,) * (len(shape) - 1))
292
+ )
293
+ noise = lambda: torch.randn(shape, device=device)
294
+ return repeat_noise() if repeat else noise()
FlashSR/AudioSR/latent_diffusion/modules/distributions/__init__.py ADDED
File without changes
FlashSR/AudioSR/latent_diffusion/modules/distributions/distributions.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(
34
+ device=self.parameters.device
35
+ )
36
+
37
+ def sample(self):
38
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(
39
+ device=self.parameters.device
40
+ )
41
+ return x
42
+
43
+ def kl(self, other=None):
44
+ if self.deterministic:
45
+ return torch.Tensor([0.0])
46
+ else:
47
+ if other is None:
48
+ return 0.5 * torch.mean(
49
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
50
+ dim=[1, 2, 3],
51
+ )
52
+ else:
53
+ return 0.5 * torch.mean(
54
+ torch.pow(self.mean - other.mean, 2) / other.var
55
+ + self.var / other.var
56
+ - 1.0
57
+ - self.logvar
58
+ + other.logvar,
59
+ dim=[1, 2, 3],
60
+ )
61
+
62
+ def nll(self, sample, dims=[1, 2, 3]):
63
+ if self.deterministic:
64
+ return torch.Tensor([0.0])
65
+ logtwopi = np.log(2.0 * np.pi)
66
+ return 0.5 * torch.sum(
67
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
68
+ dim=dims,
69
+ )
70
+
71
+ def mode(self):
72
+ return self.mean
73
+
74
+
75
+ def normal_kl(mean1, logvar1, mean2, logvar2):
76
+ """
77
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
78
+ Compute the KL divergence between two gaussians.
79
+ Shapes are automatically broadcasted, so batches can be compared to
80
+ scalars, among other use cases.
81
+ """
82
+ tensor = None
83
+ for obj in (mean1, logvar1, mean2, logvar2):
84
+ if isinstance(obj, torch.Tensor):
85
+ tensor = obj
86
+ break
87
+ assert tensor is not None, "at least one argument must be a Tensor"
88
+
89
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
90
+ # Tensors, but it does not work for torch.exp().
91
+ logvar1, logvar2 = [
92
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
93
+ for x in (logvar1, logvar2)
94
+ ]
95
+
96
+ return 0.5 * (
97
+ -1.0
98
+ + logvar2
99
+ - logvar1
100
+ + torch.exp(logvar1 - logvar2)
101
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
102
+ )
FlashSR/AudioSR/latent_diffusion/modules/ema.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class LitEma(nn.Module):
6
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
7
+ super().__init__()
8
+ if decay < 0.0 or decay > 1.0:
9
+ raise ValueError("Decay must be between 0 and 1")
10
+
11
+ self.m_name2s_name = {}
12
+ self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32))
13
+ self.register_buffer(
14
+ "num_updates",
15
+ torch.tensor(0, dtype=torch.int)
16
+ if use_num_upates
17
+ else torch.tensor(-1, dtype=torch.int),
18
+ )
19
+
20
+ for name, p in model.named_parameters():
21
+ if p.requires_grad:
22
+ # remove as '.'-character is not allowed in buffers
23
+ s_name = name.replace(".", "")
24
+ self.m_name2s_name.update({name: s_name})
25
+ self.register_buffer(s_name, p.clone().detach().data)
26
+
27
+ self.collected_params = []
28
+
29
+ def forward(self, model):
30
+ decay = self.decay
31
+
32
+ if self.num_updates >= 0:
33
+ self.num_updates += 1
34
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
35
+
36
+ one_minus_decay = 1.0 - decay
37
+
38
+ with torch.no_grad():
39
+ m_param = dict(model.named_parameters())
40
+ shadow_params = dict(self.named_buffers())
41
+
42
+ for key in m_param:
43
+ if m_param[key].requires_grad:
44
+ sname = self.m_name2s_name[key]
45
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
46
+ shadow_params[sname].sub_(
47
+ one_minus_decay * (shadow_params[sname] - m_param[key])
48
+ )
49
+ else:
50
+ assert not key in self.m_name2s_name
51
+
52
+ def copy_to(self, model):
53
+ m_param = dict(model.named_parameters())
54
+ shadow_params = dict(self.named_buffers())
55
+ for key in m_param:
56
+ if m_param[key].requires_grad:
57
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
58
+ else:
59
+ assert not key in self.m_name2s_name
60
+
61
+ def store(self, parameters):
62
+ """
63
+ Save the current parameters for restoring later.
64
+ Args:
65
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
66
+ temporarily stored.
67
+ """
68
+ self.collected_params = [param.clone() for param in parameters]
69
+
70
+ def restore(self, parameters):
71
+ """
72
+ Restore the parameters stored with the `store` method.
73
+ Useful to validate the model with EMA parameters without affecting the
74
+ original optimization process. Store the parameters before the
75
+ `copy_to` method. After validation (or model saving), use this to
76
+ restore the former parameters.
77
+ Args:
78
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
79
+ updated with the stored parameters.
80
+ """
81
+ for c_param, param in zip(self.collected_params, parameters):
82
+ param.data.copy_(c_param.data)
FlashSR/AudioSR/latent_diffusion/modules/encoders/__init__.py ADDED
File without changes
FlashSR/AudioSR/latent_diffusion/modules/encoders/modules.py ADDED
@@ -0,0 +1,682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+ import torch.nn as nn
4
+ #from audiosr.clap.open_clip import create_model
5
+ #from audiosr.clap.training.data import get_audio_features
6
+ import torchaudio
7
+ #from transformers import RobertaTokenizer, AutoTokenizer, T5EncoderModel
8
+ import torch.nn.functional as F
9
+ from audiosr.latent_diffusion.modules.audiomae.AudioMAE import Vanilla_AudioMAE
10
+ from audiosr.latent_diffusion.modules.phoneme_encoder.encoder import TextEncoder
11
+ from audiosr.latent_diffusion.util import instantiate_from_config
12
+
13
+ from transformers import AutoTokenizer, T5Config
14
+
15
+
16
+ import numpy as np
17
+
18
+ """
19
+ The model forward function can return three types of data:
20
+ 1. tensor: used directly as conditioning signal
21
+ 2. dict: where there is a main key as condition, there are also other key that you can use to pass loss function and itermediate result. etc.
22
+ 3. list: the length is 2, in which the first element is tensor, the second element is attntion mask.
23
+
24
+ The output shape for the cross attention condition should be:
25
+ x,x_mask = [bs, seq_len, emb_dim], [bs, seq_len]
26
+
27
+ All the returned data, in which will be used as diffusion input, will need to be in float type
28
+ """
29
+
30
+
31
+ def disabled_train(self, mode=True):
32
+ """Overwrite model.train with this function to make sure train/eval mode
33
+ does not change anymore."""
34
+ return self
35
+
36
+
37
+ class PhonemeEncoder(nn.Module):
38
+ def __init__(self, vocabs_size=41, pad_length=250, pad_token_id=None):
39
+ super().__init__()
40
+ """
41
+ encoder = PhonemeEncoder(40)
42
+ data = torch.randint(0, 39, (2, 250))
43
+ output = encoder(data)
44
+ import ipdb;ipdb.set_trace()
45
+ """
46
+ assert pad_token_id is not None
47
+
48
+ self.device = None
49
+ self.PAD_LENGTH = int(pad_length)
50
+ self.pad_token_id = pad_token_id
51
+ self.pad_token_sequence = torch.tensor([self.pad_token_id] * self.PAD_LENGTH)
52
+
53
+ self.text_encoder = TextEncoder(
54
+ n_vocab=vocabs_size,
55
+ out_channels=192,
56
+ hidden_channels=192,
57
+ filter_channels=768,
58
+ n_heads=2,
59
+ n_layers=6,
60
+ kernel_size=3,
61
+ p_dropout=0.1,
62
+ )
63
+
64
+ self.learnable_positional_embedding = torch.nn.Parameter(
65
+ torch.zeros((1, 192, self.PAD_LENGTH))
66
+ ) # [batchsize, seqlen, padlen]
67
+ self.learnable_positional_embedding.requires_grad = True
68
+
69
+ # Required
70
+ def get_unconditional_condition(self, batchsize):
71
+ unconditional_tokens = self.pad_token_sequence.expand(
72
+ batchsize, self.PAD_LENGTH
73
+ )
74
+ return self(unconditional_tokens) # Need to return float type
75
+
76
+ # def get_unconditional_condition(self, batchsize):
77
+
78
+ # hidden_state = torch.zeros((batchsize, self.PAD_LENGTH, 192)).to(self.device)
79
+ # attention_mask = torch.ones((batchsize, self.PAD_LENGTH)).to(self.device)
80
+ # return [hidden_state, attention_mask] # Need to return float type
81
+
82
+ def _get_src_mask(self, phoneme):
83
+ src_mask = phoneme != self.pad_token_id
84
+ return src_mask
85
+
86
+ def _get_src_length(self, phoneme):
87
+ src_mask = self._get_src_mask(phoneme)
88
+ length = torch.sum(src_mask, dim=-1)
89
+ return length
90
+
91
+ # def make_empty_condition_unconditional(self, src_length, text_emb, attention_mask):
92
+ # # src_length: [bs]
93
+ # # text_emb: [bs, 192, pad_length]
94
+ # # attention_mask: [bs, pad_length]
95
+ # mask = src_length[..., None, None] > 1
96
+ # text_emb = text_emb * mask
97
+
98
+ # attention_mask[src_length < 1] = attention_mask[src_length < 1] * 0.0 + 1.0
99
+ # return text_emb, attention_mask
100
+
101
+ def forward(self, phoneme_idx):
102
+ if self.device is None:
103
+ self.device = self.learnable_positional_embedding.device
104
+ self.pad_token_sequence = self.pad_token_sequence.to(self.device)
105
+
106
+ phoneme_idx = phoneme_idx.to(self.device)
107
+
108
+ src_length = self._get_src_length(phoneme_idx)
109
+ text_emb, m, logs, text_emb_mask = self.text_encoder(phoneme_idx, src_length)
110
+ text_emb = text_emb + self.learnable_positional_embedding
111
+
112
+ # text_emb, text_emb_mask = self.make_empty_condition_unconditional(src_length, text_emb, text_emb_mask)
113
+
114
+ return [
115
+ text_emb.permute(0, 2, 1),
116
+ text_emb_mask.squeeze(1),
117
+ ] # [2, 250, 192], [2, 250]
118
+
119
+
120
+ class VAEFeatureExtract(nn.Module):
121
+ def __init__(self, first_stage_config):
122
+ super().__init__()
123
+ # self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
124
+ self.vae = None
125
+ self.instantiate_first_stage(first_stage_config)
126
+ self.device = None
127
+ self.unconditional_cond = None
128
+
129
+ def get_unconditional_condition(self, batchsize):
130
+ return self.unconditional_cond.unsqueeze(0).expand(batchsize, -1, -1, -1)
131
+
132
+ def instantiate_first_stage(self, config):
133
+ self.vae = instantiate_from_config(config)
134
+ self.vae.eval()
135
+ for p in self.vae.parameters():
136
+ p.requires_grad = False
137
+ self.vae.train = disabled_train
138
+
139
+ def forward(self, batch):
140
+ assert self.vae.training == False
141
+ if self.device is None:
142
+ self.device = next(self.vae.parameters()).device
143
+
144
+ with torch.no_grad():
145
+ vae_embed = self.vae.encode(batch.unsqueeze(1)).sample()
146
+
147
+ self.unconditional_cond = -11.4981 + vae_embed[0].clone() * 0.0
148
+
149
+ return vae_embed.detach()
150
+
151
+
152
+ class FlanT5HiddenState(nn.Module):
153
+ """
154
+ llama = FlanT5HiddenState()
155
+ data = ["","this is not an empty sentence"]
156
+ encoder_hidden_states = llama(data)
157
+ import ipdb;ipdb.set_trace()
158
+ """
159
+
160
+ def __init__(
161
+ self, text_encoder_name="google/flan-t5-large", freeze_text_encoder=True
162
+ ):
163
+ super().__init__()
164
+ self.freeze_text_encoder = freeze_text_encoder
165
+ self.tokenizer = AutoTokenizer.from_pretrained(text_encoder_name)
166
+ #self.model = T5EncoderModel(T5Config.from_pretrained(text_encoder_name))
167
+ if freeze_text_encoder:
168
+ self.model.eval()
169
+ for p in self.model.parameters():
170
+ p.requires_grad = False
171
+ else:
172
+ print("=> The text encoder is learnable")
173
+
174
+ self.empty_hidden_state_cfg = None
175
+ self.device = None
176
+
177
+ # Required
178
+ def get_unconditional_condition(self, batchsize):
179
+ param = next(self.model.parameters())
180
+ if self.freeze_text_encoder:
181
+ assert param.requires_grad == False
182
+
183
+ # device = param.device
184
+ if self.empty_hidden_state_cfg is None:
185
+ self.empty_hidden_state_cfg, _ = self([""])
186
+
187
+ hidden_state = torch.cat([self.empty_hidden_state_cfg] * batchsize).float()
188
+ attention_mask = (
189
+ torch.ones((batchsize, hidden_state.size(1)))
190
+ .to(hidden_state.device)
191
+ .float()
192
+ )
193
+ return [hidden_state, attention_mask] # Need to return float type
194
+
195
+ def forward(self, batch):
196
+ param = next(self.model.parameters())
197
+ if self.freeze_text_encoder:
198
+ assert param.requires_grad == False
199
+
200
+ if self.device is None:
201
+ self.device = param.device
202
+
203
+ # print("Manually change text")
204
+ # for i in range(len(batch)):
205
+ # batch[i] = "dog barking"
206
+ try:
207
+ return self.encode_text(batch)
208
+ except Exception as e:
209
+ print(e, batch)
210
+ logging.exception("An error occurred: %s", str(e))
211
+
212
+ def encode_text(self, prompt):
213
+ device = self.model.device
214
+ batch = self.tokenizer(
215
+ prompt,
216
+ max_length=128, # self.tokenizer.model_max_length
217
+ padding=True,
218
+ truncation=True,
219
+ return_tensors="pt",
220
+ )
221
+ input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(
222
+ device
223
+ )
224
+ # Get text encoding
225
+ if self.freeze_text_encoder:
226
+ with torch.no_grad():
227
+ encoder_hidden_states = self.model(
228
+ input_ids=input_ids, attention_mask=attention_mask
229
+ )[0]
230
+ else:
231
+ encoder_hidden_states = self.model(
232
+ input_ids=input_ids, attention_mask=attention_mask
233
+ )[0]
234
+ return [
235
+ encoder_hidden_states.detach(),
236
+ attention_mask.float(),
237
+ ]
238
+
239
+
240
+ class AudioMAEConditionCTPoolRandTFSeparated(nn.Module):
241
+ """
242
+ audiomae = AudioMAEConditionCTPool2x2()
243
+ data = torch.randn((4, 1024, 128))
244
+ output = audiomae(data)
245
+ import ipdb;ipdb.set_trace()
246
+ exit(0)
247
+ """
248
+
249
+ def __init__(
250
+ self,
251
+ time_pooling_factors=[1, 2, 4, 8],
252
+ freq_pooling_factors=[1, 2, 4, 8],
253
+ eval_time_pooling=None,
254
+ eval_freq_pooling=None,
255
+ mask_ratio=0.0,
256
+ regularization=False,
257
+ no_audiomae_mask=True,
258
+ no_audiomae_average=False,
259
+ ):
260
+ super().__init__()
261
+ self.device = None
262
+ self.time_pooling_factors = time_pooling_factors
263
+ self.freq_pooling_factors = freq_pooling_factors
264
+ self.no_audiomae_mask = no_audiomae_mask
265
+ self.no_audiomae_average = no_audiomae_average
266
+
267
+ self.eval_freq_pooling = eval_freq_pooling
268
+ self.eval_time_pooling = eval_time_pooling
269
+ self.mask_ratio = mask_ratio
270
+ self.use_reg = regularization
271
+
272
+ self.audiomae = Vanilla_AudioMAE()
273
+ self.audiomae.eval()
274
+ for p in self.audiomae.parameters():
275
+ p.requires_grad = False
276
+
277
+ # Required
278
+ def get_unconditional_condition(self, batchsize):
279
+ param = next(self.audiomae.parameters())
280
+ assert param.requires_grad == False
281
+ device = param.device
282
+ # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors)
283
+ time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
284
+ self.eval_freq_pooling, 8
285
+ )
286
+ # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))]
287
+ # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))]
288
+ token_num = int(512 / (time_pool * freq_pool))
289
+ return [
290
+ torch.zeros((batchsize, token_num, 768)).to(device).float(),
291
+ torch.ones((batchsize, token_num)).to(device).float(),
292
+ ]
293
+
294
+ def pool(self, representation, time_pool=None, freq_pool=None):
295
+ assert representation.size(-1) == 768
296
+ representation = representation[:, 1:, :].transpose(1, 2)
297
+ bs, embedding_dim, token_num = representation.size()
298
+ representation = representation.reshape(bs, embedding_dim, 64, 8)
299
+
300
+ if self.training:
301
+ if time_pool is None and freq_pool is None:
302
+ time_pool = min(
303
+ 64,
304
+ self.time_pooling_factors[
305
+ np.random.choice(list(range(len(self.time_pooling_factors))))
306
+ ],
307
+ )
308
+ freq_pool = min(
309
+ 8,
310
+ self.freq_pooling_factors[
311
+ np.random.choice(list(range(len(self.freq_pooling_factors))))
312
+ ],
313
+ )
314
+ # freq_pool = min(8, time_pool) # TODO here I make some modification.
315
+ else:
316
+ time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
317
+ self.eval_freq_pooling, 8
318
+ )
319
+
320
+ self.avgpooling = nn.AvgPool2d(
321
+ kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
322
+ )
323
+ self.maxpooling = nn.MaxPool2d(
324
+ kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
325
+ )
326
+
327
+ pooled = (
328
+ self.avgpooling(representation) + self.maxpooling(representation)
329
+ ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num]
330
+ pooled = pooled.flatten(2).transpose(1, 2)
331
+ return pooled # [bs, token_num, embedding_dim]
332
+
333
+ def regularization(self, x):
334
+ assert x.size(-1) == 768
335
+ x = F.normalize(x, p=2, dim=-1)
336
+ return x
337
+
338
+ # Required
339
+ def forward(self, batch, time_pool=None, freq_pool=None):
340
+ assert batch.size(-2) == 1024 and batch.size(-1) == 128
341
+
342
+ if self.device is None:
343
+ self.device = batch.device
344
+
345
+ batch = batch.unsqueeze(1)
346
+ with torch.no_grad():
347
+ representation = self.audiomae(
348
+ batch,
349
+ mask_ratio=self.mask_ratio,
350
+ no_mask=self.no_audiomae_mask,
351
+ no_average=self.no_audiomae_average,
352
+ )
353
+ representation = self.pool(representation, time_pool, freq_pool)
354
+ if self.use_reg:
355
+ representation = self.regularization(representation)
356
+ return [
357
+ representation,
358
+ torch.ones((representation.size(0), representation.size(1)))
359
+ .to(representation.device)
360
+ .float(),
361
+ ]
362
+
363
+
364
+ class AudioMAEConditionCTPoolRand(nn.Module):
365
+ """
366
+ audiomae = AudioMAEConditionCTPool2x2()
367
+ data = torch.randn((4, 1024, 128))
368
+ output = audiomae(data)
369
+ import ipdb;ipdb.set_trace()
370
+ exit(0)
371
+ """
372
+
373
+ def __init__(
374
+ self,
375
+ time_pooling_factors=[1, 2, 4, 8],
376
+ freq_pooling_factors=[1, 2, 4, 8],
377
+ eval_time_pooling=None,
378
+ eval_freq_pooling=None,
379
+ mask_ratio=0.0,
380
+ regularization=False,
381
+ no_audiomae_mask=True,
382
+ no_audiomae_average=False,
383
+ ):
384
+ super().__init__()
385
+ self.device = None
386
+ self.time_pooling_factors = time_pooling_factors
387
+ self.freq_pooling_factors = freq_pooling_factors
388
+ self.no_audiomae_mask = no_audiomae_mask
389
+ self.no_audiomae_average = no_audiomae_average
390
+
391
+ self.eval_freq_pooling = eval_freq_pooling
392
+ self.eval_time_pooling = eval_time_pooling
393
+ self.mask_ratio = mask_ratio
394
+ self.use_reg = regularization
395
+
396
+ self.audiomae = Vanilla_AudioMAE()
397
+ self.audiomae.eval()
398
+ for p in self.audiomae.parameters():
399
+ p.requires_grad = False
400
+
401
+ # Required
402
+ def get_unconditional_condition(self, batchsize):
403
+ param = next(self.audiomae.parameters())
404
+ assert param.requires_grad == False
405
+ device = param.device
406
+ # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors)
407
+ time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
408
+ self.eval_freq_pooling, 8
409
+ )
410
+ # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))]
411
+ # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))]
412
+ token_num = int(512 / (time_pool * freq_pool))
413
+ return [
414
+ torch.zeros((batchsize, token_num, 768)).to(device).float(),
415
+ torch.ones((batchsize, token_num)).to(device).float(),
416
+ ]
417
+
418
+ def pool(self, representation, time_pool=None, freq_pool=None):
419
+ assert representation.size(-1) == 768
420
+ representation = representation[:, 1:, :].transpose(1, 2)
421
+ bs, embedding_dim, token_num = representation.size()
422
+ representation = representation.reshape(bs, embedding_dim, 64, 8)
423
+
424
+ if self.training:
425
+ if time_pool is None and freq_pool is None:
426
+ time_pool = min(
427
+ 64,
428
+ self.time_pooling_factors[
429
+ np.random.choice(list(range(len(self.time_pooling_factors))))
430
+ ],
431
+ )
432
+ # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))]
433
+ freq_pool = min(8, time_pool) # TODO here I make some modification.
434
+ else:
435
+ time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
436
+ self.eval_freq_pooling, 8
437
+ )
438
+
439
+ self.avgpooling = nn.AvgPool2d(
440
+ kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
441
+ )
442
+ self.maxpooling = nn.MaxPool2d(
443
+ kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
444
+ )
445
+
446
+ pooled = (
447
+ self.avgpooling(representation) + self.maxpooling(representation)
448
+ ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num]
449
+ pooled = pooled.flatten(2).transpose(1, 2)
450
+ return pooled # [bs, token_num, embedding_dim]
451
+
452
+ def regularization(self, x):
453
+ assert x.size(-1) == 768
454
+ x = F.normalize(x, p=2, dim=-1)
455
+ return x
456
+
457
+ # Required
458
+ def forward(self, batch, time_pool=None, freq_pool=None):
459
+ assert batch.size(-2) == 1024 and batch.size(-1) == 128
460
+
461
+ if self.device is None:
462
+ self.device = next(self.audiomae.parameters()).device
463
+
464
+ batch = batch.unsqueeze(1).to(self.device)
465
+ with torch.no_grad():
466
+ representation = self.audiomae(
467
+ batch,
468
+ mask_ratio=self.mask_ratio,
469
+ no_mask=self.no_audiomae_mask,
470
+ no_average=self.no_audiomae_average,
471
+ )
472
+ representation = self.pool(representation, time_pool, freq_pool)
473
+ if self.use_reg:
474
+ representation = self.regularization(representation)
475
+ return [
476
+ representation,
477
+ torch.ones((representation.size(0), representation.size(1)))
478
+ .to(representation.device)
479
+ .float(),
480
+ ]
481
+
482
+
483
+ class CLAPAudioEmbeddingClassifierFreev2(nn.Module):
484
+ def __init__(
485
+ self,
486
+ pretrained_path="",
487
+ enable_cuda=False,
488
+ sampling_rate=16000,
489
+ embed_mode="audio",
490
+ amodel="HTSAT-base",
491
+ unconditional_prob=0.1,
492
+ random_mute=False,
493
+ max_random_mute_portion=0.5,
494
+ training_mode=True,
495
+ ):
496
+ super().__init__()
497
+ self.device = "cpu" # The model itself is on cpu
498
+ self.cuda = enable_cuda
499
+ self.precision = "fp32"
500
+ self.amodel = amodel # or 'PANN-14'
501
+ self.tmodel = "roberta" # the best text encoder in our training
502
+ self.enable_fusion = False # False if you do not want to use the fusion model
503
+ self.fusion_type = "aff_2d"
504
+ self.pretrained = pretrained_path
505
+ self.embed_mode = embed_mode
506
+ self.embed_mode_orig = embed_mode
507
+ self.sampling_rate = sampling_rate
508
+ self.unconditional_prob = unconditional_prob
509
+ self.random_mute = random_mute
510
+ #self.tokenize = RobertaTokenizer.from_pretrained("roberta-base")
511
+ self.max_random_mute_portion = max_random_mute_portion
512
+ self.training_mode = training_mode
513
+ self.model, self.model_cfg = create_model(
514
+ self.amodel,
515
+ self.tmodel,
516
+ self.pretrained,
517
+ precision=self.precision,
518
+ device=self.device,
519
+ enable_fusion=self.enable_fusion,
520
+ fusion_type=self.fusion_type,
521
+ )
522
+ self.model = self.model.to(self.device)
523
+ audio_cfg = self.model_cfg["audio_cfg"]
524
+ self.mel_transform = torchaudio.transforms.MelSpectrogram(
525
+ sample_rate=audio_cfg["sample_rate"],
526
+ n_fft=audio_cfg["window_size"],
527
+ win_length=audio_cfg["window_size"],
528
+ hop_length=audio_cfg["hop_size"],
529
+ center=True,
530
+ pad_mode="reflect",
531
+ power=2.0,
532
+ norm=None,
533
+ onesided=True,
534
+ n_mels=64,
535
+ f_min=audio_cfg["fmin"],
536
+ f_max=audio_cfg["fmax"],
537
+ )
538
+ for p in self.model.parameters():
539
+ p.requires_grad = False
540
+ self.unconditional_token = None
541
+ self.model.eval()
542
+
543
+ def get_unconditional_condition(self, batchsize):
544
+ self.unconditional_token = self.model.get_text_embedding(
545
+ self.tokenizer(["", ""])
546
+ )[0:1]
547
+ return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0)
548
+
549
+ def batch_to_list(self, batch):
550
+ ret = []
551
+ for i in range(batch.size(0)):
552
+ ret.append(batch[i])
553
+ return ret
554
+
555
+ def make_decision(self, probability):
556
+ if float(torch.rand(1)) < probability:
557
+ return True
558
+ else:
559
+ return False
560
+
561
+ def random_uniform(self, start, end):
562
+ val = torch.rand(1).item()
563
+ return start + (end - start) * val
564
+
565
+ def _random_mute(self, waveform):
566
+ # waveform: [bs, t-steps]
567
+ t_steps = waveform.size(-1)
568
+ for i in range(waveform.size(0)):
569
+ mute_size = int(
570
+ self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion))
571
+ )
572
+ mute_start = int(self.random_uniform(0, t_steps - mute_size))
573
+ waveform[i, mute_start : mute_start + mute_size] = 0
574
+ return waveform
575
+
576
+ def cos_similarity(self, waveform, text):
577
+ # waveform: [bs, t_steps]
578
+ original_embed_mode = self.embed_mode
579
+ with torch.no_grad():
580
+ self.embed_mode = "audio"
581
+ # MPS currently does not support ComplexFloat dtype and operator 'aten::_fft_r2c'
582
+ if self.cuda:
583
+ audio_emb = self(waveform.cuda())
584
+ else:
585
+ audio_emb = self(waveform.to("cpu"))
586
+ self.embed_mode = "text"
587
+ text_emb = self(text)
588
+ similarity = F.cosine_similarity(audio_emb, text_emb, dim=2)
589
+ self.embed_mode = original_embed_mode
590
+ return similarity.squeeze()
591
+
592
+ def build_unconditional_emb(self):
593
+ self.unconditional_token = self.model.get_text_embedding(
594
+ self.tokenizer(["", ""])
595
+ )[0:1]
596
+
597
+ def forward(self, batch):
598
+ # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0
599
+ # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0
600
+ if self.model.training == True and not self.training_mode:
601
+ print(
602
+ "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters."
603
+ )
604
+ self.model, self.model_cfg = create_model(
605
+ self.amodel,
606
+ self.tmodel,
607
+ self.pretrained,
608
+ precision=self.precision,
609
+ device="cuda" if self.cuda else "cpu",
610
+ enable_fusion=self.enable_fusion,
611
+ fusion_type=self.fusion_type,
612
+ )
613
+ for p in self.model.parameters():
614
+ p.requires_grad = False
615
+ self.model.eval()
616
+
617
+ if self.unconditional_token is None:
618
+ self.build_unconditional_emb()
619
+
620
+ # if(self.training_mode):
621
+ # assert self.model.training == True
622
+ # else:
623
+ # assert self.model.training == False
624
+
625
+ # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
626
+ if self.embed_mode == "audio":
627
+ if not self.training:
628
+ print("INFO: clap model calculate the audio embedding as condition")
629
+ with torch.no_grad():
630
+ # assert (
631
+ # self.sampling_rate == 16000
632
+ # ), "We only support 16000 sampling rate"
633
+
634
+ # if self.random_mute:
635
+ # batch = self._random_mute(batch)
636
+ # batch: [bs, 1, t-samples]
637
+ if self.sampling_rate != 48000:
638
+ batch = torchaudio.functional.resample(
639
+ batch, orig_freq=self.sampling_rate, new_freq=48000
640
+ )
641
+ audio_data = batch.squeeze(1).to("cpu")
642
+ self.mel_transform = self.mel_transform.to(audio_data.device)
643
+ mel = self.mel_transform(audio_data)
644
+ audio_dict = get_audio_features(
645
+ audio_data,
646
+ mel,
647
+ 480000,
648
+ data_truncating="fusion",
649
+ data_filling="repeatpad",
650
+ audio_cfg=self.model_cfg["audio_cfg"],
651
+ )
652
+ # [bs, 512]
653
+ embed = self.model.get_audio_embedding(audio_dict)
654
+ elif self.embed_mode == "text":
655
+ with torch.no_grad():
656
+ # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
657
+ text_data = self.tokenizer(batch)
658
+
659
+ if isinstance(batch, str) or (
660
+ isinstance(batch, list) and len(batch) == 1
661
+ ):
662
+ for key in text_data.keys():
663
+ text_data[key] = text_data[key].unsqueeze(0)
664
+
665
+ embed = self.model.get_text_embedding(text_data)
666
+
667
+ embed = embed.unsqueeze(1)
668
+ for i in range(embed.size(0)):
669
+ if self.make_decision(self.unconditional_prob):
670
+ embed[i] = self.unconditional_token
671
+ # embed = torch.randn((batch.size(0), 1, 512)).type_as(batch)
672
+ return embed.detach()
673
+
674
+ def tokenizer(self, text):
675
+ result = self.tokenize(
676
+ text,
677
+ padding="max_length",
678
+ truncation=True,
679
+ max_length=512,
680
+ return_tensors="pt",
681
+ )
682
+ return {k: v.squeeze(0) for k, v in result.items()}
FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/__init__.py ADDED
File without changes
FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/attentions.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ import audiosr.latent_diffusion.modules.phoneme_encoder.commons as commons
7
+
8
+ LRELU_SLOPE = 0.1
9
+
10
+
11
+ class LayerNorm(nn.Module):
12
+ def __init__(self, channels, eps=1e-5):
13
+ super().__init__()
14
+ self.channels = channels
15
+ self.eps = eps
16
+
17
+ self.gamma = nn.Parameter(torch.ones(channels))
18
+ self.beta = nn.Parameter(torch.zeros(channels))
19
+
20
+ def forward(self, x):
21
+ x = x.transpose(1, -1)
22
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
23
+ return x.transpose(1, -1)
24
+
25
+
26
+ class Encoder(nn.Module):
27
+ def __init__(
28
+ self,
29
+ hidden_channels,
30
+ filter_channels,
31
+ n_heads,
32
+ n_layers,
33
+ kernel_size=1,
34
+ p_dropout=0.0,
35
+ window_size=4,
36
+ **kwargs
37
+ ):
38
+ super().__init__()
39
+ self.hidden_channels = hidden_channels
40
+ self.filter_channels = filter_channels
41
+ self.n_heads = n_heads
42
+ self.n_layers = n_layers
43
+ self.kernel_size = kernel_size
44
+ self.p_dropout = p_dropout
45
+ self.window_size = window_size
46
+
47
+ self.drop = nn.Dropout(p_dropout)
48
+ self.attn_layers = nn.ModuleList()
49
+ self.norm_layers_1 = nn.ModuleList()
50
+ self.ffn_layers = nn.ModuleList()
51
+ self.norm_layers_2 = nn.ModuleList()
52
+ for i in range(self.n_layers):
53
+ self.attn_layers.append(
54
+ MultiHeadAttention(
55
+ hidden_channels,
56
+ hidden_channels,
57
+ n_heads,
58
+ p_dropout=p_dropout,
59
+ window_size=window_size,
60
+ )
61
+ )
62
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
63
+ self.ffn_layers.append(
64
+ FFN(
65
+ hidden_channels,
66
+ hidden_channels,
67
+ filter_channels,
68
+ kernel_size,
69
+ p_dropout=p_dropout,
70
+ )
71
+ )
72
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
73
+
74
+ def forward(self, x, x_mask):
75
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
76
+ x = x * x_mask
77
+ for i in range(self.n_layers):
78
+ y = self.attn_layers[i](x, x, attn_mask)
79
+ y = self.drop(y)
80
+ x = self.norm_layers_1[i](x + y)
81
+
82
+ y = self.ffn_layers[i](x, x_mask)
83
+ y = self.drop(y)
84
+ x = self.norm_layers_2[i](x + y)
85
+ x = x * x_mask
86
+ return x
87
+
88
+
89
+ class Decoder(nn.Module):
90
+ def __init__(
91
+ self,
92
+ hidden_channels,
93
+ filter_channels,
94
+ n_heads,
95
+ n_layers,
96
+ kernel_size=1,
97
+ p_dropout=0.0,
98
+ proximal_bias=False,
99
+ proximal_init=True,
100
+ **kwargs
101
+ ):
102
+ super().__init__()
103
+ self.hidden_channels = hidden_channels
104
+ self.filter_channels = filter_channels
105
+ self.n_heads = n_heads
106
+ self.n_layers = n_layers
107
+ self.kernel_size = kernel_size
108
+ self.p_dropout = p_dropout
109
+ self.proximal_bias = proximal_bias
110
+ self.proximal_init = proximal_init
111
+
112
+ self.drop = nn.Dropout(p_dropout)
113
+ self.self_attn_layers = nn.ModuleList()
114
+ self.norm_layers_0 = nn.ModuleList()
115
+ self.encdec_attn_layers = nn.ModuleList()
116
+ self.norm_layers_1 = nn.ModuleList()
117
+ self.ffn_layers = nn.ModuleList()
118
+ self.norm_layers_2 = nn.ModuleList()
119
+ for i in range(self.n_layers):
120
+ self.self_attn_layers.append(
121
+ MultiHeadAttention(
122
+ hidden_channels,
123
+ hidden_channels,
124
+ n_heads,
125
+ p_dropout=p_dropout,
126
+ proximal_bias=proximal_bias,
127
+ proximal_init=proximal_init,
128
+ )
129
+ )
130
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
131
+ self.encdec_attn_layers.append(
132
+ MultiHeadAttention(
133
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
134
+ )
135
+ )
136
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
137
+ self.ffn_layers.append(
138
+ FFN(
139
+ hidden_channels,
140
+ hidden_channels,
141
+ filter_channels,
142
+ kernel_size,
143
+ p_dropout=p_dropout,
144
+ causal=True,
145
+ )
146
+ )
147
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
148
+
149
+ def forward(self, x, x_mask, h, h_mask):
150
+ """
151
+ x: decoder input
152
+ h: encoder output
153
+ """
154
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
155
+ device=x.device, dtype=x.dtype
156
+ )
157
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
158
+ x = x * x_mask
159
+ for i in range(self.n_layers):
160
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
161
+ y = self.drop(y)
162
+ x = self.norm_layers_0[i](x + y)
163
+
164
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
165
+ y = self.drop(y)
166
+ x = self.norm_layers_1[i](x + y)
167
+
168
+ y = self.ffn_layers[i](x, x_mask)
169
+ y = self.drop(y)
170
+ x = self.norm_layers_2[i](x + y)
171
+ x = x * x_mask
172
+ return x
173
+
174
+
175
+ class MultiHeadAttention(nn.Module):
176
+ def __init__(
177
+ self,
178
+ channels,
179
+ out_channels,
180
+ n_heads,
181
+ p_dropout=0.0,
182
+ window_size=None,
183
+ heads_share=True,
184
+ block_length=None,
185
+ proximal_bias=False,
186
+ proximal_init=False,
187
+ ):
188
+ super().__init__()
189
+ assert channels % n_heads == 0
190
+
191
+ self.channels = channels
192
+ self.out_channels = out_channels
193
+ self.n_heads = n_heads
194
+ self.p_dropout = p_dropout
195
+ self.window_size = window_size
196
+ self.heads_share = heads_share
197
+ self.block_length = block_length
198
+ self.proximal_bias = proximal_bias
199
+ self.proximal_init = proximal_init
200
+ self.attn = None
201
+
202
+ self.k_channels = channels // n_heads
203
+ self.conv_q = nn.Conv1d(channels, channels, 1)
204
+ self.conv_k = nn.Conv1d(channels, channels, 1)
205
+ self.conv_v = nn.Conv1d(channels, channels, 1)
206
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
207
+ self.drop = nn.Dropout(p_dropout)
208
+
209
+ if window_size is not None:
210
+ n_heads_rel = 1 if heads_share else n_heads
211
+ rel_stddev = self.k_channels**-0.5
212
+ self.emb_rel_k = nn.Parameter(
213
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
214
+ * rel_stddev
215
+ )
216
+ self.emb_rel_v = nn.Parameter(
217
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
218
+ * rel_stddev
219
+ )
220
+
221
+ nn.init.xavier_uniform_(self.conv_q.weight)
222
+ nn.init.xavier_uniform_(self.conv_k.weight)
223
+ nn.init.xavier_uniform_(self.conv_v.weight)
224
+ if proximal_init:
225
+ with torch.no_grad():
226
+ self.conv_k.weight.copy_(self.conv_q.weight)
227
+ self.conv_k.bias.copy_(self.conv_q.bias)
228
+
229
+ def forward(self, x, c, attn_mask=None):
230
+ q = self.conv_q(x)
231
+ k = self.conv_k(c)
232
+ v = self.conv_v(c)
233
+
234
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
235
+
236
+ x = self.conv_o(x)
237
+ return x
238
+
239
+ def attention(self, query, key, value, mask=None):
240
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
241
+ b, d, t_s, t_t = (*key.size(), query.size(2))
242
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
243
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
244
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
245
+
246
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
247
+ if self.window_size is not None:
248
+ assert (
249
+ t_s == t_t
250
+ ), "Relative attention is only available for self-attention."
251
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
252
+ rel_logits = self._matmul_with_relative_keys(
253
+ query / math.sqrt(self.k_channels), key_relative_embeddings
254
+ )
255
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
256
+ scores = scores + scores_local
257
+ if self.proximal_bias:
258
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
259
+ scores = scores + self._attention_bias_proximal(t_s).to(
260
+ device=scores.device, dtype=scores.dtype
261
+ )
262
+ if mask is not None:
263
+ scores = scores.masked_fill(mask == 0, -1e4)
264
+ if self.block_length is not None:
265
+ assert (
266
+ t_s == t_t
267
+ ), "Local attention is only available for self-attention."
268
+ block_mask = (
269
+ torch.ones_like(scores)
270
+ .triu(-self.block_length)
271
+ .tril(self.block_length)
272
+ )
273
+ scores = scores.masked_fill(block_mask == 0, -1e4)
274
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
275
+ p_attn = self.drop(p_attn)
276
+ output = torch.matmul(p_attn, value)
277
+ if self.window_size is not None:
278
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
279
+ value_relative_embeddings = self._get_relative_embeddings(
280
+ self.emb_rel_v, t_s
281
+ )
282
+ output = output + self._matmul_with_relative_values(
283
+ relative_weights, value_relative_embeddings
284
+ )
285
+ output = (
286
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
287
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
288
+ return output, p_attn
289
+
290
+ def _matmul_with_relative_values(self, x, y):
291
+ """
292
+ x: [b, h, l, m]
293
+ y: [h or 1, m, d]
294
+ ret: [b, h, l, d]
295
+ """
296
+ ret = torch.matmul(x, y.unsqueeze(0))
297
+ return ret
298
+
299
+ def _matmul_with_relative_keys(self, x, y):
300
+ """
301
+ x: [b, h, l, d]
302
+ y: [h or 1, m, d]
303
+ ret: [b, h, l, m]
304
+ """
305
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
306
+ return ret
307
+
308
+ def _get_relative_embeddings(self, relative_embeddings, length):
309
+ 2 * self.window_size + 1
310
+ # Pad first before slice to avoid using cond ops.
311
+ pad_length = max(length - (self.window_size + 1), 0)
312
+ slice_start_position = max((self.window_size + 1) - length, 0)
313
+ slice_end_position = slice_start_position + 2 * length - 1
314
+ if pad_length > 0:
315
+ padded_relative_embeddings = F.pad(
316
+ relative_embeddings,
317
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
318
+ )
319
+ else:
320
+ padded_relative_embeddings = relative_embeddings
321
+ used_relative_embeddings = padded_relative_embeddings[
322
+ :, slice_start_position:slice_end_position
323
+ ]
324
+ return used_relative_embeddings
325
+
326
+ def _relative_position_to_absolute_position(self, x):
327
+ """
328
+ x: [b, h, l, 2*l-1]
329
+ ret: [b, h, l, l]
330
+ """
331
+ batch, heads, length, _ = x.size()
332
+ # Concat columns of pad to shift from relative to absolute indexing.
333
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
334
+
335
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
336
+ x_flat = x.view([batch, heads, length * 2 * length])
337
+ x_flat = F.pad(
338
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
339
+ )
340
+
341
+ # Reshape and slice out the padded elements.
342
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
343
+ :, :, :length, length - 1 :
344
+ ]
345
+ return x_final
346
+
347
+ def _absolute_position_to_relative_position(self, x):
348
+ """
349
+ x: [b, h, l, l]
350
+ ret: [b, h, l, 2*l-1]
351
+ """
352
+ batch, heads, length, _ = x.size()
353
+ # padd along column
354
+ x = F.pad(
355
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
356
+ )
357
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
358
+ # add 0's in the beginning that will skew the elements after reshape
359
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
360
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
361
+ return x_final
362
+
363
+ def _attention_bias_proximal(self, length):
364
+ """Bias for self-attention to encourage attention to close positions.
365
+ Args:
366
+ length: an integer scalar.
367
+ Returns:
368
+ a Tensor with shape [1, 1, length, length]
369
+ """
370
+ r = torch.arange(length, dtype=torch.float32)
371
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
372
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
373
+
374
+
375
+ class FFN(nn.Module):
376
+ def __init__(
377
+ self,
378
+ in_channels,
379
+ out_channels,
380
+ filter_channels,
381
+ kernel_size,
382
+ p_dropout=0.0,
383
+ activation=None,
384
+ causal=False,
385
+ ):
386
+ super().__init__()
387
+ self.in_channels = in_channels
388
+ self.out_channels = out_channels
389
+ self.filter_channels = filter_channels
390
+ self.kernel_size = kernel_size
391
+ self.p_dropout = p_dropout
392
+ self.activation = activation
393
+ self.causal = causal
394
+
395
+ if causal:
396
+ self.padding = self._causal_padding
397
+ else:
398
+ self.padding = self._same_padding
399
+
400
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
401
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
402
+ self.drop = nn.Dropout(p_dropout)
403
+
404
+ def forward(self, x, x_mask):
405
+ x = self.conv_1(self.padding(x * x_mask))
406
+ if self.activation == "gelu":
407
+ x = x * torch.sigmoid(1.702 * x)
408
+ else:
409
+ x = torch.relu(x)
410
+ x = self.drop(x)
411
+ x = self.conv_2(self.padding(x * x_mask))
412
+ return x * x_mask
413
+
414
+ def _causal_padding(self, x):
415
+ if self.kernel_size == 1:
416
+ return x
417
+ pad_l = self.kernel_size - 1
418
+ pad_r = 0
419
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
420
+ x = F.pad(x, commons.convert_pad_shape(padding))
421
+ return x
422
+
423
+ def _same_padding(self, x):
424
+ if self.kernel_size == 1:
425
+ return x
426
+ pad_l = (self.kernel_size - 1) // 2
427
+ pad_r = self.kernel_size // 2
428
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
429
+ x = F.pad(x, commons.convert_pad_shape(padding))
430
+ return x
FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/commons.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def init_weights(m, mean=0.0, std=0.01):
7
+ classname = m.__class__.__name__
8
+ if classname.find("Conv") != -1:
9
+ m.weight.data.normal_(mean, std)
10
+
11
+
12
+ def get_padding(kernel_size, dilation=1):
13
+ return int((kernel_size * dilation - dilation) / 2)
14
+
15
+
16
+ def convert_pad_shape(pad_shape):
17
+ l = pad_shape[::-1]
18
+ pad_shape = [item for sublist in l for item in sublist]
19
+ return pad_shape
20
+
21
+
22
+ def intersperse(lst, item):
23
+ result = [item] * (len(lst) * 2 + 1)
24
+ result[1::2] = lst
25
+ return result
26
+
27
+
28
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
29
+ """KL(P||Q)"""
30
+ kl = (logs_q - logs_p) - 0.5
31
+ kl += (
32
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
33
+ )
34
+ return kl
35
+
36
+
37
+ def rand_gumbel(shape):
38
+ """Sample from the Gumbel distribution, protect from overflows."""
39
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40
+ return -torch.log(-torch.log(uniform_samples))
41
+
42
+
43
+ def rand_gumbel_like(x):
44
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45
+ return g
46
+
47
+
48
+ def slice_segments(x, ids_str, segment_size=4):
49
+ ret = torch.zeros_like(x[:, :, :segment_size])
50
+ for i in range(x.size(0)):
51
+ idx_str = ids_str[i]
52
+ idx_end = idx_str + segment_size
53
+ ret[i] = x[i, :, idx_str:idx_end]
54
+ return ret
55
+
56
+
57
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
58
+ b, d, t = x.size()
59
+ if x_lengths is None:
60
+ x_lengths = t
61
+ ids_str_max = x_lengths - segment_size + 1
62
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63
+ ret = slice_segments(x, ids_str, segment_size)
64
+ return ret, ids_str
65
+
66
+
67
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
68
+ position = torch.arange(length, dtype=torch.float)
69
+ num_timescales = channels // 2
70
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
71
+ num_timescales - 1
72
+ )
73
+ inv_timescales = min_timescale * torch.exp(
74
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
75
+ )
76
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
77
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
78
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
79
+ signal = signal.view(1, channels, length)
80
+ return signal
81
+
82
+
83
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
84
+ b, channels, length = x.size()
85
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
86
+ return x + signal.to(dtype=x.dtype, device=x.device)
87
+
88
+
89
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
90
+ b, channels, length = x.size()
91
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
92
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
93
+
94
+
95
+ def subsequent_mask(length):
96
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
97
+ return mask
98
+
99
+
100
+ @torch.jit.script
101
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
102
+ n_channels_int = n_channels[0]
103
+ in_act = input_a + input_b
104
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
105
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
106
+ acts = t_act * s_act
107
+ return acts
108
+
109
+
110
+ def convert_pad_shape(pad_shape):
111
+ l = pad_shape[::-1]
112
+ pad_shape = [item for sublist in l for item in sublist]
113
+ return pad_shape
114
+
115
+
116
+ def shift_1d(x):
117
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
118
+ return x
119
+
120
+
121
+ def sequence_mask(length, max_length=None):
122
+ if max_length is None:
123
+ max_length = length.max()
124
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
125
+ return x.unsqueeze(0) < length.unsqueeze(1)
126
+
127
+
128
+ def generate_path(duration, mask):
129
+ """
130
+ duration: [b, 1, t_x]
131
+ mask: [b, 1, t_y, t_x]
132
+ """
133
+ duration.device
134
+
135
+ b, _, t_y, t_x = mask.shape
136
+ cum_duration = torch.cumsum(duration, -1)
137
+
138
+ cum_duration_flat = cum_duration.view(b * t_x)
139
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
140
+ path = path.view(b, t_x, t_y)
141
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
142
+ path = path.unsqueeze(1).transpose(2, 3) * mask
143
+ return path
144
+
145
+
146
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
147
+ if isinstance(parameters, torch.Tensor):
148
+ parameters = [parameters]
149
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
150
+ norm_type = float(norm_type)
151
+ if clip_value is not None:
152
+ clip_value = float(clip_value)
153
+
154
+ total_norm = 0
155
+ for p in parameters:
156
+ param_norm = p.grad.data.norm(norm_type)
157
+ total_norm += param_norm.item() ** norm_type
158
+ if clip_value is not None:
159
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
160
+ total_norm = total_norm ** (1.0 / norm_type)
161
+ return total_norm
FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/encoder.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+
5
+ import audiosr.latent_diffusion.modules.phoneme_encoder.commons as commons
6
+ import audiosr.latent_diffusion.modules.phoneme_encoder.attentions as attentions
7
+
8
+
9
+ class TextEncoder(nn.Module):
10
+ def __init__(
11
+ self,
12
+ n_vocab,
13
+ out_channels=192,
14
+ hidden_channels=192,
15
+ filter_channels=768,
16
+ n_heads=2,
17
+ n_layers=6,
18
+ kernel_size=3,
19
+ p_dropout=0.1,
20
+ ):
21
+ super().__init__()
22
+ self.n_vocab = n_vocab
23
+ self.out_channels = out_channels
24
+ self.hidden_channels = hidden_channels
25
+ self.filter_channels = filter_channels
26
+ self.n_heads = n_heads
27
+ self.n_layers = n_layers
28
+ self.kernel_size = kernel_size
29
+ self.p_dropout = p_dropout
30
+
31
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
32
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
33
+
34
+ self.encoder = attentions.Encoder(
35
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
36
+ )
37
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
38
+
39
+ def forward(self, x, x_lengths):
40
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
41
+ x = torch.transpose(x, 1, -1) # [b, h, t]
42
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
43
+ x.dtype
44
+ )
45
+
46
+ x = self.encoder(x * x_mask, x_mask)
47
+ stats = self.proj(x) * x_mask
48
+
49
+ m, logs = torch.split(stats, self.out_channels, dim=1)
50
+ return x, m, logs, x_mask
FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2017 Keith Ito
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in
11
+ all copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19
+ THE SOFTWARE.
FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+ from audiosr.latent_diffusion.modules.phoneme_encoder.text import cleaners
3
+ from audiosr.latent_diffusion.modules.phoneme_encoder.text.symbols import symbols
4
+
5
+
6
+ # Mappings from symbol to numeric ID and vice versa:
7
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
8
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
9
+
10
+ cleaner = getattr(cleaners, "english_cleaners2")
11
+
12
+
13
+ def text_to_sequence(text, cleaner_names):
14
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
15
+ Args:
16
+ text: string to convert to a sequence
17
+ cleaner_names: names of the cleaner functions to run the text through
18
+ Returns:
19
+ List of integers corresponding to the symbols in the text
20
+ """
21
+ sequence = []
22
+
23
+ clean_text = _clean_text(text, cleaner_names)
24
+ for symbol in clean_text:
25
+ symbol_id = _symbol_to_id[symbol]
26
+ sequence += [symbol_id]
27
+ return sequence
28
+
29
+
30
+ def cleaned_text_to_sequence(cleaned_text):
31
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
32
+ Args:
33
+ text: string to convert to a sequence
34
+ Returns:
35
+ List of integers corresponding to the symbols in the text
36
+ """
37
+ sequence = [_symbol_to_id[symbol] for symbol in cleaned_text]
38
+ return sequence
39
+
40
+
41
+ def sequence_to_text(sequence):
42
+ """Converts a sequence of IDs back to a string"""
43
+ result = ""
44
+ for symbol_id in sequence:
45
+ s = _id_to_symbol[symbol_id]
46
+ result += s
47
+ return result
48
+
49
+
50
+ def _clean_text(text, cleaner_names):
51
+ text = cleaner(text)
52
+ return text
FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/cleaners.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ """
4
+ Cleaners are transformations that run over the input text at both training and eval time.
5
+
6
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7
+ hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8
+ 1. "english_cleaners" for English text
9
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12
+ the symbols in symbols.py to match your data).
13
+ """
14
+
15
+ import re
16
+ #from unidecode import unidecode
17
+ #from phonemizer import phonemize
18
+
19
+
20
+ # Regular expression matching whitespace:
21
+ _whitespace_re = re.compile(r"\s+")
22
+
23
+ # List of (regular expression, replacement) pairs for abbreviations:
24
+ _abbreviations = [
25
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
26
+ for x in [
27
+ ("mrs", "misess"),
28
+ ("mr", "mister"),
29
+ ("dr", "doctor"),
30
+ ("st", "saint"),
31
+ ("co", "company"),
32
+ ("jr", "junior"),
33
+ ("maj", "major"),
34
+ ("gen", "general"),
35
+ ("drs", "doctors"),
36
+ ("rev", "reverend"),
37
+ ("lt", "lieutenant"),
38
+ ("hon", "honorable"),
39
+ ("sgt", "sergeant"),
40
+ ("capt", "captain"),
41
+ ("esq", "esquire"),
42
+ ("ltd", "limited"),
43
+ ("col", "colonel"),
44
+ ("ft", "fort"),
45
+ ]
46
+ ]
47
+
48
+
49
+ def expand_abbreviations(text):
50
+ for regex, replacement in _abbreviations:
51
+ text = re.sub(regex, replacement, text)
52
+ return text
53
+
54
+
55
+ def expand_numbers(text):
56
+ return normalize_numbers(text)
57
+
58
+
59
+ def lowercase(text):
60
+ return text.lower()
61
+
62
+
63
+ def collapse_whitespace(text):
64
+ return re.sub(_whitespace_re, " ", text)
65
+
66
+
67
+ def convert_to_ascii(text):
68
+ return unidecode(text)
69
+
70
+
71
+ def basic_cleaners(text):
72
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
73
+ text = lowercase(text)
74
+ text = collapse_whitespace(text)
75
+ return text
76
+
77
+
78
+ def transliteration_cleaners(text):
79
+ """Pipeline for non-English text that transliterates to ASCII."""
80
+ text = convert_to_ascii(text)
81
+ text = lowercase(text)
82
+ text = collapse_whitespace(text)
83
+ return text
84
+
85
+
86
+ def english_cleaners(text):
87
+ """Pipeline for English text, including abbreviation expansion."""
88
+ text = convert_to_ascii(text)
89
+ text = lowercase(text)
90
+ text = expand_abbreviations(text)
91
+ phonemes = phonemize(text, language="en-us", backend="espeak", strip=True)
92
+ phonemes = collapse_whitespace(phonemes)
93
+ return phonemes
94
+
95
+
96
+ def english_cleaners2(text):
97
+ """Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
98
+ text = convert_to_ascii(text)
99
+ text = lowercase(text)
100
+ text = expand_abbreviations(text)
101
+ phonemes = phonemize(
102
+ text,
103
+ language="en-us",
104
+ backend="espeak",
105
+ strip=True,
106
+ preserve_punctuation=True,
107
+ with_stress=True,
108
+ )
109
+ phonemes = collapse_whitespace(phonemes)
110
+ return phonemes
FlashSR/AudioSR/latent_diffusion/modules/phoneme_encoder/text/symbols.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ """
4
+ Defines the set of symbols used in text input to the model.
5
+ """
6
+ _pad = "_"
7
+ _punctuation = ';:,.!?¡¿—…"«»“” '
8
+ _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
9
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
10
+
11
+
12
+ # Export all symbols:
13
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
14
+
15
+ # Special symbol ids
16
+ SPACE_ID = symbols.index(" ")
FlashSR/AudioSR/latent_diffusion/util.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ import torch
4
+ import numpy as np
5
+ from collections import abc
6
+
7
+ import multiprocessing as mp
8
+ from threading import Thread
9
+ from queue import Queue
10
+
11
+ from inspect import isfunction
12
+ from PIL import Image, ImageDraw, ImageFont
13
+
14
+ CACHE = {
15
+ "get_vits_phoneme_ids": {
16
+ "PAD_LENGTH": 310,
17
+ "_pad": "_",
18
+ "_punctuation": ';:,.!?¡¿—…"«»“” ',
19
+ "_letters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
20
+ "_letters_ipa": "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ",
21
+ "_special": "♪☎☒☝⚠",
22
+ }
23
+ }
24
+
25
+ CACHE["get_vits_phoneme_ids"]["symbols"] = (
26
+ [CACHE["get_vits_phoneme_ids"]["_pad"]]
27
+ + list(CACHE["get_vits_phoneme_ids"]["_punctuation"])
28
+ + list(CACHE["get_vits_phoneme_ids"]["_letters"])
29
+ + list(CACHE["get_vits_phoneme_ids"]["_letters_ipa"])
30
+ + list(CACHE["get_vits_phoneme_ids"]["_special"])
31
+ )
32
+ CACHE["get_vits_phoneme_ids"]["_symbol_to_id"] = {
33
+ s: i for i, s in enumerate(CACHE["get_vits_phoneme_ids"]["symbols"])
34
+ }
35
+
36
+
37
+ def get_vits_phoneme_ids_no_padding(phonemes):
38
+ pad_token_id = 0
39
+ pad_length = CACHE["get_vits_phoneme_ids"]["PAD_LENGTH"]
40
+ _symbol_to_id = CACHE["get_vits_phoneme_ids"]["_symbol_to_id"]
41
+ batchsize = len(phonemes)
42
+
43
+ clean_text = phonemes[0] + "⚠"
44
+ sequence = []
45
+
46
+ for symbol in clean_text:
47
+ if symbol not in _symbol_to_id.keys():
48
+ print("%s is not in the vocabulary. %s" % (symbol, clean_text))
49
+ symbol = "_"
50
+ symbol_id = _symbol_to_id[symbol]
51
+ sequence += [symbol_id]
52
+
53
+ def _pad_phonemes(phonemes_list):
54
+ return phonemes_list + [pad_token_id] * (pad_length - len(phonemes_list))
55
+
56
+ sequence = sequence[:pad_length]
57
+
58
+ return {
59
+ "phoneme_idx": torch.LongTensor(_pad_phonemes(sequence))
60
+ .unsqueeze(0)
61
+ .expand(batchsize, -1)
62
+ }
63
+
64
+
65
+ def log_txt_as_img(wh, xc, size=10):
66
+ # wh a tuple of (width, height)
67
+ # xc a list of captions to plot
68
+ b = len(xc)
69
+ txts = list()
70
+ for bi in range(b):
71
+ txt = Image.new("RGB", wh, color="white")
72
+ draw = ImageDraw.Draw(txt)
73
+ font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
74
+ nc = int(40 * (wh[0] / 256))
75
+ lines = "\n".join(
76
+ xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
77
+ )
78
+
79
+ try:
80
+ draw.text((0, 0), lines, fill="black", font=font)
81
+ except UnicodeEncodeError:
82
+ print("Cant encode string for logging. Skipping.")
83
+
84
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
85
+ txts.append(txt)
86
+ txts = np.stack(txts)
87
+ txts = torch.tensor(txts)
88
+ return txts
89
+
90
+
91
+ def ismap(x):
92
+ if not isinstance(x, torch.Tensor):
93
+ return False
94
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
95
+
96
+
97
+ def isimage(x):
98
+ if not isinstance(x, torch.Tensor):
99
+ return False
100
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
101
+
102
+
103
+ def int16_to_float32(x):
104
+ return (x / 32767.0).astype(np.float32)
105
+
106
+
107
+ def float32_to_int16(x):
108
+ x = np.clip(x, a_min=-1.0, a_max=1.0)
109
+ return (x * 32767.0).astype(np.int16)
110
+
111
+
112
+ def exists(x):
113
+ return x is not None
114
+
115
+
116
+ def default(val, d):
117
+ if exists(val):
118
+ return val
119
+ return d() if isfunction(d) else d
120
+
121
+
122
+ def mean_flat(tensor):
123
+ """
124
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
125
+ Take the mean over all non-batch dimensions.
126
+ """
127
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
128
+
129
+
130
+ def count_params(model, verbose=False):
131
+ total_params = sum(p.numel() for p in model.parameters())
132
+ if verbose:
133
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
134
+ return total_params
135
+
136
+
137
+ def instantiate_from_config(config):
138
+ if not "target" in config:
139
+ if config == "__is_first_stage__":
140
+ return None
141
+ elif config == "__is_unconditional__":
142
+ return None
143
+ raise KeyError("Expected key `target` to instantiate.")
144
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
145
+
146
+
147
+ def get_obj_from_str(string, reload=False):
148
+ module, cls = string.rsplit(".", 1)
149
+ if reload:
150
+ module_imp = importlib.import_module(module)
151
+ importlib.reload(module_imp)
152
+ return getattr(importlib.import_module(module, package=None), cls)
153
+
154
+
155
+ def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
156
+ # create dummy dataset instance
157
+
158
+ # run prefetching
159
+ if idx_to_fn:
160
+ res = func(data, worker_id=idx)
161
+ else:
162
+ res = func(data)
163
+ Q.put([idx, res])
164
+ Q.put("Done")
165
+
166
+
167
+ def parallel_data_prefetch(
168
+ func: callable,
169
+ data,
170
+ n_proc,
171
+ target_data_type="ndarray",
172
+ cpu_intensive=True,
173
+ use_worker_id=False,
174
+ ):
175
+ # if target_data_type not in ["ndarray", "list"]:
176
+ # raise ValueError(
177
+ # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
178
+ # )
179
+ if isinstance(data, np.ndarray) and target_data_type == "list":
180
+ raise ValueError("list expected but function got ndarray.")
181
+ elif isinstance(data, abc.Iterable):
182
+ if isinstance(data, dict):
183
+ print(
184
+ f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
185
+ )
186
+ data = list(data.values())
187
+ if target_data_type == "ndarray":
188
+ data = np.asarray(data)
189
+ else:
190
+ data = list(data)
191
+ else:
192
+ raise TypeError(
193
+ f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
194
+ )
195
+
196
+ if cpu_intensive:
197
+ Q = mp.Queue(1000)
198
+ proc = mp.Process
199
+ else:
200
+ Q = Queue(1000)
201
+ proc = Thread
202
+ # spawn processes
203
+ if target_data_type == "ndarray":
204
+ arguments = [
205
+ [func, Q, part, i, use_worker_id]
206
+ for i, part in enumerate(np.array_split(data, n_proc))
207
+ ]
208
+ else:
209
+ step = (
210
+ int(len(data) / n_proc + 1)
211
+ if len(data) % n_proc != 0
212
+ else int(len(data) / n_proc)
213
+ )
214
+ arguments = [
215
+ [func, Q, part, i, use_worker_id]
216
+ for i, part in enumerate(
217
+ [data[i : i + step] for i in range(0, len(data), step)]
218
+ )
219
+ ]
220
+ processes = []
221
+ for i in range(n_proc):
222
+ p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
223
+ processes += [p]
224
+
225
+ # start processes
226
+ print(f"Start prefetching...")
227
+ import time
228
+
229
+ start = time.time()
230
+ gather_res = [[] for _ in range(n_proc)]
231
+ try:
232
+ for p in processes:
233
+ p.start()
234
+
235
+ k = 0
236
+ while k < n_proc:
237
+ # get result
238
+ res = Q.get()
239
+ if res == "Done":
240
+ k += 1
241
+ else:
242
+ gather_res[res[0]] = res[1]
243
+
244
+ except Exception as e:
245
+ print("Exception: ", e)
246
+ for p in processes:
247
+ p.terminate()
248
+
249
+ raise e
250
+ finally:
251
+ for p in processes:
252
+ p.join()
253
+ print(f"Prefetching complete. [{time.time() - start} sec.]")
254
+
255
+ if target_data_type == "ndarray":
256
+ if not isinstance(gather_res[0], np.ndarray):
257
+ return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
258
+
259
+ # order outputs
260
+ return np.concatenate(gather_res, axis=0)
261
+ elif target_data_type == "list":
262
+ out = []
263
+ for r in gather_res:
264
+ out.extend(r)
265
+ return out
266
+ else:
267
+ return gather_res
FlashSR/BigVGAN/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 NVIDIA CORPORATION.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.