Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -22,7 +22,6 @@ import time
|
|
| 22 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 23 |
from threading import Lock
|
| 24 |
import scipy.io.wavfile
|
| 25 |
-
import spaces
|
| 26 |
import subprocess
|
| 27 |
|
| 28 |
# Logging setup
|
|
@@ -156,7 +155,7 @@ ROFORMER_MODELS = {
|
|
| 156 |
|
| 157 |
OUTPUT_FORMATS = ['wav', 'flac', 'mp3', 'ogg', 'opus', 'm4a', 'aiff', 'ac3']
|
| 158 |
|
| 159 |
-
# CSS (
|
| 160 |
CSS = """
|
| 161 |
body {
|
| 162 |
background: linear-gradient(to bottom, rgba(45, 11, 11, 0.9), rgba(0, 0, 0, 0.8)), url('/content/logo.jpg') no-repeat center center fixed;
|
|
@@ -382,18 +381,17 @@ def download_audio(url, cookie_file=None):
|
|
| 382 |
@spaces.GPU
|
| 383 |
def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, pitch_shift, model_dir, output_dir, out_format, norm_thresh, amp_thresh, batch_size, exclude_stems="", progress=gr.Progress(track_tqdm=True)):
|
| 384 |
if not audio:
|
| 385 |
-
raise ValueError("No audio file provided.")
|
| 386 |
temp_audio_path = None
|
| 387 |
extracted_audio_path = None
|
| 388 |
try:
|
| 389 |
-
# Giriş dosyasının uzantısını kontrol et
|
| 390 |
file_extension = os.path.splitext(audio)[1].lower().lstrip('.')
|
| 391 |
-
|
| 392 |
-
|
|
|
|
| 393 |
|
| 394 |
-
# Eğer giriş bir video dosyasıysa, sesi çıkar
|
| 395 |
audio_to_process = audio
|
| 396 |
-
if
|
| 397 |
extracted_audio_path = os.path.join("/tmp", f"extracted_audio_{os.path.basename(audio)}.wav")
|
| 398 |
logger.info(f"Extracting audio from video file: {audio}")
|
| 399 |
ffmpeg_command = [
|
|
@@ -405,8 +403,13 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
|
|
| 405 |
logger.info(f"Audio extracted to: {extracted_audio_path}")
|
| 406 |
audio_to_process = extracted_audio_path
|
| 407 |
except subprocess.CalledProcessError as e:
|
| 408 |
-
|
| 409 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
|
| 411 |
if isinstance(audio_to_process, tuple):
|
| 412 |
sample_rate, data = audio_to_process
|
|
@@ -454,17 +457,25 @@ def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, p
|
|
| 454 |
file_list = stems
|
| 455 |
stem1 = stems[0]
|
| 456 |
stem2 = stems[1] if len(stems) > 1 else None
|
|
|
|
| 457 |
return stem1, stem2, file_list
|
|
|
|
| 458 |
except Exception as e:
|
| 459 |
logger.error(f"Separation error: {e}")
|
| 460 |
raise RuntimeError(f"Separation error: {e}")
|
| 461 |
finally:
|
| 462 |
if temp_audio_path and os.path.exists(temp_audio_path):
|
| 463 |
-
|
| 464 |
-
|
|
|
|
|
|
|
|
|
|
| 465 |
if extracted_audio_path and os.path.exists(extracted_audio_path):
|
| 466 |
-
|
| 467 |
-
|
|
|
|
|
|
|
|
|
|
| 468 |
if torch.cuda.is_available():
|
| 469 |
torch.cuda.empty_cache()
|
| 470 |
logger.info("GPU memory cleared")
|
|
@@ -476,21 +487,20 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
|
|
| 476 |
start_time = time.time()
|
| 477 |
try:
|
| 478 |
if not audio:
|
| 479 |
-
raise ValueError("No audio file provided.")
|
| 480 |
if not model_keys:
|
| 481 |
raise ValueError("No models selected.")
|
| 482 |
if len(model_keys) > max_models:
|
| 483 |
logger.warning(f"Selected {len(model_keys)} models, limiting to {max_models}.")
|
| 484 |
model_keys = model_keys[:max_models]
|
| 485 |
|
| 486 |
-
# Giriş dosyasının uzantısını kontrol et
|
| 487 |
file_extension = os.path.splitext(audio)[1].lower().lstrip('.')
|
| 488 |
-
|
| 489 |
-
|
|
|
|
| 490 |
|
| 491 |
-
# Eğer giriş bir video dosyasıysa, sesi çıkar
|
| 492 |
audio_to_process = audio
|
| 493 |
-
if
|
| 494 |
extracted_audio_path = os.path.join("/tmp", f"extracted_audio_{os.path.basename(audio)}.wav")
|
| 495 |
logger.info(f"Extracting audio from video file: {audio}")
|
| 496 |
ffmpeg_command = [
|
|
@@ -502,10 +512,14 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
|
|
| 502 |
logger.info(f"Audio extracted to: {extracted_audio_path}")
|
| 503 |
audio_to_process = extracted_audio_path
|
| 504 |
except subprocess.CalledProcessError as e:
|
| 505 |
-
|
| 506 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
|
| 508 |
-
# Audio süresine göre dinamik batch size
|
| 509 |
audio_data, sr = librosa.load(audio_to_process, sr=None, mono=False)
|
| 510 |
duration = librosa.get_duration(y=audio_data, sr=sr)
|
| 511 |
logger.info(f"Audio duration: {duration:.2f} seconds")
|
|
@@ -518,7 +532,6 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
|
|
| 518 |
scipy.io.wavfile.write(temp_audio_path, sample_rate, data)
|
| 519 |
audio_to_process = temp_audio_path
|
| 520 |
|
| 521 |
-
# State kontrolü
|
| 522 |
if not state:
|
| 523 |
state = {
|
| 524 |
"current_audio": None,
|
|
@@ -527,7 +540,6 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
|
|
| 527 |
"model_outputs": {}
|
| 528 |
}
|
| 529 |
|
| 530 |
-
# Yeni audio dosyası kontrolü - yalnızca audio değiştiğinde sıfırlıyoruz
|
| 531 |
if state["current_audio"] != audio:
|
| 532 |
state["current_audio"] = audio
|
| 533 |
state["current_model_idx"] = 0
|
|
@@ -539,28 +551,20 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
|
|
| 539 |
base_name = os.path.splitext(os.path.basename(audio))[0]
|
| 540 |
logger.info(f"Ensemble for {base_name} with {model_keys} on {device}")
|
| 541 |
|
| 542 |
-
# Kalıcı bir klasör oluştur
|
| 543 |
permanent_output_dir = os.path.join(output_dir, "permanent_stems")
|
| 544 |
os.makedirs(permanent_output_dir, exist_ok=True)
|
| 545 |
|
| 546 |
-
# Model cache
|
| 547 |
model_cache = {}
|
| 548 |
all_stems = []
|
| 549 |
total_tasks = len(model_keys)
|
| 550 |
-
|
| 551 |
-
# Şu anki modeli işle
|
| 552 |
current_idx = state["current_model_idx"]
|
| 553 |
logger.info(f"Current model index: {current_idx}, total models: {len(model_keys)}")
|
| 554 |
|
| 555 |
-
# Tüm modeller işlendiyse ensemble işlemini yap
|
| 556 |
if current_idx >= len(model_keys):
|
| 557 |
logger.info("All models processed, running ensemble...")
|
| 558 |
progress(0.9, desc="Running ensemble...")
|
| 559 |
|
| 560 |
-
# "Exclude Stems" listesindeki stem'leri belirle
|
| 561 |
excluded_stems_list = [s.strip().lower() for s in exclude_stems.split(',')] if exclude_stems.strip() else []
|
| 562 |
-
|
| 563 |
-
# Tüm stem’leri topla, ama "Exclude Stems" ile belirtilenleri hariç tut
|
| 564 |
for model_key, stems_dict in state["model_outputs"].items():
|
| 565 |
for stem_type in ["vocals", "other"]:
|
| 566 |
if stems_dict[stem_type]:
|
|
@@ -573,7 +577,6 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
|
|
| 573 |
if not all_stems:
|
| 574 |
raise ValueError("No valid stems found for ensemble after excluding specified stems.")
|
| 575 |
|
| 576 |
-
# Ensemble işlemi
|
| 577 |
weights = [float(w.strip()) for w in weights_str.split(',')] if weights_str.strip() else [1.0] * len(all_stems)
|
| 578 |
if len(weights) != len(all_stems):
|
| 579 |
weights = [1.0] * len(all_stems)
|
|
@@ -590,7 +593,6 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
|
|
| 590 |
if result is None or not os.path.exists(output_file):
|
| 591 |
raise RuntimeError(f"Ensemble failed, output file not created: {output_file}")
|
| 592 |
|
| 593 |
-
# Durumu sıfırla
|
| 594 |
state["current_model_idx"] = 0
|
| 595 |
state["current_audio"] = None
|
| 596 |
state["processed_stems"] = []
|
|
@@ -607,7 +609,6 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
|
|
| 607 |
status += "</ul>"
|
| 608 |
return output_file, status, file_list, state
|
| 609 |
|
| 610 |
-
# Şu anki modeli işle
|
| 611 |
model_key = model_keys[current_idx]
|
| 612 |
logger.info(f"Processing model {current_idx + 1}/{len(model_keys)}: {model_key}")
|
| 613 |
progress(0.1, desc=f"Processing model {model_key}...")
|
|
@@ -615,7 +616,6 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
|
|
| 615 |
with torch.no_grad():
|
| 616 |
for attempt in range(max_retries + 1):
|
| 617 |
try:
|
| 618 |
-
# Modeli bul
|
| 619 |
for category, models in ROFORMER_MODELS.items():
|
| 620 |
if model_key in models:
|
| 621 |
model = models[model_key]
|
|
@@ -625,13 +625,11 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
|
|
| 625 |
state["current_model_idx"] += 1
|
| 626 |
return None, f"Model {model_key} not found, proceeding to next model.", [], state
|
| 627 |
|
| 628 |
-
# Zaman kontrolü
|
| 629 |
elapsed = time.time() - start_time
|
| 630 |
if elapsed > time_budget:
|
| 631 |
logger.error(f"Time budget ({time_budget}s) exceeded")
|
| 632 |
raise TimeoutError("Processing took too long")
|
| 633 |
|
| 634 |
-
# Separator oluştur
|
| 635 |
if model_key not in model_cache:
|
| 636 |
logger.info(f"Loading {model_key} into cache")
|
| 637 |
separator = Separator(
|
|
@@ -654,15 +652,12 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
|
|
| 654 |
else:
|
| 655 |
separator = model_cache[model_key]
|
| 656 |
|
| 657 |
-
# GPU ile işlem
|
| 658 |
with gpu_lock:
|
| 659 |
progress(0.3, desc=f"Separating with {model_key}")
|
| 660 |
logger.info(f"Separating with {model_key}")
|
| 661 |
separation = separator.separate(audio_to_process)
|
| 662 |
stems = [os.path.join(output_dir, file_name) for file_name in separation]
|
| 663 |
result = []
|
| 664 |
-
|
| 665 |
-
# Stem’leri kalıcı klasöre taşı
|
| 666 |
for stem in stems:
|
| 667 |
stem_type = "vocals" if "vocals" in os.path.basename(stem).lower() else "other"
|
| 668 |
permanent_stem_path = os.path.join(permanent_output_dir, f"{base_name}_{stem_type}_{model_key.replace(' | ', '_').replace(' ', '_')}.{out_format}")
|
|
@@ -670,7 +665,6 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
|
|
| 670 |
state["model_outputs"][model_key][stem_type].append(permanent_stem_path)
|
| 671 |
if stem_type not in exclude_stems.lower():
|
| 672 |
result.append(permanent_stem_path)
|
| 673 |
-
|
| 674 |
state["processed_stems"].extend(result)
|
| 675 |
break
|
| 676 |
|
|
@@ -687,24 +681,20 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
|
|
| 687 |
torch.cuda.empty_cache()
|
| 688 |
logger.info(f"Cleared CUDA cache after {model_key}")
|
| 689 |
|
| 690 |
-
# Model cache temizliği
|
| 691 |
model_cache.clear()
|
| 692 |
gc.collect()
|
| 693 |
if torch.cuda.is_available():
|
| 694 |
torch.cuda.empty_cache()
|
| 695 |
logger.info("Cleared model cache and GPU memory")
|
| 696 |
|
| 697 |
-
# Bir sonraki modele geç
|
| 698 |
state["current_model_idx"] += 1
|
| 699 |
elapsed = time.time() - start_time
|
| 700 |
logger.info(f"Model {model_key} completed in {elapsed:.2f}s")
|
| 701 |
|
| 702 |
-
# Eğer bu son modelse, ensemble işlemini hemen başlat
|
| 703 |
if state["current_model_idx"] >= len(model_keys):
|
| 704 |
logger.info("Last model processed, running ensemble immediately...")
|
| 705 |
return auto_ensemble_process(audio, model_keys, state, seg_size, overlap, out_format, use_tta, model_dir, output_dir, norm_thresh, amp_thresh, batch_size, ensemble_method, exclude_stems, weights_str, progress)
|
| 706 |
|
| 707 |
-
# Çıktılar
|
| 708 |
file_list = state["processed_stems"]
|
| 709 |
status = f"Model {model_key} (Model {current_idx + 1}/{len(model_keys)}) completed in {elapsed:.2f}s<br>Click 'Run Ensemble!' to process the next model.<br>Processed stems:<ul>"
|
| 710 |
for file in file_list:
|
|
@@ -715,7 +705,7 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
|
|
| 715 |
|
| 716 |
except Exception as e:
|
| 717 |
logger.error(f"Ensemble error: {e}")
|
| 718 |
-
error_msg = f"Processing failed: {e}. Try fewer models (max {max_models}) or uploading a local WAV file."
|
| 719 |
raise RuntimeError(error_msg)
|
| 720 |
|
| 721 |
finally:
|
|
@@ -736,13 +726,11 @@ def auto_ensemble_process(audio, model_keys, state, seg_size=64, overlap=0.1, ou
|
|
| 736 |
logger.info("GPU memory cleared")
|
| 737 |
|
| 738 |
def update_roformer_models(category):
|
| 739 |
-
"""Update Roformer model dropdown based on selected category."""
|
| 740 |
choices = list(ROFORMER_MODELS.get(category, {}).keys()) or []
|
| 741 |
logger.debug(f"Updating roformer models for category {category}: {choices}")
|
| 742 |
return gr.update(choices=choices, value=choices[0] if choices else None)
|
| 743 |
|
| 744 |
def update_ensemble_models(category):
|
| 745 |
-
"""Update ensemble model dropdown based on selected category."""
|
| 746 |
choices = list(ROFORMER_MODELS.get(category, {}).keys()) or []
|
| 747 |
logger.debug(f"Updating ensemble models for category {category}: {choices}")
|
| 748 |
return gr.update(choices=choices, value=[])
|
|
@@ -756,7 +744,6 @@ def create_interface():
|
|
| 756 |
gr.Markdown("<h1 class='header-text'>🎵 SESA Fast Separation 🎵</h1>")
|
| 757 |
gr.Markdown("**Note**: If YouTube downloads fail, upload a valid cookies file or a local WAV/MP4/MOV file. [Cookie Instructions](https://github.com/yt-dlp/yt-dlp/wiki/Extractors#exporting-youtube-cookies)")
|
| 758 |
gr.Markdown("**Tip**: For best results, use audio/video shorter than 15 minutes or fewer models (up to 6) to ensure smooth processing.")
|
| 759 |
-
# Gradio State bileşeni
|
| 760 |
ensemble_state = gr.State(value={
|
| 761 |
"current_audio": None,
|
| 762 |
"current_model_idx": 0,
|
|
@@ -777,7 +764,7 @@ def create_interface():
|
|
| 777 |
with gr.Group(elem_classes="dubbing-theme"):
|
| 778 |
gr.Markdown("### Audio Separation")
|
| 779 |
with gr.Row():
|
| 780 |
-
roformer_audio = gr.
|
| 781 |
url_ro = gr.Textbox(label="🔗 Or Paste URL", placeholder="YouTube or audio/video URL", interactive=True)
|
| 782 |
cookies_ro = gr.File(label="🍪 Cookies File", file_types=[".txt"], interactive=True)
|
| 783 |
download_roformer = gr.Button("⬇️ Download", variant="secondary")
|
|
@@ -802,7 +789,7 @@ def create_interface():
|
|
| 802 |
gr.Markdown("### Ensemble Processing")
|
| 803 |
gr.Markdown("Note: If weights are not specified, equal weights (1.0) are applied. Use up to 6 models for best results.")
|
| 804 |
with gr.Row():
|
| 805 |
-
ensemble_audio = gr.
|
| 806 |
url_ensemble = gr.Textbox(label="🔗 Or Paste URL", placeholder="YouTube or audio/video URL", interactive=True)
|
| 807 |
cookies_ensemble = gr.File(label="🍪 Cookies File", file_types=[".txt"], interactive=True)
|
| 808 |
download_ensemble = gr.Button("⬇️ Download", variant="secondary")
|
|
|
|
| 22 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 23 |
from threading import Lock
|
| 24 |
import scipy.io.wavfile
|
|
|
|
| 25 |
import subprocess
|
| 26 |
|
| 27 |
# Logging setup
|
|
|
|
| 155 |
|
| 156 |
OUTPUT_FORMATS = ['wav', 'flac', 'mp3', 'ogg', 'opus', 'm4a', 'aiff', 'ac3']
|
| 157 |
|
| 158 |
+
# CSS (orijinal CSS korundu)
|
| 159 |
CSS = """
|
| 160 |
body {
|
| 161 |
background: linear-gradient(to bottom, rgba(45, 11, 11, 0.9), rgba(0, 0, 0, 0.8)), url('/content/logo.jpg') no-repeat center center fixed;
|
|
|
|
| 381 |
@spaces.GPU
|
| 382 |
def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, pitch_shift, model_dir, output_dir, out_format, norm_thresh, amp_thresh, batch_size, exclude_stems="", progress=gr.Progress(track_tqdm=True)):
|
| 383 |
if not audio:
|
| 384 |
+
raise ValueError("No audio or video file provided.")
|
| 385 |
temp_audio_path = None
|
| 386 |
extracted_audio_path = None
|
| 387 |
try:
|
|
|
|
| 388 |
file_extension = os.path.splitext(audio)[1].lower().lstrip('.')
|
| 389 |
+
supported_formats = ['wav', 'mp3', 'flac', 'ogg', 'opus', 'm4a', 'aiff', 'ac3', 'mp4', 'mov', 'avi', 'mkv', 'flv', 'wmv', 'webm', 'mpeg', 'mpg', 'ts', 'vob']
|
| 390 |
+
if file_extension not in supported_formats:
|
| 391 |
+
raise ValueError(f"Unsupported file format: {file_extension}. Supported formats: {', '.join(supported_formats)}")
|
| 392 |
|
|
|
|
| 393 |
audio_to_process = audio
|
| 394 |
+
if file_extension in ['mp4', 'mov', 'avi', 'mkv', 'flv', 'wmv', 'webm', 'mpeg', 'mpg', 'ts', 'vob']:
|
| 395 |
extracted_audio_path = os.path.join("/tmp", f"extracted_audio_{os.path.basename(audio)}.wav")
|
| 396 |
logger.info(f"Extracting audio from video file: {audio}")
|
| 397 |
ffmpeg_command = [
|
|
|
|
| 403 |
logger.info(f"Audio extracted to: {extracted_audio_path}")
|
| 404 |
audio_to_process = extracted_audio_path
|
| 405 |
except subprocess.CalledProcessError as e:
|
| 406 |
+
error_message = e.stderr.decode() if e.stderr else str(e)
|
| 407 |
+
if "No audio stream" in error_message:
|
| 408 |
+
raise RuntimeError("The provided video file does not contain an audio track.")
|
| 409 |
+
elif "Invalid data" in error_message:
|
| 410 |
+
raise RuntimeError("The video file is corrupted or not supported.")
|
| 411 |
+
else:
|
| 412 |
+
raise RuntimeError(f"Failed to extract audio from video: {error_message}")
|
| 413 |
|
| 414 |
if isinstance(audio_to_process, tuple):
|
| 415 |
sample_rate, data = audio_to_process
|
|
|
|
| 457 |
file_list = stems
|
| 458 |
stem1 = stems[0]
|
| 459 |
stem2 = stems[1] if len(stems) > 1 else None
|
| 460 |
+
|
| 461 |
return stem1, stem2, file_list
|
| 462 |
+
|
| 463 |
except Exception as e:
|
| 464 |
logger.error(f"Separation error: {e}")
|
| 465 |
raise RuntimeError(f"Separation error: {e}")
|
| 466 |
finally:
|
| 467 |
if temp_audio_path and os.path.exists(temp_audio_path):
|
| 468 |
+
try:
|
| 469 |
+
os.remove(temp_audio_path)
|
| 470 |
+
logger.info(f"Temporary file deleted: {temp_audio_path}")
|
| 471 |
+
except Exception as e:
|
| 472 |
+
logger.warning(f"Failed to delete temporary file {temp_audio_path}: {e}")
|
| 473 |
if extracted_audio_path and os.path.exists(extracted_audio_path):
|
| 474 |
+
try:
|
| 475 |
+
os.remove(extracted_audio_path)
|
| 476 |
+
logger.info(f"Extracted audio file deleted: {extracted_audio_path}")
|
| 477 |
+
except Exception as e:
|
| 478 |
+
logger.warning(f"Failed to delete extracted audio file {extracted_audio_path}: {e}")
|
| 479 |
if torch.cuda.is_available():
|
| 480 |
torch.cuda.empty_cache()
|
| 481 |
logger.info("GPU memory cleared")
|
|
|
|
| 487 |
start_time = time.time()
|
| 488 |
try:
|
| 489 |
if not audio:
|
| 490 |
+
raise ValueError("No audio or video file provided.")
|
| 491 |
if not model_keys:
|
| 492 |
raise ValueError("No models selected.")
|
| 493 |
if len(model_keys) > max_models:
|
| 494 |
logger.warning(f"Selected {len(model_keys)} models, limiting to {max_models}.")
|
| 495 |
model_keys = model_keys[:max_models]
|
| 496 |
|
|
|
|
| 497 |
file_extension = os.path.splitext(audio)[1].lower().lstrip('.')
|
| 498 |
+
supported_formats = ['wav', 'mp3', 'flac', 'ogg', 'opus', 'm4a', 'aiff', 'ac3', 'mp4', 'mov', 'avi', 'mkv', 'flv', 'wmv', 'webm', 'mpeg', 'mpg', 'ts', 'vob']
|
| 499 |
+
if file_extension not in supported_formats:
|
| 500 |
+
raise ValueError(f"Unsupported file format: {file_extension}. Supported formats: {', '.join(supported_formats)}")
|
| 501 |
|
|
|
|
| 502 |
audio_to_process = audio
|
| 503 |
+
if file_extension in ['mp4', 'mov', 'avi', 'mkv', 'flv', 'wmv', 'webm', 'mpeg', 'mpg', 'ts', 'vob']:
|
| 504 |
extracted_audio_path = os.path.join("/tmp", f"extracted_audio_{os.path.basename(audio)}.wav")
|
| 505 |
logger.info(f"Extracting audio from video file: {audio}")
|
| 506 |
ffmpeg_command = [
|
|
|
|
| 512 |
logger.info(f"Audio extracted to: {extracted_audio_path}")
|
| 513 |
audio_to_process = extracted_audio_path
|
| 514 |
except subprocess.CalledProcessError as e:
|
| 515 |
+
error_message = e.stderr.decode() if e.stderr else str(e)
|
| 516 |
+
if "No audio stream" in error_message:
|
| 517 |
+
raise RuntimeError("The provided video file does not contain an audio track.")
|
| 518 |
+
elif "Invalid data" in error_message:
|
| 519 |
+
raise RuntimeError("The video file is corrupted or not supported.")
|
| 520 |
+
else:
|
| 521 |
+
raise RuntimeError(f"Failed to extract audio from video: {error_message}")
|
| 522 |
|
|
|
|
| 523 |
audio_data, sr = librosa.load(audio_to_process, sr=None, mono=False)
|
| 524 |
duration = librosa.get_duration(y=audio_data, sr=sr)
|
| 525 |
logger.info(f"Audio duration: {duration:.2f} seconds")
|
|
|
|
| 532 |
scipy.io.wavfile.write(temp_audio_path, sample_rate, data)
|
| 533 |
audio_to_process = temp_audio_path
|
| 534 |
|
|
|
|
| 535 |
if not state:
|
| 536 |
state = {
|
| 537 |
"current_audio": None,
|
|
|
|
| 540 |
"model_outputs": {}
|
| 541 |
}
|
| 542 |
|
|
|
|
| 543 |
if state["current_audio"] != audio:
|
| 544 |
state["current_audio"] = audio
|
| 545 |
state["current_model_idx"] = 0
|
|
|
|
| 551 |
base_name = os.path.splitext(os.path.basename(audio))[0]
|
| 552 |
logger.info(f"Ensemble for {base_name} with {model_keys} on {device}")
|
| 553 |
|
|
|
|
| 554 |
permanent_output_dir = os.path.join(output_dir, "permanent_stems")
|
| 555 |
os.makedirs(permanent_output_dir, exist_ok=True)
|
| 556 |
|
|
|
|
| 557 |
model_cache = {}
|
| 558 |
all_stems = []
|
| 559 |
total_tasks = len(model_keys)
|
|
|
|
|
|
|
| 560 |
current_idx = state["current_model_idx"]
|
| 561 |
logger.info(f"Current model index: {current_idx}, total models: {len(model_keys)}")
|
| 562 |
|
|
|
|
| 563 |
if current_idx >= len(model_keys):
|
| 564 |
logger.info("All models processed, running ensemble...")
|
| 565 |
progress(0.9, desc="Running ensemble...")
|
| 566 |
|
|
|
|
| 567 |
excluded_stems_list = [s.strip().lower() for s in exclude_stems.split(',')] if exclude_stems.strip() else []
|
|
|
|
|
|
|
| 568 |
for model_key, stems_dict in state["model_outputs"].items():
|
| 569 |
for stem_type in ["vocals", "other"]:
|
| 570 |
if stems_dict[stem_type]:
|
|
|
|
| 577 |
if not all_stems:
|
| 578 |
raise ValueError("No valid stems found for ensemble after excluding specified stems.")
|
| 579 |
|
|
|
|
| 580 |
weights = [float(w.strip()) for w in weights_str.split(',')] if weights_str.strip() else [1.0] * len(all_stems)
|
| 581 |
if len(weights) != len(all_stems):
|
| 582 |
weights = [1.0] * len(all_stems)
|
|
|
|
| 593 |
if result is None or not os.path.exists(output_file):
|
| 594 |
raise RuntimeError(f"Ensemble failed, output file not created: {output_file}")
|
| 595 |
|
|
|
|
| 596 |
state["current_model_idx"] = 0
|
| 597 |
state["current_audio"] = None
|
| 598 |
state["processed_stems"] = []
|
|
|
|
| 609 |
status += "</ul>"
|
| 610 |
return output_file, status, file_list, state
|
| 611 |
|
|
|
|
| 612 |
model_key = model_keys[current_idx]
|
| 613 |
logger.info(f"Processing model {current_idx + 1}/{len(model_keys)}: {model_key}")
|
| 614 |
progress(0.1, desc=f"Processing model {model_key}...")
|
|
|
|
| 616 |
with torch.no_grad():
|
| 617 |
for attempt in range(max_retries + 1):
|
| 618 |
try:
|
|
|
|
| 619 |
for category, models in ROFORMER_MODELS.items():
|
| 620 |
if model_key in models:
|
| 621 |
model = models[model_key]
|
|
|
|
| 625 |
state["current_model_idx"] += 1
|
| 626 |
return None, f"Model {model_key} not found, proceeding to next model.", [], state
|
| 627 |
|
|
|
|
| 628 |
elapsed = time.time() - start_time
|
| 629 |
if elapsed > time_budget:
|
| 630 |
logger.error(f"Time budget ({time_budget}s) exceeded")
|
| 631 |
raise TimeoutError("Processing took too long")
|
| 632 |
|
|
|
|
| 633 |
if model_key not in model_cache:
|
| 634 |
logger.info(f"Loading {model_key} into cache")
|
| 635 |
separator = Separator(
|
|
|
|
| 652 |
else:
|
| 653 |
separator = model_cache[model_key]
|
| 654 |
|
|
|
|
| 655 |
with gpu_lock:
|
| 656 |
progress(0.3, desc=f"Separating with {model_key}")
|
| 657 |
logger.info(f"Separating with {model_key}")
|
| 658 |
separation = separator.separate(audio_to_process)
|
| 659 |
stems = [os.path.join(output_dir, file_name) for file_name in separation]
|
| 660 |
result = []
|
|
|
|
|
|
|
| 661 |
for stem in stems:
|
| 662 |
stem_type = "vocals" if "vocals" in os.path.basename(stem).lower() else "other"
|
| 663 |
permanent_stem_path = os.path.join(permanent_output_dir, f"{base_name}_{stem_type}_{model_key.replace(' | ', '_').replace(' ', '_')}.{out_format}")
|
|
|
|
| 665 |
state["model_outputs"][model_key][stem_type].append(permanent_stem_path)
|
| 666 |
if stem_type not in exclude_stems.lower():
|
| 667 |
result.append(permanent_stem_path)
|
|
|
|
| 668 |
state["processed_stems"].extend(result)
|
| 669 |
break
|
| 670 |
|
|
|
|
| 681 |
torch.cuda.empty_cache()
|
| 682 |
logger.info(f"Cleared CUDA cache after {model_key}")
|
| 683 |
|
|
|
|
| 684 |
model_cache.clear()
|
| 685 |
gc.collect()
|
| 686 |
if torch.cuda.is_available():
|
| 687 |
torch.cuda.empty_cache()
|
| 688 |
logger.info("Cleared model cache and GPU memory")
|
| 689 |
|
|
|
|
| 690 |
state["current_model_idx"] += 1
|
| 691 |
elapsed = time.time() - start_time
|
| 692 |
logger.info(f"Model {model_key} completed in {elapsed:.2f}s")
|
| 693 |
|
|
|
|
| 694 |
if state["current_model_idx"] >= len(model_keys):
|
| 695 |
logger.info("Last model processed, running ensemble immediately...")
|
| 696 |
return auto_ensemble_process(audio, model_keys, state, seg_size, overlap, out_format, use_tta, model_dir, output_dir, norm_thresh, amp_thresh, batch_size, ensemble_method, exclude_stems, weights_str, progress)
|
| 697 |
|
|
|
|
| 698 |
file_list = state["processed_stems"]
|
| 699 |
status = f"Model {model_key} (Model {current_idx + 1}/{len(model_keys)}) completed in {elapsed:.2f}s<br>Click 'Run Ensemble!' to process the next model.<br>Processed stems:<ul>"
|
| 700 |
for file in file_list:
|
|
|
|
| 705 |
|
| 706 |
except Exception as e:
|
| 707 |
logger.error(f"Ensemble error: {e}")
|
| 708 |
+
error_msg = f"Processing failed: {e}. Try fewer models (max {max_models}) or uploading a local WAV or video file."
|
| 709 |
raise RuntimeError(error_msg)
|
| 710 |
|
| 711 |
finally:
|
|
|
|
| 726 |
logger.info("GPU memory cleared")
|
| 727 |
|
| 728 |
def update_roformer_models(category):
|
|
|
|
| 729 |
choices = list(ROFORMER_MODELS.get(category, {}).keys()) or []
|
| 730 |
logger.debug(f"Updating roformer models for category {category}: {choices}")
|
| 731 |
return gr.update(choices=choices, value=choices[0] if choices else None)
|
| 732 |
|
| 733 |
def update_ensemble_models(category):
|
|
|
|
| 734 |
choices = list(ROFORMER_MODELS.get(category, {}).keys()) or []
|
| 735 |
logger.debug(f"Updating ensemble models for category {category}: {choices}")
|
| 736 |
return gr.update(choices=choices, value=[])
|
|
|
|
| 744 |
gr.Markdown("<h1 class='header-text'>🎵 SESA Fast Separation 🎵</h1>")
|
| 745 |
gr.Markdown("**Note**: If YouTube downloads fail, upload a valid cookies file or a local WAV/MP4/MOV file. [Cookie Instructions](https://github.com/yt-dlp/yt-dlp/wiki/Extractors#exporting-youtube-cookies)")
|
| 746 |
gr.Markdown("**Tip**: For best results, use audio/video shorter than 15 minutes or fewer models (up to 6) to ensure smooth processing.")
|
|
|
|
| 747 |
ensemble_state = gr.State(value={
|
| 748 |
"current_audio": None,
|
| 749 |
"current_model_idx": 0,
|
|
|
|
| 764 |
with gr.Group(elem_classes="dubbing-theme"):
|
| 765 |
gr.Markdown("### Audio Separation")
|
| 766 |
with gr.Row():
|
| 767 |
+
roformer_audio = gr.File(label="🎧 Upload Audio or Video (WAV, MP3, MP4, MOV, etc.)", file_types=['.wav', '.mp3', '.flac', '.ogg', '.opus', '.m4a', '.aiff', '.ac3', '.mp4', '.mov', '.avi', '.mkv', '.flv', '.wmv', '.webm', '.mpeg', '.mpg', '.ts', '.vob'], interactive=True)
|
| 768 |
url_ro = gr.Textbox(label="🔗 Or Paste URL", placeholder="YouTube or audio/video URL", interactive=True)
|
| 769 |
cookies_ro = gr.File(label="🍪 Cookies File", file_types=[".txt"], interactive=True)
|
| 770 |
download_roformer = gr.Button("⬇️ Download", variant="secondary")
|
|
|
|
| 789 |
gr.Markdown("### Ensemble Processing")
|
| 790 |
gr.Markdown("Note: If weights are not specified, equal weights (1.0) are applied. Use up to 6 models for best results.")
|
| 791 |
with gr.Row():
|
| 792 |
+
ensemble_audio = gr.File(label="🎧 Upload Audio or Video (WAV, MP3, MP4, MOV, etc.)", file_types=['.wav', '.mp3', '.flac', '.ogg', '.opus', '.m4a', '.aiff', '.ac3', '.mp4', '.mov', '.avi', '.mkv', '.flv', '.wmv', '.webm', '.mpeg', '.mpg', '.ts', '.vob'], interactive=True)
|
| 793 |
url_ensemble = gr.Textbox(label="🔗 Or Paste URL", placeholder="YouTube or audio/video URL", interactive=True)
|
| 794 |
cookies_ensemble = gr.File(label="🍪 Cookies File", file_types=[".txt"], interactive=True)
|
| 795 |
download_ensemble = gr.Button("⬇️ Download", variant="secondary")
|