| | from typing import Dict |
| | import torch |
| | import torch.distributed as dist |
| | from torch import nn, Tensor |
| | from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig |
| | from peft import LoraConfig, get_peft_model, PeftModel |
| | from src.model.processor import QWEN2_5_VL_TOKENSELECTION |
| | from src.arguments import ModelArguments, TrainingArguments |
| | from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, \ |
| | backbone2model, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V |
| |
|
| | from src.arguments import ModelArguments |
| | from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, INTERNVIDEO2, \ |
| | QWEN2_VL_TOKENSELECTION, backbone2model, GME, VLM_IMAGE_TOKENS, LamRA, LamRA_QWEN2_5, COLPALI |
| | from src.model.baseline_backbone.colpali import ColPali |
| | from src.model.baseline_backbone.gme.gme_inference import GmeQwen2VL |
| | from src.model.baseline_backbone.lamra.lamra_inference import LamRAQwen2VL |
| | from src.model.baseline_backbone.lamra.lamra_qwen25_inference import LamRAQwen25VL |
| | from src.model.baseline_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM |
| | from src.model.baseline_backbone.llava_next import LlavaNextForConditionalGeneration |
| |
|
| | from transformers import modeling_utils |
| | if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None: |
| | modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", 'rowwise'] |
| |
|
| | from contextlib import contextmanager |
| |
|
| | class MMEBModel(nn.Module): |
| | TRANSFORMER_CLS = AutoModelForCausalLM |
| |
|
| | def __init__(self, |
| | encoder: PreTrainedModel, |
| | pooling: str = 'last', |
| | normalize: bool = False, |
| | temperature: float = 0.02, |
| | ): |
| | super().__init__() |
| | self.config = encoder.config |
| | self.encoder = encoder |
| | self.pooling = pooling |
| | self.normalize = normalize |
| | self.temperature = temperature |
| | self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') |
| | self.is_ddp = dist.is_initialized() |
| | if self.is_ddp: |
| | self.process_rank = dist.get_rank() |
| | self.world_size = dist.get_world_size() |
| | self.qry_lm_layers = 16 |
| | self.tgt_lm_layers = 16 |
| | |
| | self._last_image_token_bool_masks = None |
| |
|
| | def set_inference_layers(self, qry_layers: int | None = None, tgt_layers: int | None = None): |
| | """设置推理阶段每侧使用的LM层数:qry用qry_layers层、cand用tgt_layers层。None代表不裁层。""" |
| | self.qry_lm_layers = qry_layers |
| | self.tgt_lm_layers = tgt_layers |
| |
|
| | def _locate_lm_layers(self, enc: nn.Module): |
| | """ |
| | 返回(decoder容器, 'layers'),用于访问 LM 的 ModuleList。 |
| | 兼容常见Qwen/LLaMA路径:model.language_model.layers / model.model.layers / model.layers / transformer.layers |
| | """ |
| | candidates = [ |
| | ("model", "language_model", "layers"), |
| | ("model", "model", "layers"), |
| | ("model", "layers"), |
| | ("language_model", "layers"), |
| | ("transformer", "layers"), |
| | ] |
| | for path in candidates: |
| | obj = enc |
| | ok = True |
| | for p in path: |
| | if hasattr(obj, p): |
| | obj = getattr(obj, p) |
| | else: |
| | ok = False |
| | break |
| | if ok and isinstance(obj, nn.ModuleList): |
| | |
| | parent = enc |
| | for p in path[:-1]: |
| | parent = getattr(parent, p) |
| | return parent, path[-1] |
| | return None, None |
| |
|
| | @contextmanager |
| | def _limit_lm_layers(self, enc: nn.Module, keep_layers: int | None): |
| | """ |
| | 临时把 enc 的 LM layers 截断为前 keep_layers 层;退出时恢复。 |
| | keep_layers 为 None/0/负数 -> 不裁层。 |
| | """ |
| | if not keep_layers or keep_layers <= 0: |
| | yield |
| | return |
| | root, attr = self._locate_lm_layers(enc) |
| | if root is None: |
| | |
| | yield |
| | return |
| | full: nn.ModuleList = getattr(root, attr) |
| | try: |
| | |
| | setattr(root, attr, nn.ModuleList(list(full[:keep_layers]))) |
| | yield |
| | finally: |
| | setattr(root, attr, full) |
| | |
| | |
| | |
| | def encode_input(self, input, side: str | None = None): |
| | max_layers = None |
| | if side == "qry": |
| | max_layers = self.qry_lm_layers |
| | elif side == "tgt": |
| | max_layers = self.tgt_lm_layers |
| |
|
| | with self._limit_lm_layers(self.encoder, max_layers): |
| | if getattr(self, "model_backbone", None) == INTERNVIDEO2: |
| | if "input_ids" in input.keys(): |
| | |
| | text_output = self.encoder.get_text_encoder()( |
| | input["input_ids"], |
| | attention_mask=input["attention_mask"], |
| | return_dict=True, |
| | mode="text", |
| | ) |
| | text_embeds = text_output.last_hidden_state |
| | pooled_text_embeds = text_embeds[:, 0] |
| | pooled_output = self.encoder.text_proj(pooled_text_embeds) |
| | pooled_output /= pooled_output.norm(dim=-1, keepdim=True) |
| | return pooled_output |
| | else: |
| | _, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True) |
| | vfeat = self.encoder.vision_proj(vfeat) |
| | vfeat /= vfeat.norm(dim=-1, keepdim=True) |
| | return vfeat |
| | elif getattr(self, "model_backbone", None) in [GME, LamRA, LamRA_QWEN2_5]: |
| | |
| | texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]] |
| | images = [] |
| | for imgs in input['images']: |
| | |
| | if isinstance(imgs, list): |
| | imgs = imgs[len(imgs) // 2] |
| | assert not isinstance(imgs, list) |
| | images.append(imgs) |
| | else: |
| | images.append(imgs) |
| | pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images) |
| | return pooled_output |
| | elif getattr(self, "model_backbone", None) == COLPALI: |
| | pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True) |
| | return pooled_output |
| | elif getattr(self, "model_backbone", None) == LLAVA_NEXT: |
| | input['pixel_values'] = input['pixel_values'].squeeze(dim=1) |
| | input['image_sizes'] = input['image_sizes'].squeeze(dim=1) |
| | hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True) |
| | hidden_states = hidden_states.hidden_states[-1] |
| | pooled_output = self._pooling(hidden_states, input['attention_mask']) |
| | return pooled_output |
| | else: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | outputs = self.encoder(**input, return_dict=True, output_hidden_states=True) |
| | img_masks = getattr(outputs, "image_token_bool_masks", None) |
| | if img_masks is None: |
| | try: |
| | img_masks = outputs.get("image_token_bool_masks", None) |
| | except Exception: |
| | pass |
| | self._last_image_token_bool_masks = img_masks |
| |
|
| | hs = outputs.hidden_states |
| | if isinstance(hs, (list, tuple)): |
| | last_hidden = hs[-1] |
| | else: |
| | last_hidden = hs |
| | attn = getattr(outputs, "attention_mask", None) |
| | |
| | |
| | if attn is None: |
| | attn = input["attention_mask"] |
| |
|
| | pooled_output = self._pooling(last_hidden, attn) |
| | return pooled_output |
| |
|
| | def _pooling(self, last_hidden_state, attention_mask): |
| | if self.pooling == 'last' or self.pooling == 'eos': |
| | left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) |
| | batch_size = last_hidden_state.shape[0] |
| | if left_padding: |
| | |
| | reps = last_hidden_state[torch.arange(batch_size), -1, :] |
| | else: |
| | |
| | eos_indices = attention_mask.sum(dim=1) - 1 |
| | |
| | reps = last_hidden_state[ |
| | torch.arange(batch_size, device=last_hidden_state.device), eos_indices] |
| | else: |
| | raise NotImplementedError |
| | if self.normalize: |
| | reps = torch.nn.functional.normalize(reps, p=2, dim=-1) |
| | return reps |
| |
|
| | @classmethod |
| | def build(cls, model_args: ModelArguments, **kwargs): |
| | config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
| | variant = getattr(config, "backbone_variant", None) |
| | if variant == "layerprune": |
| | model_backbone = "QWEN2_VL_LayerPrune" |
| | else: |
| | model_backbone = get_backbone_name(hf_config=config) |
| | print_master(f'Loading backbone [{model_backbone}] from {model_args.model_name}') |
| | |
| | if model_backbone == PHI3V: |
| | config._attn_implementation = "eager" |
| | config.padding_side = "right" |
| | config.use_cache = False |
| | base_model = Phi3VForCausalLM.from_pretrained( |
| | model_args.model_name, |
| | config=config, |
| | torch_dtype=torch.bfloat16, |
| | low_cpu_mem_usage=True, |
| | ) |
| | elif model_backbone == LLAVA_NEXT: |
| | config.use_cache = False |
| | config.padding_side = "left" |
| | base_model = LlavaNextForConditionalGeneration.from_pretrained( |
| | model_args.model_name, |
| | config=config, |
| | torch_dtype=torch.bfloat16, |
| | low_cpu_mem_usage=True, |
| | ) |
| | elif model_backbone in [QWEN2_VL, QWEN2_5_VL]: |
| | config._attn_implementation = "flash_attention_2" |
| | config.padding_side = "left" |
| | config.use_cache = False |
| | base_model = backbone2model[model_backbone].from_pretrained( |
| | model_args.model_name, |
| | config=config, |
| | torch_dtype=torch.bfloat16, |
| | low_cpu_mem_usage=True, |
| | ) |
| | elif model_backbone in ["QWEN2_VL_LayerPrune"]: |
| | config._attn_implementation = "flash_attention_2" |
| | config.padding_side = "left" |
| | config.use_cache = False |
| | base_model = backbone2model[model_backbone].from_pretrained( |
| | model_args.model_name, |
| | config=config, |
| | torch_dtype=torch.bfloat16, |
| | low_cpu_mem_usage=True, |
| | ) |
| | elif model_backbone in [QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION]: |
| | config._attn_implementation = "flash_attention_2" |
| | config.padding_side = "left" |
| | config.use_cache = False |
| |
|
| | from .utils import parse_layer_type |
| | lm_qwen_layer = 28 |
| | vis_qwen_layer = 32 |
| | lm_skip_layer = parse_layer_type(model_args.lm_skip_layer, lm_qwen_layer) |
| | vis_skip_layer = parse_layer_type(model_args.vis_skip_layer, vis_qwen_layer) |
| |
|
| | base_model = backbone2model[model_backbone].from_pretrained( |
| | model_args.model_name, |
| | config=config, |
| | torch_dtype=torch.bfloat16, |
| | low_cpu_mem_usage=True, |
| | lm_skip_layer=lm_skip_layer, |
| | vis_skip_layer=vis_skip_layer, |
| | ) |
| | else: |
| | config.use_cache = False |
| | base_model = cls.TRANSFORMER_CLS.from_pretrained( |
| | model_args.model_name, **kwargs, config=config, |
| | attn_implementation="flash_attention_2", |
| | torch_dtype=torch.bfloat16, |
| | trust_remote_code=True) |
| |
|
| | if model_args.lora: |
| | print_master(f'Loading lora adapter from {base_model}') |
| | lora_config = LoraConfig( |
| | r=model_args.lora_r, |
| | lora_alpha=model_args.lora_alpha, |
| | target_modules=model_args.lora_target_modules.split(','), |
| | lora_dropout=model_args.lora_dropout, |
| | init_lora_weights="gaussian", |
| | use_dora=True, |
| | inference_mode=False |
| | ) |
| | lora_model = get_peft_model(base_model, lora_config) |
| | model = cls( |
| | encoder=lora_model, |
| | pooling=model_args.pooling, |
| | normalize=model_args.normalize, |
| | temperature=model_args.temperature |
| | ) |
| | else: |
| | model = cls( |
| | encoder=base_model, |
| | pooling=model_args.pooling, |
| | normalize=model_args.normalize, |
| | temperature=model_args.temperature |
| | ) |
| | return model |
| |
|
| |
|
| | @classmethod |
| | def load(cls, model_args: ModelArguments, is_trainable=True, **kwargs): |
| | |
| | model_name_or_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name |
| | config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) |
| | if not hasattr(model_args, "model_backbone") or not model_args.model_backbone: |
| | model_backbone = get_backbone_name(hf_config=config, model_type=model_args.model_type) |
| | setattr(model_args, 'model_backbone', model_backbone) |
| | print_master(f'Loading backbone [{model_args.model_backbone}] from {model_name_or_path}') |
| | if model_args.model_backbone in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V}: |
| | config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
| | config._attn_implementation = "flash_attention_2" |
| | config.vision_config._attn_implementation = "flash_attention_2" |
| | base_model = backbone2model[model_args.model_backbone].from_pretrained( |
| | model_args.model_name, |
| | torch_dtype=torch.bfloat16, |
| | low_cpu_mem_usage=True, |
| | config=config |
| | ) |
| | elif model_args.model_backbone == PHI3V: |
| | config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
| | config.use_cache = False |
| | config.padding_side = "right" |
| | base_model = Phi3VForCausalLM.from_pretrained(model_args.model_name, **kwargs, config=config, |
| | torch_dtype=torch.bfloat16, trust_remote_code=True) |
| | base_model.padding_side = "right" |
| | elif model_args.model_backbone == INTERNVIDEO2: |
| | print_master(f'Loading backbone [{model_args.model_backbone}] from {"src/model/vlm_backbone/internvideo2/"}') |
| | config = AutoConfig.from_pretrained("src/model/vlm_backbone/internvideo2/", |
| | trust_remote_code=True) |
| | base_model = backbone2model[model_args.model_backbone].from_pretrained("src/model/vlm_backbone/internvideo2/", config=config, |
| | trust_remote_code=True) |
| | elif model_args.model_backbone == GME: |
| | base_model = GmeQwen2VL(model_args.model_name, processor=kwargs['processor']) |
| | setattr(base_model, 'config', config) |
| | elif model_args.model_backbone == LamRA: |
| | base_model = LamRAQwen2VL(model_args.model_name) |
| | setattr(base_model, 'config', config) |
| | elif model_args.model_backbone == LamRA_QWEN2_5: |
| | base_model = LamRAQwen25VL(model_args.model_name) |
| | setattr(base_model, 'config', config) |
| | elif model_args.model_backbone == COLPALI: |
| | base_model = ColPali.from_pretrained(model_args.model_name) |
| | setattr(base_model, 'config', config) |
| | else: |
| | |
| | config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
| | config.use_cache = False |
| | base_model = cls.TRANSFORMER_CLS.from_pretrained( |
| | model_name_or_path, **kwargs, config=config, |
| | torch_dtype=torch.bfloat16, |
| | trust_remote_code=True) |
| |
|
| | |
| | if model_args.lora: |
| | print_master(f'Loading LoRA from {model_name_or_path}') |
| | lora_config = LoraConfig.from_pretrained(model_name_or_path) |
| | lora_model = PeftModel.from_pretrained(base_model, model_name_or_path, config=lora_config, is_trainable=is_trainable) |
| | lora_model.load_adapter(model_name_or_path, lora_model.active_adapter, is_trainable=is_trainable) |
| | if not is_trainable: |
| | lora_model = lora_model.merge_and_unload() |
| | model = cls( |
| | encoder=lora_model, |
| | pooling=model_args.pooling, |
| | normalize=model_args.normalize, |
| | temperature=model_args.temperature |
| | ) |
| | else: |
| | model = cls( |
| | encoder=base_model, |
| | pooling=model_args.pooling, |
| | normalize=model_args.normalize, |
| | temperature=model_args.temperature |
| | ) |
| |
|
| | model.model_backbone = model_args.model_backbone |
| | return model |
| |
|
| | def save(self, output_dir: str): |
| | self.encoder.save_pretrained(output_dir) |
| |
|
| | def forward(self, qry: Dict[str, Tensor] = None, tgt: Dict[str, Tensor] = None, *args, **kwargs): |
| | qry_reps, tgt_reps = None, None |
| |
|
| | if qry is not None: |
| | qry_reps = self.encode_input(qry, side="qry") |
| | if tgt is not None: |
| | tgt_reps = self.encode_input(tgt, side="tgt") |
| |
|
| | |
| | if qry_reps is None or tgt_reps is None: |
| | |
| | out = {"qry_reps": qry_reps, "tgt_reps": tgt_reps} |
| | |
| | img_masks = getattr(self, "_last_image_token_bool_masks", None) |
| | if img_masks is not None: |
| | out["image_token_bool_masks"] = img_masks |
| | return out |
| |
|
| | if self.is_ddp: |
| | all_qry_reps = self._dist_gather_tensor(qry_reps) |
| | all_tgt_reps = self._dist_gather_tensor(tgt_reps) |
| | else: |
| | all_qry_reps = qry_reps |
| | all_tgt_reps = tgt_reps |
| |
|
| | scores = self.compute_similarity(all_qry_reps, all_tgt_reps) |
| | scores = scores.view(all_qry_reps.size(0), -1) |
| | target = torch.arange(scores.size(0), device=scores.device, dtype=torch.long) |
| | target = target * (all_qry_reps.size(0) // all_tgt_reps.size(0)) |
| | loss = self.cross_entropy(scores / self.temperature, target) |
| | if self.is_ddp: |
| | loss = loss * self.world_size |
| |
|
| | return loss |
| |
|
| | def _dist_gather_tensor(self, t: Tensor): |
| | t = t.contiguous() |
| | all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] |
| | dist.all_gather(all_tensors, t) |
| | all_tensors[self.process_rank] = t |
| | all_tensors = torch.cat(all_tensors, dim=0) |
| | return all_tensors |
| |
|
| | def compute_similarity(self, q_reps, p_reps): |
| | return torch.matmul(q_reps, p_reps.transpose(0, 1)) |
| |
|