raoulritter commited on
Commit
331a0e7
1 Parent(s): 82fe504

Provide examples to Github and checkpoints to get help from edge impulse forums

Browse files
.gitattributes CHANGED
@@ -32,3 +32,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ ckpt_e09.pth.tar filter=lfs diff=lfs merge=lfs -text
36
+ ckpt_e10.pth.tar filter=lfs diff=lfs merge=lfs -text
37
+ ckpt_e49.pth.tar filter=lfs diff=lfs merge=lfs -text
38
+ ckpt_pytorch_1_11_e00.pth.tar filter=lfs diff=lfs merge=lfs -text
Test_ONNX_Convert.ipynb ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 9,
6
+ "id": "d0cd0fce",
7
+ "metadata": {
8
+ "ExecuteTime": {
9
+ "end_time": "2023-05-15T16:14:15.749744Z",
10
+ "start_time": "2023-05-15T16:14:15.540642Z"
11
+ }
12
+ },
13
+ "outputs": [],
14
+ "source": [
15
+ "import onnx\n",
16
+ "\n",
17
+ "\n",
18
+ "onnx_model = onnx.load(\"ckpt/model.onnx\")\n",
19
+ "onnx.checker.check_model(onnx_model)"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": 9,
25
+ "id": "27f06a8c",
26
+ "metadata": {
27
+ "ExecuteTime": {
28
+ "end_time": "2023-05-15T16:14:15.751689Z",
29
+ "start_time": "2023-05-15T16:14:15.748975Z"
30
+ }
31
+ },
32
+ "outputs": [],
33
+ "source": []
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 10,
38
+ "id": "f9167299",
39
+ "metadata": {
40
+ "ExecuteTime": {
41
+ "end_time": "2023-05-15T16:14:15.777873Z",
42
+ "start_time": "2023-05-15T16:14:15.753825Z"
43
+ }
44
+ },
45
+ "outputs": [
46
+ {
47
+ "ename": "ModuleNotFoundError",
48
+ "evalue": "No module named 'onnx_tf'",
49
+ "output_type": "error",
50
+ "traceback": [
51
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
52
+ "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
53
+ "Cell \u001b[0;32mIn[10], line 8\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mautograd\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m Variable\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01monnx\u001b[39;00m\n\u001b[0;32m----> 8\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01monnx_tf\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mbackend\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m prepare\n\u001b[1;32m 10\u001b[0m model \u001b[38;5;241m=\u001b[39m onnx\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mckpt/model_5x.onnx\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 11\u001b[0m tf_rep \u001b[38;5;241m=\u001b[39m prepare(model)\n",
54
+ "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'onnx_tf'"
55
+ ]
56
+ }
57
+ ],
58
+ "source": [
59
+ "import torch\n",
60
+ "import torch.nn as nn\n",
61
+ "import torch.nn.functional as F\n",
62
+ "import torch.optim as optim\n",
63
+ "from torchvision import datasets, transforms\n",
64
+ "from torch.autograd import Variable\n",
65
+ "import onnx\n",
66
+ "from onnx_tf.backend import prepare\n",
67
+ "\n",
68
+ "model = onnx.load('ckpt/model_5x.onnx')\n",
69
+ "tf_rep = prepare(model)"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": null,
75
+ "id": "2d1db936",
76
+ "metadata": {},
77
+ "outputs": [],
78
+ "source": []
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "id": "f4eb7d23",
84
+ "metadata": {
85
+ "scrolled": false
86
+ },
87
+ "outputs": [],
88
+ "source": [
89
+ "import torch\n",
90
+ "import torch.onnx as torch.onnx\n",
91
+ "import onnx\n",
92
+ "\n",
93
+ "\n",
94
+ "import torch.nn as nn\n",
95
+ "\n",
96
+ "from models.model import STBVMM\n",
97
+ "\n",
98
+ "# # Initialize model with checkpointing enabled\n",
99
+ "# model = STBVMM(img_size=384, patch_size=1, in_chans=3,\n",
100
+ "# embed_dim=48, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],\n",
101
+ "# window_size=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,\n",
102
+ "# drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n",
103
+ "# norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n",
104
+ "# use_checkpoint=True, img_range=1., resi_connection='1conv',\n",
105
+ "# manipulator_num_resblk = 1)\n",
106
+ "\n",
107
+ "model = STBVMM(img_size=384, patch_size=1, in_chans=3,\n",
108
+ " embed_dim=192, depths=[6, 6, 6, 6, 6, 6], num_heads=[6, 6, 6, 6, 6, 6],\n",
109
+ " window_size=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,\n",
110
+ " drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,\n",
111
+ " norm_layer=nn.LayerNorm, ape=False, patch_norm=True,\n",
112
+ " use_checkpoint=False, img_range=1., resi_connection='1conv',\n",
113
+ " manipulator_num_resblk=1)\n",
114
+ "\n",
115
+ "# Load pretrained weights from checkpoint\n",
116
+ "checkpoint = torch.load('ckpt/ckpt_e49.pth.tar')\n",
117
+ "# print(checkpoint.keys())\n",
118
+ "\n",
119
+ "# print(checkpoint['state_dict'])\n",
120
+ "\n",
121
+ "model.load_state_dict(checkpoint['state_dict'], strict= False)\n",
122
+ "\n",
123
+ "# Set the model to eval mode\n",
124
+ "model.eval()\n",
125
+ "\n",
126
+ "# Export model to ONNX\n",
127
+ "inputs = (torch.randn(1, 3, 384, 384), torch.randn(1, 3, 384, 384), 5)\n",
128
+ "input_names = [\"a\", \"b\", \"amp\"]\n",
129
+ "output_names = [\"output\"]\n",
130
+ "dynamic_axes = {\"a\": {0: \"batch_size\", 2: \"height\", 3: \"width\"},\n",
131
+ " \"b\": {0: \"batch_size\", 2: \"height\", 3: \"width\"},\n",
132
+ " \"output\": {0: \"batch_size\", 2: \"height\", 3: \"width\"}}\n",
133
+ "onnx.export(model, inputs, \"model_checkpoint_5x_50ep.onnx\", input_names=input_names, output_names=output_names,\n",
134
+ " dynamic_axes=dynamic_axes, opset_version=11)\n"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": 6,
140
+ "id": "8d70c0e9",
141
+ "metadata": {
142
+ "ExecuteTime": {
143
+ "end_time": "2023-05-16T04:50:57.341575Z",
144
+ "start_time": "2023-05-16T04:50:57.144003Z"
145
+ }
146
+ },
147
+ "outputs": [],
148
+ "source": [
149
+ "import onnx\n",
150
+ "\n",
151
+ "onnx_model = onnx.load(\"ckpt/model_checkpoint_5x_50ep.onnx\")\n",
152
+ "onnx.checker.check_model(onnx_model)"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": null,
158
+ "id": "ec9bacfd",
159
+ "metadata": {},
160
+ "outputs": [],
161
+ "source": [
162
+ "import onnxruntime as ort\n",
163
+ "import numpy as np\n",
164
+ "import cv2\n",
165
+ "x, y = test_data[0][0], test_data[0][1]"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "code",
170
+ "execution_count": 7,
171
+ "id": "1bea7afb",
172
+ "metadata": {
173
+ "ExecuteTime": {
174
+ "end_time": "2023-05-16T05:23:43.230764Z",
175
+ "start_time": "2023-05-16T05:23:29.950857Z"
176
+ }
177
+ },
178
+ "outputs": [
179
+ {
180
+ "name": "stdout",
181
+ "output_type": "stream",
182
+ "text": [
183
+ "Using device: cpu\r\n",
184
+ "demo_video/STB-VMM_Freezer_x20_mag\r\n",
185
+ "processing sample: 0\r\n",
186
+ "Traceback (most recent call last):\r\n",
187
+ " File \"/Users/raoulritter/STB-VMM/onnxrun.py\", line 112, in <module>\r\n",
188
+ " main(args)\r\n",
189
+ " File \"/Users/raoulritter/STB-VMM/onnxrun.py\", line 53, in main\r\n",
190
+ " ort_outs = ort_session.run(None, ort_inputs)\r\n",
191
+ " File \"/opt/anaconda3/envs/afstudeer/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py\", line 200, in run\r\n",
192
+ " return self._sess.run(output_names, input_feed, run_options)\r\n",
193
+ "RuntimeError: Input must be a list of dictionaries or a single numpy array for input 'a'.\r\n"
194
+ ]
195
+ }
196
+ ],
197
+ "source": [
198
+ "!python onnxrun.py -j4 -b1 --load_ckpt ckpt/model_checkpoint_5x_50ep.onnx --save_dir demo_video/STB-VMM_Freezer_x20_mag -m 5 --video_path demo_video/STB-VMM_Freezer_x20_original/frame --num_data 6644 --mode static\n",
199
+ "\n"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": 1,
205
+ "id": "c77312e7",
206
+ "metadata": {
207
+ "ExecuteTime": {
208
+ "end_time": "2023-05-16T05:14:51.342430Z",
209
+ "start_time": "2023-05-16T05:14:48.496955Z"
210
+ }
211
+ },
212
+ "outputs": [
213
+ {
214
+ "name": "stdout",
215
+ "output_type": "stream",
216
+ "text": [
217
+ "1.11.0\n",
218
+ "0.12.0\n"
219
+ ]
220
+ }
221
+ ],
222
+ "source": [
223
+ "import torch\n",
224
+ "import torchvision\n",
225
+ "print(torch.__version__)\n",
226
+ "print(torchvision.__version__)"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": null,
232
+ "id": "702d9d85",
233
+ "metadata": {},
234
+ "outputs": [],
235
+ "source": []
236
+ }
237
+ ],
238
+ "metadata": {
239
+ "kernelspec": {
240
+ "display_name": "Python [conda env:afstudeer]",
241
+ "language": "python",
242
+ "name": "conda-env-afstudeer-py"
243
+ },
244
+ "language_info": {
245
+ "codemirror_mode": {
246
+ "name": "ipython",
247
+ "version": 3
248
+ },
249
+ "file_extension": ".py",
250
+ "mimetype": "text/x-python",
251
+ "name": "python",
252
+ "nbconvert_exporter": "python",
253
+ "pygments_lexer": "ipython3",
254
+ "version": "3.10.11"
255
+ }
256
+ },
257
+ "nbformat": 4,
258
+ "nbformat_minor": 5
259
+ }
ckpt_e09.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b503d280322ad5a257fd760878447c3e13e80800baa06b8288ee37bd79173ce
3
+ size 149374251
ckpt_e10.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b63b6fdfbd487d5482c0e4b821040df85e315407a2205e755b426cf9c94492ce
3
+ size 149368279
ckpt_e49.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1f1df7bebba895be14728293138812826c1affeb4777f76be960e8eb100ed362
3
+ size 149368983
ckpt_pytorch_1_11_e00.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:449bb57e7b0c3a17580f5512cde77397abf6178ea30f28a726d700eac2343920
3
+ size 149368087
model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de4e1e8b51f1cf371159c53a9efdc93cdd879c8f4406941e20301d05d3718c67
3
+ size 137258146
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32c2cea3c5ef96e308f5c8b8a0b6418d8e00cacf1c6c5b3e388e796f00ccb079
3
+ size 149306703
model_2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e72d64cec630d95ba0bc4e49ee7c7e2f1a4a71dcef83ab3b2b1e24ab75fd7a9c
3
+ size 136977868
model_5x.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de4e1e8b51f1cf371159c53a9efdc93cdd879c8f4406941e20301d05d3718c67
3
+ size 137258146
model_checkpoint_5x_50ep.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7681377d1db055c8b4f4a345052830beb20a027526235fcc6c4c2232edb876bb
3
+ size 136977913
onnxrun.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+ import os
4
+ import numpy as np
5
+ import torch
6
+ import torch.utils.data as data
7
+ from PIL import Image
8
+ from utils.data_loader import ImageFromFolderTest
9
+ import onnxruntime as ort
10
+
11
+
12
+ def main(args):
13
+ # Device choice (auto)
14
+ if args.device == 'auto':
15
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
+ else:
17
+ device = args.device
18
+
19
+ print(f'Using device: {device}')
20
+
21
+ # Create ONNX Inference Session
22
+ ort_session = ort.InferenceSession(args.load_ckpt)
23
+
24
+ # Check saving directory
25
+ save_dir = args.save_dir
26
+ if not os.path.exists(save_dir):
27
+ os.makedirs(save_dir)
28
+ print(save_dir)
29
+
30
+ # Data loader
31
+ dataset_mag = ImageFromFolderTest(
32
+ args.video_path, mag=args.mag, mode=args.mode, num_data=args.num_data, preprocessing=False)
33
+ data_loader = data.DataLoader(dataset_mag,
34
+ batch_size=args.batch_size,
35
+ shuffle=False,
36
+ num_workers=args.workers,
37
+ pin_memory=False)
38
+
39
+ # Magnification
40
+ for i, (xa, xb, mag_factor) in enumerate(data_loader):
41
+ if i % args.print_freq == 0:
42
+ print('processing sample: %d' % i)
43
+
44
+ xa = xa.to(device)
45
+ xb = xb.to(device)
46
+
47
+ # Infer using ONNX model
48
+ mag_factor = torch.tensor([[args.mag]]).to(device) # Create a constant tensor for the magnification factor
49
+ ort_inputs = {ort_session.get_inputs()[0].name: xa,
50
+ ort_session.get_inputs()[1].name: xb,
51
+ ort_session.get_inputs()[2].name: mag_factor}
52
+ #y_hat, _, _, _ = ort_session.run(ort_inputs)
53
+ ort_outs = ort_session.run(None, ort_inputs)
54
+ y_hat = ort_outs[0]
55
+
56
+ # ort_inputs = {ort_session.get_inputs()[0].name: xa,
57
+ # ort_session.get_inputs()[1].name: xb}
58
+ # y_hat, _, _, _ = ort_session.run(None, ort_inputs)
59
+
60
+ if i == 0:
61
+ # Back to image scale (0-255)
62
+ tmp = xa.permute(0, 2, 3, 1).cpu().detach().numpy()
63
+ tmp = np.clip(tmp, -1.0, 1.0)
64
+ tmp = ((tmp + 1.0) * 127.5).astype(np.uint8)
65
+
66
+ # Save first frame
67
+ fn = os.path.join(save_dir, 'STBVMM_%s_%06d.png' % (args.mode, i))
68
+ im = Image.fromarray(np.concatenate(tmp, 0))
69
+ im.save(fn)
70
+
71
+ # back to image scale (0-255)
72
+ y_hat = y_hat.permute(0, 2, 3, 1).cpu().detach().numpy()
73
+ y_hat = np.clip(y_hat, -1.0, 1.0)
74
+ y_hat = ((y_hat + 1.0) * 127.5).astype(np.uint8)
75
+
76
+ # Save frames
77
+ fn = os.path.join(save_dir, 'STBVMM_%s_%06d.png' % (args.mode, i+1))
78
+ im = Image.fromarray(np.concatenate(y_hat, 0))
79
+ im.save(fn)
80
+
81
+ if __name__ == '__main__':
82
+ parser = argparse.ArgumentParser(
83
+ description='Swin Transformer Based Video Motion Magnification')
84
+
85
+ # Application parameters
86
+ parser.add_argument('-i', '--video_path', type=str, metavar='PATH', required=True,
87
+ help='path to video input frames')
88
+ parser.add_argument('-c', '--load_ckpt', type=str, metavar='PATH', required=True,
89
+ help='path to load ONNX model')
90
+ parser.add_argument('-o', '--save_dir', default='demo', type=str, metavar='PATH',
91
+ help='path to save generated frames (default: demo)')
92
+ parser.add_argument('-m', '--mag', metavar='N', default=20.0, type=float,
93
+ help='magnification factor (default: 20.0)')
94
+ parser.add_argument('--mode', default='static', type=str, choices=['static', 'dynamic'],
95
+ help='magnification mode (static, dynamic)')
96
+ parser.add_argument('-n', '--num_data', type=int, metavar='N', required=True,
97
+ help='number of frames')
98
+
99
+ # Execute parameters
100
+ parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
101
+ help='number of data loading workers (default: 16)')
102
+ parser.add_argument('-b', '--batch_size', default=1, type=int,
103
+ metavar='N', help='batch size (default: 1)')
104
+ parser.add_argument('-p', '--print_freq', default=100, type=int,
105
+ metavar='N', help='print frequency (default: 100)')
106
+
107
+ # Device
108
+ parser.add_argument('--device', type=str, default='auto',
109
+ choices=['auto', 'cpu', 'cuda', 'mps', 'xla'],
110
+ help='select device [auto/cpu/cuda] (default: auto)')
111
+ args = parser.parse_args()
112
+ main(args)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ torch==2.0
2
+ Pillow==9.3
3
+ torchvision
4
+ torchaudio
5
+ numpy
run.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.utils.data as data
8
+ import torchvision.datasets as datasets
9
+ from PIL import Image
10
+
11
+ from utils.data_loader import ImageFromFolderTest
12
+ from models.model import STBVMM
13
+
14
+
15
+ def main(args):
16
+ # Device choice (auto)
17
+ if args.device == 'auto':
18
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
19
+ else:
20
+ device = args.device
21
+
22
+ print(f'Using device: {device}')
23
+
24
+ # Create model
25
+ model = STBVMM(img_size=384, patch_size=1, in_chans=3,
26
+ embed_dim=192, depths=[6, 6, 6, 6, 6, 6], num_heads=[6, 6, 6, 6, 6, 6],
27
+ window_size=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,
28
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
29
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
30
+ use_checkpoint=False, img_range=1., resi_connection='1conv',
31
+ manipulator_num_resblk=1).to(device)
32
+
33
+ # Load checkpoint
34
+ if os.path.isfile(args.load_ckpt):
35
+ print("=> loading checkpoint '{}'".format(args.load_ckpt))
36
+ checkpoint = torch.load(args.load_ckpt)
37
+ args.start_epoch = checkpoint['epoch']
38
+
39
+ model.load_state_dict(checkpoint['state_dict'])
40
+
41
+ print("=> loaded checkpoint '{}' (epoch {})"
42
+ .format(args.load_ckpt, checkpoint['epoch']))
43
+ else:
44
+ print("=> no checkpoint found at '{}'".format(args.load_ckpt))
45
+ assert (False)
46
+
47
+ # Check saving directory
48
+ save_dir = args.save_dir
49
+ if not os.path.exists(save_dir):
50
+ os.makedirs(save_dir)
51
+ print(save_dir)
52
+
53
+ # Data loader
54
+ dataset_mag = ImageFromFolderTest(
55
+ args.video_path, mag=args.mag, mode=args.mode, num_data=args.num_data, preprocessing=False)
56
+ data_loader = data.DataLoader(dataset_mag,
57
+ batch_size=args.batch_size,
58
+ shuffle=False,
59
+ num_workers=args.workers,
60
+ pin_memory=False)
61
+
62
+ # Generate frames
63
+ model.eval()
64
+
65
+ # Magnification
66
+ for i, (xa, xb, mag_factor) in enumerate(data_loader):
67
+ if i % args.print_freq == 0:
68
+ print('processing sample: %d' % i)
69
+
70
+ mag_factor = mag_factor.unsqueeze(1).unsqueeze(1).unsqueeze(1)
71
+
72
+ xa = xa.to(device)
73
+ xb = xb.to(device)
74
+ mag_factor = mag_factor.to(device)
75
+
76
+ y_hat, _, _, _ = model(xa, xb, mag_factor)
77
+
78
+ if i == 0:
79
+ # Back to image scale (0-255)
80
+ tmp = xa.permute(0, 2, 3, 1).cpu().detach().numpy()
81
+ tmp = np.clip(tmp, -1.0, 1.0)
82
+ tmp = ((tmp + 1.0) * 127.5).astype(np.uint8)
83
+
84
+ # Save first frame
85
+ fn = os.path.join(save_dir, 'STBVMM_%s_%06d.png' % (args.mode, i))
86
+ im = Image.fromarray(np.concatenate(tmp, 0))
87
+ im.save(fn)
88
+
89
+ # back to image scale (0-255)
90
+ y_hat = y_hat.permute(0, 2, 3, 1).cpu().detach().numpy()
91
+ y_hat = np.clip(y_hat, -1.0, 1.0)
92
+ y_hat = ((y_hat + 1.0) * 127.5).astype(np.uint8)
93
+
94
+ # Save frames
95
+ fn = os.path.join(save_dir, 'STBVMM_%s_%06d.png' % (args.mode, i+1))
96
+ im = Image.fromarray(np.concatenate(y_hat, 0))
97
+ im.save(fn)
98
+
99
+
100
+ if __name__ == '__main__':
101
+ parser = argparse.ArgumentParser(
102
+ description='Swin Transformer Based Video Motion Magnification')
103
+
104
+ # Application parameters
105
+ parser.add_argument('-i', '--video_path', type=str, metavar='PATH', required=True,
106
+ help='path to video input frames')
107
+ parser.add_argument('-c', '--load_ckpt', type=str, metavar='PATH', required=True,
108
+ help='path to load checkpoint')
109
+ parser.add_argument('-o', '--save_dir', default='demo', type=str, metavar='PATH',
110
+ help='path to save generated frames (default: demo)')
111
+ parser.add_argument('-m', '--mag', metavar='N', default=20.0, type=float,
112
+ help='magnification factor (default: 20.0)')
113
+ parser.add_argument('--mode', default='static', type=str, choices=['static', 'dynamic'],
114
+ help='magnification mode (static, dynamic)')
115
+ parser.add_argument('-n', '--num_data', type=int, metavar='N', required=True,
116
+ help='number of frames')
117
+
118
+ # Execute parameters
119
+ parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
120
+ help='number of data loading workers (default: 16)')
121
+ parser.add_argument('-b', '--batch_size', default=1, type=int,
122
+ metavar='N', help='batch size (default: 1)')
123
+ parser.add_argument('-p', '--print_freq', default=100, type=int,
124
+ metavar='N', help='print frequency (default: 100)')
125
+
126
+ # Device
127
+ parser.add_argument('--device', type=str, default='auto',
128
+ choices=['auto', 'cpu', 'cuda', 'mps', 'xla'],
129
+ help='select device [auto/cpu/cuda] (default: auto)')
130
+
131
+ args = parser.parse_args()
132
+
133
+ main(args)
test_convert.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from models.model import STBVMM
4
+
5
+
6
+ model = STBVMM(img_size=384, patch_size=1, in_chans=3,
7
+ embed_dim=192, depths=[6, 6, 6, 6, 6, 6], num_heads=[6, 6, 6, 6, 6, 6],
8
+ window_size=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,
9
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
10
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
11
+ use_checkpoint=False, img_range=1., resi_connection='1conv',
12
+ manipulator_num_resblk=1).to("cpu")
13
+
14
+
15
+
16
+ checkpoint = torch.load('ckpt/ckpt_e10.pth.tar')
17
+ # print(checkpoint.keys())
18
+
19
+ print(checkpoint['state_dict'])
20
+
21
+ model.load_state_dict(checkpoint['state_dict'], strict= False)
22
+ # Get the keys in the checkpoint's state_dict
23
+ checkpoint_keys = set(checkpoint['state_dict'].keys())
24
+
25
+ # Get the keys in the current model's state_dict
26
+ model_keys = set(model.state_dict().keys())
27
+
28
+ # Find the difference between the keys
29
+ keys_only_in_checkpoint = checkpoint_keys - model_keys
30
+ keys_only_in_model = model_keys - checkpoint_keys
31
+
32
+ # Print the results
33
+ print("Keys only in the checkpoint's state_dict:")
34
+ print(keys_only_in_checkpoint)
35
+
36
+