|
import torch |
|
import numpy as np |
|
import gradio as gr |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
from sklearn.decomposition import PCA |
|
from torchvision import transforms as T |
|
from sklearn.preprocessing import MinMaxScaler |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
dino = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14') |
|
dino.eval() |
|
dino.to(device) |
|
|
|
pca = PCA(n_components=3) |
|
scaler = MinMaxScaler(clip=True) |
|
|
|
def plot_img(img_array: np.array) -> go.Figure: |
|
fig = px.imshow(img_array) |
|
fig.update_layout( |
|
xaxis=dict(showticklabels=False), |
|
yaxis=dict(showticklabels=False) |
|
) |
|
|
|
return fig |
|
|
|
|
|
def app_fn( |
|
img: np.ndarray, |
|
threshold: float, |
|
object_larger_than_bg: bool |
|
) -> go.Figure: |
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) |
|
|
|
patch_h = 40 |
|
patch_w = 40 |
|
|
|
transform = T.Compose([ |
|
T.Resize((14 * patch_h, 14 * patch_w)), |
|
T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), |
|
]) |
|
|
|
img = torch.from_numpy(img).type(torch.float).permute(2, 0, 1) / 255 |
|
img_tensor = transform(img).unsqueeze(0).to(device) |
|
|
|
with torch.no_grad(): |
|
out = dino.forward_features(img_tensor) |
|
|
|
features = out["x_prenorm"][:, 1:, :] |
|
features = features.squeeze(0) |
|
features = features.cpu().numpy() |
|
|
|
pca_features = pca.fit_transform(features) |
|
pca_features = scaler.fit_transform(pca_features) |
|
|
|
if object_larger_than_bg: |
|
pca_features_bg = pca_features[:, 0] > threshold |
|
else: |
|
pca_features_bg = pca_features[:, 0] < threshold |
|
|
|
pca_features_fg = ~pca_features_bg |
|
|
|
pca_features_fg_seg = pca.fit_transform(features[pca_features_fg]) |
|
|
|
pca_features_fg_seg = scaler.fit_transform(pca_features_fg_seg) |
|
|
|
pca_features_rgb = np.zeros((patch_h * patch_w, 3)) |
|
pca_features_rgb[pca_features_bg] = 0 |
|
pca_features_rgb[pca_features_fg] = pca_features_fg_seg |
|
pca_features_rgb = pca_features_rgb.reshape(patch_h, patch_w, 3) |
|
|
|
|
|
fig_pca = plot_img(pca_features_rgb) |
|
|
|
return fig_pca |
|
|
|
if __name__=="__main__": |
|
title = "DINOv2" |
|
with gr.Blocks(title=title) as demo: |
|
gr.Markdown(f"# {title}") |
|
gr.Markdown( |
|
""" |
|
""" |
|
) |
|
with gr.Row(): |
|
threshold = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.05, label="Threshold") |
|
object_larger_than_bg = gr.Checkbox(label="Object Larger than Background", value=False) |
|
btn = gr.Button(label="Visualize") |
|
with gr.Row(): |
|
img = gr.Image() |
|
fig_pca = gr.Plot(label="PCA Features") |
|
|
|
btn.click(fn=app_fn, inputs=[img, threshold, object_larger_than_bg], outputs=[fig_pca]) |
|
examples = gr.Examples( |
|
examples=[ |
|
["assets/photo-1.jpg", 0.6, False], |
|
["assets/photo-2.jpg", 0.7, True], |
|
["assets/photo-3.jpg", 0.8, False] |
|
], |
|
inputs=[img, threshold, object_larger_than_bg], |
|
outputs=[fig_pca], |
|
fn=app_fn, |
|
cache_examples=True |
|
) |
|
|
|
demo.queue(max_size=5).launch() |