FilipeR's picture
Get it going.
16d2be1
#!/usr/bin/env python
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# ??????????????
import os
import random
import uuid
import json
import gradio as gr
import numpy as np
from PIL import Image
import spaces
import torch
from diffusers import (
DiffusionPipeline,
StableDiffusionXLPipeline,
EulerAncestralDiscreteScheduler,
)
from typing import Tuple
from gradio_imagefeed import ImageFeed
DEF_POSITIVE = (
"photograph of {prompt}, ultra-detailed, life-like, high-resolution, sharp, vibrant colors, photorealistic, Nikon, 30mm, natural skin imperfections",
)
DEF_NEGATIVE = (
"cartoon, glossy, fake, unnatural, deformed, disfigured, detached limbs, gross, ugly, fat, low resolution, blurry, abstract, dots, weird",
)
MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "0") == "1"
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def save_image(img):
fn = str(uuid.uuid4()) + ".jpeg"
img.save(fn)
return fn
def seed_fn(seed, reuse=True):
return seed if reuse else random.randint(0, MAX_SEED)
pipe = StableDiffusionXLPipeline.from_pretrained(
"SG161222/RealVisXL_V4.0",
torch_dtype=torch.float16,
use_safetensors=True,
add_watermarker=False,
variant="fp16",
).to(DEVICE)
if ENABLE_CPU_OFFLOAD:
pipe.enable_model_cpu_offload()
if USE_TORCH_COMPILE:
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
@spaces.GPU()
def generate(
positive: str,
negative: str,
seed: int = 0,
new_seed: bool = True,
steps: int = 35,
width: int = 896,
height: int = 1152,
images: int = 2,
guidance: float = 3.0,
progress=gr.Progress(track_tqdm=True),
):
seed = seed_fn(seed, reuse=new_seed)
# torch.cuda.empty_cache()
images = pipe(
prompt=positive,
negative_prompt=negative,
guidance_scale=guidance,
width=width,
height=height,
num_inference_steps=steps,
num_images_per_prompt=images,
generator=torch.Generator().manual_seed(seed),
use_resolution_binning=True,
output_type="pil"
).images
return ([save_image(i) for i in images], seed)
css = """
.gradio-container{max-width: 600px !important}
h1{text-align:center}
"""
demo = gr.Blocks(css=css, analytics_enabled=False)
with demo:
# gr.Blocks(css=css) as demo:
with gr.Group():
with gr.Row():
positive = gr.Text(
label="Positive",
max_lines=4,
placeholder="Enter Words",
container=False
)
run = gr.Button("Run")
# gallery = gr.Gallery(columns=2, preview=True)
gallery = ImageFeed(label="Generated Images")
negative = gr.Text(
label="Negative",
max_lines=4,
placeholder="Enter Words",
value="(deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers:1.4), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
visible=True,
)
with gr.Row():
steps = gr.Slider(label="Steps", minimum=10, maximum=60, step=1, value=35)
with gr.Row():
images = gr.Slider(label="Images", minimum=1, maximum=6, step=1, value=2)
seed = gr.Slider(
label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, visible=True
)
new_seed = gr.Checkbox(label="New Seed", value=True)
with gr.Row(visible=True):
width = gr.Slider(
label="Width", minimum=512, maximum=2048, step=16, value=896
)
height = gr.Slider(
label="Height", minimum=512, maximum=2048, step=16, value=1152
)
guidance = gr.Slider(
label="Guidance", minimum=0.1, maximum=20.0, step=0.1, value=6
)
gr.on(
triggers=[positive.submit, negative.submit, run.click],
fn=generate,
inputs=[
positive,
negative,
seed,
new_seed,
steps,
width,
height,
images,
guidance,
],
outputs=[gallery, seed],
api_name="run",
)
if __name__ == "__main__":
demo.queue()
demo.launch(server_name="0.0.0.0")