gheinrich commited on
Commit
be257a4
1 Parent(s): 42f66ec

Upload model

Browse files
Files changed (8) hide show
  1. adaptor_generic.py +29 -0
  2. adaptor_mlp.py +150 -0
  3. adaptor_registry.py +37 -0
  4. eradio_model.py +18 -431
  5. hf_model.py +13 -42
  6. open_clip_adaptor.py +41 -0
  7. radio_model.py +1 -7
  8. vitdet.py +173 -0
adaptor_generic.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from argparse import Namespace
9
+
10
+ import torch
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+
14
+ from .adaptor_base import AdaptorBase, AdaptorInput, RadioOutput
15
+ from .adaptor_mlp import create_mlp_from_state
16
+
17
+
18
+ class GenericAdaptor(AdaptorBase):
19
+ def __init__(self, main_config: Namespace, adaptor_config, state):
20
+ super().__init__()
21
+
22
+ self.head_mlp = create_mlp_from_state(main_config.mlp_version, state, 'summary.')
23
+ self.feat_mlp = create_mlp_from_state(main_config.mlp_version, state, 'feature.')
24
+
25
+ def forward(self, input: AdaptorInput) -> RadioOutput:
26
+ summary = self.head_mlp(input.summary)
27
+ feat = self.feat_mlp(input.features)
28
+
29
+ return RadioOutput(summary, feat)
adaptor_mlp.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ import math
9
+ from typing import Dict
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+ from einops import rearrange
15
+ from timm.models.vision_transformer import Block
16
+
17
+
18
+ class MLP(nn.Module):
19
+ def __init__(self, input_size: int, hidden_size: int, output_size: int,
20
+ num_inner: int = 0, device: torch.device = None, **kwargs):
21
+ super(MLP, self).__init__()
22
+ self.fc1 = nn.Linear(input_size, hidden_size, device=device)
23
+ self.norm = nn.LayerNorm(hidden_size, device=device)
24
+ self.relu = nn.ReLU()
25
+
26
+ inner = []
27
+ for _ in range(num_inner):
28
+ inner.extend([
29
+ nn.Linear(hidden_size, hidden_size, device=device),
30
+ nn.LayerNorm(hidden_size, device=device),
31
+ nn.ReLU(),
32
+ ])
33
+ if inner:
34
+ self.inner = nn.Sequential(*inner)
35
+ else:
36
+ self.inner = nn.Identity()
37
+
38
+ self.fc2 = nn.Linear(hidden_size, output_size, device=device)
39
+
40
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
41
+ x = self.fc1(x)
42
+ x = self.norm(x)
43
+ x = self.relu(x)
44
+ x = self.inner(x)
45
+ x = self.fc2(x)
46
+ return x
47
+
48
+
49
+ class MLP2(nn.Module):
50
+ def __init__(self, input_size: int, hidden_size: int, output_size: int,
51
+ num_inner: int = 0,
52
+ pre_norm: bool = False, device: torch.device = None,
53
+ upsample_factor: int = 1,
54
+ **kwargs):
55
+ super().__init__()
56
+
57
+ self.pre_norm = nn.Sequential(
58
+ nn.LayerNorm(input_size),
59
+ nn.GELU(),
60
+ ) if pre_norm else nn.Identity()
61
+
62
+ self.upsample_factor = upsample_factor
63
+ self._real_output_dim = output_size
64
+
65
+ hidden_size *= upsample_factor
66
+ output_size *= (upsample_factor ** 2)
67
+
68
+ self.fc1 = nn.Linear(input_size, hidden_size, device=device)
69
+
70
+ blocks = []
71
+ for _ in range(num_inner):
72
+ blocks.append(nn.Sequential(
73
+ nn.LayerNorm(hidden_size, device=device),
74
+ nn.GELU(),
75
+ nn.Linear(hidden_size, hidden_size, device=device),
76
+ ))
77
+ self.blocks = nn.ModuleList(blocks)
78
+
79
+ self.final = nn.Sequential(
80
+ nn.LayerNorm(hidden_size, device=device),
81
+ nn.GELU(),
82
+ nn.Linear(hidden_size, output_size, device=device),
83
+ )
84
+
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ x = self.pre_norm(x)
87
+ x = self.fc1(x)
88
+ for block in self.blocks:
89
+ x = x + block(x)
90
+ x = self.final(x)
91
+
92
+ if self.upsample_factor > 1:
93
+ h = w = int(math.sqrt(x.shape[1]))
94
+ x = rearrange(x, 'b (h w) (u1 u2 c) -> b (u1 h u2 w) c',
95
+ h=h, w=w, u1=self.upsample_factor, u2=self.upsample_factor,
96
+ c=self._real_output_dim)
97
+
98
+ return x
99
+
100
+
101
+ MLP_FACTORY = {
102
+ 'v1': MLP,
103
+ 'v2': MLP2,
104
+ }
105
+
106
+
107
+ def strip_prefix(state: Dict[str, torch.Tensor], prefix: str):
108
+ state = {
109
+ k[len(prefix):]: v
110
+ for k, v in state.items()
111
+ if k.startswith(prefix)
112
+ }
113
+ return state
114
+
115
+
116
+ def get_mlp_info_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = ''):
117
+ state = strip_prefix(state, prefix)
118
+
119
+ if version == 'v1':
120
+ hidden_dim, input_dim = state['fc1.weight'].shape
121
+ output_dim = state['fc2.weight'].shape[0]
122
+
123
+ for num_inner in range(1000):
124
+ k = f'inner.{num_inner}.0.weight'
125
+ if k not in state:
126
+ break
127
+ elif version == 'v2':
128
+ hidden_dim, input_dim = state['fc1.weight'].shape
129
+ output_dim = state['final.2.weight'].shape[0]
130
+
131
+ for num_inner in range(1000):
132
+ k = f'blocks.{num_inner}.0.weight'
133
+ if k not in state:
134
+ break
135
+ else:
136
+ raise ValueError(f'Unsupported MLP version: {version}')
137
+
138
+ return input_dim, hidden_dim, output_dim, num_inner
139
+
140
+
141
+ def create_mlp_from_state(version: str, state: Dict[str, torch.Tensor], prefix: str = ''):
142
+ state = strip_prefix(state, prefix)
143
+
144
+ input_dim, hidden_dim, output_dim, num_inner = get_mlp_info_from_state(version, state)
145
+
146
+ ret: nn.Module = MLP_FACTORY[version](input_dim, hidden_dim, output_dim, num_inner)
147
+
148
+ ret.load_state_dict(state)
149
+
150
+ return ret
adaptor_registry.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from argparse import Namespace
9
+ from typing import Dict, Any
10
+
11
+ import torch
12
+
13
+ from .adaptor_generic import GenericAdaptor, AdaptorBase
14
+
15
+ dict_t = Dict[str, Any]
16
+ state_t = Dict[str, torch.Tensor]
17
+
18
+
19
+ class AdaptorRegistry:
20
+ def __init__(self):
21
+ self._registry = {}
22
+
23
+ def register_adaptor(self, name):
24
+ def decorator(factory_function):
25
+ if name in self._registry:
26
+ raise ValueError(f"Model '{name}' already registered")
27
+ self._registry[name] = factory_function
28
+ return factory_function
29
+ return decorator
30
+
31
+ def create_adaptor(self, name, main_config: Namespace, adaptor_config: dict_t, state: state_t) -> AdaptorBase:
32
+ if name not in self._registry:
33
+ return GenericAdaptor(main_config, adaptor_config, state)
34
+ return self._registry[name](main_config, adaptor_config, state)
35
+
36
+ # Creating an instance of the registry
37
+ adaptor_registry = AdaptorRegistry()
eradio_model.py CHANGED
@@ -8,7 +8,7 @@
8
  # distribution of this software and related documentation without an express
9
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
 
11
- # E-RADIO (FasterViTv2) model from
12
  # Mike Ranzinger, Greg Heinrich, Jan Kautz, and Pavlo Molchanov. "AM-RADIO: Agglomerative Model--Reduce All Domains Into One." arXiv preprint arXiv:2312.06709 (2023).
13
 
14
  # based on FasterViT, Swin Transformer, YOLOv8
@@ -638,7 +638,7 @@ class Downsample(nn.Module):
638
  else:
639
  # removed layer norm for better, in this formulation we are getting 10% better speed
640
  # LayerNorm for high resolution inputs will be a pain as it pools over the entire spatial dimension
641
- # therefore we remove it compared to the original implementation in FasterViTv1
642
  self.norm = nn.Identity()
643
  self.reduction = Conv2d_BN(dim, dim_out, 3, 2, 1, bias=False)
644
 
@@ -790,9 +790,9 @@ class WindowAttention(nn.Module):
790
 
791
 
792
 
793
- class FasterViTLayer(nn.Module):
794
  """
795
- fastervitlayer
796
  """
797
 
798
  def __init__(self,
@@ -960,7 +960,7 @@ class InterpolateLayer(nn.Module):
960
  class HiResNeck(nn.Module):
961
  """
962
  The block is used to output dense features from all stages
963
- Otherwise, by default, only the last stage features are returned with FasterViTv2
964
  """
965
  def __init__(self, dim, depths, neck_start_stage, full_features_head_dim, downsample_enabled):
966
 
@@ -1017,9 +1017,9 @@ class HiResNeck(nn.Module):
1017
  full_features = full_features + feature_projection
1018
  return full_features
1019
 
1020
- class FasterViT(nn.Module):
1021
  """
1022
- FasterViT
1023
  """
1024
 
1025
  def __init__(self,
@@ -1104,7 +1104,7 @@ class FasterViT(nn.Module):
1104
  for i in range(len(depths)):
1105
  conv = True if (i == 0 or i == 1) else False
1106
 
1107
- level = FasterViTLayer(dim=int(dim * 2 ** i),
1108
  depth=depths[i],
1109
  num_heads=num_heads[i],
1110
  window_size=window_size[i],
@@ -1208,9 +1208,9 @@ class FasterViT(nn.Module):
1208
 
1209
  def change_window_size(self, new_window_size):
1210
  """
1211
- FasterViT employs windowed attention, which may be sensitive to the choice of this parameter,
1212
  especially in cases of uneven partitioning of the feature maps.
1213
- FasterViT allows for the adjustment of the window size after training,
1214
  making it adaptable to different input image resolutions.
1215
  The recommended values for window size based on input resolution are as follows:
1216
 
@@ -1243,9 +1243,9 @@ class FasterViT(nn.Module):
1243
  """
1244
  Using hand picked window size for various resolutions.
1245
 
1246
- FasterViT employs windowed attention, which may be sensitive to the choice of this parameter,
1247
  especially in cases of uneven partitioning of the feature maps.
1248
- FasterViT allows for the adjustment of the window size after training,
1249
  making it adaptable to different input image resolutions.
1250
  The recommended values for window size based on input resolution are as follows:
1251
 
@@ -1288,271 +1288,10 @@ class FasterViT(nn.Module):
1288
 
1289
  self.change_window_size(new_window_size = new_window_size)
1290
 
1291
- # 83.44200001953125
1292
- @register_model
1293
- def fastervit2_small(pretrained=False, **kwargs): #,
1294
- model = FasterViT(depths=[3, 3, 5, 5],
1295
- num_heads=[2, 4, 8, 16],
1296
- window_size=[8, 8, [7, 7], 7],
1297
- dim=96,
1298
- in_dim=64,
1299
- mlp_ratio=4,
1300
- drop_path_rate=0.2,
1301
- sr_ratio=[1, 1, [1, 2], 1],
1302
- use_swiglu=False,
1303
- downsample_shuffle=False,
1304
- yolo_arch=True,
1305
- shuffle_down=False,
1306
- **kwargs)
1307
- if pretrained:
1308
- model.load_state_dict(torch.load(pretrained)["state_dict"])
1309
- return model
1310
-
1311
- # 82.61
1312
- @register_model
1313
- def fastervit2_tiny(pretrained=False, **kwargs): #,
1314
- model = FasterViT(depths=[1, 3, 4, 5],
1315
- num_heads=[2, 4, 8, 16],
1316
- window_size=[8, 8, [7, 7], 7],
1317
- dim=80,
1318
- in_dim=64,
1319
- mlp_ratio=4,
1320
- drop_path_rate=0.2,
1321
- sr_ratio=[1, 1, [2, 1], 1],
1322
- use_swiglu=False,
1323
- downsample_shuffle=False,
1324
- yolo_arch=True,
1325
- shuffle_down=False,
1326
- **kwargs)
1327
- if pretrained:
1328
- model.load_state_dict(torch.load(pretrained)["state_dict"])
1329
- return model
1330
-
1331
- #'top1', 84.31800001220704
1332
- @register_model
1333
- def fastervit2_base(pretrained=False, **kwargs):
1334
- model = FasterViT(depths=[3, 3, 5, 5],
1335
- num_heads=[2, 4, 8, 16],
1336
- window_size=[8, 8, [7, 7], 7],
1337
- dim=128,
1338
- in_dim=64,
1339
- mlp_ratio=4,
1340
- drop_path_rate=0.2,
1341
- sr_ratio=[1, 1, [2, 1], 1],
1342
- use_swiglu=False,
1343
- yolo_arch=True,
1344
- shuffle_down=False,
1345
- conv_base=True,
1346
- **kwargs)
1347
- if pretrained:
1348
- model.load_state_dict(torch.load(pretrained)["state_dict"])
1349
- return model
1350
-
1351
- #84.39999999267579
1352
- @register_model
1353
- def fastervit2_base_v1(pretrained=False, **kwargs):
1354
- model = FasterViT(depths=[4, 4, 5, 5],
1355
- num_heads=[2, 4, 8, 16],
1356
- window_size=[8, 8, [7, 7], 7],
1357
- dim=128,
1358
- in_dim=64,
1359
- mlp_ratio=4,
1360
- drop_path_rate=0.2,
1361
- sr_ratio=[1, 1, [2, 1], 1],
1362
- use_swiglu=False,
1363
- yolo_arch=True,
1364
- shuffle_down=False,
1365
- conv_base=True,
1366
- downsample_shuffle=False,
1367
- **kwargs)
1368
- if pretrained:
1369
- model.load_state_dict(torch.load(pretrained)["state_dict"])
1370
- return model
1371
-
1372
- @register_model
1373
- def fastervit2_base_fullres1(pretrained=False, **kwargs):
1374
- model = FasterViT(depths=[3, 3, 5, 5],
1375
- num_heads=[2, 4, 8, 16],
1376
- window_size=[8, 8, [7, 7], 7],
1377
- dim=128,
1378
- in_dim=64,
1379
- mlp_ratio=4,
1380
- drop_path_rate=0.2,
1381
- sr_ratio=[1, 1, [2, 1], 1],
1382
- use_swiglu=False,
1383
- yolo_arch=True,
1384
- shuffle_down=False,
1385
- conv_base=True,
1386
- use_neck=True,
1387
- full_features_head_dim=1024,
1388
- neck_start_stage=2,
1389
- **kwargs)
1390
- if pretrained:
1391
- model.load_state_dict(torch.load(pretrained)["state_dict"])
1392
- return model
1393
-
1394
- @register_model
1395
- def fastervit2_base_fullres2(pretrained=False, **kwargs):
1396
- model = FasterViT(depths=[3, 3, 5, 5],
1397
- num_heads=[2, 4, 8, 16],
1398
- window_size=[8, 8, [7, 7], 7],
1399
- dim=128,
1400
- in_dim=64,
1401
- mlp_ratio=4,
1402
- drop_path_rate=0.2,
1403
- sr_ratio=[1, 1, [2, 1], 1],
1404
- use_swiglu=False,
1405
- yolo_arch=True,
1406
- shuffle_down=False,
1407
- conv_base=True,
1408
- use_neck=True,
1409
- full_features_head_dim=512,
1410
- neck_start_stage=1,
1411
- **kwargs)
1412
- if pretrained:
1413
- model.load_state_dict(torch.load(pretrained)["state_dict"])
1414
- return model
1415
-
1416
- @register_model
1417
- def fastervit2_base_fullres3(pretrained=False, **kwargs):
1418
- model = FasterViT(depths=[3, 3, 5, 5],
1419
- num_heads=[2, 4, 8, 16],
1420
- window_size=[8, 8, [7, 7], 7],
1421
- dim=128,
1422
- in_dim=64,
1423
- mlp_ratio=4,
1424
- drop_path_rate=0.2,
1425
- sr_ratio=[1, 1, [2, 1], 1],
1426
- use_swiglu=False,
1427
- yolo_arch=True,
1428
- shuffle_down=False,
1429
- conv_base=True,
1430
- use_neck=True,
1431
- full_features_head_dim=256,
1432
- neck_start_stage=1,
1433
- **kwargs)
1434
- if pretrained:
1435
- model.load_state_dict(torch.load(pretrained)["state_dict"])
1436
- return model
1437
-
1438
- @register_model
1439
- def fastervit2_base_fullres4(pretrained=False, **kwargs):
1440
- model = FasterViT(depths=[3, 3, 5, 5],
1441
- num_heads=[2, 4, 8, 16],
1442
- window_size=[8, 8, [7, 7], 7],
1443
- dim=128,
1444
- in_dim=64,
1445
- mlp_ratio=4,
1446
- drop_path_rate=0.2,
1447
- sr_ratio=[1, 1, [2, 1], 1],
1448
- use_swiglu=False,
1449
- yolo_arch=True,
1450
- shuffle_down=False,
1451
- conv_base=True,
1452
- use_neck=True,
1453
- full_features_head_dim=256,
1454
- neck_start_stage=2,
1455
- **kwargs)
1456
- if pretrained:
1457
- model.load_state_dict(torch.load(pretrained)["state_dict"])
1458
- return model
1459
-
1460
- @register_model
1461
- def fastervit2_base_fullres5(pretrained=False, **kwargs):
1462
- model = FasterViT(depths=[3, 3, 5, 5],
1463
- num_heads=[2, 4, 8, 16],
1464
- window_size=[8, 8, [7, 7], 7],
1465
- dim=128,
1466
- in_dim=64,
1467
- mlp_ratio=4,
1468
- drop_path_rate=0.2,
1469
- sr_ratio=[1, 1, [2, 1], 1],
1470
- use_swiglu=False,
1471
- yolo_arch=True,
1472
- shuffle_down=False,
1473
- conv_base=True,
1474
- use_neck=True,
1475
- full_features_head_dim=512,
1476
- neck_start_stage=2,
1477
- **kwargs)
1478
- if pretrained:
1479
- model.load_state_dict(torch.load(pretrained)["state_dict"])
1480
- return model
1481
 
1482
- #84.87
1483
  @register_model
1484
- def fastervit2_large(pretrained=False, **kwargs):
1485
- model = FasterViT(depths=[3, 3, 5, 5],
1486
- num_heads=[2, 4, 8, 16],
1487
- window_size=[8, 8, [7, 7], 7],
1488
- dim=128+64,
1489
- in_dim=64,
1490
- mlp_ratio=4,
1491
- drop_path_rate=0.3,
1492
- sr_ratio=[1, 1, [2, 1], 1],
1493
- use_swiglu=False,
1494
- yolo_arch=False,
1495
- shuffle_down=False,
1496
- cpb_mlp_hidden=64,
1497
- conv_base=True,
1498
- **kwargs)
1499
- if pretrained:
1500
- model.load_state_dict(torch.load(pretrained)["state_dict"])
1501
- return model
1502
-
1503
- @register_model
1504
- def fastervit2_large_fullres(pretrained=False, **kwargs):
1505
- model = FasterViT(
1506
- depths=[3, 3, 5, 5],
1507
- num_heads=[2, 4, 8, 16],
1508
- window_size=[None, None, [7, 7], 7],
1509
- dim=192,
1510
- in_dim=64,
1511
- mlp_ratio=4,
1512
- drop_path_rate=0.0,
1513
- sr_ratio=[1, 1, [2, 1], 1],
1514
- use_swiglu=False,
1515
- yolo_arch=True,
1516
- shuffle_down=False,
1517
- conv_base=True,
1518
- use_neck=True,
1519
- full_features_head_dim=1536,
1520
- neck_start_stage=2,
1521
- **kwargs,
1522
- )
1523
- if pretrained:
1524
- model.load_state_dict(torch.load(pretrained)["state_dict"])
1525
- return model
1526
-
1527
-
1528
- @register_model
1529
- def fastervit2_large_fullres_ws8(pretrained=False, **kwargs):
1530
- model = FasterViT(
1531
- depths=[3, 3, 5, 5],
1532
- num_heads=[2, 4, 8, 16],
1533
- window_size=[None, None, [8, 8], 8],
1534
- dim=192,
1535
- in_dim=64,
1536
- mlp_ratio=4,
1537
- drop_path_rate=0.0,
1538
- sr_ratio=[1, 1, [2, 1], 1],
1539
- use_swiglu=False,
1540
- yolo_arch=True,
1541
- shuffle_down=False,
1542
- conv_base=True,
1543
- use_neck=True,
1544
- full_features_head_dim=1536,
1545
- neck_start_stage=2,
1546
- **kwargs,
1547
- )
1548
- if pretrained:
1549
- model.load_state_dict(torch.load(pretrained)["state_dict"])
1550
- return model
1551
-
1552
-
1553
- @register_model
1554
- def fastervit2_large_fullres_ws16(pretrained=False, **kwargs):
1555
- model = FasterViT(
1556
  depths=[3, 3, 5, 5],
1557
  num_heads=[2, 4, 8, 16],
1558
  window_size=[None, None, [16, 16], 16],
@@ -1575,161 +1314,9 @@ def fastervit2_large_fullres_ws16(pretrained=False, **kwargs):
1575
  return model
1576
 
1577
 
1578
- @register_model
1579
- def fastervit2_large_fullres_ws32(pretrained=False, **kwargs):
1580
- model = FasterViT(
1581
- depths=[3, 3, 5, 5],
1582
- num_heads=[2, 4, 8, 16],
1583
- window_size=[None, None, [32, 32], 32],
1584
- dim=192,
1585
- in_dim=64,
1586
- mlp_ratio=4,
1587
- drop_path_rate=0.0,
1588
- sr_ratio=[1, 1, [2, 1], 1],
1589
- use_swiglu=False,
1590
- yolo_arch=True,
1591
- shuffle_down=False,
1592
- conv_base=True,
1593
- use_neck=True,
1594
- full_features_head_dim=1536,
1595
- neck_start_stage=2,
1596
- **kwargs,
1597
- )
1598
- if pretrained:
1599
- model.load_state_dict(torch.load(pretrained)["state_dict"])
1600
- return model
1601
-
1602
- #85.23% top1
1603
- @register_model
1604
- def fastervit2_xlarge(pretrained=False, **kwargs):
1605
- model = FasterViT(depths=[3, 3, 5, 5],
1606
- num_heads=[2, 4, 8, 16],
1607
- window_size=[8, 8, [7, 7], 7],
1608
- dim=128+128+64,
1609
- in_dim=64,
1610
- mlp_ratio=4,
1611
- drop_path_rate=0.4,
1612
- sr_ratio=[1, 1, [2, 1], 1],
1613
- use_swiglu=False,
1614
- yolo_arch=False,
1615
- shuffle_down=False,
1616
- cpb_mlp_hidden=64,
1617
- **kwargs)
1618
- if pretrained:
1619
- model.load_state_dict(torch.load(pretrained)["state_dict"])
1620
- return model
1621
-
1622
- @register_model
1623
- def fastervit2_huge(pretrained=False, **kwargs):
1624
- model = FasterViT(depths=[3, 3, 5, 5],
1625
- num_heads=[2, 4, 8, 16],
1626
- window_size=[8, 8, [7, 7], 7],
1627
- dim=128+128+128+64,
1628
- in_dim=64,
1629
- mlp_ratio=4,
1630
- drop_path_rate=0.2,
1631
- sr_ratio=[1, 1, [2, 1], 1],
1632
- use_swiglu=False,
1633
- yolo_arch=True,
1634
- shuffle_down=False,
1635
- **kwargs)
1636
- if pretrained:
1637
- model.load_state_dict(torch.load(pretrained)["state_dict"])
1638
- return model
1639
-
1640
-
1641
- # 81.61
1642
- @register_model
1643
- def fastervit2_xtiny(pretrained=False, **kwargs): #,
1644
- model = FasterViT(depths=[1, 3, 4, 5],
1645
- num_heads=[2, 4, 8, 16],
1646
- window_size=[8, 8, [7, 7], 7],
1647
- dim=64,
1648
- in_dim=64,
1649
- mlp_ratio=4,
1650
- drop_path_rate=0.1,
1651
- sr_ratio=[1, 1, [2, 1], 1],
1652
- use_swiglu=False,
1653
- downsample_shuffle=False,
1654
- yolo_arch=True,
1655
- shuffle_down=False,
1656
- cpb_mlp_hidden=64,
1657
- **kwargs)
1658
- if pretrained:
1659
- model.load_state_dict(torch.load(pretrained)["state_dict"])
1660
- return model
1661
-
1662
-
1663
- # 80.19
1664
- @register_model
1665
- def fastervit2_xxtiny(pretrained=False, **kwargs): #,
1666
- model = FasterViT(depths=[1, 3, 4, 5],
1667
- num_heads=[2, 4, 8, 16],
1668
- window_size=[8, 8, [7, 7], 7],
1669
- dim=48,
1670
- in_dim=64,
1671
- mlp_ratio=4,
1672
- drop_path_rate=0.05,
1673
- sr_ratio=[1, 1, [2, 1], 1],
1674
- use_swiglu=False,
1675
- downsample_shuffle=False,
1676
- yolo_arch=True,
1677
- shuffle_down=False,
1678
- cpb_mlp_hidden=64,
1679
- **kwargs)
1680
- if pretrained:
1681
- model.load_state_dict(torch.load(pretrained)["state_dict"])
1682
- return model
1683
-
1684
- @register_model
1685
- # 77.0
1686
- def fastervit2_xxxtiny(pretrained=False, **kwargs): #,
1687
- model = FasterViT(depths=[1, 3, 4, 5],
1688
- num_heads=[2, 4, 8, 16],
1689
- window_size=[8, 8, [7, 7], 7],
1690
- dim=32,
1691
- in_dim=32,
1692
- mlp_ratio=4,
1693
- drop_path_rate=0.0,
1694
- sr_ratio=[1, 1, [2, 1], 1],
1695
- use_swiglu=False,
1696
- downsample_shuffle=False,
1697
- yolo_arch=True,
1698
- shuffle_down=False,
1699
- cpb_mlp_hidden=64,
1700
- **kwargs)
1701
- if pretrained:
1702
- model.load_state_dict(torch.load(pretrained)["state_dict"])
1703
- return model
1704
-
1705
-
1706
- @register_model
1707
- def fastervit2_xxxtiny_fullres(pretrained=False, **kwargs):
1708
- model = FasterViT(depths=[1, 3, 4, 5],
1709
- num_heads=[2, 4, 8, 16],
1710
- window_size=[8, 8, [7, 7], 7],
1711
- dim=32,
1712
- in_dim=32,
1713
- mlp_ratio=4,
1714
- drop_path_rate=0.0,
1715
- sr_ratio=[1, 1, [2, 1], 1],
1716
- use_swiglu=False,
1717
- downsample_shuffle=False,
1718
- yolo_arch=True,
1719
- shuffle_down=False,
1720
- cpb_mlp_hidden=64,
1721
- use_neck=True,
1722
- full_features_head_dim=128,
1723
- neck_start_stage=1,
1724
- conv_groups_ratio = 1,
1725
- **kwargs)
1726
- if pretrained:
1727
- model.load_state_dict(torch.load(pretrained)["state_dict"])
1728
- return model
1729
-
1730
  @register_model
1731
  def eradio_xxxtiny(pretrained=False, **kwargs): # ,
1732
- model = FasterViT(
1733
  depths=[1, 3, 4, 5],
1734
  num_heads=[2, 4, 8, 16],
1735
  window_size=[None, None, [16, 16], 16],
@@ -1753,7 +1340,7 @@ def eradio_xxxtiny(pretrained=False, **kwargs): # ,
1753
 
1754
  @register_model
1755
  def eradio_xxxtiny_8x_ws12(pretrained=False, **kwargs):
1756
- model = FasterViT(depths=[1, 3, 4, 5],
1757
  num_heads=[2, 4, 8, 16],
1758
  window_size=[None, None, [12, 12], 12],
1759
  dim=32,
@@ -1778,7 +1365,7 @@ def eradio_xxxtiny_8x_ws12(pretrained=False, **kwargs):
1778
 
1779
  @register_model
1780
  def eradio_xxxtiny_8x_ws16(pretrained=False, **kwargs):
1781
- model = FasterViT(depths=[1, 3, 4, 5],
1782
  num_heads=[2, 4, 8, 16],
1783
  window_size=[None, None, [16, 16], 16],
1784
  dim=32,
@@ -1802,4 +1389,4 @@ def eradio_xxxtiny_8x_ws16(pretrained=False, **kwargs):
1802
 
1803
  @register_model
1804
  def eradio(pretrained=False, **kwargs):
1805
- return fastervit2_large_fullres_ws16(pretrained=pretrained, **kwargs)
 
8
  # distribution of this software and related documentation without an express
9
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
10
 
11
+ # E-RADIO model from
12
  # Mike Ranzinger, Greg Heinrich, Jan Kautz, and Pavlo Molchanov. "AM-RADIO: Agglomerative Model--Reduce All Domains Into One." arXiv preprint arXiv:2312.06709 (2023).
13
 
14
  # based on FasterViT, Swin Transformer, YOLOv8
 
638
  else:
639
  # removed layer norm for better, in this formulation we are getting 10% better speed
640
  # LayerNorm for high resolution inputs will be a pain as it pools over the entire spatial dimension
641
+ # therefore we remove it compared to the original implementation in FasterViT
642
  self.norm = nn.Identity()
643
  self.reduction = Conv2d_BN(dim, dim_out, 3, 2, 1, bias=False)
644
 
 
790
 
791
 
792
 
793
+ class ERADIOLayer(nn.Module):
794
  """
795
+ E-RADIO Layer
796
  """
797
 
798
  def __init__(self,
 
960
  class HiResNeck(nn.Module):
961
  """
962
  The block is used to output dense features from all stages
963
+ Otherwise, by default, only the last stage features are returned with E-RADIO
964
  """
965
  def __init__(self, dim, depths, neck_start_stage, full_features_head_dim, downsample_enabled):
966
 
 
1017
  full_features = full_features + feature_projection
1018
  return full_features
1019
 
1020
+ class ERADIO(nn.Module):
1021
  """
1022
+ Efficient RADIO
1023
  """
1024
 
1025
  def __init__(self,
 
1104
  for i in range(len(depths)):
1105
  conv = True if (i == 0 or i == 1) else False
1106
 
1107
+ level = ERADIOLayer(dim=int(dim * 2 ** i),
1108
  depth=depths[i],
1109
  num_heads=num_heads[i],
1110
  window_size=window_size[i],
 
1208
 
1209
  def change_window_size(self, new_window_size):
1210
  """
1211
+ E-RADIO employs windowed attention, which may be sensitive to the choice of this parameter,
1212
  especially in cases of uneven partitioning of the feature maps.
1213
+ E-RADIO allows for the adjustment of the window size after training,
1214
  making it adaptable to different input image resolutions.
1215
  The recommended values for window size based on input resolution are as follows:
1216
 
 
1243
  """
1244
  Using hand picked window size for various resolutions.
1245
 
1246
+ E-RADIO employs windowed attention, which may be sensitive to the choice of this parameter,
1247
  especially in cases of uneven partitioning of the feature maps.
1248
+ E-RADIO allows for the adjustment of the window size after training,
1249
  making it adaptable to different input image resolutions.
1250
  The recommended values for window size based on input resolution are as follows:
1251
 
 
1288
 
1289
  self.change_window_size(new_window_size = new_window_size)
1290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1291
 
 
1292
  @register_model
1293
+ def eradio_large_fullres_ws16(pretrained=False, **kwargs):
1294
+ model = ERADIO(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1295
  depths=[3, 3, 5, 5],
1296
  num_heads=[2, 4, 8, 16],
1297
  window_size=[None, None, [16, 16], 16],
 
1314
  return model
1315
 
1316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1317
  @register_model
1318
  def eradio_xxxtiny(pretrained=False, **kwargs): # ,
1319
+ model = ERADIO(
1320
  depths=[1, 3, 4, 5],
1321
  num_heads=[2, 4, 8, 16],
1322
  window_size=[None, None, [16, 16], 16],
 
1340
 
1341
  @register_model
1342
  def eradio_xxxtiny_8x_ws12(pretrained=False, **kwargs):
1343
+ model = ERADIO(depths=[1, 3, 4, 5],
1344
  num_heads=[2, 4, 8, 16],
1345
  window_size=[None, None, [12, 12], 12],
1346
  dim=32,
 
1365
 
1366
  @register_model
1367
  def eradio_xxxtiny_8x_ws16(pretrained=False, **kwargs):
1368
+ model = ERADIO(depths=[1, 3, 4, 5],
1369
  num_heads=[2, 4, 8, 16],
1370
  window_size=[None, None, [16, 16], 16],
1371
  dim=32,
 
1389
 
1390
  @register_model
1391
  def eradio(pretrained=False, **kwargs):
1392
+ return eradio_large_fullres_ws16(pretrained=pretrained, **kwargs)
hf_model.py CHANGED
@@ -12,22 +12,30 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  from collections import namedtuple
15
- from typing import Callable, Optional, List, Union
16
 
17
  from timm.models import VisionTransformer
18
  import torch
19
- from torch import nn
20
  from transformers import PretrainedConfig, PreTrainedModel
21
 
22
 
23
  from .common import RESOURCE_MAP, DEFAULT_VERSION
24
 
25
- # Force import of eradio_model in order to register it.
 
 
 
 
 
 
 
26
  from .eradio_model import eradio
27
  from .radio_model import create_model_from_args
28
  from .radio_model import RADIOModel as RADIOModelBase, Resolution
29
  from .input_conditioner import get_default_conditioner, InputConditioner
30
-
 
 
31
 
32
  # Register extra models
33
  from .extra_timm_models import *
@@ -75,7 +83,7 @@ class RADIOModel(PreTrainedModel):
75
 
76
  config_class = RADIOConfig
77
 
78
- def __init__(self, config: RADIOConfig):
79
  super().__init__(config)
80
 
81
  RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
@@ -116,10 +124,6 @@ class RADIOModel(PreTrainedModel):
116
  adaptors=adaptors,
117
  )
118
 
119
- @property
120
- def adaptors(self) -> nn.ModuleDict:
121
- return self.radio_model.adaptors
122
-
123
  @property
124
  def model(self) -> VisionTransformer:
125
  return self.radio_model.model
@@ -128,38 +132,5 @@ class RADIOModel(PreTrainedModel):
128
  def input_conditioner(self) -> InputConditioner:
129
  return self.radio_model.input_conditioner
130
 
131
- @property
132
- def num_summary_tokens(self) -> int:
133
- return self.radio_model.num_summary_tokens
134
-
135
- @property
136
- def patch_size(self) -> int:
137
- return self.radio_model.patch_size
138
-
139
- @property
140
- def max_resolution(self) -> int:
141
- return self.radio_model.max_resolution
142
-
143
- @property
144
- def preferred_resolution(self) -> Resolution:
145
- return self.radio_model.preferred_resolution
146
-
147
- @property
148
- def window_size(self) -> int:
149
- return self.radio_model.window_size
150
-
151
- @property
152
- def min_resolution_step(self) -> int:
153
- return self.radio_model.min_resolution_step
154
-
155
- def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]:
156
- return self.radio_model.make_preprocessor_external()
157
-
158
- def get_nearest_supported_resolution(self, height: int, width: int) -> Resolution:
159
- return self.radio_model.get_nearest_supported_resolution(height, width)
160
-
161
- def switch_to_deploy(self):
162
- return self.radio_model.switch_to_deploy()
163
-
164
  def forward(self, x: torch.Tensor):
165
  return self.radio_model.forward(x)
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  from collections import namedtuple
15
+ from typing import Optional, List, Union
16
 
17
  from timm.models import VisionTransformer
18
  import torch
 
19
  from transformers import PretrainedConfig, PreTrainedModel
20
 
21
 
22
  from .common import RESOURCE_MAP, DEFAULT_VERSION
23
 
24
+ # Import all required modules.
25
+ from .adaptor_base import AdaptorBase, RadioOutput, AdaptorInput
26
+ from .adaptor_generic import GenericAdaptor, AdaptorBase
27
+ from .adaptor_mlp import create_mlp_from_state
28
+ from .adaptor_registry import adaptor_registry
29
+ from .cls_token import ClsToken
30
+ from .enable_cpe_support import enable_cpe
31
+ from .enable_spectral_reparam import configure_spectral_reparam_from_args
32
  from .eradio_model import eradio
33
  from .radio_model import create_model_from_args
34
  from .radio_model import RADIOModel as RADIOModelBase, Resolution
35
  from .input_conditioner import get_default_conditioner, InputConditioner
36
+ from .open_clip_adaptor import OpenCLIP_RADIO
37
+ from .vit_patch_generator import ViTPatchGenerator
38
+ from .vitdet import apply_vitdet_arch, VitDetArgs
39
 
40
  # Register extra models
41
  from .extra_timm_models import *
 
83
 
84
  config_class = RADIOConfig
85
 
86
+ def __init__(self, config):
87
  super().__init__(config)
88
 
89
  RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
 
124
  adaptors=adaptors,
125
  )
126
 
 
 
 
 
127
  @property
128
  def model(self) -> VisionTransformer:
129
  return self.radio_model.model
 
132
  def input_conditioner(self) -> InputConditioner:
133
  return self.radio_model.input_conditioner
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def forward(self, x: torch.Tensor):
136
  return self.radio_model.forward(x)
open_clip_adaptor.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+ from argparse import Namespace
9
+
10
+ import torch
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+
14
+ from .adaptor_registry import adaptor_registry, dict_t, state_t
15
+
16
+ from .adaptor_generic import GenericAdaptor
17
+
18
+
19
+ class OpenCLIP_RADIO(GenericAdaptor):
20
+ def __init__(self, main_config: Namespace, adaptor_config: dict_t, state: state_t):
21
+ super().__init__(main_config, adaptor_config, state)
22
+
23
+ import open_clip
24
+
25
+ self.oc_model = open_clip.create_model_from_pretrained(
26
+ model_name=adaptor_config['model'],
27
+ pretrained=adaptor_config['pretrained'],
28
+ return_transform=False,
29
+ )
30
+ # Unload these parameters
31
+ self.oc_model.visual = None
32
+
33
+ self.tokenizer = open_clip.get_tokenizer(model_name=adaptor_config['model'])
34
+
35
+ def encode_text(self, text, normalize: bool = False):
36
+ return self.oc_model.encode_text(text, normalize=normalize)
37
+
38
+
39
+ @adaptor_registry.register_adaptor("open_clip")
40
+ def create_open_clip_adaptor(main_config: Namespace, adaptor_config: dict_t, state: state_t):
41
+ return OpenCLIP_RADIO(main_config, adaptor_config, state)
radio_model.py CHANGED
@@ -107,12 +107,6 @@ class RADIOModel(nn.Module):
107
  fn()
108
 
109
  def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
110
- res_step = self.min_resolution_step
111
- if res_step is not None and (x.shape[-2] % res_step != 0 or x.shape[-1] % res_step != 0):
112
- raise ValueError('The input resolution must be a multiple of `self.min_resolution_step`. '
113
- '`self.get_nearest_supported_resolution(<height>, <width>) is provided as a convenience API. '
114
- f'Input: {x.shape[-2:]}, Nearest: {self.get_nearest_supported_resolution(*x.shape[-2:])}')
115
-
116
  x = self.input_conditioner(x)
117
  y = self.model.forward_features(x)
118
 
@@ -133,7 +127,7 @@ class RADIOModel(nn.Module):
133
  all_summary = y[:, 0]
134
  bb_summary = all_summary
135
  all_feat = y[:, 1:]
136
- elif isinstance(self.model, eradio_model.FasterViT):
137
  _, f = y
138
  all_feat = f.flatten(2).transpose(1, 2)
139
  all_summary = all_feat.mean(dim=1)
 
107
  fn()
108
 
109
  def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
 
 
 
 
 
 
110
  x = self.input_conditioner(x)
111
  y = self.model.forward_features(x)
112
 
 
127
  all_summary = y[:, 0]
128
  bb_summary = all_summary
129
  all_feat = y[:, 1:]
130
+ elif isinstance(self.model, eradio_model.ERADIO):
131
  _, f = y
132
  all_feat = f.flatten(2).transpose(1, 2)
133
  all_summary = all_feat.mean(dim=1)
vitdet.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from contextlib import contextmanager
3
+ from logging import getLogger
4
+ import math
5
+ import sys
6
+ from typing import List, Union, Iterable
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch import nn
11
+
12
+ from timm.models import VisionTransformer
13
+ from einops import rearrange
14
+
15
+ DEFAULT_NUM_WINDOWED = 5
16
+
17
+
18
+ class VitDetArgs:
19
+ def __init__(self,
20
+ window_size: int,
21
+ num_summary_tokens: int,
22
+ num_windowed: int = DEFAULT_NUM_WINDOWED,
23
+ ):
24
+ self.window_size = window_size
25
+ self.num_summary_tokens = num_summary_tokens
26
+ self.num_windowed = num_windowed
27
+
28
+
29
+ def apply_vitdet_arch(model: VisionTransformer, args: VitDetArgs):
30
+ if isinstance(model, VisionTransformer):
31
+ patch_embed = getattr(model, 'patch_generator', model.patch_embed)
32
+
33
+ return ViTDetHook(patch_embed, model.blocks, args)
34
+ else:
35
+ print(f'Warning: Unable to apply VitDet aug!', file=sys.stderr)
36
+
37
+
38
+ class ViTDetHook:
39
+ def __init__(self,
40
+ embedder: nn.Module,
41
+ blocks: nn.Sequential,
42
+ args: VitDetArgs,
43
+ ):
44
+ self.blocks = blocks
45
+ self.num_summary_tokens = args.num_summary_tokens
46
+ self.window_size = args.window_size
47
+
48
+ self._input_resolution = None
49
+ self._num_windows = None
50
+ self._cls_patch = None
51
+ self._order_cache = dict()
52
+
53
+ embedder.register_forward_pre_hook(self._enter_model)
54
+
55
+ # This will decide if we window-fy the patches
56
+ # and enable vit-det for this iteration, and if so,
57
+ # rearrange the patches for efficient mode switching
58
+ blocks.register_forward_pre_hook(self._enter_blocks)
59
+
60
+ is_global = True
61
+ period = args.num_windowed + 1
62
+ for i, layer in enumerate(blocks[:-1]):
63
+ ctr = i % period
64
+ if ctr == 0:
65
+ layer.register_forward_pre_hook(self._to_windows)
66
+ is_global = False
67
+ elif ctr == args.num_windowed:
68
+ layer.register_forward_pre_hook(self._to_global)
69
+ is_global = True
70
+
71
+ # Always ensure the final layer is a global layer
72
+ if not is_global:
73
+ blocks[-1].register_forward_pre_hook(self._to_global)
74
+
75
+ blocks.register_forward_hook(self._exit_model)
76
+
77
+ def _enter_model(self, _, input: List[torch.Tensor]):
78
+ self._input_resolution = input[0].shape[-2:]
79
+
80
+ def _enter_blocks(self, _, input: List[torch.Tensor]):
81
+ # print(f'{get_rank()} - ViTDet Window Size: {self._window_size}', file=sys.stderr)
82
+
83
+ patches = input[0]
84
+ patches = self._rearrange_patches(patches)
85
+
86
+ return (patches,) + input[1:]
87
+
88
+ def _to_windows(self, _, input: List[torch.Tensor]):
89
+ patches = input[0]
90
+
91
+ if self.num_summary_tokens:
92
+ self._cls_patch = patches[:, :self.num_summary_tokens]
93
+ patches = patches[:, self.num_summary_tokens:]
94
+
95
+ patches = rearrange(
96
+ patches, 'b (p t) c -> (b p) t c',
97
+ p=self._num_windows, t=self.window_size ** 2,
98
+ )
99
+
100
+ return (patches,) + input[1:]
101
+
102
+ def _to_global(self, _, input: List[torch.Tensor]):
103
+ patches = input[0]
104
+
105
+ patches = rearrange(
106
+ patches, '(b p) t c -> b (p t) c',
107
+ p=self._num_windows, t=self.window_size ** 2,
108
+ b=patches.shape[0] // self._num_windows,
109
+ )
110
+
111
+ if self.num_summary_tokens:
112
+ patches = torch.cat([
113
+ self._cls_patch,
114
+ patches,
115
+ ], dim=1)
116
+
117
+ return (patches,) + input[1:]
118
+
119
+ def _exit_model(self, _, inputs: List[torch.Tensor], patches: torch.Tensor):
120
+ # Return patches to their original order
121
+ patch_order = self._order_cache[self._input_resolution][0]
122
+ patch_order = patch_order.reshape(1, -1, 1).expand_as(patches)
123
+
124
+ ret_patches = torch.empty_like(patches)
125
+ ret_patches = torch.scatter(
126
+ ret_patches,
127
+ dim=1,
128
+ index=patch_order,
129
+ src=patches,
130
+ )
131
+
132
+ return ret_patches
133
+
134
+ def _rearrange_patches(self, patches: torch.Tensor):
135
+ # We rearrange the patches so that we can efficiently
136
+ # switch between windowed and global mode by just
137
+ # reshaping the tensor
138
+
139
+ patch_order, self._num_windows = self._order_cache.get(self._input_resolution, (None, None))
140
+ if patch_order is None:
141
+ num_feat_patches = patches.shape[1] - self.num_summary_tokens
142
+ num_pixels = self._input_resolution[0] * self._input_resolution[1]
143
+
144
+ patch_size = int(round(math.sqrt(num_pixels / num_feat_patches)))
145
+ rows = self._input_resolution[-2] // patch_size
146
+ cols = self._input_resolution[-1] // patch_size
147
+
148
+ w_rows = rows // self.window_size
149
+ w_cols = cols // self.window_size
150
+
151
+ patch_order = torch.arange(0, num_feat_patches, device=patches.device)
152
+
153
+ patch_order = rearrange(
154
+ patch_order, '(wy py wx px) -> (wy wx py px)',
155
+ wy=w_rows, wx=w_cols,
156
+ py=self.window_size, px=self.window_size,
157
+ )
158
+
159
+ if self.num_summary_tokens:
160
+ patch_order = torch.cat([
161
+ torch.arange(self.num_summary_tokens, dtype=patch_order.dtype, device=patch_order.device),
162
+ patch_order + self.num_summary_tokens,
163
+ ])
164
+
165
+ self._num_windows = w_rows * w_cols
166
+ self._order_cache[self._input_resolution] = (
167
+ patch_order,
168
+ self._num_windows,
169
+ )
170
+
171
+ patch_order = patch_order.reshape(1, -1, 1).expand_as(patches)
172
+ patches = torch.gather(patches, dim=1, index=patch_order)
173
+ return patches