File size: 4,666 Bytes
8a3e19c
29dfedd
8a3e19c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbf5d76
8a3e19c
 
bbf5d76
8a3e19c
 
 
 
 
 
 
 
 
 
 
 
 
bbf5d76
8a3e19c
 
 
 
 
 
bbf5d76
 
 
8a3e19c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af7f8ab
 
8a3e19c
 
 
 
 
 
 
 
 
bbf5d76
 
8a3e19c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbf5d76
8a3e19c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
from diffusers import StableDiffusionXLPipeline, DPMSolverSinglestepScheduler
import gradio as gr
import random
import numpy as np
import spaces


if torch.cuda.is_available():
    device = "cuda"
    print("Using GPU")
else:
    device = "cpu"
    print("Using CPU")


MAX_SEED = np.iinfo(np.int32).max


# Initialize the pipeline and download the sdxl flash model
pipe = StableDiffusionXLPipeline.from_pretrained("sd-community/sdxl-flash", torch_dtype=torch.float16)
pipe.to(device)

# Ensure sampler uses "trailing" timesteps.
pipe.scheduler = DPMSolverSinglestepScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")


# Define the image generation function
@spaces.GPU(duration=60)
def generate_image(prompt, negative_prompt, num_inference_steps, height, width, guidance_scale, seed, num_images_per_prompt, progress=gr.Progress(track_tqdm=True)):
    if seed == 0:
        seed = random.randint(1, 2**32-1)

    generator = torch.Generator().manual_seed(seed)
    
    output = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=num_inference_steps,
        height=height,
        width=width,
        guidance_scale=guidance_scale,
        generator=generator,
        num_images_per_prompt=num_images_per_prompt
    ).images
    
    return output



# Create the Gradio interface

examples = [
    ["A white car racing fast to the moon."],
    ["A woman in a red dress singing on top of a building."],
    ["An astrounat on mars in a futuristic cyborg suit."],
]

css = '''
.gradio-container{max-width: 1000px !important}
h1{text-align:center}
'''
with gr.Blocks(css=css) as demo:
    with gr.Row():
        with gr.Column():
            gr.HTML(
            """
            <h1 style='text-align: center'>
            SDXL Flash
            </h1>
            """
        )
            gr.HTML(
                """
               Made by <a href='https://linktr.ee/Nick088' target='_blank'>Nick088</a>
               <br> <a href="https://discord.gg/osai"> <img src="https://img.shields.io/discord/1198701940511617164?color=%23738ADB&label=Discord&style=for-the-badge" alt="Discord"> </a>
                """
        )
    with gr.Group():
        with gr.Column():
            prompt = gr.Textbox(label="Prompt", info="Describe the image you want", placeholder="A cat...")
            run_button = gr.Button("Run")
        result = gr.Gallery(label="Generated AI Images", elem_id="gallery")
    with gr.Accordion("Advanced options", open=False):
        with gr.Row():
            negative_prompt = gr.Textbox(label="Negative Prompt", info="Describe what you don't want in the image", value="deformed, distorted, disfigured, poorly drawn, bad anatomy, incorrect anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation", placeholder="Ugly, bad anatomy...")
        with gr.Row():
            num_inference_steps = gr.Slider(label="Number of Inference Steps", info="The number of denoising steps of the image. More denoising steps usually lead to a higher quality image at the cost of slower inference", minimum=1, maximum=15, value=8, step=1)
            guidance_scale = gr.Slider(label="Guidance Scale", info="Controls how much the image generation process follows the text prompt. Higher values make the image stick more closely to the input text.", minimum=0.0, maximum=6.0, value=3.5, step=0.1)
        with gr.Row():
            width = gr.Slider(label="Width", info="Width of the Image", minimum=256, maximum=1344, step=32, value=1024)
            height = gr.Slider(label="Height", info="Height of the Image", minimum=256, maximum=1344, step=32, value=1024)
        with gr.Row():
            seed = gr.Slider(value=42, minimum=0, maximum=MAX_SEED, step=1, label="Seed", info="A starting point to initiate the generation process, put 0 for a random one")
            num_images_per_prompt = gr.Slider(label="Images Per Prompt", info="Number of Images to generate with the settings",minimum=1, maximum=4, step=1, value=2)

    gr.Examples(
        examples=examples,
        inputs=[prompt],
        outputs=[result],
        fn=generate_image,
    )

    gr.on(
        triggers=[
            prompt.submit,
            run_button.click,
        ],
        fn=generate_image,
        inputs=[
            prompt,
            negative_prompt,
            num_inference_steps,
            width,
            height,
            guidance_scale,
            seed,
            num_images_per_prompt,
        ],
        outputs=[result],
    )

demo.queue().launch(share = False)