Spaces:
Runtime error
Runtime error
import io | |
import os.path | |
import sys | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import scipy.sparse | |
import torch | |
import torch.nn.functional as F | |
import torchvision | |
import torchvision.transforms.functional as TF | |
from gradio.inputs import Image as GradioInputImage | |
from gradio.outputs import Image as GradioOutputImage | |
from PIL import Image | |
from scipy.sparse.linalg import eigsh | |
from skimage.color import label2rgb | |
from torch.utils.hooks import RemovableHandle | |
from torchvision import transforms | |
from torchvision.utils import make_grid | |
from matplotlib.pyplot import get_cmap | |
def get_model(name: str): | |
if 'dino' in name: | |
model = torch.hub.load('facebookresearch/dino:main', name) | |
model.fc = torch.nn.Identity() | |
val_transform = get_transform(name) | |
patch_size = model.patch_embed.patch_size | |
num_heads = model.blocks[0].attn.num_heads | |
elif name in ['mocov3_vits16', 'mocov3_vitb16']: | |
model = torch.hub.load('facebookresearch/dino:main', name.replace('mocov3', 'dino')) | |
checkpoint_file, size_char = { | |
'mocov3_vits16': ('vit-s-300ep-timm-format.pth', 's'), | |
'mocov3_vitb16': ('vit-b-300ep-timm-format.pth', 'b'), | |
}[name] | |
url = f'https://dl.fbaipublicfiles.com/moco-v3/vit-{size_char}-300ep/vit-{size_char}-300ep.pth.tar' | |
checkpoint = torch.hub.load_state_dict_from_url(url) | |
model.load_state_dict(checkpoint['model']) | |
model.fc = torch.nn.Identity() | |
val_transform = get_transform(name) | |
patch_size = model.patch_embed.patch_size | |
num_heads = model.blocks[0].attn.num_heads | |
else: | |
raise ValueError(f'Unsupported model: {name}') | |
model = model.eval() | |
return model, val_transform, patch_size, num_heads | |
def get_transform(name: str): | |
if any(x in name for x in ('dino', 'mocov3', 'convnext', )): | |
normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |
transform = transforms.Compose([transforms.ToTensor(), normalize]) | |
else: | |
raise NotImplementedError() | |
return transform | |
def get_diagonal(W: scipy.sparse.csr_matrix, threshold: float = 1e-12): | |
D = W.dot(np.ones(W.shape[1], W.dtype)) | |
D[D < threshold] = 1.0 # Prevent division by zero. | |
D = scipy.sparse.diags(D) | |
return D | |
# Parameters | |
model_name = 'dino_vitb16' # TODOL Figure out how to make this user-editable | |
K = 5 | |
# Fixed parameters | |
MAX_SIZE = 384 | |
# Load model | |
model, val_transform, patch_size, num_heads = get_model(model_name) | |
# GPU | |
if torch.cuda.is_available(): | |
print("CUDA is available, using GPU.") | |
device = torch.device("cuda") | |
model.to(device) | |
else: | |
print("CUDA is not available, using CPU.") | |
device = torch.device("cpu") | |
def segment(inp: Image): | |
# NOTE: The image is already resized to the desired size. | |
# Preprocess image | |
images: torch.Tensor = val_transform(inp) | |
images = images.unsqueeze(0).to(device) | |
# Add hook | |
which_block = -1 | |
if 'dino' in model_name or 'mocov3' in model_name: | |
feat_out = {} | |
def hook_fn_forward_qkv(module, input, output): | |
feat_out["qkv"] = output | |
handle: RemovableHandle = model._modules["blocks"][which_block]._modules["attn"]._modules["qkv"].register_forward_hook( | |
hook_fn_forward_qkv | |
) | |
else: | |
raise ValueError(model_name) | |
# Reshape image | |
P = patch_size | |
B, C, H, W = images.shape | |
H_patch, W_patch = H // P, W // P | |
H_pad, W_pad = H_patch * P, W_patch * P | |
T = H_patch * W_patch + 1 # number of tokens, add 1 for [CLS] | |
# Crop image to be a multiple of the patch size | |
images = images[:, :, :H_pad, :W_pad] | |
# Extract features | |
if 'dino' in model_name or 'mocov3' in model_name: | |
model.get_intermediate_layers(images)[0].squeeze(0) | |
output_qkv = feat_out["qkv"].reshape(B, T, 3, num_heads, -1 // num_heads).permute(2, 0, 3, 1, 4) | |
feats = output_qkv[1].transpose(1, 2).reshape(B, T, -1)[:, 1:, :].squeeze(0) | |
else: | |
raise ValueError(model_name) | |
# Remove hook from the model | |
handle.remove() | |
# Normalize features | |
normalize = True | |
if normalize: | |
feats = F.normalize(feats, p=2, dim=-1) | |
# Compute affinity matrix | |
W_feat = (feats @ feats.T) | |
# Feature affinities | |
threshold_at_zero = True | |
if threshold_at_zero: | |
W_feat = (W_feat * (W_feat > 0)) | |
W_feat = W_feat / W_feat.max() # NOTE: If features are normalized, this naturally does nothing | |
W_feat = W_feat.cpu().numpy() | |
# # NOTE: Here is where we would add the color information. For simplicity, we will not add it here. | |
# W_comb = W_feat + W_color * image_color_lambda # combination | |
# D_comb = np.array(get_diagonal(W_comb).todense()) # is dense or sparse faster? not sure, should check | |
# Diagonal | |
W_comb = W_feat | |
D_comb = np.array(get_diagonal(W_comb).todense()) # is dense or sparse faster? not sure, should check | |
# Compute eigenvectors | |
try: | |
eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=(K + 1), sigma=0, which='LM', M=D_comb) | |
except: | |
eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=(K + 1), which='SM', M=D_comb) | |
eigenvalues = torch.from_numpy(eigenvalues) | |
eigenvectors = torch.from_numpy(eigenvectors.T).float() | |
# Resolve sign ambiguity | |
for k in range(eigenvectors.shape[0]): | |
if 0.5 < torch.mean((eigenvectors[k] > 0).float()).item() < 1.0: # reverse segment | |
eigenvectors[k] = 0 - eigenvectors[k] | |
# Arrange eigenvectors into grid | |
cmap = get_cmap('viridis') | |
output_images = [] | |
for i in range(1, K + 1): | |
eigenvector = eigenvectors[i].reshape(1, 1, H_patch, W_patch) # .reshape(1, 1, H_pad, W_pad) | |
eigenvector: torch.Tensor = F.interpolate(eigenvector, size=(H_pad, W_pad), mode='bilinear', align_corners=False) # slightly off, but for visualizations this is okay | |
buffer = io.BytesIO() | |
plt.imsave(buffer, eigenvector.squeeze().numpy(), format='png') # save to a temporary location | |
buffer.seek(0) | |
eigenvector_vis = Image.open(buffer).convert('RGB') | |
# eigenvector_vis = TF.to_tensor(eigenvector_vis).unsqueeze(0) | |
eigenvector_vis = np.array(eigenvector_vis) | |
output_images.append(eigenvector_vis) | |
# output_images = torch.cat(output_images, dim=0) | |
# output_images = make_grid(output_images, nrow=8, pad_value=1) | |
# # Postprocess for Gradio | |
# output_images = np.array(TF.to_pil_image(output_images)) | |
print(f'{len(output_images)=}') | |
return output_images | |
# Placeholders | |
input_placeholders = GradioInputImage(source="upload", tool="editor", type="pil") | |
# output_placeholders = GradioOutputImage(type="numpy", label=f"Eigenvectors") | |
output_placeholders = [GradioOutputImage(type="numpy", label=f"Eigenvector {i}") for i in range(K)] | |
# Metadata | |
examples = [f"examples/{stem}.jpg" for stem in [ | |
'2008_000099', '2008_000499', '2007_009446', '2007_001586', '2010_001256', '2008_000764', '2008_000705', # '2007_000039' | |
]] | |
title = "Deep Spectral Segmentation" | |
description = "Deep spectral segmentation..." | |
thumbnail = "https://raw.githubusercontent.com/gradio-app/hub-echonet/master/thumbnail.png" | |
# Gradio | |
gr.Interface( | |
segment, | |
input_placeholders, | |
output_placeholders, | |
examples=examples, | |
allow_flagging=False, | |
analytics_enabled=False, | |
title=title, | |
description=description, | |
thumbnail=thumbnail | |
).launch() | |