File size: 3,415 Bytes
bd6e54b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gradio as gr


class GradioWebUI():

    def __init__(self, device, VAE, uNet, CLAP, CLAP_tokenizer,

                 freq_resolution=512, time_resolution=256, channels=4, timesteps=1000,

                 sample_rate=16000, squared=False, VAE_scale=4,

                 flexible_duration=False, noise_strategy="repeat",

                 GAN_generator = None):
        self.device = device
        self.VAE_encoder, self.VAE_quantizer, self.VAE_decoder = VAE._encoder, VAE._vq_vae, VAE._decoder
        self.uNet = uNet
        self.CLAP, self.CLAP_tokenizer = CLAP, CLAP_tokenizer
        self.freq_resolution, self.time_resolution = freq_resolution, time_resolution
        self.channels = channels
        self.GAN_generator = GAN_generator

        self.timesteps = timesteps
        self.sample_rate = sample_rate
        self.squared = squared
        self.VAE_scale = VAE_scale
        self.flexible_duration = flexible_duration
        self.noise_strategy = noise_strategy

        self.text2sound_state = gr.State(value={})
        self.interpolation_state = gr.State(value={})
        self.sound2sound_state = gr.State(value={})
        self.inpaint_state = gr.State(value={})

    def get_sample_steps_slider(self):
        default_steps = 10 if (self.device == "cpu") else 20
        return gr.Slider(minimum=10, maximum=100, value=default_steps, step=1,
                         label="Sample steps",
                         info="Sampling steps. The more sampling steps, the better the "
                              "theoretical result, but the time it consumes.")

    def get_sampler_radio(self):
        # return gr.Radio(choices=["ddpm", "ddim", "dpmsolver++", "dpmsolver"], value="ddim", label="Sampler")
        return gr.Radio(choices=["ddpm", "ddim"], value="ddim", label="Sampler")

    def get_batchsize_slider(self, cpu_batchsize=1):
        return gr.Slider(minimum=1., maximum=16, value=cpu_batchsize if (self.device == "cpu") else 8, step=1, label="Batchsize")

    def get_time_resolution_slider(self):
        return gr.Slider(minimum=16., maximum=int(1024/self.VAE_scale), value=int(256/self.VAE_scale), step=1, label="Time resolution", interactive=True)

    def get_duration_slider(self):
        if self.flexible_duration:
            return gr.Slider(minimum=0.25, maximum=8., value=3., step=0.01, label="duration in sec")
        else:
            return gr.Slider(minimum=1., maximum=8., value=3., step=1., label="duration in sec")

    def get_guidance_scale_slider(self):
        return gr.Slider(minimum=0., maximum=20., value=6., step=1.,
                         label="Guidance scale",
                         info="The larger this value, the more the generated sound is "
                              "influenced by the condition. Setting it to 0 is equivalent to "
                              "the negative case.")

    def get_noising_strength_slider(self, default_noising_strength=0.7):
        return gr.Slider(minimum=0.0, maximum=1.00, value=default_noising_strength, step=0.01,
                         label="noising strength",
                         info="The smaller this value, the more the generated sound is "
                              "closed to the origin.")

    def get_seed_textbox(self):
        return gr.Textbox(label="Seed", lines=1, placeholder="seed", value=0)