|
from io import BytesIO |
|
from typing import Any, Dict, Optional, List |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoProcessor, MllamaForConditionalGeneration |
|
from sentence_transformers.models import Transformer as BaseTransformer |
|
|
|
|
|
class MultiModalTransformer(BaseTransformer): |
|
def __init__( |
|
self, |
|
model_name_or_path: str, |
|
cache_dir: Optional[str] = None, |
|
tokenizer_args: Optional[Dict[str, Any]] = None, |
|
**kwargs, |
|
): |
|
super().__init__(model_name_or_path, **kwargs) |
|
if tokenizer_args is None: |
|
tokenizer_args = {} |
|
tokenizer_args.pop("trust_remote_code", None) |
|
|
|
|
|
self.processor = AutoProcessor.from_pretrained( |
|
model_name_or_path, cache_dir=cache_dir, **tokenizer_args |
|
) |
|
|
|
def _load_model( |
|
self, |
|
model_name_or_path: str, |
|
config, |
|
cache_dir: str, |
|
backend: str, |
|
is_peft_model: bool, |
|
**model_args, |
|
) -> None: |
|
model_args.pop("trust_remote_code", None) |
|
self.auto_model = MllamaForConditionalGeneration.from_pretrained( |
|
model_name_or_path, torch_dtype=torch.bfloat16, cache_dir=cache_dir, **model_args |
|
) |
|
|
|
def forward( |
|
self, features: Dict[str, torch.Tensor], **kwargs |
|
) -> Dict[str, torch.Tensor]: |
|
|
|
outputs = self.auto_model( |
|
**features, |
|
return_dict=True, |
|
output_hidden_states=True, |
|
**kwargs |
|
) |
|
|
|
features.update({"token_embeddings": outputs.hidden_states[-1]}) |
|
return features |
|
|
|
def tokenize(self, texts: List[List[Dict]] | List[str]) -> Dict[str, torch.Tensor]: |
|
def process_text_item(item): |
|
if isinstance(item, str): |
|
return item, None |
|
|
|
text, img = "", None |
|
if "image" in item: |
|
text += "<|image|>" |
|
img = item["image"] |
|
if isinstance(img, bytes): |
|
img = Image.open(BytesIO(img)).convert("RGB") |
|
elif isinstance(img, str): |
|
img = Image.open(img).convert("RGB") |
|
elif not isinstance(img, Image): |
|
raise ValueError(f"Unknown image type {type(img)}") |
|
if "text" in item: |
|
if text: |
|
text += "<|begin_of_text|> " |
|
text += item["text"].lstrip() |
|
|
|
return text, img |
|
|
|
all_texts, all_images = [], [] |
|
for item in texts: |
|
text, images = process_text_item(item) |
|
all_texts.append(text) |
|
all_images.append(images) |
|
|
|
if all_images != [None] * len(all_images): |
|
inputs = self.processor( |
|
text=all_texts, |
|
images=all_images, |
|
padding="longest", |
|
truncation=True, |
|
max_length=self.max_seq_length, |
|
return_tensors="pt" |
|
) |
|
else: |
|
inputs = self.processor( |
|
text=all_texts, |
|
padding="longest", |
|
truncation=True, |
|
max_length=self.max_seq_length, |
|
return_tensors="pt" |
|
) |
|
|
|
return inputs |
|
|