Daankular commited on
Commit
6ea654b
·
1 Parent(s): 4296237

Patch image_process.py rmbg_net None guard; fix RMBG-2.0 loading with device_map

Browse files
Files changed (1) hide show
  1. app.py +24 -2
app.py CHANGED
@@ -191,6 +191,26 @@ def load_triposg():
191
  if str(triposg_src) not in sys.path:
192
  sys.path.insert(0, str(triposg_src))
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  # Safety net: patch inference_utils.py to make diso import optional.
195
  # Even if diso compiled with submodules, guard against any residual link errors.
196
  _iu_path = triposg_src / "triposg" / "inference_utils.py"
@@ -229,9 +249,11 @@ def load_triposg():
229
 
230
  try:
231
  from transformers import AutoModelForImageSegmentation
 
 
232
  _rmbg_net = AutoModelForImageSegmentation.from_pretrained(
233
- "1038lab/RMBG-2.0", trust_remote_code=True, low_cpu_mem_usage=False
234
- ).to(DEVICE)
235
  _rmbg_net.eval()
236
  _rmbg_version = "2.0"
237
  print("[load_triposg] TripoSG + RMBG-2.0 loaded.")
 
191
  if str(triposg_src) not in sys.path:
192
  sys.path.insert(0, str(triposg_src))
193
 
194
+ # Patch image_process.py: guard rmbg_net=None in load_image.
195
+ # TripoSG calls rmbg(rgb_image_resized) unconditionally when alpha is None,
196
+ # with no check for rmbg_net being None. Fallback: all-white alpha (full foreground).
197
+ _ip_path = triposg_src / "scripts" / "image_process.py"
198
+ if _ip_path.exists():
199
+ _ip_text = _ip_path.read_text()
200
+ if "rmbg_net_none_guard" not in _ip_text:
201
+ _ip_text = _ip_text.replace(
202
+ " # seg from rmbg\n alpha_gpu_rmbg = rmbg(rgb_image_resized)",
203
+ " # seg from rmbg\n"
204
+ " if rmbg_net is None: # rmbg_net_none_guard\n"
205
+ " alpha_gpu_rmbg = torch.ones(\n"
206
+ " 1, rgb_image_resized.shape[1], rgb_image_resized.shape[2],\n"
207
+ " device=rgb_image_resized.device)\n"
208
+ " else:\n"
209
+ " alpha_gpu_rmbg = rmbg(rgb_image_resized)",
210
+ )
211
+ _ip_path.write_text(_ip_text)
212
+ print("[load_triposg] Patched image_process.py: rmbg_net None guard")
213
+
214
  # Safety net: patch inference_utils.py to make diso import optional.
215
  # Even if diso compiled with submodules, guard against any residual link errors.
216
  _iu_path = triposg_src / "triposg" / "inference_utils.py"
 
249
 
250
  try:
251
  from transformers import AutoModelForImageSegmentation
252
+ # device_map loads weights directly to GPU, avoiding the meta tensor
253
+ # intermediate that causes "Tensor.item() cannot be called on meta tensors"
254
  _rmbg_net = AutoModelForImageSegmentation.from_pretrained(
255
+ "1038lab/RMBG-2.0", trust_remote_code=True, device_map=DEVICE
256
+ )
257
  _rmbg_net.eval()
258
  _rmbg_version = "2.0"
259
  print("[load_triposg] TripoSG + RMBG-2.0 loaded.")