Dionyssos commited on
Commit
d912185
·
1 Parent(s): 60fbcf9

instantiate audiogen in demo

Browse files
audiocraft/audiogen.py CHANGED
@@ -1,9 +1,3 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
  import typing as tp
8
  import torch
9
  from audiocraft.loaders import load_compression_model, load_lm_model
@@ -11,7 +5,6 @@ import typing as tp
11
  import omegaconf
12
  import torch
13
  import numpy as np
14
- from abc import ABC, abstractmethod
15
  from .lm import LMModel
16
  from .conditioners import ConditioningAttributes
17
  from .utils.autocast import TorchAutocast
@@ -25,7 +18,7 @@ def _shift(x):
25
  return x
26
 
27
 
28
- class BaseGenModel(ABC):
29
  """Base generative model with convenient generation API.
30
 
31
  Args:
@@ -90,21 +83,13 @@ class BaseGenModel(ABC):
90
  attributes = [
91
  ConditioningAttributes(text={'description': d}) for d in descriptions]
92
  tokens = self._generate_tokens(attributes)
 
93
  return self.generate_audio(tokens)
94
 
95
  def _generate_tokens(self, attributes):
96
 
97
  total_gen_len = int(self.duration * self.frame_rate)
98
 
99
-
100
- # # print(f'{self.generation_params=}')
101
- # self.generation_params={'use_sampling': True,
102
- # 'temp': 1.0, 'top_k': 250,
103
- # 'top_p': 0.0, 'cfg_coef': 2.4, 'two_step_cfg': False}
104
-
105
-
106
-
107
-
108
  if self.duration <= self.max_duration:
109
  # generate by sampling from LM, simple case.
110
 
@@ -127,66 +112,18 @@ class BaseGenModel(ABC):
127
  gen_audio = self.compression_model.decode(gen_tokens, None)
128
  return gen_audio
129
 
130
-
131
-
132
-
133
-
134
-
135
- class AudioGen(BaseGenModel):
136
 
137
- def __init__(self, name, compression_model, lm, max_duration=None):
138
- # print(f'Using {compression_model=}\n-----=-----')
139
- super().__init__(name, compression_model, lm, max_duration)
140
- self.set_generation_params(duration=5) # default duration
141
-
142
- @staticmethod
143
- def get_pretrained(name: str = 'facebook/audiogen-medium', device=None):
144
- """Return pretrained model, we provide a single model for now:
145
- - facebook/audiogen-medium (1.5B), text to sound,
146
- # see: https://huggingface.co/facebook/audiogen-medium
147
- """
148
- if device is None:
149
- if torch.cuda.device_count():
150
- device = 'cuda'
151
- else:
152
- device = 'cpu'
153
 
154
-
155
 
156
- compression_model = load_compression_model(name, device=device)
157
- lm = load_lm_model(name, device=device)
158
- assert 'self_wav' not in lm.condition_provider.conditioners, \
159
- "AudioGen do not support waveform conditioning for now"
160
- return AudioGen(name, compression_model, lm)
161
-
162
- def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
163
- top_p: float = 0.0, temperature: float = 1.0,
164
- duration: float = 10.0, cfg_coef: float = 2.4,
165
- two_step_cfg: bool = False, extend_stride: float = 2):
166
- """Set the generation parameters for AudioGen.
167
-
168
- Args:
169
- use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
170
- top_k (int, optional): top_k used for sampling. Defaults to 250.
171
- top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
172
- temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
173
- duration (float, optional): Duration of the generated waveform. Defaults to 10.0.
174
- cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
175
- two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
176
- instead of batching together the two. This has some impact on how things
177
- are padded but seems to have little impact in practice.
178
- extend_stride: when doing extended generation (i.e. more than 10 seconds), by how much
179
- should we extend the audio each time. Larger values will mean less context is
180
- preserved, and shorter value will require extra computations.
181
- """
182
- assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
183
- self.extend_stride = extend_stride
184
- self.duration = duration
185
- self.generation_params = {
186
- 'use_sampling': use_sampling,
187
- 'temp': temperature,
188
- 'top_k': top_k,
189
- 'top_p': top_p,
190
- 'cfg_coef': cfg_coef,
191
- 'two_step_cfg': two_step_cfg,
192
- }
 
 
 
 
 
 
 
1
  import typing as tp
2
  import torch
3
  from audiocraft.loaders import load_compression_model, load_lm_model
 
5
  import omegaconf
6
  import torch
7
  import numpy as np
 
8
  from .lm import LMModel
9
  from .conditioners import ConditioningAttributes
10
  from .utils.autocast import TorchAutocast
 
18
  return x
19
 
20
 
21
+ class AudioGen():
22
  """Base generative model with convenient generation API.
23
 
24
  Args:
 
83
  attributes = [
84
  ConditioningAttributes(text={'description': d}) for d in descriptions]
85
  tokens = self._generate_tokens(attributes)
86
+ print(f'\n{tokens.shape=}\n{tokens=} FINAL 5 AUD')
87
  return self.generate_audio(tokens)
88
 
89
  def _generate_tokens(self, attributes):
90
 
91
  total_gen_len = int(self.duration * self.frame_rate)
92
 
 
 
 
 
 
 
 
 
 
93
  if self.duration <= self.max_duration:
94
  # generate by sampling from LM, simple case.
95
 
 
112
  gen_audio = self.compression_model.decode(gen_tokens, None)
113
  return gen_audio
114
 
 
 
 
 
 
 
115
 
116
+ def get_pretrained(name='facebook/audiogen-medium',
117
+ device=None):
118
+ """Return pretrained model, we provide a single model for now:
119
+ - facebook/audiogen-medium (1.5B), text to sound,
120
+ # see: https://huggingface.co/facebook/audiogen-medium
121
+ """
122
+ compression_model = load_compression_model(name, device=device)
123
+ lm = load_lm_model(name, device=device)
124
+ assert 'self_wav' not in lm.condition_provider.conditioners, \
125
+ "AudioGen do not support waveform conditioning for now"
126
+ return AudioGen(name, compression_model, lm)
127
+
 
 
 
 
128
 
 
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/builders.py CHANGED
@@ -12,16 +12,18 @@ from .lm import LMModel
12
  from .seanet import SEANetDecoder
13
  from .codebooks_patterns import DelayedPatternProvider
14
  from .conditioners import (
15
- BaseConditioner,
16
  ConditionFuser,
17
  ConditioningProvider,
18
  T5Conditioner,
19
  )
20
  from .unet import DiffusionUnet
21
  from .vq import ResidualVectorQuantizer
22
- from .utils.utils import dict_from_config
23
  from .diffusion_schedule import MultiBandProcessor, SampleProcessor
24
 
 
 
 
25
 
26
  def get_quantizer(quantizer, cfg, dimension):
27
  klass = {
@@ -118,7 +120,7 @@ def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> Cond
118
  duration = cfg.dataset.segment_duration
119
  cfg = getattr(cfg, 'conditioners')
120
  dict_cfg = {} if cfg is None else dict_from_config(cfg)
121
- conditioners: tp.Dict[str, BaseConditioner] = {}
122
  condition_provider_args = dict_cfg.pop('args', {})
123
  condition_provider_args.pop('merge_text_conditions_p', None)
124
  condition_provider_args.pop('drop_desc_p', None)
 
12
  from .seanet import SEANetDecoder
13
  from .codebooks_patterns import DelayedPatternProvider
14
  from .conditioners import (
 
15
  ConditionFuser,
16
  ConditioningProvider,
17
  T5Conditioner,
18
  )
19
  from .unet import DiffusionUnet
20
  from .vq import ResidualVectorQuantizer
21
+
22
  from .diffusion_schedule import MultiBandProcessor, SampleProcessor
23
 
24
+ def dict_from_config(cfg):
25
+ dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
26
+ return dct
27
 
28
  def get_quantizer(quantizer, cfg, dimension):
29
  klass = {
 
120
  duration = cfg.dataset.segment_duration
121
  cfg = getattr(cfg, 'conditioners')
122
  dict_cfg = {} if cfg is None else dict_from_config(cfg)
123
+ conditioners: tp.Dict[str, T5Conditioner] = {}
124
  condition_provider_args = dict_cfg.pop('args', {})
125
  condition_provider_args.pop('merge_text_conditions_p', None)
126
  condition_provider_args.pop('drop_desc_p', None)
audiocraft/conditioners.py CHANGED
@@ -8,12 +8,6 @@ import soundfile
8
  from transformers import T5EncoderModel, T5Tokenizer # type: ignore
9
  import torch
10
  from torch import nn
11
-
12
-
13
- from .utils.autocast import TorchAutocast
14
-
15
-
16
-
17
  logger = logging.getLogger(__name__)
18
  TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
19
  ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
@@ -21,7 +15,6 @@ ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
21
 
22
 
23
 
24
-
25
  class JointEmbedCondition(tp.NamedTuple):
26
  wav: torch.Tensor
27
  text: tp.List[tp.Optional[str]]
@@ -76,18 +69,6 @@ class ConditioningAttributes:
76
  return out
77
 
78
 
79
-
80
-
81
-
82
-
83
-
84
-
85
-
86
-
87
-
88
-
89
-
90
-
91
  class Tokenizer:
92
  """Base tokenizer implementation
93
  (in case we want to introduce more advances tokenizers in the future).
@@ -95,59 +76,10 @@ class Tokenizer:
95
  def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
96
  raise NotImplementedError()
97
 
98
-
99
-
100
-
101
-
102
-
103
-
104
-
105
- class BaseConditioner(nn.Module):
106
- """Base model for all conditioner modules.
107
- We allow the output dim to be different than the hidden dim for two reasons:
108
- 1) keep our LUTs small when the vocab is large;
109
- 2) make all condition dims consistent.
110
-
111
- Args:
112
- dim (int): Hidden dim of the model.
113
- output_dim (int): Output dim of the conditioner.
114
- """
115
- def __init__(self, dim: int, output_dim: int):
116
- super().__init__()
117
- self.dim = dim
118
- self.output_dim = output_dim
119
- self.output_proj = nn.Linear(dim, output_dim)
120
-
121
- def tokenize(self, *args, **kwargs) -> tp.Any:
122
- """Should be any part of the processing that will lead to a synchronization
123
- point, e.g. BPE tokenization with transfer to the GPU.
124
-
125
- The returned value will be saved and return later when calling forward().
126
- """
127
- raise NotImplementedError()
128
-
129
 
130
 
 
131
 
132
- class TextConditioner(BaseConditioner):
133
- ...
134
-
135
-
136
-
137
-
138
-
139
- class T5Conditioner(TextConditioner):
140
- """T5-based TextConditioner.
141
-
142
- Args:
143
- name (str): Name of the T5 model.
144
- output_dim (int): Output dim of the conditioner.
145
- finetune (bool): Whether to fine-tune T5 at train time.
146
- device (str): Device for T5 Conditioner.
147
- autocast_dtype (tp.Optional[str], optional): Autocast dtype.
148
- word_dropout (float, optional): Word dropout probability.
149
- normalize_text (bool, optional): Whether to apply text normalization.
150
- """
151
  MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
152
  "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
153
  "google/flan-t5-xl", "google/flan-t5-xxl"]
@@ -164,24 +96,23 @@ class T5Conditioner(TextConditioner):
164
  "google/flan-t5-11b": 1024,
165
  }
166
 
167
- def __init__(self, name: str, output_dim: int, finetune: bool, device: str,
168
- autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0.,
169
- normalize_text: bool = False):
 
 
 
 
 
170
  assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
171
- super().__init__(self.MODELS_DIMS[name], output_dim)
 
 
 
172
  self.device = device
173
  self.name = name
174
- self.finetune = finetune
175
  self.word_dropout = word_dropout
176
- if autocast_dtype is None or self.device == 'cpu':
177
- self.autocast = TorchAutocast(enabled=False)
178
- if self.device != 'cpu':
179
- logger.warning("T5 has no autocast, this might lead to NaN")
180
- else:
181
- dtype = getattr(torch, autocast_dtype)
182
- assert isinstance(dtype, torch.dtype)
183
- logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}")
184
- self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
185
  # Let's disable logging temporarily because T5 will vomit some errors otherwise.
186
  # thanks https://gist.github.com/simon-weber/7853144
187
  previous_level = logging.root.manager.disable
@@ -190,7 +121,7 @@ class T5Conditioner(TextConditioner):
190
  warnings.simplefilter("ignore")
191
  try:
192
  self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
193
- t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune)
194
  finally:
195
  logging.disable(previous_level)
196
  if finetune:
@@ -223,15 +154,15 @@ class T5Conditioner(TextConditioner):
223
  mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
224
  return inputs
225
 
226
- def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
227
  mask = inputs['attention_mask']
228
- with torch.set_grad_enabled(self.finetune), self.autocast:
229
  embeds = self.t5(**inputs).last_hidden_state
230
  embeds = self.output_proj(embeds.to(self.output_proj.weight))
231
  embeds = (embeds * mask.unsqueeze(-1))
232
 
233
  # T5 torch.Size([2, 4, 1536]) dict_keys(['input_ids', 'attention_mask'])
234
- # print(f'{inputs["input_ids"].shape=}') # inputs["input_ids"].shape=torch.Size([2, 4])
235
  return embeds, mask
236
 
237
 
@@ -248,7 +179,9 @@ class ConditioningProvider(nn.Module):
248
  conditioners (dict): Dictionary of conditioners.
249
  device (torch.device or str, optional): Device for conditioners and output condition types.
250
  """
251
- def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"):
 
 
252
  super().__init__()
253
  self.device = device
254
  self.conditioners = nn.ModuleDict(conditioners)
@@ -263,7 +196,7 @@ class ConditioningProvider(nn.Module):
263
 
264
  @property
265
  def text_conditions(self):
266
- return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
267
 
268
 
269
 
 
8
  from transformers import T5EncoderModel, T5Tokenizer # type: ignore
9
  import torch
10
  from torch import nn
 
 
 
 
 
 
11
  logger = logging.getLogger(__name__)
12
  TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
13
  ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask
 
15
 
16
 
17
 
 
18
  class JointEmbedCondition(tp.NamedTuple):
19
  wav: torch.Tensor
20
  text: tp.List[tp.Optional[str]]
 
69
  return out
70
 
71
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  class Tokenizer:
73
  """Base tokenizer implementation
74
  (in case we want to introduce more advances tokenizers in the future).
 
76
  def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
77
  raise NotImplementedError()
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
 
81
+ class T5Conditioner(nn.Module):
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
84
  "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
85
  "google/flan-t5-xl", "google/flan-t5-xxl"]
 
96
  "google/flan-t5-11b": 1024,
97
  }
98
 
99
+ def __init__(self,
100
+ name: str,
101
+ output_dim: int,
102
+ device: str,
103
+ word_dropout: float = 0.,
104
+ normalize_text: bool = False,
105
+ finetune=False):
106
+ print(f'{finetune=}')
107
  assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})"
108
+ super().__init__()
109
+ self.dim = self.MODELS_DIMS[name]
110
+ self.output_dim = output_dim
111
+ self.output_proj = nn.Linear(self.dim, output_dim)
112
  self.device = device
113
  self.name = name
 
114
  self.word_dropout = word_dropout
115
+
 
 
 
 
 
 
 
 
116
  # Let's disable logging temporarily because T5 will vomit some errors otherwise.
117
  # thanks https://gist.github.com/simon-weber/7853144
118
  previous_level = logging.root.manager.disable
 
121
  warnings.simplefilter("ignore")
122
  try:
123
  self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
124
+ t5 = T5EncoderModel.from_pretrained(name).eval() #.train(mode=finetune)
125
  finally:
126
  logging.disable(previous_level)
127
  if finetune:
 
154
  mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
155
  return inputs
156
 
157
+ def forward(self, inputs):
158
  mask = inputs['attention_mask']
159
+ with torch.no_grad():
160
  embeds = self.t5(**inputs).last_hidden_state
161
  embeds = self.output_proj(embeds.to(self.output_proj.weight))
162
  embeds = (embeds * mask.unsqueeze(-1))
163
 
164
  # T5 torch.Size([2, 4, 1536]) dict_keys(['input_ids', 'attention_mask'])
165
+ print(f'{embeds.dtype=}') # inputs["input_ids"].shape=torch.Size([2, 4])
166
  return embeds, mask
167
 
168
 
 
179
  conditioners (dict): Dictionary of conditioners.
180
  device (torch.device or str, optional): Device for conditioners and output condition types.
181
  """
182
+ def __init__(self,
183
+ conditioners,
184
+ device="cpu"):
185
  super().__init__()
186
  self.device = device
187
  self.conditioners = nn.ModuleDict(conditioners)
 
196
 
197
  @property
198
  def text_conditions(self):
199
+ return [k for k, v in self.conditioners.items() if isinstance(v, T5Conditioner)]
200
 
201
 
202
 
audiocraft/lm.py CHANGED
@@ -10,9 +10,35 @@ from audiocraft.transformer import StreamingTransformer, create_norm_fn
10
  from dataclasses import dataclass
11
  from functools import partial
12
  from torch import nn
13
- from audiocraft.utils import utils
14
  from audiocraft.activations import get_activation_fn
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # ============================================== From LM.py
18
 
@@ -147,7 +173,7 @@ class LMModel(nn.Module):
147
  super().__init__()
148
  self.cfg_coef = cfg_coef
149
 
150
- self.n_draw = 2
151
  self.condition_provider = condition_provider
152
  self.fuser = fuser
153
  self.card = card # 2048 ?
@@ -332,7 +358,7 @@ class LMModel(nn.Module):
332
  token_count=offset)
333
 
334
  # print(f'BEF {logits.shape=} BEF utils.SampleTop5') # AGREES 4 BEF logits.shape=torch.Size([1, 4, 1, 2048]) BEF utils.SampleTop5
335
- next_token = utils.sample_top_k(logits, n_draw=self.n_draw) # [1,4,2048] logits
336
 
337
 
338
 
 
10
  from dataclasses import dataclass
11
  from functools import partial
12
  from torch import nn
 
13
  from audiocraft.activations import get_activation_fn
14
 
15
+ def sample_top_k(p, k=250, n_draw=None):
16
+ """
17
+ p probabs 2048 ?
18
+ num_draw : how many tokens to sample (for duplicate elongation)
19
+ """
20
+
21
+ p = torch.softmax(p, dim=-1) # p/temp
22
+
23
+
24
+
25
+ top_k_value, i250 = torch.topk(p, k, dim=-1) # probs: [1, 4, 2048]
26
+ # print('\n_____TOPK________\n', top_k_value.shape, top_k_value[0, 0, :10], '\n___________END_TOPK____________\n')
27
+ min_value_top_k = top_k_value[..., [-1]] #
28
+ p *= (p >= min_value_top_k).float()
29
+ p.div_(p.sum(dim=-1, keepdim=True))
30
+ # -- next_token = multinomial(probs, num_samples=num_draw)
31
+
32
+ # RESHAPED into bs, 4, 250
33
+ p_ = p.reshape(-1, p.shape[-1])
34
+
35
+
36
+ out = torch.multinomial(p_,
37
+ num_samples=n_draw,
38
+ replacement=False) # [4, num_draw]
39
+ return out.transpose(0, 1)[:, :, None] # [num_draw, 4, 1]
40
+
41
+
42
 
43
  # ============================================== From LM.py
44
 
 
173
  super().__init__()
174
  self.cfg_coef = cfg_coef
175
 
176
+ self.n_draw = 1
177
  self.condition_provider = condition_provider
178
  self.fuser = fuser
179
  self.card = card # 2048 ?
 
358
  token_count=offset)
359
 
360
  # print(f'BEF {logits.shape=} BEF utils.SampleTop5') # AGREES 4 BEF logits.shape=torch.Size([1, 4, 1, 2048]) BEF utils.SampleTop5
361
+ next_token = sample_top_k(logits, n_draw=self.n_draw) # [1,4,2048] logits
362
 
363
 
364
 
audiocraft/loaders.py CHANGED
@@ -120,32 +120,7 @@ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu',
120
  return model
121
 
122
 
123
- def load_lm_model_magnet(file_or_url_or_id: tp.Union[Path, str], compression_model_frame_rate: int,
124
- device='cpu', cache_dir: tp.Optional[str] = None):
125
- pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
126
- cfg = OmegaConf.create(pkg['xp.cfg'])
127
- cfg.device = str(device)
128
- if cfg.device == 'cpu':
129
- cfg.dtype = 'float32'
130
- else:
131
- cfg.dtype = 'float16'
132
- _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
133
- _delete_param(cfg, 'conditioners.args.drop_desc_p')
134
-
135
- cfg.transformer_lm.compression_model_framerate = compression_model_frame_rate
136
- cfg.transformer_lm.segment_duration = cfg.dataset.segment_duration
137
- cfg.transformer_lm.span_len = cfg.masking.span_len
138
 
139
- # MAGNeT models v1 support only xformers backend.
140
- from .transformer import set_efficient_attention_backend
141
- if cfg.transformer_lm.memory_efficient:
142
- set_efficient_attention_backend("xformers")
143
-
144
- model = builders.get_lm_model(cfg)
145
- model.load_state_dict(pkg['best_state'])
146
- model.eval()
147
- model.cfg = cfg
148
- return model
149
 
150
 
151
  def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str],
 
120
  return model
121
 
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
 
 
 
 
 
 
 
 
 
 
124
 
125
 
126
  def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str],
audiocraft/utils/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- """Utilities."""
 
 
 
 
 
 
 
audiocraft/utils/autocast.py DELETED
@@ -1,40 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import torch
8
-
9
-
10
- class TorchAutocast:
11
- """TorchAutocast utility class.
12
- Allows you to enable and disable autocast. This is specially useful
13
- when dealing with different architectures and clusters with different
14
- levels of support.
15
-
16
- Args:
17
- enabled (bool): Whether to enable torch.autocast or not.
18
- args: Additional args for torch.autocast.
19
- kwargs: Additional kwargs for torch.autocast
20
- """
21
- def __init__(self, enabled: bool, *args, **kwargs):
22
- self.autocast = torch.autocast(*args, **kwargs) if enabled else None
23
-
24
- def __enter__(self):
25
- if self.autocast is None:
26
- return
27
- try:
28
- self.autocast.__enter__()
29
- except RuntimeError:
30
- device = self.autocast.device
31
- dtype = self.autocast.fast_dtype
32
- raise RuntimeError(
33
- f"There was an error autocasting with dtype={dtype} device={device}\n"
34
- "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
35
- )
36
-
37
- def __exit__(self, *args, **kwargs):
38
- if self.autocast is None:
39
- return
40
- self.autocast.__exit__(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/utils/cache.py DELETED
@@ -1,324 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from concurrent.futures import ThreadPoolExecutor
8
- from collections import deque
9
- from functools import partial
10
- from hashlib import sha1
11
- import logging
12
- from pathlib import Path
13
- import sys
14
- import typing as tp
15
- import zipfile
16
-
17
- import flashy
18
- import torch
19
-
20
-
21
- logger = logging.getLogger(__name__)
22
-
23
-
24
- def get_full_embed(full_embed: torch.Tensor, x: tp.Any, idx: int, device: tp.Union[str, torch.device]) -> torch.Tensor:
25
- """Utility function for the EmbeddingCache, returning the full embedding without any chunking.
26
- This method can be used in case there is no need in extracting a chunk of the full embedding
27
- read from the cache.
28
-
29
- Args:
30
- full_embed (torch.Tensor): The full embedding.
31
- x (any): Batch object from which the full embedding is derived.
32
- idx (torch.Tensor): Index of object to consider in the batch object.
33
- Returns:
34
- full_embed (torch.Tensor): The full embedding
35
- """
36
- return full_embed.to(device)
37
-
38
-
39
- class EmbeddingCache:
40
- """Cache around embeddings computation for faster execution.
41
- The EmbeddingCache is storing pre-computed embeddings on disk and provides a simple API
42
- to retrieve the pre-computed embeddings on full inputs and extract only a given chunk
43
- using a user-provided function. When the cache is warm (all embeddings are pre-computed),
44
- the EmbeddingCache allows for faster training as it removes the need of computing the embeddings.
45
- Additionally, it provides in-memory cache around the loaded embeddings to limit IO footprint
46
- and synchronization points in the forward calls.
47
-
48
- Args:
49
- cache_path (Path): Path to folder where all pre-computed embeddings are saved on disk.
50
- device (str or torch.device): Device on which the embedding is returned.
51
- compute_embed_fn (callable[[Path, any, int], torch.Tensor], optional): Function to compute
52
- the embedding from a given object and path. This user provided function can compute the
53
- embedding from the provided object or using the provided path as entry point. The last parameter
54
- specify the index corresponding to the current embedding in the object that can represent batch metadata.
55
- extract_embed_fn (callable[[torch.Tensor, any, int], torch.Tensor], optional): Function to extract
56
- the desired embedding chunk from the full embedding loaded from the cache. The last parameter
57
- specify the index corresponding to the current embedding in the object that can represent batch metadata.
58
- If not specified, will return the full embedding unmodified.
59
- """
60
- def __init__(self, cache_path: tp.Union[str, Path], device: tp.Union[str, torch.device],
61
- compute_embed_fn: tp.Callable[[Path, tp.Any, int], torch.Tensor],
62
- extract_embed_fn: tp.Optional[tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]] = None):
63
- self.cache_path = Path(cache_path)
64
- self.device = device
65
- self._compute_embed_fn = compute_embed_fn
66
- self._extract_embed_fn: tp.Callable[[torch.Tensor, tp.Any, int], torch.Tensor]
67
- if extract_embed_fn is not None:
68
- self._extract_embed_fn = extract_embed_fn
69
- else:
70
- self._extract_embed_fn = partial(get_full_embed, device=device)
71
- if self.cache_path is not None:
72
- self.cache_path.mkdir(exist_ok=True, parents=True)
73
- logger.info(f"Cache instantiated at: {self.cache_path}")
74
- self.pool = ThreadPoolExecutor(8)
75
- self.pool.__enter__()
76
- self._current_batch_cache: dict = {}
77
- self._memory_cache: dict = {}
78
-
79
- def _get_cache_path(self, path: tp.Union[Path, str]):
80
- """Get cache path for the given file path."""
81
- sig = sha1(str(path).encode()).hexdigest()
82
- return self.cache_path / sig
83
-
84
- @staticmethod
85
- def _get_full_embed_from_cache(cache: Path):
86
- """Loads full pre-computed embedding from the cache."""
87
- try:
88
- embed = torch.load(cache, 'cpu')
89
- except Exception as exc:
90
- logger.error("Error loading %s: %r", cache, exc)
91
- embed = None
92
- return embed
93
-
94
- def get_embed_from_cache(self, paths: tp.List[Path], x: tp.Any) -> torch.Tensor:
95
- """Get embedding from cache, computing and storing it to cache if not already cached.
96
- The EmbeddingCache first tries to load the embedding from the in-memory cache
97
- containing the pre-computed chunks populated through `populate_embed_cache`.
98
- If not found, the full embedding is computed and stored on disk to be later accessed
99
- to populate the in-memory cache, and the desired embedding chunk is extracted and returned.
100
-
101
- Args:
102
- paths (list[Path or str]): List of paths from where the embeddings can be loaded.
103
- x (any): Object from which the embedding is extracted.
104
- """
105
- embeds = []
106
- for idx, path in enumerate(paths):
107
- cache = self._get_cache_path(path)
108
- if cache in self._current_batch_cache:
109
- embed = self._current_batch_cache[cache]
110
- else:
111
- full_embed = self._compute_embed_fn(path, x, idx)
112
- try:
113
- with flashy.utils.write_and_rename(cache, pid=True) as f:
114
- torch.save(full_embed.cpu(), f)
115
- except Exception as exc:
116
- logger.error('Error saving embed %s (%s): %r', cache, full_embed.shape, exc)
117
- else:
118
- logger.info('New embed cache saved: %s (%s)', cache, full_embed.shape)
119
- embed = self._extract_embed_fn(full_embed, x, idx)
120
- embeds.append(embed)
121
- embed = torch.stack(embeds, dim=0)
122
- return embed
123
-
124
- def populate_embed_cache(self, paths: tp.List[Path], x: tp.Any) -> None:
125
- """Populate in-memory caches for embeddings reading from the embeddings stored on disk.
126
- The in-memory caches consist in a cache for the full embedding and another cache for the
127
- final embedding chunk. Such caches are used to limit the IO access when computing the actual embeddings
128
- and reduce the IO footprint and synchronization points during forward passes.
129
-
130
- Args:
131
- paths (list[Path]): List of paths from where the embeddings can be loaded.
132
- x (any): Object from which the embedding is extracted.
133
- """
134
- self._current_batch_cache.clear()
135
- if self.cache_path is not None:
136
- futures: list = []
137
- for path in paths:
138
- assert path is not None, "Path is required for computation from cache"
139
- cache = self._get_cache_path(path)
140
- if cache in self._memory_cache or not cache.exists():
141
- futures.append(None)
142
- else:
143
- futures.append(self.pool.submit(EmbeddingCache._get_full_embed_from_cache, cache))
144
- for idx, (path, future) in enumerate(zip(paths, futures)):
145
- assert path is not None
146
- cache = self._get_cache_path(path)
147
- full_embed = None
148
- if future is None:
149
- if cache in self._memory_cache:
150
- full_embed = self._memory_cache[cache]
151
- else:
152
- full_embed = future.result()
153
- if full_embed is not None:
154
- self._memory_cache[cache] = full_embed
155
- full_embed = full_embed.to(self.device)
156
- if full_embed is not None:
157
- embed = self._extract_embed_fn(full_embed, x, idx)
158
- self._current_batch_cache[cache] = embed
159
-
160
-
161
- class CachedBatchWriter:
162
- """Write pre computed caches for mini batches. This can
163
- make loading a lot more efficient depending on your filesystem.
164
-
165
- Args:
166
- cache_folder (Path): folder in which the cached minibatches
167
- will be stored.
168
-
169
- Inside cache folder, the structure is the following:
170
- `epoch_number / update_number.zip`
171
- And the zip file contains one entry per batch item.
172
-
173
- It is possible to use the cache with a batch size smaller than
174
- created with but obviously not larger. Make sure to call the
175
- `start_epoch(epoch)` method for indicating changes of epochs.
176
-
177
- See the grid `audiocraft/grids/musicgen/musicgen_warmup_cache.py`
178
- for an example of how to warmup the cache.
179
- """
180
- def __init__(self, cache_folder: Path):
181
- self.cache_folder = cache_folder
182
- self._current_epoch: tp.Optional[int] = None
183
- self._current_index = 0
184
-
185
- def start_epoch(self, epoch: int):
186
- """Call at the beginning of each epoch.
187
- """
188
- self._current_epoch = epoch
189
- self._current_index = 0
190
- self._zip_path.parent.mkdir(exist_ok=True, parents=True)
191
-
192
- @staticmethod
193
- def _get_zip_path(cache_folder: Path, epoch: int, index: int):
194
- return cache_folder / f"{epoch:05d}" / f"{index:06d}.zip"
195
-
196
- @property
197
- def _zip_path(self):
198
- assert self._current_epoch is not None
199
- return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, self._current_index)
200
-
201
- def save(self, *content):
202
- """Save one mini batch. This function is distributed-aware
203
- and will automatically merge all the items from the different
204
- workers.
205
- """
206
- all_contents = []
207
- for rank in range(flashy.distrib.world_size()):
208
- their_content = flashy.distrib.broadcast_object(content, src=rank)
209
- all_contents.append(their_content)
210
-
211
- if flashy.distrib.is_rank_zero():
212
- idx = 0
213
- with flashy.utils.write_and_rename(self._zip_path) as tmp:
214
- with zipfile.ZipFile(tmp, 'w') as zf:
215
- for content in all_contents:
216
- for vals in zip(*content):
217
- with zf.open(f'{idx}', 'w') as f: # type: ignore
218
- torch.save(vals, f)
219
- idx += 1
220
- flashy.distrib.barrier()
221
- self._current_index += 1
222
-
223
-
224
- class CachedBatchLoader:
225
- """Loader for cached mini-batches dumped with `CachedBatchWriter`.
226
-
227
- Args:
228
- cache_folder (Path): folder in which the cached minibatches are stored.
229
- batch_size (int): batch size (per GPU) expected.
230
- num_workers (int): number of workers to use for loading.
231
- min_length (int): minimum expected length for each epoch. If some
232
- mini-batches are missing, and error is raised.
233
-
234
- This is iterable just like a regular DataLoader.
235
- """
236
-
237
- def __init__(self, cache_folder: Path, batch_size: int,
238
- num_workers: int = 10, min_length: int = 1):
239
- self.cache_folder = cache_folder
240
- self.batch_size = batch_size
241
- self.num_workers = num_workers
242
- self.min_length = min_length
243
- self._current_epoch: tp.Optional[int] = None
244
- self.sampler = None # for compatibility with the regular DataLoader
245
-
246
- def __len__(self):
247
- path = CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch or 0, 0).parent
248
- return len([p for p in path.iterdir() if p.suffix == ".zip"])
249
-
250
- def start_epoch(self, epoch: int):
251
- """Call at the beginning of each epoch.
252
- """
253
- self._current_epoch = epoch
254
-
255
- def _zip_path(self, index: int):
256
- assert self._current_epoch is not None
257
- return CachedBatchWriter._get_zip_path(self.cache_folder, self._current_epoch, index)
258
-
259
- def _load_one(self, index: int):
260
- zip_path = self._zip_path(index)
261
- if not zip_path.exists():
262
- if index < self.min_length:
263
- raise RuntimeError(f"Cache should have at least {self.min_length} batches, but {index} doesn't exist")
264
-
265
- return None
266
- mode = "rb" if sys.version_info >= (3, 9) else "r"
267
- try:
268
- with zipfile.ZipFile(zip_path, 'r') as zf:
269
- rank = flashy.distrib.rank()
270
- world_size = flashy.distrib.world_size()
271
- root = zipfile.Path(zf)
272
- items = list(root.iterdir())
273
- total_batch_size = self.batch_size * world_size
274
- if len(items) < total_batch_size:
275
- raise RuntimeError(
276
- f"The cache can handle a max batch size of {len(items)}, "
277
- f"but {total_batch_size} is needed.")
278
- start = rank * self.batch_size
279
- items = items[start: start + self.batch_size]
280
- assert len(items) == self.batch_size
281
- entries = []
282
- entries = [torch.load(item.open(mode), 'cpu') for item in items] # type: ignore
283
- transposed = zip(*entries)
284
- out = []
285
- for part in transposed:
286
- assert len(part) > 0
287
- if isinstance(part[0], torch.Tensor):
288
- out.append(torch.stack(part))
289
- else:
290
- assert isinstance(part, torch.Tensor)
291
- out.append(part)
292
- return out
293
- except Exception:
294
- logger.error("Error when reading zip path %s", zip_path)
295
- raise
296
-
297
- def __iter__(self):
298
- """This will yields tuples, exactly as provided to the
299
- `CachedBatchWriter.save` method.
300
- """
301
- pool = ThreadPoolExecutor(self.num_workers)
302
- next_index = 0
303
- queue = deque()
304
-
305
- def _get_next():
306
- nonlocal next_index
307
- r = queue.popleft().result()
308
- if r is None:
309
- return None
310
- else:
311
- queue.append(pool.submit(self._load_one, next_index))
312
- next_index += 1
313
- return r
314
-
315
- with pool:
316
- # fill the buffer of fetching jobs.
317
- for _ in range(2 * self.num_workers):
318
- queue.append(pool.submit(self._load_one, next_index))
319
- next_index += 1
320
- while True:
321
- batch = _get_next()
322
- if batch is None:
323
- return
324
- yield batch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/utils/checkpoint.py DELETED
@@ -1,161 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from enum import Enum
8
- import logging
9
- from pathlib import Path
10
- import re
11
- import typing as tp
12
-
13
- import flashy
14
- import torch
15
-
16
- from ..environment import AudioCraftEnvironment
17
-
18
-
19
- logger = logging.getLogger(__name__)
20
-
21
-
22
- class CheckpointSource(Enum):
23
- CURRENT_XP = "current_xp"
24
- PRETRAINED = "pretrained"
25
- OTHER = "other"
26
-
27
-
28
- def checkpoint_name(name: tp.Optional[str] = None, rank: tp.Optional[int] = None, use_fsdp: bool = False) -> str:
29
- """Checkpoint name formatted for all use in AudioCraft codebase and has the following format:
30
- `checkpoint_<name>.th(.<rank>)`. By convention, name is expected to be empty for last checkpoint,
31
- 'best' for the best checkpoint or the epoch number.
32
-
33
- Args:
34
- name (str, optional): Name suffix for the checkpoint file stem.
35
- rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
36
- use_fsdp (bool): Whether the calling solver relies on FSDP.
37
- Returns:
38
- str: The checkpoint name.
39
- """
40
- suffix = ''
41
- if rank is None:
42
- rank = flashy.distrib.rank()
43
- if rank > 0 and use_fsdp:
44
- suffix = '.' + str(rank)
45
- name_part = ''
46
- if name is not None:
47
- name_part = f'_{name}'
48
- return f'checkpoint{name_part}.th{suffix}'
49
-
50
-
51
- def is_sharded_checkpoint(path: Path) -> bool:
52
- """Whether the checkpoint at the given path corresponds to a sharded checkpoint across rank."""
53
- return re.search(r'\.th\.\d+$', path.name) is not None
54
-
55
-
56
- def resolve_checkpoint_path(sig_or_path: tp.Union[Path, str], name: tp.Optional[str] = None,
57
- use_fsdp: bool = False) -> tp.Optional[Path]:
58
- """Resolve a given checkpoint path for a provided dora sig or path.
59
-
60
- Args:
61
- sig_or_path (Path or str): Checkpoint path or dora signature.
62
- name (str, optional): Name suffix for the checkpoint file stem.
63
- rank (optional, int): Rank for distributed processing, retrieved with flashy if not provided.
64
- use_fsdp (bool): Whether the calling solver relies on FSDP.
65
- Returns:
66
- Path, optional: Resolved checkpoint path, if it exists.
67
- """
68
- from audiocraft import train
69
- xps_root = train.main.dora.dir / 'xps'
70
- sig_or_path = str(sig_or_path)
71
- if sig_or_path.startswith('//sig/'):
72
- sig = sig_or_path[len('//sig/'):]
73
- path = xps_root / sig
74
- else:
75
- path = Path(sig_or_path)
76
- path = AudioCraftEnvironment.resolve_reference_path(path)
77
-
78
- if path.is_dir():
79
- path = path / checkpoint_name(name, use_fsdp=use_fsdp)
80
-
81
- if path.exists():
82
- return path
83
- else:
84
- return None
85
-
86
-
87
- def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> tp.Any:
88
- """Load state from checkpoints at the specified checkpoint path."""
89
- if is_sharded:
90
- rank0_checkpoint_path = checkpoint_path.parent / checkpoint_name(use_fsdp=False)
91
- if rank0_checkpoint_path.exists():
92
- check_sharded_checkpoint(checkpoint_path, rank0_checkpoint_path)
93
- state = torch.load(checkpoint_path, 'cpu')
94
- logger.info("Checkpoint loaded from %s", checkpoint_path)
95
- return state
96
-
97
-
98
- def save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None:
99
- """Save state to disk to the specified checkpoint_path."""
100
- _safe_save_checkpoint(state, checkpoint_path, is_sharded)
101
- logger.info("Checkpoint saved to %s", checkpoint_path)
102
-
103
-
104
- def flush_stale_checkpoints(checkpoint_path: Path, keep_last: tp.Optional[int] = None) -> None:
105
- """Flush checkpoints to only keep last N checkpoints."""
106
- if keep_last is None or keep_last <= 0:
107
- return
108
- checkpoint_dir = checkpoint_path.parent
109
- suffix = ''
110
- if flashy.distrib.rank() > 0:
111
- suffix = f'.{flashy.distrib.rank()}'
112
- checkpoint_files_with_epoch = []
113
- for path in Path(checkpoint_dir).glob(f'checkpoint_*.th{suffix}'):
114
- epoch_part = path.name.split('.', 1)[0].split('_', 1)[1]
115
- if epoch_part.isdigit():
116
- checkpoint_files_with_epoch.append((path, int(epoch_part)))
117
- checkpoint_files = [path for path, _ in list(sorted(checkpoint_files_with_epoch, key=lambda t: t[1]))]
118
- total_to_flush = max(0, len(checkpoint_files) - keep_last)
119
- files_to_flush = checkpoint_files[:total_to_flush]
120
- for path in files_to_flush:
121
- logger.debug("Removing checkpoint: %s", str(path))
122
- path.unlink(missing_ok=True)
123
-
124
-
125
- def check_sharded_checkpoint(checkpoint_path: Path, rank0_checkpoint_path: Path) -> None:
126
- """Check sharded checkpoint state, ensuring the checkpoints are not corrupted."""
127
- # Finish the work of a previous run that got interrupted while dumping.
128
- old_path = Path(str(checkpoint_path) + '.old')
129
- if old_path.exists():
130
- raise RuntimeError(
131
- f"Old checkpoint {old_path} from previous version of this code exist, cannot safely proceed.")
132
- token = Path(str(rank0_checkpoint_path) + '.tmp.done')
133
- tmp_path = Path(str(checkpoint_path) + '.tmp')
134
- if token.exists():
135
- if tmp_path.exists():
136
- tmp_path.rename(checkpoint_path)
137
- flashy.distrib.barrier()
138
- if flashy.distrib.is_rank_zero() and token.exists():
139
- token.unlink()
140
-
141
-
142
- def _safe_save_checkpoint(state: tp.Any, checkpoint_path: Path, is_sharded: bool = False) -> None:
143
- """Save checkpoints in a safe manner even with when sharded checkpoints across nodes."""
144
- def _barrier_if_sharded():
145
- if is_sharded:
146
- flashy.distrib.barrier()
147
-
148
- if flashy.distrib.is_rank_zero():
149
- token = Path(str(checkpoint_path) + '.tmp.done')
150
- if token.exists():
151
- token.unlink()
152
- _barrier_if_sharded()
153
- with flashy.utils.write_and_rename(checkpoint_path) as f:
154
- torch.save(state, f)
155
- _barrier_if_sharded()
156
- if flashy.distrib.is_rank_zero():
157
- token.touch()
158
- _barrier_if_sharded()
159
- _barrier_if_sharded()
160
- if flashy.distrib.rank() == 0:
161
- token.unlink()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/utils/export.py DELETED
@@ -1,79 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Utility to export a training checkpoint to a lightweight release checkpoint.
9
- """
10
-
11
- from pathlib import Path
12
- import typing as tp
13
-
14
- from omegaconf import OmegaConf
15
- import torch
16
-
17
- from audiocraft import __version__
18
-
19
-
20
- def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
21
- """Export only the best state from the given EnCodec checkpoint. This
22
- should be used if you trained your own EnCodec model.
23
- """
24
- pkg = torch.load(checkpoint_path, 'cpu')
25
- new_pkg = {
26
- 'best_state': pkg['best_state']['model'],
27
- 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
28
- 'version': __version__,
29
- 'exported': True,
30
- }
31
- Path(out_file).parent.mkdir(exist_ok=True, parents=True)
32
- torch.save(new_pkg, out_file)
33
- return out_file
34
-
35
-
36
- def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]):
37
- """Export a compression model (potentially EnCodec) from a pretrained model.
38
- This is required for packaging the audio tokenizer along a MusicGen or AudioGen model.
39
- Do not include the //pretrained/ prefix. For instance if you trained a model
40
- with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`.
41
-
42
- In that case, this will not actually include a copy of the model, simply the reference
43
- to the model used.
44
- """
45
- if Path(pretrained_encodec).exists():
46
- pkg = torch.load(pretrained_encodec)
47
- assert 'best_state' in pkg
48
- assert 'xp.cfg' in pkg
49
- assert 'version' in pkg
50
- assert 'exported' in pkg
51
- else:
52
- pkg = {
53
- 'pretrained': pretrained_encodec,
54
- 'exported': True,
55
- 'version': __version__,
56
- }
57
- Path(out_file).parent.mkdir(exist_ok=True, parents=True)
58
- torch.save(pkg, out_file)
59
-
60
-
61
- def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
62
- """Export only the best state from the given MusicGen or AudioGen checkpoint.
63
- """
64
- pkg = torch.load(checkpoint_path, 'cpu')
65
- if pkg['fsdp_best_state']:
66
- best_state = pkg['fsdp_best_state']['model']
67
- else:
68
- assert pkg['best_state']
69
- best_state = pkg['best_state']['model']
70
- new_pkg = {
71
- 'best_state': best_state,
72
- 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
73
- 'version': __version__,
74
- 'exported': True,
75
- }
76
-
77
- Path(out_file).parent.mkdir(exist_ok=True, parents=True)
78
- torch.save(new_pkg, out_file)
79
- return out_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/utils/export_legacy.py DELETED
@@ -1,70 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Legacy functions used at the time of the first release, kept for referencd.
9
- """
10
-
11
- from pathlib import Path
12
- import typing as tp
13
-
14
- from omegaconf import OmegaConf, DictConfig
15
- import torch
16
-
17
- from audiocraft import __version__
18
-
19
-
20
- def _clean_lm_cfg(cfg: DictConfig):
21
- OmegaConf.set_struct(cfg, False)
22
- # This used to be set automatically in the LM solver, need a more robust solution
23
- # for the future.
24
- cfg['transformer_lm']['card'] = 2048
25
- n_q = 4
26
- stereo_cfg = getattr(cfg, 'interleave_stereo_codebooks', None)
27
- if stereo_cfg is not None and stereo_cfg.use:
28
- if 'downsample' in stereo_cfg:
29
- del stereo_cfg['downsample']
30
- n_q = 8
31
- cfg['transformer_lm']['n_q'] = n_q
32
- # Experimental params no longer supported.
33
- bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
34
- 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
35
- for name in bad_params:
36
- del cfg['transformer_lm'][name]
37
- OmegaConf.set_struct(cfg, True)
38
- return cfg
39
-
40
-
41
- def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
42
- pkg = torch.load(checkpoint_path, 'cpu')
43
- new_pkg = {
44
- 'best_state': pkg['ema']['state']['model'],
45
- 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
46
- # The following params were NOT exported for the first release of MusicGen.
47
- 'version': __version__,
48
- 'exported': True,
49
- }
50
- Path(out_file).parent.mkdir(exist_ok=True, parents=True)
51
- torch.save(new_pkg, out_file)
52
- return out_file
53
-
54
-
55
- def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]):
56
- pkg = torch.load(checkpoint_path, 'cpu')
57
- if pkg['fsdp_best_state']:
58
- best_state = pkg['fsdp_best_state']['model']
59
- else:
60
- best_state = pkg['best_state']['model']
61
- new_pkg = {
62
- 'best_state': best_state,
63
- 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg'])),
64
- # The following params were NOT exported for the first release of MusicGen.
65
- 'version': __version__,
66
- 'exported': True,
67
- }
68
- Path(out_file).parent.mkdir(exist_ok=True, parents=True)
69
- torch.save(new_pkg, out_file)
70
- return out_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/utils/notebook.py DELETED
@@ -1,32 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- try:
8
- import IPython.display as ipd # type: ignore
9
- except ImportError:
10
- # Note in a notebook...
11
- pass
12
-
13
-
14
- import torch
15
-
16
-
17
- def display_audio(samples: torch.Tensor, sample_rate: int):
18
- """Renders an audio player for the given audio samples.
19
-
20
- Args:
21
- samples (torch.Tensor): a Tensor of decoded audio samples
22
- with shapes [B, C, T] or [C, T]
23
- sample_rate (int): sample rate audio should be displayed with.
24
- """
25
- assert samples.dim() == 2 or samples.dim() == 3
26
-
27
- samples = samples.detach().cpu()
28
- if samples.dim() == 2:
29
- samples = samples[None, ...]
30
-
31
- for audio in samples:
32
- ipd.display(ipd.Audio(audio, rate=sample_rate))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/utils/profiler.py DELETED
@@ -1,38 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import logging
8
- import typing as tp
9
-
10
- import dora
11
- import torch
12
-
13
-
14
- logger = logging.getLogger(__name__)
15
-
16
-
17
- class Profiler:
18
- """Context manager wrapper for xformers profiler.
19
- """
20
- def __init__(self, module: torch.nn.Module, enabled: bool = False):
21
- self.profiler: tp.Optional[tp.Any] = None
22
- if enabled:
23
- from xformers.profiler import profile
24
- output_dir = dora.get_xp().folder / 'profiler_data'
25
- logger.info("Profiling activated, results with be saved to %s", output_dir)
26
- self.profiler = profile(output_dir=output_dir, module=module)
27
-
28
- def step(self):
29
- if self.profiler is not None:
30
- self.profiler.step() # type: ignore
31
-
32
- def __enter__(self):
33
- if self.profiler is not None:
34
- return self.profiler.__enter__() # type: ignore
35
-
36
- def __exit__(self, exc_type, exc_value, exc_tb):
37
- if self.profiler is not None:
38
- return self.profiler.__exit__(exc_type, exc_value, exc_tb) # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/utils/utils.py DELETED
@@ -1,98 +0,0 @@
1
- import hashlib
2
- import json
3
- import logging
4
- import typing as tp
5
- import flashy
6
- import flashy.distrib
7
- import omegaconf
8
- import torch
9
-
10
-
11
- logger = logging.getLogger(__name__)
12
-
13
-
14
- def model_hash(model: torch.nn.Module) -> str:
15
- """Return a model hash. This should allow us to track regressions in model init
16
- from the logs of past experiments.
17
- """
18
- hasher = hashlib.sha1()
19
- for p in model.parameters():
20
- hasher.update(p.data.cpu().numpy().tobytes())
21
- return hasher.hexdigest()
22
-
23
-
24
- def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
25
- """Convenience function to map an omegaconf configuration to a dictionary.
26
-
27
- Args:
28
- cfg (omegaconf.DictConfig): Original configuration to map to dict.
29
- Returns:
30
- dict: Config as dictionary object.
31
- """
32
- dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
33
- assert isinstance(dct, dict)
34
- return dct
35
-
36
-
37
-
38
-
39
-
40
- def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int,
41
- num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader:
42
- """Convenience function to load dataset into a dataloader with optional subset sampling.
43
-
44
- Args:
45
- dataset: Dataset to load.
46
- num_samples (Optional[int]): Number of samples to limit subset size.
47
- batch_size (int): Batch size.
48
- num_workers (int): Number of workers for data loading.
49
- seed (int): Random seed.
50
- """
51
- if num_samples is not None:
52
- dataset = random_subset(dataset, num_samples, seed)
53
-
54
- dataloader = flashy.distrib.loader(
55
- dataset,
56
- batch_size=batch_size,
57
- num_workers=num_workers,
58
- **kwargs
59
- )
60
- return dataloader
61
-
62
-
63
- def get_dataset_from_loader(dataloader):
64
- dataset = dataloader.dataset
65
- if isinstance(dataset, torch.utils.data.Subset):
66
- return dataset.dataset
67
- else:
68
- return dataset
69
-
70
-
71
-
72
-
73
-
74
- def sample_top_k(p, k=250, n_draw=None):
75
- """
76
- p probabs 2048 ?
77
- num_draw : how many tokens to sample (for duplicate elongation)
78
- """
79
-
80
- p = torch.softmax(p, dim=-1) # p/temp
81
-
82
-
83
-
84
- top_k_value, i250 = torch.topk(p, k, dim=-1) # probs: [1, 4, 2048]
85
- # print('\n_____TOPK________\n', top_k_value.shape, top_k_value[0, 0, :10], '\n___________END_TOPK____________\n')
86
- min_value_top_k = top_k_value[..., [-1]] #
87
- p *= (p >= min_value_top_k).float()
88
- p.div_(p.sum(dim=-1, keepdim=True))
89
- # -- next_token = multinomial(probs, num_samples=num_draw)
90
-
91
- # RESHAPED into bs, 4, 250
92
- p_ = p.reshape(-1, p.shape[-1])
93
-
94
-
95
- out = torch.multinomial(p_,
96
- num_samples=n_draw,
97
- replacement=False) # [4, num_draw]
98
- return out.transpose(0, 1)[:, :, None] # [num_draw, 4, 1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo.py CHANGED
@@ -1,14 +1,84 @@
1
- from audiocraft.audiogen import AudioGen #, audio_write
2
  import audiofile
3
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  print('\n\n\n\n___________________')
6
 
7
  txt = 'dogs barging in the street'
8
 
9
- sound_generator = AudioGen.get_pretrained('facebook/audiogen-medium')
10
- sound_generator.set_generation_params(duration=.46) # why is generating so long at 14 seconds
11
-
12
  x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
13
  x /= np.abs(x).max() + 1e-7
14
 
 
 
1
  import audiofile
2
  import numpy as np
3
+ import typing as tp
4
+ import torch
5
+
6
+ from audiocraft.loaders import load_compression_model, load_lm_model
7
+ from audiocraft.lm import LMModel
8
+ from audiocraft.conditioners import ConditioningAttributes
9
+
10
+
11
+
12
+
13
+ class AudioGen():
14
+
15
+ def __init__(self,
16
+ compression_model=None,
17
+ lm=None,
18
+ duration=.04,
19
+ top_k=249):
20
+
21
+ self.compression_model = compression_model
22
+ self.lm = lm
23
+ self.top_k = top_k
24
+ self.compression_model.eval()
25
+ self.lm.eval()
26
+ self.duration = duration
27
+ self.device = next(iter(lm.parameters())).device
28
+
29
+ @property
30
+ def frame_rate(self) -> float:
31
+ """Roughly the number of AR steps per seconds."""
32
+ return self.compression_model.frame_rate
33
+
34
+ @property
35
+ def sample_rate(self) -> int:
36
+ """Sample rate of the generated audio."""
37
+ return self.compression_model.sample_rate
38
+
39
+ def generate(self, descriptions):
40
+ attributes = [
41
+ ConditioningAttributes(text={'description': d}) for d in descriptions]
42
+ tokens = self._generate_tokens(attributes)
43
+ print(f'\n{tokens.shape=}\n{tokens=} FINAL 5 AUD')
44
+ return self.generate_audio(tokens)
45
+
46
+ def _generate_tokens(self, attributes):
47
+ total_gen_len = int(self.duration * self.frame_rate)
48
+ gen_tokens = self.lm.generate(conditions=attributes,
49
+ max_gen_len=total_gen_len)
50
+ gen_tokens = gen_tokens.transpose(0, 1).reshape(4, -1)[None, :, :]
51
+ return gen_tokens
52
+
53
+ def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor:
54
+ """Generate Audio from tokens."""
55
+ assert gen_tokens.dim() == 3
56
+ with torch.no_grad():
57
+ gen_audio = self.compression_model.decode(gen_tokens, None)
58
+ return gen_audio
59
+
60
+ device = 'cuda:0'
61
+ # https://huggingface.co/facebook/audiogen-medium
62
+
63
+
64
+ sound_generator = AudioGen(
65
+ compression_model=load_compression_model('facebook/audiogen-medium', device=device),
66
+ lm=load_lm_model('facebook/audiogen-medium', device=device).to(torch.float),
67
+ duration=.04,
68
+ top_k=1)
69
+
70
+
71
+
72
+
73
+
74
+
75
+
76
+
77
 
78
  print('\n\n\n\n___________________')
79
 
80
  txt = 'dogs barging in the street'
81
 
 
 
 
82
  x = sound_generator.generate([txt])[0].detach().cpu().numpy()[0, :]
83
  x /= np.abs(x).max() + 1e-7
84