Dionyssos commited on
Commit
54adc39
1 Parent(s): e70ad00

AudioGen class

Browse files
audiocraft/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ from .builders import AudioGen
audiocraft/builders.py CHANGED
@@ -1,12 +1,11 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
  import typing as tp
8
  import omegaconf
 
9
  import torch
 
 
 
 
10
  from .encodec import EncodecModel
11
  from .lm import LMModel
12
  from .seanet import SEANetDecoder
@@ -15,143 +14,253 @@ from .conditioners import (
15
  ConditionFuser,
16
  ConditioningProvider,
17
  T5Conditioner,
 
18
  )
19
-
20
  from .vq import ResidualVectorQuantizer
21
 
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def dict_from_config(cfg):
25
  dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
26
  return dct
27
 
28
- def get_quantizer(quantizer, cfg, dimension):
29
- klass = {
30
- 'no_quant': None,
31
- 'rvq': ResidualVectorQuantizer
32
- }[quantizer]
33
- kwargs = dict_from_config(getattr(cfg, quantizer))
34
- if quantizer != 'no_quant':
35
- kwargs['dimension'] = dimension
36
- return klass(**kwargs)
37
-
38
-
39
- def get_encodec_autoencoder(cfg):
40
- kwargs = dict_from_config(getattr(cfg, 'seanet'))
41
- _ = kwargs.pop('encoder')
42
- decoder_override_kwargs = kwargs.pop('decoder')
43
- decoder_kwargs = {**kwargs, **decoder_override_kwargs}
44
- decoder = SEANetDecoder(**decoder_kwargs)
45
- return decoder
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
 
 
 
 
 
 
 
 
48
 
49
- def get_compression_model(cfg):
50
- """Instantiate a compression model."""
51
- if cfg.compression_model == 'encodec':
52
- kwargs = dict_from_config(getattr(cfg, 'encodec'))
53
- quantizer_name = kwargs.pop('quantizer')
54
- decoder = get_encodec_autoencoder(cfg)
55
- quantizer = get_quantizer(quantizer_name, cfg, 128)
56
- renormalize = kwargs.pop('renormalize', False)
57
- # deprecated params
58
- # print(f'{frame_rate=} {encoder.dimension=}') frame_rate=50 encoder.dimension=128
59
- kwargs.pop('renorm', None)
60
- # print('\n______!____________\n', kwargs, '\n______!____________\n')
61
- # ______!____________
62
- # {'autoencoder': 'seanet', 'sample_rate': 16000, 'channels': 1, 'causal': False}
63
- # ______!____________
64
-
65
- return EncodecModel(decoder=decoder,
66
- quantizer=quantizer,
67
- frame_rate=50,
68
- renormalize=renormalize,
69
- sample_rate=16000,
70
- channels=1,
71
- causal=False
72
- ).to(cfg.device)
73
- else:
74
- raise KeyError(f"Unexpected compression model {cfg.compression_model}")
75
-
76
-
77
- def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
78
- """Instantiate a transformer LM."""
79
- if cfg.lm_model in ['transformer_lm', 'transformer_lm_magnet']:
80
- kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
81
- n_q = kwargs['n_q']
82
- q_modeling = kwargs.pop('q_modeling', None)
83
- codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
84
- attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
85
- cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
86
- cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
87
- fuser = get_condition_fuser(cfg)
88
- condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
89
- if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
90
- kwargs['cross_attention'] = True
91
- if codebooks_pattern_cfg.modeling is None:
92
- assert q_modeling is not None, \
93
- "LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
94
- codebooks_pattern_cfg = omegaconf.OmegaConf.create(
95
- {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
96
- )
97
-
98
- pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
99
- # lm_class = MagnetLMModel if cfg.lm_model == 'transformer_lm_magnet' else LMModel
100
- lm_class = LMModel # hard coded D
101
- print(f'{lm_class=}\n\n\n\n=====================')
102
- return lm_class(
103
- pattern_provider=pattern_provider,
104
- condition_provider=condition_provider,
105
- fuser=fuser,
106
- cfg_dropout=cfg_prob,
107
- cfg_coef=cfg_coef,
108
- attribute_dropout=attribute_dropout,
109
- dtype=getattr(torch, cfg.dtype),
110
- device=cfg.device,
111
- **kwargs
112
- ).to(cfg.device)
113
- else:
114
- raise KeyError(f"Unexpected LM model {cfg.lm_model}")
115
-
116
-
117
- def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
118
- """Instantiate a conditioning model."""
119
- device = cfg.device
120
- duration = cfg.dataset.segment_duration
121
- cfg = getattr(cfg, 'conditioners')
122
- dict_cfg = {} if cfg is None else dict_from_config(cfg)
123
- conditioners: tp.Dict[str, T5Conditioner] = {}
124
- condition_provider_args = dict_cfg.pop('args', {})
125
- condition_provider_args.pop('merge_text_conditions_p', None)
126
- condition_provider_args.pop('drop_desc_p', None)
127
-
128
- for cond, cond_cfg in dict_cfg.items():
129
- model_type = cond_cfg['model']
130
- model_args = cond_cfg[model_type]
131
- if model_type == 't5':
132
- conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
133
  else:
134
- raise ValueError(f"Unrecognized conditioning model: {model_type}")
135
- conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
136
- return conditioner
137
-
138
-
139
- def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
140
- """Instantiate a condition fuser object."""
141
- fuser_cfg = getattr(cfg, 'fuser')
142
- fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate']
143
- fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
144
- kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
145
- fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
146
- return fuser
147
-
148
-
149
- def get_codebooks_pattern_provider(n_q, cfg):
150
- pattern_providers = {
151
- 'delay': DelayedPatternProvider, # THIS
152
- }
153
- name = cfg.modeling
154
- kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
- klass = pattern_providers[name]
157
- return klass(n_q, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import typing as tp
2
  import omegaconf
3
+ from torch import nn
4
  import torch
5
+ from huggingface_hub import hf_hub_download
6
+ import os
7
+ from omegaconf import OmegaConf, DictConfig
8
+
9
  from .encodec import EncodecModel
10
  from .lm import LMModel
11
  from .seanet import SEANetDecoder
 
14
  ConditionFuser,
15
  ConditioningProvider,
16
  T5Conditioner,
17
+ ConditioningAttributes
18
  )
 
19
  from .vq import ResidualVectorQuantizer
20
 
21
 
22
 
23
+
24
+ def _delete_param(cfg: DictConfig, full_name: str):
25
+ parts = full_name.split('.')
26
+ for part in parts[:-1]:
27
+ if part in cfg:
28
+ cfg = cfg[part]
29
+ else:
30
+ return
31
+ OmegaConf.set_struct(cfg, False)
32
+ if parts[-1] in cfg:
33
+ del cfg[parts[-1]]
34
+ OmegaConf.set_struct(cfg, True)
35
+
36
+
37
+
38
  def dict_from_config(cfg):
39
  dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
40
  return dct
41
 
42
+
43
+
44
+
45
+
46
+
47
+
48
+
49
+ # ============================================== DEFINE AUDIOGEN
50
+
51
+
52
+
53
+
54
+
55
+
56
+ class AudioGen(nn.Module):
 
 
 
57
 
58
+ # https://huggingface.co/facebook/audiogen-medium
59
+
60
+ def __init__(self,
61
+ duration=0.024,
62
+ device='cpu'):
63
+
64
+ super().__init__()
65
+ self.device = device # needed for loading & select float16 LM
66
+ self.load_compression_model()
67
+ self.load_lm_model()
68
+ self.duration = duration
69
+
70
+ @property
71
+ def frame_rate(self):
72
+ return self.compression_model.frame_rate
73
+
74
+ def generate(self,
75
+ descriptions):
76
+ with torch.no_grad():
77
+ attributes = [
78
+ ConditioningAttributes(text={'description': d}) for d in descriptions]
79
+ gen_tokens = self.lm.generate(
80
+ conditions=attributes,
81
+ max_gen_len=int(self.duration * self.frame_rate)) #[n_draw, 4, 37]
82
+ x = self.compression_model.decode(gen_tokens, None) #[n_draw, 1, 11840]
83
+ n_draw, _, n_time_samples = x.shape
84
+ x = x.reshape(1, n_draw * n_time_samples) # linearise n_draw
85
+ print('______________\nGENTOk 5', gen_tokens)
86
+ print('GENAUD 5', x.sum())
87
+ return x
88
+
89
+ # == BUILD Fn
90
+ def get_quantizer(self, quantizer, cfg, dimension):
91
+ klass = {
92
+ 'no_quant': None,
93
+ 'rvq': ResidualVectorQuantizer
94
+ }[quantizer]
95
+ kwargs = dict_from_config(getattr(cfg, quantizer))
96
+ if quantizer != 'no_quant':
97
+ kwargs['dimension'] = dimension
98
+ return klass(**kwargs)
99
+
100
 
101
+ def get_encodec_autoencoder(self, cfg):
102
+ kwargs = dict_from_config(getattr(cfg, 'seanet'))
103
+ _ = kwargs.pop('encoder')
104
+ decoder_override_kwargs = kwargs.pop('decoder')
105
+ decoder_kwargs = {**kwargs, **decoder_override_kwargs}
106
+ decoder = SEANetDecoder(**decoder_kwargs)
107
+ return decoder
108
+
109
 
110
+
111
+ def get_compression_model(self, cfg):
112
+ """Instantiate a compression model."""
113
+ if cfg.compression_model == 'encodec':
114
+ kwargs = dict_from_config(getattr(cfg, 'encodec'))
115
+ quantizer_name = kwargs.pop('quantizer')
116
+ decoder = self.get_encodec_autoencoder(cfg)
117
+ quantizer = self.get_quantizer(quantizer_name, cfg, 128)
118
+ renormalize = kwargs.pop('renormalize', False)
119
+ # deprecated params
120
+ # print(f'{frame_rate=} {encoder.dimension=}') frame_rate=50 encoder.dimension=128
121
+ kwargs.pop('renorm', None)
122
+ # print('\n______!____________\n', kwargs, '\n______!____________\n')
123
+ # ______!____________
124
+ # {'autoencoder': 'seanet', 'sample_rate': 16000, 'channels': 1, 'causal': False}
125
+ # ______!____________
126
+
127
+ return EncodecModel(decoder=decoder,
128
+ quantizer=quantizer,
129
+ frame_rate=50,
130
+ renormalize=renormalize,
131
+ sample_rate=16000,
132
+ channels=1,
133
+ causal=False
134
+ ).to(cfg.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  else:
136
+ raise KeyError(f"Unexpected compression model {cfg.compression_model}")
137
+
138
+
139
+ def get_lm_model(self, cfg):
140
+ """Instantiate a transformer LM."""
141
+ if cfg.lm_model in ['transformer_lm',
142
+ 'transformer_lm_magnet']:
143
+ kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
144
+ n_q = kwargs['n_q']
145
+ q_modeling = kwargs.pop('q_modeling', None)
146
+ codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
147
+ attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
148
+ cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
149
+ cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
150
+ fuser = self.get_condition_fuser(cfg)
151
+ condition_provider = self.get_conditioner_provider(kwargs["dim"], cfg
152
+ ).to(self.device)
153
+
154
+
155
+ if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
156
+ kwargs['cross_attention'] = True
157
+ if codebooks_pattern_cfg.modeling is None:
158
+ print('Q MODELING\n=\n=><')
159
+ assert q_modeling is not None, \
160
+ "LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
161
+ codebooks_pattern_cfg = omegaconf.OmegaConf.create(
162
+ {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
163
+ )
164
+
165
+ pattern_provider = self.get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
166
+ return LMModel(
167
+ pattern_provider=pattern_provider,
168
+ condition_provider=condition_provider,
169
+ fuser=fuser,
170
+ cfg_dropout=cfg_prob,
171
+ cfg_coef=cfg_coef,
172
+ attribute_dropout=attribute_dropout,
173
+ dtype=getattr(torch, cfg.dtype),
174
+ device=self.device,
175
+ **kwargs
176
+ ).to(cfg.device)
177
+ else:
178
+ raise KeyError(f"Unexpected LM model {cfg.lm_model}")
179
+
180
+
181
+ def get_conditioner_provider(self, output_dim,
182
+ cfg):
183
+ """Instantiate T5 text"""
184
+ cfg = getattr(cfg, 'conditioners')
185
+ dict_cfg = {} if cfg is None else dict_from_config(cfg)
186
+ conditioners={}
187
+ condition_provider_args = dict_cfg.pop('args', {})
188
+ condition_provider_args.pop('merge_text_conditions_p', None)
189
+ condition_provider_args.pop('drop_desc_p', None)
190
+
191
+ for cond, cond_cfg in dict_cfg.items():
192
+ model_type = cond_cfg['model']
193
+ model_args = cond_cfg[model_type]
194
+ if model_type == 't5':
195
+ conditioners[str(cond)] = T5Conditioner(output_dim=output_dim,
196
+ device=self.device,
197
+ **model_args)
198
+ else:
199
+ raise ValueError(f"Unrecognized conditioning model: {model_type}")
200
+
201
+ # print(f'{condition_provider_args=}')
202
+ return ConditioningProvider(conditioners)
203
+
204
+
205
+ def get_condition_fuser(self, cfg):
206
+ """Instantiate a condition fuser object."""
207
+ fuser_cfg = getattr(cfg, 'fuser')
208
+ fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate']
209
+ fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
210
+ kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
211
+ fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
212
+ return fuser
213
+
214
+
215
+ def get_codebooks_pattern_provider(self, n_q, cfg):
216
+ pattern_providers = {
217
+ 'delay': DelayedPatternProvider, # THIS
218
+ }
219
+ name = cfg.modeling
220
+ kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
221
+
222
+ klass = pattern_providers[name]
223
+ return klass(n_q, **kwargs)
224
 
225
+ # ======================
226
+ def load_compression_model(self):
227
+ file = hf_hub_download(
228
+ repo_id='facebook/audiogen-medium',
229
+ filename="compression_state_dict.bin",
230
+ cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
231
+ library_name="audiocraft",
232
+ library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
233
+ pkg = torch.load(file, map_location='cpu')
234
+ # if 'pretrained' in pkg:
235
+ # print('NO RPtrained\n=\n=\n=\n=\n=')
236
+ # return EncodecModel.get_pretrained(pkg['pretrained'], device='cpu')
237
+ cfg = OmegaConf.create(pkg['xp.cfg'])
238
+ cfg.device = 'cpu'
239
+ model = self.get_compression_model(cfg)
240
+ model.load_state_dict(pkg['best_state'], strict=False) # ckpt has also unused encoder weights
241
+ # return model
242
+ self.compression_model = model
243
+
244
+ def load_lm_model(self):
245
+ file = hf_hub_download(
246
+ repo_id='facebook/audiogen-medium',
247
+ filename="state_dict.bin",
248
+ cache_dir=os.environ.get('AUDIOCRAFT_CACHE_DIR', None),
249
+ library_name="audiocraft",
250
+ library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
251
+ pkg = torch.load(file,
252
+ map_location=self.device) #'cpu')
253
+ cfg = OmegaConf.create(pkg['xp.cfg'])
254
+ # cfg.device = 'cpu'
255
+ if self.device == 'cpu':
256
+ cfg.dtype = 'float32'
257
+ else:
258
+ cfg.dtype = 'float16'
259
+ _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
260
+ _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
261
+ _delete_param(cfg, 'conditioners.args.drop_desc_p')
262
+ model = self.get_lm_model(cfg)
263
+ model.load_state_dict(pkg['best_state'])
264
+ model.cfg = cfg
265
+ # return model
266
+ self.lm = model.to(torch.float)
audiocraft/conditioners.py CHANGED
@@ -173,27 +173,12 @@ class T5Conditioner(nn.Module):
173
 
174
 
175
  class ConditioningProvider(nn.Module):
176
- """Prepare and provide conditions given all the supported conditioners.
177
-
178
- Args:
179
- conditioners (dict): Dictionary of conditioners.
180
- device (torch.device or str, optional): Device for conditioners and output condition types.
181
- """
182
  def __init__(self,
183
- conditioners,
184
- device="cpu"):
185
  super().__init__()
186
- self.device = device
187
  self.conditioners = nn.ModuleDict(conditioners)
188
 
189
- # @property
190
- # def joint_embed_conditions(self):
191
- # return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)]
192
-
193
- # @property
194
- # def has_joint_embed_conditions(self):
195
- # return len(self.joint_embed_conditions) > 0
196
-
197
  @property
198
  def text_conditions(self):
199
  return [k for k, v in self.conditioners.items() if isinstance(v, T5Conditioner)]
@@ -201,19 +186,6 @@ class ConditioningProvider(nn.Module):
201
 
202
 
203
  def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
204
- """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
205
- This should be called before starting any real GPU work to avoid synchronization points.
206
- This will return a dict matching conditioner names to their arbitrary tokenized representations.
207
-
208
- Args:
209
- inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing
210
- text and wav conditions.
211
- """
212
- assert all([isinstance(x, ConditioningAttributes) for x in inputs]), (
213
- "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]",
214
- f" but types were {set([type(x) for x in inputs])}"
215
- )
216
-
217
  output = {}
218
  text = self._collate_text(inputs)
219
  # wavs = self._collate_wavs(inputs)
@@ -223,9 +195,9 @@ class ConditioningProvider(nn.Module):
223
  # f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
224
  # f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
225
  # )
226
-
227
  for attribute, batch in text.items(): #, joint_embeds.items()):
228
  output[attribute] = self.conditioners[attribute].tokenize(batch)
 
229
  return output
230
 
231
  def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
@@ -246,7 +218,7 @@ class ConditioningProvider(nn.Module):
246
  output[attribute] = (condition, mask)
247
  return output
248
 
249
- def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
250
  """Given a list of ConditioningAttributes objects, compile a dictionary where the keys
251
  are the attributes and the values are the aggregated input per attribute.
252
  For example:
 
173
 
174
 
175
  class ConditioningProvider(nn.Module):
176
+
 
 
 
 
 
177
  def __init__(self,
178
+ conditioners):
 
179
  super().__init__()
 
180
  self.conditioners = nn.ModuleDict(conditioners)
181
 
 
 
 
 
 
 
 
 
182
  @property
183
  def text_conditions(self):
184
  return [k for k, v in self.conditioners.items() if isinstance(v, T5Conditioner)]
 
186
 
187
 
188
  def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  output = {}
190
  text = self._collate_text(inputs)
191
  # wavs = self._collate_wavs(inputs)
 
195
  # f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ",
196
  # f"got {text.keys(), wavs.keys(), joint_embeds.keys()}"
197
  # )
 
198
  for attribute, batch in text.items(): #, joint_embeds.items()):
199
  output[attribute] = self.conditioners[attribute].tokenize(batch)
200
+ print(f'COndProvToknz {output=}\n==')
201
  return output
202
 
203
  def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
 
218
  output[attribute] = (condition, mask)
219
  return output
220
 
221
+ def _collate_text(self, samples):
222
  """Given a list of ConditioningAttributes objects, compile a dictionary where the keys
223
  are the attributes and the values are the aggregated input per attribute.
224
  For example:
audiocraft/lm.py CHANGED
@@ -10,7 +10,7 @@ from functools import partial
10
  from torch import nn
11
  from audiocraft.activations import get_activation_fn
12
 
13
- def sample_top_k(p, k=250, n_draw=None):
14
  """
15
  p probabs 2048 ?
16
  num_draw : how many tokens to sample (for duplicate elongation)
@@ -32,8 +32,8 @@ def sample_top_k(p, k=250, n_draw=None):
32
 
33
 
34
  out = torch.multinomial(p_,
35
- num_samples=n_draw,
36
- replacement=False) # [4, num_draw]
37
  return out.transpose(0, 1)[:, :, None] # [num_draw, 4, 1]
38
 
39
 
@@ -171,7 +171,7 @@ class LMModel(nn.Module):
171
  super().__init__()
172
  self.cfg_coef = cfg_coef
173
 
174
- self.n_draw = 3
175
  self.condition_provider = condition_provider
176
  self.fuser = fuser
177
  self.card = card # 2048 ?
@@ -265,7 +265,7 @@ class LMModel(nn.Module):
265
  # input_, cross_attention_input = self.fuser(input_, condition_tensors)
266
  cross_attention_input = condition_tensors['description'][0]
267
 
268
- print(f'{input_.shape=}')
269
  out = self.transformer(input_,
270
  cross_attention_src=cross_attention_input,
271
  token_count=token_count)
 
10
  from torch import nn
11
  from audiocraft.activations import get_activation_fn
12
 
13
+ def sample_top_k(p, k=1, n_draw=None):
14
  """
15
  p probabs 2048 ?
16
  num_draw : how many tokens to sample (for duplicate elongation)
 
32
 
33
 
34
  out = torch.multinomial(p_,
35
+ num_samples=n_draw,
36
+ replacement=False) # [4, num_draw]
37
  return out.transpose(0, 1)[:, :, None] # [num_draw, 4, 1]
38
 
39
 
 
171
  super().__init__()
172
  self.cfg_coef = cfg_coef
173
 
174
+ self.n_draw = 1
175
  self.condition_provider = condition_provider
176
  self.fuser = fuser
177
  self.card = card # 2048 ?
 
265
  # input_, cross_attention_input = self.fuser(input_, condition_tensors)
266
  cross_attention_input = condition_tensors['description'][0]
267
 
268
+ # print(f'{input_.shape=}')
269
  out = self.transformer(input_,
270
  cross_attention_src=cross_attention_input,
271
  token_count=token_count)
audiocraft/loaders.py DELETED
@@ -1,130 +0,0 @@
1
- from pathlib import Path
2
- from huggingface_hub import hf_hub_download
3
- import typing as tp
4
- import os
5
- from omegaconf import OmegaConf, DictConfig
6
- import torch
7
- from . import builders
8
- from .encodec import EncodecModel
9
-
10
-
11
- def get_audiocraft_cache_dir() -> tp.Optional[str]:
12
- return os.environ.get('AUDIOCRAFT_CACHE_DIR', None)
13
-
14
-
15
- def _get_state_dict(
16
- file_or_url_or_id: tp.Union[Path, str],
17
- filename: tp.Optional[str] = None,
18
- device='cpu',
19
- cache_dir: tp.Optional[str] = None,
20
- ):
21
- if cache_dir is None:
22
- cache_dir = get_audiocraft_cache_dir()
23
- # Return the state dict either from a file or url
24
- file_or_url_or_id = str(file_or_url_or_id)
25
- assert isinstance(file_or_url_or_id, str)
26
-
27
- if os.path.isfile(file_or_url_or_id):
28
- return torch.load(file_or_url_or_id, map_location=device)
29
-
30
- if os.path.isdir(file_or_url_or_id):
31
- file = f"{file_or_url_or_id}/{filename}"
32
- return torch.load(file, map_location=device)
33
-
34
- elif file_or_url_or_id.startswith('https://'):
35
- return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
36
-
37
- else:
38
- assert filename is not None, "filename needs to be defined if using HF checkpoints"
39
-
40
- file = hf_hub_download(
41
- repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir,
42
- library_name="audiocraft",
43
- library_version= '1.3.0a1') # Found at __init__.py #audiocraft.__version__)
44
- return torch.load(file, map_location=device)
45
-
46
-
47
- def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
48
- return _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
49
-
50
-
51
- def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
52
- pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
53
- if 'pretrained' in pkg:
54
- return EncodecModel.get_pretrained(pkg['pretrained'], device=device)
55
- cfg = OmegaConf.create(pkg['xp.cfg'])
56
- cfg.device = str(device)
57
- model = builders.get_compression_model(cfg)
58
- model.load_state_dict(pkg['best_state'], strict=False) # ckpt contains uninstantiated encoder
59
- model.eval()
60
- return model
61
-
62
-
63
- def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
64
- return _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
65
-
66
-
67
- def _delete_param(cfg: DictConfig, full_name: str):
68
- parts = full_name.split('.')
69
- for part in parts[:-1]:
70
- if part in cfg:
71
- cfg = cfg[part]
72
- else:
73
- return
74
- OmegaConf.set_struct(cfg, False)
75
- if parts[-1] in cfg:
76
- del cfg[parts[-1]]
77
- OmegaConf.set_struct(cfg, True)
78
-
79
-
80
- def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu',
81
- cache_dir: tp.Optional[str] = None):
82
- pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
83
- cfg = OmegaConf.create(pkg['xp.cfg'])
84
- cfg.device = str(device)
85
- if cfg.device == 'cpu':
86
- cfg.dtype = 'float32'
87
- else:
88
- cfg.dtype = 'float16'
89
- _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
90
- _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
91
- _delete_param(cfg, 'conditioners.args.drop_desc_p')
92
- model = builders.get_lm_model(cfg)
93
- model.load_state_dict(pkg['best_state'])
94
- model.eval()
95
- model.cfg = cfg
96
- return model
97
-
98
-
99
-
100
-
101
-
102
- def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str],
103
- filename: tp.Optional[str] = None,
104
- cache_dir: tp.Optional[str] = None):
105
- return _get_state_dict(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
106
-
107
-
108
- def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str],
109
- device='cpu',
110
- filename: tp.Optional[str] = None,
111
- cache_dir: tp.Optional[str] = None):
112
- pkg = load_mbd_ckpt(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
113
- models = []
114
- processors = []
115
- cfgs = []
116
- sample_rate = pkg['sample_rate']
117
- for i in range(pkg['n_bands']):
118
- cfg = pkg[i]['cfg']
119
- model = builders.get_diffusion_model(cfg)
120
- model_dict = pkg[i]['model_state']
121
- model.load_state_dict(model_dict)
122
- model.to(device)
123
- processor = builders.get_processor(cfg=cfg.processor, sample_rate=sample_rate)
124
- processor_dict = pkg[i]['processor_state']
125
- processor.load_state_dict(processor_dict)
126
- processor.to(device)
127
- models.append(model)
128
- processors.append(processor)
129
- cfgs.append(cfg)
130
- return models, processors, cfgs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/lstm.py DELETED
@@ -1,25 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- from torch import nn
8
-
9
-
10
- class StreamableLSTM(nn.Module):
11
- """LSTM without worrying about the hidden state, nor the layout of the data.
12
- Expects input as convolutional layout.
13
- """
14
- def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
15
- super().__init__()
16
- self.skip = skip
17
- self.lstm = nn.LSTM(dimension, dimension, num_layers)
18
-
19
- def forward(self, x):
20
- x = x.permute(2, 0, 1)
21
- y, _ = self.lstm(x)
22
- if self.skip:
23
- y = y + x
24
- y = y.permute(1, 2, 0)
25
- return y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/seanet.py CHANGED
@@ -5,12 +5,30 @@
5
  # LICENSE file in the root directory of this source tree.
6
 
7
  import typing as tp
8
-
9
  import numpy as np
10
  import torch.nn as nn
11
-
12
  from .conv import StreamableConv1d, StreamableConvTranspose1d
13
- from .lstm import StreamableLSTM
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  class SEANetResnetBlock(nn.Module):
 
5
  # LICENSE file in the root directory of this source tree.
6
 
7
  import typing as tp
 
8
  import numpy as np
9
  import torch.nn as nn
 
10
  from .conv import StreamableConv1d, StreamableConvTranspose1d
11
+
12
+
13
+
14
+ class StreamableLSTM(nn.Module):
15
+ """LSTM without worrying about the hidden state, nor the layout of the data.
16
+ Expects input as convolutional layout.
17
+ """
18
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
19
+ super().__init__()
20
+ self.skip = skip
21
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
22
+
23
+ def forward(self, x):
24
+ print('LSTM called 1c')
25
+ x = x.permute(2, 0, 1)
26
+ y, _ = self.lstm(x)
27
+ if self.skip:
28
+ y = y + x
29
+ y = y.permute(1, 2, 0)
30
+ return y
31
+
32
 
33
 
34
  class SEANetResnetBlock(nn.Module):
demo.py CHANGED
@@ -1,64 +1,10 @@
1
  import audiofile
2
  import numpy as np
3
- import torch
4
- from audiocraft.loaders import load_compression_model, load_lm_model
5
- from audiocraft.conditioners import ConditioningAttributes
6
-
7
-
8
-
9
-
10
- class AudioGen():
11
-
12
- def __init__(self,
13
- compression_model=None,
14
- lm=None,
15
- duration=.74):
16
-
17
- self.compression_model = compression_model
18
- self.lm = lm
19
- self.duration = duration
20
-
21
- @property
22
- def frame_rate(self):
23
- return self.compression_model.frame_rate
24
-
25
- def generate(self,
26
- descriptions):
27
- with torch.no_grad():
28
- attributes = [
29
- ConditioningAttributes(text={'description': d}) for d in descriptions]
30
- gen_tokens = self.lm.generate(
31
- conditions=attributes,
32
- max_gen_len=int(self.duration * self.frame_rate)) #[n_draw, 4, 37]
33
- x = self.compression_model.decode(gen_tokens, None) #[n_draw, 1, 11840]
34
- n_draw, _, n_time_samples = x.shape
35
- x = x.reshape(1, n_draw * n_time_samples) # linearise n_draw
36
- return x
37
-
38
-
39
-
40
-
41
- device = 'cuda:0'
42
- # https://huggingface.co/facebook/audiogen-medium
43
-
44
-
45
- sound_generator = AudioGen(
46
- compression_model=load_compression_model('facebook/audiogen-medium', device=device).eval(),
47
- lm=load_lm_model('facebook/audiogen-medium', device=device).to(torch.float).eval(),
48
- duration=.74)
49
-
50
-
51
-
52
-
53
-
54
-
55
-
56
-
57
-
58
- print('\n\n\n\n___________________')
59
-
60
  txt = 'dogs barging in the street'
61
 
 
 
62
  x = sound_generator.generate([txt])[0].detach().cpu().numpy()
63
  x /= np.abs(x).max() + 1e-7
64
 
 
1
  import audiofile
2
  import numpy as np
3
+ from audiocraft import AudioGen
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  txt = 'dogs barging in the street'
5
 
6
+ sound_generator = AudioGen(duration=.04,
7
+ device='cuda:0').to('cuda:0').eval()
8
  x = sound_generator.generate([txt])[0].detach().cpu().numpy()
9
  x /= np.abs(x).max() + 1e-7
10