adefossez commited on
Commit
756be9c
2 Parent(s): 1897b6f 9138f15

Merge branch 'main' into our_hf2

Browse files
README.md CHANGED
@@ -56,15 +56,21 @@ You can play with MusicGen by running the jupyter notebook at [`demo.ipynb`](./d
56
  ## API
57
 
58
  We provide a simple API and 4 pre-trained models. The pre trained models are:
59
- - `small`: 300M model, text to music only,
60
- - `medium`: 1.5B model, text to music only,
61
- - `melody`: 1.5B model, text to music and text+melody to music,
62
- - `large`: 3.3B model, text to music only.
63
 
64
  We observe the best trade-off between quality and compute with the `medium` or `melody` model.
65
  In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller
66
  GPUs will be able to generate short sequences, or longer sequences with the `small` model.
67
 
 
 
 
 
 
 
68
  See after a quick example for using the API.
69
 
70
  ```python
@@ -84,7 +90,7 @@ wav = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), s
84
 
85
  for idx, one_wav in enumerate(wav):
86
  # Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
87
- audio_write(f'{idx}', one_wav, model.sample_rate, strategy="loudness")
88
  ```
89
 
90
 
 
56
  ## API
57
 
58
  We provide a simple API and 4 pre-trained models. The pre trained models are:
59
+ - `small`: 300M model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-small)
60
+ - `medium`: 1.5B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-medium)
61
+ - `melody`: 1.5B model, text to music and text+melody to music - [🤗 Hub](https://huggingface.co/facebook/musicgen-melody)
62
+ - `large`: 3.3B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-large)
63
 
64
  We observe the best trade-off between quality and compute with the `medium` or `melody` model.
65
  In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller
66
  GPUs will be able to generate short sequences, or longer sequences with the `small` model.
67
 
68
+ **Note**: Please make sure to have [ffmpeg](https://ffmpeg.org/download.html) installed when using newer version of `torchaudio`.
69
+ You can install it with:
70
+ ```
71
+ apt get install ffmpeg
72
+ ```
73
+
74
  See after a quick example for using the API.
75
 
76
  ```python
 
90
 
91
  for idx, one_wav in enumerate(wav):
92
  # Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
93
+ audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness")
94
  ```
95
 
96
 
app.py CHANGED
@@ -9,7 +9,7 @@ LICENSE file in the root directory of this source tree.
9
  from tempfile import NamedTemporaryFile
10
  import torch
11
  import gradio as gr
12
- from hf_loading import get_pretrained
13
 
14
  from audiocraft.data.audio import audio_write
15
 
@@ -19,7 +19,7 @@ MODEL = None
19
 
20
  def load_model(version):
21
  print("Loading model", version)
22
- return get_pretrained(version)
23
 
24
 
25
  def predict(model, text, melody, duration, topk, topp, temperature, cfg_coef):
 
9
  from tempfile import NamedTemporaryFile
10
  import torch
11
  import gradio as gr
12
+ from audiocraft.models import MusicGen
13
 
14
  from audiocraft.data.audio import audio_write
15
 
 
19
 
20
  def load_model(version):
21
  print("Loading model", version)
22
+ return MusicGen.get_pretrained(version)
23
 
24
 
25
  def predict(model, text, melody, duration, topk, topp, temperature, cfg_coef):
app_batched.py CHANGED
@@ -11,7 +11,7 @@ import torch
11
  import gradio as gr
12
  from audiocraft.data.audio_utils import convert_audio
13
  from audiocraft.data.audio import audio_write
14
- from hf_loading import get_pretrained
15
 
16
 
17
  MODEL = None
@@ -19,7 +19,7 @@ MODEL = None
19
 
20
  def load_model():
21
  print("Loading model")
22
- return get_pretrained("melody")
23
 
24
 
25
  def predict(texts, melodies):
 
11
  import gradio as gr
12
  from audiocraft.data.audio_utils import convert_audio
13
  from audiocraft.data.audio import audio_write
14
+ from audiocraft.models import MusicGen
15
 
16
 
17
  MODEL = None
 
19
 
20
  def load_model():
21
  print("Loading model")
22
+ return MusicGen.get_pretrained("melody")
23
 
24
 
25
  def predict(texts, melodies):
audiocraft/models/loaders.py CHANGED
@@ -20,7 +20,9 @@ of the returned model.
20
  """
21
 
22
  from pathlib import Path
 
23
  import typing as tp
 
24
 
25
  from omegaconf import OmegaConf
26
  import torch
@@ -28,18 +30,43 @@ import torch
28
  from . import builders
29
 
30
 
31
- def _get_state_dict(file_or_url: tp.Union[Path, str], device='cpu'):
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  # Return the state dict either from a file or url
33
- file_or_url = str(file_or_url)
34
- assert isinstance(file_or_url, str)
35
- if file_or_url.startswith('https://'):
36
- return torch.hub.load_state_dict_from_url(file_or_url, map_location=device, check_hash=True)
 
 
 
 
 
 
 
 
 
 
 
 
37
  else:
38
- return torch.load(file_or_url, device)
39
 
40
 
41
- def load_compression_model(file_or_url: tp.Union[Path, str], device='cpu'):
42
- pkg = _get_state_dict(file_or_url)
43
  cfg = OmegaConf.create(pkg['xp.cfg'])
44
  cfg.device = str(device)
45
  model = builders.get_compression_model(cfg)
@@ -48,8 +75,8 @@ def load_compression_model(file_or_url: tp.Union[Path, str], device='cpu'):
48
  return model
49
 
50
 
51
- def load_lm_model(file_or_url: tp.Union[Path, str], device='cpu'):
52
- pkg = _get_state_dict(file_or_url)
53
  cfg = OmegaConf.create(pkg['xp.cfg'])
54
  cfg.device = str(device)
55
  if cfg.device == 'cpu':
 
20
  """
21
 
22
  from pathlib import Path
23
+ from huggingface_hub import hf_hub_download
24
  import typing as tp
25
+ import os
26
 
27
  from omegaconf import OmegaConf
28
  import torch
 
30
  from . import builders
31
 
32
 
33
+ HF_MODEL_CHECKPOINTS_MAP = {
34
+ "small": "facebook/musicgen-small",
35
+ "medium": "facebook/musicgen-medium",
36
+ "large": "facebook/musicgen-large",
37
+ "melody": "facebook/musicgen-melody",
38
+ }
39
+
40
+
41
+ def _get_state_dict(
42
+ file_or_url_or_id: tp.Union[Path, str],
43
+ filename: tp.Optional[str] = None,
44
+ device='cpu',
45
+ cache_dir: tp.Optional[str] = None,
46
+ ):
47
  # Return the state dict either from a file or url
48
+ file_or_url_or_id = str(file_or_url_or_id)
49
+ assert isinstance(file_or_url_or_id, str)
50
+
51
+ if os.path.isfile(file_or_url_or_id):
52
+ return torch.load(file_or_url_or_id, map_location=device)
53
+
54
+ elif file_or_url_or_id.startswith('https://'):
55
+ return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
56
+
57
+ elif file_or_url_or_id in HF_MODEL_CHECKPOINTS_MAP:
58
+ assert filename is not None, "filename needs to be defined if using HF checkpoints"
59
+
60
+ repo_id = HF_MODEL_CHECKPOINTS_MAP[file_or_url_or_id]
61
+ file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir)
62
+ return torch.load(file, map_location=device)
63
+
64
  else:
65
+ raise ValueError(f"{file_or_url_or_id} is not a valid name, path or link that can be loaded.")
66
 
67
 
68
+ def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
69
+ pkg = _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
70
  cfg = OmegaConf.create(pkg['xp.cfg'])
71
  cfg.device = str(device)
72
  model = builders.get_compression_model(cfg)
 
75
  return model
76
 
77
 
78
+ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
79
+ pkg = _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
80
  cfg = OmegaConf.create(pkg['xp.cfg'])
81
  cfg.device = str(device)
82
  if cfg.device == 'cpu':
audiocraft/models/musicgen.py CHANGED
@@ -17,7 +17,7 @@ import torch
17
  from .encodec import CompressionModel
18
  from .lm import LMModel
19
  from .builders import get_debug_compression_model, get_debug_lm_model
20
- from .loaders import load_compression_model, load_lm_model
21
  from ..data.audio_utils import convert_audio
22
  from ..modules.conditioners import ConditioningAttributes, WavCondition
23
  from ..utils.autocast import TorchAutocast
@@ -67,10 +67,10 @@ class MusicGen:
67
  @staticmethod
68
  def get_pretrained(name: str = 'melody', device='cuda'):
69
  """Return pretrained model, we provide four models:
70
- - small (300M), text to music,
71
- - medium (1.5B), text to music,
72
- - melody (1.5B) text to music and text+melody to music,
73
- - large (3.3B), text to music.
74
  """
75
 
76
  if name == 'debug':
@@ -79,21 +79,16 @@ class MusicGen:
79
  lm = get_debug_lm_model(device)
80
  return MusicGen(name, compression_model, lm)
81
 
82
- if 'MUSICGEN_ROOT' in os.environ:
83
- ROOT = os.environ['MUSICGEN_ROOT']
84
- if not ROOT.endswith('/'):
85
- ROOT += '/'
86
- else:
87
- ROOT = 'https://dl.fbaipublicfiles.com/audiocraft/musicgen/v0/'
88
- compression_model = load_compression_model(ROOT + 'b0dbef54-37d256b525.th', device=device)
89
- names = {
90
- 'small': 'ba7a97ba-830fe5771e',
91
- 'medium': 'aa73ae27-fbc9f401db',
92
- 'large': '9b6e835c-1f0cf17b5e',
93
- 'melody': 'f79af192-61305ffc49',
94
- }
95
- sig = names[name]
96
- lm = load_lm_model(ROOT + f'{sig}.th', device=device)
97
  return MusicGen(name, compression_model, lm)
98
 
99
  def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
 
17
  from .encodec import CompressionModel
18
  from .lm import LMModel
19
  from .builders import get_debug_compression_model, get_debug_lm_model
20
+ from .loaders import load_compression_model, load_lm_model, HF_MODEL_CHECKPOINTS_MAP
21
  from ..data.audio_utils import convert_audio
22
  from ..modules.conditioners import ConditioningAttributes, WavCondition
23
  from ..utils.autocast import TorchAutocast
 
67
  @staticmethod
68
  def get_pretrained(name: str = 'melody', device='cuda'):
69
  """Return pretrained model, we provide four models:
70
+ - small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
71
+ - medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
72
+ - melody (1.5B) text to music and text+melody to music, # see: https://huggingface.co/facebook/musicgen-melody
73
+ - large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
74
  """
75
 
76
  if name == 'debug':
 
79
  lm = get_debug_lm_model(device)
80
  return MusicGen(name, compression_model, lm)
81
 
82
+ if name not in HF_MODEL_CHECKPOINTS_MAP:
83
+ raise ValueError(
84
+ f"{name} is not a valid checkpoint name. "
85
+ f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
86
+ )
87
+
88
+ cache_dir = os.environ.get('MUSICGEN_ROOT', None)
89
+ compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
90
+ lm = load_lm_model(name, device=device, cache_dir=cache_dir)
91
+
 
 
 
 
 
92
  return MusicGen(name, compression_model, lm)
93
 
94
  def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
audiocraft/utils/utils.py CHANGED
@@ -122,7 +122,7 @@ def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
122
  probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
123
  probs_sum = torch.cumsum(probs_sort, dim=-1)
124
  mask = probs_sum - probs_sort > p
125
- probs_sort *= (~mask).float(0)
126
  probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
127
  next_token = multinomial(probs_sort, num_samples=1)
128
  next_token = torch.gather(probs_idx, -1, next_token)
 
122
  probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
123
  probs_sum = torch.cumsum(probs_sort, dim=-1)
124
  mask = probs_sum - probs_sort > p
125
+ probs_sort *= (~mask).float()
126
  probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
127
  next_token = multinomial(probs_sort, num_samples=1)
128
  next_token = torch.gather(probs_idx, -1, next_token)
hf_loading.py DELETED
@@ -1,61 +0,0 @@
1
- """Utility for loading the models from HF."""
2
- from pathlib import Path
3
- import typing as tp
4
-
5
- from omegaconf import OmegaConf
6
- from huggingface_hub import hf_hub_download
7
- import torch
8
-
9
- from audiocraft.models import builders, MusicGen
10
-
11
- MODEL_CHECKPOINTS_MAP = {
12
- "small": "facebook/musicgen-small",
13
- "medium": "facebook/musicgen-medium",
14
- "large": "facebook/musicgen-large",
15
- "melody": "facebook/musicgen-melody",
16
- }
17
-
18
-
19
- def _get_state_dict(file_or_url: tp.Union[Path, str],
20
- filename="state_dict.bin", device='cpu'):
21
- # Return the state dict either from a file or url
22
- print("loading", file_or_url, filename)
23
- file_or_url = str(file_or_url)
24
- assert isinstance(file_or_url, str)
25
- return torch.load(
26
- hf_hub_download(repo_id=file_or_url, filename=filename), map_location=device)
27
-
28
-
29
- def load_compression_model(file_or_url: tp.Union[Path, str], device='cpu'):
30
- pkg = _get_state_dict(file_or_url, filename="compression_state_dict.bin")
31
- cfg = OmegaConf.create(pkg['xp.cfg'])
32
- cfg.device = str(device)
33
- model = builders.get_compression_model(cfg)
34
- model.load_state_dict(pkg['best_state'])
35
- model.eval()
36
- model.cfg = cfg
37
- return model
38
-
39
-
40
- def load_lm_model(file_or_url: tp.Union[Path, str], device='cpu'):
41
- pkg = _get_state_dict(file_or_url)
42
- cfg = OmegaConf.create(pkg['xp.cfg'])
43
- cfg.device = str(device)
44
- if cfg.device == 'cpu':
45
- cfg.transformer_lm.memory_efficient = False
46
- cfg.transformer_lm.custom = True
47
- cfg.dtype = 'float32'
48
- else:
49
- cfg.dtype = 'float16'
50
- model = builders.get_lm_model(cfg)
51
- model.load_state_dict(pkg['best_state'])
52
- model.eval()
53
- model.cfg = cfg
54
- return model
55
-
56
-
57
- def get_pretrained(name: str = 'small', device='cuda'):
58
- model_id = MODEL_CHECKPOINTS_MAP[name]
59
- compression_model = load_compression_model(model_id, device=device)
60
- lm = load_lm_model(model_id, device=device)
61
- return MusicGen(name, compression_model, lm)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mypy.ini CHANGED
@@ -1,4 +1,4 @@
1
  [mypy]
2
 
3
- [mypy-treetable,torchaudio.*,soundfile,einops.*,av.*,tqdm.*,num2words.*,spacy,xformers.*,scipy]
4
  ignore_missing_imports = True
 
1
  [mypy]
2
 
3
+ [mypy-treetable,torchaudio.*,soundfile,einops.*,av.*,tqdm.*,num2words.*,spacy,xformers.*,scipy,huggingface_hub]
4
  ignore_missing_imports = True
requirements.txt CHANGED
@@ -11,6 +11,7 @@ sentencepiece
11
  spacy==3.5.2
12
  torch>=2.0.0
13
  torchaudio>=2.0.0
 
14
  tqdm
15
  transformers
16
  xformers
 
11
  spacy==3.5.2
12
  torch>=2.0.0
13
  torchaudio>=2.0.0
14
+ huggingface_hub
15
  tqdm
16
  transformers
17
  xformers