Spaces:
Runtime error
Runtime error
howard-hou
commited on
Commit
•
7bcd65d
1
Parent(s):
26f043f
Update modeling_vision.py
Browse files- modeling_vision.py +6 -6
modeling_vision.py
CHANGED
@@ -28,14 +28,14 @@ class VisionEncoder(nn.Module):
|
|
28 |
return self.proj(image_features)
|
29 |
|
30 |
def grid_pooling(self, image_features):
|
|
|
|
|
31 |
if self.args.grid_size == -1: # no grid pooling
|
32 |
-
return image_features
|
33 |
if self.args.grid_size == 0: # take cls token
|
34 |
-
return
|
35 |
if self.args.grid_size == 1: # global avg pooling
|
36 |
-
return image_features.mean(dim=1, keepdim=True)
|
37 |
-
cls_features = image_features[:, 0:1, :]
|
38 |
-
image_features = image_features[:, 1:, :] #drop cls token
|
39 |
B, L, D = image_features.shape
|
40 |
H_or_W = int(L**0.5)
|
41 |
image_features = image_features.view(B, H_or_W, H_or_W, D)
|
@@ -45,4 +45,4 @@ class VisionEncoder(nn.Module):
|
|
45 |
kernel_size=grid_stride,
|
46 |
stride=grid_stride)
|
47 |
image_features = image_features.permute(0, 2, 3, 1).view(B, -1, D)
|
48 |
-
return torch.cat((
|
|
|
28 |
return self.proj(image_features)
|
29 |
|
30 |
def grid_pooling(self, image_features):
|
31 |
+
cls_features = image_features[:, 0:1, :]
|
32 |
+
image_features = image_features[:, 1:, :] #drop cls token
|
33 |
if self.args.grid_size == -1: # no grid pooling
|
34 |
+
return torch.cat((image_features, cls_features), dim=1)
|
35 |
if self.args.grid_size == 0: # take cls token
|
36 |
+
return cls_features
|
37 |
if self.args.grid_size == 1: # global avg pooling
|
38 |
+
return torch.cat((image_features.mean(dim=1, keepdim=True), cls_features), dim=1)
|
|
|
|
|
39 |
B, L, D = image_features.shape
|
40 |
H_or_W = int(L**0.5)
|
41 |
image_features = image_features.view(B, H_or_W, H_or_W, D)
|
|
|
45 |
kernel_size=grid_stride,
|
46 |
stride=grid_stride)
|
47 |
image_features = image_features.permute(0, 2, 3, 1).view(B, -1, D)
|
48 |
+
return torch.cat((image_features, cls_features), dim=1)
|