File size: 3,357 Bytes
4c104a7 501f1f4 4c104a7 501f1f4 4c104a7 501f1f4 4c104a7 501f1f4 4c104a7 501f1f4 4c104a7 501f1f4 4c104a7 501f1f4 4c104a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
from io import BytesIO
from typing import Any, Dict, Optional, List
import torch
from PIL import Image
from transformers import AutoProcessor, MllamaForConditionalGeneration
from sentence_transformers.models import Transformer as BaseTransformer
class MultiModalTransformer(BaseTransformer):
def __init__(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
tokenizer_args: Optional[Dict[str, Any]] = None,
**kwargs,
):
super().__init__(model_name_or_path, **kwargs)
if tokenizer_args is None:
tokenizer_args = {}
tokenizer_args.pop("trust_remote_code", None)
# Initialize processor
self.processor = AutoProcessor.from_pretrained(
model_name_or_path, cache_dir=cache_dir, **tokenizer_args
)
def _load_model(
self,
model_name_or_path: str,
config,
cache_dir: str,
backend: str,
is_peft_model: bool,
**model_args,
) -> None:
model_args.pop("trust_remote_code", None)
self.auto_model = MllamaForConditionalGeneration.from_pretrained(
model_name_or_path, torch_dtype=torch.bfloat16, cache_dir=cache_dir, **model_args
)
def forward(
self, features: Dict[str, torch.Tensor], **kwargs
) -> Dict[str, torch.Tensor]:
# Process inputs through the model
outputs = self.auto_model(
**features,
return_dict=True,
output_hidden_states=True,
**kwargs
)
features.update({"token_embeddings": outputs.hidden_states[-1]})
return features
def tokenize(self, texts: List[List[Dict]] | List[str]) -> Dict[str, torch.Tensor]:
def process_text_item(item):
if isinstance(item, str):
return item, None
text, img = "", None
if "image" in item:
text += "<|image|>"
img = item["image"]
if isinstance(img, bytes):
img = Image.open(BytesIO(img)).convert("RGB")
elif isinstance(img, str):
img = Image.open(img).convert("RGB")
elif not isinstance(img, Image):
raise ValueError(f"Unknown image type {type(img)}")
if "text" in item:
if text:
text += "<|begin_of_text|> "
text += item["text"].lstrip()
return text, img
all_texts, all_images = [], []
for item in texts:
text, images = process_text_item(item)
all_texts.append(text)
all_images.append(images)
if all_images != [None] * len(all_images):
inputs = self.processor(
text=all_texts,
images=all_images,
padding="longest",
truncation=True,
max_length=self.max_seq_length,
return_tensors="pt"
)
else:
inputs = self.processor(
text=all_texts,
padding="longest",
truncation=True,
max_length=self.max_seq_length,
return_tensors="pt"
)
return inputs
|