Ashley Kleynhans commited on
Commit
8e10a53
1 Parent(s): 10b60b3

Allow gradio command line arguments to be specified (#50)

Browse files
Files changed (1) hide show
  1. app.py +143 -83
app.py CHANGED
@@ -7,10 +7,10 @@ LICENSE file in the root directory of this source tree.
7
  """
8
 
9
  from tempfile import NamedTemporaryFile
 
10
  import torch
11
  import gradio as gr
12
  from audiocraft.models import MusicGen
13
-
14
  from audiocraft.data.audio import audio_write
15
 
16
 
@@ -61,90 +61,150 @@ def predict(model, text, melody, duration, topk, topp, temperature, cfg_coef):
61
  return waveform_video
62
 
63
 
64
- with gr.Blocks() as demo:
65
- gr.Markdown(
66
- """
67
- # MusicGen
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
70
- presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
71
- <br/>
72
- <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
73
- <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
74
- for longer sequences, more control and no queue.</p>
75
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  )
77
- with gr.Row():
78
- with gr.Column():
79
- with gr.Row():
80
- text = gr.Text(label="Input Text", interactive=True)
81
- melody = gr.Audio(source="upload", type="numpy", label="Melody Condition (optional)", interactive=True)
82
- with gr.Row():
83
- submit = gr.Button("Submit")
84
- with gr.Row():
85
- model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
86
- with gr.Row():
87
- duration = gr.Slider(minimum=1, maximum=30, value=10, label="Duration", interactive=True)
88
- with gr.Row():
89
- topk = gr.Number(label="Top-k", value=250, interactive=True)
90
- topp = gr.Number(label="Top-p", value=0, interactive=True)
91
- temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
92
- cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
93
- with gr.Column():
94
- output = gr.Video(label="Generated Music")
95
- submit.click(predict, inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output])
96
- gr.Examples(
97
- fn=predict,
98
- examples=[
99
- [
100
- "An 80s driving pop song with heavy drums and synth pads in the background",
101
- "./assets/bach.mp3",
102
- "melody"
103
- ],
104
- [
105
- "A cheerful country song with acoustic guitars",
106
- "./assets/bolero_ravel.mp3",
107
- "melody"
108
- ],
109
- [
110
- "90s rock song with electric guitar and heavy drums",
111
- None,
112
- "medium"
113
- ],
114
- [
115
- "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
116
- "./assets/bach.mp3",
117
- "melody"
118
- ],
119
- [
120
- "lofi slow bpm electro chill with organic samples",
121
- None,
122
- "medium",
123
- ],
124
- ],
125
- inputs=[text, melody, model],
126
- outputs=[output]
127
  )
128
- gr.Markdown(
129
- """
130
- ### More details
131
-
132
- The model will generate a short music extract based on the description you provided.
133
- You can generate up to 30 seconds of audio.
134
-
135
- We present 4 model variations:
136
- 1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
137
- 2. Small -- a 300M transformer decoder conditioned on text only.
138
- 3. Medium -- a 1.5B transformer decoder conditioned on text only.
139
- 4. Large -- a 3.3B transformer decoder conditioned on text only (might OOM for the longest sequences.)
140
-
141
- When using `melody`, ou can optionaly provide a reference audio from
142
- which a broad melody will be extracted. The model will then try to follow both the description and melody provided.
143
-
144
- You can also use your own GPU or a Google Colab by following the instructions on our repo.
145
- See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
146
- for more details.
147
- """
148
  )
 
 
 
 
 
 
 
 
149
 
150
- demo.launch()
 
 
 
 
 
 
 
 
7
  """
8
 
9
  from tempfile import NamedTemporaryFile
10
+ import argparse
11
  import torch
12
  import gradio as gr
13
  from audiocraft.models import MusicGen
 
14
  from audiocraft.data.audio import audio_write
15
 
16
 
 
61
  return waveform_video
62
 
63
 
64
+ def ui(**kwargs):
65
+ with gr.Blocks() as interface:
66
+ gr.Markdown(
67
+ """
68
+ # MusicGen
69
+
70
+ This is the demo for [MusicGen](https://github.com/facebookresearch/audiocraft), a simple and controllable model for music generation
71
+ presented at: ["Simple and Controllable Music Generation"](https://huggingface.co/papers/2306.05284).
72
+ <br/>
73
+ <a href="https://huggingface.co/spaces/musicgen/MusicGen?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank">
74
+ <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
75
+ for longer sequences, more control and no queue.</p>
76
+ """
77
+ )
78
+ with gr.Row():
79
+ with gr.Column():
80
+ with gr.Row():
81
+ text = gr.Text(label="Input Text", interactive=True)
82
+ melody = gr.Audio(source="upload", type="numpy", label="Melody Condition (optional)", interactive=True)
83
+ with gr.Row():
84
+ submit = gr.Button("Submit")
85
+ with gr.Row():
86
+ model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
87
+ with gr.Row():
88
+ duration = gr.Slider(minimum=1, maximum=30, value=10, label="Duration", interactive=True)
89
+ with gr.Row():
90
+ topk = gr.Number(label="Top-k", value=250, interactive=True)
91
+ topp = gr.Number(label="Top-p", value=0, interactive=True)
92
+ temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
93
+ cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
94
+ with gr.Column():
95
+ output = gr.Video(label="Generated Music")
96
+ submit.click(predict, inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output])
97
+ gr.Examples(
98
+ fn=predict,
99
+ examples=[
100
+ [
101
+ "An 80s driving pop song with heavy drums and synth pads in the background",
102
+ "./assets/bach.mp3",
103
+ "melody"
104
+ ],
105
+ [
106
+ "A cheerful country song with acoustic guitars",
107
+ "./assets/bolero_ravel.mp3",
108
+ "melody"
109
+ ],
110
+ [
111
+ "90s rock song with electric guitar and heavy drums",
112
+ None,
113
+ "medium"
114
+ ],
115
+ [
116
+ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
117
+ "./assets/bach.mp3",
118
+ "melody"
119
+ ],
120
+ [
121
+ "lofi slow bpm electro chill with organic samples",
122
+ None,
123
+ "medium",
124
+ ],
125
+ ],
126
+ inputs=[text, melody, model],
127
+ outputs=[output]
128
+ )
129
+ gr.Markdown(
130
+ """
131
+ ### More details
132
+
133
+ The model will generate a short music extract based on the description you provided.
134
+ You can generate up to 30 seconds of audio.
135
+
136
+ We present 4 model variations:
137
+ 1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
138
+ 2. Small -- a 300M transformer decoder conditioned on text only.
139
+ 3. Medium -- a 1.5B transformer decoder conditioned on text only.
140
+ 4. Large -- a 3.3B transformer decoder conditioned on text only (might OOM for the longest sequences.)
141
+
142
+ When using `melody`, ou can optionaly provide a reference audio from
143
+ which a broad melody will be extracted. The model will then try to follow both the description and melody provided.
144
+
145
+ You can also use your own GPU or a Google Colab by following the instructions on our repo.
146
+ See [github.com/facebookresearch/audiocraft](https://github.com/facebookresearch/audiocraft)
147
+ for more details.
148
+ """
149
+ )
150
 
151
+ # Show the interface
152
+ launch_kwargs = {}
153
+ username = kwargs.get('username')
154
+ password = kwargs.get('password')
155
+ server_port = kwargs.get('server_port', 0)
156
+ inbrowser = kwargs.get('inbrowser', False)
157
+ share = kwargs.get('share', False)
158
+ server_name = kwargs.get('listen')
159
+
160
+ launch_kwargs['server_name'] = server_name
161
+
162
+ if username and password:
163
+ launch_kwargs['auth'] = (username, password)
164
+ if server_port > 0:
165
+ launch_kwargs['server_port'] = server_port
166
+ if inbrowser:
167
+ launch_kwargs['inbrowser'] = inbrowser
168
+ if share:
169
+ launch_kwargs['share'] = share
170
+
171
+ interface.launch(**launch_kwargs)
172
+
173
+ if __name__ == "__main__":
174
+ # torch.cuda.set_per_process_memory_fraction(0.48)
175
+ parser = argparse.ArgumentParser()
176
+ parser.add_argument(
177
+ '--listen',
178
+ type=str,
179
+ default='127.0.0.1',
180
+ help='IP to listen on for connections to Gradio',
181
  )
182
+ parser.add_argument(
183
+ '--username', type=str, default='', help='Username for authentication'
184
+ )
185
+ parser.add_argument(
186
+ '--password', type=str, default='', help='Password for authentication'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  )
188
+ parser.add_argument(
189
+ '--server_port',
190
+ type=int,
191
+ default=0,
192
+ help='Port to run the server listener on',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  )
194
+ parser.add_argument(
195
+ '--inbrowser', action='store_true', help='Open in browser'
196
+ )
197
+ parser.add_argument(
198
+ '--share', action='store_true', help='Share the gradio UI'
199
+ )
200
+
201
+ args = parser.parse_args()
202
 
203
+ ui(
204
+ username=args.username,
205
+ password=args.password,
206
+ inbrowser=args.inbrowser,
207
+ server_port=args.server_port,
208
+ share=args.share,
209
+ listen=args.listen
210
+ )