File size: 15,105 Bytes
5f8b972
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
import streamlit as st
import numpy as np
import librosa
import soundfile as sf
import os
import tempfile
from pathlib import Path
import torch
from tqdm import tqdm
import base64
import io
from PIL import Image
import matplotlib.pyplot as plt

# Page configuration
st.set_page_config(
    page_title="Music Stem Splitter",
    page_icon="🎵",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Set maximum audio duration (in seconds) and file size (in MB)
MAX_AUDIO_DURATION = 300  # 5 minutes
MAX_FILE_SIZE_MB = 100

# Load pretrained separator model
@st.cache_resource
def load_separator_model():
    try:
        # Import here to avoid loading until needed
        from demucs.pretrained import get_model
        model = get_model('htdemucs')
        model.eval()
        if torch.cuda.is_available():
            model.cuda()
        return model
    except ImportError:
        st.error("Required package 'demucs' not found. Please install it with 'pip install demucs'.")
        return None

# Function to check audio length
def check_audio_length(audio_path):
    try:
        duration = librosa.get_duration(path=audio_path)
        return duration
    except Exception as e:
        st.error(f"Could not determine audio length: {str(e)}")
        return MAX_AUDIO_DURATION + 1  # Return a value that will fail the check

# Function to separate stems from an audio file
def separate_stems(audio_path, model, sample_rate=44100):
    from demucs.apply import apply_model
    import torchaudio
    
    # Load audio with potential resampling to save memory
    waveform, original_sample_rate = torchaudio.load(audio_path)
    
    # Resample if needed to optimize memory usage
    if original_sample_rate > sample_rate:
        resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=sample_rate)
        waveform = resampler(waveform)
        st.info(f"Audio resampled from {original_sample_rate}Hz to {sample_rate}Hz to optimize performance.")
    else:
        sample_rate = original_sample_rate
    
    # Create a mono version just for visualization
    if waveform.shape[0] > 1:
        waveform_mono = torch.mean(waveform, dim=0, keepdim=True)
    else:
        waveform_mono = waveform
        
    # Get the audio length in seconds for progress tracking
    audio_length = waveform.shape[1] / sample_rate
    
    # Create a progress bar
    progress_bar = st.progress(0)
    status_text = st.empty()
    
    # Prepare the model input
    if torch.cuda.is_available():
        waveform = waveform.cuda()
    
    # For Demucs, we need the audio as (batch, channels, time)
    if waveform.dim() == 2:  # (channels, time)
        waveform = waveform.unsqueeze(0)
    
    # Create a temp directory for saving stems
    temp_dir = tempfile.mkdtemp()
    stems = {}
    
    # Process and separate stems
    status_text.text("Separating stems... This may take a while depending on the audio length.")
    
    # Optimize memory usage by processing in chunks if needed
    with torch.no_grad():
        # Use smaller chunks for CPU, larger for GPU
        chunk_size = 10 * sample_rate if torch.cuda.is_available() else 5 * sample_rate
        
        if waveform.shape[-1] > chunk_size and waveform.shape[-1] > 30 * sample_rate:
            # Process in chunks for very long audio
            st.info("Processing long audio in chunks to optimize memory usage...")
            sources = []
            
            # Calculate number of chunks
            num_chunks = int(np.ceil(waveform.shape[-1] / chunk_size))
            
            for i in range(num_chunks):
                # Update progress
                progress = i / num_chunks * 0.7  # 70% of progress for separation
                progress_bar.progress(progress)
                status_text.text(f"Processing chunk {i+1}/{num_chunks}...")
                
                # Extract chunk
                start = i * chunk_size
                end = min(start + chunk_size, waveform.shape[-1])
                chunk = waveform[:, :, start:end]
                
                # Process chunk
                chunk_sources = apply_model(model, chunk, device="cuda" if torch.cuda.is_available() else "cpu")
                
                # Append to sources
                if i == 0:
                    sources = chunk_sources
                else:
                    # Concatenate along time dimension
                    sources = torch.cat([sources, chunk_sources], dim=-1)
                
                # Clear GPU memory if needed
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
        else:
            # Process entire audio at once for shorter clips
            sources = apply_model(model, waveform, device="cuda" if torch.cuda.is_available() else "cpu")
        
        # sources is (batch, source, channels, time)
        sources = sources[0]  # Remove batch dimension
        
        # Save each source
        source_names = ["drums", "bass", "other", "vocals"]
        for i, source_name in enumerate(source_names):
            stems[source_name] = sources[i].cpu().numpy()
            
            # Update progress
            progress = 0.7 + (i + 1) / len(source_names) * 0.2  # 20% of progress for stem saving
            progress_bar.progress(progress)
            status_text.text(f"Processed {source_name} stem ({i+1}/{len(source_names)})")
    
    # Create visualizations (at reduced resolution to save memory)
    visualizations = {}
    for stem_name, audio_data in stems.items():
        # Create spectrogram visualization
        plt.figure(figsize=(10, 4))
        
        # Use a smaller portion of audio for visualization if it's too long
        max_samples = min(sample_rate * 30, audio_data.shape[1])  # 30 seconds max
        visualization_data = audio_data[0, :max_samples] if audio_data.shape[1] > max_samples else audio_data[0]
        
        # Create spectrogram with reduced resolution
        D = librosa.amplitude_to_db(np.abs(librosa.stft(visualization_data, n_fft=1024, hop_length=512)), ref=np.max)
        
        plt.subplot(1, 1, 1)
        librosa.display.specshow(D, y_axis='log', x_axis='time', sr=sample_rate)
        plt.title(f'{stem_name.capitalize()} Spectrogram')
        plt.colorbar(format='%+2.0f dB')
        plt.tight_layout()
        
        # Save figure to bytes
        buf = io.BytesIO()
        plt.savefig(buf, format='png', dpi=100)  # Lower DPI to save memory
        buf.seek(0)
        visualizations[stem_name] = buf
        plt.close()
    
    # Clear GPU memory
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # Update progress to complete
    progress_bar.progress(1.0)
    status_text.text("Stem separation complete!")
    
    return stems, sample_rate, visualizations

# Function to create a download link for audio files
def get_binary_file_downloader_html(bin_data, file_label, file_extension):
    b64data = base64.b64encode(bin_data).decode()
    href = f'<a href="data:audio/{file_extension};base64,{b64data}" download="{file_label}.{file_extension}">Download {file_label}</a>'
    return href

# Title and description
st.title("🎵 Music Stem Splitter")

st.markdown("""
This application separates music tracks into individual stems:
- **Vocals**: Lead and background vocals
- **Drums**: Drum kit and percussion
- **Bass**: Bass guitar, synth bass, etc.
- **Other**: All other instruments and sounds

Upload an audio file (MP3, WAV, or FLAC) to get started.
""")

# Add warning about HF Spaces limitations
st.warning(f"""
⚠️ **Hugging Face Spaces Limitations**: 
- Maximum file size: {MAX_FILE_SIZE_MB}MB
- Maximum audio duration: {MAX_AUDIO_DURATION} seconds ({MAX_AUDIO_DURATION//60} minutes)
- Processing may take several minutes depending on server load
""")

# Initialize session state for storing results
if 'stems' not in st.session_state:
    st.session_state.stems = None
if 'sample_rate' not in st.session_state:
    st.session_state.sample_rate = None
if 'visualizations' not in st.session_state:
    st.session_state.visualizations = None

# File uploader
st.subheader("Upload Audio File")
uploaded_file = st.file_uploader("Choose an audio file", type=["mp3", "wav", "flac", "ogg"])

# Model loading (only when needed)
model_load_state = st.empty()

# Process the uploaded file
if uploaded_file is not None:
    # Check file size
    file_size_mb = uploaded_file.size / 1e6
    
    if file_size_mb > MAX_FILE_SIZE_MB:
        st.error(f"File too large: {file_size_mb:.1f}MB. Maximum allowed size is {MAX_FILE_SIZE_MB}MB.")
    else:
        # Display file info
        file_details = {"Filename": uploaded_file.name, "FileSize": f"{file_size_mb:.2f} MB"}
        st.write(file_details)
        
        # Create a temporary file
        with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
            tmp_file.write(uploaded_file.getvalue())
            tmp_path = tmp_file.name
        
        # Check audio duration
        audio_duration = check_audio_length(tmp_path)
        
        if audio_duration > MAX_AUDIO_DURATION:
            st.error(f"Audio duration too long: {audio_duration:.1f} seconds. Maximum allowed duration is {MAX_AUDIO_DURATION} seconds ({MAX_AUDIO_DURATION//60} minutes).")
            # Clean up temporary file
            os.unlink(tmp_path)
        else:
            st.info(f"Audio duration: {audio_duration:.1f} seconds")
            
            # Load model (with caching for efficiency)
            with model_load_state:
                st.info("Loading AI model... This may take a moment the first time.")
                model = load_separator_model()
            
            if model is not None:
                # Process button
                if st.button("Split into Stems"):
                    try:
                        # Select processing sample rate based on file duration
                        # Shorter files can use higher quality, longer files use lower to save memory
                        if audio_duration < 60:  # Less than 1 minute
                            processing_sample_rate = 44100
                        elif audio_duration < 180:  # 1-3 minutes
                            processing_sample_rate = 32000
                        else:  # 3-5 minutes
                            processing_sample_rate = 22050
                        
                        # Perform stem separation
                        st.session_state.stems, st.session_state.sample_rate, st.session_state.visualizations = separate_stems(
                            tmp_path, model, sample_rate=processing_sample_rate
                        )
                        st.success("Stem separation completed! Scroll down to preview and download individual stems.")
                    except Exception as e:
                        st.error(f"An error occurred during processing: {str(e)}")
                        st.info("Try with a shorter audio clip or a different file format.")
            else:
                st.warning("Required packages not available. To run locally, install with 'pip install demucs librosa soundfile'")
            
            # Clean up temporary file
            os.unlink(tmp_path)

# Display results if available
if st.session_state.stems is not None:
    st.header("Separated Stems")
    
    # Create tabs for each stem
    stem_tabs = st.tabs(["Vocals", "Drums", "Bass", "Other"])
    
    # Get stem names in correct order
    stem_names = ["vocals", "drums", "bass", "other"]
    
    # Process each stem
    for i, (stem_tab, stem_name) in enumerate(zip(stem_tabs, stem_names)):
        with stem_tab:
            # Create columns for audio player and visualization
            col1, col2 = st.columns([1, 1])
            
            with col1:
                st.subheader(f"{stem_name.capitalize()} Stem")
                
                # Convert numpy array to audio file for playback
                audio_data = st.session_state.stems[stem_name]
                
                # Create a temporary buffer for the audio data
                buf = io.BytesIO()
                sf.write(buf, audio_data.T, st.session_state.sample_rate, format='WAV')
                buf.seek(0)
                
                # Display audio player
                st.audio(buf, format='audio/wav')
                
                # Download button
                st.markdown(get_binary_file_downloader_html(buf.getvalue(), f"{stem_name}", "wav"), unsafe_allow_html=True)
                
                # Additional information
                if stem_name == "vocals":
                    st.info("Contains lead vocals and backing vocals.")
                elif stem_name == "drums":
                    st.info("Contains drums and percussion elements.")
                elif stem_name == "bass":
                    st.info("Contains bass guitar and low-frequency elements.")
                else:  # other
                    st.info("Contains all other instruments (guitars, keys, synths, etc).")
            
            with col2:
                # Display visualization
                if st.session_state.visualizations and stem_name in st.session_state.visualizations:
                    st.image(st.session_state.visualizations[stem_name], caption=f"{stem_name.capitalize()} Spectrogram")
    
    # Show instructions for downloading all stems
    st.header("Usage Tips")
    st.markdown("""
    ### What can you do with these stems?
    - Create remixes or mashups
    - Practice playing along with isolated instrument tracks
    - Create karaoke versions by removing vocals
    - Analyze individual instrument parts for educational purposes
    
    ### Next steps:
    1. Download each stem you want to use
    2. Import them into your DAW (Digital Audio Workstation)
    3. Mix, process, and create!
    """)

# Add instructions for local deployment
st.sidebar.header("About This App")
st.sidebar.markdown("""
This application uses the Demucs model to separate audio tracks into individual stems. The model was developed by Facebook AI Research.

### How it works
The separation process uses a deep neural network to identify and isolate:
- Vocals
- Drums
- Bass
- Other instruments

### Source code
[GitHub Repository](https://github.com/huggingface/music-stem-splitter)  
(Link to your repo once created)
""")

# Add a note about processing time
st.sidebar.markdown("""
### Processing Time
The processing time depends on:
- Length of the audio file
- Available computational resources
- File quality

For best results, use high-quality audio files without excessive background noise.
""")

# Show model information
st.sidebar.markdown("""
### Model Information
This app uses the HTDemucs model, which is trained to separate music into four stems.

Audio processing is optimized based on file length:
- Short files (< 1 min): 44.1kHz processing
- Medium files (1-3 min): 32kHz processing 
- Longer files (3-5 min): 22kHz processing
""")