kadirnar commited on
Commit
89455ef
1 Parent(s): 73f1049

Create modeling_llamavision.py

Browse files
Files changed (1) hide show
  1. modeling_llamavision.py +146 -8
modeling_llamavision.py CHANGED
@@ -1,12 +1,150 @@
1
- # https://huggingface.co/qresearch/llama-3.1-8B-vision-378/blob/main/configuration_llamavision.py
2
 
3
- from transformers import PretrainedConfig, LlamaConfig, SiglipVisionConfig
 
 
 
 
 
 
 
 
4
 
5
 
6
- class LlamavisionConfig(PretrainedConfig):
7
- model_type = "llamavision"
 
8
 
9
- def __init__(self, **kwargs):
10
- self.text_config = LlamaConfig(**kwargs.pop("text_config", {}))
11
- self.vision_config = SiglipVisionConfig(**kwargs.pop("vision_config", {}))
12
- super().__init__(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://huggingface.co/qresearch/llama-3.1-8B-vision-378/blob/main/modeling_llamavision.py
2
 
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import (
6
+ PreTrainedModel,
7
+ AutoModelForCausalLM,
8
+ AutoModel,
9
+ SiglipImageProcessor,
10
+ )
11
+ from .configuration_llamavision import LlamavisionConfig
12
 
13
 
14
+ class ProjectionModule(nn.Module):
15
+ def __init__(self, mm_hidden_size=1152, hidden_size=4096):
16
+ super(ProjectionModule, self).__init__()
17
 
18
+ # Directly set up the sequential model
19
+ self.model = nn.Sequential(
20
+ nn.Linear(mm_hidden_size, hidden_size),
21
+ nn.GELU(),
22
+ nn.Linear(hidden_size, hidden_size),
23
+ )
24
+
25
+ def forward(self, x):
26
+ return self.model(x)
27
+
28
+
29
+ class Llamavision(PreTrainedModel):
30
+ config_class = LlamavisionConfig
31
+
32
+ def __init__(self, config):
33
+ super().__init__(config)
34
+
35
+ self.vision_model = AutoModel.from_config(self.config.vision_config)
36
+ self.text_model = AutoModelForCausalLM.from_config(self.config.text_config)
37
+ self.processor = SiglipImageProcessor()
38
+ self.mm_projector = ProjectionModule(
39
+ mm_hidden_size=config.vision_config.hidden_size,
40
+ hidden_size=config.text_config.hidden_size,
41
+ )
42
+
43
+ @property
44
+ def device(self):
45
+ return self.text_model.device
46
+
47
+ def encode_image(self, image):
48
+ image = image.convert("RGB")
49
+ image = self.processor(
50
+ images=image,
51
+ return_tensors="pt",
52
+ do_resize=True,
53
+ size={"height": 378, "width": 378},
54
+ )["pixel_values"].to(
55
+ device=self.vision_model.device, dtype=self.vision_model.dtype
56
+ )
57
+ with torch.no_grad():
58
+ return self.vision_model(image, output_hidden_states=True).hidden_states[-2]
59
+
60
+ def input_embeds(self, prompt, image_embeds, tokenizer):
61
+ def _tokenize(txt):
62
+ return tokenizer(
63
+ txt, return_tensors="pt", add_special_tokens=False
64
+ ).input_ids.to(self.device)
65
+
66
+ text_emb = self.text_model.get_input_embeddings()
67
+
68
+ embeds = []
69
+
70
+ tokenized_prompt = _tokenize(prompt)
71
+ if (
72
+ tokenizer.bos_token_id is not None
73
+ and tokenized_prompt[0][0] != tokenizer.bos_token_id
74
+ ):
75
+ embeds.append(
76
+ text_emb(torch.tensor([[tokenizer.bos_token_id]], device=self.device))
77
+ )
78
+
79
+ projected_image_embeds = self.mm_projector(image_embeds.to(self.device))
80
+ embeds.append(projected_image_embeds)
81
+
82
+ embeds.append(text_emb(tokenized_prompt))
83
+
84
+ return torch.cat(embeds, dim=1)
85
+
86
+ def get_input_embeddings(self):
87
+ return self.text_model.get_input_embeddings()
88
+
89
+ def generate(
90
+ self,
91
+ image_embeds,
92
+ prompt,
93
+ tokenizer,
94
+ max_new_tokens=128,
95
+ **kwargs,
96
+ ):
97
+ generate_config = {
98
+ "eos_token_id": [
99
+ tokenizer.eos_token_id,
100
+ tokenizer.convert_tokens_to_ids("<|eot_id|>"),
101
+ ],
102
+ "bos_token_id": tokenizer.bos_token_id,
103
+ "pad_token_id": tokenizer.pad_token_id,
104
+ "max_new_tokens": max_new_tokens,
105
+ **kwargs,
106
+ }
107
+
108
+ with torch.no_grad():
109
+ inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
110
+
111
+ attention_mask = torch.ones(
112
+ inputs_embeds.shape[:2],
113
+ dtype=torch.long,
114
+ device=inputs_embeds.device
115
+ )
116
+
117
+ output_ids = self.text_model.generate(
118
+ inputs_embeds=inputs_embeds,
119
+ attention_mask=attention_mask,
120
+ **generate_config
121
+ )
122
+
123
+ return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
124
+
125
+ def answer_question(self, image, question, tokenizer, **kwargs):
126
+ image_embeds = self.encode_image(image)
127
+
128
+ chat = [
129
+ {
130
+ "role": "system",
131
+ "content": "You are a helpful AI assistant that can see images and answer questions about them.",
132
+ },
133
+ {"role": "user", "content": question},
134
+ ]
135
+ prompt = tokenizer.apply_chat_template(
136
+ chat, tokenize=False, add_generation_prompt=True
137
+ )
138
+
139
+ # Generate the answer
140
+ with torch.no_grad():
141
+ output = self.generate(
142
+ image_embeds=image_embeds,
143
+ prompt=prompt,
144
+ tokenizer=tokenizer,
145
+ **kwargs,
146
+ )[0]
147
+
148
+ # Clean and return the answer
149
+ cleaned_answer = output.strip()
150
+ return cleaned_answer