mmE5-mllama-11b-instruct / custom_st.py
Tom Aarsen
Finalize Sentence Transformers integration
501f1f4
raw
history blame
3.36 kB
from io import BytesIO
from typing import Any, Dict, Optional, List
import torch
from PIL import Image
from transformers import AutoProcessor, MllamaForConditionalGeneration
from sentence_transformers.models import Transformer as BaseTransformer
class MultiModalTransformer(BaseTransformer):
def __init__(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
tokenizer_args: Optional[Dict[str, Any]] = None,
**kwargs,
):
super().__init__(model_name_or_path, **kwargs)
if tokenizer_args is None:
tokenizer_args = {}
tokenizer_args.pop("trust_remote_code", None)
# Initialize processor
self.processor = AutoProcessor.from_pretrained(
model_name_or_path, cache_dir=cache_dir, **tokenizer_args
)
def _load_model(
self,
model_name_or_path: str,
config,
cache_dir: str,
backend: str,
is_peft_model: bool,
**model_args,
) -> None:
model_args.pop("trust_remote_code", None)
self.auto_model = MllamaForConditionalGeneration.from_pretrained(
model_name_or_path, torch_dtype=torch.bfloat16, cache_dir=cache_dir, **model_args
)
def forward(
self, features: Dict[str, torch.Tensor], **kwargs
) -> Dict[str, torch.Tensor]:
# Process inputs through the model
outputs = self.auto_model(
**features,
return_dict=True,
output_hidden_states=True,
**kwargs
)
features.update({"token_embeddings": outputs.hidden_states[-1]})
return features
def tokenize(self, texts: List[List[Dict]] | List[str]) -> Dict[str, torch.Tensor]:
def process_text_item(item):
if isinstance(item, str):
return item, None
text, img = "", None
if "image" in item:
text += "<|image|>"
img = item["image"]
if isinstance(img, bytes):
img = Image.open(BytesIO(img)).convert("RGB")
elif isinstance(img, str):
img = Image.open(img).convert("RGB")
elif not isinstance(img, Image):
raise ValueError(f"Unknown image type {type(img)}")
if "text" in item:
if text:
text += "<|begin_of_text|> "
text += item["text"].lstrip()
return text, img
all_texts, all_images = [], []
for item in texts:
text, images = process_text_item(item)
all_texts.append(text)
all_images.append(images)
if all_images != [None] * len(all_images):
inputs = self.processor(
text=all_texts,
images=all_images,
padding="longest",
truncation=True,
max_length=self.max_seq_length,
return_tensors="pt"
)
else:
inputs = self.processor(
text=all_texts,
padding="longest",
truncation=True,
max_length=self.max_seq_length,
return_tensors="pt"
)
return inputs