waveformer / app.py
bandhav's picture
Base code
e6a6383
raw
history blame
2.03 kB
import argparse
import os
import wget
import torch
import torchaudio
import gradio as gr
from src.helpers import utils
from src.training.dcc_tf import Net as Waveformer
TARGETS = [
"Acoustic_guitar", "Applause", "Bark", "Bass_drum",
"Burping_or_eructation", "Bus", "Cello", "Chime", "Clarinet",
"Computer_keyboard", "Cough", "Cowbell", "Double_bass",
"Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping",
"Fireworks", "Flute", "Glockenspiel", "Gong", "Gunshot_or_gunfire",
"Harmonica", "Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow",
"Microwave_oven", "Oboe", "Saxophone", "Scissors", "Shatter",
"Snare_drum", "Squeak", "Tambourine", "Tearing", "Telephone",
"Trumpet", "Violin_or_fiddle", "Writing"
]
if not os.path.exists('default_config.json'):
config_url = 'https://targetsound.cs.washington.edu/files/default_config.json'
print("Downloading model configuration from %s:" % config_url)
wget.download(config_url)
if not os.path.exists('default_ckpt.pt'):
ckpt_url = 'https://targetsound.cs.washington.edu/files/default_ckpt.pt'
print("\nDownloading the checkpoint from %s:" % ckpt_url)
wget.download(ckpt_url)
# Instantiate model
params = utils.Params('default_config.json')
model = Waveformer(**params.model_params)
utils.load_checkpoint('default_ckpt.pt', model)
model.eval()
def waveformer(audio, label_choices):
# Read input audio
fs, mixture = audio
if fs != 44100:
raise ValueError(fs)
mixture = torch.from_numpy(mixture).unsqueeze(0)
# Construct the query vector
if len(label_choices) == 0:
raise ValueError(label_choices)
query = torch.zeros(1, len(TARGETS))
for t in label_choices:
query[0, TARGETS.index(t)] = 1.
with torch.no_grad():
output = model(mixture, query)
return fs, output.squeeze(0).numpy()
label_checkbox = gr.CheckboxGroup(choices=TARGETS)
demo = gr.Interface(fn=waveformer, inputs=['audio', label_checkbox], outputs="audio")
demo.launch()