Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
from huggingface_hub import hf_hub_download | |
import spaces # [uncomment to use ZeroGPU] | |
from diffusers import DiffusionPipeline | |
import torch | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model_repo_id = "stabilityai/stable-diffusion-xl-base-1.0" # Replace to the model you would like to use | |
if torch.cuda.is_available(): | |
torch_dtype = torch.float16 | |
else: | |
torch_dtype = torch.float32 | |
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype) | |
pipe = pipe.to(device) | |
# load pruned model | |
pruned_pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype) | |
pruned_pipe.transformer = torch.load( | |
hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/sdxl/sdxl.pkl"), | |
map_location="cpu", | |
) | |
pruned_pipe = pruned_pipe.to(device) | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 1024 | |
# [uncomment to use ZeroGPU] | |
def generate_images(prompt, seed, steps, pipe, pruned_pipe): | |
# Run the model and return images directly | |
g_cpu = torch.Generator("cuda").manual_seed(seed) | |
original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0] | |
g_cpu = torch.Generator("cuda").manual_seed(seed) | |
ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0] | |
return original_image, ecodiff_image | |
examples = [ | |
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | |
"An astronaut riding a green horse", | |
"A delicious ceviche cheesecake slice", | |
] | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 640px; | |
} | |
""" | |
header = """ | |
# 🌱 Text-to-Image Generation with EcoDiff Pruned SD-XL (20% Pruning Ratio) | |
# Under Construction!!! | |
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;"> | |
<a href="https://arxiv.org/abs/2412.02852"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a> | |
<a href="https://huggingface.co/zhangyang-0123/EcoDiffPrunedModels"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a> | |
<a href="https://github.com/YaNgZhAnG-V5/EcoDiff"><img src="https://img.shields.io/badge/GitHub-Code-blue.svg?logo=github&" alt="GitHub"></a> | |
</div> | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.Markdown(header) | |
with gr.Row(): | |
prompt = gr.Textbox( | |
label="Prompt", | |
value="A clock tower floating in a sea of clouds", | |
scale=3, | |
) | |
seed = gr.Number(label="Seed", value=44, precision=0, scale=1) | |
steps = gr.Slider( | |
label="Number of Steps", | |
minimum=1, | |
maximum=100, | |
value=50, | |
step=1, | |
scale=1, | |
) | |
generate_btn = gr.Button("Generate Images") | |
gr.Examples( | |
examples=[ | |
"A clock tower floating in a sea of clouds", | |
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | |
"An astronaut riding a green horse", | |
"A delicious ceviche cheesecake slice", | |
"A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages", | |
], | |
inputs=[prompt], | |
) | |
with gr.Row(): | |
original_output = gr.Image(label="Original Output") | |
ecodiff_output = gr.Image(label="EcoDiff Output") | |
gr.on( | |
triggers=[generate_btn.click, prompt.submit], | |
fn=generate_images, | |
inputs=[ | |
prompt, | |
seed, | |
steps, | |
pipe, | |
pipe, | |
], | |
outputs=[original_output, ecodiff_output], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |