Khushi Dahiya commited on
Commit
330d8e8
·
1 Parent(s): 6bc3ca2

reverting batch processing

Browse files
Files changed (1) hide show
  1. demos/melodyflow_app.py +33 -282
demos/melodyflow_app.py CHANGED
@@ -20,11 +20,6 @@ from tempfile import NamedTemporaryFile
20
  import time
21
  import typing as tp
22
  import warnings
23
- import asyncio
24
- import threading
25
- from concurrent.futures import ThreadPoolExecutor
26
- from queue import Queue, Empty
27
- import uuid
28
 
29
  import torch
30
  import gradio as gr
@@ -35,15 +30,9 @@ from audiocraft.models import MelodyFlow
35
 
36
 
37
  MODEL = None # Last used model
38
- MODEL_LOCK = threading.Lock() # Thread lock for model access
39
- REQUEST_QUEUE = Queue() # Queue for batch processing
40
- BATCH_PROCESSOR = None # Background batch processor
41
- BATCH_SIZE = 4 # Maximum batch size for concurrent processing
42
- BATCH_TIMEOUT = 2.0 # Maximum wait time to form a batch (seconds)
43
  SPACE_ID = os.environ.get('SPACE_ID', '')
44
  MODEL_PREFIX = os.environ.get('MODEL_PREFIX', 'facebook/')
45
  IS_HF_SPACE = (MODEL_PREFIX + "MelodyFlow") in SPACE_ID
46
- MAX_BATCH_SIZE = 12
47
  N_REPEATS = 1
48
  INTERRUPTING = False
49
  MBD = None
@@ -82,213 +71,6 @@ class FileCleaner:
82
  file_cleaner = FileCleaner()
83
 
84
 
85
- class RequestBatch:
86
- """Represents a batch of requests to process together"""
87
- def __init__(self):
88
- self.requests = []
89
- self.futures = []
90
- self.created_at = time.time()
91
-
92
- def add_request(self, request_data, future):
93
- self.requests.append(request_data)
94
- self.futures.append(future)
95
-
96
- def is_full(self):
97
- return len(self.requests) >= BATCH_SIZE
98
-
99
- def is_expired(self):
100
- return time.time() - self.created_at > BATCH_TIMEOUT
101
-
102
- def should_process(self):
103
- return self.is_full() or self.is_expired() or len(self.requests) > 0
104
-
105
-
106
- class BatchProcessor:
107
- """Handles batched processing of requests"""
108
- def __init__(self):
109
- self.current_batch = RequestBatch()
110
- self.processing = False
111
- self.stop_event = threading.Event()
112
-
113
- def start(self):
114
- """Start the background batch processing thread"""
115
- self.thread = threading.Thread(target=self._process_loop, daemon=True)
116
- self.thread.start()
117
-
118
- def stop(self):
119
- """Stop the background batch processing"""
120
- self.stop_event.set()
121
-
122
- def add_request(self, request_data):
123
- """Add a request to the batch and return a future for the result"""
124
- from concurrent.futures import Future
125
- future = Future()
126
-
127
- # Add to current batch
128
- self.current_batch.add_request(request_data, future)
129
-
130
- # Signal that we have a new request
131
- REQUEST_QUEUE.put("new_request")
132
-
133
- return future
134
-
135
- def _process_loop(self):
136
- """Main processing loop that runs in background thread"""
137
- while not self.stop_event.is_set():
138
- try:
139
- # Wait for a signal or timeout
140
- REQUEST_QUEUE.get(timeout=0.5)
141
-
142
- # Check if we should process current batch
143
- if self.current_batch.should_process() and not self.processing:
144
- self._process_current_batch()
145
-
146
- except Empty:
147
- # Timeout - check if we have an expired batch
148
- if self.current_batch.should_process() and not self.processing:
149
- self._process_current_batch()
150
- continue
151
- except Exception as e:
152
- print(f"Error in batch processing loop: {e}")
153
-
154
- @spaces.GPU(duration=45) # Increased duration for batch processing
155
- def _process_current_batch(self):
156
- """Process the current batch of requests"""
157
- if len(self.current_batch.requests) == 0:
158
- return
159
-
160
- self.processing = True
161
- batch = self.current_batch
162
- self.current_batch = RequestBatch() # Start new batch
163
-
164
- try:
165
- # Extract batch data
166
- texts = []
167
- melodies = []
168
- params_list = []
169
-
170
- print(f"🔄 BATCH PROCESSOR: Processing {len(batch.requests)} requests")
171
-
172
- for request_data in batch.requests:
173
- texts.append(request_data['text'])
174
- melodies.append(request_data['melody'])
175
- params_list.append({
176
- 'solver': request_data['solver'],
177
- 'steps': request_data['steps'],
178
- 'target_flowstep': request_data['target_flowstep'],
179
- 'regularize': request_data['regularize'],
180
- 'regularization_strength': request_data['regularization_strength'],
181
- 'duration': request_data['duration'],
182
- 'model': request_data['model']
183
- })
184
-
185
- # Load model if needed (use the first request's model)
186
- model_version = params_list[0]['model']
187
- load_model(model_version)
188
-
189
- # Process batch with unified parameters (use first request's params)
190
- params = params_list[0]
191
- results = _do_predictions_batch(
192
- texts=texts,
193
- melodies=melodies,
194
- solver=params['solver'],
195
- steps=params['steps'],
196
- target_flowstep=params['target_flowstep'],
197
- regularize=params['regularize'],
198
- regularization_strength=params['regularization_strength'],
199
- duration=params['duration'],
200
- progress=False
201
- )
202
-
203
- # Set results for each future
204
- for i, future in enumerate(batch.futures):
205
- if i < len(results):
206
- future.set_result(results[i])
207
- else:
208
- future.set_exception(Exception("Batch processing failed"))
209
-
210
- except Exception as e:
211
- # Set exception for all futures in batch
212
- for future in batch.futures:
213
- future.set_exception(e)
214
- finally:
215
- self.processing = False
216
-
217
-
218
- def _do_predictions_batch(texts, melodies, solver, steps, target_flowstep,
219
- regularize, regularization_strength, duration, progress=False):
220
- """Process a batch of predictions efficiently"""
221
- with MODEL_LOCK:
222
- MODEL.set_generation_params(solver=solver, steps=steps, duration=duration)
223
- MODEL.set_editing_params(
224
- solver=solver,
225
- steps=steps,
226
- target_flowstep=target_flowstep,
227
- regularize=regularize,
228
- lambda_kl=regularization_strength
229
- )
230
-
231
- print(f"Processing batch: {len(texts)} requests")
232
- be = time.time()
233
-
234
- processed_melodies = []
235
- target_sr = 48000
236
- target_ac = 2
237
-
238
- for melody in melodies:
239
- if melody is None:
240
- processed_melodies.append(None)
241
- else:
242
- melody, sr = audio_read(melody)
243
- if melody.dim() == 2:
244
- melody = melody[None]
245
- if melody.shape[-1] > int(sr * MODEL.duration):
246
- melody = melody[..., :int(sr * MODEL.duration)]
247
- melody = convert_audio(melody, sr, target_sr, target_ac)
248
- melody = MODEL.encode_audio(melody.to(MODEL.device))
249
- processed_melodies.append(melody)
250
-
251
- try:
252
- # Process all requests in the batch together
253
- if any(m is not None for m in processed_melodies):
254
- # For editing mode, process each request individually due to melody constraints
255
- outputs_list = []
256
- for i, (text, melody) in enumerate(zip(texts, processed_melodies)):
257
- if melody is not None:
258
- output = MODEL.edit(
259
- prompt_tokens=melody.repeat(1, 1, 1),
260
- descriptions=[text],
261
- src_descriptions=[""],
262
- progress=progress,
263
- return_tokens=False,
264
- )
265
- else:
266
- output = MODEL.generate([text], progress=progress, return_tokens=False)
267
- outputs_list.append(output[0])
268
- outputs = torch.stack(outputs_list)
269
- else:
270
- # For generation mode, we can batch all requests
271
- outputs = MODEL.generate(texts, progress=progress, return_tokens=False)
272
-
273
- except RuntimeError as e:
274
- raise gr.Error("Error while generating " + e.args[0])
275
-
276
- outputs = outputs.detach().cpu().float()
277
- results = []
278
-
279
- for output in outputs:
280
- with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
281
- audio_write(
282
- file.name, output, MODEL.sample_rate, strategy="loudness",
283
- loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
284
-
285
- results.append(file.name)
286
- file_cleaner.add(file.name)
287
-
288
- print(f"Batch finished: {len(texts)} requests in {time.time() - be:.2f}s")
289
- return results
290
-
291
-
292
  def make_waveform(*args, **kwargs):
293
  # Further remove some warnings.
294
  be = time.time()
@@ -301,18 +83,18 @@ def make_waveform(*args, **kwargs):
301
 
302
  def load_model(version=(MODEL_PREFIX + "melodyflow-t24-30secs")):
303
  global MODEL
304
- with MODEL_LOCK:
305
- print("Loading model", version)
306
- if MODEL is None or MODEL.name != version:
307
- # Clear PyTorch CUDA cache and delete model
308
- del MODEL
309
- if torch.cuda.is_available():
310
- torch.cuda.empty_cache()
311
- MODEL = None # in case loading would crash
312
- MODEL = MelodyFlow.get_pretrained(version)
313
- print(f"Model {version} loaded successfully")
314
-
315
-
316
  def _do_predictions(texts,
317
  melodies,
318
  solver,
@@ -384,9 +166,7 @@ def predict(model, text,
384
  melody=None,
385
  model_path=None,
386
  progress=gr.Progress()):
387
- """Non-blocking predict function that uses batch processing"""
388
-
389
- print(f"🎵 PREDICT FUNCTION CALLED - Text: '{text[:50]}...' Model: {model}")
390
 
391
  if melody is not None:
392
  if solver == MIDPOINT:
@@ -394,15 +174,10 @@ def predict(model, text,
394
  else:
395
  steps = steps//5
396
 
397
- global INTERRUPTING, BATCH_PROCESSOR
398
  INTERRUPTING = False
399
 
400
- # Initialize batch processor if not already running
401
- if BATCH_PROCESSOR is None:
402
- BATCH_PROCESSOR = BatchProcessor()
403
- BATCH_PROCESSOR.start()
404
-
405
- progress(0, desc="Queuing request...")
406
 
407
  if model_path:
408
  model_path = model_path.strip()
@@ -413,51 +188,30 @@ def predict(model, text,
413
  "state_dict.bin and compression_state_dict_.bin.")
414
  model = model_path
415
 
416
- # Prepare request data
417
- request_data = {
418
- 'text': text,
419
- 'melody': melody,
420
- 'solver': solver,
421
- 'steps': steps,
422
- 'target_flowstep': target_flowstep,
423
- 'regularize': regularize,
424
- 'regularization_strength': regularization_strength,
425
- 'duration': duration,
426
- 'model': model,
427
- 'request_id': str(uuid.uuid4())
428
- }
429
-
430
- # Add to batch processor
431
- future = BATCH_PROCESSOR.add_request(request_data)
432
 
433
- progress(0.3, desc="Waiting for GPU...")
434
 
435
- # Wait for result with progress updates
436
- max_wait = 60 # Maximum wait time in seconds
437
- wait_start = time.time()
438
-
439
- while not future.done():
440
- elapsed = time.time() - wait_start
441
- if elapsed > max_wait:
442
- raise gr.Error("Request timeout")
443
-
444
- # Update progress based on wait time
445
- progress_val = min(0.9, 0.3 + (elapsed / max_wait) * 0.6)
446
- progress(progress_val, desc="Processing...")
447
-
448
- if INTERRUPTING:
449
- raise gr.Error("Interrupted.")
450
-
451
- time.sleep(0.1)
452
-
453
- progress(1.0, desc="Complete!")
454
-
455
- # Get result
456
  try:
457
- result = future.result()
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  if isinstance(result, list) and len(result) > 0:
459
  return result[0]
460
  return result
 
461
  except Exception as e:
462
  raise gr.Error(f"Generation failed: {str(e)}")
463
 
@@ -729,9 +483,6 @@ def ui_hf(launch_kwargs):
729
 
730
  def cleanup():
731
  """Cleanup function for graceful shutdown"""
732
- global BATCH_PROCESSOR
733
- if BATCH_PROCESSOR:
734
- BATCH_PROCESSOR.stop()
735
  print("Cleanup completed")
736
 
737
 
 
20
  import time
21
  import typing as tp
22
  import warnings
 
 
 
 
 
23
 
24
  import torch
25
  import gradio as gr
 
30
 
31
 
32
  MODEL = None # Last used model
 
 
 
 
 
33
  SPACE_ID = os.environ.get('SPACE_ID', '')
34
  MODEL_PREFIX = os.environ.get('MODEL_PREFIX', 'facebook/')
35
  IS_HF_SPACE = (MODEL_PREFIX + "MelodyFlow") in SPACE_ID
 
36
  N_REPEATS = 1
37
  INTERRUPTING = False
38
  MBD = None
 
71
  file_cleaner = FileCleaner()
72
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  def make_waveform(*args, **kwargs):
75
  # Further remove some warnings.
76
  be = time.time()
 
83
 
84
  def load_model(version=(MODEL_PREFIX + "melodyflow-t24-30secs")):
85
  global MODEL
86
+ print("Loading model", version)
87
+ if MODEL is None or MODEL.name != version:
88
+ # Clear PyTorch CUDA cache and delete model
89
+ del MODEL
90
+ if torch.cuda.is_available():
91
+ torch.cuda.empty_cache()
92
+ MODEL = None # in case loading would crash
93
+ MODEL = MelodyFlow.get_pretrained(version)
94
+ print(f"Model {version} loaded successfully")
95
+
96
+
97
+ @spaces.GPU(duration=45)
98
  def _do_predictions(texts,
99
  melodies,
100
  solver,
 
166
  melody=None,
167
  model_path=None,
168
  progress=gr.Progress()):
169
+ """Simple predict function without batch processing"""
 
 
170
 
171
  if melody is not None:
172
  if solver == MIDPOINT:
 
174
  else:
175
  steps = steps//5
176
 
177
+ global INTERRUPTING
178
  INTERRUPTING = False
179
 
180
+ progress(0, desc="Loading model...")
 
 
 
 
 
181
 
182
  if model_path:
183
  model_path = model_path.strip()
 
188
  "state_dict.bin and compression_state_dict_.bin.")
189
  model = model_path
190
 
191
+ load_model(model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
+ progress(0.1, desc="Generating music...")
194
 
195
+ # Use the simple _do_predictions function for single request
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  try:
197
+ result = _do_predictions(
198
+ texts=[text],
199
+ melodies=[melody],
200
+ solver=solver,
201
+ steps=steps,
202
+ target_flowstep=target_flowstep,
203
+ regularize=regularize,
204
+ regularization_strength=regularization_strength,
205
+ duration=duration,
206
+ progress=True
207
+ )
208
+
209
+ progress(1.0, desc="Complete!")
210
+
211
  if isinstance(result, list) and len(result) > 0:
212
  return result[0]
213
  return result
214
+
215
  except Exception as e:
216
  raise gr.Error(f"Generation failed: {str(e)}")
217
 
 
483
 
484
  def cleanup():
485
  """Cleanup function for graceful shutdown"""
 
 
 
486
  print("Cleanup completed")
487
 
488