Ashoka74 commited on
Commit
f845948
·
verified ·
1 Parent(s): 2cf63f7

Update inference_i2mv_sdxl.py

Browse files
Files changed (1) hide show
  1. inference_i2mv_sdxl.py +98 -22
inference_i2mv_sdxl.py CHANGED
@@ -151,28 +151,105 @@ def remove_bg(image: Image.Image, net, transform, device, mask: Image.Image = No
151
  # return output_image
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
 
 
 
 
 
155
 
 
 
 
 
156
 
157
- def preprocess_image(image: Image.Image, height, width):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- alpha = image[..., 3] > 0
160
- # alpha = image
161
-
162
- #if image.mode in ("RGBA", "LA"):
163
- # image = np.array(image)
164
- # alpha = image[..., 3] # Extract the alpha channel
165
- #elif image.mode in ("RGB"):
166
- # image = np.array(image)
167
- # Create default alpha for non-alpha images
168
- # alpha = np.ones(image[..., 0].shape, dtype=np.uint8) * 255 # Create
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  H, W = alpha.shape
170
- # get the bounding box of alpha
171
  y, x = np.where(alpha)
172
  y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
173
  x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
174
- image_center = image[y0:y1, x0:x1]
175
- # resize the longer side to H * 0.9
 
176
  H, W, _ = image_center.shape
177
  if H > W:
178
  W = int(W * (height * 0.9) / H)
@@ -180,18 +257,17 @@ def preprocess_image(image: Image.Image, height, width):
180
  else:
181
  H = int(H * (width * 0.9) / W)
182
  W = int(width * 0.9)
 
183
  image_center = np.array(Image.fromarray(image_center).resize((W, H)))
184
- # pad to H, W
 
185
  start_h = (height - H) // 2
186
  start_w = (width - W) // 2
187
- image = np.zeros((height, width, 4), dtype=np.uint8)
188
- image[start_h : start_h + H, start_w : start_w + W] = image_center
189
- image = image.astype(np.float32) / 255.0
190
- image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
191
- image = (image * 255).clip(0, 255).astype(np.uint8)
192
- image = Image.fromarray(image)
193
 
194
- return image
 
195
 
196
 
197
  def run_pipeline(
 
151
  # return output_image
152
 
153
 
154
+ def remove_bg(image: Image.Image, net, transform, device, mask: np.ndarray = None):
155
+ """
156
+ Applies a pre-existing mask to an image to make the background transparent.
157
+
158
+ Args:
159
+ image (PIL.Image.Image): The input image.
160
+ net: Pre-trained neural network (not used but kept for compatibility).
161
+ transform: Image transformation object (not used but kept for compatibility).
162
+ device: Device used for inference (not used but kept for compatibility).
163
+ mask (np.ndarray, optional): The mask to use. Should be the same size
164
+ as the input image, with values between 0 and 255.
165
+ If None, will return image with no changes.
166
 
167
+ Returns:
168
+ PIL.Image.Image: The modified image with transparent background.
169
+ """
170
+ if mask is None:
171
+ return image
172
 
173
+ # Ensure the mask is in the correct format
174
+ if mask.ndim == 2: # If mask is 2D (H, W)
175
+ mask = mask.astype(np.uint8) # Ensure mask is uint8
176
+ mask = np.expand_dims(mask, axis=-1) # Add channel dimension
177
 
178
+ # Convert the mask to PIL Image
179
+ mask_pil = Image.fromarray(mask.squeeze(2) * 255) # Convert to binary mask
180
+
181
+ # Resize the mask to match the original image size
182
+ mask_pil = mask_pil.resize(image.size, Image.LANCZOS)
183
+
184
+ # Create a new image with the same size and mode as the original
185
+ output_image = Image.new("RGBA", image.size)
186
+
187
+ # Apply the mask to the original image
188
+ image.putalpha(mask_pil)
189
+
190
+ # Composite the original image with the mask
191
+ output_image.paste(image, (0, 0), image)
192
+
193
+ return output_image
194
+
195
+
196
+ # def preprocess_image(image: Image.Image, height, width):
197
 
198
+ # alpha = image[..., 3] > 0
199
+ # # alpha = image
200
+
201
+ # #if image.mode in ("RGBA", "LA"):
202
+ # # image = np.array(image)
203
+ # # alpha = image[..., 3] # Extract the alpha channel
204
+ # #elif image.mode in ("RGB"):
205
+ # # image = np.array(image)
206
+ # # Create default alpha for non-alpha images
207
+ # # alpha = np.ones(image[..., 0].shape, dtype=np.uint8) * 255 # Create
208
+ # H, W = alpha.shape
209
+ # # get the bounding box of alpha
210
+ # y, x = np.where(alpha)
211
+ # y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
212
+ # x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
213
+ # image_center = image[y0:y1, x0:x1]
214
+ # # resize the longer side to H * 0.9
215
+ # H, W, _ = image_center.shape
216
+ # if H > W:
217
+ # W = int(W * (height * 0.9) / H)
218
+ # H = int(height * 0.9)
219
+ # else:
220
+ # H = int(H * (width * 0.9) / W)
221
+ # W = int(width * 0.9)
222
+ # image_center = np.array(Image.fromarray(image_center).resize((W, H)))
223
+ # # pad to H, W
224
+ # start_h = (height - H) // 2
225
+ # start_w = (width - W) // 2
226
+ # image = np.zeros((height, width, 4), dtype=np.uint8)
227
+ # image[start_h : start_h + H, start_w : start_w + W] = image_center
228
+ # image = image.astype(np.float32) / 255.0
229
+ # image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
230
+ # image = (image * 255).clip(0, 255).astype(np.uint8)
231
+ # image = Image.fromarray(image)
232
+
233
+ # return image
234
+
235
+ def preprocess_image(image: Image.Image, height, width):
236
+ # Convert image to numpy array
237
+ image_np = np.array(image)
238
+
239
+ # Extract the alpha channel if present
240
+ if image_np.shape[-1] == 4:
241
+ alpha = image_np[..., 3] > 0 # Create a binary mask from the alpha channel
242
+ else:
243
+ alpha = np.ones(image_np[..., 0].shape, dtype=bool) # Default to all true for RGB images
244
+
245
  H, W = alpha.shape
246
+ # Get the bounding box of the alpha
247
  y, x = np.where(alpha)
248
  y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
249
  x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
250
+ image_center = image_np[y0:y1, x0:x1]
251
+
252
+ # Resize the longer side to H * 0.9
253
  H, W, _ = image_center.shape
254
  if H > W:
255
  W = int(W * (height * 0.9) / H)
 
257
  else:
258
  H = int(H * (width * 0.9) / W)
259
  W = int(width * 0.9)
260
+
261
  image_center = np.array(Image.fromarray(image_center).resize((W, H)))
262
+
263
+ # Pad to H, W
264
  start_h = (height - H) // 2
265
  start_w = (width - W) // 2
266
+ padded_image = np.zeros((height, width, 4), dtype=np.uint8)
267
+ padded_image[start_h:start_h + H, start_w:start_w + W] = image_center
 
 
 
 
268
 
269
+ # Convert back to PIL Image
270
+ return Image.fromarray(padded_image)
271
 
272
 
273
  def run_pipeline(