Spaces:
Runtime error
Runtime error
Update app
Browse files- app.py +31 -20
- examples/2007_000039.jpg +0 -0
- examples/2007_001586.jpg +0 -0
- examples/2007_009446.jpg +0 -0
- examples/2008_000099.jpg +0 -0
- examples/2008_000499.jpg +0 -0
- examples/2008_000705.jpg +0 -0
- examples/2008_000764.jpg +0 -0
- examples/2010_001256.jpg +0 -0
app.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
-
import
|
2 |
import os.path
|
3 |
import sys
|
4 |
-
from os.path import splitext
|
5 |
|
6 |
import gradio as gr
|
7 |
import matplotlib.pyplot as plt
|
@@ -19,6 +18,7 @@ from skimage.color import label2rgb
|
|
19 |
from torch.utils.hooks import RemovableHandle
|
20 |
from torchvision import transforms
|
21 |
from torchvision.utils import make_grid
|
|
|
22 |
|
23 |
|
24 |
def get_model(name: str):
|
@@ -67,6 +67,8 @@ def get_diagonal(W: scipy.sparse.csr_matrix, threshold: float = 1e-12):
|
|
67 |
model_name = 'dino_vitb16' # TODOL Figure out how to make this user-editable
|
68 |
K = 5
|
69 |
|
|
|
|
|
70 |
|
71 |
# Load model
|
72 |
model, val_transform, patch_size, num_heads = get_model(model_name)
|
@@ -122,7 +124,7 @@ def segment(inp: Image):
|
|
122 |
|
123 |
# Remove hook from the model
|
124 |
handle.remove()
|
125 |
-
|
126 |
# Normalize features
|
127 |
normalize = True
|
128 |
if normalize:
|
@@ -160,27 +162,36 @@ def segment(inp: Image):
|
|
160 |
eigenvectors[k] = 0 - eigenvectors[k]
|
161 |
|
162 |
# Arrange eigenvectors into grid
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
eigenvector =
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
# Placeholders
|
178 |
-
input_placeholders = GradioInputImage(
|
179 |
-
output_placeholders = GradioOutputImage(type="numpy", label=f"Eigenvectors")
|
180 |
-
|
181 |
|
182 |
# Metadata
|
183 |
-
examples = [
|
|
|
|
|
|
|
184 |
title = "Deep Spectral Segmentation"
|
185 |
description = "Deep spectral segmentation..."
|
186 |
thumbnail = "https://raw.githubusercontent.com/gradio-app/hub-echonet/master/thumbnail.png"
|
|
|
1 |
+
import io
|
2 |
import os.path
|
3 |
import sys
|
|
|
4 |
|
5 |
import gradio as gr
|
6 |
import matplotlib.pyplot as plt
|
|
|
18 |
from torch.utils.hooks import RemovableHandle
|
19 |
from torchvision import transforms
|
20 |
from torchvision.utils import make_grid
|
21 |
+
from matplotlib.pyplot import get_cmap
|
22 |
|
23 |
|
24 |
def get_model(name: str):
|
|
|
67 |
model_name = 'dino_vitb16' # TODOL Figure out how to make this user-editable
|
68 |
K = 5
|
69 |
|
70 |
+
# Fixed parameters
|
71 |
+
MAX_SIZE = 384
|
72 |
|
73 |
# Load model
|
74 |
model, val_transform, patch_size, num_heads = get_model(model_name)
|
|
|
124 |
|
125 |
# Remove hook from the model
|
126 |
handle.remove()
|
127 |
+
|
128 |
# Normalize features
|
129 |
normalize = True
|
130 |
if normalize:
|
|
|
162 |
eigenvectors[k] = 0 - eigenvectors[k]
|
163 |
|
164 |
# Arrange eigenvectors into grid
|
165 |
+
cmap = get_cmap('viridis')
|
166 |
+
output_images = []
|
167 |
+
for i in range(1, K + 1):
|
168 |
+
eigenvector = eigenvectors[i].reshape(1, 1, H_patch, W_patch) # .reshape(1, 1, H_pad, W_pad)
|
169 |
+
eigenvector: torch.Tensor = F.interpolate(eigenvector, size=(H_pad, W_pad), mode='bilinear', align_corners=False) # slightly off, but for visualizations this is okay
|
170 |
+
buffer = io.BytesIO()
|
171 |
+
plt.imsave(buffer, eigenvector.squeeze().numpy(), format='png') # save to a temporary location
|
172 |
+
buffer.seek(0)
|
173 |
+
eigenvector_vis = Image.open(buffer).convert('RGB')
|
174 |
+
# eigenvector_vis = TF.to_tensor(eigenvector_vis).unsqueeze(0)
|
175 |
+
eigenvector_vis = np.array(eigenvector_vis)
|
176 |
+
output_images.append(eigenvector_vis)
|
177 |
+
# output_images = torch.cat(output_images, dim=0)
|
178 |
+
# output_images = make_grid(output_images, nrow=8, pad_value=1)
|
179 |
+
|
180 |
+
# # Postprocess for Gradio
|
181 |
+
# output_images = np.array(TF.to_pil_image(output_images))
|
182 |
+
print(f'{len(output_images)=}')
|
183 |
+
return output_images
|
184 |
|
185 |
# Placeholders
|
186 |
+
input_placeholders = GradioInputImage(source="upload", tool="editor", type="pil")
|
187 |
+
# output_placeholders = GradioOutputImage(type="numpy", label=f"Eigenvectors")
|
188 |
+
output_placeholders = [GradioOutputImage(type="numpy", label=f"Eigenvector {i}") for i in range(K)]
|
189 |
|
190 |
# Metadata
|
191 |
+
examples = [f"examples/{stem}.jpg" for stem in [
|
192 |
+
'2008_000099', '2008_000499', '2007_009446', '2007_001586', '2010_001256', '2008_000764', '2008_000705', # '2007_000039'
|
193 |
+
]]
|
194 |
+
|
195 |
title = "Deep Spectral Segmentation"
|
196 |
description = "Deep spectral segmentation..."
|
197 |
thumbnail = "https://raw.githubusercontent.com/gradio-app/hub-echonet/master/thumbnail.png"
|
examples/2007_000039.jpg
ADDED
examples/2007_001586.jpg
ADDED
examples/2007_009446.jpg
ADDED
examples/2008_000099.jpg
ADDED
examples/2008_000499.jpg
ADDED
examples/2008_000705.jpg
ADDED
examples/2008_000764.jpg
ADDED
examples/2010_001256.jpg
ADDED