timbre-trap / app.py
cwitkowitz's picture
Working standalone.
883013e
raw
history blame
No virus
2.48 kB
from pyharp import ModelCard, build_endpoint
import gradio as gr
import torchaudio
import torch
import os
timbre_trap = torch.load('model-8750.pt', map_location='cpu')
card = ModelCard(
name='Timbre-Trap',
description='De-timbre your audio!',
author='Frank Cwitkowitz',
tags=['example', 'music transcription', 'multi-pitch estimation', 'timbre filtering']
)
def process_fn(audio_path, de_timbre):
# Load the audio with torchaudio
audio, fs = torchaudio.load(audio_path)
# Average channels to obtain mono-channel
audio = torch.mean(audio, dim=0, keepdim=True)
# Resample audio to the specified sampling rate
audio = torchaudio.functional.resample(audio, fs, 22050)
# Add a batch dimension
audio = audio.unsqueeze(0)
# Determine original number of samples
n_samples = audio.size(-1)
# Pad audio to next multiple of block length
audio = timbre_trap.sliCQ.pad_to_block_length(audio)
# Encode raw audio into latent vectors
latents, embeddings, _ = timbre_trap.encode(audio)
# Apply skip connections if they are turned on
embeddings = timbre_trap.apply_skip_connections(embeddings)
# Obtain transcription or reconstructed spectral coefficients
coefficients = timbre_trap.decode(latents, embeddings, de_timbre)
# Invert reconstructed spectral coefficients
audio = timbre_trap.sliCQ.decode(coefficients)
# Trim to original number of samples
audio = audio[..., :n_samples]
# Remove batch dimension
audio = audio.squeeze(0)
if de_timbre and audio.abs().max():
# Normalize audio to [-1, 1]
audio /= audio.abs().max()
# Create a temporary directory for output
os.makedirs('_outputs', exist_ok=True)
# Create a path for saving the audio
save_path = os.path.join('_outputs', 'output.wav')
# Save the audio
torchaudio.save(save_path, audio, 22050)
return save_path
with gr.Blocks() as demo:
inputs = [
gr.Audio(
label='Audio Input',
type='filepath'
),
#gr.Checkbox(
# value=False,
# label='De-Timbre'
#)
gr.Slider(
minimum=0,
maximum=1,
step=1,
value=0,
label='De-Timbre'
)
]
output = gr.Audio(label='Audio Output', type='filepath')
ctrls_data, ctrls_button, process_button = build_endpoint(inputs, output, process_fn, card)
demo.launch(share=True)