SPRIGHT-T2I / app.py
sayakpaul's picture
sayakpaul HF staff
Update app.py
0bbb523 verified
import os
import random
import gradio as gr
import numpy as np
import torch
from diffusers import DiffusionPipeline
import spaces
import uuid
DESCRIPTION = """# SPRIGHT T2I
[SPRIGHT T2I](https://spright-t2i.github.io/) is a framework to improve the spatial consistency of text-to-image models WITHOUT compromising their fidelity aspects.
"""
MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES", "1") == "1"
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "768"))
TOKEN = os.getenv("HF_TOKEN")
pipe_id = "SPRIGHT-T2I/spright-t2i-sd2"
pipe = DiffusionPipeline.from_pretrained(
pipe_id,
torch_dtype=torch.float16,
use_safetensors=True,
token=TOKEN,
).to("cuda")
def save_image(img):
unique_name = str(uuid.uuid4()) + ".png"
img.save(unique_name)
return unique_name
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
@spaces.GPU
def generate(
prompt: str,
seed: int = 0,
width: int = 768,
height: int = 768,
guidance_scale: float = 7.5,
num_inference_steps: int = 50,
randomize_seed: bool = False,
progress=gr.Progress(track_tqdm=True),
):
seed = randomize_seed_fn(seed, randomize_seed)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
image_path = save_image(image)
print(image_path)
return [image_path], seed
examples = [
"A cat next to a suitcase",
"A candle on the left of a mouse",
"A bag on the right of a dog",
"A mouse on the top of a bowl",
]
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_id="duplicate-button",
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
)
with gr.Group():
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Gallery(label="Result", columns=1, show_label=False)
with gr.Accordion("Advanced options", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row(visible=False):
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=1,
maximum=20,
step=0.1,
value=7.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=10,
maximum=100,
step=1,
value=50,
)
gr.Examples(
examples=examples,
inputs=prompt,
outputs=[result, seed],
fn=generate,
cache_examples=CACHE_EXAMPLES,
)
gr.on(
triggers=[
prompt.submit,
run_button.click,
],
fn=generate,
inputs=[prompt, seed, width, height, guidance_scale, num_inference_steps, randomize_seed],
outputs=[result, seed],
api_name="run",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()