SakuraD commited on
Commit
bc059ff
1 Parent(s): f3a7b0a

init video

Browse files
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+ import torchvision.transforms as T
8
+ from PIL import Image
9
+ from decord import VideoReader
10
+ from decord import cpu
11
+ from uniformer_light_video import uniformer_xxs_video
12
+ from kinetics_class_index import kinetics_classnames
13
+ from transforms import (
14
+ GroupNormalize, GroupScale, GroupCenterCrop,
15
+ Stack, ToTorchFormatTensor
16
+ )
17
+
18
+ import gradio as gr
19
+ from huggingface_hub import hf_hub_download
20
+
21
+
22
+ # Device on which to run the model
23
+ # Set to cuda to load on GPU
24
+ device = "cpu"
25
+ model_path = hf_hub_download(repo_id="Andy1621/uniformer_light", filename="uniformer_xxs16_160_k400.pth")
26
+ # Pick a pretrained model
27
+ model = uniformer_xxs_video()
28
+ state_dict = torch.load(model_path, map_location='cpu')
29
+ model.load_state_dict(state_dict)
30
+
31
+ # Set to eval mode and move to desired device
32
+ model = model.to(device)
33
+ model = model.eval()
34
+
35
+ # Create an id to label name mapping
36
+ kinetics_id_to_classname = {}
37
+ for k, v in kinetics_classnames.items():
38
+ kinetics_id_to_classname[k] = v
39
+
40
+
41
+ def get_index(num_frames, num_segments=8):
42
+ seg_size = float(num_frames - 1) / num_segments
43
+ start = int(seg_size / 2)
44
+ offsets = np.array([
45
+ start + int(np.round(seg_size * idx)) for idx in range(num_segments)
46
+ ])
47
+ return offsets
48
+
49
+
50
+ def load_video(video_path):
51
+ vr = VideoReader(video_path, ctx=cpu(0))
52
+ num_frames = len(vr)
53
+ frame_indices = get_index(num_frames, 8)
54
+
55
+ # transform
56
+ crop_size = 160
57
+ scale_size = 160
58
+ input_mean = [0.485, 0.456, 0.406]
59
+ input_std = [0.229, 0.224, 0.225]
60
+
61
+ transform = T.Compose([
62
+ GroupScale(int(scale_size)),
63
+ GroupCenterCrop(crop_size),
64
+ Stack(),
65
+ ToTorchFormatTensor(),
66
+ GroupNormalize(input_mean, input_std)
67
+ ])
68
+
69
+ images_group = list()
70
+ for frame_index in frame_indices:
71
+ img = Image.fromarray(vr[frame_index].asnumpy())
72
+ images_group.append(img)
73
+ torch_imgs = transform(images_group)
74
+ return torch_imgs
75
+
76
+
77
+ def inference(video):
78
+ vid = load_video(video)
79
+
80
+ # The model expects inputs of shape: B x C x H x W
81
+ TC, H, W = vid.shape
82
+ inputs = vid.reshape(1, TC//3, 3, H, W).permute(0, 2, 1, 3, 4)
83
+
84
+ with torch.no_grad():
85
+ prediction = model(inputs)
86
+ prediction = F.softmax(prediction, dim=1).flatten()
87
+
88
+ return {kinetics_id_to_classname[str(i)]: float(prediction[i]) for i in range(400)}
89
+
90
+
91
+ def set_example_video(example: list) -> dict:
92
+ return gr.Video.update(value=example[0])
93
+
94
+
95
+ demo = gr.Blocks()
96
+ with demo:
97
+ gr.Markdown(
98
+ """
99
+ # UniFormer Light
100
+ Gradio demo for <a href='https://github.com/Sense-X/UniFormer' target='_blank'>UniFormer</a>: To use it, simply upload your video, or click one of the examples to load them. Read more at the links below.
101
+ """
102
+ )
103
+
104
+ with gr.Box():
105
+ with gr.Row():
106
+ with gr.Column():
107
+ with gr.Row():
108
+ input_video = gr.Video(label='Input Video')
109
+ with gr.Row():
110
+ submit_button = gr.Button('Submit')
111
+ with gr.Column():
112
+ label = gr.Label(num_top_classes=5)
113
+ with gr.Row():
114
+ example_videos = gr.Dataset(components=[input_video], samples=[['./videos/hitting_baseball.mp4'], ['./videos/hoverboarding.mp4'], ['./videos/yoga.mp4']])
115
+
116
+ gr.Markdown(
117
+ """
118
+ <p style='text-align: center'><a href='https://arxiv.org/abs/2201.09450' target='_blank'>[TPAMI] UniFormer: Unifying Convolution and Self-attention for Visual Recognition</a> | <a href='https://github.com/Sense-X/UniFormer' target='_blank'>Github Repo</a></p>
119
+ """
120
+ )
121
+
122
+ submit_button.click(fn=inference, inputs=input_video, outputs=label)
123
+ example_videos.click(fn=set_example_video, inputs=example_videos, outputs=example_videos.components)
124
+
125
+ demo.launch(enable_queue=True)
images/cat.png ADDED
images/dog.png ADDED
images/panda.png ADDED
kinetics_class_index.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ kinetics_classnames = {
2
+ "0": "riding a bike",
3
+ "1": "marching",
4
+ "2": "dodgeball",
5
+ "3": "playing cymbals",
6
+ "4": "checking tires",
7
+ "5": "roller skating",
8
+ "6": "tasting beer",
9
+ "7": "clapping",
10
+ "8": "drawing",
11
+ "9": "juggling fire",
12
+ "10": "bobsledding",
13
+ "11": "petting animal (not cat)",
14
+ "12": "spray painting",
15
+ "13": "training dog",
16
+ "14": "eating watermelon",
17
+ "15": "building cabinet",
18
+ "16": "applauding",
19
+ "17": "playing harp",
20
+ "18": "balloon blowing",
21
+ "19": "sled dog racing",
22
+ "20": "wrestling",
23
+ "21": "pole vault",
24
+ "22": "hurling (sport)",
25
+ "23": "riding scooter",
26
+ "24": "shearing sheep",
27
+ "25": "sweeping floor",
28
+ "26": "eating carrots",
29
+ "27": "skateboarding",
30
+ "28": "dunking basketball",
31
+ "29": "disc golfing",
32
+ "30": "eating spaghetti",
33
+ "31": "playing flute",
34
+ "32": "riding mechanical bull",
35
+ "33": "making sushi",
36
+ "34": "trapezing",
37
+ "35": "picking fruit",
38
+ "36": "stretching leg",
39
+ "37": "playing ukulele",
40
+ "38": "tying tie",
41
+ "39": "skydiving",
42
+ "40": "playing cello",
43
+ "41": "jumping into pool",
44
+ "42": "shooting goal (soccer)",
45
+ "43": "trimming trees",
46
+ "44": "bookbinding",
47
+ "45": "ski jumping",
48
+ "46": "walking the dog",
49
+ "47": "riding unicycle",
50
+ "48": "shaving head",
51
+ "49": "hopscotch",
52
+ "50": "playing piano",
53
+ "51": "parasailing",
54
+ "52": "bartending",
55
+ "53": "kicking field goal",
56
+ "54": "finger snapping",
57
+ "55": "dining",
58
+ "56": "yawning",
59
+ "57": "peeling potatoes",
60
+ "58": "canoeing or kayaking",
61
+ "59": "front raises",
62
+ "60": "laughing",
63
+ "61": "dancing macarena",
64
+ "62": "digging",
65
+ "63": "reading newspaper",
66
+ "64": "hitting baseball",
67
+ "65": "clay pottery making",
68
+ "66": "exercising with an exercise ball",
69
+ "67": "playing saxophone",
70
+ "68": "shooting basketball",
71
+ "69": "washing hair",
72
+ "70": "lunge",
73
+ "71": "brushing hair",
74
+ "72": "curling hair",
75
+ "73": "kitesurfing",
76
+ "74": "tapping guitar",
77
+ "75": "bending back",
78
+ "76": "skipping rope",
79
+ "77": "situp",
80
+ "78": "folding paper",
81
+ "79": "cracking neck",
82
+ "80": "assembling computer",
83
+ "81": "cleaning gutters",
84
+ "82": "blowing out candles",
85
+ "83": "shaking hands",
86
+ "84": "dancing gangnam style",
87
+ "85": "windsurfing",
88
+ "86": "tap dancing",
89
+ "87": "skiing (not slalom or crosscountry)",
90
+ "88": "bandaging",
91
+ "89": "push up",
92
+ "90": "doing nails",
93
+ "91": "punching person (boxing)",
94
+ "92": "bouncing on trampoline",
95
+ "93": "scrambling eggs",
96
+ "94": "singing",
97
+ "95": "cleaning floor",
98
+ "96": "krumping",
99
+ "97": "drumming fingers",
100
+ "98": "snowmobiling",
101
+ "99": "gymnastics tumbling",
102
+ "100": "headbanging",
103
+ "101": "catching or throwing frisbee",
104
+ "102": "riding elephant",
105
+ "103": "bee keeping",
106
+ "104": "feeding birds",
107
+ "105": "snatch weight lifting",
108
+ "106": "mowing lawn",
109
+ "107": "fixing hair",
110
+ "108": "playing trumpet",
111
+ "109": "flying kite",
112
+ "110": "crossing river",
113
+ "111": "swinging legs",
114
+ "112": "sanding floor",
115
+ "113": "belly dancing",
116
+ "114": "sneezing",
117
+ "115": "clean and jerk",
118
+ "116": "side kick",
119
+ "117": "filling eyebrows",
120
+ "118": "shuffling cards",
121
+ "119": "recording music",
122
+ "120": "cartwheeling",
123
+ "121": "feeding fish",
124
+ "122": "folding clothes",
125
+ "123": "water skiing",
126
+ "124": "tobogganing",
127
+ "125": "blowing leaves",
128
+ "126": "smoking",
129
+ "127": "unboxing",
130
+ "128": "tai chi",
131
+ "129": "waxing legs",
132
+ "130": "riding camel",
133
+ "131": "slapping",
134
+ "132": "tossing salad",
135
+ "133": "capoeira",
136
+ "134": "playing cards",
137
+ "135": "playing organ",
138
+ "136": "playing violin",
139
+ "137": "playing drums",
140
+ "138": "tapping pen",
141
+ "139": "vault",
142
+ "140": "shoveling snow",
143
+ "141": "playing tennis",
144
+ "142": "getting a tattoo",
145
+ "143": "making a sandwich",
146
+ "144": "making tea",
147
+ "145": "grinding meat",
148
+ "146": "squat",
149
+ "147": "eating doughnuts",
150
+ "148": "ice fishing",
151
+ "149": "snowkiting",
152
+ "150": "kicking soccer ball",
153
+ "151": "playing controller",
154
+ "152": "giving or receiving award",
155
+ "153": "welding",
156
+ "154": "throwing discus",
157
+ "155": "throwing axe",
158
+ "156": "ripping paper",
159
+ "157": "swimming butterfly stroke",
160
+ "158": "air drumming",
161
+ "159": "blowing nose",
162
+ "160": "hockey stop",
163
+ "161": "taking a shower",
164
+ "162": "bench pressing",
165
+ "163": "planting trees",
166
+ "164": "pumping fist",
167
+ "165": "climbing tree",
168
+ "166": "tickling",
169
+ "167": "high kick",
170
+ "168": "waiting in line",
171
+ "169": "slacklining",
172
+ "170": "tango dancing",
173
+ "171": "hurdling",
174
+ "172": "carrying baby",
175
+ "173": "celebrating",
176
+ "174": "sharpening knives",
177
+ "175": "passing American football (in game)",
178
+ "176": "headbutting",
179
+ "177": "playing recorder",
180
+ "178": "brush painting",
181
+ "179": "garbage collecting",
182
+ "180": "robot dancing",
183
+ "181": "shredding paper",
184
+ "182": "pumping gas",
185
+ "183": "rock climbing",
186
+ "184": "hula hooping",
187
+ "185": "braiding hair",
188
+ "186": "opening present",
189
+ "187": "texting",
190
+ "188": "decorating the christmas tree",
191
+ "189": "answering questions",
192
+ "190": "playing keyboard",
193
+ "191": "writing",
194
+ "192": "bungee jumping",
195
+ "193": "sniffing",
196
+ "194": "eating burger",
197
+ "195": "playing accordion",
198
+ "196": "making pizza",
199
+ "197": "playing volleyball",
200
+ "198": "tasting food",
201
+ "199": "pushing cart",
202
+ "200": "spinning poi",
203
+ "201": "cleaning windows",
204
+ "202": "arm wrestling",
205
+ "203": "changing oil",
206
+ "204": "swimming breast stroke",
207
+ "205": "tossing coin",
208
+ "206": "deadlifting",
209
+ "207": "hoverboarding",
210
+ "208": "cutting watermelon",
211
+ "209": "cheerleading",
212
+ "210": "snorkeling",
213
+ "211": "washing hands",
214
+ "212": "eating cake",
215
+ "213": "pull ups",
216
+ "214": "surfing water",
217
+ "215": "eating hotdog",
218
+ "216": "holding snake",
219
+ "217": "playing harmonica",
220
+ "218": "ironing",
221
+ "219": "cutting nails",
222
+ "220": "golf chipping",
223
+ "221": "shot put",
224
+ "222": "hugging",
225
+ "223": "playing clarinet",
226
+ "224": "faceplanting",
227
+ "225": "trimming or shaving beard",
228
+ "226": "drinking shots",
229
+ "227": "riding mountain bike",
230
+ "228": "tying bow tie",
231
+ "229": "swinging on something",
232
+ "230": "skiing crosscountry",
233
+ "231": "unloading truck",
234
+ "232": "cleaning pool",
235
+ "233": "jogging",
236
+ "234": "ice climbing",
237
+ "235": "mopping floor",
238
+ "236": "making bed",
239
+ "237": "diving cliff",
240
+ "238": "washing dishes",
241
+ "239": "grooming dog",
242
+ "240": "weaving basket",
243
+ "241": "frying vegetables",
244
+ "242": "stomping grapes",
245
+ "243": "moving furniture",
246
+ "244": "cooking sausages",
247
+ "245": "doing laundry",
248
+ "246": "dying hair",
249
+ "247": "knitting",
250
+ "248": "reading book",
251
+ "249": "baby waking up",
252
+ "250": "punching bag",
253
+ "251": "surfing crowd",
254
+ "252": "cooking chicken",
255
+ "253": "pushing car",
256
+ "254": "springboard diving",
257
+ "255": "swing dancing",
258
+ "256": "massaging legs",
259
+ "257": "beatboxing",
260
+ "258": "breading or breadcrumbing",
261
+ "259": "somersaulting",
262
+ "260": "brushing teeth",
263
+ "261": "stretching arm",
264
+ "262": "juggling balls",
265
+ "263": "massaging person's head",
266
+ "264": "eating ice cream",
267
+ "265": "extinguishing fire",
268
+ "266": "hammer throw",
269
+ "267": "whistling",
270
+ "268": "crawling baby",
271
+ "269": "using remote controller (not gaming)",
272
+ "270": "playing cricket",
273
+ "271": "opening bottle",
274
+ "272": "playing xylophone",
275
+ "273": "motorcycling",
276
+ "274": "driving car",
277
+ "275": "exercising arm",
278
+ "276": "passing American football (not in game)",
279
+ "277": "playing kickball",
280
+ "278": "sticking tongue out",
281
+ "279": "flipping pancake",
282
+ "280": "catching fish",
283
+ "281": "eating chips",
284
+ "282": "shaking head",
285
+ "283": "sword fighting",
286
+ "284": "playing poker",
287
+ "285": "cooking on campfire",
288
+ "286": "doing aerobics",
289
+ "287": "paragliding",
290
+ "288": "using segway",
291
+ "289": "folding napkins",
292
+ "290": "playing bagpipes",
293
+ "291": "gargling",
294
+ "292": "skiing slalom",
295
+ "293": "strumming guitar",
296
+ "294": "javelin throw",
297
+ "295": "waxing back",
298
+ "296": "riding or walking with horse",
299
+ "297": "plastering",
300
+ "298": "long jump",
301
+ "299": "parkour",
302
+ "300": "wrapping present",
303
+ "301": "egg hunting",
304
+ "302": "archery",
305
+ "303": "cleaning toilet",
306
+ "304": "swimming backstroke",
307
+ "305": "snowboarding",
308
+ "306": "catching or throwing baseball",
309
+ "307": "massaging back",
310
+ "308": "blowing glass",
311
+ "309": "playing guitar",
312
+ "310": "playing chess",
313
+ "311": "golf driving",
314
+ "312": "presenting weather forecast",
315
+ "313": "rock scissors paper",
316
+ "314": "high jump",
317
+ "315": "baking cookies",
318
+ "316": "using computer",
319
+ "317": "washing feet",
320
+ "318": "arranging flowers",
321
+ "319": "playing bass guitar",
322
+ "320": "spraying",
323
+ "321": "cutting pineapple",
324
+ "322": "waxing chest",
325
+ "323": "auctioning",
326
+ "324": "jetskiing",
327
+ "325": "drinking",
328
+ "326": "busking",
329
+ "327": "playing monopoly",
330
+ "328": "salsa dancing",
331
+ "329": "waxing eyebrows",
332
+ "330": "watering plants",
333
+ "331": "zumba",
334
+ "332": "chopping wood",
335
+ "333": "pushing wheelchair",
336
+ "334": "carving pumpkin",
337
+ "335": "building shed",
338
+ "336": "making jewelry",
339
+ "337": "catching or throwing softball",
340
+ "338": "bending metal",
341
+ "339": "ice skating",
342
+ "340": "dancing charleston",
343
+ "341": "abseiling",
344
+ "342": "climbing a rope",
345
+ "343": "crying",
346
+ "344": "cleaning shoes",
347
+ "345": "dancing ballet",
348
+ "346": "driving tractor",
349
+ "347": "triple jump",
350
+ "348": "throwing ball",
351
+ "349": "getting a haircut",
352
+ "350": "running on treadmill",
353
+ "351": "climbing ladder",
354
+ "352": "blasting sand",
355
+ "353": "playing trombone",
356
+ "354": "drop kicking",
357
+ "355": "country line dancing",
358
+ "356": "changing wheel",
359
+ "357": "feeding goats",
360
+ "358": "tying knot (not on a tie)",
361
+ "359": "setting table",
362
+ "360": "shaving legs",
363
+ "361": "kissing",
364
+ "362": "riding mule",
365
+ "363": "counting money",
366
+ "364": "laying bricks",
367
+ "365": "barbequing",
368
+ "366": "news anchoring",
369
+ "367": "smoking hookah",
370
+ "368": "cooking egg",
371
+ "369": "peeling apples",
372
+ "370": "yoga",
373
+ "371": "sharpening pencil",
374
+ "372": "dribbling basketball",
375
+ "373": "petting cat",
376
+ "374": "playing ice hockey",
377
+ "375": "milking cow",
378
+ "376": "shining shoes",
379
+ "377": "juggling soccer ball",
380
+ "378": "scuba diving",
381
+ "379": "playing squash or racquetball",
382
+ "380": "drinking beer",
383
+ "381": "sign language interpreting",
384
+ "382": "playing basketball",
385
+ "383": "breakdancing",
386
+ "384": "testifying",
387
+ "385": "making snowman",
388
+ "386": "golf putting",
389
+ "387": "playing didgeridoo",
390
+ "388": "biking through snow",
391
+ "389": "sailing",
392
+ "390": "jumpstyle dancing",
393
+ "391": "water sliding",
394
+ "392": "grooming horse",
395
+ "393": "massaging feet",
396
+ "394": "playing paintball",
397
+ "395": "making a cake",
398
+ "396": "bowling",
399
+ "397": "contact juggling",
400
+ "398": "applying cream",
401
+ "399": "playing badminton"
402
+ }
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch==1.9.1
2
+ torchvision
3
+ einops
4
+ timm
5
+ Pillow
6
+ decord
transforms.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ import random
3
+ from PIL import Image, ImageOps
4
+ import numpy as np
5
+ import numbers
6
+ import math
7
+ import torch
8
+
9
+
10
+ class GroupRandomCrop(object):
11
+ def __init__(self, size):
12
+ if isinstance(size, numbers.Number):
13
+ self.size = (int(size), int(size))
14
+ else:
15
+ self.size = size
16
+
17
+ def __call__(self, img_group):
18
+
19
+ w, h = img_group[0].size
20
+ th, tw = self.size
21
+
22
+ out_images = list()
23
+
24
+ x1 = random.randint(0, w - tw)
25
+ y1 = random.randint(0, h - th)
26
+
27
+ for img in img_group:
28
+ assert(img.size[0] == w and img.size[1] == h)
29
+ if w == tw and h == th:
30
+ out_images.append(img)
31
+ else:
32
+ out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
33
+
34
+ return out_images
35
+
36
+
37
+ class MultiGroupRandomCrop(object):
38
+ def __init__(self, size, groups=1):
39
+ if isinstance(size, numbers.Number):
40
+ self.size = (int(size), int(size))
41
+ else:
42
+ self.size = size
43
+ self.groups = groups
44
+
45
+ def __call__(self, img_group):
46
+
47
+ w, h = img_group[0].size
48
+ th, tw = self.size
49
+
50
+ out_images = list()
51
+
52
+ for i in range(self.groups):
53
+ x1 = random.randint(0, w - tw)
54
+ y1 = random.randint(0, h - th)
55
+
56
+ for img in img_group:
57
+ assert(img.size[0] == w and img.size[1] == h)
58
+ if w == tw and h == th:
59
+ out_images.append(img)
60
+ else:
61
+ out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
62
+
63
+ return out_images
64
+
65
+
66
+ class GroupCenterCrop(object):
67
+ def __init__(self, size):
68
+ self.worker = torchvision.transforms.CenterCrop(size)
69
+
70
+ def __call__(self, img_group):
71
+ return [self.worker(img) for img in img_group]
72
+
73
+
74
+ class GroupRandomHorizontalFlip(object):
75
+ """Randomly horizontally flips the given PIL.Image with a probability of 0.5
76
+ """
77
+
78
+ def __init__(self, is_flow=False):
79
+ self.is_flow = is_flow
80
+
81
+ def __call__(self, img_group, is_flow=False):
82
+ v = random.random()
83
+ if v < 0.5:
84
+ ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
85
+ if self.is_flow:
86
+ for i in range(0, len(ret), 2):
87
+ # invert flow pixel values when flipping
88
+ ret[i] = ImageOps.invert(ret[i])
89
+ return ret
90
+ else:
91
+ return img_group
92
+
93
+
94
+ class GroupNormalize(object):
95
+ def __init__(self, mean, std):
96
+ self.mean = mean
97
+ self.std = std
98
+
99
+ def __call__(self, tensor):
100
+ rep_mean = self.mean * (tensor.size()[0] // len(self.mean))
101
+ rep_std = self.std * (tensor.size()[0] // len(self.std))
102
+
103
+ # TODO: make efficient
104
+ for t, m, s in zip(tensor, rep_mean, rep_std):
105
+ t.sub_(m).div_(s)
106
+
107
+ return tensor
108
+
109
+
110
+ class GroupScale(object):
111
+ """ Rescales the input PIL.Image to the given 'size'.
112
+ 'size' will be the size of the smaller edge.
113
+ For example, if height > width, then image will be
114
+ rescaled to (size * height / width, size)
115
+ size: size of the smaller edge
116
+ interpolation: Default: PIL.Image.BILINEAR
117
+ """
118
+
119
+ def __init__(self, size, interpolation=Image.BILINEAR):
120
+ self.worker = torchvision.transforms.Resize(size, interpolation)
121
+
122
+ def __call__(self, img_group):
123
+ return [self.worker(img) for img in img_group]
124
+
125
+
126
+ class GroupOverSample(object):
127
+ def __init__(self, crop_size, scale_size=None, flip=True):
128
+ self.crop_size = crop_size if not isinstance(
129
+ crop_size, int) else (crop_size, crop_size)
130
+
131
+ if scale_size is not None:
132
+ self.scale_worker = GroupScale(scale_size)
133
+ else:
134
+ self.scale_worker = None
135
+ self.flip = flip
136
+
137
+ def __call__(self, img_group):
138
+
139
+ if self.scale_worker is not None:
140
+ img_group = self.scale_worker(img_group)
141
+
142
+ image_w, image_h = img_group[0].size
143
+ crop_w, crop_h = self.crop_size
144
+
145
+ offsets = GroupMultiScaleCrop.fill_fix_offset(
146
+ False, image_w, image_h, crop_w, crop_h)
147
+ oversample_group = list()
148
+ for o_w, o_h in offsets:
149
+ normal_group = list()
150
+ flip_group = list()
151
+ for i, img in enumerate(img_group):
152
+ crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
153
+ normal_group.append(crop)
154
+ flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
155
+
156
+ if img.mode == 'L' and i % 2 == 0:
157
+ flip_group.append(ImageOps.invert(flip_crop))
158
+ else:
159
+ flip_group.append(flip_crop)
160
+
161
+ oversample_group.extend(normal_group)
162
+ if self.flip:
163
+ oversample_group.extend(flip_group)
164
+ return oversample_group
165
+
166
+
167
+ class GroupFullResSample(object):
168
+ def __init__(self, crop_size, scale_size=None, flip=True):
169
+ self.crop_size = crop_size if not isinstance(
170
+ crop_size, int) else (crop_size, crop_size)
171
+
172
+ if scale_size is not None:
173
+ self.scale_worker = GroupScale(scale_size)
174
+ else:
175
+ self.scale_worker = None
176
+ self.flip = flip
177
+
178
+ def __call__(self, img_group):
179
+
180
+ if self.scale_worker is not None:
181
+ img_group = self.scale_worker(img_group)
182
+
183
+ image_w, image_h = img_group[0].size
184
+ crop_w, crop_h = self.crop_size
185
+
186
+ w_step = (image_w - crop_w) // 4
187
+ h_step = (image_h - crop_h) // 4
188
+
189
+ offsets = list()
190
+ offsets.append((0 * w_step, 2 * h_step)) # left
191
+ offsets.append((4 * w_step, 2 * h_step)) # right
192
+ offsets.append((2 * w_step, 2 * h_step)) # center
193
+
194
+ oversample_group = list()
195
+ for o_w, o_h in offsets:
196
+ normal_group = list()
197
+ flip_group = list()
198
+ for i, img in enumerate(img_group):
199
+ crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
200
+ normal_group.append(crop)
201
+ if self.flip:
202
+ flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
203
+
204
+ if img.mode == 'L' and i % 2 == 0:
205
+ flip_group.append(ImageOps.invert(flip_crop))
206
+ else:
207
+ flip_group.append(flip_crop)
208
+
209
+ oversample_group.extend(normal_group)
210
+ oversample_group.extend(flip_group)
211
+ return oversample_group
212
+
213
+
214
+ class GroupMultiScaleCrop(object):
215
+
216
+ def __init__(self, input_size, scales=None, max_distort=1,
217
+ fix_crop=True, more_fix_crop=True):
218
+ self.scales = scales if scales is not None else [1, .875, .75, .66]
219
+ self.max_distort = max_distort
220
+ self.fix_crop = fix_crop
221
+ self.more_fix_crop = more_fix_crop
222
+ self.input_size = input_size if not isinstance(input_size, int) else [
223
+ input_size, input_size]
224
+ self.interpolation = Image.BILINEAR
225
+
226
+ def __call__(self, img_group):
227
+
228
+ im_size = img_group[0].size
229
+
230
+ crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
231
+ crop_img_group = [
232
+ img.crop(
233
+ (offset_w,
234
+ offset_h,
235
+ offset_w +
236
+ crop_w,
237
+ offset_h +
238
+ crop_h)) for img in img_group]
239
+ ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)
240
+ for img in crop_img_group]
241
+ return ret_img_group
242
+
243
+ def _sample_crop_size(self, im_size):
244
+ image_w, image_h = im_size[0], im_size[1]
245
+
246
+ # find a crop size
247
+ base_size = min(image_w, image_h)
248
+ crop_sizes = [int(base_size * x) for x in self.scales]
249
+ crop_h = [
250
+ self.input_size[1] if abs(
251
+ x - self.input_size[1]) < 3 else x for x in crop_sizes]
252
+ crop_w = [
253
+ self.input_size[0] if abs(
254
+ x - self.input_size[0]) < 3 else x for x in crop_sizes]
255
+
256
+ pairs = []
257
+ for i, h in enumerate(crop_h):
258
+ for j, w in enumerate(crop_w):
259
+ if abs(i - j) <= self.max_distort:
260
+ pairs.append((w, h))
261
+
262
+ crop_pair = random.choice(pairs)
263
+ if not self.fix_crop:
264
+ w_offset = random.randint(0, image_w - crop_pair[0])
265
+ h_offset = random.randint(0, image_h - crop_pair[1])
266
+ else:
267
+ w_offset, h_offset = self._sample_fix_offset(
268
+ image_w, image_h, crop_pair[0], crop_pair[1])
269
+
270
+ return crop_pair[0], crop_pair[1], w_offset, h_offset
271
+
272
+ def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
273
+ offsets = self.fill_fix_offset(
274
+ self.more_fix_crop, image_w, image_h, crop_w, crop_h)
275
+ return random.choice(offsets)
276
+
277
+ @staticmethod
278
+ def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
279
+ w_step = (image_w - crop_w) // 4
280
+ h_step = (image_h - crop_h) // 4
281
+
282
+ ret = list()
283
+ ret.append((0, 0)) # upper left
284
+ ret.append((4 * w_step, 0)) # upper right
285
+ ret.append((0, 4 * h_step)) # lower left
286
+ ret.append((4 * w_step, 4 * h_step)) # lower right
287
+ ret.append((2 * w_step, 2 * h_step)) # center
288
+
289
+ if more_fix_crop:
290
+ ret.append((0, 2 * h_step)) # center left
291
+ ret.append((4 * w_step, 2 * h_step)) # center right
292
+ ret.append((2 * w_step, 4 * h_step)) # lower center
293
+ ret.append((2 * w_step, 0 * h_step)) # upper center
294
+
295
+ ret.append((1 * w_step, 1 * h_step)) # upper left quarter
296
+ ret.append((3 * w_step, 1 * h_step)) # upper right quarter
297
+ ret.append((1 * w_step, 3 * h_step)) # lower left quarter
298
+ ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
299
+
300
+ return ret
301
+
302
+
303
+ class GroupRandomSizedCrop(object):
304
+ """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
305
+ and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
306
+ This is popularly used to train the Inception networks
307
+ size: size of the smaller edge
308
+ interpolation: Default: PIL.Image.BILINEAR
309
+ """
310
+
311
+ def __init__(self, size, interpolation=Image.BILINEAR):
312
+ self.size = size
313
+ self.interpolation = interpolation
314
+
315
+ def __call__(self, img_group):
316
+ for attempt in range(10):
317
+ area = img_group[0].size[0] * img_group[0].size[1]
318
+ target_area = random.uniform(0.08, 1.0) * area
319
+ aspect_ratio = random.uniform(3. / 4, 4. / 3)
320
+
321
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
322
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
323
+
324
+ if random.random() < 0.5:
325
+ w, h = h, w
326
+
327
+ if w <= img_group[0].size[0] and h <= img_group[0].size[1]:
328
+ x1 = random.randint(0, img_group[0].size[0] - w)
329
+ y1 = random.randint(0, img_group[0].size[1] - h)
330
+ found = True
331
+ break
332
+ else:
333
+ found = False
334
+ x1 = 0
335
+ y1 = 0
336
+
337
+ if found:
338
+ out_group = list()
339
+ for img in img_group:
340
+ img = img.crop((x1, y1, x1 + w, y1 + h))
341
+ assert(img.size == (w, h))
342
+ out_group.append(
343
+ img.resize(
344
+ (self.size, self.size), self.interpolation))
345
+ return out_group
346
+ else:
347
+ # Fallback
348
+ scale = GroupScale(self.size, interpolation=self.interpolation)
349
+ crop = GroupRandomCrop(self.size)
350
+ return crop(scale(img_group))
351
+
352
+
353
+ class ConvertDataFormat(object):
354
+ def __init__(self, model_type):
355
+ self.model_type = model_type
356
+
357
+ def __call__(self, images):
358
+ if self.model_type == '2D':
359
+ return images
360
+ tc, h, w = images.size()
361
+ t = tc // 3
362
+ images = images.view(t, 3, h, w)
363
+ images = images.permute(1, 0, 2, 3)
364
+ return images
365
+
366
+
367
+ class Stack(object):
368
+
369
+ def __init__(self, roll=False):
370
+ self.roll = roll
371
+
372
+ def __call__(self, img_group):
373
+ if img_group[0].mode == 'L':
374
+ return np.concatenate([np.expand_dims(x, 2)
375
+ for x in img_group], axis=2)
376
+ elif img_group[0].mode == 'RGB':
377
+ if self.roll:
378
+ return np.concatenate([np.array(x)[:, :, ::-1]
379
+ for x in img_group], axis=2)
380
+ else:
381
+ #print(np.concatenate(img_group, axis=2).shape)
382
+ # print(img_group[0].shape)
383
+ return np.concatenate(img_group, axis=2)
384
+
385
+
386
+ class ToTorchFormatTensor(object):
387
+ """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
388
+ to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
389
+
390
+ def __init__(self, div=True):
391
+ self.div = div
392
+
393
+ def __call__(self, pic):
394
+ if isinstance(pic, np.ndarray):
395
+ # handle numpy array
396
+ img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
397
+ else:
398
+ # handle PIL Image
399
+ img = torch.ByteTensor(
400
+ torch.ByteStorage.from_buffer(
401
+ pic.tobytes()))
402
+ img = img.view(pic.size[1], pic.size[0], len(pic.mode))
403
+ # put it from HWC to CHW format
404
+ # yikes, this transpose takes 80% of the loading time/CPU
405
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
406
+ return img.float().div(255) if self.div else img.float()
407
+
408
+
409
+ class IdentityTransform(object):
410
+
411
+ def __call__(self, data):
412
+ return data
413
+
414
+
415
+ if __name__ == "__main__":
416
+ trans = torchvision.transforms.Compose([
417
+ GroupScale(256),
418
+ GroupRandomCrop(224),
419
+ Stack(),
420
+ ToTorchFormatTensor(),
421
+ GroupNormalize(
422
+ mean=[.485, .456, .406],
423
+ std=[.229, .224, .225]
424
+ )]
425
+ )
426
+
427
+ im = Image.open('../tensorflow-model-zoo.torch/lena_299.png')
428
+
429
+ color_group = [im] * 3
430
+ rst = trans(color_group)
431
+
432
+ gray_group = [im.convert('L')] * 9
433
+ gray_rst = trans(gray_group)
434
+
435
+ trans2 = torchvision.transforms.Compose([
436
+ GroupRandomSizedCrop(256),
437
+ Stack(),
438
+ ToTorchFormatTensor(),
439
+ GroupNormalize(
440
+ mean=[.485, .456, .406],
441
+ std=[.229, .224, .225])
442
+ ])
443
+ print(trans2(color_group))
uniformer_light_video.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # All rights reserved.
2
+ from math import ceil, sqrt
3
+ from collections import OrderedDict
4
+ import torch
5
+ import torch.nn as nn
6
+ from functools import partial
7
+ from timm.models.vision_transformer import _cfg
8
+ from timm.models.layers import trunc_normal_, DropPath, to_2tuple
9
+ import os
10
+
11
+
12
+ global_attn = None
13
+ token_indices = None
14
+
15
+ model_path = 'path_to_models'
16
+ model_path = {
17
+ 'uniformer_xxs_128_in1k': os.path.join(model_path, 'uniformer_xxs_128_in1k.pth'),
18
+ 'uniformer_xxs_160_in1k': os.path.join(model_path, 'uniformer_xxs_160_in1k.pth'),
19
+ 'uniformer_xxs_192_in1k': os.path.join(model_path, 'uniformer_xxs_192_in1k.pth'),
20
+ 'uniformer_xxs_224_in1k': os.path.join(model_path, 'uniformer_xxs_224_in1k.pth'),
21
+ 'uniformer_xs_192_in1k': os.path.join(model_path, 'uniformer_xs_192_in1k.pth'),
22
+ 'uniformer_xs_224_in1k': os.path.join(model_path, 'uniformer_xs_224_in1k.pth'),
23
+ }
24
+
25
+
26
+ def conv_3xnxn(inp, oup, kernel_size=3, stride=3, groups=1):
27
+ return nn.Conv3d(inp, oup, (3, kernel_size, kernel_size), (2, stride, stride), (1, 0, 0), groups=groups)
28
+
29
+ def conv_1xnxn(inp, oup, kernel_size=3, stride=3, groups=1):
30
+ return nn.Conv3d(inp, oup, (1, kernel_size, kernel_size), (1, stride, stride), (0, 0, 0), groups=groups)
31
+
32
+ def conv_3xnxn_std(inp, oup, kernel_size=3, stride=3, groups=1):
33
+ return nn.Conv3d(inp, oup, (3, kernel_size, kernel_size), (1, stride, stride), (1, 0, 0), groups=groups)
34
+
35
+ def conv_1x1x1(inp, oup, groups=1):
36
+ return nn.Conv3d(inp, oup, (1, 1, 1), (1, 1, 1), (0, 0, 0), groups=groups)
37
+
38
+ def conv_3x3x3(inp, oup, groups=1):
39
+ return nn.Conv3d(inp, oup, (3, 3, 3), (1, 1, 1), (1, 1, 1), groups=groups)
40
+
41
+ def conv_5x5x5(inp, oup, groups=1):
42
+ return nn.Conv3d(inp, oup, (5, 5, 5), (1, 1, 1), (2, 2, 2), groups=groups)
43
+
44
+ def bn_3d(dim):
45
+ return nn.BatchNorm3d(dim)
46
+
47
+
48
+ # code is from https://github.com/YifanXu74/Evo-ViT
49
+ def easy_gather(x, indices):
50
+ # x => B x N x C
51
+ # indices => B x N
52
+ B, N, C = x.shape
53
+ N_new = indices.shape[1]
54
+ offset = torch.arange(B, dtype=torch.long, device=x.device).view(B, 1) * N
55
+ indices = indices + offset
56
+ # only select the informative tokens
57
+ out = x.reshape(B * N, C)[indices.view(-1)].reshape(B, N_new, C)
58
+ return out
59
+
60
+
61
+ # code is from https://github.com/YifanXu74/Evo-ViT
62
+ def merge_tokens(x_drop, score):
63
+ # x_drop => B x N_drop
64
+ # score => B x N_drop
65
+ weight = score / torch.sum(score, dim=1, keepdim=True)
66
+ x_drop = weight.unsqueeze(-1) * x_drop
67
+ return torch.sum(x_drop, dim=1, keepdim=True)
68
+
69
+
70
+ class Mlp(nn.Module):
71
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
72
+ super().__init__()
73
+ out_features = out_features or in_features
74
+ hidden_features = hidden_features or in_features
75
+ self.fc1 = nn.Linear(in_features, hidden_features)
76
+ self.act = act_layer()
77
+ self.fc2 = nn.Linear(hidden_features, out_features)
78
+ self.drop = nn.Dropout(drop)
79
+
80
+ def forward(self, x):
81
+ x = self.fc1(x)
82
+ x = self.act(x)
83
+ x = self.drop(x)
84
+ x = self.fc2(x)
85
+ x = self.drop(x)
86
+ return x
87
+
88
+
89
+ class Attention(nn.Module):
90
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., trade_off=1):
91
+ super().__init__()
92
+ self.num_heads = num_heads
93
+ head_dim = dim // num_heads
94
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
95
+ self.scale = qk_scale or head_dim ** -0.5
96
+
97
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
98
+ self.attn_drop = nn.Dropout(attn_drop)
99
+ self.proj = nn.Linear(dim, dim)
100
+ self.proj_drop = nn.Dropout(proj_drop)
101
+ # updating weight for global score
102
+ self.trade_off = trade_off
103
+
104
+ def forward(self, x):
105
+ B, N, C = x.shape
106
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
107
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
108
+
109
+ attn = (q @ k.transpose(-2, -1)) * self.scale
110
+ attn = attn.softmax(dim=-1)
111
+
112
+ # update global score
113
+ global global_attn
114
+ tradeoff = self.trade_off
115
+ if isinstance(global_attn, int):
116
+ global_attn = torch.mean(attn[:, :, 0, 1:], dim=1)
117
+ elif global_attn.shape[1] == N - 1:
118
+ # no additional token and no pruning, update all global scores
119
+ cls_attn = torch.mean(attn[:, :, 0, 1:], dim=1)
120
+ global_attn = (1 - tradeoff) * global_attn + tradeoff * cls_attn
121
+ else:
122
+ # only update the informative tokens
123
+ # the first one is class token
124
+ # the last one is rrepresentative token
125
+ cls_attn = torch.mean(attn[:, :, 0, 1:-1], dim=1)
126
+ if self.training:
127
+ temp_attn = (1 - tradeoff) * global_attn[:, :(N - 2)] + tradeoff * cls_attn
128
+ global_attn = torch.cat((temp_attn, global_attn[:, (N - 2):]), dim=1)
129
+ else:
130
+ # no use torch.cat() for fast inference
131
+ global_attn[:, :(N - 2)] = (1 - tradeoff) * global_attn[:, :(N - 2)] + tradeoff * cls_attn
132
+
133
+ attn = self.attn_drop(attn)
134
+
135
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
136
+ x = self.proj(x)
137
+ x = self.proj_drop(x)
138
+ return x
139
+
140
+
141
+ class CMlp(nn.Module):
142
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
143
+ super().__init__()
144
+ out_features = out_features or in_features
145
+ hidden_features = hidden_features or in_features
146
+ self.fc1 = conv_1x1x1(in_features, hidden_features)
147
+ self.act = act_layer()
148
+ self.fc2 = conv_1x1x1(hidden_features, out_features)
149
+ self.drop = nn.Dropout(drop)
150
+
151
+ def forward(self, x):
152
+ x = self.fc1(x)
153
+ x = self.act(x)
154
+ x = self.drop(x)
155
+ x = self.fc2(x)
156
+ x = self.drop(x)
157
+ return x
158
+
159
+
160
+ class CBlock(nn.Module):
161
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
162
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
163
+ super().__init__()
164
+ self.pos_embed = conv_3x3x3(dim, dim, groups=dim)
165
+ self.norm1 = bn_3d(dim)
166
+ self.conv1 = conv_1x1x1(dim, dim, 1)
167
+ self.conv2 = conv_1x1x1(dim, dim, 1)
168
+ self.attn = conv_5x5x5(dim, dim, groups=dim)
169
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
170
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
171
+ self.norm2 = bn_3d(dim)
172
+ mlp_hidden_dim = int(dim * mlp_ratio)
173
+ self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
174
+
175
+ def forward(self, x):
176
+ x = x + self.pos_embed(x)
177
+ x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x)))))
178
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
179
+ return x
180
+
181
+
182
+ class EvoSABlock(nn.Module):
183
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
184
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, prune_ratio=1,
185
+ trade_off=0, downsample=False):
186
+ super().__init__()
187
+ self.pos_embed = conv_3x3x3(dim, dim, groups=dim)
188
+ self.norm1 = norm_layer(dim)
189
+ self.attn = Attention(
190
+ dim,
191
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
192
+ attn_drop=attn_drop, proj_drop=drop, trade_off=trade_off)
193
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
194
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
195
+ self.norm2 = norm_layer(dim)
196
+ mlp_hidden_dim = int(dim * mlp_ratio)
197
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
198
+ self.prune_ratio = prune_ratio
199
+ self.downsample = downsample
200
+ if downsample:
201
+ self.avgpool = nn.AvgPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))
202
+
203
+ def forward(self, cls_token, x):
204
+ x = x + self.pos_embed(x)
205
+ B, C, T, H, W = x.shape
206
+ x = x.flatten(2).transpose(1, 2)
207
+
208
+ if self.prune_ratio == 1:
209
+ x = torch.cat([cls_token, x], dim=1)
210
+ x = x + self.drop_path(self.attn(self.norm1(x)))
211
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
212
+ cls_token, x = x[:, :1], x[:, 1:]
213
+ x = x.transpose(1, 2).reshape(B, C, T, H, W)
214
+ return cls_token, x
215
+ else:
216
+ global global_attn, token_indices
217
+ # calculate the number of informative tokens
218
+ N = x.shape[1]
219
+ N_ = int(N * self.prune_ratio)
220
+ # sort global attention
221
+ indices = torch.argsort(global_attn, dim=1, descending=True)
222
+
223
+ # concatenate x, global attention and token indices => x_ga_ti
224
+ # rearrange the tensor according to new indices
225
+ x_ga_ti = torch.cat((x, global_attn.unsqueeze(-1), token_indices.unsqueeze(-1)), dim=-1)
226
+ x_ga_ti = easy_gather(x_ga_ti, indices)
227
+ x_sorted, global_attn, token_indices = x_ga_ti[:, :, :-2], x_ga_ti[:, :, -2], x_ga_ti[:, :, -1]
228
+
229
+ # informative tokens
230
+ x_info = x_sorted[:, :N_]
231
+ # merge dropped tokens
232
+ x_drop = x_sorted[:, N_:]
233
+ score = global_attn[:, N_:]
234
+ # B x N_drop x C => B x 1 x C
235
+ rep_token = merge_tokens(x_drop, score)
236
+ # concatenate new tokens
237
+ x = torch.cat((cls_token, x_info, rep_token), dim=1)
238
+
239
+ # slow update
240
+ fast_update = 0
241
+ tmp_x = self.attn(self.norm1(x))
242
+ fast_update = fast_update + tmp_x[:, -1:]
243
+ x = x + self.drop_path(tmp_x)
244
+ tmp_x = self.mlp(self.norm2(x))
245
+ fast_update = fast_update + tmp_x[:, -1:]
246
+ x = x + self.drop_path(tmp_x)
247
+ # fast update
248
+ x_drop = x_drop + fast_update.expand(-1, N - N_, -1)
249
+
250
+ cls_token, x = x[:, :1, :], x[:, 1:-1, :]
251
+ if self.training:
252
+ x_sorted = torch.cat((x, x_drop), dim=1)
253
+ else:
254
+ x_sorted[:, N_:] = x_drop
255
+ x_sorted[:, :N_] = x
256
+
257
+ # recover token
258
+ # scale for normalization
259
+ old_global_scale = torch.sum(global_attn, dim=1, keepdim=True)
260
+ # recover order
261
+ indices = torch.argsort(token_indices, dim=1)
262
+ x_ga_ti = torch.cat((x_sorted, global_attn.unsqueeze(-1), token_indices.unsqueeze(-1)), dim=-1)
263
+ x_ga_ti = easy_gather(x_ga_ti, indices)
264
+ x_patch, global_attn, token_indices = x_ga_ti[:, :, :-2], x_ga_ti[:, :, -2], x_ga_ti[:, :, -1]
265
+ x_patch = x_patch.transpose(1, 2).reshape(B, C, T, H, W)
266
+
267
+ if self.downsample:
268
+ # downsample global attention
269
+ global_attn = global_attn.reshape(B, 1, T, H, W)
270
+ global_attn = self.avgpool(global_attn).view(B, -1)
271
+ # normalize global attention
272
+ new_global_scale = torch.sum(global_attn, dim=1, keepdim=True)
273
+ scale = old_global_scale / new_global_scale
274
+ global_attn = global_attn * scale
275
+
276
+ return cls_token, x_patch
277
+
278
+
279
+ class SABlock(nn.Module):
280
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
281
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
282
+ super().__init__()
283
+ self.pos_embed = conv_3x3x3(dim, dim, groups=dim)
284
+ self.norm1 = norm_layer(dim)
285
+ self.attn = Attention(
286
+ dim,
287
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
288
+ attn_drop=attn_drop, proj_drop=drop)
289
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
290
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
291
+ self.norm2 = norm_layer(dim)
292
+ mlp_hidden_dim = int(dim * mlp_ratio)
293
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
294
+
295
+ def forward(self, x):
296
+ x = x + self.pos_embed(x)
297
+ B, C, T, H, W = x.shape
298
+ x = x.flatten(2).transpose(1, 2)
299
+ x = x + self.drop_path(self.attn(self.norm1(x)))
300
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
301
+ x = x.transpose(1, 2).reshape(B, C, T, H, W)
302
+ return x
303
+
304
+
305
+ class SplitSABlock(nn.Module):
306
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
307
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
308
+ super().__init__()
309
+ self.pos_embed = conv_3x3x3(dim, dim, groups=dim)
310
+ self.t_norm = norm_layer(dim)
311
+ self.t_attn = Attention(
312
+ dim,
313
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
314
+ attn_drop=attn_drop, proj_drop=drop)
315
+ self.norm1 = norm_layer(dim)
316
+ self.attn = Attention(
317
+ dim,
318
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
319
+ attn_drop=attn_drop, proj_drop=drop)
320
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
321
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
322
+ self.norm2 = norm_layer(dim)
323
+ mlp_hidden_dim = int(dim * mlp_ratio)
324
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
325
+
326
+ def forward(self, x):
327
+ x = x + self.pos_embed(x)
328
+ B, C, T, H, W = x.shape
329
+ attn = x.view(B, C, T, H * W).permute(0, 3, 2, 1).contiguous()
330
+ attn = attn.view(B * H * W, T, C)
331
+ attn = attn + self.drop_path(self.t_attn(self.t_norm(attn)))
332
+ attn = attn.view(B, H * W, T, C).permute(0, 2, 1, 3).contiguous()
333
+ attn = attn.view(B * T, H * W, C)
334
+ residual = x.view(B, C, T, H * W).permute(0, 2, 3, 1).contiguous()
335
+ residual = residual.view(B * T, H * W, C)
336
+ attn = residual + self.drop_path(self.attn(self.norm1(attn)))
337
+ attn = attn.view(B, T * H * W, C)
338
+ out = attn + self.drop_path(self.mlp(self.norm2(attn)))
339
+ out = out.transpose(1, 2).reshape(B, C, T, H, W)
340
+ return out
341
+
342
+
343
+ class SpeicalPatchEmbed(nn.Module):
344
+ """ Image to Patch Embedding
345
+ """
346
+ def __init__(self, patch_size=16, in_chans=3, embed_dim=768):
347
+ super().__init__()
348
+ patch_size = to_2tuple(patch_size)
349
+ self.patch_size = patch_size
350
+
351
+ self.proj = nn.Sequential(
352
+ nn.Conv3d(in_chans, embed_dim // 2, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1)),
353
+ nn.BatchNorm3d(embed_dim // 2),
354
+ nn.GELU(),
355
+ nn.Conv3d(embed_dim // 2, embed_dim, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)),
356
+ nn.BatchNorm3d(embed_dim),
357
+ )
358
+
359
+ def forward(self, x):
360
+ B, C, T, H, W = x.shape
361
+ # FIXME look at relaxing size constraints
362
+ # assert H == self.img_size[0] and W == self.img_size[1], \
363
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
364
+ x = self.proj(x)
365
+ B, C, T, H, W = x.shape
366
+ x = x.flatten(2).transpose(1, 2)
367
+ x = x.reshape(B, T, H, W, -1).permute(0, 4, 1, 2, 3).contiguous()
368
+ return x
369
+
370
+
371
+ class PatchEmbed(nn.Module):
372
+ """ Image to Patch Embedding
373
+ """
374
+ def __init__(self, patch_size=16, in_chans=3, embed_dim=768):
375
+ super().__init__()
376
+ patch_size = to_2tuple(patch_size)
377
+ self.patch_size = patch_size
378
+ self.norm = nn.LayerNorm(embed_dim)
379
+ self.proj = conv_1xnxn(in_chans, embed_dim, kernel_size=patch_size[0], stride=patch_size[0])
380
+
381
+ def forward(self, x):
382
+ B, C, T, H, W = x.shape
383
+ # FIXME look at relaxing size constraints
384
+ # assert H == self.img_size[0] and W == self.img_size[1], \
385
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
386
+ x = self.proj(x)
387
+ B, C, T, H, W = x.shape
388
+ x = x.flatten(2).transpose(1, 2)
389
+ x = self.norm(x)
390
+ x = x.reshape(B, T, H, W, -1).permute(0, 4, 1, 2, 3).contiguous()
391
+ return x
392
+
393
+
394
+ class Uniformer_light(nn.Module):
395
+ """ Vision Transformer
396
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
397
+ https://arxiv.org/abs/2010.11929
398
+ """
399
+ def __init__(self, depth=[3, 4, 8, 3], in_chans=3, num_classes=400, embed_dim=[64, 128, 320, 512],
400
+ head_dim=64, mlp_ratio=[4., 4., 4., 4.], qkv_bias=True, qk_scale=None, representation_size=None,
401
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
402
+ prune_ratio=[[], [], [1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5]],
403
+ trade_off=[[], [], [1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]
404
+ ):
405
+ super().__init__()
406
+
407
+ self.num_classes = num_classes
408
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
409
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
410
+
411
+ self.patch_embed1 = SpeicalPatchEmbed(
412
+ patch_size=4, in_chans=in_chans, embed_dim=embed_dim[0])
413
+ self.patch_embed2 = PatchEmbed(
414
+ patch_size=2, in_chans=embed_dim[0], embed_dim=embed_dim[1])
415
+ self.patch_embed3 = PatchEmbed(
416
+ patch_size=2, in_chans=embed_dim[1], embed_dim=embed_dim[2])
417
+ self.patch_embed4 = PatchEmbed(
418
+ patch_size=2, in_chans=embed_dim[2], embed_dim=embed_dim[3])
419
+
420
+ # class token
421
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim[2]))
422
+ self.cls_upsample = nn.Linear(embed_dim[2], embed_dim[3])
423
+
424
+ self.pos_drop = nn.Dropout(p=drop_rate)
425
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))] # stochastic depth decay rule
426
+ num_heads = [dim // head_dim for dim in embed_dim]
427
+ self.blocks1 = nn.ModuleList([
428
+ CBlock(
429
+ dim=embed_dim[0], num_heads=num_heads[0], mlp_ratio=mlp_ratio[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
430
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
431
+ for i in range(depth[0])])
432
+ self.blocks2 = nn.ModuleList([
433
+ CBlock(
434
+ dim=embed_dim[1], num_heads=num_heads[1], mlp_ratio=mlp_ratio[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
435
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]], norm_layer=norm_layer)
436
+ for i in range(depth[1])])
437
+ self.blocks3 = nn.ModuleList([
438
+ EvoSABlock(
439
+ dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
440
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]], norm_layer=norm_layer,
441
+ prune_ratio=prune_ratio[2][i], trade_off=trade_off[2][i],
442
+ downsample=True if i == depth[2] - 1 else False)
443
+ for i in range(depth[2])])
444
+ self.blocks4 = nn.ModuleList([
445
+ EvoSABlock(
446
+ dim=embed_dim[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
447
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]+depth[2]], norm_layer=norm_layer,
448
+ prune_ratio=prune_ratio[3][i], trade_off=trade_off[3][i])
449
+ for i in range(depth[3])])
450
+ self.norm = bn_3d(embed_dim[-1])
451
+ self.norm_cls = nn.LayerNorm(embed_dim[-1])
452
+
453
+ # Representation layer
454
+ if representation_size:
455
+ self.num_features = representation_size
456
+ self.pre_logits = nn.Sequential(OrderedDict([
457
+ ('fc', nn.Linear(embed_dim, representation_size)),
458
+ ('act', nn.Tanh())
459
+ ]))
460
+ else:
461
+ self.pre_logits = nn.Identity()
462
+
463
+ # Classifier head
464
+ self.head = nn.Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
465
+ self.head_cls = nn.Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
466
+
467
+ self.apply(self._init_weights)
468
+
469
+ for name, p in self.named_parameters():
470
+ # fill proj weight with 1 here to improve training dynamics. Otherwise temporal attention inputs
471
+ # are multiplied by 0*0, which is hard for the model to move out of.
472
+ if 't_attn.qkv.weight' in name:
473
+ nn.init.constant_(p, 0)
474
+ if 't_attn.qkv.bias' in name:
475
+ nn.init.constant_(p, 0)
476
+ if 't_attn.proj.weight' in name:
477
+ nn.init.constant_(p, 1)
478
+ if 't_attn.proj.bias' in name:
479
+ nn.init.constant_(p, 0)
480
+
481
+ def _init_weights(self, m):
482
+ if isinstance(m, nn.Linear):
483
+ trunc_normal_(m.weight, std=.02)
484
+ if isinstance(m, nn.Linear) and m.bias is not None:
485
+ nn.init.constant_(m.bias, 0)
486
+ elif isinstance(m, nn.LayerNorm):
487
+ nn.init.constant_(m.bias, 0)
488
+ nn.init.constant_(m.weight, 1.0)
489
+
490
+ @torch.jit.ignore
491
+ def no_weight_decay(self):
492
+ return {'pos_embed', 'cls_token'}
493
+
494
+ def get_classifier(self):
495
+ return self.head
496
+
497
+ def reset_classifier(self, num_classes, global_pool=''):
498
+ self.num_classes = num_classes
499
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
500
+
501
+ def inflate_weight(self, weight_2d, time_dim, center=False):
502
+ if center:
503
+ weight_3d = torch.zeros(*weight_2d.shape)
504
+ weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
505
+ middle_idx = time_dim // 2
506
+ weight_3d[:, :, middle_idx, :, :] = weight_2d
507
+ else:
508
+ weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
509
+ weight_3d = weight_3d / time_dim
510
+ return weight_3d
511
+
512
+ def forward_features(self, x):
513
+ x = self.patch_embed1(x)
514
+ x = self.pos_drop(x)
515
+ for blk in self.blocks1:
516
+ x = blk(x)
517
+ x = self.patch_embed2(x)
518
+ for blk in self.blocks2:
519
+ x = blk(x)
520
+ x = self.patch_embed3(x)
521
+ # add cls_token in stage3
522
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1)
523
+ global global_attn, token_indices
524
+ global_attn = 0
525
+ token_indices = torch.arange(x.shape[2] * x.shape[3] * x.shape[4], dtype=torch.long, device=x.device).unsqueeze(0)
526
+ token_indices = token_indices.expand(x.shape[0], -1)
527
+ for blk in self.blocks3:
528
+ cls_token, x = blk(cls_token, x)
529
+ # upsample cls_token before stage4
530
+ cls_token = self.cls_upsample(cls_token)
531
+ x = self.patch_embed4(x)
532
+ # whether reset global attention? Now simple avgpool
533
+ token_indices = torch.arange(x.shape[2] * x.shape[3] * x.shape[4], dtype=torch.long, device=x.device).unsqueeze(0)
534
+ token_indices = token_indices.expand(x.shape[0], -1)
535
+ for blk in self.blocks4:
536
+ cls_token, x = blk(cls_token, x)
537
+ if self.training:
538
+ # layer normalization for cls_token
539
+ cls_token = self.norm_cls(cls_token)
540
+ x = self.norm(x)
541
+ x = self.pre_logits(x)
542
+ return cls_token, x
543
+
544
+ def forward(self, x):
545
+ cls_token, x = self.forward_features(x)
546
+ x = x.flatten(2).mean(-1)
547
+ if self.training:
548
+ x = self.head(x), self.head_cls(cls_token.squeeze(1))
549
+ else:
550
+ x = self.head(x)
551
+ return x
552
+
553
+
554
+ def uniformer_xxs_video(**kwargs):
555
+ model = Uniformer_light(
556
+ depth=[2, 5, 8, 2],
557
+ prune_ratio=[[], [], [1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5]],
558
+ trade_off=[[], [], [1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5]],
559
+ embed_dim=[56, 112, 224, 448], head_dim=28, mlp_ratio=[3, 3, 3, 3], qkv_bias=True,
560
+ **kwargs)
561
+ model.default_cfg = _cfg()
562
+ return model
563
+
564
+
565
+ def uniformer_xs_video(**kwargs):
566
+ model = Uniformer_light(
567
+ depth=[3, 5, 9, 3],
568
+ prune_ratio=[[], [], [1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5]],
569
+ trade_off=[[], [], [1, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5]],
570
+ embed_dim=[64, 128, 256, 512], head_dim=32, mlp_ratio=[3, 3, 3, 3], qkv_bias=True,
571
+ **kwargs)
572
+ model.default_cfg = _cfg()
573
+ return model
574
+
575
+
576
+ if __name__ == '__main__':
577
+ import time
578
+ from fvcore.nn import FlopCountAnalysis
579
+ from fvcore.nn import flop_count_table
580
+ import numpy as np
581
+
582
+ seed = 4217
583
+ np.random.seed(seed)
584
+ torch.manual_seed(seed)
585
+ torch.cuda.manual_seed(seed)
586
+ torch.cuda.manual_seed_all(seed)
587
+ num_frames = 16
588
+
589
+ model = uniformer_xxs_video()
590
+ # print(model)
591
+
592
+ flops = FlopCountAnalysis(model, torch.rand(1, 3, num_frames, 160, 160))
593
+ s = time.time()
594
+ print(flop_count_table(flops, max_depth=1))
595
+ print(time.time()-s)
videos/hitting_baseball.mp4 ADDED
Binary file (687 kB). View file
 
videos/hoverboarding.mp4 ADDED
Binary file (464 kB). View file
 
videos/yoga.mp4 ADDED
Binary file (776 kB). View file