lukemelas commited on
Commit
cd30264
·
1 Parent(s): 59b4598
Files changed (1) hide show
  1. app.py +27 -21
app.py CHANGED
@@ -1,5 +1,6 @@
1
- from collections import namedtuple
2
  import io
 
3
  from typing import Tuple
4
 
5
  import gradio as gr
@@ -48,7 +49,11 @@ def get_model(name: str):
48
  def get_transform(name: str):
49
  if any(x in name for x in ('dino', 'mocov3', 'convnext', )):
50
  normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
51
- transform = transforms.Compose([transforms.ToTensor(), normalize])
 
 
 
 
52
  else:
53
  raise NotImplementedError()
54
  return transform
@@ -60,17 +65,28 @@ def get_diagonal(W: scipy.sparse.csr_matrix, threshold: float = 1e-12):
60
  D = scipy.sparse.diags(D)
61
  return D
62
 
 
 
63
 
64
  # Parameters
65
  model_name = 'dino_vitb16' # TODO: Figure out how to make this user-editable
66
  K = 5
67
 
68
- # Fixed parameters
69
- MAX_SIZE = 384
70
-
71
  # Load model
72
  model, val_transform, patch_size, num_heads = get_model(model_name)
73
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  # GPU
76
  if torch.cuda.is_available():
@@ -90,18 +106,6 @@ def segment(inp: Image):
90
  images: torch.Tensor = val_transform(inp)
91
  images = images.unsqueeze(0).to(device)
92
 
93
- # Add hook
94
- which_block = -1
95
- if 'dino' in model_name or 'mocov3' in model_name:
96
- feat_out = {}
97
- def hook_fn_forward_qkv(module, input, output):
98
- feat_out["qkv"] = output
99
- handle: RemovableHandle = model._modules["blocks"][which_block]._modules["attn"]._modules["qkv"].register_forward_hook(
100
- hook_fn_forward_qkv
101
- )
102
- else:
103
- raise ValueError(model_name)
104
-
105
  # Reshape image
106
  P = patch_size
107
  B, C, H, W = images.shape
@@ -119,9 +123,6 @@ def segment(inp: Image):
119
  feats = output_qkv[1].transpose(1, 2).reshape(B, T, -1)[:, 1:, :].squeeze(0)
120
  else:
121
  raise ValueError(model_name)
122
-
123
- # Remove hook from the model
124
- handle.remove()
125
 
126
  # Normalize features
127
  normalize = True
@@ -160,7 +161,7 @@ def segment(inp: Image):
160
  eigenvectors[k] = 0 - eigenvectors[k]
161
 
162
  # Arrange eigenvectors into grid
163
- cmap = get_cmap('viridis')
164
  output_images = []
165
  # eigenvectors_upscaled = []
166
  for i in range(1, K + 1):
@@ -209,6 +210,11 @@ def segment(inp: Image):
209
  # # Postprocess for Gradio
210
  # output_images = np.array(TF.to_pil_image(output_images))
211
  print(f'{len(output_images)=}')
 
 
 
 
 
212
  return output_images
213
 
214
  # Placeholders
 
1
+ import gc
2
  import io
3
+ from collections import namedtuple
4
  from typing import Tuple
5
 
6
  import gradio as gr
 
49
  def get_transform(name: str):
50
  if any(x in name for x in ('dino', 'mocov3', 'convnext', )):
51
  normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
52
+ transform = transforms.Compose([
53
+ transforms.Resize(size=256, interpolation=TF.InterpolationMode.BICUBIC, max_size=384),
54
+ transforms.ToTensor(),
55
+ normalize
56
+ ])
57
  else:
58
  raise NotImplementedError()
59
  return transform
 
65
  D = scipy.sparse.diags(D)
66
  return D
67
 
68
+ # Cache
69
+ torch.cuda.empty_cache()
70
 
71
  # Parameters
72
  model_name = 'dino_vitb16' # TODO: Figure out how to make this user-editable
73
  K = 5
74
 
 
 
 
75
  # Load model
76
  model, val_transform, patch_size, num_heads = get_model(model_name)
77
 
78
+ # Add hook
79
+ which_block = -1
80
+ if 'dino' in model_name or 'mocov3' in model_name:
81
+ feat_out = {}
82
+ def hook_fn_forward_qkv(module, input, output):
83
+ feat_out["qkv"] = output
84
+ handle: RemovableHandle = model._modules["blocks"][which_block]._modules["attn"]._modules["qkv"].register_forward_hook(
85
+ hook_fn_forward_qkv
86
+ )
87
+ else:
88
+ raise ValueError(model_name)
89
+
90
 
91
  # GPU
92
  if torch.cuda.is_available():
 
106
  images: torch.Tensor = val_transform(inp)
107
  images = images.unsqueeze(0).to(device)
108
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  # Reshape image
110
  P = patch_size
111
  B, C, H, W = images.shape
 
123
  feats = output_qkv[1].transpose(1, 2).reshape(B, T, -1)[:, 1:, :].squeeze(0)
124
  else:
125
  raise ValueError(model_name)
 
 
 
126
 
127
  # Normalize features
128
  normalize = True
 
161
  eigenvectors[k] = 0 - eigenvectors[k]
162
 
163
  # Arrange eigenvectors into grid
164
+ # cmap = get_cmap('viridis')
165
  output_images = []
166
  # eigenvectors_upscaled = []
167
  for i in range(1, K + 1):
 
210
  # # Postprocess for Gradio
211
  # output_images = np.array(TF.to_pil_image(output_images))
212
  print(f'{len(output_images)=}')
213
+
214
+ # Garbage collection and other memory-related things
215
+ gc.collect()
216
+ del eigenvector, eigenvector_vis, eigenvectors, W_comb, D_comb
217
+
218
  return output_images
219
 
220
  # Placeholders