ASesYusuf1 commited on
Commit
8205184
·
verified ·
1 Parent(s): b8209ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -44
app.py CHANGED
@@ -434,26 +434,46 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
434
  logger.warning(f"Failed to clean up temporary file {temp_audio_path}: {e}")
435
 
436
  @spaces.GPU
437
- def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_tta, model_dir, output_dir, norm_thresh, amp_thresh, batch_size, ensemble_method, exclude_stems="", weights_str=""):
438
- import gc
439
- import torch
440
-
441
- if not audio or not model_keys:
442
- raise ValueError("Audio or models missing.")
443
-
444
  temp_audio_path = None
 
445
  try:
446
- # Limit to 2 models for testing
447
- model_keys = model_keys[:2]
448
 
 
449
  if isinstance(audio, tuple):
450
  sample_rate, data = audio
451
  temp_audio_path = os.path.join("/tmp", "temp_audio.wav")
452
  scipy.io.wavfile.write(temp_audio_path, sample_rate, data)
453
  audio = temp_audio_path
454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  use_tta = use_tta == "True"
456
 
 
457
  if os.path.exists(output_dir):
458
  shutil.rmtree(output_dir)
459
  os.makedirs(output_dir, exist_ok=True)
@@ -462,48 +482,70 @@ def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_
462
  logger.info(f"Ensemble for {base_name} with {model_keys} on {device}")
463
 
464
  all_stems = []
465
- total_models = len(model_keys)
466
 
467
- for i, model_key in enumerate(model_keys):
 
468
  for category, models in ROFORMER_MODELS.items():
469
  if model_key in models:
470
  model = models[model_key]
471
  break
472
  else:
 
473
  continue
474
 
475
- separator = Separator(
476
- log_level=logging.INFO,
477
- model_file_dir=model_dir,
478
- output_dir=output_dir,
479
- output_format=out_format,
480
- normalization_threshold=norm_thresh,
481
- amplification_threshold=amp_thresh,
482
- use_autocast=use_autocast,
483
- mdxc_params={"segment_size": seg_size, "overlap": overlap, "use_tta": use_tta, "batch_size": batch_size}
484
- )
485
- logger.info(f"Loading {model_key}")
486
- separator.load_model(model_filename=model)
487
- logger.info(f"Separating with {model_key}")
488
- separation = separator.separate(audio)
489
- stems = [os.path.join(output_dir, file_name) for file_name in separation]
490
-
491
- if exclude_stems.strip():
492
- excluded = [s.strip().lower() for s in exclude_stems.split(',')]
493
- filtered_stems = [stem for stem in stems if not any(ex in os.path.basename(stem).lower() for ex in excluded)]
494
- all_stems.extend(filtered_stems)
495
- else:
496
- all_stems.extend(stems)
497
-
498
- # Clean up model to free memory
499
- separator = None
500
- gc.collect()
501
- if torch.cuda.is_available():
502
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
 
504
  if not all_stems:
505
  raise ValueError("No valid stems for ensemble after exclusion.")
506
 
 
507
  weights = [float(w.strip()) for w in weights_str.split(',')] if weights_str.strip() else [1.0] * len(all_stems)
508
  if len(weights) != len(all_stems):
509
  weights = [1.0] * len(all_stems)
@@ -515,7 +557,7 @@ def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_
515
  "--weights", *[str(w) for w in weights],
516
  "--output", output_file
517
  ]
518
- logger.info("Running ensemble...")
519
  ensemble_files(ensemble_args)
520
 
521
  logger.info("Ensemble complete")
@@ -524,12 +566,14 @@ def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_
524
  logger.error(f"Ensemble failed: {e}")
525
  raise RuntimeError(f"Ensemble failed: {e}")
526
  finally:
527
- if temp_audio_path and os.path.exists(temp_audio_path):
 
528
  try:
529
- os.remove(temp_audio_path)
530
- logger.info(f"Successfully cleaned up {temp_audio_path}")
 
531
  except Exception as e:
532
- logger.error(f"Failed to clean up {temp_audio_path}: {e}")
533
 
534
  def update_roformer_models(category):
535
  """Update Roformer model dropdown based on selected category."""
 
434
  logger.warning(f"Failed to clean up temporary file {temp_audio_path}: {e}")
435
 
436
  @spaces.GPU
437
+ def auto_ensemble_process(audio, model_keys, seg_size=128, overlap=0.1, out_format="wav", use_tta="False", model_dir="/tmp/audio-separator-models/", output_dir="output", norm_thresh=0.9, amp_thresh=0.9, batch_size=1, ensemble_method="avg_wave", exclude_stems="", weights_str=""):
 
 
 
 
 
 
438
  temp_audio_path = None
439
+ chunk_paths = []
440
  try:
441
+ if not audio or not model_keys:
442
+ raise ValueError("Audio or models missing.")
443
 
444
+ # Handle tuple input (sample_rate, data)
445
  if isinstance(audio, tuple):
446
  sample_rate, data = audio
447
  temp_audio_path = os.path.join("/tmp", "temp_audio.wav")
448
  scipy.io.wavfile.write(temp_audio_path, sample_rate, data)
449
  audio = temp_audio_path
450
 
451
+ # Load audio to check duration
452
+ audio_data, sr = librosa.load(audio, sr=None, mono=False)
453
+ duration = librosa.get_duration(y=audio_data, sr=sr)
454
+ logger.info(f"Audio duration: {duration:.2f} seconds")
455
+
456
+ # Split audio if longer than 15 minutes (900 seconds)
457
+ chunk_duration = 300 # 5 minutes in seconds
458
+ chunks = []
459
+ if duration > 900:
460
+ logger.info(f"Audio exceeds 15 minutes, splitting into {chunk_duration}-second chunks")
461
+ num_chunks = int(np.ceil(duration / chunk_duration))
462
+ for i in range(num_chunks):
463
+ start = i * chunk_duration * sr
464
+ end = min((i + 1) * chunk_duration * sr, audio_data.shape[-1])
465
+ chunk_data = audio_data[:, start:end] if audio_data.ndim == 2 else audio_data[start:end]
466
+ chunk_path = os.path.join("/tmp", f"chunk_{i}.wav")
467
+ sf.write(chunk_path, chunk_data.T if audio_data.ndim == 2 else chunk_data, sr)
468
+ chunks.append(chunk_path)
469
+ chunk_paths.append(chunk_path)
470
+ logger.info(f"Created chunk {i}: {chunk_path}")
471
+ else:
472
+ chunks = [audio]
473
+
474
  use_tta = use_tta == "True"
475
 
476
+ # Create output directory
477
  if os.path.exists(output_dir):
478
  shutil.rmtree(output_dir)
479
  os.makedirs(output_dir, exist_ok=True)
 
482
  logger.info(f"Ensemble for {base_name} with {model_keys} on {device}")
483
 
484
  all_stems = []
485
+ model_stems = {} # Store stems per model for concatenation
486
 
487
+ for model_key in model_keys:
488
+ model_stems[model_key] = {"vocals": [], "other": []}
489
  for category, models in ROFORMER_MODELS.items():
490
  if model_key in models:
491
  model = models[model_key]
492
  break
493
  else:
494
+ logger.warning(f"Model {model_key} not found, skipping")
495
  continue
496
 
497
+ for chunk_idx, chunk_path in enumerate(chunks):
498
+ separator = Separator(
499
+ log_level=logging.INFO,
500
+ model_file_dir=model_dir,
501
+ output_dir=output_dir,
502
+ output_format=out_format,
503
+ normalization_threshold=norm_thresh,
504
+ amplification_threshold=amp_thresh,
505
+ use_autocast=use_autocast,
506
+ mdxc_params={"segment_size": seg_size, "overlap": overlap, "use_tta": use_tta, "batch_size": batch_size}
507
+ )
508
+ logger.info(f"Loading {model_key} for chunk {chunk_idx}")
509
+ separator.load_model(model_filename=model)
510
+ logger.info(f"Separating chunk {chunk_idx} with {model_key}")
511
+ separation = separator.separate(chunk_path)
512
+ stems = [os.path.join(output_dir, file_name) for file_name in separation]
513
+
514
+ # Store stems for this chunk
515
+ for stem in stems:
516
+ if "vocals" in os.path.basename(stem).lower():
517
+ model_stems[model_key]["vocals"].append(stem)
518
+ elif "other" in os.path.basename(stem).lower():
519
+ model_stems[model_key]["other"].append(stem)
520
+
521
+ # Clean up memory
522
+ separator = None
523
+ gc.collect()
524
+ if torch.cuda.is_available():
525
+ torch.cuda.empty_cache()
526
+ logger.info(f"Cleared CUDA cache after {model_key} chunk {chunk_idx}")
527
+
528
+ # Combine stems for each model
529
+ for model_key, stems_dict in model_stems.items():
530
+ for stem_type in ["vocals", "other"]:
531
+ if stems_dict[stem_type]:
532
+ combined_path = os.path.join(output_dir, f"{base_name}_{stem_type}_{model_key.replace(' | ', '_').replace(' ', '_')}.wav")
533
+ combined_data = []
534
+ for stem_path in stems_dict[stem_type]:
535
+ data, _ = librosa.load(stem_path, sr=sr, mono=False)
536
+ combined_data.append(data)
537
+ combined_data = np.concatenate(combined_data, axis=-1) if combined_data[0].ndim == 2 else np.concatenate(combined_data)
538
+ sf.write(combined_path, combined_data.T if combined_data.ndim == 2 else combined_data, sr)
539
+ logger.info(f"Combined {stem_type} for {model_key}: {combined_path}")
540
+ if exclude_stems.strip() and stem_type.lower() in [s.strip().lower() for s in exclude_stems.split(',')]:
541
+ logger.info(f"Excluding {stem_type} for {model_key}")
542
+ continue
543
+ all_stems.append(combined_path)
544
 
545
  if not all_stems:
546
  raise ValueError("No valid stems for ensemble after exclusion.")
547
 
548
+ # Ensemble the combined stems
549
  weights = [float(w.strip()) for w in weights_str.split(',')] if weights_str.strip() else [1.0] * len(all_stems)
550
  if len(weights) != len(all_stems):
551
  weights = [1.0] * len(all_stems)
 
557
  "--weights", *[str(w) for w in weights],
558
  "--output", output_file
559
  ]
560
+ logger.info(f"Running ensemble with args: {ensemble_args}")
561
  ensemble_files(ensemble_args)
562
 
563
  logger.info("Ensemble complete")
 
566
  logger.error(f"Ensemble failed: {e}")
567
  raise RuntimeError(f"Ensemble failed: {e}")
568
  finally:
569
+ # Clean up temporary files
570
+ for path in chunk_paths + ([temp_audio_path] if temp_audio_path and os.path.exists(temp_audio_path) else []):
571
  try:
572
+ if os.path.exists(path):
573
+ os.remove(path)
574
+ logger.info(f"Successfully cleaned up {path}")
575
  except Exception as e:
576
+ logger.error(f"Failed to clean up {path}: {e}")
577
 
578
  def update_roformer_models(category):
579
  """Update Roformer model dropdown based on selected category."""