File size: 2,649 Bytes
31b925e
 
 
 
 
 
a52caef
3143457
a52caef
5f886ae
c34090f
1035bfa
31bdbd1
a52caef
 
 
 
31bdbd1
 
be6044e
a52caef
c34090f
 
 
 
 
 
1dad963
31b925e
1dad963
31b925e
1035bfa
 
a52caef
1035bfa
31b925e
 
c34090f
31bdbd1
31b925e
 
1dad963
 
 
31b925e
1035bfa
a52caef
31b925e
31bdbd1
a52caef
31bdbd1
 
 
 
1035bfa
a52caef
 
 
 
5527c63
cd65a38
c34090f
208fe80
 
 
 
 
c34090f
208fe80
 
9f160d1
31b925e
208fe80
5527c63
 
 
 
 
 
31b925e
208fe80
31b925e
31bdbd1
c34090f
 
 
 
 
31b925e
5527c63
1dad963
5527c63
 
 
3927a41
5527c63
 
c34090f
5527c63
 
95f3812
5527c63
3927a41
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import torchaudio
import gradio as gr
from demucs import pretrained
from demucs.apply import apply_model
from audiotools import AudioSignal
from typing import Dict
from pyharp import *


#DEMUX_MODELS = ["mdx_extra_q", "mdx_extra", "htdemucs", "mdx_q"]

STEM_CHOICES = {
    "Vocals": 3,
    "Drums": 0,
    "Bass": 1,
    "Other": 2,
    "Instrumental (No Vocals)": "instrumental"
}


#models = dict(zip(DEMUX_MODELS, [pretrained.get_model(m) for m in DEMUX_MODELS]))

#for model in models.values():
    #model.eval()

model = pretrained.get_model('mdx_extra_q')


def separate_stem(audio_file_path: str, model_name: str, stem_choice: str) -> AudioSignal:
    waveform, sr = torchaudio.load(audio_file_path)
    is_mono = waveform.shape[0] == 1
    if is_mono:
        waveform = waveform.repeat(2, 1)

    with torch.no_grad():
        stems_batch = apply_model(
            model,
            waveform.unsqueeze(0),
            overlap=0.2,
            shifts=1,
            split=True,
            progress=True,
            num_workers=4
        )

    stems = stems_batch[0]

    if stem_choice == "Instrumental (No Vocals)":
        stem = stems[0] + stems[1] + stems[2]
    else:
        stem_index = STEM_CHOICES[stem_choice]
        stem = stems[stem_index]

    if is_mono:
        stem = stem.mean(dim=0, keepdim=True)

    return AudioSignal(stem.cpu().numpy().astype('float32'), sample_rate=sr)

# Gradio Callback Function

def process_fn_stem(audio_file_path: str, stem_choice: str):
    """
    PyHARP process function:
      - Separates the chosen stem using Demucs.
      - Saves the stem as a .wav file.
    """
    stem_signal = separate_stem(audio_file_path, model_name='', stem_choice=stem_choice)
    stem_path = save_audio(stem_signal, f"{stem_choice.lower().replace(' ', '_')}.wav")
    return stem_path, LabelList(labels=[])


# Model Card 
model_card = ModelCard(
    name="Demucs Stem Separator",
    description="Uses Demucs to separate a music track into a selected stem.",
    author="Alexandre Défossez, Nicolas Usunier, Léon Bottou, Francis Bach",
    tags=["demucs", "source-separation", "pyharp", "stems"]
)

# Gradio UI
with gr.Blocks() as demo:

    #dropdown_model = gr.Dropdown(
    #    label="Demucs Model",
    #    choices=DEMUX_MODELS,
    #    value="mdx_extra_q"
    #)

    dropdown_stem = gr.Dropdown(
        label="Stem to Separate",
        choices=list(STEM_CHOICES.keys()),
        value="Vocals"
    )

    app = build_endpoint(
        model_card=model_card,
        components=[dropdown_stem],
        process_fn=process_fn_stem
    )

demo.queue()
demo.launch(show_error=True)