File size: 3,211 Bytes
3ec9877 f398fe3 3ec9877 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import torch
import numpy as np
import gradio as gr
import plotly.express as px
import plotly.graph_objects as go
from sklearn.decomposition import PCA
from torchvision import transforms as T
from sklearn.preprocessing import MinMaxScaler
device = "cuda" if torch.cuda.is_available() else "cpu"
dino = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
dino.eval()
dino.to(device)
pca = PCA(n_components=3)
scaler = MinMaxScaler(clip=True)
def plot_img(img_array: np.array) -> go.Figure:
fig = px.imshow(img_array)
fig.update_layout(
xaxis=dict(showticklabels=False),
yaxis=dict(showticklabels=False)
)
return fig
def app_fn(
img: np.ndarray,
threshold: float,
object_larger_than_bg: bool
) -> go.Figure:
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
patch_h = 40
patch_w = 40
transform = T.Compose([
T.Resize((14 * patch_h, 14 * patch_w)),
T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
])
img = torch.from_numpy(img).type(torch.float).permute(2, 0, 1) / 255
img_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
out = dino.forward_features(img_tensor)
features = out["x_prenorm"][:, 1:, :]
features = features.squeeze(0)
features = features.cpu().numpy()
pca_features = pca.fit_transform(features)
pca_features = scaler.fit_transform(pca_features)
if object_larger_than_bg:
pca_features_bg = pca_features[:, 0] > threshold
else:
pca_features_bg = pca_features[:, 0] < threshold
pca_features_fg = ~pca_features_bg
pca_features_fg_seg = pca.fit_transform(features[pca_features_fg])
pca_features_fg_seg = scaler.fit_transform(pca_features_fg_seg)
pca_features_rgb = np.zeros((patch_h * patch_w, 3))
pca_features_rgb[pca_features_bg] = 0
pca_features_rgb[pca_features_fg] = pca_features_fg_seg
pca_features_rgb = pca_features_rgb.reshape(patch_h, patch_w, 3)
fig_pca = plot_img(pca_features_rgb)
return fig_pca
if __name__=="__main__":
title = "DINOv2"
with gr.Blocks(title=title) as demo:
gr.Markdown(f"# {title}")
gr.Markdown(
"""
"""
)
with gr.Row():
threshold = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.05, label="Threshold")
object_larger_than_bg = gr.Checkbox(label="Object Larger than Background", value=False)
btn = gr.Button(label="Visualize")
with gr.Row():
img = gr.Image()
fig_pca = gr.Plot(label="PCA Features")
btn.click(fn=app_fn, inputs=[img, threshold, object_larger_than_bg], outputs=[fig_pca])
examples = gr.Examples(
examples=[
["assets/photo-1.jpg", 0.6, True],
["assets/photo-2.jpg", 0.7, True],
["assets/photo-3.jpg", 0.8, False]
],
inputs=[img, threshold, object_larger_than_bg],
outputs=[fig_pca],
fn=app_fn,
cache_examples=True
)
demo.queue(max_size=5).launch() |