Dionyssos commited on
Commit
27d24be
·
1 Parent(s): 8639464

index out of bounds 2048 @ dequantize()

Browse files
audiocraft/audiogen.py CHANGED
@@ -12,24 +12,12 @@ and provide easy access to the generation API.
12
  import typing as tp
13
  import torch
14
 
15
- from audiocraft.encodec import CompressionModel
16
  from audiocraft.genmodel import BaseGenModel
17
- from audiocraft.lm import LMModel
18
  from audiocraft.loaders import load_compression_model, load_lm_model
19
 
20
  class AudioGen(BaseGenModel):
21
- """AudioGen main model with convenient generation API.
22
-
23
- Args:
24
- name (str): name of the model.
25
- compression_model (CompressionModel): Compression model
26
- used to map audio to invertible discrete representations.
27
- lm (LMModel): Language model over discrete representations.
28
- max_duration (float, optional): maximum duration the model can produce,
29
- otherwise, inferred from the training params.
30
- """
31
- def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
32
- max_duration: tp.Optional[float] = None):
33
  # print(f'Using {compression_model=}\n-----=-----')
34
  super().__init__(name, compression_model, lm, max_duration)
35
  self.set_generation_params(duration=5) # default duration
 
12
  import typing as tp
13
  import torch
14
 
 
15
  from audiocraft.genmodel import BaseGenModel
 
16
  from audiocraft.loaders import load_compression_model, load_lm_model
17
 
18
  class AudioGen(BaseGenModel):
19
+
20
+ def __init__(self, name, compression_model, lm, max_duration=None):
 
 
 
 
 
 
 
 
 
 
21
  # print(f'Using {compression_model=}\n-----=-----')
22
  super().__init__(name, compression_model, lm, max_duration)
23
  self.set_generation_params(duration=5) # default duration
audiocraft/codebooks_patterns.py CHANGED
@@ -6,11 +6,9 @@
6
 
7
  from collections import namedtuple
8
  from dataclasses import dataclass
9
- from functools import lru_cache
10
  import logging
11
  import typing as tp
12
-
13
- from abc import ABC, abstractmethod
14
  import torch
15
 
16
  LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
@@ -50,8 +48,8 @@ class Pattern:
50
  def __post_init__(self):
51
  assert len(self.layout) > 0
52
  self._validate_layout()
53
- self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
54
- self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
55
  print("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
56
 
57
  def _validate_layout(self):
@@ -74,6 +72,53 @@ class Pattern:
74
  # each sequence step contains at max 1 coordinate per codebook
75
  assert len(qs) == len(seq_coords), \
76
  f"Multiple entries for a same codebook are found at step {s}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  @property
79
  def num_sequence_steps(self):
@@ -151,22 +196,10 @@ class Pattern:
151
  mask = torch.from_numpy(mask).to(device)
152
  return indexes, mask
153
 
154
- def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
155
- """Build sequence corresponding to the pattern from the input tensor z.
156
- The sequence is built using up to sequence_steps if specified, and non-pattern
157
- coordinates are filled with the special token.
158
-
159
- Args:
160
- z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
161
- special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
162
- keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
163
- Steps that are beyond valid steps will be replaced by the special_token in that case.
164
- Returns:
165
- values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
166
- corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
167
- indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
168
- mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
169
- """
170
  B, K, T = z.shape
171
  indexes, mask = self._build_pattern_sequence_scatter_indexes(
172
  T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
@@ -176,6 +209,11 @@ class Pattern:
176
  z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
177
  values = z[:, indexes.view(-1)]
178
  values = values.view(B, K, indexes.shape[-1])
 
 
 
 
 
179
  return values, indexes, mask
180
 
181
  def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
@@ -216,25 +254,26 @@ class Pattern:
216
  if s < sequence_steps:
217
  for code in sequence_codes:
218
  if code.t < timesteps:
219
- indexes[code.q, code.t] = s + code.q * sequence_steps
220
  mask[code.q, code.t] = 1
221
  indexes = torch.from_numpy(indexes).to(device)
222
  mask = torch.from_numpy(mask).to(device)
223
  return indexes, mask
224
 
225
- def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
226
- """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
227
- The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
228
- are filled with the special token.
 
229
 
230
  Args:
231
  s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
232
  special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
233
  Returns:
234
- values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
235
- corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
236
  indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
237
- mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
 
238
  """
239
  B, K, S = s.shape
240
  indexes, mask = self._build_reverted_sequence_scatter_indexes(
@@ -245,64 +284,44 @@ class Pattern:
245
  s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
246
  values = s[:, indexes.view(-1)]
247
  values = values.view(B, K, indexes.shape[-1])
 
248
  return values, indexes, mask
249
-
250
- def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
251
- """Revert model logits obtained on a sequence built from the pattern
252
- back to a tensor matching the original sequence.
253
-
254
- This method is similar to ``revert_pattern_sequence`` with the following specificities:
255
- 1. It is designed to work with the extra cardinality dimension
256
- 2. We return the logits for the first sequence item that matches the special_token and
257
- which matching target in the original sequence is the first item of the sequence,
258
- while we skip the last logits as there is no matching target
259
- """
260
- B, card, K, S = logits.shape
261
- indexes, mask = self._build_reverted_sequence_scatter_indexes(
262
- S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
263
- )
264
- logits = logits.reshape(B, card, -1)
265
- # we append the special token as the last index of our flattened z tensor
266
- logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
267
- values = logits[:, :, indexes.view(-1)]
268
- values = values.view(B, card, K, indexes.shape[-1])
269
- return values, indexes, mask
270
-
271
-
272
- class CodebooksPatternProvider(ABC):
273
- """Abstraction around providing pattern for interleaving codebooks.
274
-
275
- The CodebooksPatternProvider abstraction allows to implement various strategies to
276
- define interleaving pattern of sequences composed of multiple codebooks. For a given
277
- number of codebooks `n_q`, the pattern provider can generate a specified pattern
278
- corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
279
- can be used to construct a new sequence from the original codes respecting the specified
280
- pattern. The pattern is defined as a list of list of code coordinates, code coordinate
281
- being a tuple with the original timestep and codebook to build the new sequence.
282
- Note that all patterns must start with an empty list that is then used to insert a first
283
- sequence step of special tokens in the newly generated sequence.
284
-
285
- Args:
286
- n_q (int): number of codebooks.
287
- cached (bool): if True, patterns for a given length are cached. In general
288
- that should be true for efficiency reason to avoid synchronization points.
289
- """
290
- def __init__(self, n_q: int, cached: bool = True):
291
- assert n_q > 0
292
- self.n_q = n_q
293
- self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
294
-
295
- @abstractmethod
296
- def get_pattern(self, timesteps: int) -> Pattern:
297
- """Builds pattern with specific interleaving between codebooks.
298
-
299
- Args:
300
- timesteps (int): Total number of timesteps.
301
- """
302
- raise NotImplementedError()
303
-
304
-
305
- class DelayedPatternProvider(CodebooksPatternProvider):
306
  """Provider for delayed pattern across delayed codebooks.
307
  Codebooks are delayed in the sequence and sequence steps will contain codebooks
308
  from different timesteps.
@@ -325,9 +344,12 @@ class DelayedPatternProvider(CodebooksPatternProvider):
325
  flatten_first (int): Flatten the first N timesteps.
326
  empty_initial (int): Prepend with N empty list of coordinates.
327
  """
328
- def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
329
- flatten_first: int = 0, empty_initial: int = 0):
330
- super().__init__(n_q)
 
 
 
331
  if delays is None:
332
  delays = list(range(n_q))
333
  self.delays = delays
@@ -336,8 +358,12 @@ class DelayedPatternProvider(CodebooksPatternProvider):
336
  assert len(self.delays) == self.n_q
337
  assert sorted(self.delays) == self.delays
338
 
339
- def get_pattern(self, timesteps: int) -> Pattern:
340
- omit_special_token = self.empty_initial < 0
 
 
 
 
341
  out: PatternLayout = [] if omit_special_token else [[]]
342
  max_delay = max(self.delays)
343
  if self.empty_initial:
@@ -353,6 +379,7 @@ class DelayedPatternProvider(CodebooksPatternProvider):
353
  if t_for_q >= self.flatten_first:
354
  v.append(LayoutCoord(t_for_q, q))
355
  out.append(v)
 
356
  return Pattern(out, n_q=self.n_q, timesteps=timesteps)
357
 
358
 
 
6
 
7
  from collections import namedtuple
8
  from dataclasses import dataclass
9
+
10
  import logging
11
  import typing as tp
 
 
12
  import torch
13
 
14
  LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
 
48
  def __post_init__(self):
49
  assert len(self.layout) > 0
50
  self._validate_layout()
51
+ self._build_reverted_sequence_scatter_indexes = self._build_reverted_sequence_scatter_indexes
52
+ self._build_pattern_sequence_scatter_indexes = self._build_pattern_sequence_scatter_indexes
53
  print("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
54
 
55
  def _validate_layout(self):
 
72
  # each sequence step contains at max 1 coordinate per codebook
73
  assert len(qs) == len(seq_coords), \
74
  f"Multiple entries for a same codebook are found at step {s}"
75
+ print(f'{qs=}\n\n\n\n QS VALIDATE LAYOUT') # this prints 0,1,2,3 although
76
+ # if the q_timesteps contains special_index doe sthis show somehting diff than 0123
77
+ # =======================================================
78
+ # QS VALIDATE LAYOUT
79
+ # qs={0, 1}
80
+
81
+
82
+
83
+ # QS VALIDATE LAYOUT
84
+ # qs={0, 1, 2}
85
+
86
+
87
+
88
+ # QS VALIDATE LAYOUT
89
+ # qs={0, 1, 2, 3}
90
+
91
+
92
+
93
+ # QS VALIDATE LAYOUT
94
+ # qs={0, 1, 2, 3}
95
+
96
+
97
+
98
+ # QS VALIDATE LAYOUT
99
+ # qs={0, 1, 2, 3}
100
+
101
+
102
+
103
+ # QS VALIDATE LAYOUT
104
+ # qs={0, 1, 2, 3}
105
+
106
+
107
+
108
+ # QS VALIDATE LAYOUT
109
+ # qs={0, 1, 2, 3}
110
+
111
+
112
+
113
+ # QS VALIDATE LAYOUT
114
+ # qs={0, 1, 2, 3}
115
+
116
+
117
+
118
+ # QS VALIDATE LAYOUT
119
+ # qs={0, 1, 2, 3}
120
+
121
+
122
 
123
  @property
124
  def num_sequence_steps(self):
 
196
  mask = torch.from_numpy(mask).to(device)
197
  return indexes, mask
198
 
199
+ def build_pattern_sequence(self,
200
+ z,
201
+ special_token,
202
+ keep_only_valid_steps=False):
 
 
 
 
 
 
 
 
 
 
 
 
203
  B, K, T = z.shape
204
  indexes, mask = self._build_pattern_sequence_scatter_indexes(
205
  T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
 
209
  z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
210
  values = z[:, indexes.view(-1)]
211
  values = values.view(B, K, indexes.shape[-1])
212
+
213
+ # print(values.shape, indexes.shape, mask.shape, 'BUILD PATTERN')
214
+ # --
215
+ # torch.Size([1, 4, 39]) torch.Size([4, 39]) torch.Size([4, 39]) BUILD PATTERN
216
+
217
  return values, indexes, mask
218
 
219
  def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
 
254
  if s < sequence_steps:
255
  for code in sequence_codes:
256
  if code.t < timesteps:
257
+ indexes[code.q, code.t] = s + code.q * sequence_steps # oh the jump - so are the codes linearised
258
  mask[code.q, code.t] = 1
259
  indexes = torch.from_numpy(indexes).to(device)
260
  mask = torch.from_numpy(mask).to(device)
261
  return indexes, mask
262
 
263
+ def revert_pattern_sequence(self,
264
+ s,
265
+ special_token,
266
+ keep_only_valid_steps=False):
267
+ """SPECIAL TOKEN NOT DELETED HERE !!!!
268
 
269
  Args:
270
  s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
271
  special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
272
  Returns:
273
+ values (torch.Tensor) : Interleaved sequence matching the pattern, of shape [B, K, T] with T
 
274
  indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
275
+ mask (torch.Tensor) : Mask corresponding to indexes that matches valid indexes of shape [K, T].
276
+ shall this mask delete special token id;
277
  """
278
  B, K, S = s.shape
279
  indexes, mask = self._build_reverted_sequence_scatter_indexes(
 
284
  s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
285
  values = s[:, indexes.view(-1)]
286
  values = values.view(B, K, indexes.shape[-1])
287
+
288
  return values, indexes, mask
289
+
290
+
291
+
292
+
293
+
294
+
295
+
296
+
297
+
298
+ # def revert_pattern_logits(self, logits,
299
+ # special_token,
300
+ # keep_only_valid_steps=False):
301
+ # """similar to ``revert_pattern_sequence`` with the following specificities:
302
+ # 1. It is designed to work with the extra cardinality dimension
303
+ # 2. We return the logits for the first sequence item that matches the special_token and
304
+ # which matching target in the original sequence is the first item of the sequence,
305
+ # while we skip the last logits as there is no matching target
306
+ # """
307
+ # B, card, K, S = logits.shape
308
+ # indexes, mask = self._build_reverted_sequence_scatter_indexes(
309
+ # S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
310
+ # )
311
+ # logits = logits.reshape(B, card, -1)
312
+ # # we append the special token as the last index of our flattened z tensor
313
+ # logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
314
+ # values = logits[:, :, indexes.view(-1)]
315
+ # values = values.view(B, card, K, indexes.shape[-1])
316
+ # return values, indexes, mask
317
+
318
+
319
+
320
+
321
+
322
+
323
+
324
+ class DelayedPatternProvider():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  """Provider for delayed pattern across delayed codebooks.
326
  Codebooks are delayed in the sequence and sequence steps will contain codebooks
327
  from different timesteps.
 
344
  flatten_first (int): Flatten the first N timesteps.
345
  empty_initial (int): Prepend with N empty list of coordinates.
346
  """
347
+ def __init__(self,
348
+ n_q,
349
+ delays,
350
+ flatten_first=0,
351
+ empty_initial=0):
352
+ self.n_q = n_q
353
  if delays is None:
354
  delays = list(range(n_q))
355
  self.delays = delays
 
358
  assert len(self.delays) == self.n_q
359
  assert sorted(self.delays) == self.delays
360
 
361
+ def get_pattern(self, timesteps):
362
+ # get_pattern for desired length?
363
+ # print(f'{timesteps=} GET_PATTERn') # 35
364
+ # print(f'{self.empty_initial=}')
365
+ omit_special_token = self.empty_initial < 0 # False as initial = 0 unset
366
+
367
  out: PatternLayout = [] if omit_special_token else [[]]
368
  max_delay = max(self.delays)
369
  if self.empty_initial:
 
379
  if t_for_q >= self.flatten_first:
380
  v.append(LayoutCoord(t_for_q, q))
381
  out.append(v)
382
+ # print(self.n_q, 'N_Q in PATTERN') # 4 N_Q in PATTERN
383
  return Pattern(out, n_q=self.n_q, timesteps=timesteps)
384
 
385
 
audiocraft/genmodel.py CHANGED
@@ -44,7 +44,7 @@ class BaseGenModel(ABC):
44
  self.duration = self.max_duration
45
  self.device = next(iter(lm.parameters())).device
46
  self.generation_params={}
47
- self._progress_callback=None
48
  if self.device.type == 'cpu':
49
  self.autocast = TorchAutocast(enabled=False)
50
  else:
@@ -68,9 +68,7 @@ class BaseGenModel(ABC):
68
  """Audio channels of the generated audio."""
69
  return self.compression_model.channels
70
 
71
- def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
72
- """Override the default progress callback."""
73
- self._progress_callback = progress_callback
74
 
75
  @abstractmethod
76
  def set_generation_params(self, *args, **kwargs):
@@ -119,25 +117,16 @@ class BaseGenModel(ABC):
119
  max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
120
  current_gen_offset: int = 0
121
 
122
- def _progress_callback(generated_tokens: int, tokens_to_generate: int):
123
- generated_tokens += current_gen_offset
124
- if self._progress_callback is not None:
125
- # Note that total_gen_len might be quite wrong depending on the
126
- # codebook pattern used, but with delay it is almost accurate.
127
- self._progress_callback(generated_tokens, tokens_to_generate)
128
- else:
129
- print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r')
130
 
131
- callback = None
132
- if progress:
133
- callback = _progress_callback
134
 
135
  if self.duration <= self.max_duration:
136
  # generate by sampling from LM, simple case.
137
 
138
  with self.autocast:
139
  gen_tokens = self.lm.generate(conditions=attributes,
140
- callback=callback,
141
  max_gen_len=total_gen_len,
142
  **self.generation_params)
143
  else:
 
44
  self.duration = self.max_duration
45
  self.device = next(iter(lm.parameters())).device
46
  self.generation_params={}
47
+
48
  if self.device.type == 'cpu':
49
  self.autocast = TorchAutocast(enabled=False)
50
  else:
 
68
  """Audio channels of the generated audio."""
69
  return self.compression_model.channels
70
 
71
+
 
 
72
 
73
  @abstractmethod
74
  def set_generation_params(self, *args, **kwargs):
 
117
  max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
118
  current_gen_offset: int = 0
119
 
120
+
 
 
 
 
 
 
 
121
 
122
+
 
 
123
 
124
  if self.duration <= self.max_duration:
125
  # generate by sampling from LM, simple case.
126
 
127
  with self.autocast:
128
  gen_tokens = self.lm.generate(conditions=attributes,
129
+ callback=None,
130
  max_gen_len=total_gen_len,
131
  **self.generation_params)
132
  else:
audiocraft/lm.py CHANGED
@@ -373,8 +373,8 @@ class LMModel(StreamingModule):
373
 
374
  gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
375
 
376
- gen_codes[..., :start_offset] = prompt
377
- # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
378
  gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
379
 
380
  start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
@@ -397,26 +397,26 @@ class LMModel(StreamingModule):
397
 
398
  curr_sequence = gen_sequence[..., prev_offset:offset]
399
  curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
400
- if check:
401
- # check coherence between mask and sequence
402
- assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
403
- # should never happen as gen_sequence is filled progressively
404
- assert not (curr_sequence == unknown_token).any()
405
- # sample next token from the model, next token shape is [B, K, 1]
406
  next_token = self._sample_next_token(
407
  curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
408
  cfg_coef=cfg_coef, two_step_cfg=two_step_cfg)
 
 
 
409
  # ensure the tokens that should be masked are properly set to special_token_id
410
  # as the model never output special_token_id
411
- valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
412
 
413
  # next_token[~valid_mask] = self.special_token_id
414
 
415
  # print(f'{unconditional_state=} \n
416
  # print('Set All to Special')
417
 
418
- # RUNS with = 2047 just different of self.special_token_id -> 2047 is drill noise
419
- # next_token[:] = self.special_token_id
 
 
420
 
421
 
422
 
@@ -427,17 +427,34 @@ class LMModel(StreamingModule):
427
  next_token, gen_sequence[..., offset:offset+1]
428
  )
429
  prev_offset = offset
430
- if callback is not None:
431
- callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
 
432
  unconditional_state.clear()
433
-
 
 
434
  out_codes, _, _ = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
435
- print(f' <=> CODES {out_codes.shape=} {out_codes.min()} {out_codes.max()}\n') # ARRIVES here also if special
436
- out_start_offset = start_offset if remove_prompts else 0
437
- out_codes = out_codes[..., out_start_offset:max_gen_len]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
 
439
- # ensure the returned codes are all valid
440
 
441
- # assert (out_codes >= 0).all() and (out_codes <= self.card).all()
442
 
443
- return out_codes
 
373
 
374
  gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
375
 
376
+ gen_codes[..., :start_offset] = prompt # place 0
377
+
378
  gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
379
 
380
  start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
 
397
 
398
  curr_sequence = gen_sequence[..., prev_offset:offset]
399
  curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
400
+
 
 
 
 
 
401
  next_token = self._sample_next_token(
402
  curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
403
  cfg_coef=cfg_coef, two_step_cfg=two_step_cfg)
404
+
405
+
406
+
407
  # ensure the tokens that should be masked are properly set to special_token_id
408
  # as the model never output special_token_id
409
+ # valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
410
 
411
  # next_token[~valid_mask] = self.special_token_id
412
 
413
  # print(f'{unconditional_state=} \n
414
  # print('Set All to Special')
415
 
416
+ # RUNS with = 2047 just different of self.special_token_id = 2047 = drill noise
417
+ # special_token_id is filler for CODEBOOK_PATTERN ?
418
+
419
+ # next_token[:] = self.special_token_id # seanet.embed torch.embedding does not have this - out of bounds in detokenize
420
 
421
 
422
 
 
427
  next_token, gen_sequence[..., offset:offset+1]
428
  )
429
  prev_offset = offset
430
+
431
+
432
+
433
  unconditional_state.clear()
434
+
435
+
436
+ # revert_pattern_logits ~ NOT CALLED EXPLICIT
437
  out_codes, _, _ = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
438
+
439
+ # set(out_codes.unique().tolist()) - set(gen_sequence.unique().tolist()) # set()
440
+
441
+ # UNIQUE are the SAME ---------------?> is it rearrange
442
+
443
+
444
+
445
+ # ARE SOME PARTS IGNORED OR RE-ARRANGED
446
+
447
+ # print(f'{unknown_token=} {gen_sequence.shape=} {out_codes.shape=}')
448
+ # -> unknown tokn = -1 or 2048
449
+ # unknown_token=-1
450
+
451
+ # print(f' <=> CODES {out_codes.shape=} {out_codes.min()} {out_codes.max()}\n') # ARRIVES here also if special
452
+
453
+ # unknown_token=-1 gen_sequence.shape=torch.Size([1, 4, 39]) out_codes.shape=torch.Size([1, 4, 35])
454
+ # <=> CODES out_codes.shape=torch.Size([1, 4, 35]) 30 2024
455
+
456
+
457
 
 
458
 
 
459
 
460
+ return out_codes # supposedly contains extra prompt
audiocraft/vq.py CHANGED
@@ -64,6 +64,8 @@ class EuclideanCodebook(nn.Module):
64
  return embed_ind.view(*shape[:-1])
65
 
66
  def dequantize(self, embed_ind):
 
 
67
  quantize = F.embedding(embed_ind, self.embed)
68
  # print('\n\nDE QUANT\n\n', quantize.shape) # (1, 35, 128) -> also arrives here for special_token
69
  return quantize
 
64
  return embed_ind.view(*shape[:-1])
65
 
66
  def dequantize(self, embed_ind):
67
+ # embed_ind[0] = 2048
68
+ # print('MAX MAX MAX', embed_ind.shape)
69
  quantize = F.embedding(embed_ind, self.embed)
70
  # print('\n\nDE QUANT\n\n', quantize.shape) # (1, 35, 128) -> also arrives here for special_token
71
  return quantize