File size: 6,170 Bytes
c62c695
fc851dd
 
 
 
 
 
c62c695
fc851dd
 
 
 
c62c695
fc851dd
c62c695
 
fc851dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import gradio as gr
import torch
import torchaudio
import hydra
from hydra import compose, initialize
import random
from remfx import effects

cfg = None
classifier = None
models = {}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ALL_EFFECTS = effects.Pedalboard_Effects


def init_hydra():
    global cfg
    initialize(config_path="cfg", job_name="remfx", version_base="2.0")
    cfg = compose(config_name="config", overrides=["+exp=remfx_detect"])


def load_models():
    global classifier
    print("Loading models")
    classifier = hydra.utils.instantiate(cfg.classifier, _convert_="partial")
    ckpt_path = cfg.classifier_ckpt
    state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
    classifier.load_state_dict(state_dict)
    classifier.to(device)

    for effect in cfg.ckpts:
        model = hydra.utils.instantiate(cfg.ckpts[effect].model, _convert_="partial")
        ckpt_path = cfg.ckpts[effect].ckpt_path
        state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
        model.load_state_dict(state_dict)
        model.to(device)
        models[effect] = model


def audio_classification(audio_file):
    audio, sr = torchaudio.load(audio_file)
    audio = torchaudio.transforms.Resample(sr, cfg.sample_rate)(audio)
    # Add dimension for batch
    audio = audio.unsqueeze(0)
    # Convert to mono
    audio = audio.mean(0, keepdim=True)
    audio = audio.to(device)

    with torch.no_grad():
        # Classifiy
        print("Detecting effects")
        labels = torch.tensor(classifier(audio))
        labels_dict = {
            ALL_EFFECTS[i].__name__.replace("RandomPedalboard", ""): labels[i].item()
            for i in range(len(ALL_EFFECTS))
        }
    return labels_dict


def audio_removal(audio_file, labels, threshold):
    audio, sr = torchaudio.load(audio_file)
    audio = torchaudio.transforms.Resample(sr, cfg.sample_rate)(audio)
    # Add dimension for batch
    audio = audio.unsqueeze(0)
    # Convert to mono
    audio = audio.mean(0, keepdim=True)
    audio = audio.to(device)

    label_names = [f"RandomPedalboard{lab['label']}" for lab in labels["confidences"]]
    logits = torch.tensor([lab["confidence"] for lab in labels["confidences"]])
    rem_fx_labels = torch.where(logits > threshold, 1.0, 0.0)
    effects_present = [
        name for name, effect in zip(label_names, rem_fx_labels) if effect == 1.0
    ]
    print("Removing effects:", effects_present)
    # Remove effects
    # Shuffle effects order
    effects_order = cfg.inference_effects_ordering
    random.shuffle(effects_order)
    # Get the correct effect by search for names in effects_order
    effects = [effect for effect in effects_order if effect in effects_present]
    elem = audio
    with torch.no_grad():
        for effect in effects:
            # Sample the model
            elem = models[effect].model.sample(elem)
        output = elem.squeeze(0)
    waveform = gr.make_waveform((cfg.sample_rate, output[0].numpy()))

    return waveform


def ui():
    css = """

    #classifier {
        padding-top: 40px;
    }
    #classifier .output-class {
        display: none;

    }
    """
    with gr.Blocks(css=css) as interface:
        gr.HTML(
            """
                <div style="text-align: center; max-width: 700px; margin: 0 auto;">
                <div
                    style="
                    display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;
                    "
                >
                    <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
                    RemFx: General Purpose Audio Effect Removal
                    </h1>
                </div> <p style="margin-bottom: 10px; font-size: 94%">
                    <a href="https://arxiv.org/abs/2301.12503">[Paper]</a> <a href="https://csteinmetz1.github.io/RemFX/">[Project
                    page]</a>
                </p>
                </div>
            """
        )
        gr.HTML(
            """
                <div style="text-align: left;"> This is our demo for the paper General Purpose Audio Effect Removal. It uses the RemFX Detect system described in the paper to detect the audio effects that are present and remove them. <br>
                To use the demo, use one of our curated examples or upload your own audio file and click submit. The system will then detect the effects present in the audio remove them if they meet the threshold. </div>
            """
        )
        with gr.Row():
            with gr.Column():
                effected_audio = gr.Audio(
                    source="upload",
                    type="filepath",
                    label="File",
                    interactive=True,
                    elem_id="melody-input",
                )
                submit = gr.Button("Submit")
                threshold = gr.Slider(
                    minimum=0.0,
                    maximum=1.0,
                    step=0.1,
                    value=0.5,
                    label="Detection Threshold",
                )
            with gr.Column():
                classifier = gr.Label(
                    num_top_classes=5, label="Effects Present", elem_id="classifier"
                )
                audio_output = gr.Video(label="Output")
        gr.Examples(
            fn=audio_removal,
            examples=[
                ["./input_examples/guitar.wav"],
                ["./input_examples/vocal.wav"],
                ["./input_examples/bass.wav"],
                ["./input_examples/drums.wav"],
                ["./input_examples/crazy_guitar.wav"],
            ],
            inputs=effected_audio,
        )
        submit.click(
            audio_classification,
            inputs=[effected_audio],
            outputs=[classifier],
            queue=False,
            show_progress=False,
        ).then(
            audio_removal,
            inputs=[effected_audio, classifier, threshold],
            outputs=[audio_output],
        )

        interface.queue().launch()


if __name__ == "__main__":
    init_hydra()
    load_models()
    ui()