schnik commited on
Commit
1744a8e
1 Parent(s): a36c1ef

Upload 9 files

Browse files
code/code_adaptations_audiocraft/audiocraft_lm_generate.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass
8
+ from functools import partial
9
+ import logging
10
+ import math
11
+ import typing as tp
12
+
13
+ import torch
14
+ from torch import nn
15
+
16
+ from audiocraft.utils import utils
17
+ from audiocraft.modules.streaming import StreamingModule, State
18
+ from audiocraft.modules.transformer import StreamingTransformer, create_norm_fn
19
+ from audiocraft.modules.conditioners import (
20
+ ConditionFuser,
21
+ ClassifierFreeGuidanceDropout,
22
+ AttributeDropout,
23
+ ConditioningProvider,
24
+ ConditioningAttributes,
25
+ ConditionType,
26
+ )
27
+ from audiocraft.modules.codebooks_patterns import CodebooksPatternProvider
28
+ from audiocraft.modules.activations import get_activation_fn
29
+
30
+
31
+ logger = logging.getLogger(__name__)
32
+ ConditionTensors = tp.Dict[str, ConditionType]
33
+ CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
34
+
35
+
36
+ def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
37
+ """LM layer initialization.
38
+ Inspired from xlformers: https://github.com/fairinternal/xlformers
39
+
40
+ Args:
41
+ method (str): Method name for init function. Valid options are:
42
+ 'gaussian', 'uniform'.
43
+ input_dim (int): Input dimension of the initialized module.
44
+ init_depth (int, optional): Optional init depth value used to rescale
45
+ the standard deviation if defined.
46
+ """
47
+ # Compute std
48
+ std = 1 / math.sqrt(input_dim)
49
+ # Rescale with depth
50
+ if init_depth is not None:
51
+ std = std / math.sqrt(2 * init_depth)
52
+
53
+ if method == 'gaussian':
54
+ return partial(
55
+ torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
56
+ )
57
+ elif method == 'uniform':
58
+ bound = math.sqrt(3) * std # ensure the standard deviation is `std`
59
+ return partial(torch.nn.init.uniform_, a=-bound, b=bound)
60
+ else:
61
+ raise ValueError("Unsupported layer initialization method")
62
+
63
+
64
+ def init_layer(m: nn.Module,
65
+ method: str,
66
+ init_depth: tp.Optional[int] = None,
67
+ zero_bias_init: bool = False):
68
+ """Wrapper around ``get_init_fn`` for proper initialization of LM modules.
69
+
70
+ Args:
71
+ m (nn.Module): Module to initialize.
72
+ method (str): Method name for the init function.
73
+ init_depth (int, optional): Optional init depth value used to rescale
74
+ the standard deviation if defined.
75
+ zero_bias_init (bool): Whether to initialize the bias to 0 or not.
76
+ """
77
+ if isinstance(m, nn.Linear):
78
+ init_fn = get_init_fn(method, m.in_features, init_depth=init_depth)
79
+ if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
80
+ weight = m.weight.float()
81
+ init_fn(weight)
82
+ m.weight.data[:] = weight.half()
83
+ else:
84
+ init_fn(m.weight)
85
+ if zero_bias_init and m.bias is not None:
86
+ nn.init.constant_(m.bias, 0)
87
+ elif isinstance(m, nn.Embedding):
88
+ init_fn = get_init_fn(method, m.embedding_dim, init_depth=None)
89
+ if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
90
+ weight = m.weight.float()
91
+ init_fn(weight)
92
+ m.weight.data[:] = weight.half()
93
+ else:
94
+ init_fn(m.weight)
95
+
96
+
97
+ class ScaledEmbedding(nn.Embedding):
98
+ """Boost learning rate for embeddings (with `scale`).
99
+ """
100
+ def __init__(self, *args, lr=None, **kwargs):
101
+ super().__init__(*args, **kwargs)
102
+ self.lr = lr
103
+
104
+ def make_optim_group(self):
105
+ group = {"params": list(self.parameters())}
106
+ if self.lr is not None:
107
+ group["lr"] = self.lr
108
+ return group
109
+
110
+
111
+ @dataclass
112
+ class LMOutput:
113
+ # The logits are already re-aligned with the input codes
114
+ # hence no extra shift is required, e.g. when computing CE
115
+ logits: torch.Tensor # [B, K, T, card]
116
+ mask: torch.Tensor # [B, K, T]
117
+
118
+
119
+ class LMModel(StreamingModule):
120
+ """Transformer-based language model on multiple streams of codes.
121
+
122
+ Args:
123
+ pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
124
+ condition_provider (MusicConditioningProvider): Conditioning provider from metadata.
125
+ fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
126
+ n_q (int): Number of parallel streams to model.
127
+ card (int): Cardinality, vocabulary size.
128
+ dim (int): Dimension of the transformer encoder.
129
+ num_heads (int): Number of heads for the transformer encoder.
130
+ hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
131
+ norm (str): Normalization method.
132
+ norm_first (bool): Use pre-norm instead of post-norm.
133
+ emb_lr (float, optional): Embedding-specific learning rate.
134
+ bias_proj (bool): Use bias for output projections.
135
+ weight_init (str, optional): Method for weight initialization.
136
+ depthwise_init (str, optional): Method for depthwise weight initialization.
137
+ zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
138
+ cfg_dropout (float): Classifier-free guidance dropout.
139
+ cfg_coef (float): Classifier-free guidance coefficient.
140
+ attribute_dropout (dict): Attribute dropout probabilities.
141
+ two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
142
+ **kwargs: Additional parameters for the transformer encoder.
143
+ """
144
+ def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider,
145
+ fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8,
146
+ hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
147
+ emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
148
+ weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
149
+ zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
150
+ attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False,
151
+ **kwargs):
152
+ super().__init__()
153
+ self.cfg_coef = cfg_coef
154
+ self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
155
+ self.att_dropout = AttributeDropout(p=attribute_dropout)
156
+ self.condition_provider = condition_provider
157
+ self.fuser = fuser
158
+ self.card = card
159
+ embed_dim = self.card + 1
160
+ self.n_q = n_q
161
+ self.dim = dim
162
+ self.pattern_provider = pattern_provider
163
+ self.two_step_cfg = two_step_cfg
164
+ self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
165
+ if 'activation' in kwargs:
166
+ kwargs['activation'] = get_activation_fn(kwargs['activation'])
167
+ self.transformer = StreamingTransformer(
168
+ d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
169
+ norm=norm, norm_first=norm_first, **kwargs)
170
+ self.out_norm: tp.Optional[nn.Module] = None
171
+ if norm_first:
172
+ self.out_norm = create_norm_fn(norm, dim)
173
+ self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)])
174
+ self._init_weights(weight_init, depthwise_init, zero_bias_init)
175
+ self._fsdp: tp.Optional[nn.Module]
176
+ self.__dict__['_fsdp'] = None
177
+
178
+ def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
179
+ """Initialization of the transformer module weights.
180
+
181
+ Args:
182
+ weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options.
183
+ depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid:
184
+ 'current' where the depth corresponds to the current layer index or 'global' where the total number
185
+ of layer is used as depth. If not set, no depthwise initialization strategy is used.
186
+ zero_bias_init (bool): Whether to initialize bias to zero or not.
187
+ """
188
+ assert depthwise_init is None or depthwise_init in ['current', 'global']
189
+ assert depthwise_init is None or weight_init is not None, \
190
+ "If 'depthwise_init' is defined, a 'weight_init' method should be provided."
191
+ assert not zero_bias_init or weight_init is not None, \
192
+ "If 'zero_bias_init', a 'weight_init' method should be provided"
193
+
194
+ if weight_init is None:
195
+ return
196
+
197
+ for emb_layer in self.emb:
198
+ init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
199
+
200
+ for layer_idx, tr_layer in enumerate(self.transformer.layers):
201
+ depth = None
202
+ if depthwise_init == 'current':
203
+ depth = layer_idx + 1
204
+ elif depthwise_init == 'global':
205
+ depth = len(self.transformer.layers)
206
+ init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
207
+ tr_layer.apply(init_fn)
208
+
209
+ for linear in self.linears:
210
+ init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
211
+
212
+ @property
213
+ def special_token_id(self) -> int:
214
+ return self.card
215
+
216
+ @property
217
+ def num_codebooks(self) -> int:
218
+ return self.n_q
219
+
220
+ def forward(self, sequence: torch.Tensor,
221
+ conditions: tp.List[ConditioningAttributes],
222
+ condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor:
223
+ """Apply language model on sequence and conditions.
224
+ Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
225
+ S the sequence steps, return the logits with shape [B, card, K, S].
226
+
227
+ Args:
228
+ indices (torch.Tensor): Indices of the codes to model.
229
+ conditions (list of ConditioningAttributes): Conditions to use when modeling
230
+ the given codes. Note that when evaluating multiple time with the same conditioning
231
+ you should pre-compute those and pass them as `condition_tensors`.
232
+ condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
233
+ tensors, see `conditions`.
234
+ Returns:
235
+ torch.Tensor: Logits.
236
+ """
237
+ B, K, S = sequence.shape
238
+ assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks"
239
+ input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
240
+ if condition_tensors is None:
241
+ assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
242
+ # apply dropout modules
243
+ conditions = self.cfg_dropout(conditions)
244
+ conditions = self.att_dropout(conditions)
245
+ tokenized = self.condition_provider.tokenize(conditions)
246
+ # encode conditions and fuse, both have a streaming cache to not recompute when generating.
247
+ condition_tensors = self.condition_provider(tokenized)
248
+ else:
249
+ assert not conditions, "Shouldn't pass both conditions and condition_tensors."
250
+
251
+ input_, cross_attention_input = self.fuser(input_, condition_tensors)
252
+
253
+ out = self.transformer(input_, cross_attention_src=cross_attention_input)
254
+ if self.out_norm:
255
+ out = self.out_norm(out)
256
+ logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
257
+
258
+ # remove the prefix from the model outputs
259
+ if len(self.fuser.fuse2cond['prepend']) > 0:
260
+ logits = logits[:, :, -S:]
261
+
262
+ return logits # [B, K, S, card]
263
+
264
+ def compute_predictions(
265
+ self, codes: torch.Tensor,
266
+ conditions: tp.List[ConditioningAttributes],
267
+ condition_tensors: tp.Optional[ConditionTensors] = None) -> LMOutput:
268
+ """Given an input tensor of codes [B, K, T] and list of conditions, runs the model
269
+ forward using the specified codes interleaving pattern.
270
+
271
+ Args:
272
+ codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
273
+ K the number of codebooks and T the number of timesteps.
274
+ conditions (list of ConditioningAttributes): conditionings to use when modeling
275
+ the given codes. Note that when evaluating multiple time with the same conditioning
276
+ you should pre-compute those and pass them as `condition_tensors`.
277
+ condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning
278
+ tensors, see `conditions`.
279
+ Returns:
280
+ LMOutput: Language model outputs
281
+ logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
282
+ i.e. the first item corresponds to logits to predict the first code, meaning that
283
+ no additional shifting of codes and logits is required.
284
+ mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
285
+ Given the specified interleaving strategies, parts of the logits and codes should
286
+ not be considered as valid predictions because of invalid context.
287
+ """
288
+ B, K, T = codes.shape
289
+ codes = codes.contiguous()
290
+ # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
291
+ pattern = self.pattern_provider.get_pattern(T)
292
+ sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
293
+ codes, self.special_token_id, keep_only_valid_steps=True
294
+ )
295
+ # apply model on pattern sequence
296
+ model = self if self._fsdp is None else self._fsdp
297
+ logits = model(sequence_codes, conditions, condition_tensors) # [B, K, S, card]
298
+ # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
299
+ # and provide the corresponding mask over invalid positions of tokens
300
+ logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
301
+ # note: we use nans as special token to make it obvious if we feed unexpected logits
302
+ logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
303
+ logits, float('nan'), keep_only_valid_steps=True
304
+ )
305
+ logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
306
+ logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
307
+ return LMOutput(logits, logits_mask)
308
+
309
+ def _sample_next_token(self,
310
+ sequence: torch.Tensor,
311
+ cfg_conditions: CFGConditions,
312
+ unconditional_state: State,
313
+ use_sampling: bool = False,
314
+ temp: float = 1.0,
315
+ top_k: int = 0,
316
+ top_p: float = 0.0,
317
+ cfg_coef: tp.Optional[float] = None) -> torch.Tensor:
318
+ """Sample next token from the model given a sequence and a set of conditions. The model supports
319
+ multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
320
+
321
+ Args:
322
+ sequence (torch.Tensor): Current sequence of shape [B, K, S]
323
+ with K corresponding to the number of codebooks and S the number of sequence steps.
324
+ S = 1 in streaming mode, except for the first step that contains a bigger prompt.
325
+ condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used,
326
+ should be twice the batch size, being the concatenation of the conditions + null conditions.
327
+ use_sampling (bool): Whether to use a sampling strategy or not.
328
+ temp (float): Sampling temperature.
329
+ top_k (int): K for "top-k" sampling.
330
+ top_p (float): P for "top-p" sampling.
331
+ cfg_coef (float, optional): classifier free guidance coefficient
332
+ Returns:
333
+ next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
334
+ """
335
+ B = sequence.shape[0]
336
+ cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
337
+ model = self if self._fsdp is None else self._fsdp
338
+ if self.two_step_cfg and cfg_conditions != {}:
339
+ assert isinstance(cfg_conditions, tuple), type(cfg_conditions)
340
+ condition_tensors, null_condition_tensors = cfg_conditions
341
+ cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
342
+ state = self.get_streaming_state()
343
+ self.set_streaming_state(unconditional_state)
344
+ uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors)
345
+ unconditional_state.update(self.get_streaming_state())
346
+ self.set_streaming_state(state)
347
+ logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef
348
+ else:
349
+ assert isinstance(cfg_conditions, dict)
350
+ condition_tensors = cfg_conditions
351
+ if condition_tensors:
352
+ # Preparing for CFG, predicting both conditional and unconditional logits.
353
+ sequence = torch.cat([sequence, sequence], dim=0)
354
+ all_logits = model(
355
+ sequence,
356
+ conditions=[], condition_tensors=condition_tensors)
357
+ if condition_tensors:
358
+ cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
359
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
360
+ else:
361
+ logits = all_logits
362
+
363
+ logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
364
+ logits = logits[..., -1] # [B x K x card]
365
+
366
+ # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
367
+ if use_sampling and temp > 0.0:
368
+ probs = torch.softmax(logits / temp, dim=-1)
369
+ if top_p > 0.0:
370
+ next_token = utils.sample_top_p(probs, p=top_p)
371
+ elif top_k > 0:
372
+ next_token = utils.sample_top_k(probs, k=top_k)
373
+ else:
374
+ next_token = utils.multinomial(probs, num_samples=1)
375
+ else:
376
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
377
+
378
+ return next_token
379
+
380
+ @torch.no_grad()
381
+ def generate(self,
382
+ prompt: tp.Optional[torch.Tensor] = None,
383
+ conditions: tp.List[ConditioningAttributes] = [],
384
+ condition_tensors: tp.Optional[ConditionTensors] = None,
385
+ num_samples: tp.Optional[int] = None,
386
+ max_gen_len: int = 256,
387
+ use_sampling: bool = True,
388
+ temp: float = 1.0,
389
+ top_k: int = 250,
390
+ top_p: float = 0.0,
391
+ cfg_coef: tp.Optional[float] = None,
392
+ two_step_cfg: tp.Optional[bool] = None,
393
+ remove_prompts: bool = False,
394
+ check: bool = False,
395
+ callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
396
+ """Generate tokens sampling from the model given a prompt or unconditionally. Generation can
397
+ be perform in a greedy fashion or using sampling with top K and top P strategies.
398
+
399
+ Args:
400
+ prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
401
+ conditions_tensors (list of ConditioningAttributes, optional): List of conditions.
402
+ num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
403
+ max_gen_len (int): Maximum generation length.
404
+ use_sampling (bool): Whether to use a sampling strategy or not.
405
+ temp (float): Sampling temperature.
406
+ top_k (int): K for "top-k" sampling.
407
+ top_p (float): P for "top-p" sampling.
408
+ cfg_coeff (float, optional): Classifier-free guidance coefficient.
409
+ two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
410
+ remove_prompts (bool): Whether to remove prompts from generation or not.
411
+ check (bool): Whether to apply further checks on generated sequence.
412
+ callback (Callback, optional): Callback function to report generation progress.
413
+ Returns:
414
+ torch.Tensor: Generated tokens.
415
+ """
416
+ assert not self.training, "generation shouldn't be used in training mode."
417
+ first_param = next(iter(self.parameters()))
418
+ device = first_param.device
419
+
420
+ # Checking all input shapes are consistent.
421
+ possible_num_samples = []
422
+ if num_samples is not None:
423
+ possible_num_samples.append(num_samples)
424
+ elif prompt is not None:
425
+ possible_num_samples.append(prompt.shape[0])
426
+ elif conditions:
427
+ possible_num_samples.append(len(conditions))
428
+ else:
429
+ possible_num_samples.append(1)
430
+ assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
431
+ num_samples = possible_num_samples[0]
432
+
433
+ # below we create set of conditions: one conditional and one unconditional
434
+ # to do that we merge the regular condition together with the null condition
435
+ # we then do 1 forward pass instead of 2.
436
+ # the reason for that is two-fold:
437
+ # 1. it is about x2 faster than doing 2 forward passes
438
+ # 2. avoid the streaming API treating the 2 passes as part of different time steps
439
+ # We also support doing two different passes, in particular to ensure that
440
+ # the padding structure is exactly the same between train and test.
441
+ # With a batch size of 1, this can be slower though.
442
+ cfg_conditions: CFGConditions
443
+ two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
444
+ if condition_tensors is not None:
445
+ cfg_conditions = condition_tensors
446
+ elif conditions:
447
+ null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
448
+ if two_step_cfg:
449
+ cfg_conditions = (
450
+ self.condition_provider(self.condition_provider.tokenize(conditions)),
451
+ self.condition_provider(self.condition_provider.tokenize(null_conditions)),
452
+ )
453
+ else:
454
+ conditions = conditions + null_conditions
455
+ tokenized = self.condition_provider.tokenize(conditions)
456
+ cfg_conditions = self.condition_provider(tokenized)
457
+ else:
458
+ cfg_conditions = {}
459
+
460
+ if prompt is None:
461
+ assert num_samples > 0
462
+ prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
463
+
464
+ B, K, T = prompt.shape
465
+ start_offset = T
466
+ assert start_offset < max_gen_len
467
+
468
+ pattern = self.pattern_provider.get_pattern(max_gen_len)
469
+ # this token is used as default value for codes that are not generated yet
470
+ unknown_token = -1
471
+
472
+ # we generate codes up to the max_gen_len that will be mapped to the pattern sequence
473
+ gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
474
+ # filling the gen_codes with the prompt if needed
475
+ gen_codes[..., :start_offset] = prompt
476
+ # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
477
+ gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
478
+ # retrieve the start_offset in the sequence:
479
+ # it is the first sequence step that contains the `start_offset` timestep
480
+ start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
481
+ assert start_offset_sequence is not None
482
+
483
+ with self.streaming():
484
+ unconditional_state = self.get_streaming_state()
485
+ prev_offset = 0
486
+ gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
487
+ for offset in range(start_offset_sequence, gen_sequence_len):
488
+ # get current sequence (note that the streaming API is providing the caching over previous offsets)
489
+ curr_sequence = gen_sequence[..., prev_offset:offset]
490
+ curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
491
+ if check:
492
+ # check coherence between mask and sequence
493
+ assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
494
+ # should never happen as gen_sequence is filled progressively
495
+ assert not (curr_sequence == unknown_token).any()
496
+ # sample next token from the model, next token shape is [B, K, 1]
497
+ next_token = self._sample_next_token(
498
+ curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
499
+ cfg_coef=cfg_coef)
500
+ # ensure the tokens that should be masked are properly set to special_token_id
501
+ # as the model never output special_token_id
502
+ valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
503
+ next_token[~valid_mask] = self.special_token_id
504
+ # ensure we don't overwrite prompt tokens, we only write over unknown tokens
505
+ # (then mask tokens should be left as is as well, which is correct)
506
+ gen_sequence[..., offset:offset+1] = torch.where(
507
+ gen_sequence[..., offset:offset+1] == unknown_token,
508
+ next_token, gen_sequence[..., offset:offset+1]
509
+ )
510
+ prev_offset = offset
511
+ if callback is not None:
512
+ callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
513
+ unconditional_state.clear()
514
+
515
+ # ensure sequence has been entirely filled
516
+ assert not (gen_sequence == unknown_token).any()
517
+ # ensure gen_sequence pattern and mask are matching
518
+ # which means the gen_sequence is valid according to the pattern
519
+ assert (
520
+ gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
521
+ ).all()
522
+ # get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
523
+ out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
524
+
525
+ # sanity checks over the returned codes and corresponding masks
526
+ assert (out_codes[..., :max_gen_len] != unknown_token).all()
527
+ assert (out_mask[..., :max_gen_len] == 1).all()
528
+
529
+ out_start_offset = start_offset if remove_prompts else 0
530
+ out_codes = out_codes[..., out_start_offset:max_gen_len]
531
+
532
+ # ensure the returned codes are all valid
533
+ assert (out_codes >= 0).all() and (out_codes <= self.card).all()
534
+ return out_codes
code/code_adaptations_audiocraft/important_note.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ### Important information about this ``audiocraft_lm_generate.py`
2
+
3
+ As the audiocraft model does not allow to pass precalculated text embeddings into the `generate()` method directly, the
4
+ code within the audiocraft library was altered. `audiocraft_lm_generate.py` contains an updated version of
5
+ `audiocraft/audiocraft/models/lm.py`. This is also the reason why audiocraft has to be installed from the local folder
6
+ `/audiocraft`. This file is just a backup to track the changes applied in the `lm.py`.
code/inference/gradio_app/app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import sys
4
+ sys.path.insert(1, '..')
5
+ import inference
6
+ import torch
7
+
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ def generate_background_music(video_path, dataset, use_peft, musicgen_size):
11
+ print(f"Start generating background music for {video_path} with model \"{'peft' if use_peft else 'audiocraft'}_{dataset}_{musicgen_size}\"")
12
+
13
+ new_video_path = inference.generate_background_music(
14
+ video_path=video_path,
15
+ dataset=dataset,
16
+ musicgen_size=musicgen_size,
17
+ use_stereo=True,
18
+ use_peft=use_peft,
19
+ musicgen_temperature=1.0,
20
+ musicgen_guidance_scale=3.0,
21
+ top_k_sampling=250,
22
+ device=device
23
+ )
24
+ return gr.Video(new_video_path)
25
+
26
+
27
+ interface = gr.Interface(fn=generate_background_music,
28
+ inputs=[
29
+ gr.Video(
30
+ label="video input",
31
+ min_length=5,
32
+ max_length=20,
33
+ sources=['upload'],
34
+ show_download_button=True,
35
+ include_audio=True
36
+ ),
37
+ gr.Radio(["nature", "symmv"],
38
+ label="Video Encoder Version",
39
+ value="nature",
40
+ info="Choose one of the available Video Encoders."),
41
+ gr.Radio([False, True],
42
+ label="Use MusicGen Audio Decoder Model trained with PEFT",
43
+ value=False,
44
+ info="If set to 'True' the MusicGen Audio Decoder models trained with LoRA "
45
+ "(Low Rank Adaptation) are used. If set to 'False', the original "
46
+ "MusicGen models are used."),
47
+ gr.Radio(["small", "medium", "large"],
48
+ label="MusicGen Audio Decoder Size",
49
+ value="small",
50
+ info="Choose the size of the MusicGen audio decoder."),
51
+ ],
52
+
53
+ outputs=[gr.Video(label="video output")],
54
+ examples=[
55
+ [os.path.abspath("../../../videos/originals/n_1.mp4"), "nature", True, "small"],
56
+ [os.path.abspath("../../../videos/originals/n_2.mp4"), "nature", True, "small"],
57
+ [os.path.abspath("../../../videos/originals/n_3.mp4"), "nature", True, "small"],
58
+ [os.path.abspath("../../../videos/originals/n_4.mp4"), "nature", True, "small"],
59
+ [os.path.abspath("../../../videos/originals/n_5.mp4"), "nature", True, "small"],
60
+ [os.path.abspath("../../../videos/originals/n_6.mp4"), "nature", True, "small"],
61
+ [os.path.abspath("../../../videos/originals/n_7.mp4"), "nature", True, "small"],
62
+ [os.path.abspath("../../../videos/originals/n_8.mp4"), "nature", True, "small"],
63
+ [os.path.abspath("../../../videos/originals/s_1.mp4"), "nature", True, "small"],
64
+ [os.path.abspath("../../../videos/originals/s_2.mp4"), "nature", True, "small"],
65
+ [os.path.abspath("../../../videos/originals/s_3.mp4"), "nature", True, "small"],
66
+ [os.path.abspath("../../../videos/originals/s_4.mp4"), "nature", True, "small"],
67
+ [os.path.abspath("../../../videos/originals/s_5.mp4"), "nature", True, "small"],
68
+ [os.path.abspath("../../../videos/originals/s_6.mp4"), "nature", True, "small"],
69
+ [os.path.abspath("../../../videos/originals/s_7.mp4"), "nature", True, "small"],
70
+ [os.path.abspath("../../../videos/originals/s_8.mp4"), "nature", True, "small"],
71
+ ],
72
+ cache_examples=False
73
+ )
74
+
75
+ if __name__ == "__main__":
76
+ interface.launch(
77
+ share=False
78
+ )
code/inference/inference.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+ from peft import PeftConfig, get_peft_model
3
+ from audiocraft.models import MusicGen
4
+ from moviepy.editor import AudioFileClip
5
+ from code.inference.inference_utils import *
6
+ import re
7
+ import time
8
+
9
+ re_file_name = re.compile('([^/]+$)')
10
+
11
+
12
+ def generate_background_music(video_path: str,
13
+ dataset: str,
14
+ musicgen_size: str,
15
+ use_stereo: bool,
16
+ use_peft: bool,
17
+ device: str,
18
+ musicgen_temperature: float = 1.0,
19
+ musicgen_guidance_scale: float = 3.0,
20
+ top_k_sampling: int = 250) -> str:
21
+ start = time.time()
22
+ model_path = "../training/"
23
+ model_path += "models_peft" if use_peft else "models_audiocraft"
24
+ model_path += f"/{dataset}" + f"_{musicgen_size}"
25
+
26
+ conf = OmegaConf.load(model_path + '/configuration.yml')
27
+ use_sampling = True if top_k_sampling > 0 else False
28
+ video = mpe.VideoFileClip(video_path)
29
+
30
+ musicgen_model_id = "facebook/musicgen-" + "stereo-" if use_stereo else ""
31
+ musicgen_model_id += musicgen_size
32
+
33
+ result_dir = "./results"
34
+ os.makedirs(result_dir, exist_ok=True)
35
+
36
+ encoder_output_dimension = None
37
+ if "small" in conf.musicgen_model_id:
38
+ encoder_output_dimension = 1024
39
+ elif "medium" in conf.musicgen_model_id:
40
+ encoder_output_dimension = 1536
41
+ elif "large" in conf.musicgen_model_id:
42
+ encoder_output_dimension = 2048
43
+ assert encoder_output_dimension, f"Video Encoder output dimension could not be determined by {conf.musicgen_model_id}"
44
+
45
+ musicgen_model = MusicGen.get_pretrained(musicgen_model_id)
46
+ musicgen_model.lm.to(device)
47
+ musicgen_model.compression_model.to(device)
48
+ if use_peft:
49
+ peft_path = model_path + "/musicgen_peft_final"
50
+ peft_config = PeftConfig.from_pretrained(peft_path)
51
+ musicgen_model.lm = get_peft_model(musicgen_model.lm, peft_config)
52
+ musicgen_model.lm.load_adapter(peft_path, "default")
53
+
54
+ print("MusicGen Model loaded.")
55
+
56
+ video_to_t5 = VideoToT5(
57
+ video_extraction_framerate=conf.video_extraction_framerate,
58
+ encoder_input_dimension=conf.encoder_input_dimension,
59
+ encoder_output_dimension=encoder_output_dimension,
60
+ encoder_heads=conf.encoder_heads,
61
+ encoder_dim_feedforward=conf.encoder_dim_feedforward,
62
+ encoder_layers=conf.encoder_layers,
63
+ device=device
64
+ )
65
+
66
+ video_to_t5.load_state_dict(torch.load(model_path + "/lm_final.pt", map_location=device))
67
+ print("Video Encoder Model loaded.")
68
+
69
+ print("Starting Video Feature Extraction.")
70
+ video_embedding_t5 = video_to_t5(video_paths=[video_path])
71
+
72
+ condition_tensors = create_condition_tensors(
73
+ video_embeddings=video_embedding_t5,
74
+ batch_size=1,
75
+ video_extraction_framerate=video_to_t5.video_extraction_framerate,
76
+ device=device
77
+ )
78
+
79
+ musicgen_model.generation_params = {
80
+ 'max_gen_len': int(video.duration * musicgen_model.frame_rate),
81
+ 'use_sampling': use_sampling,
82
+ 'temp': musicgen_temperature,
83
+ 'cfg_coef': musicgen_guidance_scale,
84
+ 'two_step_cfg': False,
85
+ }
86
+ if use_sampling:
87
+ musicgen_model.generation_params['top_k'] = 250
88
+
89
+ print("Starting Audio Generation.")
90
+ prompt_tokens = None
91
+ with torch.no_grad():
92
+ with musicgen_model.autocast:
93
+ gen_tokens = musicgen_model.lm.generate(prompt_tokens, [], condition_tensors, callback=None,
94
+ **musicgen_model.generation_params)
95
+ gen_audio = musicgen_model.compression_model.decode(gen_tokens)
96
+
97
+ end = time.time()
98
+ print("Elapsed time for generation: " + str(end - start))
99
+
100
+ _, video_file_name = os.path.split(video_path)
101
+ video_file_name = video_file_name[:-4] # remove .mp4
102
+
103
+ re_result = re_file_name.search(video_file_name) # get video file name
104
+ result_path = f"{'peft' if use_peft else 'audiocraft'}_{dataset}_{musicgen_size}_{re_result.group(1)}"
105
+ audio_result_path = f"{result_dir}/tmp.wav"
106
+ video_result_path = f"{result_dir}/{result_path}_video.mp4"
107
+
108
+ gen_audio = torch.squeeze(gen_audio.detach().cpu()) # remove mini-batch dimension, move to CPU for saving
109
+ sample_rate = musicgen_model.sample_rate
110
+ torchaudio.save(audio_result_path, gen_audio, sample_rate)
111
+ audio_file_clip = AudioFileClip(audio_result_path)
112
+ video.audio = audio_file_clip
113
+
114
+ print("Rendering Video.")
115
+ video.write_videofile(video_result_path)
116
+
117
+ return video_result_path
code/inference/inference_utils.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import torch
3
+ from torch import nn, Tensor
4
+ import torch.nn.functional as F
5
+ import torchaudio
6
+ import os
7
+ import logging
8
+ from torchvision.models import resnet50, ResNet50_Weights, resnet152, resnet18, resnet34, ResNet152_Weights
9
+ from PIL import Image
10
+ from time import strftime
11
+ import math
12
+ import numpy as np
13
+ import moviepy.editor as mpe
14
+
15
+
16
+ class VideoDataset(Dataset):
17
+ def __init__(self, data_dir):
18
+ self.data_dir = data_dir
19
+ self.data_map = []
20
+
21
+ dir_map = os.listdir(data_dir)
22
+ for d in dir_map:
23
+ name, extension = os.path.splitext(d)
24
+ if extension == ".mp4":
25
+ self.data_map.append({"video": os.path.join(data_dir, d)})
26
+
27
+ def __len__(self):
28
+ return len(self.data_map)
29
+
30
+ def __getitem__(self, idx):
31
+ return self.data_map[idx]["video"]
32
+
33
+
34
+ # input: video_path, output: wav_music
35
+ class VideoToT5(nn.Module):
36
+ def __init__(self,
37
+ device: str,
38
+ video_extraction_framerate: int,
39
+ encoder_input_dimension: int,
40
+ encoder_output_dimension: int,
41
+ encoder_heads: int,
42
+ encoder_dim_feedforward: int,
43
+ encoder_layers: int
44
+ ):
45
+ super().__init__()
46
+ self.video_extraction_framerate = video_extraction_framerate
47
+ self.video_feature_extractor = VideoFeatureExtractor(video_extraction_framerate=video_extraction_framerate,
48
+ device=device)
49
+ self.video_encoder = VideoEncoder(
50
+ device,
51
+ encoder_input_dimension,
52
+ encoder_output_dimension,
53
+ encoder_heads,
54
+ encoder_dim_feedforward,
55
+ encoder_layers
56
+ )
57
+
58
+ def forward(self, video_paths: [str]):
59
+ image_embeddings = []
60
+ for video_path in video_paths:
61
+ video = mpe.VideoFileClip(video_path)
62
+ video_embedding = self.video_feature_extractor(video)
63
+ image_embeddings.append(video_embedding)
64
+ video_embedding = torch.stack(
65
+ image_embeddings) # resulting shape: [batch_size, video_extraction_framerate, resnet_output_dimension]
66
+ # not used, gives worse results!
67
+ # video_embeddings = torch.mean(video_embeddings, 0, True) # average out all image embedding to one video embedding
68
+
69
+ t5_embeddings = self.video_encoder(video_embedding) # T5 output: [batch_size, num_tokens,
70
+ # t5_embedding_size]
71
+ return t5_embeddings
72
+
73
+
74
+ class VideoEncoder(nn.Module):
75
+ def __init__(self,
76
+ device: str,
77
+ encoder_input_dimension: int,
78
+ encoder_output_dimension: int,
79
+ encoder_heads: int,
80
+ encoder_dim_feedforward: int,
81
+ encoder_layers: int
82
+ ):
83
+ super().__init__()
84
+ self.device = device
85
+ self.encoder = (nn.TransformerEncoder(
86
+ nn.TransformerEncoderLayer(
87
+ d_model=encoder_input_dimension,
88
+ nhead=encoder_heads,
89
+ dim_feedforward=encoder_dim_feedforward
90
+ ),
91
+ num_layers=encoder_layers,
92
+ )
93
+ ).to(device)
94
+
95
+ # linear layer to match T5 embedding dimension
96
+ self.linear = (nn.Linear(
97
+ in_features=encoder_input_dimension,
98
+ out_features=encoder_output_dimension)
99
+ .to(device))
100
+
101
+ def forward(self, x):
102
+ assert x.dim() == 3
103
+ x = torch.transpose(x, 0, 1) # encoder expects [sequence_length, batch_size, embedding_dimension]
104
+ x = self.encoder(x) # encoder forward pass
105
+ x = self.linear(x) # forward pass through the linear layer
106
+ x = torch.transpose(x, 0, 1) # shape: [batch_size, sequence_length, embedding_dimension]
107
+ return x
108
+
109
+
110
+ class VideoFeatureExtractor(nn.Module):
111
+ def __init__(self,
112
+ device: str,
113
+ video_extraction_framerate: int = 1,
114
+ resnet_output_dimension: int = 2048):
115
+ super().__init__()
116
+ self.device = device
117
+
118
+ # using a ResNet trained on ImageNet
119
+ self.resnet = resnet50(weights="IMAGENET1K_V2").eval()
120
+ self.resnet = torch.nn.Sequential(*(list(self.resnet.children())[:-1])).to(device) # remove ResNet layer
121
+ self.resnet_preprocessor = ResNet50_Weights.DEFAULT.transforms().to(device)
122
+ self.video_extraction_framerate = video_extraction_framerate # setting the fps at which the video is processed
123
+ self.positional_encoder = PositionalEncoding(resnet_output_dimension).to(device)
124
+
125
+ def forward(self, video: mpe.VideoFileClip):
126
+ embeddings = []
127
+ for i in range(0, 30 * self.video_extraction_framerate):
128
+ i = video.get_frame(i) # get frame as numpy array
129
+ i = Image.fromarray(i) # create PIL image from numpy array
130
+ i = self.resnet_preprocessor(i) # preprocess image
131
+ i = i.to(self.device)
132
+ i = i.unsqueeze(0) # adding a batch dimension
133
+ i = self.resnet(i).squeeze() # ResNet forward pass
134
+ i = i.squeeze()
135
+ embeddings.append(i) # collect embeddings
136
+
137
+ embeddings = torch.stack(embeddings) # concatenate all frame embeddings into one video embedding
138
+ embeddings = embeddings.unsqueeze(1)
139
+ embeddings = self.positional_encoder(embeddings) # apply positional encoding with a sequence length of 30
140
+ embeddings = embeddings.squeeze()
141
+ return embeddings
142
+
143
+
144
+ # from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
145
+ class PositionalEncoding(nn.Module):
146
+ def __init__(self, d_model: int, dropout: float = 0.1, max_length: int = 5000):
147
+ super().__init__()
148
+ self.dropout = nn.Dropout(p=dropout)
149
+ position = torch.arange(30).unsqueeze(1)
150
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
151
+ pe = torch.zeros(30, 1, d_model)
152
+ pe[:, 0, 0::2] = torch.sin(position * div_term)
153
+ pe[:, 0, 1::2] = torch.cos(position * div_term)
154
+ self.register_buffer('pe', pe)
155
+
156
+ def forward(self, x: Tensor) -> Tensor:
157
+ x = x + self.pe[:x.size(0)]
158
+ return self.dropout(x)
159
+
160
+
161
+ def freeze_model(model: nn.Module):
162
+ for param in model.parameters():
163
+ param.requires_grad = False
164
+ model.eval()
165
+
166
+
167
+ def split_dataset_randomly(dataset, validation_split: float, test_split: float, seed: int = None):
168
+ dataset_size = len(dataset)
169
+ indices = list(range(dataset_size))
170
+ datapoints_validation = int(np.floor(validation_split * dataset_size))
171
+ datapoints_testing = int(np.floor(test_split * dataset_size))
172
+
173
+ if seed:
174
+ np.random.seed(seed)
175
+
176
+ np.random.shuffle(indices) # in-place operation
177
+ training = indices[datapoints_validation + datapoints_testing:]
178
+ validation = indices[datapoints_validation:datapoints_testing + datapoints_validation]
179
+ testing = indices[:datapoints_testing]
180
+
181
+ assert len(validation) == datapoints_validation, "Validation set length incorrect"
182
+ assert len(testing) == datapoints_testing, "Testing set length incorrect"
183
+ assert len(training) == dataset_size - (datapoints_testing + datapoints_testing), "Training set length incorrect"
184
+ assert not any([item in training for item in validation]), "Training and Validation overlap"
185
+ assert not any([item in training for item in testing]), "Training and Testing overlap"
186
+ assert not any([item in validation for item in testing]), "Validation and Testing overlap"
187
+
188
+ return training, validation, testing
189
+
190
+
191
+ ### private function from audiocraft.solver.musicgen.py => _compute_cross_entropy
192
+ def compute_cross_entropy(logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor):
193
+ """Compute cross entropy between multi-codebook targets and model's logits.
194
+ The cross entropy is computed per codebook to provide codebook-level cross entropy.
195
+ Valid timesteps for each of the codebook are pulled from the mask, where invalid
196
+ timesteps are set to 0.
197
+
198
+ Args:
199
+ logits (torch.Tensor): Model's logits of shape [B, K, T, card].
200
+ targets (torch.Tensor): Target codes, of shape [B, K, T].
201
+ mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
202
+ Returns:
203
+ ce (torch.Tensor): Cross entropy averaged over the codebooks
204
+ ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
205
+ """
206
+ B, K, T = targets.shape
207
+ assert logits.shape[:-1] == targets.shape
208
+ assert mask.shape == targets.shape
209
+ ce = torch.zeros([], device=targets.device)
210
+ ce_per_codebook = []
211
+ for k in range(K):
212
+ logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card]
213
+ targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T]
214
+ mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T]
215
+ ce_targets = targets_k[mask_k]
216
+ ce_logits = logits_k[mask_k]
217
+ q_ce = F.cross_entropy(ce_logits, ce_targets)
218
+ ce += q_ce
219
+ ce_per_codebook.append(q_ce.detach())
220
+ # average cross entropy across codebooks
221
+ ce = ce / K
222
+ return ce, ce_per_codebook
223
+
224
+
225
+ def generate_audio_codes(audio_paths: [str],
226
+ audiocraft_compression_model: torch.nn.Module,
227
+ device: str) -> torch.Tensor:
228
+ audio_duration = 30
229
+ encodec_sample_rate = audiocraft_compression_model.sample_rate
230
+
231
+ torch_audios = []
232
+ for audio_path in audio_paths:
233
+ wav, original_sample_rate = torchaudio.load(audio_path) # load audio from file
234
+ wav = torchaudio.functional.resample(wav, original_sample_rate,
235
+ encodec_sample_rate) # cast audio to model sample rate
236
+ wav = wav[:, :encodec_sample_rate * audio_duration] # enforce an exact audio length of 30 seconds
237
+
238
+ assert len(wav.shape) == 2, f"audio data is not of shape [channels, duration]"
239
+ assert wav.shape[0] == 2, "audio data should be in stereo, but has not 2 channels"
240
+
241
+ torch_audios.append(wav)
242
+
243
+ torch_audios = torch.stack(torch_audios)
244
+ torch_audios = torch_audios.to(device)
245
+
246
+ with torch.no_grad():
247
+ gen_audio = audiocraft_compression_model.encode(torch_audios)
248
+
249
+ codes, scale = gen_audio
250
+ assert scale is None
251
+
252
+ return codes
253
+
254
+
255
+ def create_condition_tensors(
256
+ video_embeddings: torch.Tensor,
257
+ batch_size: int,
258
+ video_extraction_framerate: int,
259
+ device: str
260
+ ):
261
+ # model T5 mask
262
+ mask = torch.ones((batch_size, video_extraction_framerate * 30), dtype=torch.int).to(device)
263
+
264
+ condition_tensors = {
265
+ 'description': (video_embeddings, mask)
266
+ }
267
+ return condition_tensors
268
+
269
+
270
+ def get_current_timestamp():
271
+ return strftime("%Y_%m_%d___%H_%M_%S")
272
+
273
+
274
+ def configure_logging(output_dir: str, filename: str, log_level):
275
+ # create logs folder, if not existing
276
+ os.makedirs(output_dir, exist_ok=True)
277
+ level = getattr(logging, log_level)
278
+ file_path = output_dir + "/" + filename
279
+ logging.basicConfig(filename=file_path, encoding='utf-8', level=level)
280
+ logger = logging.getLogger()
281
+ # only add a StreamHandler if it is not present yet
282
+ if len(logger.handlers) <= 1:
283
+ logger.addHandler(logging.StreamHandler())
code/training/__init__.py ADDED
File without changes
code/training/training.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from audiocraft.models import MusicGen
2
+ from torch.optim import AdamW
3
+ from torch.utils.data import DataLoader
4
+ import wandb
5
+ from omegaconf import OmegaConf
6
+ from training_utils import *
7
+ import json
8
+ import os
9
+ import sys
10
+ import shutil
11
+ import copy
12
+ from peft import get_peft_model, LoraConfig
13
+
14
+
15
+ def get_trainable_parameters(model):
16
+ """
17
+ Prints the number of trainable parameters in the model.
18
+ """
19
+ trainable_params = 0
20
+ all_param = 0
21
+ for _, param in model.named_parameters():
22
+ all_param += param.numel()
23
+ if param.requires_grad:
24
+ trainable_params += param.numel()
25
+ return f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
26
+
27
+
28
+ def save_model(step: int, is_final_save: bool):
29
+ logging.info(f"Model save at step {int(step)}")
30
+ torch.save(video_to_t5.state_dict(),
31
+ f"{model_save_path}/lm_{'final' if is_final_save else int(step)}.pt")
32
+
33
+ if conf.use_peft:
34
+ musicgen_model.lm.save_pretrained(f"{model_save_path}/musicgen_peft_{'final' if is_final_save else int(step)}")
35
+
36
+
37
+ conf = OmegaConf.load('training_conf.yml')
38
+ start_timestamp = get_current_timestamp()
39
+ model_save_path = f"./{'models_peft' if conf.use_peft else 'models_audiocraft'}/{conf.output_dir_name}"
40
+
41
+ if os.path.isdir(model_save_path):
42
+ i = input(f"Model output directory {model_save_path} already exists, overwrite directory? confirm with [y]\n")
43
+ if i == "y" or i == "yes":
44
+ shutil.rmtree(model_save_path)
45
+ else:
46
+ print("Aborting.")
47
+ sys.exit()
48
+
49
+ configure_logging(model_save_path, f"{start_timestamp}.log", conf.log_level)
50
+ os.makedirs(model_save_path, exist_ok=True)
51
+
52
+ if conf.use_wandb:
53
+ wandb = wandb.init(project=conf.wandb_project_name,
54
+ config=OmegaConf.to_container(conf))
55
+ logging.info(f"Wandb project_name: {conf.wandb_project_name}, run_id: {wandb.id}, run_name: {wandb.id}")
56
+
57
+ logging.info("Start Training")
58
+ musicgen_model = MusicGen.get_pretrained(conf.musicgen_model_id, device=conf.device)
59
+ musicgen_model.compression_model = musicgen_model.compression_model.to(conf.device)
60
+ musicgen_model.lm = musicgen_model.lm.to(conf.device)
61
+ musicgen_model.lm = musicgen_model.lm.train()
62
+
63
+ encoder_output_dimension = None
64
+ if "small" in conf.musicgen_model_id:
65
+ encoder_output_dimension = 1024
66
+ elif "medium" in conf.musicgen_model_id:
67
+ encoder_output_dimension = 1536
68
+ elif "large" in conf.musicgen_model_id:
69
+ encoder_output_dimension = 2048
70
+ assert encoder_output_dimension, f"Video Encoder output dimension could not be determined by {conf.musicgen_model_id}"
71
+
72
+ # initialize video-to-text model
73
+ video_to_t5 = VideoToT5(video_extraction_framerate=conf.video_extraction_framerate,
74
+ encoder_input_dimension=conf.encoder_input_dimension,
75
+ encoder_output_dimension=encoder_output_dimension,
76
+ encoder_heads=conf.encoder_heads,
77
+ encoder_dim_feedforward=conf.encoder_dim_feedforward,
78
+ encoder_layers=conf.encoder_layers,
79
+ device=conf.device)
80
+
81
+ # freeze all model layers that except the video-to-text encoder
82
+ freeze_model(video_to_t5.video_feature_extractor)
83
+ freeze_model(musicgen_model.compression_model)
84
+ if not conf.use_peft:
85
+ freeze_model(musicgen_model.lm)
86
+
87
+ logging.info(f"Trainable parameters video_to_t5: {get_trainable_parameters(video_to_t5)}")
88
+
89
+ if conf.use_peft:
90
+ lora_config = LoraConfig(
91
+ r=conf.lora_r,
92
+ lora_alpha=conf.lora_alpha,
93
+ target_modules=["out_proj", "linear1", "linear2"],
94
+ lora_dropout=conf.lora_dropout,
95
+ bias="none",
96
+ modules_to_save=["classifier"]
97
+ )
98
+
99
+ logging.info(f"Trainable parameters MusicGen before LoRA: {get_trainable_parameters(musicgen_model.lm)}")
100
+ musicgen_model.lm = get_peft_model(musicgen_model.lm, lora_config)
101
+ logging.info(f"Trainable parameters MusicGen with LoRA: {get_trainable_parameters(musicgen_model.lm)}")
102
+
103
+ logging.info(f"Training on {conf.musicgen_model_id}")
104
+
105
+ # create dataset train and validation split
106
+ dataset = VideoDataset(conf.dataset_video_folder)
107
+ train_indices, validation_indices, test_indices = split_dataset_randomly(dataset,
108
+ conf.dataset_validation_split,
109
+ conf.dataset_test_split,
110
+ seed=conf.dataset_shuffling_seed)
111
+
112
+ train_dataset = copy.copy(dataset)
113
+ train_dataset.data_map = [dataset.data_map[i] for i in train_indices]
114
+ validation_dataset = copy.copy(dataset)
115
+ validation_dataset.data_map = [dataset.data_map[i] for i in validation_indices]
116
+ test_dataset = copy.copy(dataset)
117
+ test_dataset.data_map = [dataset.data_map[i] for i in test_indices]
118
+ train_dataloader = DataLoader(train_dataset, batch_size=conf.batch_size)
119
+ validation_dataloader = DataLoader(validation_dataset, batch_size=conf.batch_size)
120
+ test_dataloader = DataLoader(test_dataset, batch_size=1)
121
+
122
+ with open(f"{model_save_path}/dataset_split.json", 'w') as f:
123
+ json.dump({
124
+ "dataset": dataset.data_map,
125
+ "training": train_dataloader.dataset.data_map,
126
+ "validation": validation_dataloader.dataset.data_map,
127
+ }, f)
128
+
129
+ with open(f"{model_save_path}/configuration.yml", 'w') as f:
130
+ OmegaConf.save(conf, f)
131
+
132
+ logging.info(f"Video path: {conf.dataset_video_folder}, "
133
+ f"Audio path: {conf.dataset_audio_folder} with {len(dataset)} examples, "
134
+ f"Batch Size: {conf.batch_size}.")
135
+
136
+ optimizer = AdamW(
137
+ video_to_t5.video_encoder.parameters(),
138
+ betas=(conf.beta1, conf.beta2),
139
+ weight_decay=conf.weight_decay,
140
+ lr=conf.learning_rate
141
+ )
142
+
143
+
144
+ def forward_pass(video_path_list: [str]):
145
+ optimizer.zero_grad()
146
+
147
+ # get corresponding audio for the video data
148
+ audio_paths = []
149
+ for video_path in video_paths:
150
+ # load corresponding audio file
151
+ _, video_file_name = os.path.split(video_path)
152
+ video_file_name = video_file_name[:-4] # remove .mp4
153
+ if conf.use_demucs_folder_structure:
154
+ audio_path = f"{conf.dataset_audio_folder}/htdemucs/{video_file_name}/no_vocals.wav"
155
+ else:
156
+ audio_path = f"{conf.dataset_audio_folder}/{video_file_name}.wav"
157
+ audio_paths.append(audio_path)
158
+
159
+ # batch encode audio data
160
+ audio_batches = generate_audio_codes(audio_paths=audio_paths,
161
+ audiocraft_compression_model=musicgen_model.compression_model,
162
+ device=conf.device)
163
+
164
+ # batch encode video data
165
+ video_embedding_batches = video_to_t5(video_path_list)
166
+
167
+ condition_tensors = create_condition_tensors(video_embedding_batches,
168
+ conf.batch_size,
169
+ video_to_t5.video_extraction_framerate,
170
+ device=conf.device)
171
+
172
+ # forward pass with MusicGen
173
+ with musicgen_model.autocast:
174
+ musicgen_output = musicgen_model.lm.compute_predictions(
175
+ codes=audio_batches,
176
+ conditions=[],
177
+ condition_tensors=condition_tensors
178
+ )
179
+ loss, _ = compute_cross_entropy(logits=musicgen_output.logits,
180
+ targets=audio_batches,
181
+ mask=musicgen_output.mask)
182
+ return musicgen_output, loss
183
+
184
+
185
+ training_step = 0
186
+ for epoch in range(conf.num_epochs):
187
+ epoch_training_loss = []
188
+ epoch_validation_loss = []
189
+ logging.info("Starting next Epoch.")
190
+ for batch_idx, video_paths in enumerate(train_dataloader):
191
+ _, training_loss = forward_pass(video_paths)
192
+ epoch_training_loss.append(training_loss)
193
+
194
+ training_loss.backward()
195
+ # gradient clipping
196
+ torch.nn.utils.clip_grad_norm_(musicgen_model.lm.parameters(), conf.gradient_clipping)
197
+ optimizer.step()
198
+ training_step += 1
199
+
200
+ # update metrics
201
+ if conf.use_wandb:
202
+ wandb.log(dict(training_loss=training_loss.item()))
203
+ logging.info(
204
+ f"Epoch: {epoch + 1}/{conf.num_epochs}, "
205
+ f"Batch: {batch_idx}/{len(train_dataloader)}, "
206
+ f"Loss: {training_loss.item()}"
207
+ )
208
+
209
+ # save model after each epoch
210
+ save_model(training_step, False)
211
+
212
+ # testing
213
+ logging.info("Start Validation.")
214
+ with torch.no_grad():
215
+ for batch_idx, video_paths in enumerate(validation_dataloader):
216
+ _, validation_loss = forward_pass(video_paths)
217
+ epoch_validation_loss.append(validation_loss)
218
+ if conf.use_wandb:
219
+ wandb.log(dict(validation_loss=validation_loss.item()))
220
+ logging.info(
221
+ f"Epoch: {epoch + 1}/{conf.num_epochs}, "
222
+ f"Batch: {batch_idx}/{len(validation_dataloader)}, "
223
+ f"Loss: {validation_loss.item()}"
224
+ )
225
+ logging.info(
226
+ f"Epoch results: epoch_training_loss {epoch_training_loss}, epoch_validation_loss {epoch_validation_loss}")
227
+ save_model(training_step, True)
228
+ logging.info(f"Finished Training. Start Testing")
229
+ with torch.no_grad():
230
+ for batch_idx, video_paths in enumerate(test_dataloader):
231
+ _, testing_loss = forward_pass(video_paths)
232
+ if conf.use_wandb:
233
+ wandb.log(dict(testing_loss=testing_loss.item()))
234
+ logging.info(
235
+ f"Epoch: {epoch + 1}/{conf.num_epochs}, "
236
+ f"Batch: {batch_idx}/{len(test_dataloader)}, "
237
+ f"Loss: {testing_loss.item()}"
238
+ )
239
+ logging.info(f"Finished Testing.")
code/training/training_conf.yml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LOGGING
2
+ log_level: "INFO"
3
+
4
+ # TRAINING LOOP
5
+ num_epochs: 4
6
+ batch_size: 2
7
+ save_step: 800
8
+ device: "cuda"
9
+
10
+ # WANDB
11
+ use_wandb: true
12
+ wandb_project_name: "pe_30"
13
+
14
+ # OPTIMIZER
15
+ gradient_clipping: 1.0
16
+ beta1: 0.9 # from audiocraft
17
+ beta2: 0.95 # from audiocraft
18
+ weight_decay: 0.1 # from audiocraft
19
+ learning_rate: 1e-5
20
+
21
+ # VIDEO ENCODER
22
+ video_extraction_framerate: 1
23
+ encoder_input_dimension: 2048 # determined by ResNet
24
+ encoder_dim_feedforward: 8192
25
+ encoder_heads: 16
26
+ encoder_layers: 6
27
+
28
+ # MUSICGEN
29
+ musicgen_model_id: "facebook/musicgen-stereo-small"
30
+
31
+ # PEFT PARAMETERS
32
+ use_peft: false
33
+ lora_r: 16
34
+ lora_alpha: 16
35
+ lora_dropout: 0.1
36
+
37
+ # OUTPUT FOLDER
38
+ output_dir_name: "nature_small"
39
+
40
+ # DATASET
41
+ dataset_video_folder: "../../datasets/nature/videos_30sec"
42
+ dataset_audio_folder: "../../datasets/nature/audio_30sec"
43
+ #dataset_audio_folder: "../../datasets/symmv/audio_30sec_separated"
44
+ use_demucs_folder_structure: false
45
+ dataset_validation_split: 0.1
46
+ dataset_test_split: 0.1
47
+ dataset_shuffling_seed: 42
code/training/training_utils.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import torch
3
+ from torch import nn, Tensor
4
+ import torch.nn.functional as F
5
+ import torchaudio
6
+ import os
7
+ import logging
8
+ from torchvision.models import resnet50, ResNet50_Weights, resnet152, resnet18, resnet34, ResNet152_Weights
9
+ from PIL import Image
10
+ from time import strftime
11
+ import math
12
+ import numpy as np
13
+ import moviepy.editor as mpe
14
+
15
+
16
+ class VideoDataset(Dataset):
17
+ def __init__(self, data_dir):
18
+ self.data_dir = data_dir
19
+ self.data_map = []
20
+
21
+ dir_map = os.listdir(data_dir)
22
+ for d in dir_map:
23
+ name, extension = os.path.splitext(d)
24
+ if extension == ".mp4":
25
+ self.data_map.append({"video": os.path.join(data_dir, d)})
26
+
27
+ def __len__(self):
28
+ return len(self.data_map)
29
+
30
+ def __getitem__(self, idx):
31
+ return self.data_map[idx]["video"]
32
+
33
+
34
+ # input: video_path, output: wav_music
35
+ class VideoToT5(nn.Module):
36
+ def __init__(self,
37
+ device: str,
38
+ video_extraction_framerate: int,
39
+ encoder_input_dimension: int,
40
+ encoder_output_dimension: int,
41
+ encoder_heads: int,
42
+ encoder_dim_feedforward: int,
43
+ encoder_layers: int
44
+ ):
45
+ super().__init__()
46
+ self.video_extraction_framerate = video_extraction_framerate
47
+ self.video_feature_extractor = VideoFeatureExtractor(video_extraction_framerate=video_extraction_framerate,
48
+ device=device)
49
+ self.video_encoder = VideoEncoder(
50
+ device,
51
+ encoder_input_dimension,
52
+ encoder_output_dimension,
53
+ encoder_heads,
54
+ encoder_dim_feedforward,
55
+ encoder_layers
56
+ )
57
+
58
+ def forward(self, video_paths: [str]):
59
+ image_embeddings = []
60
+ for video_path in video_paths:
61
+ video = mpe.VideoFileClip(video_path)
62
+ video_embedding = self.video_feature_extractor(video)
63
+ image_embeddings.append(video_embedding)
64
+ video_embedding = torch.stack(
65
+ image_embeddings) # resulting shape: [batch_size, video_extraction_framerate, resnet_output_dimension]
66
+ # not used, gives worse results!
67
+ # video_embeddings = torch.mean(video_embeddings, 0, True) # average out all image embedding to one video embedding
68
+
69
+ t5_embeddings = self.video_encoder(video_embedding) # T5 output: [batch_size, num_tokens,
70
+ # t5_embedding_size]
71
+ return t5_embeddings
72
+
73
+
74
+ class VideoEncoder(nn.Module):
75
+ def __init__(self,
76
+ device: str,
77
+ encoder_input_dimension: int,
78
+ encoder_output_dimension: int,
79
+ encoder_heads: int,
80
+ encoder_dim_feedforward: int,
81
+ encoder_layers: int
82
+ ):
83
+ super().__init__()
84
+ self.device = device
85
+ self.encoder = (nn.TransformerEncoder(
86
+ nn.TransformerEncoderLayer(
87
+ d_model=encoder_input_dimension,
88
+ nhead=encoder_heads,
89
+ dim_feedforward=encoder_dim_feedforward
90
+ ),
91
+ num_layers=encoder_layers,
92
+ )
93
+ ).to(device)
94
+
95
+ # linear layer to match T5 embedding dimension
96
+ self.linear = (nn.Linear(
97
+ in_features=encoder_input_dimension,
98
+ out_features=encoder_output_dimension)
99
+ .to(device))
100
+
101
+ def forward(self, x):
102
+ assert x.dim() == 3
103
+ x = torch.transpose(x, 0, 1) # encoder expects [sequence_length, batch_size, embedding_dimension]
104
+ x = self.encoder(x) # encoder forward pass
105
+ x = self.linear(x) # forward pass through the linear layer
106
+ x = torch.transpose(x, 0, 1) # shape: [batch_size, sequence_length, embedding_dimension]
107
+ return x
108
+
109
+
110
+ class VideoFeatureExtractor(nn.Module):
111
+ def __init__(self,
112
+ device: str,
113
+ video_extraction_framerate: int = 1,
114
+ resnet_output_dimension: int = 2048):
115
+ super().__init__()
116
+ self.device = device
117
+
118
+ # using a ResNet trained on ImageNet
119
+ self.resnet = resnet50(weights="IMAGENET1K_V2").eval()
120
+ self.resnet = torch.nn.Sequential(*(list(self.resnet.children())[:-1])).to(device) # remove ResNet layer
121
+ self.resnet_preprocessor = ResNet50_Weights.DEFAULT.transforms().to(device)
122
+ self.video_extraction_framerate = video_extraction_framerate # setting the fps at which the video is processed
123
+ self.positional_encoder = PositionalEncoding(resnet_output_dimension).to(device)
124
+
125
+ def forward(self, video: mpe.VideoFileClip):
126
+ embeddings = []
127
+ for i in range(0, 30 * self.video_extraction_framerate):
128
+ i = video.get_frame(i) # get frame as numpy array
129
+ i = Image.fromarray(i) # create PIL image from numpy array
130
+ i = self.resnet_preprocessor(i) # preprocess image
131
+ i = i.to(self.device)
132
+ i = i.unsqueeze(0) # adding a batch dimension
133
+ i = self.resnet(i).squeeze() # ResNet forward pass
134
+ i = i.squeeze()
135
+ embeddings.append(i) # collect embeddings
136
+
137
+ embeddings = torch.stack(embeddings) # concatenate all frame embeddings into one video embedding
138
+ embeddings = embeddings.unsqueeze(1)
139
+ embeddings = self.positional_encoder(embeddings) # apply positional encoding with a sequence length of 30
140
+ embeddings = embeddings.squeeze()
141
+ return embeddings
142
+
143
+
144
+ # from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
145
+ class PositionalEncoding(nn.Module):
146
+ def __init__(self, d_model: int, dropout: float = 0.1, max_length: int = 30):
147
+ super().__init__()
148
+ self.dropout = nn.Dropout(p=dropout)
149
+ position = torch.arange(30).unsqueeze(1)
150
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
151
+ pe = torch.zeros(30, 1, d_model)
152
+ pe[:, 0, 0::2] = torch.sin(position * div_term)
153
+ pe[:, 0, 1::2] = torch.cos(position * div_term)
154
+ self.register_buffer('pe', pe)
155
+
156
+ def forward(self, x: Tensor) -> Tensor:
157
+ x = x + self.pe[:x.size(0)]
158
+ return self.dropout(x)
159
+
160
+
161
+ def freeze_model(model: nn.Module):
162
+ for param in model.parameters():
163
+ param.requires_grad = False
164
+ model.eval()
165
+
166
+
167
+ def split_dataset_randomly(dataset, validation_split: float, test_split: float, seed: int = None):
168
+ dataset_size = len(dataset)
169
+ indices = list(range(dataset_size))
170
+ datapoints_validation = int(np.floor(validation_split * dataset_size))
171
+ datapoints_testing = int(np.floor(test_split * dataset_size))
172
+
173
+ if seed:
174
+ np.random.seed(seed)
175
+
176
+ np.random.shuffle(indices) # in-place operation
177
+ training = indices[datapoints_validation + datapoints_testing:]
178
+ validation = indices[datapoints_validation:datapoints_testing + datapoints_validation]
179
+ testing = indices[:datapoints_testing]
180
+
181
+ assert len(validation) == datapoints_validation, "Validation set length incorrect"
182
+ assert len(testing) == datapoints_testing, "Testing set length incorrect"
183
+ assert len(training) == dataset_size - (datapoints_testing + datapoints_testing), "Training set length incorrect"
184
+ assert not any([item in training for item in validation]), "Training and Validation overlap"
185
+ assert not any([item in training for item in testing]), "Training and Testing overlap"
186
+ assert not any([item in validation for item in testing]), "Validation and Testing overlap"
187
+
188
+ return training, validation, testing
189
+
190
+
191
+ ### private function from audiocraft.solver.musicgen.py => _compute_cross_entropy
192
+ def compute_cross_entropy(logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor):
193
+ """Compute cross entropy between multi-codebook targets and model's logits.
194
+ The cross entropy is computed per codebook to provide codebook-level cross entropy.
195
+ Valid timesteps for each of the codebook are pulled from the mask, where invalid
196
+ timesteps are set to 0.
197
+
198
+ Args:
199
+ logits (torch.Tensor): Model's logits of shape [B, K, T, card].
200
+ targets (torch.Tensor): Target codes, of shape [B, K, T].
201
+ mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
202
+ Returns:
203
+ ce (torch.Tensor): Cross entropy averaged over the codebooks
204
+ ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
205
+ """
206
+ B, K, T = targets.shape
207
+ assert logits.shape[:-1] == targets.shape
208
+ assert mask.shape == targets.shape
209
+ ce = torch.zeros([], device=targets.device)
210
+ ce_per_codebook = []
211
+ for k in range(K):
212
+ logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card]
213
+ targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T]
214
+ mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T]
215
+ ce_targets = targets_k[mask_k]
216
+ ce_logits = logits_k[mask_k]
217
+ q_ce = F.cross_entropy(ce_logits, ce_targets)
218
+ ce += q_ce
219
+ ce_per_codebook.append(q_ce.detach())
220
+ # average cross entropy across codebooks
221
+ ce = ce / K
222
+ return ce, ce_per_codebook
223
+
224
+
225
+ def generate_audio_codes(audio_paths: [str],
226
+ audiocraft_compression_model: torch.nn.Module,
227
+ device: str) -> torch.Tensor:
228
+ audio_duration = 30
229
+ encodec_sample_rate = audiocraft_compression_model.sample_rate
230
+
231
+ torch_audios = []
232
+ for audio_path in audio_paths:
233
+ wav, original_sample_rate = torchaudio.load(audio_path) # load audio from file
234
+ wav = torchaudio.functional.resample(wav, original_sample_rate,
235
+ encodec_sample_rate) # cast audio to model sample rate
236
+ wav = wav[:, :encodec_sample_rate * audio_duration] # enforce an exact audio length of 30 seconds
237
+
238
+ assert len(wav.shape) == 2, f"audio data is not of shape [channels, duration]"
239
+ assert wav.shape[0] == 2, "audio data should be in stereo, but has not 2 channels"
240
+
241
+ torch_audios.append(wav)
242
+
243
+ torch_audios = torch.stack(torch_audios)
244
+ torch_audios = torch_audios.to(device)
245
+
246
+ with torch.no_grad():
247
+ gen_audio = audiocraft_compression_model.encode(torch_audios)
248
+
249
+ codes, scale = gen_audio
250
+ assert scale is None
251
+
252
+ return codes
253
+
254
+
255
+ def create_condition_tensors(
256
+ video_embeddings: torch.Tensor,
257
+ batch_size: int,
258
+ video_extraction_framerate: int,
259
+ device: str
260
+ ):
261
+ mask = torch.ones((batch_size, video_extraction_framerate * 30), dtype=torch.int).to(device)
262
+
263
+ condition_tensors = {
264
+ 'description': (video_embeddings, mask)
265
+ }
266
+ return condition_tensors
267
+
268
+
269
+ def get_current_timestamp():
270
+ return strftime("%Y_%m_%d___%H_%M_%S")
271
+
272
+
273
+ def configure_logging(output_dir: str, filename: str, log_level):
274
+ # create logs folder, if not existing
275
+ os.makedirs(output_dir, exist_ok=True)
276
+ level = getattr(logging, log_level)
277
+ file_path = output_dir + "/" + filename
278
+ logging.basicConfig(filename=file_path, encoding='utf-8', level=level)
279
+ logger = logging.getLogger()
280
+ # only add a StreamHandler if it is not present yet
281
+ if len(logger.handlers) <= 1:
282
+ logger.addHandler(logging.StreamHandler())