Spaces:
Runtime error
Runtime error
#!/usr/bin/python3 | |
import os | |
import torch | |
from audiosr import super_resolution, build_model, save_wave, get_time, read_list | |
from pyharp import ModelCard, build_endpoint | |
from audiotools import AudioSignal | |
import scipy | |
import torch | |
import gradio as gr | |
card = ModelCard( | |
name='Versatile Audio Super Resolution', | |
description='Upsample audio and predict upper spectrum.', | |
author='Team Audio', | |
tags=['AudioSR', 'Diffusion', 'Super Resolution', 'Upsampling', 'Sample Rate Conversion'] | |
) | |
os.environ["TOKENIZERS_PARALLELISM"] = "true" | |
torch.set_float32_matmul_precision("high") | |
latent_t_per_second=12.8 # not sure about this?? | |
audiosr = build_model(model_name="basic", device="auto") | |
def process_fn(input_audio_path, seed, guidance_scale, num_inference_steps): | |
""" | |
This function defines the audio processing steps | |
Args: | |
input_audio_path (str): the audio filepath to be processed. | |
<YOUR_KWARGS>: additional keyword arguments necessary for processing. | |
NOTE: These should correspond to and match order of UI elements defined below. | |
Returns: | |
output_audio_path (str): the filepath of the processed audio. | |
""" | |
sig = AudioSignal(input_audio_path) | |
outfile = "./output.wav" | |
audio_concat = None | |
total_length = sig.duration | |
num_segs = int(total_length / 10) #10 second segments | |
remainder = total_length % 10 # duration of last segment | |
for audio_segment in range(num_segs): | |
start = audio_segment * 10 | |
if audio_segment == num_segs - 1: | |
end = start + remainder | |
else: | |
end = start + 10 | |
# get segment of audio from original file | |
sig_seg = sig[start:end] | |
sig_seg.write("temp.wav") | |
audio = super_resolution( | |
audiosr, | |
"temp.wav", | |
seed=seed, | |
guidance_scale=guidance_scale, | |
ddim_steps=num_inference_steps, | |
latent_t_per_second=latent_t_per_second | |
) | |
#save_wave(waveform, output_dir, name=name, samplerate=sig.sample_rate) | |
if audio_concat is None: | |
audio_concat = audio | |
#audio_concat = audio[0] | |
else: | |
audio_concat = scipy.concatenate((audio_concat, audio)) | |
scipy.io.wavfile.write(outfile, rate=sig.sample_rate, data=audio_concat) | |
return outfile | |
# Build the endpoint | |
with gr.Blocks() as webapp: | |
# Define your Gradio interface | |
inputs = [ | |
gr.Audio( | |
label="Audio Input", | |
type="filepath" | |
), | |
gr.Slider( | |
label="seed", | |
minimum="0", | |
maximum="65535", | |
value="0", | |
step="1" | |
), | |
gr.Slider( | |
minimum=0, maximum=10, | |
value=3.5, | |
label="Guidance Scale" | |
), | |
gr.Slider( | |
minimum=1, maximum=500, | |
step=1, value=50, | |
label="Inference Steps" | |
), | |
] | |
# make an output audio widget | |
output = gr.Audio(label="Audio Output", type="filepath") | |
# Build the endpoint | |
ctrls_data, ctrls_button, process_button, cancel_button = build_endpoint(inputs, output, process_fn, card) | |
#webapp.queue() | |
webapp.launch(share=True) |