sayakpaul HF Staff commited on
Commit
f61fb8b
·
1 Parent(s): da862e5
Files changed (2) hide show
  1. app.py +1 -1
  2. optimization.py +2 -0
app.py CHANGED
@@ -18,7 +18,7 @@ pipe.transformer.fuse_qkv_projections()
18
  pipe.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())
19
 
20
  @spaces.GPU(duration=1200)
21
- def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken, progress=gr.Progress(track_tqdm=True)):
22
  if not filename.endswith(".pt2"):
23
  raise NotImplementedError("The filename must end with a `.pt2` extension.")
24
 
 
18
  pipe.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())
19
 
20
  @spaces.GPU(duration=1200)
21
+ def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken):
22
  if not filename.endswith(".pt2"):
23
  raise NotImplementedError("The filename must end with a `.pt2` extension.")
24
 
optimization.py CHANGED
@@ -41,5 +41,7 @@ def compile_transformer(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.k
41
  print("Export done.")
42
  return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
43
 
 
44
  compiled_transformer = f()
 
45
  return compiled_transformer
 
41
  print("Export done.")
42
  return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
43
 
44
+ print(f"{pipeline.transformer.device=}")
45
  compiled_transformer = f()
46
+ print("Compilation done.")
47
  return compiled_transformer