from timbre_trap.framework.modules import TimbreTrap from pyharp import * import gradio as gr import torchaudio import torch import os model = TimbreTrap(sample_rate=22050, n_octaves=9, bins_per_octave=60, secs_per_block=3, latent_size=128, model_complexity=2, skip_connections=False) model.eval() model_path_orig = os.path.join('models', 'tt-orig.pt') tt_weights_orig = torch.load(model_path_orig, map_location='cpu') model.load_state_dict(tt_weights_orig) model_card = ModelCard( name='Timbre-Trap', description='De-timbre your audio!', author='Frank Cwitkowitz', tags=['example', 'music transcription', 'multi-pitch estimation', 'timbre filtering'] ) def process_fn(audio_path, transcribe): # Load the audio with torchaudio audio, fs = torchaudio.load(audio_path) # Average channels to obtain mono-channel audio = torch.mean(audio, dim=0, keepdim=True) # Resample audio to the specified sampling rate audio = torchaudio.functional.resample(audio, fs, 22050) # Add a batch dimension audio = audio.unsqueeze(0) # Determine original number of samples n_samples = audio.size(-1) # Obtain transcription or reconstructed spectral coefficients coefficients = model.chunked_inference(audio, transcribe) # Invert coefficients to produce audio audio = model.sliCQ.decode(coefficients) # Trim to original number of samples audio = audio[..., :n_samples] # Remove batch dimension audio = audio.squeeze(0) # Low-pass filter the audio in attempt to remove artifacts audio = torchaudio.functional.lowpass_biquad(audio, 22050, 8000) # Resample audio back to the original sampling rate audio = torchaudio.functional.resample(audio, 22050, fs) # Create a temporary directory for output os.makedirs('_outputs', exist_ok=True) # Create a path for saving the audio save_path = os.path.join('_outputs', 'output.wav') # Save the audio torchaudio.save(save_path, audio, fs) # No output labels output_labels = LabelList() return save_path, output_labels # Build Gradio endpoint with gr.Blocks() as demo: components = [ #gr.Checkbox( # value=False, # label='De-Timbre' #), gr.Slider( minimum=0, maximum=1, step=1, value=0, label='De-Timbre' ), #gr.Number( # value=0, # label='De-Timbre' #), #gr.Textbox( # value='text', # label='De-Timbre' #) ] app = build_endpoint(model_card=model_card, components=components, process_fn=process_fn) demo.queue() demo.launch(share=True)