import gradio as gr import torch import torchaudio import hydra from hydra import compose, initialize import random import os from remfx import effects cfg = None classifier = None models = {} device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ALL_EFFECTS = effects.Pedalboard_Effects def init_hydra(): global cfg initialize(config_path="cfg", job_name="remfx", version_base="2.0") cfg = compose(config_name="config", overrides=["+exp=remfx_detect"]) def load_models(): global classifier print("Loading models") classifier = hydra.utils.instantiate(cfg.classifier, _convert_="partial") ckpt_path = cfg.classifier_ckpt state_dict = torch.load(ckpt_path, map_location=device)["state_dict"] classifier.load_state_dict(state_dict) classifier.to(device) for effect in cfg.ckpts: model = hydra.utils.instantiate(cfg.ckpts[effect].model, _convert_="partial") ckpt_path = cfg.ckpts[effect].ckpt_path state_dict = torch.load(ckpt_path, map_location=device)["state_dict"] model.load_state_dict(state_dict) model.to(device) models[effect] = model def audio_classification(audio_file): audio, sr = torchaudio.load(audio_file) audio = torchaudio.transforms.Resample(sr, cfg.sample_rate)(audio) # Convert to mono audio = audio.mean(0, keepdim=True) # Add dimension for batch audio = audio.unsqueeze(0) audio = audio.to(device) with torch.no_grad(): # Classify print("Detecting effects") labels = torch.tensor(classifier(audio)) labels_dict = { ALL_EFFECTS[i].__name__.replace("RandomPedalboard", ""): labels[i].item() for i in range(len(ALL_EFFECTS)) } return labels_dict def audio_removal(audio_file, labels, threshold): audio, sr = torchaudio.load(audio_file) audio = torchaudio.transforms.Resample(sr, cfg.sample_rate)(audio) # Convert to mono audio = audio.mean(0, keepdim=True) # Add dimension for batch audio = audio.unsqueeze(0) audio = audio.to(device) label_names = [f"RandomPedalboard{lab['label']}" for lab in labels["confidences"]] logits = torch.tensor([lab["confidence"] for lab in labels["confidences"]]) rem_fx_labels = torch.where(logits > threshold, 1.0, 0.0) effects_present = [ name for name, effect in zip(label_names, rem_fx_labels) if effect == 1.0 ] print("Removing effects:", effects_present) # Remove effects # Shuffle effects order effects_order = cfg.inference_effects_ordering random.shuffle(effects_order) # Get the correct effect by search for names in effects_order effects = [effect for effect in effects_order if effect in effects_present] elem = audio with torch.no_grad(): for effect in effects: # Sample the model elem = models[effect].model.sample(elem) output = elem.squeeze(0) waveform = gr.make_waveform((cfg.sample_rate, output[0].numpy())) return waveform def ui(): css = """ #classifier { padding-top: 40px; } #classifier .output-class { display: none; } """ with gr.Blocks(css=css) as interface: gr.HTML( """