File size: 5,381 Bytes
27bce5c
3e79bbd
 
 
 
 
35afeb7
3e79bbd
 
 
27bce5c
3e79bbd
877e056
 
 
3e79bbd
 
 
 
 
27bce5c
 
 
 
 
 
 
3e79bbd
 
 
 
 
b0c7882
 
 
 
 
 
 
 
 
 
73d4af2
b0c7882
 
 
3e79bbd
 
 
 
 
8a2ae69
 
 
3e79bbd
 
78e5ece
8a00b4b
 
 
 
 
3e79bbd
 
 
877e056
 
3e79bbd
b0c7882
 
 
3e79bbd
35afeb7
 
 
 
27bce5c
3e79bbd
 
 
 
 
 
 
b0c7882
 
 
 
 
3e79bbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
877e056
3e79bbd
b0c7882
 
 
 
 
3e79bbd
8a2ae69
877e056
b0c7882
35afeb7
27bce5c
a2735a4
 
 
 
b0c7882
92dbcc7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import random
from PIL import Image

from diffusers import StableDiffusionPipeline
import gradio as gr
import torch
from spectro import wav_bytes_from_spectrogram_image

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id,
                                               torch_dtype=dtype,
                                               revision="fp16")
pipe = pipe.to(device)

model_id2 = "riffusion/riffusion-model-v1"
pipe2 = StableDiffusionPipeline.from_pretrained(model_id2, torch_dtype=dtype)
pipe2 = pipe2.to(device)

COLORS = [
    ["#ff0000", "#00ff00"],
    ["#00ff00", "#0000ff"],
    ["#0000ff", "#ff0000"],
]    

title = """
        <div style="text-align: center; max-width: 650px; margin: 0 auto 10px;">
          <div style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;">
            <h1 style="font-weight: 950; margin-bottom: 7px; color: #000; font-weight: bold;">Riffusion and Stable Diffusion</h1>
          </div>
          <p style="text-align: center;font-size: 94%">
              Duplicate this Space and run it on your own profile using a (paid) private T4-small or A10G-small GPU for training: 
              <span style="display: flex;align-items: center;justify-content: center;height: 30px;">
              <a href="https://huggingface.co/spaces/juancopi81/sd-riffusion?duplicate=true">
              <img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a>
              </span>
          </p>
          <p style="text-align: center;font-size: 94%">
              You can buy me a coffee to support this space: 
              <span style="display: flex;align-items: center;justify-content: center;height: 30px;">
              <a href="https://www.buymeacoffee.com/juancopi81j">
              <img src="https://badgen.net/badge/icon/Buy%20Me%20A%20Coffee?icon=buymeacoffee&label" alt="Buy me a coffee"></a>. Depending on the support, I'll keep this space running and add more features!
              </span>
          </p>
        </div>
        """
def get_bg_image(prompt):
    images = pipe(prompt)
    print("Image generated!")
    image_output = images.images[0]
    image_output.save("img.png")
    return "img.png"

def get_music(prompt):
    duration = 10
    if duration == 5:
        width_duration=512
    else :
        width_duration = 512 + ((int(duration)-5) * 128)
    spec = pipe2(prompt, height=512, width=width_duration).images[0]
    print(spec)
    wav = wav_bytes_from_spectrogram_image(spec)
    with open("output.wav", "wb") as f:
        f.write(wav[0].getbuffer())
    return "output.wav"

def infer(prompt, style):
    style_prompt = prompt + style
    image = get_bg_image(style_prompt)
    audio = get_music(prompt)
    video = gr.make_waveform(audio,
                             bg_image=image,
                             bars_color=random.choice(COLORS))
    return video, video

css = """
    #col-container {max-width: 700px; margin-left: auto; margin-right: auto;}
    #prompt-in {
        border: 2px solid #666;
        border-radius: 2px;
        padding: 8px;
    }
    #prompt-style {
        border: 2px solid #666;
        border-radius: 2px;
        padding: 8px;
    }
    #btn-container {
      display: flex;
      align-items: center;
      justify-content: center;
      width: calc(15% - 16px);
      height: calc(15% - 16px);
    }
    /* Style the submit button */
    #submit-btn {
      background-color: #382a1d;
      color: #fff;
      border: 1px solid #000;
      border-radius: 4px;
      padding: 8px;
      font-size: 16px;
      cursor: pointer;
    }
"""
with gr.Blocks(css=css) as demo:
    gr.HTML(title)
    with gr.Column(elem_id="col-container"):
      prompt_input = gr.Textbox(placeholder="The Beatles playing for the queen",
                                elem_id="prompt-in",
                                label="Enter your music prompt.")
      style_input = gr.Textbox(placeholder="In the style of Vincent van Gogh",
                                elem_id="prompt-style",
                                label="(Optional) Add styles to your background image.",
                                value="")
      with gr.Row(elem_id="btn-container"):
        send_btn = gr.Button(value="Send", elem_id="submit-btn")   
      send_btn.click(infer,
                     inputs=[prompt_input, style_input],
                     outputs=[gr.Video(), gr.File()])

      gr.Markdown("""
            [![Twitter Follow](https://img.shields.io/twitter/follow/juancopi81?style=social)](https://twitter.com/juancopi81)
            ![visitors](https://visitor-badge.glitch.me/badge?page_id=Juancopi81.sd-riffusion)
          """)

demo.launch(debug=True)