theSure commited on
Commit
005f7bb
·
verified ·
1 Parent(s): 7492032

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -4,7 +4,7 @@ import shutil
4
  import uuid
5
  import torch
6
  import random
7
-
8
  import gradio as gr
9
  import numpy as np
10
 
@@ -16,7 +16,7 @@ from pipeline_flux_control_removal import FluxControlRemovalPipeline
16
 
17
  torch.set_grad_enabled(False)
18
  os.environ['GRADIO_TEMP_DIR'] = './tmp'
19
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  print(device)
21
  image_path = mask_path = None
22
  image_examples = [...]
@@ -52,7 +52,7 @@ image_examples = [
52
  ]
53
 
54
  ]
55
-
56
  def load_model(base_model_path, lora_path):
57
  global pipe
58
  transformer = FluxTransformer2DModel.from_pretrained(base_model_path, subfolder='transformer', torch_dtype=torch.bfloat16)
@@ -86,7 +86,8 @@ def load_model(base_model_path, lora_path):
86
  gr.Info(str(f"Inject LoRA: {lora_path}"))
87
  pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors")
88
  gr.Info(str(f"Model loading: {int((100 / 100) * 100)}%"))
89
-
 
90
  def set_seed(seed):
91
  torch.manual_seed(seed)
92
  torch.cuda.manual_seed(seed)
@@ -94,7 +95,7 @@ def set_seed(seed):
94
  np.random.seed(seed)
95
  random.seed(seed)
96
 
97
-
98
  def predict(
99
  input_image,
100
  prompt,
 
4
  import uuid
5
  import torch
6
  import random
7
+ import spaces
8
  import gradio as gr
9
  import numpy as np
10
 
 
16
 
17
  torch.set_grad_enabled(False)
18
  os.environ['GRADIO_TEMP_DIR'] = './tmp'
19
+ device = "cuda"
20
  print(device)
21
  image_path = mask_path = None
22
  image_examples = [...]
 
52
  ]
53
 
54
  ]
55
+ @spaces.GPU
56
  def load_model(base_model_path, lora_path):
57
  global pipe
58
  transformer = FluxTransformer2DModel.from_pretrained(base_model_path, subfolder='transformer', torch_dtype=torch.bfloat16)
 
86
  gr.Info(str(f"Inject LoRA: {lora_path}"))
87
  pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors")
88
  gr.Info(str(f"Model loading: {int((100 / 100) * 100)}%"))
89
+
90
+ @spaces.GPU
91
  def set_seed(seed):
92
  torch.manual_seed(seed)
93
  torch.cuda.manual_seed(seed)
 
95
  np.random.seed(seed)
96
  random.seed(seed)
97
 
98
+ @spaces.GPU
99
  def predict(
100
  input_image,
101
  prompt,