yfan07 commited on
Commit
08ff7f7
·
verified ·
1 Parent(s): 2751ecb

Add files using upload-large-folder tool

Browse files
Files changed (3) hide show
  1. SEG_LTPO_results.md +157 -17
  2. load_model.py +272 -25
  3. seg_ltpo.py +618 -32
SEG_LTPO_results.md CHANGED
@@ -317,32 +317,172 @@ QLTPOConfig(
317
  )
318
  ```
319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  ---
321
 
322
- ## Next Steps
 
 
323
 
324
- ### Immediate (full-set confirmation)
325
 
326
- Run full evaluations with e0-modulated Stage 1 to confirm quick-validation trends at scale:
327
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  ```bash
329
- # Full Null (~30 min) — expect S ≈ 0.0120 + small increase, less than +5%
330
- TRANSFORMERS_OFFLINE=1 python -W ignore load_model.py --eval_split test_n
331
 
332
- # Full Seen (~35 min) — expect mIoU gain ≥ +0.013
333
- TRANSFORMERS_OFFLINE=1 python -W ignore load_model.py --eval_split test_s
 
 
334
 
335
- # Full Unseen (~35 min) expect mIoU gain ≥ +0.025 (from pre-e0 baseline +0.0295)
336
- TRANSFORMERS_OFFLINE=1 python -W ignore load_model.py --eval_split test_u
 
 
337
  ```
338
 
339
- **Decision criteria to promote e0-Stage1 to final method:**
340
- - Null S degradation < 5% relative (full set)
341
- - Seen mIoU gain ≥ +0.012
342
- - Unseen mIoU gain ≥ +0.022
343
 
344
- ### If full-set confirms (future work)
345
 
346
- 1. **F-score improvement (Stage 3)**: Current gain is mainly in mIoU (overlap); F-score (boundary precision/recall) lags. Candidate: boundary-oriented reward using SAM's low-res logit gradient sharpness or contour consistency across anchor frames.
347
- 2. **Stronger e0 suppression ablation**: Test `e0_modulation="sqrt"` (g(e0) = sqrt(e0+ε)) to further compress Null tail. Only justified if full-set Null degradation exceeds 5%.
348
- 3. **Stage 2 revisit**: R_align_det hurt at scale due to noisy z_in/z_out from low-quality initial masks. Possible fix: gate align signal by `R_iou_pred > 0.85` to only use it when initial mask is reliable.
 
317
  )
318
  ```
319
 
320
+ ### Full Unseen Evaluation with e0 (1656 samples)
321
+
322
+ | Method | mIoU | F | Δ mIoU |
323
+ |--------|------|---|--------|
324
+ | Baseline | 0.6990 | 0.7926 | — |
325
+ | q-LTPO S1 (no e0) | 0.7285 | 0.8013 | +0.0295 (+4.22%) |
326
+ | **q-LTPO S1 (e0)** | **0.7240** | **0.7985** | **+0.0250 (+3.56%)** |
327
+
328
+ e0 版本相比 no-e0 版本 mIoU 略低 (-0.0045),但 Null 安全性更好。F 与 mIoU 的提升比例基本一致(约 60%)。
329
+
330
+ **全量评估状态(更新):**
331
+
332
+ | Split | Baseline | q-LTPO S1 (e0) | Δ | Status |
333
+ |-------|----------|----------------|---|--------|
334
+ | Unseen (full, 1656) | 0.6990 / 0.7926 | 0.7240 / 0.7985 | +3.56% mIoU | ✅ Done |
335
+ | Seen (full) | — | — | — | Pending |
336
+ | Null (full, S↓) | 0.0120 | — | — | Pending |
337
+
338
+ ---
339
+
340
+ ## Direction B: Boundary Precision Experiments(已结束,结论为失败)
341
+
342
+ ### B-Step1: Multimask Post-Processing(彻底失败)
343
+
344
+ 用 SAM 多 mask 输出(K=3)替换单 mask 解码,分别用 iou_pred 和 Sobel edge score 选最佳候选。
345
+
346
+ | Method | mIoU | F | ΔF vs s1 |
347
+ |--------|------|---|----------|
348
+ | s1 (single mask) | 0.6979 | 0.8024 | — |
349
+ | s1_mm (iou_pred selection) | 0.6979 | 0.7917 | -0.0107 |
350
+ | s1_mm_edge (Sobel selection) | 0.5715 | 0.6820 | -0.1204 |
351
+
352
+ **根本原因:** SAM 内部的单 mask 选择已经最优;外部重选更差。Sobel 在 1024×1024 归一化空间中选到纹理碎片而非语义目标,灾难性失败。
353
+
354
+ ### B1: 非对称面积膨胀惩罚(机制性无效)
355
+
356
+ 假设:LTPO 导致 mask 向非目标区域膨胀(精度下降),加惩罚项压制。
357
+
358
+ **实验结论:假设错误。** LTPO 期间 soft area 实际在下降(-16%)而非上升:
359
+
360
+ ```
361
+ soft area: 0.1507 → 0.1267 (-16%) ← background logits 更负
362
+ hard area: 0.0635 → 0.0650 (+2.4%) ← 实际 mask 区域微增
363
+ ```
364
+
365
+ **"mask sharpening" 现象:** Adam 在 R_iou_pred 驱动下使 logit 更双峰化(前景更正、背景更负),soft area 因 93% 背景像素的贡献减少而下降。B1 惩罚的前提条件(soft area 上升)从未发生:
366
+
367
+ ```
368
+ B1 activation rate : 0.025 ← 仅 2.5% 样本触发
369
+ B1 mean excess : 0.00002 ← 可忽略
370
+ ```
371
+
372
+ **结论:** Direction B 从多 mask 选择到面积约束全部失败,不再追求。F-score 滞后于 mIoU 的根本原因不是 mask 精度,而是 reward 代理信号质量问题(见 Path A)。
373
+
374
+ ---
375
+
376
+ ## Direction II: Frame-Adaptive Token Optimization(初步探索,待后续)
377
+
378
+ ### 方法设计
379
+
380
+ 将单一共享 token q 扩展为视频 token 轨迹:
381
+
382
+ ```
383
+ q_t = q_global + delta_t
384
+ ```
385
+
386
+ 其中 q_global 是全局共享 token,delta_t 是每个 anchor 帧的局部残差,初始化为 0。联合优化:
387
+
388
+ ```
389
+ max Σ_t [λ_iou · e0_t · R_iou(q_t) - λ_area · R_area(q_t)]
390
+ - λ_residual · ||delta||² - λ_smooth · Σ_t ||delta_t - delta_{t+1}||² - λ_reg · ||q_global - q_init||²
391
+ ```
392
+
393
+ 每个 anchor 帧使用各自的 e0_t(per-frame 存在先验)。delta_t 受 hard clip 约束:`||delta_t|| ≤ scale × ||q_init||`。
394
+
395
+ ### 200-sample Probe Results(Unseen split)
396
+
397
+ | Method | mIoU | F | reward gain p50 | delta ‖Δ‖ |
398
+ |--------|------|---|-----------------|-----------|
399
+ | baseline | 0.6745 | 0.7763 | — | — |
400
+ | s1 | 0.6945 | 0.7773 | +0.0053 | — |
401
+ | fa_base (无约束) | 0.6945 | 0.7711 | +0.0112 | 1.675 |
402
+ | fa_smooth (λ_smooth=0.01) | 0.6960 | 0.7731 | +0.0104 | 1.488 |
403
+ | fa_c03 (delta clip 0.3×) | 0.6959 | 0.7722 | +0.0112 | — |
404
+
405
+ ### 关键发现
406
+
407
+ **Reward-metric gap(核心问题):**
408
+ ```
409
+ reward gain p50: s1 = +0.0053 fa_c03 = +0.0112 (fa 高 2.1×)
410
+ R_iou_pred 提升: s1 +0.077 fa_c03 +0.114
411
+ 实际 mIoU 提升: s1 +2.96% fa_c03 +3.17% (仅差 0.21%)
412
+ ```
413
+ fa 拿到了多得多的 reward,但 mIoU 几乎没有额外提升,F 还略降。
414
+
415
+ **结论:** 瓶颈不是优化结构,而是 R_iou_pred 本身的任务相关性不足。R_iou_pred 衡量"mask 有多干净",不衡量"mask 是否包含正确的音频目标"。所有架构变体(单 token / frame-adaptive)都受同一个天花板限制。
416
+
417
+ Direction II 不在旧 reward 下继续调参,等 Path A(新 reward)有正向信号后再考虑是否重新引入。
418
+
419
  ---
420
 
421
+ ## Path A: AVT-Aware Reward 重设计
422
+
423
+ ### 动机
424
 
425
+ Ref-AVS 中的 referent 不一定是发声体本身(可能是拿着发声物体的人、与声源相关的对象)。纯音频对齐 reward 会将优化推向 sound source 而非 text 指向的 referent。需要 audio + text + global visual context 共同定义的 referent consistency。
426
 
427
+ ### AVT Proxy Reward 设计
428
 
429
+ **核心洞察:** Fseg(= q_init)已经是 audio + video + text 的多模态融合 token,可直接作为 frozen AVT teacher。
430
+
431
+ ```python
432
+ R_avt = mean_t cos(z_in_t, q_init)
433
+ R_avt_c = mean_t [cos(z_in_t, q_init) - β · cos(z_out_t, q_init)]
434
+ ```
435
+
436
+ - `z_in_t`:anchor 帧 t 的 soft-masked 图像特征(SAM 256-dim 空间)
437
+ - `q_init`:frozen Fseg(AVT anchor,不参与优化梯度)
438
+ - R_avt 高 → mask 区域与查询 referent 对齐;R_avt 低 → mask 指向错误目标
439
+
440
+ 与 Stage 2 的区别:Stage 2 用当前 q(移动)对齐 z_in(当前 mask),导致自我确认偏差;R_avt 用 q_init(固定)作为 teacher,打破偏差。
441
+
442
+ ### Step A0: Reward–Metric Correlation Study(下一步要做)
443
+
444
+ **目的:** 在进入 full optimization 之前,先用数据验证新 reward 是否比 R_iou_pred 更能预测真实 metric 变化。
445
+
446
+ **实验设置(200 samples, Unseen split):**
447
+ 对每个(视频,segment)样本:
448
+ 1. Baseline decode → IoU_base, F_base
449
+ 2. q-LTPO s1 → q_best;记录 reward_gain、r_avt_gain、r_avt_c_gain(均在 q_ltpo_autograd 内计算)
450
+ 3. LTPO decode → IoU_ltpo, F_ltpo
451
+ 4. Δ = LTPO - baseline
452
+
453
+ 输出 Pearson 相关表:
454
+
455
+ ```
456
+ Pearson r with ΔmIoU:
457
+ R_iou_pred_gain : +0.xxx ← 当前 proxy
458
+ R_avt_gain : +0.xxx ← cos(z_in, q_init)
459
+ R_avt_c_gain : +0.xxx ← 对比版本
460
+
461
+ Wrong direction (gain>0 但 Δ<0):
462
+ R_iou / ΔmIoU : 0.xxx
463
+ R_avt / ΔmIoU : 0.xxx
464
+ ```
465
+
466
+ **运行命令:**
467
  ```bash
468
+ python load_model.py --eval_split test_u --max_eval_rows 200
469
+ ```
470
 
471
+ **判断标准:**
472
+ - `r(R_avt, ΔmIoU) > r(R_iou, ΔmIoU)` → AVT proxy 更好,进入 Step A1
473
+ - 两者相近 → reward 本身不是瓶颈,需要重新审视
474
+ - `R_avt / ΔF wrong frac` 明显低于 `R_iou / ΔF` → AVT 能解释 F-score 不跟随 mIoU 的现象
475
 
476
+ ### Step A1: Hybrid Reward(Step A0 验证后)
477
+
478
+ ```
479
+ R_task = λ1 · e0 · R_iou_pred + λ2 · R_avt_c - λ3 · R_area_soft
480
  ```
481
 
482
+ - R_iou_pred 继续负责 mask quality(shape quality signal)
483
+ - R_avt_c 负责 referent correctness(task-specific signal)
484
+ - 两者结合才有可能同时维持 IoU 并提升 F
 
485
 
486
+ 候选权重组合:`λ1=0.6, λ2=0.5, λ3=0.2`(AVT 作为辅助项,不完全取代 R_iou)。
487
 
488
+ 如果 Step A1 有正向信号,再考虑将 Direction II(frame-adaptive)和新 reward 结合。
 
 
load_model.py CHANGED
@@ -498,6 +498,8 @@ if __name__ == "__main__":
498
  get_sam_model, get_anchor_indices,
499
  QLTPOConfig, q_ltpo_autograd, check_grad_connectivity,
500
  reset_q_ltpo_stats, get_q_ltpo_stats,
 
 
501
  )
502
 
503
  def print_q_ltpo_stats(name: str) -> None:
@@ -521,6 +523,14 @@ if __name__ == "__main__":
521
  gains = sorted(s["reward_gain"] for s in stats)
522
  def _pct(v, p): return v[max(0, int(len(v) * p / 100) - 1)]
523
  mean_e0 = sum(s["e0"] for s in stats) / n
 
 
 
 
 
 
 
 
524
  print(f"\n [q-LTPO stats | {name} | n={n}]")
525
  print(f" acceptance rate : {acc_rate:.3f}")
526
  print(f" mean e0 (exist prior): {mean_e0:.4f} ← should differ Null vs Seen")
@@ -529,20 +539,30 @@ if __name__ == "__main__":
529
  print(f" mean drift ‖q−q₀‖ : {mean_drift:.4f}")
530
  print(f" hit-clip ratio : {clip_rate:.3f}")
531
  print(f" R_iou_pred init→best : {mean_iou_init:.4f} → {mean_iou_best:.4f}")
 
 
532
  print(f" area (hard) init→best: {mean_area_init:.4f} → {mean_area_best:.4f}")
 
 
 
 
533
  print(f" reward↑ & area+20%↑ : {null_risk:.3f} ← Null safety indicator")
 
 
 
 
534
 
535
- def valuate_ltpo(model, dataloader, name, ltpo_cfg, optimize_fn=None, max_rows=-1):
 
536
  if optimize_fn is None:
537
  optimize_fn = ltpo_optimize
538
  """
539
- Evaluate with SEG-LTPO-simple test-time optimisation.
540
 
541
- For each sample:
542
- 1. Run the standard SimToken forward pass once to get initial Fseg.
543
- 2. Optimise Fseg on 4 anchor frames using antithetic ES (5 steps).
544
- 3. Decode the full video with the best Fseg found.
545
- 4. Fall back to the original Fseg when reward gating rejects the update.
546
  """
547
  model.eval()
548
  sam_model = get_sam_model(model)
@@ -590,6 +610,7 @@ if __name__ == "__main__":
590
  image_embeds_b = input_dict["image_feats"][b] # [T, 256, 64, 64]
591
  resize_b = input_dict["resizes"][b]
592
  orgsize_b = input_dict["orgsizes"][b]
 
593
 
594
  # Convert initial Fseg to float32 for stable optimisation.
595
  # seg_emb_list[b]: [num_seg, 256] in bfloat16
@@ -609,6 +630,7 @@ if __name__ == "__main__":
609
  pred_mask = decode_full_video(
610
  best_fseg, image_embeds_b, sam_model,
611
  resize_b, orgsize_b, model_dtype,
 
612
  ) # [T, H, W]
613
  pred_masks_ltpo.append(pred_mask)
614
 
@@ -699,6 +721,219 @@ if __name__ == "__main__":
699
  print(f"\n LTPO valuate on Null: S metric: {total_metric/count:.4f}")
700
 
701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
702
  # ── Stage 0: gradient connectivity check ─────────────────────────────
703
  # Loads one image_embed directly from disk — no dataloader, no gt_mask,
704
  # no media frames required. F_init is a unit-scale random vector that
@@ -846,32 +1081,44 @@ if __name__ == "__main__":
846
 
847
  # ── Run evaluation ────────────────────────────────────────────────────
848
 
849
- ltpo_cfg = LTPOConfig()
850
- q_ltpo_cfg_s1 = QLTPOConfig(stage=1)
851
- q_ltpo_cfg_s2 = QLTPOConfig(stage=2)
852
- max_rows = args.max_eval_rows # -1 = all rows
 
 
 
 
 
 
 
 
 
 
 
 
 
853
 
854
  # --max_eval_rows 0 → Stage 0 + bypass equivalence check, then exit
855
  if max_rows == 0:
856
  run_stage0_check()
857
  run_bypass_test()
858
  elif _split == 'test_n':
859
- # Safety check: Baseline vs q-LTPO Stage 1 only.
860
- # ES-LTPO / Stage 2 are omitted — ES is no longer the primary method,
861
- # and Stage 2 consistently underperforms Stage 1. If Stage 1 shows
862
- # notable deterioration here, add a small Best-of-2 ES subset run to
863
- # distinguish "reward unsafe on Null" from "autograd more aggressive".
864
  valuate_Null(model, _dataloader, max_rows=max_rows)
 
 
 
 
 
865
  reset_q_ltpo_stats()
866
- valuate_ltpo_null(model, _dataloader, q_ltpo_cfg_s1,
867
- optimize_fn=q_ltpo_autograd, max_rows=max_rows)
868
- print_q_ltpo_stats("null_q_ltpo_s1")
869
  else:
870
- # Baseline + q-LTPO Stage 1 only. ES series omitted — q-autograd is
871
- # the primary method; Stage 2 consistently underperforms Stage 1.
872
  valuate(model, _dataloader, _split, max_rows=max_rows)
873
- reset_q_ltpo_stats()
874
- valuate_ltpo(model, _dataloader, f'{_split}_q_ltpo_s1', q_ltpo_cfg_s1,
875
- optimize_fn=q_ltpo_autograd, max_rows=max_rows)
876
- print_q_ltpo_stats(f'{_split}_q_ltpo_s1')
877
 
 
498
  get_sam_model, get_anchor_indices,
499
  QLTPOConfig, q_ltpo_autograd, check_grad_connectivity,
500
  reset_q_ltpo_stats, get_q_ltpo_stats,
501
+ q_ltpo_frame_adaptive, decode_full_video_adaptive,
502
+ _compute_avt_proxy_reward,
503
  )
504
 
505
  def print_q_ltpo_stats(name: str) -> None:
 
523
  gains = sorted(s["reward_gain"] for s in stats)
524
  def _pct(v, p): return v[max(0, int(len(v) * p / 100) - 1)]
525
  mean_e0 = sum(s["e0"] for s in stats) / n
526
+ mean_mask_iou = sum(s.get("mask_soft_iou", 0.0) for s in stats) / n
527
+ mean_iou_contrib = sum(s.get("R_iou_contrib_gain", 0.0) for s in stats) / n
528
+ mean_soft_area_init = sum(s.get("r_area_soft_init", 0.0) for s in stats) / n
529
+ mean_soft_area_best = sum(s.get("r_area_soft_best", 0.0) for s in stats) / n
530
+ # B1 activation diagnostics
531
+ b1_excesses = sorted(s.get("b1_peak_excess", 0.0) for s in stats)
532
+ b1_act_rate = sum(1 for v in b1_excesses if v > 1e-8) / n
533
+ b1_mean_excess = sum(b1_excesses) / n
534
  print(f"\n [q-LTPO stats | {name} | n={n}]")
535
  print(f" acceptance rate : {acc_rate:.3f}")
536
  print(f" mean e0 (exist prior): {mean_e0:.4f} ← should differ Null vs Seen")
 
539
  print(f" mean drift ‖q−q₀‖ : {mean_drift:.4f}")
540
  print(f" hit-clip ratio : {clip_rate:.3f}")
541
  print(f" R_iou_pred init→best : {mean_iou_init:.4f} → {mean_iou_best:.4f}")
542
+ print(f" R_iou_contrib_gain : {mean_iou_contrib:+.4f} ← λ_iou·e0·Δiou")
543
+ print(f" mask soft-IoU(init,best): {mean_mask_iou:.4f} ← 1.0=mask不变")
544
  print(f" area (hard) init→best: {mean_area_init:.4f} → {mean_area_best:.4f}")
545
+ print(f" soft area init→best : {mean_soft_area_init:.4f} → {mean_soft_area_best:.4f}")
546
+ print(f" B1 activation rate : {b1_act_rate:.3f} ← frac(peak_area > e0)")
547
+ print(f" B1 mean excess : {b1_mean_excess:.5f} ← mean ReLU(peak_area - e0)")
548
+ print(f" B1 excess p10/50/90 : {_pct(b1_excesses,10):.5f} / {_pct(b1_excesses,50):.5f} / {_pct(b1_excesses,90):.5f}")
549
  print(f" reward↑ & area+20%↑ : {null_risk:.3f} ← Null safety indicator")
550
+ # Direction II: frame-adaptive delta diagnostics
551
+ delta_norms = [s.get("delta_norm", 0.0) for s in stats]
552
+ if any(v > 0 for v in delta_norms):
553
+ print(f" mean delta ‖Δ‖ : {sum(delta_norms)/n:.4f} ← per-anchor residual norm")
554
 
555
+ def valuate_ltpo(model, dataloader, name, ltpo_cfg, optimize_fn=None,
556
+ max_rows=-1, multimask=False, use_edge=False):
557
  if optimize_fn is None:
558
  optimize_fn = ltpo_optimize
559
  """
560
+ Evaluate with SEG-LTPO test-time optimisation + optional boundary refinement.
561
 
562
+ decode_mode:
563
+ multimask=False, use_edge=False : original single-mask decode (default)
564
+ multimask=True, use_edge=False : 3 candidates, SAM iou_pred selection (step 1a)
565
+ multimask=True, use_edge=True : 3 candidates, boundary-edge score (step 1b)
 
566
  """
567
  model.eval()
568
  sam_model = get_sam_model(model)
 
610
  image_embeds_b = input_dict["image_feats"][b] # [T, 256, 64, 64]
611
  resize_b = input_dict["resizes"][b]
612
  orgsize_b = input_dict["orgsizes"][b]
613
+ rgb_b = input_dict["images"][b] if use_edge else None # [T,3,H,W]
614
 
615
  # Convert initial Fseg to float32 for stable optimisation.
616
  # seg_emb_list[b]: [num_seg, 256] in bfloat16
 
630
  pred_mask = decode_full_video(
631
  best_fseg, image_embeds_b, sam_model,
632
  resize_b, orgsize_b, model_dtype,
633
+ rgb_frames=rgb_b, multimask=multimask,
634
  ) # [T, H, W]
635
  pred_masks_ltpo.append(pred_mask)
636
 
 
721
  print(f"\n LTPO valuate on Null: S metric: {total_metric/count:.4f}")
722
 
723
 
724
+ def valuate_ltpo_adaptive(model, dataloader, name, ltpo_cfg, max_rows=-1):
725
+ """Evaluate with Direction II frame-adaptive token optimization."""
726
+ model.eval()
727
+ sam_model = get_sam_model(model)
728
+ model_dtype = torch.bfloat16
729
+ num_frames = 10
730
+ anchor_indices = get_anchor_indices(num_frames, ltpo_cfg.num_anchors)
731
+
732
+ total_iou = 0
733
+ total_fscore = 0
734
+ count = 0
735
+
736
+ _total = min(max_rows, len(dataloader)) if max_rows > 0 else len(dataloader)
737
+ for i, batch in enumerate(tqdm(dataloader, desc=f"FA-LTPO Evaluating on {name}", total=_total)):
738
+ if 0 < max_rows <= i:
739
+ break
740
+ input_dict = dict_to_cuda(batch)
741
+
742
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True):
743
+ with torch.no_grad():
744
+ output_dict = model.forward(
745
+ images=input_dict["images"],
746
+ images_clip=input_dict["images_clip"],
747
+ audio_features=input_dict["audio_feats"],
748
+ image_features=input_dict["image_feats"],
749
+ input_ids=input_dict["input_ids"],
750
+ labels=input_dict["labels"],
751
+ attention_masks=input_dict["attention_masks"],
752
+ masks_list=input_dict["masks"],
753
+ resize_list=input_dict["resizes"],
754
+ orgsize_list=input_dict["orgsizes"],
755
+ conversation_list=input_dict["convs"],
756
+ refs_num=input_dict["refs_num"],
757
+ fids=input_dict["fids"],
758
+ vids=input_dict["vids"],
759
+ contrast=args.ct_weight,
760
+ ref_ids=input_dict["ref_ids"],
761
+ inference=True,
762
+ )
763
+
764
+ gt_masks = output_dict["gt_masks"] # list[B]:[num_seg, T, H, W]
765
+ seg_emb_list = output_dict["seg_embeddings"] # list[B]:[num_seg, 256]
766
+
767
+ for b in range(len(input_dict["images"])):
768
+ image_embeds_b = input_dict["image_feats"][b]
769
+ resize_b = input_dict["resizes"][b]
770
+ orgsize_b = input_dict["orgsizes"][b]
771
+ F_init_b = seg_emb_list[b].detach().float()
772
+
773
+ pred_masks_ltpo = []
774
+ for seg_idx in range(F_init_b.shape[0]):
775
+ fseg_init = F_init_b[seg_idx : seg_idx + 1]
776
+
777
+ q_global, delta = q_ltpo_frame_adaptive(
778
+ fseg_init, image_embeds_b, anchor_indices,
779
+ sam_model, model_dtype, ltpo_cfg,
780
+ )
781
+
782
+ pred_mask = decode_full_video_adaptive(
783
+ q_global, delta, anchor_indices,
784
+ image_embeds_b, sam_model,
785
+ resize_b, orgsize_b, model_dtype,
786
+ )
787
+ pred_masks_ltpo.append(pred_mask)
788
+
789
+ pred_masks_b = torch.stack(pred_masks_ltpo, dim=0)
790
+ num_seg = pred_masks_b.shape[0]
791
+ T_ = pred_masks_b.shape[1]
792
+ iou = utility.mask_iou(pred_masks_b, gt_masks[b])
793
+ fscore = utility.Eval_Fmeasure(pred_masks_b, gt_masks[b], None)
794
+
795
+ total_iou += iou * num_seg * T_
796
+ total_fscore += fscore * num_seg * T_
797
+ count += num_seg * T_
798
+
799
+ print(f"\n FA-LTPO valuate on {name}: miou: {total_iou/count:.4f} fscore: {total_fscore/count:.4f}")
800
+
801
+ # ── Step A0: reward–metric correlation study ─────────────────────────
802
+
803
+ def _print_correlation_report(per_sample: list) -> None:
804
+ import numpy as np
805
+ n = len(per_sample)
806
+ if n == 0:
807
+ return
808
+
809
+ r_iou = np.array([s["reward_gain"] for s in per_sample], dtype=float)
810
+ r_avt = np.array([s["r_avt_gain"] for s in per_sample], dtype=float)
811
+ r_avt_c = np.array([s["r_avt_c_gain"] for s in per_sample], dtype=float)
812
+ dm = np.array([s["delta_miou"] for s in per_sample], dtype=float)
813
+ df = np.array([s["delta_f"] for s in per_sample], dtype=float)
814
+
815
+ def pearson(x, y):
816
+ x = x - x.mean(); y = y - y.mean()
817
+ denom = np.sqrt((x ** 2).sum() * (y ** 2).sum())
818
+ return float((x * y).sum() / (denom + 1e-12))
819
+
820
+ def wrong_frac(gains, deltas):
821
+ return sum(1 for g, d in zip(gains, deltas) if g > 0 and d < 0) / n
822
+
823
+ print(f"\n [Step A0: Reward–Metric Correlation | n={n}]")
824
+ print(f" mean ΔmIoU : {dm.mean():+.4f} (std {dm.std():.4f})")
825
+ print(f" mean ΔF : {df.mean():+.4f} (std {df.std():.4f})")
826
+ print(f"\n Pearson r with ΔmIoU :")
827
+ print(f" R_iou_pred_gain : {pearson(r_iou, dm):+.3f} ← current proxy")
828
+ print(f" R_avt_gain : {pearson(r_avt, dm):+.3f} ← cos(z_in, q_init)")
829
+ print(f" R_avt_c_gain : {pearson(r_avt_c, dm):+.3f} ← cos(z_in,q)-β·cos(z_out,q)")
830
+ print(f"\n Pearson r with ΔF :")
831
+ print(f" R_iou_pred_gain : {pearson(r_iou, df):+.3f}")
832
+ print(f" R_avt_gain : {pearson(r_avt, df):+.3f}")
833
+ print(f" R_avt_c_gain : {pearson(r_avt_c, df):+.3f}")
834
+ print(f"\n Wrong direction (gain>0 but Δ<0):")
835
+ print(f" R_iou / ΔmIoU : {wrong_frac(r_iou, dm):.3f}")
836
+ print(f" R_avt / ΔmIoU : {wrong_frac(r_avt, dm):.3f}")
837
+ print(f" R_iou / ΔF : {wrong_frac(r_iou, df):.3f}")
838
+ print(f" R_avt / ΔF : {wrong_frac(r_avt, df):.3f}")
839
+
840
+ def valuate_ltpo_correlation_study(model, dataloader, ltpo_cfg, max_rows=-1):
841
+ """Step A0: per-sample reward–metric correlation study.
842
+
843
+ For each (video, segment) sample runs:
844
+ 1. Baseline decode (q_init → mask → IoU/F)
845
+ 2. q-LTPO s1 (q_best → mask → IoU/F)
846
+ Records reward signals and ΔmIoU / ΔF per sample, then prints
847
+ Pearson correlation table to identify which reward best predicts
848
+ actual metric improvement.
849
+ """
850
+ model.eval()
851
+ sam_model = get_sam_model(model)
852
+ model_dtype = torch.bfloat16
853
+ anchor_indices = get_anchor_indices(10, ltpo_cfg.num_anchors)
854
+
855
+ per_sample = []
856
+
857
+ _total = min(max_rows, len(dataloader)) if max_rows > 0 else len(dataloader)
858
+ for i, batch in enumerate(
859
+ tqdm(dataloader, desc="Correlation study (s1)", total=_total)
860
+ ):
861
+ if 0 < max_rows <= i:
862
+ break
863
+ input_dict = dict_to_cuda(batch)
864
+
865
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True):
866
+ with torch.no_grad():
867
+ output_dict = model.forward(
868
+ images=input_dict["images"],
869
+ images_clip=input_dict["images_clip"],
870
+ audio_features=input_dict["audio_feats"],
871
+ image_features=input_dict["image_feats"],
872
+ input_ids=input_dict["input_ids"],
873
+ labels=input_dict["labels"],
874
+ attention_masks=input_dict["attention_masks"],
875
+ masks_list=input_dict["masks"],
876
+ resize_list=input_dict["resizes"],
877
+ orgsize_list=input_dict["orgsizes"],
878
+ conversation_list=input_dict["convs"],
879
+ refs_num=input_dict["refs_num"],
880
+ fids=input_dict["fids"],
881
+ vids=input_dict["vids"],
882
+ contrast=args.ct_weight,
883
+ ref_ids=input_dict["ref_ids"],
884
+ inference=True,
885
+ )
886
+
887
+ gt_masks = output_dict["gt_masks"] # list[B]:[num_seg, T, H, W]
888
+ seg_emb_list = output_dict["seg_embeddings"] # list[B]:[num_seg, 256]
889
+
890
+ for b in range(len(input_dict["images"])):
891
+ image_embeds_b = input_dict["image_feats"][b]
892
+ resize_b = input_dict["resizes"][b]
893
+ orgsize_b = input_dict["orgsizes"][b]
894
+ F_init_b = seg_emb_list[b].detach().float()
895
+
896
+ for seg_idx in range(F_init_b.shape[0]):
897
+ q_init = F_init_b[seg_idx : seg_idx + 1] # [1, 256]
898
+ gt_seg = gt_masks[b][seg_idx : seg_idx + 1] # [1, T, H, W]
899
+
900
+ # Baseline decode (q_init, no LTPO)
901
+ with torch.no_grad():
902
+ pred_base = decode_full_video(
903
+ q_init, image_embeds_b, sam_model,
904
+ resize_b, orgsize_b, model_dtype,
905
+ ).unsqueeze(0) # [1, T, H, W]
906
+ iou_base = utility.mask_iou(pred_base, gt_seg)
907
+ f_base = utility.Eval_Fmeasure(pred_base, gt_seg, None)
908
+
909
+ # LTPO (s1) — also computes r_avt inside q_ltpo_autograd
910
+ reset_q_ltpo_stats()
911
+ q_best = q_ltpo_autograd(
912
+ q_init, image_embeds_b, anchor_indices,
913
+ sam_model, model_dtype, ltpo_cfg,
914
+ )
915
+ stat = get_q_ltpo_stats()[0]
916
+
917
+ with torch.no_grad():
918
+ pred_ltpo = decode_full_video(
919
+ q_best, image_embeds_b, sam_model,
920
+ resize_b, orgsize_b, model_dtype,
921
+ ).unsqueeze(0)
922
+ iou_ltpo = utility.mask_iou(pred_ltpo, gt_seg)
923
+ f_ltpo = utility.Eval_Fmeasure(pred_ltpo, gt_seg, None)
924
+
925
+ per_sample.append({
926
+ "reward_gain": stat["reward_gain"],
927
+ "r_avt_gain": stat.get("r_avt_gain", 0.0),
928
+ "r_avt_c_gain": stat.get("r_avt_c_gain", 0.0),
929
+ "e0": stat["e0"],
930
+ "accepted": stat["accepted"],
931
+ "delta_miou": float(iou_ltpo - iou_base),
932
+ "delta_f": float(f_ltpo - f_base),
933
+ })
934
+
935
+ _print_correlation_report(per_sample)
936
+
937
  # ── Stage 0: gradient connectivity check ─────────────────────────────
938
  # Loads one image_embed directly from disk — no dataloader, no gt_mask,
939
  # no media frames required. F_init is a unit-scale random vector that
 
1081
 
1082
  # ── Run evaluation ────────────────────────────────────────────────────
1083
 
1084
+ ltpo_cfg = LTPOConfig()
1085
+ q_ltpo_cfg_s1 = QLTPOConfig(stage=1)
1086
+ q_ltpo_cfg_s2 = QLTPOConfig(stage=2)
1087
+ q_ltpo_cfg_s21 = QLTPOConfig(stage=21) # P1a: tether probe
1088
+ q_ltpo_cfg_s22 = QLTPOConfig(stage=22) # P1b: faithful ext-ref
1089
+
1090
+ # ── Direction B: boundary precision probes ──────────────────────────────
1091
+ q_ltpo_cfg_b1_w03 = QLTPOConfig(stage=1, lambda_area_inc=0.3, area_inc_tau=0.0)
1092
+ q_ltpo_cfg_b1_w10 = QLTPOConfig(stage=1, lambda_area_inc=1.0, area_inc_tau=0.0)
1093
+
1094
+ # ── Direction II: Frame-adaptive token optimization ─────────────────────
1095
+ # fa_c03: delta clipped at 0.3×‖q_init‖ — moderate constraint.
1096
+ # First probe to answer: "does constrained frame-adaptive beat shared q?"
1097
+ # If yes → ablate tighter/looser constraints and smoothness in follow-up.
1098
+ q_ltpo_cfg_fa_c03 = QLTPOConfig(stage=1, lambda_residual=0.001, lambda_smooth_temp=0.0, max_delta_drift_scale=0.3)
1099
+
1100
+ max_rows = args.max_eval_rows # -1 = all rows
1101
 
1102
  # --max_eval_rows 0 → Stage 0 + bypass equivalence check, then exit
1103
  if max_rows == 0:
1104
  run_stage0_check()
1105
  run_bypass_test()
1106
  elif _split == 'test_n':
1107
+ # Null safety check: baseline + Stage 1 + frame-adaptive
 
 
 
 
1108
  valuate_Null(model, _dataloader, max_rows=max_rows)
1109
+ for cfg_name, cfg in [("s1", q_ltpo_cfg_s1)]:
1110
+ reset_q_ltpo_stats()
1111
+ valuate_ltpo_null(model, _dataloader, cfg,
1112
+ optimize_fn=q_ltpo_autograd, max_rows=max_rows)
1113
+ print_q_ltpo_stats(f"null_q_ltpo_{cfg_name}")
1114
  reset_q_ltpo_stats()
1115
+ valuate_ltpo_adaptive(model, _dataloader, "null_fa_c03",
1116
+ q_ltpo_cfg_fa_c03, max_rows=max_rows)
1117
+ print_q_ltpo_stats("null_fa_c03")
1118
  else:
 
 
1119
  valuate(model, _dataloader, _split, max_rows=max_rows)
1120
+ # Step A0: reward–metric correlation study (s1 + AVT proxy signals)
1121
+ valuate_ltpo_correlation_study(
1122
+ model, _dataloader, q_ltpo_cfg_s1, max_rows=max_rows
1123
+ )
1124
 
seg_ltpo.py CHANGED
@@ -283,31 +283,98 @@ def best_of_2_optimize(
283
  # Full-video decode with a given Fseg
284
  # ---------------------------------------------------------------------------
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  def decode_full_video(
287
- fseg: torch.Tensor, # [1, 256] float32
288
- image_embeds: torch.Tensor, # [T, 256, 64, 64] model dtype on CUDA
289
  sam_model,
290
- resize: tuple, # (H_resized, W_resized) – after ResizeLongestSide
291
- orgsize: tuple, # (H_orig, W_orig)
292
  model_dtype: torch.dtype,
 
 
293
  ) -> torch.Tensor:
294
- """
295
- Decode all T frames with the given Fseg.
 
 
 
 
 
 
296
  Returns raw logit mask [T, H_orig, W_orig] (not yet sigmoid).
297
  """
298
- device = image_embeds.device
299
  dense_emb = _precompute_dense_emb(sam_model, model_dtype, device)
300
  dense_pe = sam_model.prompt_encoder.get_dense_pe().to(device)
301
  sparse_emb = fseg.to(model_dtype).unsqueeze(1) # [1, 1, 256]
302
 
303
  with torch.no_grad():
304
- low_res_masks, _ = sam_model.mask_decoder(
305
- image_embeddings=image_embeds, # [T, 256, 64, 64]
306
  image_pe=dense_pe,
307
- sparse_prompt_embeddings=sparse_emb, # [1, 1, 256]
308
- dense_prompt_embeddings=dense_emb, # [1, 256, 64, 64]
309
- multimask_output=False,
310
- ) # [T, 1, 256, 256]
 
 
 
 
 
 
 
 
 
 
 
311
 
312
  pred_mask = sam_model.postprocess_masks(
313
  low_res_masks, input_size=resize, original_size=orgsize
@@ -401,12 +468,14 @@ def ltpo_optimize(
401
 
402
  @dataclass
403
  class QLTPOConfig:
404
- """Configuration for q_ltpo_autograd (Stages 1–3).
405
 
406
  stage controls which reward terms are active:
407
- 1 R_iou + R_area_soft + reg (gradient connectivity + stability)
408
- 2 Stage 1 + R_align_det (z stopgrad) (semantic alignment)
409
- 3 Stage 2 + R_temp_feat (full reward)
 
 
410
  """
411
  stage: int = 1
412
  T: int = 5
@@ -443,12 +512,44 @@ class QLTPOConfig:
443
  e0_modulation: str = "identity"
444
  e0_eps: float = 1e-4 # epsilon for "sqrt" variant
445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  # ── Oracle Null-safety gate (analysis only; NOT for final method) ──────
447
  # Derived from test-set distribution (Null area_hard ≈ 0.01, Seen ≈ 0.05)
448
  # so must not be used in reported results. Set null_gate_delta=0 to disable.
449
  null_area_threshold: float = 0.02 # hard area fraction below which guard activates
450
  null_gate_delta: float = 0.0 # 0 = disabled; 0.05 = oracle experiment
451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
 
453
  # ---------------------------------------------------------------------------
454
  # e0 helper
@@ -508,10 +609,32 @@ def _task_reward_stage1(
508
  optimizer sees only the area-penalty gradient and naturally tends toward
509
  smaller (more conservative) masks — the correct behavior when the initial
510
  prediction is near-empty (Null frames).
 
 
 
 
 
 
 
 
 
511
  """
512
  r_iou = iou.mean()
513
  r_area = torch.sigmoid(lrm / cfg.area_temp).mean()
514
- return cfg.lambda_iou * e0 * r_iou - cfg.lambda_area * r_area
 
 
 
 
 
 
 
 
 
 
 
 
 
515
 
516
 
517
  def _task_reward_stage2(
@@ -575,6 +698,167 @@ def _task_reward_stage3(
575
  return r_s2 + cfg.lambda_temp * r_temp
576
 
577
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
578
  def _compute_task_reward(
579
  q: torch.Tensor,
580
  lrm: torch.Tensor,
@@ -582,12 +866,20 @@ def _compute_task_reward(
582
  image_embeds_anchor_fp32: torch.Tensor,
583
  cfg: QLTPOConfig,
584
  e0: float = 1.0,
 
 
585
  ) -> torch.Tensor:
586
  """Dispatch to the correct stage's task reward."""
587
  if cfg.stage == 1:
588
  return _task_reward_stage1(lrm, iou, cfg, e0)
589
  if cfg.stage == 2:
590
  return _task_reward_stage2(q, lrm, iou, image_embeds_anchor_fp32, cfg, e0)
 
 
 
 
 
 
591
  return _task_reward_stage3(q, lrm, iou, image_embeds_anchor_fp32, cfg, e0)
592
 
593
 
@@ -599,9 +891,11 @@ def _compute_full_reward(
599
  q_init: torch.Tensor,
600
  cfg: QLTPOConfig,
601
  e0: float = 1.0,
 
 
602
  ) -> torch.Tensor:
603
  """Full reward = task reward + L2 regularization (used for backward)."""
604
- r_task = _compute_task_reward(q, lrm, iou, image_embeds_anchor_fp32, cfg, e0)
605
  r_reg = (q - q_init).pow(2).sum()
606
  return r_task - cfg.lambda_reg * r_reg
607
 
@@ -661,6 +955,53 @@ def check_grad_connectivity(
661
  }
662
 
663
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
664
  # ---------------------------------------------------------------------------
665
  # Stage 1–3: q-LTPO-autograd main optimizer
666
  # ---------------------------------------------------------------------------
@@ -697,6 +1038,11 @@ def q_ltpo_autograd(
697
  lr = cfg.lr if cfg.lr > 0 else 0.01 * rms.item()
698
  max_drift = cfg.max_drift if cfg.max_drift > 0 else 0.5 * q_init_fp32.norm().item()
699
 
 
 
 
 
 
700
  # ── Baseline forward + e0 existence prior ────────────────────────────
701
  with torch.no_grad():
702
  lrm0, iou0 = _decode_on_anchors_diff(
@@ -708,7 +1054,8 @@ def q_ltpo_autograd(
708
  e0 = _compute_e0(r_area_soft_init, cfg)
709
 
710
  R_init_task = _compute_task_reward(
711
- q_init_fp32, lrm0, iou0, image_embeds_anchor, cfg, e0=e0
 
712
  ).item()
713
 
714
  # ── Optimisation setup ────────────────────────────────────────────────
@@ -720,13 +1067,17 @@ def q_ltpo_autograd(
720
  hit_clip = False
721
 
722
  # ── Optimisation loop ─────────────────────────────────────────────────
 
 
 
723
  for step in range(cfg.T):
724
  optimizer.zero_grad()
725
 
726
  lrm, iou = _decode_on_anchors_diff(
727
  q, image_embeds_anchor, dense_emb, mask_dec, dense_pe
728
  )
729
- R_full = _compute_full_reward(q, lrm, iou, image_embeds_anchor, q_init_fp32, cfg, e0=e0)
 
730
  R_full.backward()
731
  optimizer.step()
732
 
@@ -744,20 +1095,32 @@ def q_ltpo_autograd(
744
  lrm_eval, iou_eval = _decode_on_anchors_diff(
745
  q.detach(), image_embeds_anchor, dense_emb, mask_dec, dense_pe
746
  )
 
 
 
 
747
  r_task = _compute_task_reward(
748
- q.detach(), lrm_eval, iou_eval, image_embeds_anchor, cfg, e0=e0
 
749
  ).item()
750
  if r_task > best_reward:
751
  best_reward = r_task
752
  best_q = q.detach().clone()
753
 
 
 
 
 
 
 
754
  # ── Reward gating: clean re-eval of best_q vs q_init ─────────────────
755
  with torch.no_grad():
756
  lrm_b, iou_b = _decode_on_anchors_diff(
757
  best_q, image_embeds_anchor, dense_emb, mask_dec, dense_pe
758
  )
759
  R_best_task = _compute_task_reward(
760
- best_q, lrm_b, iou_b, image_embeds_anchor, cfg, e0=e0
 
761
  ).item()
762
 
763
  area_init = (lrm0 > 0).float().mean().item()
@@ -768,19 +1131,242 @@ def q_ltpo_autograd(
768
  )
769
  accepted = R_best_task > R_init_task + effective_gate
770
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
771
  # ── Per-sample diagnostics ────────────────────────────────────────────
772
  _q_ltpo_stats.append({
773
- "accepted": accepted,
774
- "reward_gain": R_best_task - R_init_task,
775
- "drift": (best_q - q_init_fp32).norm().item(),
776
- "hit_clip": hit_clip,
777
- "e0": e0,
778
- "R_iou_pred_init": iou0.mean().item(),
779
- "R_iou_pred_best": iou_b.mean().item(),
780
- "area_hard_init": area_init,
781
- "area_hard_best": (lrm_b > 0).float().mean().item(),
 
 
 
 
 
 
 
 
 
 
 
 
782
  })
783
 
784
  if not accepted:
785
  return F_init.float()
786
  return best_q
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  # Full-video decode with a given Fseg
284
  # ---------------------------------------------------------------------------
285
 
286
+ def _sobel_edge(rgb_frames: torch.Tensor) -> torch.Tensor:
287
+ """Compute Sobel edge magnitude from normalized RGB frames.
288
+
289
+ Args:
290
+ rgb_frames: [T, 3, H, W] float32 (SAM-normalized, CUDA)
291
+ Returns:
292
+ edge: [T, 1, H, W] float32, non-negative
293
+ """
294
+ gray = rgb_frames.float().mean(dim=1, keepdim=True) # [T, 1, H, W]
295
+ kx = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
296
+ dtype=torch.float32, device=rgb_frames.device).view(1, 1, 3, 3)
297
+ ky = kx.transpose(2, 3)
298
+ gx = F.conv2d(gray, kx, padding=1)
299
+ gy = F.conv2d(gray, ky, padding=1)
300
+ return torch.sqrt(gx ** 2 + gy ** 2 + 1e-6) # [T, 1, H, W]
301
+
302
+
303
+ def _boundary_edge_score(
304
+ low_res_masks: torch.Tensor, # [T, K, 256, 256] logits
305
+ rgb_frames: torch.Tensor, # [T, 3, H, W] float32
306
+ resize: tuple, # (H_resized, W_resized)
307
+ area_temp: float = 5.0,
308
+ ) -> torch.Tensor:
309
+ """Score each of K mask candidates by boundary-edge alignment.
310
+
311
+ R_edge = <soft_boundary_band, Sobel_edge> / (sum(soft_boundary_band) + ε)
312
+ Rewards masks whose boundaries coincide with image edges.
313
+
314
+ Returns: [T, K] float32 scores (higher = better boundary alignment)
315
+ """
316
+ T, K = low_res_masks.shape[:2]
317
+ H_r, W_r = resize
318
+
319
+ # Upsample all candidates to resized image resolution at once
320
+ masks_up = F.interpolate(
321
+ low_res_masks.reshape(T * K, 1, 256, 256).float(),
322
+ size=(H_r, W_r), mode="bilinear", align_corners=False,
323
+ ).reshape(T, K, H_r, W_r) # [T, K, H, W]
324
+
325
+ E = _sobel_edge(rgb_frames[:, :, :H_r, :W_r]) # [T, 1, H, W]
326
+
327
+ m = torch.sigmoid(masks_up / area_temp) # [T, K, H, W]
328
+ b = 4.0 * m * (1.0 - m) # soft boundary band
329
+ num = (b * E.squeeze(1).unsqueeze(1)).sum(dim=[2, 3]) # [T, K]
330
+ den = b.sum(dim=[2, 3]) + 1e-6
331
+ return num / den # [T, K]
332
+
333
+
334
  def decode_full_video(
335
+ fseg: torch.Tensor, # [1, 256] float32
336
+ image_embeds: torch.Tensor, # [T, 256, 64, 64] model dtype on CUDA
337
  sam_model,
338
+ resize: tuple, # (H_resized, W_resized)
339
+ orgsize: tuple, # (H_orig, W_orig)
340
  model_dtype: torch.dtype,
341
+ rgb_frames: Optional[torch.Tensor] = None, # [T, 3, H, W]; enables edge selection
342
+ multimask: bool = False, # True = 3 candidates; False = single mask
343
  ) -> torch.Tensor:
344
+ """Decode all T frames with the given Fseg.
345
+
346
+ Selection logic (applied per-frame):
347
+ - multimask=False, rgb_frames=None : original single-mask decode (baseline)
348
+ - multimask=True, rgb_frames=None : 3 candidates, select by SAM iou_pred
349
+ - multimask=True, rgb_frames=* : 3 candidates, select by boundary-edge score
350
+ (boundary band × Sobel edge; directly rewards boundary-image alignment)
351
+
352
  Returns raw logit mask [T, H_orig, W_orig] (not yet sigmoid).
353
  """
354
+ device = image_embeds.device
355
  dense_emb = _precompute_dense_emb(sam_model, model_dtype, device)
356
  dense_pe = sam_model.prompt_encoder.get_dense_pe().to(device)
357
  sparse_emb = fseg.to(model_dtype).unsqueeze(1) # [1, 1, 256]
358
 
359
  with torch.no_grad():
360
+ low_res_masks, iou_preds = sam_model.mask_decoder(
361
+ image_embeddings=image_embeds,
362
  image_pe=dense_pe,
363
+ sparse_prompt_embeddings=sparse_emb,
364
+ dense_prompt_embeddings=dense_emb,
365
+ multimask_output=multimask,
366
+ ) # [T, K, 256, 256], [T, K] where K=1 or K=3
367
+
368
+ if multimask:
369
+ T = low_res_masks.shape[0]
370
+ if rgb_frames is not None:
371
+ # Step 1b: boundary-edge score selects best candidate
372
+ scores = _boundary_edge_score(low_res_masks, rgb_frames, resize)
373
+ else:
374
+ # Step 1a: SAM's own iou_pred selects best candidate
375
+ scores = iou_preds
376
+ best_idx = scores.argmax(dim=1) # [T]
377
+ low_res_masks = low_res_masks[torch.arange(T, device=device), best_idx].unsqueeze(1)
378
 
379
  pred_mask = sam_model.postprocess_masks(
380
  low_res_masks, input_size=resize, original_size=orgsize
 
468
 
469
  @dataclass
470
  class QLTPOConfig:
471
+ """Configuration for q_ltpo_autograd (Stages 1–3 + Stage 2-ext variants).
472
 
473
  stage controls which reward terms are active:
474
+ 1 R_iou + R_area_soft + reg (baseline autograd)
475
+ 2 Stage 1 + R_align_det (z_in/z_out stopgrad) (self-bootstrapped alignment)
476
+ 3 Stage 2 + R_temp_feat (full reward)
477
+ 21 Stage 1 + R_tether (P1a: tether probe) (frozen r_ref via q_init attn)
478
+ 22 Stage 1 + R_faithful (P1b: faithful ext-ref) (z_in/z_out vs frozen r_ref)
479
  """
480
  stage: int = 1
481
  T: int = 5
 
512
  e0_modulation: str = "identity"
513
  e0_eps: float = 1e-4 # epsilon for "sqrt" variant
514
 
515
+ # ── Stage 2-ext: external reference (stages 21 and 22) ────────────────
516
+ # r_ref = AttnPool(image_feats_anchor, q_init): frozen visual anchor derived
517
+ # from q_init's attention over SAM image features. Breaks Stage 2's
518
+ # self-confirming bias by providing a mask-independent teacher.
519
+ # r_ref_temp: softmax temperature for attention pooling (sqrt(256) = 16).
520
+ r_ref_temp: float = 16.0
521
+
522
+ # ── Direction B: boundary precision rewards ────────────────────────────
523
+ # B1: asymmetric area expansion penalty
524
+ # Only penalises growth beyond (1+τ)×e0; allows mask contraction.
525
+ # Targets the observed pattern where LTPO slightly expands masks into
526
+ # non-target regions (recall↑ but precision↓, hurting F-score).
527
+ # B2: boundary sharpness reward
528
+ # -mean(4m(1-m)) with temperature=1.0; rewards bimodal (certain)
529
+ # mask predictions, encouraging cleaner boundary predictions.
530
+ lambda_area_inc: float = 0.0 # B1 weight (0 = disabled)
531
+ area_inc_tau: float = 0.0 # B1 tolerance band: allow (1+τ)×e0
532
+ lambda_sharp: float = 0.0 # B2 weight (0 = disabled)
533
+
534
  # ── Oracle Null-safety gate (analysis only; NOT for final method) ──────
535
  # Derived from test-set distribution (Null area_hard ≈ 0.01, Seen ≈ 0.05)
536
  # so must not be used in reported results. Set null_gate_delta=0 to disable.
537
  null_area_threshold: float = 0.02 # hard area fraction below which guard activates
538
  null_gate_delta: float = 0.0 # 0 = disabled; 0.05 = oracle experiment
539
 
540
+ # ── Direction II: Frame-adaptive token optimization (stage=4) ─────────
541
+ # q_t = q_global + delta_t, where delta_t is a per-anchor residual.
542
+ # Optimizes q_global and {delta_t} jointly with Adam.
543
+ # lambda_residual: soft L2 penalty on delta_t
544
+ # lambda_smooth_temp: temporal smoothness penalty on adjacent delta differences
545
+ # max_delta_drift_scale: per-anchor hard L2 clip = scale × ‖q_init‖
546
+ # Prevents individual anchors from wandering to a completely different visual mode.
547
+ # Keep << max_drift (0.5) so delta stays a "small frame correction" to q_global.
548
+ # 0.1 is tight (delta ≤ 20% of global drift budget), 0.3 is moderate.
549
+ lambda_residual: float = 0.001
550
+ lambda_smooth_temp: float = 0.0
551
+ max_delta_drift_scale: float = 0.1 # per-anchor clip = scale × ‖q_init‖
552
+
553
 
554
  # ---------------------------------------------------------------------------
555
  # e0 helper
 
609
  optimizer sees only the area-penalty gradient and naturally tends toward
610
  smaller (more conservative) masks — the correct behavior when the initial
611
  prediction is near-empty (Null frames).
612
+
613
+ Optional boundary precision terms (Direction B):
614
+ B1 (lambda_area_inc > 0): asymmetric expansion penalty
615
+ -λ_inc · ReLU(r_area - (1+τ)·e0)
616
+ Penalises mask growth beyond the initial area (+ tolerance band τ).
617
+ e0 doubles as the stopgrad initial-area threshold — zero extra cost.
618
+ B2 (lambda_sharp > 0): boundary sharpness reward
619
+ -λ_sharp · mean(4m(1-m)) with m = sigmoid(lrm), temperature=1.0
620
+ Maximises bimodality of mask logits → cleaner boundary predictions.
621
  """
622
  r_iou = iou.mean()
623
  r_area = torch.sigmoid(lrm / cfg.area_temp).mean()
624
+ R = cfg.lambda_iou * e0 * r_iou - cfg.lambda_area * r_area
625
+
626
+ # B1: penalise expansion beyond (1+τ)×e0 (allow contraction freely)
627
+ if cfg.lambda_area_inc > 0.0:
628
+ area_ceil = (1.0 + cfg.area_inc_tau) * e0
629
+ R = R - cfg.lambda_area_inc * F.relu(r_area - area_ceil)
630
+
631
+ # B2: reward confident (bimodal) boundary predictions
632
+ if cfg.lambda_sharp > 0.0:
633
+ m_sharp = torch.sigmoid(lrm) # temperature=1.0 (sharp)
634
+ boundary_uncertain = 4.0 * m_sharp * (1.0 - m_sharp)
635
+ R = R - cfg.lambda_sharp * boundary_uncertain.mean()
636
+
637
+ return R
638
 
639
 
640
  def _task_reward_stage2(
 
698
  return r_s2 + cfg.lambda_temp * r_temp
699
 
700
 
701
+ @torch.no_grad()
702
+ def _compute_r_ref(
703
+ q_init: torch.Tensor, # [1, 256] float32
704
+ image_embeds_anchor: torch.Tensor, # [A, 256, 64, 64] float32
705
+ temp: float = 16.0,
706
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
707
+ """Frozen external visual reference via attention pooling guided by q_init.
708
+
709
+ r_ref: regions most attended by q_init (positive anchor).
710
+ r_neg: regions least attended by q_init (anti-attended negative).
711
+ Both are in the SAM 256d space — no projection needed.
712
+ Computed once before the optimization loop and kept fixed (stopgrad).
713
+ """
714
+ img_flat = image_embeds_anchor.flatten(2) # [A, 256, H*W]
715
+ q_norm = F.normalize(q_init[0], dim=0) # [256]
716
+ img_norm = F.normalize(img_flat, dim=1) # [A, 256, H*W]
717
+
718
+ # cosine similarity between q and each spatial position
719
+ attn = torch.einsum('d,adp->ap', q_norm, img_norm) # [A, H*W]
720
+
721
+ attn_w_pos = torch.softmax( attn / temp, dim=-1) # [A, H*W]
722
+ attn_w_neg = torch.softmax(-attn / temp, dim=-1) # [A, H*W] anti-attended
723
+
724
+ # soft attention pooling in the original (non-normalized) feature space
725
+ r_ref_frames = torch.einsum('ap,adp->ad', attn_w_pos, img_flat) # [A, 256]
726
+ r_neg_frames = torch.einsum('ap,adp->ad', attn_w_neg, img_flat) # [A, 256]
727
+
728
+ r_ref = F.normalize(r_ref_frames.mean(0), dim=0) # [256]
729
+ r_neg = F.normalize(r_neg_frames.mean(0), dim=0) # [256]
730
+ return r_ref, r_neg
731
+
732
+
733
+ def _task_reward_stage2_tether(
734
+ q: torch.Tensor, # [1, 256] float32
735
+ lrm: torch.Tensor, # [A,1,256,256] float32
736
+ iou: torch.Tensor, # [A,1] float32
737
+ r_ref: torch.Tensor, # [256] frozen
738
+ r_neg: torch.Tensor, # [256] frozen
739
+ cfg: QLTPOConfig,
740
+ e0: float = 1.0,
741
+ ) -> torch.Tensor:
742
+ """Stage 21 (P1a tether): Stage 1 + R_tether.
743
+
744
+ R_tether = cos(q, r_ref) - beta·cos(q, r_neg)
745
+ q is pulled toward the frozen visual anchor without touching mask features.
746
+ Tests whether a fixed external reference stabilizes the optimization trajectory.
747
+ """
748
+ r_s1 = _task_reward_stage1(lrm, iou, cfg, e0)
749
+ q_norm = F.normalize(q[0], dim=0)
750
+ r_tether = q_norm @ r_ref - cfg.beta_align * (q_norm @ r_neg)
751
+ return r_s1 + cfg.lambda_align * r_tether
752
+
753
+
754
+ def _task_reward_stage2_faithful(
755
+ q: torch.Tensor, # [1, 256] float32
756
+ lrm: torch.Tensor, # [A,1,256,256] float32
757
+ iou: torch.Tensor, # [A,1] float32
758
+ image_embeds_anchor_fp32: torch.Tensor, # [A, 256, 64, 64] float32
759
+ r_ref: torch.Tensor, # [256] frozen
760
+ cfg: QLTPOConfig,
761
+ e0: float = 1.0,
762
+ ) -> torch.Tensor:
763
+ """Stage 22 (P1b faithful): Stage 1 + R_faithful.
764
+
765
+ R_faithful = mean_t[ cos(z_in(q,t), r_ref) - beta·cos(z_out(q,t), r_ref) ]
766
+ z_in/z_out come from the *current* mask (change during optimization), but the
767
+ teacher r_ref is frozen — breaking Stage 2's self-confirming bias while keeping
768
+ the same structural form (mask-region vs. reference alignment).
769
+ """
770
+ r_s1 = _task_reward_stage1(lrm, iou, cfg, e0)
771
+ A = lrm.shape[0]
772
+ masks_64 = F.interpolate(
773
+ torch.sigmoid(lrm.squeeze(1) / cfg.area_temp).unsqueeze(1),
774
+ size=(64, 64), mode="bilinear", align_corners=False,
775
+ ).squeeze(1) # [A, 64, 64]
776
+
777
+ r_align = torch.tensor(0.0, device=q.device)
778
+ for t in range(A):
779
+ m = masks_64[t].detach() # stopgrad on mask weights only
780
+ img = image_embeds_anchor_fp32[t] # [256, 64, 64]
781
+ z_in = F.normalize((img * m.unsqueeze(0)).sum(dim=[1, 2]) / (m.sum() + 1e-6), dim=0)
782
+ z_out = F.normalize((img * (1 - m).unsqueeze(0)).sum(dim=[1, 2]) / ((1 - m).sum() + 1e-6), dim=0)
783
+ # teacher is r_ref (frozen), not z_in itself — no confirmation bias
784
+ r_align = r_align + z_in @ r_ref - cfg.beta_align * (z_out @ r_ref)
785
+ r_align = r_align / A
786
+
787
+ return r_s1 + cfg.lambda_align * r_align
788
+
789
+
790
+ def _decode_on_anchors_diff_adaptive(
791
+ q_global: torch.Tensor, # [1, 256] float32, requires_grad
792
+ delta: torch.Tensor, # [A, 256] float32, requires_grad
793
+ image_embeds_anchor_fp32: torch.Tensor, # [A, 256, 64, 64] float32, detached
794
+ dense_emb_fp32: torch.Tensor, # [1, 256, 64, 64] float32, detached
795
+ mask_decoder,
796
+ dense_pe_fp32: torch.Tensor, # [1, 256, 64, 64] float32, detached
797
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
798
+ """Frame-adaptive differentiable decode: each anchor t uses q_t = q_global + delta[t].
799
+
800
+ Loops over A anchors to preserve gradient flow through both q_global and delta.
801
+ Returns low_res_masks [A,1,256,256] and iou_preds [A,1], both float32.
802
+ """
803
+ A = image_embeds_anchor_fp32.shape[0]
804
+ lrm_list: List[torch.Tensor] = []
805
+ iou_list: List[torch.Tensor] = []
806
+ for t in range(A):
807
+ q_t = q_global + delta[t : t + 1] # [1, 256]
808
+ sparse_emb = q_t.unsqueeze(1) # [1, 1, 256]
809
+ lrm_t, iou_t = mask_decoder(
810
+ image_embeddings=image_embeds_anchor_fp32[t : t + 1],
811
+ image_pe=dense_pe_fp32,
812
+ sparse_prompt_embeddings=sparse_emb,
813
+ dense_prompt_embeddings=dense_emb_fp32,
814
+ multimask_output=False,
815
+ ) # [1,1,256,256], [1,1]
816
+ lrm_list.append(lrm_t)
817
+ iou_list.append(iou_t)
818
+ return torch.cat(lrm_list, dim=0), torch.cat(iou_list, dim=0) # [A,1,256,256], [A,1]
819
+
820
+
821
+ def _task_reward_frame_adaptive(
822
+ lrm: torch.Tensor, # [A, 1, 256, 256] float32
823
+ iou: torch.Tensor, # [A, 1] float32
824
+ cfg: "QLTPOConfig",
825
+ e0_vec: List[float], # per-anchor existence priors [A]
826
+ ) -> torch.Tensor:
827
+ """Per-anchor task reward averaged over anchors (no regularization)."""
828
+ A = lrm.shape[0]
829
+ R = torch.tensor(0.0, device=lrm.device)
830
+ for t in range(A):
831
+ r_iou_t = iou[t].mean()
832
+ r_area_t = torch.sigmoid(lrm[t] / cfg.area_temp).mean()
833
+ R = R + cfg.lambda_iou * e0_vec[t] * r_iou_t - cfg.lambda_area * r_area_t
834
+ return R / A
835
+
836
+
837
+ def _compute_full_reward_adaptive(
838
+ q_global: torch.Tensor, # [1, 256]
839
+ delta: torch.Tensor, # [A, 256]
840
+ lrm: torch.Tensor, # [A, 1, 256, 256]
841
+ iou: torch.Tensor, # [A, 1]
842
+ q_init: torch.Tensor, # [1, 256] detached
843
+ cfg: "QLTPOConfig",
844
+ e0_vec: List[float],
845
+ ) -> torch.Tensor:
846
+ """Full adaptive reward = task + residual penalty + temporal smoothness + L2 reg."""
847
+ r_task = _task_reward_frame_adaptive(lrm, iou, cfg, e0_vec)
848
+ r_delta = delta.pow(2).sum()
849
+ r_reg = (q_global - q_init).pow(2).sum()
850
+ R = r_task - cfg.lambda_residual * r_delta - cfg.lambda_reg * r_reg
851
+
852
+ A = delta.shape[0]
853
+ if A > 1 and cfg.lambda_smooth_temp > 0.0:
854
+ r_smooth = torch.tensor(0.0, device=delta.device)
855
+ for t in range(A - 1):
856
+ r_smooth = r_smooth + (delta[t] - delta[t + 1]).pow(2).sum()
857
+ R = R - cfg.lambda_smooth_temp * r_smooth / (A - 1)
858
+
859
+ return R
860
+
861
+
862
  def _compute_task_reward(
863
  q: torch.Tensor,
864
  lrm: torch.Tensor,
 
866
  image_embeds_anchor_fp32: torch.Tensor,
867
  cfg: QLTPOConfig,
868
  e0: float = 1.0,
869
+ r_ref: Optional[torch.Tensor] = None,
870
+ r_neg: Optional[torch.Tensor] = None,
871
  ) -> torch.Tensor:
872
  """Dispatch to the correct stage's task reward."""
873
  if cfg.stage == 1:
874
  return _task_reward_stage1(lrm, iou, cfg, e0)
875
  if cfg.stage == 2:
876
  return _task_reward_stage2(q, lrm, iou, image_embeds_anchor_fp32, cfg, e0)
877
+ if cfg.stage == 21:
878
+ assert r_ref is not None and r_neg is not None, "stage 21 requires r_ref/r_neg"
879
+ return _task_reward_stage2_tether(q, lrm, iou, r_ref, r_neg, cfg, e0)
880
+ if cfg.stage == 22:
881
+ assert r_ref is not None, "stage 22 requires r_ref"
882
+ return _task_reward_stage2_faithful(q, lrm, iou, image_embeds_anchor_fp32, r_ref, cfg, e0)
883
  return _task_reward_stage3(q, lrm, iou, image_embeds_anchor_fp32, cfg, e0)
884
 
885
 
 
891
  q_init: torch.Tensor,
892
  cfg: QLTPOConfig,
893
  e0: float = 1.0,
894
+ r_ref: Optional[torch.Tensor] = None,
895
+ r_neg: Optional[torch.Tensor] = None,
896
  ) -> torch.Tensor:
897
  """Full reward = task reward + L2 regularization (used for backward)."""
898
+ r_task = _compute_task_reward(q, lrm, iou, image_embeds_anchor_fp32, cfg, e0, r_ref, r_neg)
899
  r_reg = (q - q_init).pow(2).sum()
900
  return r_task - cfg.lambda_reg * r_reg
901
 
 
955
  }
956
 
957
 
958
+ # ---------------------------------------------------------------------------
959
+ # AVT proxy reward (Step A0: reward–metric correlation study)
960
+ # ---------------------------------------------------------------------------
961
+
962
+ @torch.no_grad()
963
+ def _compute_avt_proxy_reward(
964
+ q_init_fp32: torch.Tensor, # [1, 256] — frozen AVT anchor (= Fseg)
965
+ lrm: torch.Tensor, # [A, 1, 256, 256] float32
966
+ image_embeds_anchor_fp32: torch.Tensor, # [A, 256, 64, 64] float32
967
+ cfg: "QLTPOConfig",
968
+ beta: float = 0.5,
969
+ ) -> Tuple[float, float]:
970
+ """Task-specific proxy reward using frozen q_init (Fseg) as teacher.
971
+
972
+ q_init = Fseg is already the audio+video+text fusion token produced by SimToken.
973
+ Using it as a frozen reference breaks Stage 2's self-confirming bias while
974
+ measuring whether the mask region aligns with the correct referent.
975
+
976
+ Returns:
977
+ R_avt = mean_t cos(z_in_t, q_init) [scalar]
978
+ R_avt_c = mean_t [cos(z_in_t, q_init) - beta·cos(z_out_t, q_init)] [scalar]
979
+ """
980
+ A = lrm.shape[0]
981
+ q_norm = F.normalize(q_init_fp32[0], dim=0) # [256]
982
+
983
+ masks_64 = F.interpolate(
984
+ torch.sigmoid(lrm.squeeze(1) / cfg.area_temp).unsqueeze(1),
985
+ size=(64, 64), mode="bilinear", align_corners=False,
986
+ ).squeeze(1) # [A, 64, 64]
987
+
988
+ r_avt, r_avt_c = 0.0, 0.0
989
+ for t in range(A):
990
+ m = masks_64[t]
991
+ img = image_embeds_anchor_fp32[t]
992
+ z_in = F.normalize(
993
+ (img * m.unsqueeze(0)).sum(dim=[1, 2]) / (m.sum() + 1e-6), dim=0
994
+ )
995
+ z_out = F.normalize(
996
+ (img * (1.0 - m).unsqueeze(0)).sum(dim=[1, 2]) / ((1.0 - m).sum() + 1e-6), dim=0
997
+ )
998
+ c_in = (q_norm @ z_in).item()
999
+ c_out = (q_norm @ z_out).item()
1000
+ r_avt += c_in
1001
+ r_avt_c += c_in - beta * c_out
1002
+ return r_avt / A, r_avt_c / A
1003
+
1004
+
1005
  # ---------------------------------------------------------------------------
1006
  # Stage 1–3: q-LTPO-autograd main optimizer
1007
  # ---------------------------------------------------------------------------
 
1038
  lr = cfg.lr if cfg.lr > 0 else 0.01 * rms.item()
1039
  max_drift = cfg.max_drift if cfg.max_drift > 0 else 0.5 * q_init_fp32.norm().item()
1040
 
1041
+ # ── Precompute frozen external reference (stages 21, 22 only) ────────
1042
+ r_ref, r_neg = None, None
1043
+ if cfg.stage in (21, 22):
1044
+ r_ref, r_neg = _compute_r_ref(q_init_fp32, image_embeds_anchor, cfg.r_ref_temp)
1045
+
1046
  # ── Baseline forward + e0 existence prior ────────────────────────────
1047
  with torch.no_grad():
1048
  lrm0, iou0 = _decode_on_anchors_diff(
 
1054
  e0 = _compute_e0(r_area_soft_init, cfg)
1055
 
1056
  R_init_task = _compute_task_reward(
1057
+ q_init_fp32, lrm0, iou0, image_embeds_anchor, cfg, e0=e0,
1058
+ r_ref=r_ref, r_neg=r_neg,
1059
  ).item()
1060
 
1061
  # ── Optimisation setup ────────────────────────────────────────────────
 
1067
  hit_clip = False
1068
 
1069
  # ── Optimisation loop ─────────────────────────────────────────────────
1070
+ # Track per-step soft area to diagnose whether B1 penalty ever activates.
1071
+ _step_soft_areas: List[float] = []
1072
+
1073
  for step in range(cfg.T):
1074
  optimizer.zero_grad()
1075
 
1076
  lrm, iou = _decode_on_anchors_diff(
1077
  q, image_embeds_anchor, dense_emb, mask_dec, dense_pe
1078
  )
1079
+ R_full = _compute_full_reward(q, lrm, iou, image_embeds_anchor, q_init_fp32, cfg, e0=e0,
1080
+ r_ref=r_ref, r_neg=r_neg)
1081
  R_full.backward()
1082
  optimizer.step()
1083
 
 
1095
  lrm_eval, iou_eval = _decode_on_anchors_diff(
1096
  q.detach(), image_embeds_anchor, dense_emb, mask_dec, dense_pe
1097
  )
1098
+ # Record soft area at this step for B1 activation diagnosis
1099
+ _step_soft_areas.append(
1100
+ torch.sigmoid(lrm_eval / cfg.area_temp).mean().item()
1101
+ )
1102
  r_task = _compute_task_reward(
1103
+ q.detach(), lrm_eval, iou_eval, image_embeds_anchor, cfg, e0=e0,
1104
+ r_ref=r_ref, r_neg=r_neg,
1105
  ).item()
1106
  if r_task > best_reward:
1107
  best_reward = r_task
1108
  best_q = q.detach().clone()
1109
 
1110
+ # Peak excess: how much did soft area exceed e0 at its highest point?
1111
+ # b1_peak_excess > 0 ↔ B1 ReLU was non-zero at that step.
1112
+ # b1_peak_excess = 0 ↔ B1 never activated (area stayed below e0 throughout).
1113
+ _max_step_area = max(_step_soft_areas) if _step_soft_areas else r_area_soft_init
1114
+ b1_peak_excess = max(_max_step_area - e0, 0.0)
1115
+
1116
  # ── Reward gating: clean re-eval of best_q vs q_init ─────────────────
1117
  with torch.no_grad():
1118
  lrm_b, iou_b = _decode_on_anchors_diff(
1119
  best_q, image_embeds_anchor, dense_emb, mask_dec, dense_pe
1120
  )
1121
  R_best_task = _compute_task_reward(
1122
+ best_q, lrm_b, iou_b, image_embeds_anchor, cfg, e0=e0,
1123
+ r_ref=r_ref, r_neg=r_neg,
1124
  ).item()
1125
 
1126
  area_init = (lrm0 > 0).float().mean().item()
 
1131
  )
1132
  accepted = R_best_task > R_init_task + effective_gate
1133
 
1134
+ # ── Mask soft-IoU: how much did the mask actually change? ─────────────
1135
+ # Answers whether q-drift translated into mask change, or fell in a
1136
+ # flat direction of the mask decoder manifold.
1137
+ with torch.no_grad():
1138
+ m0 = torch.sigmoid(lrm0 / cfg.area_temp).squeeze(1) # [A,256,256]
1139
+ mb = torch.sigmoid(lrm_b / cfg.area_temp).squeeze(1) # [A,256,256]
1140
+ inter = (m0 * mb).sum(dim=[1, 2])
1141
+ union = (m0 + mb - m0 * mb).sum(dim=[1, 2])
1142
+ mask_soft_iou = (inter / (union + 1e-6)).mean().item()
1143
+
1144
+ # Soft area at best_q — tracks whether B1 asymmetric penalty worked
1145
+ r_area_soft_best = mb.mean().item() # sigmoid(lrm_b/area_temp).mean()
1146
+
1147
+ # Reward decomposition: iou contribution to reward gain
1148
+ R_iou_contrib_gain = (
1149
+ cfg.lambda_iou * e0 * (iou_b.mean().item() - iou0.mean().item())
1150
+ )
1151
+
1152
+ # AVT proxy reward (Step A0 correlation study)
1153
+ r_avt_init, r_avt_c_init = _compute_avt_proxy_reward(
1154
+ q_init_fp32, lrm0, image_embeds_anchor, cfg
1155
+ )
1156
+ r_avt_best, r_avt_c_best = _compute_avt_proxy_reward(
1157
+ q_init_fp32, lrm_b, image_embeds_anchor, cfg
1158
+ )
1159
+
1160
  # ── Per-sample diagnostics ────────────────────────────────────────────
1161
  _q_ltpo_stats.append({
1162
+ "accepted": accepted,
1163
+ "reward_gain": R_best_task - R_init_task,
1164
+ "drift": (best_q - q_init_fp32).norm().item(),
1165
+ "hit_clip": hit_clip,
1166
+ "e0": e0,
1167
+ "R_iou_pred_init": iou0.mean().item(),
1168
+ "R_iou_pred_best": iou_b.mean().item(),
1169
+ "area_hard_init": area_init,
1170
+ "area_hard_best": (lrm_b > 0).float().mean().item(),
1171
+ "r_area_soft_init": r_area_soft_init,
1172
+ "r_area_soft_best": r_area_soft_best,
1173
+ "b1_peak_excess": b1_peak_excess,
1174
+ "mask_soft_iou": mask_soft_iou,
1175
+ "R_iou_contrib_gain": R_iou_contrib_gain,
1176
+ # AVT proxy: frozen q_init as teacher — task-specific alignment
1177
+ "r_avt_init": r_avt_init,
1178
+ "r_avt_best": r_avt_best,
1179
+ "r_avt_gain": r_avt_best - r_avt_init,
1180
+ "r_avt_c_init": r_avt_c_init,
1181
+ "r_avt_c_best": r_avt_c_best,
1182
+ "r_avt_c_gain": r_avt_c_best - r_avt_c_init,
1183
  })
1184
 
1185
  if not accepted:
1186
  return F_init.float()
1187
  return best_q
1188
+
1189
+
1190
+ # ===========================================================================
1191
+ # Direction II: Frame-adaptive token optimization (stage=4)
1192
+ # q_t = q_global + delta_t — shared global token + per-anchor residual
1193
+ # ===========================================================================
1194
+
1195
+ def q_ltpo_frame_adaptive(
1196
+ F_init: torch.Tensor, # [1, 256] any dtype on CUDA
1197
+ image_embeds: torch.Tensor, # [T, 256, 64, 64] any dtype on CUDA
1198
+ anchor_indices: List[int],
1199
+ sam_model,
1200
+ model_dtype: torch.dtype,
1201
+ cfg: QLTPOConfig,
1202
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1203
+ """Frame-adaptive q-LTPO: optimize q_global and per-anchor delta jointly.
1204
+
1205
+ Each anchor frame t gets its own token q_t = q_global + delta_t.
1206
+ delta_t is initialized to zero so q_t starts equal to q_init for all frames.
1207
+ Per-frame existence priors e0_t suppress optimization on near-empty anchors.
1208
+
1209
+ Returns:
1210
+ q_global [1, 256] float32 — shared global token
1211
+ delta [A, 256] float32 — per-anchor residuals (zero if not accepted)
1212
+ """
1213
+ device = F_init.device
1214
+ A = len(anchor_indices)
1215
+
1216
+ q_init_fp32 = F_init.float().detach()
1217
+ image_embeds_anchor = image_embeds[anchor_indices].float().detach()
1218
+ dense_emb = _precompute_dense_emb(sam_model, model_dtype, device).float().detach()
1219
+ dense_pe = sam_model.prompt_encoder.get_dense_pe().to(device).float().detach()
1220
+ mask_dec = sam_model.mask_decoder
1221
+
1222
+ rms = q_init_fp32.norm() / (q_init_fp32.numel() ** 0.5)
1223
+ lr = cfg.lr if cfg.lr > 0 else 0.01 * rms.item()
1224
+ max_drift = cfg.max_drift if cfg.max_drift > 0 else 0.5 * q_init_fp32.norm().item()
1225
+ max_delta_drift = cfg.max_delta_drift_scale * q_init_fp32.norm().item()
1226
+
1227
+ # ── Baseline: per-anchor e0 existence priors ────────────────────────────
1228
+ with torch.no_grad():
1229
+ lrm0, iou0 = _decode_on_anchors_diff(
1230
+ q_init_fp32, image_embeds_anchor, dense_emb, mask_dec, dense_pe
1231
+ )
1232
+ e0_vec: List[float] = []
1233
+ for t in range(A):
1234
+ e0_t = torch.sigmoid(lrm0[t] / cfg.area_temp).mean().item()
1235
+ e0_vec.append(_compute_e0(e0_t, cfg))
1236
+ e0_global = sum(e0_vec) / A
1237
+
1238
+ R_init_task = _task_reward_frame_adaptive(lrm0, iou0, cfg, e0_vec).item()
1239
+
1240
+ # ── Setup optimization ───────────────────────────────────────────────────
1241
+ q_global = torch.nn.Parameter(q_init_fp32.clone())
1242
+ delta = torch.nn.Parameter(torch.zeros(A, 256, device=device, dtype=torch.float32))
1243
+ optimizer = torch.optim.Adam([q_global, delta], lr=lr, maximize=True)
1244
+
1245
+ best_q_global = q_global.detach().clone()
1246
+ best_delta = delta.detach().clone()
1247
+ best_reward = R_init_task
1248
+ hit_clip = False
1249
+
1250
+ # ── Optimization loop ────────────────────────────────────────────────────
1251
+ for step in range(cfg.T):
1252
+ optimizer.zero_grad()
1253
+ lrm, iou = _decode_on_anchors_diff_adaptive(
1254
+ q_global, delta, image_embeds_anchor, dense_emb, mask_dec, dense_pe
1255
+ )
1256
+ R_full = _compute_full_reward_adaptive(
1257
+ q_global, delta, lrm, iou, q_init_fp32, cfg, e0_vec
1258
+ )
1259
+ R_full.backward()
1260
+ optimizer.step()
1261
+
1262
+ # Clip q_global and each per-anchor delta within trust regions
1263
+ with torch.no_grad():
1264
+ diff = q_global - q_init_fp32
1265
+ d = diff.norm()
1266
+ if d > max_drift:
1267
+ q_global.copy_(q_init_fp32 + diff * (max_drift / d))
1268
+ hit_clip = True
1269
+ for t in range(A):
1270
+ dn = delta[t].norm()
1271
+ if dn > max_delta_drift:
1272
+ delta[t].copy_(delta[t] * (max_delta_drift / dn))
1273
+
1274
+ # Track best (no_grad re-eval of task reward without reg)
1275
+ with torch.no_grad():
1276
+ lrm_eval, iou_eval = _decode_on_anchors_diff_adaptive(
1277
+ q_global.detach(), delta.detach(),
1278
+ image_embeds_anchor, dense_emb, mask_dec, dense_pe
1279
+ )
1280
+ r_task = _task_reward_frame_adaptive(lrm_eval, iou_eval, cfg, e0_vec).item()
1281
+ if r_task > best_reward:
1282
+ best_reward = r_task
1283
+ best_q_global = q_global.detach().clone()
1284
+ best_delta = delta.detach().clone()
1285
+
1286
+ # ── Gating ───────────────────────────────────────────────────────────────
1287
+ with torch.no_grad():
1288
+ lrm_b, iou_b = _decode_on_anchors_diff_adaptive(
1289
+ best_q_global, best_delta, image_embeds_anchor, dense_emb, mask_dec, dense_pe
1290
+ )
1291
+ R_best_task = _task_reward_frame_adaptive(lrm_b, iou_b, cfg, e0_vec).item()
1292
+
1293
+ accepted = R_best_task > R_init_task + cfg.gate_delta
1294
+
1295
+ area_init = (lrm0 > 0).float().mean().item()
1296
+ r_area_soft_init = sum(torch.sigmoid(lrm0[t] / cfg.area_temp).mean().item() for t in range(A)) / A
1297
+ r_area_soft_best = sum(torch.sigmoid(lrm_b[t] / cfg.area_temp).mean().item() for t in range(A)) / A
1298
+
1299
+ # Actual mask soft-IoU between init and best (per anchor, averaged)
1300
+ m0 = torch.sigmoid(lrm0 / cfg.area_temp).squeeze(1) # [A,256,256]
1301
+ mb = torch.sigmoid(lrm_b / cfg.area_temp).squeeze(1) # [A,256,256]
1302
+ inter = (m0 * mb).sum(dim=[1, 2])
1303
+ union = (m0 + mb - m0 * mb).sum(dim=[1, 2])
1304
+ mask_soft_iou_fa = (inter / (union + 1e-6)).mean().item()
1305
+
1306
+ _q_ltpo_stats.append({
1307
+ "accepted": accepted,
1308
+ "reward_gain": R_best_task - R_init_task,
1309
+ "drift": (best_q_global - q_init_fp32).norm().item(),
1310
+ "delta_norm": best_delta.norm().item(),
1311
+ "hit_clip": hit_clip,
1312
+ "e0": e0_global,
1313
+ "R_iou_pred_init": iou0.mean().item(),
1314
+ "R_iou_pred_best": iou_b.mean().item(),
1315
+ "area_hard_init": area_init,
1316
+ "area_hard_best": (lrm_b > 0).float().mean().item(),
1317
+ "r_area_soft_init": r_area_soft_init,
1318
+ "r_area_soft_best": r_area_soft_best,
1319
+ "b1_peak_excess": 0.0,
1320
+ "mask_soft_iou": mask_soft_iou_fa,
1321
+ "R_iou_contrib_gain": cfg.lambda_iou * e0_global * (iou_b.mean().item() - iou0.mean().item()),
1322
+ })
1323
+
1324
+ if not accepted:
1325
+ return q_init_fp32, torch.zeros(A, 256, device=device, dtype=torch.float32)
1326
+ return best_q_global, best_delta
1327
+
1328
+
1329
+ def decode_full_video_adaptive(
1330
+ q_global: torch.Tensor, # [1, 256] float32
1331
+ delta: torch.Tensor, # [A, 256] float32
1332
+ anchor_indices: List[int],
1333
+ image_embeds: torch.Tensor, # [T, 256, 64, 64] model dtype on CUDA
1334
+ sam_model,
1335
+ resize: tuple,
1336
+ orgsize: tuple,
1337
+ model_dtype: torch.dtype,
1338
+ ) -> torch.Tensor:
1339
+ """Decode all T frames with frame-adaptive tokens.
1340
+
1341
+ Each frame is assigned to its nearest anchor by index distance, then decoded
1342
+ with q_t = q_global + delta[anchor_idx].
1343
+ Returns raw logit masks [T, H_orig, W_orig].
1344
+ """
1345
+ T = image_embeds.shape[0]
1346
+ A = len(anchor_indices)
1347
+ device = image_embeds.device
1348
+
1349
+ dense_emb = _precompute_dense_emb(sam_model, model_dtype, device)
1350
+ dense_pe = sam_model.prompt_encoder.get_dense_pe().to(device)
1351
+
1352
+ # Nearest-anchor assignment for every frame
1353
+ anchor_arr = torch.tensor(anchor_indices, dtype=torch.float32)
1354
+ frame_to_anchor = [int((anchor_arr - t).abs().argmin().item()) for t in range(T)]
1355
+
1356
+ pred_masks: List[torch.Tensor] = []
1357
+ with torch.no_grad():
1358
+ for t in range(T):
1359
+ a = frame_to_anchor[t]
1360
+ q_t = (q_global + delta[a : a + 1]).to(model_dtype) # [1, 256]
1361
+ sparse_emb = q_t.unsqueeze(1) # [1, 1, 256]
1362
+ lrm_t, _ = sam_model.mask_decoder(
1363
+ image_embeddings=image_embeds[t : t + 1],
1364
+ image_pe=dense_pe,
1365
+ sparse_prompt_embeddings=sparse_emb,
1366
+ dense_prompt_embeddings=dense_emb,
1367
+ multimask_output=False,
1368
+ ) # [1, 1, 256, 256]
1369
+ pred_t = sam_model.postprocess_masks(lrm_t, input_size=resize, original_size=orgsize)
1370
+ pred_masks.append(pred_t.squeeze(0).squeeze(0)) # [H, W]
1371
+
1372
+ return torch.stack(pred_masks, dim=0) # [T, H_orig, W_orig]