Hbvsa commited on
Commit
1e60ef6
1 Parent(s): a884c06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +194 -196
app.py CHANGED
@@ -2,6 +2,7 @@ import subprocess
2
  import sys
3
  from os.path import abspath, dirname,join
4
  sys.path.append(join(dirname(abspath(__file__)),'GroundingDINO'))
 
5
  def run_commands():
6
  commands = [
7
  "git clone https://github.com/IDEA-Research/GroundingDINO.git",
@@ -20,204 +21,199 @@ def run_commands():
20
  except subprocess.CalledProcessError as e:
21
  print(f"Command '{command}' failed with error: {e.stderr.decode()}")
22
 
23
- # Call the function to run the commands
24
-
25
- if __name__ == "__main__":
26
-
27
- class DinoVisionTransformerClassifier(nn.Module):
28
- def __init__(self):
29
- super(DinoVisionTransformerClassifier, self).__init__()
30
- self.transformer = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")
31
- self.classifier = nn.Sequential(nn.Linear(384, 256), nn.ReLU(), nn.Linear(256, 2))
32
-
33
- def forward(self, x):
34
- x = self.transformer(x)
35
- x = self.transformer.norm(x)
36
- x = self.classifier(x)
37
- return x
38
-
39
-
40
- class ImageClassifier:
41
-
42
- def __init__(self):
43
- with open(f"{dirname(abspath(__file__))}/config.yaml", 'r') as f:
44
- config = yaml.load(f, Loader=yaml.FullLoader)
45
- labels = config["labels"]
46
-
47
- self.labels = labels
48
- self.dino = DinoVisionTransformerClassifier()
49
- model_path = f"{dirname(abspath(__file__))}/model.pth"
50
- state_dict = torch.load(model_path)
51
- self.dino.load_state_dict(state_dict)
52
-
53
- def preprocess(self, image: np.ndarray) -> torch.Tensor:
54
- data_transforms = {
55
- "test": transforms.Compose(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  [
57
- transforms.Resize((224, 224)),
58
- transforms.ToTensor(),
59
- transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
60
  ]
61
  )
62
- }
63
- image_pillow = Image.fromarray(image)
64
- img_transformed = data_transforms['test'](image_pillow)
65
-
66
- return img_transformed
67
-
68
- def predict(self, image):
69
- image = self.preprocess(image)
70
- image = image.unsqueeze(0)
71
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
72
- self.dino.to(device)
73
- self.dino.eval()
74
- with torch.no_grad():
75
- output = self.dino(image.to(device))
76
-
77
- logit, predicted = torch.max(output.data, 1)
78
- return self.labels[predicted[0].item()], logit[0].item()
79
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- class VideoObjectDetection:
82
-
83
- def __init__(self,
84
- text_prompt: str
85
- ):
86
-
87
- self.text_prompt = text_prompt
88
-
89
- def crop(self, frame, boxes):
90
-
91
- h, w, _ = frame.shape
92
- boxes = boxes * torch.Tensor([w, h, w, h])
93
- xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
94
- min_col, min_row, max_col, max_row = map(int, xyxy[0])
95
- crop_image = frame[min_row:max_row, min_col:max_col, :]
96
-
97
- return crop_image
98
-
99
- def annotate(self,
100
- image_source: np.ndarray,
101
- boxes: torch.Tensor,
102
- logits: torch.Tensor,
103
- phrases: List[str],
104
- frame_rgb: np.ndarray,
105
- classifier) -> np.ndarray:
106
-
107
- h, w, _ = image_source.shape
108
- boxes = boxes * torch.Tensor([w, h, w, h])
109
- xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
110
- detections = sv.Detections(xyxy=xyxy)
111
- print(xyxy.shape)
112
- custom_labels = []
113
- custom_logits = []
114
-
115
- for box in xyxy:
116
- min_col, min_row, max_col, max_row = map(int, box)
117
- crop_image = frame_rgb[min_row:max_row, min_col:max_col, :]
118
- label, logit = classifier.predict(crop_image)
119
- print()
120
- if logit >= 1:
121
- custom_labels.append(label)
122
- custom_logits.append(logit)
123
- else:
124
- custom_labels.append('unknown human face')
125
- custom_logits.append(logit)
126
-
127
- labels = [
128
- f"{phrase} {logit:.2f}"
129
- for phrase, logit
130
- in zip(custom_labels, custom_logits)
131
- ]
132
-
133
- box_annotator = sv.BoxAnnotator()
134
- annotated_frame = box_annotator.annotate(scene=image_source, detections=detections, labels=labels)
135
- return annotated_frame
136
-
137
- def preprocess_image(self, image: np.ndarray) -> torch.Tensor:
138
- transform = T.Compose(
139
- [
140
- T.RandomResize([800], max_size=1333),
141
- T.ToTensor(),
142
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
143
- ]
144
- )
145
-
146
- image_pillow = Image.fromarray(image)
147
- image_transformed, _ = transform(image_pillow, None)
148
- return image_transformed
149
-
150
- @spaces.GPU(duration=30)
151
- def generate_video(self, video_path) -> None:
152
- run_commands()
153
-
154
- from typing import List
155
- from Utils import get_video_properties
156
- from GroundingDINO.groundingdino.util.inference import load_model, predict
157
- import cv2
158
- import numpy as np
159
- import torch
160
- from PIL import Image
161
- import GroundingDINO.groundingdino.datasets.transforms as T
162
- from torchvision.ops import box_convert
163
- from torchvision import transforms
164
- from torch import nn
165
- from os.path import dirname, abspath
166
- import yaml
167
- import supervision as sv
168
- import gradio as gr
169
- import spaces
170
-
171
- # Load model, set up variables and get video properties
172
- cap, fps, width, height, fourcc = get_video_properties(video_path)
173
- model = load_model("GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
174
- "GroundingDINO/weights/groundingdino_swint_ogc.pth")
175
- predictor = ImageClassifier()
176
- TEXT_PROMPT = self.text_prompt
177
- BOX_TRESHOLD = 0.6
178
- TEXT_TRESHOLD = 0.6
179
-
180
- # Read video frames, crop image based on text prompt object detection and generate dataset_train
181
- import time
182
- frame_count = 0
183
- delay = 1 / fps # Delay in seconds between frames
184
- while cap.isOpened():
185
- start_time = time.time()
186
- ret, frame = cap.read()
187
- if not ret:
188
- break
189
-
190
- if cv2.waitKey(1) & 0xff == ord('q'):
191
- break
192
-
193
- # Convert bgr frame to rgb frame to image to torch tensor transformed
194
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
195
- image_transformed = self.preprocess_image(frame_rgb)
196
-
197
- boxes, logits, phrases = predict(
198
- model=model,
199
- image=image_transformed,
200
- caption=TEXT_PROMPT,
201
- box_threshold=BOX_TRESHOLD,
202
- text_threshold=TEXT_TRESHOLD
203
- )
204
-
205
- # Get boxes
206
- if boxes.size()[0] > 0:
207
- annotated_frame = self.annotate(image_source=frame, boxes=boxes, logits=logits,
208
- phrases=phrases, frame_rgb=frame_rgb, classifier=predictor)
209
- # cv2.imshow('Object detection', annotated_frame)
210
- frame_rgb = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
211
-
212
- yield frame_rgb
213
- elapsed_time = time.time() - start_time
214
- time_to_wait = max(delay - elapsed_time, 0)
215
- time.sleep(time_to_wait)
216
-
217
- frame_count += 1
218
-
219
-
220
- def video_object_classification_pipeline():
221
  video_annotator = VideoObjectDetection(
222
  text_prompt='human face')
223
 
@@ -229,5 +225,7 @@ if __name__ == "__main__":
229
  outputs=output_image)
230
 
231
  iface.launch(share=False, debug=True)
232
-
233
- video_object_classification_pipeline()
 
 
 
2
  import sys
3
  from os.path import abspath, dirname,join
4
  sys.path.append(join(dirname(abspath(__file__)),'GroundingDINO'))
5
+
6
  def run_commands():
7
  commands = [
8
  "git clone https://github.com/IDEA-Research/GroundingDINO.git",
 
21
  except subprocess.CalledProcessError as e:
22
  print(f"Command '{command}' failed with error: {e.stderr.decode()}")
23
 
24
+ @spaces.GPU(30)
25
+ def video_app_setup_and_run_pipeline():
26
+
27
+ run_commands()
28
+
29
+ class DinoVisionTransformerClassifier(nn.Module):
30
+ def __init__(self):
31
+ super(DinoVisionTransformerClassifier, self).__init__()
32
+ self.transformer = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")
33
+ self.classifier = nn.Sequential(nn.Linear(384, 256), nn.ReLU(), nn.Linear(256, 2))
34
+
35
+ def forward(self, x):
36
+ x = self.transformer(x)
37
+ x = self.transformer.norm(x)
38
+ x = self.classifier(x)
39
+ return x
40
+
41
+ class ImageClassifier:
42
+
43
+ def __init__(self):
44
+ with open(f"{dirname(abspath(__file__))}/config.yaml", 'r') as f:
45
+ config = yaml.load(f, Loader=yaml.FullLoader)
46
+ labels = config["labels"]
47
+
48
+ self.labels = labels
49
+ self.dino = DinoVisionTransformerClassifier()
50
+ model_path = f"{dirname(abspath(__file__))}/model.pth"
51
+ state_dict = torch.load(model_path)
52
+ self.dino.load_state_dict(state_dict)
53
+
54
+ def preprocess(self, image: np.ndarray) -> torch.Tensor:
55
+ data_transforms = {
56
+ "test": transforms.Compose(
57
+ [
58
+ transforms.Resize((224, 224)),
59
+ transforms.ToTensor(),
60
+ transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
61
+ ]
62
+ )
63
+ }
64
+ image_pillow = Image.fromarray(image)
65
+ img_transformed = data_transforms['test'](image_pillow)
66
+
67
+ return img_transformed
68
+
69
+ def predict(self, image):
70
+ image = self.preprocess(image)
71
+ image = image.unsqueeze(0)
72
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
73
+ self.dino.to(device)
74
+ self.dino.eval()
75
+ with torch.no_grad():
76
+ output = self.dino(image.to(device))
77
+
78
+ logit, predicted = torch.max(output.data, 1)
79
+ return self.labels[predicted[0].item()], logit[0].item()
80
+
81
+ class VideoObjectDetection:
82
+
83
+ def __init__(self,
84
+ text_prompt: str
85
+ ):
86
+
87
+ self.text_prompt = text_prompt
88
+
89
+ def crop(self, frame, boxes):
90
+
91
+ h, w, _ = frame.shape
92
+ boxes = boxes * torch.Tensor([w, h, w, h])
93
+ xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
94
+ min_col, min_row, max_col, max_row = map(int, xyxy[0])
95
+ crop_image = frame[min_row:max_row, min_col:max_col, :]
96
+
97
+ return crop_image
98
+
99
+ def annotate(self,
100
+ image_source: np.ndarray,
101
+ boxes: torch.Tensor,
102
+ logits: torch.Tensor,
103
+ phrases: List[str],
104
+ frame_rgb: np.ndarray,
105
+ classifier) -> np.ndarray:
106
+
107
+ h, w, _ = image_source.shape
108
+ boxes = boxes * torch.Tensor([w, h, w, h])
109
+ xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
110
+ detections = sv.Detections(xyxy=xyxy)
111
+ print(xyxy.shape)
112
+ custom_labels = []
113
+ custom_logits = []
114
+
115
+ for box in xyxy:
116
+ min_col, min_row, max_col, max_row = map(int, box)
117
+ crop_image = frame_rgb[min_row:max_row, min_col:max_col, :]
118
+ label, logit = classifier.predict(crop_image)
119
+ print()
120
+ if logit >= 1:
121
+ custom_labels.append(label)
122
+ custom_logits.append(logit)
123
+ else:
124
+ custom_labels.append('unknown human face')
125
+ custom_logits.append(logit)
126
+
127
+ labels = [
128
+ f"{phrase} {logit:.2f}"
129
+ for phrase, logit
130
+ in zip(custom_labels, custom_logits)
131
+ ]
132
+
133
+ box_annotator = sv.BoxAnnotator()
134
+ annotated_frame = box_annotator.annotate(scene=image_source, detections=detections, labels=labels)
135
+ return annotated_frame
136
+
137
+ def preprocess_image(self, image: np.ndarray) -> torch.Tensor:
138
+ transform = T.Compose(
139
  [
140
+ T.RandomResize([800], max_size=1333),
141
+ T.ToTensor(),
142
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
143
  ]
144
  )
145
+
146
+ image_pillow = Image.fromarray(image)
147
+ image_transformed, _ = transform(image_pillow, None)
148
+ return image_transformed
149
+
150
+ def generate_video(self, video_path) -> None:
151
+
152
+ # Load model, set up variables and get video properties
153
+ cap, fps, width, height, fourcc = get_video_properties(video_path)
154
+ model = load_model("GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
155
+ "GroundingDINO/weights/groundingdino_swint_ogc.pth")
156
+ predictor = ImageClassifier()
157
+ TEXT_PROMPT = self.text_prompt
158
+ BOX_TRESHOLD = 0.6
159
+ TEXT_TRESHOLD = 0.6
160
+
161
+ # Read video frames, crop image based on text prompt object detection and generate dataset_train
162
+ import time
163
+ frame_count = 0
164
+ delay = 1 / fps # Delay in seconds between frames
165
+ while cap.isOpened():
166
+ start_time = time.time()
167
+ ret, frame = cap.read()
168
+ if not ret:
169
+ break
170
+
171
+ if cv2.waitKey(1) & 0xff == ord('q'):
172
+ break
173
+
174
+ # Convert bgr frame to rgb frame to image to torch tensor transformed
175
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
176
+ image_transformed = self.preprocess_image(frame_rgb)
177
+
178
+ boxes, logits, phrases = predict(
179
+ model=model,
180
+ image=image_transformed,
181
+ caption=TEXT_PROMPT,
182
+ box_threshold=BOX_TRESHOLD,
183
+ text_threshold=TEXT_TRESHOLD
184
+ )
185
+
186
+ # Get boxes
187
+ if boxes.size()[0] > 0:
188
+ annotated_frame = self.annotate(image_source=frame, boxes=boxes, logits=logits,
189
+ phrases=phrases, frame_rgb=frame_rgb, classifier=predictor)
190
+ # cv2.imshow('Object detection', annotated_frame)
191
+ frame_rgb = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
192
+
193
+ yield frame_rgb
194
+ elapsed_time = time.time() - start_time
195
+ time_to_wait = max(delay - elapsed_time, 0)
196
+ time.sleep(time_to_wait)
197
+
198
+ frame_count += 1
199
+
200
+ from typing import List
201
+ from Utils import get_video_properties
202
+ from GroundingDINO.groundingdino.util.inference import load_model, predict
203
+ import cv2
204
+ import numpy as np
205
+ import torch
206
+ from PIL import Image
207
+ import GroundingDINO.groundingdino.datasets.transforms as T
208
+ from torchvision.ops import box_convert
209
+ from torchvision import transforms
210
+ from torch import nn
211
+ from os.path import dirname, abspath
212
+ import yaml
213
+ import supervision as sv
214
+ import gradio as gr
215
+ import spaces
216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
  video_annotator = VideoObjectDetection(
218
  text_prompt='human face')
219
 
 
225
  outputs=output_image)
226
 
227
  iface.launch(share=False, debug=True)
228
+
229
+ if __name__ == "__main__":
230
+ video_app_setup_and_run_pipeline()
231
+