Hbvsa commited on
Commit
1b4035e
1 Parent(s): 1ce3359

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +203 -200
app.py CHANGED
@@ -2,9 +2,7 @@ import subprocess
2
  import sys
3
  from os.path import abspath, dirname,join
4
  import spaces
5
-
6
- sys.path.append(join(dirname(abspath(__file__)),'GroundingDINO'))
7
-
8
  def run_commands():
9
  commands = [
10
  "git clone https://github.com/IDEA-Research/GroundingDINO.git",
@@ -22,211 +20,216 @@ def run_commands():
22
  print(result.stdout.decode())
23
  except subprocess.CalledProcessError as e:
24
  print(f"Command '{command}' failed with error: {e.stderr.decode()}")
 
 
25
 
26
- @spaces.GPU(30)
27
- def video_app_setup_and_run_pipeline():
28
-
29
- #Setup needs gpu available to install properly
30
- run_commands()
31
- from typing import List
32
- from Utils import get_video_properties
33
- from GroundingDINO.groundingdino.util.inference import load_model, predict
34
- import cv2
35
- import numpy as np
36
- import torch
37
- from PIL import Image
38
- import GroundingDINO.groundingdino.datasets.transforms as T
39
- from torchvision.ops import box_convert
40
- from torchvision import transforms
41
- from torch import nn
42
- from os.path import dirname, abspath
43
- import yaml
44
- import supervision as sv
45
- import gradio as gr
46
-
47
- class DinoVisionTransformerClassifier(nn.Module):
48
- def __init__(self):
49
- super(DinoVisionTransformerClassifier, self).__init__()
50
- self.transformer = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")
51
- self.classifier = nn.Sequential(nn.Linear(384, 256), nn.ReLU(), nn.Linear(256, 2))
52
-
53
- def forward(self, x):
54
- x = self.transformer(x)
55
- x = self.transformer.norm(x)
56
- x = self.classifier(x)
57
- return x
58
 
59
- class ImageClassifier:
60
-
61
- def __init__(self):
62
- with open(f"{dirname(abspath(__file__))}/config.yaml", 'r') as f:
63
- config = yaml.load(f, Loader=yaml.FullLoader)
64
- labels = config["labels"]
65
-
66
- self.labels = labels
67
- self.dino = DinoVisionTransformerClassifier()
68
- model_path = f"{dirname(abspath(__file__))}/model.pth"
69
- state_dict = torch.load(model_path)
70
- self.dino.load_state_dict(state_dict)
71
-
72
- def preprocess(self, image: np.ndarray) -> torch.Tensor:
73
- data_transforms = {
74
- "test": transforms.Compose(
75
- [
76
- transforms.Resize((224, 224)),
77
- transforms.ToTensor(),
78
- transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
79
- ]
80
- )
81
- }
82
- image_pillow = Image.fromarray(image)
83
- img_transformed = data_transforms['test'](image_pillow)
84
-
85
- return img_transformed
86
-
87
- def predict(self, image):
88
- image = self.preprocess(image)
89
- image = image.unsqueeze(0)
90
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
91
- self.dino.to(device)
92
- self.dino.eval()
93
- with torch.no_grad():
94
- output = self.dino(image.to(device))
95
-
96
- logit, predicted = torch.max(output.data, 1)
97
- return self.labels[predicted[0].item()], logit[0].item()
98
-
99
- class VideoObjectDetection:
100
-
101
- def __init__(self,
102
- text_prompt: str
103
- ):
104
-
105
- self.text_prompt = text_prompt
106
-
107
- def crop(self, frame, boxes):
108
 
109
- h, w, _ = frame.shape
110
- boxes = boxes * torch.Tensor([w, h, w, h])
111
- xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
112
- min_col, min_row, max_col, max_row = map(int, xyxy[0])
113
- crop_image = frame[min_row:max_row, min_col:max_col, :]
114
-
115
- return crop_image
116
-
117
- def annotate(self,
118
- image_source: np.ndarray,
119
- boxes: torch.Tensor,
120
- logits: torch.Tensor,
121
- phrases: List[str],
122
- frame_rgb: np.ndarray,
123
- classifier) -> np.ndarray:
124
-
125
- h, w, _ = image_source.shape
126
- boxes = boxes * torch.Tensor([w, h, w, h])
127
- xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
128
- detections = sv.Detections(xyxy=xyxy)
129
- print(xyxy.shape)
130
- custom_labels = []
131
- custom_logits = []
132
 
133
- for box in xyxy:
134
- min_col, min_row, max_col, max_row = map(int, box)
135
- crop_image = frame_rgb[min_row:max_row, min_col:max_col, :]
136
- label, logit = classifier.predict(crop_image)
137
- print()
138
- if logit >= 1:
139
- custom_labels.append(label)
140
- custom_logits.append(logit)
141
- else:
142
- custom_labels.append('unknown human face')
143
- custom_logits.append(logit)
144
 
145
- labels = [
146
- f"{phrase} {logit:.2f}"
147
- for phrase, logit
148
- in zip(custom_labels, custom_logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  ]
150
-
151
- box_annotator = sv.BoxAnnotator()
152
- annotated_frame = box_annotator.annotate(scene=image_source, detections=detections, labels=labels)
153
- return annotated_frame
154
-
155
- def preprocess_image(self, image: np.ndarray) -> torch.Tensor:
156
- transform = T.Compose(
157
- [
158
- T.RandomResize([800], max_size=1333),
159
- T.ToTensor(),
160
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
161
- ]
162
- )
163
-
164
- image_pillow = Image.fromarray(image)
165
- image_transformed, _ = transform(image_pillow, None)
166
- return image_transformed
167
-
168
- def generate_video(self, video_path) -> None:
169
-
170
- # Load model, set up variables and get video properties
171
- cap, fps, width, height, fourcc = get_video_properties(video_path)
172
- model = load_model("GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
173
- "GroundingDINO/weights/groundingdino_swint_ogc.pth")
174
- predictor = ImageClassifier()
175
- TEXT_PROMPT = self.text_prompt
176
- BOX_TRESHOLD = 0.6
177
- TEXT_TRESHOLD = 0.6
178
-
179
- # Read video frames, crop image based on text prompt object detection and generate dataset_train
180
- import time
181
- frame_count = 0
182
- delay = 1 / fps # Delay in seconds between frames
183
- while cap.isOpened():
184
- start_time = time.time()
185
- ret, frame = cap.read()
186
- if not ret:
187
- break
188
-
189
- if cv2.waitKey(1) & 0xff == ord('q'):
190
- break
191
-
192
- # Convert bgr frame to rgb frame to image to torch tensor transformed
193
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
194
- image_transformed = self.preprocess_image(frame_rgb)
195
-
196
- boxes, logits, phrases = predict(
197
- model=model,
198
- image=image_transformed,
199
- caption=TEXT_PROMPT,
200
- box_threshold=BOX_TRESHOLD,
201
- text_threshold=TEXT_TRESHOLD
202
- )
203
-
204
- # Get boxes
205
- if boxes.size()[0] > 0:
206
- annotated_frame = self.annotate(image_source=frame, boxes=boxes, logits=logits,
207
- phrases=phrases, frame_rgb=frame_rgb, classifier=predictor)
208
- # cv2.imshow('Object detection', annotated_frame)
209
- frame_rgb = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
210
-
211
- yield frame_rgb
212
- elapsed_time = time.time() - start_time
213
- time_to_wait = max(delay - elapsed_time, 0)
214
- time.sleep(time_to_wait)
215
-
216
- frame_count += 1
217
-
218
- video_annotator = VideoObjectDetection(
219
- text_prompt='human face')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
- with gr.Blocks() as iface:
222
- video_input = gr.Video(label="Upload Video")
223
- run_button = gr.Button("Start Processing")
224
- output_image = gr.Image(label="Classified video")
225
- run_button.click(fn=video_annotator.generate_video, inputs=video_input,
226
- outputs=output_image)
227
 
228
- iface.launch(share=False, debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  if __name__ == "__main__":
231
- video_app_setup_and_run_pipeline()
 
 
 
 
 
 
 
 
 
 
 
232
 
 
2
  import sys
3
  from os.path import abspath, dirname,join
4
  import spaces
5
+ @spaces.GPU()
 
 
6
  def run_commands():
7
  commands = [
8
  "git clone https://github.com/IDEA-Research/GroundingDINO.git",
 
20
  print(result.stdout.decode())
21
  except subprocess.CalledProcessError as e:
22
  print(f"Command '{command}' failed with error: {e.stderr.decode()}")
23
+
24
+ run_commands()
25
 
26
+ from typing import List
27
+ from Utils import get_video_properties
28
+ from GroundingDINO.groundingdino.util.inference import load_model, predict
29
+ import cv2
30
+ import numpy as np
31
+ import torch
32
+ from PIL import Image
33
+ import GroundingDINO.groundingdino.datasets.transforms as T
34
+ from torchvision.ops import box_convert
35
+ from torchvision import transforms
36
+ from torch import nn
37
+ from os.path import dirname, abspath
38
+ import yaml
39
+ import supervision as sv
40
+ import gradio as gr
41
+ sys.path.append(join(dirname(abspath(__file__)),'GroundingDINO'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ class DinoVisionTransformerClassifier(nn.Module):
44
+ def __init__(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ super(DinoVisionTransformerClassifier, self).__init__()
47
+ self.transformer = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")
48
+ self.classifier = nn.Sequential(nn.Linear(384, 256), nn.ReLU(), nn.Linear(256, 2))
49
+
50
+ def forward(self, x):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ x = self.transformer(x)
53
+ x = self.transformer.norm(x)
54
+ x = self.classifier(x)
55
+ return x
56
+
57
+ class ImageClassifier:
58
+
59
+ def __init__(self):
 
 
 
60
 
61
+ with open(f"{dirname(abspath(__file__))}/config.yaml", 'r') as f:
62
+ config = yaml.load(f, Loader=yaml.FullLoader)
63
+ labels = config["labels"]
64
+
65
+ self.labels = labels
66
+ self.dino = DinoVisionTransformerClassifier()
67
+ model_path = f"{dirname(abspath(__file__))}/model.pth"
68
+ state_dict = torch.load(model_path)
69
+ self.dino.load_state_dict(state_dict)
70
+
71
+ def preprocess(self, image: np.ndarray) -> torch.Tensor:
72
+
73
+ data_transforms = {
74
+ "test": transforms.Compose(
75
+ [
76
+ transforms.Resize((224, 224)),
77
+ transforms.ToTensor(),
78
+ transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
79
  ]
80
+ )
81
+ }
82
+ image_pillow = Image.fromarray(image)
83
+ img_transformed = data_transforms['test'](image_pillow)
84
+
85
+ return img_transformed
86
+
87
+ def predict(self, image):
88
+
89
+ image = self.preprocess(image)
90
+ image = image.unsqueeze(0)
91
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
92
+ self.dino.to(device)
93
+ self.dino.eval()
94
+ with torch.no_grad():
95
+ output = self.dino(image.to(device))
96
+
97
+ logit, predicted = torch.max(output.data, 1)
98
+ return self.labels[predicted[0].item()], logit[0].item()
99
+
100
+ class VideoObjectDetection:
101
+
102
+ def __init__(self,
103
+ text_prompt: str
104
+ ):
105
+
106
+ self.text_prompt = text_prompt
107
+
108
+ def crop(self, frame, boxes):
109
+
110
+ h, w, _ = frame.shape
111
+ boxes = boxes * torch.Tensor([w, h, w, h])
112
+ xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
113
+ min_col, min_row, max_col, max_row = map(int, xyxy[0])
114
+ crop_image = frame[min_row:max_row, min_col:max_col, :]
115
+
116
+ return crop_image
117
+
118
+ def annotate(self,
119
+ image_source: np.ndarray,
120
+ boxes: torch.Tensor,
121
+ logits: torch.Tensor,
122
+ phrases: List[str],
123
+ frame_rgb: np.ndarray,
124
+ classifier) -> np.ndarray:
125
+
126
+ h, w, _ = image_source.shape
127
+ boxes = boxes * torch.Tensor([w, h, w, h])
128
+ xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
129
+ detections = sv.Detections(xyxy=xyxy)
130
+ print(xyxy.shape)
131
+ custom_labels = []
132
+ custom_logits = []
133
+
134
+ for box in xyxy:
135
+
136
+ min_col, min_row, max_col, max_row = map(int, box)
137
+ crop_image = frame_rgb[min_row:max_row, min_col:max_col, :]
138
+ label, logit = classifier.predict(crop_image)
139
+ print()
140
+ if logit >= 1:
141
+ custom_labels.append(label)
142
+ custom_logits.append(logit)
143
+ else:
144
+ custom_labels.append('unknown human face')
145
+ custom_logits.append(logit)
146
+
147
+ labels = [
148
+ f"{phrase} {logit:.2f}"
149
+ for phrase, logit
150
+ in zip(custom_labels, custom_logits)
151
+ ]
152
+
153
+ box_annotator = sv.BoxAnnotator()
154
+ annotated_frame = box_annotator.annotate(scene=image_source, detections=detections, labels=labels)
155
+ return annotated_frame
156
+
157
+ def preprocess_image(self, image: np.ndarray) -> torch.Tensor:
158
+ transform = T.Compose(
159
+ [
160
+ T.RandomResize([800], max_size=1333),
161
+ T.ToTensor(),
162
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
163
+ ]
164
+ )
165
+
166
+ image_pillow = Image.fromarray(image)
167
+ image_transformed, _ = transform(image_pillow, None)
168
+ return image_transformed
169
+
170
+ @spaces.GPU()
171
+ def generate_video(self, video_path) -> None:
172
+
173
+ # Load model, set up variables and get video properties
174
+ cap, fps, width, height, fourcc = get_video_properties(video_path)
175
+ model = load_model("GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
176
+ "GroundingDINO/weights/groundingdino_swint_ogc.pth")
177
+ predictor = ImageClassifier()
178
+ TEXT_PROMPT = self.text_prompt
179
+ BOX_TRESHOLD = 0.6
180
+ TEXT_TRESHOLD = 0.6
181
+
182
+ # Read video frames, crop image based on text prompt object detection and generate dataset_train
183
+ import time
184
+ frame_count = 0
185
+ delay = 1 / fps # Delay in seconds between frames
186
+ while cap.isOpened():
187
+ start_time = time.time()
188
+ ret, frame = cap.read()
189
+ if not ret:
190
+ break
191
+
192
+ if cv2.waitKey(1) & 0xff == ord('q'):
193
+ break
194
 
195
+ # Convert bgr frame to rgb frame to image to torch tensor transformed
196
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
197
+ image_transformed = self.preprocess_image(frame_rgb)
 
 
 
198
 
199
+ boxes, logits, phrases = predict(
200
+ model=model,
201
+ image=image_transformed,
202
+ caption=TEXT_PROMPT,
203
+ box_threshold=BOX_TRESHOLD,
204
+ text_threshold=TEXT_TRESHOLD
205
+ )
206
+
207
+ # Get boxes
208
+ if boxes.size()[0] > 0:
209
+ annotated_frame = self.annotate(image_source=frame, boxes=boxes, logits=logits,
210
+ phrases=phrases, frame_rgb=frame_rgb, classifier=predictor)
211
+ # cv2.imshow('Object detection', annotated_frame)
212
+ frame_rgb = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
213
+
214
+ yield frame_rgb
215
+ elapsed_time = time.time() - start_time
216
+ time_to_wait = max(delay - elapsed_time, 0)
217
+ time.sleep(time_to_wait)
218
+
219
+ frame_count += 1
220
+
221
 
222
  if __name__ == "__main__":
223
+
224
+ video_annotator = VideoObjectDetection(
225
+ text_prompt='human face')
226
+
227
+ with gr.Blocks() as iface:
228
+ video_input = gr.Video(label="Upload Video")
229
+ run_button = gr.Button("Start Processing")
230
+ output_image = gr.Image(label="Classified video")
231
+ run_button.click(fn=video_annotator.generate_video, inputs=video_input,
232
+ outputs=output_image)
233
+
234
+ iface.launch(share=False, debug=True)
235