Spaces:
Running
Running
import dataclasses | |
import gradio as gr | |
import spaces | |
import torch | |
from PIL import Image | |
from diffusers import DiffusionPipeline | |
from diffusers.utils import make_image_grid | |
DIFFUSERS_MODEL_IDS = [ | |
# SD Models | |
"stabilityai/stable-diffusion-3-medium-diffusers", | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
"stabilityai/stable-diffusion-2-1", | |
"runwayml/stable-diffusion-v1-5", | |
# Other Models | |
"Prgckwb/trpfrog-diffusion", | |
] | |
EXTERNAL_MODEL_MAPPING = { | |
"Beautiful Realistic Asians": "checkpoints/diffusers/Beautiful Realistic Asians v7", | |
} | |
MODEL_CHOICES = DIFFUSERS_MODEL_IDS + list(EXTERNAL_MODEL_MAPPING.keys()) | |
current_model_id = "stabilityai/stable-diffusion-3-medium-diffusers" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe = None | |
class Input: | |
prompt: str | |
model_id: str = "stabilityai/stable-diffusion-3-medium-diffusers" | |
negative_prompt: str = '' | |
width: int = 1024 | |
height: int = 1024 | |
guidance_scale: float = 7.5 | |
num_inference_step: int = 28 | |
num_images: int = 4 | |
use_safety_checker: bool = True | |
use_model_offload: bool = False | |
seed: int = 8888 | |
def to_list(self): | |
return [ | |
self.prompt, self.model_id, self.negative_prompt, | |
self.width, self.height, self.guidance_scale, | |
self.num_inference_step, self.num_images, self.use_safety_checker, self.use_model_offload, | |
self.seed | |
] | |
EXAMPLES = [ | |
Input(prompt='A cat holding a sign that says Hello world').to_list(), | |
Input( | |
prompt='Beautiful pixel art of a Wizard with hovering text "Achivement unlocked: Diffusion models can spell now"' | |
).to_list(), | |
Input(prompt='A corgi wearing sunglasses says "U-Net is OVER!!"').to_list(), | |
Input( | |
prompt='Cinematic Photo of a beautiful korean fashion model bokeh train', | |
model_id='Beautiful Realistic Asians', | |
negative_prompt='worst_quality, BadNegAnatomyV1-neg, bradhands cartoon, cgi, render, illustration, painting, drawing', | |
width=512, | |
height=512, | |
guidance_scale=5.0, | |
num_inference_step=50, | |
).to_list() | |
] | |
def inference( | |
prompt: str, | |
model_id: str = "stabilityai/stable-diffusion-3-medium-diffusers", | |
negative_prompt: str = "", | |
width: int = 512, | |
height: int = 512, | |
guidance_scale: float = 7.5, | |
num_inference_steps: int = 50, | |
num_images: int = 4, | |
safety_checker: bool = True, | |
use_model_offload: bool = False, | |
seed: int = 8888, | |
progress=gr.Progress(track_tqdm=True), | |
) -> Image.Image: | |
progress(0, "Starting inference...") | |
global current_model_id, pipe | |
progress(0.1, 'Loading pipeline...') | |
if model_id not in DIFFUSERS_MODEL_IDS: | |
model_id = EXTERNAL_MODEL_MAPPING[model_id] | |
pipe = DiffusionPipeline.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, | |
) | |
current_model_id = model_id | |
if not safety_checker: | |
pipe.safety_checker = None | |
if model_id not in DIFFUSERS_MODEL_IDS: | |
progress(0.3, 'Loading Textual Inversion...') | |
# Load Textual Inversion | |
pipe.load_textual_inversion('checkpoints/embeddings/BadNegAnatomyV1 neg.pt', token='BadNegAnatomyV1-neg') | |
pipe.load_textual_inversion('checkpoints/embeddings/Deep Negative V1 75T.pt', token='DeepNegativeV1') | |
pipe.load_textual_inversion('checkpoints/embeddings/easynegative.safetensors', token='EasyNegative') | |
# Generation | |
progress(0.4, 'Generating images...') | |
if use_model_offload: | |
pipe.enable_model_cpu_offload() | |
else: | |
pipe = pipe.to('cuda') | |
generator = torch.Generator(device=device).manual_seed(seed) | |
images = pipe( | |
prompt, | |
negative_prompt=negative_prompt, | |
width=width, | |
height=height, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
num_images_per_prompt=num_images, | |
generator=generator, | |
).images | |
if num_images % 2 == 1: | |
image = make_image_grid(images, rows=num_images, cols=1) | |
else: | |
image = make_image_grid(images, rows=2, cols=num_images // 2) | |
return image | |
if __name__ == "__main__": | |
theme = gr.themes.Default(primary_hue=gr.themes.colors.emerald) | |
with gr.Blocks(theme=theme) as demo: | |
gr.Markdown(f"# Stable Diffusion Demo") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Text(label="Prompt", placeholder="Enter a prompt here") | |
model_id = gr.Dropdown( | |
label="Model ID", | |
choices=MODEL_CHOICES, | |
value="stabilityai/stable-diffusion-3-medium-diffusers", | |
) | |
# Additional Input Settings | |
with gr.Accordion("Additional Settings", open=False): | |
negative_prompt = gr.Text(label="Negative Prompt", value="", ) | |
with gr.Row(): | |
width = gr.Number(label="Width", value=512, step=64, minimum=64, maximum=2048) | |
height = gr.Number(label="Height", value=512, step=64, minimum=64, maximum=2048) | |
num_images = gr.Number(label="Num Images", value=4, minimum=1, maximum=10, step=1) | |
seed = gr.Number(label="Seed", value=8888, step=1) | |
guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, step=0.5, minimum=0, maximum=10) | |
num_inference_step = gr.Slider( | |
label="Num Inference Steps", value=50, minimum=1, maximum=100, step=2 | |
) | |
with gr.Row(): | |
use_safety_checker = gr.Checkbox(value=True, label='Use Safety Checker') | |
use_model_offload = gr.Checkbox(value=False, label='Use Model Offload') | |
with gr.Accordion(label='Notes', open=False): | |
# language=HTML | |
notes = gr.HTML( | |
""" | |
<h2>Negative Embeddings</h2> | |
<p>If you want to use negative embedding, use the following tokens in the prompt.</p> | |
<ul> | |
<li><a href='https://civitai.com/models/59614/badneganatomy-textual-inversion'>BadNegAnatomyV1-neg</a></li> | |
<li><a href='https://civitai.com/models/4629/deep-negative-v1x'>DeepNegativeV1</a> </li> | |
<li><a href='https://civitai.com/models/7808/easynegative'>EasyNegative</a></li> | |
</ul> | |
""" | |
) | |
with gr.Column(): | |
output_image = gr.Image(label="Image", type="pil") | |
inputs = [ | |
prompt, | |
model_id, | |
negative_prompt, | |
width, | |
height, | |
guidance_scale, | |
num_inference_step, | |
num_images, | |
use_safety_checker, | |
use_model_offload, | |
seed, | |
] | |
btn = gr.Button("Generate") | |
btn.click( | |
fn=inference, | |
inputs=inputs, | |
outputs=output_image | |
) | |
gr.Examples( | |
examples=EXAMPLES, | |
inputs=inputs, | |
) | |
demo.queue().launch() | |