TheProjectsGuy commited on
Commit
476e478
1 Parent(s): ea57705

Pushed the second version of HF App

Browse files
app.py CHANGED
@@ -23,6 +23,34 @@
23
  - https://www.gradio.app/guides/blocks-and-event-listeners
24
  """
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # %%
27
  import os
28
  import gradio as gr
@@ -35,12 +63,15 @@ from torchvision import transforms as tvf
35
  from torchvision.transforms import functional as T
36
  from PIL import Image
37
  import matplotlib.pyplot as plt
 
38
  import distinctipy as dipy
 
39
  from typing import Literal, List
40
  import gradio as gr
41
  import time
42
  import glob
43
  import shutil
 
44
  from copy import deepcopy
45
  # DINOv2 imports
46
  from utilities import DinoV2ExtractFeatures
@@ -62,22 +93,33 @@ desc_facet: T1 = "value"
62
  num_c: int = 8
63
  cache_dir: str = _ex("./cache") # Directory containing program cache
64
  max_img_size: int = 1024 # Image resolution (max dim/size)
65
- max_num_imgs: int = 10 # Max number of images to upload
66
  share: bool = False # Share application using .gradio link
67
 
68
  # Verify inputs
69
  assert os.path.isdir(cache_dir), "Cache directory not found"
70
 
 
71
  # %%
72
  # Model and transforms
73
  print("Loading DINO model")
74
- # extractor = None
75
  extractor = DinoV2ExtractFeatures(dino_model, desc_layer, desc_facet,
76
  device=device)
77
  print("DINO model loaded")
78
  # VLAD path (directory)
79
  ext_s = f"{dino_model}/l{desc_layer}_{desc_facet}_c{num_c}"
80
  vc_dir = os.path.join(cache_dir, "vocabulary", ext_s)
 
 
 
 
 
 
 
 
 
 
81
  # Base image transformations
82
  base_tf = tvf.Compose([
83
  tvf.ToTensor(),
@@ -115,6 +157,9 @@ def get_descs(imgs_batch, pr = gr.Progress()):
115
  pr(0, desc="Extracting descriptors")
116
  patch_descs = []
117
  for i, img in enumerate(imgs_batch):
 
 
 
118
  # Convert to PIL image
119
  pil_img = Image.fromarray(img)
120
  img_pt = base_tf(pil_img).to(device)
@@ -139,6 +184,7 @@ def get_descs(imgs_batch, pr = gr.Progress()):
139
  ret = extractor(img_pt).cpu() # [1, n_p, d]
140
  patch_descs.append({"img": pil_img, "descs": ret})
141
  pr((i+1) / len(imgs_batch))
 
142
  return patch_descs, \
143
  f"Descriptors extracted for {len(imgs_batch)} images"
144
 
@@ -173,7 +219,10 @@ def assign_vlad(patch_descs, vlad, pr = gr.Progress()):
173
  def get_ca_images(desc_assignments, patch_descs, alpha,
174
  pr = gr.Progress()):
175
  if desc_assignments is None or len(desc_assignments) == 0:
176
- return None, "First load images"
 
 
 
177
  c_colors = dipy.get_colors(num_c, rng=928,
178
  colorblind_type="Deuteranomaly")
179
  np_colors = (np.array(c_colors) * 255).astype(np.uint8)
@@ -202,43 +251,177 @@ def get_ca_images(desc_assignments, patch_descs, alpha,
202
  return res_imgs, "Cluster assignment images generated"
203
 
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  # %%
206
  print("Interface build started")
207
- # Build the interface
208
- with gr.Blocks() as demo:
209
- # ---- Helper functions ----
210
- # Variable number of input images
211
- def var_num_img(s):
212
- n = int(s) # Slider value as int
213
- return [gr.Image.update(label=f"Image {i+1}", visible=True) \
214
- for i in range(n)] + [gr.Image.update(visible=False) \
215
- for _ in range(max_num_imgs - n)]
216
-
217
- # ---- State declarations ----
218
- vlad = gr.State() # VLAD object
219
- desc_assignments = gr.State() # Cluster assignments
220
- imgs_batch = gr.State() # Images as batch
221
- patch_descs = gr.State() # Patch descriptors
222
-
223
- # ---- All UI elements ----
224
  d_vals = [k.title() for k in DOMAINS]
225
- domain = gr.Radio(d_vals, value=d_vals[0])
226
- nimg_s = gr.Slider(1, max_num_imgs, value=1, step=1,
227
- label="How many images?") # How many images?
 
 
228
  with gr.Row(): # Dynamic row (images in columns)
229
  imgs = [gr.Image(label=f"Image {i+1}", visible=True) \
230
- for i in range(nimg_s.value)] + \
231
  [gr.Image(visible=False) \
232
- for _ in range(max_num_imgs - nimg_s.value)]
233
  for i, img in enumerate(imgs): # Set image as "input"
234
  img.change(lambda _: None, img)
235
  with gr.Row(): # Dynamic row of output (cluster) images
236
  imgs2 = [gr.Image(label=f"VLAD Clusters {i+1}",
237
  visible=False) for i in range(max_num_imgs)]
238
- nimg_s.change(var_num_img, nimg_s, imgs)
239
- blend_alpha = gr.Slider(0, 1, 0.4, step=0.01, # Cluster centers
240
- label="Blend alpha (weight for cluster centers)")
 
 
 
 
241
  bttn1 = gr.Button("Click Me!") # Cluster assignment
 
242
  out_msg1 = gr.Markdown("Select domain and upload images")
243
  out_msg2 = gr.Markdown("For descriptor extraction")
244
  out_msg3 = gr.Markdown("Followed by VLAD assignment")
@@ -247,21 +430,40 @@ with gr.Blocks() as demo:
247
  # ---- Utility functions ----
248
  # A wrapper to batch the images
249
  def batch_images(data):
250
- sv = data[nimg_s]
251
  images: List[np.ndarray] = [data[imgs[k]] \
252
  for k in range(sv)]
253
  return images
254
  # A wrapper to unbatch images (and pad to max)
255
- def unbatch_images(imgs_batch):
256
  ret = [gr.Image.update(visible=False) \
257
  for _ in range(max_num_imgs)]
258
  if imgs_batch is None or len(imgs_batch) == 0:
259
  return ret
260
- for i, img_pil in enumerate(imgs_batch):
261
- img_np = np.array(img_pil)
 
 
 
262
  ret[i] = gr.Image.update(img_np, visible=True)
263
  return ret
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  # ---- Main pipeline ----
266
  # Get the VLAD cluster assignment images on click
267
  bttn1.click(get_vlad_clusters, domain, [out_msg1, vlad])\
@@ -272,12 +474,113 @@ with gr.Blocks() as demo:
272
  .then(get_ca_images,
273
  [desc_assignments, patch_descs, blend_alpha],
274
  [imgs_batch, out_msg4])\
275
- .then(unbatch_images, imgs_batch, imgs2)
276
- # If the blending changes now, update the cluster images
277
- blend_alpha.change(get_ca_images,
278
  [desc_assignments, patch_descs, blend_alpha],
279
  [imgs_batch, out_msg4])\
280
- .then(unbatch_images, imgs_batch, imgs2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
  print("Interface build completed")
283
 
 
23
  - https://www.gradio.app/guides/blocks-and-event-listeners
24
  """
25
 
26
+ # A markdown string shown at the top of the app
27
+ header_markdown = """
28
+ # AnyLoc Demo
29
+
30
+ \| [Website](https://anyloc.github.io/) \| \
31
+ [GitHub](https://github.com/AnyLoc/AnyLoc) \| \
32
+ [YouTube](https://youtu.be/ITo8rMInatk) \|
33
+
34
+
35
+ This space contains a collection of demos for AnyLoc. Each demo is a \
36
+ self-contained application in the tabs below. The following \
37
+ applications are included
38
+
39
+ 1. **GeM t-SNE Projection**: Upload a set of images and see where \
40
+ they land on a t-SNE projection of GeM descriptors from many \
41
+ domains. This can be used to guide domain selection (from a few \
42
+ representative images).
43
+ 2. **Cluster Visualization**: This visualizes the VLAD cluster \
44
+ assignments for the patch descriptors. You need to select the \
45
+ domain for loading VLAD cluster centers (vocabulary).
46
+
47
+ We do **not** save any images uploaded to the demo. Some errors may \
48
+ leave a log. We do not collect any information about the user.
49
+
50
+ 🥳 Thanks to HuggingFace for providing a free GPU for this demo.
51
+
52
+ """
53
+
54
  # %%
55
  import os
56
  import gradio as gr
 
63
  from torchvision.transforms import functional as T
64
  from PIL import Image
65
  import matplotlib.pyplot as plt
66
+ from sklearn.manifold import TSNE
67
  import distinctipy as dipy
68
+ import joblib
69
  from typing import Literal, List
70
  import gradio as gr
71
  import time
72
  import glob
73
  import shutil
74
+ import matplotlib.pyplot as plt
75
  from copy import deepcopy
76
  # DINOv2 imports
77
  from utilities import DinoV2ExtractFeatures
 
93
  num_c: int = 8
94
  cache_dir: str = _ex("./cache") # Directory containing program cache
95
  max_img_size: int = 1024 # Image resolution (max dim/size)
96
+ max_num_imgs: int = 16 # Max number of images to upload
97
  share: bool = False # Share application using .gradio link
98
 
99
  # Verify inputs
100
  assert os.path.isdir(cache_dir), "Cache directory not found"
101
 
102
+
103
  # %%
104
  # Model and transforms
105
  print("Loading DINO model")
106
+ # extractor = None # FIXME: For quick testing only
107
  extractor = DinoV2ExtractFeatures(dino_model, desc_layer, desc_facet,
108
  device=device)
109
  print("DINO model loaded")
110
  # VLAD path (directory)
111
  ext_s = f"{dino_model}/l{desc_layer}_{desc_facet}_c{num_c}"
112
  vc_dir = os.path.join(cache_dir, "vocabulary", ext_s)
113
+ assert os.path.isdir(vc_dir), f"VLAD directory: {vc_dir} not found"
114
+ # GeM path (cache)
115
+ gem_cf = os.path.join(cache_dir, "gem_cache", "result_dino_v2.gz")
116
+ assert os.path.isfile(gem_cf), f"GeM cache: {gem_cf} not found"
117
+ gem_cache = joblib.load(gem_cf)
118
+ assert gem_cache["model"]["type"] == dino_model
119
+ assert gem_cache["model"]["layer"] == desc_layer
120
+ assert gem_cache["model"]["facet"] == desc_facet
121
+ fig = plt.figure() # Main figure
122
+ fig.clear()
123
  # Base image transformations
124
  base_tf = tvf.Compose([
125
  tvf.ToTensor(),
 
157
  pr(0, desc="Extracting descriptors")
158
  patch_descs = []
159
  for i, img in enumerate(imgs_batch):
160
+ if img is None:
161
+ print(f"Image {i+1} is None")
162
+ continue
163
  # Convert to PIL image
164
  pil_img = Image.fromarray(img)
165
  img_pt = base_tf(pil_img).to(device)
 
184
  ret = extractor(img_pt).cpu() # [1, n_p, d]
185
  patch_descs.append({"img": pil_img, "descs": ret})
186
  pr((i+1) / len(imgs_batch))
187
+ pr(1.0)
188
  return patch_descs, \
189
  f"Descriptors extracted for {len(imgs_batch)} images"
190
 
 
219
  def get_ca_images(desc_assignments, patch_descs, alpha,
220
  pr = gr.Progress()):
221
  if desc_assignments is None or len(desc_assignments) == 0:
222
+ if not 0 <= alpha <= 1:
223
+ return None, f"Invalid alpha value: {alpha} (should be "\
224
+ "between 0 and 1)"
225
+ return None, "First load the images"
226
  c_colors = dipy.get_colors(num_c, rng=928,
227
  colorblind_type="Deuteranomaly")
228
  np_colors = (np.array(c_colors) * 255).astype(np.uint8)
 
251
  return res_imgs, "Cluster assignment images generated"
252
 
253
 
254
+ # %%
255
+ # Get GeM descriptors from cache
256
+ def get_gem_descs_cache(use_d, pr = gr.Progress()):
257
+ use_d: List[str] = use_d
258
+ if len(use_d) == 0:
259
+ return "Select at least one domain", None
260
+ else:
261
+ use_d = [d.lower() for d in use_d]
262
+ indoor_datasets = ["baidu_datasets", "gardens", "17places"]
263
+ urban_datasets = ["pitts30k", "st_lucia", "Oxford"]
264
+ aerial_datasets = ["Tartan_GNSS_test_rotated",
265
+ "Tartan_GNSS_test_notrotated", "VPAir"]
266
+ pr(0, desc="Loading GeM descriptors from cache")
267
+ gem_descs = {
268
+ "labels": [],
269
+ "descs": [],
270
+ }
271
+ for i, ds in enumerate(gem_cache["data"]):
272
+ # GeM descriptors from data: n_desc, desc_dim
273
+ d: np.ndarray = gem_cache["data"][ds]["descriptors"]
274
+ if ds in indoor_datasets and "indoor" in use_d:
275
+ gem_descs["labels"].extend(["indoor"] * d.shape[0])
276
+ elif ds in urban_datasets and "urban" in use_d:
277
+ gem_descs["labels"].extend(["urban"] * d.shape[0])
278
+ elif ds in aerial_datasets and "aerial" in use_d:
279
+ gem_descs["labels"].extend(["aerial"] * d.shape[0])
280
+ else:
281
+ continue
282
+ gem_descs["descs"].append(d)
283
+ pr((i+1) / len(gem_cache["data"]))
284
+ gem_descs["descs"] = np.concatenate(gem_descs["descs"], axis=0)
285
+ pr(1.0)
286
+ return "GeM descriptors loaded from cache", gem_descs
287
+
288
+
289
+ # %%
290
+ # Get GeM pooled features of the uploaded images
291
+ def get_add_gem_descs(imgs_batch, gem_descs, pr = gr.Progress()):
292
+ imgs_batch: List[np.ndarray] = imgs_batch
293
+ gem_descs: dict = gem_descs
294
+ pr(0, desc="Extracting GeM descriptors")
295
+ num_imgs_extracted = 0
296
+ for i, img in enumerate(imgs_batch):
297
+ if img is None:
298
+ print(f"Image {i+1} is None")
299
+ continue
300
+ # Convert to PIL image
301
+ pil_img = Image.fromarray(img)
302
+ img_pt = base_tf(pil_img).to(device)
303
+ if max(img_pt.shape[-2:]) > max_img_size:
304
+ print(f"Image {i+1}: {img_pt.shape[-2:]}, outside")
305
+ c, h, w = img_pt.shape
306
+ # Maintain aspect ratio
307
+ if h == max(img_pt.shape[-2:]):
308
+ w = int(w * max_img_size / h)
309
+ h = max_img_size
310
+ else:
311
+ h = int(h * max_img_size / w)
312
+ w = max_img_size
313
+ img_pt = T.resize(img_pt, (h, w),
314
+ interpolation=T.InterpolationMode.BICUBIC)
315
+ pil_img = pil_img.resize((w, h)) # Backup
316
+ # Make image patchable
317
+ c, h, w = img_pt.shape
318
+ h_new, w_new = (h // 14) * 14, (w // 14) * 14
319
+ img_pt = tvf.CenterCrop((h_new, w_new))(img_pt)[None, ...]
320
+ # Extract descriptors
321
+ ret = extractor(img_pt).cpu() # [1, n_p, d]
322
+ # Get the GeM pooled descriptor
323
+ x = torch.mean(ret**3, dim=-2)
324
+ g_res = x.to(torch.complex64) ** (1/3)
325
+ g_res = torch.abs(g_res) * torch.sign(x) # [1, d]
326
+ g_res = g_res.numpy()
327
+ # Add to state
328
+ gem_descs["labels"].append(f"Image{i+1}")
329
+ gem_descs["descs"] = np.concatenate([gem_descs["descs"],
330
+ g_res])
331
+ num_imgs_extracted += 1
332
+ pr((i+1) / len(imgs_batch))
333
+ pr(1.0)
334
+ gem_descs["num_uimgs"] = num_imgs_extracted
335
+ return gem_descs, "GeM descriptors extracted"
336
+
337
+
338
+ # %%
339
+ # Apply tSNE to the GeM descriptors
340
+ def get_tsne_fm_gem(gem_descs, pr = gr.Progress()):
341
+ pr(0, desc="Applying tSNE to GeM descriptors")
342
+ desc_all: np.ndarray = gem_descs["descs"] # [n, d_dim]
343
+ labels_all: List[str] = gem_descs["labels"] # [n]
344
+ # tSNE projection
345
+ tsne = TSNE(n_components=2, random_state=30, perplexity=50,
346
+ learning_rate=200, init='random')
347
+ desc_2d = tsne.fit_transform(desc_all)
348
+ # Result
349
+ tsne_pts = {
350
+ "labels": labels_all,
351
+ "pts": desc_2d,
352
+ "num_uimgs": gem_descs["num_uimgs"], # Number of user imgs
353
+ }
354
+ pr(1.0)
355
+ return tsne_pts, "tSNE projection done"
356
+
357
+
358
+ # %%
359
+ # Plot tSNE to matplotlib figure
360
+ def plot_tsne(tsne_pts):
361
+ colors = {
362
+ "aerial": (80/255, 0/255, 80/255),
363
+ "indoor": ( 0/255, 76/255, 204/255),
364
+ "urban": ( 0/255, 204/255, 0/255),
365
+ }
366
+ ni = int(tsne_pts["num_uimgs"])
367
+ # Custom colors for user images
368
+ ucs = dipy.get_colors(ni, exclude_colors=list(colors.values())\
369
+ .extend([(0, 0, 0), (1, 1, 1)]),
370
+ colorblind_type="Deuteranomaly")
371
+ for i in range(ni):
372
+ colors[f"Image{i+1}"] = ucs[i]
373
+ fig.clear()
374
+ gs = fig.add_gridspec(1, 1)
375
+ ax = fig.add_subplot(gs[0, 0])
376
+ ax.set_title("tSNE Projection")
377
+ for i, domain in enumerate(list(colors.keys())):
378
+ pts = tsne_pts["pts"][np.array(tsne_pts["labels"]) == domain]
379
+ if domain.startswith("Image"):
380
+ m = "x"
381
+ else:
382
+ m = "o"
383
+ ax.scatter(pts[:, 0], pts[:, 1], label=domain, marker=m,
384
+ color=colors[domain])
385
+ # Put legend at the bottom of axis
386
+ ax.legend()
387
+ ax.set_xticks([])
388
+ ax.set_yticks([])
389
+ fig.set_tight_layout(True)
390
+ # fig.set_tight_layout(True)
391
+ return fig, "tSNE plot created"
392
+
393
+
394
  # %%
395
  print("Interface build started")
396
+
397
+
398
+ # Tab for VLAD cluster assignment visualization
399
+ def tab_cluster_viz():
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  d_vals = [k.title() for k in DOMAINS]
401
+ domain = gr.Radio(d_vals, value=d_vals[0], label="Domain",
402
+ info="The domain of images (for loading VLAD vocabulary)")
403
+ nimg_s = gr.Number(2, label="How many images?", precision=0,
404
+ info=f"Between '1' and '{max_num_imgs}' images. Press "\
405
+ "enter/return to register")
406
  with gr.Row(): # Dynamic row (images in columns)
407
  imgs = [gr.Image(label=f"Image {i+1}", visible=True) \
408
+ for i in range(int(nimg_s.value))] + \
409
  [gr.Image(visible=False) \
410
+ for _ in range(max_num_imgs - int(nimg_s.value))]
411
  for i, img in enumerate(imgs): # Set image as "input"
412
  img.change(lambda _: None, img)
413
  with gr.Row(): # Dynamic row of output (cluster) images
414
  imgs2 = [gr.Image(label=f"VLAD Clusters {i+1}",
415
  visible=False) for i in range(max_num_imgs)]
416
+ nimg_s.submit(var_num_img, nimg_s, imgs)
417
+ blend_alpha = gr.Number(0.4, label="Blending alpha",
418
+ info="Weight for cluster centers (between 0 and 1). "\
419
+ "Higher (close to 1) means greater emphasis on cluster "\
420
+ "visibility. Lower (closer to 0) will show the "\
421
+ "underlying image more. "\
422
+ "Press enter/return to register")
423
  bttn1 = gr.Button("Click Me!") # Cluster assignment
424
+ gr.Markdown("### Status strings")
425
  out_msg1 = gr.Markdown("Select domain and upload images")
426
  out_msg2 = gr.Markdown("For descriptor extraction")
427
  out_msg3 = gr.Markdown("Followed by VLAD assignment")
 
430
  # ---- Utility functions ----
431
  # A wrapper to batch the images
432
  def batch_images(data):
433
+ sv = int(data[nimg_s])
434
  images: List[np.ndarray] = [data[imgs[k]] \
435
  for k in range(sv)]
436
  return images
437
  # A wrapper to unbatch images (and pad to max)
438
+ def unbatch_images(imgs_batch, nimg):
439
  ret = [gr.Image.update(visible=False) \
440
  for _ in range(max_num_imgs)]
441
  if imgs_batch is None or len(imgs_batch) == 0:
442
  return ret
443
+ for i in range(nimg): # nimg only to match input layout
444
+ if i < len(imgs_batch):
445
+ img_np = np.array(imgs_batch[i])
446
+ else:
447
+ img_np = None
448
  ret[i] = gr.Image.update(img_np, visible=True)
449
  return ret
450
 
451
+ # ---- Examples ----
452
+ # Two images from each domain
453
+ gr.Examples(
454
+ [
455
+ ["Aerial", 2,
456
+ "ex_aerial_nardo-air_db-42.png",
457
+ "ex_aerial_nardo-air_qu-42.png",],
458
+ ["Indoor", 2,
459
+ "ex_indoor_17places_db-75.jpg",
460
+ "ex_indoor_17places_qu-75.jpg"],
461
+ ["Urban", 2,
462
+ "ex_urban_oxford_db-75.png",
463
+ "ex_urban_oxford_qu-75.png"],],
464
+ [domain, nimg_s, *imgs],
465
+ )
466
+
467
  # ---- Main pipeline ----
468
  # Get the VLAD cluster assignment images on click
469
  bttn1.click(get_vlad_clusters, domain, [out_msg1, vlad])\
 
474
  .then(get_ca_images,
475
  [desc_assignments, patch_descs, blend_alpha],
476
  [imgs_batch, out_msg4])\
477
+ .then(unbatch_images, [imgs_batch, nimg_s], imgs2)
478
+ # If the blending changes now, update the cluster images only
479
+ blend_alpha.submit(get_ca_images,
480
  [desc_assignments, patch_descs, blend_alpha],
481
  [imgs_batch, out_msg4])\
482
+ .then(unbatch_images, [imgs_batch, nimg_s], imgs2)
483
+
484
+
485
+ # Tab for GeM t-SNE projection plot
486
+ def tab_gem_tsne():
487
+ d_vals = [k.title() for k in DOMAINS]
488
+ dms = gr.CheckboxGroup(d_vals, value=d_vals, label="Domains",
489
+ info="The domains to use for the t-SNE projection")
490
+ nimg_s = gr.Number(2, label="How many images?", precision=0,
491
+ info=f"Between '1' and '{max_num_imgs}' images. Press "\
492
+ "enter/return to register")
493
+ with gr.Row(): # Dynamic row (images in columns)
494
+ imgs = [gr.Image(label=f"Image {i+1}", visible=True) \
495
+ for i in range(int(nimg_s.value))] + \
496
+ [gr.Image(visible=False) \
497
+ for _ in range(max_num_imgs - int(nimg_s.value))]
498
+ for i, img in enumerate(imgs): # Set image as "input"
499
+ img.change(lambda _: None, img)
500
+ nimg_s.submit(var_num_img, nimg_s, imgs)
501
+ tsne_plot = gr.Plot(None, label="tSNE Plot")
502
+ out_msg1 = gr.Markdown("Select domains")
503
+ out_msg2 = gr.Markdown("Upload images")
504
+ out_msg3 = gr.Markdown("Wait for tSNE plots")
505
+
506
+ # A wrapper to batch the images
507
+ def batch_images(data):
508
+ sv = int(data[nimg_s])
509
+ # images: List[np.ndarray] = [data[imgs[k]] \
510
+ # for k in range(sv)]
511
+ images: List[np.ndarray] = []
512
+ for k in range(sv):
513
+ img = data[imgs[k]]
514
+ if img is None:
515
+ return None, f"Image {k+1} is None!"
516
+ images.append(img)
517
+ return images, "Images batched"
518
+
519
+ bttn1 = gr.Button("Click Me!")
520
+
521
+ # ---- Main pipeline ----
522
+ # Get the tSNE plot
523
+ bttn1.click(get_gem_descs_cache, dms, [out_msg1, gem_descs])\
524
+ .then(batch_images, {nimg_s, *imgs, imgs_batch},
525
+ [imgs_batch, out_msg2])\
526
+ .then(get_add_gem_descs, [imgs_batch, gem_descs],
527
+ [gem_descs, out_msg2])\
528
+ .then(get_tsne_fm_gem, gem_descs, [tsne_pts, out_msg3])\
529
+ .then(plot_tsne, tsne_pts, [tsne_plot, out_msg3])
530
+
531
+
532
+ # Build the interface
533
+ with gr.Blocks() as demo:
534
+ # Main header
535
+ gr.Markdown(header_markdown)
536
+
537
+ # ---- Helper functions ----
538
+ # Variable number of input images (show/hide UI image array)
539
+ def var_num_img(s):
540
+ n = int(s) # Slider (string) value as int
541
+ assert 1 <= n <= max_num_imgs, f"Invalid num of images: {n}!"
542
+ return [gr.Image.update(label=f"Image {i+1}", visible=True) \
543
+ for i in range(n)] \
544
+ + [gr.Image.update(visible=False) \
545
+ for _ in range(max_num_imgs - n)]
546
+
547
+ # ---- State declarations ----
548
+ vlad = gr.State() # VLAD object
549
+ desc_assignments = gr.State() # Cluster assignments
550
+ imgs_batch = gr.State() # Images as batch
551
+ patch_descs = gr.State() # Patch descriptors
552
+ gem_descs = gr.State() # GeM descriptors (of each state)
553
+ tsne_pts = gr.State() # tSNE points
554
+
555
+ # ---- All UI elements ----
556
+ with gr.Tab("GeM t-SNE Projection"):
557
+ gr.Markdown(
558
+ """
559
+ ## GeM t-SNE Projection
560
+
561
+ Select the domains (toggle visibility) for t-SNE plot. \
562
+ Enter the number of images to upload and upload images. \
563
+ Then click the button to get the t-SNE plot.
564
+
565
+ """)
566
+ tab_gem_tsne()
567
+
568
+ with gr.Tab("Cluster Visualization"):
569
+ gr.Markdown(
570
+ """
571
+ ## Cluster Visualizations
572
+
573
+ Select the domain for the images (all should be from the \
574
+ same domain). Enter the number of images to upload. \
575
+ Upload the images. Then click the button to get the \
576
+ cluster assignment images.
577
+
578
+ You can also directly click on one of the examples (at \
579
+ the bottom) to load the data and then click the button \
580
+ to get the cluster assignment images.
581
+
582
+ """)
583
+ tab_cluster_viz()
584
 
585
  print("Interface build completed")
586
 
cache/gem_cache/result_dino_v2.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d27de9e8552eeeca6fd3f99fca135429169adcf926388d65eb44bb1ba9391f5
3
+ size 1990740
cache/gem_cache/result_dino_v2_tsne.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77d978ec8fcb5f10075db999d037925846f2cbad522e5d0266ddcbfb50cd99de
3
+ size 3192
requirements.txt CHANGED
@@ -8,4 +8,5 @@ matplotlib
8
  distinctipy
9
  einops
10
  fast_pytorch_kmeans
11
-
 
 
8
  distinctipy
9
  einops
10
  fast_pytorch_kmeans
11
+ joblib==1.2.0
12
+ sklearn==1.0.2