instantiate audiogen in demo
Browse files- audiocraft/audiogen.py +14 -77
- audiocraft/builders.py +5 -3
- audiocraft/conditioners.py +22 -89
- audiocraft/lm.py +29 -3
- audiocraft/loaders.py +0 -25
- audiocraft/utils/__init__.py +0 -6
- audiocraft/utils/autocast.py +0 -40
- audiocraft/utils/cache.py +0 -324
- audiocraft/utils/checkpoint.py +0 -161
- audiocraft/utils/export.py +0 -79
- audiocraft/utils/export_legacy.py +0 -70
- audiocraft/utils/notebook.py +0 -32
- audiocraft/utils/profiler.py +0 -38
- audiocraft/utils/utils.py +0 -98
- demo.py +74 -4
audiocraft/audiogen.py
CHANGED
@@ -1,9 +1,3 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
import typing as tp
|
8 |
import torch
|
9 |
from audiocraft.loaders import load_compression_model, load_lm_model
|
@@ -11,7 +5,6 @@ import typing as tp
|
|
11 |
import omegaconf
|
12 |
import torch
|
13 |
import numpy as np
|
14 |
-
from abc import ABC, abstractmethod
|
15 |
from .lm import LMModel
|
16 |
from .conditioners import ConditioningAttributes
|
17 |
from .utils.autocast import TorchAutocast
|
@@ -25,7 +18,7 @@ def _shift(x):
|
|
25 |
return x
|
26 |
|
27 |
|
28 |
-
class
|
29 |
"""Base generative model with convenient generation API.
|
30 |
|
31 |
Args:
|
@@ -90,21 +83,13 @@ class BaseGenModel(ABC):
|
|
90 |
attributes = [
|
91 |
ConditioningAttributes(text={'description': d}) for d in descriptions]
|
92 |
tokens = self._generate_tokens(attributes)
|
|
|
93 |
return self.generate_audio(tokens)
|
94 |
|
95 |
def _generate_tokens(self, attributes):
|
96 |
|
97 |
total_gen_len = int(self.duration * self.frame_rate)
|
98 |
|
99 |
-
|
100 |
-
# # print(f'{self.generation_params=}')
|
101 |
-
# self.generation_params={'use_sampling': True,
|
102 |
-
# 'temp': 1.0, 'top_k': 250,
|
103 |
-
# 'top_p': 0.0, 'cfg_coef': 2.4, 'two_step_cfg': False}
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
if self.duration <= self.max_duration:
|
109 |
# generate by sampling from LM, simple case.
|
110 |
|
@@ -127,66 +112,18 @@ class BaseGenModel(ABC):
|
|
127 |
gen_audio = self.compression_model.decode(gen_tokens, None)
|
128 |
return gen_audio
|
129 |
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
class AudioGen(BaseGenModel):
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
if torch.cuda.device_count():
|
150 |
-
device = 'cuda'
|
151 |
-
else:
|
152 |
-
device = 'cpu'
|
153 |
|
154 |
-
|
155 |
|
156 |
-
compression_model = load_compression_model(name, device=device)
|
157 |
-
lm = load_lm_model(name, device=device)
|
158 |
-
assert 'self_wav' not in lm.condition_provider.conditioners, \
|
159 |
-
"AudioGen do not support waveform conditioning for now"
|
160 |
-
return AudioGen(name, compression_model, lm)
|
161 |
-
|
162 |
-
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
163 |
-
top_p: float = 0.0, temperature: float = 1.0,
|
164 |
-
duration: float = 10.0, cfg_coef: float = 2.4,
|
165 |
-
two_step_cfg: bool = False, extend_stride: float = 2):
|
166 |
-
"""Set the generation parameters for AudioGen.
|
167 |
-
|
168 |
-
Args:
|
169 |
-
use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
|
170 |
-
top_k (int, optional): top_k used for sampling. Defaults to 250.
|
171 |
-
top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
|
172 |
-
temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
|
173 |
-
duration (float, optional): Duration of the generated waveform. Defaults to 10.0.
|
174 |
-
cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
|
175 |
-
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
|
176 |
-
instead of batching together the two. This has some impact on how things
|
177 |
-
are padded but seems to have little impact in practice.
|
178 |
-
extend_stride: when doing extended generation (i.e. more than 10 seconds), by how much
|
179 |
-
should we extend the audio each time. Larger values will mean less context is
|
180 |
-
preserved, and shorter value will require extra computations.
|
181 |
-
"""
|
182 |
-
assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
|
183 |
-
self.extend_stride = extend_stride
|
184 |
-
self.duration = duration
|
185 |
-
self.generation_params = {
|
186 |
-
'use_sampling': use_sampling,
|
187 |
-
'temp': temperature,
|
188 |
-
'top_k': top_k,
|
189 |
-
'top_p': top_p,
|
190 |
-
'cfg_coef': cfg_coef,
|
191 |
-
'two_step_cfg': two_step_cfg,
|
192 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import typing as tp
|
2 |
import torch
|
3 |
from audiocraft.loaders import load_compression_model, load_lm_model
|
|
|
5 |
import omegaconf
|
6 |
import torch
|
7 |
import numpy as np
|
|
|
8 |
from .lm import LMModel
|
9 |
from .conditioners import ConditioningAttributes
|
10 |
from .utils.autocast import TorchAutocast
|
|
|
18 |
return x
|
19 |
|
20 |
|
21 |
+
class AudioGen():
|
22 |
"""Base generative model with convenient generation API.
|
23 |
|
24 |
Args:
|
|
|
83 |
attributes = [
|
84 |
ConditioningAttributes(text={'description': d}) for d in descriptions]
|
85 |
tokens = self._generate_tokens(attributes)
|
86 |
+
print(f'\n{tokens.shape=}\n{tokens=} FINAL 5 AUD')
|
87 |
return self.generate_audio(tokens)
|
88 |
|
89 |
def _generate_tokens(self, attributes):
|
90 |
|
91 |
total_gen_len = int(self.duration * self.frame_rate)
|
92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
if self.duration <= self.max_duration:
|
94 |
# generate by sampling from LM, simple case.
|
95 |
|
|
|
112 |
gen_audio = self.compression_model.decode(gen_tokens, None)
|
113 |
return gen_audio
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
+
def get_pretrained(name='facebook/audiogen-medium',
|
117 |
+
device=None):
|
118 |
+
"""Return pretrained model, we provide a single model for now:
|
119 |
+
- facebook/audiogen-medium (1.5B), text to sound,
|
120 |
+
# see: https://huggingface.co/facebook/audiogen-medium
|
121 |
+
"""
|
122 |
+
compression_model = load_compression_model(name, device=device)
|
123 |
+
lm = load_lm_model(name, device=device)
|
124 |
+
assert 'self_wav' not in lm.condition_provider.conditioners, \
|
125 |
+
"AudioGen do not support waveform conditioning for now"
|
126 |
+
return AudioGen(name, compression_model, lm)
|
127 |
+
|
|
|
|
|
|
|
|
|
128 |
|
|
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/builders.py
CHANGED
@@ -12,16 +12,18 @@ from .lm import LMModel
|
|
12 |
from .seanet import SEANetDecoder
|
13 |
from .codebooks_patterns import DelayedPatternProvider
|
14 |
from .conditioners import (
|
15 |
-
BaseConditioner,
|
16 |
ConditionFuser,
|
17 |
ConditioningProvider,
|
18 |
T5Conditioner,
|
19 |
)
|
20 |
from .unet import DiffusionUnet
|
21 |
from .vq import ResidualVectorQuantizer
|
22 |
-
|
23 |
from .diffusion_schedule import MultiBandProcessor, SampleProcessor
|
24 |
|
|
|
|
|
|
|
25 |
|
26 |
def get_quantizer(quantizer, cfg, dimension):
|
27 |
klass = {
|
@@ -118,7 +120,7 @@ def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> Cond
|
|
118 |
duration = cfg.dataset.segment_duration
|
119 |
cfg = getattr(cfg, 'conditioners')
|
120 |
dict_cfg = {} if cfg is None else dict_from_config(cfg)
|
121 |
-
conditioners: tp.Dict[str,
|
122 |
condition_provider_args = dict_cfg.pop('args', {})
|
123 |
condition_provider_args.pop('merge_text_conditions_p', None)
|
124 |
condition_provider_args.pop('drop_desc_p', None)
|
|
|
12 |
from .seanet import SEANetDecoder
|
13 |
from .codebooks_patterns import DelayedPatternProvider
|
14 |
from .conditioners import (
|
|
|
15 |
ConditionFuser,
|
16 |
ConditioningProvider,
|
17 |
T5Conditioner,
|
18 |
)
|
19 |
from .unet import DiffusionUnet
|
20 |
from .vq import ResidualVectorQuantizer
|
21 |
+
|
22 |
from .diffusion_schedule import MultiBandProcessor, SampleProcessor
|
23 |
|
24 |
+
def dict_from_config(cfg):
|
25 |
+
dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
|
26 |
+
return dct
|
27 |
|
28 |
def get_quantizer(quantizer, cfg, dimension):
|
29 |
klass = {
|
|
|
120 |
duration = cfg.dataset.segment_duration
|
121 |
cfg = getattr(cfg, 'conditioners')
|
122 |
dict_cfg = {} if cfg is None else dict_from_config(cfg)
|
123 |
+
conditioners: tp.Dict[str, T5Conditioner] = {}
|
124 |
condition_provider_args = dict_cfg.pop('args', {})
|
125 |
condition_provider_args.pop('merge_text_conditions_p', None)
|
126 |
condition_provider_args.pop('drop_desc_p', None)
|
audiocraft/conditioners.py
CHANGED
@@ -8,12 +8,6 @@ import soundfile
|
|
8 |
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
9 |
import torch
|
10 |
from torch import nn
|
11 |
-
|
12 |
-
|
13 |
-
from .utils.autocast import TorchAutocast
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
logger = logging.getLogger(__name__)
|
18 |
TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
|
19 |
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
|
@@ -21,7 +15,6 @@ ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
|
|
21 |
|
22 |
|
23 |
|
24 |
-
|
25 |
class JointEmbedCondition(tp.NamedTuple):
|
26 |
wav: torch.Tensor
|
27 |
text: tp.List[tp.Optional[str]]
|
@@ -76,18 +69,6 @@ class ConditioningAttributes:
|
|
76 |
return out
|
77 |
|
78 |
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
class Tokenizer:
|
92 |
"""Base tokenizer implementation
|
93 |
(in case we want to introduce more advances tokenizers in the future).
|
@@ -95,59 +76,10 @@ class Tokenizer:
|
|
95 |
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
96 |
raise NotImplementedError()
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
class BaseConditioner(nn.Module):
|
106 |
-
"""Base model for all conditioner modules.
|
107 |
-
We allow the output dim to be different than the hidden dim for two reasons:
|
108 |
-
1) keep our LUTs small when the vocab is large;
|
109 |
-
2) make all condition dims consistent.
|
110 |
-
|
111 |
-
Args:
|
112 |
-
dim (int): Hidden dim of the model.
|
113 |
-
output_dim (int): Output dim of the conditioner.
|
114 |
-
"""
|
115 |
-
def __init__(self, dim: int, output_dim: int):
|
116 |
-
super().__init__()
|
117 |
-
self.dim = dim
|
118 |
-
self.output_dim = output_dim
|
119 |
-
self.output_proj = nn.Linear(dim, output_dim)
|
120 |
-
|
121 |
-
def tokenize(self, *args, **kwargs) -> tp.Any:
|
122 |
-
"""Should be any part of the processing that will lead to a synchronization
|
123 |
-
point, e.g. BPE tokenization with transfer to the GPU.
|
124 |
-
|
125 |
-
The returned value will be saved and return later when calling forward().
|
126 |
-
"""
|
127 |
-
raise NotImplementedError()
|
128 |
-
|
129 |
|
130 |
|
|
|
131 |
|
132 |
-
class TextConditioner(BaseConditioner):
|
133 |
-
...
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
class T5Conditioner(TextConditioner):
|
140 |
-
"""T5-based TextConditioner.
|
141 |
-
|
142 |
-
Args:
|
143 |
-
name (str): Name of the T5 model.
|
144 |
-
output_dim (int): Output dim of the conditioner.
|
145 |
-
finetune (bool): Whether to fine-tune T5 at train time.
|
146 |
-
device (str): Device for T5 Conditioner.
|
147 |
-
autocast_dtype (tp.Optional[str], optional): Autocast dtype.
|
148 |
-
word_dropout (float, optional): Word dropout probability.
|
149 |
-
normalize_text (bool, optional): Whether to apply text normalization.
|
150 |
-
"""
|
151 |
MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
|
152 |
"google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
|
153 |
"google/flan-t5-xl", "google/flan-t5-xxl"]
|
@@ -164,24 +96,23 @@ class T5Conditioner(TextConditioner):
|
|
164 |
"google/flan-t5-11b": 1024,
|
165 |
}
|
166 |
|
167 |
-
def __init__(self,
|
168 |
-
|
169 |
-
|
|
|
|
|
|
|
|
|
|
|
170 |
assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
|
171 |
-
super().__init__(
|
|
|
|
|
|
|
172 |
self.device = device
|
173 |
self.name = name
|
174 |
-
self.finetune = finetune
|
175 |
self.word_dropout = word_dropout
|
176 |
-
|
177 |
-
self.autocast = TorchAutocast(enabled=False)
|
178 |
-
if self.device != 'cpu':
|
179 |
-
logger.warning("T5 has no autocast, this might lead to NaN")
|
180 |
-
else:
|
181 |
-
dtype = getattr(torch, autocast_dtype)
|
182 |
-
assert isinstance(dtype, torch.dtype)
|
183 |
-
logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}")
|
184 |
-
self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
|
185 |
# Let's disable logging temporarily because T5 will vomit some errors otherwise.
|
186 |
# thanks https://gist.github.com/simon-weber/7853144
|
187 |
previous_level = logging.root.manager.disable
|
@@ -190,7 +121,7 @@ class T5Conditioner(TextConditioner):
|
|
190 |
warnings.simplefilter("ignore")
|
191 |
try:
|
192 |
self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
|
193 |
-
t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune)
|
194 |
finally:
|
195 |
logging.disable(previous_level)
|
196 |
if finetune:
|
@@ -223,15 +154,15 @@ class T5Conditioner(TextConditioner):
|
|
223 |
mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
|
224 |
return inputs
|
225 |
|
226 |
-
def forward(self, inputs
|
227 |
mask = inputs['attention_mask']
|
228 |
-
with torch.
|
229 |
embeds = self.t5(**inputs).last_hidden_state
|
230 |
embeds = self.output_proj(embeds.to(self.output_proj.weight))
|
231 |
embeds = (embeds * mask.unsqueeze(-1))
|
232 |
|
233 |
# T5 torch.Size([2, 4, 1536]) dict_keys(['input_ids', 'attention_mask'])
|
234 |
-
|
235 |
return embeds, mask
|
236 |
|
237 |
|
@@ -248,7 +179,9 @@ class ConditioningProvider(nn.Module):
|
|
248 |
conditioners (dict): Dictionary of conditioners.
|
249 |
device (torch.device or str, optional): Device for conditioners and output condition types.
|
250 |
"""
|
251 |
-
def __init__(self,
|
|
|
|
|
252 |
super().__init__()
|
253 |
self.device = device
|
254 |
self.conditioners = nn.ModuleDict(conditioners)
|
@@ -263,7 +196,7 @@ class ConditioningProvider(nn.Module):
|
|
263 |
|
264 |
@property
|
265 |
def text_conditions(self):
|
266 |
-
return [k for k, v in self.conditioners.items() if isinstance(v,
|
267 |
|
268 |
|
269 |
|
|
|
8 |
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
9 |
import torch
|
10 |
from torch import nn
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
logger = logging.getLogger(__name__)
|
12 |
TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
|
13 |
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
|
|
|
15 |
|
16 |
|
17 |
|
|
|
18 |
class JointEmbedCondition(tp.NamedTuple):
|
19 |
wav: torch.Tensor
|
20 |
text: tp.List[tp.Optional[str]]
|
|
|
69 |
return out
|
70 |
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
class Tokenizer:
|
73 |
"""Base tokenizer implementation
|
74 |
(in case we want to introduce more advances tokenizers in the future).
|
|
|
76 |
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
77 |
raise NotImplementedError()
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
|
81 |
+
class T5Conditioner(nn.Module):
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
|
84 |
"google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
|
85 |
"google/flan-t5-xl", "google/flan-t5-xxl"]
|
|
|
96 |
"google/flan-t5-11b": 1024,
|
97 |
}
|
98 |
|
99 |
+
def __init__(self,
|
100 |
+
name: str,
|
101 |
+
output_dim: int,
|
102 |
+
device: str,
|
103 |
+
word_dropout: float = 0.,
|
104 |
+
normalize_text: bool = False,
|
105 |
+
finetune=False):
|
106 |
+
print(f'{finetune=}')
|
107 |
assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
|
108 |
+
super().__init__()
|
109 |
+
self.dim = self.MODELS_DIMS[name]
|
110 |
+
self.output_dim = output_dim
|
111 |
+
self.output_proj = nn.Linear(self.dim, output_dim)
|
112 |
self.device = device
|
113 |
self.name = name
|
|
|
114 |
self.word_dropout = word_dropout
|
115 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
# Let's disable logging temporarily because T5 will vomit some errors otherwise.
|
117 |
# thanks https://gist.github.com/simon-weber/7853144
|
118 |
previous_level = logging.root.manager.disable
|
|
|
121 |
warnings.simplefilter("ignore")
|
122 |
try:
|
123 |
self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
|
124 |
+
t5 = T5EncoderModel.from_pretrained(name).eval() #.train(mode=finetune)
|
125 |
finally:
|
126 |
logging.disable(previous_level)
|
127 |
if finetune:
|
|
|
154 |
mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
|
155 |
return inputs
|
156 |
|
157 |
+
def forward(self, inputs):
|
158 |
mask = inputs['attention_mask']
|
159 |
+
with torch.no_grad():
|
160 |
embeds = self.t5(**inputs).last_hidden_state
|
161 |
embeds = self.output_proj(embeds.to(self.output_proj.weight))
|
162 |
embeds = (embeds * mask.unsqueeze(-1))
|
163 |
|
164 |
# T5 torch.Size([2, 4, 1536]) dict_keys(['input_ids', 'attention_mask'])
|
165 |
+
print(f'{embeds.dtype=}') # inputs["input_ids"].shape=torch.Size([2, 4])
|
166 |
return embeds, mask
|
167 |
|
168 |
|
|
|
179 |
conditioners (dict): Dictionary of conditioners.
|
180 |
device (torch.device or str, optional): Device for conditioners and output condition types.
|
181 |
"""
|
182 |
+
def __init__(self,
|
183 |
+
conditioners,
|
184 |
+
device="cpu"):
|
185 |
super().__init__()
|
186 |
self.device = device
|
187 |
self.conditioners = nn.ModuleDict(conditioners)
|
|
|
196 |
|
197 |
@property
|
198 |
def text_conditions(self):
|
199 |
+
return [k for k, v in self.conditioners.items() if isinstance(v, T5Conditioner)]
|
200 |
|
201 |
|
202 |
|
audiocraft/lm.py
CHANGED
@@ -10,9 +10,35 @@ from audiocraft.transformer import StreamingTransformer, create_norm_fn
|
|
10 |
from dataclasses import dataclass
|
11 |
from functools import partial
|
12 |
from torch import nn
|
13 |
-
from audiocraft.utils import utils
|
14 |
from audiocraft.activations import get_activation_fn
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
# ============================================== From LM.py
|
18 |
|
@@ -147,7 +173,7 @@ class LMModel(nn.Module):
|
|
147 |
super().__init__()
|
148 |
self.cfg_coef = cfg_coef
|
149 |
|
150 |
-
self.n_draw =
|
151 |
self.condition_provider = condition_provider
|
152 |
self.fuser = fuser
|
153 |
self.card = card # 2048 ?
|
@@ -332,7 +358,7 @@ class LMModel(nn.Module):
|
|
332 |
token_count=offset)
|
333 |
|
334 |
# print(f'BEF {logits.shape=} BEF utils.SampleTop5') # AGREES 4 BEF logits.shape=torch.Size([1, 4, 1, 2048]) BEF utils.SampleTop5
|
335 |
-
next_token =
|
336 |
|
337 |
|
338 |
|
|
|
10 |
from dataclasses import dataclass
|
11 |
from functools import partial
|
12 |
from torch import nn
|
|
|
13 |
from audiocraft.activations import get_activation_fn
|
14 |
|
15 |
+
def sample_top_k(p, k=250, n_draw=None):
|
16 |
+
"""
|
17 |
+
p probabs 2048 ?
|
18 |
+
num_draw : how many tokens to sample (for duplicate elongation)
|
19 |
+
"""
|
20 |
+
|
21 |
+
p = torch.softmax(p, dim=-1) # p/temp
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
top_k_value, i250 = torch.topk(p, k, dim=-1) # probs: [1, 4, 2048]
|
26 |
+
# print('\n_____TOPK________\n', top_k_value.shape, top_k_value[0, 0, :10], '\n___________END_TOPK____________\n')
|
27 |
+
min_value_top_k = top_k_value[..., [-1]] #
|
28 |
+
p *= (p >= min_value_top_k).float()
|
29 |
+
p.div_(p.sum(dim=-1, keepdim=True))
|
30 |
+
# -- next_token = multinomial(probs, num_samples=num_draw)
|
31 |
+
|
32 |
+
# RESHAPED into bs, 4, 250
|
33 |
+
p_ = p.reshape(-1, p.shape[-1])
|
34 |
+
|
35 |
+
|
36 |
+
out = torch.multinomial(p_,
|
37 |
+
num_samples=n_draw,
|
38 |
+
replacement=False) # [4, num_draw]
|
39 |
+
return out.transpose(0, 1)[:, :, None] # [num_draw, 4, 1]
|
40 |
+
|
41 |
+
|
42 |
|
43 |
# ============================================== From LM.py
|
44 |
|
|
|
173 |
super().__init__()
|
174 |
self.cfg_coef = cfg_coef
|
175 |
|
176 |
+
self.n_draw = 1
|
177 |
self.condition_provider = condition_provider
|
178 |
self.fuser = fuser
|
179 |
self.card = card # 2048 ?
|
|
|
358 |
token_count=offset)
|
359 |
|
360 |
# print(f'BEF {logits.shape=} BEF utils.SampleTop5') # AGREES 4 BEF logits.shape=torch.Size([1, 4, 1, 2048]) BEF utils.SampleTop5
|
361 |
+
next_token = sample_top_k(logits, n_draw=self.n_draw) # [1,4,2048] logits
|
362 |
|
363 |
|
364 |
|
audiocraft/loaders.py
CHANGED
@@ -120,32 +120,7 @@ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu',
|
|
120 |
return model
|
121 |
|
122 |
|
123 |
-
def load_lm_model_magnet(file_or_url_or_id: tp.Union[Path, str], compression_model_frame_rate: int,
|
124 |
-
device='cpu', cache_dir: tp.Optional[str] = None):
|
125 |
-
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
|
126 |
-
cfg = OmegaConf.create(pkg['xp.cfg'])
|
127 |
-
cfg.device = str(device)
|
128 |
-
if cfg.device == 'cpu':
|
129 |
-
cfg.dtype = 'float32'
|
130 |
-
else:
|
131 |
-
cfg.dtype = 'float16'
|
132 |
-
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
133 |
-
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|
134 |
-
|
135 |
-
cfg.transformer_lm.compression_model_framerate = compression_model_frame_rate
|
136 |
-
cfg.transformer_lm.segment_duration = cfg.dataset.segment_duration
|
137 |
-
cfg.transformer_lm.span_len = cfg.masking.span_len
|
138 |
|
139 |
-
# MAGNeT models v1 support only xformers backend.
|
140 |
-
from .transformer import set_efficient_attention_backend
|
141 |
-
if cfg.transformer_lm.memory_efficient:
|
142 |
-
set_efficient_attention_backend("xformers")
|
143 |
-
|
144 |
-
model = builders.get_lm_model(cfg)
|
145 |
-
model.load_state_dict(pkg['best_state'])
|
146 |
-
model.eval()
|
147 |
-
model.cfg = cfg
|
148 |
-
return model
|
149 |
|
150 |
|
151 |
def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str],
|
|
|
120 |
return model
|
121 |
|
122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
|
126 |
def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str],
|
audiocraft/utils/__init__.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
"""Utilities."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/utils/autocast.py
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import torch
|
8 |
-
|
9 |
-
|
10 |
-
class TorchAutocast:
|
11 |
-
"""TorchAutocast utility class.
|
12 |
-
Allows you to enable and disable autocast. This is specially useful
|
13 |
-
when dealing with different architectures and clusters with different
|
14 |
-
levels of support.
|
15 |
-
|
16 |
-
Args:
|
17 |
-
enabled (bool): Whether to enable torch.autocast or not.
|
18 |
-
args: Additional args for torch.autocast.
|
19 |
-
kwargs: Additional kwargs for torch.autocast
|
20 |
-
"""
|
21 |
-
def __init__(self, enabled: bool, *args, **kwargs):
|
22 |
-
self.autocast = torch.autocast(*args, **kwargs) if enabled else None
|
23 |
-
|
24 |
-
def __enter__(self):
|
25 |
-
if self.autocast is None:
|
26 |
-
return
|
27 |
-
try:
|
28 |
-
self.autocast.__enter__()
|
29 |
-
except RuntimeError:
|
30 |
-
device = self.autocast.device
|
31 |
-
dtype = self.autocast.fast_dtype
|
32 |
-
raise RuntimeError(
|
33 |
-
f"There was an error autocasting with dtype={dtype} device={device}\n"
|
34 |
-
"If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
|
35 |
-
)
|
36 |
-
|
37 |
-
def __exit__(self, *args, **kwargs):
|
38 |
-
if self.autocast is None:
|
39 |
-
return
|
40 |
-
self.autocast.__exit__(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/utils/cache.py
DELETED
@@ -1,324 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
from concurrent.futures import ThreadPoolExecutor
|
8 |
-
from collections import deque
|
9 |
-
from functools import partial
|
10 |
-
from hashlib import sha1
|
11 |
-
import logging
|
12 |
-
from pathlib import Path
|
13 |
-
import sys
|
14 |
-
import typing as tp
|
15 |
-
import zipfile
|
16 |
-
|
17 |
-
import flashy
|
18 |
-
import torch
|
19 |
-
|
20 |
-
|
21 |
-
logger = logging.getLogger(__name__)
|
22 |
-
|
23 |
-
|
24 |
-
def get_full_embed(full_embed: torch.Tensor, x: tp.Any, idx: int, device: tp.Union[str, torch.device]) -> torch.Tensor:
|
25 |
-
"""Utility function for the EmbeddingCache, returning the full embedding without any chunking.
|
26 |
-
This method can be used in case there is no need in extracting a chunk of the full embedding
|
27 |
-
read from the cache.
|
28 |
-
|
29 |
-
Args:
|
30 |
-
full_embed (torch.Tensor): The full embedding.
|
31 |
-
x (any): Batch object from which the full embedding is derived.
|
32 |
-
idx (torch.Tensor): Index of object to consider in the batch object.
|
33 |
-
Returns:
|
34 |
-
full_embed (torch.Tensor): The full embedding
|
35 |
-
"""
|
36 |
-
return full_embed.to(device)
|
37 |
-
|
38 |
-
|
39 |
-
class EmbeddingCache:
|
40 |
-
"""Cache around embeddings computation for faster execution.
|
41 |
-
The EmbeddingCache is storing pre-computed embeddings on disk and provides a simple API
|
42 |
-
to retrieve the pre-computed embeddings on full inputs and extract only a given chunk
|
43 |
-
using a user-provided function. When the cache is warm (all embeddings are pre-computed),
|
44 |
-
the EmbeddingCache allows for faster training as it removes the need of computing the embeddings.
|
45 |
-
Additionally, it provides in-memory cache around the loaded embeddings to limit IO footprint
|
46 |
-
and synchronization points in the forward calls.
|
47 |
-
|
48 |
-
Args:
|
49 |
-
cache_path (Path): Path to folder where all pre-computed embeddings are saved on disk.
|
50 |
-
device (str or torch.device): Device on which the embedding is returned.
|
51 |
-
compute_embed_fn (callable[[Path, any, int], torch.Tensor], optional): Function to compute
|
52 |
-
the embedding from a given object and path. This user provided function can compute the
|
53 |
-
embedding from the provided object or using the provided path as entry point. The last parameter
|
54 |
-
specify the index corresponding to the current embedding in the object that can represent batch metadata.
|
55 |
-
extract_embed_fn (callable[[torch.Tensor, any, int], torch.Tensor], optional): Function to extract
|
56 |
-
the desired embedding chunk from the full embedding loaded from the cache. The last parameter
|
57 |
-
specify the index corresponding to the current embedding in the object that can represent batch metadata.
|
58 |
-
If not specified, will return the full embedding unmodified.
|
59 |
-
"""
|
60 |
-
def __init__(self, cache_path: tp.Union[str, Path], device: tp.Union[str, torch.device],
|
61 |
-
compute_embed_fn: tp.Callable[[Path, tp.Any, int], torch.Tensor],
|
62 |
-
extract_embed_fn: tp.Optional[tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]] = None):
|
63 |
-
self.cache_path = Path(cache_path)
|
64 |
-
self.device = device
|
65 |
-
self._compute_embed_fn = compute_embed_fn
|
66 |
-
self._extract_embed_fn: tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]
|
67 |
-
if extract_embed_fn is not None:
|
68 |
-
self._extract_embed_fn = extract_embed_fn
|
69 |
-
else:
|
70 |
-
self._extract_embed_fn = partial(get_full_embed, device=device)
|
71 |
-
if self.cache_path is not None:
|
72 |
-
self.cache_path.mkdir(exist_ok=True, parents=True)
|
73 |
-
logger.info(f"Cache instantiated at: {self.cache_path}")
|
74 |
-
self.pool = ThreadPoolExecutor(8)
|
75 |
-
self.pool.__enter__()
|
76 |
-
self._current_batch_cache: dict = {}
|
77 |
-
self._memory_cache: dict = {}
|
78 |
-
|
79 |
-
def _get_cache_path(self, path: tp.Union[Path, str]):
|
80 |
-
"""Get cache path for the given file path."""
|
81 |
-
sig = sha1(str(path).encode()).hexdigest()
|
82 |
-
return self.cache_path / sig
|
83 |
-
|
84 |
-
@staticmethod
|
85 |
-
def _get_full_embed_from_cache(cache: Path):
|
86 |
-
"""Loads full pre-computed embedding from the cache."""
|
87 |
-
try:
|
88 |
-
embed = torch.load(cache, 'cpu')
|
89 |
-
except Exception as exc:
|
90 |
-
logger.error("Error loading %s: %r", cache, exc)
|
91 |
-
embed = None
|
92 |
-
return embed
|
93 |
-
|
94 |
-
def get_embed_from_cache(self, paths: tp.List[Path], x: tp.Any) -> torch.Tensor:
|
95 |
-
"""Get embedding from cache, computing and storing it to cache if not already cached.
|
96 |
-
The EmbeddingCache first tries to load the embedding from the in-memory cache
|
97 |
-
containing the pre-computed chunks populated through `populate_embed_cache`.
|
98 |
-
If not found, the full embedding is computed and stored on disk to be later accessed
|
99 |
-
to populate the in-memory cache, and the desired embedding chunk is extracted and returned.
|
100 |
-
|
101 |
-
Args:
|
102 |
-
paths (list[Path or str]): List of paths from where the embeddings can be loaded.
|
103 |
-
x (any): Object from which the embedding is extracted.
|
104 |
-
"""
|
105 |
-
embeds = []
|
106 |
-
for idx, path in enumerate(paths):
|
107 |
-
cache = self._get_cache_path(path)
|
108 |
-
if cache in self._current_batch_cache:
|
109 |
-
embed = self._current_batch_cache[cache]
|
110 |
-
else:
|
111 |
-
full_embed = self._compute_embed_fn(path, x, idx)
|
112 |
-
try:
|
113 |
-
with flashy.utils.write_and_rename(cache, pid=True) as f:
|
114 |
-
torch.save(full_embed.cpu(), f)
|
115 |
-
except Exception as exc:
|
116 |
-
logger.error('Error saving embed %s (%s): %r', cache, full_embed.shape, exc)
|
117 |
-
else:
|
118 |
-
logger.info('New embed cache saved: %s (%s)', cache, full_embed.shape)
|
119 |
-
embed = self._extract_embed_fn(full_embed, x, idx)
|
120 |
-
embeds.append(embed)
|
121 |
-
embed = torch.stack(embeds, dim=0)
|
122 |
-
return embed
|
123 |
-
|
124 |
-
def populate_embed_cache(self, paths: tp.List[Path], x: tp.Any) -> None:
|
125 |
-
"""Populate in-memory caches for embeddings reading from the embeddings stored on disk.
|
126 |
-
The in-memory caches consist in a cache for the full embedding and another cache for the
|
127 |
-
final embedding chunk. Such caches are used to limit the IO access when computing the actual embeddings
|
128 |
-
and reduce the IO footprint and synchronization points during forward passes.
|
129 |
-
|
130 |
-
Args:
|
131 |
-
paths (list[Path]): List of paths from where the embeddings can be loaded.
|
132 |
-
x (any): Object from which the embedding is extracted.
|
133 |
-
"""
|
134 |
-
self._current_batch_cache.clear()
|
135 |
-
if self.cache_path is not None:
|
136 |
-
futures: list = []
|
137 |
-
for path in paths:
|
138 |
-
assert path is not None, "Path is required for computation from cache"
|
139 |
-
cache = self._get_cache_path(path)
|
140 |
-
if cache in self._memory_cache or not cache.exists():
|
141 |
-
futures.append(None)
|
142 |
-
else:
|
143 |
-
futures.append(self.pool.submit(EmbeddingCache._get_full_embed_from_cache, cache))
|
144 |
-
for idx, (path, future) in enumerate(zip(paths, futures)):
|
145 |
-
assert path is not None
|
146 |
-
cache = self._get_cache_path(path)
|
147 |
-
full_embed = None
|
148 |
-
if future is None:
|
149 |
-
if cache in self._memory_cache:
|
150 |
-
full_embed = self._memory_cache[cache]
|
151 |
-
else:
|
152 |
-
full_embed = future.result()
|
153 |
-
if full_embed is not None:
|
154 |
-
self._memory_cache[cache] = full_embed
|
155 |
-
full_embed = full_embed.to(self.device)
|
156 |
-
if full_embed is not None:
|
157 |
-
embed = self._extract_embed_fn(full_embed, x, idx)
|
158 |
-
self._current_batch_cache[cache] = embed
|
159 |
-
|
160 |
-
|
161 |
-
class CachedBatchWriter:
|
162 |
-
"""Write pre computed caches for mini batches. This can
|
163 |
-
make loading a lot more efficient depending on your filesystem.
|
164 |
-
|
165 |
-
Args:
|
166 |
-
cache_folder (Path): folder in which the cached minibatches
|
167 |
-
will be stored.
|
168 |
-
|
169 |
-
Inside cache folder, the structure is the following:
|
170 |
-
`epoch_number / update_number.zip`
|
171 |
-
And the zip file contains one entry per batch item.
|
172 |
-
|
173 |
-
It is possible to use the cache with a batch size smaller than
|
174 |
-
created with but obviously not larger. Make sure to call the
|
175 |
-
`start_epoch(epoch)` method for indicating changes of epochs.
|
176 |
-
|
177 |
-
See the grid `audiocraft/grids/musicgen/musicgen_warmup_cache.py`
|
178 |
-
for an example of how to warmup the cache.
|
179 |
-
"""
|
180 |
-
def __init__(self, cache_folder: Path):
|
181 |
-
self.cache_folder = cache_folder
|
182 |
-
self._current_epoch: tp.Optional[int] = None
|
183 |
-
self._current_index = 0
|
184 |
-
|
185 |
-
def start_epoch(self, epoch: int):
|
186 |
-
"""Call at the beginning of each epoch.
|
187 |
-
"""
|
188 |
-
self._current_epoch = epoch
|
189 |
-
self._current_index = 0
|
190 |
-
self._zip_path.parent.mkdir(exist_ok=True, parents=True)
|
191 |
-
|
192 |
-
@staticmethod
|
193 |
-
def _get_zip_path(cache_folder: Path, epoch: int, index: int):
|
194 |
-
return cache_folder / f"{epoch:05d}" / f"{index:06d}.zip"
|
195 |
-
|
196 |
-
@property
|
197 |
-
def _zip_path(self):
|
198 |
-
assert self._current_epoch is not None
|
199 |
-
return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, self._current_index)
|
200 |
-
|
201 |
-
def save(self, *content):
|
202 |
-
"""Save one mini batch. This function is distributed-aware
|
203 |
-
and will automatically merge all the items from the different
|
204 |
-
workers.
|
205 |
-
"""
|
206 |
-
all_contents = []
|
207 |
-
for rank in range(flashy.distrib.world_size()):
|
208 |
-
their_content = flashy.distrib.broadcast_object(content, src=rank)
|
209 |
-
all_contents.append(their_content)
|
210 |
-
|
211 |
-
if flashy.distrib.is_rank_zero():
|
212 |
-
idx = 0
|
213 |
-
with flashy.utils.write_and_rename(self._zip_path) as tmp:
|
214 |
-
with zipfile.ZipFile(tmp, 'w') as zf:
|
215 |
-
for content in all_contents:
|
216 |
-
for vals in zip(*content):
|
217 |
-
with zf.open(f'{idx}', 'w') as f: # type: ignore
|
218 |
-
torch.save(vals, f)
|
219 |
-
idx += 1
|
220 |
-
flashy.distrib.barrier()
|
221 |
-
self._current_index += 1
|
222 |
-
|
223 |
-
|
224 |
-
class CachedBatchLoader:
|
225 |
-
"""Loader for cached mini-batches dumped with `CachedBatchWriter`.
|
226 |
-
|
227 |
-
Args:
|
228 |
-
cache_folder (Path): folder in which the cached minibatches are stored.
|
229 |
-
batch_size (int): batch size (per GPU) expected.
|
230 |
-
num_workers (int): number of workers to use for loading.
|
231 |
-
min_length (int): minimum expected length for each epoch. If some
|
232 |
-
mini-batches are missing, and error is raised.
|
233 |
-
|
234 |
-
This is iterable just like a regular DataLoader.
|
235 |
-
"""
|
236 |
-
|
237 |
-
def __init__(self, cache_folder: Path, batch_size: int,
|
238 |
-
num_workers: int = 10, min_length: int = 1):
|
239 |
-
self.cache_folder = cache_folder
|
240 |
-
self.batch_size = batch_size
|
241 |
-
self.num_workers = num_workers
|
242 |
-
self.min_length = min_length
|
243 |
-
self._current_epoch: tp.Optional[int] = None
|
244 |
-
self.sampler = None # for compatibility with the regular DataLoader
|
245 |
-
|
246 |
-
def __len__(self):
|
247 |
-
path = CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch or 0, 0).parent
|
248 |
-
return len([p for p in path.iterdir() if p.suffix == ".zip"])
|
249 |
-
|
250 |
-
def start_epoch(self, epoch: int):
|
251 |
-
"""Call at the beginning of each epoch.
|
252 |
-
"""
|
253 |
-
self._current_epoch = epoch
|
254 |
-
|
255 |
-
def _zip_path(self, index: int):
|
256 |
-
assert self._current_epoch is not None
|
257 |
-
return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, index)
|
258 |
-
|
259 |
-
def _load_one(self, index: int):
|
260 |
-
zip_path = self._zip_path(index)
|
261 |
-
if not zip_path.exists():
|
262 |
-
if index < self.min_length:
|
263 |
-
raise RuntimeError(f"Cache should have at least {self.min_length} batches, but {index} doesn't exist")
|
264 |
-
|
265 |
-
return None
|
266 |
-
mode = "rb" if sys.version_info >= (3, 9) else "r"
|
267 |
-
try:
|
268 |
-
with zipfile.ZipFile(zip_path, 'r') as zf:
|
269 |
-
rank = flashy.distrib.rank()
|
270 |
-
world_size = flashy.distrib.world_size()
|
271 |
-
root = zipfile.Path(zf)
|
272 |
-
items = list(root.iterdir())
|
273 |
-
total_batch_size = self.batch_size * world_size
|
274 |
-
if len(items) < total_batch_size:
|
275 |
-
raise RuntimeError(
|
276 |
-
f"The cache can handle a max batch size of {len(items)}, "
|
277 |
-
f"but {total_batch_size} is needed.")
|
278 |
-
start = rank * self.batch_size
|
279 |
-
items = items[start: start + self.batch_size]
|
280 |
-
assert len(items) == self.batch_size
|
281 |
-
entries = []
|
282 |
-
entries = [torch.load(item.open(mode), 'cpu') for item in items] # type: ignore
|
283 |
-
transposed = zip(*entries)
|
284 |
-
out = []
|
285 |
-
for part in transposed:
|
286 |
-
assert len(part) > 0
|
287 |
-
if isinstance(part[0], torch.Tensor):
|
288 |
-
out.append(torch.stack(part))
|
289 |
-
else:
|
290 |
-
assert isinstance(part, torch.Tensor)
|
291 |
-
out.append(part)
|
292 |
-
return out
|
293 |
-
except Exception:
|
294 |
-
logger.error("Error when reading zip path %s", zip_path)
|
295 |
-
raise
|
296 |
-
|
297 |
-
def __iter__(self):
|
298 |
-
"""This will yields tuples, exactly as provided to the
|
299 |
-
`CachedBatchWriter.save` method.
|
300 |
-
"""
|
301 |
-
pool = ThreadPoolExecutor(self.num_workers)
|
302 |
-
next_index = 0
|
303 |
-
queue = deque()
|
304 |
-
|
305 |
-
def _get_next():
|
306 |
-
nonlocal next_index
|
307 |
-
r = queue.popleft().result()
|
308 |
-
if r is None:
|
309 |
-
return None
|
310 |
-
else:
|
311 |
-
queue.append(pool.submit(self._load_one, next_index))
|
312 |
-
next_index += 1
|
313 |
-
return r
|
314 |
-
|
315 |
-
with pool:
|
316 |
-
# fill the buffer of fetching jobs.
|
317 |
-
for _ in range(2 * self.num_workers):
|
318 |
-
queue.append(pool.submit(self._load_one, next_index))
|
319 |
-
next_index += 1
|
320 |
-
while True:
|
321 |
-
batch = _get_next()
|
322 |
-
if batch is None:
|
323 |
-
return
|
324 |
-
yield batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/utils/checkpoint.py
DELETED
@@ -1,161 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
from enum import Enum
|
8 |
-
import logging
|
9 |
-
from pathlib import Path
|
10 |
-
import re
|
11 |
-
import typing as tp
|
12 |
-
|
13 |
-
import flashy
|
14 |
-
import torch
|
15 |
-
|
16 |
-
from ..environment import AudioCraftEnvironment
|
17 |
-
|
18 |
-
|
19 |
-
logger = logging.getLogger(__name__)
|
20 |
-
|
21 |
-
|
22 |
-
class CheckpointSource(Enum):
|
23 |
-
CURRENT_XP = "current_xp"
|
24 |
-
PRETRAINED = "pretrained"
|
25 |
-
OTHER = "other"
|
26 |
-
|
27 |
-
|
28 |
-
def checkpoint_name(name: tp.Optional[str] = None, rank: tp.Optional[int] = None, use_fsdp: bool = False) -> str:
|
29 |
-
"""Checkpoint name formatted for all use in AudioCraft codebase and has the following format:
|
30 |
-
`checkpoint_<name>.th(.<rank>)`. By convention, name is expected to be empty for last checkpoint,
|
31 |
-
'best' for the best checkpoint or the epoch number.
|
32 |
-
|
33 |
-
Args:
|
34 |
-
name (str, optional): Name suffix for the checkpoint file stem.
|
35 |
-
rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
|
36 |
-
use_fsdp (bool): Whether the calling solver relies on FSDP.
|
37 |
-
Returns:
|
38 |
-
str: The checkpoint name.
|
39 |
-
"""
|
40 |
-
suffix = ''
|
41 |
-
if rank is None:
|
42 |
-
rank = flashy.distrib.rank()
|
43 |
-
if rank > 0 and use_fsdp:
|
44 |
-
suffix = '.' + str(rank)
|
45 |
-
name_part = ''
|
46 |
-
if name is not None:
|
47 |
-
name_part = f'_{name}'
|
48 |
-
return f'checkpoint{name_part}.th{suffix}'
|
49 |
-
|
50 |
-
|
51 |
-
def is_sharded_checkpoint(path: Path) -> bool:
|
52 |
-
"""Whether the checkpoint at the given path corresponds to a sharded checkpoint across rank."""
|
53 |
-
return re.search(r'\.th\.\d+$', path.name) is not None
|
54 |
-
|
55 |
-
|
56 |
-
def resolve_checkpoint_path(sig_or_path: tp.Union[Path, str], name: tp.Optional[str] = None,
|
57 |
-
use_fsdp: bool = False) -> tp.Optional[Path]:
|
58 |
-
"""Resolve a given checkpoint path for a provided dora sig or path.
|
59 |
-
|
60 |
-
Args:
|
61 |
-
sig_or_path (Path or str): Checkpoint path or dora signature.
|
62 |
-
name (str, optional): Name suffix for the checkpoint file stem.
|
63 |
-
rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
|
64 |
-
use_fsdp (bool): Whether the calling solver relies on FSDP.
|
65 |
-
Returns:
|
66 |
-
Path, optional: Resolved checkpoint path, if it exists.
|
67 |
-
"""
|
68 |
-
from audiocraft import train
|
69 |
-
xps_root = train.main.dora.dir / 'xps'
|
70 |
-
sig_or_path = str(sig_or_path)
|
71 |
-
if sig_or_path.startswith('//sig/'):
|
72 |
-
sig = sig_or_path[len('//sig/'):]
|
73 |
-
path = xps_root / sig
|
74 |
-
else:
|
75 |
-
path = Path(sig_or_path)
|
76 |
-
path = AudioCraftEnvironment.resolve_reference_path(path)
|
77 |
-
|
78 |
-
if path.is_dir():
|
79 |
-
path = path / checkpoint_name(name, use_fsdp=use_fsdp)
|
80 |
-
|
81 |
-
if path.exists():
|
82 |
-
return path
|
83 |
-
else:
|
84 |
-
return None
|
85 |
-
|
86 |
-
|
87 |
-
def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> tp.Any:
|
88 |
-
"""Load state from checkpoints at the specified checkpoint path."""
|
89 |
-
if is_sharded:
|
90 |
-
rank0_checkpoint_path = checkpoint_path.parent / checkpoint_name(use_fsdp=False)
|
91 |
-
if rank0_checkpoint_path.exists():
|
92 |
-
check_sharded_checkpoint(checkpoint_path, rank0_checkpoint_path)
|
93 |
-
state = torch.load(checkpoint_path, 'cpu')
|
94 |
-
logger.info("Checkpoint loaded from %s", checkpoint_path)
|
95 |
-
return state
|
96 |
-
|
97 |
-
|
98 |
-
def save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None:
|
99 |
-
"""Save state to disk to the specified checkpoint_path."""
|
100 |
-
_safe_save_checkpoint(state, checkpoint_path, is_sharded)
|
101 |
-
logger.info("Checkpoint saved to %s", checkpoint_path)
|
102 |
-
|
103 |
-
|
104 |
-
def flush_stale_checkpoints(checkpoint_path: Path, keep_last: tp.Optional[int] = None) -> None:
|
105 |
-
"""Flush checkpoints to only keep last N checkpoints."""
|
106 |
-
if keep_last is None or keep_last <= 0:
|
107 |
-
return
|
108 |
-
checkpoint_dir = checkpoint_path.parent
|
109 |
-
suffix = ''
|
110 |
-
if flashy.distrib.rank() > 0:
|
111 |
-
suffix = f'.{flashy.distrib.rank()}'
|
112 |
-
checkpoint_files_with_epoch = []
|
113 |
-
for path in Path(checkpoint_dir).glob(f'checkpoint_*.th{suffix}'):
|
114 |
-
epoch_part = path.name.split('.', 1)[0].split('_', 1)[1]
|
115 |
-
if epoch_part.isdigit():
|
116 |
-
checkpoint_files_with_epoch.append((path, int(epoch_part)))
|
117 |
-
checkpoint_files = [path for path, _ in list(sorted(checkpoint_files_with_epoch, key=lambda t: t[1]))]
|
118 |
-
total_to_flush = max(0, len(checkpoint_files) - keep_last)
|
119 |
-
files_to_flush = checkpoint_files[:total_to_flush]
|
120 |
-
for path in files_to_flush:
|
121 |
-
logger.debug("Removing checkpoint: %s", str(path))
|
122 |
-
path.unlink(missing_ok=True)
|
123 |
-
|
124 |
-
|
125 |
-
def check_sharded_checkpoint(checkpoint_path: Path, rank0_checkpoint_path: Path) -> None:
|
126 |
-
"""Check sharded checkpoint state, ensuring the checkpoints are not corrupted."""
|
127 |
-
# Finish the work of a previous run that got interrupted while dumping.
|
128 |
-
old_path = Path(str(checkpoint_path) + '.old')
|
129 |
-
if old_path.exists():
|
130 |
-
raise RuntimeError(
|
131 |
-
f"Old checkpoint {old_path} from previous version of this code exist, cannot safely proceed.")
|
132 |
-
token = Path(str(rank0_checkpoint_path) + '.tmp.done')
|
133 |
-
tmp_path = Path(str(checkpoint_path) + '.tmp')
|
134 |
-
if token.exists():
|
135 |
-
if tmp_path.exists():
|
136 |
-
tmp_path.rename(checkpoint_path)
|
137 |
-
flashy.distrib.barrier()
|
138 |
-
if flashy.distrib.is_rank_zero() and token.exists():
|
139 |
-
token.unlink()
|
140 |
-
|
141 |
-
|
142 |
-
def _safe_save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None:
|
143 |
-
"""Save checkpoints in a safe manner even with when sharded checkpoints across nodes."""
|
144 |
-
def _barrier_if_sharded():
|
145 |
-
if is_sharded:
|
146 |
-
flashy.distrib.barrier()
|
147 |
-
|
148 |
-
if flashy.distrib.is_rank_zero():
|
149 |
-
token = Path(str(checkpoint_path) + '.tmp.done')
|
150 |
-
if token.exists():
|
151 |
-
token.unlink()
|
152 |
-
_barrier_if_sharded()
|
153 |
-
with flashy.utils.write_and_rename(checkpoint_path) as f:
|
154 |
-
torch.save(state, f)
|
155 |
-
_barrier_if_sharded()
|
156 |
-
if flashy.distrib.is_rank_zero():
|
157 |
-
token.touch()
|
158 |
-
_barrier_if_sharded()
|
159 |
-
_barrier_if_sharded()
|
160 |
-
if flashy.distrib.rank() == 0:
|
161 |
-
token.unlink()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/utils/export.py
DELETED
@@ -1,79 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
"""
|
8 |
-
Utility to export a training checkpoint to a lightweight release checkpoint.
|
9 |
-
"""
|
10 |
-
|
11 |
-
from pathlib import Path
|
12 |
-
import typing as tp
|
13 |
-
|
14 |
-
from omegaconf import OmegaConf
|
15 |
-
import torch
|
16 |
-
|
17 |
-
from audiocraft import __version__
|
18 |
-
|
19 |
-
|
20 |
-
def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
|
21 |
-
"""Export only the best state from the given EnCodec checkpoint. This
|
22 |
-
should be used if you trained your own EnCodec model.
|
23 |
-
"""
|
24 |
-
pkg = torch.load(checkpoint_path, 'cpu')
|
25 |
-
new_pkg = {
|
26 |
-
'best_state': pkg['best_state']['model'],
|
27 |
-
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
|
28 |
-
'version': __version__,
|
29 |
-
'exported': True,
|
30 |
-
}
|
31 |
-
Path(out_file).parent.mkdir(exist_ok=True, parents=True)
|
32 |
-
torch.save(new_pkg, out_file)
|
33 |
-
return out_file
|
34 |
-
|
35 |
-
|
36 |
-
def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]):
|
37 |
-
"""Export a compression model (potentially EnCodec) from a pretrained model.
|
38 |
-
This is required for packaging the audio tokenizer along a MusicGen or AudioGen model.
|
39 |
-
Do not include the //pretrained/ prefix. For instance if you trained a model
|
40 |
-
with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`.
|
41 |
-
|
42 |
-
In that case, this will not actually include a copy of the model, simply the reference
|
43 |
-
to the model used.
|
44 |
-
"""
|
45 |
-
if Path(pretrained_encodec).exists():
|
46 |
-
pkg = torch.load(pretrained_encodec)
|
47 |
-
assert 'best_state' in pkg
|
48 |
-
assert 'xp.cfg' in pkg
|
49 |
-
assert 'version' in pkg
|
50 |
-
assert 'exported' in pkg
|
51 |
-
else:
|
52 |
-
pkg = {
|
53 |
-
'pretrained': pretrained_encodec,
|
54 |
-
'exported': True,
|
55 |
-
'version': __version__,
|
56 |
-
}
|
57 |
-
Path(out_file).parent.mkdir(exist_ok=True, parents=True)
|
58 |
-
torch.save(pkg, out_file)
|
59 |
-
|
60 |
-
|
61 |
-
def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
|
62 |
-
"""Export only the best state from the given MusicGen or AudioGen checkpoint.
|
63 |
-
"""
|
64 |
-
pkg = torch.load(checkpoint_path, 'cpu')
|
65 |
-
if pkg['fsdp_best_state']:
|
66 |
-
best_state = pkg['fsdp_best_state']['model']
|
67 |
-
else:
|
68 |
-
assert pkg['best_state']
|
69 |
-
best_state = pkg['best_state']['model']
|
70 |
-
new_pkg = {
|
71 |
-
'best_state': best_state,
|
72 |
-
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
|
73 |
-
'version': __version__,
|
74 |
-
'exported': True,
|
75 |
-
}
|
76 |
-
|
77 |
-
Path(out_file).parent.mkdir(exist_ok=True, parents=True)
|
78 |
-
torch.save(new_pkg, out_file)
|
79 |
-
return out_file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/utils/export_legacy.py
DELETED
@@ -1,70 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
"""
|
8 |
-
Legacy functions used at the time of the first release, kept for referencd.
|
9 |
-
"""
|
10 |
-
|
11 |
-
from pathlib import Path
|
12 |
-
import typing as tp
|
13 |
-
|
14 |
-
from omegaconf import OmegaConf, DictConfig
|
15 |
-
import torch
|
16 |
-
|
17 |
-
from audiocraft import __version__
|
18 |
-
|
19 |
-
|
20 |
-
def _clean_lm_cfg(cfg: DictConfig):
|
21 |
-
OmegaConf.set_struct(cfg, False)
|
22 |
-
# This used to be set automatically in the LM solver, need a more robust solution
|
23 |
-
# for the future.
|
24 |
-
cfg['transformer_lm']['card'] = 2048
|
25 |
-
n_q = 4
|
26 |
-
stereo_cfg = getattr(cfg, 'interleave_stereo_codebooks', None)
|
27 |
-
if stereo_cfg is not None and stereo_cfg.use:
|
28 |
-
if 'downsample' in stereo_cfg:
|
29 |
-
del stereo_cfg['downsample']
|
30 |
-
n_q = 8
|
31 |
-
cfg['transformer_lm']['n_q'] = n_q
|
32 |
-
# Experimental params no longer supported.
|
33 |
-
bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
|
34 |
-
'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
|
35 |
-
for name in bad_params:
|
36 |
-
del cfg['transformer_lm'][name]
|
37 |
-
OmegaConf.set_struct(cfg, True)
|
38 |
-
return cfg
|
39 |
-
|
40 |
-
|
41 |
-
def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
|
42 |
-
pkg = torch.load(checkpoint_path, 'cpu')
|
43 |
-
new_pkg = {
|
44 |
-
'best_state': pkg['ema']['state']['model'],
|
45 |
-
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
|
46 |
-
# The following params were NOT exported for the first release of MusicGen.
|
47 |
-
'version': __version__,
|
48 |
-
'exported': True,
|
49 |
-
}
|
50 |
-
Path(out_file).parent.mkdir(exist_ok=True, parents=True)
|
51 |
-
torch.save(new_pkg, out_file)
|
52 |
-
return out_file
|
53 |
-
|
54 |
-
|
55 |
-
def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
|
56 |
-
pkg = torch.load(checkpoint_path, 'cpu')
|
57 |
-
if pkg['fsdp_best_state']:
|
58 |
-
best_state = pkg['fsdp_best_state']['model']
|
59 |
-
else:
|
60 |
-
best_state = pkg['best_state']['model']
|
61 |
-
new_pkg = {
|
62 |
-
'best_state': best_state,
|
63 |
-
'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg'])),
|
64 |
-
# The following params were NOT exported for the first release of MusicGen.
|
65 |
-
'version': __version__,
|
66 |
-
'exported': True,
|
67 |
-
}
|
68 |
-
Path(out_file).parent.mkdir(exist_ok=True, parents=True)
|
69 |
-
torch.save(new_pkg, out_file)
|
70 |
-
return out_file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/utils/notebook.py
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
try:
|
8 |
-
import IPython.display as ipd # type: ignore
|
9 |
-
except ImportError:
|
10 |
-
# Note in a notebook...
|
11 |
-
pass
|
12 |
-
|
13 |
-
|
14 |
-
import torch
|
15 |
-
|
16 |
-
|
17 |
-
def display_audio(samples: torch.Tensor, sample_rate: int):
|
18 |
-
"""Renders an audio player for the given audio samples.
|
19 |
-
|
20 |
-
Args:
|
21 |
-
samples (torch.Tensor): a Tensor of decoded audio samples
|
22 |
-
with shapes [B, C, T] or [C, T]
|
23 |
-
sample_rate (int): sample rate audio should be displayed with.
|
24 |
-
"""
|
25 |
-
assert samples.dim() == 2 or samples.dim() == 3
|
26 |
-
|
27 |
-
samples = samples.detach().cpu()
|
28 |
-
if samples.dim() == 2:
|
29 |
-
samples = samples[None, ...]
|
30 |
-
|
31 |
-
for audio in samples:
|
32 |
-
ipd.display(ipd.Audio(audio, rate=sample_rate))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/utils/profiler.py
DELETED
@@ -1,38 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import logging
|
8 |
-
import typing as tp
|
9 |
-
|
10 |
-
import dora
|
11 |
-
import torch
|
12 |
-
|
13 |
-
|
14 |
-
logger = logging.getLogger(__name__)
|
15 |
-
|
16 |
-
|
17 |
-
class Profiler:
|
18 |
-
"""Context manager wrapper for xformers profiler.
|
19 |
-
"""
|
20 |
-
def __init__(self, module: torch.nn.Module, enabled: bool = False):
|
21 |
-
self.profiler: tp.Optional[tp.Any] = None
|
22 |
-
if enabled:
|
23 |
-
from xformers.profiler import profile
|
24 |
-
output_dir = dora.get_xp().folder / 'profiler_data'
|
25 |
-
logger.info("Profiling activated, results with be saved to %s", output_dir)
|
26 |
-
self.profiler = profile(output_dir=output_dir, module=module)
|
27 |
-
|
28 |
-
def step(self):
|
29 |
-
if self.profiler is not None:
|
30 |
-
self.profiler.step() # type: ignore
|
31 |
-
|
32 |
-
def __enter__(self):
|
33 |
-
if self.profiler is not None:
|
34 |
-
return self.profiler.__enter__() # type: ignore
|
35 |
-
|
36 |
-
def __exit__(self, exc_type, exc_value, exc_tb):
|
37 |
-
if self.profiler is not None:
|
38 |
-
return self.profiler.__exit__(exc_type, exc_value, exc_tb) # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiocraft/utils/utils.py
DELETED
@@ -1,98 +0,0 @@
|
|
1 |
-
import hashlib
|
2 |
-
import json
|
3 |
-
import logging
|
4 |
-
import typing as tp
|
5 |
-
import flashy
|
6 |
-
import flashy.distrib
|
7 |
-
import omegaconf
|
8 |
-
import torch
|
9 |
-
|
10 |
-
|
11 |
-
logger = logging.getLogger(__name__)
|
12 |
-
|
13 |
-
|
14 |
-
def model_hash(model: torch.nn.Module) -> str:
|
15 |
-
"""Return a model hash. This should allow us to track regressions in model init
|
16 |
-
from the logs of past experiments.
|
17 |
-
"""
|
18 |
-
hasher = hashlib.sha1()
|
19 |
-
for p in model.parameters():
|
20 |
-
hasher.update(p.data.cpu().numpy().tobytes())
|
21 |
-
return hasher.hexdigest()
|
22 |
-
|
23 |
-
|
24 |
-
def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
|
25 |
-
"""Convenience function to map an omegaconf configuration to a dictionary.
|
26 |
-
|
27 |
-
Args:
|
28 |
-
cfg (omegaconf.DictConfig): Original configuration to map to dict.
|
29 |
-
Returns:
|
30 |
-
dict: Config as dictionary object.
|
31 |
-
"""
|
32 |
-
dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
|
33 |
-
assert isinstance(dct, dict)
|
34 |
-
return dct
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int,
|
41 |
-
num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader:
|
42 |
-
"""Convenience function to load dataset into a dataloader with optional subset sampling.
|
43 |
-
|
44 |
-
Args:
|
45 |
-
dataset: Dataset to load.
|
46 |
-
num_samples (Optional[int]): Number of samples to limit subset size.
|
47 |
-
batch_size (int): Batch size.
|
48 |
-
num_workers (int): Number of workers for data loading.
|
49 |
-
seed (int): Random seed.
|
50 |
-
"""
|
51 |
-
if num_samples is not None:
|
52 |
-
dataset = random_subset(dataset, num_samples, seed)
|
53 |
-
|
54 |
-
dataloader = flashy.distrib.loader(
|
55 |
-
dataset,
|
56 |
-
batch_size=batch_size,
|
57 |
-
num_workers=num_workers,
|
58 |
-
**kwargs
|
59 |
-
)
|
60 |
-
return dataloader
|
61 |
-
|
62 |
-
|
63 |
-
def get_dataset_from_loader(dataloader):
|
64 |
-
dataset = dataloader.dataset
|
65 |
-
if isinstance(dataset, torch.utils.data.Subset):
|
66 |
-
return dataset.dataset
|
67 |
-
else:
|
68 |
-
return dataset
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
def sample_top_k(p, k=250, n_draw=None):
|
75 |
-
"""
|
76 |
-
p probabs 2048 ?
|
77 |
-
num_draw : how many tokens to sample (for duplicate elongation)
|
78 |
-
"""
|
79 |
-
|
80 |
-
p = torch.softmax(p, dim=-1) # p/temp
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
top_k_value, i250 = torch.topk(p, k, dim=-1) # probs: [1, 4, 2048]
|
85 |
-
# print('\n_____TOPK________\n', top_k_value.shape, top_k_value[0, 0, :10], '\n___________END_TOPK____________\n')
|
86 |
-
min_value_top_k = top_k_value[..., [-1]] #
|
87 |
-
p *= (p >= min_value_top_k).float()
|
88 |
-
p.div_(p.sum(dim=-1, keepdim=True))
|
89 |
-
# -- next_token = multinomial(probs, num_samples=num_draw)
|
90 |
-
|
91 |
-
# RESHAPED into bs, 4, 250
|
92 |
-
p_ = p.reshape(-1, p.shape[-1])
|
93 |
-
|
94 |
-
|
95 |
-
out = torch.multinomial(p_,
|
96 |
-
num_samples=n_draw,
|
97 |
-
replacement=False) # [4, num_draw]
|
98 |
-
return out.transpose(0, 1)[:, :, None] # [num_draw, 4, 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.py
CHANGED
@@ -1,14 +1,84 @@
|
|
1 |
-
from audiocraft.audiogen import AudioGen #, audio_write
|
2 |
import audiofile
|
3 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
print('\n\n\n\n___________________')
|
6 |
|
7 |
txt = 'dogs barging in the street'
|
8 |
|
9 |
-
sound_generator = AudioGen.get_pretrained('facebook/audiogen-medium')
|
10 |
-
sound_generator.set_generation_params(duration=.46) # why is generating so long at 14 seconds
|
11 |
-
|
12 |
x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
|
13 |
x /= np.abs(x).max() + 1e-7
|
14 |
|
|
|
|
|
1 |
import audiofile
|
2 |
import numpy as np
|
3 |
+
import typing as tp
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from audiocraft.loaders import load_compression_model, load_lm_model
|
7 |
+
from audiocraft.lm import LMModel
|
8 |
+
from audiocraft.conditioners import ConditioningAttributes
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
class AudioGen():
|
14 |
+
|
15 |
+
def __init__(self,
|
16 |
+
compression_model=None,
|
17 |
+
lm=None,
|
18 |
+
duration=.04,
|
19 |
+
top_k=249):
|
20 |
+
|
21 |
+
self.compression_model = compression_model
|
22 |
+
self.lm = lm
|
23 |
+
self.top_k = top_k
|
24 |
+
self.compression_model.eval()
|
25 |
+
self.lm.eval()
|
26 |
+
self.duration = duration
|
27 |
+
self.device = next(iter(lm.parameters())).device
|
28 |
+
|
29 |
+
@property
|
30 |
+
def frame_rate(self) -> float:
|
31 |
+
"""Roughly the number of AR steps per seconds."""
|
32 |
+
return self.compression_model.frame_rate
|
33 |
+
|
34 |
+
@property
|
35 |
+
def sample_rate(self) -> int:
|
36 |
+
"""Sample rate of the generated audio."""
|
37 |
+
return self.compression_model.sample_rate
|
38 |
+
|
39 |
+
def generate(self, descriptions):
|
40 |
+
attributes = [
|
41 |
+
ConditioningAttributes(text={'description': d}) for d in descriptions]
|
42 |
+
tokens = self._generate_tokens(attributes)
|
43 |
+
print(f'\n{tokens.shape=}\n{tokens=} FINAL 5 AUD')
|
44 |
+
return self.generate_audio(tokens)
|
45 |
+
|
46 |
+
def _generate_tokens(self, attributes):
|
47 |
+
total_gen_len = int(self.duration * self.frame_rate)
|
48 |
+
gen_tokens = self.lm.generate(conditions=attributes,
|
49 |
+
max_gen_len=total_gen_len)
|
50 |
+
gen_tokens = gen_tokens.transpose(0, 1).reshape(4, -1)[None, :, :]
|
51 |
+
return gen_tokens
|
52 |
+
|
53 |
+
def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor:
|
54 |
+
"""Generate Audio from tokens."""
|
55 |
+
assert gen_tokens.dim() == 3
|
56 |
+
with torch.no_grad():
|
57 |
+
gen_audio = self.compression_model.decode(gen_tokens, None)
|
58 |
+
return gen_audio
|
59 |
+
|
60 |
+
device = 'cuda:0'
|
61 |
+
# https://huggingface.co/facebook/audiogen-medium
|
62 |
+
|
63 |
+
|
64 |
+
sound_generator = AudioGen(
|
65 |
+
compression_model=load_compression_model('facebook/audiogen-medium', device=device),
|
66 |
+
lm=load_lm_model('facebook/audiogen-medium', device=device).to(torch.float),
|
67 |
+
duration=.04,
|
68 |
+
top_k=1)
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
|
78 |
print('\n\n\n\n___________________')
|
79 |
|
80 |
txt = 'dogs barging in the street'
|
81 |
|
|
|
|
|
|
|
82 |
x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
|
83 |
x /= np.abs(x).max() + 1e-7
|
84 |
|