Surn commited on
Commit
c228235
1 Parent(s): aef7fad

Update for app_batched

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. app_batched.py +158 -71
app.py CHANGED
@@ -13,7 +13,7 @@ import gradio as gr
13
  import os
14
  from audiocraft.models import MusicGen
15
  from audiocraft.data.audio import audio_write
16
- from audiocraft.utils.extend import generate_music_segments, add_settings_to_image, sanitize_file_name
17
  import numpy as np
18
  import random
19
 
 
13
  import os
14
  from audiocraft.models import MusicGen
15
  from audiocraft.data.audio import audio_write
16
+ from audiocraft.utils.extend import generate_music_segments, add_settings_to_image
17
  import numpy as np
18
  import random
19
 
app_batched.py CHANGED
@@ -6,7 +6,12 @@ This source code is licensed under the license found in the
6
  LICENSE file in the root directory of this source tree.
7
  """
8
 
 
 
 
9
  from tempfile import NamedTemporaryFile
 
 
10
  import torch
11
  import gradio as gr
12
  from audiocraft.data.audio_utils import convert_audio
@@ -16,6 +21,29 @@ from audiocraft.models import MusicGen
16
 
17
  MODEL = None
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  def load_model():
21
  print("Loading model")
@@ -28,11 +56,13 @@ def predict(texts, melodies):
28
  MODEL = load_model()
29
 
30
  duration = 12
 
 
31
  MODEL.set_generation_params(duration=duration)
32
 
33
- print(texts, melodies)
 
34
  processed_melodies = []
35
-
36
  target_sr = 32000
37
  target_ac = 1
38
  for melody in melodies:
@@ -40,8 +70,6 @@ def predict(texts, melodies):
40
  processed_melodies.append(None)
41
  else:
42
  sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
43
- duration = min(duration, melody.shape[-1] / sr)
44
- MODEL.set_generation_params(duration=duration)
45
  if melody.dim() == 1:
46
  melody = melody[None]
47
  melody = melody[..., :int(sr * duration)]
@@ -52,7 +80,7 @@ def predict(texts, melodies):
52
  descriptions=texts,
53
  melody_wavs=processed_melodies,
54
  melody_sample_rate=target_sr,
55
- progress=True
56
  )
57
 
58
  outputs = outputs.detach().cpu().float()
@@ -62,73 +90,132 @@ def predict(texts, melodies):
62
  audio_write(
63
  file.name, output, MODEL.sample_rate, strategy="loudness",
64
  loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
65
- waveform_video = gr.make_waveform(file.name)
66
- out_files.append(waveform_video)
67
- return [out_files]
68
-
69
-
70
- with gr.Blocks() as demo:
71
- gr.Markdown(
72
- """
73
- # MusicGen
74
-
75
- This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
76
- presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
77
- <br/>
78
- <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
79
- <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
80
- for longer sequences, more control and no queue.</p>
81
- """
82
- )
83
- with gr.Row():
84
- with gr.Column():
85
- with gr.Row():
86
- text = gr.Text(label="Describe your music", lines=2, interactive=True)
87
- melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True)
88
- with gr.Row():
89
- submit = gr.Button("Generate")
90
- with gr.Column():
91
- output = gr.Video(label="Generated Music")
92
- submit.click(predict, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=12)
93
- gr.Examples(
94
- fn=predict,
95
- examples=[
96
- [
97
- "An 80s driving pop song with heavy drums and synth pads in the background",
98
- "./assets/bach.mp3",
99
- ],
100
- [
101
- "A cheerful country song with acoustic guitars",
102
- "./assets/bolero_ravel.mp3",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  ],
104
- [
105
- "90s rock song with electric guitar and heavy drums",
106
- None,
107
- ],
108
- [
109
- "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
110
- "./assets/bach.mp3",
111
- ],
112
- [
113
- "lofi slow bpm electro chill with organic samples",
114
- None,
115
- ],
116
- ],
117
- inputs=[text, melody],
118
- outputs=[output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  )
120
- gr.Markdown("""
121
- ### More details
122
-
123
- The model will generate 12 seconds of audio based on the description you provided.
124
- You can optionaly provide a reference audio from which a broad melody will be extracted.
125
- The model will then try to follow both the description and melody provided.
126
- All samples are generated with the `melody` model.
127
-
128
- You can also use your own GPU or a Google Colab by following the instructions on our repo.
129
 
130
- See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
131
- for more details.
132
- """)
133
 
134
- demo.queue(max_size=15).launch()
 
 
 
 
 
 
 
 
6
  LICENSE file in the root directory of this source tree.
7
  """
8
 
9
+ import argparse
10
+ from concurrent.futures import ProcessPoolExecutor
11
+ import subprocess as sp
12
  from tempfile import NamedTemporaryFile
13
+ import time
14
+ import warnings
15
  import torch
16
  import gradio as gr
17
  from audiocraft.data.audio_utils import convert_audio
 
21
 
22
  MODEL = None
23
 
24
+ _old_call = sp.call
25
+
26
+
27
+ def _call_nostderr(*args, **kwargs):
28
+ # Avoid ffmpeg vomitting on the logs.
29
+ kwargs['stderr'] = sp.DEVNULL
30
+ kwargs['stdout'] = sp.DEVNULL
31
+ _old_call(*args, **kwargs)
32
+
33
+
34
+ sp.call = _call_nostderr
35
+ pool = ProcessPoolExecutor(3)
36
+ pool.__enter__()
37
+
38
+
39
+ def make_waveform(*args, **kwargs):
40
+ be = time.time()
41
+ with warnings.catch_warnings():
42
+ warnings.simplefilter('ignore')
43
+ out = gr.make_waveform(*args, **kwargs)
44
+ print("Make a video took", time.time() - be)
45
+ return out
46
+
47
 
48
  def load_model():
49
  print("Loading model")
 
56
  MODEL = load_model()
57
 
58
  duration = 12
59
+ max_text_length = 512
60
+ texts = [text[:max_text_length] for text in texts]
61
  MODEL.set_generation_params(duration=duration)
62
 
63
+ print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
64
+ be = time.time()
65
  processed_melodies = []
 
66
  target_sr = 32000
67
  target_ac = 1
68
  for melody in melodies:
 
70
  processed_melodies.append(None)
71
  else:
72
  sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
 
 
73
  if melody.dim() == 1:
74
  melody = melody[None]
75
  melody = melody[..., :int(sr * duration)]
 
80
  descriptions=texts,
81
  melody_wavs=processed_melodies,
82
  melody_sample_rate=target_sr,
83
+ progress=False
84
  )
85
 
86
  outputs = outputs.detach().cpu().float()
 
90
  audio_write(
91
  file.name, output, MODEL.sample_rate, strategy="loudness",
92
  loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
93
+ out_files.append(pool.submit(make_waveform, file.name))
94
+ res = [[out_file.result() for out_file in out_files]]
95
+ print("batch finished", len(texts), time.time() - be)
96
+ return res
97
+
98
+
99
+ def ui(**kwargs):
100
+ with gr.Blocks() as demo:
101
+ gr.Markdown(
102
+ """
103
+ # MusicGen
104
+
105
+ This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
106
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
107
+ <br/>
108
+ <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
109
+ <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
110
+ for longer sequences, more control and no queue.</p>
111
+ """
112
+ )
113
+ with gr.Row():
114
+ with gr.Column():
115
+ with gr.Row():
116
+ text = gr.Text(label="Describe your music", lines=2, interactive=True)
117
+ melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True)
118
+ with gr.Row():
119
+ submit = gr.Button("Generate")
120
+ with gr.Column():
121
+ output = gr.Video(label="Generated Music")
122
+ submit.click(predict, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=8)
123
+ gr.Examples(
124
+ fn=predict,
125
+ examples=[
126
+ [
127
+ "An 80s driving pop song with heavy drums and synth pads in the background",
128
+ "./assets/bach.mp3",
129
+ ],
130
+ [
131
+ "A cheerful country song with acoustic guitars",
132
+ "./assets/bolero_ravel.mp3",
133
+ ],
134
+ [
135
+ "90s rock song with electric guitar and heavy drums",
136
+ None,
137
+ ],
138
+ [
139
+ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
140
+ "./assets/bach.mp3",
141
+ ],
142
+ [
143
+ "lofi slow bpm electro chill with organic samples",
144
+ None,
145
+ ],
146
  ],
147
+ inputs=[text, melody],
148
+ outputs=[output]
149
+ )
150
+ gr.Markdown("""
151
+ ### More details
152
+
153
+ The model will generate 12 seconds of audio based on the description you provided.
154
+ You can optionaly provide a reference audio from which a broad melody will be extracted.
155
+ The model will then try to follow both the description and melody provided.
156
+ All samples are generated with the `melody` model.
157
+
158
+ You can also use your own GPU or a Google Colab by following the instructions on our repo.
159
+
160
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
161
+ for more details.
162
+ """)
163
+
164
+ # Show the interface
165
+ launch_kwargs = {}
166
+ username = kwargs.get('username')
167
+ password = kwargs.get('password')
168
+ server_port = kwargs.get('server_port', 0)
169
+ inbrowser = kwargs.get('inbrowser', False)
170
+ share = kwargs.get('share', False)
171
+ server_name = kwargs.get('listen')
172
+
173
+ launch_kwargs['server_name'] = server_name
174
+
175
+ if username and password:
176
+ launch_kwargs['auth'] = (username, password)
177
+ if server_port > 0:
178
+ launch_kwargs['server_port'] = server_port
179
+ if inbrowser:
180
+ launch_kwargs['inbrowser'] = inbrowser
181
+ if share:
182
+ launch_kwargs['share'] = share
183
+ demo.queue(max_size=60).launch(**launch_kwargs)
184
+
185
+ if __name__ == "__main__":
186
+ parser = argparse.ArgumentParser()
187
+ parser.add_argument(
188
+ '--listen',
189
+ type=str,
190
+ default='127.0.0.1',
191
+ help='IP to listen on for connections to Gradio',
192
+ )
193
+ parser.add_argument(
194
+ '--username', type=str, default='', help='Username for authentication'
195
+ )
196
+ parser.add_argument(
197
+ '--password', type=str, default='', help='Password for authentication'
198
+ )
199
+ parser.add_argument(
200
+ '--server_port',
201
+ type=int,
202
+ default=0,
203
+ help='Port to run the server listener on',
204
+ )
205
+ parser.add_argument(
206
+ '--inbrowser', action='store_true', help='Open in browser'
207
+ )
208
+ parser.add_argument(
209
+ '--share', action='store_true', help='Share the gradio UI'
210
  )
 
 
 
 
 
 
 
 
 
211
 
212
+ args = parser.parse_args()
 
 
213
 
214
+ ui(
215
+ username=args.username,
216
+ password=args.password,
217
+ inbrowser=args.inbrowser,
218
+ server_port=args.server_port,
219
+ share=args.share,
220
+ listen=args.listen
221
+ )