code_SAS_VLM2Vec / src /trainer_early_exit_AOP_pooling.py
MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import Trainer
from src.utils import batch_to_device
from src.classifier_utils import HomogeneousBatchSampler
# 手动实现 Focal Loss (带 alpha/gamma 控制)
def sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2.0, reduction: str = "mean"):
"""
Loss = -alpha * (1 - p)^gamma * log(p)
"""
p = torch.sigmoid(inputs)
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = p * targets + (1 - p) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if reduction == "mean":
return loss.mean()
elif reduction == "sum":
return loss.sum()
return loss
class EarlyExitTrainer(Trainer):
def __init__(self, backbone_model, target_layer_idx, model_args, *args, **kwargs):
self.max_length = kwargs.pop("max_length", 512)
super().__init__(*args, **kwargs)
self.backbone = backbone_model.to(self.args.device)
self.backbone.eval()
self.target_layer_idx = target_layer_idx
self.model_args = model_args
# AOP 侧别启停(env):qry|tgt|both
self._aop_apply = os.getenv("AOP_APPLY", "both").strip().lower()
self._grad_check_done = False
def _rank_of_diagonal(self, sim_mat: torch.Tensor):
"""
返回每个样本的正例(diag)在本行中的排名(1-based),以及 topk 命中率。
"""
B = sim_mat.size(0)
# argsort 降序
order = torch.argsort(sim_mat, dim=1, descending=True) # [B, B]
gt = torch.arange(B, device=sim_mat.device).view(-1, 1) # [B,1]
# 找到每行中 diag 的位置
ranks = (order == gt).nonzero(as_tuple=False)[:, 1] + 1 # 1-based
top1 = (ranks == 1).float().mean().item()
top5 = (ranks <= 5).float().mean().item() if B >= 5 else float('nan')
top10 = (ranks <= 10).float().mean().item() if B >= 10 else float('nan')
return ranks, top1, top5, top10
# ---------------- Dataloader ----------------
def get_train_dataloader(self) -> DataLoader:
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_sampler = HomogeneousBatchSampler(
self.train_dataset,
batch_size=self._train_batch_size,
drop_last=self.args.dataloader_drop_last,
)
return DataLoader(
self.train_dataset,
batch_sampler=train_sampler,
collate_fn=self.data_collator,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
# ---------------- Optimizer ----------------
def create_optimizer(self):
if self.optimizer is None:
print(f"\n[Debug Rank {self.args.local_rank}] Creating Optimizer...")
decay_parameters = []
no_decay_parameters = []
trainable_count = 0
for name, param in self.model.named_parameters():
if not param.requires_grad:
continue
trainable_count += 1
if "bias" in name or "LayerNorm" in name or "BatchNorm" in name:
no_decay_parameters.append(param)
else:
decay_parameters.append(param)
print(f"[Debug] Found {trainable_count} trainable parameters.")
self.optimizer = torch.optim.AdamW(
[
{"params": decay_parameters, "weight_decay": self.args.weight_decay},
{"params": no_decay_parameters, "weight_decay": 0.0},
],
lr=self.args.learning_rate,
eps=self.args.adam_epsilon,
)
return self.optimizer
# ---------------- helpers: AOP 侧别开关 ----------------
def _enable_for_side(self, side: str) -> bool:
side = side.lower()
if self._aop_apply == "both": return True
return self._aop_apply == side
from contextlib import contextmanager
@contextmanager
def _aop_switch(self, enable: bool):
"""
暂时按侧别启停 AOP(仅该 forward),同步 wrapper 与底座。
"""
enc = self.backbone.encoder
old = getattr(enc, "aop_prune_config", None)
def _set_cfg(mod, cfg):
setattr(mod, "aop_prune_config", cfg)
base = mod.get_base_model() if hasattr(mod, "get_base_model") else None
if base is None and hasattr(mod, "model"):
base = mod.model
if base is not None:
setattr(base, "aop_prune_config", cfg)
if hasattr(base, "model"):
setattr(base.model, "aop_prune_config", cfg)
if old is not None and enable is False:
cfg = dict(old) if isinstance(old, dict) else None
if isinstance(cfg, dict): cfg["enabled"] = False
_set_cfg(enc, cfg)
try:
yield
finally:
_set_cfg(enc, old)
# ---------------- Pooling ----------------
def _perform_pooling(self, hidden_state, attention_mask):
"""
修复版 Pooling:增加索引边界检查,防止 CUDA Index Out of Bounds
"""
pooling_method = self.model_args.pooling
batch_size, seq_len, _ = hidden_state.shape
if pooling_method in ("last", "eos"):
if attention_mask is None:
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long, device=hidden_state.device)
# [关键修复] 检测填充方向
# 如果最后一列全是 1,说明没有右边 padding,极大概率是左填充 (Left Padding)
# 或者模型本身就是 Left Padding 的
is_left_padding = (attention_mask[:, -1].sum() == batch_size)
if is_left_padding:
# 左填充:有效内容挤在右边,直接取最后一个 Token
reps = hidden_state[:, -1, :]
else:
# 右填充:有效内容在左边,需要计算长度
if attention_mask.shape[1] > seq_len:
attention_mask = attention_mask[:, :seq_len]
eos_indices = attention_mask.sum(dim=1) - 1
eos_indices = eos_indices.clamp(min=0, max=seq_len - 1)
indices_expanded = eos_indices.unsqueeze(1).unsqueeze(2).expand(batch_size, 1, hidden_state.size(-1))
reps = torch.gather(hidden_state, 1, indices_expanded).squeeze(1)
else:
# 兜底 mean pooling 等其他逻辑
reps = hidden_state[:, -1, :]
if self.model_args.normalize:
reps = F.normalize(reps, p=2, dim=-1)
return reps
def _match_mask(self, h, pre_mask, post_mask):
"""
选择与该层 hidden_state 长度匹配的 mask(优先 post,再 pre;否则全1兜底)
"""
if post_mask is not None and post_mask.size(1) == h.size(1):
return post_mask
if pre_mask is not None and pre_mask.size(1) == h.size(1):
return pre_mask
return torch.ones(h.size(0), h.size(1), dtype=torch.long, device=h.device)
# ---------------- Loss entry ----------------
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
loss = self._compute_early_exit_loss(model, inputs)
return (loss, None) if return_outputs else loss
# ---------------- Core loss (V5) ----------------
def _compute_early_exit_loss(self, model, inputs) -> torch.Tensor:
self.backbone.eval()
model.train()
device = self.args.device
qry_inputs, tgt_inputs = inputs
# === 定义分块大小 (Chunk Size) ===
CHUNK_SIZE = 128
def forward_chunked(input_batch, side="tgt"):
"""
分块执行 Backbone Forward,提取特征后立即释放显存。
"""
# 1. 确定样本总数
total_len = input_batch["input_ids"].shape[0]
reps_mid_list = []
reps_last_list = []
# 2. 循环切片
for i in range(0, total_len, CHUNK_SIZE):
# 2.1 切片: 构造小 batch
chunk = {}
for k, v in input_batch.items():
if v is None:
chunk[k] = None
continue
if isinstance(v, torch.Tensor):
chunk[k] = v[i : i + CHUNK_SIZE]
elif isinstance(v, list):
chunk[k] = v[i : i + CHUNK_SIZE]
else:
chunk[k] = v
# 2.2 搬运到 GPU
chunk = batch_to_device(chunk, device)
# 2.3 Backbone Forward (No Grad)
with torch.no_grad():
with self._aop_switch(self._enable_for_side(side)):
outputs = self.backbone.encoder(
**chunk, return_dict=True, output_hidden_states=True
)
# 2.4 立即提取需要的层
pre_mask = chunk.get("attention_mask", None)
post_mask = getattr(outputs, "attention_mask", None)
# Mid Layer 处理
h_mid = outputs.hidden_states[self.target_layer_idx]
m_mid = self._match_mask(h_mid, pre_mask, post_mask)
r_mid = self._perform_pooling(h_mid, m_mid)
# Last Layer 处理
h_last = outputs.hidden_states[-1]
m_last = self._match_mask(h_last, pre_mask, post_mask)
r_last = self._perform_pooling(h_last, m_last)
# 2.5 存入列表
reps_mid_list.append(r_mid)
reps_last_list.append(r_last)
# 2.6 显式删除引用,辅助 GC 释放显存
del outputs, h_mid, h_last, chunk, pre_mask, post_mask
# torch.cuda.empty_cache()
# 3. 拼接所有小块
return torch.cat(reps_mid_list, dim=0), torch.cat(reps_last_list, dim=0)
# === 1. 使用分块函数提取特征 ===
tgt_reps_mid, tgt_reps_last = forward_chunked(tgt_inputs, side="tgt")
qry_reps_mid, qry_reps_last = forward_chunked(qry_inputs, side="qry")
# === 2. 计算相似度与 Loss ===
batch_size = qry_reps_mid.size(0)
backbone_ptr = self.backbone.module if hasattr(self.backbone, "module") else self.backbone
temp = getattr(backbone_ptr, "temperature", 0.02)
# 相似度计算
cos_mid = torch.matmul(qry_reps_mid, tgt_reps_mid.T)
cos_last = torch.matmul(qry_reps_last, tgt_reps_last.T)
scores_mid = cos_mid / temp
probs_mid = torch.softmax(scores_mid, dim=1)
# ====== Debug: 仅保留排名打印,移除 Mask 检查 ======
if self.state.global_step < 3 and self.args.local_rank == 0:
# 1) 计算正例排名(mid/last)
ranks_mid, top1_mid, top5_mid, top10_mid = self._rank_of_diagonal(cos_mid)
ranks_last, top1_last, top5_last, top10_last = self._rank_of_diagonal(cos_last)
# [已删除] _check_mask 及其调用,因为无法访问 hidden states 了
# 2) 环境/侧别开关
print(
f"[DBG][env] AOP_ENABLED={os.getenv('AOP_ENABLED')} "
f"APPLY={os.getenv('AOP_APPLY')} "
f"LAYER={os.getenv('AOP_LAYER')} "
f"SELECTION={os.getenv('AOP_SELECTION')} "
f"KEEP_T={os.getenv('AOP_KEEP_RATIO_TEXT')} "
f"KEEP_V={os.getenv('AOP_KEEP_RATIO_VISION')} "
f"VPOOL_ENABLED={os.getenv('VPOOL_ENABLED')} "
f"VPOOL_LAYER={os.getenv('VPOOL_LAYER')}",
flush=True
)
# 3) 打印正例排名分布与 topk
def _brief(ranks):
r = ranks.detach().cpu()
return {
"min": int(r.min().item()),
"p25": int(r.kthvalue(max(1, int(0.25*len(r)))).values.item()) if len(r) >= 4 else None,
"med": int(r.median().item()),
"p75": int(r.kthvalue(max(1, int(0.75*len(r)))).values.item()) if len(r) >= 4 else None,
"max": int(r.max().item())
}
print(f"[RANK][mid] top1={top1_mid:.2%} top5={top5_mid:.2%} top10={top10_mid:.2%} dist={_brief(ranks_mid)}", flush=True)
print(f"[RANK][last] top1={top1_last:.2%} top5={top5_last:.2%} top10={top10_last:.2%} dist={_brief(ranks_last)}", flush=True)
if top1_last < 0.4:
print("[WARN] last layer top1 < 40%. 建议先 AOP_ENABLED=0/VPOOL_ENABLED=0 进行对照,确认基座检索能力。", flush=True)
# === 特征构造 (27维) ===
# ... (以下代码保持不变) ...
diag_cos = cos_mid.max(dim=1)[0]
sorted_cos, _ = torch.sort(cos_mid, dim=1, descending=True)
s2_cos = sorted_cos[:, 1] if sorted_cos.size(1) > 1 else sorted_cos[:, 0]
margin_mid = diag_cos - s2_cos
# 统计量
margin_mean = margin_mid.mean()
margin_std = margin_mid.std(unbiased=False) + 1e-6
z_margin_mid = (margin_mid - margin_mean) / margin_std
margin_median = margin_mid.median()
mad = (margin_mid - margin_median).abs().median() + 1e-6
mad_margin_mid = (margin_mid - margin_median) / mad
p1_mid = probs_mid.max(dim=1)[0]
H_mid = -(probs_mid * torch.log(probs_mid + 1e-6)).sum(dim=1)
gini_mid = 1.0 - (probs_mid ** 2).sum(dim=1)
TOPK = min(16, probs_mid.size(1))
topk_vals, _ = torch.topk(probs_mid, k=TOPK, dim=1)
topk_mean = topk_vals.mean(dim=1)
topk_std = topk_vals.std(dim=1, unbiased=False)
topk_cv = topk_std / (topk_mean + 1e-6)
centered = topk_vals - topk_mean.unsqueeze(1)
var = (centered ** 2).mean(dim=1) + 1e-6
m4 = (centered ** 4).mean(dim=1)
topk_kurt = m4 / (var ** 2)
topk_med = topk_vals.median(dim=1).values
row_mean_cos = cos_mid.mean(dim=1)
row_med_cos = cos_mid.median(dim=1).values
s1_over_mean = diag_cos - row_mean_cos
s1_over_med = diag_cos - row_med_cos
sorted_probs, _ = torch.sort(probs_mid, dim=1, descending=True)
p1 = sorted_probs[:, 0]
p2 = sorted_probs[:, 1] if sorted_probs.size(1) > 1 else sorted_probs[:, 0]
shape_H = -(sorted_probs * torch.log(sorted_probs + 1e-6)).sum(dim=1)
shape_gini = 1.0 - (sorted_probs ** 2).sum(dim=1)
R = min(10, sorted_probs.size(1))
x = torch.arange(R, device=device, dtype=sorted_probs.dtype)
x_centered = x - x.mean()
denom = (x_centered ** 2).sum()
y = torch.log(sorted_probs[:, :R] + 1e-6)
slope = (x_centered.unsqueeze(0) * y).sum(dim=1) / denom
row_mean_p = probs_mid.mean(dim=1)
row_std_p = probs_mid.std(dim=1, unbiased=False) + 1e-6
z1 = (p1_mid - row_mean_p) / row_std_p
center_p = probs_mid - row_mean_p.unsqueeze(1)
m3 = (center_p ** 3).mean(dim=1)
skew = m3 / (row_std_p ** 3 + 1e-6)
s1_over_sk = p1_mid - skew
TAIL_K = min(10, sorted_probs.size(1))
tail_mean = sorted_probs[:, -TAIL_K:].mean(dim=1)
HEAD_K = min(5, sorted_probs.size(1))
head5_mean = sorted_probs[:, :HEAD_K].mean(dim=1)
mask_ratio = torch.zeros_like(diag_cos)
mask_len = torch.zeros_like(diag_cos)
mask_runs = torch.zeros_like(diag_cos)
scalar_inputs = torch.stack(
[
diag_cos, s2_cos, margin_mid, z_margin_mid, mad_margin_mid,
p1_mid, H_mid, gini_mid,
topk_mean, topk_std, topk_cv, topk_kurt, topk_med,
s1_over_mean, s1_over_med,
p1, p2, shape_H, shape_gini, slope, z1, s1_over_sk,
tail_mean, head5_mean,
mask_ratio, mask_len, mask_runs,
],
dim=1,
)
# Modality Index
modality_idx = torch.zeros(batch_size, dtype=torch.long, device=device)
if "pixel_values" in qry_inputs and qry_inputs["pixel_values"] is not None:
pv = qry_inputs["pixel_values"]
if isinstance(pv, list):
for i, item in enumerate(pv):
if item is not None: modality_idx[i] = 1
elif isinstance(pv, torch.Tensor) and pv.numel() > 0:
modality_idx.fill_(1)
# Labels
gt = torch.arange(batch_size, device=device)
mid_top1 = cos_mid.argmax(dim=1)
last_top1 = cos_last.argmax(dim=1)
mid_hit = mid_top1.eq(gt)
last_hit = last_top1.eq(gt)
need_last = (~mid_hit) & last_hit
labels = need_last.float().unsqueeze(1)
both_correct = mid_hit & last_hit
both_wrong = (~mid_hit) & (~last_hit)
# =======================================================
# 分类器前向(float32)
# =======================================================
scalar_inputs_f32 = scalar_inputs.float()
qry_reps_mid_f32 = qry_reps_mid.float()
logits = model(scalar_inputs_f32, modality_idx, qry_emb=qry_reps_mid_f32)
# Loss: V5调整版 Focal Loss (alpha=0.80, gamma=3.0)
loss = sigmoid_focal_loss(logits, labels, alpha=0.80, gamma=3.0, reduction="mean")
pred_probs = torch.sigmoid(logits)
# Logging
if self.state.global_step < 10 and self.args.local_rank == 0:
pos_ratio = labels.mean().item()
neg_ratio = 1.0 - pos_ratio
print(f"\n[Probe Step {self.state.global_step}] Loss: {loss.item():.4f}", flush=True)
print(f" - Pred Probs (need_last=1): mean={pred_probs.mean().item():.4f}, std={pred_probs.std().item():.4f}", flush=True)
print(f" - Labels: need_last={pos_ratio:.4f}, safe={neg_ratio:.4f}", flush=True)
print(f" - mid_hit: {mid_hit.float().mean().item():.2%}, last_hit: {last_hit.float().mean().item():.2%}", flush=True)
print(f" - both_correct: {both_correct.float().mean().item():.2%}, both_wrong: {both_wrong.float().mean().item():.2%}", flush=True)
return loss
# ---------------- training_step ----------------
def training_step(self, model, inputs, num_items_in_batch=None) -> torch.Tensor:
model.train()
inputs = self._prepare_inputs(inputs)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
if self.args.n_gpu > 1:
loss = loss.mean()
self.accelerator.backward(loss)
if not self._grad_check_done and self.args.local_rank == 0:
print(f"\n[Gradient Check After Backward - Step {self.state.global_step}]", flush=True)
inner_model = model.module if hasattr(model, "module") else model
has_grad = False
total_grad_norm = 0.0
for name, param in inner_model.named_parameters():
if param.grad is not None:
has_grad = True
grad_norm = param.grad.norm().item()
total_grad_norm += grad_norm ** 2
total_grad_norm = total_grad_norm ** 0.5
print(f" - Total Grad Norm: {total_grad_norm:.6f}", flush=True)
print(f" - Has Gradient: {has_grad}", flush=True)
if self.state.global_step >= 2:
self._grad_check_done = True
return loss.detach() / self.args.gradient_accumulation_steps