| | |
| | |
| | |
| |
|
| | import os |
| | import math |
| | import torch |
| | import gradio as gr |
| | from typing import List, Optional |
| | from PIL import Image |
| | from diffusers import ( |
| | DiffusionPipeline, |
| | StableDiffusionPipeline, |
| | AutoPipelineForText2Image, |
| | ) |
| |
|
| | |
| | MODEL_CHOICES = { |
| | |
| | "Stable Diffusion 1.5 (runwayml/stable-diffusion-v1-5)": "runwayml/stable-diffusion-v1-5", |
| | |
| | "SDXL Turbo (stabilityai/sdxl-turbo)": "stabilityai/sdxl-turbo", |
| | } |
| |
|
| | DEFAULT_MODEL_LABEL = "Stable Diffusion 1.5 (runwayml/stable-diffusion-v1-5)" |
| |
|
| | |
| | DISABLE_SAFETY_DEFAULT = True |
| |
|
| | |
| | def get_device() -> str: |
| | if torch.cuda.is_available(): |
| | return "cuda" |
| | |
| | if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): |
| | return "mps" |
| | return "cpu" |
| |
|
| | def nearest_multiple_of_8(x: int) -> int: |
| | if x < 64: |
| | return 64 |
| | return int(round(x / 8) * 8) |
| |
|
| | |
| | _PIPE_CACHE = {} |
| |
|
| | def load_pipe(model_id: str, device: str, fp16: bool) -> DiffusionPipeline: |
| | key = (model_id, device, fp16) |
| | if key in _PIPE_CACHE: |
| | return _PIPE_CACHE[key] |
| |
|
| | dtype = torch.float16 if (fp16 and device == "cuda") else torch.float32 |
| |
|
| | |
| | try: |
| | pipe = AutoPipelineForTextToImage.from_pretrained( |
| | model_id, |
| | torch_dtype=dtype, |
| | use_safetensors=True, |
| | trust_remote_code=False, |
| | ) |
| | except Exception: |
| | |
| | pipe = StableDiffusionPipeline.from_pretrained( |
| | model_id, |
| | torch_dtype=dtype, |
| | use_safetensors=True, |
| | ) |
| |
|
| | |
| | pipe = pipe.to(device) |
| |
|
| | |
| | if device == "cuda": |
| | try: |
| | pipe.enable_xformers_memory_efficient_attention() |
| | except Exception: |
| | pass |
| |
|
| | _PIPE_CACHE[key] = pipe |
| | return pipe |
| |
|
| | |
| | def generate( |
| | prompt: str, |
| | negative: str, |
| | model_label: str, |
| | steps: int, |
| | guidance: float, |
| | width: int, |
| | height: int, |
| | seed: Optional[int], |
| | batch_size: int, |
| | disable_safety: bool, |
| | ) -> List[Image.Image]: |
| | prompt = (prompt or "").strip() |
| | if not prompt: |
| | raise gr.Error("Enter a non-empty prompt.") |
| |
|
| | model_id = MODEL_CHOICES[model_label] |
| | device = get_device() |
| |
|
| | |
| | is_turbo = "sdxl-turbo" in model_id.lower() |
| | if is_turbo: |
| | steps = max(1, min(steps, 6)) |
| | guidance = 0.0 |
| |
|
| | width = nearest_multiple_of_8(width) |
| | height = nearest_multiple_of_8(height) |
| | batch_size = max(1, min(batch_size, 8)) |
| |
|
| | pipe = load_pipe(model_id, device, fp16=(device == "cuda")) |
| |
|
| | |
| | if hasattr(pipe, "safety_checker"): |
| | pipe.safety_checker = None if disable_safety else pipe.safety_checker |
| |
|
| | |
| | generator = None |
| | if seed is not None and seed != "": |
| | try: |
| | seed = int(seed) |
| | except ValueError: |
| | seed = None |
| | if seed is not None: |
| | if device == "cuda": |
| | generator = torch.Generator(device="cuda").manual_seed(seed) |
| | elif device == "mps": |
| | generator = torch.Generator(device="cpu").manual_seed(seed) |
| | else: |
| | generator = torch.Generator(device="cpu").manual_seed(seed) |
| |
|
| | prompts = [prompt] * batch_size |
| | negative_prompts = [negative] * batch_size if negative else None |
| |
|
| | |
| | with torch.autocast("cuda", enabled=(device == "cuda")): |
| | out = pipe( |
| | prompt=prompts, |
| | negative_prompt=negative_prompts, |
| | num_inference_steps=int(steps), |
| | guidance_scale=float(guidance), |
| | width=int(width), |
| | height=int(height), |
| | generator=generator, |
| | ) |
| |
|
| | images = out.images |
| | return images |
| |
|
| | |
| | with gr.Blocks(css="footer {visibility: hidden}") as demo: |
| | gr.Markdown( |
| | """ |
| | # Text-to-Image (Diffusers) |
| | - **Models:** SD 1.5 and SDXL Turbo |
| | - **Tip:** SD 1.5 = better detail on CPU; Turbo = very fast on GPU, fewer steps. |
| | """ |
| | ) |
| |
|
| | with gr.Row(): |
| | model_dd = gr.Dropdown( |
| | label="Model", |
| | choices=list(MODEL_CHOICES.keys()), |
| | value=DEFAULT_MODEL_LABEL, |
| | ) |
| | steps = gr.Slider(1, 75, value=30, step=1, label="Steps") |
| | guidance = gr.Slider(0.0, 15.0, value=7.5, step=0.1, label="Guidance (CFG)") |
| |
|
| | with gr.Row(): |
| | width = gr.Slider(256, 1024, value=768, step=8, label="Width (multiple of 8)") |
| | height = gr.Slider(256, 1024, value=768, step=8, label="Height (multiple of 8)") |
| | batch_size = gr.Slider(1, 4, value=1, step=1, label="Batch size") |
| |
|
| | prompt = gr.Textbox(label="Prompt", lines=2, placeholder="a cozy cabin at twilight beside a lake, cinematic lighting") |
| | negative = gr.Textbox(label="Negative Prompt", lines=1, placeholder="blurry, low quality, distorted") |
| | with gr.Row(): |
| | seed = gr.Textbox(label="Seed (optional integer)", value="") |
| | disable_safety = gr.Checkbox(label="Disable safety checker (you are responsible)", value=DISABLE_SAFETY_DEFAULT) |
| |
|
| | run_btn = gr.Button("Generate", variant="primary") |
| | gallery = gr.Gallery(label="Results", columns=2, height=512, preview=True) |
| |
|
| | def _on_change_model(label): |
| | |
| | if "Turbo" in label: |
| | return gr.update(value=4), gr.update(value=0.0) |
| | else: |
| | return gr.update(value=30), gr.update(value=7.5) |
| |
|
| | model_dd.change(_on_change_model, inputs=model_dd, outputs=[steps, guidance]) |
| |
|
| | run_btn.click( |
| | fn=generate, |
| | inputs=[prompt, negative, model_dd, steps, guidance, width, height, seed, batch_size, disable_safety], |
| | outputs=[gallery], |
| | api_name="generate", |
| | scroll_to_output=True, |
| | concurrency_limit=2, |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | |
| | demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)), debug=True) |
| |
|