Spaces:
Runtime error
Runtime error
# Show VLAD clustering for set of example images or a user image | |
""" | |
User input: | |
- Domain: Indoor, Aerial, or Urban | |
- Image: Image to be clustered | |
- Cluster numbers (to visualize) | |
- Pixel coordinates (to pick further clusters) | |
- A unique cache ID (to store the DINO forward passes) | |
There are example images for each domain. | |
Output: | |
- All images with cluster assignments | |
Some Gradio links: | |
- Controlling layout | |
- https://www.gradio.app/guides/quickstart#blocks-more-flexibility-and-control | |
- Data state (persistence) | |
- https://www.gradio.app/guides/interface-state | |
- https://www.gradio.app/docs/state | |
- Layout control | |
- https://www.gradio.app/guides/controlling-layout | |
- https://www.gradio.app/guides/blocks-and-event-listeners | |
""" | |
# A markdown string shown at the top of the app | |
header_markdown = """ | |
# AnyLoc Demo | |
\| [Website](https://anyloc.github.io/) \| \ | |
[GitHub](https://github.com/AnyLoc/AnyLoc) \| \ | |
[YouTube](https://youtu.be/ITo8rMInatk) \| | |
This space contains a collection of demos for AnyLoc. Each demo is a \ | |
self-contained application in the tabs below. The following \ | |
applications are included | |
1. **GeM t-SNE Projection**: Upload a set of images and see where \ | |
they land on a t-SNE projection of GeM descriptors from many \ | |
domains. This can be used to guide domain selection (from a few \ | |
representative images). | |
2. **Cluster Visualization**: This visualizes the VLAD cluster \ | |
assignments for the patch descriptors. You need to select the \ | |
domain for loading VLAD cluster centers (vocabulary). | |
We do **not** save any images uploaded to the demo. Some errors may \ | |
leave a log. We do not collect any information about the user. The \ | |
example images are attributed in the respective tabs. | |
🥳 Thanks to HuggingFace for providing a free GPU for this demo. | |
""" | |
# %% | |
import os | |
import gradio as gr | |
import numpy as np | |
import cv2 as cv | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from torchvision import transforms as tvf | |
from torchvision.transforms import functional as T | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from sklearn.manifold import TSNE | |
import distinctipy as dipy | |
import joblib | |
from typing import Literal, List | |
import gradio as gr | |
import time | |
import glob | |
import shutil | |
import matplotlib.pyplot as plt | |
from copy import deepcopy | |
# DINOv2 imports | |
from utilities import DinoV2ExtractFeatures | |
from utilities import VLAD | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# %% | |
# Configurations | |
T1 = Literal["query", "key", "value", "token"] | |
T2 = Literal["aerial", "indoor", "urban"] | |
DOMAINS = ["aerial", "indoor", "urban"] | |
T3 = Literal["dinov2_vits14", "dinov2_vitb14", "dinov2_vitl14", | |
"dinov2_vitg14"] | |
_ex = lambda x: os.path.realpath(os.path.expanduser(x)) | |
dino_model: T3 = "dinov2_vitg14" | |
desc_layer: int = 31 | |
desc_facet: T1 = "value" | |
num_c: int = 8 | |
cache_dir: str = _ex("./cache") # Directory containing program cache | |
max_img_size: int = 1024 # Image resolution (max dim/size) | |
max_num_imgs: int = 16 # Max number of images to upload | |
share: bool = False # Share application using .gradio link | |
# Verify inputs | |
assert os.path.isdir(cache_dir), "Cache directory not found" | |
# %% | |
# Model and transforms | |
print("Loading DINO model") | |
# extractor = None # FIXME: For quick testing only | |
extractor = DinoV2ExtractFeatures(dino_model, desc_layer, desc_facet, | |
device=device) | |
print("DINO model loaded") | |
# VLAD path (directory) | |
ext_s = f"{dino_model}/l{desc_layer}_{desc_facet}_c{num_c}" | |
vc_dir = os.path.join(cache_dir, "vocabulary", ext_s) | |
assert os.path.isdir(vc_dir), f"VLAD directory: {vc_dir} not found" | |
# GeM path (cache) | |
gem_cf = os.path.join(cache_dir, "gem_cache", "result_dino_v2.gz") | |
assert os.path.isfile(gem_cf), f"GeM cache: {gem_cf} not found" | |
gem_cache = joblib.load(gem_cf) | |
assert gem_cache["model"]["type"] == dino_model | |
assert gem_cache["model"]["layer"] == desc_layer | |
assert gem_cache["model"]["facet"] == desc_facet | |
fig = plt.figure() # Main figure | |
fig.clear() | |
# Base image transformations | |
base_tf = tvf.Compose([ | |
tvf.ToTensor(), | |
tvf.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
# %% | |
# Get VLAD object | |
def get_vlad_clusters(domain, pr = gr.Progress()): | |
dm: T2 = str(domain).lower() | |
assert dm in DOMAINS, "Invalid domain" | |
# Load VLAD cluster centers | |
pr(0, desc="Loading VLAD clusters") | |
c_centers_file = os.path.join(vc_dir, dm, "c_centers.pt") | |
if not os.path.isfile(c_centers_file): | |
return f"Cluster centers not found for: {domain}", None | |
c_centers = torch.load(c_centers_file) | |
pr(0.5) | |
num_c = c_centers.shape[0] | |
desc_dim = c_centers.shape[1] | |
vlad = VLAD(num_c, desc_dim, | |
cache_dir=os.path.dirname(c_centers_file)) | |
vlad.fit(None) # Restore the cache | |
pr(1) | |
return f"VLAD clusters loaded for: {domain}", vlad | |
# %% | |
# Get VLAD descriptors | |
def get_descs(imgs_batch, pr = gr.Progress()): | |
imgs_batch: List[np.ndarray] = imgs_batch | |
pr(0, desc="Extracting descriptors") | |
patch_descs = [] | |
for i, img in enumerate(imgs_batch): | |
if img is None: | |
print(f"Image {i+1} is None") | |
continue | |
# Convert to PIL image | |
pil_img = Image.fromarray(img) | |
img_pt = base_tf(pil_img).to(device) | |
if max(img_pt.shape[-2:]) > max_img_size: | |
print(f"Image {i+1}: {img_pt.shape[-2:]}, outside") | |
c, h, w = img_pt.shape | |
# Maintain aspect ratio | |
if h == max(img_pt.shape[-2:]): | |
w = int(w * max_img_size / h) | |
h = max_img_size | |
else: | |
h = int(h * max_img_size / w) | |
w = max_img_size | |
img_pt = T.resize(img_pt, (h, w), | |
interpolation=T.InterpolationMode.BICUBIC) | |
pil_img = pil_img.resize((w, h)) # Backup | |
# Make image patchable | |
c, h, w = img_pt.shape | |
h_new, w_new = (h // 14) * 14, (w // 14) * 14 | |
img_pt = tvf.CenterCrop((h_new, w_new))(img_pt)[None, ...] | |
# Extract descriptors | |
ret = extractor(img_pt).cpu() # [1, n_p, d] | |
patch_descs.append({"img": pil_img, "descs": ret}) | |
pr((i+1) / len(imgs_batch)) | |
pr(1.0) | |
return patch_descs, \ | |
f"Descriptors extracted for {len(imgs_batch)} images" | |
# %% | |
# Assign VLAD clusters (descriptor assignment) | |
def assign_vlad(patch_descs, vlad, pr = gr.Progress()): | |
vlad: VLAD = vlad | |
img_patch_descs = [pd["descs"] for pd in patch_descs] | |
pr(0, desc="Assigning VLAD clusters") | |
desc_assignments = [] # List[Tensor;shape=('h', 'w');int] | |
for i, qu_desc in enumerate(img_patch_descs): | |
# Residual vectors; 'n' could differ (based on img sizes) | |
res = vlad.generate_res_vec(qu_desc[0]) # ['n', n_c, d] | |
img = patch_descs[i]["img"] | |
h, w, c = np.array(img).shape | |
h_p, w_p = h // 14, w // 14 | |
h_new, w_new = h_p * 14, w_p * 14 | |
assert h_p * w_p == res.shape[0], "Residual incorrect!" | |
# Descriptor assignments | |
da = res.abs().sum(dim=2).argmin(dim=1).reshape(h_p, w_p) | |
da = F.interpolate(da[None, None, ...].to(float), | |
(h_new, w_new), mode="nearest")[0, 0].to(da.dtype) | |
desc_assignments.append(da) | |
pr((i+1) / len(img_patch_descs)) | |
pr(1.0) | |
return desc_assignments, "VLAD clusters assigned" | |
# %% | |
# Cluster assignments to images | |
def get_ca_images(desc_assignments, patch_descs, alpha, | |
pr = gr.Progress()): | |
if desc_assignments is None or len(desc_assignments) == 0: | |
if not 0 <= alpha <= 1: | |
return None, f"Invalid alpha value: {alpha} (should be "\ | |
"between 0 and 1)" | |
return None, "First load the images" | |
c_colors = dipy.get_colors(num_c, rng=928, | |
colorblind_type="Deuteranomaly") | |
np_colors = (np.array(c_colors) * 255).astype(np.uint8) | |
# Get images with clusters | |
pil_imgs = [pd["img"] for pd in patch_descs] | |
res_imgs = [] # List[PIL.Image] | |
pr(0, desc="Generating cluster assignment images") | |
for i, pil_img in enumerate(pil_imgs): | |
# Descriptor assignment image: [h, w, 3] | |
da: torch.Tensor = desc_assignments[i] # ['h', 'w'] | |
da_img = np.zeros((*da.shape, 3), dtype=np.uint8) | |
for c in range(num_c): | |
da_img[da == c] = np_colors[c] | |
# Background image: [h, w, 3] | |
img_np = np.array(pil_img, dtype=np.uint8) | |
h, w, c = np.array(img_np).shape | |
h_p, w_p = (h // 14), (w // 14) | |
h_new, w_new = h_p * 14, w_p * 14 | |
img_np = F.interpolate(torch.tensor(img_np)\ | |
.permute(2, 0, 1)[None, ...], (h_new, w_new), | |
mode='nearest')[0].permute(1, 2, 0).numpy() | |
res_img = cv.addWeighted(img_np, 1 - alpha, da_img, alpha, 0.) | |
res_imgs.append(Image.fromarray(res_img)) | |
pr((i+1) / len(pil_imgs)) | |
pr(1.0) | |
return res_imgs, "Cluster assignment images generated" | |
# %% | |
# Get GeM descriptors from cache | |
def get_gem_descs_cache(use_d, pr = gr.Progress()): | |
use_d: List[str] = use_d | |
if len(use_d) == 0: | |
return "Select at least one domain", None | |
else: | |
use_d = [d.lower() for d in use_d] | |
indoor_datasets = ["baidu_datasets", "gardens", "17places"] | |
urban_datasets = ["pitts30k", "st_lucia", "Oxford"] | |
aerial_datasets = ["Tartan_GNSS_test_rotated", | |
"Tartan_GNSS_test_notrotated", "VPAir"] | |
pr(0, desc="Loading GeM descriptors from cache") | |
gem_descs = { | |
"labels": [], | |
"descs": [], | |
} | |
for i, ds in enumerate(gem_cache["data"]): | |
# GeM descriptors from data: n_desc, desc_dim | |
d: np.ndarray = gem_cache["data"][ds]["descriptors"] | |
if ds in indoor_datasets and "indoor" in use_d: | |
gem_descs["labels"].extend(["indoor"] * d.shape[0]) | |
elif ds in urban_datasets and "urban" in use_d: | |
gem_descs["labels"].extend(["urban"] * d.shape[0]) | |
elif ds in aerial_datasets and "aerial" in use_d: | |
gem_descs["labels"].extend(["aerial"] * d.shape[0]) | |
else: | |
continue | |
gem_descs["descs"].append(d) | |
pr((i+1) / len(gem_cache["data"])) | |
gem_descs["descs"] = np.concatenate(gem_descs["descs"], axis=0) | |
pr(1.0) | |
return "GeM descriptors loaded from cache", gem_descs | |
# %% | |
# Get GeM pooled features of the uploaded images | |
def get_add_gem_descs(imgs_batch, gem_descs, pr = gr.Progress()): | |
imgs_batch: List[np.ndarray] = imgs_batch | |
gem_descs: dict = gem_descs | |
pr(0, desc="Extracting GeM descriptors") | |
num_imgs_extracted = 0 | |
for i, img in enumerate(imgs_batch): | |
if img is None: | |
print(f"Image {i+1} is None") | |
continue | |
# Convert to PIL image | |
pil_img = Image.fromarray(img) | |
img_pt = base_tf(pil_img).to(device) | |
if max(img_pt.shape[-2:]) > max_img_size: | |
print(f"Image {i+1}: {img_pt.shape[-2:]}, outside") | |
c, h, w = img_pt.shape | |
# Maintain aspect ratio | |
if h == max(img_pt.shape[-2:]): | |
w = int(w * max_img_size / h) | |
h = max_img_size | |
else: | |
h = int(h * max_img_size / w) | |
w = max_img_size | |
img_pt = T.resize(img_pt, (h, w), | |
interpolation=T.InterpolationMode.BICUBIC) | |
pil_img = pil_img.resize((w, h)) # Backup | |
# Make image patchable | |
c, h, w = img_pt.shape | |
h_new, w_new = (h // 14) * 14, (w // 14) * 14 | |
img_pt = tvf.CenterCrop((h_new, w_new))(img_pt)[None, ...] | |
# Extract descriptors | |
ret = extractor(img_pt).cpu() # [1, n_p, d] | |
# Get the GeM pooled descriptor | |
x = torch.mean(ret**3, dim=-2) | |
g_res = x.to(torch.complex64) ** (1/3) | |
g_res = torch.abs(g_res) * torch.sign(x) # [1, d] | |
g_res = g_res.numpy() | |
# Add to state | |
gem_descs["labels"].append(f"Image{i+1}") | |
gem_descs["descs"] = np.concatenate([gem_descs["descs"], | |
g_res]) | |
num_imgs_extracted += 1 | |
pr((i+1) / len(imgs_batch)) | |
pr(1.0) | |
gem_descs["num_uimgs"] = num_imgs_extracted | |
return gem_descs, "GeM descriptors extracted" | |
# %% | |
# Apply tSNE to the GeM descriptors | |
def get_tsne_fm_gem(gem_descs, pr = gr.Progress()): | |
pr(0, desc="Applying tSNE to GeM descriptors") | |
desc_all: np.ndarray = gem_descs["descs"] # [n, d_dim] | |
labels_all: List[str] = gem_descs["labels"] # [n] | |
# tSNE projection | |
tsne = TSNE(n_components=2, random_state=30, perplexity=50, | |
learning_rate=200, init='random') | |
desc_2d = tsne.fit_transform(desc_all) | |
# Result | |
tsne_pts = { | |
"labels": labels_all, | |
"pts": desc_2d, | |
"num_uimgs": gem_descs["num_uimgs"], # Number of user imgs | |
} | |
pr(1.0) | |
return tsne_pts, "tSNE projection done" | |
# %% | |
# Plot tSNE to matplotlib figure | |
def plot_tsne(tsne_pts): | |
colors = { | |
"aerial": (80/255, 0/255, 80/255), | |
"indoor": ( 0/255, 76/255, 204/255), | |
"urban": ( 0/255, 204/255, 0/255), | |
} | |
ni = int(tsne_pts["num_uimgs"]) | |
# Custom colors for user images | |
ucs = dipy.get_colors(ni, exclude_colors=list(colors.values())\ | |
.extend([(0, 0, 0), (1, 1, 1)]), | |
colorblind_type="Deuteranomaly") | |
for i in range(ni): | |
colors[f"Image{i+1}"] = ucs[i] | |
fig.clear() | |
gs = fig.add_gridspec(1, 1) | |
ax = fig.add_subplot(gs[0, 0]) | |
ax.set_title("tSNE Projection") | |
for i, domain in enumerate(list(colors.keys())): | |
pts = tsne_pts["pts"][np.array(tsne_pts["labels"]) == domain] | |
if domain.startswith("Image"): | |
m = "x" | |
else: | |
m = "o" | |
ax.scatter(pts[:, 0], pts[:, 1], label=domain, marker=m, | |
color=colors[domain]) | |
# Put legend at the bottom of axis | |
ax.legend() | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
fig.set_tight_layout(True) | |
# fig.set_tight_layout(True) | |
return fig, "tSNE plot created" | |
# %% | |
print("Interface build started") | |
# Tab for VLAD cluster assignment visualization | |
def tab_cluster_viz(): | |
d_vals = [k.title() for k in DOMAINS] | |
domain = gr.Radio(d_vals, value=d_vals[0], label="Domain", | |
info="The domain of images (for loading VLAD vocabulary)") | |
nimg_s = gr.Number(2, label="How many images?", precision=0, | |
info=f"Between '1' and '{max_num_imgs}' images. Press "\ | |
"enter/return to register") | |
with gr.Row(): # Dynamic row (images in columns) | |
imgs = [gr.Image(label=f"Image {i+1}", visible=True) \ | |
for i in range(int(nimg_s.value))] + \ | |
[gr.Image(visible=False) \ | |
for _ in range(max_num_imgs - int(nimg_s.value))] | |
for i, img in enumerate(imgs): # Set image as "input" | |
img.change(lambda _: None, img) | |
with gr.Row(): # Dynamic row of output (cluster) images | |
imgs2 = [gr.Image(label=f"VLAD Clusters {i+1}", | |
visible=False) for i in range(max_num_imgs)] | |
nimg_s.submit(var_num_img, nimg_s, imgs) | |
blend_alpha = gr.Number(0.4, label="Blending alpha", | |
info="Weight for cluster centers (between 0 and 1). "\ | |
"Higher (close to 1) means greater emphasis on cluster "\ | |
"visibility. Lower (closer to 0) will show the "\ | |
"underlying image more. "\ | |
"Press enter/return to register") | |
bttn1 = gr.Button("Click Me!") # Cluster assignment | |
gr.Markdown("### Status strings") | |
out_msg1 = gr.Markdown("Select domain and upload images") | |
out_msg2 = gr.Markdown("For descriptor extraction") | |
out_msg3 = gr.Markdown("Followed by VLAD assignment") | |
out_msg4 = gr.Markdown("Followed by cluster images") | |
# ---- Utility functions ---- | |
# A wrapper to batch the images | |
def batch_images(data): | |
sv = int(data[nimg_s]) | |
images: List[np.ndarray] = [data[imgs[k]] \ | |
for k in range(sv)] | |
return images | |
# A wrapper to unbatch images (and pad to max) | |
def unbatch_images(imgs_batch, nimg): | |
ret = [gr.Image.update(visible=False) \ | |
for _ in range(max_num_imgs)] | |
if imgs_batch is None or len(imgs_batch) == 0: | |
return ret | |
for i in range(nimg): # nimg only to match input layout | |
if i < len(imgs_batch): | |
img_np = np.array(imgs_batch[i]) | |
else: | |
img_np = None | |
ret[i] = gr.Image.update(img_np, visible=True) | |
return ret | |
# ---- Examples ---- | |
# Two images from each domain | |
gr.Examples( | |
[ | |
["Aerial", 2, | |
"ex_aerial_nardo-air_db-42.png", | |
"ex_aerial_nardo-air_qu-42.png",], | |
["Indoor", 2, | |
"ex_indoor_17places_db-75.jpg", | |
"ex_indoor_17places_qu-75.jpg"], | |
["Urban", 2, | |
"ex_urban_oxford_db-75.png", | |
"ex_urban_oxford_qu-75.png"],], | |
[domain, nimg_s, *imgs], | |
) | |
# ---- Main pipeline ---- | |
# Get the VLAD cluster assignment images on click | |
bttn1.click(get_vlad_clusters, domain, [out_msg1, vlad])\ | |
.then(batch_images, {nimg_s, *imgs, imgs_batch}, imgs_batch)\ | |
.then(get_descs, imgs_batch, [patch_descs, out_msg2])\ | |
.then(assign_vlad, [patch_descs, vlad], | |
[desc_assignments, out_msg3])\ | |
.then(get_ca_images, | |
[desc_assignments, patch_descs, blend_alpha], | |
[imgs_batch, out_msg4])\ | |
.then(unbatch_images, [imgs_batch, nimg_s], imgs2) | |
# If the blending changes now, update the cluster images only | |
blend_alpha.submit(get_ca_images, | |
[desc_assignments, patch_descs, blend_alpha], | |
[imgs_batch, out_msg4])\ | |
.then(unbatch_images, [imgs_batch, nimg_s], imgs2) | |
# Tab for GeM t-SNE projection plot | |
def tab_gem_tsne(): | |
d_vals = [k.title() for k in DOMAINS] | |
dms = gr.CheckboxGroup(d_vals, value=d_vals, label="Domains", | |
info="The domains to use for the t-SNE projection") | |
nimg_s = gr.Number(2, label="How many images?", precision=0, | |
info=f"Between '1' and '{max_num_imgs}' images. Press "\ | |
"enter/return to register") | |
with gr.Row(): # Dynamic row (images in columns) | |
imgs = [gr.Image(label=f"Image {i+1}", visible=True) \ | |
for i in range(int(nimg_s.value))] + \ | |
[gr.Image(visible=False) \ | |
for _ in range(max_num_imgs - int(nimg_s.value))] | |
for i, img in enumerate(imgs): # Set image as "input" | |
img.change(lambda _: None, img) | |
nimg_s.submit(var_num_img, nimg_s, imgs) | |
tsne_plot = gr.Plot(None, label="tSNE Plot") | |
out_msg1 = gr.Markdown("Select domains") | |
out_msg2 = gr.Markdown("Upload images") | |
out_msg3 = gr.Markdown("Wait for tSNE plots") | |
# A wrapper to batch the images | |
def batch_images(data): | |
sv = int(data[nimg_s]) | |
# images: List[np.ndarray] = [data[imgs[k]] \ | |
# for k in range(sv)] | |
images: List[np.ndarray] = [] | |
for k in range(sv): | |
img = data[imgs[k]] | |
if img is None: | |
return None, f"Image {k+1} is None!" | |
images.append(img) | |
return images, "Images batched" | |
bttn1 = gr.Button("Click Me!") | |
# ---- Examples ---- | |
gr.Examples( | |
[ | |
["./ex_dining_room.jpeg", "./ex_city_road.jpeg"], | |
["./ex_manhattan_aerial.jpeg", "./ex_city_road.jpeg"], | |
["./ex_dining_room.jpeg", "./ex_manhattan_aerial.jpeg"], | |
], | |
[*imgs], | |
) | |
# ---- Main pipeline ---- | |
# Get the tSNE plot | |
bttn1.click(get_gem_descs_cache, dms, [out_msg1, gem_descs])\ | |
.then(batch_images, {nimg_s, *imgs, imgs_batch}, | |
[imgs_batch, out_msg2])\ | |
.then(get_add_gem_descs, [imgs_batch, gem_descs], | |
[gem_descs, out_msg2])\ | |
.then(get_tsne_fm_gem, gem_descs, [tsne_pts, out_msg3])\ | |
.then(plot_tsne, tsne_pts, [tsne_plot, out_msg3]) | |
# Build the interface | |
with gr.Blocks() as demo: | |
# Main header | |
gr.Markdown(header_markdown) | |
# ---- Helper functions ---- | |
# Variable number of input images (show/hide UI image array) | |
def var_num_img(s): | |
n = int(s) # Slider (string) value as int | |
assert 1 <= n <= max_num_imgs, f"Invalid num of images: {n}!" | |
return [gr.Image.update(label=f"Image {i+1}", visible=True) \ | |
for i in range(n)] \ | |
+ [gr.Image.update(visible=False) \ | |
for _ in range(max_num_imgs - n)] | |
# ---- State declarations ---- | |
vlad = gr.State() # VLAD object | |
desc_assignments = gr.State() # Cluster assignments | |
imgs_batch = gr.State() # Images as batch | |
patch_descs = gr.State() # Patch descriptors | |
gem_descs = gr.State() # GeM descriptors (of each state) | |
tsne_pts = gr.State() # tSNE points | |
# ---- All UI elements ---- | |
with gr.Tab("GeM t-SNE Projection"): | |
gr.Markdown( | |
""" | |
## GeM t-SNE Projection | |
Select the domains (toggle visibility) for t-SNE plot. \ | |
Enter the number of images to upload and upload images. \ | |
Then click the button to get the t-SNE plot. | |
You can also directly click on one of the examples (at \ | |
the bottom) to load the data and then click the button \ | |
to get the t-SNE plot. | |
The examples have the following images | |
- [Manhattan aerial view](https://www.crushpixel.com/stock-photo/aerial-view-midtown-manhattan-849717.html) | |
- [Dining room](https://homesfeed.com/formal-dining-room-sets-for-8/) | |
- [City road](https://pxhere.com/en/photo/824211) | |
""") | |
tab_gem_tsne() | |
with gr.Tab("Cluster Visualization"): | |
gr.Markdown( | |
""" | |
## Cluster Visualizations | |
Select the domain for the images (all should be from the \ | |
same domain). Enter the number of images to upload. \ | |
Upload the images. Then click the button to get the \ | |
cluster assignment images. | |
You can also directly click on one of the examples (at \ | |
the bottom) to load the data and then click the button \ | |
to get the cluster assignment images. | |
- The `aerial` example is from the Tartan Air dataset | |
- The `indoor` example is from the 17Places dataset | |
- The `urban` example is from the Oxford dataset | |
""") | |
tab_cluster_viz() | |
print("Interface build completed") | |
# %% | |
# Deploy application | |
demo.queue().launch(share=share) | |
print("Application deployment ended, exiting...") | |