mattricesound commited on
Commit
0517e91
1 Parent(s): 7c82007

Update requirements.txt to be git repo

Browse files
Files changed (2) hide show
  1. app.py +15 -12
  2. requirements.txt +1 -1
app.py CHANGED
@@ -102,7 +102,7 @@ def load_model(version='melody'):
102
  MODEL = MusicGen.get_pretrained(version, device=device)
103
 
104
 
105
- def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
106
  MODEL.set_generation_params(duration=duration, **gen_kwargs)
107
  print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
108
  be = time.time()
@@ -135,14 +135,15 @@ def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
135
  out_files = []
136
  for output in outputs:
137
  # Demucs
138
- print("Running demucs")
139
- wav = convert_audio(output, MODEL.sample_rate, demucs_model.samplerate, demucs_model.audio_channels)
140
- wav = wav.unsqueeze(0)
141
- stems = apply_model(demucs_model, wav)
142
- stems = stems[:, stem_idx] # extract stem
143
- stems = stems.sum(1) # merge extracted stems
144
- stems = convert_audio(stems, demucs_model.samplerate, MODEL.sample_rate, 1)
145
- output = stems[0]
 
146
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
147
  audio_write(
148
  file.name, output, MODEL.sample_rate, strategy="loudness",
@@ -168,7 +169,7 @@ def predict_batched(texts, melodies):
168
  return [res]
169
 
170
 
171
- def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
172
  global INTERRUPTING
173
  INTERRUPTING = False
174
  if temperature < 0:
@@ -189,7 +190,7 @@ def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coe
189
 
190
  outs = _do_predictions(
191
  [text], [melody], duration, progress=True,
192
- top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
193
 
194
 
195
 
@@ -228,10 +229,12 @@ def ui_full(launch_kwargs):
228
  topp = gr.Number(label="Top-p", value=0, interactive=True)
229
  temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
230
  cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
 
 
231
  with gr.Column():
232
  output = gr.Video(label="Generated Music")
233
  submit.click(predict_full,
234
- inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef],
235
  outputs=[output])
236
  radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
237
  gr.Markdown(
 
102
  MODEL = MusicGen.get_pretrained(version, device=device)
103
 
104
 
105
+ def _do_predictions(texts, melodies, duration, progress=False, drums=True, **gen_kwargs):
106
  MODEL.set_generation_params(duration=duration, **gen_kwargs)
107
  print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
108
  be = time.time()
 
135
  out_files = []
136
  for output in outputs:
137
  # Demucs
138
+ if not drums:
139
+ print("Running demucs")
140
+ wav = convert_audio(output, MODEL.sample_rate, demucs_model.samplerate, demucs_model.audio_channels)
141
+ wav = wav.unsqueeze(0)
142
+ stems = apply_model(demucs_model, wav)
143
+ stems = stems[:, stem_idx] # extract stem
144
+ stems = stems.sum(1) # merge extracted stems
145
+ stems = convert_audio(stems, demucs_model.samplerate, MODEL.sample_rate, 1)
146
+ output = stems[0]
147
  with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
148
  audio_write(
149
  file.name, output, MODEL.sample_rate, strategy="loudness",
 
169
  return [res]
170
 
171
 
172
+ def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef, drums=drums, progress=gr.Progress()):
173
  global INTERRUPTING
174
  INTERRUPTING = False
175
  if temperature < 0:
 
190
 
191
  outs = _do_predictions(
192
  [text], [melody], duration, progress=True,
193
+ top_k=topk, top_p=topp, temperature=temperature, drums=drums, cfg_coef=cfg_coef)
194
 
195
 
196
 
 
229
  topp = gr.Number(label="Top-p", value=0, interactive=True)
230
  temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
231
  cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
232
+ with gr.Row():
233
+ drums = gr.Checkbox(label="Drums", value=True, interactive=True)
234
  with gr.Column():
235
  output = gr.Video(label="Generated Music")
236
  submit.click(predict_full,
237
+ inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef, drums],
238
  outputs=[output])
239
  radio.change(toggle_audio_src, radio, [melody], queue=False, show_progress=False)
240
  gr.Markdown(
requirements.txt CHANGED
@@ -12,7 +12,7 @@ arrow==1.2.3
12
  asttokens==2.2.1
13
  async-timeout==4.0.2
14
  attrs==23.1.0
15
- audiocraft==0.0.1
16
  audioread==3.0.0
17
  av==10.0.0
18
  backcall==0.2.0
 
12
  asttokens==2.2.1
13
  async-timeout==4.0.2
14
  attrs==23.1.0
15
+ audiocraft @ git+https://github.com/facebookresearch/audiocraft@0.0.2a2
16
  audioread==3.0.0
17
  av==10.0.0
18
  backcall==0.2.0