kushagra124 commited on
Commit
d1d4db7
1 Parent(s): 582506c

adding app with CLIP image segmentation

Browse files
Files changed (2) hide show
  1. app.py +15 -6
  2. requirements.txt +3 -1
app.py CHANGED
@@ -5,6 +5,8 @@ import numpy as np
5
  from PIL import Image
6
  import torch
7
  import cv2
 
 
8
  from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig
9
 
10
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
@@ -33,17 +35,17 @@ def detect_using_clip(image,prompts=[],threshould=0.4):
33
  for i,prompt in enumerate(prompts):
34
  predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
35
  predicted_image = np.where(predicted_image>threshould,255,0)
36
- predicted_masks.append(create_rgb_mask(predicted_image))
37
-
38
- return predicted_masks
39
 
40
  def visualize_images(image,predicted_images,brightness=15,contrast=1.8):
41
  alpha = 0.7
42
  image_resize = cv2.resize(image,(352,352))
43
  resize_image_copy = image_resize.copy()
44
 
45
- for mask_image in predicted_images:
46
- resize_image_copy = cv2.addWeighted(resize_image_copy,alpha,mask_image,1-alpha,10)
47
 
48
  return cv2.convertScaleAbs(resize_image_copy, alpha=contrast, beta=brightness)
49
 
@@ -52,10 +54,17 @@ def shot(brightness,contrast,image,labels_text):
52
  prompts = labels_text.split(',')
53
  else:
54
  prompts = [labels_text]
 
55
  prompts = list(map(lambda x: x.strip(),prompts))
 
 
 
 
 
 
56
  predicted_images = detect_using_clip(image,prompts=prompts)
 
57
 
58
- category_image = visualize_images(image=image,predicted_images=predicted_images,brightness=brightness,contrast=contrast)
59
  return category_image
60
 
61
  iface = gr.Interface(fn=shot,
 
5
  from PIL import Image
6
  import torch
7
  import cv2
8
+ from matplotlib import pyplot as plt
9
+ from segmentation_mask_overlay import overlay_masks
10
  from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation,AutoProcessor,AutoConfig
11
 
12
  processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
 
35
  for i,prompt in enumerate(prompts):
36
  predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
37
  predicted_image = np.where(predicted_image>threshould,255,0)
38
+ predicted_masks.append(predicted_image)
39
+ bool_masks = [predicted_mask.astype('bool') for predicted_mask in predicted_masks]
40
+ return bool_masks
41
 
42
  def visualize_images(image,predicted_images,brightness=15,contrast=1.8):
43
  alpha = 0.7
44
  image_resize = cv2.resize(image,(352,352))
45
  resize_image_copy = image_resize.copy()
46
 
47
+ # for mask_image in predicted_images:
48
+ # resize_image_copy = cv2.addWeighted(resize_image_copy,alpha,mask_image,1-alpha,10)
49
 
50
  return cv2.convertScaleAbs(resize_image_copy, alpha=contrast, beta=brightness)
51
 
 
54
  prompts = labels_text.split(',')
55
  else:
56
  prompts = [labels_text]
57
+
58
  prompts = list(map(lambda x: x.strip(),prompts))
59
+
60
+ mask_labels = [f"{prompt}_{i}" for i,prompt in enumerate(prompts)]
61
+ cmap = plt.cm.tab20(np.arange(len(mask_labels)))[..., :-1]
62
+
63
+ resize_image = cv2.resize(image,(352,352))
64
+
65
  predicted_images = detect_using_clip(image,prompts=prompts)
66
+ category_image = overlay_masks(resize_image,np.stack(predicted_images,-1),labels=mask_labels,colors=cmap,alpha=0.4,beta=1)
67
 
 
68
  return category_image
69
 
70
  iface = gr.Interface(fn=shot,
requirements.txt CHANGED
@@ -8,4 +8,6 @@ opencv-python
8
  Pillow
9
  requests
10
  urllib3<2
11
- git+https://github.com/facebookresearch/segment-anything.git
 
 
 
8
  Pillow
9
  requests
10
  urllib3<2
11
+ git+https://github.com/facebookresearch/segment-anything.git
12
+ segmentation_mask_overlay
13
+ matplotlib