Spaces:
Running
Running
import streamlit as st | |
import numpy as np | |
import librosa | |
import soundfile as sf | |
import os | |
import tempfile | |
from pathlib import Path | |
import torch | |
from tqdm import tqdm | |
import base64 | |
import io | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
# Page configuration | |
st.set_page_config( | |
page_title="Music Stem Splitter", | |
page_icon="🎵", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Set maximum audio duration (in seconds) and file size (in MB) | |
MAX_AUDIO_DURATION = 300 # 5 minutes | |
MAX_FILE_SIZE_MB = 100 | |
# Load pretrained separator model | |
def load_separator_model(): | |
try: | |
# Import here to avoid loading until needed | |
from demucs.pretrained import get_model | |
model = get_model('htdemucs') | |
model.eval() | |
if torch.cuda.is_available(): | |
model.cuda() | |
return model | |
except ImportError: | |
st.error("Required package 'demucs' not found. Please install it with 'pip install demucs'.") | |
return None | |
# Function to check audio length | |
def check_audio_length(audio_path): | |
try: | |
duration = librosa.get_duration(path=audio_path) | |
return duration | |
except Exception as e: | |
st.error(f"Could not determine audio length: {str(e)}") | |
return MAX_AUDIO_DURATION + 1 # Return a value that will fail the check | |
# Function to separate stems from an audio file | |
def separate_stems(audio_path, model, sample_rate=44100): | |
from demucs.apply import apply_model | |
import torchaudio | |
# Load audio with potential resampling to save memory | |
waveform, original_sample_rate = torchaudio.load(audio_path) | |
# Resample if needed to optimize memory usage | |
if original_sample_rate > sample_rate: | |
resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=sample_rate) | |
waveform = resampler(waveform) | |
st.info(f"Audio resampled from {original_sample_rate}Hz to {sample_rate}Hz to optimize performance.") | |
else: | |
sample_rate = original_sample_rate | |
# Create a mono version just for visualization | |
if waveform.shape[0] > 1: | |
waveform_mono = torch.mean(waveform, dim=0, keepdim=True) | |
else: | |
waveform_mono = waveform | |
# Get the audio length in seconds for progress tracking | |
audio_length = waveform.shape[1] / sample_rate | |
# Create a progress bar | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
# Prepare the model input | |
if torch.cuda.is_available(): | |
waveform = waveform.cuda() | |
# For Demucs, we need the audio as (batch, channels, time) | |
if waveform.dim() == 2: # (channels, time) | |
waveform = waveform.unsqueeze(0) | |
# Create a temp directory for saving stems | |
temp_dir = tempfile.mkdtemp() | |
stems = {} | |
# Process and separate stems | |
status_text.text("Separating stems... This may take a while depending on the audio length.") | |
# Optimize memory usage by processing in chunks if needed | |
with torch.no_grad(): | |
# Use smaller chunks for CPU, larger for GPU | |
chunk_size = 10 * sample_rate if torch.cuda.is_available() else 5 * sample_rate | |
if waveform.shape[-1] > chunk_size and waveform.shape[-1] > 30 * sample_rate: | |
# Process in chunks for very long audio | |
st.info("Processing long audio in chunks to optimize memory usage...") | |
sources = [] | |
# Calculate number of chunks | |
num_chunks = int(np.ceil(waveform.shape[-1] / chunk_size)) | |
for i in range(num_chunks): | |
# Update progress | |
progress = i / num_chunks * 0.7 # 70% of progress for separation | |
progress_bar.progress(progress) | |
status_text.text(f"Processing chunk {i+1}/{num_chunks}...") | |
# Extract chunk | |
start = i * chunk_size | |
end = min(start + chunk_size, waveform.shape[-1]) | |
chunk = waveform[:, :, start:end] | |
# Process chunk | |
chunk_sources = apply_model(model, chunk, device="cuda" if torch.cuda.is_available() else "cpu") | |
# Append to sources | |
if i == 0: | |
sources = chunk_sources | |
else: | |
# Concatenate along time dimension | |
sources = torch.cat([sources, chunk_sources], dim=-1) | |
# Clear GPU memory if needed | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
else: | |
# Process entire audio at once for shorter clips | |
sources = apply_model(model, waveform, device="cuda" if torch.cuda.is_available() else "cpu") | |
# sources is (batch, source, channels, time) | |
sources = sources[0] # Remove batch dimension | |
# Save each source | |
source_names = ["drums", "bass", "other", "vocals"] | |
for i, source_name in enumerate(source_names): | |
stems[source_name] = sources[i].cpu().numpy() | |
# Update progress | |
progress = 0.7 + (i + 1) / len(source_names) * 0.2 # 20% of progress for stem saving | |
progress_bar.progress(progress) | |
status_text.text(f"Processed {source_name} stem ({i+1}/{len(source_names)})") | |
# Create visualizations (at reduced resolution to save memory) | |
visualizations = {} | |
for stem_name, audio_data in stems.items(): | |
# Create spectrogram visualization | |
plt.figure(figsize=(10, 4)) | |
# Use a smaller portion of audio for visualization if it's too long | |
max_samples = min(sample_rate * 30, audio_data.shape[1]) # 30 seconds max | |
visualization_data = audio_data[0, :max_samples] if audio_data.shape[1] > max_samples else audio_data[0] | |
# Create spectrogram with reduced resolution | |
D = librosa.amplitude_to_db(np.abs(librosa.stft(visualization_data, n_fft=1024, hop_length=512)), ref=np.max) | |
plt.subplot(1, 1, 1) | |
librosa.display.specshow(D, y_axis='log', x_axis='time', sr=sample_rate) | |
plt.title(f'{stem_name.capitalize()} Spectrogram') | |
plt.colorbar(format='%+2.0f dB') | |
plt.tight_layout() | |
# Save figure to bytes | |
buf = io.BytesIO() | |
plt.savefig(buf, format='png', dpi=100) # Lower DPI to save memory | |
buf.seek(0) | |
visualizations[stem_name] = buf | |
plt.close() | |
# Clear GPU memory | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# Update progress to complete | |
progress_bar.progress(1.0) | |
status_text.text("Stem separation complete!") | |
return stems, sample_rate, visualizations | |
# Function to create a download link for audio files | |
def get_binary_file_downloader_html(bin_data, file_label, file_extension): | |
b64data = base64.b64encode(bin_data).decode() | |
href = f'<a href="data:audio/{file_extension};base64,{b64data}" download="{file_label}.{file_extension}">Download {file_label}</a>' | |
return href | |
# Title and description | |
st.title("🎵 Music Stem Splitter") | |
st.markdown(""" | |
This application separates music tracks into individual stems: | |
- **Vocals**: Lead and background vocals | |
- **Drums**: Drum kit and percussion | |
- **Bass**: Bass guitar, synth bass, etc. | |
- **Other**: All other instruments and sounds | |
Upload an audio file (MP3, WAV, or FLAC) to get started. | |
""") | |
# Add warning about HF Spaces limitations | |
st.warning(f""" | |
⚠️ **Hugging Face Spaces Limitations**: | |
- Maximum file size: {MAX_FILE_SIZE_MB}MB | |
- Maximum audio duration: {MAX_AUDIO_DURATION} seconds ({MAX_AUDIO_DURATION//60} minutes) | |
- Processing may take several minutes depending on server load | |
""") | |
# Initialize session state for storing results | |
if 'stems' not in st.session_state: | |
st.session_state.stems = None | |
if 'sample_rate' not in st.session_state: | |
st.session_state.sample_rate = None | |
if 'visualizations' not in st.session_state: | |
st.session_state.visualizations = None | |
# File uploader | |
st.subheader("Upload Audio File") | |
uploaded_file = st.file_uploader("Choose an audio file", type=["mp3", "wav", "flac", "ogg"]) | |
# Model loading (only when needed) | |
model_load_state = st.empty() | |
# Process the uploaded file | |
if uploaded_file is not None: | |
# Check file size | |
file_size_mb = uploaded_file.size / 1e6 | |
if file_size_mb > MAX_FILE_SIZE_MB: | |
st.error(f"File too large: {file_size_mb:.1f}MB. Maximum allowed size is {MAX_FILE_SIZE_MB}MB.") | |
else: | |
# Display file info | |
file_details = {"Filename": uploaded_file.name, "FileSize": f"{file_size_mb:.2f} MB"} | |
st.write(file_details) | |
# Create a temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file: | |
tmp_file.write(uploaded_file.getvalue()) | |
tmp_path = tmp_file.name | |
# Check audio duration | |
audio_duration = check_audio_length(tmp_path) | |
if audio_duration > MAX_AUDIO_DURATION: | |
st.error(f"Audio duration too long: {audio_duration:.1f} seconds. Maximum allowed duration is {MAX_AUDIO_DURATION} seconds ({MAX_AUDIO_DURATION//60} minutes).") | |
# Clean up temporary file | |
os.unlink(tmp_path) | |
else: | |
st.info(f"Audio duration: {audio_duration:.1f} seconds") | |
# Load model (with caching for efficiency) | |
with model_load_state: | |
st.info("Loading AI model... This may take a moment the first time.") | |
model = load_separator_model() | |
if model is not None: | |
# Process button | |
if st.button("Split into Stems"): | |
try: | |
# Select processing sample rate based on file duration | |
# Shorter files can use higher quality, longer files use lower to save memory | |
if audio_duration < 60: # Less than 1 minute | |
processing_sample_rate = 44100 | |
elif audio_duration < 180: # 1-3 minutes | |
processing_sample_rate = 32000 | |
else: # 3-5 minutes | |
processing_sample_rate = 22050 | |
# Perform stem separation | |
st.session_state.stems, st.session_state.sample_rate, st.session_state.visualizations = separate_stems( | |
tmp_path, model, sample_rate=processing_sample_rate | |
) | |
st.success("Stem separation completed! Scroll down to preview and download individual stems.") | |
except Exception as e: | |
st.error(f"An error occurred during processing: {str(e)}") | |
st.info("Try with a shorter audio clip or a different file format.") | |
else: | |
st.warning("Required packages not available. To run locally, install with 'pip install demucs librosa soundfile'") | |
# Clean up temporary file | |
os.unlink(tmp_path) | |
# Display results if available | |
if st.session_state.stems is not None: | |
st.header("Separated Stems") | |
# Create tabs for each stem | |
stem_tabs = st.tabs(["Vocals", "Drums", "Bass", "Other"]) | |
# Get stem names in correct order | |
stem_names = ["vocals", "drums", "bass", "other"] | |
# Process each stem | |
for i, (stem_tab, stem_name) in enumerate(zip(stem_tabs, stem_names)): | |
with stem_tab: | |
# Create columns for audio player and visualization | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
st.subheader(f"{stem_name.capitalize()} Stem") | |
# Convert numpy array to audio file for playback | |
audio_data = st.session_state.stems[stem_name] | |
# Create a temporary buffer for the audio data | |
buf = io.BytesIO() | |
sf.write(buf, audio_data.T, st.session_state.sample_rate, format='WAV') | |
buf.seek(0) | |
# Display audio player | |
st.audio(buf, format='audio/wav') | |
# Download button | |
st.markdown(get_binary_file_downloader_html(buf.getvalue(), f"{stem_name}", "wav"), unsafe_allow_html=True) | |
# Additional information | |
if stem_name == "vocals": | |
st.info("Contains lead vocals and backing vocals.") | |
elif stem_name == "drums": | |
st.info("Contains drums and percussion elements.") | |
elif stem_name == "bass": | |
st.info("Contains bass guitar and low-frequency elements.") | |
else: # other | |
st.info("Contains all other instruments (guitars, keys, synths, etc).") | |
with col2: | |
# Display visualization | |
if st.session_state.visualizations and stem_name in st.session_state.visualizations: | |
st.image(st.session_state.visualizations[stem_name], caption=f"{stem_name.capitalize()} Spectrogram") | |
# Show instructions for downloading all stems | |
st.header("Usage Tips") | |
st.markdown(""" | |
### What can you do with these stems? | |
- Create remixes or mashups | |
- Practice playing along with isolated instrument tracks | |
- Create karaoke versions by removing vocals | |
- Analyze individual instrument parts for educational purposes | |
### Next steps: | |
1. Download each stem you want to use | |
2. Import them into your DAW (Digital Audio Workstation) | |
3. Mix, process, and create! | |
""") | |
# Add instructions for local deployment | |
st.sidebar.header("About This App") | |
st.sidebar.markdown(""" | |
This application uses the Demucs model to separate audio tracks into individual stems. The model was developed by Facebook AI Research. | |
### How it works | |
The separation process uses a deep neural network to identify and isolate: | |
- Vocals | |
- Drums | |
- Bass | |
- Other instruments | |
### Source code | |
[GitHub Repository](https://github.com/huggingface/music-stem-splitter) | |
(Link to your repo once created) | |
""") | |
# Add a note about processing time | |
st.sidebar.markdown(""" | |
### Processing Time | |
The processing time depends on: | |
- Length of the audio file | |
- Available computational resources | |
- File quality | |
For best results, use high-quality audio files without excessive background noise. | |
""") | |
# Show model information | |
st.sidebar.markdown(""" | |
### Model Information | |
This app uses the HTDemucs model, which is trained to separate music into four stems. | |
Audio processing is optimized based on file length: | |
- Short files (< 1 min): 44.1kHz processing | |
- Medium files (1-3 min): 32kHz processing | |
- Longer files (3-5 min): 22kHz processing | |
""") |