Dionyssos commited on
Commit
8639464
·
1 Parent(s): d8e2a3d

prompt_token defaults to torch.zeros()

Browse files
audiocraft/audiogen.py CHANGED
@@ -16,37 +16,6 @@ from audiocraft.encodec import CompressionModel
16
  from audiocraft.genmodel import BaseGenModel
17
  from audiocraft.lm import LMModel
18
  from audiocraft.loaders import load_compression_model, load_lm_model
19
- from .utils.audio_utils import f32_pcm, normalize_audio
20
-
21
-
22
- def audio_write(stem_name,
23
- wav,
24
- sample_rate,
25
- format= 'wav',
26
- mp3_rate=320,
27
- ogg_rate= None,
28
- normalize= True,
29
- strategy= 'peak',
30
- peak_clip_headroom_db=1,
31
- rms_headroom_db= 18,
32
- loudness_headroom_db = 14,
33
- loudness_compressor = False,
34
- log_clipping = True,
35
- make_parent_dir = True,
36
- add_suffix = True):
37
-
38
- assert wav.dtype.is_floating_point, "wav is not floating point"
39
- if wav.dim() == 1:
40
- wav = wav[None]
41
- elif wav.dim() > 2:
42
- raise ValueError("Input wav should be at most 2 dimension.")
43
- assert wav.isfinite().all()
44
- wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
45
- rms_headroom_db, loudness_headroom_db, loudness_compressor,
46
- log_clipping=log_clipping, sample_rate=sample_rate,
47
- stem_name=str(stem_name))
48
- return wav
49
- # ===
50
 
51
  class AudioGen(BaseGenModel):
52
  """AudioGen main model with convenient generation API.
 
16
  from audiocraft.genmodel import BaseGenModel
17
  from audiocraft.lm import LMModel
18
  from audiocraft.loaders import load_compression_model, load_lm_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  class AudioGen(BaseGenModel):
21
  """AudioGen main model with convenient generation API.
audiocraft/genmodel.py CHANGED
@@ -1,12 +1,10 @@
1
- from abc import ABC, abstractmethod
2
  import typing as tp
3
-
4
  import omegaconf
5
  import torch
6
 
 
7
  from .encodec import CompressionModel
8
  from .lm import LMModel
9
- from .utils.audio_utils import convert_audio
10
  from .conditioners import ConditioningAttributes
11
  from .utils.autocast import TorchAutocast
12
 
@@ -16,11 +14,9 @@ class BaseGenModel(ABC):
16
 
17
  Args:
18
  name (str): name of the model.
19
- compression_model (CompressionModel): Compression model
20
- used to map audio to invertible discrete representations.
21
- lm (LMModel): Language model over discrete representations.
22
- max_duration (float, optional): maximum duration the model can produce,
23
- otherwise, inferred from the training params.
24
  """
25
  def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
26
  max_duration: tp.Optional[float] = None):
@@ -46,19 +42,16 @@ class BaseGenModel(ABC):
46
 
47
  self.max_duration: float = max_duration
48
  self.duration = self.max_duration
49
-
50
- # self.extend_stride is the length of audio extension when generating samples longer
51
- # than self.max_duration. NOTE: the derived class must set self.extend_stride to a
52
- # positive float value when generating with self.duration > self.max_duration.
53
- self.extend_stride: tp.Optional[float] = None
54
  self.device = next(iter(lm.parameters())).device
55
- self.generation_params: dict = {}
56
- self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
57
  if self.device.type == 'cpu':
58
  self.autocast = TorchAutocast(enabled=False)
59
  else:
60
  self.autocast = TorchAutocast(
61
- enabled=True, device_type=self.device.type, dtype=torch.float16)
 
 
62
 
63
  @property
64
  def frame_rate(self) -> float:
@@ -92,64 +85,36 @@ class BaseGenModel(ABC):
92
  @torch.no_grad()
93
  def _prepare_tokens_and_attributes(
94
  self,
95
- descriptions: tp.Sequence[tp.Optional[str]],
96
- prompt: tp.Optional[torch.Tensor],
97
- ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
98
- """Prepare model inputs.
99
-
100
- Args:
101
- descriptions (list of str): A list of strings used as text conditioning.
102
- prompt (torch.Tensor): A batch of waveforms used for continuation.
103
- """
104
  attributes = [
105
- ConditioningAttributes(text={'description': description})
106
- for description in descriptions]
107
-
108
- if prompt is not None:
109
- if descriptions is not None:
110
- assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
111
- prompt = prompt.to(self.device)
112
- prompt_tokens, scale = self.compression_model.encode(prompt)
113
- assert scale is None
114
- else:
115
- prompt_tokens = None
116
  return attributes, prompt_tokens
117
 
118
- def generate_unconditional(self, num_samples: int, progress: bool = False,
119
- return_tokens: bool = False) -> tp.Union[torch.Tensor,
120
- tp.Tuple[torch.Tensor, torch.Tensor]]:
121
- """Generate samples in an unconditional manner.
122
-
123
- Args:
124
- num_samples (int): Number of samples to be generated.
125
- progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
126
- """
127
  descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
128
- attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
129
- tokens = self._generate_tokens(attributes, prompt_tokens, progress)
130
  if return_tokens:
131
  return self.generate_audio(tokens), tokens
132
  return self.generate_audio(tokens)
133
 
134
  def generate(self, descriptions, progress = False, return_tokens= False):
135
- attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
136
- assert prompt_tokens is None
137
- tokens = self._generate_tokens(attributes, prompt_tokens, progress)
138
  if return_tokens:
139
  return self.generate_audio(tokens), tokens
140
  return self.generate_audio(tokens)
141
 
142
- def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
143
- prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
144
- """Generate discrete audio tokens given audio prompt and/or conditions.
145
-
146
- Args:
147
- attributes (list of ConditioningAttributes): Conditions used for generation (here text).
148
- prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
149
- progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
150
- Returns:
151
- torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
152
- """
153
  total_gen_len = int(self.duration * self.frame_rate)
154
  max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
155
  current_gen_offset: int = 0
@@ -163,49 +128,20 @@ class BaseGenModel(ABC):
163
  else:
164
  print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r')
165
 
166
- if prompt_tokens is not None:
167
- assert max_prompt_len >= prompt_tokens.shape[-1], \
168
- "Prompt is longer than audio to generate"
169
-
170
  callback = None
171
  if progress:
172
  callback = _progress_callback
173
 
174
  if self.duration <= self.max_duration:
175
  # generate by sampling from LM, simple case.
 
176
  with self.autocast:
177
- gen_tokens = self.lm.generate(
178
- prompt_tokens, attributes,
179
- callback=callback, max_gen_len=total_gen_len, **self.generation_params)
180
-
181
  else:
182
- assert self.extend_stride is not None, "Stride should be defined to generate beyond max_duration"
183
- assert self.extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
184
- all_tokens = []
185
- if prompt_tokens is None:
186
- prompt_length = 0
187
- else:
188
- all_tokens.append(prompt_tokens)
189
- prompt_length = prompt_tokens.shape[-1]
190
-
191
- stride_tokens = int(self.frame_rate * self.extend_stride)
192
- while current_gen_offset + prompt_length < total_gen_len:
193
- time_offset = current_gen_offset / self.frame_rate
194
- chunk_duration = min(self.duration - time_offset, self.max_duration)
195
- max_gen_len = int(chunk_duration * self.frame_rate)
196
- with self.autocast:
197
- gen_tokens = self.lm.generate(
198
- prompt_tokens, attributes,
199
- callback=callback, max_gen_len=max_gen_len, **self.generation_params)
200
- if prompt_tokens is None:
201
- all_tokens.append(gen_tokens)
202
- else:
203
- all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
204
- prompt_tokens = gen_tokens[:, :, stride_tokens:]
205
- prompt_length = prompt_tokens.shape[-1]
206
- current_gen_offset += stride_tokens
207
-
208
- gen_tokens = torch.cat(all_tokens, dim=-1)
209
  return gen_tokens
210
 
211
  def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor:
 
 
1
  import typing as tp
 
2
  import omegaconf
3
  import torch
4
 
5
+ from abc import ABC, abstractmethod
6
  from .encodec import CompressionModel
7
  from .lm import LMModel
 
8
  from .conditioners import ConditioningAttributes
9
  from .utils.autocast import TorchAutocast
10
 
 
14
 
15
  Args:
16
  name (str): name of the model.
17
+ compression_model (CompressionModel): Encodec with Seanet Decoder
18
+ lm (LMModel): Language model over discrete representations
19
+ max_duration (float, optional): As is using top250 token draw() we can gen xN sequences
 
 
20
  """
21
  def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
22
  max_duration: tp.Optional[float] = None):
 
42
 
43
  self.max_duration: float = max_duration
44
  self.duration = self.max_duration
 
 
 
 
 
45
  self.device = next(iter(lm.parameters())).device
46
+ self.generation_params={}
47
+ self._progress_callback=None
48
  if self.device.type == 'cpu':
49
  self.autocast = TorchAutocast(enabled=False)
50
  else:
51
  self.autocast = TorchAutocast(
52
+ enabled=True,
53
+ device_type=self.device.type,
54
+ dtype=torch.float16)
55
 
56
  @property
57
  def frame_rate(self) -> float:
 
85
  @torch.no_grad()
86
  def _prepare_tokens_and_attributes(
87
  self,
88
+ descriptions,
89
+ prompt,
90
+ ):
 
 
 
 
 
 
91
  attributes = [
92
+ ConditioningAttributes(text={'description': description}) for description in descriptions]
93
+ prompt_tokens = None
 
 
 
 
 
 
 
 
 
94
  return attributes, prompt_tokens
95
 
96
+ def generate_unconditional(self,
97
+ num_samples,
98
+ progress=False,
99
+ return_tokens=False):
 
 
 
 
 
100
  descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
101
+ attributes, _ = self._prepare_tokens_and_attributes(descriptions, None)
102
+ tokens = self._generate_tokens(attributes)
103
  if return_tokens:
104
  return self.generate_audio(tokens), tokens
105
  return self.generate_audio(tokens)
106
 
107
  def generate(self, descriptions, progress = False, return_tokens= False):
108
+ attributes, _ = self._prepare_tokens_and_attributes(descriptions, None)
109
+ tokens = self._generate_tokens(attributes)
 
110
  if return_tokens:
111
  return self.generate_audio(tokens), tokens
112
  return self.generate_audio(tokens)
113
 
114
+ def _generate_tokens(self, attributes,
115
+ prompt_tokens=None,
116
+ progress=False):
117
+
 
 
 
 
 
 
 
118
  total_gen_len = int(self.duration * self.frame_rate)
119
  max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
120
  current_gen_offset: int = 0
 
128
  else:
129
  print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r')
130
 
 
 
 
 
131
  callback = None
132
  if progress:
133
  callback = _progress_callback
134
 
135
  if self.duration <= self.max_duration:
136
  # generate by sampling from LM, simple case.
137
+
138
  with self.autocast:
139
+ gen_tokens = self.lm.generate(conditions=attributes,
140
+ callback=callback,
141
+ max_gen_len=total_gen_len,
142
+ **self.generation_params)
143
  else:
144
+ print('<>Long gen ?<>')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  return gen_tokens
146
 
147
  def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor:
audiocraft/lm.py CHANGED
@@ -323,7 +323,7 @@ class LMModel(StreamingModule):
323
  def generate(self,
324
  prompt = None,
325
  conditions = [],
326
- num_samples = None,
327
  max_gen_len: int = 256,
328
  use_sampling: bool = True,
329
  temp: float = 1.0,
@@ -335,30 +335,10 @@ class LMModel(StreamingModule):
335
  check: bool = False,
336
  callback: tp.Optional[tp.Callable[[int, int], None]] = None,
337
  **kwargs) -> torch.Tensor:
338
- """Default generation takes random token of top_250 logits
339
-
340
- Args:
341
-
342
- Returns:
343
- torch.Tensor: tokens
344
- """
345
- assert not self.training, "generation shouldn't be used in training mode."
346
  first_param = next(iter(self.parameters()))
347
  device = first_param.device
348
-
349
- # Checking all input shapes are consistent.
350
- possible_num_samples = []
351
- if num_samples is not None:
352
- possible_num_samples.append(num_samples)
353
- elif prompt is not None:
354
- possible_num_samples.append(prompt.shape[0])
355
- elif conditions:
356
- possible_num_samples.append(len(conditions))
357
- else:
358
- possible_num_samples.append(1)
359
- assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
360
- num_samples = possible_num_samples[0]
361
-
362
  # below we create set of conditions: one conditional and one unconditional
363
  # to do that we merge the regular condition together with the null condition
364
  # we then do 1 forward pass instead of 2.
@@ -380,6 +360,7 @@ class LMModel(StreamingModule):
380
  if prompt is None:
381
  assert num_samples > 0
382
  prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
 
383
 
384
  B, K, T = prompt.shape
385
  start_offset = T
 
323
  def generate(self,
324
  prompt = None,
325
  conditions = [],
326
+ num_samples = 1, # THIS IS HOW MANY GENERATIONS - A SAMPLE IS A FULL WAV
327
  max_gen_len: int = 256,
328
  use_sampling: bool = True,
329
  temp: float = 1.0,
 
335
  check: bool = False,
336
  callback: tp.Optional[tp.Callable[[int, int], None]] = None,
337
  **kwargs) -> torch.Tensor:
338
+ print(f'{num_samples=}')
 
 
 
 
 
 
 
339
  first_param = next(iter(self.parameters()))
340
  device = first_param.device
341
+
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  # below we create set of conditions: one conditional and one unconditional
343
  # to do that we merge the regular condition together with the null condition
344
  # we then do 1 forward pass instead of 2.
 
360
  if prompt is None:
361
  assert num_samples > 0
362
  prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
363
+ print('\n\n\n\n DEFAULT PROMPT ZERO \n\n-')
364
 
365
  B, K, T = prompt.shape
366
  start_offset = T
audiocraft/utils/audio_utils.py DELETED
@@ -1,176 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- """Various utilities for audio convertion (pcm format, sample rate and channels),
7
- and volume normalization."""
8
- import sys
9
- import typing as tp
10
-
11
- import julius
12
- import torch
13
- import torchaudio
14
-
15
-
16
- def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
17
- """Convert audio to the given number of channels.
18
-
19
- Args:
20
- wav (torch.Tensor): Audio wave of shape [B, C, T].
21
- channels (int): Expected number of channels as output.
22
- Returns:
23
- torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
24
- """
25
- *shape, src_channels, length = wav.shape
26
- if src_channels == channels:
27
- pass
28
- elif channels == 1:
29
- # Case 1:
30
- # The caller asked 1-channel audio, and the stream has multiple
31
- # channels, downmix all channels.
32
- wav = wav.mean(dim=-2, keepdim=True)
33
- elif src_channels == 1:
34
- # Case 2:
35
- # The caller asked for multiple channels, but the input file has
36
- # a single channel, replicate the audio over all channels.
37
- wav = wav.expand(*shape, channels, length)
38
- elif src_channels >= channels:
39
- # Case 3:
40
- # The caller asked for multiple channels, and the input file has
41
- # more channels than requested. In that case return the first channels.
42
- wav = wav[..., :channels, :]
43
- else:
44
- # Case 4: What is a reasonable choice here?
45
- raise ValueError('The audio file has less channels than requested but is not mono.')
46
- return wav
47
-
48
-
49
- def convert_audio(wav: torch.Tensor, from_rate: float,
50
- to_rate: float, to_channels: int) -> torch.Tensor:
51
- """Convert audio to new sample rate and number of audio channels."""
52
- wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
53
- wav = convert_audio_channels(wav, to_channels)
54
- return wav
55
-
56
-
57
- def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
58
- loudness_compressor: bool = False, energy_floor: float = 2e-3):
59
- """Normalize an input signal to a user loudness in dB LKFS.
60
- Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
61
-
62
- Args:
63
- wav (torch.Tensor): Input multichannel audio data.
64
- sample_rate (int): Sample rate.
65
- loudness_headroom_db (float): Target loudness of the output in dB LUFS.
66
- loudness_compressor (bool): Uses tanh for soft clipping.
67
- energy_floor (float): anything below that RMS level will not be rescaled.
68
- Returns:
69
- torch.Tensor: Loudness normalized output data.
70
- """
71
- energy = wav.pow(2).mean().sqrt().item()
72
- if energy < energy_floor:
73
- return wav
74
- transform = torchaudio.transforms.Loudness(sample_rate)
75
- input_loudness_db = transform(wav).item()
76
- # calculate the gain needed to scale to the desired loudness level
77
- delta_loudness = -loudness_headroom_db - input_loudness_db
78
- gain = 10.0 ** (delta_loudness / 20.0)
79
- output = gain * wav
80
- if loudness_compressor:
81
- output = torch.tanh(output)
82
- assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
83
- return output
84
-
85
-
86
- def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
87
- """Utility function to clip the audio with logging if specified."""
88
- max_scale = wav.abs().max()
89
- if log_clipping and max_scale > 1:
90
- clamp_prob = (wav.abs() > 1).float().mean().item()
91
- print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
92
- clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
93
- wav.clamp_(-1, 1)
94
-
95
-
96
- def normalize_audio(wav: torch.Tensor, normalize: bool = True,
97
- strategy: str = 'peak', peak_clip_headroom_db: float = 1,
98
- rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
99
- loudness_compressor: bool = False, log_clipping: bool = False,
100
- sample_rate: tp.Optional[int] = None,
101
- stem_name: tp.Optional[str] = None) -> torch.Tensor:
102
- """Normalize the audio according to the prescribed strategy (see after).
103
-
104
- Args:
105
- wav (torch.Tensor): Audio data.
106
- normalize (bool): if `True` (default), normalizes according to the prescribed
107
- strategy (see after). If `False`, the strategy is only used in case clipping
108
- would happen.
109
- strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
110
- i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
111
- with extra headroom to avoid clipping. 'clip' just clips.
112
- peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
113
- rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
114
- than the `peak_clip` one to avoid further clipping.
115
- loudness_headroom_db (float): Target loudness for loudness normalization.
116
- loudness_compressor (bool): If True, uses tanh based soft clipping.
117
- log_clipping (bool): If True, basic logging on stderr when clipping still
118
- occurs despite strategy (only for 'rms').
119
- sample_rate (int): Sample rate for the audio data (required for loudness).
120
- stem_name (str, optional): Stem name for clipping logging.
121
- Returns:
122
- torch.Tensor: Normalized audio.
123
- """
124
- scale_peak = 10 ** (-peak_clip_headroom_db / 20)
125
- scale_rms = 10 ** (-rms_headroom_db / 20)
126
- if strategy == 'peak':
127
- rescaling = (scale_peak / wav.abs().max())
128
- if normalize or rescaling < 1:
129
- wav = wav * rescaling
130
- elif strategy == 'clip':
131
- wav = wav.clamp(-scale_peak, scale_peak)
132
- elif strategy == 'rms':
133
- mono = wav.mean(dim=0)
134
- rescaling = scale_rms / mono.pow(2).mean().sqrt()
135
- if normalize or rescaling < 1:
136
- wav = wav * rescaling
137
- _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
138
- elif strategy == 'loudness':
139
- assert sample_rate is not None, "Loudness normalization requires sample rate."
140
- wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
141
- _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
142
- else:
143
- assert wav.abs().max() < 1
144
- assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
145
- return wav
146
-
147
-
148
- def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
149
- """Convert audio to float 32 bits PCM format.
150
- """
151
- if wav.dtype.is_floating_point:
152
- return wav
153
- elif wav.dtype == torch.int16:
154
- return wav.float() / 2**15
155
- elif wav.dtype == torch.int32:
156
- return wav.float() / 2**31
157
- raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
158
-
159
-
160
- def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
161
- """Convert audio to int 16 bits PCM format.
162
-
163
- ..Warning:: There exist many formula for doing this conversion. None are perfect
164
- due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
165
- or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
166
- it is possible that `i16_pcm(f32_pcm)) != Identity`.
167
- """
168
- if wav.dtype.is_floating_point:
169
- assert wav.abs().max() <= 1
170
- candidate = (wav * 2 ** 15).round()
171
- if candidate.max() >= 2 ** 15: # clipping would occur
172
- candidate = (wav * (2 ** 15 - 1)).round()
173
- return candidate.short()
174
- else:
175
- assert wav.dtype == torch.int16
176
- return wav
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/utils/samples/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
 
 
 
 
 
 
audiocraft/utils/samples/manager.py DELETED
@@ -1,386 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- API that can manage the storage and retrieval of generated samples produced by experiments.
9
-
10
- It offers the following benefits:
11
- * Samples are stored in a consistent way across epoch
12
- * Metadata about the samples can be stored and retrieved
13
- * Can retrieve audio
14
- * Identifiers are reliable and deterministic for prompted and conditioned samples
15
- * Can request the samples for multiple XPs, grouped by sample identifier
16
- * For no-input samples (not prompt and no conditions), samples across XPs are matched
17
- by sorting their identifiers
18
- """
19
-
20
- from concurrent.futures import ThreadPoolExecutor
21
- from dataclasses import asdict, dataclass
22
- from functools import lru_cache
23
- import hashlib
24
- import json
25
- import logging
26
- from pathlib import Path
27
- import re
28
- import typing as tp
29
- import unicodedata
30
- import uuid
31
-
32
- import dora
33
- import torch
34
-
35
- from ...data.audio import audio_read, audio_write
36
-
37
-
38
- logger = logging.getLogger(__name__)
39
-
40
-
41
- @dataclass
42
- class ReferenceSample:
43
- id: str
44
- path: str
45
- duration: float
46
-
47
-
48
- @dataclass
49
- class Sample:
50
- id: str
51
- path: str
52
- epoch: int
53
- duration: float
54
- conditioning: tp.Optional[tp.Dict[str, tp.Any]]
55
- prompt: tp.Optional[ReferenceSample]
56
- reference: tp.Optional[ReferenceSample]
57
- generation_args: tp.Optional[tp.Dict[str, tp.Any]]
58
-
59
- def __hash__(self):
60
- return hash(self.id)
61
-
62
- def audio(self) -> tp.Tuple[torch.Tensor, int]:
63
- return audio_read(self.path)
64
-
65
- def audio_prompt(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]:
66
- return audio_read(self.prompt.path) if self.prompt is not None else None
67
-
68
- def audio_reference(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]:
69
- return audio_read(self.reference.path) if self.reference is not None else None
70
-
71
-
72
- class SampleManager:
73
- """Audio samples IO handling within a given dora xp.
74
-
75
- The sample manager handles the dumping and loading logic for generated and
76
- references samples across epochs for a given xp, providing a simple API to
77
- store, retrieve and compare audio samples.
78
-
79
- Args:
80
- xp (dora.XP): Dora experiment object. The XP contains information on the XP folder
81
- where all outputs are stored and the configuration of the experiment,
82
- which is useful to retrieve audio-related parameters.
83
- map_reference_to_sample_id (bool): Whether to use the sample_id for all reference samples
84
- instead of generating a dedicated hash id. This is useful to allow easier comparison
85
- with ground truth sample from the files directly without having to read the JSON metadata
86
- to do the mapping (at the cost of potentially dumping duplicate prompts/references
87
- depending on the task).
88
- """
89
- def __init__(self, xp: dora.XP, map_reference_to_sample_id: bool = False):
90
- self.xp = xp
91
- self.base_folder: Path = xp.folder / xp.cfg.generate.path
92
- self.reference_folder = self.base_folder / 'reference'
93
- self.map_reference_to_sample_id = map_reference_to_sample_id
94
- self.samples: tp.List[Sample] = []
95
- self._load_samples()
96
-
97
- @property
98
- def latest_epoch(self):
99
- """Latest epoch across all samples."""
100
- return max(self.samples, key=lambda x: x.epoch).epoch if self.samples else 0
101
-
102
- def _load_samples(self):
103
- """Scan the sample folder and load existing samples."""
104
- jsons = self.base_folder.glob('**/*.json')
105
- with ThreadPoolExecutor(6) as pool:
106
- self.samples = list(pool.map(self._load_sample, jsons))
107
-
108
- @staticmethod
109
- @lru_cache(2**26)
110
- def _load_sample(json_file: Path) -> Sample:
111
- with open(json_file, 'r') as f:
112
- data: tp.Dict[str, tp.Any] = json.load(f)
113
- # fetch prompt data
114
- prompt_data = data.get('prompt')
115
- prompt = ReferenceSample(id=prompt_data['id'], path=prompt_data['path'],
116
- duration=prompt_data['duration']) if prompt_data else None
117
- # fetch reference data
118
- reference_data = data.get('reference')
119
- reference = ReferenceSample(id=reference_data['id'], path=reference_data['path'],
120
- duration=reference_data['duration']) if reference_data else None
121
- # build sample object
122
- return Sample(id=data['id'], path=data['path'], epoch=data['epoch'], duration=data['duration'],
123
- prompt=prompt, conditioning=data.get('conditioning'), reference=reference,
124
- generation_args=data.get('generation_args'))
125
-
126
- def _init_hash(self):
127
- return hashlib.sha1()
128
-
129
- def _get_tensor_id(self, tensor: torch.Tensor) -> str:
130
- hash_id = self._init_hash()
131
- hash_id.update(tensor.numpy().data)
132
- return hash_id.hexdigest()
133
-
134
- def _get_sample_id(self, index: int, prompt_wav: tp.Optional[torch.Tensor],
135
- conditions: tp.Optional[tp.Dict[str, str]]) -> str:
136
- """Computes an id for a sample given its input data.
137
- This id is deterministic if prompt and/or conditions are provided by using a sha1 hash on the input.
138
- Otherwise, a random id of the form "noinput_{uuid4().hex}" is returned.
139
-
140
- Args:
141
- index (int): Batch index, Helpful to differentiate samples from the same batch.
142
- prompt_wav (torch.Tensor): Prompt used during generation.
143
- conditions (dict[str, str]): Conditioning used during generation.
144
- """
145
- # For totally unconditioned generations we will just use a random UUID.
146
- # The function get_samples_for_xps will do a simple ordered match with a custom key.
147
- if prompt_wav is None and not conditions:
148
- return f"noinput_{uuid.uuid4().hex}"
149
-
150
- # Human readable portion
151
- hr_label = ""
152
- # Create a deterministic id using hashing
153
- hash_id = self._init_hash()
154
- hash_id.update(f"{index}".encode())
155
- if prompt_wav is not None:
156
- hash_id.update(prompt_wav.numpy().data)
157
- hr_label += "_prompted"
158
- else:
159
- hr_label += "_unprompted"
160
- if conditions:
161
- encoded_json = json.dumps(conditions, sort_keys=True).encode()
162
- hash_id.update(encoded_json)
163
- cond_str = "-".join([f"{key}={slugify(value)}"
164
- for key, value in sorted(conditions.items())])
165
- cond_str = cond_str[:100] # some raw text might be too long to be a valid filename
166
- cond_str = cond_str if len(cond_str) > 0 else "unconditioned"
167
- hr_label += f"_{cond_str}"
168
- else:
169
- hr_label += "_unconditioned"
170
-
171
- return hash_id.hexdigest() + hr_label
172
-
173
- def _store_audio(self, wav: torch.Tensor, stem_path: Path, overwrite: bool = False) -> Path:
174
- """Stores the audio with the given stem path using the XP's configuration.
175
-
176
- Args:
177
- wav (torch.Tensor): Audio to store.
178
- stem_path (Path): Path in sample output directory with file stem to use.
179
- overwrite (bool): When False (default), skips storing an existing audio file.
180
- Returns:
181
- Path: The path at which the audio is stored.
182
- """
183
- existing_paths = [
184
- path for path in stem_path.parent.glob(stem_path.stem + '.*')
185
- if path.suffix != '.json'
186
- ]
187
- exists = len(existing_paths) > 0
188
- if exists and overwrite:
189
- logger.warning(f"Overwriting existing audio file with stem path {stem_path}")
190
- elif exists:
191
- return existing_paths[0]
192
-
193
- audio_path = audio_write(stem_path, wav, **self.xp.cfg.generate.audio)
194
- return audio_path
195
-
196
- def add_sample(self, sample_wav: torch.Tensor, epoch: int, index: int = 0,
197
- conditions: tp.Optional[tp.Dict[str, str]] = None, prompt_wav: tp.Optional[torch.Tensor] = None,
198
- ground_truth_wav: tp.Optional[torch.Tensor] = None,
199
- generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> Sample:
200
- """Adds a single sample.
201
- The sample is stored in the XP's sample output directory, under a corresponding epoch folder.
202
- Each sample is assigned an id which is computed using the input data. In addition to the
203
- sample itself, a json file containing associated metadata is stored next to it.
204
-
205
- Args:
206
- sample_wav (torch.Tensor): sample audio to store. Tensor of shape [channels, shape].
207
- epoch (int): current training epoch.
208
- index (int): helpful to differentiate samples from the same batch.
209
- conditions (dict[str, str], optional): conditioning used during generation.
210
- prompt_wav (torch.Tensor, optional): prompt used during generation. Tensor of shape [channels, shape].
211
- ground_truth_wav (torch.Tensor, optional): reference audio where prompt was extracted from.
212
- Tensor of shape [channels, shape].
213
- generation_args (dict[str, any], optional): dictionary of other arguments used during generation.
214
- Returns:
215
- Sample: The saved sample.
216
- """
217
- sample_id = self._get_sample_id(index, prompt_wav, conditions)
218
- reuse_id = self.map_reference_to_sample_id
219
- prompt, ground_truth = None, None
220
- if prompt_wav is not None:
221
- prompt_id = sample_id if reuse_id else self._get_tensor_id(prompt_wav.sum(0, keepdim=True))
222
- prompt_duration = prompt_wav.shape[-1] / self.xp.cfg.sample_rate
223
- prompt_path = self._store_audio(prompt_wav, self.base_folder / str(epoch) / 'prompt' / prompt_id)
224
- prompt = ReferenceSample(prompt_id, str(prompt_path), prompt_duration)
225
- if ground_truth_wav is not None:
226
- ground_truth_id = sample_id if reuse_id else self._get_tensor_id(ground_truth_wav.sum(0, keepdim=True))
227
- ground_truth_duration = ground_truth_wav.shape[-1] / self.xp.cfg.sample_rate
228
- ground_truth_path = self._store_audio(ground_truth_wav, self.base_folder / 'reference' / ground_truth_id)
229
- ground_truth = ReferenceSample(ground_truth_id, str(ground_truth_path), ground_truth_duration)
230
- sample_path = self._store_audio(sample_wav, self.base_folder / str(epoch) / sample_id, overwrite=True)
231
- duration = sample_wav.shape[-1] / self.xp.cfg.sample_rate
232
- sample = Sample(sample_id, str(sample_path), epoch, duration, conditions, prompt, ground_truth, generation_args)
233
- self.samples.append(sample)
234
- with open(sample_path.with_suffix('.json'), 'w') as f:
235
- json.dump(asdict(sample), f, indent=2)
236
- return sample
237
-
238
- def add_samples(self, samples_wavs: torch.Tensor, epoch: int,
239
- conditioning: tp.Optional[tp.List[tp.Dict[str, tp.Any]]] = None,
240
- prompt_wavs: tp.Optional[torch.Tensor] = None,
241
- ground_truth_wavs: tp.Optional[torch.Tensor] = None,
242
- generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> tp.List[Sample]:
243
- """Adds a batch of samples.
244
- The samples are stored in the XP's sample output directory, under a corresponding
245
- epoch folder. Each sample is assigned an id which is computed using the input data and their batch index.
246
- In addition to the sample itself, a json file containing associated metadata is stored next to it.
247
-
248
- Args:
249
- sample_wavs (torch.Tensor): Batch of audio wavs to store. Tensor of shape [batch_size, channels, shape].
250
- epoch (int): Current training epoch.
251
- conditioning (list of dict[str, str], optional): List of conditions used during generation,
252
- one per sample in the batch.
253
- prompt_wavs (torch.Tensor, optional): Prompts used during generation. Tensor of shape
254
- [batch_size, channels, shape].
255
- ground_truth_wav (torch.Tensor, optional): Reference audio where prompts were extracted from.
256
- Tensor of shape [batch_size, channels, shape].
257
- generation_args (dict[str, Any], optional): Dictionary of other arguments used during generation.
258
- Returns:
259
- samples (list of Sample): The saved audio samples with prompts, ground truth and metadata.
260
- """
261
- samples = []
262
- for idx, wav in enumerate(samples_wavs):
263
- prompt_wav = prompt_wavs[idx] if prompt_wavs is not None else None
264
- gt_wav = ground_truth_wavs[idx] if ground_truth_wavs is not None else None
265
- conditions = conditioning[idx] if conditioning is not None else None
266
- samples.append(self.add_sample(wav, epoch, idx, conditions, prompt_wav, gt_wav, generation_args))
267
- return samples
268
-
269
- def get_samples(self, epoch: int = -1, max_epoch: int = -1, exclude_prompted: bool = False,
270
- exclude_unprompted: bool = False, exclude_conditioned: bool = False,
271
- exclude_unconditioned: bool = False) -> tp.Set[Sample]:
272
- """Returns a set of samples for this XP. Optionally, you can filter which samples to obtain.
273
- Please note that existing samples are loaded during the manager's initialization, and added samples through this
274
- manager are also tracked. Any other external changes are not tracked automatically, so creating a new manager
275
- is the only way detect them.
276
-
277
- Args:
278
- epoch (int): If provided, only return samples corresponding to this epoch.
279
- max_epoch (int): If provided, only return samples corresponding to the latest epoch that is <= max_epoch.
280
- exclude_prompted (bool): If True, does not include samples that used a prompt.
281
- exclude_unprompted (bool): If True, does not include samples that did not use a prompt.
282
- exclude_conditioned (bool): If True, excludes samples that used conditioning.
283
- exclude_unconditioned (bool): If True, excludes samples that did not use conditioning.
284
- Returns:
285
- Samples (set of Sample): The retrieved samples matching the provided filters.
286
- """
287
- if max_epoch >= 0:
288
- samples_epoch = max(sample.epoch for sample in self.samples if sample.epoch <= max_epoch)
289
- else:
290
- samples_epoch = self.latest_epoch if epoch < 0 else epoch
291
- samples = {
292
- sample
293
- for sample in self.samples
294
- if (
295
- (sample.epoch == samples_epoch) and
296
- (not exclude_prompted or sample.prompt is None) and
297
- (not exclude_unprompted or sample.prompt is not None) and
298
- (not exclude_conditioned or not sample.conditioning) and
299
- (not exclude_unconditioned or sample.conditioning)
300
- )
301
- }
302
- return samples
303
-
304
-
305
- def slugify(value: tp.Any, allow_unicode: bool = False):
306
- """Process string for safer file naming.
307
-
308
- Taken from https://github.com/django/django/blob/master/django/utils/text.py
309
-
310
- Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
311
- dashes to single dashes. Remove characters that aren't alphanumerics,
312
- underscores, or hyphens. Convert to lowercase. Also strip leading and
313
- trailing whitespace, dashes, and underscores.
314
- """
315
- value = str(value)
316
- if allow_unicode:
317
- value = unicodedata.normalize("NFKC", value)
318
- else:
319
- value = (
320
- unicodedata.normalize("NFKD", value)
321
- .encode("ascii", "ignore")
322
- .decode("ascii")
323
- )
324
- value = re.sub(r"[^\w\s-]", "", value.lower())
325
- return re.sub(r"[-\s]+", "-", value).strip("-_")
326
-
327
-
328
- def _match_stable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]:
329
- # Create a dictionary of stable id -> sample per XP
330
- stable_samples_per_xp = [{
331
- sample.id: sample for sample in samples
332
- if sample.prompt is not None or sample.conditioning
333
- } for samples in samples_per_xp]
334
- # Set of all stable ids
335
- stable_ids = {id for samples in stable_samples_per_xp for id in samples.keys()}
336
- # Dictionary of stable id -> list of samples. If an XP does not have it, assign None
337
- stable_samples = {id: [xp.get(id) for xp in stable_samples_per_xp] for id in stable_ids}
338
- # Filter out ids that contain None values (we only want matched samples after all)
339
- # cast is necessary to avoid mypy linter errors.
340
- return {id: tp.cast(tp.List[Sample], samples) for id, samples in stable_samples.items() if None not in samples}
341
-
342
-
343
- def _match_unstable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]:
344
- # For unstable ids, we use a sorted list since we'll match them in order
345
- unstable_samples_per_xp = [[
346
- sample for sample in sorted(samples, key=lambda x: x.id)
347
- if sample.prompt is None and not sample.conditioning
348
- ] for samples in samples_per_xp]
349
- # Trim samples per xp so all samples can have a match
350
- min_len = min([len(samples) for samples in unstable_samples_per_xp])
351
- unstable_samples_per_xp = [samples[:min_len] for samples in unstable_samples_per_xp]
352
- # Dictionary of index -> list of matched samples
353
- return {
354
- f'noinput_{i}': [samples[i] for samples in unstable_samples_per_xp] for i in range(min_len)
355
- }
356
-
357
-
358
- def get_samples_for_xps(xps: tp.List[dora.XP], **kwargs) -> tp.Dict[str, tp.List[Sample]]:
359
- """Gets a dictionary of matched samples across the given XPs.
360
- Each dictionary entry maps a sample id to a list of samples for that id. The number of samples per id
361
- will always match the number of XPs provided and will correspond to each XP in the same order given.
362
- In other words, only samples that can be match across all provided XPs will be returned
363
- in order to satisfy this rule.
364
-
365
- There are two types of ids that can be returned: stable and unstable.
366
- * Stable IDs are deterministic ids that were computed by the SampleManager given a sample's inputs
367
- (prompts/conditioning). This is why we can match them across XPs.
368
- * Unstable IDs are of the form "noinput_{idx}" and are generated on-the-fly, in order to map samples
369
- that used non-deterministic, random ids. This is the case for samples that did not use prompts or
370
- conditioning for their generation. This function will sort these samples by their id and match them
371
- by their index.
372
-
373
- Args:
374
- xps: a list of XPs to match samples from.
375
- start_epoch (int): If provided, only return samples corresponding to this epoch or newer.
376
- end_epoch (int): If provided, only return samples corresponding to this epoch or older.
377
- exclude_prompted (bool): If True, does not include samples that used a prompt.
378
- exclude_unprompted (bool): If True, does not include samples that did not use a prompt.
379
- exclude_conditioned (bool): If True, excludes samples that used conditioning.
380
- exclude_unconditioned (bool): If True, excludes samples that did not use conditioning.
381
- """
382
- managers = [SampleManager(xp) for xp in xps]
383
- samples_per_xp = [manager.get_samples(**kwargs) for manager in managers]
384
- stable_samples = _match_stable_samples(samples_per_xp)
385
- unstable_samples = _match_unstable_samples(samples_per_xp)
386
- return dict(stable_samples, **unstable_samples)