Spaces:
Running
Running
from timbre_trap.framework.modules import TimbreTrap | |
from pyharp import * | |
import gradio as gr | |
import torchaudio | |
import torch | |
import os | |
model = TimbreTrap(sample_rate=22050, | |
n_octaves=9, | |
bins_per_octave=60, | |
secs_per_block=3, | |
latent_size=128, | |
model_complexity=2, | |
skip_connections=False) | |
model.eval() | |
model_path_orig = os.path.join('models', 'tt-orig.pt') | |
tt_weights_orig = torch.load(model_path_orig, map_location='cpu') | |
model.load_state_dict(tt_weights_orig) | |
model_card = ModelCard( | |
name='Timbre-Trap', | |
description='De-timbre your audio!', | |
author='Frank Cwitkowitz', | |
tags=['example', 'music transcription', 'multi-pitch estimation', 'timbre filtering'] | |
) | |
def process_fn(audio_path, transcribe): | |
# Load the audio with torchaudio | |
audio, fs = torchaudio.load(audio_path) | |
# Average channels to obtain mono-channel | |
audio = torch.mean(audio, dim=0, keepdim=True) | |
# Resample audio to the specified sampling rate | |
audio = torchaudio.functional.resample(audio, fs, 22050) | |
# Add a batch dimension | |
audio = audio.unsqueeze(0) | |
# Determine original number of samples | |
n_samples = audio.size(-1) | |
# Obtain transcription or reconstructed spectral coefficients | |
coefficients = model.chunked_inference(audio, transcribe) | |
# Invert coefficients to produce audio | |
audio = model.sliCQ.decode(coefficients) | |
# Trim to original number of samples | |
audio = audio[..., :n_samples] | |
# Remove batch dimension | |
audio = audio.squeeze(0) | |
# Low-pass filter the audio in attempt to remove artifacts | |
audio = torchaudio.functional.lowpass_biquad(audio, 22050, 8000) | |
# Resample audio back to the original sampling rate | |
audio = torchaudio.functional.resample(audio, 22050, fs) | |
# Create a temporary directory for output | |
os.makedirs('_outputs', exist_ok=True) | |
# Create a path for saving the audio | |
save_path = os.path.join('_outputs', 'output.wav') | |
# Save the audio | |
torchaudio.save(save_path, audio, fs) | |
# No output labels | |
output_labels = LabelList() | |
return save_path, output_labels | |
# Build Gradio endpoint | |
with gr.Blocks() as demo: | |
components = [ | |
#gr.Checkbox( | |
# value=False, | |
# label='De-Timbre' | |
#), | |
gr.Slider( | |
minimum=0, | |
maximum=1, | |
step=1, | |
value=0, | |
label='De-Timbre' | |
), | |
#gr.Number( | |
# value=0, | |
# label='De-Timbre' | |
#), | |
#gr.Textbox( | |
# value='text', | |
# label='De-Timbre' | |
#) | |
] | |
app = build_endpoint(model_card=model_card, | |
components=components, | |
process_fn=process_fn) | |
demo.queue() | |
demo.launch(share=True) | |