EcoDiff / app.py
zhangyang-0123's picture
Update app.py
d52f064 verified
raw
history blame
6.38 kB
import gradio as gr
from dataclasses import dataclass
import spaces
import torch
from huggingface_hub import hf_hub_download
from diffusers import StableDiffusionXLPipeline, FluxPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
@dataclass
class GradioArgs:
seed: list = None
prompt: str = None
mix_precision: str = "bf16"
num_intervention_steps: int = 50
model: str = "sdxl"
binary: bool = False
masking: str = "binary"
scope: str = "global"
ratio: list = None
width: int = None
height: int = None
epsilon: float = 0.0
lambda_threshold: float = 0.001
def __post_init__(self):
if self.seed is None:
self.seed = [44]
def binary_mask_eval(args, model):
model = model.lower()
# load sdxl model
if model == "sdxl":
pruned_pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cpu")
pruned_pipe.unet = torch.load(
hf_hub_download(
"zhangyang-0123/EcoDiffPrunedModels", "model/sdxl/sdxl.pkl"
),
map_location="cpu",
)
elif model == "flux":
pruned_pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
).to("cpu")
pruned_pipe.transformer = torch.load(
hf_hub_download(
"zhangyang-0123/EcoDiffPrunedModels", "model/flux/flux.pkl"
),
map_location="cpu",
)
torch.cuda.empty_cache()
# reload the original model
if model == "sdxl":
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16
).to("cpu")
elif model == "flux":
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
).to("cpu")
torch.cuda.empty_cache()
print("prune complete")
return pipe, pruned_pipe
@spaces.GPU
def generate_images(prompt, seed, steps, pipe, pruned_pipe):
pipe.to("cuda")
pruned_pipe.to("cuda")
# Run the model and return images directly
g_cpu = torch.Generator("cuda").manual_seed(seed)
original_image = pipe(
prompt=prompt, generator=g_cpu, num_inference_steps=steps
).images[0]
torch.cuda.empty_cache()
g_cpu = torch.Generator("cuda").manual_seed(seed)
ecodiff_image = pruned_pipe(
prompt=prompt, generator=g_cpu, num_inference_steps=steps
).images[0]
return original_image, ecodiff_image
def on_prune_click(prompt, seed, steps, model):
args = GradioArgs(prompt=prompt, seed=[seed], num_intervention_steps=steps)
pipe, pruned_pipe = binary_mask_eval(args, model)
return pipe, pruned_pipe, [("Model Initialized", "green")]
def on_generate_click(prompt, seed, steps, pipe, pruned_pipe):
original_image, ecodiff_image = generate_images(
prompt, seed, steps, pipe, pruned_pipe
)
return original_image, ecodiff_image
def create_demo():
with gr.Blocks() as demo:
gr.Markdown("# Text-to-Image Generation with EcoDiff Pruned Model")
with gr.Row():
gr.Markdown(
"""
**Note: Please first initialize the model before generating images. This may take a while to fully load.**
"""
)
with gr.Row():
model_choice = gr.Radio(
choices=["SDXL", "FLUX"], value="SDXL", label="Model", scale=2
)
pruning_ratio = gr.Text(
"20% Pruning Ratio for SDXL, FLUX", label="Pruning Ratio", scale=2
)
status_label = gr.HighlightedText(
label="Model Status", value=[("Model Not Initialized", "red")], scale=1
)
prune_btn = gr.Button(
"Initialize Original and Pruned Models", variant="primary", scale=1
)
with gr.Row():
gr.Markdown(
"""
**Generate images with the original model and the pruned model. May take up to 1 minute due to dynamic allocation of GPU.**
**Note: we prune on step-distilled FLUX, you should use step 5 (instead of 50) for FLUX generation. **
"""
)
with gr.Row():
prompt = gr.Textbox(
label="Prompt",
value="A clock tower floating in a sea of clouds",
scale=3,
)
seed = gr.Number(label="Seed", value=44, precision=0, scale=1)
steps = gr.Slider(
label="Number of Steps",
minimum=1,
maximum=100,
value=50,
step=1,
scale=1,
)
generate_btn = gr.Button("Generate Images")
gr.Examples(
examples=[
"A clock tower floating in a sea of clouds",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
"A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages",
],
inputs=[prompt],
)
with gr.Row():
original_output = gr.Image(label="Original Output")
ecodiff_output = gr.Image(label="EcoDiff Output")
pipe_state = gr.State(None)
pruned_pipe_state = gr.State(None)
prompt.submit(
fn=on_generate_click,
inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state],
outputs=[original_output, ecodiff_output],
)
prune_btn.click(
fn=on_prune_click,
inputs=[prompt, seed, steps, model_choice],
outputs=[pipe_state, pruned_pipe_state, status_label],
)
generate_btn.click(
fn=on_generate_click,
inputs=[prompt, seed, steps, pipe_state, pruned_pipe_state],
outputs=[original_output, ecodiff_output],
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch(share=True)