SakuraD commited on
Commit
469404a
1 Parent(s): 9a445d0
Files changed (9) hide show
  1. README.md +6 -6
  2. app.py +130 -0
  3. hitting_baseball.mp4 +0 -0
  4. hoverboarding.mp4 +0 -0
  5. kinetics_class_index.py +402 -0
  6. requirements.txt +6 -0
  7. transforms.py +443 -0
  8. uniformerv2.py +510 -0
  9. yoga.mp4 +0 -0
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: Uniformerv2 Demo
3
- emoji: 💩
4
- colorFrom: green
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 3.10.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Uniformerv2_demo
3
+ emoji: 📹
4
+ colorFrom: pink
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 3.0.3
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
12
 
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+ import torchvision.transforms as T
7
+ from PIL import Image
8
+ from decord import VideoReader
9
+ from decord import cpu
10
+ from uniformerv2 import uniformerv2_b16
11
+ from kinetics_class_index import kinetics_classnames
12
+ from transforms import (
13
+ GroupNormalize, GroupScale, GroupCenterCrop,
14
+ Stack, ToTorchFormatTensor
15
+ )
16
+
17
+ import gradio as gr
18
+ from huggingface_hub import hf_hub_download
19
+
20
+ class Uniformerv2(nn.Module):
21
+ def __init__(self, model):
22
+ super().__init__()
23
+ self.backbone = model
24
+
25
+ def forward(self, x):
26
+ return self.backbone(x)
27
+
28
+ # Device on which to run the model
29
+ # Set to cuda to load on GPU
30
+ device = "cpu"
31
+ model_path = hf_hub_download(repo_id="Andy1621/uniformerv2", filename="k400+k710_uniformerv2_b16_8x224.pyth")
32
+ # Pick a pretrained model
33
+ model = Uniformerv2(uniformerv2_b16(pretrained=False, t_size=8, no_lmhra=True, temporal_downsample=False))
34
+ state_dict = torch.load(model_path, map_location='cpu')
35
+ model.load_state_dict(state_dict)
36
+
37
+ # Set to eval mode and move to desired device
38
+ model = model.to(device)
39
+ model = model.eval()
40
+
41
+ # Create an id to label name mapping
42
+ kinetics_id_to_classname = {}
43
+ for k, v in kinetics_classnames.items():
44
+ kinetics_id_to_classname[k] = v
45
+
46
+
47
+ def get_index(num_frames, num_segments=8):
48
+ seg_size = float(num_frames - 1) / num_segments
49
+ start = int(seg_size / 2)
50
+ offsets = np.array([
51
+ start + int(np.round(seg_size * idx)) for idx in range(num_segments)
52
+ ])
53
+ return offsets
54
+
55
+
56
+ def load_video(video_path):
57
+ vr = VideoReader(video_path, ctx=cpu(0))
58
+ num_frames = len(vr)
59
+ frame_indices = get_index(num_frames, 8)
60
+
61
+ # transform
62
+ crop_size = 224
63
+ scale_size = 256
64
+ input_mean = [0.485, 0.456, 0.406]
65
+ input_std = [0.229, 0.224, 0.225]
66
+
67
+ transform = T.Compose([
68
+ GroupScale(int(scale_size)),
69
+ GroupCenterCrop(crop_size),
70
+ Stack(),
71
+ ToTorchFormatTensor(),
72
+ GroupNormalize(input_mean, input_std)
73
+ ])
74
+
75
+ images_group = list()
76
+ for frame_index in frame_indices:
77
+ img = Image.fromarray(vr[frame_index].asnumpy())
78
+ images_group.append(img)
79
+ torch_imgs = transform(images_group)
80
+ return torch_imgs
81
+
82
+
83
+ def inference(video):
84
+ vid = load_video(video)
85
+
86
+ # The model expects inputs of shape: B x C x H x W
87
+ TC, H, W = vid.shape
88
+ inputs = vid.reshape(1, TC//3, 3, H, W).permute(0, 2, 1, 3, 4)
89
+
90
+ prediction = model(inputs)
91
+ prediction = F.softmax(prediction, dim=1).flatten()
92
+
93
+ return {kinetics_id_to_classname[str(i)]: float(prediction[i]) for i in range(400)}
94
+
95
+
96
+ def set_example_video(example: list) -> dict:
97
+ return gr.Video.update(value=example[0])
98
+
99
+
100
+ demo = gr.Blocks()
101
+ with demo:
102
+ gr.Markdown(
103
+ """
104
+ # UniFormerV2-B
105
+ Gradio demo for <a href='https://github.com/OpenGVLab/UniFormerV2' target='_blank'>UniFormerV2</a>: To use it, simply upload your video, or click one of the examples to load them. Read more at the links below.
106
+ """
107
+ )
108
+
109
+ with gr.Box():
110
+ with gr.Row():
111
+ with gr.Column():
112
+ with gr.Row():
113
+ input_video = gr.Video(label='Input Video')
114
+ with gr.Row():
115
+ submit_button = gr.Button('Submit')
116
+ with gr.Column():
117
+ label = gr.Label(num_top_classes=5)
118
+ with gr.Row():
119
+ example_videos = gr.Dataset(components=[input_video], samples=[['hitting_baseball.mp4'], ['hoverboarding.mp4'], ['yoga.mp4']])
120
+
121
+ gr.Markdown(
122
+ """
123
+ <p style='text-align: center'><a href='https://arxiv.org/abs/2211.09552' target='_blank'>[Arxiv] UniFormerV2: Spatiotemporal Learning by Arming Image ViTs with Video UniFormer</a> | <a href='https://github.com/OpenGVLab/UniFormerV2' target='_blank'>Github Repo</a></p>
124
+ """
125
+ )
126
+
127
+ submit_button.click(fn=inference, inputs=input_video, outputs=label)
128
+ example_videos.click(fn=set_example_video, inputs=example_videos, outputs=example_videos.components)
129
+
130
+ demo.launch(enable_queue=True)
hitting_baseball.mp4 ADDED
Binary file (687 kB). View file
 
hoverboarding.mp4 ADDED
Binary file (464 kB). View file
 
kinetics_class_index.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ kinetics_classnames = {
2
+ "0": "riding a bike",
3
+ "1": "marching",
4
+ "2": "dodgeball",
5
+ "3": "playing cymbals",
6
+ "4": "checking tires",
7
+ "5": "roller skating",
8
+ "6": "tasting beer",
9
+ "7": "clapping",
10
+ "8": "drawing",
11
+ "9": "juggling fire",
12
+ "10": "bobsledding",
13
+ "11": "petting animal (not cat)",
14
+ "12": "spray painting",
15
+ "13": "training dog",
16
+ "14": "eating watermelon",
17
+ "15": "building cabinet",
18
+ "16": "applauding",
19
+ "17": "playing harp",
20
+ "18": "balloon blowing",
21
+ "19": "sled dog racing",
22
+ "20": "wrestling",
23
+ "21": "pole vault",
24
+ "22": "hurling (sport)",
25
+ "23": "riding scooter",
26
+ "24": "shearing sheep",
27
+ "25": "sweeping floor",
28
+ "26": "eating carrots",
29
+ "27": "skateboarding",
30
+ "28": "dunking basketball",
31
+ "29": "disc golfing",
32
+ "30": "eating spaghetti",
33
+ "31": "playing flute",
34
+ "32": "riding mechanical bull",
35
+ "33": "making sushi",
36
+ "34": "trapezing",
37
+ "35": "picking fruit",
38
+ "36": "stretching leg",
39
+ "37": "playing ukulele",
40
+ "38": "tying tie",
41
+ "39": "skydiving",
42
+ "40": "playing cello",
43
+ "41": "jumping into pool",
44
+ "42": "shooting goal (soccer)",
45
+ "43": "trimming trees",
46
+ "44": "bookbinding",
47
+ "45": "ski jumping",
48
+ "46": "walking the dog",
49
+ "47": "riding unicycle",
50
+ "48": "shaving head",
51
+ "49": "hopscotch",
52
+ "50": "playing piano",
53
+ "51": "parasailing",
54
+ "52": "bartending",
55
+ "53": "kicking field goal",
56
+ "54": "finger snapping",
57
+ "55": "dining",
58
+ "56": "yawning",
59
+ "57": "peeling potatoes",
60
+ "58": "canoeing or kayaking",
61
+ "59": "front raises",
62
+ "60": "laughing",
63
+ "61": "dancing macarena",
64
+ "62": "digging",
65
+ "63": "reading newspaper",
66
+ "64": "hitting baseball",
67
+ "65": "clay pottery making",
68
+ "66": "exercising with an exercise ball",
69
+ "67": "playing saxophone",
70
+ "68": "shooting basketball",
71
+ "69": "washing hair",
72
+ "70": "lunge",
73
+ "71": "brushing hair",
74
+ "72": "curling hair",
75
+ "73": "kitesurfing",
76
+ "74": "tapping guitar",
77
+ "75": "bending back",
78
+ "76": "skipping rope",
79
+ "77": "situp",
80
+ "78": "folding paper",
81
+ "79": "cracking neck",
82
+ "80": "assembling computer",
83
+ "81": "cleaning gutters",
84
+ "82": "blowing out candles",
85
+ "83": "shaking hands",
86
+ "84": "dancing gangnam style",
87
+ "85": "windsurfing",
88
+ "86": "tap dancing",
89
+ "87": "skiing (not slalom or crosscountry)",
90
+ "88": "bandaging",
91
+ "89": "push up",
92
+ "90": "doing nails",
93
+ "91": "punching person (boxing)",
94
+ "92": "bouncing on trampoline",
95
+ "93": "scrambling eggs",
96
+ "94": "singing",
97
+ "95": "cleaning floor",
98
+ "96": "krumping",
99
+ "97": "drumming fingers",
100
+ "98": "snowmobiling",
101
+ "99": "gymnastics tumbling",
102
+ "100": "headbanging",
103
+ "101": "catching or throwing frisbee",
104
+ "102": "riding elephant",
105
+ "103": "bee keeping",
106
+ "104": "feeding birds",
107
+ "105": "snatch weight lifting",
108
+ "106": "mowing lawn",
109
+ "107": "fixing hair",
110
+ "108": "playing trumpet",
111
+ "109": "flying kite",
112
+ "110": "crossing river",
113
+ "111": "swinging legs",
114
+ "112": "sanding floor",
115
+ "113": "belly dancing",
116
+ "114": "sneezing",
117
+ "115": "clean and jerk",
118
+ "116": "side kick",
119
+ "117": "filling eyebrows",
120
+ "118": "shuffling cards",
121
+ "119": "recording music",
122
+ "120": "cartwheeling",
123
+ "121": "feeding fish",
124
+ "122": "folding clothes",
125
+ "123": "water skiing",
126
+ "124": "tobogganing",
127
+ "125": "blowing leaves",
128
+ "126": "smoking",
129
+ "127": "unboxing",
130
+ "128": "tai chi",
131
+ "129": "waxing legs",
132
+ "130": "riding camel",
133
+ "131": "slapping",
134
+ "132": "tossing salad",
135
+ "133": "capoeira",
136
+ "134": "playing cards",
137
+ "135": "playing organ",
138
+ "136": "playing violin",
139
+ "137": "playing drums",
140
+ "138": "tapping pen",
141
+ "139": "vault",
142
+ "140": "shoveling snow",
143
+ "141": "playing tennis",
144
+ "142": "getting a tattoo",
145
+ "143": "making a sandwich",
146
+ "144": "making tea",
147
+ "145": "grinding meat",
148
+ "146": "squat",
149
+ "147": "eating doughnuts",
150
+ "148": "ice fishing",
151
+ "149": "snowkiting",
152
+ "150": "kicking soccer ball",
153
+ "151": "playing controller",
154
+ "152": "giving or receiving award",
155
+ "153": "welding",
156
+ "154": "throwing discus",
157
+ "155": "throwing axe",
158
+ "156": "ripping paper",
159
+ "157": "swimming butterfly stroke",
160
+ "158": "air drumming",
161
+ "159": "blowing nose",
162
+ "160": "hockey stop",
163
+ "161": "taking a shower",
164
+ "162": "bench pressing",
165
+ "163": "planting trees",
166
+ "164": "pumping fist",
167
+ "165": "climbing tree",
168
+ "166": "tickling",
169
+ "167": "high kick",
170
+ "168": "waiting in line",
171
+ "169": "slacklining",
172
+ "170": "tango dancing",
173
+ "171": "hurdling",
174
+ "172": "carrying baby",
175
+ "173": "celebrating",
176
+ "174": "sharpening knives",
177
+ "175": "passing American football (in game)",
178
+ "176": "headbutting",
179
+ "177": "playing recorder",
180
+ "178": "brush painting",
181
+ "179": "garbage collecting",
182
+ "180": "robot dancing",
183
+ "181": "shredding paper",
184
+ "182": "pumping gas",
185
+ "183": "rock climbing",
186
+ "184": "hula hooping",
187
+ "185": "braiding hair",
188
+ "186": "opening present",
189
+ "187": "texting",
190
+ "188": "decorating the christmas tree",
191
+ "189": "answering questions",
192
+ "190": "playing keyboard",
193
+ "191": "writing",
194
+ "192": "bungee jumping",
195
+ "193": "sniffing",
196
+ "194": "eating burger",
197
+ "195": "playing accordion",
198
+ "196": "making pizza",
199
+ "197": "playing volleyball",
200
+ "198": "tasting food",
201
+ "199": "pushing cart",
202
+ "200": "spinning poi",
203
+ "201": "cleaning windows",
204
+ "202": "arm wrestling",
205
+ "203": "changing oil",
206
+ "204": "swimming breast stroke",
207
+ "205": "tossing coin",
208
+ "206": "deadlifting",
209
+ "207": "hoverboarding",
210
+ "208": "cutting watermelon",
211
+ "209": "cheerleading",
212
+ "210": "snorkeling",
213
+ "211": "washing hands",
214
+ "212": "eating cake",
215
+ "213": "pull ups",
216
+ "214": "surfing water",
217
+ "215": "eating hotdog",
218
+ "216": "holding snake",
219
+ "217": "playing harmonica",
220
+ "218": "ironing",
221
+ "219": "cutting nails",
222
+ "220": "golf chipping",
223
+ "221": "shot put",
224
+ "222": "hugging",
225
+ "223": "playing clarinet",
226
+ "224": "faceplanting",
227
+ "225": "trimming or shaving beard",
228
+ "226": "drinking shots",
229
+ "227": "riding mountain bike",
230
+ "228": "tying bow tie",
231
+ "229": "swinging on something",
232
+ "230": "skiing crosscountry",
233
+ "231": "unloading truck",
234
+ "232": "cleaning pool",
235
+ "233": "jogging",
236
+ "234": "ice climbing",
237
+ "235": "mopping floor",
238
+ "236": "making bed",
239
+ "237": "diving cliff",
240
+ "238": "washing dishes",
241
+ "239": "grooming dog",
242
+ "240": "weaving basket",
243
+ "241": "frying vegetables",
244
+ "242": "stomping grapes",
245
+ "243": "moving furniture",
246
+ "244": "cooking sausages",
247
+ "245": "doing laundry",
248
+ "246": "dying hair",
249
+ "247": "knitting",
250
+ "248": "reading book",
251
+ "249": "baby waking up",
252
+ "250": "punching bag",
253
+ "251": "surfing crowd",
254
+ "252": "cooking chicken",
255
+ "253": "pushing car",
256
+ "254": "springboard diving",
257
+ "255": "swing dancing",
258
+ "256": "massaging legs",
259
+ "257": "beatboxing",
260
+ "258": "breading or breadcrumbing",
261
+ "259": "somersaulting",
262
+ "260": "brushing teeth",
263
+ "261": "stretching arm",
264
+ "262": "juggling balls",
265
+ "263": "massaging person's head",
266
+ "264": "eating ice cream",
267
+ "265": "extinguishing fire",
268
+ "266": "hammer throw",
269
+ "267": "whistling",
270
+ "268": "crawling baby",
271
+ "269": "using remote controller (not gaming)",
272
+ "270": "playing cricket",
273
+ "271": "opening bottle",
274
+ "272": "playing xylophone",
275
+ "273": "motorcycling",
276
+ "274": "driving car",
277
+ "275": "exercising arm",
278
+ "276": "passing American football (not in game)",
279
+ "277": "playing kickball",
280
+ "278": "sticking tongue out",
281
+ "279": "flipping pancake",
282
+ "280": "catching fish",
283
+ "281": "eating chips",
284
+ "282": "shaking head",
285
+ "283": "sword fighting",
286
+ "284": "playing poker",
287
+ "285": "cooking on campfire",
288
+ "286": "doing aerobics",
289
+ "287": "paragliding",
290
+ "288": "using segway",
291
+ "289": "folding napkins",
292
+ "290": "playing bagpipes",
293
+ "291": "gargling",
294
+ "292": "skiing slalom",
295
+ "293": "strumming guitar",
296
+ "294": "javelin throw",
297
+ "295": "waxing back",
298
+ "296": "riding or walking with horse",
299
+ "297": "plastering",
300
+ "298": "long jump",
301
+ "299": "parkour",
302
+ "300": "wrapping present",
303
+ "301": "egg hunting",
304
+ "302": "archery",
305
+ "303": "cleaning toilet",
306
+ "304": "swimming backstroke",
307
+ "305": "snowboarding",
308
+ "306": "catching or throwing baseball",
309
+ "307": "massaging back",
310
+ "308": "blowing glass",
311
+ "309": "playing guitar",
312
+ "310": "playing chess",
313
+ "311": "golf driving",
314
+ "312": "presenting weather forecast",
315
+ "313": "rock scissors paper",
316
+ "314": "high jump",
317
+ "315": "baking cookies",
318
+ "316": "using computer",
319
+ "317": "washing feet",
320
+ "318": "arranging flowers",
321
+ "319": "playing bass guitar",
322
+ "320": "spraying",
323
+ "321": "cutting pineapple",
324
+ "322": "waxing chest",
325
+ "323": "auctioning",
326
+ "324": "jetskiing",
327
+ "325": "drinking",
328
+ "326": "busking",
329
+ "327": "playing monopoly",
330
+ "328": "salsa dancing",
331
+ "329": "waxing eyebrows",
332
+ "330": "watering plants",
333
+ "331": "zumba",
334
+ "332": "chopping wood",
335
+ "333": "pushing wheelchair",
336
+ "334": "carving pumpkin",
337
+ "335": "building shed",
338
+ "336": "making jewelry",
339
+ "337": "catching or throwing softball",
340
+ "338": "bending metal",
341
+ "339": "ice skating",
342
+ "340": "dancing charleston",
343
+ "341": "abseiling",
344
+ "342": "climbing a rope",
345
+ "343": "crying",
346
+ "344": "cleaning shoes",
347
+ "345": "dancing ballet",
348
+ "346": "driving tractor",
349
+ "347": "triple jump",
350
+ "348": "throwing ball",
351
+ "349": "getting a haircut",
352
+ "350": "running on treadmill",
353
+ "351": "climbing ladder",
354
+ "352": "blasting sand",
355
+ "353": "playing trombone",
356
+ "354": "drop kicking",
357
+ "355": "country line dancing",
358
+ "356": "changing wheel",
359
+ "357": "feeding goats",
360
+ "358": "tying knot (not on a tie)",
361
+ "359": "setting table",
362
+ "360": "shaving legs",
363
+ "361": "kissing",
364
+ "362": "riding mule",
365
+ "363": "counting money",
366
+ "364": "laying bricks",
367
+ "365": "barbequing",
368
+ "366": "news anchoring",
369
+ "367": "smoking hookah",
370
+ "368": "cooking egg",
371
+ "369": "peeling apples",
372
+ "370": "yoga",
373
+ "371": "sharpening pencil",
374
+ "372": "dribbling basketball",
375
+ "373": "petting cat",
376
+ "374": "playing ice hockey",
377
+ "375": "milking cow",
378
+ "376": "shining shoes",
379
+ "377": "juggling soccer ball",
380
+ "378": "scuba diving",
381
+ "379": "playing squash or racquetball",
382
+ "380": "drinking beer",
383
+ "381": "sign language interpreting",
384
+ "382": "playing basketball",
385
+ "383": "breakdancing",
386
+ "384": "testifying",
387
+ "385": "making snowman",
388
+ "386": "golf putting",
389
+ "387": "playing didgeridoo",
390
+ "388": "biking through snow",
391
+ "389": "sailing",
392
+ "390": "jumpstyle dancing",
393
+ "391": "water sliding",
394
+ "392": "grooming horse",
395
+ "393": "massaging feet",
396
+ "394": "playing paintball",
397
+ "395": "making a cake",
398
+ "396": "bowling",
399
+ "397": "contact juggling",
400
+ "398": "applying cream",
401
+ "399": "playing badminton"
402
+ }
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ einops
4
+ timm
5
+ Pillow
6
+ decord
transforms.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ import random
3
+ from PIL import Image, ImageOps
4
+ import numpy as np
5
+ import numbers
6
+ import math
7
+ import torch
8
+
9
+
10
+ class GroupRandomCrop(object):
11
+ def __init__(self, size):
12
+ if isinstance(size, numbers.Number):
13
+ self.size = (int(size), int(size))
14
+ else:
15
+ self.size = size
16
+
17
+ def __call__(self, img_group):
18
+
19
+ w, h = img_group[0].size
20
+ th, tw = self.size
21
+
22
+ out_images = list()
23
+
24
+ x1 = random.randint(0, w - tw)
25
+ y1 = random.randint(0, h - th)
26
+
27
+ for img in img_group:
28
+ assert(img.size[0] == w and img.size[1] == h)
29
+ if w == tw and h == th:
30
+ out_images.append(img)
31
+ else:
32
+ out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
33
+
34
+ return out_images
35
+
36
+
37
+ class MultiGroupRandomCrop(object):
38
+ def __init__(self, size, groups=1):
39
+ if isinstance(size, numbers.Number):
40
+ self.size = (int(size), int(size))
41
+ else:
42
+ self.size = size
43
+ self.groups = groups
44
+
45
+ def __call__(self, img_group):
46
+
47
+ w, h = img_group[0].size
48
+ th, tw = self.size
49
+
50
+ out_images = list()
51
+
52
+ for i in range(self.groups):
53
+ x1 = random.randint(0, w - tw)
54
+ y1 = random.randint(0, h - th)
55
+
56
+ for img in img_group:
57
+ assert(img.size[0] == w and img.size[1] == h)
58
+ if w == tw and h == th:
59
+ out_images.append(img)
60
+ else:
61
+ out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
62
+
63
+ return out_images
64
+
65
+
66
+ class GroupCenterCrop(object):
67
+ def __init__(self, size):
68
+ self.worker = torchvision.transforms.CenterCrop(size)
69
+
70
+ def __call__(self, img_group):
71
+ return [self.worker(img) for img in img_group]
72
+
73
+
74
+ class GroupRandomHorizontalFlip(object):
75
+ """Randomly horizontally flips the given PIL.Image with a probability of 0.5
76
+ """
77
+
78
+ def __init__(self, is_flow=False):
79
+ self.is_flow = is_flow
80
+
81
+ def __call__(self, img_group, is_flow=False):
82
+ v = random.random()
83
+ if v < 0.5:
84
+ ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
85
+ if self.is_flow:
86
+ for i in range(0, len(ret), 2):
87
+ # invert flow pixel values when flipping
88
+ ret[i] = ImageOps.invert(ret[i])
89
+ return ret
90
+ else:
91
+ return img_group
92
+
93
+
94
+ class GroupNormalize(object):
95
+ def __init__(self, mean, std):
96
+ self.mean = mean
97
+ self.std = std
98
+
99
+ def __call__(self, tensor):
100
+ rep_mean = self.mean * (tensor.size()[0] // len(self.mean))
101
+ rep_std = self.std * (tensor.size()[0] // len(self.std))
102
+
103
+ # TODO: make efficient
104
+ for t, m, s in zip(tensor, rep_mean, rep_std):
105
+ t.sub_(m).div_(s)
106
+
107
+ return tensor
108
+
109
+
110
+ class GroupScale(object):
111
+ """ Rescales the input PIL.Image to the given 'size'.
112
+ 'size' will be the size of the smaller edge.
113
+ For example, if height > width, then image will be
114
+ rescaled to (size * height / width, size)
115
+ size: size of the smaller edge
116
+ interpolation: Default: PIL.Image.BILINEAR
117
+ """
118
+
119
+ def __init__(self, size, interpolation=Image.BILINEAR):
120
+ self.worker = torchvision.transforms.Resize(size, interpolation)
121
+
122
+ def __call__(self, img_group):
123
+ return [self.worker(img) for img in img_group]
124
+
125
+
126
+ class GroupOverSample(object):
127
+ def __init__(self, crop_size, scale_size=None, flip=True):
128
+ self.crop_size = crop_size if not isinstance(
129
+ crop_size, int) else (crop_size, crop_size)
130
+
131
+ if scale_size is not None:
132
+ self.scale_worker = GroupScale(scale_size)
133
+ else:
134
+ self.scale_worker = None
135
+ self.flip = flip
136
+
137
+ def __call__(self, img_group):
138
+
139
+ if self.scale_worker is not None:
140
+ img_group = self.scale_worker(img_group)
141
+
142
+ image_w, image_h = img_group[0].size
143
+ crop_w, crop_h = self.crop_size
144
+
145
+ offsets = GroupMultiScaleCrop.fill_fix_offset(
146
+ False, image_w, image_h, crop_w, crop_h)
147
+ oversample_group = list()
148
+ for o_w, o_h in offsets:
149
+ normal_group = list()
150
+ flip_group = list()
151
+ for i, img in enumerate(img_group):
152
+ crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
153
+ normal_group.append(crop)
154
+ flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
155
+
156
+ if img.mode == 'L' and i % 2 == 0:
157
+ flip_group.append(ImageOps.invert(flip_crop))
158
+ else:
159
+ flip_group.append(flip_crop)
160
+
161
+ oversample_group.extend(normal_group)
162
+ if self.flip:
163
+ oversample_group.extend(flip_group)
164
+ return oversample_group
165
+
166
+
167
+ class GroupFullResSample(object):
168
+ def __init__(self, crop_size, scale_size=None, flip=True):
169
+ self.crop_size = crop_size if not isinstance(
170
+ crop_size, int) else (crop_size, crop_size)
171
+
172
+ if scale_size is not None:
173
+ self.scale_worker = GroupScale(scale_size)
174
+ else:
175
+ self.scale_worker = None
176
+ self.flip = flip
177
+
178
+ def __call__(self, img_group):
179
+
180
+ if self.scale_worker is not None:
181
+ img_group = self.scale_worker(img_group)
182
+
183
+ image_w, image_h = img_group[0].size
184
+ crop_w, crop_h = self.crop_size
185
+
186
+ w_step = (image_w - crop_w) // 4
187
+ h_step = (image_h - crop_h) // 4
188
+
189
+ offsets = list()
190
+ offsets.append((0 * w_step, 2 * h_step)) # left
191
+ offsets.append((4 * w_step, 2 * h_step)) # right
192
+ offsets.append((2 * w_step, 2 * h_step)) # center
193
+
194
+ oversample_group = list()
195
+ for o_w, o_h in offsets:
196
+ normal_group = list()
197
+ flip_group = list()
198
+ for i, img in enumerate(img_group):
199
+ crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
200
+ normal_group.append(crop)
201
+ if self.flip:
202
+ flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
203
+
204
+ if img.mode == 'L' and i % 2 == 0:
205
+ flip_group.append(ImageOps.invert(flip_crop))
206
+ else:
207
+ flip_group.append(flip_crop)
208
+
209
+ oversample_group.extend(normal_group)
210
+ oversample_group.extend(flip_group)
211
+ return oversample_group
212
+
213
+
214
+ class GroupMultiScaleCrop(object):
215
+
216
+ def __init__(self, input_size, scales=None, max_distort=1,
217
+ fix_crop=True, more_fix_crop=True):
218
+ self.scales = scales if scales is not None else [1, .875, .75, .66]
219
+ self.max_distort = max_distort
220
+ self.fix_crop = fix_crop
221
+ self.more_fix_crop = more_fix_crop
222
+ self.input_size = input_size if not isinstance(input_size, int) else [
223
+ input_size, input_size]
224
+ self.interpolation = Image.BILINEAR
225
+
226
+ def __call__(self, img_group):
227
+
228
+ im_size = img_group[0].size
229
+
230
+ crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
231
+ crop_img_group = [
232
+ img.crop(
233
+ (offset_w,
234
+ offset_h,
235
+ offset_w +
236
+ crop_w,
237
+ offset_h +
238
+ crop_h)) for img in img_group]
239
+ ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)
240
+ for img in crop_img_group]
241
+ return ret_img_group
242
+
243
+ def _sample_crop_size(self, im_size):
244
+ image_w, image_h = im_size[0], im_size[1]
245
+
246
+ # find a crop size
247
+ base_size = min(image_w, image_h)
248
+ crop_sizes = [int(base_size * x) for x in self.scales]
249
+ crop_h = [
250
+ self.input_size[1] if abs(
251
+ x - self.input_size[1]) < 3 else x for x in crop_sizes]
252
+ crop_w = [
253
+ self.input_size[0] if abs(
254
+ x - self.input_size[0]) < 3 else x for x in crop_sizes]
255
+
256
+ pairs = []
257
+ for i, h in enumerate(crop_h):
258
+ for j, w in enumerate(crop_w):
259
+ if abs(i - j) <= self.max_distort:
260
+ pairs.append((w, h))
261
+
262
+ crop_pair = random.choice(pairs)
263
+ if not self.fix_crop:
264
+ w_offset = random.randint(0, image_w - crop_pair[0])
265
+ h_offset = random.randint(0, image_h - crop_pair[1])
266
+ else:
267
+ w_offset, h_offset = self._sample_fix_offset(
268
+ image_w, image_h, crop_pair[0], crop_pair[1])
269
+
270
+ return crop_pair[0], crop_pair[1], w_offset, h_offset
271
+
272
+ def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
273
+ offsets = self.fill_fix_offset(
274
+ self.more_fix_crop, image_w, image_h, crop_w, crop_h)
275
+ return random.choice(offsets)
276
+
277
+ @staticmethod
278
+ def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
279
+ w_step = (image_w - crop_w) // 4
280
+ h_step = (image_h - crop_h) // 4
281
+
282
+ ret = list()
283
+ ret.append((0, 0)) # upper left
284
+ ret.append((4 * w_step, 0)) # upper right
285
+ ret.append((0, 4 * h_step)) # lower left
286
+ ret.append((4 * w_step, 4 * h_step)) # lower right
287
+ ret.append((2 * w_step, 2 * h_step)) # center
288
+
289
+ if more_fix_crop:
290
+ ret.append((0, 2 * h_step)) # center left
291
+ ret.append((4 * w_step, 2 * h_step)) # center right
292
+ ret.append((2 * w_step, 4 * h_step)) # lower center
293
+ ret.append((2 * w_step, 0 * h_step)) # upper center
294
+
295
+ ret.append((1 * w_step, 1 * h_step)) # upper left quarter
296
+ ret.append((3 * w_step, 1 * h_step)) # upper right quarter
297
+ ret.append((1 * w_step, 3 * h_step)) # lower left quarter
298
+ ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
299
+
300
+ return ret
301
+
302
+
303
+ class GroupRandomSizedCrop(object):
304
+ """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
305
+ and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
306
+ This is popularly used to train the Inception networks
307
+ size: size of the smaller edge
308
+ interpolation: Default: PIL.Image.BILINEAR
309
+ """
310
+
311
+ def __init__(self, size, interpolation=Image.BILINEAR):
312
+ self.size = size
313
+ self.interpolation = interpolation
314
+
315
+ def __call__(self, img_group):
316
+ for attempt in range(10):
317
+ area = img_group[0].size[0] * img_group[0].size[1]
318
+ target_area = random.uniform(0.08, 1.0) * area
319
+ aspect_ratio = random.uniform(3. / 4, 4. / 3)
320
+
321
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
322
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
323
+
324
+ if random.random() < 0.5:
325
+ w, h = h, w
326
+
327
+ if w <= img_group[0].size[0] and h <= img_group[0].size[1]:
328
+ x1 = random.randint(0, img_group[0].size[0] - w)
329
+ y1 = random.randint(0, img_group[0].size[1] - h)
330
+ found = True
331
+ break
332
+ else:
333
+ found = False
334
+ x1 = 0
335
+ y1 = 0
336
+
337
+ if found:
338
+ out_group = list()
339
+ for img in img_group:
340
+ img = img.crop((x1, y1, x1 + w, y1 + h))
341
+ assert(img.size == (w, h))
342
+ out_group.append(
343
+ img.resize(
344
+ (self.size, self.size), self.interpolation))
345
+ return out_group
346
+ else:
347
+ # Fallback
348
+ scale = GroupScale(self.size, interpolation=self.interpolation)
349
+ crop = GroupRandomCrop(self.size)
350
+ return crop(scale(img_group))
351
+
352
+
353
+ class ConvertDataFormat(object):
354
+ def __init__(self, model_type):
355
+ self.model_type = model_type
356
+
357
+ def __call__(self, images):
358
+ if self.model_type == '2D':
359
+ return images
360
+ tc, h, w = images.size()
361
+ t = tc // 3
362
+ images = images.view(t, 3, h, w)
363
+ images = images.permute(1, 0, 2, 3)
364
+ return images
365
+
366
+
367
+ class Stack(object):
368
+
369
+ def __init__(self, roll=False):
370
+ self.roll = roll
371
+
372
+ def __call__(self, img_group):
373
+ if img_group[0].mode == 'L':
374
+ return np.concatenate([np.expand_dims(x, 2)
375
+ for x in img_group], axis=2)
376
+ elif img_group[0].mode == 'RGB':
377
+ if self.roll:
378
+ return np.concatenate([np.array(x)[:, :, ::-1]
379
+ for x in img_group], axis=2)
380
+ else:
381
+ #print(np.concatenate(img_group, axis=2).shape)
382
+ # print(img_group[0].shape)
383
+ return np.concatenate(img_group, axis=2)
384
+
385
+
386
+ class ToTorchFormatTensor(object):
387
+ """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
388
+ to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
389
+
390
+ def __init__(self, div=True):
391
+ self.div = div
392
+
393
+ def __call__(self, pic):
394
+ if isinstance(pic, np.ndarray):
395
+ # handle numpy array
396
+ img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
397
+ else:
398
+ # handle PIL Image
399
+ img = torch.ByteTensor(
400
+ torch.ByteStorage.from_buffer(
401
+ pic.tobytes()))
402
+ img = img.view(pic.size[1], pic.size[0], len(pic.mode))
403
+ # put it from HWC to CHW format
404
+ # yikes, this transpose takes 80% of the loading time/CPU
405
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
406
+ return img.float().div(255) if self.div else img.float()
407
+
408
+
409
+ class IdentityTransform(object):
410
+
411
+ def __call__(self, data):
412
+ return data
413
+
414
+
415
+ if __name__ == "__main__":
416
+ trans = torchvision.transforms.Compose([
417
+ GroupScale(256),
418
+ GroupRandomCrop(224),
419
+ Stack(),
420
+ ToTorchFormatTensor(),
421
+ GroupNormalize(
422
+ mean=[.485, .456, .406],
423
+ std=[.229, .224, .225]
424
+ )]
425
+ )
426
+
427
+ im = Image.open('../tensorflow-model-zoo.torch/lena_299.png')
428
+
429
+ color_group = [im] * 3
430
+ rst = trans(color_group)
431
+
432
+ gray_group = [im.convert('L')] * 9
433
+ gray_rst = trans(gray_group)
434
+
435
+ trans2 = torchvision.transforms.Compose([
436
+ GroupRandomSizedCrop(256),
437
+ Stack(),
438
+ ToTorchFormatTensor(),
439
+ GroupNormalize(
440
+ mean=[.485, .456, .406],
441
+ std=[.229, .224, .225])
442
+ ])
443
+ print(trans2(color_group))
uniformerv2.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import os
3
+ from collections import OrderedDict
4
+
5
+ from timm.models.layers import DropPath
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import MultiheadAttention
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint as checkpoint
11
+
12
+
13
+ MODEL_PATH = './'
14
+ _MODELS = {
15
+ "ViT-B/16": os.path.join(MODEL_PATH, "vit_b16.pth"),
16
+ "ViT-L/14": os.path.join(MODEL_PATH, "vit_l14.pth"),
17
+ "ViT-L/14_336": os.path.join(MODEL_PATH, "vit_l14_336.pth"),
18
+ }
19
+
20
+
21
+ class LayerNorm(nn.LayerNorm):
22
+ """Subclass torch's LayerNorm to handle fp16."""
23
+
24
+ def forward(self, x):
25
+ orig_type = x.dtype
26
+ ret = super().forward(x.type(torch.float32))
27
+ return ret.type(orig_type)
28
+
29
+
30
+ class QuickGELU(nn.Module):
31
+ def forward(self, x):
32
+ return x * torch.sigmoid(1.702 * x)
33
+
34
+
35
+ class Local_MHRA(nn.Module):
36
+ def __init__(self, d_model, dw_reduction=1.5, pos_kernel_size=3):
37
+ super().__init__()
38
+
39
+ padding = pos_kernel_size // 2
40
+ re_d_model = int(d_model // dw_reduction)
41
+ self.pos_embed = nn.Sequential(
42
+ nn.BatchNorm3d(d_model),
43
+ nn.Conv3d(d_model, re_d_model, kernel_size=1, stride=1, padding=0),
44
+ nn.Conv3d(re_d_model, re_d_model, kernel_size=(pos_kernel_size, 1, 1), stride=(1, 1, 1), padding=(padding, 0, 0), groups=re_d_model),
45
+ nn.Conv3d(re_d_model, d_model, kernel_size=1, stride=1, padding=0),
46
+ )
47
+
48
+ # init zero
49
+ print('Init zero for Conv in pos_emb')
50
+ nn.init.constant_(self.pos_embed[3].weight, 0)
51
+ nn.init.constant_(self.pos_embed[3].bias, 0)
52
+
53
+ def forward(self, x):
54
+ return self.pos_embed(x)
55
+
56
+
57
+ class ResidualAttentionBlock(nn.Module):
58
+ def __init__(
59
+ self, d_model, n_head, attn_mask=None, drop_path=0.0,
60
+ dw_reduction=1.5, no_lmhra=False, double_lmhra=True
61
+ ):
62
+ super().__init__()
63
+
64
+ self.n_head = n_head
65
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
66
+ print(f'Drop path rate: {drop_path}')
67
+
68
+ self.no_lmhra = no_lmhra
69
+ self.double_lmhra = double_lmhra
70
+ print(f'No L_MHRA: {no_lmhra}')
71
+ print(f'Double L_MHRA: {double_lmhra}')
72
+ if not no_lmhra:
73
+ self.lmhra1 = Local_MHRA(d_model, dw_reduction=dw_reduction)
74
+ if double_lmhra:
75
+ self.lmhra2 = Local_MHRA(d_model, dw_reduction=dw_reduction)
76
+
77
+ # spatial
78
+ self.attn = MultiheadAttention(d_model, n_head)
79
+ self.ln_1 = LayerNorm(d_model)
80
+ self.mlp = nn.Sequential(OrderedDict([
81
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
82
+ ("gelu", QuickGELU()),
83
+ ("c_proj", nn.Linear(d_model * 4, d_model))
84
+ ]))
85
+ self.ln_2 = LayerNorm(d_model)
86
+ self.attn_mask = attn_mask
87
+
88
+ def attention(self, x):
89
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
90
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
91
+
92
+ def forward(self, x, T=8, use_checkpoint=False):
93
+ # x: 1+HW, NT, C
94
+ if not self.no_lmhra:
95
+ # Local MHRA
96
+ tmp_x = x[1:, :, :]
97
+ L, NT, C = tmp_x.shape
98
+ N = NT // T
99
+ H = W = int(L ** 0.5)
100
+ tmp_x = tmp_x.view(H, W, N, T, C).permute(2, 4, 3, 0, 1).contiguous()
101
+ tmp_x = tmp_x + self.drop_path(self.lmhra1(tmp_x))
102
+ tmp_x = tmp_x.view(N, C, T, L).permute(3, 0, 2, 1).contiguous().view(L, NT, C)
103
+ x = torch.cat([x[:1, :, :], tmp_x], dim=0)
104
+ # MHSA
105
+ if use_checkpoint:
106
+ attn_out = checkpoint.checkpoint(self.attention, self.ln_1(x))
107
+ x = x + self.drop_path(attn_out)
108
+ else:
109
+ x = x + self.drop_path(self.attention(self.ln_1(x)))
110
+ # Local MHRA
111
+ if not self.no_lmhra and self.double_lmhra:
112
+ tmp_x = x[1:, :, :]
113
+ tmp_x = tmp_x.view(H, W, N, T, C).permute(2, 4, 3, 0, 1).contiguous()
114
+ tmp_x = tmp_x + self.drop_path(self.lmhra2(tmp_x))
115
+ tmp_x = tmp_x.view(N, C, T, L).permute(3, 0, 2, 1).contiguous().view(L, NT, C)
116
+ x = torch.cat([x[:1, :, :], tmp_x], dim=0)
117
+ # FFN
118
+ if use_checkpoint:
119
+ mlp_out = checkpoint.checkpoint(self.mlp, self.ln_2(x))
120
+ x = x + self.drop_path(mlp_out)
121
+ else:
122
+ x = x + self.drop_path(self.mlp(self.ln_2(x)))
123
+ return x
124
+
125
+
126
+ class Extractor(nn.Module):
127
+ def __init__(
128
+ self, d_model, n_head, attn_mask=None,
129
+ mlp_factor=4.0, dropout=0.0, drop_path=0.0,
130
+ ):
131
+ super().__init__()
132
+
133
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
134
+ print(f'Drop path rate: {drop_path}')
135
+ self.attn = nn.MultiheadAttention(d_model, n_head)
136
+ self.ln_1 = nn.LayerNorm(d_model)
137
+ d_mlp = round(mlp_factor * d_model)
138
+ self.mlp = nn.Sequential(OrderedDict([
139
+ ("c_fc", nn.Linear(d_model, d_mlp)),
140
+ ("gelu", QuickGELU()),
141
+ ("dropout", nn.Dropout(dropout)),
142
+ ("c_proj", nn.Linear(d_mlp, d_model))
143
+ ]))
144
+ self.ln_2 = nn.LayerNorm(d_model)
145
+ self.ln_3 = nn.LayerNorm(d_model)
146
+ self.attn_mask = attn_mask
147
+
148
+ # zero init
149
+ nn.init.xavier_uniform_(self.attn.in_proj_weight)
150
+ nn.init.constant_(self.attn.out_proj.weight, 0.)
151
+ nn.init.constant_(self.attn.out_proj.bias, 0.)
152
+ nn.init.xavier_uniform_(self.mlp[0].weight)
153
+ nn.init.constant_(self.mlp[-1].weight, 0.)
154
+ nn.init.constant_(self.mlp[-1].bias, 0.)
155
+
156
+ def attention(self, x, y):
157
+ d_model = self.ln_1.weight.size(0)
158
+ q = (x @ self.attn.in_proj_weight[:d_model].T) + self.attn.in_proj_bias[:d_model]
159
+
160
+ k = (y @ self.attn.in_proj_weight[d_model:-d_model].T) + self.attn.in_proj_bias[d_model:-d_model]
161
+ v = (y @ self.attn.in_proj_weight[-d_model:].T) + self.attn.in_proj_bias[-d_model:]
162
+ Tx, Ty, N = q.size(0), k.size(0), q.size(1)
163
+ q = q.view(Tx, N, self.attn.num_heads, self.attn.head_dim).permute(1, 2, 0, 3)
164
+ k = k.view(Ty, N, self.attn.num_heads, self.attn.head_dim).permute(1, 2, 0, 3)
165
+ v = v.view(Ty, N, self.attn.num_heads, self.attn.head_dim).permute(1, 2, 0, 3)
166
+ aff = (q @ k.transpose(-2, -1) / (self.attn.head_dim ** 0.5))
167
+
168
+ aff = aff.softmax(dim=-1)
169
+ out = aff @ v
170
+ out = out.permute(2, 0, 1, 3).flatten(2)
171
+ out = self.attn.out_proj(out)
172
+ return out
173
+
174
+ def forward(self, x, y):
175
+ x = x + self.drop_path(self.attention(self.ln_1(x), self.ln_3(y)))
176
+ x = x + self.drop_path(self.mlp(self.ln_2(x)))
177
+ return x
178
+
179
+
180
+ class Transformer(nn.Module):
181
+ def __init__(
182
+ self, width, layers, heads, attn_mask=None, backbone_drop_path_rate=0.,
183
+ use_checkpoint=False, checkpoint_num=[0], t_size=8, dw_reduction=2,
184
+ no_lmhra=False, double_lmhra=True,
185
+ return_list=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
186
+ n_layers=12, n_dim=768, n_head=12, mlp_factor=4.0, drop_path_rate=0.,
187
+ mlp_dropout=[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
188
+ cls_dropout=0.5, num_classes=400,
189
+ ):
190
+ super().__init__()
191
+ self.T = t_size
192
+ self.return_list = return_list
193
+ # backbone
194
+ b_dpr = [x.item() for x in torch.linspace(0, backbone_drop_path_rate, layers)]
195
+ self.resblocks = nn.ModuleList([
196
+ ResidualAttentionBlock(
197
+ width, heads, attn_mask,
198
+ drop_path=b_dpr[i],
199
+ dw_reduction=dw_reduction,
200
+ no_lmhra=no_lmhra,
201
+ double_lmhra=double_lmhra,
202
+ ) for i in range(layers)
203
+ ])
204
+ # checkpoint
205
+ self.use_checkpoint = use_checkpoint
206
+ self.checkpoint_num = checkpoint_num
207
+ self.n_layers = n_layers
208
+ print(f'Use checkpoint: {self.use_checkpoint}')
209
+ print(f'Checkpoint number: {self.checkpoint_num}')
210
+
211
+ # global block
212
+ assert n_layers == len(return_list)
213
+ if n_layers > 0:
214
+ self.temporal_cls_token = nn.Parameter(torch.zeros(1, 1, n_dim))
215
+ self.dpe = nn.ModuleList([
216
+ nn.Conv3d(n_dim, n_dim, kernel_size=3, stride=1, padding=1, bias=True, groups=n_dim)
217
+ for i in range(n_layers)
218
+ ])
219
+ for m in self.dpe:
220
+ nn.init.constant_(m.bias, 0.)
221
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)]
222
+ self.dec = nn.ModuleList([
223
+ Extractor(
224
+ n_dim, n_head, mlp_factor=mlp_factor,
225
+ dropout=mlp_dropout[i], drop_path=dpr[i],
226
+ ) for i in range(n_layers)
227
+ ])
228
+ self.balance = nn.Parameter(torch.zeros((n_dim)))
229
+ self.sigmoid = nn.Sigmoid()
230
+ # projection
231
+ self.proj = nn.Sequential(
232
+ nn.LayerNorm(n_dim),
233
+ nn.Dropout(cls_dropout),
234
+ nn.Linear(n_dim, num_classes),
235
+ )
236
+
237
+ def forward(self, x):
238
+ T_down = self.T
239
+ L, NT, C = x.shape
240
+ N = NT // T_down
241
+ H = W = int((L - 1) ** 0.5)
242
+
243
+ if self.n_layers > 0:
244
+ cls_token = self.temporal_cls_token.repeat(1, N, 1)
245
+
246
+ j = -1
247
+ for i, resblock in enumerate(self.resblocks):
248
+ if self.use_checkpoint and i < self.checkpoint_num[0]:
249
+ x = resblock(x, self.T, use_checkpoint=True)
250
+ else:
251
+ x = resblock(x, T_down)
252
+ if i in self.return_list:
253
+ j += 1
254
+ tmp_x = x.clone()
255
+ tmp_x = tmp_x.view(L, N, T_down, C)
256
+ # dpe
257
+ _, tmp_feats = tmp_x[:1], tmp_x[1:]
258
+ tmp_feats = tmp_feats.permute(1, 3, 2, 0).reshape(N, C, T_down, H, W)
259
+ tmp_feats = self.dpe[j](tmp_feats).view(N, C, T_down, L - 1).permute(3, 0, 2, 1).contiguous()
260
+ tmp_x[1:] = tmp_x[1:] + tmp_feats
261
+ # global block
262
+ tmp_x = tmp_x.permute(2, 0, 1, 3).flatten(0, 1) # T * L, N, C
263
+ cls_token = self.dec[j](cls_token, tmp_x)
264
+
265
+ if self.n_layers > 0:
266
+ weight = self.sigmoid(self.balance)
267
+ residual = x.view(L, N, T_down, C)[0].mean(1) # L, N, T, C
268
+ return self.proj((1 - weight) * cls_token[0, :, :] + weight * residual)
269
+ else:
270
+ residual = x.view(L, N, T_down, C)[0].mean(1) # L, N, T, C
271
+ return self.proj(residual)
272
+
273
+
274
+ class VisionTransformer(nn.Module):
275
+ def __init__(
276
+ self,
277
+ # backbone
278
+ input_resolution, patch_size, width, layers, heads, output_dim, backbone_drop_path_rate=0.,
279
+ use_checkpoint=False, checkpoint_num=[0], t_size=8, kernel_size=3, dw_reduction=1.5,
280
+ temporal_downsample=True,
281
+ no_lmhra=-False, double_lmhra=True,
282
+ # global block
283
+ return_list=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
284
+ n_layers=12, n_dim=768, n_head=12, mlp_factor=4.0, drop_path_rate=0.,
285
+ mlp_dropout=[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
286
+ cls_dropout=0.5, num_classes=400,
287
+ ):
288
+ super().__init__()
289
+ self.input_resolution = input_resolution
290
+ self.output_dim = output_dim
291
+ padding = (kernel_size - 1) // 2
292
+ if temporal_downsample:
293
+ self.conv1 = nn.Conv3d(3, width, (kernel_size, patch_size, patch_size), (2, patch_size, patch_size), (padding, 0, 0), bias=False)
294
+ t_size = t_size // 2
295
+ else:
296
+ self.conv1 = nn.Conv3d(3, width, (1, patch_size, patch_size), (1, patch_size, patch_size), (0, 0, 0), bias=False)
297
+
298
+ scale = width ** -0.5
299
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
300
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
301
+ self.ln_pre = LayerNorm(width)
302
+
303
+ self.transformer = Transformer(
304
+ width, layers, heads, dw_reduction=dw_reduction,
305
+ backbone_drop_path_rate=backbone_drop_path_rate,
306
+ use_checkpoint=use_checkpoint, checkpoint_num=checkpoint_num, t_size=t_size,
307
+ no_lmhra=no_lmhra, double_lmhra=double_lmhra,
308
+ return_list=return_list, n_layers=n_layers, n_dim=n_dim, n_head=n_head,
309
+ mlp_factor=mlp_factor, drop_path_rate=drop_path_rate, mlp_dropout=mlp_dropout,
310
+ cls_dropout=cls_dropout, num_classes=num_classes,
311
+ )
312
+
313
+ def forward(self, x):
314
+ x = self.conv1(x) # shape = [*, width, grid, grid]
315
+ N, C, T, H, W = x.shape
316
+ x = x.permute(0, 2, 3, 4, 1).reshape(N * T, H * W, C)
317
+
318
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
319
+ x = x + self.positional_embedding.to(x.dtype)
320
+ x = self.ln_pre(x)
321
+
322
+ x = x.permute(1, 0, 2) # NLD -> LND
323
+ out = self.transformer(x)
324
+ return out
325
+
326
+
327
+ def inflate_weight(weight_2d, time_dim, center=True):
328
+ print(f'Init center: {center}')
329
+ if center:
330
+ weight_3d = torch.zeros(*weight_2d.shape)
331
+ weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
332
+ middle_idx = time_dim // 2
333
+ weight_3d[:, :, middle_idx, :, :] = weight_2d
334
+ else:
335
+ weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)
336
+ weight_3d = weight_3d / time_dim
337
+ return weight_3d
338
+
339
+
340
+ def load_state_dict(model, state_dict):
341
+ state_dict_3d = model.state_dict()
342
+ for k in state_dict.keys():
343
+ if state_dict[k].shape != state_dict_3d[k].shape:
344
+ if len(state_dict_3d[k].shape) <= 2:
345
+ print(f'Ignore: {k}')
346
+ continue
347
+ print(f'Inflate: {k}, {state_dict[k].shape} => {state_dict_3d[k].shape}')
348
+ time_dim = state_dict_3d[k].shape[2]
349
+ state_dict[k] = inflate_weight(state_dict[k], time_dim)
350
+ model.load_state_dict(state_dict, strict=False)
351
+
352
+
353
+ def uniformerv2_b16(
354
+ pretrained=True, use_checkpoint=False, checkpoint_num=[0],
355
+ t_size=16, dw_reduction=1.5, backbone_drop_path_rate=0.,
356
+ temporal_downsample=True,
357
+ no_lmhra=False, double_lmhra=True,
358
+ return_list=[8, 9, 10, 11],
359
+ n_layers=4, n_dim=768, n_head=12, mlp_factor=4.0, drop_path_rate=0.,
360
+ mlp_dropout=[0.5, 0.5, 0.5, 0.5],
361
+ cls_dropout=0.5, num_classes=400,
362
+ ):
363
+ model = VisionTransformer(
364
+ input_resolution=224,
365
+ patch_size=16,
366
+ width=768,
367
+ layers=12,
368
+ heads=12,
369
+ output_dim=512,
370
+ use_checkpoint=use_checkpoint,
371
+ checkpoint_num=checkpoint_num,
372
+ t_size=t_size,
373
+ dw_reduction=dw_reduction,
374
+ backbone_drop_path_rate=backbone_drop_path_rate,
375
+ temporal_downsample=temporal_downsample,
376
+ no_lmhra=no_lmhra,
377
+ double_lmhra=double_lmhra,
378
+ return_list=return_list,
379
+ n_layers=n_layers,
380
+ n_dim=n_dim,
381
+ n_head=n_head,
382
+ mlp_factor=mlp_factor,
383
+ drop_path_rate=drop_path_rate,
384
+ mlp_dropout=mlp_dropout,
385
+ cls_dropout=cls_dropout,
386
+ num_classes=num_classes,
387
+ )
388
+
389
+ if pretrained:
390
+ print('load pretrained weights')
391
+ state_dict = torch.load(_MODELS["ViT-B/16"], map_location='cpu')
392
+ load_state_dict(model, state_dict)
393
+ return model.eval()
394
+
395
+
396
+ def uniformerv2_l14(
397
+ pretrained=True, use_checkpoint=False, checkpoint_num=[0],
398
+ t_size=16, dw_reduction=1.5, backbone_drop_path_rate=0.,
399
+ temporal_downsample=True,
400
+ no_lmhra=False, double_lmhra=True,
401
+ return_list=[20, 21, 22, 23],
402
+ n_layers=4, n_dim=1024, n_head=16, mlp_factor=4.0, drop_path_rate=0.,
403
+ mlp_dropout=[0.5, 0.5, 0.5, 0.5],
404
+ cls_dropout=0.5, num_classes=400,
405
+ ):
406
+ model = VisionTransformer(
407
+ input_resolution=224,
408
+ patch_size=14,
409
+ width=1024,
410
+ layers=24,
411
+ heads=16,
412
+ output_dim=768,
413
+ use_checkpoint=use_checkpoint,
414
+ checkpoint_num=checkpoint_num,
415
+ t_size=t_size,
416
+ dw_reduction=dw_reduction,
417
+ backbone_drop_path_rate=backbone_drop_path_rate,
418
+ temporal_downsample=temporal_downsample,
419
+ no_lmhra=no_lmhra,
420
+ double_lmhra=double_lmhra,
421
+ return_list=return_list,
422
+ n_layers=n_layers,
423
+ n_dim=n_dim,
424
+ n_head=n_head,
425
+ mlp_factor=mlp_factor,
426
+ drop_path_rate=drop_path_rate,
427
+ mlp_dropout=mlp_dropout,
428
+ cls_dropout=cls_dropout,
429
+ num_classes=num_classes,
430
+ )
431
+
432
+ if pretrained:
433
+ print('load pretrained weights')
434
+ state_dict = torch.load(_MODELS["ViT-L/14"], map_location='cpu')
435
+ load_state_dict(model, state_dict)
436
+ return model.eval()
437
+
438
+
439
+ def uniformerv2_l14_336(
440
+ pretrained=True, use_checkpoint=False, checkpoint_num=[0],
441
+ t_size=16, dw_reduction=1.5, backbone_drop_path_rate=0.,
442
+ no_temporal_downsample=True,
443
+ no_lmhra=False, double_lmhra=True,
444
+ return_list=[20, 21, 22, 23],
445
+ n_layers=4, n_dim=1024, n_head=16, mlp_factor=4.0, drop_path_rate=0.,
446
+ mlp_dropout=[0.5, 0.5, 0.5, 0.5],
447
+ cls_dropout=0.5, num_classes=400,
448
+ ):
449
+ model = VisionTransformer(
450
+ input_resolution=336,
451
+ patch_size=14,
452
+ width=1024,
453
+ layers=24,
454
+ heads=16,
455
+ output_dim=768,
456
+ use_checkpoint=use_checkpoint,
457
+ checkpoint_num=checkpoint_num,
458
+ t_size=t_size,
459
+ dw_reduction=dw_reduction,
460
+ backbone_drop_path_rate=backbone_drop_path_rate,
461
+ no_temporal_downsample=no_temporal_downsample,
462
+ no_lmhra=no_lmhra,
463
+ double_lmhra=double_lmhra,
464
+ return_list=return_list,
465
+ n_layers=n_layers,
466
+ n_dim=n_dim,
467
+ n_head=n_head,
468
+ mlp_factor=mlp_factor,
469
+ drop_path_rate=drop_path_rate,
470
+ mlp_dropout=mlp_dropout,
471
+ cls_dropout=cls_dropout,
472
+ num_classes=num_classes,
473
+ )
474
+
475
+ if pretrained:
476
+ print('load pretrained weights')
477
+ state_dict = torch.load(_MODELS["ViT-L/14_336"], map_location='cpu')
478
+ load_state_dict(model, state_dict)
479
+ return model.eval()
480
+
481
+
482
+ if __name__ == '__main__':
483
+ import time
484
+ from fvcore.nn import FlopCountAnalysis
485
+ from fvcore.nn import flop_count_table
486
+ import numpy as np
487
+
488
+ seed = 4217
489
+ np.random.seed(seed)
490
+ torch.manual_seed(seed)
491
+ torch.cuda.manual_seed(seed)
492
+ torch.cuda.manual_seed_all(seed)
493
+ num_frames = 16
494
+
495
+ model = uniformerv2_l14(
496
+ pretrained=False,
497
+ t_size=num_frames, backbone_drop_path_rate=0., drop_path_rate=0.,
498
+ dw_reduction=1.5,
499
+ no_lmhra=False,
500
+ temporal_downsample=True,
501
+ return_list=[8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
502
+ mlp_dropout=[0.5]*16,
503
+ n_layers=16
504
+ )
505
+ print(model)
506
+
507
+ flops = FlopCountAnalysis(model, torch.rand(1, 3, num_frames, 224, 224))
508
+ s = time.time()
509
+ print(flop_count_table(flops, max_depth=1))
510
+ print(time.time()-s)
yoga.mp4 ADDED
Binary file (776 kB). View file