File size: 1,704 Bytes
c8030d6
f2cdf65
 
 
87fc5ec
f2cdf65
 
c8030d6
 
 
 
 
 
 
 
87fc5ec
 
 
a18e7e2
 
c8030d6
 
 
 
87fc5ec
 
 
c8030d6
87fc5ec
 
 
 
 
 
 
 
 
 
 
 
c8030d6
87fc5ec
c8030d6
 
87fc5ec
 
 
 
 
 
 
 
 
 
 
 
 
c8030d6
87fc5ec
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import gradio as gr
import torch
from torch import autocast
from diffusers import StableDiffusionPipeline
from PIL import Image
from io import BytesIO
import base64
import re
import os
import requests


model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda"

pipe = StableDiffusionPipeline.from_pretrained(
    model_id, use_auth_token=True, revision="fp16", torch_dtype=torch.float16
)
pipe = pipe.to(device)
torch.backends.cudnn.benchmark = True

is_gpu_busy = False


@torch.no_grad()
def image_generation(prompt, samples=4, steps=25, scale=7.5):
    global is_gpu_busy

    images = []
    if is_gpu_busy:
        return images

    is_gpu_busy = True
    with autocast("cuda"):
        images = pipe(
            [prompt] * samples,
            num_inference_steps=steps,
            guidance_scale=scale,
        ).images
        is_gpu_busy = False

    return images


with gr.Blocks() as demo:
    gr.Markdown("# Stable Diffusion demo\nType something and generate images!")
    textbox = gr.Textbox(placeholder="Something cool...", interactive=True)
    with gr.Column(scale=1):
        samples = gr.Slider(minimum=1, maximum=8, value=4, step=1, label="samples")
        steps = gr.Slider(minimum=10, maximum=50, value=25, step=1, label="steps")
        scale = gr.Slider(minimum=5, maximum=15, value=7.5,step=0.1, label="scale")
    submit = gr.Button("Submit", variant="primary")
    gr.Markdown("Images will appear below")
    with gr.Row():
        gallery = gr.Gallery()
    textbox.submit(image_generation, inputs=[textbox, samples, steps, scale], outputs=[gallery])
    submit.click(image_generation, inputs=[textbox, samples, steps, scale], outputs=[gallery])
        
demo.launch()