import gradio as gr import torch import torchaudio import hydra from hydra import compose, initialize import random import os from remfx import effects cfg = None classifier = None models = {} device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ALL_EFFECTS = effects.Pedalboard_Effects def init_hydra(): global cfg initialize(config_path="cfg", job_name="remfx", version_base="2.0") cfg = compose(config_name="config", overrides=["+exp=remfx_detect"]) def load_models(): global classifier print("Loading models") classifier = hydra.utils.instantiate(cfg.classifier, _convert_="partial") ckpt_path = cfg.classifier_ckpt state_dict = torch.load(ckpt_path, map_location=device)["state_dict"] classifier.load_state_dict(state_dict) classifier.to(device) for effect in cfg.ckpts: model = hydra.utils.instantiate(cfg.ckpts[effect].model, _convert_="partial") ckpt_path = cfg.ckpts[effect].ckpt_path state_dict = torch.load(ckpt_path, map_location=device)["state_dict"] model.load_state_dict(state_dict) model.to(device) models[effect] = model def audio_classification(audio_file): audio, sr = torchaudio.load(audio_file) audio = torchaudio.transforms.Resample(sr, cfg.sample_rate)(audio) # Convert to mono audio = audio.mean(0, keepdim=True) # Add dimension for batch audio = audio.unsqueeze(0) audio = audio.to(device) with torch.no_grad(): # Classify print("Detecting effects") labels = torch.tensor(classifier(audio)) labels_dict = { ALL_EFFECTS[i].__name__.replace("RandomPedalboard", ""): labels[i].item() for i in range(len(ALL_EFFECTS)) } return labels_dict def audio_removal(audio_file, labels, threshold): audio, sr = torchaudio.load(audio_file) audio = torchaudio.transforms.Resample(sr, cfg.sample_rate)(audio) # Convert to mono audio = audio.mean(0, keepdim=True) # Add dimension for batch audio = audio.unsqueeze(0) audio = audio.to(device) label_names = [f"RandomPedalboard{lab['label']}" for lab in labels["confidences"]] logits = torch.tensor([lab["confidence"] for lab in labels["confidences"]]) rem_fx_labels = torch.where(logits > threshold, 1.0, 0.0) effects_present = [ name for name, effect in zip(label_names, rem_fx_labels) if effect == 1.0 ] print("Removing effects:", effects_present) # Remove effects # Shuffle effects order effects_order = cfg.inference_effects_ordering random.shuffle(effects_order) # Get the correct effect by search for names in effects_order effects = [effect for effect in effects_order if effect in effects_present] elem = audio with torch.no_grad(): for effect in effects: # Sample the model elem = models[effect].model.sample(elem) output = elem.squeeze(0) waveform = gr.make_waveform((cfg.sample_rate, output[0].numpy())) return waveform def ui(): css = """ #classifier { padding-top: 40px; } #classifier .output-class { display: none; } """ with gr.Blocks(css=css) as interface: gr.HTML( """

RemFx: General Purpose Audio Effect Removal

[Paper] [Project page]

""" ) gr.HTML( """
This is our demo for the paper General Purpose Audio Effect Removal. It uses the RemFX Detect system described in the paper to detect the audio effects that are present and remove them.
To use the demo, use one of our curated examples or upload your own audio file and click submit. The system will then detect the effects present in the audio remove them if they meet the threshold.
""" ) with gr.Row(): with gr.Column(): effected_audio = gr.Audio( source="upload", type="filepath", label="File", interactive=True, elem_id="melody-input", ) submit = gr.Button("Submit") threshold = gr.Slider( minimum=0.0, maximum=1.0, step=0.1, value=0.5, label="Detection Threshold", ) with gr.Column(): classifier = gr.Label( num_top_classes=5, label="Effects Present", elem_id="classifier" ) audio_output = gr.Video(label="Output") gr.Examples( fn=audio_removal, examples=[ ["./input_examples/guitar.wav"], ["./input_examples/vocal.wav"], ["./input_examples/bass.wav"], ["./input_examples/drums.wav"], ["./input_examples/crazy_guitar.wav"], ], inputs=effected_audio, ) submit.click( audio_classification, inputs=[effected_audio], outputs=[classifier], queue=False, show_progress=False, ).then( audio_removal, inputs=[effected_audio, classifier, threshold], outputs=[audio_output], ) interface.queue().launch() if __name__ == "__main__": init_hydra() load_models() ui()