Spaces:
Running on Zero
Running on Zero
Patch image_process.py rmbg_net None guard; fix RMBG-2.0 loading with device_map
Browse files
app.py
CHANGED
|
@@ -191,6 +191,26 @@ def load_triposg():
|
|
| 191 |
if str(triposg_src) not in sys.path:
|
| 192 |
sys.path.insert(0, str(triposg_src))
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
# Safety net: patch inference_utils.py to make diso import optional.
|
| 195 |
# Even if diso compiled with submodules, guard against any residual link errors.
|
| 196 |
_iu_path = triposg_src / "triposg" / "inference_utils.py"
|
|
@@ -229,9 +249,11 @@ def load_triposg():
|
|
| 229 |
|
| 230 |
try:
|
| 231 |
from transformers import AutoModelForImageSegmentation
|
|
|
|
|
|
|
| 232 |
_rmbg_net = AutoModelForImageSegmentation.from_pretrained(
|
| 233 |
-
"1038lab/RMBG-2.0", trust_remote_code=True,
|
| 234 |
-
)
|
| 235 |
_rmbg_net.eval()
|
| 236 |
_rmbg_version = "2.0"
|
| 237 |
print("[load_triposg] TripoSG + RMBG-2.0 loaded.")
|
|
|
|
| 191 |
if str(triposg_src) not in sys.path:
|
| 192 |
sys.path.insert(0, str(triposg_src))
|
| 193 |
|
| 194 |
+
# Patch image_process.py: guard rmbg_net=None in load_image.
|
| 195 |
+
# TripoSG calls rmbg(rgb_image_resized) unconditionally when alpha is None,
|
| 196 |
+
# with no check for rmbg_net being None. Fallback: all-white alpha (full foreground).
|
| 197 |
+
_ip_path = triposg_src / "scripts" / "image_process.py"
|
| 198 |
+
if _ip_path.exists():
|
| 199 |
+
_ip_text = _ip_path.read_text()
|
| 200 |
+
if "rmbg_net_none_guard" not in _ip_text:
|
| 201 |
+
_ip_text = _ip_text.replace(
|
| 202 |
+
" # seg from rmbg\n alpha_gpu_rmbg = rmbg(rgb_image_resized)",
|
| 203 |
+
" # seg from rmbg\n"
|
| 204 |
+
" if rmbg_net is None: # rmbg_net_none_guard\n"
|
| 205 |
+
" alpha_gpu_rmbg = torch.ones(\n"
|
| 206 |
+
" 1, rgb_image_resized.shape[1], rgb_image_resized.shape[2],\n"
|
| 207 |
+
" device=rgb_image_resized.device)\n"
|
| 208 |
+
" else:\n"
|
| 209 |
+
" alpha_gpu_rmbg = rmbg(rgb_image_resized)",
|
| 210 |
+
)
|
| 211 |
+
_ip_path.write_text(_ip_text)
|
| 212 |
+
print("[load_triposg] Patched image_process.py: rmbg_net None guard")
|
| 213 |
+
|
| 214 |
# Safety net: patch inference_utils.py to make diso import optional.
|
| 215 |
# Even if diso compiled with submodules, guard against any residual link errors.
|
| 216 |
_iu_path = triposg_src / "triposg" / "inference_utils.py"
|
|
|
|
| 249 |
|
| 250 |
try:
|
| 251 |
from transformers import AutoModelForImageSegmentation
|
| 252 |
+
# device_map loads weights directly to GPU, avoiding the meta tensor
|
| 253 |
+
# intermediate that causes "Tensor.item() cannot be called on meta tensors"
|
| 254 |
_rmbg_net = AutoModelForImageSegmentation.from_pretrained(
|
| 255 |
+
"1038lab/RMBG-2.0", trust_remote_code=True, device_map=DEVICE
|
| 256 |
+
)
|
| 257 |
_rmbg_net.eval()
|
| 258 |
_rmbg_version = "2.0"
|
| 259 |
print("[load_triposg] TripoSG + RMBG-2.0 loaded.")
|