prompt_token defaults to torch.zeros()
Browse files- audiocraft/audiogen.py +0 -31
- audiocraft/genmodel.py +32 -96
- audiocraft/lm.py +4 -23
- audiocraft/utils/audio_utils.py +0 -176
- audiocraft/utils/samples/__init__.py +0 -5
- audiocraft/utils/samples/manager.py +0 -386
audiocraft/audiogen.py
CHANGED
@@ -16,37 +16,6 @@ from audiocraft.encodec import CompressionModel
|
|
16 |
from audiocraft.genmodel import BaseGenModel
|
17 |
from audiocraft.lm import LMModel
|
18 |
from audiocraft.loaders import load_compression_model, load_lm_model
|
19 |
-
from .utils.audio_utils import f32_pcm, normalize_audio
|
20 |
-
|
21 |
-
|
22 |
-
def audio_write(stem_name,
|
23 |
-
wav,
|
24 |
-
sample_rate,
|
25 |
-
format= 'wav',
|
26 |
-
mp3_rate=320,
|
27 |
-
ogg_rate= None,
|
28 |
-
normalize= True,
|
29 |
-
strategy= 'peak',
|
30 |
-
peak_clip_headroom_db=1,
|
31 |
-
rms_headroom_db= 18,
|
32 |
-
loudness_headroom_db = 14,
|
33 |
-
loudness_compressor = False,
|
34 |
-
log_clipping = True,
|
35 |
-
make_parent_dir = True,
|
36 |
-
add_suffix = True):
|
37 |
-
|
38 |
-
assert wav.dtype.is_floating_point, "wav is not floating point"
|
39 |
-
if wav.dim() == 1:
|
40 |
-
wav = wav[None]
|
41 |
-
elif wav.dim() > 2:
|
42 |
-
raise ValueError("Input wav should be at most 2 dimension.")
|
43 |
-
assert wav.isfinite().all()
|
44 |
-
wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
|
45 |
-
rms_headroom_db, loudness_headroom_db, loudness_compressor,
|
46 |
-
log_clipping=log_clipping, sample_rate=sample_rate,
|
47 |
-
stem_name=str(stem_name))
|
48 |
-
return wav
|
49 |
-
# ===
|
50 |
|
51 |
class AudioGen(BaseGenModel):
|
52 |
"""AudioGen main model with convenient generation API.
|
|
|
16 |
from audiocraft.genmodel import BaseGenModel
|
17 |
from audiocraft.lm import LMModel
|
18 |
from audiocraft.loaders import load_compression_model, load_lm_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
class AudioGen(BaseGenModel):
|
21 |
"""AudioGen main model with convenient generation API.
|
audiocraft/genmodel.py
CHANGED
@@ -1,12 +1,10 @@
|
|
1 |
-
from abc import ABC, abstractmethod
|
2 |
import typing as tp
|
3 |
-
|
4 |
import omegaconf
|
5 |
import torch
|
6 |
|
|
|
7 |
from .encodec import CompressionModel
|
8 |
from .lm import LMModel
|
9 |
-
from .utils.audio_utils import convert_audio
|
10 |
from .conditioners import ConditioningAttributes
|
11 |
from .utils.autocast import TorchAutocast
|
12 |
|
@@ -16,11 +14,9 @@ class BaseGenModel(ABC):
|
|
16 |
|
17 |
Args:
|
18 |
name (str): name of the model.
|
19 |
-
compression_model (CompressionModel):
|
20 |
-
|
21 |
-
|
22 |
-
max_duration (float, optional): maximum duration the model can produce,
|
23 |
-
otherwise, inferred from the training params.
|
24 |
"""
|
25 |
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
|
26 |
max_duration: tp.Optional[float] = None):
|
@@ -46,19 +42,16 @@ class BaseGenModel(ABC):
|
|
46 |
|
47 |
self.max_duration: float = max_duration
|
48 |
self.duration = self.max_duration
|
49 |
-
|
50 |
-
# self.extend_stride is the length of audio extension when generating samples longer
|
51 |
-
# than self.max_duration. NOTE: the derived class must set self.extend_stride to a
|
52 |
-
# positive float value when generating with self.duration > self.max_duration.
|
53 |
-
self.extend_stride: tp.Optional[float] = None
|
54 |
self.device = next(iter(lm.parameters())).device
|
55 |
-
self.generation_params
|
56 |
-
self._progress_callback
|
57 |
if self.device.type == 'cpu':
|
58 |
self.autocast = TorchAutocast(enabled=False)
|
59 |
else:
|
60 |
self.autocast = TorchAutocast(
|
61 |
-
enabled=True,
|
|
|
|
|
62 |
|
63 |
@property
|
64 |
def frame_rate(self) -> float:
|
@@ -92,64 +85,36 @@ class BaseGenModel(ABC):
|
|
92 |
@torch.no_grad()
|
93 |
def _prepare_tokens_and_attributes(
|
94 |
self,
|
95 |
-
descriptions
|
96 |
-
prompt
|
97 |
-
)
|
98 |
-
"""Prepare model inputs.
|
99 |
-
|
100 |
-
Args:
|
101 |
-
descriptions (list of str): A list of strings used as text conditioning.
|
102 |
-
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
103 |
-
"""
|
104 |
attributes = [
|
105 |
-
ConditioningAttributes(text={'description': description})
|
106 |
-
|
107 |
-
|
108 |
-
if prompt is not None:
|
109 |
-
if descriptions is not None:
|
110 |
-
assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
|
111 |
-
prompt = prompt.to(self.device)
|
112 |
-
prompt_tokens, scale = self.compression_model.encode(prompt)
|
113 |
-
assert scale is None
|
114 |
-
else:
|
115 |
-
prompt_tokens = None
|
116 |
return attributes, prompt_tokens
|
117 |
|
118 |
-
def generate_unconditional(self,
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
Args:
|
124 |
-
num_samples (int): Number of samples to be generated.
|
125 |
-
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
126 |
-
"""
|
127 |
descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
|
128 |
-
attributes,
|
129 |
-
tokens = self._generate_tokens(attributes
|
130 |
if return_tokens:
|
131 |
return self.generate_audio(tokens), tokens
|
132 |
return self.generate_audio(tokens)
|
133 |
|
134 |
def generate(self, descriptions, progress = False, return_tokens= False):
|
135 |
-
attributes,
|
136 |
-
|
137 |
-
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
138 |
if return_tokens:
|
139 |
return self.generate_audio(tokens), tokens
|
140 |
return self.generate_audio(tokens)
|
141 |
|
142 |
-
def _generate_tokens(self, attributes
|
143 |
-
prompt_tokens
|
144 |
-
|
145 |
-
|
146 |
-
Args:
|
147 |
-
attributes (list of ConditioningAttributes): Conditions used for generation (here text).
|
148 |
-
prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
|
149 |
-
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
150 |
-
Returns:
|
151 |
-
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
152 |
-
"""
|
153 |
total_gen_len = int(self.duration * self.frame_rate)
|
154 |
max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
|
155 |
current_gen_offset: int = 0
|
@@ -163,49 +128,20 @@ class BaseGenModel(ABC):
|
|
163 |
else:
|
164 |
print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r')
|
165 |
|
166 |
-
if prompt_tokens is not None:
|
167 |
-
assert max_prompt_len >= prompt_tokens.shape[-1], \
|
168 |
-
"Prompt is longer than audio to generate"
|
169 |
-
|
170 |
callback = None
|
171 |
if progress:
|
172 |
callback = _progress_callback
|
173 |
|
174 |
if self.duration <= self.max_duration:
|
175 |
# generate by sampling from LM, simple case.
|
|
|
176 |
with self.autocast:
|
177 |
-
gen_tokens = self.lm.generate(
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
else:
|
182 |
-
|
183 |
-
assert self.extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
|
184 |
-
all_tokens = []
|
185 |
-
if prompt_tokens is None:
|
186 |
-
prompt_length = 0
|
187 |
-
else:
|
188 |
-
all_tokens.append(prompt_tokens)
|
189 |
-
prompt_length = prompt_tokens.shape[-1]
|
190 |
-
|
191 |
-
stride_tokens = int(self.frame_rate * self.extend_stride)
|
192 |
-
while current_gen_offset + prompt_length < total_gen_len:
|
193 |
-
time_offset = current_gen_offset / self.frame_rate
|
194 |
-
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
195 |
-
max_gen_len = int(chunk_duration * self.frame_rate)
|
196 |
-
with self.autocast:
|
197 |
-
gen_tokens = self.lm.generate(
|
198 |
-
prompt_tokens, attributes,
|
199 |
-
callback=callback, max_gen_len=max_gen_len, **self.generation_params)
|
200 |
-
if prompt_tokens is None:
|
201 |
-
all_tokens.append(gen_tokens)
|
202 |
-
else:
|
203 |
-
all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
|
204 |
-
prompt_tokens = gen_tokens[:, :, stride_tokens:]
|
205 |
-
prompt_length = prompt_tokens.shape[-1]
|
206 |
-
current_gen_offset += stride_tokens
|
207 |
-
|
208 |
-
gen_tokens = torch.cat(all_tokens, dim=-1)
|
209 |
return gen_tokens
|
210 |
|
211 |
def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
1 |
import typing as tp
|
|
|
2 |
import omegaconf
|
3 |
import torch
|
4 |
|
5 |
+
from abc import ABC, abstractmethod
|
6 |
from .encodec import CompressionModel
|
7 |
from .lm import LMModel
|
|
|
8 |
from .conditioners import ConditioningAttributes
|
9 |
from .utils.autocast import TorchAutocast
|
10 |
|
|
|
14 |
|
15 |
Args:
|
16 |
name (str): name of the model.
|
17 |
+
compression_model (CompressionModel): Encodec with Seanet Decoder
|
18 |
+
lm (LMModel): Language model over discrete representations
|
19 |
+
max_duration (float, optional): As is using top250 token draw() we can gen xN sequences
|
|
|
|
|
20 |
"""
|
21 |
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
|
22 |
max_duration: tp.Optional[float] = None):
|
|
|
42 |
|
43 |
self.max_duration: float = max_duration
|
44 |
self.duration = self.max_duration
|
|
|
|
|
|
|
|
|
|
|
45 |
self.device = next(iter(lm.parameters())).device
|
46 |
+
self.generation_params={}
|
47 |
+
self._progress_callback=None
|
48 |
if self.device.type == 'cpu':
|
49 |
self.autocast = TorchAutocast(enabled=False)
|
50 |
else:
|
51 |
self.autocast = TorchAutocast(
|
52 |
+
enabled=True,
|
53 |
+
device_type=self.device.type,
|
54 |
+
dtype=torch.float16)
|
55 |
|
56 |
@property
|
57 |
def frame_rate(self) -> float:
|
|
|
85 |
@torch.no_grad()
|
86 |
def _prepare_tokens_and_attributes(
|
87 |
self,
|
88 |
+
descriptions,
|
89 |
+
prompt,
|
90 |
+
):
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
attributes = [
|
92 |
+
ConditioningAttributes(text={'description': description}) for description in descriptions]
|
93 |
+
prompt_tokens = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
return attributes, prompt_tokens
|
95 |
|
96 |
+
def generate_unconditional(self,
|
97 |
+
num_samples,
|
98 |
+
progress=False,
|
99 |
+
return_tokens=False):
|
|
|
|
|
|
|
|
|
|
|
100 |
descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
|
101 |
+
attributes, _ = self._prepare_tokens_and_attributes(descriptions, None)
|
102 |
+
tokens = self._generate_tokens(attributes)
|
103 |
if return_tokens:
|
104 |
return self.generate_audio(tokens), tokens
|
105 |
return self.generate_audio(tokens)
|
106 |
|
107 |
def generate(self, descriptions, progress = False, return_tokens= False):
|
108 |
+
attributes, _ = self._prepare_tokens_and_attributes(descriptions, None)
|
109 |
+
tokens = self._generate_tokens(attributes)
|
|
|
110 |
if return_tokens:
|
111 |
return self.generate_audio(tokens), tokens
|
112 |
return self.generate_audio(tokens)
|
113 |
|
114 |
+
def _generate_tokens(self, attributes,
|
115 |
+
prompt_tokens=None,
|
116 |
+
progress=False):
|
117 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
total_gen_len = int(self.duration * self.frame_rate)
|
119 |
max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
|
120 |
current_gen_offset: int = 0
|
|
|
128 |
else:
|
129 |
print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r')
|
130 |
|
|
|
|
|
|
|
|
|
131 |
callback = None
|
132 |
if progress:
|
133 |
callback = _progress_callback
|
134 |
|
135 |
if self.duration <= self.max_duration:
|
136 |
# generate by sampling from LM, simple case.
|
137 |
+
|
138 |
with self.autocast:
|
139 |
+
gen_tokens = self.lm.generate(conditions=attributes,
|
140 |
+
callback=callback,
|
141 |
+
max_gen_len=total_gen_len,
|
142 |
+
**self.generation_params)
|
143 |
else:
|
144 |
+
print('<>Long gen ?<>')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
return gen_tokens
|
146 |
|
147 |
def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor:
|
audiocraft/lm.py
CHANGED
@@ -323,7 +323,7 @@ class LMModel(StreamingModule):
|
|
323 |
def generate(self,
|
324 |
prompt = None,
|
325 |
conditions = [],
|
326 |
-
num_samples =
|
327 |
max_gen_len: int = 256,
|
328 |
use_sampling: bool = True,
|
329 |
temp: float = 1.0,
|
@@ -335,30 +335,10 @@ class LMModel(StreamingModule):
|
|
335 |
check: bool = False,
|
336 |
callback: tp.Optional[tp.Callable[[int, int], None]] = None,
|
337 |
**kwargs) -> torch.Tensor:
|
338 |
-
|
339 |
-
|
340 |
-
Args:
|
341 |
-
|
342 |
-
Returns:
|
343 |
-
torch.Tensor: tokens
|
344 |
-
"""
|
345 |
-
assert not self.training, "generation shouldn't be used in training mode."
|
346 |
first_param = next(iter(self.parameters()))
|
347 |
device = first_param.device
|
348 |
-
|
349 |
-
# Checking all input shapes are consistent.
|
350 |
-
possible_num_samples = []
|
351 |
-
if num_samples is not None:
|
352 |
-
possible_num_samples.append(num_samples)
|
353 |
-
elif prompt is not None:
|
354 |
-
possible_num_samples.append(prompt.shape[0])
|
355 |
-
elif conditions:
|
356 |
-
possible_num_samples.append(len(conditions))
|
357 |
-
else:
|
358 |
-
possible_num_samples.append(1)
|
359 |
-
assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
|
360 |
-
num_samples = possible_num_samples[0]
|
361 |
-
|
362 |
# below we create set of conditions: one conditional and one unconditional
|
363 |
# to do that we merge the regular condition together with the null condition
|
364 |
# we then do 1 forward pass instead of 2.
|
@@ -380,6 +360,7 @@ class LMModel(StreamingModule):
|
|
380 |
if prompt is None:
|
381 |
assert num_samples > 0
|
382 |
prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
|
|
|
383 |
|
384 |
B, K, T = prompt.shape
|
385 |
start_offset = T
|
|
|
323 |
def generate(self,
|
324 |
prompt = None,
|
325 |
conditions = [],
|
326 |
+
num_samples = 1, # THIS IS HOW MANY GENERATIONS - A SAMPLE IS A FULL WAV
|
327 |
max_gen_len: int = 256,
|
328 |
use_sampling: bool = True,
|
329 |
temp: float = 1.0,
|
|
|
335 |
check: bool = False,
|
336 |
callback: tp.Optional[tp.Callable[[int, int], None]] = None,
|
337 |
**kwargs) -> torch.Tensor:
|
338 |
+
print(f'{num_samples=}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
first_param = next(iter(self.parameters()))
|
340 |
device = first_param.device
|
341 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
# below we create set of conditions: one conditional and one unconditional
|
343 |
# to do that we merge the regular condition together with the null condition
|
344 |
# we then do 1 forward pass instead of 2.
|
|
|
360 |
if prompt is None:
|
361 |
assert num_samples > 0
|
362 |
prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
|
363 |
+
print('\n\n\n\n DEFAULT PROMPT ZERO \n\n-')
|
364 |
|
365 |
B, K, T = prompt.shape
|
366 |
start_offset = T
|
audiocraft/utils/audio_utils.py
DELETED
@@ -1,176 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
"""Various utilities for audio convertion (pcm format, sample rate and channels),
|
7 |
-
and volume normalization."""
|
8 |
-
import sys
|
9 |
-
import typing as tp
|
10 |
-
|
11 |
-
import julius
|
12 |
-
import torch
|
13 |
-
import torchaudio
|
14 |
-
|
15 |
-
|
16 |
-
def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
|
17 |
-
"""Convert audio to the given number of channels.
|
18 |
-
|
19 |
-
Args:
|
20 |
-
wav (torch.Tensor): Audio wave of shape [B, C, T].
|
21 |
-
channels (int): Expected number of channels as output.
|
22 |
-
Returns:
|
23 |
-
torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
|
24 |
-
"""
|
25 |
-
*shape, src_channels, length = wav.shape
|
26 |
-
if src_channels == channels:
|
27 |
-
pass
|
28 |
-
elif channels == 1:
|
29 |
-
# Case 1:
|
30 |
-
# The caller asked 1-channel audio, and the stream has multiple
|
31 |
-
# channels, downmix all channels.
|
32 |
-
wav = wav.mean(dim=-2, keepdim=True)
|
33 |
-
elif src_channels == 1:
|
34 |
-
# Case 2:
|
35 |
-
# The caller asked for multiple channels, but the input file has
|
36 |
-
# a single channel, replicate the audio over all channels.
|
37 |
-
wav = wav.expand(*shape, channels, length)
|
38 |
-
elif src_channels >= channels:
|
39 |
-
# Case 3:
|
40 |
-
# The caller asked for multiple channels, and the input file has
|
41 |
-
# more channels than requested. In that case return the first channels.
|
42 |
-
wav = wav[..., :channels, :]
|
43 |
-
else:
|
44 |
-
# Case 4: What is a reasonable choice here?
|
45 |
-
raise ValueError('The audio file has less channels than requested but is not mono.')
|
46 |
-
return wav
|
47 |
-
|
48 |
-
|
49 |
-
def convert_audio(wav: torch.Tensor, from_rate: float,
|
50 |
-
to_rate: float, to_channels: int) -> torch.Tensor:
|
51 |
-
"""Convert audio to new sample rate and number of audio channels."""
|
52 |
-
wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
|
53 |
-
wav = convert_audio_channels(wav, to_channels)
|
54 |
-
return wav
|
55 |
-
|
56 |
-
|
57 |
-
def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
|
58 |
-
loudness_compressor: bool = False, energy_floor: float = 2e-3):
|
59 |
-
"""Normalize an input signal to a user loudness in dB LKFS.
|
60 |
-
Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
|
61 |
-
|
62 |
-
Args:
|
63 |
-
wav (torch.Tensor): Input multichannel audio data.
|
64 |
-
sample_rate (int): Sample rate.
|
65 |
-
loudness_headroom_db (float): Target loudness of the output in dB LUFS.
|
66 |
-
loudness_compressor (bool): Uses tanh for soft clipping.
|
67 |
-
energy_floor (float): anything below that RMS level will not be rescaled.
|
68 |
-
Returns:
|
69 |
-
torch.Tensor: Loudness normalized output data.
|
70 |
-
"""
|
71 |
-
energy = wav.pow(2).mean().sqrt().item()
|
72 |
-
if energy < energy_floor:
|
73 |
-
return wav
|
74 |
-
transform = torchaudio.transforms.Loudness(sample_rate)
|
75 |
-
input_loudness_db = transform(wav).item()
|
76 |
-
# calculate the gain needed to scale to the desired loudness level
|
77 |
-
delta_loudness = -loudness_headroom_db - input_loudness_db
|
78 |
-
gain = 10.0 ** (delta_loudness / 20.0)
|
79 |
-
output = gain * wav
|
80 |
-
if loudness_compressor:
|
81 |
-
output = torch.tanh(output)
|
82 |
-
assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
|
83 |
-
return output
|
84 |
-
|
85 |
-
|
86 |
-
def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
|
87 |
-
"""Utility function to clip the audio with logging if specified."""
|
88 |
-
max_scale = wav.abs().max()
|
89 |
-
if log_clipping and max_scale > 1:
|
90 |
-
clamp_prob = (wav.abs() > 1).float().mean().item()
|
91 |
-
print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
|
92 |
-
clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
|
93 |
-
wav.clamp_(-1, 1)
|
94 |
-
|
95 |
-
|
96 |
-
def normalize_audio(wav: torch.Tensor, normalize: bool = True,
|
97 |
-
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
|
98 |
-
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
|
99 |
-
loudness_compressor: bool = False, log_clipping: bool = False,
|
100 |
-
sample_rate: tp.Optional[int] = None,
|
101 |
-
stem_name: tp.Optional[str] = None) -> torch.Tensor:
|
102 |
-
"""Normalize the audio according to the prescribed strategy (see after).
|
103 |
-
|
104 |
-
Args:
|
105 |
-
wav (torch.Tensor): Audio data.
|
106 |
-
normalize (bool): if `True` (default), normalizes according to the prescribed
|
107 |
-
strategy (see after). If `False`, the strategy is only used in case clipping
|
108 |
-
would happen.
|
109 |
-
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
|
110 |
-
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
|
111 |
-
with extra headroom to avoid clipping. 'clip' just clips.
|
112 |
-
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
|
113 |
-
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
|
114 |
-
than the `peak_clip` one to avoid further clipping.
|
115 |
-
loudness_headroom_db (float): Target loudness for loudness normalization.
|
116 |
-
loudness_compressor (bool): If True, uses tanh based soft clipping.
|
117 |
-
log_clipping (bool): If True, basic logging on stderr when clipping still
|
118 |
-
occurs despite strategy (only for 'rms').
|
119 |
-
sample_rate (int): Sample rate for the audio data (required for loudness).
|
120 |
-
stem_name (str, optional): Stem name for clipping logging.
|
121 |
-
Returns:
|
122 |
-
torch.Tensor: Normalized audio.
|
123 |
-
"""
|
124 |
-
scale_peak = 10 ** (-peak_clip_headroom_db / 20)
|
125 |
-
scale_rms = 10 ** (-rms_headroom_db / 20)
|
126 |
-
if strategy == 'peak':
|
127 |
-
rescaling = (scale_peak / wav.abs().max())
|
128 |
-
if normalize or rescaling < 1:
|
129 |
-
wav = wav * rescaling
|
130 |
-
elif strategy == 'clip':
|
131 |
-
wav = wav.clamp(-scale_peak, scale_peak)
|
132 |
-
elif strategy == 'rms':
|
133 |
-
mono = wav.mean(dim=0)
|
134 |
-
rescaling = scale_rms / mono.pow(2).mean().sqrt()
|
135 |
-
if normalize or rescaling < 1:
|
136 |
-
wav = wav * rescaling
|
137 |
-
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
|
138 |
-
elif strategy == 'loudness':
|
139 |
-
assert sample_rate is not None, "Loudness normalization requires sample rate."
|
140 |
-
wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
|
141 |
-
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
|
142 |
-
else:
|
143 |
-
assert wav.abs().max() < 1
|
144 |
-
assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
|
145 |
-
return wav
|
146 |
-
|
147 |
-
|
148 |
-
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
149 |
-
"""Convert audio to float 32 bits PCM format.
|
150 |
-
"""
|
151 |
-
if wav.dtype.is_floating_point:
|
152 |
-
return wav
|
153 |
-
elif wav.dtype == torch.int16:
|
154 |
-
return wav.float() / 2**15
|
155 |
-
elif wav.dtype == torch.int32:
|
156 |
-
return wav.float() / 2**31
|
157 |
-
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
158 |
-
|
159 |
-
|
160 |
-
def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
|
161 |
-
"""Convert audio to int 16 bits PCM format.
|
162 |
-
|
163 |
-
..Warning:: There exist many formula for doing this conversion. None are perfect
|
164 |
-
due to the asymmetry of the int16 range. One either have possible clipping, DC offset,
|
165 |
-
or inconsistencies with f32_pcm. If the given wav doesn't have enough headroom,
|
166 |
-
it is possible that `i16_pcm(f32_pcm)) != Identity`.
|
167 |
-
"""
|
168 |
-
if wav.dtype.is_floating_point:
|
169 |
-
assert wav.abs().max() <= 1
|
170 |
-
candidate = (wav * 2 ** 15).round()
|
171 |
-
if candidate.max() >= 2 ** 15: # clipping would occur
|
172 |
-
candidate = (wav * (2 ** 15 - 1)).round()
|
173 |
-
return candidate.short()
|
174 |
-
else:
|
175 |
-
assert wav.dtype == torch.int16
|
176 |
-
return wav
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/utils/samples/__init__.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/utils/samples/manager.py
DELETED
@@ -1,386 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
"""
|
8 |
-
API that can manage the storage and retrieval of generated samples produced by experiments.
|
9 |
-
|
10 |
-
It offers the following benefits:
|
11 |
-
* Samples are stored in a consistent way across epoch
|
12 |
-
* Metadata about the samples can be stored and retrieved
|
13 |
-
* Can retrieve audio
|
14 |
-
* Identifiers are reliable and deterministic for prompted and conditioned samples
|
15 |
-
* Can request the samples for multiple XPs, grouped by sample identifier
|
16 |
-
* For no-input samples (not prompt and no conditions), samples across XPs are matched
|
17 |
-
by sorting their identifiers
|
18 |
-
"""
|
19 |
-
|
20 |
-
from concurrent.futures import ThreadPoolExecutor
|
21 |
-
from dataclasses import asdict, dataclass
|
22 |
-
from functools import lru_cache
|
23 |
-
import hashlib
|
24 |
-
import json
|
25 |
-
import logging
|
26 |
-
from pathlib import Path
|
27 |
-
import re
|
28 |
-
import typing as tp
|
29 |
-
import unicodedata
|
30 |
-
import uuid
|
31 |
-
|
32 |
-
import dora
|
33 |
-
import torch
|
34 |
-
|
35 |
-
from ...data.audio import audio_read, audio_write
|
36 |
-
|
37 |
-
|
38 |
-
logger = logging.getLogger(__name__)
|
39 |
-
|
40 |
-
|
41 |
-
@dataclass
|
42 |
-
class ReferenceSample:
|
43 |
-
id: str
|
44 |
-
path: str
|
45 |
-
duration: float
|
46 |
-
|
47 |
-
|
48 |
-
@dataclass
|
49 |
-
class Sample:
|
50 |
-
id: str
|
51 |
-
path: str
|
52 |
-
epoch: int
|
53 |
-
duration: float
|
54 |
-
conditioning: tp.Optional[tp.Dict[str, tp.Any]]
|
55 |
-
prompt: tp.Optional[ReferenceSample]
|
56 |
-
reference: tp.Optional[ReferenceSample]
|
57 |
-
generation_args: tp.Optional[tp.Dict[str, tp.Any]]
|
58 |
-
|
59 |
-
def __hash__(self):
|
60 |
-
return hash(self.id)
|
61 |
-
|
62 |
-
def audio(self) -> tp.Tuple[torch.Tensor, int]:
|
63 |
-
return audio_read(self.path)
|
64 |
-
|
65 |
-
def audio_prompt(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]:
|
66 |
-
return audio_read(self.prompt.path) if self.prompt is not None else None
|
67 |
-
|
68 |
-
def audio_reference(self) -> tp.Optional[tp.Tuple[torch.Tensor, int]]:
|
69 |
-
return audio_read(self.reference.path) if self.reference is not None else None
|
70 |
-
|
71 |
-
|
72 |
-
class SampleManager:
|
73 |
-
"""Audio samples IO handling within a given dora xp.
|
74 |
-
|
75 |
-
The sample manager handles the dumping and loading logic for generated and
|
76 |
-
references samples across epochs for a given xp, providing a simple API to
|
77 |
-
store, retrieve and compare audio samples.
|
78 |
-
|
79 |
-
Args:
|
80 |
-
xp (dora.XP): Dora experiment object. The XP contains information on the XP folder
|
81 |
-
where all outputs are stored and the configuration of the experiment,
|
82 |
-
which is useful to retrieve audio-related parameters.
|
83 |
-
map_reference_to_sample_id (bool): Whether to use the sample_id for all reference samples
|
84 |
-
instead of generating a dedicated hash id. This is useful to allow easier comparison
|
85 |
-
with ground truth sample from the files directly without having to read the JSON metadata
|
86 |
-
to do the mapping (at the cost of potentially dumping duplicate prompts/references
|
87 |
-
depending on the task).
|
88 |
-
"""
|
89 |
-
def __init__(self, xp: dora.XP, map_reference_to_sample_id: bool = False):
|
90 |
-
self.xp = xp
|
91 |
-
self.base_folder: Path = xp.folder / xp.cfg.generate.path
|
92 |
-
self.reference_folder = self.base_folder / 'reference'
|
93 |
-
self.map_reference_to_sample_id = map_reference_to_sample_id
|
94 |
-
self.samples: tp.List[Sample] = []
|
95 |
-
self._load_samples()
|
96 |
-
|
97 |
-
@property
|
98 |
-
def latest_epoch(self):
|
99 |
-
"""Latest epoch across all samples."""
|
100 |
-
return max(self.samples, key=lambda x: x.epoch).epoch if self.samples else 0
|
101 |
-
|
102 |
-
def _load_samples(self):
|
103 |
-
"""Scan the sample folder and load existing samples."""
|
104 |
-
jsons = self.base_folder.glob('**/*.json')
|
105 |
-
with ThreadPoolExecutor(6) as pool:
|
106 |
-
self.samples = list(pool.map(self._load_sample, jsons))
|
107 |
-
|
108 |
-
@staticmethod
|
109 |
-
@lru_cache(2**26)
|
110 |
-
def _load_sample(json_file: Path) -> Sample:
|
111 |
-
with open(json_file, 'r') as f:
|
112 |
-
data: tp.Dict[str, tp.Any] = json.load(f)
|
113 |
-
# fetch prompt data
|
114 |
-
prompt_data = data.get('prompt')
|
115 |
-
prompt = ReferenceSample(id=prompt_data['id'], path=prompt_data['path'],
|
116 |
-
duration=prompt_data['duration']) if prompt_data else None
|
117 |
-
# fetch reference data
|
118 |
-
reference_data = data.get('reference')
|
119 |
-
reference = ReferenceSample(id=reference_data['id'], path=reference_data['path'],
|
120 |
-
duration=reference_data['duration']) if reference_data else None
|
121 |
-
# build sample object
|
122 |
-
return Sample(id=data['id'], path=data['path'], epoch=data['epoch'], duration=data['duration'],
|
123 |
-
prompt=prompt, conditioning=data.get('conditioning'), reference=reference,
|
124 |
-
generation_args=data.get('generation_args'))
|
125 |
-
|
126 |
-
def _init_hash(self):
|
127 |
-
return hashlib.sha1()
|
128 |
-
|
129 |
-
def _get_tensor_id(self, tensor: torch.Tensor) -> str:
|
130 |
-
hash_id = self._init_hash()
|
131 |
-
hash_id.update(tensor.numpy().data)
|
132 |
-
return hash_id.hexdigest()
|
133 |
-
|
134 |
-
def _get_sample_id(self, index: int, prompt_wav: tp.Optional[torch.Tensor],
|
135 |
-
conditions: tp.Optional[tp.Dict[str, str]]) -> str:
|
136 |
-
"""Computes an id for a sample given its input data.
|
137 |
-
This id is deterministic if prompt and/or conditions are provided by using a sha1 hash on the input.
|
138 |
-
Otherwise, a random id of the form "noinput_{uuid4().hex}" is returned.
|
139 |
-
|
140 |
-
Args:
|
141 |
-
index (int): Batch index, Helpful to differentiate samples from the same batch.
|
142 |
-
prompt_wav (torch.Tensor): Prompt used during generation.
|
143 |
-
conditions (dict[str, str]): Conditioning used during generation.
|
144 |
-
"""
|
145 |
-
# For totally unconditioned generations we will just use a random UUID.
|
146 |
-
# The function get_samples_for_xps will do a simple ordered match with a custom key.
|
147 |
-
if prompt_wav is None and not conditions:
|
148 |
-
return f"noinput_{uuid.uuid4().hex}"
|
149 |
-
|
150 |
-
# Human readable portion
|
151 |
-
hr_label = ""
|
152 |
-
# Create a deterministic id using hashing
|
153 |
-
hash_id = self._init_hash()
|
154 |
-
hash_id.update(f"{index}".encode())
|
155 |
-
if prompt_wav is not None:
|
156 |
-
hash_id.update(prompt_wav.numpy().data)
|
157 |
-
hr_label += "_prompted"
|
158 |
-
else:
|
159 |
-
hr_label += "_unprompted"
|
160 |
-
if conditions:
|
161 |
-
encoded_json = json.dumps(conditions, sort_keys=True).encode()
|
162 |
-
hash_id.update(encoded_json)
|
163 |
-
cond_str = "-".join([f"{key}={slugify(value)}"
|
164 |
-
for key, value in sorted(conditions.items())])
|
165 |
-
cond_str = cond_str[:100] # some raw text might be too long to be a valid filename
|
166 |
-
cond_str = cond_str if len(cond_str) > 0 else "unconditioned"
|
167 |
-
hr_label += f"_{cond_str}"
|
168 |
-
else:
|
169 |
-
hr_label += "_unconditioned"
|
170 |
-
|
171 |
-
return hash_id.hexdigest() + hr_label
|
172 |
-
|
173 |
-
def _store_audio(self, wav: torch.Tensor, stem_path: Path, overwrite: bool = False) -> Path:
|
174 |
-
"""Stores the audio with the given stem path using the XP's configuration.
|
175 |
-
|
176 |
-
Args:
|
177 |
-
wav (torch.Tensor): Audio to store.
|
178 |
-
stem_path (Path): Path in sample output directory with file stem to use.
|
179 |
-
overwrite (bool): When False (default), skips storing an existing audio file.
|
180 |
-
Returns:
|
181 |
-
Path: The path at which the audio is stored.
|
182 |
-
"""
|
183 |
-
existing_paths = [
|
184 |
-
path for path in stem_path.parent.glob(stem_path.stem + '.*')
|
185 |
-
if path.suffix != '.json'
|
186 |
-
]
|
187 |
-
exists = len(existing_paths) > 0
|
188 |
-
if exists and overwrite:
|
189 |
-
logger.warning(f"Overwriting existing audio file with stem path {stem_path}")
|
190 |
-
elif exists:
|
191 |
-
return existing_paths[0]
|
192 |
-
|
193 |
-
audio_path = audio_write(stem_path, wav, **self.xp.cfg.generate.audio)
|
194 |
-
return audio_path
|
195 |
-
|
196 |
-
def add_sample(self, sample_wav: torch.Tensor, epoch: int, index: int = 0,
|
197 |
-
conditions: tp.Optional[tp.Dict[str, str]] = None, prompt_wav: tp.Optional[torch.Tensor] = None,
|
198 |
-
ground_truth_wav: tp.Optional[torch.Tensor] = None,
|
199 |
-
generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> Sample:
|
200 |
-
"""Adds a single sample.
|
201 |
-
The sample is stored in the XP's sample output directory, under a corresponding epoch folder.
|
202 |
-
Each sample is assigned an id which is computed using the input data. In addition to the
|
203 |
-
sample itself, a json file containing associated metadata is stored next to it.
|
204 |
-
|
205 |
-
Args:
|
206 |
-
sample_wav (torch.Tensor): sample audio to store. Tensor of shape [channels, shape].
|
207 |
-
epoch (int): current training epoch.
|
208 |
-
index (int): helpful to differentiate samples from the same batch.
|
209 |
-
conditions (dict[str, str], optional): conditioning used during generation.
|
210 |
-
prompt_wav (torch.Tensor, optional): prompt used during generation. Tensor of shape [channels, shape].
|
211 |
-
ground_truth_wav (torch.Tensor, optional): reference audio where prompt was extracted from.
|
212 |
-
Tensor of shape [channels, shape].
|
213 |
-
generation_args (dict[str, any], optional): dictionary of other arguments used during generation.
|
214 |
-
Returns:
|
215 |
-
Sample: The saved sample.
|
216 |
-
"""
|
217 |
-
sample_id = self._get_sample_id(index, prompt_wav, conditions)
|
218 |
-
reuse_id = self.map_reference_to_sample_id
|
219 |
-
prompt, ground_truth = None, None
|
220 |
-
if prompt_wav is not None:
|
221 |
-
prompt_id = sample_id if reuse_id else self._get_tensor_id(prompt_wav.sum(0, keepdim=True))
|
222 |
-
prompt_duration = prompt_wav.shape[-1] / self.xp.cfg.sample_rate
|
223 |
-
prompt_path = self._store_audio(prompt_wav, self.base_folder / str(epoch) / 'prompt' / prompt_id)
|
224 |
-
prompt = ReferenceSample(prompt_id, str(prompt_path), prompt_duration)
|
225 |
-
if ground_truth_wav is not None:
|
226 |
-
ground_truth_id = sample_id if reuse_id else self._get_tensor_id(ground_truth_wav.sum(0, keepdim=True))
|
227 |
-
ground_truth_duration = ground_truth_wav.shape[-1] / self.xp.cfg.sample_rate
|
228 |
-
ground_truth_path = self._store_audio(ground_truth_wav, self.base_folder / 'reference' / ground_truth_id)
|
229 |
-
ground_truth = ReferenceSample(ground_truth_id, str(ground_truth_path), ground_truth_duration)
|
230 |
-
sample_path = self._store_audio(sample_wav, self.base_folder / str(epoch) / sample_id, overwrite=True)
|
231 |
-
duration = sample_wav.shape[-1] / self.xp.cfg.sample_rate
|
232 |
-
sample = Sample(sample_id, str(sample_path), epoch, duration, conditions, prompt, ground_truth, generation_args)
|
233 |
-
self.samples.append(sample)
|
234 |
-
with open(sample_path.with_suffix('.json'), 'w') as f:
|
235 |
-
json.dump(asdict(sample), f, indent=2)
|
236 |
-
return sample
|
237 |
-
|
238 |
-
def add_samples(self, samples_wavs: torch.Tensor, epoch: int,
|
239 |
-
conditioning: tp.Optional[tp.List[tp.Dict[str, tp.Any]]] = None,
|
240 |
-
prompt_wavs: tp.Optional[torch.Tensor] = None,
|
241 |
-
ground_truth_wavs: tp.Optional[torch.Tensor] = None,
|
242 |
-
generation_args: tp.Optional[tp.Dict[str, tp.Any]] = None) -> tp.List[Sample]:
|
243 |
-
"""Adds a batch of samples.
|
244 |
-
The samples are stored in the XP's sample output directory, under a corresponding
|
245 |
-
epoch folder. Each sample is assigned an id which is computed using the input data and their batch index.
|
246 |
-
In addition to the sample itself, a json file containing associated metadata is stored next to it.
|
247 |
-
|
248 |
-
Args:
|
249 |
-
sample_wavs (torch.Tensor): Batch of audio wavs to store. Tensor of shape [batch_size, channels, shape].
|
250 |
-
epoch (int): Current training epoch.
|
251 |
-
conditioning (list of dict[str, str], optional): List of conditions used during generation,
|
252 |
-
one per sample in the batch.
|
253 |
-
prompt_wavs (torch.Tensor, optional): Prompts used during generation. Tensor of shape
|
254 |
-
[batch_size, channels, shape].
|
255 |
-
ground_truth_wav (torch.Tensor, optional): Reference audio where prompts were extracted from.
|
256 |
-
Tensor of shape [batch_size, channels, shape].
|
257 |
-
generation_args (dict[str, Any], optional): Dictionary of other arguments used during generation.
|
258 |
-
Returns:
|
259 |
-
samples (list of Sample): The saved audio samples with prompts, ground truth and metadata.
|
260 |
-
"""
|
261 |
-
samples = []
|
262 |
-
for idx, wav in enumerate(samples_wavs):
|
263 |
-
prompt_wav = prompt_wavs[idx] if prompt_wavs is not None else None
|
264 |
-
gt_wav = ground_truth_wavs[idx] if ground_truth_wavs is not None else None
|
265 |
-
conditions = conditioning[idx] if conditioning is not None else None
|
266 |
-
samples.append(self.add_sample(wav, epoch, idx, conditions, prompt_wav, gt_wav, generation_args))
|
267 |
-
return samples
|
268 |
-
|
269 |
-
def get_samples(self, epoch: int = -1, max_epoch: int = -1, exclude_prompted: bool = False,
|
270 |
-
exclude_unprompted: bool = False, exclude_conditioned: bool = False,
|
271 |
-
exclude_unconditioned: bool = False) -> tp.Set[Sample]:
|
272 |
-
"""Returns a set of samples for this XP. Optionally, you can filter which samples to obtain.
|
273 |
-
Please note that existing samples are loaded during the manager's initialization, and added samples through this
|
274 |
-
manager are also tracked. Any other external changes are not tracked automatically, so creating a new manager
|
275 |
-
is the only way detect them.
|
276 |
-
|
277 |
-
Args:
|
278 |
-
epoch (int): If provided, only return samples corresponding to this epoch.
|
279 |
-
max_epoch (int): If provided, only return samples corresponding to the latest epoch that is <= max_epoch.
|
280 |
-
exclude_prompted (bool): If True, does not include samples that used a prompt.
|
281 |
-
exclude_unprompted (bool): If True, does not include samples that did not use a prompt.
|
282 |
-
exclude_conditioned (bool): If True, excludes samples that used conditioning.
|
283 |
-
exclude_unconditioned (bool): If True, excludes samples that did not use conditioning.
|
284 |
-
Returns:
|
285 |
-
Samples (set of Sample): The retrieved samples matching the provided filters.
|
286 |
-
"""
|
287 |
-
if max_epoch >= 0:
|
288 |
-
samples_epoch = max(sample.epoch for sample in self.samples if sample.epoch <= max_epoch)
|
289 |
-
else:
|
290 |
-
samples_epoch = self.latest_epoch if epoch < 0 else epoch
|
291 |
-
samples = {
|
292 |
-
sample
|
293 |
-
for sample in self.samples
|
294 |
-
if (
|
295 |
-
(sample.epoch == samples_epoch) and
|
296 |
-
(not exclude_prompted or sample.prompt is None) and
|
297 |
-
(not exclude_unprompted or sample.prompt is not None) and
|
298 |
-
(not exclude_conditioned or not sample.conditioning) and
|
299 |
-
(not exclude_unconditioned or sample.conditioning)
|
300 |
-
)
|
301 |
-
}
|
302 |
-
return samples
|
303 |
-
|
304 |
-
|
305 |
-
def slugify(value: tp.Any, allow_unicode: bool = False):
|
306 |
-
"""Process string for safer file naming.
|
307 |
-
|
308 |
-
Taken from https://github.com/django/django/blob/master/django/utils/text.py
|
309 |
-
|
310 |
-
Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
|
311 |
-
dashes to single dashes. Remove characters that aren't alphanumerics,
|
312 |
-
underscores, or hyphens. Convert to lowercase. Also strip leading and
|
313 |
-
trailing whitespace, dashes, and underscores.
|
314 |
-
"""
|
315 |
-
value = str(value)
|
316 |
-
if allow_unicode:
|
317 |
-
value = unicodedata.normalize("NFKC", value)
|
318 |
-
else:
|
319 |
-
value = (
|
320 |
-
unicodedata.normalize("NFKD", value)
|
321 |
-
.encode("ascii", "ignore")
|
322 |
-
.decode("ascii")
|
323 |
-
)
|
324 |
-
value = re.sub(r"[^\w\s-]", "", value.lower())
|
325 |
-
return re.sub(r"[-\s]+", "-", value).strip("-_")
|
326 |
-
|
327 |
-
|
328 |
-
def _match_stable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]:
|
329 |
-
# Create a dictionary of stable id -> sample per XP
|
330 |
-
stable_samples_per_xp = [{
|
331 |
-
sample.id: sample for sample in samples
|
332 |
-
if sample.prompt is not None or sample.conditioning
|
333 |
-
} for samples in samples_per_xp]
|
334 |
-
# Set of all stable ids
|
335 |
-
stable_ids = {id for samples in stable_samples_per_xp for id in samples.keys()}
|
336 |
-
# Dictionary of stable id -> list of samples. If an XP does not have it, assign None
|
337 |
-
stable_samples = {id: [xp.get(id) for xp in stable_samples_per_xp] for id in stable_ids}
|
338 |
-
# Filter out ids that contain None values (we only want matched samples after all)
|
339 |
-
# cast is necessary to avoid mypy linter errors.
|
340 |
-
return {id: tp.cast(tp.List[Sample], samples) for id, samples in stable_samples.items() if None not in samples}
|
341 |
-
|
342 |
-
|
343 |
-
def _match_unstable_samples(samples_per_xp: tp.List[tp.Set[Sample]]) -> tp.Dict[str, tp.List[Sample]]:
|
344 |
-
# For unstable ids, we use a sorted list since we'll match them in order
|
345 |
-
unstable_samples_per_xp = [[
|
346 |
-
sample for sample in sorted(samples, key=lambda x: x.id)
|
347 |
-
if sample.prompt is None and not sample.conditioning
|
348 |
-
] for samples in samples_per_xp]
|
349 |
-
# Trim samples per xp so all samples can have a match
|
350 |
-
min_len = min([len(samples) for samples in unstable_samples_per_xp])
|
351 |
-
unstable_samples_per_xp = [samples[:min_len] for samples in unstable_samples_per_xp]
|
352 |
-
# Dictionary of index -> list of matched samples
|
353 |
-
return {
|
354 |
-
f'noinput_{i}': [samples[i] for samples in unstable_samples_per_xp] for i in range(min_len)
|
355 |
-
}
|
356 |
-
|
357 |
-
|
358 |
-
def get_samples_for_xps(xps: tp.List[dora.XP], **kwargs) -> tp.Dict[str, tp.List[Sample]]:
|
359 |
-
"""Gets a dictionary of matched samples across the given XPs.
|
360 |
-
Each dictionary entry maps a sample id to a list of samples for that id. The number of samples per id
|
361 |
-
will always match the number of XPs provided and will correspond to each XP in the same order given.
|
362 |
-
In other words, only samples that can be match across all provided XPs will be returned
|
363 |
-
in order to satisfy this rule.
|
364 |
-
|
365 |
-
There are two types of ids that can be returned: stable and unstable.
|
366 |
-
* Stable IDs are deterministic ids that were computed by the SampleManager given a sample's inputs
|
367 |
-
(prompts/conditioning). This is why we can match them across XPs.
|
368 |
-
* Unstable IDs are of the form "noinput_{idx}" and are generated on-the-fly, in order to map samples
|
369 |
-
that used non-deterministic, random ids. This is the case for samples that did not use prompts or
|
370 |
-
conditioning for their generation. This function will sort these samples by their id and match them
|
371 |
-
by their index.
|
372 |
-
|
373 |
-
Args:
|
374 |
-
xps: a list of XPs to match samples from.
|
375 |
-
start_epoch (int): If provided, only return samples corresponding to this epoch or newer.
|
376 |
-
end_epoch (int): If provided, only return samples corresponding to this epoch or older.
|
377 |
-
exclude_prompted (bool): If True, does not include samples that used a prompt.
|
378 |
-
exclude_unprompted (bool): If True, does not include samples that did not use a prompt.
|
379 |
-
exclude_conditioned (bool): If True, excludes samples that used conditioning.
|
380 |
-
exclude_unconditioned (bool): If True, excludes samples that did not use conditioning.
|
381 |
-
"""
|
382 |
-
managers = [SampleManager(xp) for xp in xps]
|
383 |
-
samples_per_xp = [manager.get_samples(**kwargs) for manager in managers]
|
384 |
-
stable_samples = _match_stable_samples(samples_per_xp)
|
385 |
-
unstable_samples = _match_unstable_samples(samples_per_xp)
|
386 |
-
return dict(stable_samples, **unstable_samples)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|