SakuraD commited on
Commit
8bb5bb4
1 Parent(s): cf54281
Files changed (7) hide show
  1. app.py +131 -0
  2. clapping.mp4 +0 -0
  3. jumping.mp4 +0 -0
  4. mitv1_class_index.py +341 -0
  5. swimming.mp4 +0 -0
  6. transforms.py +443 -0
  7. uniformerv2.py +510 -0
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+ import torchvision.transforms as T
8
+ from PIL import Image
9
+ from decord import VideoReader
10
+ from decord import cpu
11
+ from uniformerv2 import uniformerv2_b16
12
+ from mitv1_class_index import mitv1_classnames
13
+ from transforms import (
14
+ GroupNormalize, GroupScale, GroupCenterCrop,
15
+ Stack, ToTorchFormatTensor
16
+ )
17
+
18
+ import gradio as gr
19
+ from huggingface_hub import hf_hub_download
20
+
21
+ class Uniformerv2(nn.Module):
22
+ def __init__(self, model):
23
+ super().__init__()
24
+ self.backbone = model
25
+
26
+ def forward(self, x):
27
+ return self.backbone(x)
28
+
29
+ # Device on which to run the model
30
+ # Set to cuda to load on GPU
31
+ device = "cpu"
32
+ model_path = hf_hub_download(repo_id="Andy1621/uniformerv2", filename="mit_uniformerv2_b16_8x224.pyth")
33
+ # Pick a pretrained model
34
+ model = Uniformerv2(uniformerv2_b16(pretrained=False, t_size=8, no_lmhra=True, temporal_downsample=False, num_classes=339))
35
+ state_dict = torch.load(model_path, map_location='cpu')
36
+ model.load_state_dict(state_dict)
37
+
38
+ # Set to eval mode and move to desired device
39
+ model = model.to(device)
40
+ model = model.eval()
41
+
42
+ # Create an id to label name mapping
43
+ mitv1_id_to_classname = {}
44
+ for k, v in mitv1_classnames.items():
45
+ mitv1_id_to_classname[k] = v
46
+
47
+
48
+ def get_index(num_frames, num_segments=8):
49
+ seg_size = float(num_frames - 1) / num_segments
50
+ start = int(seg_size / 2)
51
+ offsets = np.array([
52
+ start + int(np.round(seg_size * idx)) for idx in range(num_segments)
53
+ ])
54
+ return offsets
55
+
56
+
57
+ def load_video(video_path):
58
+ vr = VideoReader(video_path, ctx=cpu(0))
59
+ num_frames = len(vr)
60
+ frame_indices = get_index(num_frames, 8)
61
+
62
+ # transform
63
+ crop_size = 224
64
+ scale_size = 256
65
+ input_mean = [0.485, 0.456, 0.406]
66
+ input_std = [0.229, 0.224, 0.225]
67
+
68
+ transform = T.Compose([
69
+ GroupScale(int(scale_size)),
70
+ GroupCenterCrop(crop_size),
71
+ Stack(),
72
+ ToTorchFormatTensor(),
73
+ GroupNormalize(input_mean, input_std)
74
+ ])
75
+
76
+ images_group = list()
77
+ for frame_index in frame_indices:
78
+ img = Image.fromarray(vr[frame_index].asnumpy())
79
+ images_group.append(img)
80
+ torch_imgs = transform(images_group)
81
+ return torch_imgs
82
+
83
+
84
+ def inference(video):
85
+ vid = load_video(video)
86
+
87
+ # The model expects inputs of shape: B x C x H x W
88
+ TC, H, W = vid.shape
89
+ inputs = vid.reshape(1, TC//3, 3, H, W).permute(0, 2, 1, 3, 4)
90
+
91
+ prediction = model(inputs)
92
+ prediction = F.softmax(prediction, dim=1).flatten()
93
+
94
+ return {mitv1_id_to_classname[str(i)]: float(prediction[i]) for i in range(400)}
95
+
96
+
97
+ def set_example_video(example: list) -> dict:
98
+ return gr.Video.update(value=example[0])
99
+
100
+
101
+ demo = gr.Blocks()
102
+ with demo:
103
+ gr.Markdown(
104
+ """
105
+ # UniFormerV2-B
106
+ Gradio demo for <a href='https://github.com/OpenGVLab/UniFormerV2' target='_blank'>UniFormerV2</a>: To use it, simply upload your video, or click one of the examples to load them. Read more at the links below.
107
+ """
108
+ )
109
+
110
+ with gr.Box():
111
+ with gr.Row():
112
+ with gr.Column():
113
+ with gr.Row():
114
+ input_video = gr.Video(label='Input Video')
115
+ with gr.Row():
116
+ submit_button = gr.Button('Submit')
117
+ with gr.Column():
118
+ label = gr.Label(num_top_classes=5)
119
+ with gr.Row():
120
+ example_videos = gr.Dataset(components=[input_video], samples=[['clapping.mp4'], ['jumping.mp4'], ['swimming.mp4']])
121
+
122
+ gr.Markdown(
123
+ """
124
+ <p style='text-align: center'><a href='https://arxiv.org/abs/2211.09552' target='_blank'>[Arxiv] UniFormerV2: Spatiotemporal Learning by Arming Image ViTs with Video UniFormer</a> | <a href='https://github.com/OpenGVLab/UniFormerV2' target='_blank'>Github Repo</a></p>
125
+ """
126
+ )
127
+
128
+ submit_button.click(fn=inference, inputs=input_video, outputs=label)
129
+ example_videos.click(fn=set_example_video, inputs=example_videos, outputs=example_videos.components)
130
+
131
+ demo.launch(enable_queue=True)
clapping.mp4 ADDED
Binary file (362 kB). View file
 
jumping.mp4 ADDED
Binary file (202 kB). View file
 
mitv1_class_index.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mitv1_classnames = {
2
+ "0": "adult+female+singing",
3
+ "1": "adult+female+speaking",
4
+ "2": "adult+male+singing",
5
+ "3": "adult+male+speaking",
6
+ "4": "aiming",
7
+ "5": "applauding",
8
+ "6": "arresting",
9
+ "7": "ascending",
10
+ "8": "asking",
11
+ "9": "assembling",
12
+ "10": "attacking",
13
+ "11": "autographing",
14
+ "12": "baking",
15
+ "13": "balancing",
16
+ "14": "baptizing",
17
+ "15": "barbecuing",
18
+ "16": "barking",
19
+ "17": "bathing",
20
+ "18": "bending",
21
+ "19": "bicycling",
22
+ "20": "biting",
23
+ "21": "blocking",
24
+ "22": "blowing",
25
+ "23": "boarding",
26
+ "24": "boating",
27
+ "25": "boiling",
28
+ "26": "bouncing",
29
+ "27": "bowing",
30
+ "28": "bowling",
31
+ "29": "boxing",
32
+ "30": "breaking",
33
+ "31": "brushing",
34
+ "32": "bubbling",
35
+ "33": "building",
36
+ "34": "bulldozing",
37
+ "35": "burning",
38
+ "36": "burying",
39
+ "37": "buttoning",
40
+ "38": "buying",
41
+ "39": "calling",
42
+ "40": "camping",
43
+ "41": "carrying",
44
+ "42": "carving",
45
+ "43": "catching",
46
+ "44": "celebrating",
47
+ "45": "chasing",
48
+ "46": "cheering",
49
+ "47": "cheerleading",
50
+ "48": "chewing",
51
+ "49": "child+singing",
52
+ "50": "child+speaking",
53
+ "51": "chopping",
54
+ "52": "clapping",
55
+ "53": "clawing",
56
+ "54": "cleaning",
57
+ "55": "clearing",
58
+ "56": "climbing",
59
+ "57": "clinging",
60
+ "58": "clipping",
61
+ "59": "closing",
62
+ "60": "coaching",
63
+ "61": "colliding",
64
+ "62": "combing",
65
+ "63": "combusting",
66
+ "64": "competing",
67
+ "65": "constructing",
68
+ "66": "cooking",
69
+ "67": "coughing",
70
+ "68": "covering",
71
+ "69": "cracking",
72
+ "70": "crafting",
73
+ "71": "cramming",
74
+ "72": "crashing",
75
+ "73": "crawling",
76
+ "74": "crouching",
77
+ "75": "crushing",
78
+ "76": "crying",
79
+ "77": "cuddling",
80
+ "78": "cutting",
81
+ "79": "dancing",
82
+ "80": "descending",
83
+ "81": "destroying",
84
+ "82": "digging",
85
+ "83": "dining",
86
+ "84": "dipping",
87
+ "85": "discussing",
88
+ "86": "diving",
89
+ "87": "dragging",
90
+ "88": "draining",
91
+ "89": "drawing",
92
+ "90": "drenching",
93
+ "91": "dressing",
94
+ "92": "drilling",
95
+ "93": "drinking",
96
+ "94": "dripping",
97
+ "95": "driving",
98
+ "96": "dropping",
99
+ "97": "drumming",
100
+ "98": "drying",
101
+ "99": "dunking",
102
+ "100": "dusting",
103
+ "101": "eating",
104
+ "102": "emptying",
105
+ "103": "entering",
106
+ "104": "erupting",
107
+ "105": "exercising",
108
+ "106": "exiting",
109
+ "107": "extinguishing",
110
+ "108": "falling",
111
+ "109": "feeding",
112
+ "110": "fencing",
113
+ "111": "fighting",
114
+ "112": "filling",
115
+ "113": "filming",
116
+ "114": "fishing",
117
+ "115": "flicking",
118
+ "116": "flipping",
119
+ "117": "floating",
120
+ "118": "flooding",
121
+ "119": "flowing",
122
+ "120": "flying",
123
+ "121": "folding",
124
+ "122": "frowning",
125
+ "123": "frying",
126
+ "124": "fueling",
127
+ "125": "gambling",
128
+ "126": "gardening",
129
+ "127": "giggling",
130
+ "128": "giving",
131
+ "129": "grilling",
132
+ "130": "grinning",
133
+ "131": "gripping",
134
+ "132": "grooming",
135
+ "133": "guarding",
136
+ "134": "hammering",
137
+ "135": "handcuffing",
138
+ "136": "handwriting",
139
+ "137": "hanging",
140
+ "138": "hiking",
141
+ "139": "hitchhiking",
142
+ "140": "hitting",
143
+ "141": "howling",
144
+ "142": "hugging",
145
+ "143": "hunting",
146
+ "144": "imitating",
147
+ "145": "inflating",
148
+ "146": "injecting",
149
+ "147": "instructing",
150
+ "148": "interviewing",
151
+ "149": "jogging",
152
+ "150": "joining",
153
+ "151": "juggling",
154
+ "152": "jumping",
155
+ "153": "kicking",
156
+ "154": "kissing",
157
+ "155": "kneeling",
158
+ "156": "knitting",
159
+ "157": "knocking",
160
+ "158": "landing",
161
+ "159": "laughing",
162
+ "160": "launching",
163
+ "161": "leaking",
164
+ "162": "leaning",
165
+ "163": "leaping",
166
+ "164": "lecturing",
167
+ "165": "licking",
168
+ "166": "lifting",
169
+ "167": "loading",
170
+ "168": "locking",
171
+ "169": "manicuring",
172
+ "170": "marching",
173
+ "171": "marrying",
174
+ "172": "massaging",
175
+ "173": "measuring",
176
+ "174": "mopping",
177
+ "175": "mowing",
178
+ "176": "officiating",
179
+ "177": "opening",
180
+ "178": "operating",
181
+ "179": "overflowing",
182
+ "180": "packaging",
183
+ "181": "packing",
184
+ "182": "painting",
185
+ "183": "parading",
186
+ "184": "paying",
187
+ "185": "pedaling",
188
+ "186": "peeling",
189
+ "187": "performing",
190
+ "188": "photographing",
191
+ "189": "picking",
192
+ "190": "piloting",
193
+ "191": "pitching",
194
+ "192": "placing",
195
+ "193": "planting",
196
+ "194": "playing",
197
+ "195": "playing+fun",
198
+ "196": "playing+music",
199
+ "197": "playing+sports",
200
+ "198": "playing+videogames",
201
+ "199": "plugging",
202
+ "200": "plunging",
203
+ "201": "pointing",
204
+ "202": "poking",
205
+ "203": "pouring",
206
+ "204": "praying",
207
+ "205": "preaching",
208
+ "206": "pressing",
209
+ "207": "protesting",
210
+ "208": "pulling",
211
+ "209": "punching",
212
+ "210": "punting",
213
+ "211": "pushing",
214
+ "212": "putting",
215
+ "213": "queuing",
216
+ "214": "racing",
217
+ "215": "rafting",
218
+ "216": "raining",
219
+ "217": "raising",
220
+ "218": "reaching",
221
+ "219": "reading",
222
+ "220": "removing",
223
+ "221": "repairing",
224
+ "222": "resting",
225
+ "223": "riding",
226
+ "224": "rinsing",
227
+ "225": "rising",
228
+ "226": "roaring",
229
+ "227": "rocking",
230
+ "228": "rolling",
231
+ "229": "rowing",
232
+ "230": "rubbing",
233
+ "231": "running",
234
+ "232": "sailing",
235
+ "233": "saluting",
236
+ "234": "sanding",
237
+ "235": "sawing",
238
+ "236": "scratching",
239
+ "237": "screwing",
240
+ "238": "scrubbing",
241
+ "239": "selling",
242
+ "240": "serving",
243
+ "241": "sewing",
244
+ "242": "shaking",
245
+ "243": "shaving",
246
+ "244": "shooting",
247
+ "245": "shopping",
248
+ "246": "shouting",
249
+ "247": "shoveling",
250
+ "248": "shredding",
251
+ "249": "shrugging",
252
+ "250": "signing",
253
+ "251": "singing",
254
+ "252": "sitting",
255
+ "253": "skating",
256
+ "254": "sketching",
257
+ "255": "skiing",
258
+ "256": "skipping",
259
+ "257": "slapping",
260
+ "258": "sleeping",
261
+ "259": "slicing",
262
+ "260": "sliding",
263
+ "261": "slipping",
264
+ "262": "smashing",
265
+ "263": "smelling",
266
+ "264": "smiling",
267
+ "265": "smoking",
268
+ "266": "snapping",
269
+ "267": "sneezing",
270
+ "268": "sniffing",
271
+ "269": "snowing",
272
+ "270": "snuggling",
273
+ "271": "socializing",
274
+ "272": "sowing",
275
+ "273": "speaking",
276
+ "274": "spilling",
277
+ "275": "spinning",
278
+ "276": "spitting",
279
+ "277": "splashing",
280
+ "278": "spraying",
281
+ "279": "spreading",
282
+ "280": "sprinkling",
283
+ "281": "sprinting",
284
+ "282": "squatting",
285
+ "283": "squinting",
286
+ "284": "stacking",
287
+ "285": "standing",
288
+ "286": "starting",
289
+ "287": "stealing",
290
+ "288": "steering",
291
+ "289": "stirring",
292
+ "290": "stitching",
293
+ "291": "stomping",
294
+ "292": "stopping",
295
+ "293": "storming",
296
+ "294": "stretching",
297
+ "295": "stroking",
298
+ "296": "studying",
299
+ "297": "submerging",
300
+ "298": "surfing",
301
+ "299": "sweeping",
302
+ "300": "swerving",
303
+ "301": "swimming",
304
+ "302": "swinging",
305
+ "303": "talking",
306
+ "304": "taping",
307
+ "305": "tapping",
308
+ "306": "tattooing",
309
+ "307": "teaching",
310
+ "308": "tearing",
311
+ "309": "telephoning",
312
+ "310": "throwing",
313
+ "311": "tickling",
314
+ "312": "towing",
315
+ "313": "trimming",
316
+ "314": "tripping",
317
+ "315": "tuning",
318
+ "316": "turning",
319
+ "317": "twisting",
320
+ "318": "tying",
321
+ "319": "typing",
322
+ "320": "unloading",
323
+ "321": "unpacking",
324
+ "322": "vacuuming",
325
+ "323": "waking",
326
+ "324": "walking",
327
+ "325": "washing",
328
+ "326": "watering",
329
+ "327": "waving",
330
+ "328": "waxing",
331
+ "329": "weeding",
332
+ "330": "welding",
333
+ "331": "wetting",
334
+ "332": "whistling",
335
+ "333": "winking",
336
+ "334": "working",
337
+ "335": "wrapping",
338
+ "336": "wrestling",
339
+ "337": "writing",
340
+ "338": "yawning"
341
+ }
swimming.mp4 ADDED
Binary file (238 kB). View file
 
transforms.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ import random
3
+ from PIL import Image, ImageOps
4
+ import numpy as np
5
+ import numbers
6
+ import math
7
+ import torch
8
+
9
+
10
+ class GroupRandomCrop(object):
11
+ def __init__(self, size):
12
+ if isinstance(size, numbers.Number):
13
+ self.size = (int(size), int(size))
14
+ else:
15
+ self.size = size
16
+
17
+ def __call__(self, img_group):
18
+
19
+ w, h = img_group[0].size
20
+ th, tw = self.size
21
+
22
+ out_images = list()
23
+
24
+ x1 = random.randint(0, w - tw)
25
+ y1 = random.randint(0, h - th)
26
+
27
+ for img in img_group:
28
+ assert(img.size[0] == w and img.size[1] == h)
29
+ if w == tw and h == th:
30
+ out_images.append(img)
31
+ else:
32
+ out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
33
+
34
+ return out_images
35
+
36
+
37
+ class MultiGroupRandomCrop(object):
38
+ def __init__(self, size, groups=1):
39
+ if isinstance(size, numbers.Number):
40
+ self.size = (int(size), int(size))
41
+ else:
42
+ self.size = size
43
+ self.groups = groups
44
+
45
+ def __call__(self, img_group):
46
+
47
+ w, h = img_group[0].size
48
+ th, tw = self.size
49
+
50
+ out_images = list()
51
+
52
+ for i in range(self.groups):
53
+ x1 = random.randint(0, w - tw)
54
+ y1 = random.randint(0, h - th)
55
+
56
+ for img in img_group:
57
+ assert(img.size[0] == w and img.size[1] == h)
58
+ if w == tw and h == th:
59
+ out_images.append(img)
60
+ else:
61
+ out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
62
+
63
+ return out_images
64
+
65
+
66
+ class GroupCenterCrop(object):
67
+ def __init__(self, size):
68
+ self.worker = torchvision.transforms.CenterCrop(size)
69
+
70
+ def __call__(self, img_group):
71
+ return [self.worker(img) for img in img_group]
72
+
73
+
74
+ class GroupRandomHorizontalFlip(object):
75
+ """Randomly horizontally flips the given PIL.Image with a probability of 0.5
76
+ """
77
+
78
+ def __init__(self, is_flow=False):
79
+ self.is_flow = is_flow
80
+
81
+ def __call__(self, img_group, is_flow=False):
82
+ v = random.random()
83
+ if v < 0.5:
84
+ ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
85
+ if self.is_flow:
86
+ for i in range(0, len(ret), 2):
87
+ # invert flow pixel values when flipping
88
+ ret[i] = ImageOps.invert(ret[i])
89
+ return ret
90
+ else:
91
+ return img_group
92
+
93
+
94
+ class GroupNormalize(object):
95
+ def __init__(self, mean, std):
96
+ self.mean = mean
97
+ self.std = std
98
+
99
+ def __call__(self, tensor):
100
+ rep_mean = self.mean * (tensor.size()[0] // len(self.mean))
101
+ rep_std = self.std * (tensor.size()[0] // len(self.std))
102
+
103
+ # TODO: make efficient
104
+ for t, m, s in zip(tensor, rep_mean, rep_std):
105
+ t.sub_(m).div_(s)
106
+
107
+ return tensor
108
+
109
+
110
+ class GroupScale(object):
111
+ """ Rescales the input PIL.Image to the given 'size'.
112
+ 'size' will be the size of the smaller edge.
113
+ For example, if height > width, then image will be
114
+ rescaled to (size * height / width, size)
115
+ size: size of the smaller edge
116
+ interpolation: Default: PIL.Image.BILINEAR
117
+ """
118
+
119
+ def __init__(self, size, interpolation=Image.BILINEAR):
120
+ self.worker = torchvision.transforms.Resize(size, interpolation)
121
+
122
+ def __call__(self, img_group):
123
+ return [self.worker(img) for img in img_group]
124
+
125
+
126
+ class GroupOverSample(object):
127
+ def __init__(self, crop_size, scale_size=None, flip=True):
128
+ self.crop_size = crop_size if not isinstance(
129
+ crop_size, int) else (crop_size, crop_size)
130
+
131
+ if scale_size is not None:
132
+ self.scale_worker = GroupScale(scale_size)
133
+ else:
134
+ self.scale_worker = None
135
+ self.flip = flip
136
+
137
+ def __call__(self, img_group):
138
+
139
+ if self.scale_worker is not None:
140
+ img_group = self.scale_worker(img_group)
141
+
142
+ image_w, image_h = img_group[0].size
143
+ crop_w, crop_h = self.crop_size
144
+
145
+ offsets = GroupMultiScaleCrop.fill_fix_offset(
146
+ False, image_w, image_h, crop_w, crop_h)
147
+ oversample_group = list()
148
+ for o_w, o_h in offsets:
149
+ normal_group = list()
150
+ flip_group = list()
151
+ for i, img in enumerate(img_group):
152
+ crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
153
+ normal_group.append(crop)
154
+ flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
155
+
156
+ if img.mode == 'L' and i % 2 == 0:
157
+ flip_group.append(ImageOps.invert(flip_crop))
158
+ else:
159
+ flip_group.append(flip_crop)
160
+
161
+ oversample_group.extend(normal_group)
162
+ if self.flip:
163
+ oversample_group.extend(flip_group)
164
+ return oversample_group
165
+
166
+
167
+ class GroupFullResSample(object):
168
+ def __init__(self, crop_size, scale_size=None, flip=True):
169
+ self.crop_size = crop_size if not isinstance(
170
+ crop_size, int) else (crop_size, crop_size)
171
+
172
+ if scale_size is not None:
173
+ self.scale_worker = GroupScale(scale_size)
174
+ else:
175
+ self.scale_worker = None
176
+ self.flip = flip
177
+
178
+ def __call__(self, img_group):
179
+
180
+ if self.scale_worker is not None:
181
+ img_group = self.scale_worker(img_group)
182
+
183
+ image_w, image_h = img_group[0].size
184
+ crop_w, crop_h = self.crop_size
185
+
186
+ w_step = (image_w - crop_w) // 4
187
+ h_step = (image_h - crop_h) // 4
188
+
189
+ offsets = list()
190
+ offsets.append((0 * w_step, 2 * h_step)) # left
191
+ offsets.append((4 * w_step, 2 * h_step)) # right
192
+ offsets.append((2 * w_step, 2 * h_step)) # center
193
+
194
+ oversample_group = list()
195
+ for o_w, o_h in offsets:
196
+ normal_group = list()
197
+ flip_group = list()
198
+ for i, img in enumerate(img_group):
199
+ crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
200
+ normal_group.append(crop)
201
+ if self.flip:
202
+ flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
203
+
204
+ if img.mode == 'L' and i % 2 == 0:
205
+ flip_group.append(ImageOps.invert(flip_crop))
206
+ else:
207
+ flip_group.append(flip_crop)
208
+
209
+ oversample_group.extend(normal_group)
210
+ oversample_group.extend(flip_group)
211
+ return oversample_group
212
+
213
+
214
+ class GroupMultiScaleCrop(object):
215
+
216
+ def __init__(self, input_size, scales=None, max_distort=1,
217
+ fix_crop=True, more_fix_crop=True):
218
+ self.scales = scales if scales is not None else [1, .875, .75, .66]
219
+ self.max_distort = max_distort
220
+ self.fix_crop = fix_crop
221
+ self.more_fix_crop = more_fix_crop
222
+ self.input_size = input_size if not isinstance(input_size, int) else [
223
+ input_size, input_size]
224
+ self.interpolation = Image.BILINEAR
225
+
226
+ def __call__(self, img_group):
227
+
228
+ im_size = img_group[0].size
229
+
230
+ crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
231
+ crop_img_group = [
232
+ img.crop(
233
+ (offset_w,
234
+ offset_h,
235
+ offset_w +
236
+ crop_w,
237
+ offset_h +
238
+ crop_h)) for img in img_group]
239
+ ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)
240
+ for img in crop_img_group]
241
+ return ret_img_group
242
+
243
+ def _sample_crop_size(self, im_size):
244
+ image_w, image_h = im_size[0], im_size[1]
245
+
246
+ # find a crop size
247
+ base_size = min(image_w, image_h)
248
+ crop_sizes = [int(base_size * x) for x in self.scales]
249
+ crop_h = [
250
+ self.input_size[1] if abs(
251
+ x - self.input_size[1]) < 3 else x for x in crop_sizes]
252
+ crop_w = [
253
+ self.input_size[0] if abs(
254
+ x - self.input_size[0]) < 3 else x for x in crop_sizes]
255
+
256
+ pairs = []
257
+ for i, h in enumerate(crop_h):
258
+ for j, w in enumerate(crop_w):
259
+ if abs(i - j) <= self.max_distort:
260
+ pairs.append((w, h))
261
+
262
+ crop_pair = random.choice(pairs)
263
+ if not self.fix_crop:
264
+ w_offset = random.randint(0, image_w - crop_pair[0])
265
+ h_offset = random.randint(0, image_h - crop_pair[1])
266
+ else:
267
+ w_offset, h_offset = self._sample_fix_offset(
268
+ image_w, image_h, crop_pair[0], crop_pair[1])
269
+
270
+ return crop_pair[0], crop_pair[1], w_offset, h_offset
271
+
272
+ def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
273
+ offsets = self.fill_fix_offset(
274
+ self.more_fix_crop, image_w, image_h, crop_w, crop_h)
275
+ return random.choice(offsets)
276
+
277
+ @staticmethod
278
+ def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
279
+ w_step = (image_w - crop_w) // 4
280
+ h_step = (image_h - crop_h) // 4
281
+
282
+ ret = list()
283
+ ret.append((0, 0)) # upper left
284
+ ret.append((4 * w_step, 0)) # upper right
285
+ ret.append((0, 4 * h_step)) # lower left
286
+ ret.append((4 * w_step, 4 * h_step)) # lower right
287
+ ret.append((2 * w_step, 2 * h_step)) # center
288
+
289
+ if more_fix_crop:
290
+ ret.append((0, 2 * h_step)) # center left
291
+ ret.append((4 * w_step, 2 * h_step)) # center right
292
+ ret.append((2 * w_step, 4 * h_step)) # lower center
293
+ ret.append((2 * w_step, 0 * h_step)) # upper center
294
+
295
+ ret.append((1 * w_step, 1 * h_step)) # upper left quarter
296
+ ret.append((3 * w_step, 1 * h_step)) # upper right quarter
297
+ ret.append((1 * w_step, 3 * h_step)) # lower left quarter
298
+ ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
299
+
300
+ return ret
301
+
302
+
303
+ class GroupRandomSizedCrop(object):
304
+ """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
305
+ and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
306
+ This is popularly used to train the Inception networks
307
+ size: size of the smaller edge
308
+ interpolation: Default: PIL.Image.BILINEAR
309
+ """
310
+
311
+ def __init__(self, size, interpolation=Image.BILINEAR):
312
+ self.size = size
313
+ self.interpolation = interpolation
314
+
315
+ def __call__(self, img_group):
316
+ for attempt in range(10):
317
+ area = img_group[0].size[0] * img_group[0].size[1]
318
+ target_area = random.uniform(0.08, 1.0) * area
319
+ aspect_ratio = random.uniform(3. / 4, 4. / 3)
320
+
321
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
322
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
323
+
324
+ if random.random() < 0.5:
325
+ w, h = h, w
326
+
327
+ if w <= img_group[0].size[0] and h <= img_group[0].size[1]:
328
+ x1 = random.randint(0, img_group[0].size[0] - w)
329
+ y1 = random.randint(0, img_group[0].size[1] - h)
330
+ found = True
331
+ break
332
+ else:
333
+ found = False
334
+ x1 = 0
335
+ y1 = 0
336
+
337
+ if found:
338
+ out_group = list()
339
+ for img in img_group:
340
+ img = img.crop((x1, y1, x1 + w, y1 + h))
341
+ assert(img.size == (w, h))
342
+ out_group.append(
343
+ img.resize(
344
+ (self.size, self.size), self.interpolation))
345
+ return out_group
346
+ else:
347
+ # Fallback
348
+ scale = GroupScale(self.size, interpolation=self.interpolation)
349
+ crop = GroupRandomCrop(self.size)
350
+ return crop(scale(img_group))
351
+
352
+
353
+ class ConvertDataFormat(object):
354
+ def __init__(self, model_type):
355
+ self.model_type = model_type
356
+
357
+ def __call__(self, images):
358
+ if self.model_type == '2D':
359
+ return images
360
+ tc, h, w = images.size()
361
+ t = tc // 3
362
+ images = images.view(t, 3, h, w)
363
+ images = images.permute(1, 0, 2, 3)
364
+ return images
365
+
366
+
367
+ class Stack(object):
368
+
369
+ def __init__(self, roll=False):
370
+ self.roll = roll
371
+
372
+ def __call__(self, img_group):
373
+ if img_group[0].mode == 'L':
374
+ return np.concatenate([np.expand_dims(x, 2)
375
+ for x in img_group], axis=2)
376
+ elif img_group[0].mode == 'RGB':
377
+ if self.roll:
378
+ return np.concatenate([np.array(x)[:, :, ::-1]
379
+ for x in img_group], axis=2)
380
+ else:
381
+ #print(np.concatenate(img_group, axis=2).shape)
382
+ # print(img_group[0].shape)
383
+ return np.concatenate(img_group, axis=2)
384
+
385
+
386
+ class ToTorchFormatTensor(object):
387
+ """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
388
+ to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
389
+
390
+ def __init__(self, div=True):
391
+ self.div = div
392
+
393
+ def __call__(self, pic):
394
+ if isinstance(pic, np.ndarray):
395
+ # handle numpy array
396
+ img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
397
+ else:
398
+ # handle PIL Image
399
+ img = torch.ByteTensor(
400
+ torch.ByteStorage.from_buffer(
401
+ pic.tobytes()))
402
+ img = img.view(pic.size[1], pic.size[0], len(pic.mode))
403
+ # put it from HWC to CHW format
404
+ # yikes, this transpose takes 80% of the loading time/CPU
405
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
406
+ return img.float().div(255) if self.div else img.float()
407
+
408
+
409
+ class IdentityTransform(object):
410
+
411
+ def __call__(self, data):
412
+ return data
413
+
414
+
415
+ if __name__ == "__main__":
416
+ trans = torchvision.transforms.Compose([
417
+ GroupScale(256),
418
+ GroupRandomCrop(224),
419
+ Stack(),
420
+ ToTorchFormatTensor(),
421
+ GroupNormalize(
422
+ mean=[.485, .456, .406],
423
+ std=[.229, .224, .225]
424
+ )]
425
+ )
426
+
427
+ im = Image.open('../tensorflow-model-zoo.torch/lena_299.png')
428
+
429
+ color_group = [im] * 3
430
+ rst = trans(color_group)
431
+
432
+ gray_group = [im.convert('L')] * 9
433
+ gray_rst = trans(gray_group)
434
+
435
+ trans2 = torchvision.transforms.Compose([
436
+ GroupRandomSizedCrop(256),
437
+ Stack(),
438
+ ToTorchFormatTensor(),
439
+ GroupNormalize(
440
+ mean=[.485, .456, .406],
441
+ std=[.229, .224, .225])
442
+ ])
443
+ print(trans2(color_group))
uniformerv2.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import os
3
+ from collections import OrderedDict
4
+
5
+ from timm.models.layers import DropPath
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import MultiheadAttention
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint as checkpoint
11
+
12
+
13
+ MODEL_PATH = './'
14
+ _MODELS = {
15
+ "ViT-B/16": os.path.join(MODEL_PATH, "vit_b16.pth"),
16
+ "ViT-L/14": os.path.join(MODEL_PATH, "vit_l14.pth"),
17
+ "ViT-L/14_336": os.path.join(MODEL_PATH, "vit_l14_336.pth"),
18
+ }
19
+
20
+
21
+ class LayerNorm(nn.LayerNorm):
22
+ """Subclass torch's LayerNorm to handle fp16."""
23
+
24
+ def forward(self, x):
25
+ orig_type = x.dtype
26
+ ret = super().forward(x.type(torch.float32))
27
+ return ret.type(orig_type)
28
+
29
+
30
+ class QuickGELU(nn.Module):
31
+ def forward(self, x):
32
+ return x * torch.sigmoid(1.702 * x)
33
+
34
+
35
+ class Local_MHRA(nn.Module):
36
+ def __init__(self, d_model, dw_reduction=1.5, pos_kernel_size=3):
37
+ super().__init__()
38
+
39
+ padding = pos_kernel_size // 2
40
+ re_d_model = int(d_model // dw_reduction)
41
+ self.pos_embed = nn.Sequential(
42
+ nn.BatchNorm3d(d_model),
43
+ nn.Conv3d(d_model, re_d_model, kernel_size=1, stride=1, padding=0),
44
+ nn.Conv3d(re_d_model, re_d_model, kernel_size=(pos_kernel_size, 1, 1), stride=(1, 1, 1), padding=(padding, 0, 0), groups=re_d_model),
45
+ nn.Conv3d(re_d_model, d_model, kernel_size=1, stride=1, padding=0),
46
+ )
47
+
48
+ # init zero
49
+ print('Init zero for Conv in pos_emb')
50
+ nn.init.constant_(self.pos_embed[3].weight, 0)
51
+ nn.init.constant_(self.pos_embed[3].bias, 0)
52
+
53
+ def forward(self, x):
54
+ return self.pos_embed(x)
55
+
56
+
57
+ class ResidualAttentionBlock(nn.Module):
58
+ def __init__(
59
+ self, d_model, n_head, attn_mask=None, drop_path=0.0,
60
+ dw_reduction=1.5, no_lmhra=False, double_lmhra=True
61
+ ):
62
+ super().__init__()
63
+
64
+ self.n_head = n_head
65
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
66
+ print(f'Drop path rate: {drop_path}')
67
+
68
+ self.no_lmhra = no_lmhra
69
+ self.double_lmhra = double_lmhra
70
+ print(f'No L_MHRA: {no_lmhra}')
71
+ print(f'Double L_MHRA: {double_lmhra}')
72
+ if not no_lmhra:
73
+ self.lmhra1 = Local_MHRA(d_model, dw_reduction=dw_reduction)
74
+ if double_lmhra:
75
+ self.lmhra2 = Local_MHRA(d_model, dw_reduction=dw_reduction)
76
+
77
+ # spatial
78
+ self.attn = MultiheadAttention(d_model, n_head)
79
+ self.ln_1 = LayerNorm(d_model)
80
+ self.mlp = nn.Sequential(OrderedDict([
81
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
82
+ ("gelu", QuickGELU()),
83
+ ("c_proj", nn.Linear(d_model * 4, d_model))
84
+ ]))
85
+ self.ln_2 = LayerNorm(d_model)
86
+ self.attn_mask = attn_mask
87
+
88
+ def attention(self, x):
89
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
90
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
91
+
92
+ def forward(self, x, T=8, use_checkpoint=False):
93
+ # x: 1+HW, NT, C
94
+ if not self.no_lmhra:
95
+ # Local MHRA
96
+ tmp_x = x[1:, :, :]
97
+ L, NT, C = tmp_x.shape
98
+ N = NT // T
99
+ H = W = int(L ** 0.5)
100
+ tmp_x = tmp_x.view(H, W, N, T, C).permute(2, 4, 3, 0, 1).contiguous()
101
+ tmp_x = tmp_x + self.drop_path(self.lmhra1(tmp_x))
102
+ tmp_x = tmp_x.view(N, C, T, L).permute(3, 0, 2, 1).contiguous().view(L, NT, C)
103
+ x = torch.cat([x[:1, :, :], tmp_x], dim=0)
104
+ # MHSA
105
+ if use_checkpoint:
106
+ attn_out = checkpoint.checkpoint(self.attention, self.ln_1(x))
107
+ x = x + self.drop_path(attn_out)
108
+ else:
109
+ x = x + self.drop_path(self.attention(self.ln_1(x)))
110
+ # Local MHRA
111
+ if not self.no_lmhra and self.double_lmhra:
112
+ tmp_x = x[1:, :, :]
113
+ tmp_x = tmp_x.view(H, W, N, T, C).permute(2, 4, 3, 0, 1).contiguous()
114
+ tmp_x = tmp_x + self.drop_path(self.lmhra2(tmp_x))
115
+ tmp_x = tmp_x.view(N, C, T, L).permute(3, 0, 2, 1).contiguous().view(L, NT, C)
116
+ x = torch.cat([x[:1, :, :], tmp_x], dim=0)
117
+ # FFN
118
+ if use_checkpoint:
119
+ mlp_out = checkpoint.checkpoint(self.mlp, self.ln_2(x))
120
+ x = x + self.drop_path(mlp_out)
121
+ else:
122
+ x = x + self.drop_path(self.mlp(self.ln_2(x)))
123
+ return x
124
+
125
+
126
+ class Extractor(nn.Module):
127
+ def __init__(
128
+ self, d_model, n_head, attn_mask=None,
129
+ mlp_factor=4.0, dropout=0.0, drop_path=0.0,
130
+ ):
131
+ super().__init__()
132
+
133
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
134
+ print(f'Drop path rate: {drop_path}')
135
+ self.attn = nn.MultiheadAttention(d_model, n_head)
136
+ self.ln_1 = nn.LayerNorm(d_model)
137
+ d_mlp = round(mlp_factor * d_model)
138
+ self.mlp = nn.Sequential(OrderedDict([
139
+ ("c_fc", nn.Linear(d_model, d_mlp)),
140
+ ("gelu", QuickGELU()),
141
+ ("dropout", nn.Dropout(dropout)),
142
+ ("c_proj", nn.Linear(d_mlp, d_model))
143
+ ]))
144
+ self.ln_2 = nn.LayerNorm(d_model)
145
+ self.ln_3 = nn.LayerNorm(d_model)
146
+ self.attn_mask = attn_mask
147
+
148
+ # zero init
149
+ nn.init.xavier_uniform_(self.attn.in_proj_weight)
150
+ nn.init.constant_(self.attn.out_proj.weight, 0.)
151
+ nn.init.constant_(self.attn.out_proj.bias, 0.)
152
+ nn.init.xavier_uniform_(self.mlp[0].weight)
153
+ nn.init.constant_(self.mlp[-1].weight, 0.)
154
+ nn.init.constant_(self.mlp[-1].bias, 0.)
155
+
156
+ def attention(self, x, y):
157
+ d_model = self.ln_1.weight.size(0)
158
+ q = (x @ self.attn.in_proj_weight[:d_model].T) + self.attn.in_proj_bias[:d_model]
159
+
160
+ k = (y @ self.attn.in_proj_weight[d_model:-d_model].T) + self.attn.in_proj_bias[d_model:-d_model]
161
+ v = (y @ self.attn.in_proj_weight[-d_model:].T) + self.attn.in_proj_bias[-d_model:]
162
+ Tx, Ty, N = q.size(0), k.size(0), q.size(1)
163
+ q = q.view(Tx, N, self.attn.num_heads, self.attn.head_dim).permute(1, 2, 0, 3)
164
+ k = k.view(Ty, N, self.attn.num_heads, self.attn.head_dim).permute(1, 2, 0, 3)
165
+ v = v.view(Ty, N, self.attn.num_heads, self.attn.head_dim).permute(1, 2, 0, 3)
166
+ aff = (q @ k.transpose(-2, -1) / (self.attn.head_dim ** 0.5))
167
+
168
+ aff = aff.softmax(dim=-1)
169
+ out = aff @ v
170
+ out = out.permute(2, 0, 1, 3).flatten(2)
171
+ out = self.attn.out_proj(out)
172
+ return out
173
+
174
+ def forward(self, x, y):
175
+ x = x + self.drop_path(self.attention(self.ln_1(x), self.ln_3(y)))
176
+ x = x + self.drop_path(self.mlp(self.ln_2(x)))
177
+ return x
178
+
179
+
180
+ class Transformer(nn.Module):
181
+ def __init__(
182
+ self, width, layers, heads, attn_mask=None, backbone_drop_path_rate=0.,
183
+ use_checkpoint=False, checkpoint_num=[0], t_size=8, dw_reduction=2,
184
+ no_lmhra=False, double_lmhra=True,
185
+ return_list=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
186
+ n_layers=12, n_dim=768, n_head=12, mlp_factor=4.0, drop_path_rate=0.,
187
+ mlp_dropout=[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
188
+ cls_dropout=0.5, num_classes=400,
189
+ ):
190
+ super().__init__()
191
+ self.T = t_size
192
+ self.return_list = return_list
193
+ # backbone
194
+ b_dpr = [x.item() for x in torch.linspace(0, backbone_drop_path_rate, layers)]
195
+ self.resblocks = nn.ModuleList([
196
+ ResidualAttentionBlock(
197
+ width, heads, attn_mask,
198
+ drop_path=b_dpr[i],
199
+ dw_reduction=dw_reduction,
200
+ no_lmhra=no_lmhra,
201
+ double_lmhra=double_lmhra,
202
+ ) for i in range(layers)
203
+ ])
204
+ # checkpoint
205
+ self.use_checkpoint = use_checkpoint
206
+ self.checkpoint_num = checkpoint_num
207
+ self.n_layers = n_layers
208
+ print(f'Use checkpoint: {self.use_checkpoint}')
209
+ print(f'Checkpoint number: {self.checkpoint_num}')
210
+
211
+ # global block
212
+ assert n_layers == len(return_list)
213
+ if n_layers > 0:
214
+ self.temporal_cls_token = nn.Parameter(torch.zeros(1, 1, n_dim))
215
+ self.dpe = nn.ModuleList([
216
+ nn.Conv3d(n_dim, n_dim, kernel_size=3, stride=1, padding=1, bias=True, groups=n_dim)
217
+ for i in range(n_layers)
218
+ ])
219
+ for m in self.dpe:
220
+ nn.init.constant_(m.bias, 0.)
221
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)]
222
+ self.dec = nn.ModuleList([
223
+ Extractor(
224
+ n_dim, n_head, mlp_factor=mlp_factor,
225
+ dropout=mlp_dropout[i], drop_path=dpr[i],
226
+ ) for i in range(n_layers)
227
+ ])
228
+ self.balance = nn.Parameter(torch.zeros((n_dim)))
229
+ self.sigmoid = nn.Sigmoid()
230
+ # projection
231
+ self.proj = nn.Sequential(
232
+ nn.LayerNorm(n_dim),
233
+ nn.Dropout(cls_dropout),
234
+ nn.Linear(n_dim, num_classes),
235
+ )
236
+
237
+ def forward(self, x):
238
+ T_down = self.T
239
+ L, NT, C = x.shape
240
+ N = NT // T_down
241
+ H = W = int((L - 1) ** 0.5)
242
+
243
+ if self.n_layers > 0:
244
+ cls_token = self.temporal_cls_token.repeat(1, N, 1)
245
+
246
+ j = -1
247
+ for i, resblock in enumerate(self.resblocks):
248
+ if self.use_checkpoint and i < self.checkpoint_num[0]:
249
+ x = resblock(x, self.T, use_checkpoint=True)
250
+ else:
251
+ x = resblock(x, T_down)
252
+ if i in self.return_list:
253
+ j += 1
254
+ tmp_x = x.clone()
255
+ tmp_x = tmp_x.view(L, N, T_down, C)
256
+ # dpe
257
+ _, tmp_feats = tmp_x[:1], tmp_x[1:]
258
+ tmp_feats = tmp_feats.permute(1, 3, 2, 0).reshape(N, C, T_down, H, W)
259
+ tmp_feats = self.dpe[j](tmp_feats).view(N, C, T_down, L - 1).permute(3, 0, 2, 1).contiguous()
260
+ tmp_x[1:] = tmp_x[1:] + tmp_feats
261
+ # global block
262
+ tmp_x = tmp_x.permute(2, 0, 1, 3).flatten(0, 1) # T * L, N, C
263
+ cls_token = self.dec[j](cls_token, tmp_x)
264
+
265
+ if self.n_layers > 0:
266
+ weight = self.sigmoid(self.balance)
267
+ residual = x.view(L, N, T_down, C)[0].mean(1) # L, N, T, C
268
+ return self.proj((1 - weight) * cls_token[0, :, :] + weight * residual)
269
+ else:
270
+ residual = x.view(L, N, T_down, C)[0].mean(1) # L, N, T, C
271
+ return self.proj(residual)
272
+
273
+
274
+ class VisionTransformer(nn.Module):
275
+ def __init__(
276
+ self,
277
+ # backbone
278
+ input_resolution, patch_size, width, layers, heads, output_dim, backbone_drop_path_rate=0.,
279
+ use_checkpoint=False, checkpoint_num=[0], t_size=8, kernel_size=3, dw_reduction=1.5,
280
+ temporal_downsample=True,
281
+ no_lmhra=-False, double_lmhra=True,
282
+ # global block
283
+ return_list=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
284
+ n_layers=12, n_dim=768, n_head=12, mlp_factor=4.0, drop_path_rate=0.,
285
+ mlp_dropout=[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
286
+ cls_dropout=0.5, num_classes=400,
287
+ ):
288
+ super().__init__()
289
+ self.input_resolution = input_resolution
290
+ self.output_dim = output_dim
291
+ padding = (kernel_size - 1) // 2
292
+ if temporal_downsample:
293
+ self.conv1 = nn.Conv3d(3, width, (kernel_size, patch_size, patch_size), (2, patch_size, patch_size), (padding, 0, 0), bias=False)
294
+ t_size = t_size // 2
295
+ else:
296
+ self.conv1 = nn.Conv3d(3, width, (1, patch_size, patch_size), (1, patch_size, patch_size), (0, 0, 0), bias=False)
297
+
298
+ scale = width ** -0.5
299
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
300
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
301
+ self.ln_pre = LayerNorm(width)
302
+
303
+ self.transformer = Transformer(
304
+ width, layers, heads, dw_reduction=dw_reduction,
305
+ backbone_drop_path_rate=backbone_drop_path_rate,
306
+ use_checkpoint=use_checkpoint, checkpoint_num=checkpoint_num, t_size=t_size,
307
+ no_lmhra=no_lmhra, double_lmhra=double_lmhra,
308
+ return_list=return_list, n_layers=n_layers, n_dim=n_dim, n_head=n_head,
309
+ mlp_factor=mlp_factor, drop_path_rate=drop_path_rate, mlp_dropout=mlp_dropout,
310
+ cls_dropout=cls_dropout, num_classes=num_classes,
311
+ )
312
+
313
+ def forward(self, x):
314
+ x = self.conv1(x) # shape = [*, width, grid, grid]
315
+ N, C, T, H, W = x.shape
316
+ x = x.permute(0, 2, 3, 4, 1).reshape(N * T, H * W, C)
317
+
318
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
319
+ x = x + self.positional_embedding.to(x.dtype)
320
+ x = self.ln_pre(x)
321
+
322
+ x = x.permute(1, 0, 2) # NLD -> LND
323
+ out = self.transformer(x)
324
+ return out
325
+
326
+
327
+ def inflate_weight(weight_2d, time_dim, center=True):
328
+ print(f'Init center: {center}')
329
+ if center:
330
+ weight_3d = torch.zeros(*weight_2d.shape)
331
+ weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
332
+ middle_idx = time_dim // 2
333
+ weight_3d[:, :, middle_idx, :, :] = weight_2d
334
+ else:
335
+ weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
336
+ weight_3d = weight_3d / time_dim
337
+ return weight_3d
338
+
339
+
340
+ def load_state_dict(model, state_dict):
341
+ state_dict_3d = model.state_dict()
342
+ for k in state_dict.keys():
343
+ if state_dict[k].shape != state_dict_3d[k].shape:
344
+ if len(state_dict_3d[k].shape) <= 2:
345
+ print(f'Ignore: {k}')
346
+ continue
347
+ print(f'Inflate: {k}, {state_dict[k].shape} => {state_dict_3d[k].shape}')
348
+ time_dim = state_dict_3d[k].shape[2]
349
+ state_dict[k] = inflate_weight(state_dict[k], time_dim)
350
+ model.load_state_dict(state_dict, strict=False)
351
+
352
+
353
+ def uniformerv2_b16(
354
+ pretrained=True, use_checkpoint=False, checkpoint_num=[0],
355
+ t_size=16, dw_reduction=1.5, backbone_drop_path_rate=0.,
356
+ temporal_downsample=True,
357
+ no_lmhra=False, double_lmhra=True,
358
+ return_list=[8, 9, 10, 11],
359
+ n_layers=4, n_dim=768, n_head=12, mlp_factor=4.0, drop_path_rate=0.,
360
+ mlp_dropout=[0.5, 0.5, 0.5, 0.5],
361
+ cls_dropout=0.5, num_classes=400,
362
+ ):
363
+ model = VisionTransformer(
364
+ input_resolution=224,
365
+ patch_size=16,
366
+ width=768,
367
+ layers=12,
368
+ heads=12,
369
+ output_dim=512,
370
+ use_checkpoint=use_checkpoint,
371
+ checkpoint_num=checkpoint_num,
372
+ t_size=t_size,
373
+ dw_reduction=dw_reduction,
374
+ backbone_drop_path_rate=backbone_drop_path_rate,
375
+ temporal_downsample=temporal_downsample,
376
+ no_lmhra=no_lmhra,
377
+ double_lmhra=double_lmhra,
378
+ return_list=return_list,
379
+ n_layers=n_layers,
380
+ n_dim=n_dim,
381
+ n_head=n_head,
382
+ mlp_factor=mlp_factor,
383
+ drop_path_rate=drop_path_rate,
384
+ mlp_dropout=mlp_dropout,
385
+ cls_dropout=cls_dropout,
386
+ num_classes=num_classes,
387
+ )
388
+
389
+ if pretrained:
390
+ print('load pretrained weights')
391
+ state_dict = torch.load(_MODELS["ViT-B/16"], map_location='cpu')
392
+ load_state_dict(model, state_dict)
393
+ return model.eval()
394
+
395
+
396
+ def uniformerv2_l14(
397
+ pretrained=True, use_checkpoint=False, checkpoint_num=[0],
398
+ t_size=16, dw_reduction=1.5, backbone_drop_path_rate=0.,
399
+ temporal_downsample=True,
400
+ no_lmhra=False, double_lmhra=True,
401
+ return_list=[20, 21, 22, 23],
402
+ n_layers=4, n_dim=1024, n_head=16, mlp_factor=4.0, drop_path_rate=0.,
403
+ mlp_dropout=[0.5, 0.5, 0.5, 0.5],
404
+ cls_dropout=0.5, num_classes=400,
405
+ ):
406
+ model = VisionTransformer(
407
+ input_resolution=224,
408
+ patch_size=14,
409
+ width=1024,
410
+ layers=24,
411
+ heads=16,
412
+ output_dim=768,
413
+ use_checkpoint=use_checkpoint,
414
+ checkpoint_num=checkpoint_num,
415
+ t_size=t_size,
416
+ dw_reduction=dw_reduction,
417
+ backbone_drop_path_rate=backbone_drop_path_rate,
418
+ temporal_downsample=temporal_downsample,
419
+ no_lmhra=no_lmhra,
420
+ double_lmhra=double_lmhra,
421
+ return_list=return_list,
422
+ n_layers=n_layers,
423
+ n_dim=n_dim,
424
+ n_head=n_head,
425
+ mlp_factor=mlp_factor,
426
+ drop_path_rate=drop_path_rate,
427
+ mlp_dropout=mlp_dropout,
428
+ cls_dropout=cls_dropout,
429
+ num_classes=num_classes,
430
+ )
431
+
432
+ if pretrained:
433
+ print('load pretrained weights')
434
+ state_dict = torch.load(_MODELS["ViT-L/14"], map_location='cpu')
435
+ load_state_dict(model, state_dict)
436
+ return model.eval()
437
+
438
+
439
+ def uniformerv2_l14_336(
440
+ pretrained=True, use_checkpoint=False, checkpoint_num=[0],
441
+ t_size=16, dw_reduction=1.5, backbone_drop_path_rate=0.,
442
+ no_temporal_downsample=True,
443
+ no_lmhra=False, double_lmhra=True,
444
+ return_list=[20, 21, 22, 23],
445
+ n_layers=4, n_dim=1024, n_head=16, mlp_factor=4.0, drop_path_rate=0.,
446
+ mlp_dropout=[0.5, 0.5, 0.5, 0.5],
447
+ cls_dropout=0.5, num_classes=400,
448
+ ):
449
+ model = VisionTransformer(
450
+ input_resolution=336,
451
+ patch_size=14,
452
+ width=1024,
453
+ layers=24,
454
+ heads=16,
455
+ output_dim=768,
456
+ use_checkpoint=use_checkpoint,
457
+ checkpoint_num=checkpoint_num,
458
+ t_size=t_size,
459
+ dw_reduction=dw_reduction,
460
+ backbone_drop_path_rate=backbone_drop_path_rate,
461
+ no_temporal_downsample=no_temporal_downsample,
462
+ no_lmhra=no_lmhra,
463
+ double_lmhra=double_lmhra,
464
+ return_list=return_list,
465
+ n_layers=n_layers,
466
+ n_dim=n_dim,
467
+ n_head=n_head,
468
+ mlp_factor=mlp_factor,
469
+ drop_path_rate=drop_path_rate,
470
+ mlp_dropout=mlp_dropout,
471
+ cls_dropout=cls_dropout,
472
+ num_classes=num_classes,
473
+ )
474
+
475
+ if pretrained:
476
+ print('load pretrained weights')
477
+ state_dict = torch.load(_MODELS["ViT-L/14_336"], map_location='cpu')
478
+ load_state_dict(model, state_dict)
479
+ return model.eval()
480
+
481
+
482
+ if __name__ == '__main__':
483
+ import time
484
+ from fvcore.nn import FlopCountAnalysis
485
+ from fvcore.nn import flop_count_table
486
+ import numpy as np
487
+
488
+ seed = 4217
489
+ np.random.seed(seed)
490
+ torch.manual_seed(seed)
491
+ torch.cuda.manual_seed(seed)
492
+ torch.cuda.manual_seed_all(seed)
493
+ num_frames = 16
494
+
495
+ model = uniformerv2_l14(
496
+ pretrained=False,
497
+ t_size=num_frames, backbone_drop_path_rate=0., drop_path_rate=0.,
498
+ dw_reduction=1.5,
499
+ no_lmhra=False,
500
+ temporal_downsample=True,
501
+ return_list=[8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
502
+ mlp_dropout=[0.5]*16,
503
+ n_layers=16
504
+ )
505
+ print(model)
506
+
507
+ flops = FlopCountAnalysis(model, torch.rand(1, 3, num_frames, 224, 224))
508
+ s = time.time()
509
+ print(flop_count_table(flops, max_depth=1))
510
+ print(time.time()-s)