Spaces:
Running
on
Zero
Running
on
Zero
update gpu
Browse files
app.py
CHANGED
@@ -11,11 +11,11 @@ import time
|
|
11 |
|
12 |
import gradio as gr
|
13 |
|
14 |
-
|
15 |
|
16 |
-
|
17 |
|
18 |
-
print("CUDA is available:",
|
19 |
|
20 |
class MobileSAM(nn.Module):
|
21 |
def __init__(self, **kwargs):
|
@@ -32,7 +32,7 @@ class MobileSAM(nn.Module):
|
|
32 |
with open(sam_checkpoint, 'wb') as f:
|
33 |
f.write(r.content)
|
34 |
|
35 |
-
device = 'cuda' if
|
36 |
|
37 |
mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
38 |
|
@@ -135,6 +135,7 @@ class MobileSAM(nn.Module):
|
|
135 |
block_outputs.append(blk.block_output)
|
136 |
return attn_outputs, mlp_outputs, block_outputs
|
137 |
|
|
|
138 |
|
139 |
def image_mobilesam_feature(
|
140 |
images,
|
@@ -152,13 +153,15 @@ def image_mobilesam_feature(
|
|
152 |
)
|
153 |
|
154 |
|
155 |
-
feat_extractor =
|
|
|
|
|
156 |
|
157 |
# attn_outputs, mlp_outputs, block_outputs = [], [], []
|
158 |
outputs = []
|
159 |
for i, image in enumerate(images):
|
160 |
torch_image = transform(image)
|
161 |
-
if
|
162 |
torch_image = torch_image.cuda()
|
163 |
attn_output, mlp_output, block_output = feat_extractor(
|
164 |
torch_image.unsqueeze(0)
|
@@ -172,15 +175,25 @@ def image_mobilesam_feature(
|
|
172 |
out = out[layer]
|
173 |
outputs.append(out.cpu())
|
174 |
outputs = torch.cat(outputs, dim=0)
|
|
|
|
|
175 |
return outputs
|
176 |
|
177 |
|
178 |
|
179 |
class SAM(torch.nn.Module):
|
180 |
-
def __init__(self,
|
181 |
super().__init__(**kwargs)
|
182 |
from segment_anything import sam_model_registry, SamPredictor
|
183 |
from segment_anything.modeling.sam import Sam
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
sam: Sam = sam_model_registry["vit_b"](checkpoint=checkpoint)
|
186 |
|
@@ -215,7 +228,7 @@ class SAM(torch.nn.Module):
|
|
215 |
|
216 |
self.image_encoder = sam.image_encoder
|
217 |
self.image_encoder.eval()
|
218 |
-
if
|
219 |
self.image_encoder = self.image_encoder.cuda()
|
220 |
|
221 |
@torch.no_grad()
|
@@ -234,6 +247,7 @@ class SAM(torch.nn.Module):
|
|
234 |
block_outputs = torch.stack(block_outputs)
|
235 |
return attn_outputs, mlp_outputs, block_outputs
|
236 |
|
|
|
237 |
|
238 |
def image_sam_feature(
|
239 |
images,
|
@@ -249,22 +263,16 @@ def image_sam_feature(
|
|
249 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
250 |
]
|
251 |
)
|
252 |
-
|
253 |
-
checkpoint = "sam_vit_b_01ec64.pth"
|
254 |
-
if not os.path.exists(checkpoint):
|
255 |
-
checkpoint_url = 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth'
|
256 |
-
import requests
|
257 |
-
r = requests.get(checkpoint_url)
|
258 |
-
with open(checkpoint, 'wb') as f:
|
259 |
-
f.write(r.content)
|
260 |
|
261 |
-
feat_extractor =
|
|
|
|
|
262 |
|
263 |
# attn_outputs, mlp_outputs, block_outputs = [], [], []
|
264 |
outputs = []
|
265 |
for i, image in enumerate(images):
|
266 |
torch_image = transform(image)
|
267 |
-
if
|
268 |
torch_image = torch_image.cuda()
|
269 |
attn_output, mlp_output, block_output = feat_extractor(
|
270 |
torch_image.unsqueeze(0)
|
@@ -278,6 +286,9 @@ def image_sam_feature(
|
|
278 |
out = out[layer]
|
279 |
outputs.append(out.cpu())
|
280 |
outputs = torch.cat(outputs, dim=0)
|
|
|
|
|
|
|
281 |
return outputs
|
282 |
|
283 |
|
@@ -287,7 +298,7 @@ class DiNOv2(torch.nn.Module):
|
|
287 |
self.dinov2 = torch.hub.load("facebookresearch/dinov2", ver)
|
288 |
self.dinov2.requires_grad_(False)
|
289 |
self.dinov2.eval()
|
290 |
-
if
|
291 |
self.dinov2 = self.dinov2.cuda()
|
292 |
|
293 |
def new_block_forward(self, x: torch.Tensor) -> torch.Tensor:
|
@@ -325,6 +336,7 @@ class DiNOv2(torch.nn.Module):
|
|
325 |
block_outputs = torch.stack(block_outputs)
|
326 |
return attn_outputs, mlp_outputs, block_outputs
|
327 |
|
|
|
328 |
|
329 |
def image_dino_feature(images, resolution=(448, 448), node_type="block", layer=-1):
|
330 |
|
@@ -336,12 +348,14 @@ def image_dino_feature(images, resolution=(448, 448), node_type="block", layer=-
|
|
336 |
]
|
337 |
)
|
338 |
|
339 |
-
feat_extractor =
|
|
|
|
|
340 |
|
341 |
outputs = []
|
342 |
for i, image in enumerate(images):
|
343 |
torch_image = transform(image)
|
344 |
-
if
|
345 |
torch_image = torch_image.cuda()
|
346 |
attn_output, mlp_output, block_output = feat_extractor(
|
347 |
torch_image.unsqueeze(0)
|
@@ -356,6 +370,8 @@ def image_dino_feature(images, resolution=(448, 448), node_type="block", layer=-
|
|
356 |
outputs.append(out.cpu())
|
357 |
outputs = torch.cat(outputs, dim=0)
|
358 |
outputs = rearrange(outputs[:, 5:, :], "b (h w) c -> b h w c", h=32, w=32)
|
|
|
|
|
359 |
return outputs
|
360 |
|
361 |
|
@@ -368,7 +384,7 @@ class CLIP(torch.nn.Module):
|
|
368 |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
|
369 |
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
|
370 |
self.model = model.eval()
|
371 |
-
if
|
372 |
self.model = self.model.cuda()
|
373 |
|
374 |
def new_forward(
|
@@ -424,6 +440,7 @@ class CLIP(torch.nn.Module):
|
|
424 |
block_outputs = torch.stack(block_outputs)
|
425 |
return attn_outputs, mlp_outputs, block_outputs
|
426 |
|
|
|
427 |
|
428 |
def image_clip_feature(
|
429 |
images, resolution=(224, 224), node_type="block", layer=-1
|
@@ -442,12 +459,14 @@ def image_clip_feature(
|
|
442 |
]
|
443 |
)
|
444 |
|
445 |
-
feat_extractor =
|
|
|
|
|
446 |
|
447 |
outputs = []
|
448 |
for i, image in enumerate(images):
|
449 |
torch_image = transform(image)
|
450 |
-
if
|
451 |
torch_image = torch_image.cuda()
|
452 |
attn_output, mlp_output, block_output = feat_extractor(
|
453 |
torch_image.unsqueeze(0)
|
@@ -461,6 +480,8 @@ def image_clip_feature(
|
|
461 |
out = out[layer]
|
462 |
outputs.append(out.cpu())
|
463 |
outputs = torch.cat(outputs, dim=0)
|
|
|
|
|
464 |
return outputs
|
465 |
|
466 |
|
@@ -505,6 +526,27 @@ def compute_hash(*args, **kwargs):
|
|
505 |
return hasher.hexdigest()
|
506 |
|
507 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
508 |
def extract_features(images, model_name="sam", node_type="block", layer=-1):
|
509 |
# Compute the cache key
|
510 |
cache_key = compute_hash(images, model_name, node_type, layer)
|
@@ -514,20 +556,7 @@ def extract_features(images, model_name="sam", node_type="block", layer=-1):
|
|
514 |
print("Cache hit!")
|
515 |
return cache[cache_key]
|
516 |
|
517 |
-
|
518 |
-
# Compute the result if not in cache
|
519 |
-
if model_name == "SAM(sam_vit_b)":
|
520 |
-
if not use_cuda:
|
521 |
-
gr.warning("GPU not detected. Running SAM on CPU, ~30s/image.")
|
522 |
-
result = image_sam_feature(images, node_type=node_type, layer=layer)
|
523 |
-
elif model_name == 'MobileSAM':
|
524 |
-
result = image_mobilesam_feature(images, node_type=node_type, layer=layer)
|
525 |
-
elif model_name == "DiNO(dinov2_vitb14_reg)":
|
526 |
-
result = image_dino_feature(images, node_type=node_type, layer=layer)
|
527 |
-
elif model_name == "CLIP(openai/clip-vit-base-patch16)":
|
528 |
-
result = image_clip_feature(images, node_type=node_type, layer=layer)
|
529 |
-
else:
|
530 |
-
raise ValueError(f"Model {model_name} not supported.")
|
531 |
|
532 |
# Store the result in the cache
|
533 |
cache[cache_key] = result
|
@@ -550,11 +579,11 @@ def compute_ncut(
|
|
550 |
eigvecs, eigvals = NCUT(
|
551 |
num_eig=num_eig,
|
552 |
num_sample=num_sample_ncut,
|
553 |
-
device="
|
554 |
affinity_focal_gamma=affinity_focal_gamma,
|
555 |
knn=knn_ncut,
|
556 |
).fit_transform(features.reshape(-1, features.shape[-1]))
|
557 |
-
print(f"NCUT time: {time.time() - start:.2f}s")
|
558 |
|
559 |
start = time.time()
|
560 |
X_3d, rgb = rgb_from_tsne_3d(
|
@@ -563,7 +592,7 @@ def compute_ncut(
|
|
563 |
perplexity=perplexity,
|
564 |
knn=knn_tsne,
|
565 |
)
|
566 |
-
print(f"t-SNE time: {time.time() - start:.2f}s")
|
567 |
|
568 |
# print("input shape:", features.shape)
|
569 |
# print("output shape:", rgb.shape)
|
@@ -613,7 +642,7 @@ def main_fn(
|
|
613 |
features = extract_features(
|
614 |
images, model_name=model_name, node_type=node_type, layer=layer
|
615 |
)
|
616 |
-
print(f"Feature extraction time: {time.time() - start:.2f}s")
|
617 |
|
618 |
rgb = compute_ncut(
|
619 |
features,
|
|
|
11 |
|
12 |
import gradio as gr
|
13 |
|
14 |
+
import spaces
|
15 |
|
16 |
+
USE_CUDA = torch.cuda.is_available()
|
17 |
|
18 |
+
print("CUDA is available:", USE_CUDA)
|
19 |
|
20 |
class MobileSAM(nn.Module):
|
21 |
def __init__(self, **kwargs):
|
|
|
32 |
with open(sam_checkpoint, 'wb') as f:
|
33 |
f.write(r.content)
|
34 |
|
35 |
+
device = 'cuda' if USE_CUDA else 'cpu'
|
36 |
|
37 |
mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
38 |
|
|
|
135 |
block_outputs.append(blk.block_output)
|
136 |
return attn_outputs, mlp_outputs, block_outputs
|
137 |
|
138 |
+
mobilesam = MobileSAM()
|
139 |
|
140 |
def image_mobilesam_feature(
|
141 |
images,
|
|
|
153 |
)
|
154 |
|
155 |
|
156 |
+
feat_extractor = mobilesam
|
157 |
+
if USE_CUDA:
|
158 |
+
feat_extractor = feat_extractor.cuda()
|
159 |
|
160 |
# attn_outputs, mlp_outputs, block_outputs = [], [], []
|
161 |
outputs = []
|
162 |
for i, image in enumerate(images):
|
163 |
torch_image = transform(image)
|
164 |
+
if USE_CUDA:
|
165 |
torch_image = torch_image.cuda()
|
166 |
attn_output, mlp_output, block_output = feat_extractor(
|
167 |
torch_image.unsqueeze(0)
|
|
|
175 |
out = out[layer]
|
176 |
outputs.append(out.cpu())
|
177 |
outputs = torch.cat(outputs, dim=0)
|
178 |
+
|
179 |
+
mobilesam = mobilesam.cpu()
|
180 |
return outputs
|
181 |
|
182 |
|
183 |
|
184 |
class SAM(torch.nn.Module):
|
185 |
+
def __init__(self, **kwargs):
|
186 |
super().__init__(**kwargs)
|
187 |
from segment_anything import sam_model_registry, SamPredictor
|
188 |
from segment_anything.modeling.sam import Sam
|
189 |
+
|
190 |
+
checkpoint = "sam_vit_b_01ec64.pth"
|
191 |
+
if not os.path.exists(checkpoint):
|
192 |
+
checkpoint_url = 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth'
|
193 |
+
import requests
|
194 |
+
r = requests.get(checkpoint_url)
|
195 |
+
with open(checkpoint, 'wb') as f:
|
196 |
+
f.write(r.content)
|
197 |
|
198 |
sam: Sam = sam_model_registry["vit_b"](checkpoint=checkpoint)
|
199 |
|
|
|
228 |
|
229 |
self.image_encoder = sam.image_encoder
|
230 |
self.image_encoder.eval()
|
231 |
+
if USE_CUDA:
|
232 |
self.image_encoder = self.image_encoder.cuda()
|
233 |
|
234 |
@torch.no_grad()
|
|
|
247 |
block_outputs = torch.stack(block_outputs)
|
248 |
return attn_outputs, mlp_outputs, block_outputs
|
249 |
|
250 |
+
sam = SAM()
|
251 |
|
252 |
def image_sam_feature(
|
253 |
images,
|
|
|
263 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
264 |
]
|
265 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
|
267 |
+
feat_extractor = sam
|
268 |
+
if USE_CUDA:
|
269 |
+
feat_extractor = feat_extractor.cuda()
|
270 |
|
271 |
# attn_outputs, mlp_outputs, block_outputs = [], [], []
|
272 |
outputs = []
|
273 |
for i, image in enumerate(images):
|
274 |
torch_image = transform(image)
|
275 |
+
if USE_CUDA:
|
276 |
torch_image = torch_image.cuda()
|
277 |
attn_output, mlp_output, block_output = feat_extractor(
|
278 |
torch_image.unsqueeze(0)
|
|
|
286 |
out = out[layer]
|
287 |
outputs.append(out.cpu())
|
288 |
outputs = torch.cat(outputs, dim=0)
|
289 |
+
|
290 |
+
sam = sam.cpu()
|
291 |
+
|
292 |
return outputs
|
293 |
|
294 |
|
|
|
298 |
self.dinov2 = torch.hub.load("facebookresearch/dinov2", ver)
|
299 |
self.dinov2.requires_grad_(False)
|
300 |
self.dinov2.eval()
|
301 |
+
if USE_CUDA:
|
302 |
self.dinov2 = self.dinov2.cuda()
|
303 |
|
304 |
def new_block_forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
336 |
block_outputs = torch.stack(block_outputs)
|
337 |
return attn_outputs, mlp_outputs, block_outputs
|
338 |
|
339 |
+
dinov2 = DiNOv2()
|
340 |
|
341 |
def image_dino_feature(images, resolution=(448, 448), node_type="block", layer=-1):
|
342 |
|
|
|
348 |
]
|
349 |
)
|
350 |
|
351 |
+
feat_extractor = dinov2
|
352 |
+
if USE_CUDA:
|
353 |
+
feat_extractor = feat_extractor.cuda()
|
354 |
|
355 |
outputs = []
|
356 |
for i, image in enumerate(images):
|
357 |
torch_image = transform(image)
|
358 |
+
if USE_CUDA:
|
359 |
torch_image = torch_image.cuda()
|
360 |
attn_output, mlp_output, block_output = feat_extractor(
|
361 |
torch_image.unsqueeze(0)
|
|
|
370 |
outputs.append(out.cpu())
|
371 |
outputs = torch.cat(outputs, dim=0)
|
372 |
outputs = rearrange(outputs[:, 5:, :], "b (h w) c -> b h w c", h=32, w=32)
|
373 |
+
|
374 |
+
dinov2 = dinov2.cpu()
|
375 |
return outputs
|
376 |
|
377 |
|
|
|
384 |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
|
385 |
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
|
386 |
self.model = model.eval()
|
387 |
+
if USE_CUDA:
|
388 |
self.model = self.model.cuda()
|
389 |
|
390 |
def new_forward(
|
|
|
440 |
block_outputs = torch.stack(block_outputs)
|
441 |
return attn_outputs, mlp_outputs, block_outputs
|
442 |
|
443 |
+
clip = CLIP()
|
444 |
|
445 |
def image_clip_feature(
|
446 |
images, resolution=(224, 224), node_type="block", layer=-1
|
|
|
459 |
]
|
460 |
)
|
461 |
|
462 |
+
feat_extractor = clip
|
463 |
+
if USE_CUDA:
|
464 |
+
feat_extractor = feat_extractor.cuda()
|
465 |
|
466 |
outputs = []
|
467 |
for i, image in enumerate(images):
|
468 |
torch_image = transform(image)
|
469 |
+
if USE_CUDA:
|
470 |
torch_image = torch_image.cuda()
|
471 |
attn_output, mlp_output, block_output = feat_extractor(
|
472 |
torch_image.unsqueeze(0)
|
|
|
480 |
out = out[layer]
|
481 |
outputs.append(out.cpu())
|
482 |
outputs = torch.cat(outputs, dim=0)
|
483 |
+
|
484 |
+
clip = clip.cpu()
|
485 |
return outputs
|
486 |
|
487 |
|
|
|
526 |
return hasher.hexdigest()
|
527 |
|
528 |
|
529 |
+
@spaces.GPU(duration=30)
|
530 |
+
def run_model_on_image(image, model_name="sam", node_type="block", layer=-1):
|
531 |
+
global USE_CUDA
|
532 |
+
USE_CUDA = True
|
533 |
+
|
534 |
+
if model_name == "SAM(sam_vit_b)":
|
535 |
+
if not USE_CUDA:
|
536 |
+
gr.warning("GPU not detected. Running SAM on CPU, ~30s/image.")
|
537 |
+
result = image_sam_feature([image], node_type=node_type, layer=layer)
|
538 |
+
elif model_name == 'MobileSAM':
|
539 |
+
result = image_mobilesam_feature([image], node_type=node_type, layer=layer)
|
540 |
+
elif model_name == "DiNO(dinov2_vitb14_reg)":
|
541 |
+
result = image_dino_feature([image], node_type=node_type, layer=layer)
|
542 |
+
elif model_name == "CLIP(openai/clip-vit-base-patch16)":
|
543 |
+
result = image_clip_feature([image], node_type=node_type, layer=layer)
|
544 |
+
else:
|
545 |
+
raise ValueError(f"Model {model_name} not supported.")
|
546 |
+
|
547 |
+
USE_CUDA = False
|
548 |
+
return result
|
549 |
+
|
550 |
def extract_features(images, model_name="sam", node_type="block", layer=-1):
|
551 |
# Compute the cache key
|
552 |
cache_key = compute_hash(images, model_name, node_type, layer)
|
|
|
556 |
print("Cache hit!")
|
557 |
return cache[cache_key]
|
558 |
|
559 |
+
result = run_model_on_image(images[0], model_name=model_name, node_type=node_type, layer=layer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
560 |
|
561 |
# Store the result in the cache
|
562 |
cache[cache_key] = result
|
|
|
579 |
eigvecs, eigvals = NCUT(
|
580 |
num_eig=num_eig,
|
581 |
num_sample=num_sample_ncut,
|
582 |
+
device="cpu",
|
583 |
affinity_focal_gamma=affinity_focal_gamma,
|
584 |
knn=knn_ncut,
|
585 |
).fit_transform(features.reshape(-1, features.shape[-1]))
|
586 |
+
print(f"NCUT time (cpu): {time.time() - start:.2f}s")
|
587 |
|
588 |
start = time.time()
|
589 |
X_3d, rgb = rgb_from_tsne_3d(
|
|
|
592 |
perplexity=perplexity,
|
593 |
knn=knn_tsne,
|
594 |
)
|
595 |
+
print(f"t-SNE time (cpu): {time.time() - start:.2f}s")
|
596 |
|
597 |
# print("input shape:", features.shape)
|
598 |
# print("output shape:", rgb.shape)
|
|
|
642 |
features = extract_features(
|
643 |
images, model_name=model_name, node_type=node_type, layer=layer
|
644 |
)
|
645 |
+
print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
|
646 |
|
647 |
rgb = compute_ncut(
|
648 |
features,
|