timbre-trap / app.py
cwitkowitz's picture
Updated for new pyharp API.
bb34ae2
from timbre_trap.framework.modules import TimbreTrap
from pyharp import *
import gradio as gr
import torchaudio
import torch
import os
model = TimbreTrap(sample_rate=22050,
n_octaves=9,
bins_per_octave=60,
secs_per_block=3,
latent_size=128,
model_complexity=2,
skip_connections=False)
model.eval()
model_path_orig = os.path.join('models', 'tt-orig.pt')
tt_weights_orig = torch.load(model_path_orig, map_location='cpu')
model.load_state_dict(tt_weights_orig)
model_card = ModelCard(
name='Timbre-Trap',
description='De-timbre your audio!',
author='Frank Cwitkowitz',
tags=['example', 'music transcription', 'multi-pitch estimation', 'timbre filtering']
)
def process_fn(audio_path, transcribe):
# Load the audio with torchaudio
audio, fs = torchaudio.load(audio_path)
# Average channels to obtain mono-channel
audio = torch.mean(audio, dim=0, keepdim=True)
# Resample audio to the specified sampling rate
audio = torchaudio.functional.resample(audio, fs, 22050)
# Add a batch dimension
audio = audio.unsqueeze(0)
# Determine original number of samples
n_samples = audio.size(-1)
# Obtain transcription or reconstructed spectral coefficients
coefficients = model.chunked_inference(audio, transcribe)
# Invert coefficients to produce audio
audio = model.sliCQ.decode(coefficients)
# Trim to original number of samples
audio = audio[..., :n_samples]
# Remove batch dimension
audio = audio.squeeze(0)
# Low-pass filter the audio in attempt to remove artifacts
audio = torchaudio.functional.lowpass_biquad(audio, 22050, 8000)
# Resample audio back to the original sampling rate
audio = torchaudio.functional.resample(audio, 22050, fs)
# Create a temporary directory for output
os.makedirs('_outputs', exist_ok=True)
# Create a path for saving the audio
save_path = os.path.join('_outputs', 'output.wav')
# Save the audio
torchaudio.save(save_path, audio, fs)
# No output labels
output_labels = LabelList()
return save_path, output_labels
# Build Gradio endpoint
with gr.Blocks() as demo:
components = [
#gr.Checkbox(
# value=False,
# label='De-Timbre'
#),
gr.Slider(
minimum=0,
maximum=1,
step=1,
value=0,
label='De-Timbre'
),
#gr.Number(
# value=0,
# label='De-Timbre'
#),
#gr.Textbox(
# value='text',
# label='De-Timbre'
#)
]
app = build_endpoint(model_card=model_card,
components=components,
process_fn=process_fn)
demo.queue()
demo.launch(share=True)