kushagra124 commited on
Commit
42545c2
1 Parent(s): b2f9f4f

adding text box

Browse files
Files changed (1) hide show
  1. app.py +18 -12
app.py CHANGED
@@ -13,6 +13,15 @@ processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
13
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
14
  classes = list()
15
 
 
 
 
 
 
 
 
 
 
16
  def rescale_bbox(bbox,orig_image_shape=(1024,1024),model_shape=352):
17
  bbox = np.asarray(bbox)/model_shape
18
  y1,y2 = bbox[::2] *orig_image_shape[0]
@@ -30,7 +39,10 @@ def detect_using_clip(image,prompts=[],threshould=0.4):
30
  with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation
31
  outputs = model(**inputs)
32
  preds = outputs.logits.unsqueeze(1)
 
 
33
  detection = outputs.logits[0] # Assuming class index 0
 
34
  for i,prompt in enumerate(prompts):
35
  predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
36
  predicted_image = np.where(predicted_image>threshould,255,0)
@@ -39,7 +51,7 @@ def detect_using_clip(image,prompts=[],threshould=0.4):
39
  props = regionprops(lbl_0)
40
  model_detections[prompt] = [rescale_bbox(prop.bbox,orig_image_shape=image.shape[:2],model_shape=predicted_image.shape[0]) for prop in props]
41
 
42
- return model_detections
43
 
44
  def visualize_images(image,detections,prompt):
45
  H,W = image.shape[:2]
@@ -49,26 +61,20 @@ def visualize_images(image,detections,prompt):
49
  return image_copy
50
  for bbox in detections[prompt]:
51
  cv2.rectangle(image_copy, (int(bbox[1]), int(bbox[0])), (int(bbox[3]), int(bbox[2])), (255, 0, 0), 2)
 
52
  return image_copy
53
 
54
 
55
  def shot(image, labels_text,selected_categoty):
56
- print("Labels Text ",labels_text)
57
  prompts = labels_text.split(',')
58
- classes = prompts
59
-
60
- print("prompts :",prompts,classes)
61
- print("Image shape ",image.shape )
62
-
63
- model_detections = detect_using_clip(image,prompts=prompts)
64
- print("detections :",model_detections)
65
- print("Ctegory ",selected_categoty)
66
- return visualize_images(image=image,detections=model_detections,prompt=selected_categoty)
67
 
68
  iface = gr.Interface(fn=shot,
69
  inputs = ["image","text","text"],
70
  outputs="image",
71
- description="Add a picture and a list of labels separated by commas",
72
  title="Zero-shot Image Classification with Prompt ",
73
  examples=[["images/room.jpg","bed,table,plant",'plant']],
74
  # allow_flagging=False,
 
13
  model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
14
  classes = list()
15
 
16
+ def create_mask(image,image_mask,alpha=0.7):
17
+ mask = np.zeros_like(image)
18
+ # copy your image_mask to all dimensions (i.e. colors) of your image
19
+ for i in range(3):
20
+ mask[:,:,i] = image_mask.copy()
21
+ # apply the mask to your image
22
+ overlay_image = cv2.addWeighted(mask,alpha,image,1-alpha,0)
23
+ return overlay_image
24
+
25
  def rescale_bbox(bbox,orig_image_shape=(1024,1024),model_shape=352):
26
  bbox = np.asarray(bbox)/model_shape
27
  y1,y2 = bbox[::2] *orig_image_shape[0]
 
39
  with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation
40
  outputs = model(**inputs)
41
  preds = outputs.logits.unsqueeze(1)
42
+
43
+ # tensor_images = [torch.sigmoid(preds[i][0]) for i in range(len(prompts))]
44
  detection = outputs.logits[0] # Assuming class index 0
45
+
46
  for i,prompt in enumerate(prompts):
47
  predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
48
  predicted_image = np.where(predicted_image>threshould,255,0)
 
51
  props = regionprops(lbl_0)
52
  model_detections[prompt] = [rescale_bbox(prop.bbox,orig_image_shape=image.shape[:2],model_shape=predicted_image.shape[0]) for prop in props]
53
 
54
+ return model_detections
55
 
56
  def visualize_images(image,detections,prompt):
57
  H,W = image.shape[:2]
 
61
  return image_copy
62
  for bbox in detections[prompt]:
63
  cv2.rectangle(image_copy, (int(bbox[1]), int(bbox[0])), (int(bbox[3]), int(bbox[2])), (255, 0, 0), 2)
64
+ cv2.putText(image_copy,(int(bbox[1]), int(bbox[0])),cv2.FONT_HERSHEY_SIMPLEX, 2, 255)
65
  return image_copy
66
 
67
 
68
  def shot(image, labels_text,selected_categoty):
 
69
  prompts = labels_text.split(',')
70
+ model_detections = detect_using_clip(image,prompts=prompts)
71
+ category_image = visualize_images(image=image,detections=model_detections,prompt=selected_categoty)
72
+ return category_image
 
 
 
 
 
 
73
 
74
  iface = gr.Interface(fn=shot,
75
  inputs = ["image","text","text"],
76
  outputs="image",
77
+ description="Add an Image and list of category to be detected separated by commas",
78
  title="Zero-shot Image Classification with Prompt ",
79
  examples=[["images/room.jpg","bed,table,plant",'plant']],
80
  # allow_flagging=False,