Spaces:
Runtime error
Runtime error
Update
Browse files
app.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
-
|
2 |
import io
|
|
|
3 |
from typing import Tuple
|
4 |
|
5 |
import gradio as gr
|
@@ -48,7 +49,11 @@ def get_model(name: str):
|
|
48 |
def get_transform(name: str):
|
49 |
if any(x in name for x in ('dino', 'mocov3', 'convnext', )):
|
50 |
normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
51 |
-
transform = transforms.Compose([
|
|
|
|
|
|
|
|
|
52 |
else:
|
53 |
raise NotImplementedError()
|
54 |
return transform
|
@@ -60,17 +65,28 @@ def get_diagonal(W: scipy.sparse.csr_matrix, threshold: float = 1e-12):
|
|
60 |
D = scipy.sparse.diags(D)
|
61 |
return D
|
62 |
|
|
|
|
|
63 |
|
64 |
# Parameters
|
65 |
model_name = 'dino_vitb16' # TODO: Figure out how to make this user-editable
|
66 |
K = 5
|
67 |
|
68 |
-
# Fixed parameters
|
69 |
-
MAX_SIZE = 384
|
70 |
-
|
71 |
# Load model
|
72 |
model, val_transform, patch_size, num_heads = get_model(model_name)
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
# GPU
|
76 |
if torch.cuda.is_available():
|
@@ -90,18 +106,6 @@ def segment(inp: Image):
|
|
90 |
images: torch.Tensor = val_transform(inp)
|
91 |
images = images.unsqueeze(0).to(device)
|
92 |
|
93 |
-
# Add hook
|
94 |
-
which_block = -1
|
95 |
-
if 'dino' in model_name or 'mocov3' in model_name:
|
96 |
-
feat_out = {}
|
97 |
-
def hook_fn_forward_qkv(module, input, output):
|
98 |
-
feat_out["qkv"] = output
|
99 |
-
handle: RemovableHandle = model._modules["blocks"][which_block]._modules["attn"]._modules["qkv"].register_forward_hook(
|
100 |
-
hook_fn_forward_qkv
|
101 |
-
)
|
102 |
-
else:
|
103 |
-
raise ValueError(model_name)
|
104 |
-
|
105 |
# Reshape image
|
106 |
P = patch_size
|
107 |
B, C, H, W = images.shape
|
@@ -119,9 +123,6 @@ def segment(inp: Image):
|
|
119 |
feats = output_qkv[1].transpose(1, 2).reshape(B, T, -1)[:, 1:, :].squeeze(0)
|
120 |
else:
|
121 |
raise ValueError(model_name)
|
122 |
-
|
123 |
-
# Remove hook from the model
|
124 |
-
handle.remove()
|
125 |
|
126 |
# Normalize features
|
127 |
normalize = True
|
@@ -160,7 +161,7 @@ def segment(inp: Image):
|
|
160 |
eigenvectors[k] = 0 - eigenvectors[k]
|
161 |
|
162 |
# Arrange eigenvectors into grid
|
163 |
-
cmap = get_cmap('viridis')
|
164 |
output_images = []
|
165 |
# eigenvectors_upscaled = []
|
166 |
for i in range(1, K + 1):
|
@@ -209,6 +210,11 @@ def segment(inp: Image):
|
|
209 |
# # Postprocess for Gradio
|
210 |
# output_images = np.array(TF.to_pil_image(output_images))
|
211 |
print(f'{len(output_images)=}')
|
|
|
|
|
|
|
|
|
|
|
212 |
return output_images
|
213 |
|
214 |
# Placeholders
|
|
|
1 |
+
import gc
|
2 |
import io
|
3 |
+
from collections import namedtuple
|
4 |
from typing import Tuple
|
5 |
|
6 |
import gradio as gr
|
|
|
49 |
def get_transform(name: str):
|
50 |
if any(x in name for x in ('dino', 'mocov3', 'convnext', )):
|
51 |
normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
52 |
+
transform = transforms.Compose([
|
53 |
+
transforms.Resize(size=256, interpolation=TF.InterpolationMode.BICUBIC, max_size=384),
|
54 |
+
transforms.ToTensor(),
|
55 |
+
normalize
|
56 |
+
])
|
57 |
else:
|
58 |
raise NotImplementedError()
|
59 |
return transform
|
|
|
65 |
D = scipy.sparse.diags(D)
|
66 |
return D
|
67 |
|
68 |
+
# Cache
|
69 |
+
torch.cuda.empty_cache()
|
70 |
|
71 |
# Parameters
|
72 |
model_name = 'dino_vitb16' # TODO: Figure out how to make this user-editable
|
73 |
K = 5
|
74 |
|
|
|
|
|
|
|
75 |
# Load model
|
76 |
model, val_transform, patch_size, num_heads = get_model(model_name)
|
77 |
|
78 |
+
# Add hook
|
79 |
+
which_block = -1
|
80 |
+
if 'dino' in model_name or 'mocov3' in model_name:
|
81 |
+
feat_out = {}
|
82 |
+
def hook_fn_forward_qkv(module, input, output):
|
83 |
+
feat_out["qkv"] = output
|
84 |
+
handle: RemovableHandle = model._modules["blocks"][which_block]._modules["attn"]._modules["qkv"].register_forward_hook(
|
85 |
+
hook_fn_forward_qkv
|
86 |
+
)
|
87 |
+
else:
|
88 |
+
raise ValueError(model_name)
|
89 |
+
|
90 |
|
91 |
# GPU
|
92 |
if torch.cuda.is_available():
|
|
|
106 |
images: torch.Tensor = val_transform(inp)
|
107 |
images = images.unsqueeze(0).to(device)
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
# Reshape image
|
110 |
P = patch_size
|
111 |
B, C, H, W = images.shape
|
|
|
123 |
feats = output_qkv[1].transpose(1, 2).reshape(B, T, -1)[:, 1:, :].squeeze(0)
|
124 |
else:
|
125 |
raise ValueError(model_name)
|
|
|
|
|
|
|
126 |
|
127 |
# Normalize features
|
128 |
normalize = True
|
|
|
161 |
eigenvectors[k] = 0 - eigenvectors[k]
|
162 |
|
163 |
# Arrange eigenvectors into grid
|
164 |
+
# cmap = get_cmap('viridis')
|
165 |
output_images = []
|
166 |
# eigenvectors_upscaled = []
|
167 |
for i in range(1, K + 1):
|
|
|
210 |
# # Postprocess for Gradio
|
211 |
# output_images = np.array(TF.to_pil_image(output_images))
|
212 |
print(f'{len(output_images)=}')
|
213 |
+
|
214 |
+
# Garbage collection and other memory-related things
|
215 |
+
gc.collect()
|
216 |
+
del eigenvector, eigenvector_vis, eigenvectors, W_comb, D_comb
|
217 |
+
|
218 |
return output_images
|
219 |
|
220 |
# Placeholders
|