Dionyssos commited on
Commit
d9889a1
1 Parent(s): 27d24be

duplicate xN_DRAW - for long gen

Browse files
audiocraft/builders.py CHANGED
@@ -7,7 +7,7 @@
7
  import typing as tp
8
  import omegaconf
9
  import torch
10
- from .encodec import CompressionModel, EncodecModel
11
  from .lm import LMModel
12
  from .seanet import SEANetDecoder
13
  from .codebooks_patterns import DelayedPatternProvider
 
7
  import typing as tp
8
  import omegaconf
9
  import torch
10
+ from .encodec import EncodecModel
11
  from .lm import LMModel
12
  from .seanet import SEANetDecoder
13
  from .codebooks_patterns import DelayedPatternProvider
audiocraft/codebooks_patterns.py CHANGED
@@ -46,84 +46,12 @@ class Pattern:
46
  n_q: int
47
 
48
  def __post_init__(self):
49
- assert len(self.layout) > 0
50
- self._validate_layout()
51
  self._build_reverted_sequence_scatter_indexes = self._build_reverted_sequence_scatter_indexes
52
  self._build_pattern_sequence_scatter_indexes = self._build_pattern_sequence_scatter_indexes
53
  print("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
54
 
55
- def _validate_layout(self):
56
- """Runs checks on the layout to ensure a valid pattern is defined.
57
- A pattern is considered invalid if:
58
- - Multiple timesteps for a same codebook are defined in the same sequence step
59
- - The timesteps for a given codebook are not in ascending order as we advance in the sequence
60
- (this would mean that we have future timesteps before past timesteps).
61
- """
62
- q_timesteps = {q: 0 for q in range(self.n_q)}
63
- for s, seq_coords in enumerate(self.layout):
64
- if len(seq_coords) > 0:
65
- qs = set()
66
- for coord in seq_coords:
67
- qs.add(coord.q)
68
- last_q_timestep = q_timesteps[coord.q]
69
- assert coord.t >= last_q_timestep, \
70
- f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
71
- q_timesteps[coord.q] = coord.t
72
- # each sequence step contains at max 1 coordinate per codebook
73
- assert len(qs) == len(seq_coords), \
74
- f"Multiple entries for a same codebook are found at step {s}"
75
- print(f'{qs=}\n\n\n\n QS VALIDATE LAYOUT') # this prints 0,1,2,3 although
76
- # if the q_timesteps contains special_index doe sthis show somehting diff than 0123
77
- # =======================================================
78
- # QS VALIDATE LAYOUT
79
- # qs={0, 1}
80
-
81
-
82
-
83
- # QS VALIDATE LAYOUT
84
- # qs={0, 1, 2}
85
-
86
-
87
-
88
- # QS VALIDATE LAYOUT
89
- # qs={0, 1, 2, 3}
90
-
91
-
92
-
93
- # QS VALIDATE LAYOUT
94
- # qs={0, 1, 2, 3}
95
-
96
-
97
-
98
- # QS VALIDATE LAYOUT
99
- # qs={0, 1, 2, 3}
100
-
101
-
102
-
103
- # QS VALIDATE LAYOUT
104
- # qs={0, 1, 2, 3}
105
-
106
-
107
-
108
- # QS VALIDATE LAYOUT
109
- # qs={0, 1, 2, 3}
110
-
111
-
112
-
113
- # QS VALIDATE LAYOUT
114
- # qs={0, 1, 2, 3}
115
-
116
-
117
-
118
- # QS VALIDATE LAYOUT
119
- # qs={0, 1, 2, 3}
120
-
121
-
122
-
123
- @property
124
- def num_sequence_steps(self):
125
- return len(self.layout) - 1
126
-
127
  @property
128
  def max_delay(self):
129
  max_t_in_seq_coords = 0
@@ -289,36 +217,6 @@ class Pattern:
289
 
290
 
291
 
292
-
293
-
294
-
295
-
296
-
297
-
298
- # def revert_pattern_logits(self, logits,
299
- # special_token,
300
- # keep_only_valid_steps=False):
301
- # """similar to ``revert_pattern_sequence`` with the following specificities:
302
- # 1. It is designed to work with the extra cardinality dimension
303
- # 2. We return the logits for the first sequence item that matches the special_token and
304
- # which matching target in the original sequence is the first item of the sequence,
305
- # while we skip the last logits as there is no matching target
306
- # """
307
- # B, card, K, S = logits.shape
308
- # indexes, mask = self._build_reverted_sequence_scatter_indexes(
309
- # S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
310
- # )
311
- # logits = logits.reshape(B, card, -1)
312
- # # we append the special token as the last index of our flattened z tensor
313
- # logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
314
- # values = logits[:, :, indexes.view(-1)]
315
- # values = values.view(B, card, K, indexes.shape[-1])
316
- # return values, indexes, mask
317
-
318
-
319
-
320
-
321
-
322
 
323
 
324
  class DelayedPatternProvider():
@@ -352,6 +250,7 @@ class DelayedPatternProvider():
352
  self.n_q = n_q
353
  if delays is None:
354
  delays = list(range(n_q))
 
355
  self.delays = delays
356
  self.flatten_first = flatten_first
357
  self.empty_initial = empty_initial
 
46
  n_q: int
47
 
48
  def __post_init__(self):
49
+ # assert len(self.layout) > 0
50
+ # self._validate_layout() #
51
  self._build_reverted_sequence_scatter_indexes = self._build_reverted_sequence_scatter_indexes
52
  self._build_pattern_sequence_scatter_indexes = self._build_pattern_sequence_scatter_indexes
53
  print("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  @property
56
  def max_delay(self):
57
  max_t_in_seq_coords = 0
 
217
 
218
 
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
 
222
  class DelayedPatternProvider():
 
250
  self.n_q = n_q
251
  if delays is None:
252
  delays = list(range(n_q))
253
+ print(f'{delays=} PATTERN __ini')
254
  self.delays = delays
255
  self.flatten_first = flatten_first
256
  self.empty_initial = empty_initial
audiocraft/encodec.py CHANGED
@@ -1,100 +1,14 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- """Compression models or wrapper around existing models.
7
- Also defines the main interface that a model must follow to be usable as an audio tokenizer.
8
- """
9
-
10
- from abc import ABC, abstractmethod
11
- import logging
12
- from pathlib import Path
13
  import typing as tp
14
-
15
  from einops import rearrange
16
  import numpy as np
17
  import torch
18
  from torch import nn
19
- from transformers import EncodecModel as HFEncodecModel
20
-
21
-
22
- logger = logging.getLogger()
23
-
24
-
25
- class CompressionModel(ABC, nn.Module):
26
- """Base API for all compression models that aim at being used as audio tokenizers
27
- with a language model.
28
- """
29
-
30
-
31
-
32
- @abstractmethod
33
- def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
34
- """See `EncodecModel.decode`."""
35
- ...
36
-
37
- @abstractmethod
38
- def decode_latent(self, codes: torch.Tensor):
39
- """Decode from the discrete codes to continuous latent space."""
40
- ...
41
-
42
- @property
43
- @abstractmethod
44
- def channels(self) -> int:
45
- ...
46
 
47
- @property
48
- @abstractmethod
49
- def frame_rate(self) -> float:
50
- ...
51
-
52
- @property
53
- @abstractmethod
54
- def sample_rate(self) -> int:
55
- ...
56
-
57
- @property
58
- @abstractmethod
59
- def cardinality(self) -> int:
60
- ...
61
-
62
- @property
63
- @abstractmethod
64
- def num_codebooks(self) -> int:
65
- ...
66
 
67
- @property
68
- @abstractmethod
69
- def total_codebooks(self) -> int:
70
- ...
71
-
72
- @abstractmethod
73
- def set_num_codebooks(self, n: int):
74
- """Set the active number of codebooks used by the quantizer."""
75
- ...
76
-
77
-
78
 
79
 
80
- class EncodecModel(CompressionModel):
81
- """Encodec model operating on the raw waveform.
82
 
83
- Args:
84
- encoder (nn.Module): Encoder network.
85
- decoder (nn.Module): Decoder network.
86
- quantizer (qt.BaseQuantizer): Quantizer network.
87
- frame_rate (int): Frame rate for the latent representation.
88
- sample_rate (int): Audio sample rate.
89
- channels (int): Number of audio channels.
90
- causal (bool): Whether to use a causal version of the model.
91
- renormalize (bool): Whether to renormalize the audio before running the model.
92
- """
93
- # we need assignment to override the property in the abstract class,
94
- # I couldn't find a better way...
95
- frame_rate: float = 0
96
- sample_rate: int = 0
97
- channels: int = 0
98
 
99
  def __init__(self,
100
  decoder=None,
@@ -104,8 +18,11 @@ class EncodecModel(CompressionModel):
104
  channels=None,
105
  causal=False,
106
  renormalize=False):
107
- super().__init__()
108
 
 
 
 
 
109
  self.decoder = decoder
110
  self.quantizer = quantizer
111
  self.frame_rate = frame_rate
@@ -117,6 +34,7 @@ class EncodecModel(CompressionModel):
117
  # we force disabling here to avoid handling linear overlap of segments
118
  # as supported in original EnCodec codebase.
119
  assert not self.renormalize, 'Causal model does not support renormalize'
 
120
 
121
  @property
122
  def total_codebooks(self):
@@ -128,7 +46,7 @@ class EncodecModel(CompressionModel):
128
  """Active number of codebooks used by the quantizer."""
129
  return self.quantizer.num_codebooks
130
 
131
- def set_num_codebooks(self, n: int):
132
  """Set the active number of codebooks used by the quantizer."""
133
  self.quantizer.set_num_codebooks(n)
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import typing as tp
 
2
  from einops import rearrange
3
  import numpy as np
4
  import torch
5
  from torch import nn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
 
 
10
 
11
+ class EncodecModel(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def __init__(self,
14
  decoder=None,
 
18
  channels=None,
19
  causal=False,
20
  renormalize=False):
 
21
 
22
+ super().__init__()
23
+ self.frame_rate=0
24
+ self.sample_rate=0
25
+ self.channels=0
26
  self.decoder = decoder
27
  self.quantizer = quantizer
28
  self.frame_rate = frame_rate
 
34
  # we force disabling here to avoid handling linear overlap of segments
35
  # as supported in original EnCodec codebase.
36
  assert not self.renormalize, 'Causal model does not support renormalize'
37
+
38
 
39
  @property
40
  def total_codebooks(self):
 
46
  """Active number of codebooks used by the quantizer."""
47
  return self.quantizer.num_codebooks
48
 
49
+ def set_num_codebooks(self, n):
50
  """Set the active number of codebooks used by the quantizer."""
51
  self.quantizer.set_num_codebooks(n)
52
 
audiocraft/genmodel.py CHANGED
@@ -3,7 +3,7 @@ import omegaconf
3
  import torch
4
 
5
  from abc import ABC, abstractmethod
6
- from .encodec import CompressionModel
7
  from .lm import LMModel
8
  from .conditioners import ConditioningAttributes
9
  from .utils.autocast import TorchAutocast
@@ -18,7 +18,7 @@ class BaseGenModel(ABC):
18
  lm (LMModel): Language model over discrete representations
19
  max_duration (float, optional): As is using top250 token draw() we can gen xN sequences
20
  """
21
- def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
22
  max_duration: tp.Optional[float] = None):
23
  self.name = name
24
  self.compression_model = compression_model
@@ -131,6 +131,9 @@ class BaseGenModel(ABC):
131
  **self.generation_params)
132
  else:
133
  print('<>Long gen ?<>')
 
 
 
134
  return gen_tokens
135
 
136
  def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor:
 
3
  import torch
4
 
5
  from abc import ABC, abstractmethod
6
+
7
  from .lm import LMModel
8
  from .conditioners import ConditioningAttributes
9
  from .utils.autocast import TorchAutocast
 
18
  lm (LMModel): Language model over discrete representations
19
  max_duration (float, optional): As is using top250 token draw() we can gen xN sequences
20
  """
21
+ def __init__(self, name: str, compression_model, lm: LMModel,
22
  max_duration: tp.Optional[float] = None):
23
  self.name = name
24
  self.compression_model = compression_model
 
131
  **self.generation_params)
132
  else:
133
  print('<>Long gen ?<>')
134
+ # print(f'{gen_tokens.shape=}') # [5,4,35]
135
+ # FLATTEN BATCH AS EXTRA SEQUENCE (BATCH IS VIRTUAL JUST MULTINOMIAL SAMPLING OF N_DRAW TOKENS)
136
+ gen_tokens = gen_tokens.transpose(0, 1).reshape(4, -1)[None, :, :]
137
  return gen_tokens
138
 
139
  def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor:
audiocraft/lm.py CHANGED
@@ -148,7 +148,7 @@ class LMModel(StreamingModule):
148
  super().__init__()
149
  self.cfg_coef = cfg_coef
150
 
151
-
152
  self.condition_provider = condition_provider
153
  self.fuser = fuser
154
  self.card = card # 2048 ?
@@ -255,23 +255,7 @@ class LMModel(StreamingModule):
255
  top_p: float = 0.0,
256
  cfg_coef: tp.Optional[float] = None,
257
  two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor:
258
- """Sample next token from the model given a sequence and a set of conditions. The model supports
259
- multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
260
-
261
- Args:
262
- sequence (torch.Tensor): Current sequence of shape [B, K, S]
263
- with K corresponding to the number of codebooks and S the number of sequence steps.
264
- S = 1 in streaming mode, except for the first step that contains a bigger prompt.
265
- condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used,
266
- should be twice the batch size, being the concatenation of the conditions + null conditions.
267
- use_sampling (bool): Whether to use a sampling strategy or not.
268
- temp (float): Sampling temperature.
269
- top_k (int): K for "top-k" sampling.
270
- top_p (float): P for "top-p" sampling.
271
- cfg_coef (float, optional): classifier free guidance coefficient
272
- Returns:
273
- next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
274
- """
275
  B = sequence.shape[0]
276
  cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
277
  model = self if self._fsdp is None else self._fsdp
@@ -283,9 +267,11 @@ class LMModel(StreamingModule):
283
  assert isinstance(cfg_conditions, dict)
284
  condition_tensors = cfg_conditions
285
  if condition_tensors:
286
- # print('\nD\n')
287
- # Preparing for CFG, predicting both conditional and unconditional logits.
288
- sequence = torch.cat([sequence, sequence], dim=0)
 
 
289
  all_logits = model(
290
  sequence,
291
  conditions=[], condition_tensors=condition_tensors)
@@ -298,24 +284,25 @@ class LMModel(StreamingModule):
298
  print('\nF!\n')
299
 
300
 
301
- logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
302
- logits = logits[..., -1] # [B x K x card]
 
 
303
 
304
  # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
305
- if use_sampling and temp > 0.0:
306
- # print(f'\nR {temp=} {top_p=} {top_k=}\n') -------------> R temp=1.0 top_p=0.0 top_k=250
307
- probs = torch.softmax(logits / temp, dim=-1)
308
- if top_p > 0.0:
309
- next_token = utils.sample_top_p(probs, p=top_p)
310
- elif top_k > 0:
311
- next_token = utils.sample_top_k(probs, k=top_k)
312
- else:
313
- next_token = utils.multinomial(probs, num_samples=1)
314
- else:
315
- #
316
- print('\nNeverHere\n')
317
-
318
-
319
  return next_token
320
 
321
  # GENERATE class revert_codebook_patterns()
@@ -324,7 +311,7 @@ class LMModel(StreamingModule):
324
  prompt = None,
325
  conditions = [],
326
  num_samples = 1, # THIS IS HOW MANY GENERATIONS - A SAMPLE IS A FULL WAV
327
- max_gen_len: int = 256,
328
  use_sampling: bool = True,
329
  temp: float = 1.0,
330
  top_k: int = 250,
@@ -335,6 +322,7 @@ class LMModel(StreamingModule):
335
  check: bool = False,
336
  callback: tp.Optional[tp.Callable[[int, int], None]] = None,
337
  **kwargs) -> torch.Tensor:
 
338
  print(f'{num_samples=}')
339
  first_param = next(iter(self.parameters()))
340
  device = first_param.device
@@ -364,10 +352,10 @@ class LMModel(StreamingModule):
364
 
365
  B, K, T = prompt.shape
366
  start_offset = T
367
- assert start_offset < max_gen_len
368
 
369
- pattern = self.pattern_provider.get_pattern(max_gen_len)
370
- # this token is used as default value for codes that are not generated yet
371
  unknown_token = -1
372
 
373
 
@@ -375,32 +363,46 @@ class LMModel(StreamingModule):
375
 
376
  gen_codes[..., :start_offset] = prompt # place 0
377
 
378
- gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
 
379
 
380
- start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
381
- # print('\n=', start_offset_sequence, '\n=') # 1
382
- assert start_offset_sequence is not None
383
 
 
384
  with self.streaming():
 
385
  unconditional_state = self.get_streaming_state()
386
  prev_offset = 0
387
- gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
388
 
389
  # --
390
  # print(mask.shape, mask.sum(), 'MSK LM')
391
  # torch.Size([4, 39]) tensor(140, device='cuda:0') MSK LM ? Fully 1 normal no special token
392
  # --
 
 
 
 
 
393
 
394
-
395
- for offset in range(start_offset_sequence, gen_sequence_len):
396
- # get current sequence (note that the streaming API is providing the caching over previous offsets)
 
 
 
 
397
 
398
- curr_sequence = gen_sequence[..., prev_offset:offset]
399
- curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
400
 
401
  next_token = self._sample_next_token(
402
- curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
403
- cfg_coef=cfg_coef, two_step_cfg=two_step_cfg)
 
 
 
 
 
 
 
404
 
405
 
406
 
@@ -418,23 +420,32 @@ class LMModel(StreamingModule):
418
 
419
  # next_token[:] = self.special_token_id # seanet.embed torch.embedding does not have this - out of bounds in detokenize
420
 
 
 
 
421
 
422
-
423
- # ensure we don't overwrite prompt tokens, we only write over unknown tokens
424
-
425
- gen_sequence[..., offset:offset+1] = torch.where(
426
- gen_sequence[..., offset:offset+1] == unknown_token,
427
- next_token, gen_sequence[..., offset:offset+1]
428
- )
429
  prev_offset = offset
430
 
431
 
432
 
433
  unconditional_state.clear()
 
 
 
 
 
434
 
 
 
 
 
 
 
435
 
436
  # revert_pattern_logits ~ NOT CALLED EXPLICIT
437
- out_codes, _, _ = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
 
 
438
 
439
  # set(out_codes.unique().tolist()) - set(gen_sequence.unique().tolist()) # set()
440
 
@@ -448,7 +459,7 @@ class LMModel(StreamingModule):
448
  # -> unknown tokn = -1 or 2048
449
  # unknown_token=-1
450
 
451
- # print(f' <=> CODES {out_codes.shape=} {out_codes.min()} {out_codes.max()}\n') # ARRIVES here also if special
452
 
453
  # unknown_token=-1 gen_sequence.shape=torch.Size([1, 4, 39]) out_codes.shape=torch.Size([1, 4, 35])
454
  # <=> CODES out_codes.shape=torch.Size([1, 4, 35]) 30 2024
 
148
  super().__init__()
149
  self.cfg_coef = cfg_coef
150
 
151
+ self.n_draw = 20
152
  self.condition_provider = condition_provider
153
  self.fuser = fuser
154
  self.card = card # 2048 ?
 
255
  top_p: float = 0.0,
256
  cfg_coef: tp.Optional[float] = None,
257
  two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor:
258
+ """self.n_draw"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  B = sequence.shape[0]
260
  cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
261
  model = self if self._fsdp is None else self._fsdp
 
267
  assert isinstance(cfg_conditions, dict)
268
  condition_tensors = cfg_conditions
269
  if condition_tensors:
270
+ print('\nDcat\n') # enters here
271
+
272
+ sequence = torch.cat([sequence, sequence], dim=0) # if i concatenate
273
+ # concatenates in batch but we only want to run 1st sequence - continutation
274
+ # the other paths will build "BLindly"
275
  all_logits = model(
276
  sequence,
277
  conditions=[], condition_tensors=condition_tensors)
 
284
  print('\nF!\n')
285
 
286
 
287
+ logits = logits.permute(0, 1, 3, 2) # [1, 4, 2048, 1]
288
+ # No crop this is just squeeze() of time
289
+ logits = logits[..., -1] # [1 x 4 x 2048]
290
+
291
 
292
  # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
293
+
294
+ # print(f'\nR {temp=} {top_p=} {top_k=}\n') -------------> R temp=1.0 top_p=0.0 top_k=250
295
+ # print(f'{temp=}') # 1.0
296
+ probs = torch.softmax(logits / temp, dim=-1)
297
+
298
+ next_token = utils.sample_top_k(probs, k=top_k, n_draw=self.n_draw)
299
+
300
+
301
+ # th decoder will smooth the transitions
302
+ # so if we have 2 tokens although the 2nd token we need it for replica later
303
+ # so let it as batch and reshape at the final time-inversion
304
+
305
+ # To return multiple tokens here (batch_size = num_draws)
 
306
  return next_token
307
 
308
  # GENERATE class revert_codebook_patterns()
 
311
  prompt = None,
312
  conditions = [],
313
  num_samples = 1, # THIS IS HOW MANY GENERATIONS - A SAMPLE IS A FULL WAV
314
+ max_gen_len=256, # unduplicated sequence length - actual len will be n_draw * maxgenlen
315
  use_sampling: bool = True,
316
  temp: float = 1.0,
317
  top_k: int = 250,
 
322
  check: bool = False,
323
  callback: tp.Optional[tp.Callable[[int, int], None]] = None,
324
  **kwargs) -> torch.Tensor:
325
+
326
  print(f'{num_samples=}')
327
  first_param = next(iter(self.parameters()))
328
  device = first_param.device
 
352
 
353
  B, K, T = prompt.shape
354
  start_offset = T
355
+
356
 
357
+ pattern = self.pattern_provider.get_pattern(max_gen_len) # duplicate sequence
358
+ # this token is used as default value for codes that are not generated yet ?
359
  unknown_token = -1
360
 
361
 
 
363
 
364
  gen_codes[..., :start_offset] = prompt # place 0
365
 
366
+ _gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
367
+
368
 
 
 
 
369
 
370
+
371
  with self.streaming():
372
+
373
  unconditional_state = self.get_streaming_state()
374
  prev_offset = 0
375
+ gen_sequence_len = _gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
376
 
377
  # --
378
  # print(mask.shape, mask.sum(), 'MSK LM')
379
  # torch.Size([4, 39]) tensor(140, device='cuda:0') MSK LM ? Fully 1 normal no special token
380
  # --
381
+ duplicate_draw = [
382
+ _gen_sequence[:, :, 0:1].repeat(self.n_draw, 1, 1)
383
+ ]
384
+ # list to hold next tokens - draw sample multiple tokens at each time-step
385
+ # but continue the sequence only with isingle next token
386
 
387
+ for offset in range(1, gen_sequence_len): # start_offset_sequence=1
388
+ print(f'{offset=}')
389
+ # starts from 1 not 0 thus uses the 0:1 as curr sequence
390
+ # although this is empty contains -1 ?
391
+
392
+ curr_sequence = _gen_sequence[..., prev_offset:offset]
393
+
394
 
 
 
395
 
396
  next_token = self._sample_next_token(
397
+ curr_sequence,
398
+ cfg_conditions,
399
+ unconditional_state,
400
+ use_sampling,
401
+ temp, top_k, top_p,
402
+ cfg_coef=cfg_coef,
403
+ two_step_cfg=two_step_cfg) # [5, 4, 1]
404
+ print(f'{next_token.shape=}')
405
+ # replicate the sequence to hold 5 or more sequences as we generate 5 tokens or more
406
 
407
 
408
 
 
420
 
421
  # next_token[:] = self.special_token_id # seanet.embed torch.embedding does not have this - out of bounds in detokenize
422
 
423
+ _gen_sequence[..., offset:offset+1] = next_token[0, :, :] #gen_sequence.shape=torch.Size([1, 4, 39])
424
+ # only cat 1 token to 1 sequence - preserve the duplicates in
425
+ duplicate_draw.append(next_token)
426
 
 
 
 
 
 
 
 
427
  prev_offset = offset
428
 
429
 
430
 
431
  unconditional_state.clear()
432
+
433
+ gen_sequence = torch.cat(duplicate_draw, 2) # [self.n_draw, 4, len_seq]
434
+
435
+ # revert codes as "batch"
436
+
437
 
438
+ # In decoder - flatten
439
+
440
+ # _, tokd, len_seq = gen_sequence.shape
441
+ # gen_sequence = gen_sequence.transpose(0, 1).reshape(tokd, self.n_draw * len_seq)[None, :, :]
442
+
443
+ print(f' <=> BEFORE CODES {gen_sequence.shape=} {_gen_sequence.shape=}\n') # ARRIVES here also if special
444
 
445
  # revert_pattern_logits ~ NOT CALLED EXPLICIT
446
+ out_codes, _, _ = pattern.revert_pattern_sequence(gen_sequence,
447
+ special_token=unknown_token)
448
+
449
 
450
  # set(out_codes.unique().tolist()) - set(gen_sequence.unique().tolist()) # set()
451
 
 
459
  # -> unknown tokn = -1 or 2048
460
  # unknown_token=-1
461
 
462
+ print(f' <=> CODES {out_codes.shape=} {out_codes.min()} {out_codes.max()}\n') # ARRIVES here also if special
463
 
464
  # unknown_token=-1 gen_sequence.shape=torch.Size([1, 4, 39]) out_codes.shape=torch.Size([1, 4, 35])
465
  # <=> CODES out_codes.shape=torch.Size([1, 4, 35]) 30 2024
audiocraft/loaders.py CHANGED
@@ -29,7 +29,7 @@ import torch
29
 
30
  import audiocraft
31
  from . import builders
32
- from .encodec import CompressionModel
33
 
34
 
35
  def get_audiocraft_cache_dir() -> tp.Optional[str]:
@@ -75,7 +75,7 @@ def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_di
75
  def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
76
  pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
77
  if 'pretrained' in pkg:
78
- return CompressionModel.get_pretrained(pkg['pretrained'], device=device)
79
  cfg = OmegaConf.create(pkg['xp.cfg'])
80
  cfg.device = str(device)
81
  model = builders.get_compression_model(cfg)
 
29
 
30
  import audiocraft
31
  from . import builders
32
+ from .encodec import EncodecModel
33
 
34
 
35
  def get_audiocraft_cache_dir() -> tp.Optional[str]:
 
75
  def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
76
  pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
77
  if 'pretrained' in pkg:
78
+ return EncodecModel.get_pretrained(pkg['pretrained'], device=device)
79
  cfg = OmegaConf.create(pkg['xp.cfg'])
80
  cfg.device = str(device)
81
  model = builders.get_compression_model(cfg)
audiocraft/transformer.py CHANGED
@@ -661,34 +661,39 @@ class StreamingTransformer(StreamingModule):
661
 
662
  def _apply_layer(self, layer, *args, **kwargs):
663
  method = self.checkpointing
 
664
  if method == 'none':
665
- return layer(*args, **kwargs)
666
- elif method == 'torch':
667
- return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs)
668
- elif method.startswith('xformers'):
669
- from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy
670
- if method == 'xformers_default':
671
- # those operations will be saved, and not recomputed.
672
- # According to Francisco we can get smarter policies but this is a good start.
673
- allow_list = [
674
- "xformers.efficient_attention_forward_cutlass.default",
675
- "xformers_flash.flash_fwd.default",
676
- "aten.addmm.default",
677
- "aten.mm.default",
678
- ]
679
- elif method == 'xformers_mm':
680
- # those operations will be saved, and not recomputed.
681
- # According to Francisco we can get smarter policies but this is a good start.
682
- allow_list = [
683
- "aten.addmm.default",
684
- "aten.mm.default",
685
- ]
686
- else:
687
- raise ValueError(f"xformers checkpointing xformers policy {method} is not known.")
688
- policy_fn = _get_default_policy(allow_list)
689
- return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs)
690
- else:
691
- raise ValueError(f"Checkpointing method {method} is unknown.")
 
 
 
 
692
 
693
  def forward(self, x: torch.Tensor, *args, **kwargs):
694
  B, T, C = x.shape
 
661
 
662
  def _apply_layer(self, layer, *args, **kwargs):
663
  method = self.checkpointing
664
+ print(f'{method=}')
665
  if method == 'none':
666
+ print([i.shape for i in args])
667
+ x = layer(*args, **kwargs) # [10, 1, 1536] probably does no t detect the bathc somwhere
668
+ return x
669
+ # elif method == 'torch':
670
+ # print('TORCH')
671
+ # return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs)
672
+ # elif method.startswith('xformers'):
673
+ # print('XFORMERS')
674
+ # from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy
675
+ # if method == 'xformers_default':
676
+ # # those operations will be saved, and not recomputed.
677
+ # # According to Francisco we can get smarter policies but this is a good start.
678
+ # allow_list = [
679
+ # "xformers.efficient_attention_forward_cutlass.default",
680
+ # "xformers_flash.flash_fwd.default",
681
+ # "aten.addmm.default",
682
+ # "aten.mm.default",
683
+ # ]
684
+ # elif method == 'xformers_mm':
685
+ # # those operations will be saved, and not recomputed.
686
+ # # According to Francisco we can get smarter policies but this is a good start.
687
+ # allow_list = [
688
+ # "aten.addmm.default",
689
+ # "aten.mm.default",
690
+ # ]
691
+ # else:
692
+ # raise ValueError(f"xformers checkpointing xformers policy {method} is not known.")
693
+ # policy_fn = _get_default_policy(allow_list)
694
+ # return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs)
695
+ # else:
696
+ # raise ValueError(f"Checkpointing method {method} is unknown.")
697
 
698
  def forward(self, x: torch.Tensor, *args, **kwargs):
699
  B, T, C = x.shape
audiocraft/utils/cluster.py DELETED
@@ -1,75 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Utility functions for SLURM configuration and cluster settings.
9
- """
10
-
11
- from enum import Enum
12
- import os
13
- import socket
14
- import typing as tp
15
-
16
- import omegaconf
17
-
18
-
19
- class ClusterType(Enum):
20
- AWS = "aws"
21
- FAIR = "fair"
22
- RSC = "rsc"
23
- LOCAL_DARWIN = "darwin"
24
- DEFAULT = "default" # used for any other cluster.
25
-
26
-
27
- def _guess_cluster_type() -> ClusterType:
28
- uname = os.uname()
29
- fqdn = socket.getfqdn()
30
- if uname.sysname == "Linux" and (uname.release.endswith("-aws") or ".ec2" in fqdn):
31
- return ClusterType.AWS
32
-
33
- if fqdn.endswith(".fair"):
34
- return ClusterType.FAIR
35
-
36
- if fqdn.endswith(".facebook.com"):
37
- return ClusterType.RSC
38
-
39
- if uname.sysname == "Darwin":
40
- return ClusterType.LOCAL_DARWIN
41
-
42
- return ClusterType.DEFAULT
43
-
44
-
45
- def get_cluster_type(
46
- cluster_type: tp.Optional[ClusterType] = None,
47
- ) -> tp.Optional[ClusterType]:
48
- if cluster_type is None:
49
- return _guess_cluster_type()
50
-
51
- return cluster_type
52
-
53
-
54
- def get_slurm_parameters(
55
- cfg: omegaconf.DictConfig, cluster_type: tp.Optional[ClusterType] = None
56
- ) -> omegaconf.DictConfig:
57
- """Update SLURM parameters in configuration based on cluster type.
58
- If the cluster type is not specify, it infers it automatically.
59
- """
60
- from ..environment import AudioCraftEnvironment
61
- cluster_type = get_cluster_type(cluster_type)
62
- # apply cluster-specific adjustments
63
- if cluster_type == ClusterType.AWS:
64
- cfg["mem_per_gpu"] = None
65
- cfg["constraint"] = None
66
- cfg["setup"] = []
67
- elif cluster_type == ClusterType.RSC:
68
- cfg["mem_per_gpu"] = None
69
- cfg["setup"] = []
70
- cfg["constraint"] = None
71
- cfg["partition"] = "learn"
72
- slurm_exclude = AudioCraftEnvironment.get_slurm_exclude()
73
- if slurm_exclude is not None:
74
- cfg["exclude"] = slurm_exclude
75
- return cfg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/utils/deadlock.py DELETED
@@ -1,58 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import logging
8
- import os
9
- from queue import Queue, Empty
10
- import signal
11
- import sys
12
- import threading
13
- import traceback
14
-
15
- logger = logging.getLogger(__name__)
16
-
17
-
18
- class DeadlockDetect:
19
- def __init__(self, use: bool = False, timeout: float = 120.):
20
- self.use = use
21
- self.timeout = timeout
22
- self._queue: Queue = Queue()
23
-
24
- def update(self, stage: str):
25
- if self.use:
26
- self._queue.put(stage)
27
-
28
- def __enter__(self):
29
- if self.use:
30
- self._thread = threading.Thread(target=self._detector_thread)
31
- self._thread.start()
32
-
33
- def __exit__(self, exc_type, exc_val, exc_tb):
34
- if self.use:
35
- self._queue.put(None)
36
- self._thread.join()
37
-
38
- def _detector_thread(self):
39
- logger.debug("Deadlock detector started")
40
- last_stage = "init"
41
- while True:
42
- try:
43
- stage = self._queue.get(timeout=self.timeout)
44
- except Empty:
45
- break
46
- if stage is None:
47
- logger.debug("Exiting deadlock detector thread")
48
- return
49
- else:
50
- last_stage = stage
51
- logger.error("Deadlock detector timed out, last stage was %s", last_stage)
52
- for th in threading.enumerate():
53
- print(th, file=sys.stderr)
54
- traceback.print_stack(sys._current_frames()[th.ident])
55
- print(file=sys.stderr)
56
- sys.stdout.flush()
57
- sys.stderr.flush()
58
- os.kill(os.getpid(), signal.SIGKILL)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/utils/utils.py CHANGED
@@ -86,47 +86,24 @@ def get_dataset_from_loader(dataloader):
86
  return dataset
87
 
88
 
89
- def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
90
- """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
91
-
92
- Args:
93
- input (torch.Tensor): The input tensor containing probabilities.
94
- num_samples (int): Number of samples to draw.
95
- replacement (bool): Whether to draw with replacement or not.
96
- Keywords args:
97
- generator (torch.Generator): A pseudorandom number generator for sampling.
98
- Returns:
99
- torch.Tensor: Last dimension contains num_samples indices
100
- sampled from the multinomial probability distribution
101
- located in the last dimension of tensor input.
102
- """
103
- input_ = input.reshape(-1, input.shape[-1])
104
- output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
105
- output = output_.reshape(*list(input.shape[:-1]), -1)
106
-
107
- # print('MULTINOmial', input.shape, output.shape) # MULTINOmial torch.Size([1, 4, 2048]) torch.Size([1, 4, 1])
108
- # output = input[..., 0:1]
109
- return output
110
 
111
 
112
- def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
113
- """Sample next token from top K values along the last dimension of the input probs tensor.
114
 
115
- Args:
116
- probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
117
- k (int): The k in “top-k”.
118
- Returns:
119
- torch.Tensor: Sampled tokens.
120
  """
121
- top_k_value, i250 = torch.topk(probs, k, dim=-1) # probs: [1, 4, 2048]
122
  min_value_top_k = top_k_value[..., [-1]] #
123
- probs *= (probs >= min_value_top_k).float() # multiply all being > of min_topk with 1 thus zeroing others
124
- probs.div_(probs.sum(dim=-1, keepdim=True)) # why normalize by the sum ? oh in order to choose mult
125
- next_token = multinomial(probs, num_samples=1)
126
- # so instead of chooose multinomial what happens if we take all 250 topk tokens
127
- # probs.shape=torch.Size([1, 4, 2048]) <, print(next_token,f'{probs.shape=}', 'h') # 1,4,1 next token is 4tok
128
- # next_token = i250
129
- return next_token
 
130
 
131
 
132
 
 
86
  return dataset
87
 
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
 
 
 
91
 
92
+ def sample_top_k(p, k, n_draw=None):
93
+ """
94
+ p probabs 2048 ?
95
+ num_draw : how many tokens to sample (for duplicate elongation)
 
96
  """
97
+ top_k_value, i250 = torch.topk(p, k, dim=-1) # probs: [1, 4, 2048]
98
  min_value_top_k = top_k_value[..., [-1]] #
99
+ p *= (p >= min_value_top_k).float()
100
+ p.div_(p.sum(dim=-1, keepdim=True))
101
+ # -- next_token = multinomial(probs, num_samples=num_draw)
102
+ p_ = p.reshape(-1, p.shape[-1])
103
+ out = torch.multinomial(p_,
104
+ num_samples=n_draw,
105
+ replacement=False) # [4, num_draw]
106
+ return out.transpose(0, 1)[:, :, None] # [num_draw, 4, 1]
107
 
108
 
109