unpairedelectron07 commited on
Commit
e5de9ff
1 Parent(s): ee232aa

Upload manager.py

Browse files
Files changed (1) hide show
  1. audiocraft/utils/samples/manager.py +386 -0
audiocraft/utils/samples/manager.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ API that can manage the storage and retrieval of generated samples produced by experiments.
9
+
10
+ It offers the following benefits:
11
+ * Samples are stored in a consistent way across epoch
12
+ * Metadata about the samples can be stored and retrieved
13
+ * Can retrieve audio
14
+ * Identifiers are reliable and deterministic for prompted and conditioned samples
15
+ * Can request the samples for multiple XPs, grouped by sample identifier
16
+ * For no-input samples (not prompt and no conditions), samples across XPs are matched
17
+ by sorting their identifiers
18
+ """
19
+
20
+ from concurrent.futures import ThreadPoolExecutor
21
+ from dataclasses import asdict, dataclass
22
+ from functools import lru_cache
23
+ import hashlib
24
+ import json
25
+ import logging
26
+ from pathlib import Path
27
+ import re
28
+ import typing as tp
29
+ import unicodedata
30
+ import uuid
31
+
32
+ import dora
33
+ import torch
34
+
35
+ from ...data.audio import audio_read, audio_write
36
+
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ @dataclass
42
+ class ReferenceSample:
43
+ id: str
44
+ path: str
45
+ duration: float
46
+
47
+
48
+ @dataclass
49
+ class Sample:
50
+ id: str
51
+ path: str
52
+ epoch: int
53
+ duration: float
54
+ conditioning: tp.Optional[tp.Dict[str, tp.Any]]
55
+ prompt: tp.Optional[ReferenceSample]
56
+ reference: tp.Optional[ReferenceSample]
57
+ generation_args: tp.Optional[tp.Dict[str, tp.Any]]
58
+
59
+ def __hash__(self):
60
+ return hash(self.id)
61
+
62
+ def audio(self) -> tp.Tuple[torch.Tensor, int]:
63
+ return audio_read(self.path)
64
+
65
+ def audio_prompt(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]:
66
+ return audio_read(self.prompt.path) if self.prompt is not None else None
67
+
68
+ def audio_reference(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]:
69
+ return audio_read(self.reference.path) if self.reference is not None else None
70
+
71
+
72
+ class SampleManager:
73
+ """Audio samples IO handling within a given dora xp.
74
+
75
+ The sample manager handles the dumping and loading logic for generated and
76
+ references samples across epochs for a given xp, providing a simple API to
77
+ store, retrieve and compare audio samples.
78
+
79
+ Args:
80
+ xp (dora.XP): Dora experiment object. The XP contains information on the XP folder
81
+ where all outputs are stored and the configuration of the experiment,
82
+ which is useful to retrieve audio-related parameters.
83
+ map_reference_to_sample_id (bool): Whether to use the sample_id for all reference samples
84
+ instead of generating a dedicated hash id. This is useful to allow easier comparison
85
+ with ground truth sample from the files directly without having to read the JSON metadata
86
+ to do the mapping (at the cost of potentially dumping duplicate prompts/references
87
+ depending on the task).
88
+ """
89
+ def __init__(self, xp: dora.XP, map_reference_to_sample_id: bool = False):
90
+ self.xp = xp
91
+ self.base_folder: Path = xp.folder / xp.cfg.generate.path
92
+ self.reference_folder = self.base_folder / 'reference'
93
+ self.map_reference_to_sample_id = map_reference_to_sample_id
94
+ self.samples: tp.List[Sample] = []
95
+ self._load_samples()
96
+
97
+ @property
98
+ def latest_epoch(self):
99
+ """Latest epoch across all samples."""
100
+ return max(self.samples, key=lambda x: x.epoch).epoch if self.samples else 0
101
+
102
+ def _load_samples(self):
103
+ """Scan the sample folder and load existing samples."""
104
+ jsons = self.base_folder.glob('**/*.json')
105
+ with ThreadPoolExecutor(6) as pool:
106
+ self.samples = list(pool.map(self._load_sample, jsons))
107
+
108
+ @staticmethod
109
+ @lru_cache(2**26)
110
+ def _load_sample(json_file: Path) -> Sample:
111
+ with open(json_file, 'r') as f:
112
+ data: tp.Dict[str, tp.Any] = json.load(f)
113
+ # fetch prompt data
114
+ prompt_data = data.get('prompt')
115
+ prompt = ReferenceSample(id=prompt_data['id'], path=prompt_data['path'],
116
+ duration=prompt_data['duration']) if prompt_data else None
117
+ # fetch reference data
118
+ reference_data = data.get('reference')
119
+ reference = ReferenceSample(id=reference_data['id'], path=reference_data['path'],
120
+ duration=reference_data['duration']) if reference_data else None
121
+ # build sample object
122
+ return Sample(id=data['id'], path=data['path'], epoch=data['epoch'], duration=data['duration'],
123
+ prompt=prompt, conditioning=data.get('conditioning'), reference=reference,
124
+ generation_args=data.get('generation_args'))
125
+
126
+ def _init_hash(self):
127
+ return hashlib.sha1()
128
+
129
+ def _get_tensor_id(self, tensor: torch.Tensor) -> str:
130
+ hash_id = self._init_hash()
131
+ hash_id.update(tensor.numpy().data)
132
+ return hash_id.hexdigest()
133
+
134
+ def _get_sample_id(self, index: int, prompt_wav: tp.Optional[torch.Tensor],
135
+ conditions: tp.Optional[tp.Dict[str, str]]) -> str:
136
+ """Computes an id for a sample given its input data.
137
+ This id is deterministic if prompt and/or conditions are provided by using a sha1 hash on the input.
138
+ Otherwise, a random id of the form "noinput_{uuid4().hex}" is returned.
139
+
140
+ Args:
141
+ index (int): Batch index, Helpful to differentiate samples from the same batch.
142
+ prompt_wav (torch.Tensor): Prompt used during generation.
143
+ conditions (dict[str, str]): Conditioning used during generation.
144
+ """
145
+ # For totally unconditioned generations we will just use a random UUID.
146
+ # The function get_samples_for_xps will do a simple ordered match with a custom key.
147
+ if prompt_wav is None and not conditions:
148
+ return f"noinput_{uuid.uuid4().hex}"
149
+
150
+ # Human readable portion
151
+ hr_label = ""
152
+ # Create a deterministic id using hashing
153
+ hash_id = self._init_hash()
154
+ hash_id.update(f"{index}".encode())
155
+ if prompt_wav is not None:
156
+ hash_id.update(prompt_wav.numpy().data)
157
+ hr_label += "_prompted"
158
+ else:
159
+ hr_label += "_unprompted"
160
+ if conditions:
161
+ encoded_json = json.dumps(conditions, sort_keys=True).encode()
162
+ hash_id.update(encoded_json)
163
+ cond_str = "-".join([f"{key}={slugify(value)}"
164
+ for key, value in sorted(conditions.items())])
165
+ cond_str = cond_str[:100] # some raw text might be too long to be a valid filename
166
+ cond_str = cond_str if len(cond_str) > 0 else "unconditioned"
167
+ hr_label += f"_{cond_str}"
168
+ else:
169
+ hr_label += "_unconditioned"
170
+
171
+ return hash_id.hexdigest() + hr_label
172
+
173
+ def _store_audio(self, wav: torch.Tensor, stem_path: Path, overwrite: bool = False) -> Path:
174
+ """Stores the audio with the given stem path using the XP's configuration.
175
+
176
+ Args:
177
+ wav (torch.Tensor): Audio to store.
178
+ stem_path (Path): Path in sample output directory with file stem to use.
179
+ overwrite (bool): When False (default), skips storing an existing audio file.
180
+ Returns:
181
+ Path: The path at which the audio is stored.
182
+ """
183
+ existing_paths = [
184
+ path for path in stem_path.parent.glob(stem_path.stem + '.*')
185
+ if path.suffix != '.json'
186
+ ]
187
+ exists = len(existing_paths) > 0
188
+ if exists and overwrite:
189
+ logger.warning(f"Overwriting existing audio file with stem path {stem_path}")
190
+ elif exists:
191
+ return existing_paths[0]
192
+
193
+ audio_path = audio_write(stem_path, wav, **self.xp.cfg.generate.audio)
194
+ return audio_path
195
+
196
+ def add_sample(self, sample_wav: torch.Tensor, epoch: int, index: int = 0,
197
+ conditions: tp.Optional[tp.Dict[str, str]] = None, prompt_wav: tp.Optional[torch.Tensor] = None,
198
+ ground_truth_wav: tp.Optional[torch.Tensor] = None,
199
+ generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> Sample:
200
+ """Adds a single sample.
201
+ The sample is stored in the XP's sample output directory, under a corresponding epoch folder.
202
+ Each sample is assigned an id which is computed using the input data. In addition to the
203
+ sample itself, a json file containing associated metadata is stored next to it.
204
+
205
+ Args:
206
+ sample_wav (torch.Tensor): sample audio to store. Tensor of shape [channels, shape].
207
+ epoch (int): current training epoch.
208
+ index (int): helpful to differentiate samples from the same batch.
209
+ conditions (dict[str, str], optional): conditioning used during generation.
210
+ prompt_wav (torch.Tensor, optional): prompt used during generation. Tensor of shape [channels, shape].
211
+ ground_truth_wav (torch.Tensor, optional): reference audio where prompt was extracted from.
212
+ Tensor of shape [channels, shape].
213
+ generation_args (dict[str, any], optional): dictionary of other arguments used during generation.
214
+ Returns:
215
+ Sample: The saved sample.
216
+ """
217
+ sample_id = self._get_sample_id(index, prompt_wav, conditions)
218
+ reuse_id = self.map_reference_to_sample_id
219
+ prompt, ground_truth = None, None
220
+ if prompt_wav is not None:
221
+ prompt_id = sample_id if reuse_id else self._get_tensor_id(prompt_wav.sum(0, keepdim=True))
222
+ prompt_duration = prompt_wav.shape[-1] / self.xp.cfg.sample_rate
223
+ prompt_path = self._store_audio(prompt_wav, self.base_folder / str(epoch) / 'prompt' / prompt_id)
224
+ prompt = ReferenceSample(prompt_id, str(prompt_path), prompt_duration)
225
+ if ground_truth_wav is not None:
226
+ ground_truth_id = sample_id if reuse_id else self._get_tensor_id(ground_truth_wav.sum(0, keepdim=True))
227
+ ground_truth_duration = ground_truth_wav.shape[-1] / self.xp.cfg.sample_rate
228
+ ground_truth_path = self._store_audio(ground_truth_wav, self.base_folder / 'reference' / ground_truth_id)
229
+ ground_truth = ReferenceSample(ground_truth_id, str(ground_truth_path), ground_truth_duration)
230
+ sample_path = self._store_audio(sample_wav, self.base_folder / str(epoch) / sample_id, overwrite=True)
231
+ duration = sample_wav.shape[-1] / self.xp.cfg.sample_rate
232
+ sample = Sample(sample_id, str(sample_path), epoch, duration, conditions, prompt, ground_truth, generation_args)
233
+ self.samples.append(sample)
234
+ with open(sample_path.with_suffix('.json'), 'w') as f:
235
+ json.dump(asdict(sample), f, indent=2)
236
+ return sample
237
+
238
+ def add_samples(self, samples_wavs: torch.Tensor, epoch: int,
239
+ conditioning: tp.Optional[tp.List[tp.Dict[str, tp.Any]]] = None,
240
+ prompt_wavs: tp.Optional[torch.Tensor] = None,
241
+ ground_truth_wavs: tp.Optional[torch.Tensor] = None,
242
+ generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> tp.List[Sample]:
243
+ """Adds a batch of samples.
244
+ The samples are stored in the XP's sample output directory, under a corresponding
245
+ epoch folder. Each sample is assigned an id which is computed using the input data and their batch index.
246
+ In addition to the sample itself, a json file containing associated metadata is stored next to it.
247
+
248
+ Args:
249
+ sample_wavs (torch.Tensor): Batch of audio wavs to store. Tensor of shape [batch_size, channels, shape].
250
+ epoch (int): Current training epoch.
251
+ conditioning (list of dict[str, str], optional): List of conditions used during generation,
252
+ one per sample in the batch.
253
+ prompt_wavs (torch.Tensor, optional): Prompts used during generation. Tensor of shape
254
+ [batch_size, channels, shape].
255
+ ground_truth_wav (torch.Tensor, optional): Reference audio where prompts were extracted from.
256
+ Tensor of shape [batch_size, channels, shape].
257
+ generation_args (dict[str, Any], optional): Dictionary of other arguments used during generation.
258
+ Returns:
259
+ samples (list of Sample): The saved audio samples with prompts, ground truth and metadata.
260
+ """
261
+ samples = []
262
+ for idx, wav in enumerate(samples_wavs):
263
+ prompt_wav = prompt_wavs[idx] if prompt_wavs is not None else None
264
+ gt_wav = ground_truth_wavs[idx] if ground_truth_wavs is not None else None
265
+ conditions = conditioning[idx] if conditioning is not None else None
266
+ samples.append(self.add_sample(wav, epoch, idx, conditions, prompt_wav, gt_wav, generation_args))
267
+ return samples
268
+
269
+ def get_samples(self, epoch: int = -1, max_epoch: int = -1, exclude_prompted: bool = False,
270
+ exclude_unprompted: bool = False, exclude_conditioned: bool = False,
271
+ exclude_unconditioned: bool = False) -> tp.Set[Sample]:
272
+ """Returns a set of samples for this XP. Optionally, you can filter which samples to obtain.
273
+ Please note that existing samples are loaded during the manager's initialization, and added samples through this
274
+ manager are also tracked. Any other external changes are not tracked automatically, so creating a new manager
275
+ is the only way detect them.
276
+
277
+ Args:
278
+ epoch (int): If provided, only return samples corresponding to this epoch.
279
+ max_epoch (int): If provided, only return samples corresponding to the latest epoch that is <= max_epoch.
280
+ exclude_prompted (bool): If True, does not include samples that used a prompt.
281
+ exclude_unprompted (bool): If True, does not include samples that did not use a prompt.
282
+ exclude_conditioned (bool): If True, excludes samples that used conditioning.
283
+ exclude_unconditioned (bool): If True, excludes samples that did not use conditioning.
284
+ Returns:
285
+ Samples (set of Sample): The retrieved samples matching the provided filters.
286
+ """
287
+ if max_epoch >= 0:
288
+ samples_epoch = max(sample.epoch for sample in self.samples if sample.epoch <= max_epoch)
289
+ else:
290
+ samples_epoch = self.latest_epoch if epoch < 0 else epoch
291
+ samples = {
292
+ sample
293
+ for sample in self.samples
294
+ if (
295
+ (sample.epoch == samples_epoch) and
296
+ (not exclude_prompted or sample.prompt is None) and
297
+ (not exclude_unprompted or sample.prompt is not None) and
298
+ (not exclude_conditioned or not sample.conditioning) and
299
+ (not exclude_unconditioned or sample.conditioning)
300
+ )
301
+ }
302
+ return samples
303
+
304
+
305
+ def slugify(value: tp.Any, allow_unicode: bool = False):
306
+ """Process string for safer file naming.
307
+
308
+ Taken from https://github.com/django/django/blob/master/django/utils/text.py
309
+
310
+ Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
311
+ dashes to single dashes. Remove characters that aren't alphanumerics,
312
+ underscores, or hyphens. Convert to lowercase. Also strip leading and
313
+ trailing whitespace, dashes, and underscores.
314
+ """
315
+ value = str(value)
316
+ if allow_unicode:
317
+ value = unicodedata.normalize("NFKC", value)
318
+ else:
319
+ value = (
320
+ unicodedata.normalize("NFKD", value)
321
+ .encode("ascii", "ignore")
322
+ .decode("ascii")
323
+ )
324
+ value = re.sub(r"[^\w\s-]", "", value.lower())
325
+ return re.sub(r"[-\s]+", "-", value).strip("-_")
326
+
327
+
328
+ def _match_stable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]:
329
+ # Create a dictionary of stable id -> sample per XP
330
+ stable_samples_per_xp = [{
331
+ sample.id: sample for sample in samples
332
+ if sample.prompt is not None or sample.conditioning
333
+ } for samples in samples_per_xp]
334
+ # Set of all stable ids
335
+ stable_ids = {id for samples in stable_samples_per_xp for id in samples.keys()}
336
+ # Dictionary of stable id -> list of samples. If an XP does not have it, assign None
337
+ stable_samples = {id: [xp.get(id) for xp in stable_samples_per_xp] for id in stable_ids}
338
+ # Filter out ids that contain None values (we only want matched samples after all)
339
+ # cast is necessary to avoid mypy linter errors.
340
+ return {id: tp.cast(tp.List[Sample], samples) for id, samples in stable_samples.items() if None not in samples}
341
+
342
+
343
+ def _match_unstable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]:
344
+ # For unstable ids, we use a sorted list since we'll match them in order
345
+ unstable_samples_per_xp = [[
346
+ sample for sample in sorted(samples, key=lambda x: x.id)
347
+ if sample.prompt is None and not sample.conditioning
348
+ ] for samples in samples_per_xp]
349
+ # Trim samples per xp so all samples can have a match
350
+ min_len = min([len(samples) for samples in unstable_samples_per_xp])
351
+ unstable_samples_per_xp = [samples[:min_len] for samples in unstable_samples_per_xp]
352
+ # Dictionary of index -> list of matched samples
353
+ return {
354
+ f'noinput_{i}': [samples[i] for samples in unstable_samples_per_xp] for i in range(min_len)
355
+ }
356
+
357
+
358
+ def get_samples_for_xps(xps: tp.List[dora.XP], **kwargs) -> tp.Dict[str, tp.List[Sample]]:
359
+ """Gets a dictionary of matched samples across the given XPs.
360
+ Each dictionary entry maps a sample id to a list of samples for that id. The number of samples per id
361
+ will always match the number of XPs provided and will correspond to each XP in the same order given.
362
+ In other words, only samples that can be match across all provided XPs will be returned
363
+ in order to satisfy this rule.
364
+
365
+ There are two types of ids that can be returned: stable and unstable.
366
+ * Stable IDs are deterministic ids that were computed by the SampleManager given a sample's inputs
367
+ (prompts/conditioning). This is why we can match them across XPs.
368
+ * Unstable IDs are of the form "noinput_{idx}" and are generated on-the-fly, in order to map samples
369
+ that used non-deterministic, random ids. This is the case for samples that did not use prompts or
370
+ conditioning for their generation. This function will sort these samples by their id and match them
371
+ by their index.
372
+
373
+ Args:
374
+ xps: a list of XPs to match samples from.
375
+ start_epoch (int): If provided, only return samples corresponding to this epoch or newer.
376
+ end_epoch (int): If provided, only return samples corresponding to this epoch or older.
377
+ exclude_prompted (bool): If True, does not include samples that used a prompt.
378
+ exclude_unprompted (bool): If True, does not include samples that did not use a prompt.
379
+ exclude_conditioned (bool): If True, excludes samples that used conditioning.
380
+ exclude_unconditioned (bool): If True, excludes samples that did not use conditioning.
381
+ """
382
+ managers = [SampleManager(xp) for xp in xps]
383
+ samples_per_xp = [manager.get_samples(**kwargs) for manager in managers]
384
+ stable_samples = _match_stable_samples(samples_per_xp)
385
+ unstable_samples = _match_unstable_samples(samples_per_xp)
386
+ return dict(stable_samples, **unstable_samples)