|
|
import spaces |
|
|
import torch |
|
|
import gradio as gr |
|
|
from diffusers import StableDiffusionPipeline |
|
|
from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_ID = "runwayml/stable-diffusion-v1-5" |
|
|
|
|
|
|
|
|
|
|
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype=torch.float16, |
|
|
safety_checker=None, |
|
|
requires_safety_checker=False |
|
|
) |
|
|
pipe.to('cuda') |
|
|
pipe.scheduler.set_timesteps(50) |
|
|
|
|
|
print("Starting AoT Compilation...") |
|
|
|
|
|
@spaces.GPU(duration=1500) |
|
|
def compile_optimized_unet(): |
|
|
|
|
|
try: |
|
|
quantize_(pipe.unet, Float8DynamicActivationFloat8WeightConfig()) |
|
|
print("β
Applied FP8 quantization to UNet.") |
|
|
except Exception as e: |
|
|
print(f"β οΈ FP8 Quantization failed (may require specific hardware/libraries): {e}") |
|
|
|
|
|
|
|
|
|
|
|
bsz = 2 |
|
|
latent_model_input = torch.randn(bsz, 4, 64, 64, device="cuda", dtype=torch.float16) |
|
|
t = torch.randint(0, 1000, (bsz,), device="cuda') |
|
|
encoder_hidden_states = torch.randn(bsz, 77, 768, device="cuda", dtype=torch.float16) |
|
|
|
|
|
with spaces.aoti_capture(pipe.unet) as call: |
|
|
pipe.unet(latent_model_input, t, encoder_hidden_states) |
|
|
|
|
|
# 3. Export the model |
|
|
exported = torch.export.export( |
|
|
pipe.unet, |
|
|
args=call.args, |
|
|
kwargs=call.kwargs, |
|
|
) |
|
|
|
|
|
# 4. Compile the exported model using AoT |
|
|
return spaces.aoti_compile(exported) |
|
|
|
|
|
# Execute compilation during startup |
|
|
compiled_unet = compile_optimized_unet() |
|
|
# 5. Apply compiled model to the pipeline's UNet component |
|
|
spaces.aoti_apply(compiled_unet, pipe.unet) |
|
|
|
|
|
print("β
AoT Compilation completed successfully.") |
|
|
|
|
|
# --- 2. Inference Function (Running on GPU) --- |
|
|
|
|
|
@spaces.GPU(duration=60) # Standard duration for image generation |
|
|
def generate_image( |
|
|
prompt: str, |
|
|
negative_prompt: str, |
|
|
steps: int, |
|
|
seed: int |
|
|
): |
|
|
if not prompt: |
|
|
raise gr.Error("Prompt cannot be empty.") |
|
|
|
|
|
generator = torch.Generator(device="cuda").manual_seed(seed) if seed != -1 else None |
|
|
|
|
|
steps = int(steps) |
|
|
|
|
|
# Run inference using the optimized pipeline |
|
|
result = pipe( |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
num_inference_steps=steps, |
|
|
guidance_scale=7.5, |
|
|
generator=generator |
|
|
).images |
|
|
|
|
|
return result |
|
|
|
|
|
# --- 3. Gradio Interface --- |
|
|
|
|
|
with gr.Blocks(title="Optimized Vision Model (AoT Powered)") as demo: |
|
|
gr.HTML( |
|
|
""" |
|
|
<div style="text-align: center; max-width: 800px; margin: 0 auto;"> |
|
|
<h1><a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">Built with anycoder</a></h1> |
|
|
<h2>High-Performance Creative VLM Simulator (AoT Optimized)</h2> |
|
|
<p>This demo simulates a creative Vision Language Model using AoT-compiled Stable Diffusion for lightning-fast image generation.</p> |
|
|
</div> |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
prompt = gr.Textbox( |
|
|
label="Prompt (Input to VLM)", |
|
|
placeholder="A futuristic city painted by Van Gogh, highly detailed.", |
|
|
lines=3 |
|
|
) |
|
|
negative_prompt = gr.Textbox( |
|
|
label="Negative Prompt (What to avoid)", |
|
|
placeholder="Blurry, bad quality, low resolution", |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
with gr.Accordion("Generation Settings", open=True): |
|
|
steps = gr.Slider( |
|
|
minimum=10, |
|
|
maximum=50, |
|
|
step=1, |
|
|
value=30, |
|
|
label="Inference Steps (Higher = Slower/Better)" |
|
|
) |
|
|
seed = gr.Number( |
|
|
value=-1, |
|
|
label="Seed (-1 for random)" |
|
|
) |
|
|
|
|
|
generate_btn = gr.Button("Generate Image (AoT Fast!)", variant="primary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
output_gallery = gr.Gallery( |
|
|
label="Creative VLM Output", |
|
|
show_label=True, |
|
|
height=512, |
|
|
columns=2, |
|
|
object_fit="contain" |
|
|
) |
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_image, |
|
|
inputs=[prompt, negative_prompt, steps, seed], |
|
|
outputs=output_gallery |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["A majestic wolf standing on a snowy mountain peak, cinematic lighting", "ugly, deformed, low detail", 30], |
|
|
["Cyberpunk cat sitting in a neon-lit alley, 8k, digital art", "human, blurry, messy background", 40], |
|
|
["A vintage photograph of a space shuttle launching from a tropical island", "modern, cartoon, painting", 25] |
|
|
], |
|
|
inputs=[prompt, negative_prompt, steps], |
|
|
outputs=output_gallery, |
|
|
fn=generate_image, |
|
|
cache_examples=False, |
|
|
) |
|
|
|
|
|
demo.queue() |
|
|
demo.launch() |