| import os |
| import sys |
| import importlib |
| from dataclasses import dataclass |
| from typing import Optional, Tuple, Union |
|
|
| import torch |
| from huggingface_hub import snapshot_download |
| from safetensors.torch import load_file |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import ModelOutput |
|
|
| from .configuration_m2_encoder import M2EncoderConfig |
|
|
|
|
| @dataclass |
| class M2EncoderOutput(ModelOutput): |
| loss: Optional[torch.FloatTensor] = None |
| text_embeds: Optional[torch.FloatTensor] = None |
| image_embeds: Optional[torch.FloatTensor] = None |
| logits_per_image: Optional[torch.FloatTensor] = None |
| logits_per_text: Optional[torch.FloatTensor] = None |
|
|
|
|
| class M2EncoderModel(PreTrainedModel): |
| config_class = M2EncoderConfig |
| base_model_prefix = "m2_encoder" |
| main_input_name = "pixel_values" |
|
|
| def __init__(self, config: M2EncoderConfig): |
| super().__init__(config) |
| model_dir = getattr(config, "_model_dir", None) |
| if model_dir is None: |
| raise ValueError( |
| "M2EncoderConfig is missing `_model_dir`. Use " |
| "`M2EncoderModel.from_pretrained(...)` so the checkpoint path can be resolved." |
| ) |
| if model_dir not in sys.path: |
| sys.path.insert(0, model_dir) |
|
|
| vlmo_default_config = importlib.import_module("vlmo.config").config |
| VLMo = importlib.import_module("vlmo.modules").VLMo |
|
|
| vlmo_config = vlmo_default_config() |
| vlmo_config.update(config.to_vlmo_overrides(model_dir)) |
| load_path = vlmo_config["load_path"] |
| use_safetensors = load_path.endswith(".safetensors") |
| if use_safetensors: |
| vlmo_config["load_path"] = "" |
|
|
| if vlmo_config["flash_attn"]: |
| patch_torch_scale_with_flash_attn = importlib.import_module( |
| "vlmo.utils.patch_utils" |
| ).patch_torch_scale_with_flash_attn |
| patch_torch_scale_with_flash_attn() |
|
|
| self.model = VLMo(vlmo_config) |
| if use_safetensors: |
| state_dict = load_file(load_path) |
| self.model.load_state_dict(state_dict, strict=False) |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| pretrained_model_name_or_path, |
| *model_args, |
| config: Optional[M2EncoderConfig] = None, |
| **kwargs, |
| ): |
| model_dir = pretrained_model_name_or_path |
| if not os.path.isdir(model_dir): |
| model_dir = snapshot_download(repo_id=pretrained_model_name_or_path) |
|
|
| if config is None: |
| config = M2EncoderConfig.from_pretrained(model_dir, **kwargs) |
| checkpoint_path = os.path.join( |
| model_dir, |
| kwargs.pop("m2_checkpoint_name", config.model_file), |
| ) |
| if not os.path.exists(checkpoint_path): |
| raise FileNotFoundError( |
| f"Missing M2-Encoder checkpoint: {checkpoint_path}" |
| ) |
| config._model_dir = model_dir |
| return cls(config, *model_args) |
|
|
| def get_text_features( |
| self, |
| input_ids: torch.LongTensor, |
| attention_mask: torch.LongTensor, |
| ) -> torch.FloatTensor: |
| outputs = self.model.infer_text( |
| { |
| "text_ids": input_ids, |
| "text_masks": attention_mask, |
| "text_labels": None, |
| } |
| ) |
| return outputs["cls_vlffn_feats"] |
|
|
| def get_image_features(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor: |
| outputs = self.model.infer_image({"image": [pixel_values]}) |
| return outputs["cls_vlffn_feats"] |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.LongTensor] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| return_dict: Optional[bool] = True, |
| **kwargs, |
| ) -> Union[M2EncoderOutput, Tuple[torch.FloatTensor, ...]]: |
| text_embeds = None |
| image_embeds = None |
|
|
| if input_ids is not None: |
| if attention_mask is None: |
| attention_mask = torch.ones_like(input_ids) |
| text_embeds = self.get_text_features( |
| input_ids=input_ids, attention_mask=attention_mask |
| ) |
|
|
| if pixel_values is not None: |
| image_embeds = self.get_image_features(pixel_values=pixel_values) |
|
|
| logits_per_image = None |
| logits_per_text = None |
| if image_embeds is not None and text_embeds is not None: |
| logit_scale = self.model.logit_scale.exp() |
| logits_per_image = logit_scale * image_embeds @ text_embeds.t() |
| logits_per_text = logits_per_image.t() |
|
|
| if not return_dict: |
| return tuple( |
| value |
| for value in ( |
| text_embeds, |
| image_embeds, |
| logits_per_image, |
| logits_per_text, |
| ) |
| if value is not None |
| ) |
|
|
| return M2EncoderOutput( |
| text_embeds=text_embeds, |
| image_embeds=image_embeds, |
| logits_per_image=logits_per_image, |
| logits_per_text=logits_per_text, |
| ) |
|
|