SPRIGHT-T2I / app.py
sayakpaul's picture
sayakpaul HF staff
modify space
7c36275
import os
import random
import gradio as gr
import numpy as np
import torch
from diffusers import DiffusionPipeline, UNet2DConditionModel
import spaces
import uuid
DESCRIPTION = """# SPRIGHT T2I
#### [SPRIGHT T2I](https://spright.github.io/) is a framework to improve the spatial consistency of text-to-image models WITHOUT compromising their fidelity aspects.
"""
MAX_SEED = np.iinfo(np.int32).max
CACHE_EXAMPLES = os.getenv("CACHE_EXAMPLES", "1") == "1"
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "768"))
TOKEN = os.getenv("HF_TOKEN")
pipe_id = "SPRIGHT-T2I/spright-t2i-v1"
unet = UNet2DConditionModel.from_pretrained(pipe_id, subfolder="unet_ema", torch_dtype=torch.float16)
pipe = DiffusionPipeline.from_pretrained(
pipe_id,
unet=unet,
torch_dtype=torch.float16,
use_safetensors=True,
token=TOKEN,
).to("cuda")
def save_image(img):
unique_name = str(uuid.uuid4()) + ".png"
img.save(unique_name)
return unique_name
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
@spaces.gpu
def generate(
prompt: str,
seed: int = 0,
width: int = 768,
height: int = 768,
guidance_scale: float = 7.5,
num_inference_steps: int = 50,
randomize_seed: bool = False,
progress=gr.Progress(track_tqdm=True),
):
seed = randomize_seed_fn(seed, randomize_seed)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt=prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
image_path = save_image(image)
print(image_path)
return [image_path], seed
examples = [
"A cat next to a suitcase",
"A candle on the left of a mouse",
"A bag on the right of a dog",
"A mouse on the top of a bowl",
]
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(
value="Duplicate Space for private use",
elem_id="duplicate-button",
visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
)
with gr.Group():
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
result = gr.Gallery(label="Result", columns=1, show_label=False)
with gr.Accordion("Advanced options", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row(visible=False):
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=1,
maximum=20,
step=0.1,
value=7.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=10,
maximum=100,
step=1,
value=50,
)
gr.Examples(
examples=examples,
inputs=prompt,
outputs=[result, seed],
fn=generate,
cache_examples=CACHE_EXAMPLES,
)
gr.on(
triggers=[
prompt.submit,
run_button.click,
],
fn=generate,
inputs=[prompt, seed, width, height, guidance_scale, num_inference_steps, randomize_seed],
outputs=[result, seed],
api_name="run",
)
if __name__ == "__main__":
demo.queue(max_size=20).launch()