gxy commited on
Commit
52b6907
1 Parent(s): 3c36c35

FIRST: add model weight

Browse files
Files changed (5) hide show
  1. config.json +250 -0
  2. demo.py +38 -0
  3. modeling_ziya_blip2.py +287 -0
  4. pytorch_model.bin +3 -0
  5. wzry.jpg +0 -0
config.json ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ZiyaBLIP2ForConditionalGeneration"
4
+ ],
5
+ "assistant_name": "<bot>",
6
+ "human_name": "<human>",
7
+ "initializer_factor": 1.0,
8
+ "initializer_range": 0.02,
9
+ "model_type": "blip-2",
10
+ "num_query_tokens": 32,
11
+ "prompt_prefix": "",
12
+ "qformer_config": {
13
+ "_name_or_path": "",
14
+ "add_cross_attention": false,
15
+ "architectures": null,
16
+ "attention_probs_dropout_prob": 0.1,
17
+ "bad_words_ids": null,
18
+ "begin_suppress_tokens": null,
19
+ "bos_token_id": null,
20
+ "chunk_size_feed_forward": 0,
21
+ "classifier_dropout": null,
22
+ "cross_attention_frequency": 2,
23
+ "cross_attention_hidden_size": null,
24
+ "decoder_start_token_id": null,
25
+ "diversity_penalty": 0.0,
26
+ "do_sample": false,
27
+ "early_stopping": false,
28
+ "encoder_hidden_size": 1408,
29
+ "encoder_no_repeat_ngram_size": 0,
30
+ "eos_token_id": null,
31
+ "exponential_decay_length_penalty": null,
32
+ "finetuning_task": null,
33
+ "forced_bos_token_id": null,
34
+ "forced_eos_token_id": null,
35
+ "hidden_act": "gelu",
36
+ "hidden_dropout_prob": 0.1,
37
+ "hidden_size": 768,
38
+ "id2label": {
39
+ "0": "LABEL_0",
40
+ "1": "LABEL_1"
41
+ },
42
+ "initializer_range": 0.02,
43
+ "intermediate_size": 3072,
44
+ "is_decoder": false,
45
+ "is_encoder_decoder": false,
46
+ "label2id": {
47
+ "LABEL_0": 0,
48
+ "LABEL_1": 1
49
+ },
50
+ "layer_norm_eps": 1e-12,
51
+ "length_penalty": 1.0,
52
+ "max_length": 20,
53
+ "max_position_embeddings": 512,
54
+ "min_length": 0,
55
+ "model_type": "blip_2_qformer",
56
+ "no_repeat_ngram_size": 0,
57
+ "num_attention_heads": 12,
58
+ "num_beam_groups": 1,
59
+ "num_beams": 1,
60
+ "num_hidden_layers": 12,
61
+ "num_return_sequences": 1,
62
+ "output_attentions": false,
63
+ "output_hidden_states": false,
64
+ "output_scores": false,
65
+ "pad_token_id": 0,
66
+ "position_embedding_type": "absolute",
67
+ "prefix": null,
68
+ "problem_type": null,
69
+ "pruned_heads": {},
70
+ "remove_invalid_values": false,
71
+ "repetition_penalty": 1.0,
72
+ "return_dict": true,
73
+ "return_dict_in_generate": false,
74
+ "sep_token_id": null,
75
+ "suppress_tokens": null,
76
+ "task_specific_params": null,
77
+ "temperature": 1.0,
78
+ "tf_legacy_loss": false,
79
+ "tie_encoder_decoder": false,
80
+ "tie_word_embeddings": true,
81
+ "tokenizer_class": null,
82
+ "top_k": 50,
83
+ "top_p": 1.0,
84
+ "torch_dtype": null,
85
+ "torchscript": false,
86
+ "transformers_version": "4.29.0.dev0",
87
+ "typical_p": 1.0,
88
+ "use_bfloat16": false,
89
+ "vocab_size": 30522
90
+ },
91
+ "text_config": {
92
+ "_name_or_path": "",
93
+ "add_cross_attention": false,
94
+ "architectures": [
95
+ "LlamaForCausalLM"
96
+ ],
97
+ "bad_words_ids": null,
98
+ "begin_suppress_tokens": null,
99
+ "bos_token_id": 1,
100
+ "chunk_size_feed_forward": 0,
101
+ "cross_attention_hidden_size": null,
102
+ "decoder_start_token_id": null,
103
+ "diversity_penalty": 0.0,
104
+ "do_sample": false,
105
+ "early_stopping": false,
106
+ "encoder_no_repeat_ngram_size": 0,
107
+ "eos_token_id": 2,
108
+ "exponential_decay_length_penalty": null,
109
+ "finetuning_task": null,
110
+ "forced_bos_token_id": null,
111
+ "forced_eos_token_id": null,
112
+ "hidden_act": "silu",
113
+ "hidden_size": 5120,
114
+ "id2label": {
115
+ "0": "LABEL_0",
116
+ "1": "LABEL_1"
117
+ },
118
+ "initializer_range": 0.02,
119
+ "intermediate_size": 13824,
120
+ "is_decoder": false,
121
+ "is_encoder_decoder": false,
122
+ "label2id": {
123
+ "LABEL_0": 0,
124
+ "LABEL_1": 1
125
+ },
126
+ "length_penalty": 1.0,
127
+ "max_length": 20,
128
+ "max_position_embeddings": 2048,
129
+ "min_length": 0,
130
+ "model_type": "llama",
131
+ "no_repeat_ngram_size": 0,
132
+ "num_attention_heads": 40,
133
+ "num_beam_groups": 1,
134
+ "num_beams": 1,
135
+ "num_hidden_layers": 40,
136
+ "num_return_sequences": 1,
137
+ "output_attentions": false,
138
+ "output_hidden_states": false,
139
+ "output_scores": false,
140
+ "pad_token_id": 0,
141
+ "prefix": null,
142
+ "problem_type": null,
143
+ "pruned_heads": {},
144
+ "remove_invalid_values": false,
145
+ "repetition_penalty": 1.0,
146
+ "return_dict": true,
147
+ "return_dict_in_generate": false,
148
+ "rms_norm_eps": 1e-06,
149
+ "sep_token_id": null,
150
+ "suppress_tokens": null,
151
+ "task_specific_params": null,
152
+ "temperature": 1.0,
153
+ "tf_legacy_loss": false,
154
+ "tie_encoder_decoder": false,
155
+ "tie_word_embeddings": false,
156
+ "tokenizer_class": null,
157
+ "top_k": 50,
158
+ "top_p": 1.0,
159
+ "torch_dtype": "float32",
160
+ "torchscript": false,
161
+ "transformers_version": "4.29.0.dev0",
162
+ "typical_p": 1.0,
163
+ "use_bfloat16": false,
164
+ "use_cache": true,
165
+ "vocab_size": 39424
166
+ },
167
+ "tie_word_embeddings": false,
168
+ "torch_dtype": "float32",
169
+ "transformers_version": null,
170
+ "use_decoder_only_language_model": true,
171
+ "vision_config": {
172
+ "_name_or_path": "",
173
+ "add_cross_attention": false,
174
+ "architectures": null,
175
+ "attention_dropout": 0.0,
176
+ "bad_words_ids": null,
177
+ "begin_suppress_tokens": null,
178
+ "bos_token_id": null,
179
+ "chunk_size_feed_forward": 0,
180
+ "cross_attention_hidden_size": null,
181
+ "decoder_start_token_id": null,
182
+ "diversity_penalty": 0.0,
183
+ "do_sample": false,
184
+ "dropout": 0.0,
185
+ "early_stopping": false,
186
+ "encoder_no_repeat_ngram_size": 0,
187
+ "eos_token_id": null,
188
+ "exponential_decay_length_penalty": null,
189
+ "finetuning_task": null,
190
+ "forced_bos_token_id": null,
191
+ "forced_eos_token_id": null,
192
+ "hidden_act": "gelu",
193
+ "hidden_size": 1408,
194
+ "id2label": {
195
+ "0": "LABEL_0",
196
+ "1": "LABEL_1"
197
+ },
198
+ "image_size": 224,
199
+ "initializer_factor": 1.0,
200
+ "initializer_range": 1e-10,
201
+ "intermediate_size": 6144,
202
+ "is_decoder": false,
203
+ "is_encoder_decoder": false,
204
+ "label2id": {
205
+ "LABEL_0": 0,
206
+ "LABEL_1": 1
207
+ },
208
+ "layer_norm_eps": 1e-05,
209
+ "length_penalty": 1.0,
210
+ "max_length": 20,
211
+ "min_length": 0,
212
+ "model_type": "blip_2_vision_model",
213
+ "no_repeat_ngram_size": 0,
214
+ "num_attention_heads": 16,
215
+ "num_beam_groups": 1,
216
+ "num_beams": 1,
217
+ "num_channels": 3,
218
+ "num_hidden_layers": 39,
219
+ "num_return_sequences": 1,
220
+ "output_attentions": false,
221
+ "output_hidden_states": false,
222
+ "output_scores": false,
223
+ "pad_token_id": null,
224
+ "patch_size": 14,
225
+ "prefix": null,
226
+ "problem_type": null,
227
+ "projection_dim": 512,
228
+ "pruned_heads": {},
229
+ "qkv_bias": true,
230
+ "remove_invalid_values": false,
231
+ "repetition_penalty": 1.0,
232
+ "return_dict": true,
233
+ "return_dict_in_generate": false,
234
+ "sep_token_id": null,
235
+ "suppress_tokens": null,
236
+ "task_specific_params": null,
237
+ "temperature": 1.0,
238
+ "tf_legacy_loss": false,
239
+ "tie_encoder_decoder": false,
240
+ "tie_word_embeddings": true,
241
+ "tokenizer_class": null,
242
+ "top_k": 50,
243
+ "top_p": 1.0,
244
+ "torch_dtype": null,
245
+ "torchscript": false,
246
+ "transformers_version": "4.29.0.dev0",
247
+ "typical_p": 1.0,
248
+ "use_bfloat16": false
249
+ }
250
+ }
demo.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import LlamaForCausalLM, LlamaTokenizer, BlipImageProcessor
2
+ from modeling_ziya_blip2 import ZiyaBLIP2ForConditionalGeneration
3
+ from PIL import Image
4
+
5
+ # model path of IDEA-CCNL/Ziya-LLaMA-13B-v1
6
+ LM_MODEL_PATH="local path of model IDEA-CCNL/Ziya-LLaMA-13B-v1"
7
+ LM_MODEL_PATH="/cognitive_comp/gaoxinyu/huggingface_model/Ziya-LLaMA-13B-v1"
8
+ lm_model = LlamaForCausalLM.from_pretrained(LM_MODEL_PATH)
9
+ tokenizer = LlamaTokenizer.from_pretrained(LM_MODEL_PATH)
10
+
11
+
12
+ # visual model
13
+ OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
14
+ OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
15
+ # demo.py is in the project path, so we can use local path ".". Otherwise you should use "IDEA-CCNL/Ziya-BLIP2-14B-Visual-v1"
16
+ model = ZiyaBLIP2ForConditionalGeneration.from_pretrained(".", language_model=lm_model)
17
+ image_size = model.config.vision_config.image_size
18
+ image_processor = BlipImageProcessor(
19
+ size={"height": image_size, "width": image_size},
20
+ image_mean=OPENAI_CLIP_MEAN,
21
+ image_std=OPENAI_CLIP_STD,
22
+ )
23
+ model.cuda() # if you use on cpu, comment this line
24
+ generate_config = {
25
+ "max_new_tokens": 128,
26
+ "top_p": 0.1,
27
+ "temperature": 0.7
28
+ }
29
+ output = model.chat(
30
+ tokenizer=tokenizer,
31
+ pixel_values=image_processor(Image.open("wzry.jpg"), return_tensors="pt").pixel_values.to(model.device),
32
+ query="这是什么游戏",
33
+ previous_querys=[],
34
+ previous_outputs=[],
35
+ **generate_config,
36
+ )
37
+ print(output)
38
+ # 这是一款名为《王者荣耀》的多人在线竞技游戏。在游戏中,玩家扮演不同的角色,并与其他玩家进行战斗。游戏中的人物和环境都是虚拟的,但它们看起来非常逼真。玩家需要使用各种技能和策略来击败对手,并获得胜利。
modeling_ziya_blip2.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union, List
2
+
3
+ import torch
4
+ import torch.utils.checkpoint
5
+ from torch import nn
6
+ from transformers.utils import (
7
+ logging,
8
+ )
9
+ from transformers.models.blip_2.configuration_blip_2 import Blip2Config
10
+ from transformers.models.blip_2.modeling_blip_2 import Blip2ForConditionalGenerationModelOutput
11
+ from transformers import (
12
+ Blip2PreTrainedModel,
13
+ Blip2VisionModel,
14
+ AutoModelForCausalLM,
15
+ Blip2QFormerModel,
16
+ PreTrainedTokenizer,
17
+ PreTrainedModel,
18
+ )
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class ZiyaBLIP2ForConditionalGeneration(Blip2PreTrainedModel):
25
+ config_class = Blip2Config
26
+ main_input_name = "pixel_values"
27
+ _keys_to_ignore_on_load_missing = [
28
+ r"language_model",
29
+ ]
30
+ def __init__(self, config: Blip2Config, language_model: PreTrainedModel = None):
31
+ super().__init__(config)
32
+
33
+ self.vision_model = Blip2VisionModel(config.vision_config)
34
+
35
+ self.query_tokens = nn.Parameter(torch.zeros(
36
+ 1, config.num_query_tokens, config.qformer_config.hidden_size))
37
+ self.qformer = Blip2QFormerModel(config.qformer_config)
38
+
39
+ self.language_projection = nn.Linear(
40
+ config.qformer_config.hidden_size, config.text_config.hidden_size)
41
+ if language_model is None:
42
+ if config.use_decoder_only_language_model:
43
+ language_model = AutoModelForCausalLM.from_config(config.text_config)
44
+ else:
45
+ raise Exception("not impl")
46
+ self.language_model = language_model
47
+
48
+ # Initialize weights and apply final processing
49
+ self.post_init()
50
+
51
+ def get_input_embeddings(self):
52
+ return self.language_model.get_input_embeddings()
53
+
54
+ def set_input_embeddings(self, value):
55
+ self.language_model.set_input_embeddings(value)
56
+
57
+ def set_output_embeddings(self, new_embeddings):
58
+ self.language_model.set_output_embeddings(new_embeddings)
59
+
60
+ def get_output_embeddings(self) -> nn.Module:
61
+ return self.language_model.get_output_embeddings()
62
+
63
+ def get_encoder(self):
64
+ return self.language_model.get_encoder()
65
+
66
+ def get_decoder(self):
67
+ return self.language_model.get_decoder()
68
+
69
+ def _tie_weights(self):
70
+ if not self.config.use_decoder_only_language_model:
71
+ self.language_model.encoder.embed_tokens = self.language_model.shared
72
+ self.language_model.decoder.embed_tokens = self.language_model.shared
73
+
74
+ def _preprocess_accelerate(self):
75
+ r"""
76
+ Some pre-processing hacks to make the model `accelerate` compatible. Check
77
+ https://github.com/huggingface/transformers/pull/21707 for more details.
78
+ """
79
+ hf_device_map = self.hf_device_map
80
+
81
+ if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
82
+ # warn users about unexpected behavior when using multi-GPU + BLIP-2 + `accelerate`.
83
+ logger.warning(
84
+ "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
85
+ " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
86
+ " Please pass a `device_map` that contains `language_model` to remove this warning."
87
+ " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for",
88
+ " more details on creating a `device_map` for large models.",
89
+ )
90
+
91
+ if hasattr(self.language_model, "_hf_hook"):
92
+ self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
93
+
94
+ def forward(
95
+ self,
96
+ pixel_values: torch.FloatTensor,
97
+ input_ids_before_image: torch.FloatTensor,
98
+ input_ids_after_image: torch.FloatTensor,
99
+ labels_after_image: torch.FloatTensor,
100
+ # 因为label不会出现在image之前,所以这里不需要labels_before_image, 按照input_ids_before_image补-100就可以了
101
+ output_attentions: Optional[bool] = None,
102
+ output_hidden_states: Optional[bool] = None,
103
+ return_dict: Optional[bool] = None,
104
+ ) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
105
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
106
+
107
+ # step 1: forward the images through the vision encoder,
108
+ # to get image embeddings of shape (batch_size, seq_len, hidden_size)
109
+ vision_outputs = self.vision_model(
110
+ pixel_values=pixel_values,
111
+ output_attentions=output_attentions,
112
+ output_hidden_states=output_hidden_states,
113
+ return_dict=return_dict,
114
+ )
115
+ image_embeds = vision_outputs[0]
116
+
117
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
118
+ image_attention_mask = torch.ones(
119
+ image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
120
+
121
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
122
+ query_outputs = self.qformer(
123
+ query_embeds=query_tokens,
124
+ encoder_hidden_states=image_embeds,
125
+ encoder_attention_mask=image_attention_mask,
126
+ output_attentions=output_attentions,
127
+ output_hidden_states=output_hidden_states,
128
+ return_dict=return_dict,
129
+ )
130
+ query_output = query_outputs[0]
131
+
132
+ # step 2.5 generate the lm input by prompt and output
133
+ language_model_inputs = self.language_projection(query_output)
134
+ language_model_attention_mask = torch.ones(
135
+ language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
136
+ )
137
+ # 确保language_model_inputs的batch
138
+ assert language_model_inputs.shape[0] == input_ids_after_image.shape[0]
139
+ inputs_embeds_before_image = self.language_model.get_input_embeddings()(input_ids_before_image)
140
+ inputs_embeds_after_image = self.language_model.get_input_embeddings()(input_ids_after_image)
141
+ inputs_embeds = torch.cat(
142
+ [
143
+ inputs_embeds_before_image.to(language_model_inputs.device),
144
+ language_model_inputs,
145
+ inputs_embeds_after_image.to(language_model_inputs.device)
146
+ ], dim=1)
147
+
148
+ attention_mask_before = torch.ones_like(input_ids_before_image)
149
+ attention_mask_after = torch.ones_like(input_ids_after_image)
150
+ attention_mask = torch.cat(
151
+ [
152
+ attention_mask_before.to(language_model_attention_mask.device),
153
+ language_model_attention_mask,
154
+ attention_mask_after.to(language_model_attention_mask.device)
155
+ ], dim=1
156
+ )
157
+ # labels也需要对应的处理,把前面空缺的-100加进去
158
+ labels = torch.cat(
159
+ [
160
+ torch.tensor(
161
+ [-100]).expand_as(input_ids_before_image).to(language_model_inputs.device),
162
+ torch.tensor([-100]).expand(query_tokens.shape[:-1]
163
+ ).to(language_model_inputs.device),
164
+ labels_after_image,
165
+ ], dim=1
166
+ )
167
+
168
+ # step 3: use the language model
169
+
170
+ if self.config.use_decoder_only_language_model:
171
+ outputs = self.language_model(
172
+ inputs_embeds=inputs_embeds,
173
+ attention_mask=attention_mask,
174
+ output_attentions=output_attentions,
175
+ output_hidden_states=output_hidden_states,
176
+ return_dict=return_dict,
177
+ labels=labels,
178
+ )
179
+ loss = outputs.loss if return_dict else outputs[0]
180
+ logits = outputs.logits if return_dict else outputs[1]
181
+
182
+ else:
183
+ raise Exception("not impl")
184
+
185
+ if not return_dict:
186
+ output = (logits, vision_outputs, query_outputs, outputs)
187
+ return ((loss,) + output) if loss is not None else output
188
+
189
+ return Blip2ForConditionalGenerationModelOutput(
190
+ loss=loss,
191
+ logits=logits,
192
+ vision_outputs=vision_outputs,
193
+ qformer_outputs=query_outputs,
194
+ language_model_outputs=outputs,
195
+ )
196
+
197
+ def prepare_inputs_for_chat(
198
+ self,
199
+ tokenizer: PreTrainedTokenizer,
200
+ query: str,
201
+ pixel_values: torch.Tensor,
202
+ previous_querys: List[str],
203
+ previous_outputs: List[str],
204
+ max_length: int,
205
+ ):
206
+ # 1. process input_ids
207
+ assert len(previous_querys) == len(previous_outputs)
208
+ device = self.device
209
+ prefix = self.config.prompt_prefix
210
+ human_name = self.config.human_name
211
+ assistant_name = self.config.assistant_name
212
+ input_ids_before_image = tokenizer(
213
+ prefix, return_tensors="pt").input_ids.to(device)
214
+ inputs_ids_after_image = []
215
+ for (p, o) in zip(previous_querys, previous_outputs):
216
+ # {pormpt}\n[答]: {output}\n[问]:
217
+ inputs_ids_after_image += tokenizer(f"{human_name}: {p}\n", add_special_tokens=False).input_ids + \
218
+ tokenizer(f"{assistant_name}: {o}\n", add_special_tokens=False).input_ids
219
+
220
+ inputs_ids_after_image += tokenizer(f"{human_name}: {query}\n",
221
+ add_special_tokens=False).input_ids + tokenizer(f"{assistant_name} :",
222
+ add_special_tokens=False).input_ids
223
+ inputs_ids_after_image = torch.IntTensor([inputs_ids_after_image]).to(device)
224
+ # 2. Prepare embeddings
225
+ pixel_values.to(device)
226
+ image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state
227
+ image_attention_mask = torch.ones(
228
+ image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
229
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
230
+ query_outputs = self.qformer(
231
+ query_embeds=query_tokens,
232
+ encoder_hidden_states=image_embeds,
233
+ encoder_attention_mask=image_attention_mask,
234
+ return_dict=True,
235
+ )
236
+ query_output = query_outputs.last_hidden_state
237
+ language_model_inputs = self.language_projection(query_output)
238
+
239
+ # concatenate query embeddings with prompt embeddings
240
+ prefix_inputs_embeds = self.get_input_embeddings()(input_ids_before_image)
241
+ prompt_inputs_embeds = self.get_input_embeddings()(inputs_ids_after_image)
242
+ inputs_embeds = torch.cat([
243
+ prefix_inputs_embeds.to(language_model_inputs.device),
244
+ language_model_inputs,
245
+ prompt_inputs_embeds.to(language_model_inputs.device)], dim=1)
246
+
247
+ if inputs_embeds.shape[1] > max_length:
248
+ inputs_embeds = inputs_embeds[:, -max_length:, :]
249
+
250
+ input_ids = torch.concat([
251
+ input_ids_before_image,
252
+ torch.tensor([tokenizer.eos_token_id]).expand(
253
+ query_tokens.shape[:-1]).to(language_model_inputs.device),
254
+ inputs_ids_after_image,
255
+ ], dim=1)
256
+
257
+ return input_ids, inputs_embeds
258
+
259
+ def chat(self,
260
+ tokenizer,
261
+ query: str,
262
+ pixel_values: torch.Tensor,
263
+ previous_querys: List[str],
264
+ previous_outputs: List[str],
265
+ **generate_kwargs,):
266
+ """
267
+ use for generate text by chat-style
268
+ Args:
269
+ tokenizer (PretrainedTokenizer): llama tokenizer
270
+ query (str): current input query
271
+ pixel_values (torch.Tensor): image after image_processor
272
+ prompts (List[str]): chat history
273
+ outputs (List[str]): chat history
274
+
275
+ Returns:
276
+ text: generate text
277
+ """
278
+ input_ids, inputs_embeds = self.prepare_inputs_for_chat(
279
+ tokenizer, query, pixel_values, previous_querys, previous_outputs, 2048
280
+ )
281
+ response = self.language_model.generate(
282
+ inputs_embeds=inputs_embeds,
283
+ attention_mask=torch.ones_like(input_ids),
284
+ **generate_kwargs,
285
+ )
286
+ response = tokenizer.decode(response[0], skip_special_tokens=True)
287
+ return response
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cc4682fb5bf8adee4967316f1421b1782d2389c3ac671c448313c925d1eddc4
3
+ size 4380450257
wzry.jpg ADDED