kushagra124 commited on
Commit
aae2aac
1 Parent(s): fa1485a

adding application for CLIP model detection

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from turtle import title
2
+ import os
3
+ import gradio as gr
4
+ from transformers import pipeline
5
+ import numpy as np
6
+ from PIL import Image
7
+ import torch
8
+ import cv2
9
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
10
+ from skimage.measure import label, regionprops
11
+
12
+ processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
13
+ model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
14
+
15
+ random_images = []
16
+ images_dir = 'random/images/'
17
+ for idx, images in enumerate(os.listdir(images_dir)):
18
+ image = os.path.join(images_dir, images)
19
+ if os.path.isfile(image) and idx < 10:
20
+ random_images.append(image)
21
+
22
+
23
+ def rescale_bbox(bbox,orig_image_shape=(1024,1024),model_shape=352):
24
+ bbox = np.asarray(bbox)/model_shape
25
+ y1,y2 = bbox[::2] *orig_image_shape[0]
26
+ x1,x2 = bbox[1::2]*orig_image_shape[1]
27
+ return [int(y1),int(x1),int(y2),int(x2)]
28
+
29
+ def detect_using_clip(image,prompts=[],threshould=0.4):
30
+ model_detections = dict()
31
+ inputs = processor(
32
+ text=prompts,
33
+ images=[image] * len(prompts),
34
+ padding="max_length",
35
+ return_tensors="pt",
36
+ )
37
+ with torch.no_grad(): # Use 'torch.no_grad()' to disable gradient computation
38
+ outputs = model(**inputs)
39
+ preds = outputs.logits.unsqueeze(1)
40
+ detection = outputs.logits[0] # Assuming class index 0
41
+ for i,prompt in enumerate(prompts):
42
+ predicted_image = torch.sigmoid(preds[i][0]).detach().cpu().numpy()
43
+ predicted_image = np.where(predicted_image>threshould,255,0)
44
+ # extract countours from the image
45
+ lbl_0 = label(predicted_image)
46
+ props = regionprops(lbl_0)
47
+ model_detections[prompt] = [rescale_bbox(prop.bbox,orig_image_shape=image.shape[:2],model_shape=predicted_image.shape[0]) for prop in props]
48
+
49
+ return model_detections
50
+
51
+ def display_images(image,detections,prompt='traffic light'):
52
+ H,W = image.shape[:2]
53
+ image_copy = image.copy()
54
+ if prompt not in detections.keys():
55
+ print("prompt not in query ..")
56
+ return image_copy
57
+ for bbox in detections[prompt]:
58
+ cv2.rectangle(image_copy, (int(bbox[1]), int(bbox[0])), (int(bbox[3]), int(bbox[2])), (255, 0, 0), 2)
59
+ return image_copy
60
+
61
+
62
+ def shot(image, labels_text):
63
+ prompts = labels_text.split(',')
64
+ detections = detect_using_clip(image,prompts=prompts)
65
+ print(detections)
66
+ return 0
67
+
68
+ # Input
69
+
70
+ output = gr.outputs.Image(type="numpy", label="Detected Objects with given Category")
71
+
72
+ title = "Object Detection Using Prompts with Zero Shot CLIP Model"
73
+
74
+ with gr.Blocks(title="Zero Shot Object ddetection using Text Prompts") as demo :
75
+ gr.Markdown(
76
+ """
77
+ <center>
78
+ <h1>
79
+ The CLIP Model
80
+ </h1>
81
+ A neural network called CLIP which efficiently learns visual concepts from natural language supervision. CLIP can be applied to any visual classification benchmark by simply providing the names of the visual categories to be recognized, similar to the “zero-shot” capabilities of GPT-2 and GPT-3.
82
+ </center>
83
+ """
84
+ )
85
+
86
+
87
+ with gr.Row():
88
+ with gr.Column():
89
+ inputt = gr.Image(type="numpy", label="Input Image for Classification")
90
+ labels = gr.Textbox(placeholder="Enter Label/ labels ex. cat,car,door,window,",scale=4)
91
+ button = gr.Button(value="Locate objects")
92
+ with gr.Column():
93
+ outputs = gr.outputs.Image(type="numpy", label="Detected Objects with Selected Category")
94
+ button.click(shot,inputt,labels)
95
+
96
+
97
+ demo.launch()
98
+ # iface = gr.Interface(fn=shot,
99
+ # inputs = ["image","text","label"],
100
+ # outputs=output,
101
+ # examples=random_images,
102
+ # allow_flagging=False,
103
+ # analytics_enabled=False,
104
+ # )