audiosr / app.py
j
fix AudioSignal \audio concatenation (probably still wrong)
0c92f84
#!/usr/bin/python3
import os
import torch
from audiosr import super_resolution, build_model, save_wave, get_time, read_list
from pyharp import ModelCard, build_endpoint
from audiotools import AudioSignal
import scipy
import torch
import gradio as gr
card = ModelCard(
name='Versatile Audio Super Resolution',
description='Upsample audio and predict upper spectrum.',
author='Team Audio',
tags=['AudioSR', 'Diffusion', 'Super Resolution', 'Upsampling', 'Sample Rate Conversion']
)
os.environ["TOKENIZERS_PARALLELISM"] = "true"
torch.set_float32_matmul_precision("high")
latent_t_per_second=12.8 # not sure about this??
audiosr = build_model(model_name="basic", device="auto")
def process_fn(input_audio_path, seed, guidance_scale, num_inference_steps):
"""
This function defines the audio processing steps
Args:
input_audio_path (str): the audio filepath to be processed.
<YOUR_KWARGS>: additional keyword arguments necessary for processing.
NOTE: These should correspond to and match order of UI elements defined below.
Returns:
output_audio_path (str): the filepath of the processed audio.
"""
sig = AudioSignal(input_audio_path, sample_rate=16000)
outfile = "./output.wav"
audio_concat = None
total_length = sig.duration
print(f"Total length: {total_length}")
num_segs = int(total_length / 5.12) # 5.12 second segments
print(f"Number of segments: {num_segs}")
remainder = total_length % 5.12 # duration of last segment
print(f"Remainder: {remainder}")
for audio_segment in range(num_segs):
print(f"Processing segment {audio_segment} of {num_segs}")
start = audio_segment * 5.12
if audio_segment == num_segs - 1:
end = start + remainder
else:
end = start + 5.12
# get segment of audio from original file
sig_seg = sig[..., int(start*sig.sample_rate):int(end*sig.sample_rate)] # int accounts for float end time on last seg
print(f"Segment length: {sig_seg.duration}")
print(f"Segment start: {start}")
print(f"Segment end: {end}")
print(f"Segment start sample: {int(start*sig.sample_rate)}")
print(f"Segment end sample: {int(end*sig.sample_rate)}")
sig_seg.write(f"temp_{audio_segment}.wav")
audio = super_resolution(
audiosr,
f"temp_{audio_segment}.wav",
seed=seed,
guidance_scale=guidance_scale,
ddim_steps=num_inference_steps,
latent_t_per_second=latent_t_per_second
)
#save_wave(waveform, output_dir, name=name, samplerate=sig.sample_rate)
if audio_concat is None:
audio_concat = audio
#audio_concat = audio[0]
else:
audio_concat += audio
scipy.io.wavfile.write(outfile, rate=48000, data=audio_concat)
return outfile
# Build the endpoint
with gr.Blocks() as webapp:
# Define your Gradio interface
inputs = [
gr.Audio(
label="Audio Input",
type="filepath"
),
gr.Slider(
label="seed",
minimum="0",
maximum="65535",
value="0",
step="1"
),
gr.Slider(
minimum=0, maximum=10,
value=3.5,
label="Guidance Scale"
),
gr.Slider(
minimum=1, maximum=500,
step=1, value=50,
label="Inference Steps"
),
]
# make an output audio widget
output = gr.Audio(label="Audio Output", type="filepath")
# Build the endpoint
ctrls_data, ctrls_button, process_button, cancel_button = build_endpoint(inputs, output, process_fn, card)
#webapp.queue()
webapp.launch()