import streamlit as st import numpy as np import librosa import soundfile as sf import os import tempfile from pathlib import Path import torch from tqdm import tqdm import base64 import io from PIL import Image import matplotlib.pyplot as plt # Page configuration st.set_page_config( page_title="Music Stem Splitter", page_icon="🎵", layout="wide", initial_sidebar_state="expanded" ) # Set maximum audio duration (in seconds) and file size (in MB) MAX_AUDIO_DURATION = 300 # 5 minutes MAX_FILE_SIZE_MB = 100 # Load pretrained separator model @st.cache_resource def load_separator_model(): try: # Import here to avoid loading until needed from demucs.pretrained import get_model model = get_model('htdemucs') model.eval() if torch.cuda.is_available(): model.cuda() return model except ImportError: st.error("Required package 'demucs' not found. Please install it with 'pip install demucs'.") return None # Function to check audio length def check_audio_length(audio_path): try: duration = librosa.get_duration(path=audio_path) return duration except Exception as e: st.error(f"Could not determine audio length: {str(e)}") return MAX_AUDIO_DURATION + 1 # Return a value that will fail the check # Function to separate stems from an audio file def separate_stems(audio_path, model, sample_rate=44100): from demucs.apply import apply_model import torchaudio # Load audio with potential resampling to save memory waveform, original_sample_rate = torchaudio.load(audio_path) # Resample if needed to optimize memory usage if original_sample_rate > sample_rate: resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=sample_rate) waveform = resampler(waveform) st.info(f"Audio resampled from {original_sample_rate}Hz to {sample_rate}Hz to optimize performance.") else: sample_rate = original_sample_rate # Create a mono version just for visualization if waveform.shape[0] > 1: waveform_mono = torch.mean(waveform, dim=0, keepdim=True) else: waveform_mono = waveform # Get the audio length in seconds for progress tracking audio_length = waveform.shape[1] / sample_rate # Create a progress bar progress_bar = st.progress(0) status_text = st.empty() # Prepare the model input if torch.cuda.is_available(): waveform = waveform.cuda() # For Demucs, we need the audio as (batch, channels, time) if waveform.dim() == 2: # (channels, time) waveform = waveform.unsqueeze(0) # Create a temp directory for saving stems temp_dir = tempfile.mkdtemp() stems = {} # Process and separate stems status_text.text("Separating stems... This may take a while depending on the audio length.") # Optimize memory usage by processing in chunks if needed with torch.no_grad(): # Use smaller chunks for CPU, larger for GPU chunk_size = 10 * sample_rate if torch.cuda.is_available() else 5 * sample_rate if waveform.shape[-1] > chunk_size and waveform.shape[-1] > 30 * sample_rate: # Process in chunks for very long audio st.info("Processing long audio in chunks to optimize memory usage...") sources = [] # Calculate number of chunks num_chunks = int(np.ceil(waveform.shape[-1] / chunk_size)) for i in range(num_chunks): # Update progress progress = i / num_chunks * 0.7 # 70% of progress for separation progress_bar.progress(progress) status_text.text(f"Processing chunk {i+1}/{num_chunks}...") # Extract chunk start = i * chunk_size end = min(start + chunk_size, waveform.shape[-1]) chunk = waveform[:, :, start:end] # Process chunk chunk_sources = apply_model(model, chunk, device="cuda" if torch.cuda.is_available() else "cpu") # Append to sources if i == 0: sources = chunk_sources else: # Concatenate along time dimension sources = torch.cat([sources, chunk_sources], dim=-1) # Clear GPU memory if needed if torch.cuda.is_available(): torch.cuda.empty_cache() else: # Process entire audio at once for shorter clips sources = apply_model(model, waveform, device="cuda" if torch.cuda.is_available() else "cpu") # sources is (batch, source, channels, time) sources = sources[0] # Remove batch dimension # Save each source source_names = ["drums", "bass", "other", "vocals"] for i, source_name in enumerate(source_names): stems[source_name] = sources[i].cpu().numpy() # Update progress progress = 0.7 + (i + 1) / len(source_names) * 0.2 # 20% of progress for stem saving progress_bar.progress(progress) status_text.text(f"Processed {source_name} stem ({i+1}/{len(source_names)})") # Create visualizations (at reduced resolution to save memory) visualizations = {} for stem_name, audio_data in stems.items(): # Create spectrogram visualization plt.figure(figsize=(10, 4)) # Use a smaller portion of audio for visualization if it's too long max_samples = min(sample_rate * 30, audio_data.shape[1]) # 30 seconds max visualization_data = audio_data[0, :max_samples] if audio_data.shape[1] > max_samples else audio_data[0] # Create spectrogram with reduced resolution D = librosa.amplitude_to_db(np.abs(librosa.stft(visualization_data, n_fft=1024, hop_length=512)), ref=np.max) plt.subplot(1, 1, 1) librosa.display.specshow(D, y_axis='log', x_axis='time', sr=sample_rate) plt.title(f'{stem_name.capitalize()} Spectrogram') plt.colorbar(format='%+2.0f dB') plt.tight_layout() # Save figure to bytes buf = io.BytesIO() plt.savefig(buf, format='png', dpi=100) # Lower DPI to save memory buf.seek(0) visualizations[stem_name] = buf plt.close() # Clear GPU memory if torch.cuda.is_available(): torch.cuda.empty_cache() # Update progress to complete progress_bar.progress(1.0) status_text.text("Stem separation complete!") return stems, sample_rate, visualizations # Function to create a download link for audio files def get_binary_file_downloader_html(bin_data, file_label, file_extension): b64data = base64.b64encode(bin_data).decode() href = f'Download {file_label}' return href # Title and description st.title("🎵 Music Stem Splitter") st.markdown(""" This application separates music tracks into individual stems: - **Vocals**: Lead and background vocals - **Drums**: Drum kit and percussion - **Bass**: Bass guitar, synth bass, etc. - **Other**: All other instruments and sounds Upload an audio file (MP3, WAV, or FLAC) to get started. """) # Add warning about HF Spaces limitations st.warning(f""" ⚠️ **Hugging Face Spaces Limitations**: - Maximum file size: {MAX_FILE_SIZE_MB}MB - Maximum audio duration: {MAX_AUDIO_DURATION} seconds ({MAX_AUDIO_DURATION//60} minutes) - Processing may take several minutes depending on server load """) # Initialize session state for storing results if 'stems' not in st.session_state: st.session_state.stems = None if 'sample_rate' not in st.session_state: st.session_state.sample_rate = None if 'visualizations' not in st.session_state: st.session_state.visualizations = None # File uploader st.subheader("Upload Audio File") uploaded_file = st.file_uploader("Choose an audio file", type=["mp3", "wav", "flac", "ogg"]) # Model loading (only when needed) model_load_state = st.empty() # Process the uploaded file if uploaded_file is not None: # Check file size file_size_mb = uploaded_file.size / 1e6 if file_size_mb > MAX_FILE_SIZE_MB: st.error(f"File too large: {file_size_mb:.1f}MB. Maximum allowed size is {MAX_FILE_SIZE_MB}MB.") else: # Display file info file_details = {"Filename": uploaded_file.name, "FileSize": f"{file_size_mb:.2f} MB"} st.write(file_details) # Create a temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file: tmp_file.write(uploaded_file.getvalue()) tmp_path = tmp_file.name # Check audio duration audio_duration = check_audio_length(tmp_path) if audio_duration > MAX_AUDIO_DURATION: st.error(f"Audio duration too long: {audio_duration:.1f} seconds. Maximum allowed duration is {MAX_AUDIO_DURATION} seconds ({MAX_AUDIO_DURATION//60} minutes).") # Clean up temporary file os.unlink(tmp_path) else: st.info(f"Audio duration: {audio_duration:.1f} seconds") # Load model (with caching for efficiency) with model_load_state: st.info("Loading AI model... This may take a moment the first time.") model = load_separator_model() if model is not None: # Process button if st.button("Split into Stems"): try: # Select processing sample rate based on file duration # Shorter files can use higher quality, longer files use lower to save memory if audio_duration < 60: # Less than 1 minute processing_sample_rate = 44100 elif audio_duration < 180: # 1-3 minutes processing_sample_rate = 32000 else: # 3-5 minutes processing_sample_rate = 22050 # Perform stem separation st.session_state.stems, st.session_state.sample_rate, st.session_state.visualizations = separate_stems( tmp_path, model, sample_rate=processing_sample_rate ) st.success("Stem separation completed! Scroll down to preview and download individual stems.") except Exception as e: st.error(f"An error occurred during processing: {str(e)}") st.info("Try with a shorter audio clip or a different file format.") else: st.warning("Required packages not available. To run locally, install with 'pip install demucs librosa soundfile'") # Clean up temporary file os.unlink(tmp_path) # Display results if available if st.session_state.stems is not None: st.header("Separated Stems") # Create tabs for each stem stem_tabs = st.tabs(["Vocals", "Drums", "Bass", "Other"]) # Get stem names in correct order stem_names = ["vocals", "drums", "bass", "other"] # Process each stem for i, (stem_tab, stem_name) in enumerate(zip(stem_tabs, stem_names)): with stem_tab: # Create columns for audio player and visualization col1, col2 = st.columns([1, 1]) with col1: st.subheader(f"{stem_name.capitalize()} Stem") # Convert numpy array to audio file for playback audio_data = st.session_state.stems[stem_name] # Create a temporary buffer for the audio data buf = io.BytesIO() sf.write(buf, audio_data.T, st.session_state.sample_rate, format='WAV') buf.seek(0) # Display audio player st.audio(buf, format='audio/wav') # Download button st.markdown(get_binary_file_downloader_html(buf.getvalue(), f"{stem_name}", "wav"), unsafe_allow_html=True) # Additional information if stem_name == "vocals": st.info("Contains lead vocals and backing vocals.") elif stem_name == "drums": st.info("Contains drums and percussion elements.") elif stem_name == "bass": st.info("Contains bass guitar and low-frequency elements.") else: # other st.info("Contains all other instruments (guitars, keys, synths, etc).") with col2: # Display visualization if st.session_state.visualizations and stem_name in st.session_state.visualizations: st.image(st.session_state.visualizations[stem_name], caption=f"{stem_name.capitalize()} Spectrogram") # Show instructions for downloading all stems st.header("Usage Tips") st.markdown(""" ### What can you do with these stems? - Create remixes or mashups - Practice playing along with isolated instrument tracks - Create karaoke versions by removing vocals - Analyze individual instrument parts for educational purposes ### Next steps: 1. Download each stem you want to use 2. Import them into your DAW (Digital Audio Workstation) 3. Mix, process, and create! """) # Add instructions for local deployment st.sidebar.header("About This App") st.sidebar.markdown(""" This application uses the Demucs model to separate audio tracks into individual stems. The model was developed by Facebook AI Research. ### How it works The separation process uses a deep neural network to identify and isolate: - Vocals - Drums - Bass - Other instruments ### Source code [GitHub Repository](https://github.com/huggingface/music-stem-splitter) (Link to your repo once created) """) # Add a note about processing time st.sidebar.markdown(""" ### Processing Time The processing time depends on: - Length of the audio file - Available computational resources - File quality For best results, use high-quality audio files without excessive background noise. """) # Show model information st.sidebar.markdown(""" ### Model Information This app uses the HTDemucs model, which is trained to separate music into four stems. Audio processing is optimized based on file length: - Short files (< 1 min): 44.1kHz processing - Medium files (1-3 min): 32kHz processing - Longer files (3-5 min): 22kHz processing """)