thienphuc12339 commited on
Commit
9f83ce9
·
1 Parent(s): 4cc0c69

Add all source code

Browse files
Files changed (40) hide show
  1. .dockerignore +12 -0
  2. Dockerfile +27 -0
  3. README.md +29 -9
  4. __init__.py +3 -0
  5. app.py +311 -0
  6. models/dsta_slr_joint_motion_v3_0.onnx +3 -0
  7. models/sl_gcn_joint_v3_0.onnx +3 -0
  8. models/spoter_v3.0.onnx +3 -0
  9. request.py +15 -0
  10. requirements.txt +23 -0
  11. src/configs/__init__.py +1 -0
  12. src/configs/arguments.py +174 -0
  13. src/data/__init__.py +1 -0
  14. src/data/__pycache__/__init__.cpython-312.pyc +0 -0
  15. src/data/__pycache__/__init__.cpython-39.pyc +0 -0
  16. src/data/__pycache__/utils.cpython-312.pyc +0 -0
  17. src/data/__pycache__/utils.cpython-39.pyc +0 -0
  18. src/data/utils.py +157 -0
  19. src/inference.py +271 -0
  20. src/main.py +51 -0
  21. src/tools/__init__.py +3 -0
  22. src/tools/__pycache__/__init__.cpython-312.pyc +0 -0
  23. src/tools/__pycache__/__init__.cpython-39.pyc +0 -0
  24. src/tools/__pycache__/features.cpython-39.pyc +0 -0
  25. src/tools/__pycache__/models.cpython-312.pyc +0 -0
  26. src/tools/__pycache__/models.cpython-39.pyc +0 -0
  27. src/tools/features.py +29 -0
  28. src/tools/models.py +441 -0
  29. src/utils/__init__.py +2 -0
  30. src/utils/__pycache__/__init__.cpython-312.pyc +0 -0
  31. src/utils/__pycache__/constants.cpython-312.pyc +0 -0
  32. src/utils/__pycache__/loggers.cpython-312.pyc +0 -0
  33. src/utils/constants.py +158 -0
  34. src/utils/loggers.py +24 -0
  35. src/visualization/__init__.py +1 -0
  36. src/visualization/__pycache__/__init__.cpython-312.pyc +0 -0
  37. src/visualization/__pycache__/__init__.cpython-39.pyc +0 -0
  38. src/visualization/__pycache__/utils.cpython-312.pyc +0 -0
  39. src/visualization/__pycache__/utils.cpython-39.pyc +0 -0
  40. src/visualization/utils.py +55 -0
.dockerignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore build artifacts
2
+ *.log
3
+ *.tmp
4
+
5
+ # Ignore compiled Python files
6
+ __pycache__/
7
+ *.pyc
8
+ *.pyo
9
+ *.pyd
10
+
11
+ # Ignore files/directories
12
+ # engines/data/
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # Tắt buffering để log ra terminal ngay lập tức
4
+ ENV PYTHONUNBUFFERED=1
5
+
6
+ # Cài đặt các thư viện hệ thống cần thiết
7
+ RUN apt-get update && apt-get install -y \
8
+ libgl1-mesa-glx \
9
+ libglib2.0-0 \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ WORKDIR /app
13
+
14
+ # Sao chép requirements.txt vào container và cài đặt
15
+ COPY requirements.txt .
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Sao chép toàn bộ code vào container
19
+ COPY . .
20
+
21
+ # Thiết lập biến môi trường PORT (Hugging Face sẽ trỏ traffic vào port này)
22
+ ENV PORT 7860
23
+ EXPOSE 7860
24
+
25
+ # Chạy ứng dụng FastAPI bằng uvicorn
26
+ # Ở đây giả sử file main app của bạn là app.py và app là tên biến FastAPI instance
27
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,30 @@
1
- ---
2
- title: SignLanguage
3
- emoji: 🦀
4
- colorFrom: purple
5
- colorTo: indigo
6
- sdk: docker
7
- pinned: false
8
- ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Vietnamese Sign Language Translation
 
 
 
 
 
 
 
2
 
3
+ The Vietnamese Sign Language Translation is a project focused on developing advanced AI technology to accurately interpret Vietnamese sign language through body movements.
4
+
5
+ ## Installation
6
+ 1. Create an environment with `Python == 3.9.19`
7
+ 2. Install `Pytorchvideo`
8
+ ```
9
+ cd src/libs
10
+ git clone https://github.com/facebookresearch/pytorchvideo.git
11
+ pip install -e pytorchvideo
12
+ ```
13
+ 3. Install other requirements
14
+ ```
15
+ cd ../..
16
+ pip install -r requirements.txt
17
+ ```
18
+
19
+ ## Inference
20
+ 1. Prepare configurations for inference. Template for each architecture can be found at src/configs.
21
+ 2. Modify the inference config:
22
+ ```
23
+ inference:
24
+ source: webcam or path/to/video.mp4
25
+ output_dir: path/to/output/dir
26
+ ```
27
+ 3. Enter this command from `root` directory of the project to start inference.
28
+ ```
29
+ python src/inference.py --config_path path/to/config.yaml
30
+ ```
__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # WRITER: PhucNTT2 # EMAIL: thienphuc12339@gmail.com # DATE: 11/2023
2
+ # FROM: akaOCR Team
3
+ # ALL USE CASES MUST BE APPROVED BY AKAOCR TEAM
app.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from time import time
3
+ import pandas as pd
4
+ import numpy as np
5
+ import cv2
6
+ from typing import Optional
7
+ from pathlib import Path
8
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Query
9
+ from fastapi.responses import JSONResponse
10
+ import mediapipe as mp
11
+
12
+ from configs import ModelConfig, InferenceConfig
13
+ from utils import config_logger, POSE_BASED_MODELS
14
+ from data import Arm, get_sample_timestamp, ok_to_get_frame
15
+ from tools import load_pipeline, Predictions
16
+ from visualization import draw_text_on_image
17
+
18
+ app = FastAPI()
19
+
20
+ # Định nghĩa ba preset model
21
+ MODEL_PRESETS = {
22
+ "dsta_slr": {
23
+ "model": ModelConfig(
24
+ arch="dsta_slr",
25
+ pretrained="vsltranslation/dsta_slr_joint_motion_v3_0",
26
+ ),
27
+ "inference": InferenceConfig(
28
+ source="upload", # Sử dụng upload, không webcam
29
+ output_dir="demo/run_1",
30
+ use_onnx=True,
31
+ show_skeleton=True,
32
+ visualize=True,
33
+ bone_stream=False,
34
+ motion_stream=True,
35
+ ),
36
+ },
37
+ "sl_gcn": {
38
+ "model": ModelConfig(
39
+ arch="sl_gcn",
40
+ pretrained="models/dsta_slr_joint_motion_v3_0.onnx",
41
+ ),
42
+ "inference": InferenceConfig(
43
+ source="upload",
44
+ output_dir="demo/run_1",
45
+ use_onnx=True,
46
+ show_skeleton=True,
47
+ visualize=True,
48
+ bone_stream=True,
49
+ motion_stream=False,
50
+ ),
51
+ },
52
+ "spoter": {
53
+ "model": ModelConfig(
54
+ arch="spoter",
55
+ pretrained="vsltranslation/spoter_v3.0",
56
+ ),
57
+ "inference": InferenceConfig(
58
+ source="upload",
59
+ output_dir="demo/run_1",
60
+ use_onnx=True,
61
+ show_skeleton=True,
62
+ visualize=True,
63
+ ),
64
+ },
65
+ }
66
+
67
+ config_logger("inference.log")
68
+ logging.info("API started")
69
+
70
+ SPOTER_POSE_LANDMARKS = [
71
+ mp.solutions.pose.PoseLandmark.NOSE,
72
+ mp.solutions.pose.PoseLandmark.LEFT_EYE,
73
+ mp.solutions.pose.PoseLandmark.RIGHT_EYE,
74
+ mp.solutions.pose.PoseLandmark.RIGHT_SHOULDER,
75
+ mp.solutions.pose.PoseLandmark.LEFT_SHOULDER,
76
+ mp.solutions.pose.PoseLandmark.RIGHT_ELBOW,
77
+ mp.solutions.pose.PoseLandmark.LEFT_ELBOW,
78
+ mp.solutions.pose.PoseLandmark.RIGHT_WRIST,
79
+ mp.solutions.pose.PoseLandmark.LEFT_WRIST
80
+ ]
81
+
82
+ SPOTER_HAND_LANDMARKS = [
83
+ mp.solutions.hands.HandLandmark.WRIST,
84
+ mp.solutions.hands.HandLandmark.INDEX_FINGER_TIP, mp.solutions.hands.HandLandmark.INDEX_FINGER_DIP,
85
+ mp.solutions.hands.HandLandmark.INDEX_FINGER_PIP, mp.solutions.hands.HandLandmark.INDEX_FINGER_MCP,
86
+ mp.solutions.hands.HandLandmark.MIDDLE_FINGER_TIP, mp.solutions.hands.HandLandmark.MIDDLE_FINGER_DIP,
87
+ mp.solutions.hands.HandLandmark.MIDDLE_FINGER_PIP, mp.solutions.hands.HandLandmark.MIDDLE_FINGER_MCP,
88
+ mp.solutions.hands.HandLandmark.RING_FINGER_TIP, mp.solutions.hands.HandLandmark.RING_FINGER_DIP,
89
+ mp.solutions.hands.HandLandmark.RING_FINGER_PIP, mp.solutions.hands.HandLandmark.RING_FINGER_MCP,
90
+ mp.solutions.hands.HandLandmark.PINKY_TIP, mp.solutions.hands.HandLandmark.PINKY_DIP,
91
+ mp.solutions.hands.HandLandmark.PINKY_PIP, mp.solutions.hands.HandLandmark.PINKY_MCP,
92
+ mp.solutions.hands.HandLandmark.THUMB_TIP, mp.solutions.hands.HandLandmark.THUMB_IP,
93
+ mp.solutions.hands.HandLandmark.THUMB_MCP, mp.solutions.hands.HandLandmark.THUMB_CMC,
94
+ ]
95
+
96
+
97
+ @app.get("/healthcheck")
98
+ async def healthcheck():
99
+ return JSONResponse(status_code=200, content={"status": "UP"})
100
+
101
+
102
+ def run_inference(model_config, inference_config, input_frames):
103
+ pipeline = load_pipeline(model_config, inference_config)
104
+ logging.info("Pipeline loaded")
105
+
106
+ right_arm = Arm("right", inference_config.visibility)
107
+ left_arm = Arm("left", inference_config.visibility)
108
+ data = []
109
+ results = None
110
+ predictions = Predictions()
111
+
112
+ mp_holistic = mp.solutions.holistic
113
+ mp_drawing = mp.solutions.drawing_utils
114
+ mp_drawing_styles = mp.solutions.drawing_styles
115
+
116
+ custom_pose_style = mp_drawing_styles.get_default_pose_landmarks_style()
117
+ custom_right_hand_style = mp_drawing_styles.get_default_hand_landmarks_style()
118
+ custom_left_hand_style = mp_drawing_styles.get_default_hand_landmarks_style()
119
+ custom_pose_connections = list(mp_holistic.POSE_CONNECTIONS)
120
+ custom_hand_connections = list(mp_holistic.HAND_CONNECTIONS)
121
+
122
+ if inference_config.show_skeleton:
123
+ pose_landmarks = SPOTER_POSE_LANDMARKS
124
+ hand_landmarks = SPOTER_HAND_LANDMARKS
125
+ for landmark in mp.solutions.pose.PoseLandmark:
126
+ if landmark in pose_landmarks:
127
+ custom_pose_style[landmark] = mp.drawing.DrawingSpec(color=(0,255,0), thickness=2, circle_radius=2)
128
+ else:
129
+ custom_pose_style[landmark] = mp.drawing.DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0)
130
+ for connection_tuple in custom_pose_connections:
131
+ if landmark.value in connection_tuple:
132
+ custom_pose_connections.remove(connection_tuple)
133
+ for landmark in mp.solutions.hands.HandLandmark:
134
+ if landmark in hand_landmarks:
135
+ custom_right_hand_style[landmark] = mp.drawing.DrawingSpec(color=(0,0,255), thickness=2, circle_radius=2)
136
+ custom_left_hand_style[landmark] = mp.drawing.DrawingSpec(color=(255,0,0), thickness=2, circle_radius=2)
137
+ else:
138
+ custom_right_hand_style[landmark] = mp.drawing.DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0)
139
+ custom_left_hand_style[landmark] = mp.drawing.DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0)
140
+ for connection_tuple in custom_hand_connections:
141
+ if landmark.value in connection_tuple:
142
+ custom_hand_connections.remove(connection_tuple)
143
+
144
+ writer = None
145
+ if inference_config.output_dir is not None:
146
+ out_path = Path(inference_config.output_dir)
147
+ out_path.mkdir(parents=True, exist_ok=True)
148
+ if len(input_frames) > 0 and isinstance(input_frames[0], np.ndarray):
149
+ h, w, _ = input_frames[0].shape
150
+ writer = cv2.VideoWriter(str(out_path / "output.mp4"), cv2.VideoWriter_fourcc(*"mp4v"), 30, (w, h))
151
+
152
+ with mp_holistic.Holistic(min_detection_confidence=0.9, min_tracking_confidence=0.5) as holistic:
153
+ # giả định mỗi frame ~33ms, ở đây chỉ là demo logic
154
+ current_time_ms = 0
155
+ for frame in input_frames:
156
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
157
+ rgb_frame.flags.writeable = False
158
+ detection_results = holistic.process(rgb_frame)
159
+
160
+ try:
161
+ landmarks = detection_results.pose_landmarks.landmark
162
+ except:
163
+ current_time_ms += 33
164
+ continue
165
+
166
+ left_arm.set_pose(landmarks)
167
+ right_arm.set_pose(landmarks)
168
+
169
+ left_arm_ok_to_get_frame = ok_to_get_frame(
170
+ arm=left_arm,
171
+ angle_threshold=inference_config.angle_threshold,
172
+ min_num_up_frames=inference_config.min_num_up_frames,
173
+ min_num_down_frames=inference_config.min_num_down_frames,
174
+ current_time=current_time_ms,
175
+ delay=inference_config.delay,
176
+ )
177
+ right_arm_ok_to_get_frame = ok_to_get_frame(
178
+ arm=right_arm,
179
+ angle_threshold=inference_config.angle_threshold,
180
+ min_num_up_frames=inference_config.min_num_up_frames,
181
+ min_num_down_frames=inference_config.min_num_down_frames,
182
+ current_time=current_time_ms,
183
+ delay=inference_config.delay,
184
+ )
185
+
186
+ if left_arm_ok_to_get_frame or right_arm_ok_to_get_frame:
187
+ predictions = Predictions()
188
+ data.append(detection_results if inference_config.use_pose_model else frame)
189
+
190
+ start_time, end_time = get_sample_timestamp(left_arm, right_arm)
191
+ start_time /= 1000
192
+ end_time /= 1000
193
+
194
+ if start_time != 0 and end_time != 0:
195
+ start_inference_time = time()
196
+ predictions = Predictions(predictions=pipeline(np.array(data)))
197
+ predictions.inference_time = time() - start_inference_time
198
+ predictions.start_time = start_time
199
+ predictions.end_time = end_time
200
+ logging.info(str(predictions))
201
+ results = predictions.merge_results(results)
202
+
203
+ # Reset
204
+ start_time = 0
205
+ end_time = 0
206
+ left_arm.reset_state()
207
+ right_arm.reset_state()
208
+ data = []
209
+
210
+ # Vẽ kết quả
211
+ frame = left_arm.visualize(frame, (20, 10), "Left arm angle")
212
+ frame = right_arm.visualize(frame, (20, 40), "Right arm angle")
213
+ frame = predictions.visualize(frame, (20, 70))
214
+
215
+ if inference_config.show_skeleton:
216
+ mp.drawing.draw_landmarks(
217
+ frame,
218
+ detection_results.pose_landmarks,
219
+ connections=custom_pose_connections,
220
+ landmark_drawing_spec=custom_pose_style
221
+ )
222
+ mp.drawing.draw_landmarks(
223
+ frame,
224
+ detection_results.right_hand_landmarks,
225
+ connections=custom_hand_connections,
226
+ landmark_drawing_spec=custom_right_hand_style
227
+ )
228
+ mp.drawing.draw_landmarks(
229
+ frame,
230
+ detection_results.left_hand_landmarks,
231
+ connections=custom_hand_connections,
232
+ landmark_drawing_spec=custom_left_hand_style
233
+ )
234
+
235
+ if writer is not None:
236
+ writer.write(frame)
237
+
238
+ current_time_ms += 33
239
+
240
+ if writer is not None:
241
+ writer.release()
242
+ if results is not None:
243
+ pd.DataFrame(results).to_csv(Path(inference_config.output_dir) / "results.csv", index=False)
244
+
245
+ return predictions.predictions, results
246
+
247
+
248
+ @app.post("/inference")
249
+ async def inference_endpoint(
250
+ model_name: str = Query(..., description="Choose model: dsta_slr, sl_gcn, spoter"),
251
+ output_option: str = Query("all", description="Output option: 'predictions', 'csv', 'video', 'all'"),
252
+ output_dir: str = Query("demo/run_1", description="Output directory for results"),
253
+ file: UploadFile = File(...)
254
+ ):
255
+ """
256
+ Inference endpoint:
257
+ - model_name: chọn mô hình: dsta_slr, sl_gcn, spoter
258
+ - output_option: 'predictions', 'csv', 'video', hoặc 'all'
259
+ - output_dir: thư mục output, vd: 'my_results'
260
+ - file: upload 1 file video
261
+ """
262
+
263
+ if model_name not in MODEL_PRESETS:
264
+ raise HTTPException(status_code=400, detail="Invalid model_name")
265
+
266
+ # Đọc video từ file upload
267
+ video_bytes = np.asarray(bytearray(await file.read()), dtype=np.uint8)
268
+ temp_video_path = Path("temp_input.mp4")
269
+ with open(temp_video_path, "wb") as f:
270
+ f.write(video_bytes)
271
+ cap = cv2.VideoCapture(str(temp_video_path))
272
+
273
+ input_frames = []
274
+ while True:
275
+ ret, frame = cap.read()
276
+ if not ret:
277
+ break
278
+ input_frames.append(frame)
279
+ cap.release()
280
+
281
+ # Load config từ preset
282
+ model_config = MODEL_PRESETS[model_name]["model"]
283
+ inference_config = MODEL_PRESETS[model_name]["inference"]
284
+
285
+ # Ghi đè output_dir theo yêu cầu người dùng
286
+ inference_config.output_dir = output_dir
287
+
288
+ if model_config.arch in POSE_BASED_MODELS:
289
+ inference_config.use_pose_model = True
290
+ else:
291
+ inference_config.use_pose_model = False
292
+
293
+ predictions, results = run_inference(model_config, inference_config, input_frames)
294
+
295
+ resp = {}
296
+ out_dir = Path(inference_config.output_dir)
297
+ if predictions is None:
298
+ predictions = []
299
+
300
+ if output_option in ["predictions", "all"]:
301
+ resp["predictions"] = predictions
302
+
303
+ if output_option in ["csv", "all"]:
304
+ csv_path = str(out_dir / "results.csv")
305
+ resp["csv_path"] = csv_path if Path(csv_path).exists() else None
306
+
307
+ if output_option in ["video", "all"]:
308
+ video_path = str(out_dir / "output.mp4")
309
+ resp["video_path"] = video_path if Path(video_path).exists() else None
310
+
311
+ return resp
models/dsta_slr_joint_motion_v3_0.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ecfcb2b459fd68bfe838569d41bdb502f7cd21ddd675790146034cf0e6f71632
3
+ size 29678372
models/sl_gcn_joint_v3_0.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ab4e3b86ec2a828c9e8f72f1f80ca131c0b7439539412fe15244dbcb64fb2a1
3
+ size 17046336
models/spoter_v3.0.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38c21cd96446475cdc110f7748b11ad58b84cd055133379684f9f463dea8fcbd
3
+ size 24208453
request.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ url = 'https://<your-hf-space-url>.hf.space/inference' # URL thực tế sau khi deploy lên HF
4
+ video_path = '/path/to/your_video.mp4'
5
+ params = {
6
+ 'model_name': 'spoter',
7
+ 'output_option': 'all',
8
+ 'output_dir': 'custom_output_folder' # người dùng có thể chọn folder output
9
+ }
10
+ files = {
11
+ 'file': open(video_path, 'rb')
12
+ }
13
+
14
+ response = requests.post(url=url, files=files, params=params)
15
+ print(response.json())
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ pandas
3
+ evaluate
4
+ simple-parsing
5
+ torch
6
+ torchvision
7
+ hf-transfer
8
+ decord
9
+ accelerate
10
+ scikit-learn
11
+ wandb
12
+ pose-format
13
+ torchsummary
14
+ mediapipe
15
+ opencv-python
16
+ onnxruntime
17
+ onnx
18
+ imageio
19
+ tk
20
+ timm
21
+ einops
22
+ fastapi
23
+ uvicorn
src/configs/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .arguments import *
src/configs/arguments.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Any
3
+ from dataclasses import dataclass, field
4
+ from utils import MODELS, VIDEO_EXTENSIONS
5
+
6
+
7
+ @dataclass
8
+ class TransformConfig:
9
+ # RGB specific
10
+ horizontal_flip_prob: float = 0.5
11
+ aug_type: str = "augmix"
12
+ aug_paras: dict = field(
13
+ default_factory=lambda: {
14
+ "magnitude": 3,
15
+ "alpha": 1.0,
16
+ "width": 5,
17
+ "depth": -1,
18
+ }
19
+ )
20
+ sample_rate: int = 4
21
+
22
+ # Pose specific
23
+ normalization: bool = True
24
+
25
+ # SL-GCN, DSTA-SLR specific
26
+ random_choose: bool = False
27
+ random_shift: bool = False
28
+ random_move: bool = False
29
+ random_mirror: bool = False
30
+ random_mirror_p: float = 0.5
31
+ bone_stream: bool = False
32
+ motion_stream: bool = False
33
+
34
+ # SPOTER specific
35
+ augmentation: bool = True
36
+ aug_prob: float = 0.5
37
+ noise: bool = True
38
+
39
+ def __post_init__(self):
40
+ assert self.aug_type in ["augmix", "mixup"], \
41
+ "Only AugMix and MixUp are supported for now"
42
+
43
+
44
+ @dataclass
45
+ class DataConfig:
46
+ dataset: str = "vsl"
47
+ modality: str = "rgb"
48
+ subset: str = None
49
+ data_dir: str = "data/processed/vsl"
50
+ transform: Any = None
51
+ fps: int = 30
52
+ debug: bool = False
53
+ # transform: TransformConfig = TransformConfig()
54
+ transform: TransformConfig = field(default_factory=TransformConfig)
55
+
56
+
57
+ def __post_init__(self):
58
+ assert self.dataset in ["vsl_98", "vsl_400"], \
59
+ "Only VSL dataset is supported for now"
60
+ assert self.modality in ["rgb", "pose"], \
61
+ "Only RGB and Pose modalities are supported for now"
62
+
63
+
64
+ @dataclass
65
+ class ModelConfig:
66
+ arch: str = "sl_gcn"
67
+ pretrained: str = "vsltranslation/sl_gcn_joint_v3_0"
68
+ num_frozen_layers: int = 0
69
+ ignored_weights: list = field(default_factory=lambda: [])
70
+ num_frames: int = 16
71
+
72
+ # SL-GCN specific
73
+ num_points: int = 27
74
+ groups: int = 8
75
+ block_size: int = 41
76
+ in_channels: int = 3
77
+ labeling_mode: str = "spatial"
78
+ is_vector: bool = False
79
+
80
+ # DSTA-SLR specific
81
+ graph: str = "wlasl"
82
+ inner_dim: int = 64
83
+ drop_layers: int = 2
84
+ depth: int = 4
85
+ s_num_heads: int = 1
86
+ window_size: int = 120
87
+
88
+ # SPOTER specific
89
+ hidden_dim: int = 108
90
+
91
+ def __post_init__(self):
92
+ assert self.arch in MODELS, f"Model {self.arch} is not supported"
93
+
94
+
95
+ @dataclass
96
+ class TrainingConfig:
97
+ output_dir: str = "experiments"
98
+ remove_unused_columns: bool = False
99
+ do_train: bool = True
100
+ use_cpu: bool = False
101
+
102
+ eval_strategy: str = "epoch"
103
+ logging_strategy: str = "epoch"
104
+ save_strategy: str = "epoch"
105
+ logging_steps: int = 1
106
+ save_steps: int = 1
107
+ eval_steps: int = 1
108
+
109
+ learning_rate: float = 5e-5
110
+ weight_decay: float = 0
111
+ adam_beta1: float = 0.9
112
+ adam_beta2: float = 0.999
113
+ adam_epsilon: float = 1e-8
114
+ warmup_ratio: float = 0.1
115
+
116
+ num_train_epochs: int = 10
117
+ per_device_train_batch_size: int = 8
118
+ per_device_eval_batch_size: int = 8
119
+ dataloader_num_workers: int = 0
120
+
121
+ load_best_model_at_end: bool = True
122
+ metric_for_best_model: str = "accuracy"
123
+ resume_from_checkpoint: str = None
124
+
125
+ run_name: str = "swin3d"
126
+ report_to: str = None
127
+ push_to_hub: bool = False
128
+ hub_model_id: str = None
129
+ hub_strategy: str = "checkpoint"
130
+ hub_private_repo: bool = True
131
+
132
+ def __post_init__(self):
133
+ self.output_dir = Path(self.output_dir)
134
+ if str(self.output_dir) == "experiments":
135
+ self.output_dir = self.output_dir / self.run_name
136
+ self.output_dir.mkdir(parents=True, exist_ok=True)
137
+
138
+ if self.hub_model_id is not None:
139
+ self.push_to_hub = True
140
+ if len(self.hub_model_id.split("/")) == 1:
141
+ self.hub_model_id = f"{self.hub_model_id}/{self.run_name}"
142
+
143
+
144
+ @dataclass
145
+ class InferenceConfig:
146
+ source: str = "webcam"
147
+ output_dir: str = "demo"
148
+ use_onnx: bool = False
149
+ device: str = "cpu"
150
+ cache_dir: str = "models/huggingface"
151
+
152
+ visualize: bool = False
153
+ show_skeleton: bool = False
154
+
155
+ visibility: float = 0.5
156
+ angle_threshold: int = 140
157
+ min_num_up_frames: int = 10
158
+ min_num_down_frames: int = 10
159
+ delay: int = 400
160
+
161
+ top_k: int = 3
162
+ # SL-GCN, DSTA-SLR specific
163
+ bone_stream: bool = False
164
+ motion_stream: bool = False
165
+
166
+ def __post_init__(self):
167
+ self.source = Path(self.source)
168
+ assert any((
169
+ str(self.source) == "webcam",
170
+ (self.source.exists() and str(self.source).endswith(VIDEO_EXTENSIONS))
171
+ )), \
172
+ f"Only Webcam and Video sources are supported for now (got {self.source})"
173
+ self.output_dir = Path(self.output_dir)
174
+ self.output_dir.mkdir(parents=True, exist_ok=True)
src/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .utils import *
src/data/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (178 Bytes). View file
 
src/data/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (212 Bytes). View file
 
src/data/__pycache__/utils.cpython-312.pyc ADDED
Binary file (7.1 kB). View file
 
src/data/__pycache__/utils.cpython-39.pyc ADDED
Binary file (3.61 kB). View file
 
src/data/utils.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from mediapipe.python.solutions import pose
3
+ from visualization import draw_text_on_image
4
+
5
+
6
+ class Arm:
7
+ def __init__(
8
+ self,
9
+ side: str,
10
+ visibility: float = 0.5,
11
+ ) -> None:
12
+ if side == "left":
13
+ self.shoulde_idx = pose.PoseLandmark.LEFT_SHOULDER.value
14
+ self.elbow_idx = pose.PoseLandmark.LEFT_ELBOW.value
15
+ self.wrist_idx = pose.PoseLandmark.LEFT_WRIST.value
16
+ elif side == "right":
17
+ self.shoulde_idx = pose.PoseLandmark.RIGHT_SHOULDER.value
18
+ self.elbow_idx = pose.PoseLandmark.RIGHT_ELBOW.value
19
+ self.wrist_idx = pose.PoseLandmark.RIGHT_WRIST.value
20
+ else:
21
+ raise ValueError("Side must be either 'left' or 'right'")
22
+ self.visibility = visibility
23
+
24
+ self.is_up = False
25
+ self.num_up_frames = 0
26
+ self.num_down_frames = 0
27
+ self.start_time = 0
28
+ self.end_time = 0
29
+ self.shoulder = None
30
+ self.elbow = None
31
+ self.wrist = None
32
+ self.angle = 0
33
+
34
+ def reset_state(self) -> None:
35
+ self.is_up = False
36
+ self.num_up_frames = 0
37
+ self.num_down_frames = 0
38
+ self.start_time = 0
39
+ self.end_time = 0
40
+ self.shoulder = None
41
+ self.elbow = None
42
+ self.wrist = None
43
+ self.angle = 0
44
+
45
+ def set_pose(self, landmarks) -> bool:
46
+ if landmarks[self.shoulde_idx].visibility < self.visibility:
47
+ return False
48
+ self.shoulder = (
49
+ landmarks[self.shoulde_idx].x,
50
+ landmarks[self.shoulde_idx].y,
51
+ )
52
+
53
+ if landmarks[self.elbow_idx].visibility < self.visibility:
54
+ return False
55
+ self.elbow = (
56
+ landmarks[self.elbow_idx].x,
57
+ landmarks[self.elbow_idx].y,
58
+ )
59
+
60
+ if landmarks[self.wrist_idx].visibility < self.visibility:
61
+ return False
62
+ self.wrist = (
63
+ landmarks[self.wrist_idx].x,
64
+ landmarks[self.wrist_idx].y,
65
+ )
66
+
67
+ self.angle = calculate_angle(self.shoulder, self.elbow, self.wrist)
68
+ return True
69
+
70
+ def visualize(
71
+ self,
72
+ frame: np.ndarray,
73
+ position: tuple = (20, 50),
74
+ prefix: str = "Angle",
75
+ color: tuple = (0, 0, 255),
76
+ ) -> np.ndarray:
77
+ text = prefix + ": " + str(round(self.angle, 2))
78
+ return draw_text_on_image(
79
+ image=frame,
80
+ text=text,
81
+ position=position,
82
+ color=color,
83
+ font_size=20,
84
+ )
85
+
86
+
87
+ def get_sample_timestamp(left_arm: Arm, right_arm: Arm) -> tuple:
88
+ start_time, end_time = 0, 0
89
+ left_arm_available = left_arm.start_time > 0 and left_arm.end_time > 0
90
+ right_arm_available = right_arm.start_time > 0 and right_arm.end_time > 0
91
+
92
+ if left_arm_available and right_arm.start_time == 0:
93
+ start_time = left_arm.start_time
94
+ end_time = left_arm.end_time
95
+ if right_arm_available and left_arm.start_time == 0:
96
+ start_time = right_arm.start_time
97
+ end_time = right_arm.end_time
98
+ if all((
99
+ left_arm_available, not left_arm.is_up,
100
+ right_arm_available, not right_arm.is_up,
101
+ )):
102
+ start_time = min(left_arm.start_time, right_arm.start_time)
103
+ end_time = max(left_arm.end_time, right_arm.end_time)
104
+
105
+ # Convert seconds to milliseconds
106
+ start_time /= 1000
107
+ end_time /= 1000
108
+ return start_time, end_time
109
+
110
+
111
+ def calculate_angle(a: tuple, b: tuple, c: tuple) -> float:
112
+ a = np.array(a) # First
113
+ b = np.array(b) # Mid
114
+ c = np.array(c) # End
115
+
116
+ radians = np.arctan2(c[1] - b[1], c[0] - b[0]) - np.arctan2(a[1] - b[1], a[0] - b[0])
117
+ angle = np.abs(radians * 180.0 / np.pi)
118
+
119
+ return 360 - angle if angle > 180 else angle
120
+
121
+
122
+ def ok_to_get_frame(
123
+ arm: Arm,
124
+ angle_threshold: int,
125
+ min_num_up_frames: int,
126
+ min_num_down_frames: int,
127
+ current_time: int,
128
+ delay: int,
129
+ ) -> bool:
130
+ if 0 < arm.angle < angle_threshold:
131
+ if arm.is_up:
132
+ arm.num_down_frames = 0
133
+ arm.end_time = 0
134
+ else:
135
+ if arm.num_up_frames == min_num_up_frames:
136
+ arm.is_up = True
137
+ arm.num_up_frames = 0
138
+ else:
139
+ if arm.num_up_frames == 0:
140
+ arm.start_time = current_time - delay
141
+ arm.num_up_frames += 1
142
+ return False
143
+ else:
144
+ if arm.is_up:
145
+ if arm.num_down_frames == min_num_down_frames:
146
+ arm.is_up = False
147
+ arm.num_down_frames = 0
148
+ else:
149
+ if arm.num_down_frames == 0:
150
+ arm.end_time = current_time + delay
151
+ arm.num_down_frames += 1
152
+ return True
153
+ else:
154
+ arm.num_up_frames = 0
155
+ arm.start_time = 0
156
+
157
+ return arm.is_up
src/inference.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import logging
3
+ from time import time
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import cv2
8
+ from traceback import format_exc
9
+ from argparse import Namespace
10
+ from transformers import Pipeline
11
+ from simple_parsing import ArgumentParser
12
+ import mediapipe as mp
13
+ from mediapipe.python.solutions.pose import PoseLandmark
14
+ from mediapipe.python.solutions.hands import HandLandmark
15
+ from mediapipe.python.solutions.drawing_utils import DrawingSpec
16
+
17
+ from visualization import draw_text_on_image
18
+ from configs import ModelConfig, InferenceConfig
19
+ from utils import config_logger, POSE_BASED_MODELS
20
+ from data import Arm, get_sample_timestamp, ok_to_get_frame
21
+ from tools import load_pipeline, Predictions
22
+
23
+
24
+ SPOTER_POSE_LANDMARKS = [
25
+ PoseLandmark.NOSE,
26
+ PoseLandmark.LEFT_EYE,
27
+ PoseLandmark.RIGHT_EYE,
28
+ PoseLandmark.RIGHT_SHOULDER,
29
+ PoseLandmark.LEFT_SHOULDER,
30
+ PoseLandmark.RIGHT_ELBOW,
31
+ PoseLandmark.LEFT_ELBOW,
32
+ PoseLandmark.RIGHT_WRIST,
33
+ PoseLandmark.LEFT_WRIST ]
34
+
35
+ SPOTER_HAND_LANDMARKS = [
36
+ HandLandmark.WRIST,
37
+ HandLandmark.INDEX_FINGER_TIP, HandLandmark.INDEX_FINGER_DIP, HandLandmark.INDEX_FINGER_PIP, HandLandmark.INDEX_FINGER_MCP,
38
+ HandLandmark.MIDDLE_FINGER_TIP, HandLandmark.MIDDLE_FINGER_DIP, HandLandmark.MIDDLE_FINGER_PIP, HandLandmark.MIDDLE_FINGER_MCP,
39
+ HandLandmark.RING_FINGER_TIP, HandLandmark.RING_FINGER_DIP, HandLandmark.RING_FINGER_PIP, HandLandmark.RING_FINGER_MCP,
40
+ HandLandmark.PINKY_TIP, HandLandmark.PINKY_DIP, HandLandmark.PINKY_PIP, HandLandmark.PINKY_MCP,
41
+ HandLandmark.THUMB_TIP, HandLandmark.THUMB_IP, HandLandmark.THUMB_MCP, HandLandmark.THUMB_CMC,
42
+ ]
43
+
44
+ def get_args() -> Namespace:
45
+ parser = ArgumentParser(
46
+ description="Train a model on VSL",
47
+ add_config_path_arg=True,
48
+ )
49
+ parser.add_arguments(ModelConfig, "model")
50
+ parser.add_arguments(InferenceConfig, "inference")
51
+ return parser.parse_args()
52
+
53
+
54
+ def inference(model_config, inference_config: InferenceConfig, pipeline: Pipeline) -> None:
55
+ # Load video
56
+ source = str(inference_config.source) if inference_config.source.is_file() else 0
57
+ cap = cv2.VideoCapture(source)
58
+ if inference_config.output_dir is not None:
59
+ writer = cv2.VideoWriter(
60
+ str(inference_config.output_dir / "output.mp4"),
61
+ cv2.VideoWriter_fourcc(*"mp4v"),
62
+ cap.get(cv2.CAP_PROP_FPS),
63
+ (int(cap.get(3)), int(cap.get(4))),
64
+ )
65
+
66
+ # Init Mediapipe
67
+ mp_holistic = mp.solutions.holistic
68
+ mp_drawing = mp.solutions.drawing_utils
69
+ mp_drawing_styles = mp.solutions.drawing_styles
70
+
71
+
72
+ custom_pose_style = mp_drawing_styles.get_default_pose_landmarks_style()
73
+ custom_right_hand_style = mp_drawing_styles.get_default_hand_landmarks_style()
74
+ custom_left_hand_style = mp_drawing_styles.get_default_hand_landmarks_style()
75
+ custom_pose_connections = list(mp_holistic.POSE_CONNECTIONS)
76
+ custom_hand_connections = list(mp_holistic.HAND_CONNECTIONS)
77
+
78
+ if inference_config.show_skeleton:
79
+ # if model_config.arch == 'spoter':
80
+ pose_landmarks = SPOTER_POSE_LANDMARKS
81
+ hand_landmarks = SPOTER_HAND_LANDMARKS
82
+
83
+ for landmark in PoseLandmark:
84
+ if landmark in pose_landmarks:
85
+ custom_pose_style[landmark] = DrawingSpec(color=(0,255,0), thickness=2, circle_radius=2)
86
+ else:
87
+ custom_pose_style[landmark] = DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0)
88
+ for connection_tuple in custom_pose_connections:
89
+ if landmark.value in connection_tuple:
90
+ custom_pose_connections.remove(connection_tuple)
91
+
92
+ for landmark in HandLandmark:
93
+ if landmark in hand_landmarks:
94
+ custom_right_hand_style[landmark] = DrawingSpec(color=(0,0,255), thickness=2, circle_radius=2)
95
+ custom_left_hand_style[landmark] = DrawingSpec(color=(255,0,0), thickness=2, circle_radius=2)
96
+ else:
97
+ custom_right_hand_style[HandLandmark[landmark.name]] = DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0)
98
+ custom_left_hand_style[HandLandmark[landmark.name]] = DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0)
99
+ for connection_tuple in custom_hand_connections:
100
+ if landmark.value in connection_tuple:
101
+ custom_hand_connections.remove(connection_tuple)
102
+
103
+
104
+ # Init variables
105
+ right_arm = Arm("right", inference_config.visibility)
106
+ left_arm = Arm("left", inference_config.visibility)
107
+ data = []
108
+ results = None
109
+ predictions = Predictions()
110
+
111
+ with mp_holistic.Holistic(min_detection_confidence=0.9, min_tracking_confidence=0.5) as holistic:
112
+ while cap.isOpened():
113
+ success, frame = cap.read()
114
+ if not success:
115
+ break
116
+
117
+ # Recolor image to RGB, because mp processes on RGB image
118
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
119
+ frame.flags.writeable = False
120
+
121
+ # Make detections
122
+ detection_results = holistic.process(frame)
123
+
124
+ # Recolor image back to BGR, because cv2 processes on BGR image
125
+ frame.flags.writeable = True
126
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
127
+
128
+ # Extract landmarks
129
+ try:
130
+ landmarks = detection_results.pose_landmarks.landmark
131
+ except Exception:
132
+ continue
133
+
134
+ left_arm.set_pose(landmarks)
135
+ right_arm.set_pose(landmarks)
136
+
137
+ # Check if arms are up or down
138
+ left_arm_ok_to_get_frame = ok_to_get_frame(
139
+ arm=left_arm,
140
+ angle_threshold=inference_config.angle_threshold,
141
+ min_num_up_frames=inference_config.min_num_up_frames,
142
+ min_num_down_frames=inference_config.min_num_down_frames,
143
+ current_time=cap.get(cv2.CAP_PROP_POS_MSEC),
144
+ delay=inference_config.delay,
145
+ )
146
+ right_arm_ok_to_get_frame = ok_to_get_frame(
147
+ arm=right_arm,
148
+ angle_threshold=inference_config.angle_threshold,
149
+ min_num_up_frames=inference_config.min_num_up_frames,
150
+ min_num_down_frames=inference_config.min_num_down_frames,
151
+ current_time=cap.get(cv2.CAP_PROP_POS_MSEC),
152
+ delay=inference_config.delay,
153
+ )
154
+ if left_arm_ok_to_get_frame or right_arm_ok_to_get_frame:
155
+ # logging.info("Frame added to the list")
156
+ predictions = Predictions()
157
+ data.append(detection_results if inference_config.use_pose_model else frame)
158
+
159
+ # Calculate the start and end time of sign
160
+ start_time, end_time = get_sample_timestamp(left_arm, right_arm)
161
+
162
+ # Convert from miliseconds to seconds
163
+ start_time /= 1_000
164
+ end_time /= 1_000
165
+
166
+ # logging.info(f"start_time: {start_time} - end_time: {end_time}")
167
+ # logging.info(f"\tLeft arm: {left_arm.start_time} - {left_arm.end_time} - {left_arm.is_up}")
168
+ # logging.info(f"\tRight arm: {right_arm.start_time} - {right_arm.end_time} - {right_arm.is_up}")
169
+
170
+ if start_time != 0 and end_time != 0:
171
+ # Render waiting screen
172
+ if inference_config.visualize:
173
+ wait_frame = draw_text_on_image(
174
+ np.zeros_like(frame),
175
+ text="Please wait for the prediction...",
176
+ position=(20, 20),
177
+ color=(255, 255, 255),
178
+ font_size=20,
179
+ )
180
+ cv2.imshow("Video Visualization", wait_frame)
181
+ if cv2.waitKey(1) & 0xFF == ord('q'):
182
+ break
183
+
184
+ start_inference_time = time()
185
+ predictions = Predictions(predictions=pipeline(np.array(data)))
186
+ predictions.inference_time = time() - start_inference_time
187
+
188
+ predictions.start_time = start_time
189
+ predictions.end_time = end_time
190
+ logging.info(str(predictions))
191
+ results = predictions.merge_results(results)
192
+
193
+ # Reset variables
194
+ start_time = 0
195
+ end_time = 0
196
+ left_arm.reset_state()
197
+ right_arm.reset_state()
198
+ data = []
199
+
200
+ # Render detections
201
+ frame = left_arm.visualize(frame, (20, 10), "Left arm angle")
202
+ frame = right_arm.visualize(frame, (20, 40), "Right arm angle")
203
+ frame = predictions.visualize(frame, (20, 70))
204
+ if inference_config.show_skeleton:
205
+ mp_drawing.draw_landmarks(
206
+ frame,
207
+ detection_results.pose_landmarks,
208
+ connections = custom_pose_connections, # passing the modified connections list
209
+ landmark_drawing_spec=custom_pose_style) # and drawing style
210
+
211
+ mp_drawing.draw_landmarks(
212
+ frame,
213
+ detection_results.right_hand_landmarks,
214
+ connections = custom_hand_connections, # passing the modified connections list
215
+ landmark_drawing_spec=custom_right_hand_style) # and drawing style
216
+
217
+ mp_drawing.draw_landmarks(
218
+ frame,
219
+ detection_results.left_hand_landmarks,
220
+ connections = custom_hand_connections, # passing the modified connections list
221
+ landmark_drawing_spec=custom_left_hand_style) # and drawing style
222
+
223
+ if inference_config.output_dir is not None:
224
+ writer.write(frame)
225
+
226
+ if inference_config.visualize:
227
+ cv2.imshow("Video Visualization", frame)
228
+ if cv2.waitKey(1) & 0xFF == ord('q'):
229
+ break
230
+
231
+ cap.release()
232
+ cv2.destroyAllWindows()
233
+
234
+ if inference_config.output_dir is not None:
235
+ writer.release()
236
+ logging.info(f"Video is recorded and saved to {inference_config.output_dir / 'output.avi'}")
237
+ pd.DataFrame(results).to_csv(inference_config.output_dir / "results.csv", index=False)
238
+ logging.info(f"Results saved to {inference_config.output_dir / 'results.csv'}")
239
+
240
+
241
+ def main(args: Namespace) -> None:
242
+ model_config = args.model
243
+ logging.info(model_config)
244
+ inference_config = args.inference
245
+ logging.info(inference_config)
246
+
247
+ if model_config.arch in POSE_BASED_MODELS:
248
+ inference_config.use_pose_model = True
249
+ else:
250
+ inference_config.use_pose_model = False
251
+
252
+ pipeline = load_pipeline(model_config, inference_config)
253
+ logging.info("Pipeline loaded")
254
+
255
+ inference(model_config, inference_config, pipeline)
256
+ logging.info("Inference completed")
257
+
258
+
259
+ if __name__ == "__main__":
260
+ try:
261
+ args = get_args()
262
+
263
+ config_logger(args.inference.output_dir / "inference.log")
264
+ logging.info(f"Config file loaded from {args.config_path[0]}")
265
+
266
+ shutil.copy(args.config_path[0], args.inference.output_dir / "inference.yaml")
267
+ logging.info(f"Config file saved to {args.inference.output_dir}")
268
+
269
+ main(args=args)
270
+ except Exception:
271
+ print(format_exc())
src/main.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.responses import JSONResponse
3
+ from pathlib import Path
4
+ import shutil
5
+ import logging
6
+ from inference import inference, get_args
7
+ from utils import config_logger
8
+ from tools import load_pipeline
9
+ from configs import ModelConfig, InferenceConfig
10
+
11
+ app = FastAPI()
12
+
13
+ @app.post("/upload-video/")
14
+ async def upload_video(file: UploadFile = File(...)):
15
+ if not file.filename.endswith(('.mp4', '.avi', '.mov', '.mkv')):
16
+ raise HTTPException(status_code=400, detail="Invalid file type. Only video files are allowed.")
17
+
18
+ # Save the uploaded file to a temporary location
19
+ temp_file_path = Path(f"temp_{file.filename}")
20
+ with temp_file_path.open("wb") as buffer:
21
+ shutil.copyfileobj(file.file, buffer)
22
+
23
+ # Load configurations
24
+ args = get_args()
25
+ model_config = args.model
26
+ inference_config = args.inference
27
+
28
+ # Update the source to the uploaded file
29
+ inference_config.source = temp_file_path
30
+
31
+ # Configure logger
32
+ config_logger(inference_config.output_dir / "inference.log")
33
+
34
+ # Load the pipeline
35
+ pipeline = load_pipeline(model_config, inference_config)
36
+
37
+ # Run inference
38
+ try:
39
+ inference(model_config, inference_config, pipeline)
40
+ except Exception as e:
41
+ logging.error(f"Error during inference: {str(e)}")
42
+ raise HTTPException(status_code=500, detail="Error during video processing")
43
+
44
+ # Clean up the temporary file
45
+ temp_file_path.unlink()
46
+
47
+ return JSONResponse(content={"message": "Video processed successfully"})
48
+
49
+ if __name__ == "__main__":
50
+ import uvicorn
51
+ uvicorn.run(app, host="0.0.0.0", port=8000)
src/tools/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .models import *
2
+ from .features import *
3
+ # from .utils import exists_on_hf
src/tools/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (203 Bytes). View file
 
src/tools/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (234 Bytes). View file
 
src/tools/__pycache__/features.cpython-39.pyc ADDED
Binary file (1.51 kB). View file
 
src/tools/__pycache__/models.cpython-312.pyc ADDED
Binary file (15.4 kB). View file
 
src/tools/__pycache__/models.cpython-39.pyc ADDED
Binary file (9.63 kB). View file
 
src/tools/features.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from configs import DataConfig
3
+ from features import BaseDataset, VSL98Dataset, VSL400Dataset
4
+
5
+
6
+ def load_dataset(data_config: DataConfig) -> BaseDataset:
7
+ '''
8
+ '''
9
+ datasets = {
10
+ 'vsl_98': VSL98Dataset,
11
+ "vsl_400": VSL400Dataset,
12
+ }
13
+ return datasets[data_config.dataset](data_config)
14
+
15
+
16
+ def rgb_collate_fn(examples) -> dict:
17
+ # permute to (num_frames, num_channels, height, width)
18
+ pixel_values = torch.stack(
19
+ [example["video"].permute(1, 0, 2, 3) for example in examples]
20
+ )
21
+ labels = torch.tensor([example["label"] for example in examples])
22
+ return {"pixel_values": pixel_values, "labels": labels}
23
+
24
+
25
+ def pose_collate_fn(examples) -> dict:
26
+ # permute to (num_frames, num_channels, height, width)
27
+ poses = torch.stack([example["pose"] for example in examples])
28
+ labels = torch.tensor([example["label"] for example in examples])
29
+ return {"poses": poses, "labels": labels}
src/tools/models.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+ import onnxruntime as ort
4
+ from time import time
5
+ from typing import Union
6
+ from configs import ModelConfig, InferenceConfig
7
+ from utils import (
8
+ POSE_BASED_MODELS,
9
+ RGB_BASED_MODELS,
10
+ HUGGINGFACE_RGB_BASED_MODELS,
11
+ TORCHHUB_RGB_BASED_MODELS,
12
+ )
13
+ from transformers import (
14
+ ImageProcessingMixin,
15
+ FeatureExtractionMixin,
16
+ AutoModelForVideoClassification,
17
+ AutoModel,
18
+ Pipeline,
19
+ pipeline,
20
+ )
21
+ from transformers.pipelines import PIPELINE_REGISTRY
22
+ from visualization import draw_text_on_image
23
+ from utils import exists_on_hf
24
+ from models import (
25
+ Swin3DConfig, Swin3DImageProcessor, Swin3DForVideoClassification,
26
+ S3DConfig, S3DImageProcessor, S3DForVideoClassification,
27
+ VideoResNetConfig, VideoResNetImageProcessor, VideoResNetForVideoClassification,
28
+ MViTConfig, MViTImageProcessor, MViTForVideoClassification,
29
+ SLGCNConfig, SLGCNFeatureExtractor, SLGCNForGraphClassification,
30
+ SPOTERConfig, SPOTERFeatureExtractor, SPOTERForGraphClassification,
31
+ DSTASLRConfig, DSTASLRFeatureExtractor, DSTASLRForGraphClassification,
32
+ VideoMAEConfig, VideoMAEImageProcessor, VideoMAEForVideoClassification
33
+ )
34
+ from pipelines import (
35
+ VideoClassificationPipeline,
36
+ SLGCNGraphClassificationPipeline,
37
+ SPOTERGraphClassificationPipeline,
38
+ )
39
+
40
+
41
+ def load_model(
42
+ model_config: ModelConfig,
43
+ label2id: dict = None,
44
+ id2label: dict = None,
45
+ do_train: bool = False,
46
+ ) -> tuple:
47
+ '''
48
+ '''
49
+ if do_train:
50
+ if model_config.arch in POSE_BASED_MODELS:
51
+ return load_pose_model_for_training(model_config, label2id, id2label)
52
+ return load_rgb_model_for_training(model_config, label2id, id2label)
53
+
54
+ if model_config.arch in POSE_BASED_MODELS:
55
+ processor = FeatureExtractionMixin.from_pretrained(
56
+ model_config.pretrained,
57
+ trust_remote_code=True,
58
+ cache_dir="models/huggingface",
59
+ )
60
+ model = AutoModel.from_pretrained(
61
+ model_config.pretrained,
62
+ trust_remote_code=True,
63
+ cache_dir="models/huggingface",
64
+ )
65
+ else:
66
+ processor = ImageProcessingMixin.from_pretrained(
67
+ model_config.pretrained,
68
+ trust_remote_code=True,
69
+ cache_dir="models/huggingface",
70
+ )
71
+ model = AutoModelForVideoClassification.from_pretrained(
72
+ model_config.pretrained,
73
+ trust_remote_code=True,
74
+ cache_dir="models/huggingface",
75
+ )
76
+ model.eval()
77
+ return model.config, processor, model
78
+
79
+
80
+ def load_rgb_model_for_training(
81
+ model_config: ModelConfig,
82
+ label2id: dict = None,
83
+ id2label: dict = None,
84
+ ) -> tuple:
85
+ '''
86
+ '''
87
+ if model_config.arch in HUGGINGFACE_RGB_BASED_MODELS:
88
+ if model_config.arch == "videomae":
89
+ config_class = VideoMAEConfig
90
+ processor_class = VideoMAEImageProcessor
91
+ model_class = VideoMAEForVideoClassification
92
+ elif exists_on_hf(model_config.pretrained):
93
+ processor = ImageProcessingMixin.from_pretrained(
94
+ model_config.pretrained,
95
+ trust_remote_code=True,
96
+ cache_dir="models/huggingface",
97
+ )
98
+ model = AutoModelForVideoClassification.from_pretrained(
99
+ model_config.pretrained,
100
+ label2id,
101
+ id2label,
102
+ ignore_mismatched_sizes=True,
103
+ trust_remote_code=True,
104
+ cache_dir="models/huggingface",
105
+ )
106
+ return model.config, processor, model
107
+ elif model_config.arch in TORCHHUB_RGB_BASED_MODELS:
108
+ if model_config.arch in ['swin3d_t', 'swin3d_s', 'swin3d_b']:
109
+ config_class = Swin3DConfig
110
+ processor_class = Swin3DImageProcessor
111
+ model_class = Swin3DForVideoClassification
112
+ elif model_config.arch in ['r3d_18', 'mc3_18', 'r2plus1d_18']:
113
+ config_class = VideoResNetConfig
114
+ processor_class = VideoResNetImageProcessor
115
+ model_class = VideoResNetForVideoClassification
116
+ elif model_config.arch in ['s3d']:
117
+ config_class = S3DConfig
118
+ processor_class = S3DImageProcessor
119
+ model_class = S3DForVideoClassification
120
+ elif model_config.arch in ['mvit_v1_b', 'mvit_v2_s']:
121
+ config_class = MViTConfig
122
+ processor_class = MViTImageProcessor
123
+ model_class = MViTForVideoClassification
124
+ else:
125
+ logging.error(f"Model {model_config.arch} is not supported")
126
+ exit(1)
127
+
128
+ config_class.register_for_auto_class()
129
+ processor_class.register_for_auto_class("AutoImageProcessor")
130
+ model_class.register_for_auto_class("AutoModel")
131
+ model_class.register_for_auto_class("AutoModelForVideoClassification")
132
+ logging.info(f"{model_config.arch} classes registered")
133
+
134
+ config = config_class(**vars(model_config))
135
+ processor = processor_class(config=config)
136
+ model = model_class(config=config, label2id=label2id, id2label=id2label)
137
+
138
+ return config, processor, model
139
+
140
+
141
+ def load_pose_model_for_training(
142
+ model_config: ModelConfig,
143
+ label2id: dict = None,
144
+ id2label: dict = None,
145
+ ) -> tuple:
146
+ '''
147
+ '''
148
+ if exists_on_hf(model_config.pretrained):
149
+ processor = FeatureExtractionMixin.from_pretrained(
150
+ model_config.pretrained,
151
+ trust_remote_code=True,
152
+ cache_dir="models/huggingface",
153
+ )
154
+ model = AutoModel.from_pretrained(
155
+ model_config.pretrained,
156
+ label2id=label2id,
157
+ id2label=id2label,
158
+ ignore_mismatched_sizes=True,
159
+ trust_remote_code=True,
160
+ cache_dir="models/huggingface",
161
+ )
162
+ return model.config, processor, model
163
+ elif model_config.arch in POSE_BASED_MODELS:
164
+ if model_config.arch == "spoter":
165
+ config_class = SPOTERConfig
166
+ processor_class = SPOTERFeatureExtractor
167
+ model_class = SPOTERForGraphClassification
168
+ elif model_config.arch == "sl_gcn":
169
+ config_class = SLGCNConfig
170
+ processor_class = SLGCNFeatureExtractor
171
+ model_class = SLGCNForGraphClassification
172
+ elif model_config.arch == "dsta_slr":
173
+ config_class = DSTASLRConfig
174
+ processor_class = DSTASLRFeatureExtractor
175
+ model_class = DSTASLRForGraphClassification
176
+ else:
177
+ logging.error(f"Model {model_config.arch} is not supported")
178
+ exit(1)
179
+
180
+ config_class.register_for_auto_class()
181
+ processor_class.register_for_auto_class("AutoFeatureExtractor")
182
+ model_class.register_for_auto_class("AutoModel")
183
+ logging.info(F"Registering {model_config.arch} classes")
184
+
185
+ config = config_class(**vars(model_config))
186
+ processor = processor_class(config=config)
187
+ model = model_class(config=config, label2id=label2id, id2label=id2label)
188
+
189
+ return config, processor, model
190
+
191
+
192
+ class Predictions:
193
+ def __init__(
194
+ self,
195
+ predictions: list[dict] = None,
196
+ inference_time: float = 0,
197
+ start_time: float = 0,
198
+ end_time: float = 0,
199
+ ) -> None:
200
+ self.predictions = predictions
201
+ self.inference_time = inference_time
202
+ self.start_time = start_time
203
+ self.end_time = end_time
204
+
205
+ def visualize(
206
+ self,
207
+ frame: torch.Tensor,
208
+ position: tuple = (20, 100),
209
+ prefix: str = "Predictions",
210
+ color: tuple = (0, 0, 255),
211
+ ) -> None:
212
+ text = prefix + ": " + self.get_pred_message()
213
+ return draw_text_on_image(
214
+ image=frame,
215
+ text=text,
216
+ position=position,
217
+ color=color,
218
+ font_size=20,
219
+ )
220
+
221
+ def get_pred_message(self) -> str:
222
+ if not any((
223
+ self.start_time,
224
+ self.end_time,
225
+ self.inference_time,
226
+ self.predictions
227
+ )):
228
+ return ""
229
+
230
+ return ', '.join(
231
+ [
232
+ f"{pred['gloss']} ({pred['score']*100:.2f}%)"
233
+ for pred in self.predictions
234
+ ]
235
+ )
236
+
237
+ def __str__(self) -> str:
238
+ if not any((
239
+ self.start_time,
240
+ self.end_time,
241
+ self.inference_time,
242
+ self.predictions
243
+ )):
244
+ return ""
245
+
246
+ predictions = self.get_pred_message()
247
+ message = "Sample start: {:.2f}s - end: {:.2f}s | Runtime: {:.2f}s | Predictions: {}"
248
+ return message.format(self.start_time, self.end_time, self.inference_time, predictions)
249
+
250
+ def merge_results(self, results: dict = None) -> dict:
251
+ if results is None:
252
+ results = {
253
+ "start_time": [],
254
+ "end_time": [],
255
+ "inference_time": [],
256
+ "prediction": [],
257
+ }
258
+ results["start_time"].append(self.start_time)
259
+ results["end_time"].append(self.end_time)
260
+ results["inference_time"].append(self.inference_time)
261
+ results["prediction"].append(self.predictions)
262
+ return results
263
+
264
+
265
+ def get_predictions(
266
+ inputs: torch.Tensor,
267
+ model: Union[ort.InferenceSession, AutoModel],
268
+ id2gloss: dict,
269
+ k: int = 3,
270
+ ) -> Predictions:
271
+ '''
272
+ Get the top-k predictions.
273
+ Parameters
274
+ ----------
275
+ inputs : torch.Tensor
276
+ Model inputs (Time, Height, Width, Channels).
277
+ model : Union[ort.InferenceSession, AutoModel]
278
+ Model to get predictions from.
279
+ id2gloss : dict
280
+ Mapping of class indices to glosses.
281
+ k : int, optional
282
+ Number of predictions to return, by default 3.
283
+ Returns
284
+ -------
285
+ tuple
286
+ List of top-k predictions and inference time.
287
+ '''
288
+ if inputs is None:
289
+ return Predictions()
290
+
291
+ # Get logits
292
+ start_time = time()
293
+ if isinstance(model, ort.InferenceSession):
294
+ inputs = inputs.cpu().numpy()
295
+ logits = torch.from_numpy(model.run(None, {"pixel_values": inputs})[0])
296
+ else:
297
+ logits = model(inputs.to(model.device)).logits
298
+ inference_time = time() - start_time
299
+
300
+ # Get top-3 predictions
301
+ topk_scores, topk_indices = torch.topk(logits, k, dim=1)
302
+ topk_scores = torch.nn.functional.softmax(topk_scores, dim=1).squeeze().detach().numpy()
303
+ topk_indices = topk_indices.squeeze().detach().numpy()
304
+ predictions = [
305
+ {
306
+ 'gloss': id2gloss[str(topk_indices[i])],
307
+ 'score': topk_scores[i],
308
+ }
309
+ for i in range(k)
310
+ ]
311
+
312
+ return Predictions(predictions=predictions, inference_time=inference_time)
313
+
314
+
315
+ def register_pipeline(model_config: ModelConfig) -> Pipeline:
316
+ '''
317
+ '''
318
+ _, processor, model = load_model(model_config)
319
+
320
+ if model_config.arch == "spoter":
321
+ PIPELINE_REGISTRY.register_pipeline(
322
+ "video-classification",
323
+ pipeline_class=SPOTERGraphClassificationPipeline,
324
+ pt_model=AutoModel,
325
+ default={"pt": ("vsltranslation/spoter_v3.0", "main")},
326
+ type="multimodal",
327
+ )
328
+ return SPOTERGraphClassificationPipeline(
329
+ model=model,
330
+ feature_extractor=processor,
331
+ )
332
+ elif model_config.arch in ["sl_gcn", "dsta_slr"]:
333
+ PIPELINE_REGISTRY.register_pipeline(
334
+ "video-classification",
335
+ pipeline_class=SLGCNGraphClassificationPipeline,
336
+ pt_model=AutoModel,
337
+ default={"pt": ("vsltranslation/sl_gcn_joint_v1.0", "main")},
338
+ type="multimodal",
339
+ )
340
+ return SLGCNGraphClassificationPipeline(
341
+ model=model,
342
+ feature_extractor=processor,
343
+ )
344
+
345
+ PIPELINE_REGISTRY.register_pipeline(
346
+ "video-classification",
347
+ pipeline_class=VideoClassificationPipeline,
348
+ pt_model=AutoModelForVideoClassification,
349
+ default={"pt": ("vsltranslation/swin3d_t_v1.0", "main")},
350
+ type="multimodal",
351
+ )
352
+ return VideoClassificationPipeline(
353
+ model=model,
354
+ image_processor=processor,
355
+ )
356
+
357
+
358
+ def load_pipeline(
359
+ model_config: ModelConfig,
360
+ inference_config: InferenceConfig,
361
+ ) -> Pipeline:
362
+ '''
363
+ '''
364
+ if model_config.arch in POSE_BASED_MODELS:
365
+ return pipeline(
366
+ "video-classification",
367
+ model=model_config.pretrained,
368
+ feature_extractor=model_config.pretrained,
369
+ device=inference_config.device,
370
+ model_kwargs={
371
+ "cache_dir": inference_config.cache_dir,
372
+ },
373
+ trust_remote_code=True,
374
+ use_onnx=inference_config.use_onnx,
375
+ top_k=inference_config.top_k,
376
+ bone_stream=inference_config.bone_stream,
377
+ motion_stream=inference_config.motion_stream,
378
+ )
379
+
380
+ return pipeline(
381
+ "video-classification",
382
+ model=model_config.pretrained,
383
+ image_processor=model_config.pretrained,
384
+ device=inference_config.device,
385
+ model_kwargs={
386
+ "cache_dir": inference_config.cache_dir,
387
+ },
388
+ trust_remote_code=True,
389
+ use_onnx=inference_config.use_onnx,
390
+ top_k=inference_config.top_k,
391
+ )
392
+
393
+
394
+ def get_input_shape(
395
+ arch: str,
396
+ processor: Union[ImageProcessingMixin, FeatureExtractionMixin],
397
+ batch_size: int = 1,
398
+ ) -> tuple:
399
+ '''
400
+ Get the input shape for the model.
401
+ Parameters
402
+ ----------
403
+ processor : Union[ImageProcessingMixin, FeatureExtractionMixin]
404
+ Model processor.
405
+ batch_size : int, optional
406
+ Batch size, by default 1.
407
+ Returns
408
+ -------
409
+ tuple
410
+ Input shape.
411
+ '''
412
+ if arch in RGB_BASED_MODELS:
413
+ return (
414
+ batch_size,
415
+ processor.num_frames,
416
+ 3,
417
+ processor.size["height"],
418
+ processor.size["width"]
419
+ )
420
+ elif arch in POSE_BASED_MODELS:
421
+ if arch == "spoter":
422
+ return (
423
+ batch_size,
424
+ processor.num_frames,
425
+ processor.num_points,
426
+ processor.in_channels,
427
+ )
428
+ elif arch in ["sl_gcn", "dsta_slr"]:
429
+ return (
430
+ batch_size,
431
+ processor.in_channels,
432
+ processor.window_size,
433
+ processor.num_points,
434
+ processor.num_people,
435
+ )
436
+ else:
437
+ logging.error(f"Model {arch} is not supported")
438
+ exit(1)
439
+ else:
440
+ logging.error(f"Model {arch} is not supported")
441
+ exit(1)
src/utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .loggers import *
2
+ from .constants import *
src/utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (205 Bytes). View file
 
src/utils/__pycache__/constants.cpython-312.pyc ADDED
Binary file (4.35 kB). View file
 
src/utils/__pycache__/loggers.cpython-312.pyc ADDED
Binary file (1.59 kB). View file
 
src/utils/constants.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
5
+
6
+ TORCHHUB_RGB_BASED_MODELS = (
7
+ 'swin3d_t',
8
+ 'swin3d_s',
9
+ 'swin3d_b',
10
+ "r3d_18",
11
+ "mc3_18",
12
+ "r2plus1d_18",
13
+ "s3d",
14
+ "mvit_v1_b",
15
+ "mvit_v2_s",
16
+ )
17
+ HUGGINGFACE_RGB_BASED_MODELS = (
18
+ "videomae",
19
+ )
20
+ RGB_BASED_MODELS = HUGGINGFACE_RGB_BASED_MODELS + TORCHHUB_RGB_BASED_MODELS
21
+
22
+ POSE_BASED_MODELS = (
23
+ "spoter",
24
+ "sl_gcn",
25
+ "dsta_slr"
26
+ )
27
+
28
+ MODELS = RGB_BASED_MODELS + POSE_BASED_MODELS
29
+
30
+ HAND_LANDMARKS = [
31
+ "wrist",
32
+ "indexTip",
33
+ "indexDIP",
34
+ "indexPIP",
35
+ "indexMCP",
36
+ "middleTip",
37
+ "middleDIP",
38
+ "middlePIP",
39
+ "middleMCP",
40
+ "ringTip",
41
+ "ringDIP",
42
+ "ringPIP",
43
+ "ringMCP",
44
+ "littleTip",
45
+ "littleDIP",
46
+ "littlePIP",
47
+ "littleMCP",
48
+ "thumbTip",
49
+ "thumbIP",
50
+ "thumbMP",
51
+ "thumbCMC",
52
+ ]
53
+ BODY_LANDMARKS = [
54
+ "nose",
55
+ "neck",
56
+ "rightEye",
57
+ "leftEye",
58
+ "rightEar",
59
+ "leftEar",
60
+ "rightShoulder",
61
+ "leftShoulder",
62
+ "rightElbow",
63
+ "leftElbow",
64
+ "rightWrist",
65
+ "leftWrist",
66
+ ]
67
+ ARM_LANDMARKS_ORDER = ["neck", "$side$Shoulder", "$side$Elbow", "$side$Wrist"]
68
+
69
+ FLIP_IDXS = np.concatenate(
70
+ (
71
+ [0, 2, 1, 4, 3, 6, 5],
72
+ [17, 18, 19, 20, 21, 22, 23, 24, 25, 26],
73
+ [7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
74
+ ),
75
+ axis=0,
76
+ )
77
+
78
+ SLGCN_JOINTS = {
79
+ 59: np.concatenate((np.arange(0, 17), np.arange(91, 133)), axis=0), # 59
80
+ 31: np.concatenate(
81
+ (
82
+ np.arange(0, 11),
83
+ [91, 95, 96, 99, 100, 103, 104, 107, 108, 111],
84
+ [112, 116, 117, 120, 121, 124, 125, 128, 129, 132],
85
+ ),
86
+ axis=0,
87
+ ), # 31
88
+ 27: np.concatenate(
89
+ (
90
+ [0, 5, 6, 7, 8, 9, 10],
91
+ [91, 95, 96, 99, 100, 103, 104, 107, 108, 111],
92
+ [112, 116, 117, 120, 121, 124, 125, 128, 129, 132],
93
+ ),
94
+ axis=0,
95
+ ), # 27
96
+ }
97
+
98
+ COCO_TO_POSE_FORMAT = {
99
+ 0: ("POSE_LANDMARKS", "NOSE"),
100
+ 1: ("POSE_LANDMARKS", "LEFT_EYE"),
101
+ 2: ("POSE_LANDMARKS", "RIGHT_EYE"),
102
+ 3: ("POSE_LANDMARKS", "LEFT_EAR"),
103
+ 4: ("POSE_LANDMARKS", "RIGHT_EAR"),
104
+ 5: ("POSE_LANDMARKS", "LEFT_SHOULDER"),
105
+ 6: ("POSE_LANDMARKS", "RIGHT_SHOULDER"),
106
+ 7: ("POSE_LANDMARKS", "LEFT_ELBOW"),
107
+ 8: ("POSE_LANDMARKS", "RIGHT_ELBOW"),
108
+ 9: ("POSE_LANDMARKS", "LEFT_WRIST"),
109
+ 10: ("POSE_LANDMARKS", "RIGHT_WRIST"),
110
+ 11: ("POSE_LANDMARKS", "LEFT_HIP"),
111
+ 12: ("POSE_LANDMARKS", "RIGHT_HIP"),
112
+ 13: ("POSE_LANDMARKS", "LEFT_KNEE"),
113
+ 14: ("POSE_LANDMARKS", "RIGHT_KNEE"),
114
+ 15: ("POSE_LANDMARKS", "LEFT_ANKLE"),
115
+ 16: ("POSE_LANDMARKS", "RIGHT_ANKLE"),
116
+ 91: ("LEFT_HAND_LANDMARKS", "WRIST"),
117
+ 92: ("LEFT_HAND_LANDMARKS", "THUMB_CMC"),
118
+ 93: ("LEFT_HAND_LANDMARKS", "THUMB_MCP"),
119
+ 94: ("LEFT_HAND_LANDMARKS", "THUMB_IP"),
120
+ 95: ("LEFT_HAND_LANDMARKS", "THUMB_TIP"),
121
+ 96: ("LEFT_HAND_LANDMARKS", "INDEX_FINGER_MCP"),
122
+ 97: ("LEFT_HAND_LANDMARKS", "INDEX_FINGER_PIP"),
123
+ 98: ("LEFT_HAND_LANDMARKS", "INDEX_FINGER_DIP"),
124
+ 99: ("LEFT_HAND_LANDMARKS", "INDEX_FINGER_TIP"),
125
+ 100: ("LEFT_HAND_LANDMARKS", "MIDDLE_FINGER_MCP"),
126
+ 101: ("LEFT_HAND_LANDMARKS", "MIDDLE_FINGER_PIP"),
127
+ 102: ("LEFT_HAND_LANDMARKS", "MIDDLE_FINGER_DIP"),
128
+ 103: ("LEFT_HAND_LANDMARKS", "MIDDLE_FINGER_TIP"),
129
+ 104: ("LEFT_HAND_LANDMARKS", "RING_FINGER_MCP"),
130
+ 105: ("LEFT_HAND_LANDMARKS", "RING_FINGER_PIP"),
131
+ 106: ("LEFT_HAND_LANDMARKS", "RING_FINGER_DIP"),
132
+ 107: ("LEFT_HAND_LANDMARKS", "RING_FINGER_TIP"),
133
+ 108: ("LEFT_HAND_LANDMARKS", "PINKY_MCP"),
134
+ 109: ("LEFT_HAND_LANDMARKS", "PINKY_PIP"),
135
+ 110: ("LEFT_HAND_LANDMARKS", "PINKY_DIP"),
136
+ 111: ("LEFT_HAND_LANDMARKS", "PINKY_TIP"),
137
+ 112: ("RIGHT_HAND_LANDMARKS", "WRIST"),
138
+ 113: ("RIGHT_HAND_LANDMARKS", "THUMB_CMC"),
139
+ 114: ("RIGHT_HAND_LANDMARKS", "THUMB_MCP"),
140
+ 115: ("RIGHT_HAND_LANDMARKS", "THUMB_IP"),
141
+ 116: ("RIGHT_HAND_LANDMARKS", "THUMB_TIP"),
142
+ 117: ("RIGHT_HAND_LANDMARKS", "INDEX_FINGER_MCP"),
143
+ 118: ("RIGHT_HAND_LANDMARKS", "INDEX_FINGER_PIP"),
144
+ 119: ("RIGHT_HAND_LANDMARKS", "INDEX_FINGER_DIP"),
145
+ 120: ("RIGHT_HAND_LANDMARKS", "INDEX_FINGER_TIP"),
146
+ 121: ("RIGHT_HAND_LANDMARKS", "MIDDLE_FINGER_MCP"),
147
+ 122: ("RIGHT_HAND_LANDMARKS", "MIDDLE_FINGER_PIP"),
148
+ 123: ("RIGHT_HAND_LANDMARKS", "MIDDLE_FINGER_DIP"),
149
+ 124: ("RIGHT_HAND_LANDMARKS", "MIDDLE_FINGER_TIP"),
150
+ 125: ("RIGHT_HAND_LANDMARKS", "RING_FINGER_MCP"),
151
+ 126: ("RIGHT_HAND_LANDMARKS", "RING_FINGER_PIP"),
152
+ 127: ("RIGHT_HAND_LANDMARKS", "RING_FINGER_DIP"),
153
+ 128: ("RIGHT_HAND_LANDMARKS", "RING_FINGER_TIP"),
154
+ 129: ("RIGHT_HAND_LANDMARKS", "PINKY_MCP"),
155
+ 130: ("RIGHT_HAND_LANDMARKS", "PINKY_PIP"),
156
+ 131: ("RIGHT_HAND_LANDMARKS", "PINKY_DIP"),
157
+ 132: ("RIGHT_HAND_LANDMARKS", "PINKY_TIP"),
158
+ }
src/utils/loggers.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import logging
3
+ from pathlib import Path
4
+ from transformers import TrainerCallback
5
+
6
+
7
+ class TrainingCallback(TrainerCallback):
8
+ def on_log(self, args, state, control, logs=None, **kwargs):
9
+ logging.info(logs)
10
+
11
+
12
+ def config_logger(log_file: str = None) -> None:
13
+ handlers = [logging.StreamHandler(sys.stdout)]
14
+ if log_file is not None:
15
+ log_dir = Path(log_file).parent
16
+ if not log_dir.exists():
17
+ log_dir.mkdir(parents=True, exist_ok=True)
18
+ handlers.append(logging.FileHandler(filename=log_file))
19
+ logging.basicConfig(
20
+ datefmt="%m/%d/%Y %H:%M:%S",
21
+ level=logging.INFO,
22
+ format="[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s",
23
+ handlers=handlers
24
+ )
src/visualization/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .utils import *
src/visualization/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (187 Bytes). View file
 
src/visualization/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (221 Bytes). View file
 
src/visualization/__pycache__/utils.cpython-312.pyc ADDED
Binary file (2.44 kB). View file
 
src/visualization/__pycache__/utils.cpython-39.pyc ADDED
Binary file (1.7 kB). View file
 
src/visualization/utils.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from imageio import mimsave
4
+ from PIL import Image, ImageDraw, ImageFont
5
+
6
+
7
+ def unnormalize_img(image: np.ndarray, std: tuple, mean: tuple) -> np.ndarray:
8
+ image = (image * std) + mean
9
+ image = (image * 255).astype('uint8')
10
+ return image.clip(0, 255)
11
+
12
+
13
+ def save_as_gif(
14
+ video_tensor: torch.Tensor,
15
+ save_path: str = 'sample.gif',
16
+ std: tuple = None,
17
+ mean: tuple = None,
18
+ ):
19
+ frames = []
20
+ for video_frame in video_tensor:
21
+ frame_unnormalized = unnormalize_img(
22
+ image=video_frame.permute(1, 2, 0).numpy(),
23
+ std=std,
24
+ mean=mean,
25
+ )
26
+ frames.append(frame_unnormalized)
27
+ kargs = {'duration': 0.25}
28
+ mimsave(save_path, frames, 'GIF', **kargs)
29
+ return save_path
30
+
31
+
32
+ def display_gif(gif_path: str) -> Image:
33
+ return Image(filename=gif_path)
34
+
35
+
36
+ def draw_text_on_image(
37
+ image: np.ndarray,
38
+ text: str,
39
+ position: tuple = (20, 20),
40
+ color: tuple = (0, 0, 255),
41
+ font_size: int = 20,
42
+ ) -> np.ndarray:
43
+ font = ImageFont.truetype(
44
+ font="fonts/OpenSans-Regular.ttf",
45
+ size=font_size,
46
+ )
47
+ pil_image = Image.fromarray(image)
48
+ draw = ImageDraw.Draw(pil_image)
49
+ draw.text(
50
+ xy=position,
51
+ text=text,
52
+ fill=color,
53
+ font=font,
54
+ )
55
+ return np.array(pil_image)