flosstradamus commited on
Commit
b300542
·
verified ·
1 Parent(s): 771145b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -142
app.py CHANGED
@@ -5,15 +5,10 @@ from einops import rearrange, repeat
5
  from diffusers import AutoencoderKL
6
  from transformers import SpeechT5HifiGan
7
  from scipy.io import wavfile
 
8
  import random
9
  import numpy as np
10
  import re
11
- import requests
12
- from urllib.parse import urlparse
13
- import logging
14
-
15
- # Set up logging
16
- logging.basicConfig(level=logging.INFO)
17
 
18
  # Import necessary functions and classes
19
  from utils import load_t5, load_clap
@@ -28,12 +23,44 @@ global_vae = None
28
  global_vocoder = None
29
  global_diffusion = None
30
 
31
- # Set the generations directory
 
32
  GENERATIONS_DIR = "/content/generations"
33
 
34
  def prepare(t5, clip, img, prompt):
35
- # ... [The prepare function remains unchanged]
36
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def unload_current_model():
39
  global global_model
@@ -42,171 +69,152 @@ def unload_current_model():
42
  torch.cuda.empty_cache()
43
  global_model = None
44
 
45
- def download_model(url):
46
- try:
47
- response = requests.get(url, stream=True)
48
- if response.status_code == 200:
49
- filename = os.path.basename(urlparse(url).path)
50
- model_path = os.path.join("/tmp", filename)
51
- with open(model_path, "wb") as f:
52
- for chunk in response.iter_content(chunk_size=8192):
53
- f.write(chunk)
54
- return model_path
55
- else:
56
- raise Exception(f"Failed to download model from {url}")
57
- except Exception as e:
58
- logging.error(f"Error downloading model: {str(e)}")
59
- raise
60
-
61
- def load_model(url):
62
  global global_model
63
- try:
64
- device = "cuda" if torch.cuda.is_available() else "cpu"
65
-
66
- unload_current_model()
67
-
68
- logging.info(f"Downloading model from {url}")
69
- model_path = download_model(url)
70
-
71
- # Determine model size from filename
72
- filename = os.path.basename(model_path)
73
- if 'musicflow_b' in filename:
74
- model_size = "base"
75
- elif 'musicflow_g' in filename:
76
- model_size = "giant"
77
- elif 'musicflow_l' in filename:
78
- model_size = "large"
79
- elif 'musicflow_s' in filename:
80
- model_size = "small"
81
- else:
82
- model_size = "base" # Default to base if unrecognized
83
-
84
- logging.info(f"Loading {model_size} model: {filename}")
85
-
86
- global_model = build_model(model_size).to(device)
87
- state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
88
- global_model.load_state_dict(state_dict['ema'])
89
- global_model.eval()
90
- global_model.model_path = model_path
91
- logging.info("Model loaded successfully")
92
- return "Model loaded successfully"
93
- except Exception as e:
94
- logging.error(f"Error loading model: {str(e)}")
95
- return f"Error loading model: {str(e)}"
96
 
97
  def load_resources():
98
  global global_t5, global_clap, global_vae, global_vocoder, global_diffusion
99
 
100
  device = "cuda" if torch.cuda.is_available() else "cpu"
101
 
102
- logging.info("Loading T5 and CLAP models...")
103
  global_t5 = load_t5(device, max_length=256)
104
  global_clap = load_clap(device, max_length=256)
105
 
106
- logging.info("Loading VAE and vocoder...")
107
  global_vae = AutoencoderKL.from_pretrained('cvssp/audioldm2', subfolder="vae").to(device)
108
  global_vocoder = SpeechT5HifiGan.from_pretrained('cvssp/audioldm2', subfolder="vocoder").to(device)
109
 
110
- logging.info("Initializing diffusion...")
111
  global_diffusion = RF()
112
 
113
- logging.info("Base resources loaded successfully!")
114
 
115
- def generate_music(prompt, seed, cfg_scale, steps, duration):
116
  global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion
117
 
118
  if global_model is None:
119
- return "Please load a model first.", None
120
 
121
  if seed == 0:
122
  seed = random.randint(1, 1000000)
123
- logging.info(f"Using seed: {seed}")
124
 
125
  device = "cuda" if torch.cuda.is_available() else "cpu"
126
  torch.manual_seed(seed)
127
  torch.set_grad_enabled(False)
128
 
129
- try:
130
- # Calculate the number of segments needed for the desired duration
131
- segment_duration = 10 # Each segment is 10 seconds
132
- num_segments = int(np.ceil(duration / segment_duration))
133
 
134
- all_waveforms = []
135
 
136
- for i in range(num_segments):
137
- logging.info(f"Generating segment {i+1}/{num_segments}")
138
 
139
- # Use the same seed for all segments
140
- torch.manual_seed(seed + i) # Add i to slightly vary each segment while maintaining consistency
141
 
142
- latent_size = (256, 16)
143
- conds_txt = [prompt]
144
- unconds_txt = ["low quality, gentle"]
145
- L = len(conds_txt)
146
 
147
- init_noise = torch.randn(L, 8, latent_size[0], latent_size[1]).to(device)
148
 
149
- img, conds = prepare(global_t5, global_clap, init_noise, conds_txt)
150
- _, unconds = prepare(global_t5, global_clap, init_noise, unconds_txt)
151
 
152
- with torch.autocast(device_type='cuda'):
153
- images = global_diffusion.sample_with_xps(global_model, img, conds=conds, null_cond=unconds, sample_steps=steps, cfg=cfg_scale)
154
 
155
- images = rearrange(
156
- images[-1],
157
- "b (h w) (c ph pw) -> b c (h ph) (w pw)",
158
- h=128,
159
- w=8,
160
- ph=2,
161
- pw=2,)
162
 
163
- latents = 1 / global_vae.config.scaling_factor * images
164
- mel_spectrogram = global_vae.decode(latents).sample
165
 
166
- x_i = mel_spectrogram[0]
167
- if x_i.dim() == 4:
168
- x_i = x_i.squeeze(1)
169
- waveform = global_vocoder(x_i)
170
- waveform = waveform[0].cpu().float().detach().numpy()
171
 
172
- all_waveforms.append(waveform)
173
 
174
- # Concatenate all waveforms
175
- final_waveform = np.concatenate(all_waveforms)
176
 
177
- # Trim to exact duration
178
- sample_rate = 16000
179
- final_waveform = final_waveform[:int(duration * sample_rate)]
180
 
181
- logging.info("Saving audio file")
182
-
183
- # Create 'generations' folder
184
- os.makedirs(GENERATIONS_DIR, exist_ok=True)
185
-
186
- # Generate filename
187
- prompt_part = re.sub(r'[^\w\s-]', '', prompt)[:10].strip().replace(' ', '_')
188
- model_name = os.path.splitext(os.path.basename(global_model.model_path))[0]
189
- model_suffix = '_mf_b' if model_name == 'musicflow_b' else f'_{model_name}'
190
- base_filename = f"{prompt_part}_{seed}{model_suffix}"
191
- output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}.wav")
192
-
193
- # Check if file exists and add numerical suffix if needed
194
- counter = 1
195
- while os.path.exists(output_path):
196
- output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}_{counter}.wav")
197
- counter += 1
198
 
199
- wavfile.write(output_path, sample_rate, final_waveform)
200
 
201
- logging.info("Audio generation complete")
202
- return f"Generated with seed: {seed}", output_path
203
- except Exception as e:
204
- logging.error(f"Error generating music: {str(e)}")
205
- return f"Error generating music: {str(e)}", None
206
 
207
  # Load base resources at startup
208
  load_resources()
209
 
 
 
 
 
 
 
 
 
 
 
210
  # Set up dark grey theme
211
  theme = gr.themes.Monochrome(
212
  primary_hue="gray",
@@ -225,29 +233,32 @@ with gr.Blocks(theme=theme) as iface:
225
  </div>
226
  """)
227
 
228
- model_url = gr.Textbox(label="Model URL", placeholder="Enter the URL of the model file (.pt)")
229
- load_model_button = gr.Button("Load Model")
230
- model_status = gr.Textbox(label="Model Status")
 
 
 
231
 
232
- prompt = gr.Textbox(label="Prompt")
233
- seed = gr.Number(label="Seed", value=0)
234
- cfg_scale = gr.Slider(minimum=1, maximum=40, step=0.1, label="CFG Scale", value=20)
235
- steps = gr.Slider(minimum=10, maximum=200, step=1, label="Steps", value=100)
236
- duration = gr.Number(label="Duration (seconds)", value=10, minimum=10, maximum=300, step=1)
237
 
238
  generate_button = gr.Button("Generate Music")
239
  output_status = gr.Textbox(label="Generation Status")
240
  output_audio = gr.Audio(type="filepath")
241
 
242
- def load_model_wrapper(url):
243
- return load_model(url)
244
 
245
- def generate_music_wrapper(prompt, seed, cfg_scale, steps, duration):
246
- status, audio_path = generate_music(prompt, seed, cfg_scale, steps, duration)
247
- return status, audio_path if audio_path else None
248
 
249
- load_model_button.click(load_model_wrapper, inputs=[model_url], outputs=[model_status])
250
- generate_button.click(generate_music_wrapper, inputs=[prompt, seed, cfg_scale, steps, duration], outputs=[output_status, output_audio])
 
 
251
 
252
  # Launch the interface
253
  iface.launch()
 
5
  from diffusers import AutoencoderKL
6
  from transformers import SpeechT5HifiGan
7
  from scipy.io import wavfile
8
+ import glob
9
  import random
10
  import numpy as np
11
  import re
 
 
 
 
 
 
12
 
13
  # Import necessary functions and classes
14
  from utils import load_t5, load_clap
 
23
  global_vocoder = None
24
  global_diffusion = None
25
 
26
+ # Set the models directory
27
+ MODELS_DIR = "/content/models"
28
  GENERATIONS_DIR = "/content/generations"
29
 
30
  def prepare(t5, clip, img, prompt):
31
+ bs, c, h, w = img.shape
32
+ if bs == 1 and not isinstance(prompt, str):
33
+ bs = len(prompt)
34
+
35
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
36
+ if img.shape[0] == 1 and bs > 1:
37
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
38
+
39
+ img_ids = torch.zeros(h // 2, w // 2, 3)
40
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
41
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
42
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
43
+
44
+ if isinstance(prompt, str):
45
+ prompt = [prompt]
46
+
47
+ # Generate text embeddings
48
+ txt = t5(prompt)
49
+
50
+ if txt.shape[0] == 1 and bs > 1:
51
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
52
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
53
+
54
+ vec = clip(prompt)
55
+ if vec.shape[0] == 1 and bs > 1:
56
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
57
+
58
+ return img, {
59
+ "img_ids": img_ids.to(img.device),
60
+ "txt": txt.to(img.device),
61
+ "txt_ids": txt_ids.to(img.device),
62
+ "y": vec.to(img.device),
63
+ }
64
 
65
  def unload_current_model():
66
  global global_model
 
69
  torch.cuda.empty_cache()
70
  global_model = None
71
 
72
+ def load_model(model_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  global global_model
74
+ device = "cuda" if torch.cuda.is_available() else "cpu"
75
+
76
+ unload_current_model()
77
+
78
+ # Determine model size from filename
79
+ if 'musicflow_b' in model_name:
80
+ model_size = "base"
81
+ elif 'musicflow_g' in model_name:
82
+ model_size = "giant"
83
+ elif 'musicflow_l' in model_name:
84
+ model_size = "large"
85
+ elif 'musicflow_s' in model_name:
86
+ model_size = "small"
87
+ else:
88
+ model_size = "base" # Default to base if unrecognized
89
+
90
+ print(f"Loading {model_size} model: {model_name}")
91
+
92
+ model_path = os.path.join(MODELS_DIR, model_name)
93
+ global_model = build_model(model_size).to(device)
94
+ state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
95
+ global_model.load_state_dict(state_dict['ema'])
96
+ global_model.eval()
97
+ global_model.model_path = model_path
 
 
 
 
 
 
 
 
 
98
 
99
  def load_resources():
100
  global global_t5, global_clap, global_vae, global_vocoder, global_diffusion
101
 
102
  device = "cuda" if torch.cuda.is_available() else "cpu"
103
 
104
+ print("Loading T5 and CLAP models...")
105
  global_t5 = load_t5(device, max_length=256)
106
  global_clap = load_clap(device, max_length=256)
107
 
108
+ print("Loading VAE and vocoder...")
109
  global_vae = AutoencoderKL.from_pretrained('cvssp/audioldm2', subfolder="vae").to(device)
110
  global_vocoder = SpeechT5HifiGan.from_pretrained('cvssp/audioldm2', subfolder="vocoder").to(device)
111
 
112
+ print("Initializing diffusion...")
113
  global_diffusion = RF()
114
 
115
+ print("Base resources loaded successfully!")
116
 
117
+ def generate_music(prompt, seed, cfg_scale, steps, duration, progress=gr.Progress()):
118
  global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion
119
 
120
  if global_model is None:
121
+ return "Please select a model first.", None
122
 
123
  if seed == 0:
124
  seed = random.randint(1, 1000000)
125
+ print(f"Using seed: {seed}")
126
 
127
  device = "cuda" if torch.cuda.is_available() else "cpu"
128
  torch.manual_seed(seed)
129
  torch.set_grad_enabled(False)
130
 
131
+ # Calculate the number of segments needed for the desired duration
132
+ segment_duration = 10 # Each segment is 10 seconds
133
+ num_segments = int(np.ceil(duration / segment_duration))
 
134
 
135
+ all_waveforms = []
136
 
137
+ for i in range(num_segments):
138
+ progress(i / num_segments, desc=f"Generating segment {i+1}/{num_segments}")
139
 
140
+ # Use the same seed for all segments
141
+ torch.manual_seed(seed + i) # Add i to slightly vary each segment while maintaining consistency
142
 
143
+ latent_size = (256, 16)
144
+ conds_txt = [prompt]
145
+ unconds_txt = ["low quality, gentle"]
146
+ L = len(conds_txt)
147
 
148
+ init_noise = torch.randn(L, 8, latent_size[0], latent_size[1]).to(device)
149
 
150
+ img, conds = prepare(global_t5, global_clap, init_noise, conds_txt)
151
+ _, unconds = prepare(global_t5, global_clap, init_noise, unconds_txt)
152
 
153
+ with torch.autocast(device_type='cuda'):
154
+ images = global_diffusion.sample_with_xps(global_model, img, conds=conds, null_cond=unconds, sample_steps=steps, cfg=cfg_scale)
155
 
156
+ images = rearrange(
157
+ images[-1],
158
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
159
+ h=128,
160
+ w=8,
161
+ ph=2,
162
+ pw=2,)
163
 
164
+ latents = 1 / global_vae.config.scaling_factor * images
165
+ mel_spectrogram = global_vae.decode(latents).sample
166
 
167
+ x_i = mel_spectrogram[0]
168
+ if x_i.dim() == 4:
169
+ x_i = x_i.squeeze(1)
170
+ waveform = global_vocoder(x_i)
171
+ waveform = waveform[0].cpu().float().detach().numpy()
172
 
173
+ all_waveforms.append(waveform)
174
 
175
+ # Concatenate all waveforms
176
+ final_waveform = np.concatenate(all_waveforms)
177
 
178
+ # Trim to exact duration
179
+ sample_rate = 16000
180
+ final_waveform = final_waveform[:int(duration * sample_rate)]
181
 
182
+ progress(0.9, desc="Saving audio file")
183
+
184
+ # Create 'generations' folder
185
+ os.makedirs(GENERATIONS_DIR, exist_ok=True)
186
+
187
+ # Generate filename
188
+ prompt_part = re.sub(r'[^\w\s-]', '', prompt)[:10].strip().replace(' ', '_')
189
+ model_name = os.path.splitext(os.path.basename(global_model.model_path))[0]
190
+ model_suffix = '_mf_b' if model_name == 'musicflow_b' else f'_{model_name}'
191
+ base_filename = f"{prompt_part}_{seed}{model_suffix}"
192
+ output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}.wav")
193
+
194
+ # Check if file exists and add numerical suffix if needed
195
+ counter = 1
196
+ while os.path.exists(output_path):
197
+ output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}_{counter}.wav")
198
+ counter += 1
199
 
200
+ wavfile.write(output_path, sample_rate, final_waveform)
201
 
202
+ progress(1.0, desc="Audio generation complete")
203
+ return f"Generated with seed: {seed}", output_path
 
 
 
204
 
205
  # Load base resources at startup
206
  load_resources()
207
 
208
+ # Get list of .pt files in the models directory
209
+ model_files = glob.glob(os.path.join(MODELS_DIR, "*.pt"))
210
+ model_choices = [os.path.basename(f) for f in model_files]
211
+
212
+ # Ensure 'musicflow_b.pt' is the default choice if it exists
213
+ default_model = 'musicflow_b.pt'
214
+ if default_model in model_choices:
215
+ model_choices.remove(default_model)
216
+ model_choices.insert(0, default_model)
217
+
218
  # Set up dark grey theme
219
  theme = gr.themes.Monochrome(
220
  primary_hue="gray",
 
233
  </div>
234
  """)
235
 
236
+ with gr.Row():
237
+ model_dropdown = gr.Dropdown(choices=model_choices, label="Select Model", value=default_model if default_model in model_choices else model_choices[0])
238
+
239
+ with gr.Row():
240
+ prompt = gr.Textbox(label="Prompt")
241
+ seed = gr.Number(label="Seed", value=0)
242
 
243
+ with gr.Row():
244
+ cfg_scale = gr.Slider(minimum=1, maximum=40, step=0.1, label="CFG Scale", value=20)
245
+ steps = gr.Slider(minimum=10, maximum=200, step=1, label="Steps", value=100)
246
+ duration = gr.Number(label="Duration (seconds)", value=10, minimum=10, maximum=300, step=1)
 
247
 
248
  generate_button = gr.Button("Generate Music")
249
  output_status = gr.Textbox(label="Generation Status")
250
  output_audio = gr.Audio(type="filepath")
251
 
252
+ def on_model_change(model_name):
253
+ load_model(model_name)
254
 
255
+ model_dropdown.change(on_model_change, inputs=[model_dropdown])
256
+ generate_button.click(generate_music, inputs=[prompt, seed, cfg_scale, steps, duration], outputs=[output_status, output_audio])
 
257
 
258
+ # Load default model on startup
259
+ default_model_path = os.path.join(MODELS_DIR, default_model)
260
+ if os.path.exists(default_model_path):
261
+ iface.load(lambda: load_model(default_model), inputs=None, outputs=None)
262
 
263
  # Launch the interface
264
  iface.launch()