CCockrum's picture
Create app.py
5f8b972 verified
import streamlit as st
import numpy as np
import librosa
import soundfile as sf
import os
import tempfile
from pathlib import Path
import torch
from tqdm import tqdm
import base64
import io
from PIL import Image
import matplotlib.pyplot as plt
# Page configuration
st.set_page_config(
page_title="Music Stem Splitter",
page_icon="🎵",
layout="wide",
initial_sidebar_state="expanded"
)
# Set maximum audio duration (in seconds) and file size (in MB)
MAX_AUDIO_DURATION = 300 # 5 minutes
MAX_FILE_SIZE_MB = 100
# Load pretrained separator model
@st.cache_resource
def load_separator_model():
try:
# Import here to avoid loading until needed
from demucs.pretrained import get_model
model = get_model('htdemucs')
model.eval()
if torch.cuda.is_available():
model.cuda()
return model
except ImportError:
st.error("Required package 'demucs' not found. Please install it with 'pip install demucs'.")
return None
# Function to check audio length
def check_audio_length(audio_path):
try:
duration = librosa.get_duration(path=audio_path)
return duration
except Exception as e:
st.error(f"Could not determine audio length: {str(e)}")
return MAX_AUDIO_DURATION + 1 # Return a value that will fail the check
# Function to separate stems from an audio file
def separate_stems(audio_path, model, sample_rate=44100):
from demucs.apply import apply_model
import torchaudio
# Load audio with potential resampling to save memory
waveform, original_sample_rate = torchaudio.load(audio_path)
# Resample if needed to optimize memory usage
if original_sample_rate > sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=sample_rate)
waveform = resampler(waveform)
st.info(f"Audio resampled from {original_sample_rate}Hz to {sample_rate}Hz to optimize performance.")
else:
sample_rate = original_sample_rate
# Create a mono version just for visualization
if waveform.shape[0] > 1:
waveform_mono = torch.mean(waveform, dim=0, keepdim=True)
else:
waveform_mono = waveform
# Get the audio length in seconds for progress tracking
audio_length = waveform.shape[1] / sample_rate
# Create a progress bar
progress_bar = st.progress(0)
status_text = st.empty()
# Prepare the model input
if torch.cuda.is_available():
waveform = waveform.cuda()
# For Demucs, we need the audio as (batch, channels, time)
if waveform.dim() == 2: # (channels, time)
waveform = waveform.unsqueeze(0)
# Create a temp directory for saving stems
temp_dir = tempfile.mkdtemp()
stems = {}
# Process and separate stems
status_text.text("Separating stems... This may take a while depending on the audio length.")
# Optimize memory usage by processing in chunks if needed
with torch.no_grad():
# Use smaller chunks for CPU, larger for GPU
chunk_size = 10 * sample_rate if torch.cuda.is_available() else 5 * sample_rate
if waveform.shape[-1] > chunk_size and waveform.shape[-1] > 30 * sample_rate:
# Process in chunks for very long audio
st.info("Processing long audio in chunks to optimize memory usage...")
sources = []
# Calculate number of chunks
num_chunks = int(np.ceil(waveform.shape[-1] / chunk_size))
for i in range(num_chunks):
# Update progress
progress = i / num_chunks * 0.7 # 70% of progress for separation
progress_bar.progress(progress)
status_text.text(f"Processing chunk {i+1}/{num_chunks}...")
# Extract chunk
start = i * chunk_size
end = min(start + chunk_size, waveform.shape[-1])
chunk = waveform[:, :, start:end]
# Process chunk
chunk_sources = apply_model(model, chunk, device="cuda" if torch.cuda.is_available() else "cpu")
# Append to sources
if i == 0:
sources = chunk_sources
else:
# Concatenate along time dimension
sources = torch.cat([sources, chunk_sources], dim=-1)
# Clear GPU memory if needed
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
# Process entire audio at once for shorter clips
sources = apply_model(model, waveform, device="cuda" if torch.cuda.is_available() else "cpu")
# sources is (batch, source, channels, time)
sources = sources[0] # Remove batch dimension
# Save each source
source_names = ["drums", "bass", "other", "vocals"]
for i, source_name in enumerate(source_names):
stems[source_name] = sources[i].cpu().numpy()
# Update progress
progress = 0.7 + (i + 1) / len(source_names) * 0.2 # 20% of progress for stem saving
progress_bar.progress(progress)
status_text.text(f"Processed {source_name} stem ({i+1}/{len(source_names)})")
# Create visualizations (at reduced resolution to save memory)
visualizations = {}
for stem_name, audio_data in stems.items():
# Create spectrogram visualization
plt.figure(figsize=(10, 4))
# Use a smaller portion of audio for visualization if it's too long
max_samples = min(sample_rate * 30, audio_data.shape[1]) # 30 seconds max
visualization_data = audio_data[0, :max_samples] if audio_data.shape[1] > max_samples else audio_data[0]
# Create spectrogram with reduced resolution
D = librosa.amplitude_to_db(np.abs(librosa.stft(visualization_data, n_fft=1024, hop_length=512)), ref=np.max)
plt.subplot(1, 1, 1)
librosa.display.specshow(D, y_axis='log', x_axis='time', sr=sample_rate)
plt.title(f'{stem_name.capitalize()} Spectrogram')
plt.colorbar(format='%+2.0f dB')
plt.tight_layout()
# Save figure to bytes
buf = io.BytesIO()
plt.savefig(buf, format='png', dpi=100) # Lower DPI to save memory
buf.seek(0)
visualizations[stem_name] = buf
plt.close()
# Clear GPU memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Update progress to complete
progress_bar.progress(1.0)
status_text.text("Stem separation complete!")
return stems, sample_rate, visualizations
# Function to create a download link for audio files
def get_binary_file_downloader_html(bin_data, file_label, file_extension):
b64data = base64.b64encode(bin_data).decode()
href = f'<a href="data:audio/{file_extension};base64,{b64data}" download="{file_label}.{file_extension}">Download {file_label}</a>'
return href
# Title and description
st.title("🎵 Music Stem Splitter")
st.markdown("""
This application separates music tracks into individual stems:
- **Vocals**: Lead and background vocals
- **Drums**: Drum kit and percussion
- **Bass**: Bass guitar, synth bass, etc.
- **Other**: All other instruments and sounds
Upload an audio file (MP3, WAV, or FLAC) to get started.
""")
# Add warning about HF Spaces limitations
st.warning(f"""
⚠️ **Hugging Face Spaces Limitations**:
- Maximum file size: {MAX_FILE_SIZE_MB}MB
- Maximum audio duration: {MAX_AUDIO_DURATION} seconds ({MAX_AUDIO_DURATION//60} minutes)
- Processing may take several minutes depending on server load
""")
# Initialize session state for storing results
if 'stems' not in st.session_state:
st.session_state.stems = None
if 'sample_rate' not in st.session_state:
st.session_state.sample_rate = None
if 'visualizations' not in st.session_state:
st.session_state.visualizations = None
# File uploader
st.subheader("Upload Audio File")
uploaded_file = st.file_uploader("Choose an audio file", type=["mp3", "wav", "flac", "ogg"])
# Model loading (only when needed)
model_load_state = st.empty()
# Process the uploaded file
if uploaded_file is not None:
# Check file size
file_size_mb = uploaded_file.size / 1e6
if file_size_mb > MAX_FILE_SIZE_MB:
st.error(f"File too large: {file_size_mb:.1f}MB. Maximum allowed size is {MAX_FILE_SIZE_MB}MB.")
else:
# Display file info
file_details = {"Filename": uploaded_file.name, "FileSize": f"{file_size_mb:.2f} MB"}
st.write(file_details)
# Create a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
tmp_file.write(uploaded_file.getvalue())
tmp_path = tmp_file.name
# Check audio duration
audio_duration = check_audio_length(tmp_path)
if audio_duration > MAX_AUDIO_DURATION:
st.error(f"Audio duration too long: {audio_duration:.1f} seconds. Maximum allowed duration is {MAX_AUDIO_DURATION} seconds ({MAX_AUDIO_DURATION//60} minutes).")
# Clean up temporary file
os.unlink(tmp_path)
else:
st.info(f"Audio duration: {audio_duration:.1f} seconds")
# Load model (with caching for efficiency)
with model_load_state:
st.info("Loading AI model... This may take a moment the first time.")
model = load_separator_model()
if model is not None:
# Process button
if st.button("Split into Stems"):
try:
# Select processing sample rate based on file duration
# Shorter files can use higher quality, longer files use lower to save memory
if audio_duration < 60: # Less than 1 minute
processing_sample_rate = 44100
elif audio_duration < 180: # 1-3 minutes
processing_sample_rate = 32000
else: # 3-5 minutes
processing_sample_rate = 22050
# Perform stem separation
st.session_state.stems, st.session_state.sample_rate, st.session_state.visualizations = separate_stems(
tmp_path, model, sample_rate=processing_sample_rate
)
st.success("Stem separation completed! Scroll down to preview and download individual stems.")
except Exception as e:
st.error(f"An error occurred during processing: {str(e)}")
st.info("Try with a shorter audio clip or a different file format.")
else:
st.warning("Required packages not available. To run locally, install with 'pip install demucs librosa soundfile'")
# Clean up temporary file
os.unlink(tmp_path)
# Display results if available
if st.session_state.stems is not None:
st.header("Separated Stems")
# Create tabs for each stem
stem_tabs = st.tabs(["Vocals", "Drums", "Bass", "Other"])
# Get stem names in correct order
stem_names = ["vocals", "drums", "bass", "other"]
# Process each stem
for i, (stem_tab, stem_name) in enumerate(zip(stem_tabs, stem_names)):
with stem_tab:
# Create columns for audio player and visualization
col1, col2 = st.columns([1, 1])
with col1:
st.subheader(f"{stem_name.capitalize()} Stem")
# Convert numpy array to audio file for playback
audio_data = st.session_state.stems[stem_name]
# Create a temporary buffer for the audio data
buf = io.BytesIO()
sf.write(buf, audio_data.T, st.session_state.sample_rate, format='WAV')
buf.seek(0)
# Display audio player
st.audio(buf, format='audio/wav')
# Download button
st.markdown(get_binary_file_downloader_html(buf.getvalue(), f"{stem_name}", "wav"), unsafe_allow_html=True)
# Additional information
if stem_name == "vocals":
st.info("Contains lead vocals and backing vocals.")
elif stem_name == "drums":
st.info("Contains drums and percussion elements.")
elif stem_name == "bass":
st.info("Contains bass guitar and low-frequency elements.")
else: # other
st.info("Contains all other instruments (guitars, keys, synths, etc).")
with col2:
# Display visualization
if st.session_state.visualizations and stem_name in st.session_state.visualizations:
st.image(st.session_state.visualizations[stem_name], caption=f"{stem_name.capitalize()} Spectrogram")
# Show instructions for downloading all stems
st.header("Usage Tips")
st.markdown("""
### What can you do with these stems?
- Create remixes or mashups
- Practice playing along with isolated instrument tracks
- Create karaoke versions by removing vocals
- Analyze individual instrument parts for educational purposes
### Next steps:
1. Download each stem you want to use
2. Import them into your DAW (Digital Audio Workstation)
3. Mix, process, and create!
""")
# Add instructions for local deployment
st.sidebar.header("About This App")
st.sidebar.markdown("""
This application uses the Demucs model to separate audio tracks into individual stems. The model was developed by Facebook AI Research.
### How it works
The separation process uses a deep neural network to identify and isolate:
- Vocals
- Drums
- Bass
- Other instruments
### Source code
[GitHub Repository](https://github.com/huggingface/music-stem-splitter)
(Link to your repo once created)
""")
# Add a note about processing time
st.sidebar.markdown("""
### Processing Time
The processing time depends on:
- Length of the audio file
- Available computational resources
- File quality
For best results, use high-quality audio files without excessive background noise.
""")
# Show model information
st.sidebar.markdown("""
### Model Information
This app uses the HTDemucs model, which is trained to separate music into four stems.
Audio processing is optimized based on file length:
- Short files (< 1 min): 44.1kHz processing
- Medium files (1-3 min): 32kHz processing
- Longer files (3-5 min): 22kHz processing
""")