audiosr / app.py
j
initial commit
564c686
raw
history blame
No virus
3.25 kB
#!/usr/bin/python3
import os
import torch
from audiosr import super_resolution, build_model, save_wave, get_time, read_list
from pyharp import ModelCard, build_endpoint
from audiotools import AudioSignal
import scipy
import torch
import gradio as gr
card = ModelCard(
name='Versatile Audio Super Resolution',
description='Upsample audio and predict upper spectrum.',
author='Team Audio',
tags=['AudioSR', 'Diffusion', 'Super Resolution', 'Upsampling', 'Sample Rate Conversion']
)
os.environ["TOKENIZERS_PARALLELISM"] = "true"
torch.set_float32_matmul_precision("high")
latent_t_per_second=12.8 # not sure about this??
audiosr = build_model(model_name="basic", device="auto")
def process_fn(input_audio_path, seed, guidance_scale, num_inference_steps):
"""
This function defines the audio processing steps
Args:
input_audio_path (str): the audio filepath to be processed.
<YOUR_KWARGS>: additional keyword arguments necessary for processing.
NOTE: These should correspond to and match order of UI elements defined below.
Returns:
output_audio_path (str): the filepath of the processed audio.
"""
sig = AudioSignal(input_audio_path)
outfile = "./output.wav"
audio_concat = None
total_length = sig.duration
num_segs = int(total_length / 10) #10 second segments
remainder = total_length % 10 # duration of last segment
for audio_segment in range(num_segs):
start = audio_segment * 10
if audio_segment == num_segs - 1:
end = start + remainder
else:
end = start + 10
# get segment of audio from original file
sig_seg = sig[start:end]
sig_seg.write("temp.wav")
audio = super_resolution(
audiosr,
"temp.wav",
seed=seed,
guidance_scale=guidance_scale,
ddim_steps=num_inference_steps,
latent_t_per_second=latent_t_per_second
)
#save_wave(waveform, output_dir, name=name, samplerate=sig.sample_rate)
if audio_concat is None:
audio_concat = audio
#audio_concat = audio[0]
else:
audio_concat = scipy.concatenate((audio_concat, audio))
scipy.io.wavfile.write(outfile, rate=sig.sample_rate, data=audio_concat)
return outfile
# Build the endpoint
with gr.Blocks() as webapp:
# Define your Gradio interface
inputs = [
gr.Audio(
label="Audio Input",
type="filepath"
),
gr.Slider(
label="seed",
minimum="0",
maximum="65535",
value="0",
step="1"
),
gr.Slider(
minimum=0, maximum=10,
value=3.5,
label="Guidance Scale"
),
gr.Slider(
minimum=1, maximum=500,
step=1, value=50,
label="Inference Steps"
),
]
# make an output audio widget
output = gr.Audio(label="Audio Output", type="filepath")
# Build the endpoint
ctrls_data, ctrls_button, process_button, cancel_button = build_endpoint(inputs, output, process_fn, card)
#webapp.queue()
webapp.launch(share=True)