Husr commited on
Commit
f278e43
·
1 Parent(s): d96ae75

尝试加速推理

Browse files
Files changed (1) hide show
  1. app.py +30 -22
app.py CHANGED
@@ -18,13 +18,17 @@ LORA_PATH = os.environ.get("LORA_PATH", os.path.join("lora", "zit-mystic-xxx.saf
18
  HF_TOKEN = os.environ.get("HF_TOKEN")
19
  ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "false").lower() == "true"
20
  ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "false").lower() == "true"
21
- ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3")
22
- OFFLOAD_TO_CPU_AFTER_RUN = os.environ.get("OFFLOAD_TO_CPU_AFTER_RUN", "true").lower() == "true"
23
- ENABLE_AOTI = os.environ.get("ENABLE_AOTI", "false").lower() == "true"
24
  AOTI_REPO = os.environ.get("AOTI_REPO", "zerogpu-aoti/Z-Image")
25
  AOTI_VARIANT = os.environ.get("AOTI_VARIANT", "fa3")
26
  DEFAULT_CFG = float(os.environ.get("DEFAULT_CFG", "0.0"))
27
 
 
 
 
 
28
  warnings.filterwarnings("ignore")
29
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
30
 
@@ -165,7 +169,7 @@ def set_lora_scale(pipeline: ZImagePipeline, scale: float) -> None:
165
 
166
 
167
  def load_models() -> Tuple[ZImagePipeline, bool, str | None]:
168
- global pipe, lora_loaded, lora_error
169
  if pipe is not None and getattr(pipe, "transformer", None) is not None:
170
  return pipe, lora_loaded, lora_error
171
 
@@ -173,26 +177,31 @@ def load_models() -> Tuple[ZImagePipeline, bool, str | None]:
173
  hf_kwargs = {"use_auth_token": use_auth_token} if use_auth_token else {}
174
  print(f"Loading Z-Image from {MODEL_PATH}...")
175
 
 
 
 
176
  if not os.path.exists(MODEL_PATH):
177
  vae = AutoencoderKL.from_pretrained(
178
  MODEL_PATH,
179
  subfolder="vae",
180
  torch_dtype=torch.bfloat16,
181
  **hf_kwargs,
182
- )
183
  text_encoder = AutoModelForCausalLM.from_pretrained(
184
  MODEL_PATH,
185
  subfolder="text_encoder",
186
  torch_dtype=torch.bfloat16,
187
  **hf_kwargs,
188
- ).eval()
189
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, subfolder="tokenizer", **hf_kwargs)
190
  else:
191
- vae = AutoencoderKL.from_pretrained(os.path.join(MODEL_PATH, "vae"), torch_dtype=torch.bfloat16)
 
 
192
  text_encoder = AutoModelForCausalLM.from_pretrained(
193
  os.path.join(MODEL_PATH, "text_encoder"),
194
  torch_dtype=torch.bfloat16,
195
- ).eval()
196
  tokenizer = AutoTokenizer.from_pretrained(os.path.join(MODEL_PATH, "tokenizer"))
197
 
198
  tokenizer.padding_side = "left"
@@ -215,7 +224,8 @@ def load_models() -> Tuple[ZImagePipeline, bool, str | None]:
215
  applied_backend = set_attention_backend_safe(transformer, ATTENTION_BACKEND)
216
  print(f"Attention backend: {applied_backend}")
217
 
218
- pipeline.transformer = transformer
 
219
 
220
  loaded, error = attach_lora(pipeline)
221
  lora_loaded, lora_error = loaded, error
@@ -225,6 +235,7 @@ def load_models() -> Tuple[ZImagePipeline, bool, str | None]:
225
  print(f"LoRA loaded: {lora_loaded} ({LORA_PATH})")
226
 
227
  pipe = pipeline
 
228
  return pipe, lora_loaded, lora_error
229
 
230
 
@@ -241,7 +252,7 @@ def ensure_models_loaded() -> Tuple[ZImagePipeline, bool, str | None]:
241
 
242
 
243
  def ensure_on_gpu() -> None:
244
- global pipe_on_gpu, aoti_loaded
245
  if pipe is None:
246
  raise gr.Error("Model not loaded.")
247
  if getattr(pipe, "transformer", None) is None:
@@ -250,24 +261,12 @@ def ensure_on_gpu() -> None:
250
  raise gr.Error("CUDA is not available. This Space requires a GPU.")
251
  if pipe_on_gpu:
252
  return
253
-
254
- print("Moving model to GPU...")
255
- pipe.to("cuda", torch.bfloat16)
256
  pipe_on_gpu = True
257
 
258
  if ENABLE_COMPILE:
259
  print("Compiling transformer (torch.compile)...")
260
  pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False)
261
 
262
- if ENABLE_AOTI and not aoti_loaded:
263
- try:
264
- pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
265
- spaces.aoti_blocks_load(pipe.transformer.layers, AOTI_REPO, variant=AOTI_VARIANT)
266
- aoti_loaded = True
267
- print(f"AoTI loaded: {AOTI_REPO} (variant={AOTI_VARIANT})")
268
- except Exception as exc: # noqa: BLE001
269
- print(f"AoTI load failed (continuing without AoTI): {exc}")
270
-
271
 
272
  def offload_to_cpu() -> None:
273
  global pipe_on_gpu
@@ -388,8 +387,17 @@ def warmup_model(pipeline: ZImagePipeline, resolutions: List[str]) -> None:
388
 
389
 
390
  def init_app() -> None:
 
391
  try:
392
  ensure_models_loaded()
 
 
 
 
 
 
 
 
393
  if ENABLE_WARMUP and pipe is not None:
394
  ensure_on_gpu()
395
  try:
 
18
  HF_TOKEN = os.environ.get("HF_TOKEN")
19
  ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "false").lower() == "true"
20
  ENABLE_WARMUP = os.environ.get("ENABLE_WARMUP", "false").lower() == "true"
21
+ ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "_flash_3")
22
+ OFFLOAD_TO_CPU_AFTER_RUN = os.environ.get("OFFLOAD_TO_CPU_AFTER_RUN", "false").lower() == "true"
23
+ ENABLE_AOTI = os.environ.get("ENABLE_AOTI", "true").lower() == "true"
24
  AOTI_REPO = os.environ.get("AOTI_REPO", "zerogpu-aoti/Z-Image")
25
  AOTI_VARIANT = os.environ.get("AOTI_VARIANT", "fa3")
26
  DEFAULT_CFG = float(os.environ.get("DEFAULT_CFG", "0.0"))
27
 
28
+ if torch.cuda.is_available():
29
+ torch.backends.cuda.matmul.allow_tf32 = True
30
+ torch.set_float32_matmul_precision("high")
31
+
32
  warnings.filterwarnings("ignore")
33
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
34
 
 
169
 
170
 
171
  def load_models() -> Tuple[ZImagePipeline, bool, str | None]:
172
+ global pipe, lora_loaded, lora_error, pipe_on_gpu
173
  if pipe is not None and getattr(pipe, "transformer", None) is not None:
174
  return pipe, lora_loaded, lora_error
175
 
 
177
  hf_kwargs = {"use_auth_token": use_auth_token} if use_auth_token else {}
178
  print(f"Loading Z-Image from {MODEL_PATH}...")
179
 
180
+ if not torch.cuda.is_available():
181
+ raise RuntimeError("CUDA is not available. This app requires a GPU.")
182
+
183
  if not os.path.exists(MODEL_PATH):
184
  vae = AutoencoderKL.from_pretrained(
185
  MODEL_PATH,
186
  subfolder="vae",
187
  torch_dtype=torch.bfloat16,
188
  **hf_kwargs,
189
+ ).to("cuda", torch.bfloat16)
190
  text_encoder = AutoModelForCausalLM.from_pretrained(
191
  MODEL_PATH,
192
  subfolder="text_encoder",
193
  torch_dtype=torch.bfloat16,
194
  **hf_kwargs,
195
+ ).to("cuda", torch.bfloat16).eval()
196
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, subfolder="tokenizer", **hf_kwargs)
197
  else:
198
+ vae = AutoencoderKL.from_pretrained(os.path.join(MODEL_PATH, "vae"), torch_dtype=torch.bfloat16).to(
199
+ "cuda", torch.bfloat16
200
+ )
201
  text_encoder = AutoModelForCausalLM.from_pretrained(
202
  os.path.join(MODEL_PATH, "text_encoder"),
203
  torch_dtype=torch.bfloat16,
204
+ ).to("cuda", torch.bfloat16).eval()
205
  tokenizer = AutoTokenizer.from_pretrained(os.path.join(MODEL_PATH, "tokenizer"))
206
 
207
  tokenizer.padding_side = "left"
 
224
  applied_backend = set_attention_backend_safe(transformer, ATTENTION_BACKEND)
225
  print(f"Attention backend: {applied_backend}")
226
 
227
+ pipeline.transformer = transformer.to("cuda", torch.bfloat16)
228
+ pipeline.to("cuda", torch.bfloat16)
229
 
230
  loaded, error = attach_lora(pipeline)
231
  lora_loaded, lora_error = loaded, error
 
235
  print(f"LoRA loaded: {lora_loaded} ({LORA_PATH})")
236
 
237
  pipe = pipeline
238
+ pipe_on_gpu = True
239
  return pipe, lora_loaded, lora_error
240
 
241
 
 
252
 
253
 
254
  def ensure_on_gpu() -> None:
255
+ global pipe_on_gpu
256
  if pipe is None:
257
  raise gr.Error("Model not loaded.")
258
  if getattr(pipe, "transformer", None) is None:
 
261
  raise gr.Error("CUDA is not available. This Space requires a GPU.")
262
  if pipe_on_gpu:
263
  return
 
 
 
264
  pipe_on_gpu = True
265
 
266
  if ENABLE_COMPILE:
267
  print("Compiling transformer (torch.compile)...")
268
  pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False)
269
 
 
 
 
 
 
 
 
 
 
270
 
271
  def offload_to_cpu() -> None:
272
  global pipe_on_gpu
 
387
 
388
 
389
  def init_app() -> None:
390
+ global aoti_loaded
391
  try:
392
  ensure_models_loaded()
393
+ if ENABLE_AOTI and not aoti_loaded and pipe is not None and getattr(pipe, "transformer", None) is not None:
394
+ try:
395
+ pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
396
+ spaces.aoti_blocks_load(pipe.transformer.layers, AOTI_REPO, variant=AOTI_VARIANT)
397
+ aoti_loaded = True
398
+ print(f"AoTI loaded: {AOTI_REPO} (variant={AOTI_VARIANT})")
399
+ except Exception as exc: # noqa: BLE001
400
+ print(f"AoTI load failed (continuing without AoTI): {exc}")
401
  if ENABLE_WARMUP and pipe is not None:
402
  ensure_on_gpu()
403
  try: