Spaces:
Running
Running
Update gui.py
Browse files
gui.py
CHANGED
|
@@ -12,36 +12,40 @@ import soundfile as sf
|
|
| 12 |
from ensemble import ensemble_files
|
| 13 |
import shutil
|
| 14 |
import gradio_client.utils as client_utils
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
#
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
if isinstance(schema, bool):
|
|
|
|
| 19 |
return "boolean"
|
| 20 |
-
if
|
| 21 |
-
|
| 22 |
-
if "enum" in schema:
|
| 23 |
-
return f"Union[{', '.join(repr(e) for e in schema['enum'])}]"
|
| 24 |
-
if "type" not in schema:
|
| 25 |
return "Any"
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
return f"
|
| 29 |
-
|
| 30 |
-
return
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
|
| 35 |
-
client_utils.
|
| 36 |
|
| 37 |
# Device and autocast setup
|
| 38 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 39 |
use_autocast = device == "cuda"
|
| 40 |
|
| 41 |
-
# Logging setup
|
| 42 |
-
logging.basicConfig(level=logging.INFO)
|
| 43 |
-
logger = logging.getLogger(__name__)
|
| 44 |
-
|
| 45 |
# Model dictionaries organized by category
|
| 46 |
ROFORMER_MODELS = {
|
| 47 |
"Vocals": {
|
|
@@ -136,7 +140,7 @@ ROFORMER_MODELS = {
|
|
| 136 |
|
| 137 |
OUTPUT_FORMATS = ['wav', 'flac', 'mp3', 'ogg', 'opus', 'm4a', 'aiff', 'ac3']
|
| 138 |
|
| 139 |
-
# CSS for UI styling
|
| 140 |
CSS = """
|
| 141 |
/* Modern ve Etkileşimli Tema */
|
| 142 |
#app-container {
|
|
@@ -353,8 +357,8 @@ button:hover {
|
|
| 353 |
# Functions
|
| 354 |
def download_audio(url, out_dir="ytdl"):
|
| 355 |
"""Download audio from a URL using yt-dlp."""
|
| 356 |
-
if not url:
|
| 357 |
-
raise ValueError("
|
| 358 |
|
| 359 |
if os.path.exists(out_dir):
|
| 360 |
shutil.rmtree(out_dir)
|
|
@@ -371,6 +375,7 @@ def download_audio(url, out_dir="ytdl"):
|
|
| 371 |
info_dict = ydl.extract_info(url, download=True)
|
| 372 |
return ydl.prepare_filename(info_dict).rsplit('.', 1)[0] + '.wav'
|
| 373 |
except Exception as e:
|
|
|
|
| 374 |
raise RuntimeError(f"Download failed: {e}")
|
| 375 |
|
| 376 |
def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, pitch_shift, model_dir, output_dir, out_format, norm_thresh, amp_thresh, batch_size, exclude_stems="", progress=gr.Progress(track_tqdm=True)):
|
|
@@ -489,11 +494,15 @@ def auto_ensemble_process(audio, model_keys, seg_size, overlap, out_format, use_
|
|
| 489 |
|
| 490 |
def update_roformer_models(category):
|
| 491 |
"""Update Roformer model dropdown based on selected category."""
|
| 492 |
-
|
|
|
|
|
|
|
| 493 |
|
| 494 |
def update_ensemble_models(category):
|
| 495 |
"""Update ensemble model dropdown based on selected category."""
|
| 496 |
-
|
|
|
|
|
|
|
| 497 |
|
| 498 |
# Interface creation
|
| 499 |
def create_interface():
|
|
@@ -507,7 +516,7 @@ def create_interface():
|
|
| 507 |
with gr.Group(elem_classes="dubbing-theme"):
|
| 508 |
gr.Markdown("### General Settings")
|
| 509 |
model_file_dir = gr.Textbox(value="/tmp/audio-separator-models/", label="📂 Model Cache", placeholder="Path to model directory", interactive=True)
|
| 510 |
-
output_dir = gr.Textbox(value="output", label="📤 Output Directory", placeholder="Where to save results", interactive=True)
|
| 511 |
output_format = gr.Dropdown(value="wav", choices=OUTPUT_FORMATS, label="🎶 Output Format", interactive=True)
|
| 512 |
norm_threshold = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="🔊 Normalization Threshold", interactive=True)
|
| 513 |
amp_threshold = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="📈 Amplification Threshold", interactive=True)
|
|
@@ -594,7 +603,7 @@ if __name__ == "__main__":
|
|
| 594 |
|
| 595 |
app = create_interface()
|
| 596 |
try:
|
| 597 |
-
# For Hugging Face Spaces
|
| 598 |
app.launch(server_name="0.0.0.0", server_port=args.port, share=True)
|
| 599 |
except Exception as e:
|
| 600 |
logger.error(f"Failed to launch app: {e}")
|
|
|
|
| 12 |
from ensemble import ensemble_files
|
| 13 |
import shutil
|
| 14 |
import gradio_client.utils as client_utils
|
| 15 |
+
import validators
|
| 16 |
+
import matchering as mg
|
| 17 |
+
from typing import Any, Optional
|
| 18 |
|
| 19 |
+
# Logging setup
|
| 20 |
+
logging.basicConfig(level=logging.INFO)
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
# Patch gradio_client.utils._json_schema_to_python_type to handle enum schemas
|
| 24 |
+
original_json_schema_to_python_type = client_utils._json_schema_to_python_type
|
| 25 |
+
|
| 26 |
+
def patched_json_schema_to_python_type(schema: Any, defs: Optional[dict] = None) -> str:
|
| 27 |
+
logger.debug(f"Parsing schema: {schema}")
|
| 28 |
if isinstance(schema, bool):
|
| 29 |
+
logger.info("Found boolean schema, returning 'boolean'")
|
| 30 |
return "boolean"
|
| 31 |
+
if not isinstance(schema, dict):
|
| 32 |
+
logger.warning(f"Unexpected schema type: {type(schema)}, returning 'Any'")
|
|
|
|
|
|
|
|
|
|
| 33 |
return "Any"
|
| 34 |
+
if "enum" in schema and schema.get("type") == "string":
|
| 35 |
+
logger.info(f"Handling enum schema: {schema['enum']}")
|
| 36 |
+
return f"Literal[{', '.join(repr(e) for e in schema['enum'])}]"
|
| 37 |
+
try:
|
| 38 |
+
return original_json_schema_to_python_type(schema, defs)
|
| 39 |
+
except client_utils.APIInfoParseError as e:
|
| 40 |
+
logger.error(f"Failed to parse schema {schema}: {e}")
|
| 41 |
+
return "str" # Fallback to str for string enums
|
| 42 |
|
| 43 |
+
client_utils._json_schema_to_python_type = patched_json_schema_to_python_type
|
| 44 |
|
| 45 |
# Device and autocast setup
|
| 46 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 47 |
use_autocast = device == "cuda"
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
# Model dictionaries organized by category
|
| 50 |
ROFORMER_MODELS = {
|
| 51 |
"Vocals": {
|
|
|
|
| 140 |
|
| 141 |
OUTPUT_FORMATS = ['wav', 'flac', 'mp3', 'ogg', 'opus', 'm4a', 'aiff', 'ac3']
|
| 142 |
|
| 143 |
+
# CSS for UI styling (unchanged from previous)
|
| 144 |
CSS = """
|
| 145 |
/* Modern ve Etkileşimli Tema */
|
| 146 |
#app-container {
|
|
|
|
| 357 |
# Functions
|
| 358 |
def download_audio(url, out_dir="ytdl"):
|
| 359 |
"""Download audio from a URL using yt-dlp."""
|
| 360 |
+
if not url or not validators.url(url):
|
| 361 |
+
raise ValueError("Invalid or missing URL.")
|
| 362 |
|
| 363 |
if os.path.exists(out_dir):
|
| 364 |
shutil.rmtree(out_dir)
|
|
|
|
| 375 |
info_dict = ydl.extract_info(url, download=True)
|
| 376 |
return ydl.prepare_filename(info_dict).rsplit('.', 1)[0] + '.wav'
|
| 377 |
except Exception as e:
|
| 378 |
+
logger.error(f"Download failed: {e}")
|
| 379 |
raise RuntimeError(f"Download failed: {e}")
|
| 380 |
|
| 381 |
def roformer_separator(audio, model_key, seg_size, override_seg_size, overlap, pitch_shift, model_dir, output_dir, out_format, norm_thresh, amp_thresh, batch_size, exclude_stems="", progress=gr.Progress(track_tqdm=True)):
|
|
|
|
| 494 |
|
| 495 |
def update_roformer_models(category):
|
| 496 |
"""Update Roformer model dropdown based on selected category."""
|
| 497 |
+
choices = list(ROFORMER_MODELS.get(category, {}).keys()) or []
|
| 498 |
+
logger.debug(f"Updating roformer models for category {category}: {choices}")
|
| 499 |
+
return gr.update(choices=choices)
|
| 500 |
|
| 501 |
def update_ensemble_models(category):
|
| 502 |
"""Update ensemble model dropdown based on selected category."""
|
| 503 |
+
choices = list(ROFORMER_MODELS.get(category, {}).keys()) or []
|
| 504 |
+
logger.debug(f"Updating ensemble models for category {category}: {choices}")
|
| 505 |
+
return gr.update(choices=choices)
|
| 506 |
|
| 507 |
# Interface creation
|
| 508 |
def create_interface():
|
|
|
|
| 516 |
with gr.Group(elem_classes="dubbing-theme"):
|
| 517 |
gr.Markdown("### General Settings")
|
| 518 |
model_file_dir = gr.Textbox(value="/tmp/audio-separator-models/", label="📂 Model Cache", placeholder="Path to model directory", interactive=True)
|
| 519 |
+
output_dir = gr.Textbox(value="output+*6", label="📤 Output Directory", placeholder="Where to save results", interactive=True)
|
| 520 |
output_format = gr.Dropdown(value="wav", choices=OUTPUT_FORMATS, label="🎶 Output Format", interactive=True)
|
| 521 |
norm_threshold = gr.Slider(0.1, 1.0, value=0.9, step=0.1, label="🔊 Normalization Threshold", interactive=True)
|
| 522 |
amp_threshold = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="📈 Amplification Threshold", interactive=True)
|
|
|
|
| 603 |
|
| 604 |
app = create_interface()
|
| 605 |
try:
|
| 606 |
+
# For Hugging Face Spaces
|
| 607 |
app.launch(server_name="0.0.0.0", server_port=args.port, share=True)
|
| 608 |
except Exception as e:
|
| 609 |
logger.error(f"Failed to launch app: {e}")
|