obvious-research
commited on
Commit
•
67d652f
1
Parent(s):
f907d71
Uploaded frozen models
Browse files- frozen_models/.DS_Store +0 -0
- frozen_models/pytorch_i3d/LICENSE.txt +202 -0
- frozen_models/pytorch_i3d/README.md +25 -0
- frozen_models/pytorch_i3d/charades_dataset.py +125 -0
- frozen_models/pytorch_i3d/charades_dataset_full.py +123 -0
- frozen_models/pytorch_i3d/extract_features.py +90 -0
- frozen_models/pytorch_i3d/models/flow_charades.pt +3 -0
- frozen_models/pytorch_i3d/models/flow_imagenet.pt +3 -0
- frozen_models/pytorch_i3d/models/rgb_charades.pt +3 -0
- frozen_models/pytorch_i3d/models/rgb_imagenet.pt +3 -0
- frozen_models/pytorch_i3d/pytorch_i3d.py +338 -0
- frozen_models/pytorch_i3d/train_i3d.py +133 -0
- frozen_models/pytorch_i3d/videotransforms.py +102 -0
frozen_models/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
frozen_models/pytorch_i3d/LICENSE.txt
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Apache License
|
3 |
+
Version 2.0, January 2004
|
4 |
+
http://www.apache.org/licenses/
|
5 |
+
|
6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
7 |
+
|
8 |
+
1. Definitions.
|
9 |
+
|
10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
12 |
+
|
13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
14 |
+
the copyright owner that is granting the License.
|
15 |
+
|
16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
17 |
+
other entities that control, are controlled by, or are under common
|
18 |
+
control with that entity. For the purposes of this definition,
|
19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
20 |
+
direction or management of such entity, whether by contract or
|
21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
23 |
+
|
24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
25 |
+
exercising permissions granted by this License.
|
26 |
+
|
27 |
+
"Source" form shall mean the preferred form for making modifications,
|
28 |
+
including but not limited to software source code, documentation
|
29 |
+
source, and configuration files.
|
30 |
+
|
31 |
+
"Object" form shall mean any form resulting from mechanical
|
32 |
+
transformation or translation of a Source form, including but
|
33 |
+
not limited to compiled object code, generated documentation,
|
34 |
+
and conversions to other media types.
|
35 |
+
|
36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
37 |
+
Object form, made available under the License, as indicated by a
|
38 |
+
copyright notice that is included in or attached to the work
|
39 |
+
(an example is provided in the Appendix below).
|
40 |
+
|
41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
42 |
+
form, that is based on (or derived from) the Work and for which the
|
43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
45 |
+
of this License, Derivative Works shall not include works that remain
|
46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
47 |
+
the Work and Derivative Works thereof.
|
48 |
+
|
49 |
+
"Contribution" shall mean any work of authorship, including
|
50 |
+
the original version of the Work and any modifications or additions
|
51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
55 |
+
means any form of electronic, verbal, or written communication sent
|
56 |
+
to the Licensor or its representatives, including but not limited to
|
57 |
+
communication on electronic mailing lists, source code control systems,
|
58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
60 |
+
excluding communication that is conspicuously marked or otherwise
|
61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
62 |
+
|
63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
65 |
+
subsequently incorporated within the Work.
|
66 |
+
|
67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
72 |
+
Work and such Derivative Works in Source or Object form.
|
73 |
+
|
74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
(except as stated in this section) patent license to make, have made,
|
78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
79 |
+
where such license applies only to those patent claims licensable
|
80 |
+
by such Contributor that are necessarily infringed by their
|
81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
83 |
+
institute patent litigation against any entity (including a
|
84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
85 |
+
or a Contribution incorporated within the Work constitutes direct
|
86 |
+
or contributory patent infringement, then any patent licenses
|
87 |
+
granted to You under this License for that Work shall terminate
|
88 |
+
as of the date such litigation is filed.
|
89 |
+
|
90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
91 |
+
Work or Derivative Works thereof in any medium, with or without
|
92 |
+
modifications, and in Source or Object form, provided that You
|
93 |
+
meet the following conditions:
|
94 |
+
|
95 |
+
(a) You must give any other recipients of the Work or
|
96 |
+
Derivative Works a copy of this License; and
|
97 |
+
|
98 |
+
(b) You must cause any modified files to carry prominent notices
|
99 |
+
stating that You changed the files; and
|
100 |
+
|
101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
102 |
+
that You distribute, all copyright, patent, trademark, and
|
103 |
+
attribution notices from the Source form of the Work,
|
104 |
+
excluding those notices that do not pertain to any part of
|
105 |
+
the Derivative Works; and
|
106 |
+
|
107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
108 |
+
distribution, then any Derivative Works that You distribute must
|
109 |
+
include a readable copy of the attribution notices contained
|
110 |
+
within such NOTICE file, excluding those notices that do not
|
111 |
+
pertain to any part of the Derivative Works, in at least one
|
112 |
+
of the following places: within a NOTICE text file distributed
|
113 |
+
as part of the Derivative Works; within the Source form or
|
114 |
+
documentation, if provided along with the Derivative Works; or,
|
115 |
+
within a display generated by the Derivative Works, if and
|
116 |
+
wherever such third-party notices normally appear. The contents
|
117 |
+
of the NOTICE file are for informational purposes only and
|
118 |
+
do not modify the License. You may add Your own attribution
|
119 |
+
notices within Derivative Works that You distribute, alongside
|
120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
121 |
+
that such additional attribution notices cannot be construed
|
122 |
+
as modifying the License.
|
123 |
+
|
124 |
+
You may add Your own copyright statement to Your modifications and
|
125 |
+
may provide additional or different license terms and conditions
|
126 |
+
for use, reproduction, or distribution of Your modifications, or
|
127 |
+
for any such Derivative Works as a whole, provided Your use,
|
128 |
+
reproduction, and distribution of the Work otherwise complies with
|
129 |
+
the conditions stated in this License.
|
130 |
+
|
131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
133 |
+
by You to the Licensor shall be under the terms and conditions of
|
134 |
+
this License, without any additional terms or conditions.
|
135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
136 |
+
the terms of any separate license agreement you may have executed
|
137 |
+
with Licensor regarding such Contributions.
|
138 |
+
|
139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
141 |
+
except as required for reasonable and customary use in describing the
|
142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
143 |
+
|
144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
145 |
+
agreed to in writing, Licensor provides the Work (and each
|
146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
148 |
+
implied, including, without limitation, any warranties or conditions
|
149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
151 |
+
appropriateness of using or redistributing the Work and assume any
|
152 |
+
risks associated with Your exercise of permissions under this License.
|
153 |
+
|
154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
155 |
+
whether in tort (including negligence), contract, or otherwise,
|
156 |
+
unless required by applicable law (such as deliberate and grossly
|
157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
158 |
+
liable to You for damages, including any direct, indirect, special,
|
159 |
+
incidental, or consequential damages of any character arising as a
|
160 |
+
result of this License or out of the use or inability to use the
|
161 |
+
Work (including but not limited to damages for loss of goodwill,
|
162 |
+
work stoppage, computer failure or malfunction, or any and all
|
163 |
+
other commercial damages or losses), even if such Contributor
|
164 |
+
has been advised of the possibility of such damages.
|
165 |
+
|
166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
169 |
+
or other liability obligations and/or rights consistent with this
|
170 |
+
License. However, in accepting such obligations, You may act only
|
171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
172 |
+
of any other Contributor, and only if You agree to indemnify,
|
173 |
+
defend, and hold each Contributor harmless for any liability
|
174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
175 |
+
of your accepting any such warranty or additional liability.
|
176 |
+
|
177 |
+
END OF TERMS AND CONDITIONS
|
178 |
+
|
179 |
+
APPENDIX: How to apply the Apache License to your work.
|
180 |
+
|
181 |
+
To apply the Apache License to your work, attach the following
|
182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
183 |
+
replaced with your own identifying information. (Don't include
|
184 |
+
the brackets!) The text should be enclosed in the appropriate
|
185 |
+
comment syntax for the file format. We also recommend that a
|
186 |
+
file or class name and description of purpose be included on the
|
187 |
+
same "printed page" as the copyright notice for easier
|
188 |
+
identification within third-party archives.
|
189 |
+
|
190 |
+
Copyright [yyyy] [name of copyright owner]
|
191 |
+
|
192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
193 |
+
you may not use this file except in compliance with the License.
|
194 |
+
You may obtain a copy of the License at
|
195 |
+
|
196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
197 |
+
|
198 |
+
Unless required by applicable law or agreed to in writing, software
|
199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
201 |
+
See the License for the specific language governing permissions and
|
202 |
+
limitations under the License.
|
frozen_models/pytorch_i3d/README.md
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# I3D models trained on Kinetics
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
|
5 |
+
This repository contains trained models reported in the paper "[Quo Vadis,
|
6 |
+
Action Recognition? A New Model and the Kinetics
|
7 |
+
Dataset](https://arxiv.org/abs/1705.07750)" by Joao Carreira and Andrew
|
8 |
+
Zisserman.
|
9 |
+
|
10 |
+
This code is based on Deepmind's [Kinetics-I3D](https://github.com/deepmind/kinetics-i3d). Including PyTorch versions of their models.
|
11 |
+
|
12 |
+
## Note
|
13 |
+
This code was written for PyTorch 0.3. Version 0.4 and newer may cause issues.
|
14 |
+
|
15 |
+
|
16 |
+
# Fine-tuning and Feature Extraction
|
17 |
+
We provide code to extract I3D features and fine-tune I3D for charades. Our fine-tuned models on charades are also available in the models director (in addition to Deepmind's trained models). The deepmind pre-trained models were converted to PyTorch and give identical results (flow_imagenet.pt and rgb_imagenet.pt). These models were pretrained on imagenet and kinetics (see [Kinetics-I3D](https://github.com/deepmind/kinetics-i3d) for details).
|
18 |
+
|
19 |
+
## Fine-tuning I3D
|
20 |
+
[train_i3d.py](train_i3d.py) contains the code to fine-tune I3D based on the details in the paper and obtained from the authors. Specifically, this version follows the settings to fine-tune on the [Charades](allenai.org/plato/charades/) dataset based on the author's implementation that won the Charades 2017 challenge. Our fine-tuned RGB and Flow I3D models are available in the model directory (rgb_charades.pt and flow_charades.pt).
|
21 |
+
|
22 |
+
This relied on having the optical flow and RGB frames extracted and saved as images on dist. [charades_dataset.py](charades_dataset.py) contains our code to load video segments for training.
|
23 |
+
|
24 |
+
## Feature Extraction
|
25 |
+
[extract_features.py](extract_features.py) contains the code to load a pre-trained I3D model and extract the features and save the features as numpy arrays. The [charades_dataset_full.py](charades_dataset_full.py) script loads an entire video to extract per-segment features.
|
frozen_models/pytorch_i3d/charades_dataset.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.utils.data as data_utl
|
3 |
+
from torch.utils.data.dataloader import default_collate
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import json
|
7 |
+
import csv
|
8 |
+
import h5py
|
9 |
+
import random
|
10 |
+
import os
|
11 |
+
import os.path
|
12 |
+
|
13 |
+
import cv2
|
14 |
+
|
15 |
+
def video_to_tensor(pic):
|
16 |
+
"""Convert a ``numpy.ndarray`` to tensor.
|
17 |
+
Converts a numpy.ndarray (T x H x W x C)
|
18 |
+
to a torch.FloatTensor of shape (C x T x H x W)
|
19 |
+
|
20 |
+
Args:
|
21 |
+
pic (numpy.ndarray): Video to be converted to tensor.
|
22 |
+
Returns:
|
23 |
+
Tensor: Converted video.
|
24 |
+
"""
|
25 |
+
return torch.from_numpy(pic.transpose([3,0,1,2]))
|
26 |
+
|
27 |
+
|
28 |
+
def load_rgb_frames(image_dir, vid, start, num):
|
29 |
+
frames = []
|
30 |
+
for i in range(start, start+num):
|
31 |
+
img = cv2.imread(os.path.join(image_dir, vid, vid+'-'+str(i).zfill(6)+'.jpg'))[:, :, [2, 1, 0]]
|
32 |
+
w,h,c = img.shape
|
33 |
+
if w < 226 or h < 226:
|
34 |
+
d = 226.-min(w,h)
|
35 |
+
sc = 1+d/min(w,h)
|
36 |
+
img = cv2.resize(img,dsize=(0,0),fx=sc,fy=sc)
|
37 |
+
img = (img/255.)*2 - 1
|
38 |
+
frames.append(img)
|
39 |
+
return np.asarray(frames, dtype=np.float32)
|
40 |
+
|
41 |
+
def load_flow_frames(image_dir, vid, start, num):
|
42 |
+
frames = []
|
43 |
+
for i in range(start, start+num):
|
44 |
+
imgx = cv2.imread(os.path.join(image_dir, vid, vid+'-'+str(i).zfill(6)+'x.jpg'), cv2.IMREAD_GRAYSCALE)
|
45 |
+
imgy = cv2.imread(os.path.join(image_dir, vid, vid+'-'+str(i).zfill(6)+'y.jpg'), cv2.IMREAD_GRAYSCALE)
|
46 |
+
|
47 |
+
w,h = imgx.shape
|
48 |
+
if w < 224 or h < 224:
|
49 |
+
d = 224.-min(w,h)
|
50 |
+
sc = 1+d/min(w,h)
|
51 |
+
imgx = cv2.resize(imgx,dsize=(0,0),fx=sc,fy=sc)
|
52 |
+
imgy = cv2.resize(imgy,dsize=(0,0),fx=sc,fy=sc)
|
53 |
+
|
54 |
+
imgx = (imgx/255.)*2 - 1
|
55 |
+
imgy = (imgy/255.)*2 - 1
|
56 |
+
img = np.asarray([imgx, imgy]).transpose([1,2,0])
|
57 |
+
frames.append(img)
|
58 |
+
return np.asarray(frames, dtype=np.float32)
|
59 |
+
|
60 |
+
|
61 |
+
def make_dataset(split_file, split, root, mode, num_classes=157):
|
62 |
+
dataset = []
|
63 |
+
with open(split_file, 'r') as f:
|
64 |
+
data = json.load(f)
|
65 |
+
|
66 |
+
i = 0
|
67 |
+
for vid in data.keys():
|
68 |
+
if data[vid]['subset'] != split:
|
69 |
+
continue
|
70 |
+
|
71 |
+
if not os.path.exists(os.path.join(root, vid)):
|
72 |
+
continue
|
73 |
+
num_frames = len(os.listdir(os.path.join(root, vid)))
|
74 |
+
if mode == 'flow':
|
75 |
+
num_frames = num_frames//2
|
76 |
+
|
77 |
+
if num_frames < 66:
|
78 |
+
continue
|
79 |
+
|
80 |
+
label = np.zeros((num_classes,num_frames), np.float32)
|
81 |
+
|
82 |
+
fps = num_frames/data[vid]['duration']
|
83 |
+
for ann in data[vid]['actions']:
|
84 |
+
for fr in range(0,num_frames,1):
|
85 |
+
if fr/fps > ann[1] and fr/fps < ann[2]:
|
86 |
+
label[ann[0], fr] = 1 # binary classification
|
87 |
+
dataset.append((vid, label, data[vid]['duration'], num_frames))
|
88 |
+
i += 1
|
89 |
+
|
90 |
+
return dataset
|
91 |
+
|
92 |
+
|
93 |
+
class Charades(data_utl.Dataset):
|
94 |
+
|
95 |
+
def __init__(self, split_file, split, root, mode, transforms=None):
|
96 |
+
|
97 |
+
self.data = make_dataset(split_file, split, root, mode)
|
98 |
+
self.split_file = split_file
|
99 |
+
self.transforms = transforms
|
100 |
+
self.mode = mode
|
101 |
+
self.root = root
|
102 |
+
|
103 |
+
def __getitem__(self, index):
|
104 |
+
"""
|
105 |
+
Args:
|
106 |
+
index (int): Index
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
tuple: (image, target) where target is class_index of the target class.
|
110 |
+
"""
|
111 |
+
vid, label, dur, nf = self.data[index]
|
112 |
+
start_f = random.randint(1,nf-65)
|
113 |
+
|
114 |
+
if self.mode == 'rgb':
|
115 |
+
imgs = load_rgb_frames(self.root, vid, start_f, 64)
|
116 |
+
else:
|
117 |
+
imgs = load_flow_frames(self.root, vid, start_f, 64)
|
118 |
+
label = label[:, start_f:start_f+64]
|
119 |
+
|
120 |
+
imgs = self.transforms(imgs)
|
121 |
+
|
122 |
+
return video_to_tensor(imgs), torch.from_numpy(label)
|
123 |
+
|
124 |
+
def __len__(self):
|
125 |
+
return len(self.data)
|
frozen_models/pytorch_i3d/charades_dataset_full.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.utils.data as data_utl
|
3 |
+
from torch.utils.data.dataloader import default_collate
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import json
|
7 |
+
import csv
|
8 |
+
import h5py
|
9 |
+
|
10 |
+
import os
|
11 |
+
import os.path
|
12 |
+
|
13 |
+
import cv2
|
14 |
+
|
15 |
+
def video_to_tensor(pic):
|
16 |
+
"""Convert a ``numpy.ndarray`` to tensor.
|
17 |
+
Converts a numpy.ndarray (T x H x W x C)
|
18 |
+
to a torch.FloatTensor of shape (C x T x H x W)
|
19 |
+
|
20 |
+
Args:
|
21 |
+
pic (numpy.ndarray): Video to be converted to tensor.
|
22 |
+
Returns:
|
23 |
+
Tensor: Converted video.
|
24 |
+
"""
|
25 |
+
return torch.from_numpy(pic.transpose([3,0,1,2]))
|
26 |
+
|
27 |
+
|
28 |
+
def load_rgb_frames(image_dir, vid, start, num):
|
29 |
+
frames = []
|
30 |
+
for i in range(start, start+num):
|
31 |
+
img = cv2.imread(os.path.join(image_dir, vid, vid+'-'+str(i).zfill(6)+'.jpg'))[:, :, [2, 1, 0]]
|
32 |
+
w,h,c = img.shape
|
33 |
+
if w < 226 or h < 226:
|
34 |
+
d = 226.-min(w,h)
|
35 |
+
sc = 1+d/min(w,h)
|
36 |
+
img = cv2.resize(img,dsize=(0,0),fx=sc,fy=sc)
|
37 |
+
img = (img/255.)*2 - 1
|
38 |
+
frames.append(img)
|
39 |
+
return np.asarray(frames, dtype=np.float32)
|
40 |
+
|
41 |
+
def load_flow_frames(image_dir, vid, start, num):
|
42 |
+
frames = []
|
43 |
+
for i in range(start, start+num):
|
44 |
+
imgx = cv2.imread(os.path.join(image_dir, vid, vid+'-'+str(i).zfill(6)+'x.jpg'), cv2.IMREAD_GRAYSCALE)
|
45 |
+
imgy = cv2.imread(os.path.join(image_dir, vid, vid+'-'+str(i).zfill(6)+'y.jpg'), cv2.IMREAD_GRAYSCALE)
|
46 |
+
|
47 |
+
w,h = imgx.shape
|
48 |
+
if w < 224 or h < 224:
|
49 |
+
d = 224.-min(w,h)
|
50 |
+
sc = 1+d/min(w,h)
|
51 |
+
imgx = cv2.resize(imgx,dsize=(0,0),fx=sc,fy=sc)
|
52 |
+
imgy = cv2.resize(imgy,dsize=(0,0),fx=sc,fy=sc)
|
53 |
+
|
54 |
+
imgx = (imgx/255.)*2 - 1
|
55 |
+
imgy = (imgy/255.)*2 - 1
|
56 |
+
img = np.asarray([imgx, imgy]).transpose([1,2,0])
|
57 |
+
frames.append(img)
|
58 |
+
return np.asarray(frames, dtype=np.float32)
|
59 |
+
|
60 |
+
|
61 |
+
def make_dataset(split_file, split, root, mode, num_classes=157):
|
62 |
+
dataset = []
|
63 |
+
with open(split_file, 'r') as f:
|
64 |
+
data = json.load(f)
|
65 |
+
|
66 |
+
i = 0
|
67 |
+
for vid in data.keys():
|
68 |
+
if data[vid]['subset'] != split:
|
69 |
+
continue
|
70 |
+
|
71 |
+
if not os.path.exists(os.path.join(root, vid)):
|
72 |
+
continue
|
73 |
+
num_frames = len(os.listdir(os.path.join(root, vid)))
|
74 |
+
if mode == 'flow':
|
75 |
+
num_frames = num_frames//2
|
76 |
+
|
77 |
+
label = np.zeros((num_classes,num_frames), np.float32)
|
78 |
+
|
79 |
+
fps = num_frames/data[vid]['duration']
|
80 |
+
for ann in data[vid]['actions']:
|
81 |
+
for fr in range(0,num_frames,1):
|
82 |
+
if fr/fps > ann[1] and fr/fps < ann[2]:
|
83 |
+
label[ann[0], fr] = 1 # binary classification
|
84 |
+
dataset.append((vid, label, data[vid]['duration'], num_frames))
|
85 |
+
i += 1
|
86 |
+
|
87 |
+
return dataset
|
88 |
+
|
89 |
+
|
90 |
+
class Charades(data_utl.Dataset):
|
91 |
+
|
92 |
+
def __init__(self, split_file, split, root, mode, transforms=None, save_dir='', num=0):
|
93 |
+
|
94 |
+
self.data = make_dataset(split_file, split, root, mode)
|
95 |
+
self.split_file = split_file
|
96 |
+
self.transforms = transforms
|
97 |
+
self.mode = mode
|
98 |
+
self.root = root
|
99 |
+
self.save_dir = save_dir
|
100 |
+
|
101 |
+
def __getitem__(self, index):
|
102 |
+
"""
|
103 |
+
Args:
|
104 |
+
index (int): Index
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
tuple: (image, target) where target is class_index of the target class.
|
108 |
+
"""
|
109 |
+
vid, label, dur, nf = self.data[index]
|
110 |
+
if os.path.exists(os.path.join(self.save_dir, vid+'.npy')):
|
111 |
+
return 0, 0, vid
|
112 |
+
|
113 |
+
if self.mode == 'rgb':
|
114 |
+
imgs = load_rgb_frames(self.root, vid, 1, nf)
|
115 |
+
else:
|
116 |
+
imgs = load_flow_frames(self.root, vid, 1, nf)
|
117 |
+
|
118 |
+
imgs = self.transforms(imgs)
|
119 |
+
|
120 |
+
return video_to_tensor(imgs), torch.from_numpy(label), vid
|
121 |
+
|
122 |
+
def __len__(self):
|
123 |
+
return len(self.data)
|
frozen_models/pytorch_i3d/extract_features.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
|
3 |
+
import sys
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
parser = argparse.ArgumentParser()
|
7 |
+
parser.add_argument('-mode', type=str, help='rgb or flow')
|
8 |
+
parser.add_argument('-load_model', type=str)
|
9 |
+
parser.add_argument('-root', type=str)
|
10 |
+
parser.add_argument('-gpu', type=str)
|
11 |
+
parser.add_argument('-save_dir', type=str)
|
12 |
+
|
13 |
+
args = parser.parse_args()
|
14 |
+
os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
import torch.optim as optim
|
20 |
+
from torch.optim import lr_scheduler
|
21 |
+
from torch.autograd import Variable
|
22 |
+
|
23 |
+
import torchvision
|
24 |
+
from torchvision import datasets, transforms
|
25 |
+
import videotransforms
|
26 |
+
|
27 |
+
|
28 |
+
import numpy as np
|
29 |
+
|
30 |
+
from pytorch_i3d import InceptionI3d
|
31 |
+
|
32 |
+
from charades_dataset_full import Charades as Dataset
|
33 |
+
|
34 |
+
|
35 |
+
def run(max_steps=64e3, mode='rgb', root='/ssd2/charades/Charades_v1_rgb', split='charades/charades.json', batch_size=1, load_model='', save_dir=''):
|
36 |
+
# setup dataset
|
37 |
+
test_transforms = transforms.Compose([videotransforms.CenterCrop(224)])
|
38 |
+
|
39 |
+
dataset = Dataset(split, 'training', root, mode, test_transforms, num=-1, save_dir=save_dir)
|
40 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
|
41 |
+
|
42 |
+
val_dataset = Dataset(split, 'testing', root, mode, test_transforms, num=-1, save_dir=save_dir)
|
43 |
+
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
|
44 |
+
|
45 |
+
dataloaders = {'train': dataloader, 'val': val_dataloader}
|
46 |
+
datasets = {'train': dataset, 'val': val_dataset}
|
47 |
+
|
48 |
+
|
49 |
+
# setup the model
|
50 |
+
if mode == 'flow':
|
51 |
+
i3d = InceptionI3d(400, in_channels=2)
|
52 |
+
else:
|
53 |
+
i3d = InceptionI3d(400, in_channels=3)
|
54 |
+
i3d.replace_logits(157)
|
55 |
+
i3d.load_state_dict(torch.load(load_model))
|
56 |
+
i3d.cuda()
|
57 |
+
|
58 |
+
for phase in ['train', 'val']:
|
59 |
+
i3d.train(False) # Set model to evaluate mode
|
60 |
+
|
61 |
+
tot_loss = 0.0
|
62 |
+
tot_loc_loss = 0.0
|
63 |
+
tot_cls_loss = 0.0
|
64 |
+
|
65 |
+
# Iterate over data.
|
66 |
+
for data in dataloaders[phase]:
|
67 |
+
# get the inputs
|
68 |
+
inputs, labels, name = data
|
69 |
+
if os.path.exists(os.path.join(save_dir, name[0]+'.npy')):
|
70 |
+
continue
|
71 |
+
|
72 |
+
b,c,t,h,w = inputs.shape
|
73 |
+
if t > 1600:
|
74 |
+
features = []
|
75 |
+
for start in range(1, t-56, 1600):
|
76 |
+
end = min(t-1, start+1600+56)
|
77 |
+
start = max(1, start-48)
|
78 |
+
ip = Variable(torch.from_numpy(inputs.numpy()[:,:,start:end]).cuda(), volatile=True)
|
79 |
+
features.append(i3d.extract_features(ip).squeeze(0).permute(1,2,3,0).data.cpu().numpy())
|
80 |
+
np.save(os.path.join(save_dir, name[0]), np.concatenate(features, axis=0))
|
81 |
+
else:
|
82 |
+
# wrap them in Variable
|
83 |
+
inputs = Variable(inputs.cuda(), volatile=True)
|
84 |
+
features = i3d.extract_features(inputs)
|
85 |
+
np.save(os.path.join(save_dir, name[0]), features.squeeze(0).permute(1,2,3,0).data.cpu().numpy())
|
86 |
+
|
87 |
+
|
88 |
+
if __name__ == '__main__':
|
89 |
+
# need to add argparse
|
90 |
+
run(mode=args.mode, root=args.root, load_model=args.load_model, save_dir=args.save_dir)
|
frozen_models/pytorch_i3d/models/flow_charades.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:74a8b6f6226ec6850aec384dd91f86a617be70a585417d9cc71ceec59289cd7e
|
3 |
+
size 49802179
|
frozen_models/pytorch_i3d/models/flow_imagenet.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:81c8650f33698a9ad5a81ded395a496f1363497e83e8e63ab364c7539dc740b0
|
3 |
+
size 50795330
|
frozen_models/pytorch_i3d/models/rgb_charades.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:65a54c0f09ef7aa0b1028f15a83792dd3d023fc52a25e6ffbef252eb55da0933
|
3 |
+
size 49886838
|
frozen_models/pytorch_i3d/models/rgb_imagenet.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2609088c2e8c868187c9921c50bc225329a9057ed75e76120e0b4a397a2c7538
|
3 |
+
size 50883138
|
frozen_models/pytorch_i3d/pytorch_i3d.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.autograd import Variable
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
from collections import OrderedDict
|
11 |
+
|
12 |
+
|
13 |
+
class MaxPool3dSamePadding(nn.MaxPool3d):
|
14 |
+
|
15 |
+
def compute_pad(self, dim, s):
|
16 |
+
if s % self.stride[dim] == 0:
|
17 |
+
return max(self.kernel_size[dim] - self.stride[dim], 0)
|
18 |
+
else:
|
19 |
+
return max(self.kernel_size[dim] - (s % self.stride[dim]), 0)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
# compute 'same' padding
|
23 |
+
(batch, channel, t, h, w) = x.size()
|
24 |
+
#print t,h,w
|
25 |
+
out_t = np.ceil(float(t) / float(self.stride[0]))
|
26 |
+
out_h = np.ceil(float(h) / float(self.stride[1]))
|
27 |
+
out_w = np.ceil(float(w) / float(self.stride[2]))
|
28 |
+
#print out_t, out_h, out_w
|
29 |
+
pad_t = self.compute_pad(0, t)
|
30 |
+
pad_h = self.compute_pad(1, h)
|
31 |
+
pad_w = self.compute_pad(2, w)
|
32 |
+
#print pad_t, pad_h, pad_w
|
33 |
+
|
34 |
+
pad_t_f = pad_t // 2
|
35 |
+
pad_t_b = pad_t - pad_t_f
|
36 |
+
pad_h_f = pad_h // 2
|
37 |
+
pad_h_b = pad_h - pad_h_f
|
38 |
+
pad_w_f = pad_w // 2
|
39 |
+
pad_w_b = pad_w - pad_w_f
|
40 |
+
|
41 |
+
pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
|
42 |
+
#print x.size()
|
43 |
+
#print pad
|
44 |
+
x = F.pad(x, pad)
|
45 |
+
return super(MaxPool3dSamePadding, self).forward(x)
|
46 |
+
|
47 |
+
|
48 |
+
class Unit3D(nn.Module):
|
49 |
+
|
50 |
+
def __init__(self, in_channels,
|
51 |
+
output_channels,
|
52 |
+
kernel_shape=(1, 1, 1),
|
53 |
+
stride=(1, 1, 1),
|
54 |
+
padding=0,
|
55 |
+
activation_fn=F.relu,
|
56 |
+
use_batch_norm=True,
|
57 |
+
use_bias=False,
|
58 |
+
name='unit_3d'):
|
59 |
+
|
60 |
+
"""Initializes Unit3D module."""
|
61 |
+
super(Unit3D, self).__init__()
|
62 |
+
|
63 |
+
self._output_channels = output_channels
|
64 |
+
self._kernel_shape = kernel_shape
|
65 |
+
self._stride = stride
|
66 |
+
self._use_batch_norm = use_batch_norm
|
67 |
+
self._activation_fn = activation_fn
|
68 |
+
self._use_bias = use_bias
|
69 |
+
self.name = name
|
70 |
+
self.padding = padding
|
71 |
+
|
72 |
+
self.conv3d = nn.Conv3d(in_channels=in_channels,
|
73 |
+
out_channels=self._output_channels,
|
74 |
+
kernel_size=self._kernel_shape,
|
75 |
+
stride=self._stride,
|
76 |
+
padding=0, # we always want padding to be 0 here. We will dynamically pad based on input size in forward function
|
77 |
+
bias=self._use_bias)
|
78 |
+
|
79 |
+
if self._use_batch_norm:
|
80 |
+
self.bn = nn.BatchNorm3d(self._output_channels, eps=0.001, momentum=0.01)
|
81 |
+
|
82 |
+
def compute_pad(self, dim, s):
|
83 |
+
if s % self._stride[dim] == 0:
|
84 |
+
return max(self._kernel_shape[dim] - self._stride[dim], 0)
|
85 |
+
else:
|
86 |
+
return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0)
|
87 |
+
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
# compute 'same' padding
|
91 |
+
(batch, channel, t, h, w) = x.size()
|
92 |
+
#print t,h,w
|
93 |
+
out_t = np.ceil(float(t) / float(self._stride[0]))
|
94 |
+
out_h = np.ceil(float(h) / float(self._stride[1]))
|
95 |
+
out_w = np.ceil(float(w) / float(self._stride[2]))
|
96 |
+
#print out_t, out_h, out_w
|
97 |
+
pad_t = self.compute_pad(0, t)
|
98 |
+
pad_h = self.compute_pad(1, h)
|
99 |
+
pad_w = self.compute_pad(2, w)
|
100 |
+
#print pad_t, pad_h, pad_w
|
101 |
+
|
102 |
+
pad_t_f = pad_t // 2
|
103 |
+
pad_t_b = pad_t - pad_t_f
|
104 |
+
pad_h_f = pad_h // 2
|
105 |
+
pad_h_b = pad_h - pad_h_f
|
106 |
+
pad_w_f = pad_w // 2
|
107 |
+
pad_w_b = pad_w - pad_w_f
|
108 |
+
|
109 |
+
pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
|
110 |
+
#print x.size()
|
111 |
+
#print pad
|
112 |
+
x = F.pad(x, pad)
|
113 |
+
#print x.size()
|
114 |
+
|
115 |
+
x = self.conv3d(x)
|
116 |
+
if self._use_batch_norm:
|
117 |
+
x = self.bn(x)
|
118 |
+
if self._activation_fn is not None:
|
119 |
+
x = self._activation_fn(x)
|
120 |
+
return x
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
class InceptionModule(nn.Module):
|
125 |
+
def __init__(self, in_channels, out_channels, name):
|
126 |
+
super(InceptionModule, self).__init__()
|
127 |
+
|
128 |
+
self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0,
|
129 |
+
name=name+'/Branch_0/Conv3d_0a_1x1')
|
130 |
+
self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0,
|
131 |
+
name=name+'/Branch_1/Conv3d_0a_1x1')
|
132 |
+
self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3],
|
133 |
+
name=name+'/Branch_1/Conv3d_0b_3x3')
|
134 |
+
self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0,
|
135 |
+
name=name+'/Branch_2/Conv3d_0a_1x1')
|
136 |
+
self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3],
|
137 |
+
name=name+'/Branch_2/Conv3d_0b_3x3')
|
138 |
+
self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3],
|
139 |
+
stride=(1, 1, 1), padding=0)
|
140 |
+
self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0,
|
141 |
+
name=name+'/Branch_3/Conv3d_0b_1x1')
|
142 |
+
self.name = name
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
b0 = self.b0(x)
|
146 |
+
b1 = self.b1b(self.b1a(x))
|
147 |
+
b2 = self.b2b(self.b2a(x))
|
148 |
+
b3 = self.b3b(self.b3a(x))
|
149 |
+
return torch.cat([b0,b1,b2,b3], dim=1)
|
150 |
+
|
151 |
+
|
152 |
+
class InceptionI3d(nn.Module):
|
153 |
+
"""Inception-v1 I3D architecture.
|
154 |
+
The model is introduced in:
|
155 |
+
Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset
|
156 |
+
Joao Carreira, Andrew Zisserman
|
157 |
+
https://arxiv.org/pdf/1705.07750v1.pdf.
|
158 |
+
See also the Inception architecture, introduced in:
|
159 |
+
Going deeper with convolutions
|
160 |
+
Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
|
161 |
+
Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
|
162 |
+
http://arxiv.org/pdf/1409.4842v1.pdf.
|
163 |
+
"""
|
164 |
+
|
165 |
+
# Endpoints of the model in order. During construction, all the endpoints up
|
166 |
+
# to a designated `final_endpoint` are returned in a dictionary as the
|
167 |
+
# second return value.
|
168 |
+
VALID_ENDPOINTS = (
|
169 |
+
'Conv3d_1a_7x7',
|
170 |
+
'MaxPool3d_2a_3x3',
|
171 |
+
'Conv3d_2b_1x1',
|
172 |
+
'Conv3d_2c_3x3',
|
173 |
+
'MaxPool3d_3a_3x3',
|
174 |
+
'Mixed_3b',
|
175 |
+
'Mixed_3c',
|
176 |
+
'MaxPool3d_4a_3x3',
|
177 |
+
'Mixed_4b',
|
178 |
+
'Mixed_4c',
|
179 |
+
'Mixed_4d',
|
180 |
+
'Mixed_4e',
|
181 |
+
'Mixed_4f',
|
182 |
+
'MaxPool3d_5a_2x2',
|
183 |
+
'Mixed_5b',
|
184 |
+
'Mixed_5c',
|
185 |
+
'Logits',
|
186 |
+
'Predictions',
|
187 |
+
)
|
188 |
+
|
189 |
+
def __init__(self, num_classes=400, spatial_squeeze=True,
|
190 |
+
final_endpoint='Logits', name='inception_i3d', in_channels=3, dropout_keep_prob=0.5):
|
191 |
+
"""Initializes I3D model instance.
|
192 |
+
Args:
|
193 |
+
num_classes: The number of outputs in the logit layer (default 400, which
|
194 |
+
matches the Kinetics dataset).
|
195 |
+
spatial_squeeze: Whether to squeeze the spatial dimensions for the logits
|
196 |
+
before returning (default True).
|
197 |
+
final_endpoint: The model contains many possible endpoints.
|
198 |
+
`final_endpoint` specifies the last endpoint for the model to be built
|
199 |
+
up to. In addition to the output at `final_endpoint`, all the outputs
|
200 |
+
at endpoints up to `final_endpoint` will also be returned, in a
|
201 |
+
dictionary. `final_endpoint` must be one of
|
202 |
+
InceptionI3d.VALID_ENDPOINTS (default 'Logits').
|
203 |
+
name: A string (optional). The name of this module.
|
204 |
+
Raises:
|
205 |
+
ValueError: if `final_endpoint` is not recognized.
|
206 |
+
"""
|
207 |
+
|
208 |
+
if final_endpoint not in self.VALID_ENDPOINTS:
|
209 |
+
raise ValueError('Unknown final endpoint %s' % final_endpoint)
|
210 |
+
|
211 |
+
super(InceptionI3d, self).__init__()
|
212 |
+
self._num_classes = num_classes
|
213 |
+
self._spatial_squeeze = spatial_squeeze
|
214 |
+
self._final_endpoint = final_endpoint
|
215 |
+
self.logits = None
|
216 |
+
|
217 |
+
if self._final_endpoint not in self.VALID_ENDPOINTS:
|
218 |
+
raise ValueError('Unknown final endpoint %s' % self._final_endpoint)
|
219 |
+
|
220 |
+
self.end_points = {}
|
221 |
+
end_point = 'Conv3d_1a_7x7'
|
222 |
+
self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7],
|
223 |
+
stride=(2, 2, 2), padding=(3,3,3), name=name+end_point)
|
224 |
+
if self._final_endpoint == end_point: return
|
225 |
+
|
226 |
+
end_point = 'MaxPool3d_2a_3x3'
|
227 |
+
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
|
228 |
+
padding=0)
|
229 |
+
if self._final_endpoint == end_point: return
|
230 |
+
|
231 |
+
end_point = 'Conv3d_2b_1x1'
|
232 |
+
self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0,
|
233 |
+
name=name+end_point)
|
234 |
+
if self._final_endpoint == end_point: return
|
235 |
+
|
236 |
+
end_point = 'Conv3d_2c_3x3'
|
237 |
+
self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1,
|
238 |
+
name=name+end_point)
|
239 |
+
if self._final_endpoint == end_point: return
|
240 |
+
|
241 |
+
end_point = 'MaxPool3d_3a_3x3'
|
242 |
+
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
|
243 |
+
padding=0)
|
244 |
+
if self._final_endpoint == end_point: return
|
245 |
+
|
246 |
+
end_point = 'Mixed_3b'
|
247 |
+
self.end_points[end_point] = InceptionModule(192, [64,96,128,16,32,32], name+end_point)
|
248 |
+
if self._final_endpoint == end_point: return
|
249 |
+
|
250 |
+
end_point = 'Mixed_3c'
|
251 |
+
self.end_points[end_point] = InceptionModule(256, [128,128,192,32,96,64], name+end_point)
|
252 |
+
if self._final_endpoint == end_point: return
|
253 |
+
|
254 |
+
end_point = 'MaxPool3d_4a_3x3'
|
255 |
+
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2),
|
256 |
+
padding=0)
|
257 |
+
if self._final_endpoint == end_point: return
|
258 |
+
|
259 |
+
end_point = 'Mixed_4b'
|
260 |
+
self.end_points[end_point] = InceptionModule(128+192+96+64, [192,96,208,16,48,64], name+end_point)
|
261 |
+
if self._final_endpoint == end_point: return
|
262 |
+
|
263 |
+
end_point = 'Mixed_4c'
|
264 |
+
self.end_points[end_point] = InceptionModule(192+208+48+64, [160,112,224,24,64,64], name+end_point)
|
265 |
+
if self._final_endpoint == end_point: return
|
266 |
+
|
267 |
+
end_point = 'Mixed_4d'
|
268 |
+
self.end_points[end_point] = InceptionModule(160+224+64+64, [128,128,256,24,64,64], name+end_point)
|
269 |
+
if self._final_endpoint == end_point: return
|
270 |
+
|
271 |
+
end_point = 'Mixed_4e'
|
272 |
+
self.end_points[end_point] = InceptionModule(128+256+64+64, [112,144,288,32,64,64], name+end_point)
|
273 |
+
if self._final_endpoint == end_point: return
|
274 |
+
|
275 |
+
end_point = 'Mixed_4f'
|
276 |
+
self.end_points[end_point] = InceptionModule(112+288+64+64, [256,160,320,32,128,128], name+end_point)
|
277 |
+
if self._final_endpoint == end_point: return
|
278 |
+
|
279 |
+
end_point = 'MaxPool3d_5a_2x2'
|
280 |
+
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2),
|
281 |
+
padding=0)
|
282 |
+
if self._final_endpoint == end_point: return
|
283 |
+
|
284 |
+
end_point = 'Mixed_5b'
|
285 |
+
self.end_points[end_point] = InceptionModule(256+320+128+128, [256,160,320,32,128,128], name+end_point)
|
286 |
+
if self._final_endpoint == end_point: return
|
287 |
+
|
288 |
+
end_point = 'Mixed_5c'
|
289 |
+
self.end_points[end_point] = InceptionModule(256+320+128+128, [384,192,384,48,128,128], name+end_point)
|
290 |
+
if self._final_endpoint == end_point: return
|
291 |
+
|
292 |
+
end_point = 'Logits'
|
293 |
+
self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7],
|
294 |
+
stride=(1, 1, 1))
|
295 |
+
self.dropout = nn.Dropout(dropout_keep_prob)
|
296 |
+
self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,
|
297 |
+
kernel_shape=[1, 1, 1],
|
298 |
+
padding=0,
|
299 |
+
activation_fn=None,
|
300 |
+
use_batch_norm=False,
|
301 |
+
use_bias=True,
|
302 |
+
name='logits')
|
303 |
+
|
304 |
+
self.build()
|
305 |
+
|
306 |
+
|
307 |
+
def replace_logits(self, num_classes):
|
308 |
+
self._num_classes = num_classes
|
309 |
+
self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,
|
310 |
+
kernel_shape=[1, 1, 1],
|
311 |
+
padding=0,
|
312 |
+
activation_fn=None,
|
313 |
+
use_batch_norm=False,
|
314 |
+
use_bias=True,
|
315 |
+
name='logits')
|
316 |
+
|
317 |
+
|
318 |
+
def build(self):
|
319 |
+
for k in self.end_points.keys():
|
320 |
+
self.add_module(k, self.end_points[k])
|
321 |
+
|
322 |
+
def forward(self, x):
|
323 |
+
for end_point in self.VALID_ENDPOINTS:
|
324 |
+
if end_point in self.end_points:
|
325 |
+
x = self._modules[end_point](x) # use _modules to work with dataparallel
|
326 |
+
|
327 |
+
x = self.logits(self.dropout(self.avg_pool(x)))
|
328 |
+
if self._spatial_squeeze:
|
329 |
+
logits = x.squeeze(3).squeeze(3)
|
330 |
+
# logits is batch X time X classes, which is what we want to work with
|
331 |
+
return logits
|
332 |
+
|
333 |
+
|
334 |
+
def extract_features(self, x):
|
335 |
+
for end_point in self.VALID_ENDPOINTS:
|
336 |
+
if end_point in self.end_points:
|
337 |
+
x = self._modules[end_point](x)
|
338 |
+
return self.avg_pool(x)
|
frozen_models/pytorch_i3d/train_i3d.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
|
3 |
+
#os.environ["CUDA_VISIBLE_DEVICES"]='0,1,2,3'
|
4 |
+
import sys
|
5 |
+
import argparse
|
6 |
+
|
7 |
+
parser = argparse.ArgumentParser()
|
8 |
+
parser.add_argument('-mode', type=str, help='rgb or flow')
|
9 |
+
parser.add_argument('-save_model', type=str)
|
10 |
+
parser.add_argument('-root', type=str)
|
11 |
+
|
12 |
+
args = parser.parse_args()
|
13 |
+
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
import torch.optim as optim
|
19 |
+
from torch.optim import lr_scheduler
|
20 |
+
from torch.autograd import Variable
|
21 |
+
|
22 |
+
import torchvision
|
23 |
+
from torchvision import datasets, transforms
|
24 |
+
import videotransforms
|
25 |
+
|
26 |
+
|
27 |
+
import numpy as np
|
28 |
+
|
29 |
+
from pytorch_i3d import InceptionI3d
|
30 |
+
|
31 |
+
from charades_dataset import Charades as Dataset
|
32 |
+
|
33 |
+
|
34 |
+
def run(init_lr=0.1, max_steps=64e3, mode='rgb', root='/ssd/Charades_v1_rgb', train_split='charades/charades.json', batch_size=8*5, save_model=''):
|
35 |
+
# setup dataset
|
36 |
+
train_transforms = transforms.Compose([videotransforms.RandomCrop(224),
|
37 |
+
videotransforms.RandomHorizontalFlip(),
|
38 |
+
])
|
39 |
+
test_transforms = transforms.Compose([videotransforms.CenterCrop(224)])
|
40 |
+
|
41 |
+
dataset = Dataset(train_split, 'training', root, mode, train_transforms)
|
42 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=36, pin_memory=True)
|
43 |
+
|
44 |
+
val_dataset = Dataset(train_split, 'testing', root, mode, test_transforms)
|
45 |
+
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=36, pin_memory=True)
|
46 |
+
|
47 |
+
dataloaders = {'train': dataloader, 'val': val_dataloader}
|
48 |
+
datasets = {'train': dataset, 'val': val_dataset}
|
49 |
+
|
50 |
+
|
51 |
+
# setup the model
|
52 |
+
if mode == 'flow':
|
53 |
+
i3d = InceptionI3d(400, in_channels=2)
|
54 |
+
i3d.load_state_dict(torch.load('models/flow_imagenet.pt'))
|
55 |
+
else:
|
56 |
+
i3d = InceptionI3d(400, in_channels=3)
|
57 |
+
i3d.load_state_dict(torch.load('models/rgb_imagenet.pt'))
|
58 |
+
i3d.replace_logits(157)
|
59 |
+
#i3d.load_state_dict(torch.load('/ssd/models/000920.pt'))
|
60 |
+
i3d.cuda()
|
61 |
+
i3d = nn.DataParallel(i3d)
|
62 |
+
|
63 |
+
lr = init_lr
|
64 |
+
optimizer = optim.SGD(i3d.parameters(), lr=lr, momentum=0.9, weight_decay=0.0000001)
|
65 |
+
lr_sched = optim.lr_scheduler.MultiStepLR(optimizer, [300, 1000])
|
66 |
+
|
67 |
+
|
68 |
+
num_steps_per_update = 4 # accum gradient
|
69 |
+
steps = 0
|
70 |
+
# train it
|
71 |
+
while steps < max_steps:#for epoch in range(num_epochs):
|
72 |
+
print 'Step {}/{}'.format(steps, max_steps)
|
73 |
+
print '-' * 10
|
74 |
+
|
75 |
+
# Each epoch has a training and validation phase
|
76 |
+
for phase in ['train', 'val']:
|
77 |
+
if phase == 'train':
|
78 |
+
i3d.train(True)
|
79 |
+
else:
|
80 |
+
i3d.train(False) # Set model to evaluate mode
|
81 |
+
|
82 |
+
tot_loss = 0.0
|
83 |
+
tot_loc_loss = 0.0
|
84 |
+
tot_cls_loss = 0.0
|
85 |
+
num_iter = 0
|
86 |
+
optimizer.zero_grad()
|
87 |
+
|
88 |
+
# Iterate over data.
|
89 |
+
for data in dataloaders[phase]:
|
90 |
+
num_iter += 1
|
91 |
+
# get the inputs
|
92 |
+
inputs, labels = data
|
93 |
+
|
94 |
+
# wrap them in Variable
|
95 |
+
inputs = Variable(inputs.cuda())
|
96 |
+
t = inputs.size(2)
|
97 |
+
labels = Variable(labels.cuda())
|
98 |
+
|
99 |
+
per_frame_logits = i3d(inputs)
|
100 |
+
# upsample to input size
|
101 |
+
per_frame_logits = F.upsample(per_frame_logits, t, mode='linear')
|
102 |
+
|
103 |
+
# compute localization loss
|
104 |
+
loc_loss = F.binary_cross_entropy_with_logits(per_frame_logits, labels)
|
105 |
+
tot_loc_loss += loc_loss.data[0]
|
106 |
+
|
107 |
+
# compute classification loss (with max-pooling along time B x C x T)
|
108 |
+
cls_loss = F.binary_cross_entropy_with_logits(torch.max(per_frame_logits, dim=2)[0], torch.max(labels, dim=2)[0])
|
109 |
+
tot_cls_loss += cls_loss.data[0]
|
110 |
+
|
111 |
+
loss = (0.5*loc_loss + 0.5*cls_loss)/num_steps_per_update
|
112 |
+
tot_loss += loss.data[0]
|
113 |
+
loss.backward()
|
114 |
+
|
115 |
+
if num_iter == num_steps_per_update and phase == 'train':
|
116 |
+
steps += 1
|
117 |
+
num_iter = 0
|
118 |
+
optimizer.step()
|
119 |
+
optimizer.zero_grad()
|
120 |
+
lr_sched.step()
|
121 |
+
if steps % 10 == 0:
|
122 |
+
print '{} Loc Loss: {:.4f} Cls Loss: {:.4f} Tot Loss: {:.4f}'.format(phase, tot_loc_loss/(10*num_steps_per_update), tot_cls_loss/(10*num_steps_per_update), tot_loss/10)
|
123 |
+
# save model
|
124 |
+
torch.save(i3d.module.state_dict(), save_model+str(steps).zfill(6)+'.pt')
|
125 |
+
tot_loss = tot_loc_loss = tot_cls_loss = 0.
|
126 |
+
if phase == 'val':
|
127 |
+
print '{} Loc Loss: {:.4f} Cls Loss: {:.4f} Tot Loss: {:.4f}'.format(phase, tot_loc_loss/num_iter, tot_cls_loss/num_iter, (tot_loss*num_steps_per_update)/num_iter)
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
if __name__ == '__main__':
|
132 |
+
# need to add argparse
|
133 |
+
run(mode=args.mode, root=args.root, save_model=args.save_model)
|
frozen_models/pytorch_i3d/videotransforms.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import numbers
|
3 |
+
import random
|
4 |
+
|
5 |
+
class RandomCrop(object):
|
6 |
+
"""Crop the given video sequences (t x h x w) at a random location.
|
7 |
+
Args:
|
8 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
9 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
10 |
+
made.
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, size):
|
14 |
+
if isinstance(size, numbers.Number):
|
15 |
+
self.size = (int(size), int(size))
|
16 |
+
else:
|
17 |
+
self.size = size
|
18 |
+
|
19 |
+
@staticmethod
|
20 |
+
def get_params(img, output_size):
|
21 |
+
"""Get parameters for ``crop`` for a random crop.
|
22 |
+
Args:
|
23 |
+
img (PIL Image): Image to be cropped.
|
24 |
+
output_size (tuple): Expected output size of the crop.
|
25 |
+
Returns:
|
26 |
+
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
|
27 |
+
"""
|
28 |
+
t, h, w, c = img.shape
|
29 |
+
th, tw = output_size
|
30 |
+
if w == tw and h == th:
|
31 |
+
return 0, 0, h, w
|
32 |
+
|
33 |
+
i = random.randint(0, h - th) if h!=th else 0
|
34 |
+
j = random.randint(0, w - tw) if w!=tw else 0
|
35 |
+
return i, j, th, tw
|
36 |
+
|
37 |
+
def __call__(self, imgs):
|
38 |
+
|
39 |
+
i, j, h, w = self.get_params(imgs, self.size)
|
40 |
+
|
41 |
+
imgs = imgs[:, i:i+h, j:j+w, :]
|
42 |
+
return imgs
|
43 |
+
|
44 |
+
def __repr__(self):
|
45 |
+
return self.__class__.__name__ + '(size={0})'.format(self.size)
|
46 |
+
|
47 |
+
class CenterCrop(object):
|
48 |
+
"""Crops the given seq Images at the center.
|
49 |
+
Args:
|
50 |
+
size (sequence or int): Desired output size of the crop. If size is an
|
51 |
+
int instead of sequence like (h, w), a square crop (size, size) is
|
52 |
+
made.
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(self, size):
|
56 |
+
if isinstance(size, numbers.Number):
|
57 |
+
self.size = (int(size), int(size))
|
58 |
+
else:
|
59 |
+
self.size = size
|
60 |
+
|
61 |
+
def __call__(self, imgs):
|
62 |
+
"""
|
63 |
+
Args:
|
64 |
+
img (PIL Image): Image to be cropped.
|
65 |
+
Returns:
|
66 |
+
PIL Image: Cropped image.
|
67 |
+
"""
|
68 |
+
t, h, w, c = imgs.shape
|
69 |
+
th, tw = self.size
|
70 |
+
i = int(np.round((h - th) / 2.))
|
71 |
+
j = int(np.round((w - tw) / 2.))
|
72 |
+
|
73 |
+
return imgs[:, i:i+th, j:j+tw, :]
|
74 |
+
|
75 |
+
|
76 |
+
def __repr__(self):
|
77 |
+
return self.__class__.__name__ + '(size={0})'.format(self.size)
|
78 |
+
|
79 |
+
|
80 |
+
class RandomHorizontalFlip(object):
|
81 |
+
"""Horizontally flip the given seq Images randomly with a given probability.
|
82 |
+
Args:
|
83 |
+
p (float): probability of the image being flipped. Default value is 0.5
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(self, p=0.5):
|
87 |
+
self.p = p
|
88 |
+
|
89 |
+
def __call__(self, imgs):
|
90 |
+
"""
|
91 |
+
Args:
|
92 |
+
img (seq Images): seq Images to be flipped.
|
93 |
+
Returns:
|
94 |
+
seq Images: Randomly flipped seq images.
|
95 |
+
"""
|
96 |
+
if random.random() < self.p:
|
97 |
+
# t x h x w
|
98 |
+
return np.flip(imgs, axis=2).copy()
|
99 |
+
return imgs
|
100 |
+
|
101 |
+
def __repr__(self):
|
102 |
+
return self.__class__.__name__ + '(p={})'.format(self.p)
|