kxhit commited on
Commit
71f5049
1 Parent(s): b9865ef

rembg->carvekit gpu

Browse files
Files changed (2) hide show
  1. app.py +26 -20
  2. dust3r/utils/image.py +2 -2
app.py CHANGED
@@ -74,7 +74,7 @@ from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline
74
  from segment_anything import sam_model_registry, SamPredictor
75
 
76
  import rembg
77
- # from carvekit.api.high import HiInterface
78
 
79
 
80
  pretrained_model_name_or_path = "kxic/EscherNet_demo"
@@ -130,25 +130,31 @@ def sam_init():
130
  predictor = SamPredictor(sam)
131
  return predictor
132
 
133
- # @spaces.GPU
134
- # def create_carvekit_interface():
135
- # # Check doc strings for more information
136
- # interface = HiInterface(object_type="object", # Can be "object" or "hairs-like".
137
- # batch_size_seg=6,
138
- # batch_size_matting=1,
139
- # device=device,
140
- # seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
141
- # matting_mask_size=2048,
142
- # trimap_prob_threshold=231,
143
- # trimap_dilation=30,
144
- # trimap_erosion_iters=5,
145
- # fp16=True)
146
- #
147
- # return interface
148
-
149
-
150
- rembg_session = rembg.new_session()
151
- # rembg_session = create_carvekit_interface()
 
 
 
 
 
 
152
 
153
  predictor = sam_init()
154
 
 
74
  from segment_anything import sam_model_registry, SamPredictor
75
 
76
  import rembg
77
+ from carvekit.api.high import HiInterface
78
 
79
 
80
  pretrained_model_name_or_path = "kxic/EscherNet_demo"
 
130
  predictor = SamPredictor(sam)
131
  return predictor
132
 
133
+ def create_carvekit_interface():
134
+ # Check doc strings for more information
135
+ interface = HiInterface(object_type="object", # Can be "object" or "hairs-like".
136
+ batch_size_seg=6,
137
+ batch_size_matting=1,
138
+ device="cpu",
139
+ seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
140
+ matting_mask_size=2048,
141
+ trimap_prob_threshold=231,
142
+ trimap_dilation=30,
143
+ trimap_erosion_iters=5,
144
+ fp16=False)
145
+
146
+ return interface
147
+
148
+
149
+ # rembg_session = rembg.new_session()
150
+ rembg_session = create_carvekit_interface()
151
+ rembg_session.u2net = rembg_session.u2net.to(device)
152
+ rembg_session.fba = rembg_session.fba.to(device)
153
+ rembg_session.fba.device = device
154
+ rembg_session.device = device
155
+ rembg_session.u2net.device = device
156
+ # rembg_session.postprocessing_pipeline = rembg_session.postprocessing_pipeline.to(device)
157
+ # rembg_session.postprocessing_pipeline.device = device
158
 
159
  predictor = sam_init()
160
 
dust3r/utils/image.py CHANGED
@@ -119,9 +119,9 @@ def load_images(folder_or_list, size, square_ok=False, verbose=True, do_remove_b
119
  # remove background if needed
120
  if do_remove_background:
121
  # use rembg
122
- image_nobg = remove(img, alpha_matting=True, session=rembg_session)
123
  # use carvekit
124
- # image_nobg = rembg_session([img])[0]
125
  arr = np.asarray(image_nobg)[:, :, -1]
126
  x_nonzero = np.nonzero(arr.sum(axis=0))
127
  y_nonzero = np.nonzero(arr.sum(axis=1))
 
119
  # remove background if needed
120
  if do_remove_background:
121
  # use rembg
122
+ # image_nobg = remove(img, alpha_matting=True, session=rembg_session)
123
  # use carvekit
124
+ image_nobg = rembg_session([img])[0]
125
  arr = np.asarray(image_nobg)[:, :, -1]
126
  x_nonzero = np.nonzero(arr.sum(axis=0))
127
  y_nonzero = np.nonzero(arr.sum(axis=1))