File size: 22,245 Bytes
1e05592
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
"""VLAlert-X v2 Phase 2 β€” dual-stream cache extractor (leak-free).

For each (video, 8-frame) tick, build a prompt that contains the per-frame
BELIEF reasoning text but NO action tokens (this is the key: GT actions
never enter causal attention so neither stream leaks).

    Scene: ...                           (optional, from manifest)
    Critical: ...                        (optional)
    <|BELIEF|> {belief_text_0} </|BELIEF|>
    <|BELIEF|> {belief_text_1} </|BELIEF|>
    ...
    <|BELIEF|> {belief_text_7} </|BELIEF|>

Forward through Qwen3-VL-4B (SFT'd, `checkpoints/sft_x_v2/best`) with
`output_hidden_states=True`, then extract two complementary features per frame:

  (A) BELIEF_CONTENT[f]   "perception/risk-cue register"
      = mean-pool hidden states over tokens BETWEEN
        the f-th `<|BELIEF|>` and the matching `</|BELIEF|>`,
        EXCLUDING the two tags themselves.
      Concat hidden_states from layers {20, 24, 28, 32}.
      shape: [8, 4 Γ— 2560] = [8, 10240]

  (B) POLICY_POSITION[f]  "decision-time register"
      = hidden state AT the position of the f-th `</|BELIEF|>` closing tag.
      Single layer 33.
      shape: [8, 2560]

The position right after `</|BELIEF|>` is where the SFT model committed to
the next-token prediction (=action). At that position the model has just
finished reading the belief reasoning and is about to emit the action; the
hidden state encodes its commitment state.

Output cache:
    data/belief_cache_v2/{tag}__{split}.pt = {
        "ids":              list[str]                (N,)
        "belief_content":   tensor   [N, 8, 10240]   fp16
        "policy_position":  tensor   [N, 8, 2560]    fp16
        "valid_frames":     tensor   [N, 8]          bool
        "actions_pf":       tensor   [N, 8]          long
        "danger_pf":        tensor   [N, 8]          fp32
        "tta_pf":           tensor   [N, 8]          fp32
        "tick_action":      tensor   [N]             long
        "tick_tta_raw":     tensor   [N]             fp32
        "category":         list[str]
        "source":           list[str]
        "video_id":         list[str]
        "schema":           "vlalert_x_v2_dual_pool"
        "belief_layers":    [20, 24, 28, 32]
        "policy_layer":     33
    }

Usage:
    python tools/make_cache_x_v2.py --split train
    python tools/make_cache_x_v2.py --split val
"""
from __future__ import annotations

# PR patch must run BEFORE Qwen3-VL import
import sys
sys.path.insert(0, ".")
from tools import run_train_cot_belief_fast  # noqa: F401

import argparse
import json
import logging
import re
import time
from pathlib import Path
from typing import Dict, List, Tuple

import torch
from tqdm import tqdm

ROOT = Path(__file__).resolve().parents[1]
logging.basicConfig(level=logging.INFO,
                    format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("make_cache_x_v2")

ACTION_NAME_TO_IDX = {"SILENT": 0, "OBSERVE": 1, "ALERT": 2}


def build_extraction_assistant(beliefs_per_frame: List[str],
                                 scene: str = "",
                                 critical: str = "") -> str:
    """Same as SFT format_assistant_v2 but ACTION TOKENS REMOVED.

    This is the key leak-mitigation: at cache time the prompt has the
    belief reasoning content (perception, not decision) wrapped by
    `<|BELIEF|>...</|BELIEF|>` and NO `<|ACTION|>` tokens anywhere.
    Causal attention cannot leak GT actions because they don't exist.
    """
    from training.VLA.cot_belief_dataset_v2 import BELIEF_OPEN, BELIEF_CLOSE
    assert len(beliefs_per_frame) == 8
    lines: List[str] = []
    scene = (scene or "").strip()
    critical = (critical or "").strip()
    if scene:
        lines.append(f"Scene: {scene}")
    if critical:
        lines.append(f"Critical: {critical}")
    if lines:
        lines.append("")
    for b in beliefs_per_frame:
        b_clean = (b or "").strip().replace("\n", " ")
        b_clean = " ".join(b_clean.split()[:25])
        lines.append(f"{BELIEF_OPEN} {b_clean} {BELIEF_CLOSE}")
    return "\n".join(lines)


@torch.no_grad()
def extract_split(ckpt_dir: Path, base_model: Path,
                    manifest_path: Path, out_path: Path,
                    belief_layers: Tuple[int, ...] = (20, 24, 28, 32),
                    policy_layer: int = 33,
                    n_frames: int = 8,
                    limit: int = 0,
                    batch_size: int = 4,
                    pool_mode: str = "range",
                    random_span_seed: int = 0):
    if out_path.exists():
        logger.info(f"[skip] {out_path} exists β€” delete to re-extract")
        return

    from transformers import AutoProcessor, AutoModelForImageTextToText
    from peft import PeftModel
    from training.VLA.cot_belief_dataset_v2 import (
        ALL_SPECIAL, BELIEF_OPEN, BELIEF_CLOSE, build_chat_v2,
    )
    from training.VLA.frame_utils import sample_frames

    logger.info(f"[load] base_model={base_model} ckpt={ckpt_dir}")
    logger.info(f"       belief_layers={belief_layers}  policy_layer={policy_layer}  "
                f"batch_size={batch_size}")
    processor = AutoProcessor.from_pretrained(base_model, trust_remote_code=True)
    processor.tokenizer.add_special_tokens({"additional_special_tokens": ALL_SPECIAL})
    # IMPORTANT: right padding so BELIEF token positions stay correct in batched mode
    processor.tokenizer.padding_side = "right"
    model = AutoModelForImageTextToText.from_pretrained(
        base_model, dtype=torch.bfloat16, device_map="auto",
        trust_remote_code=True)
    model.resize_token_embeddings(len(processor.tokenizer))
    if (ckpt_dir / "adapter_config.json").exists():
        model = PeftModel.from_pretrained(model, ckpt_dir)
    model.eval()

    tok = processor.tokenizer
    belief_open_id  = tok.convert_tokens_to_ids(BELIEF_OPEN)
    belief_close_id = tok.convert_tokens_to_ids(BELIEF_CLOSE)
    logger.info(f"[tok] BELIEF_OPEN={belief_open_id}  BELIEF_CLOSE={belief_close_id}")

    # ── load manifest ──
    records: List[Dict] = []
    with open(manifest_path) as f:
        for ln in f:
            ln = ln.strip()
            if not ln: continue
            try:
                r = json.loads(ln)
            except json.JSONDecodeError:
                continue
            if (isinstance(r.get("beliefs_per_frame"), list)
                 and len(r["beliefs_per_frame"]) == n_frames
                 and r.get("video_path")):
                records.append(r)
    if limit > 0:
        records = records[:limit]
    logger.info(f"[load] {manifest_path}  n={len(records)}")

    # output tensors (lazy-alloc after first forward to know hidden_dim)
    N = len(records)
    n_belief_layers = len(belief_layers)
    out_belief: torch.Tensor = None   # [N, 8, n_belief_layers * D]
    out_policy: torch.Tensor = None   # [N, 8, D]
    out_valid = torch.zeros(N, n_frames, dtype=torch.bool)
    out_actions = torch.zeros(N, n_frames, dtype=torch.long)
    out_danger = torch.zeros(N, n_frames, dtype=torch.float32)
    out_tta = torch.zeros(N, n_frames, dtype=torch.float32)
    out_tick_action = torch.zeros(N, dtype=torch.long)
    out_tick_tta = torch.zeros(N, dtype=torch.float32)
    ids_list: List[str] = [None] * N
    cat_list: List[str] = [""] * N
    src_list: List[str] = [""] * N
    vid_list: List[str] = [""] * N

    n_failed = 0
    n_pool_fallback = 0
    t0 = time.time()

    def _prepare_one(rec):
        """Decode frames + build text for a single record. Returns
        (frames, full_text) or None on failure."""
        frames = sample_frames(rec["video_path"], n_frames=n_frames,
                                 resize_short=336,
                                 frame_indices=rec["frame_indices"])
        assistant_text = build_extraction_assistant(
            rec["beliefs_per_frame"],
            scene=rec.get("scene", ""),
            critical=rec.get("critical", ""),
        )
        full_msgs = build_chat_v2(frames, assistant_text=assistant_text)
        full_text = processor.apply_chat_template(
            full_msgs, tokenize=False, add_generation_prompt=False)
        return frames, full_text

    # Process in batches of `batch_size` for parallel GPU utilisation.
    # With batch_size=4 on Qwen3-VL-4B + Conv3d→Linear patch, expect ~3-4× the
    # batch=1 throughput on RTX 5090 with ≀30 GB VRAM.
    for batch_start in tqdm(range(0, N, batch_size), ncols=80, desc="cache_v2"):
        batch_end = min(N, batch_start + batch_size)
        batch_recs = records[batch_start:batch_end]

        # ── prepare batch (CPU: decode + tokenize text) ──
        batch_frames = []
        batch_texts  = []
        keep_idx     = []                # indices within this batch that succeeded prep
        for j, rec in enumerate(batch_recs):
            try:
                frames, full_text = _prepare_one(rec)
                batch_frames.append(frames)
                batch_texts.append(full_text)
                keep_idx.append(j)
            except Exception as e:
                n_failed += 1
                logger.warning(f"[skip] {rec.get('id')}: {e}")
                global_i = batch_start + j
                ids_list[global_i] = rec.get("id", str(global_i))

        if not keep_idx:
            continue

        try:
            # batched tokenisation (right padding, so BELIEF positions stay correct)
            inputs = processor(text=batch_texts, images=batch_frames,
                                 return_tensors="pt", padding=True,
                                 truncation=True, max_length=4096)
            inputs = {k: v.to(model.device) for k, v in inputs.items()}

            out = model(**inputs, output_hidden_states=True, return_dict=True)
            hs_tuple = out.hidden_states                 # tuple of [B, T, D]
            ids_b_all   = inputs["input_ids"]            # [B, T]
            attn_b_all  = inputs["attention_mask"]       # [B, T]
            D = hs_tuple[-1].shape[-1]
        except torch.cuda.OutOfMemoryError as e:
            logger.error(f"[OOM] batch {batch_start}..{batch_end}: {e}")
            torch.cuda.empty_cache()
            n_failed += len(keep_idx)
            for j in keep_idx:
                global_i = batch_start + j
                ids_list[global_i] = batch_recs[j].get("id", str(global_i))
            continue
        except Exception as e:
            logger.error(f"[fwd-err] batch {batch_start}..{batch_end}: {e}")
            n_failed += len(keep_idx)
            for j in keep_idx:
                global_i = batch_start + j
                ids_list[global_i] = batch_recs[j].get("id", str(global_i))
            continue

        # ── per-sample extraction ──
        # lazy-allocate output tensors (need D from first forward)
        if out_belief is None:
            out_belief = torch.zeros(N, n_frames, n_belief_layers * D,
                                        dtype=torch.float16)
            out_policy = torch.zeros(N, n_frames, D, dtype=torch.float16)
            logger.info(f"[alloc] belief shape={tuple(out_belief.shape)}  "
                        f"policy shape={tuple(out_policy.shape)}")

        for b, j in enumerate(keep_idx):
            global_i = batch_start + j
            rec = batch_recs[j]
            ids_t  = ids_b_all[b]
            attn_t = attn_b_all[b]

            # restrict to valid (non-pad) region
            valid_mask = attn_t.bool()
            open_pos  = ((ids_t == belief_open_id) & valid_mask).nonzero(
                as_tuple=False).flatten().tolist()
            close_pos = ((ids_t == belief_close_id) & valid_mask).nonzero(
                as_tuple=False).flatten().tolist()
            n_blocks = min(len(open_pos), len(close_pos), n_frames)

            if n_blocks == 0:
                n_pool_fallback += 1
                ids_list[global_i] = rec["id"]
                cat_list[global_i] = rec.get("category", "")
                src_list[global_i] = rec.get("source", "")
                vid_list[global_i] = rec.get("video_id", rec["id"])
                continue

            belief_concat = torch.zeros(n_blocks, n_belief_layers * D,
                                            dtype=torch.float16)
            policy_vec = torch.zeros(n_blocks, D, dtype=torch.float16)

            # Pre-compute pool spans per frame, depending on pool_mode.
            # For each frame f we need (inner_start, inner_end) on the same
            # token stream as the original (range) extractor.
            T_valid = int(valid_mask.sum().item())
            pairs_default = list(zip(open_pos[:n_blocks], close_pos[:n_blocks]))

            if pool_mode == "range":
                pool_spans = [(o + 1, c) for (o, c) in pairs_default]
            elif pool_mode == "open":
                # single-token pool at <|BELIEF|> open position (length-1 span)
                pool_spans = [(o, o + 1) for (o, c) in pairs_default]
            elif pool_mode == "token_mean":
                # Format-agnostic baseline: mean over the assistant-response span
                # (first OPEN β†’ last CLOSE), replicated across n_blocks frames.
                resp_start = open_pos[0]
                resp_end   = close_pos[min(len(close_pos), n_blocks) - 1] + 1
                pool_spans = [(resp_start, resp_end)] * n_blocks
            elif pool_mode == "random_span":
                # Control: spans of same length as the average BELIEF span on
                # this sample, but at random positions inside the response.
                import random as _rnd
                rng = _rnd.Random(int(random_span_seed) * 100003 + global_i)
                span_lens = [c - (o + 1) for (o, c) in pairs_default if c > o + 1]
                L_span = max(3, int(round(sum(span_lens) / max(len(span_lens), 1))))
                resp_start = open_pos[0]
                resp_end   = close_pos[min(len(close_pos), n_blocks) - 1] + 1
                pool_spans = []
                for f in range(n_blocks):
                    if resp_end - resp_start <= L_span:
                        pool_spans.append((resp_start, resp_end))
                    else:
                        s = rng.randint(resp_start, resp_end - L_span)
                        pool_spans.append((s, s + L_span))
            else:
                raise ValueError(f"unknown pool_mode={pool_mode}")

            for f, ((o, c), (s, e)) in enumerate(zip(pairs_default, pool_spans)):
                if e <= s:
                    n_pool_fallback += 1
                    continue
                parts = []
                for L in belief_layers:
                    Lh = hs_tuple[L][b, s:e]
                    parts.append(Lh.mean(dim=0).to(torch.float16))
                belief_concat[f] = torch.cat(parts, dim=-1).cpu()
                # policy_position stays as the hidden state AT the f-th close-tag
                # so downstream PolicyHead receives the same register regardless
                # of pool_mode β€” isolating the ablation to belief_content only.
                policy_vec[f] = hs_tuple[policy_layer][b, c].to(torch.float16).cpu()
                out_valid[global_i, f] = True

            out_belief[global_i, :n_blocks] = belief_concat
            out_policy[global_i, :n_blocks] = policy_vec

            ids_list[global_i] = rec["id"]
            cat_list[global_i] = rec.get("category", "")
            src_list[global_i] = rec.get("source", "")
            vid_list[global_i] = rec.get("video_id", rec["id"])
            out_actions[global_i] = torch.tensor(
                [ACTION_NAME_TO_IDX.get(a, 0) for a in rec["actions_per_frame"]],
                dtype=torch.long)
            out_danger[global_i] = torch.tensor(rec["danger_per_frame"],
                                                  dtype=torch.float32)
            out_tta[global_i]    = torch.tensor(rec["tta_per_frame"],
                                                  dtype=torch.float32)
            out_tick_action[global_i] = ACTION_NAME_TO_IDX.get(
                rec.get("tick_action", "SILENT"), 0)
            out_tick_tta[global_i] = float(rec.get("tick_tta_raw", -1.0))

    # keep only successful entries (non-empty id)
    # MEMORY-SAFE: avoid fancy-index COPY of 30 GB belief tensor that OOM-kills the
    # process at save time. If all records succeeded (the typical case), pass
    # tensors through directly. Else use torch.index_select which is memory-
    # equivalent to fancy indexing but cleaner to free.
    keep = [k for k, x in enumerate(ids_list) if x is not None]
    all_valid = (len(keep) == N)

    if all_valid:
        belief_save = out_belief
        policy_save = out_policy
        valid_save = out_valid
        actions_save = out_actions
        danger_save = out_danger
        tta_save = out_tta
        tick_action_save = out_tick_action
        tick_tta_save = out_tick_tta
    else:
        keep_t = torch.tensor(keep, dtype=torch.long)
        belief_save = (out_belief.index_select(0, keep_t)
                        if out_belief is not None else None)
        policy_save = (out_policy.index_select(0, keep_t)
                        if out_policy is not None else None)
        valid_save = out_valid.index_select(0, keep_t)
        actions_save = out_actions.index_select(0, keep_t)
        danger_save = out_danger.index_select(0, keep_t)
        tta_save = out_tta.index_select(0, keep_t)
        tick_action_save = out_tick_action.index_select(0, keep_t)
        tick_tta_save = out_tick_tta.index_select(0, keep_t)
        # Free the original full tensors before torch.save (avoid 2x peak RAM)
        out_belief = out_policy = None
        out_valid = out_actions = out_danger = out_tta = None
        out_tick_action = out_tick_tta = None
        import gc; gc.collect()

    out_dict = {
        "ids":             [ids_list[k] for k in keep],
        "belief_content":  belief_save,
        "policy_position": policy_save,
        "valid_frames":    valid_save,
        "actions_pf":      actions_save,
        "danger_pf":       danger_save,
        "tta_pf":          tta_save,
        "tick_action":     tick_action_save,
        "tick_tta_raw":    tick_tta_save,
        "category":        [cat_list[k] for k in keep],
        "source":          [src_list[k] for k in keep],
        "video_id":        [vid_list[k] for k in keep],
        "schema":          "vlalert_x_v2_dual_pool",
        "belief_layers":   list(belief_layers),
        "policy_layer":    policy_layer,
        "pool_mode":       pool_mode,
        "ckpt":            str(ckpt_dir),
    }
    out_path.parent.mkdir(parents=True, exist_ok=True)
    logger.info(f"[save] writing β†’ {out_path}  "
                f"(belief {tuple(belief_save.shape) if belief_save is not None else None}, "
                f"policy {tuple(policy_save.shape) if policy_save is not None else None})")
    # Atomic write: save to .tmp then rename (avoids partial files on crash)
    tmp_path = out_path.with_suffix(out_path.suffix + ".tmp")
    torch.save(out_dict, tmp_path)
    import os
    os.replace(str(tmp_path), str(out_path))
    dt = time.time() - t0
    logger.info(f"[save] DONE β†’ {out_path}")
    if belief_save is not None:
        logger.info(f"  belief_content shape={tuple(belief_save.shape)}")
        logger.info(f"  policy_position shape={tuple(policy_save.shape)}")
    logger.info(f"  n={len(keep)}  failed={n_failed}  fallback={n_pool_fallback}  "
                f"elapsed={dt:.0f}s ({len(keep)/max(dt,1):.2f} it/s)")


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--split", required=True,
                    help="Tag for output filename. Common: train|val|"
                         "multisrc_val_full|adasto_val|nexar_test|...")
    ap.add_argument("--manifest", type=Path)
    ap.add_argument("--ckpt", type=Path,
                    default=ROOT / "checkpoints/sft_x_v2/best")
    ap.add_argument("--base_model", type=Path,
                    default=ROOT / "models/Qwen3-VL-4B-Instruct")
    ap.add_argument("--tag", default="sft_x_v2")
    ap.add_argument("--out_dir", type=Path,
                    default=ROOT / "data/belief_cache_v2")
    ap.add_argument("--belief_layers", nargs="+", type=int,
                    default=[20, 24, 28, 32])
    ap.add_argument("--policy_layer", type=int, default=33)
    ap.add_argument("--limit", type=int, default=0)
    ap.add_argument("--batch_size", type=int, default=4,
                    help="Forward batch size. 4 fits in ~30 GB on RTX 5090 "
                         "with Qwen3-VL-4B + Conv3d patch + bf16.")
    ap.add_argument("--pool_mode",
                    choices=["range", "open", "token_mean", "random_span"],
                    default="range",
                    # Note: "action" mode is not supported here because the
                    # extraction prompt only contains <|BELIEF|>...</|BELIEF|>
                    # spans (no action tokens fed to the model). Add a separate
                    # extraction prompt if you want action-position pooling.
                    help="How to pool hidden states to form belief_content: "
                         "range=mean inside <|BELIEF|>...</|BELIEF|> span (default); "
                         "open=hidden at <|BELIEF|> open token; "
                         "token_mean=mean over the whole response (format-agnostic); "
                         "random_span=same-length span at random positions (control).")
    ap.add_argument("--random_span_seed", type=int, default=0)
    args = ap.parse_args()

    if args.manifest is None:
        args.manifest = ROOT / f"data/cot_corpus_v2/vlalert_x_perframe_v2_{args.split}.jsonl"
    out_path = args.out_dir / f"{args.tag}__{args.split}.pt"
    extract_split(ckpt_dir=args.ckpt, base_model=args.base_model,
                   manifest_path=args.manifest, out_path=out_path,
                   belief_layers=tuple(args.belief_layers),
                   policy_layer=args.policy_layer,
                   limit=args.limit,
                   batch_size=args.batch_size,
                   pool_mode=args.pool_mode,
                   random_span_seed=args.random_span_seed)


if __name__ == "__main__":
    main()