sayakpaul HF Staff commited on
Commit
6580845
·
1 Parent(s): ebbd677
Files changed (2) hide show
  1. app.py +4 -1
  2. optimization.py +0 -1
app.py CHANGED
@@ -12,8 +12,11 @@ dtype = torch.bfloat16
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  # Load the model pipeline
15
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", torch_dtype=dtype).to(device)
 
 
16
 
 
17
  def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken, progress=gr.Progress(track_tqdm=True)):
18
  if not filename.endswith(".pt2"):
19
  raise NotImplementedError("The filename must end with a `.pt2` extension.")
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
 
14
  # Load the model pipeline
15
+ pipe = DiffusionPipeline.from_pretrained(
16
+ "black-forest-labs/Flux.1-Dev", torch_dtype=dtype
17
+ ).to(device)
18
 
19
+ @spaces.GPU(duration=1500)
20
  def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken, progress=gr.Progress(track_tqdm=True)):
21
  if not filename.endswith(".pt2"):
22
  raise NotImplementedError("The filename must end with a `.pt2` extension.")
optimization.py CHANGED
@@ -27,7 +27,6 @@ INDUCTOR_CONFIGS = {
27
 
28
 
29
  def compile_transformer(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
30
- @spaces.GPU(duration=1500)
31
  def f():
32
  with spaces.aoti_capture(pipeline.transformer) as call:
33
  pipeline(*args, **kwargs)
 
27
 
28
 
29
  def compile_transformer(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
 
30
  def f():
31
  with spaces.aoti_capture(pipeline.transformer) as call:
32
  pipeline(*args, **kwargs)