Daankular commited on
Commit
78469a2
·
1 Parent(s): dfd6c0e

Fix RMBG loading with torch.device(cpu) context; fix None fallback alpha shape [1,1,H,W]

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -197,13 +197,13 @@ def load_triposg():
197
  _ip_path = triposg_src / "scripts" / "image_process.py"
198
  if _ip_path.exists():
199
  _ip_text = _ip_path.read_text()
200
- if "rmbg_net_none_guard" not in _ip_text:
201
  _ip_text = _ip_text.replace(
202
  " # seg from rmbg\n alpha_gpu_rmbg = rmbg(rgb_image_resized)",
203
  " # seg from rmbg\n"
204
- " if rmbg_net is None: # rmbg_net_none_guard\n"
205
  " alpha_gpu_rmbg = torch.ones(\n"
206
- " 1, rgb_image_resized.shape[1], rgb_image_resized.shape[2],\n"
207
  " device=rgb_image_resized.device)\n"
208
  " else:\n"
209
  " alpha_gpu_rmbg = rmbg(rgb_image_resized)",
@@ -249,9 +249,14 @@ def load_triposg():
249
 
250
  try:
251
  from transformers import AutoModelForImageSegmentation
252
- _rmbg_net = AutoModelForImageSegmentation.from_pretrained(
253
- "1038lab/RMBG-2.0", trust_remote_code=True
254
- )
 
 
 
 
 
255
  torch.set_float32_matmul_precision("high")
256
  _rmbg_net.to(DEVICE)
257
  _rmbg_net.eval()
 
197
  _ip_path = triposg_src / "scripts" / "image_process.py"
198
  if _ip_path.exists():
199
  _ip_text = _ip_path.read_text()
200
+ if "rmbg_net_none_guard_v2" not in _ip_text:
201
  _ip_text = _ip_text.replace(
202
  " # seg from rmbg\n alpha_gpu_rmbg = rmbg(rgb_image_resized)",
203
  " # seg from rmbg\n"
204
+ " if rmbg_net is None: # rmbg_net_none_guard_v2\n"
205
  " alpha_gpu_rmbg = torch.ones(\n"
206
+ " 1, 1, rgb_image_resized.shape[1], rgb_image_resized.shape[2],\n"
207
  " device=rgb_image_resized.device)\n"
208
  " else:\n"
209
  " alpha_gpu_rmbg = rmbg(rgb_image_resized)",
 
249
 
250
  try:
251
  from transformers import AutoModelForImageSegmentation
252
+ # torch.device('cpu') context forces all tensor creation to real CPU memory,
253
+ # bypassing any meta-device context left active by TripoSGPipeline loading.
254
+ # BiRefNet's __init__ creates Config() instances and calls eval() on class
255
+ # names — these fire during meta-device init and crash with .item() errors.
256
+ with torch.device("cpu"):
257
+ _rmbg_net = AutoModelForImageSegmentation.from_pretrained(
258
+ "1038lab/RMBG-2.0", trust_remote_code=True, low_cpu_mem_usage=False
259
+ )
260
  torch.set_float32_matmul_precision("high")
261
  _rmbg_net.to(DEVICE)
262
  _rmbg_net.eval()