Hbvsa commited on
Commit
d9cb2e6
1 Parent(s): d8bad69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -119
app.py CHANGED
@@ -96,133 +96,133 @@ class ImageClassifier:
96
  logit, predicted = torch.max(output.data, 1)
97
  return self.labels[predicted[0].item()], logit[0].item()
98
 
99
- class VideoObjectDetection:
100
 
101
- def __init__(self,
102
- text_prompt: str
103
- ):
104
-
105
- self.text_prompt = text_prompt
106
-
107
- def crop(self, frame, boxes):
108
-
109
- h, w, _ = frame.shape
110
- boxes = boxes * torch.Tensor([w, h, w, h])
111
- xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
112
- min_col, min_row, max_col, max_row = map(int, xyxy[0])
113
- crop_image = frame[min_row:max_row, min_col:max_col, :]
114
-
115
- return crop_image
116
-
117
- def annotate(self,
118
- image_source: np.ndarray,
119
- boxes: torch.Tensor,
120
- logits: torch.Tensor,
121
- phrases: List[str],
122
- frame_rgb: np.ndarray,
123
- classifier) -> np.ndarray:
124
-
125
- h, w, _ = image_source.shape
126
- boxes = boxes * torch.Tensor([w, h, w, h])
127
- xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
128
- detections = sv.Detections(xyxy=xyxy)
129
- print(xyxy.shape)
130
- custom_labels = []
131
- custom_logits = []
132
-
133
- for box in xyxy:
134
-
135
- min_col, min_row, max_col, max_row = map(int, box)
136
- crop_image = frame_rgb[min_row:max_row, min_col:max_col, :]
137
- label, logit = classifier.predict(crop_image)
138
- print()
139
- if logit >= 1:
140
- custom_labels.append(label)
141
- custom_logits.append(logit)
142
- else:
143
- custom_labels.append('unknown human face')
144
- custom_logits.append(logit)
145
-
146
- labels = [
147
- f"{phrase} {logit:.2f}"
148
- for phrase, logit
149
- in zip(custom_labels, custom_logits)
150
- ]
151
-
152
- box_annotator = sv.BoxAnnotator()
153
- annotated_frame = box_annotator.annotate(scene=image_source, detections=detections, labels=labels)
154
- return annotated_frame
155
-
156
- def preprocess_image(self, image: np.ndarray) -> torch.Tensor:
157
- transform = T.Compose(
158
- [
159
- T.RandomResize([800], max_size=1333),
160
- T.ToTensor(),
161
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
162
- ]
163
- )
164
-
165
- image_pillow = Image.fromarray(image)
166
- image_transformed, _ = transform(image_pillow, None)
167
- return image_transformed
168
-
169
-
170
- def generate_video(self, video_path) -> None:
171
-
172
- # Load model, set up variables and get video properties
173
- cap, fps, width, height, fourcc = get_video_properties(video_path)
174
- model = load_model("GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
175
- "GroundingDINO/weights/groundingdino_swint_ogc.pth")
176
- predictor = ImageClassifier()
177
- TEXT_PROMPT = self.text_prompt
178
- BOX_TRESHOLD = 0.6
179
- TEXT_TRESHOLD = 0.6
180
-
181
- # Read video frames, crop image based on text prompt object detection and generate dataset_train
182
- import time
183
- frame_count = 0
184
- delay = 1 / fps # Delay in seconds between frames
185
- while cap.isOpened():
186
- start_time = time.time()
187
- ret, frame = cap.read()
188
- if not ret:
189
- break
190
-
191
- if cv2.waitKey(1) & 0xff == ord('q'):
192
- break
193
-
194
- # Convert bgr frame to rgb frame to image to torch tensor transformed
195
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
196
- image_transformed = self.preprocess_image(frame_rgb)
197
-
198
- boxes, logits, phrases = predict(
199
- model=model,
200
- image=image_transformed,
201
- caption=TEXT_PROMPT,
202
- box_threshold=BOX_TRESHOLD,
203
- text_threshold=TEXT_TRESHOLD
204
- )
205
-
206
- # Get boxes
207
- if boxes.size()[0] > 0:
208
- annotated_frame = self.annotate(image_source=frame, boxes=boxes, logits=logits,
209
- phrases=phrases, frame_rgb=frame_rgb, classifier=predictor)
210
- # cv2.imshow('Object detection', annotated_frame)
211
- frame_rgb = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
212
-
213
- yield frame_rgb
214
- elapsed_time = time.time() - start_time
215
- time_to_wait = max(delay - elapsed_time, 0)
216
- time.sleep(time_to_wait)
217
-
218
- frame_count += 1
219
 
220
  @spaces.GPU()
221
  def pipeline_to_setup_with_gpu():
222
  run_commands()
223
  from GroundingDINO.groundingdino.util.inference import load_model, predict
224
  import GroundingDINO.groundingdino.datasets.transforms as T
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
 
226
  video_annotator = VideoObjectDetection(
227
  text_prompt='human face')
228
 
 
96
  logit, predicted = torch.max(output.data, 1)
97
  return self.labels[predicted[0].item()], logit[0].item()
98
 
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  @spaces.GPU()
102
  def pipeline_to_setup_with_gpu():
103
  run_commands()
104
  from GroundingDINO.groundingdino.util.inference import load_model, predict
105
  import GroundingDINO.groundingdino.datasets.transforms as T
106
+ class VideoObjectDetection:
107
+
108
+ def __init__(self,
109
+ text_prompt: str
110
+ ):
111
+
112
+ self.text_prompt = text_prompt
113
+
114
+ def crop(self, frame, boxes):
115
+
116
+ h, w, _ = frame.shape
117
+ boxes = boxes * torch.Tensor([w, h, w, h])
118
+ xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
119
+ min_col, min_row, max_col, max_row = map(int, xyxy[0])
120
+ crop_image = frame[min_row:max_row, min_col:max_col, :]
121
+
122
+ return crop_image
123
+
124
+ def annotate(self,
125
+ image_source: np.ndarray,
126
+ boxes: torch.Tensor,
127
+ logits: torch.Tensor,
128
+ phrases: List[str],
129
+ frame_rgb: np.ndarray,
130
+ classifier) -> np.ndarray:
131
+
132
+ h, w, _ = image_source.shape
133
+ boxes = boxes * torch.Tensor([w, h, w, h])
134
+ xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
135
+ detections = sv.Detections(xyxy=xyxy)
136
+ print(xyxy.shape)
137
+ custom_labels = []
138
+ custom_logits = []
139
+
140
+ for box in xyxy:
141
+
142
+ min_col, min_row, max_col, max_row = map(int, box)
143
+ crop_image = frame_rgb[min_row:max_row, min_col:max_col, :]
144
+ label, logit = classifier.predict(crop_image)
145
+ print()
146
+ if logit >= 1:
147
+ custom_labels.append(label)
148
+ custom_logits.append(logit)
149
+ else:
150
+ custom_labels.append('unknown human face')
151
+ custom_logits.append(logit)
152
+
153
+ labels = [
154
+ f"{phrase} {logit:.2f}"
155
+ for phrase, logit
156
+ in zip(custom_labels, custom_logits)
157
+ ]
158
+
159
+ box_annotator = sv.BoxAnnotator()
160
+ annotated_frame = box_annotator.annotate(scene=image_source, detections=detections, labels=labels)
161
+ return annotated_frame
162
+
163
+ def preprocess_image(self, image: np.ndarray) -> torch.Tensor:
164
+ transform = T.Compose(
165
+ [
166
+ T.RandomResize([800], max_size=1333),
167
+ T.ToTensor(),
168
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
169
+ ]
170
+ )
171
+
172
+ image_pillow = Image.fromarray(image)
173
+ image_transformed, _ = transform(image_pillow, None)
174
+ return image_transformed
175
+
176
+
177
+ def generate_video(self, video_path) -> None:
178
+
179
+ # Load model, set up variables and get video properties
180
+ cap, fps, width, height, fourcc = get_video_properties(video_path)
181
+ model = load_model("GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
182
+ "GroundingDINO/weights/groundingdino_swint_ogc.pth")
183
+ predictor = ImageClassifier()
184
+ TEXT_PROMPT = self.text_prompt
185
+ BOX_TRESHOLD = 0.6
186
+ TEXT_TRESHOLD = 0.6
187
+
188
+ # Read video frames, crop image based on text prompt object detection and generate dataset_train
189
+ import time
190
+ frame_count = 0
191
+ delay = 1 / fps # Delay in seconds between frames
192
+ while cap.isOpened():
193
+ start_time = time.time()
194
+ ret, frame = cap.read()
195
+ if not ret:
196
+ break
197
+
198
+ if cv2.waitKey(1) & 0xff == ord('q'):
199
+ break
200
+
201
+ # Convert bgr frame to rgb frame to image to torch tensor transformed
202
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
203
+ image_transformed = self.preprocess_image(frame_rgb)
204
+
205
+ boxes, logits, phrases = predict(
206
+ model=model,
207
+ image=image_transformed,
208
+ caption=TEXT_PROMPT,
209
+ box_threshold=BOX_TRESHOLD,
210
+ text_threshold=TEXT_TRESHOLD
211
+ )
212
+
213
+ # Get boxes
214
+ if boxes.size()[0] > 0:
215
+ annotated_frame = self.annotate(image_source=frame, boxes=boxes, logits=logits,
216
+ phrases=phrases, frame_rgb=frame_rgb, classifier=predictor)
217
+ # cv2.imshow('Object detection', annotated_frame)
218
+ frame_rgb = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
219
+
220
+ yield frame_rgb
221
+ elapsed_time = time.time() - start_time
222
+ time_to_wait = max(delay - elapsed_time, 0)
223
+ time.sleep(time_to_wait)
224
 
225
+ frame_count += 1
226
  video_annotator = VideoObjectDetection(
227
  text_prompt='human face')
228