#!/usr/bin/python3 import os import torch from audiosr import super_resolution, build_model, save_wave, get_time, read_list from pyharp import ModelCard, build_endpoint from audiotools import AudioSignal import scipy import torch import gradio as gr card = ModelCard( name='Versatile Audio Super Resolution', description='Upsample audio and predict upper spectrum.', author='Team Audio', tags=['AudioSR', 'Diffusion', 'Super Resolution', 'Upsampling', 'Sample Rate Conversion'] ) os.environ["TOKENIZERS_PARALLELISM"] = "true" torch.set_float32_matmul_precision("high") latent_t_per_second=12.8 # not sure about this?? audiosr = build_model(model_name="basic", device="auto") def process_fn(input_audio_path, seed, guidance_scale, num_inference_steps): """ This function defines the audio processing steps Args: input_audio_path (str): the audio filepath to be processed. : additional keyword arguments necessary for processing. NOTE: These should correspond to and match order of UI elements defined below. Returns: output_audio_path (str): the filepath of the processed audio. """ sig = AudioSignal(input_audio_path, sample_rate=16000) outfile = "./output.wav" audio_concat = None total_length = sig.duration print(f"Total length: {total_length}") num_segs = int(total_length / 5.12) # 5.12 second segments print(f"Number of segments: {num_segs}") remainder = total_length % 5.12 # duration of last segment print(f"Remainder: {remainder}") for audio_segment in range(num_segs): print(f"Processing segment {audio_segment} of {num_segs}") start = audio_segment * 5.12 if audio_segment == num_segs - 1: end = start + remainder else: end = start + 5.12 # get segment of audio from original file sig_seg = sig[..., int(start*sig.sample_rate):int(end*sig.sample_rate)] # int accounts for float end time on last seg print(f"Segment length: {sig_seg.duration}") print(f"Segment start: {start}") print(f"Segment end: {end}") print(f"Segment start sample: {int(start*sig.sample_rate)}") print(f"Segment end sample: {int(end*sig.sample_rate)}") sig_seg.write(f"temp_{audio_segment}.wav") audio = super_resolution( audiosr, f"temp_{audio_segment}.wav", seed=seed, guidance_scale=guidance_scale, ddim_steps=num_inference_steps, latent_t_per_second=latent_t_per_second ) #save_wave(waveform, output_dir, name=name, samplerate=sig.sample_rate) if audio_concat is None: audio_concat = audio #audio_concat = audio[0] else: audio_concat += audio scipy.io.wavfile.write(outfile, rate=48000, data=audio_concat) return outfile # Build the endpoint with gr.Blocks() as webapp: # Define your Gradio interface inputs = [ gr.Audio( label="Audio Input", type="filepath" ), gr.Slider( label="seed", minimum="0", maximum="65535", value="0", step="1" ), gr.Slider( minimum=0, maximum=10, value=3.5, label="Guidance Scale" ), gr.Slider( minimum=1, maximum=500, step=1, value=50, label="Inference Steps" ), ] # make an output audio widget output = gr.Audio(label="Audio Output", type="filepath") # Build the endpoint ctrls_data, ctrls_button, process_button, cancel_button = build_endpoint(inputs, output, process_fn, card) #webapp.queue() webapp.launch()