File size: 28,356 Bytes
dd3c1c5
4f09ecf
 
 
 
 
 
 
dd3c1c5
 
 
4f09ecf
 
0a3d091
 
 
 
dd3c1c5
e595323
0a3d091
dd3c1c5
 
faa901b
3cbeaeb
ab93c8a
182a98a
dd3c1c5
4f09ecf
 
 
 
 
 
 
a3c38a0
4f09ecf
 
e595323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212e8c3
 
 
 
8b66bab
 
 
212e8c3
 
 
 
 
 
 
 
 
 
 
 
 
a336540
9780b6e
a336540
 
 
 
 
 
 
 
 
 
 
 
 
 
212e8c3
4f09ecf
e2ca106
4f09ecf
 
 
ece2dc2
4f09ecf
 
 
 
 
 
 
 
 
 
40d2b47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f09ecf
faa901b
 
 
 
dd3c1c5
4f09ecf
dd3c1c5
d8cc1c3
 
 
 
 
 
3cbeaeb
 
 
eb3bad4
 
 
e93ae13
eb3bad4
3cbeaeb
d8cc1c3
e93ae13
 
 
7af86d2
 
 
 
 
a0d401d
 
7af86d2
 
 
 
e93ae13
7af86d2
d8cc1c3
3cbeaeb
0a3d091
e93ae13
 
 
 
 
 
 
 
 
 
 
 
3cbeaeb
0a3d091
3cbeaeb
0a3d091
 
 
3cbeaeb
 
 
 
 
 
 
40d2b47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cbeaeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a336540
 
 
 
 
 
 
 
 
 
 
ab93c8a
a336540
 
ab93c8a
a9de956
 
 
 
 
 
 
8b66bab
a9de956
 
 
 
 
 
 
 
 
ab93c8a
 
 
 
40d2b47
ab93c8a
 
 
 
40d2b47
ab93c8a
 
e93ae13
 
ab93c8a
 
212e8c3
 
3cbeaeb
 
 
 
cfbbf4b
 
5364222
 
 
dd3c1c5
e595323
3cbeaeb
 
 
 
 
 
 
208b0a9
 
 
 
 
 
 
 
 
3cbeaeb
 
 
 
 
e595323
 
3cbeaeb
208b0a9
 
 
 
 
 
 
3cbeaeb
 
 
 
5364222
 
8b66bab
a9de956
 
5364222
 
8b66bab
 
56bd821
8b66bab
 
 
 
 
 
 
 
 
cfbbf4b
c246744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7af86d2
 
 
 
 
 
c246744
 
 
 
e93ae13
7af86d2
c246744
cfbbf4b
 
3cbeaeb
e595323
 
3cbeaeb
 
 
 
 
0a3d091
212e8c3
 
e93ae13
 
212e8c3
68a0d40
 
0c5b121
 
 
 
 
 
 
 
 
 
 
 
7af86d2
 
 
68a0d40
 
 
 
7af86d2
 
 
 
 
 
 
 
 
 
 
 
 
 
e93ae13
7af86d2
dd3c1c5
 
 
4f09ecf
dd3c1c5
4f09ecf
dd3c1c5
944833f
a194370
8987ccf
 
cfbbf4b
dd3c1c5
4f09ecf
456f9ce
 
944833f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456f9ce
 
 
 
dd3c1c5
 
 
2807645
dd3c1c5
 
 
0a3d091
dd3c1c5
 
 
 
f996296
dd3c1c5
 
 
 
0a3d091
dd3c1c5
 
3cbeaeb
4f09ecf
 
4ba9d6e
 
e5efc3c
9bf5bc2
7af86d2
 
3cbeaeb
9bf5bc2
7af86d2
 
 
 
e93ae13
7af86d2
 
 
 
 
 
 
 
 
 
4f09ecf
 
dd3c1c5
4f09ecf
97a2227
d58fe81
dd3c1c5
 
 
 
4c07482
 
e648d5c
c14f0cd
7e159c0
4c07482
 
7e159c0
e648d5c
97a2227
36c3482
 
 
97a2227
2b62f64
 
 
 
 
 
97a2227
4f09ecf
7af86d2
3cbeaeb
97a2227
 
 
4f09ecf
3cbeaeb
 
e93ae13
 
dd3c1c5
 
456f9ce
 
 
f952795
456f9ce
 
dd3c1c5
4f09ecf
dd3c1c5
4f09ecf
212e8c3
eaff0d6
 
212e8c3
 
 
 
a194370
 
212e8c3
 
 
 
 
 
e93ae13
212e8c3
 
e93ae13
 
 
 
 
 
212e8c3
 
 
e93ae13
212e8c3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
"""
Simplified Gradio demo for Search-TTA evaluation.
This version mirrors the layout of `app_BACKUP.py` but:
1. Loads no OpenCLIP / CLAP / Satellite encoders at import-time.
2. Keeps only the Satellite and Ground-level image inputs.
3. Exposes the high-level wrapper classes `ClipSegTTA` and
`TestWorker` and calls `TestWorker.run_episode` inside the
`process` callback.
"""

# ────────────────────────── imports ───────────────────────────────────
from pathlib import Path

# Use non-GUI backend to avoid Tkinter errors in background threads
import matplotlib
matplotlib.use("Agg", force=True)

import gradio as gr
import ctypes  # for safely stopping background threads
import os, glob, threading, time
import torch
from PIL import Image
import json
import copy
import shutil
import spaces   # integration with ZeroGPU on hf

# Import configuration & RL / TTA utilities -------------------------------------------------
# NOTE: we import * so that the global names (e.g. USE_GPU, MODEL_NAME, etc.)
#       are available exactly as referenced later in the unchanged snippet.
from test_parameter import *          # noqa: F403, F401  (wild-import is intentional here)

from model import PolicyNet           # noqa: E402 – after wild import on purpose
from test_multi_robot_worker import TestWorker  # noqa: E402
from Taxabind.TaxaBind.SatBind.clip_seg_tta import ClipSegTTA   # noqa: E402


# Helper to kill a Python thread by injecting SystemExit

def _stop_thread(thread: threading.Thread):
    """Forcefully raise SystemExit in the given thread (best-effort)."""
    if thread is None or not thread.is_alive():
        return
    tid = thread.ident
    if tid is None:
        return
    # Ask CPython to raise SystemExit in the thread context
    res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), ctypes.py_object(SystemExit))
    if res > 1:
        # If it returned >1, cleanup and fail safe
        ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), None)

# ──────────── Thread Registry for Cleanup on Tab Switch ─────────────
_running_threads: list[threading.Thread] = []
_running_threads_lock = threading.Lock()

# Map worker threads to their ClipSegTTA instance so UI can read executing_tta flag
_thread_clip_map: dict[threading.Thread, ClipSegTTA] = {}

def _register_thread(th: threading.Thread):
    """Record a newly started worker thread so we can later cancel it."""
    with _running_threads_lock:
        _running_threads.append(th)

def _kill_running_threads():
    """Stop all worker threads that are still alive."""
    with _running_threads_lock:
        for t in list(_running_threads):
            _stop_thread(t)
        # Clear list regardless of alive status
        _running_threads.clear()

# ──────────── Run directory rotation ─────────────
RUN_HISTORY_LIMIT = 30  # keep at most this many timestamped run directories per instance

def _prune_old_run_dirs(base_dir: str, limit: int = RUN_HISTORY_LIMIT):
    """Delete oldest timestamp-named run directories leaving only *limit* of the newest ones."""
    try:
        dirs = [d for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d))]
        # Timestamp format YYYYmmdd_HHMMSS ensures lexicographic order == chronological order
        dirs.sort()
        if len(dirs) > limit:
            for obsolete in dirs[:-limit]:
                shutil.rmtree(os.path.join(base_dir, obsolete), ignore_errors=True)
    except Exception:
        # Best-effort; ignore cleanup errors
        pass


# CHANGE ME!
POLL_INTERVAL = 1.0 # For visualization

# Prepare the model
# device = torch.device('cpu') #if USE_GPU_TRAINING else torch.device('cpu')
device = torch.device('cuda') if USE_GPU and torch.cuda.is_available() else torch.device('cpu')
policy_net = PolicyNet(INPUT_DIM, EMBEDDING_DIM).to(device)
# script_dir = os.path.dirname(os.path.abspath(__file__))
script_dir = Path(__file__).resolve().parent
print("real_script_dir: ", script_dir)
# checkpoint = torch.load(f'{script_dir}/modules/vlm_search/{model_path}/{MODEL_NAME}')
checkpoint = torch.load(f'{model_path}/{MODEL_NAME}')
policy_net.load_state_dict(checkpoint['policy_model'])
print('Model loaded!')
# print(next(policy_net.parameters()).device)

# # (ClipSegTTA will now be instantiated lazily inside each planner thread)
# clip_seg_tta_1 = clip_seg_tta_2 = None  # placeholder; real instances created per thread
# if False and TAXABIND_TTA:
#     # Instantiate TWO independent ClipSegTTA objects (one per concurrent run)
#     clip_seg_tta_1 = ClipSegTTA(
#         img_dir=TAXABIND_IMG_DIR,
#         imo_dir=TAXABIND_IMO_DIR,
#         json_path=TAXABIND_INAT_JSON_PATH,
#         sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH,
#         patch_size=TAXABIND_PATCH_SIZE,
#         sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH,
#         sample_index = -1,   # Set using 'reset' in worker
#         blur_kernel = TAXABIND_GAUSSIAN_BLUR_KERNEL,
#         device=device, # device,
#         sat_to_img_ids_json_is_train_dict=False, # for search ds val
#         tax_to_filter_val=QUERY_TAX,
#         load_model=USE_CLIP_PREDS,
#         initial_modality=INITIAL_MODALITY,
#         sound_data_path = TAXABIND_SOUND_DATA_PATH,
#         sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH,
#         # sat_filtered_json_path=TAXABIND_FILTERED_INAT_JSON_PATH,
#     )
#     clip_seg_tta_2 = ClipSegTTA(
#         img_dir=TAXABIND_IMG_DIR,
#         imo_dir=TAXABIND_IMO_DIR,
#         json_path=TAXABIND_INAT_JSON_PATH,
#         sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH,
#         patch_size=TAXABIND_PATCH_SIZE,
#         sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH,
#         sample_index = -1,   # Set using 'reset' in worker
#         blur_kernel = TAXABIND_GAUSSIAN_BLUR_KERNEL,
#         device=device,
#         sat_to_img_ids_json_is_train_dict=False,
#         tax_to_filter_val=QUERY_TAX,
#         load_model=USE_CLIP_PREDS,
#         initial_modality=INITIAL_MODALITY,
#         sound_data_path=TAXABIND_SOUND_DATA_PATH,
#         sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH,
#     )
    
    

# Load metadata json
tgts_metadata_json_path = os.path.join(script_dir, "examples/metadata.json")
tgts_metadata = json.load(open(tgts_metadata_json_path))


# ────────────────────────── Gradio process fn ─────────────────────────

# Helper wrappers so that Gradio recognises streaming (generator) functions
# NOTE: A lambda that *returns* a generator is NOT itself a generator *function*,
# hence Gradio fails to detect streaming and treats the return value as a plain
# object. By defining explicit generator functions (with `yield from`) we ensure
# `inspect.isgeneratorfunction` evaluates to True and Gradio streams correctly.

# # # integration with ZeroGPU on hf
# @spaces.GPU
def process_search_tta(
    sat_path: str | None,
    ground_path: str | None,
    taxonomy: str | None = None,
    session_threads: list[threading.Thread] | None = None,
):
    """Run both TTA and non-TTA search episodes concurrently and stream both heat-maps."""

    if session_threads is None:
        session_threads = []

    # Disable Run button and clear image/status outputs, hide sliders, clear frame states
    yield (
        gr.update(interactive=False),
        gr.update(value=None),
        gr.update(value=None),
        gr.update(value="Initializing model…", visible=True),
        gr.update(value="Initializing model…", visible=True),
        gr.update(visible=False),
        gr.update(visible=False),
        [],
        [],
        session_threads,
    )

    # Bail early if satellite image missing
    if sat_path is None:
        yield (
            gr.update(interactive=True),
            gr.update(value=None),
            gr.update(value=None),
            gr.update(value="No satellite image provided.", visible=True),
            gr.update(value="", visible=True),
            gr.update(visible=False),
            gr.update(visible=False),
            [],
            [],
            session_threads,
        )
        return

    # Prepare PIL images
    sat_img = Image.open(sat_path).convert("RGB")
    ground_img_pil = Image.open(ground_path).convert("RGB") if ground_path else None

    # Lookup target positions metadata (may be empty)
    tgt_positions = []
    if taxonomy and taxonomy in tgts_metadata:
        tgt_positions = [tuple(t) for t in tgts_metadata[taxonomy]["target_positions"]]

    # Helper to build a TestWorker with/without TTA
    def build_planner(enable_tta: bool, save_dir: str, clip_obj):
        # Lazily (re)create a ClipSegTTA instance per thread if not provided
        local_clip = clip_obj
        if TAXABIND_TTA and local_clip is None:
            local_clip = ClipSegTTA(
                img_dir=TAXABIND_IMG_DIR,
                imo_dir=TAXABIND_IMO_DIR,
                json_path=TAXABIND_INAT_JSON_PATH,
                sat_to_img_ids_json_path=TAXABIND_SAT_TO_IMG_IDS_JSON_PATH,
                patch_size=TAXABIND_PATCH_SIZE,
                sat_checkpoint_path=TAXABIND_SAT_CHECKPOINT_PATH,
                sample_index=-1,
                blur_kernel=TAXABIND_GAUSSIAN_BLUR_KERNEL,
                device=device,
                sat_to_img_ids_json_is_train_dict=False,
                tax_to_filter_val=QUERY_TAX,
                load_model=USE_CLIP_PREDS,
                initial_modality=INITIAL_MODALITY,
                sound_data_path=TAXABIND_SOUND_DATA_PATH,
                sound_checkpoint_path=TAXABIND_SOUND_CHECKPOINT_PATH,
            )
        if local_clip is not None:
            # Feed inputs to ClipSegTTA copy
            local_clip.img_paths = [ground_path] if ground_path else []
            local_clip.imo_path = sat_path
            local_clip.imgs = ([local_clip.dataset.img_transform(ground_img_pil).to(device)] if ground_img_pil else [])
            local_clip.imo = local_clip.dataset.imo_transform(sat_img).to(device)
            local_clip.sounds = []
            local_clip.sound_ids = []
            local_clip.species_name = taxonomy or ""
            local_clip.gt_mask_name = taxonomy.replace(" ", "_") if taxonomy else ""
            local_clip.target_positions = tgt_positions if tgt_positions else [(0, 0)]

        planner = TestWorker(
            meta_agent_id=0,
            n_agent=1,
            policy_net=policy_net,
            global_step=-1,
            device=device,
            greedy=True,
            save_image=SAVE_GIFS,
            clip_seg_tta=local_clip,
        )
        planner.execute_tta = enable_tta
        planner.gifs_path = save_dir
        return planner

    # ────────────── Per-run output directories ──────────────
    # Ensure base directory exists
    os.makedirs(gifs_path, exist_ok=True)

    run_id = time.strftime("%Y%m%d_%H%M%S")  # unique timestamp
    run_root = os.path.join(gifs_path, run_id)
    gifs_dir_tta = os.path.join(run_root, "with_tta")
    gifs_dir_no  = os.path.join(run_root, "no_tta")

    os.makedirs(gifs_dir_tta, exist_ok=True)
    os.makedirs(gifs_dir_no,  exist_ok=True)

    # House-keep old runs so we never keep more than RUN_HISTORY_LIMIT
    _prune_old_run_dirs(gifs_path, RUN_HISTORY_LIMIT)

    # Shared dict to record if a thread hit an exception
    error_flags = {"tta": False, "no": False}

    def _planner_thread(enable_tta: bool, save_dir: str, clip_obj, key: str):
        """Prepare directory, build planner, run an episode, record errors."""
        try:
            planner = build_planner(enable_tta, save_dir, clip_obj)
            _thread_clip_map[threading.current_thread()] = planner.clip_seg_tta
            planner.run_episode(0)
        except Exception as exc:
            # Mark that this planner crashed so UI can show an error status
            error_flags[key] = True
            # Log full traceback so developers can debug via console logs
            import traceback, sys
            traceback.print_exc()
            # Still exit the thread
            return

    # Launch both planners in background threads – preparation included
    thread_tta = threading.Thread(
        target=_planner_thread,
        args=(True, gifs_dir_tta, None, "tta"),
        daemon=True,
    )
    thread_no = threading.Thread(
        target=_planner_thread,
        args=(False, gifs_dir_no, None, "no"),
        daemon=True,
    )
    # Track threads for this user session
    session_threads.extend([thread_tta, thread_no])
    thread_tta.start()
    thread_no.start()


    sent_tta: set[str] = set()
    sent_no:  set[str] = set()
    last_tta = None
    last_no  = None
    # Track previous status strings so we can emit updates when only the
    # status (Running…/Done.) changes even if no new frame was produced.
    # Previous status values so we can detect changes and yield updates
    prev_status_tta = "Initializing model…"
    prev_status_no  = "Initializing model…"

    try:
        while thread_tta.is_alive() or thread_no.is_alive():
            updated = False
            # Collect new frames from TTA dir
            pngs = glob.glob(os.path.join(gifs_dir_tta, "*.png"))
            pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
            for fp in pngs:
                if fp not in sent_tta:
                    # Ensure file is fully written (non-empty & readable)
                    try:
                        if os.path.getsize(fp) == 0:
                            continue
                        with open(fp, "rb") as fh:
                            fh.read(1)
                    except Exception:
                        # Skip this round; we'll retry next poll
                        continue
                    sent_tta.add(fp)
                    last_tta = fp
                    updated = True
            # Collect new frames from no-TTA dir
            pngs = glob.glob(os.path.join(gifs_dir_no, "*.png"))
            pngs.sort(key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
            for fp in pngs:
                if fp not in sent_no:
                    try:
                        if os.path.getsize(fp) == 0:
                            continue
                        with open(fp, "rb") as fh:
                            fh.read(1)
                    except Exception:
                        continue
                    sent_no.add(fp)
                    last_no = fp
                    updated = True

            # Determine status based on whether we already have a frame and whether
            # the corresponding thread is still alive.
            def _mk_status(last_frame, thread_alive, errored: bool, running_tta: bool=False):
                if errored:
                    return "Error!"
                if last_frame is None:
                    return "Initializing model…"
                if not thread_alive:
                    return "Done."
                return "Executing TTA (Scheduling GPUs)…" if running_tta else "Executing Planner…"

            exec_tta_flag = False
            if thread_tta.is_alive():
                clip_obj = _thread_clip_map.get(thread_tta)
                if clip_obj is not None and getattr(clip_obj, "executing_tta", False):
                    exec_tta_flag = True

            status_tta = _mk_status(last_tta, thread_tta.is_alive(), error_flags["tta"], exec_tta_flag)
            status_no  = _mk_status(last_no,  thread_no.is_alive(), error_flags["no"], False)

            # Determine if we should reveal sliders (once corresponding thread has finished)
            show_slider_tta = (not thread_tta.is_alive()) and (last_tta is not None)
            show_slider_no  = (not thread_no.is_alive()) and (last_no  is not None)

            # Build slider updates
            slider_tta_upd = gr.update()
            slider_no_upd  = gr.update()
            frames_tta_upd = gr.update()
            frames_no_upd  = gr.update()

            if show_slider_tta:
                n_tta_frames = max(len(sent_tta), 1)
                slider_tta_upd = gr.update(visible=True, minimum=1, maximum=n_tta_frames, value=n_tta_frames)
                frames_tta_upd = sorted(sent_tta, key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
            if show_slider_no:
                n_no_frames = max(len(sent_no), 1)
                slider_no_upd = gr.update(visible=True, minimum=1, maximum=n_no_frames, value=n_no_frames)
                frames_no_upd = sorted(sent_no, key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))

            # Emit update if we have a new frame OR status changed OR slider visibility changed
            if (
                updated
                or status_tta != prev_status_tta
                or status_no != prev_status_no
                or show_slider_tta
                or show_slider_no
            ):
                yield (
                    gr.update(interactive=False),
                    last_tta,
                    last_no,
                    gr.update(value=status_tta, visible=True),
                    gr.update(value=status_no, visible=True),
                    slider_tta_upd,
                    slider_no_upd,
                    frames_tta_upd,
                    frames_no_upd,
                    session_threads,
                )

                prev_status_tta = status_tta
                prev_status_no  = status_no

            time.sleep(POLL_INTERVAL)
    finally:
        # Ensure background threads are stopped on cancel
        for th in (thread_tta, thread_no):
            if th.is_alive():
                _stop_thread(th)
                th.join(timeout=1)

    # Remove finished threads from global registry
    with _running_threads_lock:
        # Clear session thread list
        session_threads.clear()

    # Small delay to ensure last frame files are fully flushed
    time.sleep(0.2)
    # One last scan after both threads have finished to catch any frame
    # that may have been written just before termination but after the last
    # polling iteration.
    for fp in sorted(glob.glob(os.path.join(gifs_dir_tta, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0])):
        if fp not in sent_tta:
            sent_tta.add(fp)
            last_tta = fp
    for fp in sorted(glob.glob(os.path.join(gifs_dir_no, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0])):
        if fp not in sent_no:
            sent_no.add(fp)
            last_no = fp

    # Prepare frames list and slider configs
    frames_tta = sorted(glob.glob(os.path.join(gifs_dir_tta, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
    frames_no = sorted(glob.glob(os.path.join(gifs_dir_no, "*.png")), key=lambda p: int(os.path.splitext(os.path.basename(p))[0]))
    if last_tta is None and frames_tta:
        last_tta = frames_tta[-1]
    if last_no is None and frames_no:
        last_no = frames_no[-1]
    n_tta = len(frames_tta) or 1  # prevent zero-range slider
    n_no  = len(frames_no) or 1

    # Final emit: re-enable button, hide statuses, show sliders set to last frame
    yield (
        gr.update(interactive=True),
        last_tta,
        last_no,
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=True, minimum=1, maximum=n_tta, value=n_tta),
        gr.update(visible=True, minimum=1, maximum=n_no,  value=n_no),
        frames_tta,
        frames_no,
        session_threads,
    )


# ────────────────────────── Gradio UI ─────────────────────────────────
with gr.Blocks(title="Search-TTA (Simplified)", theme=gr.themes.Base()) as demo:

    gr.Markdown(
        """
        # Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild Demo
        Click on any of the <b>examples below</b> and run the <b>TTA demo</b>. Check out the <b>multimodal heatmap generation feature</b> by switching to the other tab above. <br>
        Note that the model initialization, RL planner, and TTA updates are not fully optimized on GPU for this huggingface demo, and hence may experience some lag during execution. <br>
        If you encounter an 'Error' status, refresh the browser and rerun the demo, or try again the next day. We will improve this in the future.  <br>
        <a href="https://search-tta.github.io">Project Website</a> 
        """
    )
    # gr.Markdown(
    #     """
    #     <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
    #       <div>
    #         <h1>Search-TTA: A Multimodal Test-Time Adaptation Framework for Visual Search in the Wild</h1>
    #         <span></span>
    #         <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
    #         <a href="https://search-tta.github.io">Project Website</a>
    #         </h2>
    #         <span></span>
    #         <h2 style='font-weight: 450; font-size: 0.5rem; margin: 0rem'>[Work in Progress]</h2>
    #       </div>
    #     </div>
    #     """
    # )
    # gr.Markdown(
    #     """
    #     # Search-TTA – Simplified Demo
    #     **Satellite ↔ Ground-level Visual Search** via RL Test-Time Adaptation.
    #     """
    # )

    with gr.Row(variant="panel"):
        with gr.Column():
            gr.Markdown("### Model Inputs")
            sat_input = gr.Image(
                label="Satellite Image",
                sources=["upload"],
                type="filepath",
                height=320,
            )
            taxonomy_input = gr.Textbox(
                label="Full Taxonomy Name (optional)",
                placeholder="e.g. Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
            )
            ground_input = gr.Image(
                label="Ground-level Image (optional)",
                sources=["upload"],
                type="filepath",
                height=320,
            )
            run_btn = gr.Button("Run Search-TTA", variant="primary")

        with gr.Column():
            gr.Markdown("### Live Heatmap Output")
            # gr.Markdown("### Live Heatmap (with TTA)")
            display_img_tta = gr.Image(label="Heatmap (TTA per 20 steps)", type="filepath", height=400)  # 512
            status_tta = gr.Markdown("")
            slider_tta = gr.Slider(label="TTA Frame", minimum=1, maximum=1, step=1, value=1, visible=False)

            display_img_no_tta = gr.Image(label="Heatmap (no TTA)", type="filepath", height=400)  # 512
            status_no_tta = gr.Markdown("")
            slider_no = gr.Slider(label="No-TTA Frame", minimum=1, maximum=1, step=1, value=1, visible=False)

            frames_state_tta = gr.State([])
            frames_state_no = gr.State([])
            session_threads_state = gr.State([])

    # Slider callbacks (updates image when user drags slider)
    def _show_frame(idx: int, frames: list[str]):
        # Slider is 1-indexed; convert to 0-indexed list access
        if 1 <= idx <= len(frames):
            return frames[idx - 1]
        return gr.update()

    slider_tta.change(_show_frame, inputs=[slider_tta, frames_state_tta], outputs=display_img_tta)
    slider_no.change(_show_frame, inputs=[slider_no, frames_state_no], outputs=display_img_no_tta)

    # Bind callback

    # EXAMPLES – copied from original demo (satellite, ground, taxonomy only)
    with gr.Row():
        gr.Markdown("### Taxonomy")
    with gr.Row():
        gr.Examples(
            examples=[
                [
                    "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/80645_39.76079_-74.10316.jpg",
                    "examples/Animalia_Chordata_Aves_Charadriiformes_Laridae_Larus_marinus/cc1ebaf9-899d-49f2-81c8-d452249a8087.jpg",
                    "Animalia Chordata Aves Charadriiformes Laridae Larus marinus",
                ],
                [
                    "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/388246_45.49036_7.14796.jpg",
                    "examples/Animalia_Chordata_Mammalia_Rodentia_Sciuridae_Marmota_marmota/327e1f07-692b-4140-8a3e-bd098bc064ff.jpg",
                    "Animalia Chordata Mammalia Rodentia Sciuridae Marmota marmota",
                ],
                [
                    "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/410613_5.35573_100.28948.jpg",
                    "examples/Animalia_Chordata_Reptilia_Squamata_Varanidae_Varanus_salvator/461d8e6c-0e66-4acc-8ecd-bfd9c218bc14.jpg",
                    "Animalia Chordata Reptilia Squamata Varanidae Varanus salvator",
                ],
                [
                    "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/277303_38.72364_-75.07749.jpg",
                    "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/0b9cc264-a2ba-44bd-8e41-0d01a6edd1e8.jpg",
                    "Animalia Arthropoda Malacostraca Decapoda Ocypodidae Ocypode quadrata",
                    "examples/Animalia_Arthropoda_Malacostraca_Decapoda_Ocypodidae_Ocypode_quadrata/12372063.mp3"
                ],
            ],
            inputs=[sat_input, ground_input, taxonomy_input],
            outputs=[run_btn, display_img_tta, display_img_no_tta, status_tta, status_no_tta, slider_tta, slider_no, frames_state_tta, frames_state_no],
            fn=process_search_tta,
            cache_examples=False,
        )


    run_btn.click(
        fn=process_search_tta,
        inputs=[sat_input, ground_input, taxonomy_input, session_threads_state],
        outputs=[run_btn, display_img_tta, display_img_no_tta, status_tta, status_no_tta, slider_tta, slider_no, frames_state_tta, frames_state_no, session_threads_state],
    )

    # Footer to point out to model and data from app page.
    gr.Markdown(
        """
        The satellite image CLIP encoder is fine-tuned using [Sentinel-2 Level 2A](https://docs.sentinel-hub.com/api/latest/data/sentinel-2-l2a/) satellite image and taxonomy images (with GPS locations) from [iNaturalist](https://inaturalist.org/). The sound CLIP encoder is fine-tuned with a subset of the same taxonomy images and their corresponding sounds from [iNaturalist](https://inaturalist.org/). Some of these iNaturalist data are also used in [Taxabind](https://arxiv.org/abs/2411.00683). Note that while some of the examples above result in poor probability distributions, they will be improved using our test-time adaptation framework during the search process.
        """
    )

# if def main
if __name__ == "__main__":

    # Build UI with explicit Tabs so we can detect tab selection and clean up
    from app_multimodal_inference import demo as multimodal_demo

    with gr.Blocks() as root:
        with gr.Tabs() as tabs:
            with gr.TabItem("Multimodal Inference"):
                multimodal_demo.render()
            with gr.TabItem("Search-TTA"):
                demo.render()

        # Hidden textbox purely to satisfy Gradio's need for an output component.
        _cleanup_status = gr.Textbox(visible=False)

        outputs_on_tab = [_cleanup_status]

        def _on_tab_change(evt: gr.SelectData, session_threads: list[threading.Thread]):
            # evt.value contains the name of the newly-selected tab.
            if evt.value == "Multimodal Inference":
                # Stop only threads started in this session
                for th in list(session_threads):
                    if th is not None and th.is_alive():
                        _stop_thread(th)
                        th.join(timeout=1)
                session_threads.clear()
                return "Stopped running Search-TTA threads."
            return ""

        tabs.select(_on_tab_change, inputs=[session_threads_state], outputs=outputs_on_tab)

    root.queue(max_size=15)
    root.launch(share=True)