File size: 3,367 Bytes
7d9b08c
 
 
 
 
 
 
1ece765
7d9b08c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
import torch
from diffusers import AudioLDMPipeline

from transformers import AutoProcessor, ClapModel


device = "cpu"
torch_dtype = torch.float32

repo_id = "cvssp/audioldm-m-full"
pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device)
pipe.unet = torch.compile(pipe.unet)


clap_model = ClapModel.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full").to(device)
processor = AutoProcessor.from_pretrained("sanchit-gandhi/clap-htsat-unfused-m-full")

generator = torch.Generator(device)


def text2audio(text, negative_prompt, duration, guidance_scale, random_seed, n_candidates):
    if text is None:
        raise gr.Error("请提供文本输入")

    waveforms = pipe(
        text,
        audio_length_in_s=duration,
        guidance_scale=guidance_scale,
        negative_prompt=negative_prompt,
        num_waveforms_per_prompt=n_candidates if n_candidates else 1,
        generator=generator.manual_seed(int(random_seed)),
    )["audios"]

    if waveforms.shape[0] > 1:
        waveform = score_waveforms(text, waveforms)
    else:
        waveform = waveforms[0]

    return gr.make_waveform((16000, waveform))


def score_waveforms(text, waveforms):
    inputs = processor(text=text, audios=list(waveforms), return_tensors="pt", padding=True)
    inputs = {key: inputs[key].to(device) for key in inputs}
    with torch.no_grad():
        logits_per_text = clap_model(**inputs).logits_per_text
        probs = logits_per_text.softmax(dim=-1)
        most_probable = torch.argmax(probs)
    waveform = waveforms[most_probable]
    return waveform


iface = gr.Blocks()

with iface:
    with gr.Group():
        with gr.Box():
            textbox = gr.Textbox(
                max_lines=1,
                label="要求",
                info="要求",
                elem_id="prompt-in",
            )
            negative_textbox = gr.Textbox(
                max_lines=1,
                label="更详细的要求",
                info="更详细的要求",
                elem_id="prompt-in",
            )

            with gr.Accordion("展开更多选项", open=False):
                seed = gr.Number(
                    value=45,
                    label="种子",
                    info="不同种子有不同结果,相同种子有相同结果",
                )
                duration = gr.Slider(2.5, 10, value=5, step=2.5, label="持续时间(秒)")
                guidance_scale = gr.Slider(
                    0,
                    4,
                    value=2.5,
                    step=0.5,
                    label="质量",
                    info="大有更好的质量和与文本的相关性;小有更好的多样性",
                )
                n_candidates = gr.Slider(
                    1,
                    3,
                    value=3,
                    step=1,
                    label="候选数量",
                    info="这个数字控制候选数量",
                )

            outputs = gr.Video(label="Output", elem_id="output-video")
            btn = gr.Button("Submit").style(full_width=True)

        btn.click(
            text2audio,
            inputs=[textbox, negative_textbox, duration, guidance_scale, seed, n_candidates],
            outputs=[outputs],
        )

iface.queue(max_size=10).launch(debug=True)