Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ import shutil
|
|
4 |
import uuid
|
5 |
import torch
|
6 |
import random
|
7 |
-
|
8 |
import gradio as gr
|
9 |
import numpy as np
|
10 |
|
@@ -16,7 +16,7 @@ from pipeline_flux_control_removal import FluxControlRemovalPipeline
|
|
16 |
|
17 |
torch.set_grad_enabled(False)
|
18 |
os.environ['GRADIO_TEMP_DIR'] = './tmp'
|
19 |
-
device =
|
20 |
print(device)
|
21 |
image_path = mask_path = None
|
22 |
image_examples = [...]
|
@@ -52,7 +52,7 @@ image_examples = [
|
|
52 |
]
|
53 |
|
54 |
]
|
55 |
-
|
56 |
def load_model(base_model_path, lora_path):
|
57 |
global pipe
|
58 |
transformer = FluxTransformer2DModel.from_pretrained(base_model_path, subfolder='transformer', torch_dtype=torch.bfloat16)
|
@@ -86,7 +86,8 @@ def load_model(base_model_path, lora_path):
|
|
86 |
gr.Info(str(f"Inject LoRA: {lora_path}"))
|
87 |
pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors")
|
88 |
gr.Info(str(f"Model loading: {int((100 / 100) * 100)}%"))
|
89 |
-
|
|
|
90 |
def set_seed(seed):
|
91 |
torch.manual_seed(seed)
|
92 |
torch.cuda.manual_seed(seed)
|
@@ -94,7 +95,7 @@ def set_seed(seed):
|
|
94 |
np.random.seed(seed)
|
95 |
random.seed(seed)
|
96 |
|
97 |
-
|
98 |
def predict(
|
99 |
input_image,
|
100 |
prompt,
|
|
|
4 |
import uuid
|
5 |
import torch
|
6 |
import random
|
7 |
+
import spaces
|
8 |
import gradio as gr
|
9 |
import numpy as np
|
10 |
|
|
|
16 |
|
17 |
torch.set_grad_enabled(False)
|
18 |
os.environ['GRADIO_TEMP_DIR'] = './tmp'
|
19 |
+
device = "cuda"
|
20 |
print(device)
|
21 |
image_path = mask_path = None
|
22 |
image_examples = [...]
|
|
|
52 |
]
|
53 |
|
54 |
]
|
55 |
+
@spaces.GPU
|
56 |
def load_model(base_model_path, lora_path):
|
57 |
global pipe
|
58 |
transformer = FluxTransformer2DModel.from_pretrained(base_model_path, subfolder='transformer', torch_dtype=torch.bfloat16)
|
|
|
86 |
gr.Info(str(f"Inject LoRA: {lora_path}"))
|
87 |
pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors")
|
88 |
gr.Info(str(f"Model loading: {int((100 / 100) * 100)}%"))
|
89 |
+
|
90 |
+
@spaces.GPU
|
91 |
def set_seed(seed):
|
92 |
torch.manual_seed(seed)
|
93 |
torch.cuda.manual_seed(seed)
|
|
|
95 |
np.random.seed(seed)
|
96 |
random.seed(seed)
|
97 |
|
98 |
+
@spaces.GPU
|
99 |
def predict(
|
100 |
input_image,
|
101 |
prompt,
|