Update app.py
Browse files
app.py
CHANGED
|
@@ -11,8 +11,8 @@ from diffusers.models.attention_processor import AttnProcessor2_0
|
|
| 11 |
from custom_pipeline import FluxWithCFGPipeline
|
| 12 |
|
| 13 |
# --- Torch Optimizations ---
|
| 14 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
| 15 |
-
torch.backends.cudnn.benchmark = True # Enable cuDNN benchmark for potentially faster convolutions
|
| 16 |
|
| 17 |
# --- Constants ---
|
| 18 |
MAX_SEED = np.iinfo(np.int32).max
|
|
@@ -27,39 +27,30 @@ ENHANCE_STEPS = 2 # Fixed steps for the enhance button
|
|
| 27 |
# --- Device and Model Setup ---
|
| 28 |
dtype = torch.float16
|
| 29 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 30 |
-
pipe = None # Initialize pipe to None
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
|
| 37 |
-
|
| 38 |
-
pipe.to(device)
|
| 39 |
|
| 40 |
-
|
| 41 |
-
pipe.unet.set_attn_processor(AttnProcessor2_0())
|
| 42 |
-
pipe.vae.set_attn_processor(AttnProcessor2_0()) # VAE might benefit too
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
pipe.unload_lora_weights() # Unload after fusing
|
| 48 |
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
except Exception as e:
|
| 58 |
-
print(e)
|
| 59 |
|
| 60 |
|
| 61 |
# --- Inference Function ---
|
| 62 |
-
@spaces.GPU
|
| 63 |
def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, height: int = DEFAULT_HEIGHT, randomize_seed: bool = False, num_inference_steps: int = DEFAULT_INFERENCE_STEPS, is_enhance: bool = False):
|
| 64 |
"""Generates an image using the FLUX pipeline with error handling."""
|
| 65 |
|
|
@@ -119,20 +110,6 @@ def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, heig
|
|
| 119 |
raise gr.Error(f"An error occurred during generation: {e}")
|
| 120 |
|
| 121 |
|
| 122 |
-
# --- Real-time Generation Wrapper ---
|
| 123 |
-
# This function checks the realtime toggle before calling the main generation function.
|
| 124 |
-
# It's triggered by changes in prompt or sliders when realtime is enabled.
|
| 125 |
-
def handle_realtime_update(realtime_enabled: bool, prompt: str, seed: int, width: int, height: int, randomize_seed: bool, num_inference_steps: int):
|
| 126 |
-
if realtime_enabled and pipe is not None:
|
| 127 |
-
# Call generate_image directly. Errors within generate_image will be caught and raised as gr.Error.
|
| 128 |
-
# We don't set is_enhance=True for realtime updates.
|
| 129 |
-
return generate_image(prompt, seed, width, height, randomize_seed, num_inference_steps, is_enhance=False)
|
| 130 |
-
else:
|
| 131 |
-
# If realtime is disabled or pipe failed, don't update the image, seed, or latency.
|
| 132 |
-
# Return gr.update() for each output component to indicate no change.
|
| 133 |
-
return gr.update(), gr.update(), gr.update()
|
| 134 |
-
|
| 135 |
-
|
| 136 |
# --- Example Prompts ---
|
| 137 |
examples = [
|
| 138 |
"a tiny astronaut hatching from an egg on the moon",
|
|
@@ -195,9 +172,7 @@ with gr.Blocks() as demo:
|
|
| 195 |
fn=generate_image,
|
| 196 |
inputs=[prompt, seed, width, height],
|
| 197 |
outputs=[result, seed, latency],
|
| 198 |
-
show_progress="full"
|
| 199 |
-
queue=False,
|
| 200 |
-
concurrency_limit=None,
|
| 201 |
)
|
| 202 |
|
| 203 |
generateBtn.click(
|
|
@@ -206,7 +181,6 @@ with gr.Blocks() as demo:
|
|
| 206 |
outputs=[result, seed, latency],
|
| 207 |
show_progress="full",
|
| 208 |
api_name="RealtimeFlux",
|
| 209 |
-
queue=False
|
| 210 |
)
|
| 211 |
|
| 212 |
def update_ui(realtime_enabled):
|
|
@@ -222,21 +196,14 @@ with gr.Blocks() as demo:
|
|
| 222 |
realtime.change(
|
| 223 |
fn=update_ui,
|
| 224 |
inputs=[realtime],
|
| 225 |
-
outputs=[prompt, generateBtn]
|
| 226 |
-
queue=False,
|
| 227 |
-
concurrency_limit=None
|
| 228 |
)
|
| 229 |
|
| 230 |
-
# Removed the intermediate realtime_generation function.
|
| 231 |
-
# handle_realtime_update checks the realtime toggle internally.
|
| 232 |
-
|
| 233 |
prompt.submit(
|
| 234 |
fn=generate_image,
|
| 235 |
inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps],
|
| 236 |
outputs=[result, seed, latency],
|
| 237 |
-
show_progress="full"
|
| 238 |
-
queue=False,
|
| 239 |
-
concurrency_limit=None
|
| 240 |
)
|
| 241 |
|
| 242 |
for component in [prompt, width, height, num_inference_steps]:
|
|
@@ -245,9 +212,7 @@ with gr.Blocks() as demo:
|
|
| 245 |
inputs=[realtime, prompt, seed, width, height, randomize_seed, num_inference_steps],
|
| 246 |
outputs=[result, seed, latency],
|
| 247 |
show_progress="hidden",
|
| 248 |
-
trigger_mode="always_last"
|
| 249 |
-
queue=False,
|
| 250 |
-
concurrency_limit=None
|
| 251 |
)
|
| 252 |
|
| 253 |
# Launch the app
|
|
|
|
| 11 |
from custom_pipeline import FluxWithCFGPipeline
|
| 12 |
|
| 13 |
# --- Torch Optimizations ---
|
| 14 |
+
# torch.backends.cuda.matmul.allow_tf32 = True
|
| 15 |
+
# torch.backends.cudnn.benchmark = True # Enable cuDNN benchmark for potentially faster convolutions
|
| 16 |
|
| 17 |
# --- Constants ---
|
| 18 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
|
| 27 |
# --- Device and Model Setup ---
|
| 28 |
dtype = torch.float16
|
| 29 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 30 |
|
| 31 |
+
pipe = FluxWithCFGPipeline.from_pretrained(
|
| 32 |
+
"black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
|
| 33 |
+
)
|
| 34 |
+
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
pipe.to(device)
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
# Apply optimizations
|
| 39 |
+
pipe.unet.set_attn_processor(AttnProcessor2_0())
|
| 40 |
+
pipe.vae.set_attn_processor(AttnProcessor2_0()) # VAE might benefit too
|
|
|
|
| 41 |
|
| 42 |
+
pipe.load_lora_weights('hugovntr/flux-schnell-realism', weight_name='schnell-realism_v2.3.safetensors', adapter_name="better")
|
| 43 |
+
pipe.set_adapters(["better"], adapter_weights=[1.0])
|
| 44 |
+
pipe.fuse_lora(adapter_name=["better"], lora_scale=1.0) # Fuse for potential speedup
|
| 45 |
+
pipe.unload_lora_weights() # Unload after fusing
|
| 46 |
|
| 47 |
+
# --- Compilation (Major Speed Optimization) ---
|
| 48 |
+
pipe.vae.decoder = torch.compile(pipe.vae.decoder, mode="reduce-overhead", fullgraph=True)
|
| 49 |
+
pipe.vae.encoder = torch.compile(pipe.vae.encoder, mode="reduce-overhead", fullgraph=True)
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
# --- Inference Function ---
|
| 53 |
+
@spaces.GPU
|
| 54 |
def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, height: int = DEFAULT_HEIGHT, randomize_seed: bool = False, num_inference_steps: int = DEFAULT_INFERENCE_STEPS, is_enhance: bool = False):
|
| 55 |
"""Generates an image using the FLUX pipeline with error handling."""
|
| 56 |
|
|
|
|
| 110 |
raise gr.Error(f"An error occurred during generation: {e}")
|
| 111 |
|
| 112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
# --- Example Prompts ---
|
| 114 |
examples = [
|
| 115 |
"a tiny astronaut hatching from an egg on the moon",
|
|
|
|
| 172 |
fn=generate_image,
|
| 173 |
inputs=[prompt, seed, width, height],
|
| 174 |
outputs=[result, seed, latency],
|
| 175 |
+
show_progress="full"
|
|
|
|
|
|
|
| 176 |
)
|
| 177 |
|
| 178 |
generateBtn.click(
|
|
|
|
| 181 |
outputs=[result, seed, latency],
|
| 182 |
show_progress="full",
|
| 183 |
api_name="RealtimeFlux",
|
|
|
|
| 184 |
)
|
| 185 |
|
| 186 |
def update_ui(realtime_enabled):
|
|
|
|
| 196 |
realtime.change(
|
| 197 |
fn=update_ui,
|
| 198 |
inputs=[realtime],
|
| 199 |
+
outputs=[prompt, generateBtn]
|
|
|
|
|
|
|
| 200 |
)
|
| 201 |
|
|
|
|
|
|
|
|
|
|
| 202 |
prompt.submit(
|
| 203 |
fn=generate_image,
|
| 204 |
inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps],
|
| 205 |
outputs=[result, seed, latency],
|
| 206 |
+
show_progress="full"
|
|
|
|
|
|
|
| 207 |
)
|
| 208 |
|
| 209 |
for component in [prompt, width, height, num_inference_steps]:
|
|
|
|
| 212 |
inputs=[realtime, prompt, seed, width, height, randomize_seed, num_inference_steps],
|
| 213 |
outputs=[result, seed, latency],
|
| 214 |
show_progress="hidden",
|
| 215 |
+
trigger_mode="always_last"
|
|
|
|
|
|
|
| 216 |
)
|
| 217 |
|
| 218 |
# Launch the app
|