IbrahimHasani commited on
Commit
06499c0
1 Parent(s): 652137b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -0
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from transformers import OwlViTProcessor, OwlViTForObjectDetection, ResNetModel
5
+ from torchvision import transforms
6
+ from PIL import Image
7
+ import cv2
8
+ import torch.nn.functional as F
9
+ import tempfile
10
+ import os
11
+
12
+ # Load models
13
+ resnet = ResNetModel.from_pretrained("Microsoft/resnet-50")
14
+ resnet.eval()
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ resnet = resnet.to(device)
17
+
18
+ mixin = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32")
19
+ processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
20
+ model = mixin.to(device)
21
+
22
+ # Preprocess the image
23
+ def preprocess_image(image):
24
+ transform = transforms.Compose([
25
+ transforms.Resize((224, 224)),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
28
+ ])
29
+ return transform(image).unsqueeze(0)
30
+
31
+ def extract_embedding(image):
32
+ image_tensor = preprocess_image(image).to(device)
33
+ with torch.no_grad():
34
+ output = resnet(image_tensor)
35
+ embedding = output.pooler_output
36
+ return embedding
37
+
38
+ def cosine_similarity(embedding1, embedding2):
39
+ return F.cosine_similarity(embedding1, embedding2)
40
+
41
+ def l2_distance(embedding1, embedding2):
42
+ return torch.norm(embedding1 - embedding2, p=2)
43
+
44
+ def save_array_to_temp_image(arr):
45
+ rgb_arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
46
+ img = Image.fromarray(rgb_arr)
47
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
48
+ temp_file_name = temp_file.name
49
+ temp_file.close()
50
+ img.save(temp_file_name)
51
+ return temp_file_name
52
+
53
+ def detect_and_crop(target_image, query_image, threshold=0.6, nms_threshold=0.3):
54
+ target_sizes = torch.Tensor([target_image.size[::-1]])
55
+ inputs = processor(images=target_image, query_images=query_image, return_tensors="pt").to(device)
56
+ with torch.no_grad():
57
+ outputs = model.image_guided_detection(**inputs)
58
+
59
+ img = cv2.cvtColor(np.array(target_image), cv2.COLOR_BGR2RGB)
60
+ outputs.logits = outputs.logits.cpu()
61
+ outputs.target_pred_boxes = outputs.target_pred_boxes.cpu()
62
+
63
+ results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold, nms_threshold=nms_threshold, target_sizes=target_sizes)
64
+ boxes, scores = results[0]["boxes"], results[0]["scores"]
65
+
66
+ if len(boxes) == 0:
67
+ return []
68
+
69
+ filtered_boxes = []
70
+ for box in boxes:
71
+ x1, y1, x2, y2 = [int(i) for i in box.tolist()]
72
+ cropped_img = img[y1:y2, x1:x2]
73
+ if cropped_img.size != 0:
74
+ filtered_boxes.append(cropped_img)
75
+
76
+ return filtered_boxes
77
+
78
+ def process_video(video_path, query_image, skipframes=0):
79
+ cap = cv2.VideoCapture(video_path)
80
+ if not cap.isOpened():
81
+ return
82
+
83
+ frame_count = 0
84
+ all_results = []
85
+ while True:
86
+ ret, frame = cap.read()
87
+ if not ret:
88
+ break
89
+ if frame_count % (skipframes + 1) == 0:
90
+ frame_file = save_array_to_temp_image(frame)
91
+ result_frames = detect_and_crop(Image.open(frame_file), query_image)
92
+ for res in result_frames:
93
+ saved_res = save_array_to_temp_image(res)
94
+ embedding1 = extract_embedding(query_image)
95
+ embedding2 = extract_embedding(Image.open(saved_res))
96
+ dist = l2_distance(embedding1, embedding2).item()
97
+ cos = cosine_similarity(embedding1, embedding2).item()
98
+ all_results.append({'l2_dist': dist, 'cos': cos})
99
+ frame_count += 1
100
+ cap.release()
101
+ return all_results
102
+
103
+ def process_videos_and_compare(image, video, skipframes=5, threshold=0.47):
104
+ def median(values):
105
+ n = len(values)
106
+ return (values[n // 2 - 1] + values[n // 2]) / 2 if n % 2 == 0 else values[n // 2]
107
+
108
+ results = process_video(video, image, skipframes)
109
+ if results:
110
+ l2_dists = [item['l2_dist'] for item in results]
111
+ cosines = [item['cos'] for item in results]
112
+ avg_l2_dist = sum(l2_dists) / len(l2_dists)
113
+ avg_cos = sum(cosines) / len(cosines)
114
+ median_l2_dist = median(sorted(l2_dists))
115
+ median_cos = median(sorted(cosines))
116
+ result = {
117
+ "avg_l2_dist": avg_l2_dist,
118
+ "avg_cos": avg_cos,
119
+ "median_l2_dist": median_l2_dist,
120
+ "median_cos": median_cos,
121
+ "avg_cos_dist": 1 - avg_cos,
122
+ "median_cos_dist": 1 - median_cos,
123
+ "is_present": avg_cos >= threshold
124
+ }
125
+ else:
126
+ result = {
127
+ "avg_l2_dist": float('inf'),
128
+ "avg_cos": 0,
129
+ "median_l2_dist": float('inf'),
130
+ "median_cos": 0,
131
+ "avg_cos_dist": float('inf'),
132
+ "median_cos_dist": float('inf'),
133
+ "is_present": False
134
+ }
135
+ return result
136
+
137
+ def interface(video, image, skipframes, threshold):
138
+ result = process_videos_and_compare(image, video, skipframes, threshold)
139
+ return result
140
+
141
+ iface = gr.Interface(
142
+ fn=interface,
143
+ inputs=[
144
+ gr.Video(label="Upload a Video"),
145
+ gr.Image(type="pil", label="Upload a Query Image"),
146
+ gr.Slider(minimum=0, maximum=10, step=1, default=5, label="Skip Frames"),
147
+ gr.Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.47, label="Threshold")
148
+ ],
149
+ outputs=[
150
+ gr.JSON(label="Result")
151
+ ],
152
+ title="Object Detection in Video",
153
+ description="""
154
+ **Instructions:**
155
+
156
+ 1. **Upload a Video**: Select a video file to upload.
157
+ 2. **Upload a Query Image**: Select an image file that contains the object you want to detect in the video.
158
+ 3. **Set Skip Frames**: Adjust the slider to set the number of frames to skip between each processing.
159
+ 4. **Set Threshold**: Adjust the slider to set the threshold for cosine similarity to determine if the object is present in the video.
160
+ 5. **View Results**: The result will show the average and median distances and similarities, and whether the object is present in the video based on the threshold.
161
+ """
162
+ )
163
+
164
+ if __name__ == "__main__":
165
+ iface.launch()