huzey commited on
Commit
9576e83
1 Parent(s): e3b132f

update MobileSAM

Browse files
Files changed (2) hide show
  1. app.py +167 -1
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,6 +1,7 @@
1
  from typing import Optional, Tuple
2
  from einops import rearrange
3
  import torch
 
4
  from PIL import Image
5
  import torchvision.transforms as transforms
6
  from torch import nn
@@ -12,8 +13,168 @@ import gradio as gr
12
 
13
  use_cuda = torch.cuda.is_available()
14
 
 
 
15
  print("CUDA is available:", use_cuda)
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  class SAM(torch.nn.Module):
19
  def __init__(self, checkpoint="/data/sam_model/sam_vit_b_01ec64.pth", **kwargs):
@@ -307,6 +468,8 @@ def image_clip_feature(
307
  def extract_features(images, model_name="sam", node_type="block", layer=-1):
308
  if model_name == "SAM(sam_vit_b)":
309
  return image_sam_feature(images, node_type=node_type, layer=layer)
 
 
310
  elif model_name == "DiNO(dinov2_vitb14_reg)":
311
  return image_dino_feature(images, node_type=node_type, layer=layer)
312
  elif model_name == "CLIP(openai/clip-vit-base-patch16)":
@@ -346,6 +509,9 @@ def compute_ncut(
346
  )
347
  print(f"t-SNE time: {time.time() - start:.2f}s")
348
 
 
 
 
349
  rgb = rgb.reshape(features.shape[:3] + (3,))
350
  return rgb
351
 
@@ -413,7 +579,7 @@ demo = gr.Interface(
413
  main_fn,
414
  [
415
  gr.Gallery(value=default_images, label="Select images", show_label=False, elem_id="images", columns=[3], rows=[1], object_fit="contain", height="auto", type="pil"),
416
- gr.Dropdown(["SAM(sam_vit_b)", "DiNO(dinov2_vitb14_reg)", "CLIP(openai/clip-vit-base-patch16)"], label="Model", value="SAM(sam_vit_b)", elem_id="model_name"),
417
  gr.Dropdown(["attn", "mlp", "block"], label="Node type", value="block", elem_id="node_type", info="attn: attention output, mlp: mlp output, block: sum of residual stream"),
418
  gr.Slider(0, 11, step=1, label="Layer", value=11, elem_id="layer", info="which layer of the image backbone features"),
419
  gr.Slider(1, 1000, step=1, label="Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more object parts, decrease for whole object'),
 
1
  from typing import Optional, Tuple
2
  from einops import rearrange
3
  import torch
4
+ import torch.nn.functional as F
5
  from PIL import Image
6
  import torchvision.transforms as transforms
7
  from torch import nn
 
13
 
14
  use_cuda = torch.cuda.is_available()
15
 
16
+ # use_cuda = False
17
+
18
  print("CUDA is available:", use_cuda)
19
 
20
+ class MobileSAM(nn.Module):
21
+ def __init__(self, **kwargs):
22
+ super().__init__(**kwargs)
23
+
24
+ from mobile_sam import sam_model_registry
25
+
26
+ url = 'https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/master/weights/mobile_sam.pt'
27
+ model_type = "vit_t"
28
+ sam_checkpoint = "mobile_sam.pt"
29
+ if not os.path.exists(sam_checkpoint):
30
+ import requests
31
+ r = requests.get(url)
32
+ with open(sam_checkpoint, 'wb') as f:
33
+ f.write(r.content)
34
+
35
+ device = 'cuda' if use_cuda else 'cpu'
36
+
37
+ mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
38
+
39
+ def new_forward_fn(self, x):
40
+ shortcut = x
41
+
42
+ x = self.conv1(x)
43
+ x = self.act1(x)
44
+
45
+ x = self.conv2(x)
46
+ x = self.act2(x)
47
+
48
+ self.attn_output = rearrange(x.clone(), "b c h w -> b h w c")
49
+
50
+ x = self.conv3(x)
51
+
52
+ self.mlp_output = rearrange(x.clone(), "b c h w -> b h w c")
53
+
54
+ x = self.drop_path(x)
55
+
56
+ x += shortcut
57
+ x = self.act3(x)
58
+
59
+ self.block_output = rearrange(x.clone(), "b c h w -> b h w c")
60
+
61
+ return x
62
+
63
+ setattr(mobile_sam.image_encoder.layers[0].blocks[0].__class__, "forward", new_forward_fn)
64
+
65
+ def new_forward_fn2(self, x):
66
+ H, W = self.input_resolution
67
+ B, L, C = x.shape
68
+ assert L == H * W, "input feature has wrong size"
69
+ res_x = x
70
+ if H == self.window_size and W == self.window_size:
71
+ x = self.attn(x)
72
+ else:
73
+ x = x.view(B, H, W, C)
74
+ pad_b = (self.window_size - H %
75
+ self.window_size) % self.window_size
76
+ pad_r = (self.window_size - W %
77
+ self.window_size) % self.window_size
78
+ padding = pad_b > 0 or pad_r > 0
79
+
80
+ if padding:
81
+ x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
82
+
83
+ pH, pW = H + pad_b, W + pad_r
84
+ nH = pH // self.window_size
85
+ nW = pW // self.window_size
86
+ # window partition
87
+ x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape(
88
+ B * nH * nW, self.window_size * self.window_size, C)
89
+ x = self.attn(x)
90
+ # window reverse
91
+ x = x.view(B, nH, nW, self.window_size, self.window_size,
92
+ C).transpose(2, 3).reshape(B, pH, pW, C)
93
+
94
+ if padding:
95
+ x = x[:, :H, :W].contiguous()
96
+
97
+ x = x.view(B, L, C)
98
+
99
+ hw = np.sqrt(x.shape[1]).astype(int)
100
+ self.attn_output = rearrange(x.clone(), "b (h w) c -> b h w c", h=hw)
101
+
102
+ x = res_x + self.drop_path(x)
103
+
104
+ x = x.transpose(1, 2).reshape(B, C, H, W)
105
+ x = self.local_conv(x)
106
+ x = x.view(B, C, L).transpose(1, 2)
107
+
108
+ mlp_output = self.mlp(x)
109
+ self.mlp_output = rearrange(mlp_output.clone(), "b (h w) c -> b h w c", h=hw)
110
+
111
+ x = x + self.drop_path(mlp_output)
112
+ self.block_output = rearrange(x.clone(), "b (h w) c -> b h w c", h=hw)
113
+ return x
114
+
115
+ setattr(mobile_sam.image_encoder.layers[1].blocks[0].__class__, "forward", new_forward_fn2)
116
+
117
+
118
+ mobile_sam.to(device=device)
119
+ mobile_sam.eval()
120
+ self.image_encoder = mobile_sam.image_encoder
121
+
122
+
123
+ @torch.no_grad()
124
+ def forward(self, x):
125
+ with torch.no_grad():
126
+ x = torch.nn.functional.interpolate(x, size=(1024, 1024), mode="bilinear")
127
+ out = self.image_encoder(x)
128
+
129
+ attn_outputs, mlp_outputs, block_outputs = [], [], []
130
+ for i_layer in range(len(self.image_encoder.layers)):
131
+ for i_block in range(len(self.image_encoder.layers[i_layer].blocks)):
132
+ blk = self.image_encoder.layers[i_layer].blocks[i_block]
133
+ attn_outputs.append(blk.attn_output)
134
+ mlp_outputs.append(blk.mlp_output)
135
+ block_outputs.append(blk.block_output)
136
+ return attn_outputs, mlp_outputs, block_outputs
137
+
138
+
139
+ def image_mobilesam_feature(
140
+ images,
141
+ resolution=(1024, 1024),
142
+ node_type="block",
143
+ layer=-1,
144
+ ):
145
+
146
+ transform = transforms.Compose(
147
+ [
148
+ transforms.Resize(resolution),
149
+ transforms.ToTensor(),
150
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
151
+ ]
152
+ )
153
+
154
+
155
+ feat_extractor = MobileSAM()
156
+
157
+ # attn_outputs, mlp_outputs, block_outputs = [], [], []
158
+ outputs = []
159
+ for i, image in enumerate(images):
160
+ torch_image = transform(image)
161
+ if use_cuda:
162
+ torch_image = torch_image.cuda()
163
+ attn_output, mlp_output, block_output = feat_extractor(
164
+ torch_image.unsqueeze(0)
165
+ )
166
+ out_dict = {
167
+ "attn": attn_output,
168
+ "mlp": mlp_output,
169
+ "block": block_output,
170
+ }
171
+ out = out_dict[node_type]
172
+ out = out[layer]
173
+ outputs.append(out.cpu())
174
+ outputs = torch.cat(outputs, dim=0)
175
+ return outputs
176
+
177
+
178
 
179
  class SAM(torch.nn.Module):
180
  def __init__(self, checkpoint="/data/sam_model/sam_vit_b_01ec64.pth", **kwargs):
 
468
  def extract_features(images, model_name="sam", node_type="block", layer=-1):
469
  if model_name == "SAM(sam_vit_b)":
470
  return image_sam_feature(images, node_type=node_type, layer=layer)
471
+ elif model_name == 'MobileSAM':
472
+ return image_mobilesam_feature(images, node_type=node_type, layer=layer)
473
  elif model_name == "DiNO(dinov2_vitb14_reg)":
474
  return image_dino_feature(images, node_type=node_type, layer=layer)
475
  elif model_name == "CLIP(openai/clip-vit-base-patch16)":
 
509
  )
510
  print(f"t-SNE time: {time.time() - start:.2f}s")
511
 
512
+ # print("input shape:", features.shape)
513
+ # print("output shape:", rgb.shape)
514
+
515
  rgb = rgb.reshape(features.shape[:3] + (3,))
516
  return rgb
517
 
 
579
  main_fn,
580
  [
581
  gr.Gallery(value=default_images, label="Select images", show_label=False, elem_id="images", columns=[3], rows=[1], object_fit="contain", height="auto", type="pil"),
582
+ gr.Dropdown(["MobileSAM", "SAM(sam_vit_b)", "DiNO(dinov2_vitb14_reg)", "CLIP(openai/clip-vit-base-patch16)"], label="Model", value="MobileSAM", elem_id="model_name"),
583
  gr.Dropdown(["attn", "mlp", "block"], label="Node type", value="block", elem_id="node_type", info="attn: attention output, mlp: mlp output, block: sum of residual stream"),
584
  gr.Slider(0, 11, step=1, label="Layer", value=11, elem_id="layer", info="which layer of the image backbone features"),
585
  gr.Slider(1, 1000, step=1, label="Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more object parts, decrease for whole object'),
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  ncut-pytorch
2
  transformers
3
  segment-anything @ git+https://github.com/facebookresearch/segment-anything.git
 
4
  --extra-index-url https://download.pytorch.org/whl/cpu
5
  torch
6
  torchvision
 
1
  ncut-pytorch
2
  transformers
3
  segment-anything @ git+https://github.com/facebookresearch/segment-anything.git
4
+ mobile-sam @ git+https://github.com/ChaoningZhang/MobileSAM.git
5
  --extra-index-url https://download.pytorch.org/whl/cpu
6
  torch
7
  torchvision