code_SAS_VLM2Vec / eval_test_time_early_exit_mid2mid.py
MgGladys's picture
Add files using upload-large-folder tool
ac8b25b verified
import datetime
import logging
import json
import random
import time
import numpy as np
import os
import pickle
import sys
import torch
import torch.distributed as dist
import torch.nn.functional as F
import yaml
import transformers
import math
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import HfArgumentParser, AutoConfig, AutoTokenizer
from datasets import Dataset, concatenate_datasets
from datasets.distributed import split_dataset_by_node
from src.model.vlm_backbone.qwen2_vl.modeling_qwen2_vl_train_tokrnpooling import Qwen2VLForConditionalGeneration as _Qwen2VLForConditionalGeneration_src
from src.arguments import ModelArguments, DataArguments, TrainingArguments
from src.data.collator.eval_collator import MultimodalEvalDataCollator
from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset
from src.eval_utils.metrics import RankingMetrics
from src.model.model_cut_layer_AOP_add_text_cut import MMEBModel
from src.model.processor import get_backbone_name, load_processor, COLPALI
from src.utils import batch_to_device, print_rank, print_master
from dataclasses import dataclass
def get_env_mid_layer():
v = os.environ.get("MID_LM_LAYER", "").strip()
if v == "" or v.lower() in {"none", "null"}:
return None
try:
return int(v)
except:
logger.warning(f"Invalid MID_LM_LAYER={v}, ignore.")
return None
# ------------- AOP-Prune config parsing -------------
def _parse_bool(v: str, default=False):
if v is None: return default
v = v.strip().lower()
return v in {"1","true","yes","y","t","on"}
def _parse_float(v: str, default=None):
try: return float(v) if v is not None else default
except: return default
def _parse_int(v: str, default=None):
try: return int(v) if v is not None else default
except: return default
def get_env_aop_config():
"""
从环境变量读取 AOP 剪裁配置。仅作为“驱动层”的简要测试开关;
实际剪裁逻辑在底模里(Qwen2-VLModel.forward)实现。
"""
enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False)
apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() # qry|cand|both
layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None)
mode = os.environ.get("AOP_MODE", "delta").strip().lower()
# 通用回退
delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10)
khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0)
keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0)
min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64)
use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True)
# 按类型控制
prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True)
prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False)
delta_v = _parse_float(os.environ.get("AOP_DELTA_VISION"), None)
khat_v = _parse_float(os.environ.get("AOP_KHAT_VISION"), None)
keep_ratio_v= _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None)
min_keep_v = _parse_int(os.environ.get("AOP_MIN_KEEP_VISION"), None)
delta_t = _parse_float(os.environ.get("AOP_DELTA_TEXT"), None)
khat_t = _parse_float(os.environ.get("AOP_KHAT_TEXT"), None)
keep_ratio_t= _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None)
min_keep_t = _parse_int(os.environ.get("AOP_MIN_KEEP_TEXT"), 32)
protect_text_last = _parse_int(os.environ.get("AOP_PROTECT_TEXT_LAST"), 16)
protect_special = _parse_bool(os.environ.get("AOP_PROTECT_SPECIAL"), True)
margin_src = os.environ.get("AOP_MARGIN", "").strip().lower() # "" or "mid"
attn_impl = os.environ.get("AOP_ATTN_IMPL", "").strip().lower() # "" or "sdpa"
if layer_idx is None and enabled:
logger.warning("AOP_ENABLED=1 但未设置 AOP_LAYER,关闭 AOP。"); enabled=False
# 新增:选择策略(aop | random)
selection = os.environ.get("AOP_SELECTION", "aop").strip().lower()
if _parse_bool(os.environ.get("AOP_RANDOM"), False):
selection = "random"
random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None)
# 选择策略:aop | random | attention
selection = os.environ.get("AOP_SELECTION", "aop").strip().lower()
if _parse_bool(os.environ.get("AOP_RANDOM"), False):
selection = "random"
random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None)
attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() # mean|max|sum
cfg = {
"enabled": enabled,
"apply_to": apply_to,
"layer_idx": layer_idx,
"mode": mode,
# 回退
"delta": delta, "K_hat": khat,
"keep_ratio": keep_ratio, "min_keep": min_keep,
"use_bias": use_bias, "eps": 1e-6,
# 类型开关
"prune_vision": prune_vision,
"prune_text": prune_text,
# 视觉桶
"delta_vision": delta_v,
"K_hat_vision": khat_v,
"keep_ratio_vision": keep_ratio_v,
"min_keep_vision": min_keep_v,
# 文本桶
"delta_text": delta_t,
"K_hat_text": khat_t,
"keep_ratio_text": keep_ratio_t,
"min_keep_text": min_keep_t,
# 文本保护
"protect_text_last": protect_text_last,
"protect_special": protect_special,
# 可选:排名安全预算
"margin_mid": None if margin_src != "mid" else "USE_MID_MARGIN",
"epsilon_hat": None,
"attn_impl_override": attn_impl if attn_impl in {"sdpa"} else "",
# NEW: 选择策略
"selection": selection, # "aop" 或 "random"
"random_seed": random_seed, # 可选
"attn_agg": attn_agg,
}
return cfg
def get_env_eval_layers():
"""
解析环境变量 LM_LAYERS(优先)或兼容旧的 MID_LM_LAYER。
- LM_LAYERS 示例:"4,8,12,last";可包含 'last'/'none'/'null'/'-1' 表示最后一层(None)。
- 若未设置 LM_LAYERS,则回落到旧逻辑:MID_LM_LAYER=None -> [None];否则 [mid, None]
返回: list[ int | None ],例如 [4, 8, 12, None];None 代表最后一层。
"""
v = os.environ.get("LM_LAYERS", None)
if v is not None:
v = v.strip()
if v:
toks = [t.strip() for t in v.split(',') if t.strip() != ""]
layers = []
for tok in toks:
tl = tok.lower()
if tl in {"last", "none", "null", "-1"}:
layers.append(None)
else:
try:
val = int(tok)
if val > 0:
layers.append(val)
else:
logger.warning(f"Ignoring non-positive layer '{tok}' in LM_LAYERS.")
except Exception:
logger.warning(f"Invalid token '{tok}' in LM_LAYERS; must be int or 'last'/'none'.")
# 去重但保持顺序
seen = set()
uniq = []
for l in layers:
key = -1 if l is None else l
if key in seen:
continue
seen.add(key)
uniq.append(l)
if not uniq:
return [None]
return uniq
else:
# 兼容旧逻辑
mid = get_env_mid_layer()
return [None] if mid is None else [mid, None]
# === Early-Exit config & helpers ===
def get_env_ee_config():
ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on","y","t"}
layer = int(os.environ.get("EE_LAYER", os.environ.get("AOP_LAYER", "12"))) # 默认用 AOP_LAYER
method = os.environ.get("EE_METHOD", "margin").strip().lower() # margin|p1p2|entropy|gini|combined
tau = float(os.environ.get("EE_TAU", "0.2"))
topk = int(os.environ.get("EE_TOPK", "1024"))
temp = float(os.environ.get("EE_TEMP", "0.05"))
save = os.environ.get("EE_SAVE", "1").strip().lower() in {"1","true","yes","on","y","t"}
combw = os.environ.get("EE_COMB_WEIGHTS", "1.0,0.5,0.5")
try:
w_margin, w_conf, w_sq = [float(x) for x in combw.split(",")]
except Exception:
w_margin, w_conf, w_sq = 1.0, 0.5, 0.5
return dict(
enabled=ee_enabled, layer=layer, method=method, tau=tau,
topk=topk, temp=temp, save=save,
w_margin=w_margin, w_conf=w_conf, w_sq=w_sq
)
def _softmax_np(x: np.ndarray, temp: float = 1.0) -> np.ndarray:
x = x - np.max(x)
ex = np.exp(x / max(1e-6, temp))
s = np.sum(ex)
return ex / max(s, 1e-12)
def confidence_from_topk(scores: np.ndarray, method="margin", temp=0.05, w_margin=1.0, w_conf=0.5, w_sq=0.5) -> float:
# scores: 已按降序排列(topK)
if scores.size == 0:
return 0.0
if scores.size == 1:
return 1e9
margin = float(scores[0] - scores[1])
p = _softmax_np(scores, temp=temp)
p1p2 = float(p[0] - p[1])
H = - float(np.sum(p * np.log(p + 1e-12))) / np.log(len(p)) # 归一化熵 ∈ [0,1]
conf = 1.0 - H
sqsum = float(np.sum(p**2)) # Gini 的等价度量(越大越集中)
if method == "margin": return margin
if method == "p1p2": return p1p2
if method == "entropy": return conf
if method == "gini": return sqsum
# combined
return w_margin*margin + w_conf*conf + w_sq*sqsum
def run_mid_layer_retrieval(
model: MMEBModel,
processor,
model_args: ModelArguments,
data_args: DataArguments,
training_args: TrainingArguments,
qry_dataset: Dataset,
cand_mid_dict: dict,
mid_layer: int,
dataset_name: str,
out_dir: str,
global_ranking: bool = True,
topk_save: int = 200,
save_qry_embeds: bool = True,
):
"""
在线检索:只用中间层 qry & 中间层 cand 做相似度和排序,并将中间层 embedding 和相似度信息保存到文件。
参数:
- cand_mid_dict: {cand_id: embedding(mid_layer)},由 encode_candidates_both_layers 生成
- mid_layer: 中间层 index(和训练用的 supervise_layers 保持一致,如 12 / 16)
- global_ranking: True=全库检索,False=按 gt_info["cand_names"] 的局部子集检索
- topk_save: 每个 query 保存相似度的 topK 候选数
输出文件:
- {out_dir}/{dataset}_qry_mid_L{mid_layer}.pkl # 全部 query 的中间层 embedding (np.ndarray [Nq,D])
- {out_dir}/{dataset}_score_midonly_L{mid_layer}.json # 检索指标
- {out_dir}/{dataset}_pred_midonly_L{mid_layer}.jsonl # 排序结果
- {out_dir}/{dataset}_midonly_L{mid_layer}_sims_top{K}.jsonl# 每个 query 的 topK 相似度信息
"""
device = training_args.device
local_rank = dist.get_rank() if dist.is_initialized() else 0
is_main = (not dist.is_initialized()) or (local_rank == 0)
# ------- 1) 准备 cand 中间层矩阵 -------
cand_ids = list(cand_mid_dict.keys())
cand_id2idx = {str(cid): i for i, cid in enumerate(cand_ids)}
cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32)
cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16) # [Nc,D]
Nc = cand_mid_t.size(0)
# ------- 2) 构建 query DataLoader -------
collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry")
loader = DataLoader(
qry_dataset,
batch_size=training_args.per_device_eval_batch_size,
collate_fn=collator,
num_workers=training_args.dataloader_num_workers,
)
pred_dicts = []
all_qry_embeds = [] # 存全部 query 中间层 embedding,用于最后保存
sim_records = [] if is_main else None
qid_offset = 0
start_time = time.time()
for inputs, infos in tqdm(
loader,
desc=f"[MID-ONLY] {dataset_name}@L{mid_layer}",
disable=local_rank > 0,
):
inputs = batch_to_device(inputs, device)
# -------- 3) 只跑到 mid_layer,取中间层 hidden_states --------
with torch.no_grad(), torch.autocast(
device_type="cuda", dtype=torch.bfloat16, enabled=True
):
out_mid = model.encoder(
**inputs,
return_dict=True,
output_hidden_states=False,
stop_at_layer=int(mid_layer),
compute_lm_head=False,
)
hs_mid = getattr(out_mid, "last_hidden_state", None)
if hs_mid is None:
assert out_mid.hidden_states is not None and len(out_mid.hidden_states) > 0
hs_mid = out_mid.hidden_states[-1]
am_mid = getattr(out_mid, "attention_mask", None)
if am_mid is None:
am_mid = inputs.get("attention_mask", None)
if hasattr(am_mid, "device") and am_mid.device != hs_mid.device:
am_mid = am_mid.to(hs_mid.device)
# EOS pooling -> 中间层 query 表征 [B,D]
reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(
device=device, dtype=torch.bfloat16
) # [B,D]
B = reps_mid_t.size(0)
# 存 embedding(float32)
all_qry_embeds.append(reps_mid_t.float().cpu())
# -------- 4) 计算 mid→mid 相似度并排序 --------
if global_ranking:
# 全库检索:cand_mid_t 全量
scores = (reps_mid_t @ cand_mid_t.T).detach().float().cpu().numpy() # [B,Nc]
for b in range(B):
info = infos[b]
score_vec = scores[b] # [Nc]
# 全排序用于评测
full_order = np.argsort(-score_vec) # indices into cand_ids
# 构造 pred_dicts(完整排序)
rel_docids = info["label_name"]
if not isinstance(rel_docids, list):
rel_docids = [rel_docids]
pred_dicts.append({
"prediction": [cand_ids[i] for i in full_order],
"label": rel_docids,
"rel_scores": None,
})
# 记录 topK 相似度
if is_main and sim_records is not None:
K = min(topk_save, Nc)
top_idx = full_order[:K]
sim_records.append({
"qid": int(qid_offset + b),
"label": rel_docids,
"topk_cand_ids": [cand_ids[i] for i in top_idx],
"topk_scores": score_vec[top_idx].astype(float).tolist(),
})
else:
# 局部检索:每个 query 只在自己的 cand_names 子集里排序
reps_cpu = reps_mid_t.detach().float().cpu()
for b in range(B):
info = infos[b]
cand_local = info["cand_names"]
rows = [cand_id2idx.get(str(cid), -1) for cid in cand_local]
rows = [r for r in rows if r >= 0]
if len(rows) == 0:
preds = []
score_vec = np.zeros((0,), dtype=np.float32)
else:
cmat = cand_mid_t[rows] # [Nl,D]
sv = (reps_mid_t[b:b+1] @ cmat.T)[0].detach().float().cpu().numpy() # [Nl]
order_local = np.argsort(-sv)
preds = [str(cand_local[i]) for i in order_local]
score_vec = sv
rel_docids = info["label_name"]
if not isinstance(rel_docids, list):
rel_docids = [rel_docids]
pred_dicts.append({
"prediction": preds,
"label": rel_docids,
"rel_scores": None,
})
if is_main and sim_records is not None:
K = min(topk_save, len(preds))
top_ids = preds[:K]
top_scores = score_vec[order_local[:K]].astype(float).tolist() if len(score_vec) > 0 else []
sim_records.append({
"qid": int(qid_offset + b),
"label": rel_docids,
"topk_cand_ids": top_ids,
"topk_scores": top_scores,
})
qid_offset += B
# -------- 5) 聚合并保存中间层 embedding --------
if save_qry_embeds and is_main and len(all_qry_embeds) > 0:
all_qry_mid = torch.cat(all_qry_embeds, dim=0).numpy() # [Nq,D]
qry_mid_path = os.path.join(out_dir, f"{dataset_name}_qry_mid_L{mid_layer}.pkl")
with open(qry_mid_path, "wb") as f:
pickle.dump(all_qry_mid, f)
print_master(f"[MID-ONLY] Saved query mid-layer embeddings -> {qry_mid_path}")
# -------- 6) 评测 + 写结果 --------
metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"]
score = RankingMetrics(metrics_to_report).evaluate(pred_dicts)
if is_main:
os.makedirs(out_dir, exist_ok=True)
score_path = os.path.join(out_dir, f"{dataset_name}_score_midonly_L{mid_layer}.json")
pred_path = os.path.join(out_dir, f"{dataset_name}_pred_midonly_L{mid_layer}.jsonl")
sims_path = os.path.join(out_dir, f"{dataset_name}_midonly_L{mid_layer}_sims_top{topk_save}.jsonl")
score["num_pred"] = len(pred_dicts)
with open(score_path, "w") as f:
json.dump(score, f, indent=4)
with open(pred_path, "w", encoding="utf-8") as f:
for row in pred_dicts:
f.write(json.dumps(row, ensure_ascii=False) + "\n")
if sim_records is not None:
with open(sims_path, "w", encoding="utf-8") as f:
for rec in sim_records:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
print_master(
f"[MID-ONLY] {dataset_name}@L{mid_layer} score: "
+ json.dumps({k: (f"{v:.4f}" if isinstance(v, (int, float)) else v) for k, v in score.items()})
)
print_master(f"[MID-ONLY] Saved mid-layer similarity records -> {sims_path}")
elapsed = time.time() - start_time
return score, elapsed
def run_early_exit_queries(
model: MMEBModel,
processor,
model_args: ModelArguments,
data_args: DataArguments,
training_args: TrainingArguments,
qry_dataset: Dataset,
cand_mid_dict: dict,
cand_last_dict: dict,
ee_cfg: dict,
dataset_name: str,
out_dir: str,
global_ranking: bool = True,
):
"""
仅在线早停推理(不画曲线),并额外输出:
- 每个 query 的中间层 vs cand_last 的 top-K 相似度 (mid2last)
- 对未早停的 query,再输出最后一层 vs cand_last 的 top-K 相似度 (last2last)
输出文件:
{out_dir}/{dataset}_score_earlyexit.json - 检索指标
{out_dir}/{dataset}_pred_earlyexit.jsonl - 预测列表(原有)
{out_dir}/{dataset}_sim_earlyexit.jsonl - 本函数新增的相似度信息(需 EE_SAVE=1)
"""
device = training_args.device
local_rank = dist.get_rank() if dist.is_initialized() else 0
is_main = (not dist.is_initialized()) or (local_rank == 0)
# 候选矩阵
cand_ids = list(cand_mid_dict.keys())
cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)}
cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32)
cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32)
cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16)
cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16)
# query DataLoader
collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry")
loader = DataLoader(
qry_dataset,
batch_size=training_args.per_device_eval_batch_size,
collate_fn=collator,
num_workers=training_args.dataloader_num_workers
)
pred_dicts = []
# 是否保存相似度(沿用 EE_SAVE)
save_scores = ee_cfg.get("save", False)
topk_sim = int(ee_cfg.get("topk", 1024))
sim_records = [] if (save_scores and is_main) else None
# AOP 按侧开启
aop_cfg = getattr(model.encoder, "aop_prune_config", None)
_orig_enabled = None
side_enable = True
if isinstance(aop_cfg, dict) and aop_cfg:
_orig_enabled = aop_cfg.get("enabled", False)
apply_to = aop_cfg.get("apply_to", "qry")
side_enable = (apply_to == "both") or (apply_to == "qry")
# 门控相关
k_conf = 2
tau = float(ee_cfg["tau"])
method= ee_cfg["method"]
temp = float(ee_cfg["temp"])
idx_global = 0
start_time = time.time()
for inputs, infos in tqdm(
loader,
desc=f"[EE] {dataset_name}@L{ee_cfg['layer']} (rank {local_rank})",
disable=local_rank > 0,
):
inputs = batch_to_device(inputs, device)
# -------- 1) 跑到中间层(stop_at_layer),不算 logits --------
orig_cfg = None
if isinstance(aop_cfg, dict) and aop_cfg:
orig_cfg = dict(aop_cfg)
aop_layer = aop_cfg.get("layer_idx", None)
ee_layer = int(ee_cfg["layer"])
apply_to = aop_cfg.get("apply_to", "qry").strip().lower()
aop_on_mid = bool(
_orig_enabled and side_enable and
(aop_layer is not None) and (aop_layer < ee_layer) and
(apply_to in {"qry", "both"})
)
aop_cfg_mid = dict(aop_cfg)
aop_cfg_mid["enabled"] = aop_on_mid
setattr(model.encoder, "aop_prune_config", aop_cfg_mid)
with torch.no_grad(), torch.autocast(
device_type="cuda", dtype=torch.bfloat16, enabled=True
):
out_mid = model.encoder(
**inputs,
return_dict=True,
output_hidden_states=False,
stop_at_layer=int(ee_cfg["layer"]),
compute_lm_head=False,
)
if isinstance(orig_cfg, dict):
setattr(model.encoder, "aop_prune_config", orig_cfg)
# EOS pooling 得到中间层表征
hs_mid = getattr(out_mid, "last_hidden_state", None)
if hs_mid is None:
assert out_mid.hidden_states is not None and len(out_mid.hidden_states) > 0
hs_mid = out_mid.hidden_states[-1]
am_mid = getattr(out_mid, "attention_mask", None)
if am_mid is None:
am_mid = inputs.get("attention_mask", None)
if hasattr(am_mid, "device") and am_mid.device != hs_mid.device:
am_mid = am_mid.to(hs_mid.device)
reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(device=device, dtype=torch.bfloat16) # [B,D]
B = reps_mid_t.size(0)
use_local = (not global_ranking)
# -------- 2) 门控:基于 mid→mid 的 top2 分数 --------
if not use_local:
# 全库:cand_mid_t
scores_t = reps_mid_t @ cand_mid_t.T # [B, Nc]
vals_t, idxs_t = torch.topk(
scores_t, k=min(k_conf, scores_t.size(1)), dim=1
) # [B,2]
p_t = torch.softmax(vals_t / max(temp, 1e-6), dim=1)
if vals_t.size(1) >= 2:
margin_t = vals_t[:, 0] - vals_t[:, 1]
p1p2_t = p_t[:, 0] - p_t[:, 1]
else:
margin_t = torch.full((B,), float("inf"), device=device, dtype=vals_t.dtype)
p1p2_t = torch.ones(B, device=device, dtype=vals_t.dtype)
H_t = -(p_t * (torch.log(p_t + 1e-12))).sum(dim=1) / math.log(max(vals_t.size(1),1))
conf_map = {
"margin": margin_t,
"p1p2": p1p2_t,
"entropy": 1.0 - H_t,
"gini": (p_t ** 2).sum(dim=1),
}
confs_t = conf_map.get(method, margin_t)
exit_mask = (confs_t >= tau).detach().cpu().numpy().astype(bool)
else:
# local:对每个 query 单独用 cand_mid_t[rows]
confs = []
for b in range(B):
cand_local = infos[b]["cand_names"]
rows = [cand_id2row.get(str(cid), -1) for cid in cand_local]
rows = [r for r in rows if r >= 0]
if len(rows) == 0:
confs.append(0.0)
continue
cmat_t = cand_mid_t[rows] # [Nl, D]
sv_t = (reps_mid_t[b:b+1] @ cmat_t.T)[0] # [Nl]
k = 2 if sv_t.size(0) >= 2 else 1
vals_t, _ = torch.topk(sv_t, k=k, dim=0)
p_t = torch.softmax(vals_t / max(temp, 1e-6), dim=0)
if k >= 2:
margin = (vals_t[0] - vals_t[1]).item()
p1p2 = (p_t[0] - p_t[1]).item()
else:
margin, p1p2 = float("inf"), 1.0
H = (-(p_t * (torch.log(p_t + 1e-12))).sum() / math.log(max(k,1))).item()
gini = ((p_t ** 2).sum()).item()
d = {"margin": margin, "p1p2": p1p2, "entropy": 1.0 - H, "gini": gini}
confs.append(d.get(method, margin))
exit_mask = (np.array(confs) >= tau)
# -------- 3) 检索 + 相似度记录 --------
# 早停样本
exit_indices = np.where(exit_mask)[0].tolist()
# 续跑样本
need_indices = np.where(~exit_mask)[0].tolist()
# A. 早停:直接用 mid→mid 排序,但我们额外算 mid→last 的 top-K 相似度
for j in exit_indices:
# 1) 排序(pred_dicts)
if not use_local:
scores_mid_mid = (reps_mid_t[j:j+1] @ cand_mid_t.T)[0] # [Nc]
order = torch.argsort(scores_mid_mid, dim=0, descending=True).detach().cpu().numpy()
cids = [cand_ids[i] for i in order]
else:
cand_local = infos[j]["cand_names"]
rows = [cand_id2row.get(str(cid), -1) for cid in cand_local]
rows = [r for r in rows if r >= 0]
if len(rows) == 0:
cids = []
else:
cmat_t = cand_mid_t[rows]
sv = (reps_mid_t[j:j+1] @ cmat_t.T)[0]
order_local = torch.argsort(sv, dim=0, descending=True).detach().cpu().numpy()
cids = [str(cand_local[i]) for i in order_local]
rel_docids = infos[j]["label_name"]
if not isinstance(rel_docids, list):
rel_docids = [rel_docids]
pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None})
# 2) 相似度记录:mid→last
if save_scores and is_main:
if not use_local:
scores_mid_last = (reps_mid_t[j:j+1] @ cand_last_t.T)[0].detach().float().cpu() # [Nc]
Nc = scores_mid_last.size(0)
K = min(topk_sim, Nc)
mid_vals, mid_inds = torch.topk(scores_mid_last, k=K, dim=0)
mid_ids = [cand_ids[i] for i in mid_inds.tolist()]
else:
cand_local = infos[j]["cand_names"]
rows = [cand_id2row.get(str(cid), -1) for cid in cand_local]
rows = [r for r in rows if r >= 0]
if len(rows) == 0:
mid_vals = torch.empty(0)
mid_ids = []
else:
cmat_last = cand_last_t[rows] # [Nl, D]
sv_last = (reps_mid_t[j:j+1] @ cmat_last.T)[0].detach().float().cpu()
K = min(topk_sim, sv_last.size(0))
mid_vals, mid_inds = torch.topk(sv_last, k=K, dim=0)
mid_ids = [str(cand_local[i]) for i in mid_inds.tolist()]
rec = {
"qid": int(idx_global + j),
"early_exit": True,
"mid_topk_scores": mid_vals.tolist() if mid_vals.numel() > 0 else [],
"mid_topk_cand_ids": mid_ids,
"last_topk_scores": None,
"last_topk_cand_ids": None,
}
sim_records.append(rec)
# B. 续跑:mid->last,再用 last→last 排序;同时记录 mid→last & last→last 相似度
if len(need_indices) > 0:
# 从中间态恢复
if isinstance(aop_cfg, dict) and aop_cfg:
aop_resume = dict(aop_cfg)
aop_resume["enabled"] = bool(_orig_enabled and side_enable)
setattr(model.encoder, "aop_prune_config", aop_resume)
interm = getattr(out_mid, "intermediate_state", None)
assert interm is not None, "Model must return intermediate_state when stop_at_layer is set."
hs = interm["hidden_states"].detach()
am = interm["attention_mask"].detach()
pos = interm["position_ids"].detach()
vm = interm.get("vision_mask", None)
tm = interm.get("text_mask", None)
next_layer = int(interm["next_layer_idx"])
hs_sub = hs[need_indices]
am_sub = am[need_indices]
pos_sub = pos[:, need_indices, :]
vm_sub = vm[need_indices] if vm is not None else None
tm_sub = tm[need_indices] if tm is not None else None
resume_state = {
"hidden_states": hs_sub,
"attention_mask": am_sub,
"position_ids": pos_sub,
"vision_mask": vm_sub,
"text_mask": tm_sub,
"next_layer_idx": next_layer,
}
with torch.no_grad(), torch.autocast(
device_type="cuda", dtype=torch.bfloat16, enabled=True
):
out_last = model.encoder(
return_dict=True,
output_hidden_states=False,
stop_at_layer=None,
resume_state=resume_state,
compute_lm_head=False,
)
hs_last = getattr(out_last, "last_hidden_state", None)
if hs_last is None:
assert out_last.hidden_states is not None and len(out_last.hidden_states) > 0
hs_last = out_last.hidden_states[-1]
am_last = getattr(out_last, "attention_mask", None)
if am_last is None:
am_last = am_sub
if hasattr(am_last, "device") and am_last.device != hs_last.device:
am_last = am_last.to(hs_last.device)
reps_last_t = model._pooling(hs_last, am_last).detach().to(device=device, dtype=torch.bfloat16)
if not use_local:
scores_last_all = (reps_last_t @ cand_last_t.T).detach().float().cpu() # [N_need, Nc]
for k, j in enumerate(need_indices):
# 1) 排序预测
row = scores_last_all[k]
order = torch.argsort(row, dim=0, descending=True).tolist()
cids = [cand_ids[i] for i in order]
rel_docids = infos[j]["label_name"]
if not isinstance(rel_docids, list):
rel_docids = [rel_docids]
pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None})
# 2) mid→last & last→last 相似度
if save_scores and is_main:
# mid2last
scores_mid_last = (reps_mid_t[j:j+1] @ cand_last_t.T)[0].detach().float().cpu()
Nc = scores_mid_last.size(0)
K = min(topk_sim, Nc)
mid_vals, mid_inds = torch.topk(scores_mid_last, k=K, dim=0)
mid_ids = [cand_ids[i] for i in mid_inds.tolist()]
# last2last
last_row = row
last_vals, last_inds = torch.topk(last_row, k=K, dim=0)
last_ids = [cand_ids[i] for i in last_inds.tolist()]
rec = {
"qid": int(idx_global + j),
"early_exit": False,
"mid_topk_scores": mid_vals.tolist(),
"mid_topk_cand_ids": mid_ids,
"last_topk_scores": last_vals.tolist(),
"last_topk_cand_ids": last_ids,
}
sim_records.append(rec)
else:
# local ranking
for k, j in enumerate(need_indices):
cand_local = infos[j]["cand_names"]
rows = [cand_id2row.get(str(cid), -1) for cid in cand_local]
rows = [r for r in rows if r >= 0]
if len(rows) == 0:
cids = []
rel_docids = infos[j]["label_name"]
if not isinstance(rel_docids, list):
rel_docids = [rel_docids]
pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None})
if save_scores and is_main:
rec = {
"qid": int(idx_global + j),
"early_exit": False,
"mid_topk_scores": [],
"mid_topk_cand_ids": [],
"last_topk_scores": [],
"last_topk_cand_ids": [],
}
sim_records.append(rec)
continue
# 1) 排序(last→last)
cmat_last = cand_last_t[rows]
sv_last = (reps_last_t[k:k+1] @ cmat_last.T)[0].detach().float().cpu()
order_local = torch.argsort(sv_last, dim=0, descending=True).tolist()
cids = [str(cand_local[i]) for i in order_local]
rel_docids = infos[j]["label_name"]
if not isinstance(rel_docids, list):
rel_docids = [rel_docids]
pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None})
# 2) mid→last & last→last 相似度
if save_scores and is_main:
cmat_last = cand_last_t[rows]
# mid2last
sv_mid_last = (reps_mid_t[j:j+1] @ cmat_last.T)[0].detach().float().cpu()
K = min(topk_sim, sv_mid_last.size(0))
mid_vals, mid_inds = torch.topk(sv_mid_last, k=K, dim=0)
mid_ids = [str(cand_local[i]) for i in mid_inds.tolist()]
# last2last
sv_last_row = sv_last
last_vals, last_inds = torch.topk(sv_last_row, k=K, dim=0)
last_ids = [str(cand_local[i]) for i in last_inds.tolist()]
rec = {
"qid": int(idx_global + j),
"early_exit": False,
"mid_topk_scores": mid_vals.tolist(),
"mid_topk_cand_ids": mid_ids,
"last_topk_scores": last_vals.tolist(),
"last_topk_cand_ids": last_ids,
}
sim_records.append(rec)
idx_global += B
# -------- 4) 评测 + 写出 --------
metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"]
score = RankingMetrics(metrics_to_report).evaluate(pred_dicts)
if is_main:
os.makedirs(out_dir, exist_ok=True)
# 原有的 early-exit 检索结果
with open(os.path.join(out_dir, f"{dataset_name}_score_earlyexit.json"), "w") as f:
json.dump(score, f, indent=4)
if ee_cfg.get("save", False):
with open(os.path.join(out_dir, f"{dataset_name}_pred_earlyexit.jsonl"), "w", encoding="utf-8") as f:
for row in pred_dicts:
f.write(json.dumps(row, ensure_ascii=False) + "\n")
# 新增的 mid/last 相似度输出
if save_scores and sim_records is not None:
sims_path = os.path.join(out_dir, f"{dataset_name}_sim_earlyexit.jsonl")
with open(sims_path, "w", encoding="utf-8") as f:
for rec in sim_records:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
print_master(f"[EE] Saved mid/last similarity records -> {sims_path}")
elapsed = time.time() - start_time
return score, elapsed
def make_layer_tag(keep_layers: int | None):
return f"layer{keep_layers}" if keep_layers and keep_layers > 0 else "layerlast"
def dot_sim(a: np.ndarray, b: np.ndarray) -> np.ndarray:
# a: [Nq, D], b: [Nc, D], both L2-normalized already if normalize=true
return a @ b.T
def build_score_details(qid: int, cand_ids: list, score_vec: np.ndarray, ranked_indices: np.ndarray):
return {
"qid": int(qid),
"cand_scores": [
{"cand_id": str(cand_ids[i]), "score": float(score_vec[i])}
for i in ranked_indices
]
}
def top1_top2_margin(score_vec: np.ndarray) -> float:
if len(score_vec) < 2:
return float("inf") # 只有一个候选时视作极大margin
top2 = np.partition(score_vec, -2)[-2:]
top2.sort()
return float(top2[-1] - top2[-2])
def simulate_early_exit_by_margin(
sims_mid: list[dict], sims_last: list[dict], labels: list[list[str]], metrics_to_report: list[str],
taus: list[float], rank_global: bool
):
"""
sims_mid / sims_last: 每个query一个dict: {cand_id: score}
labels: 每个query的正样本cand_id列表
返回:不同tau下的覆盖率、指标
"""
assert len(sims_mid) == len(sims_last) == len(labels)
N = len(labels)
results = []
from src.eval_utils.metrics import RankingMetrics
metrics = RankingMetrics(metrics_to_report)
# 预构造 用于metrics.evaluate 的pred_dict
def to_pred_dicts(use_mid_mask: list[bool]) -> list[dict]:
pred_dicts = []
for qid in range(N):
sims_use = sims_mid[qid] if use_mid_mask[qid] else sims_last[qid]
# 排序
ranked = sorted(sims_use.items(), key=lambda x: -x[1])
pred_dicts.append({
"prediction": [cid for cid, _ in ranked],
"label": labels[qid],
"rel_scores": None
})
return pred_dicts
# 计算中间层margin
margins = []
for qid in range(N):
# 取前两大分数的margin
if len(sims_mid[qid]) == 0:
margins.append(0.0)
continue
scores = np.array(list(sims_mid[qid].values()), dtype=np.float32)
margins.append(top1_top2_margin(scores))
margins = np.array(margins, dtype=np.float32)
for tau in taus:
use_mid_mask = (margins >= tau).tolist()
pred_dicts = to_pred_dicts(use_mid_mask)
score_dict = metrics.evaluate(pred_dicts)
coverage = float(np.mean(use_mid_mask)) # 早停覆盖率
results.append({
"tau": tau,
"coverage": coverage,
**score_dict
})
return results
def top1_top2_margin_from_array(score_vec: np.ndarray) -> float:
if score_vec is None or len(score_vec) == 0:
return 0.0
if len(score_vec) == 1:
return float('inf')
# 取前两大
top2 = np.partition(score_vec, -2)[-2:]
top2.sort()
return float(top2[-1] - top2[-2])
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s')
logger = logging.getLogger(__name__)
# --- Global Dictionaries for Hooks (will be cleared before each encode_embeddings call) ---
timing_info = {}
token_info = {
"vision_tokens": 0,
"text_input_tokens": 0, # Refers to the original text token count
"text_output_tokens": 0, # Not directly applicable here as we are encoding, not generating. Will be 0.
"total_llm_input_tokens": 0, # Refers to the total tokens LLM receives (visual + formatted text)
}
# --- Hook Functions Definition ---
def timing_pre_hook(module, input):
module_id = id(module)
if module_id not in timing_info:
timing_info[module_id] = []
timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__))
def timing_post_hook(module, input, output):
module_id = id(module)
if module_id not in timing_info:
# print(f"Warning: No pre-hook data for module {module.__class__.__name__} ({module_id})")
return
timing_info[module_id].append((time.time(), 'post', module.__class__.__name__))
# Collect vision token count (only from Vision Transformer module's post hook)
module_name = module.__class__.__name__
if "vision" in module_name.lower() and "transformer" in module_name.lower():
if isinstance(output, torch.Tensor):
token_info["vision_tokens"] = output.shape[0] # For visual features, usually (batch_size, num_tokens, hidden_dim)
elif hasattr(output, 'last_hidden_state'):
token_info["vision_tokens"] = output.last_hidden_state.shape[1]
def register_model_hooks(model):
registered_modules = []
core_model = model.encoder if hasattr(model, "encoder") and model.encoder is not None else model
# Vision module
if hasattr(core_model, 'visual') and core_model.visual is not None:
vision_module = core_model.visual
vision_module.register_forward_pre_hook(timing_pre_hook)
vision_module.register_forward_hook(timing_post_hook)
registered_modules.append(vision_module)
print_master(f"Registered hooks for vision module: {vision_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'visual' attribute found on core_model ({type(core_model)}).")
# Merger module (if inside visual) - it's part of the vision component
if hasattr(core_model, 'visual') and hasattr(core_model.visual, 'merger') and core_model.visual.merger is not None:
merger_module = core_model.visual.merger
merger_module.register_forward_pre_hook(timing_pre_hook)
merger_module.register_forward_hook(timing_post_hook)
registered_modules.append(merger_module)
print_master(f"Registered hooks for merger module: {merger_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'merger' attribute found on core_model.visual ({type(getattr(core_model, 'visual', 'N/A'))}).")
# Language model body
if hasattr(core_model, 'model') and core_model.model is not None:
llm_main_module = core_model.model
llm_main_module.register_forward_pre_hook(timing_pre_hook)
llm_main_module.register_forward_hook(timing_post_hook)
registered_modules.append(llm_main_module)
print_master(f"Registered hooks for LLM main module: {llm_main_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'model' attribute found on core_model ({type(core_model)}).")
# LM Head
if hasattr(core_model, 'lm_head') and core_model.lm_head is not None:
lm_head_module = core_model.lm_head
lm_head_module.register_forward_pre_hook(timing_pre_hook)
lm_head_module.register_forward_hook(timing_post_hook)
registered_modules.append(lm_head_module)
print_master(f"Registered hooks for LM head module: {lm_head_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'lm_head' attribute found on core_model ({type(core_model)}).")
if not registered_modules:
print_master("Warning: No major modules found for hook registration. Check model architecture.")
return registered_modules
def pad_dataset_to_divisible(dataset, world_size):
num_samples = len(dataset)
if num_samples % world_size == 0:
return dataset, num_samples
num_to_add = world_size - (num_samples % world_size)
padded_size = num_samples + num_to_add
padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)])
padded_dataset = concatenate_datasets([dataset, padding_data])
return padded_dataset, padded_size
def encode_embeddings(
model: MMEBModel,
loader: DataLoader,
training_args: TrainingArguments,
model_args: ModelArguments,
full_dataset: Dataset,
encode_side: str,
description: str = "Encoding"
) -> tuple[np.ndarray, list, list, list]: # CHANGED: + list for img_token_masks
"""
Encodes embeddings for a given dataset using the model, handling both standard and
late-interaction models in a DDP-safe manner.
Returns:
- embeddings: np.ndarray
- infos_or_ids: list
- batch_stats_list: list
- img_token_masks: list[None | list[bool]] # NEW
"""
local_rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
# Check if the model is a late-interaction type
is_late_interaction = (model_args.model_backbone == COLPALI)
local_embeds = []
local_gt_infos = []
local_max_len = 0
# --- New: List to store statistics for each batch ---
batch_stats_list = []
# --- NEW: Collect masks ---
local_img_token_masks = [] # post image mask per sample
local_txt_token_masks = [] # NEW: post text mask per sample
local_post_attn_masks = [] # NEW: post attention_mask per sample (after prune, 1/0)
# --- NEW: per-sample token reduction records ---
local_token_records = [] # 每条样本一个 dict,含 pre/post/delta 数量
model.eval()
# Register hooks for the model once per encode_embeddings call
registered_hooks = register_model_hooks(model)
# --- NEW: helpers to取mask并序列化 ---
def _search_key(obj, key: str):
# 递归搜索 dict/list/tuple,找到指定 key
if isinstance(obj, dict):
if key in obj:
return obj[key]
for v in obj.values():
r = _search_key(v, key)
if r is not None:
return r
elif isinstance(obj, (list, tuple)):
for v in obj:
r = _search_key(v, key)
if r is not None:
return r
return None
def _to_serializable_mask_list(mask_list, batch_size: int):
# 将模型返回的 mask(list/tensor/ndarray/None)转成 [None | list[bool]] * B
if mask_list is None:
return [None] * batch_size
out = []
if isinstance(mask_list, (list, tuple)):
for m in mask_list:
if m is None:
out.append(None)
elif torch.is_tensor(m):
out.append(m.detach().cpu().tolist())
elif isinstance(m, np.ndarray):
out.append(m.tolist())
else:
# already python list/bool
out.append(m)
elif torch.is_tensor(mask_list):
# 若是 2D 张量(B, L),直接 tolist() -> list[list[bool/int]]
out = mask_list.detach().cpu().tolist()
elif isinstance(mask_list, np.ndarray):
out = mask_list.tolist()
else:
# 未知类型,保守返回 None 占位
out = [None] * batch_size
# 长度对齐 batch_size
if isinstance(out, list):
if len(out) < batch_size:
out = out + [None] * (batch_size - len(out))
elif len(out) > batch_size:
out = out[:batch_size]
return out
def _to_bool_lists(m, batch_size: int):
lst = _to_serializable_mask_list(m, batch_size)
# 归一化成 list[ list[bool] | None ]
out = []
for x in lst:
if x is None:
out.append(None)
else:
# x 可能是 list[int] 或 list[bool]
out.append([bool(int(v)) for v in x])
return out
with torch.no_grad():
for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0):
# --- Reset statistics for each inference pass ---
timing_info.clear()
token_info["vision_tokens"] = 0
token_info["text_input_tokens"] = 0
token_info["text_output_tokens"] = 0
token_info["total_llm_input_tokens"] = 0
inputs = batch_to_device(inputs, training_args.device)
current_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs and inputs['input_ids'] is not None else 1
with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
start_inference_time = time.time()
# ---- NEW: 按侧开/关 AOP ----
aop_cfg = getattr(model.encoder, "aop_prune_config", None)
_orig_enabled = None
if isinstance(aop_cfg, dict) and aop_cfg:
_orig_enabled = aop_cfg.get("enabled", False)
apply_to = aop_cfg.get("apply_to", "qry")
side_enable = (apply_to == "both") or (apply_to == encode_side)
aop_cfg["enabled"] = bool(side_enable and _orig_enabled)
setattr(model.encoder, "aop_prune_config", aop_cfg)
if encode_side == "qry":
output = model(qry=inputs)
reps = output["qry_reps"].detach()
local_gt_infos.extend(dataset_info)
else:
output = model(tgt=inputs)
reps = output["tgt_reps"].detach()
local_gt_infos.extend([info["cand_name"] for info in dataset_info])
# ---- NEW: 恢复 enabled(避免影响下个 encode_side)----
if isinstance(aop_cfg, dict) and _orig_enabled is not None:
aop_cfg["enabled"] = _orig_enabled
setattr(model.encoder, "aop_prune_config", aop_cfg)
end_inference_time = time.time()
# --- NEW: 提取 post-prune 的 image/text 掩码 与 post attention_mask ---
img_masks_raw = None
txt_masks_raw = None
post_attn_raw = None
if isinstance(output, dict):
img_masks_raw = _search_key(output, "image_token_bool_masks")
txt_masks_raw = _search_key(output, "text_token_bool_masks") # NEW
post_attn_raw = _search_key(output, "post_attention_mask") # NEW(我们的 MMEBModel.forward 里带了这个键)
# 兼容:若挂在 model 上
if img_masks_raw is None and hasattr(model, "image_token_bool_masks"):
img_masks_raw = getattr(model, "image_token_bool_masks")
if txt_masks_raw is None and hasattr(model, "text_token_bool_masks"):
txt_masks_raw = getattr(model, "text_token_bool_masks")
if post_attn_raw is None and hasattr(model, "post_attention_mask"):
post_attn_raw = getattr(model, "post_attention_mask")
img_masks_serializable = _to_serializable_mask_list(img_masks_raw, current_batch_size)
txt_masks_serializable = _to_serializable_mask_list(txt_masks_raw, current_batch_size) # NEW
post_attn_serializable = _to_serializable_mask_list(post_attn_raw, current_batch_size) # NEW
local_img_token_masks.extend(img_masks_serializable)
local_txt_token_masks.extend(txt_masks_serializable) # NEW
local_post_attn_masks.extend(post_attn_serializable) # NEW
# --- NEW: 计算本 batch 的 pre/post/delta 数量并累计 ---
cfg = getattr(model.encoder, "config", None)
# pre masks 来自 inputs(删前)
input_ids = inputs.get("input_ids", None)
attn2d_pre = inputs.get("attention_mask", None)
if input_ids is None or attn2d_pre is None or cfg is None:
# 无法统计,留空
pre_vis_counts = [0] * current_batch_size
pre_txt_counts = [0] * current_batch_size
pre_tot_counts = [0] * current_batch_size
else:
iid = input_ids
am = attn2d_pre.to(torch.bool)
image_token_id = getattr(cfg, "image_token_id", None)
video_token_id = getattr(cfg, "video_token_id", None)
bos_id = getattr(cfg, "bos_token_id", None)
eos_id = getattr(cfg, "eos_token_id", None)
pad_id = getattr(cfg, "pad_token_id", None)
is_image = (iid == image_token_id) if (image_token_id is not None and image_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool)
is_video = (iid == video_token_id) if (video_token_id is not None and video_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool)
is_vision = is_image | is_video
is_special = torch.zeros_like(iid, dtype=torch.bool)
for tid in [bos_id, eos_id, pad_id]:
if tid is not None and tid >= 0:
is_special |= (iid == tid)
pre_txt_mask = am & (~is_vision) & (~is_special)
pre_vis_mask = am & is_vision
pre_vis_counts = pre_vis_mask.sum(dim=1).tolist()
pre_txt_counts = pre_txt_mask.sum(dim=1).tolist()
pre_tot_counts = am.sum(dim=1).tolist()
# post masks(删后)来自模型输出;与 post_attn 做与运算
post_text_masks = _to_bool_lists(txt_masks_raw, current_batch_size) # list[ list[bool] | None ]
post_image_masks = _to_bool_lists(img_masks_raw, current_batch_size)
post_attn_masks = _to_bool_lists(post_attn_raw, current_batch_size)
sum_pre_text = 0; sum_post_text = 0
sum_pre_vis = 0; sum_post_vis = 0
sum_pre_tot = 0; sum_post_tot = 0
for i in range(current_batch_size):
pre_text = int(pre_txt_counts[i]) if i < len(pre_txt_counts) else 0
pre_vis = int(pre_vis_counts[i]) if i < len(pre_vis_counts) else 0
pre_tot = int(pre_tot_counts[i]) if i < len(pre_tot_counts) else 0
# post 计数:mask 可能为 None
m_text = post_text_masks[i] if post_text_masks is not None and i < len(post_text_masks) else None
m_img = post_image_masks[i] if post_image_masks is not None and i < len(post_image_masks) else None
m_attn = post_attn_masks[i] if post_attn_masks is not None and i < len(post_attn_masks) else None
if m_attn is None:
post_text = 0; post_vis = 0; post_tot = 0
else:
# 与 attention_mask 后统计 True 的数
if m_text is not None:
post_text = sum(1 for a, t in zip(m_attn, m_text) if a and t)
else:
post_text = 0
if m_img is not None:
post_vis = sum(1 for a, v in zip(m_attn, m_img) if a and v)
else:
post_vis = 0
post_tot = sum(1 for a in m_attn if a)
# 累计 batch 级
sum_pre_text += pre_text; sum_post_text += post_text
sum_pre_vis += pre_vis; sum_post_vis += post_vis
sum_pre_tot += pre_tot; sum_post_tot += post_tot
# 保存 per-sample 记录(用于 JSONL)
local_token_records.append({
"side": encode_side,
"pre": {"text": pre_text, "vision": pre_vis, "total": pre_tot},
"post": {"text": post_text, "vision": post_vis, "total": post_tot},
"delta":{"text": pre_text - post_text, "vision": pre_vis - post_vis, "total": pre_tot - post_tot},
})
# --- Update total LLM input tokens after the model call ---
if 'input_ids' in inputs and inputs['input_ids'] is not None:
token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1]
token_info["text_input_tokens"] = token_info["total_llm_input_tokens"] - token_info["vision_tokens"]
token_info["text_input_tokens"] = max(0, token_info["text_input_tokens"])
# --- Collect and Store Batch Statistics ---
batch_inference_time = end_inference_time - start_inference_time
current_batch_stats = {
"batch_size": current_batch_size,
"total_inference_time_seconds": batch_inference_time,
"module_inference_times": {},
"token_counts": {
"visual_tokens": token_info["vision_tokens"],
"language_input_tokens_raw": token_info["text_input_tokens"],
"llm_total_input_tokens": token_info["total_llm_input_tokens"],
"language_output_tokens": token_info["text_output_tokens"],
}
}
current_batch_stats["token_reduction"] = {
"sum_pre_text": sum_pre_text,
"sum_post_text": sum_post_text,
"sum_pre_vision": sum_pre_vis,
"sum_post_vision": sum_post_vis,
"sum_pre_total": sum_pre_tot,
"sum_post_total": sum_post_tot,
}
# Calculate and store module timings for the current batch
for module_obj in registered_hooks:
module_id = id(module_obj)
module_name = module_obj.__class__.__name__
times = timing_info.get(module_id, [])
durations = []
pre_times = {}
for t, event_type, _ in times:
if event_type == 'pre':
pre_times[module_id] = t
elif event_type == 'post' and module_id in pre_times:
duration = t - pre_times.pop(module_id)
durations.append(duration)
if durations:
current_batch_stats["module_inference_times"][module_name] = {
"total": sum(durations),
"count": len(durations),
"avg": sum(durations) / len(durations)
}
else:
current_batch_stats["module_inference_times"][module_name] = {
"total": 0.0,
"count": 0,
"avg": 0.0
}
batch_stats_list.append(current_batch_stats)
# --- Debug prints (optional) ---
print_rank(f"\n--- Inference Statistics for {encode_side} batch (Rank {local_rank}) ---")
print_rank(f"Batch Inference took: {batch_inference_time:.4f} seconds")
print_rank("--- Module Inference Timing Statistics ---")
for module_name, stats in current_batch_stats["module_inference_times"].items():
print_rank(f"**{module_name}**: Total: {stats['total']:.6f}s, Count: {stats['count']}, Avg: {stats['avg']:.6f}s")
print_rank("--- Token Count Statistics ---")
print_rank(f"**视觉 token 数量**: {current_batch_stats['token_counts']['visual_tokens']}")
print_rank(f"**语言输入 token 数量 (仅原始文本)**: {current_batch_stats['token_counts']['language_input_tokens_raw']}")
print_rank(f"**LLM总输入 token 数量 (包含视觉 + 格式化文本)**: {current_batch_stats['token_counts']['llm_total_input_tokens']}")
print_rank(f"**语言输出 token 数量**: {current_batch_stats['token_counts']['language_output_tokens']}")
if is_late_interaction and reps.dim() == 3:
local_max_len = max(local_max_len, reps.shape[1])
local_embeds.append(reps)
if not local_embeds:
# Handle cases where a rank gets no data
return np.array([]), [], [], [] # CHANGED: 4个返回值
# === DDP Synchronization and Padding for Late-Interaction Models ===
if is_late_interaction:
if dist.is_initialized():
# 1: global max length
local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device)
dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX)
global_max_len = local_max_len_tensor.item()
else:
global_max_len = local_max_len
# 2: pad to global max length
padded_embeds = []
for reps_batch in local_embeds:
if reps_batch.dim() == 3:
B, L, H = reps_batch.shape
padding_size = global_max_len - L
padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0)
padded_embeds.append(padded_batch)
else:
padded_embeds.append(reps_batch)
embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous()
else:
embeds_tensor = torch.cat(local_embeds, dim=0).contiguous()
# === Gather embeddings and keys from all ranks ===
if dist.is_initialized() and full_dataset.num_rows >= world_size:
print_master(f"Gathering {encode_side} embeddings across all ranks...")
# tensor gather
output_shape = list(embeds_tensor.shape)
output_shape[0] = full_dataset.num_rows
embeds_tensor = embeds_tensor.to(training_args.device)
gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device)
dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor)
final_embeddings = gathered_embeds_tensor.cpu().float().numpy()
# object gather for infos and stats
gathered_gt_infos = [None for _ in range(world_size)]
dist.all_gather_object(gathered_gt_infos, local_gt_infos)
all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys]
gathered_batch_stats = [None for _ in range(world_size)]
dist.all_gather_object(gathered_batch_stats, batch_stats_list)
all_batch_stats = [stats for rank_stats in gathered_batch_stats for stats in rank_stats]
# --- NEW: gather masks ---
gathered_masks = [None for _ in range(world_size)]
dist.all_gather_object(gathered_masks, local_img_token_masks)
all_img_token_masks = [m for rank_list in gathered_masks for m in rank_list]
# NEW: gather text masks
gathered_txt_masks = [None for _ in range(world_size)]
dist.all_gather_object(gathered_txt_masks, local_txt_token_masks)
all_txt_token_masks = [m for rank_list in gathered_txt_masks for m in rank_list]
# NEW: gather post attention masks(如需)
gathered_post_attn = [None for _ in range(world_size)]
dist.all_gather_object(gathered_post_attn, local_post_attn_masks)
all_post_attn_masks = [m for rank_list in gathered_post_attn for m in rank_list]
# NEW: gather token records
gathered_token_recs = [None for _ in range(world_size)]
dist.all_gather_object(gathered_token_recs, local_token_records)
all_token_records = [r for rank_list in gathered_token_recs for r in rank_list]
else:
all_gt_infos = local_gt_infos
final_embeddings = embeds_tensor.cpu().float().numpy()
all_batch_stats = batch_stats_list
all_img_token_masks = local_img_token_masks # NEW
all_txt_token_masks = local_txt_token_masks
all_post_attn_masks = local_post_attn_masks
all_token_records = local_token_records
return final_embeddings, all_gt_infos, all_batch_stats, all_img_token_masks, all_txt_token_masks, all_token_records
# === NEW: 一次前向同时导出 cand 的中间层和最后一层向量 ===
def encode_candidates_both_layers(
model: MMEBModel,
loader: DataLoader,
training_args: TrainingArguments,
model_args: ModelArguments,
full_dataset: Dataset,
mid_layer: int,
) -> tuple[np.ndarray, np.ndarray, list]:
"""
单次forward到最后一层,直接从 hidden_states 取:
- mid_hidden = hidden_states[mid_layer] # 表示经过 mid_layer 层后的状态(见Qwen2_5_VLModel的all_hidden_states定义)
- last_hidden = hidden_states[-1] # 最后一层norm后的状态
然后用 _pooling(attention_mask) 取句向量,返回:
- cand_mid_embeds: np.ndarray [Nc, D]
- cand_last_embeds: np.ndarray [Nc, D]
- cand_ids: list[str]
说明:
- cand 侧默认不做 AOP 剪枝(AOP_APPLY=qry 时天然关闭),因此 mid/last 的序列长度一致,可直接用原 attention_mask 做池化。
"""
local_rank = dist.get_rank() if dist.is_initialized() else 0
model.eval()
all_mid = []
all_last = []
all_ids = []
with torch.no_grad():
for inputs, dataset_info in tqdm(loader, desc=f"Candidates[BOTH] (rank {local_rank})", disable=local_rank > 0):
inputs = batch_to_device(inputs, training_args.device)
# cand 侧确保不触发 AOP(如果你的 AOP_APPLY=qry/both,会在底模按侧门控;此处再做一次保险)
aop_cfg = getattr(model.encoder, "aop_prune_config", None)
_orig_enabled = None
if isinstance(aop_cfg, dict) and aop_cfg:
_orig_enabled = aop_cfg.get("enabled", False)
apply_to = aop_cfg.get("apply_to", "qry")
side_enable = (apply_to == "both") or (apply_to == "cand")
aop_cfg["enabled"] = bool(side_enable and _orig_enabled)
setattr(model.encoder, "aop_prune_config", aop_cfg)
with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
# 关键:一次forward拿全层的hidden_states
out = model.encoder(
**inputs,
return_dict=True,
output_hidden_states=True, # 必须
stop_at_layer=None, # 走全层
)
# 取 hidden_states 并索引中间层/最后一层
hs_list = out.hidden_states
assert hs_list is not None and len(hs_list) > mid_layer, \
f"hidden_states is None or too short. Need index {mid_layer}, got len={0 if hs_list is None else len(hs_list)}"
mid_hs = hs_list[mid_layer] # [B, L, D]:等价“经过 mid_layer 层后的状态”(即 pre-layer(mid_layer+1))
last_hs = hs_list[-1] # [B, L, D]:最终norm后的状态
# 用原 attention_mask 池化(cand侧未剪枝)
am = inputs.get("attention_mask", None)
if am is not None and hasattr(am, "device"):
if am.device != mid_hs.device:
am = am.to(mid_hs.device)
reps_mid = model._pooling(mid_hs, am) # [B, D]
reps_last = model._pooling(last_hs, am) # [B, D]
all_mid.append(reps_mid.detach().float().cpu())
all_last.append(reps_last.detach().float().cpu())
all_ids.extend([info["cand_name"] for info in dataset_info])
# 恢复 AOP 开关(避免影响其它侧)
if isinstance(aop_cfg, dict) and _orig_enabled is not None:
aop_cfg["enabled"] = _orig_enabled
setattr(model.encoder, "aop_prune_config", aop_cfg)
if not all_mid:
return np.array([]), np.array([]), []
cand_mid_embeds = torch.cat(all_mid, dim=0).numpy()
cand_last_embeds = torch.cat(all_last, dim=0).numpy()
return cand_mid_embeds, cand_last_embeds, all_ids
def main():
# ----------------------- Distributed init -----------------------
if "RANK" in os.environ and dist.is_available() and not dist.is_initialized():
dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60))
local_rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
print_master("Distributed init debug info:")
print_master(f"RANK: {os.environ.get('RANK')}")
print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}")
print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}")
print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}")
print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}")
if dist.is_initialized():
print_rank(f"dist.get_rank(): {dist.get_rank()}")
print_rank(f"dist.get_world_size(): {dist.get_world_size()}")
# 兼容 torchrun 参数
for arg in sys.argv:
if arg.startswith("--local-rank="):
rank = arg.split("=")[1]
sys.argv.remove(arg)
sys.argv.append('--local_rank')
sys.argv.append(rank)
# ----------------------- Parse args -----------------------
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments
os.makedirs(data_args.encode_output_path, exist_ok=True)
# 支持多层评测(优先 LM_LAYERS,兼容 MID_LM_LAYER)
layers_to_eval = get_env_eval_layers()
print_master(f"Eval layers (qry/tgt): {layers_to_eval}")
# ----------------------- Model loading -----------------------
hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
if not getattr(model_args, "model_backbone", None):
model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type)
setattr(model_args, 'model_backbone', model_backbone)
setattr(training_args, 'model_backbone', model_backbone)
print_master(f'Model Backbone: {model_args.model_backbone}')
# 仅 rank0 下载,其他rank等待缓存
if local_rank == 0:
processor = load_processor(model_args, data_args)
model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...")
if torch.distributed.is_initialized():
torch.distributed.barrier()
if local_rank != 0:
print_rank(f"Loading the model from cache...")
processor = load_processor(model_args, data_args)
time.sleep(random.randint(2 * local_rank, 3 * local_rank))
model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
model.eval()
model = model.to(training_args.device, dtype=torch.bfloat16)
# ---- NEW: AOP 剪裁配置注入(驱动底模里已实现的 AOP 逻辑)----
aop_cfg = get_env_aop_config()
if aop_cfg["enabled"]:
# 把配置塞到底模;底模 forward 中读取该 dict 并执行剪裁
setattr(model.encoder, "aop_prune_config", aop_cfg)
# 可选:为了便于在判定层取注意力或手算 qk,覆盖注意力实现
attn_override = aop_cfg.get("attn_impl_override", "")
if attn_override:
try:
if hasattr(model.encoder, "model") and hasattr(model.encoder.model, "config"):
prev = model.encoder.model.config._attn_implementation
model.encoder.model.config._attn_implementation = attn_override
print_master(f"[AOP] override attn impl: {prev} -> {attn_override} (仅测试建议)")
except Exception as e:
print_master(f"[AOP] try override attn impl failed: {e}")
print_master("[AOP] AOP-Prune enabled with config: " + json.dumps({
"apply_to": aop_cfg["apply_to"],
"layer_idx": aop_cfg["layer_idx"],
"mode": aop_cfg["mode"],
"delta": aop_cfg["delta"],
"K_hat": aop_cfg["K_hat"],
"keep_ratio": aop_cfg["keep_ratio"],
"min_keep": aop_cfg["min_keep"],
"use_bias": aop_cfg["use_bias"],
"margin_mid?": (aop_cfg["margin_mid"] is not None),
"prune_text": aop_cfg.get("prune_text", False),
"keep_ratio_text": aop_cfg.get("keep_ratio_text", None),
"keep_ratio_vision": aop_cfg.get("keep_ratio_vision", None),
"selection": aop_cfg.get("selection", "aop"),
"attn_agg": aop_cfg.get("attn_agg", "mean"),
}))
else:
print_master("[AOP] disabled (set AOP_ENABLED=1 to enable)")
# 确保“最后一层”时不裁层(避免类里默认20层的坑)
model.set_inference_layers(qry_layers=None, tgt_layers=None)
with open(data_args.dataset_config, 'r') as yaml_file:
dataset_configs = yaml.safe_load(yaml_file)
# ----------------------- Main evaluation loop -----------------------
for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()):
if dist.is_initialized():
dist.barrier()
print_master(f"\n--- Evaluating {dataset_name} ---")
# 根据 data_basedir 修正路径
if data_args.data_basedir is not None:
for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]:
if data_args.data_basedir and task_config.get(key):
task_config[key] = os.path.join(data_args.data_basedir, task_config[key])
# 构建数据集
full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config)
full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus)
eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset
if dist.is_initialized():
world_size = dist.get_world_size()
padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size)
padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size)
eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size)
eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size)
else:
padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset
# === EE-only: 仅在线早停推理(先确保两份 candidate 向量)===
ee_cfg = get_env_ee_config()
assert ee_cfg["enabled"], "EE_ENABLED must be 1 for EE-only pipeline."
# 依据 EE_LAYER 构造 tag
mid_layer = int(ee_cfg["layer"])
mid_tag = make_layer_tag(mid_layer) # e.g., layer12
last_tag = "layerlast"
# 准备路径
cand_mid_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{mid_tag}")
cand_last_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{last_tag}")
# 构造 cand DataLoader(一次性,不切分)
eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand")
eval_cand_loader = DataLoader(
full_eval_cand_dataset,
batch_size=training_args.per_device_eval_batch_size,
collate_fn=eval_cand_collator,
num_workers=training_args.dataloader_num_workers
)
# === 替换为:一次前向,导出 cand 的 mid/last 两份向量 ===
need_mid = (not os.path.exists(cand_mid_path))
need_last = (not os.path.exists(cand_last_path))
if need_mid or need_last:
print_master(f"[{dataset_name}] EE-only: encoding candidates BOTH layers in one pass (mid={mid_tag}, last={last_tag}) ...")
# 走全层(不截层)
model.set_inference_layers(qry_layers=None, tgt_layers=None)
cand_embeds_mid, cand_embeds_last, all_cand_ids = encode_candidates_both_layers(
model=model,
loader=eval_cand_loader,
training_args=training_args,
model_args=model_args,
full_dataset=full_eval_cand_dataset,
mid_layer=mid_layer,
)
if local_rank == 0:
if need_mid:
cand_embed_dict_mid = {cid: emb for cid, emb in zip(all_cand_ids, cand_embeds_mid)}
with open(cand_mid_path, "wb") as f:
pickle.dump(cand_embed_dict_mid, f)
print_master(f"[{dataset_name}] EE-only: saved {mid_tag} candidate embeddings -> {cand_mid_path}")
if need_last:
cand_embed_dict_last = {cid: emb for cid, emb in zip(all_cand_ids, cand_embeds_last)}
with open(cand_last_path, "wb") as f:
pickle.dump(cand_embed_dict_last, f)
print_master(f"[{dataset_name}] EE-only: saved {last_tag} candidate embeddings -> {cand_last_path}")
else:
print_master(f"[{dataset_name}] EE-only: reuse existing candidates (mid={cand_mid_path}, last={cand_last_path})")
if dist.is_initialized():
dist.barrier()
# 3) 在线检索(mid↔mid):不做 early-exit,只用中间层做预测
if local_rank == 0:
with open(cand_mid_path, "rb") as f:
cand_mid_dict = pickle.load(f)
# 你可以用 EE_LAYER 作为中间层 index,也可以写死一个和训练一致的 supervise_layers[0]
mid_layer = int(ee_cfg["layer"]) # 比如 12 或 16
rank_global = task_config.get("eval_type", "global") == "global"
print_master(f"[{dataset_name}] Run MID-ONLY retrieval at layer={mid_layer}, global={rank_global}")
run_mid_layer_retrieval(
model=model,
processor=processor,
model_args=model_args,
data_args=data_args,
training_args=training_args,
qry_dataset=full_eval_qry_dataset, # 全量 query
cand_mid_dict=cand_mid_dict,
mid_layer=mid_layer,
dataset_name=dataset_name,
out_dir=data_args.encode_output_path,
global_ranking=rank_global,
topk_save=200, # 每个 query 保存 top-200 相似度;可以按需调小/调大
save_qry_embeds=True, # 保存所有 query 中间层 embedding
)
if dist.is_initialized():
dist.barrier()
continue # 跑完这个 dataset,继续下一个
if __name__ == '__main__':
main()