File size: 3,357 Bytes
4c104a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501f1f4
4c104a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501f1f4
4c104a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501f1f4
4c104a7
 
 
 
 
501f1f4
4c104a7
501f1f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c104a7
 
 
 
 
501f1f4
4c104a7
501f1f4
4c104a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from io import BytesIO
from typing import Any, Dict, Optional, List
import torch
from PIL import Image
from transformers import AutoProcessor, MllamaForConditionalGeneration
from sentence_transformers.models import Transformer as BaseTransformer


class MultiModalTransformer(BaseTransformer):
    def __init__(
            self,
            model_name_or_path: str,
            cache_dir: Optional[str] = None,
            tokenizer_args: Optional[Dict[str, Any]] = None,
            **kwargs,
    ):
        super().__init__(model_name_or_path, **kwargs)
        if tokenizer_args is None:
            tokenizer_args = {}
        tokenizer_args.pop("trust_remote_code", None)

        # Initialize processor
        self.processor = AutoProcessor.from_pretrained(
            model_name_or_path, cache_dir=cache_dir, **tokenizer_args
        )

    def _load_model(
            self,
            model_name_or_path: str,
            config,
            cache_dir: str,
            backend: str,
            is_peft_model: bool,
            **model_args,
    ) -> None:
        model_args.pop("trust_remote_code", None)
        self.auto_model = MllamaForConditionalGeneration.from_pretrained(
            model_name_or_path, torch_dtype=torch.bfloat16, cache_dir=cache_dir, **model_args
        )

    def forward(
            self, features: Dict[str, torch.Tensor], **kwargs
    ) -> Dict[str, torch.Tensor]:
        # Process inputs through the model
        outputs = self.auto_model(
            **features,
            return_dict=True,
            output_hidden_states=True,
            **kwargs
        )

        features.update({"token_embeddings": outputs.hidden_states[-1]})
        return features

    def tokenize(self, texts: List[List[Dict]] | List[str]) -> Dict[str, torch.Tensor]:
        def process_text_item(item):
            if isinstance(item, str):
                return item, None

            text, img = "", None
            if "image" in item:
                text += "<|image|>"
                img = item["image"]
                if isinstance(img, bytes):
                    img = Image.open(BytesIO(img)).convert("RGB")
                elif isinstance(img, str):
                    img = Image.open(img).convert("RGB")
                elif not isinstance(img, Image):
                    raise ValueError(f"Unknown image type {type(img)}")
            if "text" in item:
                if text:
                    text += "<|begin_of_text|> "
                text += item["text"].lstrip()
            
            return text, img

        all_texts, all_images = [], []
        for item in texts:
            text, images = process_text_item(item)
            all_texts.append(text)
            all_images.append(images)

        if all_images != [None] * len(all_images):
            inputs = self.processor(
                text=all_texts,
                images=all_images,
                padding="longest",
                truncation=True,
                max_length=self.max_seq_length,
                return_tensors="pt"
            )
        else:
            inputs = self.processor(
                text=all_texts,
                padding="longest",
                truncation=True,
                max_length=self.max_seq_length,
                return_tensors="pt"
            )

        return inputs