File size: 22,118 Bytes
61364af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476e478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61364af
 
 
 
 
 
 
 
 
 
 
 
476e478
61364af
476e478
61364af
 
 
 
 
476e478
61364af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476e478
61364af
 
 
 
 
476e478
61364af
 
 
476e478
1c2f636
 
61364af
 
 
 
476e478
 
 
 
 
 
 
 
 
 
61364af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476e478
 
 
61364af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476e478
61364af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476e478
 
 
 
61364af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476e478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61364af
 
476e478
 
 
 
61364af
476e478
 
 
 
 
61364af
 
476e478
61364af
476e478
61364af
 
 
 
 
476e478
 
 
 
 
 
 
61364af
476e478
61364af
 
 
 
 
 
 
 
476e478
61364af
 
 
 
476e478
61364af
 
 
 
476e478
 
 
 
 
61364af
 
 
476e478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61364af
 
 
 
 
 
 
 
 
 
476e478
 
 
61364af
 
476e478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61364af
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
# Show VLAD clustering for set of example images or a user image
"""
    User input:
    - Domain: Indoor, Aerial, or Urban
    - Image: Image to be clustered
    - Cluster numbers (to visualize)
    - Pixel coordinates (to pick further clusters)
    - A unique cache ID (to store the DINO forward passes)
    
    There are example images for each domain.
    
    Output:
    - All images with cluster assignments
    
    Some Gradio links:
    - Controlling layout
        - https://www.gradio.app/guides/quickstart#blocks-more-flexibility-and-control
    - Data state (persistence)
        - https://www.gradio.app/guides/interface-state
        - https://www.gradio.app/docs/state
    - Layout control
        - https://www.gradio.app/guides/controlling-layout
        - https://www.gradio.app/guides/blocks-and-event-listeners
"""

# A markdown string shown at the top of the app
header_markdown = """
# AnyLoc Demo

\|  [Website](https://anyloc.github.io/) \| \
    [GitHub](https://github.com/AnyLoc/AnyLoc) \| \
    [YouTube](https://youtu.be/ITo8rMInatk) \|


This space contains a collection of demos for AnyLoc. Each demo is a \
self-contained application in the tabs below. The following \
applications are included

1. **GeM t-SNE Projection**: Upload a set of images and see where \
    they land on a t-SNE projection of GeM descriptors from many \
    domains. This can be used to guide domain selection (from a few \
    representative images).
2. **Cluster Visualization**: This visualizes the VLAD cluster \
    assignments for the patch descriptors. You need to select the \
    domain for loading VLAD cluster centers (vocabulary).

We do **not** save any images uploaded to the demo. Some errors may \
leave a log. We do not collect any information about the user.

🥳 Thanks to HuggingFace for providing a free GPU for this demo.

"""

# %%
import os
import gradio as gr
import numpy as np
import cv2 as cv
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms as tvf
from torchvision.transforms import functional as T
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import distinctipy as dipy
import joblib
from typing import Literal, List
import gradio as gr
import time
import glob
import shutil
import matplotlib.pyplot as plt
from copy import deepcopy
# DINOv2 imports
from utilities import DinoV2ExtractFeatures
from utilities import VLAD

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# %%
# Configurations
T1 = Literal["query", "key", "value", "token"]
T2 = Literal["aerial", "indoor", "urban"]
DOMAINS = ["aerial", "indoor", "urban"]
T3 = Literal["dinov2_vits14", "dinov2_vitb14", "dinov2_vitl14", 
                "dinov2_vitg14"]
_ex = lambda x: os.path.realpath(os.path.expanduser(x))
dino_model: T3 = "dinov2_vitg14"
desc_layer: int = 31
desc_facet: T1 = "value"
num_c: int = 8
cache_dir: str = _ex("./cache") # Directory containing program cache
max_img_size: int = 1024    # Image resolution (max dim/size)
max_num_imgs: int = 16      # Max number of images to upload
share: bool = False          # Share application using .gradio link

# Verify inputs
assert os.path.isdir(cache_dir), "Cache directory not found"


# %%
# Model and transforms
print("Loading DINO model")
# extractor = None  # FIXME: For quick testing only
extractor = DinoV2ExtractFeatures(dino_model, desc_layer, desc_facet, 
                                    device=device)
print("DINO model loaded")
# VLAD path (directory)
ext_s = f"{dino_model}/l{desc_layer}_{desc_facet}_c{num_c}"
vc_dir = os.path.join(cache_dir, "vocabulary", ext_s)
assert os.path.isdir(vc_dir), f"VLAD directory: {vc_dir} not found"
# GeM path (cache)
gem_cf = os.path.join(cache_dir, "gem_cache", "result_dino_v2.gz")
assert os.path.isfile(gem_cf), f"GeM cache: {gem_cf} not found"
gem_cache = joblib.load(gem_cf)
assert gem_cache["model"]["type"] == dino_model
assert gem_cache["model"]["layer"] == desc_layer
assert gem_cache["model"]["facet"] == desc_facet
fig = plt.figure()  # Main figure
fig.clear()
# Base image transformations
base_tf = tvf.Compose([
    tvf.ToTensor(),
    tvf.Normalize(mean=[0.485, 0.456, 0.406], 
                    std=[0.229, 0.224, 0.225])
])


# %%
# Get VLAD object
def get_vlad_clusters(domain, pr = gr.Progress()):
    dm: T2 = str(domain).lower()
    assert dm in DOMAINS, "Invalid domain"
    # Load VLAD cluster centers
    pr(0, desc="Loading VLAD clusters")
    c_centers_file = os.path.join(vc_dir, dm, "c_centers.pt")
    if not os.path.isfile(c_centers_file):
        return f"Cluster centers not found for: {domain}", None
    c_centers = torch.load(c_centers_file)
    pr(0.5)
    num_c = c_centers.shape[0]
    desc_dim = c_centers.shape[1]
    vlad = VLAD(num_c, desc_dim, 
            cache_dir=os.path.dirname(c_centers_file))
    vlad.fit(None)  # Restore the cache
    pr(1)
    return f"VLAD clusters loaded for: {domain}", vlad


# %%
# Get VLAD descriptors
@torch.no_grad()
def get_descs(imgs_batch, pr = gr.Progress()):
    imgs_batch: List[np.ndarray] = imgs_batch
    pr(0, desc="Extracting descriptors")
    patch_descs = []
    for i, img in enumerate(imgs_batch):
        if img is None:
            print(f"Image {i+1} is None")
            continue
        # Convert to PIL image
        pil_img = Image.fromarray(img)
        img_pt = base_tf(pil_img).to(device)
        if max(img_pt.shape[-2:]) > max_img_size:
            print(f"Image {i+1}: {img_pt.shape[-2:]}, outside")
            c, h, w = img_pt.shape
            # Maintain aspect ratio
            if h == max(img_pt.shape[-2:]):
                w = int(w * max_img_size / h)
                h = max_img_size
            else:
                h = int(h * max_img_size / w)
                w = max_img_size
            img_pt = T.resize(img_pt, (h, w), 
                interpolation=T.InterpolationMode.BICUBIC)
            pil_img = pil_img.resize((w, h))    # Backup
        # Make image patchable
        c, h, w = img_pt.shape
        h_new, w_new = (h // 14) * 14, (w // 14) * 14
        img_pt = tvf.CenterCrop((h_new, w_new))(img_pt)[None, ...]
        # Extract descriptors
        ret = extractor(img_pt).cpu()  # [1, n_p, d]
        patch_descs.append({"img": pil_img, "descs": ret})
        pr((i+1) / len(imgs_batch))
    pr(1.0)
    return patch_descs, \
            f"Descriptors extracted for {len(imgs_batch)} images"


# %%
# Assign VLAD clusters (descriptor assignment)
def assign_vlad(patch_descs, vlad, pr = gr.Progress()):
    vlad: VLAD = vlad
    img_patch_descs = [pd["descs"] for pd in patch_descs]
    pr(0, desc="Assigning VLAD clusters")
    desc_assignments = []   # List[Tensor;shape=('h', 'w');int]
    for i, qu_desc in enumerate(img_patch_descs):
        # Residual vectors; 'n' could differ (based on img sizes)
        res = vlad.generate_res_vec(qu_desc[0]) # ['n', n_c, d]
        img = patch_descs[i]["img"]
        h, w, c = np.array(img).shape
        h_p, w_p = h // 14, w // 14
        h_new, w_new = h_p * 14, w_p * 14
        assert h_p * w_p == res.shape[0], "Residual incorrect!"
        # Descriptor assignments
        da = res.abs().sum(dim=2).argmin(dim=1).reshape(h_p, w_p)
        da = F.interpolate(da[None, None, ...].to(float),
                (h_new, w_new), mode="nearest")[0, 0].to(da.dtype)
        desc_assignments.append(da)
        pr((i+1) / len(img_patch_descs))
    pr(1.0)
    return desc_assignments, "VLAD clusters assigned"


# %%
# Cluster assignments to images
def get_ca_images(desc_assignments, patch_descs, alpha,
            pr = gr.Progress()):
    if desc_assignments is None or len(desc_assignments) == 0:
        if not 0 <= alpha <= 1:
            return None, f"Invalid alpha value: {alpha} (should be "\
                    "between 0 and 1)"
        return None, "First load the images"
    c_colors = dipy.get_colors(num_c, rng=928, 
            colorblind_type="Deuteranomaly")
    np_colors = (np.array(c_colors) * 255).astype(np.uint8)
    # Get images with clusters
    pil_imgs = [pd["img"] for pd in patch_descs]
    res_imgs = []   # List[PIL.Image]
    pr(0, desc="Generating cluster assignment images")
    for i, pil_img in enumerate(pil_imgs):
        # Descriptor assignment image: [h, w, 3]
        da: torch.Tensor = desc_assignments[i]    # ['h', 'w']
        da_img = np.zeros((*da.shape, 3), dtype=np.uint8)
        for c in range(num_c):
            da_img[da == c] = np_colors[c]
        # Background image: [h, w, 3]
        img_np = np.array(pil_img, dtype=np.uint8)
        h, w, c = np.array(img_np).shape
        h_p, w_p = (h // 14), (w // 14)
        h_new, w_new = h_p * 14, w_p * 14
        img_np = F.interpolate(torch.tensor(img_np)\
                .permute(2, 0, 1)[None, ...], (h_new, w_new),
                mode='nearest')[0].permute(1, 2, 0).numpy()
        res_img = cv.addWeighted(img_np, 1 - alpha, da_img, alpha, 0.)
        res_imgs.append(Image.fromarray(res_img))
        pr((i+1) / len(pil_imgs))
    pr(1.0)
    return res_imgs, "Cluster assignment images generated"


# %%
# Get GeM descriptors from cache
def get_gem_descs_cache(use_d, pr = gr.Progress()):
    use_d: List[str] = use_d
    if len(use_d) == 0:
        return "Select at least one domain", None
    else:
        use_d = [d.lower() for d in use_d]
    indoor_datasets = ["baidu_datasets", "gardens", "17places"]
    urban_datasets = ["pitts30k", "st_lucia", "Oxford"]
    aerial_datasets = ["Tartan_GNSS_test_rotated", 
            "Tartan_GNSS_test_notrotated", "VPAir"]
    pr(0, desc="Loading GeM descriptors from cache")
    gem_descs = {
        "labels": [],
        "descs": [],
    }
    for i, ds in enumerate(gem_cache["data"]):
        # GeM descriptors from data: n_desc, desc_dim
        d: np.ndarray = gem_cache["data"][ds]["descriptors"]
        if ds in indoor_datasets and "indoor" in use_d:
            gem_descs["labels"].extend(["indoor"] * d.shape[0])
        elif ds in urban_datasets and "urban" in use_d:
            gem_descs["labels"].extend(["urban"] * d.shape[0])
        elif ds in aerial_datasets and "aerial" in use_d:
            gem_descs["labels"].extend(["aerial"] * d.shape[0])
        else:
            continue
        gem_descs["descs"].append(d)
        pr((i+1) / len(gem_cache["data"]))
    gem_descs["descs"] = np.concatenate(gem_descs["descs"], axis=0)
    pr(1.0)
    return "GeM descriptors loaded from cache", gem_descs


# %%
# Get GeM pooled features of the uploaded images
def get_add_gem_descs(imgs_batch, gem_descs, pr = gr.Progress()):
    imgs_batch: List[np.ndarray] = imgs_batch
    gem_descs: dict = gem_descs
    pr(0, desc="Extracting GeM descriptors")
    num_imgs_extracted = 0
    for i, img in enumerate(imgs_batch):
        if img is None:
            print(f"Image {i+1} is None")
            continue
        # Convert to PIL image
        pil_img = Image.fromarray(img)
        img_pt = base_tf(pil_img).to(device)
        if max(img_pt.shape[-2:]) > max_img_size:
            print(f"Image {i+1}: {img_pt.shape[-2:]}, outside")
            c, h, w = img_pt.shape
            # Maintain aspect ratio
            if h == max(img_pt.shape[-2:]):
                w = int(w * max_img_size / h)
                h = max_img_size
            else:
                h = int(h * max_img_size / w)
                w = max_img_size
            img_pt = T.resize(img_pt, (h, w), 
                interpolation=T.InterpolationMode.BICUBIC)
            pil_img = pil_img.resize((w, h))    # Backup
        # Make image patchable
        c, h, w = img_pt.shape
        h_new, w_new = (h // 14) * 14, (w // 14) * 14
        img_pt = tvf.CenterCrop((h_new, w_new))(img_pt)[None, ...]
        # Extract descriptors
        ret = extractor(img_pt).cpu()  # [1, n_p, d]
        # Get the GeM pooled descriptor
        x = torch.mean(ret**3, dim=-2)
        g_res = x.to(torch.complex64) ** (1/3)
        g_res = torch.abs(g_res) * torch.sign(x)    # [1, d]
        g_res = g_res.numpy()
        # Add to state
        gem_descs["labels"].append(f"Image{i+1}")
        gem_descs["descs"] = np.concatenate([gem_descs["descs"], 
                                            g_res])
        num_imgs_extracted += 1
        pr((i+1) / len(imgs_batch))
    pr(1.0)
    gem_descs["num_uimgs"] = num_imgs_extracted
    return gem_descs, "GeM descriptors extracted"


# %%
# Apply tSNE to the GeM descriptors
def get_tsne_fm_gem(gem_descs, pr = gr.Progress()):
    pr(0, desc="Applying tSNE to GeM descriptors")
    desc_all: np.ndarray = gem_descs["descs"]   # [n, d_dim]
    labels_all: List[str] = gem_descs["labels"] # [n]
    # tSNE projection
    tsne = TSNE(n_components=2, random_state=30, perplexity=50, 
            learning_rate=200, init='random')
    desc_2d = tsne.fit_transform(desc_all)
    # Result
    tsne_pts = {
        "labels": labels_all,
        "pts": desc_2d,
        "num_uimgs": gem_descs["num_uimgs"],    # Number of user imgs
    }
    pr(1.0)
    return tsne_pts, "tSNE projection done"


# %%
# Plot tSNE to matplotlib figure
def plot_tsne(tsne_pts):
    colors = {
        "aerial": (80/255,  0/255,  80/255),
        "indoor": ( 0/255, 76/255, 204/255),
        "urban":  ( 0/255, 204/255,  0/255),
    }
    ni = int(tsne_pts["num_uimgs"])
    # Custom colors for user images
    ucs = dipy.get_colors(ni, exclude_colors=list(colors.values())\
            .extend([(0, 0, 0), (1, 1, 1)]), 
            colorblind_type="Deuteranomaly")
    for i in range(ni):
        colors[f"Image{i+1}"] = ucs[i]
    fig.clear()
    gs = fig.add_gridspec(1, 1)
    ax = fig.add_subplot(gs[0, 0])
    ax.set_title("tSNE Projection")
    for i, domain in enumerate(list(colors.keys())):
        pts = tsne_pts["pts"][np.array(tsne_pts["labels"]) == domain]
        if domain.startswith("Image"):
            m = "x"
        else:
            m = "o"
        ax.scatter(pts[:, 0], pts[:, 1], label=domain, marker=m,
                color=colors[domain])
    # Put legend at the bottom of axis
    ax.legend()
    ax.set_xticks([])
    ax.set_yticks([])
    fig.set_tight_layout(True)
    # fig.set_tight_layout(True)
    return fig, "tSNE plot created"


# %%
print("Interface build started")


# Tab for VLAD cluster assignment visualization
def tab_cluster_viz():
    d_vals = [k.title() for k in DOMAINS]
    domain = gr.Radio(d_vals, value=d_vals[0], label="Domain",
            info="The domain of images (for loading VLAD vocabulary)")
    nimg_s = gr.Number(2, label="How many images?", precision=0,
            info=f"Between '1' and '{max_num_imgs}' images. Press "\
                    "enter/return to register")
    with gr.Row():  # Dynamic row (images in columns)
        imgs = [gr.Image(label=f"Image {i+1}", visible=True) \
                for i in range(int(nimg_s.value))] + \
                [gr.Image(visible=False) \
                for _ in range(max_num_imgs - int(nimg_s.value))]
        for i, img in enumerate(imgs):  # Set image as "input"
            img.change(lambda _: None, img)
    with gr.Row():  # Dynamic row of output (cluster) images
        imgs2 = [gr.Image(label=f"VLAD Clusters {i+1}", 
                visible=False) for i in range(max_num_imgs)]
    nimg_s.submit(var_num_img, nimg_s, imgs)
    blend_alpha = gr.Number(0.4, label="Blending alpha",
        info="Weight for cluster centers (between 0 and 1). "\
            "Higher (close to 1) means greater emphasis on cluster "\
                "visibility. Lower (closer to 0) will show the "\
                "underlying image more. "\
            "Press enter/return to register")
    bttn1 = gr.Button("Click Me!")  # Cluster assignment
    gr.Markdown("### Status strings")
    out_msg1 = gr.Markdown("Select domain and upload images")
    out_msg2 = gr.Markdown("For descriptor extraction")
    out_msg3 = gr.Markdown("Followed by VLAD assignment")
    out_msg4 = gr.Markdown("Followed by cluster images")
    
    # ---- Utility functions ----
    # A wrapper to batch the images
    def batch_images(data):
        sv = int(data[nimg_s])
        images: List[np.ndarray] = [data[imgs[k]] \
                for k in range(sv)]
        return images
    # A wrapper to unbatch images (and pad to max)
    def unbatch_images(imgs_batch, nimg):
        ret = [gr.Image.update(visible=False) \
                for _ in range(max_num_imgs)]
        if imgs_batch is None or len(imgs_batch) == 0:
            return ret
        for i in range(nimg):   # nimg only to match input layout
            if i < len(imgs_batch):
                img_np = np.array(imgs_batch[i])
            else:
                img_np = None
            ret[i] = gr.Image.update(img_np, visible=True)
        return ret
    
    # ---- Examples ----
    # Two images from each domain
    gr.Examples(
        [
        ["Aerial", 2, 
            "ex_aerial_nardo-air_db-42.png",
            "ex_aerial_nardo-air_qu-42.png",],
        ["Indoor", 2,
            "ex_indoor_17places_db-75.jpg",
            "ex_indoor_17places_qu-75.jpg"],
        ["Urban", 2,
            "ex_urban_oxford_db-75.png",
            "ex_urban_oxford_qu-75.png"],],
        [domain, nimg_s, *imgs],
    )
    
    # ---- Main pipeline ----
    # Get the VLAD cluster assignment images on click
    bttn1.click(get_vlad_clusters, domain, [out_msg1, vlad])\
        .then(batch_images, {nimg_s, *imgs, imgs_batch}, imgs_batch)\
        .then(get_descs, imgs_batch, [patch_descs, out_msg2])\
        .then(assign_vlad, [patch_descs, vlad], 
                [desc_assignments, out_msg3])\
        .then(get_ca_images, 
                [desc_assignments, patch_descs, blend_alpha],
                [imgs_batch, out_msg4])\
        .then(unbatch_images, [imgs_batch, nimg_s], imgs2)
    # If the blending changes now, update the cluster images only
    blend_alpha.submit(get_ca_images, 
            [desc_assignments, patch_descs, blend_alpha],
            [imgs_batch, out_msg4])\
        .then(unbatch_images, [imgs_batch, nimg_s], imgs2)


# Tab for GeM t-SNE projection plot
def tab_gem_tsne():
    d_vals = [k.title() for k in DOMAINS]
    dms = gr.CheckboxGroup(d_vals, value=d_vals, label="Domains",
            info="The domains to use for the t-SNE projection")
    nimg_s = gr.Number(2, label="How many images?", precision=0,
            info=f"Between '1' and '{max_num_imgs}' images. Press "\
                    "enter/return to register")
    with gr.Row():  # Dynamic row (images in columns)
        imgs = [gr.Image(label=f"Image {i+1}", visible=True) \
                for i in range(int(nimg_s.value))] + \
                [gr.Image(visible=False) \
                for _ in range(max_num_imgs - int(nimg_s.value))]
        for i, img in enumerate(imgs):  # Set image as "input"
            img.change(lambda _: None, img)
    nimg_s.submit(var_num_img, nimg_s, imgs)
    tsne_plot = gr.Plot(None, label="tSNE Plot")
    out_msg1 = gr.Markdown("Select domains")
    out_msg2 = gr.Markdown("Upload images")
    out_msg3 = gr.Markdown("Wait for tSNE plots")
    
    # A wrapper to batch the images
    def batch_images(data):
        sv = int(data[nimg_s])
        # images: List[np.ndarray] = [data[imgs[k]] \
        #         for k in range(sv)]
        images: List[np.ndarray] = []
        for k in range(sv):
            img = data[imgs[k]]
            if img is None:
                return None, f"Image {k+1} is None!"
            images.append(img)
        return images, "Images batched"
    
    bttn1 = gr.Button("Click Me!")
    
    # ---- Main pipeline ----
    # Get the tSNE plot
    bttn1.click(get_gem_descs_cache, dms, [out_msg1, gem_descs])\
        .then(batch_images, {nimg_s, *imgs, imgs_batch}, 
                [imgs_batch, out_msg2])\
        .then(get_add_gem_descs, [imgs_batch, gem_descs],
                [gem_descs, out_msg2])\
        .then(get_tsne_fm_gem, gem_descs, [tsne_pts, out_msg3])\
        .then(plot_tsne, tsne_pts, [tsne_plot, out_msg3])


# Build the interface
with gr.Blocks() as demo:
    # Main header
    gr.Markdown(header_markdown)
    
    # ---- Helper functions ----
    # Variable number of input images (show/hide UI image array)
    def var_num_img(s):
        n = int(s)  # Slider (string) value as int
        assert 1 <= n <= max_num_imgs, f"Invalid num of images: {n}!"
        return [gr.Image.update(label=f"Image {i+1}", visible=True) \
                for i in range(n)] \
            + [gr.Image.update(visible=False) \
                for _ in range(max_num_imgs - n)]
    
    # ---- State declarations ----
    vlad = gr.State()   # VLAD object
    desc_assignments = gr.State()   # Cluster assignments
    imgs_batch = gr.State() # Images as batch
    patch_descs = gr.State()    # Patch descriptors
    gem_descs = gr.State()  # GeM descriptors (of each state)
    tsne_pts = gr.State()   # tSNE points
    
    # ---- All UI elements ----
    with gr.Tab("GeM t-SNE Projection"):
        gr.Markdown(
            """
            ## GeM t-SNE Projection
            
            Select the domains (toggle visibility) for t-SNE plot. \
            Enter the number of images to upload and upload images. \
            Then click the button to get the t-SNE plot.
            
            """)
        tab_gem_tsne()
    
    with gr.Tab("Cluster Visualization"):
        gr.Markdown(
            """
            ## Cluster Visualizations
            
            Select the domain for the images (all should be from the \
            same domain). Enter the number of images to upload. \
            Upload the images. Then click the button to get the \
            cluster assignment images.
            
            You can also directly click on one of the examples (at \
            the bottom) to load the data and then click the button \
            to get the cluster assignment images.
            
            """)
        tab_cluster_viz()

print("Interface build completed")


# %%
# Deploy application
demo.queue().launch(share=share)
print("Application deployment ended, exiting...")