test2 / app.py
terryyz's picture
Upload app.py with huggingface_hub
c148ec9 verified
import spaces
import torch
import gradio as gr
from diffusers import StableDiffusionPipeline
from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig
import os
# --- 1. Model Loading and Optimization (AoT Compilation) ---
# Choose a stable diffusion model
MODEL_ID = "runwayml/stable-diffusion-v1-5"
# Initialize pipeline, disable safety checker for faster compilation and inference
# Use torch.float16 for efficiency on CUDA hardware
pipe = StableDiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
safety_checker=None,
requires_safety_checker=False
)
pipe.to('cuda')
pipe.scheduler.set_timesteps(50) # Set max steps for consistent performance testing
print("Starting AoT Compilation...")
@spaces.GPU(duration=1500) # Reserve maximum time for startup compilation
def compile_optimized_unet():
# 1. Apply FP8 quantization (optional, requires H200/H100 for maximum benefit)
try:
quantize_(pipe.unet, Float8DynamicActivationFloat8WeightConfig())
print("βœ… Applied FP8 quantization to UNet.")
except Exception as e:
print(f"⚠️ FP8 Quantization failed (may require specific hardware/libraries): {e}")
# 2. Define and capture example inputs for the UNet (the core engine)
# Standard Stable Diffusion UNet inputs (batch_size=2 for classifier-free guidance)
bsz = 2
latent_model_input = torch.randn(bsz, 4, 64, 64, device="cuda", dtype=torch.float16)
t = torch.randint(0, 1000, (bsz,), device="cuda')
encoder_hidden_states = torch.randn(bsz, 77, 768, device="cuda", dtype=torch.float16)
with spaces.aoti_capture(pipe.unet) as call:
pipe.unet(latent_model_input, t, encoder_hidden_states)
# 3. Export the model
exported = torch.export.export(
pipe.unet,
args=call.args,
kwargs=call.kwargs,
)
# 4. Compile the exported model using AoT
return spaces.aoti_compile(exported)
# Execute compilation during startup
compiled_unet = compile_optimized_unet()
# 5. Apply compiled model to the pipeline's UNet component
spaces.aoti_apply(compiled_unet, pipe.unet)
print("βœ… AoT Compilation completed successfully.")
# --- 2. Inference Function (Running on GPU) ---
@spaces.GPU(duration=60) # Standard duration for image generation
def generate_image(
prompt: str,
negative_prompt: str,
steps: int,
seed: int
):
if not prompt:
raise gr.Error("Prompt cannot be empty.")
generator = torch.Generator(device="cuda").manual_seed(seed) if seed != -1 else None
steps = int(steps)
# Run inference using the optimized pipeline
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=steps,
guidance_scale=7.5,
generator=generator
).images
return result
# --- 3. Gradio Interface ---
with gr.Blocks(title="Optimized Vision Model (AoT Powered)") as demo:
gr.HTML(
"""
<div style="text-align: center; max-width: 800px; margin: 0 auto;">
<h1><a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">Built with anycoder</a></h1>
<h2>High-Performance Creative VLM Simulator (AoT Optimized)</h2>
<p>This demo simulates a creative Vision Language Model using AoT-compiled Stable Diffusion for lightning-fast image generation.</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
prompt = gr.Textbox(
label="Prompt (Input to VLM)",
placeholder="A futuristic city painted by Van Gogh, highly detailed.",
lines=3
)
negative_prompt = gr.Textbox(
label="Negative Prompt (What to avoid)",
placeholder="Blurry, bad quality, low resolution",
lines=2
)
with gr.Accordion("Generation Settings", open=True):
steps = gr.Slider(
minimum=10,
maximum=50,
step=1,
value=30,
label="Inference Steps (Higher = Slower/Better)"
)
seed = gr.Number(
value=-1,
label="Seed (-1 for random)"
)
generate_btn = gr.Button("Generate Image (AoT Fast!)", variant="primary")
with gr.Column(scale=2):
output_gallery = gr.Gallery(
label="Creative VLM Output",
show_label=True,
height=512,
columns=2,
object_fit="contain"
)
generate_btn.click(
fn=generate_image,
inputs=[prompt, negative_prompt, steps, seed],
outputs=output_gallery
)
gr.Examples(
examples=[
["A majestic wolf standing on a snowy mountain peak, cinematic lighting", "ugly, deformed, low detail", 30],
["Cyberpunk cat sitting in a neon-lit alley, 8k, digital art", "human, blurry, messy background", 40],
["A vintage photograph of a space shuttle launching from a tropical island", "modern, cartoon, painting", 25]
],
inputs=[prompt, negative_prompt, steps],
outputs=output_gallery,
fn=generate_image,
cache_examples=False,
)
demo.queue()
demo.launch()