File size: 4,900 Bytes
d8cc680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
import torch.nn as nn
from transformers import (
    PreTrainedModel,
    AutoModelForCausalLM,
    AutoModel,
    SiglipImageProcessor,
)
from .configuration_llamavision import LlamavisionConfig


class ProjectionModule(nn.Module):
    def __init__(self, mm_hidden_size=1152, hidden_size=4096):
        super(ProjectionModule, self).__init__()

        # Directly set up the sequential model
        self.model = nn.Sequential(
            nn.Linear(mm_hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size),
        )

    def forward(self, x):
        return self.model(x)


class Llamavision(PreTrainedModel):
    config_class = LlamavisionConfig

    def __init__(self, config):
        super().__init__(config)

        self.text_model = AutoModelForCausalLM.from_config(config.text_config)
        self.vision_model = AutoModel.from_config(config.vision_config)
        self.processor = SiglipImageProcessor()
        self.mm_projector = ProjectionModule()

    @property
    def device(self):
        return self.text_model.device

    def tokenizer_image_token(
        self, prompt, tokenizer, image_token_index=-200, return_tensors=None
    ):
        prompt_chunks = [
            tokenizer(chunk).input_ids for chunk in prompt.split("<image>")
        ]

        def insert_separator(X, sep):
            return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]

        input_ids = []
        offset = 0
        if (
            len(prompt_chunks) > 0
            and len(prompt_chunks[0]) > 0
            and prompt_chunks[0][0] == tokenizer.bos_token_id
        ):
            offset = 1
            input_ids.append(prompt_chunks[0][0])

        for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
            input_ids.extend(x[offset:])

        return torch.tensor(input_ids, dtype=torch.long)

    def process_tensors(self, input_ids, image_features, embedding_layer):
        # Find the index of -200 in input_ids
        split_index = (input_ids == -200).nonzero(as_tuple=True)[1][0]

        # Split the input_ids at the index found, excluding -200
        input_ids_1 = input_ids[:, :split_index]
        input_ids_2 = input_ids[:, split_index + 1 :]

        # Convert input_ids to embeddings
        embeddings_1 = embedding_layer(input_ids_1)
        embeddings_2 = embedding_layer(input_ids_2)

        device = image_features.device
        token_embeddings_part1 = embeddings_1.to(device)
        token_embeddings_part2 = embeddings_2.to(device)

        # Concatenate the token embeddings and image features
        concatenated_embeddings = torch.cat(
            [token_embeddings_part1, image_features, token_embeddings_part2], dim=1
        )

        # Create the corrected attention mask
        attention_mask = torch.ones(
            concatenated_embeddings.shape[:2], dtype=torch.long, device=device
        )
        return concatenated_embeddings, attention_mask

    def answer_question(self, image, question, tokenizer, **kwargs):
        question = "<image>" + question

        prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"

        input_ids = (
            self.tokenizer_image_token(prompt, tokenizer, -200, return_tensors="pt")
            .unsqueeze(0)
            .to(self.device)
        )
        terminators = [
            tokenizer.eos_token_id,
            tokenizer.convert_tokens_to_ids("<|eot_id|>"),
        ]
        with torch.inference_mode():
            image_inputs = self.processor(
                images=[image],
                return_tensors="pt",
                do_resize=True,
                size={"height": 384, "width": 384},
            )

            image_inputs = image_inputs["pixel_values"].to(
                device=self.device, dtype=self.dtype
            )

            image_forward_outs = self.vision_model(
                image_inputs,
                output_hidden_states=True,
            )

            image_features = image_forward_outs.hidden_states[-2]

            projected_embeddings = self.mm_projector(image_features).to(self.device)

            embedding_layer = self.text_model.get_input_embeddings()
            # text_embeddings = embedding_layer(input_ids)

            new_embeds, attn_mask = self.process_tensors(
                input_ids, projected_embeddings, embedding_layer
            )

            attn_mask = attn_mask.to(self.device)
            new_embeds = new_embeds.to(self.device)
            answer = self.text_model.generate(
                inputs_embeds=new_embeds,
                attention_mask=attn_mask,
                eos_token_id=terminators,
                temperature=0.2,
                do_sample=True,
                **kwargs,
            )[0]

            return answer