gme-Qwen2-VL-2B-Instruct / custom_st.py
Samoed's picture
Base Integration with SentenceTransformers
2df56dc verified
raw
history blame
5.3 kB
from io import BytesIO
from typing import Any, Dict, Optional, List
import torch
from PIL import Image
from sentence_transformers.models import Transformer as BaseTransformer
from transformers import AutoModelForVision2Seq, AutoProcessor
class MultiModalTransformer(BaseTransformer):
def __init__(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
tokenizer_args: Optional[Dict[str, Any]] = None,
min_image_tokens: int = 256,
max_image_tokens: int = 1280,
max_length: int = 1800,
**kwargs,
):
super().__init__(model_name_or_path, **kwargs)
if tokenizer_args is None:
tokenizer_args = {}
tokenizer_args.pop("trust_remote_code", None)
# Initialize processor
min_pixels = min_image_tokens * 28 * 28
max_pixels = max_image_tokens * 28 * 28
self.processor = AutoProcessor.from_pretrained(
model_name_or_path, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs
)
self.processor.tokenizer.padding_side = 'right'
self.sep = ' '
self.max_length = max_length
self.normalize = True
def _load_model(
self,
model_name_or_path: str,
config,
cache_dir: str,
backend: str,
is_peft_model: bool,
**model_args,
) -> None:
model_args.pop("trust_remote_code", None)
self.auto_model = AutoModelForVision2Seq.from_pretrained(
model_name_or_path, torch_dtype=torch.float16, **model_args
)
def forward(
self, features: Dict[str, torch.Tensor], **kwargs
) -> Dict[str, torch.Tensor]:
if features.get("inputs_embeds", None) is None:
features["inputs_embeds"] = self.auto_model.base_model.embed_tokens(features["input_ids"])
if features.get("pixel_values", None) is not None:
features["pixel_values"] = features["pixel_values"].type(self.auto_model.visual.get_dtype())
image_embeds = self.auto_model.visual(
features["pixel_values"], grid_thw=features["image_grid_thw"]
)
image_mask = features["input_ids"] == self.auto_model.config.image_token_id
features["inputs_embeds"][image_mask] = image_embeds
features.pop("pixel_values")
features.pop("image_grid_thw")
features.pop("input_ids")
outputs = self.auto_model.model(
**features,
return_dict=True,
output_hidden_states=True,
# **kwargs
)
pooling_mask = features["attention_mask"] if features.get("pooling_mask", None) is None else features["pooling_mask"]
left_padding = (pooling_mask[:, -1].sum() == pooling_mask.shape[0]) # TODO
if left_padding:
embeddings = outputs.last_hidden_state
else:
sequence_lengths = pooling_mask.sum(dim=1) - 1
embeddings = outputs.last_hidden_state[torch.arange(
outputs.last_hidden_state.shape[0], device=outputs.last_hidden_state.device
), sequence_lengths]
features.update({"token_embeddings": embeddings})
return features
def tokenize(self, texts: List[List[Dict[str, Image.Image]]] | List[str]) -> Dict[str, torch.Tensor]:
split_token = "<|im_end|>\n"
def process_text_item(item):
if isinstance(item, str):
return item, None
text, img = "", None
if "image" in item:
text += "<|vision_start|><|image_pad|><|vision_end|>"
img = item["image"]
if isinstance(img, bytes):
img = Image.open(BytesIO(img)).convert("RGB")
elif isinstance(img, str):
img = Image.open(img).convert("RGB")
elif not isinstance(img, Image):
raise ValueError(f"Unknown image type {type(img)}")
if "text" in item:
text += item["text"].lstrip()
if split_token in text:
instruction, text = text.split(split_token, 1)
text = f'{instruction}{split_token}<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>'
else:
text = f"<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>"
return text, img
all_texts, all_images = [], []
for item in texts:
text, images = process_text_item(item)
all_texts.append(text)
all_images.append(images)
if all_images != [None] * len(all_images):
inputs = self.processor(
text=all_texts,
images=all_images,
padding="longest",
truncation=True,
max_length=self.max_seq_length,
return_tensors="pt"
)
else:
inputs = self.processor(
text=all_texts,
padding="longest",
truncation=True,
max_length=self.max_seq_length,
return_tensors="pt"
)
return inputs