code_SAS_VLM2Vec / src /model /model_add_mlp.py
MgGladys's picture
Add files using upload-large-folder tool
0a937d7 verified
import os
from typing import Dict, Tuple, Optional
import torch
import torch.distributed as dist
from torch import nn, Tensor
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig
from peft import LoraConfig, get_peft_model, PeftModel
from src.model.processor import QWEN2_5_VL_TOKENSELECTION
from src.arguments import ModelArguments, TrainingArguments
from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, \
backbone2model, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V
from src.arguments import ModelArguments
from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, INTERNVIDEO2, \
QWEN2_VL_TOKENSELECTION, backbone2model, GME, VLM_IMAGE_TOKENS, LamRA, LamRA_QWEN2_5, COLPALI
from src.model.baseline_backbone.colpali import ColPali
from src.model.baseline_backbone.gme.gme_inference import GmeQwen2VL
from src.model.baseline_backbone.lamra.lamra_inference import LamRAQwen2VL
from src.model.baseline_backbone.lamra.lamra_qwen25_inference import LamRAQwen25VL
from src.model.baseline_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM
from src.model.baseline_backbone.llava_next import LlavaNextForConditionalGeneration
from transformers import modeling_utils
if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None:
modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", 'rowwise']
try:
from safetensors.torch import load_file as safe_load
print('safetensors.torch')
except Exception:
safe_load = None
# 在文件内(例如 MMEBModel 前)新增一个 Token-wise MLP
class TokenWiseMLP(nn.Module):
def __init__(self, hidden_size: int, mlp_hidden_size: Optional[int] = None, dropout: float = 0.1):
super().__init__()
h = mlp_hidden_size or hidden_size
self.proj = nn.Sequential(
nn.Linear(hidden_size, h),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(h, hidden_size),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, T, H]
B, T, H = x.shape
y = self.proj(x.reshape(B * T, H)).reshape(B, T, H)
return y
class MMEBModel(nn.Module):
TRANSFORMER_CLS = AutoModelForCausalLM
def __init__(self,
encoder: PreTrainedModel,
pooling: str = 'last',
normalize: bool = False,
temperature: float = 0.02,
):
super().__init__()
self.config = encoder.config
self.encoder = encoder
self.pooling = pooling
self.normalize = normalize
self.temperature = temperature
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
self.is_ddp = dist.is_initialized() and dist.get_world_size() > 1
self.process_rank = dist.get_rank() if self.is_ddp else 0
self.world_size = dist.get_world_size() if self.is_ddp else 1
# 早退出相关:默认关闭,构建/加载时按 ModelArguments 决定是否开启
self.enable_early_mlp = False
self.early_layer_index = 20 # 第20层
self.early_loss_weight = 0.0 # 第二阶段如需联合loss,这里可>0;第一阶段只训early会在trainer里走compute_early_only
def add_early_mlp(self, layer_index: int = 20,
mlp_hidden_size: Optional[int] = None,
dropout: float = 0.1):
# 将 early_mlp 挂到 encoder 下,保证保存/加载时与encoder一起处理
hidden_size = getattr(self.config, "hidden_size", None)
if hidden_size is None:
raise ValueError("config.hidden_size 未找到,无法初始化 early MLP")
self.enable_early_mlp = True
self.early_layer_index = int(layer_index)
# 将模块挂在 encoder 下,state_dict key 会是 encoder.early_mlp.*
self.encoder.early_mlp = TokenWiseMLP(hidden_size, mlp_hidden_size, dropout)
# 将配置写入 encoder.config,保存时会带进 config.json
setattr(self.encoder.config, "enable_early_mlp", True)
setattr(self.encoder.config, "early_layer_index", self.early_layer_index)
setattr(self.encoder.config, "early_mlp_hidden_size", mlp_hidden_size or hidden_size)
setattr(self.encoder.config, "early_mlp_dropout", dropout)
def _apply_early(self, hidden: torch.Tensor) -> torch.Tensor:
# 对第 early_layer_index 层的 token hidden 过 MLP 用于早退出
if not self.enable_early_mlp or not hasattr(self.encoder, "early_mlp"):
raise RuntimeError("early_mlp 未启用,无法计算早退出表示")
return self.encoder.early_mlp(hidden)
def encode_input(self, input, return_early: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
返回:
- 当 return_early=False: 只返回 final_pooled
- 当 return_early=True: 返回 (final_pooled, early_pooled)
"""
if getattr(self, "model_backbone", None) == INTERNVIDEO2:
if "input_ids" in input.keys():
# text side
text_output = self.encoder.get_text_encoder()(
input["input_ids"],
attention_mask=input["attention_mask"],
return_dict=True,
mode="text",
)
text_embeds = text_output.last_hidden_state
pooled_text_embeds = text_embeds[:, 0]
pooled_output = self.encoder.text_proj(pooled_text_embeds)
pooled_output /= pooled_output.norm(dim=-1, keepdim=True)
return pooled_output
else:
_, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True)
vfeat = self.encoder.vision_proj(vfeat)
vfeat /= vfeat.norm(dim=-1, keepdim=True)
return vfeat
elif getattr(self, "model_backbone", None) in [GME, LamRA, LamRA_QWEN2_5]:
# pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True)
texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]] # we are actually passing video queries so this should not happen
images = []
for imgs in input['images']:
# if multi images are given, select the middle frame only
if isinstance(imgs, list):
imgs = imgs[len(imgs) // 2]
assert not isinstance(imgs, list) # make sure we have extracted the middle frame and it is no longer a list
images.append(imgs)
else:
images.append(imgs)
pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images)
return pooled_output
elif getattr(self, "model_backbone", None) == COLPALI:
pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True)
return pooled_output
elif getattr(self, "model_backbone", None) == LLAVA_NEXT:
input['pixel_values'] = input['pixel_values'].squeeze(dim=1)
input['image_sizes'] = input['image_sizes'].squeeze(dim=1)
hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True)
hidden_states = hidden_states.hidden_states[-1]
pooled_output = self._pooling(hidden_states, input['attention_mask'])
return pooled_output
else:
# 泛化:大多数HF模型这里都能拿到 hidden_states
outputs = self.encoder(**input, return_dict=True, output_hidden_states=True)
hidden_states = outputs.hidden_states
final_hidden = hidden_states[-1]
final_pooled = self._pooling(final_hidden, input['attention_mask'])
if return_early and self.enable_early_mlp:
idx = min(max(0, self.early_layer_index), len(hidden_states) - 1)
early_hidden = hidden_states[idx]
early_hidden = self._apply_early(early_hidden)
early_pooled = self._pooling(early_hidden, input['attention_mask'])
return final_pooled, early_pooled
return final_pooled
def _pooling(self, last_hidden_state, attention_mask):
if self.pooling == 'last' or self.pooling == 'eos':
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if attention_mask is not None else False
batch_size = last_hidden_state.shape[0]
if attention_mask is None:
reps = last_hidden_state[:, -1, :]
elif left_padding:
reps = last_hidden_state[torch.arange(batch_size), -1, :]
else:
eos_indices = attention_mask.sum(dim=1) - 1
reps = last_hidden_state[
torch.arange(batch_size, device=last_hidden_state.device), eos_indices]
else:
raise NotImplementedError
if self.normalize:
reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
return reps
@classmethod
def build(cls, model_args: ModelArguments, **kwargs):
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
model_backbone = get_backbone_name(hf_config=config)
print_master(f'Loading backbone [{model_backbone}] from {model_args.model_name}')
# Loading the base model
if model_backbone == PHI3V:
config._attn_implementation = "eager"
config.padding_side = "right"
config.use_cache = False
base_model = Phi3VForCausalLM.from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
elif model_backbone == LLAVA_NEXT:
config.use_cache = False
config.padding_side = "left"
base_model = LlavaNextForConditionalGeneration.from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
elif model_backbone in [QWEN2_VL, QWEN2_5_VL]:
config._attn_implementation = "flash_attention_2"
config.padding_side = "left"
config.use_cache = False
base_model = backbone2model[model_backbone].from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
elif model_backbone in [QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION]:
config._attn_implementation = "flash_attention_2"
config.padding_side = "left"
config.use_cache = False
from .utils import parse_layer_type
lm_qwen_layer = 28
vis_qwen_layer = 32
lm_skip_layer = parse_layer_type(model_args.lm_skip_layer, lm_qwen_layer)
vis_skip_layer = parse_layer_type(model_args.vis_skip_layer, vis_qwen_layer)
base_model = backbone2model[model_backbone].from_pretrained(
model_args.model_name,
config=config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
lm_skip_layer=lm_skip_layer,
vis_skip_layer=vis_skip_layer,
)
else:
config.use_cache = False
base_model = cls.TRANSFORMER_CLS.from_pretrained(
model_args.model_name, **kwargs, config=config,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
trust_remote_code=True)
# if model_args.lora:
# print_master(f'Loading lora adapter from {base_model}')
# lora_config = LoraConfig(
# r=model_args.lora_r,
# lora_alpha=model_args.lora_alpha,
# target_modules=model_args.lora_target_modules.split(','),
# lora_dropout=model_args.lora_dropout,
# init_lora_weights="gaussian",
# use_dora=True,
# inference_mode=False
# )
# lora_model = get_peft_model(base_model, lora_config)
# model = cls(
# encoder=lora_model,
# pooling=model_args.pooling,
# normalize=model_args.normalize,
# temperature=model_args.temperature
# )
# else:
# model = cls(
# encoder=base_model,
# pooling=model_args.pooling,
# normalize=model_args.normalize,
# temperature=model_args.temperature
# )
if model_args.lora:
# 修改开始:优先从 model_args.model_name 加载已有LoRA;否则才新建
def _has_adapter_files(path: str):
if not (path and os.path.isdir(path)):
return False
for fname in ("adapter_model.safetensors", "adapter_model.bin", "adapter_config.json"):
if os.path.exists(os.path.join(path, fname)):
return True
return False
if _has_adapter_files(model_args.model_name):
print_master(f"[build] detected LoRA adapter in '{model_args.model_name}', loading pretrained adapter.")
lora_config = LoraConfig.from_pretrained(model_args.model_name)
lora_model = PeftModel.from_pretrained(
base_model, model_args.model_name, config=lora_config, is_trainable=True
)
# 一般不需要 merge,这里保持Adapter在线,后续freeze/训练由外层控制
else:
print_master(f"[build] no adapter files in '{model_args.model_name}', create a new LoRA adapter.")
lora_config = LoraConfig(
r=model_args.lora_r,
lora_alpha=model_args.lora_alpha,
target_modules=model_args.lora_target_modules.split(','),
lora_dropout=model_args.lora_dropout,
init_lora_weights="gaussian",
use_dora=True,
inference_mode=False
)
lora_model = get_peft_model(base_model, lora_config)
model = cls(
encoder=lora_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
# 修改结束
else:
model = cls(
encoder=base_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
# 早退出 MLP(可选)
if getattr(model_args, "enable_early_mlp", False):
model.add_early_mlp(
layer_index=getattr(model_args, "early_layer_index", 20),
mlp_hidden_size=getattr(model_args, "early_mlp_hidden_size", None),
dropout=getattr(model_args, "early_mlp_dropout", 0.1),
)
# 第二阶段如需联合loss,可以通过 model_args.early_loss_weight 传入
model.early_loss_weight = float(getattr(model_args, "early_loss_weight", 0.0))
model.model_backbone = model_backbone
return model
@classmethod
def load(cls, model_args: ModelArguments, is_trainable=True, **kwargs):
# Loading the base model
model_name_or_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
if not hasattr(model_args, "model_backbone") or not model_args.model_backbone:
model_backbone = get_backbone_name(hf_config=config, model_type=model_args.model_type)
setattr(model_args, 'model_backbone', model_backbone)
print_master(f'Loading backbone [{model_args.model_backbone}] from {model_name_or_path}')
if model_args.model_backbone in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V}:
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
config._attn_implementation = "flash_attention_2"
config.vision_config._attn_implementation = "flash_attention_2"
base_model = backbone2model[model_args.model_backbone].from_pretrained(
model_args.model_name,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
config=config
)
elif model_args.model_backbone == PHI3V:
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
config.use_cache = False
config.padding_side = "right"
base_model = Phi3VForCausalLM.from_pretrained(model_args.model_name, **kwargs, config=config,
torch_dtype=torch.bfloat16, trust_remote_code=True)
base_model.padding_side = "right"
elif model_args.model_backbone == INTERNVIDEO2:
print_master(f'Loading backbone [{model_args.model_backbone}] from {"src/model/vlm_backbone/internvideo2/"}')
config = AutoConfig.from_pretrained("src/model/vlm_backbone/internvideo2/",
trust_remote_code=True)
base_model = backbone2model[model_args.model_backbone].from_pretrained("src/model/vlm_backbone/internvideo2/", config=config,
trust_remote_code=True)
elif model_args.model_backbone == GME:
base_model = GmeQwen2VL(model_args.model_name, processor=kwargs['processor'])
setattr(base_model, 'config', config)
elif model_args.model_backbone == LamRA:
base_model = LamRAQwen2VL(model_args.model_name)
setattr(base_model, 'config', config)
elif model_args.model_backbone == LamRA_QWEN2_5:
base_model = LamRAQwen25VL(model_args.model_name)
setattr(base_model, 'config', config)
elif model_args.model_backbone == COLPALI:
base_model = ColPali.from_pretrained(model_args.model_name)
setattr(base_model, 'config', config)
else:
# Loading external base model from HF
config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
config.use_cache = False
base_model = cls.TRANSFORMER_CLS.from_pretrained(
model_name_or_path, **kwargs, config=config,
torch_dtype=torch.bfloat16,
trust_remote_code=True)
# Building the model on top of the base
if model_args.lora:
print_master(f'Loading LoRA from {model_name_or_path}')
# lora_config = LoraConfig.from_pretrained(model_name_or_path)
# lora_model = PeftModel.from_pretrained(base_model, model_name_or_path, config=lora_config, is_trainable=is_trainable)
# lora_model.load_adapter(model_name_or_path, lora_model.active_adapter, is_trainable=is_trainable)
# if not is_trainable:
# lora_model = lora_model.merge_and_unload()
# model = cls(
# encoder=lora_model,
# pooling=model_args.pooling,
# normalize=model_args.normalize,
# temperature=model_args.temperature
# )
# 修改开始:检测 adapter 是否存在,不存在则回落到 model_args.model_name
def _has_adapter_files(path: str):
if not (path and os.path.isdir(path)):
return False
for fname in ("adapter_model.safetensors", "adapter_model.bin", "adapter_config.json"):
if os.path.exists(os.path.join(path, fname)):
return True
return False
adapter_source = model_name_or_path if _has_adapter_files(model_name_or_path) else model_args.model_name
if adapter_source != model_name_or_path:
print_master(f"[load] adapter files not found in '{model_name_or_path}', fallback to '{adapter_source}'")
# 某些 peft 版本要求 config 单独加载
lora_config = LoraConfig.from_pretrained(adapter_source)
lora_model = PeftModel.from_pretrained(base_model, adapter_source, config=lora_config, is_trainable=is_trainable)
lora_model.load_adapter(adapter_source, lora_model.active_adapter, is_trainable=is_trainable)
if not is_trainable:
lora_model = lora_model.merge_and_unload()
model = cls(
encoder=lora_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
# 修改结束
else:
model = cls(
encoder=base_model,
pooling=model_args.pooling,
normalize=model_args.normalize,
temperature=model_args.temperature
)
model.model_backbone = model_args.model_backbone
# 早退出 MLP:按配置/检查点信息启用,并尝试从checkpoint加载early_mlp权重
enable_early = bool(getattr(model_args, "enable_early_mlp", False) or getattr(model.encoder.config, "enable_early_mlp", False))
if enable_early:
layer_index = int(getattr(model_args, "early_layer_index", getattr(model.encoder.config, "early_layer_index", 20)))
mlp_hidden_size = getattr(model_args, "early_mlp_hidden_size", getattr(model.encoder.config, "early_mlp_hidden_size", None))
dropout = getattr(model_args, "early_mlp_dropout", getattr(model.encoder.config, "early_mlp_dropout", 0.1))
model.add_early_mlp(layer_index=layer_index, mlp_hidden_size=mlp_hidden_size, dropout=dropout)
model.early_loss_weight = float(getattr(model_args, "early_loss_weight", 0.0))
# 从checkpoint加载 early_mlp.* 权重
ckpt_dir = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name
if os.path.isdir(ckpt_dir):
pt_path = os.path.join(ckpt_dir, "pytorch_model.bin")
st_path = os.path.join(ckpt_dir, "model.safetensors")
state = None
if os.path.exists(pt_path):
state = torch.load(pt_path, map_location="cpu")
elif os.path.exists(st_path) and safe_load is not None:
state = safe_load(st_path)
if state is not None:
early_state = {
k.replace("early_mlp.", "", 1): v
for k, v in state.items()
if k.startswith("early_mlp.")
}
if early_state:
missing, unexpected = model.encoder.early_mlp.load_state_dict(early_state, strict=False)
print_master(
f"Loaded early_mlp weights from {ckpt_dir} (missing={missing}, unexpected={unexpected})"
)
# 新增兜底:尝试从 early_mlp.bin 回载
bin_path = os.path.join(ckpt_dir, "early_mlp.bin")
if os.path.exists(bin_path):
try:
early_bin = torch.load(bin_path, map_location="cpu")
missing, unexpected = model.encoder.early_mlp.load_state_dict(early_bin, strict=False)
print_master(
f"Loaded early_mlp weights from early_mlp.bin (missing={missing}, unexpected={unexpected})"
)
except Exception as e:
print_master(f"Failed to load early_mlp from early_mlp.bin: {e}")
return model
def save(self, output_dir: str):
self.encoder.save_pretrained(output_dir)
def forward(self, qry: Dict[str, torch.Tensor] = None, tgt: Dict[str, torch.Tensor] = None,
compute_early_only: bool = False, *args, **kwargs):
# 修改开始:支持单侧前向(GradCache 取表征时会这样调用)
only_q = (qry is not None) and (tgt is None)
only_t = (tgt is not None) and (qry is None)
if only_q or only_t:
# GradCache 的 get_rep_fn 期望的是“最终层”的检索表征;这里不要用 early,只返回 final
single = qry if only_q else tgt
rep = self.encode_input(single, return_early=False) # [B, D]
return {"qry_reps": rep if only_q else None, "tgt_reps": rep if only_t else None}
# 如果两边都没传,说明 split_and_process 有问题,给出更友好的报错
if qry is None and tgt is None:
raise ValueError("MMEBModel.forward expected 'qry' and/or 'tgt' but got none. "
"Check split_and_process_vlm_inputs / training_step packaging.")
# 修改结束
# 下面是双侧(正常计算loss)的路径
if compute_early_only:
if not self.enable_early_mlp:
raise RuntimeError("compute_early_only=True 但 early_mlp 未启用")
qry_final, qry_early = self.encode_input(qry, return_early=True)
tgt_final, tgt_early = self.encode_input(tgt, return_early=True)
assert qry_early is not None and tgt_early is not None
qry_reps, tgt_reps = qry_early, tgt_early
else:
if self.enable_early_mlp and (self.training and self.early_loss_weight > 0):
qry_final, qry_early = self.encode_input(qry, return_early=True)
tgt_final, tgt_early = self.encode_input(tgt, return_early=True)
else:
qry_final = self.encode_input(qry, return_early=False)
tgt_final = self.encode_input(tgt, return_early=False)
qry_early, tgt_early = None, None
qry_reps, tgt_reps = qry_final, tgt_final
# DDP gather
if self.is_ddp:
all_qry_reps = self._dist_gather_tensor(qry_reps)
all_tgt_reps = self._dist_gather_tensor(tgt_reps)
else:
all_qry_reps = qry_reps
all_tgt_reps = tgt_reps
# 主loss
scores = self.compute_similarity(all_qry_reps, all_tgt_reps).view(all_qry_reps.size(0), -1)
target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long)
target = target * (all_qry_reps.size(0) // all_tgt_reps.size(0))
loss = self.cross_entropy(scores / self.temperature, target)
if self.is_ddp:
loss = loss * self.world_size
# 辅助 early loss(可选)
if (not compute_early_only) and self.training and self.enable_early_mlp and self.early_loss_weight > 0:
assert qry_early is not None and tgt_early is not None
if self.is_ddp:
all_qry_e = self._dist_gather_tensor(qry_early)
all_tgt_e = self._dist_gather_tensor(tgt_early)
else:
all_qry_e, all_tgt_e = qry_early, tgt_early
scores_e = self.compute_similarity(all_qry_e, all_tgt_e).view(all_qry_e.size(0), -1)
target_e = torch.arange(scores_e.size(0), device=scores_e.device, dtype=torch.long)
target_e = target_e * (all_qry_e.size(0) // all_tgt_e.size(0))
loss_e = self.cross_entropy(scores_e / self.temperature, target_e)
if self.is_ddp:
loss_e = loss_e * self.world_size
loss = loss + self.early_loss_weight * loss_e
return loss
def _dist_gather_tensor(self, t: Tensor):
t = t.contiguous()
all_tensors = [torch.empty_like(t) for _ in range(self.world_size)]
dist.all_gather(all_tensors, t)
all_tensors[self.process_rank] = t
all_tensors = torch.cat(all_tensors, dim=0)
return all_tensors
def compute_similarity(self, q_reps, p_reps):
return torch.matmul(q_reps, p_reps.transpose(0, 1))