code_SAS_VLM2Vec / eval_early_exit_phaseA.py
MgGladys's picture
Add files using upload-large-folder tool
2a40e7a verified
import datetime
import logging
import json
import random
import time
import numpy as np
from torch import nn
import os
import pickle
import sys
import torch
import torch.distributed as dist
import torch.nn.functional as F
import yaml
import transformers
import math
# from src.ee_controller.model_phaseA import EEExitClassifier # NEW
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import HfArgumentParser, AutoConfig, AutoTokenizer
from datasets import Dataset, concatenate_datasets
from datasets.distributed import split_dataset_by_node
from src.model.vlm_backbone.qwen2_vl.modeling_qwen2_vl_train_tokrnpooling import Qwen2VLForConditionalGeneration as _Qwen2VLForConditionalGeneration_src
from src.arguments import ModelArguments, DataArguments, TrainingArguments
from src.data.collator.eval_collator import MultimodalEvalDataCollator
from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset
from src.eval_utils.metrics import RankingMetrics
from src.model.model_cut_layer_AOP_add_text_cut import MMEBModel
from src.model.processor import get_backbone_name, load_processor, COLPALI
from src.utils import batch_to_device, print_rank, print_master
from dataclasses import dataclass
class EEExitClassifier(nn.Module):
def __init__(self, in_dim: int, hidden: int = 128, dropout: float = 0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(hidden, hidden),
nn.ReLU(inplace=True),
nn.Linear(hidden, 1),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
logit = self.net(x).squeeze(-1)
return torch.sigmoid(logit)
def get_env_mid_layer():
v = os.environ.get("MID_LM_LAYER", "").strip()
if v == "" or v.lower() in {"none", "null"}:
return None
try:
return int(v)
except:
logger.warning(f"Invalid MID_LM_LAYER={v}, ignore.")
return None
# ------------- AOP-Prune config parsing -------------
def _parse_bool(v: str, default=False):
if v is None: return default
v = v.strip().lower()
return v in {"1","true","yes","y","t","on"}
def _parse_float(v: str, default=None):
try: return float(v) if v is not None else default
except: return default
def _parse_int(v: str, default=None):
try: return int(v) if v is not None else default
except: return default
def get_env_aop_config():
"""
从环境变量读取 AOP 剪裁配置。仅作为“驱动层”的简要测试开关;
实际剪裁逻辑在底模里(Qwen2-VLModel.forward)实现。
"""
enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False)
apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() # qry|cand|both
layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None)
mode = os.environ.get("AOP_MODE", "delta").strip().lower()
# 通用回退
delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10)
khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0)
keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0)
min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64)
use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True)
# 按类型控制
prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True)
prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False)
delta_v = _parse_float(os.environ.get("AOP_DELTA_VISION"), None)
khat_v = _parse_float(os.environ.get("AOP_KHAT_VISION"), None)
keep_ratio_v= _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None)
min_keep_v = _parse_int(os.environ.get("AOP_MIN_KEEP_VISION"), None)
delta_t = _parse_float(os.environ.get("AOP_DELTA_TEXT"), None)
khat_t = _parse_float(os.environ.get("AOP_KHAT_TEXT"), None)
keep_ratio_t= _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None)
min_keep_t = _parse_int(os.environ.get("AOP_MIN_KEEP_TEXT"), 32)
protect_text_last = _parse_int(os.environ.get("AOP_PROTECT_TEXT_LAST"), 16)
protect_special = _parse_bool(os.environ.get("AOP_PROTECT_SPECIAL"), True)
margin_src = os.environ.get("AOP_MARGIN", "").strip().lower() # "" or "mid"
attn_impl = os.environ.get("AOP_ATTN_IMPL", "").strip().lower() # "" or "sdpa"
if layer_idx is None and enabled:
logger.warning("AOP_ENABLED=1 但未设置 AOP_LAYER,关闭 AOP。"); enabled=False
# 新增:选择策略(aop | random)
selection = os.environ.get("AOP_SELECTION", "aop").strip().lower()
if _parse_bool(os.environ.get("AOP_RANDOM"), False):
selection = "random"
random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None)
# 选择策略:aop | random | attention
selection = os.environ.get("AOP_SELECTION", "aop").strip().lower()
if _parse_bool(os.environ.get("AOP_RANDOM"), False):
selection = "random"
random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None)
attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() # mean|max|sum
cfg = {
"enabled": enabled,
"apply_to": apply_to,
"layer_idx": layer_idx,
"mode": mode,
# 回退
"delta": delta, "K_hat": khat,
"keep_ratio": keep_ratio, "min_keep": min_keep,
"use_bias": use_bias, "eps": 1e-6,
# 类型开关
"prune_vision": prune_vision,
"prune_text": prune_text,
# 视觉桶
"delta_vision": delta_v,
"K_hat_vision": khat_v,
"keep_ratio_vision": keep_ratio_v,
"min_keep_vision": min_keep_v,
# 文本桶
"delta_text": delta_t,
"K_hat_text": khat_t,
"keep_ratio_text": keep_ratio_t,
"min_keep_text": min_keep_t,
# 文本保护
"protect_text_last": protect_text_last,
"protect_special": protect_special,
# 可选:排名安全预算
"margin_mid": None if margin_src != "mid" else "USE_MID_MARGIN",
"epsilon_hat": None,
"attn_impl_override": attn_impl if attn_impl in {"sdpa"} else "",
# NEW: 选择策略
"selection": selection, # "aop" 或 "random"
"random_seed": random_seed, # 可选
"attn_agg": attn_agg,
}
return cfg
def get_env_eval_layers():
"""
解析环境变量 LM_LAYERS(优先)或兼容旧的 MID_LM_LAYER。
- LM_LAYERS 示例:"4,8,12,last";可包含 'last'/'none'/'null'/'-1' 表示最后一层(None)。
- 若未设置 LM_LAYERS,则回落到旧逻辑:MID_LM_LAYER=None -> [None];否则 [mid, None]
返回: list[ int | None ],例如 [4, 8, 12, None];None 代表最后一层。
"""
v = os.environ.get("LM_LAYERS", None)
if v is not None:
v = v.strip()
if v:
toks = [t.strip() for t in v.split(',') if t.strip() != ""]
layers = []
for tok in toks:
tl = tok.lower()
if tl in {"last", "none", "null", "-1"}:
layers.append(None)
else:
try:
val = int(tok)
if val > 0:
layers.append(val)
else:
logger.warning(f"Ignoring non-positive layer '{tok}' in LM_LAYERS.")
except Exception:
logger.warning(f"Invalid token '{tok}' in LM_LAYERS; must be int or 'last'/'none'.")
# 去重但保持顺序
seen = set()
uniq = []
for l in layers:
key = -1 if l is None else l
if key in seen:
continue
seen.add(key)
uniq.append(l)
if not uniq:
return [None]
return uniq
else:
# 兼容旧逻辑
mid = get_env_mid_layer()
return [None] if mid is None else [mid, None]
# === Early-Exit config & helpers ===
def get_env_ee_config():
ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on","y","t"}
layer = int(os.environ.get("EE_LAYER", os.environ.get("AOP_LAYER", "12"))) # 默认用 AOP_LAYER
method = os.environ.get("EE_METHOD", "margin").strip().lower() # margin|p1p2|entropy|gini|combined
tau = float(os.environ.get("EE_TAU", "0.2"))
topk = int(os.environ.get("EE_TOPK", "1024"))
temp = float(os.environ.get("EE_TEMP", "0.05"))
save = os.environ.get("EE_SAVE", "1").strip().lower() in {"1","true","yes","on","y","t"}
combw = os.environ.get("EE_COMB_WEIGHTS", "1.0,0.5,0.5")
try:
w_margin, w_conf, w_sq = [float(x) for x in combw.split(",")]
except Exception:
w_margin, w_conf, w_sq = 1.0, 0.5, 0.5
return dict(
enabled=ee_enabled, layer=layer, method=method, tau=tau,
topk=topk, temp=temp, save=save,
w_margin=w_margin, w_conf=w_conf, w_sq=w_sq
)
def _softmax_np(x: np.ndarray, temp: float = 1.0) -> np.ndarray:
x = x - np.max(x)
ex = np.exp(x / max(1e-6, temp))
s = np.sum(ex)
return ex / max(s, 1e-12)
def confidence_from_topk(scores: np.ndarray, method="margin", temp=0.05, w_margin=1.0, w_conf=0.5, w_sq=0.5) -> float:
# scores: 已按降序排列(topK)
if scores.size == 0:
return 0.0
if scores.size == 1:
return 1e9
margin = float(scores[0] - scores[1])
p = _softmax_np(scores, temp=temp)
p1p2 = float(p[0] - p[1])
H = - float(np.sum(p * np.log(p + 1e-12))) / np.log(len(p)) # 归一化熵 ∈ [0,1]
conf = 1.0 - H
sqsum = float(np.sum(p**2)) # Gini 的等价度量(越大越集中)
if method == "margin": return margin
if method == "p1p2": return p1p2
if method == "entropy": return conf
if method == "gini": return sqsum
# combined
return w_margin*margin + w_conf*conf + w_sq*sqsum
# === Phase A: 轻量特征构造(I/T/I+T 通用,缺失模态置0 + one-hot 标识) ===
@torch.no_grad()
def build_phaseA_features_global(
reps_mid_t: torch.Tensor, # [B, D] GPU
cand_mid_t: torch.Tensor, # [Nc, D] GPU
am_mid: torch.Tensor, # [B, L] mid删后 2D mask
input_ids: torch.Tensor, # [B, L] 原始 ids(用于判别视觉/文本)
cfg, # model.encoder.config
topk: int = 200,
temp: float = 0.05
) -> torch.Tensor:
device = reps_mid_t.device
B = reps_mid_t.size(0)
# 1) 相似度 topK -> s1,s2,margin,H,sum_p2
scores_t = reps_mid_t @ cand_mid_t.T # [B, Nc]
k = min(topk, scores_t.size(1))
vals_t, _ = torch.topk(scores_t, k=k, dim=1) # [B,k]
s1 = vals_t[:, 0]
s2 = vals_t[:, 1] if k >= 2 else torch.zeros_like(s1)
margin = s1 - s2
p_t = torch.softmax(vals_t / max(temp, 1e-6), dim=1) # [B,k]
H = -(p_t * (torch.log(p_t + 1e-12))).sum(dim=1) / math.log(max(k, 1)) # 归一化熵
sum_p2 = (p_t ** 2).sum(dim=1)
# 2) 长度/比例 (基于 mid 删后 mask)
am = am_mid.to(torch.bool) # [B,L]
iid = input_ids
image_token_id = getattr(cfg, "image_token_id", None)
video_token_id = getattr(cfg, "video_token_id", None)
bos_id = getattr(cfg, "bos_token_id", None)
eos_id = getattr(cfg, "eos_token_id", None)
pad_id = getattr(cfg, "pad_token_id", None)
is_image = (iid == image_token_id) if (image_token_id is not None and image_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool)
is_video = (iid == video_token_id) if (video_token_id is not None and video_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool)
is_vision = (is_image | is_video) & am
is_special = torch.zeros_like(iid, dtype=torch.bool)
for tid in [bos_id, eos_id, pad_id]:
if tid is not None and tid >= 0:
is_special |= (iid == tid)
is_text = (am & (~is_vision) & (~is_special))
L_vis = is_vision.sum(dim=1).float() # [B]
L_txt = is_text.sum(dim=1).float()
L_tot = am.sum(dim=1).float().clamp(min=1.0)
r_vis = L_vis / L_tot
r_txt = L_txt / L_tot
# 3) 类型 one-hot
is_I = ((L_vis > 0) & (L_txt == 0)).float()
is_T = ((L_txt > 0) & (L_vis == 0)).float()
is_IT = ((L_txt > 0) & (L_vis > 0)).float()
# 4) 拼特征 [B, F=13]
feats = torch.stack([s1, s2, margin, H, sum_p2, L_txt, L_vis, L_tot, r_txt, r_vis, is_I, is_T, is_IT], dim=1)
return feats # [B, 13]
@torch.no_grad()
def build_phaseA_features_local(
reps_mid_t: torch.Tensor, # [B, D] GPU
cand_mid_t: torch.Tensor, # [Nc, D] GPU
am_mid: torch.Tensor, # [B, L]
input_ids: torch.Tensor, # [B, L]
cfg,
per_sample_rows: list, # list[list[int]] for each sample b
topk: int = 200,
temp: float = 0.05
) -> torch.Tensor:
device = reps_mid_t.device
B = reps_mid_t.size(0)
s1_list, s2_list, H_list, sum_p2_list = [], [], [], []
for b in range(B):
rows = per_sample_rows[b]
if len(rows) == 0:
s1_list.append(torch.tensor(0.0, device=device))
s2_list.append(torch.tensor(0.0, device=device))
H_list.append(torch.tensor(1.0, device=device))
sum_p2_list.append(torch.tensor(0.0, device=device))
continue
cmat_t = cand_mid_t[rows] # [Nl, D]
sv_t = (reps_mid_t[b:b+1] @ cmat_t.T)[0] # [Nl]
k = min(topk, sv_t.size(0))
vals_t, _ = torch.topk(sv_t, k=k, dim=0)
s1_list.append(vals_t[0])
s2_list.append(vals_t[1] if k >= 2 else torch.tensor(0.0, device=device, dtype=vals_t.dtype))
p_t = torch.softmax(vals_t / max(temp, 1e-6), dim=0)
H_list.append((-(p_t * (torch.log(p_t + 1e-12))).sum() / math.log(max(k, 1))))
sum_p2_list.append((p_t ** 2).sum())
s1 = torch.stack(s1_list)
s2 = torch.stack(s2_list)
H = torch.stack(H_list)
sum_p2 = torch.stack(sum_p2_list)
margin = s1 - s2
am = am_mid.to(torch.bool)
iid = input_ids
image_token_id = getattr(cfg, "image_token_id", None)
video_token_id = getattr(cfg, "video_token_id", None)
bos_id = getattr(cfg, "bos_token_id", None)
eos_id = getattr(cfg, "eos_token_id", None)
pad_id = getattr(cfg, "pad_token_id", None)
is_image = (iid == image_token_id) if (image_token_id is not None and image_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool)
is_video = (iid == video_token_id) if (video_token_id is not None and video_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool)
is_vision = (is_image | is_video) & am
is_special = torch.zeros_like(iid, dtype=torch.bool)
for tid in [bos_id, eos_id, pad_id]:
if tid is not None and tid >= 0:
is_special |= (iid == tid)
is_text = (am & (~is_vision) & (~is_special))
L_vis = is_vision.sum(dim=1).float() # [B]
L_txt = is_text.sum(dim=1).float()
L_tot = am.sum(dim=1).float().clamp(min=1.0)
r_vis = L_vis / L_tot
r_txt = L_txt / L_tot
is_I = ((L_vis > 0) & (L_txt == 0)).float()
is_T = ((L_txt > 0) & (L_vis == 0)).float()
is_IT = ((L_txt > 0) & (L_vis > 0)).float()
feats = torch.stack([s1, s2, margin, H, sum_p2, L_txt, L_vis, L_tot, r_txt, r_vis, is_I, is_T, is_IT], dim=1)
return feats # [B, 13]
def load_phaseA_classifier(device: torch.device):
"""
从环境变量读取分类器与标准化:
- EE_CTRL_CKPT: 分类器权重 .pt
- EE_CTRL_SCALER: 标准化 json {"mean":[...], "std":[...], "in_dim":13}
- EE_CTRL_IN_DIM: 输入维度(默认13)
- EE_TAU_CTRL: 早停阈值 (默认用 EE_TAU 或 0.2)
- EE_FEAT_TOPK: 构造特征用的 topK (默认200)
返回: (model or None, (mean,std) or None, tau_ctrl, feat_topk)
"""
ckpt = os.environ.get("EE_CTRL_CKPT", "").strip()
in_dim = int(os.environ.get("EE_CTRL_IN_DIM", "13"))
scaler_path = os.environ.get("EE_CTRL_SCALER", "").strip()
tau_ctrl = float(os.environ.get("EE_TAU_CTRL", os.environ.get("EE_TAU", "0.2")))
feat_topk = int(os.environ.get("EE_FEAT_TOPK", "200"))
model = None
scaler = None
if ckpt and os.path.exists(ckpt):
model = EEExitClassifier(in_dim=in_dim, hidden=128, dropout=0.1).to(device)
sd = torch.load(ckpt, map_location=device)
model.load_state_dict(sd)
model.eval()
if scaler_path and os.path.exists(scaler_path):
with open(scaler_path, "r", encoding="utf-8") as f:
stat = json.load(f)
mean = torch.tensor(stat.get("mean", [0.0]*in_dim), dtype=torch.float32, device=device)
std = torch.tensor(stat.get("std", [1.0]*in_dim), dtype=torch.float32, device=device).clamp(min=1e-6)
scaler = (mean, std)
return model, scaler, tau_ctrl, feat_topk
def run_early_exit_queries(
model: MMEBModel,
processor,
model_args: ModelArguments,
data_args: DataArguments,
training_args: TrainingArguments,
qry_dataset: Dataset,
cand_mid_dict: dict,
cand_last_dict: dict,
ee_cfg: dict,
dataset_name: str,
out_dir: str,
global_ranking: bool = True,
):
device = training_args.device
local_rank = dist.get_rank() if dist.is_initialized() else 0
is_main = (not dist.is_initialized()) or (local_rank == 0)
# 候选矩阵 -> GPU(bfloat16)
cand_ids = list(cand_mid_dict.keys())
cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)}
cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32)
cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32)
# 先搬 cand_mid 到 GPU;cand_last 延迟到真的需要续跑时再搬
cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16)
cand_last_t = None # NEW: 延迟到 need_idx>0 分支内首次使用时再构造
# NEW: Phase A 分类器加载(若未配置CKPT则返回None, 回退到margin门控)
ctrl_model, ctrl_scaler, tau_ctrl, feat_topk = load_phaseA_classifier(device)
# DataLoader(仅 query)
collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry")
loader = DataLoader(
qry_dataset,
batch_size=training_args.per_device_eval_batch_size,
collate_fn=collator,
num_workers=training_args.dataloader_num_workers
)
pred_dicts = []
details = []
# AOP 按侧门控(仅对 query 生效)
aop_cfg = getattr(model.encoder, "aop_prune_config", None)
_orig_enabled = None
side_enable = True
if isinstance(aop_cfg, dict) and aop_cfg:
_orig_enabled = aop_cfg.get("enabled", False)
apply_to = aop_cfg.get("apply_to", "qry")
side_enable = (apply_to == "both") or (apply_to == "qry")
# 门控用的 k(margin/p1p2 只需要 top2)
k_conf = 2
tau = float(ee_cfg["tau"])
method= ee_cfg["method"]
temp = float(ee_cfg["temp"])
start_time = time.time()
idx_global = 0
for inputs, infos in tqdm(loader, desc=f"[EE] {dataset_name}@L{ee_cfg['layer']} (rank {local_rank})", disable=local_rank>0):
inputs = batch_to_device(inputs, device)
orig_cfg = None
if isinstance(aop_cfg, dict) and aop_cfg:
orig_cfg = dict(aop_cfg) # 备份原配置,mid 后恢复
aop_layer = aop_cfg.get("layer_idx", None)
ee_layer = int(ee_cfg["layer"])
apply_to = aop_cfg.get("apply_to", "qry").strip().lower()
# 新规则:mid 阶段是否启用 AOP
aop_on_mid = bool(
_orig_enabled and side_enable and
(aop_layer is not None) and (aop_layer < ee_layer) and
(apply_to in {"qry", "both"})
)
aop_cfg_mid = dict(aop_cfg)
aop_cfg_mid["enabled"] = aop_on_mid
setattr(model.encoder, "aop_prune_config", aop_cfg_mid)
# 1) 前半程:跑到中间层(stop_at_layer),跳过 logits
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
out_mid = model.encoder(
**inputs,
return_dict=True,
output_hidden_states=False,
stop_at_layer=int(ee_cfg["layer"]),
compute_lm_head=False, # 不算 logits
)
# 恢复原始 AOP 配置(避免影响后续续跑逻辑)
if isinstance(orig_cfg, dict):
setattr(model.encoder, "aop_prune_config", orig_cfg)
# EOS 池化 -> GPU
hs_mid = getattr(out_mid, "last_hidden_state", None)
if hs_mid is None:
assert out_mid.hidden_states is not None and len(out_mid.hidden_states) > 0
hs_mid = out_mid.hidden_states[-1]
am_mid = getattr(out_mid, "attention_mask", None)
if am_mid is None:
am_mid = inputs.get("attention_mask", None)
if hasattr(am_mid, "device") and am_mid.device != hs_mid.device:
am_mid = am_mid.to(hs_mid.device)
reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(device=device, dtype=torch.bfloat16) # [B,D]
B = reps_mid_t.size(0)
use_local = (not global_ranking)
# 2) 优先使用分类器; 无分类器回退到 margin
use_ctrl = (ctrl_model is not None)
if use_ctrl:
# 构造 Phase A 轻量特征 (GPU)
if not use_local:
feats_t = build_phaseA_features_global(
reps_mid_t=reps_mid_t,
cand_mid_t=cand_mid_t,
am_mid=am_mid,
input_ids=inputs["input_ids"],
cfg=model.encoder.config,
topk=feat_topk,
temp=float(ee_cfg["temp"]),
) # [B,13]
else:
rows_list = []
for b in range(B):
cand_local = infos[b]["cand_names"]
rows = [cand_id2row.get(str(cid), -1) for cid in cand_local]
rows = [r for r in rows if r >= 0]
rows_list.append(rows)
feats_t = build_phaseA_features_local(
reps_mid_t=reps_mid_t,
cand_mid_t=cand_mid_t,
am_mid=am_mid,
input_ids=inputs["input_ids"],
cfg=model.encoder.config,
per_sample_rows=rows_list,
topk=feat_topk,
temp=float(ee_cfg["temp"]),
) # [B,13]
# 标准化
if ctrl_scaler is not None:
mean, std = ctrl_scaler
feats_t = (feats_t - mean) / std
# 推理
with torch.no_grad():
p_exit = ctrl_model(feats_t.float()).detach().cpu().numpy() # [B]
exit_mask = (p_exit >= tau_ctrl)
else:
# 回退到 margin 判定(与原逻辑一致)
# 全局: 取 top2 margin;局部: 对子集取 top2 margin
if not use_local:
scores_t = reps_mid_t @ cand_mid_t.T
vals_t, _ = torch.topk(scores_t, k=2, dim=1)
margin_t = (vals_t[:, 0] - vals_t[:, 1]).detach().cpu().numpy()
exit_mask = (margin_t >= float(ee_cfg["tau"]))
else:
margins = []
for b in range(B):
cand_local = infos[b]["cand_names"]
rows = [cand_id2row.get(str(cid), -1) for cid in cand_local]
rows = [r for r in rows if r >= 0]
if len(rows) == 0:
margins.append(0.0)
continue
cmat_t = cand_mid_t[rows]
sv_t = (reps_mid_t[b:b+1] @ cmat_t.T)[0]
k2 = min(2, sv_t.size(0))
vals_t, _ = torch.topk(sv_t, k=k2, dim=0)
if k2 >= 2:
margins.append(float((vals_t[0] - vals_t[1]).item()))
else:
margins.append(float("inf"))
exit_mask = (np.array(margins) >= float(ee_cfg["tau"]))
# 3) 早停:直接 mid 排序(只在需要保存 details 时构建 topk 列表)
for j in np.where(exit_mask)[0].tolist():
if not use_local:
scores_j = (reps_mid_t[j:j+1] @ cand_mid_t.T)[0] # [Nc]
order = torch.argsort(scores_j, dim=0, descending=True).detach().cpu().numpy()
cids = [cand_ids[i] for i in order]
else:
cand_local = infos[j]["cand_names"]
rows = [cand_id2row.get(str(cid), -1) for cid in cand_local]
rows = [r for r in rows if r >= 0]
if len(rows) == 0:
cids = []
else:
cmat_t = cand_mid_t[rows]
vec = (reps_mid_t[j:j+1] @ cmat_t.T)[0]
order_local = torch.argsort(vec, dim=0, descending=True).detach().cpu().numpy()
cids = [str(cand_local[i]) for i in order_local]
rel_docids = infos[j]["label_name"] if isinstance(infos[j]["label_name"], list) else [infos[j]["label_name"]]
pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None})
# 4) 续跑:仅对未早停子集,从中间态继续到 last(跳过 logits)
need_idx = np.where(~exit_mask)[0].tolist()
if len(need_idx) > 0:
if cand_last_t is None:
cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16)
if isinstance(aop_cfg, dict) and aop_cfg:
aop_resume = dict(aop_cfg)
aop_resume["enabled"] = bool(_orig_enabled and side_enable)
setattr(model.encoder, "aop_prune_config", aop_resume)
interm = getattr(out_mid, "intermediate_state", None)
assert interm is not None, "Model must return intermediate_state when stop_at_layer is set."
hs = interm["hidden_states"].detach()
am = interm["attention_mask"].detach()
pos = interm["position_ids"].detach()
vm = interm.get("vision_mask", None)
tm = interm.get("text_mask", None)
next_layer = int(interm["next_layer_idx"])
hs_sub = hs[need_idx]
am_sub = am[need_idx]
pos_sub = pos[:, need_idx, :]
vm_sub = vm[need_idx] if vm is not None else None
tm_sub = tm[need_idx] if tm is not None else None
resume_state = {
"hidden_states": hs_sub,
"attention_mask": am_sub,
"position_ids": pos_sub,
"vision_mask": vm_sub,
"text_mask": tm_sub,
"next_layer_idx": next_layer,
}
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
out_last = model.encoder(
return_dict=True,
output_hidden_states=False,
stop_at_layer=None,
resume_state=resume_state,
compute_lm_head=False, # 关键:不算 logits
)
hs_last = getattr(out_last, "last_hidden_state", None)
if hs_last is None:
assert out_last.hidden_states is not None and len(out_last.hidden_states) > 0
hs_last = out_last.hidden_states[-1]
am_last = getattr(out_last, "attention_mask", None)
if am_last is None:
am_last = am_sub
if hasattr(am_last, "device") and am_last.device != hs_last.device:
am_last = am_last.to(hs_last.device)
reps_last_t = model._pooling(hs_last, am_last).detach().to(device=device, dtype=torch.bfloat16)
if not use_local:
scores_last_t = reps_last_t @ cand_last_t.T
order_t = torch.argsort(scores_last_t, dim=1, descending=True)
for k, j in enumerate(need_idx):
order = order_t[k].detach().cpu().numpy()
cids = [cand_ids[i] for i in order]
rel_docids = infos[j]["label_name"] if isinstance(infos[j]["label_name"], list) else [infos[j]["label_name"]]
pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None})
else:
for k, j in enumerate(need_idx):
cand_local = infos[j]["cand_names"]
rows = [cand_id2row.get(str(cid), -1) for cid in cand_local]
rows = [r for r in rows if r >= 0]
if len(rows) == 0:
cids = []
else:
cmat_last_t = cand_last_t[rows]
vec_t = (reps_last_t[k:k+1] @ cmat_last_t.T)[0]
order_local = torch.argsort(vec_t, dim=0, descending=True).detach().cpu().numpy()
cids = [str(cand_local[i]) for i in order_local]
rel_docids = infos[j]["label_name"] if isinstance(infos[j]["label_name"], list) else [infos[j]["label_name"]]
pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None})
idx_global += B
# 评测并保存
metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"]
score = RankingMetrics(metrics_to_report).evaluate(pred_dicts)
if is_main:
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, f"{dataset_name}_score_earlyexit.json"), "w") as f:
json.dump(score, f, indent=4)
# 建议测速时 EE_SAVE=0,不写 details
if ee_cfg.get("save", False):
with open(os.path.join(out_dir, f"{dataset_name}_pred_earlyexit.jsonl"), "w", encoding="utf-8") as f:
for row in pred_dicts: f.write(json.dumps(row, ensure_ascii=False) + "\n")
elapsed = time.time() - start_time
return score, elapsed
def make_layer_tag(keep_layers: int | None):
return f"layer{keep_layers}" if keep_layers and keep_layers > 0 else "layerlast"
def dot_sim(a: np.ndarray, b: np.ndarray) -> np.ndarray:
# a: [Nq, D], b: [Nc, D], both L2-normalized already if normalize=true
return a @ b.T
def build_score_details(qid: int, cand_ids: list, score_vec: np.ndarray, ranked_indices: np.ndarray):
return {
"qid": int(qid),
"cand_scores": [
{"cand_id": str(cand_ids[i]), "score": float(score_vec[i])}
for i in ranked_indices
]
}
def top1_top2_margin(score_vec: np.ndarray) -> float:
if len(score_vec) < 2:
return float("inf") # 只有一个候选时视作极大margin
top2 = np.partition(score_vec, -2)[-2:]
top2.sort()
return float(top2[-1] - top2[-2])
def simulate_early_exit_by_margin(
sims_mid: list[dict], sims_last: list[dict], labels: list[list[str]], metrics_to_report: list[str],
taus: list[float], rank_global: bool
):
"""
sims_mid / sims_last: 每个query一个dict: {cand_id: score}
labels: 每个query的正样本cand_id列表
返回:不同tau下的覆盖率、指标
"""
assert len(sims_mid) == len(sims_last) == len(labels)
N = len(labels)
results = []
from src.eval_utils.metrics import RankingMetrics
metrics = RankingMetrics(metrics_to_report)
# 预构造 用于metrics.evaluate 的pred_dict
def to_pred_dicts(use_mid_mask: list[bool]) -> list[dict]:
pred_dicts = []
for qid in range(N):
sims_use = sims_mid[qid] if use_mid_mask[qid] else sims_last[qid]
# 排序
ranked = sorted(sims_use.items(), key=lambda x: -x[1])
pred_dicts.append({
"prediction": [cid for cid, _ in ranked],
"label": labels[qid],
"rel_scores": None
})
return pred_dicts
# 计算中间层margin
margins = []
for qid in range(N):
# 取前两大分数的margin
if len(sims_mid[qid]) == 0:
margins.append(0.0)
continue
scores = np.array(list(sims_mid[qid].values()), dtype=np.float32)
margins.append(top1_top2_margin(scores))
margins = np.array(margins, dtype=np.float32)
for tau in taus:
use_mid_mask = (margins >= tau).tolist()
pred_dicts = to_pred_dicts(use_mid_mask)
score_dict = metrics.evaluate(pred_dicts)
coverage = float(np.mean(use_mid_mask)) # 早停覆盖率
results.append({
"tau": tau,
"coverage": coverage,
**score_dict
})
return results
def top1_top2_margin_from_array(score_vec: np.ndarray) -> float:
if score_vec is None or len(score_vec) == 0:
return 0.0
if len(score_vec) == 1:
return float('inf')
# 取前两大
top2 = np.partition(score_vec, -2)[-2:]
top2.sort()
return float(top2[-1] - top2[-2])
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s')
logger = logging.getLogger(__name__)
# --- Global Dictionaries for Hooks (will be cleared before each encode_embeddings call) ---
timing_info = {}
token_info = {
"vision_tokens": 0,
"text_input_tokens": 0, # Refers to the original text token count
"text_output_tokens": 0, # Not directly applicable here as we are encoding, not generating. Will be 0.
"total_llm_input_tokens": 0, # Refers to the total tokens LLM receives (visual + formatted text)
}
# --- Hook Functions Definition ---
def timing_pre_hook(module, input):
module_id = id(module)
if module_id not in timing_info:
timing_info[module_id] = []
timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__))
def timing_post_hook(module, input, output):
module_id = id(module)
if module_id not in timing_info:
# print(f"Warning: No pre-hook data for module {module.__class__.__name__} ({module_id})")
return
timing_info[module_id].append((time.time(), 'post', module.__class__.__name__))
# Collect vision token count (only from Vision Transformer module's post hook)
module_name = module.__class__.__name__
if "vision" in module_name.lower() and "transformer" in module_name.lower():
if isinstance(output, torch.Tensor):
token_info["vision_tokens"] = output.shape[0] # For visual features, usually (batch_size, num_tokens, hidden_dim)
elif hasattr(output, 'last_hidden_state'):
token_info["vision_tokens"] = output.last_hidden_state.shape[1]
def register_model_hooks(model):
registered_modules = []
core_model = model.encoder if hasattr(model, "encoder") and model.encoder is not None else model
# Vision module
if hasattr(core_model, 'visual') and core_model.visual is not None:
vision_module = core_model.visual
vision_module.register_forward_pre_hook(timing_pre_hook)
vision_module.register_forward_hook(timing_post_hook)
registered_modules.append(vision_module)
print_master(f"Registered hooks for vision module: {vision_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'visual' attribute found on core_model ({type(core_model)}).")
# Merger module (if inside visual) - it's part of the vision component
if hasattr(core_model, 'visual') and hasattr(core_model.visual, 'merger') and core_model.visual.merger is not None:
merger_module = core_model.visual.merger
merger_module.register_forward_pre_hook(timing_pre_hook)
merger_module.register_forward_hook(timing_post_hook)
registered_modules.append(merger_module)
print_master(f"Registered hooks for merger module: {merger_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'merger' attribute found on core_model.visual ({type(getattr(core_model, 'visual', 'N/A'))}).")
# Language model body
if hasattr(core_model, 'model') and core_model.model is not None:
llm_main_module = core_model.model
llm_main_module.register_forward_pre_hook(timing_pre_hook)
llm_main_module.register_forward_hook(timing_post_hook)
registered_modules.append(llm_main_module)
print_master(f"Registered hooks for LLM main module: {llm_main_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'model' attribute found on core_model ({type(core_model)}).")
# LM Head
if hasattr(core_model, 'lm_head') and core_model.lm_head is not None:
lm_head_module = core_model.lm_head
lm_head_module.register_forward_pre_hook(timing_pre_hook)
lm_head_module.register_forward_hook(timing_post_hook)
registered_modules.append(lm_head_module)
print_master(f"Registered hooks for LM head module: {lm_head_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'lm_head' attribute found on core_model ({type(core_model)}).")
if not registered_modules:
print_master("Warning: No major modules found for hook registration. Check model architecture.")
return registered_modules
def pad_dataset_to_divisible(dataset, world_size):
num_samples = len(dataset)
if num_samples % world_size == 0:
return dataset, num_samples
num_to_add = world_size - (num_samples % world_size)
padded_size = num_samples + num_to_add
padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)])
padded_dataset = concatenate_datasets([dataset, padding_data])
return padded_dataset, padded_size
def encode_embeddings(
model: MMEBModel,
loader: DataLoader,
training_args: TrainingArguments,
model_args: ModelArguments,
full_dataset: Dataset,
encode_side: str,
description: str = "Encoding"
) -> tuple[np.ndarray, list, list, list]: # CHANGED: + list for img_token_masks
"""
Encodes embeddings for a given dataset using the model, handling both standard and
late-interaction models in a DDP-safe manner.
Returns:
- embeddings: np.ndarray
- infos_or_ids: list
- batch_stats_list: list
- img_token_masks: list[None | list[bool]] # NEW
"""
local_rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
# Check if the model is a late-interaction type
is_late_interaction = (model_args.model_backbone == COLPALI)
local_embeds = []
local_gt_infos = []
local_max_len = 0
# --- New: List to store statistics for each batch ---
batch_stats_list = []
# --- NEW: Collect masks ---
local_img_token_masks = [] # post image mask per sample
local_txt_token_masks = [] # NEW: post text mask per sample
local_post_attn_masks = [] # NEW: post attention_mask per sample (after prune, 1/0)
# --- NEW: per-sample token reduction records ---
local_token_records = [] # 每条样本一个 dict,含 pre/post/delta 数量
model.eval()
# Register hooks for the model once per encode_embeddings call
registered_hooks = register_model_hooks(model)
# --- NEW: helpers to取mask并序列化 ---
def _search_key(obj, key: str):
# 递归搜索 dict/list/tuple,找到指定 key
if isinstance(obj, dict):
if key in obj:
return obj[key]
for v in obj.values():
r = _search_key(v, key)
if r is not None:
return r
elif isinstance(obj, (list, tuple)):
for v in obj:
r = _search_key(v, key)
if r is not None:
return r
return None
def _to_serializable_mask_list(mask_list, batch_size: int):
# 将模型返回的 mask(list/tensor/ndarray/None)转成 [None | list[bool]] * B
if mask_list is None:
return [None] * batch_size
out = []
if isinstance(mask_list, (list, tuple)):
for m in mask_list:
if m is None:
out.append(None)
elif torch.is_tensor(m):
out.append(m.detach().cpu().tolist())
elif isinstance(m, np.ndarray):
out.append(m.tolist())
else:
# already python list/bool
out.append(m)
elif torch.is_tensor(mask_list):
# 若是 2D 张量(B, L),直接 tolist() -> list[list[bool/int]]
out = mask_list.detach().cpu().tolist()
elif isinstance(mask_list, np.ndarray):
out = mask_list.tolist()
else:
# 未知类型,保守返回 None 占位
out = [None] * batch_size
# 长度对齐 batch_size
if isinstance(out, list):
if len(out) < batch_size:
out = out + [None] * (batch_size - len(out))
elif len(out) > batch_size:
out = out[:batch_size]
return out
def _to_bool_lists(m, batch_size: int):
lst = _to_serializable_mask_list(m, batch_size)
# 归一化成 list[ list[bool] | None ]
out = []
for x in lst:
if x is None:
out.append(None)
else:
# x 可能是 list[int] 或 list[bool]
out.append([bool(int(v)) for v in x])
return out
with torch.no_grad():
for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0):
# --- Reset statistics for each inference pass ---
timing_info.clear()
token_info["vision_tokens"] = 0
token_info["text_input_tokens"] = 0
token_info["text_output_tokens"] = 0
token_info["total_llm_input_tokens"] = 0
inputs = batch_to_device(inputs, training_args.device)
current_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs and inputs['input_ids'] is not None else 1
with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
start_inference_time = time.time()
# ---- NEW: 按侧开/关 AOP ----
aop_cfg = getattr(model.encoder, "aop_prune_config", None)
_orig_enabled = None
if isinstance(aop_cfg, dict) and aop_cfg:
_orig_enabled = aop_cfg.get("enabled", False)
apply_to = aop_cfg.get("apply_to", "qry")
side_enable = (apply_to == "both") or (apply_to == encode_side)
aop_cfg["enabled"] = bool(side_enable and _orig_enabled)
setattr(model.encoder, "aop_prune_config", aop_cfg)
if encode_side == "qry":
output = model(qry=inputs)
reps = output["qry_reps"].detach()
local_gt_infos.extend(dataset_info)
else:
output = model(tgt=inputs)
reps = output["tgt_reps"].detach()
local_gt_infos.extend([info["cand_name"] for info in dataset_info])
# ---- NEW: 恢复 enabled(避免影响下个 encode_side)----
if isinstance(aop_cfg, dict) and _orig_enabled is not None:
aop_cfg["enabled"] = _orig_enabled
setattr(model.encoder, "aop_prune_config", aop_cfg)
end_inference_time = time.time()
# --- NEW: 提取 post-prune 的 image/text 掩码 与 post attention_mask ---
img_masks_raw = None
txt_masks_raw = None
post_attn_raw = None
if isinstance(output, dict):
img_masks_raw = _search_key(output, "image_token_bool_masks")
txt_masks_raw = _search_key(output, "text_token_bool_masks") # NEW
post_attn_raw = _search_key(output, "post_attention_mask") # NEW(我们的 MMEBModel.forward 里带了这个键)
# 兼容:若挂在 model 上
if img_masks_raw is None and hasattr(model, "image_token_bool_masks"):
img_masks_raw = getattr(model, "image_token_bool_masks")
if txt_masks_raw is None and hasattr(model, "text_token_bool_masks"):
txt_masks_raw = getattr(model, "text_token_bool_masks")
if post_attn_raw is None and hasattr(model, "post_attention_mask"):
post_attn_raw = getattr(model, "post_attention_mask")
img_masks_serializable = _to_serializable_mask_list(img_masks_raw, current_batch_size)
txt_masks_serializable = _to_serializable_mask_list(txt_masks_raw, current_batch_size) # NEW
post_attn_serializable = _to_serializable_mask_list(post_attn_raw, current_batch_size) # NEW
local_img_token_masks.extend(img_masks_serializable)
local_txt_token_masks.extend(txt_masks_serializable) # NEW
local_post_attn_masks.extend(post_attn_serializable) # NEW
# --- NEW: 计算本 batch 的 pre/post/delta 数量并累计 ---
cfg = getattr(model.encoder, "config", None)
# pre masks 来自 inputs(删前)
input_ids = inputs.get("input_ids", None)
attn2d_pre = inputs.get("attention_mask", None)
if input_ids is None or attn2d_pre is None or cfg is None:
# 无法统计,留空
pre_vis_counts = [0] * current_batch_size
pre_txt_counts = [0] * current_batch_size
pre_tot_counts = [0] * current_batch_size
else:
iid = input_ids
am = attn2d_pre.to(torch.bool)
image_token_id = getattr(cfg, "image_token_id", None)
video_token_id = getattr(cfg, "video_token_id", None)
bos_id = getattr(cfg, "bos_token_id", None)
eos_id = getattr(cfg, "eos_token_id", None)
pad_id = getattr(cfg, "pad_token_id", None)
is_image = (iid == image_token_id) if (image_token_id is not None and image_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool)
is_video = (iid == video_token_id) if (video_token_id is not None and video_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool)
is_vision = is_image | is_video
is_special = torch.zeros_like(iid, dtype=torch.bool)
for tid in [bos_id, eos_id, pad_id]:
if tid is not None and tid >= 0:
is_special |= (iid == tid)
pre_txt_mask = am & (~is_vision) & (~is_special)
pre_vis_mask = am & is_vision
pre_vis_counts = pre_vis_mask.sum(dim=1).tolist()
pre_txt_counts = pre_txt_mask.sum(dim=1).tolist()
pre_tot_counts = am.sum(dim=1).tolist()
# post masks(删后)来自模型输出;与 post_attn 做与运算
post_text_masks = _to_bool_lists(txt_masks_raw, current_batch_size) # list[ list[bool] | None ]
post_image_masks = _to_bool_lists(img_masks_raw, current_batch_size)
post_attn_masks = _to_bool_lists(post_attn_raw, current_batch_size)
sum_pre_text = 0; sum_post_text = 0
sum_pre_vis = 0; sum_post_vis = 0
sum_pre_tot = 0; sum_post_tot = 0
for i in range(current_batch_size):
pre_text = int(pre_txt_counts[i]) if i < len(pre_txt_counts) else 0
pre_vis = int(pre_vis_counts[i]) if i < len(pre_vis_counts) else 0
pre_tot = int(pre_tot_counts[i]) if i < len(pre_tot_counts) else 0
# post 计数:mask 可能为 None
m_text = post_text_masks[i] if post_text_masks is not None and i < len(post_text_masks) else None
m_img = post_image_masks[i] if post_image_masks is not None and i < len(post_image_masks) else None
m_attn = post_attn_masks[i] if post_attn_masks is not None and i < len(post_attn_masks) else None
if m_attn is None:
post_text = 0; post_vis = 0; post_tot = 0
else:
# 与 attention_mask 后统计 True 的数
if m_text is not None:
post_text = sum(1 for a, t in zip(m_attn, m_text) if a and t)
else:
post_text = 0
if m_img is not None:
post_vis = sum(1 for a, v in zip(m_attn, m_img) if a and v)
else:
post_vis = 0
post_tot = sum(1 for a in m_attn if a)
# 累计 batch 级
sum_pre_text += pre_text; sum_post_text += post_text
sum_pre_vis += pre_vis; sum_post_vis += post_vis
sum_pre_tot += pre_tot; sum_post_tot += post_tot
# 保存 per-sample 记录(用于 JSONL)
local_token_records.append({
"side": encode_side,
"pre": {"text": pre_text, "vision": pre_vis, "total": pre_tot},
"post": {"text": post_text, "vision": post_vis, "total": post_tot},
"delta":{"text": pre_text - post_text, "vision": pre_vis - post_vis, "total": pre_tot - post_tot},
})
# --- Update total LLM input tokens after the model call ---
if 'input_ids' in inputs and inputs['input_ids'] is not None:
token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1]
token_info["text_input_tokens"] = token_info["total_llm_input_tokens"] - token_info["vision_tokens"]
token_info["text_input_tokens"] = max(0, token_info["text_input_tokens"])
# --- Collect and Store Batch Statistics ---
batch_inference_time = end_inference_time - start_inference_time
current_batch_stats = {
"batch_size": current_batch_size,
"total_inference_time_seconds": batch_inference_time,
"module_inference_times": {},
"token_counts": {
"visual_tokens": token_info["vision_tokens"],
"language_input_tokens_raw": token_info["text_input_tokens"],
"llm_total_input_tokens": token_info["total_llm_input_tokens"],
"language_output_tokens": token_info["text_output_tokens"],
}
}
current_batch_stats["token_reduction"] = {
"sum_pre_text": sum_pre_text,
"sum_post_text": sum_post_text,
"sum_pre_vision": sum_pre_vis,
"sum_post_vision": sum_post_vis,
"sum_pre_total": sum_pre_tot,
"sum_post_total": sum_post_tot,
}
# Calculate and store module timings for the current batch
for module_obj in registered_hooks:
module_id = id(module_obj)
module_name = module_obj.__class__.__name__
times = timing_info.get(module_id, [])
durations = []
pre_times = {}
for t, event_type, _ in times:
if event_type == 'pre':
pre_times[module_id] = t
elif event_type == 'post' and module_id in pre_times:
duration = t - pre_times.pop(module_id)
durations.append(duration)
if durations:
current_batch_stats["module_inference_times"][module_name] = {
"total": sum(durations),
"count": len(durations),
"avg": sum(durations) / len(durations)
}
else:
current_batch_stats["module_inference_times"][module_name] = {
"total": 0.0,
"count": 0,
"avg": 0.0
}
batch_stats_list.append(current_batch_stats)
# --- Debug prints (optional) ---
print_rank(f"\n--- Inference Statistics for {encode_side} batch (Rank {local_rank}) ---")
print_rank(f"Batch Inference took: {batch_inference_time:.4f} seconds")
print_rank("--- Module Inference Timing Statistics ---")
for module_name, stats in current_batch_stats["module_inference_times"].items():
print_rank(f"**{module_name}**: Total: {stats['total']:.6f}s, Count: {stats['count']}, Avg: {stats['avg']:.6f}s")
print_rank("--- Token Count Statistics ---")
print_rank(f"**视觉 token 数量**: {current_batch_stats['token_counts']['visual_tokens']}")
print_rank(f"**语言输入 token 数量 (仅原始文本)**: {current_batch_stats['token_counts']['language_input_tokens_raw']}")
print_rank(f"**LLM总输入 token 数量 (包含视觉 + 格式化文本)**: {current_batch_stats['token_counts']['llm_total_input_tokens']}")
print_rank(f"**语言输出 token 数量**: {current_batch_stats['token_counts']['language_output_tokens']}")
if is_late_interaction and reps.dim() == 3:
local_max_len = max(local_max_len, reps.shape[1])
local_embeds.append(reps)
if not local_embeds:
# Handle cases where a rank gets no data
return np.array([]), [], [], [] # CHANGED: 4个返回值
# === DDP Synchronization and Padding for Late-Interaction Models ===
if is_late_interaction:
if dist.is_initialized():
# 1: global max length
local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device)
dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX)
global_max_len = local_max_len_tensor.item()
else:
global_max_len = local_max_len
# 2: pad to global max length
padded_embeds = []
for reps_batch in local_embeds:
if reps_batch.dim() == 3:
B, L, H = reps_batch.shape
padding_size = global_max_len - L
padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0)
padded_embeds.append(padded_batch)
else:
padded_embeds.append(reps_batch)
embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous()
else:
embeds_tensor = torch.cat(local_embeds, dim=0).contiguous()
# === Gather embeddings and keys from all ranks ===
if dist.is_initialized() and full_dataset.num_rows >= world_size:
print_master(f"Gathering {encode_side} embeddings across all ranks...")
# tensor gather
output_shape = list(embeds_tensor.shape)
output_shape[0] = full_dataset.num_rows
embeds_tensor = embeds_tensor.to(training_args.device)
gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device)
dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor)
final_embeddings = gathered_embeds_tensor.cpu().float().numpy()
# object gather for infos and stats
gathered_gt_infos = [None for _ in range(world_size)]
dist.all_gather_object(gathered_gt_infos, local_gt_infos)
all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys]
gathered_batch_stats = [None for _ in range(world_size)]
dist.all_gather_object(gathered_batch_stats, batch_stats_list)
all_batch_stats = [stats for rank_stats in gathered_batch_stats for stats in rank_stats]
# --- NEW: gather masks ---
gathered_masks = [None for _ in range(world_size)]
dist.all_gather_object(gathered_masks, local_img_token_masks)
all_img_token_masks = [m for rank_list in gathered_masks for m in rank_list]
# NEW: gather text masks
gathered_txt_masks = [None for _ in range(world_size)]
dist.all_gather_object(gathered_txt_masks, local_txt_token_masks)
all_txt_token_masks = [m for rank_list in gathered_txt_masks for m in rank_list]
# NEW: gather post attention masks(如需)
gathered_post_attn = [None for _ in range(world_size)]
dist.all_gather_object(gathered_post_attn, local_post_attn_masks)
all_post_attn_masks = [m for rank_list in gathered_post_attn for m in rank_list]
# NEW: gather token records
gathered_token_recs = [None for _ in range(world_size)]
dist.all_gather_object(gathered_token_recs, local_token_records)
all_token_records = [r for rank_list in gathered_token_recs for r in rank_list]
else:
all_gt_infos = local_gt_infos
final_embeddings = embeds_tensor.cpu().float().numpy()
all_batch_stats = batch_stats_list
all_img_token_masks = local_img_token_masks # NEW
all_txt_token_masks = local_txt_token_masks
all_post_attn_masks = local_post_attn_masks
all_token_records = local_token_records
return final_embeddings, all_gt_infos, all_batch_stats, all_img_token_masks, all_txt_token_masks, all_token_records
# === NEW: 一次前向同时导出 cand 的中间层和最后一层向量 ===
def encode_candidates_both_layers(
model: MMEBModel,
loader: DataLoader,
training_args: TrainingArguments,
model_args: ModelArguments,
full_dataset: Dataset,
mid_layer: int,
) -> tuple[np.ndarray, np.ndarray, list]:
"""
单次forward到最后一层,直接从 hidden_states 取:
- mid_hidden = hidden_states[mid_layer] # 表示经过 mid_layer 层后的状态(见Qwen2_5_VLModel的all_hidden_states定义)
- last_hidden = hidden_states[-1] # 最后一层norm后的状态
然后用 _pooling(attention_mask) 取句向量,返回:
- cand_mid_embeds: np.ndarray [Nc, D]
- cand_last_embeds: np.ndarray [Nc, D]
- cand_ids: list[str]
说明:
- cand 侧默认不做 AOP 剪枝(AOP_APPLY=qry 时天然关闭),因此 mid/last 的序列长度一致,可直接用原 attention_mask 做池化。
"""
local_rank = dist.get_rank() if dist.is_initialized() else 0
model.eval()
all_mid = []
all_last = []
all_ids = []
with torch.no_grad():
for inputs, dataset_info in tqdm(loader, desc=f"Candidates[BOTH] (rank {local_rank})", disable=local_rank > 0):
inputs = batch_to_device(inputs, training_args.device)
# cand 侧确保不触发 AOP(如果你的 AOP_APPLY=qry/both,会在底模按侧门控;此处再做一次保险)
aop_cfg = getattr(model.encoder, "aop_prune_config", None)
_orig_enabled = None
if isinstance(aop_cfg, dict) and aop_cfg:
_orig_enabled = aop_cfg.get("enabled", False)
apply_to = aop_cfg.get("apply_to", "qry")
side_enable = (apply_to == "both") or (apply_to == "cand")
aop_cfg["enabled"] = bool(side_enable and _orig_enabled)
setattr(model.encoder, "aop_prune_config", aop_cfg)
with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
# 关键:一次forward拿全层的hidden_states
out = model.encoder(
**inputs,
return_dict=True,
output_hidden_states=True, # 必须
stop_at_layer=None, # 走全层
)
# 取 hidden_states 并索引中间层/最后一层
hs_list = out.hidden_states
assert hs_list is not None and len(hs_list) > mid_layer, \
f"hidden_states is None or too short. Need index {mid_layer}, got len={0 if hs_list is None else len(hs_list)}"
mid_hs = hs_list[mid_layer] # [B, L, D]:等价“经过 mid_layer 层后的状态”(即 pre-layer(mid_layer+1))
last_hs = hs_list[-1] # [B, L, D]:最终norm后的状态
# 用原 attention_mask 池化(cand侧未剪枝)
am = inputs.get("attention_mask", None)
if am is not None and hasattr(am, "device"):
if am.device != mid_hs.device:
am = am.to(mid_hs.device)
reps_mid = model._pooling(mid_hs, am) # [B, D]
reps_last = model._pooling(last_hs, am) # [B, D]
all_mid.append(reps_mid.detach().float().cpu())
all_last.append(reps_last.detach().float().cpu())
all_ids.extend([info["cand_name"] for info in dataset_info])
# 恢复 AOP 开关(避免影响其它侧)
if isinstance(aop_cfg, dict) and _orig_enabled is not None:
aop_cfg["enabled"] = _orig_enabled
setattr(model.encoder, "aop_prune_config", aop_cfg)
if not all_mid:
return np.array([]), np.array([]), []
cand_mid_embeds = torch.cat(all_mid, dim=0).numpy()
cand_last_embeds = torch.cat(all_last, dim=0).numpy()
return cand_mid_embeds, cand_last_embeds, all_ids
def main():
# ----------------------- Distributed init -----------------------
if "RANK" in os.environ and dist.is_available() and not dist.is_initialized():
dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60))
local_rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
print_master("Distributed init debug info:")
print_master(f"RANK: {os.environ.get('RANK')}")
print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}")
print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}")
print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}")
print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}")
if dist.is_initialized():
print_rank(f"dist.get_rank(): {dist.get_rank()}")
print_rank(f"dist.get_world_size(): {dist.get_world_size()}")
# 兼容 torchrun 参数
for arg in sys.argv:
if arg.startswith("--local-rank="):
rank = arg.split("=")[1]
sys.argv.remove(arg)
sys.argv.append('--local_rank')
sys.argv.append(rank)
# ----------------------- Parse args -----------------------
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments
os.makedirs(data_args.encode_output_path, exist_ok=True)
# 支持多层评测(优先 LM_LAYERS,兼容 MID_LM_LAYER)
layers_to_eval = get_env_eval_layers()
print_master(f"Eval layers (qry/tgt): {layers_to_eval}")
# ----------------------- Model loading -----------------------
hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
if not getattr(model_args, "model_backbone", None):
model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type)
setattr(model_args, 'model_backbone', model_backbone)
setattr(training_args, 'model_backbone', model_backbone)
print_master(f'Model Backbone: {model_args.model_backbone}')
# 仅 rank0 下载,其他rank等待缓存
if local_rank == 0:
processor = load_processor(model_args, data_args)
model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...")
if torch.distributed.is_initialized():
torch.distributed.barrier()
if local_rank != 0:
print_rank(f"Loading the model from cache...")
processor = load_processor(model_args, data_args)
time.sleep(random.randint(2 * local_rank, 3 * local_rank))
model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
model.eval()
model = model.to(training_args.device, dtype=torch.bfloat16)
# ---- NEW: AOP 剪裁配置注入(驱动底模里已实现的 AOP 逻辑)----
aop_cfg = get_env_aop_config()
if aop_cfg["enabled"]:
# 把配置塞到底模;底模 forward 中读取该 dict 并执行剪裁
setattr(model.encoder, "aop_prune_config", aop_cfg)
# 可选:为了便于在判定层取注意力或手算 qk,覆盖注意力实现
attn_override = aop_cfg.get("attn_impl_override", "")
if attn_override:
try:
if hasattr(model.encoder, "model") and hasattr(model.encoder.model, "config"):
prev = model.encoder.model.config._attn_implementation
model.encoder.model.config._attn_implementation = attn_override
print_master(f"[AOP] override attn impl: {prev} -> {attn_override} (仅测试建议)")
except Exception as e:
print_master(f"[AOP] try override attn impl failed: {e}")
print_master("[AOP] AOP-Prune enabled with config: " + json.dumps({
"apply_to": aop_cfg["apply_to"],
"layer_idx": aop_cfg["layer_idx"],
"mode": aop_cfg["mode"],
"delta": aop_cfg["delta"],
"K_hat": aop_cfg["K_hat"],
"keep_ratio": aop_cfg["keep_ratio"],
"min_keep": aop_cfg["min_keep"],
"use_bias": aop_cfg["use_bias"],
"margin_mid?": (aop_cfg["margin_mid"] is not None),
"prune_text": aop_cfg.get("prune_text", False),
"keep_ratio_text": aop_cfg.get("keep_ratio_text", None),
"keep_ratio_vision": aop_cfg.get("keep_ratio_vision", None),
"selection": aop_cfg.get("selection", "aop"),
"attn_agg": aop_cfg.get("attn_agg", "mean"),
}))
else:
print_master("[AOP] disabled (set AOP_ENABLED=1 to enable)")
# 确保“最后一层”时不裁层(避免类里默认20层的坑)
model.set_inference_layers(qry_layers=None, tgt_layers=None)
with open(data_args.dataset_config, 'r') as yaml_file:
dataset_configs = yaml.safe_load(yaml_file)
# ----------------------- Main evaluation loop -----------------------
for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()):
if dist.is_initialized():
dist.barrier()
print_master(f"\n--- Evaluating {dataset_name} ---")
# 根据 data_basedir 修正路径
if data_args.data_basedir is not None:
for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]:
if data_args.data_basedir and task_config.get(key):
task_config[key] = os.path.join(data_args.data_basedir, task_config[key])
# 构建数据集
full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config)
full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus)
eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset
if dist.is_initialized():
world_size = dist.get_world_size()
padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size)
padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size)
eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size)
eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size)
else:
padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset
# === EE-only: 仅在线早停推理(先确保两份 candidate 向量)===
ee_cfg = get_env_ee_config()
assert ee_cfg["enabled"], "EE_ENABLED must be 1 for EE-only pipeline."
# 依据 EE_LAYER 构造 tag
mid_layer = int(ee_cfg["layer"])
mid_tag = make_layer_tag(mid_layer) # e.g., layer12
last_tag = "layerlast"
# 准备路径
cand_mid_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{mid_tag}")
cand_last_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{last_tag}")
# 构造 cand DataLoader(一次性,不切分)
eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand")
eval_cand_loader = DataLoader(
full_eval_cand_dataset,
batch_size=training_args.per_device_eval_batch_size,
collate_fn=eval_cand_collator,
num_workers=training_args.dataloader_num_workers
)
# === 替换为:一次前向,导出 cand 的 mid/last 两份向量 ===
need_mid = (not os.path.exists(cand_mid_path))
need_last = (not os.path.exists(cand_last_path))
if need_mid or need_last:
print_master(f"[{dataset_name}] EE-only: encoding candidates BOTH layers in one pass (mid={mid_tag}, last={last_tag}) ...")
# 走全层(不截层)
model.set_inference_layers(qry_layers=None, tgt_layers=None)
cand_embeds_mid, cand_embeds_last, all_cand_ids = encode_candidates_both_layers(
model=model,
loader=eval_cand_loader,
training_args=training_args,
model_args=model_args,
full_dataset=full_eval_cand_dataset,
mid_layer=mid_layer,
)
if local_rank == 0:
if need_mid:
cand_embed_dict_mid = {cid: emb for cid, emb in zip(all_cand_ids, cand_embeds_mid)}
with open(cand_mid_path, "wb") as f:
pickle.dump(cand_embed_dict_mid, f)
print_master(f"[{dataset_name}] EE-only: saved {mid_tag} candidate embeddings -> {cand_mid_path}")
if need_last:
cand_embed_dict_last = {cid: emb for cid, emb in zip(all_cand_ids, cand_embeds_last)}
with open(cand_last_path, "wb") as f:
pickle.dump(cand_embed_dict_last, f)
print_master(f"[{dataset_name}] EE-only: saved {last_tag} candidate embeddings -> {cand_last_path}")
else:
print_master(f"[{dataset_name}] EE-only: reuse existing candidates (mid={cand_mid_path}, last={cand_last_path})")
if dist.is_initialized():
dist.barrier()
# 3) 在线早停门控 + 子集续跑(不做离线分层评分/曲线)
if local_rank == 0:
with open(cand_mid_path, "rb") as f:
cand_mid_dict = pickle.load(f)
with open(cand_last_path, "rb") as f:
cand_last_dict = pickle.load(f)
rank_global = task_config.get("eval_type", "global") == "global"
print_master(f"[{dataset_name}] Run ONLINE early-exit at layer={ee_cfg['layer']}, method={ee_cfg['method']}, tau={ee_cfg['tau']}, topk={ee_cfg['topk']}, global={rank_global}")
run_early_exit_queries(
model=model,
processor=processor,
model_args=model_args,
data_args=data_args,
training_args=training_args,
qry_dataset=full_eval_qry_dataset, # 全量 query
cand_mid_dict=cand_mid_dict,
cand_last_dict=cand_last_dict,
ee_cfg=ee_cfg,
dataset_name=dataset_name,
out_dir=data_args.encode_output_path,
global_ranking=rank_global,
)
if dist.is_initialized():
dist.barrier()
# === EE-only 结束;直接进入下一个数据集 ===
continue
if __name__ == '__main__':
main()