Spaces:
Runtime error
Runtime error
kushagra124
commited on
Commit
•
42545c2
1
Parent(s):
b2f9f4f
adding text box
Browse files
app.py
CHANGED
@@ -13,6 +13,15 @@ processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
|
13 |
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
14 |
classes = list()
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
def rescale_bbox(bbox,orig_image_shape=(1024,1024),model_shape=352):
|
17 |
bbox = np.asarray(bbox)/model_shape
|
18 |
y1,y2 = bbox[::2] *orig_image_shape[0]
|
@@ -30,7 +39,10 @@ def detect_using_clip(image,prompts=[],threshould=0.4):
|
|
30 |
with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation
|
31 |
outputs = model(**inputs)
|
32 |
preds = outputs.logits.unsqueeze(1)
|
|
|
|
|
33 |
detection = outputs.logits[0] # Assuming class index 0
|
|
|
34 |
for i,prompt in enumerate(prompts):
|
35 |
predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
|
36 |
predicted_image = np.where(predicted_image>threshould,255,0)
|
@@ -39,7 +51,7 @@ def detect_using_clip(image,prompts=[],threshould=0.4):
|
|
39 |
props = regionprops(lbl_0)
|
40 |
model_detections[prompt] = [rescale_bbox(prop.bbox,orig_image_shape=image.shape[:2],model_shape=predicted_image.shape[0]) for prop in props]
|
41 |
|
42 |
-
return model_detections
|
43 |
|
44 |
def visualize_images(image,detections,prompt):
|
45 |
H,W = image.shape[:2]
|
@@ -49,26 +61,20 @@ def visualize_images(image,detections,prompt):
|
|
49 |
return image_copy
|
50 |
for bbox in detections[prompt]:
|
51 |
cv2.rectangle(image_copy, (int(bbox[1]), int(bbox[0])), (int(bbox[3]), int(bbox[2])), (255, 0, 0), 2)
|
|
|
52 |
return image_copy
|
53 |
|
54 |
|
55 |
def shot(image, labels_text,selected_categoty):
|
56 |
-
print("Labels Text ",labels_text)
|
57 |
prompts = labels_text.split(',')
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
print("Image shape ",image.shape )
|
62 |
-
|
63 |
-
model_detections = detect_using_clip(image,prompts=prompts)
|
64 |
-
print("detections :",model_detections)
|
65 |
-
print("Ctegory ",selected_categoty)
|
66 |
-
return visualize_images(image=image,detections=model_detections,prompt=selected_categoty)
|
67 |
|
68 |
iface = gr.Interface(fn=shot,
|
69 |
inputs = ["image","text","text"],
|
70 |
outputs="image",
|
71 |
-
description="Add
|
72 |
title="Zero-shot Image Classification with Prompt ",
|
73 |
examples=[["images/room.jpg","bed,table,plant",'plant']],
|
74 |
# allow_flagging=False,
|
|
|
13 |
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
14 |
classes = list()
|
15 |
|
16 |
+
def create_mask(image,image_mask,alpha=0.7):
|
17 |
+
mask = np.zeros_like(image)
|
18 |
+
# copy your image_mask to all dimensions (i.e. colors) of your image
|
19 |
+
for i in range(3):
|
20 |
+
mask[:,:,i] = image_mask.copy()
|
21 |
+
# apply the mask to your image
|
22 |
+
overlay_image = cv2.addWeighted(mask,alpha,image,1-alpha,0)
|
23 |
+
return overlay_image
|
24 |
+
|
25 |
def rescale_bbox(bbox,orig_image_shape=(1024,1024),model_shape=352):
|
26 |
bbox = np.asarray(bbox)/model_shape
|
27 |
y1,y2 = bbox[::2] *orig_image_shape[0]
|
|
|
39 |
with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation
|
40 |
outputs = model(**inputs)
|
41 |
preds = outputs.logits.unsqueeze(1)
|
42 |
+
|
43 |
+
# tensor_images = [torch.sigmoid(preds[i][0]) for i in range(len(prompts))]
|
44 |
detection = outputs.logits[0] # Assuming class index 0
|
45 |
+
|
46 |
for i,prompt in enumerate(prompts):
|
47 |
predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
|
48 |
predicted_image = np.where(predicted_image>threshould,255,0)
|
|
|
51 |
props = regionprops(lbl_0)
|
52 |
model_detections[prompt] = [rescale_bbox(prop.bbox,orig_image_shape=image.shape[:2],model_shape=predicted_image.shape[0]) for prop in props]
|
53 |
|
54 |
+
return model_detections
|
55 |
|
56 |
def visualize_images(image,detections,prompt):
|
57 |
H,W = image.shape[:2]
|
|
|
61 |
return image_copy
|
62 |
for bbox in detections[prompt]:
|
63 |
cv2.rectangle(image_copy, (int(bbox[1]), int(bbox[0])), (int(bbox[3]), int(bbox[2])), (255, 0, 0), 2)
|
64 |
+
cv2.putText(image_copy,(int(bbox[1]), int(bbox[0])),cv2.FONT_HERSHEY_SIMPLEX, 2, 255)
|
65 |
return image_copy
|
66 |
|
67 |
|
68 |
def shot(image, labels_text,selected_categoty):
|
|
|
69 |
prompts = labels_text.split(',')
|
70 |
+
model_detections = detect_using_clip(image,prompts=prompts)
|
71 |
+
category_image = visualize_images(image=image,detections=model_detections,prompt=selected_categoty)
|
72 |
+
return category_image
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
iface = gr.Interface(fn=shot,
|
75 |
inputs = ["image","text","text"],
|
76 |
outputs="image",
|
77 |
+
description="Add an Image and list of category to be detected separated by commas",
|
78 |
title="Zero-shot Image Classification with Prompt ",
|
79 |
examples=[["images/room.jpg","bed,table,plant",'plant']],
|
80 |
# allow_flagging=False,
|