YouLiXiya commited on
Commit
7dbe662
1 Parent(s): 27d04f5

Upload 22 files

Browse files
sam_extension/distillation_models/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .dino import DINO
2
+ from .sam import SAMEncoderViT, DINOSAMViT
3
+ from .fastertinyvit import FasterTinyViT
4
+ # from .flashvision_transformer import FlashVisionTransformer
sam_extension/distillation_models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (322 Bytes). View file
 
sam_extension/distillation_models/__pycache__/dino.cpython-38.pyc ADDED
Binary file (4.72 kB). View file
 
sam_extension/distillation_models/__pycache__/fastertinyvit.cpython-38.pyc ADDED
Binary file (6.26 kB). View file
 
sam_extension/distillation_models/__pycache__/fastervit.cpython-38.pyc ADDED
Binary file (18 kB). View file
 
sam_extension/distillation_models/__pycache__/sam.cpython-38.pyc ADDED
Binary file (10.7 kB). View file
 
sam_extension/distillation_models/dino.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ from PIL.Image import Image
3
+ from typing import Union
4
+
5
+ from sklearn.decomposition import PCA
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torchvision import transforms as tfs
10
+
11
+
12
+ MEAN = [0.485, 0.456, 0.406]
13
+ STD = [0.229, 0.224, 0.225]
14
+ DINO_MODEL_HUB = 'facebookresearch/dino:main'
15
+ DINO_MODEL_TYPE = ['dino_vits16',
16
+ 'dino_vits8',
17
+ 'dino_vitb16',
18
+ 'dino_vitb8',
19
+ 'dino_xcit_small_12_p16',
20
+ 'dino_xcit_small_12_p8',
21
+ 'dino_xcit_medium_24_p16',
22
+ 'dino_xcit_medium_24_p8',
23
+ 'dino_resnet50']
24
+
25
+ DINOV2_MODEL_HUB = 'facebookresearch/dinov2:main'
26
+ DINOV2_MODEL_TYPE = ['dinov2_vits14',
27
+ 'dinov2_vitb14',
28
+ 'dinov2_vitl14',
29
+ 'dinov2_vitg14']
30
+
31
+ class DINO(nn.Module):
32
+ def __init__(self, model_type, device='cuda', img_size=224, pca_dim=None):
33
+ super(DINO, self).__init__()
34
+ assert model_type in DINO_MODEL_TYPE, 'Given DINO model type must in DINO_MODEL_TYPE!'
35
+ self.model = torch.hub.load(DINO_MODEL_HUB, model_type).to(device)
36
+ self.device = device
37
+ for param in self.model.parameters():
38
+ param.requires_grad = False
39
+ self.model.eval()
40
+ self.img_size = img_size
41
+ self.pca_dim = pca_dim
42
+ self.pca = self.set_pca(pca_dim) if pca_dim else None
43
+ def set_pca(self, dim=64):
44
+ return PCA(n_components=dim)
45
+ @torch.no_grad()
46
+ def extract_features(
47
+ self, img: Union[Image, torch.Tensor], transform=True, size=None
48
+ ):
49
+ if transform and isinstance(img, Image):
50
+ img = self.transform(img, self.img_size).unsqueeze(0) # Nx3xHxW
51
+ with torch.no_grad():
52
+ out = self.model.get_intermediate_layers(img.to(self.device), n=1)[0]
53
+ out = out[:, 1:, :] # we discard the [CLS] token
54
+ h, w = int(img.shape[2] / self.model.patch_embed.patch_size), int(
55
+ img.shape[3] / self.model.patch_embed.patch_size
56
+ )
57
+ dim = out.shape[-1]
58
+ out = out.reshape(-1, h, w, dim)
59
+ dtype = out.dtype
60
+ if size is not None:
61
+ out = torch.nn.functional.interpolate(out.permute(0, 3, 1, 2), size=size, mode='bilinear').permute(0, 2, 3, 1)
62
+ if self.pca:
63
+ B, H, W, C = out.shape
64
+ out = out.view(-1, C).cpu().numpy()
65
+ out = self.pca.fit_transform(out)
66
+ out = torch.tensor(out.reshape(B, H, W, self.pca_dim), dtype=dtype).to(self.device)
67
+ return out
68
+ def forward(self, img: Union[Image, torch.Tensor], transform=True, size=None):
69
+ return self.extract_features(img, transform, size)
70
+ @staticmethod
71
+ def transform(img, image_size):
72
+ transforms = tfs.Compose(
73
+ [tfs.Resize((image_size, image_size)), tfs.ToTensor(), tfs.Normalize(MEAN, STD)]
74
+ )
75
+ img = transforms(img)
76
+ return img
77
+
78
+ class DINOV2(nn.Module):
79
+ def __init__(self, model_type, device='cuda', img_size=224, pca_dim=None):
80
+ super(DINOV2, self).__init__()
81
+ assert model_type in DINOV2_MODEL_TYPE, 'Given DINO model type must in DINO_MODEL_TYPE!'
82
+ self.model = torch.hub.load(DINOV2_MODEL_HUB, model_type).to(device)
83
+ self.device = device
84
+ for param in self.model.parameters():
85
+ param.requires_grad = False
86
+ self.model.eval()
87
+ self.img_size = img_size
88
+ self.pca_dim = pca_dim
89
+ self.pca = self.set_pca(pca_dim) if pca_dim else None
90
+ def set_pca(self, dim=64):
91
+ return PCA(n_components=dim)
92
+ @torch.no_grad()
93
+ def extract_features(
94
+ self, img: Union[Image, torch.Tensor], transform=True, size=None
95
+ ):
96
+ if transform and isinstance(img, Image):
97
+ img = self.transform(img, self.img_size).unsqueeze(0) # Nx3xHxW
98
+ with torch.no_grad():
99
+ out = self.model.forward_features(img.to(self.device))['x_norm_patchtokens']
100
+ h, w = int(img.shape[2] / self.model.patch_size), int(
101
+ img.shape[3] / self.model.patch_size
102
+ )
103
+ dim = out.shape[-1]
104
+ out = out.reshape(-1, h, w, dim)
105
+ dtype = out.dtype
106
+ if size is not None:
107
+ out = torch.nn.functional.interpolate(out.permute(0, 3, 1, 2), size=size, mode='bilinear').permute(0, 2, 3, 1)
108
+ if self.pca:
109
+ B, H, W, C = out.shape
110
+ out = out.view(-1, C).cpu().numpy()
111
+ out = self.pca.fit_transform(out)
112
+ out = torch.tensor(out.reshape(B, H, W, self.pca_dim), dtype=dtype).to(self.device)
113
+ return out
114
+ def forward(self, img: Union[Image, torch.Tensor], transform=True, size=None):
115
+ return self.extract_features(img, transform, size)
116
+ @staticmethod
117
+ def transform(img, image_size):
118
+ transforms = tfs.Compose(
119
+ [tfs.Resize((image_size, image_size)), tfs.ToTensor(), tfs.Normalize(MEAN, STD)]
120
+ )
121
+ img = transforms(img)
122
+ return img
sam_extension/distillation_models/fastertinyvit.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List, Union
2
+ import torch
3
+ from torch import nn
4
+ from torch.utils.checkpoint import checkpoint
5
+ import torch.nn.functional as F
6
+ from timm.models.layers import trunc_normal_
7
+ from sam_extension.distillation_models.fastervit import FasterViTLayer
8
+ from segment_anything.mobile_encoder.tiny_vit_sam import PatchEmbed, Conv2d_BN, LayerNorm2d, MBConv
9
+ class PatchMerging(nn.Module):
10
+ def __init__(self, input_resolution, dim, out_dim, activation):
11
+ super().__init__()
12
+
13
+ self.input_resolution = input_resolution
14
+ self.dim = dim
15
+ self.out_dim = out_dim
16
+ self.act = activation()
17
+ self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
18
+ stride_c=2
19
+ if(out_dim==320 or out_dim==448 or out_dim==576):#handongshen 576
20
+ stride_c=1
21
+ self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
22
+ self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
23
+
24
+ def forward(self, x):
25
+ if x.ndim == 3:
26
+ H, W = self.input_resolution
27
+ B = len(x)
28
+ # (B, C, H, W)
29
+ x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
30
+
31
+ x = self.conv1(x)
32
+ x = self.act(x)
33
+
34
+ x = self.conv2(x)
35
+ x = self.act(x)
36
+ x = self.conv3(x)
37
+ return x
38
+
39
+
40
+ class ConvLayer(nn.Module):
41
+ def __init__(self, dim, input_resolution, depth,
42
+ activation,
43
+ drop_path=0., downsample=None, use_checkpoint=False,
44
+ out_dim=None,
45
+ conv_expand_ratio=4.,
46
+ ):
47
+
48
+ super().__init__()
49
+ self.dim = dim
50
+ self.input_resolution = input_resolution
51
+ self.depth = depth
52
+ self.use_checkpoint = use_checkpoint
53
+
54
+ # build blocks
55
+ self.blocks = nn.ModuleList([
56
+ MBConv(dim, dim, conv_expand_ratio, activation,
57
+ drop_path[i] if isinstance(drop_path, list) else drop_path,
58
+ )
59
+ for i in range(depth)])
60
+
61
+ # patch merging layer
62
+ if downsample is not None:
63
+ self.downsample = downsample(
64
+ input_resolution, dim=dim, out_dim=out_dim, activation=activation)
65
+ else:
66
+ self.downsample = None
67
+
68
+ def forward(self, x):
69
+ for blk in self.blocks:
70
+ if self.use_checkpoint:
71
+ x = checkpoint.checkpoint(blk, x)
72
+ else:
73
+ x = blk(x)
74
+ if self.downsample is not None:
75
+ x = self.downsample(x)
76
+ return x
77
+
78
+ class FasterTinyViT(nn.Module):
79
+ def __init__(self, img_size=224,
80
+ in_chans=3,
81
+ out_chans=256,
82
+ embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2],
83
+ num_heads=[3, 6, 12, 24],
84
+ window_sizes=[7, 7, 14, 7],
85
+ mlp_ratio=4.,
86
+ drop_rate=0.,
87
+ drop_path_rate=0.1,
88
+ use_checkpoint=False,
89
+ mbconv_expand_ratio=4.0,
90
+ ct_size=2,
91
+ conv=False,
92
+ multi_scale=False,
93
+ output_shape=None,
94
+ ):
95
+ super().__init__()
96
+ self.img_size = img_size
97
+ self.depths = depths
98
+ self.num_layers = len(depths)
99
+ self.mlp_ratio = mlp_ratio
100
+ self.multi_scale = multi_scale
101
+ self.output_shape = tuple(output_shape) if output_shape else None
102
+
103
+ activation = nn.GELU
104
+
105
+ self.patch_embed = PatchEmbed(in_chans=in_chans,
106
+ embed_dim=embed_dims[0],
107
+ resolution=img_size,
108
+ activation=activation)
109
+
110
+ patches_resolution = self.patch_embed.patches_resolution
111
+ self.patches_resolution = patches_resolution
112
+
113
+ # stochastic depth
114
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate,
115
+ sum(depths))] # stochastic depth decay rule
116
+
117
+ # build layers
118
+ self.layers = nn.ModuleList()
119
+ for i_layer in range(self.num_layers):
120
+ kwargs_0 = dict(dim=embed_dims[i_layer],
121
+ input_resolution=(patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
122
+ patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer))),
123
+ # input_resolution=(patches_resolution[0] // (2 ** i_layer),
124
+ # patches_resolution[1] // (2 ** i_layer)),
125
+ depth=depths[i_layer],
126
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
127
+ downsample=PatchMerging if (
128
+ i_layer < self.num_layers - 1) else None,
129
+ use_checkpoint=use_checkpoint,
130
+ out_dim=embed_dims[min(
131
+ i_layer + 1, len(embed_dims) - 1)],
132
+ activation=activation,
133
+ )
134
+ kwargs_1 = dict(dim=embed_dims[i_layer],
135
+ out_dim=embed_dims[i_layer+1] if (
136
+ i_layer < self.num_layers - 1) else embed_dims[i_layer],
137
+ input_resolution=patches_resolution[0] // (2 ** i_layer),
138
+ depth=depths[i_layer],
139
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
140
+ downsample=True if (i_layer < self.num_layers - 1) else False,
141
+ ct_size=ct_size,
142
+ conv=conv,
143
+ )
144
+ if i_layer == 0:
145
+ layer = ConvLayer(
146
+ conv_expand_ratio=mbconv_expand_ratio,
147
+ **kwargs_0,
148
+ )
149
+ else:
150
+ layer = FasterViTLayer(
151
+ num_heads=num_heads[i_layer],
152
+ window_size=window_sizes[i_layer],
153
+ mlp_ratio=self.mlp_ratio,
154
+ drop=drop_rate,
155
+ **kwargs_1)
156
+ self.layers.append(layer)
157
+
158
+ # init weights
159
+ self.apply(self._init_weights)
160
+
161
+ self.neck = nn.Sequential(
162
+ nn.Conv2d(
163
+ sum(embed_dims)+embed_dims[-1] if self.multi_scale and self.output_shape else embed_dims[-1],
164
+ out_chans,
165
+ kernel_size=1,
166
+ bias=False,
167
+ ),
168
+ LayerNorm2d(out_chans),
169
+ nn.Conv2d(
170
+ out_chans,
171
+ out_chans,
172
+ kernel_size=3,
173
+ padding=1,
174
+ bias=False,
175
+ ),
176
+ LayerNorm2d(out_chans),
177
+ )
178
+
179
+ def _init_weights(self, m):
180
+ if isinstance(m, nn.Linear):
181
+ trunc_normal_(m.weight, std=.02)
182
+ if isinstance(m, nn.Linear) and m.bias is not None:
183
+ nn.init.constant_(m.bias, 0)
184
+ elif isinstance(m, nn.LayerNorm):
185
+ nn.init.constant_(m.bias, 0)
186
+ nn.init.constant_(m.weight, 1.0)
187
+
188
+ @torch.jit.ignore
189
+ def no_weight_decay_keywords(self):
190
+ return {'attention_biases'}
191
+
192
+ def forward_features(self, x):
193
+ if self.multi_scale and self.output_shape:
194
+ output_list = []
195
+ # x: (N, C, H, W)
196
+ x = self.patch_embed(x)
197
+ output_list.append(F.interpolate(x, size=self.output_shape, mode='bilinear'))
198
+ for layer in self.layers:
199
+ x = layer(x)
200
+ output_list.append(F.interpolate(x, size=self.output_shape, mode='bilinear'))
201
+ x = self.neck(torch.cat(output_list, dim=1))
202
+
203
+ else:
204
+ x = self.patch_embed(x)
205
+ for layer in self.layers:
206
+ x = layer(x)
207
+ x = self.neck(x)
208
+ return x
209
+
210
+
211
+ def forward(self, x):
212
+ x = self.forward_features(x)
213
+
214
+ return x
215
+
216
+ if __name__ == '__main__':
217
+ from distillation.utils import get_parameter_number
218
+ x = torch.randn(1, 3, 1024, 1024).cuda()
219
+ fastertinyvit = FasterTinyViT(img_size=1024, in_chans=3,
220
+ embed_dims=[64, 128, 256],
221
+ depths=[1, 2, 1],
222
+ num_heads=[2, 4, 8],
223
+ window_sizes=[8, 8, 8],
224
+ mlp_ratio=4.,
225
+ drop_rate=0.,
226
+ drop_path_rate=0.0,
227
+ use_checkpoint=False,
228
+ mbconv_expand_ratio=4.0,
229
+ multi_scale=False,
230
+ output_shape='').cuda()
231
+ print(fastertinyvit(x).shape)
232
+ print(get_parameter_number(fastertinyvit))
233
+ # torch.save(fastertinyvit, 'fastertinyvit.pt')
sam_extension/distillation_models/fastervit.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn as nn
4
+ from timm.models.layers import DropPath, LayerNorm2d
5
+ def window_partition(x, window_size):
6
+ B, C, H, W = x.shape
7
+ x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
8
+ windows = x.permute(0, 2, 4, 3, 5, 1).reshape(-1, window_size*window_size, C)
9
+ return windows
10
+
11
+
12
+ def window_reverse(windows, window_size, H, W, B):
13
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
14
+ x = x.permute(0, 5, 1, 3, 2, 4).reshape(B, windows.shape[2], H, W)
15
+ return x
16
+
17
+
18
+ def ct_dewindow(ct, W, H, window_size):
19
+ bs = ct.shape[0]
20
+ N=ct.shape[2]
21
+ ct2 = ct.view(-1, W//window_size, H//window_size, window_size, window_size, N).permute(0, 5, 1, 3, 2, 4)
22
+ ct2 = ct2.reshape(bs, N, W*H).transpose(1, 2)
23
+ return ct2
24
+
25
+
26
+ def ct_window(ct, W, H, window_size):
27
+ bs = ct.shape[0]
28
+ N = ct.shape[2]
29
+ ct = ct.view(bs, H // window_size, window_size, W // window_size, window_size, N)
30
+ ct = ct.permute(0, 1, 3, 2, 4, 5)
31
+ return ct
32
+
33
+ class PosEmbMLPSwinv2D(nn.Module):
34
+ def __init__(self,
35
+ window_size,
36
+ pretrained_window_size,
37
+ num_heads, seq_length,
38
+ ct_correct=False,
39
+ no_log=False):
40
+ super().__init__()
41
+ self.window_size = window_size
42
+ self.num_heads = num_heads
43
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
44
+ nn.ReLU(inplace=True),
45
+ nn.Linear(512, num_heads, bias=False))
46
+ relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
47
+ relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
48
+ relative_coords_table = torch.stack(
49
+ torch.meshgrid([relative_coords_h,
50
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
51
+ if pretrained_window_size[0] > 0:
52
+ relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
53
+ relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
54
+ else:
55
+ relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
56
+ relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
57
+
58
+ if not no_log:
59
+ relative_coords_table *= 8 # normalize to -8, 8
60
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
61
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
62
+
63
+ self.register_buffer("relative_coords_table", relative_coords_table)
64
+ coords_h = torch.arange(self.window_size[0])
65
+ coords_w = torch.arange(self.window_size[1])
66
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
67
+ coords_flatten = torch.flatten(coords, 1)
68
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
69
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
70
+ relative_coords[:, :, 0] += self.window_size[0] - 1
71
+ relative_coords[:, :, 1] += self.window_size[1] - 1
72
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
73
+ relative_position_index = relative_coords.sum(-1)
74
+ self.register_buffer("relative_position_index", relative_position_index)
75
+ self.grid_exists = False
76
+ self.pos_emb = None
77
+ self.deploy = False
78
+ relative_bias = torch.zeros(1, num_heads, seq_length, seq_length)
79
+ self.seq_length = seq_length
80
+ self.register_buffer("relative_bias", relative_bias)
81
+ self.ct_correct=ct_correct
82
+
83
+ def switch_to_deploy(self):
84
+ self.deploy = True
85
+
86
+ def forward(self, input_tensor, local_window_size):
87
+ if self.deploy:
88
+ input_tensor += self.relative_bias
89
+ return input_tensor
90
+ else:
91
+ self.grid_exists = False
92
+
93
+ if not self.grid_exists:
94
+ self.grid_exists = True
95
+
96
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
97
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
98
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1],
99
+ -1)
100
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
101
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
102
+ n_global_feature = input_tensor.shape[2] - local_window_size
103
+ if n_global_feature > 0 and self.ct_correct:
104
+
105
+ step_for_ct=self.window_size[0]/(n_global_feature**0.5+1)
106
+ seq_length = int(n_global_feature ** 0.5)
107
+ indices = []
108
+ for i in range(seq_length):
109
+ for j in range(seq_length):
110
+ ind = (i+1)*step_for_ct*self.window_size[0] + (j+1)*step_for_ct
111
+ indices.append(int(ind))
112
+
113
+ top_part = relative_position_bias[:, indices, :]
114
+ lefttop_part = relative_position_bias[:, indices, :][:, :, indices]
115
+ left_part = relative_position_bias[:, :, indices]
116
+ relative_position_bias = torch.nn.functional.pad(relative_position_bias, (n_global_feature,
117
+ 0,
118
+ n_global_feature,
119
+ 0)).contiguous()
120
+ if n_global_feature>0 and self.ct_correct:
121
+ relative_position_bias = relative_position_bias*0.0
122
+ relative_position_bias[:, :n_global_feature, :n_global_feature] = lefttop_part
123
+ relative_position_bias[:, :n_global_feature, n_global_feature:] = top_part
124
+ relative_position_bias[:, n_global_feature:, :n_global_feature] = left_part
125
+
126
+ self.pos_emb = relative_position_bias.unsqueeze(0)
127
+ self.relative_bias = self.pos_emb
128
+
129
+ input_tensor += self.pos_emb
130
+ return input_tensor
131
+
132
+
133
+ class PosEmbMLPSwinv1D(nn.Module):
134
+ def __init__(self,
135
+ dim,
136
+ rank=2,
137
+ seq_length=4,
138
+ conv=False):
139
+ super().__init__()
140
+ self.rank = rank
141
+ if not conv:
142
+ self.cpb_mlp = nn.Sequential(nn.Linear(self.rank, 512, bias=True),
143
+ nn.ReLU(),
144
+ nn.Linear(512, dim, bias=False))
145
+ else:
146
+ self.cpb_mlp = nn.Sequential(nn.Conv1d(self.rank, 512, 1,bias=True),
147
+ nn.ReLU(),
148
+ nn.Conv1d(512, dim, 1,bias=False))
149
+ self.grid_exists = False
150
+ self.pos_emb = None
151
+ self.deploy = False
152
+ relative_bias = torch.zeros(1,seq_length, dim)
153
+ self.register_buffer("relative_bias", relative_bias)
154
+ self.conv = conv
155
+
156
+ def switch_to_deploy(self):
157
+ self.deploy = True
158
+
159
+ def forward(self, input_tensor):
160
+ seq_length = input_tensor.shape[1] if not self.conv else input_tensor.shape[2]
161
+ if self.deploy:
162
+ return input_tensor + self.relative_bias
163
+ else:
164
+ self.grid_exists = False
165
+ if not self.grid_exists:
166
+ self.grid_exists = True
167
+ if self.rank == 1:
168
+ relative_coords_h = torch.arange(0, seq_length, device=input_tensor.device, dtype = input_tensor.dtype)
169
+ relative_coords_h -= seq_length//2
170
+ relative_coords_h /= (seq_length//2)
171
+ relative_coords_table = relative_coords_h
172
+ self.pos_emb = self.cpb_mlp(relative_coords_table.unsqueeze(0).unsqueeze(2))
173
+ self.relative_bias = self.pos_emb
174
+ else:
175
+ seq_length = int(seq_length**0.5)
176
+ relative_coords_h = torch.arange(0, seq_length, device=input_tensor.device, dtype = input_tensor.dtype)
177
+ relative_coords_w = torch.arange(0, seq_length, device=input_tensor.device, dtype = input_tensor.dtype)
178
+ relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])).contiguous().unsqueeze(0)
179
+ relative_coords_table -= seq_length // 2
180
+ relative_coords_table /= (seq_length // 2)
181
+ if not self.conv:
182
+ self.pos_emb = self.cpb_mlp(relative_coords_table.flatten(2).transpose(1,2))
183
+ else:
184
+ self.pos_emb = self.cpb_mlp(relative_coords_table.flatten(2))
185
+ self.relative_bias = self.pos_emb
186
+ input_tensor = input_tensor + self.pos_emb
187
+ return input_tensor
188
+
189
+
190
+ class Mlp(nn.Module):
191
+ """
192
+ Multi-Layer Perceptron (MLP) block
193
+ """
194
+
195
+ def __init__(self,
196
+ in_features,
197
+ hidden_features=None,
198
+ out_features=None,
199
+ act_layer=nn.GELU,
200
+ drop=0.):
201
+ """
202
+ Args:
203
+ in_features: input features dimension.
204
+ hidden_features: hidden features dimension.
205
+ out_features: output features dimension.
206
+ act_layer: activation function.
207
+ drop: dropout rate.
208
+ """
209
+
210
+ super().__init__()
211
+ out_features = out_features or in_features
212
+ hidden_features = hidden_features or in_features
213
+ self.fc1 = nn.Linear(in_features, hidden_features)
214
+ self.act = act_layer()
215
+ self.fc2 = nn.Linear(hidden_features, out_features)
216
+ self.drop = nn.Dropout(drop)
217
+
218
+ def forward(self, x):
219
+ x_size = x.size()
220
+ x = x.view(-1, x_size[-1])
221
+ x = self.fc1(x)
222
+ x = self.act(x)
223
+ x = self.drop(x)
224
+ x = self.fc2(x)
225
+ x = self.drop(x)
226
+ x = x.view(x_size)
227
+ return x
228
+
229
+ class Downsample(nn.Module):
230
+ """
231
+ Down-sampling block based on: "Hatamizadeh et al.,
232
+ FasterViT: Fast Vision Transformers with Hierarchical Attention
233
+ """
234
+
235
+ def __init__(self,
236
+ dim,
237
+ out_dim,
238
+ keep_dim=False,
239
+ stride=2,
240
+ ):
241
+ """
242
+ Args:
243
+ dim: feature size dimension.
244
+ norm_layer: normalization layer.
245
+ keep_dim: bool argument for maintaining the resolution.
246
+ """
247
+
248
+ super().__init__()
249
+ if keep_dim:
250
+ out_dim = dim
251
+ self.norm = LayerNorm2d(dim)
252
+ self.reduction = nn.Sequential(
253
+ nn.Conv2d(dim, out_dim, 3, stride, 1, bias=False),
254
+ )
255
+
256
+ def forward(self, x):
257
+ x = self.norm(x)
258
+ x = self.reduction(x)
259
+ return x
260
+ class PatchEmbed(nn.Module):
261
+ """
262
+ Patch embedding block based on: "Hatamizadeh et al.,
263
+ FasterViT: Fast Vision Transformers with Hierarchical Attention
264
+ """
265
+
266
+ def __init__(self, in_chans=3, in_dim=64, dim=96):
267
+ """
268
+ Args:
269
+ in_chans: number of input channels.
270
+ dim: feature size dimension.
271
+ """
272
+ super().__init__()
273
+ self.proj = nn.Identity()
274
+ self.conv_down = nn.Sequential(
275
+ nn.Conv2d(in_chans, in_dim, 3, 2, 1, bias=False),
276
+ nn.BatchNorm2d(in_dim, eps=1e-4),
277
+ nn.ReLU(),
278
+ nn.Conv2d(in_dim, dim, 3, 2, 1, bias=False),
279
+ nn.BatchNorm2d(dim, eps=1e-4),
280
+ nn.ReLU()
281
+ )
282
+
283
+ def forward(self, x):
284
+ x = self.proj(x)
285
+ x = self.conv_down(x)
286
+ return x
287
+
288
+
289
+ class ConvBlock(nn.Module):
290
+ """
291
+ Conv block based on: "Hatamizadeh et al.,
292
+ FasterViT: Fast Vision Transformers with Hierarchical Attention
293
+ """
294
+
295
+ def __init__(self, dim,
296
+ drop_path=0.,
297
+ layer_scale=None,
298
+ kernel_size=3):
299
+ super().__init__()
300
+ """
301
+ Args:
302
+ drop_path: drop path.
303
+ layer_scale: layer scale coefficient.
304
+ kernel_size: kernel size.
305
+ """
306
+ self.conv1 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
307
+ self.norm1 = nn.BatchNorm2d(dim, eps=1e-5)
308
+ self.act1 = nn.GELU()
309
+ self.conv2 = nn.Conv2d(dim, dim, kernel_size=kernel_size, stride=1, padding=1)
310
+ self.norm2 = nn.BatchNorm2d(dim, eps=1e-5)
311
+ self.layer_scale = layer_scale
312
+ if layer_scale is not None and type(layer_scale) in [int, float]:
313
+ self.gamma = nn.Parameter(layer_scale * torch.ones(dim))
314
+ self.layer_scale = True
315
+ else:
316
+ self.layer_scale = False
317
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
318
+
319
+ def forward(self, x, global_feature=None):
320
+ input = x
321
+ x = self.conv1(x)
322
+ x = self.norm1(x)
323
+ x = self.act1(x)
324
+ x = self.conv2(x)
325
+ x = self.norm2(x)
326
+ if self.layer_scale:
327
+ x = x * self.gamma.view(1, -1, 1, 1)
328
+ x = input + self.drop_path(x)
329
+ return x, global_feature
330
+
331
+
332
+ class WindowAttention(nn.Module):
333
+ """
334
+ Window attention based on: "Hatamizadeh et al.,
335
+ FasterViT: Fast Vision Transformers with Hierarchical Attention
336
+ """
337
+ def __init__(self,
338
+ dim,
339
+ num_heads=8,
340
+ qkv_bias=False,
341
+ qk_scale=None,
342
+ attn_drop=0.,
343
+ proj_drop=0.,
344
+ resolution=0,
345
+ seq_length=0):
346
+ super().__init__()
347
+ """
348
+ Args:
349
+ dim: feature size dimension.
350
+ num_heads: number of attention head.
351
+ qkv_bias: bool argument for query, key, value learnable bias.
352
+ qk_scale: bool argument to scaling query, key.
353
+ attn_drop: attention dropout rate.
354
+ proj_drop: output dropout rate.
355
+ resolution: feature resolution.
356
+ seq_length: sequence length.
357
+ """
358
+ self.num_heads = num_heads
359
+ head_dim = dim // num_heads
360
+ self.head_dim = dim // num_heads
361
+ self.scale = qk_scale or head_dim ** -0.5
362
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
363
+ self.attn_drop = nn.Dropout(attn_drop)
364
+ self.proj = nn.Linear(dim, dim)
365
+ self.proj_drop = nn.Dropout(proj_drop)
366
+ # attention positional bias
367
+ self.pos_emb_funct = PosEmbMLPSwinv2D(window_size=[resolution, resolution],
368
+ pretrained_window_size=[resolution, resolution],
369
+ num_heads=num_heads,
370
+ seq_length=seq_length)
371
+
372
+ self.resolution = resolution
373
+
374
+ def forward(self, x):
375
+ B, N, C = x.shape
376
+ qkv = self.qkv(x).reshape(B, -1, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
377
+ q, k, v = qkv[0], qkv[1], qkv[2]
378
+ attn = (q @ k.transpose(-2, -1)) * self.scale
379
+ attn = self.pos_emb_funct(attn, self.resolution ** 2)
380
+ attn = attn.softmax(dim=-1)
381
+ attn = self.attn_drop(attn)
382
+ x = (attn @ v).transpose(1, 2).reshape(B, -1, C)
383
+ x = self.proj(x)
384
+ x = self.proj_drop(x)
385
+ return x
386
+
387
+
388
+ class HAT(nn.Module):
389
+ """
390
+ Hierarchical attention (HAT) based on: "Hatamizadeh et al.,
391
+ FasterViT: Fast Vision Transformers with Hierarchical Attention
392
+ """
393
+ def __init__(self,
394
+ dim,
395
+ num_heads,
396
+ mlp_ratio=4.,
397
+ qkv_bias=False,
398
+ qk_scale=None,
399
+ drop=0.,
400
+ attn_drop=0.,
401
+ drop_path=0.,
402
+ act_layer=nn.GELU,
403
+ norm_layer=nn.LayerNorm,
404
+ sr_ratio=1.,
405
+ window_size=7,
406
+ last=False,
407
+ layer_scale=None,
408
+ ct_size=1,
409
+ do_propagation=False):
410
+ super().__init__()
411
+ """
412
+ Args:
413
+ dim: feature size dimension.
414
+ num_heads: number of attention head.
415
+ mlp_ratio: MLP ratio.
416
+ qkv_bias: bool argument for query, key, value learnable bias.
417
+ qk_scale: bool argument to scaling query, key.
418
+ drop: dropout rate.
419
+ attn_drop: attention dropout rate.
420
+ proj_drop: output dropout rate.
421
+ act_layer: activation function.
422
+ norm_layer: normalization layer.
423
+ sr_ratio: input to window size ratio.
424
+ window_size: window size.
425
+ last: last layer flag.
426
+ layer_scale: layer scale coefficient.
427
+ ct_size: spatial dimension of carrier token local window.
428
+ do_propagation: enable carrier token propagation.
429
+ """
430
+ # positional encoding for windowed attention tokens
431
+ self.pos_embed = PosEmbMLPSwinv1D(dim, rank=2, seq_length=window_size**2)
432
+ self.norm1 = norm_layer(dim)
433
+ # number of carrier tokens per every window
434
+ cr_tokens_per_window = ct_size**2 if sr_ratio > 1 else 0
435
+ # total number of carrier tokens
436
+ cr_tokens_total = cr_tokens_per_window*sr_ratio*sr_ratio
437
+ self.cr_window = ct_size
438
+ self.attn = WindowAttention(dim,
439
+ num_heads=num_heads,
440
+ qkv_bias=qkv_bias,
441
+ qk_scale=qk_scale,
442
+ attn_drop=attn_drop,
443
+ proj_drop=drop,
444
+ resolution=window_size,
445
+ seq_length=window_size**2 + cr_tokens_per_window)
446
+
447
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
448
+ self.norm2 = norm_layer(dim)
449
+ mlp_hidden_dim = int(dim * mlp_ratio)
450
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
451
+ self.window_size = window_size
452
+
453
+ use_layer_scale = layer_scale is not None and type(layer_scale) in [int, float]
454
+ self.gamma3 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
455
+ self.gamma4 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
456
+
457
+ self.sr_ratio = sr_ratio
458
+ if sr_ratio > 1:
459
+ # if do hierarchical attention, this part is for carrier tokens
460
+ self.hat_norm1 = norm_layer(dim)
461
+ self.hat_norm2 = norm_layer(dim)
462
+ self.hat_attn = WindowAttention(
463
+ dim,
464
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
465
+ attn_drop=attn_drop, proj_drop=drop, resolution=int(cr_tokens_total**0.5),
466
+ seq_length=cr_tokens_total)
467
+
468
+ self.hat_mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
469
+ self.hat_drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
470
+ self.hat_pos_embed = PosEmbMLPSwinv1D(dim, rank=2, seq_length=cr_tokens_total)
471
+ self.gamma1 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
472
+ self.gamma2 = nn.Parameter(layer_scale * torch.ones(dim)) if use_layer_scale else 1
473
+ self.upsampler = nn.Upsample(size=window_size, mode='nearest')
474
+
475
+ # keep track for the last block to explicitly add carrier tokens to feature maps
476
+ self.last = last
477
+ self.do_propagation = do_propagation
478
+
479
+ def forward(self, x, carrier_tokens):
480
+ B, T, N = x.shape
481
+ ct = carrier_tokens
482
+ x = self.pos_embed(x)
483
+
484
+ if self.sr_ratio > 1:
485
+ # do hierarchical attention via carrier tokens
486
+ # first do attention for carrier tokens
487
+ Bg, Ng, Hg = ct.shape
488
+
489
+ # ct are located quite differently
490
+ ct = ct_dewindow(ct, self.cr_window*self.sr_ratio, self.cr_window*self.sr_ratio, self.cr_window)
491
+
492
+ # positional bias for carrier tokens
493
+ ct = self.hat_pos_embed(ct)
494
+
495
+ # attention plus mlp
496
+ ct = ct + self.hat_drop_path(self.gamma1*self.hat_attn(self.hat_norm1(ct)))
497
+ ct = ct + self.hat_drop_path(self.gamma2*self.hat_mlp(self.hat_norm2(ct)))
498
+
499
+ # ct are put back to windows
500
+ ct = ct_window(ct, self.cr_window * self.sr_ratio, self.cr_window * self.sr_ratio, self.cr_window)
501
+
502
+ ct = ct.reshape(x.shape[0], -1, N)
503
+ # concatenate carrier_tokens to the windowed tokens
504
+ x = torch.cat((ct, x), dim=1)
505
+
506
+ # window attention together with carrier tokens
507
+ x = x + self.drop_path(self.gamma3*self.attn(self.norm1(x)))
508
+ x = x + self.drop_path(self.gamma4*self.mlp(self.norm2(x)))
509
+
510
+ if self.sr_ratio > 1:
511
+ # for hierarchical attention we need to split carrier tokens and window tokens back
512
+ ctr, x = x.split([x.shape[1] - self.window_size*self.window_size, self.window_size*self.window_size], dim=1)
513
+ ct = ctr.reshape(Bg, Ng, Hg) # reshape carrier tokens.
514
+ if self.last and self.do_propagation:
515
+ # propagate carrier token information into the image
516
+ ctr_image_space = ctr.transpose(1, 2).reshape(B, N, self.cr_window, self.cr_window)
517
+ x = x + self.gamma1 * self.upsampler(ctr_image_space.to(dtype=torch.float32)).flatten(2).transpose(1, 2).to(dtype=x.dtype)
518
+ return x, ct
519
+
520
+
521
+ class TokenInitializer(nn.Module):
522
+ """
523
+ Carrier token Initializer based on: "Hatamizadeh et al.,
524
+ FasterViT: Fast Vision Transformers with Hierarchical Attention
525
+ """
526
+ def __init__(self,
527
+ dim,
528
+ input_resolution,
529
+ window_size,
530
+ ct_size=1):
531
+ """
532
+ Args:
533
+ dim: feature size dimension.
534
+ input_resolution: input image resolution.
535
+ window_size: window size.
536
+ ct_size: spatial dimension of carrier token local window
537
+ """
538
+ super().__init__()
539
+
540
+ output_size = int(ct_size * input_resolution/window_size)
541
+ stride_size = int(input_resolution/output_size)
542
+ kernel_size = input_resolution - (output_size - 1) * stride_size
543
+ self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
544
+ to_global_feature = nn.Sequential()
545
+ to_global_feature.add_module("pos", self.pos_embed)
546
+ to_global_feature.add_module("pool", nn.AvgPool2d(kernel_size=kernel_size, stride=stride_size))
547
+ self.to_global_feature = to_global_feature
548
+ self.window_size = ct_size
549
+
550
+ def forward(self, x):
551
+ x = self.to_global_feature(x)
552
+ B, C, H, W = x.shape
553
+ ct = x.view(B, C, H // self.window_size, self.window_size, W // self.window_size, self.window_size)
554
+ ct = ct.permute(0, 2, 4, 3, 5, 1).reshape(-1, H*W, C)
555
+ return ct
556
+ class FasterViTLayer(nn.Module):
557
+ """
558
+ GCViT layer based on: "Hatamizadeh et al.,
559
+ Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
560
+ """
561
+
562
+ def __init__(self,
563
+ dim,
564
+ out_dim,
565
+ depth,
566
+ input_resolution,
567
+ num_heads,
568
+ window_size,
569
+ ct_size=1,
570
+ conv=False,
571
+ downsample=True,
572
+ mlp_ratio=4.,
573
+ qkv_bias=True,
574
+ qk_scale=None,
575
+ drop=0.,
576
+ attn_drop=0.,
577
+ drop_path=0.,
578
+ layer_scale=None,
579
+ layer_scale_conv=None,
580
+ only_local=False,
581
+ hierarchy=True,
582
+ do_propagation=False
583
+ ):
584
+ """
585
+ Args:
586
+ dim: feature size dimension.
587
+ depth: layer depth.
588
+ input_resolution: input resolution.
589
+ num_heads: number of attention head.
590
+ window_size: window size.
591
+ ct_size: spatial dimension of carrier token local window.
592
+ conv: conv_based stage flag.
593
+ downsample: downsample flag.
594
+ mlp_ratio: MLP ratio.
595
+ qkv_bias: bool argument for query, key, value learnable bias.
596
+ qk_scale: bool argument to scaling query, key.
597
+ drop: dropout rate.
598
+ attn_drop: attention dropout rate.
599
+ drop_path: drop path rate.
600
+ layer_scale: layer scale coefficient.
601
+ layer_scale_conv: conv layer scale coefficient.
602
+ only_local: local attention flag.
603
+ hierarchy: hierarchical attention flag.
604
+ do_propagation: enable carrier token propagation.
605
+ """
606
+ super().__init__()
607
+ self.conv = conv
608
+ self.transformer_block = False
609
+ if conv:
610
+ self.blocks = nn.ModuleList([
611
+ ConvBlock(dim=dim,
612
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
613
+ layer_scale=layer_scale_conv)
614
+ for i in range(depth)])
615
+ self.transformer_block = False
616
+ else:
617
+ sr_ratio = input_resolution // window_size if not only_local else 1
618
+ self.blocks = nn.ModuleList([
619
+ HAT(dim=dim,
620
+ num_heads=num_heads,
621
+ mlp_ratio=mlp_ratio,
622
+ qkv_bias=qkv_bias,
623
+ qk_scale=qk_scale,
624
+ drop=drop,
625
+ attn_drop=attn_drop,
626
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
627
+ sr_ratio=sr_ratio,
628
+ window_size=window_size,
629
+ last=(i == depth-1),
630
+ layer_scale=layer_scale,
631
+ ct_size=ct_size,
632
+ do_propagation=do_propagation,
633
+ )
634
+ for i in range(depth)])
635
+ self.transformer_block = True
636
+ self.downsample = Downsample(dim=dim, out_dim=out_dim, stride=1) if not downsample else Downsample(dim=dim, out_dim=out_dim, stride=2)
637
+ if len(self.blocks) and not only_local and input_resolution // window_size > 1 and hierarchy and not self.conv:
638
+ self.global_tokenizer = TokenInitializer(dim,
639
+ input_resolution,
640
+ window_size,
641
+ ct_size=ct_size)
642
+ self.do_gt = True
643
+ else:
644
+ self.do_gt = False
645
+
646
+ self.window_size = window_size
647
+
648
+ def forward(self, x):
649
+ ct = self.global_tokenizer(x) if self.do_gt else None
650
+ B, C, H, W = x.shape
651
+ if self.transformer_block:
652
+ x = window_partition(x, self.window_size)
653
+ for bn, blk in enumerate(self.blocks):
654
+ x, ct = blk(x, ct)
655
+ if self.transformer_block:
656
+ x = window_reverse(x, self.window_size, H, W, B)
657
+ if self.downsample is None:
658
+ return x
659
+ return self.downsample(x)
sam_extension/distillation_models/sam.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import functools
3
+
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ from typing import Optional, List, Union, Tuple, Type
11
+
12
+ from segment_anything import build_sam
13
+ from segment_anything.mobile_encoder.tiny_vit_sam import TinyViT
14
+ from segment_anything.modeling import PromptEncoder, MaskDecoder, TwoWayTransformer
15
+ from segment_anything.modeling.image_encoder import ImageEncoderViT, LayerNorm2d, PatchEmbed, Block, Attention
16
+ from segment_anything.mobile_encoder.setup_mobile_sam import load_mobile_sam
17
+ from segment_anything.modeling.sam import Sam
18
+
19
+ from sam_extension.distillation_models.fastertinyvit import FasterTinyViT
20
+ from sam_extension.distillation_models.dino import DINO
21
+ # from sam_extension.distillation_models.flashvision_transformer import FlashVisionTransformer
22
+
23
+ SAM_REPO_ID = 'YouLiXiya/YL-SAM'
24
+ hf_sam_download = functools.partial(hf_hub_download, repo_id=SAM_REPO_ID, local_dir_use_symlinks=True)
25
+
26
+
27
+ class SAMImageEncoder(nn.Module):
28
+ def __init__(self,
29
+ sam_checkpoint_path,
30
+ device='cuda'):
31
+ super(SAMImageEncoder, self).__init__()
32
+ sam = build_sam(sam_checkpoint_path).to(device)
33
+ self.image_encoder = sam.image_encoder
34
+ del sam
35
+ torch.cuda.empty_cache()
36
+ def forward(self, x):
37
+ return self.image_encoder(x)
38
+
39
+
40
+
41
+ class MobileSAMImageEncoder(nn.Module):
42
+ def __init__(self,
43
+ sam_checkpoint_path,
44
+ device='cuda'):
45
+ super(MobileSAMImageEncoder, self).__init__()
46
+ sam = load_mobile_sam(sam_checkpoint_path, device)
47
+ self.image_encoder = sam.image_encoder
48
+ del sam
49
+ torch.cuda.empty_cache()
50
+ def forward(self, x):
51
+ return self.image_encoder(x)
52
+
53
+ class SAMEncoderViT(nn.Module):
54
+ def __init__(
55
+ self,
56
+ img_size: int = 1024,
57
+ patch_size: int = 16,
58
+ in_chans: int = 3,
59
+ embed_dim: int = 768,
60
+ depth: int = 12,
61
+ num_heads: int = 12,
62
+ mlp_ratio: float = 4.0,
63
+ out_chans: int = 256,
64
+ qkv_bias: bool = True,
65
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
66
+ act_layer: Type[nn.Module] = nn.GELU,
67
+ use_abs_pos: bool = True,
68
+ use_rel_pos: bool = False,
69
+ rel_pos_zero_init: bool = True,
70
+ window_size: int = 0,
71
+ global_attn_indexes: Tuple[int, ...] = (),
72
+ multi_scale: bool = False,
73
+ output_shape: Union[Tuple, List] = None
74
+ ) -> None:
75
+ """
76
+ Args:
77
+ img_size (int): Input image size.
78
+ patch_size (int): Patch size.
79
+ in_chans (int): Number of input image channels.
80
+ embed_dim (int): Patch embedding dimension.
81
+ depth (int): Depth of ViT.
82
+ num_heads (int): Number of attention heads in each ViT block.
83
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
84
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
85
+ norm_layer (nn.Module): Normalization layer.
86
+ act_layer (nn.Module): Activation layer.
87
+ use_abs_pos (bool): If True, use absolute positional embeddings.
88
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
89
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
90
+ window_size (int): Window size for window attention blocks.
91
+ global_attn_indexes (list): Indexes for blocks using global attention.
92
+ """
93
+ super().__init__()
94
+ self.img_size = img_size
95
+ self.multi_scale = multi_scale
96
+ self.output_shape = tuple(output_shape) if output_shape else None
97
+
98
+
99
+ self.patch_embed = PatchEmbed(
100
+ kernel_size=(patch_size, patch_size),
101
+ stride=(patch_size, patch_size),
102
+ in_chans=in_chans,
103
+ embed_dim=embed_dim,
104
+ )
105
+
106
+ self.pos_embed: Optional[nn.Parameter] = None
107
+ if use_abs_pos:
108
+ # Initialize absolute positional embedding with pretrain image size.
109
+ self.pos_embed = nn.Parameter(
110
+ torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
111
+ )
112
+
113
+ self.blocks = nn.ModuleList()
114
+ for i in range(depth):
115
+ block = Block(
116
+ dim=embed_dim,
117
+ num_heads=num_heads,
118
+ mlp_ratio=mlp_ratio,
119
+ qkv_bias=qkv_bias,
120
+ norm_layer=norm_layer,
121
+ act_layer=act_layer,
122
+ use_rel_pos=use_rel_pos,
123
+ rel_pos_zero_init=rel_pos_zero_init,
124
+ window_size=window_size if i not in global_attn_indexes else 0,
125
+ input_size=(img_size // patch_size, img_size // patch_size),
126
+ )
127
+ self.blocks.append(block)
128
+
129
+ self.neck = nn.Sequential(
130
+ nn.Conv2d(
131
+ embed_dim*depth if self.multi_scale and self.output_shape else embed_dim,
132
+ out_chans,
133
+ kernel_size=1,
134
+ bias=False,
135
+ ),
136
+ LayerNorm2d(out_chans),
137
+ nn.Conv2d(
138
+ out_chans,
139
+ out_chans,
140
+ kernel_size=3,
141
+ padding=1,
142
+ bias=False,
143
+ ),
144
+ LayerNorm2d(out_chans),
145
+ )
146
+
147
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
148
+ x = self.patch_embed(x)
149
+ if self.pos_embed is not None:
150
+ x = x + self.pos_embed
151
+
152
+ if self.multi_scale and self.output_shape:
153
+ output_list = []
154
+ for blk in self.blocks:
155
+ x = blk(x)
156
+ output_list.append(F.interpolate(x.permute(0, 3, 1, 2), size=self.output_shape, mode='bilinear'))
157
+
158
+ x = self.neck(torch.cat(output_list, dim=1))
159
+ else:
160
+ for blk in self.blocks:
161
+ x = blk(x)
162
+ x = self.neck(x.permute(0, 3, 1, 2))
163
+ return x
164
+
165
+ class SAMEncoderAdaptor(nn.Module):
166
+ def __init__(self,
167
+ img_size: int,
168
+ input_size: Optional[Tuple[int, int]],
169
+ embed_dim: int = 768,
170
+ depth: int = 12,
171
+ num_heads: int = 12,
172
+ mlp_ratio: float = 4.0,
173
+ out_chans: int = 256,
174
+ qkv_bias: bool = True,
175
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
176
+ act_layer: Type[nn.Module] = nn.GELU,
177
+ use_abs_pos: bool = True,
178
+ use_rel_pos: bool = False,
179
+ rel_pos_zero_init: bool = True,
180
+ window_size: int = 0,
181
+ global_attn_indexes: Tuple[int, ...] = (),
182
+ multi_scale: bool = False,
183
+ output_shape: Union[Tuple, List] = None):
184
+ super(SAMEncoderAdaptor, self).__init__()
185
+ self.img_size = img_size
186
+ self.multi_scale = multi_scale
187
+ self.output_shape = tuple(output_shape) if output_shape else None
188
+
189
+ self.pos_embed: Optional[nn.Parameter] = None
190
+ if use_abs_pos:
191
+ # Initialize absolute positional embedding with pretrain image size.
192
+ self.pos_embed = nn.Parameter(
193
+ torch.zeros(1, input_size[0], input_size[1], embed_dim)
194
+ )
195
+ self.blocks = nn.ModuleList()
196
+ for i in range(depth):
197
+ block = Block(
198
+ dim=embed_dim,
199
+ num_heads=num_heads,
200
+ mlp_ratio=mlp_ratio,
201
+ qkv_bias=qkv_bias,
202
+ norm_layer=norm_layer,
203
+ act_layer=act_layer,
204
+ use_rel_pos=use_rel_pos,
205
+ rel_pos_zero_init=rel_pos_zero_init,
206
+ window_size=window_size if i not in global_attn_indexes else 0,
207
+ input_size=input_size,
208
+ )
209
+ self.blocks.append(block)
210
+
211
+ self.neck = nn.Sequential(
212
+ nn.Conv2d(
213
+ embed_dim * depth if self.multi_scale and self.output_shape else embed_dim,
214
+ out_chans,
215
+ kernel_size=1,
216
+ bias=False,
217
+ ),
218
+ LayerNorm2d(out_chans),
219
+ nn.Conv2d(
220
+ out_chans,
221
+ out_chans,
222
+ kernel_size=3,
223
+ padding=1,
224
+ bias=False,
225
+ ),
226
+ LayerNorm2d(out_chans),
227
+ )
228
+
229
+ def forward(self, x: torch.Tensor, original_size: Union[Tuple, List] = None) -> torch.Tensor:
230
+ if original_size:
231
+ original_size = torch.LongTensor(original_size)
232
+ output_shape = x.shape[-2:]
233
+ if original_size.ndim == 1:
234
+ original_size = original_size[None, ...]
235
+ adaptor_inputs = []
236
+ for i in range(original_size.shape[0]):
237
+ h, w = original_size[i]
238
+ if h > w:
239
+ new_h = output_shape[0]
240
+ new_w = int(w * new_h / h)
241
+ else:
242
+ new_w = output_shape[1]
243
+ new_h = int(h * new_w / w)
244
+ encoder_output = x[0].unsqueeze(0)
245
+ encoder_output = F.interpolate(encoder_output, size=(new_h, new_w), mode='bilinear')
246
+ pad_h = output_shape[0] - new_h
247
+ pad_w = output_shape[1] - new_w
248
+ encoder_output = F.pad(encoder_output, (0, pad_w, 0, pad_h))
249
+ adaptor_inputs.append(encoder_output)
250
+ adaptor_inputs = torch.cat(adaptor_inputs, dim=0)
251
+ x = adaptor_inputs.permute(0, 2, 3, 1)
252
+ if self.pos_embed is not None:
253
+ x = x + self.pos_embed
254
+ if self.multi_scale and self.output_shape:
255
+ output_list = []
256
+ for blk in self.blocks:
257
+ x = blk(x)
258
+ output_list.append(F.interpolate(x.permute(0, 3, 1, 2), size=self.output_shape, mode='bilinear'))
259
+
260
+ x = self.neck(torch.cat(output_list, dim=1))
261
+ else:
262
+ for blk in self.blocks:
263
+ x = blk(x)
264
+ x = self.neck(x.permute(0, 3, 1, 2))
265
+ return x
266
+
267
+
268
+ class DINOSAMViT(nn.Module):
269
+ def __init__(self,
270
+ dino_model_type,
271
+ device='cuda',
272
+ pca_dim=None,
273
+ **kwargs
274
+ ):
275
+ super(DINOSAMViT, self).__init__()
276
+ self.img_size = kwargs['img_size']
277
+ if not pca_dim:
278
+ pca_dim = None
279
+ self.dino = DINO(dino_model_type, device, self.img_size, pca_dim)
280
+ self.input_size = tuple(kwargs['output_shape'])
281
+ # input_size = self.dino.model.patch_embed.img_size // self.dino.model.patch_embed.img_size
282
+ # self.input_size = (input_size, input_size)
283
+ embed_dim = pca_dim if pca_dim is not None else self.dino.model.embed_dim
284
+ kwargs.update({'input_size': self.input_size, 'embed_dim': embed_dim})
285
+ self.adaptor = SAMEncoderAdaptor(**kwargs).to(device)
286
+ def extract_dino_features(self, x, transform=False, size = None):
287
+ return self.dino.extract_features(x, transform, size)
288
+ def forward(self, x, transform=False, size = None):
289
+ dino_feature = F.normalize(self.extract_dino_features(x, transform, size), dim=3)
290
+ adaptor_input = F.interpolate(dino_feature.permute(0, 3, 1, 2), size=self.input_size, mode='bilinear').permute(0, 2, 3, 1)
291
+ return self.adaptor(adaptor_input)
292
+ def setup_model(model_config):
293
+ prompt_embed_dim = 256
294
+ image_size = 1024
295
+ vit_patch_size = 16
296
+ image_embedding_size = image_size // vit_patch_size
297
+ model = eval(model_config.pop('type'))(**model_config)
298
+ if model.__class__.__name__ == 'SAMEncoderAdaptor':
299
+ adaptor = model
300
+ image_encoder = load_sam('weights/sam/mobile_sam.pt', 'mobile_sam', 'cpu').image_encoder
301
+ else:
302
+ adaptor = None
303
+ image_encoder = model
304
+ sam = Sam(
305
+ image_encoder=image_encoder,
306
+ prompt_encoder=PromptEncoder(
307
+ embed_dim=prompt_embed_dim,
308
+ image_embedding_size=(image_embedding_size, image_embedding_size),
309
+ input_image_size=(image_size, image_size),
310
+ mask_in_chans=16,
311
+ ),
312
+ mask_decoder=MaskDecoder(
313
+ num_multimask_outputs=3,
314
+ transformer=TwoWayTransformer(
315
+ depth=2,
316
+ embedding_dim=prompt_embed_dim,
317
+ mlp_dim=2048,
318
+ num_heads=8,
319
+ ),
320
+ transformer_dim=prompt_embed_dim,
321
+ iou_head_depth=3,
322
+ iou_head_hidden_dim=256,
323
+ ),
324
+ adaptor=adaptor,
325
+ pixel_mean=[123.675, 116.28, 103.53],
326
+ pixel_std=[58.395, 57.12, 57.375],
327
+ )
328
+ return sam
329
+
330
+ def load_distillation_sam(distillation_sam_ckpt_path,
331
+ device='cuda'):
332
+ ckpt = torch.load(distillation_sam_ckpt_path)
333
+ sam = setup_model(ckpt['model_config'])
334
+ sam.load_state_dict(ckpt['model'])
335
+ return sam.to(device)
336
+
337
+ def load_sam(sam_ckpt_path, sam_version, device):
338
+ if not os.path.exists(sam_ckpt_path):
339
+ parent_dir = os.path.dirname(sam_ckpt_path)
340
+ os.makedirs(parent_dir, exist_ok=True)
341
+ hf_sam_download(filename=os.path.basename(sam_ckpt_path), local_dir=parent_dir)
342
+ if sam_version == 'sam':
343
+ sam = build_sam(sam_ckpt_path).to(device)
344
+ elif sam_version == 'mobile_sam':
345
+ sam = load_mobile_sam(sam_ckpt_path, device)
346
+ elif sam_version == 'distillation_sam':
347
+ sam = load_distillation_sam(sam_ckpt_path, device)
348
+ else:
349
+ raise ValueError('sam version error, please give sam version in [sam, mobile_sam, distillation_sam]')
350
+ return sam
351
+
352
+ if __name__ == '__main__':
353
+ from distillation.utils import get_parameter_number
354
+ vit = SAMEncoderViT(depth=3,
355
+ embed_dim=256,
356
+ img_size=512,
357
+ mlp_ratio=4,
358
+ num_heads=16,
359
+ patch_size=8,
360
+ qkv_bias=True,
361
+ use_rel_pos=True,
362
+ global_attn_indexes=[1],
363
+ window_size=16,
364
+ out_chans=256,
365
+ multi_scale=False,
366
+ output_shape='').cuda()
367
+ x = torch.randn((1, 3, 512, 512)).cuda()
368
+ print(vit(x).shape)
369
+ print(get_parameter_number(vit))
sam_extension/pipeline/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .base import Pipeline
2
+ from .sam import SAMEncoderPipeline, SAMDecoderPipeline
3
+ from .owlvit import OwlViTVisionEncoderPipeline, OwlViTDecoderPipeline
4
+ from .groundingdino import GroundingDinoPipeline
sam_extension/pipeline/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (421 Bytes). View file
 
sam_extension/pipeline/__pycache__/base.cpython-38.pyc ADDED
Binary file (1.14 kB). View file
 
sam_extension/pipeline/__pycache__/groundingdino.cpython-38.pyc ADDED
Binary file (3.28 kB). View file
 
sam_extension/pipeline/__pycache__/owlvit.cpython-38.pyc ADDED
Binary file (10.8 kB). View file
 
sam_extension/pipeline/__pycache__/sam.cpython-38.pyc ADDED
Binary file (19.6 kB). View file
 
sam_extension/pipeline/base.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from typing import Union, Dict
4
+ from dataclasses import dataclass
5
+
6
+ @dataclass(repr=True)
7
+ class Output:
8
+ pass
9
+
10
+ class Pipeline(nn.Module):
11
+ def __init__(self, *args, **kwargs):
12
+ super(Pipeline, self).__init__()
13
+ self.args = args
14
+ self.kwargs = kwargs
15
+ @classmethod
16
+ def from_pretrained(cls, ckpt_path, device='cuda', *args, **kwargs):
17
+ pass
18
+ def forward(self, *args, **kwargs):
19
+ pass
20
+
sam_extension/pipeline/groundingdino.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import functools
3
+ import PIL
4
+ from PIL.Image import Image
5
+ import numpy as np
6
+ from typing import List, Union
7
+ import supervision as sv
8
+
9
+ import torch
10
+ import torchvision
11
+
12
+ from huggingface_hub import hf_hub_download
13
+ from sam_extension.pipeline import Pipeline
14
+ from groundingdino.util.inference import Model
15
+
16
+ GROUNDING_DINO_CONFIG_PATH = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
17
+ GROUNDING_DINO_CHECKPOINT_PATH = "groundingdino_swint_ogc.pth"
18
+ SAM_REPO_ID = 'YouLiXiya/YL-SAM'
19
+ LOCAL_DIR = "weights/groundingdino"
20
+ hf_sam_download = functools.partial(hf_hub_download, repo_id=SAM_REPO_ID, local_dir=LOCAL_DIR, local_dir_use_symlinks=True)
21
+ class GroundingDinoPipeline(Pipeline):
22
+ def __init__(self,
23
+ grounding_dino_config_path,
24
+ grounfing_dino_ckpt_path,
25
+ grounding_dino_model,
26
+ device,
27
+ *args,
28
+ **kwargs):
29
+ super(GroundingDinoPipeline, self).__init__(*args, **kwargs)
30
+ self.grounding_dino_config_path = grounding_dino_config_path
31
+ self.grounfing_dino_ckpt_path = grounfing_dino_ckpt_path
32
+ self.grounding_dino_model = grounding_dino_model
33
+ self.device = device
34
+
35
+
36
+ @classmethod
37
+ def from_pretrained(cls, grounding_dino_config_path, grounfing_dino_ckpt_path,device='cuda', *args, **kwargs):
38
+ if not os.path.exists(grounfing_dino_ckpt_path):
39
+ hf_sam_download(filename=os.path.basename(grounfing_dino_ckpt_path))
40
+ grounding_dino_model = Model(model_config_path=grounding_dino_config_path,
41
+ model_checkpoint_path=grounfing_dino_ckpt_path,
42
+ device=device)
43
+ return cls(grounding_dino_config_path,
44
+ grounfing_dino_ckpt_path,
45
+ grounding_dino_model,
46
+ device,
47
+ *args,
48
+ **kwargs)
49
+
50
+ def visualize_results(self,
51
+ img: Union[Image, np.ndarray],
52
+ class_list: [List],
53
+ box_threshold: float=0.25,
54
+ text_threshold: float=0.25,
55
+ nms_threshold: float=0.8,
56
+ pil: bool=True):
57
+ detections = self.forward(img, class_list, box_threshold, text_threshold)
58
+ box_annotator = sv.BoxAnnotator()
59
+ nms_idx = torchvision.ops.nms(
60
+ torch.from_numpy(detections.xyxy),
61
+ torch.from_numpy(detections.confidence),
62
+ nms_threshold
63
+ ).numpy().tolist()
64
+
65
+ detections.xyxy = detections.xyxy[nms_idx]
66
+ detections.confidence = detections.confidence[nms_idx]
67
+ detections.class_id = detections.class_id[nms_idx]
68
+ labels = [
69
+ f"{class_list[class_id]} {confidence:0.2f}"
70
+ for _, _, confidence, class_id, _
71
+ in detections]
72
+ annotated_frame = box_annotator.annotate(scene=img.copy(), detections=detections, labels=labels)
73
+ if pil:
74
+ return PIL.Image.fromarray(annotated_frame[:, :, ::-1]), detections
75
+ else:
76
+ return annotated_frame, detections
77
+
78
+
79
+ @torch.no_grad()
80
+ def forward(self,
81
+ img: Union[Image, np.ndarray],
82
+ class_list: [List],
83
+ box_threshold: float=0.25,
84
+ text_threshold: float=0.25
85
+ )->sv.Detections:
86
+ if isinstance(img, Image):
87
+ img = np.uint8(img)[:, :, ::-1]
88
+ detections = self.grounding_dino_model.predict_with_classes(
89
+ image=img,
90
+ classes=class_list,
91
+ box_threshold=box_threshold,
92
+ text_threshold=text_threshold
93
+ )
94
+ return detections
95
+
96
+
97
+
sam_extension/pipeline/owlvit.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union, List
2
+ import numpy as np
3
+ import PIL
4
+ from PIL.Image import Image
5
+ import supervision as sv
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+ from transformers import OwlViTProcessor, OwlViTForObjectDetection, OwlViTVisionModel
11
+ from transformers.models.owlvit.modeling_owlvit import center_to_corners_format, box_iou, generalized_box_iou, OwlViTObjectDetectionOutput
12
+
13
+ from sam_extension.pipeline.base import Pipeline, Output
14
+
15
+ class OwlViTVisionEncoderPipeline(Pipeline):
16
+
17
+ def __init__(self,
18
+ vision_model,
19
+ layer_norm,
20
+ processor,
21
+ device='cuda',
22
+ *args,
23
+ **kwargs):
24
+ super().__init__(*args, **kwargs)
25
+ self.vision_model = vision_model
26
+ self.layer_norm = layer_norm
27
+ self.processor = processor
28
+ self.device = device
29
+ torch.cuda.empty_cache()
30
+ @classmethod
31
+ def from_pretrained(cls, model_type, device='cuda', *args, **kwargs):
32
+ owlvit_for_object_detection = OwlViTForObjectDetection.from_pretrained(model_type).to(device)
33
+ processor = OwlViTProcessor.from_pretrained(model_type)
34
+ return cls(owlvit_for_object_detection.owlvit.vision_model,
35
+ owlvit_for_object_detection.layer_norm,
36
+ processor,
37
+ device,
38
+ *args,
39
+ **kwargs)
40
+ def process_image(self, image:Image):
41
+ image = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device)
42
+ return image
43
+ @torch.no_grad()
44
+ def forward(
45
+ self,
46
+ pixel_values: Union[torch.FloatTensor, Image] = None,
47
+ output_attentions: Optional[bool] = None,
48
+ output_hidden_states: Optional[bool] = None,
49
+ return_dict: Optional[bool] = None,
50
+ ) -> torch.FloatTensor:
51
+ if isinstance(pixel_values, Image):
52
+ pixel_values = self.process_image(pixel_values)
53
+ pixel_values = pixel_values.to(self.device)
54
+ vision_outputs = self.vision_model(
55
+ pixel_values=pixel_values,
56
+ output_attentions=output_attentions,
57
+ output_hidden_states=output_hidden_states,
58
+ return_dict=return_dict,
59
+ )
60
+ # Get image embeddings
61
+ last_hidden_state = vision_outputs[0]
62
+ image_embeds = self.vision_model.post_layernorm(last_hidden_state)
63
+ new_size = tuple(np.array(image_embeds.shape) - np.array((0, 1, 0)))
64
+ class_token_out = torch.broadcast_to(image_embeds[:, :1, :], new_size)
65
+
66
+ # Merge image embedding with class tokens
67
+ image_embeds = image_embeds[:, 1:, :] * class_token_out
68
+ image_embeds = self.layer_norm(image_embeds)
69
+
70
+ # Resize to [batch_size, num_patches, num_patches, hidden_size]
71
+ new_size = (
72
+ image_embeds.shape[0],
73
+ int(np.sqrt(image_embeds.shape[1])),
74
+ int(np.sqrt(image_embeds.shape[1])),
75
+ image_embeds.shape[-1],
76
+ )
77
+ image_embeds = image_embeds.reshape(new_size)
78
+ return image_embeds
79
+
80
+
81
+
82
+ class OwlViTDecoderPipeline(Pipeline):
83
+ prompt_template: str = 'a photo of a '
84
+ def __init__(self,
85
+ owlvit_text,
86
+ text_projection,
87
+ class_head,
88
+ box_head,
89
+ processor,
90
+ device='cuda',
91
+ *args,
92
+ **kwargs):
93
+ super().__init__(*args, **kwargs)
94
+
95
+ self.owlvit_text = owlvit_text
96
+ self.text_projection = text_projection
97
+ self.class_head = class_head
98
+ self.box_head = box_head
99
+
100
+ self.sigmoid = nn.Sigmoid()
101
+ self.processor = processor
102
+ self.device = device
103
+ torch.cuda.empty_cache()
104
+
105
+ @classmethod
106
+ def from_pretrained(cls, model_type, device='cuda', *args, **kwargs):
107
+ owlvit_for_object_detection = OwlViTForObjectDetection.from_pretrained(model_type).to(device)
108
+ processor = OwlViTProcessor.from_pretrained(model_type)
109
+ return cls(owlvit_for_object_detection.owlvit.text_model,
110
+ owlvit_for_object_detection.owlvit.text_projection,
111
+ owlvit_for_object_detection.class_head,
112
+ owlvit_for_object_detection.box_head,
113
+ processor,
114
+ device,
115
+ *args,
116
+ **kwargs)
117
+ def set_template(self, template: str):
118
+ self.prompt_template = template
119
+ def process_text(self, text:List, use_template:bool = True):
120
+ if use_template:
121
+ text = [[self.prompt_template+i for i in text[0]]]
122
+ inputs = self.processor(text=text, return_tensors="pt")
123
+ return inputs
124
+ def normalize_grid_corner_coordinates(self, feature_map: torch.FloatTensor):
125
+ # Computes normalized xy corner coordinates from feature_map.
126
+ if not feature_map.ndim == 4:
127
+ raise ValueError("Expected input shape is [batch_size, num_patches, num_patches, hidden_dim]")
128
+
129
+ device = feature_map.device
130
+ num_patches = feature_map.shape[1]
131
+
132
+ box_coordinates = np.stack(
133
+ np.meshgrid(np.arange(1, num_patches + 1), np.arange(1, num_patches + 1)), axis=-1
134
+ ).astype(np.float32)
135
+ box_coordinates /= np.array([num_patches, num_patches], np.float32)
136
+
137
+ # Flatten (h, w, 2) -> (h*w, 2)
138
+ box_coordinates = box_coordinates.reshape(
139
+ box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2]
140
+ )
141
+ box_coordinates = torch.from_numpy(box_coordinates).to(device)
142
+
143
+ return box_coordinates
144
+
145
+ def compute_box_bias(self, feature_map: torch.FloatTensor) -> torch.FloatTensor:
146
+ # The box center is biased to its position on the feature grid
147
+ box_coordinates = self.normalize_grid_corner_coordinates(feature_map)
148
+ box_coordinates = torch.clip(box_coordinates, 0.0, 1.0)
149
+
150
+ # Unnormalize xy
151
+ box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4)
152
+
153
+ # The box size is biased to the patch size
154
+ box_size = torch.full_like(box_coord_bias, 1.0 / feature_map.shape[-2])
155
+ box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4)
156
+
157
+ # Compute box bias
158
+ box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1)
159
+ return box_bias
160
+
161
+ def box_predictor(
162
+ self,
163
+ image_feats: torch.FloatTensor,
164
+ feature_map: torch.FloatTensor,
165
+ ) -> torch.FloatTensor:
166
+ """
167
+ Args:
168
+ image_feats:
169
+ Features extracted from the image, returned by the `image_text_embedder` method.
170
+ feature_map:
171
+ A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method.
172
+ Returns:
173
+ pred_boxes:
174
+ List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary.
175
+ """
176
+ # Bounding box detection head [batch_size, num_boxes, 4].
177
+ pred_boxes = self.box_head(image_feats)
178
+
179
+ # Compute the location of each token on the grid and use it to compute a bias for the bbox prediction
180
+ pred_boxes += self.compute_box_bias(feature_map)
181
+ pred_boxes = self.sigmoid(pred_boxes)
182
+ return pred_boxes
183
+
184
+ def class_predictor(
185
+ self,
186
+ image_feats: torch.FloatTensor,
187
+ query_embeds: Optional[torch.FloatTensor] = None,
188
+ query_mask: Optional[torch.Tensor] = None,
189
+ ) -> Tuple[torch.FloatTensor]:
190
+ """
191
+ Args:
192
+ image_feats:
193
+ Features extracted from the `image_text_embedder`.
194
+ query_embeds:
195
+ Text query embeddings.
196
+ query_mask:
197
+ Must be provided with query_embeddings. A mask indicating which query embeddings are valid.
198
+ """
199
+ (pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask)
200
+
201
+ return (pred_logits, image_class_embeds)
202
+
203
+ def image_text_embedder(
204
+ self,
205
+ input_ids: torch.Tensor,
206
+ image_embeds: torch.FloatTensor,
207
+ attention_mask: torch.Tensor,
208
+ output_attentions: Optional[bool] = None,
209
+ output_hidden_states: Optional[bool] = None,
210
+ ) -> Tuple[torch.FloatTensor]:
211
+
212
+ # Encode text and image
213
+ text_outputs = self.owlvit_text(
214
+ input_ids=input_ids,
215
+ attention_mask=attention_mask,
216
+ output_attentions=output_attentions,
217
+ output_hidden_states=output_hidden_states,
218
+ return_dict=True,
219
+ )
220
+ text_embeds = text_outputs[1]
221
+ text_embeds = self.text_projection(text_embeds)
222
+ text_embeds = text_embeds / torch.linalg.norm(text_embeds, ord=2, dim=-1, keepdim=True)
223
+
224
+ return (text_embeds, image_embeds, text_outputs)
225
+
226
+ def embed_image_query(
227
+ self, query_image_features: torch.FloatTensor, query_feature_map: torch.FloatTensor
228
+ ) -> torch.FloatTensor:
229
+
230
+ _, class_embeds = self.class_predictor(query_image_features)
231
+ pred_boxes = self.box_predictor(query_image_features, query_feature_map)
232
+ pred_boxes_as_corners = center_to_corners_format(pred_boxes)
233
+
234
+ # Loop over query images
235
+ best_class_embeds = []
236
+ best_box_indices = []
237
+ pred_boxes_device = pred_boxes_as_corners.device
238
+
239
+ for i in range(query_image_features.shape[0]):
240
+ each_query_box = torch.tensor([[0, 0, 1, 1]], device=pred_boxes_device)
241
+ each_query_pred_boxes = pred_boxes_as_corners[i]
242
+ ious, _ = box_iou(each_query_box, each_query_pred_boxes)
243
+
244
+ # If there are no overlapping boxes, fall back to generalized IoU
245
+ if torch.all(ious[0] == 0.0):
246
+ ious = generalized_box_iou(each_query_box, each_query_pred_boxes)
247
+
248
+ # Use an adaptive threshold to include all boxes within 80% of the best IoU
249
+ iou_threshold = torch.max(ious) * 0.8
250
+
251
+ selected_inds = (ious[0] >= iou_threshold).nonzero()
252
+ if selected_inds.numel():
253
+ selected_embeddings = class_embeds[i][selected_inds[0]]
254
+ mean_embeds = torch.mean(class_embeds[i], axis=0)
255
+ mean_sim = torch.einsum("d,id->i", mean_embeds, selected_embeddings)
256
+ best_box_ind = selected_inds[torch.argmin(mean_sim)]
257
+ best_class_embeds.append(class_embeds[i][best_box_ind])
258
+ best_box_indices.append(best_box_ind)
259
+
260
+ if best_class_embeds:
261
+ query_embeds = torch.stack(best_class_embeds)
262
+ box_indices = torch.stack(best_box_indices)
263
+ else:
264
+ query_embeds, box_indices = None, None
265
+
266
+ return query_embeds, box_indices, pred_boxes
267
+
268
+ @torch.no_grad()
269
+ def forward(
270
+ self,
271
+ image_embeds: torch.FloatTensor,
272
+ input_ids: Optional[torch.Tensor] = None,
273
+ text: Optional[List] = None,
274
+ attention_mask: Optional[torch.Tensor] = None,
275
+ output_attentions: Optional[bool] = None,
276
+ output_hidden_states: Optional[bool] = None,
277
+ return_dict: Optional[bool] = None,
278
+ ) -> OwlViTObjectDetectionOutput:
279
+ if text is not None:
280
+ inputs = self.process_text(text)
281
+ input_ids = inputs.input_ids.to(self.device)
282
+ attention_mask = inputs.attention_mask.to(self.device)
283
+ input_ids = input_ids.to(self.device)
284
+ image_embeds = image_embeds.to(self.device)
285
+ attention_mask = attention_mask.to(self.device)
286
+ output_attentions = output_attentions if output_attentions is not None else False
287
+ output_hidden_states = (
288
+ output_hidden_states if output_hidden_states is not None else False
289
+ )
290
+ return_dict = return_dict if return_dict is not None else True
291
+
292
+ # Embed images and text queries
293
+ query_embeds, feature_map, text_outputs = self.image_text_embedder(
294
+ input_ids=input_ids,
295
+ image_embeds=image_embeds,
296
+ attention_mask=attention_mask,
297
+ output_attentions=output_attentions,
298
+ output_hidden_states=output_hidden_states,
299
+ )
300
+
301
+ # Text and vision model outputs
302
+
303
+ batch_size, num_patches, num_patches, hidden_dim = feature_map.shape
304
+ image_feats = torch.reshape(feature_map, (batch_size, num_patches * num_patches, hidden_dim))
305
+
306
+ # Reshape from [batch_size * max_text_queries, hidden_dim] -> [batch_size, max_text_queries, hidden_dim]
307
+ max_text_queries = input_ids.shape[0] // batch_size
308
+ query_embeds = query_embeds.reshape(batch_size, max_text_queries, query_embeds.shape[-1])
309
+
310
+ # If first token is 0, then this is a padded query [batch_size, num_queries].
311
+ input_ids = input_ids.reshape(batch_size, max_text_queries, input_ids.shape[-1])
312
+ query_mask = input_ids[..., 0] > 0
313
+
314
+ # Predict object classes [batch_size, num_patches, num_queries+1]
315
+ (pred_logits, class_embeds) = self.class_predictor(image_feats, query_embeds, query_mask)
316
+
317
+ # Predict object boxes
318
+ pred_boxes = self.box_predictor(image_feats, feature_map)
319
+
320
+ if not return_dict:
321
+ output = (
322
+ pred_logits,
323
+ pred_boxes,
324
+ query_embeds,
325
+ feature_map,
326
+ class_embeds,
327
+ text_outputs.to_tuple(),
328
+ None,
329
+ )
330
+ output = tuple(x for x in output if x is not None)
331
+ return output
332
+
333
+ return OwlViTObjectDetectionOutput(
334
+ image_embeds=feature_map,
335
+ text_embeds=query_embeds,
336
+ pred_boxes=pred_boxes.cpu(),
337
+ logits=pred_logits.cpu(),
338
+ class_embeds=class_embeds,
339
+ text_model_output=text_outputs,
340
+ vision_model_output=None,
341
+ )
342
+
343
+ def owlvit_visualize(self,
344
+ image: Image,
345
+ texts: List,
346
+ owlvit_objectdetection_output: OwlViTObjectDetectionOutput,
347
+ score_threshold: float = 0.1,
348
+ pil=True):
349
+ target_sizes = torch.Tensor([image.size[::-1]])
350
+ # Convert outputs (bounding boxes and class logits) to COCO API
351
+ results = self.processor.post_process(outputs=owlvit_objectdetection_output, target_sizes=target_sizes)
352
+
353
+ text = texts[0]
354
+ boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
355
+ boxes_np = []
356
+ labels_list = []
357
+ # Print detected objects and rescaled box coordinates
358
+ for box, score, label in zip(boxes, scores, labels):
359
+ box = [int(i) for i in box.tolist()]
360
+ if score >= score_threshold:
361
+ labels_list.append(f"{text[label]} {round(score.item(), 3)}")
362
+ boxes_np.append(box)
363
+ print(f"Detected {text[label]} with confidence {round(score.item(), 3)} at location {box}")
364
+ boxes_np = np.array(boxes_np)
365
+ detections = sv.Detections(xyxy=boxes_np)
366
+ image_np = np.uint8(image)[:, :, ::-1]
367
+ box_annotator = sv.BoxAnnotator()
368
+ annotated_frame = box_annotator.annotate(scene=image_np.copy(), detections=detections, labels=labels_list)
369
+ if pil:
370
+ return PIL.Image.fromarray(annotated_frame[:, :, ::-1])
371
+ else:
372
+ return annotated_frame[:, :, ::-1]
sam_extension/pipeline/sam.py ADDED
@@ -0,0 +1,722 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from dataclasses import dataclass
3
+ import PIL
4
+ from PIL.Image import Image
5
+ import numpy as np
6
+ from typing import Union, Tuple, List, Optional, Callable
7
+ from sklearn.decomposition import PCA
8
+ import supervision as sv
9
+
10
+ import torch
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+ import torchvision
14
+ import torchvision.transforms as T
15
+
16
+ from segment_anything.utils.transforms import ResizeLongestSide
17
+ from segment_anything.predictor import preprocess, postprocess_masks
18
+ from segment_anything import build_sam, load_mobile_sam
19
+
20
+ from sam_extension.utils import add_prompts_tag, get_empty_detections, transform_coords
21
+ from sam_extension.pipeline.base import Pipeline, Output
22
+ from sam_extension.pipeline.groundingdino import GroundingDinoPipeline
23
+ from sam_extension.distillation_models.sam import load_distillation_sam, load_sam
24
+ from sam_extension.distillation_models import *
25
+
26
+ ORIGINAL_SAM_IMG_SIZE: int = 1024
27
+ PIXEL_MEAN = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
28
+ PIXEL_STD = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
29
+ PREPROCESS = functools.partial(preprocess, ORIGINAL_SAM_IMG_SIZE, PIXEL_MEAN, PIXEL_STD)
30
+ POSTPROCESS_MASKS = functools.partial(postprocess_masks, ORIGINAL_SAM_IMG_SIZE)
31
+
32
+ @dataclass(repr=True)
33
+ class SAMEncoderOutput(Output):
34
+ features: torch.Tensor
35
+ interm_features: List[torch.Tensor]
36
+ original_size: Tuple
37
+ input_size: Tuple
38
+
39
+ @dataclass(repr=True)
40
+ class SAMEncoderProcesImgOutput(Output):
41
+ input_image: torch.Tensor
42
+ original_size: Tuple
43
+ input_size: Tuple
44
+
45
+ @dataclass(repr=True)
46
+ class SAMDecoderPredictOutput(Output):
47
+ masks_np: np.ndarray
48
+ iou_predictions_np: np.ndarray
49
+ low_res_masks_np: np.ndarray
50
+
51
+ @dataclass(repr=True)
52
+ class SAMDecoderPredictTorchOutput(Output):
53
+ masks: torch.Tensor
54
+ iou_predictions: torch.Tensor
55
+ low_res_masks: torch.Tensor
56
+
57
+
58
+ class SAMEncoderPipeline(Pipeline):
59
+ def __init__(self,
60
+ encoder: nn.Module,
61
+ input_img_size: Tuple,
62
+ multi_output: bool,
63
+ preprocess: Callable,
64
+ transform: ResizeLongestSide,
65
+ device: str,
66
+ *args,
67
+ **kwargs):
68
+ super(SAMEncoderPipeline, self).__init__(*args, **kwargs)
69
+ self.encoder = encoder
70
+ self.input_img_size = input_img_size
71
+ self.multi_output = multi_output
72
+ self.preprocess = preprocess
73
+ self.transform = transform
74
+ self.device = device
75
+ @classmethod
76
+ def from_pretrained(cls, ckpt_path, device='cuda', *args, **kwargs):
77
+ if 'sam_version' not in kwargs.keys():
78
+ sam_version = 'sam'
79
+ else:
80
+ sam_version = kwargs['sam_version']
81
+ sam = load_sam(ckpt_path, sam_version, device)
82
+ encoder = sam.image_encoder
83
+ encoder_type = encoder.__class__.__name__
84
+ if encoder_type in ['TinyViT', 'FasterTinyViT', 'SAMEncoderViT', 'DINOSAMViT', 'FlashVisionTransformer']:
85
+ multi_output = False
86
+ if encoder_type in ['FasterTinyViT', 'SAMEncoderViT', 'DINOSAMViT', 'FlashVisionTransformer']:
87
+ input_img_size = (encoder.img_size, encoder.img_size)
88
+ if encoder_type == 'DINOSAMViT':
89
+ encoder = encoder.dino
90
+ else:
91
+ input_img_size = (ORIGINAL_SAM_IMG_SIZE, ORIGINAL_SAM_IMG_SIZE)
92
+ else:
93
+ multi_output = True
94
+ input_img_size = (ORIGINAL_SAM_IMG_SIZE, ORIGINAL_SAM_IMG_SIZE)
95
+ if sam.adaptor is None:
96
+ transform = ResizeLongestSide(ORIGINAL_SAM_IMG_SIZE)
97
+ preprocess_ = functools.partial(preprocess, ORIGINAL_SAM_IMG_SIZE, PIXEL_MEAN.to(device), PIXEL_STD.to(device))
98
+ else:
99
+ transform = T.Compose([
100
+ T.Resize(input_img_size),
101
+ T.ToTensor(),
102
+ T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
103
+ ])
104
+ preprocess_ = None
105
+ pipeline = cls(encoder=encoder,
106
+ input_img_size=input_img_size,
107
+ multi_output=multi_output,
108
+ preprocess=preprocess_,
109
+ transform=transform,
110
+ device=device)
111
+ del sam, encoder
112
+ torch.cuda.empty_cache()
113
+ return pipeline
114
+
115
+ def process_img(self, img: Union[Image, np.ndarray]) -> SAMEncoderProcesImgOutput:
116
+ if self.preprocess is not None:
117
+ if isinstance(img, Image):
118
+ img = np.uint8(img)
119
+ input_image = self.transform.apply_image(img)
120
+ input_image_torch = torch.as_tensor(input_image, device=self.device)
121
+ input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
122
+ original_size = tuple(img.shape[:2])
123
+ input_size = tuple(input_image_torch.shape[-2:])
124
+ input_image = F.interpolate(self.preprocess(input_image_torch), size=self.input_img_size, mode='bilinear')
125
+ else:
126
+ if isinstance(img, np.ndarray):
127
+ img = PIL.Image.fromarray(img)
128
+ original_size = (img.size[1], img.size[0])
129
+ if original_size[0] > original_size[1]:
130
+ input_h = 1024
131
+ input_w = int((1024 / original_size[0]) * original_size[1])
132
+ else:
133
+ input_w = 1024
134
+ input_h = int((1024 / original_size[1]) * original_size[0])
135
+ input_size = (input_h, input_w)
136
+ input_image = self.transform(img)[None, ...].to(self.device)
137
+ return SAMEncoderProcesImgOutput(input_image, original_size, input_size)
138
+ @torch.no_grad()
139
+ def get_visual_feature(self, x: Union[torch.Tensor, Image, np.ndarray]=None, **kwargs):
140
+ pca_rgb = PCA(n_components=3)
141
+ if 'sam_feature' in kwargs.keys() and 'original_size' in kwargs.keys():
142
+ sam_feature = kwargs['sam_feature']
143
+ original_size = kwargs['original_size']
144
+ else:
145
+ assert x is not None, 'please give x type Union[torch.Tensor, Image, np.ndarray] !'
146
+ sam_encoder_output = self.forward(x, **kwargs)
147
+ sam_feature = sam_encoder_output.features
148
+ original_size = sam_encoder_output.original_size
149
+ assert original_size is not None, 'please give original_size!'
150
+ sam_feature = F.interpolate(sam_feature, size=original_size, mode='bilinear').permute(0, 2, 3, 1)
151
+ b, h, w, c = sam_feature.shape
152
+ sam_feature = sam_feature.view(-1, c).cpu().numpy()
153
+ sam_feature = pca_rgb.fit_transform(sam_feature)
154
+ sam_feature = torch.Tensor(sam_feature.reshape(h, w, 3))
155
+ min_f, _ = sam_feature.min(-1)
156
+ max_f, _ = sam_feature.max(-1)
157
+ sam_feature = (sam_feature - min_f[..., None]) / (max_f[..., None] - min_f[..., None])
158
+ sam_feature = sam_feature.cpu().numpy()
159
+ sam_feature_image = PIL.Image.fromarray((sam_feature * 255).astype(np.uint8))
160
+ return sam_feature_image
161
+ def forward(self, x: Union[torch.Tensor, Image, np.ndarray], **kwargs) -> SAMEncoderOutput:
162
+ if isinstance(x, (Image, np.ndarray)):
163
+ process_img_output = self.process_img(x)
164
+ x = process_img_output.input_image
165
+ original_size = process_img_output.original_size
166
+ input_size = process_img_output.input_size
167
+ else:
168
+ original_size = kwargs.pop('original_size') if 'original_size' in kwargs.keys() else None
169
+ input_size = x.shape[-2:]
170
+ with torch.no_grad():
171
+ if self.multi_output:
172
+ features, interm_features = self.encoder(x, **kwargs)
173
+ else:
174
+ features = self.encoder(x, **kwargs)
175
+ if self.encoder.__class__.__name__ == 'DINO':
176
+ features = features.permute(0, 3, 1, 2)
177
+ interm_features = None
178
+ return SAMEncoderOutput(features, interm_features, original_size, input_size)
179
+
180
+ class SAMDecoderPipeline(Pipeline):
181
+ def __init__(self,
182
+ prompt_encoder: nn.Module,
183
+ mask_decoder: nn.Module,
184
+ adaptor: nn.Module,
185
+ mask_threshold: float,
186
+ transform: ResizeLongestSide,
187
+ postprocess_masks: Callable,
188
+ img_size: int,
189
+ device: str,
190
+ *args,
191
+ **kwargs):
192
+ super(SAMDecoderPipeline, self).__init__(*args, **kwargs)
193
+ self.prompt_encoder = prompt_encoder
194
+ self.mask_decoder = mask_decoder
195
+ self.adaptor = adaptor
196
+ self.mask_threshold = mask_threshold
197
+ self.transform = transform
198
+ self.postprocess_masks = postprocess_masks
199
+ self.img_size = img_size
200
+ self.device = device
201
+ @classmethod
202
+ def from_pretrained(cls, ckpt_path, device='cuda', *args, **kwargs):
203
+ if 'sam_version' not in kwargs.keys():
204
+ sam_version = 'sam'
205
+ else:
206
+ sam_version = kwargs['sam_version']
207
+ sam = load_sam(ckpt_path, sam_version, device)
208
+ if sam.image_encoder.__class__.__name__ == 'DINOSAMViT':
209
+ adaptor = sam.image_encoder.adaptor
210
+ elif sam.adaptor is not None:
211
+ adaptor = sam.adaptor
212
+ else:
213
+ adaptor = None
214
+ img_size = sam.image_encoder.img_size
215
+ prompt_encoder = sam.prompt_encoder
216
+ mask_decoder = sam.mask_decoder
217
+ transform = ResizeLongestSide(ORIGINAL_SAM_IMG_SIZE)
218
+ pipeline = cls(prompt_encoder=prompt_encoder,
219
+ mask_decoder=mask_decoder,
220
+ adaptor=adaptor,
221
+ mask_threshold=sam.mask_threshold,
222
+ transform=transform,
223
+ postprocess_masks=POSTPROCESS_MASKS,
224
+ img_size=img_size,
225
+ device=device)
226
+ del sam, prompt_encoder, mask_decoder
227
+ torch.cuda.empty_cache()
228
+ return pipeline
229
+ def visualize_prompt(self,
230
+ img: Union[Image, np.ndarray],
231
+ des_img: Union[Image, np.ndarray] = None,
232
+ point_labels: Union[List[int], np.ndarray] = None,
233
+ point_coords: Union[List[List[int]], np.ndarray] = None,
234
+ boxes: Union[List[List[int]], np.ndarray] = None,
235
+ pil: bool = False
236
+ ) -> Union[Image, np.ndarray]:
237
+ if des_img is not None:
238
+ if isinstance(des_img, np.ndarray):
239
+ des_shape = tuple(des_img.shape[:2])
240
+
241
+ else:
242
+ des_shape = (des_img.size[1], des_img.size[0])
243
+ src_shape = (img.size[1], img.size[0])
244
+ point_coords, boxes = transform_coords(src_shape, des_shape, point_coords, boxes)
245
+ return add_prompts_tag(des_img, point_labels, point_coords, boxes, pil)
246
+ else:
247
+ return add_prompts_tag(img, point_labels, point_coords, boxes, pil)
248
+
249
+ def visualize_results(self,
250
+ img: Union[Image, np.ndarray],
251
+ des_img: Union[Image, np.ndarray] = None,
252
+ sam_encoder_output: Optional[SAMEncoderOutput] = None,
253
+ features: Optional[torch.Tensor] = None,
254
+ interm_features: Optional[List[torch.Tensor]] = None,
255
+ original_size: Optional[Tuple] = None,
256
+ input_size: Optional[Tuple] = None,
257
+ point_coords: Optional[np.ndarray] = None,
258
+ point_labels: Optional[np.ndarray] = None,
259
+ boxes: Optional[np.ndarray] = None,
260
+ texts: Optional[List] = None,
261
+ grounding_dino_pipeline: GroundingDinoPipeline = None,
262
+ box_threshold: float = 0.25,
263
+ text_threshold: float = 0.25,
264
+ nms_threshold: float = 0.8,
265
+ detections: Optional[sv.Detections] = None,
266
+ multimask_output: bool = True,
267
+ visualize_promts: bool = True,
268
+ pil: bool = False):
269
+ if isinstance(img, Image):
270
+ img = np.uint8(img)
271
+ if des_img is not None:
272
+ if isinstance(des_img, np.ndarray):
273
+ des_shape = tuple(des_img.shape[:2])
274
+ else:
275
+ des_shape = (des_img.size[1], des_img.size[0])
276
+ src_shape = img.shape[:2]
277
+ if point_coords is not None or boxes is not None:
278
+ des_point_coords, des_boxes = transform_coords(src_shape, des_shape, point_coords, boxes)
279
+ else:
280
+ des_point_coords = None
281
+ des_boxes = None
282
+ else:
283
+ des_point_coords = None
284
+ des_boxes = None
285
+ src_shape = None
286
+ des_shape = None
287
+ detections = get_empty_detections() if detections is None else detections
288
+ mask_annotator = sv.MaskAnnotator()
289
+ result_list = []
290
+ mask_result_list = []
291
+ mask_list = []
292
+ if boxes is None and point_coords is None and point_labels is None and texts is None or \
293
+ (point_coords is not None and point_labels is not None and point_coords.shape[0] != point_labels.shape[0]):
294
+ print('no prompt given!')
295
+ result_list.append(img)
296
+ return result_list
297
+ # if boxes is not None and point_coords is not None and point_labels is not None:
298
+ # multimask_output = False
299
+ def get_annotated_image(mask_annotator,
300
+ detections,
301
+ img,
302
+ point_labels=None,
303
+ point_coords=None,
304
+ boxes=None,
305
+ visualize_promts=True,
306
+ pil=False):
307
+ annotated_image = mask_annotator.annotate(scene=img.copy(), detections=detections)
308
+ if visualize_promts:
309
+ annotated_image = add_prompts_tag(annotated_image, point_labels, point_coords, boxes=boxes, pil=pil)
310
+ else:
311
+ if pil:
312
+ annotated_image = PIL.Image.fromarray(annotated_image)
313
+ return annotated_image
314
+ def get_masked_image(img,
315
+ masks,
316
+ pil=True):
317
+ masked_image_list = []
318
+ for i in range(masks.shape[0]):
319
+ object_rgb = img * (masks[i].reshape(img.shape[0], img.shape[1], 1))
320
+ object_rgb = object_rgb.astype(np.uint8)
321
+ bkgd_mask = np.where(object_rgb == 0, 1, 0)
322
+ bkgd_mask *= 255
323
+ bkgd_mask = bkgd_mask.astype(np.uint8)
324
+ object_rgb += bkgd_mask
325
+ if pil:
326
+ masked_image_list.append(PIL.Image.fromarray(object_rgb))
327
+ else:
328
+ masked_image_list.append(object_rgb)
329
+ return masked_image_list
330
+ def interpolate_mask(mask_np, des_shape):
331
+ mask_tensor = torch.tensor(mask_np, dtype=torch.float32).unsqueeze(0)
332
+ mask_interpolate = F.interpolate(mask_tensor, size=des_shape, mode='bilinear')
333
+ mask_interpolate = (mask_interpolate+0.5).long()
334
+ mask_np = mask_interpolate.squeeze(0).numpy().astype(bool)
335
+ return mask_np
336
+
337
+ if point_coords is not None and point_labels is not None:
338
+
339
+ if src_shape is not None:
340
+ point_result = self.forward(sam_encoder_output,
341
+ features,
342
+ interm_features,
343
+ original_size,
344
+ input_size,
345
+ des_point_coords,
346
+ point_labels)
347
+ masks_np = interpolate_mask(point_result.masks_np, src_shape)
348
+ else:
349
+ point_result = self.forward(sam_encoder_output,
350
+ features,
351
+ interm_features,
352
+ original_size,
353
+ input_size,
354
+ point_coords,
355
+ point_labels)
356
+ masks_np = point_result.masks_np
357
+ if multimask_output:
358
+ for i in range(masks_np.shape[0]):
359
+ detections.mask = masks_np[i][None, ...]
360
+ mask_list.append(masks_np[i])
361
+ result_list.append(get_annotated_image(mask_annotator,
362
+ detections,
363
+ img,
364
+ point_labels=point_labels,
365
+ point_coords=point_coords,
366
+ visualize_promts=visualize_promts,
367
+ pil=pil))
368
+ mask_result_list += get_masked_image(img,
369
+ detections.mask,
370
+ pil=pil)
371
+ else:
372
+ index = np.argmax(point_result.iou_predictions_np)
373
+ detections.mask = masks_np[index][None, ...]
374
+ mask_list.append(masks_np[index])
375
+ result_list.append(get_annotated_image(mask_annotator,
376
+ detections,
377
+ img,
378
+ point_labels=point_labels,
379
+ point_coords=point_coords,
380
+ visualize_promts=visualize_promts,
381
+ pil=pil))
382
+ mask_result_list += get_masked_image(img,
383
+ detections.mask,
384
+ pil=pil)
385
+
386
+ if boxes is not None:
387
+ result_masks = []
388
+ if src_shape is not None:
389
+ boxes_ = des_boxes
390
+ else:
391
+ boxes_ = boxes
392
+ if boxes_.shape[0] > 1:
393
+ for i in range(len(boxes)):
394
+ box_result = self.forward(sam_encoder_output,
395
+ features,
396
+ interm_features,
397
+ original_size,
398
+ input_size,
399
+ box=boxes_[i])
400
+ index = np.argmax(box_result.iou_predictions_np)
401
+ result_masks.append(box_result.masks_np[index])
402
+ mask = np.array(result_masks)
403
+ if src_shape is not None:
404
+ masks_np = interpolate_mask(mask, src_shape)
405
+ else:
406
+ masks_np = mask
407
+ mask_list.append(masks_np)
408
+ detections.mask = masks_np
409
+ result_list.append(get_annotated_image(mask_annotator,
410
+ detections,
411
+ img,
412
+ boxes=boxes,
413
+ visualize_promts=visualize_promts,
414
+ pil=pil))
415
+ mask_result_list += get_masked_image(img,
416
+ detections.mask,
417
+ pil=pil)
418
+ else:
419
+ box_result = self.forward(sam_encoder_output,
420
+ features,
421
+ interm_features,
422
+ original_size,
423
+ input_size,
424
+ box=boxes_)
425
+ if src_shape is not None:
426
+ masks_np = interpolate_mask(box_result.masks_np, src_shape)
427
+ else:
428
+ masks_np = box_result.masks_np
429
+
430
+ if multimask_output:
431
+ for i in range(masks_np.shape[0]):
432
+ detections.mask = masks_np[i][None, ...]
433
+ mask_list.append(masks_np[i])
434
+ result_list.append(get_annotated_image(mask_annotator,
435
+ detections,
436
+ img,
437
+ boxes=boxes,
438
+ visualize_promts=visualize_promts,
439
+ pil=pil))
440
+ mask_result_list += get_masked_image(img,
441
+ detections.mask,
442
+ pil=pil)
443
+ else:
444
+ index = np.argmax(box_result.iou_predictions_np)
445
+ detections.mask = masks_np[index][None, ...]
446
+ mask_list.append(masks_np[index])
447
+ result_list.append(get_annotated_image(mask_annotator, detections, img, boxes=boxes, pil=pil))
448
+ mask_result_list += get_masked_image(img,
449
+ detections.mask,
450
+ pil=pil)
451
+
452
+ if texts is not None and grounding_dino_pipeline is not None:
453
+ detections = grounding_dino_pipeline(img[:, :, ::-1], texts, box_threshold, text_threshold)
454
+ box_annotator = sv.BoxAnnotator()
455
+ nms_idx = torchvision.ops.nms(
456
+ torch.from_numpy(detections.xyxy),
457
+ torch.from_numpy(detections.confidence),
458
+ nms_threshold
459
+ ).numpy().tolist()
460
+
461
+ detections.xyxy = detections.xyxy[nms_idx]
462
+ detections.confidence = detections.confidence[nms_idx]
463
+ detections.class_id = detections.class_id[nms_idx]
464
+ labels = [
465
+ f"{texts[class_id]} {confidence:0.2f}"
466
+ for _, _, confidence, class_id, _
467
+ in detections]
468
+ result_masks = []
469
+ if src_shape is not None:
470
+ _, boxes_ = transform_coords(src_shape, des_shape, boxes=detections.xyxy)
471
+ else:
472
+ boxes_ = detections.xyxy
473
+ for box in boxes_:
474
+ box_result = self.forward(sam_encoder_output,
475
+ features,
476
+ interm_features,
477
+ original_size,
478
+ input_size,
479
+ box=box)
480
+ index = np.argmax(box_result.iou_predictions_np)
481
+ result_masks.append(box_result.masks_np[index])
482
+ mask = np.array(result_masks)
483
+ if src_shape is not None:
484
+ detections.mask = interpolate_mask(mask, src_shape)
485
+ else:
486
+ detections.mask = mask
487
+ for i in range(detections.mask.shape[0]):
488
+ mask_list.append(detections.mask[i, ...])
489
+ if visualize_promts:
490
+ annotated_image = mask_annotator.annotate(scene=img[:, :, ::-1].copy(), detections=detections)
491
+ annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections, labels=labels)
492
+ else:
493
+ annotated_image = mask_annotator.annotate(scene=img[:, :, ::-1].copy(), detections=detections)
494
+
495
+ if pil:
496
+ result_list.append(PIL.Image.fromarray(annotated_image[:, :, ::-1]))
497
+ else:
498
+ result_list.append(annotated_image[:, :, ::-1])
499
+ mask_result_list += get_masked_image(img,
500
+ detections.mask,
501
+ pil=pil)
502
+
503
+ return result_list, mask_result_list, mask_list
504
+
505
+ def predict(
506
+ self,
507
+ features: torch.Tensor,
508
+ interm_features: List[torch.Tensor],
509
+ original_size: Tuple,
510
+ input_size: Tuple,
511
+ point_coords: Optional[np.ndarray] = None,
512
+ point_labels: Optional[np.ndarray] = None,
513
+ box: Optional[np.ndarray] = None,
514
+ mask_input: Optional[np.ndarray] = None,
515
+ multimask_output: bool = True,
516
+ return_logits: bool = False,
517
+ hq_token_only: bool = False,
518
+ ) -> SAMDecoderPredictOutput:
519
+ """
520
+ Predict masks for the given input prompts, using the currently set image.
521
+
522
+ Arguments:
523
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
524
+ model. Each point is in (X,Y) in pixels.
525
+ point_labels (np.ndarray or None): A length N array of labels for the
526
+ point prompts. 1 indicates a foreground point and 0 indicates a
527
+ background point.
528
+ box (np.ndarray or None): A length 4 array given a box prompt to the
529
+ model, in XYXY format.
530
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
531
+ coming from a previous prediction iteration. Has form 1xHxW, where
532
+ for SAM, H=W=256.
533
+ multimask_output (bool): If true, the model will return three masks.
534
+ For ambiguous input prompts (such as a single click), this will often
535
+ produce better masks than a single prediction. If only a single
536
+ mask is needed, the model's predicted quality score can be used
537
+ to select the best mask. For non-ambiguous prompts, such as multiple
538
+ input prompts, multimask_output=False can give better results.
539
+ return_logits (bool): If true, returns un-thresholded masks logits
540
+ instead of a binary mask.
541
+
542
+ Returns:
543
+ (np.ndarray): The output masks in CxHxW format, where C is the
544
+ number of masks, and (H, W) is the original image size.
545
+ (np.ndarray): An array of length C containing the model's
546
+ predictions for the quality of each mask.
547
+ (np.ndarray): An array of shape CxHxW, where C is the number
548
+ of masks and H=W=256. These low resolution logits can be passed to
549
+ a subsequent iteration as mask input.
550
+ """
551
+ # Transform input prompts
552
+
553
+ coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
554
+ if point_coords is not None:
555
+ assert (
556
+ point_labels is not None
557
+ ), "point_labels must be supplied if point_coords is supplied."
558
+ point_coords = self.transform.apply_coords(point_coords, original_size)
559
+ coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
560
+ labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
561
+ coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
562
+ if box is not None:
563
+ box = self.transform.apply_boxes(box, original_size)
564
+ box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
565
+ box_torch = box_torch[None, :]
566
+ if mask_input is not None:
567
+ mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
568
+ mask_input_torch = mask_input_torch[None, :, :, :]
569
+
570
+ sam_decoder_predict_torch_output = self.predict_torch(
571
+ features,
572
+ interm_features,
573
+ original_size,
574
+ input_size,
575
+ coords_torch,
576
+ labels_torch,
577
+ box_torch,
578
+ mask_input_torch,
579
+ multimask_output,
580
+ return_logits=return_logits,
581
+ hq_token_only=hq_token_only,
582
+ )
583
+
584
+ masks_np = sam_decoder_predict_torch_output.masks[0].detach().cpu().numpy()
585
+ iou_predictions_np = sam_decoder_predict_torch_output.iou_predictions[0].detach().cpu().numpy()
586
+ low_res_masks_np = sam_decoder_predict_torch_output.low_res_masks[0].detach().cpu().numpy()
587
+ return SAMDecoderPredictOutput(masks_np, iou_predictions_np, low_res_masks_np)
588
+
589
+ @torch.no_grad()
590
+ def predict_torch(
591
+ self,
592
+ features: torch.Tensor,
593
+ interm_features: List[torch.Tensor],
594
+ original_size: Tuple,
595
+ input_size: Tuple,
596
+ point_coords: Optional[torch.Tensor],
597
+ point_labels: Optional[torch.Tensor],
598
+ boxes: Optional[torch.Tensor] = None,
599
+ mask_input: Optional[torch.Tensor] = None,
600
+ multimask_output: bool = True,
601
+ return_logits: bool = False,
602
+ hq_token_only: bool = False,
603
+ ) -> SAMDecoderPredictTorchOutput:
604
+ """
605
+ Predict masks for the given input prompts, using the currently set image.
606
+ Input prompts are batched torch tensors and are expected to already be
607
+ transformed to the input frame using ResizeLongestSide.
608
+
609
+ Arguments:
610
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
611
+ model. Each point is in (X,Y) in pixels.
612
+ point_labels (torch.Tensor or None): A BxN array of labels for the
613
+ point prompts. 1 indicates a foreground point and 0 indicates a
614
+ background point.
615
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
616
+ model, in XYXY format.
617
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
618
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
619
+ for SAM, H=W=256. Masks returned by a previous iteration of the
620
+ predict method do not need further transformation.
621
+ multimask_output (bool): If true, the model will return three masks.
622
+ For ambiguous input prompts (such as a single click), this will often
623
+ produce better masks than a single prediction. If only a single
624
+ mask is needed, the model's predicted quality score can be used
625
+ to select the best mask. For non-ambiguous prompts, such as multiple
626
+ input prompts, multimask_output=False can give better results.
627
+ return_logits (bool): If true, returns un-thresholded masks logits
628
+ instead of a binary mask.
629
+
630
+ Returns:
631
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
632
+ number of masks, and (H, W) is the original image size.
633
+ (torch.Tensor): An array of shape BxC containing the model's
634
+ predictions for the quality of each mask.
635
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
636
+ of masks and H=W=256. These low res logits can be passed to
637
+ a subsequent iteration as mask input.
638
+ """
639
+
640
+ if point_coords is not None:
641
+ points = (point_coords, point_labels)
642
+ else:
643
+ points = None
644
+
645
+ # Embed prompts
646
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
647
+ points=points,
648
+ boxes=boxes,
649
+ masks=mask_input,
650
+ )
651
+
652
+ # Predict masks
653
+ low_res_masks, iou_predictions = self.mask_decoder(
654
+ image_embeddings=features,
655
+ image_pe=self.prompt_encoder.get_dense_pe(),
656
+ sparse_prompt_embeddings=sparse_embeddings,
657
+ dense_prompt_embeddings=dense_embeddings,
658
+ multimask_output=multimask_output,
659
+ hq_token_only=hq_token_only,
660
+ interm_embeddings=interm_features,
661
+ )
662
+
663
+ # Upscale the masks to the original image resolution
664
+ # masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
665
+ masks = self.postprocess_masks(low_res_masks, input_size, original_size)
666
+
667
+ if not return_logits:
668
+ masks = masks > self.mask_threshold
669
+
670
+ return SAMDecoderPredictTorchOutput(masks, iou_predictions, low_res_masks)
671
+ def forward(self,
672
+ sam_encoder_output: Optional[SAMEncoderOutput]=None,
673
+ features: Optional[torch.Tensor]=None,
674
+ interm_features: Optional[List[torch.Tensor]]=None,
675
+ original_size: Optional[Tuple]=None,
676
+ input_size: Optional[Tuple]=None,
677
+ point_coords: Optional[np.ndarray] = None,
678
+ point_labels: Optional[np.ndarray] = None,
679
+ box: Optional[np.ndarray] = None,
680
+ mask_input: Optional[np.ndarray] = None,
681
+ multimask_output: bool = True,
682
+ return_logits: bool = False,
683
+ hq_token_only: bool = False,
684
+ dino: bool = False
685
+ ) -> SAMDecoderPredictOutput:
686
+ assert sam_encoder_output or (features is not None and original_size is not None and input_size is not None), 'one of sam_encoder_output and four necessary inputs must be given!'
687
+ if sam_encoder_output:
688
+ features = sam_encoder_output.features
689
+ interm_features = sam_encoder_output.interm_features
690
+ original_size = sam_encoder_output.original_size
691
+ input_size = sam_encoder_output.input_size
692
+ if self.adaptor is not None:
693
+ if dino:
694
+ features = F.interpolate(F.normalize(features, dim=1), size=(64, 64), mode='bilinear').permute(0, 2, 3, 1)
695
+ features = self.adaptor(features)
696
+ #
697
+ # else:
698
+ # features = self.adaptor(features, original_size)
699
+
700
+ return self.predict(features,
701
+ interm_features,
702
+ original_size,
703
+ input_size,
704
+ point_coords,
705
+ point_labels,
706
+ box,
707
+ mask_input,
708
+ multimask_output,
709
+ return_logits,
710
+ hq_token_only)
711
+
712
+ '''
713
+ class SAMPipeline(Pipeline):
714
+ @classmethod
715
+ def from_pretrained(cls, ckpt_path, device='cuda', *args, **kwargs):
716
+ sam_encoder_pipeline = SAMEncoderPipeline(ckpt_path, device, *args, **kwargs)
717
+ sam_decoder_pipeline = SAMDecoderPipeline(ckpt_path, device, *args, **kwargs)
718
+ pipeline = cls(**dict(sam_encoder_pipeline=sam_encoder_pipeline,
719
+ sam_decoder_pipeline=sam_decoder_pipeline,
720
+ device=device))
721
+ return pipeline
722
+ '''
sam_extension/utils/__init__.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import cv2
3
+ import PIL
4
+ import torch
5
+ from PIL.Image import Image
6
+ from typing import Union, Tuple, List, Optional
7
+ import numpy as np
8
+ import supervision as sv
9
+ from sklearn.decomposition import PCA
10
+
11
+ # def add_points_tag(img: Union[Image, np.ndarray],
12
+ # point_labels: Union[List[int], np.ndarray] = None,
13
+ # point_coords: Union[List[List[int]], np.ndarray] = None,
14
+ # pil: bool = False):
15
+ # if point_labels is None or point_coords is None or \
16
+ # not isinstance(point_labels, (List, np.ndarray)) or \
17
+ # not isinstance(point_coords, (List, np.ndarray)):
18
+ # return img
19
+ # if len(point_labels) != len(point_coords):
20
+ # print('length of point_label and point_coordinate must be same!')
21
+ # return img
22
+ # if isinstance(img, Image):
23
+ # img = np.uint8(img)
24
+ # start_angle = 40
25
+ # x = 8
26
+ # y = 2
27
+ # def get_point(angle, d, base):
28
+ # angle = angle / 180.0 * math.pi
29
+ # _x, _y = math.cos(angle) * d, math.sin(angle) * d
30
+ # return [base[0] + _x, base[1] - _y]
31
+ # # assert len(point_labels) == len(point_coords), ''
32
+ # for i in range(len(point_labels)):
33
+ # points = []
34
+ # for j in range(5):
35
+ # _x, _y = math.cos(start_angle), math.sin(start_angle)
36
+ # points.append(get_point(start_angle, x, point_coords[i]))
37
+ # start_angle -= 36
38
+ # points.append(get_point(start_angle, y, point_coords[i]))
39
+ # start_angle -= 36
40
+ # points = np.array([points], np.int32)
41
+ # color = (255, 0, 0) if point_labels[i] == 0 else (0, 255, 0)
42
+ # cv2.fillPoly(img, points, color, cv2.LINE_AA)
43
+ # if pil:
44
+ # img = PIL.Image.fromarray(img)
45
+ # return img
46
+ def add_points_tag(img: Union[Image, np.ndarray],
47
+ point_labels: Union[List[int], np.ndarray] = None,
48
+ point_coords: Union[List[List[int]], np.ndarray] = None,
49
+ pil: bool = False):
50
+ if point_labels is None or point_coords is None or \
51
+ not isinstance(point_labels, (List, np.ndarray)) or \
52
+ not isinstance(point_coords, (List, np.ndarray)):
53
+ return img
54
+ if len(point_labels) != len(point_coords):
55
+ print('length of point_label and point_coordinate must be same!')
56
+ return img
57
+ if isinstance(img, Image):
58
+ img = np.array(img)
59
+ # img.flags.writeable = True
60
+ h, w = img.shape[:2]
61
+ x_start_list, x_end_list = np.where((point_coords[:, 0] - 4) > 0, point_coords[:, 0] - 4, 0), np.where((point_coords[:, 0] + 4) < w, point_coords[:, 0] + 4, w)
62
+ y_start_list, y_end_list = np.where((point_coords[:, 1] - 4) > 0, point_coords[:, 1] - 4, 0), np.where((point_coords[:, 1] + 4) < h, point_coords[:, 1] + 4, h)
63
+ for i in range(len(point_labels)):
64
+ x_start, x_end = x_start_list[i], x_end_list[i]
65
+ y_start, y_end = y_start_list[i], y_end_list[i]
66
+ label = point_labels[i]
67
+ color = [0, 255, 0] if int(label) == 1 else [255, 0, 0]
68
+ for x in range(x_start, x_end):
69
+ for y in range(y_start, y_end):
70
+ img[y, x, :] = color
71
+ if pil:
72
+ img = PIL.Image.fromarray(img)
73
+ return img
74
+ def add_boxes_tag(img: Union[Image, np.ndarray],
75
+ boxes: Union[List[List[int]], np.ndarray] = None,
76
+ pil: bool = False):
77
+ if boxes is None or not isinstance(boxes, (List, np.ndarray)):
78
+ return img
79
+ # if isinstance(boxes, np.ndarray):
80
+ # if not boxes.all():
81
+ # return img
82
+ # else:
83
+ # if not boxes:
84
+ # return img
85
+ if isinstance(img, Image):
86
+ img = np.uint8(img)
87
+ thickness = 2
88
+ for i in range(len(boxes)):
89
+ color = (0, 255, 0)
90
+ img = cv2.rectangle(img, (boxes[i][0], boxes[i][1]), (boxes[i][2], boxes[i][3]), color, thickness)
91
+ if pil:
92
+ img = PIL.Image.fromarray(img)
93
+ return img
94
+
95
+ def add_prompts_tag(img: Union[Image, np.ndarray],
96
+ point_labels: Union[List[int], np.ndarray] = None,
97
+ point_coords: Union[List[List[int]], np.ndarray] = None,
98
+ boxes: Union[List[List[int]], np.ndarray] = None,
99
+ pil: bool = False):
100
+ img = add_points_tag(img, point_labels, point_coords, pil=pil)
101
+ img = add_boxes_tag(img, boxes, pil=pil)
102
+ return img
103
+
104
+
105
+ def get_empty_detections():
106
+ detections = sv.Detections(xyxy=np.array([0, 0, 0, 0]).reshape(1, 4))
107
+ detections.xyxy = None
108
+ return detections
109
+
110
+
111
+ def pca_feature(feature: torch.Tensor, dim: int = 3, return_np: bool = True):
112
+ pca = PCA(n_components=dim)
113
+ H, W, C = feature.shape
114
+ feature = feature.view(-1, C).cpu().numpy()
115
+ feature = pca.fit_transform(feature)
116
+ feature = torch.tensor(feature.reshape(H, W, dim))
117
+ if return_np:
118
+ return feature.numpy()
119
+ else:
120
+ return feature
121
+
122
+ def visual_feature_rgb(feature: torch.Tensor, pil:bool = True):
123
+ assert feature.ndim >= 3, 'the dim of feature must >= 3!'
124
+ if feature.ndim == 4:
125
+ feature = feature.squeeze(0)
126
+ if feature.shape[-1] != 3:
127
+ feature = pca_feature(feature, 3, False)
128
+ max_f, _ = feature.max(-1)
129
+ min_f, _ = feature.min(-1)
130
+ feature = (feature - min_f[..., None]) / (max_f[..., None] - min_f[..., None])
131
+ feature = np.uint8((feature*255).cpu().numpy())
132
+ if pil:
133
+ return PIL.Image.fromarray(feature)
134
+ else:
135
+ return feature
136
+
137
+ def transform_coords(src_shape, des_shape, points = None, boxes = None):
138
+ assert points is not None or boxes is not None, 'one of points and boxes must be given!'
139
+ scale_h = des_shape[0] / src_shape[0]
140
+ scale_w = des_shape[1] / src_shape[1]
141
+ if points is not None:
142
+ new_points = np.full_like(points, 0)
143
+ new_points[:, 0] = points[:, 0] * scale_w
144
+ new_points[:, 1] = points[:, 1] * scale_h
145
+ new_points.astype(np.int64)
146
+ else:
147
+ new_points = None
148
+ if boxes is not None:
149
+ new_boxes = np.full_like(boxes, 0)
150
+ new_boxes[:, 0] = boxes[:, 0] * scale_w
151
+ new_boxes[:, 1] = boxes[:, 1] * scale_h
152
+ new_boxes[:, 2] = boxes[:, 2] * scale_w
153
+ new_boxes[:, 3] = boxes[:, 3] * scale_h
154
+ new_boxes.astype(np.int64)
155
+ else:
156
+ new_boxes = None
157
+ return new_points, new_boxes
158
+
159
+
160
+ def mask2greyimg(mask_list, pil=True):
161
+ grey_img_list = []
162
+ for mask in mask_list:
163
+ if pil:
164
+ grey_img_list.append(PIL.Image.fromarray(np.uint8(mask*255)))
165
+ else:
166
+ grey_img_list.append(np.uint8(mask * 255))
167
+ return grey_img_list
168
+ if __name__ == '__main__':
169
+ src_shape = (100,100)
170
+ des_shape = (200,200)
171
+ points = np.array([[20,20],[40,40]])
172
+ boxes = np.array([[10,10,20,20]])
173
+ new_points, new_boxes = transform_coords(src_shape, des_shape, points, boxes)
174
+ print(new_points, new_boxes)
175
+
sam_extension/utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (4.51 kB). View file