import argparse import os import json import wget import torch import torchaudio import gradio as gr from dcc_tf import Net as Waveformer TARGETS = [ "Acoustic_guitar", "Applause", "Bark", "Bass_drum", "Burping_or_eructation", "Bus", "Cello", "Chime", "Clarinet", "Computer_keyboard", "Cough", "Cowbell", "Double_bass", "Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping", "Fireworks", "Flute", "Glockenspiel", "Gong", "Gunshot_or_gunfire", "Harmonica", "Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow", "Microwave_oven", "Oboe", "Saxophone", "Scissors", "Shatter", "Snare_drum", "Squeak", "Tambourine", "Tearing", "Telephone", "Trumpet", "Violin_or_fiddle", "Writing" ] if not os.path.exists('default_config.json'): config_url = 'https://targetsound.cs.washington.edu/files/default_config.json' print("Downloading model configuration from %s:" % config_url) wget.download(config_url) if not os.path.exists('default_ckpt.pt'): ckpt_url = 'https://targetsound.cs.washington.edu/files/default_ckpt.pt' print("\nDownloading the checkpoint from %s:" % ckpt_url) wget.download(ckpt_url) # Instantiate model with open('default_config.json') as f: params = json.load(f) model = Waveformer(**params['model_params']) model.load_state_dict( torch.load('default_ckpt.pt', map_location=torch.device('cpu'))['model_state_dict']) model.eval() def waveformer(audio, label_choices): # Read input audio fs, mixture = audio if fs != 44100: raise ValueError("Sampling rate must be 44100, but got %d" % fs) mixture = torch.from_numpy( mixture).unsqueeze(0).unsqueeze(0).to(torch.float) / (2.0 ** 15) # Construct the query vector query = torch.zeros(1, len(TARGETS)) for t in label_choices: query[0, TARGETS.index(t)] = 1. with torch.no_grad(): output = (2.0 ** 15) * model(mixture, query) return fs, output.squeeze(0).squeeze(0).to(torch.short).numpy() label_checkbox = gr.CheckboxGroup(choices=TARGETS) demo = gr.Interface(fn=waveformer, inputs=['audio', label_checkbox], outputs="audio") demo.launch(show_error=True)