votuongquan2004@gmail.com commited on
Commit
e317171
1 Parent(s): ede92c4
Files changed (9) hide show
  1. .gitattributes copy +35 -0
  2. .gitignore +160 -0
  3. README copy.md +12 -0
  4. VSL_SAM_SLR_V2.onnx +3 -0
  5. app.py +99 -0
  6. gloss.csv +100 -0
  7. requirements.txt +9 -0
  8. utils/data.py +325 -0
  9. utils/model.py +47 -0
.gitattributes copy ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
README copy.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SAM SLR V2
3
+ emoji: 📉
4
+ colorFrom: blue
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 4.29.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
VSL_SAM_SLR_V2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:509a2c017b6539366c59ee8a90408f62cec2484bb10078241f0af1845263ad90
3
+ size 16709176
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import pandas as pd
3
+ import gradio as gr
4
+ import onnxruntime as ort
5
+ from mediapipe.python.solutions import holistic
6
+ from utils import get_predictions, preprocess
7
+
8
+
9
+ title = '''
10
+
11
+ '''
12
+
13
+ cite_markdown = '''
14
+
15
+ '''
16
+
17
+ description = '''
18
+
19
+ '''
20
+
21
+ examples = [
22
+ ['000_con_cho.mp4'],
23
+ ]
24
+
25
+ # Load the configuration file.
26
+ ort_session = ort.InferenceSession('VSL_SAM_SLR_V2.onnx')
27
+
28
+ # Load id-to-gloss mapping.
29
+ id2gloss = pd.read_csv('gloss.csv', names=['id', 'gloss']).to_dict()['gloss']
30
+
31
+
32
+ def inference(
33
+ video: str,
34
+ progress: gr.Progress = gr.Progress(),
35
+ ) -> str:
36
+ '''
37
+ Video-based inference for Vietnamese Sign Language recognition.
38
+
39
+ Parameters
40
+ ----------
41
+ video : str
42
+ The path to the video.
43
+ progress : gr.Progress, optional
44
+ The progress bar, by default gr.Progress()
45
+
46
+ Returns
47
+ -------
48
+ str
49
+ The inference message.
50
+ '''
51
+ keypoints_detector = holistic.Holistic(
52
+ static_image_mode=False,
53
+ model_complexity=2,
54
+ enable_segmentation=True,
55
+ refine_face_landmarks=True,
56
+ )
57
+
58
+ progress(0, desc='Preprocessing video')
59
+ start_time = time.time()
60
+ inputs = preprocess(
61
+ source=video,
62
+ keypoints_detector=keypoints_detector,
63
+ )
64
+ end_time = time.time()
65
+ data_time = end_time - start_time
66
+
67
+ progress(1/2, desc='Getting predictions')
68
+ start_time = time.time()
69
+ predictions = get_predictions(
70
+ inputs=inputs, ort_session=ort_session, id2gloss=id2gloss, k=3
71
+ )
72
+ end_time = time.time()
73
+ model_time = end_time - start_time
74
+
75
+ if len(predictions) == 0:
76
+ output_message = 'No sign language detected in the video. Please try again.'
77
+ else:
78
+ output_message = 'The top-3 predictions are:\n'
79
+ for i, prediction in enumerate(predictions):
80
+ output_message += f'\t{i+1}. {prediction["label"]} ({prediction["score"]:2f})\n'
81
+ output_message += f'Data processing time: {data_time:.2f} seconds\n'
82
+ output_message += f'Model inference time: {model_time:.2f} seconds\n'
83
+ output_message += f'Total time: {data_time + model_time:.2f} seconds'
84
+
85
+ progress(1/2, desc='Completed')
86
+
87
+ return output_message
88
+
89
+
90
+ iface = gr.Interface(
91
+ fn=inference,
92
+ inputs='video',
93
+ outputs='text',
94
+ examples=examples,
95
+ title=title,
96
+ description=description,
97
+ )
98
+ iface.launch()
99
+ # print(inference('000_con_cho.mp4'))
gloss.csv ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 0,Con chó
2
+ 1,Con mèo
3
+ 2,Con gà
4
+ 3,Con vịt
5
+ 4,Con rùa
6
+ 5,Con thỏ
7
+ 6,Con trâu
8
+ 7,Con bò
9
+ 8,Con dê
10
+ 9,Con heo
11
+ 10,Màu đen
12
+ 11,Màu trắng
13
+ 12,Màu đỏ
14
+ 13,Màu cam
15
+ 14,Màu vàng
16
+ 15,Màu lá cây
17
+ 16,Màu da trời
18
+ 17,Màu hồng
19
+ 18,Màu tím
20
+ 19,Màu nâu
21
+ 20,Quả dâu
22
+ 21,Quả mận
23
+ 22,Quả dứa
24
+ 23,Quả đào
25
+ 24,Quả đu đủ
26
+ 25,Quả cam
27
+ 26,Quả bơ
28
+ 27,Quả chuối
29
+ 28,Quả xoài
30
+ 29,Quả dừa
31
+ 30,Bố
32
+ 31,Mẹ
33
+ 32,Con trai
34
+ 33,Con gái
35
+ 34,Vợ
36
+ 35,Chồng
37
+ 36,Ông nội
38
+ 37,Bà nội
39
+ 38,Ông ngoại
40
+ 39,Bà ngoại
41
+ 40,Ăn
42
+ 41,Uống
43
+ 42,Xem
44
+ 43,Thèm
45
+ 44,Mách
46
+ 45,Khóc
47
+ 46,Cười
48
+ 47,Học
49
+ 48,Dỗi
50
+ 49,Chết
51
+ 50,Đi
52
+ 51,Chạy
53
+ 52,Bận
54
+ 53,Hát
55
+ 54,Múa
56
+ 55,Nấu
57
+ 56,Nướng
58
+ 57,Nhầm lẫn
59
+ 58,Quan sát
60
+ 59,Cắm trại
61
+ 60,Cung cấp
62
+ 61,Bắt chước
63
+ 62,Bắt buộc
64
+ 63,Báo cáo
65
+ 64,Mua bán
66
+ 65,Không quen
67
+ 66,Không nên
68
+ 67,Không cần
69
+ 68,Không cho
70
+ 69,Không nghe lời
71
+ 70,Mặn
72
+ 71,Đắng
73
+ 72,Cay
74
+ 73,Ngọt
75
+ 74,Đậm
76
+ 75,Nhạt
77
+ 76,Ngon miệng
78
+ 77,Xấu
79
+ 78,Đẹp
80
+ 79,Chật
81
+ 80,Hẹp
82
+ 81,Rộng
83
+ 82,Dài
84
+ 83,Cao
85
+ 84,Lùn
86
+ 85,Ốm
87
+ 86,Mập
88
+ 87,Ngoan
89
+ 88,Hư
90
+ 89,Khỏe
91
+ 90,Mệt
92
+ 91,Đau
93
+ 92,Giỏi
94
+ 93,Chăm chỉ
95
+ 94,Lười biếng
96
+ 95,Tốt bụng
97
+ 96,Thú vị
98
+ 97,Hài hước
99
+ 98,Dũng cảm
100
+ 99,Sáng tạo
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ opencv-python
5
+ mediapipe
6
+ timm
7
+ einops
8
+ yacs
9
+ onnxruntime
utils/data.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import random
3
+ import numpy as np
4
+ from mediapipe.python.solutions import pose
5
+
6
+
7
+ SELECTED_JOINTS = {
8
+ 27: {
9
+ 'pose': [0, 11, 12, 13, 14, 15, 16],
10
+ 'hand': [0, 4, 5, 8, 9, 12, 13, 16, 17, 20],
11
+ }, # 27
12
+ }
13
+
14
+
15
+ def pad(joints: np.ndarray, num_frames: int = 150) -> np.ndarray:
16
+ '''
17
+ Add padding to the joints.
18
+
19
+ Parameters
20
+ ----------
21
+ joints : np.ndarray
22
+ The joints to pad.
23
+ num_frames : int, default=150
24
+ The number of frames to pad.
25
+
26
+ Returns
27
+ -------
28
+ np.ndarray
29
+ The padded joints.
30
+ '''
31
+ if joints.shape[0] < num_frames:
32
+ L = joints.shape[0]
33
+ padded_joints = np.zeros((num_frames, joints.shape[1], joints.shape[2]))
34
+ padded_joints[:L, :, :] = joints
35
+ rest = num_frames - L
36
+ num = int(np.ceil(rest / L))
37
+ pad = np.concatenate([joints for _ in range(num)], 0)[:rest]
38
+ padded_joints[L:, :, :] = pad
39
+ else:
40
+ padded_joints = joints[:num_frames]
41
+ return padded_joints
42
+
43
+
44
+ def extract_joints(
45
+ source: str,
46
+ keypoints_detector,
47
+ resize_to: tuple = (256, 256),
48
+ num_joints: int = 27,
49
+ num_frames: int = 150,
50
+ num_bodies: int = 1,
51
+ num_channels: int = 3,
52
+ ) -> np.ndarray:
53
+ '''
54
+ Extract the joints from the video.
55
+
56
+ Parameters
57
+ ----------
58
+ source : str
59
+ The path to the video.
60
+ keypoints_detector : mediapipe.solutions.holistic.Holistic
61
+ The keypoints detector.
62
+ resize_to : tuple, default=(256, 256)
63
+ The size to resize the image.
64
+ num_joints : int, default=27
65
+ The number of joints.
66
+ num_frames : int, default=150
67
+ The number of frames.
68
+ num_bodies : int, default=1
69
+ The number of bodies.
70
+ num_channels : int, default=3
71
+ The number of channels.
72
+
73
+ Returns
74
+ -------
75
+ np.ndarray
76
+ The extracted joints.
77
+ '''
78
+ cap = cv2.VideoCapture(source)
79
+
80
+ extracted_joints = []
81
+ while cap.isOpened():
82
+ success, image = cap.read()
83
+ if not success:
84
+ break
85
+ image = cv2.resize(image, resize_to)
86
+ image = cv2.flip(image, flipCode=1)
87
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
88
+ frame_joints = []
89
+
90
+ results = keypoints_detector.process(image)
91
+
92
+ pose = [(0.0, 0.0, 0.0)] * len(SELECTED_JOINTS[num_joints]['pose'])
93
+ if results.pose_landmarks is not None:
94
+ pose = [
95
+ (landmark.x * resize_to[0], landmark.y * resize_to[1], landmark.visibility)
96
+ for i, landmark in enumerate(results.pose_landmarks.landmark)
97
+ if i in SELECTED_JOINTS[num_joints]['pose']
98
+ ]
99
+ frame_joints.extend(pose)
100
+
101
+ left_hand = [(0.0, 0.0, 0.0)] * len(SELECTED_JOINTS[num_joints]['hand'])
102
+ if results.left_hand_landmarks is not None:
103
+ left_hand = [
104
+ (landmark.x * resize_to[0], landmark.y * resize_to[1], landmark.visibility)
105
+ for i, landmark in enumerate(results.left_hand_landmarks.landmark)
106
+ if i in SELECTED_JOINTS[num_joints]['hand']
107
+ ]
108
+ frame_joints.extend(left_hand)
109
+
110
+ right_hand = [(0.0, 0.0, 0.0)] * len(SELECTED_JOINTS[num_joints]['hand'])
111
+ if results.right_hand_landmarks is not None:
112
+ right_hand = [
113
+ (landmark.x * resize_to[0], landmark.y * resize_to[1], landmark.visibility)
114
+ for i, landmark in enumerate(results.right_hand_landmarks.landmark)
115
+ if i in SELECTED_JOINTS[num_joints]['hand']
116
+ ]
117
+ frame_joints.extend(right_hand)
118
+
119
+ assert len(frame_joints) == num_joints, \
120
+ f'Expected {num_joints} joints, got {len(frame_joints)} joints.'
121
+ extracted_joints.append(frame_joints)
122
+
123
+ extracted_joints = np.array(extracted_joints)
124
+ extracted_joints = pad(extracted_joints, num_frames=num_frames)
125
+
126
+ fp = np.zeros(
127
+ (num_frames, num_joints, num_channels, num_bodies),
128
+ dtype=np.float32,
129
+ )
130
+ fp[:, :, :, 0] = extracted_joints
131
+
132
+ return np.transpose(fp, [2, 0, 1, 3])
133
+
134
+
135
+ def preprocess(
136
+ source: str,
137
+ data_args: dict,
138
+ keypoints_detector,
139
+ device: str = 'cpu',
140
+ ) -> torch.Tensor:
141
+ '''
142
+ Preprocess the video.
143
+
144
+ Parameters
145
+ ----------
146
+ source : str
147
+ The path to the video.
148
+
149
+ Returns
150
+ -------
151
+ dict
152
+ The model inputs.
153
+ '''
154
+ print('Extracting joints from pose...')
155
+ inputs = extract_joints_new(source=source, keypoints_detector=keypoints_detector)
156
+ T = inputs.shape[1]
157
+ print('Sampling video...')
158
+ if data_args['random_choose']:
159
+ inputs = random_sample_np(inputs, data_args['window_size'])
160
+ else:
161
+ inputs = uniform_sample_np(inputs, data_args['window_size'])
162
+
163
+ print('Normalizing video...')
164
+ print(inputs.shape, inputs)
165
+ return np.squeeze(inputs).transpose(1, 2, 0).astype(np.float32)
166
+
167
+
168
+ def random_sample_np(data: np.ndarray, size: int) -> np.ndarray:
169
+ '''
170
+ Sample the data randomly.
171
+
172
+ Parameters
173
+ ----------
174
+ data : np.ndarray
175
+ The data to sample.
176
+ size : int
177
+ The size of the data to sample.
178
+
179
+ Returns
180
+ -------
181
+ np.ndarray
182
+ The sampled data.
183
+ '''
184
+ C, T, V, M = data.shape
185
+ if T == size:
186
+ return data
187
+ interval = int(np.ceil(size / T))
188
+ random_list = sorted(random.sample(list(range(T))*interval, size))
189
+ return data[:, random_list]
190
+
191
+
192
+ def uniform_sample_np(data: np.ndarray, size: int) -> np.ndarray:
193
+ '''
194
+ Sample the data uniformly.
195
+
196
+ Parameters
197
+ ----------
198
+ data : np.ndarray
199
+ The data to sample.
200
+ size : int
201
+ The size of the data to sample.
202
+
203
+ Returns
204
+ -------
205
+ np.ndarray
206
+ The sampled data.
207
+ '''
208
+ C, T, V, M = data.shape
209
+ if T == size:
210
+ return data
211
+ interval = T / size
212
+ uniform_list = [int(i * interval) for i in range(size)]
213
+ return data[:, uniform_list]
214
+
215
+
216
+ def calculate_angle(
217
+ shoulder: list,
218
+ elbow: list,
219
+ wrist: list,
220
+ ) -> float:
221
+ '''
222
+ Calculate the angle between the shoulder, elbow, and wrist.
223
+
224
+ Parameters
225
+ ----------
226
+ shoulder : list
227
+ Shoulder coordinates.
228
+ elbow : list
229
+ Elbow coordinates.
230
+ wrist : list
231
+ Wrist coordinates.
232
+
233
+ Returns
234
+ -------
235
+ float
236
+ Angle in degree between the shoulder, elbow, and wrist.
237
+ '''
238
+ shoulder = np.array(shoulder)
239
+ elbow = np.array(elbow)
240
+ wrist = np.array(wrist)
241
+
242
+ radians = np.arctan2(wrist[1] - elbow[1], wrist[0] - elbow[0]) \
243
+ - np.arctan2(shoulder[1] - elbow[1], shoulder[0] - elbow[0])
244
+ angle = np.abs(radians * 180.0 / np.pi)
245
+
246
+ if angle > 180.0:
247
+ angle = 360 - angle
248
+ return angle
249
+
250
+
251
+ def do_hands_relax(
252
+ pose_landmarks: list,
253
+ angle_threshold: float = 160.0,
254
+ ) -> bool:
255
+ '''
256
+ Check if the hand is down.
257
+
258
+ Parameters
259
+ ----------
260
+ hand_landmarks : list
261
+ Hand landmarks.
262
+ angle_threshold : float, optional
263
+ Angle threshold, by default 160.0.
264
+
265
+ Returns
266
+ -------
267
+ bool
268
+ True if the hand is down, False otherwise.
269
+ '''
270
+ if pose_landmarks is None:
271
+ return True
272
+
273
+ landmarks = pose_landmarks.landmark
274
+ left_shoulder = [
275
+ landmarks[pose.PoseLandmark.LEFT_SHOULDER.value].x,
276
+ landmarks[pose.PoseLandmark.LEFT_SHOULDER.value].y,
277
+ landmarks[pose.PoseLandmark.LEFT_SHOULDER.value].visibility,
278
+ ]
279
+ left_elbow = [
280
+ landmarks[pose.PoseLandmark.LEFT_ELBOW.value].x,
281
+ landmarks[pose.PoseLandmark.LEFT_ELBOW.value].y,
282
+ landmarks[pose.PoseLandmark.LEFT_SHOULDER.value].visibility,
283
+ ]
284
+ left_wrist = [
285
+ landmarks[pose.PoseLandmark.LEFT_WRIST.value].x,
286
+ landmarks[pose.PoseLandmark.LEFT_WRIST.value].y,
287
+ landmarks[pose.PoseLandmark.LEFT_SHOULDER.value].visibility,
288
+ ]
289
+ left_angle = calculate_angle(left_shoulder, left_elbow, left_wrist)
290
+
291
+ right_shoulder = [
292
+ landmarks[pose.PoseLandmark.RIGHT_SHOULDER.value].x,
293
+ landmarks[pose.PoseLandmark.RIGHT_SHOULDER.value].y,
294
+ landmarks[pose.PoseLandmark.RIGHT_SHOULDER.value].visibility,
295
+ ]
296
+ right_elbow = [
297
+ landmarks[pose.PoseLandmark.RIGHT_ELBOW.value].x,
298
+ landmarks[pose.PoseLandmark.RIGHT_ELBOW.value].y,
299
+ landmarks[pose.PoseLandmark.RIGHT_SHOULDER.value].visibility,
300
+ ]
301
+ right_wrist = [
302
+ landmarks[pose.PoseLandmark.RIGHT_WRIST.value].x,
303
+ landmarks[pose.PoseLandmark.RIGHT_WRIST.value].y,
304
+ landmarks[pose.PoseLandmark.RIGHT_SHOULDER.value].visibility,
305
+ ]
306
+ right_angle = calculate_angle(right_shoulder, right_elbow, right_wrist)
307
+
308
+ is_visible = all(
309
+ [
310
+ left_shoulder[2] > 0,
311
+ left_elbow[2] > 0,
312
+ left_wrist[2] > 0,
313
+ right_shoulder[2] > 0,
314
+ right_elbow[2] > 0,
315
+ right_wrist[2] > 0,
316
+ ]
317
+ )
318
+
319
+ return all(
320
+ [
321
+ is_visible,
322
+ left_angle < angle_threshold,
323
+ right_angle < angle_threshold,
324
+ ]
325
+ )
utils/model.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import onnxruntime as ort
4
+
5
+
6
+ def get_predictions(
7
+ inputs: np.ndarray,
8
+ ort_session: ort.InferenceSession,
9
+ id2gloss: dict,
10
+ k: int = 3,
11
+ ) -> list:
12
+ '''
13
+ Get the top-k predictions.
14
+
15
+ Parameters
16
+ ----------
17
+ inputs : dict
18
+ Model inputs.
19
+ ort_session : ort.InferenceSession
20
+ ONNX Runtime session.
21
+ id2gloss : dict
22
+ Mapping from class index to class label.
23
+ k : int, optional
24
+ Number of predictions to return, by default 3.
25
+
26
+ Returns
27
+ -------
28
+ list
29
+ Top-k predictions.
30
+ '''
31
+ if inputs is None:
32
+ return []
33
+
34
+ logits = torch.from_numpy(ort_session.run(None, {'x': inputs})[0])
35
+
36
+ # Get top-3 predictions
37
+ topk_scores, topk_indices = torch.topk(logits, k, dim=1)
38
+ topk_scores = torch.nn.functional.softmax(topk_scores, dim=1).squeeze().detach().numpy()
39
+ topk_indices = topk_indices.squeeze().detach().numpy()
40
+
41
+ return [
42
+ {
43
+ 'label': id2gloss[topk_indices[i]],
44
+ 'score': topk_scores[i],
45
+ }
46
+ for i in range(k)
47
+ ]