Spaces:
Running
on
Zero
Running
on
Zero
import yt_dlp | |
import re | |
import subprocess | |
import os | |
import shutil | |
from pydub import AudioSegment, silence | |
import gradio as gr | |
import traceback | |
import logging | |
from inference import proc_folder_direct | |
from pathlib import Path | |
import spaces | |
from pydub.exceptions import CouldntEncodeError | |
from transformers import pipeline | |
import requests | |
# Initialize text generation model | |
model = pipeline('text-generation', model='EleutherAI/gpt-neo-125M') | |
# Define constants | |
OUTPUT_FOLDER = "separation_results/" | |
INPUT_FOLDER = "input" | |
download_path = "" | |
# URL for the cookies.txt file in the Hugging Face repository | |
cookies_url = "https://huggingface.co/spaces/Awell00/music_drums_separation/raw/main/cookies.txt" | |
def download_cookies(): | |
try: | |
response = requests.get(cookies_url) | |
response.raise_for_status() # Check for HTTP errors | |
# Write content to cookies.txt file in the Docker container | |
with open("cookies.txt", "w") as file: | |
file.write(response.text) | |
print("cookies.txt downloaded successfully.") | |
except requests.exceptions.RequestException as e: | |
print(f"Error downloading cookies.txt: {e}") | |
class MyLogger: | |
def debug(self, msg): | |
# For compatibility with youtube-dl, both debug and info are passed into debug | |
if msg.startswith('[debug] '): | |
pass | |
else: | |
self.info(msg) | |
def info(self, msg): | |
pass | |
def warning(self, msg): | |
pass | |
def error(self, msg): | |
print(msg) | |
def my_hook(d): | |
if d['status'] == 'finished': | |
print('Done downloading, now post-processing ...') | |
def sanitize_filename(filename): | |
""" | |
Remove special characters from filename to ensure it's valid across different file systems. | |
Args: | |
filename (str): The original filename | |
Returns: | |
str: Sanitized filename | |
""" | |
return re.sub(r'[\\/*?:"<>|]', '_', filename) | |
def delete_input_files(input_dir): | |
""" | |
Delete all WAV files in the input directory. | |
Args: | |
input_dir (str): Path to the input directory | |
""" | |
wav_dir = Path(input_dir) / "wav" | |
for wav_file in wav_dir.glob("*.wav"): | |
wav_file.unlink() | |
print(f"Deleted {wav_file}") | |
def standardize_title(input_title): | |
""" | |
Standardize the title format by removing unnecessary words and rearranging artist and title. | |
Args: | |
input_title (str): The original title | |
Returns: | |
str: Standardized title in "Artist - Title" format | |
""" | |
# Remove content within parentheses or brackets | |
title_cleaned = re.sub(r"[\(\[].*?[\)\]]", "", input_title) | |
# Remove unnecessary words | |
unnecessary_words = ["official", "video", "hd", "4k", "lyrics", "music", "audio", "visualizer", "remix", ""] | |
title_cleaned = re.sub(r"\b(?:{})\b".format("|".join(unnecessary_words)), "", title_cleaned, flags=re.IGNORECASE) | |
# Split title into parts | |
parts = re.split(r"\s*-\s*|\s*,\s*", title_cleaned) | |
# Determine artist and title parts | |
if len(parts) >= 2: | |
title_part = parts[-1].strip() | |
artist_part = ', '.join(parts[:-1]).strip() | |
else: | |
artist_part = "Unknown Artist" | |
title_part = title_cleaned.strip() | |
# Handle "with" or "feat" in the title | |
if "with" in input_title.lower() or "feat" in input_title.lower(): | |
match = re.search(r"\((with|feat\.?) (.*?)\)", input_title, re.IGNORECASE) | |
if match: | |
additional_artist = match.group(2).strip() | |
artist_part = f"{artist_part}, {additional_artist}" if artist_part != "Unknown Artist" else additional_artist | |
# Clean up and capitalize | |
artist_part = re.sub(r'\s+', ' ', artist_part).title() | |
title_part = re.sub(r'\s+', ' ', title_part).title() | |
# Combine artist and title | |
standardized_output = f"{artist_part} - {title_part}" | |
return standardized_output.strip() | |
def get_video_title(video_url): | |
ydl_opts = { | |
'logger': MyLogger(), | |
'progress_hooks': [my_hook], | |
'cookiefile': 'cookies.txt', | |
'quiet': True, | |
'ratelimit': 500000, | |
'retries': 3, | |
} | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
# Extract video info using the provided URL | |
video_info = ydl.extract_info(video_url, download=False) | |
# Get the video title | |
video_title = video_info['title'] # Get the video title | |
return video_title | |
def download_youtube_audio(youtube_url: str, output_dir: str = './download', delete_existing: bool = True, simulate: bool = False) -> str: | |
""" | |
Downloads audio from a YouTube URL and saves it as an MP3 file with specified yt-dlp options. | |
Args: | |
youtube_url (str): URL of the YouTube video. | |
output_dir (str): Directory to save the downloaded audio file. | |
delete_existing (bool): If True, deletes any existing file with the same name. | |
simulate (bool): If True, simulates the download without actually downloading. | |
Returns: | |
str: Path to the downloaded audio file. | |
""" | |
if not os.path.exists(output_dir): | |
os.makedirs(output_dir) | |
download_cookies() | |
title = get_video_title(youtube_url) | |
audio_file = os.path.join(output_dir, title) | |
# Remove existing file if requested | |
if delete_existing and os.path.exists(audio_file + '.mp3'): | |
os.remove(audio_file + '.mp3') | |
# Prepare yt-dlp options | |
ydl_opts = { | |
'logger': MyLogger(), | |
'progress_hooks': [my_hook], | |
'format': 'bestaudio', | |
'outtmpl': audio_file, | |
'postprocessors': [{ | |
'key': 'FFmpegExtractAudio', | |
'preferredcodec': 'wav', | |
}], | |
'extractor_retries': 10, | |
'force_overwrites': True, | |
'cookiefile': 'cookies.txt', | |
'verbose': True, | |
'ratelimit': 500000, | |
'retries': 3, | |
'sleep_interval': 10, | |
'max_sleep_interval': 30 | |
} | |
if simulate: | |
ydl_opts['simulate'] = True | |
# Download the audio using yt-dlp | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
ydl.download([youtube_url]) | |
return audio_file + '.wav' | |
def handle_file_upload(file): | |
""" | |
Handle file upload, standardize the filename, change extension to .wav, and copy it to the input folder. | |
Args: | |
file: Uploaded file object or file path string | |
Returns: | |
tuple: (input_path, formatted_title) or (None, error_message) | |
""" | |
if file is None: | |
return None, "No file uploaded" | |
# Check if 'file' is an instance of a file object or a string | |
if isinstance(file, str): | |
filename = os.path.basename(file) # If it's a string, use it directly | |
file_path = file # The string itself is the file path | |
else: | |
filename = os.path.basename(file.name) # If it's a file object | |
file_path = file.name | |
formatted_title = standardize_title(os.path.splitext(filename)[0]) # Removing extension | |
formatted_title = sanitize_filename(formatted_title.strip()) | |
# Change the extension to .wav | |
input_path = os.path.join(INPUT_FOLDER, "wav", f"{formatted_title}.wav") | |
os.makedirs(os.path.dirname(input_path), exist_ok=True) | |
# Convert the input file to .wav if it's not already | |
audio = AudioSegment.from_file(file_path) | |
audio.export(input_path, format="wav") | |
return input_path, formatted_title | |
def run_inference(model_type, config_path, start_check_point, input_dir, output_dir, device_ids="0"): | |
""" | |
Run inference using the specified model and parameters. | |
Args: | |
model_type (str): Type of the model | |
config_path (str): Path to the model configuration | |
start_check_point (str): Path to the model checkpoint | |
input_dir (str): Input directory | |
output_dir (str): Output directory | |
device_ids (str): GPU device IDs to use | |
Returns: | |
subprocess.CompletedProcess: Result of the subprocess run | |
""" | |
command = [ | |
"python", "inference.py", | |
"--model_type", model_type, | |
"--config_path", config_path, | |
"--start_check_point", start_check_point, | |
"--INPUT_FOLDER", input_dir, | |
"--store_dir", output_dir, | |
"--device_ids", device_ids | |
] | |
return subprocess.run(command, check=True, capture_output=True, text=True) | |
def move_stems_to_parent(input_dir): | |
""" | |
Move generated stem files to their parent directories. | |
Args: | |
input_dir (str): Input directory containing stem folders | |
""" | |
for subdir, dirs, files in os.walk(input_dir): | |
if subdir == input_dir: | |
continue | |
parent_dir = os.path.dirname(subdir) | |
song_name = os.path.basename(parent_dir) | |
# Move bass stem | |
if 'htdemucs' in subdir: | |
bass_path = os.path.join(subdir, f"{song_name}_bass.wav") | |
if os.path.exists(bass_path): | |
new_bass_path = os.path.join(parent_dir, "bass.wav") | |
shutil.move(bass_path, new_bass_path) | |
else: | |
print(f"Bass file not found: {bass_path}") | |
# Move vocals stem | |
elif 'mel_band_roformer' in subdir: | |
vocals_path = os.path.join(subdir, f"{song_name}_vocals.wav") | |
if os.path.exists(vocals_path): | |
new_vocals_path = os.path.join(parent_dir, "vocals.wav") | |
shutil.move(vocals_path, new_vocals_path) | |
else: | |
print(f"Vocals file not found: {vocals_path}") | |
# Move other stem | |
elif 'scnet' in subdir: | |
other_path = os.path.join(subdir, f"{song_name}_other.wav") | |
if os.path.exists(other_path): | |
new_other_path = os.path.join(parent_dir, "other.wav") | |
shutil.move(other_path, new_other_path) | |
else: | |
print(f"Other file not found: {other_path}") | |
# Move instrumental stem | |
elif 'bs_roformer' in subdir: | |
instrumental_path = os.path.join(subdir, f"{song_name}_other.wav") | |
if os.path.exists(instrumental_path): | |
new_instrumental_path = os.path.join(parent_dir, "instrumental.wav") | |
shutil.move(instrumental_path, new_instrumental_path) | |
def combine_stems_for_all(input_dir, output_format="mp3"): | |
""" | |
Combine all stems for each song in the input directory and export as MP3. | |
Args: | |
input_dir (str): Input directory containing song folders | |
output_format (str): Output audio format (default is 'mp3') | |
Returns: | |
str: Path to the combined audio file | |
""" | |
for subdir, _, _ in os.walk(input_dir): | |
if subdir == input_dir: | |
continue | |
song_name = os.path.basename(subdir).strip() # Remove any trailing spaces | |
print(f"Processing {subdir}") | |
stem_paths = { | |
"vocals": os.path.join(subdir, "vocals.wav"), | |
"bass": os.path.join(subdir, "bass.wav"), | |
"others": os.path.join(subdir, "other.wav"), | |
"instrumental": os.path.join(subdir, "instrumental.wav") | |
} | |
# Skip if not all stems are present | |
if not all(os.path.exists(path) for path in stem_paths.values()): | |
print(f"Skipping {subdir}, not all stems are present.") | |
continue | |
# Load and combine stems | |
stems = {name: AudioSegment.from_file(path) for name, path in stem_paths.items()} | |
stems["instrumental"] = stems["instrumental"].apply_gain(-20) | |
combined = stems["vocals"].overlay(stems["bass"]).overlay(stems["others"]).overlay(stems["instrumental"]) | |
# Trim silence at the end | |
trimmed_combined = trim_silence_at_end(combined) | |
# Format the output file name correctly | |
output_file = os.path.join(subdir, f"{song_name}.{output_format.lower()}") | |
# Export combined audio | |
try: | |
trimmed_combined.export(output_file, format=output_format.lower(), codec="libmp3lame", bitrate="320k") | |
print(f"Exported combined stems to {output_format.upper()} format: {output_file}") | |
except CouldntEncodeError as e: | |
print(f"{output_format.upper()} Encoding failed: {e}") | |
return None | |
return output_file | |
def trim_silence_at_end(audio_segment, silence_thresh=-50, chunk_size=10): | |
""" | |
Trim silence at the end of an audio segment. | |
Args: | |
audio_segment (AudioSegment): Input audio segment | |
silence_thresh (int): Silence threshold in dB | |
chunk_size (int): Size of chunks to analyze in ms | |
Returns: | |
AudioSegment: Trimmed audio segment | |
""" | |
silence_end = silence.detect_silence(audio_segment, min_silence_len=chunk_size, silence_thresh=silence_thresh) | |
if silence_end: | |
last_silence_start = silence_end[-1][0] | |
return audio_segment[:last_silence_start] | |
else: | |
return audio_segment | |
def delete_folders_and_files(input_dir): | |
""" | |
Delete temporary folders and files after processing. | |
Args: | |
input_dir (str): Input directory to clean up | |
""" | |
folders_to_delete = ['htdemucs', 'mel_band_roformer', 'scnet', 'bs_roformer'] | |
files_to_delete = ['bass.wav', 'vocals.wav', 'other.wav', 'instrumental.wav'] | |
for root, dirs, files in os.walk(input_dir, topdown=False): | |
if root == input_dir: | |
continue | |
# Delete specified folders | |
for folder in folders_to_delete: | |
folder_path = os.path.join(root, folder) | |
if os.path.isdir(folder_path): | |
print(f"Deleting folder: {folder_path}") | |
shutil.rmtree(folder_path) | |
# Delete specified files | |
for file in files_to_delete: | |
file_path = os.path.join(root, file) | |
if os.path.isfile(file_path): | |
print(f"Deleting file: {file_path}") | |
os.remove(file_path) | |
# Delete vocals folders | |
for root, dirs, files in os.walk(OUTPUT_FOLDER): | |
for dir_name in dirs: | |
if dir_name.endswith('_vocals'): | |
dir_path = os.path.join(root, dir_name) | |
print(f"Deleting folder: {dir_path}") | |
shutil.rmtree(dir_path) | |
print("Cleanup completed.") | |
def process_audio(uploaded_file, link): | |
""" | |
Main function to process the uploaded audio file. | |
Args: | |
uploaded_file: Uploaded file object | |
Yields: | |
tuple: (status_message, output_file_path) | |
""" | |
try: | |
yield "Processing audio...", None | |
if uploaded_file: | |
input_path, formatted_title = handle_file_upload(uploaded_file) | |
if input_path is None: | |
raise ValueError("File upload failed.") | |
elif link: | |
new_file = download_youtube_audio(link) | |
input_path, formatted_title = handle_file_upload(new_file) | |
else: | |
raise ValueError("Please upload a WAV file.") | |
# Run inference for different models | |
yield "Starting SCNet inference...", None | |
proc_folder_direct("scnet", "configs/config_scnet_other.yaml", "results/model_scnet_other.ckpt", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER) | |
yield "Starting Mel Band Roformer inference...", None | |
proc_folder_direct("mel_band_roformer", "configs/config_mel_band_roformer_vocals.yaml", "results/model_mel_band_roformer_vocals.ckpt", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER, extract_instrumental=True) | |
yield "Starting HTDemucs inference...", None | |
proc_folder_direct("htdemucs", "configs/config_htdemucs_bass.yaml", "results/model_htdemucs_bass.th", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER) | |
# Rename instrumental file | |
source_path = f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer/{formatted_title}_instrumental.wav' | |
destination_path = f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer/{formatted_title}.wav' | |
os.rename(source_path, destination_path) | |
yield "Starting BS Roformer inference...", None | |
proc_folder_direct("bs_roformer", "configs/config_bs_roformer_instrumental.yaml", "results/model_bs_roformer_instrumental.ckpt", f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer', OUTPUT_FOLDER) | |
# Clean up and organize files | |
yield "Moving input files...", None | |
delete_input_files(INPUT_FOLDER) | |
yield "Moving stems to parent...", None | |
move_stems_to_parent(OUTPUT_FOLDER) | |
yield "Combining stems...", None | |
output_file = combine_stems_for_all(OUTPUT_FOLDER, "mp3") | |
yield "Cleaning up...", None | |
delete_folders_and_files(OUTPUT_FOLDER) | |
yield f"Audio processing completed successfully.", output_file | |
except Exception as e: | |
error_msg = f"An error occurred: {str(e)}\n{traceback.format_exc()}" | |
logging.error(error_msg) | |
yield error_msg, None | |
# Set up Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Music Player and Processor") | |
youtube_url = gr.Textbox( | |
label="YouTube Song URL", | |
placeholder="This feature is currently disabled. You cannot input a URL.", | |
interactive=False | |
) | |
file_upload = gr.File(label="Upload MP3 file", file_types=[".mp3"]) | |
process_button = gr.Button("Process Audio") | |
log_output = gr.Textbox(label="Processing Log", interactive=False) | |
processed_audio_output = gr.File(label="Processed Audio") | |
process_button.click( | |
fn=process_audio, | |
inputs=[file_upload, youtube_url], | |
outputs=[log_output, processed_audio_output], | |
show_progress=True | |
) | |
# Launch the Gradio app | |
demo.launch() |