thinh-huynh-re commited on
Commit
1def0a4
1 Parent(s): 6b89aad
Files changed (2) hide show
  1. requirements.txt +1 -0
  2. run_opencv.py +106 -92
requirements.txt CHANGED
@@ -6,3 +6,4 @@ black
6
  opencv-python
7
  opencv-python-headless
8
  streamlit-webrtc
 
 
6
  opencv-python
7
  opencv-python-headless
8
  streamlit-webrtc
9
+ typed-argument-parser
run_opencv.py CHANGED
@@ -1,134 +1,148 @@
1
- from typing import List, Tuple
2
 
3
  import cv2
4
  import numpy as np
5
  import pandas as pd
6
  import torch
 
7
  from torch import Tensor
8
  from transformers import AutoFeatureExtractor, TimesformerForVideoClassification
9
 
10
  from utils.img_container import ImgContainer
11
 
12
 
13
- def load_model(model_name: str):
14
- if "base-finetuned-k400" in model_name or "base-finetuned-k600" in model_name:
15
- feature_extractor = AutoFeatureExtractor.from_pretrained(
16
- "MCG-NJU/videomae-base-finetuned-kinetics"
17
- )
18
- else:
19
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
20
- model = TimesformerForVideoClassification.from_pretrained(model_name)
21
- return feature_extractor, model
 
 
 
22
 
 
23
 
24
- def inference():
25
- if not img_container.ready:
26
- return
27
 
28
- inputs = feature_extractor(list(img_container.imgs), return_tensors="pt")
 
 
29
 
30
- with torch.no_grad():
31
- outputs = model(**inputs)
32
- logits: Tensor = outputs.logits
33
 
34
- # model predicts one of the 400 Kinetics-400 classes
35
- max_index = logits.argmax(-1).item()
36
- predicted_label = model.config.id2label[max_index]
 
 
 
 
 
 
37
 
38
- img_container.frame_rate.label = f"{predicted_label}_{logits[0][max_index]:.2f}%"
 
 
39
 
40
- TOP_K = 12
41
- # logits = np.squeeze(logits)
42
- logits = logits.squeeze().numpy()
43
- indices = np.argsort(logits)[::-1][:TOP_K]
44
- values = logits[indices]
45
 
46
- results: List[Tuple[str, float]] = []
47
- for index, value in zip(indices, values):
48
- predicted_label = model.config.id2label[index]
49
- # print(f"Label: {predicted_label} - {value:.2f}%")
50
- results.append((predicted_label, value))
51
 
52
- img_container.rs = pd.DataFrame(results, columns=("Label", "Confidence"))
 
 
 
 
 
 
53
 
 
 
 
 
 
54
 
55
- def get_frames_per_video(model_name: str) -> int:
56
- if "base-finetuned" in model_name:
57
- return 8
58
- elif "hr-finetuned" in model_name:
59
- return 16
60
- else:
61
- return 96
62
 
 
63
 
64
- model_name = "facebook/timesformer-base-finetuned-k400"
65
- # "facebook/timesformer-base-finetuned-k400"
66
- # "facebook/timesformer-base-finetuned-k600",
67
- # "facebook/timesformer-base-finetuned-ssv2",
68
- # "facebook/timesformer-hr-finetuned-k600",
69
- # "facebook/timesformer-hr-finetuned-k400",
70
- # "facebook/timesformer-hr-finetuned-ssv2",
71
- # "fcakyon/timesformer-large-finetuned-k400",
72
- # "fcakyon/timesformer-large-finetuned-k600",
73
- feature_extractor, model = load_model(model_name)
74
 
75
 
76
- frames_per_video = get_frames_per_video(model_name)
77
- print(f"Frames per video: {frames_per_video}")
 
78
 
79
- img_container = ImgContainer(frames_per_video)
80
 
81
- SKIP_FRAMES = 4
 
82
 
83
- num_skips = 0
 
 
84
 
85
- # define a video capture object
86
- camera = cv2.VideoCapture(0)
 
87
 
88
- frame_width = int(camera.get(3))
89
- frame_height = int(camera.get(4))
90
- size = (frame_width, frame_height)
91
 
92
- video_output = cv2.VideoWriter(
93
- "activities.mp4", cv2.VideoWriter_fourcc(*"MJPG"), 10, size
94
- )
 
95
 
96
- if camera.isOpened() == False:
97
- print("Error reading video file")
98
 
99
- while camera.isOpened():
100
- # Capture the video frame
101
- # by frame
102
- ret, frame = camera.read()
103
 
104
- num_skips = (num_skips + 1) % SKIP_FRAMES
 
 
 
105
 
106
- img_container.img = frame
107
- img_container.frame_rate.count()
108
 
109
- if num_skips == 0:
110
- img_container.add_frame(frame)
111
- # inference()
112
- rs = img_container.frame_rate.show_fps(frame, img_container.is_recording)
113
 
114
- # Display the resulting frame
115
- cv2.imshow("ActivityTracking", rs)
 
 
116
 
117
- if img_container.is_recording:
118
- video_output.write(rs)
 
 
119
 
120
- # the 'q' button is set as the
121
- # quitting button you may use any
122
- # desired button of your choice
123
- k = cv2.waitKey(1)
 
124
 
125
- if k == ord("q"):
126
- break
127
- elif k == ord("r"):
128
- img_container.toggle_recording()
129
 
130
- # After the loop release the cap object
131
- camera.release()
132
- video_output.release()
133
- # Destroy all the windows
134
- cv2.destroyAllWindows()
 
1
+ from typing import List, Optional, Tuple
2
 
3
  import cv2
4
  import numpy as np
5
  import pandas as pd
6
  import torch
7
+ from tap import Tap
8
  from torch import Tensor
9
  from transformers import AutoFeatureExtractor, TimesformerForVideoClassification
10
 
11
  from utils.img_container import ImgContainer
12
 
13
 
14
+ class ArgParser(Tap):
15
+ is_recording: Optional[bool] = False
16
+
17
+ # "facebook/timesformer-base-finetuned-k400"
18
+ # "facebook/timesformer-base-finetuned-k600",
19
+ # "facebook/timesformer-base-finetuned-ssv2",
20
+ # "facebook/timesformer-hr-finetuned-k600",
21
+ # "facebook/timesformer-hr-finetuned-k400",
22
+ # "facebook/timesformer-hr-finetuned-ssv2",
23
+ # "fcakyon/timesformer-large-finetuned-k400",
24
+ # "fcakyon/timesformer-large-finetuned-k600",
25
+ model_name: Optional[str] = "facebook/timesformer-base-finetuned-k400"
26
 
27
+ num_skip_frames: Optional[int] = 4
28
 
 
 
 
29
 
30
+ class ActivityModel:
31
+ def __init__(self, args: ArgParser):
32
+ self.feature_extractor, self.model = self.load_model(args.model_name)
33
 
34
+ self.frames_per_video = self.get_frames_per_video(args.model_name)
35
+ print(f"Frames per video: {self.frames_per_video}")
 
36
 
37
+ def load_model(self, model_name: str):
38
+ if "base-finetuned-k400" in model_name or "base-finetuned-k600" in model_name:
39
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
40
+ "MCG-NJU/videomae-base-finetuned-kinetics"
41
+ )
42
+ else:
43
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
44
+ model = TimesformerForVideoClassification.from_pretrained(model_name)
45
+ return feature_extractor, model
46
 
47
+ def inference(self, img_container: ImgContainer):
48
+ if not img_container.ready:
49
+ return
50
 
51
+ inputs = self.feature_extractor(list(img_container.imgs), return_tensors="pt")
 
 
 
 
52
 
53
+ with torch.no_grad():
54
+ outputs = self.model(**inputs)
55
+ logits: Tensor = outputs.logits
 
 
56
 
57
+ # model predicts one of the 400 Kinetics-400 classes
58
+ max_index = logits.argmax(-1).item()
59
+ predicted_label = self.model.config.id2label[max_index]
60
+
61
+ img_container.frame_rate.label = (
62
+ f"{predicted_label}_{logits[0][max_index]:.2f}%"
63
+ )
64
 
65
+ TOP_K = 12
66
+ # logits = np.squeeze(logits)
67
+ logits = logits.squeeze().numpy()
68
+ indices = np.argsort(logits)[::-1][:TOP_K]
69
+ values = logits[indices]
70
 
71
+ results: List[Tuple[str, float]] = []
72
+ for index, value in zip(indices, values):
73
+ predicted_label = self.model.config.id2label[index]
74
+ # print(f"Label: {predicted_label} - {value:.2f}%")
75
+ results.append((predicted_label, value))
 
 
76
 
77
+ img_container.rs = pd.DataFrame(results, columns=("Label", "Confidence"))
78
 
79
+ def get_frames_per_video(self, model_name: str) -> int:
80
+ if "base-finetuned" in model_name:
81
+ return 8
82
+ elif "hr-finetuned" in model_name:
83
+ return 16
84
+ else:
85
+ return 96
 
 
 
86
 
87
 
88
+ def main(args: ArgParser):
89
+ activity_model = ActivityModel(args)
90
+ img_container = ImgContainer(activity_model.frames_per_video)
91
 
92
+ num_skips = 0
93
 
94
+ # define a video capture object
95
+ camera = cv2.VideoCapture(0)
96
 
97
+ frame_width = int(camera.get(3))
98
+ frame_height = int(camera.get(4))
99
+ size = (frame_width, frame_height)
100
 
101
+ video_output = cv2.VideoWriter(
102
+ "activities.mp4", cv2.VideoWriter_fourcc(*"MP4V"), 10, size
103
+ )
104
 
105
+ if camera.isOpened() == False:
106
+ print("Error reading video file")
 
107
 
108
+ while camera.isOpened():
109
+ # Capture the video frame
110
+ # by frame
111
+ ret, frame = camera.read()
112
 
113
+ num_skips = (num_skips + 1) % args.num_skip_frames
 
114
 
115
+ img_container.img = frame
116
+ img_container.frame_rate.count()
 
 
117
 
118
+ if num_skips == 0:
119
+ img_container.add_frame(frame)
120
+ activity_model.inference(img_container)
121
+ rs = img_container.frame_rate.show_fps(frame, img_container.is_recording)
122
 
123
+ # Display the resulting frame
124
+ cv2.imshow("ActivityTracking", rs)
125
 
126
+ if img_container.is_recording:
127
+ video_output.write(rs)
 
 
128
 
129
+ # the 'q' button is set as the
130
+ # quitting button you may use any
131
+ # desired button of your choice
132
+ k = cv2.waitKey(1)
133
 
134
+ if k == ord("q"):
135
+ break
136
+ elif k == ord("r"):
137
+ img_container.toggle_recording()
138
 
139
+ # After the loop release the cap object
140
+ camera.release()
141
+ video_output.release()
142
+ # Destroy all the windows
143
+ cv2.destroyAllWindows()
144
 
 
 
 
 
145
 
146
+ if __name__ == "__main__":
147
+ args = ArgParser().parse_args()
148
+ main(args)