File size: 3,807 Bytes
564c686
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c2366f
564c686
 
 
 
 
 
e72db6a
8c2366f
e72db6a
8c2366f
e72db6a
564c686
 
8c2366f
 
564c686
 
 
 
8c2366f
564c686
 
caf6171
e72db6a
 
 
 
 
cc634a8
564c686
 
cc634a8
564c686
 
 
 
 
 
 
 
 
 
 
 
0c92f84
564c686
b79b863
564c686
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c2366f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/python3
import os
import torch
from audiosr import super_resolution, build_model, save_wave, get_time, read_list
from pyharp import ModelCard, build_endpoint
from audiotools import AudioSignal
import scipy
import torch
import gradio as gr

card = ModelCard(
    name='Versatile Audio Super Resolution',
    description='Upsample audio and predict upper spectrum.',
    author='Team Audio',
    tags=['AudioSR', 'Diffusion', 'Super Resolution', 'Upsampling', 'Sample Rate Conversion']
)

os.environ["TOKENIZERS_PARALLELISM"] = "true"
torch.set_float32_matmul_precision("high")
latent_t_per_second=12.8  # not sure about this??

audiosr = build_model(model_name="basic", device="auto")

def process_fn(input_audio_path, seed, guidance_scale, num_inference_steps):
    """
    This function defines the audio processing steps

    Args:
        input_audio_path (str): the audio filepath to be processed.

        <YOUR_KWARGS>: additional keyword arguments necessary for processing.
            NOTE: These should correspond to and match order of UI elements defined below.

    Returns:
        output_audio_path (str): the filepath of the processed audio.
    """

    sig = AudioSignal(input_audio_path, sample_rate=16000)    

    outfile = "./output.wav"

    audio_concat = None

    total_length = sig.duration
    print(f"Total length: {total_length}")
    num_segs = int(total_length / 5.12) # 5.12 second segments
    print(f"Number of segments: {num_segs}")
    remainder = total_length % 5.12 # duration of last segment
    print(f"Remainder: {remainder}")

    for audio_segment in range(num_segs):
        print(f"Processing segment {audio_segment} of {num_segs}")
        start = audio_segment * 5.12

        if audio_segment == num_segs - 1:
            end = start + remainder
        else:
            end = start + 5.12

        # get segment of audio from original file
        sig_seg = sig[..., int(start*sig.sample_rate):int(end*sig.sample_rate)]  # int accounts for float end time on last seg
        print(f"Segment length: {sig_seg.duration}")
        print(f"Segment start: {start}")
        print(f"Segment end: {end}")
        print(f"Segment start sample: {int(start*sig.sample_rate)}")
        print(f"Segment end sample: {int(end*sig.sample_rate)}")
        sig_seg.write(f"temp_{audio_segment}.wav")
        audio = super_resolution(
            audiosr,
            f"temp_{audio_segment}.wav",
            seed=seed,
            guidance_scale=guidance_scale,
            ddim_steps=num_inference_steps,
            latent_t_per_second=latent_t_per_second
        )

        #save_wave(waveform, output_dir, name=name, samplerate=sig.sample_rate)

        if audio_concat is None:
            audio_concat = audio
            #audio_concat = audio[0]
        else:
            audio_concat += audio

    scipy.io.wavfile.write(outfile, rate=48000, data=audio_concat)
    return outfile

# Build the endpoint
with gr.Blocks() as webapp:
    # Define your Gradio interface
    inputs = [
        gr.Audio(
            label="Audio Input", 
            type="filepath"
        ), 
        gr.Slider(
            label="seed",
            minimum="0",
            maximum="65535",
            value="0",
            step="1"
        ),
        gr.Slider(
            minimum=0, maximum=10, 
            value=3.5, 
            label="Guidance Scale"
        ),
        gr.Slider(
            minimum=1, maximum=500, 
            step=1, value=50, 
            label="Inference Steps"
        ),
    ]

    # make an output audio widget
    output = gr.Audio(label="Audio Output", type="filepath")

    # Build the endpoint
    ctrls_data, ctrls_button, process_button, cancel_button = build_endpoint(inputs, output, process_fn, card)

#webapp.queue()
webapp.launch()