CSH-1220 commited on
Commit
d57e374
·
1 Parent(s): 35ff45f

Add application file

Browse files
APadapter/ap_adapter/__init__.py ADDED
File without changes
APadapter/ap_adapter/attention_processor.py ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import copy
6
+ import os
7
+ class AttnProcessor(nn.Module):
8
+ r"""
9
+ Default processor for performing attention-related computations.
10
+ """
11
+
12
+ def __init__(
13
+ self,
14
+ hidden_size=None,
15
+ cross_attention_dim=None,
16
+ ):
17
+ super().__init__()
18
+
19
+ def __call__(
20
+ self,
21
+ attn,
22
+ hidden_states,
23
+ encoder_hidden_states=None,
24
+ attention_mask=None,
25
+ temb=None,
26
+ ):
27
+ residual = hidden_states
28
+
29
+ if attn.spatial_norm is not None:
30
+ hidden_states = attn.spatial_norm(hidden_states, temb)
31
+
32
+ input_ndim = hidden_states.ndim
33
+
34
+ if input_ndim == 4:
35
+ batch_size, channel, height, width = hidden_states.shape
36
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
37
+
38
+ batch_size, sequence_length, _ = (
39
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
40
+ )
41
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
42
+
43
+ if attn.group_norm is not None:
44
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
45
+
46
+ query = attn.to_q(hidden_states)
47
+
48
+ if encoder_hidden_states is None:
49
+ encoder_hidden_states = hidden_states
50
+ elif attn.norm_cross:
51
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
52
+
53
+ key = attn.to_k(encoder_hidden_states)
54
+ value = attn.to_v(encoder_hidden_states)
55
+
56
+ query = attn.head_to_batch_dim(query)
57
+ key = attn.head_to_batch_dim(key)
58
+ value = attn.head_to_batch_dim(value)
59
+
60
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
61
+ hidden_states = torch.bmm(attention_probs, value)
62
+ hidden_states = attn.batch_to_head_dim(hidden_states)
63
+
64
+ # linear proj
65
+ hidden_states = attn.to_out[0](hidden_states)
66
+ # dropout
67
+ hidden_states = attn.to_out[1](hidden_states)
68
+
69
+ if input_ndim == 4:
70
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
71
+
72
+ if attn.residual_connection:
73
+ hidden_states = hidden_states + residual
74
+
75
+ hidden_states = hidden_states / attn.rescale_output_factor
76
+
77
+ return hidden_states
78
+
79
+
80
+ class IPAttnProcessor(nn.Module):
81
+ r"""
82
+ Attention processor for IP-Adapater.
83
+ Args:
84
+ hidden_size (`int`):
85
+ The hidden size of the attention layer.
86
+ cross_attention_dim (`int`):
87
+ The number of channels in the `encoder_hidden_states`.
88
+ scale (`float`, defaults to 1.0):
89
+ the weight scale of image prompt.
90
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
91
+ The context length of the image features.
92
+ """
93
+
94
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
95
+ super().__init__()
96
+
97
+ self.hidden_size = hidden_size
98
+ self.cross_attention_dim = cross_attention_dim
99
+ self.scale = scale
100
+ self.num_tokens = num_tokens
101
+
102
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
103
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
104
+
105
+ def __call__(
106
+ self,
107
+ attn,
108
+ hidden_states,
109
+ encoder_hidden_states=None,
110
+ attention_mask=None,
111
+ temb=None,
112
+ ):
113
+ residual = hidden_states
114
+
115
+ if attn.spatial_norm is not None:
116
+ hidden_states = attn.spatial_norm(hidden_states, temb)
117
+
118
+ input_ndim = hidden_states.ndim
119
+
120
+ if input_ndim == 4:
121
+ batch_size, channel, height, width = hidden_states.shape
122
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
123
+ encoder_hidden_states = encoder_hidden_states.squeeze(0)
124
+ batch_size, sequence_length, _ = (
125
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
126
+ )
127
+
128
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
129
+
130
+ if attn.group_norm is not None:
131
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
132
+
133
+ query = attn.to_q(hidden_states)
134
+
135
+ if encoder_hidden_states is None:
136
+ encoder_hidden_states = hidden_states
137
+ else:
138
+ # get encoder_hidden_states, ip_hidden_states
139
+ end_pos = encoder_hidden_states.shape[1]//2
140
+ encoder_hidden_states, ip_hidden_states = (
141
+ encoder_hidden_states[:, :end_pos, :],
142
+ encoder_hidden_states[:, end_pos:, :],
143
+ )
144
+ if attn.norm_cross:
145
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
146
+
147
+ key = attn.to_k(encoder_hidden_states)
148
+ value = attn.to_v(encoder_hidden_states)
149
+
150
+ query = attn.head_to_batch_dim(query)
151
+ key = attn.head_to_batch_dim(key)
152
+ value = attn.head_to_batch_dim(value)
153
+
154
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
155
+ hidden_states = torch.bmm(attention_probs, value)
156
+ hidden_states = attn.batch_to_head_dim(hidden_states)
157
+
158
+ # for ip-adapter
159
+ self.to_k_ip.weight = copy.deepcopy(attn.to_k.weight)
160
+ self.to_k_ip.bias = copy.deepcopy(attn.to_k.bias)
161
+ self.to_v_ip.weight = copy.deepcopy(attn.to_v.weight)
162
+ self.to_v_ip.bias = copy.deepcopy(attn.to_v.bias)
163
+
164
+ # Set the weights of self.to_k_ip to zero
165
+ # nn.init.zeros_(self.to_k_ip.weight)
166
+
167
+ # # Set the weights of self.to_v_ip to zero
168
+ # nn.init.zeros_(self.to_v_ip.weight)
169
+
170
+ ip_key = self.to_k_ip(ip_hidden_states)
171
+ ip_value = self.to_v_ip(ip_hidden_states)
172
+
173
+ ip_key = attn.head_to_batch_dim(ip_key)
174
+ ip_value = attn.head_to_batch_dim(ip_value)
175
+
176
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
177
+ self.attn_map = ip_attention_probs
178
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
179
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
180
+
181
+ hidden_states = hidden_states + self.scale * ip_hidden_states
182
+
183
+ # linear proj
184
+ hidden_states = attn.to_out[0](hidden_states)
185
+ # dropout
186
+ hidden_states = attn.to_out[1](hidden_states)
187
+
188
+ if input_ndim == 4:
189
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
190
+
191
+ if attn.residual_connection:
192
+ hidden_states = hidden_states + residual
193
+
194
+ hidden_states = hidden_states / attn.rescale_output_factor
195
+
196
+ return hidden_states
197
+
198
+
199
+ class AttnProcessor2_0(torch.nn.Module):
200
+ r"""
201
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
202
+ """
203
+
204
+ def __init__(
205
+ self,
206
+ hidden_size=None,
207
+ cross_attention_dim=None,
208
+ ):
209
+ super().__init__()
210
+ if not hasattr(F, "scaled_dot_product_attention"):
211
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
212
+
213
+
214
+ def __call__(
215
+ self,
216
+ attn,
217
+ hidden_states,
218
+ encoder_hidden_states=None,
219
+ attention_mask=None,
220
+ temb=None,
221
+ ):
222
+ residual = hidden_states
223
+ # print("encoder_hidden_states_attn",encoder_hidden_states.shape)
224
+ if attn.spatial_norm is not None:
225
+ hidden_states = attn.spatial_norm(hidden_states, temb)
226
+
227
+ input_ndim = hidden_states.ndim
228
+ # print("hidden_states",hidden_states.shape)
229
+ if input_ndim == 4:
230
+ batch_size, channel, height, width = hidden_states.shape
231
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
232
+ # encoder_hidden_states = encoder_hidden_states.squeeze(0)
233
+ # if encoder_hidden_states is None:
234
+ # # print(hidden_states.shape)
235
+ # pass
236
+ # else:
237
+ # print(encoder_hidden_states.shape)
238
+ # # encoder_hidden_states = encoder_hidden_states.squeeze(0)
239
+ if encoder_hidden_states is not None and encoder_hidden_states.dim() < 3:
240
+ encoder_hidden_states = encoder_hidden_states.unsqueeze(0)
241
+ batch_size, sequence_length, _ = (
242
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
243
+ )
244
+
245
+ if attention_mask is not None:
246
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
247
+ # scaled_dot_product_attention expects attention_mask shape to be
248
+ # (batch, heads, source_length, target_length)
249
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
250
+
251
+ if attn.group_norm is not None:
252
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
253
+
254
+ query = attn.to_q(hidden_states)
255
+
256
+ if encoder_hidden_states is None:
257
+ encoder_hidden_states = hidden_states
258
+ elif attn.norm_cross:
259
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
260
+
261
+ key = attn.to_k(encoder_hidden_states)
262
+ value = attn.to_v(encoder_hidden_states)
263
+ # print("encoder_hidden_states_attn",encoder_hidden_states.shape)
264
+ inner_dim = key.shape[-1]
265
+ head_dim = inner_dim // attn.heads
266
+
267
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
268
+
269
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
270
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
271
+
272
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
273
+ # TODO: add support for attn.scale when we move to Torch 2.1
274
+ hidden_states = F.scaled_dot_product_attention(
275
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
276
+ )
277
+
278
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
279
+ hidden_states = hidden_states.to(query.dtype)
280
+
281
+ # linear proj
282
+ hidden_states = attn.to_out[0](hidden_states)
283
+ # dropout
284
+ hidden_states = attn.to_out[1](hidden_states)
285
+
286
+ if input_ndim == 4:
287
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
288
+
289
+ if attn.residual_connection:
290
+ hidden_states = hidden_states + residual
291
+
292
+ hidden_states = hidden_states / attn.rescale_output_factor
293
+
294
+ return hidden_states
295
+
296
+
297
+ class IPAttnProcessor2_0(torch.nn.Module):
298
+ r"""
299
+ Attention processor for IP-Adapater for PyTorch 2.0.
300
+
301
+ Args:
302
+ hidden_size (`int`):
303
+ The hidden size of the attention layer.
304
+ cross_attention_dim (`int`):
305
+ The number of channels in the `encoder_hidden_states`.
306
+ num_tokens (`int`, defaults to 4):
307
+ The context length of the image features.
308
+ scale (`float`, defaults to 1.0):
309
+ the weight scale of image prompt.
310
+ """
311
+
312
+ def __init__(self, hidden_size, name, cross_attention_dim=None, num_tokens=4, scale=1.0, do_copy = False):
313
+ super().__init__()
314
+
315
+ if not hasattr(F, "scaled_dot_product_attention"):
316
+ raise ImportError(
317
+ f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
318
+ )
319
+
320
+ self.hidden_size = hidden_size
321
+ self.cross_attention_dim = cross_attention_dim
322
+ self.num_tokens = num_tokens
323
+ self.scale = scale
324
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
325
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
326
+ self.name = name
327
+ # Below is for copying the weight of the original weight to the \
328
+ if do_copy:
329
+ print("do copy")
330
+ current_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
331
+ # Go up one level to the parent directory
332
+ parent_dir = os.path.dirname(current_dir)
333
+ # Construct the path to the weights
334
+ k_weight_path = os.path.join(parent_dir, 'copied_cross_attention', f'{self.name}_k.bin')
335
+ v_weight_path = os.path.join(parent_dir, 'copied_cross_attention', f'{self.name}_v.bin')
336
+ # Load the weights
337
+ k_weight = torch.load(k_weight_path)
338
+ v_weight = torch.load(v_weight_path)
339
+ k_weight = k_weight.to(torch.float32)
340
+ v_weight = v_weight.to(torch.float32)
341
+ self.to_k_ip.weight = nn.Parameter(k_weight)
342
+ self.to_v_ip.weight = nn.Parameter(v_weight)
343
+ self.to_k_ip.weight.requires_grad = True
344
+ self.to_v_ip.weight.requires_grad = True
345
+
346
+
347
+ def __call__(
348
+ self,
349
+ attn,
350
+ hidden_states,
351
+ encoder_hidden_states=None,
352
+ attention_mask=None,
353
+ temb=None,
354
+ scale=1.0,
355
+ ):
356
+ if scale != 1.0:
357
+ logger.warning("`scale` of IPAttnProcessor should be set by `set_ip_adapter_scale`.")
358
+ residual = hidden_states
359
+ # print("original encoder_hidden_states",encoder_hidden_states.shape)
360
+ if attn.spatial_norm is not None:
361
+ hidden_states = attn.spatial_norm(hidden_states, temb)
362
+
363
+ input_ndim = hidden_states.ndim
364
+
365
+ if input_ndim == 4:
366
+ batch_size, channel, height, width = hidden_states.shape
367
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
368
+ # print("hidden_states",hidden_states.shape)
369
+ # print("encoder_hidden_states",encoder_hidden_states.shape)
370
+ # encoder_hidden_states = encoder_hidden_states.squeeze(1)
371
+ if encoder_hidden_states is not None and encoder_hidden_states.dim() < 3:
372
+ encoder_hidden_states = encoder_hidden_states.unsqueeze(0)
373
+ # print("encoder_hidden_states",encoder_hidden_states.shape)
374
+ batch_size, sequence_length, _ = (
375
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
376
+ )
377
+
378
+ if attention_mask is not None:
379
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
380
+ # scaled_dot_product_attention expects attention_mask shape to be
381
+ # (batch, heads, source_length, target_length)
382
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
383
+
384
+ if attn.group_norm is not None:
385
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
386
+
387
+ query = attn.to_q(hidden_states)
388
+
389
+ if encoder_hidden_states is None:
390
+ encoder_hidden_states = hidden_states
391
+ elif attn.norm_cross:
392
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
393
+ # print("in norm cross")
394
+ # print("encoder_hidden_states",encoder_hidden_states.shape)
395
+
396
+ # split hidden states
397
+ # end_pos = encoder_hidden_states.shape[1]//2
398
+ # print("encoder_hidden_states.shape",encoder_hidden_states.shape)
399
+ # print("end_pos",end_pos)
400
+ encoder_hidden_states, ip_hidden_states = (
401
+ encoder_hidden_states[:, :self.num_tokens, :],
402
+ encoder_hidden_states[:, self.num_tokens:, :],
403
+ )
404
+ # print("encoder_hidden_states",encoder_hidden_states.shape)
405
+ # print("ip_hidden_states",ip_hidden_states.shape)
406
+ key = attn.to_k(encoder_hidden_states)
407
+ value = attn.to_v(encoder_hidden_states)
408
+
409
+ inner_dim = key.shape[-1]
410
+ head_dim = inner_dim // attn.heads
411
+
412
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
413
+
414
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
415
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
416
+
417
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
418
+ # TODO: add support for attn.scale when we move to Torch 2.1
419
+ # print("query shape",query.shape)
420
+ # print("key shape",key.shape)
421
+ # print("value shape",value.shape)
422
+ # print("attention_mask",attention_mask)
423
+
424
+ if attention_mask != None:
425
+ target = attention_mask.shape
426
+ # print("target",target)
427
+ # print("attention_mask.shape",attention_mask.shape)
428
+ attention_mask = attention_mask.split(target[2], dim=3)[0]
429
+ hidden_states = F.scaled_dot_product_attention(
430
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
431
+ )
432
+ # print("hidden_states",hidden_states.shape)
433
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
434
+ hidden_states = hidden_states.to(query.dtype)
435
+ ip_key = self.to_k_ip(ip_hidden_states)
436
+ ip_value = self.to_v_ip(ip_hidden_states)
437
+
438
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
439
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
440
+
441
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
442
+ # TODO: add support for attn.scale when we move to Torch 2.1
443
+ ip_hidden_states = F.scaled_dot_product_attention(
444
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
445
+ )
446
+ # print("query",query.shape)
447
+ # print("ip_key",ip_key.shape)
448
+ # print("ip_value",ip_value.shape)
449
+ # print("ip_hidden_states",ip_hidden_states.shape)
450
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
451
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
452
+ # print("hidden_states",hidden_states)
453
+ # print("ip_hidden_states",ip_hidden_states)
454
+ hidden_states = hidden_states + self.scale * ip_hidden_states
455
+ # print("ip_hidden_states",ip_hidden_states.shape)
456
+ # linear proj
457
+ hidden_states = attn.to_out[0](hidden_states)
458
+ # dropout
459
+ hidden_states = attn.to_out[1](hidden_states)
460
+
461
+ if input_ndim == 4:
462
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
463
+
464
+ if attn.residual_connection:
465
+ print("residual_connection")
466
+ hidden_states = hidden_states + residual
467
+ # print(residual)
468
+ hidden_states = hidden_states / attn.rescale_output_factor
469
+
470
+ return hidden_states
471
+
472
+ ## for controlnet
473
+ class CNAttnProcessor:
474
+ r"""
475
+ Default processor for performing attention-related computations.
476
+ """
477
+
478
+ def __init__(self, num_tokens=4):
479
+ self.num_tokens = num_tokens
480
+
481
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
482
+ residual = hidden_states
483
+
484
+ if attn.spatial_norm is not None:
485
+ hidden_states = attn.spatial_norm(hidden_states, temb)
486
+
487
+ input_ndim = hidden_states.ndim
488
+
489
+ if input_ndim == 4:
490
+ batch_size, channel, height, width = hidden_states.shape
491
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
492
+
493
+ batch_size, sequence_length, _ = (
494
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
495
+ )
496
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
497
+
498
+ if attn.group_norm is not None:
499
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
500
+
501
+ query = attn.to_q(hidden_states)
502
+
503
+ if encoder_hidden_states is None:
504
+ encoder_hidden_states = hidden_states
505
+ else:
506
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
507
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
508
+ if attn.norm_cross:
509
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
510
+
511
+ key = attn.to_k(encoder_hidden_states)
512
+ value = attn.to_v(encoder_hidden_states)
513
+
514
+ query = attn.head_to_batch_dim(query)
515
+ key = attn.head_to_batch_dim(key)
516
+ value = attn.head_to_batch_dim(value)
517
+
518
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
519
+ hidden_states = torch.bmm(attention_probs, value)
520
+ hidden_states = attn.batch_to_head_dim(hidden_states)
521
+
522
+ # linear proj
523
+ hidden_states = attn.to_out[0](hidden_states)
524
+ # dropout
525
+ hidden_states = attn.to_out[1](hidden_states)
526
+
527
+ if input_ndim == 4:
528
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
529
+
530
+ if attn.residual_connection:
531
+ hidden_states = hidden_states + residual
532
+
533
+ hidden_states = hidden_states / attn.rescale_output_factor
534
+
535
+ return hidden_states
536
+
537
+
538
+ class CNAttnProcessor2_0:
539
+ r"""
540
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
541
+ """
542
+
543
+ def __init__(self, num_tokens=4):
544
+ if not hasattr(F, "scaled_dot_product_attention"):
545
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
546
+ self.num_tokens = num_tokens
547
+
548
+ def __call__(
549
+ self,
550
+ attn,
551
+ hidden_states,
552
+ encoder_hidden_states=None,
553
+ attention_mask=None,
554
+ temb=None,
555
+ ):
556
+ residual = hidden_states
557
+
558
+ if attn.spatial_norm is not None:
559
+ hidden_states = attn.spatial_norm(hidden_states, temb)
560
+
561
+ input_ndim = hidden_states.ndim
562
+
563
+ if input_ndim == 4:
564
+ batch_size, channel, height, width = hidden_states.shape
565
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
566
+
567
+ batch_size, sequence_length, _ = (
568
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
569
+ )
570
+
571
+ if attention_mask is not None:
572
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
573
+ # scaled_dot_product_attention expects attention_mask shape to be
574
+ # (batch, heads, source_length, target_length)
575
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
576
+
577
+ if attn.group_norm is not None:
578
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
579
+
580
+ query = attn.to_q(hidden_states)
581
+
582
+ if encoder_hidden_states is None:
583
+ encoder_hidden_states = hidden_states
584
+ else:
585
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
586
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
587
+ if attn.norm_cross:
588
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
589
+
590
+ key = attn.to_k(encoder_hidden_states)
591
+ value = attn.to_v(encoder_hidden_states)
592
+
593
+ inner_dim = key.shape[-1]
594
+ head_dim = inner_dim // attn.heads
595
+
596
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
597
+
598
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
599
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
600
+
601
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
602
+ # TODO: add support for attn.scale when we move to Torch 2.1
603
+ hidden_states = F.scaled_dot_product_attention(
604
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
605
+ )
606
+
607
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
608
+ hidden_states = hidden_states.to(query.dtype)
609
+
610
+ # linear proj
611
+ hidden_states = attn.to_out[0](hidden_states)
612
+ # dropout
613
+ hidden_states = attn.to_out[1](hidden_states)
614
+
615
+ if input_ndim == 4:
616
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
617
+
618
+ if attn.residual_connection:
619
+ hidden_states = hidden_states + residual
620
+
621
+ hidden_states = hidden_states / attn.rescale_output_factor
622
+
623
+ return hidden_states
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torchaudio
4
+ import torch
5
+ from pipeline.morph_pipeline_successed_ver1 import AudioLDM2MorphPipeline
6
+
7
+ pipeline = AudioLDM2MorphPipeline.from_pretrained("cvssp/audioldm2-large", torch_dtype=torch.float32)
8
+ pipeline.to("cuda")
9
+
10
+ def morph_audio(audio_file1, audio_file2, prompt1, prompt2, negative_prompt1="Low quality", negative_prompt2="Low quality"):
11
+ save_lora_dir = "output"
12
+ os.makedirs(save_lora_dir, exist_ok=True)
13
+
14
+ waveform, sample_rate = torchaudio.load(audio_file1)
15
+ duration = waveform.shape[1] / sample_rate
16
+ duration = int(duration)
17
+
18
+ _ = pipeline(
19
+ audio_file=audio_file1,
20
+ audio_file2=audio_file2,
21
+ audio_length_in_s=duration,
22
+ time_pooling=2,
23
+ freq_pooling=2,
24
+ prompt_1=prompt1,
25
+ prompt_2=prompt2,
26
+ negative_prompt_1=negative_prompt1,
27
+ negative_prompt_2=negative_prompt2,
28
+ save_lora_dir=save_lora_dir,
29
+ use_adain=True,
30
+ use_reschedule=True,
31
+ num_inference_steps=50,
32
+ lamd=0.6,
33
+ output_path=save_lora_dir,
34
+ num_frames=5,
35
+ fix_lora=None,
36
+ use_lora=True,
37
+ lora_steps=50,
38
+ noisy_latent_with_lora=True,
39
+ morphing_with_lora=True,
40
+ use_morph_prompt=True,
41
+ guidance_scale=7.5,
42
+ )
43
+
44
+ output_paths = [os.path.join(save_lora_dir, file) for file in os.listdir(save_lora_dir) if file.endswith(".wav")]
45
+ return output_paths
46
+
47
+ def interface(audio1, audio2, prompt1, prompt2):
48
+ output_paths = morph_audio(audio1, audio2, prompt1, prompt2)
49
+ return output_paths
50
+
51
+ # Gradio UI
52
+ with gr.Blocks() as demo:
53
+ gr.Markdown("### Audio Morphing Demo with AudioLDM2")
54
+
55
+ with gr.Row():
56
+ audio_file1 = gr.Audio(label="Upload Audio File 1", type="filepath")
57
+ audio_file2 = gr.Audio(label="Upload Audio File 2", type="filepath")
58
+
59
+ with gr.Row():
60
+ prompt1 = gr.Textbox(label="Prompt for Audio File 1")
61
+ prompt2 = gr.Textbox(label="Prompt for Audio File 2")
62
+
63
+ output_audios = gr.Audio(label="Generated Morphing Audios", type="filepath", interactive=False)
64
+ morph_button = gr.Button("Generate Morphing Audio")
65
+
66
+ morph_button.click(
67
+ interface,
68
+ inputs=[audio_file1, audio_file2, prompt1, prompt2],
69
+ outputs=[output_audios]
70
+ )
71
+
72
+ demo.launch()
audio_encoder/AudioMAE.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reference Repo: https://github.com/facebookresearch/AudioMAE
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from timm.models.layers import to_2tuple
8
+ from . import models_vit
9
+ from . import models_mae
10
+ import librosa
11
+ import librosa.display
12
+ import matplotlib.pyplot as plt
13
+ import numpy as np
14
+ import torchaudio
15
+
16
+ # model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128))
17
+ class Vanilla_AudioMAE(nn.Module):
18
+ """Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM2)"""
19
+
20
+ def __init__(
21
+ self,
22
+ ):
23
+ super().__init__()
24
+ model = models_mae.__dict__["mae_vit_base_patch16"](
25
+ in_chans=1, audio_exp=True, img_size=(1024, 128)
26
+ )
27
+
28
+ checkpoint_path = '/Data/home/Dennis/DeepMIR-2024/Final_Project/AP-adapter/pretrained.pth'
29
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
30
+ msg = model.load_state_dict(checkpoint['model'], strict=False)
31
+
32
+ # Skip the missing keys of decoder modules (not required)
33
+ # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')
34
+ self.model = model.eval()
35
+ self.model = model.train()
36
+
37
+ def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False):
38
+ """
39
+ x: mel fbank [Batch, 1, 1024 (T), 128 (F)]
40
+ mask_ratio: 'masking ratio (percentage of removed patches).'
41
+ """
42
+
43
+ with torch.no_grad():
44
+ # embed: [B, 513, 768] for mask_ratio=0.0
45
+ if no_mask:
46
+ if no_average:
47
+ # raise RuntimeError("This function is deprecated")
48
+ embed = self.model.forward_encoder_no_random_mask_no_average(
49
+ x
50
+ ) # mask_ratio
51
+ else:
52
+ embed = self.model.forward_encoder_no_mask(x) # mask_ratio
53
+ else:
54
+ raise RuntimeError("This function is deprecated")
55
+ embed, _, _, _ = self.model.forward_encoder(x, mask_ratio=mask_ratio)
56
+ return embed
57
+ import torchaudio
58
+ import numpy as np
59
+ import torch
60
+
61
+ # def roll_mag_aug(waveform):
62
+ # idx = np.random.randint(len(waveform))
63
+ # rolled_waveform = np.roll(waveform, idx)
64
+ # mag = np.random.beta(10, 10) + 0.5
65
+ # return torch.Tensor(rolled_waveform * mag)
66
+
67
+ def wav_to_fbank(filename, melbins, target_length, roll_mag_aug_flag=False):
68
+ waveform, sr = torchaudio.load(filename)
69
+ waveform = waveform - waveform.mean()
70
+ fbank = torchaudio.compliance.kaldi.fbank(
71
+ waveform,
72
+ htk_compat=True,
73
+ sample_frequency=sr,
74
+ use_energy=False,
75
+ window_type='hanning',
76
+ num_mel_bins=melbins,
77
+ dither=0.0,
78
+ frame_shift=10
79
+ )
80
+
81
+ n_frames = fbank.shape[0]
82
+ p = target_length - n_frames
83
+
84
+ # Cut and pad
85
+ if p > 0:
86
+ m = torch.nn.ZeroPad2d((0, 0, 0, p))
87
+ fbank = m(fbank)
88
+ elif p < 0:
89
+ fbank = fbank[0:target_length, :]
90
+
91
+ return fbank
92
+
93
+ # Example usage
94
+ import torch.nn.functional as F
95
+ class AudioMAEConditionCTPoolRand(nn.Module):
96
+ """
97
+ audiomae = AudioMAEConditionCTPool2x2()
98
+ data = torch.randn((4, 1024, 128))
99
+ output = audiomae(data)
100
+ import ipdb;ipdb.set_trace()
101
+ exit(0)
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ time_pooling_factors=[1, 2, 4, 8],
107
+ freq_pooling_factors=[1, 2, 4, 8],
108
+ eval_time_pooling=8,
109
+ eval_freq_pooling=8,
110
+ mask_ratio=0.0,
111
+ regularization=False,
112
+ no_audiomae_mask=True,
113
+ no_audiomae_average=True,
114
+ ):
115
+ super().__init__()
116
+ self.device = None
117
+ self.time_pooling_factors = time_pooling_factors
118
+ self.freq_pooling_factors = freq_pooling_factors
119
+ self.no_audiomae_mask = no_audiomae_mask
120
+ self.no_audiomae_average = no_audiomae_average
121
+
122
+ self.eval_freq_pooling = eval_freq_pooling
123
+ self.eval_time_pooling = eval_time_pooling
124
+ self.mask_ratio = mask_ratio
125
+ self.use_reg = regularization
126
+
127
+ self.audiomae = Vanilla_AudioMAE()
128
+ self.audiomae.eval()
129
+ for p in self.audiomae.parameters():
130
+ p.requires_grad = False
131
+
132
+ # Required
133
+ def get_unconditional_condition(self, batchsize):
134
+ param = next(self.audiomae.parameters())
135
+ assert param.requires_grad == False
136
+ device = param.device
137
+ # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors)
138
+ time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
139
+ self.eval_freq_pooling, 8
140
+ )
141
+ # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))]
142
+ # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))]
143
+ token_num = int(512 / (time_pool * freq_pool))
144
+ return [
145
+ torch.zeros((batchsize, token_num, 768)).to(device).float(),
146
+ torch.ones((batchsize, token_num)).to(device).float(),
147
+ ]
148
+
149
+ def pool(self, representation, time_pool=None, freq_pool=None):
150
+ assert representation.size(-1) == 768
151
+ representation = representation[:, 1:, :].transpose(1, 2)
152
+ # print("representation.shape",representation.shape)
153
+ bs, embedding_dim, token_num = representation.size()
154
+ representation = representation.reshape(bs, embedding_dim, 64, 8)
155
+
156
+ # if self.training:
157
+ # if time_pool is None and freq_pool is None:
158
+ # time_pool = min(
159
+ # 64,
160
+ # self.time_pooling_factors[
161
+ # np.random.choice(list(range(len(self.time_pooling_factors))))
162
+ # ],
163
+ # )
164
+ # # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))]
165
+ # freq_pool = min(8, time_pool) # TODO here I make some modification.
166
+ # else:
167
+ # time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
168
+ # self.eval_freq_pooling, 8
169
+ # )
170
+
171
+ self.avgpooling = nn.AvgPool2d(
172
+ kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
173
+ )
174
+ self.maxpooling = nn.MaxPool2d(
175
+ kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
176
+ )
177
+
178
+ pooled = (
179
+ self.avgpooling(representation) + self.maxpooling(representation)
180
+ ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num]
181
+ # print("pooled.shape",pooled.shape)
182
+ pooled = pooled.flatten(2).transpose(1, 2)
183
+ return pooled # [bs, token_num, embedding_dim]
184
+
185
+ def regularization(self, x):
186
+ assert x.size(-1) == 768
187
+ x = F.normalize(x, p=2, dim=-1)
188
+ return x
189
+
190
+ # Required
191
+ def forward(self, batch, time_pool=None, freq_pool=None):
192
+ assert batch.size(-2) == 1024 and batch.size(-1) == 128
193
+
194
+ if self.device is None:
195
+ self.device = next(self.audiomae.parameters()).device
196
+
197
+ batch = batch.unsqueeze(1).to(self.device)
198
+ with torch.no_grad():
199
+ representation = self.audiomae(
200
+ batch,
201
+ mask_ratio=self.mask_ratio,
202
+ no_mask=self.no_audiomae_mask,
203
+ no_average=self.no_audiomae_average,
204
+ )
205
+ representation = self.pool(representation, time_pool, freq_pool)
206
+ if self.use_reg:
207
+ representation = self.regularization(representation)
208
+ return [
209
+ representation,
210
+ torch.ones((representation.size(0), representation.size(1)))
211
+ .to(representation.device)
212
+ # .float(),
213
+ ]
214
+
215
+
216
+ class AudioMAEConditionCTPoolRandTFSeparated(nn.Module):
217
+ """
218
+ audiomae = AudioMAEConditionCTPool2x2()
219
+ data = torch.randn((4, 1024, 128))
220
+ output = audiomae(data)
221
+ import ipdb;ipdb.set_trace()
222
+ exit(0)
223
+ """
224
+
225
+ def __init__(
226
+ self,
227
+ time_pooling_factors=[8],
228
+ freq_pooling_factors=[8],
229
+ eval_time_pooling=8,
230
+ eval_freq_pooling=8,
231
+ mask_ratio=0.0,
232
+ regularization=False,
233
+ no_audiomae_mask=True,
234
+ no_audiomae_average=False,
235
+ ):
236
+ super().__init__()
237
+ self.device = None
238
+ self.time_pooling_factors = time_pooling_factors
239
+ self.freq_pooling_factors = freq_pooling_factors
240
+ self.no_audiomae_mask = no_audiomae_mask
241
+ self.no_audiomae_average = no_audiomae_average
242
+
243
+ self.eval_freq_pooling = eval_freq_pooling
244
+ self.eval_time_pooling = eval_time_pooling
245
+ self.mask_ratio = mask_ratio
246
+ self.use_reg = regularization
247
+
248
+ self.audiomae = Vanilla_AudioMAE()
249
+ self.audiomae.eval()
250
+ for p in self.audiomae.parameters():
251
+ p.requires_grad = False
252
+
253
+ # Required
254
+ def get_unconditional_condition(self, batchsize):
255
+ param = next(self.audiomae.parameters())
256
+ assert param.requires_grad == False
257
+ device = param.device
258
+ # time_pool, freq_pool = max(self.time_pooling_factors), max(self.freq_pooling_factors)
259
+ time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
260
+ self.eval_freq_pooling, 8
261
+ )
262
+ # time_pool = self.time_pooling_factors[np.random.choice(list(range(len(self.time_pooling_factors))))]
263
+ # freq_pool = self.freq_pooling_factors[np.random.choice(list(range(len(self.freq_pooling_factors))))]
264
+ token_num = int(512 / (time_pool * freq_pool))
265
+ return [
266
+ torch.zeros((batchsize, token_num, 768)).to(device).float(),
267
+ torch.ones((batchsize, token_num)).to(device).float(),
268
+ ]
269
+
270
+ def pool(self, representation, time_pool=None, freq_pool=None):
271
+ assert representation.size(-1) == 768
272
+ representation = representation[:, 1:, :].transpose(1, 2)
273
+ bs, embedding_dim, token_num = representation.size()
274
+ representation = representation.reshape(bs, embedding_dim, 64, 8)
275
+
276
+ # if self.training:
277
+ # if time_pool is None and freq_pool is None:
278
+ # time_pool = min(
279
+ # 64,
280
+ # self.time_pooling_factors[
281
+ # np.random.choice(list(range(len(self.time_pooling_factors))))
282
+ # ],
283
+ # )
284
+ # freq_pool = min(
285
+ # 8,
286
+ # self.freq_pooling_factors[
287
+ # np.random.choice(list(range(len(self.freq_pooling_factors))))
288
+ # ],
289
+ # )
290
+ # # freq_pool = min(8, time_pool) # TODO here I make some modification.
291
+ # else:
292
+ # time_pool, freq_pool = min(self.eval_time_pooling, 64), min(
293
+ # self.eval_freq_pooling, 8
294
+ # )
295
+
296
+ self.avgpooling = nn.AvgPool2d(
297
+ kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
298
+ )
299
+ self.maxpooling = nn.MaxPool2d(
300
+ kernel_size=(time_pool, freq_pool), stride=(time_pool, freq_pool)
301
+ )
302
+
303
+ pooled = (
304
+ self.avgpooling(representation) + self.maxpooling(representation)
305
+ ) / 2 # [bs, embedding_dim, time_token_num, freq_token_num]
306
+ pooled = pooled.flatten(2).transpose(1, 2)
307
+ return pooled # [bs, token_num, embedding_dim]
308
+
309
+ def regularization(self, x):
310
+ assert x.size(-1) == 768
311
+ x = F.normalize(x, p=2, dim=-1)
312
+ return x
313
+
314
+ # Required
315
+ def forward(self, batch, time_pool=None, freq_pool=None):
316
+ assert batch.size(-2) == 1024 and batch.size(-1) == 128
317
+
318
+ if self.device is None:
319
+ self.device = batch.device
320
+
321
+ batch = batch.unsqueeze(1)
322
+ with torch.no_grad():
323
+ representation = self.audiomae(
324
+ batch,
325
+ mask_ratio=self.mask_ratio,
326
+ no_mask=self.no_audiomae_mask,
327
+ no_average=self.no_audiomae_average,
328
+ )
329
+ representation = self.pool(representation, time_pool, freq_pool)
330
+ if self.use_reg:
331
+ representation = self.regularization(representation)
332
+ return [
333
+ representation,
334
+ torch.ones((representation.size(0), representation.size(1)))
335
+ .to(representation.device)
336
+ .float(),
337
+ ]
338
+ def apply_time_mask(spectrogram, mask_width_range=(1000, 1001), max_masks=2):
339
+ """
340
+ Apply time masking to a spectrogram (PyTorch tensor).
341
+
342
+ :param spectrogram: A PyTorch tensor of shape (time_steps, frequency_bands)
343
+ :param mask_width_range: A tuple indicating the min and max width of the mask
344
+ :param max_masks: Maximum number of masks to apply
345
+ :return: Masked spectrogram
346
+ """
347
+ time_steps, frequency_bands = spectrogram.shape
348
+ masked_spectrogram = spectrogram.clone()
349
+
350
+ for _ in range(max_masks):
351
+ mask_width = torch.randint(mask_width_range[0], mask_width_range[1], (1,)).item()
352
+ start_step = torch.randint(0, time_steps - mask_width, (1,)).item()
353
+ masked_spectrogram[100:1024, :] = 0 # or another constant value
354
+
355
+ return masked_spectrogram
356
+
357
+ def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec= torch.zeros((1024, 128)), num_mels=128):
358
+ norm_mean = -4.2677393
359
+ norm_std = 4.5689974
360
+ if sampling_rate != 16000:
361
+ waveform_16k = torchaudio.functional.resample(
362
+ waveform, orig_freq=sampling_rate, new_freq=16000
363
+ )
364
+ else:
365
+ waveform_16k = waveform
366
+ waveform_16k = waveform_16k - waveform_16k.mean()
367
+ fbank = torchaudio.compliance.kaldi.fbank(
368
+ waveform_16k,
369
+ htk_compat=True,
370
+ sample_frequency=16000,
371
+ use_energy=False,
372
+ window_type="hanning",
373
+ num_mel_bins=num_mels,
374
+ dither=0.0,
375
+ frame_shift=10,
376
+ )
377
+ TARGET_LEN = log_mel_spec.size(0)
378
+ # cut and pad
379
+ n_frames = fbank.shape[0]
380
+ p = TARGET_LEN - n_frames
381
+ # print(TARGET_LEN)
382
+ # print(n_frames)
383
+ if p > 0:
384
+ m = torch.nn.ZeroPad2d((0, 0, 0, p))
385
+ fbank = m(fbank)
386
+ elif p < 0:
387
+ fbank = fbank[:TARGET_LEN, :]
388
+ fbank = (fbank - norm_mean) / (norm_std * 2)
389
+ # fbank = apply_time_mask(fbank)
390
+ return fbank
391
+
392
+ if __name__ == "__main__":
393
+
394
+ filename = '/home/fundwotsai/DreamSound/training_audio_v2/output_slice_18.wav'
395
+ waveform, sr = torchaudio.load(filename)
396
+ fbank = torch.zeros(
397
+ (1024, 128)
398
+ )
399
+ ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, 16000,fbank)
400
+ print(ta_kaldi_fbank.shape)
401
+ # melbins = 128 # Number of Mel bins
402
+ # target_length = 1024 # Number of frames
403
+ # fbank = wav_to_fbank(file_path, melbins, target_length, roll_mag_aug_flag=False)
404
+ # print(fbank.shape)
405
+ # # Convert to PyTorch tensor and reshape
406
+ mel_spect_tensor = torch.tensor(ta_kaldi_fbank).unsqueeze(0) # [Batch, Channel, Time, Frequency]
407
+
408
+ mel_spect_tensor = mel_spect_tensor.to("cuda")
409
+ # Save the figure
410
+ print("mel_spect_tensor111.shape",mel_spect_tensor.shape)
411
+ model = AudioMAEConditionCTPoolRand().cuda()
412
+ print("The first run")
413
+ embed = model(mel_spect_tensor, time_pool=1, freq_pool=1)
414
+ print(embed[0].shape)
415
+
416
+ # Reshape tensor for 2D pooling: treat each 768 as a channel
417
+ # Example usage
418
+ # Assuming the pooling operation reduces the second dimension from 513 to 8
419
+
420
+
421
+ torch.save(embed[0], "MAE_feature1_stride-no-pool.pt")
422
+ with open('output_tensor.txt', 'w') as f:
423
+ print(embed[0], file=f)
424
+
audio_encoder/models_mae.py ADDED
@@ -0,0 +1,738 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # --------------------------------------------------------
11
+
12
+ from functools import partial
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import numpy as np
17
+ from timm.models.layers import to_2tuple
18
+ from timm.models.vision_transformer import Block
19
+ class PatchEmbed_org(nn.Module):
20
+ """Image to Patch Embedding"""
21
+
22
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
23
+ super().__init__()
24
+ img_size = to_2tuple(img_size)
25
+ patch_size = to_2tuple(patch_size)
26
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
27
+ self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
28
+ self.img_size = img_size
29
+ self.patch_size = patch_size
30
+ self.num_patches = num_patches
31
+
32
+ self.proj = nn.Conv2d(
33
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
34
+ )
35
+
36
+ def forward(self, x):
37
+ # print(x.shape)
38
+ B, C, H, W = x.shape
39
+ # FIXME look at relaxing size constraints
40
+ # assert H == self.img_size[0] and W == self.img_size[1], \
41
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
42
+ x = self.proj(x)
43
+ y = x.flatten(2).transpose(1, 2)
44
+ return y
45
+
46
+
47
+ class PatchEmbed_new(nn.Module):
48
+ """Flexible Image to Patch Embedding"""
49
+
50
+ def __init__(
51
+ self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10
52
+ ):
53
+ super().__init__()
54
+ img_size = to_2tuple(img_size)
55
+ patch_size = to_2tuple(patch_size)
56
+ stride = to_2tuple(stride)
57
+
58
+ self.img_size = img_size
59
+ self.patch_size = patch_size
60
+
61
+ self.proj = nn.Conv2d(
62
+ in_chans, embed_dim, kernel_size=patch_size, stride=stride
63
+ ) # with overlapped patches
64
+ # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
65
+
66
+ # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
67
+ # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
68
+ _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
69
+ self.patch_hw = (h, w)
70
+ self.num_patches = h * w
71
+
72
+ def get_output_shape(self, img_size):
73
+ # todo: don't be lazy..
74
+ return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape
75
+
76
+ def forward(self, x):
77
+ B, C, H, W = x.shape
78
+ # FIXME look at relaxing size constraints
79
+ # assert H == self.img_size[0] and W == self.img_size[1], \
80
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
81
+ # x = self.proj(x).flatten(2).transpose(1, 2)
82
+ x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12
83
+ x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212
84
+ x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768
85
+ return x
86
+
87
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
88
+ """
89
+ grid_size: int of the grid height and width
90
+ return:
91
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
92
+ """
93
+ grid_h = np.arange(grid_size, dtype=np.float32)
94
+ grid_w = np.arange(grid_size, dtype=np.float32)
95
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
96
+ grid = np.stack(grid, axis=0)
97
+
98
+ grid = grid.reshape([2, 1, grid_size, grid_size])
99
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
100
+ if cls_token:
101
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
102
+ return pos_embed
103
+
104
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
105
+ """
106
+ embed_dim: output dimension for each position
107
+ pos: a list of positions to be encoded: size (M,)
108
+ out: (M, D)
109
+ """
110
+ assert embed_dim % 2 == 0
111
+ # omega = np.arange(embed_dim // 2, dtype=np.float)
112
+ omega = np.arange(embed_dim // 2, dtype=float)
113
+ omega /= embed_dim / 2.0
114
+ omega = 1.0 / 10000**omega # (D/2,)
115
+
116
+ pos = pos.reshape(-1) # (M,)
117
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
118
+
119
+ emb_sin = np.sin(out) # (M, D/2)
120
+ emb_cos = np.cos(out) # (M, D/2)
121
+
122
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
123
+ return emb
124
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
125
+ assert embed_dim % 2 == 0
126
+ # print("embed_dim",embed_dim)
127
+ # print("[grid[0]",grid[0])
128
+ # print("[grid[1]",grid[1])
129
+ # use half of dimensions to encode grid_h
130
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
131
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
132
+ # print("emb_h",emb_h.shape)
133
+ # print("emb_w",emb_w.shape)
134
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
135
+ return emb
136
+ def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):
137
+ """
138
+ grid_size: int of the grid height and width
139
+ return:
140
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
141
+ """
142
+ grid_h = np.arange(grid_size[0], dtype=np.float32)
143
+ grid_w = np.arange(grid_size[1], dtype=np.float32)
144
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
145
+ grid = np.stack(grid, axis=0)
146
+
147
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
148
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
149
+ if cls_token:
150
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
151
+ return pos_embed
152
+
153
+ class MaskedAutoencoderViT(nn.Module):
154
+ """Masked Autoencoder with VisionTransformer backbone"""
155
+
156
+ def __init__(
157
+ self,
158
+ img_size=224,
159
+ patch_size=16,
160
+ stride=10,
161
+ in_chans=3,
162
+ embed_dim=1024,
163
+ depth=24,
164
+ num_heads=16,
165
+ decoder_embed_dim=512,
166
+ decoder_depth=8,
167
+ decoder_num_heads=16,
168
+ mlp_ratio=4.0,
169
+ norm_layer=nn.LayerNorm,
170
+ norm_pix_loss=False,
171
+ audio_exp=False,
172
+ alpha=0.0,
173
+ temperature=0.2,
174
+ mode=0,
175
+ contextual_depth=8,
176
+ use_custom_patch=False,
177
+ split_pos=False,
178
+ pos_trainable=False,
179
+ use_nce=False,
180
+ beta=4.0,
181
+ decoder_mode=0,
182
+ mask_t_prob=0.6,
183
+ mask_f_prob=0.5,
184
+ mask_2d=False,
185
+ epoch=0,
186
+ no_shift=False,
187
+ ):
188
+ super().__init__()
189
+
190
+ self.audio_exp = audio_exp
191
+ self.embed_dim = embed_dim
192
+ self.decoder_embed_dim = decoder_embed_dim
193
+ # --------------------------------------------------------------------------
194
+ # MAE encoder specifics
195
+ if use_custom_patch:
196
+ print(
197
+ f"Use custom patch_emb with patch size: {patch_size}, stride: {stride}"
198
+ )
199
+ self.patch_embed = PatchEmbed_new(
200
+ img_size=img_size,
201
+ patch_size=patch_size,
202
+ in_chans=in_chans,
203
+ embed_dim=embed_dim,
204
+ stride=stride,
205
+ )
206
+ else:
207
+ self.patch_embed = PatchEmbed_org(img_size, patch_size, in_chans, embed_dim)
208
+ self.use_custom_patch = use_custom_patch
209
+ num_patches = self.patch_embed.num_patches
210
+
211
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
212
+
213
+ # self.split_pos = split_pos # not useful
214
+ self.pos_embed = nn.Parameter(
215
+ torch.zeros(1, num_patches + 1, embed_dim), requires_grad=pos_trainable
216
+ ) # fixed sin-cos embedding
217
+
218
+ self.encoder_depth = depth
219
+ self.contextual_depth = contextual_depth
220
+ self.blocks = nn.ModuleList(
221
+ [
222
+ Block(
223
+ embed_dim,
224
+ num_heads,
225
+ mlp_ratio,
226
+ qkv_bias=True,
227
+ norm_layer=norm_layer,
228
+ ) # qk_scale=None
229
+ for i in range(depth)
230
+ ]
231
+ )
232
+ self.norm = norm_layer(embed_dim)
233
+
234
+ # --------------------------------------------------------------------------
235
+ # MAE decoder specifics
236
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
237
+
238
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
239
+ self.decoder_pos_embed = nn.Parameter(
240
+ torch.zeros(1, num_patches + 1, decoder_embed_dim),
241
+ requires_grad=pos_trainable,
242
+ ) # fixed sin-cos embedding
243
+
244
+ self.no_shift = no_shift
245
+
246
+ self.decoder_mode = decoder_mode
247
+ if (
248
+ self.use_custom_patch
249
+ ): # overlapped patches as in AST. Similar performance yet compute heavy
250
+ window_size = (6, 6)
251
+ feat_size = (102, 12)
252
+ else:
253
+ window_size = (4, 4)
254
+ feat_size = (64, 8)
255
+ if self.decoder_mode == 1:
256
+ decoder_modules = []
257
+ for index in range(16):
258
+ if self.no_shift:
259
+ shift_size = (0, 0)
260
+ else:
261
+ if (index % 2) == 0:
262
+ shift_size = (0, 0)
263
+ else:
264
+ shift_size = (2, 0)
265
+ # shift_size = tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size])
266
+ decoder_modules.append(
267
+ SwinTransformerBlock(
268
+ dim=decoder_embed_dim,
269
+ num_heads=16,
270
+ feat_size=feat_size,
271
+ window_size=window_size,
272
+ shift_size=shift_size,
273
+ mlp_ratio=mlp_ratio,
274
+ drop=0.0,
275
+ drop_attn=0.0,
276
+ drop_path=0.0,
277
+ extra_norm=False,
278
+ sequential_attn=False,
279
+ norm_layer=norm_layer, # nn.LayerNorm,
280
+ )
281
+ )
282
+ self.decoder_blocks = nn.ModuleList(decoder_modules)
283
+ else:
284
+ # Transfomer
285
+ self.decoder_blocks = nn.ModuleList(
286
+ [
287
+ Block(
288
+ decoder_embed_dim,
289
+ decoder_num_heads,
290
+ mlp_ratio,
291
+ qkv_bias=True,
292
+ norm_layer=norm_layer,
293
+ ) # qk_scale=None,
294
+ for i in range(decoder_depth)
295
+ ]
296
+ )
297
+
298
+ self.decoder_norm = norm_layer(decoder_embed_dim)
299
+ self.decoder_pred = nn.Linear(
300
+ decoder_embed_dim, patch_size**2 * in_chans, bias=True
301
+ ) # decoder to patch
302
+
303
+ # --------------------------------------------------------------------------
304
+
305
+ self.norm_pix_loss = norm_pix_loss
306
+
307
+ self.patch_size = patch_size
308
+ self.stride = stride
309
+
310
+ # audio exps
311
+ self.alpha = alpha
312
+ self.T = temperature
313
+ self.mode = mode
314
+ self.use_nce = use_nce
315
+ self.beta = beta
316
+
317
+ self.log_softmax = nn.LogSoftmax(dim=-1)
318
+
319
+ self.mask_t_prob = mask_t_prob
320
+ self.mask_f_prob = mask_f_prob
321
+ self.mask_2d = mask_2d
322
+
323
+ self.epoch = epoch
324
+
325
+ self.initialize_weights()
326
+
327
+ def initialize_weights(self):
328
+ # initialization
329
+ # initialize (and freeze) pos_embed by sin-cos embedding
330
+ if self.audio_exp:
331
+ pos_embed = get_2d_sincos_pos_embed_flexible(
332
+ self.pos_embed.shape[-1], self.patch_embed.patch_hw, cls_token=True
333
+ )
334
+ else:
335
+ pos_embed = get_2d_sincos_pos_embed(
336
+ self.pos_embed.shape[-1],
337
+ int(self.patch_embed.num_patches**0.5),
338
+ cls_token=True,
339
+ )
340
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
341
+
342
+ if self.audio_exp:
343
+ decoder_pos_embed = get_2d_sincos_pos_embed_flexible(
344
+ self.decoder_pos_embed.shape[-1],
345
+ self.patch_embed.patch_hw,
346
+ cls_token=True,
347
+ )
348
+ else:
349
+ decoder_pos_embed = get_2d_sincos_pos_embed(
350
+ self.decoder_pos_embed.shape[-1],
351
+ int(self.patch_embed.num_patches**0.5),
352
+ cls_token=True,
353
+ )
354
+ self.decoder_pos_embed.data.copy_(
355
+ torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
356
+ )
357
+
358
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
359
+ w = self.patch_embed.proj.weight.data
360
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
361
+
362
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
363
+ torch.nn.init.normal_(self.cls_token, std=0.02)
364
+ torch.nn.init.normal_(self.mask_token, std=0.02)
365
+
366
+ # initialize nn.Linear and nn.LayerNorm
367
+ self.apply(self._init_weights)
368
+
369
+ def _init_weights(self, m):
370
+ if isinstance(m, nn.Linear):
371
+ # we use xavier_uniform following official JAX ViT:
372
+ torch.nn.init.xavier_uniform_(m.weight)
373
+ if isinstance(m, nn.Linear) and m.bias is not None:
374
+ nn.init.constant_(m.bias, 0)
375
+ elif isinstance(m, nn.LayerNorm):
376
+ nn.init.constant_(m.bias, 0)
377
+ nn.init.constant_(m.weight, 1.0)
378
+
379
+ def patchify(self, imgs):
380
+ """
381
+ imgs: (N, 3, H, W)
382
+ x: (N, L, patch_size**2 *3)
383
+ L = (H/p)*(W/p)
384
+ """
385
+ p = self.patch_embed.patch_size[0]
386
+ # assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
387
+
388
+ if self.audio_exp:
389
+ if self.use_custom_patch: # overlapped patch
390
+ h, w = self.patch_embed.patch_hw
391
+ # todo: fixed h/w patch size and stride size. Make hw custom in the future
392
+ x = imgs.unfold(2, self.patch_size, self.stride).unfold(
393
+ 3, self.patch_size, self.stride
394
+ ) # n,1,H,W -> n,1,h,w,p,p
395
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
396
+ # x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))
397
+ # x = torch.einsum('nchpwq->nhwpqc', x)
398
+ # x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
399
+ else:
400
+ h = imgs.shape[2] // p
401
+ w = imgs.shape[3] // p
402
+ # h,w = self.patch_embed.patch_hw
403
+ x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))
404
+ x = torch.einsum("nchpwq->nhwpqc", x)
405
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
406
+ else:
407
+ h = w = imgs.shape[2] // p
408
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
409
+ x = torch.einsum("nchpwq->nhwpqc", x)
410
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
411
+
412
+ return x
413
+
414
+ def unpatchify(self, x):
415
+ """
416
+ x: (N, L, patch_size**2 *3)
417
+ specs: (N, 1, H, W)
418
+ """
419
+ p = self.patch_embed.patch_size[0]
420
+ h = 1024 // p
421
+ w = 128 // p
422
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, 1))
423
+ x = torch.einsum("nhwpqc->nchpwq", x)
424
+ specs = x.reshape(shape=(x.shape[0], 1, h * p, w * p))
425
+ return specs
426
+
427
+ def random_masking(self, x, mask_ratio):
428
+ """
429
+ Perform per-sample random masking by per-sample shuffling.
430
+ Per-sample shuffling is done by argsort random noise.
431
+ x: [N, L, D], sequence
432
+ """
433
+ N, L, D = x.shape # batch, length, dim
434
+ len_keep = int(L * (1 - mask_ratio))
435
+
436
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
437
+
438
+ # sort noise for each sample
439
+ ids_shuffle = torch.argsort(
440
+ noise, dim=1
441
+ ) # ascend: small is keep, large is remove
442
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
443
+
444
+ # keep the first subset
445
+ ids_keep = ids_shuffle[:, :len_keep]
446
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
447
+
448
+ # generate the binary mask: 0 is keep, 1 is remove
449
+ mask = torch.ones([N, L], device=x.device)
450
+ mask[:, :len_keep] = 0
451
+ # unshuffle to get the binary mask
452
+ mask = torch.gather(mask, dim=1, index=ids_restore)
453
+
454
+ return x_masked, mask, ids_restore
455
+
456
+ def random_masking_2d(self, x, mask_t_prob, mask_f_prob):
457
+ """
458
+ 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)
459
+ Perform per-sample random masking by per-sample shuffling.
460
+ Per-sample shuffling is done by argsort random noise.
461
+ x: [N, L, D], sequence
462
+ """
463
+ N, L, D = x.shape # batch, length, dim
464
+ if self.use_custom_patch: # overlapped patch
465
+ T = 101
466
+ F = 12
467
+ else:
468
+ T = 64
469
+ F = 8
470
+ # x = x.reshape(N, T, F, D)
471
+ len_keep_t = int(T * (1 - mask_t_prob))
472
+ len_keep_f = int(F * (1 - mask_f_prob))
473
+
474
+ # noise for mask in time
475
+ noise_t = torch.rand(N, T, device=x.device) # noise in [0, 1]
476
+ # sort noise for each sample aling time
477
+ ids_shuffle_t = torch.argsort(
478
+ noise_t, dim=1
479
+ ) # ascend: small is keep, large is remove
480
+ ids_restore_t = torch.argsort(ids_shuffle_t, dim=1)
481
+ ids_keep_t = ids_shuffle_t[:, :len_keep_t]
482
+ # noise mask in freq
483
+ noise_f = torch.rand(N, F, device=x.device) # noise in [0, 1]
484
+ ids_shuffle_f = torch.argsort(
485
+ noise_f, dim=1
486
+ ) # ascend: small is keep, large is remove
487
+ ids_restore_f = torch.argsort(ids_shuffle_f, dim=1)
488
+ ids_keep_f = ids_shuffle_f[:, :len_keep_f] #
489
+
490
+ # generate the binary mask: 0 is keep, 1 is remove
491
+ # mask in freq
492
+ mask_f = torch.ones(N, F, device=x.device)
493
+ mask_f[:, :len_keep_f] = 0
494
+ mask_f = (
495
+ torch.gather(mask_f, dim=1, index=ids_restore_f)
496
+ .unsqueeze(1)
497
+ .repeat(1, T, 1)
498
+ ) # N,T,F
499
+ # mask in time
500
+ mask_t = torch.ones(N, T, device=x.device)
501
+ mask_t[:, :len_keep_t] = 0
502
+ mask_t = (
503
+ torch.gather(mask_t, dim=1, index=ids_restore_t)
504
+ .unsqueeze(1)
505
+ .repeat(1, F, 1)
506
+ .permute(0, 2, 1)
507
+ ) # N,T,F
508
+ mask = 1 - (1 - mask_t) * (1 - mask_f) # N, T, F
509
+
510
+ # get masked x
511
+ id2res = torch.Tensor(list(range(N * T * F))).reshape(N, T, F).to(x.device)
512
+ id2res = id2res + 999 * mask # add a large value for masked elements
513
+ id2res2 = torch.argsort(id2res.flatten(start_dim=1))
514
+ ids_keep = id2res2.flatten(start_dim=1)[:, : len_keep_f * len_keep_t]
515
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
516
+
517
+ ids_restore = torch.argsort(id2res2.flatten(start_dim=1))
518
+ mask = mask.flatten(start_dim=1)
519
+
520
+ return x_masked, mask, ids_restore
521
+
522
+ def forward_encoder(self, x, mask_ratio, mask_2d=False):
523
+ # embed patches
524
+ x = self.patch_embed(x)
525
+ # add pos embed w/o cls token
526
+ x = x + self.pos_embed[:, 1:, :]
527
+
528
+ # masking: length -> length * mask_ratio
529
+ if mask_2d:
530
+ x, mask, ids_restore = self.random_masking_2d(
531
+ x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob
532
+ )
533
+ else:
534
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
535
+
536
+ # append cls token
537
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
538
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
539
+ x = torch.cat((cls_tokens, x), dim=1)
540
+
541
+ # apply Transformer blocks
542
+ for blk in self.blocks:
543
+ x = blk(x)
544
+ x = self.norm(x)
545
+
546
+ return x, mask, ids_restore, None
547
+
548
+ def forward_encoder_no_random_mask_no_average(self, x):
549
+ # embed patches
550
+ x = self.patch_embed(x)
551
+ # add pos embed w/o cls token
552
+ x = x + self.pos_embed[:, 1:, :]
553
+
554
+ # masking: length -> length * mask_ratio
555
+ # if mask_2d:
556
+ # x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob)
557
+ # else:
558
+ # x, mask, ids_restore = self.random_masking(x, mask_ratio)
559
+
560
+ # append cls token
561
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
562
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
563
+ x = torch.cat((cls_tokens, x), dim=1)
564
+
565
+ # apply Transformer blocks
566
+ for blk in self.blocks:
567
+ x = blk(x)
568
+ x = self.norm(x)
569
+
570
+ return x
571
+
572
+ def forward_encoder_no_mask(self, x):
573
+ # embed patches
574
+ x = self.patch_embed(x)
575
+
576
+ # add pos embed w/o cls token
577
+ x = x + self.pos_embed[:, 1:, :]
578
+
579
+ # masking: length -> length * mask_ratio
580
+ # x, mask, ids_restore = self.random_masking(x, mask_ratio)
581
+ # append cls token
582
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
583
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
584
+ x = torch.cat((cls_tokens, x), dim=1)
585
+
586
+ # apply Transformer blocks
587
+ contextual_embs = []
588
+ for n, blk in enumerate(self.blocks):
589
+ x = blk(x)
590
+ if n > self.contextual_depth:
591
+ contextual_embs.append(self.norm(x))
592
+ # x = self.norm(x)
593
+ contextual_emb = torch.stack(contextual_embs, dim=0).mean(dim=0)
594
+
595
+ return contextual_emb
596
+
597
+ def forward_decoder(self, x, ids_restore):
598
+ # embed tokens
599
+ x = self.decoder_embed(x)
600
+
601
+ # append mask tokens to sequence
602
+ mask_tokens = self.mask_token.repeat(
603
+ x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
604
+ )
605
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
606
+ x_ = torch.gather(
607
+ x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
608
+ ) # unshuffle
609
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
610
+
611
+ # add pos embed
612
+ x = x + self.decoder_pos_embed
613
+
614
+ if self.decoder_mode != 0:
615
+ B, L, D = x.shape
616
+ x = x[:, 1:, :]
617
+ if self.use_custom_patch:
618
+ x = x.reshape(B, 101, 12, D)
619
+ x = torch.cat([x, x[:, -1, :].unsqueeze(1)], dim=1) # hack
620
+ x = x.reshape(B, 1224, D)
621
+ if self.decoder_mode > 3: # mvit
622
+ x = self.decoder_blocks(x)
623
+ else:
624
+ # apply Transformer blocks
625
+ for blk in self.decoder_blocks:
626
+ x = blk(x)
627
+ x = self.decoder_norm(x)
628
+
629
+ # predictor projection
630
+ pred = self.decoder_pred(x)
631
+
632
+ # remove cls token
633
+ if self.decoder_mode != 0:
634
+ if self.use_custom_patch:
635
+ pred = pred.reshape(B, 102, 12, 256)
636
+ pred = pred[:, :101, :, :]
637
+ pred = pred.reshape(B, 1212, 256)
638
+ else:
639
+ pred = pred
640
+ else:
641
+ pred = pred[:, 1:, :]
642
+ return pred, None, None # emb, emb_pixel
643
+
644
+ def forward_loss(self, imgs, pred, mask, norm_pix_loss=False):
645
+ """
646
+ imgs: [N, 3, H, W]
647
+ pred: [N, L, p*p*3]
648
+ mask: [N, L], 0 is keep, 1 is remove,
649
+ """
650
+ target = self.patchify(imgs)
651
+ if norm_pix_loss:
652
+ mean = target.mean(dim=-1, keepdim=True)
653
+ var = target.var(dim=-1, keepdim=True)
654
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
655
+
656
+ loss = (pred - target) ** 2
657
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
658
+
659
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
660
+ return loss
661
+
662
+ def forward(self, imgs, mask_ratio=0.8):
663
+ emb_enc, mask, ids_restore, _ = self.forward_encoder(
664
+ imgs, mask_ratio, mask_2d=self.mask_2d
665
+ )
666
+ pred, _, _ = self.forward_decoder(emb_enc, ids_restore) # [N, L, p*p*3]
667
+ loss_recon = self.forward_loss(
668
+ imgs, pred, mask, norm_pix_loss=self.norm_pix_loss
669
+ )
670
+ loss_contrastive = torch.FloatTensor([0.0]).cuda()
671
+ return loss_recon, pred, mask, loss_contrastive
672
+
673
+
674
+ def mae_vit_small_patch16_dec512d8b(**kwargs):
675
+ model = MaskedAutoencoderViT(
676
+ patch_size=16,
677
+ embed_dim=384,
678
+ depth=12,
679
+ num_heads=6,
680
+ decoder_embed_dim=512,
681
+ decoder_num_heads=16,
682
+ mlp_ratio=4,
683
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
684
+ **kwargs,
685
+ )
686
+ return model
687
+
688
+
689
+ def mae_vit_base_patch16_dec512d8b(**kwargs):
690
+ model = MaskedAutoencoderViT(
691
+ patch_size=16,
692
+ embed_dim=768,
693
+ depth=12,
694
+ num_heads=12,
695
+ decoder_embed_dim=512,
696
+ decoder_num_heads=16,
697
+ mlp_ratio=4,
698
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
699
+ **kwargs,
700
+ )
701
+ return model
702
+
703
+
704
+ def mae_vit_large_patch16_dec512d8b(**kwargs):
705
+ model = MaskedAutoencoderViT(
706
+ patch_size=16,
707
+ embed_dim=1024,
708
+ depth=24,
709
+ num_heads=16,
710
+ decoder_embed_dim=512,
711
+ decoder_num_heads=16,
712
+ mlp_ratio=4,
713
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
714
+ **kwargs,
715
+ )
716
+ return model
717
+
718
+
719
+ def mae_vit_huge_patch14_dec512d8b(**kwargs):
720
+ model = MaskedAutoencoderViT(
721
+ patch_size=14,
722
+ embed_dim=1280,
723
+ depth=32,
724
+ num_heads=16,
725
+ decoder_embed_dim=512,
726
+ decoder_num_heads=16,
727
+ mlp_ratio=4,
728
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
729
+ **kwargs,
730
+ )
731
+ return model
732
+
733
+
734
+ # set recommended archs
735
+ mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks
736
+ mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
737
+ mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks
738
+ mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b # decoder: 512 dim, 8 blocks
audio_encoder/models_vit.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # --------------------------------------------------------
11
+
12
+ from functools import partial
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import timm.models.vision_transformer
17
+
18
+
19
+ class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
20
+ """Vision Transformer with support for global average pooling"""
21
+
22
+ def __init__(
23
+ self, global_pool=False, mask_2d=True, use_custom_patch=False, **kwargs
24
+ ):
25
+ super(VisionTransformer, self).__init__(**kwargs)
26
+
27
+ self.global_pool = global_pool
28
+ if self.global_pool:
29
+ norm_layer = kwargs["norm_layer"]
30
+ embed_dim = kwargs["embed_dim"]
31
+ self.fc_norm = norm_layer(embed_dim)
32
+ del self.norm # remove the original norm
33
+ self.mask_2d = mask_2d
34
+ self.use_custom_patch = use_custom_patch
35
+
36
+ def forward_features(self, x):
37
+ B = x.shape[0]
38
+ x = self.patch_embed(x)
39
+ x = x + self.pos_embed[:, 1:, :]
40
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
41
+ cls_tokens = cls_token.expand(
42
+ B, -1, -1
43
+ ) # stole cls_tokens impl from Phil Wang, thanks
44
+ x = torch.cat((cls_tokens, x), dim=1)
45
+ x = self.pos_drop(x)
46
+
47
+ for blk in self.blocks:
48
+ x = blk(x)
49
+
50
+ if self.global_pool:
51
+ x = x[:, 1:, :].mean(dim=1) # global pool without cls token
52
+ outcome = self.fc_norm(x)
53
+ else:
54
+ x = self.norm(x)
55
+ outcome = x[:, 0]
56
+
57
+ return outcome
58
+
59
+ def random_masking(self, x, mask_ratio):
60
+ """
61
+ Perform per-sample random masking by per-sample shuffling.
62
+ Per-sample shuffling is done by argsort random noise.
63
+ x: [N, L, D], sequence
64
+ """
65
+ N, L, D = x.shape # batch, length, dim
66
+ len_keep = int(L * (1 - mask_ratio))
67
+
68
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
69
+
70
+ # sort noise for each sample
71
+ ids_shuffle = torch.argsort(
72
+ noise, dim=1
73
+ ) # ascend: small is keep, large is remove
74
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
75
+
76
+ # keep the first subset
77
+ ids_keep = ids_shuffle[:, :len_keep]
78
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
79
+
80
+ # generate the binary mask: 0 is keep, 1 is remove
81
+ mask = torch.ones([N, L], device=x.device)
82
+ mask[:, :len_keep] = 0
83
+ # unshuffle to get the binary mask
84
+ mask = torch.gather(mask, dim=1, index=ids_restore)
85
+
86
+ return x_masked, mask, ids_restore
87
+
88
+ def random_masking_2d(self, x, mask_t_prob, mask_f_prob):
89
+ """
90
+ 2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)
91
+ Perform per-sample random masking by per-sample shuffling.
92
+ Per-sample shuffling is done by argsort random noise.
93
+ x: [N, L, D], sequence
94
+ """
95
+
96
+ N, L, D = x.shape # batch, length, dim
97
+ if self.use_custom_patch:
98
+ # # for AS
99
+ T = 101 # 64,101
100
+ F = 12 # 8,12
101
+ # # for ESC
102
+ # T=50
103
+ # F=12
104
+ # for SPC
105
+ # T=12
106
+ # F=12
107
+ else:
108
+ # ## for AS
109
+ T = 64
110
+ F = 8
111
+ # ## for ESC
112
+ # T=32
113
+ # F=8
114
+ ## for SPC
115
+ # T=8
116
+ # F=8
117
+
118
+ # mask T
119
+ x = x.reshape(N, T, F, D)
120
+ len_keep_T = int(T * (1 - mask_t_prob))
121
+ noise = torch.rand(N, T, device=x.device) # noise in [0, 1]
122
+ # sort noise for each sample
123
+ ids_shuffle = torch.argsort(
124
+ noise, dim=1
125
+ ) # ascend: small is keep, large is remove
126
+ ids_keep = ids_shuffle[:, :len_keep_T]
127
+ index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, F, D)
128
+ # x_masked = torch.gather(x, dim=1, index=index)
129
+ # x_masked = x_masked.reshape(N,len_keep_T*F,D)
130
+ x = torch.gather(x, dim=1, index=index) # N, len_keep_T(T'), F, D
131
+
132
+ # mask F
133
+ # x = x.reshape(N, T, F, D)
134
+ x = x.permute(0, 2, 1, 3) # N T' F D => N F T' D
135
+ len_keep_F = int(F * (1 - mask_f_prob))
136
+ noise = torch.rand(N, F, device=x.device) # noise in [0, 1]
137
+ # sort noise for each sample
138
+ ids_shuffle = torch.argsort(
139
+ noise, dim=1
140
+ ) # ascend: small is keep, large is remove
141
+ ids_keep = ids_shuffle[:, :len_keep_F]
142
+ # index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, T, D)
143
+ index = ids_keep.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, len_keep_T, D)
144
+ x_masked = torch.gather(x, dim=1, index=index)
145
+ x_masked = x_masked.permute(0, 2, 1, 3) # N F' T' D => N T' F' D
146
+ # x_masked = x_masked.reshape(N,len_keep*T,D)
147
+ x_masked = x_masked.reshape(N, len_keep_F * len_keep_T, D)
148
+
149
+ return x_masked, None, None
150
+
151
+ def forward_features_mask(self, x, mask_t_prob, mask_f_prob):
152
+ B = x.shape[0] # 4,1,1024,128
153
+ x = self.patch_embed(x) # 4, 512, 768
154
+
155
+ x = x + self.pos_embed[:, 1:, :]
156
+ if self.random_masking_2d:
157
+ x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob, mask_f_prob)
158
+ else:
159
+ x, mask, ids_restore = self.random_masking(x, mask_t_prob)
160
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
161
+ cls_tokens = cls_token.expand(B, -1, -1)
162
+ x = torch.cat((cls_tokens, x), dim=1)
163
+ x = self.pos_drop(x)
164
+
165
+ # apply Transformer blocks
166
+ for blk in self.blocks:
167
+ x = blk(x)
168
+
169
+ if self.global_pool:
170
+ x = x[:, 1:, :].mean(dim=1) # global pool without cls token
171
+ outcome = self.fc_norm(x)
172
+ else:
173
+ x = self.norm(x)
174
+ outcome = x[:, 0]
175
+
176
+ return outcome
177
+
178
+ # overwrite original timm
179
+ def forward(self, x, v=None, mask_t_prob=0.0, mask_f_prob=0.0):
180
+ if mask_t_prob > 0.0 or mask_f_prob > 0.0:
181
+ x = self.forward_features_mask(
182
+ x, mask_t_prob=mask_t_prob, mask_f_prob=mask_f_prob
183
+ )
184
+ else:
185
+ x = self.forward_features(x)
186
+ x = self.head(x)
187
+ return x
188
+
189
+
190
+ def vit_small_patch16(**kwargs):
191
+ model = VisionTransformer(
192
+ patch_size=16,
193
+ embed_dim=384,
194
+ depth=12,
195
+ num_heads=6,
196
+ mlp_ratio=4,
197
+ qkv_bias=True,
198
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
199
+ **kwargs
200
+ )
201
+ return model
202
+
203
+
204
+ def vit_base_patch16(**kwargs):
205
+ model = VisionTransformer(
206
+ patch_size=16,
207
+ embed_dim=768,
208
+ depth=12,
209
+ num_heads=12,
210
+ mlp_ratio=4,
211
+ qkv_bias=True,
212
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
213
+ **kwargs
214
+ )
215
+ return model
216
+
217
+
218
+ def vit_large_patch16(**kwargs):
219
+ model = VisionTransformer(
220
+ patch_size=16,
221
+ embed_dim=1024,
222
+ depth=24,
223
+ num_heads=16,
224
+ mlp_ratio=4,
225
+ qkv_bias=True,
226
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
227
+ **kwargs
228
+ )
229
+ return model
230
+
231
+
232
+ def vit_huge_patch14(**kwargs):
233
+ model = VisionTransformer(
234
+ patch_size=14,
235
+ embed_dim=1280,
236
+ depth=32,
237
+ num_heads=16,
238
+ mlp_ratio=4,
239
+ qkv_bias=True,
240
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
241
+ **kwargs
242
+ )
243
+ return model
pipeline/modeling_audioldm2.py ADDED
@@ -0,0 +1,1546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.loaders import UNet2DConditionLoadersMixin
24
+ from diffusers.models.activations import get_activation
25
+ from diffusers.models.attention_processor import (
26
+ # ADDED_KV_ATTENTION_PROCESSORS,
27
+ # CROSS_ATTENTION_PROCESSORS,
28
+ SlicedAttnAddedKVProcessor,
29
+ AttnAddedKVProcessor2_0,
30
+ XFormersAttnAddedKVProcessor,
31
+ AttnProcessor2_0,
32
+ XFormersAttnProcessor,
33
+ SlicedAttnProcessor,
34
+ LoRAAttnProcessor,
35
+ LoRAAttnProcessor2_0,
36
+ LoRAXFormersAttnProcessor,
37
+ LoRAAttnAddedKVProcessor,
38
+ AttentionProcessor,
39
+ AttnAddedKVProcessor,
40
+ AttnProcessor,
41
+ )
42
+
43
+ ADDED_KV_ATTENTION_PROCESSORS = (
44
+ AttnAddedKVProcessor,
45
+ SlicedAttnAddedKVProcessor,
46
+ AttnAddedKVProcessor2_0,
47
+ XFormersAttnAddedKVProcessor,
48
+ LoRAAttnAddedKVProcessor,
49
+ )
50
+ CROSS_ATTENTION_PROCESSORS = (
51
+ AttnProcessor,
52
+ AttnProcessor2_0,
53
+ XFormersAttnProcessor,
54
+ SlicedAttnProcessor,
55
+ LoRAAttnProcessor,
56
+ LoRAAttnProcessor2_0,
57
+ LoRAXFormersAttnProcessor,
58
+ )
59
+
60
+ from diffusers.models.embeddings import (
61
+ TimestepEmbedding,
62
+ Timesteps,
63
+ )
64
+ from diffusers.models.modeling_utils import ModelMixin
65
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
66
+ # from diffusers.models.transformers.transformer_2d import Transformer2DModel
67
+ from diffusers.models.transformer_2d import Transformer2DModel
68
+ from diffusers.models.unet_2d_blocks import DownBlock2D, UpBlock2D
69
+ from diffusers.models.unet_2d_condition import UNet2DConditionOutput
70
+ # from diffusers.models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D
71
+ # from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
72
+ from diffusers.utils import BaseOutput, is_torch_version, logging
73
+
74
+
75
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
76
+
77
+
78
+ def add_special_tokens(hidden_states, attention_mask, sos_token, eos_token):
79
+ batch_size = hidden_states.shape[0]
80
+
81
+ if attention_mask is not None:
82
+ # Add two more steps to attn mask
83
+ new_attn_mask_step = attention_mask.new_ones((batch_size, 1))
84
+ attention_mask = torch.concat([new_attn_mask_step, attention_mask, new_attn_mask_step], dim=-1)
85
+
86
+ # Add the SOS / EOS tokens at the start / end of the sequence respectively
87
+ sos_token = sos_token.expand(batch_size, 1, -1)
88
+ eos_token = eos_token.expand(batch_size, 1, -1)
89
+ hidden_states = torch.concat([sos_token, hidden_states, eos_token], dim=1)
90
+ return hidden_states, attention_mask
91
+
92
+
93
+ @dataclass
94
+ class AudioLDM2ProjectionModelOutput(BaseOutput):
95
+ """
96
+ Args:
97
+ Class for AudioLDM2 projection layer's outputs.
98
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
99
+ Sequence of hidden-states obtained by linearly projecting the hidden-states for each of the text
100
+ encoders and subsequently concatenating them together.
101
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
102
+ Mask to avoid performing attention on padding token indices, formed by concatenating the attention masks
103
+ for the two text encoders together. Mask values selected in `[0, 1]`:
104
+
105
+ - 1 for tokens that are **not masked**,
106
+ - 0 for tokens that are **masked**.
107
+ """
108
+
109
+ hidden_states: torch.FloatTensor
110
+ attention_mask: Optional[torch.LongTensor] = None
111
+
112
+
113
+ class AudioLDM2ProjectionModel(ModelMixin, ConfigMixin):
114
+ """
115
+ A simple linear projection model to map two text embeddings to a shared latent space. It also inserts learned
116
+ embedding vectors at the start and end of each text embedding sequence respectively. Each variable appended with
117
+ `_1` refers to that corresponding to the second text encoder. Otherwise, it is from the first.
118
+
119
+ Args:
120
+ text_encoder_dim (`int`):
121
+ Dimensionality of the text embeddings from the first text encoder (CLAP).
122
+ text_encoder_1_dim (`int`):
123
+ Dimensionality of the text embeddings from the second text encoder (T5 or VITS).
124
+ langauge_model_dim (`int`):
125
+ Dimensionality of the text embeddings from the language model (GPT2).
126
+ """
127
+
128
+ @register_to_config
129
+ def __init__(self, text_encoder_dim, text_encoder_1_dim, langauge_model_dim):
130
+ super().__init__()
131
+ # additional projection layers for each text encoder
132
+ self.projection = nn.Linear(text_encoder_dim, langauge_model_dim)
133
+ self.projection_1 = nn.Linear(text_encoder_1_dim, langauge_model_dim)
134
+
135
+ # learnable SOS / EOS token embeddings for each text encoder
136
+ self.sos_embed = nn.Parameter(torch.ones(langauge_model_dim))
137
+ self.eos_embed = nn.Parameter(torch.ones(langauge_model_dim))
138
+
139
+ self.sos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim))
140
+ self.eos_embed_1 = nn.Parameter(torch.ones(langauge_model_dim))
141
+
142
+ def forward(
143
+ self,
144
+ hidden_states: Optional[torch.FloatTensor] = None,
145
+ hidden_states_1: Optional[torch.FloatTensor] = None,
146
+ attention_mask: Optional[torch.LongTensor] = None,
147
+ attention_mask_1: Optional[torch.LongTensor] = None,
148
+ ):
149
+ hidden_states = self.projection(hidden_states)
150
+ hidden_states, attention_mask = add_special_tokens(
151
+ hidden_states, attention_mask, sos_token=self.sos_embed, eos_token=self.eos_embed
152
+ )
153
+
154
+ hidden_states_1 = self.projection_1(hidden_states_1)
155
+ hidden_states_1, attention_mask_1 = add_special_tokens(
156
+ hidden_states_1, attention_mask_1, sos_token=self.sos_embed_1, eos_token=self.eos_embed_1
157
+ )
158
+
159
+ # concatenate clap and t5 text encoding
160
+ hidden_states = torch.cat([hidden_states, hidden_states_1], dim=1)
161
+
162
+ # concatenate attention masks
163
+ if attention_mask is None and attention_mask_1 is not None:
164
+ attention_mask = attention_mask_1.new_ones((hidden_states[:2]))
165
+ elif attention_mask is not None and attention_mask_1 is None:
166
+ attention_mask_1 = attention_mask.new_ones((hidden_states_1[:2]))
167
+
168
+ if attention_mask is not None and attention_mask_1 is not None:
169
+ attention_mask = torch.cat([attention_mask, attention_mask_1], dim=-1)
170
+ else:
171
+ attention_mask = None
172
+
173
+ return AudioLDM2ProjectionModelOutput(
174
+ hidden_states=hidden_states,
175
+ attention_mask=attention_mask,
176
+ )
177
+
178
+
179
+ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
180
+ r"""
181
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
182
+ shaped output. Compared to the vanilla [`UNet2DConditionModel`], this variant optionally includes an additional
183
+ self-attention layer in each Transformer block, as well as multiple cross-attention layers. It also allows for up
184
+ to two cross-attention embeddings, `encoder_hidden_states` and `encoder_hidden_states_1`.
185
+
186
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
187
+ for all models (such as downloading or saving).
188
+
189
+ Parameters:
190
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
191
+ Height and width of input/output sample.
192
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
193
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
194
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
195
+ Whether to flip the sin to cos in the time embedding.
196
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
197
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
198
+ The tuple of downsample blocks to use.
199
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
200
+ Block type for middle of UNet, it can only be `UNetMidBlock2DCrossAttn` for AudioLDM2.
201
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
202
+ The tuple of upsample blocks to use.
203
+ only_cross_attention (`bool` or `Tuple[bool]`, *optional*, default to `False`):
204
+ Whether to include self-attention in the basic transformer blocks, see
205
+ [`~models.attention.BasicTransformerBlock`].
206
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
207
+ The tuple of output channels for each block.
208
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
209
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
210
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
211
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
212
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
213
+ If `None`, normalization and activation layers is skipped in post-processing.
214
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
215
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
216
+ The dimension of the cross attention features.
217
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
218
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
219
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
220
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
221
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
222
+ num_attention_heads (`int`, *optional*):
223
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
224
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
225
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
226
+ class_embed_type (`str`, *optional*, defaults to `None`):
227
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
228
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
229
+ num_class_embeds (`int`, *optional*, defaults to `None`):
230
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
231
+ class conditioning with `class_embed_type` equal to `None`.
232
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
233
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
234
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
235
+ An optional override for the dimension of the projected time embedding.
236
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
237
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
238
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
239
+ timestep_post_act (`str`, *optional*, defaults to `None`):
240
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
241
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
242
+ The dimension of `cond_proj` layer in the timestep embedding.
243
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
244
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
245
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
246
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
247
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
248
+ embeddings with the class embeddings.
249
+ """
250
+
251
+ _supports_gradient_checkpointing = True
252
+
253
+ @register_to_config
254
+ def __init__(
255
+ self,
256
+ sample_size: Optional[int] = None,
257
+ in_channels: int = 4,
258
+ out_channels: int = 4,
259
+ flip_sin_to_cos: bool = True,
260
+ freq_shift: int = 0,
261
+ down_block_types: Tuple[str] = (
262
+ "CrossAttnDownBlock2D",
263
+ "CrossAttnDownBlock2D",
264
+ "CrossAttnDownBlock2D",
265
+ "DownBlock2D",
266
+ ),
267
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
268
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
269
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
270
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
271
+ layers_per_block: Union[int, Tuple[int]] = 2,
272
+ downsample_padding: int = 1,
273
+ mid_block_scale_factor: float = 1,
274
+ act_fn: str = "silu",
275
+ norm_num_groups: Optional[int] = 32,
276
+ norm_eps: float = 1e-5,
277
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
278
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
279
+ attention_head_dim: Union[int, Tuple[int]] = 8,
280
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
281
+ use_linear_projection: bool = False,
282
+ class_embed_type: Optional[str] = None,
283
+ num_class_embeds: Optional[int] = None,
284
+ upcast_attention: bool = False,
285
+ resnet_time_scale_shift: str = "default",
286
+ time_embedding_type: str = "positional",
287
+ time_embedding_dim: Optional[int] = None,
288
+ time_embedding_act_fn: Optional[str] = None,
289
+ timestep_post_act: Optional[str] = None,
290
+ time_cond_proj_dim: Optional[int] = None,
291
+ conv_in_kernel: int = 3,
292
+ conv_out_kernel: int = 3,
293
+ projection_class_embeddings_input_dim: Optional[int] = None,
294
+ class_embeddings_concat: bool = False,
295
+ ):
296
+ super().__init__()
297
+
298
+ self.sample_size = sample_size
299
+
300
+ if num_attention_heads is not None:
301
+ raise ValueError(
302
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
303
+ )
304
+
305
+ # If `num_attention_heads` is not defined (which is the case for most models)
306
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
307
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
308
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
309
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
310
+ # which is why we correct for the naming here.
311
+ num_attention_heads = num_attention_heads or attention_head_dim
312
+
313
+ # Check inputs
314
+ if len(down_block_types) != len(up_block_types):
315
+ raise ValueError(
316
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
317
+ )
318
+
319
+ if len(block_out_channels) != len(down_block_types):
320
+ raise ValueError(
321
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
322
+ )
323
+
324
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
325
+ raise ValueError(
326
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
327
+ )
328
+
329
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
330
+ raise ValueError(
331
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
332
+ )
333
+
334
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
335
+ raise ValueError(
336
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
337
+ )
338
+
339
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
340
+ raise ValueError(
341
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
342
+ )
343
+
344
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
345
+ raise ValueError(
346
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
347
+ )
348
+
349
+ # input
350
+ conv_in_padding = (conv_in_kernel - 1) // 2
351
+ self.conv_in = nn.Conv2d(
352
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
353
+ )
354
+
355
+ # time
356
+ if time_embedding_type == "positional":
357
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
358
+
359
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
360
+ timestep_input_dim = block_out_channels[0]
361
+ else:
362
+ raise ValueError(f"{time_embedding_type} does not exist. Please make sure to use `positional`.")
363
+
364
+ self.time_embedding = TimestepEmbedding(
365
+ timestep_input_dim,
366
+ time_embed_dim,
367
+ act_fn=act_fn,
368
+ post_act_fn=timestep_post_act,
369
+ cond_proj_dim=time_cond_proj_dim,
370
+ )
371
+
372
+ # class embedding
373
+ if class_embed_type is None and num_class_embeds is not None:
374
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
375
+ elif class_embed_type == "timestep":
376
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
377
+ elif class_embed_type == "identity":
378
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
379
+ elif class_embed_type == "projection":
380
+ if projection_class_embeddings_input_dim is None:
381
+ raise ValueError(
382
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
383
+ )
384
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
385
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
386
+ # 2. it projects from an arbitrary input dimension.
387
+ #
388
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
389
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
390
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
391
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
392
+ elif class_embed_type == "simple_projection":
393
+ if projection_class_embeddings_input_dim is None:
394
+ raise ValueError(
395
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
396
+ )
397
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
398
+ else:
399
+ self.class_embedding = None
400
+
401
+ if time_embedding_act_fn is None:
402
+ self.time_embed_act = None
403
+ else:
404
+ self.time_embed_act = get_activation(time_embedding_act_fn)
405
+
406
+ self.down_blocks = nn.ModuleList([])
407
+ self.up_blocks = nn.ModuleList([])
408
+
409
+ if isinstance(only_cross_attention, bool):
410
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
411
+
412
+ if isinstance(num_attention_heads, int):
413
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
414
+
415
+ if isinstance(cross_attention_dim, int):
416
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
417
+
418
+ if isinstance(layers_per_block, int):
419
+ layers_per_block = [layers_per_block] * len(down_block_types)
420
+
421
+ if isinstance(transformer_layers_per_block, int):
422
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
423
+
424
+ if class_embeddings_concat:
425
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
426
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
427
+ # regular time embeddings
428
+ blocks_time_embed_dim = time_embed_dim * 2
429
+ else:
430
+ blocks_time_embed_dim = time_embed_dim
431
+
432
+ # down
433
+ output_channel = block_out_channels[0]
434
+ for i, down_block_type in enumerate(down_block_types):
435
+ input_channel = output_channel
436
+ output_channel = block_out_channels[i]
437
+ is_final_block = i == len(block_out_channels) - 1
438
+
439
+ down_block = get_down_block(
440
+ down_block_type,
441
+ num_layers=layers_per_block[i],
442
+ transformer_layers_per_block=transformer_layers_per_block[i],
443
+ in_channels=input_channel,
444
+ out_channels=output_channel,
445
+ temb_channels=blocks_time_embed_dim,
446
+ add_downsample=not is_final_block,
447
+ resnet_eps=norm_eps,
448
+ resnet_act_fn=act_fn,
449
+ resnet_groups=norm_num_groups,
450
+ cross_attention_dim=cross_attention_dim[i],
451
+ num_attention_heads=num_attention_heads[i],
452
+ downsample_padding=downsample_padding,
453
+ use_linear_projection=use_linear_projection,
454
+ only_cross_attention=only_cross_attention[i],
455
+ upcast_attention=upcast_attention,
456
+ resnet_time_scale_shift=resnet_time_scale_shift,
457
+ )
458
+ self.down_blocks.append(down_block)
459
+
460
+ # mid
461
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
462
+ self.mid_block = UNetMidBlock2DCrossAttn(
463
+ transformer_layers_per_block=transformer_layers_per_block[-1],
464
+ in_channels=block_out_channels[-1],
465
+ temb_channels=blocks_time_embed_dim,
466
+ resnet_eps=norm_eps,
467
+ resnet_act_fn=act_fn,
468
+ output_scale_factor=mid_block_scale_factor,
469
+ resnet_time_scale_shift=resnet_time_scale_shift,
470
+ cross_attention_dim=cross_attention_dim[-1],
471
+ num_attention_heads=num_attention_heads[-1],
472
+ resnet_groups=norm_num_groups,
473
+ use_linear_projection=use_linear_projection,
474
+ upcast_attention=upcast_attention,
475
+ )
476
+ else:
477
+ raise ValueError(
478
+ f"unknown mid_block_type : {mid_block_type}. Should be `UNetMidBlock2DCrossAttn` for AudioLDM2."
479
+ )
480
+
481
+ # count how many layers upsample the images
482
+ self.num_upsamplers = 0
483
+
484
+ # up
485
+ reversed_block_out_channels = list(reversed(block_out_channels))
486
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
487
+ reversed_layers_per_block = list(reversed(layers_per_block))
488
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
489
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
490
+ only_cross_attention = list(reversed(only_cross_attention))
491
+
492
+ output_channel = reversed_block_out_channels[0]
493
+ for i, up_block_type in enumerate(up_block_types):
494
+ is_final_block = i == len(block_out_channels) - 1
495
+
496
+ prev_output_channel = output_channel
497
+ output_channel = reversed_block_out_channels[i]
498
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
499
+
500
+ # add upsample block for all BUT final layer
501
+ if not is_final_block:
502
+ add_upsample = True
503
+ self.num_upsamplers += 1
504
+ else:
505
+ add_upsample = False
506
+
507
+ up_block = get_up_block(
508
+ up_block_type,
509
+ num_layers=reversed_layers_per_block[i] + 1,
510
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
511
+ in_channels=input_channel,
512
+ out_channels=output_channel,
513
+ prev_output_channel=prev_output_channel,
514
+ temb_channels=blocks_time_embed_dim,
515
+ add_upsample=add_upsample,
516
+ resnet_eps=norm_eps,
517
+ resnet_act_fn=act_fn,
518
+ resnet_groups=norm_num_groups,
519
+ cross_attention_dim=reversed_cross_attention_dim[i],
520
+ num_attention_heads=reversed_num_attention_heads[i],
521
+ use_linear_projection=use_linear_projection,
522
+ only_cross_attention=only_cross_attention[i],
523
+ upcast_attention=upcast_attention,
524
+ resnet_time_scale_shift=resnet_time_scale_shift,
525
+ )
526
+ self.up_blocks.append(up_block)
527
+ prev_output_channel = output_channel
528
+
529
+ # out
530
+ if norm_num_groups is not None:
531
+ self.conv_norm_out = nn.GroupNorm(
532
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
533
+ )
534
+
535
+ self.conv_act = get_activation(act_fn)
536
+
537
+ else:
538
+ self.conv_norm_out = None
539
+ self.conv_act = None
540
+
541
+ conv_out_padding = (conv_out_kernel - 1) // 2
542
+ self.conv_out = nn.Conv2d(
543
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
544
+ )
545
+
546
+ @property
547
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
548
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
549
+ r"""
550
+ Returns:
551
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
552
+ indexed by its weight name.
553
+ """
554
+ # set recursively
555
+ processors = {}
556
+
557
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
558
+ if hasattr(module, "get_processor"):
559
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
560
+
561
+ for sub_name, child in module.named_children():
562
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
563
+
564
+ return processors
565
+
566
+ for name, module in self.named_children():
567
+ fn_recursive_add_processors(name, module, processors)
568
+
569
+ return processors
570
+
571
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
572
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
573
+ r"""
574
+ Sets the attention processor to use to compute attention.
575
+
576
+ Parameters:
577
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
578
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
579
+ for **all** `Attention` layers.
580
+
581
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
582
+ processor. This is strongly recommended when setting trainable attention processors.
583
+
584
+ """
585
+ count = len(self.attn_processors.keys())
586
+
587
+ if isinstance(processor, dict) and len(processor) != count:
588
+ raise ValueError(
589
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
590
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
591
+ )
592
+
593
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
594
+ if hasattr(module, "set_processor"):
595
+ if not isinstance(processor, dict):
596
+ module.set_processor(processor)
597
+ else:
598
+ module.set_processor(processor.pop(f"{name}.processor"))
599
+
600
+ for sub_name, child in module.named_children():
601
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
602
+
603
+ for name, module in self.named_children():
604
+ fn_recursive_attn_processor(name, module, processor)
605
+ # print(f"{processor}, Type: {type(processor)}")
606
+
607
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
608
+ def set_default_attn_processor(self):
609
+ """
610
+ Disables custom attention processors and sets the default attention implementation.
611
+ """
612
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
613
+ processor = AttnAddedKVProcessor()
614
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
615
+ processor = AttnProcessor()
616
+ else:
617
+ raise ValueError(
618
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
619
+ )
620
+
621
+ self.set_attn_processor(processor)
622
+
623
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
624
+ def set_attention_slice(self, slice_size):
625
+ r"""
626
+ Enable sliced attention computation.
627
+
628
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
629
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
630
+
631
+ Args:
632
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
633
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
634
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
635
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
636
+ must be a multiple of `slice_size`.
637
+ """
638
+ sliceable_head_dims = []
639
+
640
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
641
+ if hasattr(module, "set_attention_slice"):
642
+ sliceable_head_dims.append(module.sliceable_head_dim)
643
+
644
+ for child in module.children():
645
+ fn_recursive_retrieve_sliceable_dims(child)
646
+
647
+ # retrieve number of attention layers
648
+ for module in self.children():
649
+ fn_recursive_retrieve_sliceable_dims(module)
650
+
651
+ num_sliceable_layers = len(sliceable_head_dims)
652
+
653
+ if slice_size == "auto":
654
+ # half the attention head size is usually a good trade-off between
655
+ # speed and memory
656
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
657
+ elif slice_size == "max":
658
+ # make smallest slice possible
659
+ slice_size = num_sliceable_layers * [1]
660
+
661
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
662
+
663
+ if len(slice_size) != len(sliceable_head_dims):
664
+ raise ValueError(
665
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
666
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
667
+ )
668
+
669
+ for i in range(len(slice_size)):
670
+ size = slice_size[i]
671
+ dim = sliceable_head_dims[i]
672
+ if size is not None and size > dim:
673
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
674
+
675
+ # Recursively walk through all the children.
676
+ # Any children which exposes the set_attention_slice method
677
+ # gets the message
678
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
679
+ if hasattr(module, "set_attention_slice"):
680
+ module.set_attention_slice(slice_size.pop())
681
+
682
+ for child in module.children():
683
+ fn_recursive_set_attention_slice(child, slice_size)
684
+
685
+ reversed_slice_size = list(reversed(slice_size))
686
+ for module in self.children():
687
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
688
+
689
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel._set_gradient_checkpointing
690
+ def _set_gradient_checkpointing(self, module, value=False):
691
+ if hasattr(module, "gradient_checkpointing"):
692
+ module.gradient_checkpointing = value
693
+
694
+ def forward(
695
+ self,
696
+ sample: torch.FloatTensor,
697
+ timestep: Union[torch.Tensor, float, int],
698
+ encoder_hidden_states: torch.Tensor,
699
+ class_labels: Optional[torch.Tensor] = None,
700
+ timestep_cond: Optional[torch.Tensor] = None,
701
+ attention_mask: Optional[torch.Tensor] = None,
702
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
703
+ encoder_attention_mask: Optional[torch.Tensor] = None,
704
+ return_dict: bool = True,
705
+ encoder_hidden_states_1: Optional[torch.Tensor] = None,
706
+ encoder_attention_mask_1: Optional[torch.Tensor] = None,
707
+ ) -> Union[UNet2DConditionOutput, Tuple]:
708
+ r"""
709
+ The [`AudioLDM2UNet2DConditionModel`] forward method.
710
+
711
+ Args:
712
+ sample (`torch.FloatTensor`):
713
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
714
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
715
+ encoder_hidden_states (`torch.FloatTensor`):
716
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
717
+ encoder_attention_mask (`torch.Tensor`):
718
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
719
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
720
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
721
+ return_dict (`bool`, *optional*, defaults to `True`):
722
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
723
+ tuple.
724
+ cross_attention_kwargs (`dict`, *optional*):
725
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
726
+ encoder_hidden_states_1 (`torch.FloatTensor`, *optional*):
727
+ A second set of encoder hidden states with shape `(batch, sequence_length_2, feature_dim_2)`. Can be
728
+ used to condition the model on a different set of embeddings to `encoder_hidden_states`.
729
+ encoder_attention_mask_1 (`torch.Tensor`, *optional*):
730
+ A cross-attention mask of shape `(batch, sequence_length_2)` is applied to `encoder_hidden_states_1`.
731
+ If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
732
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
733
+
734
+ Returns:
735
+ [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
736
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
737
+ a `tuple` is returned where the first element is the sample tensor.
738
+ """
739
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
740
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
741
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
742
+ # on the fly if necessary.
743
+ default_overall_up_factor = 2**self.num_upsamplers
744
+
745
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
746
+ forward_upsample_size = False
747
+ upsample_size = None
748
+
749
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
750
+ logger.info("Forward upsample size to force interpolation output size.")
751
+ forward_upsample_size = True
752
+
753
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
754
+ # expects mask of shape:
755
+ # [batch, key_tokens]
756
+ # adds singleton query_tokens dimension:
757
+ # [batch, 1, key_tokens]
758
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
759
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
760
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
761
+ # if attention_mask is not None:
762
+ # # assume that mask is expressed as:
763
+ # # (1 = keep, 0 = discard)
764
+ # # convert mask into a bias that can be added to attention scores:
765
+ # # (keep = +0, discard = -10000.0)
766
+ # print("type of attention_mask",type(attention_mask))
767
+ # print("attention_mask",attention_mask)
768
+ # attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
769
+ # attention_mask = attention_mask.unsqueeze(1)
770
+
771
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
772
+ if encoder_attention_mask is not None:
773
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
774
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
775
+
776
+ if encoder_attention_mask_1 is not None:
777
+ encoder_attention_mask_1 = (1 - encoder_attention_mask_1.to(sample.dtype)) * -10000.0
778
+ encoder_attention_mask_1 = encoder_attention_mask_1.unsqueeze(1)
779
+
780
+ # 1. time
781
+ timesteps = timestep
782
+ if not torch.is_tensor(timesteps):
783
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
784
+ # This would be a good case for the `match` statement (Python 3.10+)
785
+ is_mps = sample.device.type == "mps"
786
+ if isinstance(timestep, float):
787
+ dtype = torch.float32 if is_mps else torch.float64
788
+ else:
789
+ dtype = torch.int32 if is_mps else torch.int64
790
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
791
+ elif len(timesteps.shape) == 0:
792
+ timesteps = timesteps[None].to(sample.device)
793
+
794
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
795
+ timesteps = timesteps.expand(sample.shape[0])
796
+
797
+ t_emb = self.time_proj(timesteps)
798
+
799
+ # `Timesteps` does not contain any weights and will always return f32 tensors
800
+ # but time_embedding might actually be running in fp16. so we need to cast here.
801
+ # there might be better ways to encapsulate this.
802
+ t_emb = t_emb.to(dtype=sample.dtype)
803
+
804
+ emb = self.time_embedding(t_emb, timestep_cond)
805
+ aug_emb = None
806
+
807
+ if self.class_embedding is not None:
808
+ if class_labels is None:
809
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
810
+
811
+ if self.config.class_embed_type == "timestep":
812
+ class_labels = self.time_proj(class_labels)
813
+
814
+ # `Timesteps` does not contain any weights and will always return f32 tensors
815
+ # there might be better ways to encapsulate this.
816
+ class_labels = class_labels.to(dtype=sample.dtype)
817
+
818
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
819
+
820
+ if self.config.class_embeddings_concat:
821
+ emb = torch.cat([emb, class_emb], dim=-1)
822
+ else:
823
+ emb = emb + class_emb
824
+
825
+ emb = emb + aug_emb if aug_emb is not None else emb
826
+
827
+ if self.time_embed_act is not None:
828
+ emb = self.time_embed_act(emb)
829
+
830
+ # 2. pre-process
831
+ sample = self.conv_in(sample)
832
+
833
+ # 3. down
834
+ down_block_res_samples = (sample,)
835
+ for downsample_block in self.down_blocks:
836
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
837
+ sample, res_samples = downsample_block(
838
+ hidden_states=sample,
839
+ temb=emb,
840
+ encoder_hidden_states=encoder_hidden_states,
841
+ attention_mask=attention_mask,
842
+ cross_attention_kwargs=cross_attention_kwargs,
843
+ encoder_attention_mask=encoder_attention_mask,
844
+ encoder_hidden_states_1=encoder_hidden_states_1,
845
+ encoder_attention_mask_1=encoder_attention_mask_1,
846
+ )
847
+ else:
848
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
849
+
850
+ down_block_res_samples += res_samples
851
+
852
+ # 4. mid
853
+ if self.mid_block is not None:
854
+ sample = self.mid_block(
855
+ sample,
856
+ emb,
857
+ encoder_hidden_states=encoder_hidden_states,
858
+ attention_mask=attention_mask,
859
+ cross_attention_kwargs=cross_attention_kwargs,
860
+ encoder_attention_mask=encoder_attention_mask,
861
+ encoder_hidden_states_1=encoder_hidden_states_1,
862
+ encoder_attention_mask_1=encoder_attention_mask_1,
863
+ )
864
+
865
+ # 5. up
866
+ for i, upsample_block in enumerate(self.up_blocks):
867
+ is_final_block = i == len(self.up_blocks) - 1
868
+
869
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
870
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
871
+
872
+ # if we have not reached the final block and need to forward the
873
+ # upsample size, we do it here
874
+ if not is_final_block and forward_upsample_size:
875
+ upsample_size = down_block_res_samples[-1].shape[2:]
876
+
877
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
878
+ sample = upsample_block(
879
+ hidden_states=sample,
880
+ temb=emb,
881
+ res_hidden_states_tuple=res_samples,
882
+ encoder_hidden_states=encoder_hidden_states,
883
+ cross_attention_kwargs=cross_attention_kwargs,
884
+ upsample_size=upsample_size,
885
+ attention_mask=attention_mask,
886
+ encoder_attention_mask=encoder_attention_mask,
887
+ encoder_hidden_states_1=encoder_hidden_states_1,
888
+ encoder_attention_mask_1=encoder_attention_mask_1,
889
+ )
890
+ else:
891
+ sample = upsample_block(
892
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
893
+ )
894
+
895
+ # 6. post-process
896
+ if self.conv_norm_out:
897
+ sample = self.conv_norm_out(sample)
898
+ sample = self.conv_act(sample)
899
+ sample = self.conv_out(sample)
900
+ sample = torch.tensor(sample, requires_grad=True)
901
+ print(f'sample requires_grad: {sample.requires_grad}')
902
+ if not return_dict:
903
+ return (sample,)
904
+
905
+ return UNet2DConditionOutput(sample=sample)
906
+
907
+
908
+ def get_down_block(
909
+ down_block_type,
910
+ num_layers,
911
+ in_channels,
912
+ out_channels,
913
+ temb_channels,
914
+ add_downsample,
915
+ resnet_eps,
916
+ resnet_act_fn,
917
+ transformer_layers_per_block=1,
918
+ num_attention_heads=None,
919
+ resnet_groups=None,
920
+ cross_attention_dim=None,
921
+ downsample_padding=None,
922
+ use_linear_projection=False,
923
+ only_cross_attention=False,
924
+ upcast_attention=False,
925
+ resnet_time_scale_shift="default",
926
+ ):
927
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
928
+ if down_block_type == "DownBlock2D":
929
+ return DownBlock2D(
930
+ num_layers=num_layers,
931
+ in_channels=in_channels,
932
+ out_channels=out_channels,
933
+ temb_channels=temb_channels,
934
+ add_downsample=add_downsample,
935
+ resnet_eps=resnet_eps,
936
+ resnet_act_fn=resnet_act_fn,
937
+ resnet_groups=resnet_groups,
938
+ downsample_padding=downsample_padding,
939
+ resnet_time_scale_shift=resnet_time_scale_shift,
940
+ )
941
+ elif down_block_type == "CrossAttnDownBlock2D":
942
+ if cross_attention_dim is None:
943
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
944
+ return CrossAttnDownBlock2D(
945
+ num_layers=num_layers,
946
+ transformer_layers_per_block=transformer_layers_per_block,
947
+ in_channels=in_channels,
948
+ out_channels=out_channels,
949
+ temb_channels=temb_channels,
950
+ add_downsample=add_downsample,
951
+ resnet_eps=resnet_eps,
952
+ resnet_act_fn=resnet_act_fn,
953
+ resnet_groups=resnet_groups,
954
+ downsample_padding=downsample_padding,
955
+ cross_attention_dim=cross_attention_dim,
956
+ num_attention_heads=num_attention_heads,
957
+ use_linear_projection=use_linear_projection,
958
+ only_cross_attention=only_cross_attention,
959
+ upcast_attention=upcast_attention,
960
+ resnet_time_scale_shift=resnet_time_scale_shift,
961
+ )
962
+ raise ValueError(f"{down_block_type} does not exist.")
963
+
964
+
965
+ def get_up_block(
966
+ up_block_type,
967
+ num_layers,
968
+ in_channels,
969
+ out_channels,
970
+ prev_output_channel,
971
+ temb_channels,
972
+ add_upsample,
973
+ resnet_eps,
974
+ resnet_act_fn,
975
+ transformer_layers_per_block=1,
976
+ num_attention_heads=None,
977
+ resnet_groups=None,
978
+ cross_attention_dim=None,
979
+ use_linear_projection=False,
980
+ only_cross_attention=False,
981
+ upcast_attention=False,
982
+ resnet_time_scale_shift="default",
983
+ ):
984
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
985
+ if up_block_type == "UpBlock2D":
986
+ return UpBlock2D(
987
+ num_layers=num_layers,
988
+ in_channels=in_channels,
989
+ out_channels=out_channels,
990
+ prev_output_channel=prev_output_channel,
991
+ temb_channels=temb_channels,
992
+ add_upsample=add_upsample,
993
+ resnet_eps=resnet_eps,
994
+ resnet_act_fn=resnet_act_fn,
995
+ resnet_groups=resnet_groups,
996
+ resnet_time_scale_shift=resnet_time_scale_shift,
997
+ )
998
+ elif up_block_type == "CrossAttnUpBlock2D":
999
+ if cross_attention_dim is None:
1000
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
1001
+ return CrossAttnUpBlock2D(
1002
+ num_layers=num_layers,
1003
+ transformer_layers_per_block=transformer_layers_per_block,
1004
+ in_channels=in_channels,
1005
+ out_channels=out_channels,
1006
+ prev_output_channel=prev_output_channel,
1007
+ temb_channels=temb_channels,
1008
+ add_upsample=add_upsample,
1009
+ resnet_eps=resnet_eps,
1010
+ resnet_act_fn=resnet_act_fn,
1011
+ resnet_groups=resnet_groups,
1012
+ cross_attention_dim=cross_attention_dim,
1013
+ num_attention_heads=num_attention_heads,
1014
+ use_linear_projection=use_linear_projection,
1015
+ only_cross_attention=only_cross_attention,
1016
+ upcast_attention=upcast_attention,
1017
+ resnet_time_scale_shift=resnet_time_scale_shift,
1018
+ )
1019
+ raise ValueError(f"{up_block_type} does not exist.")
1020
+
1021
+
1022
+ class CrossAttnDownBlock2D(nn.Module):
1023
+ def __init__(
1024
+ self,
1025
+ in_channels: int,
1026
+ out_channels: int,
1027
+ temb_channels: int,
1028
+ dropout: float = 0.0,
1029
+ num_layers: int = 1,
1030
+ transformer_layers_per_block: int = 1,
1031
+ resnet_eps: float = 1e-6,
1032
+ resnet_time_scale_shift: str = "default",
1033
+ resnet_act_fn: str = "swish",
1034
+ resnet_groups: int = 32,
1035
+ resnet_pre_norm: bool = True,
1036
+ num_attention_heads=1,
1037
+ cross_attention_dim=1280,
1038
+ output_scale_factor=1.0,
1039
+ downsample_padding=1,
1040
+ add_downsample=True,
1041
+ use_linear_projection=False,
1042
+ only_cross_attention=False,
1043
+ upcast_attention=False,
1044
+ ):
1045
+ super().__init__()
1046
+ resnets = []
1047
+ attentions = []
1048
+
1049
+ self.has_cross_attention = True
1050
+ self.num_attention_heads = num_attention_heads
1051
+
1052
+ if isinstance(cross_attention_dim, int):
1053
+ cross_attention_dim = (cross_attention_dim,)
1054
+ if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4:
1055
+ raise ValueError(
1056
+ "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention "
1057
+ f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}"
1058
+ )
1059
+ self.cross_attention_dim = cross_attention_dim
1060
+
1061
+ for i in range(num_layers):
1062
+ in_channels = in_channels if i == 0 else out_channels
1063
+ resnets.append(
1064
+ ResnetBlock2D(
1065
+ in_channels=in_channels,
1066
+ out_channels=out_channels,
1067
+ temb_channels=temb_channels,
1068
+ eps=resnet_eps,
1069
+ groups=resnet_groups,
1070
+ dropout=dropout,
1071
+ time_embedding_norm=resnet_time_scale_shift,
1072
+ non_linearity=resnet_act_fn,
1073
+ output_scale_factor=output_scale_factor,
1074
+ pre_norm=resnet_pre_norm,
1075
+ )
1076
+ )
1077
+ for j in range(len(cross_attention_dim)):
1078
+ attentions.append(
1079
+ Transformer2DModel(
1080
+ num_attention_heads,
1081
+ out_channels // num_attention_heads,
1082
+ in_channels=out_channels,
1083
+ num_layers=transformer_layers_per_block,
1084
+ cross_attention_dim=cross_attention_dim[j],
1085
+ norm_num_groups=resnet_groups,
1086
+ use_linear_projection=use_linear_projection,
1087
+ only_cross_attention=only_cross_attention,
1088
+ upcast_attention=upcast_attention,
1089
+ double_self_attention=True if cross_attention_dim[j] is None else False,
1090
+ )
1091
+ )
1092
+ self.attentions = nn.ModuleList(attentions)
1093
+ self.resnets = nn.ModuleList(resnets)
1094
+
1095
+ if add_downsample:
1096
+ self.downsamplers = nn.ModuleList(
1097
+ [
1098
+ Downsample2D(
1099
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
1100
+ )
1101
+ ]
1102
+ )
1103
+ else:
1104
+ self.downsamplers = None
1105
+
1106
+ self.gradient_checkpointing = False
1107
+
1108
+ def forward(
1109
+ self,
1110
+ hidden_states: torch.FloatTensor,
1111
+ temb: Optional[torch.FloatTensor] = None,
1112
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1113
+ attention_mask: Optional[torch.FloatTensor] = None,
1114
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1115
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1116
+ encoder_hidden_states_1: Optional[torch.FloatTensor] = None,
1117
+ encoder_attention_mask_1: Optional[torch.FloatTensor] = None,
1118
+ ):
1119
+ output_states = ()
1120
+ num_layers = len(self.resnets)
1121
+ num_attention_per_layer = len(self.attentions) // num_layers
1122
+
1123
+ encoder_hidden_states_1 = (
1124
+ encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states
1125
+ )
1126
+ encoder_attention_mask_1 = (
1127
+ encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask
1128
+ )
1129
+
1130
+ for i in range(num_layers):
1131
+ if self.training and self.gradient_checkpointing:
1132
+
1133
+ def create_custom_forward(module, return_dict=None):
1134
+ def custom_forward(*inputs):
1135
+ if return_dict is not None:
1136
+ return module(*inputs, return_dict=return_dict)
1137
+ else:
1138
+ return module(*inputs)
1139
+
1140
+ return custom_forward
1141
+
1142
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1143
+ hidden_states = torch.utils.checkpoint.checkpoint(
1144
+ create_custom_forward(self.resnets[i]),
1145
+ hidden_states,
1146
+ temb,
1147
+ **ckpt_kwargs,
1148
+ )
1149
+ for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
1150
+ if cross_attention_dim is not None and idx <= 1:
1151
+ forward_encoder_hidden_states = encoder_hidden_states
1152
+ forward_encoder_attention_mask = encoder_attention_mask
1153
+ elif cross_attention_dim is not None and idx > 1:
1154
+ forward_encoder_hidden_states = encoder_hidden_states_1
1155
+ forward_encoder_attention_mask = encoder_attention_mask_1
1156
+ else:
1157
+ forward_encoder_hidden_states = None
1158
+ forward_encoder_attention_mask = None
1159
+ hidden_states = torch.utils.checkpoint.checkpoint(
1160
+ create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
1161
+ hidden_states,
1162
+ forward_encoder_hidden_states,
1163
+ None, # timestep
1164
+ None, # class_labels
1165
+ cross_attention_kwargs,
1166
+ attention_mask,
1167
+ forward_encoder_attention_mask,
1168
+ **ckpt_kwargs,
1169
+ )[0]
1170
+ else:
1171
+ hidden_states = self.resnets[i](hidden_states, temb)
1172
+ for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
1173
+ if cross_attention_dim is not None and idx <= 1:
1174
+ forward_encoder_hidden_states = encoder_hidden_states
1175
+ forward_encoder_attention_mask = encoder_attention_mask
1176
+ elif cross_attention_dim is not None and idx > 1:
1177
+ forward_encoder_hidden_states = encoder_hidden_states_1
1178
+ forward_encoder_attention_mask = encoder_attention_mask_1
1179
+ else:
1180
+ forward_encoder_hidden_states = None
1181
+ forward_encoder_attention_mask = None
1182
+ hidden_states = self.attentions[i * num_attention_per_layer + idx](
1183
+ hidden_states,
1184
+ attention_mask=attention_mask,
1185
+ encoder_hidden_states=forward_encoder_hidden_states,
1186
+ encoder_attention_mask=forward_encoder_attention_mask,
1187
+ return_dict=False,
1188
+ )[0]
1189
+
1190
+ output_states = output_states + (hidden_states,)
1191
+
1192
+ if self.downsamplers is not None:
1193
+ for downsampler in self.downsamplers:
1194
+ hidden_states = downsampler(hidden_states)
1195
+
1196
+ output_states = output_states + (hidden_states,)
1197
+
1198
+ return hidden_states, output_states
1199
+
1200
+
1201
+ class UNetMidBlock2DCrossAttn(nn.Module):
1202
+ def __init__(
1203
+ self,
1204
+ in_channels: int,
1205
+ temb_channels: int,
1206
+ dropout: float = 0.0,
1207
+ num_layers: int = 1,
1208
+ transformer_layers_per_block: int = 1,
1209
+ resnet_eps: float = 1e-6,
1210
+ resnet_time_scale_shift: str = "default",
1211
+ resnet_act_fn: str = "swish",
1212
+ resnet_groups: int = 32,
1213
+ resnet_pre_norm: bool = True,
1214
+ num_attention_heads=1,
1215
+ output_scale_factor=1.0,
1216
+ cross_attention_dim=1280,
1217
+ use_linear_projection=False,
1218
+ upcast_attention=False,
1219
+ ):
1220
+ super().__init__()
1221
+
1222
+ self.has_cross_attention = True
1223
+ self.num_attention_heads = num_attention_heads
1224
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
1225
+
1226
+ if isinstance(cross_attention_dim, int):
1227
+ cross_attention_dim = (cross_attention_dim,)
1228
+ if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4:
1229
+ raise ValueError(
1230
+ "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention "
1231
+ f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}"
1232
+ )
1233
+ self.cross_attention_dim = cross_attention_dim
1234
+
1235
+ # there is always at least one resnet
1236
+ resnets = [
1237
+ ResnetBlock2D(
1238
+ in_channels=in_channels,
1239
+ out_channels=in_channels,
1240
+ temb_channels=temb_channels,
1241
+ eps=resnet_eps,
1242
+ groups=resnet_groups,
1243
+ dropout=dropout,
1244
+ time_embedding_norm=resnet_time_scale_shift,
1245
+ non_linearity=resnet_act_fn,
1246
+ output_scale_factor=output_scale_factor,
1247
+ pre_norm=resnet_pre_norm,
1248
+ )
1249
+ ]
1250
+ attentions = []
1251
+
1252
+ for i in range(num_layers):
1253
+ for j in range(len(cross_attention_dim)):
1254
+ attentions.append(
1255
+ Transformer2DModel(
1256
+ num_attention_heads,
1257
+ in_channels // num_attention_heads,
1258
+ in_channels=in_channels,
1259
+ num_layers=transformer_layers_per_block,
1260
+ cross_attention_dim=cross_attention_dim[j],
1261
+ norm_num_groups=resnet_groups,
1262
+ use_linear_projection=use_linear_projection,
1263
+ upcast_attention=upcast_attention,
1264
+ double_self_attention=True if cross_attention_dim[j] is None else False,
1265
+ )
1266
+ )
1267
+ resnets.append(
1268
+ ResnetBlock2D(
1269
+ in_channels=in_channels,
1270
+ out_channels=in_channels,
1271
+ temb_channels=temb_channels,
1272
+ eps=resnet_eps,
1273
+ groups=resnet_groups,
1274
+ dropout=dropout,
1275
+ time_embedding_norm=resnet_time_scale_shift,
1276
+ non_linearity=resnet_act_fn,
1277
+ output_scale_factor=output_scale_factor,
1278
+ pre_norm=resnet_pre_norm,
1279
+ )
1280
+ )
1281
+
1282
+ self.attentions = nn.ModuleList(attentions)
1283
+ self.resnets = nn.ModuleList(resnets)
1284
+
1285
+ self.gradient_checkpointing = False
1286
+
1287
+ def forward(
1288
+ self,
1289
+ hidden_states: torch.FloatTensor,
1290
+ temb: Optional[torch.FloatTensor] = None,
1291
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1292
+ attention_mask: Optional[torch.FloatTensor] = None,
1293
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1294
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1295
+ encoder_hidden_states_1: Optional[torch.FloatTensor] = None,
1296
+ encoder_attention_mask_1: Optional[torch.FloatTensor] = None,
1297
+ ) -> torch.FloatTensor:
1298
+ hidden_states = self.resnets[0](hidden_states, temb)
1299
+ num_attention_per_layer = len(self.attentions) // (len(self.resnets) - 1)
1300
+
1301
+ encoder_hidden_states_1 = (
1302
+ encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states
1303
+ )
1304
+ encoder_attention_mask_1 = (
1305
+ encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask
1306
+ )
1307
+
1308
+ for i in range(len(self.resnets[1:])):
1309
+ if self.training and self.gradient_checkpointing:
1310
+
1311
+ def create_custom_forward(module, return_dict=None):
1312
+ def custom_forward(*inputs):
1313
+ if return_dict is not None:
1314
+ return module(*inputs, return_dict=return_dict)
1315
+ else:
1316
+ return module(*inputs)
1317
+
1318
+ return custom_forward
1319
+
1320
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1321
+ for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
1322
+ if cross_attention_dim is not None and idx <= 1:
1323
+ forward_encoder_hidden_states = encoder_hidden_states
1324
+ forward_encoder_attention_mask = encoder_attention_mask
1325
+ elif cross_attention_dim is not None and idx > 1:
1326
+ forward_encoder_hidden_states = encoder_hidden_states_1
1327
+ forward_encoder_attention_mask = encoder_attention_mask_1
1328
+ else:
1329
+ forward_encoder_hidden_states = None
1330
+ forward_encoder_attention_mask = None
1331
+ hidden_states = torch.utils.checkpoint.checkpoint(
1332
+ create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
1333
+ hidden_states,
1334
+ forward_encoder_hidden_states,
1335
+ None, # timestep
1336
+ None, # class_labels
1337
+ cross_attention_kwargs,
1338
+ attention_mask,
1339
+ forward_encoder_attention_mask,
1340
+ **ckpt_kwargs,
1341
+ )[0]
1342
+ hidden_states = torch.utils.checkpoint.checkpoint(
1343
+ create_custom_forward(self.resnets[i + 1]),
1344
+ hidden_states,
1345
+ temb,
1346
+ **ckpt_kwargs,
1347
+ )
1348
+ else:
1349
+ for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
1350
+ if cross_attention_dim is not None and idx <= 1:
1351
+ forward_encoder_hidden_states = encoder_hidden_states
1352
+ forward_encoder_attention_mask = encoder_attention_mask
1353
+ elif cross_attention_dim is not None and idx > 1:
1354
+ forward_encoder_hidden_states = encoder_hidden_states_1
1355
+ forward_encoder_attention_mask = encoder_attention_mask_1
1356
+ else:
1357
+ forward_encoder_hidden_states = None
1358
+ forward_encoder_attention_mask = None
1359
+ hidden_states = self.attentions[i * num_attention_per_layer + idx](
1360
+ hidden_states,
1361
+ attention_mask=attention_mask,
1362
+ encoder_hidden_states=forward_encoder_hidden_states,
1363
+ encoder_attention_mask=forward_encoder_attention_mask,
1364
+ return_dict=False,
1365
+ )[0]
1366
+
1367
+ hidden_states = self.resnets[i + 1](hidden_states, temb)
1368
+
1369
+ return hidden_states
1370
+
1371
+
1372
+ class CrossAttnUpBlock2D(nn.Module):
1373
+ def __init__(
1374
+ self,
1375
+ in_channels: int,
1376
+ out_channels: int,
1377
+ prev_output_channel: int,
1378
+ temb_channels: int,
1379
+ dropout: float = 0.0,
1380
+ num_layers: int = 1,
1381
+ transformer_layers_per_block: int = 1,
1382
+ resnet_eps: float = 1e-6,
1383
+ resnet_time_scale_shift: str = "default",
1384
+ resnet_act_fn: str = "swish",
1385
+ resnet_groups: int = 32,
1386
+ resnet_pre_norm: bool = True,
1387
+ num_attention_heads=1,
1388
+ cross_attention_dim=1280,
1389
+ output_scale_factor=1.0,
1390
+ add_upsample=True,
1391
+ use_linear_projection=False,
1392
+ only_cross_attention=False,
1393
+ upcast_attention=False,
1394
+ ):
1395
+ super().__init__()
1396
+ resnets = []
1397
+ attentions = []
1398
+
1399
+ self.has_cross_attention = True
1400
+ self.num_attention_heads = num_attention_heads
1401
+
1402
+ if isinstance(cross_attention_dim, int):
1403
+ cross_attention_dim = (cross_attention_dim,)
1404
+ if isinstance(cross_attention_dim, (list, tuple)) and len(cross_attention_dim) > 4:
1405
+ raise ValueError(
1406
+ "Only up to 4 cross-attention layers are supported. Ensure that the length of cross-attention "
1407
+ f"dims is less than or equal to 4. Got cross-attention dims {cross_attention_dim} of length {len(cross_attention_dim)}"
1408
+ )
1409
+ self.cross_attention_dim = cross_attention_dim
1410
+
1411
+ for i in range(num_layers):
1412
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1413
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
1414
+
1415
+ resnets.append(
1416
+ ResnetBlock2D(
1417
+ in_channels=resnet_in_channels + res_skip_channels,
1418
+ out_channels=out_channels,
1419
+ temb_channels=temb_channels,
1420
+ eps=resnet_eps,
1421
+ groups=resnet_groups,
1422
+ dropout=dropout,
1423
+ time_embedding_norm=resnet_time_scale_shift,
1424
+ non_linearity=resnet_act_fn,
1425
+ output_scale_factor=output_scale_factor,
1426
+ pre_norm=resnet_pre_norm,
1427
+ )
1428
+ )
1429
+ for j in range(len(cross_attention_dim)):
1430
+ attentions.append(
1431
+ Transformer2DModel(
1432
+ num_attention_heads,
1433
+ out_channels // num_attention_heads,
1434
+ in_channels=out_channels,
1435
+ num_layers=transformer_layers_per_block,
1436
+ cross_attention_dim=cross_attention_dim[j],
1437
+ norm_num_groups=resnet_groups,
1438
+ use_linear_projection=use_linear_projection,
1439
+ only_cross_attention=only_cross_attention,
1440
+ upcast_attention=upcast_attention,
1441
+ double_self_attention=True if cross_attention_dim[j] is None else False,
1442
+ )
1443
+ )
1444
+ self.attentions = nn.ModuleList(attentions)
1445
+ self.resnets = nn.ModuleList(resnets)
1446
+
1447
+ if add_upsample:
1448
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1449
+ else:
1450
+ self.upsamplers = None
1451
+
1452
+ self.gradient_checkpointing = False
1453
+
1454
+ def forward(
1455
+ self,
1456
+ hidden_states: torch.FloatTensor,
1457
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1458
+ temb: Optional[torch.FloatTensor] = None,
1459
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1460
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1461
+ upsample_size: Optional[int] = None,
1462
+ attention_mask: Optional[torch.FloatTensor] = None,
1463
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1464
+ encoder_hidden_states_1: Optional[torch.FloatTensor] = None,
1465
+ encoder_attention_mask_1: Optional[torch.FloatTensor] = None,
1466
+ ):
1467
+ num_layers = len(self.resnets)
1468
+ num_attention_per_layer = len(self.attentions) // num_layers
1469
+
1470
+ encoder_hidden_states_1 = (
1471
+ encoder_hidden_states_1 if encoder_hidden_states_1 is not None else encoder_hidden_states
1472
+ )
1473
+ encoder_attention_mask_1 = (
1474
+ encoder_attention_mask_1 if encoder_hidden_states_1 is not None else encoder_attention_mask
1475
+ )
1476
+
1477
+ for i in range(num_layers):
1478
+ # pop res hidden states
1479
+ res_hidden_states = res_hidden_states_tuple[-1]
1480
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1481
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1482
+
1483
+ if self.training and self.gradient_checkpointing:
1484
+
1485
+ def create_custom_forward(module, return_dict=None):
1486
+ def custom_forward(*inputs):
1487
+ if return_dict is not None:
1488
+ return module(*inputs, return_dict=return_dict)
1489
+ else:
1490
+ return module(*inputs)
1491
+
1492
+ return custom_forward
1493
+
1494
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1495
+ hidden_states = torch.utils.checkpoint.checkpoint(
1496
+ create_custom_forward(self.resnets[i]),
1497
+ hidden_states,
1498
+ temb,
1499
+ **ckpt_kwargs,
1500
+ )
1501
+ for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
1502
+ if cross_attention_dim is not None and idx <= 1:
1503
+ forward_encoder_hidden_states = encoder_hidden_states
1504
+ forward_encoder_attention_mask = encoder_attention_mask
1505
+ elif cross_attention_dim is not None and idx > 1:
1506
+ forward_encoder_hidden_states = encoder_hidden_states_1
1507
+ forward_encoder_attention_mask = encoder_attention_mask_1
1508
+ else:
1509
+ forward_encoder_hidden_states = None
1510
+ forward_encoder_attention_mask = None
1511
+ hidden_states = torch.utils.checkpoint.checkpoint(
1512
+ create_custom_forward(self.attentions[i * num_attention_per_layer + idx], return_dict=False),
1513
+ hidden_states,
1514
+ forward_encoder_hidden_states,
1515
+ None, # timestep
1516
+ None, # class_labels
1517
+ cross_attention_kwargs,
1518
+ attention_mask,
1519
+ forward_encoder_attention_mask,
1520
+ **ckpt_kwargs,
1521
+ )[0]
1522
+ else:
1523
+ hidden_states = self.resnets[i](hidden_states, temb)
1524
+ for idx, cross_attention_dim in enumerate(self.cross_attention_dim):
1525
+ if cross_attention_dim is not None and idx <= 1:
1526
+ forward_encoder_hidden_states = encoder_hidden_states
1527
+ forward_encoder_attention_mask = encoder_attention_mask
1528
+ elif cross_attention_dim is not None and idx > 1:
1529
+ forward_encoder_hidden_states = encoder_hidden_states_1
1530
+ forward_encoder_attention_mask = encoder_attention_mask_1
1531
+ else:
1532
+ forward_encoder_hidden_states = None
1533
+ forward_encoder_attention_mask = None
1534
+ hidden_states = self.attentions[i * num_attention_per_layer + idx](
1535
+ hidden_states,
1536
+ attention_mask=attention_mask,
1537
+ encoder_hidden_states=forward_encoder_hidden_states,
1538
+ encoder_attention_mask=forward_encoder_attention_mask,
1539
+ return_dict=False,
1540
+ )[0]
1541
+
1542
+ if self.upsamplers is not None:
1543
+ for upsampler in self.upsamplers:
1544
+ hidden_states = upsampler(hidden_states, upsample_size)
1545
+
1546
+ return hidden_states
pipeline/morph_pipeline_successed_ver1.py ADDED
@@ -0,0 +1,1435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from audio_encoder.AudioMAE import AudioMAEConditionCTPoolRand, extract_kaldi_fbank_feature
2
+ import torchaudio
3
+ import torchaudio.transforms as T
4
+ import torch.nn.functional as F
5
+ import inspect
6
+ from typing import Any, Callable, Dict, List, Optional, Union
7
+ from APadapter.ap_adapter.attention_processor import AttnProcessor2_0, IPAttnProcessor2_0
8
+ import random
9
+ import os
10
+ import scipy
11
+ import safetensors
12
+ import numpy as np
13
+ import torch
14
+ from transformers import (
15
+ ClapFeatureExtractor,
16
+ ClapModel,
17
+ GPT2Model,
18
+ RobertaTokenizer,
19
+ RobertaTokenizerFast,
20
+ SpeechT5HifiGan,
21
+ T5EncoderModel,
22
+ T5Tokenizer,
23
+ T5TokenizerFast,
24
+ )
25
+
26
+ from diffusers.loaders import AttnProcsLayers
27
+ from diffusers import AutoencoderKL
28
+ from diffusers.schedulers import KarrasDiffusionSchedulers
29
+ from diffusers.utils import (
30
+ is_accelerate_available,
31
+ is_accelerate_version,
32
+ is_librosa_available,
33
+ logging,
34
+ replace_example_docstring,
35
+ )
36
+ from diffusers.utils.torch_utils import randn_tensor
37
+ from diffusers.pipelines.pipeline_utils import AudioPipelineOutput, DiffusionPipeline
38
+ from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
39
+ from diffusers.loaders import TextualInversionLoaderMixin
40
+
41
+ from tqdm import tqdm # for progress bar
42
+ from utils.lora_utils_successed_ver1 import train_lora, load_lora, wav_to_mel
43
+ from utils.model_utils import slerp, do_replace_attn
44
+ from utils.alpha_scheduler import AlphaScheduler
45
+ from audioldm.utils import default_audioldm_config
46
+ from audioldm.audio import TacotronSTFT, read_wav_file
47
+ from audioldm.audio.tools import get_mel_from_wav, _pad_spec, normalize_wav, pad_wav
48
+ if is_librosa_available():
49
+ import librosa
50
+ import warnings
51
+ import matplotlib.pyplot as plt
52
+
53
+
54
+ from .pipeline_audioldm2 import AudioLDM2Pipeline
55
+
56
+ pipeline_trained = AudioLDM2Pipeline.from_pretrained("cvssp/audioldm2-large", torch_dtype=torch.float32)
57
+ pipeline_trained = pipeline_trained.to("cuda")
58
+ layer_num = 0
59
+ cross = [None, None, 768, 768, 1024, 1024, None, None]
60
+ unet = pipeline_trained.unet
61
+
62
+
63
+ attn_procs = {}
64
+ for name in unet.attn_processors.keys():
65
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
66
+ if name.startswith("mid_block"):
67
+ hidden_size = unet.config.block_out_channels[-1]
68
+ elif name.startswith("up_blocks"):
69
+ block_id = int(name[len("up_blocks.")])
70
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
71
+ elif name.startswith("down_blocks"):
72
+ block_id = int(name[len("down_blocks.")])
73
+ hidden_size = unet.config.block_out_channels[block_id]
74
+
75
+ if cross_attention_dim is None:
76
+ attn_procs[name] = AttnProcessor2_0()
77
+ else:
78
+ cross_attention_dim = cross[layer_num % 8]
79
+ layer_num += 1
80
+ if cross_attention_dim == 768:
81
+ attn_procs[name] = IPAttnProcessor2_0(
82
+ hidden_size=hidden_size,
83
+ name=name,
84
+ cross_attention_dim=cross_attention_dim,
85
+ scale=0.5,
86
+ num_tokens=8,
87
+ do_copy=False
88
+ ).to("cuda", dtype=torch.float32)
89
+ else:
90
+ attn_procs[name] = AttnProcessor2_0()
91
+
92
+ state_dict = torch.load('/Data/home/Dennis/DeepMIR-2024/Final_Project/AP-adapter/pytorch_model.bin', map_location="cuda")
93
+ for name, processor in attn_procs.items():
94
+ if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
95
+ weight_name_v = name + ".to_v_ip.weight"
96
+ weight_name_k = name + ".to_k_ip.weight"
97
+ processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].half())
98
+ processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].half())
99
+
100
+ unet.set_attn_processor(attn_procs)
101
+ unet.to("cuda", dtype=torch.float32)
102
+
103
+
104
+
105
+
106
+ def visualize_mel_spectrogram(mel_spect_tensor, output_path=None):
107
+ mel_spect_array = mel_spect_tensor.squeeze().transpose(1, 0).detach().cpu().numpy()
108
+ plt.figure(figsize=(10, 5))
109
+ plt.imshow(mel_spect_array, aspect='auto', origin='lower', cmap='magma')
110
+ plt.colorbar(label="Log-Mel Energy")
111
+ plt.title("Mel-Spectrogram")
112
+ plt.xlabel("Time")
113
+ plt.ylabel("Mel Frequency Bins")
114
+ plt.tight_layout()
115
+ if output_path:
116
+ plt.savefig(output_path, dpi=300)
117
+ print(f"Mel-spectrogram saved to {output_path}")
118
+ else:
119
+ plt.show()
120
+
121
+
122
+ warnings.filterwarnings("ignore", category=FutureWarning)
123
+ warnings.filterwarnings("ignore", category=UserWarning)
124
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
125
+
126
+ class StoreProcessor():
127
+ def __init__(self, original_processor, value_dict, name):
128
+ self.original_processor = original_processor
129
+ self.value_dict = value_dict
130
+ self.name = name
131
+ self.value_dict[self.name] = dict()
132
+ self.id = 0
133
+
134
+ def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
135
+ # Is self attention
136
+ if encoder_hidden_states is None:
137
+ # 將 hidden_states 存入 value_dict 中,名稱為 self.name
138
+ # 如果輸入沒有 encoder_hidden_states,表示是自注意力層,則將輸入的 hidden_states 儲存在 value_dict 中。
139
+ # print(f'In StoreProcessor: {self.name} {self.id}')
140
+ self.value_dict[self.name][self.id] = hidden_states.detach()
141
+ self.id += 1
142
+ # 調用原始處理器,執行正常的注意力操作
143
+ res = self.original_processor(attn, hidden_states, *args,
144
+ encoder_hidden_states=encoder_hidden_states,
145
+ attention_mask=attention_mask,
146
+ **kwargs)
147
+ return res
148
+
149
+
150
+ class LoadProcessor():
151
+ def __init__(self, original_processor, name, aud1_dict, aud2_dict, alpha, beta=0, lamd=0.6):
152
+ super().__init__()
153
+ self.original_processor = original_processor
154
+ self.name = name
155
+ self.aud1_dict = aud1_dict
156
+ self.aud2_dict = aud2_dict
157
+ self.alpha = alpha
158
+ self.beta = beta
159
+ self.lamd = lamd
160
+ self.id = 0
161
+
162
+ def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
163
+ # Is self attention
164
+ # 判斷是否是自注意力(self-attention)
165
+ if encoder_hidden_states is None:
166
+ # 如果當前索引小於 10 倍的 self.lamd,使用自定義的混合邏輯
167
+ if self.id < 10 * self.lamd:
168
+ map0 = self.aud1_dict[self.name][self.id]
169
+ map1 = self.aud2_dict[self.name][self.id]
170
+ cross_map = self.beta * hidden_states + \
171
+ (1 - self.beta) * ((1 - self.alpha) * map0 + self.alpha * map1)
172
+ # 調用原始處理器,將 cross_map 作為 encoder_hidden_states 傳入
173
+ res = self.original_processor(attn, hidden_states, *args,
174
+ encoder_hidden_states=cross_map,
175
+ attention_mask=attention_mask,
176
+ **kwargs)
177
+ else:
178
+ # 否則,使用原始的 encoder_hidden_states(可能為 None)
179
+ res = self.original_processor(attn, hidden_states, *args,
180
+ encoder_hidden_states=encoder_hidden_states,
181
+ attention_mask=attention_mask,
182
+ **kwargs)
183
+
184
+ self.id += 1
185
+ # 如果索引到達 self.aud1_dict[self.name] 的長度,重置索引為 0
186
+ if self.id == len(self.aud1_dict[self.name]):
187
+ self.id = 0
188
+ else:
189
+ # 如果是跨注意力(encoder_hidden_states 不為 None),直接使用原始處理器
190
+ res = self.original_processor(attn, hidden_states, *args,
191
+ encoder_hidden_states=encoder_hidden_states,
192
+ attention_mask=attention_mask,
193
+ **kwargs)
194
+
195
+ return res
196
+
197
+
198
+ def prepare_inputs_for_generation(
199
+ inputs_embeds,
200
+ attention_mask=None,
201
+ past_key_values=None,
202
+ **kwargs,):
203
+ if past_key_values is not None:
204
+ # only last token for inputs_embeds if past is defined in kwargs
205
+ inputs_embeds = inputs_embeds[:, -1:]
206
+
207
+ return {
208
+ "inputs_embeds": inputs_embeds,
209
+ "attention_mask": attention_mask,
210
+ "past_key_values": past_key_values,
211
+ "use_cache": kwargs.get("use_cache"),
212
+ }
213
+
214
+
215
+ class AudioLDM2MorphPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
216
+ r"""
217
+ Pipeline for text-to-audio generation using AudioLDM2.
218
+
219
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
220
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
221
+
222
+ Args:
223
+ vae ([`AutoencoderKL`]):
224
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
225
+ text_encoder ([`~transformers.ClapModel`]):
226
+ First frozen text-encoder. AudioLDM2 uses the joint audio-text embedding model
227
+ [CLAP](https://huggingface.co/docs/transformers/model_doc/clap#transformers.CLAPTextModelWithProjection),
228
+ specifically the [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. The
229
+ text branch is used to encode the text prompt to a prompt embedding. The full audio-text model is used to
230
+ rank generated waveforms against the text prompt by computing similarity scores.
231
+ text_encoder_2 ([`~transformers.T5EncoderModel`]):
232
+ Second frozen text-encoder. AudioLDM2 uses the encoder of
233
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
234
+ [google/flan-t5-large](https://huggingface.co/google/flan-t5-large) variant.
235
+ projection_model ([`AudioLDM2ProjectionModel`]):
236
+ A trained model used to linearly project the hidden-states from the first and second text encoder models
237
+ and insert learned SOS and EOS token embeddings. The projected hidden-states from the two text encoders are
238
+ concatenated to give the input to the language model.
239
+ language_model ([`~transformers.GPT2Model`]):
240
+ An auto-regressive language model used to generate a sequence of hidden-states conditioned on the projected
241
+ outputs from the two text encoders.
242
+ tokenizer ([`~transformers.RobertaTokenizer`]):
243
+ Tokenizer to tokenize text for the first frozen text-encoder.
244
+ tokenizer_2 ([`~transformers.T5Tokenizer`]):
245
+ Tokenizer to tokenize text for the second frozen text-encoder.
246
+ feature_extractor ([`~transformers.ClapFeatureExtractor`]):
247
+ Feature extractor to pre-process generated audio waveforms to log-mel spectrograms for automatic scoring.
248
+ unet ([`UNet2DConditionModel`]):
249
+ A `UNet2DConditionModel` to denoise the encoded audio latents.
250
+ scheduler ([`SchedulerMixin`]):
251
+ A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of
252
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
253
+ vocoder ([`~transformers.SpeechT5HifiGan`]):
254
+ Vocoder of class `SpeechT5HifiGan` to convert the mel-spectrogram latents to the final audio waveform.
255
+ """
256
+
257
+ def __init__(
258
+ self,
259
+ vae: AutoencoderKL,
260
+ text_encoder: ClapModel,
261
+ text_encoder_2: T5EncoderModel,
262
+ projection_model: AudioLDM2ProjectionModel,
263
+ language_model: GPT2Model,
264
+ tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
265
+ tokenizer_2: Union[T5Tokenizer, T5TokenizerFast],
266
+ feature_extractor: ClapFeatureExtractor,
267
+ unet: AudioLDM2UNet2DConditionModel,
268
+ scheduler: KarrasDiffusionSchedulers,
269
+ vocoder: SpeechT5HifiGan,
270
+ ):
271
+ super().__init__()
272
+
273
+ self.register_modules(
274
+ vae=vae,
275
+ text_encoder=text_encoder,
276
+ text_encoder_2=text_encoder_2,
277
+ projection_model=projection_model,
278
+ language_model=language_model,
279
+ tokenizer=tokenizer,
280
+ tokenizer_2=tokenizer_2,
281
+ feature_extractor=feature_extractor,
282
+ unet=unet,
283
+ scheduler=scheduler,
284
+ vocoder=vocoder,
285
+ )
286
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
287
+ self.aud1_dict = dict()
288
+ self.aud2_dict = dict()
289
+
290
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
291
+ def enable_vae_slicing(self):
292
+ r"""
293
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
294
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
295
+ """
296
+ self.vae.enable_slicing()
297
+
298
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
299
+ def disable_vae_slicing(self):
300
+ r"""
301
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
302
+ computing decoding in one step.
303
+ """
304
+ self.vae.disable_slicing()
305
+
306
+ def enable_model_cpu_offload(self, gpu_id=0):
307
+ r"""
308
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
309
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
310
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
311
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
312
+ """
313
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
314
+ from accelerate import cpu_offload_with_hook
315
+ else:
316
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
317
+
318
+ device = torch.device(f"cuda:{gpu_id}")
319
+
320
+ if self.device.type != "cpu":
321
+ self.to("cpu", silence_dtype_warnings=True)
322
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
323
+
324
+ model_sequence = [
325
+ self.text_encoder.text_model,
326
+ self.text_encoder.text_projection,
327
+ self.text_encoder_2,
328
+ self.projection_model,
329
+ self.language_model,
330
+ self.unet,
331
+ self.vae,
332
+ self.vocoder,
333
+ self.text_encoder,
334
+ ]
335
+
336
+ hook = None
337
+ for cpu_offloaded_model in model_sequence:
338
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
339
+
340
+ # We'll offload the last model manually.
341
+ self.final_offload_hook = hook
342
+
343
+ def generate_language_model(
344
+ self,
345
+ inputs_embeds: torch.Tensor = None,
346
+ max_new_tokens: int = 512,
347
+ **model_kwargs,
348
+ ):
349
+ """
350
+
351
+ Generates a sequence of hidden-states from the language model, conditioned on the embedding inputs.
352
+
353
+ Parameters:
354
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
355
+ The sequence used as a prompt for the generation.
356
+ max_new_tokens (`int`):
357
+ Number of new tokens to generate.
358
+ model_kwargs (`Dict[str, Any]`, *optional*):
359
+ Ad hoc parametrization of additional model-specific kwargs that will be forwarded to the `forward`
360
+ function of the model.
361
+
362
+ Return:
363
+ `inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
364
+ The sequence of generated hidden-states.
365
+ """
366
+ max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
367
+ model_kwargs = self.language_model._get_initial_cache_position(inputs_embeds, model_kwargs)
368
+ for _ in range(max_new_tokens):
369
+ # prepare model inputs
370
+ model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
371
+
372
+ # forward pass to get next hidden states
373
+ output = self.language_model(**model_inputs, return_dict=True)
374
+
375
+ next_hidden_states = output.last_hidden_state
376
+
377
+ # Update the model input
378
+ inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1)
379
+
380
+ # Update generated hidden states, model inputs, and length for next step
381
+ model_kwargs = self.language_model._update_model_kwargs_for_generation(output, model_kwargs)
382
+
383
+ return inputs_embeds[:, -max_new_tokens:, :]
384
+
385
+ def encode_prompt(
386
+ self,
387
+ prompt,
388
+ device,
389
+ num_waveforms_per_prompt,
390
+ do_classifier_free_guidance,
391
+ negative_prompt=None,
392
+ prompt_embeds: Optional[torch.FloatTensor] = None,
393
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
394
+ generated_prompt_embeds: Optional[torch.FloatTensor] = None,
395
+ negative_generated_prompt_embeds: Optional[torch.FloatTensor] = None,
396
+ attention_mask: Optional[torch.LongTensor] = None,
397
+ negative_attention_mask: Optional[torch.LongTensor] = None,
398
+ max_new_tokens: Optional[int] = None,
399
+ ):
400
+ r"""
401
+ Encodes the prompt into text encoder hidden states.
402
+
403
+ Args:
404
+ prompt (`str` or `List[str]`, *optional*):
405
+ prompt to be encoded
406
+ device (`torch.device`):
407
+ torch device
408
+ num_waveforms_per_prompt (`int`):
409
+ number of waveforms that should be generated per prompt
410
+ do_classifier_free_guidance (`bool`):
411
+ whether to use classifier free guidance or not
412
+ negative_prompt (`str` or `List[str]`, *optional*):
413
+ The prompt or prompts not to guide the audio generation. If not defined, one has to pass
414
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
415
+ less than `1`).
416
+ prompt_embeds (`torch.FloatTensor`, *optional*):
417
+ Pre-computed text embeddings from the Flan T5 model. Can be used to easily tweak text inputs, *e.g.*
418
+ prompt weighting. If not provided, text embeddings will be computed from `prompt` input argument.
419
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
420
+ Pre-computed negative text embeddings from the Flan T5 model. Can be used to easily tweak text inputs,
421
+ *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
422
+ `negative_prompt` input argument.
423
+ generated_prompt_embeds (`torch.FloatTensor`, *optional*):
424
+ Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs,
425
+ *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
426
+ argument.
427
+ negative_generated_prompt_embeds (`torch.FloatTensor`, *optional*):
428
+ Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text
429
+ inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
430
+ `negative_prompt` input argument.
431
+ attention_mask (`torch.LongTensor`, *optional*):
432
+ Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
433
+ be computed from `prompt` input argument.
434
+ negative_attention_mask (`torch.LongTensor`, *optional*):
435
+ Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention
436
+ mask will be computed from `negative_prompt` input argument.
437
+ max_new_tokens (`int`, *optional*, defaults to None):
438
+ The number of new tokens to generate with the GPT2 language model.
439
+ Returns:
440
+ prompt_embeds (`torch.FloatTensor`):
441
+ Text embeddings from the Flan T5 model.
442
+ attention_mask (`torch.LongTensor`):
443
+ Attention mask to be applied to the `prompt_embeds`.
444
+ generated_prompt_embeds (`torch.FloatTensor`):
445
+ Text embeddings generated from the GPT2 langauge model.
446
+
447
+ Example:
448
+
449
+ ```python
450
+ >>> import scipy
451
+ >>> import torch
452
+ >>> from diffusers import AudioLDM2Pipeline
453
+
454
+ >>> repo_id = "cvssp/audioldm2"
455
+ >>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
456
+ >>> pipe = pipe.to("cuda")
457
+
458
+ >>> # Get text embedding vectors
459
+ >>> prompt_embeds, attention_mask, generated_prompt_embeds = pipe.encode_prompt(
460
+ ... prompt="Techno music with a strong, upbeat tempo and high melodic riffs",
461
+ ... device="cuda",
462
+ ... do_classifier_free_guidance=True,
463
+ ... )
464
+
465
+ >>> # Pass text embeddings to pipeline for text-conditional audio generation
466
+ >>> audio = pipe(
467
+ ... prompt_embeds=prompt_embeds,
468
+ ... attention_mask=attention_mask,
469
+ ... generated_prompt_embeds=generated_prompt_embeds,
470
+ ... num_inference_steps=200,
471
+ ... audio_length_in_s=10.0,
472
+ ... ).audios[0]
473
+
474
+ >>> # save generated audio sample
475
+ >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio)
476
+ ```"""
477
+ # print("prompt",prompt)
478
+ if prompt is not None and isinstance(prompt, str):
479
+ batch_size = 1
480
+ elif prompt is not None and isinstance(prompt, list):
481
+ batch_size = len(prompt)
482
+ else:
483
+ batch_size = prompt_embeds.shape[0]
484
+
485
+ # Define tokenizers and text encoders
486
+ tokenizers = [self.tokenizer, self.tokenizer_2]
487
+ text_encoders = [self.text_encoder, self.text_encoder_2]
488
+
489
+ if prompt_embeds is None:
490
+ prompt_embeds_list = []
491
+ attention_mask_list = []
492
+
493
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
494
+ text_inputs = tokenizer(
495
+ prompt,
496
+ padding="max_length" if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) else True,
497
+ max_length=tokenizer.model_max_length,
498
+ truncation=True,
499
+ return_tensors="pt",
500
+ )
501
+ text_input_ids = text_inputs.input_ids
502
+ attention_mask = text_inputs.attention_mask
503
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
504
+
505
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
506
+ text_input_ids, untruncated_ids
507
+ ):
508
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
509
+ logger.warning(
510
+ f"The following part of your input was truncated because {text_encoder.config.model_type} can "
511
+ f"only handle sequences up to {tokenizer.model_max_length} tokens: {removed_text}"
512
+ )
513
+
514
+ text_input_ids = text_input_ids.to(device)
515
+ attention_mask = attention_mask.to(device)
516
+
517
+ if text_encoder.config.model_type == "clap":
518
+ prompt_embeds = text_encoder.get_text_features(
519
+ text_input_ids,
520
+ attention_mask=attention_mask,
521
+ )
522
+ # append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
523
+ prompt_embeds = prompt_embeds[:, None, :]
524
+ # make sure that we attend to this single hidden-state
525
+ attention_mask = attention_mask.new_ones((batch_size, 1))
526
+ else:
527
+ prompt_embeds = text_encoder(
528
+ text_input_ids,
529
+ attention_mask=attention_mask,
530
+ )
531
+ prompt_embeds = prompt_embeds[0]
532
+
533
+ prompt_embeds_list.append(prompt_embeds)
534
+ attention_mask_list.append(attention_mask)
535
+
536
+ projection_output = self.projection_model(
537
+ hidden_states=prompt_embeds_list[0],
538
+ hidden_states_1=prompt_embeds_list[1],
539
+ attention_mask=attention_mask_list[0],
540
+ attention_mask_1=attention_mask_list[1],
541
+ )
542
+ projected_prompt_embeds = projection_output.hidden_states
543
+ projected_attention_mask = projection_output.attention_mask
544
+
545
+ generated_prompt_embeds = self.generate_language_model(
546
+ projected_prompt_embeds,
547
+ attention_mask=projected_attention_mask,
548
+ max_new_tokens=max_new_tokens,
549
+ )
550
+
551
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
552
+ attention_mask = (
553
+ attention_mask.to(device=device)
554
+ if attention_mask is not None
555
+ else torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=device)
556
+ )
557
+ generated_prompt_embeds = generated_prompt_embeds.to(dtype=self.language_model.dtype, device=device)
558
+
559
+ bs_embed, seq_len, hidden_size = prompt_embeds.shape
560
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
561
+ prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
562
+ prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len, hidden_size)
563
+
564
+ # duplicate attention mask for each generation per prompt
565
+ attention_mask = attention_mask.repeat(1, num_waveforms_per_prompt)
566
+ attention_mask = attention_mask.view(bs_embed * num_waveforms_per_prompt, seq_len)
567
+
568
+ bs_embed, seq_len, hidden_size = generated_prompt_embeds.shape
569
+ # duplicate generated embeddings for each generation per prompt, using mps friendly method
570
+ generated_prompt_embeds = generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
571
+ generated_prompt_embeds = generated_prompt_embeds.view(
572
+ bs_embed * num_waveforms_per_prompt, seq_len, hidden_size
573
+ )
574
+
575
+ # get unconditional embeddings for classifier free guidance
576
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
577
+ uncond_tokens: List[str]
578
+ if negative_prompt is None:
579
+ uncond_tokens = [""] * batch_size
580
+ elif type(prompt) is not type(negative_prompt):
581
+ raise TypeError(
582
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
583
+ f" {type(prompt)}."
584
+ )
585
+ elif isinstance(negative_prompt, str):
586
+ uncond_tokens = [negative_prompt]
587
+ elif batch_size != len(negative_prompt):
588
+ raise ValueError(
589
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
590
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
591
+ " the batch size of `prompt`."
592
+ )
593
+ else:
594
+ uncond_tokens = negative_prompt
595
+
596
+ negative_prompt_embeds_list = []
597
+ negative_attention_mask_list = []
598
+ max_length = prompt_embeds.shape[1]
599
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
600
+ uncond_input = tokenizer(
601
+ uncond_tokens,
602
+ padding="max_length",
603
+ max_length=tokenizer.model_max_length
604
+ if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
605
+ else max_length,
606
+ truncation=True,
607
+ return_tensors="pt",
608
+ )
609
+
610
+ uncond_input_ids = uncond_input.input_ids.to(device)
611
+ negative_attention_mask = uncond_input.attention_mask.to(device)
612
+
613
+ if text_encoder.config.model_type == "clap":
614
+ negative_prompt_embeds = text_encoder.get_text_features(
615
+ uncond_input_ids,
616
+ attention_mask=negative_attention_mask,
617
+ )
618
+ # append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
619
+ negative_prompt_embeds = negative_prompt_embeds[:, None, :]
620
+ # make sure that we attend to this single hidden-state
621
+ negative_attention_mask = negative_attention_mask.new_ones((batch_size, 1))
622
+ else:
623
+ negative_prompt_embeds = text_encoder(
624
+ uncond_input_ids,
625
+ attention_mask=negative_attention_mask,
626
+ )
627
+ negative_prompt_embeds = negative_prompt_embeds[0]
628
+
629
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
630
+ negative_attention_mask_list.append(negative_attention_mask)
631
+
632
+ projection_output = self.projection_model(
633
+ hidden_states=negative_prompt_embeds_list[0],
634
+ hidden_states_1=negative_prompt_embeds_list[1],
635
+ attention_mask=negative_attention_mask_list[0],
636
+ attention_mask_1=negative_attention_mask_list[1],
637
+ )
638
+ negative_projected_prompt_embeds = projection_output.hidden_states
639
+ negative_projected_attention_mask = projection_output.attention_mask
640
+
641
+ negative_generated_prompt_embeds = self.generate_language_model(
642
+ negative_projected_prompt_embeds,
643
+ attention_mask=negative_projected_attention_mask,
644
+ max_new_tokens=max_new_tokens,
645
+ )
646
+
647
+ if do_classifier_free_guidance:
648
+ seq_len = negative_prompt_embeds.shape[1]
649
+
650
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
651
+ negative_attention_mask = (
652
+ negative_attention_mask.to(device=device)
653
+ if negative_attention_mask is not None
654
+ else torch.ones(negative_prompt_embeds.shape[:2], dtype=torch.long, device=device)
655
+ )
656
+ negative_generated_prompt_embeds = negative_generated_prompt_embeds.to(
657
+ dtype=self.language_model.dtype, device=device
658
+ )
659
+
660
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
661
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
662
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len, -1)
663
+
664
+ # duplicate unconditional attention mask for each generation per prompt
665
+ negative_attention_mask = negative_attention_mask.repeat(1, num_waveforms_per_prompt)
666
+ negative_attention_mask = negative_attention_mask.view(batch_size * num_waveforms_per_prompt, seq_len)
667
+
668
+ # duplicate unconditional generated embeddings for each generation per prompt
669
+ seq_len = negative_generated_prompt_embeds.shape[1]
670
+ negative_generated_prompt_embeds = negative_generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
671
+ negative_generated_prompt_embeds = negative_generated_prompt_embeds.view(
672
+ batch_size * num_waveforms_per_prompt, seq_len, -1
673
+ )
674
+
675
+ # For classifier free guidance, we need to do two forward passes.
676
+ # Here we concatenate the unconditional and text embeddings into a single batch
677
+ # to avoid doing two forward passes
678
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
679
+ attention_mask = torch.cat([negative_attention_mask, attention_mask])
680
+ generated_prompt_embeds = torch.cat([negative_generated_prompt_embeds, generated_prompt_embeds])
681
+
682
+ return prompt_embeds, attention_mask, generated_prompt_embeds
683
+
684
+ # Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.mel_spectrogram_to_waveform
685
+ def mel_spectrogram_to_waveform(self, mel_spectrogram):
686
+ if mel_spectrogram.dim() == 4:
687
+ mel_spectrogram = mel_spectrogram.squeeze(1)
688
+
689
+ waveform = self.vocoder(mel_spectrogram)
690
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
691
+ waveform = waveform.cpu().float()
692
+ return waveform
693
+
694
+ def score_waveforms(self, text, audio, num_waveforms_per_prompt, device, dtype):
695
+ if not is_librosa_available():
696
+ logger.info(
697
+ "Automatic scoring of the generated audio waveforms against the input prompt text requires the "
698
+ "`librosa` package to resample the generated waveforms. Returning the audios in the order they were "
699
+ "generated. To enable automatic scoring, install `librosa` with: `pip install librosa`."
700
+ )
701
+ return audio
702
+ inputs = self.tokenizer(text, return_tensors="pt", padding=True)
703
+ resampled_audio = librosa.resample(
704
+ audio.numpy(), orig_sr=self.vocoder.config.sampling_rate, target_sr=self.feature_extractor.sampling_rate
705
+ )
706
+ inputs["input_features"] = self.feature_extractor(
707
+ list(resampled_audio), return_tensors="pt", sampling_rate=self.feature_extractor.sampling_rate
708
+ ).input_features.type(dtype)
709
+ inputs = inputs.to(device)
710
+
711
+ # compute the audio-text similarity score using the CLAP model
712
+ logits_per_text = self.text_encoder(**inputs).logits_per_text
713
+ # sort by the highest matching generations per prompt
714
+ indices = torch.argsort(logits_per_text, dim=1, descending=True)[:, :num_waveforms_per_prompt]
715
+ audio = torch.index_select(audio, 0, indices.reshape(-1).cpu())
716
+ return audio
717
+
718
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
719
+ def prepare_extra_step_kwargs(self, generator, eta):
720
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
721
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
722
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
723
+ # and should be between [0, 1]
724
+
725
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
726
+ extra_step_kwargs = {}
727
+ if accepts_eta:
728
+ extra_step_kwargs["eta"] = eta
729
+
730
+ # check if the scheduler accepts generator
731
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
732
+ if accepts_generator:
733
+ extra_step_kwargs["generator"] = generator
734
+ return extra_step_kwargs
735
+
736
+ def check_inputs(
737
+ self,
738
+ prompt,
739
+ audio_length_in_s,
740
+ vocoder_upsample_factor,
741
+ callback_steps,
742
+ negative_prompt=None,
743
+ prompt_embeds=None,
744
+ negative_prompt_embeds=None,
745
+ generated_prompt_embeds=None,
746
+ negative_generated_prompt_embeds=None,
747
+ attention_mask=None,
748
+ negative_attention_mask=None,):
749
+ min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor
750
+ if audio_length_in_s < min_audio_length_in_s:
751
+ raise ValueError(
752
+ f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but "
753
+ f"is {audio_length_in_s}."
754
+ )
755
+
756
+ if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0:
757
+ raise ValueError(
758
+ f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the "
759
+ f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of "
760
+ f"{self.vae_scale_factor}."
761
+ )
762
+
763
+ if (callback_steps is None) or (
764
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
765
+ ):
766
+ raise ValueError(
767
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
768
+ f" {type(callback_steps)}."
769
+ )
770
+
771
+ if prompt is not None and prompt_embeds is not None:
772
+ raise ValueError(
773
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
774
+ " only forward one of the two."
775
+ )
776
+ elif prompt is None and (prompt_embeds is None or generated_prompt_embeds is None):
777
+ raise ValueError(
778
+ "Provide either `prompt`, or `prompt_embeds` and `generated_prompt_embeds`. Cannot leave "
779
+ "`prompt` undefined without specifying both `prompt_embeds` and `generated_prompt_embeds`."
780
+ )
781
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
782
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
783
+
784
+ if negative_prompt is not None and negative_prompt_embeds is not None:
785
+ raise ValueError(
786
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
787
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
788
+ )
789
+ elif negative_prompt_embeds is not None and negative_generated_prompt_embeds is None:
790
+ raise ValueError(
791
+ "Cannot forward `negative_prompt_embeds` without `negative_generated_prompt_embeds`. Ensure that"
792
+ "both arguments are specified"
793
+ )
794
+
795
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
796
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
797
+ raise ValueError(
798
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
799
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
800
+ f" {negative_prompt_embeds.shape}."
801
+ )
802
+ if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]:
803
+ raise ValueError(
804
+ "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
805
+ f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}"
806
+ )
807
+
808
+ if generated_prompt_embeds is not None and negative_generated_prompt_embeds is not None:
809
+ if generated_prompt_embeds.shape != negative_generated_prompt_embeds.shape:
810
+ raise ValueError(
811
+ "`generated_prompt_embeds` and `negative_generated_prompt_embeds` must have the same shape when "
812
+ f"passed directly, but got: `generated_prompt_embeds` {generated_prompt_embeds.shape} != "
813
+ f"`negative_generated_prompt_embeds` {negative_generated_prompt_embeds.shape}."
814
+ )
815
+ if (
816
+ negative_attention_mask is not None
817
+ and negative_attention_mask.shape != negative_prompt_embeds.shape[:2]
818
+ ):
819
+ raise ValueError(
820
+ "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
821
+ f"`attention_mask: {negative_attention_mask.shape} != `prompt_embeds` {negative_prompt_embeds.shape}"
822
+ )
823
+
824
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim
825
+ def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None):
826
+ shape = (
827
+ batch_size,
828
+ num_channels_latents,
829
+ height // self.vae_scale_factor,
830
+ self.vocoder.config.model_in_dim // self.vae_scale_factor,
831
+ )
832
+ if isinstance(generator, list) and len(generator) != batch_size:
833
+ raise ValueError(
834
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
835
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
836
+ )
837
+
838
+ if latents is None:
839
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
840
+ else:
841
+ latents = latents.to(device)
842
+
843
+ # scale the initial noise by the standard deviation required by the scheduler
844
+ latents = latents * self.scheduler.init_noise_sigma
845
+ return latents
846
+
847
+ def pre_check(self, audio_length_in_s, prompt, callback_steps, negative_prompt):
848
+ """
849
+ Step 0: Convert audio input length from seconds to spectrogram height
850
+ Step 1. Check inputs. Raise error if not correct
851
+ """
852
+ vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
853
+
854
+ if audio_length_in_s is None:
855
+ audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor
856
+
857
+ height = int(audio_length_in_s / vocoder_upsample_factor)
858
+
859
+ original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)
860
+ if height % self.vae_scale_factor != 0:
861
+ height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor
862
+ logger.info(
863
+ f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} "
864
+ f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the "
865
+ f"denoising process."
866
+ )
867
+ # 1. Check inputs. Raise error if not correct
868
+ self.check_inputs(
869
+ prompt,
870
+ audio_length_in_s,
871
+ vocoder_upsample_factor,
872
+ callback_steps,
873
+ negative_prompt,
874
+ )
875
+
876
+ return height, original_waveform_length
877
+
878
+ def encode_prompt_for_2_sources(self, prompt_1, prompt_2, negative_prompt_1, negative_prompt_2, max_new_tokens, device, num_waveforms_per_prompt, do_classifier_free_guidance):
879
+ prompt_embeds_1, attention_mask_1, generated_prompt_embeds_1 = self.encode_prompt(
880
+ prompt_1,
881
+ device,
882
+ num_waveforms_per_prompt,
883
+ do_classifier_free_guidance,
884
+ negative_prompt_1,
885
+ max_new_tokens=max_new_tokens,
886
+ )
887
+
888
+ prompt_embeds_2, attention_mask_2, generated_prompt_embeds_2 = self.encode_prompt(
889
+ prompt_2,
890
+ device,
891
+ num_waveforms_per_prompt,
892
+ do_classifier_free_guidance,
893
+ negative_prompt_2,
894
+ max_new_tokens=max_new_tokens,
895
+ )
896
+ return [prompt_embeds_1, attention_mask_1, generated_prompt_embeds_1], [prompt_embeds_2, attention_mask_2, generated_prompt_embeds_2]
897
+
898
+ def process_encoded_prompt(self, encoded_prompt, audio_file, time_pooling, freq_pooling):
899
+ prompt_embeds, attention_mask, generated_prompt_embeds = encoded_prompt
900
+ waveform, sr = torchaudio.load(audio_file)
901
+ fbank = torch.zeros((1024, 128))
902
+ ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, sr, fbank)
903
+ # print("ta_kaldi_fbank.shape",ta_kaldi_fbank.shape)
904
+ mel_spect_tensor = ta_kaldi_fbank.unsqueeze(0)
905
+ model = AudioMAEConditionCTPoolRand().cuda()
906
+ model.eval()
907
+ LOA_embed = model(mel_spect_tensor, time_pool=time_pooling, freq_pool=freq_pooling)
908
+ uncond_LOA_embed = model(torch.zeros_like(mel_spect_tensor), time_pool=time_pooling, freq_pool=freq_pooling)
909
+ LOA_embeds = LOA_embed[0]
910
+ uncond_LOA_embeds = uncond_LOA_embed[0]
911
+ bs_embed, seq_len, _ = LOA_embeds.shape
912
+ num = prompt_embeds.shape[0] // 2
913
+
914
+ LOA_embeds = LOA_embeds.view(bs_embed , seq_len, -1)
915
+ LOA_embeds = LOA_embeds.repeat(num, 1, 1)
916
+ uncond_LOA_embeds = uncond_LOA_embeds.view(bs_embed , seq_len, -1)
917
+ uncond_LOA_embeds = uncond_LOA_embeds.repeat(num, 1, 1)
918
+
919
+ negative_g, g = generated_prompt_embeds.chunk(2)
920
+ uncond = torch.cat([negative_g, uncond_LOA_embeds], dim=1)
921
+ cond = torch.cat([g, LOA_embeds], dim=1)
922
+ generated_prompt_embeds = torch.cat([uncond, cond], dim=0)
923
+ model_dtype = next(self.unet.parameters()).dtype
924
+ # Convert your tensor to the same dtype as the model
925
+ generated_prompt_embeds = generated_prompt_embeds.to(model_dtype)
926
+
927
+ return prompt_embeds, attention_mask, generated_prompt_embeds
928
+
929
+ @torch.no_grad()
930
+ def aud2latent(self, audio_path, audio_length_in_s):
931
+ DEVICE = torch.device(
932
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
933
+
934
+ # waveform, sr = torchaudio.load(audio_path)
935
+ # fbank = torch.zeros((height, 64))
936
+ # ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, sr, fbank, num_mels=64)
937
+ # mel_spect_tensor = ta_kaldi_fbank.unsqueeze(0).unsqueeze(0)
938
+
939
+ mel_spect_tensor = wav_to_mel(audio_path, duration=audio_length_in_s).unsqueeze(0)
940
+ output_path = audio_path.replace('.wav', '_fbank.png')
941
+ visualize_mel_spectrogram(mel_spect_tensor, output_path)
942
+ mel_spect_tensor = mel_spect_tensor.to(next(self.vae.parameters()).dtype)
943
+ # print(f'mel_spect_tensor dtype: {mel_spect_tensor.dtype}')
944
+ # print(f'self.vae dtype: {next(self.vae.parameters()).dtype}')
945
+ latents = self.vae.encode(mel_spect_tensor.to(DEVICE))['latent_dist'].mean
946
+ return latents
947
+
948
+ @torch.no_grad()
949
+ def ddim_inversion(self, start_latents, prompt_embeds, attention_mask, generated_prompt_embeds, guidance_scale,num_inference_steps):
950
+ start_step = 0
951
+ num_inference_steps = num_inference_steps
952
+ device = start_latents.device
953
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
954
+ start_latents *= self.scheduler.init_noise_sigma
955
+ latents = start_latents.clone()
956
+ for i in tqdm(range(start_step, num_inference_steps)):
957
+ t = self.scheduler.timesteps[i]
958
+ latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1. else latents
959
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
960
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=generated_prompt_embeds, encoder_hidden_states_1=prompt_embeds, encoder_attention_mask_1=attention_mask).sample
961
+ if guidance_scale > 1.:
962
+ noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
963
+ noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
964
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
965
+ return latents
966
+
967
+ def generate_morphing_prompt(self, prompt_1, prompt_2, alpha):
968
+ closer_prompt = prompt_1 if alpha <= 0.5 else prompt_2
969
+ prompt = (
970
+ f"A musical performance morphing between '{prompt_1}' and '{prompt_2}'. "
971
+ f"The sound is closer to '{closer_prompt}' with an interpolation factor of alpha={alpha:.2f}, "
972
+ f"where alpha=0 represents fully the {prompt_1} and alpha=1 represents fully {prompt_2}."
973
+ )
974
+ return prompt
975
+
976
+ @torch.no_grad()
977
+ def cal_latent(self,audio_length_in_s,time_pooling, freq_pooling,num_inference_steps, guidance_scale, aud_noise_1, aud_noise_2, prompt_1, prompt_2,
978
+ prompt_embeds_1, attention_mask_1, generated_prompt_embeds_1, prompt_embeds_2, attention_mask_2, generated_prompt_embeds_2,
979
+ alpha, original_processor,attn_processor_dict, use_morph_prompt, morphing_with_lora):
980
+ latents = slerp(aud_noise_1, aud_noise_2, alpha, self.use_adain)
981
+ if not use_morph_prompt:
982
+ max_length = max(prompt_embeds_1.shape[1], prompt_embeds_2.shape[1])
983
+ if prompt_embeds_1.shape[1] < max_length:
984
+ pad_size = max_length - prompt_embeds_1.shape[1]
985
+ padding = torch.zeros(
986
+ (prompt_embeds_1.shape[0], pad_size, prompt_embeds_1.shape[2]),
987
+ device=prompt_embeds_1.device,
988
+ dtype=prompt_embeds_1.dtype
989
+ )
990
+ prompt_embeds_1 = torch.cat([prompt_embeds_1, padding], dim=1)
991
+
992
+ if prompt_embeds_2.shape[1] < max_length:
993
+ pad_size = max_length - prompt_embeds_2.shape[1]
994
+ padding = torch.zeros(
995
+ (prompt_embeds_2.shape[0], pad_size, prompt_embeds_2.shape[2]),
996
+ device=prompt_embeds_2.device,
997
+ dtype=prompt_embeds_2.dtype
998
+ )
999
+ prompt_embeds_2 = torch.cat([prompt_embeds_2, padding], dim=1)
1000
+
1001
+ if attention_mask_1.shape[1] < max_length:
1002
+ pad_size = max_length - attention_mask_1.shape[1]
1003
+ padding = torch.zeros(
1004
+ (attention_mask_1.shape[0], pad_size),
1005
+ device=attention_mask_1.device,
1006
+ dtype=attention_mask_1.dtype
1007
+ )
1008
+ attention_mask_1 = torch.cat([attention_mask_1, padding], dim=1)
1009
+
1010
+ if attention_mask_2.shape[1] < max_length:
1011
+ pad_size = max_length - attention_mask_2.shape[1]
1012
+ padding = torch.zeros(
1013
+ (attention_mask_2.shape[0], pad_size),
1014
+ device=attention_mask_2.device,
1015
+ dtype=attention_mask_2.dtype
1016
+ )
1017
+ attention_mask_2 = torch.cat([attention_mask_2, padding], dim=1)
1018
+
1019
+ prompt_embeds = (1 - alpha) * prompt_embeds_1 + \
1020
+ alpha * prompt_embeds_2
1021
+ generated_prompt_embeds = (1 - alpha) * generated_prompt_embeds_1 + \
1022
+ alpha * generated_prompt_embeds_2
1023
+ attention_mask = attention_mask_1 if alpha < 0.5 else attention_mask_2
1024
+ # attention_mask = attention_mask_1 & attention_mask_2
1025
+ # attention_mask = attention_mask_1 | attention_mask_2
1026
+ # attention_mask = (1 - alpha) * attention_mask_1 + alpha * attention_mask_2
1027
+ # attention_mask = (attention_mask > 0.5).long()
1028
+
1029
+ if morphing_with_lora:
1030
+ pipeline_trained.unet.set_attn_processor(attn_processor_dict)
1031
+ waveform = pipeline_trained(
1032
+ time_pooling= time_pooling,
1033
+ freq_pooling= freq_pooling,
1034
+ latents = latents,
1035
+ num_inference_steps= num_inference_steps,
1036
+ guidance_scale= guidance_scale,
1037
+ num_waveforms_per_prompt= 1,
1038
+ audio_length_in_s=audio_length_in_s,
1039
+ prompt_embeds = prompt_embeds.chunk(2)[1],
1040
+ negative_prompt_embeds = prompt_embeds.chunk(2)[0],
1041
+ generated_prompt_embeds = generated_prompt_embeds.chunk(2)[1],
1042
+ negative_generated_prompt_embeds = generated_prompt_embeds.chunk(2)[0],
1043
+ attention_mask = attention_mask.chunk(2)[1],
1044
+ negative_attention_mask = attention_mask.chunk(2)[0],
1045
+ ).audios[0]
1046
+ if morphing_with_lora:
1047
+ pipeline_trained.unet.set_attn_processor(original_processor)
1048
+ else:
1049
+ latent_model_input = latents
1050
+ morphing_prompt = self.generate_morphing_prompt(prompt_1, prompt_2, alpha)
1051
+ if morphing_with_lora:
1052
+ pipeline_trained.unet.set_attn_processor(attn_processor_dict)
1053
+ waveform = pipeline_trained(
1054
+ time_pooling= time_pooling,
1055
+ freq_pooling= freq_pooling,
1056
+ latents = latent_model_input,
1057
+ num_inference_steps= num_inference_steps,
1058
+ guidance_scale= guidance_scale,
1059
+ num_waveforms_per_prompt= 1,
1060
+ audio_length_in_s=audio_length_in_s,
1061
+ prompt= morphing_prompt,
1062
+ negative_prompt= 'Low quality',
1063
+ ).audios[0]
1064
+ if morphing_with_lora:
1065
+ pipeline_trained.unet.set_attn_processor(original_processor)
1066
+
1067
+ return waveform
1068
+
1069
+ @torch.no_grad()
1070
+ def __call__(
1071
+ self,
1072
+ audio_file = None,
1073
+ audio_file2 = None,
1074
+ save_lora_dir = "./lora",
1075
+ load_lora_path_1 = None,
1076
+ load_lora_path_2 = None,
1077
+ lora_steps = 200,
1078
+ lora_lr = 2e-4,
1079
+ lora_rank = 16,
1080
+ time_pooling = 8,
1081
+ freq_pooling = 8,
1082
+ audio_length_in_s: Optional[float] = None,
1083
+ prompt_1: Union[str, List[str]] = None,
1084
+ prompt_2: Union[str, List[str]] = None,
1085
+ negative_prompt_1: Optional[Union[str, List[str]]] = None,
1086
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
1087
+ use_lora: bool = True,
1088
+ use_adain: bool = True,
1089
+ use_reschedule: bool = True,
1090
+ output_path: Optional[str] = None,
1091
+ num_inference_steps: int = 200,
1092
+ guidance_scale: float = 7.5,
1093
+ num_waveforms_per_prompt: Optional[int] = 1,
1094
+ attn_beta=0,
1095
+ lamd=0.6,
1096
+ fix_lora=None,
1097
+ save_intermediates=True,
1098
+ num_frames=50,
1099
+ max_new_tokens: Optional[int] = None,
1100
+ callback_steps: Optional[int] = 1,
1101
+ noisy_latent_with_lora=False,
1102
+ morphing_with_lora=False,
1103
+ use_morph_prompt=False,
1104
+ ):
1105
+ # 0. Load the pre-trained AP-adapter model
1106
+ layer_num = 0
1107
+ cross = [None, None, 768, 768, 1024, 1024, None, None]
1108
+ attn_procs = {}
1109
+ for name in self.unet.attn_processors.keys():
1110
+ cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim
1111
+ if name.startswith("mid_block"):
1112
+ hidden_size = self.unet.config.block_out_channels[-1]
1113
+ elif name.startswith("up_blocks"):
1114
+ block_id = int(name[len("up_blocks.")])
1115
+ hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id]
1116
+ elif name.startswith("down_blocks"):
1117
+ block_id = int(name[len("down_blocks.")])
1118
+ hidden_size = self.unet.config.block_out_channels[block_id]
1119
+
1120
+ if cross_attention_dim is None:
1121
+ attn_procs[name] = AttnProcessor2_0()
1122
+ else:
1123
+ cross_attention_dim = cross[layer_num % 8]
1124
+ layer_num += 1
1125
+ if cross_attention_dim == 768:
1126
+ attn_procs[name] = IPAttnProcessor2_0(
1127
+ hidden_size=hidden_size,
1128
+ name=name,
1129
+ cross_attention_dim=cross_attention_dim,
1130
+ scale=0.5,
1131
+ num_tokens=8,
1132
+ do_copy=False
1133
+ ).to("cuda", dtype=torch.float32)
1134
+ else:
1135
+ attn_procs[name] = AttnProcessor2_0()
1136
+
1137
+ state_dict = torch.load('/Data/home/Dennis/DeepMIR-2024/Final_Project/AP-adapter/pytorch_model.bin', map_location="cuda")
1138
+ for name, processor in attn_procs.items():
1139
+ if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
1140
+ weight_name_v = name + ".to_v_ip.weight"
1141
+ weight_name_k = name + ".to_k_ip.weight"
1142
+ processor.to_v_ip.weight = torch.nn.Parameter(state_dict[weight_name_v].half())
1143
+ processor.to_k_ip.weight = torch.nn.Parameter(state_dict[weight_name_k].half())
1144
+ self.unet.set_attn_processor(attn_procs)
1145
+ self.vae= self.vae.to("cuda", dtype=torch.float32)
1146
+ self.unet = self.unet.to("cuda", dtype=torch.float32)
1147
+ self.language_model = self.language_model.to("cuda", dtype=torch.float32)
1148
+ self.projection_model = self.projection_model.to("cuda", dtype=torch.float32)
1149
+ self.vocoder = self.vocoder.to("cuda", dtype=torch.float32)
1150
+ self.text_encoder = self.text_encoder.to("cuda", dtype=torch.float32)
1151
+ self.text_encoder_2 = self.text_encoder_2.to("cuda", dtype=torch.float32)
1152
+
1153
+
1154
+
1155
+ # 1. Pre-check
1156
+ height, original_waveform_length = self.pre_check(audio_length_in_s, prompt_1, callback_steps, negative_prompt_1)
1157
+ _, _ = self.pre_check(audio_length_in_s, prompt_2, callback_steps, negative_prompt_2)
1158
+ # print(f"height: {height}, original_waveform_length: {original_waveform_length}") # height: 1000, original_waveform_length: 160000
1159
+
1160
+ # # 2. Define call parameters
1161
+ device = "cuda" if torch.cuda.is_available() else "cpu"
1162
+ do_classifier_free_guidance = guidance_scale > 1.0
1163
+ self.use_lora = use_lora
1164
+ self.use_adain = use_adain
1165
+ self.use_reschedule = use_reschedule
1166
+ self.output_path = output_path
1167
+
1168
+ if self.use_lora:
1169
+ print("Loading lora...")
1170
+ if not load_lora_path_1:
1171
+
1172
+ weight_name = f"{output_path.split('/')[-1]}_lora_0.ckpt"
1173
+ load_lora_path_1 = save_lora_dir + "/" + weight_name
1174
+ if not os.path.exists(load_lora_path_1):
1175
+ train_lora(audio_file ,height ,time_pooling ,freq_pooling ,prompt_1, negative_prompt_1, guidance_scale, save_lora_dir, self.tokenizer, self.tokenizer_2,
1176
+ self.text_encoder, self.text_encoder_2, self.language_model, self.projection_model, self.vocoder,
1177
+ self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
1178
+ print(f"Load from {load_lora_path_1}.")
1179
+
1180
+ if load_lora_path_1.endswith(".safetensors"):
1181
+ lora_1 = safetensors.torch.load_file(
1182
+ load_lora_path_1, device="cpu")
1183
+ else:
1184
+ lora_1 = torch.load(load_lora_path_1, map_location="cpu")
1185
+
1186
+ if not load_lora_path_2:
1187
+ weight_name = f"{output_path.split('/')[-1]}_lora_1.ckpt"
1188
+ load_lora_path_2 = save_lora_dir + "/" + weight_name
1189
+ if not os.path.exists(load_lora_path_2):
1190
+ train_lora(audio_file2 ,height,time_pooling ,freq_pooling ,prompt_2, negative_prompt_2, guidance_scale, save_lora_dir, self.tokenizer, self.tokenizer_2,
1191
+ self.text_encoder, self.text_encoder_2, self.language_model, self.projection_model, self.vocoder,
1192
+ self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
1193
+ print(f"Load from {load_lora_path_2}.")
1194
+ if load_lora_path_2.endswith(".safetensors"):
1195
+ lora_2 = safetensors.torch.load_file(
1196
+ load_lora_path_2, device="cpu")
1197
+ else:
1198
+ lora_2 = torch.load(load_lora_path_2, map_location="cpu")
1199
+ else:
1200
+ lora_1 = lora_2 = None
1201
+
1202
+ # # 3. Encode input prompt
1203
+ encoded_prompt_1, encoded_prompt_2 = self.encode_prompt_for_2_sources(prompt_1, prompt_2, negative_prompt_1, negative_prompt_2, max_new_tokens, device, num_waveforms_per_prompt, do_classifier_free_guidance)
1204
+ prompt_embeds_1, attention_mask_1, generated_prompt_embeds_1 = self.process_encoded_prompt(encoded_prompt_1, audio_file, time_pooling, freq_pooling)
1205
+ prompt_embeds_2, attention_mask_2, generated_prompt_embeds_2 = self.process_encoded_prompt(encoded_prompt_2, audio_file2, time_pooling, freq_pooling)
1206
+
1207
+
1208
+ # 4. Prepare latent variables
1209
+ # For the first audio file
1210
+ original_processor = list(self.unet.attn_processors.values())[0]
1211
+
1212
+ if noisy_latent_with_lora:
1213
+ self.unet = load_lora(self.unet, lora_1, lora_2, 0)
1214
+ # print(self.unet.attn_processors)
1215
+ # We directly use the latent representation of the audio file for VAE's decoder as the 1st ground truth
1216
+ audio_latent = self.aud2latent(audio_file, audio_length_in_s).to(device)
1217
+ # mel_spectrogram = self.vae.decode(audio_latent).sample
1218
+ # first_audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
1219
+ # first_audio = first_audio[:, :original_waveform_length]
1220
+ # torchaudio.save(f"{self.output_path}/{0:02d}_gt.wav", first_audio, 16000)
1221
+
1222
+ # aud_noise_1 is the noisy latent representation of the audio file 1
1223
+ aud_noise_1 = self.ddim_inversion(audio_latent, prompt_embeds_1, attention_mask_1, generated_prompt_embeds_1, guidance_scale, num_inference_steps)
1224
+ # We use the pre-trained model to generate the audio file from the noisy latent representation
1225
+ # waveform = pipeline_trained(
1226
+ # audio_file = audio_file,
1227
+ # time_pooling= 2,
1228
+ # freq_pooling= 2,
1229
+ # prompt= prompt_1,
1230
+ # latents = aud_noise_1,
1231
+ # negative_prompt= negative_prompt_1,
1232
+ # num_inference_steps= 100,
1233
+ # guidance_scale= guidance_scale,
1234
+ # num_waveforms_per_prompt= 1,
1235
+ # audio_length_in_s=10,
1236
+ # ).audios
1237
+ # file_path = os.path.join(self.output_path, f"{0:02d}_gt2.wav")
1238
+ # scipy.io.wavfile.write(file_path, rate=16000, data=waveform[0])
1239
+
1240
+ # After reconstructed the audio file 1, we set the original processor back
1241
+ if noisy_latent_with_lora:
1242
+ self.unet.set_attn_processor(original_processor)
1243
+ # print(self.unet.attn_processors)
1244
+
1245
+ # For the second audio file
1246
+ if noisy_latent_with_lora:
1247
+ self.unet = load_lora(self.unet, lora_1, lora_2, 1)
1248
+ # print(self.unet.attn_processors)
1249
+ # We directly use the latent representation of the audio file for VAE's decoder as the 1st ground truth
1250
+ audio_latent = self.aud2latent(audio_file2, audio_length_in_s)
1251
+ # mel_spectrogram = self.vae.decode(audio_latent).sample
1252
+ # last_audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
1253
+ # last_audio = last_audio[:, :original_waveform_length]
1254
+ # torchaudio.save(f"{self.output_path}/{num_frames-1:02d}_gt.wav", last_audio, 16000)
1255
+ # aud_noise_2 is the noisy latent representation of the audio file 2
1256
+ aud_noise_2 = self.ddim_inversion(audio_latent, prompt_embeds_2, attention_mask_2, generated_prompt_embeds_2, guidance_scale, num_inference_steps)
1257
+ # waveform = pipeline_trained(
1258
+ # audio_file = audio_file2,
1259
+ # time_pooling= 2,
1260
+ # freq_pooling= 2,
1261
+ # prompt= prompt_2,
1262
+ # latents = aud_noise_2,
1263
+ # negative_prompt= negative_prompt_2,
1264
+ # num_inference_steps= 100,
1265
+ # guidance_scale= guidance_scale,
1266
+ # num_waveforms_per_prompt= 1,
1267
+ # audio_length_in_s=10,
1268
+ # ).audios
1269
+ # file_path = os.path.join(self.output_path, f"{num_frames-1:02d}_gt2.wav")
1270
+ # scipy.io.wavfile.write(file_path, rate=16000, data=waveform[0])
1271
+ if noisy_latent_with_lora:
1272
+ self.unet.set_attn_processor(original_processor)
1273
+ # print(self.unet.attn_processors)
1274
+ # After reconstructed the audio file 1, we set the original processor back
1275
+ original_processor = list(self.unet.attn_processors.values())[0]
1276
+
1277
+
1278
+ def morph(alpha_list, desc):
1279
+ audios = []
1280
+ # if attn_beta is not None:
1281
+ if self.use_lora:
1282
+ self.unet = load_lora(
1283
+ self.unet, lora_1, lora_2, 0 if fix_lora is None else fix_lora)
1284
+ attn_processor_dict = {}
1285
+ # print(self.unet.attn_processors)
1286
+ for k in self.unet.attn_processors.keys():
1287
+ # print(k)
1288
+ if do_replace_attn(k):
1289
+ # print(f"Since the key starts with *up*, we replace the processor with StoreProcessor.")
1290
+ if self.use_lora:
1291
+ attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
1292
+ self.aud1_dict, k)
1293
+ else:
1294
+ attn_processor_dict[k] = StoreProcessor(original_processor,
1295
+ self.aud1_dict, k)
1296
+ else:
1297
+ attn_processor_dict[k] = self.unet.attn_processors[k]
1298
+ # print(attn_processor_dict)
1299
+
1300
+ # print(attn_processor_dict)
1301
+
1302
+ # print(self.unet.attn_processors)
1303
+ # self.unet.set_attn_processor(attn_processor_dict)
1304
+ # print(self.unet.attn_processors)
1305
+
1306
+ first_audio = self.cal_latent(
1307
+ audio_length_in_s,
1308
+ time_pooling,
1309
+ freq_pooling,
1310
+ num_inference_steps,
1311
+ guidance_scale,
1312
+ aud_noise_1,
1313
+ aud_noise_2,
1314
+ prompt_1,
1315
+ prompt_2,
1316
+ prompt_embeds_1,
1317
+ attention_mask_1,
1318
+ generated_prompt_embeds_1,
1319
+ prompt_embeds_2,
1320
+ attention_mask_2,
1321
+ generated_prompt_embeds_2,
1322
+ alpha_list[0],
1323
+ original_processor,
1324
+ attn_processor_dict,
1325
+ use_morph_prompt,
1326
+ morphing_with_lora
1327
+ )
1328
+
1329
+ self.unet.set_attn_processor(original_processor)
1330
+ file_path = os.path.join(self.output_path, f"{0:02d}.wav")
1331
+ scipy.io.wavfile.write(file_path, rate=16000, data=first_audio)
1332
+
1333
+ if self.use_lora:
1334
+ self.unet = load_lora(
1335
+ self.unet, lora_1, lora_2, 1 if fix_lora is None else fix_lora)
1336
+ attn_processor_dict = {}
1337
+ for k in self.unet.attn_processors.keys():
1338
+ if do_replace_attn(k):
1339
+ # print(f"Since the key starts with *up*, we replace the processor with StoreProcessor.")
1340
+ if self.use_lora:
1341
+ attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
1342
+ self.aud2_dict, k)
1343
+ else:
1344
+ attn_processor_dict[k] = StoreProcessor(original_processor,
1345
+ self.aud2_dict, k)
1346
+ else:
1347
+ attn_processor_dict[k] = self.unet.attn_processors[k]
1348
+ # self.unet.set_attn_processor(attn_processor_dict)
1349
+ last_audio = self.cal_latent(
1350
+ audio_length_in_s,
1351
+ time_pooling,
1352
+ freq_pooling,
1353
+ num_inference_steps,
1354
+ guidance_scale,
1355
+ aud_noise_1,
1356
+ aud_noise_2,
1357
+ prompt_1,
1358
+ prompt_2,
1359
+ prompt_embeds_1,
1360
+ attention_mask_1,
1361
+ generated_prompt_embeds_1,
1362
+ prompt_embeds_2,
1363
+ attention_mask_2,
1364
+ generated_prompt_embeds_2,
1365
+ alpha_list[-1],
1366
+ original_processor,
1367
+ attn_processor_dict,
1368
+ use_morph_prompt,
1369
+ morphing_with_lora
1370
+ )
1371
+ file_path = os.path.join(self.output_path, f"{num_frames-1:02d}.wav")
1372
+ scipy.io.wavfile.write(file_path, rate=16000, data=last_audio)
1373
+ self.unet.set_attn_processor(original_processor)
1374
+
1375
+ for i in tqdm(range(1, num_frames - 1), desc=desc):
1376
+ alpha = alpha_list[i]
1377
+ if self.use_lora:
1378
+ self.unet = load_lora(
1379
+ self.unet, lora_1, lora_2, alpha if fix_lora is None else fix_lora)
1380
+
1381
+ attn_processor_dict = {}
1382
+ for k in self.unet.attn_processors.keys():
1383
+ if do_replace_attn(k):
1384
+ if self.use_lora:
1385
+ attn_processor_dict[k] = LoadProcessor(
1386
+ self.unet.attn_processors[k], k, self.aud1_dict, self.aud2_dict, alpha, attn_beta, lamd)
1387
+ else:
1388
+ attn_processor_dict[k] = LoadProcessor(
1389
+ original_processor, k, self.aud1_dict, self.aud2_dict, alpha, attn_beta, lamd)
1390
+ else:
1391
+ attn_processor_dict[k] = self.unet.attn_processors[k]
1392
+ # self.unet.set_attn_processor(attn_processor_dict)
1393
+ audio = self.cal_latent(
1394
+ audio_length_in_s,
1395
+ time_pooling,
1396
+ freq_pooling,
1397
+ num_inference_steps,
1398
+ guidance_scale,
1399
+ aud_noise_1,
1400
+ aud_noise_2,
1401
+ prompt_1,
1402
+ prompt_2,
1403
+ prompt_embeds_1,
1404
+ attention_mask_1,
1405
+ generated_prompt_embeds_1,
1406
+ prompt_embeds_2,
1407
+ attention_mask_2,
1408
+ generated_prompt_embeds_2,
1409
+ alpha_list[i],
1410
+ original_processor,
1411
+ attn_processor_dict,
1412
+ use_morph_prompt,
1413
+ morphing_with_lora
1414
+ )
1415
+ file_path = os.path.join(self.output_path, f"{i:02d}.wav")
1416
+ scipy.io.wavfile.write(file_path, rate=16000, data=audio)
1417
+ self.unet.set_attn_processor(original_processor)
1418
+ audios.append(audio)
1419
+ audios = [first_audio] + audios + [last_audio]
1420
+ return audios
1421
+ with torch.no_grad():
1422
+ if self.use_reschedule:
1423
+ alpha_scheduler = AlphaScheduler()
1424
+ alpha_list = list(torch.linspace(0, 1, num_frames))
1425
+ audios_pt = morph(alpha_list, "Sampling...")
1426
+ audios_pt = [torch.tensor(aud).unsqueeze(0)
1427
+ for aud in audios_pt]
1428
+ alpha_scheduler.from_imgs(audios_pt)
1429
+ alpha_list = alpha_scheduler.get_list()
1430
+ audios = morph(alpha_list, "Reschedule...")
1431
+ else:
1432
+ alpha_list = list(torch.linspace(0, 1, num_frames))
1433
+ audios = morph(alpha_list, "Sampling...")
1434
+
1435
+ return audios
pipeline/pipeline_audioldm.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import inspect
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from transformers import ClapTextModelWithProjection, RobertaTokenizer, RobertaTokenizerFast, SpeechT5HifiGan
9
+
10
+ from diffusers import AutoencoderKL, UNet2DConditionModel
11
+ from diffusers.schedulers import KarrasDiffusionSchedulers
12
+ from diffusers.utils.torch_utils import randn_tensor
13
+ from diffusers.utils import is_accelerate_available, logging, replace_example_docstring
14
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline,AudioPipelineOutput
15
+ # from diffusers.pipelines.pipeline_utils import AudioPipelineOutput
16
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
17
+
18
+
19
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
+
21
+ EXAMPLE_DOC_STRING = """
22
+ Examples:
23
+ ```py
24
+ >>> import torch
25
+ >>> from diffusers import AudioLDMPipeline
26
+
27
+ >>> pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm", torch_dtype=torch.float16)
28
+ >>> pipe = pipe.to("cuda")
29
+
30
+ >>> prompt = "A hammer hitting a wooden surface"
31
+ >>> audio = pipe(prompt).audios[0]
32
+ ```
33
+ """
34
+
35
+
36
+ class AudioLDMPipeline(DiffusionPipeline,TextualInversionLoaderMixin):
37
+ r"""
38
+ Pipeline for text-to-audio generation using AudioLDM.
39
+
40
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
41
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
42
+
43
+ Args:
44
+ vae ([`AutoencoderKL`]):
45
+ Variational Auto-Encoder (VAE) Model to encode and decode audios to and from latent representations.
46
+ text_encoder ([`ClapTextModelWithProjection`]):
47
+ Frozen text-encoder. AudioLDM uses the text portion of
48
+ [CLAP](https://huggingface.co/docs/transformers/main/model_doc/clap#transformers.ClapTextModelWithProjection),
49
+ specifically the [RoBERTa HSTAT-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant.
50
+ tokenizer ([`PreTrainedTokenizer`]):
51
+ Tokenizer of class
52
+ [RobertaTokenizer](https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaTokenizer).
53
+ unet ([`UNet2DConditionModel`]): U-Net architecture to denoise the encoded audio latents.
54
+ scheduler ([`SchedulerMixin`]):
55
+ A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of
56
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
57
+ vocoder ([`SpeechT5HifiGan`]):
58
+ Vocoder of class
59
+ [SpeechT5HifiGan](https://huggingface.co/docs/transformers/main/en/model_doc/speecht5#transformers.SpeechT5HifiGan).
60
+ """
61
+
62
+ def __init__(
63
+ self,
64
+ vae: AutoencoderKL,
65
+ text_encoder: ClapTextModelWithProjection,
66
+ tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
67
+ unet: UNet2DConditionModel,
68
+ scheduler: KarrasDiffusionSchedulers,
69
+ vocoder: SpeechT5HifiGan,
70
+ ):
71
+ super().__init__()
72
+
73
+ self.register_modules(
74
+ vae=vae,
75
+ text_encoder=text_encoder,
76
+ tokenizer=tokenizer,
77
+ unet=unet,
78
+ scheduler=scheduler,
79
+ vocoder=vocoder,
80
+ )
81
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
82
+
83
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
84
+ def enable_vae_slicing(self):
85
+ r"""
86
+ Enable sliced VAE decoding.
87
+
88
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
89
+ steps. This is useful to save some memory and allow larger batch sizes.
90
+ """
91
+ self.vae.enable_slicing()
92
+
93
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
94
+ def disable_vae_slicing(self):
95
+ r"""
96
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
97
+ computing decoding in one step.
98
+ """
99
+ self.vae.disable_slicing()
100
+
101
+ def enable_sequential_cpu_offload(self, gpu_id=0):
102
+ r"""
103
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
104
+ text_encoder, vae and vocoder have their state dicts saved to CPU and then are moved to a `torch.device('meta')
105
+ and loaded to GPU only when their specific submodule has its `forward` method called.
106
+ """
107
+ if is_accelerate_available():
108
+ from accelerate import cpu_offload
109
+ else:
110
+ raise ImportError("Please install accelerate via `pip install accelerate`")
111
+
112
+ device = torch.device(f"cuda:{gpu_id}")
113
+
114
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.vocoder]:
115
+ cpu_offload(cpu_offloaded_model, device)
116
+
117
+ @property
118
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
119
+ def _execution_device(self):
120
+ r"""
121
+ Returns the device on which the pipeline's models will be executed. After calling
122
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
123
+ hooks.
124
+ """
125
+ if not hasattr(self.unet, "_hf_hook"):
126
+ return self.device
127
+ for module in self.unet.modules():
128
+ if (
129
+ hasattr(module, "_hf_hook")
130
+ and hasattr(module._hf_hook, "execution_device")
131
+ and module._hf_hook.execution_device is not None
132
+ ):
133
+ return torch.device(module._hf_hook.execution_device)
134
+ return self.device
135
+
136
+ def _encode_prompt(
137
+ self,
138
+ prompt,
139
+ device,
140
+ num_waveforms_per_prompt,
141
+ do_classifier_free_guidance,
142
+ negative_prompt=None,
143
+ prompt_embeds: Optional[torch.FloatTensor] = None,
144
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
145
+ ):
146
+ r"""
147
+ Encodes the prompt into text encoder hidden states.
148
+
149
+ Args:
150
+ prompt (`str` or `List[str]`, *optional*):
151
+ prompt to be encoded
152
+ device (`torch.device`):
153
+ torch device
154
+ num_waveforms_per_prompt (`int`):
155
+ number of waveforms that should be generated per prompt
156
+ do_classifier_free_guidance (`bool`):
157
+ whether to use classifier free guidance or not
158
+ negative_prompt (`str` or `List[str]`, *optional*):
159
+ The prompt or prompts not to guide the audio generation. If not defined, one has to pass
160
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
161
+ less than `1`).
162
+ prompt_embeds (`torch.FloatTensor`, *optional*):
163
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
164
+ provided, text embeddings will be generated from `prompt` input argument.
165
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
166
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
167
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
168
+ argument.
169
+ """
170
+ if prompt is not None and isinstance(prompt, str):
171
+ batch_size = 1
172
+ elif prompt is not None and isinstance(prompt, list):
173
+ batch_size = len(prompt)
174
+ else:
175
+ batch_size = prompt_embeds.shape[0]
176
+
177
+ if prompt_embeds is None:
178
+ text_inputs = self.tokenizer(
179
+ prompt,
180
+ padding="max_length",
181
+ max_length=self.tokenizer.model_max_length,
182
+ truncation=True,
183
+ return_tensors="pt",
184
+ )
185
+ text_input_ids = text_inputs.input_ids
186
+ # print("text_input_ids: ", text_input_ids.shape)
187
+ attention_mask = text_inputs.attention_mask
188
+ # print("attention_mask: ", attention_mask.shape)
189
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
190
+
191
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
192
+ text_input_ids, untruncated_ids
193
+ ):
194
+ removed_text = self.tokenizer.batch_decode(
195
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
196
+ )
197
+ logger.warning(
198
+ "The following part of your input was truncated because CLAP can only handle sequences up to"
199
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
200
+ )
201
+
202
+ prompt_embeds = self.text_encoder(
203
+ text_input_ids.to(device),
204
+ attention_mask=attention_mask.to(device),
205
+ )
206
+ prompt_embeds = prompt_embeds.text_embeds
207
+ # additional L_2 normalization over each hidden-state
208
+ prompt_embeds = F.normalize(prompt_embeds, dim=-1)
209
+
210
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
211
+
212
+ (
213
+ bs_embed,
214
+ seq_len,
215
+ ) = prompt_embeds.shape
216
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
217
+ prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt)
218
+ prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len)
219
+
220
+ # get unconditional embeddings for classifier free guidance
221
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
222
+ uncond_tokens: List[str]
223
+ if negative_prompt is None:
224
+ uncond_tokens = [""] * batch_size
225
+ elif type(prompt) is not type(negative_prompt):
226
+ raise TypeError(
227
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
228
+ f" {type(prompt)}."
229
+ )
230
+ elif isinstance(negative_prompt, str):
231
+ uncond_tokens = [negative_prompt]
232
+ elif batch_size != len(negative_prompt):
233
+ raise ValueError(
234
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
235
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
236
+ " the batch size of `prompt`."
237
+ )
238
+ else:
239
+ uncond_tokens = negative_prompt
240
+
241
+ max_length = prompt_embeds.shape[1]
242
+ uncond_input = self.tokenizer(
243
+ uncond_tokens,
244
+ padding="max_length",
245
+ max_length=max_length,
246
+ truncation=True,
247
+ return_tensors="pt",
248
+ )
249
+
250
+ uncond_input_ids = uncond_input.input_ids.to(device)
251
+ attention_mask = uncond_input.attention_mask.to(device)
252
+
253
+ negative_prompt_embeds = self.text_encoder(
254
+ uncond_input_ids,
255
+ attention_mask=attention_mask,
256
+ )
257
+ negative_prompt_embeds = negative_prompt_embeds.text_embeds
258
+ # additional L_2 normalization over each hidden-state
259
+ negative_prompt_embeds = F.normalize(negative_prompt_embeds, dim=-1)
260
+
261
+ if do_classifier_free_guidance:
262
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
263
+ seq_len = negative_prompt_embeds.shape[1]
264
+
265
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
266
+
267
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt)
268
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len)
269
+
270
+ # For classifier free guidance, we need to do two forward passes.
271
+ # Here we concatenate the unconditional and text embeddings into a single batch
272
+ # to avoid doing two forward passes
273
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
274
+
275
+ return prompt_embeds
276
+
277
+ def decode_latents(self, latents):
278
+ latents = 1 / self.vae.config.scaling_factor * latents
279
+ mel_spectrogram = self.vae.decode(latents).sample
280
+ return mel_spectrogram
281
+
282
+ def mel_spectrogram_to_waveform(self, mel_spectrogram):
283
+ if mel_spectrogram.dim() == 4:
284
+ mel_spectrogram = mel_spectrogram.squeeze(1)
285
+
286
+ waveform = self.vocoder(mel_spectrogram)
287
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
288
+ waveform = waveform.cpu().float()
289
+ return waveform
290
+
291
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
292
+ def prepare_extra_step_kwargs(self, generator, eta):
293
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
294
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
295
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
296
+ # and should be between [0, 1]
297
+
298
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
299
+ extra_step_kwargs = {}
300
+ if accepts_eta:
301
+ extra_step_kwargs["eta"] = eta
302
+
303
+ # check if the scheduler accepts generator
304
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
305
+ if accepts_generator:
306
+ extra_step_kwargs["generator"] = generator
307
+ return extra_step_kwargs
308
+
309
+ def check_inputs(
310
+ self,
311
+ prompt,
312
+ audio_length_in_s,
313
+ vocoder_upsample_factor,
314
+ callback_steps,
315
+ negative_prompt=None,
316
+ prompt_embeds=None,
317
+ negative_prompt_embeds=None,
318
+ ):
319
+ min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor
320
+ if audio_length_in_s < min_audio_length_in_s:
321
+ raise ValueError(
322
+ f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but "
323
+ f"is {audio_length_in_s}."
324
+ )
325
+
326
+ if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0:
327
+ raise ValueError(
328
+ f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the "
329
+ f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of "
330
+ f"{self.vae_scale_factor}."
331
+ )
332
+
333
+ if (callback_steps is None) or (
334
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
335
+ ):
336
+ raise ValueError(
337
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
338
+ f" {type(callback_steps)}."
339
+ )
340
+
341
+ if prompt is not None and prompt_embeds is not None:
342
+ raise ValueError(
343
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
344
+ " only forward one of the two."
345
+ )
346
+ elif prompt is None and prompt_embeds is None:
347
+ raise ValueError(
348
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
349
+ )
350
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
351
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
352
+
353
+ if negative_prompt is not None and negative_prompt_embeds is not None:
354
+ raise ValueError(
355
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
356
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
357
+ )
358
+
359
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
360
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
361
+ raise ValueError(
362
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
363
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
364
+ f" {negative_prompt_embeds.shape}."
365
+ )
366
+
367
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim
368
+ def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None):
369
+ shape = (
370
+ batch_size,
371
+ num_channels_latents,
372
+ height // self.vae_scale_factor,
373
+ self.vocoder.config.model_in_dim // self.vae_scale_factor,
374
+ )
375
+ if isinstance(generator, list) and len(generator) != batch_size:
376
+ raise ValueError(
377
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
378
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
379
+ )
380
+
381
+ if latents is None:
382
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
383
+ else:
384
+ latents = latents.to(device)
385
+
386
+ # scale the initial noise by the standard deviation required by the scheduler
387
+ latents = latents * self.scheduler.init_noise_sigma
388
+ return latents
389
+
390
+ @torch.no_grad()
391
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
392
+ def __call__(
393
+ self,
394
+ prompt: Union[str, List[str]] = None,
395
+ audio_length_in_s: Optional[float] = None,
396
+ num_inference_steps: int = 10,
397
+ guidance_scale: float = 2.5,
398
+ negative_prompt: Optional[Union[str, List[str]]] = None,
399
+ num_waveforms_per_prompt: Optional[int] = 1,
400
+ eta: float = 0.0,
401
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
402
+ latents: Optional[torch.FloatTensor] = None,
403
+ prompt_embeds: Optional[torch.FloatTensor] = None,
404
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
405
+ return_dict: bool = True,
406
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
407
+ callback_steps: Optional[int] = 1,
408
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
409
+ output_type: Optional[str] = "np",
410
+ ):
411
+ r"""
412
+ Function invoked when calling the pipeline for generation.
413
+
414
+ Args:
415
+ prompt (`str` or `List[str]`, *optional*):
416
+ The prompt or prompts to guide the audio generation. If not defined, one has to pass `prompt_embeds`.
417
+ instead.
418
+ audio_length_in_s (`int`, *optional*, defaults to 5.12):
419
+ The length of the generated audio sample in seconds.
420
+ num_inference_steps (`int`, *optional*, defaults to 10):
421
+ The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
422
+ expense of slower inference.
423
+ guidance_scale (`float`, *optional*, defaults to 2.5):
424
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
425
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
426
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
427
+ 1`. Higher guidance scale encourages to generate audios that are closely linked to the text `prompt`,
428
+ usually at the expense of lower sound quality.
429
+ negative_prompt (`str` or `List[str]`, *optional*):
430
+ The prompt or prompts not to guide the audio generation. If not defined, one has to pass
431
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
432
+ less than `1`).
433
+ num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
434
+ The number of waveforms to generate per prompt.
435
+ eta (`float`, *optional*, defaults to 0.0):
436
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
437
+ [`schedulers.DDIMScheduler`], will be ignored for others.
438
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
439
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
440
+ to make generation deterministic.
441
+ latents (`torch.FloatTensor`, *optional*):
442
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for audio
443
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
444
+ tensor will ge generated by sampling using the supplied random `generator`.
445
+ prompt_embeds (`torch.FloatTensor`, *optional*):
446
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
447
+ provided, text embeddings will be generated from `prompt` input argument.
448
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
449
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
450
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
451
+ argument.
452
+ return_dict (`bool`, *optional*, defaults to `True`):
453
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
454
+ plain tuple.
455
+ callback (`Callable`, *optional*):
456
+ A function that will be called every `callback_steps` steps during inference. The function will be
457
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
458
+ callback_steps (`int`, *optional*, defaults to 1):
459
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
460
+ called at every step.
461
+ cross_attention_kwargs (`dict`, *optional*):
462
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
463
+ `self.processor` in
464
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
465
+ output_type (`str`, *optional*, defaults to `"np"`):
466
+ The output format of the generate image. Choose between:
467
+ - `"np"`: Return Numpy `np.ndarray` objects.
468
+ - `"pt"`: Return PyTorch `torch.Tensor` objects.
469
+
470
+ Examples:
471
+
472
+ Returns:
473
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
474
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
475
+ When returning a tuple, the first element is a list with the generated audios.
476
+ """
477
+ # 0. Convert audio input length from seconds to spectrogram height
478
+ vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
479
+
480
+ if audio_length_in_s is None:
481
+ audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor
482
+
483
+ height = int(audio_length_in_s / vocoder_upsample_factor)
484
+
485
+ original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)
486
+ if height % self.vae_scale_factor != 0:
487
+ height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor
488
+ logger.info(
489
+ f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} "
490
+ f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the "
491
+ f"denoising process."
492
+ )
493
+
494
+ # 1. Check inputs. Raise error if not correct
495
+ self.check_inputs(
496
+ prompt,
497
+ audio_length_in_s,
498
+ vocoder_upsample_factor,
499
+ callback_steps,
500
+ negative_prompt,
501
+ prompt_embeds,
502
+ negative_prompt_embeds,
503
+ )
504
+
505
+ # 2. Define call parameters
506
+ if prompt is not None and isinstance(prompt, str):
507
+ batch_size = 1
508
+ elif prompt is not None and isinstance(prompt, list):
509
+ batch_size = len(prompt)
510
+ else:
511
+ batch_size = prompt_embeds.shape[0]
512
+
513
+ device = self._execution_device
514
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
515
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
516
+ # corresponds to doing no classifier free guidance.
517
+ do_classifier_free_guidance = guidance_scale > 1.0
518
+
519
+ # 3. Encode input prompt
520
+ prompt_embeds = self._encode_prompt(
521
+ prompt,
522
+ device,
523
+ num_waveforms_per_prompt,
524
+ do_classifier_free_guidance,
525
+ negative_prompt,
526
+ prompt_embeds=prompt_embeds,
527
+ negative_prompt_embeds=negative_prompt_embeds,
528
+ )
529
+
530
+ # 4. Prepare timesteps
531
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
532
+ timesteps = self.scheduler.timesteps
533
+
534
+ # 5. Prepare latent variables
535
+ num_channels_latents = self.unet.config.in_channels
536
+ latents = self.prepare_latents(
537
+ batch_size * num_waveforms_per_prompt,
538
+ num_channels_latents,
539
+ height,
540
+ prompt_embeds.dtype,
541
+ device,
542
+ generator,
543
+ latents,
544
+ )
545
+
546
+
547
+
548
+ # 6. Prepare extra step kwargs
549
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
550
+
551
+ # 7. Denoising loop
552
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
553
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
554
+ for i, t in enumerate(timesteps):
555
+ # expand the latents if we are doing classifier free guidance
556
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
557
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
558
+
559
+ # predict the noise residual
560
+ noise_pred = self.unet(
561
+ latent_model_input,
562
+ t,
563
+ encoder_hidden_states=None,
564
+ class_labels=prompt_embeds,
565
+ cross_attention_kwargs=cross_attention_kwargs,
566
+ ).sample
567
+
568
+ # perform guidance
569
+ if do_classifier_free_guidance:
570
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
571
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
572
+
573
+ # compute the previous noisy sample x_t -> x_t-1
574
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
575
+
576
+ # call the callback, if provided
577
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
578
+ progress_bar.update()
579
+ if callback is not None and i % callback_steps == 0:
580
+ callback(i, t, latents)
581
+
582
+ # 8. Post-processing
583
+ mel_spectrogram = self.decode_latents(latents)
584
+
585
+ audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
586
+
587
+ audio = audio[:, :original_waveform_length]
588
+
589
+ if output_type == "np":
590
+ audio = audio.numpy()
591
+
592
+ if not return_dict:
593
+ return (audio,)
594
+
595
+ return AudioPipelineOutput(audios=audio)
596
+
597
+
pipeline/pipeline_audioldm2.py ADDED
@@ -0,0 +1,1080 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 CVSSP, ByteDance and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from audio_encoder.AudioMAE import AudioMAEConditionCTPoolRand, extract_kaldi_fbank_feature
15
+ import torchaudio
16
+ import inspect
17
+ from typing import Any, Callable, Dict, List, Optional, Union
18
+ import random
19
+ import numpy as np
20
+ import torch
21
+ from transformers import (
22
+ ClapFeatureExtractor,
23
+ ClapModel,
24
+ GPT2Model,
25
+ RobertaTokenizer,
26
+ RobertaTokenizerFast,
27
+ SpeechT5HifiGan,
28
+ T5EncoderModel,
29
+ T5Tokenizer,
30
+ T5TokenizerFast,
31
+ )
32
+
33
+ from diffusers import AutoencoderKL
34
+ from diffusers.schedulers import KarrasDiffusionSchedulers
35
+ from diffusers.utils import (
36
+ is_accelerate_available,
37
+ is_accelerate_version,
38
+ is_librosa_available,
39
+ logging,
40
+ replace_example_docstring,
41
+ )
42
+ from diffusers.utils.torch_utils import randn_tensor
43
+ from diffusers.pipelines.pipeline_utils import AudioPipelineOutput, DiffusionPipeline
44
+ from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
45
+ from diffusers.loaders import TextualInversionLoaderMixin
46
+ from audioldm.utils import default_audioldm_config
47
+ from audioldm.audio import TacotronSTFT, read_wav_file
48
+ from audioldm.audio.tools import get_mel_from_wav, _pad_spec, normalize_wav, pad_wav
49
+
50
+ if is_librosa_available():
51
+ import librosa
52
+
53
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
54
+
55
+ EXAMPLE_DOC_STRING = """
56
+ Examples:
57
+ ```py
58
+ >>> import scipy
59
+ >>> import torch
60
+ >>> from diffusers import AudioLDM2Pipeline
61
+
62
+ >>> repo_id = "cvssp/audioldm2"
63
+ >>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
64
+ >>> pipe = pipe.to("cuda")
65
+
66
+ >>> # define the prompts
67
+ >>> prompt = "The sound of a hammer hitting a wooden surface."
68
+ >>> negative_prompt = "Low quality."
69
+
70
+ >>> # set the seed for generator
71
+ >>> generator = torch.Generator("cuda").manual_seed(0)
72
+
73
+ >>> # run the generation
74
+ >>> audio = pipe(
75
+ ... prompt,
76
+ ... negative_prompt=negative_prompt,
77
+ ... num_inference_steps=200,
78
+ ... audio_length_in_s=10.0,
79
+ ... num_waveforms_per_prompt=3,
80
+ ... generator=generator,
81
+ ... ).audios
82
+
83
+ >>> # save the best audio sample (index 0) as a .wav file
84
+ >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio[0])
85
+ ```
86
+ """
87
+
88
+
89
+ def prepare_inputs_for_generation(
90
+ inputs_embeds,
91
+ attention_mask=None,
92
+ past_key_values=None,
93
+ **kwargs,
94
+ ):
95
+ if past_key_values is not None:
96
+ # only last token for inputs_embeds if past is defined in kwargs
97
+ inputs_embeds = inputs_embeds[:, -1:]
98
+
99
+ return {
100
+ "inputs_embeds": inputs_embeds,
101
+ "attention_mask": attention_mask,
102
+ "past_key_values": past_key_values,
103
+ "use_cache": kwargs.get("use_cache"),
104
+ }
105
+
106
+
107
+ class AudioLDM2Pipeline(DiffusionPipeline,TextualInversionLoaderMixin):
108
+ r"""
109
+ Pipeline for text-to-audio generation using AudioLDM2.
110
+
111
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
112
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
113
+
114
+ Args:
115
+ vae ([`AutoencoderKL`]):
116
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
117
+ text_encoder ([`~transformers.ClapModel`]):
118
+ First frozen text-encoder. AudioLDM2 uses the joint audio-text embedding model
119
+ [CLAP](https://huggingface.co/docs/transformers/model_doc/clap#transformers.CLAPTextModelWithProjection),
120
+ specifically the [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. The
121
+ text branch is used to encode the text prompt to a prompt embedding. The full audio-text model is used to
122
+ rank generated waveforms against the text prompt by computing similarity scores.
123
+ text_encoder_2 ([`~transformers.T5EncoderModel`]):
124
+ Second frozen text-encoder. AudioLDM2 uses the encoder of
125
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
126
+ [google/flan-t5-large](https://huggingface.co/google/flan-t5-large) variant.
127
+ projection_model ([`AudioLDM2ProjectionModel`]):
128
+ A trained model used to linearly project the hidden-states from the first and second text encoder models
129
+ and insert learned SOS and EOS token embeddings. The projected hidden-states from the two text encoders are
130
+ concatenated to give the input to the language model.
131
+ language_model ([`~transformers.GPT2Model`]):
132
+ An auto-regressive language model used to generate a sequence of hidden-states conditioned on the projected
133
+ outputs from the two text encoders.
134
+ tokenizer ([`~transformers.RobertaTokenizer`]):
135
+ Tokenizer to tokenize text for the first frozen text-encoder.
136
+ tokenizer_2 ([`~transformers.T5Tokenizer`]):
137
+ Tokenizer to tokenize text for the second frozen text-encoder.
138
+ feature_extractor ([`~transformers.ClapFeatureExtractor`]):
139
+ Feature extractor to pre-process generated audio waveforms to log-mel spectrograms for automatic scoring.
140
+ unet ([`UNet2DConditionModel`]):
141
+ A `UNet2DConditionModel` to denoise the encoded audio latents.
142
+ scheduler ([`SchedulerMixin`]):
143
+ A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of
144
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
145
+ vocoder ([`~transformers.SpeechT5HifiGan`]):
146
+ Vocoder of class `SpeechT5HifiGan` to convert the mel-spectrogram latents to the final audio waveform.
147
+ """
148
+
149
+ def __init__(
150
+ self,
151
+ vae: AutoencoderKL,
152
+ text_encoder: ClapModel,
153
+ text_encoder_2: T5EncoderModel,
154
+ projection_model: AudioLDM2ProjectionModel,
155
+ language_model: GPT2Model,
156
+ tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
157
+ tokenizer_2: Union[T5Tokenizer, T5TokenizerFast],
158
+ feature_extractor: ClapFeatureExtractor,
159
+ unet: AudioLDM2UNet2DConditionModel,
160
+ scheduler: KarrasDiffusionSchedulers,
161
+ vocoder: SpeechT5HifiGan,
162
+ ):
163
+ super().__init__()
164
+
165
+ self.register_modules(
166
+ vae=vae,
167
+ text_encoder=text_encoder,
168
+ text_encoder_2=text_encoder_2,
169
+ projection_model=projection_model,
170
+ language_model=language_model,
171
+ tokenizer=tokenizer,
172
+ tokenizer_2=tokenizer_2,
173
+ feature_extractor=feature_extractor,
174
+ unet=unet,
175
+ scheduler=scheduler,
176
+ vocoder=vocoder,
177
+ )
178
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
179
+
180
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
181
+ def enable_vae_slicing(self):
182
+ r"""
183
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
184
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
185
+ """
186
+ self.vae.enable_slicing()
187
+
188
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
189
+ def disable_vae_slicing(self):
190
+ r"""
191
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
192
+ computing decoding in one step.
193
+ """
194
+ self.vae.disable_slicing()
195
+
196
+ def enable_model_cpu_offload(self, gpu_id=0):
197
+ r"""
198
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
199
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
200
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
201
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
202
+ """
203
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
204
+ from accelerate import cpu_offload_with_hook
205
+ else:
206
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
207
+
208
+ device = torch.device(f"cuda:{gpu_id}")
209
+
210
+ if self.device.type != "cpu":
211
+ self.to("cpu", silence_dtype_warnings=True)
212
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
213
+
214
+ model_sequence = [
215
+ self.text_encoder.text_model,
216
+ self.text_encoder.text_projection,
217
+ self.text_encoder_2,
218
+ self.projection_model,
219
+ self.language_model,
220
+ self.unet,
221
+ self.vae,
222
+ self.vocoder,
223
+ self.text_encoder,
224
+ ]
225
+
226
+ hook = None
227
+ for cpu_offloaded_model in model_sequence:
228
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
229
+
230
+ # We'll offload the last model manually.
231
+ self.final_offload_hook = hook
232
+
233
+ def generate_language_model(
234
+ self,
235
+ inputs_embeds: torch.Tensor = None,
236
+ max_new_tokens: int = 512,
237
+ **model_kwargs,
238
+ ):
239
+ """
240
+
241
+ Generates a sequence of hidden-states from the language model, conditioned on the embedding inputs.
242
+
243
+ Parameters:
244
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
245
+ The sequence used as a prompt for the generation.
246
+ max_new_tokens (`int`):
247
+ Number of new tokens to generate.
248
+ model_kwargs (`Dict[str, Any]`, *optional*):
249
+ Ad hoc parametrization of additional model-specific kwargs that will be forwarded to the `forward`
250
+ function of the model.
251
+
252
+ Return:
253
+ `inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
254
+ The sequence of generated hidden-states.
255
+ """
256
+ max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
257
+ model_kwargs = self.language_model._get_initial_cache_position(inputs_embeds, model_kwargs)
258
+ for _ in range(max_new_tokens):
259
+ # prepare model inputs
260
+ model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
261
+
262
+ # forward pass to get next hidden states
263
+ output = self.language_model(**model_inputs, return_dict=True)
264
+
265
+ next_hidden_states = output.last_hidden_state
266
+
267
+ # Update the model input
268
+ inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1)
269
+
270
+ # Update generated hidden states, model inputs, and length for next step
271
+ model_kwargs = self.language_model._update_model_kwargs_for_generation(output, model_kwargs)
272
+
273
+ return inputs_embeds[:, -max_new_tokens:, :]
274
+
275
+ def encode_prompt(
276
+ self,
277
+ prompt,
278
+ device,
279
+ num_waveforms_per_prompt,
280
+ do_classifier_free_guidance,
281
+ negative_prompt=None,
282
+ prompt_embeds: Optional[torch.FloatTensor] = None,
283
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
284
+ generated_prompt_embeds: Optional[torch.FloatTensor] = None,
285
+ negative_generated_prompt_embeds: Optional[torch.FloatTensor] = None,
286
+ attention_mask: Optional[torch.LongTensor] = None,
287
+ negative_attention_mask: Optional[torch.LongTensor] = None,
288
+ max_new_tokens: Optional[int] = None,
289
+ ):
290
+ r"""
291
+ Encodes the prompt into text encoder hidden states.
292
+
293
+ Args:
294
+ prompt (`str` or `List[str]`, *optional*):
295
+ prompt to be encoded
296
+ device (`torch.device`):
297
+ torch device
298
+ num_waveforms_per_prompt (`int`):
299
+ number of waveforms that should be generated per prompt
300
+ do_classifier_free_guidance (`bool`):
301
+ whether to use classifier free guidance or not
302
+ negative_prompt (`str` or `List[str]`, *optional*):
303
+ The prompt or prompts not to guide the audio generation. If not defined, one has to pass
304
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
305
+ less than `1`).
306
+ prompt_embeds (`torch.FloatTensor`, *optional*):
307
+ Pre-computed text embeddings from the Flan T5 model. Can be used to easily tweak text inputs, *e.g.*
308
+ prompt weighting. If not provided, text embeddings will be computed from `prompt` input argument.
309
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
310
+ Pre-computed negative text embeddings from the Flan T5 model. Can be used to easily tweak text inputs,
311
+ *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
312
+ `negative_prompt` input argument.
313
+ generated_prompt_embeds (`torch.FloatTensor`, *optional*):
314
+ Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs,
315
+ *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
316
+ argument.
317
+ negative_generated_prompt_embeds (`torch.FloatTensor`, *optional*):
318
+ Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text
319
+ inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
320
+ `negative_prompt` input argument.
321
+ attention_mask (`torch.LongTensor`, *optional*):
322
+ Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
323
+ be computed from `prompt` input argument.
324
+ negative_attention_mask (`torch.LongTensor`, *optional*):
325
+ Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention
326
+ mask will be computed from `negative_prompt` input argument.
327
+ max_new_tokens (`int`, *optional*, defaults to None):
328
+ The number of new tokens to generate with the GPT2 language model.
329
+ Returns:
330
+ prompt_embeds (`torch.FloatTensor`):
331
+ Text embeddings from the Flan T5 model.
332
+ attention_mask (`torch.LongTensor`):
333
+ Attention mask to be applied to the `prompt_embeds`.
334
+ generated_prompt_embeds (`torch.FloatTensor`):
335
+ Text embeddings generated from the GPT2 langauge model.
336
+
337
+ Example:
338
+
339
+ ```python
340
+ >>> import scipy
341
+ >>> import torch
342
+ >>> from diffusers import AudioLDM2Pipeline
343
+
344
+ >>> repo_id = "cvssp/audioldm2"
345
+ >>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
346
+ >>> pipe = pipe.to("cuda")
347
+
348
+ >>> # Get text embedding vectors
349
+ >>> prompt_embeds, attention_mask, generated_prompt_embeds = pipe.encode_prompt(
350
+ ... prompt="Techno music with a strong, upbeat tempo and high melodic riffs",
351
+ ... device="cuda",
352
+ ... do_classifier_free_guidance=True,
353
+ ... )
354
+
355
+ >>> # Pass text embeddings to pipeline for text-conditional audio generation
356
+ >>> audio = pipe(
357
+ ... prompt_embeds=prompt_embeds,
358
+ ... attention_mask=attention_mask,
359
+ ... generated_prompt_embeds=generated_prompt_embeds,
360
+ ... num_inference_steps=200,
361
+ ... audio_length_in_s=10.0,
362
+ ... ).audios[0]
363
+
364
+ >>> # save generated audio sample
365
+ >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio)
366
+ ```"""
367
+ # print("prompt",prompt)
368
+ if prompt is not None and isinstance(prompt, str):
369
+ batch_size = 1
370
+ elif prompt is not None and isinstance(prompt, list):
371
+ batch_size = len(prompt)
372
+ else:
373
+ batch_size = prompt_embeds.shape[0]
374
+
375
+ # Define tokenizers and text encoders
376
+ tokenizers = [self.tokenizer, self.tokenizer_2]
377
+ text_encoders = [self.text_encoder, self.text_encoder_2]
378
+
379
+ if prompt_embeds is None:
380
+ prompt_embeds_list = []
381
+ attention_mask_list = []
382
+
383
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
384
+ text_inputs = tokenizer(
385
+ prompt,
386
+ padding="max_length" if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) else True,
387
+ max_length=tokenizer.model_max_length,
388
+ truncation=True,
389
+ return_tensors="pt",
390
+ )
391
+ text_input_ids = text_inputs.input_ids
392
+ attention_mask = text_inputs.attention_mask
393
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
394
+
395
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
396
+ text_input_ids, untruncated_ids
397
+ ):
398
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
399
+ logger.warning(
400
+ f"The following part of your input was truncated because {text_encoder.config.model_type} can "
401
+ f"only handle sequences up to {tokenizer.model_max_length} tokens: {removed_text}"
402
+ )
403
+
404
+ text_input_ids = text_input_ids.to(device)
405
+ attention_mask = attention_mask.to(device)
406
+
407
+ if text_encoder.config.model_type == "clap":
408
+ prompt_embeds = text_encoder.get_text_features(
409
+ text_input_ids,
410
+ attention_mask=attention_mask,
411
+ )
412
+ # append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
413
+ prompt_embeds = prompt_embeds[:, None, :]
414
+ # make sure that we attend to this single hidden-state
415
+ attention_mask = attention_mask.new_ones((batch_size, 1))
416
+ else:
417
+ prompt_embeds = text_encoder(
418
+ text_input_ids,
419
+ attention_mask=attention_mask,
420
+ )
421
+ prompt_embeds = prompt_embeds[0]
422
+
423
+ prompt_embeds_list.append(prompt_embeds)
424
+ attention_mask_list.append(attention_mask)
425
+
426
+ projection_output = self.projection_model(
427
+ hidden_states=prompt_embeds_list[0],
428
+ hidden_states_1=prompt_embeds_list[1],
429
+ attention_mask=attention_mask_list[0],
430
+ attention_mask_1=attention_mask_list[1],
431
+ )
432
+ projected_prompt_embeds = projection_output.hidden_states
433
+ projected_attention_mask = projection_output.attention_mask
434
+
435
+ generated_prompt_embeds = self.generate_language_model(
436
+ projected_prompt_embeds,
437
+ attention_mask=projected_attention_mask,
438
+ max_new_tokens=max_new_tokens,
439
+ )
440
+
441
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
442
+ attention_mask = (
443
+ attention_mask.to(device=device)
444
+ if attention_mask is not None
445
+ else torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=device)
446
+ )
447
+ generated_prompt_embeds = generated_prompt_embeds.to(dtype=self.language_model.dtype, device=device)
448
+
449
+ bs_embed, seq_len, hidden_size = prompt_embeds.shape
450
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
451
+ prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
452
+ prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len, hidden_size)
453
+
454
+ # duplicate attention mask for each generation per prompt
455
+ attention_mask = attention_mask.repeat(1, num_waveforms_per_prompt)
456
+ attention_mask = attention_mask.view(bs_embed * num_waveforms_per_prompt, seq_len)
457
+
458
+ bs_embed, seq_len, hidden_size = generated_prompt_embeds.shape
459
+ # duplicate generated embeddings for each generation per prompt, using mps friendly method
460
+ generated_prompt_embeds = generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
461
+ generated_prompt_embeds = generated_prompt_embeds.view(
462
+ bs_embed * num_waveforms_per_prompt, seq_len, hidden_size
463
+ )
464
+
465
+ # get unconditional embeddings for classifier free guidance
466
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
467
+ uncond_tokens: List[str]
468
+ if negative_prompt is None:
469
+ uncond_tokens = [""] * batch_size
470
+ elif type(prompt) is not type(negative_prompt):
471
+ raise TypeError(
472
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
473
+ f" {type(prompt)}."
474
+ )
475
+ elif isinstance(negative_prompt, str):
476
+ uncond_tokens = [negative_prompt]
477
+ elif batch_size != len(negative_prompt):
478
+ raise ValueError(
479
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
480
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
481
+ " the batch size of `prompt`."
482
+ )
483
+ else:
484
+ uncond_tokens = negative_prompt
485
+
486
+ negative_prompt_embeds_list = []
487
+ negative_attention_mask_list = []
488
+ max_length = prompt_embeds.shape[1]
489
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
490
+ uncond_input = tokenizer(
491
+ uncond_tokens,
492
+ padding="max_length",
493
+ max_length=tokenizer.model_max_length
494
+ if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
495
+ else max_length,
496
+ truncation=True,
497
+ return_tensors="pt",
498
+ )
499
+
500
+ uncond_input_ids = uncond_input.input_ids.to(device)
501
+ negative_attention_mask = uncond_input.attention_mask.to(device)
502
+
503
+ if text_encoder.config.model_type == "clap":
504
+ negative_prompt_embeds = text_encoder.get_text_features(
505
+ uncond_input_ids,
506
+ attention_mask=negative_attention_mask,
507
+ )
508
+ # append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
509
+ negative_prompt_embeds = negative_prompt_embeds[:, None, :]
510
+ # make sure that we attend to this single hidden-state
511
+ negative_attention_mask = negative_attention_mask.new_ones((batch_size, 1))
512
+ else:
513
+ negative_prompt_embeds = text_encoder(
514
+ uncond_input_ids,
515
+ attention_mask=negative_attention_mask,
516
+ )
517
+ negative_prompt_embeds = negative_prompt_embeds[0]
518
+
519
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
520
+ negative_attention_mask_list.append(negative_attention_mask)
521
+
522
+ projection_output = self.projection_model(
523
+ hidden_states=negative_prompt_embeds_list[0],
524
+ hidden_states_1=negative_prompt_embeds_list[1],
525
+ attention_mask=negative_attention_mask_list[0],
526
+ attention_mask_1=negative_attention_mask_list[1],
527
+ )
528
+ negative_projected_prompt_embeds = projection_output.hidden_states
529
+ negative_projected_attention_mask = projection_output.attention_mask
530
+
531
+ negative_generated_prompt_embeds = self.generate_language_model(
532
+ negative_projected_prompt_embeds,
533
+ attention_mask=negative_projected_attention_mask,
534
+ max_new_tokens=max_new_tokens,
535
+ )
536
+
537
+ if do_classifier_free_guidance:
538
+ seq_len = negative_prompt_embeds.shape[1]
539
+
540
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
541
+ negative_attention_mask = (
542
+ negative_attention_mask.to(device=device)
543
+ if negative_attention_mask is not None
544
+ else torch.ones(negative_prompt_embeds.shape[:2], dtype=torch.long, device=device)
545
+ )
546
+ negative_generated_prompt_embeds = negative_generated_prompt_embeds.to(
547
+ dtype=self.language_model.dtype, device=device
548
+ )
549
+
550
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
551
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
552
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len, -1)
553
+
554
+ # duplicate unconditional attention mask for each generation per prompt
555
+ negative_attention_mask = negative_attention_mask.repeat(1, num_waveforms_per_prompt)
556
+ negative_attention_mask = negative_attention_mask.view(batch_size * num_waveforms_per_prompt, seq_len)
557
+
558
+ # duplicate unconditional generated embeddings for each generation per prompt
559
+ seq_len = negative_generated_prompt_embeds.shape[1]
560
+ negative_generated_prompt_embeds = negative_generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
561
+ negative_generated_prompt_embeds = negative_generated_prompt_embeds.view(
562
+ batch_size * num_waveforms_per_prompt, seq_len, -1
563
+ )
564
+
565
+ # For classifier free guidance, we need to do two forward passes.
566
+ # Here we concatenate the unconditional and text embeddings into a single batch
567
+ # to avoid doing two forward passes
568
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
569
+ attention_mask = torch.cat([negative_attention_mask, attention_mask])
570
+ # print("negative_generated_prompt_embeds",negative_generated_prompt_embeds.shape)
571
+ # print("generated_prompt_embeds",generated_prompt_embeds.shape)
572
+ generated_prompt_embeds = torch.cat([negative_generated_prompt_embeds, generated_prompt_embeds])
573
+ # if random.random() < 0.25:
574
+ # num = random.randint(0, 2)
575
+ # if num == 0:
576
+ # audiomae = torch.load("/data/home/fundwotsai/DreamSound/MAE_feature0_stride16.pt")
577
+ # elif num == 1:
578
+ # audiomae = torch.load("/data/home/fundwotsai/DreamSound/MAE_feature1_stride16.pt")
579
+ # else:
580
+ # audiomae = torch.load("/data/home/fundwotsai/DreamSound/MAE_feature2_stride16.pt")
581
+ # audiomae = audiomae.to(torch.float32)
582
+ # audiomae = audiomae.to("cuda")
583
+ # generated_prompt_embeds[1:2] = audiomae
584
+
585
+ return prompt_embeds, attention_mask, generated_prompt_embeds
586
+
587
+ # Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.mel_spectrogram_to_waveform
588
+ def mel_spectrogram_to_waveform(self, mel_spectrogram):
589
+ if mel_spectrogram.dim() == 4:
590
+ mel_spectrogram = mel_spectrogram.squeeze(1)
591
+
592
+ waveform = self.vocoder(mel_spectrogram)
593
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
594
+ waveform = waveform.cpu().float()
595
+ return waveform
596
+
597
+ def score_waveforms(self, text, audio, num_waveforms_per_prompt, device, dtype):
598
+ if not is_librosa_available():
599
+ logger.info(
600
+ "Automatic scoring of the generated audio waveforms against the input prompt text requires the "
601
+ "`librosa` package to resample the generated waveforms. Returning the audios in the order they were "
602
+ "generated. To enable automatic scoring, install `librosa` with: `pip install librosa`."
603
+ )
604
+ return audio
605
+ inputs = self.tokenizer(text, return_tensors="pt", padding=True)
606
+ resampled_audio = librosa.resample(
607
+ audio.numpy(), orig_sr=self.vocoder.config.sampling_rate, target_sr=self.feature_extractor.sampling_rate
608
+ )
609
+ inputs["input_features"] = self.feature_extractor(
610
+ list(resampled_audio), return_tensors="pt", sampling_rate=self.feature_extractor.sampling_rate
611
+ ).input_features.type(dtype)
612
+ inputs = inputs.to(device)
613
+
614
+ # compute the audio-text similarity score using the CLAP model
615
+ logits_per_text = self.text_encoder(**inputs).logits_per_text
616
+ # sort by the highest matching generations per prompt
617
+ indices = torch.argsort(logits_per_text, dim=1, descending=True)[:, :num_waveforms_per_prompt]
618
+ audio = torch.index_select(audio, 0, indices.reshape(-1).cpu())
619
+ return audio
620
+
621
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
622
+ def prepare_extra_step_kwargs(self, generator, eta):
623
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
624
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
625
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
626
+ # and should be between [0, 1]
627
+
628
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
629
+ extra_step_kwargs = {}
630
+ if accepts_eta:
631
+ extra_step_kwargs["eta"] = eta
632
+
633
+ # check if the scheduler accepts generator
634
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
635
+ if accepts_generator:
636
+ extra_step_kwargs["generator"] = generator
637
+ return extra_step_kwargs
638
+
639
+ def check_inputs(
640
+ self,
641
+ prompt,
642
+ audio_length_in_s,
643
+ vocoder_upsample_factor,
644
+ callback_steps,
645
+ negative_prompt=None,
646
+ prompt_embeds=None,
647
+ negative_prompt_embeds=None,
648
+ generated_prompt_embeds=None,
649
+ negative_generated_prompt_embeds=None,
650
+ attention_mask=None,
651
+ negative_attention_mask=None,
652
+ ):
653
+ min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor
654
+ if audio_length_in_s < min_audio_length_in_s:
655
+ raise ValueError(
656
+ f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but "
657
+ f"is {audio_length_in_s}."
658
+ )
659
+
660
+ if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0:
661
+ raise ValueError(
662
+ f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the "
663
+ f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of "
664
+ f"{self.vae_scale_factor}."
665
+ )
666
+
667
+ if (callback_steps is None) or (
668
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)):
669
+ raise ValueError(
670
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
671
+ f" {type(callback_steps)}."
672
+ )
673
+
674
+ if prompt is not None and prompt_embeds is not None:
675
+ raise ValueError(
676
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
677
+ " only forward one of the two."
678
+ )
679
+ elif prompt is None and (prompt_embeds is None or generated_prompt_embeds is None):
680
+ raise ValueError(
681
+ "Provide either `prompt`, or `prompt_embeds` and `generated_prompt_embeds`. Cannot leave "
682
+ "`prompt` undefined without specifying both `prompt_embeds` and `generated_prompt_embeds`."
683
+ )
684
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
685
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
686
+
687
+ if negative_prompt is not None and negative_prompt_embeds is not None:
688
+ raise ValueError(
689
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
690
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
691
+ )
692
+ elif negative_prompt_embeds is not None and negative_generated_prompt_embeds is None:
693
+ raise ValueError(
694
+ "Cannot forward `negative_prompt_embeds` without `negative_generated_prompt_embeds`. Ensure that"
695
+ "both arguments are specified"
696
+ )
697
+
698
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
699
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
700
+ raise ValueError(
701
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
702
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
703
+ f" {negative_prompt_embeds.shape}."
704
+ )
705
+ if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]:
706
+ raise ValueError(
707
+ "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
708
+ f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}"
709
+ )
710
+
711
+ if generated_prompt_embeds is not None and negative_generated_prompt_embeds is not None:
712
+ if generated_prompt_embeds.shape != negative_generated_prompt_embeds.shape:
713
+ raise ValueError(
714
+ "`generated_prompt_embeds` and `negative_generated_prompt_embeds` must have the same shape when "
715
+ f"passed directly, but got: `generated_prompt_embeds` {generated_prompt_embeds.shape} != "
716
+ f"`negative_generated_prompt_embeds` {negative_generated_prompt_embeds.shape}."
717
+ )
718
+ if (
719
+ negative_attention_mask is not None
720
+ and negative_attention_mask.shape != negative_prompt_embeds.shape[:2]
721
+ ):
722
+ raise ValueError(
723
+ "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
724
+ f"`attention_mask: {negative_attention_mask.shape} != `prompt_embeds` {negative_prompt_embeds.shape}"
725
+ )
726
+
727
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim
728
+ def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None):
729
+ shape = (
730
+ batch_size,
731
+ num_channels_latents,
732
+ height // self.vae_scale_factor,
733
+ self.vocoder.config.model_in_dim // self.vae_scale_factor,
734
+ )
735
+ if isinstance(generator, list) and len(generator) != batch_size:
736
+ raise ValueError(
737
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
738
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
739
+ )
740
+
741
+ if latents is None:
742
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
743
+ else:
744
+ latents = latents.to(device)
745
+
746
+ # scale the initial noise by the standard deviation required by the scheduler
747
+ latents = latents * self.scheduler.init_noise_sigma
748
+ return latents
749
+
750
+ @torch.no_grad()
751
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
752
+ def __call__(
753
+ self,
754
+ audio_file = None,
755
+ audio_file2 = None,
756
+ time_pooling = 8,
757
+ freq_pooling = 8,
758
+ prompt: Union[str, List[str]] = None,
759
+ audio_length_in_s: Optional[float] = None,
760
+ num_inference_steps: int = 200,
761
+ guidance_scale: float = 7.5,
762
+ negative_prompt: Optional[Union[str, List[str]]] = None,
763
+ num_waveforms_per_prompt: Optional[int] = 1,
764
+ eta: float = 0.0,
765
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
766
+ latents: Optional[torch.FloatTensor] = None,
767
+ prompt_embeds: Optional[torch.FloatTensor] = None,
768
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
769
+ generated_prompt_embeds: Optional[torch.FloatTensor] = None,
770
+ negative_generated_prompt_embeds: Optional[torch.FloatTensor] = None,
771
+ attention_mask: Optional[torch.LongTensor] = None,
772
+ negative_attention_mask: Optional[torch.LongTensor] = None,
773
+ max_new_tokens: Optional[int] = None,
774
+ return_dict: bool = True,
775
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
776
+ callback_steps: Optional[int] = 1,
777
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
778
+ output_type: Optional[str] = "np",
779
+ ):
780
+ r"""
781
+ The call function to the pipeline for generation.
782
+
783
+ Args:
784
+ prompt (`str` or `List[str]`, *optional*):
785
+ The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
786
+ audio_length_in_s (`int`, *optional*, defaults to 10.24):
787
+ The length of the generated audio sample in seconds.
788
+ num_inference_steps (`int`, *optional*, defaults to 200):
789
+ The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
790
+ expense of slower inference.
791
+ guidance_scale (`float`, *optional*, defaults to 3.5):
792
+ A higher guidance scale value encourages the model to generate audio that is closely linked to the text
793
+ `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
794
+ negative_prompt (`str` or `List[str]`, *optional*):
795
+ The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
796
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
797
+ num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
798
+ The number of waveforms to generate per prompt. If `num_waveforms_per_prompt > 1`, then automatic
799
+ scoring is performed between the generated outputs and the text prompt. This scoring ranks the
800
+ generated waveforms based on their cosine similarity with the text input in the joint text-audio
801
+ embedding space.
802
+ eta (`float`, *optional*, defaults to 0.0):
803
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
804
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
805
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
806
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
807
+ generation deterministic.
808
+ latents (`torch.FloatTensor`, *optional*):
809
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for spectrogram
810
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
811
+ tensor is generated by sampling using the supplied random `generator`.
812
+ prompt_embeds (`torch.FloatTensor`, *optional*):
813
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
814
+ provided, text embeddings are generated from the `prompt` input argument.
815
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
816
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
817
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
818
+ generated_prompt_embeds (`torch.FloatTensor`, *optional*):
819
+ Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs,
820
+ *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
821
+ argument.
822
+ negative_generated_prompt_embeds (`torch.FloatTensor`, *optional*):
823
+ Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text
824
+ inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
825
+ `negative_prompt` input argument.
826
+ attention_mask (`torch.LongTensor`, *optional*):
827
+ Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
828
+ be computed from `prompt` input argument.
829
+ negative_attention_mask (`torch.LongTensor`, *optional*):
830
+ Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention
831
+ mask will be computed from `negative_prompt` input argument.
832
+ max_new_tokens (`int`, *optional*, defaults to None):
833
+ Number of new tokens to generate with the GPT2 language model. If not provided, number of tokens will
834
+ be taken from the config of the model.
835
+ return_dict (`bool`, *optional*, defaults to `True`):
836
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
837
+ plain tuple.
838
+ callback (`Callable`, *optional*):
839
+ A function that calls every `callback_steps` steps during inference. The function is called with the
840
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
841
+ callback_steps (`int`, *optional*, defaults to 1):
842
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
843
+ every step.
844
+ cross_attention_kwargs (`dict`, *optional*):
845
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
846
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
847
+ output_type (`str`, *optional*, defaults to `"np"`):
848
+ The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or
849
+ `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion
850
+ model (LDM) output.
851
+
852
+ Examples:
853
+
854
+ Returns:
855
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
856
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
857
+ otherwise a `tuple` is returned where the first element is a list with the generated audio.
858
+ """
859
+ # 0. Convert audio input length from seconds to spectrogram height
860
+ vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
861
+
862
+ if audio_length_in_s is None:
863
+ audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor
864
+
865
+ height = int(audio_length_in_s / vocoder_upsample_factor)
866
+
867
+ original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)
868
+ if height % self.vae_scale_factor != 0:
869
+ height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor
870
+ logger.info(
871
+ f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} "
872
+ f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the "
873
+ f"denoising process."
874
+ )
875
+
876
+ # 1. Check inputs. Raise error if not correct
877
+ self.check_inputs(
878
+ prompt,
879
+ audio_length_in_s,
880
+ vocoder_upsample_factor,
881
+ callback_steps,
882
+ negative_prompt,
883
+ prompt_embeds,
884
+ negative_prompt_embeds,
885
+ generated_prompt_embeds,
886
+ negative_generated_prompt_embeds,
887
+ attention_mask,
888
+ negative_attention_mask,
889
+ )
890
+
891
+ # 2. Define call parameters
892
+ if prompt is not None and isinstance(prompt, str):
893
+ batch_size = 1
894
+ elif prompt is not None and isinstance(prompt, list):
895
+ batch_size = len(prompt)
896
+ else:
897
+ batch_size = prompt_embeds.shape[0]
898
+
899
+ device = self._execution_device
900
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
901
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
902
+ # corresponds to doing no classifier free guidance.
903
+ do_classifier_free_guidance = guidance_scale > 1.0
904
+
905
+ # 3. Encode input prompt
906
+ prompt_embeds, attention_mask, generated_prompt_embeds = self.encode_prompt(
907
+ prompt,
908
+ device,
909
+ num_waveforms_per_prompt,
910
+ do_classifier_free_guidance,
911
+ negative_prompt,
912
+ prompt_embeds=prompt_embeds,
913
+ negative_prompt_embeds=negative_prompt_embeds,
914
+ generated_prompt_embeds=generated_prompt_embeds,
915
+ negative_generated_prompt_embeds=negative_generated_prompt_embeds,
916
+ attention_mask=attention_mask,
917
+ negative_attention_mask=negative_attention_mask,
918
+ max_new_tokens=max_new_tokens,
919
+ )
920
+ # print("prompt_embeds",prompt_embeds.shape)
921
+ # print("attention_mask",attention_mask.shape)
922
+ # print("generated_prompt_embeds",generated_prompt_embeds.shape)
923
+ if audio_file != None:
924
+ waveform, sr = torchaudio.load(audio_file)
925
+ fbank = torch.zeros((1024, 128))
926
+ # print(sr)
927
+ ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, sr, fbank)
928
+ # print("ta_kaldi_fbank.shape",ta_kaldi_fbank.shape)
929
+ mel_spect_tensor = ta_kaldi_fbank.unsqueeze(0)
930
+ model = AudioMAEConditionCTPoolRand().cuda()
931
+ model.eval()
932
+ LOA_embed = model(mel_spect_tensor, time_pool=time_pooling, freq_pool=freq_pooling)
933
+ uncond_LOA_embed = model(torch.zeros_like(mel_spect_tensor), time_pool=time_pooling, freq_pool=freq_pooling)
934
+ # print(LOA_embed[0].size(),uncond_LOA_embed[0].size())
935
+ # return LOA_embed[0], uncond_LOA_embed[0]
936
+ LOA_embeds = LOA_embed[0]
937
+ uncond_LOA_embeds = uncond_LOA_embed[0]
938
+ bs_embed, seq_len, _ = LOA_embeds.shape
939
+ num = prompt_embeds.shape[0] // 2
940
+ # print("num",num)
941
+ LOA_embeds = LOA_embeds.view(bs_embed , seq_len, -1)
942
+ LOA_embeds = LOA_embeds.repeat(num, 1, 1)
943
+ uncond_LOA_embeds = uncond_LOA_embeds.view(bs_embed , seq_len, -1)
944
+ uncond_LOA_embeds = uncond_LOA_embeds.repeat(num, 1, 1)
945
+ negative_g, g = generated_prompt_embeds.chunk(2)
946
+ # print("negative_g",negative_g.shape)
947
+ # print("uncond_LOA_embeds",uncond_LOA_embeds.shape)
948
+ # print("LOA_embeds",LOA_embeds.shape)
949
+ uncond = torch.cat([negative_g, uncond_LOA_embeds], dim=1)
950
+ cond = torch.cat([g, LOA_embeds], dim=1)
951
+ # print("uncond",uncond.shape)
952
+ # print("cond",cond.shape)
953
+ generated_prompt_embeds = torch.cat([uncond, cond], dim=0)
954
+ # generated_prompt_embeds[1:2] = LOA_embeds
955
+ # print("generated_prompt_embeds.shape", generated_prompt_embeds.shape)
956
+ # Assuming 'model' is your pre-defined model
957
+ model_dtype = next(self.unet.parameters()).dtype
958
+ # print(model_dtype)
959
+ # Convert your tensor to the same dtype as the model
960
+ generated_prompt_embeds = generated_prompt_embeds.to(model_dtype)
961
+
962
+ # generated_prompt_embeds = generated_prompt_embeds.to(torch.float32)
963
+ # print("generated_prompt_embeds.shape", generated_prompt_embeds.shape)
964
+ # print("LOA_embeds.shape", LOA_embeds.shape)
965
+ # generated_prompt_embeds[1:2] = LOA_embeds
966
+
967
+ # if random.random() < 0.25:
968
+ # num = random.randint(0, 2)
969
+ # if num == 0:
970
+ # audiomae = torch.load("/home/fundwotsai/DreamSound/MAE_feature1_stride-no-pool.pt")
971
+ # # elif num == 1:
972
+ # # audiomae = torch.load("/data/home/fundwotsai/DreamSound/MAE_feature1_stride16.pt")
973
+ # # else:
974
+ # # audiomae = torch.load("/data/home/fundwotsai/DreamSound/MAE_feature2_stride16.pt")
975
+ # audiomae = audiomae.to(torch.float32)
976
+ # audiomae = audiomae.to("cuda")
977
+ # print("generated_prompt_embeds",generated_prompt_embeds.shape)
978
+ # print("audiomae",audiomae.shape)
979
+ # generated_prompt_embeds[1:2] = audiomae
980
+
981
+ # print("generated_prompt_embeds",generated_prompt_embeds.shape)
982
+ # audiomae = torch.load("/home/fundwotsai/DreamSound/MAE_feature1_stride-no-pool.pt")
983
+ # audiomae = audiomae.to(torch.float16)
984
+ # audiomae = audiomae.to("cuda")
985
+ # generated_prompt_embeds[1:2] = audiomae
986
+ # 4. Prepare timesteps
987
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
988
+ timesteps = self.scheduler.timesteps
989
+
990
+ # 5. Prepare latent variables
991
+ num_channels_latents = self.unet.config.in_channels
992
+ latents = self.prepare_latents(
993
+ batch_size * num_waveforms_per_prompt,
994
+ num_channels_latents,
995
+ height,
996
+ prompt_embeds.dtype,
997
+ device,
998
+ generator,
999
+ latents,
1000
+ )
1001
+
1002
+ # 6. Prepare extra step kwargs
1003
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1004
+
1005
+ # 7. Denoising loop
1006
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1007
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1008
+ for i, t in enumerate(timesteps):
1009
+ # print(f"t: {t}")
1010
+ # expand the latents if we are doing classifier free guidance
1011
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1012
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1013
+
1014
+ # print(f"latent_model_input dtype: {latent_model_input.dtype}")
1015
+ # print(f"generated_prompt_embeds dtype: {generated_prompt_embeds.dtype}")
1016
+ # print(f"prompt_embeds dtype: {prompt_embeds.dtype}")
1017
+ # print(f"attention_mask dtype: {attention_mask.dtype}")
1018
+
1019
+ # print(f"latent_model_input shape: {latent_model_input.shape}")
1020
+ # print(f"generated_prompt_embeds shape: {generated_prompt_embeds.shape}")
1021
+ # print(f"prompt_embeds shape: {prompt_embeds.shape}")
1022
+ # print(f"attention_mask shape: {attention_mask.shape}")
1023
+
1024
+ latent_model_input = latent_model_input.to(generated_prompt_embeds.dtype)
1025
+
1026
+ # predict the noise residual
1027
+ noise_pred = self.unet(
1028
+ latent_model_input,
1029
+ t,
1030
+ encoder_hidden_states=generated_prompt_embeds,
1031
+ encoder_hidden_states_1=prompt_embeds,
1032
+ encoder_attention_mask_1=attention_mask,
1033
+ return_dict=False,
1034
+ )[0]
1035
+
1036
+ # perform guidance
1037
+ if do_classifier_free_guidance:
1038
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1039
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1040
+
1041
+ # compute the previous noisy sample x_t -> x_t-1
1042
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1043
+ # print(f"latents shape: {latents.shape}")
1044
+ # call the callback, if provided
1045
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1046
+ progress_bar.update()
1047
+ if callback is not None and i % callback_steps == 0:
1048
+ callback(i, t, latents)
1049
+
1050
+ self.maybe_free_model_hooks()
1051
+
1052
+ # 8. Post-processing
1053
+ if not output_type == "latent":
1054
+ latents = 1 / self.vae.config.scaling_factor * latents
1055
+ latents = latents.to(next(self.vae.parameters()).dtype)
1056
+ mel_spectrogram = self.vae.decode(latents).sample
1057
+ else:
1058
+ return AudioPipelineOutput(audios=latents)
1059
+
1060
+ audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
1061
+
1062
+ audio = audio[:, :original_waveform_length]
1063
+
1064
+ # 9. Automatic scoring
1065
+ if num_waveforms_per_prompt > 1 and prompt is not None:
1066
+ audio = self.score_waveforms(
1067
+ text=prompt,
1068
+ audio=audio,
1069
+ num_waveforms_per_prompt=num_waveforms_per_prompt,
1070
+ device=device,
1071
+ dtype=prompt_embeds.dtype,
1072
+ )
1073
+
1074
+ if output_type == "np":
1075
+ audio = audio.numpy()
1076
+
1077
+ if not return_dict:
1078
+ return (audio,)
1079
+
1080
+ return AudioPipelineOutput(audios=audio)
pipeline/style_transfer_pipeline.py ADDED
@@ -0,0 +1,1012 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 CVSSP, ByteDance and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from train_ipadapter_v2 import wav_to_mel
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import (
21
+ ClapFeatureExtractor,
22
+ ClapModel,
23
+ GPT2Model,
24
+ RobertaTokenizer,
25
+ RobertaTokenizerFast,
26
+ SpeechT5HifiGan,
27
+ T5EncoderModel,
28
+ T5Tokenizer,
29
+ T5TokenizerFast,
30
+ )
31
+
32
+ from diffusers import AutoencoderKL
33
+ from diffusers.schedulers import KarrasDiffusionSchedulers
34
+ from diffusers.utils import (
35
+ is_accelerate_available,
36
+ is_accelerate_version,
37
+ is_librosa_available,
38
+ logging,
39
+ replace_example_docstring,
40
+ )
41
+ from diffusers.utils.torch_utils import randn_tensor
42
+ from diffusers.pipelines.pipeline_utils import AudioPipelineOutput, DiffusionPipeline
43
+ from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
44
+ from diffusers.loaders import TextualInversionLoaderMixin
45
+
46
+
47
+ if is_librosa_available():
48
+ import librosa
49
+
50
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51
+
52
+ EXAMPLE_DOC_STRING = """
53
+ Examples:
54
+ ```py
55
+ >>> import scipy
56
+ >>> import torch
57
+ >>> from diffusers import AudioLDM2Pipeline
58
+
59
+ >>> repo_id = "cvssp/audioldm2"
60
+ >>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
61
+ >>> pipe = pipe.to("cuda")
62
+
63
+ >>> # define the prompts
64
+ >>> prompt = "The sound of a hammer hitting a wooden surface."
65
+ >>> negative_prompt = "Low quality."
66
+
67
+ >>> # set the seed for generator
68
+ >>> generator = torch.Generator("cuda").manual_seed(0)
69
+
70
+ >>> # run the generation
71
+ >>> audio = pipe(
72
+ ... prompt,
73
+ ... negative_prompt=negative_prompt,
74
+ ... num_inference_steps=200,
75
+ ... audio_length_in_s=10.0,
76
+ ... num_waveforms_per_prompt=3,
77
+ ... generator=generator,
78
+ ... ).audios
79
+
80
+ >>> # save the best audio sample (index 0) as a .wav file
81
+ >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio[0])
82
+ ```
83
+ """
84
+
85
+
86
+ def prepare_inputs_for_generation(
87
+ inputs_embeds,
88
+ attention_mask=None,
89
+ past_key_values=None,
90
+ **kwargs,
91
+ ):
92
+ if past_key_values is not None:
93
+ # only last token for inputs_embeds if past is defined in kwargs
94
+ inputs_embeds = inputs_embeds[:, -1:]
95
+
96
+ return {
97
+ "inputs_embeds": inputs_embeds,
98
+ "attention_mask": attention_mask,
99
+ "past_key_values": past_key_values,
100
+ "use_cache": kwargs.get("use_cache"),
101
+ }
102
+
103
+
104
+ class AudioLDM2Pipeline(DiffusionPipeline,TextualInversionLoaderMixin):
105
+ r"""
106
+ Pipeline for text-to-audio generation using AudioLDM2.
107
+
108
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
109
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
110
+
111
+ Args:
112
+ vae ([`AutoencoderKL`]):
113
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
114
+ text_encoder ([`~transformers.ClapModel`]):
115
+ First frozen text-encoder. AudioLDM2 uses the joint audio-text embedding model
116
+ [CLAP](https://huggingface.co/docs/transformers/model_doc/clap#transformers.CLAPTextModelWithProjection),
117
+ specifically the [laion/clap-htsat-unfused](https://huggingface.co/laion/clap-htsat-unfused) variant. The
118
+ text branch is used to encode the text prompt to a prompt embedding. The full audio-text model is used to
119
+ rank generated waveforms against the text prompt by computing similarity scores.
120
+ text_encoder_2 ([`~transformers.T5EncoderModel`]):
121
+ Second frozen text-encoder. AudioLDM2 uses the encoder of
122
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
123
+ [google/flan-t5-large](https://huggingface.co/google/flan-t5-large) variant.
124
+ projection_model ([`AudioLDM2ProjectionModel`]):
125
+ A trained model used to linearly project the hidden-states from the first and second text encoder models
126
+ and insert learned SOS and EOS token embeddings. The projected hidden-states from the two text encoders are
127
+ concatenated to give the input to the language model.
128
+ language_model ([`~transformers.GPT2Model`]):
129
+ An auto-regressive language model used to generate a sequence of hidden-states conditioned on the projected
130
+ outputs from the two text encoders.
131
+ tokenizer ([`~transformers.RobertaTokenizer`]):
132
+ Tokenizer to tokenize text for the first frozen text-encoder.
133
+ tokenizer_2 ([`~transformers.T5Tokenizer`]):
134
+ Tokenizer to tokenize text for the second frozen text-encoder.
135
+ feature_extractor ([`~transformers.ClapFeatureExtractor`]):
136
+ Feature extractor to pre-process generated audio waveforms to log-mel spectrograms for automatic scoring.
137
+ unet ([`UNet2DConditionModel`]):
138
+ A `UNet2DConditionModel` to denoise the encoded audio latents.
139
+ scheduler ([`SchedulerMixin`]):
140
+ A scheduler to be used in combination with `unet` to denoise the encoded audio latents. Can be one of
141
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
142
+ vocoder ([`~transformers.SpeechT5HifiGan`]):
143
+ Vocoder of class `SpeechT5HifiGan` to convert the mel-spectrogram latents to the final audio waveform.
144
+ """
145
+
146
+ def __init__(
147
+ self,
148
+ vae: AutoencoderKL,
149
+ text_encoder: ClapModel,
150
+ text_encoder_2: T5EncoderModel,
151
+ projection_model: AudioLDM2ProjectionModel,
152
+ language_model: GPT2Model,
153
+ tokenizer: Union[RobertaTokenizer, RobertaTokenizerFast],
154
+ tokenizer_2: Union[T5Tokenizer, T5TokenizerFast],
155
+ feature_extractor: ClapFeatureExtractor,
156
+ unet: AudioLDM2UNet2DConditionModel,
157
+ scheduler: KarrasDiffusionSchedulers,
158
+ vocoder: SpeechT5HifiGan,
159
+ ):
160
+ super().__init__()
161
+
162
+ self.register_modules(
163
+ vae=vae,
164
+ text_encoder=text_encoder,
165
+ text_encoder_2=text_encoder_2,
166
+ projection_model=projection_model,
167
+ language_model=language_model,
168
+ tokenizer=tokenizer,
169
+ tokenizer_2=tokenizer_2,
170
+ feature_extractor=feature_extractor,
171
+ unet=unet,
172
+ scheduler=scheduler,
173
+ vocoder=vocoder,
174
+ )
175
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
176
+
177
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
178
+ def enable_vae_slicing(self):
179
+ r"""
180
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
181
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
182
+ """
183
+ self.vae.enable_slicing()
184
+
185
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
186
+ def disable_vae_slicing(self):
187
+ r"""
188
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
189
+ computing decoding in one step.
190
+ """
191
+ self.vae.disable_slicing()
192
+
193
+ def enable_model_cpu_offload(self, gpu_id=0):
194
+ r"""
195
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
196
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
197
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
198
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
199
+ """
200
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
201
+ from accelerate import cpu_offload_with_hook
202
+ else:
203
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
204
+
205
+ device = torch.device(f"cuda:{gpu_id}")
206
+
207
+ if self.device.type != "cpu":
208
+ self.to("cpu", silence_dtype_warnings=True)
209
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
210
+
211
+ model_sequence = [
212
+ self.text_encoder.text_model,
213
+ self.text_encoder.text_projection,
214
+ self.text_encoder_2,
215
+ self.projection_model,
216
+ self.language_model,
217
+ self.unet,
218
+ self.vae,
219
+ self.vocoder,
220
+ self.text_encoder,
221
+ ]
222
+
223
+ hook = None
224
+ for cpu_offloaded_model in model_sequence:
225
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
226
+
227
+ # We'll offload the last model manually.
228
+ self.final_offload_hook = hook
229
+
230
+ def generate_language_model(
231
+ self,
232
+ inputs_embeds: torch.Tensor = None,
233
+ max_new_tokens: int = 8,
234
+ **model_kwargs,
235
+ ):
236
+ """
237
+
238
+ Generates a sequence of hidden-states from the language model, conditioned on the embedding inputs.
239
+
240
+ Parameters:
241
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
242
+ The sequence used as a prompt for the generation.
243
+ max_new_tokens (`int`):
244
+ Number of new tokens to generate.
245
+ model_kwargs (`Dict[str, Any]`, *optional*):
246
+ Ad hoc parametrization of additional model-specific kwargs that will be forwarded to the `forward`
247
+ function of the model.
248
+
249
+ Return:
250
+ `inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
251
+ The sequence of generated hidden-states.
252
+ """
253
+ max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
254
+ for _ in range(max_new_tokens):
255
+ # prepare model inputs
256
+ model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
257
+
258
+ # forward pass to get next hidden states
259
+ output = self.language_model(**model_inputs, return_dict=True)
260
+
261
+ next_hidden_states = output.last_hidden_state
262
+
263
+ # Update the model input
264
+ inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1)
265
+
266
+ # Update generated hidden states, model inputs, and length for next step
267
+ model_kwargs = self.language_model._update_model_kwargs_for_generation(output, model_kwargs)
268
+
269
+ return inputs_embeds[:, -max_new_tokens:, :]
270
+
271
+ def encode_prompt(
272
+ self,
273
+ prompt,
274
+ device,
275
+ num_waveforms_per_prompt,
276
+ do_classifier_free_guidance,
277
+ negative_prompt=None,
278
+ prompt_embeds: Optional[torch.FloatTensor] = None,
279
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
280
+ generated_prompt_embeds: Optional[torch.FloatTensor] = None,
281
+ negative_generated_prompt_embeds: Optional[torch.FloatTensor] = None,
282
+ attention_mask: Optional[torch.LongTensor] = None,
283
+ negative_attention_mask: Optional[torch.LongTensor] = None,
284
+ max_new_tokens: Optional[int] = None,
285
+ ):
286
+ r"""
287
+ Encodes the prompt into text encoder hidden states.
288
+
289
+ Args:
290
+ prompt (`str` or `List[str]`, *optional*):
291
+ prompt to be encoded
292
+ device (`torch.device`):
293
+ torch device
294
+ num_waveforms_per_prompt (`int`):
295
+ number of waveforms that should be generated per prompt
296
+ do_classifier_free_guidance (`bool`):
297
+ whether to use classifier free guidance or not
298
+ negative_prompt (`str` or `List[str]`, *optional*):
299
+ The prompt or prompts not to guide the audio generation. If not defined, one has to pass
300
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
301
+ less than `1`).
302
+ prompt_embeds (`torch.FloatTensor`, *optional*):
303
+ Pre-computed text embeddings from the Flan T5 model. Can be used to easily tweak text inputs, *e.g.*
304
+ prompt weighting. If not provided, text embeddings will be computed from `prompt` input argument.
305
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
306
+ Pre-computed negative text embeddings from the Flan T5 model. Can be used to easily tweak text inputs,
307
+ *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
308
+ `negative_prompt` input argument.
309
+ generated_prompt_embeds (`torch.FloatTensor`, *optional*):
310
+ Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs,
311
+ *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
312
+ argument.
313
+ negative_generated_prompt_embeds (`torch.FloatTensor`, *optional*):
314
+ Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text
315
+ inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
316
+ `negative_prompt` input argument.
317
+ attention_mask (`torch.LongTensor`, *optional*):
318
+ Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
319
+ be computed from `prompt` input argument.
320
+ negative_attention_mask (`torch.LongTensor`, *optional*):
321
+ Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention
322
+ mask will be computed from `negative_prompt` input argument.
323
+ max_new_tokens (`int`, *optional*, defaults to None):
324
+ The number of new tokens to generate with the GPT2 language model.
325
+ Returns:
326
+ prompt_embeds (`torch.FloatTensor`):
327
+ Text embeddings from the Flan T5 model.
328
+ attention_mask (`torch.LongTensor`):
329
+ Attention mask to be applied to the `prompt_embeds`.
330
+ generated_prompt_embeds (`torch.FloatTensor`):
331
+ Text embeddings generated from the GPT2 langauge model.
332
+
333
+ Example:
334
+
335
+ ```python
336
+ >>> import scipy
337
+ >>> import torch
338
+ >>> from diffusers import AudioLDM2Pipeline
339
+
340
+ >>> repo_id = "cvssp/audioldm2"
341
+ >>> pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
342
+ >>> pipe = pipe.to("cuda")
343
+
344
+ >>> # Get text embedding vectors
345
+ >>> prompt_embeds, attention_mask, generated_prompt_embeds = pipe.encode_prompt(
346
+ ... prompt="Techno music with a strong, upbeat tempo and high melodic riffs",
347
+ ... device="cuda",
348
+ ... do_classifier_free_guidance=True,
349
+ ... )
350
+
351
+ >>> # Pass text embeddings to pipeline for text-conditional audio generation
352
+ >>> audio = pipe(
353
+ ... prompt_embeds=prompt_embeds,
354
+ ... attention_mask=attention_mask,
355
+ ... generated_prompt_embeds=generated_prompt_embeds,
356
+ ... num_inference_steps=200,
357
+ ... audio_length_in_s=10.0,
358
+ ... ).audios[0]
359
+
360
+ >>> # save generated audio sample
361
+ >>> scipy.io.wavfile.write("techno.wav", rate=16000, data=audio)
362
+ ```"""
363
+ if prompt is not None and isinstance(prompt, str):
364
+ batch_size = 1
365
+ elif prompt is not None and isinstance(prompt, list):
366
+ batch_size = len(prompt)
367
+ else:
368
+ batch_size = prompt_embeds.shape[0]
369
+
370
+ # Define tokenizers and text encoders
371
+ tokenizers = [self.tokenizer, self.tokenizer_2]
372
+ text_encoders = [self.text_encoder, self.text_encoder_2]
373
+
374
+ if prompt_embeds is None:
375
+ prompt_embeds_list = []
376
+ attention_mask_list = []
377
+
378
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
379
+ text_inputs = tokenizer(
380
+ prompt,
381
+ padding="max_length" if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) else True,
382
+ max_length=tokenizer.model_max_length,
383
+ truncation=True,
384
+ return_tensors="pt",
385
+ )
386
+ text_input_ids = text_inputs.input_ids
387
+ attention_mask = text_inputs.attention_mask
388
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
389
+
390
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
391
+ text_input_ids, untruncated_ids
392
+ ):
393
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
394
+ logger.warning(
395
+ f"The following part of your input was truncated because {text_encoder.config.model_type} can "
396
+ f"only handle sequences up to {tokenizer.model_max_length} tokens: {removed_text}"
397
+ )
398
+
399
+ text_input_ids = text_input_ids.to(device)
400
+ attention_mask = attention_mask.to(device)
401
+
402
+ if text_encoder.config.model_type == "clap":
403
+ prompt_embeds = text_encoder.get_text_features(
404
+ text_input_ids,
405
+ attention_mask=attention_mask,
406
+ )
407
+ # append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
408
+ prompt_embeds = prompt_embeds[:, None, :]
409
+ # make sure that we attend to this single hidden-state
410
+ attention_mask = attention_mask.new_ones((batch_size, 1))
411
+ else:
412
+ prompt_embeds = text_encoder(
413
+ text_input_ids,
414
+ attention_mask=attention_mask,
415
+ )
416
+ prompt_embeds = prompt_embeds[0]
417
+
418
+ prompt_embeds_list.append(prompt_embeds)
419
+ attention_mask_list.append(attention_mask)
420
+
421
+ projection_output = self.projection_model(
422
+ hidden_states=prompt_embeds_list[0],
423
+ hidden_states_1=prompt_embeds_list[1],
424
+ attention_mask=attention_mask_list[0],
425
+ attention_mask_1=attention_mask_list[1],
426
+ )
427
+ projected_prompt_embeds = projection_output.hidden_states
428
+ projected_attention_mask = projection_output.attention_mask
429
+
430
+ generated_prompt_embeds = self.generate_language_model(
431
+ projected_prompt_embeds,
432
+ attention_mask=projected_attention_mask,
433
+ max_new_tokens=max_new_tokens,
434
+ )
435
+
436
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
437
+ attention_mask = (
438
+ attention_mask.to(device=device)
439
+ if attention_mask is not None
440
+ else torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=device)
441
+ )
442
+ generated_prompt_embeds = generated_prompt_embeds.to(dtype=self.language_model.dtype, device=device)
443
+
444
+ bs_embed, seq_len, hidden_size = prompt_embeds.shape
445
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
446
+ prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
447
+ prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len, hidden_size)
448
+
449
+ # duplicate attention mask for each generation per prompt
450
+ attention_mask = attention_mask.repeat(1, num_waveforms_per_prompt)
451
+ attention_mask = attention_mask.view(bs_embed * num_waveforms_per_prompt, seq_len)
452
+
453
+ bs_embed, seq_len, hidden_size = generated_prompt_embeds.shape
454
+ # duplicate generated embeddings for each generation per prompt, using mps friendly method
455
+ generated_prompt_embeds = generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
456
+ generated_prompt_embeds = generated_prompt_embeds.view(
457
+ bs_embed * num_waveforms_per_prompt, seq_len, hidden_size
458
+ )
459
+
460
+ # get unconditional embeddings for classifier free guidance
461
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
462
+ uncond_tokens: List[str]
463
+ if negative_prompt is None:
464
+ uncond_tokens = [""] * batch_size
465
+ elif type(prompt) is not type(negative_prompt):
466
+ raise TypeError(
467
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
468
+ f" {type(prompt)}."
469
+ )
470
+ elif isinstance(negative_prompt, str):
471
+ uncond_tokens = [negative_prompt]
472
+ elif batch_size != len(negative_prompt):
473
+ raise ValueError(
474
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
475
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
476
+ " the batch size of `prompt`."
477
+ )
478
+ else:
479
+ uncond_tokens = negative_prompt
480
+
481
+ negative_prompt_embeds_list = []
482
+ negative_attention_mask_list = []
483
+ max_length = prompt_embeds.shape[1]
484
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
485
+ uncond_input = tokenizer(
486
+ uncond_tokens,
487
+ padding="max_length",
488
+ max_length=tokenizer.model_max_length
489
+ if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
490
+ else max_length,
491
+ truncation=True,
492
+ return_tensors="pt",
493
+ )
494
+
495
+ uncond_input_ids = uncond_input.input_ids.to(device)
496
+ negative_attention_mask = uncond_input.attention_mask.to(device)
497
+
498
+ if text_encoder.config.model_type == "clap":
499
+ negative_prompt_embeds = text_encoder.get_text_features(
500
+ uncond_input_ids,
501
+ attention_mask=negative_attention_mask,
502
+ )
503
+ # append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
504
+ negative_prompt_embeds = negative_prompt_embeds[:, None, :]
505
+ # make sure that we attend to this single hidden-state
506
+ negative_attention_mask = negative_attention_mask.new_ones((batch_size, 1))
507
+ else:
508
+ negative_prompt_embeds = text_encoder(
509
+ uncond_input_ids,
510
+ attention_mask=negative_attention_mask,
511
+ )
512
+ negative_prompt_embeds = negative_prompt_embeds[0]
513
+
514
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
515
+ negative_attention_mask_list.append(negative_attention_mask)
516
+
517
+ projection_output = self.projection_model(
518
+ hidden_states=negative_prompt_embeds_list[0],
519
+ hidden_states_1=negative_prompt_embeds_list[1],
520
+ attention_mask=negative_attention_mask_list[0],
521
+ attention_mask_1=negative_attention_mask_list[1],
522
+ )
523
+ negative_projected_prompt_embeds = projection_output.hidden_states
524
+ negative_projected_attention_mask = projection_output.attention_mask
525
+
526
+ negative_generated_prompt_embeds = self.generate_language_model(
527
+ negative_projected_prompt_embeds,
528
+ attention_mask=negative_projected_attention_mask,
529
+ max_new_tokens=max_new_tokens,
530
+ )
531
+
532
+ if do_classifier_free_guidance:
533
+ seq_len = negative_prompt_embeds.shape[1]
534
+
535
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
536
+ negative_attention_mask = (
537
+ negative_attention_mask.to(device=device)
538
+ if negative_attention_mask is not None
539
+ else torch.ones(negative_prompt_embeds.shape[:2], dtype=torch.long, device=device)
540
+ )
541
+ negative_generated_prompt_embeds = negative_generated_prompt_embeds.to(
542
+ dtype=self.language_model.dtype, device=device
543
+ )
544
+
545
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
546
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
547
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len, -1)
548
+
549
+ # duplicate unconditional attention mask for each generation per prompt
550
+ negative_attention_mask = negative_attention_mask.repeat(1, num_waveforms_per_prompt)
551
+ negative_attention_mask = negative_attention_mask.view(batch_size * num_waveforms_per_prompt, seq_len)
552
+
553
+ # duplicate unconditional generated embeddings for each generation per prompt
554
+ seq_len = negative_generated_prompt_embeds.shape[1]
555
+ negative_generated_prompt_embeds = negative_generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
556
+ negative_generated_prompt_embeds = negative_generated_prompt_embeds.view(
557
+ batch_size * num_waveforms_per_prompt, seq_len, -1
558
+ )
559
+
560
+ # For classifier free guidance, we need to do two forward passes.
561
+ # Here we concatenate the unconditional and text embeddings into a single batch
562
+ # to avoid doing two forward passes
563
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
564
+ attention_mask = torch.cat([negative_attention_mask, attention_mask])
565
+ generated_prompt_embeds = torch.cat([negative_generated_prompt_embeds, generated_prompt_embeds])
566
+
567
+ return prompt_embeds, attention_mask, generated_prompt_embeds
568
+
569
+ # Copied from diffusers.pipelines.audioldm.pipeline_audioldm.AudioLDMPipeline.mel_spectrogram_to_waveform
570
+ def mel_spectrogram_to_waveform(self, mel_spectrogram):
571
+ if mel_spectrogram.dim() == 4:
572
+ mel_spectrogram = mel_spectrogram.squeeze(1)
573
+
574
+ waveform = self.vocoder(mel_spectrogram)
575
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
576
+ waveform = waveform.cpu().float()
577
+ return waveform
578
+
579
+ def score_waveforms(self, text, audio, num_waveforms_per_prompt, device, dtype):
580
+ if not is_librosa_available():
581
+ logger.info(
582
+ "Automatic scoring of the generated audio waveforms against the input prompt text requires the "
583
+ "`librosa` package to resample the generated waveforms. Returning the audios in the order they were "
584
+ "generated. To enable automatic scoring, install `librosa` with: `pip install librosa`."
585
+ )
586
+ return audio
587
+ inputs = self.tokenizer(text, return_tensors="pt", padding=True)
588
+ resampled_audio = librosa.resample(
589
+ audio.numpy(), orig_sr=self.vocoder.config.sampling_rate, target_sr=self.feature_extractor.sampling_rate
590
+ )
591
+ inputs["input_features"] = self.feature_extractor(
592
+ list(resampled_audio), return_tensors="pt", sampling_rate=self.feature_extractor.sampling_rate
593
+ ).input_features.type(dtype)
594
+ inputs = inputs.to(device)
595
+
596
+ # compute the audio-text similarity score using the CLAP model
597
+ logits_per_text = self.text_encoder(**inputs).logits_per_text
598
+ # sort by the highest matching generations per prompt
599
+ indices = torch.argsort(logits_per_text, dim=1, descending=True)[:, :num_waveforms_per_prompt]
600
+ audio = torch.index_select(audio, 0, indices.reshape(-1).cpu())
601
+ return audio
602
+
603
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
604
+ def prepare_extra_step_kwargs(self, generator, eta):
605
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
606
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
607
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
608
+ # and should be between [0, 1]
609
+
610
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
611
+ extra_step_kwargs = {}
612
+ if accepts_eta:
613
+ extra_step_kwargs["eta"] = eta
614
+
615
+ # check if the scheduler accepts generator
616
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
617
+ if accepts_generator:
618
+ extra_step_kwargs["generator"] = generator
619
+ return extra_step_kwargs
620
+
621
+ def check_inputs(
622
+ self,
623
+ prompt,
624
+ audio_length_in_s,
625
+ vocoder_upsample_factor,
626
+ callback_steps,
627
+ negative_prompt=None,
628
+ prompt_embeds=None,
629
+ negative_prompt_embeds=None,
630
+ generated_prompt_embeds=None,
631
+ negative_generated_prompt_embeds=None,
632
+ attention_mask=None,
633
+ negative_attention_mask=None,
634
+ ):
635
+ min_audio_length_in_s = vocoder_upsample_factor * self.vae_scale_factor
636
+ if audio_length_in_s < min_audio_length_in_s:
637
+ raise ValueError(
638
+ f"`audio_length_in_s` has to be a positive value greater than or equal to {min_audio_length_in_s}, but "
639
+ f"is {audio_length_in_s}."
640
+ )
641
+
642
+ if self.vocoder.config.model_in_dim % self.vae_scale_factor != 0:
643
+ raise ValueError(
644
+ f"The number of frequency bins in the vocoder's log-mel spectrogram has to be divisible by the "
645
+ f"VAE scale factor, but got {self.vocoder.config.model_in_dim} bins and a scale factor of "
646
+ f"{self.vae_scale_factor}."
647
+ )
648
+
649
+ if (callback_steps is None) or (
650
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
651
+ ):
652
+ raise ValueError(
653
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
654
+ f" {type(callback_steps)}."
655
+ )
656
+
657
+ if prompt is not None and prompt_embeds is not None:
658
+ raise ValueError(
659
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
660
+ " only forward one of the two."
661
+ )
662
+ elif prompt is None and (prompt_embeds is None or generated_prompt_embeds is None):
663
+ raise ValueError(
664
+ "Provide either `prompt`, or `prompt_embeds` and `generated_prompt_embeds`. Cannot leave "
665
+ "`prompt` undefined without specifying both `prompt_embeds` and `generated_prompt_embeds`."
666
+ )
667
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
668
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
669
+
670
+ if negative_prompt is not None and negative_prompt_embeds is not None:
671
+ raise ValueError(
672
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
673
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
674
+ )
675
+ elif negative_prompt_embeds is not None and negative_generated_prompt_embeds is None:
676
+ raise ValueError(
677
+ "Cannot forward `negative_prompt_embeds` without `negative_generated_prompt_embeds`. Ensure that"
678
+ "both arguments are specified"
679
+ )
680
+
681
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
682
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
683
+ raise ValueError(
684
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
685
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
686
+ f" {negative_prompt_embeds.shape}."
687
+ )
688
+ if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]:
689
+ raise ValueError(
690
+ "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
691
+ f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}"
692
+ )
693
+
694
+ if generated_prompt_embeds is not None and negative_generated_prompt_embeds is not None:
695
+ if generated_prompt_embeds.shape != negative_generated_prompt_embeds.shape:
696
+ raise ValueError(
697
+ "`generated_prompt_embeds` and `negative_generated_prompt_embeds` must have the same shape when "
698
+ f"passed directly, but got: `generated_prompt_embeds` {generated_prompt_embeds.shape} != "
699
+ f"`negative_generated_prompt_embeds` {negative_generated_prompt_embeds.shape}."
700
+ )
701
+ if (
702
+ negative_attention_mask is not None
703
+ and negative_attention_mask.shape != negative_prompt_embeds.shape[:2]
704
+ ):
705
+ raise ValueError(
706
+ "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:"
707
+ f"`attention_mask: {negative_attention_mask.shape} != `prompt_embeds` {negative_prompt_embeds.shape}"
708
+ )
709
+
710
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with width->self.vocoder.config.model_in_dim
711
+ def prepare_latents(self, batch_size, num_channels_latents, height, dtype, device, generator, latents=None):
712
+ shape = (
713
+ batch_size,
714
+ num_channels_latents,
715
+ height // self.vae_scale_factor,
716
+ self.vocoder.config.model_in_dim // self.vae_scale_factor,
717
+ )
718
+ if isinstance(generator, list) and len(generator) != batch_size:
719
+ raise ValueError(
720
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
721
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
722
+ )
723
+
724
+ if latents is None:
725
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
726
+ else:
727
+ latents = latents.to(device)
728
+
729
+ # scale the initial noise by the standard deviation required by the scheduler
730
+ latents = latents * self.scheduler.init_noise_sigma
731
+ return latents
732
+
733
+ @torch.no_grad()
734
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
735
+ def __call__(
736
+ self,
737
+ audio_path = None,
738
+ prompt: Union[str, List[str]] = None,
739
+ audio_length_in_s: Optional[float] = None,
740
+ num_inference_steps: int = 200,
741
+ guidance_scale: float = 3.5,
742
+ negative_prompt: Optional[Union[str, List[str]]] = None,
743
+ num_waveforms_per_prompt: Optional[int] = 1,
744
+ eta: float = 0.0,
745
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
746
+ latents: Optional[torch.FloatTensor] = None,
747
+ prompt_embeds: Optional[torch.FloatTensor] = None,
748
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
749
+ generated_prompt_embeds: Optional[torch.FloatTensor] = None,
750
+ negative_generated_prompt_embeds: Optional[torch.FloatTensor] = None,
751
+ attention_mask: Optional[torch.LongTensor] = None,
752
+ negative_attention_mask: Optional[torch.LongTensor] = None,
753
+ max_new_tokens: Optional[int] = None,
754
+ return_dict: bool = True,
755
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
756
+ callback_steps: Optional[int] = 1,
757
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
758
+ output_type: Optional[str] = "np",
759
+ ):
760
+ r"""
761
+ The call function to the pipeline for generation.
762
+
763
+ Args:
764
+ prompt (`str` or `List[str]`, *optional*):
765
+ The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`.
766
+ audio_length_in_s (`int`, *optional*, defaults to 10.24):
767
+ The length of the generated audio sample in seconds.
768
+ num_inference_steps (`int`, *optional*, defaults to 200):
769
+ The number of denoising steps. More denoising steps usually lead to a higher quality audio at the
770
+ expense of slower inference.
771
+ guidance_scale (`float`, *optional*, defaults to 3.5):
772
+ A higher guidance scale value encourages the model to generate audio that is closely linked to the text
773
+ `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`.
774
+ negative_prompt (`str` or `List[str]`, *optional*):
775
+ The prompt or prompts to guide what to not include in audio generation. If not defined, you need to
776
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
777
+ num_waveforms_per_prompt (`int`, *optional*, defaults to 1):
778
+ The number of waveforms to generate per prompt. If `num_waveforms_per_prompt > 1`, then automatic
779
+ scoring is performed between the generated outputs and the text prompt. This scoring ranks the
780
+ generated waveforms based on their cosine similarity with the text input in the joint text-audio
781
+ embedding space.
782
+ eta (`float`, *optional*, defaults to 0.0):
783
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
784
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
785
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
786
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
787
+ generation deterministic.
788
+ latents (`torch.FloatTensor`, *optional*):
789
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for spectrogram
790
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
791
+ tensor is generated by sampling using the supplied random `generator`.
792
+ prompt_embeds (`torch.FloatTensor`, *optional*):
793
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
794
+ provided, text embeddings are generated from the `prompt` input argument.
795
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
796
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
797
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
798
+ generated_prompt_embeds (`torch.FloatTensor`, *optional*):
799
+ Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs,
800
+ *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
801
+ argument.
802
+ negative_generated_prompt_embeds (`torch.FloatTensor`, *optional*):
803
+ Pre-generated negative text embeddings from the GPT2 language model. Can be used to easily tweak text
804
+ inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
805
+ `negative_prompt` input argument.
806
+ attention_mask (`torch.LongTensor`, *optional*):
807
+ Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will
808
+ be computed from `prompt` input argument.
809
+ negative_attention_mask (`torch.LongTensor`, *optional*):
810
+ Pre-computed attention mask to be applied to the `negative_prompt_embeds`. If not provided, attention
811
+ mask will be computed from `negative_prompt` input argument.
812
+ max_new_tokens (`int`, *optional*, defaults to None):
813
+ Number of new tokens to generate with the GPT2 language model. If not provided, number of tokens will
814
+ be taken from the config of the model.
815
+ return_dict (`bool`, *optional*, defaults to `True`):
816
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
817
+ plain tuple.
818
+ callback (`Callable`, *optional*):
819
+ A function that calls every `callback_steps` steps during inference. The function is called with the
820
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
821
+ callback_steps (`int`, *optional*, defaults to 1):
822
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
823
+ every step.
824
+ cross_attention_kwargs (`dict`, *optional*):
825
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
826
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
827
+ output_type (`str`, *optional*, defaults to `"np"`):
828
+ The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or
829
+ `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion
830
+ model (LDM) output.
831
+
832
+ Examples:
833
+
834
+ Returns:
835
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
836
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
837
+ otherwise a `tuple` is returned where the first element is a list with the generated audio.
838
+ """
839
+ # 0. Convert audio input length from seconds to spectrogram height
840
+ vocoder_upsample_factor = np.prod(self.vocoder.config.upsample_rates) / self.vocoder.config.sampling_rate
841
+
842
+ if audio_length_in_s is None:
843
+ audio_length_in_s = self.unet.config.sample_size * self.vae_scale_factor * vocoder_upsample_factor
844
+
845
+ height = int(audio_length_in_s / vocoder_upsample_factor)
846
+
847
+ original_waveform_length = int(audio_length_in_s * self.vocoder.config.sampling_rate)
848
+ if height % self.vae_scale_factor != 0:
849
+ height = int(np.ceil(height / self.vae_scale_factor)) * self.vae_scale_factor
850
+ logger.info(
851
+ f"Audio length in seconds {audio_length_in_s} is increased to {height * vocoder_upsample_factor} "
852
+ f"so that it can be handled by the model. It will be cut to {audio_length_in_s} after the "
853
+ f"denoising process."
854
+ )
855
+
856
+ # 1. Check inputs. Raise error if not correct
857
+ self.check_inputs(
858
+ prompt,
859
+ audio_length_in_s,
860
+ vocoder_upsample_factor,
861
+ callback_steps,
862
+ negative_prompt,
863
+ prompt_embeds,
864
+ negative_prompt_embeds,
865
+ generated_prompt_embeds,
866
+ negative_generated_prompt_embeds,
867
+ attention_mask,
868
+ negative_attention_mask,
869
+ )
870
+
871
+ # 2. Define call parameters
872
+ if prompt is not None and isinstance(prompt, str):
873
+ batch_size = 1
874
+ elif prompt is not None and isinstance(prompt, list):
875
+ batch_size = len(prompt)
876
+ else:
877
+ batch_size = prompt_embeds.shape[0]
878
+
879
+ device = self._execution_device
880
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
881
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
882
+ # corresponds to doing no classifier free guidance.
883
+ do_classifier_free_guidance = guidance_scale > 1.0
884
+
885
+ # 3. Encode input prompt
886
+ prompt_embeds, attention_mask, generated_prompt_embeds = self.encode_prompt(
887
+ prompt,
888
+ device,
889
+ num_waveforms_per_prompt,
890
+ do_classifier_free_guidance,
891
+ negative_prompt,
892
+ prompt_embeds=prompt_embeds,
893
+ negative_prompt_embeds=negative_prompt_embeds,
894
+ generated_prompt_embeds=generated_prompt_embeds,
895
+ negative_generated_prompt_embeds=negative_generated_prompt_embeds,
896
+ attention_mask=attention_mask,
897
+ negative_attention_mask=negative_attention_mask,
898
+ max_new_tokens=max_new_tokens,
899
+ )
900
+
901
+ # 4. Prepare timesteps
902
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
903
+ timesteps = self.scheduler.timesteps
904
+
905
+ # 5. Prepare latent variables
906
+ num_channels_latents = self.unet.config.in_channels
907
+
908
+ mel = wav_to_mel(audio_path,
909
+ 10,
910
+ augment_data=False,
911
+ mix_data=None,
912
+ snr=None)
913
+ # print("mel shape", mel.shape)
914
+ mel = mel.unsqueeze(0)
915
+ mel = mel.to(device)
916
+ mel = mel.to(torch.float16)
917
+ latents = self.vae.encode(mel).latent_dist.sample()
918
+
919
+ latents = latents * self.vae.config.scaling_factor
920
+ noise = torch.randn_like(latents)
921
+ print("timesteps",timesteps)
922
+ shallow_reverse_step = num_inference_steps // 4 *2
923
+ print("shallow_reverse_steps",timesteps[shallow_reverse_step:])
924
+ timesteps = timesteps[shallow_reverse_step:]
925
+ timesteps_tensor = torch.tensor([timesteps[0]], dtype=torch.int32)
926
+ noisy_sample = self.scheduler.add_noise(latents,noise,timesteps_tensor)
927
+
928
+ latents = self.prepare_latents(
929
+ batch_size * num_waveforms_per_prompt,
930
+ num_channels_latents,
931
+ height,
932
+ prompt_embeds.dtype,
933
+ device,
934
+ generator,
935
+ noisy_sample,
936
+ )
937
+ # latents = latents.squeeze(0)
938
+ # 6. Prepare extra step kwargs
939
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
940
+
941
+ # 7. Denoising loop
942
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
943
+ # timesteps = timesteps[:shallow_reverse_step]
944
+ # print("timesteps",timesteps)
945
+ print("latents",latents.shape)
946
+ latents = latents.repeat(8,1,1,1)
947
+ print("latents",latents.shape)
948
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
949
+ for i, t in enumerate(timesteps):
950
+ # expand the latents if we are doing classifier free guidance
951
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
952
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
953
+ # latent_model_input = latent_model_input[:,:,:250,:]
954
+ print("latent_model_input.shape",latent_model_input.shape)
955
+ print("t",t)
956
+ print("generated_prompt_embeds",generated_prompt_embeds.shape)
957
+ print("attention_mask",attention_mask.shape)
958
+ print("prompt_embeds",prompt_embeds.shape)
959
+ # predict the noise residual
960
+ noise_pred = self.unet(
961
+ latent_model_input,
962
+ t,
963
+ encoder_hidden_states=generated_prompt_embeds,
964
+ encoder_hidden_states_1=prompt_embeds,
965
+ encoder_attention_mask_1=attention_mask,
966
+ return_dict=False,
967
+ )[0]
968
+
969
+ # perform guidance
970
+ if do_classifier_free_guidance:
971
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
972
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
973
+
974
+ # compute the previous noisy sample x_t -> x_t-1
975
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
976
+
977
+ # call the callback, if provided
978
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
979
+ progress_bar.update()
980
+ if callback is not None and i % callback_steps == 0:
981
+ callback(i, t, latents)
982
+
983
+ self.maybe_free_model_hooks()
984
+
985
+ # 8. Post-processing
986
+ if not output_type == "latent":
987
+ latents = 1 / self.vae.config.scaling_factor * latents
988
+ mel_spectrogram = self.vae.decode(latents).sample
989
+ else:
990
+ return AudioPipelineOutput(audios=latents)
991
+
992
+ audio = self.mel_spectrogram_to_waveform(mel_spectrogram)
993
+
994
+ audio = audio[:, :original_waveform_length]
995
+
996
+ # 9. Automatic scoring
997
+ if num_waveforms_per_prompt > 1 and prompt is not None:
998
+ audio = self.score_waveforms(
999
+ text=prompt,
1000
+ audio=audio,
1001
+ num_waveforms_per_prompt=num_waveforms_per_prompt,
1002
+ device=device,
1003
+ dtype=prompt_embeds.dtype,
1004
+ )
1005
+
1006
+ if output_type == "np":
1007
+ audio = audio.numpy()
1008
+
1009
+ if not return_dict:
1010
+ return (audio,)
1011
+
1012
+ return AudioPipelineOutput(audios=audio)
utils/alpha_scheduler.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import lpips
5
+
6
+ perceptual_loss = lpips.LPIPS()
7
+
8
+
9
+ def distance(img_a, img_b):
10
+ # return perceptual_loss(img_a, img_b).item()
11
+ return F.mse_loss(img_a, img_b).item()
12
+
13
+
14
+ class AlphaScheduler:
15
+ def __init__(self):
16
+ ...
17
+
18
+ def from_imgs(self, imgs):
19
+ self.__num_values = len(imgs)
20
+ self.__values = [0]
21
+ for i in range(self.__num_values - 1):
22
+ dis = distance(imgs[i], imgs[i + 1])
23
+ self.__values.append(dis)
24
+ self.__values[i + 1] += self.__values[i]
25
+ for i in range(self.__num_values):
26
+ self.__values[i] /= self.__values[-1]
27
+
28
+ def save(self, filename):
29
+ torch.save(torch.tensor(self.__values), filename)
30
+
31
+ def load(self, filename):
32
+ self.__values = torch.load(filename).tolist()
33
+ self.__num_values = len(self.__values)
34
+
35
+ def get_x(self, y):
36
+ assert y >= 0 and y <= 1
37
+ id = bisect.bisect_left(self.__values, y)
38
+ id -= 1
39
+ if id < 0:
40
+ id = 0
41
+ yl = self.__values[id]
42
+ yr = self.__values[id + 1]
43
+ xl = id * (1 / (self.__num_values - 1))
44
+ xr = (id + 1) * (1 / (self.__num_values - 1))
45
+ x = (y - yl) / (yr - yl) * (xr - xl) + xl
46
+ return x
47
+
48
+ def get_list(self, len=None):
49
+ if len is None:
50
+ len = self.__num_values
51
+
52
+ ys = torch.linspace(0, 1, len)
53
+ res = [self.get_x(y) for y in ys]
54
+ return res
utils/lora_utils_successed_ver1.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from timeit import default_timer as timer
2
+ from datetime import timedelta
3
+ from PIL import Image
4
+ import os
5
+ import itertools
6
+ import numpy as np
7
+ from einops import rearrange
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torchvision import transforms
11
+ import transformers
12
+ from accelerate import Accelerator
13
+ from accelerate.utils import set_seed
14
+ from packaging import version
15
+ from PIL import Image
16
+ import tqdm
17
+ from typing import Any, Callable, Dict, List, Optional, Union
18
+ from transformers import AutoTokenizer, PretrainedConfig
19
+ from APadapter.ap_adapter.attention_processor import AttnProcessor2_0,IPAttnProcessor2_0
20
+ import diffusers
21
+ from diffusers import (
22
+ AutoencoderKL,
23
+ DDPMScheduler,
24
+ DiffusionPipeline,
25
+ DPMSolverMultistepScheduler,
26
+ StableDiffusionPipeline,
27
+ UNet2DConditionModel,
28
+ )
29
+ from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
30
+ from diffusers.models.attention_processor import (
31
+ AttnAddedKVProcessor,
32
+ AttnAddedKVProcessor2_0,
33
+ LoRAAttnAddedKVProcessor,
34
+ LoRAAttnProcessor,
35
+ LoRAAttnProcessor2_0,
36
+ SlicedAttnAddedKVProcessor,
37
+ )
38
+ from diffusers.optimization import get_scheduler
39
+ from diffusers.utils import check_min_version
40
+ from diffusers.utils.import_utils import is_xformers_available
41
+ import torchaudio
42
+ from audio_encoder.AudioMAE import AudioMAEConditionCTPoolRand, extract_kaldi_fbank_feature
43
+ from audioldm.utils import default_audioldm_config
44
+ from audioldm.audio import TacotronSTFT, read_wav_file
45
+ from audioldm.audio.tools import get_mel_from_wav, _pad_spec, normalize_wav, pad_wav
46
+ from transformers import (
47
+ ClapFeatureExtractor,
48
+ ClapModel,
49
+ GPT2Model,
50
+ RobertaTokenizer,
51
+ RobertaTokenizerFast,
52
+ SpeechT5HifiGan,
53
+ T5EncoderModel,
54
+ T5Tokenizer,
55
+ T5TokenizerFast,
56
+ )
57
+ from diffusers.utils.torch_utils import randn_tensor
58
+ from peft import (
59
+ prepare_model_for_kbit_training,
60
+ LoraConfig,
61
+ get_peft_model,
62
+ PeftModel
63
+ )
64
+ from torchviz import make_dot
65
+ import json
66
+ from matplotlib import pyplot as plt
67
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
68
+ # check_min_version("0.17.0")
69
+
70
+ def wav_to_fbank(
71
+ filename,
72
+ target_length=1024,
73
+ fn_STFT=None,
74
+ augment_data=False,
75
+ mix_data=False,
76
+ snr=None
77
+ ):
78
+ assert fn_STFT is not None
79
+ waveform = read_wav_file(filename, target_length * 160) # hop size is 160
80
+ waveform = waveform[0, ...]
81
+ waveform = torch.FloatTensor(waveform)
82
+
83
+ fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
84
+
85
+ fbank = torch.FloatTensor(fbank.T)
86
+ log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T)
87
+
88
+ fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
89
+ log_magnitudes_stft, target_length
90
+ )
91
+ fbank = fbank.contiguous()
92
+ log_magnitudes_stft = log_magnitudes_stft.contiguous()
93
+ waveform = waveform.contiguous()
94
+ return fbank, log_magnitudes_stft, waveform
95
+
96
+ def wav_to_mel(
97
+ original_audio_file_path,
98
+ duration,
99
+ augment_data=False,
100
+ mix_data=False,
101
+ snr=None):
102
+ config=default_audioldm_config()
103
+
104
+ fn_STFT = TacotronSTFT(
105
+ config["preprocessing"]["stft"]["filter_length"],
106
+ config["preprocessing"]["stft"]["hop_length"],
107
+ config["preprocessing"]["stft"]["win_length"],
108
+ config["preprocessing"]["mel"]["n_mel_channels"],
109
+ config["preprocessing"]["audio"]["sampling_rate"],
110
+ config["preprocessing"]["mel"]["mel_fmin"],
111
+ config["preprocessing"]["mel"]["mel_fmax"],
112
+ )
113
+
114
+ mel, _, _ = wav_to_fbank(
115
+ original_audio_file_path,
116
+ target_length=int(duration * 102.4),
117
+ fn_STFT=fn_STFT,
118
+ augment_data=augment_data,
119
+ mix_data=mix_data,
120
+ snr=snr
121
+ )
122
+ mel = mel.unsqueeze(0)
123
+ return mel
124
+
125
+ def prepare_inputs_for_generation(
126
+ inputs_embeds,
127
+ attention_mask=None,
128
+ past_key_values=None,
129
+ **kwargs,
130
+ ):
131
+ if past_key_values is not None:
132
+ # only last token for inputs_embeds if past is defined in kwargs
133
+ inputs_embeds = inputs_embeds[:, -1:]
134
+ kwargs["use_cache"] = True
135
+ return {
136
+ "inputs_embeds": inputs_embeds,
137
+ "attention_mask": attention_mask,
138
+ "past_key_values": past_key_values,
139
+ "use_cache": kwargs.get("use_cache"),
140
+ }
141
+
142
+ def generate_language_model(
143
+ language_model,
144
+ inputs_embeds: torch.Tensor = None,
145
+ max_new_tokens: int = 512,
146
+ **model_kwargs,
147
+ ):
148
+ """
149
+
150
+ Generates a sequence of hidden-states from the language model, conditioned on the embedding inputs.
151
+
152
+ Parameters:
153
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
154
+ The sequence used as a prompt for the generation.
155
+ max_new_tokens (`int`):
156
+ Number of new tokens to generate.
157
+ model_kwargs (`Dict[str, Any]`, *optional*):
158
+ Ad hoc parametrization of additional model-specific kwargs that will be forwarded to the `forward`
159
+ function of the model.
160
+
161
+ Return:
162
+ `inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
163
+ The sequence of generated hidden-states.
164
+ """
165
+ max_new_tokens = max_new_tokens if max_new_tokens is not None else language_model.config.max_new_tokens
166
+ model_kwargs = language_model._get_initial_cache_position(inputs_embeds, model_kwargs)
167
+ for _ in range(max_new_tokens):
168
+ # prepare model inputs
169
+ model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
170
+
171
+ # forward pass to get next hidden states
172
+ output = language_model(**model_inputs, return_dict=True)
173
+ next_hidden_states = output.last_hidden_state
174
+
175
+ # Update the model input
176
+ inputs_embeds = torch.cat([inputs_embeds, next_hidden_states[:, -1:, :]], dim=1)
177
+
178
+ # Update generated hidden states, model inputs, and length for next step
179
+ model_kwargs = language_model._update_model_kwargs_for_generation(output, model_kwargs)
180
+
181
+ return inputs_embeds[:, -max_new_tokens:, :]
182
+
183
+ def encode_prompt(
184
+ tokenizer,
185
+ tokenizer_2,
186
+ text_encoder,
187
+ text_encoder_2,
188
+ projection_model,
189
+ language_model,
190
+ prompt,
191
+ device,
192
+ num_waveforms_per_prompt,
193
+ do_classifier_free_guidance,
194
+ negative_prompt=None,
195
+ prompt_embeds: Optional[torch.FloatTensor] = None,
196
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
197
+ generated_prompt_embeds: Optional[torch.FloatTensor] = None,
198
+ negative_generated_prompt_embeds: Optional[torch.FloatTensor] = None,
199
+ attention_mask: Optional[torch.LongTensor] = None,
200
+ negative_attention_mask: Optional[torch.LongTensor] = None,
201
+ max_new_tokens: Optional[int] = None,
202
+ ):
203
+ if prompt is not None and isinstance(prompt, str):
204
+ batch_size = 1
205
+ elif prompt is not None and isinstance(prompt, list):
206
+ batch_size = len(prompt)
207
+ else:
208
+ batch_size = prompt_embeds.shape[0]
209
+ # Define tokenizers and text encoders
210
+ tokenizers = [tokenizer, tokenizer_2]
211
+ text_encoders = [text_encoder, text_encoder_2]
212
+
213
+ if prompt_embeds is None:
214
+ prompt_embeds_list = []
215
+ attention_mask_list = []
216
+
217
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
218
+ text_inputs = tokenizer(
219
+ prompt,
220
+ padding="max_length" if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast)) else True,
221
+ max_length=tokenizer.model_max_length,
222
+ truncation=True,
223
+ return_tensors="pt",
224
+ )
225
+ text_input_ids = text_inputs.input_ids
226
+ attention_mask = text_inputs.attention_mask
227
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
228
+
229
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
230
+ text_input_ids, untruncated_ids
231
+ ):
232
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
233
+ # logger.warning(
234
+ # f"The following part of your input was truncated because {text_encoder.config.model_type} can "
235
+ # f"only handle sequences up to {tokenizer.model_max_length} tokens: {removed_text}"
236
+ # )
237
+
238
+ text_input_ids = text_input_ids.to(device)
239
+ attention_mask = attention_mask.to(device)
240
+
241
+ if text_encoder.config.model_type == "clap":
242
+ prompt_embeds = text_encoder.get_text_features(
243
+ text_input_ids,
244
+ attention_mask=attention_mask,
245
+ )
246
+ # append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
247
+ prompt_embeds = prompt_embeds[:, None, :]
248
+ # make sure that we attend to this single hidden-state
249
+ attention_mask = attention_mask.new_ones((batch_size, 1))
250
+ else:
251
+ prompt_embeds = text_encoder(
252
+ text_input_ids,
253
+ attention_mask=attention_mask,
254
+ )
255
+ prompt_embeds = prompt_embeds[0]
256
+
257
+ prompt_embeds_list.append(prompt_embeds)
258
+ attention_mask_list.append(attention_mask)
259
+ projection_output = projection_model(
260
+ hidden_states=prompt_embeds_list[0],
261
+ hidden_states_1=prompt_embeds_list[1],
262
+ attention_mask=attention_mask_list[0],
263
+ attention_mask_1=attention_mask_list[1],
264
+ )
265
+ projected_prompt_embeds = projection_output.hidden_states
266
+ projected_attention_mask = projection_output.attention_mask
267
+
268
+ generated_prompt_embeds = generate_language_model(
269
+ language_model,
270
+ projected_prompt_embeds,
271
+ attention_mask=projected_attention_mask,
272
+ max_new_tokens=max_new_tokens,
273
+ )
274
+
275
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder_2.dtype, device=device)
276
+ attention_mask = (
277
+ attention_mask.to(device=device)
278
+ if attention_mask is not None
279
+ else torch.ones(prompt_embeds.shape[:2], dtype=torch.long, device=device)
280
+ )
281
+ generated_prompt_embeds = generated_prompt_embeds.to(dtype=language_model.dtype, device=device)
282
+
283
+ bs_embed, seq_len, hidden_size = prompt_embeds.shape
284
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
285
+ prompt_embeds = prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
286
+ prompt_embeds = prompt_embeds.view(bs_embed * num_waveforms_per_prompt, seq_len, hidden_size)
287
+
288
+ # duplicate attention mask for each generation per prompt
289
+ attention_mask = attention_mask.repeat(1, num_waveforms_per_prompt)
290
+ attention_mask = attention_mask.view(bs_embed * num_waveforms_per_prompt, seq_len)
291
+
292
+ bs_embed, seq_len, hidden_size = generated_prompt_embeds.shape
293
+ # duplicate generated embeddings for each generation per prompt, using mps friendly method
294
+ generated_prompt_embeds = generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
295
+ generated_prompt_embeds = generated_prompt_embeds.view(
296
+ bs_embed * num_waveforms_per_prompt, seq_len, hidden_size
297
+ )
298
+
299
+ # get unconditional embeddings for classifier free guidance
300
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
301
+ uncond_tokens: List[str]
302
+ if negative_prompt is None:
303
+ uncond_tokens = [""] * batch_size
304
+ elif type(prompt) is not type(negative_prompt):
305
+ raise TypeError(
306
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
307
+ f" {type(prompt)}."
308
+ )
309
+ elif isinstance(negative_prompt, str):
310
+ uncond_tokens = [negative_prompt]
311
+ elif batch_size != len(negative_prompt):
312
+ raise ValueError(
313
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
314
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
315
+ " the batch size of `prompt`."
316
+ )
317
+ else:
318
+ uncond_tokens = negative_prompt
319
+
320
+ negative_prompt_embeds_list = []
321
+ negative_attention_mask_list = []
322
+ max_length = prompt_embeds.shape[1]
323
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
324
+ uncond_input = tokenizer(
325
+ uncond_tokens,
326
+ padding="max_length",
327
+ max_length=tokenizer.model_max_length
328
+ if isinstance(tokenizer, (RobertaTokenizer, RobertaTokenizerFast))
329
+ else max_length,
330
+ truncation=True,
331
+ return_tensors="pt",
332
+ )
333
+
334
+ uncond_input_ids = uncond_input.input_ids.to(device)
335
+ negative_attention_mask = uncond_input.attention_mask.to(device)
336
+
337
+ if text_encoder.config.model_type == "clap":
338
+ negative_prompt_embeds = text_encoder.get_text_features(
339
+ uncond_input_ids,
340
+ attention_mask=negative_attention_mask,
341
+ )
342
+ # append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
343
+ negative_prompt_embeds = negative_prompt_embeds[:, None, :]
344
+ # make sure that we attend to this single hidden-state
345
+ negative_attention_mask = negative_attention_mask.new_ones((batch_size, 1))
346
+ else:
347
+ negative_prompt_embeds = text_encoder(
348
+ uncond_input_ids,
349
+ attention_mask=negative_attention_mask,
350
+ )
351
+ negative_prompt_embeds = negative_prompt_embeds[0]
352
+
353
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
354
+ negative_attention_mask_list.append(negative_attention_mask)
355
+
356
+ projection_output = projection_model(
357
+ hidden_states=negative_prompt_embeds_list[0],
358
+ hidden_states_1=negative_prompt_embeds_list[1],
359
+ attention_mask=negative_attention_mask_list[0],
360
+ attention_mask_1=negative_attention_mask_list[1],
361
+ )
362
+ negative_projected_prompt_embeds = projection_output.hidden_states
363
+ negative_projected_attention_mask = projection_output.attention_mask
364
+
365
+ negative_generated_prompt_embeds = generate_language_model(
366
+ language_model,
367
+ negative_projected_prompt_embeds,
368
+ attention_mask=negative_projected_attention_mask,
369
+ max_new_tokens=max_new_tokens,
370
+ )
371
+
372
+ if do_classifier_free_guidance:
373
+ seq_len = negative_prompt_embeds.shape[1]
374
+
375
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder_2.dtype, device=device)
376
+ negative_attention_mask = (
377
+ negative_attention_mask.to(device=device)
378
+ if negative_attention_mask is not None
379
+ else torch.ones(negative_prompt_embeds.shape[:2], dtype=torch.long, device=device)
380
+ )
381
+ negative_generated_prompt_embeds = negative_generated_prompt_embeds.to(
382
+ dtype=language_model.dtype, device=device
383
+ )
384
+
385
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
386
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
387
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_waveforms_per_prompt, seq_len, -1)
388
+
389
+ # duplicate unconditional attention mask for each generation per prompt
390
+ negative_attention_mask = negative_attention_mask.repeat(1, num_waveforms_per_prompt)
391
+ negative_attention_mask = negative_attention_mask.view(batch_size * num_waveforms_per_prompt, seq_len)
392
+
393
+ # duplicate unconditional generated embeddings for each generation per prompt
394
+ seq_len = negative_generated_prompt_embeds.shape[1]
395
+ negative_generated_prompt_embeds = negative_generated_prompt_embeds.repeat(1, num_waveforms_per_prompt, 1)
396
+ negative_generated_prompt_embeds = negative_generated_prompt_embeds.view(
397
+ batch_size * num_waveforms_per_prompt, seq_len, -1
398
+ )
399
+
400
+ # For classifier free guidance, we need to do two forward passes.
401
+ # Here we concatenate the unconditional and text embeddings into a single batch
402
+ # to avoid doing two forward passes
403
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
404
+ attention_mask = torch.cat([negative_attention_mask, attention_mask])
405
+ generated_prompt_embeds = torch.cat([negative_generated_prompt_embeds, generated_prompt_embeds])
406
+
407
+ return prompt_embeds, attention_mask, generated_prompt_embeds
408
+
409
+ def prepare_latents(vae, vocoder, scheduler, batch_size, num_channels_latents, height, dtype, device, generator, latents=None):
410
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
411
+ shape = (
412
+ batch_size,
413
+ num_channels_latents,
414
+ height // vae_scale_factor,
415
+ vocoder.config.model_in_dim // vae_scale_factor,
416
+ )
417
+ if isinstance(generator, list) and len(generator) != batch_size:
418
+ raise ValueError(
419
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
420
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
421
+ )
422
+
423
+ if latents is None:
424
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
425
+ else:
426
+ latents = latents.to(device)
427
+
428
+ # scale the initial noise by the standard deviation required by the scheduler
429
+ latents = latents * scheduler.init_noise_sigma
430
+ return latents
431
+
432
+ def plot_loss(loss_history, loss_plot_path, lora_steps):
433
+ plt.figure(figsize=(10, 6))
434
+ plt.plot(range(1, lora_steps + 1), loss_history, label="Training Loss")
435
+ plt.xlabel("Steps")
436
+ plt.ylabel("Loss")
437
+ plt.title("Training Loss Over Steps")
438
+ plt.legend()
439
+ plt.grid(True)
440
+ plt.savefig(loss_plot_path)
441
+ plt.close()
442
+ # print(f"Loss plot saved to {loss_plot_path}")
443
+
444
+
445
+ # model_path: path of the model
446
+ # image: input image, have not been pre-processed
447
+ # save_lora_dir: the path to save the lora
448
+ # prompt: the user input prompt
449
+ # lora_steps: number of lora training step
450
+ # lora_lr: learning rate of lora training
451
+ # lora_rank: the rank of lora
452
+ def train_lora(audio_path ,height ,time_pooling ,freq_pooling ,prompt, negative_prompt, guidance_scale, save_lora_dir, tokenizer=None, tokenizer_2=None,
453
+ text_encoder=None, text_encoder_2=None, GPT2=None, projection_model=None, vocoder=None,
454
+ vae=None, unet=None, noise_scheduler=None, lora_steps=200, lora_lr=2e-4, lora_rank=16, weight_name=None, safe_serialization=False, progress=tqdm):
455
+ time_pooling = time_pooling
456
+ freq_pooling = freq_pooling
457
+ # initialize accelerator
458
+ # accelerator = Accelerator(
459
+ # gradient_accumulation_steps=1,
460
+ # mixed_precision='no'
461
+ # )
462
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
463
+ set_seed(0)
464
+ # set device and dtype
465
+ # prepare accelerator
466
+ # unet_lora_layers = accelerator.prepare_model(unet_lora_layers)
467
+ # optimizer = accelerator.prepare_optimizer(optimizer)
468
+ # lr_scheduler = accelerator.prepare_scheduler(lr_scheduler)
469
+
470
+ vae.requires_grad_(False)
471
+ text_encoder.requires_grad_(False)
472
+ text_encoder_2.requires_grad_(False)
473
+ GPT2.requires_grad_(False)
474
+ projection_model.requires_grad_(False)
475
+ vocoder.requires_grad_(False)
476
+ unet.requires_grad_(False)
477
+
478
+
479
+
480
+
481
+ for name, param in text_encoder_2.named_parameters():
482
+ if param.requires_grad:
483
+ print(name)
484
+ for name, param in GPT2.named_parameters():
485
+ if param.requires_grad:
486
+ print(name)
487
+ for name, param in vae.named_parameters():
488
+ if param.requires_grad:
489
+ print(name)
490
+ for name, param in vocoder.named_parameters():
491
+ if param.requires_grad:
492
+ print(name)
493
+
494
+ unet.to(device)
495
+ vae.to(device)
496
+ text_encoder.to(device)
497
+
498
+
499
+ # initialize UNet LoRA
500
+ unet_lora_attn_procs = {}
501
+ i = 0 # Counter variable to iterate through the cross-attention dimension array.
502
+ cross = [None, None, 768, 768, 1024, 1024, None, None] # Predefined cross-attention dimensions for different layers.
503
+ do_copy = False
504
+ for name, attn_processor in unet.attn_processors.items():
505
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
506
+ if name.startswith("mid_block"):
507
+ hidden_size = unet.config.block_out_channels[-1]
508
+ elif name.startswith("up_blocks"):
509
+ block_id = int(name[len("up_blocks.")])
510
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
511
+ elif name.startswith("down_blocks"):
512
+ block_id = int(name[len("down_blocks.")])
513
+ hidden_size = unet.config.block_out_channels[block_id]
514
+ else:
515
+ raise NotImplementedError("name must start with up_blocks, mid_blocks, or down_blocks")
516
+
517
+ # if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
518
+ # lora_attn_processor_class = LoRAAttnAddedKVProcessor
519
+ # else:
520
+ # lora_attn_processor_class = (
521
+ # LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
522
+ # )
523
+
524
+ if cross_attention_dim is None:
525
+ unet_lora_attn_procs[name] = AttnProcessor2_0()
526
+ else:
527
+ cross_attention_dim = cross[i%8]
528
+ i += 1
529
+ if cross_attention_dim == 768:
530
+ unet_lora_attn_procs[name] = IPAttnProcessor2_0(
531
+ hidden_size=hidden_size,
532
+ name = name,
533
+ cross_attention_dim=cross_attention_dim,
534
+ scale=1.0,
535
+ num_tokens=8,
536
+ do_copy = do_copy
537
+ ).to(device, dtype=torch.float32)
538
+ else:
539
+ unet_lora_attn_procs[name] = AttnProcessor2_0()
540
+ unet.set_attn_processor(unet_lora_attn_procs)
541
+ unet_lora_layers = AttnProcsLayers(unet.attn_processors)
542
+
543
+ # Optimizer creation
544
+ params_to_optimize = (unet_lora_layers.parameters())
545
+ optimizer = torch.optim.AdamW(
546
+ params_to_optimize,
547
+ lr=lora_lr,
548
+ betas=(0.9, 0.999),
549
+ weight_decay=1e-2,
550
+ eps=1e-08,
551
+ )
552
+
553
+ lr_scheduler = get_scheduler(
554
+ "constant",
555
+ optimizer=optimizer,
556
+ num_warmup_steps=0,
557
+ num_training_steps=lora_steps,
558
+ num_cycles=1,
559
+ power=1.0,
560
+ )
561
+
562
+
563
+ do_classifier_free_guidance = guidance_scale > 1.0
564
+ # initialize text embeddings
565
+ with torch.no_grad():
566
+ prompt_embeds, attention_mask, generated_prompt_embeds = encode_prompt(
567
+ tokenizer,
568
+ tokenizer_2,
569
+ text_encoder,
570
+ text_encoder_2,
571
+ projection_model,
572
+ GPT2,
573
+ prompt,
574
+ device,
575
+ num_waveforms_per_prompt = 1,
576
+ do_classifier_free_guidance= do_classifier_free_guidance,
577
+ negative_prompt = negative_prompt,
578
+ )
579
+ waveform, sr = torchaudio.load(audio_path)
580
+ fbank = torch.zeros((1024, 128))
581
+ ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, sr, fbank)
582
+ mel_spect_tensor = ta_kaldi_fbank.unsqueeze(0)
583
+ model = AudioMAEConditionCTPoolRand().to(device).to(dtype=torch.float32)
584
+ model.eval()
585
+ mel_spect_tensor = mel_spect_tensor.to(device, dtype=next(model.parameters()).dtype)
586
+ LOA_embed = model(mel_spect_tensor, time_pool=time_pooling, freq_pool=freq_pooling)
587
+ uncond_LOA_embed = model(torch.zeros_like(mel_spect_tensor), time_pool=time_pooling, freq_pool=freq_pooling)
588
+ LOA_embeds = LOA_embed[0]
589
+ uncond_LOA_embeds = uncond_LOA_embed[0]
590
+ bs_embed, seq_len, _ = LOA_embeds.shape
591
+ num = prompt_embeds.shape[0] // 2
592
+ LOA_embeds = LOA_embeds.view(bs_embed , seq_len, -1)
593
+ LOA_embeds = LOA_embeds.repeat(num, 1, 1)
594
+ uncond_LOA_embeds = uncond_LOA_embeds.view(bs_embed , seq_len, -1)
595
+ uncond_LOA_embeds = uncond_LOA_embeds.repeat(num, 1, 1)
596
+ negative_g, g = generated_prompt_embeds.chunk(2)
597
+ uncond = torch.cat([negative_g, uncond_LOA_embeds], dim=1)
598
+ cond = torch.cat([g, LOA_embeds], dim=1)
599
+ generated_prompt_embeds = torch.cat([uncond, cond], dim=0)
600
+ model_dtype = next(unet.parameters()).dtype
601
+ generated_prompt_embeds = generated_prompt_embeds.to(model_dtype)
602
+
603
+ # num_channels_latents = unet.config.in_channels
604
+ # batch_size = 1
605
+ # num_waveforms_per_prompt = 1
606
+ # generator = None
607
+ # latents = None
608
+ # latents = prepare_latents(
609
+ # vae,
610
+ # vocoder,
611
+ # noise_scheduler,
612
+ # batch_size * num_waveforms_per_prompt,
613
+ # num_channels_latents,
614
+ # height,
615
+ # prompt_embeds.dtype,
616
+ # device,
617
+ # generator,
618
+ # latents,
619
+ # )
620
+
621
+ loss_history = []
622
+ if not os.path.exists(save_lora_dir):
623
+ os.makedirs(save_lora_dir)
624
+ weight_path = os.path.join(save_lora_dir, weight_name)
625
+ base_name, _ = os.path.splitext(weight_path)
626
+ save_image_path = f"{base_name}.png"
627
+ print(f'Save image path: {save_image_path}')
628
+ mel_spect_tensor = wav_to_mel(audio_path, duration = 10).unsqueeze(0).to(next(vae.parameters()).dtype)
629
+
630
+ for step in progress.tqdm(range(lora_steps), desc="Training LoRA..."):
631
+ unet.train()
632
+ # with accelerator.accumulate(unet):
633
+ latents_dist = vae.encode(mel_spect_tensor.to(device)).latent_dist
634
+ model_input = torch.cat([latents_dist.sample()] * 2) if do_classifier_free_guidance else latents_dist.sample()
635
+ model_input = model_input * vae.config.scaling_factor
636
+ # Sample noise that we'll add to the latents
637
+ noise = torch.randn_like(model_input).to(model_input.device)
638
+ bsz, channels, height, width = model_input.shape
639
+ # Sample a random timestep for each image
640
+ timesteps = torch.randint(
641
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
642
+ )
643
+ timesteps = timesteps.long()
644
+ # Add noise to the model input according to the noise magnitude at each timestep (this is the forward diffusion process)
645
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
646
+ generated_prompt_embeds = generated_prompt_embeds.to(device)
647
+ prompt_embeds = prompt_embeds.to(device)
648
+ attention_mask = attention_mask.to(device)
649
+ # Predict the noise residual
650
+ model_pred = unet(sample=noisy_model_input,
651
+ timestep=timesteps,
652
+ encoder_hidden_states=generated_prompt_embeds,
653
+ encoder_hidden_states_1=prompt_embeds,
654
+ encoder_attention_mask_1=attention_mask,
655
+ return_dict=False,
656
+ )[0]
657
+
658
+ # Get the target for loss depending on the prediction type
659
+ if noise_scheduler.config.prediction_type == "epsilon":
660
+ target = noise
661
+ elif noise_scheduler.config.prediction_type == "v_prediction":
662
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
663
+ else:
664
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
665
+ loss = F.mse_loss(model_pred, target, reduction="mean")
666
+ loss_history.append(loss.item())
667
+ loss.requires_grad = True
668
+ loss.backward()
669
+ optimizer.step()
670
+ lr_scheduler.step()
671
+ optimizer.zero_grad()
672
+ # with open(loss_log_path, "w") as f:
673
+ # json.dump(loss_history, f)
674
+
675
+ plot_loss(loss_history, save_image_path, step+1)
676
+
677
+
678
+ LoraLoaderMixin.save_lora_weights(
679
+ save_directory=save_lora_dir,
680
+ unet_lora_layers=unet_lora_layers,
681
+ text_encoder_lora_layers=None,
682
+ weight_name=weight_name,
683
+ safe_serialization=safe_serialization
684
+ )
685
+
686
+ def load_lora(unet, lora_0, lora_1, alpha):
687
+ attn_procs = unet.attn_processors
688
+ for name, processor in attn_procs.items():
689
+ if hasattr(processor, 'to_v_ip') or hasattr(processor, 'to_k_ip'):
690
+ weight_name_v = name + ".to_v_ip.weight"
691
+ weight_name_k = name + ".to_k_ip.weight"
692
+ if weight_name_v in lora_0 and weight_name_v in lora_1:
693
+ v_weight = (1 - alpha) * lora_0[weight_name_v] + alpha * lora_1[weight_name_v]
694
+ processor.to_v_ip.weight = torch.nn.Parameter(v_weight.half())
695
+
696
+ if weight_name_k in lora_0 and weight_name_k in lora_1:
697
+ k_weight = (1 - alpha) * lora_0[weight_name_k] + alpha * lora_1[weight_name_k]
698
+ processor.to_k_ip.weight = torch.nn.Parameter(k_weight.half())
699
+ unet.set_attn_processor(attn_procs)
700
+ return unet
utils/model_utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torchvision import transforms
4
+
5
+ def calc_mean_std(feat, eps=1e-5):
6
+ # eps is a small value added to the variance to avoid divide-by-zero.
7
+ size = feat.size()
8
+
9
+ N, C = size[:2]
10
+ feat_var = feat.view(N, C, -1).var(dim=2) + eps
11
+ if len(size) == 3:
12
+ feat_std = feat_var.sqrt().view(N, C, 1)
13
+ feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1)
14
+ else:
15
+ feat_std = feat_var.sqrt().view(N, C, 1, 1)
16
+ feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
17
+ return feat_mean, feat_std
18
+
19
+
20
+ def get_img(img, resolution=512):
21
+ norm_mean = [0.5, 0.5, 0.5]
22
+ norm_std = [0.5, 0.5, 0.5]
23
+ transform = transforms.Compose([
24
+ transforms.Resize((resolution, resolution)),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize(norm_mean, norm_std)
27
+ ])
28
+ img = transform(img)
29
+ return img.unsqueeze(0)
30
+
31
+ @torch.no_grad()
32
+ def slerp(p0, p1, fract_mixing: float, adain=True):
33
+ r""" Copied from lunarring/latentblending
34
+ Helper function to correctly mix two random variables using spherical interpolation.
35
+ The function will always cast up to float64 for sake of extra 4.
36
+ Args:
37
+ p0:
38
+ First tensor for interpolation
39
+ p1:
40
+ Second tensor for interpolation
41
+ fract_mixing: float
42
+ Mixing coefficient of interval [0, 1].
43
+ 0 will return in p0
44
+ 1 will return in p1
45
+ 0.x will return a mix between both preserving angular velocity.
46
+ """
47
+ if p0.dtype == torch.float16:
48
+ recast_to = 'fp16'
49
+ else:
50
+ recast_to = 'fp32'
51
+
52
+ p0 = p0.double()
53
+ p1 = p1.double()
54
+
55
+ if adain:
56
+ mean1, std1 = calc_mean_std(p0)
57
+ mean2, std2 = calc_mean_std(p1)
58
+ mean = mean1 * (1 - fract_mixing) + mean2 * fract_mixing
59
+ std = std1 * (1 - fract_mixing) + std2 * fract_mixing
60
+
61
+ norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
62
+ epsilon = 1e-7
63
+ dot = torch.sum(p0 * p1) / norm
64
+ dot = dot.clamp(-1+epsilon, 1-epsilon)
65
+
66
+ theta_0 = torch.arccos(dot)
67
+ sin_theta_0 = torch.sin(theta_0)
68
+ theta_t = theta_0 * fract_mixing
69
+ s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
70
+ s1 = torch.sin(theta_t) / sin_theta_0
71
+ interp = p0*s0 + p1*s1
72
+
73
+ if adain:
74
+ interp = F.instance_norm(interp) * std + mean
75
+
76
+ if recast_to == 'fp16':
77
+ interp = interp.half()
78
+ elif recast_to == 'fp32':
79
+ interp = interp.float()
80
+
81
+ return interp
82
+
83
+
84
+ def do_replace_attn(key: str):
85
+ # return key.startswith('up_blocks.2') or key.startswith('up_blocks.3')
86
+ return key.startswith('up')