ILFM commited on
Commit
2219a97
·
1 Parent(s): d8145bb

Fixed for GPU provision

Browse files
Files changed (1) hide show
  1. app.py +8 -2
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  import gradio as gr
3
  from diffusers import FluxPipeline
 
4
 
5
  from nunchaku import NunchakuFluxTransformer2dModel
6
  from nunchaku.utils import get_precision
@@ -8,14 +9,19 @@ from nunchaku.utils import get_precision
8
  dtype=torch.bfloat16
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
- precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
 
 
 
 
12
  transformer = NunchakuFluxTransformer2dModel.from_pretrained(
13
- f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
14
  )
15
  pipeline = FluxPipeline.from_pretrained(
16
  "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=dtype
17
  ).to(device)
18
 
 
19
  def generate_image(prompt: str, steps: int, guidance_scale: float):
20
  if not prompt.strip():
21
  raise gr.Error("Prompt cannot be empty.")
 
1
  import torch
2
  import gradio as gr
3
  from diffusers import FluxPipeline
4
+ import spaces
5
 
6
  from nunchaku import NunchakuFluxTransformer2dModel
7
  from nunchaku.utils import get_precision
 
9
  dtype=torch.bfloat16
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
+ @spaces.GPU(duration=10)
13
+ def gpu_precision():
14
+ precision = get_precision()
15
+ return precision
16
+
17
  transformer = NunchakuFluxTransformer2dModel.from_pretrained(
18
+ f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{gpu_precision()}_r32-flux.1-dev.safetensors"
19
  )
20
  pipeline = FluxPipeline.from_pretrained(
21
  "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=dtype
22
  ).to(device)
23
 
24
+ @spaces.GPU(duration=60)
25
  def generate_image(prompt: str, steps: int, guidance_scale: float):
26
  if not prompt.strip():
27
  raise gr.Error("Prompt cannot be empty.")