flosstradamus commited on
Commit
f8cd83e
·
verified ·
1 Parent(s): 772add9

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -11
app.py CHANGED
@@ -15,9 +15,6 @@ from utils import load_t5, load_clap
15
  from train import RF
16
  from constants import build_model
17
 
18
- # Disable flash attention if not available
19
- torch.backends.cuda.enable_flash_sdp(False)
20
-
21
  # Global variables to store loaded models and resources
22
  global_model = None
23
  global_t5 = None
@@ -31,8 +28,39 @@ MODELS_DIR = "/content/models"
31
  GENERATIONS_DIR = "/content/generations"
32
 
33
  def prepare(t5, clip, img, prompt):
34
- # ... [The prepare function remains unchanged]
35
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  def unload_current_model():
38
  global global_model
@@ -87,12 +115,92 @@ def load_resources():
87
  print("Base resources loaded successfully!")
88
 
89
  def generate_music(prompt, seed, cfg_scale, steps, duration, progress=gr.Progress()):
90
- # ... [The generate_music function remains largely unchanged]
91
- # Update the output directory
92
- output_dir = GENERATIONS_DIR
93
- os.makedirs(output_dir, exist_ok=True)
94
- # ... [Rest of the function remains the same]
95
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  # Load base resources at startup
98
  load_resources()
 
15
  from train import RF
16
  from constants import build_model
17
 
 
 
 
18
  # Global variables to store loaded models and resources
19
  global_model = None
20
  global_t5 = None
 
28
  GENERATIONS_DIR = "/content/generations"
29
 
30
  def prepare(t5, clip, img, prompt):
31
+ bs, c, h, w = img.shape
32
+ if bs == 1 and not isinstance(prompt, str):
33
+ bs = len(prompt)
34
+
35
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
36
+ if img.shape[0] == 1 and bs > 1:
37
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
38
+
39
+ img_ids = torch.zeros(h // 2, w // 2, 3)
40
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
41
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
42
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
43
+
44
+ if isinstance(prompt, str):
45
+ prompt = [prompt]
46
+
47
+ # Generate text embeddings
48
+ txt = t5(prompt)
49
+
50
+ if txt.shape[0] == 1 and bs > 1:
51
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
52
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
53
+
54
+ vec = clip(prompt)
55
+ if vec.shape[0] == 1 and bs > 1:
56
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
57
+
58
+ return img, {
59
+ "img_ids": img_ids.to(img.device),
60
+ "txt": txt.to(img.device),
61
+ "txt_ids": txt_ids.to(img.device),
62
+ "y": vec.to(img.device),
63
+ }
64
 
65
  def unload_current_model():
66
  global global_model
 
115
  print("Base resources loaded successfully!")
116
 
117
  def generate_music(prompt, seed, cfg_scale, steps, duration, progress=gr.Progress()):
118
+ global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion
119
+
120
+ if global_model is None:
121
+ return "Please select a model first.", None
122
+
123
+ if seed == 0:
124
+ seed = random.randint(1, 1000000)
125
+ print(f"Using seed: {seed}")
126
+
127
+ device = "cuda" if torch.cuda.is_available() else "cpu"
128
+ torch.manual_seed(seed)
129
+ torch.set_grad_enabled(False)
130
+
131
+ # Calculate the number of segments needed for the desired duration
132
+ segment_duration = 10 # Each segment is 10 seconds
133
+ num_segments = int(np.ceil(duration / segment_duration))
134
+
135
+ all_waveforms = []
136
+
137
+ for i in range(num_segments):
138
+ progress(i / num_segments, desc=f"Generating segment {i+1}/{num_segments}")
139
+
140
+ # Use the same seed for all segments
141
+ torch.manual_seed(seed + i) # Add i to slightly vary each segment while maintaining consistency
142
+
143
+ latent_size = (256, 16)
144
+ conds_txt = [prompt]
145
+ unconds_txt = ["low quality, gentle"]
146
+ L = len(conds_txt)
147
+
148
+ init_noise = torch.randn(L, 8, latent_size[0], latent_size[1]).to(device)
149
+
150
+ img, conds = prepare(global_t5, global_clap, init_noise, conds_txt)
151
+ _, unconds = prepare(global_t5, global_clap, init_noise, unconds_txt)
152
+
153
+ with torch.autocast(device_type='cuda'):
154
+ images = global_diffusion.sample_with_xps(global_model, img, conds=conds, null_cond=unconds, sample_steps=steps, cfg=cfg_scale)
155
+
156
+ images = rearrange(
157
+ images[-1],
158
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
159
+ h=128,
160
+ w=8,
161
+ ph=2,
162
+ pw=2,)
163
+
164
+ latents = 1 / global_vae.config.scaling_factor * images
165
+ mel_spectrogram = global_vae.decode(latents).sample
166
+
167
+ x_i = mel_spectrogram[0]
168
+ if x_i.dim() == 4:
169
+ x_i = x_i.squeeze(1)
170
+ waveform = global_vocoder(x_i)
171
+ waveform = waveform[0].cpu().float().detach().numpy()
172
+
173
+ all_waveforms.append(waveform)
174
+
175
+ # Concatenate all waveforms
176
+ final_waveform = np.concatenate(all_waveforms)
177
+
178
+ # Trim to exact duration
179
+ sample_rate = 16000
180
+ final_waveform = final_waveform[:int(duration * sample_rate)]
181
+
182
+ progress(0.9, desc="Saving audio file")
183
+
184
+ # Create 'generations' folder
185
+ os.makedirs(GENERATIONS_DIR, exist_ok=True)
186
+
187
+ # Generate filename
188
+ prompt_part = re.sub(r'[^\w\s-]', '', prompt)[:10].strip().replace(' ', '_')
189
+ model_name = os.path.splitext(os.path.basename(global_model.model_path))[0]
190
+ model_suffix = '_mf_b' if model_name == 'musicflow_b' else f'_{model_name}'
191
+ base_filename = f"{prompt_part}_{seed}{model_suffix}"
192
+ output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}.wav")
193
+
194
+ # Check if file exists and add numerical suffix if needed
195
+ counter = 1
196
+ while os.path.exists(output_path):
197
+ output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}_{counter}.wav")
198
+ counter += 1
199
+
200
+ wavfile.write(output_path, sample_rate, final_waveform)
201
+
202
+ progress(1.0, desc="Audio generation complete")
203
+ return f"Generated with seed: {seed}", output_path
204
 
205
  # Load base resources at startup
206
  load_resources()