sd3-with-LLM / app.py
alfredplpl's picture
Update app.py
e2a600c verified
raw
history blame
No virus
3.65 kB
# Thanks: https://huggingface.co/spaces/markmagic/Stable-Diffusion-3/blob/main/app.py
import os
import random
import uuid
import gradio as gr
import numpy as np
from PIL import Image
import spaces
import torch
from diffusers import StableDiffusion3Pipeline, DPMSolverMultistepScheduler, AutoencoderKL
DESCRIPTION = """# 日本語で入力できるStable Diffusion 3"""
pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
torch_dtype=torch.float16,
token=os.getenv("TOKEN")
)
@spaces.GPU()
def generate(
prompt: str,
negative_prompt: str = "",
seed: int = 0,
width: int = 1024,
height: int = 1024,
guidance_scale: float = 7,
num_inference_steps=30,
progress=gr.Progress(track_tqdm=True),
):
pipe = pipe.to("cuda")
generator = torch.Generator().manual_seed(seed)
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
output_type="pil",
).images
return output
examples = [
"A red sofa on top of a white building.",
]
css = '''
.gradio-container{max-width: 1000px !important}
h1{text-align:center}
'''
with gr.Blocks(css=css) as demo:
with gr.Row():
with gr.Column():
gr.HTML(
"""
<h1 style='text-align: center'>
日本語で入力できるStable Diffusion 3 Medium
</h1>
"""
)
gr.HTML(
"""
"""
)
with gr.Group():
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Gallery(label="Result", elem_id="gallery", show_label=False)
with gr.Accordion("Advanced options", open=False):
with gr.Row():
negative_prompt = gr.Text(
label="Negative prompt",
max_lines=1,
value = "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, NSFW",
visible=True,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
steps = gr.Slider(
label="Steps",
minimum=0,
maximum=60,
step=1,
value=30,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=0.1,
maximum=10,
step=0.1,
value=7.0,
)
gr.Examples(
examples=examples,
inputs=prompt,
outputs=[result],
fn=generate,
cache_examples=CACHE_EXAMPLES,
)
gr.on(
triggers=[
prompt.submit,
negative_prompt.submit,
run_button.click,
],
fn=generate,
inputs=[
prompt,
negative_prompt,
seed,
guidance_scale,
steps,
],
outputs=[result],
)
if __name__ == "__main__":
demo.queue().launch()