materight commited on
Commit
ce3dce6
1 Parent(s): fd30f81
Files changed (9) hide show
  1. .gitignore +134 -0
  2. README.md +9 -6
  3. convert_weights.py +203 -0
  4. repnet/__init__.py +0 -0
  5. repnet/model.py +192 -0
  6. repnet/plots.py +66 -0
  7. repnet/utils.py +41 -0
  8. requirements.txt +7 -0
  9. run.py +116 -0
.gitignore ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+
131
+ # Project specific
132
+ checkpoints
133
+ videos
134
+ visualizations
README.md CHANGED
@@ -8,15 +8,18 @@ datasets:
8
  ---
9
 
10
  # RepNet PyTorch
 
 
 
11
  A PyTorch port with pre-trained weights of **RepNet**, from *Counting Out Time: Class Agnostic Video Repetition Counting in the Wild* (CVPR 2020) [[paper]](https://arxiv.org/abs/2006.15418) [[project]](https://sites.google.com/view/repnet) [[notebook]](https://colab.research.google.com/github/google-research/google-research/blob/master/repnet/repnet_colab.ipynb#scrollTo=FUg2vSYhmsT0).
12
 
13
  This repo provides an implementation of RepNet written in PyTorch and a script to convert the pre-trained TensorFlow weights provided by the authors. The outputs of the two implementations are almost identical, with a small deviation (less than $10^{-6}$ at most) probably caused by the [limited precision of floating point operations](https://pytorch.org/docs/stable/notes/numerical_accuracy.html).
14
 
15
  <div align="center">
16
- <img src="img/example1.gif" height="160" />
17
- <img src="img/example2.gif" height="160" />
18
- <img src="img/example3.gif" height="160" />
19
- <img src="img/example4.gif" height="160" />
20
  </div>
21
 
22
  ## Get Started
@@ -45,6 +48,6 @@ If the model does not produce good results, try to run the script with more stri
45
 
46
  Example of generated videos showing the repetition count, with the periodicity score and the temporal self-similarity matrix:
47
  <div align="center">
48
- <img src="img/example5_score.gif" height="200" />
49
- <img src="img/example5_tsm.png" height="200" />
50
  </div>
 
8
  ---
9
 
10
  # RepNet PyTorch
11
+
12
+ GitHub repository: https://github.com/materight/RepNet-pytorch.
13
+
14
  A PyTorch port with pre-trained weights of **RepNet**, from *Counting Out Time: Class Agnostic Video Repetition Counting in the Wild* (CVPR 2020) [[paper]](https://arxiv.org/abs/2006.15418) [[project]](https://sites.google.com/view/repnet) [[notebook]](https://colab.research.google.com/github/google-research/google-research/blob/master/repnet/repnet_colab.ipynb#scrollTo=FUg2vSYhmsT0).
15
 
16
  This repo provides an implementation of RepNet written in PyTorch and a script to convert the pre-trained TensorFlow weights provided by the authors. The outputs of the two implementations are almost identical, with a small deviation (less than $10^{-6}$ at most) probably caused by the [limited precision of floating point operations](https://pytorch.org/docs/stable/notes/numerical_accuracy.html).
17
 
18
  <div align="center">
19
+ <img src="https://raw.githubusercontent.com/materight/RepNet-pytorch/main/img/example1.gif" height="160" />
20
+ <img src="https://raw.githubusercontent.com/materight/RepNet-pytorch/main/img/example2.gif" height="160" />
21
+ <img src="https://raw.githubusercontent.com/materight/RepNet-pytorch/main/img/example3.gif" height="160" />
22
+ <img src="https://raw.githubusercontent.com/materight/RepNet-pytorch/main/img/example4.gif" height="160" />
23
  </div>
24
 
25
  ## Get Started
 
48
 
49
  Example of generated videos showing the repetition count, with the periodicity score and the temporal self-similarity matrix:
50
  <div align="center">
51
+ <img src="https://raw.githubusercontent.com/materight/RepNet-pytorch/main/img/example5_score.gif" height="200" />
52
+ <img src="https://raw.githubusercontent.com/materight/RepNet-pytorch/main/img/example5_tsm.png" height="200" />
53
  </div>
convert_weights.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Script to download the pre-trained tensorflow weights and convert them to pytorch weights."""
2
+ import os
3
+ import argparse
4
+ import torch
5
+ import numpy as np
6
+ from tensorflow.python.training import py_checkpoint_reader
7
+
8
+ from repnet import utils
9
+ from repnet.model import RepNet
10
+
11
+
12
+ # Relevant paths
13
+ PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
14
+ TF_CHECKPOINT_BASE_URL = 'https://storage.googleapis.com/repnet_ckpt'
15
+ TF_CHECKPOINT_FILES = ['checkpoint', 'ckpt-88.data-00000-of-00002', 'ckpt-88.data-00001-of-00002', 'ckpt-88.index']
16
+ OUT_CHECKPOINTS_DIR = os.path.join(PROJECT_ROOT, 'checkpoints')
17
+
18
+ # Mapping of ndim -> permutation to go from tf to pytorch
19
+ WEIGHTS_PERMUTATION = {
20
+ 2: (1, 0),
21
+ 4: (3, 2, 0, 1),
22
+ 5: (4, 3, 0, 1, 2)
23
+ }
24
+
25
+ # Mapping of tf attributes -> pytorch attributes
26
+ ATTR_MAPPING = {
27
+ 'kernel':'weight',
28
+ 'bias': 'bias',
29
+ 'beta': 'bias',
30
+ 'gamma': 'weight',
31
+ 'moving_mean': 'running_mean',
32
+ 'moving_variance': 'running_var'
33
+ }
34
+
35
+ # Mapping of tf checkpoint -> tf model -> pytorch model
36
+ WEIGHTS_MAPPING = [
37
+ # Base frame encoder
38
+ ('base_model.layer-2', 'conv1_conv', 'encoder.stem.conv'),
39
+ ('base_model.layer-5', 'conv2_block1_preact_bn', 'encoder.stages.0.blocks.0.norm1'),
40
+ ('base_model.layer-7', 'conv2_block1_1_conv', 'encoder.stages.0.blocks.0.conv1'),
41
+ ('base_model.layer-8', 'conv2_block1_1_bn', 'encoder.stages.0.blocks.0.norm2'),
42
+ ('base_model.layer_with_weights-4', 'conv2_block1_2_conv', 'encoder.stages.0.blocks.0.conv2'),
43
+ ('base_model.layer_with_weights-5', 'conv2_block1_2_bn', 'encoder.stages.0.blocks.0.norm3'),
44
+ ('base_model.layer_with_weights-6', 'conv2_block1_0_conv', 'encoder.stages.0.blocks.0.downsample.conv'),
45
+ ('base_model.layer_with_weights-7', 'conv2_block1_3_conv', 'encoder.stages.0.blocks.0.conv3'),
46
+ ('base_model.layer_with_weights-8', 'conv2_block2_preact_bn', 'encoder.stages.0.blocks.1.norm1'),
47
+ ('base_model.layer_with_weights-9', 'conv2_block2_1_conv', 'encoder.stages.0.blocks.1.conv1'),
48
+ ('base_model.layer_with_weights-10', 'conv2_block2_1_bn', 'encoder.stages.0.blocks.1.norm2'),
49
+ ('base_model.layer_with_weights-11', 'conv2_block2_2_conv', 'encoder.stages.0.blocks.1.conv2'),
50
+ ('base_model.layer_with_weights-12', 'conv2_block2_2_bn', 'encoder.stages.0.blocks.1.norm3'),
51
+ ('base_model.layer_with_weights-13', 'conv2_block2_3_conv', 'encoder.stages.0.blocks.1.conv3'),
52
+ ('base_model.layer_with_weights-14', 'conv2_block3_preact_bn', 'encoder.stages.0.blocks.2.norm1'),
53
+ ('base_model.layer_with_weights-15', 'conv2_block3_1_conv', 'encoder.stages.0.blocks.2.conv1'),
54
+ ('base_model.layer_with_weights-16', 'conv2_block3_1_bn', 'encoder.stages.0.blocks.2.norm2'),
55
+ ('base_model.layer_with_weights-17', 'conv2_block3_2_conv', 'encoder.stages.0.blocks.2.conv2'),
56
+ ('base_model.layer_with_weights-18', 'conv2_block3_2_bn', 'encoder.stages.0.blocks.2.norm3'),
57
+ ('base_model.layer_with_weights-19', 'conv2_block3_3_conv', 'encoder.stages.0.blocks.2.conv3'),
58
+ ('base_model.layer_with_weights-20', 'conv3_block1_preact_bn', 'encoder.stages.1.blocks.0.norm1'),
59
+ ('base_model.layer_with_weights-21', 'conv3_block1_1_conv', 'encoder.stages.1.blocks.0.conv1'),
60
+ ('base_model.layer_with_weights-22', 'conv3_block1_1_bn', 'encoder.stages.1.blocks.0.norm2'),
61
+ ('base_model.layer_with_weights-23', 'conv3_block1_2_conv', 'encoder.stages.1.blocks.0.conv2'),
62
+ ('base_model.layer-47', 'conv3_block1_2_bn', 'encoder.stages.1.blocks.0.norm3'),
63
+ ('base_model.layer_with_weights-25', 'conv3_block1_0_conv', 'encoder.stages.1.blocks.0.downsample.conv'),
64
+ ('base_model.layer_with_weights-26', 'conv3_block1_3_conv', 'encoder.stages.1.blocks.0.conv3'),
65
+ ('base_model.layer_with_weights-27', 'conv3_block2_preact_bn', 'encoder.stages.1.blocks.1.norm1'),
66
+ ('base_model.layer_with_weights-28', 'conv3_block2_1_conv', 'encoder.stages.1.blocks.1.conv1'),
67
+ ('base_model.layer_with_weights-29', 'conv3_block2_1_bn', 'encoder.stages.1.blocks.1.norm2'),
68
+ ('base_model.layer_with_weights-30', 'conv3_block2_2_conv', 'encoder.stages.1.blocks.1.conv2'),
69
+ ('base_model.layer_with_weights-31', 'conv3_block2_2_bn', 'encoder.stages.1.blocks.1.norm3'),
70
+ ('base_model.layer-61', 'conv3_block2_3_conv', 'encoder.stages.1.blocks.1.conv3'),
71
+ ('base_model.layer-63', 'conv3_block3_preact_bn', 'encoder.stages.1.blocks.2.norm1'),
72
+ ('base_model.layer-65', 'conv3_block3_1_conv', 'encoder.stages.1.blocks.2.conv1'),
73
+ ('base_model.layer-66', 'conv3_block3_1_bn', 'encoder.stages.1.blocks.2.norm2'),
74
+ ('base_model.layer-69', 'conv3_block3_2_conv', 'encoder.stages.1.blocks.2.conv2'),
75
+ ('base_model.layer-70', 'conv3_block3_2_bn', 'encoder.stages.1.blocks.2.norm3'),
76
+ ('base_model.layer_with_weights-38', 'conv3_block3_3_conv', 'encoder.stages.1.blocks.2.conv3'),
77
+ ('base_model.layer-74', 'conv3_block4_preact_bn', 'encoder.stages.1.blocks.3.norm1'),
78
+ ('base_model.layer_with_weights-40', 'conv3_block4_1_conv', 'encoder.stages.1.blocks.3.conv1'),
79
+ ('base_model.layer_with_weights-41', 'conv3_block4_1_bn', 'encoder.stages.1.blocks.3.norm2'),
80
+ ('base_model.layer_with_weights-42', 'conv3_block4_2_conv', 'encoder.stages.1.blocks.3.conv2'),
81
+ ('base_model.layer_with_weights-43', 'conv3_block4_2_bn', 'encoder.stages.1.blocks.3.norm3'),
82
+ ('base_model.layer_with_weights-44', 'conv3_block4_3_conv', 'encoder.stages.1.blocks.3.conv3'),
83
+ ('base_model.layer_with_weights-45', 'conv4_block1_preact_bn', 'encoder.stages.2.blocks.0.norm1'),
84
+ ('base_model.layer_with_weights-46', 'conv4_block1_1_conv', 'encoder.stages.2.blocks.0.conv1'),
85
+ ('base_model.layer_with_weights-47', 'conv4_block1_1_bn', 'encoder.stages.2.blocks.0.norm2'),
86
+ ('base_model.layer-92', 'conv4_block1_2_conv', 'encoder.stages.2.blocks.0.conv2'),
87
+ ('base_model.layer-93', 'conv4_block1_2_bn', 'encoder.stages.2.blocks.0.norm3'),
88
+ ('base_model.layer-95', 'conv4_block1_0_conv', 'encoder.stages.2.blocks.0.downsample.conv'),
89
+ ('base_model.layer-96', 'conv4_block1_3_conv', 'encoder.stages.2.blocks.0.conv3'),
90
+ ('base_model.layer-98', 'conv4_block2_preact_bn', 'encoder.stages.2.blocks.1.norm1'),
91
+ ('base_model.layer-100', 'conv4_block2_1_conv', 'encoder.stages.2.blocks.1.conv1'),
92
+ ('base_model.layer-101', 'conv4_block2_1_bn', 'encoder.stages.2.blocks.1.norm2'),
93
+ ('base_model.layer-104', 'conv4_block2_2_conv', 'encoder.stages.2.blocks.1.conv2'),
94
+ ('base_model.layer-105', 'conv4_block2_2_bn', 'encoder.stages.2.blocks.1.norm3'),
95
+ ('base_model.layer-107', 'conv4_block2_3_conv', 'encoder.stages.2.blocks.1.conv3'),
96
+ ('base_model.layer-109', 'conv4_block3_preact_bn', 'encoder.stages.2.blocks.2.norm1'),
97
+ ('base_model.layer-111', 'conv4_block3_1_conv', 'encoder.stages.2.blocks.2.conv1'),
98
+ ('base_model.layer-112', 'conv4_block3_1_bn', 'encoder.stages.2.blocks.2.norm2'),
99
+ ('base_model.layer-115', 'conv4_block3_2_conv', 'encoder.stages.2.blocks.2.conv2'),
100
+ ('base_model.layer-116', 'conv4_block3_2_bn', 'encoder.stages.2.blocks.2.norm3'),
101
+ ('base_model.layer-118', 'conv4_block3_3_conv', 'encoder.stages.2.blocks.2.conv3'),
102
+ # Temporal convolution
103
+ ('temporal_conv_layers.0', 'conv3d', 'temporal_conv.0'),
104
+ ('temporal_bn_layers.0', 'batch_normalization', 'temporal_conv.1'),
105
+ ('conv_3x3_layer', 'conv2d', 'tsm_conv.0'),
106
+ # Period length head
107
+ ('input_projection', 'dense', 'period_length_head.0.input_projection'),
108
+ ('pos_encoding', None, 'period_length_head.0.pos_encoding'),
109
+ ('transformer_layers.0.ffn.layer-0', None, 'period_length_head.0.transformer_layer.linear1'),
110
+ ('transformer_layers.0.ffn.layer-1', None, 'period_length_head.0.transformer_layer.linear2'),
111
+ ('transformer_layers.0.layernorm1', None, 'period_length_head.0.transformer_layer.norm1'),
112
+ ('transformer_layers.0.layernorm2', None, 'period_length_head.0.transformer_layer.norm2'),
113
+ ('transformer_layers.0.mha.w_weight', None, 'period_length_head.0.transformer_layer.self_attn.in_proj_weight'),
114
+ ('transformer_layers.0.mha.w_bias', None, 'period_length_head.0.transformer_layer.self_attn.in_proj_bias'),
115
+ ('transformer_layers.0.mha.dense', None, 'period_length_head.0.transformer_layer.self_attn.out_proj'),
116
+ ('fc_layers.0', 'dense_14', 'period_length_head.1'),
117
+ ('fc_layers.1', 'dense_15', 'period_length_head.3'),
118
+ ('fc_layers.2', 'dense_16', 'period_length_head.5'),
119
+ # Periodicity head
120
+ ('input_projection2', 'dense_1', 'periodicity_head.0.input_projection'),
121
+ ('pos_encoding2', None, 'periodicity_head.0.pos_encoding'),
122
+ ('transformer_layers2.0.ffn.layer-0', None, 'periodicity_head.0.transformer_layer.linear1'),
123
+ ('transformer_layers2.0.ffn.layer-1', None, 'periodicity_head.0.transformer_layer.linear2'),
124
+ ('transformer_layers2.0.layernorm1', None, 'periodicity_head.0.transformer_layer.norm1'),
125
+ ('transformer_layers2.0.layernorm2', None, 'periodicity_head.0.transformer_layer.norm2'),
126
+ ('transformer_layers2.0.mha.w_weight',None, 'periodicity_head.0.transformer_layer.self_attn.in_proj_weight'),
127
+ ('transformer_layers2.0.mha.w_bias', None, 'periodicity_head.0.transformer_layer.self_attn.in_proj_bias'),
128
+ ('transformer_layers2.0.mha.dense', None, 'periodicity_head.0.transformer_layer.self_attn.out_proj'),
129
+ ('within_period_fc_layers.0', 'dense_17', 'periodicity_head.1'),
130
+ ('within_period_fc_layers.1', 'dense_18', 'periodicity_head.3'),
131
+ ('within_period_fc_layers.2', 'dense_19', 'periodicity_head.5'),
132
+ ]
133
+
134
+ # Script arguments
135
+ parser = argparse.ArgumentParser(description='Download and convert the pre-trained weights from tensorflow to pytorch.')
136
+
137
+
138
+ if __name__ == '__main__':
139
+ args = parser.parse_args()
140
+
141
+ # Download tensorflow checkpoints
142
+ print('Downloading checkpoints...')
143
+ tf_checkpoint_dir = os.path.join(OUT_CHECKPOINTS_DIR, 'tf_checkpoint')
144
+ os.makedirs(tf_checkpoint_dir, exist_ok=True)
145
+ for file in TF_CHECKPOINT_FILES:
146
+ dst = os.path.join(tf_checkpoint_dir, file)
147
+ if not os.path.exists(dst):
148
+ utils.download_file(f'{TF_CHECKPOINT_BASE_URL}/{file}', dst)
149
+
150
+ # Load tensorflow weights into a dictionary
151
+ print('Loading tensorflow checkpoint...')
152
+ checkpoint_path = os.path.join(tf_checkpoint_dir, 'ckpt-88')
153
+ checkpoint_reader = py_checkpoint_reader.NewCheckpointReader(checkpoint_path)
154
+ shape_map = checkpoint_reader.get_variable_to_shape_map()
155
+ tf_state_dict = {}
156
+ for var_name in sorted(shape_map.keys()):
157
+ var_tensor = checkpoint_reader.get_tensor(var_name)
158
+ if not var_name.startswith('model') or '.OPTIMIZER_SLOT' in var_name:
159
+ continue # Skip variables that are not part of the model, e.g. from the optimizer
160
+ # Split var_name into path
161
+ var_path = var_name.split('/')[1:] # Remove `model`` key from the path
162
+ var_path = [p for p in var_path if p not in ['.ATTRIBUTES', 'VARIABLE_VALUE']]
163
+ # Map weights into a nested dictionary
164
+ current_dict = tf_state_dict
165
+ for path in var_path[:-1]:
166
+ current_dict = current_dict.setdefault(path, {})
167
+ current_dict[var_path[-1]] = var_tensor
168
+
169
+ # Merge transformer self-attention weights into a single tensor
170
+ for k in ['transformer_layers', 'transformer_layers2']:
171
+ v = tf_state_dict[k]['0']['mha']
172
+ v['w_weight'] = np.concatenate([v['wq']['kernel'].T, v['wk']['kernel'].T, v['wv']['kernel'].T], axis=0)
173
+ v['w_bias'] = np.concatenate([v['wq']['bias'].T, v['wk']['bias'].T, v['wv']['bias'].T], axis=0)
174
+ del v['wk'], v['wq'], v['wv']
175
+ tf_state_dict = utils.flatten_dict(tf_state_dict, keep_last=True)
176
+ # Add missing final level for some weights
177
+ for k, v in tf_state_dict.items():
178
+ if not isinstance(v, dict):
179
+ tf_state_dict[k] = {None: v}
180
+
181
+ # Convert to a format compatible with PyTorch and save
182
+ print(f'Converting to PyTorch format...')
183
+ pt_checkpoint_path = os.path.join(OUT_CHECKPOINTS_DIR, 'pytorch_weights.pth')
184
+ pt_state_dict = {}
185
+ for k_tf, _, k_pt in WEIGHTS_MAPPING:
186
+ assert k_pt not in pt_state_dict
187
+ pt_state_dict[k_pt] = {}
188
+ for attr in tf_state_dict[k_tf]:
189
+ new_attr = ATTR_MAPPING.get(attr, attr)
190
+ pt_state_dict[k_pt][new_attr] = torch.from_numpy(tf_state_dict[k_tf][attr])
191
+ if attr == 'kernel':
192
+ weights_permutation = WEIGHTS_PERMUTATION[pt_state_dict[k_pt][new_attr].ndim] # Permute weights if needed
193
+ pt_state_dict[k_pt][new_attr] = pt_state_dict[k_pt][new_attr].permute(weights_permutation)
194
+ pt_state_dict = utils.flatten_dict(pt_state_dict, skip_none=True)
195
+ torch.save(pt_state_dict, pt_checkpoint_path)
196
+
197
+ # Initialize the model and try to load the weights
198
+ print('Check that the weights can be loaded into the model...')
199
+ model = RepNet()
200
+ pt_state_dict = torch.load(pt_checkpoint_path)
201
+ model.load_state_dict(pt_state_dict)
202
+
203
+ print(f'Done. PyTorch weights saved to {pt_checkpoint_path}.')
repnet/__init__.py ADDED
File without changes
repnet/model.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch implementation of RepNet."""
2
+ import torch
3
+ from torch import nn
4
+ from typing import Tuple
5
+
6
+
7
+ # List of ResNet50V2 conv layers that uses bias in the tensorflow implementation
8
+ CONVS_WITH_BIAS = [
9
+ 'stem.conv',
10
+ 'stages.0.blocks.0.downsample.conv', 'stages.0.blocks.0.conv3', 'stages.0.blocks.1.conv3', 'stages.0.blocks.2.conv3',
11
+ 'stages.1.blocks.0.downsample.conv', 'stages.1.blocks.0.conv3', 'stages.1.blocks.1.conv3', 'stages.1.blocks.2.conv3', 'stages.1.blocks.3.conv3',
12
+ 'stages.2.blocks.0.downsample.conv', 'stages.2.blocks.0.conv3', 'stages.2.blocks.1.conv3', 'stages.2.blocks.2.conv3',
13
+ ]
14
+
15
+ # List of ResNet50V2 conv layers that uses stride 1 in the tensorflow implementation
16
+ CONVS_WITHOUT_STRIDE = [
17
+ 'stages.1.blocks.0.downsample.conv', 'stages.1.blocks.0.conv2',
18
+ 'stages.2.blocks.0.downsample.conv', 'stages.2.blocks.0.conv2',
19
+ ]
20
+
21
+ # List of ResNet50V2 conv layers that use max pooling instead of stride 2 in the tensorflow implementation
22
+ FINAL_BLOCKS_WITH_MAX_POOL = [
23
+ 'stages.0.blocks.2', 'stages.1.blocks.3',
24
+ ]
25
+
26
+
27
+ class RepNet(nn.Module):
28
+ """RepNet model."""
29
+ def __init__(self, num_frames: int = 64, temperature: float = 13.544):
30
+ super().__init__()
31
+ self.num_frames = num_frames
32
+ self.temperature = temperature
33
+ self.encoder = self._init_encoder()
34
+ self.temporal_conv = nn.Sequential(
35
+ nn.Conv3d(1024, 512, kernel_size=3, dilation=(3, 1, 1), padding=(3, 1, 1)),
36
+ nn.BatchNorm3d(512, eps=0.001),
37
+ nn.ReLU(inplace=True),
38
+ nn.AdaptiveMaxPool3d((None, 1, 1)),
39
+ nn.Flatten(2, 4),
40
+ )
41
+ self.tsm_conv = nn.Sequential(
42
+ nn.Conv2d(1, 32, kernel_size=3, padding=1),
43
+ nn.ReLU(inplace=True),
44
+ )
45
+ self.period_length_head = self._init_transformer_head(num_frames, 2048, 4, 512, num_frames // 2)
46
+ self.periodicity_head = self._init_transformer_head(num_frames, 2048, 4, 512, 1)
47
+
48
+
49
+ @staticmethod
50
+ def _init_encoder() -> nn.Module:
51
+ """Initialize the encoder network using ResNet50 V2."""
52
+ encoder = torch.hub.load('huggingface/pytorch-image-models', 'resnetv2_50')
53
+ # Remove unused layers
54
+ del encoder.stages[2].blocks[3:6], encoder.stages[3]
55
+ encoder.norm = nn.Identity()
56
+ encoder.head.global_pool = nn.Identity()
57
+ encoder.head.fc = nn.Identity()
58
+ encoder.head.flatten = nn.Identity()
59
+ # Change padding from -inf to 0 on max pool to have the same behavior as tensorflow
60
+ encoder.stem.pool.padding = 0
61
+ encoder.stem.pool = nn.Sequential(nn.ZeroPad2d((1, 1, 1, 1)), encoder.stem.pool)
62
+ # Change properties of existing layers
63
+ for name, module in encoder.named_modules():
64
+ # Add missing bias to conv layers
65
+ if name in CONVS_WITH_BIAS:
66
+ module.bias = nn.Parameter(torch.zeros(module.out_channels))
67
+ # Remove stride from the first block in the later stages
68
+ if name in CONVS_WITHOUT_STRIDE:
69
+ module.stride = (1, 1)
70
+ # Change stride and add max pooling to final block
71
+ if name in FINAL_BLOCKS_WITH_MAX_POOL:
72
+ module.conv2.stride = (2, 2)
73
+ module.downsample = nn.MaxPool2d(1, stride=2)
74
+ # Change the forward function so that the input of max pooling is the raw `x` instead of the pre-activation result
75
+ bound_method = _max_pool_block_forward.__get__(module, module.__class__)
76
+ setattr(module, 'forward', bound_method)
77
+ # Change eps in batchnorm layers
78
+ if isinstance(module, nn.BatchNorm2d):
79
+ module.eps = 1.001e-5
80
+ return encoder
81
+
82
+
83
+ @staticmethod
84
+ def _init_transformer_head(num_frames: int, in_features: int, n_head: int, hidden_features: int, out_features: int) -> nn.Module:
85
+ """Initialize the fully-connected head for the final output."""
86
+ return nn.Sequential(
87
+ TranformerLayer(in_features, n_head, hidden_features, num_frames),
88
+ nn.Linear(hidden_features, hidden_features),
89
+ nn.ReLU(inplace=True),
90
+ nn.Linear(hidden_features, hidden_features),
91
+ nn.ReLU(inplace=True),
92
+ nn.Linear(hidden_features, out_features),
93
+ )
94
+
95
+
96
+ def extract_feat(self, x: torch.Tensor) -> torch.Tensor:
97
+ """Forward pass of the encoder network to extract per-frame embeddings. Expected input shape: N x C x D x H x W."""
98
+ batch_size, _, seq_len, _, _ = x.shape
99
+ torch._assert(seq_len == self.num_frames, f'Expected {self.num_frames} frames, got {seq_len}')
100
+ # Extract features frame-by-frame
101
+ x = x.movedim(1, 2).flatten(0, 1)
102
+ x = self.encoder(x)
103
+ x = x.unflatten(0, (batch_size, seq_len)).movedim(1, 2)
104
+ # Temporal convolution
105
+ x = self.temporal_conv(x)
106
+ x = x.movedim(1, 2) # Convert to N x D x C
107
+ return x
108
+
109
+
110
+ def period_predictor(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
111
+ """Forward pass of the period predictor network from the extracted embeddings. Expected input shape: N x D x C."""
112
+ batch_size, seq_len, _ = x.shape
113
+ torch._assert(seq_len == self.num_frames, f'Expected {self.num_frames} frames, got {seq_len}')
114
+ # Compute temporal self-similarity matrix
115
+ x = torch.cdist(x, x)**2 # N x D x D
116
+ x = -x / self.temperature
117
+ x = x.softmax(dim=-1)
118
+ # Conv layer on top of the TSM
119
+ x = self.tsm_conv(x.unsqueeze(1))
120
+ x = x.movedim(1, 3).reshape(batch_size, seq_len, -1) # Flatten channels into N x D x C
121
+ # Final prediction heads
122
+ period_length = self.period_length_head(x)
123
+ periodicity = self.periodicity_head(x)
124
+ return period_length, periodicity
125
+
126
+
127
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
128
+ """Forward pass. Expected input shape: N x C x D x H x W."""
129
+ embeddings = self.extract_feat(x)
130
+ period_length, periodicity = self.period_predictor(embeddings)
131
+ return period_length, periodicity, embeddings
132
+
133
+
134
+ @staticmethod
135
+ def get_counts(raw_period_length: torch.Tensor, raw_periodicity: torch.Tensor, stride: int,
136
+ periodicity_threshold: float = 0.5) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
137
+ """Compute the final scores from the period length and periodicity predictions."""
138
+ # Repeat the input to account for the stride
139
+ raw_period_length = raw_period_length.repeat_interleave(stride, dim=0)
140
+ raw_periodicity = raw_periodicity.repeat_interleave(stride, dim=0)
141
+ # Compute the final scores in [0, 1]
142
+ periodicity_score = torch.sigmoid(raw_periodicity).squeeze(-1)
143
+ period_length_confidence, period_length = torch.max(torch.softmax(raw_period_length, dim=-1), dim=-1)
144
+ # Remove the confidence for short periods and convert to the correct stride
145
+ period_length_confidence[period_length < 2] = 0
146
+ period_length = (period_length + 1) * stride
147
+ periodicity_score = torch.sqrt(periodicity_score * period_length_confidence)
148
+ # Generate the final counts and set them to 0 if the periodicity is too low
149
+ period_count = 1 / period_length
150
+ period_count[periodicity_score < periodicity_threshold] = 0
151
+ period_length = 1 / (torch.mean(period_count) + 1e-6)
152
+ period_count = torch.cumsum(period_count, dim=0)
153
+ confidence = torch.mean(periodicity_score)
154
+ return confidence, period_length, period_count, periodicity_score
155
+
156
+
157
+
158
+ class TranformerLayer(nn.Module):
159
+ """A single transformer layer with self-attention and positional encoding."""
160
+
161
+ def __init__(self, in_features: int, n_head: int, out_features: int, num_frames: int):
162
+ super().__init__()
163
+ self.input_projection = nn.Linear(in_features, out_features)
164
+ self.pos_encoding = nn.Parameter(torch.normal(mean=0, std=0.02, size=(1, num_frames, 1)))
165
+ self.transformer_layer = nn.TransformerEncoderLayer(
166
+ d_model=out_features, nhead=n_head, dim_feedforward=out_features, activation='relu',
167
+ layer_norm_eps=1e-6, batch_first=True, norm_first=True
168
+ )
169
+
170
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
171
+ """Forward pass, expected input shape: N x C x D."""
172
+ x = self.input_projection(x)
173
+ x = x + self.pos_encoding
174
+ x = self.transformer_layer(x)
175
+ return x
176
+
177
+
178
+
179
+ def _max_pool_block_forward(self, x):
180
+ """
181
+ Custom `forward` function for the last block of each stage in ResNetV2, to have the same behavior as tensorflow.
182
+ Original implementation: https://github.com/huggingface/pytorch-image-models/blob/4b8cfa6c0a355a9b3cb2a77298b240213fb3b921/timm/models/resnetv2.py#L197
183
+ """
184
+ x_preact = self.norm1(x)
185
+ shortcut = x
186
+ if self.downsample is not None:
187
+ shortcut = self.downsample(x) # Changed here from `x_preact` to `x`
188
+ x = self.conv1(x_preact)
189
+ x = self.conv2(self.norm2(x))
190
+ x = self.conv3(self.norm3(x))
191
+ x = self.drop_path(x)
192
+ return x + shortcut
repnet/plots.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for plotting."""
2
+ import cv2
3
+ import numpy as np
4
+ from typing import List, Optional
5
+ from sklearn.decomposition import PCA
6
+
7
+
8
+ def plot_heatmap(dist: np.ndarray, log_scale: bool = False) -> np.ndarray:
9
+ """Plot the temporal self-similarity matrix into an OpenCV image."""
10
+ np.fill_diagonal(dist, np.nan)
11
+ if log_scale:
12
+ dist = np.log(1 + dist)
13
+ dist = -dist # Invert the distance
14
+ zmin, zmax = np.nanmin(dist), np.nanmax(dist)
15
+ heatmap = (dist - zmin) / (zmax - zmin) # Normalize into [0, 1]
16
+ heatmap = np.nan_to_num(heatmap, nan=1)
17
+ heatmap = np.clip(heatmap * 255, 0, 255).astype(np.uint8)
18
+ heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_VIRIDIS)
19
+ return heatmap
20
+
21
+
22
+ def plot_pca(embeddings: List[np.ndarray]) -> np.ndarray:
23
+ """Plot the 1D PCA of the embeddings into an OpenCV image."""
24
+ projection = PCA(n_components=1).fit_transform(embeddings).flatten()
25
+ projection = (projection - projection.min()) / (projection.max() - projection.min())
26
+ h, w = 200, len(projection) * 4
27
+ img = np.full((h, w, 3), 255, dtype=np.uint8)
28
+ y = ((1 - projection) * h).astype(np.int32)
29
+ x = (np.arange(len(y)) / len(y) * w).astype(np.int32)
30
+ pts = np.stack([x, y], axis=1).reshape((-1, 1, 2))
31
+ img = cv2.polylines(img, [pts], False, (102, 60, 0), 1, cv2.LINE_AA)
32
+ return img
33
+
34
+
35
+ def plot_repetitions(frames: List[np.ndarray], counts: List[float], periodicity: Optional[List[float]]) -> List[np.ndarray]:
36
+ """Generate video with repetition counts and return frames."""
37
+ blue_dark, blue_light = (102, 60, 0), (215, 175, 121)
38
+ h, w, _ = frames[0].shape
39
+ pbar_r = max(int(min(w, h) * 0.1), 20)
40
+ pbar_c = (pbar_r + 5, pbar_r + 5)
41
+ txt_s = pbar_r / 30
42
+ assert len(frames) == len(counts), 'Number of frames and counts must match.'
43
+ out_frames = []
44
+ for i, (frame, count) in enumerate(zip(frames, counts)):
45
+ frame = frame.copy()
46
+ # Draw progress bar
47
+ frame = cv2.ellipse(frame, pbar_c, (pbar_r, pbar_r), -90, 0, 360, blue_dark, -1, cv2.LINE_AA)
48
+ frame = cv2.ellipse(frame, pbar_c, (pbar_r, pbar_r), -90, 0, 360 * (count % 1.0), blue_light, -1, cv2.LINE_AA)
49
+ txt_box, _ = cv2.getTextSize(str(int(count)), cv2.FONT_HERSHEY_SIMPLEX, txt_s, 2)
50
+ txt_c = (pbar_c[0] - txt_box[0] // 2, pbar_c[1] + txt_box[1] // 2)
51
+ frame = cv2.putText(frame, str(int(count)), txt_c, cv2.FONT_HERSHEY_SIMPLEX, txt_s, (255, 255, 255), 2, cv2.LINE_AA)
52
+ # Draw periodicity plot on the right if available
53
+ if periodicity is not None:
54
+ periodicity = np.asarray(periodicity)
55
+ padx, pady, window_size = 5, 10, 64
56
+ pcanvas_h, pcanvas_w = frame.shape[0], min(frame.shape[0], frame.shape[1])
57
+ pcanvas = np.full((pcanvas_h, pcanvas_w, 3), 255, dtype=np.uint8)
58
+ pcanvas[pady::int((pcanvas_h - pady*2) / 10), :, :] = (235, 235, 235) # Draw horizontal grid
59
+ y = ((1 - periodicity[:i+1][-window_size:]) * (pcanvas_h - pady*2) + pady).astype(np.int32)
60
+ x = ((np.arange(len(y)) / window_size) * (pcanvas_w - padx*2)).astype(np.int32)
61
+ pts = np.stack([x, y], axis=1).reshape((-1, 1, 2))
62
+ pcanvas = cv2.polylines(pcanvas, [pts], False, blue_dark, 1, cv2.LINE_AA)
63
+ pcanvas = cv2.circle(pcanvas, (x[-1], y[-1]), 2, (0, 0, 255), -1, cv2.LINE_AA)
64
+ frame = np.concatenate([frame, pcanvas], axis=1)
65
+ out_frames.append(frame)
66
+ return out_frames
repnet/utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions."""
2
+ import os
3
+ import shutil
4
+ import requests
5
+ import yt_dlp
6
+
7
+
8
+
9
+ def flatten_dict(dictionary: dict, parent_key: str = '', sep: str = '.', keep_last: bool = False, skip_none: bool = False):
10
+ """Flatten a nested dictionary into a single dictionary with keys separated by `sep`."""
11
+ items = {}
12
+ for k, v in dictionary.items():
13
+ key_prefix = parent_key if parent_key else ''
14
+ key_suffix = k if not skip_none or k is not None else ''
15
+ key_sep = sep if key_prefix and key_suffix else ''
16
+ new_key = key_prefix + key_sep + key_suffix
17
+ if isinstance(v, dict) and (not keep_last or isinstance(next(iter(v.values())), dict)):
18
+ items.update(flatten_dict(v, new_key, sep=sep, keep_last=keep_last, skip_none=skip_none))
19
+ else:
20
+ items[new_key] = v
21
+ return items
22
+
23
+
24
+
25
+ YOUTUB_DL_DOMAINS = ['youtube.com', 'imgur.com', 'reddit.com']
26
+ def download_file(url: str, dst: str):
27
+ """Download a file from a given url."""
28
+ if any(domain in url for domain in YOUTUB_DL_DOMAINS):
29
+ # Download video from YouTube
30
+ with yt_dlp.YoutubeDL(dict(format='bestvideo[ext=mp4]/mp4', outtmpl=dst, quiet=True)) as ydl:
31
+ ydl.download([url])
32
+ elif url.startswith('http://') or url.startswith('https://'):
33
+ # Download file from HTTP
34
+ response = requests.get(url, timeout=10)
35
+ with open(dst, 'wb') as file:
36
+ file.write(response.content)
37
+ elif os.path.exists(url) and os.path.isfile(url):
38
+ # Copy file from local path
39
+ shutil.copyfile(url, dst)
40
+ else:
41
+ raise ValueError(f'Invalid url: {url}')
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==1.10.0
2
+ torchvision==0.11.1
3
+ numpy==1.21.6
4
+ opencv_python==4.6.0.66
5
+ requests==2.27.1
6
+ tensorflow==2.11.0
7
+ yt_dlp==2023.2.17
run.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run the RepNet model on a given video."""
2
+ import os
3
+ import cv2
4
+ import argparse
5
+ import torch
6
+ import torchvision.transforms as T
7
+
8
+ from repnet import utils, plots
9
+ from repnet.model import RepNet
10
+
11
+
12
+ PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))
13
+ OUT_VISUALIZATIONS_DIR = os.path.join(PROJECT_ROOT, 'visualizations')
14
+ SAMPLE_VIDEOS_URLS = [
15
+ 'https://imgur.com/t/hummingbird/m2e2Nfa', # Hummingbird
16
+ 'https://www.youtube.com/watch?v=w0JOoC-5_Lk', # Chopping
17
+ 'https://www.youtube.com/watch?v=t9OE3nxnI2Y', # Hammer training
18
+ 'https://www.youtube.com/watch?v=aY3TrpiUOqE', # Bouncing ball
19
+ 'https://www.youtube.com/watch?v=5EYY2J3nb5c', # Cooking
20
+ 'https://www.reddit.com/r/gifs/comments/4qfif6/cheetah_running_at_63_mph_102_kph', # Cheetah
21
+ 'https://www.youtube.com/watch?v=cMWb7NvWWuI', # Pendulum
22
+ 'https://www.youtube.com/watch?v=5g1T-ff07kM', # Excersise
23
+ 'https://www.youtube.com/watch?v=-Q3_7T5w4nE', # Excersise
24
+
25
+ ]
26
+
27
+ # Script arguments
28
+ parser = argparse.ArgumentParser(description='Run the RepNet model on a given video.')
29
+ parser.add_argument('--weights', type=str, default=os.path.join(PROJECT_ROOT, 'checkpoints', 'pytorch_weights.pth'), help='Path to the model weights (default: %(default)s).')
30
+ parser.add_argument('--video', type=str, default=SAMPLE_VIDEOS_URLS[0], help='Video to test the model on, either a YouTube/http/local path (default: %(default)s).')
31
+ parser.add_argument('--strides', nargs='+', type=int, default=[1, 2, 3, 4, 8], help='Temporal strides to try when testing on the sample video (default: %(default)s).')
32
+ parser.add_argument('--device', type=str, default='cuda', help='Device to use for inference (default: %(default)s).')
33
+ parser.add_argument('--no-score', action='store_true', help='If specified, do not plot the periodicity score.')
34
+
35
+ if __name__ == '__main__':
36
+ args = parser.parse_args()
37
+
38
+ # Download the video sample if needed
39
+ print(f'Downloading {args.video}...')
40
+ video_path = os.path.join(PROJECT_ROOT, 'videos', os.path.basename(args.video) + '.mp4')
41
+ if not os.path.exists(video_path):
42
+ os.makedirs(os.path.dirname(video_path), exist_ok=True)
43
+ utils.download_file(args.video, video_path)
44
+
45
+ # Read frames and apply preprocessing
46
+ print(f'Reading video file and pre-processing frames...')
47
+ transform = T.Compose([
48
+ T.ToPILImage(),
49
+ T.Resize((112, 112)),
50
+ T.ToTensor(),
51
+ T.Normalize(mean=0.5, std=0.5),
52
+ ])
53
+ cap = cv2.VideoCapture(video_path)
54
+ fps = cap.get(cv2.CAP_PROP_FPS)
55
+ raw_frames, frames = [], []
56
+ while cap.isOpened():
57
+ ret, frame = cap.read()
58
+ if not ret or frame is None:
59
+ break
60
+ raw_frames.append(frame)
61
+ frame = transform(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
62
+ frames.append(frame)
63
+ cap.release()
64
+
65
+ # Load model
66
+ model = RepNet()
67
+ state_dict = torch.load(args.weights)
68
+ model.load_state_dict(state_dict)
69
+ model.eval()
70
+ model.to(args.device)
71
+
72
+ # Test multiple strides and pick the best one
73
+ print('Running inference on multiple stride values...')
74
+ best_stride, best_confidence, best_period_length, best_period_count, best_periodicity_score, best_embeddings = None, None, None, None, None, None
75
+ for stride in args.strides:
76
+ # Apply stride
77
+ stride_frames = frames[::stride]
78
+ stride_frames = stride_frames[:(len(stride_frames) // 64) * 64]
79
+ if len(stride_frames) < 64:
80
+ continue # Skip this stride if there are not enough frames
81
+ stride_frames = torch.stack(stride_frames, axis=0).unflatten(0, (-1, 64)).movedim(1, 2) # Convert to N x C x D x H x W
82
+ stride_frames = stride_frames.to(args.device)
83
+ # Run inference
84
+ raw_period_length, raw_periodicity_score, embeddings = [], [], []
85
+ with torch.no_grad():
86
+ for i in range(stride_frames.shape[0]): # Process each batch separately to avoid OOM
87
+ batch_period_length, batch_periodicity, batch_embeddings = model(stride_frames[i].unsqueeze(0))
88
+ raw_period_length.append(batch_period_length[0].cpu())
89
+ raw_periodicity_score.append(batch_periodicity[0].cpu())
90
+ embeddings.append(batch_embeddings[0].cpu())
91
+ # Post-process results
92
+ raw_period_length, raw_periodicity_score, embeddings = torch.cat(raw_period_length), torch.cat(raw_periodicity_score), torch.cat(embeddings)
93
+ confidence, period_length, period_count, periodicity_score = model.get_counts(raw_period_length, raw_periodicity_score, stride)
94
+ if best_confidence is None or confidence > best_confidence:
95
+ best_stride, best_confidence, best_period_length, best_period_count, best_periodicity_score, best_embeddings = stride, confidence, period_length, period_count, periodicity_score, embeddings
96
+ if best_stride is None:
97
+ raise RuntimeError('The stride values used are too large and nove 64 video chunk could be sampled. Try different values for --strides.')
98
+ print(f'Predicted a period length of {best_period_length/fps:.1f} seconds (~{int(best_period_length)} frames) with a confidence of {best_confidence:.2f} using a stride of {best_stride} frames.')
99
+
100
+ # Generate plots and videos
101
+ print(f'Save plots and video with counts to {OUT_VISUALIZATIONS_DIR}...')
102
+ os.makedirs(OUT_VISUALIZATIONS_DIR, exist_ok=True)
103
+ dist = torch.cdist(best_embeddings, best_embeddings, p=2)**2
104
+ tsm_img = plots.plot_heatmap(dist.numpy(), log_scale=True)
105
+ pca_img = plots.plot_pca(best_embeddings.numpy())
106
+ cv2.imwrite(os.path.join(OUT_VISUALIZATIONS_DIR, 'tsm.png'), tsm_img)
107
+ cv2.imwrite(os.path.join(OUT_VISUALIZATIONS_DIR, 'pca.png'), pca_img)
108
+
109
+ # Generate video with counts
110
+ rep_frames = plots.plot_repetitions(raw_frames[:len(best_period_count)], best_period_count.tolist(), best_periodicity_score.tolist() if not args.no_score else None)
111
+ video = cv2.VideoWriter(os.path.join(OUT_VISUALIZATIONS_DIR, 'repetitions.mp4'), cv2.VideoWriter_fourcc(*'mp4v'), fps, rep_frames[0].shape[:2][::-1])
112
+ for frame in rep_frames:
113
+ video.write(frame)
114
+ video.release()
115
+
116
+ print('Done')