adefossez commited on
Commit
f714a9b
2 Parent(s): 898c175 756be9c

Merge branch 'main' into our_hf

Browse files
README.md CHANGED
@@ -56,15 +56,21 @@ You can play with MusicGen by running the jupyter notebook at [`demo.ipynb`](./d
56
  ## API
57
 
58
  We provide a simple API and 4 pre-trained models. The pre trained models are:
59
- - `small`: 300M model, text to music only,
60
- - `medium`: 1.5B model, text to music only,
61
- - `melody`: 1.5B model, text to music and text+melody to music,
62
- - `large`: 3.3B model, text to music only.
63
 
64
  We observe the best trade-off between quality and compute with the `medium` or `melody` model.
65
  In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller
66
  GPUs will be able to generate short sequences, or longer sequences with the `small` model.
67
 
 
 
 
 
 
 
68
  See after a quick example for using the API.
69
 
70
  ```python
@@ -84,7 +90,7 @@ wav = model.generate_with_chroma(descriptions, melody[None].expand(3, -1, -1), s
84
 
85
  for idx, one_wav in enumerate(wav):
86
  # Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
87
- audio_write(f'{idx}', one_wav, model.sample_rate, strategy="loudness")
88
  ```
89
 
90
 
 
56
  ## API
57
 
58
  We provide a simple API and 4 pre-trained models. The pre trained models are:
59
+ - `small`: 300M model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-small)
60
+ - `medium`: 1.5B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-medium)
61
+ - `melody`: 1.5B model, text to music and text+melody to music - [🤗 Hub](https://huggingface.co/facebook/musicgen-melody)
62
+ - `large`: 3.3B model, text to music only - [🤗 Hub](https://huggingface.co/facebook/musicgen-large)
63
 
64
  We observe the best trade-off between quality and compute with the `medium` or `melody` model.
65
  In order to use MusicGen locally **you must have a GPU**. We recommend 16GB of memory, but smaller
66
  GPUs will be able to generate short sequences, or longer sequences with the `small` model.
67
 
68
+ **Note**: Please make sure to have [ffmpeg](https://ffmpeg.org/download.html) installed when using newer version of `torchaudio`.
69
+ You can install it with:
70
+ ```
71
+ apt get install ffmpeg
72
+ ```
73
+
74
  See after a quick example for using the API.
75
 
76
  ```python
 
90
 
91
  for idx, one_wav in enumerate(wav):
92
  # Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
93
+ audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness")
94
  ```
95
 
96
 
app.py CHANGED
@@ -6,9 +6,12 @@ This source code is licensed under the license found in the
6
  LICENSE file in the root directory of this source tree.
7
  """
8
 
 
9
  import torch
10
  import gradio as gr
11
- from hf_loading import get_pretrained
 
 
12
 
13
 
14
  MODEL = None
@@ -16,7 +19,7 @@ MODEL = None
16
 
17
  def load_model(version):
18
  print("Loading model", version)
19
- return get_pretrained(version)
20
 
21
 
22
  def predict(model, text, melody, duration, topk, topp, temperature, cfg_coef):
@@ -51,8 +54,11 @@ def predict(model, text, melody, duration, topk, topp, temperature, cfg_coef):
51
  else:
52
  output = MODEL.generate(descriptions=[text], progress=False)
53
 
54
- output = output.detach().cpu().numpy()
55
- return MODEL.sample_rate, output
 
 
 
56
 
57
 
58
  with gr.Blocks() as demo:
@@ -60,25 +66,12 @@ with gr.Blocks() as demo:
60
  """
61
  # MusicGen
62
 
63
- This is the demo for MusicGen, a simple and controllable model for music generation presented at: "Simple and Controllable Music Generation".
64
-
65
- Below we present 3 model variations:
66
- 1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
67
- 2. Small -- a 300M transformer decoder conditioned on text only.
68
- 3. Medium -- a 1.5B transformer decoder conditioned on text only.
69
- 4. Large -- a 3.3B transformer decoder conditioned on text only (might OOM for the longest sequences.)
70
-
71
- When the optional melody conditioning wav is provided, the model will extract
72
- a broad melody and try to follow it in the generated samples.
73
-
74
- For skipping queue, you can duplicate this space, and upgrade to GPU in the settings.
75
  <br/>
76
- <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true">
77
- <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
78
- </p>
79
-
80
- See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
81
- for more details.
82
  """
83
  )
84
  with gr.Row():
@@ -98,7 +91,7 @@ with gr.Blocks() as demo:
98
  temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
99
  cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
100
  with gr.Column():
101
- output = gr.Audio(label="Generated Music", type="numpy")
102
  submit.click(predict, inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output])
103
  gr.Examples(
104
  fn=predict,
@@ -132,5 +125,21 @@ with gr.Blocks() as demo:
132
  inputs=[text, melody, model],
133
  outputs=[output]
134
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  demo.launch()
 
6
  LICENSE file in the root directory of this source tree.
7
  """
8
 
9
+ from tempfile import NamedTemporaryFile
10
  import torch
11
  import gradio as gr
12
+ from audiocraft.models import MusicGen
13
+
14
+ from audiocraft.data.audio import audio_write
15
 
16
 
17
  MODEL = None
 
19
 
20
  def load_model(version):
21
  print("Loading model", version)
22
+ return MusicGen.get_pretrained(version)
23
 
24
 
25
  def predict(model, text, melody, duration, topk, topp, temperature, cfg_coef):
 
54
  else:
55
  output = MODEL.generate(descriptions=[text], progress=False)
56
 
57
+ output = output.detach().cpu().float()[0]
58
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
59
+ audio_write(file.name, output, MODEL.sample_rate, strategy="loudness", add_suffix=False)
60
+ waveform_video = gr.make_waveform(file.name)
61
+ return waveform_video
62
 
63
 
64
  with gr.Blocks() as demo:
 
66
  """
67
  # MusicGen
68
 
69
+ This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
70
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
 
 
 
 
 
 
 
 
 
 
71
  <br/>
72
+ <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
73
+ <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
74
+ for longer sequences, more control and no queue.</p>
 
 
 
75
  """
76
  )
77
  with gr.Row():
 
91
  temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
92
  cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
93
  with gr.Column():
94
+ output = gr.Video(label="Generated Music")
95
  submit.click(predict, inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output])
96
  gr.Examples(
97
  fn=predict,
 
125
  inputs=[text, melody, model],
126
  outputs=[output]
127
  )
128
+ gr.Markdown(
129
+ """
130
+ ### More details
131
+
132
+ By typing a description of the music you want and an optional audio used for melody conditioning,
133
+
134
+ We present 4 model variations:
135
+ 1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
136
+ 2. Small -- a 300M transformer decoder conditioned on text only.
137
+ 3. Medium -- a 1.5B transformer decoder conditioned on text only.
138
+ 4. Large -- a 3.3B transformer decoder conditioned on text only (might OOM for the longest sequences.)
139
+
140
+ When the optional melody conditioning wav is provided, the model will extract
141
+ a broad melody and try to follow it in the generated samples.
142
+ """
143
+ )
144
 
145
  demo.launch()
app_batched.py CHANGED
@@ -11,7 +11,7 @@ import torch
11
  import gradio as gr
12
  from audiocraft.data.audio_utils import convert_audio
13
  from audiocraft.data.audio import audio_write
14
- from hf_loading import get_pretrained
15
 
16
 
17
  MODEL = None
@@ -19,7 +19,7 @@ MODEL = None
19
 
20
  def load_model():
21
  print("Loading model")
22
- return get_pretrained("melody")
23
 
24
 
25
  def predict(texts, melodies):
@@ -58,8 +58,9 @@ def predict(texts, melodies):
58
  for output in outputs:
59
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
60
  audio_write(file.name, output, MODEL.sample_rate, strategy="loudness", add_suffix=False)
61
- out_files.append([file.name])
62
- return out_files
 
63
 
64
 
65
  with gr.Blocks() as demo:
@@ -67,35 +68,23 @@ with gr.Blocks() as demo:
67
  """
68
  # MusicGen
69
 
70
- This is the demo for MusicGen, a simple and controllable model for music generation
71
- presented at: "Simple and Controllable Music Generation".
72
-
73
- Enter the description of the music you want and an optional audio used for melody conditioning.
74
- The model will extract the broad melody from the uploaded wav if provided.
75
- This will generate a 12s extract with the `melody` model.
76
-
77
- For generating longer sequences (up to 30 seconds) and skipping queue, you can duplicate
78
- to full demo space, which contains more control and upgrade to GPU in the settings.
79
  <br/>
80
- <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true">
81
- <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
82
- </p>
83
-
84
- You can also use your own GPU or a Google Colab by following the instructions on our repo.
85
-
86
- See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
87
- for more details.
88
  """
89
  )
90
  with gr.Row():
91
  with gr.Column():
92
  with gr.Row():
93
- text = gr.Text(label="Input Text", interactive=True)
94
- melody = gr.Audio(source="upload", type="numpy", label="Melody Condition (optional)", interactive=True)
95
  with gr.Row():
96
- submit = gr.Button("Submit")
97
  with gr.Column():
98
- output = gr.Audio(label="Generated Music", type="filepath", format="wav")
99
  submit.click(predict, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=12)
100
  gr.Examples(
101
  fn=predict,
@@ -124,5 +113,15 @@ with gr.Blocks() as demo:
124
  inputs=[text, melody],
125
  outputs=[output]
126
  )
 
 
 
 
 
 
 
 
 
 
127
 
128
  demo.queue(max_size=15).launch()
 
11
  import gradio as gr
12
  from audiocraft.data.audio_utils import convert_audio
13
  from audiocraft.data.audio import audio_write
14
+ from audiocraft.models import MusicGen
15
 
16
 
17
  MODEL = None
 
19
 
20
  def load_model():
21
  print("Loading model")
22
+ return MusicGen.get_pretrained("melody")
23
 
24
 
25
  def predict(texts, melodies):
 
58
  for output in outputs:
59
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
60
  audio_write(file.name, output, MODEL.sample_rate, strategy="loudness", add_suffix=False)
61
+ waveform_video = gr.make_waveform(file.name)
62
+ out_files.append(waveform_video)
63
+ return [out_files]
64
 
65
 
66
  with gr.Blocks() as demo:
 
68
  """
69
  # MusicGen
70
 
71
+ This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
72
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
 
 
 
 
 
 
 
73
  <br/>
74
+ <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
75
+ <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
76
+ for longer sequences, more control and no queue</p>
 
 
 
 
 
77
  """
78
  )
79
  with gr.Row():
80
  with gr.Column():
81
  with gr.Row():
82
+ text = gr.Text(label="Describe your music", lines=2, interactive=True)
83
+ melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True)
84
  with gr.Row():
85
+ submit = gr.Button("Generate")
86
  with gr.Column():
87
+ output = gr.Video(label="Generated Music")
88
  submit.click(predict, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=12)
89
  gr.Examples(
90
  fn=predict,
 
113
  inputs=[text, melody],
114
  outputs=[output]
115
  )
116
+ gr.Markdown("""
117
+ ### More details
118
+ By typing a description of the music you want and an optional audio used for melody conditioning,
119
+ the model will extract the broad melody from the uploaded wav if provided and generate a 12s extract with the `melody` model.
120
+
121
+ You can also use your own GPU or a Google Colab by following the instructions on our repo.
122
+
123
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
124
+ for more details.
125
+ """)
126
 
127
  demo.queue(max_size=15).launch()
audiocraft/models/loaders.py CHANGED
@@ -20,7 +20,9 @@ of the returned model.
20
  """
21
 
22
  from pathlib import Path
 
23
  import typing as tp
 
24
 
25
  from omegaconf import OmegaConf
26
  import torch
@@ -28,18 +30,43 @@ import torch
28
  from . import builders
29
 
30
 
31
- def _get_state_dict(file_or_url: tp.Union[Path, str], device='cpu'):
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  # Return the state dict either from a file or url
33
- file_or_url = str(file_or_url)
34
- assert isinstance(file_or_url, str)
35
- if file_or_url.startswith('https://'):
36
- return torch.hub.load_state_dict_from_url(file_or_url, map_location=device, check_hash=True)
 
 
 
 
 
 
 
 
 
 
 
 
37
  else:
38
- return torch.load(file_or_url, device)
39
 
40
 
41
- def load_compression_model(file_or_url: tp.Union[Path, str], device='cpu'):
42
- pkg = _get_state_dict(file_or_url)
43
  cfg = OmegaConf.create(pkg['xp.cfg'])
44
  cfg.device = str(device)
45
  model = builders.get_compression_model(cfg)
@@ -48,8 +75,8 @@ def load_compression_model(file_or_url: tp.Union[Path, str], device='cpu'):
48
  return model
49
 
50
 
51
- def load_lm_model(file_or_url: tp.Union[Path, str], device='cpu'):
52
- pkg = _get_state_dict(file_or_url)
53
  cfg = OmegaConf.create(pkg['xp.cfg'])
54
  cfg.device = str(device)
55
  if cfg.device == 'cpu':
 
20
  """
21
 
22
  from pathlib import Path
23
+ from huggingface_hub import hf_hub_download
24
  import typing as tp
25
+ import os
26
 
27
  from omegaconf import OmegaConf
28
  import torch
 
30
  from . import builders
31
 
32
 
33
+ HF_MODEL_CHECKPOINTS_MAP = {
34
+ "small": "facebook/musicgen-small",
35
+ "medium": "facebook/musicgen-medium",
36
+ "large": "facebook/musicgen-large",
37
+ "melody": "facebook/musicgen-melody",
38
+ }
39
+
40
+
41
+ def _get_state_dict(
42
+ file_or_url_or_id: tp.Union[Path, str],
43
+ filename: tp.Optional[str] = None,
44
+ device='cpu',
45
+ cache_dir: tp.Optional[str] = None,
46
+ ):
47
  # Return the state dict either from a file or url
48
+ file_or_url_or_id = str(file_or_url_or_id)
49
+ assert isinstance(file_or_url_or_id, str)
50
+
51
+ if os.path.isfile(file_or_url_or_id):
52
+ return torch.load(file_or_url_or_id, map_location=device)
53
+
54
+ elif file_or_url_or_id.startswith('https://'):
55
+ return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
56
+
57
+ elif file_or_url_or_id in HF_MODEL_CHECKPOINTS_MAP:
58
+ assert filename is not None, "filename needs to be defined if using HF checkpoints"
59
+
60
+ repo_id = HF_MODEL_CHECKPOINTS_MAP[file_or_url_or_id]
61
+ file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir)
62
+ return torch.load(file, map_location=device)
63
+
64
  else:
65
+ raise ValueError(f"{file_or_url_or_id} is not a valid name, path or link that can be loaded.")
66
 
67
 
68
+ def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
69
+ pkg = _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
70
  cfg = OmegaConf.create(pkg['xp.cfg'])
71
  cfg.device = str(device)
72
  model = builders.get_compression_model(cfg)
 
75
  return model
76
 
77
 
78
+ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
79
+ pkg = _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
80
  cfg = OmegaConf.create(pkg['xp.cfg'])
81
  cfg.device = str(device)
82
  if cfg.device == 'cpu':
audiocraft/models/musicgen.py CHANGED
@@ -17,7 +17,7 @@ import torch
17
  from .encodec import CompressionModel
18
  from .lm import LMModel
19
  from .builders import get_debug_compression_model, get_debug_lm_model
20
- from .loaders import load_compression_model, load_lm_model
21
  from ..data.audio_utils import convert_audio
22
  from ..modules.conditioners import ConditioningAttributes, WavCondition
23
  from ..utils.autocast import TorchAutocast
@@ -67,10 +67,10 @@ class MusicGen:
67
  @staticmethod
68
  def get_pretrained(name: str = 'melody', device='cuda'):
69
  """Return pretrained model, we provide four models:
70
- - small (300M), text to music,
71
- - medium (1.5B), text to music,
72
- - melody (1.5B) text to music and text+melody to music,
73
- - large (3.3B), text to music.
74
  """
75
 
76
  if name == 'debug':
@@ -79,21 +79,16 @@ class MusicGen:
79
  lm = get_debug_lm_model(device)
80
  return MusicGen(name, compression_model, lm)
81
 
82
- if 'MUSICGEN_ROOT' in os.environ:
83
- ROOT = os.environ['MUSICGEN_ROOT']
84
- if not ROOT.endswith('/'):
85
- ROOT += '/'
86
- else:
87
- ROOT = 'https://dl.fbaipublicfiles.com/audiocraft/musicgen/v0/'
88
- compression_model = load_compression_model(ROOT + 'b0dbef54-37d256b525.th', device=device)
89
- names = {
90
- 'small': 'ba7a97ba-830fe5771e',
91
- 'medium': 'aa73ae27-fbc9f401db',
92
- 'large': '9b6e835c-1f0cf17b5e',
93
- 'melody': 'f79af192-61305ffc49',
94
- }
95
- sig = names[name]
96
- lm = load_lm_model(ROOT + f'{sig}.th', device=device)
97
  return MusicGen(name, compression_model, lm)
98
 
99
  def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
 
17
  from .encodec import CompressionModel
18
  from .lm import LMModel
19
  from .builders import get_debug_compression_model, get_debug_lm_model
20
+ from .loaders import load_compression_model, load_lm_model, HF_MODEL_CHECKPOINTS_MAP
21
  from ..data.audio_utils import convert_audio
22
  from ..modules.conditioners import ConditioningAttributes, WavCondition
23
  from ..utils.autocast import TorchAutocast
 
67
  @staticmethod
68
  def get_pretrained(name: str = 'melody', device='cuda'):
69
  """Return pretrained model, we provide four models:
70
+ - small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
71
+ - medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
72
+ - melody (1.5B) text to music and text+melody to music, # see: https://huggingface.co/facebook/musicgen-melody
73
+ - large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
74
  """
75
 
76
  if name == 'debug':
 
79
  lm = get_debug_lm_model(device)
80
  return MusicGen(name, compression_model, lm)
81
 
82
+ if name not in HF_MODEL_CHECKPOINTS_MAP:
83
+ raise ValueError(
84
+ f"{name} is not a valid checkpoint name. "
85
+ f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
86
+ )
87
+
88
+ cache_dir = os.environ.get('MUSICGEN_ROOT', None)
89
+ compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
90
+ lm = load_lm_model(name, device=device, cache_dir=cache_dir)
91
+
 
 
 
 
 
92
  return MusicGen(name, compression_model, lm)
93
 
94
  def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
audiocraft/utils/utils.py CHANGED
@@ -122,7 +122,7 @@ def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
122
  probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
123
  probs_sum = torch.cumsum(probs_sort, dim=-1)
124
  mask = probs_sum - probs_sort > p
125
- probs_sort *= (~mask).float(0)
126
  probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
127
  next_token = multinomial(probs_sort, num_samples=1)
128
  next_token = torch.gather(probs_idx, -1, next_token)
 
122
  probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
123
  probs_sum = torch.cumsum(probs_sort, dim=-1)
124
  mask = probs_sum - probs_sort > p
125
+ probs_sort *= (~mask).float()
126
  probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
127
  next_token = multinomial(probs_sort, num_samples=1)
128
  next_token = torch.gather(probs_idx, -1, next_token)
hf_loading.py DELETED
@@ -1,61 +0,0 @@
1
- """Utility for loading the models from HF."""
2
- from pathlib import Path
3
- import typing as tp
4
-
5
- from omegaconf import OmegaConf
6
- from huggingface_hub import hf_hub_download
7
- import torch
8
-
9
- from audiocraft.models import builders, MusicGen
10
-
11
- MODEL_CHECKPOINTS_MAP = {
12
- "small": "facebook/musicgen-small",
13
- "medium": "facebook/musicgen-medium",
14
- "large": "facebook/musicgen-large",
15
- "melody": "facebook/musicgen-melody",
16
- }
17
-
18
-
19
- def _get_state_dict(file_or_url: tp.Union[Path, str],
20
- filename="state_dict.bin", device='cpu'):
21
- # Return the state dict either from a file or url
22
- print("loading", file_or_url, filename)
23
- file_or_url = str(file_or_url)
24
- assert isinstance(file_or_url, str)
25
- return torch.load(
26
- hf_hub_download(repo_id=file_or_url, filename=filename), map_location=device)
27
-
28
-
29
- def load_compression_model(file_or_url: tp.Union[Path, str], device='cpu'):
30
- pkg = _get_state_dict(file_or_url, filename="compression_state_dict.bin")
31
- cfg = OmegaConf.create(pkg['xp.cfg'])
32
- cfg.device = str(device)
33
- model = builders.get_compression_model(cfg)
34
- model.load_state_dict(pkg['best_state'])
35
- model.eval()
36
- model.cfg = cfg
37
- return model
38
-
39
-
40
- def load_lm_model(file_or_url: tp.Union[Path, str], device='cpu'):
41
- pkg = _get_state_dict(file_or_url)
42
- cfg = OmegaConf.create(pkg['xp.cfg'])
43
- cfg.device = str(device)
44
- if cfg.device == 'cpu':
45
- cfg.transformer_lm.memory_efficient = False
46
- cfg.transformer_lm.custom = True
47
- cfg.dtype = 'float32'
48
- else:
49
- cfg.dtype = 'float16'
50
- model = builders.get_lm_model(cfg)
51
- model.load_state_dict(pkg['best_state'])
52
- model.eval()
53
- model.cfg = cfg
54
- return model
55
-
56
-
57
- def get_pretrained(name: str = 'small', device='cuda'):
58
- model_id = MODEL_CHECKPOINTS_MAP[name]
59
- compression_model = load_compression_model(model_id, device=device)
60
- lm = load_lm_model(model_id, device=device)
61
- return MusicGen(name, compression_model, lm)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mypy.ini CHANGED
@@ -1,4 +1,4 @@
1
  [mypy]
2
 
3
- [mypy-treetable,torchaudio.*,soundfile,einops.*,av.*,tqdm.*,num2words.*,spacy,xformers.*,scipy]
4
  ignore_missing_imports = True
 
1
  [mypy]
2
 
3
+ [mypy-treetable,torchaudio.*,soundfile,einops.*,av.*,tqdm.*,num2words.*,spacy,xformers.*,scipy,huggingface_hub]
4
  ignore_missing_imports = True
requirements.txt CHANGED
@@ -11,6 +11,7 @@ sentencepiece
11
  spacy==3.5.2
12
  torch>=2.0.0
13
  torchaudio>=2.0.0
 
14
  tqdm
15
  transformers
16
  xformers
 
11
  spacy==3.5.2
12
  torch>=2.0.0
13
  torchaudio>=2.0.0
14
+ huggingface_hub
15
  tqdm
16
  transformers
17
  xformers