Add model
Browse files- .gitignore +134 -0
- README.md +9 -6
- convert_weights.py +203 -0
- repnet/__init__.py +0 -0
- repnet/model.py +192 -0
- repnet/plots.py +66 -0
- repnet/utils.py +41 -0
- requirements.txt +7 -0
- run.py +116 -0
.gitignore
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.env
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
|
113 |
+
# Spyder project settings
|
114 |
+
.spyderproject
|
115 |
+
.spyproject
|
116 |
+
|
117 |
+
# Rope project settings
|
118 |
+
.ropeproject
|
119 |
+
|
120 |
+
# mkdocs documentation
|
121 |
+
/site
|
122 |
+
|
123 |
+
# mypy
|
124 |
+
.mypy_cache/
|
125 |
+
.dmypy.json
|
126 |
+
dmypy.json
|
127 |
+
|
128 |
+
# Pyre type checker
|
129 |
+
.pyre/
|
130 |
+
|
131 |
+
# Project specific
|
132 |
+
checkpoints
|
133 |
+
videos
|
134 |
+
visualizations
|
README.md
CHANGED
@@ -8,15 +8,18 @@ datasets:
|
|
8 |
---
|
9 |
|
10 |
# RepNet PyTorch
|
|
|
|
|
|
|
11 |
A PyTorch port with pre-trained weights of **RepNet**, from *Counting Out Time: Class Agnostic Video Repetition Counting in the Wild* (CVPR 2020) [[paper]](https://arxiv.org/abs/2006.15418) [[project]](https://sites.google.com/view/repnet) [[notebook]](https://colab.research.google.com/github/google-research/google-research/blob/master/repnet/repnet_colab.ipynb#scrollTo=FUg2vSYhmsT0).
|
12 |
|
13 |
This repo provides an implementation of RepNet written in PyTorch and a script to convert the pre-trained TensorFlow weights provided by the authors. The outputs of the two implementations are almost identical, with a small deviation (less than $10^{-6}$ at most) probably caused by the [limited precision of floating point operations](https://pytorch.org/docs/stable/notes/numerical_accuracy.html).
|
14 |
|
15 |
<div align="center">
|
16 |
-
<img src="img/example1.gif" height="160" />
|
17 |
-
<img src="img/example2.gif" height="160" />
|
18 |
-
<img src="img/example3.gif" height="160" />
|
19 |
-
<img src="img/example4.gif" height="160" />
|
20 |
</div>
|
21 |
|
22 |
## Get Started
|
@@ -45,6 +48,6 @@ If the model does not produce good results, try to run the script with more stri
|
|
45 |
|
46 |
Example of generated videos showing the repetition count, with the periodicity score and the temporal self-similarity matrix:
|
47 |
<div align="center">
|
48 |
-
<img src="img/example5_score.gif" height="200" />
|
49 |
-
<img src="img/example5_tsm.png" height="200" />
|
50 |
</div>
|
|
|
8 |
---
|
9 |
|
10 |
# RepNet PyTorch
|
11 |
+
|
12 |
+
GitHub repository: https://github.com/materight/RepNet-pytorch.
|
13 |
+
|
14 |
A PyTorch port with pre-trained weights of **RepNet**, from *Counting Out Time: Class Agnostic Video Repetition Counting in the Wild* (CVPR 2020) [[paper]](https://arxiv.org/abs/2006.15418) [[project]](https://sites.google.com/view/repnet) [[notebook]](https://colab.research.google.com/github/google-research/google-research/blob/master/repnet/repnet_colab.ipynb#scrollTo=FUg2vSYhmsT0).
|
15 |
|
16 |
This repo provides an implementation of RepNet written in PyTorch and a script to convert the pre-trained TensorFlow weights provided by the authors. The outputs of the two implementations are almost identical, with a small deviation (less than $10^{-6}$ at most) probably caused by the [limited precision of floating point operations](https://pytorch.org/docs/stable/notes/numerical_accuracy.html).
|
17 |
|
18 |
<div align="center">
|
19 |
+
<img src="https://raw.githubusercontent.com/materight/RepNet-pytorch/main/img/example1.gif" height="160" />
|
20 |
+
<img src="https://raw.githubusercontent.com/materight/RepNet-pytorch/main/img/example2.gif" height="160" />
|
21 |
+
<img src="https://raw.githubusercontent.com/materight/RepNet-pytorch/main/img/example3.gif" height="160" />
|
22 |
+
<img src="https://raw.githubusercontent.com/materight/RepNet-pytorch/main/img/example4.gif" height="160" />
|
23 |
</div>
|
24 |
|
25 |
## Get Started
|
|
|
48 |
|
49 |
Example of generated videos showing the repetition count, with the periodicity score and the temporal self-similarity matrix:
|
50 |
<div align="center">
|
51 |
+
<img src="https://raw.githubusercontent.com/materight/RepNet-pytorch/main/img/example5_score.gif" height="200" />
|
52 |
+
<img src="https://raw.githubusercontent.com/materight/RepNet-pytorch/main/img/example5_tsm.png" height="200" />
|
53 |
</div>
|
convert_weights.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Script to download the pre-trained tensorflow weights and convert them to pytorch weights."""
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from tensorflow.python.training import py_checkpoint_reader
|
7 |
+
|
8 |
+
from repnet import utils
|
9 |
+
from repnet.model import RepNet
|
10 |
+
|
11 |
+
|
12 |
+
# Relevant paths
|
13 |
+
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
|
14 |
+
TF_CHECKPOINT_BASE_URL = 'https://storage.googleapis.com/repnet_ckpt'
|
15 |
+
TF_CHECKPOINT_FILES = ['checkpoint', 'ckpt-88.data-00000-of-00002', 'ckpt-88.data-00001-of-00002', 'ckpt-88.index']
|
16 |
+
OUT_CHECKPOINTS_DIR = os.path.join(PROJECT_ROOT, 'checkpoints')
|
17 |
+
|
18 |
+
# Mapping of ndim -> permutation to go from tf to pytorch
|
19 |
+
WEIGHTS_PERMUTATION = {
|
20 |
+
2: (1, 0),
|
21 |
+
4: (3, 2, 0, 1),
|
22 |
+
5: (4, 3, 0, 1, 2)
|
23 |
+
}
|
24 |
+
|
25 |
+
# Mapping of tf attributes -> pytorch attributes
|
26 |
+
ATTR_MAPPING = {
|
27 |
+
'kernel':'weight',
|
28 |
+
'bias': 'bias',
|
29 |
+
'beta': 'bias',
|
30 |
+
'gamma': 'weight',
|
31 |
+
'moving_mean': 'running_mean',
|
32 |
+
'moving_variance': 'running_var'
|
33 |
+
}
|
34 |
+
|
35 |
+
# Mapping of tf checkpoint -> tf model -> pytorch model
|
36 |
+
WEIGHTS_MAPPING = [
|
37 |
+
# Base frame encoder
|
38 |
+
('base_model.layer-2', 'conv1_conv', 'encoder.stem.conv'),
|
39 |
+
('base_model.layer-5', 'conv2_block1_preact_bn', 'encoder.stages.0.blocks.0.norm1'),
|
40 |
+
('base_model.layer-7', 'conv2_block1_1_conv', 'encoder.stages.0.blocks.0.conv1'),
|
41 |
+
('base_model.layer-8', 'conv2_block1_1_bn', 'encoder.stages.0.blocks.0.norm2'),
|
42 |
+
('base_model.layer_with_weights-4', 'conv2_block1_2_conv', 'encoder.stages.0.blocks.0.conv2'),
|
43 |
+
('base_model.layer_with_weights-5', 'conv2_block1_2_bn', 'encoder.stages.0.blocks.0.norm3'),
|
44 |
+
('base_model.layer_with_weights-6', 'conv2_block1_0_conv', 'encoder.stages.0.blocks.0.downsample.conv'),
|
45 |
+
('base_model.layer_with_weights-7', 'conv2_block1_3_conv', 'encoder.stages.0.blocks.0.conv3'),
|
46 |
+
('base_model.layer_with_weights-8', 'conv2_block2_preact_bn', 'encoder.stages.0.blocks.1.norm1'),
|
47 |
+
('base_model.layer_with_weights-9', 'conv2_block2_1_conv', 'encoder.stages.0.blocks.1.conv1'),
|
48 |
+
('base_model.layer_with_weights-10', 'conv2_block2_1_bn', 'encoder.stages.0.blocks.1.norm2'),
|
49 |
+
('base_model.layer_with_weights-11', 'conv2_block2_2_conv', 'encoder.stages.0.blocks.1.conv2'),
|
50 |
+
('base_model.layer_with_weights-12', 'conv2_block2_2_bn', 'encoder.stages.0.blocks.1.norm3'),
|
51 |
+
('base_model.layer_with_weights-13', 'conv2_block2_3_conv', 'encoder.stages.0.blocks.1.conv3'),
|
52 |
+
('base_model.layer_with_weights-14', 'conv2_block3_preact_bn', 'encoder.stages.0.blocks.2.norm1'),
|
53 |
+
('base_model.layer_with_weights-15', 'conv2_block3_1_conv', 'encoder.stages.0.blocks.2.conv1'),
|
54 |
+
('base_model.layer_with_weights-16', 'conv2_block3_1_bn', 'encoder.stages.0.blocks.2.norm2'),
|
55 |
+
('base_model.layer_with_weights-17', 'conv2_block3_2_conv', 'encoder.stages.0.blocks.2.conv2'),
|
56 |
+
('base_model.layer_with_weights-18', 'conv2_block3_2_bn', 'encoder.stages.0.blocks.2.norm3'),
|
57 |
+
('base_model.layer_with_weights-19', 'conv2_block3_3_conv', 'encoder.stages.0.blocks.2.conv3'),
|
58 |
+
('base_model.layer_with_weights-20', 'conv3_block1_preact_bn', 'encoder.stages.1.blocks.0.norm1'),
|
59 |
+
('base_model.layer_with_weights-21', 'conv3_block1_1_conv', 'encoder.stages.1.blocks.0.conv1'),
|
60 |
+
('base_model.layer_with_weights-22', 'conv3_block1_1_bn', 'encoder.stages.1.blocks.0.norm2'),
|
61 |
+
('base_model.layer_with_weights-23', 'conv3_block1_2_conv', 'encoder.stages.1.blocks.0.conv2'),
|
62 |
+
('base_model.layer-47', 'conv3_block1_2_bn', 'encoder.stages.1.blocks.0.norm3'),
|
63 |
+
('base_model.layer_with_weights-25', 'conv3_block1_0_conv', 'encoder.stages.1.blocks.0.downsample.conv'),
|
64 |
+
('base_model.layer_with_weights-26', 'conv3_block1_3_conv', 'encoder.stages.1.blocks.0.conv3'),
|
65 |
+
('base_model.layer_with_weights-27', 'conv3_block2_preact_bn', 'encoder.stages.1.blocks.1.norm1'),
|
66 |
+
('base_model.layer_with_weights-28', 'conv3_block2_1_conv', 'encoder.stages.1.blocks.1.conv1'),
|
67 |
+
('base_model.layer_with_weights-29', 'conv3_block2_1_bn', 'encoder.stages.1.blocks.1.norm2'),
|
68 |
+
('base_model.layer_with_weights-30', 'conv3_block2_2_conv', 'encoder.stages.1.blocks.1.conv2'),
|
69 |
+
('base_model.layer_with_weights-31', 'conv3_block2_2_bn', 'encoder.stages.1.blocks.1.norm3'),
|
70 |
+
('base_model.layer-61', 'conv3_block2_3_conv', 'encoder.stages.1.blocks.1.conv3'),
|
71 |
+
('base_model.layer-63', 'conv3_block3_preact_bn', 'encoder.stages.1.blocks.2.norm1'),
|
72 |
+
('base_model.layer-65', 'conv3_block3_1_conv', 'encoder.stages.1.blocks.2.conv1'),
|
73 |
+
('base_model.layer-66', 'conv3_block3_1_bn', 'encoder.stages.1.blocks.2.norm2'),
|
74 |
+
('base_model.layer-69', 'conv3_block3_2_conv', 'encoder.stages.1.blocks.2.conv2'),
|
75 |
+
('base_model.layer-70', 'conv3_block3_2_bn', 'encoder.stages.1.blocks.2.norm3'),
|
76 |
+
('base_model.layer_with_weights-38', 'conv3_block3_3_conv', 'encoder.stages.1.blocks.2.conv3'),
|
77 |
+
('base_model.layer-74', 'conv3_block4_preact_bn', 'encoder.stages.1.blocks.3.norm1'),
|
78 |
+
('base_model.layer_with_weights-40', 'conv3_block4_1_conv', 'encoder.stages.1.blocks.3.conv1'),
|
79 |
+
('base_model.layer_with_weights-41', 'conv3_block4_1_bn', 'encoder.stages.1.blocks.3.norm2'),
|
80 |
+
('base_model.layer_with_weights-42', 'conv3_block4_2_conv', 'encoder.stages.1.blocks.3.conv2'),
|
81 |
+
('base_model.layer_with_weights-43', 'conv3_block4_2_bn', 'encoder.stages.1.blocks.3.norm3'),
|
82 |
+
('base_model.layer_with_weights-44', 'conv3_block4_3_conv', 'encoder.stages.1.blocks.3.conv3'),
|
83 |
+
('base_model.layer_with_weights-45', 'conv4_block1_preact_bn', 'encoder.stages.2.blocks.0.norm1'),
|
84 |
+
('base_model.layer_with_weights-46', 'conv4_block1_1_conv', 'encoder.stages.2.blocks.0.conv1'),
|
85 |
+
('base_model.layer_with_weights-47', 'conv4_block1_1_bn', 'encoder.stages.2.blocks.0.norm2'),
|
86 |
+
('base_model.layer-92', 'conv4_block1_2_conv', 'encoder.stages.2.blocks.0.conv2'),
|
87 |
+
('base_model.layer-93', 'conv4_block1_2_bn', 'encoder.stages.2.blocks.0.norm3'),
|
88 |
+
('base_model.layer-95', 'conv4_block1_0_conv', 'encoder.stages.2.blocks.0.downsample.conv'),
|
89 |
+
('base_model.layer-96', 'conv4_block1_3_conv', 'encoder.stages.2.blocks.0.conv3'),
|
90 |
+
('base_model.layer-98', 'conv4_block2_preact_bn', 'encoder.stages.2.blocks.1.norm1'),
|
91 |
+
('base_model.layer-100', 'conv4_block2_1_conv', 'encoder.stages.2.blocks.1.conv1'),
|
92 |
+
('base_model.layer-101', 'conv4_block2_1_bn', 'encoder.stages.2.blocks.1.norm2'),
|
93 |
+
('base_model.layer-104', 'conv4_block2_2_conv', 'encoder.stages.2.blocks.1.conv2'),
|
94 |
+
('base_model.layer-105', 'conv4_block2_2_bn', 'encoder.stages.2.blocks.1.norm3'),
|
95 |
+
('base_model.layer-107', 'conv4_block2_3_conv', 'encoder.stages.2.blocks.1.conv3'),
|
96 |
+
('base_model.layer-109', 'conv4_block3_preact_bn', 'encoder.stages.2.blocks.2.norm1'),
|
97 |
+
('base_model.layer-111', 'conv4_block3_1_conv', 'encoder.stages.2.blocks.2.conv1'),
|
98 |
+
('base_model.layer-112', 'conv4_block3_1_bn', 'encoder.stages.2.blocks.2.norm2'),
|
99 |
+
('base_model.layer-115', 'conv4_block3_2_conv', 'encoder.stages.2.blocks.2.conv2'),
|
100 |
+
('base_model.layer-116', 'conv4_block3_2_bn', 'encoder.stages.2.blocks.2.norm3'),
|
101 |
+
('base_model.layer-118', 'conv4_block3_3_conv', 'encoder.stages.2.blocks.2.conv3'),
|
102 |
+
# Temporal convolution
|
103 |
+
('temporal_conv_layers.0', 'conv3d', 'temporal_conv.0'),
|
104 |
+
('temporal_bn_layers.0', 'batch_normalization', 'temporal_conv.1'),
|
105 |
+
('conv_3x3_layer', 'conv2d', 'tsm_conv.0'),
|
106 |
+
# Period length head
|
107 |
+
('input_projection', 'dense', 'period_length_head.0.input_projection'),
|
108 |
+
('pos_encoding', None, 'period_length_head.0.pos_encoding'),
|
109 |
+
('transformer_layers.0.ffn.layer-0', None, 'period_length_head.0.transformer_layer.linear1'),
|
110 |
+
('transformer_layers.0.ffn.layer-1', None, 'period_length_head.0.transformer_layer.linear2'),
|
111 |
+
('transformer_layers.0.layernorm1', None, 'period_length_head.0.transformer_layer.norm1'),
|
112 |
+
('transformer_layers.0.layernorm2', None, 'period_length_head.0.transformer_layer.norm2'),
|
113 |
+
('transformer_layers.0.mha.w_weight', None, 'period_length_head.0.transformer_layer.self_attn.in_proj_weight'),
|
114 |
+
('transformer_layers.0.mha.w_bias', None, 'period_length_head.0.transformer_layer.self_attn.in_proj_bias'),
|
115 |
+
('transformer_layers.0.mha.dense', None, 'period_length_head.0.transformer_layer.self_attn.out_proj'),
|
116 |
+
('fc_layers.0', 'dense_14', 'period_length_head.1'),
|
117 |
+
('fc_layers.1', 'dense_15', 'period_length_head.3'),
|
118 |
+
('fc_layers.2', 'dense_16', 'period_length_head.5'),
|
119 |
+
# Periodicity head
|
120 |
+
('input_projection2', 'dense_1', 'periodicity_head.0.input_projection'),
|
121 |
+
('pos_encoding2', None, 'periodicity_head.0.pos_encoding'),
|
122 |
+
('transformer_layers2.0.ffn.layer-0', None, 'periodicity_head.0.transformer_layer.linear1'),
|
123 |
+
('transformer_layers2.0.ffn.layer-1', None, 'periodicity_head.0.transformer_layer.linear2'),
|
124 |
+
('transformer_layers2.0.layernorm1', None, 'periodicity_head.0.transformer_layer.norm1'),
|
125 |
+
('transformer_layers2.0.layernorm2', None, 'periodicity_head.0.transformer_layer.norm2'),
|
126 |
+
('transformer_layers2.0.mha.w_weight',None, 'periodicity_head.0.transformer_layer.self_attn.in_proj_weight'),
|
127 |
+
('transformer_layers2.0.mha.w_bias', None, 'periodicity_head.0.transformer_layer.self_attn.in_proj_bias'),
|
128 |
+
('transformer_layers2.0.mha.dense', None, 'periodicity_head.0.transformer_layer.self_attn.out_proj'),
|
129 |
+
('within_period_fc_layers.0', 'dense_17', 'periodicity_head.1'),
|
130 |
+
('within_period_fc_layers.1', 'dense_18', 'periodicity_head.3'),
|
131 |
+
('within_period_fc_layers.2', 'dense_19', 'periodicity_head.5'),
|
132 |
+
]
|
133 |
+
|
134 |
+
# Script arguments
|
135 |
+
parser = argparse.ArgumentParser(description='Download and convert the pre-trained weights from tensorflow to pytorch.')
|
136 |
+
|
137 |
+
|
138 |
+
if __name__ == '__main__':
|
139 |
+
args = parser.parse_args()
|
140 |
+
|
141 |
+
# Download tensorflow checkpoints
|
142 |
+
print('Downloading checkpoints...')
|
143 |
+
tf_checkpoint_dir = os.path.join(OUT_CHECKPOINTS_DIR, 'tf_checkpoint')
|
144 |
+
os.makedirs(tf_checkpoint_dir, exist_ok=True)
|
145 |
+
for file in TF_CHECKPOINT_FILES:
|
146 |
+
dst = os.path.join(tf_checkpoint_dir, file)
|
147 |
+
if not os.path.exists(dst):
|
148 |
+
utils.download_file(f'{TF_CHECKPOINT_BASE_URL}/{file}', dst)
|
149 |
+
|
150 |
+
# Load tensorflow weights into a dictionary
|
151 |
+
print('Loading tensorflow checkpoint...')
|
152 |
+
checkpoint_path = os.path.join(tf_checkpoint_dir, 'ckpt-88')
|
153 |
+
checkpoint_reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path)
|
154 |
+
shape_map = checkpoint_reader.get_variable_to_shape_map()
|
155 |
+
tf_state_dict = {}
|
156 |
+
for var_name in sorted(shape_map.keys()):
|
157 |
+
var_tensor = checkpoint_reader.get_tensor(var_name)
|
158 |
+
if not var_name.startswith('model') or '.OPTIMIZER_SLOT' in var_name:
|
159 |
+
continue # Skip variables that are not part of the model, e.g. from the optimizer
|
160 |
+
# Split var_name into path
|
161 |
+
var_path = var_name.split('/')[1:] # Remove `model`` key from the path
|
162 |
+
var_path = [p for p in var_path if p not in ['.ATTRIBUTES', 'VARIABLE_VALUE']]
|
163 |
+
# Map weights into a nested dictionary
|
164 |
+
current_dict = tf_state_dict
|
165 |
+
for path in var_path[:-1]:
|
166 |
+
current_dict = current_dict.setdefault(path, {})
|
167 |
+
current_dict[var_path[-1]] = var_tensor
|
168 |
+
|
169 |
+
# Merge transformer self-attention weights into a single tensor
|
170 |
+
for k in ['transformer_layers', 'transformer_layers2']:
|
171 |
+
v = tf_state_dict[k]['0']['mha']
|
172 |
+
v['w_weight'] = np.concatenate([v['wq']['kernel'].T, v['wk']['kernel'].T, v['wv']['kernel'].T], axis=0)
|
173 |
+
v['w_bias'] = np.concatenate([v['wq']['bias'].T, v['wk']['bias'].T, v['wv']['bias'].T], axis=0)
|
174 |
+
del v['wk'], v['wq'], v['wv']
|
175 |
+
tf_state_dict = utils.flatten_dict(tf_state_dict, keep_last=True)
|
176 |
+
# Add missing final level for some weights
|
177 |
+
for k, v in tf_state_dict.items():
|
178 |
+
if not isinstance(v, dict):
|
179 |
+
tf_state_dict[k] = {None: v}
|
180 |
+
|
181 |
+
# Convert to a format compatible with PyTorch and save
|
182 |
+
print(f'Converting to PyTorch format...')
|
183 |
+
pt_checkpoint_path = os.path.join(OUT_CHECKPOINTS_DIR, 'pytorch_weights.pth')
|
184 |
+
pt_state_dict = {}
|
185 |
+
for k_tf, _, k_pt in WEIGHTS_MAPPING:
|
186 |
+
assert k_pt not in pt_state_dict
|
187 |
+
pt_state_dict[k_pt] = {}
|
188 |
+
for attr in tf_state_dict[k_tf]:
|
189 |
+
new_attr = ATTR_MAPPING.get(attr, attr)
|
190 |
+
pt_state_dict[k_pt][new_attr] = torch.from_numpy(tf_state_dict[k_tf][attr])
|
191 |
+
if attr == 'kernel':
|
192 |
+
weights_permutation = WEIGHTS_PERMUTATION[pt_state_dict[k_pt][new_attr].ndim] # Permute weights if needed
|
193 |
+
pt_state_dict[k_pt][new_attr] = pt_state_dict[k_pt][new_attr].permute(weights_permutation)
|
194 |
+
pt_state_dict = utils.flatten_dict(pt_state_dict, skip_none=True)
|
195 |
+
torch.save(pt_state_dict, pt_checkpoint_path)
|
196 |
+
|
197 |
+
# Initialize the model and try to load the weights
|
198 |
+
print('Check that the weights can be loaded into the model...')
|
199 |
+
model = RepNet()
|
200 |
+
pt_state_dict = torch.load(pt_checkpoint_path)
|
201 |
+
model.load_state_dict(pt_state_dict)
|
202 |
+
|
203 |
+
print(f'Done. PyTorch weights saved to {pt_checkpoint_path}.')
|
repnet/__init__.py
ADDED
File without changes
|
repnet/model.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""PyTorch implementation of RepNet."""
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from typing import Tuple
|
5 |
+
|
6 |
+
|
7 |
+
# List of ResNet50V2 conv layers that uses bias in the tensorflow implementation
|
8 |
+
CONVS_WITH_BIAS = [
|
9 |
+
'stem.conv',
|
10 |
+
'stages.0.blocks.0.downsample.conv', 'stages.0.blocks.0.conv3', 'stages.0.blocks.1.conv3', 'stages.0.blocks.2.conv3',
|
11 |
+
'stages.1.blocks.0.downsample.conv', 'stages.1.blocks.0.conv3', 'stages.1.blocks.1.conv3', 'stages.1.blocks.2.conv3', 'stages.1.blocks.3.conv3',
|
12 |
+
'stages.2.blocks.0.downsample.conv', 'stages.2.blocks.0.conv3', 'stages.2.blocks.1.conv3', 'stages.2.blocks.2.conv3',
|
13 |
+
]
|
14 |
+
|
15 |
+
# List of ResNet50V2 conv layers that uses stride 1 in the tensorflow implementation
|
16 |
+
CONVS_WITHOUT_STRIDE = [
|
17 |
+
'stages.1.blocks.0.downsample.conv', 'stages.1.blocks.0.conv2',
|
18 |
+
'stages.2.blocks.0.downsample.conv', 'stages.2.blocks.0.conv2',
|
19 |
+
]
|
20 |
+
|
21 |
+
# List of ResNet50V2 conv layers that use max pooling instead of stride 2 in the tensorflow implementation
|
22 |
+
FINAL_BLOCKS_WITH_MAX_POOL = [
|
23 |
+
'stages.0.blocks.2', 'stages.1.blocks.3',
|
24 |
+
]
|
25 |
+
|
26 |
+
|
27 |
+
class RepNet(nn.Module):
|
28 |
+
"""RepNet model."""
|
29 |
+
def __init__(self, num_frames: int = 64, temperature: float = 13.544):
|
30 |
+
super().__init__()
|
31 |
+
self.num_frames = num_frames
|
32 |
+
self.temperature = temperature
|
33 |
+
self.encoder = self._init_encoder()
|
34 |
+
self.temporal_conv = nn.Sequential(
|
35 |
+
nn.Conv3d(1024, 512, kernel_size=3, dilation=(3, 1, 1), padding=(3, 1, 1)),
|
36 |
+
nn.BatchNorm3d(512, eps=0.001),
|
37 |
+
nn.ReLU(inplace=True),
|
38 |
+
nn.AdaptiveMaxPool3d((None, 1, 1)),
|
39 |
+
nn.Flatten(2, 4),
|
40 |
+
)
|
41 |
+
self.tsm_conv = nn.Sequential(
|
42 |
+
nn.Conv2d(1, 32, kernel_size=3, padding=1),
|
43 |
+
nn.ReLU(inplace=True),
|
44 |
+
)
|
45 |
+
self.period_length_head = self._init_transformer_head(num_frames, 2048, 4, 512, num_frames // 2)
|
46 |
+
self.periodicity_head = self._init_transformer_head(num_frames, 2048, 4, 512, 1)
|
47 |
+
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def _init_encoder() -> nn.Module:
|
51 |
+
"""Initialize the encoder network using ResNet50 V2."""
|
52 |
+
encoder = torch.hub.load('huggingface/pytorch-image-models', 'resnetv2_50')
|
53 |
+
# Remove unused layers
|
54 |
+
del encoder.stages[2].blocks[3:6], encoder.stages[3]
|
55 |
+
encoder.norm = nn.Identity()
|
56 |
+
encoder.head.global_pool = nn.Identity()
|
57 |
+
encoder.head.fc = nn.Identity()
|
58 |
+
encoder.head.flatten = nn.Identity()
|
59 |
+
# Change padding from -inf to 0 on max pool to have the same behavior as tensorflow
|
60 |
+
encoder.stem.pool.padding = 0
|
61 |
+
encoder.stem.pool = nn.Sequential(nn.ZeroPad2d((1, 1, 1, 1)), encoder.stem.pool)
|
62 |
+
# Change properties of existing layers
|
63 |
+
for name, module in encoder.named_modules():
|
64 |
+
# Add missing bias to conv layers
|
65 |
+
if name in CONVS_WITH_BIAS:
|
66 |
+
module.bias = nn.Parameter(torch.zeros(module.out_channels))
|
67 |
+
# Remove stride from the first block in the later stages
|
68 |
+
if name in CONVS_WITHOUT_STRIDE:
|
69 |
+
module.stride = (1, 1)
|
70 |
+
# Change stride and add max pooling to final block
|
71 |
+
if name in FINAL_BLOCKS_WITH_MAX_POOL:
|
72 |
+
module.conv2.stride = (2, 2)
|
73 |
+
module.downsample = nn.MaxPool2d(1, stride=2)
|
74 |
+
# Change the forward function so that the input of max pooling is the raw `x` instead of the pre-activation result
|
75 |
+
bound_method = _max_pool_block_forward.__get__(module, module.__class__)
|
76 |
+
setattr(module, 'forward', bound_method)
|
77 |
+
# Change eps in batchnorm layers
|
78 |
+
if isinstance(module, nn.BatchNorm2d):
|
79 |
+
module.eps = 1.001e-5
|
80 |
+
return encoder
|
81 |
+
|
82 |
+
|
83 |
+
@staticmethod
|
84 |
+
def _init_transformer_head(num_frames: int, in_features: int, n_head: int, hidden_features: int, out_features: int) -> nn.Module:
|
85 |
+
"""Initialize the fully-connected head for the final output."""
|
86 |
+
return nn.Sequential(
|
87 |
+
TranformerLayer(in_features, n_head, hidden_features, num_frames),
|
88 |
+
nn.Linear(hidden_features, hidden_features),
|
89 |
+
nn.ReLU(inplace=True),
|
90 |
+
nn.Linear(hidden_features, hidden_features),
|
91 |
+
nn.ReLU(inplace=True),
|
92 |
+
nn.Linear(hidden_features, out_features),
|
93 |
+
)
|
94 |
+
|
95 |
+
|
96 |
+
def extract_feat(self, x: torch.Tensor) -> torch.Tensor:
|
97 |
+
"""Forward pass of the encoder network to extract per-frame embeddings. Expected input shape: N x C x D x H x W."""
|
98 |
+
batch_size, _, seq_len, _, _ = x.shape
|
99 |
+
torch._assert(seq_len == self.num_frames, f'Expected {self.num_frames} frames, got {seq_len}')
|
100 |
+
# Extract features frame-by-frame
|
101 |
+
x = x.movedim(1, 2).flatten(0, 1)
|
102 |
+
x = self.encoder(x)
|
103 |
+
x = x.unflatten(0, (batch_size, seq_len)).movedim(1, 2)
|
104 |
+
# Temporal convolution
|
105 |
+
x = self.temporal_conv(x)
|
106 |
+
x = x.movedim(1, 2) # Convert to N x D x C
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
def period_predictor(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
111 |
+
"""Forward pass of the period predictor network from the extracted embeddings. Expected input shape: N x D x C."""
|
112 |
+
batch_size, seq_len, _ = x.shape
|
113 |
+
torch._assert(seq_len == self.num_frames, f'Expected {self.num_frames} frames, got {seq_len}')
|
114 |
+
# Compute temporal self-similarity matrix
|
115 |
+
x = torch.cdist(x, x)**2 # N x D x D
|
116 |
+
x = -x / self.temperature
|
117 |
+
x = x.softmax(dim=-1)
|
118 |
+
# Conv layer on top of the TSM
|
119 |
+
x = self.tsm_conv(x.unsqueeze(1))
|
120 |
+
x = x.movedim(1, 3).reshape(batch_size, seq_len, -1) # Flatten channels into N x D x C
|
121 |
+
# Final prediction heads
|
122 |
+
period_length = self.period_length_head(x)
|
123 |
+
periodicity = self.periodicity_head(x)
|
124 |
+
return period_length, periodicity
|
125 |
+
|
126 |
+
|
127 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
128 |
+
"""Forward pass. Expected input shape: N x C x D x H x W."""
|
129 |
+
embeddings = self.extract_feat(x)
|
130 |
+
period_length, periodicity = self.period_predictor(embeddings)
|
131 |
+
return period_length, periodicity, embeddings
|
132 |
+
|
133 |
+
|
134 |
+
@staticmethod
|
135 |
+
def get_counts(raw_period_length: torch.Tensor, raw_periodicity: torch.Tensor, stride: int,
|
136 |
+
periodicity_threshold: float = 0.5) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
137 |
+
"""Compute the final scores from the period length and periodicity predictions."""
|
138 |
+
# Repeat the input to account for the stride
|
139 |
+
raw_period_length = raw_period_length.repeat_interleave(stride, dim=0)
|
140 |
+
raw_periodicity = raw_periodicity.repeat_interleave(stride, dim=0)
|
141 |
+
# Compute the final scores in [0, 1]
|
142 |
+
periodicity_score = torch.sigmoid(raw_periodicity).squeeze(-1)
|
143 |
+
period_length_confidence, period_length = torch.max(torch.softmax(raw_period_length, dim=-1), dim=-1)
|
144 |
+
# Remove the confidence for short periods and convert to the correct stride
|
145 |
+
period_length_confidence[period_length < 2] = 0
|
146 |
+
period_length = (period_length + 1) * stride
|
147 |
+
periodicity_score = torch.sqrt(periodicity_score * period_length_confidence)
|
148 |
+
# Generate the final counts and set them to 0 if the periodicity is too low
|
149 |
+
period_count = 1 / period_length
|
150 |
+
period_count[periodicity_score < periodicity_threshold] = 0
|
151 |
+
period_length = 1 / (torch.mean(period_count) + 1e-6)
|
152 |
+
period_count = torch.cumsum(period_count, dim=0)
|
153 |
+
confidence = torch.mean(periodicity_score)
|
154 |
+
return confidence, period_length, period_count, periodicity_score
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
class TranformerLayer(nn.Module):
|
159 |
+
"""A single transformer layer with self-attention and positional encoding."""
|
160 |
+
|
161 |
+
def __init__(self, in_features: int, n_head: int, out_features: int, num_frames: int):
|
162 |
+
super().__init__()
|
163 |
+
self.input_projection = nn.Linear(in_features, out_features)
|
164 |
+
self.pos_encoding = nn.Parameter(torch.normal(mean=0, std=0.02, size=(1, num_frames, 1)))
|
165 |
+
self.transformer_layer = nn.TransformerEncoderLayer(
|
166 |
+
d_model=out_features, nhead=n_head, dim_feedforward=out_features, activation='relu',
|
167 |
+
layer_norm_eps=1e-6, batch_first=True, norm_first=True
|
168 |
+
)
|
169 |
+
|
170 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
171 |
+
"""Forward pass, expected input shape: N x C x D."""
|
172 |
+
x = self.input_projection(x)
|
173 |
+
x = x + self.pos_encoding
|
174 |
+
x = self.transformer_layer(x)
|
175 |
+
return x
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
def _max_pool_block_forward(self, x):
|
180 |
+
"""
|
181 |
+
Custom `forward` function for the last block of each stage in ResNetV2, to have the same behavior as tensorflow.
|
182 |
+
Original implementation: https://github.com/huggingface/pytorch-image-models/blob/4b8cfa6c0a355a9b3cb2a77298b240213fb3b921/timm/models/resnetv2.py#L197
|
183 |
+
"""
|
184 |
+
x_preact = self.norm1(x)
|
185 |
+
shortcut = x
|
186 |
+
if self.downsample is not None:
|
187 |
+
shortcut = self.downsample(x) # Changed here from `x_preact` to `x`
|
188 |
+
x = self.conv1(x_preact)
|
189 |
+
x = self.conv2(self.norm2(x))
|
190 |
+
x = self.conv3(self.norm3(x))
|
191 |
+
x = self.drop_path(x)
|
192 |
+
return x + shortcut
|
repnet/plots.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility functions for plotting."""
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
from typing import List, Optional
|
5 |
+
from sklearn.decomposition import PCA
|
6 |
+
|
7 |
+
|
8 |
+
def plot_heatmap(dist: np.ndarray, log_scale: bool = False) -> np.ndarray:
|
9 |
+
"""Plot the temporal self-similarity matrix into an OpenCV image."""
|
10 |
+
np.fill_diagonal(dist, np.nan)
|
11 |
+
if log_scale:
|
12 |
+
dist = np.log(1 + dist)
|
13 |
+
dist = -dist # Invert the distance
|
14 |
+
zmin, zmax = np.nanmin(dist), np.nanmax(dist)
|
15 |
+
heatmap = (dist - zmin) / (zmax - zmin) # Normalize into [0, 1]
|
16 |
+
heatmap = np.nan_to_num(heatmap, nan=1)
|
17 |
+
heatmap = np.clip(heatmap * 255, 0, 255).astype(np.uint8)
|
18 |
+
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_VIRIDIS)
|
19 |
+
return heatmap
|
20 |
+
|
21 |
+
|
22 |
+
def plot_pca(embeddings: List[np.ndarray]) -> np.ndarray:
|
23 |
+
"""Plot the 1D PCA of the embeddings into an OpenCV image."""
|
24 |
+
projection = PCA(n_components=1).fit_transform(embeddings).flatten()
|
25 |
+
projection = (projection - projection.min()) / (projection.max() - projection.min())
|
26 |
+
h, w = 200, len(projection) * 4
|
27 |
+
img = np.full((h, w, 3), 255, dtype=np.uint8)
|
28 |
+
y = ((1 - projection) * h).astype(np.int32)
|
29 |
+
x = (np.arange(len(y)) / len(y) * w).astype(np.int32)
|
30 |
+
pts = np.stack([x, y], axis=1).reshape((-1, 1, 2))
|
31 |
+
img = cv2.polylines(img, [pts], False, (102, 60, 0), 1, cv2.LINE_AA)
|
32 |
+
return img
|
33 |
+
|
34 |
+
|
35 |
+
def plot_repetitions(frames: List[np.ndarray], counts: List[float], periodicity: Optional[List[float]]) -> List[np.ndarray]:
|
36 |
+
"""Generate video with repetition counts and return frames."""
|
37 |
+
blue_dark, blue_light = (102, 60, 0), (215, 175, 121)
|
38 |
+
h, w, _ = frames[0].shape
|
39 |
+
pbar_r = max(int(min(w, h) * 0.1), 20)
|
40 |
+
pbar_c = (pbar_r + 5, pbar_r + 5)
|
41 |
+
txt_s = pbar_r / 30
|
42 |
+
assert len(frames) == len(counts), 'Number of frames and counts must match.'
|
43 |
+
out_frames = []
|
44 |
+
for i, (frame, count) in enumerate(zip(frames, counts)):
|
45 |
+
frame = frame.copy()
|
46 |
+
# Draw progress bar
|
47 |
+
frame = cv2.ellipse(frame, pbar_c, (pbar_r, pbar_r), -90, 0, 360, blue_dark, -1, cv2.LINE_AA)
|
48 |
+
frame = cv2.ellipse(frame, pbar_c, (pbar_r, pbar_r), -90, 0, 360 * (count % 1.0), blue_light, -1, cv2.LINE_AA)
|
49 |
+
txt_box, _ = cv2.getTextSize(str(int(count)), cv2.FONT_HERSHEY_SIMPLEX, txt_s, 2)
|
50 |
+
txt_c = (pbar_c[0] - txt_box[0] // 2, pbar_c[1] + txt_box[1] // 2)
|
51 |
+
frame = cv2.putText(frame, str(int(count)), txt_c, cv2.FONT_HERSHEY_SIMPLEX, txt_s, (255, 255, 255), 2, cv2.LINE_AA)
|
52 |
+
# Draw periodicity plot on the right if available
|
53 |
+
if periodicity is not None:
|
54 |
+
periodicity = np.asarray(periodicity)
|
55 |
+
padx, pady, window_size = 5, 10, 64
|
56 |
+
pcanvas_h, pcanvas_w = frame.shape[0], min(frame.shape[0], frame.shape[1])
|
57 |
+
pcanvas = np.full((pcanvas_h, pcanvas_w, 3), 255, dtype=np.uint8)
|
58 |
+
pcanvas[pady::int((pcanvas_h - pady*2) / 10), :, :] = (235, 235, 235) # Draw horizontal grid
|
59 |
+
y = ((1 - periodicity[:i+1][-window_size:]) * (pcanvas_h - pady*2) + pady).astype(np.int32)
|
60 |
+
x = ((np.arange(len(y)) / window_size) * (pcanvas_w - padx*2)).astype(np.int32)
|
61 |
+
pts = np.stack([x, y], axis=1).reshape((-1, 1, 2))
|
62 |
+
pcanvas = cv2.polylines(pcanvas, [pts], False, blue_dark, 1, cv2.LINE_AA)
|
63 |
+
pcanvas = cv2.circle(pcanvas, (x[-1], y[-1]), 2, (0, 0, 255), -1, cv2.LINE_AA)
|
64 |
+
frame = np.concatenate([frame, pcanvas], axis=1)
|
65 |
+
out_frames.append(frame)
|
66 |
+
return out_frames
|
repnet/utils.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility functions."""
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
import requests
|
5 |
+
import yt_dlp
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
def flatten_dict(dictionary: dict, parent_key: str = '', sep: str = '.', keep_last: bool = False, skip_none: bool = False):
|
10 |
+
"""Flatten a nested dictionary into a single dictionary with keys separated by `sep`."""
|
11 |
+
items = {}
|
12 |
+
for k, v in dictionary.items():
|
13 |
+
key_prefix = parent_key if parent_key else ''
|
14 |
+
key_suffix = k if not skip_none or k is not None else ''
|
15 |
+
key_sep = sep if key_prefix and key_suffix else ''
|
16 |
+
new_key = key_prefix + key_sep + key_suffix
|
17 |
+
if isinstance(v, dict) and (not keep_last or isinstance(next(iter(v.values())), dict)):
|
18 |
+
items.update(flatten_dict(v, new_key, sep=sep, keep_last=keep_last, skip_none=skip_none))
|
19 |
+
else:
|
20 |
+
items[new_key] = v
|
21 |
+
return items
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
YOUTUB_DL_DOMAINS = ['youtube.com', 'imgur.com', 'reddit.com']
|
26 |
+
def download_file(url: str, dst: str):
|
27 |
+
"""Download a file from a given url."""
|
28 |
+
if any(domain in url for domain in YOUTUB_DL_DOMAINS):
|
29 |
+
# Download video from YouTube
|
30 |
+
with yt_dlp.YoutubeDL(dict(format='bestvideo[ext=mp4]/mp4', outtmpl=dst, quiet=True)) as ydl:
|
31 |
+
ydl.download([url])
|
32 |
+
elif url.startswith('http://') or url.startswith('https://'):
|
33 |
+
# Download file from HTTP
|
34 |
+
response = requests.get(url, timeout=10)
|
35 |
+
with open(dst, 'wb') as file:
|
36 |
+
file.write(response.content)
|
37 |
+
elif os.path.exists(url) and os.path.isfile(url):
|
38 |
+
# Copy file from local path
|
39 |
+
shutil.copyfile(url, dst)
|
40 |
+
else:
|
41 |
+
raise ValueError(f'Invalid url: {url}')
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.10.0
|
2 |
+
torchvision==0.11.1
|
3 |
+
numpy==1.21.6
|
4 |
+
opencv_python==4.6.0.66
|
5 |
+
requests==2.27.1
|
6 |
+
tensorflow==2.11.0
|
7 |
+
yt_dlp==2023.2.17
|
run.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Run the RepNet model on a given video."""
|
2 |
+
import os
|
3 |
+
import cv2
|
4 |
+
import argparse
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms as T
|
7 |
+
|
8 |
+
from repnet import utils, plots
|
9 |
+
from repnet.model import RepNet
|
10 |
+
|
11 |
+
|
12 |
+
PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
|
13 |
+
OUT_VISUALIZATIONS_DIR = os.path.join(PROJECT_ROOT, 'visualizations')
|
14 |
+
SAMPLE_VIDEOS_URLS = [
|
15 |
+
'https://imgur.com/t/hummingbird/m2e2Nfa', # Hummingbird
|
16 |
+
'https://www.youtube.com/watch?v=w0JOoC-5_Lk', # Chopping
|
17 |
+
'https://www.youtube.com/watch?v=t9OE3nxnI2Y', # Hammer training
|
18 |
+
'https://www.youtube.com/watch?v=aY3TrpiUOqE', # Bouncing ball
|
19 |
+
'https://www.youtube.com/watch?v=5EYY2J3nb5c', # Cooking
|
20 |
+
'https://www.reddit.com/r/gifs/comments/4qfif6/cheetah_running_at_63_mph_102_kph', # Cheetah
|
21 |
+
'https://www.youtube.com/watch?v=cMWb7NvWWuI', # Pendulum
|
22 |
+
'https://www.youtube.com/watch?v=5g1T-ff07kM', # Excersise
|
23 |
+
'https://www.youtube.com/watch?v=-Q3_7T5w4nE', # Excersise
|
24 |
+
|
25 |
+
]
|
26 |
+
|
27 |
+
# Script arguments
|
28 |
+
parser = argparse.ArgumentParser(description='Run the RepNet model on a given video.')
|
29 |
+
parser.add_argument('--weights', type=str, default=os.path.join(PROJECT_ROOT, 'checkpoints', 'pytorch_weights.pth'), help='Path to the model weights (default: %(default)s).')
|
30 |
+
parser.add_argument('--video', type=str, default=SAMPLE_VIDEOS_URLS[0], help='Video to test the model on, either a YouTube/http/local path (default: %(default)s).')
|
31 |
+
parser.add_argument('--strides', nargs='+', type=int, default=[1, 2, 3, 4, 8], help='Temporal strides to try when testing on the sample video (default: %(default)s).')
|
32 |
+
parser.add_argument('--device', type=str, default='cuda', help='Device to use for inference (default: %(default)s).')
|
33 |
+
parser.add_argument('--no-score', action='store_true', help='If specified, do not plot the periodicity score.')
|
34 |
+
|
35 |
+
if __name__ == '__main__':
|
36 |
+
args = parser.parse_args()
|
37 |
+
|
38 |
+
# Download the video sample if needed
|
39 |
+
print(f'Downloading {args.video}...')
|
40 |
+
video_path = os.path.join(PROJECT_ROOT, 'videos', os.path.basename(args.video) + '.mp4')
|
41 |
+
if not os.path.exists(video_path):
|
42 |
+
os.makedirs(os.path.dirname(video_path), exist_ok=True)
|
43 |
+
utils.download_file(args.video, video_path)
|
44 |
+
|
45 |
+
# Read frames and apply preprocessing
|
46 |
+
print(f'Reading video file and pre-processing frames...')
|
47 |
+
transform = T.Compose([
|
48 |
+
T.ToPILImage(),
|
49 |
+
T.Resize((112, 112)),
|
50 |
+
T.ToTensor(),
|
51 |
+
T.Normalize(mean=0.5, std=0.5),
|
52 |
+
])
|
53 |
+
cap = cv2.VideoCapture(video_path)
|
54 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
55 |
+
raw_frames, frames = [], []
|
56 |
+
while cap.isOpened():
|
57 |
+
ret, frame = cap.read()
|
58 |
+
if not ret or frame is None:
|
59 |
+
break
|
60 |
+
raw_frames.append(frame)
|
61 |
+
frame = transform(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
62 |
+
frames.append(frame)
|
63 |
+
cap.release()
|
64 |
+
|
65 |
+
# Load model
|
66 |
+
model = RepNet()
|
67 |
+
state_dict = torch.load(args.weights)
|
68 |
+
model.load_state_dict(state_dict)
|
69 |
+
model.eval()
|
70 |
+
model.to(args.device)
|
71 |
+
|
72 |
+
# Test multiple strides and pick the best one
|
73 |
+
print('Running inference on multiple stride values...')
|
74 |
+
best_stride, best_confidence, best_period_length, best_period_count, best_periodicity_score, best_embeddings = None, None, None, None, None, None
|
75 |
+
for stride in args.strides:
|
76 |
+
# Apply stride
|
77 |
+
stride_frames = frames[::stride]
|
78 |
+
stride_frames = stride_frames[:(len(stride_frames) // 64) * 64]
|
79 |
+
if len(stride_frames) < 64:
|
80 |
+
continue # Skip this stride if there are not enough frames
|
81 |
+
stride_frames = torch.stack(stride_frames, axis=0).unflatten(0, (-1, 64)).movedim(1, 2) # Convert to N x C x D x H x W
|
82 |
+
stride_frames = stride_frames.to(args.device)
|
83 |
+
# Run inference
|
84 |
+
raw_period_length, raw_periodicity_score, embeddings = [], [], []
|
85 |
+
with torch.no_grad():
|
86 |
+
for i in range(stride_frames.shape[0]): # Process each batch separately to avoid OOM
|
87 |
+
batch_period_length, batch_periodicity, batch_embeddings = model(stride_frames[i].unsqueeze(0))
|
88 |
+
raw_period_length.append(batch_period_length[0].cpu())
|
89 |
+
raw_periodicity_score.append(batch_periodicity[0].cpu())
|
90 |
+
embeddings.append(batch_embeddings[0].cpu())
|
91 |
+
# Post-process results
|
92 |
+
raw_period_length, raw_periodicity_score, embeddings = torch.cat(raw_period_length), torch.cat(raw_periodicity_score), torch.cat(embeddings)
|
93 |
+
confidence, period_length, period_count, periodicity_score = model.get_counts(raw_period_length, raw_periodicity_score, stride)
|
94 |
+
if best_confidence is None or confidence > best_confidence:
|
95 |
+
best_stride, best_confidence, best_period_length, best_period_count, best_periodicity_score, best_embeddings = stride, confidence, period_length, period_count, periodicity_score, embeddings
|
96 |
+
if best_stride is None:
|
97 |
+
raise RuntimeError('The stride values used are too large and nove 64 video chunk could be sampled. Try different values for --strides.')
|
98 |
+
print(f'Predicted a period length of {best_period_length/fps:.1f} seconds (~{int(best_period_length)} frames) with a confidence of {best_confidence:.2f} using a stride of {best_stride} frames.')
|
99 |
+
|
100 |
+
# Generate plots and videos
|
101 |
+
print(f'Save plots and video with counts to {OUT_VISUALIZATIONS_DIR}...')
|
102 |
+
os.makedirs(OUT_VISUALIZATIONS_DIR, exist_ok=True)
|
103 |
+
dist = torch.cdist(best_embeddings, best_embeddings, p=2)**2
|
104 |
+
tsm_img = plots.plot_heatmap(dist.numpy(), log_scale=True)
|
105 |
+
pca_img = plots.plot_pca(best_embeddings.numpy())
|
106 |
+
cv2.imwrite(os.path.join(OUT_VISUALIZATIONS_DIR, 'tsm.png'), tsm_img)
|
107 |
+
cv2.imwrite(os.path.join(OUT_VISUALIZATIONS_DIR, 'pca.png'), pca_img)
|
108 |
+
|
109 |
+
# Generate video with counts
|
110 |
+
rep_frames = plots.plot_repetitions(raw_frames[:len(best_period_count)], best_period_count.tolist(), best_periodicity_score.tolist() if not args.no_score else None)
|
111 |
+
video = cv2.VideoWriter(os.path.join(OUT_VISUALIZATIONS_DIR, 'repetitions.mp4'), cv2.VideoWriter_fourcc(*'mp4v'), fps, rep_frames[0].shape[:2][::-1])
|
112 |
+
for frame in rep_frames:
|
113 |
+
video.write(frame)
|
114 |
+
video.release()
|
115 |
+
|
116 |
+
print('Done')
|