| from typing import Dict |
| import os |
| import torch |
| import torch.distributed as dist |
| from torch import nn, Tensor |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel, AutoModelForCausalLM, AutoConfig |
| from peft import LoraConfig, get_peft_model, PeftModel |
| from src.model.processor import QWEN2_5_VL_TOKENSELECTION |
| from src.arguments_multi_layer import ModelArguments, TrainingArguments |
| from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, \ |
| backbone2model, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V |
|
|
| from src.model.processor import LLAVA_NEXT, QWEN2_VL, PHI3V, get_backbone_name, print_master, QWEN2_5_VL, INTERNVIDEO2, \ |
| QWEN2_VL_TOKENSELECTION, backbone2model, GME, VLM_IMAGE_TOKENS, LamRA, LamRA_QWEN2_5, COLPALI |
| from src.model.baseline_backbone.colpali import ColPali |
| from src.model.baseline_backbone.gme.gme_inference import GmeQwen2VL |
| from src.model.baseline_backbone.lamra.lamra_inference import LamRAQwen2VL |
| from src.model.baseline_backbone.lamra.lamra_qwen25_inference import LamRAQwen25VL |
| from src.model.baseline_backbone.phi3_v.modeling_phi3_v import Phi3VForCausalLM |
| from src.model.baseline_backbone.llava_next import LlavaNextForConditionalGeneration |
|
|
| from transformers import modeling_utils |
| if not hasattr(modeling_utils, "ALL_PARALLEL_STYLES") or modeling_utils.ALL_PARALLEL_STYLES is None: |
| modeling_utils.ALL_PARALLEL_STYLES = ["tp", "none", "colwise", 'rowwise'] |
| from contextlib import contextmanager |
|
|
| class _AOPSwitch: |
| """ |
| Temporarily toggle encoder.aop_prune_config.enabled for one forward call. |
| """ |
| def __init__(self, module: nn.Module, enable: bool): |
| self.module = module |
| self.enable = bool(enable) |
| self._old = getattr(module, "aop_prune_config", None) |
|
|
| def __enter__(self): |
| |
| if self._old is None: |
| return self |
| if not self.enable: |
| |
| if isinstance(self._old, dict): |
| cfg = dict(self._old) |
| cfg["enabled"] = False |
| setattr(self.module, "aop_prune_config", cfg) |
| else: |
| setattr(self.module, "aop_prune_config", None) |
| |
| return self |
|
|
| def __exit__(self, exc_type, exc, tb): |
| |
| setattr(self.module, "aop_prune_config", self._old) |
| return False |
|
|
| class _VPoolSwitch: |
| """ |
| Temporarily toggle encoder.vision_pooling_config.enabled for one forward call. |
| """ |
| def __init__(self, module: nn.Module, enable: bool): |
| self.module = module |
| self.enable = bool(enable) |
| self._old = getattr(module, "vision_pooling_config", None) |
| def __enter__(self): |
| if self._old is None: return self |
| if not self.enable: |
| if isinstance(self._old, dict): |
| cfg = dict(self._old); cfg["enabled"] = False |
| setattr(self.module, "vision_pooling_config", cfg) |
| else: |
| setattr(self.module, "vision_pooling_config", None) |
| return self |
| def __exit__(self, exc_type, exc, tb): |
| setattr(self.module, "vision_pooling_config", self._old) |
| return False |
|
|
| class MMEBModel(nn.Module): |
| TRANSFORMER_CLS = AutoModelForCausalLM |
|
|
| def __init__(self, |
| encoder: PreTrainedModel, |
| pooling: str = 'last', |
| normalize: bool = False, |
| temperature: float = 0.02, |
| ): |
| super().__init__() |
| self.config = encoder.config |
| self.encoder = encoder |
| self.pooling = pooling |
| self.normalize = normalize |
| self.temperature = temperature |
| self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') |
| self.is_ddp = dist.is_initialized() |
| if self.is_ddp: |
| self.process_rank = dist.get_rank() |
| self.world_size = dist.get_world_size() |
| self.layer_indices = [20, -1] |
| |
| self.supervise_layers = [20, -1] |
| self.supervise_weights = [0.15, 0.85] |
| |
| @property |
| def device(self) -> torch.device: |
| try: |
| return next(self.parameters()).device |
| except StopIteration: |
| |
| return torch.device("cpu") |
| |
| def _want_prune_for(self, side: str) -> bool: |
| """ |
| side: "qry" or "tgt" |
| """ |
| cfg = getattr(self.encoder, "aop_prune_config", None) |
| if not isinstance(cfg, dict) or not cfg.get("enabled", False): |
| return False |
| apply_to = str(cfg.get("apply_to", "both")).lower() |
| return (apply_to == "both") or (apply_to == side.lower()) |
|
|
| def _want_pool_for(self, side: str) -> bool: |
| cfg = getattr(self.encoder, "vision_pooling_config", None) |
| if not isinstance(cfg, dict) or not cfg.get("enabled", False): return False |
| apply_to = str(cfg.get("apply_to", "both")).lower() |
| return (apply_to == "both") or (apply_to == side.lower()) |
| |
| def _normalize_layers(self, hs_len: int, layers: list[int]) -> list[int]: |
| Lmax = hs_len - 1 |
| out = [] |
| for idx in layers: |
| if idx < 0: |
| idx = hs_len + idx |
| idx = max(1, min(idx, Lmax)) |
| out.append(idx) |
| if (hs_len - 1) not in out: |
| out.append(hs_len - 1) |
| return out |
|
|
| def _encode_multi(self, input): |
| """ |
| 通用多层编码:返回 [B, K, D],K=len(self.supervise_layers,经规范化且包含最后一层)。 |
| """ |
| mb = getattr(self, "model_backbone", None) |
|
|
| def norm(x): |
| return F.normalize(x, p=2, dim=-1) if self.normalize else x |
|
|
| |
| if mb not in [GME, LamRA, LamRA_QWEN2_5, INTERNVIDEO2, COLPALI]: |
| out = self.encoder( |
| **input, |
| return_dict=True, |
| output_hidden_states=True, |
| ) |
| hs_list = out.hidden_states |
| |
| post_mask = getattr(out, "attention_mask", None) |
| pre_mask = input['attention_mask'] |
|
|
| |
| idxs = self._normalize_layers(len(hs_list), list(dict.fromkeys(self.supervise_layers))) |
|
|
| |
| aop_cfg = getattr(self.encoder, "aop_prune_config", None) |
| cut_layer = None |
| if isinstance(aop_cfg, dict) and aop_cfg.get("enabled", False): |
| try: |
| cut_layer = int(aop_cfg.get("layer_idx") or 0) |
| if cut_layer <= 0: |
| cut_layer = None |
| except Exception: |
| cut_layer = None |
| vpool_cfg = getattr(self.encoder, "vision_pooling_config", None) |
| pool_layer = None |
| if isinstance(vpool_cfg, dict) and vpool_cfg.get("enabled", False): |
| try: |
| pool_layer = int(vpool_cfg.get("layer_idx") or 0) |
| if pool_layer <= 0: |
| pool_layer = None |
| except Exception: |
| pool_layer = None |
|
|
| reps = [] |
| for idx in idxs: |
| h = hs_list[idx] |
| use_post = False |
| if post_mask is not None: |
| if (cut_layer is not None and idx >= cut_layer + 1) or (pool_layer is not None and idx >= pool_layer + 1): |
| use_post = True |
|
|
| mask_this = post_mask if use_post else pre_mask |
|
|
| if mask_this is not None and h.size(1) != mask_this.size(1): |
| if pre_mask is not None and pre_mask.size(1) == h.size(1): |
| mask_this = pre_mask |
| elif post_mask is not None and post_mask.size(1) == h.size(1): |
| mask_this = post_mask |
| else: |
| mask_this = torch.ones(h.size(0), h.size(1), dtype=torch.long, device=h.device) |
|
|
| r = self._pooling(h, mask_this) |
| reps.append(F.normalize(r, p=2, dim=-1) if self.normalize else r) |
|
|
| return torch.stack(reps, dim=1) |
|
|
| |
| def encode_input(self, input, layer_indices=None): |
| if getattr(self, "model_backbone", None) == INTERNVIDEO2: |
| if "input_ids" in input.keys(): |
| |
| text_output = self.encoder.get_text_encoder()( |
| input["input_ids"], |
| attention_mask=input["attention_mask"], |
| return_dict=True, |
| mode="text", |
| ) |
| text_embeds = text_output.last_hidden_state |
| pooled_text_embeds = text_embeds[:, 0] |
| pooled_output = self.encoder.text_proj(pooled_text_embeds) |
| pooled_output /= pooled_output.norm(dim=-1, keepdim=True) |
| return pooled_output |
| else: |
| _, vfeat = self.encoder.encode_vision(input["pixel_values"], test=True) |
| vfeat = self.encoder.vision_proj(vfeat) |
| vfeat /= vfeat.norm(dim=-1, keepdim=True) |
| return vfeat |
| elif getattr(self, "model_backbone", None) in [GME, LamRA, LamRA_QWEN2_5]: |
| |
| texts = [text.replace(VLM_IMAGE_TOKENS[QWEN2_VL] + '\n', '') for text in input["texts"]] |
| images = [] |
| for imgs in input['images']: |
| |
| if isinstance(imgs, list): |
| imgs = imgs[len(imgs) // 2] |
| assert not isinstance(imgs, list) |
| images.append(imgs) |
| else: |
| images.append(imgs) |
| pooled_output = self.encoder.get_fused_embeddings(texts=texts, images=images) |
| return pooled_output |
| elif getattr(self, "model_backbone", None) == COLPALI: |
| pooled_output = self.encoder(**input, return_dict=True, output_hidden_states=True) |
| return pooled_output |
| elif getattr(self, "model_backbone", None) == LLAVA_NEXT: |
| input['pixel_values'] = input['pixel_values'].squeeze(dim=1) |
| input['image_sizes'] = input['image_sizes'].squeeze(dim=1) |
| hidden_states = self.encoder(**input, return_dict=True, output_hidden_states=True) |
| hidden_states = hidden_states.hidden_states[-1] |
| pooled_output = self._pooling(hidden_states, input['attention_mask']) |
| return pooled_output |
| else: |
| |
| out = self.encoder(**input, return_dict=True, output_hidden_states=True) |
| hs_list = out.hidden_states |
| post_mask = getattr(out, "attention_mask", None) |
| pre_mask = input['attention_mask'] |
|
|
| |
| if os.getenv("AOP_MONITOR", "0") == "1": |
| try: |
| B = pre_mask.size(0) if pre_mask is not None else hs_list[-1].size(0) |
| |
| pre_len = pre_mask.sum(dim=1).detach().cpu().tolist() if pre_mask is not None else [hs_list[-1].size(1)] * B |
| post_len = post_mask.sum(dim=1).detach().cpu().tolist() if post_mask is not None else pre_len |
|
|
| |
| aop_cfg = getattr(self.encoder, "aop_prune_config", None) |
| kr_t = aop_cfg.get("_last_sampled_keep_ratio_text") if isinstance(aop_cfg, dict) else None |
| kr_v = aop_cfg.get("_last_sampled_keep_ratio_vision") if isinstance(aop_cfg, dict) else None |
|
|
| |
| pre_txt_cnt = pre_vis_cnt = post_txt_cnt = post_vis_cnt = None |
| input_ids = input.get("input_ids", None) |
| if input_ids is not None and pre_mask is not None: |
| cfg = self.encoder.config |
| valid_pre = pre_mask.bool() |
| vis_pre = (input_ids == getattr(cfg, "image_token_id", -999)) |
| if hasattr(cfg, "video_token_id") and cfg.video_token_id is not None and cfg.video_token_id >= 0: |
| vis_pre = vis_pre | (input_ids == cfg.video_token_id) |
| special_pre = torch.zeros_like(input_ids, dtype=torch.bool) |
| for name in ["bos_token_id", "eos_token_id", "pad_token_id"]: |
| tid = getattr(cfg, name, None) |
| if tid is not None and tid >= 0: |
| special_pre |= (input_ids == tid) |
| pre_vis_cnt = (vis_pre & valid_pre).sum(dim=1).detach().cpu().tolist() |
| pre_txt_cnt = (valid_pre & (~vis_pre) & (~special_pre)).sum(dim=1).detach().cpu().tolist() |
|
|
| vis_post_mask = getattr(out, "image_token_bool_masks", None) |
| txt_post_mask = getattr(out, "text_token_bool_masks", None) |
| if vis_post_mask is not None: |
| post_vis_cnt = vis_post_mask.sum(dim=1).detach().cpu().tolist() |
| if txt_post_mask is not None: |
| post_txt_cnt = txt_post_mask.sum(dim=1).detach().cpu().tolist() |
|
|
| |
| if not hasattr(self, "_aop_mon_prints"): |
| self._aop_mon_prints = 0 |
| if self._aop_mon_prints < 3: |
| print(f"[AOP][monitor] B={B} sampled: kr_text={kr_t}, kr_vision={kr_v}") |
| for b in range(min(B, 8)): |
| preL = int(pre_len[b]); postL = int(post_len[b]); keep = (postL / (preL + 1e-9)) |
| msg = f" b={b}: pre_len={preL}, post_len={postL}, keep={keep:.3f}" |
| if pre_txt_cnt is not None and post_txt_cnt is not None: |
| kt = (post_txt_cnt[b] / (pre_txt_cnt[b] + 1e-9)) if pre_txt_cnt[b] > 0 else float('nan') |
| msg += f", txt_keep={kt:.3f}" |
| if pre_vis_cnt is not None and post_vis_cnt is not None: |
| kv = (post_vis_cnt[b] / (pre_vis_cnt[b] + 1e-9)) if pre_vis_cnt[b] > 0 else float('nan') |
| msg += f", vis_keep={kv:.3f}" |
| print(msg) |
| self._aop_mon_prints += 1 |
| except Exception as e: |
| |
| print(f"[AOP][monitor] warn: monitor failed with error: {e}") |
|
|
| def _pooling(self, last_hidden_state, attention_mask): |
| if self.pooling == 'last' or self.pooling == 'eos': |
| left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) |
| batch_size = last_hidden_state.shape[0] |
| if left_padding: |
| |
| reps = last_hidden_state[torch.arange(batch_size), -1, :] |
| else: |
| |
| eos_indices = attention_mask.sum(dim=1) - 1 |
| |
| reps = last_hidden_state[ |
| torch.arange(batch_size, device=last_hidden_state.device), eos_indices] |
| else: |
| raise NotImplementedError |
| if self.normalize: |
| reps = torch.nn.functional.normalize(reps, p=2, dim=-1) |
| return reps |
|
|
| @classmethod |
| def build(cls, model_args: ModelArguments, **kwargs): |
| config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
| variant = getattr(config, "backbone_variant", None) |
| if variant == "layerprune": |
| model_backbone = "QWEN2_VL_LayerPrune" |
| else: |
| model_backbone = get_backbone_name(hf_config=config) |
| print_master(f'Loading backbone [{model_backbone}] from {model_args.model_name}') |
| |
| if model_backbone == PHI3V: |
| config._attn_implementation = "eager" |
| config.padding_side = "right" |
| config.use_cache = False |
| base_model = Phi3VForCausalLM.from_pretrained( |
| model_args.model_name, |
| config=config, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| ) |
| elif model_backbone == LLAVA_NEXT: |
| config.use_cache = False |
| config.padding_side = "left" |
| base_model = LlavaNextForConditionalGeneration.from_pretrained( |
| model_args.model_name, |
| config=config, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| ) |
| elif model_backbone in [QWEN2_VL, QWEN2_5_VL]: |
| config._attn_implementation = "flash_attention_2" |
| config.padding_side = "left" |
| config.use_cache = False |
| if model_backbone == QWEN2_5_VL: |
| |
| try: |
| from src.model.vlm_backbone.qwen2_5_vl_mldaop_pooling.modeling_qwen2_5_vl import ( |
| Qwen2_5_VLForConditionalGeneration as Qwen2_5VL_Variant |
| ) |
| base_model = Qwen2_5VL_Variant.from_pretrained( |
| model_args.model_name, |
| config=config, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| ) |
| print_master("[Backbone] Using qwen2_5_vl_mldaop_pooling variant (default).") |
| except Exception as e: |
| print_master(f"[Backbone] mldaop_pooling import failed, fallback to vanilla. err={e}") |
| base_model = backbone2model[model_backbone].from_pretrained( |
| model_args.model_name, |
| config=config, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| ) |
| else: |
| try: |
| from src.model.vlm_backbone.qwen2_vl_mldaop_pooling.modeling_qwen2_vl import ( |
| Qwen2VLForConditionalGeneration as Qwen2VL_Variant |
| ) |
| base_model = Qwen2VL_Variant.from_pretrained( |
| model_args.model_name, |
| config=config, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| ) |
| print_master("[Backbone] Using qwen2_vl_mldaop_pooling variant (default).") |
| except Exception as e: |
| print_master(f"[Backbone] qwen2_vl_mldaop_pooling import failed, fallback to vanilla. err={e}") |
| base_model = backbone2model[model_backbone].from_pretrained( |
| model_args.model_name, |
| config=config, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| ) |
| elif model_backbone in ["QWEN2_VL_LayerPrune"]: |
| config._attn_implementation = "flash_attention_2" |
| config.padding_side = "left" |
| config.use_cache = False |
| base_model = backbone2model[model_backbone].from_pretrained( |
| model_args.model_name, |
| config=config, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| ) |
| elif model_backbone in [QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION]: |
| config._attn_implementation = "flash_attention_2" |
| config.padding_side = "left" |
| config.use_cache = False |
|
|
| from .utils import parse_layer_type |
| lm_qwen_layer = 28 |
| vis_qwen_layer = 32 |
| lm_skip_layer = parse_layer_type(model_args.lm_skip_layer, lm_qwen_layer) |
| vis_skip_layer = parse_layer_type(model_args.vis_skip_layer, vis_qwen_layer) |
|
|
| base_model = backbone2model[model_backbone].from_pretrained( |
| model_args.model_name, |
| config=config, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| lm_skip_layer=lm_skip_layer, |
| vis_skip_layer=vis_skip_layer, |
| ) |
| else: |
| config.use_cache = False |
| base_model = cls.TRANSFORMER_CLS.from_pretrained( |
| model_args.model_name, **kwargs, config=config, |
| attn_implementation="flash_attention_2", |
| torch_dtype=torch.bfloat16, |
| trust_remote_code=True) |
|
|
| if model_args.lora: |
| print_master(f'Loading lora adapter from {base_model}') |
| lora_config = LoraConfig( |
| r=model_args.lora_r, |
| lora_alpha=model_args.lora_alpha, |
| target_modules=model_args.lora_target_modules.split(','), |
| lora_dropout=model_args.lora_dropout, |
| init_lora_weights="gaussian", |
| use_dora=True, |
| inference_mode=False |
| ) |
| lora_model = get_peft_model(base_model, lora_config) |
| model = cls( |
| encoder=lora_model, |
| pooling=model_args.pooling, |
| normalize=model_args.normalize, |
| temperature=model_args.temperature |
| ) |
| else: |
| model = cls( |
| encoder=base_model, |
| pooling=model_args.pooling, |
| normalize=model_args.normalize, |
| temperature=model_args.temperature |
| ) |
| |
| def _parse_list(val, tp=float): |
| if val is None: return None |
| if isinstance(val, (list, tuple)): return [tp(x) for x in val] |
| s = str(val).strip() |
| if s == "": return None |
| return [tp(v.strip()) for v in s.split(",") if v.strip() != ""] |
|
|
| layers = _parse_list(getattr(model_args, "supervise_layers", None), tp=int) |
| weights = _parse_list(getattr(model_args, "supervise_weights", None), tp=float) |
|
|
| if layers is None: |
| |
| layers = [getattr(model_args, 'dual_layer_idx', 20), -1] |
| if -1 not in layers: |
| layers = list(layers) + [-1] |
|
|
| if weights is None or len(weights) != len(layers): |
| |
| K = len(layers) |
| base = [1.0/(K-1)]*(K-1) if K>1 else [1.0] |
| weights = base + [max(0.0, 1.0 - sum(base))] |
|
|
| |
| s = sum(max(0.0, w) for w in weights) |
| weights = [max(0.0, w)/s for w in weights] |
|
|
| setattr(model, 'supervise_layers', layers) |
| setattr(model, 'supervise_weights', weights) |
| |
| setattr(model, 'dual_layer_idx', layers[0] if len(layers)>1 else layers[0]) |
| setattr(model, 'dual_alpha', weights[0] if len(weights)>1 else 1.0) |
| setattr(model, 'layer_indices', layers) |
| return model |
|
|
|
|
| @classmethod |
| def load(cls, model_args: ModelArguments, is_trainable=True, **kwargs): |
| |
| model_name_or_path = model_args.checkpoint_path if model_args.checkpoint_path else model_args.model_name |
| config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) |
| if not hasattr(model_args, "model_backbone") or not model_args.model_backbone: |
| model_backbone = get_backbone_name(hf_config=config, model_type=model_args.model_type) |
| setattr(model_args, 'model_backbone', model_backbone) |
| print_master(f'Loading backbone [{model_args.model_backbone}] from {model_name_or_path}') |
| if model_args.model_backbone in {LLAVA_NEXT, QWEN2_VL, QWEN2_5_VL, QWEN2_VL_TOKENSELECTION, QWEN2_5_VL_TOKENSELECTION, E5_V}: |
| config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
| config._attn_implementation = "flash_attention_2" |
| if hasattr(config, "vision_config") and getattr(config, "vision_config") is not None: |
| config.vision_config._attn_implementation = "flash_attention_2" |
| if model_args.model_backbone == QWEN2_5_VL: |
| try: |
| from src.model.vlm_backbone.qwen2_5_vl_mldaop_pooling.modeling_qwen2_5_vl import ( |
| Qwen2_5_VLForConditionalGeneration as Qwen2_5VL_Variant |
| ) |
| base_model = Qwen2_5VL_Variant.from_pretrained( |
| model_args.model_name, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| config=config |
| ) |
| print_master("[Backbone:load] Using qwen2_5_vl_mldaop_pooling variant (default).") |
| except Exception as e: |
| print_master(f"[Backbone:load] mldaop_pooling import failed, fallback to vanilla. err={e}") |
| base_model = backbone2model[model_args.model_backbone].from_pretrained( |
| model_args.model_name, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| config=config |
| ) |
| elif model_args.model_backbone == QWEN2_VL: |
| try: |
| from src.model.vlm_backbone.qwen2_vl_mldaop_pooling.modeling_qwen2_vl import ( |
| Qwen2VLForConditionalGeneration as Qwen2VL_Variant |
| ) |
| base_model = Qwen2VL_Variant.from_pretrained( |
| model_args.model_name, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| config=config |
| ) |
| print_master("[Backbone:load] Using qwen2_vl_mldaop_pooling variant (default).") |
| except Exception as e: |
| print_master(f"[Backbone:load] qwen2_vl_mldaop_pooling import failed, fallback to vanilla. err={e}") |
| base_model = backbone2model[model_args.model_backbone].from_pretrained( |
| model_args.model_name, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| config=config |
| ) |
| else: |
| base_model = backbone2model[model_args.model_backbone].from_pretrained( |
| model_args.model_name, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| config=config |
| ) |
| elif model_args.model_backbone == PHI3V: |
| config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
| config.use_cache = False |
| config.padding_side = "right" |
| base_model = Phi3VForCausalLM.from_pretrained(model_args.model_name, **kwargs, config=config, |
| torch_dtype=torch.bfloat16, trust_remote_code=True) |
| base_model.padding_side = "right" |
| elif model_args.model_backbone == INTERNVIDEO2: |
| print_master(f'Loading backbone [{model_args.model_backbone}] from {"src/model/vlm_backbone/internvideo2/"}') |
| config = AutoConfig.from_pretrained("src/model/vlm_backbone/internvideo2/", |
| trust_remote_code=True) |
| base_model = backbone2model[model_args.model_backbone].from_pretrained("src/model/vlm_backbone/internvideo2/", config=config, |
| trust_remote_code=True) |
| elif model_args.model_backbone == GME: |
| base_model = GmeQwen2VL(model_args.model_name, processor=kwargs['processor']) |
| setattr(base_model, 'config', config) |
| elif model_args.model_backbone == LamRA: |
| base_model = LamRAQwen2VL(model_args.model_name) |
| setattr(base_model, 'config', config) |
| elif model_args.model_backbone == LamRA_QWEN2_5: |
| base_model = LamRAQwen25VL(model_args.model_name) |
| setattr(base_model, 'config', config) |
| elif model_args.model_backbone == COLPALI: |
| base_model = ColPali.from_pretrained(model_args.model_name) |
| setattr(base_model, 'config', config) |
| else: |
| |
| config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
| config.use_cache = False |
| base_model = cls.TRANSFORMER_CLS.from_pretrained( |
| model_name_or_path, **kwargs, config=config, |
| torch_dtype=torch.bfloat16, |
| trust_remote_code=True) |
|
|
| |
| if model_args.lora: |
| print_master(f'Loading LoRA from {model_name_or_path}') |
| lora_config = LoraConfig.from_pretrained(model_name_or_path) |
| lora_model = PeftModel.from_pretrained(base_model, model_name_or_path, config=lora_config, is_trainable=is_trainable) |
| lora_model.load_adapter(model_name_or_path, lora_model.active_adapter, is_trainable=is_trainable) |
| if not is_trainable: |
| lora_model = lora_model.merge_and_unload() |
| model = cls( |
| encoder=lora_model, |
| pooling=model_args.pooling, |
| normalize=model_args.normalize, |
| temperature=model_args.temperature |
| ) |
| else: |
| model = cls( |
| encoder=base_model, |
| pooling=model_args.pooling, |
| normalize=model_args.normalize, |
| temperature=model_args.temperature |
| ) |
|
|
| model.model_backbone = model_args.model_backbone |
| return model |
|
|
| def save(self, output_dir: str): |
| self.encoder.save_pretrained(output_dir) |
|
|
| def forward(self, qry: Dict[str, Tensor] = None, tgt: Dict[str, Tensor] = None, *args, **kwargs): |
| |
| if qry is not None and tgt is None: |
| with _AOPSwitch(self.encoder, self._want_prune_for("qry")): |
| with _VPoolSwitch(self.encoder, self._want_pool_for("qry")): |
| qry_reps = self._encode_multi(qry) |
| return {"qry_reps": qry_reps, "tgt_reps": None} |
| if tgt is not None and qry is None: |
| with _AOPSwitch(self.encoder, self._want_prune_for("tgt")): |
| with _VPoolSwitch(self.encoder, self._want_pool_for("tgt")): |
| tgt_reps = self._encode_multi(tgt) |
| return {"qry_reps": None, "tgt_reps": tgt_reps} |
|
|
| with _AOPSwitch(self.encoder, self._want_prune_for("qry")): |
| with _VPoolSwitch(self.encoder, self._want_pool_for("qry")): |
| q_multi = self._encode_multi(qry) |
| with _AOPSwitch(self.encoder, self._want_prune_for("tgt")): |
| with _VPoolSwitch(self.encoder, self._want_pool_for("tgt")): |
| p_multi = self._encode_multi(tgt) |
|
|
| |
| if self.is_ddp: |
| q_multi_all = self._dist_gather_tensor(q_multi) |
| p_multi_all = self._dist_gather_tensor(p_multi) |
| else: |
| q_multi_all, p_multi_all = q_multi, p_multi |
|
|
| Bglob, K, D = q_multi_all.shape |
| assert p_multi_all.shape[:2] == (Bglob, K), f"Shape mismatch: q {q_multi_all.shape}, p {p_multi_all.shape}" |
| target = torch.arange(Bglob, device=q_multi_all.device, dtype=torch.long) |
|
|
| w = torch.tensor(self.supervise_weights, dtype=torch.float32, device=q_multi_all.device) |
| w = torch.clamp(w, min=0) |
| w = w / max(w.sum().item(), 1e-8) |
|
|
| loss = 0.0 |
| for k in range(K): |
| |
| logits_k = torch.matmul(q_multi_all[:, k, :], p_multi_all[:, k, :].transpose(0, 1)) / self.temperature |
| loss_k = self.cross_entropy(logits_k, target) |
| loss = loss + w[k] * loss_k |
|
|
| if self.is_ddp: |
| loss = loss * self.world_size |
|
|
| return loss |
|
|
| def _dist_gather_tensor(self, t: Tensor): |
| t = t.contiguous() |
| all_tensors = [torch.empty_like(t) for _ in range(self.world_size)] |
| dist.all_gather(all_tensors, t) |
| all_tensors[self.process_rank] = t |
| all_tensors = torch.cat(all_tensors, dim=0) |
| return all_tensors |
|
|
| def compute_similarity(self, q_reps, p_reps): |
| return torch.matmul(q_reps, p_reps.transpose(0, 1)) |
|
|