huzey commited on
Commit
3f7fee9
1 Parent(s): a48bd1b

update gpu

Browse files
Files changed (1) hide show
  1. app.py +71 -42
app.py CHANGED
@@ -11,11 +11,11 @@ import time
11
 
12
  import gradio as gr
13
 
14
- use_cuda = torch.cuda.is_available()
15
 
16
- # use_cuda = False
17
 
18
- print("CUDA is available:", use_cuda)
19
 
20
  class MobileSAM(nn.Module):
21
  def __init__(self, **kwargs):
@@ -32,7 +32,7 @@ class MobileSAM(nn.Module):
32
  with open(sam_checkpoint, 'wb') as f:
33
  f.write(r.content)
34
 
35
- device = 'cuda' if use_cuda else 'cpu'
36
 
37
  mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
38
 
@@ -135,6 +135,7 @@ class MobileSAM(nn.Module):
135
  block_outputs.append(blk.block_output)
136
  return attn_outputs, mlp_outputs, block_outputs
137
 
 
138
 
139
  def image_mobilesam_feature(
140
  images,
@@ -152,13 +153,15 @@ def image_mobilesam_feature(
152
  )
153
 
154
 
155
- feat_extractor = MobileSAM()
 
 
156
 
157
  # attn_outputs, mlp_outputs, block_outputs = [], [], []
158
  outputs = []
159
  for i, image in enumerate(images):
160
  torch_image = transform(image)
161
- if use_cuda:
162
  torch_image = torch_image.cuda()
163
  attn_output, mlp_output, block_output = feat_extractor(
164
  torch_image.unsqueeze(0)
@@ -172,15 +175,25 @@ def image_mobilesam_feature(
172
  out = out[layer]
173
  outputs.append(out.cpu())
174
  outputs = torch.cat(outputs, dim=0)
 
 
175
  return outputs
176
 
177
 
178
 
179
  class SAM(torch.nn.Module):
180
- def __init__(self, checkpoint="/data/sam_model/sam_vit_b_01ec64.pth", **kwargs):
181
  super().__init__(**kwargs)
182
  from segment_anything import sam_model_registry, SamPredictor
183
  from segment_anything.modeling.sam import Sam
 
 
 
 
 
 
 
 
184
 
185
  sam: Sam = sam_model_registry["vit_b"](checkpoint=checkpoint)
186
 
@@ -215,7 +228,7 @@ class SAM(torch.nn.Module):
215
 
216
  self.image_encoder = sam.image_encoder
217
  self.image_encoder.eval()
218
- if use_cuda:
219
  self.image_encoder = self.image_encoder.cuda()
220
 
221
  @torch.no_grad()
@@ -234,6 +247,7 @@ class SAM(torch.nn.Module):
234
  block_outputs = torch.stack(block_outputs)
235
  return attn_outputs, mlp_outputs, block_outputs
236
 
 
237
 
238
  def image_sam_feature(
239
  images,
@@ -249,22 +263,16 @@ def image_sam_feature(
249
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
250
  ]
251
  )
252
-
253
- checkpoint = "sam_vit_b_01ec64.pth"
254
- if not os.path.exists(checkpoint):
255
- checkpoint_url = 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth'
256
- import requests
257
- r = requests.get(checkpoint_url)
258
- with open(checkpoint, 'wb') as f:
259
- f.write(r.content)
260
 
261
- feat_extractor = SAM(checkpoint=checkpoint)
 
 
262
 
263
  # attn_outputs, mlp_outputs, block_outputs = [], [], []
264
  outputs = []
265
  for i, image in enumerate(images):
266
  torch_image = transform(image)
267
- if use_cuda:
268
  torch_image = torch_image.cuda()
269
  attn_output, mlp_output, block_output = feat_extractor(
270
  torch_image.unsqueeze(0)
@@ -278,6 +286,9 @@ def image_sam_feature(
278
  out = out[layer]
279
  outputs.append(out.cpu())
280
  outputs = torch.cat(outputs, dim=0)
 
 
 
281
  return outputs
282
 
283
 
@@ -287,7 +298,7 @@ class DiNOv2(torch.nn.Module):
287
  self.dinov2 = torch.hub.load("facebookresearch/dinov2", ver)
288
  self.dinov2.requires_grad_(False)
289
  self.dinov2.eval()
290
- if use_cuda:
291
  self.dinov2 = self.dinov2.cuda()
292
 
293
  def new_block_forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -325,6 +336,7 @@ class DiNOv2(torch.nn.Module):
325
  block_outputs = torch.stack(block_outputs)
326
  return attn_outputs, mlp_outputs, block_outputs
327
 
 
328
 
329
  def image_dino_feature(images, resolution=(448, 448), node_type="block", layer=-1):
330
 
@@ -336,12 +348,14 @@ def image_dino_feature(images, resolution=(448, 448), node_type="block", layer=-
336
  ]
337
  )
338
 
339
- feat_extractor = DiNOv2()
 
 
340
 
341
  outputs = []
342
  for i, image in enumerate(images):
343
  torch_image = transform(image)
344
- if use_cuda:
345
  torch_image = torch_image.cuda()
346
  attn_output, mlp_output, block_output = feat_extractor(
347
  torch_image.unsqueeze(0)
@@ -356,6 +370,8 @@ def image_dino_feature(images, resolution=(448, 448), node_type="block", layer=-
356
  outputs.append(out.cpu())
357
  outputs = torch.cat(outputs, dim=0)
358
  outputs = rearrange(outputs[:, 5:, :], "b (h w) c -> b h w c", h=32, w=32)
 
 
359
  return outputs
360
 
361
 
@@ -368,7 +384,7 @@ class CLIP(torch.nn.Module):
368
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
369
  # processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
370
  self.model = model.eval()
371
- if use_cuda:
372
  self.model = self.model.cuda()
373
 
374
  def new_forward(
@@ -424,6 +440,7 @@ class CLIP(torch.nn.Module):
424
  block_outputs = torch.stack(block_outputs)
425
  return attn_outputs, mlp_outputs, block_outputs
426
 
 
427
 
428
  def image_clip_feature(
429
  images, resolution=(224, 224), node_type="block", layer=-1
@@ -442,12 +459,14 @@ def image_clip_feature(
442
  ]
443
  )
444
 
445
- feat_extractor = CLIP()
 
 
446
 
447
  outputs = []
448
  for i, image in enumerate(images):
449
  torch_image = transform(image)
450
- if use_cuda:
451
  torch_image = torch_image.cuda()
452
  attn_output, mlp_output, block_output = feat_extractor(
453
  torch_image.unsqueeze(0)
@@ -461,6 +480,8 @@ def image_clip_feature(
461
  out = out[layer]
462
  outputs.append(out.cpu())
463
  outputs = torch.cat(outputs, dim=0)
 
 
464
  return outputs
465
 
466
 
@@ -505,6 +526,27 @@ def compute_hash(*args, **kwargs):
505
  return hasher.hexdigest()
506
 
507
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  def extract_features(images, model_name="sam", node_type="block", layer=-1):
509
  # Compute the cache key
510
  cache_key = compute_hash(images, model_name, node_type, layer)
@@ -514,20 +556,7 @@ def extract_features(images, model_name="sam", node_type="block", layer=-1):
514
  print("Cache hit!")
515
  return cache[cache_key]
516
 
517
-
518
- # Compute the result if not in cache
519
- if model_name == "SAM(sam_vit_b)":
520
- if not use_cuda:
521
- gr.warning("GPU not detected. Running SAM on CPU, ~30s/image.")
522
- result = image_sam_feature(images, node_type=node_type, layer=layer)
523
- elif model_name == 'MobileSAM':
524
- result = image_mobilesam_feature(images, node_type=node_type, layer=layer)
525
- elif model_name == "DiNO(dinov2_vitb14_reg)":
526
- result = image_dino_feature(images, node_type=node_type, layer=layer)
527
- elif model_name == "CLIP(openai/clip-vit-base-patch16)":
528
- result = image_clip_feature(images, node_type=node_type, layer=layer)
529
- else:
530
- raise ValueError(f"Model {model_name} not supported.")
531
 
532
  # Store the result in the cache
533
  cache[cache_key] = result
@@ -550,11 +579,11 @@ def compute_ncut(
550
  eigvecs, eigvals = NCUT(
551
  num_eig=num_eig,
552
  num_sample=num_sample_ncut,
553
- device="cuda" if use_cuda else "cpu",
554
  affinity_focal_gamma=affinity_focal_gamma,
555
  knn=knn_ncut,
556
  ).fit_transform(features.reshape(-1, features.shape[-1]))
557
- print(f"NCUT time: {time.time() - start:.2f}s")
558
 
559
  start = time.time()
560
  X_3d, rgb = rgb_from_tsne_3d(
@@ -563,7 +592,7 @@ def compute_ncut(
563
  perplexity=perplexity,
564
  knn=knn_tsne,
565
  )
566
- print(f"t-SNE time: {time.time() - start:.2f}s")
567
 
568
  # print("input shape:", features.shape)
569
  # print("output shape:", rgb.shape)
@@ -613,7 +642,7 @@ def main_fn(
613
  features = extract_features(
614
  images, model_name=model_name, node_type=node_type, layer=layer
615
  )
616
- print(f"Feature extraction time: {time.time() - start:.2f}s")
617
 
618
  rgb = compute_ncut(
619
  features,
 
11
 
12
  import gradio as gr
13
 
14
+ import spaces
15
 
16
+ USE_CUDA = torch.cuda.is_available()
17
 
18
+ print("CUDA is available:", USE_CUDA)
19
 
20
  class MobileSAM(nn.Module):
21
  def __init__(self, **kwargs):
 
32
  with open(sam_checkpoint, 'wb') as f:
33
  f.write(r.content)
34
 
35
+ device = 'cuda' if USE_CUDA else 'cpu'
36
 
37
  mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
38
 
 
135
  block_outputs.append(blk.block_output)
136
  return attn_outputs, mlp_outputs, block_outputs
137
 
138
+ mobilesam = MobileSAM()
139
 
140
  def image_mobilesam_feature(
141
  images,
 
153
  )
154
 
155
 
156
+ feat_extractor = mobilesam
157
+ if USE_CUDA:
158
+ feat_extractor = feat_extractor.cuda()
159
 
160
  # attn_outputs, mlp_outputs, block_outputs = [], [], []
161
  outputs = []
162
  for i, image in enumerate(images):
163
  torch_image = transform(image)
164
+ if USE_CUDA:
165
  torch_image = torch_image.cuda()
166
  attn_output, mlp_output, block_output = feat_extractor(
167
  torch_image.unsqueeze(0)
 
175
  out = out[layer]
176
  outputs.append(out.cpu())
177
  outputs = torch.cat(outputs, dim=0)
178
+
179
+ mobilesam = mobilesam.cpu()
180
  return outputs
181
 
182
 
183
 
184
  class SAM(torch.nn.Module):
185
+ def __init__(self, **kwargs):
186
  super().__init__(**kwargs)
187
  from segment_anything import sam_model_registry, SamPredictor
188
  from segment_anything.modeling.sam import Sam
189
+
190
+ checkpoint = "sam_vit_b_01ec64.pth"
191
+ if not os.path.exists(checkpoint):
192
+ checkpoint_url = 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth'
193
+ import requests
194
+ r = requests.get(checkpoint_url)
195
+ with open(checkpoint, 'wb') as f:
196
+ f.write(r.content)
197
 
198
  sam: Sam = sam_model_registry["vit_b"](checkpoint=checkpoint)
199
 
 
228
 
229
  self.image_encoder = sam.image_encoder
230
  self.image_encoder.eval()
231
+ if USE_CUDA:
232
  self.image_encoder = self.image_encoder.cuda()
233
 
234
  @torch.no_grad()
 
247
  block_outputs = torch.stack(block_outputs)
248
  return attn_outputs, mlp_outputs, block_outputs
249
 
250
+ sam = SAM()
251
 
252
  def image_sam_feature(
253
  images,
 
263
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
264
  ]
265
  )
 
 
 
 
 
 
 
 
266
 
267
+ feat_extractor = sam
268
+ if USE_CUDA:
269
+ feat_extractor = feat_extractor.cuda()
270
 
271
  # attn_outputs, mlp_outputs, block_outputs = [], [], []
272
  outputs = []
273
  for i, image in enumerate(images):
274
  torch_image = transform(image)
275
+ if USE_CUDA:
276
  torch_image = torch_image.cuda()
277
  attn_output, mlp_output, block_output = feat_extractor(
278
  torch_image.unsqueeze(0)
 
286
  out = out[layer]
287
  outputs.append(out.cpu())
288
  outputs = torch.cat(outputs, dim=0)
289
+
290
+ sam = sam.cpu()
291
+
292
  return outputs
293
 
294
 
 
298
  self.dinov2 = torch.hub.load("facebookresearch/dinov2", ver)
299
  self.dinov2.requires_grad_(False)
300
  self.dinov2.eval()
301
+ if USE_CUDA:
302
  self.dinov2 = self.dinov2.cuda()
303
 
304
  def new_block_forward(self, x: torch.Tensor) -> torch.Tensor:
 
336
  block_outputs = torch.stack(block_outputs)
337
  return attn_outputs, mlp_outputs, block_outputs
338
 
339
+ dinov2 = DiNOv2()
340
 
341
  def image_dino_feature(images, resolution=(448, 448), node_type="block", layer=-1):
342
 
 
348
  ]
349
  )
350
 
351
+ feat_extractor = dinov2
352
+ if USE_CUDA:
353
+ feat_extractor = feat_extractor.cuda()
354
 
355
  outputs = []
356
  for i, image in enumerate(images):
357
  torch_image = transform(image)
358
+ if USE_CUDA:
359
  torch_image = torch_image.cuda()
360
  attn_output, mlp_output, block_output = feat_extractor(
361
  torch_image.unsqueeze(0)
 
370
  outputs.append(out.cpu())
371
  outputs = torch.cat(outputs, dim=0)
372
  outputs = rearrange(outputs[:, 5:, :], "b (h w) c -> b h w c", h=32, w=32)
373
+
374
+ dinov2 = dinov2.cpu()
375
  return outputs
376
 
377
 
 
384
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
385
  # processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
386
  self.model = model.eval()
387
+ if USE_CUDA:
388
  self.model = self.model.cuda()
389
 
390
  def new_forward(
 
440
  block_outputs = torch.stack(block_outputs)
441
  return attn_outputs, mlp_outputs, block_outputs
442
 
443
+ clip = CLIP()
444
 
445
  def image_clip_feature(
446
  images, resolution=(224, 224), node_type="block", layer=-1
 
459
  ]
460
  )
461
 
462
+ feat_extractor = clip
463
+ if USE_CUDA:
464
+ feat_extractor = feat_extractor.cuda()
465
 
466
  outputs = []
467
  for i, image in enumerate(images):
468
  torch_image = transform(image)
469
+ if USE_CUDA:
470
  torch_image = torch_image.cuda()
471
  attn_output, mlp_output, block_output = feat_extractor(
472
  torch_image.unsqueeze(0)
 
480
  out = out[layer]
481
  outputs.append(out.cpu())
482
  outputs = torch.cat(outputs, dim=0)
483
+
484
+ clip = clip.cpu()
485
  return outputs
486
 
487
 
 
526
  return hasher.hexdigest()
527
 
528
 
529
+ @spaces.GPU(duration=30)
530
+ def run_model_on_image(image, model_name="sam", node_type="block", layer=-1):
531
+ global USE_CUDA
532
+ USE_CUDA = True
533
+
534
+ if model_name == "SAM(sam_vit_b)":
535
+ if not USE_CUDA:
536
+ gr.warning("GPU not detected. Running SAM on CPU, ~30s/image.")
537
+ result = image_sam_feature([image], node_type=node_type, layer=layer)
538
+ elif model_name == 'MobileSAM':
539
+ result = image_mobilesam_feature([image], node_type=node_type, layer=layer)
540
+ elif model_name == "DiNO(dinov2_vitb14_reg)":
541
+ result = image_dino_feature([image], node_type=node_type, layer=layer)
542
+ elif model_name == "CLIP(openai/clip-vit-base-patch16)":
543
+ result = image_clip_feature([image], node_type=node_type, layer=layer)
544
+ else:
545
+ raise ValueError(f"Model {model_name} not supported.")
546
+
547
+ USE_CUDA = False
548
+ return result
549
+
550
  def extract_features(images, model_name="sam", node_type="block", layer=-1):
551
  # Compute the cache key
552
  cache_key = compute_hash(images, model_name, node_type, layer)
 
556
  print("Cache hit!")
557
  return cache[cache_key]
558
 
559
+ result = run_model_on_image(images[0], model_name=model_name, node_type=node_type, layer=layer)
 
 
 
 
 
 
 
 
 
 
 
 
 
560
 
561
  # Store the result in the cache
562
  cache[cache_key] = result
 
579
  eigvecs, eigvals = NCUT(
580
  num_eig=num_eig,
581
  num_sample=num_sample_ncut,
582
+ device="cpu",
583
  affinity_focal_gamma=affinity_focal_gamma,
584
  knn=knn_ncut,
585
  ).fit_transform(features.reshape(-1, features.shape[-1]))
586
+ print(f"NCUT time (cpu): {time.time() - start:.2f}s")
587
 
588
  start = time.time()
589
  X_3d, rgb = rgb_from_tsne_3d(
 
592
  perplexity=perplexity,
593
  knn=knn_tsne,
594
  )
595
+ print(f"t-SNE time (cpu): {time.time() - start:.2f}s")
596
 
597
  # print("input shape:", features.shape)
598
  # print("output shape:", rgb.shape)
 
642
  features = extract_features(
643
  images, model_name=model_name, node_type=node_type, layer=layer
644
  )
645
+ print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
646
 
647
  rgb = compute_ncut(
648
  features,