terryyz commited on
Commit
c148ec9
·
verified ·
1 Parent(s): 1af71b3

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +158 -0
app.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import torch
3
+ import gradio as gr
4
+ from diffusers import StableDiffusionPipeline
5
+ from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig
6
+ import os
7
+
8
+ # --- 1. Model Loading and Optimization (AoT Compilation) ---
9
+
10
+ # Choose a stable diffusion model
11
+ MODEL_ID = "runwayml/stable-diffusion-v1-5"
12
+
13
+ # Initialize pipeline, disable safety checker for faster compilation and inference
14
+ # Use torch.float16 for efficiency on CUDA hardware
15
+ pipe = StableDiffusionPipeline.from_pretrained(
16
+ MODEL_ID,
17
+ torch_dtype=torch.float16,
18
+ safety_checker=None,
19
+ requires_safety_checker=False
20
+ )
21
+ pipe.to('cuda')
22
+ pipe.scheduler.set_timesteps(50) # Set max steps for consistent performance testing
23
+
24
+ print("Starting AoT Compilation...")
25
+
26
+ @spaces.GPU(duration=1500) # Reserve maximum time for startup compilation
27
+ def compile_optimized_unet():
28
+ # 1. Apply FP8 quantization (optional, requires H200/H100 for maximum benefit)
29
+ try:
30
+ quantize_(pipe.unet, Float8DynamicActivationFloat8WeightConfig())
31
+ print("✅ Applied FP8 quantization to UNet.")
32
+ except Exception as e:
33
+ print(f"⚠️ FP8 Quantization failed (may require specific hardware/libraries): {e}")
34
+
35
+ # 2. Define and capture example inputs for the UNet (the core engine)
36
+ # Standard Stable Diffusion UNet inputs (batch_size=2 for classifier-free guidance)
37
+ bsz = 2
38
+ latent_model_input = torch.randn(bsz, 4, 64, 64, device="cuda", dtype=torch.float16)
39
+ t = torch.randint(0, 1000, (bsz,), device="cuda')
40
+ encoder_hidden_states = torch.randn(bsz, 77, 768, device="cuda", dtype=torch.float16)
41
+
42
+ with spaces.aoti_capture(pipe.unet) as call:
43
+ pipe.unet(latent_model_input, t, encoder_hidden_states)
44
+
45
+ # 3. Export the model
46
+ exported = torch.export.export(
47
+ pipe.unet,
48
+ args=call.args,
49
+ kwargs=call.kwargs,
50
+ )
51
+
52
+ # 4. Compile the exported model using AoT
53
+ return spaces.aoti_compile(exported)
54
+
55
+ # Execute compilation during startup
56
+ compiled_unet = compile_optimized_unet()
57
+ # 5. Apply compiled model to the pipeline's UNet component
58
+ spaces.aoti_apply(compiled_unet, pipe.unet)
59
+
60
+ print("✅ AoT Compilation completed successfully.")
61
+
62
+ # --- 2. Inference Function (Running on GPU) ---
63
+
64
+ @spaces.GPU(duration=60) # Standard duration for image generation
65
+ def generate_image(
66
+ prompt: str,
67
+ negative_prompt: str,
68
+ steps: int,
69
+ seed: int
70
+ ):
71
+ if not prompt:
72
+ raise gr.Error("Prompt cannot be empty.")
73
+
74
+ generator = torch.Generator(device="cuda").manual_seed(seed) if seed != -1 else None
75
+
76
+ steps = int(steps)
77
+
78
+ # Run inference using the optimized pipeline
79
+ result = pipe(
80
+ prompt=prompt,
81
+ negative_prompt=negative_prompt,
82
+ num_inference_steps=steps,
83
+ guidance_scale=7.5,
84
+ generator=generator
85
+ ).images
86
+
87
+ return result
88
+
89
+ # --- 3. Gradio Interface ---
90
+
91
+ with gr.Blocks(title="Optimized Vision Model (AoT Powered)") as demo:
92
+ gr.HTML(
93
+ """
94
+ <div style="text-align: center; max-width: 800px; margin: 0 auto;">
95
+ <h1><a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">Built with anycoder</a></h1>
96
+ <h2>High-Performance Creative VLM Simulator (AoT Optimized)</h2>
97
+ <p>This demo simulates a creative Vision Language Model using AoT-compiled Stable Diffusion for lightning-fast image generation.</p>
98
+ </div>
99
+ """
100
+ )
101
+
102
+ with gr.Row():
103
+ with gr.Column(scale=1):
104
+ prompt = gr.Textbox(
105
+ label="Prompt (Input to VLM)",
106
+ placeholder="A futuristic city painted by Van Gogh, highly detailed.",
107
+ lines=3
108
+ )
109
+ negative_prompt = gr.Textbox(
110
+ label="Negative Prompt (What to avoid)",
111
+ placeholder="Blurry, bad quality, low resolution",
112
+ lines=2
113
+ )
114
+
115
+ with gr.Accordion("Generation Settings", open=True):
116
+ steps = gr.Slider(
117
+ minimum=10,
118
+ maximum=50,
119
+ step=1,
120
+ value=30,
121
+ label="Inference Steps (Higher = Slower/Better)"
122
+ )
123
+ seed = gr.Number(
124
+ value=-1,
125
+ label="Seed (-1 for random)"
126
+ )
127
+
128
+ generate_btn = gr.Button("Generate Image (AoT Fast!)", variant="primary")
129
+
130
+ with gr.Column(scale=2):
131
+ output_gallery = gr.Gallery(
132
+ label="Creative VLM Output",
133
+ show_label=True,
134
+ height=512,
135
+ columns=2,
136
+ object_fit="contain"
137
+ )
138
+
139
+ generate_btn.click(
140
+ fn=generate_image,
141
+ inputs=[prompt, negative_prompt, steps, seed],
142
+ outputs=output_gallery
143
+ )
144
+
145
+ gr.Examples(
146
+ examples=[
147
+ ["A majestic wolf standing on a snowy mountain peak, cinematic lighting", "ugly, deformed, low detail", 30],
148
+ ["Cyberpunk cat sitting in a neon-lit alley, 8k, digital art", "human, blurry, messy background", 40],
149
+ ["A vintage photograph of a space shuttle launching from a tropical island", "modern, cartoon, painting", 25]
150
+ ],
151
+ inputs=[prompt, negative_prompt, steps],
152
+ outputs=output_gallery,
153
+ fn=generate_image,
154
+ cache_examples=False,
155
+ )
156
+
157
+ demo.queue()
158
+ demo.launch()