Hbvsa commited on
Commit
ccfbdde
1 Parent(s): 6d2982b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +234 -233
app.py CHANGED
@@ -1,234 +1,235 @@
1
- import subprocess
2
-
3
- def run_commands():
4
- commands = [
5
- "apt-get update",
6
- "apt-get install -y libgl1",
7
- "git clone https://github.com/IDEA-Research/GroundingDINO.git",
8
- "pip install -e ./GroundingDINO",
9
- "cd GroundingDINO",
10
- "mkdir weights",
11
- "wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
12
- "cd .."
13
- ]
14
-
15
- for command in commands:
16
- try:
17
- print(f"Running command: {command}")
18
- result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
19
- print(result.stdout.decode())
20
- except subprocess.CalledProcessError as e:
21
- print(f"Command '{command}' failed with error: {e.stderr.decode()}")
22
-
23
- # Call the function to run the commands
24
-
25
- if __name__ == "__main__":
26
- run_commands()
27
-
28
- from typing import List
29
- from Utils import get_video_properties
30
- from GroundingDINO.groundingdino.util.inference import load_model, predict
31
- import cv2
32
- import numpy as np
33
- import torch
34
- from PIL import Image
35
- import GroundingDINO.groundingdino.datasets.transforms as T
36
- from torchvision.ops import box_convert
37
- from torchvision import transforms
38
- from torch import nn
39
- from os.path import dirname, abspath
40
- import yaml
41
- import supervision as sv
42
- import gradio as gr
43
- import spaces
44
-
45
- class DinoVisionTransformerClassifier(nn.Module):
46
- def __init__(self):
47
- super(DinoVisionTransformerClassifier, self).__init__()
48
- self.transformer = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")
49
- self.classifier = nn.Sequential(nn.Linear(384, 256), nn.ReLU(), nn.Linear(256, 2))
50
-
51
- def forward(self, x):
52
- x = self.transformer(x)
53
- x = self.transformer.norm(x)
54
- x = self.classifier(x)
55
- return x
56
-
57
-
58
- class ImageClassifier:
59
-
60
- def __init__(self):
61
- with open(f"{dirname(abspath(__file__))}/config.yaml", 'r') as f:
62
- config = yaml.load(f, Loader=yaml.FullLoader)
63
- labels = config["labels"]
64
-
65
- self.labels = labels
66
- self.dino = DinoVisionTransformerClassifier()
67
- model_path = f"{dirname(abspath(__file__))}/model.pth"
68
- state_dict = torch.load(model_path)
69
- self.dino.load_state_dict(state_dict)
70
-
71
- def preprocess(self, image: np.ndarray) -> torch.Tensor:
72
- data_transforms = {
73
- "test": transforms.Compose(
74
- [
75
- transforms.Resize((224, 224)),
76
- transforms.ToTensor(),
77
- transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
78
- ]
79
- )
80
- }
81
- image_pillow = Image.fromarray(image)
82
- img_transformed = data_transforms['test'](image_pillow)
83
-
84
- return img_transformed
85
-
86
- def predict(self, image):
87
- image = self.preprocess(image)
88
- image = image.unsqueeze(0)
89
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
90
- self.dino.to(device)
91
- self.dino.eval()
92
- with torch.no_grad():
93
- output = self.dino(image.to(device))
94
-
95
- logit, predicted = torch.max(output.data, 1)
96
- return self.labels[predicted[0].item()], logit[0].item()
97
-
98
-
99
- class VideoObjectDetection:
100
-
101
- def __init__(self,
102
- text_prompt: str
103
- ):
104
-
105
- self.text_prompt = text_prompt
106
-
107
- def crop(self, frame, boxes):
108
-
109
- h, w, _ = frame.shape
110
- boxes = boxes * torch.Tensor([w, h, w, h])
111
- xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
112
- min_col, min_row, max_col, max_row = map(int, xyxy[0])
113
- crop_image = frame[min_row:max_row, min_col:max_col, :]
114
-
115
- return crop_image
116
-
117
- def annotate(self,
118
- image_source: np.ndarray,
119
- boxes: torch.Tensor,
120
- logits: torch.Tensor,
121
- phrases: List[str],
122
- frame_rgb: np.ndarray,
123
- classifier) -> np.ndarray:
124
-
125
- h, w, _ = image_source.shape
126
- boxes = boxes * torch.Tensor([w, h, w, h])
127
- xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
128
- detections = sv.Detections(xyxy=xyxy)
129
- print(xyxy.shape)
130
- custom_labels = []
131
- custom_logits = []
132
-
133
- for box in xyxy:
134
- min_col, min_row, max_col, max_row = map(int, box)
135
- crop_image = frame_rgb[min_row:max_row, min_col:max_col, :]
136
- label, logit = classifier.predict(crop_image)
137
- print()
138
- if logit >= 1:
139
- custom_labels.append(label)
140
- custom_logits.append(logit)
141
- else:
142
- custom_labels.append('unknown human face')
143
- custom_logits.append(logit)
144
-
145
- labels = [
146
- f"{phrase} {logit:.2f}"
147
- for phrase, logit
148
- in zip(custom_labels, custom_logits)
149
- ]
150
-
151
- box_annotator = sv.BoxAnnotator()
152
- annotated_frame = box_annotator.annotate(scene=image_source, detections=detections, labels=labels)
153
- return annotated_frame
154
-
155
- def preprocess_image(self, image: np.ndarray) -> torch.Tensor:
156
- transform = T.Compose(
157
- [
158
- T.RandomResize([800], max_size=1333),
159
- T.ToTensor(),
160
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
161
- ]
162
- )
163
-
164
- image_pillow = Image.fromarray(image)
165
- image_transformed, _ = transform(image_pillow, None)
166
- return image_transformed
167
-
168
- def generate_video(self, video_path) -> None:
169
-
170
- # Load model, set up variables and get video properties
171
- cap, fps, width, height, fourcc = get_video_properties(video_path)
172
- model = load_model("GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
173
- "GroundingDINO/weights/groundingdino_swint_ogc.pth")
174
- predictor = ImageClassifier()
175
- TEXT_PROMPT = self.text_prompt
176
- BOX_TRESHOLD = 0.6
177
- TEXT_TRESHOLD = 0.6
178
-
179
- # Read video frames, crop image based on text prompt object detection and generate dataset_train
180
- import time
181
- frame_count = 0
182
- delay = 1 / fps # Delay in seconds between frames
183
- while cap.isOpened():
184
- start_time = time.time()
185
- ret, frame = cap.read()
186
- if not ret:
187
- break
188
-
189
- if cv2.waitKey(1) & 0xff == ord('q'):
190
- break
191
-
192
- # Convert bgr frame to rgb frame to image to torch tensor transformed
193
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
194
- image_transformed = self.preprocess_image(frame_rgb)
195
-
196
- boxes, logits, phrases = predict(
197
- model=model,
198
- image=image_transformed,
199
- caption=TEXT_PROMPT,
200
- box_threshold=BOX_TRESHOLD,
201
- text_threshold=TEXT_TRESHOLD
202
- )
203
-
204
- # Get boxes
205
- if boxes.size()[0] > 0:
206
- annotated_frame = self.annotate(image_source=frame, boxes=boxes, logits=logits,
207
- phrases=phrases, frame_rgb=frame_rgb, classifier=predictor)
208
- # cv2.imshow('Object detection', annotated_frame)
209
- frame_rgb = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
210
-
211
- yield frame_rgb
212
- elapsed_time = time.time() - start_time
213
- time_to_wait = max(delay - elapsed_time, 0)
214
- time.sleep(time_to_wait)
215
-
216
- frame_count += 1
217
-
218
-
219
- @spaces.GPU(duration=200)
220
- def video_object_classification_pipeline():
221
- video_annotator = VideoObjectDetection(
222
- text_prompt='human face')
223
-
224
- with gr.Blocks() as iface:
225
- video_input = gr.Video(label="Upload Video")
226
- run_button = gr.Button("Start Processing")
227
- output_image = gr.Image(label="Classified video")
228
- run_button.click(fn=video_annotator.generate_video, inputs=video_input,
229
- outputs=output_image)
230
-
231
- iface.launch(share=False, debug=True)
232
-
233
- print("Só me falta a GPU")
 
234
  video_object_classification_pipeline()
 
1
+ import subprocess
2
+
3
+ def run_commands():
4
+ commands = [
5
+ "apt-get update",
6
+ "apt-get install -y libgl1",
7
+ "git clone https://github.com/IDEA-Research/GroundingDINO.git",
8
+ "pip install -e ./GroundingDINO",
9
+ "cd GroundingDINO",
10
+ "mkdir weights",
11
+ "wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth",
12
+ "cd .."
13
+ "ls"
14
+ ]
15
+
16
+ for command in commands:
17
+ try:
18
+ print(f"Running command: {command}")
19
+ result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
20
+ print(result.stdout.decode())
21
+ except subprocess.CalledProcessError as e:
22
+ print(f"Command '{command}' failed with error: {e.stderr.decode()}")
23
+
24
+ # Call the function to run the commands
25
+
26
+ if __name__ == "__main__":
27
+ run_commands()
28
+
29
+ from typing import List
30
+ from Utils import get_video_properties
31
+ from GroundingDINO.groundingdino.util.inference import load_model, predict
32
+ import cv2
33
+ import numpy as np
34
+ import torch
35
+ from PIL import Image
36
+ import GroundingDINO.groundingdino.datasets.transforms as T
37
+ from torchvision.ops import box_convert
38
+ from torchvision import transforms
39
+ from torch import nn
40
+ from os.path import dirname, abspath
41
+ import yaml
42
+ import supervision as sv
43
+ import gradio as gr
44
+ import spaces
45
+
46
+ class DinoVisionTransformerClassifier(nn.Module):
47
+ def __init__(self):
48
+ super(DinoVisionTransformerClassifier, self).__init__()
49
+ self.transformer = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14")
50
+ self.classifier = nn.Sequential(nn.Linear(384, 256), nn.ReLU(), nn.Linear(256, 2))
51
+
52
+ def forward(self, x):
53
+ x = self.transformer(x)
54
+ x = self.transformer.norm(x)
55
+ x = self.classifier(x)
56
+ return x
57
+
58
+
59
+ class ImageClassifier:
60
+
61
+ def __init__(self):
62
+ with open(f"{dirname(abspath(__file__))}/config.yaml", 'r') as f:
63
+ config = yaml.load(f, Loader=yaml.FullLoader)
64
+ labels = config["labels"]
65
+
66
+ self.labels = labels
67
+ self.dino = DinoVisionTransformerClassifier()
68
+ model_path = f"{dirname(abspath(__file__))}/model.pth"
69
+ state_dict = torch.load(model_path)
70
+ self.dino.load_state_dict(state_dict)
71
+
72
+ def preprocess(self, image: np.ndarray) -> torch.Tensor:
73
+ data_transforms = {
74
+ "test": transforms.Compose(
75
+ [
76
+ transforms.Resize((224, 224)),
77
+ transforms.ToTensor(),
78
+ transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
79
+ ]
80
+ )
81
+ }
82
+ image_pillow = Image.fromarray(image)
83
+ img_transformed = data_transforms['test'](image_pillow)
84
+
85
+ return img_transformed
86
+
87
+ def predict(self, image):
88
+ image = self.preprocess(image)
89
+ image = image.unsqueeze(0)
90
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
91
+ self.dino.to(device)
92
+ self.dino.eval()
93
+ with torch.no_grad():
94
+ output = self.dino(image.to(device))
95
+
96
+ logit, predicted = torch.max(output.data, 1)
97
+ return self.labels[predicted[0].item()], logit[0].item()
98
+
99
+
100
+ class VideoObjectDetection:
101
+
102
+ def __init__(self,
103
+ text_prompt: str
104
+ ):
105
+
106
+ self.text_prompt = text_prompt
107
+
108
+ def crop(self, frame, boxes):
109
+
110
+ h, w, _ = frame.shape
111
+ boxes = boxes * torch.Tensor([w, h, w, h])
112
+ xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
113
+ min_col, min_row, max_col, max_row = map(int, xyxy[0])
114
+ crop_image = frame[min_row:max_row, min_col:max_col, :]
115
+
116
+ return crop_image
117
+
118
+ def annotate(self,
119
+ image_source: np.ndarray,
120
+ boxes: torch.Tensor,
121
+ logits: torch.Tensor,
122
+ phrases: List[str],
123
+ frame_rgb: np.ndarray,
124
+ classifier) -> np.ndarray:
125
+
126
+ h, w, _ = image_source.shape
127
+ boxes = boxes * torch.Tensor([w, h, w, h])
128
+ xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
129
+ detections = sv.Detections(xyxy=xyxy)
130
+ print(xyxy.shape)
131
+ custom_labels = []
132
+ custom_logits = []
133
+
134
+ for box in xyxy:
135
+ min_col, min_row, max_col, max_row = map(int, box)
136
+ crop_image = frame_rgb[min_row:max_row, min_col:max_col, :]
137
+ label, logit = classifier.predict(crop_image)
138
+ print()
139
+ if logit >= 1:
140
+ custom_labels.append(label)
141
+ custom_logits.append(logit)
142
+ else:
143
+ custom_labels.append('unknown human face')
144
+ custom_logits.append(logit)
145
+
146
+ labels = [
147
+ f"{phrase} {logit:.2f}"
148
+ for phrase, logit
149
+ in zip(custom_labels, custom_logits)
150
+ ]
151
+
152
+ box_annotator = sv.BoxAnnotator()
153
+ annotated_frame = box_annotator.annotate(scene=image_source, detections=detections, labels=labels)
154
+ return annotated_frame
155
+
156
+ def preprocess_image(self, image: np.ndarray) -> torch.Tensor:
157
+ transform = T.Compose(
158
+ [
159
+ T.RandomResize([800], max_size=1333),
160
+ T.ToTensor(),
161
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
162
+ ]
163
+ )
164
+
165
+ image_pillow = Image.fromarray(image)
166
+ image_transformed, _ = transform(image_pillow, None)
167
+ return image_transformed
168
+
169
+ def generate_video(self, video_path) -> None:
170
+
171
+ # Load model, set up variables and get video properties
172
+ cap, fps, width, height, fourcc = get_video_properties(video_path)
173
+ model = load_model("GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
174
+ "GroundingDINO/weights/groundingdino_swint_ogc.pth")
175
+ predictor = ImageClassifier()
176
+ TEXT_PROMPT = self.text_prompt
177
+ BOX_TRESHOLD = 0.6
178
+ TEXT_TRESHOLD = 0.6
179
+
180
+ # Read video frames, crop image based on text prompt object detection and generate dataset_train
181
+ import time
182
+ frame_count = 0
183
+ delay = 1 / fps # Delay in seconds between frames
184
+ while cap.isOpened():
185
+ start_time = time.time()
186
+ ret, frame = cap.read()
187
+ if not ret:
188
+ break
189
+
190
+ if cv2.waitKey(1) & 0xff == ord('q'):
191
+ break
192
+
193
+ # Convert bgr frame to rgb frame to image to torch tensor transformed
194
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
195
+ image_transformed = self.preprocess_image(frame_rgb)
196
+
197
+ boxes, logits, phrases = predict(
198
+ model=model,
199
+ image=image_transformed,
200
+ caption=TEXT_PROMPT,
201
+ box_threshold=BOX_TRESHOLD,
202
+ text_threshold=TEXT_TRESHOLD
203
+ )
204
+
205
+ # Get boxes
206
+ if boxes.size()[0] > 0:
207
+ annotated_frame = self.annotate(image_source=frame, boxes=boxes, logits=logits,
208
+ phrases=phrases, frame_rgb=frame_rgb, classifier=predictor)
209
+ # cv2.imshow('Object detection', annotated_frame)
210
+ frame_rgb = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
211
+
212
+ yield frame_rgb
213
+ elapsed_time = time.time() - start_time
214
+ time_to_wait = max(delay - elapsed_time, 0)
215
+ time.sleep(time_to_wait)
216
+
217
+ frame_count += 1
218
+
219
+
220
+ @spaces.GPU(duration=200)
221
+ def video_object_classification_pipeline():
222
+ video_annotator = VideoObjectDetection(
223
+ text_prompt='human face')
224
+
225
+ with gr.Blocks() as iface:
226
+ video_input = gr.Video(label="Upload Video")
227
+ run_button = gr.Button("Start Processing")
228
+ output_image = gr.Image(label="Classified video")
229
+ run_button.click(fn=video_annotator.generate_video, inputs=video_input,
230
+ outputs=output_image)
231
+
232
+ iface.launch(share=False, debug=True)
233
+
234
+ print("Só me falta a GPU")
235
  video_object_classification_pipeline()