aryswisnu commited on
Commit
8e44bc6
·
1 Parent(s): fa45949

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -23
app.py CHANGED
@@ -44,36 +44,35 @@ def get_masks(prompts, img, threhsold):
44
  return masks
45
 
46
 
47
- def extract_image(img, pos_prompts, neg_prompts, threshold, alpha_value=0.5):
48
  positive_masks = get_masks(pos_prompts, img, threshold)
49
  negative_masks = get_masks(neg_prompts, img, threshold)
50
 
51
- # combine masks into one mask, logic OR
52
  pos_mask = np.any(np.stack(positive_masks), axis=0)
53
  neg_mask = np.any(np.stack(negative_masks), axis=0)
54
  final_mask = pos_mask & ~neg_mask
55
 
56
- # threshold the mask
57
- bmask = final_mask > threshold
58
- # zero out values below the threshold
59
- final_mask[final_mask < threshold] = 0
60
-
61
- # convert PIL image to RGBA numpy array
62
- img_np = np.array(img.convert("RGBA"))
63
- # create an empty RGBA image with the same size
64
- output_image = np.zeros_like(img_np, dtype=np.uint8)
65
-
66
- # apply the final_mask as alpha channel on the output image
67
- output_image[:, :, :3] = img_np[:, :, :3]
68
- output_image[:, :, 3] = (final_mask * alpha_value * 255).astype(np.uint8)
69
-
70
- # convert the output_image, final_mask, and bmask back to PIL.Image objects
71
- output_image = Image.fromarray(output_image, "RGBA")
72
- final_mask = Image.fromarray((final_mask * 255).astype(np.uint8), "L")
73
- bmask = Image.fromarray((bmask * 255).astype(np.uint8), "L")
74
-
75
- return output_image, final_mask, bmask
76
-
77
 
78
 
79
 
 
44
  return masks
45
 
46
 
47
+ def extract_image(img, pos_prompts, neg_prompts, threshold, alpha_value=0.5, blur_radius=5):
48
  positive_masks = get_masks(pos_prompts, img, threshold)
49
  negative_masks = get_masks(neg_prompts, img, threshold)
50
 
51
+ # combine masks into one masks, logic OR
52
  pos_mask = np.any(np.stack(positive_masks), axis=0)
53
  neg_mask = np.any(np.stack(negative_masks), axis=0)
54
  final_mask = pos_mask & ~neg_mask
55
 
56
+ # apply Gaussian blur for feathering
57
+ final_mask_img = Image.fromarray((final_mask * 255).astype(np.uint8), "L")
58
+ final_mask_img = final_mask_img.filter(ImageFilter.GaussianBlur(radius=blur_radius))
59
+ final_mask = np.array(final_mask_img) / 255
60
+ final_mask = final_mask > threshold
61
+
62
+ # blend the original image and the mask using the alpha value
63
+ fig, ax = plt.subplots()
64
+ ax.imshow(img)
65
+ ax.imshow(final_mask, alpha=alpha_value, cmap="jet")
66
+ ax.axis("off")
67
+ plt.tight_layout()
68
+
69
+ # extract the final image
70
+ final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L")
71
+ inverse_mask = np.invert(final_mask)
72
+ output_image = Image.new("RGBA", img.size, (0, 0, 0, 0))
73
+ output_image.paste(img, mask=final_mask)
74
+
75
+ return output_image, final_mask, inverse_mask
 
76
 
77
 
78