mattricesound commited on
Commit
fc851dd
1 Parent(s): c62c695

Updated all for hf spaces

Browse files
Files changed (3) hide show
  1. README.md +6 -0
  2. app.py +176 -4
  3. setup.py +1 -0
README.md CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  <div align="center">
2
 
3
  # RemFx
 
1
+ ---
2
+ title: RemFx
3
+ app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 3.41.2
6
+ ---
7
  <div align="center">
8
 
9
  # RemFx
app.py CHANGED
@@ -1,9 +1,181 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
 
 
 
 
3
 
4
- def greet(name):
5
- return "Hello " + name + "!!"
6
 
7
 
8
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
9
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ import hydra
5
+ from hydra import compose, initialize
6
+ import random
7
+ from remfx import effects
8
 
9
+ cfg = None
10
+ classifier = None
11
+ models = {}
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
+ ALL_EFFECTS = effects.Pedalboard_Effects
 
15
 
16
 
17
+ def init_hydra():
18
+ global cfg
19
+ initialize(config_path="cfg", job_name="remfx", version_base="2.0")
20
+ cfg = compose(config_name="config", overrides=["+exp=remfx_detect"])
21
+
22
+
23
+ def load_models():
24
+ global classifier
25
+ print("Loading models")
26
+ classifier = hydra.utils.instantiate(cfg.classifier, _convert_="partial")
27
+ ckpt_path = cfg.classifier_ckpt
28
+ state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
29
+ classifier.load_state_dict(state_dict)
30
+ classifier.to(device)
31
+
32
+ for effect in cfg.ckpts:
33
+ model = hydra.utils.instantiate(cfg.ckpts[effect].model, _convert_="partial")
34
+ ckpt_path = cfg.ckpts[effect].ckpt_path
35
+ state_dict = torch.load(ckpt_path, map_location=device)["state_dict"]
36
+ model.load_state_dict(state_dict)
37
+ model.to(device)
38
+ models[effect] = model
39
+
40
+
41
+ def audio_classification(audio_file):
42
+ audio, sr = torchaudio.load(audio_file)
43
+ audio = torchaudio.transforms.Resample(sr, cfg.sample_rate)(audio)
44
+ # Add dimension for batch
45
+ audio = audio.unsqueeze(0)
46
+ # Convert to mono
47
+ audio = audio.mean(0, keepdim=True)
48
+ audio = audio.to(device)
49
+
50
+ with torch.no_grad():
51
+ # Classifiy
52
+ print("Detecting effects")
53
+ labels = torch.tensor(classifier(audio))
54
+ labels_dict = {
55
+ ALL_EFFECTS[i].__name__.replace("RandomPedalboard", ""): labels[i].item()
56
+ for i in range(len(ALL_EFFECTS))
57
+ }
58
+ return labels_dict
59
+
60
+
61
+ def audio_removal(audio_file, labels, threshold):
62
+ audio, sr = torchaudio.load(audio_file)
63
+ audio = torchaudio.transforms.Resample(sr, cfg.sample_rate)(audio)
64
+ # Add dimension for batch
65
+ audio = audio.unsqueeze(0)
66
+ # Convert to mono
67
+ audio = audio.mean(0, keepdim=True)
68
+ audio = audio.to(device)
69
+
70
+ label_names = [f"RandomPedalboard{lab['label']}" for lab in labels["confidences"]]
71
+ logits = torch.tensor([lab["confidence"] for lab in labels["confidences"]])
72
+ rem_fx_labels = torch.where(logits > threshold, 1.0, 0.0)
73
+ effects_present = [
74
+ name for name, effect in zip(label_names, rem_fx_labels) if effect == 1.0
75
+ ]
76
+ print("Removing effects:", effects_present)
77
+ # Remove effects
78
+ # Shuffle effects order
79
+ effects_order = cfg.inference_effects_ordering
80
+ random.shuffle(effects_order)
81
+ # Get the correct effect by search for names in effects_order
82
+ effects = [effect for effect in effects_order if effect in effects_present]
83
+ elem = audio
84
+ with torch.no_grad():
85
+ for effect in effects:
86
+ # Sample the model
87
+ elem = models[effect].model.sample(elem)
88
+ output = elem.squeeze(0)
89
+ waveform = gr.make_waveform((cfg.sample_rate, output[0].numpy()))
90
+
91
+ return waveform
92
+
93
+
94
+ def ui():
95
+ css = """
96
+
97
+ #classifier {
98
+ padding-top: 40px;
99
+ }
100
+ #classifier .output-class {
101
+ display: none;
102
+
103
+ }
104
+ """
105
+ with gr.Blocks(css=css) as interface:
106
+ gr.HTML(
107
+ """
108
+ <div style="text-align: center; max-width: 700px; margin: 0 auto;">
109
+ <div
110
+ style="
111
+ display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;
112
+ "
113
+ >
114
+ <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
115
+ RemFx: General Purpose Audio Effect Removal
116
+ </h1>
117
+ </div> <p style="margin-bottom: 10px; font-size: 94%">
118
+ <a href="https://arxiv.org/abs/2301.12503">[Paper]</a> <a href="https://csteinmetz1.github.io/RemFX/">[Project
119
+ page]</a>
120
+ </p>
121
+ </div>
122
+ """
123
+ )
124
+ gr.HTML(
125
+ """
126
+ <div style="text-align: left;"> This is our demo for the paper General Purpose Audio Effect Removal. It uses the RemFX Detect system described in the paper to detect the audio effects that are present and remove them. <br>
127
+ To use the demo, use one of our curated examples or upload your own audio file and click submit. The system will then detect the effects present in the audio remove them if they meet the threshold. </div>
128
+ """
129
+ )
130
+ with gr.Row():
131
+ with gr.Column():
132
+ effected_audio = gr.Audio(
133
+ source="upload",
134
+ type="filepath",
135
+ label="File",
136
+ interactive=True,
137
+ elem_id="melody-input",
138
+ )
139
+ submit = gr.Button("Submit")
140
+ threshold = gr.Slider(
141
+ minimum=0.0,
142
+ maximum=1.0,
143
+ step=0.1,
144
+ value=0.5,
145
+ label="Detection Threshold",
146
+ )
147
+ with gr.Column():
148
+ classifier = gr.Label(
149
+ num_top_classes=5, label="Effects Present", elem_id="classifier"
150
+ )
151
+ audio_output = gr.Video(label="Output")
152
+ gr.Examples(
153
+ fn=audio_removal,
154
+ examples=[
155
+ ["./input_examples/guitar.wav"],
156
+ ["./input_examples/vocal.wav"],
157
+ ["./input_examples/bass.wav"],
158
+ ["./input_examples/drums.wav"],
159
+ ["./input_examples/crazy_guitar.wav"],
160
+ ],
161
+ inputs=effected_audio,
162
+ )
163
+ submit.click(
164
+ audio_classification,
165
+ inputs=[effected_audio],
166
+ outputs=[classifier],
167
+ queue=False,
168
+ show_progress=False,
169
+ ).then(
170
+ audio_removal,
171
+ inputs=[effected_audio, classifier, threshold],
172
+ outputs=[audio_output],
173
+ )
174
+
175
+ interface.queue().launch()
176
+
177
+
178
+ if __name__ == "__main__":
179
+ init_hydra()
180
+ load_models()
181
+ ui()
setup.py CHANGED
@@ -53,6 +53,7 @@ setup(
53
  "torchmetrics>=1.0",
54
  "wav2clip_hear @ git+https://github.com/hohsiangwu/wav2clip-hear.git",
55
  "panns_hear @ git+https://github.com/qiuqiangkong/HEAR2021_Challenge_PANNs",
 
56
  ],
57
  include_package_data=True,
58
  license="Apache License 2.0",
 
53
  "torchmetrics>=1.0",
54
  "wav2clip_hear @ git+https://github.com/hohsiangwu/wav2clip-hear.git",
55
  "panns_hear @ git+https://github.com/qiuqiangkong/HEAR2021_Challenge_PANNs",
56
+ "gradio",
57
  ],
58
  include_package_data=True,
59
  license="Apache License 2.0",