| import datetime |
| import logging |
| import json |
| import random |
| import time |
| import numpy as np |
| import os |
| import pickle |
| import sys |
| import torch |
| import torch.distributed as dist |
| import torch.nn.functional as F |
| import yaml |
| import transformers |
| import math |
|
|
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| from transformers import HfArgumentParser, AutoConfig, AutoTokenizer |
| from datasets import Dataset, concatenate_datasets |
| from datasets.distributed import split_dataset_by_node |
| from src.model.vlm_backbone.qwen2_vl.modeling_qwen2_vl_train_tokrnpooling import Qwen2VLForConditionalGeneration as _Qwen2VLForConditionalGeneration_src |
|
|
| from src.arguments import ModelArguments, DataArguments, TrainingArguments |
| from src.data.collator.eval_collator import MultimodalEvalDataCollator |
| from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset |
| from src.eval_utils.metrics import RankingMetrics |
| from src.model.model_cut_layer_AOP_add_text_cut import MMEBModel |
| from src.model.processor import get_backbone_name, load_processor, COLPALI |
| from src.utils import batch_to_device, print_rank, print_master |
|
|
| from dataclasses import dataclass |
|
|
| def get_env_mid_layer(): |
| v = os.environ.get("MID_LM_LAYER", "").strip() |
| if v == "" or v.lower() in {"none", "null"}: |
| return None |
| try: |
| return int(v) |
| except: |
| logger.warning(f"Invalid MID_LM_LAYER={v}, ignore.") |
| return None |
|
|
| |
| def _parse_bool(v: str, default=False): |
| if v is None: return default |
| v = v.strip().lower() |
| return v in {"1","true","yes","y","t","on"} |
|
|
| def _parse_float(v: str, default=None): |
| try: return float(v) if v is not None else default |
| except: return default |
|
|
| def _parse_int(v: str, default=None): |
| try: return int(v) if v is not None else default |
| except: return default |
|
|
| def get_env_aop_config(): |
| """ |
| 从环境变量读取 AOP 剪裁配置。仅作为“驱动层”的简要测试开关; |
| 实际剪裁逻辑在底模里(Qwen2-VLModel.forward)实现。 |
| """ |
| enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False) |
| apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() |
| layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None) |
| mode = os.environ.get("AOP_MODE", "delta").strip().lower() |
|
|
| |
| delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10) |
| khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0) |
| keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0) |
| min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64) |
| use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True) |
|
|
| |
| prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True) |
| prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False) |
|
|
| delta_v = _parse_float(os.environ.get("AOP_DELTA_VISION"), None) |
| khat_v = _parse_float(os.environ.get("AOP_KHAT_VISION"), None) |
| keep_ratio_v= _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None) |
| min_keep_v = _parse_int(os.environ.get("AOP_MIN_KEEP_VISION"), None) |
|
|
| delta_t = _parse_float(os.environ.get("AOP_DELTA_TEXT"), None) |
| khat_t = _parse_float(os.environ.get("AOP_KHAT_TEXT"), None) |
| keep_ratio_t= _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None) |
| min_keep_t = _parse_int(os.environ.get("AOP_MIN_KEEP_TEXT"), 32) |
|
|
| protect_text_last = _parse_int(os.environ.get("AOP_PROTECT_TEXT_LAST"), 16) |
| protect_special = _parse_bool(os.environ.get("AOP_PROTECT_SPECIAL"), True) |
|
|
| margin_src = os.environ.get("AOP_MARGIN", "").strip().lower() |
| attn_impl = os.environ.get("AOP_ATTN_IMPL", "").strip().lower() |
|
|
| if layer_idx is None and enabled: |
| logger.warning("AOP_ENABLED=1 但未设置 AOP_LAYER,关闭 AOP。"); enabled=False |
|
|
| |
| selection = os.environ.get("AOP_SELECTION", "aop").strip().lower() |
| if _parse_bool(os.environ.get("AOP_RANDOM"), False): |
| selection = "random" |
| random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None) |
|
|
| |
| selection = os.environ.get("AOP_SELECTION", "aop").strip().lower() |
| if _parse_bool(os.environ.get("AOP_RANDOM"), False): |
| selection = "random" |
| random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None) |
| attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() |
|
|
| cfg = { |
| "enabled": enabled, |
| "apply_to": apply_to, |
| "layer_idx": layer_idx, |
| "mode": mode, |
|
|
| |
| "delta": delta, "K_hat": khat, |
| "keep_ratio": keep_ratio, "min_keep": min_keep, |
| "use_bias": use_bias, "eps": 1e-6, |
|
|
| |
| "prune_vision": prune_vision, |
| "prune_text": prune_text, |
|
|
| |
| "delta_vision": delta_v, |
| "K_hat_vision": khat_v, |
| "keep_ratio_vision": keep_ratio_v, |
| "min_keep_vision": min_keep_v, |
|
|
| |
| "delta_text": delta_t, |
| "K_hat_text": khat_t, |
| "keep_ratio_text": keep_ratio_t, |
| "min_keep_text": min_keep_t, |
|
|
| |
| "protect_text_last": protect_text_last, |
| "protect_special": protect_special, |
|
|
| |
| "margin_mid": None if margin_src != "mid" else "USE_MID_MARGIN", |
| "epsilon_hat": None, |
| "attn_impl_override": attn_impl if attn_impl in {"sdpa"} else "", |
|
|
| |
| "selection": selection, |
| "random_seed": random_seed, |
| "attn_agg": attn_agg, |
| } |
| return cfg |
|
|
| def get_env_eval_layers(): |
| """ |
| 解析环境变量 LM_LAYERS(优先)或兼容旧的 MID_LM_LAYER。 |
| - LM_LAYERS 示例:"4,8,12,last";可包含 'last'/'none'/'null'/'-1' 表示最后一层(None)。 |
| - 若未设置 LM_LAYERS,则回落到旧逻辑:MID_LM_LAYER=None -> [None];否则 [mid, None] |
| 返回: list[ int | None ],例如 [4, 8, 12, None];None 代表最后一层。 |
| """ |
| v = os.environ.get("LM_LAYERS", None) |
| if v is not None: |
| v = v.strip() |
|
|
| if v: |
| toks = [t.strip() for t in v.split(',') if t.strip() != ""] |
| layers = [] |
| for tok in toks: |
| tl = tok.lower() |
| if tl in {"last", "none", "null", "-1"}: |
| layers.append(None) |
| else: |
| try: |
| val = int(tok) |
| if val > 0: |
| layers.append(val) |
| else: |
| logger.warning(f"Ignoring non-positive layer '{tok}' in LM_LAYERS.") |
| except Exception: |
| logger.warning(f"Invalid token '{tok}' in LM_LAYERS; must be int or 'last'/'none'.") |
| |
| seen = set() |
| uniq = [] |
| for l in layers: |
| key = -1 if l is None else l |
| if key in seen: |
| continue |
| seen.add(key) |
| uniq.append(l) |
| if not uniq: |
| return [None] |
| return uniq |
| else: |
| |
| mid = get_env_mid_layer() |
| return [None] if mid is None else [mid, None] |
|
|
| |
| def get_env_ee_config(): |
| ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on","y","t"} |
| layer = int(os.environ.get("EE_LAYER", os.environ.get("AOP_LAYER", "12"))) |
| method = os.environ.get("EE_METHOD", "margin").strip().lower() |
| tau = float(os.environ.get("EE_TAU", "0.2")) |
| topk = int(os.environ.get("EE_TOPK", "1024")) |
| temp = float(os.environ.get("EE_TEMP", "0.05")) |
| save = os.environ.get("EE_SAVE", "1").strip().lower() in {"1","true","yes","on","y","t"} |
| combw = os.environ.get("EE_COMB_WEIGHTS", "1.0,0.5,0.5") |
| try: |
| w_margin, w_conf, w_sq = [float(x) for x in combw.split(",")] |
| except Exception: |
| w_margin, w_conf, w_sq = 1.0, 0.5, 0.5 |
| return dict( |
| enabled=ee_enabled, layer=layer, method=method, tau=tau, |
| topk=topk, temp=temp, save=save, |
| w_margin=w_margin, w_conf=w_conf, w_sq=w_sq |
| ) |
|
|
| def _softmax_np(x: np.ndarray, temp: float = 1.0) -> np.ndarray: |
| x = x - np.max(x) |
| ex = np.exp(x / max(1e-6, temp)) |
| s = np.sum(ex) |
| return ex / max(s, 1e-12) |
|
|
| def confidence_from_topk(scores: np.ndarray, method="margin", temp=0.05, w_margin=1.0, w_conf=0.5, w_sq=0.5) -> float: |
| |
| if scores.size == 0: |
| return 0.0 |
| if scores.size == 1: |
| return 1e9 |
| margin = float(scores[0] - scores[1]) |
| p = _softmax_np(scores, temp=temp) |
| p1p2 = float(p[0] - p[1]) |
| H = - float(np.sum(p * np.log(p + 1e-12))) / np.log(len(p)) |
| conf = 1.0 - H |
| sqsum = float(np.sum(p**2)) |
| if method == "margin": return margin |
| if method == "p1p2": return p1p2 |
| if method == "entropy": return conf |
| if method == "gini": return sqsum |
| |
| return w_margin*margin + w_conf*conf + w_sq*sqsum |
|
|
| def run_mid_layer_retrieval( |
| model: MMEBModel, |
| processor, |
| model_args: ModelArguments, |
| data_args: DataArguments, |
| training_args: TrainingArguments, |
| qry_dataset: Dataset, |
| cand_mid_dict: dict, |
| mid_layer: int, |
| dataset_name: str, |
| out_dir: str, |
| global_ranking: bool = True, |
| topk_save: int = 200, |
| save_qry_embeds: bool = True, |
| ): |
| """ |
| 在线检索:只用中间层 qry & 中间层 cand 做相似度和排序,并将中间层 embedding 和相似度信息保存到文件。 |
| |
| 参数: |
| - cand_mid_dict: {cand_id: embedding(mid_layer)},由 encode_candidates_both_layers 生成 |
| - mid_layer: 中间层 index(和训练用的 supervise_layers 保持一致,如 12 / 16) |
| - global_ranking: True=全库检索,False=按 gt_info["cand_names"] 的局部子集检索 |
| - topk_save: 每个 query 保存相似度的 topK 候选数 |
| |
| 输出文件: |
| - {out_dir}/{dataset}_qry_mid_L{mid_layer}.pkl # 全部 query 的中间层 embedding (np.ndarray [Nq,D]) |
| - {out_dir}/{dataset}_score_midonly_L{mid_layer}.json # 检索指标 |
| - {out_dir}/{dataset}_pred_midonly_L{mid_layer}.jsonl # 排序结果 |
| - {out_dir}/{dataset}_midonly_L{mid_layer}_sims_top{K}.jsonl# 每个 query 的 topK 相似度信息 |
| """ |
|
|
| device = training_args.device |
| local_rank = dist.get_rank() if dist.is_initialized() else 0 |
| is_main = (not dist.is_initialized()) or (local_rank == 0) |
|
|
| |
| cand_ids = list(cand_mid_dict.keys()) |
| cand_id2idx = {str(cid): i for i, cid in enumerate(cand_ids)} |
|
|
| cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32) |
| cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16) |
| Nc = cand_mid_t.size(0) |
|
|
| |
| collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") |
| loader = DataLoader( |
| qry_dataset, |
| batch_size=training_args.per_device_eval_batch_size, |
| collate_fn=collator, |
| num_workers=training_args.dataloader_num_workers, |
| ) |
|
|
| pred_dicts = [] |
| all_qry_embeds = [] |
| sim_records = [] if is_main else None |
|
|
| qid_offset = 0 |
| start_time = time.time() |
|
|
| for inputs, infos in tqdm( |
| loader, |
| desc=f"[MID-ONLY] {dataset_name}@L{mid_layer}", |
| disable=local_rank > 0, |
| ): |
| inputs = batch_to_device(inputs, device) |
|
|
| |
| with torch.no_grad(), torch.autocast( |
| device_type="cuda", dtype=torch.bfloat16, enabled=True |
| ): |
| out_mid = model.encoder( |
| **inputs, |
| return_dict=True, |
| output_hidden_states=False, |
| stop_at_layer=int(mid_layer), |
| compute_lm_head=False, |
| ) |
|
|
| hs_mid = getattr(out_mid, "last_hidden_state", None) |
| if hs_mid is None: |
| assert out_mid.hidden_states is not None and len(out_mid.hidden_states) > 0 |
| hs_mid = out_mid.hidden_states[-1] |
|
|
| am_mid = getattr(out_mid, "attention_mask", None) |
| if am_mid is None: |
| am_mid = inputs.get("attention_mask", None) |
| if hasattr(am_mid, "device") and am_mid.device != hs_mid.device: |
| am_mid = am_mid.to(hs_mid.device) |
|
|
| |
| reps_mid_t = model._pooling(hs_mid, am_mid).detach().to( |
| device=device, dtype=torch.bfloat16 |
| ) |
| B = reps_mid_t.size(0) |
|
|
| |
| all_qry_embeds.append(reps_mid_t.float().cpu()) |
|
|
| |
| if global_ranking: |
| |
| scores = (reps_mid_t @ cand_mid_t.T).detach().float().cpu().numpy() |
| for b in range(B): |
| info = infos[b] |
| score_vec = scores[b] |
| |
| full_order = np.argsort(-score_vec) |
|
|
| |
| rel_docids = info["label_name"] |
| if not isinstance(rel_docids, list): |
| rel_docids = [rel_docids] |
| pred_dicts.append({ |
| "prediction": [cand_ids[i] for i in full_order], |
| "label": rel_docids, |
| "rel_scores": None, |
| }) |
|
|
| |
| if is_main and sim_records is not None: |
| K = min(topk_save, Nc) |
| top_idx = full_order[:K] |
| sim_records.append({ |
| "qid": int(qid_offset + b), |
| "label": rel_docids, |
| "topk_cand_ids": [cand_ids[i] for i in top_idx], |
| "topk_scores": score_vec[top_idx].astype(float).tolist(), |
| }) |
| else: |
| |
| reps_cpu = reps_mid_t.detach().float().cpu() |
| for b in range(B): |
| info = infos[b] |
| cand_local = info["cand_names"] |
| rows = [cand_id2idx.get(str(cid), -1) for cid in cand_local] |
| rows = [r for r in rows if r >= 0] |
|
|
| if len(rows) == 0: |
| preds = [] |
| score_vec = np.zeros((0,), dtype=np.float32) |
| else: |
| cmat = cand_mid_t[rows] |
| sv = (reps_mid_t[b:b+1] @ cmat.T)[0].detach().float().cpu().numpy() |
| order_local = np.argsort(-sv) |
| preds = [str(cand_local[i]) for i in order_local] |
| score_vec = sv |
|
|
| rel_docids = info["label_name"] |
| if not isinstance(rel_docids, list): |
| rel_docids = [rel_docids] |
|
|
| pred_dicts.append({ |
| "prediction": preds, |
| "label": rel_docids, |
| "rel_scores": None, |
| }) |
|
|
| if is_main and sim_records is not None: |
| K = min(topk_save, len(preds)) |
| top_ids = preds[:K] |
| top_scores = score_vec[order_local[:K]].astype(float).tolist() if len(score_vec) > 0 else [] |
| sim_records.append({ |
| "qid": int(qid_offset + b), |
| "label": rel_docids, |
| "topk_cand_ids": top_ids, |
| "topk_scores": top_scores, |
| }) |
|
|
| qid_offset += B |
|
|
| |
| if save_qry_embeds and is_main and len(all_qry_embeds) > 0: |
| all_qry_mid = torch.cat(all_qry_embeds, dim=0).numpy() |
| qry_mid_path = os.path.join(out_dir, f"{dataset_name}_qry_mid_L{mid_layer}.pkl") |
| with open(qry_mid_path, "wb") as f: |
| pickle.dump(all_qry_mid, f) |
| print_master(f"[MID-ONLY] Saved query mid-layer embeddings -> {qry_mid_path}") |
|
|
| |
| metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] |
| score = RankingMetrics(metrics_to_report).evaluate(pred_dicts) |
|
|
| if is_main: |
| os.makedirs(out_dir, exist_ok=True) |
| score_path = os.path.join(out_dir, f"{dataset_name}_score_midonly_L{mid_layer}.json") |
| pred_path = os.path.join(out_dir, f"{dataset_name}_pred_midonly_L{mid_layer}.jsonl") |
| sims_path = os.path.join(out_dir, f"{dataset_name}_midonly_L{mid_layer}_sims_top{topk_save}.jsonl") |
|
|
| score["num_pred"] = len(pred_dicts) |
| with open(score_path, "w") as f: |
| json.dump(score, f, indent=4) |
| with open(pred_path, "w", encoding="utf-8") as f: |
| for row in pred_dicts: |
| f.write(json.dumps(row, ensure_ascii=False) + "\n") |
|
|
| if sim_records is not None: |
| with open(sims_path, "w", encoding="utf-8") as f: |
| for rec in sim_records: |
| f.write(json.dumps(rec, ensure_ascii=False) + "\n") |
|
|
| print_master( |
| f"[MID-ONLY] {dataset_name}@L{mid_layer} score: " |
| + json.dumps({k: (f"{v:.4f}" if isinstance(v, (int, float)) else v) for k, v in score.items()}) |
| ) |
| print_master(f"[MID-ONLY] Saved mid-layer similarity records -> {sims_path}") |
|
|
| elapsed = time.time() - start_time |
| return score, elapsed |
|
|
| def run_early_exit_queries( |
| model: MMEBModel, |
| processor, |
| model_args: ModelArguments, |
| data_args: DataArguments, |
| training_args: TrainingArguments, |
| qry_dataset: Dataset, |
| cand_mid_dict: dict, |
| cand_last_dict: dict, |
| ee_cfg: dict, |
| dataset_name: str, |
| out_dir: str, |
| global_ranking: bool = True, |
| ): |
| """ |
| 仅在线早停推理(不画曲线),并额外输出: |
| - 每个 query 的中间层 vs cand_last 的 top-K 相似度 (mid2last) |
| - 对未早停的 query,再输出最后一层 vs cand_last 的 top-K 相似度 (last2last) |
| |
| 输出文件: |
| {out_dir}/{dataset}_score_earlyexit.json - 检索指标 |
| {out_dir}/{dataset}_pred_earlyexit.jsonl - 预测列表(原有) |
| {out_dir}/{dataset}_sim_earlyexit.jsonl - 本函数新增的相似度信息(需 EE_SAVE=1) |
| """ |
| device = training_args.device |
| local_rank = dist.get_rank() if dist.is_initialized() else 0 |
| is_main = (not dist.is_initialized()) or (local_rank == 0) |
|
|
| |
| cand_ids = list(cand_mid_dict.keys()) |
| cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)} |
|
|
| cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32) |
| cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32) |
|
|
| cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16) |
| cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16) |
|
|
| |
| collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") |
| loader = DataLoader( |
| qry_dataset, |
| batch_size=training_args.per_device_eval_batch_size, |
| collate_fn=collator, |
| num_workers=training_args.dataloader_num_workers |
| ) |
|
|
| pred_dicts = [] |
|
|
| |
| save_scores = ee_cfg.get("save", False) |
| topk_sim = int(ee_cfg.get("topk", 1024)) |
| sim_records = [] if (save_scores and is_main) else None |
|
|
| |
| aop_cfg = getattr(model.encoder, "aop_prune_config", None) |
| _orig_enabled = None |
| side_enable = True |
| if isinstance(aop_cfg, dict) and aop_cfg: |
| _orig_enabled = aop_cfg.get("enabled", False) |
| apply_to = aop_cfg.get("apply_to", "qry") |
| side_enable = (apply_to == "both") or (apply_to == "qry") |
|
|
| |
| k_conf = 2 |
| tau = float(ee_cfg["tau"]) |
| method= ee_cfg["method"] |
| temp = float(ee_cfg["temp"]) |
|
|
| idx_global = 0 |
| start_time = time.time() |
|
|
| for inputs, infos in tqdm( |
| loader, |
| desc=f"[EE] {dataset_name}@L{ee_cfg['layer']} (rank {local_rank})", |
| disable=local_rank > 0, |
| ): |
| inputs = batch_to_device(inputs, device) |
|
|
| |
| orig_cfg = None |
| if isinstance(aop_cfg, dict) and aop_cfg: |
| orig_cfg = dict(aop_cfg) |
| aop_layer = aop_cfg.get("layer_idx", None) |
| ee_layer = int(ee_cfg["layer"]) |
| apply_to = aop_cfg.get("apply_to", "qry").strip().lower() |
| aop_on_mid = bool( |
| _orig_enabled and side_enable and |
| (aop_layer is not None) and (aop_layer < ee_layer) and |
| (apply_to in {"qry", "both"}) |
| ) |
| aop_cfg_mid = dict(aop_cfg) |
| aop_cfg_mid["enabled"] = aop_on_mid |
| setattr(model.encoder, "aop_prune_config", aop_cfg_mid) |
|
|
| with torch.no_grad(), torch.autocast( |
| device_type="cuda", dtype=torch.bfloat16, enabled=True |
| ): |
| out_mid = model.encoder( |
| **inputs, |
| return_dict=True, |
| output_hidden_states=False, |
| stop_at_layer=int(ee_cfg["layer"]), |
| compute_lm_head=False, |
| ) |
|
|
| if isinstance(orig_cfg, dict): |
| setattr(model.encoder, "aop_prune_config", orig_cfg) |
|
|
| |
| hs_mid = getattr(out_mid, "last_hidden_state", None) |
| if hs_mid is None: |
| assert out_mid.hidden_states is not None and len(out_mid.hidden_states) > 0 |
| hs_mid = out_mid.hidden_states[-1] |
| am_mid = getattr(out_mid, "attention_mask", None) |
| if am_mid is None: |
| am_mid = inputs.get("attention_mask", None) |
| if hasattr(am_mid, "device") and am_mid.device != hs_mid.device: |
| am_mid = am_mid.to(hs_mid.device) |
| reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(device=device, dtype=torch.bfloat16) |
|
|
| B = reps_mid_t.size(0) |
| use_local = (not global_ranking) |
|
|
| |
| if not use_local: |
| |
| scores_t = reps_mid_t @ cand_mid_t.T |
| vals_t, idxs_t = torch.topk( |
| scores_t, k=min(k_conf, scores_t.size(1)), dim=1 |
| ) |
| p_t = torch.softmax(vals_t / max(temp, 1e-6), dim=1) |
| if vals_t.size(1) >= 2: |
| margin_t = vals_t[:, 0] - vals_t[:, 1] |
| p1p2_t = p_t[:, 0] - p_t[:, 1] |
| else: |
| margin_t = torch.full((B,), float("inf"), device=device, dtype=vals_t.dtype) |
| p1p2_t = torch.ones(B, device=device, dtype=vals_t.dtype) |
| H_t = -(p_t * (torch.log(p_t + 1e-12))).sum(dim=1) / math.log(max(vals_t.size(1),1)) |
| conf_map = { |
| "margin": margin_t, |
| "p1p2": p1p2_t, |
| "entropy": 1.0 - H_t, |
| "gini": (p_t ** 2).sum(dim=1), |
| } |
| confs_t = conf_map.get(method, margin_t) |
| exit_mask = (confs_t >= tau).detach().cpu().numpy().astype(bool) |
| else: |
| |
| confs = [] |
| for b in range(B): |
| cand_local = infos[b]["cand_names"] |
| rows = [cand_id2row.get(str(cid), -1) for cid in cand_local] |
| rows = [r for r in rows if r >= 0] |
| if len(rows) == 0: |
| confs.append(0.0) |
| continue |
| cmat_t = cand_mid_t[rows] |
| sv_t = (reps_mid_t[b:b+1] @ cmat_t.T)[0] |
| k = 2 if sv_t.size(0) >= 2 else 1 |
| vals_t, _ = torch.topk(sv_t, k=k, dim=0) |
| p_t = torch.softmax(vals_t / max(temp, 1e-6), dim=0) |
| if k >= 2: |
| margin = (vals_t[0] - vals_t[1]).item() |
| p1p2 = (p_t[0] - p_t[1]).item() |
| else: |
| margin, p1p2 = float("inf"), 1.0 |
| H = (-(p_t * (torch.log(p_t + 1e-12))).sum() / math.log(max(k,1))).item() |
| gini = ((p_t ** 2).sum()).item() |
| d = {"margin": margin, "p1p2": p1p2, "entropy": 1.0 - H, "gini": gini} |
| confs.append(d.get(method, margin)) |
| exit_mask = (np.array(confs) >= tau) |
|
|
| |
| |
| exit_indices = np.where(exit_mask)[0].tolist() |
| |
| need_indices = np.where(~exit_mask)[0].tolist() |
|
|
| |
| for j in exit_indices: |
| |
| if not use_local: |
| scores_mid_mid = (reps_mid_t[j:j+1] @ cand_mid_t.T)[0] |
| order = torch.argsort(scores_mid_mid, dim=0, descending=True).detach().cpu().numpy() |
| cids = [cand_ids[i] for i in order] |
| else: |
| cand_local = infos[j]["cand_names"] |
| rows = [cand_id2row.get(str(cid), -1) for cid in cand_local] |
| rows = [r for r in rows if r >= 0] |
| if len(rows) == 0: |
| cids = [] |
| else: |
| cmat_t = cand_mid_t[rows] |
| sv = (reps_mid_t[j:j+1] @ cmat_t.T)[0] |
| order_local = torch.argsort(sv, dim=0, descending=True).detach().cpu().numpy() |
| cids = [str(cand_local[i]) for i in order_local] |
| rel_docids = infos[j]["label_name"] |
| if not isinstance(rel_docids, list): |
| rel_docids = [rel_docids] |
| pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None}) |
|
|
| |
| if save_scores and is_main: |
| if not use_local: |
| scores_mid_last = (reps_mid_t[j:j+1] @ cand_last_t.T)[0].detach().float().cpu() |
| Nc = scores_mid_last.size(0) |
| K = min(topk_sim, Nc) |
| mid_vals, mid_inds = torch.topk(scores_mid_last, k=K, dim=0) |
| mid_ids = [cand_ids[i] for i in mid_inds.tolist()] |
| else: |
| cand_local = infos[j]["cand_names"] |
| rows = [cand_id2row.get(str(cid), -1) for cid in cand_local] |
| rows = [r for r in rows if r >= 0] |
| if len(rows) == 0: |
| mid_vals = torch.empty(0) |
| mid_ids = [] |
| else: |
| cmat_last = cand_last_t[rows] |
| sv_last = (reps_mid_t[j:j+1] @ cmat_last.T)[0].detach().float().cpu() |
| K = min(topk_sim, sv_last.size(0)) |
| mid_vals, mid_inds = torch.topk(sv_last, k=K, dim=0) |
| mid_ids = [str(cand_local[i]) for i in mid_inds.tolist()] |
|
|
| rec = { |
| "qid": int(idx_global + j), |
| "early_exit": True, |
| "mid_topk_scores": mid_vals.tolist() if mid_vals.numel() > 0 else [], |
| "mid_topk_cand_ids": mid_ids, |
| "last_topk_scores": None, |
| "last_topk_cand_ids": None, |
| } |
| sim_records.append(rec) |
|
|
| |
| if len(need_indices) > 0: |
| |
| if isinstance(aop_cfg, dict) and aop_cfg: |
| aop_resume = dict(aop_cfg) |
| aop_resume["enabled"] = bool(_orig_enabled and side_enable) |
| setattr(model.encoder, "aop_prune_config", aop_resume) |
|
|
| interm = getattr(out_mid, "intermediate_state", None) |
| assert interm is not None, "Model must return intermediate_state when stop_at_layer is set." |
|
|
| hs = interm["hidden_states"].detach() |
| am = interm["attention_mask"].detach() |
| pos = interm["position_ids"].detach() |
| vm = interm.get("vision_mask", None) |
| tm = interm.get("text_mask", None) |
| next_layer = int(interm["next_layer_idx"]) |
|
|
| hs_sub = hs[need_indices] |
| am_sub = am[need_indices] |
| pos_sub = pos[:, need_indices, :] |
| vm_sub = vm[need_indices] if vm is not None else None |
| tm_sub = tm[need_indices] if tm is not None else None |
| resume_state = { |
| "hidden_states": hs_sub, |
| "attention_mask": am_sub, |
| "position_ids": pos_sub, |
| "vision_mask": vm_sub, |
| "text_mask": tm_sub, |
| "next_layer_idx": next_layer, |
| } |
|
|
| with torch.no_grad(), torch.autocast( |
| device_type="cuda", dtype=torch.bfloat16, enabled=True |
| ): |
| out_last = model.encoder( |
| return_dict=True, |
| output_hidden_states=False, |
| stop_at_layer=None, |
| resume_state=resume_state, |
| compute_lm_head=False, |
| ) |
|
|
| hs_last = getattr(out_last, "last_hidden_state", None) |
| if hs_last is None: |
| assert out_last.hidden_states is not None and len(out_last.hidden_states) > 0 |
| hs_last = out_last.hidden_states[-1] |
| am_last = getattr(out_last, "attention_mask", None) |
| if am_last is None: |
| am_last = am_sub |
| if hasattr(am_last, "device") and am_last.device != hs_last.device: |
| am_last = am_last.to(hs_last.device) |
|
|
| reps_last_t = model._pooling(hs_last, am_last).detach().to(device=device, dtype=torch.bfloat16) |
|
|
| if not use_local: |
| scores_last_all = (reps_last_t @ cand_last_t.T).detach().float().cpu() |
| for k, j in enumerate(need_indices): |
| |
| row = scores_last_all[k] |
| order = torch.argsort(row, dim=0, descending=True).tolist() |
| cids = [cand_ids[i] for i in order] |
| rel_docids = infos[j]["label_name"] |
| if not isinstance(rel_docids, list): |
| rel_docids = [rel_docids] |
| pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None}) |
|
|
| |
| if save_scores and is_main: |
| |
| scores_mid_last = (reps_mid_t[j:j+1] @ cand_last_t.T)[0].detach().float().cpu() |
| Nc = scores_mid_last.size(0) |
| K = min(topk_sim, Nc) |
| mid_vals, mid_inds = torch.topk(scores_mid_last, k=K, dim=0) |
| mid_ids = [cand_ids[i] for i in mid_inds.tolist()] |
|
|
| |
| last_row = row |
| last_vals, last_inds = torch.topk(last_row, k=K, dim=0) |
| last_ids = [cand_ids[i] for i in last_inds.tolist()] |
|
|
| rec = { |
| "qid": int(idx_global + j), |
| "early_exit": False, |
| "mid_topk_scores": mid_vals.tolist(), |
| "mid_topk_cand_ids": mid_ids, |
| "last_topk_scores": last_vals.tolist(), |
| "last_topk_cand_ids": last_ids, |
| } |
| sim_records.append(rec) |
| else: |
| |
| for k, j in enumerate(need_indices): |
| cand_local = infos[j]["cand_names"] |
| rows = [cand_id2row.get(str(cid), -1) for cid in cand_local] |
| rows = [r for r in rows if r >= 0] |
| if len(rows) == 0: |
| cids = [] |
| rel_docids = infos[j]["label_name"] |
| if not isinstance(rel_docids, list): |
| rel_docids = [rel_docids] |
| pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None}) |
| if save_scores and is_main: |
| rec = { |
| "qid": int(idx_global + j), |
| "early_exit": False, |
| "mid_topk_scores": [], |
| "mid_topk_cand_ids": [], |
| "last_topk_scores": [], |
| "last_topk_cand_ids": [], |
| } |
| sim_records.append(rec) |
| continue |
|
|
| |
| cmat_last = cand_last_t[rows] |
| sv_last = (reps_last_t[k:k+1] @ cmat_last.T)[0].detach().float().cpu() |
| order_local = torch.argsort(sv_last, dim=0, descending=True).tolist() |
| cids = [str(cand_local[i]) for i in order_local] |
| rel_docids = infos[j]["label_name"] |
| if not isinstance(rel_docids, list): |
| rel_docids = [rel_docids] |
| pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None}) |
|
|
| |
| if save_scores and is_main: |
| cmat_last = cand_last_t[rows] |
| |
| sv_mid_last = (reps_mid_t[j:j+1] @ cmat_last.T)[0].detach().float().cpu() |
| K = min(topk_sim, sv_mid_last.size(0)) |
| mid_vals, mid_inds = torch.topk(sv_mid_last, k=K, dim=0) |
| mid_ids = [str(cand_local[i]) for i in mid_inds.tolist()] |
| |
| sv_last_row = sv_last |
| last_vals, last_inds = torch.topk(sv_last_row, k=K, dim=0) |
| last_ids = [str(cand_local[i]) for i in last_inds.tolist()] |
|
|
| rec = { |
| "qid": int(idx_global + j), |
| "early_exit": False, |
| "mid_topk_scores": mid_vals.tolist(), |
| "mid_topk_cand_ids": mid_ids, |
| "last_topk_scores": last_vals.tolist(), |
| "last_topk_cand_ids": last_ids, |
| } |
| sim_records.append(rec) |
|
|
| idx_global += B |
|
|
| |
| metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] |
| score = RankingMetrics(metrics_to_report).evaluate(pred_dicts) |
|
|
| if is_main: |
| os.makedirs(out_dir, exist_ok=True) |
| |
| with open(os.path.join(out_dir, f"{dataset_name}_score_earlyexit.json"), "w") as f: |
| json.dump(score, f, indent=4) |
| if ee_cfg.get("save", False): |
| with open(os.path.join(out_dir, f"{dataset_name}_pred_earlyexit.jsonl"), "w", encoding="utf-8") as f: |
| for row in pred_dicts: |
| f.write(json.dumps(row, ensure_ascii=False) + "\n") |
| |
| if save_scores and sim_records is not None: |
| sims_path = os.path.join(out_dir, f"{dataset_name}_sim_earlyexit.jsonl") |
| with open(sims_path, "w", encoding="utf-8") as f: |
| for rec in sim_records: |
| f.write(json.dumps(rec, ensure_ascii=False) + "\n") |
| print_master(f"[EE] Saved mid/last similarity records -> {sims_path}") |
|
|
| elapsed = time.time() - start_time |
| return score, elapsed |
|
|
| def make_layer_tag(keep_layers: int | None): |
| return f"layer{keep_layers}" if keep_layers and keep_layers > 0 else "layerlast" |
|
|
| def dot_sim(a: np.ndarray, b: np.ndarray) -> np.ndarray: |
| |
| return a @ b.T |
|
|
| def build_score_details(qid: int, cand_ids: list, score_vec: np.ndarray, ranked_indices: np.ndarray): |
| return { |
| "qid": int(qid), |
| "cand_scores": [ |
| {"cand_id": str(cand_ids[i]), "score": float(score_vec[i])} |
| for i in ranked_indices |
| ] |
| } |
|
|
| def top1_top2_margin(score_vec: np.ndarray) -> float: |
| if len(score_vec) < 2: |
| return float("inf") |
| top2 = np.partition(score_vec, -2)[-2:] |
| top2.sort() |
| return float(top2[-1] - top2[-2]) |
|
|
| def simulate_early_exit_by_margin( |
| sims_mid: list[dict], sims_last: list[dict], labels: list[list[str]], metrics_to_report: list[str], |
| taus: list[float], rank_global: bool |
| ): |
| """ |
| sims_mid / sims_last: 每个query一个dict: {cand_id: score} |
| labels: 每个query的正样本cand_id列表 |
| 返回:不同tau下的覆盖率、指标 |
| """ |
| assert len(sims_mid) == len(sims_last) == len(labels) |
| N = len(labels) |
| results = [] |
|
|
| from src.eval_utils.metrics import RankingMetrics |
| metrics = RankingMetrics(metrics_to_report) |
|
|
| |
| def to_pred_dicts(use_mid_mask: list[bool]) -> list[dict]: |
| pred_dicts = [] |
| for qid in range(N): |
| sims_use = sims_mid[qid] if use_mid_mask[qid] else sims_last[qid] |
| |
| ranked = sorted(sims_use.items(), key=lambda x: -x[1]) |
| pred_dicts.append({ |
| "prediction": [cid for cid, _ in ranked], |
| "label": labels[qid], |
| "rel_scores": None |
| }) |
| return pred_dicts |
|
|
| |
| margins = [] |
| for qid in range(N): |
| |
| if len(sims_mid[qid]) == 0: |
| margins.append(0.0) |
| continue |
| scores = np.array(list(sims_mid[qid].values()), dtype=np.float32) |
| margins.append(top1_top2_margin(scores)) |
|
|
| margins = np.array(margins, dtype=np.float32) |
|
|
| for tau in taus: |
| use_mid_mask = (margins >= tau).tolist() |
| pred_dicts = to_pred_dicts(use_mid_mask) |
| score_dict = metrics.evaluate(pred_dicts) |
| coverage = float(np.mean(use_mid_mask)) |
| results.append({ |
| "tau": tau, |
| "coverage": coverage, |
| **score_dict |
| }) |
| return results |
|
|
| def top1_top2_margin_from_array(score_vec: np.ndarray) -> float: |
| if score_vec is None or len(score_vec) == 0: |
| return 0.0 |
| if len(score_vec) == 1: |
| return float('inf') |
| |
| top2 = np.partition(score_vec, -2)[-2:] |
| top2.sort() |
| return float(top2[-1] - top2[-2]) |
|
|
| logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| |
| timing_info = {} |
| token_info = { |
| "vision_tokens": 0, |
| "text_input_tokens": 0, |
| "text_output_tokens": 0, |
| "total_llm_input_tokens": 0, |
| } |
|
|
| |
| def timing_pre_hook(module, input): |
| module_id = id(module) |
| if module_id not in timing_info: |
| timing_info[module_id] = [] |
| timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__)) |
|
|
| def timing_post_hook(module, input, output): |
| module_id = id(module) |
| if module_id not in timing_info: |
| |
| return |
|
|
| timing_info[module_id].append((time.time(), 'post', module.__class__.__name__)) |
|
|
| |
| module_name = module.__class__.__name__ |
| if "vision" in module_name.lower() and "transformer" in module_name.lower(): |
| if isinstance(output, torch.Tensor): |
| token_info["vision_tokens"] = output.shape[0] |
| elif hasattr(output, 'last_hidden_state'): |
| token_info["vision_tokens"] = output.last_hidden_state.shape[1] |
|
|
|
|
| def register_model_hooks(model): |
| registered_modules = [] |
| |
| core_model = model.encoder if hasattr(model, "encoder") and model.encoder is not None else model |
| |
| |
| if hasattr(core_model, 'visual') and core_model.visual is not None: |
| vision_module = core_model.visual |
| vision_module.register_forward_pre_hook(timing_pre_hook) |
| vision_module.register_forward_hook(timing_post_hook) |
| registered_modules.append(vision_module) |
| print_master(f"Registered hooks for vision module: {vision_module.__class__.__name__}") |
| else: |
| print_master(f"WARNING: No 'visual' attribute found on core_model ({type(core_model)}).") |
|
|
|
|
| |
| if hasattr(core_model, 'visual') and hasattr(core_model.visual, 'merger') and core_model.visual.merger is not None: |
| merger_module = core_model.visual.merger |
| merger_module.register_forward_pre_hook(timing_pre_hook) |
| merger_module.register_forward_hook(timing_post_hook) |
| registered_modules.append(merger_module) |
| print_master(f"Registered hooks for merger module: {merger_module.__class__.__name__}") |
| else: |
| print_master(f"WARNING: No 'merger' attribute found on core_model.visual ({type(getattr(core_model, 'visual', 'N/A'))}).") |
|
|
| |
| if hasattr(core_model, 'model') and core_model.model is not None: |
| llm_main_module = core_model.model |
| llm_main_module.register_forward_pre_hook(timing_pre_hook) |
| llm_main_module.register_forward_hook(timing_post_hook) |
| registered_modules.append(llm_main_module) |
| print_master(f"Registered hooks for LLM main module: {llm_main_module.__class__.__name__}") |
| else: |
| print_master(f"WARNING: No 'model' attribute found on core_model ({type(core_model)}).") |
|
|
|
|
| |
| if hasattr(core_model, 'lm_head') and core_model.lm_head is not None: |
| lm_head_module = core_model.lm_head |
| lm_head_module.register_forward_pre_hook(timing_pre_hook) |
| lm_head_module.register_forward_hook(timing_post_hook) |
| registered_modules.append(lm_head_module) |
| print_master(f"Registered hooks for LM head module: {lm_head_module.__class__.__name__}") |
| else: |
| print_master(f"WARNING: No 'lm_head' attribute found on core_model ({type(core_model)}).") |
|
|
|
|
| if not registered_modules: |
| print_master("Warning: No major modules found for hook registration. Check model architecture.") |
| return registered_modules |
|
|
|
|
| def pad_dataset_to_divisible(dataset, world_size): |
| num_samples = len(dataset) |
| if num_samples % world_size == 0: |
| return dataset, num_samples |
|
|
| num_to_add = world_size - (num_samples % world_size) |
| padded_size = num_samples + num_to_add |
|
|
| padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) |
| padded_dataset = concatenate_datasets([dataset, padding_data]) |
| return padded_dataset, padded_size |
|
|
| def encode_embeddings( |
| model: MMEBModel, |
| loader: DataLoader, |
| training_args: TrainingArguments, |
| model_args: ModelArguments, |
| full_dataset: Dataset, |
| encode_side: str, |
| description: str = "Encoding" |
| ) -> tuple[np.ndarray, list, list, list]: |
| """ |
| Encodes embeddings for a given dataset using the model, handling both standard and |
| late-interaction models in a DDP-safe manner. |
| Returns: |
| - embeddings: np.ndarray |
| - infos_or_ids: list |
| - batch_stats_list: list |
| - img_token_masks: list[None | list[bool]] # NEW |
| """ |
| local_rank = dist.get_rank() if dist.is_initialized() else 0 |
| world_size = dist.get_world_size() if dist.is_initialized() else 1 |
|
|
| |
| is_late_interaction = (model_args.model_backbone == COLPALI) |
|
|
| local_embeds = [] |
| local_gt_infos = [] |
| local_max_len = 0 |
| |
| |
| batch_stats_list = [] |
|
|
| |
| local_img_token_masks = [] |
| local_txt_token_masks = [] |
| local_post_attn_masks = [] |
|
|
| |
| local_token_records = [] |
|
|
| model.eval() |
|
|
| |
| registered_hooks = register_model_hooks(model) |
|
|
| |
| def _search_key(obj, key: str): |
| |
| if isinstance(obj, dict): |
| if key in obj: |
| return obj[key] |
| for v in obj.values(): |
| r = _search_key(v, key) |
| if r is not None: |
| return r |
| elif isinstance(obj, (list, tuple)): |
| for v in obj: |
| r = _search_key(v, key) |
| if r is not None: |
| return r |
| return None |
|
|
| def _to_serializable_mask_list(mask_list, batch_size: int): |
| |
| if mask_list is None: |
| return [None] * batch_size |
|
|
| out = [] |
| if isinstance(mask_list, (list, tuple)): |
| for m in mask_list: |
| if m is None: |
| out.append(None) |
| elif torch.is_tensor(m): |
| out.append(m.detach().cpu().tolist()) |
| elif isinstance(m, np.ndarray): |
| out.append(m.tolist()) |
| else: |
| |
| out.append(m) |
| elif torch.is_tensor(mask_list): |
| |
| out = mask_list.detach().cpu().tolist() |
| elif isinstance(mask_list, np.ndarray): |
| out = mask_list.tolist() |
| else: |
| |
| out = [None] * batch_size |
|
|
| |
| if isinstance(out, list): |
| if len(out) < batch_size: |
| out = out + [None] * (batch_size - len(out)) |
| elif len(out) > batch_size: |
| out = out[:batch_size] |
| return out |
| |
| def _to_bool_lists(m, batch_size: int): |
| lst = _to_serializable_mask_list(m, batch_size) |
| |
| out = [] |
| for x in lst: |
| if x is None: |
| out.append(None) |
| else: |
| |
| out.append([bool(int(v)) for v in x]) |
| return out |
|
|
| with torch.no_grad(): |
| for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0): |
| |
| timing_info.clear() |
| token_info["vision_tokens"] = 0 |
| token_info["text_input_tokens"] = 0 |
| token_info["text_output_tokens"] = 0 |
| token_info["total_llm_input_tokens"] = 0 |
|
|
| inputs = batch_to_device(inputs, training_args.device) |
| current_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs and inputs['input_ids'] is not None else 1 |
|
|
| with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): |
| start_inference_time = time.time() |
| |
| aop_cfg = getattr(model.encoder, "aop_prune_config", None) |
| _orig_enabled = None |
| if isinstance(aop_cfg, dict) and aop_cfg: |
| _orig_enabled = aop_cfg.get("enabled", False) |
| apply_to = aop_cfg.get("apply_to", "qry") |
| side_enable = (apply_to == "both") or (apply_to == encode_side) |
| aop_cfg["enabled"] = bool(side_enable and _orig_enabled) |
| setattr(model.encoder, "aop_prune_config", aop_cfg) |
| if encode_side == "qry": |
| output = model(qry=inputs) |
| reps = output["qry_reps"].detach() |
| local_gt_infos.extend(dataset_info) |
| else: |
| output = model(tgt=inputs) |
| reps = output["tgt_reps"].detach() |
| local_gt_infos.extend([info["cand_name"] for info in dataset_info]) |
| |
| if isinstance(aop_cfg, dict) and _orig_enabled is not None: |
| aop_cfg["enabled"] = _orig_enabled |
| setattr(model.encoder, "aop_prune_config", aop_cfg) |
| end_inference_time = time.time() |
|
|
| |
| img_masks_raw = None |
| txt_masks_raw = None |
| post_attn_raw = None |
|
|
| if isinstance(output, dict): |
| img_masks_raw = _search_key(output, "image_token_bool_masks") |
| txt_masks_raw = _search_key(output, "text_token_bool_masks") |
| post_attn_raw = _search_key(output, "post_attention_mask") |
|
|
| |
| if img_masks_raw is None and hasattr(model, "image_token_bool_masks"): |
| img_masks_raw = getattr(model, "image_token_bool_masks") |
| if txt_masks_raw is None and hasattr(model, "text_token_bool_masks"): |
| txt_masks_raw = getattr(model, "text_token_bool_masks") |
| if post_attn_raw is None and hasattr(model, "post_attention_mask"): |
| post_attn_raw = getattr(model, "post_attention_mask") |
|
|
| img_masks_serializable = _to_serializable_mask_list(img_masks_raw, current_batch_size) |
| txt_masks_serializable = _to_serializable_mask_list(txt_masks_raw, current_batch_size) |
| post_attn_serializable = _to_serializable_mask_list(post_attn_raw, current_batch_size) |
|
|
| local_img_token_masks.extend(img_masks_serializable) |
| local_txt_token_masks.extend(txt_masks_serializable) |
| local_post_attn_masks.extend(post_attn_serializable) |
|
|
| |
| cfg = getattr(model.encoder, "config", None) |
| |
| input_ids = inputs.get("input_ids", None) |
| attn2d_pre = inputs.get("attention_mask", None) |
| if input_ids is None or attn2d_pre is None or cfg is None: |
| |
| pre_vis_counts = [0] * current_batch_size |
| pre_txt_counts = [0] * current_batch_size |
| pre_tot_counts = [0] * current_batch_size |
| else: |
| iid = input_ids |
| am = attn2d_pre.to(torch.bool) |
| image_token_id = getattr(cfg, "image_token_id", None) |
| video_token_id = getattr(cfg, "video_token_id", None) |
| bos_id = getattr(cfg, "bos_token_id", None) |
| eos_id = getattr(cfg, "eos_token_id", None) |
| pad_id = getattr(cfg, "pad_token_id", None) |
|
|
| is_image = (iid == image_token_id) if (image_token_id is not None and image_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool) |
| is_video = (iid == video_token_id) if (video_token_id is not None and video_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool) |
| is_vision = is_image | is_video |
|
|
| is_special = torch.zeros_like(iid, dtype=torch.bool) |
| for tid in [bos_id, eos_id, pad_id]: |
| if tid is not None and tid >= 0: |
| is_special |= (iid == tid) |
|
|
| pre_txt_mask = am & (~is_vision) & (~is_special) |
| pre_vis_mask = am & is_vision |
|
|
| pre_vis_counts = pre_vis_mask.sum(dim=1).tolist() |
| pre_txt_counts = pre_txt_mask.sum(dim=1).tolist() |
| pre_tot_counts = am.sum(dim=1).tolist() |
|
|
| |
| post_text_masks = _to_bool_lists(txt_masks_raw, current_batch_size) |
| post_image_masks = _to_bool_lists(img_masks_raw, current_batch_size) |
| post_attn_masks = _to_bool_lists(post_attn_raw, current_batch_size) |
|
|
| sum_pre_text = 0; sum_post_text = 0 |
| sum_pre_vis = 0; sum_post_vis = 0 |
| sum_pre_tot = 0; sum_post_tot = 0 |
|
|
| for i in range(current_batch_size): |
| pre_text = int(pre_txt_counts[i]) if i < len(pre_txt_counts) else 0 |
| pre_vis = int(pre_vis_counts[i]) if i < len(pre_vis_counts) else 0 |
| pre_tot = int(pre_tot_counts[i]) if i < len(pre_tot_counts) else 0 |
|
|
| |
| m_text = post_text_masks[i] if post_text_masks is not None and i < len(post_text_masks) else None |
| m_img = post_image_masks[i] if post_image_masks is not None and i < len(post_image_masks) else None |
| m_attn = post_attn_masks[i] if post_attn_masks is not None and i < len(post_attn_masks) else None |
|
|
| if m_attn is None: |
| post_text = 0; post_vis = 0; post_tot = 0 |
| else: |
| |
| if m_text is not None: |
| post_text = sum(1 for a, t in zip(m_attn, m_text) if a and t) |
| else: |
| post_text = 0 |
| if m_img is not None: |
| post_vis = sum(1 for a, v in zip(m_attn, m_img) if a and v) |
| else: |
| post_vis = 0 |
| post_tot = sum(1 for a in m_attn if a) |
|
|
| |
| sum_pre_text += pre_text; sum_post_text += post_text |
| sum_pre_vis += pre_vis; sum_post_vis += post_vis |
| sum_pre_tot += pre_tot; sum_post_tot += post_tot |
|
|
| |
| local_token_records.append({ |
| "side": encode_side, |
| "pre": {"text": pre_text, "vision": pre_vis, "total": pre_tot}, |
| "post": {"text": post_text, "vision": post_vis, "total": post_tot}, |
| "delta":{"text": pre_text - post_text, "vision": pre_vis - post_vis, "total": pre_tot - post_tot}, |
| }) |
|
|
| |
| if 'input_ids' in inputs and inputs['input_ids'] is not None: |
| token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1] |
| token_info["text_input_tokens"] = token_info["total_llm_input_tokens"] - token_info["vision_tokens"] |
| token_info["text_input_tokens"] = max(0, token_info["text_input_tokens"]) |
|
|
| |
| batch_inference_time = end_inference_time - start_inference_time |
| |
| current_batch_stats = { |
| "batch_size": current_batch_size, |
| "total_inference_time_seconds": batch_inference_time, |
| "module_inference_times": {}, |
| "token_counts": { |
| "visual_tokens": token_info["vision_tokens"], |
| "language_input_tokens_raw": token_info["text_input_tokens"], |
| "llm_total_input_tokens": token_info["total_llm_input_tokens"], |
| "language_output_tokens": token_info["text_output_tokens"], |
| } |
| } |
| current_batch_stats["token_reduction"] = { |
| "sum_pre_text": sum_pre_text, |
| "sum_post_text": sum_post_text, |
| "sum_pre_vision": sum_pre_vis, |
| "sum_post_vision": sum_post_vis, |
| "sum_pre_total": sum_pre_tot, |
| "sum_post_total": sum_post_tot, |
| } |
|
|
| |
| for module_obj in registered_hooks: |
| module_id = id(module_obj) |
| module_name = module_obj.__class__.__name__ |
| times = timing_info.get(module_id, []) |
| durations = [] |
| pre_times = {} |
| for t, event_type, _ in times: |
| if event_type == 'pre': |
| pre_times[module_id] = t |
| elif event_type == 'post' and module_id in pre_times: |
| duration = t - pre_times.pop(module_id) |
| durations.append(duration) |
| |
| if durations: |
| current_batch_stats["module_inference_times"][module_name] = { |
| "total": sum(durations), |
| "count": len(durations), |
| "avg": sum(durations) / len(durations) |
| } |
| else: |
| current_batch_stats["module_inference_times"][module_name] = { |
| "total": 0.0, |
| "count": 0, |
| "avg": 0.0 |
| } |
| |
| batch_stats_list.append(current_batch_stats) |
|
|
| |
| print_rank(f"\n--- Inference Statistics for {encode_side} batch (Rank {local_rank}) ---") |
| print_rank(f"Batch Inference took: {batch_inference_time:.4f} seconds") |
| print_rank("--- Module Inference Timing Statistics ---") |
| for module_name, stats in current_batch_stats["module_inference_times"].items(): |
| print_rank(f"**{module_name}**: Total: {stats['total']:.6f}s, Count: {stats['count']}, Avg: {stats['avg']:.6f}s") |
| print_rank("--- Token Count Statistics ---") |
| print_rank(f"**视觉 token 数量**: {current_batch_stats['token_counts']['visual_tokens']}") |
| print_rank(f"**语言输入 token 数量 (仅原始文本)**: {current_batch_stats['token_counts']['language_input_tokens_raw']}") |
| print_rank(f"**LLM总输入 token 数量 (包含视觉 + 格式化文本)**: {current_batch_stats['token_counts']['llm_total_input_tokens']}") |
| print_rank(f"**语言输出 token 数量**: {current_batch_stats['token_counts']['language_output_tokens']}") |
|
|
| if is_late_interaction and reps.dim() == 3: |
| local_max_len = max(local_max_len, reps.shape[1]) |
|
|
| local_embeds.append(reps) |
|
|
| if not local_embeds: |
| |
| return np.array([]), [], [], [] |
|
|
| |
| if is_late_interaction: |
| if dist.is_initialized(): |
| |
| local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device) |
| dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX) |
| global_max_len = local_max_len_tensor.item() |
| else: |
| global_max_len = local_max_len |
|
|
| |
| padded_embeds = [] |
| for reps_batch in local_embeds: |
| if reps_batch.dim() == 3: |
| B, L, H = reps_batch.shape |
| padding_size = global_max_len - L |
| padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0) |
| padded_embeds.append(padded_batch) |
| else: |
| padded_embeds.append(reps_batch) |
|
|
| embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous() |
| else: |
| embeds_tensor = torch.cat(local_embeds, dim=0).contiguous() |
|
|
| |
| if dist.is_initialized() and full_dataset.num_rows >= world_size: |
| print_master(f"Gathering {encode_side} embeddings across all ranks...") |
|
|
| |
| output_shape = list(embeds_tensor.shape) |
| output_shape[0] = full_dataset.num_rows |
| embeds_tensor = embeds_tensor.to(training_args.device) |
| gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device) |
| dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor) |
| final_embeddings = gathered_embeds_tensor.cpu().float().numpy() |
|
|
| |
| gathered_gt_infos = [None for _ in range(world_size)] |
| dist.all_gather_object(gathered_gt_infos, local_gt_infos) |
| all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys] |
|
|
| gathered_batch_stats = [None for _ in range(world_size)] |
| dist.all_gather_object(gathered_batch_stats, batch_stats_list) |
| all_batch_stats = [stats for rank_stats in gathered_batch_stats for stats in rank_stats] |
|
|
| |
| gathered_masks = [None for _ in range(world_size)] |
| dist.all_gather_object(gathered_masks, local_img_token_masks) |
| all_img_token_masks = [m for rank_list in gathered_masks for m in rank_list] |
| |
| gathered_txt_masks = [None for _ in range(world_size)] |
| dist.all_gather_object(gathered_txt_masks, local_txt_token_masks) |
| all_txt_token_masks = [m for rank_list in gathered_txt_masks for m in rank_list] |
|
|
| |
| gathered_post_attn = [None for _ in range(world_size)] |
| dist.all_gather_object(gathered_post_attn, local_post_attn_masks) |
| all_post_attn_masks = [m for rank_list in gathered_post_attn for m in rank_list] |
|
|
| |
| gathered_token_recs = [None for _ in range(world_size)] |
| dist.all_gather_object(gathered_token_recs, local_token_records) |
| all_token_records = [r for rank_list in gathered_token_recs for r in rank_list] |
| else: |
| all_gt_infos = local_gt_infos |
| final_embeddings = embeds_tensor.cpu().float().numpy() |
| all_batch_stats = batch_stats_list |
| all_img_token_masks = local_img_token_masks |
| all_txt_token_masks = local_txt_token_masks |
| all_post_attn_masks = local_post_attn_masks |
| all_token_records = local_token_records |
|
|
| return final_embeddings, all_gt_infos, all_batch_stats, all_img_token_masks, all_txt_token_masks, all_token_records |
|
|
| |
| def encode_candidates_both_layers( |
| model: MMEBModel, |
| loader: DataLoader, |
| training_args: TrainingArguments, |
| model_args: ModelArguments, |
| full_dataset: Dataset, |
| mid_layer: int, |
| ) -> tuple[np.ndarray, np.ndarray, list]: |
| """ |
| 单次forward到最后一层,直接从 hidden_states 取: |
| - mid_hidden = hidden_states[mid_layer] # 表示经过 mid_layer 层后的状态(见Qwen2_5_VLModel的all_hidden_states定义) |
| - last_hidden = hidden_states[-1] # 最后一层norm后的状态 |
| 然后用 _pooling(attention_mask) 取句向量,返回: |
| - cand_mid_embeds: np.ndarray [Nc, D] |
| - cand_last_embeds: np.ndarray [Nc, D] |
| - cand_ids: list[str] |
| 说明: |
| - cand 侧默认不做 AOP 剪枝(AOP_APPLY=qry 时天然关闭),因此 mid/last 的序列长度一致,可直接用原 attention_mask 做池化。 |
| """ |
| local_rank = dist.get_rank() if dist.is_initialized() else 0 |
| model.eval() |
|
|
| all_mid = [] |
| all_last = [] |
| all_ids = [] |
|
|
| with torch.no_grad(): |
| for inputs, dataset_info in tqdm(loader, desc=f"Candidates[BOTH] (rank {local_rank})", disable=local_rank > 0): |
| inputs = batch_to_device(inputs, training_args.device) |
| |
| aop_cfg = getattr(model.encoder, "aop_prune_config", None) |
| _orig_enabled = None |
| if isinstance(aop_cfg, dict) and aop_cfg: |
| _orig_enabled = aop_cfg.get("enabled", False) |
| apply_to = aop_cfg.get("apply_to", "qry") |
| side_enable = (apply_to == "both") or (apply_to == "cand") |
| aop_cfg["enabled"] = bool(side_enable and _orig_enabled) |
| setattr(model.encoder, "aop_prune_config", aop_cfg) |
|
|
| with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): |
| |
| out = model.encoder( |
| **inputs, |
| return_dict=True, |
| output_hidden_states=True, |
| stop_at_layer=None, |
| ) |
|
|
| |
| hs_list = out.hidden_states |
| assert hs_list is not None and len(hs_list) > mid_layer, \ |
| f"hidden_states is None or too short. Need index {mid_layer}, got len={0 if hs_list is None else len(hs_list)}" |
| mid_hs = hs_list[mid_layer] |
| last_hs = hs_list[-1] |
|
|
| |
| am = inputs.get("attention_mask", None) |
| if am is not None and hasattr(am, "device"): |
| if am.device != mid_hs.device: |
| am = am.to(mid_hs.device) |
|
|
| reps_mid = model._pooling(mid_hs, am) |
| reps_last = model._pooling(last_hs, am) |
|
|
| all_mid.append(reps_mid.detach().float().cpu()) |
| all_last.append(reps_last.detach().float().cpu()) |
| all_ids.extend([info["cand_name"] for info in dataset_info]) |
|
|
| |
| if isinstance(aop_cfg, dict) and _orig_enabled is not None: |
| aop_cfg["enabled"] = _orig_enabled |
| setattr(model.encoder, "aop_prune_config", aop_cfg) |
|
|
| if not all_mid: |
| return np.array([]), np.array([]), [] |
|
|
| cand_mid_embeds = torch.cat(all_mid, dim=0).numpy() |
| cand_last_embeds = torch.cat(all_last, dim=0).numpy() |
| return cand_mid_embeds, cand_last_embeds, all_ids |
|
|
| def main(): |
| |
| if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): |
| dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) |
| local_rank = dist.get_rank() if dist.is_initialized() else 0 |
| world_size = dist.get_world_size() if dist.is_initialized() else 1 |
|
|
| print_master("Distributed init debug info:") |
| print_master(f"RANK: {os.environ.get('RANK')}") |
| print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}") |
| print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}") |
| print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}") |
| print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}") |
| if dist.is_initialized(): |
| print_rank(f"dist.get_rank(): {dist.get_rank()}") |
| print_rank(f"dist.get_world_size(): {dist.get_world_size()}") |
|
|
| |
| for arg in sys.argv: |
| if arg.startswith("--local-rank="): |
| rank = arg.split("=")[1] |
| sys.argv.remove(arg) |
| sys.argv.append('--local_rank') |
| sys.argv.append(rank) |
|
|
| |
| parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
| model_args: ModelArguments |
| data_args: DataArguments |
| training_args: TrainingArguments |
| os.makedirs(data_args.encode_output_path, exist_ok=True) |
|
|
| |
| layers_to_eval = get_env_eval_layers() |
| print_master(f"Eval layers (qry/tgt): {layers_to_eval}") |
|
|
| |
| hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
| if not getattr(model_args, "model_backbone", None): |
| model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) |
| setattr(model_args, 'model_backbone', model_backbone) |
| setattr(training_args, 'model_backbone', model_backbone) |
| print_master(f'Model Backbone: {model_args.model_backbone}') |
|
|
| |
| if local_rank == 0: |
| processor = load_processor(model_args, data_args) |
| model = MMEBModel.load(model_args, is_trainable=False, processor=processor) |
| print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...") |
| if torch.distributed.is_initialized(): |
| torch.distributed.barrier() |
| if local_rank != 0: |
| print_rank(f"Loading the model from cache...") |
| processor = load_processor(model_args, data_args) |
| time.sleep(random.randint(2 * local_rank, 3 * local_rank)) |
| model = MMEBModel.load(model_args, is_trainable=False, processor=processor) |
|
|
| model.eval() |
| model = model.to(training_args.device, dtype=torch.bfloat16) |
|
|
| |
| aop_cfg = get_env_aop_config() |
| if aop_cfg["enabled"]: |
| |
| setattr(model.encoder, "aop_prune_config", aop_cfg) |
| |
| attn_override = aop_cfg.get("attn_impl_override", "") |
| if attn_override: |
| try: |
| if hasattr(model.encoder, "model") and hasattr(model.encoder.model, "config"): |
| prev = model.encoder.model.config._attn_implementation |
| model.encoder.model.config._attn_implementation = attn_override |
| print_master(f"[AOP] override attn impl: {prev} -> {attn_override} (仅测试建议)") |
| except Exception as e: |
| print_master(f"[AOP] try override attn impl failed: {e}") |
|
|
| print_master("[AOP] AOP-Prune enabled with config: " + json.dumps({ |
| "apply_to": aop_cfg["apply_to"], |
| "layer_idx": aop_cfg["layer_idx"], |
| "mode": aop_cfg["mode"], |
| "delta": aop_cfg["delta"], |
| "K_hat": aop_cfg["K_hat"], |
| "keep_ratio": aop_cfg["keep_ratio"], |
| "min_keep": aop_cfg["min_keep"], |
| "use_bias": aop_cfg["use_bias"], |
| "margin_mid?": (aop_cfg["margin_mid"] is not None), |
| "prune_text": aop_cfg.get("prune_text", False), |
| "keep_ratio_text": aop_cfg.get("keep_ratio_text", None), |
| "keep_ratio_vision": aop_cfg.get("keep_ratio_vision", None), |
| "selection": aop_cfg.get("selection", "aop"), |
| "attn_agg": aop_cfg.get("attn_agg", "mean"), |
| })) |
| else: |
| print_master("[AOP] disabled (set AOP_ENABLED=1 to enable)") |
| |
| |
| model.set_inference_layers(qry_layers=None, tgt_layers=None) |
|
|
| with open(data_args.dataset_config, 'r') as yaml_file: |
| dataset_configs = yaml.safe_load(yaml_file) |
|
|
| |
| for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()): |
| if dist.is_initialized(): |
| dist.barrier() |
| print_master(f"\n--- Evaluating {dataset_name} ---") |
|
|
| |
| if data_args.data_basedir is not None: |
| for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: |
| if data_args.data_basedir and task_config.get(key): |
| task_config[key] = os.path.join(data_args.data_basedir, task_config[key]) |
|
|
| |
| full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config) |
| full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus) |
| eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset |
|
|
| if dist.is_initialized(): |
| world_size = dist.get_world_size() |
| padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size) |
| padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size) |
| eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size) |
| eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size) |
| else: |
| padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset |
|
|
|
|
| |
| ee_cfg = get_env_ee_config() |
| assert ee_cfg["enabled"], "EE_ENABLED must be 1 for EE-only pipeline." |
|
|
| |
| mid_layer = int(ee_cfg["layer"]) |
| mid_tag = make_layer_tag(mid_layer) |
| last_tag = "layerlast" |
|
|
| |
| cand_mid_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{mid_tag}") |
| cand_last_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{last_tag}") |
|
|
| |
| eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") |
| eval_cand_loader = DataLoader( |
| full_eval_cand_dataset, |
| batch_size=training_args.per_device_eval_batch_size, |
| collate_fn=eval_cand_collator, |
| num_workers=training_args.dataloader_num_workers |
| ) |
|
|
| |
| need_mid = (not os.path.exists(cand_mid_path)) |
| need_last = (not os.path.exists(cand_last_path)) |
|
|
| if need_mid or need_last: |
| print_master(f"[{dataset_name}] EE-only: encoding candidates BOTH layers in one pass (mid={mid_tag}, last={last_tag}) ...") |
| |
| model.set_inference_layers(qry_layers=None, tgt_layers=None) |
|
|
| cand_embeds_mid, cand_embeds_last, all_cand_ids = encode_candidates_both_layers( |
| model=model, |
| loader=eval_cand_loader, |
| training_args=training_args, |
| model_args=model_args, |
| full_dataset=full_eval_cand_dataset, |
| mid_layer=mid_layer, |
| ) |
| if local_rank == 0: |
| if need_mid: |
| cand_embed_dict_mid = {cid: emb for cid, emb in zip(all_cand_ids, cand_embeds_mid)} |
| with open(cand_mid_path, "wb") as f: |
| pickle.dump(cand_embed_dict_mid, f) |
| print_master(f"[{dataset_name}] EE-only: saved {mid_tag} candidate embeddings -> {cand_mid_path}") |
| if need_last: |
| cand_embed_dict_last = {cid: emb for cid, emb in zip(all_cand_ids, cand_embeds_last)} |
| with open(cand_last_path, "wb") as f: |
| pickle.dump(cand_embed_dict_last, f) |
| print_master(f"[{dataset_name}] EE-only: saved {last_tag} candidate embeddings -> {cand_last_path}") |
| else: |
| print_master(f"[{dataset_name}] EE-only: reuse existing candidates (mid={cand_mid_path}, last={cand_last_path})") |
|
|
| if dist.is_initialized(): |
| dist.barrier() |
|
|
| |
| if local_rank == 0: |
| with open(cand_mid_path, "rb") as f: |
| cand_mid_dict = pickle.load(f) |
|
|
| |
| mid_layer = int(ee_cfg["layer"]) |
| rank_global = task_config.get("eval_type", "global") == "global" |
| print_master(f"[{dataset_name}] Run MID-ONLY retrieval at layer={mid_layer}, global={rank_global}") |
|
|
| run_mid_layer_retrieval( |
| model=model, |
| processor=processor, |
| model_args=model_args, |
| data_args=data_args, |
| training_args=training_args, |
| qry_dataset=full_eval_qry_dataset, |
| cand_mid_dict=cand_mid_dict, |
| mid_layer=mid_layer, |
| dataset_name=dataset_name, |
| out_dir=data_args.encode_output_path, |
| global_ranking=rank_global, |
| topk_save=200, |
| save_qry_embeds=True, |
| ) |
|
|
| if dist.is_initialized(): |
| dist.barrier() |
| continue |
|
|
|
|
| if __name__ == '__main__': |
| main() |