waveformer / app.py
Bandhav Veluri
Labels
d6a90a5
raw
history blame contribute delete
No virus
2.3 kB
import argparse
import os
import json
import wget
import torch
import torchaudio
import gradio as gr
from dcc_tf import Net as Waveformer
TARGETS = [
"Acoustic_guitar", "Applause", "Bark", "Bass_drum",
"Burping_or_eructation", "Bus", "Cello", "Chime", "Clarinet",
"Computer_keyboard", "Cough", "Cowbell", "Double_bass",
"Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping",
"Fireworks", "Flute", "Glockenspiel", "Gong", "Gunshot_or_gunfire",
"Harmonica", "Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow",
"Microwave_oven", "Oboe", "Saxophone", "Scissors", "Shatter",
"Snare_drum", "Squeak", "Tambourine", "Tearing", "Telephone",
"Trumpet", "Violin_or_fiddle", "Writing"
]
if not os.path.exists('default_config.json'):
config_url = 'https://targetsound.cs.washington.edu/files/default_config.json'
print("Downloading model configuration from %s:" % config_url)
wget.download(config_url)
if not os.path.exists('default_ckpt.pt'):
ckpt_url = 'https://targetsound.cs.washington.edu/files/default_ckpt.pt'
print("\nDownloading the checkpoint from %s:" % ckpt_url)
wget.download(ckpt_url)
# Instantiate model
with open('default_config.json') as f:
params = json.load(f)
model = Waveformer(**params['model_params'])
model.load_state_dict(
torch.load('default_ckpt.pt', map_location=torch.device('cpu'))['model_state_dict'])
model.eval()
def waveformer(audio, label_choices):
# Read input audio
fs, mixture = audio
if fs != 44100:
raise ValueError("Sampling rate must be 44100, but got %d" % fs)
mixture = torch.from_numpy(
mixture).unsqueeze(0).unsqueeze(0).to(torch.float) / (2.0 ** 15)
# Construct the query vector
query = torch.zeros(1, len(TARGETS))
for t in label_choices:
query[0, TARGETS.index(t)] = 1.
with torch.no_grad():
output = (2.0 ** 15) * model(mixture, query)
return fs, output.squeeze(0).squeeze(0).to(torch.short).numpy()
input_audio = gr.Audio(label="Input audio")
label_checkbox = gr.CheckboxGroup(choices=TARGETS, label="Input target selection(s)")
output_audio = gr.Audio(label="Output audio")
demo = gr.Interface(fn=waveformer, inputs=[input_audio, label_checkbox], outputs=output_audio)
demo.launch(show_error=True)