| import math |
| import os |
| import re |
| import json |
| from typing import List, Optional, Dict, Tuple, Union |
|
|
| from PIL import Image |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration |
|
|
| |
| _EMPTY_SENTINELS = {"", "-1", "none", "null", "na", "n/a", "nan", "<na>"} |
|
|
| def _is_empty_cell(x) -> bool: |
| """True if x should be considered 'missing'.""" |
| if x is None: |
| return True |
| |
| try: |
| if isinstance(x, float) and math.isnan(x): |
| return True |
| except Exception: |
| pass |
| s = str(x).strip().lower() |
| return s in _EMPTY_SENTINELS |
|
|
| def _clean_text_or_empty(x) -> str: |
| """Return a clean string or '' if missing.""" |
| return "" if _is_empty_cell(x) else str(x).strip() |
|
|
| try: |
| from peft import LoraConfig, get_peft_model |
| HAS_PEFT = True |
| except Exception: |
| HAS_PEFT = False |
|
|
|
|
| |
|
|
| def l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-12) -> torch.Tensor: |
| return x / (x.norm(dim=dim, keepdim=True) + eps) |
|
|
| def masked_mean_pool(hidden: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| """Mean over tokens where mask==True.""" |
| if mask is None: |
| return hidden.mean(dim=1) |
| mask = mask.to(hidden.dtype) |
| denom = mask.sum(dim=1, keepdim=True).clamp_min(1e-6) |
| return (hidden * mask.unsqueeze(-1)).sum(dim=1) / denom |
|
|
| def to_qwen_grid(img: Image.Image, target: int = 512, patch_size: int = 14, merge_size: int = 2) -> Image.Image: |
| """ |
| Resize image so H=W is a multiple of 28 (=patch_size*merge_size). |
| FLOOR to nearest multiple (512->504, 1024->1008). |
| """ |
| grid = patch_size * merge_size |
| new = max(grid, (target // grid) * grid) |
| return img.resize((new, new), Image.BILINEAR) |
|
|
| def _open_or_none(path: object, root: str = "") -> Optional[Image.Image]: |
| """Returns a PIL.Image or None. Handles '', NaN, '-1', <NA>, etc.""" |
| if _is_empty_cell(path): |
| return None |
| p = str(path).strip() |
| |
| if root and not re.match(r'^[a-zA-Z][a-zA-Z0-9+\-.]*://', p): |
| p = os.path.join(root, p) |
| try: |
| return Image.open(p).convert("RGB") |
| except Exception: |
| return None |
|
|
| def build_image_map_from_row(row, root: str = "") -> dict: |
| """ |
| Mapping per your schema: |
| - frontal_image <- img_path1 (also used as current_image) |
| - lateral_image <- img_path2 |
| - prior_image <- img_path3 |
| """ |
| m = { |
| "frontal_image": _open_or_none(str(row.get("img_path1", "-1")), root), |
| "lateral_image": _open_or_none(str(row.get("img_path2", "-1")), root), |
| "prior_image": _open_or_none(str(row.get("img_path3", "-1")), root), |
| } |
| |
| n1 = _open_or_none(str(row.get("neg_image1", row.get("neg_path1", "-1"))), root) |
| n2 = _open_or_none(str(row.get("neg_image2", row.get("neg_path2", "-1"))), root) |
| |
| n3 = _open_or_none(str(row.get("neg_image3", row.get("neg_prior_image", row.get("neg_path3", "-1")))), root) |
| if n1 is not None: |
| m.update({"neg_image1": n1, "neg_path1": n1, "neg_frontal_image": n1}) |
| if n2 is not None: |
| m.update({"neg_image2": n2, "neg_path2": n2, "neg_lateral_image": n2}) |
| if n3 is not None: |
| m.update({"neg_prior_image": n3, "neg_image3": n3, "neg_path3": n3}) |
| return m |
|
|
| def _s(x): return "" if x is None else str(x) |
|
|
| def build_text_map_from_row(row) -> Dict[str, str]: |
| m = { |
| "report": _clean_text_or_empty(row.get("report")), |
| "prior_report": _clean_text_or_empty(row.get("prior_report")), |
| "demographics": _clean_text_or_empty(row.get("demographics")), |
| |
| "lab_test": _clean_text_or_empty(row.get("lab_test")), |
| "indication": _clean_text_or_empty(row.get("indication")), |
| } |
| |
| return {k: v for k, v in m.items() if v} |
|
|
| def parse_text_placeholders(s) -> dict: |
| if isinstance(s, dict): |
| d = s |
| elif isinstance(s, str) and s.strip(): |
| try: |
| d = json.loads(s) |
| except Exception: |
| d = {} |
| else: |
| d = {} |
| if not isinstance(d, dict): |
| return {} |
| out = {} |
| for k, v in d.items(): |
| val = _clean_text_or_empty(v) |
| if val: |
| out[str(k).lower()] = val |
| return out |
|
|
|
|
| |
|
|
| class LatentAttentionPooler(nn.Module): |
| """ |
| NV-Embed style: tokens (Q) attend to trainable latents (K=V), then MLP, |
| then mean-pool over tokens (optionally masked). |
| """ |
| def __init__(self, dim: int, num_latents: int = 512, num_layers: int = 1, |
| num_heads: int = 8, mlp_ratio: float = 2.0): |
| super().__init__() |
| self.latents = nn.Parameter(torch.randn(num_latents, dim) / math.sqrt(dim)) |
| self.layers = nn.ModuleList() |
| self.ln_q = nn.LayerNorm(dim) |
| self.ln_kv = nn.LayerNorm(dim) |
|
|
| for _ in range(num_layers): |
| attn = nn.MultiheadAttention(dim, num_heads, batch_first=True) |
| ffn = nn.Sequential( |
| nn.Linear(dim, int(dim * mlp_ratio)), |
| nn.GELU(), |
| nn.Linear(int(dim * mlp_ratio), dim), |
| ) |
| self.layers.append(nn.ModuleDict({"attn": attn, "ffn": ffn})) |
|
|
| def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| |
| B, S, D = x.shape |
|
|
| |
| q = self.ln_q(x) |
| lat = self.latents.unsqueeze(0).expand(B, -1, -1).contiguous() |
| kv = self.ln_kv(lat) |
|
|
| |
| for blk in self.layers: |
| y = blk["attn"](q, kv, kv, need_weights=False)[0] |
| q = q + y |
| q = q + blk["ffn"](q) |
|
|
| |
| return masked_mean_pool(q, mask) |
|
|
| class Projection(nn.Module): |
| def __init__(self, in_dim: int, out_dim: int = 1024, hidden: Optional[int] = None): |
| super().__init__() |
| if hidden is None: |
| self.proj = nn.Sequential(nn.Linear(in_dim, out_dim, bias=False)) |
| else: |
| self.proj = nn.Sequential(nn.Linear(in_dim, hidden), nn.GELU(), nn.Linear(hidden, out_dim, bias=False)) |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return l2norm(self.proj(x)) |
|
|
|
|
| |
|
|
| class LingshuEmbedder(nn.Module): |
| def __init__( |
| self, |
| model_name: str = "lingshu-medical-mllm/Lingshu-7B", |
| attn_implementation: str = "flash_attention_2", |
| torch_dtype: torch.dtype = torch.bfloat16, |
| embed_dim: int = 1024, |
| |
| |
| pool_mode: str = "latent_attention", |
| num_latents_unified: int = 512, |
| |
| |
| image_size: int = 504, |
| min_grid: int = 256, |
| max_grid: int = 1296, |
| |
| |
| |
| use_lora: bool = False, |
| lora_r: int = 64, lora_alpha: int = 64, lora_dropout: float = 0.0, |
| apply_lora_to_vision: bool = False, |
| |
| |
| bidirectional: bool = True, |
| |
| |
| max_text_tokens: int = 2560, |
| |
| |
| enable_gradient_checkpointing: bool = False, |
| |
| device: Optional[Union[str, torch.device]] = None, |
| ) -> None: |
| super().__init__() |
|
|
| |
| if device is None: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| else: |
| device = torch.device(device) |
| if device.type != "cuda": |
| attn_implementation = "sdpa" |
| if torch_dtype in (torch.float16, torch.bfloat16): |
| torch_dtype = torch.float32 |
|
|
| |
| self.vl = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| model_name, torch_dtype=torch_dtype, attn_implementation=attn_implementation |
| ) |
| self.processor = AutoProcessor.from_pretrained( |
| model_name, |
| min_pixels=min_grid * 28 * 28, |
| max_pixels=max_grid * 28 * 28, |
| ) |
| self._propagate_attn_impl(attn_implementation) |
|
|
| |
| for p in self.vl.parameters(): |
| p.requires_grad_(False) |
| |
| |
| |
| unfrozen_modules = [] |
| for name, module in self.vl.named_modules(): |
| |
| if any(x in name.lower() for x in ['visual.merger', 'visual.proj', 'vision_proj', 'mm_projector']): |
| n_params = sum(p.numel() for p in module.parameters()) |
| for p in module.parameters(): |
| p.requires_grad_(True) |
| unfrozen_modules.append((name, n_params)) |
| |
| if unfrozen_modules: |
| print(f"[model] Unfrozen vision projector modules for memorization:") |
| for name, n_params in unfrozen_modules: |
| print(f" - {name}: {n_params:,} parameters") |
|
|
| |
| txt_hidden = getattr(self.vl.config, "text_config", None) |
| vis_hidden = getattr(self.vl.config, "vision_config", None) |
| self.text_hidden = getattr(txt_hidden, "hidden_size", None) |
| self.vision_hidden = getattr(vis_hidden, "out_hidden_size", None) or getattr(vis_hidden, "hidden_size", None) |
|
|
| |
| self.text_proj = Projection(self.text_hidden, embed_dim, hidden=None) |
| self.image_proj = Projection(self.vision_hidden, embed_dim, hidden=None) |
| self.unified_proj = Projection(self.text_hidden, embed_dim, hidden=None) |
|
|
| self.logit_scale = nn.Parameter(torch.tensor(math.log(1/0.07))) |
|
|
| |
| self.pool_mode = pool_mode |
| if self.pool_mode == "latent_attention": |
| self.unified_pooler = LatentAttentionPooler( |
| dim=self.text_hidden, |
| num_latents=num_latents_unified, |
| num_layers=1, |
| num_heads=8 |
| ) |
| else: |
| self.unified_pooler = None |
|
|
| |
| if image_size % 28 != 0: |
| raise ValueError(f"image_size must be a multiple of 28, got {image_size}") |
| self.image_size = image_size |
|
|
| |
| self.peft_active = False |
| if use_lora: |
| if not HAS_PEFT: |
| raise ImportError("peft not installed") |
| targets_text = ("q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj") |
| targets_vision = ("qkv", "proj") |
| targets = list(set(targets_text + (targets_vision if apply_lora_to_vision else tuple()))) |
| cfg = LoraConfig(r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, |
| target_modules=targets, bias="none", task_type="CAUSAL_LM") |
| self.vl = get_peft_model(self.vl, cfg) |
| self.peft_active = True |
|
|
| |
| if bidirectional: |
| self._enable_bidirectional_attention() |
|
|
| |
| if enable_gradient_checkpointing: |
| |
| try: |
| self.vl.gradient_checkpointing_enable( |
| gradient_checkpointing_kwargs={"use_reentrant": False} |
| ) |
| except TypeError: |
| |
| self.vl.gradient_checkpointing_enable() |
| try: |
| self.vl.config.use_cache = False |
| except Exception: |
| pass |
|
|
| |
| self.to(device) |
| self.device = device |
|
|
| |
| base_dtype = next(self.vl.parameters()).dtype |
| if getattr(self, "unified_pooler", None) is not None: |
| self.unified_pooler.to(device=device, dtype=base_dtype) |
|
|
| |
| self.max_text_tokens = int(max_text_tokens) |
|
|
| |
|
|
| def _propagate_attn_impl(self, impl: str): |
| cfgs = [getattr(self.vl, "config", None)] |
| if cfgs[0] is not None: |
| for sub in ("text_config", "vision_config"): |
| cfgs.append(getattr(cfgs[0], sub, None)) |
| for cfg in cfgs: |
| if cfg is None: |
| continue |
| try: |
| cfg._attn_implementation = impl |
| cfg.attn_implementation = impl |
| if hasattr(cfg, "use_flash_attention_2"): |
| cfg.use_flash_attention_2 = (impl == "flash_attention_2") |
| except Exception: |
| pass |
| for _, module in self.vl.named_modules(): |
| if hasattr(module, "config"): |
| try: |
| module.config._attn_implementation = impl |
| module.config.attn_implementation = impl |
| if hasattr(module.config, "use_flash_attention_2"): |
| module.config.use_flash_attention_2 = (impl == "flash_attention_2") |
| except Exception: |
| pass |
|
|
| def _enable_bidirectional_attention(self): |
| """Best-effort removal of causal masking.""" |
| cfg = getattr(self.vl, "config", None) |
| if cfg is not None: |
| if hasattr(cfg, "is_decoder"): cfg.is_decoder = False |
| if hasattr(cfg, "use_cache"): cfg.use_cache = False |
| core = getattr(self.vl, "model", self.vl) |
| core_cfg = getattr(core, "config", None) |
| if core_cfg is not None: |
| if hasattr(core_cfg, "is_decoder"): core_cfg.is_decoder = False |
| if hasattr(core_cfg, "use_cache"): core_cfg.use_cache = False |
| for m in self.vl.modules(): |
| if hasattr(m, "is_causal"): |
| try: |
| m.is_causal = False |
| except Exception: |
| pass |
|
|
| def _get_text_module(self): |
| core = getattr(self.vl, "model", self.vl) |
| for attr in ("language_model", "text_model", "lm"): |
| m = getattr(core, attr, None) |
| if m is not None and hasattr(m, "forward"): |
| return m |
| for _, module in self.vl.named_modules(): |
| cname = module.__class__.__name__.lower() |
| if "vision" in cname: |
| continue |
| if hasattr(module, "forward") and hasattr(module, "embed_tokens"): |
| return module |
| raise AttributeError("Could not locate the text submodule in Qwen-VL.") |
|
|
| def _get_vision_module(self): |
| core = getattr(self.vl, "model", self.vl) |
| for attr in ("vision_model", "vision_tower", "visual", "vision"): |
| m = getattr(core, attr, None) |
| if m is not None and hasattr(m, "forward"): |
| return m |
| for _, module in self.vl.named_modules(): |
| if "vision" in module.__class__.__name__.lower(): |
| return module |
| raise AttributeError("Could not locate the vision submodule in Qwen-VL.") |
|
|
| def _get_vision_entry(self): |
| """ |
| Return the top-level VisionModel object that accepts: |
| forward(pixel_values=..., grid_thw=..., output_hidden_states=..., return_dict=True) |
| Avoid returning the inner transformer which expects (hidden_states, grid_thw). |
| """ |
| core = getattr(self.vl, "model", self.vl) |
| |
| vis = getattr(core, "vision_model", None) |
| if vis is not None: |
| return vis |
| |
| for _, m in core.named_modules(): |
| name = m.__class__.__name__.lower() |
| if name.endswith("visionmodel"): |
| return m |
| |
| return self._get_vision_module() |
|
|
| |
|
|
| def _target_from_image_size(self, image_size: Optional[int]) -> int: |
| """ |
| Return a pixel target that will be floored to a multiple of 28 by to_qwen_grid(). |
| Any multiple of 28 works (e.g., 448, 504, 1008). |
| """ |
| sz = image_size if isinstance(image_size, int) and image_size % 28 == 0 else self.image_size |
| return int(sz) |
|
|
| def _build_interleaved_content(self, text: str, imgs: List[Image.Image], append_unused_images: bool = False) -> Tuple[list, list]: |
| """ |
| NUMERIC placeholders: <image1>, <image2>, ... |
| Returns (content_list, images_in_order). |
| """ |
| if text is None: |
| text = "" |
| content: list = [] |
| ordered_images: list = [] |
| imgs = imgs or [] |
|
|
| pat = re.compile(r"<image\s*(\d+)\s*>", re.IGNORECASE) |
| pos = 0 |
| matches = list(pat.finditer(text)) |
|
|
| if not matches: |
| |
| if text.strip(): |
| content.append({"type": "text", "text": text}) |
| if append_unused_images: |
| for im in imgs: |
| content.append({"type": "image", "image": im}) |
| ordered_images.append(im) |
| return content, ordered_images |
|
|
| for m in matches: |
| s, e = m.span() |
| if s > pos: |
| seg = text[pos:s] |
| if seg.strip(): |
| content.append({"type": "text", "text": seg}) |
| idx = int(m.group(1)) - 1 |
| if 0 <= idx < len(imgs): |
| content.append({"type": "image", "image": imgs[idx]}) |
| ordered_images.append(imgs[idx]) |
| pos = e |
|
|
| if pos < len(text): |
| seg = text[pos:] |
| if seg.strip(): |
| content.append({"type": "text", "text": seg}) |
|
|
| if append_unused_images: |
| used = set(ordered_images) |
| for im in imgs: |
| if im not in used: |
| content.append({"type": "image", "image": im}) |
| ordered_images.append(im) |
|
|
| return content, ordered_images |
|
|
| def _build_content_from_template( |
| self, |
| template: str, |
| image_map: Optional[Dict[str, Image.Image]], |
| text_map: Optional[Dict[str, str]], |
| append_unused_images: bool = False, |
| ) -> Tuple[list, list]: |
| """ |
| NAMED placeholders: <frontal_image>, <lateral_image>, <prior_image>, <report>, <prior_report>, <demographics>, ... |
| Also supports alias: <current_image> -> <frontal_image>. |
| """ |
| template = template or "" |
| image_map = {k.lower(): v for k, v in (image_map or {}).items() if v is not None} |
| text_map = {k.lower(): v for k, v in (text_map or {}).items() if v is not None and str(v).strip()} |
|
|
| content: list = [] |
| images_in_order: list = [] |
|
|
| pat = re.compile(r"<\s*([A-Za-z_]\w*)\s*>") |
| pos = 0 |
| for m in pat.finditer(template): |
| s, e = m.span() |
| if s > pos: |
| seg = template[pos:s] |
| if seg.strip(): |
| content.append({"type": "text", "text": seg}) |
|
|
| name = m.group(1).lower() |
| |
| if name == "current_image": |
| name = "frontal_image" |
|
|
| if name in image_map: |
| img = image_map.get(name) |
| if img is not None: |
| content.append({"type": "image", "image": img}) |
| images_in_order.append(img) |
| else: |
| val = text_map.get(name) |
| if val is not None: |
| content.append({"type": "text", "text": str(val)}) |
|
|
| pos = e |
|
|
| if pos < len(template): |
| tail = template[pos:] |
| if tail.strip(): |
| content.append({"type": "text", "text": tail}) |
|
|
| |
| if append_unused_images: |
| for key, img in image_map.items(): |
| if img is not None and img not in images_in_order: |
| content.append({"type": "image", "image": img}) |
| images_in_order.append(img) |
|
|
| return content, images_in_order |
|
|
| def _mask_last_role_block(self, inputs: dict, hidden: torch.Tensor) -> torch.Tensor: |
| """ |
| Boolean mask (B,S) selecting tokens inside the **last** role block (user/assistant), |
| excluding the final <|im_end|>, for **any** batch size. |
| Falls back to attention_mask if special tokens are unavailable. |
| """ |
| device = hidden.device |
| ids = inputs.get("input_ids", None) |
| attn = inputs.get("attention_mask", None) |
| if ids is None: |
| return (attn if attn is not None else torch.ones(hidden.shape[:2], device=device, dtype=torch.long)).bool() |
|
|
| B, S = ids.shape |
| mask = torch.zeros((B, S), device=device, dtype=torch.bool) |
|
|
| |
| try: |
| start_id = self.processor.tokenizer.convert_tokens_to_ids("<|im_start|>") |
| except Exception: |
| start_id = None |
| try: |
| end_id = self.processor.tokenizer.convert_tokens_to_ids("<|im_end|>") |
| except Exception: |
| end_id = None |
|
|
| if end_id is None: |
| return (attn if attn is not None else torch.ones((B, S), device=device, dtype=torch.long)).bool() |
|
|
| for b in range(B): |
| |
| if attn is not None: |
| valid_len = int(attn[b].sum().item()) |
| else: |
| valid_len = S |
| valid_len = max(1, min(valid_len, S)) |
| seq = ids[b, :valid_len] |
|
|
| ends = (seq == end_id).nonzero(as_tuple=False).flatten() |
| if ends.numel() == 0: |
| |
| mask[b, :valid_len] = True |
| continue |
| last_end = int(ends[-1].item()) |
|
|
| last_start = -1 |
| if start_id is not None: |
| starts = (seq == start_id).nonzero(as_tuple=False).flatten() |
| starts_before = starts[starts < last_end] if starts.numel() > 0 else None |
| if starts_before is not None and starts_before.numel() > 0: |
| last_start = int(starts_before[-1].item()) |
| elif ends.numel() >= 2: |
| |
| last_start = int(ends[-2].item()) |
| else: |
| if ends.numel() >= 2: |
| last_start = int(ends[-2].item()) |
|
|
| left = max(last_start + 1, 0) |
| right = max(last_end - 1, left) |
| mask[b, left:right + 1] = True |
|
|
| if attn is not None: |
| mask = mask & attn.bool() |
| return mask |
|
|
| |
|
|
| @torch.no_grad() |
| def encode_text_unified(self, instructions: List[Optional[str]], texts: List[str], role: str = "user", |
| normalize: bool = True) -> torch.Tensor: |
| """Text-only, but still go through the unified VL path for consistency.""" |
| empty_images = [[] for _ in texts] |
| return self.encode_interleaved(instructions, texts, empty_images, role=role, normalize=normalize) |
|
|
| @torch.no_grad() |
| def encode_images_unified(self, instructions: List[Optional[str]], image_templates: List[str], |
| image_maps: List[Dict[str, Image.Image]], role: str = "user", |
| normalize: bool = True, image_size: Optional[int] = None) -> torch.Tensor: |
| """ |
| Image-only via unified path. Pass templates like "<frontal_image>" or "" (images only included if explicitly referenced). |
| """ |
| empty_text_maps = [{} for _ in image_templates] |
| return self.encode_interleaved_with_ph(instructions, image_templates, image_maps, empty_text_maps, |
| role=role, normalize=normalize, image_size=image_size) |
|
|
| @torch.no_grad() |
| def encode_interleaved( |
| self, |
| instructions: List[Optional[str]], |
| contents: List[str], |
| images: List[List[Image.Image]], |
| role: str = "user", |
| normalize: bool = True, |
| image_size: Optional[int] = None, |
| ) -> torch.Tensor: |
| device = self.device |
| vm = self._get_vision_module() |
| vision_dtype = next(vm.parameters()).dtype |
|
|
| assert len(instructions) == len(contents) == len(images), "length mismatch" |
| out_vecs = [] |
| target = self._target_from_image_size(image_size) |
|
|
| for inst, text, imgs in zip(instructions, contents, images): |
| proc_imgs = [to_qwen_grid(im, target=target) for im in (imgs or [])] |
| content_list, images_in_order = self._build_interleaved_content( |
| text or "", proc_imgs, append_unused_images=False |
| ) |
|
|
| msgs = [] |
| if inst and str(inst).strip(): |
| msgs.append({"role": "system", "content": [{"type": "text", "text": inst}]}) |
| msgs.append({"role": role, "content": content_list}) |
|
|
| chat_text = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False) |
|
|
| proc = self.processor( |
| text=[chat_text], |
| images=images_in_order if images_in_order else None, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| do_resize=False, |
| max_length=self.max_text_tokens, |
| ) |
| inputs = {k: v.to(device) for k, v in proc.items()} |
| if "pixel_values" in inputs: |
| inputs["pixel_values"] = inputs["pixel_values"].to(device=device, dtype=vision_dtype) |
| if "image_grid_thw" in inputs: |
| inputs["image_grid_thw"] = inputs["image_grid_thw"].to(device) |
|
|
| out = self.vl(**inputs, output_hidden_states=True, use_cache=False) |
| hidden = out.hidden_states[-1] |
| span_mask = self._mask_last_role_block(inputs, hidden) |
|
|
| if self.pool_mode == "latent_attention": |
| pool_dtype = next(self.unified_pooler.parameters()).dtype |
| if hidden.dtype != pool_dtype: |
| hidden = hidden.to(dtype=pool_dtype) |
| vec = self.unified_pooler(hidden, span_mask).squeeze(0) |
| else: |
| vec = masked_mean_pool(hidden, span_mask).squeeze(0) |
|
|
| out_vecs.append(vec) |
|
|
| embs = torch.stack(out_vecs, dim=0) |
| proj_dtype = next(self.unified_proj.parameters()).dtype |
| emb = self.unified_proj(embs.to(dtype=proj_dtype)) |
| if normalize: |
| emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12) |
| return emb |
|
|
| @torch.no_grad() |
| def encode_interleaved_with_ph( |
| self, |
| instructions: List[Optional[str]], |
| templates: List[str], |
| image_maps: List[Optional[Dict[str, Image.Image]]], |
| text_maps: List[Optional[Dict[str, str]]], |
| role: str = "user", |
| normalize: bool = True, |
| image_size: Optional[int] = None, |
| ) -> torch.Tensor: |
| device = self.device |
| vm = self._get_vision_module() |
| vision_dtype = next(vm.parameters()).dtype |
|
|
| assert len(instructions) == len(templates) == len(image_maps) == len(text_maps), "length mismatch" |
|
|
| vecs = [] |
| target = self._target_from_image_size(image_size) |
|
|
| for inst, tmpl, imap, tmap in zip(instructions, templates, image_maps, text_maps): |
| proc_imap: Dict[str, Image.Image] = {} |
| if imap: |
| for k, im in imap.items(): |
| if im is not None: |
| proc_imap[k.lower()] = to_qwen_grid(im, target=target) |
|
|
| content_list, images_in_order = self._build_content_from_template(tmpl or "", proc_imap, (tmap or {})) |
|
|
| msgs = [] |
| if inst and str(inst).strip(): |
| msgs.append({"role": "system", "content": [{"type": "text", "text": inst}]}) |
| msgs.append({"role": role, "content": content_list}) |
|
|
| chat_text = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False) |
|
|
| proc = self.processor( |
| text=[chat_text], |
| images=images_in_order if images_in_order else None, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| do_resize=False, |
| max_length=self.max_text_tokens, |
| ) |
| inputs = {k: v.to(device) for k, v in proc.items()} |
| if "pixel_values" in inputs: |
| inputs["pixel_values"] = inputs["pixel_values"].to(device=device, dtype=vision_dtype) |
| if "image_grid_thw" in inputs: |
| inputs["image_grid_thw"] = inputs["image_grid_thw"].to(device) |
|
|
| out = self.vl(**inputs, output_hidden_states=True, use_cache=False) |
| hidden = out.hidden_states[-1] |
| span_mask = self._mask_last_role_block(inputs, hidden) |
|
|
| if self.pool_mode == "latent_attention": |
| pool_dtype = next(self.unified_pooler.parameters()).dtype |
| if hidden.dtype != pool_dtype: |
| hidden = hidden.to(dtype=pool_dtype) |
| vec = self.unified_pooler(hidden, span_mask).squeeze(0) |
| else: |
| vec = masked_mean_pool(hidden, span_mask).squeeze(0) |
|
|
| vecs.append(vec) |
|
|
| embs = torch.stack(vecs, dim=0) |
| proj_dtype = next(self.unified_proj.parameters()).dtype |
| emb = self.unified_proj(embs.to(dtype=proj_dtype)) |
| if normalize: |
| emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12) |
| return emb |
|
|
| |
|
|
| @torch.no_grad() |
| def encode_text_dual(self, texts: List[str], normalize: bool = True) -> torch.Tensor: |
| device = self.device |
| tok = self.processor.tokenizer(text=texts, padding=True, truncation=True, return_tensors="pt", max_length=self.max_text_tokens) |
| tok = {k: v.to(device) for k, v in tok.items()} |
| lm = self._get_text_module() |
| out = lm(**tok, output_hidden_states=True, use_cache=False) |
| hidden = out.last_hidden_state |
| mask = tok.get("attention_mask") |
| pooled = masked_mean_pool(hidden, mask) |
| proj_dtype = next(self.text_proj.parameters()).dtype |
| emb = self.text_proj(pooled.to(dtype=proj_dtype)) |
| if normalize: |
| emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12) |
| return emb |
|
|
| @torch.no_grad() |
| def encode_images_dual(self, images: List[List[Image.Image]], normalize: bool = True, |
| image_size: Optional[int] = None) -> torch.Tensor: |
| device = self.device |
| flat = [img for group in images for img in group] |
| counts = [len(g) for g in images] |
| if len(flat) == 0: |
| proj_dtype = next(self.image_proj.parameters()).dtype |
| zeros = torch.zeros((len(images), self.vision_hidden), device=device, dtype=proj_dtype) |
| emb = self.image_proj(zeros) |
| if normalize: |
| emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12) |
| return emb |
| target = self._target_from_image_size(image_size) |
| processed = [to_qwen_grid(img, target=target) for img in flat] |
| proc = self.processor.image_processor(images=processed, return_tensors="pt", do_resize=False) |
| vm = self._get_vision_module() |
| vision_dtype = next(vm.parameters()).dtype |
| pixel_values = proc["pixel_values"].to(device=device, dtype=vision_dtype) |
| vis_out = vm(pixel_values=pixel_values, output_hidden_states=True) |
| feats = vis_out[0] if isinstance(vis_out, (tuple, list)) else getattr(vis_out, "last_hidden_state", None) |
| if feats is None: |
| feats = getattr(vis_out, "pooler_output", None) |
| if feats is None: |
| raise RuntimeError("Vision backbone did not return features as expected.") |
| per_img = feats.mean(dim=1) if feats.ndim == 3 else feats |
| splits = torch.split(per_img, counts, dim=0) |
| set_vecs = torch.stack([s.mean(dim=0) if s.ndim > 1 else s for s in splits], dim=0) |
| proj_dtype = next(self.image_proj.parameters()).dtype |
| emb = self.image_proj(set_vecs.to(dtype=proj_dtype)) |
| if normalize: |
| emb = emb / emb.norm(dim=-1, keepdim=True).clamp_min(1e-12) |
| return emb |
|
|
| |
|
|
| def _find_subsequence(self, haystack: list, needle: list) -> list: |
| """Return start indices where 'needle' occurs in 'haystack' (exact match).""" |
| if not haystack or not needle or len(needle) > len(haystack): |
| return [] |
| hits = [] |
| n = len(needle) |
| for i in range(len(haystack) - n + 1): |
| if haystack[i:i+n] == needle: |
| hits.append(i) |
| return hits |
|
|
| def _window_decode_matches(self, tokenizer, ids, target_lower: str) -> list: |
| """Fallback: sliding-window decode match (robust to BPE splits). Returns window (start,end) indices.""" |
| hits = [] |
| L = len(ids) |
| |
| for w in range(1, 8): |
| for i in range(0, L - w + 1): |
| s, e = i, i + w |
| text = tokenizer.decode(ids[s:e], skip_special_tokens=True).lower().replace(" ", "") |
| if target_lower in text: |
| hits.append((s, e)) |
| |
| hits = sorted(set(hits), key=lambda x: (x[1]-x[0], x[0])) |
| return hits |
|
|
| def _resize_heatmap_like(self, hm_np, target_w, target_h): |
| from PIL import Image |
| import numpy as np |
| |
| H, W = hm_np.shape |
| im = Image.fromarray((hm_np * 255.0).astype("uint8"), mode="L") |
| im = im.resize((target_w, target_h), Image.BILINEAR) |
| out = (np.array(im).astype("float32") / 255.0) |
| return out |
|
|
| def _overlay_heatmap_on_image(self, img_pil, hm_np, alpha=0.45): |
| """Return PIL with heatmap overlay; hm_np in [0,1] same size as img.""" |
| import matplotlib |
| import numpy as np |
| from PIL import Image |
|
|
| img = np.array(img_pil.convert("RGB")).astype("float32") / 255.0 |
| H, W = img.shape[:2] |
| hm = np.clip(hm_np, 0.0, 1.0) |
| if hm.shape[:2] != (H, W): |
| raise ValueError("Heatmap and image size mismatch") |
| |
| cmap = matplotlib.cm.get_cmap("jet") |
| color_hm = cmap(hm)[..., :3] |
| blended = (1.0 - alpha) * img + alpha * color_hm |
| blended = np.clip(blended, 0.0, 1.0) |
| return Image.fromarray((blended * 255).astype("uint8")) |
|
|
| def phrase_ground_and_visualize( |
| self, |
| word: str, |
| template: str, |
| row, |
| role: str = "user", |
| instruction: str = None, |
| image_size: int = None, |
| layer_for_text: int = -1, |
| save_dir: str = None, |
| return_arrays: bool = False, |
| ): |
| """ |
| Compute patch-level grounding for a word against images referenced in `template` filled by `row`. |
| Returns a PhraseGroundingOutput, and optionally writes overlay PNGs. |
| |
| Strategy: |
| - Build a single-sample chat like encode_interleaved_with_ph(). |
| - Forward Qwen-VL with hidden_states (+ attention if available). |
| - Locate word tokens inside last role block. |
| - Run vision tower once to get per-patch features per image. |
| - Project (text token avg) with text_proj, patches with image_proj; cosine sim per patch → heatmap. |
| - (Optional) also compute LM self-attn from word tokens to any image placeholders if available. |
| """ |
| import os, numpy as np, torch |
| from PIL import Image |
|
|
| device = self.device |
| tok = self.processor.tokenizer |
| target = self._target_from_image_size(image_size) |
|
|
| |
| imap = build_image_map_from_row(row, root="") |
| |
| |
| proc_imap = {k.lower(): to_qwen_grid(v, target=target) for k, v in (imap or {}).items() if v is not None} |
| tmap = build_text_map_from_row(row) |
|
|
| content_list, images_in_order = self._build_content_from_template(template or "", proc_imap, (tmap or {}), append_unused_images=False) |
|
|
| msgs = [] |
| if instruction and str(instruction).strip(): |
| msgs.append({"role": "system", "content": [{"type": "text", "text": f"INSTRUCTION:\n{instruction}"}]}) |
| msgs.append({"role": role, "content": content_list}) |
| chat_text = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False) |
|
|
| vm = self._get_vision_module() |
| vision_dtype = next(vm.parameters()).dtype |
|
|
| proc = self.processor( |
| text=[chat_text], |
| images=images_in_order if images_in_order else None, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| do_resize=False, |
| max_length=self.max_text_tokens, |
| ) |
| inputs = {k: v.to(device) for k, v in proc.items()} |
| if "pixel_values" in inputs: |
| inputs["pixel_values"] = inputs["pixel_values"].to(device=device, dtype=vision_dtype) |
| if "image_grid_thw" in inputs: |
| inputs["image_grid_thw"] = inputs["image_grid_thw"].to(device) |
|
|
| |
| with torch.no_grad(): |
| out = self.vl(**inputs, output_hidden_states=True, output_attentions=True, use_cache=False, return_dict=True) |
|
|
| hidden = out.hidden_states[layer_for_text] |
| span_mask = self._mask_last_role_block(inputs, hidden)[0].bool() |
| seq_ids = inputs["input_ids"][0].tolist() |
|
|
| |
| |
| tgt_ids = tok(word, add_special_tokens=False)["input_ids"] |
| last_role_positions = [i for i, m in enumerate(span_mask.tolist()) if m] |
| id_seq_in_span = [seq_ids[i] for i in last_role_positions] |
| hits = self._find_subsequence(id_seq_in_span, tgt_ids) |
| token_span = None |
| if hits: |
| start_in_span = hits[0] |
| abs_start = last_role_positions[start_in_span] |
| abs_end = last_role_positions[start_in_span + len(tgt_ids) - 1] + 1 |
| token_span = (abs_start, abs_end) |
| else: |
| |
| win_hits = self._window_decode_matches(tok, id_seq_in_span, target_lower=word.lower().replace(" ", "")) |
| if win_hits: |
| s, e = win_hits[0] |
| abs_start = last_role_positions[s] |
| abs_end = last_role_positions[e - 1] + 1 |
| token_span = (abs_start, abs_end) |
|
|
| if token_span is None: |
| |
| |
| last_idx = last_role_positions[-1] |
| token_span = (last_idx, last_idx + 1) |
|
|
| s_idx, e_idx = token_span |
| word_tokens = hidden[0, s_idx:e_idx, :] |
| |
| word_vec_txt = word_tokens.mean(dim=0, keepdim=True) |
|
|
| |
| heatmaps = [] |
| per_image_debug = [] |
| if "pixel_values" in inputs: |
| |
| vmodel = self._get_vision_entry() |
| with torch.no_grad(): |
| vout = vmodel( |
| pixel_values=inputs["pixel_values"], |
| grid_thw=inputs.get("image_grid_thw", None), |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
|
|
| |
| vlast = vout.last_hidden_state |
| B, Svis, C = vlast.shape |
|
|
| |
| grids = inputs.get("image_grid_thw", None) |
| if grids is not None: |
| |
| thw = grids.detach().cpu().tolist() |
| if isinstance(thw[0], (int, float)): |
| thw = [thw] |
| else: |
| thw = [[1, int(round(Svis ** 0.5)), int(round(Svis ** 0.5))] for _ in range(B)] |
|
|
| |
| per_img = [] |
| offset = 0 |
| for i in range(B): |
| t, h, w = map(int, thw[i]) |
| tokens_per = t * h * w |
| take_from = 1 if (Svis == tokens_per + 1) else 0 |
| patches = vlast[i, take_from:take_from + tokens_per, :] |
| per_img.append((patches, (t, h, w))) |
|
|
| proj_dtype_img = next(self.image_proj.parameters()).dtype |
| proj_dtype_txt = next(self.text_proj.parameters()).dtype |
|
|
| word_vec = self.text_proj(word_vec_txt.to(dtype=proj_dtype_txt)) |
| word_vec = word_vec / (word_vec.norm(dim=-1, keepdim=True) + 1e-12) |
|
|
| for (patches, (t, h, w)) in per_img: |
| patch_emb = self.image_proj(patches.to(dtype=proj_dtype_img)) |
| patch_emb = patch_emb / (patch_emb.norm(dim=-1, keepdim=True) + 1e-12) |
| sim = (patch_emb @ word_vec[0].T).squeeze(-1) |
| sim = sim.reshape(t, h, w).mean(dim=0) |
| smin, smax = float(sim.min()), float(sim.max()) |
| hm = (sim - smin) / max(1e-6, (smax - smin)) |
| heatmaps.append(hm.detach().cpu().numpy()) |
| per_image_debug.append({"tokens_per": t*h*w, "grid": (t, h, w)}) |
|
|
| |
| saved_paths = [] |
| if save_dir and heatmaps: |
| os.makedirs(save_dir, exist_ok=True) |
| for i, im in enumerate(images_in_order): |
| |
| tgt_w, tgt_h = im.size |
| hm_np = self._resize_heatmap_like(heatmaps[i], tgt_w, tgt_h) |
| overlay = self._overlay_heatmap_on_image(im, hm_np, alpha=0.45) |
| fname = os.path.join(save_dir, f"ground_{i:02d}_{word.replace(' ','_')}.png") |
| overlay.save(fname) |
| saved_paths.append(fname) |
|
|
| result = PhraseGroundingOutput( |
| token_span=(int(s_idx), int(e_idx)), |
| per_image=[{ |
| "heatmap": (heatmaps[i] if return_arrays else None), |
| "saved_path": (saved_paths[i] if i < len(saved_paths) else None), |
| "grid": per_image_debug[i].get("grid", None), |
| "tokens_per": per_image_debug[i].get("tokens_per", None), |
| "placeholder_attn": per_image_debug[i].get("placeholder_attn", None), |
| } for i in range(len(heatmaps))] |
| ) |
| return result |
|
|
|
|
| class PhraseGroundingOutput: |
| def __init__(self, token_span, per_image): |
| self.token_span = token_span |
| self.per_image = per_image |
|
|