Prgckwb's picture
:tada: add external model
c3d14af
raw
history blame
No virus
5.2 kB
import dataclasses
import gradio as gr
import requests
import spaces
import torch
from PIL import Image
from diffusers import DiffusionPipeline
from diffusers.utils import make_image_grid
DIFFUSERS_MODEL_IDS = [
# SD Models
"stabilityai/stable-diffusion-3-medium-diffusers",
"stabilityai/stable-diffusion-xl-base-1.0",
"stabilityai/stable-diffusion-2-1",
"runwayml/stable-diffusion-v1-5",
# Other Models
"Prgckwb/trpfrog-diffusion",
]
EXTERNAL_MODEL_MAPPING = {
"Beautiful Realistic Asians": "checkpoints/diffusers/Beautiful Realistic Asians v7",
}
MODEL_CHOICES = DIFFUSERS_MODEL_IDS + list(EXTERNAL_MODEL_MAPPING.keys())
# Global Variables
current_model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == 'cuda':
pipe = DiffusionPipeline.from_pretrained(
current_model_id,
torch_dtype=torch.float16,
).to(device)
@dataclasses.dataclass
class Input:
prompt: str
model_id: str = "stabilityai/stable-diffusion-3-medium-diffusers"
negative_prompt: str = ''
width: int = 1024
height: int = 1024
guidance_scale: float = 7.5
num_inference_step: int = 28
num_images: int = 4
def to_list(self):
return [
self.prompt, self.model_id, self.negative_prompt,
self.width, self.height, self.guidance_scale,
self.num_inference_step, self.num_images
]
EXAMPLES = [
Input(prompt='A cat holding a sign that says Hello world').to_list(),
Input(
prompt='Beautiful pixel art of a Wizard with hovering text "Achivement unlocked: Diffusion models can spell now"'
).to_list(),
Input(prompt='A corgi wearing sunglasses says "U-Net is OVER!!"').to_list(),
]
@spaces.GPU()
@torch.inference_mode()
def inference(
prompt: str,
model_id: str = "stabilityai/stable-diffusion-3-medium-diffusers",
negative_prompt: str = "",
width: int = 512,
height: int = 512,
guidance_scale: float = 7.5,
num_inference_steps: int = 50,
num_images: int = 4,
progress=gr.Progress(track_tqdm=True),
) -> Image.Image:
progress(0, "Starting inference...")
if device != 'cuda':
raise gr.Error("This model requires a GPU to run. Please switch to a GPU runtime.")
global current_model_id, pipe
if model_id != current_model_id:
try:
if model_id not in DIFFUSERS_MODEL_IDS:
model_id = EXTERNAL_MODEL_MAPPING[model_id]
pipe = DiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
).to(device)
current_model_id = model_id
except Exception as e:
raise gr.Error(str(e))
# Generation
images = pipe(
prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
num_images_per_prompt=num_images,
).images
if num_images % 2 == 1:
image = make_image_grid(images, rows=num_images, cols=1)
else:
image = make_image_grid(images, rows=2, cols=num_images // 2)
return image
if __name__ == "__main__":
theme = gr.themes.Default(primary_hue=gr.themes.colors.emerald)
with gr.Blocks(theme=theme) as demo:
gr.Markdown(f"# Stable Diffusion Demo")
with gr.Row():
with gr.Column():
prompt = gr.Text(label="Prompt", placeholder="Enter a prompt here")
model_id = gr.Dropdown(
label="Model ID",
choices=MODEL_CHOICES,
value="stabilityai/stable-diffusion-3-medium-diffusers",
)
with gr.Accordion("Additional Settings", open=False):
negative_prompt = gr.Text(label="Negative Prompt", value="")
with gr.Row():
width = gr.Number(label="Width", value=512, step=64, minimum=64, maximum=2048)
height = gr.Number(label="Height", value=512, step=64, minimum=64, maximum=2048)
num_images = gr.Number(label="Num Images", value=4, minimum=1, maximum=10, step=1)
guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, step=0.5, minimum=0, maximum=10)
num_inference_step = gr.Slider(label="Num Inference Steps", value=50, minimum=1, maximum=100,
step=1)
with gr.Column():
output_image = gr.Image(label="Image", type="pil")
inputs = [
prompt,
model_id,
negative_prompt,
width,
height,
guidance_scale,
num_inference_step,
num_images,
]
btn = gr.Button("Generate")
btn.click(
fn=inference,
inputs=inputs,
outputs=output_image
)
gr.Examples(
examples=EXAMPLES,
inputs=inputs,
)
demo.queue().launch()